""" aiosched: Yet another Python async scheduler Copyright (C) 2022 nocturn9x Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at https:www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ from aiosched.task import Task from aiosched.internals.syscalls import ( spawn, wait, cancel, set_context, close_context, current_task, sleep, throw, set_scope, close_scope, get_current_scope ) from aiosched.sync import Event from typing import Any, Coroutine, Callable class TaskScope: def __init__(self, timeout: int | float = 0.0, silent: bool = False): self.timeout = timeout self.silent = silent self.inner: TaskScope | None = None self.outer: TaskScope | None = None self.pools: list[TaskPool] = list() self.waiter: Task | None = None self.entry_point: Task | None = None self.timed_out: bool = False async def _timeout_worker(self): await sleep(self.timeout) for pool in self.pools: if not pool.done(): self.timed_out = True await pool.cancel() if pool.entry_point is not self.entry_point: await cancel(pool.entry_point, block=True) if not self.entry_point.done(): self.timed_out = True # raise TimeoutError("timed out") await throw(self.entry_point, TimeoutError("timed out")) async def __aenter__(self): self.entry_point = await current_task() await set_scope(self) if self.timeout: self.waiter = await spawn(self._timeout_worker) return self async def __aexit__(self, exc_type: type, exception: Exception, tb): await close_scope(self) if not self.waiter.done(): await cancel(self.waiter, block=True) if exception is not None: return self.silent class TaskPool: """ An asynchronous context manager that automatically waits for all tasks spawned within it and cancels itself when an exception occurs. Contexts can be nested and will cancel inner ones if an exception is raised inside them """ def __init__(self, gather: bool = True) -> None: """ Object constructor """ # All the tasks that belong to this context self.tasks: list[Task] = [] # Whether we have been cancelled or not self.cancelled: bool = False # The context's entry point self.entry_point: Task | TaskPool | None = None # Do we gather multiple exceptions from # children tasks? self.gather: bool = gather # TODO: Implement # Have we crashed? self.error: BaseException | None = None # Data about inner and outer contexts self.inner: TaskPool | None = None self.outer: TaskPool | None = None self.event: Event = Event() async def spawn( self, func: Callable[..., Coroutine[Any, Any, Any]], *args, **kwargs ) -> Task: """ Spawns a child task """ task = await spawn(func, *args, **kwargs) task.context = self self.tasks.append(task) return task async def __aenter__(self): """ Implements the asynchronous context manager interface """ self.entry_point = await current_task() scope = await get_current_scope() if scope: scope.pools.append(self) await set_context(self) return self async def __aexit__(self, exc_type: Exception, exc: Exception, tb): """ Implements the asynchronous context manager interface, waiting for all the tasks spawned inside the context and handling exceptions """ try: for task in self.tasks: # This forces the interpreter to stop at the # end of the block and wait for all # children to exit await wait(task) if self.inner: # We wait for inner contexts to terminate await self.event.wait() except (Exception, KeyboardInterrupt) as exc: if not self.cancelled: await self.cancel() self.error = exc finally: self.entry_point.propagate = True await close_context(self) self.entry_point.context = None if self.outer: # We reschedule the entry point of the outer # context once we're done await self.outer.event.trigger() if self.error and not self.outer: raise self.error async def cancel(self): """ Cancels the entire context, iterating over all of its tasks (which includes inner contexts) and cancelling them """ for task in self.tasks: await cancel(task, block=True) if self.inner: await self.inner.cancel() self.cancelled = True def done(self) -> bool: """ Returns whether all the tasks inside the context have exited """ for task in self.tasks: if not task.done(): return False return self.entry_point.done() def __repr__(self): """ Implements repr(self) """ result = "TaskPool([" for i, task in enumerate(self.tasks): result += repr(task) if i < len(self.tasks) - 1: result += ", " result += "])" return result