Simplify and improve Ctrl+C delivery. System task status is now inherited for children of tasks spawned via spawn_system_task
This commit is contained in:
parent
bfd494a2d7
commit
723efc91fe
174
structio/abc.py
174
structio/abc.py
|
@ -20,23 +20,23 @@ class BaseClock(ABC):
|
|||
|
||||
@abstractmethod
|
||||
def start(self):
|
||||
return NotImplemented
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def setup(self):
|
||||
return NotImplemented
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def teardown(self):
|
||||
return NotImplemented
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def current_time(self):
|
||||
return NotImplemented
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def deadline(self, deadline):
|
||||
return NotImplemented
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class SchedulingPolicy(ABC):
|
||||
|
@ -56,6 +56,8 @@ class SchedulingPolicy(ABC):
|
|||
policy knows about this task
|
||||
"""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def has_next_task(self) -> bool:
|
||||
"""
|
||||
|
@ -78,9 +80,20 @@ class SchedulingPolicy(ABC):
|
|||
def peek_paused_task(self) -> Task | None:
|
||||
"""
|
||||
Returns the first paused task in the queue,
|
||||
if there is any, but doesn't consume it
|
||||
"""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def peek_next_task(self) -> Task | None:
|
||||
"""
|
||||
Returns the first task that is ready to run,
|
||||
if there is any, but doesn't remove it
|
||||
"""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_paused_task(self) -> Task | None:
|
||||
"""
|
||||
|
@ -88,6 +101,8 @@ class SchedulingPolicy(ABC):
|
|||
if it exists
|
||||
"""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def schedule(self, task: Task, front: bool = False):
|
||||
"""
|
||||
|
@ -95,7 +110,7 @@ class SchedulingPolicy(ABC):
|
|||
the task will be the next one to be scheduled
|
||||
"""
|
||||
|
||||
return NotImplemented
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def pause(self, task: Task):
|
||||
|
@ -103,7 +118,7 @@ class SchedulingPolicy(ABC):
|
|||
Pauses the given task
|
||||
"""
|
||||
|
||||
return NotImplemented
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def discard(self, task: Task):
|
||||
|
@ -111,7 +126,7 @@ class SchedulingPolicy(ABC):
|
|||
Discards the given task from the policy
|
||||
"""
|
||||
|
||||
return NotImplemented
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_next_task(self) -> Task | None:
|
||||
|
@ -129,7 +144,7 @@ class SchedulingPolicy(ABC):
|
|||
Returns the closest deadline to be satisfied
|
||||
"""
|
||||
|
||||
return NotImplemented
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class AsyncResource(ABC):
|
||||
|
@ -145,7 +160,7 @@ class AsyncResource(ABC):
|
|||
|
||||
@abstractmethod
|
||||
async def close(self):
|
||||
return NotImplemented
|
||||
raise NotImplementedError
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
await self.close()
|
||||
|
@ -159,7 +174,7 @@ class StreamWriter(AsyncResource, ABC):
|
|||
|
||||
@abstractmethod
|
||||
async def write(self, _data):
|
||||
return NotImplemented
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class StreamReader(AsyncResource, ABC):
|
||||
|
@ -172,7 +187,7 @@ class StreamReader(AsyncResource, ABC):
|
|||
|
||||
@abstractmethod
|
||||
async def _read(self, _size: int = -1):
|
||||
return NotImplemented
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class Stream(StreamReader, StreamWriter, ABC):
|
||||
|
@ -193,7 +208,7 @@ class Stream(StreamReader, StreamWriter, ABC):
|
|||
Flushes the underlying resource asynchronously
|
||||
"""
|
||||
|
||||
return NotImplemented
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class WriteCloseableStream(Stream, ABC):
|
||||
|
@ -228,7 +243,7 @@ class ChannelReader(AsyncResource, ABC):
|
|||
possibly blocking
|
||||
"""
|
||||
|
||||
return NotImplemented
|
||||
raise NotImplementedError
|
||||
|
||||
def __aiter__(self):
|
||||
"""
|
||||
|
@ -274,7 +289,7 @@ class ChannelWriter(AsyncResource, ABC):
|
|||
possibly blocking
|
||||
"""
|
||||
|
||||
return NotImplemented
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def writers(self):
|
||||
|
@ -301,7 +316,7 @@ class BaseDebugger(ABC):
|
|||
loop starts executing
|
||||
"""
|
||||
|
||||
return NotImplemented
|
||||
raise NotImplementedError
|
||||
|
||||
def on_exit(self):
|
||||
"""
|
||||
|
@ -309,7 +324,7 @@ class BaseDebugger(ABC):
|
|||
loop exits entirely (all tasks completed)
|
||||
"""
|
||||
|
||||
return NotImplemented
|
||||
raise NotImplementedError
|
||||
|
||||
def on_task_spawn(self, task: Task):
|
||||
"""
|
||||
|
@ -320,7 +335,7 @@ class BaseDebugger(ABC):
|
|||
:type task: :class: structio.objects.Task
|
||||
"""
|
||||
|
||||
return NotImplemented
|
||||
raise NotImplementedError
|
||||
|
||||
def on_task_exit(self, task: Task):
|
||||
"""
|
||||
|
@ -330,7 +345,7 @@ class BaseDebugger(ABC):
|
|||
:type task: :class: structio.objects.Task
|
||||
"""
|
||||
|
||||
return NotImplemented
|
||||
raise NotImplementedError
|
||||
|
||||
def before_task_step(self, task: Task):
|
||||
"""
|
||||
|
@ -341,7 +356,7 @@ class BaseDebugger(ABC):
|
|||
:type task: :class: structio.objects.Task
|
||||
"""
|
||||
|
||||
return NotImplemented
|
||||
raise NotImplementedError
|
||||
|
||||
def after_task_step(self, task: Task):
|
||||
"""
|
||||
|
@ -352,7 +367,7 @@ class BaseDebugger(ABC):
|
|||
:type task: :class: structio.objects.Task
|
||||
"""
|
||||
|
||||
return NotImplemented
|
||||
raise NotImplementedError
|
||||
|
||||
def before_sleep(self, task: Task, seconds: float):
|
||||
"""
|
||||
|
@ -366,7 +381,7 @@ class BaseDebugger(ABC):
|
|||
:type seconds: int
|
||||
"""
|
||||
|
||||
return NotImplemented
|
||||
raise NotImplementedError
|
||||
|
||||
def after_sleep(self, task: Task, seconds: float):
|
||||
"""
|
||||
|
@ -380,7 +395,7 @@ class BaseDebugger(ABC):
|
|||
:type seconds: float
|
||||
"""
|
||||
|
||||
return NotImplemented
|
||||
raise NotImplementedError
|
||||
|
||||
def before_io(self, timeout: float):
|
||||
"""
|
||||
|
@ -393,7 +408,7 @@ class BaseDebugger(ABC):
|
|||
:type timeout: float
|
||||
"""
|
||||
|
||||
return NotImplemented
|
||||
raise NotImplementedError
|
||||
|
||||
def after_io(self, timeout: float):
|
||||
"""
|
||||
|
@ -406,7 +421,7 @@ class BaseDebugger(ABC):
|
|||
:type timeout: float
|
||||
"""
|
||||
|
||||
return NotImplemented
|
||||
raise NotImplementedError
|
||||
|
||||
def before_cancel(self, task: Task):
|
||||
"""
|
||||
|
@ -417,7 +432,7 @@ class BaseDebugger(ABC):
|
|||
:type task: :class: structio.objects.Task
|
||||
"""
|
||||
|
||||
return NotImplemented
|
||||
raise NotImplementedError
|
||||
|
||||
def after_cancel(self, task: Task) -> object:
|
||||
"""
|
||||
|
@ -428,7 +443,7 @@ class BaseDebugger(ABC):
|
|||
:type task: :class: structio.objects.Task
|
||||
"""
|
||||
|
||||
return NotImplemented
|
||||
raise NotImplementedError
|
||||
|
||||
def on_exception_raised(self, task: Task, exc: BaseException):
|
||||
"""
|
||||
|
@ -441,7 +456,7 @@ class BaseDebugger(ABC):
|
|||
:type exc: BaseException
|
||||
"""
|
||||
|
||||
return NotImplemented
|
||||
raise NotImplementedError
|
||||
|
||||
def on_io_schedule(self, stream, event: str):
|
||||
"""
|
||||
|
@ -450,7 +465,7 @@ class BaseDebugger(ABC):
|
|||
event loop
|
||||
"""
|
||||
|
||||
return NotImplemented
|
||||
raise NotImplementedError
|
||||
|
||||
def on_io_unschedule(self, stream):
|
||||
"""
|
||||
|
@ -458,7 +473,7 @@ class BaseDebugger(ABC):
|
|||
is unregistered from the loop
|
||||
"""
|
||||
|
||||
return NotImplemented
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class BaseIOManager(ABC):
|
||||
|
@ -473,7 +488,7 @@ class BaseIOManager(ABC):
|
|||
when data is ready to be read/written
|
||||
"""
|
||||
|
||||
return NotImplemented
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def request_read(self, rsc, task: Task):
|
||||
|
@ -483,7 +498,7 @@ class BaseIOManager(ABC):
|
|||
task
|
||||
"""
|
||||
|
||||
return NotImplemented
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def request_write(self, rsc, task: Task):
|
||||
|
@ -493,7 +508,7 @@ class BaseIOManager(ABC):
|
|||
task
|
||||
"""
|
||||
|
||||
return NotImplemented
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def pending(self):
|
||||
|
@ -503,7 +518,7 @@ class BaseIOManager(ABC):
|
|||
in the manager
|
||||
"""
|
||||
|
||||
return NotImplemented
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def release(self, resource):
|
||||
|
@ -513,7 +528,7 @@ class BaseIOManager(ABC):
|
|||
closed!
|
||||
"""
|
||||
|
||||
return NotImplemented
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def release_task(self, task: Task):
|
||||
|
@ -525,7 +540,7 @@ class BaseIOManager(ABC):
|
|||
not unschedule it for those as well
|
||||
"""
|
||||
|
||||
return NotImplemented
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_reader(self, rsc):
|
||||
|
@ -534,7 +549,7 @@ class BaseIOManager(ABC):
|
|||
resource, if any (None otherwise)
|
||||
"""
|
||||
|
||||
return NotImplemented
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_writer(self, rsc):
|
||||
|
@ -543,7 +558,7 @@ class BaseIOManager(ABC):
|
|||
resource, if any (None otherwise)
|
||||
"""
|
||||
|
||||
return NotImplemented
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_readers(self) -> tuple["structio.io.FdWrapper", Task]:
|
||||
|
@ -552,7 +567,7 @@ class BaseIOManager(ABC):
|
|||
by the manager for read events
|
||||
"""
|
||||
|
||||
return NotImplemented
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_writers(self) -> tuple["structio.io.FdWrapper", Task]:
|
||||
|
@ -561,7 +576,7 @@ class BaseIOManager(ABC):
|
|||
by the manager for write events
|
||||
"""
|
||||
|
||||
return NotImplemented
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def close(self):
|
||||
|
@ -583,7 +598,7 @@ class SignalManager(ABC):
|
|||
Installs the signal handler
|
||||
"""
|
||||
|
||||
return NotImplemented
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def uninstall(self):
|
||||
|
@ -591,7 +606,7 @@ class SignalManager(ABC):
|
|||
Uninstalls the signal handler
|
||||
"""
|
||||
|
||||
return NotImplemented
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class BaseKernel(ABC):
|
||||
|
@ -610,8 +625,8 @@ class BaseKernel(ABC):
|
|||
):
|
||||
self.clock = clock
|
||||
self.current_task: Task | None = None
|
||||
self.current_pool: type["structio.TaskPool"] | None = None
|
||||
self.current_scope: type["structio.TaskScope"] | None = None
|
||||
self.current_pool: "structio.TaskPool" = None # noqa
|
||||
self.current_scope: structio.TaskScope = None # noqa
|
||||
self.tools: list[BaseDebugger] = tools or []
|
||||
self.restrict_ki_to_checkpoints: bool = restrict_ki_to_checkpoints
|
||||
self.io_manager = io_manager
|
||||
|
@ -619,7 +634,7 @@ class BaseKernel(ABC):
|
|||
self.entry_point: Task | None = None
|
||||
self.policy = policy
|
||||
# Pool for system tasks
|
||||
self.pool: type["structio.TaskPool"] | None = None
|
||||
self.pool: "structio.TaskPool" = None # noqa
|
||||
|
||||
def get_system_pool(self) -> "structio.TaskPool":
|
||||
"""
|
||||
|
@ -642,7 +657,7 @@ class BaseKernel(ABC):
|
|||
the current task
|
||||
"""
|
||||
|
||||
return NotImplemented
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def wait_writable(self, resource: AsyncResource):
|
||||
|
@ -651,7 +666,7 @@ class BaseKernel(ABC):
|
|||
the current task
|
||||
"""
|
||||
|
||||
return NotImplemented
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def release_resource(self, resource: AsyncResource):
|
||||
|
@ -659,7 +674,7 @@ class BaseKernel(ABC):
|
|||
Releases the given resource from the scheduler
|
||||
"""
|
||||
|
||||
return NotImplemented
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def notify_closing(
|
||||
|
@ -670,7 +685,7 @@ class BaseKernel(ABC):
|
|||
is about to be closed and can be unscheduled
|
||||
"""
|
||||
|
||||
return NotImplemented
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def cancel_task(self, task: Task):
|
||||
|
@ -678,7 +693,7 @@ class BaseKernel(ABC):
|
|||
Cancels the given task individually
|
||||
"""
|
||||
|
||||
return NotImplemented
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def signal_notify(self, sig: int, frame: FrameType):
|
||||
|
@ -687,21 +702,25 @@ class BaseKernel(ABC):
|
|||
received
|
||||
"""
|
||||
|
||||
return NotImplemented
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def spawn(self, func: Callable[[Any, Any], Coroutine[Any, Any, Any]], *args):
|
||||
def spawn(self, func: Callable[[Any, Any], Coroutine[Any, Any, Any]], *args,
|
||||
ki_protected: bool = False,
|
||||
pool: "structio.TaskPool" = None,
|
||||
system_task: bool = False,
|
||||
entry_point: bool = False) -> Task:
|
||||
"""
|
||||
Readies a task for execution. All positional arguments are passed
|
||||
to the given coroutine (for keyword arguments, use `functools.partial`)
|
||||
"""
|
||||
|
||||
return NotImplemented
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def spawn_system_task(
|
||||
self, func: Callable[[Any, Any], Coroutine[Any, Any, Any]], *args
|
||||
):
|
||||
) -> Task:
|
||||
"""
|
||||
Spawns a system task. System tasks run in a special internal
|
||||
task pool and begin execution in a scope with Ctrl+C protection
|
||||
|
@ -713,7 +732,7 @@ class BaseKernel(ABC):
|
|||
used
|
||||
"""
|
||||
|
||||
return NotImplemented
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_closest_deadline(self) -> Any:
|
||||
|
@ -721,7 +740,7 @@ class BaseKernel(ABC):
|
|||
Returns the closest deadline to be satisfied
|
||||
"""
|
||||
|
||||
return NotImplemented
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def setup(self):
|
||||
|
@ -746,7 +765,7 @@ class BaseKernel(ABC):
|
|||
task
|
||||
"""
|
||||
|
||||
return NotImplemented
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def reschedule(self, task: Task):
|
||||
|
@ -755,7 +774,7 @@ class BaseKernel(ABC):
|
|||
execution
|
||||
"""
|
||||
|
||||
return NotImplemented
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def event(self, evt_name, *args):
|
||||
|
@ -764,7 +783,7 @@ class BaseKernel(ABC):
|
|||
in the event loop
|
||||
"""
|
||||
|
||||
return NotImplemented
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def run(self):
|
||||
|
@ -773,7 +792,7 @@ class BaseKernel(ABC):
|
|||
of the "event loop"
|
||||
"""
|
||||
|
||||
return NotImplemented
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def sleep(self, amount):
|
||||
|
@ -782,7 +801,7 @@ class BaseKernel(ABC):
|
|||
time as defined by our current clock
|
||||
"""
|
||||
|
||||
return NotImplemented
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def suspend(self):
|
||||
|
@ -790,7 +809,7 @@ class BaseKernel(ABC):
|
|||
Suspends the current task until it is rescheduled
|
||||
"""
|
||||
|
||||
return NotImplemented
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def init_scope(self, scope):
|
||||
|
@ -799,7 +818,7 @@ class BaseKernel(ABC):
|
|||
TaskScope.__enter__)
|
||||
"""
|
||||
|
||||
return NotImplemented
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def close_scope(self, scope):
|
||||
|
@ -808,7 +827,7 @@ class BaseKernel(ABC):
|
|||
TaskScope.__exit__)
|
||||
"""
|
||||
|
||||
return NotImplemented
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def init_pool(self, pool):
|
||||
|
@ -817,7 +836,7 @@ class BaseKernel(ABC):
|
|||
TaskPool.__aenter__)
|
||||
"""
|
||||
|
||||
return NotImplemented
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def close_pool(self, pool):
|
||||
|
@ -826,7 +845,7 @@ class BaseKernel(ABC):
|
|||
TaskPool.__aexit__)
|
||||
"""
|
||||
|
||||
return NotImplemented
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def cancel_scope(self, scope):
|
||||
|
@ -834,7 +853,7 @@ class BaseKernel(ABC):
|
|||
Cancels the given scope
|
||||
"""
|
||||
|
||||
return NotImplemented
|
||||
raise NotImplementedError
|
||||
|
||||
def start(self, entry_point: Callable[[Any, Any], Coroutine[Any, Any, Any]], *args):
|
||||
"""
|
||||
|
@ -847,9 +866,9 @@ class BaseKernel(ABC):
|
|||
|
||||
self.setup()
|
||||
self.event("on_start")
|
||||
self.pool: "structio.TaskPool"
|
||||
self.current_pool = self.pool
|
||||
self.entry_point = self.spawn(entry_point, *args)
|
||||
self.entry_point = self.spawn(entry_point, *args, entry_point=True)
|
||||
assert not self.entry_point.is_system_task
|
||||
self.current_pool.scope.owner = self.entry_point
|
||||
self.entry_point.pool = self.current_pool
|
||||
self.current_pool.entry_point = self.entry_point
|
||||
|
@ -864,13 +883,24 @@ class BaseKernel(ABC):
|
|||
self.event("on_exit")
|
||||
return self.entry_point.result
|
||||
|
||||
@abstractmethod
|
||||
def raise_ki(self, task: Task | None = None):
|
||||
"""
|
||||
Raises a KeyboardInterrupt exception into a
|
||||
task: If one is passed explicitly, the exception
|
||||
is thrown there, otherwise a suitable task is
|
||||
awakened and thrown into
|
||||
"""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def done(self):
|
||||
"""
|
||||
Returns whether the loop has work to do
|
||||
"""
|
||||
|
||||
return NotImplemented
|
||||
raise NotImplementedError
|
||||
|
||||
def close(self, force: bool = False):
|
||||
"""
|
||||
|
@ -904,7 +934,7 @@ class BaseKernel(ABC):
|
|||
unique identifier that can be used to unregister the shutdown task
|
||||
"""
|
||||
|
||||
return NotImplemented
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def remove_shutdown_task(self, ident: Any) -> bool:
|
||||
|
@ -913,4 +943,4 @@ class BaseKernel(ABC):
|
|||
Returns whether a task was actually removed
|
||||
"""
|
||||
|
||||
return NotImplemented
|
||||
raise NotImplementedError
|
||||
|
|
|
@ -135,6 +135,7 @@ class DefaultKernel(BaseKernel):
|
|||
ki_protected: bool = False,
|
||||
pool: TaskPool = None,
|
||||
system_task: bool = False,
|
||||
entry_point: bool = False
|
||||
):
|
||||
if isinstance(func, partial):
|
||||
name = func.func.__name__ or repr(func.func)
|
||||
|
@ -145,6 +146,8 @@ class DefaultKernel(BaseKernel):
|
|||
pool = self.pool
|
||||
else:
|
||||
pool = self.current_pool
|
||||
if pool is self.pool and not entry_point:
|
||||
system_task = True
|
||||
task = Task(name, func(*args), pool.scope, pool, is_system_task=system_task)
|
||||
pool.scope.tasks.append(task)
|
||||
# We inject our magic secret variable into the coroutine's stack frame, so
|
||||
|
@ -196,7 +199,7 @@ class DefaultKernel(BaseKernel):
|
|||
elif self._sigint_handled and not critical_section(
|
||||
self.current_task.coroutine.cr_frame
|
||||
):
|
||||
self._raise_ki(self.current_task)
|
||||
self.raise_ki(self.current_task)
|
||||
return
|
||||
self.event("before_task_step", self.current_task)
|
||||
self.current_task.state = TaskState.RUNNING
|
||||
|
@ -235,7 +238,7 @@ class DefaultKernel(BaseKernel):
|
|||
|
||||
def check_cancelled(self, schedule: bool = True):
|
||||
if self._sigint_handled:
|
||||
self._raise_ki()
|
||||
self.raise_ki()
|
||||
elif self.current_task.pending_cancellation:
|
||||
self.current_task: Task
|
||||
self.cancel_task(self.current_task)
|
||||
|
@ -300,7 +303,7 @@ class DefaultKernel(BaseKernel):
|
|||
throw KeyboardInterrupt into
|
||||
"""
|
||||
|
||||
if self.policy.has_next_task():
|
||||
if self.policy.has_next_task() and not self.policy.peek_next_task().is_system_task:
|
||||
return self.policy.get_next_task()
|
||||
if self.policy.has_paused_task():
|
||||
return self.policy.get_paused_task()
|
||||
|
@ -308,7 +311,7 @@ class DefaultKernel(BaseKernel):
|
|||
return self.entry_point
|
||||
raise StructIOException("unable to find a task to throw KeyboardInterrupt into")
|
||||
|
||||
def _raise_ki(self, task: Task | None = None):
|
||||
def raise_ki(self, task: Task | None = None):
|
||||
"""
|
||||
Raises a KeyboardInterrupt exception into a
|
||||
task: If one is passed explicitly, the exception
|
||||
|
@ -317,7 +320,10 @@ class DefaultKernel(BaseKernel):
|
|||
"""
|
||||
|
||||
self._sigint_handled = False
|
||||
self.throw(task or self._pick_ki_task(), KeyboardInterrupt())
|
||||
task = task or self._pick_ki_task()
|
||||
self.throw(task, KeyboardInterrupt())
|
||||
if task.done():
|
||||
self.close()
|
||||
|
||||
def _tick(self):
|
||||
"""
|
||||
|
@ -325,7 +331,7 @@ class DefaultKernel(BaseKernel):
|
|||
"""
|
||||
|
||||
if self._sigint_handled and not self.restrict_ki_to_checkpoints:
|
||||
self._raise_ki()
|
||||
self.raise_ki()
|
||||
self.wakeup()
|
||||
self.check_scopes()
|
||||
self.step()
|
||||
|
@ -356,7 +362,8 @@ class DefaultKernel(BaseKernel):
|
|||
assert self.pool.scope.attempted_cancel
|
||||
assert self.pool.scope.cancelled
|
||||
assert self.pool.done()
|
||||
# Reset some stuff
|
||||
# Reset some stuff. Should probably initialize
|
||||
# a new task scope, but I'm lazy and this works
|
||||
self.pool.scope.attempted_cancel = False
|
||||
self.pool.scope.cancelled = False
|
||||
if self.entry_point.state == TaskState.FINISHED:
|
||||
|
@ -543,6 +550,8 @@ class DefaultKernel(BaseKernel):
|
|||
self.reschedule(self.current_task)
|
||||
else:
|
||||
self.cancel_task(task)
|
||||
if scope.done():
|
||||
scope.cancelled = True
|
||||
|
||||
def init_pool(self, pool: TaskPool):
|
||||
pool.outer = self.current_pool
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
from structio.abc import SignalManager
|
||||
from structio.util.ki import currently_protected
|
||||
from structio.signals import set_signal_handler
|
||||
from structio.core.run import current_loop
|
||||
from structio.core.run import current_loop, current_task
|
||||
from types import FrameType
|
||||
import warnings
|
||||
import signal
|
||||
|
@ -20,8 +20,7 @@ class SigIntManager(SignalManager):
|
|||
if currently_protected():
|
||||
current_loop().signal_notify(sig, frame)
|
||||
else:
|
||||
current_loop().reschedule(current_loop().entry_point)
|
||||
current_loop().throw(current_loop().entry_point, KeyboardInterrupt())
|
||||
current_loop().raise_ki()
|
||||
|
||||
def install(self):
|
||||
if signal.getsignal(signal.SIGINT) != signal.default_int_handler:
|
||||
|
|
|
@ -32,19 +32,16 @@ class FIFOPolicy(SchedulingPolicy):
|
|||
return self.run_queue.popleft()
|
||||
|
||||
def peek_paused_task(self) -> Task | None:
|
||||
"""
|
||||
Returns the first paused task in the queue,
|
||||
if there is any, but doesn't remove it
|
||||
"""
|
||||
|
||||
if not self.has_paused_task():
|
||||
return
|
||||
return self.paused.peek()
|
||||
|
||||
def get_paused_task(self) -> Task | None:
|
||||
"""
|
||||
Dequeues the first paused task in the queue,
|
||||
if it exists
|
||||
"""
|
||||
def peek_next_task(self) -> Task | None:
|
||||
if not self.has_next_task():
|
||||
return
|
||||
return self.run_queue[0]
|
||||
|
||||
def get_paused_task(self) -> Task | None:
|
||||
if not self.paused:
|
||||
return None
|
||||
return self.paused.get()[0]
|
||||
|
|
|
@ -68,12 +68,14 @@ async def main_python():
|
|||
# run the given target function
|
||||
p = structio.parallel.PythonProcess(target=foo)
|
||||
p.start()
|
||||
await p.wait_started()
|
||||
await p.wait()
|
||||
print("[main] Pyhon process test complete")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
structio.run(main_simple, "owo")
|
||||
#structio.run(main_simple, "owo")
|
||||
structio.run(main_limiter)
|
||||
structio.run(main_python)
|
||||
#structio.run(main_python)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue