Added experimental task-level cancellation and waiting primitives
This commit is contained in:
parent
15d0a0674f
commit
d8b2066126
|
@ -5,7 +5,7 @@
|
|||
<sourceFolder url="file://$MODULE_DIR$" isTestSource="false" />
|
||||
<excludeFolder url="file://$MODULE_DIR$/venv" />
|
||||
</content>
|
||||
<orderEntry type="jdk" jdkName="Python 3.10 (structio)" jdkType="Python SDK" />
|
||||
<orderEntry type="inheritedJdk" />
|
||||
<orderEntry type="sourceFolder" forTests="false" />
|
||||
</component>
|
||||
</module>
|
|
@ -1,4 +1,4 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.10 (structio)" project-jdk-type="Python SDK" />
|
||||
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.10 (StructuredIO)" project-jdk-type="Python SDK" />
|
||||
</project>
|
|
@ -4,9 +4,11 @@ from structio.core.kernels.fifo import FIFOKernel
|
|||
from structio.core.managers.io.simple import SimpleIOManager
|
||||
from structio.core.managers.signals.sigint import SigIntManager
|
||||
from structio.core.time.clock import DefaultClock
|
||||
from structio.core.syscalls import sleep
|
||||
from structio.core.syscalls import sleep, suspend as _suspend
|
||||
from structio.core.context import TaskPool, TaskScope
|
||||
from structio.core.exceptions import Cancelled, TimedOut
|
||||
from structio.core import task
|
||||
from structio.core.task import Task, TaskState
|
||||
from structio.sync import Event, Queue, MemoryChannel, Semaphore
|
||||
from structio.core.abc import Channel, Stream, ChannelReader, ChannelWriter
|
||||
|
||||
|
@ -66,6 +68,25 @@ def clock():
|
|||
return _run.current_loop().clock.current_time()
|
||||
|
||||
|
||||
async def _join(self: Task):
|
||||
self.waiters.add(_run.current_task())
|
||||
await _suspend()
|
||||
if self.state == TaskState.CRASHED:
|
||||
raise self.exc
|
||||
return self.result
|
||||
|
||||
|
||||
def _cancel(self: Task):
|
||||
_run.current_loop().cancel_task(self)
|
||||
|
||||
|
||||
task._joiner = _join
|
||||
|
||||
_cancel.__name__ = Task.cancel.__name__
|
||||
_cancel.__doc__ = Task.cancel.__doc__
|
||||
Task.cancel = _cancel
|
||||
|
||||
|
||||
__all__ = ["run",
|
||||
"sleep",
|
||||
"create_pool",
|
||||
|
|
|
@ -487,6 +487,14 @@ class BaseKernel(ABC):
|
|||
# Pool for system tasks
|
||||
self.pool: "TaskPool" = None
|
||||
|
||||
@abstractmethod
|
||||
def cancel_task(self, task: Task):
|
||||
"""
|
||||
Cancels the given task individually
|
||||
"""
|
||||
|
||||
return NotImplemented
|
||||
|
||||
@abstractmethod
|
||||
def signal_notify(self, sig: int, frame: FrameType):
|
||||
"""
|
||||
|
|
|
@ -265,8 +265,11 @@ class FIFOKernel(BaseKernel):
|
|||
|
||||
# TODO: Anything else?
|
||||
task.pool: TaskPool
|
||||
for waiter in task.waiters:
|
||||
self.reschedule(waiter)
|
||||
if task.pool.done() and task is not self.entry_point:
|
||||
self.reschedule(task.pool.entry_point)
|
||||
task.waiters.clear()
|
||||
self.event("on_task_exit", task)
|
||||
self.io_manager.release_task(task)
|
||||
|
||||
|
@ -276,29 +279,11 @@ class FIFOKernel(BaseKernel):
|
|||
"""
|
||||
|
||||
self.event("on_exception_raised", task, task.exc)
|
||||
task.pool.scope.cancel()
|
||||
current = task.pool.scope
|
||||
while current and current is not self.pool.scope:
|
||||
# Unroll nested task scopes until one of
|
||||
# them catches the exception, or we reach
|
||||
# the topmost one (i.e. ours), in which case
|
||||
# we'll crash later
|
||||
current.cancel()
|
||||
# We re-raise the original exception into
|
||||
# the parent of the task scope
|
||||
# TODO: Implement something akin to trio.MultiError, or (better)
|
||||
# ExceptionGroup (which is Python 3.11+ only)
|
||||
self.throw(current.owner, task.exc)
|
||||
if current.owner.done():
|
||||
# The scope's entry point has managed
|
||||
# the exception and has exited, we can
|
||||
# proceed!
|
||||
break
|
||||
current = current.outer
|
||||
for waiter in task.waiters:
|
||||
self.reschedule(waiter)
|
||||
self.throw(task.pool.scope.owner, task.exc)
|
||||
task.waiters.clear()
|
||||
self.release(task)
|
||||
self.current_scope = task.pool.scope.outer
|
||||
self.current_pool = task.pool.outer
|
||||
|
||||
def on_cancel(self, task: Task):
|
||||
"""
|
||||
|
@ -306,6 +291,11 @@ class FIFOKernel(BaseKernel):
|
|||
cancellation exception
|
||||
"""
|
||||
|
||||
for waiter in task.waiters:
|
||||
self.reschedule(waiter)
|
||||
task.waiters.clear()
|
||||
if task.pool.done() and task is not self.entry_point:
|
||||
self.reschedule(task.pool.entry_point)
|
||||
self.release(task)
|
||||
|
||||
def init_scope(self, scope: TaskScope):
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from enum import Enum, auto
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Coroutine, Any
|
||||
from typing import Coroutine, Any, Callable
|
||||
|
||||
|
||||
class TaskState(Enum):
|
||||
|
@ -13,6 +13,9 @@ class TaskState(Enum):
|
|||
IO: int = auto()
|
||||
|
||||
|
||||
_joiner: Callable[[Any, Any], Coroutine[Any, Any, Any]] | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class Task:
|
||||
"""
|
||||
|
@ -38,6 +41,8 @@ class Task:
|
|||
next_deadline: Any = -1
|
||||
# Is cancellation pending?
|
||||
pending_cancellation: bool = False
|
||||
# Any task explicitly joining us?
|
||||
waiters: set["Task"] = field(default_factory=set)
|
||||
|
||||
def done(self):
|
||||
"""
|
||||
|
@ -56,3 +61,18 @@ class Task:
|
|||
"""
|
||||
|
||||
return self.coroutine.__hash__()
|
||||
|
||||
# These are patched later at import time!
|
||||
def __await__(self):
|
||||
"""
|
||||
Wait for the task to complete and return/raise appropriately (returns when cancelled)
|
||||
"""
|
||||
|
||||
return _joiner(self).__await__()
|
||||
|
||||
def cancel(self):
|
||||
"""
|
||||
Cancels the given task
|
||||
"""
|
||||
|
||||
return NotImplemented
|
||||
|
|
|
@ -0,0 +1,45 @@
|
|||
import structio
|
||||
import random
|
||||
|
||||
|
||||
async def waiter(ch: structio.ChannelReader):
|
||||
print("[waiter] Waiter is alive!")
|
||||
async with ch:
|
||||
while True:
|
||||
print("[waiter] Awaiting events")
|
||||
evt: structio.Event = await ch.receive()
|
||||
if not evt:
|
||||
break
|
||||
print("[waiter] Received event, waiting to be triggered")
|
||||
await evt.wait()
|
||||
print("[waiter] Event triggered")
|
||||
print("[waiter] Done!")
|
||||
|
||||
|
||||
async def sender(ch: structio.ChannelWriter, n: int):
|
||||
print("[sender] Sender is alive!")
|
||||
async with ch:
|
||||
for _ in range(n):
|
||||
print("[sender] Sending event")
|
||||
ev = structio.Event()
|
||||
await ch.send(ev)
|
||||
t = random.random()
|
||||
print(f"[sender] Sent event, sleeping {t:.2f} seconds")
|
||||
await structio.sleep(t)
|
||||
print("[sender] Setting the event")
|
||||
ev.set()
|
||||
await ch.send(None)
|
||||
print("[sender] Done!")
|
||||
|
||||
|
||||
async def main(n: int):
|
||||
print("[main] Parent is alive")
|
||||
channel = structio.MemoryChannel(1)
|
||||
async with structio.create_pool() as pool:
|
||||
pool.spawn(waiter, channel.reader)
|
||||
pool.spawn(sender, channel.writer, n)
|
||||
print("[main] Children spawned")
|
||||
print("[main] Done!")
|
||||
|
||||
|
||||
structio.run(main, 3)
|
|
@ -1,14 +1,15 @@
|
|||
import structio
|
||||
|
||||
|
||||
async def successful(name: str, n: int):
|
||||
async def successful(name: str, n):
|
||||
before = structio.clock()
|
||||
print(f"[child {name}] Sleeping for {n} seconds")
|
||||
await structio.sleep(n)
|
||||
print(f"[child {name}] Done! Slept for {structio.clock() - before:.2f} seconds")
|
||||
return n
|
||||
|
||||
|
||||
async def failing(name: str, n: int):
|
||||
async def failing(name: str, n):
|
||||
before = structio.clock()
|
||||
print(f"[child {name}] Sleeping for {n} seconds")
|
||||
await structio.sleep(n)
|
||||
|
|
|
@ -2,7 +2,7 @@ import structio
|
|||
from nested_pool_inner_raises import successful, failing
|
||||
|
||||
|
||||
async def main(
|
||||
async def main_simple(
|
||||
children_outer: list[tuple[str, int]], children_inner: list[tuple[str, int]]
|
||||
):
|
||||
before = structio.clock()
|
||||
|
@ -22,9 +22,39 @@ async def main(
|
|||
print(f"[main] Children exited in {structio.clock() - before:.2f} seconds")
|
||||
|
||||
|
||||
async def main_nested(
|
||||
children_outer: list[tuple[str, int]], children_inner: list[tuple[str, int]]
|
||||
):
|
||||
before = structio.clock()
|
||||
try:
|
||||
async with structio.create_pool() as p1:
|
||||
print(f"[main] Spawning children in first context ({hex(id(p1))})")
|
||||
for name, delay in children_outer:
|
||||
p1.spawn(failing, name, delay)
|
||||
print("[main] Children spawned")
|
||||
async with structio.create_pool() as p2:
|
||||
print(f"[main] Spawning children in second context ({hex(id(p2))})")
|
||||
for name, delay in children_inner:
|
||||
p2.spawn(successful, name, delay)
|
||||
print("[main] Children spawned")
|
||||
async with structio.create_pool() as p3:
|
||||
print(f"[main] Spawning children in third context ({hex(id(p3))})")
|
||||
for name, delay in children_inner:
|
||||
p3.spawn(successful(), name, delay)
|
||||
print("[main] Children spawned")
|
||||
except TypeError:
|
||||
print("[main] TypeError caught!")
|
||||
print(f"[main] Children exited in {structio.clock() - before:.2f} seconds")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
structio.run(
|
||||
main,
|
||||
main_simple,
|
||||
[("second", 2), ("third", 3)],
|
||||
[("first", 1), ("fourth", 4)],
|
||||
)
|
||||
structio.run(
|
||||
main_nested,
|
||||
[("second", 2), ("third", 3)],
|
||||
[("first", 1), ("fourth", 4)],
|
||||
)
|
||||
|
|
|
@ -0,0 +1,41 @@
|
|||
import structio
|
||||
|
||||
from nested_pool_inner_raises import successful, failing
|
||||
|
||||
|
||||
async def main_cancel(i):
|
||||
print("[main] Parent is alive, spawning child")
|
||||
t = structio.clock()
|
||||
async with structio.create_pool() as pool:
|
||||
task: structio.Task = pool.spawn(successful, "test", i * 2)
|
||||
print(f"[main] Child spawned, waiting {i} seconds before canceling it")
|
||||
await structio.sleep(i)
|
||||
print("[main] Cancelling child")
|
||||
task.cancel()
|
||||
print(f"[main] Exited in {structio.clock() - t:.2f} seconds")
|
||||
|
||||
|
||||
async def main_wait_successful():
|
||||
print("[main] Parent is alive, spawning (and explicitly waiting for) child")
|
||||
t = structio.clock()
|
||||
async with structio.create_pool() as pool:
|
||||
print(f"[main] Child has returned: {await pool.spawn(successful, 'test', 5)}")
|
||||
print(f"[main] Exited in {structio.clock() - t:.2f} seconds")
|
||||
|
||||
|
||||
async def main_wait_failing():
|
||||
print("[main] Parent is alive, spawning (and explicitly waiting for) child")
|
||||
t = structio.clock()
|
||||
try:
|
||||
async with structio.create_pool() as pool:
|
||||
print(f"[main] Child has returned: {await pool.spawn(failing, 'test', 5)}")
|
||||
except TypeError:
|
||||
print(f"[main] TypeError caught!")
|
||||
print(f"[main] Exited in {structio.clock() - t:.2f} seconds")
|
||||
|
||||
|
||||
# Total time should be about 15s
|
||||
structio.run(main_cancel, 5)
|
||||
structio.run(main_wait_successful)
|
||||
structio.run(main_wait_failing)
|
||||
|
Loading…
Reference in New Issue