205 lines
5.6 KiB
Python
205 lines
5.6 KiB
Python
# This is, ahem, inspired by Curio and Trio. See https://github.com/dabeaz/curio/issues/104
|
|
import io
|
|
import os
|
|
from structio.core.syscalls import (
|
|
checkpoint,
|
|
wait_readable,
|
|
wait_writable,
|
|
closing,
|
|
release,
|
|
)
|
|
from structio.exceptions import ResourceClosed
|
|
from structio.abc import AsyncResource
|
|
|
|
try:
|
|
from ssl import SSLWantReadError, SSLWantWriteError, SSLSocket
|
|
|
|
WantRead = (BlockingIOError, SSLWantReadError, InterruptedError)
|
|
WantWrite = (BlockingIOError, SSLWantWriteError, InterruptedError)
|
|
except ImportError:
|
|
WantWrite = (BlockingIOError, InterruptedError)
|
|
WantRead = (BlockingIOError, InterruptedError)
|
|
SSLSocket = None
|
|
|
|
|
|
class FdWrapper:
|
|
"""
|
|
A simple wrapper around a file descriptor that
|
|
allows the event loop to perform an optimization
|
|
regarding I/O event registration safely. This is
|
|
because while integer file descriptors can be reused
|
|
by the operating system, instances of this class will
|
|
not (hence if the event loop keeps around a dead instance
|
|
of an FdWrapper, it at least won't accidentally register
|
|
a new file with that same file descriptor). A bonus is
|
|
that this also allows us to always assume that we can call
|
|
fileno() on all objects registered in our selector, regardless
|
|
of whether the wrapped fd is an int or something else entirely
|
|
"""
|
|
|
|
__slots__ = ("fd",)
|
|
|
|
def __init__(self, fd):
|
|
self.fd = fd
|
|
|
|
def fileno(self):
|
|
return self.fd
|
|
|
|
def __hash__(self):
|
|
return self.fd.__hash__()
|
|
|
|
# Can be converted to an int
|
|
def __int__(self):
|
|
return self.fd
|
|
|
|
def __eq__(self, other):
|
|
if not isinstance(other, FdWrapper):
|
|
return False
|
|
return self.fileno() == other.fileno()
|
|
|
|
def __repr__(self):
|
|
return f"<fd={self.fd!r}>"
|
|
|
|
|
|
class AsyncStream(AsyncResource):
|
|
"""
|
|
A generic asynchronous stream over
|
|
a file-like object, with buffering
|
|
"""
|
|
|
|
def __init__(self, fileobj):
|
|
self.fileobj = fileobj
|
|
self._fd = FdWrapper(self.fileobj.fileno())
|
|
self._buf = bytearray()
|
|
|
|
async def _read(self, size: int = -1) -> bytes:
|
|
raise NotImplementedError()
|
|
|
|
async def write(self, data):
|
|
raise NotImplementedError()
|
|
|
|
async def read(self, size: int = -1):
|
|
"""
|
|
Reads up to size bytes from the
|
|
given stream. If size == -1, read
|
|
as much as possible
|
|
"""
|
|
|
|
if size < 0 and size < -1:
|
|
raise ValueError("size must be -1 or a positive integer")
|
|
if size == -1:
|
|
size = len(self._buf)
|
|
buf = self._buf
|
|
if not buf:
|
|
return await self._read(size)
|
|
if len(buf) <= size:
|
|
data = bytes(buf)
|
|
buf.clear()
|
|
else:
|
|
data = bytes(buf[:size])
|
|
del buf[:size]
|
|
return data
|
|
|
|
# Yes I stole this from curio. Sue me.
|
|
async def readall(self):
|
|
chunks = []
|
|
maxread = 65536
|
|
if self._buf:
|
|
chunks.append(bytes(self._buf))
|
|
self._buf.clear()
|
|
while True:
|
|
chunk = await self.read(maxread)
|
|
if not chunk:
|
|
return b"".join(chunks)
|
|
chunks.append(chunk)
|
|
if len(chunk) == maxread:
|
|
maxread *= 2
|
|
|
|
async def flush(self):
|
|
pass
|
|
|
|
async def close(self):
|
|
"""
|
|
Closes the stream asynchronously
|
|
"""
|
|
|
|
if self.fileno() == -1:
|
|
return
|
|
await self.flush()
|
|
await closing(self._fd)
|
|
await release(self._fd)
|
|
self.fileobj.close()
|
|
self.fileobj = None
|
|
self._fd = -1
|
|
await checkpoint()
|
|
|
|
def fileno(self):
|
|
"""
|
|
Wrapper socket method
|
|
"""
|
|
|
|
return int(self._fd)
|
|
|
|
async def __aenter__(self):
|
|
return self
|
|
|
|
async def __aexit__(self, *args):
|
|
if self.fileno() != -1:
|
|
await self.close()
|
|
|
|
def __repr__(self):
|
|
return f"AsyncStream({self.fileobj})"
|
|
|
|
|
|
class FileStream(AsyncStream):
|
|
"""
|
|
A stream wrapper around a binary file-like object.
|
|
The underlying file descriptor is put into non-blocking
|
|
mode
|
|
"""
|
|
|
|
async def _read(self, size: int = -1) -> bytes:
|
|
while True:
|
|
try:
|
|
data = self.fileobj.read(size)
|
|
if data is None:
|
|
# Files in non-blocking mode don't always
|
|
# raise a blocking I/O exception and can
|
|
# return None instead, so we account for
|
|
# that here
|
|
raise BlockingIOError()
|
|
return data
|
|
except WantRead:
|
|
await wait_readable(self._fd)
|
|
|
|
async def write(self, data):
|
|
# We use a memory view so that
|
|
# slicing doesn't copy any memory
|
|
mem = memoryview(data)
|
|
while mem:
|
|
try:
|
|
written = self.fileobj.write(data)
|
|
if written is None:
|
|
raise BlockingIOError()
|
|
mem = mem[written:]
|
|
except WantWrite:
|
|
await wait_writable(self._fd)
|
|
|
|
async def flush(self):
|
|
if self.fileno() == -1:
|
|
return
|
|
while True:
|
|
try:
|
|
return self.fileobj.flush()
|
|
except WantWrite:
|
|
await wait_writable(self._fd)
|
|
except WantRead:
|
|
await wait_readable(self._fd)
|
|
|
|
def __init__(self, fileobj):
|
|
if isinstance(fileobj, io.TextIOBase):
|
|
raise TypeError("only binary mode files can be streamed")
|
|
super().__init__(fileobj)
|
|
if hasattr(os, "set_blocking"):
|
|
os.set_blocking(self.fileno(), False)
|