Ported old aiosched sockets and streams and added related tests
This commit is contained in:
parent
3ea159c858
commit
26a43d5f84
|
@ -11,6 +11,8 @@ from structio.core import task
|
||||||
from structio.core.task import Task, TaskState
|
from structio.core.task import Task, TaskState
|
||||||
from structio.sync import Event, Queue, MemoryChannel, Semaphore, Lock, RLock, emit, on_event, register_event
|
from structio.sync import Event, Queue, MemoryChannel, Semaphore, Lock, RLock, emit, on_event, register_event
|
||||||
from structio.abc import Channel, Stream, ChannelReader, ChannelWriter
|
from structio.abc import Channel, Stream, ChannelReader, ChannelWriter
|
||||||
|
from structio.io import socket
|
||||||
|
from structio.io.socket import AsyncSocket
|
||||||
from structio.io.files import (
|
from structio.io.files import (
|
||||||
open_file,
|
open_file,
|
||||||
wrap_file,
|
wrap_file,
|
||||||
|
|
|
@ -50,11 +50,15 @@ class AsyncResource(ABC):
|
||||||
async def close(self):
|
async def close(self):
|
||||||
return NotImplemented
|
return NotImplemented
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def fileno(self):
|
||||||
|
return NotImplemented
|
||||||
|
|
||||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||||
await self.close()
|
await self.close()
|
||||||
|
|
||||||
|
|
||||||
class StreamWriter(AsyncResource):
|
class StreamWriter(AsyncResource, ABC):
|
||||||
"""
|
"""
|
||||||
Interface for writing binary data to
|
Interface for writing binary data to
|
||||||
a byte stream
|
a byte stream
|
||||||
|
@ -84,7 +88,7 @@ class StreamWriter(AsyncResource):
|
||||||
return NotImplemented
|
return NotImplemented
|
||||||
|
|
||||||
|
|
||||||
class StreamReader(AsyncResource):
|
class StreamReader(AsyncResource, ABC):
|
||||||
"""
|
"""
|
||||||
Interface for reading binary data from
|
Interface for reading binary data from
|
||||||
a byte stream. The stream implements the
|
a byte stream. The stream implements the
|
||||||
|
@ -124,13 +128,21 @@ class StreamReader(AsyncResource):
|
||||||
return NotImplemented
|
return NotImplemented
|
||||||
|
|
||||||
|
|
||||||
class Stream(StreamReader, StreamWriter):
|
class Stream(StreamReader, StreamWriter, ABC):
|
||||||
"""
|
"""
|
||||||
A generic, asynchronous, readable/writable binary stream
|
A generic, asynchronous, readable/writable binary stream
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def flush(self):
|
||||||
|
"""
|
||||||
|
Flushes the underlying resource asynchronously
|
||||||
|
"""
|
||||||
|
|
||||||
class WriteCloseableStream(Stream):
|
return NotImplemented
|
||||||
|
|
||||||
|
|
||||||
|
class WriteCloseableStream(Stream, ABC):
|
||||||
"""
|
"""
|
||||||
Extension to the Stream class that allows
|
Extension to the Stream class that allows
|
||||||
shutting down the write end of the stream
|
shutting down the write end of the stream
|
||||||
|
@ -149,7 +161,7 @@ class WriteCloseableStream(Stream):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
class ChannelReader(AsyncResource):
|
class ChannelReader(AsyncResource, ABC):
|
||||||
"""
|
"""
|
||||||
Interface for reading data from a
|
Interface for reading data from a
|
||||||
channel
|
channel
|
||||||
|
@ -178,8 +190,11 @@ class ChannelReader(AsyncResource):
|
||||||
read from the channel
|
read from the channel
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def fileno(self):
|
||||||
|
return None
|
||||||
|
|
||||||
class ChannelWriter(AsyncResource):
|
|
||||||
|
class ChannelWriter(AsyncResource, ABC):
|
||||||
"""
|
"""
|
||||||
Interface for writing data to a
|
Interface for writing data to a
|
||||||
channel
|
channel
|
||||||
|
@ -201,12 +216,18 @@ class ChannelWriter(AsyncResource):
|
||||||
to write to the channel
|
to write to the channel
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def fileno(self):
|
||||||
|
return None
|
||||||
|
|
||||||
class Channel(ChannelWriter, ChannelReader):
|
|
||||||
|
class Channel(ChannelWriter, ChannelReader, ABC):
|
||||||
"""
|
"""
|
||||||
A generic, two-way channel
|
A generic, two-way channel
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def fileno(self):
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
class BaseDebugger(ABC):
|
class BaseDebugger(ABC):
|
||||||
"""
|
"""
|
||||||
|
@ -407,19 +428,21 @@ class BaseIOManager(ABC):
|
||||||
return NotImplemented
|
return NotImplemented
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def request_read(self, rsc: AsyncResource):
|
def request_read(self, rsc: AsyncResource, task: Task):
|
||||||
"""
|
"""
|
||||||
"Requests" a read operation on the given
|
"Requests" a read operation on the given
|
||||||
resource to the I/O manager from the current task
|
resource to the I/O manager from the given
|
||||||
|
task
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return NotImplemented
|
return NotImplemented
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def request_write(self, rsc: AsyncResource):
|
def request_write(self, rsc: AsyncResource, task: Task):
|
||||||
"""
|
"""
|
||||||
"Requests" a write operation on the given
|
"Requests" a write operation on the given
|
||||||
resource to the I/O manager from the current task
|
resource to the I/O manager from the given
|
||||||
|
task
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return NotImplemented
|
return NotImplemented
|
||||||
|
@ -427,8 +450,9 @@ class BaseIOManager(ABC):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def pending(self):
|
def pending(self):
|
||||||
"""
|
"""
|
||||||
Returns a boolean value that indicates whether
|
Returns whether there's any tasks waiting
|
||||||
there's any I/O registered in the manager
|
to read from/write to a resource registered
|
||||||
|
in the manager
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return NotImplemented
|
return NotImplemented
|
||||||
|
@ -455,6 +479,19 @@ class BaseIOManager(ABC):
|
||||||
|
|
||||||
return NotImplemented
|
return NotImplemented
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_reader(self, rsc: AsyncResource):
|
||||||
|
"""
|
||||||
|
Returns the task reading from the given
|
||||||
|
resource, if any (None otherwise)
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_writer(self, rsc: AsyncResource):
|
||||||
|
"""
|
||||||
|
Returns the task writing to the given
|
||||||
|
resource, if any (None otherwise)
|
||||||
|
"""
|
||||||
|
|
||||||
class SignalManager(ABC):
|
class SignalManager(ABC):
|
||||||
"""
|
"""
|
||||||
|
@ -505,13 +542,47 @@ class BaseKernel(ABC):
|
||||||
self.current_scope: "TaskScope" = None
|
self.current_scope: "TaskScope" = None
|
||||||
self.tools: list[BaseDebugger] = tools or []
|
self.tools: list[BaseDebugger] = tools or []
|
||||||
self.restrict_ki_to_checkpoints: bool = restrict_ki_to_checkpoints
|
self.restrict_ki_to_checkpoints: bool = restrict_ki_to_checkpoints
|
||||||
self.running: bool = False
|
|
||||||
self.io_manager = io_manager
|
self.io_manager = io_manager
|
||||||
self.signal_managers = signal_managers
|
self.signal_managers = signal_managers
|
||||||
self.entry_point: Task | None = None
|
self.entry_point: Task | None = None
|
||||||
# Pool for system tasks
|
# Pool for system tasks
|
||||||
self.pool: "TaskPool" = None
|
self.pool: "TaskPool" = None
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def wait_readable(self, resource: AsyncResource):
|
||||||
|
"""
|
||||||
|
Schedule the given resource for reading from
|
||||||
|
the current task
|
||||||
|
"""
|
||||||
|
|
||||||
|
return NotImplemented
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def wait_writable(self, resource: AsyncResource):
|
||||||
|
"""
|
||||||
|
Schedule the given resource for reading from
|
||||||
|
the current task
|
||||||
|
"""
|
||||||
|
|
||||||
|
return NotImplemented
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def release_resource(self, resource: AsyncResource):
|
||||||
|
"""
|
||||||
|
Releases the given resource from the scheduler
|
||||||
|
"""
|
||||||
|
|
||||||
|
return NotImplemented
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def notify_closing(self, resource: AsyncResource, broken: bool = False, owner: Task | None = None):
|
||||||
|
"""
|
||||||
|
Notifies the event loop that a given resource
|
||||||
|
is about to be closed and can be unscheduled
|
||||||
|
"""
|
||||||
|
|
||||||
|
return NotImplemented
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def cancel_task(self, task: Task):
|
def cancel_task(self, task: Task):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -7,12 +7,13 @@ from structio.abc import (
|
||||||
BaseDebugger,
|
BaseDebugger,
|
||||||
BaseIOManager,
|
BaseIOManager,
|
||||||
SignalManager,
|
SignalManager,
|
||||||
|
AsyncResource
|
||||||
)
|
)
|
||||||
from structio.core.context import TaskPool, TaskScope
|
from structio.core.context import TaskPool, TaskScope
|
||||||
from structio.core.task import Task, TaskState
|
from structio.core.task import Task, TaskState
|
||||||
from structio.util.ki import CTRLC_PROTECTION_ENABLED
|
from structio.util.ki import CTRLC_PROTECTION_ENABLED
|
||||||
from structio.core.time.queue import TimeQueue
|
from structio.core.time.queue import TimeQueue
|
||||||
from structio.exceptions import StructIOException, Cancelled, TimedOut
|
from structio.exceptions import StructIOException, Cancelled, TimedOut, ResourceClosed, ResourceBroken
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from typing import Callable, Coroutine, Any
|
from typing import Callable, Coroutine, Any
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
@ -62,6 +63,26 @@ class FIFOKernel(BaseKernel):
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def wait_readable(self, resource: AsyncResource):
|
||||||
|
self.io_manager.request_read(resource, self.current_task)
|
||||||
|
|
||||||
|
def wait_writable(self, resource: AsyncResource):
|
||||||
|
self.io_manager.request_write(resource, self.current_task)
|
||||||
|
|
||||||
|
def notify_closing(self, resource: AsyncResource, broken: bool = False, owner: Task | None = None):
|
||||||
|
if not broken:
|
||||||
|
exc = ResourceClosed("stream has been closed")
|
||||||
|
else:
|
||||||
|
exc = ResourceBroken("stream might be corrupted")
|
||||||
|
owner = owner or self.current_task
|
||||||
|
reader = self.io_manager.get_reader(resource)
|
||||||
|
writer = self.io_manager.get_writer(resource)
|
||||||
|
if reader is not owner:
|
||||||
|
self.throw(reader, exc)
|
||||||
|
if writer is not owner:
|
||||||
|
self.throw(writer, exc)
|
||||||
|
self.reschedule_running()
|
||||||
|
|
||||||
def get_closest_deadline_owner(self):
|
def get_closest_deadline_owner(self):
|
||||||
return self.paused.peek()
|
return self.paused.peek()
|
||||||
|
|
||||||
|
@ -77,18 +98,21 @@ class FIFOKernel(BaseKernel):
|
||||||
# We really can't afford to have our internals explode,
|
# We really can't afford to have our internals explode,
|
||||||
# sorry!
|
# sorry!
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
f"Exception during debugging event delivery ({evt_name!r}): {type(e).__name__} -> {e}",
|
f"Exception during debugging event delivery in {f!r} ({evt_name!r}): {type(e).__name__} -> {e}",
|
||||||
)
|
)
|
||||||
traceback.print_tb(e.__traceback__)
|
traceback.print_tb(e.__traceback__)
|
||||||
|
# We disable the tool, so it can't raise at the next debugging
|
||||||
|
# event
|
||||||
|
self.tools.remove(tool)
|
||||||
|
|
||||||
def done(self):
|
def done(self):
|
||||||
|
if self.entry_point.done():
|
||||||
|
return True
|
||||||
if any([self.run_queue, self.paused, self.io_manager.pending()]):
|
if any([self.run_queue, self.paused, self.io_manager.pending()]):
|
||||||
return False
|
return False
|
||||||
for scope in self.scopes:
|
for scope in self.scopes:
|
||||||
if not scope.done():
|
if not scope.done():
|
||||||
return False
|
return False
|
||||||
if self.entry_point.done():
|
|
||||||
return True
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def spawn(self, func: Callable[[Any, Any], Coroutine[Any, Any, Any]], *args):
|
def spawn(self, func: Callable[[Any, Any], Coroutine[Any, Any, Any]], *args):
|
||||||
|
@ -168,6 +192,8 @@ class FIFOKernel(BaseKernel):
|
||||||
return
|
return
|
||||||
if task.state == TaskState.PAUSED:
|
if task.state == TaskState.PAUSED:
|
||||||
self.paused.discard(task)
|
self.paused.discard(task)
|
||||||
|
elif task.state == TaskState.IO:
|
||||||
|
self.io_manager.release_task(task)
|
||||||
self.handle_errors(partial(task.coroutine.throw, err), task)
|
self.handle_errors(partial(task.coroutine.throw, err), task)
|
||||||
|
|
||||||
def reschedule(self, task: Task):
|
def reschedule(self, task: Task):
|
||||||
|
@ -185,6 +211,7 @@ class FIFOKernel(BaseKernel):
|
||||||
self.run_queue.appendleft(self.current_task)
|
self.run_queue.appendleft(self.current_task)
|
||||||
|
|
||||||
def schedule_point(self):
|
def schedule_point(self):
|
||||||
|
self.skip = True
|
||||||
self.reschedule_running()
|
self.reschedule_running()
|
||||||
|
|
||||||
def sleep(self, amount):
|
def sleep(self, amount):
|
||||||
|
@ -233,7 +260,6 @@ class FIFOKernel(BaseKernel):
|
||||||
while not self.done():
|
while not self.done():
|
||||||
if self.run_queue and not self.skip:
|
if self.run_queue and not self.skip:
|
||||||
self.handle_errors(self.step)
|
self.handle_errors(self.step)
|
||||||
self.running = False
|
|
||||||
self.skip = False
|
self.skip = False
|
||||||
if self._sigint_handled and not self.restrict_ki_to_checkpoints:
|
if self._sigint_handled and not self.restrict_ki_to_checkpoints:
|
||||||
self.throw(self.entry_point, KeyboardInterrupt())
|
self.throw(self.entry_point, KeyboardInterrupt())
|
||||||
|
@ -248,7 +274,7 @@ class FIFOKernel(BaseKernel):
|
||||||
Reschedules the currently running task
|
Reschedules the currently running task
|
||||||
"""
|
"""
|
||||||
|
|
||||||
self.run_queue.append(self.current_task)
|
self.reschedule(self.current_task)
|
||||||
|
|
||||||
def handle_errors(self, func: Callable, task: Task | None = None):
|
def handle_errors(self, func: Callable, task: Task | None = None):
|
||||||
"""
|
"""
|
||||||
|
@ -296,6 +322,10 @@ class FIFOKernel(BaseKernel):
|
||||||
self.event("on_exception_raised", task)
|
self.event("on_exception_raised", task)
|
||||||
self.on_error(task)
|
self.on_error(task)
|
||||||
|
|
||||||
|
def release_resource(self, resource: AsyncResource):
|
||||||
|
self.io_manager.release(resource)
|
||||||
|
self.reschedule_running()
|
||||||
|
|
||||||
def release(self, task: Task):
|
def release(self, task: Task):
|
||||||
"""
|
"""
|
||||||
Releases the timeouts and associated
|
Releases the timeouts and associated
|
||||||
|
|
|
@ -1,14 +1,17 @@
|
||||||
from collections import defaultdict
|
from structio.abc import BaseIOManager, BaseKernel, AsyncResource
|
||||||
from structio.abc import BaseIOManager, BaseKernel
|
|
||||||
from structio.core.context import Task
|
from structio.core.context import Task
|
||||||
from structio.core.run import current_loop, current_task
|
from structio.core.run import current_loop
|
||||||
import select
|
import select
|
||||||
|
|
||||||
|
|
||||||
class SimpleIOManager(BaseIOManager):
|
class SimpleIOManager(BaseIOManager):
|
||||||
"""
|
"""
|
||||||
A simple, cross-platform, select()-based
|
A simple, cross-platform, select()-based
|
||||||
I/O manager
|
I/O manager. This class is only meant to
|
||||||
|
be used as a default fallback and is quite
|
||||||
|
inefficient and slower compared to more ad-hoc
|
||||||
|
alternatives such as epoll or kqueue (it should
|
||||||
|
work on most platforms though)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
@ -17,21 +20,19 @@ class SimpleIOManager(BaseIOManager):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Maps resources to tasks
|
# Maps resources to tasks
|
||||||
self.readers = {}
|
self.readers: dict[AsyncResource, Task] = {}
|
||||||
self.writers = {}
|
self.writers: dict[AsyncResource, Task] = {}
|
||||||
# This allows us to have a bidirectional mapping:
|
|
||||||
# we know both which tasks are using which resources
|
|
||||||
# and which resources are used by which tasks,
|
|
||||||
# without having to go through too many hoops and jumps.
|
|
||||||
self.tasks: dict[Task, list] = defaultdict(list)
|
|
||||||
|
|
||||||
def pending(self):
|
def pending(self):
|
||||||
# We don't return bool(self.resources) because there is
|
|
||||||
# no pending I/O to do if no tasks are waiting to read or
|
|
||||||
# write, even if there's dangling resources around!
|
|
||||||
return bool(self.readers or self.writers)
|
return bool(self.readers or self.writers)
|
||||||
|
|
||||||
def _collect_readers(self) -> list:
|
def get_reader(self, rsc: AsyncResource):
|
||||||
|
return self.readers.get(rsc)
|
||||||
|
|
||||||
|
def get_writer(self, rsc: AsyncResource):
|
||||||
|
return self.writers.get(rsc)
|
||||||
|
|
||||||
|
def _collect_readers(self) -> list[int]:
|
||||||
"""
|
"""
|
||||||
Collects all resources that need to be read from,
|
Collects all resources that need to be read from,
|
||||||
so we can select() on them later
|
so we can select() on them later
|
||||||
|
@ -39,10 +40,10 @@ class SimpleIOManager(BaseIOManager):
|
||||||
|
|
||||||
result = []
|
result = []
|
||||||
for resource in self.readers:
|
for resource in self.readers:
|
||||||
result.append(resource)
|
result.append(resource.fileno())
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def _collect_writers(self) -> list:
|
def _collect_writers(self) -> list[int]:
|
||||||
"""
|
"""
|
||||||
Collects all resources that need to be written to,
|
Collects all resources that need to be written to,
|
||||||
so we can select() on them later
|
so we can select() on them later
|
||||||
|
@ -50,34 +51,43 @@ class SimpleIOManager(BaseIOManager):
|
||||||
|
|
||||||
result = []
|
result = []
|
||||||
for resource in self.writers:
|
for resource in self.writers:
|
||||||
result.append(resource)
|
result.append(resource.fileno())
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def wait_io(self):
|
def wait_io(self):
|
||||||
kernel: BaseKernel = current_loop()
|
kernel: BaseKernel = current_loop()
|
||||||
|
deadline = kernel.get_closest_deadline()
|
||||||
|
if deadline == float("inf"):
|
||||||
|
deadline = 0
|
||||||
readable, writable, _ = select.select(
|
readable, writable, _ = select.select(
|
||||||
self._collect_readers(),
|
self._collect_readers(),
|
||||||
self._collect_writers(),
|
self._collect_writers(),
|
||||||
[],
|
[],
|
||||||
kernel.get_closest_deadline(),
|
deadline,
|
||||||
)
|
)
|
||||||
for read_ready in readable:
|
for read_ready in readable:
|
||||||
kernel.reschedule(self.readers[read_ready])
|
for resource, task in self.readers.items():
|
||||||
|
if resource.fileno() == read_ready:
|
||||||
|
kernel.reschedule(task)
|
||||||
for write_ready in writable:
|
for write_ready in writable:
|
||||||
kernel.reschedule(self.writers[write_ready])
|
for resource, task in self.writers.items():
|
||||||
|
if resource.fileno() == write_ready:
|
||||||
|
kernel.reschedule(task)
|
||||||
|
|
||||||
def request_read(self, rsc):
|
def request_read(self, rsc: AsyncResource, task: Task):
|
||||||
task = current_task()
|
|
||||||
self.readers[rsc] = task
|
self.readers[rsc] = task
|
||||||
|
|
||||||
def request_write(self, rsc):
|
def request_write(self, rsc: AsyncResource, task: Task):
|
||||||
task = current_task()
|
|
||||||
self.writers[rsc] = task
|
self.writers[rsc] = task
|
||||||
|
|
||||||
def release(self, resource):
|
def release(self, resource: AsyncResource):
|
||||||
self.readers.pop(resource, None)
|
self.readers.pop(resource, None)
|
||||||
self.writers.pop(resource, None)
|
self.writers.pop(resource, None)
|
||||||
|
|
||||||
def release_task(self, task: Task):
|
def release_task(self, task: Task):
|
||||||
for resource in self.tasks[task]:
|
for resource, owner in self.readers.copy().items():
|
||||||
self.release(resource)
|
if owner == task:
|
||||||
|
self.readers.pop(resource)
|
||||||
|
for resource, owner in self.writers.copy().items():
|
||||||
|
if owner == task:
|
||||||
|
self.writers.pop(resource)
|
||||||
|
|
|
@ -68,3 +68,19 @@ async def checkpoint():
|
||||||
|
|
||||||
await check_cancelled()
|
await check_cancelled()
|
||||||
await schedule_point()
|
await schedule_point()
|
||||||
|
|
||||||
|
|
||||||
|
async def wait_readable(rsc):
|
||||||
|
return await syscall("wait_readable", rsc)
|
||||||
|
|
||||||
|
|
||||||
|
async def wait_writable(rsc):
|
||||||
|
return await syscall("wait_writable", rsc)
|
||||||
|
|
||||||
|
|
||||||
|
async def closing(rsc):
|
||||||
|
return await syscall("notify_closing", rsc)
|
||||||
|
|
||||||
|
|
||||||
|
async def release(rsc):
|
||||||
|
return await syscall("release_resource", rsc)
|
||||||
|
|
|
@ -84,11 +84,10 @@ class TimeQueue:
|
||||||
def get_closest_deadline(self) -> float:
|
def get_closest_deadline(self) -> float:
|
||||||
"""
|
"""
|
||||||
Returns the closest deadline that is meant to expire
|
Returns the closest deadline that is meant to expire
|
||||||
or raises IndexError if the queue is empty
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if not self:
|
if not self:
|
||||||
raise IndexError("TimeQueue is empty")
|
return float("inf")
|
||||||
return self.container[0][0]
|
return self.container[0][0]
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
|
|
|
@ -38,3 +38,11 @@ class ResourceBusy(StructIOException):
|
||||||
Raised when an attempt is made to use an
|
Raised when an attempt is made to use an
|
||||||
asynchronous resource that is currently busy
|
asynchronous resource that is currently busy
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class ResourceBroken(StructIOException):
|
||||||
|
"""
|
||||||
|
Raised when an asynchronous resource gets
|
||||||
|
corrupted and is no longer usable
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,152 @@
|
||||||
|
# This is, ahem, inspired by Curio and Trio. See https://github.com/dabeaz/curio/issues/104
|
||||||
|
import os
|
||||||
|
import warnings
|
||||||
|
from structio.core.syscalls import checkpoint, wait_readable, wait_writable, closing, release
|
||||||
|
from structio.exceptions import ResourceClosed
|
||||||
|
from structio.abc import AsyncResource
|
||||||
|
try:
|
||||||
|
from ssl import SSLWantReadError, SSLWantWriteError, SSLSocket
|
||||||
|
WantRead = (BlockingIOError, SSLWantReadError, InterruptedError)
|
||||||
|
WantWrite = (BlockingIOError, SSLWantWriteError, InterruptedError)
|
||||||
|
except ImportError:
|
||||||
|
WantWrite = (BlockingIOError, InterruptedError)
|
||||||
|
WantRead = (BlockingIOError, InterruptedError)
|
||||||
|
SSLSocket = None
|
||||||
|
|
||||||
|
|
||||||
|
class FdWrapper:
|
||||||
|
"""
|
||||||
|
A simple wrapper around a file descriptor that
|
||||||
|
allows the event loop to perform an optimization
|
||||||
|
regarding I/O event registration safely. This is
|
||||||
|
because while integer file descriptors can be reused
|
||||||
|
by the operating system, instances of this class will
|
||||||
|
not (hence if the event loop keeps around a dead instance
|
||||||
|
of an FdWrapper, it at least won't accidentally register
|
||||||
|
a new file with that same file descriptor). A bonus is
|
||||||
|
that this also allows us to always assume that we can call
|
||||||
|
fileno() on all objects registered in our selector, regardless
|
||||||
|
of whether the wrapped fd is an int or something else entirely
|
||||||
|
"""
|
||||||
|
|
||||||
|
__slots__ = ("fd", )
|
||||||
|
|
||||||
|
def __init__(self, fd):
|
||||||
|
self.fd = fd
|
||||||
|
|
||||||
|
def fileno(self):
|
||||||
|
return self.fd
|
||||||
|
|
||||||
|
# Can be converted to an int
|
||||||
|
def __int__(self):
|
||||||
|
return self.fd
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"<fd={self.fd!r}>"
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncStream(AsyncResource):
|
||||||
|
"""
|
||||||
|
A generic asynchronous stream over
|
||||||
|
a file descriptor. Functionality
|
||||||
|
is OS-dependent
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
fd: int,
|
||||||
|
open_fd: bool = True,
|
||||||
|
close_on_context_exit: bool = True,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
self._fd = FdWrapper(fd)
|
||||||
|
self.fileobj = None
|
||||||
|
if open_fd:
|
||||||
|
self.fileobj = os.fdopen(int(self._fd), **kwargs)
|
||||||
|
os.set_blocking(int(self._fd), False)
|
||||||
|
# Do we close ourselves upon the end of a context manager?
|
||||||
|
self.close_on_context_exit = close_on_context_exit
|
||||||
|
|
||||||
|
async def read(self, size: int = -1):
|
||||||
|
"""
|
||||||
|
Reads up to size bytes from the
|
||||||
|
given stream. If size == -1, read
|
||||||
|
until EOF is reached
|
||||||
|
"""
|
||||||
|
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
data = self.fileobj.read(size)
|
||||||
|
await checkpoint()
|
||||||
|
return data
|
||||||
|
except WantRead:
|
||||||
|
await wait_readable(self)
|
||||||
|
|
||||||
|
async def write(self, data):
|
||||||
|
"""
|
||||||
|
Writes data to the stream.
|
||||||
|
Returns the number of bytes
|
||||||
|
written
|
||||||
|
"""
|
||||||
|
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
data = self.fileobj.write(data)
|
||||||
|
await checkpoint()
|
||||||
|
return data
|
||||||
|
except WantWrite:
|
||||||
|
await wait_writable(self)
|
||||||
|
|
||||||
|
async def close(self):
|
||||||
|
"""
|
||||||
|
Closes the stream asynchronously
|
||||||
|
"""
|
||||||
|
|
||||||
|
if self.fileno() == -1:
|
||||||
|
raise ResourceClosed("I/O operation on closed stream")
|
||||||
|
self._fd = -1
|
||||||
|
await closing(self)
|
||||||
|
await release(self)
|
||||||
|
self.fileobj.close()
|
||||||
|
self.fileobj = None
|
||||||
|
await checkpoint()
|
||||||
|
|
||||||
|
def fileno(self):
|
||||||
|
"""
|
||||||
|
Wrapper socket method
|
||||||
|
"""
|
||||||
|
|
||||||
|
return int(self._fd)
|
||||||
|
|
||||||
|
async def __aenter__(self):
|
||||||
|
self.fileobj.__enter__()
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(self, *args):
|
||||||
|
if self._fd != -1 and self.close_on_context_exit:
|
||||||
|
await self.close()
|
||||||
|
|
||||||
|
async def dup(self):
|
||||||
|
"""
|
||||||
|
Wrapper stream method
|
||||||
|
"""
|
||||||
|
|
||||||
|
return type(self)(os.dup(self.fileno()))
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"AsyncStream({self.fileobj})"
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
"""
|
||||||
|
Stream destructor. Do *not* call
|
||||||
|
this directly: stuff will break
|
||||||
|
"""
|
||||||
|
|
||||||
|
if self._fd != -1 and self.fileobj.fileno() != -1:
|
||||||
|
try:
|
||||||
|
os.set_blocking(self.fileno(), False)
|
||||||
|
os.close(self.fileno())
|
||||||
|
except OSError as e:
|
||||||
|
warnings.warn(
|
||||||
|
f"An exception occurred in __del__ for stream {self} -> {type(e).__name__}: {e}"
|
||||||
|
)
|
|
@ -44,7 +44,7 @@ _FILE_ASYNC_METHODS = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class AsyncResourceWrapper(AsyncResource):
|
class AsyncFile(AsyncResource):
|
||||||
"""
|
"""
|
||||||
Asynchronous wrapper around regular file-like objects.
|
Asynchronous wrapper around regular file-like objects.
|
||||||
Blocking operations are turned into async ones using threads.
|
Blocking operations are turned into async ones using threads.
|
||||||
|
@ -52,6 +52,9 @@ class AsyncResourceWrapper(AsyncResource):
|
||||||
and read/write methods
|
and read/write methods
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def fileno(self):
|
||||||
|
return self.handle.fileno()
|
||||||
|
|
||||||
def __init__(self, f):
|
def __init__(self, f):
|
||||||
self._file = f
|
self._file = f
|
||||||
|
|
||||||
|
@ -108,7 +111,8 @@ class AsyncResourceWrapper(AsyncResource):
|
||||||
# This operation is non-cancellable, meaning it'll run
|
# This operation is non-cancellable, meaning it'll run
|
||||||
# no matter what our event loop has to say about it.
|
# no matter what our event loop has to say about it.
|
||||||
# After we're done, we'll obviously re-raise the cancellation
|
# After we're done, we'll obviously re-raise the cancellation
|
||||||
# if necessary
|
# if necessary. This ensures files are always closed even when
|
||||||
|
# the operation gets cancelled
|
||||||
await structio.thread.run_in_worker(self.handle.close)
|
await structio.thread.run_in_worker(self.handle.close)
|
||||||
# If we were cancelled, here is where we raise
|
# If we were cancelled, here is where we raise
|
||||||
await check_cancelled()
|
await check_cancelled()
|
||||||
|
@ -123,7 +127,7 @@ async def open_file(
|
||||||
newline=None,
|
newline=None,
|
||||||
closefd=True,
|
closefd=True,
|
||||||
opener=None,
|
opener=None,
|
||||||
) -> AsyncResourceWrapper:
|
) -> AsyncFile:
|
||||||
"""
|
"""
|
||||||
Like io.open(), but async. Magic
|
Like io.open(), but async. Magic
|
||||||
"""
|
"""
|
||||||
|
@ -135,13 +139,13 @@ async def open_file(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def wrap_file(file) -> AsyncResourceWrapper:
|
def wrap_file(file) -> AsyncFile:
|
||||||
"""
|
"""
|
||||||
Wraps a file-like object into an async
|
Wraps a file-like object into an async
|
||||||
wrapper
|
wrapper
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return AsyncResourceWrapper(file)
|
return AsyncFile(file)
|
||||||
|
|
||||||
|
|
||||||
stdin = wrap_file(sys.stdin)
|
stdin = wrap_file(sys.stdin)
|
||||||
|
|
|
@ -1,13 +1,344 @@
|
||||||
import structio
|
|
||||||
from structio.abc import AsyncResource
|
from structio.abc import AsyncResource
|
||||||
from structio.core.syscalls import check_cancelled
|
from structio.io import FdWrapper, WantRead, WantWrite, SSLSocket
|
||||||
|
from structio.exceptions import ResourceClosed, ResourceBroken
|
||||||
|
from structio.core.syscalls import wait_readable, wait_writable, checkpoint, closing, release
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
import socket as _socket
|
import socket as _socket
|
||||||
|
|
||||||
|
|
||||||
@wraps(_socket.socket)
|
@wraps(_socket.socket)
|
||||||
def socket(*args, **kwargs):
|
def socket(*args, **kwargs):
|
||||||
return None # TODO
|
return AsyncSocket(_socket.socket(*args, **kwargs))
|
||||||
|
|
||||||
# TODO
|
|
||||||
|
class AsyncSocket(AsyncResource):
|
||||||
|
"""
|
||||||
|
Abstraction layer for asynchronous sockets
|
||||||
|
"""
|
||||||
|
|
||||||
|
def fileno(self):
|
||||||
|
return int(self._fd)
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
sock: _socket.socket,
|
||||||
|
close_on_context_exit: bool = True,
|
||||||
|
do_handshake_on_connect: bool = True,
|
||||||
|
):
|
||||||
|
self._fd = FdWrapper(sock.fileno())
|
||||||
|
self.close_on_context_exit = close_on_context_exit
|
||||||
|
# Do we perform the TCP handshake automatically
|
||||||
|
# upon connection? This is mostly needed for SSL
|
||||||
|
# sockets
|
||||||
|
self.do_handshake_on_connect = do_handshake_on_connect
|
||||||
|
self.socket = sock
|
||||||
|
self.socket.setblocking(False)
|
||||||
|
# A socket that isn't connected doesn't
|
||||||
|
# need to be closed
|
||||||
|
self.needs_closing: bool = False
|
||||||
|
|
||||||
|
async def receive(self, max_size: int, flags: int = 0) -> bytes:
|
||||||
|
"""
|
||||||
|
Receives up to max_size bytes from a socket asynchronously
|
||||||
|
"""
|
||||||
|
|
||||||
|
assert max_size >= 1, "max_size must be >= 1"
|
||||||
|
if self._fd == -1:
|
||||||
|
raise ResourceClosed("I/O operation on closed socket")
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
data = self.socket.recv(max_size, flags)
|
||||||
|
await checkpoint()
|
||||||
|
return data
|
||||||
|
except WantRead:
|
||||||
|
await wait_readable(self)
|
||||||
|
except WantWrite:
|
||||||
|
await wait_writable(self)
|
||||||
|
|
||||||
|
async def receive_exactly(self, size: int, flags: int = 0) -> bytes:
|
||||||
|
"""
|
||||||
|
Receives exactly size bytes from a socket asynchronously.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# https://stackoverflow.com/questions/55825905/how-can-i-reliably-read-exactly-n-bytes-from-a-tcp-socket
|
||||||
|
buf = bytearray(size)
|
||||||
|
pos = 0
|
||||||
|
while pos < size:
|
||||||
|
n = await self.recv_into(memoryview(buf)[pos:], flags=flags)
|
||||||
|
if n == 0:
|
||||||
|
raise ResourceBroken("incomplete read detected")
|
||||||
|
pos += n
|
||||||
|
return bytes(buf)
|
||||||
|
|
||||||
|
async def connect(self, address):
|
||||||
|
"""
|
||||||
|
Wrapper socket method
|
||||||
|
"""
|
||||||
|
|
||||||
|
if self._fd == -1:
|
||||||
|
raise ResourceClosed("I/O operation on closed socket")
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
self.socket.connect(address)
|
||||||
|
if self.do_handshake_on_connect:
|
||||||
|
await self.do_handshake()
|
||||||
|
await checkpoint()
|
||||||
|
break
|
||||||
|
except WantWrite:
|
||||||
|
await wait_writable(self)
|
||||||
|
self.needs_closing = True
|
||||||
|
|
||||||
|
async def close(self):
|
||||||
|
"""
|
||||||
|
Wrapper socket method
|
||||||
|
"""
|
||||||
|
|
||||||
|
if self.needs_closing:
|
||||||
|
self.socket.close()
|
||||||
|
await checkpoint()
|
||||||
|
|
||||||
|
async def accept(self):
|
||||||
|
"""
|
||||||
|
Accepts the socket, completing the 3-step TCP handshake asynchronously
|
||||||
|
"""
|
||||||
|
|
||||||
|
if self._fd == -1:
|
||||||
|
raise ResourceClosed("I/O operation on closed socket")
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
remote, addr = self.socket.accept()
|
||||||
|
await checkpoint()
|
||||||
|
return type(self)(remote), addr
|
||||||
|
except WantRead:
|
||||||
|
await wait_readable(self)
|
||||||
|
|
||||||
|
async def send_all(self, data: bytes, flags: int = 0):
|
||||||
|
"""
|
||||||
|
Sends all data inside the buffer asynchronously until it is empty
|
||||||
|
"""
|
||||||
|
|
||||||
|
if self._fd == -1:
|
||||||
|
raise ResourceClosed("I/O operation on closed socket")
|
||||||
|
sent_no = 0
|
||||||
|
while data:
|
||||||
|
try:
|
||||||
|
sent_no = self.socket.send(data, flags)
|
||||||
|
await checkpoint()
|
||||||
|
except WantRead:
|
||||||
|
await wait_readable(self)
|
||||||
|
except WantWrite:
|
||||||
|
await wait_writable(self)
|
||||||
|
data = data[sent_no:]
|
||||||
|
|
||||||
|
async def shutdown(self, how):
|
||||||
|
"""
|
||||||
|
Wrapper socket method
|
||||||
|
"""
|
||||||
|
|
||||||
|
if self.fileno() == -1:
|
||||||
|
raise ResourceClosed("I/O operation on closed socket")
|
||||||
|
if self.socket:
|
||||||
|
self.socket.shutdown(how)
|
||||||
|
await checkpoint()
|
||||||
|
|
||||||
|
async def bind(self, addr: tuple):
|
||||||
|
"""
|
||||||
|
Binds the socket to an address
|
||||||
|
:param addr: The address, port tuple to bind to
|
||||||
|
:type addr: tuple
|
||||||
|
"""
|
||||||
|
|
||||||
|
if self._fd == -1:
|
||||||
|
raise ResourceClosed("I/O operation on closed socket")
|
||||||
|
self.socket.bind(addr)
|
||||||
|
await checkpoint()
|
||||||
|
|
||||||
|
async def listen(self, backlog: int):
|
||||||
|
"""
|
||||||
|
Starts listening with the given backlog
|
||||||
|
:param backlog: The socket's backlog
|
||||||
|
:type backlog: int
|
||||||
|
"""
|
||||||
|
|
||||||
|
if self._fd == -1:
|
||||||
|
raise ResourceClosed("I/O operation on closed socket")
|
||||||
|
self.socket.listen(backlog)
|
||||||
|
await checkpoint()
|
||||||
|
|
||||||
|
# Yes, I stole these from Curio because I could not be
|
||||||
|
# arsed to write a bunch of uninteresting simple socket
|
||||||
|
# methods from scratch, deal with it.
|
||||||
|
|
||||||
|
def settimeout(self, seconds):
|
||||||
|
"""
|
||||||
|
Wrapper socket method
|
||||||
|
"""
|
||||||
|
|
||||||
|
raise RuntimeError("Use with_timeout() to set a timeout")
|
||||||
|
|
||||||
|
def gettimeout(self):
|
||||||
|
"""
|
||||||
|
Wrapper socket method
|
||||||
|
"""
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def dup(self):
|
||||||
|
"""
|
||||||
|
Wrapper socket method
|
||||||
|
"""
|
||||||
|
|
||||||
|
return type(self)(self.socket.dup(), self.do_handshake_on_connect)
|
||||||
|
|
||||||
|
async def do_handshake(self):
|
||||||
|
"""
|
||||||
|
Wrapper socket method
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not hasattr(self.socket, "do_handshake"):
|
||||||
|
return
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
self.socket: SSLSocket # Silences pycharm warnings
|
||||||
|
self.socket.do_handshake()
|
||||||
|
await checkpoint()
|
||||||
|
except WantRead:
|
||||||
|
await wait_readable(self)
|
||||||
|
except WantWrite:
|
||||||
|
await wait_writable(self)
|
||||||
|
|
||||||
|
async def recvfrom(self, buffersize, flags=0):
|
||||||
|
"""
|
||||||
|
Wrapper socket method
|
||||||
|
"""
|
||||||
|
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
return self.socket.recvfrom(buffersize, flags)
|
||||||
|
except WantRead:
|
||||||
|
await wait_readable(self)
|
||||||
|
except WantWrite:
|
||||||
|
await wait_writable(self)
|
||||||
|
|
||||||
|
async def recv_into(self, buffer, nbytes=0, flags=0):
|
||||||
|
"""
|
||||||
|
Wrapper socket method
|
||||||
|
"""
|
||||||
|
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
data = self.socket.recv_into(buffer, nbytes, flags)
|
||||||
|
await checkpoint()
|
||||||
|
return data
|
||||||
|
except WantRead:
|
||||||
|
await wait_readable(self)
|
||||||
|
except WantWrite:
|
||||||
|
await wait_writable(self)
|
||||||
|
|
||||||
|
async def recvfrom_into(self, buffer, bytes=0, flags=0):
|
||||||
|
"""
|
||||||
|
Wrapper socket method
|
||||||
|
"""
|
||||||
|
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
data = self.socket.recvfrom_into(buffer, bytes, flags)
|
||||||
|
await checkpoint()
|
||||||
|
return data
|
||||||
|
except WantRead:
|
||||||
|
await wait_readable(self)
|
||||||
|
except WantWrite:
|
||||||
|
await wait_writable(self)
|
||||||
|
|
||||||
|
async def sendto(self, bytes, flags_or_address, address=None):
|
||||||
|
"""
|
||||||
|
Wrapper socket method
|
||||||
|
"""
|
||||||
|
|
||||||
|
if address:
|
||||||
|
flags = flags_or_address
|
||||||
|
else:
|
||||||
|
address = flags_or_address
|
||||||
|
flags = 0
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
data = self.socket.sendto(bytes, flags, address)
|
||||||
|
await checkpoint()
|
||||||
|
return data
|
||||||
|
except WantWrite:
|
||||||
|
await wait_writable(self)
|
||||||
|
except WantRead:
|
||||||
|
await wait_readable(self)
|
||||||
|
|
||||||
|
async def getpeername(self):
|
||||||
|
"""
|
||||||
|
Wrapper socket method
|
||||||
|
"""
|
||||||
|
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
data = self.socket.getpeername()
|
||||||
|
await checkpoint()
|
||||||
|
return data
|
||||||
|
except WantWrite:
|
||||||
|
await wait_writable(self)
|
||||||
|
except WantRead:
|
||||||
|
await wait_readable(self)
|
||||||
|
|
||||||
|
async def getsockname(self):
|
||||||
|
"""
|
||||||
|
Wrapper socket method
|
||||||
|
"""
|
||||||
|
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
data = self.socket.getpeername()
|
||||||
|
await checkpoint()
|
||||||
|
return data
|
||||||
|
except WantWrite:
|
||||||
|
await wait_writable(self)
|
||||||
|
except WantRead:
|
||||||
|
await wait_readable(self)
|
||||||
|
|
||||||
|
async def recvmsg(self, bufsize, ancbufsize=0, flags=0):
|
||||||
|
"""
|
||||||
|
Wrapper socket method
|
||||||
|
"""
|
||||||
|
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
data = self.socket.recvmsg(bufsize, ancbufsize, flags)
|
||||||
|
await checkpoint()
|
||||||
|
return data
|
||||||
|
except WantRead:
|
||||||
|
await wait_readable(self)
|
||||||
|
|
||||||
|
async def recvmsg_into(self, buffers, ancbufsize=0, flags=0):
|
||||||
|
"""
|
||||||
|
Wrapper socket method
|
||||||
|
"""
|
||||||
|
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
data = self.socket.recvmsg_into(buffers, ancbufsize, flags)
|
||||||
|
await checkpoint()
|
||||||
|
return data
|
||||||
|
except WantRead:
|
||||||
|
await wait_readable(self)
|
||||||
|
|
||||||
|
async def sendmsg(self, buffers, ancdata=(), flags=0, address=None):
|
||||||
|
"""
|
||||||
|
Wrapper socket method
|
||||||
|
"""
|
||||||
|
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
data = self.socket.sendmsg(buffers, ancdata, flags, address)
|
||||||
|
await checkpoint()
|
||||||
|
return data
|
||||||
|
except WantRead:
|
||||||
|
await wait_writable(self)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"AsyncSocket({self.socket})"
|
||||||
|
|
||||||
|
|
|
@ -1,10 +1,39 @@
|
||||||
# Module inspired by subprocess which allows for asynchronous
|
# Module inspired by subprocess which allows for asynchronous
|
||||||
# multiprocessing
|
# multiprocessing
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
from subprocess import (
|
||||||
|
CalledProcessError,
|
||||||
|
CompletedProcess,
|
||||||
|
SubprocessError,
|
||||||
|
STDOUT,
|
||||||
|
DEVNULL,
|
||||||
|
PIPE
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Process:
|
class Popen:
|
||||||
"""
|
"""
|
||||||
An asynchronous process
|
Wrapper around subprocess.Popen, but async
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# TODO
|
def __init__(self, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Public object constructor
|
||||||
|
"""
|
||||||
|
|
||||||
|
if "universal_newlines" in kwargs:
|
||||||
|
# Not sure why? But everyone else is doing it so :shrug:
|
||||||
|
raise RuntimeError("universal_newlines is not supported")
|
||||||
|
if stdin := kwargs.get("stdin"):
|
||||||
|
# Curio mentions stuff breaking if the child process
|
||||||
|
# is passed a stdin fd that is set to non-blocking mode
|
||||||
|
if hasattr(os, "set_blocking"):
|
||||||
|
os.set_blocking(stdin.fileno(), True)
|
||||||
|
# Delegate to Popen's constructor
|
||||||
|
self._process = subprocess.Popen(*args, **kwargs)
|
||||||
|
|
||||||
|
def __getattr__(self, item):
|
||||||
|
# Delegate to internal process object
|
||||||
|
return getattr(self._process, item)
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,136 @@
|
||||||
|
import structio
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
|
||||||
|
# An asynchronous chatroom
|
||||||
|
|
||||||
|
clients: dict[structio.socket.AsyncSocket, list[str, str]] = {}
|
||||||
|
names: set[str] = set()
|
||||||
|
|
||||||
|
|
||||||
|
async def event_handler(q: structio.Queue):
|
||||||
|
"""
|
||||||
|
Reads data submitted onto the queue
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
logging.info("Event handler spawned")
|
||||||
|
while True:
|
||||||
|
msg, payload = await q.get()
|
||||||
|
logging.info(f"Caught event {msg!r} with the following payload: {payload}")
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"An exception occurred in the message handler -> {type(e).__name__}: {e}")
|
||||||
|
except structio.exceptions.Cancelled:
|
||||||
|
logging.warning(f"Cancellation detected, message handler shutting down")
|
||||||
|
# Propagate the cancellation
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
async def serve(bind_address: tuple):
|
||||||
|
"""
|
||||||
|
Serves asynchronously forever (or until Ctrl+C ;))
|
||||||
|
:param bind_address: The address to bind the server to, represented as a tuple
|
||||||
|
(address, port) where address is a string and port is an integer
|
||||||
|
"""
|
||||||
|
|
||||||
|
sock = structio.socket.socket()
|
||||||
|
queue = structio.Queue()
|
||||||
|
await sock.bind(bind_address)
|
||||||
|
await sock.listen(5)
|
||||||
|
logging.info(f"Serving asynchronously at {bind_address[0]}:{bind_address[1]}")
|
||||||
|
async with structio.create_pool() as pool:
|
||||||
|
pool.spawn(event_handler, queue)
|
||||||
|
async with sock:
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
conn, address_tuple = await sock.accept()
|
||||||
|
clients[conn] = ["", f"{address_tuple[0]}:{address_tuple[1]}"]
|
||||||
|
await queue.put(("connect", clients[conn]))
|
||||||
|
logging.info(f"{address_tuple[0]}:{address_tuple[1]} connected")
|
||||||
|
await pool.spawn(handler, conn, queue)
|
||||||
|
except Exception as err:
|
||||||
|
# Because exceptions just *work*
|
||||||
|
logging.info(f"{address_tuple[0]}:{address_tuple[1]} has raised {type(err).__name__}: {err}")
|
||||||
|
|
||||||
|
|
||||||
|
async def handler(sock: structio.socket.AsyncSocket, q: structio.Queue):
|
||||||
|
"""
|
||||||
|
Handles a single client connection
|
||||||
|
:param sock: The AsyncSocket object connected to the client
|
||||||
|
"""
|
||||||
|
|
||||||
|
address = clients[sock][1]
|
||||||
|
name = ""
|
||||||
|
async with sock: # Closes the socket automatically
|
||||||
|
await sock.send_all(b"Welcome to the chatroom pal, may you tell me your name?\n> ")
|
||||||
|
cond = True
|
||||||
|
while cond:
|
||||||
|
while not name.endswith("\n"):
|
||||||
|
name = (await sock.receive(64)).decode()
|
||||||
|
if name == "":
|
||||||
|
cond = False
|
||||||
|
break
|
||||||
|
name = name.rstrip("\n")
|
||||||
|
if name not in names:
|
||||||
|
names.add(name)
|
||||||
|
clients[sock][0] = name
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
await sock.send_all(b"Sorry, but that name is already taken. Try again!\n> ")
|
||||||
|
await sock.send_all(f"Okay {name}, welcome to the chatroom!\n".encode())
|
||||||
|
await q.put(("join", (address, name)))
|
||||||
|
logging.info(f"{name} has joined the chatroom ({address}), informing clients")
|
||||||
|
for i, client_sock in enumerate(clients):
|
||||||
|
if client_sock != sock and clients[client_sock][0]:
|
||||||
|
await client_sock.send_all(f"{name} joins the chatroom!\n> ".encode())
|
||||||
|
while True:
|
||||||
|
await sock.send_all(b"> ")
|
||||||
|
data = await sock.receive(1024)
|
||||||
|
if not data:
|
||||||
|
break
|
||||||
|
decoded = data.decode().rstrip("\n")
|
||||||
|
if decoded.startswith("/"):
|
||||||
|
logging.info(f"{name} issued server command {decoded}")
|
||||||
|
await q.put(("cmd", (name, decoded[1:])))
|
||||||
|
match decoded[1:]:
|
||||||
|
case "bye":
|
||||||
|
await sock.send_all(b"Bye!\n")
|
||||||
|
break
|
||||||
|
case _:
|
||||||
|
await sock.send_all(b"Unknown command\n")
|
||||||
|
else:
|
||||||
|
await q.put(("msg", (name, data)))
|
||||||
|
logging.info(f"Got: {data!r} from {address}")
|
||||||
|
for i, client_sock in enumerate(clients):
|
||||||
|
if client_sock != sock and clients[client_sock][0]:
|
||||||
|
logging.info(f"Sending {data!r} to {':'.join(map(str, await client_sock.getpeername()))}")
|
||||||
|
if not data.endswith(b"\n"):
|
||||||
|
data += b"\n"
|
||||||
|
await client_sock.send_all(f"[{name}] ({address}): {data.decode()}> ".encode())
|
||||||
|
logging.info(f"Sent {data!r} to {i} clients")
|
||||||
|
await q.put(("leave", name))
|
||||||
|
logging.info(f"Connection from {address} closed")
|
||||||
|
logging.info(f"{name} has left the chatroom ({address}), informing clients")
|
||||||
|
for i, client_sock in enumerate(clients):
|
||||||
|
if client_sock != sock and clients[client_sock][0]:
|
||||||
|
await client_sock.send_all(f"{name} has left the chatroom\n> ".encode())
|
||||||
|
clients.pop(sock)
|
||||||
|
names.discard(name)
|
||||||
|
logging.info("Handler shutting down")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
port = int(sys.argv[1]) if len(sys.argv) > 1 else 1501
|
||||||
|
logging.basicConfig(
|
||||||
|
level=20,
|
||||||
|
format="[%(levelname)s] %(asctime)s %(message)s",
|
||||||
|
datefmt="%d/%m/%Y %p",
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
structio.run(serve, ("0.0.0.0", port))
|
||||||
|
except (Exception, KeyboardInterrupt) as error: # Exceptions propagate!
|
||||||
|
if isinstance(error, KeyboardInterrupt):
|
||||||
|
logging.info("Ctrl+C detected, exiting")
|
||||||
|
else:
|
||||||
|
logging.error(f"Exiting due to a {type(error).__name__}: {error}")
|
||||||
|
|
|
@ -0,0 +1,74 @@
|
||||||
|
import sys
|
||||||
|
import logging
|
||||||
|
import structio
|
||||||
|
|
||||||
|
|
||||||
|
# A test to check for asynchronous I/O
|
||||||
|
|
||||||
|
|
||||||
|
async def serve(bind_address: tuple):
|
||||||
|
"""
|
||||||
|
Serves asynchronously forever
|
||||||
|
:param bind_address: The address to bind the server to represented as a tuple
|
||||||
|
(address, port) where address is a string and port is an integer
|
||||||
|
"""
|
||||||
|
|
||||||
|
sock = structio.socket.socket()
|
||||||
|
await sock.bind(bind_address)
|
||||||
|
await sock.listen(5)
|
||||||
|
logging.info(f"Serving asynchronously at {bind_address[0]}:{bind_address[1]}")
|
||||||
|
async with structio.create_pool() as ctx:
|
||||||
|
async with sock:
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
conn, address_tuple = await sock.accept()
|
||||||
|
logging.info(f"{address_tuple[0]}:{address_tuple[1]} connected")
|
||||||
|
await ctx.spawn(handler, conn, address_tuple)
|
||||||
|
except Exception as err:
|
||||||
|
# Because exceptions just *work*
|
||||||
|
logging.info(f"{address_tuple[0]}:{address_tuple[1]} has raised {type(err).__name__}: {err}")
|
||||||
|
|
||||||
|
|
||||||
|
async def handler(sock: structio.socket.AsyncSocket, client_address: tuple):
|
||||||
|
"""
|
||||||
|
Handles a single client connection
|
||||||
|
:param sock: The AsyncSocket object connected to the client
|
||||||
|
:param client_address: The client's address represented as a tuple
|
||||||
|
(address, port) where address is a string and port is an integer
|
||||||
|
:type client_address: tuple
|
||||||
|
"""
|
||||||
|
|
||||||
|
address = f"{client_address[0]}:{client_address[1]}"
|
||||||
|
async with sock: # Closes the socket automatically
|
||||||
|
await sock.send_all(b"Welcome to the server pal, feel free to send me something!\n")
|
||||||
|
while True:
|
||||||
|
await sock.send_all(b"-> ")
|
||||||
|
data = await sock.receive(1024)
|
||||||
|
if not data:
|
||||||
|
break
|
||||||
|
elif data == b"exit\n":
|
||||||
|
await sock.send_all(b"I'm dead dude\n")
|
||||||
|
raise TypeError("Oh, no, I'm gonna die!")
|
||||||
|
elif data == b"fatal\n":
|
||||||
|
await sock.send_all(b"What a dick\n")
|
||||||
|
raise KeyboardInterrupt("He told me to do it!")
|
||||||
|
logging.info(f"Got: {data!r} from {address}")
|
||||||
|
await sock.send_all(b"Got: " + data)
|
||||||
|
logging.info(f"Echoed back {data!r} to {address}")
|
||||||
|
logging.info(f"Connection from {address} closed")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
port = int(sys.argv[1]) if len(sys.argv) > 1 else 1501
|
||||||
|
logging.basicConfig(
|
||||||
|
level=20,
|
||||||
|
format="[%(levelname)s] %(asctime)s %(message)s",
|
||||||
|
datefmt="%d/%m/%Y %H:%M:%S %p",
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
structio.run(serve, ("localhost", port))
|
||||||
|
except (Exception, KeyboardInterrupt) as error: # Exceptions propagate!
|
||||||
|
if isinstance(error, KeyboardInterrupt):
|
||||||
|
logging.info("Ctrl+C detected, exiting")
|
||||||
|
else:
|
||||||
|
logging.error(f"Exiting due to a {type(error).__name__}: {error}")
|
Loading…
Reference in New Issue