Various bug fixes and simplifications, added multiprocessing support and new tests
This commit is contained in:
parent
f9e56cffc4
commit
156b3c6fc8
|
@ -23,7 +23,7 @@ from structio.io.files import (
|
|||
ainput,
|
||||
)
|
||||
from structio.core.run import current_loop, current_task
|
||||
from structio import thread
|
||||
from structio import thread, parallel
|
||||
from structio.path import Path
|
||||
|
||||
|
||||
|
@ -145,4 +145,5 @@ __all__ = [
|
|||
"current_loop",
|
||||
"current_task",
|
||||
"Path",
|
||||
"parallel"
|
||||
]
|
||||
|
|
|
@ -52,10 +52,6 @@ class AsyncResource(ABC):
|
|||
async def close(self):
|
||||
return NotImplemented
|
||||
|
||||
@abstractmethod
|
||||
def fileno(self):
|
||||
return NotImplemented
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
await self.close()
|
||||
|
||||
|
@ -153,9 +149,6 @@ class ChannelReader(AsyncResource, ABC):
|
|||
read from the channel
|
||||
"""
|
||||
|
||||
def fileno(self):
|
||||
return None
|
||||
|
||||
|
||||
class ChannelWriter(AsyncResource, ABC):
|
||||
"""
|
||||
|
@ -179,18 +172,12 @@ class ChannelWriter(AsyncResource, ABC):
|
|||
to write to the channel
|
||||
"""
|
||||
|
||||
def fileno(self):
|
||||
return None
|
||||
|
||||
|
||||
class Channel(ChannelWriter, ChannelReader, ABC):
|
||||
"""
|
||||
A generic, two-way channel
|
||||
"""
|
||||
|
||||
def fileno(self):
|
||||
return None
|
||||
|
||||
|
||||
class BaseDebugger(ABC):
|
||||
"""
|
||||
|
@ -391,7 +378,7 @@ class BaseIOManager(ABC):
|
|||
return NotImplemented
|
||||
|
||||
@abstractmethod
|
||||
def request_read(self, rsc: AsyncResource, task: Task):
|
||||
def request_read(self, rsc, task: Task):
|
||||
"""
|
||||
"Requests" a read operation on the given
|
||||
resource to the I/O manager from the given
|
||||
|
@ -401,7 +388,7 @@ class BaseIOManager(ABC):
|
|||
return NotImplemented
|
||||
|
||||
@abstractmethod
|
||||
def request_write(self, rsc: AsyncResource, task: Task):
|
||||
def request_write(self, rsc, task: Task):
|
||||
"""
|
||||
"Requests" a write operation on the given
|
||||
resource to the I/O manager from the given
|
||||
|
@ -421,7 +408,7 @@ class BaseIOManager(ABC):
|
|||
return NotImplemented
|
||||
|
||||
@abstractmethod
|
||||
def release(self, resource: AsyncResource):
|
||||
def release(self, resource):
|
||||
"""
|
||||
Releases the given async resource from the
|
||||
manager. Note that the resource is *not*
|
||||
|
@ -443,19 +430,20 @@ class BaseIOManager(ABC):
|
|||
return NotImplemented
|
||||
|
||||
@abstractmethod
|
||||
def get_reader(self, rsc: AsyncResource):
|
||||
def get_reader(self, rsc):
|
||||
"""
|
||||
Returns the task reading from the given
|
||||
resource, if any (None otherwise)
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_writer(self, rsc: AsyncResource):
|
||||
def get_writer(self, rsc):
|
||||
"""
|
||||
Returns the task writing to the given
|
||||
resource, if any (None otherwise)
|
||||
"""
|
||||
|
||||
|
||||
class SignalManager(ABC):
|
||||
"""
|
||||
A signal manager
|
||||
|
|
|
@ -125,8 +125,8 @@ class TaskPool:
|
|||
try:
|
||||
if exc_val:
|
||||
await checkpoint()
|
||||
raise exc_val
|
||||
else:
|
||||
raise exc_val.with_traceback(exc_tb)
|
||||
elif not self.done():
|
||||
await suspend()
|
||||
except Cancelled as e:
|
||||
self.error = e
|
||||
|
|
|
@ -7,8 +7,8 @@ from structio.abc import (
|
|||
BaseDebugger,
|
||||
BaseIOManager,
|
||||
SignalManager,
|
||||
AsyncResource
|
||||
)
|
||||
from structio.io import FdWrapper
|
||||
from structio.core.context import TaskPool, TaskScope
|
||||
from structio.core.task import Task, TaskState
|
||||
from structio.util.ki import CTRLC_PROTECTION_ENABLED
|
||||
|
@ -63,13 +63,13 @@ class FIFOKernel(BaseKernel):
|
|||
]
|
||||
)
|
||||
|
||||
def wait_readable(self, resource: AsyncResource):
|
||||
def wait_readable(self, resource: FdWrapper):
|
||||
self.io_manager.request_read(resource, self.current_task)
|
||||
|
||||
def wait_writable(self, resource: AsyncResource):
|
||||
def wait_writable(self, resource: FdWrapper):
|
||||
self.io_manager.request_write(resource, self.current_task)
|
||||
|
||||
def notify_closing(self, resource: AsyncResource, broken: bool = False, owner: Task | None = None):
|
||||
def notify_closing(self, resource: FdWrapper, broken: bool = False, owner: Task | None = None):
|
||||
if not broken:
|
||||
exc = ResourceClosed("stream has been closed")
|
||||
else:
|
||||
|
@ -77,9 +77,9 @@ class FIFOKernel(BaseKernel):
|
|||
owner = owner or self.current_task
|
||||
reader = self.io_manager.get_reader(resource)
|
||||
writer = self.io_manager.get_writer(resource)
|
||||
if reader is not owner:
|
||||
if reader and reader is not owner:
|
||||
self.throw(reader, exc)
|
||||
if writer is not owner:
|
||||
if writer and writer is not owner:
|
||||
self.throw(writer, exc)
|
||||
self.reschedule_running()
|
||||
|
||||
|
@ -191,6 +191,8 @@ class FIFOKernel(BaseKernel):
|
|||
def throw(self, task: Task, err: BaseException):
|
||||
if task.done():
|
||||
return
|
||||
if self.current_scope.shielded:
|
||||
return
|
||||
if task.state == TaskState.PAUSED:
|
||||
self.paused.discard(task)
|
||||
elif task.state == TaskState.IO:
|
||||
|
@ -198,7 +200,7 @@ class FIFOKernel(BaseKernel):
|
|||
self.handle_errors(partial(task.coroutine.throw, err), task)
|
||||
|
||||
def reschedule(self, task: Task):
|
||||
if task.done():
|
||||
if task.done() or task in self.run_queue:
|
||||
return
|
||||
self.run_queue.append(task)
|
||||
|
||||
|
@ -298,7 +300,7 @@ class FIFOKernel(BaseKernel):
|
|||
# most of this code below is just useful for internal/debugging purposes
|
||||
task.state = TaskState.FINISHED
|
||||
task.result = ret.value
|
||||
self.on_success(self.current_task)
|
||||
self.on_success(task)
|
||||
except Cancelled:
|
||||
# When a task needs to be cancelled, we try to do it gracefully first:
|
||||
# if the task is paused in either I/O or sleeping, that's perfect.
|
||||
|
@ -323,7 +325,7 @@ class FIFOKernel(BaseKernel):
|
|||
self.event("on_exception_raised", task)
|
||||
self.on_error(task)
|
||||
|
||||
def release_resource(self, resource: AsyncResource):
|
||||
def release_resource(self, resource: FdWrapper):
|
||||
self.io_manager.release(resource)
|
||||
self.reschedule_running()
|
||||
|
||||
|
@ -401,7 +403,10 @@ class FIFOKernel(BaseKernel):
|
|||
inner = scope.inner
|
||||
if inner and not inner.shielded:
|
||||
self.cancel_scope(inner)
|
||||
for task in scope.tasks:
|
||||
for task in scope.tasks.copy():
|
||||
# We make a copy of the list because we
|
||||
# need to make sure that tasks aren't
|
||||
# removed out from under us
|
||||
self.cancel_task(task)
|
||||
if scope.done():
|
||||
self.reschedule(scope.owner)
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
from structio.abc import BaseIOManager, BaseKernel, AsyncResource
|
||||
from structio.abc import BaseIOManager, BaseKernel
|
||||
from structio.core.task import Task, TaskState
|
||||
from structio.core.run import current_loop, current_task
|
||||
from structio.io import FdWrapper
|
||||
import select
|
||||
|
||||
|
||||
|
@ -20,16 +21,16 @@ class SimpleIOManager(BaseIOManager):
|
|||
"""
|
||||
|
||||
# Maps resources to tasks
|
||||
self.readers: dict[AsyncResource, Task] = {}
|
||||
self.writers: dict[AsyncResource, Task] = {}
|
||||
self.readers: dict[FdWrapper, Task] = {}
|
||||
self.writers: dict[FdWrapper, Task] = {}
|
||||
|
||||
def pending(self):
|
||||
def pending(self) -> bool:
|
||||
return bool(self.readers or self.writers)
|
||||
|
||||
def get_reader(self, rsc: AsyncResource):
|
||||
def get_reader(self, rsc: FdWrapper):
|
||||
return self.readers.get(rsc)
|
||||
|
||||
def get_writer(self, rsc: AsyncResource):
|
||||
def get_writer(self, rsc: FdWrapper):
|
||||
return self.writers.get(rsc)
|
||||
|
||||
def _collect_readers(self) -> list[int]:
|
||||
|
@ -74,15 +75,15 @@ class SimpleIOManager(BaseIOManager):
|
|||
if resource.fileno() == write_ready and task.state == TaskState.IO:
|
||||
kernel.reschedule(task)
|
||||
|
||||
def request_read(self, rsc: AsyncResource, task: Task):
|
||||
def request_read(self, rsc: FdWrapper, task: Task):
|
||||
current_task().state = TaskState.IO
|
||||
self.readers[rsc] = task
|
||||
|
||||
def request_write(self, rsc: AsyncResource, task: Task):
|
||||
def request_write(self, rsc: FdWrapper, task: Task):
|
||||
current_task().state = TaskState.IO
|
||||
self.writers[rsc] = task
|
||||
|
||||
def release(self, resource: AsyncResource):
|
||||
def release(self, resource: FdWrapper):
|
||||
self.readers.pop(resource, None)
|
||||
self.writers.pop(resource, None)
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
# This is, ahem, inspired by Curio and Trio. See https://github.com/dabeaz/curio/issues/104
|
||||
import io
|
||||
import os
|
||||
import warnings
|
||||
from structio.core.syscalls import checkpoint, wait_readable, wait_writable, closing, release
|
||||
from structio.exceptions import ResourceClosed
|
||||
from structio.abc import AsyncResource
|
||||
|
@ -48,54 +48,62 @@ class FdWrapper:
|
|||
class AsyncStream(AsyncResource):
|
||||
"""
|
||||
A generic asynchronous stream over
|
||||
a file descriptor. Functionality
|
||||
is OS-dependent
|
||||
a file-like object, with buffering
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fd: int,
|
||||
open_fd: bool = True,
|
||||
close_on_context_exit: bool = True,
|
||||
**kwargs,
|
||||
fileobj
|
||||
):
|
||||
self._fd = FdWrapper(fd)
|
||||
self.fileobj = None
|
||||
if open_fd:
|
||||
self.fileobj = os.fdopen(int(self._fd), **kwargs)
|
||||
os.set_blocking(int(self._fd), False)
|
||||
# Do we close ourselves upon the end of a context manager?
|
||||
self.close_on_context_exit = close_on_context_exit
|
||||
self.fileobj = fileobj
|
||||
self._fd = FdWrapper(self.fileobj.fileno())
|
||||
self._buf = bytearray()
|
||||
|
||||
async def _read(self, size: int = -1) -> bytes:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def write(self, data):
|
||||
raise NotImplementedError()
|
||||
|
||||
async def read(self, size: int = -1):
|
||||
"""
|
||||
Reads up to size bytes from the
|
||||
given stream. If size == -1, read
|
||||
until EOF is reached
|
||||
as much as possible
|
||||
"""
|
||||
|
||||
if size < 0 and size < -1:
|
||||
raise ValueError("size must be -1 or a positive integer")
|
||||
if size == -1:
|
||||
size = len(self._buf)
|
||||
buf = self._buf
|
||||
if not buf:
|
||||
return await self._read(size)
|
||||
if len(buf) <= size:
|
||||
data = bytes(buf)
|
||||
buf.clear()
|
||||
else:
|
||||
data = bytes(buf[:size])
|
||||
del buf[:size]
|
||||
return data
|
||||
|
||||
# Yes I stole this from curio. Sue me.
|
||||
async def readall(self):
|
||||
chunks = []
|
||||
maxread = 65536
|
||||
if self._buf:
|
||||
chunks.append(bytes(self._buf))
|
||||
self._buf.clear()
|
||||
while True:
|
||||
try:
|
||||
data = self.fileobj.read(size)
|
||||
await checkpoint()
|
||||
return data
|
||||
except WantRead:
|
||||
await wait_readable(self)
|
||||
chunk = await self.read(maxread)
|
||||
if not chunk:
|
||||
return b''.join(chunks)
|
||||
chunks.append(chunk)
|
||||
if len(chunk) == maxread:
|
||||
maxread *= 2
|
||||
|
||||
async def write(self, data):
|
||||
"""
|
||||
Writes data to the stream.
|
||||
Returns the number of bytes
|
||||
written
|
||||
"""
|
||||
|
||||
while True:
|
||||
try:
|
||||
data = self.fileobj.write(data)
|
||||
await checkpoint()
|
||||
return data
|
||||
except WantWrite:
|
||||
await wait_writable(self)
|
||||
async def flush(self):
|
||||
pass
|
||||
|
||||
async def close(self):
|
||||
"""
|
||||
|
@ -103,12 +111,13 @@ class AsyncStream(AsyncResource):
|
|||
"""
|
||||
|
||||
if self.fileno() == -1:
|
||||
raise ResourceClosed("I/O operation on closed stream")
|
||||
self._fd = -1
|
||||
await closing(self)
|
||||
await release(self)
|
||||
return
|
||||
await self.flush()
|
||||
await closing(self._fd)
|
||||
await release(self._fd)
|
||||
self.fileobj.close()
|
||||
self.fileobj = None
|
||||
self._fd = -1
|
||||
await checkpoint()
|
||||
|
||||
def fileno(self):
|
||||
|
@ -119,34 +128,63 @@ class AsyncStream(AsyncResource):
|
|||
return int(self._fd)
|
||||
|
||||
async def __aenter__(self):
|
||||
self.fileobj.__enter__()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *args):
|
||||
if self._fd != -1 and self.close_on_context_exit:
|
||||
if self.fileno() != -1:
|
||||
await self.close()
|
||||
|
||||
async def dup(self):
|
||||
"""
|
||||
Wrapper stream method
|
||||
"""
|
||||
|
||||
return type(self)(os.dup(self.fileno()))
|
||||
|
||||
def __repr__(self):
|
||||
return f"AsyncStream({self.fileobj})"
|
||||
|
||||
def __del__(self):
|
||||
"""
|
||||
Stream destructor. Do *not* call
|
||||
this directly: stuff will break
|
||||
"""
|
||||
|
||||
if self._fd != -1 and self.fileobj.fileno() != -1:
|
||||
class FileStream(AsyncStream):
|
||||
"""
|
||||
A stream wrapper around a binary file-like object.
|
||||
The underlying file descriptor is put into non-blocking
|
||||
mode
|
||||
"""
|
||||
|
||||
async def _read(self, size: int = -1) -> bytes:
|
||||
while True:
|
||||
try:
|
||||
os.set_blocking(self.fileno(), False)
|
||||
os.close(self.fileno())
|
||||
except OSError as e:
|
||||
warnings.warn(
|
||||
f"An exception occurred in __del__ for stream {self} -> {type(e).__name__}: {e}"
|
||||
)
|
||||
data = self.fileobj.read(size)
|
||||
if data is None:
|
||||
# Files in non-blocking mode don't always
|
||||
# raise a blocking I/O exception and can
|
||||
# return None instead, so we account for
|
||||
# that here
|
||||
raise BlockingIOError()
|
||||
return data
|
||||
except WantRead:
|
||||
await wait_readable(self._fd)
|
||||
|
||||
async def write(self, data):
|
||||
# We use a memory view so that
|
||||
# slicing doesn't copy any memory
|
||||
mem = memoryview(data)
|
||||
while mem:
|
||||
try:
|
||||
written = self.fileobj.write(data)
|
||||
if written is None:
|
||||
raise BlockingIOError()
|
||||
mem = mem[written:]
|
||||
except WantWrite:
|
||||
await wait_writable(self._fd)
|
||||
|
||||
async def flush(self):
|
||||
if self.fileno() == -1:
|
||||
return
|
||||
while True:
|
||||
try:
|
||||
return self.fileobj.flush()
|
||||
except WantWrite:
|
||||
await wait_writable(self._fd)
|
||||
except WantRead:
|
||||
await wait_readable(self._fd)
|
||||
|
||||
def __init__(self, fileobj):
|
||||
if isinstance(fileobj, io.TextIOBase):
|
||||
raise TypeError("only binary mode files can be streamed")
|
||||
super().__init__(fileobj)
|
||||
os.set_blocking(self.fileno(), False)
|
||||
|
|
|
@ -1,12 +1,9 @@
|
|||
import io
|
||||
import os
|
||||
import sys
|
||||
import structio
|
||||
from functools import partial
|
||||
from structio.abc import AsyncResource, Stream
|
||||
from structio.core.syscalls import check_cancelled, wait_writable, wait_readable, checkpoint
|
||||
from structio.io import WantRead, WantWrite
|
||||
from structio.exceptions import ResourceClosed
|
||||
from structio.abc import AsyncResource
|
||||
from structio.core.syscalls import check_cancelled
|
||||
|
||||
# Stolen from Trio
|
||||
_FILE_SYNC_ATTRS = {
|
||||
|
@ -132,7 +129,7 @@ async def open_file(
|
|||
opener=None,
|
||||
) -> AsyncFile:
|
||||
"""
|
||||
Like io.open(), but async. Magic
|
||||
Like io.open(), but async
|
||||
"""
|
||||
|
||||
return wrap_file(
|
||||
|
@ -161,7 +158,7 @@ async def aprint(*args, sep=" ", end="\n", file=stdout, flush=False):
|
|||
Like print(), but asynchronous
|
||||
"""
|
||||
|
||||
await file.write(f"{sep.join(args)}{end}")
|
||||
await file.write(f"{sep.join(map(str, args))}{end}")
|
||||
if flush:
|
||||
await file.flush()
|
||||
|
||||
|
|
|
@ -51,9 +51,9 @@ class AsyncSocket(AsyncResource):
|
|||
await checkpoint()
|
||||
return data
|
||||
except WantRead:
|
||||
await wait_readable(self)
|
||||
await wait_readable(self._fd)
|
||||
except WantWrite:
|
||||
await wait_writable(self)
|
||||
await wait_writable(self._fd)
|
||||
|
||||
async def receive_exactly(self, size: int, flags: int = 0) -> bytes:
|
||||
"""
|
||||
|
@ -85,7 +85,7 @@ class AsyncSocket(AsyncResource):
|
|||
await checkpoint()
|
||||
break
|
||||
except WantWrite:
|
||||
await wait_writable(self)
|
||||
await wait_writable(self._fd)
|
||||
self.needs_closing = True
|
||||
|
||||
async def close(self):
|
||||
|
@ -110,7 +110,7 @@ class AsyncSocket(AsyncResource):
|
|||
await checkpoint()
|
||||
return type(self)(remote), addr
|
||||
except WantRead:
|
||||
await wait_readable(self)
|
||||
await wait_readable(self._fd)
|
||||
|
||||
async def send_all(self, data: bytes, flags: int = 0):
|
||||
"""
|
||||
|
@ -125,9 +125,9 @@ class AsyncSocket(AsyncResource):
|
|||
sent_no = self.socket.send(data, flags)
|
||||
await checkpoint()
|
||||
except WantRead:
|
||||
await wait_readable(self)
|
||||
await wait_readable(self._fd)
|
||||
except WantWrite:
|
||||
await wait_writable(self)
|
||||
await wait_writable(self._fd)
|
||||
data = data[sent_no:]
|
||||
|
||||
async def shutdown(self, how):
|
||||
|
@ -203,9 +203,9 @@ class AsyncSocket(AsyncResource):
|
|||
self.socket.do_handshake()
|
||||
await checkpoint()
|
||||
except WantRead:
|
||||
await wait_readable(self)
|
||||
await wait_readable(self._fd)
|
||||
except WantWrite:
|
||||
await wait_writable(self)
|
||||
await wait_writable(self._fd)
|
||||
|
||||
async def recvfrom(self, buffersize, flags=0):
|
||||
"""
|
||||
|
@ -216,9 +216,9 @@ class AsyncSocket(AsyncResource):
|
|||
try:
|
||||
return self.socket.recvfrom(buffersize, flags)
|
||||
except WantRead:
|
||||
await wait_readable(self)
|
||||
await wait_readable(self._fd)
|
||||
except WantWrite:
|
||||
await wait_writable(self)
|
||||
await wait_writable(self._fd)
|
||||
|
||||
async def recv_into(self, buffer, nbytes=0, flags=0):
|
||||
"""
|
||||
|
@ -231,9 +231,9 @@ class AsyncSocket(AsyncResource):
|
|||
await checkpoint()
|
||||
return data
|
||||
except WantRead:
|
||||
await wait_readable(self)
|
||||
await wait_readable(self._fd)
|
||||
except WantWrite:
|
||||
await wait_writable(self)
|
||||
await wait_writable(self._fd)
|
||||
|
||||
async def recvfrom_into(self, buffer, bytes=0, flags=0):
|
||||
"""
|
||||
|
@ -246,9 +246,9 @@ class AsyncSocket(AsyncResource):
|
|||
await checkpoint()
|
||||
return data
|
||||
except WantRead:
|
||||
await wait_readable(self)
|
||||
await wait_readable(self._fd)
|
||||
except WantWrite:
|
||||
await wait_writable(self)
|
||||
await wait_writable(self._fd)
|
||||
|
||||
async def sendto(self, bytes, flags_or_address, address=None):
|
||||
"""
|
||||
|
@ -266,9 +266,9 @@ class AsyncSocket(AsyncResource):
|
|||
await checkpoint()
|
||||
return data
|
||||
except WantWrite:
|
||||
await wait_writable(self)
|
||||
await wait_writable(self._fd)
|
||||
except WantRead:
|
||||
await wait_readable(self)
|
||||
await wait_readable(self._fd)
|
||||
|
||||
async def getpeername(self):
|
||||
"""
|
||||
|
@ -281,9 +281,9 @@ class AsyncSocket(AsyncResource):
|
|||
await checkpoint()
|
||||
return data
|
||||
except WantWrite:
|
||||
await wait_writable(self)
|
||||
await wait_writable(self._fd)
|
||||
except WantRead:
|
||||
await wait_readable(self)
|
||||
await wait_readable(self._fd)
|
||||
|
||||
async def getsockname(self):
|
||||
"""
|
||||
|
@ -296,9 +296,9 @@ class AsyncSocket(AsyncResource):
|
|||
await checkpoint()
|
||||
return data
|
||||
except WantWrite:
|
||||
await wait_writable(self)
|
||||
await wait_writable(self._fd)
|
||||
except WantRead:
|
||||
await wait_readable(self)
|
||||
await wait_readable(self._fd)
|
||||
|
||||
async def recvmsg(self, bufsize, ancbufsize=0, flags=0):
|
||||
"""
|
||||
|
@ -311,7 +311,7 @@ class AsyncSocket(AsyncResource):
|
|||
await checkpoint()
|
||||
return data
|
||||
except WantRead:
|
||||
await wait_readable(self)
|
||||
await wait_readable(self._fd)
|
||||
|
||||
async def recvmsg_into(self, buffers, ancbufsize=0, flags=0):
|
||||
"""
|
||||
|
@ -324,7 +324,7 @@ class AsyncSocket(AsyncResource):
|
|||
await checkpoint()
|
||||
return data
|
||||
except WantRead:
|
||||
await wait_readable(self)
|
||||
await wait_readable(self._fd)
|
||||
|
||||
async def sendmsg(self, buffers, ancdata=(), flags=0, address=None):
|
||||
"""
|
||||
|
@ -337,7 +337,7 @@ class AsyncSocket(AsyncResource):
|
|||
await checkpoint()
|
||||
return data
|
||||
except WantRead:
|
||||
await wait_writable(self)
|
||||
await wait_writable(self._fd)
|
||||
|
||||
def __repr__(self):
|
||||
return f"AsyncSocket({self.socket})"
|
||||
|
|
|
@ -1,15 +1,15 @@
|
|||
# Module inspired by subprocess which allows for asynchronous
|
||||
# multiprocessing
|
||||
import os
|
||||
import structio
|
||||
import subprocess
|
||||
from subprocess import (
|
||||
CalledProcessError,
|
||||
CompletedProcess,
|
||||
SubprocessError,
|
||||
STDOUT,
|
||||
DEVNULL,
|
||||
PIPE
|
||||
)
|
||||
from structio.io import FileStream
|
||||
|
||||
|
||||
class Popen:
|
||||
|
@ -18,24 +18,91 @@ class Popen:
|
|||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
"""
|
||||
Public object constructor
|
||||
"""
|
||||
|
||||
if "universal_newlines" in kwargs:
|
||||
# Not sure why? But everyone else is doing it so :shrug:
|
||||
raise RuntimeError("universal_newlines is not supported")
|
||||
if stdin := kwargs.get("stdin"):
|
||||
# Curio mentions stuff breaking if the child process
|
||||
# is passed a stdin fd that is set to non-blocking mode
|
||||
if hasattr(os, "set_blocking"):
|
||||
if stdin not in {PIPE, DEVNULL}:
|
||||
# Curio mentions stuff breaking if the child process
|
||||
# is passed a stdin fd that is set to non-blocking mode
|
||||
os.set_blocking(stdin.fileno(), True)
|
||||
# Delegate to Popen's constructor
|
||||
self._process = subprocess.Popen(*args, **kwargs)
|
||||
self._process: subprocess.Popen = subprocess.Popen(*args, **kwargs)
|
||||
self.stdin = None
|
||||
self.stdout = None
|
||||
self.stderr = None
|
||||
if self._process.stdin:
|
||||
self.stdin = None
|
||||
self.stdin = FileStream(self._process.stdin)
|
||||
if self._process.stdout:
|
||||
self.stdout = FileStream(self._process.stdout)
|
||||
if self._process.stderr:
|
||||
self.stderr = FileStream(self._process.stderr)
|
||||
|
||||
async def wait(self):
|
||||
status = self._process.poll()
|
||||
if status is None:
|
||||
status = await structio.thread.run_in_worker(self._process.wait, cancellable=True)
|
||||
return status
|
||||
|
||||
async def communicate(self, input=b"") -> tuple[bytes, bytes]:
|
||||
async with structio.create_pool() as pool:
|
||||
stdout = pool.spawn(self.stdout.readall) if self.stdout else None
|
||||
stderr = pool.spawn(self.stderr.readall) if self.stderr else None
|
||||
if input:
|
||||
await self.stdin.write(input)
|
||||
await self.stdin.close()
|
||||
# Awaiting a task object waits for its completion and
|
||||
# returns its return value!
|
||||
out = b""
|
||||
err = b""
|
||||
if stdout:
|
||||
out = await stdout
|
||||
if stderr:
|
||||
err = await stderr
|
||||
return out, err
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *args):
|
||||
if self.stdin:
|
||||
await self.stdin.close()
|
||||
if self.stdout:
|
||||
await self.stdout.close()
|
||||
if self.stderr:
|
||||
await self.stderr.close()
|
||||
await self.wait()
|
||||
|
||||
def __getattr__(self, item):
|
||||
# Delegate to internal process object
|
||||
return getattr(self._process, item)
|
||||
|
||||
|
||||
async def run(args, *, stdin=None, input=None, stdout=None, stderr=None, shell=False, check=False):
|
||||
"""
|
||||
Async version of subprocess.run()
|
||||
"""
|
||||
|
||||
if input:
|
||||
stdin = subprocess.PIPE
|
||||
async with Popen(args, stdin=stdin, stdout=stdout, stderr=stderr, shell=shell) as process:
|
||||
try:
|
||||
stdout, stderr = await process.communicate(input)
|
||||
except:
|
||||
process.kill()
|
||||
raise
|
||||
|
||||
status = process.poll()
|
||||
if check and status:
|
||||
raise CalledProcessError(status, process.args, output=stdout, stderr=stderr)
|
||||
return CompletedProcess(process.args, status, stdout, stderr)
|
||||
|
||||
|
||||
async def check_output(args, *, stdin=None, stderr=None, shell=False, input=None):
|
||||
"""
|
||||
Async version of subprocess.check_output
|
||||
"""
|
||||
|
||||
out = await run(args, stdout=PIPE, stdin=stdin, stderr=stderr, shell=shell,
|
||||
check=True, input=input)
|
||||
return out.stdout
|
||||
|
|
|
@ -211,7 +211,7 @@ class MemoryReceiveChannel(ChannelReader):
|
|||
|
||||
class MemoryChannel(Channel, MemorySendChannel, MemoryReceiveChannel):
|
||||
"""
|
||||
An in-memory two-way channel between
|
||||
An in-memory, two-way channel between
|
||||
tasks with optional buffering
|
||||
"""
|
||||
|
||||
|
@ -243,11 +243,18 @@ class Semaphore:
|
|||
assert initial_size <= max_size
|
||||
self.max_size = max_size
|
||||
# We use an unbuffered memory channel to pause
|
||||
# as necessary, kind like socket.set_wakeup_fd
|
||||
# or something? Anyways I think it's pretty nifty
|
||||
# as necessary, kinda like socket.set_wakeup_fd
|
||||
# or something? Anyway I think it's pretty nifty
|
||||
self.channel: MemoryChannel = MemoryChannel(0)
|
||||
self._counter: int = initial_size
|
||||
|
||||
def __repr__(self):
|
||||
return f"<structio.Semaphore max_size={self.max_size} size={self._counter}>"
|
||||
|
||||
@property
|
||||
def size(self) -> int:
|
||||
return self._counter
|
||||
|
||||
@enable_ki_protection
|
||||
async def acquire(self):
|
||||
"""
|
||||
|
@ -264,9 +271,8 @@ class Semaphore:
|
|||
@enable_ki_protection
|
||||
async def release(self):
|
||||
"""
|
||||
Releases the semaphore if it was previously
|
||||
acquired by the caller. Raises RuntimeError
|
||||
if the semaphore is not acquired by anyone
|
||||
Releases a slot in the semaphore. Raises RuntimeError
|
||||
if there are no occupied slots to release
|
||||
"""
|
||||
|
||||
if self._counter == self.max_size:
|
||||
|
@ -289,13 +295,33 @@ class Semaphore:
|
|||
|
||||
class Lock:
|
||||
"""
|
||||
An asynchronous single-owner task lock
|
||||
An asynchronous single-owner task lock. Unlike
|
||||
the lock in threading.Thread, only the lock's
|
||||
owner can release it
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.owner: Task | None = None
|
||||
self._owner: Task | None = None
|
||||
self._sem: Semaphore = Semaphore(1)
|
||||
|
||||
@property
|
||||
def owner(self) -> Task | None:
|
||||
"""
|
||||
Returns the current owner of the lock,
|
||||
or None if the lock is not being held
|
||||
"""
|
||||
|
||||
return self._owner
|
||||
|
||||
@property
|
||||
def locked(self) -> bool:
|
||||
"""
|
||||
Returns whether the lock is currently
|
||||
held
|
||||
"""
|
||||
|
||||
return self._sem.size == 0
|
||||
|
||||
@enable_ki_protection
|
||||
async def acquire(self):
|
||||
"""
|
||||
|
@ -304,7 +330,7 @@ class Lock:
|
|||
"""
|
||||
|
||||
await self._sem.acquire()
|
||||
self.owner = current_task()
|
||||
self._owner = current_task()
|
||||
|
||||
@enable_ki_protection
|
||||
async def release(self):
|
||||
|
@ -320,7 +346,7 @@ class Lock:
|
|||
raise RuntimeError("lock is not acquired")
|
||||
if current_task() is not self.owner:
|
||||
raise RuntimeError("lock can only be released by the owner")
|
||||
self.owner = None
|
||||
self._owner = None
|
||||
await self._sem.release()
|
||||
|
||||
@enable_ki_protection
|
||||
|
@ -335,7 +361,7 @@ class Lock:
|
|||
|
||||
class RLock(Lock):
|
||||
"""
|
||||
An asynchronous single-owner recursive lock.
|
||||
An asynchronous, single-owner recursive lock.
|
||||
Recursive locks have the property that their
|
||||
acquire() method can be called multiple times
|
||||
by the owner without deadlocking: each call
|
||||
|
@ -357,13 +383,32 @@ class RLock(Lock):
|
|||
await checkpoint()
|
||||
self._acquire_count += 1
|
||||
|
||||
@property
|
||||
def acquire_count(self) -> int:
|
||||
"""
|
||||
Returns the number of times acquire()
|
||||
was called by the owner (note that it
|
||||
may be zero if the lock is not being
|
||||
held)
|
||||
"""
|
||||
|
||||
return self._acquire_count
|
||||
|
||||
@enable_ki_protection
|
||||
async def release(self):
|
||||
self._acquire_count -= 1
|
||||
if self._acquire_count == 0:
|
||||
# I hate the repetition, but it's the
|
||||
# only way to make sure that a task can't
|
||||
# decrement the counter of a lock it does
|
||||
# not own
|
||||
current = current_task()
|
||||
if self.owner != current:
|
||||
await super().release()
|
||||
else:
|
||||
await checkpoint()
|
||||
self._acquire_count -= 1
|
||||
if self._acquire_count == 0:
|
||||
await super().release()
|
||||
else:
|
||||
await checkpoint()
|
||||
|
||||
|
||||
_events: dict[str, list[Callable[[Any, Any], Coroutine[Any, Any, Any]]]] = defaultdict(list)
|
||||
|
|
|
@ -319,15 +319,7 @@ async def run_in_worker(
|
|||
# we run out of slots and proceed once
|
||||
# we have more
|
||||
async with _storage.max_workers:
|
||||
# We do a little magic trick and inject the "async thread" as a
|
||||
# task in the current task pool (keep in mind structio is always
|
||||
# within some task pool, even if you don't see one explicitly. The
|
||||
# event loop has its own secret "root" task pool which is a parent to all
|
||||
# others and where the call to structio.run() as well as any other system
|
||||
# task run)
|
||||
return await current_loop().current_pool.spawn(
|
||||
_spawn_supervised_thread, sync_func, cancellable, *args
|
||||
)
|
||||
return await _spawn_supervised_thread(sync_func, cancellable, *args)
|
||||
|
||||
|
||||
@enable_ki_protection
|
||||
|
|
|
@ -0,0 +1,10 @@
|
|||
import structio
|
||||
|
||||
|
||||
async def main():
|
||||
async with structio.create_pool():
|
||||
pass
|
||||
print("[main] Done")
|
||||
|
||||
|
||||
structio.run(main)
|
|
@ -0,0 +1,20 @@
|
|||
import structio
|
||||
import subprocess
|
||||
import shlex
|
||||
|
||||
|
||||
async def main(data: str):
|
||||
cmd = shlex.split("python3 -c 'print(input())'")
|
||||
data = data.encode(errors="ignore")
|
||||
# This will print data to stdout
|
||||
await structio.parallel.run(cmd, input=data)
|
||||
# Other option
|
||||
out = await structio.parallel.check_output(cmd, input=data)
|
||||
print(out.rstrip(b"\n") == data)
|
||||
# Other, other option :D
|
||||
process = structio.parallel.Popen(cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE)
|
||||
out, _ = await process.communicate(data)
|
||||
print(out.rstrip(b"\n") == data)
|
||||
|
||||
|
||||
structio.run(main, "owo")
|
|
@ -12,12 +12,14 @@ def fake_async_sleeper(n):
|
|||
print(f"[thread] Using old boring time.sleep :(")
|
||||
time.sleep(n)
|
||||
print(f"[thread] Slept for {time.time() - t:.2f} seconds")
|
||||
return n
|
||||
|
||||
|
||||
async def main(n):
|
||||
print(f"[main] Spawning worker thread, exiting in {n} seconds")
|
||||
t = structio.clock()
|
||||
await structio.thread.run_in_worker(fake_async_sleeper, n)
|
||||
d = await structio.thread.run_in_worker(fake_async_sleeper, n)
|
||||
assert d == n
|
||||
print(f"[main] Exited in {structio.clock() - t:.2f} seconds")
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue