This repository has been archived on 2023-05-12. You can view files and clone it, but cannot push or open issues or pull requests.
aiosched/aiosched/kernel.py

768 lines
30 KiB
Python

"""
aiosched: Yet another Python async scheduler
Copyright (C) 2022 nocturn9x
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https:www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import signal
import itertools
import warnings
from collections import deque
from functools import partial
from aiosched.task import Task, TaskState
from timeit import default_timer
from aiosched.internals.queue import TimeQueue
from aiosched.util.debugging import BaseDebugger
from typing import Callable, Any, Coroutine
from aiosched.errors import (
InternalError,
ResourceBusy,
Cancelled,
ResourceClosed,
ResourceBroken,
)
from aiosched.context import TaskPool, TaskScope
from aiosched.util.sigint import CTRLC_PROTECTION_ENABLED, currently_protected, enable_ki_protection
from selectors import DefaultSelector, BaseSelector, EVENT_READ, EVENT_WRITE
from types import FrameType
class FIFOKernel:
"""
An asynchronous event loop implementation with a FIFO
scheduling policy.
:param clock: The function used to keep track of time. Defaults to timeit.default_timer
:param debugger: A subclass of aiosched.util.BaseDebugger or None if no debugging output is desired
:type debugger: :class: aiosched.util.debugging.BaseDebugger, optional
:param selector: The selector to use for I/O multiplexing, defaults to selectors.DefaultSelector
:type selector: :class: selectors.DefaultSelector
"""
def __init__(
self,
clock: Callable[[], float] = default_timer,
debugger: BaseDebugger | None = None,
selector: BaseSelector = DefaultSelector(),
):
"""
Public constructor
"""
self.clock = clock
if debugger and not issubclass(type(debugger), BaseDebugger):
raise InternalError(
"The debugger must be a subclass of aiosched.util.debugging.BaseDebugger"
)
# The debugger object. If it is none we create a dummy object that immediately returns an empty
# lambda which in turn returns None every time we access any of its attributes to avoid lots of
# if self.debugger clauses
self.debugger = (
debugger
or type(
"DumbDebugger",
(object,),
{"__getattr__": lambda *_: lambda *_: None},
)()
)
# Abstraction layer over low-level OS
# primitives for asynchronous I/O
self.selector: BaseSelector = selector
# Tasks that are ready to run
self.run_ready: deque[Task] = deque()
# Tasks that are paused and waiting
# for some deadline to expire
self.paused: TimeQueue = TimeQueue(self.clock)
# Data that is to be sent back to coroutines
self.data: dict[Task, Any] = {}
# The currently running task
self.current_task: Task | None = None
# The loop's entry point
self.entry_point: Task | None = None
# Did we receive a Ctrl+C?
self._sigint_handled: bool = False
# Are we executing any task code?
self._running: bool = False
# The current task pool we're in
self.current_pool: TaskPool | None = None
# The current task scope we're in
self.current_scope: TaskScope | None = None
def __repr__(self):
"""
Returns repr(self)
"""
fields = {
"debugger",
"run_ready",
"selector",
"clock",
"data",
"paused",
"current_task",
}
data = ", ".join(
name + "=" + str(value)
for name, value in zip(fields, (getattr(self, field) for field in fields))
)
return f"{type(self).__name__}({data})"
def _sigint_handler(self, _sig: int, _frame: FrameType):
"""
Handles SIGINT
:return:
"""
if not currently_protected():
self.run_ready.appendleft(self.entry_point)
self.entry_point.pending_exception = KeyboardInterrupt()
self.handle_errors(self.run_task_step)
else:
self._sigint_handled = True
def done(self) -> bool:
"""
Returns whether the loop has no more work
to do
"""
if any([self.paused, self.run_ready]):
# There's tasks sleeping and/or on the
# ready queue!
return False
if self.get_active_io_count():
# We don't just do any([self.paused, self.run_ready, self.selector.get_map()])
# because we don't want to just know if there's any resources we're waiting on,
# but if there's at least one non-terminated task that owns a resource we're
# waiting on. This avoids issues such as the event loop never exiting if the
# user forgets to close a socket, for example
return False
if self.current_task:
return self.current_task.done()
return True
def close(self, force: bool = False):
"""
Closes the event loop. If force equals False,
which is the default, raises an InternalError
exception. If force equals True, cancels all
tasks
"""
if not self.done() and not force:
self.current_task.throw(
InternalError("cannot shut down a running event loop")
)
for task in self.all(copy=True):
self.cancel(task)
self.selector.close()
def all(self, copy: bool = False) -> Task:
"""
Yields a ll the tasks the event loop is keeping track of.
This is an internal undocumented method
"""
sources = []
if self.paused:
sources.append([])
for _, __, task, ___ in self.paused.container:
sources[-1].append(task)
if copy:
sources.append(self.run_ready.copy())
else:
sources.append(self.run_ready)
if self.selector.get_map():
sources.append([])
for key in (self.selector.get_map() or {}).values():
for task in key.data.values():
sources[-1].append(task)
for task in itertools.chain(*sources):
task: Task
yield task
def wait_io(self):
"""
Waits for I/O and schedules tasks when their
associated resource is ready to be used
"""
self._running = False
before_time = self.clock() # Used for the debugger
timeout = 0.0
if self.run_ready:
# If there is work to do immediately (tasks to run) we
# can't wait.
# TODO: This could cause I/O starvation in highly concurrent
# environments: maybe a more convoluted scheduling strategy
# where I/O timeouts can only be skipped n times before a
# mandatory x-second timeout occurs is needed? It should of
# course take deadlines into account so that timeouts are
# always delivered in a timely manner and tasks awake from
# sleeping at the right moment
timeout = 0.0
elif self.paused:
# If there are asleep tasks or deadlines, wait until the closest date
timeout = self.paused.get_closest_deadline() - self.clock()
self.debugger.before_io(timeout)
# Get sockets that are ready and schedule their tasks
for key, _ in self.selector.select(timeout):
key.data: dict[int, Task]
for task in key.data.values():
# Since we don't unschedule I/O
# resources after every operation,
# we may hold on to a socket while
# its owner task is sleeping on some
# synchronization primitive: if we
# rescheduled it now, it would cause
# all sort of nonsense! So we only
# schedule tasks waiting for I/O to happen
if task.state == TaskState.IO:
self.run_ready.append(task) # Resource ready? Schedule its task
self.debugger.after_io(self.clock() - before_time)
def awake_tasks(self):
"""
Reschedules paused tasks if their deadline
has elapsed
"""
self._running = False
while self.paused and self.paused.get_closest_deadline() <= self.clock():
# Reschedules tasks when their deadline has elapsed
task, _ = self.paused.get()
slept = self.clock() - task.paused_when
self.run_ready.append(task)
task.paused_when = 0
task.next_deadline = 0
self.debugger.after_sleep(task, slept)
def reschedule_running(self):
"""
Reschedules the currently running task
"""
self.run_ready.append(self.current_task)
def schedule(self, task: Task):
"""
Schedules a task that was previously
suspended
"""
self.run_ready.append(task)
self.reschedule_running()
def suspend(self):
"""
Suspends execution of the current task. This is basically
a do-nothing method, since it will not reschedule the task
before returning. The task will stay suspended as long as
something else outside the loop reschedules it (possibly
forever)
"""
self.current_task.state = TaskState.PAUSED
def set_context(self, ctx: TaskPool):
"""
Sets the current task context
"""
self.debugger.on_context_creation(ctx)
self.current_task.context = ctx
if self.current_pool is None:
self.current_pool = ctx
else:
self.current_pool.inner = ctx
ctx.outer = self.current_pool
self.current_pool = ctx
self.reschedule_running()
def close_context(self, ctx: TaskPool):
"""
Closes the given context
"""
ctx.inner = None
self.debugger.on_context_exit(ctx)
ctx.entry_point.context = None
self.current_pool = ctx.outer
self.reschedule_running()
def set_scope(self, scope: TaskScope):
"""
Sets the current task scope
"""
if self.current_scope is None:
self.current_scope = scope
else:
self.current_scope.inner = scope
scope.outer = self.current_scope
self.current_scope = scope
self.reschedule_running()
def close_scope(self, scope: TaskScope):
"""
Closes the given scope
"""
scope.inner = None
self.current_scope = scope.outer
self.reschedule_running()
def get_current_scope(self):
self.data[self.current_task] = self.current_scope
self.reschedule_running()
def run_task_step(self):
"""
Runs a single step for the current task.
A step ends when the task awaits any of
our primitives or async methods.
Note that this method does NOT catch any
errors arising from tasks, nor does it take
StopIteration or Cancelled exceptions into
account
"""
# Sets the currently running task
self.current_task = self.run_ready.popleft()
while self.current_task.done():
# We make sure not to schedule
# any terminated tasks. Might want
# to eventually get rid of this code,
# but for now it does the job
if not self.run_ready:
# We'll let run() handle the I/O
# or the shutdown if necessary, as
# there are no more runnable tasks
return
self.current_task = self.run_ready.popleft()
self._running = True
_runner = self.current_task.run
_data = [self.data.pop(self.current_task, None)]
if exc := self.current_task.pending_exception:
self.current_task.pending_exception = None
_runner = partial(self.current_task.throw, exc)
_data = []
# Some debugging and internal chatter here
self.current_task.steps += 1
self.current_task.state = TaskState.RUN
self.debugger.before_task_step(self.current_task)
# Run a single step with the calculation (i.e. until a yield
# somewhere)
method, args, kwargs = _runner(*_data)
if not hasattr(self, method) or not callable(getattr(self, method)):
# This if block is meant to be triggered by other async
# libraries, which most likely have different method names and behaviors
# compared to us. If you get this exception, and you're 100% sure you're
# not mixing async primitives from other libraries, then it's a bug!
self.current_task.throw(
InternalError(
"Uh oh! Something bad just happened: did you try to mix primitives from other async libraries?"
)
)
# Sneaky method call, thanks to David Beazley for this ;)
getattr(self, method)(*args, **kwargs)
self.debugger.after_task_step(self.current_task)
def setup(self):
"""
Configures the event loop
"""
if signal.getsignal(signal.SIGINT) is signal.default_int_handler:
signal.signal(signal.SIGINT, self._sigint_handler)
else:
warnings.warn("aiosched detected a custom signal handler for SIGINT and it won't touch it, but"
" keep in mind that Ctrl+C is likely to break!")
def teardown(self):
"""
Undoes any modification made by setup()
"""
if signal.getsignal(signal.SIGINT) is self._sigint_handler:
signal.signal(signal.SIGINT, signal.default_int_handler)
def run(self):
"""
The event loop's runner function. This method drives
execution for the entire framework and orchestrates I/O,
events, sleeping, cancellations and deadlines, but the
actual functionality for all of that is implemented in
object wrappers. This keeps the size of this module to
a minimum while allowing anyone to replace it with their
own, as long as the system calls required by higher-level
object wrappers are implemented. If you want to add features
to the library, don't add them here, but take inspiration
from the current API (i.e. not depending on any implementation
detail from the loop aside from system calls)
"""
while True:
if self.done():
# If we're done, which means there are
# both no paused tasks and no running tasks, we
# simply tear us down and return to self.start
self.close()
break
elif self._sigint_handled:
# We got Ctrl+C-ed while not running a task! We pick
# any of the tasks we have, schedule it for execution
# (no matter what it's doing, because it doesn't really
# matter) and let run_task_step raise an exception inside
# it
# P.S.: Okay, so, I never liked this code here because it
# is really hard to test properly in different contexts: while
# it's true that it's technically okay for us to raise KeyboardInterrupt
# anywhere in the user's code, simplifying the critical path and making it
# behave more predictably has priority. Also, if trio can afford to be lazy,
# then so can we. Besides, the loop's entry point _is_ technically part of
# where KeyboardInterrupt is allowed to pop up
"""
task: Task | None = None
if self.selector.get_map():
# Pretty convoluted, huh? Sorry, but I wanted this on one line ;)
task = next(iter(next(iter(self.selector.get_map().values())).data.values()))
elif self.paused:
task, *_ = self.paused.get()
else:
task = self.current_task
"""
self._sigint_handled = False
task = self.entry_point
task.pending_exception = KeyboardInterrupt()
self.run_ready.appendleft(task)
self.handle_errors(self.run_task_step)
elif not self.run_ready:
# If there are no actively running tasks, we start by
# checking for I/O. This method will wait for I/O until
# the closest deadline to avoid starving sleeping tasks
# or missing deadlines
if self.selector.get_map():
self.wait_io()
if self.paused:
# Next we check for deadlines
self.awake_tasks()
else:
# Otherwise, while there are tasks ready to run, we run them!
self.handle_errors(self.run_task_step)
@enable_ki_protection
def start(
self, func: Callable[..., Coroutine[Any, Any, Any]], *args, **kwargs
) -> Any:
"""
Starts the event loop from a synchronous context
"""
self.setup()
self.entry_point = Task(func.__name__ or str(func), func(*args, **kwargs))
self.run_ready.append(self.entry_point)
self.debugger.on_start()
try:
self.run()
finally:
self.teardown()
if (
self.entry_point.exc
and self.entry_point.context is None
and self.entry_point.propagate
):
# Contexts already manage exceptions for us,
# no need to raise it manually. If a context
# is not used, *then* we can raise the error
raise self.entry_point.exc
self.debugger.on_exit()
return self.entry_point.result
def io_release(self, resource, internal: bool = False):
"""
Releases the given resource from our
selector
:param resource: The resource to be released
"""
if resource in self.selector.get_map():
self.selector.unregister(resource)
self.debugger.on_io_unschedule(resource)
if self.current_task.last_io and resource is self.current_task.last_io[1]:
self.current_task.last_io = None
if not internal:
self.reschedule_running()
def io_release_task(self, task: Task):
"""
Calls self.io_release in a loop
for each I/O resource the given task owns
"""
for key in dict(self.selector.get_map() or {}).values():
if task not in key.data.values():
continue
if len(key.data.values()) == 2:
a, b = key.data.values()
if a is not task or b is not task:
continue
self.notify_closing(key.fileobj, broken=True, owner=task)
self.io_release(key.fileobj, internal=True)
task.last_io = None
def get_active_io_count(self) -> int:
"""
Returns the number of streams that are currently
being used by any active task
"""
result = 0
for key in (self.selector.get_map() or {}).values():
key.data: dict[int, Task]
for task in key.data.values():
if task.done():
continue
result += 1
return result
def notify_closing(self, stream, broken: bool = False, owner: Task | None = None):
"""
Notifies paused tasks that a stream
is about to be closed. The stream
itself is not touched and must be
closed by the caller
"""
if not broken:
exc = ResourceClosed("stream has been closed")
else:
exc = ResourceBroken("stream might be corrupted")
owner = owner or self.current_task
for k in filter(
lambda o: o.fileobj == stream,
dict(self.selector.get_map()).values(),
):
for task in k.data.values():
if task is not owner:
# We don't want to raise an error inside
# the task that's trying to close the stream!
self.handle_errors(partial(task.throw, exc), task)
self.reschedule_running()
def cancel(self, task: Task):
"""
Attempts to cancel the given task or
schedules cancellation for later if
it fails
"""
self.reschedule_running()
self.paused.discard(task)
self.io_release_task(task)
self.run_ready.appendleft(task)
self.handle_errors(partial(task.throw, Cancelled()), task)
self.handle_errors(self.run_task_step)
if task.state != TaskState.CANCELLED:
task.pending_exception = Cancelled()
def throw(self, task, error):
"""
Throws the given exception into the given task
"""
task.pending_exception = error
self.run_ready.appendleft(task)
self.handle_errors(self.run_task_step)
self.reschedule_running()
def handle_errors(self, func: Callable, task: Task | None = None):
"""
Convenience method for handling various exceptions
from tasks
"""
try:
func()
except StopIteration as ret:
# We re-define it because we call run_task_step
# with this method and that changes the current
# task
task = task or self.current_task
# At the end of the day, coroutines are generator functions with
# some tricky behaviors, and this is one of them. When a coroutine
# hits a return statement (either explicit or implicit), it raises
# a StopIteration exception, which has an attribute named value that
# represents the return value of the coroutine, if it has one. Of course
# this exception is not an error, and we should happily keep going after it:
# most of this code below is just useful for internal/debugging purposes
task.state = TaskState.FINISHED
task.result = ret.value
self.io_release_task(self.current_task)
self.wait(task)
except Cancelled:
# When a task needs to be cancelled, aiosched tries to do it gracefully
# first: if the task is paused in either I/O or sleeping, that's perfect.
# But we also need to cancel a task if it was not sleeping or waiting on
# any I/O because it could never do so (therefore blocking everything
# forever). So, when cancellation can't be done right away, we schedule
# it for the next execution step of the task. aiosched will also make sure
# to re-raise cancellations at every checkpoint until the task lets the
# exception propagate into us, because we *really* want the task to be
# cancelled
task = task or self.current_task
task.state = TaskState.CANCELLED
task.pending_cancellation = False
self.io_release_task(self.current_task)
self.debugger.after_cancel(task)
self.wait(task)
except (Exception, KeyboardInterrupt) as err:
# Any other exception is caught here
task = task or self.current_task
task.exc = err
task.state = TaskState.CRASHED
self.io_release_task(self.current_task)
self.debugger.on_exception_raised(task, err)
self.wait(task)
def sleep(self, seconds: int | float):
"""
Puts the current task to sleep for a given amount of seconds
"""
if seconds:
self.debugger.before_sleep(self.current_task, seconds)
self.paused.put(self.current_task, seconds)
else:
# When we're called with a timeout of 0, this method acts as a checkpoint
# that allows aiosched to kick in and to its job without pausing the task's
# execution for too long. It is recommended to put a couple of checkpoints
# like these in your code if you see degraded concurrent performance in parts
# of your code that block the loop
self.reschedule_running()
def wait(self, task: Task):
"""
Makes the current task wait for completion of the given one
by only rescheduling it once the given task has finished
executing
"""
if task is not self.current_task:
task.joiners.add(self.current_task)
if task.done():
self.run_ready.extend(task.joiners)
def spawn(self, func: Callable[..., Coroutine[Any, Any, Any]], *args, **kwargs):
"""
Spawns a task from a coroutine function. All positional and keyword arguments
besides the coroutine function itself are passed to the newly created coroutine
"""
task = Task(func.__name__ or repr(func), func(*args, **kwargs))
# We inject our magic secret variable into the coroutine's stack frame, so
# we can look it up later
task.coroutine.cr_frame.f_locals.setdefault(CTRLC_PROTECTION_ENABLED, False)
task.scope = self.current_scope
if self.current_pool:
task.context = self.current_pool
self.current_pool.tasks.append(task)
self.data[self.current_task] = task
self.run_ready.append(task)
self.reschedule_running()
self.debugger.on_task_spawn(task)
def get_current_task(self):
"""
Returns the current task to an asynchronous
caller
"""
self.data[self.current_task] = self.current_task
self.reschedule_running()
def perform_io(self, resource, evt_type: int):
"""
Registers the given resource inside our selector to perform I/O multiplexing
:param resource: The resource on which a read or write operation
has to be performed
:param evt_type: The type of event to perform on the given
socket, either selectors.EVENT_READ or selectors.EVENT_WRITE
:type evt_type: int
"""
self.current_task.state = TaskState.IO
if self.current_task.last_io:
# Since most of the time tasks will perform multiple
# I/O operations on a given resource, unregistering them
# every time isn't a sensible approach. A quick and
# easy optimization to address this problem is to
# store the last I/O operation that the task performed,
# together with the resource itself, inside the task
# object. If the task then tries to perform the same
# operation on the same resource again, this method then
# returns immediately as the resource is already being watched
# by the selector. If the resource is the same, but the
# event type has changed, then we modify the resource's
# associated event. Only if the resource is different from
# the last one used then this method will register a new
# one
if self.current_task.last_io == (evt_type, resource):
# Selector is already listening for that event on
# this resource
return
elif self.current_task.last_io[1] == resource:
# If the event to listen for has changed we just modify it
key = self.selector.get_key(resource)
self.selector.modify(resource, evt_type | key.events, key.data.update({evt_type: self.current_task}))
self.current_task.last_io = (evt_type, resource)
self.debugger.on_io_schedule(resource, evt_type)
elif not self.current_task.last_io or self.current_task.last_io[1] != resource:
# The task has either registered a new resource or is doing
# I/O for the first time
self.current_task.last_io = evt_type, resource
try:
self.selector.register(resource, evt_type, {evt_type: self.current_task})
self.debugger.on_io_schedule(resource, evt_type)
except KeyError:
# The stream is already being used
key = self.selector.get_key(resource)
if key.data[key.events] == self.current_task:
# If the task that registered the stream
# changed their mind on what they want
# to do with it, who are we to deny their
# request?
self.selector.modify(resource, key.events | evt_type, {EVENT_READ: self.current_task,
EVENT_WRITE: self.current_task})
self.debugger.on_io_schedule(resource, evt_type)
elif key.events != evt_type:
# We also modify the event in
# our selector so that one task can read
# off a given stream while another one is
# writing to it
self.selector.modify(resource, key.events | evt_type, {evt_type: self.current_task,
key.events: list(key.data.values())[0]})
else:
# One task reading and one writing on the same
# resource is fine (think producer-consumer),
# but having two tasks reading/writing at the
# same time can't lead to anything good, better
# disallow it
self.current_task.throw(ResourceBusy(f"The resource is being read from/written to by another task"))