diff --git a/aiosched/context.py b/aiosched/context.py index 3edea10..15d4b93 100644 --- a/aiosched/context.py +++ b/aiosched/context.py @@ -16,31 +16,43 @@ See the License for the specific language governing permissions and limitations under the License. """ from aiosched.task import Task -from aiosched.internals.syscalls import spawn, wait, cancel +from aiosched.errors import Cancelled +from aiosched.internals.syscalls import spawn, wait, cancel, set_context, close_context, join from typing import Any, Coroutine, Callable class TaskContext(Task): """ - An asynchronous task context that automatically waits - for all tasks spawned within it. A TaskContext object - behaves like a task and is handled as a single unit - inside the event loop + An asynchronous context manager that automatically waits + for all tasks spawned within it and cancels itself when + an exception occurs. A TaskContext object behaves like + a regular task and the event loop treats it like a single + unit rather than a collection of tasks (in fact, the event + loop doesn't even know whether the current task is a task + context or not, which is by design). TaskContexts can be + nested and will cancel inner ones if an exception is raised + inside them """ - def __init__(self) -> None: + def __init__(self, silent: bool = False, gather: bool = True) -> None: """ Object constructor """ - # All the tasks that belong to this context. This - # includes any inner contexts contained within this - # one - self.tasks: list[Task | "TaskContext"] = [] + # All the tasks that belong to this context + self.tasks: list[Task] = [] # Whether we have been cancelled or not self.cancelled: bool = False - super().__init__(f"TaskContext object at {hex(id(self))}", None) - + # The context's entry point (needed to forward run() calls and the like) + self.entry_point: Task | TaskContext | None = None + # Do we ignore exceptions? + self.silent: bool = silent + # Do we gather multiple exceptions from + # children tasks? + self.gather: bool = gather + # Do we wrap any other task contexts? + self.inner: TaskContext | None = None + async def spawn( self, func: Callable[..., Coroutine[Any, Any, Any]], *args, **kwargs ) -> Task: @@ -49,6 +61,8 @@ class TaskContext(Task): """ task = await spawn(func, *args, **kwargs) + task.context = self + await join(task) self.tasks.append(task) return task @@ -57,44 +71,136 @@ class TaskContext(Task): Implements the asynchronous context manager interface """ + await set_context(self) return self async def __aexit__(self, exc_type: Exception, exc: Exception, tb): """ Implements the asynchronous context manager interface, waiting - for all the tasks spawned inside the context + for all the tasks spawned inside the context and handling + exceptions """ - for task in self.tasks: - # This forces the interpreter to stop at the - # end of the block and wait for all - # children to exit - try: + try: + for task in self.tasks: + # This forces the interpreter to stop at the + # end of the block and wait for all + # children to exit + if task is self.entry_point: + continue await wait(task) - self.tasks.remove(task) - except BaseException: - self.tasks.remove(task) - await self.cancel() - raise + except BaseException as exc: + await self.cancel(False) + self.exc = exc + if not self.silent: + raise self.exc + if self.inner: + for task in self.inner.tasks: + try: + await wait(task) + except BaseException: + await self.inner.cancel(False) + self.inner.propagate = False + await close_context() - async def cancel(self): + # Task method wrappers + + async def cancel(self, propagate: bool = True): """ Cancels the entire context, iterating over all - of its tasks and cancelling them + of its tasks (which includes inner contexts) + and cancelling them """ + if self.inner: + await self.inner.cancel(propagate) for task in self.tasks: + if task is self.entry_point: + continue await cancel(task) self.cancelled = True - self.tasks = [] + await close_context() + self.propagate = False + if propagate: + if isinstance(self.entry_point, TaskContext): + self.entry_point: TaskContext + await self.entry_point.cancel() + else: + await cancel(self.entry_point) def done(self) -> bool: """ - Returns True if all the tasks inside the - context have exited, False otherwise + Returns whether all the tasks inside the + context have exited """ - return all([task.done() for task in self.tasks]) + for task in self.tasks: + if task is self.entry_point: + continue + if not task.done(): + return False + if not isinstance(self.entry_point, TaskContext) and not self.entry_point.done(): + return False + if self.inner: + return self.inner.done() + + @property + def state(self) -> int: + return self.entry_point.state + + @state.setter + def state(self, state: int): + self.entry_point.state = state + + @property + def result(self) -> Any: + return self.entry_point.result + + @result.setter + def result(self, result: Any): + self.entry_point.result = result + + @property + def exc(self) -> BaseException: + return self.entry_point.exc + + @exc.setter + def exc(self, exc: BaseException): + self.entry_point.exc = exc + + @property + def propagate(self) -> bool: + return self.entry_point.propagate + + @propagate.setter + def propagate(self, val: bool): + self.entry_point.propagate = val + + @property + def name(self): + return self.entry_point.name + + def throw(self, err: BaseException): + for task in self.tasks: + try: + task.throw(err) + except err: + continue + self.entry_point.throw(err) + + @property + def joiners(self) -> set[Task]: + return self.entry_point.joiners + + @joiners.setter + def joiners(self, joiners: set[Task]): + self.entry_point.joiners = joiners + + def __hash__(self): + return self.entry_point.__hash__() + + def run(self, what: Any | None = None): + return self.entry_point.run(what) def __del__(self): """ @@ -109,4 +215,13 @@ class TaskContext(Task): Implements repr(self) """ - return f"TaskContext({self.tasks})" + result = "TaskContext([" + for i, task in enumerate(self.tasks): + if task is self.entry_point: + result += repr(self.entry_point) + else: + result += repr(task) + if i < len(self.tasks) - 1: + result += ", " + result += "])" + return result diff --git a/aiosched/errors.py b/aiosched/errors.py index 06ba2fb..5e9a23b 100644 --- a/aiosched/errors.py +++ b/aiosched/errors.py @@ -15,6 +15,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ +import traceback from aiosched.task import Task @@ -67,3 +68,37 @@ class Cancelled(BaseException): """ task: Task + + +class ErrorStack(SchedulerError): + """ + This exception wraps multiple exceptions and + shows each individual traceback of them when + printed. This is to ensure that no exception is + lost even if 2 or more tasks raise at the + same time or during cancellation of other + tasks + """ + + def __init__(self, errors: list[BaseException]): + """ + Object constructor + """ + + super().__init__() + self.errors = errors + + def __str__(self): + """ + Returns str(self) + """ + + tracebacks = "" + for i, err in enumerate(self.errors): + if i not in (1, len(self.errors)): + tracebacks += ( + f"\n{''.join(traceback.format_exception(type(err), err, err.__traceback__))}\n{'-' * 32}\n" + ) + else: + tracebacks += f"\n{''.join(traceback.format_exception(type(err), err, err.__traceback__))}" + return f"Multiple errors occurred:\n{tracebacks}" diff --git a/aiosched/internals/syscalls.py b/aiosched/internals/syscalls.py index 6cbbb40..fb87ff1 100644 --- a/aiosched/internals/syscalls.py +++ b/aiosched/internals/syscalls.py @@ -103,43 +103,71 @@ async def suspend(): await syscall("suspend") +async def current_task() -> Task: + """ + Returns the currently running + task object + """ + + return await syscall("get_current_task") + + +async def join(task: Task): + """ + Tells the event loop that the current task + wants to wait on the given one, but without + waiting for its completion + """ + + await syscall("join", task) + + async def wait(task: Task) -> Any | None: """ Waits for the completion of a given task and returns its return value. Can be called multiple times by multiple tasks. - Raises an error if the task has - completed already. Please note that - exceptions are propagated, too + Returns immediately if the task has + completed already, but exceptions are + propagated only once :param task: The task to wait for :type task: :class: Task :returns: The task's return value, if any """ - if task.done(): - raise SchedulerError(f"task {task.name!r} has completed already") + current = await current_task() + if task is current: + raise SchedulerError("a task cannot join itself") await syscall("wait", task) - if task.exc: + if task.exc and task.state != TaskState.CANCELLED and task.propagate: + task.propagate = False raise task.exc return task.result -async def cancel(task: Task): +async def cancel(task: Task, block: bool = False): """ Cancels the given task. Note that cancellations may not happen immediately if the task is blocked in an uninterruptible - state + state. If block equals False, the default, + this function returns immediately, otherwise + it waits for the task to receive the cancellation :param task: The task to wait for :type task: :class: Task + :param block: Whether to wait for the task to be + actually cancelled or not, defaults to False + :type block: bool, optional """ await syscall("cancel", task) - if task.state != TaskState.CANCELLED: - raise SchedulerError(f"task {task.name!r} ignored cancellation") + if block: + await wait(task) + if not task.state == TaskState.CANCELLED: + raise SchedulerError(f"task {task.name!r} ignored cancellation") async def closing(stream): @@ -170,3 +198,19 @@ async def wait_writable(stream): """ await syscall("perform_io", stream, EVENT_WRITE) + + +async def set_context(ctx): + """ + Sets the current task context + """ + + await syscall("set_context", ctx) + + +async def close_context(): + """ + Closes the current task context + """ + + await syscall("close_context") diff --git a/aiosched/kernel.py b/aiosched/kernel.py index cef91eb..71ad9a6 100644 --- a/aiosched/kernel.py +++ b/aiosched/kernel.py @@ -24,6 +24,7 @@ from aiosched.internals.queues import TimeQueue from aiosched.util.debugging import BaseDebugger from typing import Callable, Any, Coroutine from aiosched.errors import InternalError, ResourceBusy, Cancelled, ResourceClosed, ResourceBroken +from aiosched.context import TaskContext from selectors import DefaultSelector, BaseSelector @@ -77,6 +78,8 @@ class FIFOKernel: self.data: dict[Task, Any] = {} # The currently running task self.current_task: Task | None = None + # The loop's entry point + self.entry_point: Task | None = None def __repr__(self): """ @@ -118,7 +121,7 @@ class FIFOKernel: """ if not self.done() and not force: - raise InternalError("cannot shut down a running event loop") + self.current_task.throw(InternalError("cannot shut down a running event loop")) for task in self.all(): self.cancel(task) @@ -183,10 +186,7 @@ class FIFOKernel: Reschedules the currently running task """ - if self.current_task: - self.run_ready.append(self.current_task) - else: - raise InternalError("aiosched is not running") + self.run_ready.append(self.current_task) def suspend(self): """ @@ -215,6 +215,8 @@ class FIFOKernel: while self.current_task.done(): # We need to make sure we don't try to execute # exited tasks that are on the running queue + if not self.run_ready: + return # No more tasks to run! self.current_task = self.run_ready.popleft() self.debugger.before_task_step(self.current_task) # Some debugging and internal chatter here @@ -227,18 +229,15 @@ class FIFOKernel: else: # Run a single step with the calculation (i.e. until a yield # somewhere) - method, args, kwargs = self.current_task.run( - self.data.get(self.current_task) - ) - self.data.pop(self.current_task, None) - if not hasattr(self, method) and not callable(getattr(self, method)): + method, args, kwargs = self.current_task.run(self.data.pop(self.current_task, None)) + if not hasattr(self, method) or not callable(getattr(self, method)): # This if block is meant to be triggered by other async # libraries, which most likely have different trap names and behaviors # compared to us. If you get this exception, and you're 100% sure you're # not mixing async primitives from other libraries, then it's a bug! - raise InternalError( + self.current_task.throw(InternalError( "Uh oh! Something very bad just happened, did you try to mix primitives from other async libraries?" - ) from None + )) # Sneaky method call, thanks to David Beazley for this ;) getattr(self, method)(*args, **kwargs) self.debugger.after_task_step(self.current_task) @@ -286,16 +285,18 @@ class FIFOKernel: Starts the event loop from a synchronous context """ - entry_point = Task(func.__name__ or str(func), func(*args, **kwargs)) - self.run_ready.append(entry_point) + self.entry_point = Task(func.__name__ or str(func), func(*args, **kwargs)) + self.run_ready.append(self.entry_point) self.debugger.on_start() try: self.run() finally: self.debugger.on_exit() - if entry_point.exc: - raise entry_point.exc - return entry_point.result + if self.entry_point.exc and self.entry_point.context is None and self.entry_point.propagate: + # Contexts already manage exceptions for us, + # no need to raise it manually + raise self.entry_point.exc + return self.entry_point.result def io_release(self, resource): """ @@ -344,23 +345,17 @@ class FIFOKernel: def cancel(self, task: Task): """ - Schedules the task to be cancelled later - or does so straight away if it is safe to do so + Attempts to cancel the given task or + schedules cancellation for later if + it fails """ - self.reschedule_running() - match task.state: - case TaskState.IO: - self.io_release_task(task) - case TaskState.PAUSED: - self.paused.discard(task) - case TaskState.INIT, TaskState.CANCELLED, TaskState.CRASHED: - return + self.io_release_task(task) + self.paused.discard(task) self.handle_task_run(partial(task.throw, Cancelled(task)), task) - if task.state == TaskState.CANCELLED: - self.debugger.after_cancel(task) - else: + if task.state != TaskState.CANCELLED: task.pending_cancellation = True + self.reschedule_running() def handle_task_run(self, func: Callable, task: Task | None = None): """ @@ -398,6 +393,7 @@ class FIFOKernel: task = task or self.current_task task.state = TaskState.CANCELLED task.pending_cancellation = False + self.debugger.after_cancel(task) self.wait(task) except BaseException as err: # Any other exception is caught here @@ -425,13 +421,24 @@ class FIFOKernel: def wait(self, task: Task): """ Makes the current task wait for completion of the given one + by only rescheduling it once the given task has finished + executing """ if task.done(): + self.paused.discard(task) + self.io_release_task(task) self.run_ready.extend(task.joiners) - task.joiners = {} - else: - task.joiners.add(self.current_task) + + def join(self, task: Task): + """ + Tells the event loop that the current task + wants to wait on the given one, but without + actually waiting for its completion + """ + + task.joiners.add(self.current_task) + self.reschedule_running() def spawn(self, func: Callable[..., Coroutine[Any, Any, Any]], *args, **kwargs): """ @@ -444,6 +451,44 @@ class FIFOKernel: self.run_ready.append(task) self.reschedule_running() + def set_context(self, ctx: TaskContext): + """ + Sets the current task context. This is + implemented as simply wrapping the current + task inside the context and replacing the + Task object with the TaskContext one + """ + + ctx.entry_point = self.current_task + if isinstance(self.current_task, TaskContext): + self.current_task.inner = ctx + else: + ctx.tasks.append(ctx.entry_point) + self.current_task.context = ctx + self.current_task = ctx + self.reschedule_running() + + def close_context(self): + """ + Closes the context associated with the current + task + """ + + ctx: TaskContext = self.current_task + task = ctx.entry_point + task.context = None + self.current_task = task + self.reschedule_running() + + def get_current_task(self): + """ + Returns the current task to an asynchronous + caller + """ + + self.data[self.current_task] = self.current_task + self.reschedule_running() + def perform_io(self, resource, evt_type: int): """ Registers the given resource inside our selector to perform I/O multiplexing @@ -495,4 +540,4 @@ class FIFOKernel: # If we get here, two tasks are trying to read or write on the same resource at the same time raise ResourceBusy( "The given resource is being read from/written to from another task" - ) from None + ) diff --git a/aiosched/runtime.py b/aiosched/runtime.py index 9e70cb2..d40d7c4 100644 --- a/aiosched/runtime.py +++ b/aiosched/runtime.py @@ -88,13 +88,14 @@ def run( get_event_loop().start(func, *args, **kwargs) -def with_context() -> TaskContext: +def with_context(*args, **kwargs) -> TaskContext: """ Creates and returns a new TaskContext - object + object. All positional and keyword arguments + are passed to TaskContext's constructor """ - return TaskContext() + return TaskContext(*args, **kwargs) def clock() -> float: diff --git a/aiosched/task.py b/aiosched/task.py index b1e4c2c..57d7e35 100644 --- a/aiosched/task.py +++ b/aiosched/task.py @@ -77,28 +77,39 @@ class Task: paused_when: float = 0.0 # The next deadline, in terms of the absolute clock of the loop, associated to the task next_deadline: float = 0.0 + # Is this task within a context? This is needed to fix a bug that would occur when + # the event loop tries to raise the exception caused by first task that kicked the + # loop even if that context already ignored said error + context: "TaskContext" = None + # We propagate exception only at the first call to wait() + propagate: bool = True def run(self, what: Any | None = None): """ - Simple abstraction layer over a coroutine's send method + Simple abstraction layer over a coroutine's send method. + Does nothing if the task has already exited :param what: The object that has to be sent to the coroutine, defaults to None :type what: Any, optional """ + if self.done(): + return return self.coroutine.send(what) def throw(self, err: BaseException): """ - Simple abstraction layer over a coroutine's throw method + Simple abstraction layer over a coroutine's throw method. + Does nothing if the task has already exited :param err: The exception that has to be raised inside the task :type err: BaseException """ - self.exc = err + if self.done(): + return return self.coroutine.throw(err) def __hash__(self): @@ -125,9 +136,5 @@ class Task: Task destructor """ - try: - self.coroutine.close() - except RuntimeError: - pass # TODO: This is kinda bad if self.last_io: warnings.warn(f"task '{self.name}' was destroyed, but has pending I/O") diff --git a/tests/context_catch.py b/tests/context_catch.py index 465bb44..730cd50 100644 --- a/tests/context_catch.py +++ b/tests/context_catch.py @@ -9,7 +9,7 @@ async def main(children: list[tuple[str, int]]): print("[main] Spawning children") for name, delay in children: await ctx.spawn(child, name, delay) - print(f"[main] Spawned {len(ctx.tasks)} children") + print("[main] Children spawned") before = aiosched.clock() except BaseException as err: print(f"[main] Child raised an exception -> {type(err).__name__}: {err}") diff --git a/tests/context_silent_catch.py b/tests/context_silent_catch.py new file mode 100644 index 0000000..2afac99 --- /dev/null +++ b/tests/context_silent_catch.py @@ -0,0 +1,19 @@ +import aiosched +from catch import child +from debugger import Debugger + + +async def main(children: list[tuple[str, int]]): + async with aiosched.with_context(silent=True) as ctx: + print("[main] Spawning children") + for name, delay in children: + await ctx.spawn(child, name, delay) + print("[main] Children spawned") + before = aiosched.clock() + if ctx.exc: + print(f"[main] Child raised an exception -> {type(ctx.exc).__name__}: {ctx.exc}") + print(f"[main] Children exited in {aiosched.clock() - before:.2f} seconds") + + +if __name__ == "__main__": + aiosched.run(main, [("first", 1), ("second", 2), ("third", 3)], debugger=None) diff --git a/tests/context_wait.py b/tests/context_wait.py index 05cdff3..8f9cac6 100644 --- a/tests/context_wait.py +++ b/tests/context_wait.py @@ -8,10 +8,10 @@ async def main(children: list[tuple[str, int]]): async with aiosched.with_context() as ctx: for name, delay in children: await ctx.spawn(child, name, delay) - print(f"[main] Spawned {len(ctx.tasks)} children") + print("[main] Children spawned") before = aiosched.clock() print(f"[main] Children exited in {aiosched.clock() - before:.2f} seconds") if __name__ == "__main__": - aiosched.run(main, [("first", 1), ("second", 2), ("third", 3)], debugger=None) + aiosched.run(main, [("first", 1), ("second", 2), ("third", 3)], debugger=None) \ No newline at end of file diff --git a/tests/nested_context_catch_inner.py b/tests/nested_context_catch_inner.py new file mode 100644 index 0000000..2317702 --- /dev/null +++ b/tests/nested_context_catch_inner.py @@ -0,0 +1,26 @@ +import aiosched +from catch import child as errorer +from wait import child as successful +from debugger import Debugger + + +async def main(children_outer: list[tuple[str, int]], children_inner: list[tuple[str, int]]): + before = aiosched.clock() + async with aiosched.with_context() as ctx: + print("[main] Spawning children in first context") + for name, delay in children_outer: + await ctx.spawn(successful, name, delay) + print("[main] Children spawned") + # An exception in an outer context cancels everything + # inside it, but an exception in an inner context does + # not affect outer ones + async with aiosched.with_context() as ctx2: + print("[main] Spawning children in second context") + for name, delay in children_inner: + await ctx2.spawn(errorer, name, delay) + print("[main] Children spawned") + print(f"[main] Children exited in {aiosched.clock() - before:.2f} seconds") + + +if __name__ == "__main__": + aiosched.run(main, [("first", 1), ("second", 2)], [("third", 3), ("fourth", 4)], debugger=None) diff --git a/tests/nested_context_catch_outer.py b/tests/nested_context_catch_outer.py new file mode 100644 index 0000000..cd96614 --- /dev/null +++ b/tests/nested_context_catch_outer.py @@ -0,0 +1,26 @@ +import aiosched +from catch import child +from debugger import Debugger + + +# TODO: This crashes 1 second later than it should be +async def main(children_outer: list[tuple[str, int]], children_inner: list[tuple[str, int]]): + try: + async with aiosched.with_context() as ctx: + before = aiosched.clock() + print("[main] Spawning children in first context") + for name, delay in children_outer: + await ctx.spawn(child, name, delay) + print("[main] Children spawned") + async with aiosched.with_context() as ctx2: + print("[main] Spawning children in second context") + for name, delay in children_inner: + await ctx2.spawn(child, name, delay) + print("[main] Children spawned") + except BaseException as err: + print(f"[main] Child raised an exception -> {type(err).__name__}: {err}") + print(f"[main] Children exited in {aiosched.clock() - before:.2f} seconds") + + +if __name__ == "__main__": + aiosched.run(main, [("first", 1), ("second", 2)], [("third", 3), ("fourth", 4)], debugger=None) diff --git a/tests/nested_context_wait.py b/tests/nested_context_wait.py new file mode 100644 index 0000000..0ac9b24 --- /dev/null +++ b/tests/nested_context_wait.py @@ -0,0 +1,22 @@ +import aiosched +from wait import child +from debugger import Debugger + + +async def main(children_outer: list[tuple[str, int]], children_inner: list[tuple[str, int]]): + async with aiosched.with_context() as ctx: + before = aiosched.clock() + print("[main] Spawning children in first context") + for name, delay in children_outer: + await ctx.spawn(child, name, delay) + print("[main] Children spawned") + async with aiosched.with_context() as ctx2: + print("[main] Spawning children in second context") + for name, delay in children_inner: + await ctx2.spawn(child, name, delay) + print("[main] Children spawned") + print(f"[main] Children exited in {aiosched.clock() - before:.2f} seconds") + + +if __name__ == "__main__": + aiosched.run(main, [("first", 1), ("second", 2)], [("third", 3), ("fourth", 4)], debugger=None)