structio/structio/core/kernels/fifo.py

359 lines
13 KiB
Python

import traceback
import warnings
from types import FrameType
from structio.abc import BaseKernel, BaseClock, BaseDebugger, BaseIOManager, SignalManager
from structio.core.context import TaskPool, TaskScope
from structio.core.task import Task, TaskState
from structio.util.ki import CTRLC_PROTECTION_ENABLED
from structio.core.time.queue import TimeQueue
from structio.exceptions import StructIOException, Cancelled, TimedOut
from collections import deque
from typing import Callable, Coroutine, Any
from functools import partial
import signal
class FIFOKernel(BaseKernel):
"""
An asynchronous event loop implementation
with a FIFO scheduling policy
"""
def __init__(self, clock: BaseClock, io_manager: BaseIOManager,
signal_managers: list[SignalManager],
tools: list[BaseDebugger] | None = None,
restrict_ki_to_checkpoints: bool = False):
super().__init__(clock, io_manager, signal_managers, tools, restrict_ki_to_checkpoints)
self.skip: bool = False
# Tasks that are ready to run
self.run_queue: deque[Task] = deque()
# Data to send back to tasks
self.data: dict[Task, Any] = {}
# Have we handled SIGINT?
self._sigint_handled: bool = False
# Paused tasks along with their deadlines
self.paused: TimeQueue = TimeQueue(self.clock)
# All task scopes we handle
self.scopes: list[TaskScope] = []
self.pool = TaskPool()
self.current_pool = self.pool
self.current_scope = self.current_pool.scope
self.current_scope.shielded = False
self.scopes.append(self.current_scope)
self._closing = False
def get_closest_deadline(self):
return min([self.current_scope.get_actual_timeout(), self.paused.get_closest_deadline()])
def get_closest_deadline_owner(self):
return self.paused.peek()
def event(self, evt_name: str, *args):
if not hasattr(BaseDebugger, evt_name):
warnings.warn(f"Invalid debugging event fired: {evt_name!r}")
return
for tool in self.tools:
if f := getattr(tool, evt_name, None):
try:
f(*args)
except BaseException as e:
# We really can't afford to have our internals explode,
# sorry!
warnings.warn(f"Exception during debugging event delivery ({evt_name!r}): {type(e).__name__} -> {e}",
)
traceback.print_tb(e.__traceback__)
def done(self):
if self.entry_point.done():
return True
if any([self.run_queue, self.paused, self.io_manager.pending()]):
return False
for scope in self.scopes:
if not scope.done():
return False
return True
def spawn(self, func: Callable[[Any, Any], Coroutine[Any, Any, Any]], *args):
task = Task(func.__name__ or repr(func), func(*args), self.current_pool)
# We inject our magic secret variable into the coroutine's stack frame, so
# we can look it up later
task.coroutine.cr_frame.f_locals.setdefault(CTRLC_PROTECTION_ENABLED, False)
task.pool.scope.tasks.append(task)
self.run_queue.append(task)
self.event("on_task_spawn")
return task
def spawn_system_task(self, func: Callable[[Any, Any], Coroutine[Any, Any, Any]], *args):
task = Task(func.__name__ or repr(func), func(*args), self.pool)
task.coroutine.cr_frame.f_locals.setdefault(CTRLC_PROTECTION_ENABLED, True)
task.pool.scope.tasks.append(task)
self.run_queue.append(task)
self.event("on_task_spawn")
return task
def signal_notify(self, sig: int, frame: FrameType):
match sig:
case signal.SIGINT:
self._sigint_handled = True
case _:
pass
def step(self):
"""
Run a single task step (i.e. until an "await" to our
primitives somewhere)
"""
self.current_task = self.run_queue.popleft()
while self.current_task.done():
if not self.run_queue:
return
self.current_task = self.run_queue.popleft()
runner = partial(self.current_task.coroutine.send, self.data.pop(self.current_task, None))
if self.current_task.pending_cancellation:
_runner = partial(self.current_task.coroutine.throw, Cancelled())
elif self._sigint_handled:
self._sigint_handled = False
runner = partial(self.current_task.coroutine.throw, KeyboardInterrupt())
self.event("before_task_step", self.current_task)
method, args, kwargs = runner()
if not callable(getattr(self, method, None)):
# This if block is meant to be triggered by other async
# libraries, which most likely have different method names and behaviors
# compared to us. If you get this exception, and you're 100% sure you're
# not mixing async primitives from other libraries, then it's a bug!
self.throw(self.current_task, StructIOException("Uh oh! Something bad just happened: did you try to mix "
"primitives from other async libraries?"))
# Sneaky method call, thanks to David Beazley for this ;)
getattr(self, method)(*args, **kwargs)
self.event("after_task_step", self.current_task)
def throw(self, task: Task, err: BaseException):
if task.done():
return
if task.state == TaskState.PAUSED:
self.paused.discard(task)
self.handle_errors(partial(task.coroutine.throw, err), task)
def reschedule(self, task: Task):
if task.done():
return
self.run_queue.append(task)
def check_cancelled(self):
if self._sigint_handled:
self.throw(self.entry_point, KeyboardInterrupt())
elif self.current_task.pending_cancellation:
self.cancel_task(self.current_task)
else:
# We reschedule the caller immediately!
self.run_queue.appendleft(self.current_task)
def schedule_point(self):
self.reschedule_running()
def sleep(self, amount):
"""
Puts the current task to sleep for the given amount of
time as defined by our current clock
"""
# Just to avoid code duplication, you know
self.suspend()
if amount > 0:
self.event("before_sleep", self.current_task, amount)
self.current_task.next_deadline = self.current_task.paused_when + amount
self.paused.put(self.current_task, amount)
else:
# If sleep is called with 0 as argument,
# then it's just a checkpoint!
self.skip = True
self.reschedule_running()
def check_scopes(self):
for scope in self.scopes:
if scope.get_actual_timeout() <= self.clock.current_time():
error = TimedOut("timed out")
error.scope = scope
self.throw(scope.owner, error)
def wakeup(self):
while self.paused and self.paused.peek().next_deadline <= self.clock.current_time():
task, _ = self.paused.get()
task.next_deadline = 0
self.event("after_sleep", task, task.paused_when - self.clock.current_time())
self.reschedule(task)
def run(self):
"""
This is the actual "loop" part
of the "event loop"
"""
while not self.done():
if self.run_queue and not self.skip:
self.handle_errors(self.step)
self.running = False
self.skip = False
if self._sigint_handled and not self.restrict_ki_to_checkpoints:
self.throw(self.entry_point, KeyboardInterrupt())
if self.io_manager.pending():
self.io_manager.wait_io()
self.wakeup()
self.check_scopes()
self.close()
def reschedule_running(self):
"""
Reschedules the currently running task
"""
self.run_queue.append(self.current_task)
def handle_errors(self, func: Callable, task: Task | None = None):
"""
Convenience method for handling various exceptions
from tasks
"""
try:
func()
except StopIteration as ret:
# We re-define it because we call step() with
# this method and that changes the current task
task = task or self.current_task
# At the end of the day, coroutines are generator functions with
# some tricky behaviors, and this is one of them. When a coroutine
# hits a return statement (either explicit or implicit), it raises
# a StopIteration exception, which has an attribute named value that
# represents the return value of the coroutine, if it has one. Of course
# this exception is not an error, and we should happily keep going after it:
# most of this code below is just useful for internal/debugging purposes
task.state = TaskState.FINISHED
task.result = ret.value
self.on_success(self.current_task)
except Cancelled:
# When a task needs to be cancelled, we try to do it gracefully first:
# if the task is paused in either I/O or sleeping, that's perfect.
# But we also need to cancel a task if it was not sleeping or waiting on
# any I/O because it could never do so (therefore blocking everything
# forever). So, when cancellation can't be done right away, we schedule
# it for the next execution step of the task. We will also make sure
# to re-raise cancellations at every checkpoint until the task lets the
# exception propagate into us, because we *really* want the task to be
# cancelled
task = task or self.current_task
task.state = TaskState.CANCELLED
task.pending_cancellation = False
self.event("after_cancel")
self.on_cancel(task)
except (Exception, KeyboardInterrupt) as err:
# Any other exception is caught here
task = task or self.current_task
task.exc = err
err.scope = task.pool.scope
task.state = TaskState.CRASHED
self.event("on_exception_raised", task)
self.on_error(task)
def release(self, task: Task):
"""
Releases the timeouts and associated
I/O resourced that the given task owns
"""
self.io_manager.release_task(task)
self.paused.discard(task)
def on_success(self, task: Task):
"""
The given task has exited gracefully: hooray!
"""
# TODO: Anything else?
task.pool: TaskPool
for waiter in task.waiters:
self.reschedule(waiter)
if task.pool.done():
self.reschedule(task.pool.entry_point)
task.waiters.clear()
self.event("on_task_exit", task)
self.io_manager.release_task(task)
def on_error(self, task: Task):
"""
The given task raised an exception
"""
self.event("on_exception_raised", task, task.exc)
for waiter in task.waiters:
self.reschedule(waiter)
self.throw(task.pool.scope.owner, task.exc)
task.waiters.clear()
self.release(task)
def on_cancel(self, task: Task):
"""
The given task crashed because of a
cancellation exception
"""
for waiter in task.waiters:
self.reschedule(waiter)
task.waiters.clear()
self.release(task)
if task.pool.done():
self.reschedule(task.pool.entry_point)
def init_scope(self, scope: TaskScope):
scope.owner = self.current_task
self.current_scope.inner = scope
scope.outer = self.current_scope
self.current_scope = scope
self.scopes.append(scope)
def close_scope(self, scope: TaskScope):
self.current_scope = scope.outer
self.reschedule(scope.owner)
self.scopes.pop()
def cancel_task(self, task: Task):
if task.done():
return
err = Cancelled()
err.scope = task.pool.scope
self.throw(task, err)
if task.state != TaskState.CANCELLED:
task.pending_cancellation = True
def cancel_scope(self, scope: TaskScope):
scope.attempted_cancel = True
inner = scope.inner
if inner and not inner.shielded:
self.cancel_scope(inner)
for task in scope.tasks:
self.cancel_task(task)
if scope.done():
self.reschedule(scope.owner)
def init_pool(self, pool: TaskPool):
pool.outer = self.current_pool
pool.entry_point = self.current_task
self.current_pool.inner = pool
self.current_pool = pool
def close_pool(self, pool: TaskPool):
self.current_pool = pool.outer
def suspend(self):
self.current_task.state = TaskState.PAUSED
self.current_task.paused_when = self.clock.current_time()
def setup(self):
for manager in self.signal_managers:
manager.install()
def teardown(self):
for manager in self.signal_managers:
manager.uninstall()