391 lines
14 KiB
Python
391 lines
14 KiB
Python
import traceback
|
|
import warnings
|
|
from types import FrameType
|
|
from structio.abc import (
|
|
BaseKernel,
|
|
BaseClock,
|
|
BaseDebugger,
|
|
BaseIOManager,
|
|
SignalManager,
|
|
)
|
|
from structio.core.context import TaskPool, TaskScope
|
|
from structio.core.task import Task, TaskState
|
|
from structio.util.ki import CTRLC_PROTECTION_ENABLED
|
|
from structio.core.time.queue import TimeQueue
|
|
from structio.exceptions import StructIOException, Cancelled, TimedOut
|
|
from collections import deque
|
|
from typing import Callable, Coroutine, Any
|
|
from functools import partial
|
|
import signal
|
|
|
|
|
|
class FIFOKernel(BaseKernel):
|
|
"""
|
|
An asynchronous event loop implementation
|
|
with a FIFO scheduling policy
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
clock: BaseClock,
|
|
io_manager: BaseIOManager,
|
|
signal_managers: list[SignalManager],
|
|
tools: list[BaseDebugger] | None = None,
|
|
restrict_ki_to_checkpoints: bool = False,
|
|
):
|
|
super().__init__(
|
|
clock, io_manager, signal_managers, tools, restrict_ki_to_checkpoints
|
|
)
|
|
self.skip: bool = False
|
|
# Tasks that are ready to run
|
|
self.run_queue: deque[Task] = deque()
|
|
# Data to send back to tasks
|
|
self.data: dict[Task, Any] = {}
|
|
# Have we handled SIGINT?
|
|
self._sigint_handled: bool = False
|
|
# Paused tasks along with their deadlines
|
|
self.paused: TimeQueue = TimeQueue(self.clock)
|
|
# All task scopes we handle
|
|
self.scopes: list[TaskScope] = []
|
|
self.pool = TaskPool()
|
|
self.current_pool = self.pool
|
|
self.current_scope = self.current_pool.scope
|
|
self.current_scope.shielded = False
|
|
self.scopes.append(self.current_scope)
|
|
self._closing = False
|
|
|
|
def get_closest_deadline(self):
|
|
return min(
|
|
[
|
|
self.current_scope.get_actual_timeout(),
|
|
self.paused.get_closest_deadline(),
|
|
]
|
|
)
|
|
|
|
def get_closest_deadline_owner(self):
|
|
return self.paused.peek()
|
|
|
|
def event(self, evt_name: str, *args):
|
|
if not hasattr(BaseDebugger, evt_name):
|
|
warnings.warn(f"Invalid debugging event fired: {evt_name!r}")
|
|
return
|
|
for tool in self.tools:
|
|
if f := getattr(tool, evt_name, None):
|
|
try:
|
|
f(*args)
|
|
except BaseException as e:
|
|
# We really can't afford to have our internals explode,
|
|
# sorry!
|
|
warnings.warn(
|
|
f"Exception during debugging event delivery ({evt_name!r}): {type(e).__name__} -> {e}",
|
|
)
|
|
traceback.print_tb(e.__traceback__)
|
|
|
|
def done(self):
|
|
if self.entry_point.done():
|
|
return True
|
|
if any([self.run_queue, self.paused, self.io_manager.pending()]):
|
|
return False
|
|
for scope in self.scopes:
|
|
if not scope.done():
|
|
return False
|
|
return True
|
|
|
|
def spawn(self, func: Callable[[Any, Any], Coroutine[Any, Any, Any]], *args):
|
|
task = Task(func.__name__ or repr(func), func(*args), self.current_pool)
|
|
# We inject our magic secret variable into the coroutine's stack frame, so
|
|
# we can look it up later
|
|
task.coroutine.cr_frame.f_locals.setdefault(CTRLC_PROTECTION_ENABLED, False)
|
|
task.pool.scope.tasks.append(task)
|
|
self.run_queue.append(task)
|
|
self.event("on_task_spawn")
|
|
return task
|
|
|
|
def spawn_system_task(
|
|
self, func: Callable[[Any, Any], Coroutine[Any, Any, Any]], *args
|
|
):
|
|
task = Task(func.__name__ or repr(func), func(*args), self.pool)
|
|
task.coroutine.cr_frame.f_locals.setdefault(CTRLC_PROTECTION_ENABLED, True)
|
|
task.pool.scope.tasks.append(task)
|
|
self.run_queue.append(task)
|
|
self.event("on_task_spawn")
|
|
return task
|
|
|
|
def signal_notify(self, sig: int, frame: FrameType):
|
|
match sig:
|
|
case signal.SIGINT:
|
|
self._sigint_handled = True
|
|
case _:
|
|
pass
|
|
|
|
def step(self):
|
|
"""
|
|
Run a single task step (i.e. until an "await" to our
|
|
primitives somewhere)
|
|
"""
|
|
|
|
self.current_task = self.run_queue.popleft()
|
|
while self.current_task.done():
|
|
if not self.run_queue:
|
|
return
|
|
self.current_task = self.run_queue.popleft()
|
|
runner = partial(
|
|
self.current_task.coroutine.send, self.data.pop(self.current_task, None)
|
|
)
|
|
if self.current_task.pending_cancellation:
|
|
_runner = partial(self.current_task.coroutine.throw, Cancelled())
|
|
elif self._sigint_handled:
|
|
self._sigint_handled = False
|
|
runner = partial(self.current_task.coroutine.throw, KeyboardInterrupt())
|
|
self.event("before_task_step", self.current_task)
|
|
method, args, kwargs = runner()
|
|
if not callable(getattr(self, method, None)):
|
|
# This if block is meant to be triggered by other async
|
|
# libraries, which most likely have different method names and behaviors
|
|
# compared to us. If you get this exception, and you're 100% sure you're
|
|
# not mixing async primitives from other libraries, then it's a bug!
|
|
self.throw(
|
|
self.current_task,
|
|
StructIOException(
|
|
"Uh oh! Something bad just happened: did you try to mix "
|
|
"primitives from other async libraries?"
|
|
),
|
|
)
|
|
# Sneaky method call, thanks to David Beazley for this ;)
|
|
getattr(self, method)(*args, **kwargs)
|
|
self.event("after_task_step", self.current_task)
|
|
|
|
def throw(self, task: Task, err: BaseException):
|
|
if task.done():
|
|
return
|
|
if task.state == TaskState.PAUSED:
|
|
self.paused.discard(task)
|
|
self.handle_errors(partial(task.coroutine.throw, err), task)
|
|
|
|
def reschedule(self, task: Task):
|
|
if task.done():
|
|
return
|
|
self.run_queue.append(task)
|
|
|
|
def check_cancelled(self):
|
|
if self._sigint_handled:
|
|
self.throw(self.entry_point, KeyboardInterrupt())
|
|
elif self.current_task.pending_cancellation:
|
|
self.cancel_task(self.current_task)
|
|
else:
|
|
# We reschedule the caller immediately!
|
|
self.run_queue.appendleft(self.current_task)
|
|
|
|
def schedule_point(self):
|
|
self.reschedule_running()
|
|
|
|
def sleep(self, amount):
|
|
"""
|
|
Puts the current task to sleep for the given amount of
|
|
time as defined by our current clock
|
|
"""
|
|
|
|
# Just to avoid code duplication, you know
|
|
self.suspend()
|
|
if amount > 0:
|
|
self.event("before_sleep", self.current_task, amount)
|
|
self.current_task.next_deadline = self.clock.deadline(amount)
|
|
self.paused.put(self.current_task, amount)
|
|
else:
|
|
# If sleep is called with 0 as argument,
|
|
# then it's just a checkpoint!
|
|
self.skip = True
|
|
self.reschedule_running()
|
|
|
|
def check_scopes(self):
|
|
for scope in self.scopes:
|
|
if scope.get_actual_timeout() <= self.clock.current_time():
|
|
error = TimedOut("timed out")
|
|
error.scope = scope
|
|
self.throw(scope.owner, error)
|
|
|
|
def wakeup(self):
|
|
while (
|
|
self.paused
|
|
and self.paused.peek().next_deadline <= self.clock.current_time()
|
|
):
|
|
task, _ = self.paused.get()
|
|
task.next_deadline = 0
|
|
self.event(
|
|
"after_sleep", task, task.paused_when - self.clock.current_time()
|
|
)
|
|
self.reschedule(task)
|
|
|
|
def run(self):
|
|
"""
|
|
This is the actual "loop" part
|
|
of the "event loop"
|
|
"""
|
|
|
|
while not self.done():
|
|
if self.run_queue and not self.skip:
|
|
self.handle_errors(self.step)
|
|
self.running = False
|
|
self.skip = False
|
|
if self._sigint_handled and not self.restrict_ki_to_checkpoints:
|
|
self.throw(self.entry_point, KeyboardInterrupt())
|
|
if self.io_manager.pending():
|
|
self.io_manager.wait_io()
|
|
self.wakeup()
|
|
self.check_scopes()
|
|
self.close()
|
|
|
|
def reschedule_running(self):
|
|
"""
|
|
Reschedules the currently running task
|
|
"""
|
|
|
|
self.run_queue.append(self.current_task)
|
|
|
|
def handle_errors(self, func: Callable, task: Task | None = None):
|
|
"""
|
|
Convenience method for handling various exceptions
|
|
from tasks
|
|
"""
|
|
|
|
try:
|
|
func()
|
|
except StopIteration as ret:
|
|
# We re-define it because we call step() with
|
|
# this method and that changes the current task
|
|
task = task or self.current_task
|
|
# At the end of the day, coroutines are generator functions with
|
|
# some tricky behaviors, and this is one of them. When a coroutine
|
|
# hits a return statement (either explicit or implicit), it raises
|
|
# a StopIteration exception, which has an attribute named value that
|
|
# represents the return value of the coroutine, if it has one. Of course
|
|
# this exception is not an error, and we should happily keep going after it:
|
|
# most of this code below is just useful for internal/debugging purposes
|
|
task.state = TaskState.FINISHED
|
|
task.result = ret.value
|
|
self.on_success(self.current_task)
|
|
except Cancelled:
|
|
# When a task needs to be cancelled, we try to do it gracefully first:
|
|
# if the task is paused in either I/O or sleeping, that's perfect.
|
|
# But we also need to cancel a task if it was not sleeping or waiting on
|
|
# any I/O because it could never do so (therefore blocking everything
|
|
# forever). So, when cancellation can't be done right away, we schedule
|
|
# it for the next execution step of the task. We will also make sure
|
|
# to re-raise cancellations at every checkpoint until the task lets the
|
|
# exception propagate into us, because we *really* want the task to be
|
|
# cancelled
|
|
task = task or self.current_task
|
|
task.state = TaskState.CANCELLED
|
|
task.pending_cancellation = False
|
|
self.event("after_cancel")
|
|
self.on_cancel(task)
|
|
except (Exception, KeyboardInterrupt) as err:
|
|
# Any other exception is caught here
|
|
task = task or self.current_task
|
|
task.exc = err
|
|
err.scope = task.pool.scope
|
|
task.state = TaskState.CRASHED
|
|
self.event("on_exception_raised", task)
|
|
self.on_error(task)
|
|
|
|
def release(self, task: Task):
|
|
"""
|
|
Releases the timeouts and associated
|
|
I/O resourced that the given task owns
|
|
"""
|
|
|
|
self.io_manager.release_task(task)
|
|
self.paused.discard(task)
|
|
|
|
def on_success(self, task: Task):
|
|
"""
|
|
The given task has exited gracefully: hooray!
|
|
"""
|
|
|
|
# TODO: Anything else?
|
|
task.pool: TaskPool
|
|
for waiter in task.waiters:
|
|
self.reschedule(waiter)
|
|
if task.pool.done():
|
|
self.reschedule(task.pool.entry_point)
|
|
task.waiters.clear()
|
|
self.event("on_task_exit", task)
|
|
self.io_manager.release_task(task)
|
|
|
|
def on_error(self, task: Task):
|
|
"""
|
|
The given task raised an exception
|
|
"""
|
|
|
|
self.event("on_exception_raised", task, task.exc)
|
|
for waiter in task.waiters:
|
|
self.reschedule(waiter)
|
|
self.throw(task.pool.scope.owner, task.exc)
|
|
task.waiters.clear()
|
|
self.release(task)
|
|
|
|
def on_cancel(self, task: Task):
|
|
"""
|
|
The given task crashed because of a
|
|
cancellation exception
|
|
"""
|
|
|
|
for waiter in task.waiters:
|
|
self.reschedule(waiter)
|
|
task.waiters.clear()
|
|
self.release(task)
|
|
if task.pool.done():
|
|
self.reschedule(task.pool.entry_point)
|
|
|
|
def init_scope(self, scope: TaskScope):
|
|
scope.owner = self.current_task
|
|
self.current_scope.inner = scope
|
|
scope.outer = self.current_scope
|
|
self.current_scope = scope
|
|
self.scopes.append(scope)
|
|
|
|
def close_scope(self, scope: TaskScope):
|
|
self.current_scope = scope.outer
|
|
self.reschedule(scope.owner)
|
|
self.scopes.pop()
|
|
|
|
def cancel_task(self, task: Task):
|
|
if task.done():
|
|
return
|
|
err = Cancelled()
|
|
err.scope = task.pool.scope
|
|
self.throw(task, err)
|
|
if task.state != TaskState.CANCELLED:
|
|
task.pending_cancellation = True
|
|
|
|
def cancel_scope(self, scope: TaskScope):
|
|
scope.attempted_cancel = True
|
|
inner = scope.inner
|
|
if inner and not inner.shielded:
|
|
self.cancel_scope(inner)
|
|
for task in scope.tasks:
|
|
self.cancel_task(task)
|
|
if scope.done():
|
|
self.reschedule(scope.owner)
|
|
|
|
def init_pool(self, pool: TaskPool):
|
|
pool.outer = self.current_pool
|
|
pool.entry_point = self.current_task
|
|
self.current_pool.inner = pool
|
|
self.current_pool = pool
|
|
|
|
def close_pool(self, pool: TaskPool):
|
|
self.current_pool = pool.outer
|
|
|
|
def suspend(self):
|
|
self.current_task.state = TaskState.PAUSED
|
|
self.current_task.paused_when = self.clock.current_time()
|
|
|
|
def setup(self):
|
|
for manager in self.signal_managers:
|
|
manager.install()
|
|
|
|
def teardown(self):
|
|
for manager in self.signal_managers:
|
|
manager.uninstall()
|