structio/structio/thread.py

385 lines
13 KiB
Python

# Support module for running synchronous functions as
# coroutines into worker threads and to submit asynchronous
# work to the event loop from a synchronous thread
from functools import partial
import structio
import threading
from collections import deque
from structio.abc import Kernel
from structio.core.run import current_loop
from typing import Callable, Any, Coroutine
from structio.core.syscalls import checkpoint
from structio.sync import Event, Semaphore, Queue
from structio.util.ki import enable_ki_protection
from structio.exceptions import StructIOException
from itertools import count as _count
_storage = threading.local()
# Max number of concurrent threads that can
# be spawned by run_in_worker before blocking
_storage.max_workers = Semaphore(50)
_worker_id = _count()
def is_async_thread() -> bool:
return hasattr(_storage, "parent_loop")
class AsyncThreadEvent(Event):
"""
An extension of the regular event
class that is safe to utilize both
from threads and from async code
"""
def __init__(self):
super().__init__()
self._lock = threading.Lock()
self._workers: deque[threading.Event] = deque()
@enable_ki_protection
def wait_sync(self):
"""
Like wait(), but synchronous
"""
with self._lock:
if self.is_set():
return
ev = threading.Event()
self._workers.append(ev)
ev.wait()
@enable_ki_protection
async def wait(self):
with self._lock:
if self.is_set():
return
await super().wait()
@enable_ki_protection
def set(self):
with self._lock:
if self.is_set():
return
# We can't just call super().set() because that
# will call current_loop(), and we may have been
# called from an async thread that doesn't have a
# loop
loop: Kernel = _storage.parent_loop
for task in self._tasks:
loop.reschedule(task)
# Awakes all threads
for evt in self._workers:
evt.set()
self._set = True
class AsyncThreadQueue(Queue):
"""
An extension of the regular queue
class that is safe to use both from
threaded and asynchronous code
"""
def __init__(self, max_size):
super().__init__(max_size)
self._lock = threading.Lock()
@enable_ki_protection
async def get(self):
evt: AsyncThreadEvent | None = None
with self._lock:
if not self.container:
self.getters.append(AsyncThreadEvent())
evt = self.getters[-1]
if self.putters:
self.putters.popleft().set()
if evt:
await evt.wait()
await checkpoint()
return self.container.popleft()
@enable_ki_protection
async def put(self, item):
evt: AsyncThreadEvent | None = None
with self._lock:
if self.maxsize and self.maxsize == len(self.container):
self.putters.append(AsyncThreadEvent())
evt = self.putters[-1]
if self.getters:
self.getters.popleft().set()
if evt:
await evt.wait()
self.container.append(item)
await checkpoint()
@enable_ki_protection
def get_noblock(self) -> Any:
return super().get_noblock()
@enable_ki_protection
def put_noblock(self, item: Any):
return super().put_noblock(item)
@enable_ki_protection
def put_sync(self, item):
"""
Like put(), but synchronous
"""
evt: AsyncThreadEvent | None = None
with self._lock:
if self.maxsize and self.maxsize == len(self.container):
evt = AsyncThreadEvent()
self.putters.append(evt)
if self.getters:
self.getters.popleft().set()
if evt:
evt.wait_sync()
self.container.append(item)
@enable_ki_protection
def get_sync(self):
"""
Like get(), but synchronous
"""
evt: AsyncThreadEvent | None = None
with self._lock:
if not self.container:
self.getters.append(AsyncThreadEvent())
evt = self.getters[-1]
if self.putters:
self.putters.popleft().set()
if evt:
evt.wait_sync()
return self.container.popleft()
# Just a bunch of private helpers to run sync/async functions
def _threaded_runner(
f,
parent_loop: Kernel,
rq: AsyncThreadQueue,
rsq: AsyncThreadQueue,
evt: AsyncThreadEvent,
coro_runner: "structio.util.wakeup_fd.WakeupFd",
supervisor: "structio.util.wakeup_fd.WakeupFd",
*args,
):
"""
This is the actual function where our worker thread "lives"
"""
try:
# Setup thread-local storage so future calls
# to run_coro() can find this stuff
_storage.parent_loop = parent_loop
_storage.rq = rq
_storage.rsq = rsq
_storage.coro_runner = coro_runner
_storage.supervisor = supervisor
result = f(*args)
except BaseException as e:
rsq.put_sync((False, e))
else:
rsq.put_sync((True, result))
finally:
# Wakeup the event loop
_storage.supervisor.wakeup()
# Notify run_in_worker that the thread
# has exited
evt.set()
@enable_ki_protection
async def _coroutine_request_handler(
coroutines: AsyncThreadQueue,
results: AsyncThreadQueue,
reader: "structio.socket.AsyncSocket",
):
"""
Runs coroutines on behalf of a thread spawned by structio and
submits the outcome back to the thread
"""
while True:
await reader.receive(1)
coro = await coroutines.get()
try:
result = await coro
except BaseException as e:
await results.put((False, e))
else:
await results.put((True, result))
@enable_ki_protection
async def run_in_worker(
sync_func,
*args,
cancellable: bool = False,
):
"""
Call the given synchronous function in a separate
worker thread, turning it into an async operation.
Must be called from an asynchronous context (a
StructIOException is raised otherwise). The result
of the call is returned, and any exceptions that occur
are propagated back to the caller. This is semantically
identical to just calling the function itself from within
the async context, but it has the added benefit of 1) Being
partially cancellable (with a catch, see below) and 2) If
the function performs some long-running blocking operation,
calling it in the main thread is not advisable, as it would
cause structio's event loop to grind to a halt, meaning that
timeouts and cancellations don't work, I/O doesn't get scheduled,
and all sorts of nasty things happen (or rather, don't happen,
since no work is getting done). In short, don't do long-running
sync calls in the main thread, use a worker. Also, don't do any
CPU-bound work in it, or you're likely to negatively affect the main
thread anyway because CPython is weird and likes to starve-out I/O
bound threads if there's some CPU-bound workers running (for that kind
of work, you might want to spawn an entire separate process instead).
Now, onto cancellations: If cancellable equals False, then the operation
cannot be canceled in any way (this is the default option). This means
that even if you set a task scope with a timeout or explicitly cancel
the pool where this function is awaited, its effects won't be visible
until after the thread has exited. If cancellable equals True, cancellation
will cause this function to return early and to abruptly drop the thread:
keep in mind that it is likely to keep running in the background, as
structio doesn't make any effort to stop it (it can't). If you call this
with cancellable=True, make sure the operation you're performing is side-effect-free,
or you might get nasty deadlocks or race conditions happening.
Note: If the number of current active thread workers is equal to the value of get_max_worker_count(),
this function blocks until a slot is available and then proceeds normally.
"""
if not hasattr(_storage, "parent_loop"):
_storage.parent_loop = current_loop()
else:
try:
current_loop()
except StructIOException:
raise StructIOException("cannot be called from sync context")
# This will automatically block once
# we run out of slots and proceed once
# we have more
async with _storage.max_workers:
# Thread termination event
terminate = AsyncThreadEvent()
# Request queue. This is where the thread
# sends coroutines to run
rq = AsyncThreadQueue(0)
# Results queue. This is where we put the result
# of the coroutines in the request queue
rsq = AsyncThreadQueue(0)
# This looks like a lot of bookkeeping to do synchronization, but it all has a purpose.
# The termination event is necessary so that we can know when the thread has terminated,
# no surprises there I'd say. The request and result queues are used to send coroutines
# and their results back and forth when using run_coro from within the "asynchronous thread"
async with structio.create_pool() as pool:
# If the operation is cancellable, then we're not
# shielded
pool.scope.shielded = not cancellable
worker_id = next(_worker_id)
wakeup = structio.util.wakeup_fd.WakeupFd()
wakeup2 = structio.util.wakeup_fd.WakeupFd()
# Spawn a coroutine to process incoming requests from
# the new async thread. We can't await it because it
# needs to run in the background
handler = pool.spawn(_coroutine_request_handler, rq, rsq, wakeup.reader)
# Start the worker thread
threading.Thread(
target=_threaded_runner,
args=(
sync_func,
current_loop(),
rq,
rsq,
terminate,
wakeup,
wakeup2,
*args,
),
name=f"structio-worker-thread-{worker_id}",
# We start cancellable threads in daemonic mode so that
# the main thread doesn't get stuck waiting on them forever
# when their associated async counterpart gets cancelled. This
# is due to the fact that there's really no way to "kill" a thread
# (and for good reason!), so we just pretend nothing happened and go
# about our merry way, hoping the thread dies eventually I guess
daemon=cancellable,
).start()
# Ensure we get poked by the worker thread
await wakeup2.reader.receive(1)
# Wait for the thread to terminate
await terminate.wait()
# Worker thread has exited: we no longer need to process
# any requests, so we shut our request handler down
handler.cancel()
# Fetch for the final result from the thread. We use get_noblock()
# because we know the result should already be there, so the operation
# should not block (and if this raises WouldBlock, then it's a bug)
success, data = rsq.get_noblock()
if success:
return data
raise data
@enable_ki_protection
def run_coro(
async_func: Callable[[Any, Any], Coroutine[Any, Any, Any]], *args, **kwargs
):
"""
Submits a coroutine for execution to the event loop from another thread,
passing any arguments along the way. Return values and exceptions are
propagated, and from the point of view of the calling thread this call
blocks until the coroutine returns. The thread must be async flavored,
meaning it must be able to communicate back and forth with the event
loop running in the main thread (in practice, this means only threads
spawned with run_in_worker are able to call this)
"""
try:
current_loop()
except StructIOException:
pass
else:
raise StructIOException("cannot be called from async context")
if not is_async_thread() or _storage.parent_loop.done():
raise StructIOException("run_coro requires a running loop in another thread!")
# Wake up the event loop if it's blocked in a call to select() or similar I/O routine
_storage.coro_runner.wakeup()
_storage.rq.put_sync(async_func(*args, **kwargs))
success, data = _storage.rsq.get_sync()
if success:
return data
raise data
def set_max_worker_count(count: int):
"""
Sets a new value for the maximum number of concurrent
worker threads structio is allowed to spawn
"""
# Everything, to avoid the unholy "global"
_storage.max_workers = Semaphore(count)
def get_max_worker_count() -> int:
"""
Gets the maximum number of concurrent worker
threads structio is allowed to spawn
"""
return _storage.max_workers.max_size