146 lines
5.1 KiB
Python
146 lines
5.1 KiB
Python
import ssl
|
|
import typing
|
|
|
|
import anyio
|
|
|
|
from .._exceptions import (
|
|
ConnectError,
|
|
ConnectTimeout,
|
|
ReadError,
|
|
ReadTimeout,
|
|
WriteError,
|
|
WriteTimeout,
|
|
map_exceptions,
|
|
)
|
|
from .._utils import is_socket_readable
|
|
from .base import SOCKET_OPTION, AsyncNetworkBackend, AsyncNetworkStream
|
|
|
|
|
|
class AnyIOStream(AsyncNetworkStream):
|
|
def __init__(self, stream: anyio.abc.ByteStream) -> None:
|
|
self._stream = stream
|
|
|
|
async def read(
|
|
self, max_bytes: int, timeout: typing.Optional[float] = None
|
|
) -> bytes:
|
|
exc_map = {
|
|
TimeoutError: ReadTimeout,
|
|
anyio.BrokenResourceError: ReadError,
|
|
anyio.ClosedResourceError: ReadError,
|
|
}
|
|
with map_exceptions(exc_map):
|
|
with anyio.fail_after(timeout):
|
|
try:
|
|
return await self._stream.receive(max_bytes=max_bytes)
|
|
except anyio.EndOfStream: # pragma: nocover
|
|
return b""
|
|
|
|
async def write(
|
|
self, buffer: bytes, timeout: typing.Optional[float] = None
|
|
) -> None:
|
|
if not buffer:
|
|
return
|
|
|
|
exc_map = {
|
|
TimeoutError: WriteTimeout,
|
|
anyio.BrokenResourceError: WriteError,
|
|
anyio.ClosedResourceError: WriteError,
|
|
}
|
|
with map_exceptions(exc_map):
|
|
with anyio.fail_after(timeout):
|
|
await self._stream.send(item=buffer)
|
|
|
|
async def aclose(self) -> None:
|
|
await self._stream.aclose()
|
|
|
|
async def start_tls(
|
|
self,
|
|
ssl_context: ssl.SSLContext,
|
|
server_hostname: typing.Optional[str] = None,
|
|
timeout: typing.Optional[float] = None,
|
|
) -> AsyncNetworkStream:
|
|
exc_map = {
|
|
TimeoutError: ConnectTimeout,
|
|
anyio.BrokenResourceError: ConnectError,
|
|
}
|
|
with map_exceptions(exc_map):
|
|
try:
|
|
with anyio.fail_after(timeout):
|
|
ssl_stream = await anyio.streams.tls.TLSStream.wrap(
|
|
self._stream,
|
|
ssl_context=ssl_context,
|
|
hostname=server_hostname,
|
|
standard_compatible=False,
|
|
server_side=False,
|
|
)
|
|
except Exception as exc: # pragma: nocover
|
|
await self.aclose()
|
|
raise exc
|
|
return AnyIOStream(ssl_stream)
|
|
|
|
def get_extra_info(self, info: str) -> typing.Any:
|
|
if info == "ssl_object":
|
|
return self._stream.extra(anyio.streams.tls.TLSAttribute.ssl_object, None)
|
|
if info == "client_addr":
|
|
return self._stream.extra(anyio.abc.SocketAttribute.local_address, None)
|
|
if info == "server_addr":
|
|
return self._stream.extra(anyio.abc.SocketAttribute.remote_address, None)
|
|
if info == "socket":
|
|
return self._stream.extra(anyio.abc.SocketAttribute.raw_socket, None)
|
|
if info == "is_readable":
|
|
sock = self._stream.extra(anyio.abc.SocketAttribute.raw_socket, None)
|
|
return is_socket_readable(sock)
|
|
return None
|
|
|
|
|
|
class AnyIOBackend(AsyncNetworkBackend):
|
|
async def connect_tcp(
|
|
self,
|
|
host: str,
|
|
port: int,
|
|
timeout: typing.Optional[float] = None,
|
|
local_address: typing.Optional[str] = None,
|
|
socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None,
|
|
) -> AsyncNetworkStream:
|
|
if socket_options is None:
|
|
socket_options = [] # pragma: no cover
|
|
exc_map = {
|
|
TimeoutError: ConnectTimeout,
|
|
OSError: ConnectError,
|
|
anyio.BrokenResourceError: ConnectError,
|
|
}
|
|
with map_exceptions(exc_map):
|
|
with anyio.fail_after(timeout):
|
|
stream: anyio.abc.ByteStream = await anyio.connect_tcp(
|
|
remote_host=host,
|
|
remote_port=port,
|
|
local_host=local_address,
|
|
)
|
|
# By default TCP sockets opened in `asyncio` include TCP_NODELAY.
|
|
for option in socket_options:
|
|
stream._raw_socket.setsockopt(*option) # type: ignore[attr-defined] # pragma: no cover
|
|
return AnyIOStream(stream)
|
|
|
|
async def connect_unix_socket(
|
|
self,
|
|
path: str,
|
|
timeout: typing.Optional[float] = None,
|
|
socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None,
|
|
) -> AsyncNetworkStream: # pragma: nocover
|
|
if socket_options is None:
|
|
socket_options = []
|
|
exc_map = {
|
|
TimeoutError: ConnectTimeout,
|
|
OSError: ConnectError,
|
|
anyio.BrokenResourceError: ConnectError,
|
|
}
|
|
with map_exceptions(exc_map):
|
|
with anyio.fail_after(timeout):
|
|
stream: anyio.abc.ByteStream = await anyio.connect_unix(path)
|
|
for option in socket_options:
|
|
stream._raw_socket.setsockopt(*option) # type: ignore[attr-defined] # pragma: no cover
|
|
return AnyIOStream(stream)
|
|
|
|
async def sleep(self, seconds: float) -> None:
|
|
await anyio.sleep(seconds) # pragma: nocover
|