Minor refactoring and cleanup. Added LIFO policy

This commit is contained in:
Mattia Giambirtone 2024-03-26 11:54:28 +01:00
parent 9e1301322a
commit 1ba76ecdee
15 changed files with 132 additions and 84 deletions

View File

@ -1,6 +1,6 @@
from structio.core import run as _run
from typing import Coroutine, Any, Callable
from structio.core.kernel import DefaultKernel
from structio.core.run import run
from structio.core import kernel
from structio.core.policies.fifo import FIFOPolicy
from structio.core.managers.io.simple import SimpleIOManager
from structio.core.managers.signals.sigint import SigIntManager
@ -46,34 +46,6 @@ from structio import signals as _signals
from structio import util
def run(
func: Callable[[Any, Any], Coroutine[Any, Any, Any]],
*args,
restrict_ki_to_checkpoints: bool = False,
tools: list | None = None,
):
try:
result = _run.run(
func,
DefaultKernel,
FIFOPolicy(),
SimpleIOManager(),
[SigIntManager()],
DefaultClock(),
tools,
restrict_ki_to_checkpoints,
*args,
)
finally:
# Bunch of cleanup
_signals._sig_handlers.clear() # noqa
_signals._sig_data.clear() # noqa
return result
run.__doc__ = _run.run.__doc__
def create_pool() -> TaskPool:
"""
Creates a new task pool

View File

@ -104,10 +104,9 @@ class SchedulingPolicy(ABC):
raise NotImplementedError
@abstractmethod
def schedule(self, task: Task, front: bool = False):
def schedule(self, task: Task):
"""
Schedules a task for execution. If front is true,
the task will be the next one to be scheduled
Schedules a task for execution
"""
raise NotImplementedError
@ -623,6 +622,28 @@ class BaseKernel(ABC):
tools: list[BaseDebugger] | None = None,
restrict_ki_to_checkpoints: bool = False,
):
if not issubclass(clock.__class__, BaseClock):
raise TypeError(
f"clock must be a subclass of {BaseClock.__module__}.{BaseClock.__qualname__}, not {type(clock)}"
)
if not issubclass(policy.__class__, SchedulingPolicy):
raise TypeError(
f"policy must be a subclass of {SchedulingPolicy.__module__}.{SchedulingPolicy.__qualname__}, not {type(policy)}"
)
if not issubclass(io_manager.__class__, BaseIOManager):
raise TypeError(
f"io_manager must be a subclass of {BaseIOManager.__module__}.{BaseIOManager.__qualname__}, not {type(io_manager)}"
)
for tool in tools or []:
if not issubclass(tool.__class__, BaseDebugger):
raise TypeError(
f"tools must be a subclass of {BaseDebugger.__module__}.{BaseDebugger.__qualname__}, not {type(tool)}"
)
for mgr in signal_managers or []:
if not issubclass(mgr.__class__, SignalManager):
raise TypeError(
f"signal manager must be a subclass of {SignalManager.__module__}.{SignalManager.__qualname__}, not {type(mgr)}"
)
self.clock = clock
self.current_task: Task | None = None
self.current_pool: "structio.TaskPool" = None # noqa

View File

@ -0,0 +1,8 @@
from . import managers, policies, time
__all__ = [
"managers",
"policies",
"time"
]

View File

@ -71,7 +71,7 @@ class TaskScope:
queue = TimeQueue()
if self.shielded:
return float("inf"), self
times = queue.put(self, self.deadline)
queue.put(self, self.deadline)
for child in self.children:
if child.shielded:
return float("inf"), self
@ -91,8 +91,6 @@ class TaskScope:
if exc_val and isinstance(exc_val, structio.TimedOut):
if exc_val.scope is self:
return self.silent
return True
return False
# Just a recursive helper
def _get_children(self, lst=None):

View File

@ -12,7 +12,7 @@ from structio.abc import (
from structio.io import FdWrapper
from structio.core.context import TaskPool, TaskScope
from structio.core.task import Task, TaskState
from structio.util.ki import CTRLC_PROTECTION_ENABLED, critical_section, currently_protected
from structio.util.ki import CTRLC_PROTECTION_ENABLED
from structio.exceptions import (
StructIOException,
Cancelled,
@ -95,10 +95,11 @@ class DefaultKernel(BaseKernel):
self.reschedule_running()
def event(self, evt_name: str, *args):
if not hasattr(BaseDebugger, evt_name):
if not callable(getattr(BaseDebugger, evt_name, None)):
warnings.warn(f"Invalid debugging event fired: {evt_name!r}")
return
for tool in self.tools:
# Not all tools must implement all debugging events!
if f := getattr(tool, evt_name, None):
try:
f(*args)
@ -244,7 +245,7 @@ class DefaultKernel(BaseKernel):
elif schedule:
self.current_task: Task
# We reschedule the caller immediately!
self.policy.schedule(self.current_task, front=True)
self.policy.schedule(self.current_task)
def schedule_point(self):
self.reschedule_running()
@ -318,8 +319,9 @@ class DefaultKernel(BaseKernel):
awakened and thrown into
"""
self._sigint_handled = False
self.throw(task or self._pick_ki_task(), KeyboardInterrupt())
self._sigint_handled = False
def _tick(self):
"""

View File

@ -0,0 +1,7 @@
from . import io, signals
__all__ = [
"io",
"signals"
]

View File

@ -0,0 +1,5 @@
from . import simple
__all__ = [
"simple"
]

View File

@ -47,10 +47,7 @@ class SimpleIOManager(BaseIOManager):
so we can select() on them later
"""
result = []
for reader in self.readers:
result.append(reader.fileno())
return result
return [reader.fileno() for reader in self.readers]
def _collect_writers(self) -> list[int]:
"""
@ -58,10 +55,7 @@ class SimpleIOManager(BaseIOManager):
so we can select() on them later
"""
result = []
for writer in self.writers:
result.append(writer.fileno())
return result
return [writer.fileno() for writer in self.writers]
def _check_closed(self):
if self._closed:
@ -72,13 +66,13 @@ class SimpleIOManager(BaseIOManager):
kernel: BaseKernel = current_loop()
current_time = kernel.clock.current_time()
deadline = kernel.get_closest_deadline()
if deadline == float("inf"):
deadline = 0
elif deadline > 0:
deadline -= current_time
# FIXME: This delay seems to help throttle the calls
# to this method. Should we be calling it this often?
deadline = max(0.01, deadline)
if deadline == float("inf"):
deadline = 0.01
elif deadline > 0:
deadline -= current_time
deadline = max(0, deadline)
readers = self._collect_readers()
writers = self._collect_writers()
kernel.event("before_io", deadline)
@ -118,8 +112,7 @@ class SimpleIOManager(BaseIOManager):
self.writers[rsc] = task
def release(self, resource: FdWrapper):
if self._closed:
return
self._check_closed()
self.readers.pop(resource, None)
self.writers.pop(resource, None)

View File

@ -0,0 +1,5 @@
from . import fifo, lifo
__all__ = ["lifo",
"fifo"]

View File

@ -16,8 +16,8 @@ class FIFOPolicy(SchedulingPolicy):
# Paused tasks along with their deadlines
self.paused: TimeQueue = TimeQueue()
# noinspection PyMethodMayBeStatic
def is_scheduled(self, task: Task) -> bool:
# TODO: This should be fine, make sure of it
return task.state == TaskState.READY
def has_next_task(self) -> bool:
@ -46,14 +46,11 @@ class FIFOPolicy(SchedulingPolicy):
return None
return self.paused.get()[0]
def schedule(self, task: Task, front: bool = False):
def schedule(self, task: Task):
if self.is_scheduled(task):
return
task.state = TaskState.READY
if front:
self.run_queue.append(task)
else:
self.run_queue.append(task)
self.run_queue.append(task)
def pause(self, task: Task):
task.state = TaskState.PAUSED
@ -63,7 +60,7 @@ class FIFOPolicy(SchedulingPolicy):
self.paused.discard(task)
def get_closest_deadline(self):
if self.run_queue:
if self.has_next_task():
# We absolutely cannot block while other
# tasks have things to do!
return current_loop().clock.current_time()
@ -77,4 +74,4 @@ class FIFOPolicy(SchedulingPolicy):
min(deadlines),
self.paused.get_closest_deadline(),
]
)
)

View File

@ -0,0 +1,24 @@
from structio.core.task import Task
from structio.core.policies.fifo import FIFOPolicy
class LIFOPolicy(FIFOPolicy):
"""
A last-in, first out scheduling policy
"""
def __init__(self):
super().__init__()
# We redefine it as a list because we
# have a different queueing policy
self.run_queue: list[Task] = []
def get_next_task(self) -> Task | None:
if not self.run_queue:
return None
return self.run_queue.pop()
def peek_next_task(self) -> Task | None:
if not self.has_next_task():
return
return self.run_queue[-1]

View File

@ -60,14 +60,14 @@ def new_event_loop(kernel: BaseKernel):
def run(
func: Callable[[Any, Any], Coroutine[Any, Any, Any]],
kernel: type,
policy: SchedulingPolicy,
io_manager: BaseIOManager,
signal_managers: list[SignalManager],
clock: BaseClock,
*args,
kernel: type | None = None,
policy: SchedulingPolicy| None = None,
io_manager: BaseIOManager| None = None,
signal_managers: list[SignalManager] | None = None,
clock: BaseClock | None = None,
tools: list[BaseDebugger] | None = None,
restrict_ki_to_checkpoints: bool = False,
*args,
):
"""
Starts the event loop from a synchronous entry point. All
@ -76,10 +76,22 @@ def run(
using functools.partial()
"""
if kernel is None:
kernel = structio.kernel.DefaultKernel
if policy is None:
policy = structio.core.policies.fifo.FIFOPolicy()
if io_manager is None:
io_manager = structio.core.managers.io.simple.SimpleIOManager()
if clock is None:
clock = structio.core.time.clock.DefaultClock()
if not issubclass(kernel, BaseKernel):
raise TypeError(
f"kernel must be a subclass of {BaseKernel.__module__}.{BaseKernel.__qualname__}!"
f"kernel must be a subclass of {BaseKernel.__module__}.{BaseKernel.__qualname__}, not {type(kernel)}"
)
signal_managers = signal_managers or []
sigint_manager = structio.core.managers.signals.sigint.SigIntManager()
if sigint_manager not in signal_managers:
signal_managers.append(sigint_manager)
check = func
if isinstance(func, functools.partial):
check = func.func
@ -92,8 +104,8 @@ def run(
raise StructIOException(
"structio.run() requires an async function as its first argument!"
)
# Used to wake up the signal watcher when signals arrive
waker = structio.util.wakeup_fd.WakeupFd()
watcher = structio.signals.signal_watcher
waker.set_wakeup_fd()
new_event_loop(
kernel(
@ -105,5 +117,15 @@ def run(
tools=tools,
)
)
current_loop().spawn_system_task(watcher, waker.reader)
return current_loop().start(func, *args)
current_loop().spawn_system_task(structio.signals.signal_watcher, waker.reader)
try:
return current_loop().start(func, *args)
finally:
# Bunch of cleanup to ensure signals across different runs
# do not mix up
# noinspection PyProtectedMember
structio.signals._sig_handlers.clear()
# noinspection PyProtectedMember
structio.signals._sig_data.clear()

View File

@ -0,0 +1,3 @@
from . import queue, clock
__all__ = ["queue", "clock"]

View File

@ -1,9 +0,0 @@
from structio.abc import BaseDebugger
class SimpleDebugger(BaseDebugger):
def on_start(self):
print(">> Started")
def on_exit(self):
print(f"<< Stopped")

View File

@ -266,7 +266,7 @@ async def connect_tcp_socket(
# a timeout or something, but it may be something worth
# investigating
await sock.close()
except BaseException as e:
except Exception as e:
# Again, we shouldn't be ignoring
# errors willy-nilly like that, but
# hey beta software am I right?
@ -389,7 +389,7 @@ class AsyncSocket(AsyncResource):
self.socket.connect(address)
except WantWrite:
await wait_writable(self._fd)
if self.do_handshake_on_connect:
if self.do_handshake_on_connect and hasattr(self.socket, "do_handshake"):
await self.do_handshake()
async def close(self):