92 lines
2.3 KiB
Python
92 lines
2.3 KiB
Python
|
# Support module for running synchronous functions as
|
||
|
# coroutines into worker threads and to submit asynchronous
|
||
|
# work to the event loop from a synchronous thread
|
||
|
import threading
|
||
|
from typing import Callable, Any, Coroutine
|
||
|
from collections import deque
|
||
|
from structio.sync import Event, Semaphore
|
||
|
from structio.util.ki import enable_ki_protection
|
||
|
|
||
|
|
||
|
_storage = threading.local()
|
||
|
# Max number of concurrent threads that can
|
||
|
# be spawned by run_in_worker before blocking
|
||
|
_storage.max_workers = Semaphore(50)
|
||
|
|
||
|
|
||
|
class AsyncThreadEvent(Event):
|
||
|
"""
|
||
|
An extension of the regular event
|
||
|
class that is safe to utilize both
|
||
|
from threads and from async code
|
||
|
"""
|
||
|
|
||
|
def __init__(self):
|
||
|
super().__init__()
|
||
|
self._lock = threading.Lock()
|
||
|
self._workers: deque[threading.Event] = deque()
|
||
|
|
||
|
@enable_ki_protection
|
||
|
def wait_sync(self):
|
||
|
"""
|
||
|
Like wait(), but synchronous
|
||
|
"""
|
||
|
|
||
|
with self._lock:
|
||
|
if self.is_set():
|
||
|
return
|
||
|
ev = threading.Event()
|
||
|
self._workers.append(ev)
|
||
|
ev.wait()
|
||
|
|
||
|
@enable_ki_protection
|
||
|
async def wait(self):
|
||
|
with self._lock:
|
||
|
if self.is_set():
|
||
|
return
|
||
|
await super().wait()
|
||
|
|
||
|
@enable_ki_protection
|
||
|
def set(self):
|
||
|
with self._lock:
|
||
|
if self.is_set():
|
||
|
return
|
||
|
# Awakes all coroutines
|
||
|
super().set()
|
||
|
# Awakes all threads
|
||
|
for evt in self._workers:
|
||
|
evt.set()
|
||
|
|
||
|
|
||
|
async def run_in_worker(func: Callable[[Any, Any], Coroutine[Any, Any, Any]],
|
||
|
*args,
|
||
|
**kwargs,
|
||
|
):
|
||
|
async with _storage.max_workers:
|
||
|
# This will automatically block once
|
||
|
# we run out of slots and proceed once
|
||
|
# we have more
|
||
|
pass # TODO
|
||
|
|
||
|
|
||
|
def set_max_worker_count(count: int):
|
||
|
"""
|
||
|
Sets a new value for the maximum number of concurrent
|
||
|
worker threads structio is allowed to spawn
|
||
|
"""
|
||
|
|
||
|
# Everything, to avoid the unholy "global"
|
||
|
_storage.max_workers = Semaphore(count)
|
||
|
|
||
|
|
||
|
def get_max_worker_count() -> int:
|
||
|
"""
|
||
|
Gets the maximum number of concurrent worker
|
||
|
threads structio is allowed to spawn
|
||
|
"""
|
||
|
|
||
|
return _storage.max_workers.max_size
|
||
|
|
||
|
|
||
|
|