Fixes and refactoring

This commit is contained in:
Mattia Giambirtone 2024-03-29 15:16:05 +01:00
parent e014460138
commit 760e25bff7
71 changed files with 90 additions and 138 deletions

View File

@ -2,7 +2,7 @@ import structio
import random
async def waiter(ch: structio.ChannelReader):
async def waiter(ch: structio.ReadableChannel):
print("[waiter] Waiter is alive!")
while True:
print("[waiter] Awaiting events")

View File

@ -2,7 +2,7 @@ import structio
from typing import Any
async def reader(ch: structio.ChannelReader):
async def reader(ch: structio.ReadableChannel):
print("[reader] Reader is alive!")
async with ch:
while True:
@ -16,7 +16,7 @@ async def reader(ch: structio.ChannelReader):
print("[reader] Done!")
async def writer(ch: structio.ChannelWriter, objects: list[Any]):
async def writer(ch: structio.WritableChannel, objects: list[Any]):
print("[writer] Writer is alive!")
async with ch:
for obj in objects:

View File

@ -2,6 +2,9 @@ import structio
import subprocess
import shlex
from structio import Task
# In the interest of compatibility, structio.parallel
# tries to be compatible with the subprocess module. You
# can even pass the constants such as DEVNULL, PIPE, etc.

View File

@ -5,7 +5,7 @@ from structio.core.policies.fifo import FIFOPolicy
from structio.core.managers.io.simple import SimpleIOManager
from structio.core.managers.signals.sigint import SigIntManager
from structio.core.time.clock import DefaultClock
from structio.core.syscalls import sleep, suspend as _suspend
from structio.core.syscalls import sleep, suspend
from structio.core.context import TaskPool, TaskScope
from structio.exceptions import (
Cancelled,
@ -25,7 +25,7 @@ from structio.sync import (
Lock,
RLock,
)
from structio.abc import Channel, Stream, ChannelReader, ChannelWriter
from structio.abc import Channel, Stream, ReadableChannel, WritableChannel
from structio.io import socket
from structio.io.socket import AsyncSocket
from structio.io.files import (
@ -83,27 +83,6 @@ def clock():
return _run.current_loop().clock.current_time()
async def _join(self: Task):
if self.done():
return self.result
await _suspend()
assert self.done()
if self.state == TaskState.CRASHED:
raise self.exc
return self.result
def _cancel(self: Task):
_run.current_loop().cancel_task(self)
task._joiner = _join
_cancel.__name__ = Task.cancel.__name__
_cancel.__doc__ = Task.cancel.__doc__
Task.cancel = _cancel
__all__ = [
"run",
"sleep",
@ -117,8 +96,8 @@ __all__ = [
"MemoryChannel",
"Channel",
"Stream",
"ChannelReader",
"ChannelWriter",
"ReadableChannel",
"WritableChannel",
"Semaphore",
"TimedOut",
"Task",

View File

@ -9,7 +9,7 @@ from structio.exceptions import StructIOException
from typing import Callable, Any, Coroutine
class BaseClock(ABC):
class Clock(ABC):
"""
Abstract base clock class
"""
@ -165,7 +165,7 @@ class AsyncResource(ABC):
await self.close()
class StreamWriter(AsyncResource, ABC):
class WritableStream(AsyncResource, ABC):
"""
Interface for writing binary data to
a byte stream
@ -176,7 +176,7 @@ class StreamWriter(AsyncResource, ABC):
raise NotImplementedError
class StreamReader(AsyncResource, ABC):
class ReadableStream(AsyncResource, ABC):
"""
Interface for reading binary data from
a byte stream. The stream implements the
@ -189,7 +189,7 @@ class StreamReader(AsyncResource, ABC):
raise NotImplementedError
class Stream(StreamReader, StreamWriter, ABC):
class Stream(ReadableStream, WritableStream, ABC):
"""
A generic, asynchronous, readable/writable binary stream
"""
@ -229,7 +229,7 @@ class WriteCloseableStream(Stream, ABC):
"""
class ChannelReader(AsyncResource, ABC):
class ReadableChannel(AsyncResource, ABC):
"""
Interface for reading data from a
channel
@ -275,7 +275,7 @@ class ChannelReader(AsyncResource, ABC):
"""
class ChannelWriter(AsyncResource, ABC):
class WritableChannel(AsyncResource, ABC):
"""
Interface for writing data to a
channel
@ -298,13 +298,13 @@ class ChannelWriter(AsyncResource, ABC):
"""
class Channel(ChannelWriter, ChannelReader, ABC):
class Channel(WritableChannel, ReadableChannel, ABC):
"""
A generic, two-way channel
"""
class BaseDebugger(ABC):
class Debugger(ABC):
"""
The base for all debugger objects
"""
@ -315,7 +315,7 @@ class BaseDebugger(ABC):
loop starts executing
"""
raise NotImplementedError
return NotImplemented
def on_exit(self):
"""
@ -323,7 +323,7 @@ class BaseDebugger(ABC):
loop exits entirely (all tasks completed)
"""
raise NotImplementedError
return NotImplemented
def on_task_spawn(self, task: Task):
"""
@ -334,7 +334,7 @@ class BaseDebugger(ABC):
:type task: :class: structio.objects.Task
"""
raise NotImplementedError
return NotImplemented
def on_task_exit(self, task: Task):
"""
@ -344,7 +344,7 @@ class BaseDebugger(ABC):
:type task: :class: structio.objects.Task
"""
raise NotImplementedError
return NotImplemented
def before_task_step(self, task: Task):
"""
@ -355,7 +355,7 @@ class BaseDebugger(ABC):
:type task: :class: structio.objects.Task
"""
raise NotImplementedError
return NotImplemented
def after_task_step(self, task: Task):
"""
@ -366,7 +366,7 @@ class BaseDebugger(ABC):
:type task: :class: structio.objects.Task
"""
raise NotImplementedError
return NotImplemented
def before_sleep(self, task: Task, seconds: float):
"""
@ -380,7 +380,7 @@ class BaseDebugger(ABC):
:type seconds: int
"""
raise NotImplementedError
return NotImplemented
def after_sleep(self, task: Task, seconds: float):
"""
@ -394,7 +394,7 @@ class BaseDebugger(ABC):
:type seconds: float
"""
raise NotImplementedError
return NotImplemented
def before_io(self, timeout: float):
"""
@ -407,7 +407,7 @@ class BaseDebugger(ABC):
:type timeout: float
"""
raise NotImplementedError
return NotImplemented
def after_io(self, timeout: float):
"""
@ -420,7 +420,7 @@ class BaseDebugger(ABC):
:type timeout: float
"""
raise NotImplementedError
return NotImplemented
def before_cancel(self, task: Task):
"""
@ -431,7 +431,7 @@ class BaseDebugger(ABC):
:type task: :class: structio.objects.Task
"""
raise NotImplementedError
return NotImplemented
def after_cancel(self, task: Task) -> object:
"""
@ -442,7 +442,7 @@ class BaseDebugger(ABC):
:type task: :class: structio.objects.Task
"""
raise NotImplementedError
return NotImplemented
def on_exception_raised(self, task: Task, exc: BaseException):
"""
@ -455,27 +455,10 @@ class BaseDebugger(ABC):
:type exc: BaseException
"""
raise NotImplementedError
def on_io_schedule(self, stream, event: str):
"""
This method is called whenever an
I/O resource is registered in the
event loop
"""
raise NotImplementedError
def on_io_unschedule(self, stream):
"""
This method is called whenever an I/O resource
is unregistered from the loop
"""
raise NotImplementedError
return NotImplemented
class BaseIOManager(ABC):
class IOManager(ABC):
"""
Base class for all I/O managers
"""
@ -608,7 +591,7 @@ class SignalManager(ABC):
raise NotImplementedError
class BaseKernel(ABC):
class Kernel(ABC):
"""
Abstract kernel base class
"""
@ -616,28 +599,28 @@ class BaseKernel(ABC):
def __init__(
self,
policy: SchedulingPolicy,
clock: BaseClock,
io_manager: BaseIOManager,
clock: Clock,
io_manager: IOManager,
signal_managers: list[SignalManager],
tools: list[BaseDebugger] | None = None,
tools: list[Debugger] | None = None,
restrict_ki_to_checkpoints: bool = False,
):
if not issubclass(clock.__class__, BaseClock):
if not issubclass(clock.__class__, Clock):
raise TypeError(
f"clock must be a subclass of {BaseClock.__module__}.{BaseClock.__qualname__}, not {type(clock)}"
f"clock must be a subclass of {Clock.__module__}.{Clock.__qualname__}, not {type(clock)}"
)
if not issubclass(policy.__class__, SchedulingPolicy):
raise TypeError(
f"policy must be a subclass of {SchedulingPolicy.__module__}.{SchedulingPolicy.__qualname__}, not {type(policy)}"
)
if not issubclass(io_manager.__class__, BaseIOManager):
if not issubclass(io_manager.__class__, IOManager):
raise TypeError(
f"io_manager must be a subclass of {BaseIOManager.__module__}.{BaseIOManager.__qualname__}, not {type(io_manager)}"
f"io_manager must be a subclass of {IOManager.__module__}.{IOManager.__qualname__}, not {type(io_manager)}"
)
for tool in tools or []:
if not issubclass(tool.__class__, BaseDebugger):
if not issubclass(tool.__class__, Debugger):
raise TypeError(
f"tools must be a subclass of {BaseDebugger.__module__}.{BaseDebugger.__qualname__}, not {type(tool)}"
f"tools must be a subclass of {Debugger.__module__}.{Debugger.__qualname__}, not {type(tool)}"
)
for mgr in signal_managers or []:
if not issubclass(mgr.__class__, SignalManager):
@ -648,7 +631,7 @@ class BaseKernel(ABC):
self.current_task: Task | None = None
self.current_pool: "structio.TaskPool" = None # noqa
self.current_scope: structio.TaskScope = None # noqa
self.tools: list[BaseDebugger] = tools or []
self.tools: list[Debugger] = tools or []
self.restrict_ki_to_checkpoints: bool = restrict_ki_to_checkpoints
self.io_manager = io_manager
self.signal_managers = signal_managers

View File

@ -2,10 +2,10 @@ import traceback
import warnings
from types import FrameType
from structio.abc import (
BaseKernel,
BaseClock,
BaseDebugger,
BaseIOManager,
Kernel,
Clock,
Debugger,
IOManager,
SignalManager,
SchedulingPolicy,
)
@ -28,7 +28,7 @@ import signal
import sniffio
class DefaultKernel(BaseKernel):
class DefaultKernel(Kernel):
"""
An asynchronous event loop implementation
supporting generic scheduling policies
@ -37,10 +37,10 @@ class DefaultKernel(BaseKernel):
def __init__(
self,
policy: SchedulingPolicy,
clock: BaseClock,
io_manager: BaseIOManager,
clock: Clock,
io_manager: IOManager,
signal_managers: list[SignalManager],
tools: list[BaseDebugger] | None = None,
tools: list[Debugger] | None = None,
restrict_ki_to_checkpoints: bool = False,
):
super().__init__(
@ -95,7 +95,7 @@ class DefaultKernel(BaseKernel):
self.reschedule_running()
def event(self, evt_name: str, *args):
if not callable(getattr(BaseDebugger, evt_name, None)):
if not callable(getattr(Debugger, evt_name, None)):
warnings.warn(f"Invalid debugging event fired: {evt_name!r}")
return
for tool in self.tools:
@ -322,7 +322,6 @@ class DefaultKernel(BaseKernel):
self.throw(task or self._pick_ki_task(), KeyboardInterrupt())
self._sigint_handled = False
def _tick(self):
"""
Runs a single event loop tick

View File

@ -1,14 +1,14 @@
import warnings
import structio
from structio.abc import BaseIOManager, BaseKernel
from structio.abc import IOManager, Kernel
from structio.core.task import Task, TaskState
from structio.core.run import current_loop, current_task
from structio.io import FdWrapper
import select
class SimpleIOManager(BaseIOManager):
class SimpleIOManager(IOManager):
"""
A simple, cross-platform, select()-based
I/O manager. This class is only meant to
@ -63,7 +63,7 @@ class SimpleIOManager(BaseIOManager):
def wait_io(self):
self._check_closed()
kernel: BaseKernel = current_loop()
kernel: Kernel = current_loop()
current_time = kernel.clock.current_time()
deadline = kernel.get_closest_deadline()
# FIXME: This delay seems to help throttle the calls

View File

@ -5,11 +5,11 @@ import functools
# I *really* hate fork()
from multiprocessing_utils import local
from structio.abc import (
BaseKernel,
BaseDebugger,
BaseClock,
Kernel,
Debugger,
Clock,
SignalManager,
BaseIOManager,
IOManager,
SchedulingPolicy,
)
from structio.exceptions import StructIOException
@ -19,7 +19,7 @@ from typing import Callable, Any, Coroutine
_RUN = local()
def current_loop() -> BaseKernel:
def current_loop() -> Kernel:
"""
Returns the current event loop in the calling
thread. Raises a StructIOException if no async
@ -39,7 +39,7 @@ def current_task() -> Task:
return current_loop().current_task
def new_event_loop(kernel: BaseKernel):
def new_event_loop(kernel: Kernel):
"""
Initializes a new event loop using the
given kernel implementation. Cannot be
@ -63,10 +63,10 @@ def run(
*args,
kernel: type | None = None,
policy: SchedulingPolicy| None = None,
io_manager: BaseIOManager| None = None,
io_manager: IOManager | None = None,
signal_managers: list[SignalManager] | None = None,
clock: BaseClock | None = None,
tools: list[BaseDebugger] | None = None,
clock: Clock | None = None,
tools: list[Debugger] | None = None,
restrict_ki_to_checkpoints: bool = False,
):
"""
@ -84,9 +84,9 @@ def run(
io_manager = structio.core.managers.io.simple.SimpleIOManager()
if clock is None:
clock = structio.core.time.clock.DefaultClock()
if not issubclass(kernel, BaseKernel):
if not issubclass(kernel, Kernel):
raise TypeError(
f"kernel must be a subclass of {BaseKernel.__module__}.{BaseKernel.__qualname__}, not {type(kernel)}"
f"kernel must be a subclass of {Kernel.__module__}.{Kernel.__qualname__}, not {type(kernel)}"
)
signal_managers = signal_managers or []
sigint_manager = structio.core.managers.signals.sigint.SigIntManager()

View File

@ -2,6 +2,7 @@ from enum import Enum, auto
from dataclasses import dataclass, field
from typing import Coroutine, Any, Callable
from itertools import count
import structio
_counter = count()
@ -32,7 +33,7 @@ class Task:
# task
coroutine: Coroutine = field(repr=False)
# The task's scope
scope: "TaskScope"
scope: "TaskScope" = field(repr=False)
# The task's pool
pool: "TaskPool" = field(repr=False)
# The state of the task
@ -68,19 +69,11 @@ class Task:
Implements hash(self)
"""
return self.coroutine.__hash__()
# These are patched later at import time!
def __await__(self):
"""
Wait for the task to complete and return/raise appropriately (returns when cancelled)
"""
return _joiner(self).__await__()
return hash(self.coroutine)
def cancel(self):
"""
Cancels the given task
"""
return NotImplemented
structio.current_loop().cancel_task(self)

View File

@ -1,9 +1,9 @@
import random
from timeit import default_timer
from structio.abc import BaseClock
from structio.abc import Clock
class DefaultClock(BaseClock):
class DefaultClock(Clock):
def __init__(self):
super().__init__()
# We add a large random offset to our timer value

View File

@ -175,21 +175,16 @@ class Process:
return status
async def communicate(self, input=b"") -> tuple[bytes, bytes]:
async with structio.create_pool() as pool:
stdout = pool.spawn(self.stdout.readall) if self.stdout else None
stderr = pool.spawn(self.stderr.readall) if self.stderr else None
if input:
await self.stdin.write(input)
await self.stdin.close()
# Awaiting a task object waits for its completion and
# returns its return value!
out = b""
err = b""
if stdout:
out = await stdout
if stderr:
err = await stderr
return out, err
if input:
await self.stdin.write(input)
await self.stdin.close()
out = b""
err = b""
if self.stdout:
out = await self.stdout.readall()
if self.stderr:
err = await self.stderr.readall()
return out, err
async def __aenter__(self):
self.start()

View File

@ -3,7 +3,7 @@ import structio
from structio.core.syscalls import suspend, checkpoint
from structio.exceptions import ResourceClosed, WouldBlock
from structio.core.run import current_task, current_loop
from structio.abc import ChannelReader, ChannelWriter, Channel
from structio.abc import ReadableChannel, WritableChannel, Channel
from structio.util.ki import enable_ki_protection
from structio.util.misc import ThereCanBeOnlyOne
from structio.core.task import Task
@ -221,7 +221,7 @@ class PriorityQueue(Queue):
heappush(self.container, item)
class MemoryReceiveChannel(ChannelReader):
class MemoryReceiveChannel(ReadableChannel):
"""
An in-memory one-way channel to read
data
@ -249,7 +249,7 @@ class MemoryReceiveChannel(ChannelReader):
return len(self._buffer.getters)
class MemorySendChannel(ChannelWriter):
class MemorySendChannel(WritableChannel):
"""
An in-memory one-way channel to send
data

View File

@ -6,7 +6,7 @@ from functools import partial
import structio
import threading
from collections import deque
from structio.abc import BaseKernel
from structio.abc import Kernel
from structio.core.run import current_loop
from typing import Callable, Any, Coroutine
from structio.core.syscalls import checkpoint
@ -68,7 +68,7 @@ class AsyncThreadEvent(Event):
# will call current_loop(), and we may have been
# called from an async thread that doesn't have a
# loop
loop: BaseKernel = _storage.parent_loop
loop: Kernel = _storage.parent_loop
for task in self._tasks:
loop.reschedule(task)
# Awakes all threads
@ -164,7 +164,7 @@ class AsyncThreadQueue(Queue):
def _threaded_runner(
f,
parent_loop: BaseKernel,
parent_loop: Kernel,
rq: AsyncThreadQueue,
rsq: AsyncThreadQueue,
evt: AsyncThreadEvent,

View File

@ -1,8 +1,8 @@
from structio.abc import BaseDebugger
from structio.abc import Debugger
from structio.core.task import Task
class SimpleDebugger(BaseDebugger):
class SimpleDebugger(Debugger):
"""
A simple debugger for structio
"""