structio/structio/thread.py

289 lines
8.9 KiB
Python
Raw Normal View History

# Support module for running synchronous functions as
# coroutines into worker threads and to submit asynchronous
# work to the event loop from a synchronous thread
import threading
from collections import deque
import structio
from structio.sync import Event, Semaphore, Queue
from structio.util.ki import enable_ki_protection
from structio.core.syscalls import checkpoint
from structio.core.abc import BaseKernel
from structio.core.run import current_loop
from typing import Callable, Any, Coroutine
from structio.core.exceptions import StructIOException, TimedOut, Cancelled
_storage = threading.local()
# Max number of concurrent threads that can
# be spawned by run_in_worker before blocking
_storage.max_workers = Semaphore(50)
def is_async_thread() -> bool:
return hasattr(_storage, "parent_loop")
class AsyncThreadEvent(Event):
"""
An extension of the regular event
class that is safe to utilize both
from threads and from async code
"""
def __init__(self):
super().__init__()
self._lock = threading.Lock()
self._workers: deque[threading.Event] = deque()
@enable_ki_protection
def wait_sync(self):
"""
Like wait(), but synchronous
"""
with self._lock:
if self.is_set():
return
ev = threading.Event()
self._workers.append(ev)
ev.wait()
@enable_ki_protection
async def wait(self):
with self._lock:
if self.is_set():
return
await super().wait()
@enable_ki_protection
def set(self):
with self._lock:
if self.is_set():
return
# We can't just call super().set() because that
# will call current_loop(), and we may have been
# called from a non-async thread
loop: BaseKernel = _storage.parent_loop
for task in self._tasks:
loop.reschedule(task)
# Awakes all threads
for evt in self._workers:
evt.set()
self._set = True
class AsyncThreadQueue(Queue):
"""
An extension of the regular queue
class that is safe to use both from
threaded and asynchronous code
"""
def __init__(self, max_size):
super().__init__(max_size)
self._lock = threading.Lock()
@enable_ki_protection
async def get(self):
evt: AsyncThreadEvent | None = None
with self._lock:
if not self.container:
self.getters.append(AsyncThreadEvent())
evt = self.getters[-1]
if self.putters:
self.putters.popleft().set()
if evt:
await evt.wait()
return self.container.popleft()
@enable_ki_protection
async def put(self, item):
evt: AsyncThreadEvent | None = None
with self._lock:
if self.maxsize and self.maxsize == len(self.container):
self.putters.append(AsyncThreadEvent())
evt = self.putters[-1]
if self.getters:
self.getters.popleft().set()
if evt:
await evt.wait()
self.container.append(item)
await checkpoint()
@enable_ki_protection
def put_sync(self, item):
"""
Like put(), but synchronous
"""
evt: AsyncThreadEvent | None = None
with self._lock:
if self.maxsize and self.maxsize == len(self.container):
2023-05-18 11:23:34 +02:00
evt = AsyncThreadEvent()
self.putters.append(evt)
if self.getters:
self.getters.popleft().set()
if evt:
evt.wait_sync()
self.container.append(item)
@enable_ki_protection
def get_sync(self):
"""
2023-05-18 11:23:34 +02:00
Like get(), but synchronous
"""
evt: AsyncThreadEvent | None = None
with self._lock:
if not self.container:
self.getters.append(AsyncThreadEvent())
evt = self.getters[-1]
if self.putters:
self.putters.popleft().set()
if evt:
evt.wait_sync()
return self.container.popleft()
def _threaded_runner(f, q: AsyncThreadQueue, parent_loop: BaseKernel, rq: AsyncThreadQueue,
rsq: AsyncThreadQueue, evt: AsyncThreadEvent, *args):
try:
_storage.parent_loop = parent_loop
_storage.rq = rq
_storage.rsq = rsq
q.put_sync((True, f(*args)))
except BaseException as e:
q.put_sync((False, e))
finally:
evt.set()
@enable_ki_protection
async def _async_waiter(events, results: AsyncThreadQueue):
while True:
data = await events.get()
if not data:
break
coro = data
try:
await results.put((True, await coro))
except BaseException as e:
await results.put((False, e))
@enable_ki_protection
async def _wait_for_thread(events, results: AsyncThreadQueue, evt: AsyncThreadEvent, cancellable: bool = False):
async with structio.create_pool() as pool:
pool.scope.shielded = not cancellable
# Spawn a coroutine to process incoming requests from
# the new async thread
waiter = pool.spawn(_async_waiter, events, results)
# Wait for the thread to terminate
await evt.wait()
# Worker thread has exited: we no longer need to process any
# requests, so we shut our waiter down
await events.put(None)
2023-05-18 11:23:34 +02:00
@enable_ki_protection
async def _async_runner(f, cancellable: bool = False, *args):
evt = AsyncThreadEvent()
queue = AsyncThreadQueue(1)
# Request queue
rq = AsyncThreadQueue(0)
# Results queue
rsq = AsyncThreadQueue(0)
current_loop().current_pool.spawn(_wait_for_thread, rq, rsq, evt, cancellable)
th = threading.Thread(target=_threaded_runner, args=(f, queue, current_loop(), rq, rsq, evt, *args),
name="structio-worker-thread", daemon=cancellable)
th.start()
success, data = await queue.get()
await rsq.put(None)
if success:
return data
raise data
@enable_ki_protection
async def run_in_worker(sync_func,
*args,
cancellable: bool = False,
):
"""
Call the given synchronous function in a separate
worker thread, turning it into an async operation.
The result of the call is returned, and any exceptions
are propagated back to the caller. If cancellable equals
False, the default, then the operation cannot be canceled
in any way. If cancellable equals True, cancellation will
cause this function to return early and to abruptly drop
the thread: keep in mind the thread is likely to keep running
in the background, as structio doesn't make any effort to stop
it (it can't). If you call this with cancellable=True, make sure
the operation you're performing is side-effect-free (for example,
the async version of getaddrinfo() uses run_in_worker with cancellable=True
to avoid hogging the event loop when doing domain name resolution but still
be able to fail properly, since no one really cares if a random DNS lookup
keeps running in the background)
"""
if not hasattr(_storage, "parent_loop"):
_storage.parent_loop = current_loop()
else:
try:
current_loop()
except StructIOException:
raise StructIOException("cannot be called from sync context")
# This will automatically block once
# we run out of slots and proceed once
# we have more
async with _storage.max_workers:
return await current_loop().current_pool.spawn(_async_runner, sync_func, cancellable, *args)
@enable_ki_protection
def run_coro(async_func: Callable[[Any, Any], Coroutine[Any, Any, Any]],
*args, **kwargs):
"""
Submits a coroutine for execution to the event loop, passing any
arguments along the way. Return values and exceptions
are propagated and from the point of view of the calling thread,
this call blocks until the coroutine returns
"""
try:
current_loop()
raise StructIOException("cannot be called from async context")
except StructIOException:
pass
if not hasattr(_storage, "parent_loop"):
raise StructIOException("run_coro requires a running loop in another thread!")
_storage.rq.put_sync(async_func(*args, **kwargs))
success, data = _storage.rsq.get_sync()
if success:
return data
raise data
def set_max_worker_count(count: int):
"""
Sets a new value for the maximum number of concurrent
worker threads structio is allowed to spawn
"""
# Everything, to avoid the unholy "global"
_storage.max_workers = Semaphore(count)
def get_max_worker_count() -> int:
"""
Gets the maximum number of concurrent worker
threads structio is allowed to spawn
"""
return _storage.max_workers.max_size