Simplify and improve Ctrl+C delivery. System task status is now inherited for children of tasks spawned via spawn_system_task

This commit is contained in:
Mattia Giambirtone 2024-03-25 20:57:01 +01:00
parent bfd494a2d7
commit 723efc91fe
5 changed files with 131 additions and 94 deletions

View File

@ -20,23 +20,23 @@ class BaseClock(ABC):
@abstractmethod
def start(self):
return NotImplemented
raise NotImplementedError
@abstractmethod
def setup(self):
return NotImplemented
raise NotImplementedError
@abstractmethod
def teardown(self):
return NotImplemented
raise NotImplementedError
@abstractmethod
def current_time(self):
return NotImplemented
raise NotImplementedError
@abstractmethod
def deadline(self, deadline):
return NotImplemented
raise NotImplementedError
class SchedulingPolicy(ABC):
@ -56,6 +56,8 @@ class SchedulingPolicy(ABC):
policy knows about this task
"""
raise NotImplementedError
@abstractmethod
def has_next_task(self) -> bool:
"""
@ -78,9 +80,20 @@ class SchedulingPolicy(ABC):
def peek_paused_task(self) -> Task | None:
"""
Returns the first paused task in the queue,
if there is any, but doesn't consume it
"""
raise NotImplementedError
@abstractmethod
def peek_next_task(self) -> Task | None:
"""
Returns the first task that is ready to run,
if there is any, but doesn't remove it
"""
raise NotImplementedError
@abstractmethod
def get_paused_task(self) -> Task | None:
"""
@ -88,6 +101,8 @@ class SchedulingPolicy(ABC):
if it exists
"""
raise NotImplementedError
@abstractmethod
def schedule(self, task: Task, front: bool = False):
"""
@ -95,7 +110,7 @@ class SchedulingPolicy(ABC):
the task will be the next one to be scheduled
"""
return NotImplemented
raise NotImplementedError
@abstractmethod
def pause(self, task: Task):
@ -103,7 +118,7 @@ class SchedulingPolicy(ABC):
Pauses the given task
"""
return NotImplemented
raise NotImplementedError
@abstractmethod
def discard(self, task: Task):
@ -111,7 +126,7 @@ class SchedulingPolicy(ABC):
Discards the given task from the policy
"""
return NotImplemented
raise NotImplementedError
@abstractmethod
def get_next_task(self) -> Task | None:
@ -129,7 +144,7 @@ class SchedulingPolicy(ABC):
Returns the closest deadline to be satisfied
"""
return NotImplemented
raise NotImplementedError
class AsyncResource(ABC):
@ -145,7 +160,7 @@ class AsyncResource(ABC):
@abstractmethod
async def close(self):
return NotImplemented
raise NotImplementedError
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.close()
@ -159,7 +174,7 @@ class StreamWriter(AsyncResource, ABC):
@abstractmethod
async def write(self, _data):
return NotImplemented
raise NotImplementedError
class StreamReader(AsyncResource, ABC):
@ -172,7 +187,7 @@ class StreamReader(AsyncResource, ABC):
@abstractmethod
async def _read(self, _size: int = -1):
return NotImplemented
raise NotImplementedError
class Stream(StreamReader, StreamWriter, ABC):
@ -193,7 +208,7 @@ class Stream(StreamReader, StreamWriter, ABC):
Flushes the underlying resource asynchronously
"""
return NotImplemented
raise NotImplementedError
class WriteCloseableStream(Stream, ABC):
@ -228,7 +243,7 @@ class ChannelReader(AsyncResource, ABC):
possibly blocking
"""
return NotImplemented
raise NotImplementedError
def __aiter__(self):
"""
@ -274,7 +289,7 @@ class ChannelWriter(AsyncResource, ABC):
possibly blocking
"""
return NotImplemented
raise NotImplementedError
@abstractmethod
def writers(self):
@ -301,7 +316,7 @@ class BaseDebugger(ABC):
loop starts executing
"""
return NotImplemented
raise NotImplementedError
def on_exit(self):
"""
@ -309,7 +324,7 @@ class BaseDebugger(ABC):
loop exits entirely (all tasks completed)
"""
return NotImplemented
raise NotImplementedError
def on_task_spawn(self, task: Task):
"""
@ -320,7 +335,7 @@ class BaseDebugger(ABC):
:type task: :class: structio.objects.Task
"""
return NotImplemented
raise NotImplementedError
def on_task_exit(self, task: Task):
"""
@ -330,7 +345,7 @@ class BaseDebugger(ABC):
:type task: :class: structio.objects.Task
"""
return NotImplemented
raise NotImplementedError
def before_task_step(self, task: Task):
"""
@ -341,7 +356,7 @@ class BaseDebugger(ABC):
:type task: :class: structio.objects.Task
"""
return NotImplemented
raise NotImplementedError
def after_task_step(self, task: Task):
"""
@ -352,7 +367,7 @@ class BaseDebugger(ABC):
:type task: :class: structio.objects.Task
"""
return NotImplemented
raise NotImplementedError
def before_sleep(self, task: Task, seconds: float):
"""
@ -366,7 +381,7 @@ class BaseDebugger(ABC):
:type seconds: int
"""
return NotImplemented
raise NotImplementedError
def after_sleep(self, task: Task, seconds: float):
"""
@ -380,7 +395,7 @@ class BaseDebugger(ABC):
:type seconds: float
"""
return NotImplemented
raise NotImplementedError
def before_io(self, timeout: float):
"""
@ -393,7 +408,7 @@ class BaseDebugger(ABC):
:type timeout: float
"""
return NotImplemented
raise NotImplementedError
def after_io(self, timeout: float):
"""
@ -406,7 +421,7 @@ class BaseDebugger(ABC):
:type timeout: float
"""
return NotImplemented
raise NotImplementedError
def before_cancel(self, task: Task):
"""
@ -417,7 +432,7 @@ class BaseDebugger(ABC):
:type task: :class: structio.objects.Task
"""
return NotImplemented
raise NotImplementedError
def after_cancel(self, task: Task) -> object:
"""
@ -428,7 +443,7 @@ class BaseDebugger(ABC):
:type task: :class: structio.objects.Task
"""
return NotImplemented
raise NotImplementedError
def on_exception_raised(self, task: Task, exc: BaseException):
"""
@ -441,7 +456,7 @@ class BaseDebugger(ABC):
:type exc: BaseException
"""
return NotImplemented
raise NotImplementedError
def on_io_schedule(self, stream, event: str):
"""
@ -450,7 +465,7 @@ class BaseDebugger(ABC):
event loop
"""
return NotImplemented
raise NotImplementedError
def on_io_unschedule(self, stream):
"""
@ -458,7 +473,7 @@ class BaseDebugger(ABC):
is unregistered from the loop
"""
return NotImplemented
raise NotImplementedError
class BaseIOManager(ABC):
@ -473,7 +488,7 @@ class BaseIOManager(ABC):
when data is ready to be read/written
"""
return NotImplemented
raise NotImplementedError
@abstractmethod
def request_read(self, rsc, task: Task):
@ -483,7 +498,7 @@ class BaseIOManager(ABC):
task
"""
return NotImplemented
raise NotImplementedError
@abstractmethod
def request_write(self, rsc, task: Task):
@ -493,7 +508,7 @@ class BaseIOManager(ABC):
task
"""
return NotImplemented
raise NotImplementedError
@abstractmethod
def pending(self):
@ -503,7 +518,7 @@ class BaseIOManager(ABC):
in the manager
"""
return NotImplemented
raise NotImplementedError
@abstractmethod
def release(self, resource):
@ -513,7 +528,7 @@ class BaseIOManager(ABC):
closed!
"""
return NotImplemented
raise NotImplementedError
@abstractmethod
def release_task(self, task: Task):
@ -525,7 +540,7 @@ class BaseIOManager(ABC):
not unschedule it for those as well
"""
return NotImplemented
raise NotImplementedError
@abstractmethod
def get_reader(self, rsc):
@ -534,7 +549,7 @@ class BaseIOManager(ABC):
resource, if any (None otherwise)
"""
return NotImplemented
raise NotImplementedError
@abstractmethod
def get_writer(self, rsc):
@ -543,7 +558,7 @@ class BaseIOManager(ABC):
resource, if any (None otherwise)
"""
return NotImplemented
raise NotImplementedError
@abstractmethod
def get_readers(self) -> tuple["structio.io.FdWrapper", Task]:
@ -552,7 +567,7 @@ class BaseIOManager(ABC):
by the manager for read events
"""
return NotImplemented
raise NotImplementedError
@abstractmethod
def get_writers(self) -> tuple["structio.io.FdWrapper", Task]:
@ -561,7 +576,7 @@ class BaseIOManager(ABC):
by the manager for write events
"""
return NotImplemented
raise NotImplementedError
@abstractmethod
def close(self):
@ -583,7 +598,7 @@ class SignalManager(ABC):
Installs the signal handler
"""
return NotImplemented
raise NotImplementedError
@abstractmethod
def uninstall(self):
@ -591,7 +606,7 @@ class SignalManager(ABC):
Uninstalls the signal handler
"""
return NotImplemented
raise NotImplementedError
class BaseKernel(ABC):
@ -610,8 +625,8 @@ class BaseKernel(ABC):
):
self.clock = clock
self.current_task: Task | None = None
self.current_pool: type["structio.TaskPool"] | None = None
self.current_scope: type["structio.TaskScope"] | None = None
self.current_pool: "structio.TaskPool" = None # noqa
self.current_scope: structio.TaskScope = None # noqa
self.tools: list[BaseDebugger] = tools or []
self.restrict_ki_to_checkpoints: bool = restrict_ki_to_checkpoints
self.io_manager = io_manager
@ -619,7 +634,7 @@ class BaseKernel(ABC):
self.entry_point: Task | None = None
self.policy = policy
# Pool for system tasks
self.pool: type["structio.TaskPool"] | None = None
self.pool: "structio.TaskPool" = None # noqa
def get_system_pool(self) -> "structio.TaskPool":
"""
@ -642,7 +657,7 @@ class BaseKernel(ABC):
the current task
"""
return NotImplemented
raise NotImplementedError
@abstractmethod
def wait_writable(self, resource: AsyncResource):
@ -651,7 +666,7 @@ class BaseKernel(ABC):
the current task
"""
return NotImplemented
raise NotImplementedError
@abstractmethod
def release_resource(self, resource: AsyncResource):
@ -659,7 +674,7 @@ class BaseKernel(ABC):
Releases the given resource from the scheduler
"""
return NotImplemented
raise NotImplementedError
@abstractmethod
def notify_closing(
@ -670,7 +685,7 @@ class BaseKernel(ABC):
is about to be closed and can be unscheduled
"""
return NotImplemented
raise NotImplementedError
@abstractmethod
def cancel_task(self, task: Task):
@ -678,7 +693,7 @@ class BaseKernel(ABC):
Cancels the given task individually
"""
return NotImplemented
raise NotImplementedError
@abstractmethod
def signal_notify(self, sig: int, frame: FrameType):
@ -687,21 +702,25 @@ class BaseKernel(ABC):
received
"""
return NotImplemented
raise NotImplementedError
@abstractmethod
def spawn(self, func: Callable[[Any, Any], Coroutine[Any, Any, Any]], *args):
def spawn(self, func: Callable[[Any, Any], Coroutine[Any, Any, Any]], *args,
ki_protected: bool = False,
pool: "structio.TaskPool" = None,
system_task: bool = False,
entry_point: bool = False) -> Task:
"""
Readies a task for execution. All positional arguments are passed
to the given coroutine (for keyword arguments, use `functools.partial`)
"""
return NotImplemented
raise NotImplementedError
@abstractmethod
def spawn_system_task(
self, func: Callable[[Any, Any], Coroutine[Any, Any, Any]], *args
):
) -> Task:
"""
Spawns a system task. System tasks run in a special internal
task pool and begin execution in a scope with Ctrl+C protection
@ -713,7 +732,7 @@ class BaseKernel(ABC):
used
"""
return NotImplemented
raise NotImplementedError
@abstractmethod
def get_closest_deadline(self) -> Any:
@ -721,7 +740,7 @@ class BaseKernel(ABC):
Returns the closest deadline to be satisfied
"""
return NotImplemented
raise NotImplementedError
@abstractmethod
def setup(self):
@ -746,7 +765,7 @@ class BaseKernel(ABC):
task
"""
return NotImplemented
raise NotImplementedError
@abstractmethod
def reschedule(self, task: Task):
@ -755,7 +774,7 @@ class BaseKernel(ABC):
execution
"""
return NotImplemented
raise NotImplementedError
@abstractmethod
def event(self, evt_name, *args):
@ -764,7 +783,7 @@ class BaseKernel(ABC):
in the event loop
"""
return NotImplemented
raise NotImplementedError
@abstractmethod
def run(self):
@ -773,7 +792,7 @@ class BaseKernel(ABC):
of the "event loop"
"""
return NotImplemented
raise NotImplementedError
@abstractmethod
def sleep(self, amount):
@ -782,7 +801,7 @@ class BaseKernel(ABC):
time as defined by our current clock
"""
return NotImplemented
raise NotImplementedError
@abstractmethod
def suspend(self):
@ -790,7 +809,7 @@ class BaseKernel(ABC):
Suspends the current task until it is rescheduled
"""
return NotImplemented
raise NotImplementedError
@abstractmethod
def init_scope(self, scope):
@ -799,7 +818,7 @@ class BaseKernel(ABC):
TaskScope.__enter__)
"""
return NotImplemented
raise NotImplementedError
@abstractmethod
def close_scope(self, scope):
@ -808,7 +827,7 @@ class BaseKernel(ABC):
TaskScope.__exit__)
"""
return NotImplemented
raise NotImplementedError
@abstractmethod
def init_pool(self, pool):
@ -817,7 +836,7 @@ class BaseKernel(ABC):
TaskPool.__aenter__)
"""
return NotImplemented
raise NotImplementedError
@abstractmethod
def close_pool(self, pool):
@ -826,7 +845,7 @@ class BaseKernel(ABC):
TaskPool.__aexit__)
"""
return NotImplemented
raise NotImplementedError
@abstractmethod
def cancel_scope(self, scope):
@ -834,7 +853,7 @@ class BaseKernel(ABC):
Cancels the given scope
"""
return NotImplemented
raise NotImplementedError
def start(self, entry_point: Callable[[Any, Any], Coroutine[Any, Any, Any]], *args):
"""
@ -847,9 +866,9 @@ class BaseKernel(ABC):
self.setup()
self.event("on_start")
self.pool: "structio.TaskPool"
self.current_pool = self.pool
self.entry_point = self.spawn(entry_point, *args)
self.entry_point = self.spawn(entry_point, *args, entry_point=True)
assert not self.entry_point.is_system_task
self.current_pool.scope.owner = self.entry_point
self.entry_point.pool = self.current_pool
self.current_pool.entry_point = self.entry_point
@ -864,13 +883,24 @@ class BaseKernel(ABC):
self.event("on_exit")
return self.entry_point.result
@abstractmethod
def raise_ki(self, task: Task | None = None):
"""
Raises a KeyboardInterrupt exception into a
task: If one is passed explicitly, the exception
is thrown there, otherwise a suitable task is
awakened and thrown into
"""
raise NotImplementedError
@abstractmethod
def done(self):
"""
Returns whether the loop has work to do
"""
return NotImplemented
raise NotImplementedError
def close(self, force: bool = False):
"""
@ -904,7 +934,7 @@ class BaseKernel(ABC):
unique identifier that can be used to unregister the shutdown task
"""
return NotImplemented
raise NotImplementedError
@abstractmethod
def remove_shutdown_task(self, ident: Any) -> bool:
@ -913,4 +943,4 @@ class BaseKernel(ABC):
Returns whether a task was actually removed
"""
return NotImplemented
raise NotImplementedError

View File

@ -135,6 +135,7 @@ class DefaultKernel(BaseKernel):
ki_protected: bool = False,
pool: TaskPool = None,
system_task: bool = False,
entry_point: bool = False
):
if isinstance(func, partial):
name = func.func.__name__ or repr(func.func)
@ -145,6 +146,8 @@ class DefaultKernel(BaseKernel):
pool = self.pool
else:
pool = self.current_pool
if pool is self.pool and not entry_point:
system_task = True
task = Task(name, func(*args), pool.scope, pool, is_system_task=system_task)
pool.scope.tasks.append(task)
# We inject our magic secret variable into the coroutine's stack frame, so
@ -196,7 +199,7 @@ class DefaultKernel(BaseKernel):
elif self._sigint_handled and not critical_section(
self.current_task.coroutine.cr_frame
):
self._raise_ki(self.current_task)
self.raise_ki(self.current_task)
return
self.event("before_task_step", self.current_task)
self.current_task.state = TaskState.RUNNING
@ -235,7 +238,7 @@ class DefaultKernel(BaseKernel):
def check_cancelled(self, schedule: bool = True):
if self._sigint_handled:
self._raise_ki()
self.raise_ki()
elif self.current_task.pending_cancellation:
self.current_task: Task
self.cancel_task(self.current_task)
@ -300,7 +303,7 @@ class DefaultKernel(BaseKernel):
throw KeyboardInterrupt into
"""
if self.policy.has_next_task():
if self.policy.has_next_task() and not self.policy.peek_next_task().is_system_task:
return self.policy.get_next_task()
if self.policy.has_paused_task():
return self.policy.get_paused_task()
@ -308,7 +311,7 @@ class DefaultKernel(BaseKernel):
return self.entry_point
raise StructIOException("unable to find a task to throw KeyboardInterrupt into")
def _raise_ki(self, task: Task | None = None):
def raise_ki(self, task: Task | None = None):
"""
Raises a KeyboardInterrupt exception into a
task: If one is passed explicitly, the exception
@ -317,7 +320,10 @@ class DefaultKernel(BaseKernel):
"""
self._sigint_handled = False
self.throw(task or self._pick_ki_task(), KeyboardInterrupt())
task = task or self._pick_ki_task()
self.throw(task, KeyboardInterrupt())
if task.done():
self.close()
def _tick(self):
"""
@ -325,7 +331,7 @@ class DefaultKernel(BaseKernel):
"""
if self._sigint_handled and not self.restrict_ki_to_checkpoints:
self._raise_ki()
self.raise_ki()
self.wakeup()
self.check_scopes()
self.step()
@ -356,7 +362,8 @@ class DefaultKernel(BaseKernel):
assert self.pool.scope.attempted_cancel
assert self.pool.scope.cancelled
assert self.pool.done()
# Reset some stuff
# Reset some stuff. Should probably initialize
# a new task scope, but I'm lazy and this works
self.pool.scope.attempted_cancel = False
self.pool.scope.cancelled = False
if self.entry_point.state == TaskState.FINISHED:
@ -543,6 +550,8 @@ class DefaultKernel(BaseKernel):
self.reschedule(self.current_task)
else:
self.cancel_task(task)
if scope.done():
scope.cancelled = True
def init_pool(self, pool: TaskPool):
pool.outer = self.current_pool

View File

@ -1,7 +1,7 @@
from structio.abc import SignalManager
from structio.util.ki import currently_protected
from structio.signals import set_signal_handler
from structio.core.run import current_loop
from structio.core.run import current_loop, current_task
from types import FrameType
import warnings
import signal
@ -20,8 +20,7 @@ class SigIntManager(SignalManager):
if currently_protected():
current_loop().signal_notify(sig, frame)
else:
current_loop().reschedule(current_loop().entry_point)
current_loop().throw(current_loop().entry_point, KeyboardInterrupt())
current_loop().raise_ki()
def install(self):
if signal.getsignal(signal.SIGINT) != signal.default_int_handler:

View File

@ -32,19 +32,16 @@ class FIFOPolicy(SchedulingPolicy):
return self.run_queue.popleft()
def peek_paused_task(self) -> Task | None:
"""
Returns the first paused task in the queue,
if there is any, but doesn't remove it
"""
if not self.has_paused_task():
return
return self.paused.peek()
def get_paused_task(self) -> Task | None:
"""
Dequeues the first paused task in the queue,
if it exists
"""
def peek_next_task(self) -> Task | None:
if not self.has_next_task():
return
return self.run_queue[0]
def get_paused_task(self) -> Task | None:
if not self.paused:
return None
return self.paused.get()[0]

View File

@ -68,12 +68,14 @@ async def main_python():
# run the given target function
p = structio.parallel.PythonProcess(target=foo)
p.start()
await p.wait_started()
await p.wait()
print("[main] Pyhon process test complete")
if __name__ == "__main__":
structio.run(main_simple, "owo")
#structio.run(main_simple, "owo")
structio.run(main_limiter)
structio.run(main_python)
#structio.run(main_python)