119 lines
4.2 KiB
Python
119 lines
4.2 KiB
Python
from .base import AsyncNetworkStream, AsyncNetworkBackend, SOCKET_OPTION
|
|
from .._exceptions import (ConnectTimeout, ReadError, ReadTimeout,
|
|
WriteError, WriteTimeout, ConnectError,
|
|
map_exceptions, ExceptionMapping, PoolTimeout)
|
|
|
|
import structio
|
|
import ssl
|
|
import typing
|
|
|
|
|
|
class StructioStream(AsyncNetworkStream):
|
|
"""
|
|
Structio-compatible async stream for
|
|
httpx
|
|
"""
|
|
|
|
def __init__(self, stream: structio.AsyncSocket):
|
|
self._stream = stream
|
|
|
|
async def read(
|
|
self, max_bytes: int, timeout: typing.Optional[float] = None
|
|
) -> bytes:
|
|
timeout_or_inf = float("inf") if timeout is None else timeout
|
|
exc_map: ExceptionMapping = {
|
|
structio.TimedOut: ReadTimeout,
|
|
structio.ResourceClosed: ReadError,
|
|
structio.ResourceBusy: ReadError,
|
|
structio.ResourceBroken: ReadError
|
|
}
|
|
with map_exceptions(exc_map):
|
|
with structio.with_timeout(timeout_or_inf):
|
|
data: bytes = await self._stream.receive(max_bytes)
|
|
return data
|
|
|
|
async def write(
|
|
self, buffer: bytes, timeout: typing.Optional[float] = None
|
|
) -> None:
|
|
if not buffer:
|
|
return
|
|
|
|
timeout_or_inf = float("inf") if timeout is None else timeout
|
|
exc_map: ExceptionMapping = {
|
|
structio.TimedOut: WriteTimeout,
|
|
structio.ResourceClosed: WriteError,
|
|
structio.ResourceBusy: WriteError,
|
|
structio.ResourceBroken: WriteError
|
|
}
|
|
with map_exceptions(exc_map):
|
|
with structio.with_timeout(timeout_or_inf):
|
|
await self._stream.send_all(data=buffer)
|
|
|
|
async def start_tls(
|
|
self,
|
|
ssl_context: ssl.SSLContext,
|
|
server_hostname: typing.Optional[str] = None,
|
|
timeout: typing.Optional[float] = None,
|
|
) -> AsyncNetworkStream:
|
|
timeout_or_inf = float("inf") if timeout is None else timeout
|
|
exc_map: ExceptionMapping = {
|
|
structio.TimedOut: ConnectTimeout,
|
|
structio.ResourceBroken: ConnectError,
|
|
}
|
|
with map_exceptions(exc_map):
|
|
try:
|
|
with structio.with_timeout(timeout_or_inf):
|
|
self._stream = await structio.socket.wrap_socket_with_ssl(self._stream, context=ssl_context, server_hostname=server_hostname)
|
|
except Exception as exc: # pragma: nocover
|
|
await self.aclose()
|
|
raise exc
|
|
return self
|
|
|
|
async def aclose(self) -> None:
|
|
return await self._stream.close()
|
|
|
|
def get_extra_info(self, info: str) -> typing.Any:
|
|
if info == "ssl_object" and hasattr(self._stream, "_sslobj"):
|
|
return self._stream._sslobj
|
|
if info == "client_addr":
|
|
return self._stream.socket.getsockname()
|
|
if info == "server_addr":
|
|
return self._stream.socket.getpeername()
|
|
if info == "socket":
|
|
return self._stream
|
|
if info == "is_readable":
|
|
return self._stream.is_readable()
|
|
return None
|
|
|
|
|
|
class StructioBackend(AsyncNetworkBackend):
|
|
async def connect_tcp(
|
|
self,
|
|
host: str,
|
|
port: int,
|
|
timeout: typing.Optional[float] = None,
|
|
local_address: typing.Optional[str] = None,
|
|
socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None,
|
|
) -> AsyncNetworkStream:
|
|
if socket_options is None:
|
|
socket_options = []
|
|
timeout_or_inf = float("inf") if timeout is None else timeout
|
|
exc_map: ExceptionMapping = {
|
|
structio.TimedOut: ConnectTimeout,
|
|
structio.ResourceBusy: ConnectError,
|
|
OSError: ConnectError,
|
|
}
|
|
with map_exceptions(exc_map):
|
|
with structio.with_timeout(timeout_or_inf):
|
|
stream: structio.AsyncSocket = await structio.socket.connect_tcp_socket(
|
|
host=host, port=port, source_address=local_address
|
|
)
|
|
for option in socket_options:
|
|
stream.setsockopt(*option)
|
|
return StructioStream(stream)
|
|
|
|
# TODO: connect_unix_socket
|
|
|
|
async def sleep(self, seconds: float) -> None:
|
|
await structio.sleep(seconds)
|