structio/structio/core/kernels/fifo.py

478 lines
17 KiB
Python

import traceback
import warnings
from types import FrameType
from structio.abc import (
BaseKernel,
BaseClock,
BaseDebugger,
BaseIOManager,
SignalManager,
)
from structio.io import FdWrapper
from structio.core.context import TaskPool, TaskScope
from structio.core.task import Task, TaskState
from structio.util.ki import CTRLC_PROTECTION_ENABLED
from structio.core.time.queue import TimeQueue
from structio.exceptions import (
StructIOException,
Cancelled,
TimedOut,
ResourceClosed,
ResourceBroken,
)
from collections import deque
from typing import Callable, Coroutine, Any
from functools import partial
import signal
import sniffio
class FIFOKernel(BaseKernel):
"""
An asynchronous event loop implementation
with a FIFO scheduling policy
"""
def __init__(
self,
clock: BaseClock,
io_manager: BaseIOManager,
signal_managers: list[SignalManager],
tools: list[BaseDebugger] | None = None,
restrict_ki_to_checkpoints: bool = False,
):
super().__init__(
clock, io_manager, signal_managers, tools, restrict_ki_to_checkpoints
)
# Tasks that are ready to run
self.run_queue: deque[Task] = deque()
# Data to send back to tasks
self.data: dict[Task, Any] = {}
# Have we handled SIGINT?
self._sigint_handled: bool = False
# Paused tasks along with their deadlines
self.paused: TimeQueue = TimeQueue()
self.pool = TaskPool()
self.pool.scope.shielded = True
self.current_scope = self.pool.scope
self.current_scope.shielded = False
def get_closest_deadline(self):
if self.run_queue:
# We absolutely cannot block while other
# tasks have things to do!
return self.clock.current_time()
deadlines = []
for scope in self.pool.scope.children:
deadlines.append(scope.get_effective_deadline()[0])
if not deadlines:
deadlines.append(float("inf"))
return min(
[
min(deadlines),
self.paused.get_closest_deadline(),
]
)
def wait_readable(self, resource: FdWrapper):
self.current_task.state = TaskState.IO
self.io_manager.request_read(resource, self.current_task)
def wait_writable(self, resource: FdWrapper):
self.current_task.state = TaskState.IO
self.io_manager.request_write(resource, self.current_task)
def notify_closing(
self, resource: FdWrapper, broken: bool = False, owner: Task | None = None
):
if not broken:
exc = ResourceClosed("stream has been closed")
else:
exc = ResourceBroken("stream might be corrupted")
owner = owner or self.current_task
reader = self.io_manager.get_reader(resource)
writer = self.io_manager.get_writer(resource)
if reader and reader is not owner:
self.throw(reader, exc)
if writer and writer is not owner:
self.throw(writer, exc)
self.reschedule_running()
def event(self, evt_name: str, *args):
if not hasattr(BaseDebugger, evt_name):
warnings.warn(f"Invalid debugging event fired: {evt_name!r}")
return
for tool in self.tools:
if f := getattr(tool, evt_name, None):
try:
f(*args)
except BaseException as e:
# We really can't afford to have our internals explode,
# sorry!
warnings.warn(
f"Exception during debugging event delivery in {f!r} ({evt_name!r}): {type(e).__name__} -> {e}",
)
traceback.print_tb(e.__traceback__)
# We disable the tool, so it can't raise at the next debugging
# event
self.tools.remove(tool)
def done(self):
if self.entry_point.done():
return True
if any([self.run_queue, self.paused, self.io_manager.pending()]):
return False
if not self.pool.done():
return False
return True
def spawn(
self,
func: Callable[[Any, Any], Coroutine[Any, Any, Any]],
*args,
ki_protected: bool = False,
pool: TaskPool = None,
):
if isinstance(func, partial):
name = func.func.__name__ or repr(func.func)
else:
name = func.__name__ or repr(func)
if pool is None:
pool = self.current_pool
task = Task(name, func(*args), pool)
# We inject our magic secret variable into the coroutine's stack frame, so
# we can look it up later
task.coroutine.cr_frame.f_locals.setdefault(
CTRLC_PROTECTION_ENABLED, ki_protected
)
self.run_queue.append(task)
self.event("on_task_spawn")
return task
def spawn_system_task(
self, func: Callable[[Any, Any], Coroutine[Any, Any, Any]], *args
):
return self.spawn(func, *args, ki_protected=True, pool=self.pool)
def signal_notify(self, sig: int, frame: FrameType):
match sig:
case signal.SIGINT:
self._sigint_handled = True
# Poke the event loop with a stick ;)
self.run_queue.append(self.entry_point)
case _:
pass
def step(self):
"""
Run a single task step (i.e. until an "await" to our
primitives somewhere)
"""
self.current_task = self.run_queue.popleft()
while self.current_task.done():
if not self.run_queue:
return
self.current_task = self.run_queue.popleft()
runner = partial(
self.current_task.coroutine.send, self.data.pop(self.current_task, None)
)
if self.current_task.pending_cancellation:
runner = partial(self.current_task.coroutine.throw, Cancelled())
elif self._sigint_handled:
self._sigint_handled = False
runner = partial(self.current_task.coroutine.throw, KeyboardInterrupt())
self.event("before_task_step", self.current_task)
self.current_task.state = TaskState.RUNNING
self.current_task.paused_when = 0
self.current_pool = self.current_task.pool
self.current_scope = self.current_pool.scope
method, args, kwargs = runner()
self.current_task.state = TaskState.PAUSED
self.current_task.paused_when = self.clock.current_time()
if not callable(getattr(self, method, None)):
# This if block is meant to be triggered by other async
# libraries, which most likely have different method names and behaviors
# compared to us. If you get this exception, and you're 100% sure you're
# not mixing async primitives from other libraries, then it's a bug!
self.throw(
self.current_task,
StructIOException(
"Uh oh! Something bad just happened: did you try to mix "
"primitives from other async libraries?"
),
)
# Sneaky method call, thanks to David Beazley for this ;)
getattr(self, method)(*args, **kwargs)
self.event("after_task_step", self.current_task)
def throw(self, task: Task, err: BaseException):
if task.done():
return
self.release(task)
self.handle_errors(partial(task.coroutine.throw, err), task)
def reschedule(self, task: Task):
if task.done():
return
self.run_queue.append(task)
def check_cancelled(self):
if self._sigint_handled:
self.throw(self.entry_point, KeyboardInterrupt())
elif self.current_task.pending_cancellation:
self.cancel_task(self.current_task)
else:
# We reschedule the caller immediately!
self.run_queue.appendleft(self.current_task)
def schedule_point(self):
self.reschedule_running()
def sleep(self, amount):
"""
Puts the current task to sleep for the given amount of
time as defined by our current clock
"""
# Just to avoid code duplication, you know
self.suspend()
if amount > 0:
self.event("before_sleep", self.current_task, amount)
self.current_task.next_deadline = self.clock.deadline(amount)
self.paused.put(self.current_task, self.clock.deadline(amount))
else:
# If sleep is called with 0 as argument,
# then it's just a checkpoint!
self.schedule_point()
self.check_cancelled()
def check_scopes(self):
expired = set()
for scope in self.pool.scope.children:
deadline, actual = scope.get_effective_deadline()
if deadline <= self.clock.current_time() and not actual.timed_out:
expired.add(actual)
for scope in expired:
scope.timed_out = True
error = TimedOut("timed out")
error.scope = scope
self.throw(scope.owner, error)
self.reschedule(scope.owner)
def wakeup(self):
while (
self.paused
and self.paused.peek().next_deadline <= self.clock.current_time()
):
task, _ = self.paused.get()
task.next_deadline = 0
self.event(
"after_sleep", task, task.paused_when - self.clock.current_time()
)
self.reschedule(task)
def run(self):
"""
This is the actual "loop" part
of the "event loop"
"""
while not self.done():
if self._sigint_handled and not self.restrict_ki_to_checkpoints:
self.throw(self.entry_point, KeyboardInterrupt())
if self.run_queue:
self.handle_errors(self.step)
self.wakeup()
self.check_scopes()
if self.io_manager.pending():
self.io_manager.wait_io(self.clock.current_time())
self.close()
def reschedule_running(self):
"""
Reschedules the currently running task
"""
self.reschedule(self.current_task)
def handle_errors(self, func: Callable, task: Task | None = None):
"""
Convenience method for handling various exceptions
from tasks
"""
old_name, sniffio.thread_local.name = sniffio.thread_local.name, "structured-io"
try:
func()
except StopIteration as ret:
# We re-define it because we call step() with
# this method and that changes the current task
task = task or self.current_task
# At the end of the day, coroutines are generator functions with
# some tricky behaviors, and this is one of them. When a coroutine
# hits a return statement (either explicit or implicit), it raises
# a StopIteration exception, which has an attribute named value that
# represents the return value of the coroutine, if it has one. Of course
# this exception is not an error, and we should happily keep going after it:
# most of this code below is just useful for internal/debugging purposes
task.state = TaskState.FINISHED
task.result = ret.value
self.on_success(task)
self.event("on_task_exit", task)
except Cancelled:
# When a task needs to be cancelled, we try to do it gracefully first:
# if the task is paused in either I/O or sleeping, that's perfect.
# But we also need to cancel a task if it was not sleeping or waiting on
# any I/O because it could never do so (therefore blocking everything
# forever). So, when cancellation can't be done right away, we schedule
# it for the next execution step of the task. We will also make sure
# to re-raise cancellations at every checkpoint until the task lets the
# exception propagate into us, because we *really* want the task to be
# cancelled
task = task or self.current_task
task.state = TaskState.CANCELLED
task.pending_cancellation = False
self.on_cancel(task)
self.event("after_cancel", task)
except (Exception, KeyboardInterrupt) as err:
# Any other exception is caught here
task = task or self.current_task
task.exc = err
err.scope = task.pool.scope
task.state = TaskState.CRASHED
self.on_error(task)
self.event("on_exception_raised", task)
finally:
sniffio.thread_local.name = old_name
def release_resource(self, resource: FdWrapper):
self.io_manager.release(resource)
self.reschedule_running()
def release(self, task: Task):
"""
Releases the timeouts and associated
I/O resourced that the given task owns
"""
self.io_manager.release_task(task)
self.paused.discard(task)
def on_success(self, task: Task):
"""
The given task has exited gracefully: hooray!
"""
assert task.state == TaskState.FINISHED
# Walk up the scope tree and reschedule all necessary
# tasks
scope = task.pool.scope
while scope.done() and scope is not self.pool.scope:
if scope.done():
self.reschedule(scope.owner)
scope = scope.outer
self.event("on_task_exit", task)
self.release(task)
def on_error(self, task: Task):
"""
The given task raised an exception
"""
assert task.state == TaskState.CRASHED
self.event("on_exception_raised", task, task.exc)
scope = task.pool.scope
if task is not scope.owner:
self.reschedule(scope.owner)
self.throw(scope.owner, task.exc)
self.release(task)
def on_cancel(self, task: Task):
"""
The given task crashed because of a
cancellation exception
"""
assert task.state == TaskState.CANCELLED
self.event("after_cancel", task)
self.release(task)
def init_scope(self, scope: TaskScope):
scope.deadline = self.clock.deadline(scope.timeout)
scope.owner = self.current_task
self.current_scope.inner.append(scope)
scope.outer = self.current_scope
self.current_scope = scope
def close_scope(self, scope: TaskScope):
self.current_scope = scope.outer
self.current_scope.inner = []
def cancel_task(self, task: Task):
if task.done():
return
if task.state == TaskState.RUNNING:
# Can't cancel a task while it's
# running, will raise ValueError
# if we try. We defer it for later
task.pending_cancellation = True
return
err = Cancelled()
err.scope = task.pool.scope
self.throw(task, err)
if task.state != TaskState.CANCELLED:
# Task is stubborn. But so are we,
# so we'll redeliver the cancellation
# every time said task tries to call any
# event loop primitive
task.pending_cancellation = True
def cancel_scope(self, scope: TaskScope):
scope.attempted_cancel = True
# We can't just immediately cancel the
# current task because this method is
# called synchronously by TaskScope.cancel(),
# so there is nowhere to throw an exception
# to
if self.current_task in scope.tasks and self.current_task is not scope.owner:
self.current_task.pending_cancellation = True
for child in filter(lambda c: not c.shielded, scope.children):
self.cancel_scope(child)
for task in scope.tasks:
if task is self.current_task:
continue
self.cancel_task(task)
if (
scope is not self.current_task.pool.scope
and scope.owner is not self.current_task
and scope.owner is not self.entry_point
):
# Handles the case where the current task calls
# cancel() for a scope which it doesn't own, which
# is an entirely reasonable thing to do
self.cancel_task(scope.owner)
if scope.done():
scope.cancelled = True
def init_pool(self, pool: TaskPool):
pool.outer = self.current_pool
pool.entry_point = self.current_task
self.current_pool = pool
def close_pool(self, pool: TaskPool):
self.current_pool = pool.outer
self.close_scope(pool.scope)
def suspend(self):
self.current_task.state = TaskState.PAUSED
self.current_task.paused_when = self.clock.current_time()
def setup(self):
for manager in self.signal_managers:
manager.install()
def teardown(self):
for manager in self.signal_managers:
manager.uninstall()