from .base import AsyncNetworkStream, AsyncNetworkBackend, SOCKET_OPTION from .._exceptions import (ConnectTimeout, ReadError, ReadTimeout, WriteError, WriteTimeout, ConnectError, map_exceptions, ExceptionMapping, PoolTimeout) import structio import ssl import typing class StructioStream(AsyncNetworkStream): """ Structio-compatible async stream for httpx """ def __init__(self, stream: structio.AsyncSocket): self._stream = stream async def read( self, max_bytes: int, timeout: typing.Optional[float] = None ) -> bytes: timeout_or_inf = float("inf") if timeout is None else timeout exc_map: ExceptionMapping = { structio.TimedOut: ReadTimeout, structio.ResourceClosed: ReadError, structio.ResourceBusy: ReadError, structio.ResourceBroken: ReadError } with map_exceptions(exc_map): with structio.with_timeout(timeout_or_inf): data: bytes = await self._stream.receive(max_bytes) return data async def write( self, buffer: bytes, timeout: typing.Optional[float] = None ) -> None: if not buffer: return timeout_or_inf = float("inf") if timeout is None else timeout exc_map: ExceptionMapping = { structio.TimedOut: WriteTimeout, structio.ResourceClosed: WriteError, structio.ResourceBusy: WriteError, structio.ResourceBroken: WriteError } with map_exceptions(exc_map): with structio.with_timeout(timeout_or_inf): await self._stream.send_all(data=buffer) async def start_tls( self, ssl_context: ssl.SSLContext, server_hostname: typing.Optional[str] = None, timeout: typing.Optional[float] = None, ) -> AsyncNetworkStream: timeout_or_inf = float("inf") if timeout is None else timeout exc_map: ExceptionMapping = { structio.TimedOut: ConnectTimeout, structio.ResourceBroken: ConnectError, } with map_exceptions(exc_map): try: with structio.with_timeout(timeout_or_inf): self._stream = await structio.socket.wrap_socket_with_ssl(self._stream, context=ssl_context, server_hostname=server_hostname) except Exception as exc: # pragma: nocover await self.aclose() raise exc return self async def aclose(self) -> None: return await self._stream.close() def get_extra_info(self, info: str) -> typing.Any: if info == "ssl_object" and hasattr(self._stream, "_sslobj"): return self._stream._sslobj if info == "client_addr": return self._stream.socket.getsockname() if info == "server_addr": return self._stream.socket.getpeername() if info == "socket": return self._stream if info == "is_readable": return self._stream.is_readable() return None class StructioBackend(AsyncNetworkBackend): async def connect_tcp( self, host: str, port: int, timeout: typing.Optional[float] = None, local_address: typing.Optional[str] = None, socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None, ) -> AsyncNetworkStream: if socket_options is None: socket_options = [] timeout_or_inf = float("inf") if timeout is None else timeout exc_map: ExceptionMapping = { structio.TimedOut: ConnectTimeout, structio.ResourceBusy: ConnectError, OSError: ConnectError, } with map_exceptions(exc_map): with structio.with_timeout(timeout_or_inf): stream: structio.AsyncSocket = await structio.socket.connect_tcp_socket( host=host, port=port, source_address=local_address ) for option in socket_options: stream.setsockopt(*option) return StructioStream(stream) # TODO: connect_unix_socket async def sleep(self, seconds: float) -> None: await structio.sleep(seconds)