Various fixes for exception handling in contexts

This commit is contained in:
Nocturn9x 2022-11-02 09:28:04 +01:00
parent ce1583e9c2
commit ff2acf298f
10 changed files with 56 additions and 57 deletions

View File

@ -16,7 +16,7 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
""" """
from aiosched.runtime import run, get_event_loop, new_event_loop, clock, with_context from aiosched.runtime import run, get_event_loop, new_event_loop, clock, with_context
from aiosched.internals.syscalls import spawn, wait, sleep, cancel, checkpoint from aiosched.internals.syscalls import spawn, wait, sleep, cancel, checkpoint, join
import aiosched.task import aiosched.task
import aiosched.errors import aiosched.errors
import aiosched.context import aiosched.context

View File

@ -15,6 +15,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
""" """
from functools import partial
from aiosched.task import Task from aiosched.task import Task
from aiosched.errors import Cancelled from aiosched.errors import Cancelled
from aiosched.internals.syscalls import ( from aiosched.internals.syscalls import (
@ -57,11 +59,9 @@ class TaskContext(Task):
# Do we gather multiple exceptions from # Do we gather multiple exceptions from
# children tasks? # children tasks?
self.gather: bool = gather self.gather: bool = gather
# Do we wrap any other task contexts?
self.inner: TaskContext | None = None
async def spawn( async def spawn(
self, func: Callable[..., Coroutine[Any, Any, Any]], *args, **kwargs self, func: Callable[..., Coroutine[Any, Any, Any]], *args, **kwargs
) -> Task: ) -> Task:
""" """
Spawns a child task Spawns a child task
@ -69,8 +69,8 @@ class TaskContext(Task):
task = await spawn(func, *args, **kwargs) task = await spawn(func, *args, **kwargs)
task.context = self task.context = self
await join(task)
self.tasks.append(task) self.tasks.append(task)
await join(task)
return task return task
async def __aenter__(self): async def __aenter__(self):
@ -101,14 +101,9 @@ class TaskContext(Task):
self.exc = exc self.exc = exc
if not self.silent: if not self.silent:
raise self.exc raise self.exc
if self.inner: finally:
for task in self.inner.tasks: await close_context(self)
try: self.entry_point.propagate = True
await wait(task)
except BaseException:
await self.inner.cancel(False)
self.inner.propagate = False
await close_context()
# Task method wrappers # Task method wrappers
@ -119,21 +114,22 @@ class TaskContext(Task):
and cancelling them and cancelling them
""" """
if self.inner:
await self.inner.cancel(propagate)
for task in self.tasks: for task in self.tasks:
if task is self.entry_point: if task is self.entry_point:
continue continue
await cancel(task) if isinstance(task, Task):
await cancel(task)
else:
task: TaskContext
await task.cancel(propagate)
self.cancelled = True self.cancelled = True
await close_context()
self.propagate = False self.propagate = False
if propagate: if propagate:
if isinstance(self.entry_point, TaskContext): if isinstance(self.entry_point, Task):
self.entry_point: TaskContext
await self.entry_point.cancel()
else:
await cancel(self.entry_point) await cancel(self.entry_point)
else:
self.entry_point: TaskContext
await self.entry_point.cancel(propagate)
def done(self) -> bool: def done(self) -> bool:
""" """
@ -142,17 +138,9 @@ class TaskContext(Task):
""" """
for task in self.tasks: for task in self.tasks:
if task is self.entry_point:
continue
if not task.done(): if not task.done():
return False return False
if ( return True
not isinstance(self.entry_point, TaskContext)
and not self.entry_point.done()
):
return False
if self.inner:
return self.inner.done()
@property @property
def state(self) -> int: def state(self) -> int:
@ -206,6 +194,10 @@ class TaskContext(Task):
def joiners(self, joiners: set[Task]): def joiners(self, joiners: set[Task]):
self.entry_point.joiners = joiners self.entry_point.joiners = joiners
@property
def coroutine(self):
return self.entry_point.coroutine
def __hash__(self): def __hash__(self):
return self.entry_point.__hash__() return self.entry_point.__hash__()

View File

@ -139,7 +139,8 @@ async def wait(task: Task) -> Any | None:
multiple times by multiple tasks. multiple times by multiple tasks.
Returns immediately if the task has Returns immediately if the task has
completed already, but exceptions are completed already, but exceptions are
propagated only once propagated only once. Returns the task's
return value, if it has one
:param task: The task to wait for :param task: The task to wait for
:type task: :class: Task :type task: :class: Task
@ -149,6 +150,10 @@ async def wait(task: Task) -> Any | None:
current = await current_task() current = await current_task()
if task is current: if task is current:
raise SchedulerError("a task cannot join itself") raise SchedulerError("a task cannot join itself")
if current not in task.joiners:
# Luckily we use a set, so this has O(1)
# complexity
await join(task) # Waiting implies joining!
await syscall("wait", task) await syscall("wait", task)
if task.exc and task.state != TaskState.CANCELLED and task.propagate: if task.exc and task.state != TaskState.CANCELLED and task.propagate:
task.propagate = False task.propagate = False
@ -226,9 +231,9 @@ async def set_context(ctx):
await syscall("set_context", ctx) await syscall("set_context", ctx)
async def close_context(): async def close_context(ctx):
""" """
Closes the current task context Closes the current task context
""" """
await syscall("close_context") await syscall("close_context", ctx)

View File

@ -233,6 +233,10 @@ class FIFOKernel:
if not self.run_ready: if not self.run_ready:
return # No more tasks to run! return # No more tasks to run!
self.current_task = self.run_ready.popleft() self.current_task = self.run_ready.popleft()
# We nullify the exception object just in case the
# entry point raised and caught an error so that
# self.start() doesn't raise it again at the end
self.current_task.exc = None
self.debugger.before_task_step(self.current_task) self.debugger.before_task_step(self.current_task)
# Some debugging and internal chatter here # Some debugging and internal chatter here
self.current_task.state = TaskState.RUN self.current_task.state = TaskState.RUN
@ -477,25 +481,24 @@ class FIFOKernel:
Sets the current task context. This is Sets the current task context. This is
implemented as simply wrapping the current implemented as simply wrapping the current
task inside the context and replacing the task inside the context and replacing the
Task object with the TaskContext one Task object with the TaskContext one. This
may also wrap another task context into a
new one, but the loop doesn't need to care
about that: the API is designed exactly for
this
""" """
ctx.entry_point = self.current_task ctx.entry_point = self.current_task
if isinstance(self.current_task, TaskContext): ctx.tasks.append(ctx.entry_point)
self.current_task.inner = ctx self.current_task.context = ctx
else:
ctx.tasks.append(ctx.entry_point)
self.current_task.context = ctx
self.current_task = ctx self.current_task = ctx
self.reschedule_running() self.reschedule_running()
def close_context(self): def close_context(self, ctx: TaskContext):
""" """
Closes the context associated with the current Closes the given context
task
""" """
ctx: TaskContext = self.current_task
task = ctx.entry_point task = ctx.entry_point
task.context = None task.context = None
self.current_task = task self.current_task = task

View File

@ -20,7 +20,6 @@ from enum import Enum, auto
from typing import Coroutine, Any from typing import Coroutine, Any
from dataclasses import dataclass, field from dataclasses import dataclass, field
class TaskState(Enum): class TaskState(Enum):
""" """
An enumeration of task states An enumeration of task states

View File

@ -1,5 +1,5 @@
import aiosched import aiosched
from raw_catch import child from raw_catch import child_raises
from debugger import Debugger from debugger import Debugger
@ -8,7 +8,7 @@ async def main(children: list[tuple[str, int]]):
async with aiosched.with_context() as ctx: async with aiosched.with_context() as ctx:
print("[main] Spawning children") print("[main] Spawning children")
for name, delay in children: for name, delay in children:
await ctx.spawn(child, name, delay) await ctx.spawn(child_raises, name, delay)
print("[main] Children spawned") print("[main] Children spawned")
before = aiosched.clock() before = aiosched.clock()
except BaseException as err: except BaseException as err:

View File

@ -14,4 +14,4 @@ async def main(children: list[tuple[str, int]]):
if __name__ == "__main__": if __name__ == "__main__":
aiosched.run(main, [("first", 1), ("second", 2), ("third", 3)], debugger=Debugger()) aiosched.run(main, [("first", 1), ("second", 2), ("third", 3)], debugger=None)

View File

@ -1,5 +1,5 @@
import aiosched import aiosched
from raw_catch import child as errorer from raw_catch import child_raises
from raw_wait import child as successful from raw_wait import child as successful
from debugger import Debugger from debugger import Debugger
@ -19,7 +19,7 @@ async def main(
async with aiosched.with_context() as ctx2: async with aiosched.with_context() as ctx2:
print("[main] Spawning children in second context") print("[main] Spawning children in second context")
for name, delay in children_inner: for name, delay in children_inner:
await ctx2.spawn(errorer, name, delay) await ctx2.spawn(child_raises, name, delay)
print("[main] Children spawned") print("[main] Children spawned")
print(f"[main] Children exited in {aiosched.clock() - before:.2f} seconds") print(f"[main] Children exited in {aiosched.clock() - before:.2f} seconds")
@ -27,7 +27,7 @@ async def main(
if __name__ == "__main__": if __name__ == "__main__":
aiosched.run( aiosched.run(
main, main,
[("first", 1), ("second", 2)], [("first", 1), ("third", 3)],
[("third", 3), ("fourth", 4)], [("second", 2), ("fourth", 4)],
debugger=None, debugger=None,
) )

View File

@ -1,5 +1,5 @@
import aiosched import aiosched
from raw_catch import child from raw_catch import child_raises
from debugger import Debugger from debugger import Debugger
@ -12,12 +12,12 @@ async def main(
before = aiosched.clock() before = aiosched.clock()
print("[main] Spawning children in first context") print("[main] Spawning children in first context")
for name, delay in children_outer: for name, delay in children_outer:
await ctx.spawn(child, name, delay) await ctx.spawn(child_raises, name, delay)
print("[main] Children spawned") print("[main] Children spawned")
async with aiosched.with_context() as ctx2: async with aiosched.with_context() as ctx2:
print("[main] Spawning children in second context") print("[main] Spawning children in second context")
for name, delay in children_inner: for name, delay in children_inner:
await ctx2.spawn(child, name, delay) await ctx2.spawn(child_raises, name, delay)
print("[main] Children spawned") print("[main] Children spawned")
except BaseException as err: except BaseException as err:
print(f"[main] Child raised an exception -> {type(err).__name__}: {err}") print(f"[main] Child raised an exception -> {type(err).__name__}: {err}")

View File

@ -2,17 +2,17 @@ import aiosched
from debugger import Debugger from debugger import Debugger
async def child(name: str, n: int): async def child_raises(name: str, n: int):
before = aiosched.clock() before = aiosched.clock()
print(f"[child {name}] Sleeping for {n} seconds") print(f"[child {name}] Sleeping for {n} seconds")
await aiosched.sleep(n) await aiosched.sleep(n)
print(f"[child {name}] Done! Slept for {aiosched.clock() - before:.2f} seconds") print(f"[child {name}] Done! Slept for {aiosched.clock() - before:.2f} seconds, raising now!")
raise TypeError("waa") raise TypeError("waa")
async def main(n: int): async def main(n: int):
print("[main] Spawning child") print("[main] Spawning child")
task = await aiosched.spawn(child, "raise", n) task = await aiosched.spawn(child_raises, "raise", n)
print("[main] Waiting for child") print("[main] Waiting for child")
before = aiosched.clock() before = aiosched.clock()
try: try: