Many fixes to nested scope handling, added stress test and scope tree test

This commit is contained in:
Mattia Giambirtone 2023-06-27 17:58:12 +02:00 committed by nocturn9x
parent e0f2e87cad
commit 9e6ee1e104
Signed by: nocturn9x
GPG Key ID: 8270F9F467971E59
15 changed files with 336 additions and 241 deletions

View File

@ -1,4 +1,4 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.10 (structio)" project-jdk-type="Python SDK" />
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.11 (StructuredIO)" project-jdk-type="Python SDK" />
</project>

View File

@ -37,6 +37,7 @@ from structio import thread, parallel
from structio.path import Path
from structio.signals import set_signal_handler, get_signal_handler
from structio import signals as _signals
from structio import util
def run(
@ -168,4 +169,5 @@ __all__ = [
"parallel",
"get_signal_handler",
"set_signal_handler",
"util"
]

View File

@ -36,7 +36,7 @@ class TaskScope:
# Data about inner and outer scopes.
# This is used internally to make sure
# nesting task scopes works as expected
self.inner: TaskScope | None = None
self.inner: list[TaskScope] = []
self.outer: TaskScope | None = None
# Which tasks do we contain?
self.tasks: list[Task] = []
@ -53,20 +53,14 @@ class TaskScope:
def get_actual_timeout(self):
"""
Returns the effective timeout of the whole
cancel scope. This is different from the
self.timeout parameter because cancel scopes
can be nested, and we might have a parent with
a lower timeout than us
:return:
cancel scope
"""
if self.outer is None:
return self.timeout
current = self.inner
while current:
if current.shielded:
for child in self.children:
if child.shielded:
return float("inf")
current = current.inner
return min([self.timeout, self.outer.get_actual_timeout()])
def __enter__(self):
@ -81,12 +75,29 @@ class TaskScope:
return self.silent
return False
# Just a recursive helper
def _get_children(self, lst=None):
if lst is None:
lst = []
for child in self.inner:
lst.append(child)
child._get_children(lst)
return lst
@property
def children(self) -> list["TaskScope"]:
"""
Gets all the scopes contained within this one
"""
return self._get_children()
def done(self):
"""
Returns whether the task scope has finished executing
"""
if self.inner and not self.inner.done():
if not all(child.done() for child in self.children):
return False
return all(task.done() for task in self.tasks)
@ -103,10 +114,7 @@ class TaskPool:
self.entry_point: Task | None = None
self.scope: TaskScope = TaskScope(timeout=float("inf"))
# Data about inner and outer pools.
# This is used internally to make sure
# nesting task pools works as expected
self.inner: TaskPool | None = None
# This pool's parent
self.outer: TaskPool | None = None
# Have we errored out?
self.error: BaseException | None = None
@ -128,6 +136,8 @@ class TaskPool:
raise exc_val.with_traceback(exc_tb)
elif not self.done():
await suspend()
else:
await checkpoint()
except Cancelled as e:
self.error = e
self.scope.cancelled = True
@ -136,7 +146,6 @@ class TaskPool:
self.scope.cancel()
finally:
current_loop().close_pool(self)
self.scope.__exit__(exc_type, exc_val, exc_tb)
self._closed = True
if self.error:
raise self.error
@ -159,4 +168,5 @@ class TaskPool:
executing until it is picked by the scheduler later on
"""
return current_loop().spawn(func, *args)
self.scope.tasks.append(current_loop().spawn(func, *args))
return self.scope.tasks[-1]

View File

@ -53,13 +53,10 @@ class FIFOKernel(BaseKernel):
self._sigint_handled: bool = False
# Paused tasks along with their deadlines
self.paused: TimeQueue = TimeQueue(self.clock)
# All task scopes we handle
self.scopes: list[TaskScope] = []
self.pool = TaskPool()
self.current_pool = self.pool
self.current_scope = self.current_pool.scope
self.pool.scope.shielded = True
self.current_scope = self.pool.scope
self.current_scope.shielded = False
self.scopes.append(self.current_scope)
self._closing = False
def get_closest_deadline(self):
@ -125,21 +122,22 @@ class FIFOKernel(BaseKernel):
return True
if any([self.run_queue, self.paused, self.io_manager.pending()]):
return False
for scope in self.scopes:
if not scope.done():
return False
if not self.pool.done():
return False
return True
def spawn(self, func: Callable[[Any, Any], Coroutine[Any, Any, Any]], *args):
def spawn(self, func: Callable[[Any, Any], Coroutine[Any, Any, Any]], *args,
ki_protected: bool = False, pool: TaskPool = None):
if isinstance(func, partial):
name = func.func.__name__ or repr(func.func)
else:
name = func.__name__ or repr(func)
task = Task(name, func(*args), self.current_pool)
if pool is None:
pool = self.current_pool
task = Task(name, func(*args), pool)
# We inject our magic secret variable into the coroutine's stack frame, so
# we can look it up later
task.coroutine.cr_frame.f_locals.setdefault(CTRLC_PROTECTION_ENABLED, False)
task.pool.scope.tasks.append(task)
task.coroutine.cr_frame.f_locals.setdefault(CTRLC_PROTECTION_ENABLED, ki_protected)
self.run_queue.append(task)
self.event("on_task_spawn")
return task
@ -147,16 +145,7 @@ class FIFOKernel(BaseKernel):
def spawn_system_task(
self, func: Callable[[Any, Any], Coroutine[Any, Any, Any]], *args
):
if isinstance(func, partial):
name = func.func.__name__ or repr(func.func)
else:
name = func.__name__ or repr(func)
task = Task(name, func(*args), self.pool)
task.coroutine.cr_frame.f_locals.setdefault(CTRLC_PROTECTION_ENABLED, True)
task.pool.scope.tasks.append(task)
self.run_queue.append(task)
self.event("on_task_spawn")
return task
return self.spawn(func, *args, ki_protected=True, pool=self.pool)
def signal_notify(self, sig: int, frame: FrameType):
match sig:
@ -207,14 +196,10 @@ class FIFOKernel(BaseKernel):
def throw(self, task: Task, err: BaseException):
if task.done():
return
if task.state == TaskState.PAUSED:
self.paused.discard(task)
elif task.state == TaskState.IO:
self.io_manager.release_task(task)
self.handle_errors(partial(task.coroutine.throw, err), task)
def reschedule(self, task: Task):
if task.done() or task in self.run_queue:
if task.done():
return
self.run_queue.append(task)
@ -250,7 +235,7 @@ class FIFOKernel(BaseKernel):
self.reschedule_running()
def check_scopes(self):
for scope in self.scopes:
for scope in self.pool.scope.children:
if scope.get_actual_timeout() <= self.clock.current_time():
error = TimedOut("timed out")
error.scope = scope
@ -275,11 +260,11 @@ class FIFOKernel(BaseKernel):
"""
while not self.done():
if self._sigint_handled and not self.restrict_ki_to_checkpoints:
self.throw(self.entry_point, KeyboardInterrupt())
if self.run_queue and not self.skip:
self.handle_errors(self.step)
self.skip = False
if self._sigint_handled and not self.restrict_ki_to_checkpoints:
self.throw(self.entry_point, KeyboardInterrupt())
if self.io_manager.pending():
self.io_manager.wait_io()
self.wakeup()
@ -355,17 +340,20 @@ class FIFOKernel(BaseKernel):
self.io_manager.release_task(task)
self.paused.discard(task)
def _reschedule_scope(self, scope: TaskScope):
while scope.done() and scope is not self.pool.scope:
self.reschedule(scope.owner)
scope = scope.outer
def on_success(self, task: Task):
"""
The given task has exited gracefully: hooray!
"""
# TODO: Anything else?
task.pool: TaskPool
for waiter in task.waiters:
self.reschedule(waiter)
if task.pool.done():
self.reschedule(task.pool.entry_point)
self._reschedule_scope(task.pool.scope)
task.waiters.clear()
self.event("on_task_exit", task)
self.io_manager.release_task(task)
@ -378,6 +366,7 @@ class FIFOKernel(BaseKernel):
self.event("on_exception_raised", task, task.exc)
for waiter in task.waiters:
self.reschedule(waiter)
self._reschedule_scope(task.pool.scope)
if task.pool.scope.owner is not self.current_task:
self.throw(task.pool.scope.owner, task.exc)
task.waiters.clear()
@ -393,19 +382,16 @@ class FIFOKernel(BaseKernel):
self.reschedule(waiter)
task.waiters.clear()
self.release(task)
if task.pool.done():
self.reschedule(task.pool.entry_point)
self._reschedule_scope(task.pool.scope)
def init_scope(self, scope: TaskScope):
scope.owner = self.current_task
self.current_scope.inner = scope
self.current_scope.inner.append(scope)
scope.outer = self.current_scope
self.current_scope = scope
self.scopes.append(scope)
def close_scope(self, scope: TaskScope):
self.current_scope = scope.outer
self.scopes.pop()
def cancel_task(self, task: Task):
if task.done():
@ -435,9 +421,8 @@ class FIFOKernel(BaseKernel):
# to
if self.current_task in scope.tasks and self.current_task is not scope.owner:
self.current_task.pending_cancellation = True
inner = scope.inner
if inner and not inner.shielded:
self.cancel_scope(inner)
for child in filter(lambda c: not c.shielded, scope.children):
self.cancel_scope(child)
for task in scope.tasks:
if task is self.current_task:
continue
@ -455,11 +440,11 @@ class FIFOKernel(BaseKernel):
def init_pool(self, pool: TaskPool):
pool.outer = self.current_pool
pool.entry_point = self.current_task
self.current_pool.inner = pool
self.current_pool = pool
def close_pool(self, pool: TaskPool):
self.current_pool = pool.outer
self.close_scope(pool.scope)
def suspend(self):
self.current_task.state = TaskState.PAUSED

View File

@ -1,4 +1,5 @@
import inspect
import structio
import functools
from threading import local
from structio.abc import (
@ -86,6 +87,9 @@ def run(
raise StructIOException(
"structio.run() requires an async function as its first argument!"
)
waker = structio.util.wakeup_fd.WakeupFd()
watcher = structio.signals.signal_watcher
waker.set_wakeup_fd()
new_event_loop(
kernel(
clock=clock,
@ -95,5 +99,6 @@ def run(
tools=tools,
)
)
current_loop().spawn_system_task(watcher, waker.reader)
return current_loop().start(func, *args)

View File

@ -1,6 +1,9 @@
from enum import Enum, auto
from dataclasses import dataclass, field
from typing import Coroutine, Any, Callable
from itertools import count
_counter = count()
class TaskState(Enum):
@ -31,6 +34,8 @@ class Task:
pool: "TaskPool" = field(repr=False)
# The state of the task
state: TaskState = field(default=TaskState.INIT)
# Used for debugging
id: int = field(default_factory=lambda: next(_counter))
# What error did the task raise, if any?
exc: BaseException | None = None
# The task's return value, if any

View File

@ -307,6 +307,8 @@ class AsyncSocket(AsyncResource):
self.socket = sock
self.socket.setblocking(False)
self.connected: bool = False
self.write_lock = structio.util.misc.ThereCanBeOnlyOne("another task is writing on this socket")
self.read_lock = structio.util.misc.ThereCanBeOnlyOne("another task is writing on this socket")
async def __aenter__(self):
return self
@ -322,15 +324,16 @@ class AsyncSocket(AsyncResource):
assert max_size >= 1, "max_size must be >= 1"
if self._fd == -1:
raise ResourceClosed("I/O operation on closed socket")
while True:
try:
data = self.socket.recv(max_size, flags)
await checkpoint()
return data
except WantRead:
await wait_readable(self._fd)
except WantWrite:
await wait_writable(self._fd)
with self.read_lock:
while True:
try:
data = self.socket.recv(max_size, flags)
await checkpoint()
return data
except WantRead:
await wait_readable(self._fd)
except WantWrite:
await wait_writable(self._fd)
async def receive_exactly(self, size: int, flags: int = 0) -> bytes:
"""
@ -338,14 +341,15 @@ class AsyncSocket(AsyncResource):
"""
# https://stackoverflow.com/questions/55825905/how-can-i-reliably-read-exactly-n-bytes-from-a-tcp-socket
buf = bytearray(size)
pos = 0
while pos < size:
n = await self.recv_into(memoryview(buf)[pos:], flags=flags)
if n == 0:
raise ResourceBroken("incomplete read detected")
pos += n
return bytes(buf)
with self.read_lock:
buf = bytearray(size)
pos = 0
while pos < size:
n = await self.recv_into(memoryview(buf)[pos:], flags=flags)
if n == 0:
raise ResourceBroken("incomplete read detected")
pos += n
return bytes(buf)
async def connect(self, address):
"""
@ -399,16 +403,17 @@ class AsyncSocket(AsyncResource):
if self._fd == -1:
raise ResourceClosed("I/O operation on closed socket")
sent_no = 0
while data:
try:
sent_no = self.socket.send(data, flags)
await checkpoint()
except WantRead:
await wait_readable(self._fd)
except WantWrite:
await wait_writable(self._fd)
data = data[sent_no:]
with self.write_lock:
sent_no = 0
while data:
try:
sent_no = self.socket.send(data, flags)
await checkpoint()
except WantRead:
await wait_readable(self._fd)
except WantWrite:
await wait_writable(self._fd)
data = data[sent_no:]
async def shutdown(self, how):
"""
@ -497,45 +502,48 @@ class AsyncSocket(AsyncResource):
Wrapper socket method
"""
while True:
try:
data = self.socket.recvfrom(buffersize, flags)
await checkpoint()
return data
except WantRead:
await wait_readable(self._fd)
except WantWrite:
await wait_writable(self._fd)
with self.read_lock:
while True:
try:
data = self.socket.recvfrom(buffersize, flags)
await checkpoint()
return data
except WantRead:
await wait_readable(self._fd)
except WantWrite:
await wait_writable(self._fd)
async def recv_into(self, buffer, nbytes=0, flags=0):
"""
Wrapper socket method
"""
while True:
try:
data = self.socket.recv_into(buffer, nbytes, flags)
await checkpoint()
return data
except WantRead:
await wait_readable(self._fd)
except WantWrite:
await wait_writable(self._fd)
with self.read_lock:
while True:
try:
data = self.socket.recv_into(buffer, nbytes, flags)
await checkpoint()
return data
except WantRead:
await wait_readable(self._fd)
except WantWrite:
await wait_writable(self._fd)
async def recvfrom_into(self, buffer, bytes=0, flags=0):
"""
Wrapper socket method
"""
while True:
try:
data = self.socket.recvfrom_into(buffer, bytes, flags)
await checkpoint()
return data
except WantRead:
await wait_readable(self._fd)
except WantWrite:
await wait_writable(self._fd)
with self.read_lock:
while True:
try:
data = self.socket.recvfrom_into(buffer, bytes, flags)
await checkpoint()
return data
except WantRead:
await wait_readable(self._fd)
except WantWrite:
await wait_writable(self._fd)
async def sendto(self, bytes, flags_or_address, address=None):
"""
@ -547,15 +555,16 @@ class AsyncSocket(AsyncResource):
else:
address = flags_or_address
flags = 0
while True:
try:
data = self.socket.sendto(bytes, flags, address)
await checkpoint()
return data
except WantWrite:
await wait_writable(self._fd)
except WantRead:
await wait_readable(self._fd)
with self.write_lock:
while True:
try:
data = self.socket.sendto(bytes, flags, address)
await checkpoint()
return data
except WantWrite:
await wait_writable(self._fd)
except WantRead:
await wait_readable(self._fd)
async def getpeername(self):
"""
@ -592,39 +601,42 @@ class AsyncSocket(AsyncResource):
Wrapper socket method
"""
while True:
try:
data = self.socket.recvmsg(bufsize, ancbufsize, flags)
await checkpoint()
return data
except WantRead:
await wait_readable(self._fd)
with self.read_lock:
while True:
try:
data = self.socket.recvmsg(bufsize, ancbufsize, flags)
await checkpoint()
return data
except WantRead:
await wait_readable(self._fd)
async def recvmsg_into(self, buffers, ancbufsize=0, flags=0):
"""
Wrapper socket method
"""
while True:
try:
data = self.socket.recvmsg_into(buffers, ancbufsize, flags)
await checkpoint()
return data
except WantRead:
await wait_readable(self._fd)
with self.read_lock:
while True:
try:
data = self.socket.recvmsg_into(buffers, ancbufsize, flags)
await checkpoint()
return data
except WantRead:
await wait_readable(self._fd)
async def sendmsg(self, buffers, ancdata=(), flags=0, address=None):
"""
Wrapper socket method
"""
while True:
try:
data = self.socket.sendmsg(buffers, ancdata, flags, address)
await checkpoint()
return data
except WantRead:
await wait_writable(self._fd)
with self.write_lock:
while True:
try:
data = self.socket.sendmsg(buffers, ancdata, flags, address)
await checkpoint()
return data
except WantRead:
await wait_writable(self._fd)
def __repr__(self):
return f"AsyncSocket({self.socket})"

View File

@ -2,11 +2,9 @@
import signal
from collections import defaultdict
from types import FrameType
from structio.io.socket import AsyncSocket, socketpair
from structio.io.socket import AsyncSocket
from typing import Callable, Any, Coroutine
from structio.thread import AsyncThreadQueue
from structio.core.task import Task
from structio.core.run import current_loop
@ -14,8 +12,6 @@ _sig_data = AsyncThreadQueue(float("inf"))
_sig_handlers: dict[
signal.Signals, Callable[[Any, Any], Coroutine[Any, Any, Any]] | None
] = defaultdict(lambda: None)
_watcher: Task | None = None
_reader, _writer = socketpair()
def _handle(sig: int, frame: FrameType):
@ -42,10 +38,6 @@ def set_signal_handler(
None is returned
"""
global _watcher
if not _watcher:
signal.set_wakeup_fd(_writer.fileno())
_watcher = current_loop().spawn_system_task(signal_watcher, _reader)
# Raises an appropriate error
sig = signal.Signals(sig)
match sig:

View File

@ -1,6 +1,8 @@
# Support module for running synchronous functions as
# coroutines into worker threads and to submit asynchronous
# work to the event loop from a synchronous thread
from functools import partial
import structio
import threading
from collections import deque
@ -12,7 +14,6 @@ from structio.sync import Event, Semaphore, Queue
from structio.util.ki import enable_ki_protection
from structio.exceptions import StructIOException
_storage = threading.local()
# Max number of concurrent threads that can
# be spawned by run_in_worker before blocking
@ -62,7 +63,8 @@ class AsyncThreadEvent(Event):
return
# We can't just call super().set() because that
# will call current_loop(), and we may have been
# called from a non-async thread
# called from an async thread that doesn't have a
# loop
loop: BaseKernel = _storage.parent_loop
for task in self._tasks:
loop.reschedule(task)
@ -154,6 +156,7 @@ def _threaded_runner(
rq: AsyncThreadQueue,
rsq: AsyncThreadQueue,
evt: AsyncThreadEvent,
writer,
*args,
):
try:
@ -162,6 +165,7 @@ def _threaded_runner(
_storage.parent_loop = parent_loop
_storage.rq = rq
_storage.rsq = rsq
_storage.wakeup = writer
result = f(*args)
except BaseException as e:
rsq.put_sync((False, e))
@ -175,7 +179,7 @@ def _threaded_runner(
@enable_ki_protection
async def _coroutine_request_handler(
events: AsyncThreadQueue, results: AsyncThreadQueue
events: AsyncThreadQueue, results: AsyncThreadQueue, sock: "structio.socket.AsyncSocket"
):
"""
Runs coroutines on behalf of a thread spawned by structio and
@ -183,10 +187,8 @@ async def _coroutine_request_handler(
"""
while True:
data = await events.get()
if not data:
break
coro = data
await sock.receive(1)
coro = await events.get()
try:
result = await coro
except BaseException as e:
@ -195,76 +197,6 @@ async def _coroutine_request_handler(
await results.put((True, result))
@enable_ki_protection
async def _wait_for_thread(
events: AsyncThreadQueue,
results: AsyncThreadQueue,
termination_event: AsyncThreadEvent,
cancellable: bool = False,
):
"""
Waits for a thread spawned by structio to complete and
returns its result. Exceptions are also propagated
"""
async with structio.create_pool() as pool:
# If the operation is cancellable, then we're not
# shielded
pool.scope.shielded = not cancellable
# Spawn a coroutine to process incoming requests from
# the new async thread. We can't await it because it
# needs to run in the background
pool.spawn(_coroutine_request_handler, events, results)
# Wait for the thread to terminate
await termination_event.wait()
# Worker thread has exited: we no longer need to process
# any requests, so we shut our request handler down
await events.put(None)
# Wait for the final result from the thread
success, data = await results.get()
if success:
return data
raise data
@enable_ki_protection
async def _spawn_supervised_thread(f, cancellable: bool = False, *args):
# Thread termination event
terminate = AsyncThreadEvent()
# Request queue. This is where the thread
# sends coroutines to run
rq = AsyncThreadQueue(0)
# Results queue. This is where we put the result
# of the coroutines in the request queue
rsq = AsyncThreadQueue(0)
# This looks like a lot of bookkeeping to do synchronization, but it all has a purpose.
# The termination event is necessary so that _wait_for_thread can know when to shut
# down (and, by extension, shut down its workers too). The request and result queues
# are used to send coroutines and their results back and forth when using run_coro from
# within the "asynchronous thread". Trying to reduce the amount of primitives turns out
# to be very hard, because we'd have at least 3 different things (_wait_for_thread,
# _threaded_runner and _coroutine_request_handler) trying to work on the same resources, which is
# a hellish nightmare to synchronize properly. For example, _coroutine_request_handler *could* just
# use a single queue for sending data back and forth, but since it runs in a while loop in order to
# handle more than one request, as soon as it would put any data onto the queue and then go to the
# next iteration in the loop, it would (likely, but not always, as it depends on how things get
# scheduled) immediately call get() again, get something out of queue that it doesn't expect and
# crash horribly. So this separation is necessary to retain my sanity
threading.Thread(
target=_threaded_runner,
args=(f, current_loop(), rq, rsq, terminate, *args),
# We start cancellable threads in daemonic mode so that
# the main thread doesn't get stuck waiting on them forever
# when their associated async counterpart gets cancelled. This
# is due to the fact that there's really no way to "kill" a thread
# (and for good reason!), so we just pretend nothing happened and go
# about our merry way, hoping the thread dies eventually I guess
name="structio-worker-thread",
daemon=cancellable,
).start()
return await _wait_for_thread(rq, rsq, terminate, cancellable)
@enable_ki_protection
async def run_in_worker(
sync_func,
@ -319,7 +251,59 @@ async def run_in_worker(
# we run out of slots and proceed once
# we have more
async with _storage.max_workers:
return await _spawn_supervised_thread(sync_func, cancellable, *args)
# Thread termination event
terminate = AsyncThreadEvent()
# Request queue. This is where the thread
# sends coroutines to run
rq = AsyncThreadQueue(0)
# Results queue. This is where we put the result
# of the coroutines in the request queue
rsq = AsyncThreadQueue(0)
# This looks like a lot of bookkeeping to do synchronization, but it all has a purpose.
# The termination event is necessary so that _wait_for_thread can know when to shut
# down (and, by extension, shut down its workers too). The request and result queues
# are used to send coroutines and their results back and forth when using run_coro from
# within the "asynchronous thread". Trying to reduce the amount of primitives turns out
# to be very hard, because we'd have at least 3 different things (_wait_for_thread,
# _threaded_runner and _coroutine_request_handler) trying to work on the same resources, which is
# a hellish nightmare to synchronize properly. For example, _coroutine_request_handler *could* just
# use a single queue for sending data back and forth, but since it runs in a while loop in order to
# handle more than one request, as soon as it would put any data onto the queue and then go to the
# next iteration in the loop, it would (likely, but not always, as it depends on how things get
# scheduled) immediately call get() again, get something out of queue that it doesn't expect and
# crash horribly. So this separation is necessary to retain my sanity
async with structio.create_pool() as pool:
# If the operation is cancellable, then we're not
# shielded
pool.scope.shielded = not cancellable
# Spawn a coroutine to process incoming requests from
# the new async thread. We can't await it because it
# needs to run in the background
wakeup = structio.util.wakeup_fd.WakeupFd()
handler = pool.spawn(_coroutine_request_handler, rq, rsq, wakeup.reader)
# Start the worker thread
threading.Thread(
target=_threaded_runner,
args=(sync_func, current_loop(), rq, rsq, terminate, wakeup, *args),
name="structio-worker-thread",
# We start cancellable threads in daemonic mode so that
# the main thread doesn't get stuck waiting on them forever
# when their associated async counterpart gets cancelled. This
# is due to the fact that there's really no way to "kill" a thread
# (and for good reason!), so we just pretend nothing happened and go
# about our merry way, hoping the thread dies eventually I guess
daemon=cancellable,
).start()
# Wait for the thread to terminate
await terminate.wait()
# Worker thread has exited: we no longer need to process
# any requests, so we shut our request handler down
handler.cancel()
# Wait for the final result from the thread
success, data = await rsq.get()
if success:
return data
raise data
@enable_ki_protection
@ -327,10 +311,13 @@ def run_coro(
async_func: Callable[[Any, Any], Coroutine[Any, Any, Any]], *args, **kwargs
):
"""
Submits a coroutine for execution to the event loop, passing any
arguments along the way. Return values and exceptions are propagated
and from the point of view of the calling thread, this call blocks
until the coroutine returns
Submits a coroutine for execution to the event loop from another thread,
passing any arguments along the way. Return values and exceptions are
propagated, and from the point of view of the calling thread this call
blocks until the coroutine returns. The thread must be async flavored,
meaning it must be able to communicate back and forth with the event
loop running in the main thread (in practice, this means only threads
spawned with run_in_worker are able to call this)
"""
try:
@ -338,8 +325,10 @@ def run_coro(
raise StructIOException("cannot be called from async context")
except StructIOException:
pass
if not hasattr(_storage, "parent_loop"):
if not is_async_thread():
raise StructIOException("run_coro requires a running loop in another thread!")
# Wake up the event loop if it's blocked in a call to select() or similar I/O routine
_storage.wakeup.wakeup()
_storage.rq.put_sync(async_func(*args, **kwargs))
success, data = _storage.rsq.get_sync()
if success:

View File

@ -0,0 +1,6 @@
from . import misc, ki, wakeup_fd
__all__ = ["misc",
"ki",
"wakeup_fd"]

View File

@ -25,3 +25,6 @@ class ThereCanBeOnlyOne:
def __exit__(self, *args):
self._acquired = False
__all__ = ["ThereCanBeOnlyOne"]

View File

@ -0,0 +1,23 @@
from structio.io.socket import socketpair
import signal
class WakeupFd:
"""
A thin wrapper over a socket pair used in set_wakeup_fd
and for thread wakeup events
"""
def __init__(self):
self.reader, self.writer = socketpair()
def set_wakeup_fd(self):
signal.set_wakeup_fd(self.writer.socket.fileno())
def wakeup(self):
try:
self.writer.socket.send(b"\x00")
except BlockingIOError:
pass

20
tests/scope_tree.py Normal file
View File

@ -0,0 +1,20 @@
import structio
async def child(k):
print("[child] I'm alive! Spawning sleeper")
async with structio.create_pool() as p:
p.spawn(structio.sleep, k)
print("[child] I'm done sleeping!")
async def main(n: int, k):
print(f"[main] Spawning {n} children in their own pools, each sleeping for {k} seconds")
t = structio.clock()
async with structio.create_pool() as p:
for _ in range(n):
p.spawn(child, k)
print(f"[main] Done in {structio.clock() - t:.2f} seconds")
# Should exit in ~2 seconds
structio.run(main, 10, 2)

26
tests/stress_test.py Normal file
View File

@ -0,0 +1,26 @@
import datetime as dtt
import structio
async def task():
for i in range(100):
await structio.sleep(0.01)
async def main(tests: list[int]):
print("[main] Starting stress test, aggregate results will be printed at the end")
results = []
for N in tests:
print(f"[main] Starting test with {N} tasks")
start = dtt.datetime.utcnow()
async with structio.create_pool() as p:
for _ in range(N):
p.spawn(task)
end = dtt.datetime.utcnow()
results.append((end - start).total_seconds())
print(f"[main] Test with {N} tasks completed in {results[-1]:.2f} seconds")
results = " ".join((f'{r:0>5.2f}' for r in results))
print(f"[main] Results: {results}")
structio.run(main, [10, 100, 1000, 10000])

View File

@ -2,16 +2,16 @@ import structio
import time
def fake_async_sleeper(n):
print(f"[thread] About to sleep for {n} seconds")
def fake_async_sleeper(n, name: str = ""):
print(f"[thread{f' {name}' if name else ''}] About to sleep for {n} seconds")
t = time.time()
if structio.thread.is_async_thread():
print(f"[thread] I have async superpowers!")
print(f"[thread{f' {name}' if name else ''}] I have async superpowers!")
structio.thread.run_coro(structio.sleep, n)
else:
print(f"[thread] Using old boring time.sleep :(")
print(f"[thread{f' {name}' if name else ''}] Using old boring time.sleep :(")
time.sleep(n)
print(f"[thread] Slept for {time.time() - t:.2f} seconds")
print(f"[thread{f' {name}' if name else ''}] Slept for {time.time() - t:.2f} seconds")
return n
@ -33,5 +33,22 @@ async def main_timeout(n, k):
print(f"[main] Exited in {structio.clock() - t:.2f} seconds")
async def main_multiple(n, k):
print(f"[main] Spawning {n} worker threads each sleeping for {k} seconds")
t = structio.clock()
async with structio.create_pool() as pool:
for i in range(n):
pool.spawn(structio.thread.run_in_worker, fake_async_sleeper, k, str(i))
print(f"[main] Workers spawned")
# Keep in mind that there is some overhead associated with running worker threads,
# not to mention that it gets tricky with how the OS schedules them and whatnot. So,
# it's unlikely that all threads finish exactly at the same time and that we exit in
# k seconds, even just because there's a lot of back and forth going on under the hood
# between structio and the worker threads themselves
print(f"[main] Exited in {structio.clock() - t:.2f} seconds")
structio.run(main, 2)
structio.run(main_timeout, 5, 3)
structio.run(main_multiple, 5, 3)