Work on context exception handling
This commit is contained in:
parent
ca1e8a157b
commit
6690263b55
|
@ -16,31 +16,43 @@ See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
"""
|
"""
|
||||||
from aiosched.task import Task
|
from aiosched.task import Task
|
||||||
from aiosched.internals.syscalls import spawn, wait, cancel
|
from aiosched.errors import Cancelled
|
||||||
|
from aiosched.internals.syscalls import spawn, wait, cancel, set_context, close_context, join
|
||||||
from typing import Any, Coroutine, Callable
|
from typing import Any, Coroutine, Callable
|
||||||
|
|
||||||
|
|
||||||
class TaskContext(Task):
|
class TaskContext(Task):
|
||||||
"""
|
"""
|
||||||
An asynchronous task context that automatically waits
|
An asynchronous context manager that automatically waits
|
||||||
for all tasks spawned within it. A TaskContext object
|
for all tasks spawned within it and cancels itself when
|
||||||
behaves like a task and is handled as a single unit
|
an exception occurs. A TaskContext object behaves like
|
||||||
inside the event loop
|
a regular task and the event loop treats it like a single
|
||||||
|
unit rather than a collection of tasks (in fact, the event
|
||||||
|
loop doesn't even know whether the current task is a task
|
||||||
|
context or not, which is by design). TaskContexts can be
|
||||||
|
nested and will cancel inner ones if an exception is raised
|
||||||
|
inside them
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self, silent: bool = False, gather: bool = True) -> None:
|
||||||
"""
|
"""
|
||||||
Object constructor
|
Object constructor
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# All the tasks that belong to this context. This
|
# All the tasks that belong to this context
|
||||||
# includes any inner contexts contained within this
|
self.tasks: list[Task] = []
|
||||||
# one
|
|
||||||
self.tasks: list[Task | "TaskContext"] = []
|
|
||||||
# Whether we have been cancelled or not
|
# Whether we have been cancelled or not
|
||||||
self.cancelled: bool = False
|
self.cancelled: bool = False
|
||||||
super().__init__(f"TaskContext object at {hex(id(self))}", None)
|
# The context's entry point (needed to forward run() calls and the like)
|
||||||
|
self.entry_point: Task | TaskContext | None = None
|
||||||
|
# Do we ignore exceptions?
|
||||||
|
self.silent: bool = silent
|
||||||
|
# Do we gather multiple exceptions from
|
||||||
|
# children tasks?
|
||||||
|
self.gather: bool = gather
|
||||||
|
# Do we wrap any other task contexts?
|
||||||
|
self.inner: TaskContext | None = None
|
||||||
|
|
||||||
async def spawn(
|
async def spawn(
|
||||||
self, func: Callable[..., Coroutine[Any, Any, Any]], *args, **kwargs
|
self, func: Callable[..., Coroutine[Any, Any, Any]], *args, **kwargs
|
||||||
) -> Task:
|
) -> Task:
|
||||||
|
@ -49,6 +61,8 @@ class TaskContext(Task):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
task = await spawn(func, *args, **kwargs)
|
task = await spawn(func, *args, **kwargs)
|
||||||
|
task.context = self
|
||||||
|
await join(task)
|
||||||
self.tasks.append(task)
|
self.tasks.append(task)
|
||||||
return task
|
return task
|
||||||
|
|
||||||
|
@ -57,44 +71,136 @@ class TaskContext(Task):
|
||||||
Implements the asynchronous context manager interface
|
Implements the asynchronous context manager interface
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
await set_context(self)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
async def __aexit__(self, exc_type: Exception, exc: Exception, tb):
|
async def __aexit__(self, exc_type: Exception, exc: Exception, tb):
|
||||||
"""
|
"""
|
||||||
Implements the asynchronous context manager interface, waiting
|
Implements the asynchronous context manager interface, waiting
|
||||||
for all the tasks spawned inside the context
|
for all the tasks spawned inside the context and handling
|
||||||
|
exceptions
|
||||||
"""
|
"""
|
||||||
|
|
||||||
for task in self.tasks:
|
try:
|
||||||
# This forces the interpreter to stop at the
|
for task in self.tasks:
|
||||||
# end of the block and wait for all
|
# This forces the interpreter to stop at the
|
||||||
# children to exit
|
# end of the block and wait for all
|
||||||
try:
|
# children to exit
|
||||||
|
if task is self.entry_point:
|
||||||
|
continue
|
||||||
await wait(task)
|
await wait(task)
|
||||||
self.tasks.remove(task)
|
except BaseException as exc:
|
||||||
except BaseException:
|
await self.cancel(False)
|
||||||
self.tasks.remove(task)
|
self.exc = exc
|
||||||
await self.cancel()
|
if not self.silent:
|
||||||
raise
|
raise self.exc
|
||||||
|
if self.inner:
|
||||||
|
for task in self.inner.tasks:
|
||||||
|
try:
|
||||||
|
await wait(task)
|
||||||
|
except BaseException:
|
||||||
|
await self.inner.cancel(False)
|
||||||
|
self.inner.propagate = False
|
||||||
|
await close_context()
|
||||||
|
|
||||||
async def cancel(self):
|
# Task method wrappers
|
||||||
|
|
||||||
|
async def cancel(self, propagate: bool = True):
|
||||||
"""
|
"""
|
||||||
Cancels the entire context, iterating over all
|
Cancels the entire context, iterating over all
|
||||||
of its tasks and cancelling them
|
of its tasks (which includes inner contexts)
|
||||||
|
and cancelling them
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
if self.inner:
|
||||||
|
await self.inner.cancel(propagate)
|
||||||
for task in self.tasks:
|
for task in self.tasks:
|
||||||
|
if task is self.entry_point:
|
||||||
|
continue
|
||||||
await cancel(task)
|
await cancel(task)
|
||||||
self.cancelled = True
|
self.cancelled = True
|
||||||
self.tasks = []
|
await close_context()
|
||||||
|
self.propagate = False
|
||||||
|
if propagate:
|
||||||
|
if isinstance(self.entry_point, TaskContext):
|
||||||
|
self.entry_point: TaskContext
|
||||||
|
await self.entry_point.cancel()
|
||||||
|
else:
|
||||||
|
await cancel(self.entry_point)
|
||||||
|
|
||||||
def done(self) -> bool:
|
def done(self) -> bool:
|
||||||
"""
|
"""
|
||||||
Returns True if all the tasks inside the
|
Returns whether all the tasks inside the
|
||||||
context have exited, False otherwise
|
context have exited
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return all([task.done() for task in self.tasks])
|
for task in self.tasks:
|
||||||
|
if task is self.entry_point:
|
||||||
|
continue
|
||||||
|
if not task.done():
|
||||||
|
return False
|
||||||
|
if not isinstance(self.entry_point, TaskContext) and not self.entry_point.done():
|
||||||
|
return False
|
||||||
|
if self.inner:
|
||||||
|
return self.inner.done()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def state(self) -> int:
|
||||||
|
return self.entry_point.state
|
||||||
|
|
||||||
|
@state.setter
|
||||||
|
def state(self, state: int):
|
||||||
|
self.entry_point.state = state
|
||||||
|
|
||||||
|
@property
|
||||||
|
def result(self) -> Any:
|
||||||
|
return self.entry_point.result
|
||||||
|
|
||||||
|
@result.setter
|
||||||
|
def result(self, result: Any):
|
||||||
|
self.entry_point.result = result
|
||||||
|
|
||||||
|
@property
|
||||||
|
def exc(self) -> BaseException:
|
||||||
|
return self.entry_point.exc
|
||||||
|
|
||||||
|
@exc.setter
|
||||||
|
def exc(self, exc: BaseException):
|
||||||
|
self.entry_point.exc = exc
|
||||||
|
|
||||||
|
@property
|
||||||
|
def propagate(self) -> bool:
|
||||||
|
return self.entry_point.propagate
|
||||||
|
|
||||||
|
@propagate.setter
|
||||||
|
def propagate(self, val: bool):
|
||||||
|
self.entry_point.propagate = val
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self):
|
||||||
|
return self.entry_point.name
|
||||||
|
|
||||||
|
def throw(self, err: BaseException):
|
||||||
|
for task in self.tasks:
|
||||||
|
try:
|
||||||
|
task.throw(err)
|
||||||
|
except err:
|
||||||
|
continue
|
||||||
|
self.entry_point.throw(err)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def joiners(self) -> set[Task]:
|
||||||
|
return self.entry_point.joiners
|
||||||
|
|
||||||
|
@joiners.setter
|
||||||
|
def joiners(self, joiners: set[Task]):
|
||||||
|
self.entry_point.joiners = joiners
|
||||||
|
|
||||||
|
def __hash__(self):
|
||||||
|
return self.entry_point.__hash__()
|
||||||
|
|
||||||
|
def run(self, what: Any | None = None):
|
||||||
|
return self.entry_point.run(what)
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
"""
|
"""
|
||||||
|
@ -109,4 +215,13 @@ class TaskContext(Task):
|
||||||
Implements repr(self)
|
Implements repr(self)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return f"TaskContext({self.tasks})"
|
result = "TaskContext(["
|
||||||
|
for i, task in enumerate(self.tasks):
|
||||||
|
if task is self.entry_point:
|
||||||
|
result += repr(self.entry_point)
|
||||||
|
else:
|
||||||
|
result += repr(task)
|
||||||
|
if i < len(self.tasks) - 1:
|
||||||
|
result += ", "
|
||||||
|
result += "])"
|
||||||
|
return result
|
||||||
|
|
|
@ -15,6 +15,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
See the License for the specific language governing permissions and
|
See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
"""
|
"""
|
||||||
|
import traceback
|
||||||
from aiosched.task import Task
|
from aiosched.task import Task
|
||||||
|
|
||||||
|
|
||||||
|
@ -67,3 +68,37 @@ class Cancelled(BaseException):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
task: Task
|
task: Task
|
||||||
|
|
||||||
|
|
||||||
|
class ErrorStack(SchedulerError):
|
||||||
|
"""
|
||||||
|
This exception wraps multiple exceptions and
|
||||||
|
shows each individual traceback of them when
|
||||||
|
printed. This is to ensure that no exception is
|
||||||
|
lost even if 2 or more tasks raise at the
|
||||||
|
same time or during cancellation of other
|
||||||
|
tasks
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, errors: list[BaseException]):
|
||||||
|
"""
|
||||||
|
Object constructor
|
||||||
|
"""
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
self.errors = errors
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
"""
|
||||||
|
Returns str(self)
|
||||||
|
"""
|
||||||
|
|
||||||
|
tracebacks = ""
|
||||||
|
for i, err in enumerate(self.errors):
|
||||||
|
if i not in (1, len(self.errors)):
|
||||||
|
tracebacks += (
|
||||||
|
f"\n{''.join(traceback.format_exception(type(err), err, err.__traceback__))}\n{'-' * 32}\n"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
tracebacks += f"\n{''.join(traceback.format_exception(type(err), err, err.__traceback__))}"
|
||||||
|
return f"Multiple errors occurred:\n{tracebacks}"
|
||||||
|
|
|
@ -103,43 +103,71 @@ async def suspend():
|
||||||
await syscall("suspend")
|
await syscall("suspend")
|
||||||
|
|
||||||
|
|
||||||
|
async def current_task() -> Task:
|
||||||
|
"""
|
||||||
|
Returns the currently running
|
||||||
|
task object
|
||||||
|
"""
|
||||||
|
|
||||||
|
return await syscall("get_current_task")
|
||||||
|
|
||||||
|
|
||||||
|
async def join(task: Task):
|
||||||
|
"""
|
||||||
|
Tells the event loop that the current task
|
||||||
|
wants to wait on the given one, but without
|
||||||
|
waiting for its completion
|
||||||
|
"""
|
||||||
|
|
||||||
|
await syscall("join", task)
|
||||||
|
|
||||||
|
|
||||||
async def wait(task: Task) -> Any | None:
|
async def wait(task: Task) -> Any | None:
|
||||||
"""
|
"""
|
||||||
Waits for the completion of a
|
Waits for the completion of a
|
||||||
given task and returns its
|
given task and returns its
|
||||||
return value. Can be called
|
return value. Can be called
|
||||||
multiple times by multiple tasks.
|
multiple times by multiple tasks.
|
||||||
Raises an error if the task has
|
Returns immediately if the task has
|
||||||
completed already. Please note that
|
completed already, but exceptions are
|
||||||
exceptions are propagated, too
|
propagated only once
|
||||||
|
|
||||||
:param task: The task to wait for
|
:param task: The task to wait for
|
||||||
:type task: :class: Task
|
:type task: :class: Task
|
||||||
:returns: The task's return value, if any
|
:returns: The task's return value, if any
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if task.done():
|
current = await current_task()
|
||||||
raise SchedulerError(f"task {task.name!r} has completed already")
|
if task is current:
|
||||||
|
raise SchedulerError("a task cannot join itself")
|
||||||
await syscall("wait", task)
|
await syscall("wait", task)
|
||||||
if task.exc:
|
if task.exc and task.state != TaskState.CANCELLED and task.propagate:
|
||||||
|
task.propagate = False
|
||||||
raise task.exc
|
raise task.exc
|
||||||
return task.result
|
return task.result
|
||||||
|
|
||||||
|
|
||||||
async def cancel(task: Task):
|
async def cancel(task: Task, block: bool = False):
|
||||||
"""
|
"""
|
||||||
Cancels the given task. Note that
|
Cancels the given task. Note that
|
||||||
cancellations may not happen immediately
|
cancellations may not happen immediately
|
||||||
if the task is blocked in an uninterruptible
|
if the task is blocked in an uninterruptible
|
||||||
state
|
state. If block equals False, the default,
|
||||||
|
this function returns immediately, otherwise
|
||||||
|
it waits for the task to receive the cancellation
|
||||||
|
|
||||||
:param task: The task to wait for
|
:param task: The task to wait for
|
||||||
:type task: :class: Task
|
:type task: :class: Task
|
||||||
|
:param block: Whether to wait for the task to be
|
||||||
|
actually cancelled or not, defaults to False
|
||||||
|
:type block: bool, optional
|
||||||
"""
|
"""
|
||||||
|
|
||||||
await syscall("cancel", task)
|
await syscall("cancel", task)
|
||||||
if task.state != TaskState.CANCELLED:
|
if block:
|
||||||
raise SchedulerError(f"task {task.name!r} ignored cancellation")
|
await wait(task)
|
||||||
|
if not task.state == TaskState.CANCELLED:
|
||||||
|
raise SchedulerError(f"task {task.name!r} ignored cancellation")
|
||||||
|
|
||||||
|
|
||||||
async def closing(stream):
|
async def closing(stream):
|
||||||
|
@ -170,3 +198,19 @@ async def wait_writable(stream):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
await syscall("perform_io", stream, EVENT_WRITE)
|
await syscall("perform_io", stream, EVENT_WRITE)
|
||||||
|
|
||||||
|
|
||||||
|
async def set_context(ctx):
|
||||||
|
"""
|
||||||
|
Sets the current task context
|
||||||
|
"""
|
||||||
|
|
||||||
|
await syscall("set_context", ctx)
|
||||||
|
|
||||||
|
|
||||||
|
async def close_context():
|
||||||
|
"""
|
||||||
|
Closes the current task context
|
||||||
|
"""
|
||||||
|
|
||||||
|
await syscall("close_context")
|
||||||
|
|
|
@ -24,6 +24,7 @@ from aiosched.internals.queues import TimeQueue
|
||||||
from aiosched.util.debugging import BaseDebugger
|
from aiosched.util.debugging import BaseDebugger
|
||||||
from typing import Callable, Any, Coroutine
|
from typing import Callable, Any, Coroutine
|
||||||
from aiosched.errors import InternalError, ResourceBusy, Cancelled, ResourceClosed, ResourceBroken
|
from aiosched.errors import InternalError, ResourceBusy, Cancelled, ResourceClosed, ResourceBroken
|
||||||
|
from aiosched.context import TaskContext
|
||||||
from selectors import DefaultSelector, BaseSelector
|
from selectors import DefaultSelector, BaseSelector
|
||||||
|
|
||||||
|
|
||||||
|
@ -77,6 +78,8 @@ class FIFOKernel:
|
||||||
self.data: dict[Task, Any] = {}
|
self.data: dict[Task, Any] = {}
|
||||||
# The currently running task
|
# The currently running task
|
||||||
self.current_task: Task | None = None
|
self.current_task: Task | None = None
|
||||||
|
# The loop's entry point
|
||||||
|
self.entry_point: Task | None = None
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
"""
|
"""
|
||||||
|
@ -118,7 +121,7 @@ class FIFOKernel:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if not self.done() and not force:
|
if not self.done() and not force:
|
||||||
raise InternalError("cannot shut down a running event loop")
|
self.current_task.throw(InternalError("cannot shut down a running event loop"))
|
||||||
for task in self.all():
|
for task in self.all():
|
||||||
self.cancel(task)
|
self.cancel(task)
|
||||||
|
|
||||||
|
@ -183,10 +186,7 @@ class FIFOKernel:
|
||||||
Reschedules the currently running task
|
Reschedules the currently running task
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if self.current_task:
|
self.run_ready.append(self.current_task)
|
||||||
self.run_ready.append(self.current_task)
|
|
||||||
else:
|
|
||||||
raise InternalError("aiosched is not running")
|
|
||||||
|
|
||||||
def suspend(self):
|
def suspend(self):
|
||||||
"""
|
"""
|
||||||
|
@ -215,6 +215,8 @@ class FIFOKernel:
|
||||||
while self.current_task.done():
|
while self.current_task.done():
|
||||||
# We need to make sure we don't try to execute
|
# We need to make sure we don't try to execute
|
||||||
# exited tasks that are on the running queue
|
# exited tasks that are on the running queue
|
||||||
|
if not self.run_ready:
|
||||||
|
return # No more tasks to run!
|
||||||
self.current_task = self.run_ready.popleft()
|
self.current_task = self.run_ready.popleft()
|
||||||
self.debugger.before_task_step(self.current_task)
|
self.debugger.before_task_step(self.current_task)
|
||||||
# Some debugging and internal chatter here
|
# Some debugging and internal chatter here
|
||||||
|
@ -227,18 +229,15 @@ class FIFOKernel:
|
||||||
else:
|
else:
|
||||||
# Run a single step with the calculation (i.e. until a yield
|
# Run a single step with the calculation (i.e. until a yield
|
||||||
# somewhere)
|
# somewhere)
|
||||||
method, args, kwargs = self.current_task.run(
|
method, args, kwargs = self.current_task.run(self.data.pop(self.current_task, None))
|
||||||
self.data.get(self.current_task)
|
if not hasattr(self, method) or not callable(getattr(self, method)):
|
||||||
)
|
|
||||||
self.data.pop(self.current_task, None)
|
|
||||||
if not hasattr(self, method) and not callable(getattr(self, method)):
|
|
||||||
# This if block is meant to be triggered by other async
|
# This if block is meant to be triggered by other async
|
||||||
# libraries, which most likely have different trap names and behaviors
|
# libraries, which most likely have different trap names and behaviors
|
||||||
# compared to us. If you get this exception, and you're 100% sure you're
|
# compared to us. If you get this exception, and you're 100% sure you're
|
||||||
# not mixing async primitives from other libraries, then it's a bug!
|
# not mixing async primitives from other libraries, then it's a bug!
|
||||||
raise InternalError(
|
self.current_task.throw(InternalError(
|
||||||
"Uh oh! Something very bad just happened, did you try to mix primitives from other async libraries?"
|
"Uh oh! Something very bad just happened, did you try to mix primitives from other async libraries?"
|
||||||
) from None
|
))
|
||||||
# Sneaky method call, thanks to David Beazley for this ;)
|
# Sneaky method call, thanks to David Beazley for this ;)
|
||||||
getattr(self, method)(*args, **kwargs)
|
getattr(self, method)(*args, **kwargs)
|
||||||
self.debugger.after_task_step(self.current_task)
|
self.debugger.after_task_step(self.current_task)
|
||||||
|
@ -286,16 +285,18 @@ class FIFOKernel:
|
||||||
Starts the event loop from a synchronous context
|
Starts the event loop from a synchronous context
|
||||||
"""
|
"""
|
||||||
|
|
||||||
entry_point = Task(func.__name__ or str(func), func(*args, **kwargs))
|
self.entry_point = Task(func.__name__ or str(func), func(*args, **kwargs))
|
||||||
self.run_ready.append(entry_point)
|
self.run_ready.append(self.entry_point)
|
||||||
self.debugger.on_start()
|
self.debugger.on_start()
|
||||||
try:
|
try:
|
||||||
self.run()
|
self.run()
|
||||||
finally:
|
finally:
|
||||||
self.debugger.on_exit()
|
self.debugger.on_exit()
|
||||||
if entry_point.exc:
|
if self.entry_point.exc and self.entry_point.context is None and self.entry_point.propagate:
|
||||||
raise entry_point.exc
|
# Contexts already manage exceptions for us,
|
||||||
return entry_point.result
|
# no need to raise it manually
|
||||||
|
raise self.entry_point.exc
|
||||||
|
return self.entry_point.result
|
||||||
|
|
||||||
def io_release(self, resource):
|
def io_release(self, resource):
|
||||||
"""
|
"""
|
||||||
|
@ -344,23 +345,17 @@ class FIFOKernel:
|
||||||
|
|
||||||
def cancel(self, task: Task):
|
def cancel(self, task: Task):
|
||||||
"""
|
"""
|
||||||
Schedules the task to be cancelled later
|
Attempts to cancel the given task or
|
||||||
or does so straight away if it is safe to do so
|
schedules cancellation for later if
|
||||||
|
it fails
|
||||||
"""
|
"""
|
||||||
|
|
||||||
self.reschedule_running()
|
self.io_release_task(task)
|
||||||
match task.state:
|
self.paused.discard(task)
|
||||||
case TaskState.IO:
|
|
||||||
self.io_release_task(task)
|
|
||||||
case TaskState.PAUSED:
|
|
||||||
self.paused.discard(task)
|
|
||||||
case TaskState.INIT, TaskState.CANCELLED, TaskState.CRASHED:
|
|
||||||
return
|
|
||||||
self.handle_task_run(partial(task.throw, Cancelled(task)), task)
|
self.handle_task_run(partial(task.throw, Cancelled(task)), task)
|
||||||
if task.state == TaskState.CANCELLED:
|
if task.state != TaskState.CANCELLED:
|
||||||
self.debugger.after_cancel(task)
|
|
||||||
else:
|
|
||||||
task.pending_cancellation = True
|
task.pending_cancellation = True
|
||||||
|
self.reschedule_running()
|
||||||
|
|
||||||
def handle_task_run(self, func: Callable, task: Task | None = None):
|
def handle_task_run(self, func: Callable, task: Task | None = None):
|
||||||
"""
|
"""
|
||||||
|
@ -398,6 +393,7 @@ class FIFOKernel:
|
||||||
task = task or self.current_task
|
task = task or self.current_task
|
||||||
task.state = TaskState.CANCELLED
|
task.state = TaskState.CANCELLED
|
||||||
task.pending_cancellation = False
|
task.pending_cancellation = False
|
||||||
|
self.debugger.after_cancel(task)
|
||||||
self.wait(task)
|
self.wait(task)
|
||||||
except BaseException as err:
|
except BaseException as err:
|
||||||
# Any other exception is caught here
|
# Any other exception is caught here
|
||||||
|
@ -425,13 +421,24 @@ class FIFOKernel:
|
||||||
def wait(self, task: Task):
|
def wait(self, task: Task):
|
||||||
"""
|
"""
|
||||||
Makes the current task wait for completion of the given one
|
Makes the current task wait for completion of the given one
|
||||||
|
by only rescheduling it once the given task has finished
|
||||||
|
executing
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if task.done():
|
if task.done():
|
||||||
|
self.paused.discard(task)
|
||||||
|
self.io_release_task(task)
|
||||||
self.run_ready.extend(task.joiners)
|
self.run_ready.extend(task.joiners)
|
||||||
task.joiners = {}
|
|
||||||
else:
|
def join(self, task: Task):
|
||||||
task.joiners.add(self.current_task)
|
"""
|
||||||
|
Tells the event loop that the current task
|
||||||
|
wants to wait on the given one, but without
|
||||||
|
actually waiting for its completion
|
||||||
|
"""
|
||||||
|
|
||||||
|
task.joiners.add(self.current_task)
|
||||||
|
self.reschedule_running()
|
||||||
|
|
||||||
def spawn(self, func: Callable[..., Coroutine[Any, Any, Any]], *args, **kwargs):
|
def spawn(self, func: Callable[..., Coroutine[Any, Any, Any]], *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
|
@ -444,6 +451,44 @@ class FIFOKernel:
|
||||||
self.run_ready.append(task)
|
self.run_ready.append(task)
|
||||||
self.reschedule_running()
|
self.reschedule_running()
|
||||||
|
|
||||||
|
def set_context(self, ctx: TaskContext):
|
||||||
|
"""
|
||||||
|
Sets the current task context. This is
|
||||||
|
implemented as simply wrapping the current
|
||||||
|
task inside the context and replacing the
|
||||||
|
Task object with the TaskContext one
|
||||||
|
"""
|
||||||
|
|
||||||
|
ctx.entry_point = self.current_task
|
||||||
|
if isinstance(self.current_task, TaskContext):
|
||||||
|
self.current_task.inner = ctx
|
||||||
|
else:
|
||||||
|
ctx.tasks.append(ctx.entry_point)
|
||||||
|
self.current_task.context = ctx
|
||||||
|
self.current_task = ctx
|
||||||
|
self.reschedule_running()
|
||||||
|
|
||||||
|
def close_context(self):
|
||||||
|
"""
|
||||||
|
Closes the context associated with the current
|
||||||
|
task
|
||||||
|
"""
|
||||||
|
|
||||||
|
ctx: TaskContext = self.current_task
|
||||||
|
task = ctx.entry_point
|
||||||
|
task.context = None
|
||||||
|
self.current_task = task
|
||||||
|
self.reschedule_running()
|
||||||
|
|
||||||
|
def get_current_task(self):
|
||||||
|
"""
|
||||||
|
Returns the current task to an asynchronous
|
||||||
|
caller
|
||||||
|
"""
|
||||||
|
|
||||||
|
self.data[self.current_task] = self.current_task
|
||||||
|
self.reschedule_running()
|
||||||
|
|
||||||
def perform_io(self, resource, evt_type: int):
|
def perform_io(self, resource, evt_type: int):
|
||||||
"""
|
"""
|
||||||
Registers the given resource inside our selector to perform I/O multiplexing
|
Registers the given resource inside our selector to perform I/O multiplexing
|
||||||
|
@ -495,4 +540,4 @@ class FIFOKernel:
|
||||||
# If we get here, two tasks are trying to read or write on the same resource at the same time
|
# If we get here, two tasks are trying to read or write on the same resource at the same time
|
||||||
raise ResourceBusy(
|
raise ResourceBusy(
|
||||||
"The given resource is being read from/written to from another task"
|
"The given resource is being read from/written to from another task"
|
||||||
) from None
|
)
|
||||||
|
|
|
@ -88,13 +88,14 @@ def run(
|
||||||
get_event_loop().start(func, *args, **kwargs)
|
get_event_loop().start(func, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def with_context() -> TaskContext:
|
def with_context(*args, **kwargs) -> TaskContext:
|
||||||
"""
|
"""
|
||||||
Creates and returns a new TaskContext
|
Creates and returns a new TaskContext
|
||||||
object
|
object. All positional and keyword arguments
|
||||||
|
are passed to TaskContext's constructor
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return TaskContext()
|
return TaskContext(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def clock() -> float:
|
def clock() -> float:
|
||||||
|
|
|
@ -77,28 +77,39 @@ class Task:
|
||||||
paused_when: float = 0.0
|
paused_when: float = 0.0
|
||||||
# The next deadline, in terms of the absolute clock of the loop, associated to the task
|
# The next deadline, in terms of the absolute clock of the loop, associated to the task
|
||||||
next_deadline: float = 0.0
|
next_deadline: float = 0.0
|
||||||
|
# Is this task within a context? This is needed to fix a bug that would occur when
|
||||||
|
# the event loop tries to raise the exception caused by first task that kicked the
|
||||||
|
# loop even if that context already ignored said error
|
||||||
|
context: "TaskContext" = None
|
||||||
|
# We propagate exception only at the first call to wait()
|
||||||
|
propagate: bool = True
|
||||||
|
|
||||||
def run(self, what: Any | None = None):
|
def run(self, what: Any | None = None):
|
||||||
"""
|
"""
|
||||||
Simple abstraction layer over a coroutine's send method
|
Simple abstraction layer over a coroutine's send method.
|
||||||
|
Does nothing if the task has already exited
|
||||||
|
|
||||||
:param what: The object that has to be sent to the coroutine,
|
:param what: The object that has to be sent to the coroutine,
|
||||||
defaults to None
|
defaults to None
|
||||||
:type what: Any, optional
|
:type what: Any, optional
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
if self.done():
|
||||||
|
return
|
||||||
return self.coroutine.send(what)
|
return self.coroutine.send(what)
|
||||||
|
|
||||||
def throw(self, err: BaseException):
|
def throw(self, err: BaseException):
|
||||||
"""
|
"""
|
||||||
Simple abstraction layer over a coroutine's throw method
|
Simple abstraction layer over a coroutine's throw method.
|
||||||
|
Does nothing if the task has already exited
|
||||||
|
|
||||||
:param err: The exception that has to be raised inside
|
:param err: The exception that has to be raised inside
|
||||||
the task
|
the task
|
||||||
:type err: BaseException
|
:type err: BaseException
|
||||||
"""
|
"""
|
||||||
|
|
||||||
self.exc = err
|
if self.done():
|
||||||
|
return
|
||||||
return self.coroutine.throw(err)
|
return self.coroutine.throw(err)
|
||||||
|
|
||||||
def __hash__(self):
|
def __hash__(self):
|
||||||
|
@ -125,9 +136,5 @@ class Task:
|
||||||
Task destructor
|
Task destructor
|
||||||
"""
|
"""
|
||||||
|
|
||||||
try:
|
|
||||||
self.coroutine.close()
|
|
||||||
except RuntimeError:
|
|
||||||
pass # TODO: This is kinda bad
|
|
||||||
if self.last_io:
|
if self.last_io:
|
||||||
warnings.warn(f"task '{self.name}' was destroyed, but has pending I/O")
|
warnings.warn(f"task '{self.name}' was destroyed, but has pending I/O")
|
||||||
|
|
|
@ -9,7 +9,7 @@ async def main(children: list[tuple[str, int]]):
|
||||||
print("[main] Spawning children")
|
print("[main] Spawning children")
|
||||||
for name, delay in children:
|
for name, delay in children:
|
||||||
await ctx.spawn(child, name, delay)
|
await ctx.spawn(child, name, delay)
|
||||||
print(f"[main] Spawned {len(ctx.tasks)} children")
|
print("[main] Children spawned")
|
||||||
before = aiosched.clock()
|
before = aiosched.clock()
|
||||||
except BaseException as err:
|
except BaseException as err:
|
||||||
print(f"[main] Child raised an exception -> {type(err).__name__}: {err}")
|
print(f"[main] Child raised an exception -> {type(err).__name__}: {err}")
|
||||||
|
|
|
@ -0,0 +1,19 @@
|
||||||
|
import aiosched
|
||||||
|
from catch import child
|
||||||
|
from debugger import Debugger
|
||||||
|
|
||||||
|
|
||||||
|
async def main(children: list[tuple[str, int]]):
|
||||||
|
async with aiosched.with_context(silent=True) as ctx:
|
||||||
|
print("[main] Spawning children")
|
||||||
|
for name, delay in children:
|
||||||
|
await ctx.spawn(child, name, delay)
|
||||||
|
print("[main] Children spawned")
|
||||||
|
before = aiosched.clock()
|
||||||
|
if ctx.exc:
|
||||||
|
print(f"[main] Child raised an exception -> {type(ctx.exc).__name__}: {ctx.exc}")
|
||||||
|
print(f"[main] Children exited in {aiosched.clock() - before:.2f} seconds")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
aiosched.run(main, [("first", 1), ("second", 2), ("third", 3)], debugger=None)
|
|
@ -8,10 +8,10 @@ async def main(children: list[tuple[str, int]]):
|
||||||
async with aiosched.with_context() as ctx:
|
async with aiosched.with_context() as ctx:
|
||||||
for name, delay in children:
|
for name, delay in children:
|
||||||
await ctx.spawn(child, name, delay)
|
await ctx.spawn(child, name, delay)
|
||||||
print(f"[main] Spawned {len(ctx.tasks)} children")
|
print("[main] Children spawned")
|
||||||
before = aiosched.clock()
|
before = aiosched.clock()
|
||||||
print(f"[main] Children exited in {aiosched.clock() - before:.2f} seconds")
|
print(f"[main] Children exited in {aiosched.clock() - before:.2f} seconds")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
aiosched.run(main, [("first", 1), ("second", 2), ("third", 3)], debugger=None)
|
aiosched.run(main, [("first", 1), ("second", 2), ("third", 3)], debugger=None)
|
|
@ -0,0 +1,26 @@
|
||||||
|
import aiosched
|
||||||
|
from catch import child as errorer
|
||||||
|
from wait import child as successful
|
||||||
|
from debugger import Debugger
|
||||||
|
|
||||||
|
|
||||||
|
async def main(children_outer: list[tuple[str, int]], children_inner: list[tuple[str, int]]):
|
||||||
|
before = aiosched.clock()
|
||||||
|
async with aiosched.with_context() as ctx:
|
||||||
|
print("[main] Spawning children in first context")
|
||||||
|
for name, delay in children_outer:
|
||||||
|
await ctx.spawn(successful, name, delay)
|
||||||
|
print("[main] Children spawned")
|
||||||
|
# An exception in an outer context cancels everything
|
||||||
|
# inside it, but an exception in an inner context does
|
||||||
|
# not affect outer ones
|
||||||
|
async with aiosched.with_context() as ctx2:
|
||||||
|
print("[main] Spawning children in second context")
|
||||||
|
for name, delay in children_inner:
|
||||||
|
await ctx2.spawn(errorer, name, delay)
|
||||||
|
print("[main] Children spawned")
|
||||||
|
print(f"[main] Children exited in {aiosched.clock() - before:.2f} seconds")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
aiosched.run(main, [("first", 1), ("second", 2)], [("third", 3), ("fourth", 4)], debugger=None)
|
|
@ -0,0 +1,26 @@
|
||||||
|
import aiosched
|
||||||
|
from catch import child
|
||||||
|
from debugger import Debugger
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: This crashes 1 second later than it should be
|
||||||
|
async def main(children_outer: list[tuple[str, int]], children_inner: list[tuple[str, int]]):
|
||||||
|
try:
|
||||||
|
async with aiosched.with_context() as ctx:
|
||||||
|
before = aiosched.clock()
|
||||||
|
print("[main] Spawning children in first context")
|
||||||
|
for name, delay in children_outer:
|
||||||
|
await ctx.spawn(child, name, delay)
|
||||||
|
print("[main] Children spawned")
|
||||||
|
async with aiosched.with_context() as ctx2:
|
||||||
|
print("[main] Spawning children in second context")
|
||||||
|
for name, delay in children_inner:
|
||||||
|
await ctx2.spawn(child, name, delay)
|
||||||
|
print("[main] Children spawned")
|
||||||
|
except BaseException as err:
|
||||||
|
print(f"[main] Child raised an exception -> {type(err).__name__}: {err}")
|
||||||
|
print(f"[main] Children exited in {aiosched.clock() - before:.2f} seconds")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
aiosched.run(main, [("first", 1), ("second", 2)], [("third", 3), ("fourth", 4)], debugger=None)
|
|
@ -0,0 +1,22 @@
|
||||||
|
import aiosched
|
||||||
|
from wait import child
|
||||||
|
from debugger import Debugger
|
||||||
|
|
||||||
|
|
||||||
|
async def main(children_outer: list[tuple[str, int]], children_inner: list[tuple[str, int]]):
|
||||||
|
async with aiosched.with_context() as ctx:
|
||||||
|
before = aiosched.clock()
|
||||||
|
print("[main] Spawning children in first context")
|
||||||
|
for name, delay in children_outer:
|
||||||
|
await ctx.spawn(child, name, delay)
|
||||||
|
print("[main] Children spawned")
|
||||||
|
async with aiosched.with_context() as ctx2:
|
||||||
|
print("[main] Spawning children in second context")
|
||||||
|
for name, delay in children_inner:
|
||||||
|
await ctx2.spawn(child, name, delay)
|
||||||
|
print("[main] Children spawned")
|
||||||
|
print(f"[main] Children exited in {aiosched.clock() - before:.2f} seconds")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
aiosched.run(main, [("first", 1), ("second", 2)], [("third", 3), ("fourth", 4)], debugger=None)
|
Reference in New Issue