Ported old aiosched sockets and streams and added related tests
This commit is contained in:
parent
3ea159c858
commit
26a43d5f84
|
@ -11,6 +11,8 @@ from structio.core import task
|
|||
from structio.core.task import Task, TaskState
|
||||
from structio.sync import Event, Queue, MemoryChannel, Semaphore, Lock, RLock, emit, on_event, register_event
|
||||
from structio.abc import Channel, Stream, ChannelReader, ChannelWriter
|
||||
from structio.io import socket
|
||||
from structio.io.socket import AsyncSocket
|
||||
from structio.io.files import (
|
||||
open_file,
|
||||
wrap_file,
|
||||
|
|
|
@ -50,11 +50,15 @@ class AsyncResource(ABC):
|
|||
async def close(self):
|
||||
return NotImplemented
|
||||
|
||||
@abstractmethod
|
||||
def fileno(self):
|
||||
return NotImplemented
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
await self.close()
|
||||
|
||||
|
||||
class StreamWriter(AsyncResource):
|
||||
class StreamWriter(AsyncResource, ABC):
|
||||
"""
|
||||
Interface for writing binary data to
|
||||
a byte stream
|
||||
|
@ -84,7 +88,7 @@ class StreamWriter(AsyncResource):
|
|||
return NotImplemented
|
||||
|
||||
|
||||
class StreamReader(AsyncResource):
|
||||
class StreamReader(AsyncResource, ABC):
|
||||
"""
|
||||
Interface for reading binary data from
|
||||
a byte stream. The stream implements the
|
||||
|
@ -124,13 +128,21 @@ class StreamReader(AsyncResource):
|
|||
return NotImplemented
|
||||
|
||||
|
||||
class Stream(StreamReader, StreamWriter):
|
||||
class Stream(StreamReader, StreamWriter, ABC):
|
||||
"""
|
||||
A generic, asynchronous, readable/writable binary stream
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def flush(self):
|
||||
"""
|
||||
Flushes the underlying resource asynchronously
|
||||
"""
|
||||
|
||||
class WriteCloseableStream(Stream):
|
||||
return NotImplemented
|
||||
|
||||
|
||||
class WriteCloseableStream(Stream, ABC):
|
||||
"""
|
||||
Extension to the Stream class that allows
|
||||
shutting down the write end of the stream
|
||||
|
@ -149,7 +161,7 @@ class WriteCloseableStream(Stream):
|
|||
"""
|
||||
|
||||
|
||||
class ChannelReader(AsyncResource):
|
||||
class ChannelReader(AsyncResource, ABC):
|
||||
"""
|
||||
Interface for reading data from a
|
||||
channel
|
||||
|
@ -178,8 +190,11 @@ class ChannelReader(AsyncResource):
|
|||
read from the channel
|
||||
"""
|
||||
|
||||
def fileno(self):
|
||||
return None
|
||||
|
||||
class ChannelWriter(AsyncResource):
|
||||
|
||||
class ChannelWriter(AsyncResource, ABC):
|
||||
"""
|
||||
Interface for writing data to a
|
||||
channel
|
||||
|
@ -201,12 +216,18 @@ class ChannelWriter(AsyncResource):
|
|||
to write to the channel
|
||||
"""
|
||||
|
||||
def fileno(self):
|
||||
return None
|
||||
|
||||
class Channel(ChannelWriter, ChannelReader):
|
||||
|
||||
class Channel(ChannelWriter, ChannelReader, ABC):
|
||||
"""
|
||||
A generic, two-way channel
|
||||
"""
|
||||
|
||||
def fileno(self):
|
||||
return None
|
||||
|
||||
|
||||
class BaseDebugger(ABC):
|
||||
"""
|
||||
|
@ -407,19 +428,21 @@ class BaseIOManager(ABC):
|
|||
return NotImplemented
|
||||
|
||||
@abstractmethod
|
||||
def request_read(self, rsc: AsyncResource):
|
||||
def request_read(self, rsc: AsyncResource, task: Task):
|
||||
"""
|
||||
"Requests" a read operation on the given
|
||||
resource to the I/O manager from the current task
|
||||
resource to the I/O manager from the given
|
||||
task
|
||||
"""
|
||||
|
||||
return NotImplemented
|
||||
|
||||
@abstractmethod
|
||||
def request_write(self, rsc: AsyncResource):
|
||||
def request_write(self, rsc: AsyncResource, task: Task):
|
||||
"""
|
||||
"Requests" a write operation on the given
|
||||
resource to the I/O manager from the current task
|
||||
resource to the I/O manager from the given
|
||||
task
|
||||
"""
|
||||
|
||||
return NotImplemented
|
||||
|
@ -427,8 +450,9 @@ class BaseIOManager(ABC):
|
|||
@abstractmethod
|
||||
def pending(self):
|
||||
"""
|
||||
Returns a boolean value that indicates whether
|
||||
there's any I/O registered in the manager
|
||||
Returns whether there's any tasks waiting
|
||||
to read from/write to a resource registered
|
||||
in the manager
|
||||
"""
|
||||
|
||||
return NotImplemented
|
||||
|
@ -455,6 +479,19 @@ class BaseIOManager(ABC):
|
|||
|
||||
return NotImplemented
|
||||
|
||||
@abstractmethod
|
||||
def get_reader(self, rsc: AsyncResource):
|
||||
"""
|
||||
Returns the task reading from the given
|
||||
resource, if any (None otherwise)
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_writer(self, rsc: AsyncResource):
|
||||
"""
|
||||
Returns the task writing to the given
|
||||
resource, if any (None otherwise)
|
||||
"""
|
||||
|
||||
class SignalManager(ABC):
|
||||
"""
|
||||
|
@ -505,13 +542,47 @@ class BaseKernel(ABC):
|
|||
self.current_scope: "TaskScope" = None
|
||||
self.tools: list[BaseDebugger] = tools or []
|
||||
self.restrict_ki_to_checkpoints: bool = restrict_ki_to_checkpoints
|
||||
self.running: bool = False
|
||||
self.io_manager = io_manager
|
||||
self.signal_managers = signal_managers
|
||||
self.entry_point: Task | None = None
|
||||
# Pool for system tasks
|
||||
self.pool: "TaskPool" = None
|
||||
|
||||
@abstractmethod
|
||||
def wait_readable(self, resource: AsyncResource):
|
||||
"""
|
||||
Schedule the given resource for reading from
|
||||
the current task
|
||||
"""
|
||||
|
||||
return NotImplemented
|
||||
|
||||
@abstractmethod
|
||||
def wait_writable(self, resource: AsyncResource):
|
||||
"""
|
||||
Schedule the given resource for reading from
|
||||
the current task
|
||||
"""
|
||||
|
||||
return NotImplemented
|
||||
|
||||
@abstractmethod
|
||||
def release_resource(self, resource: AsyncResource):
|
||||
"""
|
||||
Releases the given resource from the scheduler
|
||||
"""
|
||||
|
||||
return NotImplemented
|
||||
|
||||
@abstractmethod
|
||||
def notify_closing(self, resource: AsyncResource, broken: bool = False, owner: Task | None = None):
|
||||
"""
|
||||
Notifies the event loop that a given resource
|
||||
is about to be closed and can be unscheduled
|
||||
"""
|
||||
|
||||
return NotImplemented
|
||||
|
||||
@abstractmethod
|
||||
def cancel_task(self, task: Task):
|
||||
"""
|
||||
|
|
|
@ -7,12 +7,13 @@ from structio.abc import (
|
|||
BaseDebugger,
|
||||
BaseIOManager,
|
||||
SignalManager,
|
||||
AsyncResource
|
||||
)
|
||||
from structio.core.context import TaskPool, TaskScope
|
||||
from structio.core.task import Task, TaskState
|
||||
from structio.util.ki import CTRLC_PROTECTION_ENABLED
|
||||
from structio.core.time.queue import TimeQueue
|
||||
from structio.exceptions import StructIOException, Cancelled, TimedOut
|
||||
from structio.exceptions import StructIOException, Cancelled, TimedOut, ResourceClosed, ResourceBroken
|
||||
from collections import deque
|
||||
from typing import Callable, Coroutine, Any
|
||||
from functools import partial
|
||||
|
@ -62,6 +63,26 @@ class FIFOKernel(BaseKernel):
|
|||
]
|
||||
)
|
||||
|
||||
def wait_readable(self, resource: AsyncResource):
|
||||
self.io_manager.request_read(resource, self.current_task)
|
||||
|
||||
def wait_writable(self, resource: AsyncResource):
|
||||
self.io_manager.request_write(resource, self.current_task)
|
||||
|
||||
def notify_closing(self, resource: AsyncResource, broken: bool = False, owner: Task | None = None):
|
||||
if not broken:
|
||||
exc = ResourceClosed("stream has been closed")
|
||||
else:
|
||||
exc = ResourceBroken("stream might be corrupted")
|
||||
owner = owner or self.current_task
|
||||
reader = self.io_manager.get_reader(resource)
|
||||
writer = self.io_manager.get_writer(resource)
|
||||
if reader is not owner:
|
||||
self.throw(reader, exc)
|
||||
if writer is not owner:
|
||||
self.throw(writer, exc)
|
||||
self.reschedule_running()
|
||||
|
||||
def get_closest_deadline_owner(self):
|
||||
return self.paused.peek()
|
||||
|
||||
|
@ -77,18 +98,21 @@ class FIFOKernel(BaseKernel):
|
|||
# We really can't afford to have our internals explode,
|
||||
# sorry!
|
||||
warnings.warn(
|
||||
f"Exception during debugging event delivery ({evt_name!r}): {type(e).__name__} -> {e}",
|
||||
f"Exception during debugging event delivery in {f!r} ({evt_name!r}): {type(e).__name__} -> {e}",
|
||||
)
|
||||
traceback.print_tb(e.__traceback__)
|
||||
# We disable the tool, so it can't raise at the next debugging
|
||||
# event
|
||||
self.tools.remove(tool)
|
||||
|
||||
def done(self):
|
||||
if self.entry_point.done():
|
||||
return True
|
||||
if any([self.run_queue, self.paused, self.io_manager.pending()]):
|
||||
return False
|
||||
for scope in self.scopes:
|
||||
if not scope.done():
|
||||
return False
|
||||
if self.entry_point.done():
|
||||
return True
|
||||
return True
|
||||
|
||||
def spawn(self, func: Callable[[Any, Any], Coroutine[Any, Any, Any]], *args):
|
||||
|
@ -168,6 +192,8 @@ class FIFOKernel(BaseKernel):
|
|||
return
|
||||
if task.state == TaskState.PAUSED:
|
||||
self.paused.discard(task)
|
||||
elif task.state == TaskState.IO:
|
||||
self.io_manager.release_task(task)
|
||||
self.handle_errors(partial(task.coroutine.throw, err), task)
|
||||
|
||||
def reschedule(self, task: Task):
|
||||
|
@ -185,6 +211,7 @@ class FIFOKernel(BaseKernel):
|
|||
self.run_queue.appendleft(self.current_task)
|
||||
|
||||
def schedule_point(self):
|
||||
self.skip = True
|
||||
self.reschedule_running()
|
||||
|
||||
def sleep(self, amount):
|
||||
|
@ -233,7 +260,6 @@ class FIFOKernel(BaseKernel):
|
|||
while not self.done():
|
||||
if self.run_queue and not self.skip:
|
||||
self.handle_errors(self.step)
|
||||
self.running = False
|
||||
self.skip = False
|
||||
if self._sigint_handled and not self.restrict_ki_to_checkpoints:
|
||||
self.throw(self.entry_point, KeyboardInterrupt())
|
||||
|
@ -248,7 +274,7 @@ class FIFOKernel(BaseKernel):
|
|||
Reschedules the currently running task
|
||||
"""
|
||||
|
||||
self.run_queue.append(self.current_task)
|
||||
self.reschedule(self.current_task)
|
||||
|
||||
def handle_errors(self, func: Callable, task: Task | None = None):
|
||||
"""
|
||||
|
@ -296,6 +322,10 @@ class FIFOKernel(BaseKernel):
|
|||
self.event("on_exception_raised", task)
|
||||
self.on_error(task)
|
||||
|
||||
def release_resource(self, resource: AsyncResource):
|
||||
self.io_manager.release(resource)
|
||||
self.reschedule_running()
|
||||
|
||||
def release(self, task: Task):
|
||||
"""
|
||||
Releases the timeouts and associated
|
||||
|
|
|
@ -1,14 +1,17 @@
|
|||
from collections import defaultdict
|
||||
from structio.abc import BaseIOManager, BaseKernel
|
||||
from structio.abc import BaseIOManager, BaseKernel, AsyncResource
|
||||
from structio.core.context import Task
|
||||
from structio.core.run import current_loop, current_task
|
||||
from structio.core.run import current_loop
|
||||
import select
|
||||
|
||||
|
||||
class SimpleIOManager(BaseIOManager):
|
||||
"""
|
||||
A simple, cross-platform, select()-based
|
||||
I/O manager
|
||||
I/O manager. This class is only meant to
|
||||
be used as a default fallback and is quite
|
||||
inefficient and slower compared to more ad-hoc
|
||||
alternatives such as epoll or kqueue (it should
|
||||
work on most platforms though)
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
|
@ -17,21 +20,19 @@ class SimpleIOManager(BaseIOManager):
|
|||
"""
|
||||
|
||||
# Maps resources to tasks
|
||||
self.readers = {}
|
||||
self.writers = {}
|
||||
# This allows us to have a bidirectional mapping:
|
||||
# we know both which tasks are using which resources
|
||||
# and which resources are used by which tasks,
|
||||
# without having to go through too many hoops and jumps.
|
||||
self.tasks: dict[Task, list] = defaultdict(list)
|
||||
self.readers: dict[AsyncResource, Task] = {}
|
||||
self.writers: dict[AsyncResource, Task] = {}
|
||||
|
||||
def pending(self):
|
||||
# We don't return bool(self.resources) because there is
|
||||
# no pending I/O to do if no tasks are waiting to read or
|
||||
# write, even if there's dangling resources around!
|
||||
return bool(self.readers or self.writers)
|
||||
|
||||
def _collect_readers(self) -> list:
|
||||
def get_reader(self, rsc: AsyncResource):
|
||||
return self.readers.get(rsc)
|
||||
|
||||
def get_writer(self, rsc: AsyncResource):
|
||||
return self.writers.get(rsc)
|
||||
|
||||
def _collect_readers(self) -> list[int]:
|
||||
"""
|
||||
Collects all resources that need to be read from,
|
||||
so we can select() on them later
|
||||
|
@ -39,10 +40,10 @@ class SimpleIOManager(BaseIOManager):
|
|||
|
||||
result = []
|
||||
for resource in self.readers:
|
||||
result.append(resource)
|
||||
result.append(resource.fileno())
|
||||
return result
|
||||
|
||||
def _collect_writers(self) -> list:
|
||||
def _collect_writers(self) -> list[int]:
|
||||
"""
|
||||
Collects all resources that need to be written to,
|
||||
so we can select() on them later
|
||||
|
@ -50,34 +51,43 @@ class SimpleIOManager(BaseIOManager):
|
|||
|
||||
result = []
|
||||
for resource in self.writers:
|
||||
result.append(resource)
|
||||
result.append(resource.fileno())
|
||||
return result
|
||||
|
||||
def wait_io(self):
|
||||
kernel: BaseKernel = current_loop()
|
||||
deadline = kernel.get_closest_deadline()
|
||||
if deadline == float("inf"):
|
||||
deadline = 0
|
||||
readable, writable, _ = select.select(
|
||||
self._collect_readers(),
|
||||
self._collect_writers(),
|
||||
[],
|
||||
kernel.get_closest_deadline(),
|
||||
deadline,
|
||||
)
|
||||
for read_ready in readable:
|
||||
kernel.reschedule(self.readers[read_ready])
|
||||
for resource, task in self.readers.items():
|
||||
if resource.fileno() == read_ready:
|
||||
kernel.reschedule(task)
|
||||
for write_ready in writable:
|
||||
kernel.reschedule(self.writers[write_ready])
|
||||
for resource, task in self.writers.items():
|
||||
if resource.fileno() == write_ready:
|
||||
kernel.reschedule(task)
|
||||
|
||||
def request_read(self, rsc):
|
||||
task = current_task()
|
||||
def request_read(self, rsc: AsyncResource, task: Task):
|
||||
self.readers[rsc] = task
|
||||
|
||||
def request_write(self, rsc):
|
||||
task = current_task()
|
||||
def request_write(self, rsc: AsyncResource, task: Task):
|
||||
self.writers[rsc] = task
|
||||
|
||||
def release(self, resource):
|
||||
def release(self, resource: AsyncResource):
|
||||
self.readers.pop(resource, None)
|
||||
self.writers.pop(resource, None)
|
||||
|
||||
def release_task(self, task: Task):
|
||||
for resource in self.tasks[task]:
|
||||
self.release(resource)
|
||||
for resource, owner in self.readers.copy().items():
|
||||
if owner == task:
|
||||
self.readers.pop(resource)
|
||||
for resource, owner in self.writers.copy().items():
|
||||
if owner == task:
|
||||
self.writers.pop(resource)
|
||||
|
|
|
@ -68,3 +68,19 @@ async def checkpoint():
|
|||
|
||||
await check_cancelled()
|
||||
await schedule_point()
|
||||
|
||||
|
||||
async def wait_readable(rsc):
|
||||
return await syscall("wait_readable", rsc)
|
||||
|
||||
|
||||
async def wait_writable(rsc):
|
||||
return await syscall("wait_writable", rsc)
|
||||
|
||||
|
||||
async def closing(rsc):
|
||||
return await syscall("notify_closing", rsc)
|
||||
|
||||
|
||||
async def release(rsc):
|
||||
return await syscall("release_resource", rsc)
|
||||
|
|
|
@ -84,11 +84,10 @@ class TimeQueue:
|
|||
def get_closest_deadline(self) -> float:
|
||||
"""
|
||||
Returns the closest deadline that is meant to expire
|
||||
or raises IndexError if the queue is empty
|
||||
"""
|
||||
|
||||
if not self:
|
||||
raise IndexError("TimeQueue is empty")
|
||||
return float("inf")
|
||||
return self.container[0][0]
|
||||
|
||||
def __iter__(self):
|
||||
|
|
|
@ -38,3 +38,11 @@ class ResourceBusy(StructIOException):
|
|||
Raised when an attempt is made to use an
|
||||
asynchronous resource that is currently busy
|
||||
"""
|
||||
|
||||
|
||||
class ResourceBroken(StructIOException):
|
||||
"""
|
||||
Raised when an asynchronous resource gets
|
||||
corrupted and is no longer usable
|
||||
"""
|
||||
|
||||
|
|
|
@ -0,0 +1,152 @@
|
|||
# This is, ahem, inspired by Curio and Trio. See https://github.com/dabeaz/curio/issues/104
|
||||
import os
|
||||
import warnings
|
||||
from structio.core.syscalls import checkpoint, wait_readable, wait_writable, closing, release
|
||||
from structio.exceptions import ResourceClosed
|
||||
from structio.abc import AsyncResource
|
||||
try:
|
||||
from ssl import SSLWantReadError, SSLWantWriteError, SSLSocket
|
||||
WantRead = (BlockingIOError, SSLWantReadError, InterruptedError)
|
||||
WantWrite = (BlockingIOError, SSLWantWriteError, InterruptedError)
|
||||
except ImportError:
|
||||
WantWrite = (BlockingIOError, InterruptedError)
|
||||
WantRead = (BlockingIOError, InterruptedError)
|
||||
SSLSocket = None
|
||||
|
||||
|
||||
class FdWrapper:
|
||||
"""
|
||||
A simple wrapper around a file descriptor that
|
||||
allows the event loop to perform an optimization
|
||||
regarding I/O event registration safely. This is
|
||||
because while integer file descriptors can be reused
|
||||
by the operating system, instances of this class will
|
||||
not (hence if the event loop keeps around a dead instance
|
||||
of an FdWrapper, it at least won't accidentally register
|
||||
a new file with that same file descriptor). A bonus is
|
||||
that this also allows us to always assume that we can call
|
||||
fileno() on all objects registered in our selector, regardless
|
||||
of whether the wrapped fd is an int or something else entirely
|
||||
"""
|
||||
|
||||
__slots__ = ("fd", )
|
||||
|
||||
def __init__(self, fd):
|
||||
self.fd = fd
|
||||
|
||||
def fileno(self):
|
||||
return self.fd
|
||||
|
||||
# Can be converted to an int
|
||||
def __int__(self):
|
||||
return self.fd
|
||||
|
||||
def __repr__(self):
|
||||
return f"<fd={self.fd!r}>"
|
||||
|
||||
|
||||
class AsyncStream(AsyncResource):
|
||||
"""
|
||||
A generic asynchronous stream over
|
||||
a file descriptor. Functionality
|
||||
is OS-dependent
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fd: int,
|
||||
open_fd: bool = True,
|
||||
close_on_context_exit: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
self._fd = FdWrapper(fd)
|
||||
self.fileobj = None
|
||||
if open_fd:
|
||||
self.fileobj = os.fdopen(int(self._fd), **kwargs)
|
||||
os.set_blocking(int(self._fd), False)
|
||||
# Do we close ourselves upon the end of a context manager?
|
||||
self.close_on_context_exit = close_on_context_exit
|
||||
|
||||
async def read(self, size: int = -1):
|
||||
"""
|
||||
Reads up to size bytes from the
|
||||
given stream. If size == -1, read
|
||||
until EOF is reached
|
||||
"""
|
||||
|
||||
while True:
|
||||
try:
|
||||
data = self.fileobj.read(size)
|
||||
await checkpoint()
|
||||
return data
|
||||
except WantRead:
|
||||
await wait_readable(self)
|
||||
|
||||
async def write(self, data):
|
||||
"""
|
||||
Writes data to the stream.
|
||||
Returns the number of bytes
|
||||
written
|
||||
"""
|
||||
|
||||
while True:
|
||||
try:
|
||||
data = self.fileobj.write(data)
|
||||
await checkpoint()
|
||||
return data
|
||||
except WantWrite:
|
||||
await wait_writable(self)
|
||||
|
||||
async def close(self):
|
||||
"""
|
||||
Closes the stream asynchronously
|
||||
"""
|
||||
|
||||
if self.fileno() == -1:
|
||||
raise ResourceClosed("I/O operation on closed stream")
|
||||
self._fd = -1
|
||||
await closing(self)
|
||||
await release(self)
|
||||
self.fileobj.close()
|
||||
self.fileobj = None
|
||||
await checkpoint()
|
||||
|
||||
def fileno(self):
|
||||
"""
|
||||
Wrapper socket method
|
||||
"""
|
||||
|
||||
return int(self._fd)
|
||||
|
||||
async def __aenter__(self):
|
||||
self.fileobj.__enter__()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *args):
|
||||
if self._fd != -1 and self.close_on_context_exit:
|
||||
await self.close()
|
||||
|
||||
async def dup(self):
|
||||
"""
|
||||
Wrapper stream method
|
||||
"""
|
||||
|
||||
return type(self)(os.dup(self.fileno()))
|
||||
|
||||
def __repr__(self):
|
||||
return f"AsyncStream({self.fileobj})"
|
||||
|
||||
def __del__(self):
|
||||
"""
|
||||
Stream destructor. Do *not* call
|
||||
this directly: stuff will break
|
||||
"""
|
||||
|
||||
if self._fd != -1 and self.fileobj.fileno() != -1:
|
||||
try:
|
||||
os.set_blocking(self.fileno(), False)
|
||||
os.close(self.fileno())
|
||||
except OSError as e:
|
||||
warnings.warn(
|
||||
f"An exception occurred in __del__ for stream {self} -> {type(e).__name__}: {e}"
|
||||
)
|
|
@ -44,7 +44,7 @@ _FILE_ASYNC_METHODS = {
|
|||
}
|
||||
|
||||
|
||||
class AsyncResourceWrapper(AsyncResource):
|
||||
class AsyncFile(AsyncResource):
|
||||
"""
|
||||
Asynchronous wrapper around regular file-like objects.
|
||||
Blocking operations are turned into async ones using threads.
|
||||
|
@ -52,6 +52,9 @@ class AsyncResourceWrapper(AsyncResource):
|
|||
and read/write methods
|
||||
"""
|
||||
|
||||
def fileno(self):
|
||||
return self.handle.fileno()
|
||||
|
||||
def __init__(self, f):
|
||||
self._file = f
|
||||
|
||||
|
@ -108,7 +111,8 @@ class AsyncResourceWrapper(AsyncResource):
|
|||
# This operation is non-cancellable, meaning it'll run
|
||||
# no matter what our event loop has to say about it.
|
||||
# After we're done, we'll obviously re-raise the cancellation
|
||||
# if necessary
|
||||
# if necessary. This ensures files are always closed even when
|
||||
# the operation gets cancelled
|
||||
await structio.thread.run_in_worker(self.handle.close)
|
||||
# If we were cancelled, here is where we raise
|
||||
await check_cancelled()
|
||||
|
@ -123,7 +127,7 @@ async def open_file(
|
|||
newline=None,
|
||||
closefd=True,
|
||||
opener=None,
|
||||
) -> AsyncResourceWrapper:
|
||||
) -> AsyncFile:
|
||||
"""
|
||||
Like io.open(), but async. Magic
|
||||
"""
|
||||
|
@ -135,13 +139,13 @@ async def open_file(
|
|||
)
|
||||
|
||||
|
||||
def wrap_file(file) -> AsyncResourceWrapper:
|
||||
def wrap_file(file) -> AsyncFile:
|
||||
"""
|
||||
Wraps a file-like object into an async
|
||||
wrapper
|
||||
"""
|
||||
|
||||
return AsyncResourceWrapper(file)
|
||||
return AsyncFile(file)
|
||||
|
||||
|
||||
stdin = wrap_file(sys.stdin)
|
||||
|
|
|
@ -1,13 +1,344 @@
|
|||
import structio
|
||||
from structio.abc import AsyncResource
|
||||
from structio.core.syscalls import check_cancelled
|
||||
from structio.io import FdWrapper, WantRead, WantWrite, SSLSocket
|
||||
from structio.exceptions import ResourceClosed, ResourceBroken
|
||||
from structio.core.syscalls import wait_readable, wait_writable, checkpoint, closing, release
|
||||
from functools import wraps
|
||||
import socket as _socket
|
||||
|
||||
|
||||
@wraps(_socket.socket)
|
||||
def socket(*args, **kwargs):
|
||||
return None # TODO
|
||||
return AsyncSocket(_socket.socket(*args, **kwargs))
|
||||
|
||||
# TODO
|
||||
|
||||
class AsyncSocket(AsyncResource):
|
||||
"""
|
||||
Abstraction layer for asynchronous sockets
|
||||
"""
|
||||
|
||||
def fileno(self):
|
||||
return int(self._fd)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
sock: _socket.socket,
|
||||
close_on_context_exit: bool = True,
|
||||
do_handshake_on_connect: bool = True,
|
||||
):
|
||||
self._fd = FdWrapper(sock.fileno())
|
||||
self.close_on_context_exit = close_on_context_exit
|
||||
# Do we perform the TCP handshake automatically
|
||||
# upon connection? This is mostly needed for SSL
|
||||
# sockets
|
||||
self.do_handshake_on_connect = do_handshake_on_connect
|
||||
self.socket = sock
|
||||
self.socket.setblocking(False)
|
||||
# A socket that isn't connected doesn't
|
||||
# need to be closed
|
||||
self.needs_closing: bool = False
|
||||
|
||||
async def receive(self, max_size: int, flags: int = 0) -> bytes:
|
||||
"""
|
||||
Receives up to max_size bytes from a socket asynchronously
|
||||
"""
|
||||
|
||||
assert max_size >= 1, "max_size must be >= 1"
|
||||
if self._fd == -1:
|
||||
raise ResourceClosed("I/O operation on closed socket")
|
||||
while True:
|
||||
try:
|
||||
data = self.socket.recv(max_size, flags)
|
||||
await checkpoint()
|
||||
return data
|
||||
except WantRead:
|
||||
await wait_readable(self)
|
||||
except WantWrite:
|
||||
await wait_writable(self)
|
||||
|
||||
async def receive_exactly(self, size: int, flags: int = 0) -> bytes:
|
||||
"""
|
||||
Receives exactly size bytes from a socket asynchronously.
|
||||
"""
|
||||
|
||||
# https://stackoverflow.com/questions/55825905/how-can-i-reliably-read-exactly-n-bytes-from-a-tcp-socket
|
||||
buf = bytearray(size)
|
||||
pos = 0
|
||||
while pos < size:
|
||||
n = await self.recv_into(memoryview(buf)[pos:], flags=flags)
|
||||
if n == 0:
|
||||
raise ResourceBroken("incomplete read detected")
|
||||
pos += n
|
||||
return bytes(buf)
|
||||
|
||||
async def connect(self, address):
|
||||
"""
|
||||
Wrapper socket method
|
||||
"""
|
||||
|
||||
if self._fd == -1:
|
||||
raise ResourceClosed("I/O operation on closed socket")
|
||||
while True:
|
||||
try:
|
||||
self.socket.connect(address)
|
||||
if self.do_handshake_on_connect:
|
||||
await self.do_handshake()
|
||||
await checkpoint()
|
||||
break
|
||||
except WantWrite:
|
||||
await wait_writable(self)
|
||||
self.needs_closing = True
|
||||
|
||||
async def close(self):
|
||||
"""
|
||||
Wrapper socket method
|
||||
"""
|
||||
|
||||
if self.needs_closing:
|
||||
self.socket.close()
|
||||
await checkpoint()
|
||||
|
||||
async def accept(self):
|
||||
"""
|
||||
Accepts the socket, completing the 3-step TCP handshake asynchronously
|
||||
"""
|
||||
|
||||
if self._fd == -1:
|
||||
raise ResourceClosed("I/O operation on closed socket")
|
||||
while True:
|
||||
try:
|
||||
remote, addr = self.socket.accept()
|
||||
await checkpoint()
|
||||
return type(self)(remote), addr
|
||||
except WantRead:
|
||||
await wait_readable(self)
|
||||
|
||||
async def send_all(self, data: bytes, flags: int = 0):
|
||||
"""
|
||||
Sends all data inside the buffer asynchronously until it is empty
|
||||
"""
|
||||
|
||||
if self._fd == -1:
|
||||
raise ResourceClosed("I/O operation on closed socket")
|
||||
sent_no = 0
|
||||
while data:
|
||||
try:
|
||||
sent_no = self.socket.send(data, flags)
|
||||
await checkpoint()
|
||||
except WantRead:
|
||||
await wait_readable(self)
|
||||
except WantWrite:
|
||||
await wait_writable(self)
|
||||
data = data[sent_no:]
|
||||
|
||||
async def shutdown(self, how):
|
||||
"""
|
||||
Wrapper socket method
|
||||
"""
|
||||
|
||||
if self.fileno() == -1:
|
||||
raise ResourceClosed("I/O operation on closed socket")
|
||||
if self.socket:
|
||||
self.socket.shutdown(how)
|
||||
await checkpoint()
|
||||
|
||||
async def bind(self, addr: tuple):
|
||||
"""
|
||||
Binds the socket to an address
|
||||
:param addr: The address, port tuple to bind to
|
||||
:type addr: tuple
|
||||
"""
|
||||
|
||||
if self._fd == -1:
|
||||
raise ResourceClosed("I/O operation on closed socket")
|
||||
self.socket.bind(addr)
|
||||
await checkpoint()
|
||||
|
||||
async def listen(self, backlog: int):
|
||||
"""
|
||||
Starts listening with the given backlog
|
||||
:param backlog: The socket's backlog
|
||||
:type backlog: int
|
||||
"""
|
||||
|
||||
if self._fd == -1:
|
||||
raise ResourceClosed("I/O operation on closed socket")
|
||||
self.socket.listen(backlog)
|
||||
await checkpoint()
|
||||
|
||||
# Yes, I stole these from Curio because I could not be
|
||||
# arsed to write a bunch of uninteresting simple socket
|
||||
# methods from scratch, deal with it.
|
||||
|
||||
def settimeout(self, seconds):
|
||||
"""
|
||||
Wrapper socket method
|
||||
"""
|
||||
|
||||
raise RuntimeError("Use with_timeout() to set a timeout")
|
||||
|
||||
def gettimeout(self):
|
||||
"""
|
||||
Wrapper socket method
|
||||
"""
|
||||
|
||||
return None
|
||||
|
||||
def dup(self):
|
||||
"""
|
||||
Wrapper socket method
|
||||
"""
|
||||
|
||||
return type(self)(self.socket.dup(), self.do_handshake_on_connect)
|
||||
|
||||
async def do_handshake(self):
|
||||
"""
|
||||
Wrapper socket method
|
||||
"""
|
||||
|
||||
if not hasattr(self.socket, "do_handshake"):
|
||||
return
|
||||
while True:
|
||||
try:
|
||||
self.socket: SSLSocket # Silences pycharm warnings
|
||||
self.socket.do_handshake()
|
||||
await checkpoint()
|
||||
except WantRead:
|
||||
await wait_readable(self)
|
||||
except WantWrite:
|
||||
await wait_writable(self)
|
||||
|
||||
async def recvfrom(self, buffersize, flags=0):
|
||||
"""
|
||||
Wrapper socket method
|
||||
"""
|
||||
|
||||
while True:
|
||||
try:
|
||||
return self.socket.recvfrom(buffersize, flags)
|
||||
except WantRead:
|
||||
await wait_readable(self)
|
||||
except WantWrite:
|
||||
await wait_writable(self)
|
||||
|
||||
async def recv_into(self, buffer, nbytes=0, flags=0):
|
||||
"""
|
||||
Wrapper socket method
|
||||
"""
|
||||
|
||||
while True:
|
||||
try:
|
||||
data = self.socket.recv_into(buffer, nbytes, flags)
|
||||
await checkpoint()
|
||||
return data
|
||||
except WantRead:
|
||||
await wait_readable(self)
|
||||
except WantWrite:
|
||||
await wait_writable(self)
|
||||
|
||||
async def recvfrom_into(self, buffer, bytes=0, flags=0):
|
||||
"""
|
||||
Wrapper socket method
|
||||
"""
|
||||
|
||||
while True:
|
||||
try:
|
||||
data = self.socket.recvfrom_into(buffer, bytes, flags)
|
||||
await checkpoint()
|
||||
return data
|
||||
except WantRead:
|
||||
await wait_readable(self)
|
||||
except WantWrite:
|
||||
await wait_writable(self)
|
||||
|
||||
async def sendto(self, bytes, flags_or_address, address=None):
|
||||
"""
|
||||
Wrapper socket method
|
||||
"""
|
||||
|
||||
if address:
|
||||
flags = flags_or_address
|
||||
else:
|
||||
address = flags_or_address
|
||||
flags = 0
|
||||
while True:
|
||||
try:
|
||||
data = self.socket.sendto(bytes, flags, address)
|
||||
await checkpoint()
|
||||
return data
|
||||
except WantWrite:
|
||||
await wait_writable(self)
|
||||
except WantRead:
|
||||
await wait_readable(self)
|
||||
|
||||
async def getpeername(self):
|
||||
"""
|
||||
Wrapper socket method
|
||||
"""
|
||||
|
||||
while True:
|
||||
try:
|
||||
data = self.socket.getpeername()
|
||||
await checkpoint()
|
||||
return data
|
||||
except WantWrite:
|
||||
await wait_writable(self)
|
||||
except WantRead:
|
||||
await wait_readable(self)
|
||||
|
||||
async def getsockname(self):
|
||||
"""
|
||||
Wrapper socket method
|
||||
"""
|
||||
|
||||
while True:
|
||||
try:
|
||||
data = self.socket.getpeername()
|
||||
await checkpoint()
|
||||
return data
|
||||
except WantWrite:
|
||||
await wait_writable(self)
|
||||
except WantRead:
|
||||
await wait_readable(self)
|
||||
|
||||
async def recvmsg(self, bufsize, ancbufsize=0, flags=0):
|
||||
"""
|
||||
Wrapper socket method
|
||||
"""
|
||||
|
||||
while True:
|
||||
try:
|
||||
data = self.socket.recvmsg(bufsize, ancbufsize, flags)
|
||||
await checkpoint()
|
||||
return data
|
||||
except WantRead:
|
||||
await wait_readable(self)
|
||||
|
||||
async def recvmsg_into(self, buffers, ancbufsize=0, flags=0):
|
||||
"""
|
||||
Wrapper socket method
|
||||
"""
|
||||
|
||||
while True:
|
||||
try:
|
||||
data = self.socket.recvmsg_into(buffers, ancbufsize, flags)
|
||||
await checkpoint()
|
||||
return data
|
||||
except WantRead:
|
||||
await wait_readable(self)
|
||||
|
||||
async def sendmsg(self, buffers, ancdata=(), flags=0, address=None):
|
||||
"""
|
||||
Wrapper socket method
|
||||
"""
|
||||
|
||||
while True:
|
||||
try:
|
||||
data = self.socket.sendmsg(buffers, ancdata, flags, address)
|
||||
await checkpoint()
|
||||
return data
|
||||
except WantRead:
|
||||
await wait_writable(self)
|
||||
|
||||
def __repr__(self):
|
||||
return f"AsyncSocket({self.socket})"
|
||||
|
||||
|
|
|
@ -1,10 +1,39 @@
|
|||
# Module inspired by subprocess which allows for asynchronous
|
||||
# multiprocessing
|
||||
import os
|
||||
import subprocess
|
||||
from subprocess import (
|
||||
CalledProcessError,
|
||||
CompletedProcess,
|
||||
SubprocessError,
|
||||
STDOUT,
|
||||
DEVNULL,
|
||||
PIPE
|
||||
)
|
||||
|
||||
|
||||
class Process:
|
||||
class Popen:
|
||||
"""
|
||||
An asynchronous process
|
||||
Wrapper around subprocess.Popen, but async
|
||||
"""
|
||||
|
||||
# TODO
|
||||
def __init__(self, *args, **kwargs):
|
||||
"""
|
||||
Public object constructor
|
||||
"""
|
||||
|
||||
if "universal_newlines" in kwargs:
|
||||
# Not sure why? But everyone else is doing it so :shrug:
|
||||
raise RuntimeError("universal_newlines is not supported")
|
||||
if stdin := kwargs.get("stdin"):
|
||||
# Curio mentions stuff breaking if the child process
|
||||
# is passed a stdin fd that is set to non-blocking mode
|
||||
if hasattr(os, "set_blocking"):
|
||||
os.set_blocking(stdin.fileno(), True)
|
||||
# Delegate to Popen's constructor
|
||||
self._process = subprocess.Popen(*args, **kwargs)
|
||||
|
||||
def __getattr__(self, item):
|
||||
# Delegate to internal process object
|
||||
return getattr(self._process, item)
|
||||
|
||||
|
|
|
@ -0,0 +1,136 @@
|
|||
import structio
|
||||
import logging
|
||||
import sys
|
||||
|
||||
# An asynchronous chatroom
|
||||
|
||||
clients: dict[structio.socket.AsyncSocket, list[str, str]] = {}
|
||||
names: set[str] = set()
|
||||
|
||||
|
||||
async def event_handler(q: structio.Queue):
|
||||
"""
|
||||
Reads data submitted onto the queue
|
||||
"""
|
||||
|
||||
try:
|
||||
logging.info("Event handler spawned")
|
||||
while True:
|
||||
msg, payload = await q.get()
|
||||
logging.info(f"Caught event {msg!r} with the following payload: {payload}")
|
||||
except Exception as e:
|
||||
logging.error(f"An exception occurred in the message handler -> {type(e).__name__}: {e}")
|
||||
except structio.exceptions.Cancelled:
|
||||
logging.warning(f"Cancellation detected, message handler shutting down")
|
||||
# Propagate the cancellation
|
||||
raise
|
||||
|
||||
|
||||
async def serve(bind_address: tuple):
|
||||
"""
|
||||
Serves asynchronously forever (or until Ctrl+C ;))
|
||||
:param bind_address: The address to bind the server to, represented as a tuple
|
||||
(address, port) where address is a string and port is an integer
|
||||
"""
|
||||
|
||||
sock = structio.socket.socket()
|
||||
queue = structio.Queue()
|
||||
await sock.bind(bind_address)
|
||||
await sock.listen(5)
|
||||
logging.info(f"Serving asynchronously at {bind_address[0]}:{bind_address[1]}")
|
||||
async with structio.create_pool() as pool:
|
||||
pool.spawn(event_handler, queue)
|
||||
async with sock:
|
||||
while True:
|
||||
try:
|
||||
conn, address_tuple = await sock.accept()
|
||||
clients[conn] = ["", f"{address_tuple[0]}:{address_tuple[1]}"]
|
||||
await queue.put(("connect", clients[conn]))
|
||||
logging.info(f"{address_tuple[0]}:{address_tuple[1]} connected")
|
||||
await pool.spawn(handler, conn, queue)
|
||||
except Exception as err:
|
||||
# Because exceptions just *work*
|
||||
logging.info(f"{address_tuple[0]}:{address_tuple[1]} has raised {type(err).__name__}: {err}")
|
||||
|
||||
|
||||
async def handler(sock: structio.socket.AsyncSocket, q: structio.Queue):
|
||||
"""
|
||||
Handles a single client connection
|
||||
:param sock: The AsyncSocket object connected to the client
|
||||
"""
|
||||
|
||||
address = clients[sock][1]
|
||||
name = ""
|
||||
async with sock: # Closes the socket automatically
|
||||
await sock.send_all(b"Welcome to the chatroom pal, may you tell me your name?\n> ")
|
||||
cond = True
|
||||
while cond:
|
||||
while not name.endswith("\n"):
|
||||
name = (await sock.receive(64)).decode()
|
||||
if name == "":
|
||||
cond = False
|
||||
break
|
||||
name = name.rstrip("\n")
|
||||
if name not in names:
|
||||
names.add(name)
|
||||
clients[sock][0] = name
|
||||
break
|
||||
else:
|
||||
await sock.send_all(b"Sorry, but that name is already taken. Try again!\n> ")
|
||||
await sock.send_all(f"Okay {name}, welcome to the chatroom!\n".encode())
|
||||
await q.put(("join", (address, name)))
|
||||
logging.info(f"{name} has joined the chatroom ({address}), informing clients")
|
||||
for i, client_sock in enumerate(clients):
|
||||
if client_sock != sock and clients[client_sock][0]:
|
||||
await client_sock.send_all(f"{name} joins the chatroom!\n> ".encode())
|
||||
while True:
|
||||
await sock.send_all(b"> ")
|
||||
data = await sock.receive(1024)
|
||||
if not data:
|
||||
break
|
||||
decoded = data.decode().rstrip("\n")
|
||||
if decoded.startswith("/"):
|
||||
logging.info(f"{name} issued server command {decoded}")
|
||||
await q.put(("cmd", (name, decoded[1:])))
|
||||
match decoded[1:]:
|
||||
case "bye":
|
||||
await sock.send_all(b"Bye!\n")
|
||||
break
|
||||
case _:
|
||||
await sock.send_all(b"Unknown command\n")
|
||||
else:
|
||||
await q.put(("msg", (name, data)))
|
||||
logging.info(f"Got: {data!r} from {address}")
|
||||
for i, client_sock in enumerate(clients):
|
||||
if client_sock != sock and clients[client_sock][0]:
|
||||
logging.info(f"Sending {data!r} to {':'.join(map(str, await client_sock.getpeername()))}")
|
||||
if not data.endswith(b"\n"):
|
||||
data += b"\n"
|
||||
await client_sock.send_all(f"[{name}] ({address}): {data.decode()}> ".encode())
|
||||
logging.info(f"Sent {data!r} to {i} clients")
|
||||
await q.put(("leave", name))
|
||||
logging.info(f"Connection from {address} closed")
|
||||
logging.info(f"{name} has left the chatroom ({address}), informing clients")
|
||||
for i, client_sock in enumerate(clients):
|
||||
if client_sock != sock and clients[client_sock][0]:
|
||||
await client_sock.send_all(f"{name} has left the chatroom\n> ".encode())
|
||||
clients.pop(sock)
|
||||
names.discard(name)
|
||||
logging.info("Handler shutting down")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
port = int(sys.argv[1]) if len(sys.argv) > 1 else 1501
|
||||
logging.basicConfig(
|
||||
level=20,
|
||||
format="[%(levelname)s] %(asctime)s %(message)s",
|
||||
datefmt="%d/%m/%Y %p",
|
||||
)
|
||||
try:
|
||||
structio.run(serve, ("0.0.0.0", port))
|
||||
except (Exception, KeyboardInterrupt) as error: # Exceptions propagate!
|
||||
if isinstance(error, KeyboardInterrupt):
|
||||
logging.info("Ctrl+C detected, exiting")
|
||||
else:
|
||||
logging.error(f"Exiting due to a {type(error).__name__}: {error}")
|
||||
|
|
@ -0,0 +1,74 @@
|
|||
import sys
|
||||
import logging
|
||||
import structio
|
||||
|
||||
|
||||
# A test to check for asynchronous I/O
|
||||
|
||||
|
||||
async def serve(bind_address: tuple):
|
||||
"""
|
||||
Serves asynchronously forever
|
||||
:param bind_address: The address to bind the server to represented as a tuple
|
||||
(address, port) where address is a string and port is an integer
|
||||
"""
|
||||
|
||||
sock = structio.socket.socket()
|
||||
await sock.bind(bind_address)
|
||||
await sock.listen(5)
|
||||
logging.info(f"Serving asynchronously at {bind_address[0]}:{bind_address[1]}")
|
||||
async with structio.create_pool() as ctx:
|
||||
async with sock:
|
||||
while True:
|
||||
try:
|
||||
conn, address_tuple = await sock.accept()
|
||||
logging.info(f"{address_tuple[0]}:{address_tuple[1]} connected")
|
||||
await ctx.spawn(handler, conn, address_tuple)
|
||||
except Exception as err:
|
||||
# Because exceptions just *work*
|
||||
logging.info(f"{address_tuple[0]}:{address_tuple[1]} has raised {type(err).__name__}: {err}")
|
||||
|
||||
|
||||
async def handler(sock: structio.socket.AsyncSocket, client_address: tuple):
|
||||
"""
|
||||
Handles a single client connection
|
||||
:param sock: The AsyncSocket object connected to the client
|
||||
:param client_address: The client's address represented as a tuple
|
||||
(address, port) where address is a string and port is an integer
|
||||
:type client_address: tuple
|
||||
"""
|
||||
|
||||
address = f"{client_address[0]}:{client_address[1]}"
|
||||
async with sock: # Closes the socket automatically
|
||||
await sock.send_all(b"Welcome to the server pal, feel free to send me something!\n")
|
||||
while True:
|
||||
await sock.send_all(b"-> ")
|
||||
data = await sock.receive(1024)
|
||||
if not data:
|
||||
break
|
||||
elif data == b"exit\n":
|
||||
await sock.send_all(b"I'm dead dude\n")
|
||||
raise TypeError("Oh, no, I'm gonna die!")
|
||||
elif data == b"fatal\n":
|
||||
await sock.send_all(b"What a dick\n")
|
||||
raise KeyboardInterrupt("He told me to do it!")
|
||||
logging.info(f"Got: {data!r} from {address}")
|
||||
await sock.send_all(b"Got: " + data)
|
||||
logging.info(f"Echoed back {data!r} to {address}")
|
||||
logging.info(f"Connection from {address} closed")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
port = int(sys.argv[1]) if len(sys.argv) > 1 else 1501
|
||||
logging.basicConfig(
|
||||
level=20,
|
||||
format="[%(levelname)s] %(asctime)s %(message)s",
|
||||
datefmt="%d/%m/%Y %H:%M:%S %p",
|
||||
)
|
||||
try:
|
||||
structio.run(serve, ("localhost", port))
|
||||
except (Exception, KeyboardInterrupt) as error: # Exceptions propagate!
|
||||
if isinstance(error, KeyboardInterrupt):
|
||||
logging.info("Ctrl+C detected, exiting")
|
||||
else:
|
||||
logging.error(f"Exiting due to a {type(error).__name__}: {error}")
|
Loading…
Reference in New Issue