Various fixes for exception handling in contexts

This commit is contained in:
Nocturn9x 2022-11-02 09:28:04 +01:00
parent ce1583e9c2
commit ff2acf298f
10 changed files with 56 additions and 57 deletions

View File

@ -16,7 +16,7 @@ See the License for the specific language governing permissions and
limitations under the License.
"""
from aiosched.runtime import run, get_event_loop, new_event_loop, clock, with_context
from aiosched.internals.syscalls import spawn, wait, sleep, cancel, checkpoint
from aiosched.internals.syscalls import spawn, wait, sleep, cancel, checkpoint, join
import aiosched.task
import aiosched.errors
import aiosched.context

View File

@ -15,6 +15,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from functools import partial
from aiosched.task import Task
from aiosched.errors import Cancelled
from aiosched.internals.syscalls import (
@ -57,11 +59,9 @@ class TaskContext(Task):
# Do we gather multiple exceptions from
# children tasks?
self.gather: bool = gather
# Do we wrap any other task contexts?
self.inner: TaskContext | None = None
async def spawn(
self, func: Callable[..., Coroutine[Any, Any, Any]], *args, **kwargs
self, func: Callable[..., Coroutine[Any, Any, Any]], *args, **kwargs
) -> Task:
"""
Spawns a child task
@ -69,8 +69,8 @@ class TaskContext(Task):
task = await spawn(func, *args, **kwargs)
task.context = self
await join(task)
self.tasks.append(task)
await join(task)
return task
async def __aenter__(self):
@ -101,14 +101,9 @@ class TaskContext(Task):
self.exc = exc
if not self.silent:
raise self.exc
if self.inner:
for task in self.inner.tasks:
try:
await wait(task)
except BaseException:
await self.inner.cancel(False)
self.inner.propagate = False
await close_context()
finally:
await close_context(self)
self.entry_point.propagate = True
# Task method wrappers
@ -119,21 +114,22 @@ class TaskContext(Task):
and cancelling them
"""
if self.inner:
await self.inner.cancel(propagate)
for task in self.tasks:
if task is self.entry_point:
continue
await cancel(task)
if isinstance(task, Task):
await cancel(task)
else:
task: TaskContext
await task.cancel(propagate)
self.cancelled = True
await close_context()
self.propagate = False
if propagate:
if isinstance(self.entry_point, TaskContext):
self.entry_point: TaskContext
await self.entry_point.cancel()
else:
if isinstance(self.entry_point, Task):
await cancel(self.entry_point)
else:
self.entry_point: TaskContext
await self.entry_point.cancel(propagate)
def done(self) -> bool:
"""
@ -142,17 +138,9 @@ class TaskContext(Task):
"""
for task in self.tasks:
if task is self.entry_point:
continue
if not task.done():
return False
if (
not isinstance(self.entry_point, TaskContext)
and not self.entry_point.done()
):
return False
if self.inner:
return self.inner.done()
return True
@property
def state(self) -> int:
@ -206,6 +194,10 @@ class TaskContext(Task):
def joiners(self, joiners: set[Task]):
self.entry_point.joiners = joiners
@property
def coroutine(self):
return self.entry_point.coroutine
def __hash__(self):
return self.entry_point.__hash__()

View File

@ -139,7 +139,8 @@ async def wait(task: Task) -> Any | None:
multiple times by multiple tasks.
Returns immediately if the task has
completed already, but exceptions are
propagated only once
propagated only once. Returns the task's
return value, if it has one
:param task: The task to wait for
:type task: :class: Task
@ -149,6 +150,10 @@ async def wait(task: Task) -> Any | None:
current = await current_task()
if task is current:
raise SchedulerError("a task cannot join itself")
if current not in task.joiners:
# Luckily we use a set, so this has O(1)
# complexity
await join(task) # Waiting implies joining!
await syscall("wait", task)
if task.exc and task.state != TaskState.CANCELLED and task.propagate:
task.propagate = False
@ -226,9 +231,9 @@ async def set_context(ctx):
await syscall("set_context", ctx)
async def close_context():
async def close_context(ctx):
"""
Closes the current task context
"""
await syscall("close_context")
await syscall("close_context", ctx)

View File

@ -233,6 +233,10 @@ class FIFOKernel:
if not self.run_ready:
return # No more tasks to run!
self.current_task = self.run_ready.popleft()
# We nullify the exception object just in case the
# entry point raised and caught an error so that
# self.start() doesn't raise it again at the end
self.current_task.exc = None
self.debugger.before_task_step(self.current_task)
# Some debugging and internal chatter here
self.current_task.state = TaskState.RUN
@ -477,25 +481,24 @@ class FIFOKernel:
Sets the current task context. This is
implemented as simply wrapping the current
task inside the context and replacing the
Task object with the TaskContext one
Task object with the TaskContext one. This
may also wrap another task context into a
new one, but the loop doesn't need to care
about that: the API is designed exactly for
this
"""
ctx.entry_point = self.current_task
if isinstance(self.current_task, TaskContext):
self.current_task.inner = ctx
else:
ctx.tasks.append(ctx.entry_point)
self.current_task.context = ctx
ctx.tasks.append(ctx.entry_point)
self.current_task.context = ctx
self.current_task = ctx
self.reschedule_running()
def close_context(self):
def close_context(self, ctx: TaskContext):
"""
Closes the context associated with the current
task
Closes the given context
"""
ctx: TaskContext = self.current_task
task = ctx.entry_point
task.context = None
self.current_task = task

View File

@ -20,7 +20,6 @@ from enum import Enum, auto
from typing import Coroutine, Any
from dataclasses import dataclass, field
class TaskState(Enum):
"""
An enumeration of task states

View File

@ -1,5 +1,5 @@
import aiosched
from raw_catch import child
from raw_catch import child_raises
from debugger import Debugger
@ -8,7 +8,7 @@ async def main(children: list[tuple[str, int]]):
async with aiosched.with_context() as ctx:
print("[main] Spawning children")
for name, delay in children:
await ctx.spawn(child, name, delay)
await ctx.spawn(child_raises, name, delay)
print("[main] Children spawned")
before = aiosched.clock()
except BaseException as err:

View File

@ -14,4 +14,4 @@ async def main(children: list[tuple[str, int]]):
if __name__ == "__main__":
aiosched.run(main, [("first", 1), ("second", 2), ("third", 3)], debugger=Debugger())
aiosched.run(main, [("first", 1), ("second", 2), ("third", 3)], debugger=None)

View File

@ -1,5 +1,5 @@
import aiosched
from raw_catch import child as errorer
from raw_catch import child_raises
from raw_wait import child as successful
from debugger import Debugger
@ -19,7 +19,7 @@ async def main(
async with aiosched.with_context() as ctx2:
print("[main] Spawning children in second context")
for name, delay in children_inner:
await ctx2.spawn(errorer, name, delay)
await ctx2.spawn(child_raises, name, delay)
print("[main] Children spawned")
print(f"[main] Children exited in {aiosched.clock() - before:.2f} seconds")
@ -27,7 +27,7 @@ async def main(
if __name__ == "__main__":
aiosched.run(
main,
[("first", 1), ("second", 2)],
[("third", 3), ("fourth", 4)],
[("first", 1), ("third", 3)],
[("second", 2), ("fourth", 4)],
debugger=None,
)

View File

@ -1,5 +1,5 @@
import aiosched
from raw_catch import child
from raw_catch import child_raises
from debugger import Debugger
@ -12,12 +12,12 @@ async def main(
before = aiosched.clock()
print("[main] Spawning children in first context")
for name, delay in children_outer:
await ctx.spawn(child, name, delay)
await ctx.spawn(child_raises, name, delay)
print("[main] Children spawned")
async with aiosched.with_context() as ctx2:
print("[main] Spawning children in second context")
for name, delay in children_inner:
await ctx2.spawn(child, name, delay)
await ctx2.spawn(child_raises, name, delay)
print("[main] Children spawned")
except BaseException as err:
print(f"[main] Child raised an exception -> {type(err).__name__}: {err}")

View File

@ -2,17 +2,17 @@ import aiosched
from debugger import Debugger
async def child(name: str, n: int):
async def child_raises(name: str, n: int):
before = aiosched.clock()
print(f"[child {name}] Sleeping for {n} seconds")
await aiosched.sleep(n)
print(f"[child {name}] Done! Slept for {aiosched.clock() - before:.2f} seconds")
print(f"[child {name}] Done! Slept for {aiosched.clock() - before:.2f} seconds, raising now!")
raise TypeError("waa")
async def main(n: int):
print("[main] Spawning child")
task = await aiosched.spawn(child, "raise", n)
task = await aiosched.spawn(child_raises, "raise", n)
print("[main] Waiting for child")
before = aiosched.clock()
try: