2022-10-18 17:26:58 +02:00
|
|
|
"""
|
|
|
|
aiosched: Yet another Python async scheduler
|
|
|
|
|
|
|
|
Copyright (C) 2022 nocturn9x
|
|
|
|
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
you may not use this file except in compliance with the License.
|
|
|
|
You may obtain a copy of the License at
|
|
|
|
|
|
|
|
https:www.apache.org/licenses/LICENSE-2.0
|
|
|
|
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
See the License for the specific language governing permissions and
|
|
|
|
limitations under the License.
|
|
|
|
"""
|
|
|
|
from aiosched.task import Task
|
2022-10-19 11:54:32 +02:00
|
|
|
from aiosched.internals.syscalls import (
|
|
|
|
spawn,
|
|
|
|
wait,
|
|
|
|
cancel,
|
|
|
|
set_context,
|
|
|
|
close_context,
|
|
|
|
join,
|
|
|
|
)
|
2022-10-18 17:26:58 +02:00
|
|
|
from typing import Any, Coroutine, Callable
|
|
|
|
|
|
|
|
|
|
|
|
class TaskContext(Task):
|
|
|
|
"""
|
2022-10-19 11:31:45 +02:00
|
|
|
An asynchronous context manager that automatically waits
|
|
|
|
for all tasks spawned within it and cancels itself when
|
|
|
|
an exception occurs. A TaskContext object behaves like
|
|
|
|
a regular task and the event loop treats it like a single
|
|
|
|
unit rather than a collection of tasks (in fact, the event
|
|
|
|
loop doesn't even know whether the current task is a task
|
|
|
|
context or not, which is by design). TaskContexts can be
|
|
|
|
nested and will cancel inner ones if an exception is raised
|
|
|
|
inside them
|
2022-10-18 17:26:58 +02:00
|
|
|
"""
|
|
|
|
|
2022-10-19 11:31:45 +02:00
|
|
|
def __init__(self, silent: bool = False, gather: bool = True) -> None:
|
2022-10-18 17:26:58 +02:00
|
|
|
"""
|
|
|
|
Object constructor
|
|
|
|
"""
|
|
|
|
|
2022-10-19 11:31:45 +02:00
|
|
|
# All the tasks that belong to this context
|
|
|
|
self.tasks: list[Task] = []
|
2022-10-18 17:26:58 +02:00
|
|
|
# Whether we have been cancelled or not
|
|
|
|
self.cancelled: bool = False
|
2022-10-19 11:31:45 +02:00
|
|
|
# The context's entry point (needed to forward run() calls and the like)
|
|
|
|
self.entry_point: Task | TaskContext | None = None
|
|
|
|
# Do we ignore exceptions?
|
|
|
|
self.silent: bool = silent
|
|
|
|
# Do we gather multiple exceptions from
|
|
|
|
# children tasks?
|
|
|
|
self.gather: bool = gather
|
|
|
|
|
2022-10-18 17:26:58 +02:00
|
|
|
async def spawn(
|
2022-11-02 09:28:04 +01:00
|
|
|
self, func: Callable[..., Coroutine[Any, Any, Any]], *args, **kwargs
|
2022-10-18 17:26:58 +02:00
|
|
|
) -> Task:
|
|
|
|
"""
|
|
|
|
Spawns a child task
|
|
|
|
"""
|
|
|
|
|
|
|
|
task = await spawn(func, *args, **kwargs)
|
2022-10-19 11:31:45 +02:00
|
|
|
task.context = self
|
2022-10-18 17:26:58 +02:00
|
|
|
self.tasks.append(task)
|
2022-11-02 09:28:04 +01:00
|
|
|
await join(task)
|
2022-10-18 17:26:58 +02:00
|
|
|
return task
|
|
|
|
|
|
|
|
async def __aenter__(self):
|
|
|
|
"""
|
|
|
|
Implements the asynchronous context manager interface
|
|
|
|
"""
|
|
|
|
|
2022-10-19 11:31:45 +02:00
|
|
|
await set_context(self)
|
2022-10-18 17:26:58 +02:00
|
|
|
return self
|
|
|
|
|
|
|
|
async def __aexit__(self, exc_type: Exception, exc: Exception, tb):
|
|
|
|
"""
|
|
|
|
Implements the asynchronous context manager interface, waiting
|
2022-10-19 11:31:45 +02:00
|
|
|
for all the tasks spawned inside the context and handling
|
|
|
|
exceptions
|
2022-10-18 17:26:58 +02:00
|
|
|
"""
|
|
|
|
|
2022-10-19 11:31:45 +02:00
|
|
|
try:
|
|
|
|
for task in self.tasks:
|
|
|
|
# This forces the interpreter to stop at the
|
|
|
|
# end of the block and wait for all
|
|
|
|
# children to exit
|
|
|
|
if task is self.entry_point:
|
|
|
|
continue
|
2022-10-18 17:26:58 +02:00
|
|
|
await wait(task)
|
2022-10-19 11:31:45 +02:00
|
|
|
except BaseException as exc:
|
|
|
|
await self.cancel(False)
|
|
|
|
self.exc = exc
|
|
|
|
if not self.silent:
|
|
|
|
raise self.exc
|
2022-11-02 09:28:04 +01:00
|
|
|
finally:
|
|
|
|
await close_context(self)
|
|
|
|
self.entry_point.propagate = True
|
2022-10-19 11:31:45 +02:00
|
|
|
|
|
|
|
# Task method wrappers
|
|
|
|
|
|
|
|
async def cancel(self, propagate: bool = True):
|
2022-10-18 17:26:58 +02:00
|
|
|
"""
|
|
|
|
Cancels the entire context, iterating over all
|
2022-10-19 11:31:45 +02:00
|
|
|
of its tasks (which includes inner contexts)
|
|
|
|
and cancelling them
|
2022-10-18 17:26:58 +02:00
|
|
|
"""
|
|
|
|
|
|
|
|
for task in self.tasks:
|
2022-10-19 11:31:45 +02:00
|
|
|
if task is self.entry_point:
|
|
|
|
continue
|
2022-11-02 09:28:04 +01:00
|
|
|
if isinstance(task, Task):
|
|
|
|
await cancel(task)
|
|
|
|
else:
|
|
|
|
task: TaskContext
|
|
|
|
await task.cancel(propagate)
|
2022-10-18 17:26:58 +02:00
|
|
|
self.cancelled = True
|
2022-10-19 11:31:45 +02:00
|
|
|
self.propagate = False
|
|
|
|
if propagate:
|
2022-11-02 09:28:04 +01:00
|
|
|
if isinstance(self.entry_point, Task):
|
2022-10-19 11:31:45 +02:00
|
|
|
await cancel(self.entry_point)
|
2022-11-02 09:28:04 +01:00
|
|
|
else:
|
|
|
|
self.entry_point: TaskContext
|
|
|
|
await self.entry_point.cancel(propagate)
|
2022-10-18 17:26:58 +02:00
|
|
|
|
|
|
|
def done(self) -> bool:
|
|
|
|
"""
|
2022-10-19 11:31:45 +02:00
|
|
|
Returns whether all the tasks inside the
|
|
|
|
context have exited
|
2022-10-18 17:26:58 +02:00
|
|
|
"""
|
|
|
|
|
2022-10-19 11:31:45 +02:00
|
|
|
for task in self.tasks:
|
|
|
|
if not task.done():
|
|
|
|
return False
|
2022-11-02 09:28:04 +01:00
|
|
|
return True
|
2022-10-19 11:31:45 +02:00
|
|
|
|
|
|
|
@property
|
|
|
|
def state(self) -> int:
|
|
|
|
return self.entry_point.state
|
|
|
|
|
|
|
|
@state.setter
|
|
|
|
def state(self, state: int):
|
|
|
|
self.entry_point.state = state
|
|
|
|
|
|
|
|
@property
|
|
|
|
def result(self) -> Any:
|
|
|
|
return self.entry_point.result
|
|
|
|
|
|
|
|
@result.setter
|
|
|
|
def result(self, result: Any):
|
|
|
|
self.entry_point.result = result
|
|
|
|
|
|
|
|
@property
|
|
|
|
def exc(self) -> BaseException:
|
|
|
|
return self.entry_point.exc
|
|
|
|
|
|
|
|
@exc.setter
|
|
|
|
def exc(self, exc: BaseException):
|
|
|
|
self.entry_point.exc = exc
|
|
|
|
|
|
|
|
@property
|
|
|
|
def propagate(self) -> bool:
|
|
|
|
return self.entry_point.propagate
|
|
|
|
|
|
|
|
@propagate.setter
|
|
|
|
def propagate(self, val: bool):
|
|
|
|
self.entry_point.propagate = val
|
|
|
|
|
|
|
|
@property
|
|
|
|
def name(self):
|
|
|
|
return self.entry_point.name
|
|
|
|
|
|
|
|
def throw(self, err: BaseException):
|
|
|
|
for task in self.tasks:
|
|
|
|
try:
|
|
|
|
task.throw(err)
|
|
|
|
except err:
|
|
|
|
continue
|
|
|
|
self.entry_point.throw(err)
|
|
|
|
|
|
|
|
@property
|
|
|
|
def joiners(self) -> set[Task]:
|
|
|
|
return self.entry_point.joiners
|
|
|
|
|
|
|
|
@joiners.setter
|
|
|
|
def joiners(self, joiners: set[Task]):
|
|
|
|
self.entry_point.joiners = joiners
|
|
|
|
|
2022-11-02 09:28:04 +01:00
|
|
|
@property
|
|
|
|
def coroutine(self):
|
|
|
|
return self.entry_point.coroutine
|
|
|
|
|
2022-10-19 11:31:45 +02:00
|
|
|
def __hash__(self):
|
|
|
|
return self.entry_point.__hash__()
|
|
|
|
|
|
|
|
def run(self, what: Any | None = None):
|
|
|
|
return self.entry_point.run(what)
|
2022-10-18 17:26:58 +02:00
|
|
|
|
|
|
|
def __del__(self):
|
|
|
|
"""
|
|
|
|
Context destructor
|
|
|
|
"""
|
|
|
|
|
|
|
|
for task in self.tasks:
|
|
|
|
task.__del__()
|
|
|
|
|
|
|
|
def __repr__(self):
|
|
|
|
"""
|
|
|
|
Implements repr(self)
|
|
|
|
"""
|
|
|
|
|
2022-10-19 11:31:45 +02:00
|
|
|
result = "TaskContext(["
|
|
|
|
for i, task in enumerate(self.tasks):
|
|
|
|
if task is self.entry_point:
|
|
|
|
result += repr(self.entry_point)
|
|
|
|
else:
|
|
|
|
result += repr(task)
|
|
|
|
if i < len(self.tasks) - 1:
|
|
|
|
result += ", "
|
|
|
|
result += "])"
|
|
|
|
return result
|