diff --git a/aiosched/context.py b/aiosched/context.py index f6aeff6..1f4ea79 100644 --- a/aiosched/context.py +++ b/aiosched/context.py @@ -21,7 +21,8 @@ from aiosched.internals.syscalls import ( wait, cancel, join, - current_task + current_task, + sleep ) from typing import Any, Coroutine, Callable @@ -52,10 +53,21 @@ class TaskContext: self.gather: bool = gather # TODO: Implement # For how long do we allow tasks inside us # to run? - self.timeout: int | float = timeout # TODO: Implement + self.timeout: int | float = timeout + self.timed_out: bool = False # Have we crashed? self.error: BaseException | None = None + async def _timeout_worker(self): + await sleep(self.timeout) + if not self.done(): + self.error = TimeoutError("timed out") + self.timed_out = True + for task in self.tasks: + if task is self.entry_point or task.done(): + continue + await cancel(task, block=True) + async def spawn( self, func: Callable[..., Coroutine[Any, Any, Any]], *args, **kwargs ) -> Task: @@ -95,6 +107,8 @@ class TaskContext: exceptions """ + if self.timeout: + waiter = await spawn(self._timeout_worker) try: for task in self.tasks: # This forces the interpreter to stop at the @@ -112,11 +126,13 @@ class TaskContext: await self.cancel(False) self.error = exc finally: + if self.timeout and not waiter.done(): + await cancel(waiter, block=True) self.entry_point.propagate = True if self.silent: return - if self.entry_point.exc: - raise self.entry_point.exc + if self.error: + raise self.error # Task method wrappers diff --git a/tests/context_timeout.py b/tests/context_timeout.py new file mode 100644 index 0000000..e510ea0 --- /dev/null +++ b/tests/context_timeout.py @@ -0,0 +1,16 @@ +import aiosched +from raw_wait import child + + +async def main(children: list[tuple[str, int]]): + print("[main] Spawning children") + async with aiosched.with_context(timeout=4, silent=True) as ctx: + for name, delay in children: + await ctx.spawn(child, name, delay) + print("[main] Children spawned") + before = aiosched.clock() + print(f"[main] Children exited in {aiosched.clock() - before:.2f} seconds") + + +if __name__ == "__main__": + aiosched.run(main, [("first", 2), ("second", 4), ("third", 6), ("fourth", 8)], debugger=None) diff --git a/tests/context_wait.py b/tests/context_wait.py index 9deee48..2c3d140 100644 --- a/tests/context_wait.py +++ b/tests/context_wait.py @@ -2,7 +2,6 @@ import aiosched from raw_wait import child - async def main(children: list[tuple[str, int]]): print("[main] Spawning children") async with aiosched.with_context() as ctx: