Various fixes for exception handling in contexts
This commit is contained in:
parent
ce1583e9c2
commit
ff2acf298f
|
@ -16,7 +16,7 @@ See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
"""
|
"""
|
||||||
from aiosched.runtime import run, get_event_loop, new_event_loop, clock, with_context
|
from aiosched.runtime import run, get_event_loop, new_event_loop, clock, with_context
|
||||||
from aiosched.internals.syscalls import spawn, wait, sleep, cancel, checkpoint
|
from aiosched.internals.syscalls import spawn, wait, sleep, cancel, checkpoint, join
|
||||||
import aiosched.task
|
import aiosched.task
|
||||||
import aiosched.errors
|
import aiosched.errors
|
||||||
import aiosched.context
|
import aiosched.context
|
||||||
|
|
|
@ -15,6 +15,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
See the License for the specific language governing permissions and
|
See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
"""
|
"""
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
from aiosched.task import Task
|
from aiosched.task import Task
|
||||||
from aiosched.errors import Cancelled
|
from aiosched.errors import Cancelled
|
||||||
from aiosched.internals.syscalls import (
|
from aiosched.internals.syscalls import (
|
||||||
|
@ -57,11 +59,9 @@ class TaskContext(Task):
|
||||||
# Do we gather multiple exceptions from
|
# Do we gather multiple exceptions from
|
||||||
# children tasks?
|
# children tasks?
|
||||||
self.gather: bool = gather
|
self.gather: bool = gather
|
||||||
# Do we wrap any other task contexts?
|
|
||||||
self.inner: TaskContext | None = None
|
|
||||||
|
|
||||||
async def spawn(
|
async def spawn(
|
||||||
self, func: Callable[..., Coroutine[Any, Any, Any]], *args, **kwargs
|
self, func: Callable[..., Coroutine[Any, Any, Any]], *args, **kwargs
|
||||||
) -> Task:
|
) -> Task:
|
||||||
"""
|
"""
|
||||||
Spawns a child task
|
Spawns a child task
|
||||||
|
@ -69,8 +69,8 @@ class TaskContext(Task):
|
||||||
|
|
||||||
task = await spawn(func, *args, **kwargs)
|
task = await spawn(func, *args, **kwargs)
|
||||||
task.context = self
|
task.context = self
|
||||||
await join(task)
|
|
||||||
self.tasks.append(task)
|
self.tasks.append(task)
|
||||||
|
await join(task)
|
||||||
return task
|
return task
|
||||||
|
|
||||||
async def __aenter__(self):
|
async def __aenter__(self):
|
||||||
|
@ -101,14 +101,9 @@ class TaskContext(Task):
|
||||||
self.exc = exc
|
self.exc = exc
|
||||||
if not self.silent:
|
if not self.silent:
|
||||||
raise self.exc
|
raise self.exc
|
||||||
if self.inner:
|
finally:
|
||||||
for task in self.inner.tasks:
|
await close_context(self)
|
||||||
try:
|
self.entry_point.propagate = True
|
||||||
await wait(task)
|
|
||||||
except BaseException:
|
|
||||||
await self.inner.cancel(False)
|
|
||||||
self.inner.propagate = False
|
|
||||||
await close_context()
|
|
||||||
|
|
||||||
# Task method wrappers
|
# Task method wrappers
|
||||||
|
|
||||||
|
@ -119,21 +114,22 @@ class TaskContext(Task):
|
||||||
and cancelling them
|
and cancelling them
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if self.inner:
|
|
||||||
await self.inner.cancel(propagate)
|
|
||||||
for task in self.tasks:
|
for task in self.tasks:
|
||||||
if task is self.entry_point:
|
if task is self.entry_point:
|
||||||
continue
|
continue
|
||||||
await cancel(task)
|
if isinstance(task, Task):
|
||||||
|
await cancel(task)
|
||||||
|
else:
|
||||||
|
task: TaskContext
|
||||||
|
await task.cancel(propagate)
|
||||||
self.cancelled = True
|
self.cancelled = True
|
||||||
await close_context()
|
|
||||||
self.propagate = False
|
self.propagate = False
|
||||||
if propagate:
|
if propagate:
|
||||||
if isinstance(self.entry_point, TaskContext):
|
if isinstance(self.entry_point, Task):
|
||||||
self.entry_point: TaskContext
|
|
||||||
await self.entry_point.cancel()
|
|
||||||
else:
|
|
||||||
await cancel(self.entry_point)
|
await cancel(self.entry_point)
|
||||||
|
else:
|
||||||
|
self.entry_point: TaskContext
|
||||||
|
await self.entry_point.cancel(propagate)
|
||||||
|
|
||||||
def done(self) -> bool:
|
def done(self) -> bool:
|
||||||
"""
|
"""
|
||||||
|
@ -142,17 +138,9 @@ class TaskContext(Task):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
for task in self.tasks:
|
for task in self.tasks:
|
||||||
if task is self.entry_point:
|
|
||||||
continue
|
|
||||||
if not task.done():
|
if not task.done():
|
||||||
return False
|
return False
|
||||||
if (
|
return True
|
||||||
not isinstance(self.entry_point, TaskContext)
|
|
||||||
and not self.entry_point.done()
|
|
||||||
):
|
|
||||||
return False
|
|
||||||
if self.inner:
|
|
||||||
return self.inner.done()
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def state(self) -> int:
|
def state(self) -> int:
|
||||||
|
@ -206,6 +194,10 @@ class TaskContext(Task):
|
||||||
def joiners(self, joiners: set[Task]):
|
def joiners(self, joiners: set[Task]):
|
||||||
self.entry_point.joiners = joiners
|
self.entry_point.joiners = joiners
|
||||||
|
|
||||||
|
@property
|
||||||
|
def coroutine(self):
|
||||||
|
return self.entry_point.coroutine
|
||||||
|
|
||||||
def __hash__(self):
|
def __hash__(self):
|
||||||
return self.entry_point.__hash__()
|
return self.entry_point.__hash__()
|
||||||
|
|
||||||
|
|
|
@ -139,7 +139,8 @@ async def wait(task: Task) -> Any | None:
|
||||||
multiple times by multiple tasks.
|
multiple times by multiple tasks.
|
||||||
Returns immediately if the task has
|
Returns immediately if the task has
|
||||||
completed already, but exceptions are
|
completed already, but exceptions are
|
||||||
propagated only once
|
propagated only once. Returns the task's
|
||||||
|
return value, if it has one
|
||||||
|
|
||||||
:param task: The task to wait for
|
:param task: The task to wait for
|
||||||
:type task: :class: Task
|
:type task: :class: Task
|
||||||
|
@ -149,6 +150,10 @@ async def wait(task: Task) -> Any | None:
|
||||||
current = await current_task()
|
current = await current_task()
|
||||||
if task is current:
|
if task is current:
|
||||||
raise SchedulerError("a task cannot join itself")
|
raise SchedulerError("a task cannot join itself")
|
||||||
|
if current not in task.joiners:
|
||||||
|
# Luckily we use a set, so this has O(1)
|
||||||
|
# complexity
|
||||||
|
await join(task) # Waiting implies joining!
|
||||||
await syscall("wait", task)
|
await syscall("wait", task)
|
||||||
if task.exc and task.state != TaskState.CANCELLED and task.propagate:
|
if task.exc and task.state != TaskState.CANCELLED and task.propagate:
|
||||||
task.propagate = False
|
task.propagate = False
|
||||||
|
@ -226,9 +231,9 @@ async def set_context(ctx):
|
||||||
await syscall("set_context", ctx)
|
await syscall("set_context", ctx)
|
||||||
|
|
||||||
|
|
||||||
async def close_context():
|
async def close_context(ctx):
|
||||||
"""
|
"""
|
||||||
Closes the current task context
|
Closes the current task context
|
||||||
"""
|
"""
|
||||||
|
|
||||||
await syscall("close_context")
|
await syscall("close_context", ctx)
|
||||||
|
|
|
@ -233,6 +233,10 @@ class FIFOKernel:
|
||||||
if not self.run_ready:
|
if not self.run_ready:
|
||||||
return # No more tasks to run!
|
return # No more tasks to run!
|
||||||
self.current_task = self.run_ready.popleft()
|
self.current_task = self.run_ready.popleft()
|
||||||
|
# We nullify the exception object just in case the
|
||||||
|
# entry point raised and caught an error so that
|
||||||
|
# self.start() doesn't raise it again at the end
|
||||||
|
self.current_task.exc = None
|
||||||
self.debugger.before_task_step(self.current_task)
|
self.debugger.before_task_step(self.current_task)
|
||||||
# Some debugging and internal chatter here
|
# Some debugging and internal chatter here
|
||||||
self.current_task.state = TaskState.RUN
|
self.current_task.state = TaskState.RUN
|
||||||
|
@ -477,25 +481,24 @@ class FIFOKernel:
|
||||||
Sets the current task context. This is
|
Sets the current task context. This is
|
||||||
implemented as simply wrapping the current
|
implemented as simply wrapping the current
|
||||||
task inside the context and replacing the
|
task inside the context and replacing the
|
||||||
Task object with the TaskContext one
|
Task object with the TaskContext one. This
|
||||||
|
may also wrap another task context into a
|
||||||
|
new one, but the loop doesn't need to care
|
||||||
|
about that: the API is designed exactly for
|
||||||
|
this
|
||||||
"""
|
"""
|
||||||
|
|
||||||
ctx.entry_point = self.current_task
|
ctx.entry_point = self.current_task
|
||||||
if isinstance(self.current_task, TaskContext):
|
ctx.tasks.append(ctx.entry_point)
|
||||||
self.current_task.inner = ctx
|
self.current_task.context = ctx
|
||||||
else:
|
|
||||||
ctx.tasks.append(ctx.entry_point)
|
|
||||||
self.current_task.context = ctx
|
|
||||||
self.current_task = ctx
|
self.current_task = ctx
|
||||||
self.reschedule_running()
|
self.reschedule_running()
|
||||||
|
|
||||||
def close_context(self):
|
def close_context(self, ctx: TaskContext):
|
||||||
"""
|
"""
|
||||||
Closes the context associated with the current
|
Closes the given context
|
||||||
task
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
ctx: TaskContext = self.current_task
|
|
||||||
task = ctx.entry_point
|
task = ctx.entry_point
|
||||||
task.context = None
|
task.context = None
|
||||||
self.current_task = task
|
self.current_task = task
|
||||||
|
|
|
@ -20,7 +20,6 @@ from enum import Enum, auto
|
||||||
from typing import Coroutine, Any
|
from typing import Coroutine, Any
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
|
||||||
class TaskState(Enum):
|
class TaskState(Enum):
|
||||||
"""
|
"""
|
||||||
An enumeration of task states
|
An enumeration of task states
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
import aiosched
|
import aiosched
|
||||||
from raw_catch import child
|
from raw_catch import child_raises
|
||||||
from debugger import Debugger
|
from debugger import Debugger
|
||||||
|
|
||||||
|
|
||||||
|
@ -8,7 +8,7 @@ async def main(children: list[tuple[str, int]]):
|
||||||
async with aiosched.with_context() as ctx:
|
async with aiosched.with_context() as ctx:
|
||||||
print("[main] Spawning children")
|
print("[main] Spawning children")
|
||||||
for name, delay in children:
|
for name, delay in children:
|
||||||
await ctx.spawn(child, name, delay)
|
await ctx.spawn(child_raises, name, delay)
|
||||||
print("[main] Children spawned")
|
print("[main] Children spawned")
|
||||||
before = aiosched.clock()
|
before = aiosched.clock()
|
||||||
except BaseException as err:
|
except BaseException as err:
|
||||||
|
|
|
@ -14,4 +14,4 @@ async def main(children: list[tuple[str, int]]):
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
aiosched.run(main, [("first", 1), ("second", 2), ("third", 3)], debugger=Debugger())
|
aiosched.run(main, [("first", 1), ("second", 2), ("third", 3)], debugger=None)
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
import aiosched
|
import aiosched
|
||||||
from raw_catch import child as errorer
|
from raw_catch import child_raises
|
||||||
from raw_wait import child as successful
|
from raw_wait import child as successful
|
||||||
from debugger import Debugger
|
from debugger import Debugger
|
||||||
|
|
||||||
|
@ -19,7 +19,7 @@ async def main(
|
||||||
async with aiosched.with_context() as ctx2:
|
async with aiosched.with_context() as ctx2:
|
||||||
print("[main] Spawning children in second context")
|
print("[main] Spawning children in second context")
|
||||||
for name, delay in children_inner:
|
for name, delay in children_inner:
|
||||||
await ctx2.spawn(errorer, name, delay)
|
await ctx2.spawn(child_raises, name, delay)
|
||||||
print("[main] Children spawned")
|
print("[main] Children spawned")
|
||||||
print(f"[main] Children exited in {aiosched.clock() - before:.2f} seconds")
|
print(f"[main] Children exited in {aiosched.clock() - before:.2f} seconds")
|
||||||
|
|
||||||
|
@ -27,7 +27,7 @@ async def main(
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
aiosched.run(
|
aiosched.run(
|
||||||
main,
|
main,
|
||||||
[("first", 1), ("second", 2)],
|
[("first", 1), ("third", 3)],
|
||||||
[("third", 3), ("fourth", 4)],
|
[("second", 2), ("fourth", 4)],
|
||||||
debugger=None,
|
debugger=None,
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
import aiosched
|
import aiosched
|
||||||
from raw_catch import child
|
from raw_catch import child_raises
|
||||||
from debugger import Debugger
|
from debugger import Debugger
|
||||||
|
|
||||||
|
|
||||||
|
@ -12,12 +12,12 @@ async def main(
|
||||||
before = aiosched.clock()
|
before = aiosched.clock()
|
||||||
print("[main] Spawning children in first context")
|
print("[main] Spawning children in first context")
|
||||||
for name, delay in children_outer:
|
for name, delay in children_outer:
|
||||||
await ctx.spawn(child, name, delay)
|
await ctx.spawn(child_raises, name, delay)
|
||||||
print("[main] Children spawned")
|
print("[main] Children spawned")
|
||||||
async with aiosched.with_context() as ctx2:
|
async with aiosched.with_context() as ctx2:
|
||||||
print("[main] Spawning children in second context")
|
print("[main] Spawning children in second context")
|
||||||
for name, delay in children_inner:
|
for name, delay in children_inner:
|
||||||
await ctx2.spawn(child, name, delay)
|
await ctx2.spawn(child_raises, name, delay)
|
||||||
print("[main] Children spawned")
|
print("[main] Children spawned")
|
||||||
except BaseException as err:
|
except BaseException as err:
|
||||||
print(f"[main] Child raised an exception -> {type(err).__name__}: {err}")
|
print(f"[main] Child raised an exception -> {type(err).__name__}: {err}")
|
||||||
|
|
|
@ -2,17 +2,17 @@ import aiosched
|
||||||
from debugger import Debugger
|
from debugger import Debugger
|
||||||
|
|
||||||
|
|
||||||
async def child(name: str, n: int):
|
async def child_raises(name: str, n: int):
|
||||||
before = aiosched.clock()
|
before = aiosched.clock()
|
||||||
print(f"[child {name}] Sleeping for {n} seconds")
|
print(f"[child {name}] Sleeping for {n} seconds")
|
||||||
await aiosched.sleep(n)
|
await aiosched.sleep(n)
|
||||||
print(f"[child {name}] Done! Slept for {aiosched.clock() - before:.2f} seconds")
|
print(f"[child {name}] Done! Slept for {aiosched.clock() - before:.2f} seconds, raising now!")
|
||||||
raise TypeError("waa")
|
raise TypeError("waa")
|
||||||
|
|
||||||
|
|
||||||
async def main(n: int):
|
async def main(n: int):
|
||||||
print("[main] Spawning child")
|
print("[main] Spawning child")
|
||||||
task = await aiosched.spawn(child, "raise", n)
|
task = await aiosched.spawn(child_raises, "raise", n)
|
||||||
print("[main] Waiting for child")
|
print("[main] Waiting for child")
|
||||||
before = aiosched.clock()
|
before = aiosched.clock()
|
||||||
try:
|
try:
|
||||||
|
|
Reference in New Issue