""" aiosched: Yet another Python async scheduler Copyright (C) 2022 nocturn9x Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at https:www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ from aiosched.task import Task from aiosched.internals.syscalls import ( spawn, wait, cancel, set_context, close_context, current_task, sleep, throw, set_scope, close_scope ) from aiosched.sync import Event from typing import Any, Coroutine, Callable class TaskScope: def __init__(self, timeout: int | float = 0.0, silent: bool = False): self.timeout = timeout self.silent = silent self.inner: TaskScope | None = None self.outer: TaskScope | None = None self.waiter: Task | None = None self.entry_point: Task | None = None self.timed_out: bool = False # Can we be cancelled? self.cancellable: bool = True # Task scope of our timeout worker self.timeout_scope: TaskScope | None = None async def _timeout_worker(self): async with TaskScope() as scope: self.timeout_scope = scope # We can't let this task be cancelled # because this is the only safeguard of # our timeouts: if this crashes, then # timeouts don't work at all! scope.cancellable = False await sleep(self.timeout) if not self.entry_point.done(): self.timed_out = True await throw(self.entry_point, TimeoutError("timed out")) async def __aenter__(self): self.entry_point = await current_task() await set_scope(self) if self.timeout: self.waiter = await spawn(self._timeout_worker) return self async def __aexit__(self, exc_type: type, exception: Exception, tb): if self.timeout and not self.waiter.done(): # Well, looks like we finished before our worker. # Thanks for your help! Now die. self.timeout_scope.cancellable = True await cancel(self.waiter, block=True) # Task scopes are sick: Nathaniel, you're an effing genius. await close_scope(self) if isinstance(exception, TimeoutError) and self.timed_out: # This way we only silence our own timeouts and not # someone else's! return self.silent class TaskPool: """ An asynchronous context manager that automatically waits for all tasks spawned within it and cancels itself when an exception occurs. Pools can be nested and will cancel inner ones if an exception is raised inside them """ def __init__(self, gather: bool = True) -> None: """ Object constructor """ # All the tasks that belong to this context self.tasks: list[Task] = [] # Whether we have been cancelled or not self.cancelled: bool = False # The context's entry point self.entry_point: Task | TaskPool | None = None # Do we gather multiple exceptions from # children tasks? self.gather: bool = gather # TODO: Implement # Have we crashed? self.error: BaseException | None = None # Data about inner and outer contexts self.inner: TaskPool | None = None self.outer: TaskPool | None = None self.event: Event = Event() async def spawn( self, func: Callable[..., Coroutine[Any, Any, Any]], *args, **kwargs ) -> Task: """ Spawns a child task """ return await spawn(func, *args, **kwargs) async def __aenter__(self): """ Implements the asynchronous context manager interface """ self.entry_point = await current_task() await set_context(self) return self async def __aexit__(self, exc_type: Exception, exc: Exception, tb): """ Implements the asynchronous context manager interface, waiting for all the tasks spawned inside the context and handling exceptions """ try: for task in self.tasks: # This forces the interpreter to stop at the # end of the block and wait for all # children to exit await wait(task) if self.inner: # We wait for inner contexts to terminate await self.event.wait() except (Exception, KeyboardInterrupt) as err: if not self.cancelled: await self.cancel() self.error = err finally: self.entry_point.propagate = True await close_context(self) self.entry_point.context = None if self.outer: # We reschedule the entry point of the outer # context once we're done await self.outer.event.trigger() if self.error and not self.outer: raise self.error async def cancel(self): """ Cancels the entire context, iterating over all of its tasks (which includes inner contexts) and cancelling them """ for task in self.tasks: await cancel(task, block=True) if self.inner: await self.inner.cancel() self.cancelled = True def done(self) -> bool: """ Returns whether all the tasks inside the context have exited """ for task in self.tasks: if not task.done(): return False return self.entry_point.done() def __repr__(self): """ Implements repr(self) """ result = "TaskPool([" for i, task in enumerate(self.tasks): result += repr(task) if i < len(self.tasks) - 1: result += ", " result += "])" return result