Work on context exception handling

This commit is contained in:
Nocturn9x 2022-10-19 11:31:45 +02:00
parent ca1e8a157b
commit 6690263b55
12 changed files with 427 additions and 87 deletions

View File

@ -16,31 +16,43 @@ See the License for the specific language governing permissions and
limitations under the License.
"""
from aiosched.task import Task
from aiosched.internals.syscalls import spawn, wait, cancel
from aiosched.errors import Cancelled
from aiosched.internals.syscalls import spawn, wait, cancel, set_context, close_context, join
from typing import Any, Coroutine, Callable
class TaskContext(Task):
"""
An asynchronous task context that automatically waits
for all tasks spawned within it. A TaskContext object
behaves like a task and is handled as a single unit
inside the event loop
An asynchronous context manager that automatically waits
for all tasks spawned within it and cancels itself when
an exception occurs. A TaskContext object behaves like
a regular task and the event loop treats it like a single
unit rather than a collection of tasks (in fact, the event
loop doesn't even know whether the current task is a task
context or not, which is by design). TaskContexts can be
nested and will cancel inner ones if an exception is raised
inside them
"""
def __init__(self) -> None:
def __init__(self, silent: bool = False, gather: bool = True) -> None:
"""
Object constructor
"""
# All the tasks that belong to this context. This
# includes any inner contexts contained within this
# one
self.tasks: list[Task | "TaskContext"] = []
# All the tasks that belong to this context
self.tasks: list[Task] = []
# Whether we have been cancelled or not
self.cancelled: bool = False
super().__init__(f"TaskContext object at {hex(id(self))}", None)
# The context's entry point (needed to forward run() calls and the like)
self.entry_point: Task | TaskContext | None = None
# Do we ignore exceptions?
self.silent: bool = silent
# Do we gather multiple exceptions from
# children tasks?
self.gather: bool = gather
# Do we wrap any other task contexts?
self.inner: TaskContext | None = None
async def spawn(
self, func: Callable[..., Coroutine[Any, Any, Any]], *args, **kwargs
) -> Task:
@ -49,6 +61,8 @@ class TaskContext(Task):
"""
task = await spawn(func, *args, **kwargs)
task.context = self
await join(task)
self.tasks.append(task)
return task
@ -57,44 +71,136 @@ class TaskContext(Task):
Implements the asynchronous context manager interface
"""
await set_context(self)
return self
async def __aexit__(self, exc_type: Exception, exc: Exception, tb):
"""
Implements the asynchronous context manager interface, waiting
for all the tasks spawned inside the context
for all the tasks spawned inside the context and handling
exceptions
"""
for task in self.tasks:
# This forces the interpreter to stop at the
# end of the block and wait for all
# children to exit
try:
try:
for task in self.tasks:
# This forces the interpreter to stop at the
# end of the block and wait for all
# children to exit
if task is self.entry_point:
continue
await wait(task)
self.tasks.remove(task)
except BaseException:
self.tasks.remove(task)
await self.cancel()
raise
except BaseException as exc:
await self.cancel(False)
self.exc = exc
if not self.silent:
raise self.exc
if self.inner:
for task in self.inner.tasks:
try:
await wait(task)
except BaseException:
await self.inner.cancel(False)
self.inner.propagate = False
await close_context()
async def cancel(self):
# Task method wrappers
async def cancel(self, propagate: bool = True):
"""
Cancels the entire context, iterating over all
of its tasks and cancelling them
of its tasks (which includes inner contexts)
and cancelling them
"""
if self.inner:
await self.inner.cancel(propagate)
for task in self.tasks:
if task is self.entry_point:
continue
await cancel(task)
self.cancelled = True
self.tasks = []
await close_context()
self.propagate = False
if propagate:
if isinstance(self.entry_point, TaskContext):
self.entry_point: TaskContext
await self.entry_point.cancel()
else:
await cancel(self.entry_point)
def done(self) -> bool:
"""
Returns True if all the tasks inside the
context have exited, False otherwise
Returns whether all the tasks inside the
context have exited
"""
return all([task.done() for task in self.tasks])
for task in self.tasks:
if task is self.entry_point:
continue
if not task.done():
return False
if not isinstance(self.entry_point, TaskContext) and not self.entry_point.done():
return False
if self.inner:
return self.inner.done()
@property
def state(self) -> int:
return self.entry_point.state
@state.setter
def state(self, state: int):
self.entry_point.state = state
@property
def result(self) -> Any:
return self.entry_point.result
@result.setter
def result(self, result: Any):
self.entry_point.result = result
@property
def exc(self) -> BaseException:
return self.entry_point.exc
@exc.setter
def exc(self, exc: BaseException):
self.entry_point.exc = exc
@property
def propagate(self) -> bool:
return self.entry_point.propagate
@propagate.setter
def propagate(self, val: bool):
self.entry_point.propagate = val
@property
def name(self):
return self.entry_point.name
def throw(self, err: BaseException):
for task in self.tasks:
try:
task.throw(err)
except err:
continue
self.entry_point.throw(err)
@property
def joiners(self) -> set[Task]:
return self.entry_point.joiners
@joiners.setter
def joiners(self, joiners: set[Task]):
self.entry_point.joiners = joiners
def __hash__(self):
return self.entry_point.__hash__()
def run(self, what: Any | None = None):
return self.entry_point.run(what)
def __del__(self):
"""
@ -109,4 +215,13 @@ class TaskContext(Task):
Implements repr(self)
"""
return f"TaskContext({self.tasks})"
result = "TaskContext(["
for i, task in enumerate(self.tasks):
if task is self.entry_point:
result += repr(self.entry_point)
else:
result += repr(task)
if i < len(self.tasks) - 1:
result += ", "
result += "])"
return result

View File

@ -15,6 +15,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import traceback
from aiosched.task import Task
@ -67,3 +68,37 @@ class Cancelled(BaseException):
"""
task: Task
class ErrorStack(SchedulerError):
"""
This exception wraps multiple exceptions and
shows each individual traceback of them when
printed. This is to ensure that no exception is
lost even if 2 or more tasks raise at the
same time or during cancellation of other
tasks
"""
def __init__(self, errors: list[BaseException]):
"""
Object constructor
"""
super().__init__()
self.errors = errors
def __str__(self):
"""
Returns str(self)
"""
tracebacks = ""
for i, err in enumerate(self.errors):
if i not in (1, len(self.errors)):
tracebacks += (
f"\n{''.join(traceback.format_exception(type(err), err, err.__traceback__))}\n{'-' * 32}\n"
)
else:
tracebacks += f"\n{''.join(traceback.format_exception(type(err), err, err.__traceback__))}"
return f"Multiple errors occurred:\n{tracebacks}"

View File

@ -103,43 +103,71 @@ async def suspend():
await syscall("suspend")
async def current_task() -> Task:
"""
Returns the currently running
task object
"""
return await syscall("get_current_task")
async def join(task: Task):
"""
Tells the event loop that the current task
wants to wait on the given one, but without
waiting for its completion
"""
await syscall("join", task)
async def wait(task: Task) -> Any | None:
"""
Waits for the completion of a
given task and returns its
return value. Can be called
multiple times by multiple tasks.
Raises an error if the task has
completed already. Please note that
exceptions are propagated, too
Returns immediately if the task has
completed already, but exceptions are
propagated only once
:param task: The task to wait for
:type task: :class: Task
:returns: The task's return value, if any
"""
if task.done():
raise SchedulerError(f"task {task.name!r} has completed already")
current = await current_task()
if task is current:
raise SchedulerError("a task cannot join itself")
await syscall("wait", task)
if task.exc:
if task.exc and task.state != TaskState.CANCELLED and task.propagate:
task.propagate = False
raise task.exc
return task.result
async def cancel(task: Task):
async def cancel(task: Task, block: bool = False):
"""
Cancels the given task. Note that
cancellations may not happen immediately
if the task is blocked in an uninterruptible
state
state. If block equals False, the default,
this function returns immediately, otherwise
it waits for the task to receive the cancellation
:param task: The task to wait for
:type task: :class: Task
:param block: Whether to wait for the task to be
actually cancelled or not, defaults to False
:type block: bool, optional
"""
await syscall("cancel", task)
if task.state != TaskState.CANCELLED:
raise SchedulerError(f"task {task.name!r} ignored cancellation")
if block:
await wait(task)
if not task.state == TaskState.CANCELLED:
raise SchedulerError(f"task {task.name!r} ignored cancellation")
async def closing(stream):
@ -170,3 +198,19 @@ async def wait_writable(stream):
"""
await syscall("perform_io", stream, EVENT_WRITE)
async def set_context(ctx):
"""
Sets the current task context
"""
await syscall("set_context", ctx)
async def close_context():
"""
Closes the current task context
"""
await syscall("close_context")

View File

@ -24,6 +24,7 @@ from aiosched.internals.queues import TimeQueue
from aiosched.util.debugging import BaseDebugger
from typing import Callable, Any, Coroutine
from aiosched.errors import InternalError, ResourceBusy, Cancelled, ResourceClosed, ResourceBroken
from aiosched.context import TaskContext
from selectors import DefaultSelector, BaseSelector
@ -77,6 +78,8 @@ class FIFOKernel:
self.data: dict[Task, Any] = {}
# The currently running task
self.current_task: Task | None = None
# The loop's entry point
self.entry_point: Task | None = None
def __repr__(self):
"""
@ -118,7 +121,7 @@ class FIFOKernel:
"""
if not self.done() and not force:
raise InternalError("cannot shut down a running event loop")
self.current_task.throw(InternalError("cannot shut down a running event loop"))
for task in self.all():
self.cancel(task)
@ -183,10 +186,7 @@ class FIFOKernel:
Reschedules the currently running task
"""
if self.current_task:
self.run_ready.append(self.current_task)
else:
raise InternalError("aiosched is not running")
self.run_ready.append(self.current_task)
def suspend(self):
"""
@ -215,6 +215,8 @@ class FIFOKernel:
while self.current_task.done():
# We need to make sure we don't try to execute
# exited tasks that are on the running queue
if not self.run_ready:
return # No more tasks to run!
self.current_task = self.run_ready.popleft()
self.debugger.before_task_step(self.current_task)
# Some debugging and internal chatter here
@ -227,18 +229,15 @@ class FIFOKernel:
else:
# Run a single step with the calculation (i.e. until a yield
# somewhere)
method, args, kwargs = self.current_task.run(
self.data.get(self.current_task)
)
self.data.pop(self.current_task, None)
if not hasattr(self, method) and not callable(getattr(self, method)):
method, args, kwargs = self.current_task.run(self.data.pop(self.current_task, None))
if not hasattr(self, method) or not callable(getattr(self, method)):
# This if block is meant to be triggered by other async
# libraries, which most likely have different trap names and behaviors
# compared to us. If you get this exception, and you're 100% sure you're
# not mixing async primitives from other libraries, then it's a bug!
raise InternalError(
self.current_task.throw(InternalError(
"Uh oh! Something very bad just happened, did you try to mix primitives from other async libraries?"
) from None
))
# Sneaky method call, thanks to David Beazley for this ;)
getattr(self, method)(*args, **kwargs)
self.debugger.after_task_step(self.current_task)
@ -286,16 +285,18 @@ class FIFOKernel:
Starts the event loop from a synchronous context
"""
entry_point = Task(func.__name__ or str(func), func(*args, **kwargs))
self.run_ready.append(entry_point)
self.entry_point = Task(func.__name__ or str(func), func(*args, **kwargs))
self.run_ready.append(self.entry_point)
self.debugger.on_start()
try:
self.run()
finally:
self.debugger.on_exit()
if entry_point.exc:
raise entry_point.exc
return entry_point.result
if self.entry_point.exc and self.entry_point.context is None and self.entry_point.propagate:
# Contexts already manage exceptions for us,
# no need to raise it manually
raise self.entry_point.exc
return self.entry_point.result
def io_release(self, resource):
"""
@ -344,23 +345,17 @@ class FIFOKernel:
def cancel(self, task: Task):
"""
Schedules the task to be cancelled later
or does so straight away if it is safe to do so
Attempts to cancel the given task or
schedules cancellation for later if
it fails
"""
self.reschedule_running()
match task.state:
case TaskState.IO:
self.io_release_task(task)
case TaskState.PAUSED:
self.paused.discard(task)
case TaskState.INIT, TaskState.CANCELLED, TaskState.CRASHED:
return
self.io_release_task(task)
self.paused.discard(task)
self.handle_task_run(partial(task.throw, Cancelled(task)), task)
if task.state == TaskState.CANCELLED:
self.debugger.after_cancel(task)
else:
if task.state != TaskState.CANCELLED:
task.pending_cancellation = True
self.reschedule_running()
def handle_task_run(self, func: Callable, task: Task | None = None):
"""
@ -398,6 +393,7 @@ class FIFOKernel:
task = task or self.current_task
task.state = TaskState.CANCELLED
task.pending_cancellation = False
self.debugger.after_cancel(task)
self.wait(task)
except BaseException as err:
# Any other exception is caught here
@ -425,13 +421,24 @@ class FIFOKernel:
def wait(self, task: Task):
"""
Makes the current task wait for completion of the given one
by only rescheduling it once the given task has finished
executing
"""
if task.done():
self.paused.discard(task)
self.io_release_task(task)
self.run_ready.extend(task.joiners)
task.joiners = {}
else:
task.joiners.add(self.current_task)
def join(self, task: Task):
"""
Tells the event loop that the current task
wants to wait on the given one, but without
actually waiting for its completion
"""
task.joiners.add(self.current_task)
self.reschedule_running()
def spawn(self, func: Callable[..., Coroutine[Any, Any, Any]], *args, **kwargs):
"""
@ -444,6 +451,44 @@ class FIFOKernel:
self.run_ready.append(task)
self.reschedule_running()
def set_context(self, ctx: TaskContext):
"""
Sets the current task context. This is
implemented as simply wrapping the current
task inside the context and replacing the
Task object with the TaskContext one
"""
ctx.entry_point = self.current_task
if isinstance(self.current_task, TaskContext):
self.current_task.inner = ctx
else:
ctx.tasks.append(ctx.entry_point)
self.current_task.context = ctx
self.current_task = ctx
self.reschedule_running()
def close_context(self):
"""
Closes the context associated with the current
task
"""
ctx: TaskContext = self.current_task
task = ctx.entry_point
task.context = None
self.current_task = task
self.reschedule_running()
def get_current_task(self):
"""
Returns the current task to an asynchronous
caller
"""
self.data[self.current_task] = self.current_task
self.reschedule_running()
def perform_io(self, resource, evt_type: int):
"""
Registers the given resource inside our selector to perform I/O multiplexing
@ -495,4 +540,4 @@ class FIFOKernel:
# If we get here, two tasks are trying to read or write on the same resource at the same time
raise ResourceBusy(
"The given resource is being read from/written to from another task"
) from None
)

View File

@ -88,13 +88,14 @@ def run(
get_event_loop().start(func, *args, **kwargs)
def with_context() -> TaskContext:
def with_context(*args, **kwargs) -> TaskContext:
"""
Creates and returns a new TaskContext
object
object. All positional and keyword arguments
are passed to TaskContext's constructor
"""
return TaskContext()
return TaskContext(*args, **kwargs)
def clock() -> float:

View File

@ -77,28 +77,39 @@ class Task:
paused_when: float = 0.0
# The next deadline, in terms of the absolute clock of the loop, associated to the task
next_deadline: float = 0.0
# Is this task within a context? This is needed to fix a bug that would occur when
# the event loop tries to raise the exception caused by first task that kicked the
# loop even if that context already ignored said error
context: "TaskContext" = None
# We propagate exception only at the first call to wait()
propagate: bool = True
def run(self, what: Any | None = None):
"""
Simple abstraction layer over a coroutine's send method
Simple abstraction layer over a coroutine's send method.
Does nothing if the task has already exited
:param what: The object that has to be sent to the coroutine,
defaults to None
:type what: Any, optional
"""
if self.done():
return
return self.coroutine.send(what)
def throw(self, err: BaseException):
"""
Simple abstraction layer over a coroutine's throw method
Simple abstraction layer over a coroutine's throw method.
Does nothing if the task has already exited
:param err: The exception that has to be raised inside
the task
:type err: BaseException
"""
self.exc = err
if self.done():
return
return self.coroutine.throw(err)
def __hash__(self):
@ -125,9 +136,5 @@ class Task:
Task destructor
"""
try:
self.coroutine.close()
except RuntimeError:
pass # TODO: This is kinda bad
if self.last_io:
warnings.warn(f"task '{self.name}' was destroyed, but has pending I/O")

View File

@ -9,7 +9,7 @@ async def main(children: list[tuple[str, int]]):
print("[main] Spawning children")
for name, delay in children:
await ctx.spawn(child, name, delay)
print(f"[main] Spawned {len(ctx.tasks)} children")
print("[main] Children spawned")
before = aiosched.clock()
except BaseException as err:
print(f"[main] Child raised an exception -> {type(err).__name__}: {err}")

View File

@ -0,0 +1,19 @@
import aiosched
from catch import child
from debugger import Debugger
async def main(children: list[tuple[str, int]]):
async with aiosched.with_context(silent=True) as ctx:
print("[main] Spawning children")
for name, delay in children:
await ctx.spawn(child, name, delay)
print("[main] Children spawned")
before = aiosched.clock()
if ctx.exc:
print(f"[main] Child raised an exception -> {type(ctx.exc).__name__}: {ctx.exc}")
print(f"[main] Children exited in {aiosched.clock() - before:.2f} seconds")
if __name__ == "__main__":
aiosched.run(main, [("first", 1), ("second", 2), ("third", 3)], debugger=None)

View File

@ -8,10 +8,10 @@ async def main(children: list[tuple[str, int]]):
async with aiosched.with_context() as ctx:
for name, delay in children:
await ctx.spawn(child, name, delay)
print(f"[main] Spawned {len(ctx.tasks)} children")
print("[main] Children spawned")
before = aiosched.clock()
print(f"[main] Children exited in {aiosched.clock() - before:.2f} seconds")
if __name__ == "__main__":
aiosched.run(main, [("first", 1), ("second", 2), ("third", 3)], debugger=None)
aiosched.run(main, [("first", 1), ("second", 2), ("third", 3)], debugger=None)

View File

@ -0,0 +1,26 @@
import aiosched
from catch import child as errorer
from wait import child as successful
from debugger import Debugger
async def main(children_outer: list[tuple[str, int]], children_inner: list[tuple[str, int]]):
before = aiosched.clock()
async with aiosched.with_context() as ctx:
print("[main] Spawning children in first context")
for name, delay in children_outer:
await ctx.spawn(successful, name, delay)
print("[main] Children spawned")
# An exception in an outer context cancels everything
# inside it, but an exception in an inner context does
# not affect outer ones
async with aiosched.with_context() as ctx2:
print("[main] Spawning children in second context")
for name, delay in children_inner:
await ctx2.spawn(errorer, name, delay)
print("[main] Children spawned")
print(f"[main] Children exited in {aiosched.clock() - before:.2f} seconds")
if __name__ == "__main__":
aiosched.run(main, [("first", 1), ("second", 2)], [("third", 3), ("fourth", 4)], debugger=None)

View File

@ -0,0 +1,26 @@
import aiosched
from catch import child
from debugger import Debugger
# TODO: This crashes 1 second later than it should be
async def main(children_outer: list[tuple[str, int]], children_inner: list[tuple[str, int]]):
try:
async with aiosched.with_context() as ctx:
before = aiosched.clock()
print("[main] Spawning children in first context")
for name, delay in children_outer:
await ctx.spawn(child, name, delay)
print("[main] Children spawned")
async with aiosched.with_context() as ctx2:
print("[main] Spawning children in second context")
for name, delay in children_inner:
await ctx2.spawn(child, name, delay)
print("[main] Children spawned")
except BaseException as err:
print(f"[main] Child raised an exception -> {type(err).__name__}: {err}")
print(f"[main] Children exited in {aiosched.clock() - before:.2f} seconds")
if __name__ == "__main__":
aiosched.run(main, [("first", 1), ("second", 2)], [("third", 3), ("fourth", 4)], debugger=None)

View File

@ -0,0 +1,22 @@
import aiosched
from wait import child
from debugger import Debugger
async def main(children_outer: list[tuple[str, int]], children_inner: list[tuple[str, int]]):
async with aiosched.with_context() as ctx:
before = aiosched.clock()
print("[main] Spawning children in first context")
for name, delay in children_outer:
await ctx.spawn(child, name, delay)
print("[main] Children spawned")
async with aiosched.with_context() as ctx2:
print("[main] Spawning children in second context")
for name, delay in children_inner:
await ctx2.spawn(child, name, delay)
print("[main] Children spawned")
print(f"[main] Children exited in {aiosched.clock() - before:.2f} seconds")
if __name__ == "__main__":
aiosched.run(main, [("first", 1), ("second", 2)], [("third", 3), ("fourth", 4)], debugger=None)