structio/structio/thread.py

205 lines
5.8 KiB
Python
Raw Normal View History

# Support module for running synchronous functions as
# coroutines into worker threads and to submit asynchronous
# work to the event loop from a synchronous thread
import threading
from collections import deque
import structio
from structio.sync import Event, Semaphore, Queue
from structio.util.ki import enable_ki_protection
from structio.core.syscalls import checkpoint
from structio.core.abc import BaseKernel
from structio.core.run import current_loop, current_task
_storage = threading.local()
# Max number of concurrent threads that can
# be spawned by run_in_worker before blocking
_storage.max_workers = Semaphore(50)
class AsyncThreadEvent(Event):
"""
An extension of the regular event
class that is safe to utilize both
from threads and from async code
"""
def __init__(self):
super().__init__()
self._lock = threading.Lock()
self._workers: deque[threading.Event] = deque()
@enable_ki_protection
def wait_sync(self):
"""
Like wait(), but synchronous
"""
with self._lock:
if self.is_set():
return
ev = threading.Event()
self._workers.append(ev)
ev.wait()
@enable_ki_protection
async def wait(self):
with self._lock:
if self.is_set():
return
await super().wait()
@enable_ki_protection
def set(self):
with self._lock:
if self.is_set():
return
# We can't just call super().set() because that
# will call current_loop(), and we may have been
# called from a non-async thread
loop: BaseKernel = _storage.parent_loop
for task in self._tasks:
loop.reschedule(task)
# Awakes all threads
for evt in self._workers:
evt.set()
self._set = True
class AsyncThreadQueue(Queue):
"""
An extension of the regular queue
class that is safe to use both from
threaded and asynchronous code
"""
def __init__(self, max_size):
super().__init__(max_size)
self._lock = threading.Lock()
@enable_ki_protection
async def get(self):
evt: AsyncThreadEvent | None = None
with self._lock:
if not self.container:
self.getters.append(AsyncThreadEvent())
evt = self.getters[-1]
if self.putters:
self.putters.popleft().set()
if evt:
await evt.wait()
return self.container.popleft()
@enable_ki_protection
async def put(self, item):
evt: AsyncThreadEvent | None = None
with self._lock:
if self.maxsize and self.maxsize == len(self.container):
self.putters.append(AsyncThreadEvent())
evt = self.putters[-1]
if self.getters:
self.getters.popleft().set()
if evt:
await evt.wait()
self.container.append(item)
await checkpoint()
@enable_ki_protection
def put_sync(self, item):
"""
Like put(), but synchronous
"""
evt: AsyncThreadEvent | None = None
with self._lock:
if self.maxsize and self.maxsize == len(self.container):
self.putters.append(AsyncThreadEvent())
if self.getters:
self.getters.popleft().set()
if evt:
evt.wait_sync()
self.container.append(item)
@enable_ki_protection
def get_sync(self):
"""
Like get(), but asynchronous
"""
evt: AsyncThreadEvent | None = None
with self._lock:
if not self.container:
self.getters.append(AsyncThreadEvent())
evt = self.getters[-1]
if self.putters:
self.putters.popleft().set()
if evt:
evt.wait_sync()
return self.container.popleft()
def _threaded_runner(f, q: AsyncThreadQueue, parent_loop: BaseKernel, *args):
try:
_storage.parent_loop = parent_loop
q.put_sync((True, f(*args)))
except BaseException as e:
q.put_sync((False, e))
async def _async_runner(f, *args):
queue = AsyncThreadQueue(1)
th = threading.Thread(target=_threaded_runner, args=(f, queue, current_loop(), *args),
name="structio-worker-thread")
th.start()
success, data = await queue.get()
if success:
return data
raise data
async def run_in_worker(sync_func,
*args,
):
"""
Call the given synchronous function in a separate
worker thread, turning it into an async operation.
The result of the call is returned, and any exceptions
are propagated back to the caller. Note that threaded
operations are not usually cancellable (i.e. the async
operation will fail when cancelled, but the thread will
continue running until termination, as there is no simple
and reliable way to stop a thread anywhere)
"""
if not hasattr(_storage, "parent_loop"):
_storage.parent_loop = current_loop()
async with _storage.max_workers:
async with structio.create_pool() as pool:
# This will automatically block once
# we run out of slots and proceed once
# we have more
return await pool.spawn(_async_runner, sync_func, *args)
def set_max_worker_count(count: int):
"""
Sets a new value for the maximum number of concurrent
worker threads structio is allowed to spawn
"""
# Everything, to avoid the unholy "global"
_storage.max_workers = Semaphore(count)
def get_max_worker_count() -> int:
"""
Gets the maximum number of concurrent worker
threads structio is allowed to spawn
"""
return _storage.max_workers.max_size