diff --git a/structio/__init__.py b/structio/__init__.py index 3392059..bbb1fe0 100644 --- a/structio/__init__.py +++ b/structio/__init__.py @@ -9,7 +9,17 @@ from structio.core.context import TaskPool, TaskScope from structio.exceptions import Cancelled, TimedOut, ResourceClosed from structio.core import task from structio.core.task import Task, TaskState -from structio.sync import Event, Queue, MemoryChannel, Semaphore, Lock, RLock, emit, on_event, register_event +from structio.sync import ( + Event, + Queue, + MemoryChannel, + Semaphore, + Lock, + RLock, + emit, + on_event, + register_event, +) from structio.abc import Channel, Stream, ChannelReader, ChannelWriter from structio.io import socket from structio.io.socket import AsyncSocket @@ -25,6 +35,7 @@ from structio.io.files import ( from structio.core.run import current_loop, current_task from structio import thread, parallel from structio.path import Path +from structio.signals import set_signal_handler, get_signal_handler def run( @@ -147,5 +158,7 @@ __all__ = [ "current_loop", "current_task", "Path", - "parallel" + "parallel", + "get_signal_handler", + "set_signal_handler", ] diff --git a/structio/abc.py b/structio/abc.py index 4629dfc..1cff6f6 100644 --- a/structio/abc.py +++ b/structio/abc.py @@ -519,7 +519,9 @@ class BaseKernel(ABC): return NotImplemented @abstractmethod - def notify_closing(self, resource: AsyncResource, broken: bool = False, owner: Task | None = None): + def notify_closing( + self, resource: AsyncResource, broken: bool = False, owner: Task | None = None + ): """ Notifies the event loop that a given resource is about to be closed and can be unscheduled diff --git a/structio/core/kernels/fifo.py b/structio/core/kernels/fifo.py index aa13f67..83f1d30 100644 --- a/structio/core/kernels/fifo.py +++ b/structio/core/kernels/fifo.py @@ -13,7 +13,13 @@ from structio.core.context import TaskPool, TaskScope from structio.core.task import Task, TaskState from structio.util.ki import CTRLC_PROTECTION_ENABLED from structio.core.time.queue import TimeQueue -from structio.exceptions import StructIOException, Cancelled, TimedOut, ResourceClosed, ResourceBroken +from structio.exceptions import ( + StructIOException, + Cancelled, + TimedOut, + ResourceClosed, + ResourceBroken, +) from collections import deque from typing import Callable, Coroutine, Any from functools import partial @@ -76,7 +82,9 @@ class FIFOKernel(BaseKernel): self.current_task.state = TaskState.IO self.io_manager.request_write(resource, self.current_task) - def notify_closing(self, resource: FdWrapper, broken: bool = False, owner: Task | None = None): + def notify_closing( + self, resource: FdWrapper, broken: bool = False, owner: Task | None = None + ): if not broken: exc = ResourceClosed("stream has been closed") else: @@ -434,7 +442,11 @@ class FIFOKernel(BaseKernel): if task is self.current_task: continue self.cancel_task(task) - if scope is not self.current_task.pool.scope and scope.owner is not self.current_task: + if ( + scope is not self.current_task.pool.scope + and scope.owner is not self.current_task + and scope.owner is not self.entry_point + ): # Handles the case where the current task calls # cancel() for a scope which it doesn't own, which # is an entirely reasonable thing to do diff --git a/structio/core/managers/signals/sigint.py b/structio/core/managers/signals/sigint.py index c0b0985..2c41947 100644 --- a/structio/core/managers/signals/sigint.py +++ b/structio/core/managers/signals/sigint.py @@ -1,55 +1,10 @@ from structio.abc import SignalManager from structio.util.ki import currently_protected +from structio.signals import set_signal_handler from structio.core.run import current_loop -from structio.io.socket import AsyncSocket -from structio.thread import AsyncThreadQueue from types import FrameType import warnings import signal -import socket - - -# TODO: This can (and should) be refactored to work with any signal -# so that users can tap into this machinery and handle "asynchronous -# signals" (kind of). Something similar to trio.open_signal_receiver, -# but maybe not quite as restrictive (i.e. it might be a good idea to -# just let users set an "async signal handler" instead of using an iterator -# which temporarily blocks all signals that we want to catch) - -_sig_data = AsyncThreadQueue(float("inf")) - - -async def signal_watcher(sock: AsyncSocket): - while True: - # Even though we use set_wakeup_fd (which makes sure - # our I/O manager is signal-aware and exits cleanly - # when they arrive), it turns out that actually using the - # data Python sends over our socket is trickier than it - # sounds at first. That is because if we receive a bunch - # of signals and the socket buffer gets filled, we are going - # to lose all signals after that. Python can raise a warning - # about this, but it's 1) Not ideal, we're still losing signals, - # which is bad if we can do better and 2) It can be confusing, - # because now we're leaking details about the way signals are - # implemented, and that sucks too; So instead, we use set_wakeup_fd - # merely as a notification mechanism to wake up our watcher and - # register a custom signal handler that stores all the information - # about incoming signals in an unbuffered queue (which means that even - # if the socket's buffer gets filled, we are still going to deliver all - # signals when we do our first call to read()). I'm a little uneasy about - # using an unbounded queue, but realistically I doubt that one would face - # memory problems because their code is receiving thousands of signals and - # the event loop is not handling them fast enough (right?) - await sock.receive(1) - while _sig_data: - sig, frame = await _sig_data.get() - match sig: - case signal.SIGINT: - if currently_protected(): - current_loop().signal_notify(sig, frame) - else: - current_loop().reschedule(current_loop().entry_point) - current_loop().throw(current_loop().entry_point, KeyboardInterrupt()) class SigIntManager(SignalManager): @@ -59,14 +14,14 @@ class SigIntManager(SignalManager): def __init__(self): self.installed = False - self.reader, self.writer = socket.socketpair() @staticmethod - def _handle(sig: int, frame: FrameType): - # Submit signal info to our asynchronous - # watcher. This call never blocks because - # _sig_data is unbounded - _sig_data.put_sync((sig, frame)) + async def _handle(sig: int, frame: FrameType): + if currently_protected(): + current_loop().signal_notify(sig, frame) + else: + current_loop().reschedule(current_loop().entry_point) + current_loop().throw(current_loop().entry_point, KeyboardInterrupt()) def install(self): if signal.getsignal(signal.SIGINT) != signal.default_int_handler: @@ -75,17 +30,7 @@ class SigIntManager(SignalManager): f" this is likely to break KeyboardInterrupt delivery!" ) return - loop = current_loop() - signal.signal(signal.SIGINT, self._handle) - # This allows us to semi-cleanly handle a Ctrl+C - # (or better, any signal) even when we're blocked - # in select() or similar (we register the reading - # end of the pair into the event loop so that as - # soon as a signal arrives, our watcher is scheduled) - self.writer.setblocking(False) - signal.set_wakeup_fd(self.writer.fileno()) - sock = AsyncSocket(self.reader) - loop.spawn(signal_watcher, sock) + set_signal_handler(signal.SIGINT, self._handle) self.installed = True def uninstall(self): diff --git a/structio/exceptions.py b/structio/exceptions.py index d14aacc..8729441 100644 --- a/structio/exceptions.py +++ b/structio/exceptions.py @@ -45,4 +45,3 @@ class ResourceBroken(StructIOException): Raised when an asynchronous resource gets corrupted and is no longer usable """ - diff --git a/structio/io/__init__.py b/structio/io/__init__.py index 31b24fc..12228a0 100644 --- a/structio/io/__init__.py +++ b/structio/io/__init__.py @@ -1,11 +1,19 @@ # This is, ahem, inspired by Curio and Trio. See https://github.com/dabeaz/curio/issues/104 import io import os -from structio.core.syscalls import checkpoint, wait_readable, wait_writable, closing, release +from structio.core.syscalls import ( + checkpoint, + wait_readable, + wait_writable, + closing, + release, +) from structio.exceptions import ResourceClosed from structio.abc import AsyncResource + try: from ssl import SSLWantReadError, SSLWantWriteError, SSLSocket + WantRead = (BlockingIOError, SSLWantReadError, InterruptedError) WantWrite = (BlockingIOError, SSLWantWriteError, InterruptedError) except ImportError: @@ -29,7 +37,7 @@ class FdWrapper: of whether the wrapped fd is an int or something else entirely """ - __slots__ = ("fd", ) + __slots__ = ("fd",) def __init__(self, fd): self.fd = fd @@ -51,10 +59,7 @@ class AsyncStream(AsyncResource): a file-like object, with buffering """ - def __init__( - self, - fileobj - ): + def __init__(self, fileobj): self.fileobj = fileobj self._fd = FdWrapper(self.fileobj.fileno()) self._buf = bytearray() @@ -97,7 +102,7 @@ class AsyncStream(AsyncResource): while True: chunk = await self.read(maxread) if not chunk: - return b''.join(chunks) + return b"".join(chunks) chunks.append(chunk) if len(chunk) == maxread: maxread *= 2 diff --git a/structio/io/socket.py b/structio/io/socket.py index 2ee9d01..79c85e9 100644 --- a/structio/io/socket.py +++ b/structio/io/socket.py @@ -5,9 +5,16 @@ from structio.abc import AsyncResource from structio.io import FdWrapper, WantRead, WantWrite, SSLSocket from structio.thread import run_in_worker from structio.exceptions import ResourceClosed, ResourceBroken -from structio.core.syscalls import wait_readable, wait_writable, checkpoint, closing, release +from structio.core.syscalls import ( + wait_readable, + wait_writable, + checkpoint, + closing, + release, +) from functools import wraps import socket as _socket + try: import ssl as _ssl except ImportError: @@ -20,12 +27,18 @@ def socket(*args, **kwargs): @wraps(_socket.fromfd) -async def fromfd(fd: Any, family: _socket.AddressFamily | int, - type: _socket.SocketKind | int, proto: int = 0) -> "AsyncSocket": +async def fromfd( + fd: Any, + family: _socket.AddressFamily | int, + type: _socket.SocketKind | int, + proto: int = 0, +) -> "AsyncSocket": return AsyncSocket(_socket.fromfd(fd, family, type, proto)) -async def wrap_socket_with_ssl(sock, *args, context, do_handshake_on_connect=True, **kwargs): +async def wrap_socket_with_ssl( + sock, *args, context, do_handshake_on_connect=True, **kwargs +): """ Wraps a regular unencrypted socket or a structio async socket into a TLS-capable asynchronous socket. All positional and keyword arguments @@ -59,21 +72,28 @@ async def wrap_socket_with_ssl(sock, *args, context, do_handshake_on_connect=Tru # Wrappers of the socket module + @wraps(_socket.socketpair) -def socketpair(family=None, type=_socket.SOCK_STREAM, proto=0) -> tuple["AsyncSocket", "AsyncSocket"]: +def socketpair( + family=None, type=_socket.SOCK_STREAM, proto=0 +) -> tuple["AsyncSocket", "AsyncSocket"]: a, b = _socket.socketpair(family, type, proto) return AsyncSocket(a), AsyncSocket(b) @wraps(_socket.getaddrinfo) -async def getaddrinfo(host: bytearray | bytes | str | None, - port: str | int | None, - family: int = 0, - type: int = 0, - proto: int = 0, - flags: int = 0): - return await run_in_worker(_socket.getaddrinfo, host, port, family, type, proto, flags, - cancellable=True) +async def getaddrinfo( + host: bytearray | bytes | str | None, + port: str | int | None, + family: int = 0, + type: int = 0, + proto: int = 0, + flags: int = 0, +): + return await run_in_worker( + _socket.getaddrinfo, host, port, family, type, proto, flags, cancellable=True + ) + @wraps(_socket.getfqdn) async def getfqdn(name: str) -> str: @@ -81,7 +101,9 @@ async def getfqdn(name: str) -> str: @wraps(_socket.getnameinfo) -async def getnameinfo(sockaddr: tuple[str, int] | tuple[str, int, int, int], flags: int) -> tuple[str, str]: +async def getnameinfo( + sockaddr: tuple[str, int] | tuple[str, int, int, int], flags: int +) -> tuple[str, str]: return await run_in_worker(_socket.getnameinfo, sockaddr, flags, cancellable=True) @@ -110,9 +132,13 @@ async def gethostbyname_ex(hostname: str) -> tuple[str, list[str], list[str]]: CONNECT_DELAY: float = 0.250 -async def connect_socket(host: str | bytes, port: int, *, - source_address=None, - happy_eyeballs_delay: float = CONNECT_DELAY) -> "AsyncSocket": +async def connect_socket( + host: str | bytes, + port: int, + *, + source_address=None, + happy_eyeballs_delay: float = CONNECT_DELAY, +) -> "AsyncSocket": """ Resolve the given (non-numeric) host and attempt to connect to it, at the chosen port. Connection attempts are made according to the "Happy eyeballs" algorithm as per RFC @@ -181,7 +207,9 @@ async def connect_socket(host: str | bytes, port: int, *, await attempt_sock.bind((source_address, 0)) except OSError: # Almost hit the 120 character line, phew... - raise OSError(f"Source addr {source_address!r} is incompatible with remote addr {addr!r}") + raise OSError( + f"Source addr {source_address!r} is incompatible with remote addr {addr!r}" + ) await attempt_sock.connect(addr) # Hooray! Connection was successful. Record the socket # and cancel the rest of the attempts (either future or @@ -223,27 +251,39 @@ async def connect_socket(host: str | bytes, port: int, *, # Again, we shouldn't be ignoring # errors willy-nilly like that, but # hey beta software am I right? - warnings.warn(f"Failed to close {sock!r} in call to connect_socket -> {type(e).__name__}: {e}") + warnings.warn( + f"Failed to close {sock!r} in call to connect_socket -> {type(e).__name__}: {e}" + ) if not successful: # All connection attempts failed raise OSError(f"connecting to {host}:{port} failed") return successful -async def connect_ssl_socket(host: str | bytes, port: int, *, - ssl_context=None, - source_address=None, - happy_eyeballs_delay: float = CONNECT_DELAY) -> "AsyncSocket": +async def connect_ssl_socket( + host: str | bytes, + port: int, + *, + ssl_context=None, + source_address=None, + happy_eyeballs_delay: float = CONNECT_DELAY, +) -> "AsyncSocket": """ Convenience wrapper over connect_socket with SSL/TLS functionality """ if not _ssl: raise RuntimeError("SSL is not supported on the current platform") - return await wrap_socket_with_ssl(await connect_socket(host, port, source_address=source_address, - happy_eyeballs_delay=happy_eyeballs_delay), - context=ssl_context, - server_hostname=host) + return await wrap_socket_with_ssl( + await connect_socket( + host, + port, + source_address=source_address, + happy_eyeballs_delay=happy_eyeballs_delay, + ), + context=ssl_context, + server_hostname=host, + ) class AsyncSocket(AsyncResource): @@ -588,4 +628,3 @@ class AsyncSocket(AsyncResource): def __repr__(self): return f"AsyncSocket({self.socket})" - diff --git a/structio/parallel.py b/structio/parallel.py index d022e35..0b97aeb 100644 --- a/structio/parallel.py +++ b/structio/parallel.py @@ -3,12 +3,7 @@ import os import structio import subprocess -from subprocess import ( - CalledProcessError, - CompletedProcess, - DEVNULL, - PIPE -) +from subprocess import CalledProcessError, CompletedProcess, DEVNULL, PIPE from structio.io import FileStream @@ -41,7 +36,9 @@ class Popen: async def wait(self): status = self._process.poll() if status is None: - status = await structio.thread.run_in_worker(self._process.wait, cancellable=True) + status = await structio.thread.run_in_worker( + self._process.wait, cancellable=True + ) return status async def communicate(self, input=b"") -> tuple[bytes, bytes]: @@ -78,14 +75,18 @@ class Popen: return getattr(self._process, item) -async def run(args, *, stdin=None, input=None, stdout=None, stderr=None, shell=False, check=False): +async def run( + args, *, stdin=None, input=None, stdout=None, stderr=None, shell=False, check=False +): """ Async version of subprocess.run() """ if input: stdin = subprocess.PIPE - async with Popen(args, stdin=stdin, stdout=stdout, stderr=stderr, shell=shell) as process: + async with Popen( + args, stdin=stdin, stdout=stdout, stderr=stderr, shell=shell + ) as process: try: stdout, stderr = await process.communicate(input) except: @@ -103,6 +104,13 @@ async def check_output(args, *, stdin=None, stderr=None, shell=False, input=None Async version of subprocess.check_output """ - out = await run(args, stdout=PIPE, stdin=stdin, stderr=stderr, shell=shell, - check=True, input=input) + out = await run( + args, + stdout=PIPE, + stdin=stdin, + stderr=stderr, + shell=shell, + check=True, + input=input, + ) return out.stdout diff --git a/structio/signals.py b/structio/signals.py new file mode 100644 index 0000000..b5338eb --- /dev/null +++ b/structio/signals.py @@ -0,0 +1,86 @@ +# Signal handling module +import signal +from collections import defaultdict +from types import FrameType + +from structio.io.socket import AsyncSocket, socketpair +from typing import Callable, Any, Coroutine +from structio.thread import AsyncThreadQueue +from structio.core.task import Task +from structio.core.run import current_loop + + +_sig_data = AsyncThreadQueue(float("inf")) +_sig_handlers: dict[ + signal.Signals, Callable[[Any, Any], Coroutine[Any, Any, Any]] | None +] = defaultdict(lambda: None) +_watcher: Task | None = None +_reader, _writer = socketpair() + + +def _handle(sig: int, frame: FrameType): + _sig_data.put_sync((sig, frame)) + + +def get_signal_handler( + sig: int, +) -> Callable[[Any, Any], Coroutine[Any, Any, Any]] | None: + """ + Returns the currently installed async signal handler for the + given signal or None if it is not set + """ + + return _sig_handlers[signal.Signals(sig)] + + +def set_signal_handler( + sig: int, handler: Callable[[Any, Any], Coroutine[Any, Any, Any]] +) -> Callable[[Any, Any], Coroutine[Any, Any, Any]] | None: + """ + Sets the given coroutine to handle the given signal asynchronously. The + previous async signal handler is returned if any was set, otherwise + None is returned + """ + + global _watcher + if not _watcher: + signal.set_wakeup_fd(_writer.fileno()) + _watcher = current_loop().spawn_system_task(signal_watcher, _reader) + # Raises an appropriate error + sig = signal.Signals(sig) + match sig: + case signal.SIGKILL | signal.SIGSTOP: + raise ValueError(f"signal {sig!r} does not support custom handlers") + case _: + prev = _sig_handlers[sig] + signal.signal(sig, _handle) + _sig_handlers[sig] = handler + return prev + + +async def signal_watcher(sock: AsyncSocket): + while True: + # Even though we use set_wakeup_fd (which makes sure + # our I/O manager is signal-aware and exits cleanly + # when they arrive), it turns out that actually using the + # data Python sends over our socket is trickier than it + # sounds at first. That is because if we receive a bunch + # of signals and the socket buffer gets filled, we are going + # to lose all signals after that. Python can raise a warning + # about this, but it's 1) Not ideal, we're still losing signals, + # which is bad if we can do better and 2) It can be confusing, + # because now we're leaking details about the way signals are + # implemented, and that sucks too; So instead, we use set_wakeup_fd + # merely as a notification mechanism to wake up our watcher and + # register a custom signal handler that stores all the information + # about incoming signals in an unbuffered queue (which means that even + # if the socket's buffer gets filled, we are still going to deliver all + # signals when we do our first call to read()). I'm a little uneasy about + # using an unbounded queue, but realistically I doubt that one would face + # memory problems because their code is receiving thousands of signals and + # the event loop is not handling them fast enough (right?) + await sock.receive(1) + while _sig_data: + sig, frame = await _sig_data.get() + if _sig_handlers[sig]: + await _sig_handlers[sig](sig, frame) diff --git a/structio/sync.py b/structio/sync.py index 7fa5cb0..5c7bc45 100644 --- a/structio/sync.py +++ b/structio/sync.py @@ -233,7 +233,6 @@ class Semaphore: """ def __init__(self, max_size: int, initial_size: int | None = None): - if initial_size is None: initial_size = max_size assert initial_size <= max_size @@ -407,7 +406,9 @@ class RLock(Lock): await checkpoint() -_events: dict[str, list[Callable[[Any, Any], Coroutine[Any, Any, Any]]]] = defaultdict(list) +_events: dict[str, list[Callable[[Any, Any], Coroutine[Any, Any, Any]]]] = defaultdict( + list +) async def emit(evt: str, *args, **kwargs): @@ -459,6 +460,7 @@ def on_event(evt: str): @wraps def wrapper(*args, **kwargs): f(*args, **kwargs) + register_event(evt, f) return wrapper diff --git a/tests/chatroom_server.py b/tests/chatroom_server.py index 331a440..dd58274 100644 --- a/tests/chatroom_server.py +++ b/tests/chatroom_server.py @@ -19,7 +19,9 @@ async def event_handler(q: structio.Queue): msg, payload = await q.get() logging.info(f"Caught event {msg!r} with the following payload: {payload}") except Exception as e: - logging.error(f"An exception occurred in the message handler -> {type(e).__name__}: {e}") + logging.error( + f"An exception occurred in the message handler -> {type(e).__name__}: {e}" + ) except structio.exceptions.Cancelled: logging.warning(f"Cancellation detected, message handler shutting down") # Propagate the cancellation @@ -50,7 +52,9 @@ async def serve(bind_address: tuple): pool.spawn(handler, conn, queue) except Exception as err: # Because exceptions just *work* - logging.info(f"{address_tuple[0]}:{address_tuple[1]} has raised {type(err).__name__}: {err}") + logging.info( + f"{address_tuple[0]}:{address_tuple[1]} has raised {type(err).__name__}: {err}" + ) async def handler(sock: structio.socket.AsyncSocket, q: structio.Queue): @@ -62,7 +66,9 @@ async def handler(sock: structio.socket.AsyncSocket, q: structio.Queue): address = clients[sock][1] name = "" async with sock: # Closes the socket automatically - await sock.send_all(b"Welcome to the chatroom pal, may you tell me your name?\n> ") + await sock.send_all( + b"Welcome to the chatroom pal, may you tell me your name?\n> " + ) cond = True while cond: while not name.endswith("\n"): @@ -76,7 +82,9 @@ async def handler(sock: structio.socket.AsyncSocket, q: structio.Queue): clients[sock][0] = name break else: - await sock.send_all(b"Sorry, but that name is already taken. Try again!\n> ") + await sock.send_all( + b"Sorry, but that name is already taken. Try again!\n> " + ) await sock.send_all(f"Okay {name}, welcome to the chatroom!\n".encode()) await q.put(("join", (address, name))) logging.info(f"{name} has joined the chatroom ({address}), informing clients") @@ -103,10 +111,14 @@ async def handler(sock: structio.socket.AsyncSocket, q: structio.Queue): logging.info(f"Got: {data!r} from {address}") for i, client_sock in enumerate(clients): if client_sock != sock and clients[client_sock][0]: - logging.info(f"Sending {data!r} to {':'.join(map(str, await client_sock.getpeername()))}") + logging.info( + f"Sending {data!r} to {':'.join(map(str, await client_sock.getpeername()))}" + ) if not data.endswith(b"\n"): data += b"\n" - await client_sock.send_all(f"[{name}] ({address}): {data.decode()}> ".encode()) + await client_sock.send_all( + f"[{name}] ({address}): {data.decode()}> ".encode() + ) logging.info(f"Sent {data!r} to {i} clients") await q.put(("leave", name)) logging.info(f"Connection from {address} closed") @@ -133,4 +145,3 @@ if __name__ == "__main__": logging.info("Ctrl+C detected, exiting") else: logging.error(f"Exiting due to a {type(error).__name__}: {error}") - diff --git a/tests/echo_server.py b/tests/echo_server.py index 91cbd3d..3581692 100644 --- a/tests/echo_server.py +++ b/tests/echo_server.py @@ -26,7 +26,9 @@ async def serve(bind_address: tuple): await ctx.spawn(handler, conn, address_tuple) except Exception as err: # Because exceptions just *work* - logging.info(f"{address_tuple[0]}:{address_tuple[1]} has raised {type(err).__name__}: {err}") + logging.info( + f"{address_tuple[0]}:{address_tuple[1]} has raised {type(err).__name__}: {err}" + ) async def handler(sock: structio.socket.AsyncSocket, client_address: tuple): @@ -40,7 +42,9 @@ async def handler(sock: structio.socket.AsyncSocket, client_address: tuple): address = f"{client_address[0]}:{client_address[1]}" async with sock: # Closes the socket automatically - await sock.send_all(b"Welcome to the server pal, feel free to send me something!\n") + await sock.send_all( + b"Welcome to the server pal, feel free to send me something!\n" + ) while True: await sock.send_all(b"-> ") data = await sock.receive(1024) diff --git a/tests/events.py b/tests/events.py index 8dec5aa..cf1d9be 100644 --- a/tests/events.py +++ b/tests/events.py @@ -49,8 +49,11 @@ async def main_async_thread(i): # Of course, threaded events work both ways: coroutines and threads # can set/wait on them from either side. Isn't that neat? + def thread_worker_2(n, ev: structio.thread.AsyncThreadEvent): - print(f"[worker] Worker thread spawned, sleeping {n} seconds before setting the event") + print( + f"[worker] Worker thread spawned, sleeping {n} seconds before setting the event" + ) time.sleep(n) print("[worker] Setting the event") ev.set() diff --git a/tests/files.py b/tests/files.py index f115b09..6276365 100644 --- a/tests/files.py +++ b/tests/files.py @@ -7,8 +7,12 @@ from structio import aprint async def main_2(data: bytes): t = structio.clock() await aprint("[main] Using low level os module") - async with await structio.open_file(os.path.join(tempfile.gettempdir(), "structio_test.txt"), "wb+") as f: - await aprint(f"[main] Opened async file {f.name!r}, writing payload of {len(data)} bytes") + async with await structio.open_file( + os.path.join(tempfile.gettempdir(), "structio_test.txt"), "wb+" + ) as f: + await aprint( + f"[main] Opened async file {f.name!r}, writing payload of {len(data)} bytes" + ) await f.write(data) await f.seek(0) assert await f.read(len(data)) == data @@ -24,7 +28,9 @@ async def main_3(data: bytes): await aprint("[main] Using high level pathlib wrapper") path = structio.Path(tempfile.gettempdir()) / "structio_test.txt" async with await path.open("wb+") as f: - await aprint(f"[main] Opened async file {f.name!r}, writing payload of {len(data)} bytes") + await aprint( + f"[main] Opened async file {f.name!r}, writing payload of {len(data)} bytes" + ) await f.write(data) await f.seek(0) assert await f.read(len(data)) == data @@ -40,4 +46,3 @@ payload = b"a" * MB * 100 # Write 100MiB of data (too much?) structio.run(main_2, payload) structio.run(main_3, payload) - diff --git a/tests/https_test.py b/tests/https_test.py index c5f5c69..0734a26 100644 --- a/tests/https_test.py +++ b/tests/https_test.py @@ -30,7 +30,9 @@ async def test(host: str, port: int, bufsize: int = 4096): # We purposely do NOT check for the end of the response (\r\n) so that when the # connection is in keep-alive mode we hang and let our timeout expire the whole # block - print(f"Requesting up to {bufsize} bytes (current response size: {len(buffer)})") + print( + f"Requesting up to {bufsize} bytes (current response size: {len(buffer)})" + ) data = await socket.receive(bufsize) if data: print(f"Received {len(data)} bytes") @@ -40,7 +42,9 @@ async def test(host: str, port: int, bufsize: int = 4096): break if buffer: data = buffer.decode().split("\r\n") - print(f"HTTP Response below {'(might be incomplete)' if scope.cancelled else ''}:") + print( + f"HTTP Response below {'(might be incomplete)' if scope.cancelled else ''}:" + ) _print(f"Response: {data[0]}") _print("Headers:") content = False diff --git a/tests/ki_test.py b/tests/ki_test.py index eb03260..b781ac2 100644 --- a/tests/ki_test.py +++ b/tests/ki_test.py @@ -8,7 +8,9 @@ async def child(n: int): await structio.sleep(n) except structio.Cancelled: slept = structio.clock() - i - print(f"Oh no, I've been cancelled! (was gonna sleep {n - slept:.2f} more seconds)") + print( + f"Oh no, I've been cancelled! (was gonna sleep {n - slept:.2f} more seconds)" + ) raise print(f"Slept for {structio.clock() - i:.2f} seconds!") diff --git a/tests/memory_channel.py b/tests/memory_channel.py index b971a59..a34f8a2 100644 --- a/tests/memory_channel.py +++ b/tests/memory_channel.py @@ -33,7 +33,7 @@ async def writer(ch: structio.ChannelWriter, objects: list[Any]): async def main(objects: list[Any]): print("[main] Parent is alive") # We construct a new memory channel... - channel = structio.MemoryChannel(1) # 1 is the size of the internal buffer + channel = structio.MemoryChannel(1) # 1 is the size of the internal buffer async with structio.create_pool() as pool: # ... and dispatch the two ends to different # tasks. Isn't this neat? @@ -44,4 +44,3 @@ async def main(objects: list[Any]): structio.run(main, [1, 2, 3, 4]) - diff --git a/tests/nested_pool_inner_raises.py b/tests/nested_pool_inner_raises.py index a5c3aa9..49af80a 100644 --- a/tests/nested_pool_inner_raises.py +++ b/tests/nested_pool_inner_raises.py @@ -13,7 +13,9 @@ async def failing(name: str, n): before = structio.clock() print(f"[child {name}] Sleeping for {n} seconds") await structio.sleep(n) - print(f"[child {name}] Done! Slept for {structio.clock() - before:.2f} seconds, raising now!") + print( + f"[child {name}] Done! Slept for {structio.clock() - before:.2f} seconds, raising now!" + ) raise TypeError("waa") @@ -61,6 +63,7 @@ async def main_nested( print("[main] TypeError caught!") print(f"[main] Children exited in {structio.clock() - before:.2f} seconds") + if __name__ == "__main__": structio.run( main, @@ -72,5 +75,3 @@ if __name__ == "__main__": [("first", 1), ("third", 3)], [("second", 2), ("fourth", 4)], ) - - diff --git a/tests/processes.py b/tests/processes.py index 5a1e210..942b72b 100644 --- a/tests/processes.py +++ b/tests/processes.py @@ -12,7 +12,9 @@ async def main(data: str): out = await structio.parallel.check_output(cmd, input=data) print(out.rstrip(b"\n") == data) # Other, other option :D - process = structio.parallel.Popen(cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE) + process = structio.parallel.Popen( + cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE + ) out, _ = await process.communicate(data) print(out.rstrip(b"\n") == data) diff --git a/tests/self_cancel.py b/tests/self_cancel.py index ddefe71..9a3ecde 100644 --- a/tests/self_cancel.py +++ b/tests/self_cancel.py @@ -7,7 +7,9 @@ async def sleeper(n): try: await structio.sleep(n) except structio.Cancelled: - print(f"[sleeper] Oh no, I've been cancelled! (was gonna sleep {structio.clock() - i:.2f} more seconds)") + print( + f"[sleeper] Oh no, I've been cancelled! (was gonna sleep {structio.clock() - i:.2f} more seconds)" + ) raise print("[sleeper] Woke up!") @@ -26,7 +28,9 @@ async def main_simple(n, o, p): async def main_nested(n, o, p): - print(f"[main] Parent is alive, spawning {o} children in two contexts sleeping {n} seconds each") + print( + f"[main] Parent is alive, spawning {o} children in two contexts sleeping {n} seconds each" + ) t = structio.clock() async with structio.create_pool() as p1: for i in range(o): @@ -57,6 +61,7 @@ async def main_child(x: float): await structio.sleep(x) print(f"[main] Done in {structio.clock() - t:.2f} seconds") + # Should take about 5 seconds structio.run(main_simple, 5, 2, 2) structio.run(main_nested, 5, 2, 2) diff --git a/tests/signals.py b/tests/signals.py new file mode 100644 index 0000000..4bd4ef5 --- /dev/null +++ b/tests/signals.py @@ -0,0 +1,28 @@ +import structio +import signal +from types import FrameType + +ev = structio.Event() + + +async def handler(sig: signal.Signals, _frame: FrameType): + print(f"[handler] Handling signal {signal.Signals(sig).name!r}, waiting 1 second before setting event") + # Just to show off the async part + await structio.sleep(1) + ev.set() + + +async def main(n): + print("[main] Main is alive, setting signal handler") + assert structio.get_signal_handler(signal.SIGALRM) is None + structio.set_signal_handler(signal.SIGALRM, handler) + assert structio.get_signal_handler(signal.SIGALRM) is handler + print(f"[main] Signal handler set, calling signal.alarm({n})") + signal.alarm(n) + print("[main] Alarm scheduled, waiting on global event") + t = structio.clock() + await ev.wait() + print(f"[main] Exited in {structio.clock() - t:.2f} seconds") + + +structio.run(main, 5) diff --git a/tests/sliding_deadline.py b/tests/sliding_deadline.py index f89a7ac..c0aa5b1 100644 --- a/tests/sliding_deadline.py +++ b/tests/sliding_deadline.py @@ -30,4 +30,3 @@ async def main(n): structio.run(main, 7.5) - diff --git a/tests/task_handling.py b/tests/task_handling.py index fb2dec6..0202252 100644 --- a/tests/task_handling.py +++ b/tests/task_handling.py @@ -48,4 +48,3 @@ async def main_wait_failing(i): structio.run(main_cancel, 5) structio.run(main_wait_successful, 5) structio.run(main_wait_failing, 5) - diff --git a/tests/timeouts.py b/tests/timeouts.py index 31d3e85..9fc9b73 100644 --- a/tests/timeouts.py +++ b/tests/timeouts.py @@ -7,7 +7,9 @@ async def test_silent(i, j): with structio.skip_after(i) as scope: print(f"[test] Sleeping for {j} seconds") await structio.sleep(j) - print(f"[test] Finished in {structio.clock() - k:.2f} seconds (timed out: {scope.cancelled})") + print( + f"[test] Finished in {structio.clock() - k:.2f} seconds (timed out: {scope.cancelled})" + ) async def test_loud(i, j):