This repository has been archived on 2023-05-12. You can view files and clone it, but cannot push or open issues or pull requests.
aiosched/aiosched/context.py

197 lines
6.2 KiB
Python
Raw Normal View History

2022-10-18 17:26:58 +02:00
"""
aiosched: Yet another Python async scheduler
Copyright (C) 2022 nocturn9x
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https:www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from aiosched.task import Task
from aiosched.internals.syscalls import (
spawn,
wait,
cancel,
2023-04-28 16:04:30 +02:00
set_context,
close_context,
2023-04-22 12:52:44 +02:00
current_task,
2023-04-28 16:04:30 +02:00
sleep,
throw,
set_scope,
2023-05-10 11:04:27 +02:00
close_scope
)
2023-04-28 16:04:30 +02:00
from aiosched.sync import Event
2022-10-18 17:26:58 +02:00
from typing import Any, Coroutine, Callable
2023-04-28 16:04:30 +02:00
class TaskScope:
def __init__(self, timeout: int | float = 0.0, silent: bool = False):
self.timeout = timeout
self.silent = silent
self.inner: TaskScope | None = None
self.outer: TaskScope | None = None
self.waiter: Task | None = None
self.entry_point: Task | None = None
self.timed_out: bool = False
# Can we be cancelled?
self.cancellable: bool = True
# Task scope of our timeout worker
2023-05-10 11:04:27 +02:00
self.timeout_scope: TaskScope | None = None
2023-04-28 16:04:30 +02:00
async def _timeout_worker(self):
async with TaskScope() as scope:
self.timeout_scope = scope
# We can't let this task be cancelled
2023-05-10 11:04:27 +02:00
# because this is the only safeguard of
# our timeouts: if this crashes, then
# timeouts don't work at all!
scope.cancellable = False
await sleep(self.timeout)
if not self.entry_point.done():
2023-04-28 16:04:30 +02:00
self.timed_out = True
await throw(self.entry_point, TimeoutError("timed out"))
2023-04-28 16:04:30 +02:00
async def __aenter__(self):
self.entry_point = await current_task()
await set_scope(self)
if self.timeout:
self.waiter = await spawn(self._timeout_worker)
return self
async def __aexit__(self, exc_type: type, exception: Exception, tb):
if self.timeout and not self.waiter.done():
# Well, looks like we finished before our worker.
# Thanks for your help! Now die.
self.timeout_scope.cancellable = True
2023-04-28 16:04:30 +02:00
await cancel(self.waiter, block=True)
# Task scopes are sick: Nathaniel, you're an effing genius.
2023-05-10 22:56:34 +02:00
await close_scope(self)
if isinstance(exception, TimeoutError) and self.timed_out:
# This way we only silence our own timeouts and not
# someone else's!
2023-04-28 16:04:30 +02:00
return self.silent
class TaskPool:
2022-10-18 17:26:58 +02:00
"""
2022-10-19 11:31:45 +02:00
An asynchronous context manager that automatically waits
for all tasks spawned within it and cancels itself when
an exception occurs. Pools can be nested and will
cancel inner ones if an exception is raised inside them
2022-10-18 17:26:58 +02:00
"""
2023-04-28 16:04:30 +02:00
def __init__(self, gather: bool = True) -> None:
2022-10-18 17:26:58 +02:00
"""
Object constructor
"""
2022-10-19 11:31:45 +02:00
# All the tasks that belong to this context
self.tasks: list[Task] = []
2022-10-18 17:26:58 +02:00
# Whether we have been cancelled or not
self.cancelled: bool = False
2023-04-28 16:04:30 +02:00
# The context's entry point
self.entry_point: Task | TaskPool | None = None
2022-10-19 11:31:45 +02:00
# Do we gather multiple exceptions from
# children tasks?
self.gather: bool = gather # TODO: Implement
# Have we crashed?
self.error: BaseException | None = None
2023-04-28 16:04:30 +02:00
# Data about inner and outer contexts
self.inner: TaskPool | None = None
self.outer: TaskPool | None = None
self.event: Event = Event()
2023-04-22 12:52:44 +02:00
2022-10-18 17:26:58 +02:00
async def spawn(
self, func: Callable[..., Coroutine[Any, Any, Any]], *args, **kwargs
2022-10-18 17:26:58 +02:00
) -> Task:
"""
Spawns a child task
"""
return await spawn(func, *args, **kwargs)
2022-10-18 17:26:58 +02:00
async def __aenter__(self):
"""
Implements the asynchronous context manager interface
"""
self.entry_point = await current_task()
2023-04-28 16:04:30 +02:00
await set_context(self)
2022-10-18 17:26:58 +02:00
return self
async def __aexit__(self, exc_type: Exception, exc: Exception, tb):
"""
Implements the asynchronous context manager interface, waiting
2022-10-19 11:31:45 +02:00
for all the tasks spawned inside the context and handling
exceptions
2022-10-18 17:26:58 +02:00
"""
2022-10-19 11:31:45 +02:00
try:
for task in self.tasks:
# This forces the interpreter to stop at the
# end of the block and wait for all
# children to exit
2022-10-18 17:26:58 +02:00
await wait(task)
2023-04-28 16:04:30 +02:00
if self.inner:
# We wait for inner contexts to terminate
await self.event.wait()
except (Exception, KeyboardInterrupt) as err:
2023-04-28 16:04:30 +02:00
if not self.cancelled:
await self.cancel()
self.error = err
finally:
self.entry_point.propagate = True
2023-04-28 16:04:30 +02:00
await close_context(self)
self.entry_point.context = None
if self.outer:
# We reschedule the entry point of the outer
# context once we're done
await self.outer.event.trigger()
if self.error and not self.outer:
2023-04-22 12:52:44 +02:00
raise self.error
2022-10-19 11:31:45 +02:00
2023-04-28 16:04:30 +02:00
async def cancel(self):
2022-10-18 17:26:58 +02:00
"""
Cancels the entire context, iterating over all
2022-10-19 11:31:45 +02:00
of its tasks (which includes inner contexts)
and cancelling them
2022-10-18 17:26:58 +02:00
"""
for task in self.tasks:
2023-04-28 16:04:30 +02:00
await cancel(task, block=True)
if self.inner:
await self.inner.cancel()
2022-10-18 17:26:58 +02:00
self.cancelled = True
def done(self) -> bool:
"""
2022-10-19 11:31:45 +02:00
Returns whether all the tasks inside the
context have exited
2022-10-18 17:26:58 +02:00
"""
2022-10-19 11:31:45 +02:00
for task in self.tasks:
if not task.done():
return False
2023-04-28 16:04:30 +02:00
return self.entry_point.done()
2022-10-18 17:26:58 +02:00
def __repr__(self):
"""
Implements repr(self)
"""
2023-04-28 16:04:30 +02:00
result = "TaskPool(["
2022-10-19 11:31:45 +02:00
for i, task in enumerate(self.tasks):
result += repr(task)
2022-10-19 11:31:45 +02:00
if i < len(self.tasks) - 1:
result += ", "
result += "])"
return result