From ff2acf298fb2016387f868bca4b0333d85025fcc Mon Sep 17 00:00:00 2001 From: Nocturn9x Date: Wed, 2 Nov 2022 09:28:04 +0100 Subject: [PATCH] Various fixes for exception handling in contexts --- aiosched/__init__.py | 2 +- aiosched/context.py | 50 ++++++++++++----------------- aiosched/internals/syscalls.py | 11 +++++-- aiosched/kernel.py | 23 +++++++------ aiosched/task.py | 1 - tests/context_catch.py | 4 +-- tests/context_wait.py | 2 +- tests/nested_context_catch_inner.py | 8 ++--- tests/nested_context_catch_outer.py | 6 ++-- tests/raw_catch.py | 6 ++-- 10 files changed, 56 insertions(+), 57 deletions(-) diff --git a/aiosched/__init__.py b/aiosched/__init__.py index d77430f..3868389 100644 --- a/aiosched/__init__.py +++ b/aiosched/__init__.py @@ -16,7 +16,7 @@ See the License for the specific language governing permissions and limitations under the License. """ from aiosched.runtime import run, get_event_loop, new_event_loop, clock, with_context -from aiosched.internals.syscalls import spawn, wait, sleep, cancel, checkpoint +from aiosched.internals.syscalls import spawn, wait, sleep, cancel, checkpoint, join import aiosched.task import aiosched.errors import aiosched.context diff --git a/aiosched/context.py b/aiosched/context.py index 17d1a8a..abd9da4 100644 --- a/aiosched/context.py +++ b/aiosched/context.py @@ -15,6 +15,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ +from functools import partial + from aiosched.task import Task from aiosched.errors import Cancelled from aiosched.internals.syscalls import ( @@ -57,11 +59,9 @@ class TaskContext(Task): # Do we gather multiple exceptions from # children tasks? self.gather: bool = gather - # Do we wrap any other task contexts? - self.inner: TaskContext | None = None async def spawn( - self, func: Callable[..., Coroutine[Any, Any, Any]], *args, **kwargs + self, func: Callable[..., Coroutine[Any, Any, Any]], *args, **kwargs ) -> Task: """ Spawns a child task @@ -69,8 +69,8 @@ class TaskContext(Task): task = await spawn(func, *args, **kwargs) task.context = self - await join(task) self.tasks.append(task) + await join(task) return task async def __aenter__(self): @@ -101,14 +101,9 @@ class TaskContext(Task): self.exc = exc if not self.silent: raise self.exc - if self.inner: - for task in self.inner.tasks: - try: - await wait(task) - except BaseException: - await self.inner.cancel(False) - self.inner.propagate = False - await close_context() + finally: + await close_context(self) + self.entry_point.propagate = True # Task method wrappers @@ -119,21 +114,22 @@ class TaskContext(Task): and cancelling them """ - if self.inner: - await self.inner.cancel(propagate) for task in self.tasks: if task is self.entry_point: continue - await cancel(task) + if isinstance(task, Task): + await cancel(task) + else: + task: TaskContext + await task.cancel(propagate) self.cancelled = True - await close_context() self.propagate = False if propagate: - if isinstance(self.entry_point, TaskContext): - self.entry_point: TaskContext - await self.entry_point.cancel() - else: + if isinstance(self.entry_point, Task): await cancel(self.entry_point) + else: + self.entry_point: TaskContext + await self.entry_point.cancel(propagate) def done(self) -> bool: """ @@ -142,17 +138,9 @@ class TaskContext(Task): """ for task in self.tasks: - if task is self.entry_point: - continue if not task.done(): return False - if ( - not isinstance(self.entry_point, TaskContext) - and not self.entry_point.done() - ): - return False - if self.inner: - return self.inner.done() + return True @property def state(self) -> int: @@ -206,6 +194,10 @@ class TaskContext(Task): def joiners(self, joiners: set[Task]): self.entry_point.joiners = joiners + @property + def coroutine(self): + return self.entry_point.coroutine + def __hash__(self): return self.entry_point.__hash__() diff --git a/aiosched/internals/syscalls.py b/aiosched/internals/syscalls.py index 98e45a2..b0255fb 100644 --- a/aiosched/internals/syscalls.py +++ b/aiosched/internals/syscalls.py @@ -139,7 +139,8 @@ async def wait(task: Task) -> Any | None: multiple times by multiple tasks. Returns immediately if the task has completed already, but exceptions are - propagated only once + propagated only once. Returns the task's + return value, if it has one :param task: The task to wait for :type task: :class: Task @@ -149,6 +150,10 @@ async def wait(task: Task) -> Any | None: current = await current_task() if task is current: raise SchedulerError("a task cannot join itself") + if current not in task.joiners: + # Luckily we use a set, so this has O(1) + # complexity + await join(task) # Waiting implies joining! await syscall("wait", task) if task.exc and task.state != TaskState.CANCELLED and task.propagate: task.propagate = False @@ -226,9 +231,9 @@ async def set_context(ctx): await syscall("set_context", ctx) -async def close_context(): +async def close_context(ctx): """ Closes the current task context """ - await syscall("close_context") + await syscall("close_context", ctx) diff --git a/aiosched/kernel.py b/aiosched/kernel.py index 2862af7..7e96807 100644 --- a/aiosched/kernel.py +++ b/aiosched/kernel.py @@ -233,6 +233,10 @@ class FIFOKernel: if not self.run_ready: return # No more tasks to run! self.current_task = self.run_ready.popleft() + # We nullify the exception object just in case the + # entry point raised and caught an error so that + # self.start() doesn't raise it again at the end + self.current_task.exc = None self.debugger.before_task_step(self.current_task) # Some debugging and internal chatter here self.current_task.state = TaskState.RUN @@ -477,25 +481,24 @@ class FIFOKernel: Sets the current task context. This is implemented as simply wrapping the current task inside the context and replacing the - Task object with the TaskContext one + Task object with the TaskContext one. This + may also wrap another task context into a + new one, but the loop doesn't need to care + about that: the API is designed exactly for + this """ ctx.entry_point = self.current_task - if isinstance(self.current_task, TaskContext): - self.current_task.inner = ctx - else: - ctx.tasks.append(ctx.entry_point) - self.current_task.context = ctx + ctx.tasks.append(ctx.entry_point) + self.current_task.context = ctx self.current_task = ctx self.reschedule_running() - def close_context(self): + def close_context(self, ctx: TaskContext): """ - Closes the context associated with the current - task + Closes the given context """ - ctx: TaskContext = self.current_task task = ctx.entry_point task.context = None self.current_task = task diff --git a/aiosched/task.py b/aiosched/task.py index 57d7e35..198be10 100644 --- a/aiosched/task.py +++ b/aiosched/task.py @@ -20,7 +20,6 @@ from enum import Enum, auto from typing import Coroutine, Any from dataclasses import dataclass, field - class TaskState(Enum): """ An enumeration of task states diff --git a/tests/context_catch.py b/tests/context_catch.py index e53d904..964e633 100644 --- a/tests/context_catch.py +++ b/tests/context_catch.py @@ -1,5 +1,5 @@ import aiosched -from raw_catch import child +from raw_catch import child_raises from debugger import Debugger @@ -8,7 +8,7 @@ async def main(children: list[tuple[str, int]]): async with aiosched.with_context() as ctx: print("[main] Spawning children") for name, delay in children: - await ctx.spawn(child, name, delay) + await ctx.spawn(child_raises, name, delay) print("[main] Children spawned") before = aiosched.clock() except BaseException as err: diff --git a/tests/context_wait.py b/tests/context_wait.py index 60ee16b..af7f93b 100644 --- a/tests/context_wait.py +++ b/tests/context_wait.py @@ -14,4 +14,4 @@ async def main(children: list[tuple[str, int]]): if __name__ == "__main__": - aiosched.run(main, [("first", 1), ("second", 2), ("third", 3)], debugger=Debugger()) + aiosched.run(main, [("first", 1), ("second", 2), ("third", 3)], debugger=None) diff --git a/tests/nested_context_catch_inner.py b/tests/nested_context_catch_inner.py index b1f67d6..f285371 100644 --- a/tests/nested_context_catch_inner.py +++ b/tests/nested_context_catch_inner.py @@ -1,5 +1,5 @@ import aiosched -from raw_catch import child as errorer +from raw_catch import child_raises from raw_wait import child as successful from debugger import Debugger @@ -19,7 +19,7 @@ async def main( async with aiosched.with_context() as ctx2: print("[main] Spawning children in second context") for name, delay in children_inner: - await ctx2.spawn(errorer, name, delay) + await ctx2.spawn(child_raises, name, delay) print("[main] Children spawned") print(f"[main] Children exited in {aiosched.clock() - before:.2f} seconds") @@ -27,7 +27,7 @@ async def main( if __name__ == "__main__": aiosched.run( main, - [("first", 1), ("second", 2)], - [("third", 3), ("fourth", 4)], + [("first", 1), ("third", 3)], + [("second", 2), ("fourth", 4)], debugger=None, ) diff --git a/tests/nested_context_catch_outer.py b/tests/nested_context_catch_outer.py index 774192b..3110c49 100644 --- a/tests/nested_context_catch_outer.py +++ b/tests/nested_context_catch_outer.py @@ -1,5 +1,5 @@ import aiosched -from raw_catch import child +from raw_catch import child_raises from debugger import Debugger @@ -12,12 +12,12 @@ async def main( before = aiosched.clock() print("[main] Spawning children in first context") for name, delay in children_outer: - await ctx.spawn(child, name, delay) + await ctx.spawn(child_raises, name, delay) print("[main] Children spawned") async with aiosched.with_context() as ctx2: print("[main] Spawning children in second context") for name, delay in children_inner: - await ctx2.spawn(child, name, delay) + await ctx2.spawn(child_raises, name, delay) print("[main] Children spawned") except BaseException as err: print(f"[main] Child raised an exception -> {type(err).__name__}: {err}") diff --git a/tests/raw_catch.py b/tests/raw_catch.py index 1b9c1df..953f3fb 100644 --- a/tests/raw_catch.py +++ b/tests/raw_catch.py @@ -2,17 +2,17 @@ import aiosched from debugger import Debugger -async def child(name: str, n: int): +async def child_raises(name: str, n: int): before = aiosched.clock() print(f"[child {name}] Sleeping for {n} seconds") await aiosched.sleep(n) - print(f"[child {name}] Done! Slept for {aiosched.clock() - before:.2f} seconds") + print(f"[child {name}] Done! Slept for {aiosched.clock() - before:.2f} seconds, raising now!") raise TypeError("waa") async def main(n: int): print("[main] Spawning child") - task = await aiosched.spawn(child, "raise", n) + task = await aiosched.spawn(child_raises, "raise", n) print("[main] Waiting for child") before = aiosched.clock() try: