Compare commits
4 Commits
360be2750d
...
060d61dc32
Author | SHA1 | Date |
---|---|---|
Mattia Giambirtone | 060d61dc32 | |
Mattia Giambirtone | 5071399431 | |
Mattia Giambirtone | c85e3037e4 | |
Mattia Giambirtone | bc5e0f167f |
|
@ -1,6 +1,7 @@
|
|||
from structio.core import run as _run
|
||||
from typing import Coroutine, Any, Callable
|
||||
from structio.core.kernels.fifo import FIFOKernel
|
||||
from structio.core.kernel import DefaultKernel
|
||||
from structio.core.policies.fifo import FIFOPolicy
|
||||
from structio.core.managers.io.simple import SimpleIOManager
|
||||
from structio.core.managers.signals.sigint import SigIntManager
|
||||
from structio.core.time.clock import DefaultClock
|
||||
|
@ -51,11 +52,11 @@ def run(
|
|||
restrict_ki_to_checkpoints: bool = False,
|
||||
tools: list | None = None,
|
||||
):
|
||||
result = None
|
||||
try:
|
||||
result = _run.run(
|
||||
func,
|
||||
FIFOKernel,
|
||||
DefaultKernel,
|
||||
FIFOPolicy(),
|
||||
SimpleIOManager(),
|
||||
[SigIntManager()],
|
||||
DefaultClock(),
|
||||
|
@ -65,8 +66,8 @@ def run(
|
|||
)
|
||||
finally:
|
||||
# Bunch of cleanup
|
||||
_signals._sig_handlers.clear()
|
||||
_signals._sig_data.clear()
|
||||
_signals._sig_handlers.clear() # noqa
|
||||
_signals._sig_data.clear() # noqa
|
||||
return result
|
||||
|
||||
|
||||
|
|
145
structio/abc.py
145
structio/abc.py
|
@ -39,6 +39,99 @@ class BaseClock(ABC):
|
|||
return NotImplemented
|
||||
|
||||
|
||||
class SchedulingPolicy(ABC):
|
||||
"""
|
||||
A generic scheduling policy. This is what
|
||||
controls the way tasks are scheduled in the
|
||||
event loop
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def is_scheduled(self, task: Task) -> bool:
|
||||
"""
|
||||
Returns whether the given task is
|
||||
scheduled to run. This doesn't
|
||||
necessarily mean that the task will
|
||||
actually get executed, just that the
|
||||
policy knows about this task
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def has_next_task(self) -> bool:
|
||||
"""
|
||||
Returns whether the policy has a next
|
||||
candidate task to run
|
||||
"""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def has_paused_task(self) -> bool:
|
||||
"""
|
||||
Returns whether the policy has any paused
|
||||
tasks waiting to be rescheduled
|
||||
"""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def peek_paused_task(self) -> Task | None:
|
||||
"""
|
||||
Returns the first paused task in the queue,
|
||||
if there is any, but doesn't remove it
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_paused_task(self) -> Task | None:
|
||||
"""
|
||||
Dequeues the first paused task in the queue,
|
||||
if it exists
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def schedule(self, task: Task, front: bool = False):
|
||||
"""
|
||||
Schedules a task for execution. If front is true,
|
||||
the task will be the next one to be scheduled
|
||||
"""
|
||||
|
||||
return NotImplemented
|
||||
|
||||
@abstractmethod
|
||||
def pause(self, task: Task):
|
||||
"""
|
||||
Pauses the given task
|
||||
"""
|
||||
|
||||
return NotImplemented
|
||||
|
||||
@abstractmethod
|
||||
def discard(self, task: Task):
|
||||
"""
|
||||
Discards the given task from the policy
|
||||
"""
|
||||
|
||||
return NotImplemented
|
||||
|
||||
@abstractmethod
|
||||
def get_next_task(self) -> Task | None:
|
||||
"""
|
||||
Returns the next runnable task. None
|
||||
may returned if no runnable tasks are
|
||||
available
|
||||
"""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_closest_deadline(self) -> Any:
|
||||
"""
|
||||
Returns the closest deadline to be satisfied
|
||||
"""
|
||||
|
||||
return NotImplemented
|
||||
|
||||
|
||||
class AsyncResource(ABC):
|
||||
"""
|
||||
A generic asynchronous resource which needs to
|
||||
|
@ -65,7 +158,7 @@ class StreamWriter(AsyncResource, ABC):
|
|||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def write(self, data):
|
||||
async def write(self, _data):
|
||||
return NotImplemented
|
||||
|
||||
|
||||
|
@ -78,7 +171,7 @@ class StreamReader(AsyncResource, ABC):
|
|||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def _read(self, size: int = -1):
|
||||
async def _read(self, _size: int = -1):
|
||||
return NotImplemented
|
||||
|
||||
|
||||
|
@ -255,7 +348,7 @@ class BaseDebugger(ABC):
|
|||
This method is called right after
|
||||
calling a task's run() method
|
||||
|
||||
:param task: The Task that has ran
|
||||
:param task: The Task that has run
|
||||
:type task: :class: structio.objects.Task
|
||||
"""
|
||||
|
||||
|
@ -374,7 +467,7 @@ class BaseIOManager(ABC):
|
|||
"""
|
||||
|
||||
@abstractmethod
|
||||
def wait_io(self, current_time):
|
||||
def wait_io(self):
|
||||
"""
|
||||
Waits for I/O and reschedules tasks
|
||||
when data is ready to be read/written
|
||||
|
@ -453,7 +546,7 @@ class BaseIOManager(ABC):
|
|||
return NotImplemented
|
||||
|
||||
@abstractmethod
|
||||
def get_readers(self) -> tuple["FdWrapper", Task]:
|
||||
def get_readers(self) -> tuple["structio.io.FdWrapper", Task]:
|
||||
"""
|
||||
Returns all I/O resources currently watched
|
||||
by the manager for read events
|
||||
|
@ -462,7 +555,7 @@ class BaseIOManager(ABC):
|
|||
return NotImplemented
|
||||
|
||||
@abstractmethod
|
||||
def get_writers(self) -> tuple["FdWrapper", Task]:
|
||||
def get_writers(self) -> tuple["structio.io.FdWrapper", Task]:
|
||||
"""
|
||||
Returns all I/O resources currently watched
|
||||
by the manager for write events
|
||||
|
@ -508,6 +601,7 @@ class BaseKernel(ABC):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
policy: SchedulingPolicy,
|
||||
clock: BaseClock,
|
||||
io_manager: BaseIOManager,
|
||||
signal_managers: list[SignalManager],
|
||||
|
@ -516,15 +610,30 @@ class BaseKernel(ABC):
|
|||
):
|
||||
self.clock = clock
|
||||
self.current_task: Task | None = None
|
||||
self.current_pool: "TaskPool" = None
|
||||
self.current_scope: "TaskScope" = None
|
||||
self.current_pool: type["structio.TaskPool"] | None = None
|
||||
self.current_scope: type["structio.TaskScope"] | None = None
|
||||
self.tools: list[BaseDebugger] = tools or []
|
||||
self.restrict_ki_to_checkpoints: bool = restrict_ki_to_checkpoints
|
||||
self.io_manager = io_manager
|
||||
self.signal_managers = signal_managers
|
||||
self.entry_point: Task | None = None
|
||||
self.policy = policy
|
||||
# Pool for system tasks
|
||||
self.pool: "TaskPool" = None
|
||||
self.pool: type["structio.TaskPool"] | None = None
|
||||
|
||||
def get_system_pool(self) -> "structio.TaskPool":
|
||||
"""
|
||||
Returns the kernel's "system" pool, where tasks
|
||||
spawned via spawn_system_task() as well as the
|
||||
entry point are implicitly run into. This is meant
|
||||
to be used as an internal method for structio's
|
||||
scheduling policy implementations
|
||||
"""
|
||||
|
||||
if self.pool is None:
|
||||
raise StructIOException("broken state: system pool is None")
|
||||
self.pool: "structio.TaskPool"
|
||||
return self.pool
|
||||
|
||||
@abstractmethod
|
||||
def wait_readable(self, resource: AsyncResource):
|
||||
|
@ -595,16 +704,19 @@ class BaseKernel(ABC):
|
|||
):
|
||||
"""
|
||||
Spawns a system task. System tasks run in a special internal
|
||||
task pool and begin execution in a scope shielded by cancellations
|
||||
and with Ctrl+C protection enabled. Please note that if a system
|
||||
tasks raises an exception, all tasks are cancelled and the exception
|
||||
is propagated into the loop's entry point
|
||||
task pool and begin execution in a scope with Ctrl+C protection
|
||||
enabled. Please note that if a system tasks raises an exception,
|
||||
all tasks are cancelled and a StructIOException is propagated into the
|
||||
loop's entry point. System tasks are guaranteed to always run at least
|
||||
one task step regardless of the state of the entry point and are cancelled
|
||||
automatically when the entry point exits (unless a shielded TaskScope is
|
||||
used
|
||||
"""
|
||||
|
||||
return NotImplemented
|
||||
|
||||
@abstractmethod
|
||||
def get_closest_deadline(self):
|
||||
def get_closest_deadline(self) -> Any:
|
||||
"""
|
||||
Returns the closest deadline to be satisfied
|
||||
"""
|
||||
|
@ -735,6 +847,7 @@ class BaseKernel(ABC):
|
|||
|
||||
self.setup()
|
||||
self.event("on_start")
|
||||
self.pool: "structio.TaskPool"
|
||||
self.current_pool = self.pool
|
||||
self.entry_point = self.spawn(entry_point, *args)
|
||||
self.current_pool.scope.owner = self.entry_point
|
||||
|
@ -778,7 +891,9 @@ class BaseKernel(ABC):
|
|||
raise StructIOException("the event loop is running")
|
||||
|
||||
@abstractmethod
|
||||
def add_shutdown_task(self, func: Callable[[Any, Any], Coroutine[Any, Any, Any]], *args) -> Any:
|
||||
def add_shutdown_task(
|
||||
self, func: Callable[[Any, Any], Coroutine[Any, Any, Any]], *args
|
||||
) -> Any:
|
||||
"""
|
||||
Registers a task to be run right before the event loop shuts
|
||||
down. The task is spawned as a system task when (and if) the main
|
||||
|
|
|
@ -151,6 +151,7 @@ class TaskPool:
|
|||
raise exc_val.with_traceback(exc_tb)
|
||||
elif not self.done():
|
||||
await suspend()
|
||||
assert self.done()
|
||||
else:
|
||||
await checkpoint()
|
||||
except Cancelled as e:
|
||||
|
@ -160,6 +161,7 @@ class TaskPool:
|
|||
self.error = e
|
||||
self.scope.cancel()
|
||||
finally:
|
||||
self.scope.__exit__(exc_type, exc_val, exc_tb)
|
||||
current_loop().close_pool(self)
|
||||
self._closed = True
|
||||
if self.error:
|
||||
|
@ -167,7 +169,10 @@ class TaskPool:
|
|||
|
||||
def done(self):
|
||||
"""
|
||||
Returns whether the task pool has finished executing
|
||||
Returns whether the task pool's internal
|
||||
task scope has finished executing. Note
|
||||
that this does not take the scope's owner
|
||||
into account!
|
||||
"""
|
||||
|
||||
return self.scope.done()
|
||||
|
@ -183,4 +188,7 @@ class TaskPool:
|
|||
executing until it is picked by the scheduler later on
|
||||
"""
|
||||
|
||||
if self._closed:
|
||||
raise StructIOException("task pool is closed")
|
||||
|
||||
return current_loop().spawn(func, *args)
|
||||
|
|
|
@ -7,12 +7,12 @@ from structio.abc import (
|
|||
BaseDebugger,
|
||||
BaseIOManager,
|
||||
SignalManager,
|
||||
SchedulingPolicy,
|
||||
)
|
||||
from structio.io import FdWrapper
|
||||
from structio.core.context import TaskPool, TaskScope
|
||||
from structio.core.task import Task, TaskState
|
||||
from structio.util.ki import CTRLC_PROTECTION_ENABLED, critical_section
|
||||
from structio.core.time.queue import TimeQueue
|
||||
from structio.exceptions import (
|
||||
StructIOException,
|
||||
Cancelled,
|
||||
|
@ -28,14 +28,15 @@ import signal
|
|||
import sniffio
|
||||
|
||||
|
||||
class FIFOKernel(BaseKernel):
|
||||
class DefaultKernel(BaseKernel):
|
||||
"""
|
||||
An asynchronous event loop implementation
|
||||
with a FIFO scheduling policy
|
||||
supporting generic scheduling policies
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
policy: SchedulingPolicy,
|
||||
clock: BaseClock,
|
||||
io_manager: BaseIOManager,
|
||||
signal_managers: list[SignalManager],
|
||||
|
@ -43,45 +44,34 @@ class FIFOKernel(BaseKernel):
|
|||
restrict_ki_to_checkpoints: bool = False,
|
||||
):
|
||||
super().__init__(
|
||||
clock, io_manager, signal_managers, tools, restrict_ki_to_checkpoints
|
||||
policy,
|
||||
clock,
|
||||
io_manager,
|
||||
signal_managers,
|
||||
tools,
|
||||
restrict_ki_to_checkpoints,
|
||||
)
|
||||
# Tasks that are ready to run
|
||||
self.run_queue: deque[Task] = deque()
|
||||
self.shutdown_tasks: deque[tuple[int, Callable[[Any, Any], Coroutine[Any, Any, Any]], list[Any]]] = deque()
|
||||
self.shutdown_tasks: deque[
|
||||
tuple[Any, Callable[[Any, Any], Coroutine[Any, Any, Any]], tuple[Any, ...]]
|
||||
] = deque()
|
||||
self._shutdown_task_ident = count(0)
|
||||
# Data to send back to tasks
|
||||
self.data: dict[Task, Any] = {}
|
||||
# Have we handled SIGINT?
|
||||
self._sigint_handled: bool = False
|
||||
# Paused tasks along with their deadlines
|
||||
self.paused: TimeQueue = TimeQueue()
|
||||
self.pool = TaskPool()
|
||||
self.current_scope = self.pool.scope
|
||||
self._shutting_down = False
|
||||
|
||||
def get_closest_deadline(self):
|
||||
if self.run_queue:
|
||||
# We absolutely cannot block while other
|
||||
# tasks have things to do!
|
||||
return self.clock.current_time()
|
||||
deadlines = []
|
||||
for scope in self.pool.scope.children:
|
||||
deadlines.append(scope.get_effective_deadline()[0])
|
||||
if not deadlines:
|
||||
deadlines.append(float("inf"))
|
||||
return min(
|
||||
[
|
||||
min(deadlines),
|
||||
self.paused.get_closest_deadline(),
|
||||
]
|
||||
)
|
||||
def get_closest_deadline(self) -> Any:
|
||||
return self.policy.get_closest_deadline()
|
||||
|
||||
def wait_readable(self, resource: FdWrapper):
|
||||
self.current_task.state = TaskState.IO
|
||||
self.current_task: Task
|
||||
self.io_manager.request_read(resource, self.current_task)
|
||||
|
||||
def wait_writable(self, resource: FdWrapper):
|
||||
self.current_task.state = TaskState.IO
|
||||
self.current_task: Task
|
||||
self.io_manager.request_write(resource, self.current_task)
|
||||
|
||||
def notify_closing(
|
||||
|
@ -122,7 +112,7 @@ class FIFOKernel(BaseKernel):
|
|||
def done(self):
|
||||
if self.entry_point.done() and not self._shutting_down:
|
||||
return True
|
||||
if any([self.run_queue, self.paused, self.io_manager.pending()]):
|
||||
if any([self.policy.has_next_task(), self.policy.has_paused_task(), self.io_manager.pending()]):
|
||||
return False
|
||||
if not self.pool.done():
|
||||
return False
|
||||
|
@ -152,7 +142,7 @@ class FIFOKernel(BaseKernel):
|
|||
task.coroutine.cr_frame.f_locals.setdefault(
|
||||
CTRLC_PROTECTION_ENABLED, ki_protected
|
||||
)
|
||||
self.run_queue.append(task)
|
||||
self.policy.schedule(task)
|
||||
self.event("on_task_spawn", task)
|
||||
return task
|
||||
|
||||
|
@ -166,7 +156,7 @@ class FIFOKernel(BaseKernel):
|
|||
case signal.SIGINT:
|
||||
self._sigint_handled = True
|
||||
# Poke the event loop with a stick ;)
|
||||
self.run_queue.append(self.entry_point)
|
||||
self.policy.schedule(self.entry_point)
|
||||
case _:
|
||||
pass
|
||||
|
||||
|
@ -176,11 +166,15 @@ class FIFOKernel(BaseKernel):
|
|||
primitives somewhere)
|
||||
"""
|
||||
|
||||
self.current_task = self.run_queue.popleft()
|
||||
if not self.policy.has_next_task():
|
||||
return
|
||||
self.current_task = self.policy.get_next_task()
|
||||
self.current_task: Task
|
||||
while self.current_task.done():
|
||||
if not self.run_queue:
|
||||
if not self.policy.has_next_task():
|
||||
return
|
||||
self.current_task = self.run_queue.popleft()
|
||||
self.current_task = self.policy.get_next_task()
|
||||
self.current_task: Task
|
||||
if self.current_task.done():
|
||||
return
|
||||
runner = partial(
|
||||
|
@ -189,7 +183,9 @@ class FIFOKernel(BaseKernel):
|
|||
if self.current_task.pending_cancellation:
|
||||
self.cancel_task(self.current_task)
|
||||
return
|
||||
elif self._sigint_handled and not critical_section(self.current_task.coroutine.cr_frame):
|
||||
elif self._sigint_handled and not critical_section(
|
||||
self.current_task.coroutine.cr_frame
|
||||
):
|
||||
self._sigint_handled = False
|
||||
runner = partial(self.current_task.coroutine.throw, KeyboardInterrupt())
|
||||
self.event("before_task_step", self.current_task)
|
||||
|
@ -197,11 +193,10 @@ class FIFOKernel(BaseKernel):
|
|||
self.current_task.paused_when = 0
|
||||
self.current_pool = self.current_task.pool
|
||||
self.current_scope = self.current_task.scope
|
||||
data = self.handle_errors(runner, self.current_task)
|
||||
data = self.handle(runner, self.current_task)
|
||||
if data is not None:
|
||||
method, args, kwargs = data
|
||||
self.current_task.state = TaskState.PAUSED
|
||||
self.current_task.paused_when = self.clock.current_time()
|
||||
self.suspend()
|
||||
if not callable(getattr(self, method, None)):
|
||||
# This if block is meant to be triggered by other async
|
||||
# libraries, which most likely have different method names and behaviors
|
||||
|
@ -221,21 +216,23 @@ class FIFOKernel(BaseKernel):
|
|||
def throw(self, task: Task, err: BaseException):
|
||||
if task.done():
|
||||
return
|
||||
self.handle_errors(partial(task.coroutine.throw, err), task)
|
||||
self.handle(partial(task.coroutine.throw, err), task)
|
||||
|
||||
def reschedule(self, task: Task):
|
||||
if task.done():
|
||||
return
|
||||
self.run_queue.append(task)
|
||||
self.policy.schedule(task)
|
||||
|
||||
def check_cancelled(self, schedule: bool = True):
|
||||
if self._sigint_handled:
|
||||
self.throw(self.entry_point, KeyboardInterrupt())
|
||||
elif self.current_task.pending_cancellation:
|
||||
self.current_task: Task
|
||||
self.cancel_task(self.current_task)
|
||||
elif schedule:
|
||||
self.current_task: Task
|
||||
# We reschedule the caller immediately!
|
||||
self.run_queue.appendleft(self.current_task)
|
||||
self.policy.schedule(self.current_task, front=True)
|
||||
|
||||
def schedule_point(self):
|
||||
self.reschedule_running()
|
||||
|
@ -251,7 +248,8 @@ class FIFOKernel(BaseKernel):
|
|||
if amount > 0:
|
||||
self.event("before_sleep", self.current_task, amount)
|
||||
self.current_task.next_deadline = self.clock.deadline(amount)
|
||||
self.paused.put(self.current_task, self.clock.deadline(amount))
|
||||
self.current_task: Task
|
||||
self.policy.pause(self.current_task)
|
||||
else:
|
||||
# If sleep is called with 0 as argument,
|
||||
# then it's just a checkpoint!
|
||||
|
@ -268,15 +266,17 @@ class FIFOKernel(BaseKernel):
|
|||
scope.timed_out = True
|
||||
error = TimedOut("timed out")
|
||||
error.scope = scope
|
||||
scope.cancel()
|
||||
self.throw(scope.owner, error)
|
||||
self.reschedule(scope.owner)
|
||||
if not self.policy.is_scheduled(scope.owner):
|
||||
self.reschedule(scope.owner)
|
||||
|
||||
def wakeup(self):
|
||||
while (
|
||||
self.paused
|
||||
and self.paused.peek().next_deadline <= self.clock.current_time()
|
||||
self.policy.has_paused_task()
|
||||
and self.policy.peek_paused_task().next_deadline <= self.clock.current_time()
|
||||
):
|
||||
task, _ = self.paused.get()
|
||||
task = self.policy.get_paused_task()
|
||||
task.next_deadline = 0
|
||||
self.event(
|
||||
"after_sleep", task, task.paused_when - self.clock.current_time()
|
||||
|
@ -286,12 +286,11 @@ class FIFOKernel(BaseKernel):
|
|||
def _tick(self):
|
||||
if self._sigint_handled and not self.restrict_ki_to_checkpoints:
|
||||
self.throw(self.entry_point, KeyboardInterrupt())
|
||||
if self.run_queue:
|
||||
self.step()
|
||||
self.wakeup()
|
||||
self.check_scopes()
|
||||
self.step()
|
||||
if self.io_manager.pending():
|
||||
self.io_manager.wait_io(self.clock.current_time())
|
||||
self.io_manager.wait_io()
|
||||
|
||||
def run(self):
|
||||
"""
|
||||
|
@ -303,15 +302,20 @@ class FIFOKernel(BaseKernel):
|
|||
self._tick()
|
||||
self._shutting_down = True
|
||||
# Ensure all system tasks have a chance to spin up
|
||||
while any(task.state == TaskState.INIT for task in self.pool.scope.tasks):
|
||||
while any(task.is_system_task and task.state == TaskState.INIT for task in self.pool.scope.tasks):
|
||||
self._tick()
|
||||
# Cancel the system pool and wait for cancellation
|
||||
# to be delivered
|
||||
self.pool.scope.cancel()
|
||||
while not self.done():
|
||||
self._tick()
|
||||
# Sanity checking
|
||||
assert self.pool.scope.attempted_cancel
|
||||
assert self.pool.scope.cancelled
|
||||
assert self.pool.done()
|
||||
# Reset some stuff
|
||||
self.pool.scope.attempted_cancel = False
|
||||
self.pool.scope.cancelled = False
|
||||
if self.entry_point.state == TaskState.FINISHED:
|
||||
while True:
|
||||
# Spawn all the shutdown tasks that are currently registered
|
||||
|
@ -336,9 +340,10 @@ class FIFOKernel(BaseKernel):
|
|||
Reschedules the currently running task
|
||||
"""
|
||||
|
||||
self.current_task: Task
|
||||
self.reschedule(self.current_task)
|
||||
|
||||
def handle_errors(self, func: Callable, task: Task):
|
||||
def handle(self, func: Callable, task: Task):
|
||||
"""
|
||||
Convenience method for handling various exceptions
|
||||
from tasks
|
||||
|
@ -391,7 +396,7 @@ class FIFOKernel(BaseKernel):
|
|||
"""
|
||||
|
||||
self.io_manager.release_task(task)
|
||||
self.paused.discard(task)
|
||||
self.policy.discard(task)
|
||||
|
||||
def on_success(self, task: Task):
|
||||
"""
|
||||
|
@ -442,8 +447,6 @@ class FIFOKernel(BaseKernel):
|
|||
self.release(task)
|
||||
|
||||
def init_scope(self, scope: TaskScope):
|
||||
if self.current_task is not self.current_scope.owner:
|
||||
self.current_scope.tasks.remove(self.current_task)
|
||||
self.current_task.scope = scope
|
||||
scope.deadline = self.clock.deadline(scope.timeout)
|
||||
scope.owner = self.current_task
|
||||
|
@ -459,7 +462,7 @@ class FIFOKernel(BaseKernel):
|
|||
def cancel_task(self, task: Task):
|
||||
if task.done():
|
||||
return
|
||||
if task.state in [TaskState.RUNNING]:
|
||||
if task.state == TaskState.RUNNING:
|
||||
# Can't cancel a task while it's
|
||||
# running, will raise ValueError
|
||||
# if we try, so we defer it for later
|
||||
|
@ -469,7 +472,7 @@ class FIFOKernel(BaseKernel):
|
|||
err.scope = task.scope
|
||||
self.throw(task, err)
|
||||
if task.state != TaskState.CANCELLED:
|
||||
# Task is stubborn. But so are we,
|
||||
# Task is stubborn, but so are we,
|
||||
# so we'll redeliver the cancellation
|
||||
# every time said task tries to call any
|
||||
# event loop primitive
|
||||
|
@ -516,7 +519,9 @@ class FIFOKernel(BaseKernel):
|
|||
for manager in self.signal_managers:
|
||||
manager.uninstall()
|
||||
|
||||
def add_shutdown_task(self, func: Callable[[Any, Any], Coroutine[Any, Any, Any]], *args) -> int:
|
||||
def add_shutdown_task(
|
||||
self, func: Callable[[Any, Any], Coroutine[Any, Any, Any]], *args
|
||||
) -> int:
|
||||
ident = next(self._shutdown_task_ident)
|
||||
self.shutdown_tasks.append((ident, func, args))
|
||||
return ident
|
|
@ -1,4 +1,3 @@
|
|||
|
||||
import warnings
|
||||
|
||||
import structio
|
||||
|
@ -68,16 +67,18 @@ class SimpleIOManager(BaseIOManager):
|
|||
if self._closed:
|
||||
raise structio.exceptions.ResourceClosed("the I/O manager is closed")
|
||||
|
||||
def wait_io(self, current_time):
|
||||
def wait_io(self):
|
||||
self._check_closed()
|
||||
kernel: BaseKernel = current_loop()
|
||||
current_time = kernel.clock.current_time()
|
||||
deadline = kernel.get_closest_deadline()
|
||||
if deadline == float("inf"):
|
||||
deadline = 0
|
||||
elif deadline > 0:
|
||||
deadline -= current_time
|
||||
deadline = max(0, deadline)
|
||||
current = kernel.clock.current_time()
|
||||
# FIXME: This delay seems to help throttle the calls
|
||||
# to this method. Should we be calling it this often?
|
||||
deadline = max(0.01, deadline)
|
||||
readers = self._collect_readers()
|
||||
writers = self._collect_writers()
|
||||
kernel.event("before_io", deadline)
|
||||
|
@ -87,22 +88,24 @@ class SimpleIOManager(BaseIOManager):
|
|||
writers + readers,
|
||||
deadline,
|
||||
)
|
||||
kernel.event("after_io", kernel.clock.current_time() - current)
|
||||
kernel.event("after_io", kernel.clock.current_time() - current_time)
|
||||
# On Windows, a successful connection is marked
|
||||
# as an exceptional event rather than a write
|
||||
# one
|
||||
writable.extend(exceptional)
|
||||
del exceptional
|
||||
for read_ready in readable:
|
||||
for resource, task in self.readers.copy().items():
|
||||
if resource.fileno() == read_ready and task.state == TaskState.IO:
|
||||
kernel.reschedule(task)
|
||||
self.readers.pop(resource)
|
||||
wrapper = FdWrapper(read_ready)
|
||||
task = self.readers[wrapper]
|
||||
kernel.reschedule(task)
|
||||
# We don't want to listen for read events on
|
||||
# this resource anymore, so we release it
|
||||
self.release(wrapper)
|
||||
for write_ready in writable:
|
||||
for resource, task in self.writers.copy().items():
|
||||
if resource.fileno() == write_ready and task.state == TaskState.IO:
|
||||
kernel.reschedule(task)
|
||||
self.writers.pop(resource)
|
||||
wrapper = FdWrapper(write_ready)
|
||||
task = self.writers[wrapper]
|
||||
kernel.reschedule(task)
|
||||
self.release(wrapper)
|
||||
|
||||
def request_read(self, rsc: FdWrapper, task: Task):
|
||||
self._check_closed()
|
||||
|
@ -143,9 +146,13 @@ class SimpleIOManager(BaseIOManager):
|
|||
def close(self):
|
||||
self._check_closed()
|
||||
for reader in self.readers:
|
||||
warnings.warn(f"I/O manager was closed with scheduled write event for {reader}")
|
||||
warnings.warn(
|
||||
f"I/O manager was closed with scheduled write event for {reader}"
|
||||
)
|
||||
for writer in self.writers:
|
||||
warnings.warn(f"I/O manager was closed with scheduled write event for {writer}")
|
||||
warnings.warn(
|
||||
f"I/O manager was closed with scheduled write event for {writer}"
|
||||
)
|
||||
self.readers = {}
|
||||
self.writers = {}
|
||||
self._closed = True
|
||||
|
|
|
@ -0,0 +1,81 @@
|
|||
from structio.core.run import current_loop
|
||||
from structio.abc import SchedulingPolicy
|
||||
from structio.core.task import Task, TaskState
|
||||
from structio.core.time.queue import TimeQueue
|
||||
from collections import deque
|
||||
|
||||
|
||||
class FIFOPolicy(SchedulingPolicy):
|
||||
"""
|
||||
A First-in, First-out scheduling policy
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
# Tasks that are ready to run
|
||||
self.run_queue: deque[Task] = deque()
|
||||
# Paused tasks along with their deadlines
|
||||
self.paused: TimeQueue = TimeQueue()
|
||||
|
||||
def is_scheduled(self, task: Task) -> bool:
|
||||
# TODO: This should be fine, make sure of it
|
||||
return task.state == TaskState.READY
|
||||
|
||||
def has_next_task(self) -> bool:
|
||||
return bool(self.run_queue)
|
||||
|
||||
def has_paused_task(self) -> bool:
|
||||
return bool(self.paused)
|
||||
|
||||
def get_next_task(self) -> Task | None:
|
||||
if not self.run_queue:
|
||||
return None
|
||||
return self.run_queue.popleft()
|
||||
|
||||
def peek_paused_task(self) -> Task | None:
|
||||
"""
|
||||
Returns the first paused task in the queue,
|
||||
if there is any, but doesn't remove it
|
||||
"""
|
||||
|
||||
return self.paused.peek()
|
||||
|
||||
def get_paused_task(self) -> Task | None:
|
||||
"""
|
||||
Dequeues the first paused task in the queue,
|
||||
if it exists
|
||||
"""
|
||||
|
||||
if not self.paused:
|
||||
return None
|
||||
return self.paused.get()[0]
|
||||
|
||||
def schedule(self, task: Task, front: bool = False):
|
||||
task.state = TaskState.READY
|
||||
if front:
|
||||
self.run_queue.append(task)
|
||||
else:
|
||||
self.run_queue.append(task)
|
||||
|
||||
def pause(self, task: Task):
|
||||
task.state = TaskState.PAUSED
|
||||
self.paused.put(task, task.next_deadline)
|
||||
|
||||
def discard(self, task: Task):
|
||||
self.paused.discard(task)
|
||||
|
||||
def get_closest_deadline(self):
|
||||
if self.run_queue:
|
||||
# We absolutely cannot block while other
|
||||
# tasks have things to do!
|
||||
return current_loop().clock.current_time()
|
||||
deadlines = []
|
||||
for scope in current_loop().get_system_pool().scope.children:
|
||||
deadlines.append(scope.get_effective_deadline()[0])
|
||||
if not deadlines:
|
||||
deadlines.append(float("inf"))
|
||||
return min(
|
||||
[
|
||||
min(deadlines),
|
||||
self.paused.get_closest_deadline(),
|
||||
]
|
||||
)
|
|
@ -1,6 +1,7 @@
|
|||
import inspect
|
||||
import structio
|
||||
import functools
|
||||
|
||||
# I *really* hate fork()
|
||||
from multiprocessing_utils import local
|
||||
from structio.abc import (
|
||||
|
@ -8,7 +9,7 @@ from structio.abc import (
|
|||
BaseDebugger,
|
||||
BaseClock,
|
||||
SignalManager,
|
||||
BaseIOManager,
|
||||
BaseIOManager, SchedulingPolicy,
|
||||
)
|
||||
from structio.exceptions import StructIOException
|
||||
from structio.core.task import Task
|
||||
|
@ -59,6 +60,7 @@ def new_event_loop(kernel: BaseKernel):
|
|||
def run(
|
||||
func: Callable[[Any, Any], Coroutine[Any, Any, Any]],
|
||||
kernel: type,
|
||||
policy: SchedulingPolicy,
|
||||
io_manager: BaseIOManager,
|
||||
signal_managers: list[SignalManager],
|
||||
clock: BaseClock,
|
||||
|
@ -94,6 +96,7 @@ def run(
|
|||
waker.set_wakeup_fd()
|
||||
new_event_loop(
|
||||
kernel(
|
||||
policy=policy,
|
||||
clock=clock,
|
||||
restrict_ki_to_checkpoints=restrict_ki_to_checkpoints,
|
||||
io_manager=io_manager,
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
|
||||
from enum import Enum, auto
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Coroutine, Any, Callable
|
||||
|
@ -9,6 +10,7 @@ _counter = count()
|
|||
class TaskState(Enum):
|
||||
INIT: int = auto()
|
||||
RUNNING: int = auto()
|
||||
READY: int = auto()
|
||||
PAUSED: int = auto()
|
||||
FINISHED: int = auto()
|
||||
CRASHED: int = auto()
|
||||
|
|
|
@ -45,10 +45,18 @@ class FdWrapper:
|
|||
def fileno(self):
|
||||
return self.fd
|
||||
|
||||
def __hash__(self):
|
||||
return self.fd.__hash__()
|
||||
|
||||
# Can be converted to an int
|
||||
def __int__(self):
|
||||
return self.fd
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, FdWrapper):
|
||||
return False
|
||||
return self.fileno() == other.fileno()
|
||||
|
||||
def __repr__(self):
|
||||
return f"<fd={self.fd!r}>"
|
||||
|
||||
|
|
|
@ -372,7 +372,9 @@ class AsyncSocket(AsyncResource):
|
|||
while pos < size:
|
||||
n = await self.recv_into(memoryview(buf)[pos:], flags=flags)
|
||||
if n == 0:
|
||||
raise ResourceBroken("incomplete read detected")
|
||||
raise ResourceBroken(
|
||||
"incomplete read detected: is the remote end gone?"
|
||||
)
|
||||
pos += n
|
||||
return bytes(buf)
|
||||
|
||||
|
@ -383,19 +385,14 @@ class AsyncSocket(AsyncResource):
|
|||
|
||||
if self._fd == -1:
|
||||
raise ResourceClosed("I/O operation on closed socket")
|
||||
await checkpoint()
|
||||
with self.write_lock, self.read_lock:
|
||||
connected = False
|
||||
while not connected:
|
||||
try:
|
||||
self.socket.connect(address)
|
||||
if self.do_handshake_on_connect:
|
||||
await self.do_handshake()
|
||||
connected = True
|
||||
await checkpoint()
|
||||
except WantRead:
|
||||
await wait_readable(self._fd)
|
||||
except WantWrite:
|
||||
await wait_writable(self._fd)
|
||||
try:
|
||||
self.socket.connect(address)
|
||||
except WantWrite:
|
||||
await wait_writable(self._fd)
|
||||
if self.do_handshake_on_connect:
|
||||
await self.do_handshake()
|
||||
|
||||
async def close(self):
|
||||
"""
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
Module inspired by subprocess which allows for asynchronous
|
||||
multiprocessing
|
||||
"""
|
||||
|
||||
import os
|
||||
import struct
|
||||
import sys
|
||||
|
@ -52,7 +53,9 @@ class Process:
|
|||
self.stderr = None
|
||||
self.returncode = None
|
||||
self.pid = -1
|
||||
self.shutdown_handlers: list[tuple[int, bool, Callable[[Any, Any], Coroutine[Any, Any, Any]], args]] = []
|
||||
self.shutdown_handlers: list[
|
||||
tuple[int, bool, Callable[[Any, Any], Coroutine[Any, Any, Any]], args]
|
||||
] = []
|
||||
self._handler_id = count()
|
||||
self._taskid = None
|
||||
self._started = structio.Event()
|
||||
|
@ -67,11 +70,17 @@ class Process:
|
|||
)
|
||||
|
||||
async def _run_shutdown_handlers(self, before_wait: bool = False):
|
||||
for _, _, coro, args in filter(lambda h: h[1] is before_wait, self.shutdown_handlers):
|
||||
for _, _, coro, args in filter(
|
||||
lambda h: h[1] is before_wait, self.shutdown_handlers
|
||||
):
|
||||
await coro(*args)
|
||||
|
||||
def add_shutdown_handler(self, func: Callable[[Any, Any], Coroutine[Any, Any, Any]], *args,
|
||||
before_wait: bool = False) -> int:
|
||||
def add_shutdown_handler(
|
||||
self,
|
||||
func: Callable[[Any, Any], Coroutine[Any, Any, Any]],
|
||||
*args,
|
||||
before_wait: bool = False,
|
||||
) -> int:
|
||||
"""
|
||||
Registers a coroutine to be executed either right after, or right before wait() is called.
|
||||
Shutdown handlers are executed one at a time in the order in which they are registered. All
|
||||
|
@ -219,7 +228,9 @@ async def run(
|
|||
raise
|
||||
|
||||
if check and process.returncode:
|
||||
raise CalledProcessError(process.returncode, process.args, output=stdout, stderr=stderr)
|
||||
raise CalledProcessError(
|
||||
process.returncode, process.args, output=stdout, stderr=stderr
|
||||
)
|
||||
return CompletedProcess(process.args, process.returncode, stdout, stderr)
|
||||
|
||||
|
||||
|
@ -346,7 +357,9 @@ class PythonProcess:
|
|||
await self._sock.bind(("127.0.0.1", 0))
|
||||
await self._sock.listen(1)
|
||||
addr, port = self._sock.getsockname()
|
||||
self.process = Process([sys.executable, "-m", "structio.util.child_process", addr, str(port)])
|
||||
self.process = Process(
|
||||
[sys.executable, "-m", "structio.util.child_process", addr, str(port)]
|
||||
)
|
||||
# If we didn't close the socket before calling wait(), we'd deadlock waiting for the
|
||||
# process to exit while the process waits for us to send them a message
|
||||
self.process.add_shutdown_handler(self.close, before_wait=True)
|
||||
|
@ -364,8 +377,10 @@ class PythonProcess:
|
|||
except StructIOException as e:
|
||||
raise StructIOException("unable to get ACK from remote process") from e
|
||||
if payload["msg"] != "ACK":
|
||||
raise StructIOException(f"invalid message type {payload['msg']!r} received from process (expecting "
|
||||
f"'ACK'): {payload}")
|
||||
raise StructIOException(
|
||||
f"invalid message type {payload['msg']!r} received from process (expecting "
|
||||
f"'ACK'): {payload}"
|
||||
)
|
||||
|
||||
async def send_sos(self):
|
||||
"""
|
||||
|
@ -414,7 +429,9 @@ class PythonProcess:
|
|||
size, *_ = struct.unpack("Q", data)
|
||||
message = msgpack.unpackb(await self._remote.receive_exactly(size))
|
||||
if not message["ok"]:
|
||||
raise StructIOException(f"got error response from remote process: {message}")
|
||||
raise StructIOException(
|
||||
f"got error response from remote process: {message}"
|
||||
)
|
||||
return message
|
||||
|
||||
def start(self):
|
||||
|
@ -435,5 +452,3 @@ class PythonProcess:
|
|||
# is likely to be None until _do_setup runs to
|
||||
# completion
|
||||
return await self._started.wait()
|
||||
|
||||
|
||||
|
|
|
@ -80,7 +80,7 @@ async def signal_watcher(sock: AsyncSocket):
|
|||
# memory problems because their code is receiving thousands of signals and
|
||||
# the event loop is not handling them fast enough (right?)
|
||||
await sock.receive(1)
|
||||
async for (sig, frame) in _sig_data:
|
||||
async for sig, frame in _sig_data:
|
||||
if _sig_handlers[sig]:
|
||||
try:
|
||||
await _sig_handlers[sig](sig, frame)
|
||||
|
|
|
@ -230,14 +230,12 @@ class MemoryReceiveChannel(ChannelReader):
|
|||
def __init__(self, buffer):
|
||||
self._buffer = buffer
|
||||
self._closed = False
|
||||
self._read_lock = ThereCanBeOnlyOne("another task is reading from this channel")
|
||||
|
||||
@enable_ki_protection
|
||||
async def receive(self):
|
||||
if self._closed:
|
||||
raise ResourceClosed("cannot operate on a closed channel")
|
||||
with self._read_lock:
|
||||
return await self._buffer.get()
|
||||
return await self._buffer.get()
|
||||
|
||||
@enable_ki_protection
|
||||
async def close(self):
|
||||
|
@ -260,14 +258,12 @@ class MemorySendChannel(ChannelWriter):
|
|||
def __init__(self, buffer):
|
||||
self._buffer = buffer
|
||||
self._closed = False
|
||||
self._write_lock = ThereCanBeOnlyOne("another task is writing to this channel")
|
||||
|
||||
@enable_ki_protection
|
||||
async def send(self, item):
|
||||
if self._closed:
|
||||
raise ResourceClosed("cannot operate on a closed channel")
|
||||
with self._write_lock:
|
||||
return await self._buffer.put(item)
|
||||
return await self._buffer.put(item)
|
||||
|
||||
@enable_ki_protection
|
||||
async def close(self):
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
Helper module to spawn asynchronous Python processes via
|
||||
structio
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import struct
|
||||
|
@ -68,7 +69,14 @@ async def dispatch(sock: structio.AsyncSocket, message: dict):
|
|||
sys.exit(0)
|
||||
case _:
|
||||
# IDK: I don't know (means the command is unknown)
|
||||
await send_message(sock, {"ok": False, "msg": "IDK", "error": f"unknown message type {message['msg']!r}"})
|
||||
await send_message(
|
||||
sock,
|
||||
{
|
||||
"ok": False,
|
||||
"msg": "IDK",
|
||||
"error": f"unknown message type {message['msg']!r}",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def main(addr: tuple[str, int]):
|
||||
|
@ -79,5 +87,6 @@ async def main(addr: tuple[str, int]):
|
|||
while True:
|
||||
await dispatch(socket, await receive_message(socket))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
structio.run(main, (sys.argv[1], int(sys.argv[2])))
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
Utility module to look up objects to be called in a Python subprocess spawned
|
||||
by structio. Inspired by https://pikers.dev/goodboy/tractor/src/branch/mv_to_new_trio_py3.11/tractor/msg/ptr.py#L53
|
||||
"""
|
||||
|
||||
import importlib
|
||||
from pathlib import Path
|
||||
from inspect import ismethod, isfunction, isbuiltin, getmodule, getmodulename
|
||||
|
@ -29,7 +30,7 @@ def get_real_module_name(obj) -> tuple[bool, str]:
|
|||
mod_suffix = getmodulename(mod_obj.__file__)
|
||||
|
||||
# join parent to child with a .
|
||||
module = '.'.join(filter(bool, [mod_obj.__package__, mod_suffix]))
|
||||
module = ".".join(filter(bool, [mod_obj.__package__, mod_suffix]))
|
||||
|
||||
if mod_obj.__package__ is None:
|
||||
in_package = False
|
||||
|
@ -57,7 +58,9 @@ class ObjectReference:
|
|||
|
||||
def __init__(self, obj):
|
||||
if ismethod(obj) or islambda(obj):
|
||||
raise ValueError("bound methods and lambdas cannot be passed to a remote process")
|
||||
raise ValueError(
|
||||
"bound methods and lambdas cannot be passed to a remote process"
|
||||
)
|
||||
self.obj = obj
|
||||
self.in_package = True
|
||||
self._make_ref()
|
||||
|
|
|
@ -15,6 +15,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import inspect
|
||||
from functools import wraps
|
||||
|
|
|
@ -6,13 +6,15 @@ _print = print
|
|||
|
||||
|
||||
def print(*args, **kwargs):
|
||||
sys.stdout.write(f"[{time.strftime('%H:%M:%S')}] ")
|
||||
_print(*args, **kwargs)
|
||||
_print(f"[{time.strftime('%H:%M:%S')}]", *args, **kwargs)
|
||||
|
||||
|
||||
async def test(host: str, port: int, bufsize: int = 4096, keepalive: bool = False):
|
||||
async def test(host: str, port: int, bufsize: int = 4096, keepalive: bool = False, secure: bool = False):
|
||||
print(f"Attempting a connection to {host}:{port} {'in keep-alive mode' if keepalive else ''}")
|
||||
socket = await structio.socket.connect_tcp_ssl_socket(host, port)
|
||||
if secure:
|
||||
socket = await structio.socket.connect_tcp_ssl_socket(host, port)
|
||||
else:
|
||||
socket = await structio.socket.connect_tcp_socket(host, port)
|
||||
buffer = b""
|
||||
print("Connected")
|
||||
# Ensures the code below doesn't run for more than 5 seconds
|
||||
|
@ -41,7 +43,7 @@ async def test(host: str, port: int, bufsize: int = 4096, keepalive: bool = Fals
|
|||
print("Received empty stream, closing connection")
|
||||
break
|
||||
if buffer:
|
||||
data = buffer.decode().split("\r\n")
|
||||
data = buffer.decode(errors="ignore").split("\r\n")
|
||||
print(
|
||||
f"HTTP Response below {'(might be incomplete)' if scope.timed_out else ''}:"
|
||||
)
|
||||
|
@ -63,6 +65,6 @@ async def test(host: str, port: int, bufsize: int = 4096, keepalive: bool = Fals
|
|||
_print("Done!")
|
||||
|
||||
|
||||
structio.run(test, "google.com", 443, 256)
|
||||
structio.run(test, "google.com", 80, 256)
|
||||
# With keep-alive on, our timeout will kick in
|
||||
structio.run(test, "google.com", 443, 256, True)
|
||||
structio.run(test, "google.com", 80, 256, True)
|
|
@ -63,6 +63,6 @@ async def main_child(x: float):
|
|||
|
||||
|
||||
# Should take about 5 seconds
|
||||
structio.run(main_simple, 5, 2, 2)
|
||||
structio.run(main_nested, 5, 2, 2)
|
||||
#structio.run(main_simple, 5, 2, 2)
|
||||
#structio.run(main_nested, 5, 2, 2)
|
||||
structio.run(main_child, 2)
|
||||
|
|
Loading…
Reference in New Issue