From d8b2066126bcf49c3c05838eb9b6d3ebb24a6a7d Mon Sep 17 00:00:00 2001 From: Nocturn9x Date: Tue, 16 May 2023 15:48:19 +0200 Subject: [PATCH] Added experimental task-level cancellation and waiting primitives --- .idea/StructuredIO.iml | 2 +- .idea/misc.xml | 2 +- structio/__init__.py | 23 +++++++++++++++- structio/core/abc.py | 8 ++++++ structio/core/kernels/fifo.py | 32 ++++++++-------------- structio/core/task.py | 22 ++++++++++++++- tests/event_channel.py | 45 +++++++++++++++++++++++++++++++ tests/nested_pool_inner_raises.py | 5 ++-- tests/nested_pool_outer_raises.py | 34 +++++++++++++++++++++-- tests/task_handling.py | 41 ++++++++++++++++++++++++++++ 10 files changed, 185 insertions(+), 29 deletions(-) create mode 100644 tests/event_channel.py create mode 100644 tests/task_handling.py diff --git a/.idea/StructuredIO.iml b/.idea/StructuredIO.iml index ddb026d..ce769e3 100644 --- a/.idea/StructuredIO.iml +++ b/.idea/StructuredIO.iml @@ -5,7 +5,7 @@ - + \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml index df4a621..00cc026 100644 --- a/.idea/misc.xml +++ b/.idea/misc.xml @@ -1,4 +1,4 @@ - + \ No newline at end of file diff --git a/structio/__init__.py b/structio/__init__.py index 9244789..b2b96e1 100644 --- a/structio/__init__.py +++ b/structio/__init__.py @@ -4,9 +4,11 @@ from structio.core.kernels.fifo import FIFOKernel from structio.core.managers.io.simple import SimpleIOManager from structio.core.managers.signals.sigint import SigIntManager from structio.core.time.clock import DefaultClock -from structio.core.syscalls import sleep +from structio.core.syscalls import sleep, suspend as _suspend from structio.core.context import TaskPool, TaskScope from structio.core.exceptions import Cancelled, TimedOut +from structio.core import task +from structio.core.task import Task, TaskState from structio.sync import Event, Queue, MemoryChannel, Semaphore from structio.core.abc import Channel, Stream, ChannelReader, ChannelWriter @@ -66,6 +68,25 @@ def clock(): return _run.current_loop().clock.current_time() +async def _join(self: Task): + self.waiters.add(_run.current_task()) + await _suspend() + if self.state == TaskState.CRASHED: + raise self.exc + return self.result + + +def _cancel(self: Task): + _run.current_loop().cancel_task(self) + + +task._joiner = _join + +_cancel.__name__ = Task.cancel.__name__ +_cancel.__doc__ = Task.cancel.__doc__ +Task.cancel = _cancel + + __all__ = ["run", "sleep", "create_pool", diff --git a/structio/core/abc.py b/structio/core/abc.py index 65941fd..ec9b7c3 100644 --- a/structio/core/abc.py +++ b/structio/core/abc.py @@ -487,6 +487,14 @@ class BaseKernel(ABC): # Pool for system tasks self.pool: "TaskPool" = None + @abstractmethod + def cancel_task(self, task: Task): + """ + Cancels the given task individually + """ + + return NotImplemented + @abstractmethod def signal_notify(self, sig: int, frame: FrameType): """ diff --git a/structio/core/kernels/fifo.py b/structio/core/kernels/fifo.py index 8df9efa..0713f67 100644 --- a/structio/core/kernels/fifo.py +++ b/structio/core/kernels/fifo.py @@ -265,8 +265,11 @@ class FIFOKernel(BaseKernel): # TODO: Anything else? task.pool: TaskPool + for waiter in task.waiters: + self.reschedule(waiter) if task.pool.done() and task is not self.entry_point: self.reschedule(task.pool.entry_point) + task.waiters.clear() self.event("on_task_exit", task) self.io_manager.release_task(task) @@ -276,29 +279,11 @@ class FIFOKernel(BaseKernel): """ self.event("on_exception_raised", task, task.exc) - task.pool.scope.cancel() - current = task.pool.scope - while current and current is not self.pool.scope: - # Unroll nested task scopes until one of - # them catches the exception, or we reach - # the topmost one (i.e. ours), in which case - # we'll crash later - current.cancel() - # We re-raise the original exception into - # the parent of the task scope - # TODO: Implement something akin to trio.MultiError, or (better) - # ExceptionGroup (which is Python 3.11+ only) - self.throw(current.owner, task.exc) - if current.owner.done(): - # The scope's entry point has managed - # the exception and has exited, we can - # proceed! - break - current = current.outer + for waiter in task.waiters: + self.reschedule(waiter) self.throw(task.pool.scope.owner, task.exc) + task.waiters.clear() self.release(task) - self.current_scope = task.pool.scope.outer - self.current_pool = task.pool.outer def on_cancel(self, task: Task): """ @@ -306,6 +291,11 @@ class FIFOKernel(BaseKernel): cancellation exception """ + for waiter in task.waiters: + self.reschedule(waiter) + task.waiters.clear() + if task.pool.done() and task is not self.entry_point: + self.reschedule(task.pool.entry_point) self.release(task) def init_scope(self, scope: TaskScope): diff --git a/structio/core/task.py b/structio/core/task.py index 3aa3175..2727a80 100644 --- a/structio/core/task.py +++ b/structio/core/task.py @@ -1,6 +1,6 @@ from enum import Enum, auto from dataclasses import dataclass, field -from typing import Coroutine, Any +from typing import Coroutine, Any, Callable class TaskState(Enum): @@ -13,6 +13,9 @@ class TaskState(Enum): IO: int = auto() +_joiner: Callable[[Any, Any], Coroutine[Any, Any, Any]] | None = None + + @dataclass class Task: """ @@ -38,6 +41,8 @@ class Task: next_deadline: Any = -1 # Is cancellation pending? pending_cancellation: bool = False + # Any task explicitly joining us? + waiters: set["Task"] = field(default_factory=set) def done(self): """ @@ -56,3 +61,18 @@ class Task: """ return self.coroutine.__hash__() + + # These are patched later at import time! + def __await__(self): + """ + Wait for the task to complete and return/raise appropriately (returns when cancelled) + """ + + return _joiner(self).__await__() + + def cancel(self): + """ + Cancels the given task + """ + + return NotImplemented diff --git a/tests/event_channel.py b/tests/event_channel.py new file mode 100644 index 0000000..0f061eb --- /dev/null +++ b/tests/event_channel.py @@ -0,0 +1,45 @@ +import structio +import random + + +async def waiter(ch: structio.ChannelReader): + print("[waiter] Waiter is alive!") + async with ch: + while True: + print("[waiter] Awaiting events") + evt: structio.Event = await ch.receive() + if not evt: + break + print("[waiter] Received event, waiting to be triggered") + await evt.wait() + print("[waiter] Event triggered") + print("[waiter] Done!") + + +async def sender(ch: structio.ChannelWriter, n: int): + print("[sender] Sender is alive!") + async with ch: + for _ in range(n): + print("[sender] Sending event") + ev = structio.Event() + await ch.send(ev) + t = random.random() + print(f"[sender] Sent event, sleeping {t:.2f} seconds") + await structio.sleep(t) + print("[sender] Setting the event") + ev.set() + await ch.send(None) + print("[sender] Done!") + + +async def main(n: int): + print("[main] Parent is alive") + channel = structio.MemoryChannel(1) + async with structio.create_pool() as pool: + pool.spawn(waiter, channel.reader) + pool.spawn(sender, channel.writer, n) + print("[main] Children spawned") + print("[main] Done!") + + +structio.run(main, 3) diff --git a/tests/nested_pool_inner_raises.py b/tests/nested_pool_inner_raises.py index dab250f..16edb88 100644 --- a/tests/nested_pool_inner_raises.py +++ b/tests/nested_pool_inner_raises.py @@ -1,14 +1,15 @@ import structio -async def successful(name: str, n: int): +async def successful(name: str, n): before = structio.clock() print(f"[child {name}] Sleeping for {n} seconds") await structio.sleep(n) print(f"[child {name}] Done! Slept for {structio.clock() - before:.2f} seconds") + return n -async def failing(name: str, n: int): +async def failing(name: str, n): before = structio.clock() print(f"[child {name}] Sleeping for {n} seconds") await structio.sleep(n) diff --git a/tests/nested_pool_outer_raises.py b/tests/nested_pool_outer_raises.py index 965b6ef..25d0331 100644 --- a/tests/nested_pool_outer_raises.py +++ b/tests/nested_pool_outer_raises.py @@ -2,7 +2,7 @@ import structio from nested_pool_inner_raises import successful, failing -async def main( +async def main_simple( children_outer: list[tuple[str, int]], children_inner: list[tuple[str, int]] ): before = structio.clock() @@ -22,9 +22,39 @@ async def main( print(f"[main] Children exited in {structio.clock() - before:.2f} seconds") +async def main_nested( + children_outer: list[tuple[str, int]], children_inner: list[tuple[str, int]] +): + before = structio.clock() + try: + async with structio.create_pool() as p1: + print(f"[main] Spawning children in first context ({hex(id(p1))})") + for name, delay in children_outer: + p1.spawn(failing, name, delay) + print("[main] Children spawned") + async with structio.create_pool() as p2: + print(f"[main] Spawning children in second context ({hex(id(p2))})") + for name, delay in children_inner: + p2.spawn(successful, name, delay) + print("[main] Children spawned") + async with structio.create_pool() as p3: + print(f"[main] Spawning children in third context ({hex(id(p3))})") + for name, delay in children_inner: + p3.spawn(successful(), name, delay) + print("[main] Children spawned") + except TypeError: + print("[main] TypeError caught!") + print(f"[main] Children exited in {structio.clock() - before:.2f} seconds") + + if __name__ == "__main__": structio.run( - main, + main_simple, + [("second", 2), ("third", 3)], + [("first", 1), ("fourth", 4)], + ) + structio.run( + main_nested, [("second", 2), ("third", 3)], [("first", 1), ("fourth", 4)], ) diff --git a/tests/task_handling.py b/tests/task_handling.py new file mode 100644 index 0000000..c3119d2 --- /dev/null +++ b/tests/task_handling.py @@ -0,0 +1,41 @@ +import structio + +from nested_pool_inner_raises import successful, failing + + +async def main_cancel(i): + print("[main] Parent is alive, spawning child") + t = structio.clock() + async with structio.create_pool() as pool: + task: structio.Task = pool.spawn(successful, "test", i * 2) + print(f"[main] Child spawned, waiting {i} seconds before canceling it") + await structio.sleep(i) + print("[main] Cancelling child") + task.cancel() + print(f"[main] Exited in {structio.clock() - t:.2f} seconds") + + +async def main_wait_successful(): + print("[main] Parent is alive, spawning (and explicitly waiting for) child") + t = structio.clock() + async with structio.create_pool() as pool: + print(f"[main] Child has returned: {await pool.spawn(successful, 'test', 5)}") + print(f"[main] Exited in {structio.clock() - t:.2f} seconds") + + +async def main_wait_failing(): + print("[main] Parent is alive, spawning (and explicitly waiting for) child") + t = structio.clock() + try: + async with structio.create_pool() as pool: + print(f"[main] Child has returned: {await pool.spawn(failing, 'test', 5)}") + except TypeError: + print(f"[main] TypeError caught!") + print(f"[main] Exited in {structio.clock() - t:.2f} seconds") + + +# Total time should be about 15s +structio.run(main_cancel, 5) +structio.run(main_wait_successful) +structio.run(main_wait_failing) +