""" aiosched: Yet another Python async scheduler Copyright (C) 2022 nocturn9x Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at https:www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ from aiosched.task import Task from aiosched.internals.syscalls import ( spawn, wait, cancel, set_context, close_context, join, current_task, ) from typing import Any, Coroutine, Callable class TaskContext(Task): """ An asynchronous context manager that automatically waits for all tasks spawned within it and cancels itself when an exception occurs. A TaskContext object behaves like a regular task and the event loop treats it like a single unit rather than a collection of tasks (in fact, the event loop doesn't even know, nor care about, whether the current task is a task context or not, which is by design). Contexts can be nested and will cancel inner ones if an exception is raised inside them """ def __init__(self, silent: bool = False, gather: bool = True, timeout: int | float = 0.0) -> None: """ Object constructor """ # All the tasks that belong to this context self.tasks: list[Task] = [] # Whether we have been cancelled or not self.cancelled: bool = False # The context's entry point (needed to disguise ourselves as a task ;)) self.entry_point: Task | TaskContext | None = None # Do we ignore exceptions? self.silent: bool = silent # Do we gather multiple exceptions from # children tasks? self.gather: bool = gather # TODO: Implement # For how long do we allow tasks inside us # to run? self.timeout: int | float = timeout # TODO: Implement async def spawn( self, func: Callable[..., Coroutine[Any, Any, Any]], *args, **kwargs ) -> Task: """ Spawns a child task """ task = await spawn(func, *args, **kwargs) task.context = self self.tasks.append(task) await join(task) return task async def __aenter__(self): """ Implements the asynchronous context manager interface """ await set_context(self) return self def __eq__(self, other): """ Implements self == other """ if isinstance(other, TaskContext): return super().__eq__(self, other) elif isinstance(other, Task): return other == self.entry_point return False async def __aexit__(self, exc_type: Exception, exc: Exception, tb): """ Implements the asynchronous context manager interface, waiting for all the tasks spawned inside the context and handling exceptions """ try: for task in self.tasks: # This forces the interpreter to stop at the # end of the block and wait for all # children to exit if task is self.entry_point: # We don't wait on the entry # point because that's us! # Besides, even if we tried, # wait() would raise an error # to avoid a deadlock continue await wait(task) except BaseException as exc: await self.cancel(False) self.exc = exc if not self.silent: raise self.exc finally: await close_context(self) self.entry_point.propagate = True # Task method wrappers async def cancel(self, propagate: bool = True): """ Cancels the entire context, iterating over all of its tasks (which includes inner contexts) and cancelling them """ for task in self.tasks: if task is self.entry_point: continue if isinstance(task, Task): await cancel(task) else: task: TaskContext await task.cancel(propagate) self.cancelled = True self.propagate = False if propagate: if isinstance(self.entry_point, Task): await cancel(self.entry_point) else: self.entry_point: TaskContext await self.entry_point.cancel(propagate) def done(self) -> bool: """ Returns whether all the tasks inside the context have exited """ for task in self.tasks: if not task.done(): return False return True @property def state(self) -> int: return self.entry_point.state @state.setter def state(self, state: int): self.entry_point.state = state @property def result(self) -> Any: return self.entry_point.result @result.setter def result(self, result: Any): self.entry_point.result = result @property def exc(self) -> BaseException: return self.entry_point.exc @exc.setter def exc(self, exc: BaseException): self.entry_point.exc = exc @property def propagate(self) -> bool: return self.entry_point.propagate @propagate.setter def propagate(self, val: bool): self.entry_point.propagate = val @property def name(self): return self.entry_point.name def throw(self, err: BaseException): for task in self.tasks: try: task.throw(err) except err: continue self.entry_point.throw(err) @property def joiners(self) -> set[Task]: return self.entry_point.joiners @joiners.setter def joiners(self, joiners: set[Task]): self.entry_point.joiners = joiners @property def coroutine(self): return self.entry_point.coroutine def __hash__(self): return self.entry_point.__hash__() def run(self, what: Any | None = None): return self.entry_point.run(what) def __del__(self): """ Context destructor """ for task in self.tasks: task.__del__() def __repr__(self): """ Implements repr(self) """ result = "TaskContext([" for i, task in enumerate(self.tasks): if task is self.entry_point: result += repr(self.entry_point) else: result += repr(task) if i < len(self.tasks) - 1: result += ", " result += "])" return result