structio/structio/thread.py

268 lines
7.8 KiB
Python
Raw Normal View History

# Support module for running synchronous functions as
# coroutines into worker threads and to submit asynchronous
# work to the event loop from a synchronous thread
import threading
from collections import deque
import structio
from structio.sync import Event, Semaphore, Queue
from structio.util.ki import enable_ki_protection
from structio.core.syscalls import checkpoint
from structio.core.abc import BaseKernel
from structio.core.run import current_loop
from typing import Callable, Any, Coroutine
from structio.core.exceptions import StructIOException
_storage = threading.local()
# Max number of concurrent threads that can
# be spawned by run_in_worker before blocking
_storage.max_workers = Semaphore(50)
def is_async_thread() -> bool:
return hasattr(_storage, "parent_loop")
class AsyncThreadEvent(Event):
"""
An extension of the regular event
class that is safe to utilize both
from threads and from async code
"""
def __init__(self):
super().__init__()
self._lock = threading.Lock()
self._workers: deque[threading.Event] = deque()
@enable_ki_protection
def wait_sync(self):
"""
Like wait(), but synchronous
"""
with self._lock:
if self.is_set():
return
ev = threading.Event()
self._workers.append(ev)
ev.wait()
@enable_ki_protection
async def wait(self):
with self._lock:
if self.is_set():
return
await super().wait()
@enable_ki_protection
def set(self):
with self._lock:
if self.is_set():
return
# We can't just call super().set() because that
# will call current_loop(), and we may have been
# called from a non-async thread
loop: BaseKernel = _storage.parent_loop
for task in self._tasks:
loop.reschedule(task)
# Awakes all threads
for evt in self._workers:
evt.set()
self._set = True
class AsyncThreadQueue(Queue):
"""
An extension of the regular queue
class that is safe to use both from
threaded and asynchronous code
"""
def __init__(self, max_size):
super().__init__(max_size)
self._lock = threading.Lock()
@enable_ki_protection
async def get(self):
evt: AsyncThreadEvent | None = None
with self._lock:
if not self.container:
self.getters.append(AsyncThreadEvent())
evt = self.getters[-1]
if self.putters:
self.putters.popleft().set()
if evt:
await evt.wait()
return self.container.popleft()
@enable_ki_protection
async def put(self, item):
evt: AsyncThreadEvent | None = None
with self._lock:
if self.maxsize and self.maxsize == len(self.container):
self.putters.append(AsyncThreadEvent())
evt = self.putters[-1]
if self.getters:
self.getters.popleft().set()
if evt:
await evt.wait()
self.container.append(item)
await checkpoint()
@enable_ki_protection
def put_sync(self, item):
"""
Like put(), but synchronous
"""
evt: AsyncThreadEvent | None = None
with self._lock:
if self.maxsize and self.maxsize == len(self.container):
2023-05-18 11:23:34 +02:00
evt = AsyncThreadEvent()
self.putters.append(evt)
if self.getters:
self.getters.popleft().set()
if evt:
evt.wait_sync()
self.container.append(item)
@enable_ki_protection
def get_sync(self):
"""
2023-05-18 11:23:34 +02:00
Like get(), but synchronous
"""
evt: AsyncThreadEvent | None = None
with self._lock:
if not self.container:
self.getters.append(AsyncThreadEvent())
evt = self.getters[-1]
if self.putters:
self.putters.popleft().set()
if evt:
evt.wait_sync()
return self.container.popleft()
def _threaded_runner(f, q: AsyncThreadQueue, parent_loop: BaseKernel, rq: AsyncThreadQueue,
rsq: AsyncThreadQueue, *args):
try:
_storage.parent_loop = parent_loop
_storage.rq = rq
_storage.rsq = rsq
q.put_sync((True, f(*args)))
except BaseException as e:
q.put_sync((False, e))
@enable_ki_protection
async def _wait_for_thread(events, results: AsyncThreadQueue):
with structio.TaskScope(shielded=True):
while True:
data = await events.get()
if not data:
break
coro = data
try:
await results.put((True, await coro))
except BaseException as e:
await results.put((False, e))
2023-05-18 11:23:34 +02:00
@enable_ki_protection
async def _async_runner(f, *args):
queue = AsyncThreadQueue(1)
# Request queue
rq = AsyncThreadQueue(0)
# Results queue
rsq = AsyncThreadQueue(0)
current_loop().current_pool.spawn(_wait_for_thread, rq, rsq)
th = threading.Thread(target=_threaded_runner, args=(f, queue, current_loop(), rq, rsq, *args),
name="structio-worker-thread")
th.start()
success, data = await queue.get()
await rq.put(None)
if success:
return data
raise data
@enable_ki_protection
async def run_in_worker(sync_func,
*args,
):
"""
Call the given synchronous function in a separate
worker thread, turning it into an async operation.
The result of the call is returned, and any exceptions
are propagated back to the caller. Note that threaded
operations are not usually cancellable (i.e. the async
operation will fail when cancelled, but the thread will
continue running until termination, as there is no simple
and reliable way to stop a thread anywhere)
"""
if not hasattr(_storage, "parent_loop"):
_storage.parent_loop = current_loop()
else:
try:
current_loop()
except StructIOException:
raise StructIOException("cannot be called from sync context")
# This will automatically block once
# we run out of slots and proceed once
# we have more
async with _storage.max_workers:
# We inject a worker task into the current
# pool so waiting for the thread is handled
# as if it were a task
2023-05-18 11:23:34 +02:00
return await current_loop().current_pool.spawn(_async_runner, sync_func, *args)
@enable_ki_protection
def run_coro(async_func: Callable[[Any, Any], Coroutine[Any, Any, Any]],
*args, **kwargs):
"""
Submits a coroutine for execution to the event loop, passing any
arguments along the way. Return values and exceptions
are propagated and from the point of view of the calling thread,
this call blocks until the coroutine returns
"""
try:
current_loop()
raise StructIOException("cannot be called from async context")
except StructIOException:
pass
if not hasattr(_storage, "parent_loop"):
raise StructIOException("run_coro requires a running loop in another thread!")
_storage.rq.put_sync(async_func(*args, **kwargs))
success, data = _storage.rsq.get_sync()
if success:
return data
raise data
def set_max_worker_count(count: int):
"""
Sets a new value for the maximum number of concurrent
worker threads structio is allowed to spawn
"""
# Everything, to avoid the unholy "global"
_storage.max_workers = Semaphore(count)
def get_max_worker_count() -> int:
"""
Gets the maximum number of concurrent worker
threads structio is allowed to spawn
"""
return _storage.max_workers.max_size