Many fixes to nested scope handling, added stress test and scope tree test
This commit is contained in:
parent
e0f2e87cad
commit
9e6ee1e104
|
@ -1,4 +1,4 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.10 (structio)" project-jdk-type="Python SDK" />
|
||||
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.11 (StructuredIO)" project-jdk-type="Python SDK" />
|
||||
</project>
|
|
@ -37,6 +37,7 @@ from structio import thread, parallel
|
|||
from structio.path import Path
|
||||
from structio.signals import set_signal_handler, get_signal_handler
|
||||
from structio import signals as _signals
|
||||
from structio import util
|
||||
|
||||
|
||||
def run(
|
||||
|
@ -168,4 +169,5 @@ __all__ = [
|
|||
"parallel",
|
||||
"get_signal_handler",
|
||||
"set_signal_handler",
|
||||
"util"
|
||||
]
|
||||
|
|
|
@ -36,7 +36,7 @@ class TaskScope:
|
|||
# Data about inner and outer scopes.
|
||||
# This is used internally to make sure
|
||||
# nesting task scopes works as expected
|
||||
self.inner: TaskScope | None = None
|
||||
self.inner: list[TaskScope] = []
|
||||
self.outer: TaskScope | None = None
|
||||
# Which tasks do we contain?
|
||||
self.tasks: list[Task] = []
|
||||
|
@ -53,20 +53,14 @@ class TaskScope:
|
|||
def get_actual_timeout(self):
|
||||
"""
|
||||
Returns the effective timeout of the whole
|
||||
cancel scope. This is different from the
|
||||
self.timeout parameter because cancel scopes
|
||||
can be nested, and we might have a parent with
|
||||
a lower timeout than us
|
||||
:return:
|
||||
cancel scope
|
||||
"""
|
||||
|
||||
if self.outer is None:
|
||||
return self.timeout
|
||||
current = self.inner
|
||||
while current:
|
||||
if current.shielded:
|
||||
for child in self.children:
|
||||
if child.shielded:
|
||||
return float("inf")
|
||||
current = current.inner
|
||||
return min([self.timeout, self.outer.get_actual_timeout()])
|
||||
|
||||
def __enter__(self):
|
||||
|
@ -81,12 +75,29 @@ class TaskScope:
|
|||
return self.silent
|
||||
return False
|
||||
|
||||
# Just a recursive helper
|
||||
def _get_children(self, lst=None):
|
||||
if lst is None:
|
||||
lst = []
|
||||
for child in self.inner:
|
||||
lst.append(child)
|
||||
child._get_children(lst)
|
||||
return lst
|
||||
|
||||
@property
|
||||
def children(self) -> list["TaskScope"]:
|
||||
"""
|
||||
Gets all the scopes contained within this one
|
||||
"""
|
||||
|
||||
return self._get_children()
|
||||
|
||||
def done(self):
|
||||
"""
|
||||
Returns whether the task scope has finished executing
|
||||
"""
|
||||
|
||||
if self.inner and not self.inner.done():
|
||||
if not all(child.done() for child in self.children):
|
||||
return False
|
||||
return all(task.done() for task in self.tasks)
|
||||
|
||||
|
@ -103,10 +114,7 @@ class TaskPool:
|
|||
|
||||
self.entry_point: Task | None = None
|
||||
self.scope: TaskScope = TaskScope(timeout=float("inf"))
|
||||
# Data about inner and outer pools.
|
||||
# This is used internally to make sure
|
||||
# nesting task pools works as expected
|
||||
self.inner: TaskPool | None = None
|
||||
# This pool's parent
|
||||
self.outer: TaskPool | None = None
|
||||
# Have we errored out?
|
||||
self.error: BaseException | None = None
|
||||
|
@ -128,6 +136,8 @@ class TaskPool:
|
|||
raise exc_val.with_traceback(exc_tb)
|
||||
elif not self.done():
|
||||
await suspend()
|
||||
else:
|
||||
await checkpoint()
|
||||
except Cancelled as e:
|
||||
self.error = e
|
||||
self.scope.cancelled = True
|
||||
|
@ -136,7 +146,6 @@ class TaskPool:
|
|||
self.scope.cancel()
|
||||
finally:
|
||||
current_loop().close_pool(self)
|
||||
self.scope.__exit__(exc_type, exc_val, exc_tb)
|
||||
self._closed = True
|
||||
if self.error:
|
||||
raise self.error
|
||||
|
@ -159,4 +168,5 @@ class TaskPool:
|
|||
executing until it is picked by the scheduler later on
|
||||
"""
|
||||
|
||||
return current_loop().spawn(func, *args)
|
||||
self.scope.tasks.append(current_loop().spawn(func, *args))
|
||||
return self.scope.tasks[-1]
|
||||
|
|
|
@ -53,13 +53,10 @@ class FIFOKernel(BaseKernel):
|
|||
self._sigint_handled: bool = False
|
||||
# Paused tasks along with their deadlines
|
||||
self.paused: TimeQueue = TimeQueue(self.clock)
|
||||
# All task scopes we handle
|
||||
self.scopes: list[TaskScope] = []
|
||||
self.pool = TaskPool()
|
||||
self.current_pool = self.pool
|
||||
self.current_scope = self.current_pool.scope
|
||||
self.pool.scope.shielded = True
|
||||
self.current_scope = self.pool.scope
|
||||
self.current_scope.shielded = False
|
||||
self.scopes.append(self.current_scope)
|
||||
self._closing = False
|
||||
|
||||
def get_closest_deadline(self):
|
||||
|
@ -125,21 +122,22 @@ class FIFOKernel(BaseKernel):
|
|||
return True
|
||||
if any([self.run_queue, self.paused, self.io_manager.pending()]):
|
||||
return False
|
||||
for scope in self.scopes:
|
||||
if not scope.done():
|
||||
return False
|
||||
if not self.pool.done():
|
||||
return False
|
||||
return True
|
||||
|
||||
def spawn(self, func: Callable[[Any, Any], Coroutine[Any, Any, Any]], *args):
|
||||
def spawn(self, func: Callable[[Any, Any], Coroutine[Any, Any, Any]], *args,
|
||||
ki_protected: bool = False, pool: TaskPool = None):
|
||||
if isinstance(func, partial):
|
||||
name = func.func.__name__ or repr(func.func)
|
||||
else:
|
||||
name = func.__name__ or repr(func)
|
||||
task = Task(name, func(*args), self.current_pool)
|
||||
if pool is None:
|
||||
pool = self.current_pool
|
||||
task = Task(name, func(*args), pool)
|
||||
# We inject our magic secret variable into the coroutine's stack frame, so
|
||||
# we can look it up later
|
||||
task.coroutine.cr_frame.f_locals.setdefault(CTRLC_PROTECTION_ENABLED, False)
|
||||
task.pool.scope.tasks.append(task)
|
||||
task.coroutine.cr_frame.f_locals.setdefault(CTRLC_PROTECTION_ENABLED, ki_protected)
|
||||
self.run_queue.append(task)
|
||||
self.event("on_task_spawn")
|
||||
return task
|
||||
|
@ -147,16 +145,7 @@ class FIFOKernel(BaseKernel):
|
|||
def spawn_system_task(
|
||||
self, func: Callable[[Any, Any], Coroutine[Any, Any, Any]], *args
|
||||
):
|
||||
if isinstance(func, partial):
|
||||
name = func.func.__name__ or repr(func.func)
|
||||
else:
|
||||
name = func.__name__ or repr(func)
|
||||
task = Task(name, func(*args), self.pool)
|
||||
task.coroutine.cr_frame.f_locals.setdefault(CTRLC_PROTECTION_ENABLED, True)
|
||||
task.pool.scope.tasks.append(task)
|
||||
self.run_queue.append(task)
|
||||
self.event("on_task_spawn")
|
||||
return task
|
||||
return self.spawn(func, *args, ki_protected=True, pool=self.pool)
|
||||
|
||||
def signal_notify(self, sig: int, frame: FrameType):
|
||||
match sig:
|
||||
|
@ -207,14 +196,10 @@ class FIFOKernel(BaseKernel):
|
|||
def throw(self, task: Task, err: BaseException):
|
||||
if task.done():
|
||||
return
|
||||
if task.state == TaskState.PAUSED:
|
||||
self.paused.discard(task)
|
||||
elif task.state == TaskState.IO:
|
||||
self.io_manager.release_task(task)
|
||||
self.handle_errors(partial(task.coroutine.throw, err), task)
|
||||
|
||||
def reschedule(self, task: Task):
|
||||
if task.done() or task in self.run_queue:
|
||||
if task.done():
|
||||
return
|
||||
self.run_queue.append(task)
|
||||
|
||||
|
@ -250,7 +235,7 @@ class FIFOKernel(BaseKernel):
|
|||
self.reschedule_running()
|
||||
|
||||
def check_scopes(self):
|
||||
for scope in self.scopes:
|
||||
for scope in self.pool.scope.children:
|
||||
if scope.get_actual_timeout() <= self.clock.current_time():
|
||||
error = TimedOut("timed out")
|
||||
error.scope = scope
|
||||
|
@ -275,11 +260,11 @@ class FIFOKernel(BaseKernel):
|
|||
"""
|
||||
|
||||
while not self.done():
|
||||
if self._sigint_handled and not self.restrict_ki_to_checkpoints:
|
||||
self.throw(self.entry_point, KeyboardInterrupt())
|
||||
if self.run_queue and not self.skip:
|
||||
self.handle_errors(self.step)
|
||||
self.skip = False
|
||||
if self._sigint_handled and not self.restrict_ki_to_checkpoints:
|
||||
self.throw(self.entry_point, KeyboardInterrupt())
|
||||
if self.io_manager.pending():
|
||||
self.io_manager.wait_io()
|
||||
self.wakeup()
|
||||
|
@ -355,17 +340,20 @@ class FIFOKernel(BaseKernel):
|
|||
self.io_manager.release_task(task)
|
||||
self.paused.discard(task)
|
||||
|
||||
def _reschedule_scope(self, scope: TaskScope):
|
||||
while scope.done() and scope is not self.pool.scope:
|
||||
self.reschedule(scope.owner)
|
||||
scope = scope.outer
|
||||
|
||||
def on_success(self, task: Task):
|
||||
"""
|
||||
The given task has exited gracefully: hooray!
|
||||
"""
|
||||
|
||||
# TODO: Anything else?
|
||||
task.pool: TaskPool
|
||||
for waiter in task.waiters:
|
||||
self.reschedule(waiter)
|
||||
if task.pool.done():
|
||||
self.reschedule(task.pool.entry_point)
|
||||
self._reschedule_scope(task.pool.scope)
|
||||
task.waiters.clear()
|
||||
self.event("on_task_exit", task)
|
||||
self.io_manager.release_task(task)
|
||||
|
@ -378,6 +366,7 @@ class FIFOKernel(BaseKernel):
|
|||
self.event("on_exception_raised", task, task.exc)
|
||||
for waiter in task.waiters:
|
||||
self.reschedule(waiter)
|
||||
self._reschedule_scope(task.pool.scope)
|
||||
if task.pool.scope.owner is not self.current_task:
|
||||
self.throw(task.pool.scope.owner, task.exc)
|
||||
task.waiters.clear()
|
||||
|
@ -393,19 +382,16 @@ class FIFOKernel(BaseKernel):
|
|||
self.reschedule(waiter)
|
||||
task.waiters.clear()
|
||||
self.release(task)
|
||||
if task.pool.done():
|
||||
self.reschedule(task.pool.entry_point)
|
||||
self._reschedule_scope(task.pool.scope)
|
||||
|
||||
def init_scope(self, scope: TaskScope):
|
||||
scope.owner = self.current_task
|
||||
self.current_scope.inner = scope
|
||||
self.current_scope.inner.append(scope)
|
||||
scope.outer = self.current_scope
|
||||
self.current_scope = scope
|
||||
self.scopes.append(scope)
|
||||
|
||||
def close_scope(self, scope: TaskScope):
|
||||
self.current_scope = scope.outer
|
||||
self.scopes.pop()
|
||||
|
||||
def cancel_task(self, task: Task):
|
||||
if task.done():
|
||||
|
@ -435,9 +421,8 @@ class FIFOKernel(BaseKernel):
|
|||
# to
|
||||
if self.current_task in scope.tasks and self.current_task is not scope.owner:
|
||||
self.current_task.pending_cancellation = True
|
||||
inner = scope.inner
|
||||
if inner and not inner.shielded:
|
||||
self.cancel_scope(inner)
|
||||
for child in filter(lambda c: not c.shielded, scope.children):
|
||||
self.cancel_scope(child)
|
||||
for task in scope.tasks:
|
||||
if task is self.current_task:
|
||||
continue
|
||||
|
@ -455,11 +440,11 @@ class FIFOKernel(BaseKernel):
|
|||
def init_pool(self, pool: TaskPool):
|
||||
pool.outer = self.current_pool
|
||||
pool.entry_point = self.current_task
|
||||
self.current_pool.inner = pool
|
||||
self.current_pool = pool
|
||||
|
||||
def close_pool(self, pool: TaskPool):
|
||||
self.current_pool = pool.outer
|
||||
self.close_scope(pool.scope)
|
||||
|
||||
def suspend(self):
|
||||
self.current_task.state = TaskState.PAUSED
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import inspect
|
||||
import structio
|
||||
import functools
|
||||
from threading import local
|
||||
from structio.abc import (
|
||||
|
@ -86,6 +87,9 @@ def run(
|
|||
raise StructIOException(
|
||||
"structio.run() requires an async function as its first argument!"
|
||||
)
|
||||
waker = structio.util.wakeup_fd.WakeupFd()
|
||||
watcher = structio.signals.signal_watcher
|
||||
waker.set_wakeup_fd()
|
||||
new_event_loop(
|
||||
kernel(
|
||||
clock=clock,
|
||||
|
@ -95,5 +99,6 @@ def run(
|
|||
tools=tools,
|
||||
)
|
||||
)
|
||||
current_loop().spawn_system_task(watcher, waker.reader)
|
||||
return current_loop().start(func, *args)
|
||||
|
||||
|
|
|
@ -1,6 +1,9 @@
|
|||
from enum import Enum, auto
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Coroutine, Any, Callable
|
||||
from itertools import count
|
||||
|
||||
_counter = count()
|
||||
|
||||
|
||||
class TaskState(Enum):
|
||||
|
@ -31,6 +34,8 @@ class Task:
|
|||
pool: "TaskPool" = field(repr=False)
|
||||
# The state of the task
|
||||
state: TaskState = field(default=TaskState.INIT)
|
||||
# Used for debugging
|
||||
id: int = field(default_factory=lambda: next(_counter))
|
||||
# What error did the task raise, if any?
|
||||
exc: BaseException | None = None
|
||||
# The task's return value, if any
|
||||
|
|
|
@ -307,6 +307,8 @@ class AsyncSocket(AsyncResource):
|
|||
self.socket = sock
|
||||
self.socket.setblocking(False)
|
||||
self.connected: bool = False
|
||||
self.write_lock = structio.util.misc.ThereCanBeOnlyOne("another task is writing on this socket")
|
||||
self.read_lock = structio.util.misc.ThereCanBeOnlyOne("another task is writing on this socket")
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
@ -322,15 +324,16 @@ class AsyncSocket(AsyncResource):
|
|||
assert max_size >= 1, "max_size must be >= 1"
|
||||
if self._fd == -1:
|
||||
raise ResourceClosed("I/O operation on closed socket")
|
||||
while True:
|
||||
try:
|
||||
data = self.socket.recv(max_size, flags)
|
||||
await checkpoint()
|
||||
return data
|
||||
except WantRead:
|
||||
await wait_readable(self._fd)
|
||||
except WantWrite:
|
||||
await wait_writable(self._fd)
|
||||
with self.read_lock:
|
||||
while True:
|
||||
try:
|
||||
data = self.socket.recv(max_size, flags)
|
||||
await checkpoint()
|
||||
return data
|
||||
except WantRead:
|
||||
await wait_readable(self._fd)
|
||||
except WantWrite:
|
||||
await wait_writable(self._fd)
|
||||
|
||||
async def receive_exactly(self, size: int, flags: int = 0) -> bytes:
|
||||
"""
|
||||
|
@ -338,14 +341,15 @@ class AsyncSocket(AsyncResource):
|
|||
"""
|
||||
|
||||
# https://stackoverflow.com/questions/55825905/how-can-i-reliably-read-exactly-n-bytes-from-a-tcp-socket
|
||||
buf = bytearray(size)
|
||||
pos = 0
|
||||
while pos < size:
|
||||
n = await self.recv_into(memoryview(buf)[pos:], flags=flags)
|
||||
if n == 0:
|
||||
raise ResourceBroken("incomplete read detected")
|
||||
pos += n
|
||||
return bytes(buf)
|
||||
with self.read_lock:
|
||||
buf = bytearray(size)
|
||||
pos = 0
|
||||
while pos < size:
|
||||
n = await self.recv_into(memoryview(buf)[pos:], flags=flags)
|
||||
if n == 0:
|
||||
raise ResourceBroken("incomplete read detected")
|
||||
pos += n
|
||||
return bytes(buf)
|
||||
|
||||
async def connect(self, address):
|
||||
"""
|
||||
|
@ -399,16 +403,17 @@ class AsyncSocket(AsyncResource):
|
|||
|
||||
if self._fd == -1:
|
||||
raise ResourceClosed("I/O operation on closed socket")
|
||||
sent_no = 0
|
||||
while data:
|
||||
try:
|
||||
sent_no = self.socket.send(data, flags)
|
||||
await checkpoint()
|
||||
except WantRead:
|
||||
await wait_readable(self._fd)
|
||||
except WantWrite:
|
||||
await wait_writable(self._fd)
|
||||
data = data[sent_no:]
|
||||
with self.write_lock:
|
||||
sent_no = 0
|
||||
while data:
|
||||
try:
|
||||
sent_no = self.socket.send(data, flags)
|
||||
await checkpoint()
|
||||
except WantRead:
|
||||
await wait_readable(self._fd)
|
||||
except WantWrite:
|
||||
await wait_writable(self._fd)
|
||||
data = data[sent_no:]
|
||||
|
||||
async def shutdown(self, how):
|
||||
"""
|
||||
|
@ -497,45 +502,48 @@ class AsyncSocket(AsyncResource):
|
|||
Wrapper socket method
|
||||
"""
|
||||
|
||||
while True:
|
||||
try:
|
||||
data = self.socket.recvfrom(buffersize, flags)
|
||||
await checkpoint()
|
||||
return data
|
||||
except WantRead:
|
||||
await wait_readable(self._fd)
|
||||
except WantWrite:
|
||||
await wait_writable(self._fd)
|
||||
with self.read_lock:
|
||||
while True:
|
||||
try:
|
||||
data = self.socket.recvfrom(buffersize, flags)
|
||||
await checkpoint()
|
||||
return data
|
||||
except WantRead:
|
||||
await wait_readable(self._fd)
|
||||
except WantWrite:
|
||||
await wait_writable(self._fd)
|
||||
|
||||
async def recv_into(self, buffer, nbytes=0, flags=0):
|
||||
"""
|
||||
Wrapper socket method
|
||||
"""
|
||||
|
||||
while True:
|
||||
try:
|
||||
data = self.socket.recv_into(buffer, nbytes, flags)
|
||||
await checkpoint()
|
||||
return data
|
||||
except WantRead:
|
||||
await wait_readable(self._fd)
|
||||
except WantWrite:
|
||||
await wait_writable(self._fd)
|
||||
with self.read_lock:
|
||||
while True:
|
||||
try:
|
||||
data = self.socket.recv_into(buffer, nbytes, flags)
|
||||
await checkpoint()
|
||||
return data
|
||||
except WantRead:
|
||||
await wait_readable(self._fd)
|
||||
except WantWrite:
|
||||
await wait_writable(self._fd)
|
||||
|
||||
async def recvfrom_into(self, buffer, bytes=0, flags=0):
|
||||
"""
|
||||
Wrapper socket method
|
||||
"""
|
||||
|
||||
while True:
|
||||
try:
|
||||
data = self.socket.recvfrom_into(buffer, bytes, flags)
|
||||
await checkpoint()
|
||||
return data
|
||||
except WantRead:
|
||||
await wait_readable(self._fd)
|
||||
except WantWrite:
|
||||
await wait_writable(self._fd)
|
||||
with self.read_lock:
|
||||
while True:
|
||||
try:
|
||||
data = self.socket.recvfrom_into(buffer, bytes, flags)
|
||||
await checkpoint()
|
||||
return data
|
||||
except WantRead:
|
||||
await wait_readable(self._fd)
|
||||
except WantWrite:
|
||||
await wait_writable(self._fd)
|
||||
|
||||
async def sendto(self, bytes, flags_or_address, address=None):
|
||||
"""
|
||||
|
@ -547,15 +555,16 @@ class AsyncSocket(AsyncResource):
|
|||
else:
|
||||
address = flags_or_address
|
||||
flags = 0
|
||||
while True:
|
||||
try:
|
||||
data = self.socket.sendto(bytes, flags, address)
|
||||
await checkpoint()
|
||||
return data
|
||||
except WantWrite:
|
||||
await wait_writable(self._fd)
|
||||
except WantRead:
|
||||
await wait_readable(self._fd)
|
||||
with self.write_lock:
|
||||
while True:
|
||||
try:
|
||||
data = self.socket.sendto(bytes, flags, address)
|
||||
await checkpoint()
|
||||
return data
|
||||
except WantWrite:
|
||||
await wait_writable(self._fd)
|
||||
except WantRead:
|
||||
await wait_readable(self._fd)
|
||||
|
||||
async def getpeername(self):
|
||||
"""
|
||||
|
@ -592,39 +601,42 @@ class AsyncSocket(AsyncResource):
|
|||
Wrapper socket method
|
||||
"""
|
||||
|
||||
while True:
|
||||
try:
|
||||
data = self.socket.recvmsg(bufsize, ancbufsize, flags)
|
||||
await checkpoint()
|
||||
return data
|
||||
except WantRead:
|
||||
await wait_readable(self._fd)
|
||||
with self.read_lock:
|
||||
while True:
|
||||
try:
|
||||
data = self.socket.recvmsg(bufsize, ancbufsize, flags)
|
||||
await checkpoint()
|
||||
return data
|
||||
except WantRead:
|
||||
await wait_readable(self._fd)
|
||||
|
||||
async def recvmsg_into(self, buffers, ancbufsize=0, flags=0):
|
||||
"""
|
||||
Wrapper socket method
|
||||
"""
|
||||
|
||||
while True:
|
||||
try:
|
||||
data = self.socket.recvmsg_into(buffers, ancbufsize, flags)
|
||||
await checkpoint()
|
||||
return data
|
||||
except WantRead:
|
||||
await wait_readable(self._fd)
|
||||
with self.read_lock:
|
||||
while True:
|
||||
try:
|
||||
data = self.socket.recvmsg_into(buffers, ancbufsize, flags)
|
||||
await checkpoint()
|
||||
return data
|
||||
except WantRead:
|
||||
await wait_readable(self._fd)
|
||||
|
||||
async def sendmsg(self, buffers, ancdata=(), flags=0, address=None):
|
||||
"""
|
||||
Wrapper socket method
|
||||
"""
|
||||
|
||||
while True:
|
||||
try:
|
||||
data = self.socket.sendmsg(buffers, ancdata, flags, address)
|
||||
await checkpoint()
|
||||
return data
|
||||
except WantRead:
|
||||
await wait_writable(self._fd)
|
||||
with self.write_lock:
|
||||
while True:
|
||||
try:
|
||||
data = self.socket.sendmsg(buffers, ancdata, flags, address)
|
||||
await checkpoint()
|
||||
return data
|
||||
except WantRead:
|
||||
await wait_writable(self._fd)
|
||||
|
||||
def __repr__(self):
|
||||
return f"AsyncSocket({self.socket})"
|
||||
|
|
|
@ -2,11 +2,9 @@
|
|||
import signal
|
||||
from collections import defaultdict
|
||||
from types import FrameType
|
||||
|
||||
from structio.io.socket import AsyncSocket, socketpair
|
||||
from structio.io.socket import AsyncSocket
|
||||
from typing import Callable, Any, Coroutine
|
||||
from structio.thread import AsyncThreadQueue
|
||||
from structio.core.task import Task
|
||||
from structio.core.run import current_loop
|
||||
|
||||
|
||||
|
@ -14,8 +12,6 @@ _sig_data = AsyncThreadQueue(float("inf"))
|
|||
_sig_handlers: dict[
|
||||
signal.Signals, Callable[[Any, Any], Coroutine[Any, Any, Any]] | None
|
||||
] = defaultdict(lambda: None)
|
||||
_watcher: Task | None = None
|
||||
_reader, _writer = socketpair()
|
||||
|
||||
|
||||
def _handle(sig: int, frame: FrameType):
|
||||
|
@ -42,10 +38,6 @@ def set_signal_handler(
|
|||
None is returned
|
||||
"""
|
||||
|
||||
global _watcher
|
||||
if not _watcher:
|
||||
signal.set_wakeup_fd(_writer.fileno())
|
||||
_watcher = current_loop().spawn_system_task(signal_watcher, _reader)
|
||||
# Raises an appropriate error
|
||||
sig = signal.Signals(sig)
|
||||
match sig:
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
# Support module for running synchronous functions as
|
||||
# coroutines into worker threads and to submit asynchronous
|
||||
# work to the event loop from a synchronous thread
|
||||
from functools import partial
|
||||
|
||||
import structio
|
||||
import threading
|
||||
from collections import deque
|
||||
|
@ -12,7 +14,6 @@ from structio.sync import Event, Semaphore, Queue
|
|||
from structio.util.ki import enable_ki_protection
|
||||
from structio.exceptions import StructIOException
|
||||
|
||||
|
||||
_storage = threading.local()
|
||||
# Max number of concurrent threads that can
|
||||
# be spawned by run_in_worker before blocking
|
||||
|
@ -62,7 +63,8 @@ class AsyncThreadEvent(Event):
|
|||
return
|
||||
# We can't just call super().set() because that
|
||||
# will call current_loop(), and we may have been
|
||||
# called from a non-async thread
|
||||
# called from an async thread that doesn't have a
|
||||
# loop
|
||||
loop: BaseKernel = _storage.parent_loop
|
||||
for task in self._tasks:
|
||||
loop.reschedule(task)
|
||||
|
@ -154,6 +156,7 @@ def _threaded_runner(
|
|||
rq: AsyncThreadQueue,
|
||||
rsq: AsyncThreadQueue,
|
||||
evt: AsyncThreadEvent,
|
||||
writer,
|
||||
*args,
|
||||
):
|
||||
try:
|
||||
|
@ -162,6 +165,7 @@ def _threaded_runner(
|
|||
_storage.parent_loop = parent_loop
|
||||
_storage.rq = rq
|
||||
_storage.rsq = rsq
|
||||
_storage.wakeup = writer
|
||||
result = f(*args)
|
||||
except BaseException as e:
|
||||
rsq.put_sync((False, e))
|
||||
|
@ -175,7 +179,7 @@ def _threaded_runner(
|
|||
|
||||
@enable_ki_protection
|
||||
async def _coroutine_request_handler(
|
||||
events: AsyncThreadQueue, results: AsyncThreadQueue
|
||||
events: AsyncThreadQueue, results: AsyncThreadQueue, sock: "structio.socket.AsyncSocket"
|
||||
):
|
||||
"""
|
||||
Runs coroutines on behalf of a thread spawned by structio and
|
||||
|
@ -183,10 +187,8 @@ async def _coroutine_request_handler(
|
|||
"""
|
||||
|
||||
while True:
|
||||
data = await events.get()
|
||||
if not data:
|
||||
break
|
||||
coro = data
|
||||
await sock.receive(1)
|
||||
coro = await events.get()
|
||||
try:
|
||||
result = await coro
|
||||
except BaseException as e:
|
||||
|
@ -195,76 +197,6 @@ async def _coroutine_request_handler(
|
|||
await results.put((True, result))
|
||||
|
||||
|
||||
@enable_ki_protection
|
||||
async def _wait_for_thread(
|
||||
events: AsyncThreadQueue,
|
||||
results: AsyncThreadQueue,
|
||||
termination_event: AsyncThreadEvent,
|
||||
cancellable: bool = False,
|
||||
):
|
||||
"""
|
||||
Waits for a thread spawned by structio to complete and
|
||||
returns its result. Exceptions are also propagated
|
||||
"""
|
||||
|
||||
async with structio.create_pool() as pool:
|
||||
# If the operation is cancellable, then we're not
|
||||
# shielded
|
||||
pool.scope.shielded = not cancellable
|
||||
# Spawn a coroutine to process incoming requests from
|
||||
# the new async thread. We can't await it because it
|
||||
# needs to run in the background
|
||||
pool.spawn(_coroutine_request_handler, events, results)
|
||||
# Wait for the thread to terminate
|
||||
await termination_event.wait()
|
||||
# Worker thread has exited: we no longer need to process
|
||||
# any requests, so we shut our request handler down
|
||||
await events.put(None)
|
||||
# Wait for the final result from the thread
|
||||
success, data = await results.get()
|
||||
if success:
|
||||
return data
|
||||
raise data
|
||||
|
||||
|
||||
@enable_ki_protection
|
||||
async def _spawn_supervised_thread(f, cancellable: bool = False, *args):
|
||||
# Thread termination event
|
||||
terminate = AsyncThreadEvent()
|
||||
# Request queue. This is where the thread
|
||||
# sends coroutines to run
|
||||
rq = AsyncThreadQueue(0)
|
||||
# Results queue. This is where we put the result
|
||||
# of the coroutines in the request queue
|
||||
rsq = AsyncThreadQueue(0)
|
||||
# This looks like a lot of bookkeeping to do synchronization, but it all has a purpose.
|
||||
# The termination event is necessary so that _wait_for_thread can know when to shut
|
||||
# down (and, by extension, shut down its workers too). The request and result queues
|
||||
# are used to send coroutines and their results back and forth when using run_coro from
|
||||
# within the "asynchronous thread". Trying to reduce the amount of primitives turns out
|
||||
# to be very hard, because we'd have at least 3 different things (_wait_for_thread,
|
||||
# _threaded_runner and _coroutine_request_handler) trying to work on the same resources, which is
|
||||
# a hellish nightmare to synchronize properly. For example, _coroutine_request_handler *could* just
|
||||
# use a single queue for sending data back and forth, but since it runs in a while loop in order to
|
||||
# handle more than one request, as soon as it would put any data onto the queue and then go to the
|
||||
# next iteration in the loop, it would (likely, but not always, as it depends on how things get
|
||||
# scheduled) immediately call get() again, get something out of queue that it doesn't expect and
|
||||
# crash horribly. So this separation is necessary to retain my sanity
|
||||
threading.Thread(
|
||||
target=_threaded_runner,
|
||||
args=(f, current_loop(), rq, rsq, terminate, *args),
|
||||
# We start cancellable threads in daemonic mode so that
|
||||
# the main thread doesn't get stuck waiting on them forever
|
||||
# when their associated async counterpart gets cancelled. This
|
||||
# is due to the fact that there's really no way to "kill" a thread
|
||||
# (and for good reason!), so we just pretend nothing happened and go
|
||||
# about our merry way, hoping the thread dies eventually I guess
|
||||
name="structio-worker-thread",
|
||||
daemon=cancellable,
|
||||
).start()
|
||||
return await _wait_for_thread(rq, rsq, terminate, cancellable)
|
||||
|
||||
|
||||
@enable_ki_protection
|
||||
async def run_in_worker(
|
||||
sync_func,
|
||||
|
@ -319,7 +251,59 @@ async def run_in_worker(
|
|||
# we run out of slots and proceed once
|
||||
# we have more
|
||||
async with _storage.max_workers:
|
||||
return await _spawn_supervised_thread(sync_func, cancellable, *args)
|
||||
# Thread termination event
|
||||
terminate = AsyncThreadEvent()
|
||||
# Request queue. This is where the thread
|
||||
# sends coroutines to run
|
||||
rq = AsyncThreadQueue(0)
|
||||
# Results queue. This is where we put the result
|
||||
# of the coroutines in the request queue
|
||||
rsq = AsyncThreadQueue(0)
|
||||
# This looks like a lot of bookkeeping to do synchronization, but it all has a purpose.
|
||||
# The termination event is necessary so that _wait_for_thread can know when to shut
|
||||
# down (and, by extension, shut down its workers too). The request and result queues
|
||||
# are used to send coroutines and their results back and forth when using run_coro from
|
||||
# within the "asynchronous thread". Trying to reduce the amount of primitives turns out
|
||||
# to be very hard, because we'd have at least 3 different things (_wait_for_thread,
|
||||
# _threaded_runner and _coroutine_request_handler) trying to work on the same resources, which is
|
||||
# a hellish nightmare to synchronize properly. For example, _coroutine_request_handler *could* just
|
||||
# use a single queue for sending data back and forth, but since it runs in a while loop in order to
|
||||
# handle more than one request, as soon as it would put any data onto the queue and then go to the
|
||||
# next iteration in the loop, it would (likely, but not always, as it depends on how things get
|
||||
# scheduled) immediately call get() again, get something out of queue that it doesn't expect and
|
||||
# crash horribly. So this separation is necessary to retain my sanity
|
||||
async with structio.create_pool() as pool:
|
||||
# If the operation is cancellable, then we're not
|
||||
# shielded
|
||||
pool.scope.shielded = not cancellable
|
||||
# Spawn a coroutine to process incoming requests from
|
||||
# the new async thread. We can't await it because it
|
||||
# needs to run in the background
|
||||
wakeup = structio.util.wakeup_fd.WakeupFd()
|
||||
handler = pool.spawn(_coroutine_request_handler, rq, rsq, wakeup.reader)
|
||||
# Start the worker thread
|
||||
threading.Thread(
|
||||
target=_threaded_runner,
|
||||
args=(sync_func, current_loop(), rq, rsq, terminate, wakeup, *args),
|
||||
name="structio-worker-thread",
|
||||
# We start cancellable threads in daemonic mode so that
|
||||
# the main thread doesn't get stuck waiting on them forever
|
||||
# when their associated async counterpart gets cancelled. This
|
||||
# is due to the fact that there's really no way to "kill" a thread
|
||||
# (and for good reason!), so we just pretend nothing happened and go
|
||||
# about our merry way, hoping the thread dies eventually I guess
|
||||
daemon=cancellable,
|
||||
).start()
|
||||
# Wait for the thread to terminate
|
||||
await terminate.wait()
|
||||
# Worker thread has exited: we no longer need to process
|
||||
# any requests, so we shut our request handler down
|
||||
handler.cancel()
|
||||
# Wait for the final result from the thread
|
||||
success, data = await rsq.get()
|
||||
if success:
|
||||
return data
|
||||
raise data
|
||||
|
||||
|
||||
@enable_ki_protection
|
||||
|
@ -327,10 +311,13 @@ def run_coro(
|
|||
async_func: Callable[[Any, Any], Coroutine[Any, Any, Any]], *args, **kwargs
|
||||
):
|
||||
"""
|
||||
Submits a coroutine for execution to the event loop, passing any
|
||||
arguments along the way. Return values and exceptions are propagated
|
||||
and from the point of view of the calling thread, this call blocks
|
||||
until the coroutine returns
|
||||
Submits a coroutine for execution to the event loop from another thread,
|
||||
passing any arguments along the way. Return values and exceptions are
|
||||
propagated, and from the point of view of the calling thread this call
|
||||
blocks until the coroutine returns. The thread must be async flavored,
|
||||
meaning it must be able to communicate back and forth with the event
|
||||
loop running in the main thread (in practice, this means only threads
|
||||
spawned with run_in_worker are able to call this)
|
||||
"""
|
||||
|
||||
try:
|
||||
|
@ -338,8 +325,10 @@ def run_coro(
|
|||
raise StructIOException("cannot be called from async context")
|
||||
except StructIOException:
|
||||
pass
|
||||
if not hasattr(_storage, "parent_loop"):
|
||||
if not is_async_thread():
|
||||
raise StructIOException("run_coro requires a running loop in another thread!")
|
||||
# Wake up the event loop if it's blocked in a call to select() or similar I/O routine
|
||||
_storage.wakeup.wakeup()
|
||||
_storage.rq.put_sync(async_func(*args, **kwargs))
|
||||
success, data = _storage.rsq.get_sync()
|
||||
if success:
|
||||
|
|
|
@ -0,0 +1,6 @@
|
|||
from . import misc, ki, wakeup_fd
|
||||
|
||||
|
||||
__all__ = ["misc",
|
||||
"ki",
|
||||
"wakeup_fd"]
|
|
@ -25,3 +25,6 @@ class ThereCanBeOnlyOne:
|
|||
|
||||
def __exit__(self, *args):
|
||||
self._acquired = False
|
||||
|
||||
|
||||
__all__ = ["ThereCanBeOnlyOne"]
|
||||
|
|
|
@ -0,0 +1,23 @@
|
|||
from structio.io.socket import socketpair
|
||||
import signal
|
||||
|
||||
|
||||
class WakeupFd:
|
||||
"""
|
||||
A thin wrapper over a socket pair used in set_wakeup_fd
|
||||
and for thread wakeup events
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.reader, self.writer = socketpair()
|
||||
|
||||
def set_wakeup_fd(self):
|
||||
signal.set_wakeup_fd(self.writer.socket.fileno())
|
||||
|
||||
def wakeup(self):
|
||||
try:
|
||||
self.writer.socket.send(b"\x00")
|
||||
except BlockingIOError:
|
||||
pass
|
||||
|
||||
|
|
@ -0,0 +1,20 @@
|
|||
import structio
|
||||
|
||||
|
||||
async def child(k):
|
||||
print("[child] I'm alive! Spawning sleeper")
|
||||
async with structio.create_pool() as p:
|
||||
p.spawn(structio.sleep, k)
|
||||
print("[child] I'm done sleeping!")
|
||||
|
||||
|
||||
async def main(n: int, k):
|
||||
print(f"[main] Spawning {n} children in their own pools, each sleeping for {k} seconds")
|
||||
t = structio.clock()
|
||||
async with structio.create_pool() as p:
|
||||
for _ in range(n):
|
||||
p.spawn(child, k)
|
||||
print(f"[main] Done in {structio.clock() - t:.2f} seconds")
|
||||
|
||||
# Should exit in ~2 seconds
|
||||
structio.run(main, 10, 2)
|
|
@ -0,0 +1,26 @@
|
|||
import datetime as dtt
|
||||
import structio
|
||||
|
||||
|
||||
async def task():
|
||||
for i in range(100):
|
||||
await structio.sleep(0.01)
|
||||
|
||||
|
||||
async def main(tests: list[int]):
|
||||
print("[main] Starting stress test, aggregate results will be printed at the end")
|
||||
results = []
|
||||
for N in tests:
|
||||
print(f"[main] Starting test with {N} tasks")
|
||||
start = dtt.datetime.utcnow()
|
||||
async with structio.create_pool() as p:
|
||||
for _ in range(N):
|
||||
p.spawn(task)
|
||||
end = dtt.datetime.utcnow()
|
||||
results.append((end - start).total_seconds())
|
||||
print(f"[main] Test with {N} tasks completed in {results[-1]:.2f} seconds")
|
||||
results = " ".join((f'{r:0>5.2f}' for r in results))
|
||||
print(f"[main] Results: {results}")
|
||||
|
||||
|
||||
structio.run(main, [10, 100, 1000, 10000])
|
|
@ -2,16 +2,16 @@ import structio
|
|||
import time
|
||||
|
||||
|
||||
def fake_async_sleeper(n):
|
||||
print(f"[thread] About to sleep for {n} seconds")
|
||||
def fake_async_sleeper(n, name: str = ""):
|
||||
print(f"[thread{f' {name}' if name else ''}] About to sleep for {n} seconds")
|
||||
t = time.time()
|
||||
if structio.thread.is_async_thread():
|
||||
print(f"[thread] I have async superpowers!")
|
||||
print(f"[thread{f' {name}' if name else ''}] I have async superpowers!")
|
||||
structio.thread.run_coro(structio.sleep, n)
|
||||
else:
|
||||
print(f"[thread] Using old boring time.sleep :(")
|
||||
print(f"[thread{f' {name}' if name else ''}] Using old boring time.sleep :(")
|
||||
time.sleep(n)
|
||||
print(f"[thread] Slept for {time.time() - t:.2f} seconds")
|
||||
print(f"[thread{f' {name}' if name else ''}] Slept for {time.time() - t:.2f} seconds")
|
||||
return n
|
||||
|
||||
|
||||
|
@ -33,5 +33,22 @@ async def main_timeout(n, k):
|
|||
print(f"[main] Exited in {structio.clock() - t:.2f} seconds")
|
||||
|
||||
|
||||
async def main_multiple(n, k):
|
||||
print(f"[main] Spawning {n} worker threads each sleeping for {k} seconds")
|
||||
t = structio.clock()
|
||||
async with structio.create_pool() as pool:
|
||||
for i in range(n):
|
||||
pool.spawn(structio.thread.run_in_worker, fake_async_sleeper, k, str(i))
|
||||
print(f"[main] Workers spawned")
|
||||
# Keep in mind that there is some overhead associated with running worker threads,
|
||||
# not to mention that it gets tricky with how the OS schedules them and whatnot. So,
|
||||
# it's unlikely that all threads finish exactly at the same time and that we exit in
|
||||
# k seconds, even just because there's a lot of back and forth going on under the hood
|
||||
# between structio and the worker threads themselves
|
||||
print(f"[main] Exited in {structio.clock() - t:.2f} seconds")
|
||||
|
||||
|
||||
structio.run(main, 2)
|
||||
structio.run(main_timeout, 5, 3)
|
||||
structio.run(main_multiple, 5, 3)
|
||||
|
||||
|
|
Loading…
Reference in New Issue