Various fixes for exception handling in contexts
This commit is contained in:
parent
ce1583e9c2
commit
ff2acf298f
|
@ -16,7 +16,7 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
"""
|
||||
from aiosched.runtime import run, get_event_loop, new_event_loop, clock, with_context
|
||||
from aiosched.internals.syscalls import spawn, wait, sleep, cancel, checkpoint
|
||||
from aiosched.internals.syscalls import spawn, wait, sleep, cancel, checkpoint, join
|
||||
import aiosched.task
|
||||
import aiosched.errors
|
||||
import aiosched.context
|
||||
|
|
|
@ -15,6 +15,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
from functools import partial
|
||||
|
||||
from aiosched.task import Task
|
||||
from aiosched.errors import Cancelled
|
||||
from aiosched.internals.syscalls import (
|
||||
|
@ -57,11 +59,9 @@ class TaskContext(Task):
|
|||
# Do we gather multiple exceptions from
|
||||
# children tasks?
|
||||
self.gather: bool = gather
|
||||
# Do we wrap any other task contexts?
|
||||
self.inner: TaskContext | None = None
|
||||
|
||||
async def spawn(
|
||||
self, func: Callable[..., Coroutine[Any, Any, Any]], *args, **kwargs
|
||||
self, func: Callable[..., Coroutine[Any, Any, Any]], *args, **kwargs
|
||||
) -> Task:
|
||||
"""
|
||||
Spawns a child task
|
||||
|
@ -69,8 +69,8 @@ class TaskContext(Task):
|
|||
|
||||
task = await spawn(func, *args, **kwargs)
|
||||
task.context = self
|
||||
await join(task)
|
||||
self.tasks.append(task)
|
||||
await join(task)
|
||||
return task
|
||||
|
||||
async def __aenter__(self):
|
||||
|
@ -101,14 +101,9 @@ class TaskContext(Task):
|
|||
self.exc = exc
|
||||
if not self.silent:
|
||||
raise self.exc
|
||||
if self.inner:
|
||||
for task in self.inner.tasks:
|
||||
try:
|
||||
await wait(task)
|
||||
except BaseException:
|
||||
await self.inner.cancel(False)
|
||||
self.inner.propagate = False
|
||||
await close_context()
|
||||
finally:
|
||||
await close_context(self)
|
||||
self.entry_point.propagate = True
|
||||
|
||||
# Task method wrappers
|
||||
|
||||
|
@ -119,21 +114,22 @@ class TaskContext(Task):
|
|||
and cancelling them
|
||||
"""
|
||||
|
||||
if self.inner:
|
||||
await self.inner.cancel(propagate)
|
||||
for task in self.tasks:
|
||||
if task is self.entry_point:
|
||||
continue
|
||||
await cancel(task)
|
||||
if isinstance(task, Task):
|
||||
await cancel(task)
|
||||
else:
|
||||
task: TaskContext
|
||||
await task.cancel(propagate)
|
||||
self.cancelled = True
|
||||
await close_context()
|
||||
self.propagate = False
|
||||
if propagate:
|
||||
if isinstance(self.entry_point, TaskContext):
|
||||
self.entry_point: TaskContext
|
||||
await self.entry_point.cancel()
|
||||
else:
|
||||
if isinstance(self.entry_point, Task):
|
||||
await cancel(self.entry_point)
|
||||
else:
|
||||
self.entry_point: TaskContext
|
||||
await self.entry_point.cancel(propagate)
|
||||
|
||||
def done(self) -> bool:
|
||||
"""
|
||||
|
@ -142,17 +138,9 @@ class TaskContext(Task):
|
|||
"""
|
||||
|
||||
for task in self.tasks:
|
||||
if task is self.entry_point:
|
||||
continue
|
||||
if not task.done():
|
||||
return False
|
||||
if (
|
||||
not isinstance(self.entry_point, TaskContext)
|
||||
and not self.entry_point.done()
|
||||
):
|
||||
return False
|
||||
if self.inner:
|
||||
return self.inner.done()
|
||||
return True
|
||||
|
||||
@property
|
||||
def state(self) -> int:
|
||||
|
@ -206,6 +194,10 @@ class TaskContext(Task):
|
|||
def joiners(self, joiners: set[Task]):
|
||||
self.entry_point.joiners = joiners
|
||||
|
||||
@property
|
||||
def coroutine(self):
|
||||
return self.entry_point.coroutine
|
||||
|
||||
def __hash__(self):
|
||||
return self.entry_point.__hash__()
|
||||
|
||||
|
|
|
@ -139,7 +139,8 @@ async def wait(task: Task) -> Any | None:
|
|||
multiple times by multiple tasks.
|
||||
Returns immediately if the task has
|
||||
completed already, but exceptions are
|
||||
propagated only once
|
||||
propagated only once. Returns the task's
|
||||
return value, if it has one
|
||||
|
||||
:param task: The task to wait for
|
||||
:type task: :class: Task
|
||||
|
@ -149,6 +150,10 @@ async def wait(task: Task) -> Any | None:
|
|||
current = await current_task()
|
||||
if task is current:
|
||||
raise SchedulerError("a task cannot join itself")
|
||||
if current not in task.joiners:
|
||||
# Luckily we use a set, so this has O(1)
|
||||
# complexity
|
||||
await join(task) # Waiting implies joining!
|
||||
await syscall("wait", task)
|
||||
if task.exc and task.state != TaskState.CANCELLED and task.propagate:
|
||||
task.propagate = False
|
||||
|
@ -226,9 +231,9 @@ async def set_context(ctx):
|
|||
await syscall("set_context", ctx)
|
||||
|
||||
|
||||
async def close_context():
|
||||
async def close_context(ctx):
|
||||
"""
|
||||
Closes the current task context
|
||||
"""
|
||||
|
||||
await syscall("close_context")
|
||||
await syscall("close_context", ctx)
|
||||
|
|
|
@ -233,6 +233,10 @@ class FIFOKernel:
|
|||
if not self.run_ready:
|
||||
return # No more tasks to run!
|
||||
self.current_task = self.run_ready.popleft()
|
||||
# We nullify the exception object just in case the
|
||||
# entry point raised and caught an error so that
|
||||
# self.start() doesn't raise it again at the end
|
||||
self.current_task.exc = None
|
||||
self.debugger.before_task_step(self.current_task)
|
||||
# Some debugging and internal chatter here
|
||||
self.current_task.state = TaskState.RUN
|
||||
|
@ -477,25 +481,24 @@ class FIFOKernel:
|
|||
Sets the current task context. This is
|
||||
implemented as simply wrapping the current
|
||||
task inside the context and replacing the
|
||||
Task object with the TaskContext one
|
||||
Task object with the TaskContext one. This
|
||||
may also wrap another task context into a
|
||||
new one, but the loop doesn't need to care
|
||||
about that: the API is designed exactly for
|
||||
this
|
||||
"""
|
||||
|
||||
ctx.entry_point = self.current_task
|
||||
if isinstance(self.current_task, TaskContext):
|
||||
self.current_task.inner = ctx
|
||||
else:
|
||||
ctx.tasks.append(ctx.entry_point)
|
||||
self.current_task.context = ctx
|
||||
ctx.tasks.append(ctx.entry_point)
|
||||
self.current_task.context = ctx
|
||||
self.current_task = ctx
|
||||
self.reschedule_running()
|
||||
|
||||
def close_context(self):
|
||||
def close_context(self, ctx: TaskContext):
|
||||
"""
|
||||
Closes the context associated with the current
|
||||
task
|
||||
Closes the given context
|
||||
"""
|
||||
|
||||
ctx: TaskContext = self.current_task
|
||||
task = ctx.entry_point
|
||||
task.context = None
|
||||
self.current_task = task
|
||||
|
|
|
@ -20,7 +20,6 @@ from enum import Enum, auto
|
|||
from typing import Coroutine, Any
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
class TaskState(Enum):
|
||||
"""
|
||||
An enumeration of task states
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
import aiosched
|
||||
from raw_catch import child
|
||||
from raw_catch import child_raises
|
||||
from debugger import Debugger
|
||||
|
||||
|
||||
|
@ -8,7 +8,7 @@ async def main(children: list[tuple[str, int]]):
|
|||
async with aiosched.with_context() as ctx:
|
||||
print("[main] Spawning children")
|
||||
for name, delay in children:
|
||||
await ctx.spawn(child, name, delay)
|
||||
await ctx.spawn(child_raises, name, delay)
|
||||
print("[main] Children spawned")
|
||||
before = aiosched.clock()
|
||||
except BaseException as err:
|
||||
|
|
|
@ -14,4 +14,4 @@ async def main(children: list[tuple[str, int]]):
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
aiosched.run(main, [("first", 1), ("second", 2), ("third", 3)], debugger=Debugger())
|
||||
aiosched.run(main, [("first", 1), ("second", 2), ("third", 3)], debugger=None)
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
import aiosched
|
||||
from raw_catch import child as errorer
|
||||
from raw_catch import child_raises
|
||||
from raw_wait import child as successful
|
||||
from debugger import Debugger
|
||||
|
||||
|
@ -19,7 +19,7 @@ async def main(
|
|||
async with aiosched.with_context() as ctx2:
|
||||
print("[main] Spawning children in second context")
|
||||
for name, delay in children_inner:
|
||||
await ctx2.spawn(errorer, name, delay)
|
||||
await ctx2.spawn(child_raises, name, delay)
|
||||
print("[main] Children spawned")
|
||||
print(f"[main] Children exited in {aiosched.clock() - before:.2f} seconds")
|
||||
|
||||
|
@ -27,7 +27,7 @@ async def main(
|
|||
if __name__ == "__main__":
|
||||
aiosched.run(
|
||||
main,
|
||||
[("first", 1), ("second", 2)],
|
||||
[("third", 3), ("fourth", 4)],
|
||||
[("first", 1), ("third", 3)],
|
||||
[("second", 2), ("fourth", 4)],
|
||||
debugger=None,
|
||||
)
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
import aiosched
|
||||
from raw_catch import child
|
||||
from raw_catch import child_raises
|
||||
from debugger import Debugger
|
||||
|
||||
|
||||
|
@ -12,12 +12,12 @@ async def main(
|
|||
before = aiosched.clock()
|
||||
print("[main] Spawning children in first context")
|
||||
for name, delay in children_outer:
|
||||
await ctx.spawn(child, name, delay)
|
||||
await ctx.spawn(child_raises, name, delay)
|
||||
print("[main] Children spawned")
|
||||
async with aiosched.with_context() as ctx2:
|
||||
print("[main] Spawning children in second context")
|
||||
for name, delay in children_inner:
|
||||
await ctx2.spawn(child, name, delay)
|
||||
await ctx2.spawn(child_raises, name, delay)
|
||||
print("[main] Children spawned")
|
||||
except BaseException as err:
|
||||
print(f"[main] Child raised an exception -> {type(err).__name__}: {err}")
|
||||
|
|
|
@ -2,17 +2,17 @@ import aiosched
|
|||
from debugger import Debugger
|
||||
|
||||
|
||||
async def child(name: str, n: int):
|
||||
async def child_raises(name: str, n: int):
|
||||
before = aiosched.clock()
|
||||
print(f"[child {name}] Sleeping for {n} seconds")
|
||||
await aiosched.sleep(n)
|
||||
print(f"[child {name}] Done! Slept for {aiosched.clock() - before:.2f} seconds")
|
||||
print(f"[child {name}] Done! Slept for {aiosched.clock() - before:.2f} seconds, raising now!")
|
||||
raise TypeError("waa")
|
||||
|
||||
|
||||
async def main(n: int):
|
||||
print("[main] Spawning child")
|
||||
task = await aiosched.spawn(child, "raise", n)
|
||||
task = await aiosched.spawn(child_raises, "raise", n)
|
||||
print("[main] Waiting for child")
|
||||
before = aiosched.clock()
|
||||
try:
|
||||
|
|
Reference in New Issue