Added many fixes for exception propagation and SIGINT handling

This commit is contained in:
Nocturn9x 2023-04-22 12:26:37 +02:00
parent e730f7f27a
commit 509b555628
Signed by: nocturn9x
GPG Key ID: 8270F9F467971E59
24 changed files with 172 additions and 258 deletions

View File

@ -17,6 +17,7 @@ limitations under the License.
"""
from aiosched.runtime import run, get_event_loop, new_event_loop, clock, with_context
from aiosched.internals.syscalls import spawn, wait, sleep, cancel, checkpoint, join
import aiosched.util
import aiosched.task
import aiosched.errors
import aiosched.context
@ -41,4 +42,5 @@ __all__ = [
"checkpoint",
"NetworkChannel",
"socket",
"util"
]

View File

@ -15,29 +15,23 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from aiosched.task import Task, TaskState
from aiosched.task import Task
from aiosched.internals.syscalls import (
spawn,
wait,
cancel,
set_context,
close_context,
join
join,
current_task
)
from typing import Any, Coroutine, Callable
class TaskContext(Task):
class TaskContext:
"""
An asynchronous context manager that automatically waits
for all tasks spawned within it and cancels itself when
an exception occurs. A TaskContext object behaves like
a regular task and the event loop treats it like a single
unit rather than a collection of tasks (in fact, the event
loop doesn't even know, nor care about, whether the current
task is a task context or not, which is by design). Contexts
can be nested and will cancel inner ones if an exception is
raised inside them
an exception occurs. Contexts can be nested and will
cancel inner ones if an exception is raised inside them
"""
def __init__(self, silent: bool = False, gather: bool = True, timeout: int | float = 0.0) -> None:
@ -59,6 +53,8 @@ class TaskContext(Task):
# For how long do we allow tasks inside us
# to run?
self.timeout: int | float = timeout # TODO: Implement
# Have we crashed?
self.error: BaseException | None = None
async def spawn(
self, func: Callable[..., Coroutine[Any, Any, Any]], *args, **kwargs
@ -78,7 +74,7 @@ class TaskContext(Task):
Implements the asynchronous context manager interface
"""
await set_context(self)
self.entry_point = await current_task()
return self
def __eq__(self, other):
@ -114,12 +110,13 @@ class TaskContext(Task):
await wait(task)
except BaseException as exc:
await self.cancel(False)
self.exc = exc
self.error = exc
finally:
await close_context(self)
self.entry_point.propagate = True
if self.exc and not self.silent:
raise self.exc
if self.silent:
return
if self.entry_point.exc:
raise self.entry_point.exc
# Task method wrappers
@ -139,7 +136,6 @@ class TaskContext(Task):
task: TaskContext
await task.cancel(propagate)
self.cancelled = True
self.propagate = False
if propagate:
if isinstance(self.entry_point, Task):
await cancel(self.entry_point)
@ -158,66 +154,6 @@ class TaskContext(Task):
return False
return True
@property
def state(self) -> int:
return self.entry_point.state
@state.setter
def state(self, state: int):
self.entry_point.state = state
@property
def result(self) -> Any:
return self.entry_point.result
@result.setter
def result(self, result: Any):
self.entry_point.result = result
@property
def exc(self) -> BaseException:
return self.entry_point.exc
@exc.setter
def exc(self, exc: BaseException):
self.entry_point.exc = exc
@property
def propagate(self) -> bool:
return self.entry_point.propagate
@propagate.setter
def propagate(self, val: bool):
self.entry_point.propagate = val
@property
def name(self):
return self.entry_point.name
def throw(self, err: BaseException):
for task in self.tasks:
if task is self.entry_point:
continue
try:
task.exc = err
task.state = TaskState.CRASHED
task.throw(err)
finally:
continue
self.entry_point.throw(err)
@property
def joiners(self) -> set[Task]:
return self.entry_point.joiners
@joiners.setter
def joiners(self, joiners: set[Task]):
self.entry_point.joiners = joiners
@property
def coroutine(self):
return self.entry_point.coroutine
def __hash__(self):
return self.entry_point.__hash__()
@ -239,10 +175,7 @@ class TaskContext(Task):
result = "TaskContext(["
for i, task in enumerate(self.tasks):
if task is self.entry_point:
result += repr(self.entry_point)
else:
result += repr(task)
result += repr(task)
if i < len(self.tasks) - 1:
result += ", "
result += "])"

View File

@ -149,16 +149,11 @@ async def wait(task: Task) -> Any | None:
:returns: The task's return value, if any
"""
current = await current_task()
if task == current:
if task == await current_task():
# We don't do an "x is y" check because
# tasks and task contexts can compare equal
# despite having different memory addresses
raise SchedulerError("a task cannot join itself")
if current not in task.joiners:
# Luckily we use a set, so this has O(1)
# complexity on average
await join(task) # Waiting implies joining!
await syscall("wait", task)
if task.exc and task.state != TaskState.CANCELLED and task.propagate:
# The task raised an error that wasn't directly caused by a cancellation:
@ -228,19 +223,3 @@ async def io_release(stream):
"""
await syscall("io_release", stream)
async def set_context(ctx):
"""
Sets the current task context
"""
await syscall("set_context", ctx)
async def close_context(ctx):
"""
Closes the current task context
"""
await syscall("close_context", ctx)

View File

@ -32,7 +32,6 @@ from aiosched.errors import (
ResourceClosed,
ResourceBroken,
)
from aiosched.context import TaskContext
from selectors import DefaultSelector, BaseSelector, EVENT_READ, EVENT_WRITE
@ -121,6 +120,11 @@ class FIFOKernel:
"""
self._sigint_handled = True
# We reschedule the current task
# immediately no matter what it's
# doing so that we process the
# exception immediately
self.reschedule_running()
def done(self) -> bool:
"""
@ -280,23 +284,23 @@ class FIFOKernel:
# there are no more runnable tasks
return
self.current_task = self.run_ready.popleft()
# We nullify the exception object just in case the
# entry point raised and caught an error so that
# self.start() doesn't raise it again at the end
self.current_task.exc = None
self._running = True
# Some debugging and internal chatter here
self.current_task.state = TaskState.RUN
self.current_task.steps += 1
if self._sigint_handled:
self._sigint_handled = False
self.reschedule_running()
self.current_task.throw(KeyboardInterrupt())
self.join(self.current_task)
elif self.current_task.pending_cancellation:
# We perform the deferred cancellation
# if it was previously scheduled
self.cancel(self.current_task)
elif exc := self.current_task.pending_exception:
self.current_task.pending_exception = None
self.reschedule_running()
self.current_task.throw(exc)
else:
# Some debugging and internal chatter here
self.current_task.steps += 1
self.current_task.state = TaskState.RUN
self.debugger.before_task_step(self.current_task)
# Run a single step with the calculation (i.e. until a yield
# somewhere)
@ -351,6 +355,8 @@ class FIFOKernel:
task = next(iter(next(iter(self.selector.get_map().values())).data.values()))
elif self.paused:
task, *_ = self.paused.get()
else:
task = self.current_task
self.run_ready.append(task)
self.handle_errors(self.run_task_step)
elif not self.run_ready:
@ -382,7 +388,6 @@ class FIFOKernel:
try:
self.run()
finally:
self.debugger.on_exit()
signal.signal(signal.SIGINT, old)
if (
self.entry_point.exc
@ -393,6 +398,7 @@ class FIFOKernel:
# no need to raise it manually. If a context
# is not used, *then* we can raise the error
raise self.entry_point.exc
self.debugger.on_exit()
return self.entry_point.result
def io_release(self, resource):
@ -415,9 +421,12 @@ class FIFOKernel:
for each I/O resource the given task owns
"""
for key in filter(
lambda k: task in k.data.values(), dict(self.selector.get_map()).values()
):
for key in dict(self.selector.get_map()).values():
if task not in key.data.values():
continue
if len(key.data.values()) == 2:
if key.data.values()[0] != task or key.data.values[1] != task:
continue
self.notify_closing(key.fileobj, broken=True)
self.selector.unregister(key.fileobj)
task.last_io = ()
@ -457,8 +466,7 @@ class FIFOKernel:
if task is not self.current_task:
# We don't want to raise an error inside
# the task that's trying to close the stream!
for t in k.data:
self.handle_errors(partial(t.throw, exc), k.data)
self.handle_errors(partial(task.throw, exc), task)
self.reschedule_running()
def cancel(self, task: Task):
@ -471,8 +479,9 @@ class FIFOKernel:
self.handle_errors(partial(task.throw, Cancelled(task)), task)
if task.state != TaskState.CANCELLED:
task.pending_cancellation = True
self.io_release_task(task)
self.paused.discard(task)
else:
self.io_release_task(task)
self.paused.discard(task)
self.reschedule_running()
def handle_errors(self, func: Callable, task: Task | None = None):
@ -520,8 +529,6 @@ class FIFOKernel:
task.state = TaskState.CRASHED
self.debugger.on_exception_raised(task, err)
self.wait(task)
if isinstance(err, KeyboardInterrupt):
raise
def sleep(self, seconds: int | float):
"""
@ -546,16 +553,21 @@ class FIFOKernel:
executing
"""
if task != self.current_task:
task.joiners.add(self.current_task)
if task.done():
self.paused.discard(task)
self.io_release_task(task)
self.run_ready.extend(task.joiners)
for joiner in task.joiners:
joiner.pending_exception = task.exc
def join(self, task: Task):
"""
Tells the event loop that the current task
wants to wait on the given one, but without
actually waiting for its completion
actually waiting for its completion. This is
an internal method and should not be used outside
the kernel machinery
"""
task.joiners.add(self.current_task)
@ -573,36 +585,6 @@ class FIFOKernel:
self.reschedule_running()
self.debugger.on_task_spawn(task)
def set_context(self, ctx: TaskContext):
"""
Sets the current task context. This is
implemented as simply wrapping the current
task inside the context and replacing the
Task object with the TaskContext one. This
may also wrap another task context into a
new one, but the loop doesn't need to care
about that: the API is designed exactly for
this
"""
ctx.entry_point = self.current_task
ctx.tasks.append(ctx.entry_point)
self.current_task.context = ctx
self.current_task = ctx
self.debugger.on_context_creation(ctx)
self.reschedule_running()
def close_context(self, ctx: TaskContext):
"""
Closes the given context
"""
self.debugger.on_context_exit(ctx)
task = ctx.entry_point
task.context = None
self.current_task = task
self.reschedule_running()
def get_current_task(self):
"""
Returns the current task to an asynchronous

View File

@ -80,9 +80,11 @@ class Task:
# Is this task within a context? This is needed to fix a bug that would occur when
# the event loop tries to raise the exception caused by first task that kicked the
# loop even if that context already ignored said error
context: "TaskContext" = None
context: "TaskContext" = field(default=None, repr=False)
# We propagate exception only at the first call to wait()
propagate: bool = True
# Do we have any exceptions pending?
pending_exception: Exception | None = None
def run(self, what: Any | None = None):
"""

View File

@ -0,0 +1,5 @@
from aiosched.util import debugging
__all__ = ["debugging",
]

View File

@ -1,6 +1,7 @@
from abc import ABC, abstractmethod
from aiosched.task import Task
from aiosched.context import TaskContext
from selectors import EVENT_READ, EVENT_WRITE
class BaseDebugger(ABC):
@ -242,3 +243,74 @@ class BaseDebugger(ABC):
"""
return NotImplemented
class SimpleDebugger(BaseDebugger):
"""
A simple debugger for aiosched
"""
def on_start(self):
print("## Started running")
def on_exit(self):
print("## Finished running")
def on_task_schedule(self, task, delay: int):
print(
f">> A task named '{task.name}' was scheduled to run in {delay:.2f} seconds"
)
def on_task_spawn(self, task):
print(f">> A task named '{task.name}' was spawned")
def on_task_exit(self, task):
print(f"<< Task '{task.name}' exited")
def before_task_step(self, task):
print(f"-> About to run a step for '{task.name}'")
def after_task_step(self, task):
print(f"<- Ran a step for '{task.name}'")
def before_sleep(self, task, seconds):
print(f"# About to put '{task.name}' to sleep for {seconds:.2f} seconds")
def after_sleep(self, task, seconds):
print(f"# Task '{task.name}' slept for {seconds:.2f} seconds")
def before_io(self, timeout):
if timeout is None:
timeout = float("inf")
print(f"!! About to check for I/O for up to {timeout:.2f} seconds")
def after_io(self, timeout):
print(f"!! Done I/O check (waited for {timeout:.2f} seconds)")
def before_cancel(self, task):
print(f"// About to cancel '{task.name}'")
def after_cancel(self, task):
print(f"// Cancelled '{task.name}'")
def on_exception_raised(self, task, exc):
print(f"== '{task.name}' raised {repr(exc)}")
def on_context_creation(self, ctx):
print(f"=> A new context was created by {ctx.entry_point.name!r}")
def on_context_exit(self, ctx):
print(f"=> A context was closed by {ctx.entry_point.name}")
def on_io_schedule(self, stream, event: int):
evt = ""
if event == EVENT_READ:
evt = "reading"
elif event == EVENT_WRITE:
evt = "writing"
elif event == EVENT_WRITE | EVENT_READ:
evt = "reading or writing"
print(f"|| Stream {stream!r} was scheduled for {evt}")
def on_io_unschedule(self, stream):
print(f"|| Stream {stream!r} was unscheduled")

View File

@ -1,6 +1,5 @@
import random
import aiosched
from debugger import Debugger
async def child(name: str, n: int):

View File

@ -2,8 +2,6 @@ import aiosched
import logging
import sys
from debugger import Debugger
# An asynchronous chatroom
clients: dict[aiosched.socket.AsyncSocket, list[str, str]] = {}
@ -46,10 +44,14 @@ async def handler(sock: aiosched.socket.AsyncSocket):
name = ""
async with sock: # Closes the socket automatically
await sock.send_all(b"Welcome to the chatroom pal, may you tell me your name?\n> ")
while True:
cond = True
while cond:
while not name.endswith("\n"):
name = (await sock.receive(64)).decode()
name = name[:-1]
if name == "":
cond = False
break
name = name.rstrip("\n")
if name not in names:
names.add(name)
clients[sock][0] = name
@ -66,14 +68,24 @@ async def handler(sock: aiosched.socket.AsyncSocket):
data = await sock.receive(1024)
if not data:
break
logging.info(f"Got: {data!r} from {address}")
for i, client_sock in enumerate(clients):
if client_sock != sock and clients[client_sock][0]:
logging.info(f"Sending {data!r} to {':'.join(map(str, await client_sock.getpeername()))}")
if not data.endswith(b"\n"):
data += b"\n"
await client_sock.send_all(f"[{name}] ({address}): {data.decode()}> ".encode())
logging.info(f"Sent {data!r} to {i} clients")
decoded = data.decode().rstrip("\n")
if decoded.startswith("/"):
logging.info(f"{name} issued server command {decoded}")
match decoded[1:]:
case "bye":
await sock.send_all(b"Bye!\n")
break
case _:
await sock.send_all(b"Unknown command\n")
else:
logging.info(f"Got: {data!r} from {address}")
for i, client_sock in enumerate(clients):
if client_sock != sock and clients[client_sock][0]:
logging.info(f"Sending {data!r} to {':'.join(map(str, await client_sock.getpeername()))}")
if not data.endswith(b"\n"):
data += b"\n"
await client_sock.send_all(f"[{name}] ({address}): {data.decode()}> ".encode())
logging.info(f"Sent {data!r} to {i} clients")
logging.info(f"Connection from {address} closed")
clients.pop(sock)
names.discard(name)
@ -88,7 +100,7 @@ if __name__ == "__main__":
datefmt="%d/%m/%Y %p",
)
try:
aiosched.run(serve, ("0.0.0.0", port), debugger=())
aiosched.run(serve, ("0.0.0.0", port), debugger=None)
except (Exception, KeyboardInterrupt) as error: # Exceptions propagate!
if isinstance(error, KeyboardInterrupt):
logging.info("Ctrl+C detected, exiting")

View File

@ -1,6 +1,5 @@
import aiosched
from raw_catch import child_raises
from debugger import Debugger
async def main(children: list[tuple[str, int]]):

View File

@ -1,6 +1,6 @@
import aiosched
from raw_catch import child
from debugger import Debugger
async def main(children: list[tuple[str, int]]):

View File

@ -1,6 +1,6 @@
import aiosched
from raw_wait import child
from debugger import Debugger
async def main(children: list[tuple[str, int]]):

View File

@ -1,73 +0,0 @@
from aiosched.util.debugging import BaseDebugger
from selectors import EVENT_READ, EVENT_WRITE
class Debugger(BaseDebugger):
"""
A simple debugger for aiosched
"""
def on_start(self):
print("## Started running")
def on_exit(self):
print("## Finished running")
def on_task_schedule(self, task, delay: int):
print(
f">> A task named '{task.name}' was scheduled to run in {delay:.2f} seconds"
)
def on_task_spawn(self, task):
print(f">> A task named '{task.name}' was spawned")
def on_task_exit(self, task):
print(f"<< Task '{task.name}' exited")
def before_task_step(self, task):
print(f"-> About to run a step for '{task.name}'")
def after_task_step(self, task):
print(f"<- Ran a step for '{task.name}'")
def before_sleep(self, task, seconds):
print(f"# About to put '{task.name}' to sleep for {seconds:.2f} seconds")
def after_sleep(self, task, seconds):
print(f"# Task '{task.name}' slept for {seconds:.2f} seconds")
def before_io(self, timeout):
if timeout is None:
timeout = float("inf")
print(f"!! About to check for I/O for up to {timeout:.2f} seconds")
def after_io(self, timeout):
print(f"!! Done I/O check (waited for {timeout:.2f} seconds)")
def before_cancel(self, task):
print(f"// About to cancel '{task.name}'")
def after_cancel(self, task):
print(f"// Cancelled '{task.name}'")
def on_exception_raised(self, task, exc):
print(f"== '{task.name}' raised {repr(exc)}")
def on_context_creation(self, ctx):
print(f"=> A new context was created by {ctx.entry_point.name!r}")
def on_context_exit(self, ctx):
print(f"=> A context was closed by {ctx.entry_point.name}")
def on_io_schedule(self, stream, event: int):
evt = ""
if event == EVENT_READ:
evt = "reading"
elif event == EVENT_WRITE:
evt = "writing"
elif event == EVENT_WRITE | EVENT_READ:
evt = "reading or writing"
print(f"|| Stream {stream!r} was scheduled for {evt}")
def on_io_unschedule(self, stream):
print(f"|| Stream {stream!r} was unscheduled")

View File

@ -1,7 +1,7 @@
import sys
import logging
import aiosched
from debugger import Debugger
# A test to check for asynchronous I/O
@ -51,6 +51,9 @@ async def handler(sock: aiosched.socket.AsyncSocket, client_address: tuple):
elif data == b"exit\n":
await sock.send_all(b"I'm dead dude\n")
raise TypeError("Oh, no, I'm gonna die!")
elif data == b"fatal\n":
await sock.send_all(b"What a dick\n")
raise KeyboardInterrupt("He told me to do it!")
logging.info(f"Got: {data!r} from {address}")
await sock.send_all(b"Got: " + data)
logging.info(f"Echoed back {data!r} to {address}")
@ -65,7 +68,7 @@ if __name__ == "__main__":
datefmt="%d/%m/%Y %H:%M:%S %p",
)
try:
aiosched.run(serve, ("localhost", port), debugger=())
aiosched.run(serve, ("localhost", port), debugger=None)
except (Exception, KeyboardInterrupt) as error: # Exceptions propagate!
if isinstance(error, KeyboardInterrupt):
logging.info("Ctrl+C detected, exiting")

View File

@ -1,4 +1,4 @@
from debugger import Debugger
import aiosched

View File

@ -1,5 +1,5 @@
import aiosched
from debugger import Debugger
async def sender(c: aiosched.MemoryChannel, n: int):

View File

@ -1,7 +1,7 @@
import aiosched
from raw_catch import child_raises
from raw_wait import child as successful
from debugger import Debugger
async def main(

View File

@ -1,6 +1,6 @@
import aiosched
from raw_catch import child_raises
from debugger import Debugger
# TODO: This crashes 1 second later than it should be

View File

@ -1,6 +1,6 @@
import aiosched
from raw_wait import child
from debugger import Debugger
async def main(

View File

@ -1,5 +1,5 @@
import aiosched
from debugger import Debugger
async def producer(c: aiosched.NetworkChannel, n: int):

View File

@ -1,5 +1,5 @@
import aiosched
from debugger import Debugger
async def producer(q: aiosched.Queue, n: int):

View File

@ -1,5 +1,5 @@
import aiosched
from debugger import Debugger
async def child_raises(name: str, n: int):

View File

@ -1,5 +1,4 @@
import aiosched
from debugger import Debugger
async def child(name: str, n: int):

View File

@ -1,4 +1,4 @@
from debugger import Debugger
import aiosched
import socket as sock
import ssl