This repository has been archived on 2023-05-12. You can view files and clone it, but cannot push or open issues or pull requests.
aiosched/aiosched/context.py

197 lines
6.2 KiB
Python

"""
aiosched: Yet another Python async scheduler
Copyright (C) 2022 nocturn9x
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https:www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from aiosched.task import Task
from aiosched.internals.syscalls import (
spawn,
wait,
cancel,
set_context,
close_context,
current_task,
sleep,
throw,
set_scope,
close_scope
)
from aiosched.sync import Event
from typing import Any, Coroutine, Callable
class TaskScope:
def __init__(self, timeout: int | float = 0.0, silent: bool = False):
self.timeout = timeout
self.silent = silent
self.inner: TaskScope | None = None
self.outer: TaskScope | None = None
self.waiter: Task | None = None
self.entry_point: Task | None = None
self.timed_out: bool = False
# Can we be cancelled?
self.cancellable: bool = True
# Task scope of our timeout worker
self.timeout_scope: TaskScope | None = None
async def _timeout_worker(self):
async with TaskScope() as scope:
self.timeout_scope = scope
# We can't let this task be cancelled
# because this is the only safeguard of
# our timeouts: if this crashes, then
# timeouts don't work at all!
scope.cancellable = False
await sleep(self.timeout)
if not self.entry_point.done():
self.timed_out = True
await throw(self.entry_point, TimeoutError("timed out"))
async def __aenter__(self):
self.entry_point = await current_task()
await set_scope(self)
if self.timeout:
self.waiter = await spawn(self._timeout_worker)
return self
async def __aexit__(self, exc_type: type, exception: Exception, tb):
if self.timeout and not self.waiter.done():
# Well, looks like we finished before our worker.
# Thanks for your help! Now die.
self.timeout_scope.cancellable = True
await cancel(self.waiter, block=True)
# Task scopes are sick: Nathaniel, you're an effing genius.
await close_scope(self)
if isinstance(exception, TimeoutError) and self.timed_out:
# This way we only silence our own timeouts and not
# someone else's!
return self.silent
class TaskPool:
"""
An asynchronous context manager that automatically waits
for all tasks spawned within it and cancels itself when
an exception occurs. Pools can be nested and will
cancel inner ones if an exception is raised inside them
"""
def __init__(self, gather: bool = True) -> None:
"""
Object constructor
"""
# All the tasks that belong to this context
self.tasks: list[Task] = []
# Whether we have been cancelled or not
self.cancelled: bool = False
# The context's entry point
self.entry_point: Task | TaskPool | None = None
# Do we gather multiple exceptions from
# children tasks?
self.gather: bool = gather # TODO: Implement
# Have we crashed?
self.error: BaseException | None = None
# Data about inner and outer contexts
self.inner: TaskPool | None = None
self.outer: TaskPool | None = None
self.event: Event = Event()
async def spawn(
self, func: Callable[..., Coroutine[Any, Any, Any]], *args, **kwargs
) -> Task:
"""
Spawns a child task
"""
return await spawn(func, *args, **kwargs)
async def __aenter__(self):
"""
Implements the asynchronous context manager interface
"""
self.entry_point = await current_task()
await set_context(self)
return self
async def __aexit__(self, exc_type: Exception, exc: Exception, tb):
"""
Implements the asynchronous context manager interface, waiting
for all the tasks spawned inside the context and handling
exceptions
"""
try:
for task in self.tasks:
# This forces the interpreter to stop at the
# end of the block and wait for all
# children to exit
await wait(task)
if self.inner:
# We wait for inner contexts to terminate
await self.event.wait()
except (Exception, KeyboardInterrupt) as err:
if not self.cancelled:
await self.cancel()
self.error = err
finally:
self.entry_point.propagate = True
await close_context(self)
self.entry_point.context = None
if self.outer:
# We reschedule the entry point of the outer
# context once we're done
await self.outer.event.trigger()
if self.error and not self.outer:
raise self.error
async def cancel(self):
"""
Cancels the entire context, iterating over all
of its tasks (which includes inner contexts)
and cancelling them
"""
for task in self.tasks:
await cancel(task, block=True)
if self.inner:
await self.inner.cancel()
self.cancelled = True
def done(self) -> bool:
"""
Returns whether all the tasks inside the
context have exited
"""
for task in self.tasks:
if not task.done():
return False
return self.entry_point.done()
def __repr__(self):
"""
Implements repr(self)
"""
result = "TaskPool(["
for i, task in enumerate(self.tasks):
result += repr(task)
if i < len(self.tasks) - 1:
result += ", "
result += "])"
return result