2023-05-18 00:06:21 +02:00
|
|
|
# Support module for running synchronous functions as
|
|
|
|
# coroutines into worker threads and to submit asynchronous
|
|
|
|
# work to the event loop from a synchronous thread
|
|
|
|
import threading
|
|
|
|
from collections import deque
|
2023-05-18 09:55:10 +02:00
|
|
|
|
|
|
|
import structio
|
|
|
|
from structio.sync import Event, Semaphore, Queue
|
2023-05-18 00:06:21 +02:00
|
|
|
from structio.util.ki import enable_ki_protection
|
2023-05-18 09:55:10 +02:00
|
|
|
from structio.core.syscalls import checkpoint
|
|
|
|
from structio.core.abc import BaseKernel
|
|
|
|
from structio.core.run import current_loop, current_task
|
2023-05-18 00:06:21 +02:00
|
|
|
|
|
|
|
|
|
|
|
_storage = threading.local()
|
|
|
|
# Max number of concurrent threads that can
|
|
|
|
# be spawned by run_in_worker before blocking
|
|
|
|
_storage.max_workers = Semaphore(50)
|
|
|
|
|
|
|
|
|
|
|
|
class AsyncThreadEvent(Event):
|
|
|
|
"""
|
|
|
|
An extension of the regular event
|
|
|
|
class that is safe to utilize both
|
|
|
|
from threads and from async code
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self):
|
|
|
|
super().__init__()
|
|
|
|
self._lock = threading.Lock()
|
|
|
|
self._workers: deque[threading.Event] = deque()
|
|
|
|
|
|
|
|
@enable_ki_protection
|
|
|
|
def wait_sync(self):
|
|
|
|
"""
|
|
|
|
Like wait(), but synchronous
|
|
|
|
"""
|
|
|
|
|
|
|
|
with self._lock:
|
|
|
|
if self.is_set():
|
|
|
|
return
|
|
|
|
ev = threading.Event()
|
|
|
|
self._workers.append(ev)
|
|
|
|
ev.wait()
|
|
|
|
|
|
|
|
@enable_ki_protection
|
|
|
|
async def wait(self):
|
|
|
|
with self._lock:
|
|
|
|
if self.is_set():
|
|
|
|
return
|
|
|
|
await super().wait()
|
|
|
|
|
|
|
|
@enable_ki_protection
|
|
|
|
def set(self):
|
|
|
|
with self._lock:
|
|
|
|
if self.is_set():
|
|
|
|
return
|
2023-05-18 09:55:10 +02:00
|
|
|
# We can't just call super().set() because that
|
|
|
|
# will call current_loop(), and we may have been
|
|
|
|
# called from a non-async thread
|
|
|
|
loop: BaseKernel = _storage.parent_loop
|
|
|
|
for task in self._tasks:
|
|
|
|
loop.reschedule(task)
|
2023-05-18 00:06:21 +02:00
|
|
|
# Awakes all threads
|
|
|
|
for evt in self._workers:
|
|
|
|
evt.set()
|
2023-05-18 09:55:10 +02:00
|
|
|
self._set = True
|
|
|
|
|
|
|
|
|
|
|
|
class AsyncThreadQueue(Queue):
|
|
|
|
"""
|
|
|
|
An extension of the regular queue
|
|
|
|
class that is safe to use both from
|
|
|
|
threaded and asynchronous code
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self, max_size):
|
|
|
|
super().__init__(max_size)
|
|
|
|
self._lock = threading.Lock()
|
|
|
|
|
|
|
|
@enable_ki_protection
|
|
|
|
async def get(self):
|
|
|
|
evt: AsyncThreadEvent | None = None
|
|
|
|
with self._lock:
|
|
|
|
if not self.container:
|
|
|
|
self.getters.append(AsyncThreadEvent())
|
|
|
|
evt = self.getters[-1]
|
|
|
|
if self.putters:
|
|
|
|
self.putters.popleft().set()
|
|
|
|
if evt:
|
|
|
|
await evt.wait()
|
|
|
|
return self.container.popleft()
|
|
|
|
|
|
|
|
@enable_ki_protection
|
|
|
|
async def put(self, item):
|
|
|
|
evt: AsyncThreadEvent | None = None
|
|
|
|
with self._lock:
|
|
|
|
if self.maxsize and self.maxsize == len(self.container):
|
|
|
|
self.putters.append(AsyncThreadEvent())
|
|
|
|
evt = self.putters[-1]
|
|
|
|
if self.getters:
|
|
|
|
self.getters.popleft().set()
|
|
|
|
if evt:
|
|
|
|
await evt.wait()
|
|
|
|
self.container.append(item)
|
|
|
|
await checkpoint()
|
2023-05-18 00:06:21 +02:00
|
|
|
|
2023-05-18 09:55:10 +02:00
|
|
|
@enable_ki_protection
|
|
|
|
def put_sync(self, item):
|
|
|
|
"""
|
|
|
|
Like put(), but synchronous
|
|
|
|
"""
|
2023-05-18 00:06:21 +02:00
|
|
|
|
2023-05-18 09:55:10 +02:00
|
|
|
evt: AsyncThreadEvent | None = None
|
|
|
|
with self._lock:
|
|
|
|
if self.maxsize and self.maxsize == len(self.container):
|
|
|
|
self.putters.append(AsyncThreadEvent())
|
|
|
|
if self.getters:
|
|
|
|
self.getters.popleft().set()
|
|
|
|
if evt:
|
|
|
|
evt.wait_sync()
|
|
|
|
self.container.append(item)
|
|
|
|
|
|
|
|
@enable_ki_protection
|
|
|
|
def get_sync(self):
|
|
|
|
"""
|
|
|
|
Like get(), but asynchronous
|
|
|
|
"""
|
|
|
|
|
|
|
|
evt: AsyncThreadEvent | None = None
|
|
|
|
with self._lock:
|
|
|
|
if not self.container:
|
|
|
|
self.getters.append(AsyncThreadEvent())
|
|
|
|
evt = self.getters[-1]
|
|
|
|
if self.putters:
|
|
|
|
self.putters.popleft().set()
|
|
|
|
if evt:
|
|
|
|
evt.wait_sync()
|
|
|
|
return self.container.popleft()
|
|
|
|
|
|
|
|
|
|
|
|
def _threaded_runner(f, q: AsyncThreadQueue, parent_loop: BaseKernel, *args):
|
|
|
|
try:
|
|
|
|
_storage.parent_loop = parent_loop
|
|
|
|
q.put_sync((True, f(*args)))
|
|
|
|
except BaseException as e:
|
|
|
|
q.put_sync((False, e))
|
|
|
|
|
|
|
|
|
|
|
|
async def _async_runner(f, *args):
|
|
|
|
queue = AsyncThreadQueue(1)
|
|
|
|
th = threading.Thread(target=_threaded_runner, args=(f, queue, current_loop(), *args),
|
|
|
|
name="structio-worker-thread")
|
|
|
|
th.start()
|
|
|
|
success, data = await queue.get()
|
|
|
|
if success:
|
|
|
|
return data
|
|
|
|
raise data
|
|
|
|
|
|
|
|
|
|
|
|
async def run_in_worker(sync_func,
|
2023-05-18 00:06:21 +02:00
|
|
|
*args,
|
|
|
|
):
|
2023-05-18 09:55:10 +02:00
|
|
|
"""
|
|
|
|
Call the given synchronous function in a separate
|
|
|
|
worker thread, turning it into an async operation.
|
|
|
|
The result of the call is returned, and any exceptions
|
|
|
|
are propagated back to the caller. Note that threaded
|
|
|
|
operations are not usually cancellable (i.e. the async
|
|
|
|
operation will fail when cancelled, but the thread will
|
|
|
|
continue running until termination, as there is no simple
|
|
|
|
and reliable way to stop a thread anywhere)
|
|
|
|
"""
|
|
|
|
|
|
|
|
if not hasattr(_storage, "parent_loop"):
|
|
|
|
_storage.parent_loop = current_loop()
|
2023-05-18 00:06:21 +02:00
|
|
|
async with _storage.max_workers:
|
2023-05-18 09:55:10 +02:00
|
|
|
async with structio.create_pool() as pool:
|
|
|
|
# This will automatically block once
|
|
|
|
# we run out of slots and proceed once
|
|
|
|
# we have more
|
|
|
|
return await pool.spawn(_async_runner, sync_func, *args)
|
2023-05-18 00:06:21 +02:00
|
|
|
|
|
|
|
|
|
|
|
def set_max_worker_count(count: int):
|
|
|
|
"""
|
|
|
|
Sets a new value for the maximum number of concurrent
|
|
|
|
worker threads structio is allowed to spawn
|
|
|
|
"""
|
|
|
|
|
|
|
|
# Everything, to avoid the unholy "global"
|
|
|
|
_storage.max_workers = Semaphore(count)
|
|
|
|
|
|
|
|
|
|
|
|
def get_max_worker_count() -> int:
|
|
|
|
"""
|
|
|
|
Gets the maximum number of concurrent worker
|
|
|
|
threads structio is allowed to spawn
|
|
|
|
"""
|
|
|
|
|
|
|
|
return _storage.max_workers.max_size
|
|
|
|
|
|
|
|
|
|
|
|
|