385 lines
13 KiB
Python
385 lines
13 KiB
Python
# Support module for running synchronous functions as
|
|
# coroutines into worker threads and to submit asynchronous
|
|
# work to the event loop from a synchronous thread
|
|
from functools import partial
|
|
|
|
import structio
|
|
import threading
|
|
from collections import deque
|
|
from structio.abc import Kernel
|
|
from structio.core.run import current_loop
|
|
from typing import Callable, Any, Coroutine
|
|
from structio.core.syscalls import checkpoint
|
|
from structio.sync import Event, Semaphore, Queue
|
|
from structio.util.ki import enable_ki_protection
|
|
from structio.exceptions import StructIOException
|
|
from itertools import count as _count
|
|
|
|
|
|
_storage = threading.local()
|
|
# Max number of concurrent threads that can
|
|
# be spawned by run_in_worker before blocking
|
|
_storage.max_workers = Semaphore(50)
|
|
_worker_id = _count()
|
|
|
|
|
|
def is_async_thread() -> bool:
|
|
return hasattr(_storage, "parent_loop")
|
|
|
|
|
|
class AsyncThreadEvent(Event):
|
|
"""
|
|
An extension of the regular event
|
|
class that is safe to utilize both
|
|
from threads and from async code
|
|
"""
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
self._lock = threading.Lock()
|
|
self._workers: deque[threading.Event] = deque()
|
|
|
|
@enable_ki_protection
|
|
def wait_sync(self):
|
|
"""
|
|
Like wait(), but synchronous
|
|
"""
|
|
|
|
with self._lock:
|
|
if self.is_set():
|
|
return
|
|
ev = threading.Event()
|
|
self._workers.append(ev)
|
|
ev.wait()
|
|
|
|
@enable_ki_protection
|
|
async def wait(self):
|
|
with self._lock:
|
|
if self.is_set():
|
|
return
|
|
await super().wait()
|
|
|
|
@enable_ki_protection
|
|
def set(self):
|
|
with self._lock:
|
|
if self.is_set():
|
|
return
|
|
# We can't just call super().set() because that
|
|
# will call current_loop(), and we may have been
|
|
# called from an async thread that doesn't have a
|
|
# loop
|
|
loop: Kernel = _storage.parent_loop
|
|
for task in self._tasks:
|
|
loop.reschedule(task)
|
|
# Awakes all threads
|
|
for evt in self._workers:
|
|
evt.set()
|
|
self._set = True
|
|
|
|
|
|
class AsyncThreadQueue(Queue):
|
|
"""
|
|
An extension of the regular queue
|
|
class that is safe to use both from
|
|
threaded and asynchronous code
|
|
"""
|
|
|
|
def __init__(self, max_size):
|
|
super().__init__(max_size)
|
|
self._lock = threading.Lock()
|
|
|
|
@enable_ki_protection
|
|
async def get(self):
|
|
evt: AsyncThreadEvent | None = None
|
|
with self._lock:
|
|
if not self.container:
|
|
self.getters.append(AsyncThreadEvent())
|
|
evt = self.getters[-1]
|
|
if self.putters:
|
|
self.putters.popleft().set()
|
|
if evt:
|
|
await evt.wait()
|
|
await checkpoint()
|
|
return self.container.popleft()
|
|
|
|
@enable_ki_protection
|
|
async def put(self, item):
|
|
evt: AsyncThreadEvent | None = None
|
|
with self._lock:
|
|
if self.maxsize and self.maxsize == len(self.container):
|
|
self.putters.append(AsyncThreadEvent())
|
|
evt = self.putters[-1]
|
|
if self.getters:
|
|
self.getters.popleft().set()
|
|
if evt:
|
|
await evt.wait()
|
|
self.container.append(item)
|
|
await checkpoint()
|
|
|
|
@enable_ki_protection
|
|
def get_noblock(self) -> Any:
|
|
return super().get_noblock()
|
|
|
|
@enable_ki_protection
|
|
def put_noblock(self, item: Any):
|
|
return super().put_noblock(item)
|
|
|
|
@enable_ki_protection
|
|
def put_sync(self, item):
|
|
"""
|
|
Like put(), but synchronous
|
|
"""
|
|
|
|
evt: AsyncThreadEvent | None = None
|
|
with self._lock:
|
|
if self.maxsize and self.maxsize == len(self.container):
|
|
evt = AsyncThreadEvent()
|
|
self.putters.append(evt)
|
|
if self.getters:
|
|
self.getters.popleft().set()
|
|
if evt:
|
|
evt.wait_sync()
|
|
self.container.append(item)
|
|
|
|
@enable_ki_protection
|
|
def get_sync(self):
|
|
"""
|
|
Like get(), but synchronous
|
|
"""
|
|
|
|
evt: AsyncThreadEvent | None = None
|
|
with self._lock:
|
|
if not self.container:
|
|
self.getters.append(AsyncThreadEvent())
|
|
evt = self.getters[-1]
|
|
if self.putters:
|
|
self.putters.popleft().set()
|
|
if evt:
|
|
evt.wait_sync()
|
|
return self.container.popleft()
|
|
|
|
|
|
# Just a bunch of private helpers to run sync/async functions
|
|
|
|
|
|
def _threaded_runner(
|
|
f,
|
|
parent_loop: Kernel,
|
|
rq: AsyncThreadQueue,
|
|
rsq: AsyncThreadQueue,
|
|
evt: AsyncThreadEvent,
|
|
coro_runner: "structio.util.wakeup_fd.WakeupFd",
|
|
supervisor: "structio.util.wakeup_fd.WakeupFd",
|
|
*args,
|
|
):
|
|
"""
|
|
This is the actual function where our worker thread "lives"
|
|
"""
|
|
|
|
try:
|
|
# Setup thread-local storage so future calls
|
|
# to run_coro() can find this stuff
|
|
_storage.parent_loop = parent_loop
|
|
_storage.rq = rq
|
|
_storage.rsq = rsq
|
|
_storage.coro_runner = coro_runner
|
|
_storage.supervisor = supervisor
|
|
result = f(*args)
|
|
except BaseException as e:
|
|
rsq.put_sync((False, e))
|
|
else:
|
|
rsq.put_sync((True, result))
|
|
finally:
|
|
# Wakeup the event loop
|
|
_storage.supervisor.wakeup()
|
|
# Notify run_in_worker that the thread
|
|
# has exited
|
|
evt.set()
|
|
|
|
|
|
@enable_ki_protection
|
|
async def _coroutine_request_handler(
|
|
coroutines: AsyncThreadQueue,
|
|
results: AsyncThreadQueue,
|
|
reader: "structio.socket.AsyncSocket",
|
|
):
|
|
"""
|
|
Runs coroutines on behalf of a thread spawned by structio and
|
|
submits the outcome back to the thread
|
|
"""
|
|
|
|
while True:
|
|
await reader.receive(1)
|
|
coro = await coroutines.get()
|
|
try:
|
|
result = await coro
|
|
except BaseException as e:
|
|
await results.put((False, e))
|
|
else:
|
|
await results.put((True, result))
|
|
|
|
|
|
@enable_ki_protection
|
|
async def run_in_worker(
|
|
sync_func,
|
|
*args,
|
|
cancellable: bool = False,
|
|
):
|
|
"""
|
|
Call the given synchronous function in a separate
|
|
worker thread, turning it into an async operation.
|
|
Must be called from an asynchronous context (a
|
|
StructIOException is raised otherwise). The result
|
|
of the call is returned, and any exceptions that occur
|
|
are propagated back to the caller. This is semantically
|
|
identical to just calling the function itself from within
|
|
the async context, but it has the added benefit of 1) Being
|
|
partially cancellable (with a catch, see below) and 2) If
|
|
the function performs some long-running blocking operation,
|
|
calling it in the main thread is not advisable, as it would
|
|
cause structio's event loop to grind to a halt, meaning that
|
|
timeouts and cancellations don't work, I/O doesn't get scheduled,
|
|
and all sorts of nasty things happen (or rather, don't happen,
|
|
since no work is getting done). In short, don't do long-running
|
|
sync calls in the main thread, use a worker. Also, don't do any
|
|
CPU-bound work in it, or you're likely to negatively affect the main
|
|
thread anyway because CPython is weird and likes to starve-out I/O
|
|
bound threads if there's some CPU-bound workers running (for that kind
|
|
of work, you might want to spawn an entire separate process instead).
|
|
Now, onto cancellations: If cancellable equals False, then the operation
|
|
cannot be canceled in any way (this is the default option). This means
|
|
that even if you set a task scope with a timeout or explicitly cancel
|
|
the pool where this function is awaited, its effects won't be visible
|
|
until after the thread has exited. If cancellable equals True, cancellation
|
|
will cause this function to return early and to abruptly drop the thread:
|
|
keep in mind that it is likely to keep running in the background, as
|
|
structio doesn't make any effort to stop it (it can't). If you call this
|
|
with cancellable=True, make sure the operation you're performing is side-effect-free,
|
|
or you might get nasty deadlocks or race conditions happening.
|
|
|
|
Note: If the number of current active thread workers is equal to the value of get_max_worker_count(),
|
|
this function blocks until a slot is available and then proceeds normally.
|
|
|
|
"""
|
|
|
|
if not hasattr(_storage, "parent_loop"):
|
|
_storage.parent_loop = current_loop()
|
|
else:
|
|
try:
|
|
current_loop()
|
|
except StructIOException:
|
|
raise StructIOException("cannot be called from sync context")
|
|
# This will automatically block once
|
|
# we run out of slots and proceed once
|
|
# we have more
|
|
async with _storage.max_workers:
|
|
# Thread termination event
|
|
terminate = AsyncThreadEvent()
|
|
# Request queue. This is where the thread
|
|
# sends coroutines to run
|
|
rq = AsyncThreadQueue(0)
|
|
# Results queue. This is where we put the result
|
|
# of the coroutines in the request queue
|
|
rsq = AsyncThreadQueue(0)
|
|
# This looks like a lot of bookkeeping to do synchronization, but it all has a purpose.
|
|
# The termination event is necessary so that we can know when the thread has terminated,
|
|
# no surprises there I'd say. The request and result queues are used to send coroutines
|
|
# and their results back and forth when using run_coro from within the "asynchronous thread"
|
|
async with structio.create_pool() as pool:
|
|
# If the operation is cancellable, then we're not
|
|
# shielded
|
|
pool.scope.shielded = not cancellable
|
|
worker_id = next(_worker_id)
|
|
wakeup = structio.util.wakeup_fd.WakeupFd()
|
|
wakeup2 = structio.util.wakeup_fd.WakeupFd()
|
|
# Spawn a coroutine to process incoming requests from
|
|
# the new async thread. We can't await it because it
|
|
# needs to run in the background
|
|
handler = pool.spawn(_coroutine_request_handler, rq, rsq, wakeup.reader)
|
|
# Start the worker thread
|
|
threading.Thread(
|
|
target=_threaded_runner,
|
|
args=(
|
|
sync_func,
|
|
current_loop(),
|
|
rq,
|
|
rsq,
|
|
terminate,
|
|
wakeup,
|
|
wakeup2,
|
|
*args,
|
|
),
|
|
name=f"structio-worker-thread-{worker_id}",
|
|
# We start cancellable threads in daemonic mode so that
|
|
# the main thread doesn't get stuck waiting on them forever
|
|
# when their associated async counterpart gets cancelled. This
|
|
# is due to the fact that there's really no way to "kill" a thread
|
|
# (and for good reason!), so we just pretend nothing happened and go
|
|
# about our merry way, hoping the thread dies eventually I guess
|
|
daemon=cancellable,
|
|
).start()
|
|
# Ensure we get poked by the worker thread
|
|
await wakeup2.reader.receive(1)
|
|
# Wait for the thread to terminate
|
|
await terminate.wait()
|
|
# Worker thread has exited: we no longer need to process
|
|
# any requests, so we shut our request handler down
|
|
handler.cancel()
|
|
# Fetch for the final result from the thread. We use get_noblock()
|
|
# because we know the result should already be there, so the operation
|
|
# should not block (and if this raises WouldBlock, then it's a bug)
|
|
success, data = rsq.get_noblock()
|
|
if success:
|
|
return data
|
|
raise data
|
|
|
|
|
|
@enable_ki_protection
|
|
def run_coro(
|
|
async_func: Callable[[Any, Any], Coroutine[Any, Any, Any]], *args, **kwargs
|
|
):
|
|
"""
|
|
Submits a coroutine for execution to the event loop from another thread,
|
|
passing any arguments along the way. Return values and exceptions are
|
|
propagated, and from the point of view of the calling thread this call
|
|
blocks until the coroutine returns. The thread must be async flavored,
|
|
meaning it must be able to communicate back and forth with the event
|
|
loop running in the main thread (in practice, this means only threads
|
|
spawned with run_in_worker are able to call this)
|
|
"""
|
|
|
|
try:
|
|
current_loop()
|
|
except StructIOException:
|
|
pass
|
|
else:
|
|
raise StructIOException("cannot be called from async context")
|
|
if not is_async_thread() or _storage.parent_loop.done():
|
|
raise StructIOException("run_coro requires a running loop in another thread!")
|
|
# Wake up the event loop if it's blocked in a call to select() or similar I/O routine
|
|
_storage.coro_runner.wakeup()
|
|
_storage.rq.put_sync(async_func(*args, **kwargs))
|
|
success, data = _storage.rsq.get_sync()
|
|
if success:
|
|
return data
|
|
raise data
|
|
|
|
|
|
def set_max_worker_count(count: int):
|
|
"""
|
|
Sets a new value for the maximum number of concurrent
|
|
worker threads structio is allowed to spawn
|
|
"""
|
|
|
|
# Everything, to avoid the unholy "global"
|
|
_storage.max_workers = Semaphore(count)
|
|
|
|
|
|
def get_max_worker_count() -> int:
|
|
"""
|
|
Gets the maximum number of concurrent worker
|
|
threads structio is allowed to spawn
|
|
"""
|
|
|
|
return _storage.max_workers.max_size
|