197 lines
6.2 KiB
Python
197 lines
6.2 KiB
Python
"""
|
|
aiosched: Yet another Python async scheduler
|
|
|
|
Copyright (C) 2022 nocturn9x
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
https:www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
"""
|
|
from aiosched.task import Task
|
|
from aiosched.internals.syscalls import (
|
|
spawn,
|
|
wait,
|
|
cancel,
|
|
set_context,
|
|
close_context,
|
|
current_task,
|
|
sleep,
|
|
throw,
|
|
set_scope,
|
|
close_scope
|
|
)
|
|
from aiosched.sync import Event
|
|
from typing import Any, Coroutine, Callable
|
|
|
|
|
|
class TaskScope:
|
|
def __init__(self, timeout: int | float = 0.0, silent: bool = False):
|
|
self.timeout = timeout
|
|
self.silent = silent
|
|
self.inner: TaskScope | None = None
|
|
self.outer: TaskScope | None = None
|
|
self.waiter: Task | None = None
|
|
self.entry_point: Task | None = None
|
|
self.timed_out: bool = False
|
|
# Can we be cancelled?
|
|
self.cancellable: bool = True
|
|
# Task scope of our timeout worker
|
|
self.timeout_scope: TaskScope | None = None
|
|
|
|
async def _timeout_worker(self):
|
|
async with TaskScope() as scope:
|
|
self.timeout_scope = scope
|
|
# We can't let this task be cancelled
|
|
# because this is the only safeguard of
|
|
# our timeouts: if this crashes, then
|
|
# timeouts don't work at all!
|
|
scope.cancellable = False
|
|
await sleep(self.timeout)
|
|
if not self.entry_point.done():
|
|
self.timed_out = True
|
|
await throw(self.entry_point, TimeoutError("timed out"))
|
|
|
|
async def __aenter__(self):
|
|
self.entry_point = await current_task()
|
|
await set_scope(self)
|
|
if self.timeout:
|
|
self.waiter = await spawn(self._timeout_worker)
|
|
return self
|
|
|
|
async def __aexit__(self, exc_type: type, exception: Exception, tb):
|
|
if self.timeout and not self.waiter.done():
|
|
# Well, looks like we finished before our worker.
|
|
# Thanks for your help! Now die.
|
|
self.timeout_scope.cancellable = True
|
|
await cancel(self.waiter, block=True)
|
|
# Task scopes are sick: Nathaniel, you're an effing genius.
|
|
await close_scope(self)
|
|
if isinstance(exception, TimeoutError) and self.timed_out:
|
|
# This way we only silence our own timeouts and not
|
|
# someone else's!
|
|
return self.silent
|
|
|
|
|
|
class TaskPool:
|
|
"""
|
|
An asynchronous context manager that automatically waits
|
|
for all tasks spawned within it and cancels itself when
|
|
an exception occurs. Pools can be nested and will
|
|
cancel inner ones if an exception is raised inside them
|
|
"""
|
|
|
|
def __init__(self, gather: bool = True) -> None:
|
|
"""
|
|
Object constructor
|
|
"""
|
|
|
|
# All the tasks that belong to this context
|
|
self.tasks: list[Task] = []
|
|
# Whether we have been cancelled or not
|
|
self.cancelled: bool = False
|
|
# The context's entry point
|
|
self.entry_point: Task | TaskPool | None = None
|
|
# Do we gather multiple exceptions from
|
|
# children tasks?
|
|
self.gather: bool = gather # TODO: Implement
|
|
# Have we crashed?
|
|
self.error: BaseException | None = None
|
|
# Data about inner and outer contexts
|
|
self.inner: TaskPool | None = None
|
|
self.outer: TaskPool | None = None
|
|
self.event: Event = Event()
|
|
|
|
async def spawn(
|
|
self, func: Callable[..., Coroutine[Any, Any, Any]], *args, **kwargs
|
|
) -> Task:
|
|
"""
|
|
Spawns a child task
|
|
"""
|
|
|
|
return await spawn(func, *args, **kwargs)
|
|
|
|
async def __aenter__(self):
|
|
"""
|
|
Implements the asynchronous context manager interface
|
|
"""
|
|
|
|
self.entry_point = await current_task()
|
|
await set_context(self)
|
|
return self
|
|
|
|
async def __aexit__(self, exc_type: Exception, exc: Exception, tb):
|
|
"""
|
|
Implements the asynchronous context manager interface, waiting
|
|
for all the tasks spawned inside the context and handling
|
|
exceptions
|
|
"""
|
|
|
|
try:
|
|
for task in self.tasks:
|
|
# This forces the interpreter to stop at the
|
|
# end of the block and wait for all
|
|
# children to exit
|
|
await wait(task)
|
|
if self.inner:
|
|
# We wait for inner contexts to terminate
|
|
await self.event.wait()
|
|
except (Exception, KeyboardInterrupt) as err:
|
|
if not self.cancelled:
|
|
await self.cancel()
|
|
self.error = err
|
|
finally:
|
|
self.entry_point.propagate = True
|
|
await close_context(self)
|
|
self.entry_point.context = None
|
|
if self.outer:
|
|
# We reschedule the entry point of the outer
|
|
# context once we're done
|
|
await self.outer.event.trigger()
|
|
if self.error and not self.outer:
|
|
raise self.error
|
|
|
|
async def cancel(self):
|
|
"""
|
|
Cancels the entire context, iterating over all
|
|
of its tasks (which includes inner contexts)
|
|
and cancelling them
|
|
"""
|
|
|
|
for task in self.tasks:
|
|
await cancel(task, block=True)
|
|
if self.inner:
|
|
await self.inner.cancel()
|
|
self.cancelled = True
|
|
|
|
def done(self) -> bool:
|
|
"""
|
|
Returns whether all the tasks inside the
|
|
context have exited
|
|
"""
|
|
|
|
for task in self.tasks:
|
|
if not task.done():
|
|
return False
|
|
return self.entry_point.done()
|
|
|
|
def __repr__(self):
|
|
"""
|
|
Implements repr(self)
|
|
"""
|
|
|
|
result = "TaskPool(["
|
|
for i, task in enumerate(self.tasks):
|
|
result += repr(task)
|
|
if i < len(self.tasks) - 1:
|
|
result += ", "
|
|
result += "])"
|
|
return result
|