Ported old aiosched sockets and streams and added related tests

This commit is contained in:
Mattia Giambirtone 2023-06-06 11:31:30 +02:00 committed by nocturn9x
parent 3ea159c858
commit 26a43d5f84
Signed by: nocturn9x
GPG Key ID: 8270F9F467971E59
13 changed files with 924 additions and 62 deletions

View File

@ -11,6 +11,8 @@ from structio.core import task
from structio.core.task import Task, TaskState
from structio.sync import Event, Queue, MemoryChannel, Semaphore, Lock, RLock, emit, on_event, register_event
from structio.abc import Channel, Stream, ChannelReader, ChannelWriter
from structio.io import socket
from structio.io.socket import AsyncSocket
from structio.io.files import (
open_file,
wrap_file,

View File

@ -50,11 +50,15 @@ class AsyncResource(ABC):
async def close(self):
return NotImplemented
@abstractmethod
def fileno(self):
return NotImplemented
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.close()
class StreamWriter(AsyncResource):
class StreamWriter(AsyncResource, ABC):
"""
Interface for writing binary data to
a byte stream
@ -84,7 +88,7 @@ class StreamWriter(AsyncResource):
return NotImplemented
class StreamReader(AsyncResource):
class StreamReader(AsyncResource, ABC):
"""
Interface for reading binary data from
a byte stream. The stream implements the
@ -124,13 +128,21 @@ class StreamReader(AsyncResource):
return NotImplemented
class Stream(StreamReader, StreamWriter):
class Stream(StreamReader, StreamWriter, ABC):
"""
A generic, asynchronous, readable/writable binary stream
"""
@abstractmethod
async def flush(self):
"""
Flushes the underlying resource asynchronously
"""
class WriteCloseableStream(Stream):
return NotImplemented
class WriteCloseableStream(Stream, ABC):
"""
Extension to the Stream class that allows
shutting down the write end of the stream
@ -149,7 +161,7 @@ class WriteCloseableStream(Stream):
"""
class ChannelReader(AsyncResource):
class ChannelReader(AsyncResource, ABC):
"""
Interface for reading data from a
channel
@ -178,8 +190,11 @@ class ChannelReader(AsyncResource):
read from the channel
"""
def fileno(self):
return None
class ChannelWriter(AsyncResource):
class ChannelWriter(AsyncResource, ABC):
"""
Interface for writing data to a
channel
@ -201,12 +216,18 @@ class ChannelWriter(AsyncResource):
to write to the channel
"""
def fileno(self):
return None
class Channel(ChannelWriter, ChannelReader):
class Channel(ChannelWriter, ChannelReader, ABC):
"""
A generic, two-way channel
"""
def fileno(self):
return None
class BaseDebugger(ABC):
"""
@ -407,19 +428,21 @@ class BaseIOManager(ABC):
return NotImplemented
@abstractmethod
def request_read(self, rsc: AsyncResource):
def request_read(self, rsc: AsyncResource, task: Task):
"""
"Requests" a read operation on the given
resource to the I/O manager from the current task
resource to the I/O manager from the given
task
"""
return NotImplemented
@abstractmethod
def request_write(self, rsc: AsyncResource):
def request_write(self, rsc: AsyncResource, task: Task):
"""
"Requests" a write operation on the given
resource to the I/O manager from the current task
resource to the I/O manager from the given
task
"""
return NotImplemented
@ -427,8 +450,9 @@ class BaseIOManager(ABC):
@abstractmethod
def pending(self):
"""
Returns a boolean value that indicates whether
there's any I/O registered in the manager
Returns whether there's any tasks waiting
to read from/write to a resource registered
in the manager
"""
return NotImplemented
@ -455,6 +479,19 @@ class BaseIOManager(ABC):
return NotImplemented
@abstractmethod
def get_reader(self, rsc: AsyncResource):
"""
Returns the task reading from the given
resource, if any (None otherwise)
"""
@abstractmethod
def get_writer(self, rsc: AsyncResource):
"""
Returns the task writing to the given
resource, if any (None otherwise)
"""
class SignalManager(ABC):
"""
@ -505,13 +542,47 @@ class BaseKernel(ABC):
self.current_scope: "TaskScope" = None
self.tools: list[BaseDebugger] = tools or []
self.restrict_ki_to_checkpoints: bool = restrict_ki_to_checkpoints
self.running: bool = False
self.io_manager = io_manager
self.signal_managers = signal_managers
self.entry_point: Task | None = None
# Pool for system tasks
self.pool: "TaskPool" = None
@abstractmethod
def wait_readable(self, resource: AsyncResource):
"""
Schedule the given resource for reading from
the current task
"""
return NotImplemented
@abstractmethod
def wait_writable(self, resource: AsyncResource):
"""
Schedule the given resource for reading from
the current task
"""
return NotImplemented
@abstractmethod
def release_resource(self, resource: AsyncResource):
"""
Releases the given resource from the scheduler
"""
return NotImplemented
@abstractmethod
def notify_closing(self, resource: AsyncResource, broken: bool = False, owner: Task | None = None):
"""
Notifies the event loop that a given resource
is about to be closed and can be unscheduled
"""
return NotImplemented
@abstractmethod
def cancel_task(self, task: Task):
"""

View File

@ -7,12 +7,13 @@ from structio.abc import (
BaseDebugger,
BaseIOManager,
SignalManager,
AsyncResource
)
from structio.core.context import TaskPool, TaskScope
from structio.core.task import Task, TaskState
from structio.util.ki import CTRLC_PROTECTION_ENABLED
from structio.core.time.queue import TimeQueue
from structio.exceptions import StructIOException, Cancelled, TimedOut
from structio.exceptions import StructIOException, Cancelled, TimedOut, ResourceClosed, ResourceBroken
from collections import deque
from typing import Callable, Coroutine, Any
from functools import partial
@ -62,6 +63,26 @@ class FIFOKernel(BaseKernel):
]
)
def wait_readable(self, resource: AsyncResource):
self.io_manager.request_read(resource, self.current_task)
def wait_writable(self, resource: AsyncResource):
self.io_manager.request_write(resource, self.current_task)
def notify_closing(self, resource: AsyncResource, broken: bool = False, owner: Task | None = None):
if not broken:
exc = ResourceClosed("stream has been closed")
else:
exc = ResourceBroken("stream might be corrupted")
owner = owner or self.current_task
reader = self.io_manager.get_reader(resource)
writer = self.io_manager.get_writer(resource)
if reader is not owner:
self.throw(reader, exc)
if writer is not owner:
self.throw(writer, exc)
self.reschedule_running()
def get_closest_deadline_owner(self):
return self.paused.peek()
@ -77,18 +98,21 @@ class FIFOKernel(BaseKernel):
# We really can't afford to have our internals explode,
# sorry!
warnings.warn(
f"Exception during debugging event delivery ({evt_name!r}): {type(e).__name__} -> {e}",
f"Exception during debugging event delivery in {f!r} ({evt_name!r}): {type(e).__name__} -> {e}",
)
traceback.print_tb(e.__traceback__)
# We disable the tool, so it can't raise at the next debugging
# event
self.tools.remove(tool)
def done(self):
if self.entry_point.done():
return True
if any([self.run_queue, self.paused, self.io_manager.pending()]):
return False
for scope in self.scopes:
if not scope.done():
return False
if self.entry_point.done():
return True
return True
def spawn(self, func: Callable[[Any, Any], Coroutine[Any, Any, Any]], *args):
@ -168,6 +192,8 @@ class FIFOKernel(BaseKernel):
return
if task.state == TaskState.PAUSED:
self.paused.discard(task)
elif task.state == TaskState.IO:
self.io_manager.release_task(task)
self.handle_errors(partial(task.coroutine.throw, err), task)
def reschedule(self, task: Task):
@ -185,6 +211,7 @@ class FIFOKernel(BaseKernel):
self.run_queue.appendleft(self.current_task)
def schedule_point(self):
self.skip = True
self.reschedule_running()
def sleep(self, amount):
@ -233,7 +260,6 @@ class FIFOKernel(BaseKernel):
while not self.done():
if self.run_queue and not self.skip:
self.handle_errors(self.step)
self.running = False
self.skip = False
if self._sigint_handled and not self.restrict_ki_to_checkpoints:
self.throw(self.entry_point, KeyboardInterrupt())
@ -248,7 +274,7 @@ class FIFOKernel(BaseKernel):
Reschedules the currently running task
"""
self.run_queue.append(self.current_task)
self.reschedule(self.current_task)
def handle_errors(self, func: Callable, task: Task | None = None):
"""
@ -296,6 +322,10 @@ class FIFOKernel(BaseKernel):
self.event("on_exception_raised", task)
self.on_error(task)
def release_resource(self, resource: AsyncResource):
self.io_manager.release(resource)
self.reschedule_running()
def release(self, task: Task):
"""
Releases the timeouts and associated

View File

@ -1,14 +1,17 @@
from collections import defaultdict
from structio.abc import BaseIOManager, BaseKernel
from structio.abc import BaseIOManager, BaseKernel, AsyncResource
from structio.core.context import Task
from structio.core.run import current_loop, current_task
from structio.core.run import current_loop
import select
class SimpleIOManager(BaseIOManager):
"""
A simple, cross-platform, select()-based
I/O manager
I/O manager. This class is only meant to
be used as a default fallback and is quite
inefficient and slower compared to more ad-hoc
alternatives such as epoll or kqueue (it should
work on most platforms though)
"""
def __init__(self):
@ -17,21 +20,19 @@ class SimpleIOManager(BaseIOManager):
"""
# Maps resources to tasks
self.readers = {}
self.writers = {}
# This allows us to have a bidirectional mapping:
# we know both which tasks are using which resources
# and which resources are used by which tasks,
# without having to go through too many hoops and jumps.
self.tasks: dict[Task, list] = defaultdict(list)
self.readers: dict[AsyncResource, Task] = {}
self.writers: dict[AsyncResource, Task] = {}
def pending(self):
# We don't return bool(self.resources) because there is
# no pending I/O to do if no tasks are waiting to read or
# write, even if there's dangling resources around!
return bool(self.readers or self.writers)
def _collect_readers(self) -> list:
def get_reader(self, rsc: AsyncResource):
return self.readers.get(rsc)
def get_writer(self, rsc: AsyncResource):
return self.writers.get(rsc)
def _collect_readers(self) -> list[int]:
"""
Collects all resources that need to be read from,
so we can select() on them later
@ -39,10 +40,10 @@ class SimpleIOManager(BaseIOManager):
result = []
for resource in self.readers:
result.append(resource)
result.append(resource.fileno())
return result
def _collect_writers(self) -> list:
def _collect_writers(self) -> list[int]:
"""
Collects all resources that need to be written to,
so we can select() on them later
@ -50,34 +51,43 @@ class SimpleIOManager(BaseIOManager):
result = []
for resource in self.writers:
result.append(resource)
result.append(resource.fileno())
return result
def wait_io(self):
kernel: BaseKernel = current_loop()
deadline = kernel.get_closest_deadline()
if deadline == float("inf"):
deadline = 0
readable, writable, _ = select.select(
self._collect_readers(),
self._collect_writers(),
[],
kernel.get_closest_deadline(),
deadline,
)
for read_ready in readable:
kernel.reschedule(self.readers[read_ready])
for resource, task in self.readers.items():
if resource.fileno() == read_ready:
kernel.reschedule(task)
for write_ready in writable:
kernel.reschedule(self.writers[write_ready])
for resource, task in self.writers.items():
if resource.fileno() == write_ready:
kernel.reschedule(task)
def request_read(self, rsc):
task = current_task()
def request_read(self, rsc: AsyncResource, task: Task):
self.readers[rsc] = task
def request_write(self, rsc):
task = current_task()
def request_write(self, rsc: AsyncResource, task: Task):
self.writers[rsc] = task
def release(self, resource):
def release(self, resource: AsyncResource):
self.readers.pop(resource, None)
self.writers.pop(resource, None)
def release_task(self, task: Task):
for resource in self.tasks[task]:
self.release(resource)
for resource, owner in self.readers.copy().items():
if owner == task:
self.readers.pop(resource)
for resource, owner in self.writers.copy().items():
if owner == task:
self.writers.pop(resource)

View File

@ -68,3 +68,19 @@ async def checkpoint():
await check_cancelled()
await schedule_point()
async def wait_readable(rsc):
return await syscall("wait_readable", rsc)
async def wait_writable(rsc):
return await syscall("wait_writable", rsc)
async def closing(rsc):
return await syscall("notify_closing", rsc)
async def release(rsc):
return await syscall("release_resource", rsc)

View File

@ -84,11 +84,10 @@ class TimeQueue:
def get_closest_deadline(self) -> float:
"""
Returns the closest deadline that is meant to expire
or raises IndexError if the queue is empty
"""
if not self:
raise IndexError("TimeQueue is empty")
return float("inf")
return self.container[0][0]
def __iter__(self):

View File

@ -38,3 +38,11 @@ class ResourceBusy(StructIOException):
Raised when an attempt is made to use an
asynchronous resource that is currently busy
"""
class ResourceBroken(StructIOException):
"""
Raised when an asynchronous resource gets
corrupted and is no longer usable
"""

View File

@ -0,0 +1,152 @@
# This is, ahem, inspired by Curio and Trio. See https://github.com/dabeaz/curio/issues/104
import os
import warnings
from structio.core.syscalls import checkpoint, wait_readable, wait_writable, closing, release
from structio.exceptions import ResourceClosed
from structio.abc import AsyncResource
try:
from ssl import SSLWantReadError, SSLWantWriteError, SSLSocket
WantRead = (BlockingIOError, SSLWantReadError, InterruptedError)
WantWrite = (BlockingIOError, SSLWantWriteError, InterruptedError)
except ImportError:
WantWrite = (BlockingIOError, InterruptedError)
WantRead = (BlockingIOError, InterruptedError)
SSLSocket = None
class FdWrapper:
"""
A simple wrapper around a file descriptor that
allows the event loop to perform an optimization
regarding I/O event registration safely. This is
because while integer file descriptors can be reused
by the operating system, instances of this class will
not (hence if the event loop keeps around a dead instance
of an FdWrapper, it at least won't accidentally register
a new file with that same file descriptor). A bonus is
that this also allows us to always assume that we can call
fileno() on all objects registered in our selector, regardless
of whether the wrapped fd is an int or something else entirely
"""
__slots__ = ("fd", )
def __init__(self, fd):
self.fd = fd
def fileno(self):
return self.fd
# Can be converted to an int
def __int__(self):
return self.fd
def __repr__(self):
return f"<fd={self.fd!r}>"
class AsyncStream(AsyncResource):
"""
A generic asynchronous stream over
a file descriptor. Functionality
is OS-dependent
"""
def __init__(
self,
fd: int,
open_fd: bool = True,
close_on_context_exit: bool = True,
**kwargs,
):
self._fd = FdWrapper(fd)
self.fileobj = None
if open_fd:
self.fileobj = os.fdopen(int(self._fd), **kwargs)
os.set_blocking(int(self._fd), False)
# Do we close ourselves upon the end of a context manager?
self.close_on_context_exit = close_on_context_exit
async def read(self, size: int = -1):
"""
Reads up to size bytes from the
given stream. If size == -1, read
until EOF is reached
"""
while True:
try:
data = self.fileobj.read(size)
await checkpoint()
return data
except WantRead:
await wait_readable(self)
async def write(self, data):
"""
Writes data to the stream.
Returns the number of bytes
written
"""
while True:
try:
data = self.fileobj.write(data)
await checkpoint()
return data
except WantWrite:
await wait_writable(self)
async def close(self):
"""
Closes the stream asynchronously
"""
if self.fileno() == -1:
raise ResourceClosed("I/O operation on closed stream")
self._fd = -1
await closing(self)
await release(self)
self.fileobj.close()
self.fileobj = None
await checkpoint()
def fileno(self):
"""
Wrapper socket method
"""
return int(self._fd)
async def __aenter__(self):
self.fileobj.__enter__()
return self
async def __aexit__(self, *args):
if self._fd != -1 and self.close_on_context_exit:
await self.close()
async def dup(self):
"""
Wrapper stream method
"""
return type(self)(os.dup(self.fileno()))
def __repr__(self):
return f"AsyncStream({self.fileobj})"
def __del__(self):
"""
Stream destructor. Do *not* call
this directly: stuff will break
"""
if self._fd != -1 and self.fileobj.fileno() != -1:
try:
os.set_blocking(self.fileno(), False)
os.close(self.fileno())
except OSError as e:
warnings.warn(
f"An exception occurred in __del__ for stream {self} -> {type(e).__name__}: {e}"
)

View File

@ -44,7 +44,7 @@ _FILE_ASYNC_METHODS = {
}
class AsyncResourceWrapper(AsyncResource):
class AsyncFile(AsyncResource):
"""
Asynchronous wrapper around regular file-like objects.
Blocking operations are turned into async ones using threads.
@ -52,6 +52,9 @@ class AsyncResourceWrapper(AsyncResource):
and read/write methods
"""
def fileno(self):
return self.handle.fileno()
def __init__(self, f):
self._file = f
@ -108,7 +111,8 @@ class AsyncResourceWrapper(AsyncResource):
# This operation is non-cancellable, meaning it'll run
# no matter what our event loop has to say about it.
# After we're done, we'll obviously re-raise the cancellation
# if necessary
# if necessary. This ensures files are always closed even when
# the operation gets cancelled
await structio.thread.run_in_worker(self.handle.close)
# If we were cancelled, here is where we raise
await check_cancelled()
@ -123,7 +127,7 @@ async def open_file(
newline=None,
closefd=True,
opener=None,
) -> AsyncResourceWrapper:
) -> AsyncFile:
"""
Like io.open(), but async. Magic
"""
@ -135,13 +139,13 @@ async def open_file(
)
def wrap_file(file) -> AsyncResourceWrapper:
def wrap_file(file) -> AsyncFile:
"""
Wraps a file-like object into an async
wrapper
"""
return AsyncResourceWrapper(file)
return AsyncFile(file)
stdin = wrap_file(sys.stdin)

View File

@ -1,13 +1,344 @@
import structio
from structio.abc import AsyncResource
from structio.core.syscalls import check_cancelled
from structio.io import FdWrapper, WantRead, WantWrite, SSLSocket
from structio.exceptions import ResourceClosed, ResourceBroken
from structio.core.syscalls import wait_readable, wait_writable, checkpoint, closing, release
from functools import wraps
import socket as _socket
@wraps(_socket.socket)
def socket(*args, **kwargs):
return None # TODO
return AsyncSocket(_socket.socket(*args, **kwargs))
# TODO
class AsyncSocket(AsyncResource):
"""
Abstraction layer for asynchronous sockets
"""
def fileno(self):
return int(self._fd)
def __init__(
self,
sock: _socket.socket,
close_on_context_exit: bool = True,
do_handshake_on_connect: bool = True,
):
self._fd = FdWrapper(sock.fileno())
self.close_on_context_exit = close_on_context_exit
# Do we perform the TCP handshake automatically
# upon connection? This is mostly needed for SSL
# sockets
self.do_handshake_on_connect = do_handshake_on_connect
self.socket = sock
self.socket.setblocking(False)
# A socket that isn't connected doesn't
# need to be closed
self.needs_closing: bool = False
async def receive(self, max_size: int, flags: int = 0) -> bytes:
"""
Receives up to max_size bytes from a socket asynchronously
"""
assert max_size >= 1, "max_size must be >= 1"
if self._fd == -1:
raise ResourceClosed("I/O operation on closed socket")
while True:
try:
data = self.socket.recv(max_size, flags)
await checkpoint()
return data
except WantRead:
await wait_readable(self)
except WantWrite:
await wait_writable(self)
async def receive_exactly(self, size: int, flags: int = 0) -> bytes:
"""
Receives exactly size bytes from a socket asynchronously.
"""
# https://stackoverflow.com/questions/55825905/how-can-i-reliably-read-exactly-n-bytes-from-a-tcp-socket
buf = bytearray(size)
pos = 0
while pos < size:
n = await self.recv_into(memoryview(buf)[pos:], flags=flags)
if n == 0:
raise ResourceBroken("incomplete read detected")
pos += n
return bytes(buf)
async def connect(self, address):
"""
Wrapper socket method
"""
if self._fd == -1:
raise ResourceClosed("I/O operation on closed socket")
while True:
try:
self.socket.connect(address)
if self.do_handshake_on_connect:
await self.do_handshake()
await checkpoint()
break
except WantWrite:
await wait_writable(self)
self.needs_closing = True
async def close(self):
"""
Wrapper socket method
"""
if self.needs_closing:
self.socket.close()
await checkpoint()
async def accept(self):
"""
Accepts the socket, completing the 3-step TCP handshake asynchronously
"""
if self._fd == -1:
raise ResourceClosed("I/O operation on closed socket")
while True:
try:
remote, addr = self.socket.accept()
await checkpoint()
return type(self)(remote), addr
except WantRead:
await wait_readable(self)
async def send_all(self, data: bytes, flags: int = 0):
"""
Sends all data inside the buffer asynchronously until it is empty
"""
if self._fd == -1:
raise ResourceClosed("I/O operation on closed socket")
sent_no = 0
while data:
try:
sent_no = self.socket.send(data, flags)
await checkpoint()
except WantRead:
await wait_readable(self)
except WantWrite:
await wait_writable(self)
data = data[sent_no:]
async def shutdown(self, how):
"""
Wrapper socket method
"""
if self.fileno() == -1:
raise ResourceClosed("I/O operation on closed socket")
if self.socket:
self.socket.shutdown(how)
await checkpoint()
async def bind(self, addr: tuple):
"""
Binds the socket to an address
:param addr: The address, port tuple to bind to
:type addr: tuple
"""
if self._fd == -1:
raise ResourceClosed("I/O operation on closed socket")
self.socket.bind(addr)
await checkpoint()
async def listen(self, backlog: int):
"""
Starts listening with the given backlog
:param backlog: The socket's backlog
:type backlog: int
"""
if self._fd == -1:
raise ResourceClosed("I/O operation on closed socket")
self.socket.listen(backlog)
await checkpoint()
# Yes, I stole these from Curio because I could not be
# arsed to write a bunch of uninteresting simple socket
# methods from scratch, deal with it.
def settimeout(self, seconds):
"""
Wrapper socket method
"""
raise RuntimeError("Use with_timeout() to set a timeout")
def gettimeout(self):
"""
Wrapper socket method
"""
return None
def dup(self):
"""
Wrapper socket method
"""
return type(self)(self.socket.dup(), self.do_handshake_on_connect)
async def do_handshake(self):
"""
Wrapper socket method
"""
if not hasattr(self.socket, "do_handshake"):
return
while True:
try:
self.socket: SSLSocket # Silences pycharm warnings
self.socket.do_handshake()
await checkpoint()
except WantRead:
await wait_readable(self)
except WantWrite:
await wait_writable(self)
async def recvfrom(self, buffersize, flags=0):
"""
Wrapper socket method
"""
while True:
try:
return self.socket.recvfrom(buffersize, flags)
except WantRead:
await wait_readable(self)
except WantWrite:
await wait_writable(self)
async def recv_into(self, buffer, nbytes=0, flags=0):
"""
Wrapper socket method
"""
while True:
try:
data = self.socket.recv_into(buffer, nbytes, flags)
await checkpoint()
return data
except WantRead:
await wait_readable(self)
except WantWrite:
await wait_writable(self)
async def recvfrom_into(self, buffer, bytes=0, flags=0):
"""
Wrapper socket method
"""
while True:
try:
data = self.socket.recvfrom_into(buffer, bytes, flags)
await checkpoint()
return data
except WantRead:
await wait_readable(self)
except WantWrite:
await wait_writable(self)
async def sendto(self, bytes, flags_or_address, address=None):
"""
Wrapper socket method
"""
if address:
flags = flags_or_address
else:
address = flags_or_address
flags = 0
while True:
try:
data = self.socket.sendto(bytes, flags, address)
await checkpoint()
return data
except WantWrite:
await wait_writable(self)
except WantRead:
await wait_readable(self)
async def getpeername(self):
"""
Wrapper socket method
"""
while True:
try:
data = self.socket.getpeername()
await checkpoint()
return data
except WantWrite:
await wait_writable(self)
except WantRead:
await wait_readable(self)
async def getsockname(self):
"""
Wrapper socket method
"""
while True:
try:
data = self.socket.getpeername()
await checkpoint()
return data
except WantWrite:
await wait_writable(self)
except WantRead:
await wait_readable(self)
async def recvmsg(self, bufsize, ancbufsize=0, flags=0):
"""
Wrapper socket method
"""
while True:
try:
data = self.socket.recvmsg(bufsize, ancbufsize, flags)
await checkpoint()
return data
except WantRead:
await wait_readable(self)
async def recvmsg_into(self, buffers, ancbufsize=0, flags=0):
"""
Wrapper socket method
"""
while True:
try:
data = self.socket.recvmsg_into(buffers, ancbufsize, flags)
await checkpoint()
return data
except WantRead:
await wait_readable(self)
async def sendmsg(self, buffers, ancdata=(), flags=0, address=None):
"""
Wrapper socket method
"""
while True:
try:
data = self.socket.sendmsg(buffers, ancdata, flags, address)
await checkpoint()
return data
except WantRead:
await wait_writable(self)
def __repr__(self):
return f"AsyncSocket({self.socket})"

View File

@ -1,10 +1,39 @@
# Module inspired by subprocess which allows for asynchronous
# multiprocessing
import os
import subprocess
from subprocess import (
CalledProcessError,
CompletedProcess,
SubprocessError,
STDOUT,
DEVNULL,
PIPE
)
class Process:
class Popen:
"""
An asynchronous process
Wrapper around subprocess.Popen, but async
"""
# TODO
def __init__(self, *args, **kwargs):
"""
Public object constructor
"""
if "universal_newlines" in kwargs:
# Not sure why? But everyone else is doing it so :shrug:
raise RuntimeError("universal_newlines is not supported")
if stdin := kwargs.get("stdin"):
# Curio mentions stuff breaking if the child process
# is passed a stdin fd that is set to non-blocking mode
if hasattr(os, "set_blocking"):
os.set_blocking(stdin.fileno(), True)
# Delegate to Popen's constructor
self._process = subprocess.Popen(*args, **kwargs)
def __getattr__(self, item):
# Delegate to internal process object
return getattr(self._process, item)

136
tests/chatroom_server.py Normal file
View File

@ -0,0 +1,136 @@
import structio
import logging
import sys
# An asynchronous chatroom
clients: dict[structio.socket.AsyncSocket, list[str, str]] = {}
names: set[str] = set()
async def event_handler(q: structio.Queue):
"""
Reads data submitted onto the queue
"""
try:
logging.info("Event handler spawned")
while True:
msg, payload = await q.get()
logging.info(f"Caught event {msg!r} with the following payload: {payload}")
except Exception as e:
logging.error(f"An exception occurred in the message handler -> {type(e).__name__}: {e}")
except structio.exceptions.Cancelled:
logging.warning(f"Cancellation detected, message handler shutting down")
# Propagate the cancellation
raise
async def serve(bind_address: tuple):
"""
Serves asynchronously forever (or until Ctrl+C ;))
:param bind_address: The address to bind the server to, represented as a tuple
(address, port) where address is a string and port is an integer
"""
sock = structio.socket.socket()
queue = structio.Queue()
await sock.bind(bind_address)
await sock.listen(5)
logging.info(f"Serving asynchronously at {bind_address[0]}:{bind_address[1]}")
async with structio.create_pool() as pool:
pool.spawn(event_handler, queue)
async with sock:
while True:
try:
conn, address_tuple = await sock.accept()
clients[conn] = ["", f"{address_tuple[0]}:{address_tuple[1]}"]
await queue.put(("connect", clients[conn]))
logging.info(f"{address_tuple[0]}:{address_tuple[1]} connected")
await pool.spawn(handler, conn, queue)
except Exception as err:
# Because exceptions just *work*
logging.info(f"{address_tuple[0]}:{address_tuple[1]} has raised {type(err).__name__}: {err}")
async def handler(sock: structio.socket.AsyncSocket, q: structio.Queue):
"""
Handles a single client connection
:param sock: The AsyncSocket object connected to the client
"""
address = clients[sock][1]
name = ""
async with sock: # Closes the socket automatically
await sock.send_all(b"Welcome to the chatroom pal, may you tell me your name?\n> ")
cond = True
while cond:
while not name.endswith("\n"):
name = (await sock.receive(64)).decode()
if name == "":
cond = False
break
name = name.rstrip("\n")
if name not in names:
names.add(name)
clients[sock][0] = name
break
else:
await sock.send_all(b"Sorry, but that name is already taken. Try again!\n> ")
await sock.send_all(f"Okay {name}, welcome to the chatroom!\n".encode())
await q.put(("join", (address, name)))
logging.info(f"{name} has joined the chatroom ({address}), informing clients")
for i, client_sock in enumerate(clients):
if client_sock != sock and clients[client_sock][0]:
await client_sock.send_all(f"{name} joins the chatroom!\n> ".encode())
while True:
await sock.send_all(b"> ")
data = await sock.receive(1024)
if not data:
break
decoded = data.decode().rstrip("\n")
if decoded.startswith("/"):
logging.info(f"{name} issued server command {decoded}")
await q.put(("cmd", (name, decoded[1:])))
match decoded[1:]:
case "bye":
await sock.send_all(b"Bye!\n")
break
case _:
await sock.send_all(b"Unknown command\n")
else:
await q.put(("msg", (name, data)))
logging.info(f"Got: {data!r} from {address}")
for i, client_sock in enumerate(clients):
if client_sock != sock and clients[client_sock][0]:
logging.info(f"Sending {data!r} to {':'.join(map(str, await client_sock.getpeername()))}")
if not data.endswith(b"\n"):
data += b"\n"
await client_sock.send_all(f"[{name}] ({address}): {data.decode()}> ".encode())
logging.info(f"Sent {data!r} to {i} clients")
await q.put(("leave", name))
logging.info(f"Connection from {address} closed")
logging.info(f"{name} has left the chatroom ({address}), informing clients")
for i, client_sock in enumerate(clients):
if client_sock != sock and clients[client_sock][0]:
await client_sock.send_all(f"{name} has left the chatroom\n> ".encode())
clients.pop(sock)
names.discard(name)
logging.info("Handler shutting down")
if __name__ == "__main__":
port = int(sys.argv[1]) if len(sys.argv) > 1 else 1501
logging.basicConfig(
level=20,
format="[%(levelname)s] %(asctime)s %(message)s",
datefmt="%d/%m/%Y %p",
)
try:
structio.run(serve, ("0.0.0.0", port))
except (Exception, KeyboardInterrupt) as error: # Exceptions propagate!
if isinstance(error, KeyboardInterrupt):
logging.info("Ctrl+C detected, exiting")
else:
logging.error(f"Exiting due to a {type(error).__name__}: {error}")

74
tests/echo_server.py Normal file
View File

@ -0,0 +1,74 @@
import sys
import logging
import structio
# A test to check for asynchronous I/O
async def serve(bind_address: tuple):
"""
Serves asynchronously forever
:param bind_address: The address to bind the server to represented as a tuple
(address, port) where address is a string and port is an integer
"""
sock = structio.socket.socket()
await sock.bind(bind_address)
await sock.listen(5)
logging.info(f"Serving asynchronously at {bind_address[0]}:{bind_address[1]}")
async with structio.create_pool() as ctx:
async with sock:
while True:
try:
conn, address_tuple = await sock.accept()
logging.info(f"{address_tuple[0]}:{address_tuple[1]} connected")
await ctx.spawn(handler, conn, address_tuple)
except Exception as err:
# Because exceptions just *work*
logging.info(f"{address_tuple[0]}:{address_tuple[1]} has raised {type(err).__name__}: {err}")
async def handler(sock: structio.socket.AsyncSocket, client_address: tuple):
"""
Handles a single client connection
:param sock: The AsyncSocket object connected to the client
:param client_address: The client's address represented as a tuple
(address, port) where address is a string and port is an integer
:type client_address: tuple
"""
address = f"{client_address[0]}:{client_address[1]}"
async with sock: # Closes the socket automatically
await sock.send_all(b"Welcome to the server pal, feel free to send me something!\n")
while True:
await sock.send_all(b"-> ")
data = await sock.receive(1024)
if not data:
break
elif data == b"exit\n":
await sock.send_all(b"I'm dead dude\n")
raise TypeError("Oh, no, I'm gonna die!")
elif data == b"fatal\n":
await sock.send_all(b"What a dick\n")
raise KeyboardInterrupt("He told me to do it!")
logging.info(f"Got: {data!r} from {address}")
await sock.send_all(b"Got: " + data)
logging.info(f"Echoed back {data!r} to {address}")
logging.info(f"Connection from {address} closed")
if __name__ == "__main__":
port = int(sys.argv[1]) if len(sys.argv) > 1 else 1501
logging.basicConfig(
level=20,
format="[%(levelname)s] %(asctime)s %(message)s",
datefmt="%d/%m/%Y %H:%M:%S %p",
)
try:
structio.run(serve, ("localhost", port))
except (Exception, KeyboardInterrupt) as error: # Exceptions propagate!
if isinstance(error, KeyboardInterrupt):
logging.info("Ctrl+C detected, exiting")
else:
logging.error(f"Exiting due to a {type(error).__name__}: {error}")