Minor refactoring and cleanup. Added LIFO policy
This commit is contained in:
parent
9e1301322a
commit
1ba76ecdee
|
@ -1,6 +1,6 @@
|
|||
from structio.core import run as _run
|
||||
from typing import Coroutine, Any, Callable
|
||||
from structio.core.kernel import DefaultKernel
|
||||
from structio.core.run import run
|
||||
from structio.core import kernel
|
||||
from structio.core.policies.fifo import FIFOPolicy
|
||||
from structio.core.managers.io.simple import SimpleIOManager
|
||||
from structio.core.managers.signals.sigint import SigIntManager
|
||||
|
@ -46,34 +46,6 @@ from structio import signals as _signals
|
|||
from structio import util
|
||||
|
||||
|
||||
def run(
|
||||
func: Callable[[Any, Any], Coroutine[Any, Any, Any]],
|
||||
*args,
|
||||
restrict_ki_to_checkpoints: bool = False,
|
||||
tools: list | None = None,
|
||||
):
|
||||
try:
|
||||
result = _run.run(
|
||||
func,
|
||||
DefaultKernel,
|
||||
FIFOPolicy(),
|
||||
SimpleIOManager(),
|
||||
[SigIntManager()],
|
||||
DefaultClock(),
|
||||
tools,
|
||||
restrict_ki_to_checkpoints,
|
||||
*args,
|
||||
)
|
||||
finally:
|
||||
# Bunch of cleanup
|
||||
_signals._sig_handlers.clear() # noqa
|
||||
_signals._sig_data.clear() # noqa
|
||||
return result
|
||||
|
||||
|
||||
run.__doc__ = _run.run.__doc__
|
||||
|
||||
|
||||
def create_pool() -> TaskPool:
|
||||
"""
|
||||
Creates a new task pool
|
||||
|
|
|
@ -104,10 +104,9 @@ class SchedulingPolicy(ABC):
|
|||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def schedule(self, task: Task, front: bool = False):
|
||||
def schedule(self, task: Task):
|
||||
"""
|
||||
Schedules a task for execution. If front is true,
|
||||
the task will be the next one to be scheduled
|
||||
Schedules a task for execution
|
||||
"""
|
||||
|
||||
raise NotImplementedError
|
||||
|
@ -623,6 +622,28 @@ class BaseKernel(ABC):
|
|||
tools: list[BaseDebugger] | None = None,
|
||||
restrict_ki_to_checkpoints: bool = False,
|
||||
):
|
||||
if not issubclass(clock.__class__, BaseClock):
|
||||
raise TypeError(
|
||||
f"clock must be a subclass of {BaseClock.__module__}.{BaseClock.__qualname__}, not {type(clock)}"
|
||||
)
|
||||
if not issubclass(policy.__class__, SchedulingPolicy):
|
||||
raise TypeError(
|
||||
f"policy must be a subclass of {SchedulingPolicy.__module__}.{SchedulingPolicy.__qualname__}, not {type(policy)}"
|
||||
)
|
||||
if not issubclass(io_manager.__class__, BaseIOManager):
|
||||
raise TypeError(
|
||||
f"io_manager must be a subclass of {BaseIOManager.__module__}.{BaseIOManager.__qualname__}, not {type(io_manager)}"
|
||||
)
|
||||
for tool in tools or []:
|
||||
if not issubclass(tool.__class__, BaseDebugger):
|
||||
raise TypeError(
|
||||
f"tools must be a subclass of {BaseDebugger.__module__}.{BaseDebugger.__qualname__}, not {type(tool)}"
|
||||
)
|
||||
for mgr in signal_managers or []:
|
||||
if not issubclass(mgr.__class__, SignalManager):
|
||||
raise TypeError(
|
||||
f"signal manager must be a subclass of {SignalManager.__module__}.{SignalManager.__qualname__}, not {type(mgr)}"
|
||||
)
|
||||
self.clock = clock
|
||||
self.current_task: Task | None = None
|
||||
self.current_pool: "structio.TaskPool" = None # noqa
|
||||
|
|
|
@ -0,0 +1,8 @@
|
|||
from . import managers, policies, time
|
||||
|
||||
|
||||
__all__ = [
|
||||
"managers",
|
||||
"policies",
|
||||
"time"
|
||||
]
|
|
@ -71,7 +71,7 @@ class TaskScope:
|
|||
queue = TimeQueue()
|
||||
if self.shielded:
|
||||
return float("inf"), self
|
||||
times = queue.put(self, self.deadline)
|
||||
queue.put(self, self.deadline)
|
||||
for child in self.children:
|
||||
if child.shielded:
|
||||
return float("inf"), self
|
||||
|
@ -91,8 +91,6 @@ class TaskScope:
|
|||
if exc_val and isinstance(exc_val, structio.TimedOut):
|
||||
if exc_val.scope is self:
|
||||
return self.silent
|
||||
return True
|
||||
return False
|
||||
|
||||
# Just a recursive helper
|
||||
def _get_children(self, lst=None):
|
||||
|
|
|
@ -12,7 +12,7 @@ from structio.abc import (
|
|||
from structio.io import FdWrapper
|
||||
from structio.core.context import TaskPool, TaskScope
|
||||
from structio.core.task import Task, TaskState
|
||||
from structio.util.ki import CTRLC_PROTECTION_ENABLED, critical_section, currently_protected
|
||||
from structio.util.ki import CTRLC_PROTECTION_ENABLED
|
||||
from structio.exceptions import (
|
||||
StructIOException,
|
||||
Cancelled,
|
||||
|
@ -95,10 +95,11 @@ class DefaultKernel(BaseKernel):
|
|||
self.reschedule_running()
|
||||
|
||||
def event(self, evt_name: str, *args):
|
||||
if not hasattr(BaseDebugger, evt_name):
|
||||
if not callable(getattr(BaseDebugger, evt_name, None)):
|
||||
warnings.warn(f"Invalid debugging event fired: {evt_name!r}")
|
||||
return
|
||||
for tool in self.tools:
|
||||
# Not all tools must implement all debugging events!
|
||||
if f := getattr(tool, evt_name, None):
|
||||
try:
|
||||
f(*args)
|
||||
|
@ -244,7 +245,7 @@ class DefaultKernel(BaseKernel):
|
|||
elif schedule:
|
||||
self.current_task: Task
|
||||
# We reschedule the caller immediately!
|
||||
self.policy.schedule(self.current_task, front=True)
|
||||
self.policy.schedule(self.current_task)
|
||||
|
||||
def schedule_point(self):
|
||||
self.reschedule_running()
|
||||
|
@ -318,8 +319,9 @@ class DefaultKernel(BaseKernel):
|
|||
awakened and thrown into
|
||||
"""
|
||||
|
||||
self._sigint_handled = False
|
||||
self.throw(task or self._pick_ki_task(), KeyboardInterrupt())
|
||||
self._sigint_handled = False
|
||||
|
||||
|
||||
def _tick(self):
|
||||
"""
|
||||
|
|
|
@ -0,0 +1,7 @@
|
|||
from . import io, signals
|
||||
|
||||
|
||||
__all__ = [
|
||||
"io",
|
||||
"signals"
|
||||
]
|
|
@ -0,0 +1,5 @@
|
|||
from . import simple
|
||||
|
||||
__all__ = [
|
||||
"simple"
|
||||
]
|
|
@ -47,10 +47,7 @@ class SimpleIOManager(BaseIOManager):
|
|||
so we can select() on them later
|
||||
"""
|
||||
|
||||
result = []
|
||||
for reader in self.readers:
|
||||
result.append(reader.fileno())
|
||||
return result
|
||||
return [reader.fileno() for reader in self.readers]
|
||||
|
||||
def _collect_writers(self) -> list[int]:
|
||||
"""
|
||||
|
@ -58,10 +55,7 @@ class SimpleIOManager(BaseIOManager):
|
|||
so we can select() on them later
|
||||
"""
|
||||
|
||||
result = []
|
||||
for writer in self.writers:
|
||||
result.append(writer.fileno())
|
||||
return result
|
||||
return [writer.fileno() for writer in self.writers]
|
||||
|
||||
def _check_closed(self):
|
||||
if self._closed:
|
||||
|
@ -72,13 +66,13 @@ class SimpleIOManager(BaseIOManager):
|
|||
kernel: BaseKernel = current_loop()
|
||||
current_time = kernel.clock.current_time()
|
||||
deadline = kernel.get_closest_deadline()
|
||||
if deadline == float("inf"):
|
||||
deadline = 0
|
||||
elif deadline > 0:
|
||||
deadline -= current_time
|
||||
# FIXME: This delay seems to help throttle the calls
|
||||
# to this method. Should we be calling it this often?
|
||||
deadline = max(0.01, deadline)
|
||||
if deadline == float("inf"):
|
||||
deadline = 0.01
|
||||
elif deadline > 0:
|
||||
deadline -= current_time
|
||||
deadline = max(0, deadline)
|
||||
readers = self._collect_readers()
|
||||
writers = self._collect_writers()
|
||||
kernel.event("before_io", deadline)
|
||||
|
@ -118,8 +112,7 @@ class SimpleIOManager(BaseIOManager):
|
|||
self.writers[rsc] = task
|
||||
|
||||
def release(self, resource: FdWrapper):
|
||||
if self._closed:
|
||||
return
|
||||
self._check_closed()
|
||||
self.readers.pop(resource, None)
|
||||
self.writers.pop(resource, None)
|
||||
|
||||
|
|
|
@ -0,0 +1,5 @@
|
|||
from . import fifo, lifo
|
||||
|
||||
|
||||
__all__ = ["lifo",
|
||||
"fifo"]
|
|
@ -16,8 +16,8 @@ class FIFOPolicy(SchedulingPolicy):
|
|||
# Paused tasks along with their deadlines
|
||||
self.paused: TimeQueue = TimeQueue()
|
||||
|
||||
# noinspection PyMethodMayBeStatic
|
||||
def is_scheduled(self, task: Task) -> bool:
|
||||
# TODO: This should be fine, make sure of it
|
||||
return task.state == TaskState.READY
|
||||
|
||||
def has_next_task(self) -> bool:
|
||||
|
@ -46,14 +46,11 @@ class FIFOPolicy(SchedulingPolicy):
|
|||
return None
|
||||
return self.paused.get()[0]
|
||||
|
||||
def schedule(self, task: Task, front: bool = False):
|
||||
def schedule(self, task: Task):
|
||||
if self.is_scheduled(task):
|
||||
return
|
||||
task.state = TaskState.READY
|
||||
if front:
|
||||
self.run_queue.append(task)
|
||||
else:
|
||||
self.run_queue.append(task)
|
||||
self.run_queue.append(task)
|
||||
|
||||
def pause(self, task: Task):
|
||||
task.state = TaskState.PAUSED
|
||||
|
@ -63,7 +60,7 @@ class FIFOPolicy(SchedulingPolicy):
|
|||
self.paused.discard(task)
|
||||
|
||||
def get_closest_deadline(self):
|
||||
if self.run_queue:
|
||||
if self.has_next_task():
|
||||
# We absolutely cannot block while other
|
||||
# tasks have things to do!
|
||||
return current_loop().clock.current_time()
|
||||
|
@ -77,4 +74,4 @@ class FIFOPolicy(SchedulingPolicy):
|
|||
min(deadlines),
|
||||
self.paused.get_closest_deadline(),
|
||||
]
|
||||
)
|
||||
)
|
|
@ -0,0 +1,24 @@
|
|||
from structio.core.task import Task
|
||||
from structio.core.policies.fifo import FIFOPolicy
|
||||
|
||||
|
||||
class LIFOPolicy(FIFOPolicy):
|
||||
"""
|
||||
A last-in, first out scheduling policy
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# We redefine it as a list because we
|
||||
# have a different queueing policy
|
||||
self.run_queue: list[Task] = []
|
||||
|
||||
def get_next_task(self) -> Task | None:
|
||||
if not self.run_queue:
|
||||
return None
|
||||
return self.run_queue.pop()
|
||||
|
||||
def peek_next_task(self) -> Task | None:
|
||||
if not self.has_next_task():
|
||||
return
|
||||
return self.run_queue[-1]
|
|
@ -60,14 +60,14 @@ def new_event_loop(kernel: BaseKernel):
|
|||
|
||||
def run(
|
||||
func: Callable[[Any, Any], Coroutine[Any, Any, Any]],
|
||||
kernel: type,
|
||||
policy: SchedulingPolicy,
|
||||
io_manager: BaseIOManager,
|
||||
signal_managers: list[SignalManager],
|
||||
clock: BaseClock,
|
||||
*args,
|
||||
kernel: type | None = None,
|
||||
policy: SchedulingPolicy| None = None,
|
||||
io_manager: BaseIOManager| None = None,
|
||||
signal_managers: list[SignalManager] | None = None,
|
||||
clock: BaseClock | None = None,
|
||||
tools: list[BaseDebugger] | None = None,
|
||||
restrict_ki_to_checkpoints: bool = False,
|
||||
*args,
|
||||
):
|
||||
"""
|
||||
Starts the event loop from a synchronous entry point. All
|
||||
|
@ -76,10 +76,22 @@ def run(
|
|||
using functools.partial()
|
||||
"""
|
||||
|
||||
if kernel is None:
|
||||
kernel = structio.kernel.DefaultKernel
|
||||
if policy is None:
|
||||
policy = structio.core.policies.fifo.FIFOPolicy()
|
||||
if io_manager is None:
|
||||
io_manager = structio.core.managers.io.simple.SimpleIOManager()
|
||||
if clock is None:
|
||||
clock = structio.core.time.clock.DefaultClock()
|
||||
if not issubclass(kernel, BaseKernel):
|
||||
raise TypeError(
|
||||
f"kernel must be a subclass of {BaseKernel.__module__}.{BaseKernel.__qualname__}!"
|
||||
f"kernel must be a subclass of {BaseKernel.__module__}.{BaseKernel.__qualname__}, not {type(kernel)}"
|
||||
)
|
||||
signal_managers = signal_managers or []
|
||||
sigint_manager = structio.core.managers.signals.sigint.SigIntManager()
|
||||
if sigint_manager not in signal_managers:
|
||||
signal_managers.append(sigint_manager)
|
||||
check = func
|
||||
if isinstance(func, functools.partial):
|
||||
check = func.func
|
||||
|
@ -92,8 +104,8 @@ def run(
|
|||
raise StructIOException(
|
||||
"structio.run() requires an async function as its first argument!"
|
||||
)
|
||||
# Used to wake up the signal watcher when signals arrive
|
||||
waker = structio.util.wakeup_fd.WakeupFd()
|
||||
watcher = structio.signals.signal_watcher
|
||||
waker.set_wakeup_fd()
|
||||
new_event_loop(
|
||||
kernel(
|
||||
|
@ -105,5 +117,15 @@ def run(
|
|||
tools=tools,
|
||||
)
|
||||
)
|
||||
current_loop().spawn_system_task(watcher, waker.reader)
|
||||
return current_loop().start(func, *args)
|
||||
current_loop().spawn_system_task(structio.signals.signal_watcher, waker.reader)
|
||||
try:
|
||||
return current_loop().start(func, *args)
|
||||
finally:
|
||||
# Bunch of cleanup to ensure signals across different runs
|
||||
# do not mix up
|
||||
|
||||
# noinspection PyProtectedMember
|
||||
structio.signals._sig_handlers.clear()
|
||||
# noinspection PyProtectedMember
|
||||
structio.signals._sig_data.clear()
|
||||
|
||||
|
|
|
@ -0,0 +1,3 @@
|
|||
from . import queue, clock
|
||||
|
||||
__all__ = ["queue", "clock"]
|
|
@ -1,9 +0,0 @@
|
|||
from structio.abc import BaseDebugger
|
||||
|
||||
|
||||
class SimpleDebugger(BaseDebugger):
|
||||
def on_start(self):
|
||||
print(">> Started")
|
||||
|
||||
def on_exit(self):
|
||||
print(f"<< Stopped")
|
|
@ -266,7 +266,7 @@ async def connect_tcp_socket(
|
|||
# a timeout or something, but it may be something worth
|
||||
# investigating
|
||||
await sock.close()
|
||||
except BaseException as e:
|
||||
except Exception as e:
|
||||
# Again, we shouldn't be ignoring
|
||||
# errors willy-nilly like that, but
|
||||
# hey beta software am I right?
|
||||
|
@ -389,7 +389,7 @@ class AsyncSocket(AsyncResource):
|
|||
self.socket.connect(address)
|
||||
except WantWrite:
|
||||
await wait_writable(self._fd)
|
||||
if self.do_handshake_on_connect:
|
||||
if self.do_handshake_on_connect and hasattr(self.socket, "do_handshake"):
|
||||
await self.do_handshake()
|
||||
|
||||
async def close(self):
|
||||
|
|
Loading…
Reference in New Issue