Added patched httpcore to make the test run properly
This commit is contained in:
parent
9b6735b924
commit
3f48c74346
|
@ -0,0 +1,139 @@
|
|||
from ._api import request, stream
|
||||
from ._async import (
|
||||
AsyncConnectionInterface,
|
||||
AsyncConnectionPool,
|
||||
AsyncHTTP2Connection,
|
||||
AsyncHTTP11Connection,
|
||||
AsyncHTTPConnection,
|
||||
AsyncHTTPProxy,
|
||||
AsyncSOCKSProxy,
|
||||
)
|
||||
from ._backends.base import (
|
||||
SOCKET_OPTION,
|
||||
AsyncNetworkBackend,
|
||||
AsyncNetworkStream,
|
||||
NetworkBackend,
|
||||
NetworkStream,
|
||||
)
|
||||
from ._backends.mock import AsyncMockBackend, AsyncMockStream, MockBackend, MockStream
|
||||
from ._backends.sync import SyncBackend
|
||||
from ._exceptions import (
|
||||
ConnectError,
|
||||
ConnectionNotAvailable,
|
||||
ConnectTimeout,
|
||||
LocalProtocolError,
|
||||
NetworkError,
|
||||
PoolTimeout,
|
||||
ProtocolError,
|
||||
ProxyError,
|
||||
ReadError,
|
||||
ReadTimeout,
|
||||
RemoteProtocolError,
|
||||
TimeoutException,
|
||||
UnsupportedProtocol,
|
||||
WriteError,
|
||||
WriteTimeout,
|
||||
)
|
||||
from ._models import URL, Origin, Request, Response
|
||||
from ._ssl import default_ssl_context
|
||||
from ._sync import (
|
||||
ConnectionInterface,
|
||||
ConnectionPool,
|
||||
HTTP2Connection,
|
||||
HTTP11Connection,
|
||||
HTTPConnection,
|
||||
HTTPProxy,
|
||||
SOCKSProxy,
|
||||
)
|
||||
|
||||
# The 'httpcore.AnyIOBackend' class is conditional on 'anyio' being installed.
|
||||
try:
|
||||
from ._backends.anyio import AnyIOBackend
|
||||
except ImportError: # pragma: nocover
|
||||
|
||||
class AnyIOBackend: # type: ignore
|
||||
def __init__(self, *args, **kwargs): # type: ignore
|
||||
msg = (
|
||||
"Attempted to use 'httpcore.AnyIOBackend' but 'anyio' is not installed."
|
||||
)
|
||||
raise RuntimeError(msg)
|
||||
|
||||
|
||||
# The 'httpcore.TrioBackend' class is conditional on 'trio' being installed.
|
||||
try:
|
||||
from ._backends.trio import TrioBackend
|
||||
except ImportError: # pragma: nocover
|
||||
|
||||
class TrioBackend: # type: ignore
|
||||
def __init__(self, *args, **kwargs): # type: ignore
|
||||
msg = "Attempted to use 'httpcore.TrioBackend' but 'trio' is not installed."
|
||||
raise RuntimeError(msg)
|
||||
|
||||
|
||||
__all__ = [
|
||||
# top-level requests
|
||||
"request",
|
||||
"stream",
|
||||
# models
|
||||
"Origin",
|
||||
"URL",
|
||||
"Request",
|
||||
"Response",
|
||||
# async
|
||||
"AsyncHTTPConnection",
|
||||
"AsyncConnectionPool",
|
||||
"AsyncHTTPProxy",
|
||||
"AsyncHTTP11Connection",
|
||||
"AsyncHTTP2Connection",
|
||||
"AsyncConnectionInterface",
|
||||
"AsyncSOCKSProxy",
|
||||
# sync
|
||||
"HTTPConnection",
|
||||
"ConnectionPool",
|
||||
"HTTPProxy",
|
||||
"HTTP11Connection",
|
||||
"HTTP2Connection",
|
||||
"ConnectionInterface",
|
||||
"SOCKSProxy",
|
||||
# network backends, implementations
|
||||
"SyncBackend",
|
||||
"AnyIOBackend",
|
||||
"TrioBackend",
|
||||
# network backends, mock implementations
|
||||
"AsyncMockBackend",
|
||||
"AsyncMockStream",
|
||||
"MockBackend",
|
||||
"MockStream",
|
||||
# network backends, interface
|
||||
"AsyncNetworkStream",
|
||||
"AsyncNetworkBackend",
|
||||
"NetworkStream",
|
||||
"NetworkBackend",
|
||||
# util
|
||||
"default_ssl_context",
|
||||
"SOCKET_OPTION",
|
||||
# exceptions
|
||||
"ConnectionNotAvailable",
|
||||
"ProxyError",
|
||||
"ProtocolError",
|
||||
"LocalProtocolError",
|
||||
"RemoteProtocolError",
|
||||
"UnsupportedProtocol",
|
||||
"TimeoutException",
|
||||
"PoolTimeout",
|
||||
"ConnectTimeout",
|
||||
"ReadTimeout",
|
||||
"WriteTimeout",
|
||||
"NetworkError",
|
||||
"ConnectError",
|
||||
"ReadError",
|
||||
"WriteError",
|
||||
]
|
||||
|
||||
__version__ = "0.17.3"
|
||||
|
||||
|
||||
__locals = locals()
|
||||
for __name in __all__:
|
||||
if not __name.startswith("__"):
|
||||
setattr(__locals[__name], "__module__", "httpcore") # noqa
|
|
@ -0,0 +1,92 @@
|
|||
from contextlib import contextmanager
|
||||
from typing import Iterator, Optional, Union
|
||||
|
||||
from ._models import URL, Extensions, HeaderTypes, Response
|
||||
from ._sync.connection_pool import ConnectionPool
|
||||
|
||||
|
||||
def request(
|
||||
method: Union[bytes, str],
|
||||
url: Union[URL, bytes, str],
|
||||
*,
|
||||
headers: HeaderTypes = None,
|
||||
content: Union[bytes, Iterator[bytes], None] = None,
|
||||
extensions: Optional[Extensions] = None,
|
||||
) -> Response:
|
||||
"""
|
||||
Sends an HTTP request, returning the response.
|
||||
|
||||
```
|
||||
response = httpcore.request("GET", "https://www.example.com/")
|
||||
```
|
||||
|
||||
Arguments:
|
||||
method: The HTTP method for the request. Typically one of `"GET"`,
|
||||
`"OPTIONS"`, `"HEAD"`, `"POST"`, `"PUT"`, `"PATCH"`, or `"DELETE"`.
|
||||
url: The URL of the HTTP request. Either as an instance of `httpcore.URL`,
|
||||
or as str/bytes.
|
||||
headers: The HTTP request headers. Either as a dictionary of str/bytes,
|
||||
or as a list of two-tuples of str/bytes.
|
||||
content: The content of the request body. Either as bytes,
|
||||
or as a bytes iterator.
|
||||
extensions: A dictionary of optional extra information included on the request.
|
||||
Possible keys include `"timeout"`.
|
||||
|
||||
Returns:
|
||||
An instance of `httpcore.Response`.
|
||||
"""
|
||||
with ConnectionPool() as pool:
|
||||
return pool.request(
|
||||
method=method,
|
||||
url=url,
|
||||
headers=headers,
|
||||
content=content,
|
||||
extensions=extensions,
|
||||
)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def stream(
|
||||
method: Union[bytes, str],
|
||||
url: Union[URL, bytes, str],
|
||||
*,
|
||||
headers: HeaderTypes = None,
|
||||
content: Union[bytes, Iterator[bytes], None] = None,
|
||||
extensions: Optional[Extensions] = None,
|
||||
) -> Iterator[Response]:
|
||||
"""
|
||||
Sends an HTTP request, returning the response within a content manager.
|
||||
|
||||
```
|
||||
with httpcore.stream("GET", "https://www.example.com/") as response:
|
||||
...
|
||||
```
|
||||
|
||||
When using the `stream()` function, the body of the response will not be
|
||||
automatically read. If you want to access the response body you should
|
||||
either use `content = response.read()`, or `for chunk in response.iter_content()`.
|
||||
|
||||
Arguments:
|
||||
method: The HTTP method for the request. Typically one of `"GET"`,
|
||||
`"OPTIONS"`, `"HEAD"`, `"POST"`, `"PUT"`, `"PATCH"`, or `"DELETE"`.
|
||||
url: The URL of the HTTP request. Either as an instance of `httpcore.URL`,
|
||||
or as str/bytes.
|
||||
headers: The HTTP request headers. Either as a dictionary of str/bytes,
|
||||
or as a list of two-tuples of str/bytes.
|
||||
content: The content of the request body. Either as bytes,
|
||||
or as a bytes iterator.
|
||||
extensions: A dictionary of optional extra information included on the request.
|
||||
Possible keys include `"timeout"`.
|
||||
|
||||
Returns:
|
||||
An instance of `httpcore.Response`.
|
||||
"""
|
||||
with ConnectionPool() as pool:
|
||||
with pool.stream(
|
||||
method=method,
|
||||
url=url,
|
||||
headers=headers,
|
||||
content=content,
|
||||
extensions=extensions,
|
||||
) as response:
|
||||
yield response
|
|
@ -0,0 +1,39 @@
|
|||
from .connection import AsyncHTTPConnection
|
||||
from .connection_pool import AsyncConnectionPool
|
||||
from .http11 import AsyncHTTP11Connection
|
||||
from .http_proxy import AsyncHTTPProxy
|
||||
from .interfaces import AsyncConnectionInterface
|
||||
|
||||
try:
|
||||
from .http2 import AsyncHTTP2Connection
|
||||
except ImportError: # pragma: nocover
|
||||
|
||||
class AsyncHTTP2Connection: # type: ignore
|
||||
def __init__(self, *args, **kwargs) -> None: # type: ignore
|
||||
raise RuntimeError(
|
||||
"Attempted to use http2 support, but the `h2` package is not "
|
||||
"installed. Use 'pip install httpcore[http2]'."
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
from .socks_proxy import AsyncSOCKSProxy
|
||||
except ImportError: # pragma: nocover
|
||||
|
||||
class AsyncSOCKSProxy: # type: ignore
|
||||
def __init__(self, *args, **kwargs) -> None: # type: ignore
|
||||
raise RuntimeError(
|
||||
"Attempted to use SOCKS support, but the `socksio` package is not "
|
||||
"installed. Use 'pip install httpcore[socks]'."
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"AsyncHTTPConnection",
|
||||
"AsyncConnectionPool",
|
||||
"AsyncHTTPProxy",
|
||||
"AsyncHTTP11Connection",
|
||||
"AsyncHTTP2Connection",
|
||||
"AsyncConnectionInterface",
|
||||
"AsyncSOCKSProxy",
|
||||
]
|
|
@ -0,0 +1,215 @@
|
|||
import itertools
|
||||
import logging
|
||||
import ssl
|
||||
from types import TracebackType
|
||||
from typing import Iterable, Iterator, Optional, Type
|
||||
|
||||
from .._backends.auto import AutoBackend
|
||||
from .._backends.base import SOCKET_OPTION, AsyncNetworkBackend, AsyncNetworkStream
|
||||
from .._exceptions import ConnectError, ConnectionNotAvailable, ConnectTimeout
|
||||
from .._models import Origin, Request, Response
|
||||
from .._ssl import default_ssl_context
|
||||
from .._synchronization import AsyncLock
|
||||
from .._trace import Trace
|
||||
from .http11 import AsyncHTTP11Connection
|
||||
from .interfaces import AsyncConnectionInterface
|
||||
|
||||
RETRIES_BACKOFF_FACTOR = 0.5 # 0s, 0.5s, 1s, 2s, 4s, etc.
|
||||
|
||||
|
||||
logger = logging.getLogger("httpcore.connection")
|
||||
|
||||
|
||||
def exponential_backoff(factor: float) -> Iterator[float]:
|
||||
yield 0
|
||||
for n in itertools.count(2):
|
||||
yield factor * (2 ** (n - 2))
|
||||
|
||||
|
||||
class AsyncHTTPConnection(AsyncConnectionInterface):
|
||||
def __init__(
|
||||
self,
|
||||
origin: Origin,
|
||||
ssl_context: Optional[ssl.SSLContext] = None,
|
||||
keepalive_expiry: Optional[float] = None,
|
||||
http1: bool = True,
|
||||
http2: bool = False,
|
||||
retries: int = 0,
|
||||
local_address: Optional[str] = None,
|
||||
uds: Optional[str] = None,
|
||||
network_backend: Optional[AsyncNetworkBackend] = None,
|
||||
socket_options: Optional[Iterable[SOCKET_OPTION]] = None,
|
||||
) -> None:
|
||||
self._origin = origin
|
||||
self._ssl_context = ssl_context
|
||||
self._keepalive_expiry = keepalive_expiry
|
||||
self._http1 = http1
|
||||
self._http2 = http2
|
||||
self._retries = retries
|
||||
self._local_address = local_address
|
||||
self._uds = uds
|
||||
|
||||
self._network_backend: AsyncNetworkBackend = (
|
||||
AutoBackend() if network_backend is None else network_backend
|
||||
)
|
||||
self._connection: Optional[AsyncConnectionInterface] = None
|
||||
self._connect_failed: bool = False
|
||||
self._request_lock = AsyncLock()
|
||||
self._socket_options = socket_options
|
||||
|
||||
async def handle_async_request(self, request: Request) -> Response:
|
||||
if not self.can_handle_request(request.url.origin):
|
||||
raise RuntimeError(
|
||||
f"Attempted to send request to {request.url.origin} on connection to {self._origin}"
|
||||
)
|
||||
|
||||
async with self._request_lock:
|
||||
if self._connection is None:
|
||||
try:
|
||||
stream = await self._connect(request)
|
||||
|
||||
ssl_object = stream.get_extra_info("ssl_object")
|
||||
http2_negotiated = (
|
||||
ssl_object is not None
|
||||
and ssl_object.selected_alpn_protocol() == "h2"
|
||||
)
|
||||
if http2_negotiated or (self._http2 and not self._http1):
|
||||
from .http2 import AsyncHTTP2Connection
|
||||
|
||||
self._connection = AsyncHTTP2Connection(
|
||||
origin=self._origin,
|
||||
stream=stream,
|
||||
keepalive_expiry=self._keepalive_expiry,
|
||||
)
|
||||
else:
|
||||
self._connection = AsyncHTTP11Connection(
|
||||
origin=self._origin,
|
||||
stream=stream,
|
||||
keepalive_expiry=self._keepalive_expiry,
|
||||
)
|
||||
except Exception as exc:
|
||||
self._connect_failed = True
|
||||
raise exc
|
||||
elif not self._connection.is_available():
|
||||
raise ConnectionNotAvailable()
|
||||
|
||||
return await self._connection.handle_async_request(request)
|
||||
|
||||
async def _connect(self, request: Request) -> AsyncNetworkStream:
|
||||
timeouts = request.extensions.get("timeout", {})
|
||||
sni_hostname = request.extensions.get("sni_hostname", None)
|
||||
timeout = timeouts.get("connect", None)
|
||||
|
||||
retries_left = self._retries
|
||||
delays = exponential_backoff(factor=RETRIES_BACKOFF_FACTOR)
|
||||
|
||||
while True:
|
||||
try:
|
||||
if self._uds is None:
|
||||
kwargs = {
|
||||
"host": self._origin.host.decode("ascii"),
|
||||
"port": self._origin.port,
|
||||
"local_address": self._local_address,
|
||||
"timeout": timeout,
|
||||
"socket_options": self._socket_options,
|
||||
}
|
||||
async with Trace("connect_tcp", logger, request, kwargs) as trace:
|
||||
stream = await self._network_backend.connect_tcp(**kwargs)
|
||||
trace.return_value = stream
|
||||
else:
|
||||
kwargs = {
|
||||
"path": self._uds,
|
||||
"timeout": timeout,
|
||||
"socket_options": self._socket_options,
|
||||
}
|
||||
async with Trace(
|
||||
"connect_unix_socket", logger, request, kwargs
|
||||
) as trace:
|
||||
stream = await self._network_backend.connect_unix_socket(
|
||||
**kwargs
|
||||
)
|
||||
trace.return_value = stream
|
||||
|
||||
if self._origin.scheme == b"https":
|
||||
ssl_context = (
|
||||
default_ssl_context()
|
||||
if self._ssl_context is None
|
||||
else self._ssl_context
|
||||
)
|
||||
alpn_protocols = ["http/1.1", "h2"] if self._http2 else ["http/1.1"]
|
||||
ssl_context.set_alpn_protocols(alpn_protocols)
|
||||
|
||||
kwargs = {
|
||||
"ssl_context": ssl_context,
|
||||
"server_hostname": sni_hostname
|
||||
or self._origin.host.decode("ascii"),
|
||||
"timeout": timeout,
|
||||
}
|
||||
async with Trace("start_tls", logger, request, kwargs) as trace:
|
||||
stream = await stream.start_tls(**kwargs)
|
||||
trace.return_value = stream
|
||||
return stream
|
||||
except (ConnectError, ConnectTimeout):
|
||||
if retries_left <= 0:
|
||||
raise
|
||||
retries_left -= 1
|
||||
delay = next(delays)
|
||||
async with Trace("retry", logger, request, kwargs) as trace:
|
||||
await self._network_backend.sleep(delay)
|
||||
|
||||
def can_handle_request(self, origin: Origin) -> bool:
|
||||
return origin == self._origin
|
||||
|
||||
async def aclose(self) -> None:
|
||||
if self._connection is not None:
|
||||
async with Trace("close", logger, None, {}):
|
||||
await self._connection.aclose()
|
||||
|
||||
def is_available(self) -> bool:
|
||||
if self._connection is None:
|
||||
# If HTTP/2 support is enabled, and the resulting connection could
|
||||
# end up as HTTP/2 then we should indicate the connection as being
|
||||
# available to service multiple requests.
|
||||
return (
|
||||
self._http2
|
||||
and (self._origin.scheme == b"https" or not self._http1)
|
||||
and not self._connect_failed
|
||||
)
|
||||
return self._connection.is_available()
|
||||
|
||||
def has_expired(self) -> bool:
|
||||
if self._connection is None:
|
||||
return self._connect_failed
|
||||
return self._connection.has_expired()
|
||||
|
||||
def is_idle(self) -> bool:
|
||||
if self._connection is None:
|
||||
return self._connect_failed
|
||||
return self._connection.is_idle()
|
||||
|
||||
def is_closed(self) -> bool:
|
||||
if self._connection is None:
|
||||
return self._connect_failed
|
||||
return self._connection.is_closed()
|
||||
|
||||
def info(self) -> str:
|
||||
if self._connection is None:
|
||||
return "CONNECTION FAILED" if self._connect_failed else "CONNECTING"
|
||||
return self._connection.info()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<{self.__class__.__name__} [{self.info()}]>"
|
||||
|
||||
# These context managers are not used in the standard flow, but are
|
||||
# useful for testing or working with connection instances directly.
|
||||
|
||||
async def __aenter__(self) -> "AsyncHTTPConnection":
|
||||
return self
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: Optional[Type[BaseException]] = None,
|
||||
exc_value: Optional[BaseException] = None,
|
||||
traceback: Optional[TracebackType] = None,
|
||||
) -> None:
|
||||
await self.aclose()
|
|
@ -0,0 +1,356 @@
|
|||
import ssl
|
||||
import sys
|
||||
from types import TracebackType
|
||||
from typing import AsyncIterable, AsyncIterator, Iterable, List, Optional, Type
|
||||
|
||||
from .._backends.auto import AutoBackend
|
||||
from .._backends.base import SOCKET_OPTION, AsyncNetworkBackend
|
||||
from .._exceptions import ConnectionNotAvailable, UnsupportedProtocol
|
||||
from .._models import Origin, Request, Response
|
||||
from .._synchronization import AsyncEvent, AsyncLock, AsyncShieldCancellation
|
||||
from .connection import AsyncHTTPConnection
|
||||
from .interfaces import AsyncConnectionInterface, AsyncRequestInterface
|
||||
|
||||
|
||||
class RequestStatus:
|
||||
def __init__(self, request: Request):
|
||||
self.request = request
|
||||
self.connection: Optional[AsyncConnectionInterface] = None
|
||||
self._connection_acquired = AsyncEvent()
|
||||
|
||||
def set_connection(self, connection: AsyncConnectionInterface) -> None:
|
||||
assert self.connection is None
|
||||
self.connection = connection
|
||||
self._connection_acquired.set()
|
||||
|
||||
def unset_connection(self) -> None:
|
||||
assert self.connection is not None
|
||||
self.connection = None
|
||||
self._connection_acquired = AsyncEvent()
|
||||
|
||||
async def wait_for_connection(
|
||||
self, timeout: Optional[float] = None
|
||||
) -> AsyncConnectionInterface:
|
||||
if self.connection is None:
|
||||
await self._connection_acquired.wait(timeout=timeout)
|
||||
assert self.connection is not None
|
||||
return self.connection
|
||||
|
||||
|
||||
class AsyncConnectionPool(AsyncRequestInterface):
|
||||
"""
|
||||
A connection pool for making HTTP requests.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ssl_context: Optional[ssl.SSLContext] = None,
|
||||
max_connections: Optional[int] = 10,
|
||||
max_keepalive_connections: Optional[int] = None,
|
||||
keepalive_expiry: Optional[float] = None,
|
||||
http1: bool = True,
|
||||
http2: bool = False,
|
||||
retries: int = 0,
|
||||
local_address: Optional[str] = None,
|
||||
uds: Optional[str] = None,
|
||||
network_backend: Optional[AsyncNetworkBackend] = None,
|
||||
socket_options: Optional[Iterable[SOCKET_OPTION]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
A connection pool for making HTTP requests.
|
||||
|
||||
Parameters:
|
||||
ssl_context: An SSL context to use for verifying connections.
|
||||
If not specified, the default `httpcore.default_ssl_context()`
|
||||
will be used.
|
||||
max_connections: The maximum number of concurrent HTTP connections that
|
||||
the pool should allow. Any attempt to send a request on a pool that
|
||||
would exceed this amount will block until a connection is available.
|
||||
max_keepalive_connections: The maximum number of idle HTTP connections
|
||||
that will be maintained in the pool.
|
||||
keepalive_expiry: The duration in seconds that an idle HTTP connection
|
||||
may be maintained for before being expired from the pool.
|
||||
http1: A boolean indicating if HTTP/1.1 requests should be supported
|
||||
by the connection pool. Defaults to True.
|
||||
http2: A boolean indicating if HTTP/2 requests should be supported by
|
||||
the connection pool. Defaults to False.
|
||||
retries: The maximum number of retries when trying to establish a
|
||||
connection.
|
||||
local_address: Local address to connect from. Can also be used to connect
|
||||
using a particular address family. Using `local_address="0.0.0.0"`
|
||||
will connect using an `AF_INET` address (IPv4), while using
|
||||
`local_address="::"` will connect using an `AF_INET6` address (IPv6).
|
||||
uds: Path to a Unix Domain Socket to use instead of TCP sockets.
|
||||
network_backend: A backend instance to use for handling network I/O.
|
||||
socket_options: Socket options that have to be included
|
||||
in the TCP socket when the connection was established.
|
||||
"""
|
||||
self._ssl_context = ssl_context
|
||||
|
||||
self._max_connections = (
|
||||
sys.maxsize if max_connections is None else max_connections
|
||||
)
|
||||
self._max_keepalive_connections = (
|
||||
sys.maxsize
|
||||
if max_keepalive_connections is None
|
||||
else max_keepalive_connections
|
||||
)
|
||||
self._max_keepalive_connections = min(
|
||||
self._max_connections, self._max_keepalive_connections
|
||||
)
|
||||
|
||||
self._keepalive_expiry = keepalive_expiry
|
||||
self._http1 = http1
|
||||
self._http2 = http2
|
||||
self._retries = retries
|
||||
self._local_address = local_address
|
||||
self._uds = uds
|
||||
|
||||
self._pool: List[AsyncConnectionInterface] = []
|
||||
self._requests: List[RequestStatus] = []
|
||||
self._pool_lock = AsyncLock()
|
||||
self._network_backend = (
|
||||
AutoBackend() if network_backend is None else network_backend
|
||||
)
|
||||
self._socket_options = socket_options
|
||||
|
||||
def create_connection(self, origin: Origin) -> AsyncConnectionInterface:
|
||||
return AsyncHTTPConnection(
|
||||
origin=origin,
|
||||
ssl_context=self._ssl_context,
|
||||
keepalive_expiry=self._keepalive_expiry,
|
||||
http1=self._http1,
|
||||
http2=self._http2,
|
||||
retries=self._retries,
|
||||
local_address=self._local_address,
|
||||
uds=self._uds,
|
||||
network_backend=self._network_backend,
|
||||
socket_options=self._socket_options,
|
||||
)
|
||||
|
||||
@property
|
||||
def connections(self) -> List[AsyncConnectionInterface]:
|
||||
"""
|
||||
Return a list of the connections currently in the pool.
|
||||
|
||||
For example:
|
||||
|
||||
```python
|
||||
>>> pool.connections
|
||||
[
|
||||
<AsyncHTTPConnection ['https://example.com:443', HTTP/1.1, ACTIVE, Request Count: 6]>,
|
||||
<AsyncHTTPConnection ['https://example.com:443', HTTP/1.1, IDLE, Request Count: 9]> ,
|
||||
<AsyncHTTPConnection ['http://example.com:80', HTTP/1.1, IDLE, Request Count: 1]>,
|
||||
]
|
||||
```
|
||||
"""
|
||||
return list(self._pool)
|
||||
|
||||
async def _attempt_to_acquire_connection(self, status: RequestStatus) -> bool:
|
||||
"""
|
||||
Attempt to provide a connection that can handle the given origin.
|
||||
"""
|
||||
origin = status.request.url.origin
|
||||
|
||||
# If there are queued requests in front of us, then don't acquire a
|
||||
# connection. We handle requests strictly in order.
|
||||
waiting = [s for s in self._requests if s.connection is None]
|
||||
if waiting and waiting[0] is not status:
|
||||
return False
|
||||
|
||||
# Reuse an existing connection if one is currently available.
|
||||
for idx, connection in enumerate(self._pool):
|
||||
if connection.can_handle_request(origin) and connection.is_available():
|
||||
self._pool.pop(idx)
|
||||
self._pool.insert(0, connection)
|
||||
status.set_connection(connection)
|
||||
return True
|
||||
|
||||
# If the pool is currently full, attempt to close one idle connection.
|
||||
if len(self._pool) >= self._max_connections:
|
||||
for idx, connection in reversed(list(enumerate(self._pool))):
|
||||
if connection.is_idle():
|
||||
await connection.aclose()
|
||||
self._pool.pop(idx)
|
||||
break
|
||||
|
||||
# If the pool is still full, then we cannot acquire a connection.
|
||||
if len(self._pool) >= self._max_connections:
|
||||
return False
|
||||
|
||||
# Otherwise create a new connection.
|
||||
connection = self.create_connection(origin)
|
||||
self._pool.insert(0, connection)
|
||||
status.set_connection(connection)
|
||||
return True
|
||||
|
||||
async def _close_expired_connections(self) -> None:
|
||||
"""
|
||||
Clean up the connection pool by closing off any connections that have expired.
|
||||
"""
|
||||
# Close any connections that have expired their keep-alive time.
|
||||
for idx, connection in reversed(list(enumerate(self._pool))):
|
||||
if connection.has_expired():
|
||||
await connection.aclose()
|
||||
self._pool.pop(idx)
|
||||
|
||||
# If the pool size exceeds the maximum number of allowed keep-alive connections,
|
||||
# then close off idle connections as required.
|
||||
pool_size = len(self._pool)
|
||||
for idx, connection in reversed(list(enumerate(self._pool))):
|
||||
if connection.is_idle() and pool_size > self._max_keepalive_connections:
|
||||
await connection.aclose()
|
||||
self._pool.pop(idx)
|
||||
pool_size -= 1
|
||||
|
||||
async def handle_async_request(self, request: Request) -> Response:
|
||||
"""
|
||||
Send an HTTP request, and return an HTTP response.
|
||||
|
||||
This is the core implementation that is called into by `.request()` or `.stream()`.
|
||||
"""
|
||||
scheme = request.url.scheme.decode()
|
||||
if scheme == "":
|
||||
raise UnsupportedProtocol(
|
||||
"Request URL is missing an 'http://' or 'https://' protocol."
|
||||
)
|
||||
if scheme not in ("http", "https", "ws", "wss"):
|
||||
raise UnsupportedProtocol(
|
||||
f"Request URL has an unsupported protocol '{scheme}://'."
|
||||
)
|
||||
|
||||
status = RequestStatus(request)
|
||||
|
||||
async with self._pool_lock:
|
||||
self._requests.append(status)
|
||||
await self._close_expired_connections()
|
||||
await self._attempt_to_acquire_connection(status)
|
||||
|
||||
while True:
|
||||
timeouts = request.extensions.get("timeout", {})
|
||||
timeout = timeouts.get("pool", None)
|
||||
try:
|
||||
connection = await status.wait_for_connection(timeout=timeout)
|
||||
except BaseException as exc:
|
||||
# If we timeout here, or if the task is cancelled, then make
|
||||
# sure to remove the request from the queue before bubbling
|
||||
# up the exception.
|
||||
async with self._pool_lock:
|
||||
# Ensure only remove when task exists.
|
||||
if status in self._requests:
|
||||
self._requests.remove(status)
|
||||
raise exc
|
||||
|
||||
try:
|
||||
response = await connection.handle_async_request(request)
|
||||
except ConnectionNotAvailable:
|
||||
# The ConnectionNotAvailable exception is a special case, that
|
||||
# indicates we need to retry the request on a new connection.
|
||||
#
|
||||
# The most common case where this can occur is when multiple
|
||||
# requests are queued waiting for a single connection, which
|
||||
# might end up as an HTTP/2 connection, but which actually ends
|
||||
# up as HTTP/1.1.
|
||||
async with self._pool_lock:
|
||||
# Maintain our position in the request queue, but reset the
|
||||
# status so that the request becomes queued again.
|
||||
status.unset_connection()
|
||||
await self._attempt_to_acquire_connection(status)
|
||||
except BaseException as exc:
|
||||
with AsyncShieldCancellation():
|
||||
await self.response_closed(status)
|
||||
raise exc
|
||||
else:
|
||||
break
|
||||
|
||||
# When we return the response, we wrap the stream in a special class
|
||||
# that handles notifying the connection pool once the response
|
||||
# has been released.
|
||||
assert isinstance(response.stream, AsyncIterable)
|
||||
return Response(
|
||||
status=response.status,
|
||||
headers=response.headers,
|
||||
content=ConnectionPoolByteStream(response.stream, self, status),
|
||||
extensions=response.extensions,
|
||||
)
|
||||
|
||||
async def response_closed(self, status: RequestStatus) -> None:
|
||||
"""
|
||||
This method acts as a callback once the request/response cycle is complete.
|
||||
|
||||
It is called into from the `ConnectionPoolByteStream.aclose()` method.
|
||||
"""
|
||||
assert status.connection is not None
|
||||
connection = status.connection
|
||||
|
||||
async with self._pool_lock:
|
||||
# Update the state of the connection pool.
|
||||
if status in self._requests:
|
||||
self._requests.remove(status)
|
||||
|
||||
if connection.is_closed() and connection in self._pool:
|
||||
self._pool.remove(connection)
|
||||
|
||||
# Since we've had a response closed, it's possible we'll now be able
|
||||
# to service one or more requests that are currently pending.
|
||||
for status in self._requests:
|
||||
if status.connection is None:
|
||||
acquired = await self._attempt_to_acquire_connection(status)
|
||||
# If we could not acquire a connection for a queued request
|
||||
# then we don't need to check anymore requests that are
|
||||
# queued later behind it.
|
||||
if not acquired:
|
||||
break
|
||||
|
||||
# Housekeeping.
|
||||
await self._close_expired_connections()
|
||||
|
||||
async def aclose(self) -> None:
|
||||
"""
|
||||
Close any connections in the pool.
|
||||
"""
|
||||
async with self._pool_lock:
|
||||
for connection in self._pool:
|
||||
await connection.aclose()
|
||||
self._pool = []
|
||||
self._requests = []
|
||||
|
||||
async def __aenter__(self) -> "AsyncConnectionPool":
|
||||
return self
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: Optional[Type[BaseException]] = None,
|
||||
exc_value: Optional[BaseException] = None,
|
||||
traceback: Optional[TracebackType] = None,
|
||||
) -> None:
|
||||
await self.aclose()
|
||||
|
||||
|
||||
class ConnectionPoolByteStream:
|
||||
"""
|
||||
A wrapper around the response byte stream, that additionally handles
|
||||
notifying the connection pool when the response has been closed.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
stream: AsyncIterable[bytes],
|
||||
pool: AsyncConnectionPool,
|
||||
status: RequestStatus,
|
||||
) -> None:
|
||||
self._stream = stream
|
||||
self._pool = pool
|
||||
self._status = status
|
||||
|
||||
async def __aiter__(self) -> AsyncIterator[bytes]:
|
||||
async for part in self._stream:
|
||||
yield part
|
||||
|
||||
async def aclose(self) -> None:
|
||||
try:
|
||||
if hasattr(self._stream, "aclose"):
|
||||
await self._stream.aclose()
|
||||
finally:
|
||||
with AsyncShieldCancellation():
|
||||
await self._pool.response_closed(self._status)
|
|
@ -0,0 +1,331 @@
|
|||
import enum
|
||||
import logging
|
||||
import time
|
||||
from types import TracebackType
|
||||
from typing import (
|
||||
AsyncIterable,
|
||||
AsyncIterator,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
import h11
|
||||
|
||||
from .._backends.base import AsyncNetworkStream
|
||||
from .._exceptions import (
|
||||
ConnectionNotAvailable,
|
||||
LocalProtocolError,
|
||||
RemoteProtocolError,
|
||||
map_exceptions,
|
||||
)
|
||||
from .._models import Origin, Request, Response
|
||||
from .._synchronization import AsyncLock, AsyncShieldCancellation
|
||||
from .._trace import Trace
|
||||
from .interfaces import AsyncConnectionInterface
|
||||
|
||||
logger = logging.getLogger("httpcore.http11")
|
||||
|
||||
|
||||
# A subset of `h11.Event` types supported by `_send_event`
|
||||
H11SendEvent = Union[
|
||||
h11.Request,
|
||||
h11.Data,
|
||||
h11.EndOfMessage,
|
||||
]
|
||||
|
||||
|
||||
class HTTPConnectionState(enum.IntEnum):
|
||||
NEW = 0
|
||||
ACTIVE = 1
|
||||
IDLE = 2
|
||||
CLOSED = 3
|
||||
|
||||
|
||||
class AsyncHTTP11Connection(AsyncConnectionInterface):
|
||||
READ_NUM_BYTES = 64 * 1024
|
||||
MAX_INCOMPLETE_EVENT_SIZE = 100 * 1024
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
origin: Origin,
|
||||
stream: AsyncNetworkStream,
|
||||
keepalive_expiry: Optional[float] = None,
|
||||
) -> None:
|
||||
self._origin = origin
|
||||
self._network_stream = stream
|
||||
self._keepalive_expiry: Optional[float] = keepalive_expiry
|
||||
self._expire_at: Optional[float] = None
|
||||
self._state = HTTPConnectionState.NEW
|
||||
self._state_lock = AsyncLock()
|
||||
self._request_count = 0
|
||||
self._h11_state = h11.Connection(
|
||||
our_role=h11.CLIENT,
|
||||
max_incomplete_event_size=self.MAX_INCOMPLETE_EVENT_SIZE,
|
||||
)
|
||||
|
||||
async def handle_async_request(self, request: Request) -> Response:
|
||||
if not self.can_handle_request(request.url.origin):
|
||||
raise RuntimeError(
|
||||
f"Attempted to send request to {request.url.origin} on connection "
|
||||
f"to {self._origin}"
|
||||
)
|
||||
|
||||
async with self._state_lock:
|
||||
if self._state in (HTTPConnectionState.NEW, HTTPConnectionState.IDLE):
|
||||
self._request_count += 1
|
||||
self._state = HTTPConnectionState.ACTIVE
|
||||
self._expire_at = None
|
||||
else:
|
||||
raise ConnectionNotAvailable()
|
||||
|
||||
try:
|
||||
kwargs = {"request": request}
|
||||
async with Trace("send_request_headers", logger, request, kwargs) as trace:
|
||||
await self._send_request_headers(**kwargs)
|
||||
async with Trace("send_request_body", logger, request, kwargs) as trace:
|
||||
await self._send_request_body(**kwargs)
|
||||
async with Trace(
|
||||
"receive_response_headers", logger, request, kwargs
|
||||
) as trace:
|
||||
(
|
||||
http_version,
|
||||
status,
|
||||
reason_phrase,
|
||||
headers,
|
||||
) = await self._receive_response_headers(**kwargs)
|
||||
trace.return_value = (
|
||||
http_version,
|
||||
status,
|
||||
reason_phrase,
|
||||
headers,
|
||||
)
|
||||
|
||||
return Response(
|
||||
status=status,
|
||||
headers=headers,
|
||||
content=HTTP11ConnectionByteStream(self, request),
|
||||
extensions={
|
||||
"http_version": http_version,
|
||||
"reason_phrase": reason_phrase,
|
||||
"network_stream": self._network_stream,
|
||||
},
|
||||
)
|
||||
except BaseException as exc:
|
||||
with AsyncShieldCancellation():
|
||||
async with Trace("response_closed", logger, request) as trace:
|
||||
await self._response_closed()
|
||||
raise exc
|
||||
|
||||
# Sending the request...
|
||||
|
||||
async def _send_request_headers(self, request: Request) -> None:
|
||||
timeouts = request.extensions.get("timeout", {})
|
||||
timeout = timeouts.get("write", None)
|
||||
|
||||
with map_exceptions({h11.LocalProtocolError: LocalProtocolError}):
|
||||
event = h11.Request(
|
||||
method=request.method,
|
||||
target=request.url.target,
|
||||
headers=request.headers,
|
||||
)
|
||||
await self._send_event(event, timeout=timeout)
|
||||
|
||||
async def _send_request_body(self, request: Request) -> None:
|
||||
timeouts = request.extensions.get("timeout", {})
|
||||
timeout = timeouts.get("write", None)
|
||||
|
||||
assert isinstance(request.stream, AsyncIterable)
|
||||
async for chunk in request.stream:
|
||||
event = h11.Data(data=chunk)
|
||||
await self._send_event(event, timeout=timeout)
|
||||
|
||||
await self._send_event(h11.EndOfMessage(), timeout=timeout)
|
||||
|
||||
async def _send_event(
|
||||
self, event: h11.Event, timeout: Optional[float] = None
|
||||
) -> None:
|
||||
bytes_to_send = self._h11_state.send(event)
|
||||
if bytes_to_send is not None:
|
||||
await self._network_stream.write(bytes_to_send, timeout=timeout)
|
||||
|
||||
# Receiving the response...
|
||||
|
||||
async def _receive_response_headers(
|
||||
self, request: Request
|
||||
) -> Tuple[bytes, int, bytes, List[Tuple[bytes, bytes]]]:
|
||||
timeouts = request.extensions.get("timeout", {})
|
||||
timeout = timeouts.get("read", None)
|
||||
|
||||
while True:
|
||||
event = await self._receive_event(timeout=timeout)
|
||||
if isinstance(event, h11.Response):
|
||||
break
|
||||
if (
|
||||
isinstance(event, h11.InformationalResponse)
|
||||
and event.status_code == 101
|
||||
):
|
||||
break
|
||||
|
||||
http_version = b"HTTP/" + event.http_version
|
||||
|
||||
# h11 version 0.11+ supports a `raw_items` interface to get the
|
||||
# raw header casing, rather than the enforced lowercase headers.
|
||||
headers = event.headers.raw_items()
|
||||
|
||||
return http_version, event.status_code, event.reason, headers
|
||||
|
||||
async def _receive_response_body(self, request: Request) -> AsyncIterator[bytes]:
|
||||
timeouts = request.extensions.get("timeout", {})
|
||||
timeout = timeouts.get("read", None)
|
||||
|
||||
while True:
|
||||
event = await self._receive_event(timeout=timeout)
|
||||
if isinstance(event, h11.Data):
|
||||
yield bytes(event.data)
|
||||
elif isinstance(event, (h11.EndOfMessage, h11.PAUSED)):
|
||||
break
|
||||
|
||||
async def _receive_event(
|
||||
self, timeout: Optional[float] = None
|
||||
) -> Union[h11.Event, Type[h11.PAUSED]]:
|
||||
while True:
|
||||
with map_exceptions({h11.RemoteProtocolError: RemoteProtocolError}):
|
||||
event = self._h11_state.next_event()
|
||||
|
||||
if event is h11.NEED_DATA:
|
||||
data = await self._network_stream.read(
|
||||
self.READ_NUM_BYTES, timeout=timeout
|
||||
)
|
||||
|
||||
# If we feed this case through h11 we'll raise an exception like:
|
||||
#
|
||||
# httpcore.RemoteProtocolError: can't handle event type
|
||||
# ConnectionClosed when role=SERVER and state=SEND_RESPONSE
|
||||
#
|
||||
# Which is accurate, but not very informative from an end-user
|
||||
# perspective. Instead we handle this case distinctly and treat
|
||||
# it as a ConnectError.
|
||||
if data == b"" and self._h11_state.their_state == h11.SEND_RESPONSE:
|
||||
msg = "Server disconnected without sending a response."
|
||||
raise RemoteProtocolError(msg)
|
||||
|
||||
self._h11_state.receive_data(data)
|
||||
else:
|
||||
# mypy fails to narrow the type in the above if statement above
|
||||
return cast(Union[h11.Event, Type[h11.PAUSED]], event)
|
||||
|
||||
async def _response_closed(self) -> None:
|
||||
async with self._state_lock:
|
||||
if (
|
||||
self._h11_state.our_state is h11.DONE
|
||||
and self._h11_state.their_state is h11.DONE
|
||||
):
|
||||
self._state = HTTPConnectionState.IDLE
|
||||
self._h11_state.start_next_cycle()
|
||||
if self._keepalive_expiry is not None:
|
||||
now = time.monotonic()
|
||||
self._expire_at = now + self._keepalive_expiry
|
||||
else:
|
||||
await self.aclose()
|
||||
|
||||
# Once the connection is no longer required...
|
||||
|
||||
async def aclose(self) -> None:
|
||||
# Note that this method unilaterally closes the connection, and does
|
||||
# not have any kind of locking in place around it.
|
||||
self._state = HTTPConnectionState.CLOSED
|
||||
await self._network_stream.aclose()
|
||||
|
||||
# The AsyncConnectionInterface methods provide information about the state of
|
||||
# the connection, allowing for a connection pooling implementation to
|
||||
# determine when to reuse and when to close the connection...
|
||||
|
||||
def can_handle_request(self, origin: Origin) -> bool:
|
||||
return origin == self._origin
|
||||
|
||||
def is_available(self) -> bool:
|
||||
# Note that HTTP/1.1 connections in the "NEW" state are not treated as
|
||||
# being "available". The control flow which created the connection will
|
||||
# be able to send an outgoing request, but the connection will not be
|
||||
# acquired from the connection pool for any other request.
|
||||
return self._state == HTTPConnectionState.IDLE
|
||||
|
||||
def has_expired(self) -> bool:
|
||||
now = time.monotonic()
|
||||
keepalive_expired = self._expire_at is not None and now > self._expire_at
|
||||
|
||||
# If the HTTP connection is idle but the socket is readable, then the
|
||||
# only valid state is that the socket is about to return b"", indicating
|
||||
# a server-initiated disconnect.
|
||||
server_disconnected = (
|
||||
self._state == HTTPConnectionState.IDLE
|
||||
and self._network_stream.get_extra_info("is_readable")
|
||||
)
|
||||
|
||||
return keepalive_expired or server_disconnected
|
||||
|
||||
def is_idle(self) -> bool:
|
||||
return self._state == HTTPConnectionState.IDLE
|
||||
|
||||
def is_closed(self) -> bool:
|
||||
return self._state == HTTPConnectionState.CLOSED
|
||||
|
||||
def info(self) -> str:
|
||||
origin = str(self._origin)
|
||||
return (
|
||||
f"{origin!r}, HTTP/1.1, {self._state.name}, "
|
||||
f"Request Count: {self._request_count}"
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
class_name = self.__class__.__name__
|
||||
origin = str(self._origin)
|
||||
return (
|
||||
f"<{class_name} [{origin!r}, {self._state.name}, "
|
||||
f"Request Count: {self._request_count}]>"
|
||||
)
|
||||
|
||||
# These context managers are not used in the standard flow, but are
|
||||
# useful for testing or working with connection instances directly.
|
||||
|
||||
async def __aenter__(self) -> "AsyncHTTP11Connection":
|
||||
return self
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: Optional[Type[BaseException]] = None,
|
||||
exc_value: Optional[BaseException] = None,
|
||||
traceback: Optional[TracebackType] = None,
|
||||
) -> None:
|
||||
await self.aclose()
|
||||
|
||||
|
||||
class HTTP11ConnectionByteStream:
|
||||
def __init__(self, connection: AsyncHTTP11Connection, request: Request) -> None:
|
||||
self._connection = connection
|
||||
self._request = request
|
||||
self._closed = False
|
||||
|
||||
async def __aiter__(self) -> AsyncIterator[bytes]:
|
||||
kwargs = {"request": self._request}
|
||||
try:
|
||||
async with Trace("receive_response_body", logger, self._request, kwargs):
|
||||
async for chunk in self._connection._receive_response_body(**kwargs):
|
||||
yield chunk
|
||||
except BaseException as exc:
|
||||
# If we get an exception while streaming the response,
|
||||
# we want to close the response (and possibly the connection)
|
||||
# before raising that exception.
|
||||
with AsyncShieldCancellation():
|
||||
await self.aclose()
|
||||
raise exc
|
||||
|
||||
async def aclose(self) -> None:
|
||||
if not self._closed:
|
||||
self._closed = True
|
||||
async with Trace("response_closed", logger, self._request):
|
||||
await self._connection._response_closed()
|
|
@ -0,0 +1,589 @@
|
|||
import enum
|
||||
import logging
|
||||
import time
|
||||
import types
|
||||
import typing
|
||||
|
||||
import h2.config
|
||||
import h2.connection
|
||||
import h2.events
|
||||
import h2.exceptions
|
||||
import h2.settings
|
||||
|
||||
from .._backends.base import AsyncNetworkStream
|
||||
from .._exceptions import (
|
||||
ConnectionNotAvailable,
|
||||
LocalProtocolError,
|
||||
RemoteProtocolError,
|
||||
)
|
||||
from .._models import Origin, Request, Response
|
||||
from .._synchronization import AsyncLock, AsyncSemaphore, AsyncShieldCancellation
|
||||
from .._trace import Trace
|
||||
from .interfaces import AsyncConnectionInterface
|
||||
|
||||
logger = logging.getLogger("httpcore.http2")
|
||||
|
||||
|
||||
def has_body_headers(request: Request) -> bool:
|
||||
return any(
|
||||
k.lower() == b"content-length" or k.lower() == b"transfer-encoding"
|
||||
for k, v in request.headers
|
||||
)
|
||||
|
||||
|
||||
class HTTPConnectionState(enum.IntEnum):
|
||||
ACTIVE = 1
|
||||
IDLE = 2
|
||||
CLOSED = 3
|
||||
|
||||
|
||||
class AsyncHTTP2Connection(AsyncConnectionInterface):
|
||||
READ_NUM_BYTES = 64 * 1024
|
||||
CONFIG = h2.config.H2Configuration(validate_inbound_headers=False)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
origin: Origin,
|
||||
stream: AsyncNetworkStream,
|
||||
keepalive_expiry: typing.Optional[float] = None,
|
||||
):
|
||||
self._origin = origin
|
||||
self._network_stream = stream
|
||||
self._keepalive_expiry: typing.Optional[float] = keepalive_expiry
|
||||
self._h2_state = h2.connection.H2Connection(config=self.CONFIG)
|
||||
self._state = HTTPConnectionState.IDLE
|
||||
self._expire_at: typing.Optional[float] = None
|
||||
self._request_count = 0
|
||||
self._init_lock = AsyncLock()
|
||||
self._state_lock = AsyncLock()
|
||||
self._read_lock = AsyncLock()
|
||||
self._write_lock = AsyncLock()
|
||||
self._sent_connection_init = False
|
||||
self._used_all_stream_ids = False
|
||||
self._connection_error = False
|
||||
|
||||
# Mapping from stream ID to response stream events.
|
||||
self._events: typing.Dict[
|
||||
int,
|
||||
typing.Union[
|
||||
h2.events.ResponseReceived,
|
||||
h2.events.DataReceived,
|
||||
h2.events.StreamEnded,
|
||||
h2.events.StreamReset,
|
||||
],
|
||||
] = {}
|
||||
|
||||
# Connection terminated events are stored as state since
|
||||
# we need to handle them for all streams.
|
||||
self._connection_terminated: typing.Optional[
|
||||
h2.events.ConnectionTerminated
|
||||
] = None
|
||||
|
||||
self._read_exception: typing.Optional[Exception] = None
|
||||
self._write_exception: typing.Optional[Exception] = None
|
||||
|
||||
async def handle_async_request(self, request: Request) -> Response:
|
||||
if not self.can_handle_request(request.url.origin):
|
||||
# This cannot occur in normal operation, since the connection pool
|
||||
# will only send requests on connections that handle them.
|
||||
# It's in place simply for resilience as a guard against incorrect
|
||||
# usage, for anyone working directly with httpcore connections.
|
||||
raise RuntimeError(
|
||||
f"Attempted to send request to {request.url.origin} on connection "
|
||||
f"to {self._origin}"
|
||||
)
|
||||
|
||||
async with self._state_lock:
|
||||
if self._state in (HTTPConnectionState.ACTIVE, HTTPConnectionState.IDLE):
|
||||
self._request_count += 1
|
||||
self._expire_at = None
|
||||
self._state = HTTPConnectionState.ACTIVE
|
||||
else:
|
||||
raise ConnectionNotAvailable()
|
||||
|
||||
async with self._init_lock:
|
||||
if not self._sent_connection_init:
|
||||
try:
|
||||
kwargs = {"request": request}
|
||||
async with Trace("send_connection_init", logger, request, kwargs):
|
||||
await self._send_connection_init(**kwargs)
|
||||
except BaseException as exc:
|
||||
with AsyncShieldCancellation():
|
||||
await self.aclose()
|
||||
raise exc
|
||||
|
||||
self._sent_connection_init = True
|
||||
|
||||
# Initially start with just 1 until the remote server provides
|
||||
# its max_concurrent_streams value
|
||||
self._max_streams = 1
|
||||
|
||||
local_settings_max_streams = (
|
||||
self._h2_state.local_settings.max_concurrent_streams
|
||||
)
|
||||
self._max_streams_semaphore = AsyncSemaphore(local_settings_max_streams)
|
||||
|
||||
for _ in range(local_settings_max_streams - self._max_streams):
|
||||
await self._max_streams_semaphore.acquire()
|
||||
|
||||
await self._max_streams_semaphore.acquire()
|
||||
|
||||
try:
|
||||
stream_id = self._h2_state.get_next_available_stream_id()
|
||||
self._events[stream_id] = []
|
||||
except h2.exceptions.NoAvailableStreamIDError: # pragma: nocover
|
||||
self._used_all_stream_ids = True
|
||||
self._request_count -= 1
|
||||
raise ConnectionNotAvailable()
|
||||
|
||||
try:
|
||||
kwargs = {"request": request, "stream_id": stream_id}
|
||||
async with Trace("send_request_headers", logger, request, kwargs):
|
||||
await self._send_request_headers(request=request, stream_id=stream_id)
|
||||
async with Trace("send_request_body", logger, request, kwargs):
|
||||
await self._send_request_body(request=request, stream_id=stream_id)
|
||||
async with Trace(
|
||||
"receive_response_headers", logger, request, kwargs
|
||||
) as trace:
|
||||
status, headers = await self._receive_response(
|
||||
request=request, stream_id=stream_id
|
||||
)
|
||||
trace.return_value = (status, headers)
|
||||
|
||||
return Response(
|
||||
status=status,
|
||||
headers=headers,
|
||||
content=HTTP2ConnectionByteStream(self, request, stream_id=stream_id),
|
||||
extensions={
|
||||
"http_version": b"HTTP/2",
|
||||
"network_stream": self._network_stream,
|
||||
"stream_id": stream_id,
|
||||
},
|
||||
)
|
||||
except BaseException as exc: # noqa: PIE786
|
||||
with AsyncShieldCancellation():
|
||||
kwargs = {"stream_id": stream_id}
|
||||
async with Trace("response_closed", logger, request, kwargs):
|
||||
await self._response_closed(stream_id=stream_id)
|
||||
|
||||
if isinstance(exc, h2.exceptions.ProtocolError):
|
||||
# One case where h2 can raise a protocol error is when a
|
||||
# closed frame has been seen by the state machine.
|
||||
#
|
||||
# This happens when one stream is reading, and encounters
|
||||
# a GOAWAY event. Other flows of control may then raise
|
||||
# a protocol error at any point they interact with the 'h2_state'.
|
||||
#
|
||||
# In this case we'll have stored the event, and should raise
|
||||
# it as a RemoteProtocolError.
|
||||
if self._connection_terminated: # pragma: nocover
|
||||
raise RemoteProtocolError(self._connection_terminated)
|
||||
# If h2 raises a protocol error in some other state then we
|
||||
# must somehow have made a protocol violation.
|
||||
raise LocalProtocolError(exc) # pragma: nocover
|
||||
|
||||
raise exc
|
||||
|
||||
async def _send_connection_init(self, request: Request) -> None:
|
||||
"""
|
||||
The HTTP/2 connection requires some initial setup before we can start
|
||||
using individual request/response streams on it.
|
||||
"""
|
||||
# Need to set these manually here instead of manipulating via
|
||||
# __setitem__() otherwise the H2Connection will emit SettingsUpdate
|
||||
# frames in addition to sending the undesired defaults.
|
||||
self._h2_state.local_settings = h2.settings.Settings(
|
||||
client=True,
|
||||
initial_values={
|
||||
# Disable PUSH_PROMISE frames from the server since we don't do anything
|
||||
# with them for now. Maybe when we support caching?
|
||||
h2.settings.SettingCodes.ENABLE_PUSH: 0,
|
||||
# These two are taken from h2 for safe defaults
|
||||
h2.settings.SettingCodes.MAX_CONCURRENT_STREAMS: 100,
|
||||
h2.settings.SettingCodes.MAX_HEADER_LIST_SIZE: 65536,
|
||||
},
|
||||
)
|
||||
|
||||
# Some websites (*cough* Yahoo *cough*) balk at this setting being
|
||||
# present in the initial handshake since it's not defined in the original
|
||||
# RFC despite the RFC mandating ignoring settings you don't know about.
|
||||
del self._h2_state.local_settings[
|
||||
h2.settings.SettingCodes.ENABLE_CONNECT_PROTOCOL
|
||||
]
|
||||
|
||||
self._h2_state.initiate_connection()
|
||||
self._h2_state.increment_flow_control_window(2**24)
|
||||
await self._write_outgoing_data(request)
|
||||
|
||||
# Sending the request...
|
||||
|
||||
async def _send_request_headers(self, request: Request, stream_id: int) -> None:
|
||||
"""
|
||||
Send the request headers to a given stream ID.
|
||||
"""
|
||||
end_stream = not has_body_headers(request)
|
||||
|
||||
# In HTTP/2 the ':authority' pseudo-header is used instead of 'Host'.
|
||||
# In order to gracefully handle HTTP/1.1 and HTTP/2 we always require
|
||||
# HTTP/1.1 style headers, and map them appropriately if we end up on
|
||||
# an HTTP/2 connection.
|
||||
authority = [v for k, v in request.headers if k.lower() == b"host"][0]
|
||||
|
||||
headers = [
|
||||
(b":method", request.method),
|
||||
(b":authority", authority),
|
||||
(b":scheme", request.url.scheme),
|
||||
(b":path", request.url.target),
|
||||
] + [
|
||||
(k.lower(), v)
|
||||
for k, v in request.headers
|
||||
if k.lower()
|
||||
not in (
|
||||
b"host",
|
||||
b"transfer-encoding",
|
||||
)
|
||||
]
|
||||
|
||||
self._h2_state.send_headers(stream_id, headers, end_stream=end_stream)
|
||||
self._h2_state.increment_flow_control_window(2**24, stream_id=stream_id)
|
||||
await self._write_outgoing_data(request)
|
||||
|
||||
async def _send_request_body(self, request: Request, stream_id: int) -> None:
|
||||
"""
|
||||
Iterate over the request body sending it to a given stream ID.
|
||||
"""
|
||||
if not has_body_headers(request):
|
||||
return
|
||||
|
||||
assert isinstance(request.stream, typing.AsyncIterable)
|
||||
async for data in request.stream:
|
||||
await self._send_stream_data(request, stream_id, data)
|
||||
await self._send_end_stream(request, stream_id)
|
||||
|
||||
async def _send_stream_data(
|
||||
self, request: Request, stream_id: int, data: bytes
|
||||
) -> None:
|
||||
"""
|
||||
Send a single chunk of data in one or more data frames.
|
||||
"""
|
||||
while data:
|
||||
max_flow = await self._wait_for_outgoing_flow(request, stream_id)
|
||||
chunk_size = min(len(data), max_flow)
|
||||
chunk, data = data[:chunk_size], data[chunk_size:]
|
||||
self._h2_state.send_data(stream_id, chunk)
|
||||
await self._write_outgoing_data(request)
|
||||
|
||||
async def _send_end_stream(self, request: Request, stream_id: int) -> None:
|
||||
"""
|
||||
Send an empty data frame on on a given stream ID with the END_STREAM flag set.
|
||||
"""
|
||||
self._h2_state.end_stream(stream_id)
|
||||
await self._write_outgoing_data(request)
|
||||
|
||||
# Receiving the response...
|
||||
|
||||
async def _receive_response(
|
||||
self, request: Request, stream_id: int
|
||||
) -> typing.Tuple[int, typing.List[typing.Tuple[bytes, bytes]]]:
|
||||
"""
|
||||
Return the response status code and headers for a given stream ID.
|
||||
"""
|
||||
while True:
|
||||
event = await self._receive_stream_event(request, stream_id)
|
||||
if isinstance(event, h2.events.ResponseReceived):
|
||||
break
|
||||
|
||||
status_code = 200
|
||||
headers = []
|
||||
for k, v in event.headers:
|
||||
if k == b":status":
|
||||
status_code = int(v.decode("ascii", errors="ignore"))
|
||||
elif not k.startswith(b":"):
|
||||
headers.append((k, v))
|
||||
|
||||
return (status_code, headers)
|
||||
|
||||
async def _receive_response_body(
|
||||
self, request: Request, stream_id: int
|
||||
) -> typing.AsyncIterator[bytes]:
|
||||
"""
|
||||
Iterator that returns the bytes of the response body for a given stream ID.
|
||||
"""
|
||||
while True:
|
||||
event = await self._receive_stream_event(request, stream_id)
|
||||
if isinstance(event, h2.events.DataReceived):
|
||||
amount = event.flow_controlled_length
|
||||
self._h2_state.acknowledge_received_data(amount, stream_id)
|
||||
await self._write_outgoing_data(request)
|
||||
yield event.data
|
||||
elif isinstance(event, h2.events.StreamEnded):
|
||||
break
|
||||
|
||||
async def _receive_stream_event(
|
||||
self, request: Request, stream_id: int
|
||||
) -> typing.Union[
|
||||
h2.events.ResponseReceived, h2.events.DataReceived, h2.events.StreamEnded
|
||||
]:
|
||||
"""
|
||||
Return the next available event for a given stream ID.
|
||||
|
||||
Will read more data from the network if required.
|
||||
"""
|
||||
while not self._events.get(stream_id):
|
||||
await self._receive_events(request, stream_id)
|
||||
event = self._events[stream_id].pop(0)
|
||||
if isinstance(event, h2.events.StreamReset):
|
||||
raise RemoteProtocolError(event)
|
||||
return event
|
||||
|
||||
async def _receive_events(
|
||||
self, request: Request, stream_id: typing.Optional[int] = None
|
||||
) -> None:
|
||||
"""
|
||||
Read some data from the network until we see one or more events
|
||||
for a given stream ID.
|
||||
"""
|
||||
async with self._read_lock:
|
||||
if self._connection_terminated is not None:
|
||||
last_stream_id = self._connection_terminated.last_stream_id
|
||||
if stream_id and last_stream_id and stream_id > last_stream_id:
|
||||
self._request_count -= 1
|
||||
raise ConnectionNotAvailable()
|
||||
raise RemoteProtocolError(self._connection_terminated)
|
||||
|
||||
# This conditional is a bit icky. We don't want to block reading if we've
|
||||
# actually got an event to return for a given stream. We need to do that
|
||||
# check *within* the atomic read lock. Though it also need to be optional,
|
||||
# because when we call it from `_wait_for_outgoing_flow` we *do* want to
|
||||
# block until we've available flow control, event when we have events
|
||||
# pending for the stream ID we're attempting to send on.
|
||||
if stream_id is None or not self._events.get(stream_id):
|
||||
events = await self._read_incoming_data(request)
|
||||
for event in events:
|
||||
if isinstance(event, h2.events.RemoteSettingsChanged):
|
||||
async with Trace(
|
||||
"receive_remote_settings", logger, request
|
||||
) as trace:
|
||||
await self._receive_remote_settings_change(event)
|
||||
trace.return_value = event
|
||||
|
||||
elif isinstance(
|
||||
event,
|
||||
(
|
||||
h2.events.ResponseReceived,
|
||||
h2.events.DataReceived,
|
||||
h2.events.StreamEnded,
|
||||
h2.events.StreamReset,
|
||||
),
|
||||
):
|
||||
if event.stream_id in self._events:
|
||||
self._events[event.stream_id].append(event)
|
||||
|
||||
elif isinstance(event, h2.events.ConnectionTerminated):
|
||||
self._connection_terminated = event
|
||||
|
||||
await self._write_outgoing_data(request)
|
||||
|
||||
async def _receive_remote_settings_change(self, event: h2.events.Event) -> None:
|
||||
max_concurrent_streams = event.changed_settings.get(
|
||||
h2.settings.SettingCodes.MAX_CONCURRENT_STREAMS
|
||||
)
|
||||
if max_concurrent_streams:
|
||||
new_max_streams = min(
|
||||
max_concurrent_streams.new_value,
|
||||
self._h2_state.local_settings.max_concurrent_streams,
|
||||
)
|
||||
if new_max_streams and new_max_streams != self._max_streams:
|
||||
while new_max_streams > self._max_streams:
|
||||
await self._max_streams_semaphore.release()
|
||||
self._max_streams += 1
|
||||
while new_max_streams < self._max_streams:
|
||||
await self._max_streams_semaphore.acquire()
|
||||
self._max_streams -= 1
|
||||
|
||||
async def _response_closed(self, stream_id: int) -> None:
|
||||
await self._max_streams_semaphore.release()
|
||||
del self._events[stream_id]
|
||||
async with self._state_lock:
|
||||
if self._connection_terminated and not self._events:
|
||||
await self.aclose()
|
||||
|
||||
elif self._state == HTTPConnectionState.ACTIVE and not self._events:
|
||||
self._state = HTTPConnectionState.IDLE
|
||||
if self._keepalive_expiry is not None:
|
||||
now = time.monotonic()
|
||||
self._expire_at = now + self._keepalive_expiry
|
||||
if self._used_all_stream_ids: # pragma: nocover
|
||||
await self.aclose()
|
||||
|
||||
async def aclose(self) -> None:
|
||||
# Note that this method unilaterally closes the connection, and does
|
||||
# not have any kind of locking in place around it.
|
||||
self._h2_state.close_connection()
|
||||
self._state = HTTPConnectionState.CLOSED
|
||||
await self._network_stream.aclose()
|
||||
|
||||
# Wrappers around network read/write operations...
|
||||
|
||||
async def _read_incoming_data(
|
||||
self, request: Request
|
||||
) -> typing.List[h2.events.Event]:
|
||||
timeouts = request.extensions.get("timeout", {})
|
||||
timeout = timeouts.get("read", None)
|
||||
|
||||
if self._read_exception is not None:
|
||||
raise self._read_exception # pragma: nocover
|
||||
|
||||
try:
|
||||
data = await self._network_stream.read(self.READ_NUM_BYTES, timeout)
|
||||
if data == b"":
|
||||
raise RemoteProtocolError("Server disconnected")
|
||||
except Exception as exc:
|
||||
# If we get a network error we should:
|
||||
#
|
||||
# 1. Save the exception and just raise it immediately on any future reads.
|
||||
# (For example, this means that a single read timeout or disconnect will
|
||||
# immediately close all pending streams. Without requiring multiple
|
||||
# sequential timeouts.)
|
||||
# 2. Mark the connection as errored, so that we don't accept any other
|
||||
# incoming requests.
|
||||
self._read_exception = exc
|
||||
self._connection_error = True
|
||||
raise exc
|
||||
|
||||
events: typing.List[h2.events.Event] = self._h2_state.receive_data(data)
|
||||
|
||||
return events
|
||||
|
||||
async def _write_outgoing_data(self, request: Request) -> None:
|
||||
timeouts = request.extensions.get("timeout", {})
|
||||
timeout = timeouts.get("write", None)
|
||||
|
||||
async with self._write_lock:
|
||||
data_to_send = self._h2_state.data_to_send()
|
||||
|
||||
if self._write_exception is not None:
|
||||
raise self._write_exception # pragma: nocover
|
||||
|
||||
try:
|
||||
await self._network_stream.write(data_to_send, timeout)
|
||||
except Exception as exc: # pragma: nocover
|
||||
# If we get a network error we should:
|
||||
#
|
||||
# 1. Save the exception and just raise it immediately on any future write.
|
||||
# (For example, this means that a single write timeout or disconnect will
|
||||
# immediately close all pending streams. Without requiring multiple
|
||||
# sequential timeouts.)
|
||||
# 2. Mark the connection as errored, so that we don't accept any other
|
||||
# incoming requests.
|
||||
self._write_exception = exc
|
||||
self._connection_error = True
|
||||
raise exc
|
||||
|
||||
# Flow control...
|
||||
|
||||
async def _wait_for_outgoing_flow(self, request: Request, stream_id: int) -> int:
|
||||
"""
|
||||
Returns the maximum allowable outgoing flow for a given stream.
|
||||
|
||||
If the allowable flow is zero, then waits on the network until
|
||||
WindowUpdated frames have increased the flow rate.
|
||||
https://tools.ietf.org/html/rfc7540#section-6.9
|
||||
"""
|
||||
local_flow: int = self._h2_state.local_flow_control_window(stream_id)
|
||||
max_frame_size: int = self._h2_state.max_outbound_frame_size
|
||||
flow = min(local_flow, max_frame_size)
|
||||
while flow == 0:
|
||||
await self._receive_events(request)
|
||||
local_flow = self._h2_state.local_flow_control_window(stream_id)
|
||||
max_frame_size = self._h2_state.max_outbound_frame_size
|
||||
flow = min(local_flow, max_frame_size)
|
||||
return flow
|
||||
|
||||
# Interface for connection pooling...
|
||||
|
||||
def can_handle_request(self, origin: Origin) -> bool:
|
||||
return origin == self._origin
|
||||
|
||||
def is_available(self) -> bool:
|
||||
return (
|
||||
self._state != HTTPConnectionState.CLOSED
|
||||
and not self._connection_error
|
||||
and not self._used_all_stream_ids
|
||||
and not (
|
||||
self._h2_state.state_machine.state
|
||||
== h2.connection.ConnectionState.CLOSED
|
||||
)
|
||||
)
|
||||
|
||||
def has_expired(self) -> bool:
|
||||
now = time.monotonic()
|
||||
return self._expire_at is not None and now > self._expire_at
|
||||
|
||||
def is_idle(self) -> bool:
|
||||
return self._state == HTTPConnectionState.IDLE
|
||||
|
||||
def is_closed(self) -> bool:
|
||||
return self._state == HTTPConnectionState.CLOSED
|
||||
|
||||
def info(self) -> str:
|
||||
origin = str(self._origin)
|
||||
return (
|
||||
f"{origin!r}, HTTP/2, {self._state.name}, "
|
||||
f"Request Count: {self._request_count}"
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
class_name = self.__class__.__name__
|
||||
origin = str(self._origin)
|
||||
return (
|
||||
f"<{class_name} [{origin!r}, {self._state.name}, "
|
||||
f"Request Count: {self._request_count}]>"
|
||||
)
|
||||
|
||||
# These context managers are not used in the standard flow, but are
|
||||
# useful for testing or working with connection instances directly.
|
||||
|
||||
async def __aenter__(self) -> "AsyncHTTP2Connection":
|
||||
return self
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: typing.Optional[typing.Type[BaseException]] = None,
|
||||
exc_value: typing.Optional[BaseException] = None,
|
||||
traceback: typing.Optional[types.TracebackType] = None,
|
||||
) -> None:
|
||||
await self.aclose()
|
||||
|
||||
|
||||
class HTTP2ConnectionByteStream:
|
||||
def __init__(
|
||||
self, connection: AsyncHTTP2Connection, request: Request, stream_id: int
|
||||
) -> None:
|
||||
self._connection = connection
|
||||
self._request = request
|
||||
self._stream_id = stream_id
|
||||
self._closed = False
|
||||
|
||||
async def __aiter__(self) -> typing.AsyncIterator[bytes]:
|
||||
kwargs = {"request": self._request, "stream_id": self._stream_id}
|
||||
try:
|
||||
async with Trace("receive_response_body", logger, self._request, kwargs):
|
||||
async for chunk in self._connection._receive_response_body(
|
||||
request=self._request, stream_id=self._stream_id
|
||||
):
|
||||
yield chunk
|
||||
except BaseException as exc:
|
||||
# If we get an exception while streaming the response,
|
||||
# we want to close the response (and possibly the connection)
|
||||
# before raising that exception.
|
||||
with AsyncShieldCancellation():
|
||||
await self.aclose()
|
||||
raise exc
|
||||
|
||||
async def aclose(self) -> None:
|
||||
if not self._closed:
|
||||
self._closed = True
|
||||
kwargs = {"stream_id": self._stream_id}
|
||||
async with Trace("response_closed", logger, self._request, kwargs):
|
||||
await self._connection._response_closed(stream_id=self._stream_id)
|
|
@ -0,0 +1,350 @@
|
|||
import logging
|
||||
import ssl
|
||||
from base64 import b64encode
|
||||
from typing import Iterable, List, Mapping, Optional, Sequence, Tuple, Union
|
||||
|
||||
from .._backends.base import SOCKET_OPTION, AsyncNetworkBackend
|
||||
from .._exceptions import ProxyError
|
||||
from .._models import (
|
||||
URL,
|
||||
Origin,
|
||||
Request,
|
||||
Response,
|
||||
enforce_bytes,
|
||||
enforce_headers,
|
||||
enforce_url,
|
||||
)
|
||||
from .._ssl import default_ssl_context
|
||||
from .._synchronization import AsyncLock
|
||||
from .._trace import Trace
|
||||
from .connection import AsyncHTTPConnection
|
||||
from .connection_pool import AsyncConnectionPool
|
||||
from .http11 import AsyncHTTP11Connection
|
||||
from .interfaces import AsyncConnectionInterface
|
||||
|
||||
HeadersAsSequence = Sequence[Tuple[Union[bytes, str], Union[bytes, str]]]
|
||||
HeadersAsMapping = Mapping[Union[bytes, str], Union[bytes, str]]
|
||||
|
||||
|
||||
logger = logging.getLogger("httpcore.proxy")
|
||||
|
||||
|
||||
def merge_headers(
|
||||
default_headers: Optional[Sequence[Tuple[bytes, bytes]]] = None,
|
||||
override_headers: Optional[Sequence[Tuple[bytes, bytes]]] = None,
|
||||
) -> List[Tuple[bytes, bytes]]:
|
||||
"""
|
||||
Append default_headers and override_headers, de-duplicating if a key exists
|
||||
in both cases.
|
||||
"""
|
||||
default_headers = [] if default_headers is None else list(default_headers)
|
||||
override_headers = [] if override_headers is None else list(override_headers)
|
||||
has_override = set(key.lower() for key, value in override_headers)
|
||||
default_headers = [
|
||||
(key, value)
|
||||
for key, value in default_headers
|
||||
if key.lower() not in has_override
|
||||
]
|
||||
return default_headers + override_headers
|
||||
|
||||
|
||||
def build_auth_header(username: bytes, password: bytes) -> bytes:
|
||||
userpass = username + b":" + password
|
||||
return b"Basic " + b64encode(userpass)
|
||||
|
||||
|
||||
class AsyncHTTPProxy(AsyncConnectionPool):
|
||||
"""
|
||||
A connection pool that sends requests via an HTTP proxy.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
proxy_url: Union[URL, bytes, str],
|
||||
proxy_auth: Optional[Tuple[Union[bytes, str], Union[bytes, str]]] = None,
|
||||
proxy_headers: Union[HeadersAsMapping, HeadersAsSequence, None] = None,
|
||||
ssl_context: Optional[ssl.SSLContext] = None,
|
||||
max_connections: Optional[int] = 10,
|
||||
max_keepalive_connections: Optional[int] = None,
|
||||
keepalive_expiry: Optional[float] = None,
|
||||
http1: bool = True,
|
||||
http2: bool = False,
|
||||
retries: int = 0,
|
||||
local_address: Optional[str] = None,
|
||||
uds: Optional[str] = None,
|
||||
network_backend: Optional[AsyncNetworkBackend] = None,
|
||||
socket_options: Optional[Iterable[SOCKET_OPTION]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
A connection pool for making HTTP requests.
|
||||
|
||||
Parameters:
|
||||
proxy_url: The URL to use when connecting to the proxy server.
|
||||
For example `"http://127.0.0.1:8080/"`.
|
||||
proxy_auth: Any proxy authentication as a two-tuple of
|
||||
(username, password). May be either bytes or ascii-only str.
|
||||
proxy_headers: Any HTTP headers to use for the proxy requests.
|
||||
For example `{"Proxy-Authorization": "Basic <username>:<password>"}`.
|
||||
ssl_context: An SSL context to use for verifying connections.
|
||||
If not specified, the default `httpcore.default_ssl_context()`
|
||||
will be used.
|
||||
max_connections: The maximum number of concurrent HTTP connections that
|
||||
the pool should allow. Any attempt to send a request on a pool that
|
||||
would exceed this amount will block until a connection is available.
|
||||
max_keepalive_connections: The maximum number of idle HTTP connections
|
||||
that will be maintained in the pool.
|
||||
keepalive_expiry: The duration in seconds that an idle HTTP connection
|
||||
may be maintained for before being expired from the pool.
|
||||
http1: A boolean indicating if HTTP/1.1 requests should be supported
|
||||
by the connection pool. Defaults to True.
|
||||
http2: A boolean indicating if HTTP/2 requests should be supported by
|
||||
the connection pool. Defaults to False.
|
||||
retries: The maximum number of retries when trying to establish
|
||||
a connection.
|
||||
local_address: Local address to connect from. Can also be used to
|
||||
connect using a particular address family. Using
|
||||
`local_address="0.0.0.0"` will connect using an `AF_INET` address
|
||||
(IPv4), while using `local_address="::"` will connect using an
|
||||
`AF_INET6` address (IPv6).
|
||||
uds: Path to a Unix Domain Socket to use instead of TCP sockets.
|
||||
network_backend: A backend instance to use for handling network I/O.
|
||||
"""
|
||||
super().__init__(
|
||||
ssl_context=ssl_context,
|
||||
max_connections=max_connections,
|
||||
max_keepalive_connections=max_keepalive_connections,
|
||||
keepalive_expiry=keepalive_expiry,
|
||||
http1=http1,
|
||||
http2=http2,
|
||||
network_backend=network_backend,
|
||||
retries=retries,
|
||||
local_address=local_address,
|
||||
uds=uds,
|
||||
socket_options=socket_options,
|
||||
)
|
||||
self._ssl_context = ssl_context
|
||||
self._proxy_url = enforce_url(proxy_url, name="proxy_url")
|
||||
self._proxy_headers = enforce_headers(proxy_headers, name="proxy_headers")
|
||||
if proxy_auth is not None:
|
||||
username = enforce_bytes(proxy_auth[0], name="proxy_auth")
|
||||
password = enforce_bytes(proxy_auth[1], name="proxy_auth")
|
||||
authorization = build_auth_header(username, password)
|
||||
self._proxy_headers = [
|
||||
(b"Proxy-Authorization", authorization)
|
||||
] + self._proxy_headers
|
||||
|
||||
def create_connection(self, origin: Origin) -> AsyncConnectionInterface:
|
||||
if origin.scheme == b"http":
|
||||
return AsyncForwardHTTPConnection(
|
||||
proxy_origin=self._proxy_url.origin,
|
||||
proxy_headers=self._proxy_headers,
|
||||
remote_origin=origin,
|
||||
keepalive_expiry=self._keepalive_expiry,
|
||||
network_backend=self._network_backend,
|
||||
)
|
||||
return AsyncTunnelHTTPConnection(
|
||||
proxy_origin=self._proxy_url.origin,
|
||||
proxy_headers=self._proxy_headers,
|
||||
remote_origin=origin,
|
||||
ssl_context=self._ssl_context,
|
||||
keepalive_expiry=self._keepalive_expiry,
|
||||
http1=self._http1,
|
||||
http2=self._http2,
|
||||
network_backend=self._network_backend,
|
||||
)
|
||||
|
||||
|
||||
class AsyncForwardHTTPConnection(AsyncConnectionInterface):
|
||||
def __init__(
|
||||
self,
|
||||
proxy_origin: Origin,
|
||||
remote_origin: Origin,
|
||||
proxy_headers: Union[HeadersAsMapping, HeadersAsSequence, None] = None,
|
||||
keepalive_expiry: Optional[float] = None,
|
||||
network_backend: Optional[AsyncNetworkBackend] = None,
|
||||
socket_options: Optional[Iterable[SOCKET_OPTION]] = None,
|
||||
) -> None:
|
||||
self._connection = AsyncHTTPConnection(
|
||||
origin=proxy_origin,
|
||||
keepalive_expiry=keepalive_expiry,
|
||||
network_backend=network_backend,
|
||||
socket_options=socket_options,
|
||||
)
|
||||
self._proxy_origin = proxy_origin
|
||||
self._proxy_headers = enforce_headers(proxy_headers, name="proxy_headers")
|
||||
self._remote_origin = remote_origin
|
||||
|
||||
async def handle_async_request(self, request: Request) -> Response:
|
||||
headers = merge_headers(self._proxy_headers, request.headers)
|
||||
url = URL(
|
||||
scheme=self._proxy_origin.scheme,
|
||||
host=self._proxy_origin.host,
|
||||
port=self._proxy_origin.port,
|
||||
target=bytes(request.url),
|
||||
)
|
||||
proxy_request = Request(
|
||||
method=request.method,
|
||||
url=url,
|
||||
headers=headers,
|
||||
content=request.stream,
|
||||
extensions=request.extensions,
|
||||
)
|
||||
return await self._connection.handle_async_request(proxy_request)
|
||||
|
||||
def can_handle_request(self, origin: Origin) -> bool:
|
||||
return origin == self._remote_origin
|
||||
|
||||
async def aclose(self) -> None:
|
||||
await self._connection.aclose()
|
||||
|
||||
def info(self) -> str:
|
||||
return self._connection.info()
|
||||
|
||||
def is_available(self) -> bool:
|
||||
return self._connection.is_available()
|
||||
|
||||
def has_expired(self) -> bool:
|
||||
return self._connection.has_expired()
|
||||
|
||||
def is_idle(self) -> bool:
|
||||
return self._connection.is_idle()
|
||||
|
||||
def is_closed(self) -> bool:
|
||||
return self._connection.is_closed()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<{self.__class__.__name__} [{self.info()}]>"
|
||||
|
||||
|
||||
class AsyncTunnelHTTPConnection(AsyncConnectionInterface):
|
||||
def __init__(
|
||||
self,
|
||||
proxy_origin: Origin,
|
||||
remote_origin: Origin,
|
||||
ssl_context: Optional[ssl.SSLContext] = None,
|
||||
proxy_headers: Optional[Sequence[Tuple[bytes, bytes]]] = None,
|
||||
keepalive_expiry: Optional[float] = None,
|
||||
http1: bool = True,
|
||||
http2: bool = False,
|
||||
network_backend: Optional[AsyncNetworkBackend] = None,
|
||||
socket_options: Optional[Iterable[SOCKET_OPTION]] = None,
|
||||
) -> None:
|
||||
self._connection: AsyncConnectionInterface = AsyncHTTPConnection(
|
||||
origin=proxy_origin,
|
||||
keepalive_expiry=keepalive_expiry,
|
||||
network_backend=network_backend,
|
||||
socket_options=socket_options,
|
||||
)
|
||||
self._proxy_origin = proxy_origin
|
||||
self._remote_origin = remote_origin
|
||||
self._ssl_context = ssl_context
|
||||
self._proxy_headers = enforce_headers(proxy_headers, name="proxy_headers")
|
||||
self._keepalive_expiry = keepalive_expiry
|
||||
self._http1 = http1
|
||||
self._http2 = http2
|
||||
self._connect_lock = AsyncLock()
|
||||
self._connected = False
|
||||
|
||||
async def handle_async_request(self, request: Request) -> Response:
|
||||
timeouts = request.extensions.get("timeout", {})
|
||||
timeout = timeouts.get("connect", None)
|
||||
|
||||
async with self._connect_lock:
|
||||
if not self._connected:
|
||||
target = b"%b:%d" % (self._remote_origin.host, self._remote_origin.port)
|
||||
|
||||
connect_url = URL(
|
||||
scheme=self._proxy_origin.scheme,
|
||||
host=self._proxy_origin.host,
|
||||
port=self._proxy_origin.port,
|
||||
target=target,
|
||||
)
|
||||
connect_headers = merge_headers(
|
||||
[(b"Host", target), (b"Accept", b"*/*")], self._proxy_headers
|
||||
)
|
||||
connect_request = Request(
|
||||
method=b"CONNECT",
|
||||
url=connect_url,
|
||||
headers=connect_headers,
|
||||
extensions=request.extensions,
|
||||
)
|
||||
connect_response = await self._connection.handle_async_request(
|
||||
connect_request
|
||||
)
|
||||
|
||||
if connect_response.status < 200 or connect_response.status > 299:
|
||||
reason_bytes = connect_response.extensions.get("reason_phrase", b"")
|
||||
reason_str = reason_bytes.decode("ascii", errors="ignore")
|
||||
msg = "%d %s" % (connect_response.status, reason_str)
|
||||
await self._connection.aclose()
|
||||
raise ProxyError(msg)
|
||||
|
||||
stream = connect_response.extensions["network_stream"]
|
||||
|
||||
# Upgrade the stream to SSL
|
||||
ssl_context = (
|
||||
default_ssl_context()
|
||||
if self._ssl_context is None
|
||||
else self._ssl_context
|
||||
)
|
||||
alpn_protocols = ["http/1.1", "h2"] if self._http2 else ["http/1.1"]
|
||||
ssl_context.set_alpn_protocols(alpn_protocols)
|
||||
|
||||
kwargs = {
|
||||
"ssl_context": ssl_context,
|
||||
"server_hostname": self._remote_origin.host.decode("ascii"),
|
||||
"timeout": timeout,
|
||||
}
|
||||
async with Trace("start_tls", logger, request, kwargs) as trace:
|
||||
stream = await stream.start_tls(**kwargs)
|
||||
trace.return_value = stream
|
||||
|
||||
# Determine if we should be using HTTP/1.1 or HTTP/2
|
||||
ssl_object = stream.get_extra_info("ssl_object")
|
||||
http2_negotiated = (
|
||||
ssl_object is not None
|
||||
and ssl_object.selected_alpn_protocol() == "h2"
|
||||
)
|
||||
|
||||
# Create the HTTP/1.1 or HTTP/2 connection
|
||||
if http2_negotiated or (self._http2 and not self._http1):
|
||||
from .http2 import AsyncHTTP2Connection
|
||||
|
||||
self._connection = AsyncHTTP2Connection(
|
||||
origin=self._remote_origin,
|
||||
stream=stream,
|
||||
keepalive_expiry=self._keepalive_expiry,
|
||||
)
|
||||
else:
|
||||
self._connection = AsyncHTTP11Connection(
|
||||
origin=self._remote_origin,
|
||||
stream=stream,
|
||||
keepalive_expiry=self._keepalive_expiry,
|
||||
)
|
||||
|
||||
self._connected = True
|
||||
return await self._connection.handle_async_request(request)
|
||||
|
||||
def can_handle_request(self, origin: Origin) -> bool:
|
||||
return origin == self._remote_origin
|
||||
|
||||
async def aclose(self) -> None:
|
||||
await self._connection.aclose()
|
||||
|
||||
def info(self) -> str:
|
||||
return self._connection.info()
|
||||
|
||||
def is_available(self) -> bool:
|
||||
return self._connection.is_available()
|
||||
|
||||
def has_expired(self) -> bool:
|
||||
return self._connection.has_expired()
|
||||
|
||||
def is_idle(self) -> bool:
|
||||
return self._connection.is_idle()
|
||||
|
||||
def is_closed(self) -> bool:
|
||||
return self._connection.is_closed()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<{self.__class__.__name__} [{self.info()}]>"
|
|
@ -0,0 +1,135 @@
|
|||
from contextlib import asynccontextmanager
|
||||
from typing import AsyncIterator, Optional, Union
|
||||
|
||||
from .._models import (
|
||||
URL,
|
||||
Extensions,
|
||||
HeaderTypes,
|
||||
Origin,
|
||||
Request,
|
||||
Response,
|
||||
enforce_bytes,
|
||||
enforce_headers,
|
||||
enforce_url,
|
||||
include_request_headers,
|
||||
)
|
||||
|
||||
|
||||
class AsyncRequestInterface:
|
||||
async def request(
|
||||
self,
|
||||
method: Union[bytes, str],
|
||||
url: Union[URL, bytes, str],
|
||||
*,
|
||||
headers: HeaderTypes = None,
|
||||
content: Union[bytes, AsyncIterator[bytes], None] = None,
|
||||
extensions: Optional[Extensions] = None,
|
||||
) -> Response:
|
||||
# Strict type checking on our parameters.
|
||||
method = enforce_bytes(method, name="method")
|
||||
url = enforce_url(url, name="url")
|
||||
headers = enforce_headers(headers, name="headers")
|
||||
|
||||
# Include Host header, and optionally Content-Length or Transfer-Encoding.
|
||||
headers = include_request_headers(headers, url=url, content=content)
|
||||
|
||||
request = Request(
|
||||
method=method,
|
||||
url=url,
|
||||
headers=headers,
|
||||
content=content,
|
||||
extensions=extensions,
|
||||
)
|
||||
response = await self.handle_async_request(request)
|
||||
try:
|
||||
await response.aread()
|
||||
finally:
|
||||
await response.aclose()
|
||||
return response
|
||||
|
||||
@asynccontextmanager
|
||||
async def stream(
|
||||
self,
|
||||
method: Union[bytes, str],
|
||||
url: Union[URL, bytes, str],
|
||||
*,
|
||||
headers: HeaderTypes = None,
|
||||
content: Union[bytes, AsyncIterator[bytes], None] = None,
|
||||
extensions: Optional[Extensions] = None,
|
||||
) -> AsyncIterator[Response]:
|
||||
# Strict type checking on our parameters.
|
||||
method = enforce_bytes(method, name="method")
|
||||
url = enforce_url(url, name="url")
|
||||
headers = enforce_headers(headers, name="headers")
|
||||
|
||||
# Include Host header, and optionally Content-Length or Transfer-Encoding.
|
||||
headers = include_request_headers(headers, url=url, content=content)
|
||||
|
||||
request = Request(
|
||||
method=method,
|
||||
url=url,
|
||||
headers=headers,
|
||||
content=content,
|
||||
extensions=extensions,
|
||||
)
|
||||
response = await self.handle_async_request(request)
|
||||
try:
|
||||
yield response
|
||||
finally:
|
||||
await response.aclose()
|
||||
|
||||
async def handle_async_request(self, request: Request) -> Response:
|
||||
raise NotImplementedError() # pragma: nocover
|
||||
|
||||
|
||||
class AsyncConnectionInterface(AsyncRequestInterface):
|
||||
async def aclose(self) -> None:
|
||||
raise NotImplementedError() # pragma: nocover
|
||||
|
||||
def info(self) -> str:
|
||||
raise NotImplementedError() # pragma: nocover
|
||||
|
||||
def can_handle_request(self, origin: Origin) -> bool:
|
||||
raise NotImplementedError() # pragma: nocover
|
||||
|
||||
def is_available(self) -> bool:
|
||||
"""
|
||||
Return `True` if the connection is currently able to accept an
|
||||
outgoing request.
|
||||
|
||||
An HTTP/1.1 connection will only be available if it is currently idle.
|
||||
|
||||
An HTTP/2 connection will be available so long as the stream ID space is
|
||||
not yet exhausted, and the connection is not in an error state.
|
||||
|
||||
While the connection is being established we may not yet know if it is going
|
||||
to result in an HTTP/1.1 or HTTP/2 connection. The connection should be
|
||||
treated as being available, but might ultimately raise `NewConnectionRequired`
|
||||
required exceptions if multiple requests are attempted over a connection
|
||||
that ends up being established as HTTP/1.1.
|
||||
"""
|
||||
raise NotImplementedError() # pragma: nocover
|
||||
|
||||
def has_expired(self) -> bool:
|
||||
"""
|
||||
Return `True` if the connection is in a state where it should be closed.
|
||||
|
||||
This either means that the connection is idle and it has passed the
|
||||
expiry time on its keep-alive, or that server has sent an EOF.
|
||||
"""
|
||||
raise NotImplementedError() # pragma: nocover
|
||||
|
||||
def is_idle(self) -> bool:
|
||||
"""
|
||||
Return `True` if the connection is currently idle.
|
||||
"""
|
||||
raise NotImplementedError() # pragma: nocover
|
||||
|
||||
def is_closed(self) -> bool:
|
||||
"""
|
||||
Return `True` if the connection has been closed.
|
||||
|
||||
Used when a response is closed to determine if the connection may be
|
||||
returned to the connection pool or not.
|
||||
"""
|
||||
raise NotImplementedError() # pragma: nocover
|
|
@ -0,0 +1,340 @@
|
|||
import logging
|
||||
import ssl
|
||||
import typing
|
||||
|
||||
from socksio import socks5
|
||||
|
||||
from .._backends.auto import AutoBackend
|
||||
from .._backends.base import AsyncNetworkBackend, AsyncNetworkStream
|
||||
from .._exceptions import ConnectionNotAvailable, ProxyError
|
||||
from .._models import URL, Origin, Request, Response, enforce_bytes, enforce_url
|
||||
from .._ssl import default_ssl_context
|
||||
from .._synchronization import AsyncLock
|
||||
from .._trace import Trace
|
||||
from .connection_pool import AsyncConnectionPool
|
||||
from .http11 import AsyncHTTP11Connection
|
||||
from .interfaces import AsyncConnectionInterface
|
||||
|
||||
logger = logging.getLogger("httpcore.socks")
|
||||
|
||||
|
||||
AUTH_METHODS = {
|
||||
b"\x00": "NO AUTHENTICATION REQUIRED",
|
||||
b"\x01": "GSSAPI",
|
||||
b"\x02": "USERNAME/PASSWORD",
|
||||
b"\xff": "NO ACCEPTABLE METHODS",
|
||||
}
|
||||
|
||||
REPLY_CODES = {
|
||||
b"\x00": "Succeeded",
|
||||
b"\x01": "General SOCKS server failure",
|
||||
b"\x02": "Connection not allowed by ruleset",
|
||||
b"\x03": "Network unreachable",
|
||||
b"\x04": "Host unreachable",
|
||||
b"\x05": "Connection refused",
|
||||
b"\x06": "TTL expired",
|
||||
b"\x07": "Command not supported",
|
||||
b"\x08": "Address type not supported",
|
||||
}
|
||||
|
||||
|
||||
async def _init_socks5_connection(
|
||||
stream: AsyncNetworkStream,
|
||||
*,
|
||||
host: bytes,
|
||||
port: int,
|
||||
auth: typing.Optional[typing.Tuple[bytes, bytes]] = None,
|
||||
) -> None:
|
||||
conn = socks5.SOCKS5Connection()
|
||||
|
||||
# Auth method request
|
||||
auth_method = (
|
||||
socks5.SOCKS5AuthMethod.NO_AUTH_REQUIRED
|
||||
if auth is None
|
||||
else socks5.SOCKS5AuthMethod.USERNAME_PASSWORD
|
||||
)
|
||||
conn.send(socks5.SOCKS5AuthMethodsRequest([auth_method]))
|
||||
outgoing_bytes = conn.data_to_send()
|
||||
await stream.write(outgoing_bytes)
|
||||
|
||||
# Auth method response
|
||||
incoming_bytes = await stream.read(max_bytes=4096)
|
||||
response = conn.receive_data(incoming_bytes)
|
||||
assert isinstance(response, socks5.SOCKS5AuthReply)
|
||||
if response.method != auth_method:
|
||||
requested = AUTH_METHODS.get(auth_method, "UNKNOWN")
|
||||
responded = AUTH_METHODS.get(response.method, "UNKNOWN")
|
||||
raise ProxyError(
|
||||
f"Requested {requested} from proxy server, but got {responded}."
|
||||
)
|
||||
|
||||
if response.method == socks5.SOCKS5AuthMethod.USERNAME_PASSWORD:
|
||||
# Username/password request
|
||||
assert auth is not None
|
||||
username, password = auth
|
||||
conn.send(socks5.SOCKS5UsernamePasswordRequest(username, password))
|
||||
outgoing_bytes = conn.data_to_send()
|
||||
await stream.write(outgoing_bytes)
|
||||
|
||||
# Username/password response
|
||||
incoming_bytes = await stream.read(max_bytes=4096)
|
||||
response = conn.receive_data(incoming_bytes)
|
||||
assert isinstance(response, socks5.SOCKS5UsernamePasswordReply)
|
||||
if not response.success:
|
||||
raise ProxyError("Invalid username/password")
|
||||
|
||||
# Connect request
|
||||
conn.send(
|
||||
socks5.SOCKS5CommandRequest.from_address(
|
||||
socks5.SOCKS5Command.CONNECT, (host, port)
|
||||
)
|
||||
)
|
||||
outgoing_bytes = conn.data_to_send()
|
||||
await stream.write(outgoing_bytes)
|
||||
|
||||
# Connect response
|
||||
incoming_bytes = await stream.read(max_bytes=4096)
|
||||
response = conn.receive_data(incoming_bytes)
|
||||
assert isinstance(response, socks5.SOCKS5Reply)
|
||||
if response.reply_code != socks5.SOCKS5ReplyCode.SUCCEEDED:
|
||||
reply_code = REPLY_CODES.get(response.reply_code, "UNKOWN")
|
||||
raise ProxyError(f"Proxy Server could not connect: {reply_code}.")
|
||||
|
||||
|
||||
class AsyncSOCKSProxy(AsyncConnectionPool):
|
||||
"""
|
||||
A connection pool that sends requests via an HTTP proxy.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
proxy_url: typing.Union[URL, bytes, str],
|
||||
proxy_auth: typing.Optional[
|
||||
typing.Tuple[typing.Union[bytes, str], typing.Union[bytes, str]]
|
||||
] = None,
|
||||
ssl_context: typing.Optional[ssl.SSLContext] = None,
|
||||
max_connections: typing.Optional[int] = 10,
|
||||
max_keepalive_connections: typing.Optional[int] = None,
|
||||
keepalive_expiry: typing.Optional[float] = None,
|
||||
http1: bool = True,
|
||||
http2: bool = False,
|
||||
retries: int = 0,
|
||||
network_backend: typing.Optional[AsyncNetworkBackend] = None,
|
||||
) -> None:
|
||||
"""
|
||||
A connection pool for making HTTP requests.
|
||||
|
||||
Parameters:
|
||||
proxy_url: The URL to use when connecting to the proxy server.
|
||||
For example `"http://127.0.0.1:8080/"`.
|
||||
ssl_context: An SSL context to use for verifying connections.
|
||||
If not specified, the default `httpcore.default_ssl_context()`
|
||||
will be used.
|
||||
max_connections: The maximum number of concurrent HTTP connections that
|
||||
the pool should allow. Any attempt to send a request on a pool that
|
||||
would exceed this amount will block until a connection is available.
|
||||
max_keepalive_connections: The maximum number of idle HTTP connections
|
||||
that will be maintained in the pool.
|
||||
keepalive_expiry: The duration in seconds that an idle HTTP connection
|
||||
may be maintained for before being expired from the pool.
|
||||
http1: A boolean indicating if HTTP/1.1 requests should be supported
|
||||
by the connection pool. Defaults to True.
|
||||
http2: A boolean indicating if HTTP/2 requests should be supported by
|
||||
the connection pool. Defaults to False.
|
||||
retries: The maximum number of retries when trying to establish
|
||||
a connection.
|
||||
local_address: Local address to connect from. Can also be used to
|
||||
connect using a particular address family. Using
|
||||
`local_address="0.0.0.0"` will connect using an `AF_INET` address
|
||||
(IPv4), while using `local_address="::"` will connect using an
|
||||
`AF_INET6` address (IPv6).
|
||||
uds: Path to a Unix Domain Socket to use instead of TCP sockets.
|
||||
network_backend: A backend instance to use for handling network I/O.
|
||||
"""
|
||||
super().__init__(
|
||||
ssl_context=ssl_context,
|
||||
max_connections=max_connections,
|
||||
max_keepalive_connections=max_keepalive_connections,
|
||||
keepalive_expiry=keepalive_expiry,
|
||||
http1=http1,
|
||||
http2=http2,
|
||||
network_backend=network_backend,
|
||||
retries=retries,
|
||||
)
|
||||
self._ssl_context = ssl_context
|
||||
self._proxy_url = enforce_url(proxy_url, name="proxy_url")
|
||||
if proxy_auth is not None:
|
||||
username, password = proxy_auth
|
||||
username_bytes = enforce_bytes(username, name="proxy_auth")
|
||||
password_bytes = enforce_bytes(password, name="proxy_auth")
|
||||
self._proxy_auth: typing.Optional[typing.Tuple[bytes, bytes]] = (
|
||||
username_bytes,
|
||||
password_bytes,
|
||||
)
|
||||
else:
|
||||
self._proxy_auth = None
|
||||
|
||||
def create_connection(self, origin: Origin) -> AsyncConnectionInterface:
|
||||
return AsyncSocks5Connection(
|
||||
proxy_origin=self._proxy_url.origin,
|
||||
remote_origin=origin,
|
||||
proxy_auth=self._proxy_auth,
|
||||
ssl_context=self._ssl_context,
|
||||
keepalive_expiry=self._keepalive_expiry,
|
||||
http1=self._http1,
|
||||
http2=self._http2,
|
||||
network_backend=self._network_backend,
|
||||
)
|
||||
|
||||
|
||||
class AsyncSocks5Connection(AsyncConnectionInterface):
|
||||
def __init__(
|
||||
self,
|
||||
proxy_origin: Origin,
|
||||
remote_origin: Origin,
|
||||
proxy_auth: typing.Optional[typing.Tuple[bytes, bytes]] = None,
|
||||
ssl_context: typing.Optional[ssl.SSLContext] = None,
|
||||
keepalive_expiry: typing.Optional[float] = None,
|
||||
http1: bool = True,
|
||||
http2: bool = False,
|
||||
network_backend: typing.Optional[AsyncNetworkBackend] = None,
|
||||
) -> None:
|
||||
self._proxy_origin = proxy_origin
|
||||
self._remote_origin = remote_origin
|
||||
self._proxy_auth = proxy_auth
|
||||
self._ssl_context = ssl_context
|
||||
self._keepalive_expiry = keepalive_expiry
|
||||
self._http1 = http1
|
||||
self._http2 = http2
|
||||
|
||||
self._network_backend: AsyncNetworkBackend = (
|
||||
AutoBackend() if network_backend is None else network_backend
|
||||
)
|
||||
self._connect_lock = AsyncLock()
|
||||
self._connection: typing.Optional[AsyncConnectionInterface] = None
|
||||
self._connect_failed = False
|
||||
|
||||
async def handle_async_request(self, request: Request) -> Response:
|
||||
timeouts = request.extensions.get("timeout", {})
|
||||
timeout = timeouts.get("connect", None)
|
||||
|
||||
async with self._connect_lock:
|
||||
if self._connection is None:
|
||||
try:
|
||||
# Connect to the proxy
|
||||
kwargs = {
|
||||
"host": self._proxy_origin.host.decode("ascii"),
|
||||
"port": self._proxy_origin.port,
|
||||
"timeout": timeout,
|
||||
}
|
||||
with Trace("connect_tcp", logger, request, kwargs) as trace:
|
||||
stream = await self._network_backend.connect_tcp(**kwargs)
|
||||
trace.return_value = stream
|
||||
|
||||
# Connect to the remote host using socks5
|
||||
kwargs = {
|
||||
"stream": stream,
|
||||
"host": self._remote_origin.host.decode("ascii"),
|
||||
"port": self._remote_origin.port,
|
||||
"auth": self._proxy_auth,
|
||||
}
|
||||
with Trace(
|
||||
"setup_socks5_connection", logger, request, kwargs
|
||||
) as trace:
|
||||
await _init_socks5_connection(**kwargs)
|
||||
trace.return_value = stream
|
||||
|
||||
# Upgrade the stream to SSL
|
||||
if self._remote_origin.scheme == b"https":
|
||||
ssl_context = (
|
||||
default_ssl_context()
|
||||
if self._ssl_context is None
|
||||
else self._ssl_context
|
||||
)
|
||||
alpn_protocols = (
|
||||
["http/1.1", "h2"] if self._http2 else ["http/1.1"]
|
||||
)
|
||||
ssl_context.set_alpn_protocols(alpn_protocols)
|
||||
|
||||
kwargs = {
|
||||
"ssl_context": ssl_context,
|
||||
"server_hostname": self._remote_origin.host.decode("ascii"),
|
||||
"timeout": timeout,
|
||||
}
|
||||
async with Trace("start_tls", logger, request, kwargs) as trace:
|
||||
stream = await stream.start_tls(**kwargs)
|
||||
trace.return_value = stream
|
||||
|
||||
# Determine if we should be using HTTP/1.1 or HTTP/2
|
||||
ssl_object = stream.get_extra_info("ssl_object")
|
||||
http2_negotiated = (
|
||||
ssl_object is not None
|
||||
and ssl_object.selected_alpn_protocol() == "h2"
|
||||
)
|
||||
|
||||
# Create the HTTP/1.1 or HTTP/2 connection
|
||||
if http2_negotiated or (
|
||||
self._http2 and not self._http1
|
||||
): # pragma: nocover
|
||||
from .http2 import AsyncHTTP2Connection
|
||||
|
||||
self._connection = AsyncHTTP2Connection(
|
||||
origin=self._remote_origin,
|
||||
stream=stream,
|
||||
keepalive_expiry=self._keepalive_expiry,
|
||||
)
|
||||
else:
|
||||
self._connection = AsyncHTTP11Connection(
|
||||
origin=self._remote_origin,
|
||||
stream=stream,
|
||||
keepalive_expiry=self._keepalive_expiry,
|
||||
)
|
||||
except Exception as exc:
|
||||
self._connect_failed = True
|
||||
raise exc
|
||||
elif not self._connection.is_available(): # pragma: nocover
|
||||
raise ConnectionNotAvailable()
|
||||
|
||||
return await self._connection.handle_async_request(request)
|
||||
|
||||
def can_handle_request(self, origin: Origin) -> bool:
|
||||
return origin == self._remote_origin
|
||||
|
||||
async def aclose(self) -> None:
|
||||
if self._connection is not None:
|
||||
await self._connection.aclose()
|
||||
|
||||
def is_available(self) -> bool:
|
||||
if self._connection is None: # pragma: nocover
|
||||
# If HTTP/2 support is enabled, and the resulting connection could
|
||||
# end up as HTTP/2 then we should indicate the connection as being
|
||||
# available to service multiple requests.
|
||||
return (
|
||||
self._http2
|
||||
and (self._remote_origin.scheme == b"https" or not self._http1)
|
||||
and not self._connect_failed
|
||||
)
|
||||
return self._connection.is_available()
|
||||
|
||||
def has_expired(self) -> bool:
|
||||
if self._connection is None: # pragma: nocover
|
||||
return self._connect_failed
|
||||
return self._connection.has_expired()
|
||||
|
||||
def is_idle(self) -> bool:
|
||||
if self._connection is None: # pragma: nocover
|
||||
return self._connect_failed
|
||||
return self._connection.is_idle()
|
||||
|
||||
def is_closed(self) -> bool:
|
||||
if self._connection is None: # pragma: nocover
|
||||
return self._connect_failed
|
||||
return self._connection.is_closed()
|
||||
|
||||
def info(self) -> str:
|
||||
if self._connection is None: # pragma: nocover
|
||||
return "CONNECTION FAILED" if self._connect_failed else "CONNECTING"
|
||||
return self._connection.info()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<{self.__class__.__name__} [{self.info()}]>"
|
|
@ -0,0 +1,145 @@
|
|||
import ssl
|
||||
import typing
|
||||
|
||||
import anyio
|
||||
|
||||
from .._exceptions import (
|
||||
ConnectError,
|
||||
ConnectTimeout,
|
||||
ReadError,
|
||||
ReadTimeout,
|
||||
WriteError,
|
||||
WriteTimeout,
|
||||
map_exceptions,
|
||||
)
|
||||
from .._utils import is_socket_readable
|
||||
from .base import SOCKET_OPTION, AsyncNetworkBackend, AsyncNetworkStream
|
||||
|
||||
|
||||
class AnyIOStream(AsyncNetworkStream):
|
||||
def __init__(self, stream: anyio.abc.ByteStream) -> None:
|
||||
self._stream = stream
|
||||
|
||||
async def read(
|
||||
self, max_bytes: int, timeout: typing.Optional[float] = None
|
||||
) -> bytes:
|
||||
exc_map = {
|
||||
TimeoutError: ReadTimeout,
|
||||
anyio.BrokenResourceError: ReadError,
|
||||
anyio.ClosedResourceError: ReadError,
|
||||
}
|
||||
with map_exceptions(exc_map):
|
||||
with anyio.fail_after(timeout):
|
||||
try:
|
||||
return await self._stream.receive(max_bytes=max_bytes)
|
||||
except anyio.EndOfStream: # pragma: nocover
|
||||
return b""
|
||||
|
||||
async def write(
|
||||
self, buffer: bytes, timeout: typing.Optional[float] = None
|
||||
) -> None:
|
||||
if not buffer:
|
||||
return
|
||||
|
||||
exc_map = {
|
||||
TimeoutError: WriteTimeout,
|
||||
anyio.BrokenResourceError: WriteError,
|
||||
anyio.ClosedResourceError: WriteError,
|
||||
}
|
||||
with map_exceptions(exc_map):
|
||||
with anyio.fail_after(timeout):
|
||||
await self._stream.send(item=buffer)
|
||||
|
||||
async def aclose(self) -> None:
|
||||
await self._stream.aclose()
|
||||
|
||||
async def start_tls(
|
||||
self,
|
||||
ssl_context: ssl.SSLContext,
|
||||
server_hostname: typing.Optional[str] = None,
|
||||
timeout: typing.Optional[float] = None,
|
||||
) -> AsyncNetworkStream:
|
||||
exc_map = {
|
||||
TimeoutError: ConnectTimeout,
|
||||
anyio.BrokenResourceError: ConnectError,
|
||||
}
|
||||
with map_exceptions(exc_map):
|
||||
try:
|
||||
with anyio.fail_after(timeout):
|
||||
ssl_stream = await anyio.streams.tls.TLSStream.wrap(
|
||||
self._stream,
|
||||
ssl_context=ssl_context,
|
||||
hostname=server_hostname,
|
||||
standard_compatible=False,
|
||||
server_side=False,
|
||||
)
|
||||
except Exception as exc: # pragma: nocover
|
||||
await self.aclose()
|
||||
raise exc
|
||||
return AnyIOStream(ssl_stream)
|
||||
|
||||
def get_extra_info(self, info: str) -> typing.Any:
|
||||
if info == "ssl_object":
|
||||
return self._stream.extra(anyio.streams.tls.TLSAttribute.ssl_object, None)
|
||||
if info == "client_addr":
|
||||
return self._stream.extra(anyio.abc.SocketAttribute.local_address, None)
|
||||
if info == "server_addr":
|
||||
return self._stream.extra(anyio.abc.SocketAttribute.remote_address, None)
|
||||
if info == "socket":
|
||||
return self._stream.extra(anyio.abc.SocketAttribute.raw_socket, None)
|
||||
if info == "is_readable":
|
||||
sock = self._stream.extra(anyio.abc.SocketAttribute.raw_socket, None)
|
||||
return is_socket_readable(sock)
|
||||
return None
|
||||
|
||||
|
||||
class AnyIOBackend(AsyncNetworkBackend):
|
||||
async def connect_tcp(
|
||||
self,
|
||||
host: str,
|
||||
port: int,
|
||||
timeout: typing.Optional[float] = None,
|
||||
local_address: typing.Optional[str] = None,
|
||||
socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None,
|
||||
) -> AsyncNetworkStream:
|
||||
if socket_options is None:
|
||||
socket_options = [] # pragma: no cover
|
||||
exc_map = {
|
||||
TimeoutError: ConnectTimeout,
|
||||
OSError: ConnectError,
|
||||
anyio.BrokenResourceError: ConnectError,
|
||||
}
|
||||
with map_exceptions(exc_map):
|
||||
with anyio.fail_after(timeout):
|
||||
stream: anyio.abc.ByteStream = await anyio.connect_tcp(
|
||||
remote_host=host,
|
||||
remote_port=port,
|
||||
local_host=local_address,
|
||||
)
|
||||
# By default TCP sockets opened in `asyncio` include TCP_NODELAY.
|
||||
for option in socket_options:
|
||||
stream._raw_socket.setsockopt(*option) # type: ignore[attr-defined] # pragma: no cover
|
||||
return AnyIOStream(stream)
|
||||
|
||||
async def connect_unix_socket(
|
||||
self,
|
||||
path: str,
|
||||
timeout: typing.Optional[float] = None,
|
||||
socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None,
|
||||
) -> AsyncNetworkStream: # pragma: nocover
|
||||
if socket_options is None:
|
||||
socket_options = []
|
||||
exc_map = {
|
||||
TimeoutError: ConnectTimeout,
|
||||
OSError: ConnectError,
|
||||
anyio.BrokenResourceError: ConnectError,
|
||||
}
|
||||
with map_exceptions(exc_map):
|
||||
with anyio.fail_after(timeout):
|
||||
stream: anyio.abc.ByteStream = await anyio.connect_unix(path)
|
||||
for option in socket_options:
|
||||
stream._raw_socket.setsockopt(*option) # type: ignore[attr-defined] # pragma: no cover
|
||||
return AnyIOStream(stream)
|
||||
|
||||
async def sleep(self, seconds: float) -> None:
|
||||
await anyio.sleep(seconds) # pragma: nocover
|
|
@ -0,0 +1,56 @@
|
|||
import typing
|
||||
from typing import Optional
|
||||
|
||||
import sniffio
|
||||
|
||||
from .base import SOCKET_OPTION, AsyncNetworkBackend, AsyncNetworkStream
|
||||
|
||||
|
||||
class AutoBackend(AsyncNetworkBackend):
|
||||
async def _init_backend(self) -> None:
|
||||
if not (hasattr(self, "_backend")):
|
||||
backend = sniffio.current_async_library()
|
||||
if backend == "trio":
|
||||
from .trio import TrioBackend
|
||||
|
||||
self._backend: AsyncNetworkBackend = TrioBackend()
|
||||
elif backend == "structured-io":
|
||||
from .structio import StructioBackend
|
||||
|
||||
self._backend: AsyncNetworkBackend = StructioBackend()
|
||||
else:
|
||||
from .anyio import AnyIOBackend
|
||||
|
||||
self._backend = AnyIOBackend()
|
||||
|
||||
async def connect_tcp(
|
||||
self,
|
||||
host: str,
|
||||
port: int,
|
||||
timeout: Optional[float] = None,
|
||||
local_address: Optional[str] = None,
|
||||
socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None,
|
||||
) -> AsyncNetworkStream:
|
||||
await self._init_backend()
|
||||
return await self._backend.connect_tcp(
|
||||
host,
|
||||
port,
|
||||
timeout=timeout,
|
||||
local_address=local_address,
|
||||
socket_options=socket_options,
|
||||
)
|
||||
|
||||
async def connect_unix_socket(
|
||||
self,
|
||||
path: str,
|
||||
timeout: Optional[float] = None,
|
||||
socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None,
|
||||
) -> AsyncNetworkStream: # pragma: nocover
|
||||
await self._init_backend()
|
||||
return await self._backend.connect_unix_socket(
|
||||
path, timeout=timeout, socket_options=socket_options
|
||||
)
|
||||
|
||||
async def sleep(self, seconds: float) -> None: # pragma: nocover
|
||||
await self._init_backend()
|
||||
return await self._backend.sleep(seconds)
|
|
@ -0,0 +1,103 @@
|
|||
import ssl
|
||||
import time
|
||||
import typing
|
||||
|
||||
SOCKET_OPTION = typing.Union[
|
||||
typing.Tuple[int, int, int],
|
||||
typing.Tuple[int, int, typing.Union[bytes, bytearray]],
|
||||
typing.Tuple[int, int, None, int],
|
||||
]
|
||||
|
||||
|
||||
class NetworkStream:
|
||||
def read(self, max_bytes: int, timeout: typing.Optional[float] = None) -> bytes:
|
||||
raise NotImplementedError() # pragma: nocover
|
||||
|
||||
def write(self, buffer: bytes, timeout: typing.Optional[float] = None) -> None:
|
||||
raise NotImplementedError() # pragma: nocover
|
||||
|
||||
def close(self) -> None:
|
||||
raise NotImplementedError() # pragma: nocover
|
||||
|
||||
def start_tls(
|
||||
self,
|
||||
ssl_context: ssl.SSLContext,
|
||||
server_hostname: typing.Optional[str] = None,
|
||||
timeout: typing.Optional[float] = None,
|
||||
) -> "NetworkStream":
|
||||
raise NotImplementedError() # pragma: nocover
|
||||
|
||||
def get_extra_info(self, info: str) -> typing.Any:
|
||||
return None # pragma: nocover
|
||||
|
||||
|
||||
class NetworkBackend:
|
||||
def connect_tcp(
|
||||
self,
|
||||
host: str,
|
||||
port: int,
|
||||
timeout: typing.Optional[float] = None,
|
||||
local_address: typing.Optional[str] = None,
|
||||
socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None,
|
||||
) -> NetworkStream:
|
||||
raise NotImplementedError() # pragma: nocover
|
||||
|
||||
def connect_unix_socket(
|
||||
self,
|
||||
path: str,
|
||||
timeout: typing.Optional[float] = None,
|
||||
socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None,
|
||||
) -> NetworkStream:
|
||||
raise NotImplementedError() # pragma: nocover
|
||||
|
||||
def sleep(self, seconds: float) -> None:
|
||||
time.sleep(seconds) # pragma: nocover
|
||||
|
||||
|
||||
class AsyncNetworkStream:
|
||||
async def read(
|
||||
self, max_bytes: int, timeout: typing.Optional[float] = None
|
||||
) -> bytes:
|
||||
raise NotImplementedError() # pragma: nocover
|
||||
|
||||
async def write(
|
||||
self, buffer: bytes, timeout: typing.Optional[float] = None
|
||||
) -> None:
|
||||
raise NotImplementedError() # pragma: nocover
|
||||
|
||||
async def aclose(self) -> None:
|
||||
raise NotImplementedError() # pragma: nocover
|
||||
|
||||
async def start_tls(
|
||||
self,
|
||||
ssl_context: ssl.SSLContext,
|
||||
server_hostname: typing.Optional[str] = None,
|
||||
timeout: typing.Optional[float] = None,
|
||||
) -> "AsyncNetworkStream":
|
||||
raise NotImplementedError() # pragma: nocover
|
||||
|
||||
def get_extra_info(self, info: str) -> typing.Any:
|
||||
return None # pragma: nocover
|
||||
|
||||
|
||||
class AsyncNetworkBackend:
|
||||
async def connect_tcp(
|
||||
self,
|
||||
host: str,
|
||||
port: int,
|
||||
timeout: typing.Optional[float] = None,
|
||||
local_address: typing.Optional[str] = None,
|
||||
socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None,
|
||||
) -> AsyncNetworkStream:
|
||||
raise NotImplementedError() # pragma: nocover
|
||||
|
||||
async def connect_unix_socket(
|
||||
self,
|
||||
path: str,
|
||||
timeout: typing.Optional[float] = None,
|
||||
socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None,
|
||||
) -> AsyncNetworkStream:
|
||||
raise NotImplementedError() # pragma: nocover
|
||||
|
||||
async def sleep(self, seconds: float) -> None:
|
||||
raise NotImplementedError() # pragma: nocover
|
|
@ -0,0 +1,142 @@
|
|||
import ssl
|
||||
import typing
|
||||
from typing import Optional
|
||||
|
||||
from .._exceptions import ReadError
|
||||
from .base import (
|
||||
SOCKET_OPTION,
|
||||
AsyncNetworkBackend,
|
||||
AsyncNetworkStream,
|
||||
NetworkBackend,
|
||||
NetworkStream,
|
||||
)
|
||||
|
||||
|
||||
class MockSSLObject:
|
||||
def __init__(self, http2: bool):
|
||||
self._http2 = http2
|
||||
|
||||
def selected_alpn_protocol(self) -> str:
|
||||
return "h2" if self._http2 else "http/1.1"
|
||||
|
||||
|
||||
class MockStream(NetworkStream):
|
||||
def __init__(self, buffer: typing.List[bytes], http2: bool = False) -> None:
|
||||
self._buffer = buffer
|
||||
self._http2 = http2
|
||||
self._closed = False
|
||||
|
||||
def read(self, max_bytes: int, timeout: Optional[float] = None) -> bytes:
|
||||
if self._closed:
|
||||
raise ReadError("Connection closed")
|
||||
if not self._buffer:
|
||||
return b""
|
||||
return self._buffer.pop(0)
|
||||
|
||||
def write(self, buffer: bytes, timeout: Optional[float] = None) -> None:
|
||||
pass
|
||||
|
||||
def close(self) -> None:
|
||||
self._closed = True
|
||||
|
||||
def start_tls(
|
||||
self,
|
||||
ssl_context: ssl.SSLContext,
|
||||
server_hostname: Optional[str] = None,
|
||||
timeout: Optional[float] = None,
|
||||
) -> NetworkStream:
|
||||
return self
|
||||
|
||||
def get_extra_info(self, info: str) -> typing.Any:
|
||||
return MockSSLObject(http2=self._http2) if info == "ssl_object" else None
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "<httpcore.MockStream>"
|
||||
|
||||
|
||||
class MockBackend(NetworkBackend):
|
||||
def __init__(self, buffer: typing.List[bytes], http2: bool = False) -> None:
|
||||
self._buffer = buffer
|
||||
self._http2 = http2
|
||||
|
||||
def connect_tcp(
|
||||
self,
|
||||
host: str,
|
||||
port: int,
|
||||
timeout: Optional[float] = None,
|
||||
local_address: Optional[str] = None,
|
||||
socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None,
|
||||
) -> NetworkStream:
|
||||
return MockStream(list(self._buffer), http2=self._http2)
|
||||
|
||||
def connect_unix_socket(
|
||||
self,
|
||||
path: str,
|
||||
timeout: Optional[float] = None,
|
||||
socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None,
|
||||
) -> NetworkStream:
|
||||
return MockStream(list(self._buffer), http2=self._http2)
|
||||
|
||||
def sleep(self, seconds: float) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class AsyncMockStream(AsyncNetworkStream):
|
||||
def __init__(self, buffer: typing.List[bytes], http2: bool = False) -> None:
|
||||
self._buffer = buffer
|
||||
self._http2 = http2
|
||||
self._closed = False
|
||||
|
||||
async def read(self, max_bytes: int, timeout: Optional[float] = None) -> bytes:
|
||||
if self._closed:
|
||||
raise ReadError("Connection closed")
|
||||
if not self._buffer:
|
||||
return b""
|
||||
return self._buffer.pop(0)
|
||||
|
||||
async def write(self, buffer: bytes, timeout: Optional[float] = None) -> None:
|
||||
pass
|
||||
|
||||
async def aclose(self) -> None:
|
||||
self._closed = True
|
||||
|
||||
async def start_tls(
|
||||
self,
|
||||
ssl_context: ssl.SSLContext,
|
||||
server_hostname: Optional[str] = None,
|
||||
timeout: Optional[float] = None,
|
||||
) -> AsyncNetworkStream:
|
||||
return self
|
||||
|
||||
def get_extra_info(self, info: str) -> typing.Any:
|
||||
return MockSSLObject(http2=self._http2) if info == "ssl_object" else None
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "<httpcore.AsyncMockStream>"
|
||||
|
||||
|
||||
class AsyncMockBackend(AsyncNetworkBackend):
|
||||
def __init__(self, buffer: typing.List[bytes], http2: bool = False) -> None:
|
||||
self._buffer = buffer
|
||||
self._http2 = http2
|
||||
|
||||
async def connect_tcp(
|
||||
self,
|
||||
host: str,
|
||||
port: int,
|
||||
timeout: Optional[float] = None,
|
||||
local_address: Optional[str] = None,
|
||||
socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None,
|
||||
) -> AsyncNetworkStream:
|
||||
return AsyncMockStream(list(self._buffer), http2=self._http2)
|
||||
|
||||
async def connect_unix_socket(
|
||||
self,
|
||||
path: str,
|
||||
timeout: Optional[float] = None,
|
||||
socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None,
|
||||
) -> AsyncNetworkStream:
|
||||
return AsyncMockStream(list(self._buffer), http2=self._http2)
|
||||
|
||||
async def sleep(self, seconds: float) -> None:
|
||||
pass
|
|
@ -0,0 +1,118 @@
|
|||
from .base import AsyncNetworkStream, AsyncNetworkBackend, SOCKET_OPTION
|
||||
from .._exceptions import (ConnectTimeout, ReadError, ReadTimeout,
|
||||
WriteError, WriteTimeout, ConnectError,
|
||||
map_exceptions, ExceptionMapping, PoolTimeout)
|
||||
|
||||
import structio
|
||||
import ssl
|
||||
import typing
|
||||
|
||||
|
||||
class StructioStream(AsyncNetworkStream):
|
||||
"""
|
||||
Structio-compatible async stream for
|
||||
httpx
|
||||
"""
|
||||
|
||||
def __init__(self, stream: structio.AsyncSocket):
|
||||
self._stream = stream
|
||||
|
||||
async def read(
|
||||
self, max_bytes: int, timeout: typing.Optional[float] = None
|
||||
) -> bytes:
|
||||
timeout_or_inf = float("inf") if timeout is None else timeout
|
||||
exc_map: ExceptionMapping = {
|
||||
structio.TimedOut: ReadTimeout,
|
||||
structio.ResourceClosed: ReadError,
|
||||
structio.ResourceBusy: ReadError,
|
||||
structio.ResourceBroken: ReadError
|
||||
}
|
||||
with map_exceptions(exc_map):
|
||||
with structio.with_timeout(timeout_or_inf):
|
||||
data: bytes = await self._stream.receive(max_bytes)
|
||||
return data
|
||||
|
||||
async def write(
|
||||
self, buffer: bytes, timeout: typing.Optional[float] = None
|
||||
) -> None:
|
||||
if not buffer:
|
||||
return
|
||||
|
||||
timeout_or_inf = float("inf") if timeout is None else timeout
|
||||
exc_map: ExceptionMapping = {
|
||||
structio.TimedOut: WriteTimeout,
|
||||
structio.ResourceClosed: WriteError,
|
||||
structio.ResourceBusy: WriteError,
|
||||
structio.ResourceBroken: WriteError
|
||||
}
|
||||
with map_exceptions(exc_map):
|
||||
with structio.with_timeout(timeout_or_inf):
|
||||
await self._stream.send_all(data=buffer)
|
||||
|
||||
async def start_tls(
|
||||
self,
|
||||
ssl_context: ssl.SSLContext,
|
||||
server_hostname: typing.Optional[str] = None,
|
||||
timeout: typing.Optional[float] = None,
|
||||
) -> AsyncNetworkStream:
|
||||
timeout_or_inf = float("inf") if timeout is None else timeout
|
||||
exc_map: ExceptionMapping = {
|
||||
structio.TimedOut: ConnectTimeout,
|
||||
structio.ResourceBroken: ConnectError,
|
||||
}
|
||||
with map_exceptions(exc_map):
|
||||
try:
|
||||
with structio.with_timeout(timeout_or_inf):
|
||||
self._stream = await structio.socket.wrap_socket_with_ssl(self._stream, context=ssl_context, server_hostname=server_hostname)
|
||||
except Exception as exc: # pragma: nocover
|
||||
await self.aclose()
|
||||
raise exc
|
||||
return self
|
||||
|
||||
async def aclose(self) -> None:
|
||||
return await self._stream.close()
|
||||
|
||||
def get_extra_info(self, info: str) -> typing.Any:
|
||||
if info == "ssl_object" and hasattr(self._stream, "_sslobj"):
|
||||
return self._stream._sslobj
|
||||
if info == "client_addr":
|
||||
return self._stream.socket.getsockname()
|
||||
if info == "server_addr":
|
||||
return self._stream.socket.getpeername()
|
||||
if info == "socket":
|
||||
return self._stream
|
||||
if info == "is_readable":
|
||||
return self._stream.is_readable()
|
||||
return None
|
||||
|
||||
|
||||
class StructioBackend(AsyncNetworkBackend):
|
||||
async def connect_tcp(
|
||||
self,
|
||||
host: str,
|
||||
port: int,
|
||||
timeout: typing.Optional[float] = None,
|
||||
local_address: typing.Optional[str] = None,
|
||||
socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None,
|
||||
) -> AsyncNetworkStream:
|
||||
if socket_options is None:
|
||||
socket_options = []
|
||||
timeout_or_inf = float("inf") if timeout is None else timeout
|
||||
exc_map: ExceptionMapping = {
|
||||
structio.TimedOut: ConnectTimeout,
|
||||
structio.ResourceBusy: ConnectError,
|
||||
OSError: ConnectError,
|
||||
}
|
||||
with map_exceptions(exc_map):
|
||||
with structio.with_timeout(timeout_or_inf):
|
||||
stream: structio.AsyncSocket = await structio.socket.connect_tcp_socket(
|
||||
host=host, port=port, source_address=local_address
|
||||
)
|
||||
for option in socket_options:
|
||||
stream.setsockopt(*option)
|
||||
return StructioStream(stream)
|
||||
|
||||
# TODO: connect_unix_socket
|
||||
|
||||
async def sleep(self, seconds: float) -> None:
|
||||
await structio.sleep(seconds)
|
|
@ -0,0 +1,133 @@
|
|||
import socket
|
||||
import ssl
|
||||
import sys
|
||||
import typing
|
||||
|
||||
from .._exceptions import (
|
||||
ConnectError,
|
||||
ConnectTimeout,
|
||||
ExceptionMapping,
|
||||
ReadError,
|
||||
ReadTimeout,
|
||||
WriteError,
|
||||
WriteTimeout,
|
||||
map_exceptions,
|
||||
)
|
||||
from .._utils import is_socket_readable
|
||||
from .base import SOCKET_OPTION, NetworkBackend, NetworkStream
|
||||
|
||||
|
||||
class SyncStream(NetworkStream):
|
||||
def __init__(self, sock: socket.socket) -> None:
|
||||
self._sock = sock
|
||||
|
||||
def read(self, max_bytes: int, timeout: typing.Optional[float] = None) -> bytes:
|
||||
exc_map: ExceptionMapping = {socket.timeout: ReadTimeout, OSError: ReadError}
|
||||
with map_exceptions(exc_map):
|
||||
self._sock.settimeout(timeout)
|
||||
return self._sock.recv(max_bytes)
|
||||
|
||||
def write(self, buffer: bytes, timeout: typing.Optional[float] = None) -> None:
|
||||
if not buffer:
|
||||
return
|
||||
|
||||
exc_map: ExceptionMapping = {socket.timeout: WriteTimeout, OSError: WriteError}
|
||||
with map_exceptions(exc_map):
|
||||
while buffer:
|
||||
self._sock.settimeout(timeout)
|
||||
n = self._sock.send(buffer)
|
||||
buffer = buffer[n:]
|
||||
|
||||
def close(self) -> None:
|
||||
self._sock.close()
|
||||
|
||||
def start_tls(
|
||||
self,
|
||||
ssl_context: ssl.SSLContext,
|
||||
server_hostname: typing.Optional[str] = None,
|
||||
timeout: typing.Optional[float] = None,
|
||||
) -> NetworkStream:
|
||||
exc_map: ExceptionMapping = {
|
||||
socket.timeout: ConnectTimeout,
|
||||
OSError: ConnectError,
|
||||
}
|
||||
with map_exceptions(exc_map):
|
||||
try:
|
||||
self._sock.settimeout(timeout)
|
||||
sock = ssl_context.wrap_socket(
|
||||
self._sock, server_hostname=server_hostname
|
||||
)
|
||||
except Exception as exc: # pragma: nocover
|
||||
self.close()
|
||||
raise exc
|
||||
return SyncStream(sock)
|
||||
|
||||
def get_extra_info(self, info: str) -> typing.Any:
|
||||
if info == "ssl_object" and isinstance(self._sock, ssl.SSLSocket):
|
||||
return self._sock._sslobj # type: ignore
|
||||
if info == "client_addr":
|
||||
return self._sock.getsockname()
|
||||
if info == "server_addr":
|
||||
return self._sock.getpeername()
|
||||
if info == "socket":
|
||||
return self._sock
|
||||
if info == "is_readable":
|
||||
return is_socket_readable(self._sock)
|
||||
return None
|
||||
|
||||
|
||||
class SyncBackend(NetworkBackend):
|
||||
def connect_tcp(
|
||||
self,
|
||||
host: str,
|
||||
port: int,
|
||||
timeout: typing.Optional[float] = None,
|
||||
local_address: typing.Optional[str] = None,
|
||||
socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None,
|
||||
) -> NetworkStream:
|
||||
# Note that we automatically include `TCP_NODELAY`
|
||||
# in addition to any other custom socket options.
|
||||
if socket_options is None:
|
||||
socket_options = [] # pragma: no cover
|
||||
address = (host, port)
|
||||
source_address = None if local_address is None else (local_address, 0)
|
||||
exc_map: ExceptionMapping = {
|
||||
socket.timeout: ConnectTimeout,
|
||||
OSError: ConnectError,
|
||||
}
|
||||
|
||||
with map_exceptions(exc_map):
|
||||
sock = socket.create_connection(
|
||||
address,
|
||||
timeout,
|
||||
source_address=source_address,
|
||||
)
|
||||
for option in socket_options:
|
||||
sock.setsockopt(*option) # pragma: no cover
|
||||
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
|
||||
return SyncStream(sock)
|
||||
|
||||
def connect_unix_socket(
|
||||
self,
|
||||
path: str,
|
||||
timeout: typing.Optional[float] = None,
|
||||
socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None,
|
||||
) -> NetworkStream: # pragma: nocover
|
||||
if sys.platform == "win32":
|
||||
raise RuntimeError(
|
||||
"Attempted to connect to a UNIX socket on a Windows system."
|
||||
)
|
||||
if socket_options is None:
|
||||
socket_options = []
|
||||
|
||||
exc_map: ExceptionMapping = {
|
||||
socket.timeout: ConnectTimeout,
|
||||
OSError: ConnectError,
|
||||
}
|
||||
with map_exceptions(exc_map):
|
||||
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
|
||||
for option in socket_options:
|
||||
sock.setsockopt(*option)
|
||||
sock.settimeout(timeout)
|
||||
sock.connect(path)
|
||||
return SyncStream(sock)
|
|
@ -0,0 +1,161 @@
|
|||
import ssl
|
||||
import typing
|
||||
|
||||
import trio
|
||||
|
||||
from .._exceptions import (
|
||||
ConnectError,
|
||||
ConnectTimeout,
|
||||
ExceptionMapping,
|
||||
ReadError,
|
||||
ReadTimeout,
|
||||
WriteError,
|
||||
WriteTimeout,
|
||||
map_exceptions,
|
||||
)
|
||||
from .base import SOCKET_OPTION, AsyncNetworkBackend, AsyncNetworkStream
|
||||
|
||||
|
||||
class TrioStream(AsyncNetworkStream):
|
||||
def __init__(self, stream: trio.abc.Stream) -> None:
|
||||
self._stream = stream
|
||||
|
||||
async def read(
|
||||
self, max_bytes: int, timeout: typing.Optional[float] = None
|
||||
) -> bytes:
|
||||
timeout_or_inf = float("inf") if timeout is None else timeout
|
||||
exc_map: ExceptionMapping = {
|
||||
trio.TooSlowError: ReadTimeout,
|
||||
trio.BrokenResourceError: ReadError,
|
||||
trio.ClosedResourceError: ReadError,
|
||||
}
|
||||
with map_exceptions(exc_map):
|
||||
with trio.fail_after(timeout_or_inf):
|
||||
data: bytes = await self._stream.receive_some(max_bytes=max_bytes)
|
||||
return data
|
||||
|
||||
async def write(
|
||||
self, buffer: bytes, timeout: typing.Optional[float] = None
|
||||
) -> None:
|
||||
if not buffer:
|
||||
return
|
||||
|
||||
timeout_or_inf = float("inf") if timeout is None else timeout
|
||||
exc_map: ExceptionMapping = {
|
||||
trio.TooSlowError: WriteTimeout,
|
||||
trio.BrokenResourceError: WriteError,
|
||||
trio.ClosedResourceError: WriteError,
|
||||
}
|
||||
with map_exceptions(exc_map):
|
||||
with trio.fail_after(timeout_or_inf):
|
||||
await self._stream.send_all(data=buffer)
|
||||
|
||||
async def aclose(self) -> None:
|
||||
await self._stream.aclose()
|
||||
|
||||
async def start_tls(
|
||||
self,
|
||||
ssl_context: ssl.SSLContext,
|
||||
server_hostname: typing.Optional[str] = None,
|
||||
timeout: typing.Optional[float] = None,
|
||||
) -> AsyncNetworkStream:
|
||||
timeout_or_inf = float("inf") if timeout is None else timeout
|
||||
exc_map: ExceptionMapping = {
|
||||
trio.TooSlowError: ConnectTimeout,
|
||||
trio.BrokenResourceError: ConnectError,
|
||||
}
|
||||
ssl_stream = trio.SSLStream(
|
||||
self._stream,
|
||||
ssl_context=ssl_context,
|
||||
server_hostname=server_hostname,
|
||||
https_compatible=True,
|
||||
server_side=False,
|
||||
)
|
||||
with map_exceptions(exc_map):
|
||||
try:
|
||||
with trio.fail_after(timeout_or_inf):
|
||||
await ssl_stream.do_handshake()
|
||||
except Exception as exc: # pragma: nocover
|
||||
await self.aclose()
|
||||
raise exc
|
||||
return TrioStream(ssl_stream)
|
||||
|
||||
def get_extra_info(self, info: str) -> typing.Any:
|
||||
if info == "ssl_object" and isinstance(self._stream, trio.SSLStream):
|
||||
# Type checkers cannot see `_ssl_object` attribute because trio._ssl.SSLStream uses __getattr__/__setattr__.
|
||||
# Tracked at https://github.com/python-trio/trio/issues/542
|
||||
return self._stream._ssl_object # type: ignore[attr-defined]
|
||||
if info == "client_addr":
|
||||
return self._get_socket_stream().socket.getsockname()
|
||||
if info == "server_addr":
|
||||
return self._get_socket_stream().socket.getpeername()
|
||||
if info == "socket":
|
||||
stream = self._stream
|
||||
while isinstance(stream, trio.SSLStream):
|
||||
stream = stream.transport_stream
|
||||
assert isinstance(stream, trio.SocketStream)
|
||||
return stream.socket
|
||||
if info == "is_readable":
|
||||
socket = self.get_extra_info("socket")
|
||||
return socket.is_readable()
|
||||
return None
|
||||
|
||||
def _get_socket_stream(self) -> trio.SocketStream:
|
||||
stream = self._stream
|
||||
while isinstance(stream, trio.SSLStream):
|
||||
stream = stream.transport_stream
|
||||
assert isinstance(stream, trio.SocketStream)
|
||||
return stream
|
||||
|
||||
|
||||
class TrioBackend(AsyncNetworkBackend):
|
||||
async def connect_tcp(
|
||||
self,
|
||||
host: str,
|
||||
port: int,
|
||||
timeout: typing.Optional[float] = None,
|
||||
local_address: typing.Optional[str] = None,
|
||||
socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None,
|
||||
) -> AsyncNetworkStream:
|
||||
# By default for TCP sockets, trio enables TCP_NODELAY.
|
||||
# https://trio.readthedocs.io/en/stable/reference-io.html#trio.SocketStream
|
||||
if socket_options is None:
|
||||
socket_options = [] # pragma: no cover
|
||||
timeout_or_inf = float("inf") if timeout is None else timeout
|
||||
exc_map: ExceptionMapping = {
|
||||
trio.TooSlowError: ConnectTimeout,
|
||||
trio.BrokenResourceError: ConnectError,
|
||||
OSError: ConnectError,
|
||||
}
|
||||
with map_exceptions(exc_map):
|
||||
with trio.fail_after(timeout_or_inf):
|
||||
stream: trio.abc.Stream = await trio.open_tcp_stream(
|
||||
host=host, port=port, local_address=local_address
|
||||
)
|
||||
for option in socket_options:
|
||||
stream.setsockopt(*option) # type: ignore[attr-defined] # pragma: no cover
|
||||
return TrioStream(stream)
|
||||
|
||||
async def connect_unix_socket(
|
||||
self,
|
||||
path: str,
|
||||
timeout: typing.Optional[float] = None,
|
||||
socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None,
|
||||
) -> AsyncNetworkStream: # pragma: nocover
|
||||
if socket_options is None:
|
||||
socket_options = []
|
||||
timeout_or_inf = float("inf") if timeout is None else timeout
|
||||
exc_map: ExceptionMapping = {
|
||||
trio.TooSlowError: ConnectTimeout,
|
||||
trio.BrokenResourceError: ConnectError,
|
||||
OSError: ConnectError,
|
||||
}
|
||||
with map_exceptions(exc_map):
|
||||
with trio.fail_after(timeout_or_inf):
|
||||
stream: trio.abc.Stream = await trio.open_unix_socket(path)
|
||||
for option in socket_options:
|
||||
stream.setsockopt(*option) # type: ignore[attr-defined] # pragma: no cover
|
||||
return TrioStream(stream)
|
||||
|
||||
async def sleep(self, seconds: float) -> None:
|
||||
await trio.sleep(seconds) # pragma: nocover
|
|
@ -0,0 +1,81 @@
|
|||
import contextlib
|
||||
from typing import Iterator, Mapping, Type
|
||||
|
||||
ExceptionMapping = Mapping[Type[Exception], Type[Exception]]
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def map_exceptions(map: ExceptionMapping) -> Iterator[None]:
|
||||
try:
|
||||
yield
|
||||
except Exception as exc: # noqa: PIE786
|
||||
for from_exc, to_exc in map.items():
|
||||
if isinstance(exc, from_exc):
|
||||
raise to_exc(exc) from exc
|
||||
raise # pragma: nocover
|
||||
|
||||
|
||||
class ConnectionNotAvailable(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class ProxyError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class UnsupportedProtocol(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class ProtocolError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class RemoteProtocolError(ProtocolError):
|
||||
pass
|
||||
|
||||
|
||||
class LocalProtocolError(ProtocolError):
|
||||
pass
|
||||
|
||||
|
||||
# Timeout errors
|
||||
|
||||
|
||||
class TimeoutException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class PoolTimeout(TimeoutException):
|
||||
pass
|
||||
|
||||
|
||||
class ConnectTimeout(TimeoutException):
|
||||
pass
|
||||
|
||||
|
||||
class ReadTimeout(TimeoutException):
|
||||
pass
|
||||
|
||||
|
||||
class WriteTimeout(TimeoutException):
|
||||
pass
|
||||
|
||||
|
||||
# Network errors
|
||||
|
||||
|
||||
class NetworkError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class ConnectError(NetworkError):
|
||||
pass
|
||||
|
||||
|
||||
class ReadError(NetworkError):
|
||||
pass
|
||||
|
||||
|
||||
class WriteError(NetworkError):
|
||||
pass
|
|
@ -0,0 +1,483 @@
|
|||
from typing import (
|
||||
Any,
|
||||
AsyncIterable,
|
||||
AsyncIterator,
|
||||
Iterable,
|
||||
Iterator,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
from urllib.parse import urlparse
|
||||
|
||||
# Functions for typechecking...
|
||||
|
||||
|
||||
HeadersAsSequence = Sequence[Tuple[Union[bytes, str], Union[bytes, str]]]
|
||||
HeadersAsMapping = Mapping[Union[bytes, str], Union[bytes, str]]
|
||||
HeaderTypes = Union[HeadersAsSequence, HeadersAsMapping, None]
|
||||
|
||||
Extensions = Mapping[str, Any]
|
||||
|
||||
|
||||
def enforce_bytes(value: Union[bytes, str], *, name: str) -> bytes:
|
||||
"""
|
||||
Any arguments that are ultimately represented as bytes can be specified
|
||||
either as bytes or as strings.
|
||||
|
||||
However we enforce that any string arguments must only contain characters in
|
||||
the plain ASCII range. chr(0)...chr(127). If you need to use characters
|
||||
outside that range then be precise, and use a byte-wise argument.
|
||||
"""
|
||||
if isinstance(value, str):
|
||||
try:
|
||||
return value.encode("ascii")
|
||||
except UnicodeEncodeError:
|
||||
raise TypeError(f"{name} strings may not include unicode characters.")
|
||||
elif isinstance(value, bytes):
|
||||
return value
|
||||
|
||||
seen_type = type(value).__name__
|
||||
raise TypeError(f"{name} must be bytes or str, but got {seen_type}.")
|
||||
|
||||
|
||||
def enforce_url(value: Union["URL", bytes, str], *, name: str) -> "URL":
|
||||
"""
|
||||
Type check for URL parameters.
|
||||
"""
|
||||
if isinstance(value, (bytes, str)):
|
||||
return URL(value)
|
||||
elif isinstance(value, URL):
|
||||
return value
|
||||
|
||||
seen_type = type(value).__name__
|
||||
raise TypeError(f"{name} must be a URL, bytes, or str, but got {seen_type}.")
|
||||
|
||||
|
||||
def enforce_headers(
|
||||
value: Union[HeadersAsMapping, HeadersAsSequence, None] = None, *, name: str
|
||||
) -> List[Tuple[bytes, bytes]]:
|
||||
"""
|
||||
Convienence function that ensure all items in request or response headers
|
||||
are either bytes or strings in the plain ASCII range.
|
||||
"""
|
||||
if value is None:
|
||||
return []
|
||||
elif isinstance(value, Mapping):
|
||||
return [
|
||||
(
|
||||
enforce_bytes(k, name="header name"),
|
||||
enforce_bytes(v, name="header value"),
|
||||
)
|
||||
for k, v in value.items()
|
||||
]
|
||||
elif isinstance(value, Sequence):
|
||||
return [
|
||||
(
|
||||
enforce_bytes(k, name="header name"),
|
||||
enforce_bytes(v, name="header value"),
|
||||
)
|
||||
for k, v in value
|
||||
]
|
||||
|
||||
seen_type = type(value).__name__
|
||||
raise TypeError(
|
||||
f"{name} must be a mapping or sequence of two-tuples, but got {seen_type}."
|
||||
)
|
||||
|
||||
|
||||
def enforce_stream(
|
||||
value: Union[bytes, Iterable[bytes], AsyncIterable[bytes], None], *, name: str
|
||||
) -> Union[Iterable[bytes], AsyncIterable[bytes]]:
|
||||
if value is None:
|
||||
return ByteStream(b"")
|
||||
elif isinstance(value, bytes):
|
||||
return ByteStream(value)
|
||||
return value
|
||||
|
||||
|
||||
# * https://tools.ietf.org/html/rfc3986#section-3.2.3
|
||||
# * https://url.spec.whatwg.org/#url-miscellaneous
|
||||
# * https://url.spec.whatwg.org/#scheme-state
|
||||
DEFAULT_PORTS = {
|
||||
b"ftp": 21,
|
||||
b"http": 80,
|
||||
b"https": 443,
|
||||
b"ws": 80,
|
||||
b"wss": 443,
|
||||
}
|
||||
|
||||
|
||||
def include_request_headers(
|
||||
headers: List[Tuple[bytes, bytes]],
|
||||
*,
|
||||
url: "URL",
|
||||
content: Union[None, bytes, Iterable[bytes], AsyncIterable[bytes]],
|
||||
) -> List[Tuple[bytes, bytes]]:
|
||||
headers_set = set(k.lower() for k, v in headers)
|
||||
|
||||
if b"host" not in headers_set:
|
||||
default_port = DEFAULT_PORTS.get(url.scheme)
|
||||
if url.port is None or url.port == default_port:
|
||||
header_value = url.host
|
||||
else:
|
||||
header_value = b"%b:%d" % (url.host, url.port)
|
||||
headers = [(b"Host", header_value)] + headers
|
||||
|
||||
if (
|
||||
content is not None
|
||||
and b"content-length" not in headers_set
|
||||
and b"transfer-encoding" not in headers_set
|
||||
):
|
||||
if isinstance(content, bytes):
|
||||
content_length = str(len(content)).encode("ascii")
|
||||
headers += [(b"Content-Length", content_length)]
|
||||
else:
|
||||
headers += [(b"Transfer-Encoding", b"chunked")] # pragma: nocover
|
||||
|
||||
return headers
|
||||
|
||||
|
||||
# Interfaces for byte streams...
|
||||
|
||||
|
||||
class ByteStream:
|
||||
"""
|
||||
A container for non-streaming content, and that supports both sync and async
|
||||
stream iteration.
|
||||
"""
|
||||
|
||||
def __init__(self, content: bytes) -> None:
|
||||
self._content = content
|
||||
|
||||
def __iter__(self) -> Iterator[bytes]:
|
||||
yield self._content
|
||||
|
||||
async def __aiter__(self) -> AsyncIterator[bytes]:
|
||||
yield self._content
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<{self.__class__.__name__} [{len(self._content)} bytes]>"
|
||||
|
||||
|
||||
class Origin:
|
||||
def __init__(self, scheme: bytes, host: bytes, port: int) -> None:
|
||||
self.scheme = scheme
|
||||
self.host = host
|
||||
self.port = port
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
return (
|
||||
isinstance(other, Origin)
|
||||
and self.scheme == other.scheme
|
||||
and self.host == other.host
|
||||
and self.port == other.port
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
scheme = self.scheme.decode("ascii")
|
||||
host = self.host.decode("ascii")
|
||||
port = str(self.port)
|
||||
return f"{scheme}://{host}:{port}"
|
||||
|
||||
|
||||
class URL:
|
||||
"""
|
||||
Represents the URL against which an HTTP request may be made.
|
||||
|
||||
The URL may either be specified as a plain string, for convienence:
|
||||
|
||||
```python
|
||||
url = httpcore.URL("https://www.example.com/")
|
||||
```
|
||||
|
||||
Or be constructed with explicitily pre-parsed components:
|
||||
|
||||
```python
|
||||
url = httpcore.URL(scheme=b'https', host=b'www.example.com', port=None, target=b'/')
|
||||
```
|
||||
|
||||
Using this second more explicit style allows integrations that are using
|
||||
`httpcore` to pass through URLs that have already been parsed in order to use
|
||||
libraries such as `rfc-3986` rather than relying on the stdlib. It also ensures
|
||||
that URL parsing is treated identically at both the networking level and at any
|
||||
higher layers of abstraction.
|
||||
|
||||
The four components are important here, as they allow the URL to be precisely
|
||||
specified in a pre-parsed format. They also allow certain types of request to
|
||||
be created that could not otherwise be expressed.
|
||||
|
||||
For example, an HTTP request to `http://www.example.com/` forwarded via a proxy
|
||||
at `http://localhost:8080`...
|
||||
|
||||
```python
|
||||
# Constructs an HTTP request with a complete URL as the target:
|
||||
# GET https://www.example.com/ HTTP/1.1
|
||||
url = httpcore.URL(
|
||||
scheme=b'http',
|
||||
host=b'localhost',
|
||||
port=8080,
|
||||
target=b'https://www.example.com/'
|
||||
)
|
||||
request = httpcore.Request(
|
||||
method="GET",
|
||||
url=url
|
||||
)
|
||||
```
|
||||
|
||||
Another example is constructing an `OPTIONS *` request...
|
||||
|
||||
```python
|
||||
# Constructs an 'OPTIONS *' HTTP request:
|
||||
# OPTIONS * HTTP/1.1
|
||||
url = httpcore.URL(scheme=b'https', host=b'www.example.com', target=b'*')
|
||||
request = httpcore.Request(method="OPTIONS", url=url)
|
||||
```
|
||||
|
||||
This kind of request is not possible to formulate with a URL string,
|
||||
because the `/` delimiter is always used to demark the target from the
|
||||
host/port portion of the URL.
|
||||
|
||||
For convenience, string-like arguments may be specified either as strings or
|
||||
as bytes. However, once a request is being issue over-the-wire, the URL
|
||||
components are always ultimately required to be a bytewise representation.
|
||||
|
||||
In order to avoid any ambiguity over character encodings, when strings are used
|
||||
as arguments, they must be strictly limited to the ASCII range `chr(0)`-`chr(127)`.
|
||||
If you require a bytewise representation that is outside this range you must
|
||||
handle the character encoding directly, and pass a bytes instance.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
url: Union[bytes, str] = "",
|
||||
*,
|
||||
scheme: Union[bytes, str] = b"",
|
||||
host: Union[bytes, str] = b"",
|
||||
port: Optional[int] = None,
|
||||
target: Union[bytes, str] = b"",
|
||||
) -> None:
|
||||
"""
|
||||
Parameters:
|
||||
url: The complete URL as a string or bytes.
|
||||
scheme: The URL scheme as a string or bytes.
|
||||
Typically either `"http"` or `"https"`.
|
||||
host: The URL host as a string or bytes. Such as `"www.example.com"`.
|
||||
port: The port to connect to. Either an integer or `None`.
|
||||
target: The target of the HTTP request. Such as `"/items?search=red"`.
|
||||
"""
|
||||
if url:
|
||||
parsed = urlparse(enforce_bytes(url, name="url"))
|
||||
self.scheme = parsed.scheme
|
||||
self.host = parsed.hostname or b""
|
||||
self.port = parsed.port
|
||||
self.target = (parsed.path or b"/") + (
|
||||
b"?" + parsed.query if parsed.query else b""
|
||||
)
|
||||
else:
|
||||
self.scheme = enforce_bytes(scheme, name="scheme")
|
||||
self.host = enforce_bytes(host, name="host")
|
||||
self.port = port
|
||||
self.target = enforce_bytes(target, name="target")
|
||||
|
||||
@property
|
||||
def origin(self) -> Origin:
|
||||
default_port = {
|
||||
b"http": 80,
|
||||
b"https": 443,
|
||||
b"ws": 80,
|
||||
b"wss": 443,
|
||||
b"socks5": 1080,
|
||||
}[self.scheme]
|
||||
return Origin(
|
||||
scheme=self.scheme, host=self.host, port=self.port or default_port
|
||||
)
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
return (
|
||||
isinstance(other, URL)
|
||||
and other.scheme == self.scheme
|
||||
and other.host == self.host
|
||||
and other.port == self.port
|
||||
and other.target == self.target
|
||||
)
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
if self.port is None:
|
||||
return b"%b://%b%b" % (self.scheme, self.host, self.target)
|
||||
return b"%b://%b:%d%b" % (self.scheme, self.host, self.port, self.target)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"{self.__class__.__name__}(scheme={self.scheme!r}, "
|
||||
f"host={self.host!r}, port={self.port!r}, target={self.target!r})"
|
||||
)
|
||||
|
||||
|
||||
class Request:
|
||||
"""
|
||||
An HTTP request.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
method: Union[bytes, str],
|
||||
url: Union[URL, bytes, str],
|
||||
*,
|
||||
headers: HeaderTypes = None,
|
||||
content: Union[bytes, Iterable[bytes], AsyncIterable[bytes], None] = None,
|
||||
extensions: Optional[Extensions] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Parameters:
|
||||
method: The HTTP request method, either as a string or bytes.
|
||||
For example: `GET`.
|
||||
url: The request URL, either as a `URL` instance, or as a string or bytes.
|
||||
For example: `"https://www.example.com".`
|
||||
headers: The HTTP request headers.
|
||||
content: The content of the response body.
|
||||
extensions: A dictionary of optional extra information included on
|
||||
the request. Possible keys include `"timeout"`, and `"trace"`.
|
||||
"""
|
||||
self.method: bytes = enforce_bytes(method, name="method")
|
||||
self.url: URL = enforce_url(url, name="url")
|
||||
self.headers: List[Tuple[bytes, bytes]] = enforce_headers(
|
||||
headers, name="headers"
|
||||
)
|
||||
self.stream: Union[Iterable[bytes], AsyncIterable[bytes]] = enforce_stream(
|
||||
content, name="content"
|
||||
)
|
||||
self.extensions = {} if extensions is None else extensions
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<{self.__class__.__name__} [{self.method!r}]>"
|
||||
|
||||
|
||||
class Response:
|
||||
"""
|
||||
An HTTP response.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
status: int,
|
||||
*,
|
||||
headers: HeaderTypes = None,
|
||||
content: Union[bytes, Iterable[bytes], AsyncIterable[bytes], None] = None,
|
||||
extensions: Optional[Extensions] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Parameters:
|
||||
status: The HTTP status code of the response. For example `200`.
|
||||
headers: The HTTP response headers.
|
||||
content: The content of the response body.
|
||||
extensions: A dictionary of optional extra information included on
|
||||
the responseself.Possible keys include `"http_version"`,
|
||||
`"reason_phrase"`, and `"network_stream"`.
|
||||
"""
|
||||
self.status: int = status
|
||||
self.headers: List[Tuple[bytes, bytes]] = enforce_headers(
|
||||
headers, name="headers"
|
||||
)
|
||||
self.stream: Union[Iterable[bytes], AsyncIterable[bytes]] = enforce_stream(
|
||||
content, name="content"
|
||||
)
|
||||
self.extensions = {} if extensions is None else extensions
|
||||
|
||||
self._stream_consumed = False
|
||||
|
||||
@property
|
||||
def content(self) -> bytes:
|
||||
if not hasattr(self, "_content"):
|
||||
if isinstance(self.stream, Iterable):
|
||||
raise RuntimeError(
|
||||
"Attempted to access 'response.content' on a streaming response. "
|
||||
"Call 'response.read()' first."
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"Attempted to access 'response.content' on a streaming response. "
|
||||
"Call 'await response.aread()' first."
|
||||
)
|
||||
return self._content
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<{self.__class__.__name__} [{self.status}]>"
|
||||
|
||||
# Sync interface...
|
||||
|
||||
def read(self) -> bytes:
|
||||
if not isinstance(self.stream, Iterable): # pragma: nocover
|
||||
raise RuntimeError(
|
||||
"Attempted to read an asynchronous response using 'response.read()'. "
|
||||
"You should use 'await response.aread()' instead."
|
||||
)
|
||||
if not hasattr(self, "_content"):
|
||||
self._content = b"".join([part for part in self.iter_stream()])
|
||||
return self._content
|
||||
|
||||
def iter_stream(self) -> Iterator[bytes]:
|
||||
if not isinstance(self.stream, Iterable): # pragma: nocover
|
||||
raise RuntimeError(
|
||||
"Attempted to stream an asynchronous response using 'for ... in "
|
||||
"response.iter_stream()'. "
|
||||
"You should use 'async for ... in response.aiter_stream()' instead."
|
||||
)
|
||||
if self._stream_consumed:
|
||||
raise RuntimeError(
|
||||
"Attempted to call 'for ... in response.iter_stream()' more than once."
|
||||
)
|
||||
self._stream_consumed = True
|
||||
for chunk in self.stream:
|
||||
yield chunk
|
||||
|
||||
def close(self) -> None:
|
||||
if not isinstance(self.stream, Iterable): # pragma: nocover
|
||||
raise RuntimeError(
|
||||
"Attempted to close an asynchronous response using 'response.close()'. "
|
||||
"You should use 'await response.aclose()' instead."
|
||||
)
|
||||
if hasattr(self.stream, "close"):
|
||||
self.stream.close()
|
||||
|
||||
# Async interface...
|
||||
|
||||
async def aread(self) -> bytes:
|
||||
if not isinstance(self.stream, AsyncIterable): # pragma: nocover
|
||||
raise RuntimeError(
|
||||
"Attempted to read an synchronous response using "
|
||||
"'await response.aread()'. "
|
||||
"You should use 'response.read()' instead."
|
||||
)
|
||||
if not hasattr(self, "_content"):
|
||||
self._content = b"".join([part async for part in self.aiter_stream()])
|
||||
return self._content
|
||||
|
||||
async def aiter_stream(self) -> AsyncIterator[bytes]:
|
||||
if not isinstance(self.stream, AsyncIterable): # pragma: nocover
|
||||
raise RuntimeError(
|
||||
"Attempted to stream an synchronous response using 'async for ... in "
|
||||
"response.aiter_stream()'. "
|
||||
"You should use 'for ... in response.iter_stream()' instead."
|
||||
)
|
||||
if self._stream_consumed:
|
||||
raise RuntimeError(
|
||||
"Attempted to call 'async for ... in response.aiter_stream()' "
|
||||
"more than once."
|
||||
)
|
||||
self._stream_consumed = True
|
||||
async for chunk in self.stream:
|
||||
yield chunk
|
||||
|
||||
async def aclose(self) -> None:
|
||||
if not isinstance(self.stream, AsyncIterable): # pragma: nocover
|
||||
raise RuntimeError(
|
||||
"Attempted to close a synchronous response using "
|
||||
"'await response.aclose()'. "
|
||||
"You should use 'response.close()' instead."
|
||||
)
|
||||
if hasattr(self.stream, "aclose"):
|
||||
await self.stream.aclose()
|
|
@ -0,0 +1,8 @@
|
|||
import ssl
|
||||
import certifi
|
||||
|
||||
|
||||
def default_ssl_context() -> ssl.SSLContext:
|
||||
context = ssl.create_default_context()
|
||||
context.load_verify_locations(certifi.where())
|
||||
return context
|
|
@ -0,0 +1,39 @@
|
|||
from .connection import HTTPConnection
|
||||
from .connection_pool import ConnectionPool
|
||||
from .http11 import HTTP11Connection
|
||||
from .http_proxy import HTTPProxy
|
||||
from .interfaces import ConnectionInterface
|
||||
|
||||
try:
|
||||
from .http2 import HTTP2Connection
|
||||
except ImportError: # pragma: nocover
|
||||
|
||||
class HTTP2Connection: # type: ignore
|
||||
def __init__(self, *args, **kwargs) -> None: # type: ignore
|
||||
raise RuntimeError(
|
||||
"Attempted to use http2 support, but the `h2` package is not "
|
||||
"installed. Use 'pip install httpcore[http2]'."
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
from .socks_proxy import SOCKSProxy
|
||||
except ImportError: # pragma: nocover
|
||||
|
||||
class SOCKSProxy: # type: ignore
|
||||
def __init__(self, *args, **kwargs) -> None: # type: ignore
|
||||
raise RuntimeError(
|
||||
"Attempted to use SOCKS support, but the `socksio` package is not "
|
||||
"installed. Use 'pip install httpcore[socks]'."
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"HTTPConnection",
|
||||
"ConnectionPool",
|
||||
"HTTPProxy",
|
||||
"HTTP11Connection",
|
||||
"HTTP2Connection",
|
||||
"ConnectionInterface",
|
||||
"SOCKSProxy",
|
||||
]
|
|
@ -0,0 +1,215 @@
|
|||
import itertools
|
||||
import logging
|
||||
import ssl
|
||||
from types import TracebackType
|
||||
from typing import Iterable, Iterator, Optional, Type
|
||||
|
||||
from .._backends.sync import SyncBackend
|
||||
from .._backends.base import SOCKET_OPTION, NetworkBackend, NetworkStream
|
||||
from .._exceptions import ConnectError, ConnectionNotAvailable, ConnectTimeout
|
||||
from .._models import Origin, Request, Response
|
||||
from .._ssl import default_ssl_context
|
||||
from .._synchronization import Lock
|
||||
from .._trace import Trace
|
||||
from .http11 import HTTP11Connection
|
||||
from .interfaces import ConnectionInterface
|
||||
|
||||
RETRIES_BACKOFF_FACTOR = 0.5 # 0s, 0.5s, 1s, 2s, 4s, etc.
|
||||
|
||||
|
||||
logger = logging.getLogger("httpcore.connection")
|
||||
|
||||
|
||||
def exponential_backoff(factor: float) -> Iterator[float]:
|
||||
yield 0
|
||||
for n in itertools.count(2):
|
||||
yield factor * (2 ** (n - 2))
|
||||
|
||||
|
||||
class HTTPConnection(ConnectionInterface):
|
||||
def __init__(
|
||||
self,
|
||||
origin: Origin,
|
||||
ssl_context: Optional[ssl.SSLContext] = None,
|
||||
keepalive_expiry: Optional[float] = None,
|
||||
http1: bool = True,
|
||||
http2: bool = False,
|
||||
retries: int = 0,
|
||||
local_address: Optional[str] = None,
|
||||
uds: Optional[str] = None,
|
||||
network_backend: Optional[NetworkBackend] = None,
|
||||
socket_options: Optional[Iterable[SOCKET_OPTION]] = None,
|
||||
) -> None:
|
||||
self._origin = origin
|
||||
self._ssl_context = ssl_context
|
||||
self._keepalive_expiry = keepalive_expiry
|
||||
self._http1 = http1
|
||||
self._http2 = http2
|
||||
self._retries = retries
|
||||
self._local_address = local_address
|
||||
self._uds = uds
|
||||
|
||||
self._network_backend: NetworkBackend = (
|
||||
SyncBackend() if network_backend is None else network_backend
|
||||
)
|
||||
self._connection: Optional[ConnectionInterface] = None
|
||||
self._connect_failed: bool = False
|
||||
self._request_lock = Lock()
|
||||
self._socket_options = socket_options
|
||||
|
||||
def handle_request(self, request: Request) -> Response:
|
||||
if not self.can_handle_request(request.url.origin):
|
||||
raise RuntimeError(
|
||||
f"Attempted to send request to {request.url.origin} on connection to {self._origin}"
|
||||
)
|
||||
|
||||
with self._request_lock:
|
||||
if self._connection is None:
|
||||
try:
|
||||
stream = self._connect(request)
|
||||
|
||||
ssl_object = stream.get_extra_info("ssl_object")
|
||||
http2_negotiated = (
|
||||
ssl_object is not None
|
||||
and ssl_object.selected_alpn_protocol() == "h2"
|
||||
)
|
||||
if http2_negotiated or (self._http2 and not self._http1):
|
||||
from .http2 import HTTP2Connection
|
||||
|
||||
self._connection = HTTP2Connection(
|
||||
origin=self._origin,
|
||||
stream=stream,
|
||||
keepalive_expiry=self._keepalive_expiry,
|
||||
)
|
||||
else:
|
||||
self._connection = HTTP11Connection(
|
||||
origin=self._origin,
|
||||
stream=stream,
|
||||
keepalive_expiry=self._keepalive_expiry,
|
||||
)
|
||||
except Exception as exc:
|
||||
self._connect_failed = True
|
||||
raise exc
|
||||
elif not self._connection.is_available():
|
||||
raise ConnectionNotAvailable()
|
||||
|
||||
return self._connection.handle_request(request)
|
||||
|
||||
def _connect(self, request: Request) -> NetworkStream:
|
||||
timeouts = request.extensions.get("timeout", {})
|
||||
sni_hostname = request.extensions.get("sni_hostname", None)
|
||||
timeout = timeouts.get("connect", None)
|
||||
|
||||
retries_left = self._retries
|
||||
delays = exponential_backoff(factor=RETRIES_BACKOFF_FACTOR)
|
||||
|
||||
while True:
|
||||
try:
|
||||
if self._uds is None:
|
||||
kwargs = {
|
||||
"host": self._origin.host.decode("ascii"),
|
||||
"port": self._origin.port,
|
||||
"local_address": self._local_address,
|
||||
"timeout": timeout,
|
||||
"socket_options": self._socket_options,
|
||||
}
|
||||
with Trace("connect_tcp", logger, request, kwargs) as trace:
|
||||
stream = self._network_backend.connect_tcp(**kwargs)
|
||||
trace.return_value = stream
|
||||
else:
|
||||
kwargs = {
|
||||
"path": self._uds,
|
||||
"timeout": timeout,
|
||||
"socket_options": self._socket_options,
|
||||
}
|
||||
with Trace(
|
||||
"connect_unix_socket", logger, request, kwargs
|
||||
) as trace:
|
||||
stream = self._network_backend.connect_unix_socket(
|
||||
**kwargs
|
||||
)
|
||||
trace.return_value = stream
|
||||
|
||||
if self._origin.scheme == b"https":
|
||||
ssl_context = (
|
||||
default_ssl_context()
|
||||
if self._ssl_context is None
|
||||
else self._ssl_context
|
||||
)
|
||||
alpn_protocols = ["http/1.1", "h2"] if self._http2 else ["http/1.1"]
|
||||
ssl_context.set_alpn_protocols(alpn_protocols)
|
||||
|
||||
kwargs = {
|
||||
"ssl_context": ssl_context,
|
||||
"server_hostname": sni_hostname
|
||||
or self._origin.host.decode("ascii"),
|
||||
"timeout": timeout,
|
||||
}
|
||||
with Trace("start_tls", logger, request, kwargs) as trace:
|
||||
stream = stream.start_tls(**kwargs)
|
||||
trace.return_value = stream
|
||||
return stream
|
||||
except (ConnectError, ConnectTimeout):
|
||||
if retries_left <= 0:
|
||||
raise
|
||||
retries_left -= 1
|
||||
delay = next(delays)
|
||||
with Trace("retry", logger, request, kwargs) as trace:
|
||||
self._network_backend.sleep(delay)
|
||||
|
||||
def can_handle_request(self, origin: Origin) -> bool:
|
||||
return origin == self._origin
|
||||
|
||||
def close(self) -> None:
|
||||
if self._connection is not None:
|
||||
with Trace("close", logger, None, {}):
|
||||
self._connection.close()
|
||||
|
||||
def is_available(self) -> bool:
|
||||
if self._connection is None:
|
||||
# If HTTP/2 support is enabled, and the resulting connection could
|
||||
# end up as HTTP/2 then we should indicate the connection as being
|
||||
# available to service multiple requests.
|
||||
return (
|
||||
self._http2
|
||||
and (self._origin.scheme == b"https" or not self._http1)
|
||||
and not self._connect_failed
|
||||
)
|
||||
return self._connection.is_available()
|
||||
|
||||
def has_expired(self) -> bool:
|
||||
if self._connection is None:
|
||||
return self._connect_failed
|
||||
return self._connection.has_expired()
|
||||
|
||||
def is_idle(self) -> bool:
|
||||
if self._connection is None:
|
||||
return self._connect_failed
|
||||
return self._connection.is_idle()
|
||||
|
||||
def is_closed(self) -> bool:
|
||||
if self._connection is None:
|
||||
return self._connect_failed
|
||||
return self._connection.is_closed()
|
||||
|
||||
def info(self) -> str:
|
||||
if self._connection is None:
|
||||
return "CONNECTION FAILED" if self._connect_failed else "CONNECTING"
|
||||
return self._connection.info()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<{self.__class__.__name__} [{self.info()}]>"
|
||||
|
||||
# These context managers are not used in the standard flow, but are
|
||||
# useful for testing or working with connection instances directly.
|
||||
|
||||
def __enter__(self) -> "HTTPConnection":
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: Optional[Type[BaseException]] = None,
|
||||
exc_value: Optional[BaseException] = None,
|
||||
traceback: Optional[TracebackType] = None,
|
||||
) -> None:
|
||||
self.close()
|
|
@ -0,0 +1,356 @@
|
|||
import ssl
|
||||
import sys
|
||||
from types import TracebackType
|
||||
from typing import Iterable, Iterator, Iterable, List, Optional, Type
|
||||
|
||||
from .._backends.sync import SyncBackend
|
||||
from .._backends.base import SOCKET_OPTION, NetworkBackend
|
||||
from .._exceptions import ConnectionNotAvailable, UnsupportedProtocol
|
||||
from .._models import Origin, Request, Response
|
||||
from .._synchronization import Event, Lock, ShieldCancellation
|
||||
from .connection import HTTPConnection
|
||||
from .interfaces import ConnectionInterface, RequestInterface
|
||||
|
||||
|
||||
class RequestStatus:
|
||||
def __init__(self, request: Request):
|
||||
self.request = request
|
||||
self.connection: Optional[ConnectionInterface] = None
|
||||
self._connection_acquired = Event()
|
||||
|
||||
def set_connection(self, connection: ConnectionInterface) -> None:
|
||||
assert self.connection is None
|
||||
self.connection = connection
|
||||
self._connection_acquired.set()
|
||||
|
||||
def unset_connection(self) -> None:
|
||||
assert self.connection is not None
|
||||
self.connection = None
|
||||
self._connection_acquired = Event()
|
||||
|
||||
def wait_for_connection(
|
||||
self, timeout: Optional[float] = None
|
||||
) -> ConnectionInterface:
|
||||
if self.connection is None:
|
||||
self._connection_acquired.wait(timeout=timeout)
|
||||
assert self.connection is not None
|
||||
return self.connection
|
||||
|
||||
|
||||
class ConnectionPool(RequestInterface):
|
||||
"""
|
||||
A connection pool for making HTTP requests.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ssl_context: Optional[ssl.SSLContext] = None,
|
||||
max_connections: Optional[int] = 10,
|
||||
max_keepalive_connections: Optional[int] = None,
|
||||
keepalive_expiry: Optional[float] = None,
|
||||
http1: bool = True,
|
||||
http2: bool = False,
|
||||
retries: int = 0,
|
||||
local_address: Optional[str] = None,
|
||||
uds: Optional[str] = None,
|
||||
network_backend: Optional[NetworkBackend] = None,
|
||||
socket_options: Optional[Iterable[SOCKET_OPTION]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
A connection pool for making HTTP requests.
|
||||
|
||||
Parameters:
|
||||
ssl_context: An SSL context to use for verifying connections.
|
||||
If not specified, the default `httpcore.default_ssl_context()`
|
||||
will be used.
|
||||
max_connections: The maximum number of concurrent HTTP connections that
|
||||
the pool should allow. Any attempt to send a request on a pool that
|
||||
would exceed this amount will block until a connection is available.
|
||||
max_keepalive_connections: The maximum number of idle HTTP connections
|
||||
that will be maintained in the pool.
|
||||
keepalive_expiry: The duration in seconds that an idle HTTP connection
|
||||
may be maintained for before being expired from the pool.
|
||||
http1: A boolean indicating if HTTP/1.1 requests should be supported
|
||||
by the connection pool. Defaults to True.
|
||||
http2: A boolean indicating if HTTP/2 requests should be supported by
|
||||
the connection pool. Defaults to False.
|
||||
retries: The maximum number of retries when trying to establish a
|
||||
connection.
|
||||
local_address: Local address to connect from. Can also be used to connect
|
||||
using a particular address family. Using `local_address="0.0.0.0"`
|
||||
will connect using an `AF_INET` address (IPv4), while using
|
||||
`local_address="::"` will connect using an `AF_INET6` address (IPv6).
|
||||
uds: Path to a Unix Domain Socket to use instead of TCP sockets.
|
||||
network_backend: A backend instance to use for handling network I/O.
|
||||
socket_options: Socket options that have to be included
|
||||
in the TCP socket when the connection was established.
|
||||
"""
|
||||
self._ssl_context = ssl_context
|
||||
|
||||
self._max_connections = (
|
||||
sys.maxsize if max_connections is None else max_connections
|
||||
)
|
||||
self._max_keepalive_connections = (
|
||||
sys.maxsize
|
||||
if max_keepalive_connections is None
|
||||
else max_keepalive_connections
|
||||
)
|
||||
self._max_keepalive_connections = min(
|
||||
self._max_connections, self._max_keepalive_connections
|
||||
)
|
||||
|
||||
self._keepalive_expiry = keepalive_expiry
|
||||
self._http1 = http1
|
||||
self._http2 = http2
|
||||
self._retries = retries
|
||||
self._local_address = local_address
|
||||
self._uds = uds
|
||||
|
||||
self._pool: List[ConnectionInterface] = []
|
||||
self._requests: List[RequestStatus] = []
|
||||
self._pool_lock = Lock()
|
||||
self._network_backend = (
|
||||
SyncBackend() if network_backend is None else network_backend
|
||||
)
|
||||
self._socket_options = socket_options
|
||||
|
||||
def create_connection(self, origin: Origin) -> ConnectionInterface:
|
||||
return HTTPConnection(
|
||||
origin=origin,
|
||||
ssl_context=self._ssl_context,
|
||||
keepalive_expiry=self._keepalive_expiry,
|
||||
http1=self._http1,
|
||||
http2=self._http2,
|
||||
retries=self._retries,
|
||||
local_address=self._local_address,
|
||||
uds=self._uds,
|
||||
network_backend=self._network_backend,
|
||||
socket_options=self._socket_options,
|
||||
)
|
||||
|
||||
@property
|
||||
def connections(self) -> List[ConnectionInterface]:
|
||||
"""
|
||||
Return a list of the connections currently in the pool.
|
||||
|
||||
For example:
|
||||
|
||||
```python
|
||||
>>> pool.connections
|
||||
[
|
||||
<HTTPConnection ['https://example.com:443', HTTP/1.1, ACTIVE, Request Count: 6]>,
|
||||
<HTTPConnection ['https://example.com:443', HTTP/1.1, IDLE, Request Count: 9]> ,
|
||||
<HTTPConnection ['http://example.com:80', HTTP/1.1, IDLE, Request Count: 1]>,
|
||||
]
|
||||
```
|
||||
"""
|
||||
return list(self._pool)
|
||||
|
||||
def _attempt_to_acquire_connection(self, status: RequestStatus) -> bool:
|
||||
"""
|
||||
Attempt to provide a connection that can handle the given origin.
|
||||
"""
|
||||
origin = status.request.url.origin
|
||||
|
||||
# If there are queued requests in front of us, then don't acquire a
|
||||
# connection. We handle requests strictly in order.
|
||||
waiting = [s for s in self._requests if s.connection is None]
|
||||
if waiting and waiting[0] is not status:
|
||||
return False
|
||||
|
||||
# Reuse an existing connection if one is currently available.
|
||||
for idx, connection in enumerate(self._pool):
|
||||
if connection.can_handle_request(origin) and connection.is_available():
|
||||
self._pool.pop(idx)
|
||||
self._pool.insert(0, connection)
|
||||
status.set_connection(connection)
|
||||
return True
|
||||
|
||||
# If the pool is currently full, attempt to close one idle connection.
|
||||
if len(self._pool) >= self._max_connections:
|
||||
for idx, connection in reversed(list(enumerate(self._pool))):
|
||||
if connection.is_idle():
|
||||
connection.close()
|
||||
self._pool.pop(idx)
|
||||
break
|
||||
|
||||
# If the pool is still full, then we cannot acquire a connection.
|
||||
if len(self._pool) >= self._max_connections:
|
||||
return False
|
||||
|
||||
# Otherwise create a new connection.
|
||||
connection = self.create_connection(origin)
|
||||
self._pool.insert(0, connection)
|
||||
status.set_connection(connection)
|
||||
return True
|
||||
|
||||
def _close_expired_connections(self) -> None:
|
||||
"""
|
||||
Clean up the connection pool by closing off any connections that have expired.
|
||||
"""
|
||||
# Close any connections that have expired their keep-alive time.
|
||||
for idx, connection in reversed(list(enumerate(self._pool))):
|
||||
if connection.has_expired():
|
||||
connection.close()
|
||||
self._pool.pop(idx)
|
||||
|
||||
# If the pool size exceeds the maximum number of allowed keep-alive connections,
|
||||
# then close off idle connections as required.
|
||||
pool_size = len(self._pool)
|
||||
for idx, connection in reversed(list(enumerate(self._pool))):
|
||||
if connection.is_idle() and pool_size > self._max_keepalive_connections:
|
||||
connection.close()
|
||||
self._pool.pop(idx)
|
||||
pool_size -= 1
|
||||
|
||||
def handle_request(self, request: Request) -> Response:
|
||||
"""
|
||||
Send an HTTP request, and return an HTTP response.
|
||||
|
||||
This is the core implementation that is called into by `.request()` or `.stream()`.
|
||||
"""
|
||||
scheme = request.url.scheme.decode()
|
||||
if scheme == "":
|
||||
raise UnsupportedProtocol(
|
||||
"Request URL is missing an 'http://' or 'https://' protocol."
|
||||
)
|
||||
if scheme not in ("http", "https", "ws", "wss"):
|
||||
raise UnsupportedProtocol(
|
||||
f"Request URL has an unsupported protocol '{scheme}://'."
|
||||
)
|
||||
|
||||
status = RequestStatus(request)
|
||||
|
||||
with self._pool_lock:
|
||||
self._requests.append(status)
|
||||
self._close_expired_connections()
|
||||
self._attempt_to_acquire_connection(status)
|
||||
|
||||
while True:
|
||||
timeouts = request.extensions.get("timeout", {})
|
||||
timeout = timeouts.get("pool", None)
|
||||
try:
|
||||
connection = status.wait_for_connection(timeout=timeout)
|
||||
except BaseException as exc:
|
||||
# If we timeout here, or if the task is cancelled, then make
|
||||
# sure to remove the request from the queue before bubbling
|
||||
# up the exception.
|
||||
with self._pool_lock:
|
||||
# Ensure only remove when task exists.
|
||||
if status in self._requests:
|
||||
self._requests.remove(status)
|
||||
raise exc
|
||||
|
||||
try:
|
||||
response = connection.handle_request(request)
|
||||
except ConnectionNotAvailable:
|
||||
# The ConnectionNotAvailable exception is a special case, that
|
||||
# indicates we need to retry the request on a new connection.
|
||||
#
|
||||
# The most common case where this can occur is when multiple
|
||||
# requests are queued waiting for a single connection, which
|
||||
# might end up as an HTTP/2 connection, but which actually ends
|
||||
# up as HTTP/1.1.
|
||||
with self._pool_lock:
|
||||
# Maintain our position in the request queue, but reset the
|
||||
# status so that the request becomes queued again.
|
||||
status.unset_connection()
|
||||
self._attempt_to_acquire_connection(status)
|
||||
except BaseException as exc:
|
||||
with ShieldCancellation():
|
||||
self.response_closed(status)
|
||||
raise exc
|
||||
else:
|
||||
break
|
||||
|
||||
# When we return the response, we wrap the stream in a special class
|
||||
# that handles notifying the connection pool once the response
|
||||
# has been released.
|
||||
assert isinstance(response.stream, Iterable)
|
||||
return Response(
|
||||
status=response.status,
|
||||
headers=response.headers,
|
||||
content=ConnectionPoolByteStream(response.stream, self, status),
|
||||
extensions=response.extensions,
|
||||
)
|
||||
|
||||
def response_closed(self, status: RequestStatus) -> None:
|
||||
"""
|
||||
This method acts as a callback once the request/response cycle is complete.
|
||||
|
||||
It is called into from the `ConnectionPoolByteStream.close()` method.
|
||||
"""
|
||||
assert status.connection is not None
|
||||
connection = status.connection
|
||||
|
||||
with self._pool_lock:
|
||||
# Update the state of the connection pool.
|
||||
if status in self._requests:
|
||||
self._requests.remove(status)
|
||||
|
||||
if connection.is_closed() and connection in self._pool:
|
||||
self._pool.remove(connection)
|
||||
|
||||
# Since we've had a response closed, it's possible we'll now be able
|
||||
# to service one or more requests that are currently pending.
|
||||
for status in self._requests:
|
||||
if status.connection is None:
|
||||
acquired = self._attempt_to_acquire_connection(status)
|
||||
# If we could not acquire a connection for a queued request
|
||||
# then we don't need to check anymore requests that are
|
||||
# queued later behind it.
|
||||
if not acquired:
|
||||
break
|
||||
|
||||
# Housekeeping.
|
||||
self._close_expired_connections()
|
||||
|
||||
def close(self) -> None:
|
||||
"""
|
||||
Close any connections in the pool.
|
||||
"""
|
||||
with self._pool_lock:
|
||||
for connection in self._pool:
|
||||
connection.close()
|
||||
self._pool = []
|
||||
self._requests = []
|
||||
|
||||
def __enter__(self) -> "ConnectionPool":
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: Optional[Type[BaseException]] = None,
|
||||
exc_value: Optional[BaseException] = None,
|
||||
traceback: Optional[TracebackType] = None,
|
||||
) -> None:
|
||||
self.close()
|
||||
|
||||
|
||||
class ConnectionPoolByteStream:
|
||||
"""
|
||||
A wrapper around the response byte stream, that additionally handles
|
||||
notifying the connection pool when the response has been closed.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
stream: Iterable[bytes],
|
||||
pool: ConnectionPool,
|
||||
status: RequestStatus,
|
||||
) -> None:
|
||||
self._stream = stream
|
||||
self._pool = pool
|
||||
self._status = status
|
||||
|
||||
def __iter__(self) -> Iterator[bytes]:
|
||||
for part in self._stream:
|
||||
yield part
|
||||
|
||||
def close(self) -> None:
|
||||
try:
|
||||
if hasattr(self._stream, "close"):
|
||||
self._stream.close()
|
||||
finally:
|
||||
with ShieldCancellation():
|
||||
self._pool.response_closed(self._status)
|
|
@ -0,0 +1,331 @@
|
|||
import enum
|
||||
import logging
|
||||
import time
|
||||
from types import TracebackType
|
||||
from typing import (
|
||||
Iterable,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
import h11
|
||||
|
||||
from .._backends.base import NetworkStream
|
||||
from .._exceptions import (
|
||||
ConnectionNotAvailable,
|
||||
LocalProtocolError,
|
||||
RemoteProtocolError,
|
||||
map_exceptions,
|
||||
)
|
||||
from .._models import Origin, Request, Response
|
||||
from .._synchronization import Lock, ShieldCancellation
|
||||
from .._trace import Trace
|
||||
from .interfaces import ConnectionInterface
|
||||
|
||||
logger = logging.getLogger("httpcore.http11")
|
||||
|
||||
|
||||
# A subset of `h11.Event` types supported by `_send_event`
|
||||
H11SendEvent = Union[
|
||||
h11.Request,
|
||||
h11.Data,
|
||||
h11.EndOfMessage,
|
||||
]
|
||||
|
||||
|
||||
class HTTPConnectionState(enum.IntEnum):
|
||||
NEW = 0
|
||||
ACTIVE = 1
|
||||
IDLE = 2
|
||||
CLOSED = 3
|
||||
|
||||
|
||||
class HTTP11Connection(ConnectionInterface):
|
||||
READ_NUM_BYTES = 64 * 1024
|
||||
MAX_INCOMPLETE_EVENT_SIZE = 100 * 1024
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
origin: Origin,
|
||||
stream: NetworkStream,
|
||||
keepalive_expiry: Optional[float] = None,
|
||||
) -> None:
|
||||
self._origin = origin
|
||||
self._network_stream = stream
|
||||
self._keepalive_expiry: Optional[float] = keepalive_expiry
|
||||
self._expire_at: Optional[float] = None
|
||||
self._state = HTTPConnectionState.NEW
|
||||
self._state_lock = Lock()
|
||||
self._request_count = 0
|
||||
self._h11_state = h11.Connection(
|
||||
our_role=h11.CLIENT,
|
||||
max_incomplete_event_size=self.MAX_INCOMPLETE_EVENT_SIZE,
|
||||
)
|
||||
|
||||
def handle_request(self, request: Request) -> Response:
|
||||
if not self.can_handle_request(request.url.origin):
|
||||
raise RuntimeError(
|
||||
f"Attempted to send request to {request.url.origin} on connection "
|
||||
f"to {self._origin}"
|
||||
)
|
||||
|
||||
with self._state_lock:
|
||||
if self._state in (HTTPConnectionState.NEW, HTTPConnectionState.IDLE):
|
||||
self._request_count += 1
|
||||
self._state = HTTPConnectionState.ACTIVE
|
||||
self._expire_at = None
|
||||
else:
|
||||
raise ConnectionNotAvailable()
|
||||
|
||||
try:
|
||||
kwargs = {"request": request}
|
||||
with Trace("send_request_headers", logger, request, kwargs) as trace:
|
||||
self._send_request_headers(**kwargs)
|
||||
with Trace("send_request_body", logger, request, kwargs) as trace:
|
||||
self._send_request_body(**kwargs)
|
||||
with Trace(
|
||||
"receive_response_headers", logger, request, kwargs
|
||||
) as trace:
|
||||
(
|
||||
http_version,
|
||||
status,
|
||||
reason_phrase,
|
||||
headers,
|
||||
) = self._receive_response_headers(**kwargs)
|
||||
trace.return_value = (
|
||||
http_version,
|
||||
status,
|
||||
reason_phrase,
|
||||
headers,
|
||||
)
|
||||
|
||||
return Response(
|
||||
status=status,
|
||||
headers=headers,
|
||||
content=HTTP11ConnectionByteStream(self, request),
|
||||
extensions={
|
||||
"http_version": http_version,
|
||||
"reason_phrase": reason_phrase,
|
||||
"network_stream": self._network_stream,
|
||||
},
|
||||
)
|
||||
except BaseException as exc:
|
||||
with ShieldCancellation():
|
||||
with Trace("response_closed", logger, request) as trace:
|
||||
self._response_closed()
|
||||
raise exc
|
||||
|
||||
# Sending the request...
|
||||
|
||||
def _send_request_headers(self, request: Request) -> None:
|
||||
timeouts = request.extensions.get("timeout", {})
|
||||
timeout = timeouts.get("write", None)
|
||||
|
||||
with map_exceptions({h11.LocalProtocolError: LocalProtocolError}):
|
||||
event = h11.Request(
|
||||
method=request.method,
|
||||
target=request.url.target,
|
||||
headers=request.headers,
|
||||
)
|
||||
self._send_event(event, timeout=timeout)
|
||||
|
||||
def _send_request_body(self, request: Request) -> None:
|
||||
timeouts = request.extensions.get("timeout", {})
|
||||
timeout = timeouts.get("write", None)
|
||||
|
||||
assert isinstance(request.stream, Iterable)
|
||||
for chunk in request.stream:
|
||||
event = h11.Data(data=chunk)
|
||||
self._send_event(event, timeout=timeout)
|
||||
|
||||
self._send_event(h11.EndOfMessage(), timeout=timeout)
|
||||
|
||||
def _send_event(
|
||||
self, event: h11.Event, timeout: Optional[float] = None
|
||||
) -> None:
|
||||
bytes_to_send = self._h11_state.send(event)
|
||||
if bytes_to_send is not None:
|
||||
self._network_stream.write(bytes_to_send, timeout=timeout)
|
||||
|
||||
# Receiving the response...
|
||||
|
||||
def _receive_response_headers(
|
||||
self, request: Request
|
||||
) -> Tuple[bytes, int, bytes, List[Tuple[bytes, bytes]]]:
|
||||
timeouts = request.extensions.get("timeout", {})
|
||||
timeout = timeouts.get("read", None)
|
||||
|
||||
while True:
|
||||
event = self._receive_event(timeout=timeout)
|
||||
if isinstance(event, h11.Response):
|
||||
break
|
||||
if (
|
||||
isinstance(event, h11.InformationalResponse)
|
||||
and event.status_code == 101
|
||||
):
|
||||
break
|
||||
|
||||
http_version = b"HTTP/" + event.http_version
|
||||
|
||||
# h11 version 0.11+ supports a `raw_items` interface to get the
|
||||
# raw header casing, rather than the enforced lowercase headers.
|
||||
headers = event.headers.raw_items()
|
||||
|
||||
return http_version, event.status_code, event.reason, headers
|
||||
|
||||
def _receive_response_body(self, request: Request) -> Iterator[bytes]:
|
||||
timeouts = request.extensions.get("timeout", {})
|
||||
timeout = timeouts.get("read", None)
|
||||
|
||||
while True:
|
||||
event = self._receive_event(timeout=timeout)
|
||||
if isinstance(event, h11.Data):
|
||||
yield bytes(event.data)
|
||||
elif isinstance(event, (h11.EndOfMessage, h11.PAUSED)):
|
||||
break
|
||||
|
||||
def _receive_event(
|
||||
self, timeout: Optional[float] = None
|
||||
) -> Union[h11.Event, Type[h11.PAUSED]]:
|
||||
while True:
|
||||
with map_exceptions({h11.RemoteProtocolError: RemoteProtocolError}):
|
||||
event = self._h11_state.next_event()
|
||||
|
||||
if event is h11.NEED_DATA:
|
||||
data = self._network_stream.read(
|
||||
self.READ_NUM_BYTES, timeout=timeout
|
||||
)
|
||||
|
||||
# If we feed this case through h11 we'll raise an exception like:
|
||||
#
|
||||
# httpcore.RemoteProtocolError: can't handle event type
|
||||
# ConnectionClosed when role=SERVER and state=SEND_RESPONSE
|
||||
#
|
||||
# Which is accurate, but not very informative from an end-user
|
||||
# perspective. Instead we handle this case distinctly and treat
|
||||
# it as a ConnectError.
|
||||
if data == b"" and self._h11_state.their_state == h11.SEND_RESPONSE:
|
||||
msg = "Server disconnected without sending a response."
|
||||
raise RemoteProtocolError(msg)
|
||||
|
||||
self._h11_state.receive_data(data)
|
||||
else:
|
||||
# mypy fails to narrow the type in the above if statement above
|
||||
return cast(Union[h11.Event, Type[h11.PAUSED]], event)
|
||||
|
||||
def _response_closed(self) -> None:
|
||||
with self._state_lock:
|
||||
if (
|
||||
self._h11_state.our_state is h11.DONE
|
||||
and self._h11_state.their_state is h11.DONE
|
||||
):
|
||||
self._state = HTTPConnectionState.IDLE
|
||||
self._h11_state.start_next_cycle()
|
||||
if self._keepalive_expiry is not None:
|
||||
now = time.monotonic()
|
||||
self._expire_at = now + self._keepalive_expiry
|
||||
else:
|
||||
self.close()
|
||||
|
||||
# Once the connection is no longer required...
|
||||
|
||||
def close(self) -> None:
|
||||
# Note that this method unilaterally closes the connection, and does
|
||||
# not have any kind of locking in place around it.
|
||||
self._state = HTTPConnectionState.CLOSED
|
||||
self._network_stream.close()
|
||||
|
||||
# The ConnectionInterface methods provide information about the state of
|
||||
# the connection, allowing for a connection pooling implementation to
|
||||
# determine when to reuse and when to close the connection...
|
||||
|
||||
def can_handle_request(self, origin: Origin) -> bool:
|
||||
return origin == self._origin
|
||||
|
||||
def is_available(self) -> bool:
|
||||
# Note that HTTP/1.1 connections in the "NEW" state are not treated as
|
||||
# being "available". The control flow which created the connection will
|
||||
# be able to send an outgoing request, but the connection will not be
|
||||
# acquired from the connection pool for any other request.
|
||||
return self._state == HTTPConnectionState.IDLE
|
||||
|
||||
def has_expired(self) -> bool:
|
||||
now = time.monotonic()
|
||||
keepalive_expired = self._expire_at is not None and now > self._expire_at
|
||||
|
||||
# If the HTTP connection is idle but the socket is readable, then the
|
||||
# only valid state is that the socket is about to return b"", indicating
|
||||
# a server-initiated disconnect.
|
||||
server_disconnected = (
|
||||
self._state == HTTPConnectionState.IDLE
|
||||
and self._network_stream.get_extra_info("is_readable")
|
||||
)
|
||||
|
||||
return keepalive_expired or server_disconnected
|
||||
|
||||
def is_idle(self) -> bool:
|
||||
return self._state == HTTPConnectionState.IDLE
|
||||
|
||||
def is_closed(self) -> bool:
|
||||
return self._state == HTTPConnectionState.CLOSED
|
||||
|
||||
def info(self) -> str:
|
||||
origin = str(self._origin)
|
||||
return (
|
||||
f"{origin!r}, HTTP/1.1, {self._state.name}, "
|
||||
f"Request Count: {self._request_count}"
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
class_name = self.__class__.__name__
|
||||
origin = str(self._origin)
|
||||
return (
|
||||
f"<{class_name} [{origin!r}, {self._state.name}, "
|
||||
f"Request Count: {self._request_count}]>"
|
||||
)
|
||||
|
||||
# These context managers are not used in the standard flow, but are
|
||||
# useful for testing or working with connection instances directly.
|
||||
|
||||
def __enter__(self) -> "HTTP11Connection":
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: Optional[Type[BaseException]] = None,
|
||||
exc_value: Optional[BaseException] = None,
|
||||
traceback: Optional[TracebackType] = None,
|
||||
) -> None:
|
||||
self.close()
|
||||
|
||||
|
||||
class HTTP11ConnectionByteStream:
|
||||
def __init__(self, connection: HTTP11Connection, request: Request) -> None:
|
||||
self._connection = connection
|
||||
self._request = request
|
||||
self._closed = False
|
||||
|
||||
def __iter__(self) -> Iterator[bytes]:
|
||||
kwargs = {"request": self._request}
|
||||
try:
|
||||
with Trace("receive_response_body", logger, self._request, kwargs):
|
||||
for chunk in self._connection._receive_response_body(**kwargs):
|
||||
yield chunk
|
||||
except BaseException as exc:
|
||||
# If we get an exception while streaming the response,
|
||||
# we want to close the response (and possibly the connection)
|
||||
# before raising that exception.
|
||||
with ShieldCancellation():
|
||||
self.close()
|
||||
raise exc
|
||||
|
||||
def close(self) -> None:
|
||||
if not self._closed:
|
||||
self._closed = True
|
||||
with Trace("response_closed", logger, self._request):
|
||||
self._connection._response_closed()
|
|
@ -0,0 +1,589 @@
|
|||
import enum
|
||||
import logging
|
||||
import time
|
||||
import types
|
||||
import typing
|
||||
|
||||
import h2.config
|
||||
import h2.connection
|
||||
import h2.events
|
||||
import h2.exceptions
|
||||
import h2.settings
|
||||
|
||||
from .._backends.base import NetworkStream
|
||||
from .._exceptions import (
|
||||
ConnectionNotAvailable,
|
||||
LocalProtocolError,
|
||||
RemoteProtocolError,
|
||||
)
|
||||
from .._models import Origin, Request, Response
|
||||
from .._synchronization import Lock, Semaphore, ShieldCancellation
|
||||
from .._trace import Trace
|
||||
from .interfaces import ConnectionInterface
|
||||
|
||||
logger = logging.getLogger("httpcore.http2")
|
||||
|
||||
|
||||
def has_body_headers(request: Request) -> bool:
|
||||
return any(
|
||||
k.lower() == b"content-length" or k.lower() == b"transfer-encoding"
|
||||
for k, v in request.headers
|
||||
)
|
||||
|
||||
|
||||
class HTTPConnectionState(enum.IntEnum):
|
||||
ACTIVE = 1
|
||||
IDLE = 2
|
||||
CLOSED = 3
|
||||
|
||||
|
||||
class HTTP2Connection(ConnectionInterface):
|
||||
READ_NUM_BYTES = 64 * 1024
|
||||
CONFIG = h2.config.H2Configuration(validate_inbound_headers=False)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
origin: Origin,
|
||||
stream: NetworkStream,
|
||||
keepalive_expiry: typing.Optional[float] = None,
|
||||
):
|
||||
self._origin = origin
|
||||
self._network_stream = stream
|
||||
self._keepalive_expiry: typing.Optional[float] = keepalive_expiry
|
||||
self._h2_state = h2.connection.H2Connection(config=self.CONFIG)
|
||||
self._state = HTTPConnectionState.IDLE
|
||||
self._expire_at: typing.Optional[float] = None
|
||||
self._request_count = 0
|
||||
self._init_lock = Lock()
|
||||
self._state_lock = Lock()
|
||||
self._read_lock = Lock()
|
||||
self._write_lock = Lock()
|
||||
self._sent_connection_init = False
|
||||
self._used_all_stream_ids = False
|
||||
self._connection_error = False
|
||||
|
||||
# Mapping from stream ID to response stream events.
|
||||
self._events: typing.Dict[
|
||||
int,
|
||||
typing.Union[
|
||||
h2.events.ResponseReceived,
|
||||
h2.events.DataReceived,
|
||||
h2.events.StreamEnded,
|
||||
h2.events.StreamReset,
|
||||
],
|
||||
] = {}
|
||||
|
||||
# Connection terminated events are stored as state since
|
||||
# we need to handle them for all streams.
|
||||
self._connection_terminated: typing.Optional[
|
||||
h2.events.ConnectionTerminated
|
||||
] = None
|
||||
|
||||
self._read_exception: typing.Optional[Exception] = None
|
||||
self._write_exception: typing.Optional[Exception] = None
|
||||
|
||||
def handle_request(self, request: Request) -> Response:
|
||||
if not self.can_handle_request(request.url.origin):
|
||||
# This cannot occur in normal operation, since the connection pool
|
||||
# will only send requests on connections that handle them.
|
||||
# It's in place simply for resilience as a guard against incorrect
|
||||
# usage, for anyone working directly with httpcore connections.
|
||||
raise RuntimeError(
|
||||
f"Attempted to send request to {request.url.origin} on connection "
|
||||
f"to {self._origin}"
|
||||
)
|
||||
|
||||
with self._state_lock:
|
||||
if self._state in (HTTPConnectionState.ACTIVE, HTTPConnectionState.IDLE):
|
||||
self._request_count += 1
|
||||
self._expire_at = None
|
||||
self._state = HTTPConnectionState.ACTIVE
|
||||
else:
|
||||
raise ConnectionNotAvailable()
|
||||
|
||||
with self._init_lock:
|
||||
if not self._sent_connection_init:
|
||||
try:
|
||||
kwargs = {"request": request}
|
||||
with Trace("send_connection_init", logger, request, kwargs):
|
||||
self._send_connection_init(**kwargs)
|
||||
except BaseException as exc:
|
||||
with ShieldCancellation():
|
||||
self.close()
|
||||
raise exc
|
||||
|
||||
self._sent_connection_init = True
|
||||
|
||||
# Initially start with just 1 until the remote server provides
|
||||
# its max_concurrent_streams value
|
||||
self._max_streams = 1
|
||||
|
||||
local_settings_max_streams = (
|
||||
self._h2_state.local_settings.max_concurrent_streams
|
||||
)
|
||||
self._max_streams_semaphore = Semaphore(local_settings_max_streams)
|
||||
|
||||
for _ in range(local_settings_max_streams - self._max_streams):
|
||||
self._max_streams_semaphore.acquire()
|
||||
|
||||
self._max_streams_semaphore.acquire()
|
||||
|
||||
try:
|
||||
stream_id = self._h2_state.get_next_available_stream_id()
|
||||
self._events[stream_id] = []
|
||||
except h2.exceptions.NoAvailableStreamIDError: # pragma: nocover
|
||||
self._used_all_stream_ids = True
|
||||
self._request_count -= 1
|
||||
raise ConnectionNotAvailable()
|
||||
|
||||
try:
|
||||
kwargs = {"request": request, "stream_id": stream_id}
|
||||
with Trace("send_request_headers", logger, request, kwargs):
|
||||
self._send_request_headers(request=request, stream_id=stream_id)
|
||||
with Trace("send_request_body", logger, request, kwargs):
|
||||
self._send_request_body(request=request, stream_id=stream_id)
|
||||
with Trace(
|
||||
"receive_response_headers", logger, request, kwargs
|
||||
) as trace:
|
||||
status, headers = self._receive_response(
|
||||
request=request, stream_id=stream_id
|
||||
)
|
||||
trace.return_value = (status, headers)
|
||||
|
||||
return Response(
|
||||
status=status,
|
||||
headers=headers,
|
||||
content=HTTP2ConnectionByteStream(self, request, stream_id=stream_id),
|
||||
extensions={
|
||||
"http_version": b"HTTP/2",
|
||||
"network_stream": self._network_stream,
|
||||
"stream_id": stream_id,
|
||||
},
|
||||
)
|
||||
except BaseException as exc: # noqa: PIE786
|
||||
with ShieldCancellation():
|
||||
kwargs = {"stream_id": stream_id}
|
||||
with Trace("response_closed", logger, request, kwargs):
|
||||
self._response_closed(stream_id=stream_id)
|
||||
|
||||
if isinstance(exc, h2.exceptions.ProtocolError):
|
||||
# One case where h2 can raise a protocol error is when a
|
||||
# closed frame has been seen by the state machine.
|
||||
#
|
||||
# This happens when one stream is reading, and encounters
|
||||
# a GOAWAY event. Other flows of control may then raise
|
||||
# a protocol error at any point they interact with the 'h2_state'.
|
||||
#
|
||||
# In this case we'll have stored the event, and should raise
|
||||
# it as a RemoteProtocolError.
|
||||
if self._connection_terminated: # pragma: nocover
|
||||
raise RemoteProtocolError(self._connection_terminated)
|
||||
# If h2 raises a protocol error in some other state then we
|
||||
# must somehow have made a protocol violation.
|
||||
raise LocalProtocolError(exc) # pragma: nocover
|
||||
|
||||
raise exc
|
||||
|
||||
def _send_connection_init(self, request: Request) -> None:
|
||||
"""
|
||||
The HTTP/2 connection requires some initial setup before we can start
|
||||
using individual request/response streams on it.
|
||||
"""
|
||||
# Need to set these manually here instead of manipulating via
|
||||
# __setitem__() otherwise the H2Connection will emit SettingsUpdate
|
||||
# frames in addition to sending the undesired defaults.
|
||||
self._h2_state.local_settings = h2.settings.Settings(
|
||||
client=True,
|
||||
initial_values={
|
||||
# Disable PUSH_PROMISE frames from the server since we don't do anything
|
||||
# with them for now. Maybe when we support caching?
|
||||
h2.settings.SettingCodes.ENABLE_PUSH: 0,
|
||||
# These two are taken from h2 for safe defaults
|
||||
h2.settings.SettingCodes.MAX_CONCURRENT_STREAMS: 100,
|
||||
h2.settings.SettingCodes.MAX_HEADER_LIST_SIZE: 65536,
|
||||
},
|
||||
)
|
||||
|
||||
# Some websites (*cough* Yahoo *cough*) balk at this setting being
|
||||
# present in the initial handshake since it's not defined in the original
|
||||
# RFC despite the RFC mandating ignoring settings you don't know about.
|
||||
del self._h2_state.local_settings[
|
||||
h2.settings.SettingCodes.ENABLE_CONNECT_PROTOCOL
|
||||
]
|
||||
|
||||
self._h2_state.initiate_connection()
|
||||
self._h2_state.increment_flow_control_window(2**24)
|
||||
self._write_outgoing_data(request)
|
||||
|
||||
# Sending the request...
|
||||
|
||||
def _send_request_headers(self, request: Request, stream_id: int) -> None:
|
||||
"""
|
||||
Send the request headers to a given stream ID.
|
||||
"""
|
||||
end_stream = not has_body_headers(request)
|
||||
|
||||
# In HTTP/2 the ':authority' pseudo-header is used instead of 'Host'.
|
||||
# In order to gracefully handle HTTP/1.1 and HTTP/2 we always require
|
||||
# HTTP/1.1 style headers, and map them appropriately if we end up on
|
||||
# an HTTP/2 connection.
|
||||
authority = [v for k, v in request.headers if k.lower() == b"host"][0]
|
||||
|
||||
headers = [
|
||||
(b":method", request.method),
|
||||
(b":authority", authority),
|
||||
(b":scheme", request.url.scheme),
|
||||
(b":path", request.url.target),
|
||||
] + [
|
||||
(k.lower(), v)
|
||||
for k, v in request.headers
|
||||
if k.lower()
|
||||
not in (
|
||||
b"host",
|
||||
b"transfer-encoding",
|
||||
)
|
||||
]
|
||||
|
||||
self._h2_state.send_headers(stream_id, headers, end_stream=end_stream)
|
||||
self._h2_state.increment_flow_control_window(2**24, stream_id=stream_id)
|
||||
self._write_outgoing_data(request)
|
||||
|
||||
def _send_request_body(self, request: Request, stream_id: int) -> None:
|
||||
"""
|
||||
Iterate over the request body sending it to a given stream ID.
|
||||
"""
|
||||
if not has_body_headers(request):
|
||||
return
|
||||
|
||||
assert isinstance(request.stream, typing.Iterable)
|
||||
for data in request.stream:
|
||||
self._send_stream_data(request, stream_id, data)
|
||||
self._send_end_stream(request, stream_id)
|
||||
|
||||
def _send_stream_data(
|
||||
self, request: Request, stream_id: int, data: bytes
|
||||
) -> None:
|
||||
"""
|
||||
Send a single chunk of data in one or more data frames.
|
||||
"""
|
||||
while data:
|
||||
max_flow = self._wait_for_outgoing_flow(request, stream_id)
|
||||
chunk_size = min(len(data), max_flow)
|
||||
chunk, data = data[:chunk_size], data[chunk_size:]
|
||||
self._h2_state.send_data(stream_id, chunk)
|
||||
self._write_outgoing_data(request)
|
||||
|
||||
def _send_end_stream(self, request: Request, stream_id: int) -> None:
|
||||
"""
|
||||
Send an empty data frame on on a given stream ID with the END_STREAM flag set.
|
||||
"""
|
||||
self._h2_state.end_stream(stream_id)
|
||||
self._write_outgoing_data(request)
|
||||
|
||||
# Receiving the response...
|
||||
|
||||
def _receive_response(
|
||||
self, request: Request, stream_id: int
|
||||
) -> typing.Tuple[int, typing.List[typing.Tuple[bytes, bytes]]]:
|
||||
"""
|
||||
Return the response status code and headers for a given stream ID.
|
||||
"""
|
||||
while True:
|
||||
event = self._receive_stream_event(request, stream_id)
|
||||
if isinstance(event, h2.events.ResponseReceived):
|
||||
break
|
||||
|
||||
status_code = 200
|
||||
headers = []
|
||||
for k, v in event.headers:
|
||||
if k == b":status":
|
||||
status_code = int(v.decode("ascii", errors="ignore"))
|
||||
elif not k.startswith(b":"):
|
||||
headers.append((k, v))
|
||||
|
||||
return (status_code, headers)
|
||||
|
||||
def _receive_response_body(
|
||||
self, request: Request, stream_id: int
|
||||
) -> typing.Iterator[bytes]:
|
||||
"""
|
||||
Iterator that returns the bytes of the response body for a given stream ID.
|
||||
"""
|
||||
while True:
|
||||
event = self._receive_stream_event(request, stream_id)
|
||||
if isinstance(event, h2.events.DataReceived):
|
||||
amount = event.flow_controlled_length
|
||||
self._h2_state.acknowledge_received_data(amount, stream_id)
|
||||
self._write_outgoing_data(request)
|
||||
yield event.data
|
||||
elif isinstance(event, h2.events.StreamEnded):
|
||||
break
|
||||
|
||||
def _receive_stream_event(
|
||||
self, request: Request, stream_id: int
|
||||
) -> typing.Union[
|
||||
h2.events.ResponseReceived, h2.events.DataReceived, h2.events.StreamEnded
|
||||
]:
|
||||
"""
|
||||
Return the next available event for a given stream ID.
|
||||
|
||||
Will read more data from the network if required.
|
||||
"""
|
||||
while not self._events.get(stream_id):
|
||||
self._receive_events(request, stream_id)
|
||||
event = self._events[stream_id].pop(0)
|
||||
if isinstance(event, h2.events.StreamReset):
|
||||
raise RemoteProtocolError(event)
|
||||
return event
|
||||
|
||||
def _receive_events(
|
||||
self, request: Request, stream_id: typing.Optional[int] = None
|
||||
) -> None:
|
||||
"""
|
||||
Read some data from the network until we see one or more events
|
||||
for a given stream ID.
|
||||
"""
|
||||
with self._read_lock:
|
||||
if self._connection_terminated is not None:
|
||||
last_stream_id = self._connection_terminated.last_stream_id
|
||||
if stream_id and last_stream_id and stream_id > last_stream_id:
|
||||
self._request_count -= 1
|
||||
raise ConnectionNotAvailable()
|
||||
raise RemoteProtocolError(self._connection_terminated)
|
||||
|
||||
# This conditional is a bit icky. We don't want to block reading if we've
|
||||
# actually got an event to return for a given stream. We need to do that
|
||||
# check *within* the atomic read lock. Though it also need to be optional,
|
||||
# because when we call it from `_wait_for_outgoing_flow` we *do* want to
|
||||
# block until we've available flow control, event when we have events
|
||||
# pending for the stream ID we're attempting to send on.
|
||||
if stream_id is None or not self._events.get(stream_id):
|
||||
events = self._read_incoming_data(request)
|
||||
for event in events:
|
||||
if isinstance(event, h2.events.RemoteSettingsChanged):
|
||||
with Trace(
|
||||
"receive_remote_settings", logger, request
|
||||
) as trace:
|
||||
self._receive_remote_settings_change(event)
|
||||
trace.return_value = event
|
||||
|
||||
elif isinstance(
|
||||
event,
|
||||
(
|
||||
h2.events.ResponseReceived,
|
||||
h2.events.DataReceived,
|
||||
h2.events.StreamEnded,
|
||||
h2.events.StreamReset,
|
||||
),
|
||||
):
|
||||
if event.stream_id in self._events:
|
||||
self._events[event.stream_id].append(event)
|
||||
|
||||
elif isinstance(event, h2.events.ConnectionTerminated):
|
||||
self._connection_terminated = event
|
||||
|
||||
self._write_outgoing_data(request)
|
||||
|
||||
def _receive_remote_settings_change(self, event: h2.events.Event) -> None:
|
||||
max_concurrent_streams = event.changed_settings.get(
|
||||
h2.settings.SettingCodes.MAX_CONCURRENT_STREAMS
|
||||
)
|
||||
if max_concurrent_streams:
|
||||
new_max_streams = min(
|
||||
max_concurrent_streams.new_value,
|
||||
self._h2_state.local_settings.max_concurrent_streams,
|
||||
)
|
||||
if new_max_streams and new_max_streams != self._max_streams:
|
||||
while new_max_streams > self._max_streams:
|
||||
self._max_streams_semaphore.release()
|
||||
self._max_streams += 1
|
||||
while new_max_streams < self._max_streams:
|
||||
self._max_streams_semaphore.acquire()
|
||||
self._max_streams -= 1
|
||||
|
||||
def _response_closed(self, stream_id: int) -> None:
|
||||
self._max_streams_semaphore.release()
|
||||
del self._events[stream_id]
|
||||
with self._state_lock:
|
||||
if self._connection_terminated and not self._events:
|
||||
self.close()
|
||||
|
||||
elif self._state == HTTPConnectionState.ACTIVE and not self._events:
|
||||
self._state = HTTPConnectionState.IDLE
|
||||
if self._keepalive_expiry is not None:
|
||||
now = time.monotonic()
|
||||
self._expire_at = now + self._keepalive_expiry
|
||||
if self._used_all_stream_ids: # pragma: nocover
|
||||
self.close()
|
||||
|
||||
def close(self) -> None:
|
||||
# Note that this method unilaterally closes the connection, and does
|
||||
# not have any kind of locking in place around it.
|
||||
self._h2_state.close_connection()
|
||||
self._state = HTTPConnectionState.CLOSED
|
||||
self._network_stream.close()
|
||||
|
||||
# Wrappers around network read/write operations...
|
||||
|
||||
def _read_incoming_data(
|
||||
self, request: Request
|
||||
) -> typing.List[h2.events.Event]:
|
||||
timeouts = request.extensions.get("timeout", {})
|
||||
timeout = timeouts.get("read", None)
|
||||
|
||||
if self._read_exception is not None:
|
||||
raise self._read_exception # pragma: nocover
|
||||
|
||||
try:
|
||||
data = self._network_stream.read(self.READ_NUM_BYTES, timeout)
|
||||
if data == b"":
|
||||
raise RemoteProtocolError("Server disconnected")
|
||||
except Exception as exc:
|
||||
# If we get a network error we should:
|
||||
#
|
||||
# 1. Save the exception and just raise it immediately on any future reads.
|
||||
# (For example, this means that a single read timeout or disconnect will
|
||||
# immediately close all pending streams. Without requiring multiple
|
||||
# sequential timeouts.)
|
||||
# 2. Mark the connection as errored, so that we don't accept any other
|
||||
# incoming requests.
|
||||
self._read_exception = exc
|
||||
self._connection_error = True
|
||||
raise exc
|
||||
|
||||
events: typing.List[h2.events.Event] = self._h2_state.receive_data(data)
|
||||
|
||||
return events
|
||||
|
||||
def _write_outgoing_data(self, request: Request) -> None:
|
||||
timeouts = request.extensions.get("timeout", {})
|
||||
timeout = timeouts.get("write", None)
|
||||
|
||||
with self._write_lock:
|
||||
data_to_send = self._h2_state.data_to_send()
|
||||
|
||||
if self._write_exception is not None:
|
||||
raise self._write_exception # pragma: nocover
|
||||
|
||||
try:
|
||||
self._network_stream.write(data_to_send, timeout)
|
||||
except Exception as exc: # pragma: nocover
|
||||
# If we get a network error we should:
|
||||
#
|
||||
# 1. Save the exception and just raise it immediately on any future write.
|
||||
# (For example, this means that a single write timeout or disconnect will
|
||||
# immediately close all pending streams. Without requiring multiple
|
||||
# sequential timeouts.)
|
||||
# 2. Mark the connection as errored, so that we don't accept any other
|
||||
# incoming requests.
|
||||
self._write_exception = exc
|
||||
self._connection_error = True
|
||||
raise exc
|
||||
|
||||
# Flow control...
|
||||
|
||||
def _wait_for_outgoing_flow(self, request: Request, stream_id: int) -> int:
|
||||
"""
|
||||
Returns the maximum allowable outgoing flow for a given stream.
|
||||
|
||||
If the allowable flow is zero, then waits on the network until
|
||||
WindowUpdated frames have increased the flow rate.
|
||||
https://tools.ietf.org/html/rfc7540#section-6.9
|
||||
"""
|
||||
local_flow: int = self._h2_state.local_flow_control_window(stream_id)
|
||||
max_frame_size: int = self._h2_state.max_outbound_frame_size
|
||||
flow = min(local_flow, max_frame_size)
|
||||
while flow == 0:
|
||||
self._receive_events(request)
|
||||
local_flow = self._h2_state.local_flow_control_window(stream_id)
|
||||
max_frame_size = self._h2_state.max_outbound_frame_size
|
||||
flow = min(local_flow, max_frame_size)
|
||||
return flow
|
||||
|
||||
# Interface for connection pooling...
|
||||
|
||||
def can_handle_request(self, origin: Origin) -> bool:
|
||||
return origin == self._origin
|
||||
|
||||
def is_available(self) -> bool:
|
||||
return (
|
||||
self._state != HTTPConnectionState.CLOSED
|
||||
and not self._connection_error
|
||||
and not self._used_all_stream_ids
|
||||
and not (
|
||||
self._h2_state.state_machine.state
|
||||
== h2.connection.ConnectionState.CLOSED
|
||||
)
|
||||
)
|
||||
|
||||
def has_expired(self) -> bool:
|
||||
now = time.monotonic()
|
||||
return self._expire_at is not None and now > self._expire_at
|
||||
|
||||
def is_idle(self) -> bool:
|
||||
return self._state == HTTPConnectionState.IDLE
|
||||
|
||||
def is_closed(self) -> bool:
|
||||
return self._state == HTTPConnectionState.CLOSED
|
||||
|
||||
def info(self) -> str:
|
||||
origin = str(self._origin)
|
||||
return (
|
||||
f"{origin!r}, HTTP/2, {self._state.name}, "
|
||||
f"Request Count: {self._request_count}"
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
class_name = self.__class__.__name__
|
||||
origin = str(self._origin)
|
||||
return (
|
||||
f"<{class_name} [{origin!r}, {self._state.name}, "
|
||||
f"Request Count: {self._request_count}]>"
|
||||
)
|
||||
|
||||
# These context managers are not used in the standard flow, but are
|
||||
# useful for testing or working with connection instances directly.
|
||||
|
||||
def __enter__(self) -> "HTTP2Connection":
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: typing.Optional[typing.Type[BaseException]] = None,
|
||||
exc_value: typing.Optional[BaseException] = None,
|
||||
traceback: typing.Optional[types.TracebackType] = None,
|
||||
) -> None:
|
||||
self.close()
|
||||
|
||||
|
||||
class HTTP2ConnectionByteStream:
|
||||
def __init__(
|
||||
self, connection: HTTP2Connection, request: Request, stream_id: int
|
||||
) -> None:
|
||||
self._connection = connection
|
||||
self._request = request
|
||||
self._stream_id = stream_id
|
||||
self._closed = False
|
||||
|
||||
def __iter__(self) -> typing.Iterator[bytes]:
|
||||
kwargs = {"request": self._request, "stream_id": self._stream_id}
|
||||
try:
|
||||
with Trace("receive_response_body", logger, self._request, kwargs):
|
||||
for chunk in self._connection._receive_response_body(
|
||||
request=self._request, stream_id=self._stream_id
|
||||
):
|
||||
yield chunk
|
||||
except BaseException as exc:
|
||||
# If we get an exception while streaming the response,
|
||||
# we want to close the response (and possibly the connection)
|
||||
# before raising that exception.
|
||||
with ShieldCancellation():
|
||||
self.close()
|
||||
raise exc
|
||||
|
||||
def close(self) -> None:
|
||||
if not self._closed:
|
||||
self._closed = True
|
||||
kwargs = {"stream_id": self._stream_id}
|
||||
with Trace("response_closed", logger, self._request, kwargs):
|
||||
self._connection._response_closed(stream_id=self._stream_id)
|
|
@ -0,0 +1,350 @@
|
|||
import logging
|
||||
import ssl
|
||||
from base64 import b64encode
|
||||
from typing import Iterable, List, Mapping, Optional, Sequence, Tuple, Union
|
||||
|
||||
from .._backends.base import SOCKET_OPTION, NetworkBackend
|
||||
from .._exceptions import ProxyError
|
||||
from .._models import (
|
||||
URL,
|
||||
Origin,
|
||||
Request,
|
||||
Response,
|
||||
enforce_bytes,
|
||||
enforce_headers,
|
||||
enforce_url,
|
||||
)
|
||||
from .._ssl import default_ssl_context
|
||||
from .._synchronization import Lock
|
||||
from .._trace import Trace
|
||||
from .connection import HTTPConnection
|
||||
from .connection_pool import ConnectionPool
|
||||
from .http11 import HTTP11Connection
|
||||
from .interfaces import ConnectionInterface
|
||||
|
||||
HeadersAsSequence = Sequence[Tuple[Union[bytes, str], Union[bytes, str]]]
|
||||
HeadersAsMapping = Mapping[Union[bytes, str], Union[bytes, str]]
|
||||
|
||||
|
||||
logger = logging.getLogger("httpcore.proxy")
|
||||
|
||||
|
||||
def merge_headers(
|
||||
default_headers: Optional[Sequence[Tuple[bytes, bytes]]] = None,
|
||||
override_headers: Optional[Sequence[Tuple[bytes, bytes]]] = None,
|
||||
) -> List[Tuple[bytes, bytes]]:
|
||||
"""
|
||||
Append default_headers and override_headers, de-duplicating if a key exists
|
||||
in both cases.
|
||||
"""
|
||||
default_headers = [] if default_headers is None else list(default_headers)
|
||||
override_headers = [] if override_headers is None else list(override_headers)
|
||||
has_override = set(key.lower() for key, value in override_headers)
|
||||
default_headers = [
|
||||
(key, value)
|
||||
for key, value in default_headers
|
||||
if key.lower() not in has_override
|
||||
]
|
||||
return default_headers + override_headers
|
||||
|
||||
|
||||
def build_auth_header(username: bytes, password: bytes) -> bytes:
|
||||
userpass = username + b":" + password
|
||||
return b"Basic " + b64encode(userpass)
|
||||
|
||||
|
||||
class HTTPProxy(ConnectionPool):
|
||||
"""
|
||||
A connection pool that sends requests via an HTTP proxy.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
proxy_url: Union[URL, bytes, str],
|
||||
proxy_auth: Optional[Tuple[Union[bytes, str], Union[bytes, str]]] = None,
|
||||
proxy_headers: Union[HeadersAsMapping, HeadersAsSequence, None] = None,
|
||||
ssl_context: Optional[ssl.SSLContext] = None,
|
||||
max_connections: Optional[int] = 10,
|
||||
max_keepalive_connections: Optional[int] = None,
|
||||
keepalive_expiry: Optional[float] = None,
|
||||
http1: bool = True,
|
||||
http2: bool = False,
|
||||
retries: int = 0,
|
||||
local_address: Optional[str] = None,
|
||||
uds: Optional[str] = None,
|
||||
network_backend: Optional[NetworkBackend] = None,
|
||||
socket_options: Optional[Iterable[SOCKET_OPTION]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
A connection pool for making HTTP requests.
|
||||
|
||||
Parameters:
|
||||
proxy_url: The URL to use when connecting to the proxy server.
|
||||
For example `"http://127.0.0.1:8080/"`.
|
||||
proxy_auth: Any proxy authentication as a two-tuple of
|
||||
(username, password). May be either bytes or ascii-only str.
|
||||
proxy_headers: Any HTTP headers to use for the proxy requests.
|
||||
For example `{"Proxy-Authorization": "Basic <username>:<password>"}`.
|
||||
ssl_context: An SSL context to use for verifying connections.
|
||||
If not specified, the default `httpcore.default_ssl_context()`
|
||||
will be used.
|
||||
max_connections: The maximum number of concurrent HTTP connections that
|
||||
the pool should allow. Any attempt to send a request on a pool that
|
||||
would exceed this amount will block until a connection is available.
|
||||
max_keepalive_connections: The maximum number of idle HTTP connections
|
||||
that will be maintained in the pool.
|
||||
keepalive_expiry: The duration in seconds that an idle HTTP connection
|
||||
may be maintained for before being expired from the pool.
|
||||
http1: A boolean indicating if HTTP/1.1 requests should be supported
|
||||
by the connection pool. Defaults to True.
|
||||
http2: A boolean indicating if HTTP/2 requests should be supported by
|
||||
the connection pool. Defaults to False.
|
||||
retries: The maximum number of retries when trying to establish
|
||||
a connection.
|
||||
local_address: Local address to connect from. Can also be used to
|
||||
connect using a particular address family. Using
|
||||
`local_address="0.0.0.0"` will connect using an `AF_INET` address
|
||||
(IPv4), while using `local_address="::"` will connect using an
|
||||
`AF_INET6` address (IPv6).
|
||||
uds: Path to a Unix Domain Socket to use instead of TCP sockets.
|
||||
network_backend: A backend instance to use for handling network I/O.
|
||||
"""
|
||||
super().__init__(
|
||||
ssl_context=ssl_context,
|
||||
max_connections=max_connections,
|
||||
max_keepalive_connections=max_keepalive_connections,
|
||||
keepalive_expiry=keepalive_expiry,
|
||||
http1=http1,
|
||||
http2=http2,
|
||||
network_backend=network_backend,
|
||||
retries=retries,
|
||||
local_address=local_address,
|
||||
uds=uds,
|
||||
socket_options=socket_options,
|
||||
)
|
||||
self._ssl_context = ssl_context
|
||||
self._proxy_url = enforce_url(proxy_url, name="proxy_url")
|
||||
self._proxy_headers = enforce_headers(proxy_headers, name="proxy_headers")
|
||||
if proxy_auth is not None:
|
||||
username = enforce_bytes(proxy_auth[0], name="proxy_auth")
|
||||
password = enforce_bytes(proxy_auth[1], name="proxy_auth")
|
||||
authorization = build_auth_header(username, password)
|
||||
self._proxy_headers = [
|
||||
(b"Proxy-Authorization", authorization)
|
||||
] + self._proxy_headers
|
||||
|
||||
def create_connection(self, origin: Origin) -> ConnectionInterface:
|
||||
if origin.scheme == b"http":
|
||||
return ForwardHTTPConnection(
|
||||
proxy_origin=self._proxy_url.origin,
|
||||
proxy_headers=self._proxy_headers,
|
||||
remote_origin=origin,
|
||||
keepalive_expiry=self._keepalive_expiry,
|
||||
network_backend=self._network_backend,
|
||||
)
|
||||
return TunnelHTTPConnection(
|
||||
proxy_origin=self._proxy_url.origin,
|
||||
proxy_headers=self._proxy_headers,
|
||||
remote_origin=origin,
|
||||
ssl_context=self._ssl_context,
|
||||
keepalive_expiry=self._keepalive_expiry,
|
||||
http1=self._http1,
|
||||
http2=self._http2,
|
||||
network_backend=self._network_backend,
|
||||
)
|
||||
|
||||
|
||||
class ForwardHTTPConnection(ConnectionInterface):
|
||||
def __init__(
|
||||
self,
|
||||
proxy_origin: Origin,
|
||||
remote_origin: Origin,
|
||||
proxy_headers: Union[HeadersAsMapping, HeadersAsSequence, None] = None,
|
||||
keepalive_expiry: Optional[float] = None,
|
||||
network_backend: Optional[NetworkBackend] = None,
|
||||
socket_options: Optional[Iterable[SOCKET_OPTION]] = None,
|
||||
) -> None:
|
||||
self._connection = HTTPConnection(
|
||||
origin=proxy_origin,
|
||||
keepalive_expiry=keepalive_expiry,
|
||||
network_backend=network_backend,
|
||||
socket_options=socket_options,
|
||||
)
|
||||
self._proxy_origin = proxy_origin
|
||||
self._proxy_headers = enforce_headers(proxy_headers, name="proxy_headers")
|
||||
self._remote_origin = remote_origin
|
||||
|
||||
def handle_request(self, request: Request) -> Response:
|
||||
headers = merge_headers(self._proxy_headers, request.headers)
|
||||
url = URL(
|
||||
scheme=self._proxy_origin.scheme,
|
||||
host=self._proxy_origin.host,
|
||||
port=self._proxy_origin.port,
|
||||
target=bytes(request.url),
|
||||
)
|
||||
proxy_request = Request(
|
||||
method=request.method,
|
||||
url=url,
|
||||
headers=headers,
|
||||
content=request.stream,
|
||||
extensions=request.extensions,
|
||||
)
|
||||
return self._connection.handle_request(proxy_request)
|
||||
|
||||
def can_handle_request(self, origin: Origin) -> bool:
|
||||
return origin == self._remote_origin
|
||||
|
||||
def close(self) -> None:
|
||||
self._connection.close()
|
||||
|
||||
def info(self) -> str:
|
||||
return self._connection.info()
|
||||
|
||||
def is_available(self) -> bool:
|
||||
return self._connection.is_available()
|
||||
|
||||
def has_expired(self) -> bool:
|
||||
return self._connection.has_expired()
|
||||
|
||||
def is_idle(self) -> bool:
|
||||
return self._connection.is_idle()
|
||||
|
||||
def is_closed(self) -> bool:
|
||||
return self._connection.is_closed()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<{self.__class__.__name__} [{self.info()}]>"
|
||||
|
||||
|
||||
class TunnelHTTPConnection(ConnectionInterface):
|
||||
def __init__(
|
||||
self,
|
||||
proxy_origin: Origin,
|
||||
remote_origin: Origin,
|
||||
ssl_context: Optional[ssl.SSLContext] = None,
|
||||
proxy_headers: Optional[Sequence[Tuple[bytes, bytes]]] = None,
|
||||
keepalive_expiry: Optional[float] = None,
|
||||
http1: bool = True,
|
||||
http2: bool = False,
|
||||
network_backend: Optional[NetworkBackend] = None,
|
||||
socket_options: Optional[Iterable[SOCKET_OPTION]] = None,
|
||||
) -> None:
|
||||
self._connection: ConnectionInterface = HTTPConnection(
|
||||
origin=proxy_origin,
|
||||
keepalive_expiry=keepalive_expiry,
|
||||
network_backend=network_backend,
|
||||
socket_options=socket_options,
|
||||
)
|
||||
self._proxy_origin = proxy_origin
|
||||
self._remote_origin = remote_origin
|
||||
self._ssl_context = ssl_context
|
||||
self._proxy_headers = enforce_headers(proxy_headers, name="proxy_headers")
|
||||
self._keepalive_expiry = keepalive_expiry
|
||||
self._http1 = http1
|
||||
self._http2 = http2
|
||||
self._connect_lock = Lock()
|
||||
self._connected = False
|
||||
|
||||
def handle_request(self, request: Request) -> Response:
|
||||
timeouts = request.extensions.get("timeout", {})
|
||||
timeout = timeouts.get("connect", None)
|
||||
|
||||
with self._connect_lock:
|
||||
if not self._connected:
|
||||
target = b"%b:%d" % (self._remote_origin.host, self._remote_origin.port)
|
||||
|
||||
connect_url = URL(
|
||||
scheme=self._proxy_origin.scheme,
|
||||
host=self._proxy_origin.host,
|
||||
port=self._proxy_origin.port,
|
||||
target=target,
|
||||
)
|
||||
connect_headers = merge_headers(
|
||||
[(b"Host", target), (b"Accept", b"*/*")], self._proxy_headers
|
||||
)
|
||||
connect_request = Request(
|
||||
method=b"CONNECT",
|
||||
url=connect_url,
|
||||
headers=connect_headers,
|
||||
extensions=request.extensions,
|
||||
)
|
||||
connect_response = self._connection.handle_request(
|
||||
connect_request
|
||||
)
|
||||
|
||||
if connect_response.status < 200 or connect_response.status > 299:
|
||||
reason_bytes = connect_response.extensions.get("reason_phrase", b"")
|
||||
reason_str = reason_bytes.decode("ascii", errors="ignore")
|
||||
msg = "%d %s" % (connect_response.status, reason_str)
|
||||
self._connection.close()
|
||||
raise ProxyError(msg)
|
||||
|
||||
stream = connect_response.extensions["network_stream"]
|
||||
|
||||
# Upgrade the stream to SSL
|
||||
ssl_context = (
|
||||
default_ssl_context()
|
||||
if self._ssl_context is None
|
||||
else self._ssl_context
|
||||
)
|
||||
alpn_protocols = ["http/1.1", "h2"] if self._http2 else ["http/1.1"]
|
||||
ssl_context.set_alpn_protocols(alpn_protocols)
|
||||
|
||||
kwargs = {
|
||||
"ssl_context": ssl_context,
|
||||
"server_hostname": self._remote_origin.host.decode("ascii"),
|
||||
"timeout": timeout,
|
||||
}
|
||||
with Trace("start_tls", logger, request, kwargs) as trace:
|
||||
stream = stream.start_tls(**kwargs)
|
||||
trace.return_value = stream
|
||||
|
||||
# Determine if we should be using HTTP/1.1 or HTTP/2
|
||||
ssl_object = stream.get_extra_info("ssl_object")
|
||||
http2_negotiated = (
|
||||
ssl_object is not None
|
||||
and ssl_object.selected_alpn_protocol() == "h2"
|
||||
)
|
||||
|
||||
# Create the HTTP/1.1 or HTTP/2 connection
|
||||
if http2_negotiated or (self._http2 and not self._http1):
|
||||
from .http2 import HTTP2Connection
|
||||
|
||||
self._connection = HTTP2Connection(
|
||||
origin=self._remote_origin,
|
||||
stream=stream,
|
||||
keepalive_expiry=self._keepalive_expiry,
|
||||
)
|
||||
else:
|
||||
self._connection = HTTP11Connection(
|
||||
origin=self._remote_origin,
|
||||
stream=stream,
|
||||
keepalive_expiry=self._keepalive_expiry,
|
||||
)
|
||||
|
||||
self._connected = True
|
||||
return self._connection.handle_request(request)
|
||||
|
||||
def can_handle_request(self, origin: Origin) -> bool:
|
||||
return origin == self._remote_origin
|
||||
|
||||
def close(self) -> None:
|
||||
self._connection.close()
|
||||
|
||||
def info(self) -> str:
|
||||
return self._connection.info()
|
||||
|
||||
def is_available(self) -> bool:
|
||||
return self._connection.is_available()
|
||||
|
||||
def has_expired(self) -> bool:
|
||||
return self._connection.has_expired()
|
||||
|
||||
def is_idle(self) -> bool:
|
||||
return self._connection.is_idle()
|
||||
|
||||
def is_closed(self) -> bool:
|
||||
return self._connection.is_closed()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<{self.__class__.__name__} [{self.info()}]>"
|
|
@ -0,0 +1,135 @@
|
|||
from contextlib import contextmanager
|
||||
from typing import Iterator, Optional, Union
|
||||
|
||||
from .._models import (
|
||||
URL,
|
||||
Extensions,
|
||||
HeaderTypes,
|
||||
Origin,
|
||||
Request,
|
||||
Response,
|
||||
enforce_bytes,
|
||||
enforce_headers,
|
||||
enforce_url,
|
||||
include_request_headers,
|
||||
)
|
||||
|
||||
|
||||
class RequestInterface:
|
||||
def request(
|
||||
self,
|
||||
method: Union[bytes, str],
|
||||
url: Union[URL, bytes, str],
|
||||
*,
|
||||
headers: HeaderTypes = None,
|
||||
content: Union[bytes, Iterator[bytes], None] = None,
|
||||
extensions: Optional[Extensions] = None,
|
||||
) -> Response:
|
||||
# Strict type checking on our parameters.
|
||||
method = enforce_bytes(method, name="method")
|
||||
url = enforce_url(url, name="url")
|
||||
headers = enforce_headers(headers, name="headers")
|
||||
|
||||
# Include Host header, and optionally Content-Length or Transfer-Encoding.
|
||||
headers = include_request_headers(headers, url=url, content=content)
|
||||
|
||||
request = Request(
|
||||
method=method,
|
||||
url=url,
|
||||
headers=headers,
|
||||
content=content,
|
||||
extensions=extensions,
|
||||
)
|
||||
response = self.handle_request(request)
|
||||
try:
|
||||
response.read()
|
||||
finally:
|
||||
response.close()
|
||||
return response
|
||||
|
||||
@contextmanager
|
||||
def stream(
|
||||
self,
|
||||
method: Union[bytes, str],
|
||||
url: Union[URL, bytes, str],
|
||||
*,
|
||||
headers: HeaderTypes = None,
|
||||
content: Union[bytes, Iterator[bytes], None] = None,
|
||||
extensions: Optional[Extensions] = None,
|
||||
) -> Iterator[Response]:
|
||||
# Strict type checking on our parameters.
|
||||
method = enforce_bytes(method, name="method")
|
||||
url = enforce_url(url, name="url")
|
||||
headers = enforce_headers(headers, name="headers")
|
||||
|
||||
# Include Host header, and optionally Content-Length or Transfer-Encoding.
|
||||
headers = include_request_headers(headers, url=url, content=content)
|
||||
|
||||
request = Request(
|
||||
method=method,
|
||||
url=url,
|
||||
headers=headers,
|
||||
content=content,
|
||||
extensions=extensions,
|
||||
)
|
||||
response = self.handle_request(request)
|
||||
try:
|
||||
yield response
|
||||
finally:
|
||||
response.close()
|
||||
|
||||
def handle_request(self, request: Request) -> Response:
|
||||
raise NotImplementedError() # pragma: nocover
|
||||
|
||||
|
||||
class ConnectionInterface(RequestInterface):
|
||||
def close(self) -> None:
|
||||
raise NotImplementedError() # pragma: nocover
|
||||
|
||||
def info(self) -> str:
|
||||
raise NotImplementedError() # pragma: nocover
|
||||
|
||||
def can_handle_request(self, origin: Origin) -> bool:
|
||||
raise NotImplementedError() # pragma: nocover
|
||||
|
||||
def is_available(self) -> bool:
|
||||
"""
|
||||
Return `True` if the connection is currently able to accept an
|
||||
outgoing request.
|
||||
|
||||
An HTTP/1.1 connection will only be available if it is currently idle.
|
||||
|
||||
An HTTP/2 connection will be available so long as the stream ID space is
|
||||
not yet exhausted, and the connection is not in an error state.
|
||||
|
||||
While the connection is being established we may not yet know if it is going
|
||||
to result in an HTTP/1.1 or HTTP/2 connection. The connection should be
|
||||
treated as being available, but might ultimately raise `NewConnectionRequired`
|
||||
required exceptions if multiple requests are attempted over a connection
|
||||
that ends up being established as HTTP/1.1.
|
||||
"""
|
||||
raise NotImplementedError() # pragma: nocover
|
||||
|
||||
def has_expired(self) -> bool:
|
||||
"""
|
||||
Return `True` if the connection is in a state where it should be closed.
|
||||
|
||||
This either means that the connection is idle and it has passed the
|
||||
expiry time on its keep-alive, or that server has sent an EOF.
|
||||
"""
|
||||
raise NotImplementedError() # pragma: nocover
|
||||
|
||||
def is_idle(self) -> bool:
|
||||
"""
|
||||
Return `True` if the connection is currently idle.
|
||||
"""
|
||||
raise NotImplementedError() # pragma: nocover
|
||||
|
||||
def is_closed(self) -> bool:
|
||||
"""
|
||||
Return `True` if the connection has been closed.
|
||||
|
||||
Used when a response is closed to determine if the connection may be
|
||||
returned to the connection pool or not.
|
||||
"""
|
||||
raise NotImplementedError() # pragma: nocover
|
|
@ -0,0 +1,340 @@
|
|||
import logging
|
||||
import ssl
|
||||
import typing
|
||||
|
||||
from socksio import socks5
|
||||
|
||||
from .._backends.sync import SyncBackend
|
||||
from .._backends.base import NetworkBackend, NetworkStream
|
||||
from .._exceptions import ConnectionNotAvailable, ProxyError
|
||||
from .._models import URL, Origin, Request, Response, enforce_bytes, enforce_url
|
||||
from .._ssl import default_ssl_context
|
||||
from .._synchronization import Lock
|
||||
from .._trace import Trace
|
||||
from .connection_pool import ConnectionPool
|
||||
from .http11 import HTTP11Connection
|
||||
from .interfaces import ConnectionInterface
|
||||
|
||||
logger = logging.getLogger("httpcore.socks")
|
||||
|
||||
|
||||
AUTH_METHODS = {
|
||||
b"\x00": "NO AUTHENTICATION REQUIRED",
|
||||
b"\x01": "GSSAPI",
|
||||
b"\x02": "USERNAME/PASSWORD",
|
||||
b"\xff": "NO ACCEPTABLE METHODS",
|
||||
}
|
||||
|
||||
REPLY_CODES = {
|
||||
b"\x00": "Succeeded",
|
||||
b"\x01": "General SOCKS server failure",
|
||||
b"\x02": "Connection not allowed by ruleset",
|
||||
b"\x03": "Network unreachable",
|
||||
b"\x04": "Host unreachable",
|
||||
b"\x05": "Connection refused",
|
||||
b"\x06": "TTL expired",
|
||||
b"\x07": "Command not supported",
|
||||
b"\x08": "Address type not supported",
|
||||
}
|
||||
|
||||
|
||||
def _init_socks5_connection(
|
||||
stream: NetworkStream,
|
||||
*,
|
||||
host: bytes,
|
||||
port: int,
|
||||
auth: typing.Optional[typing.Tuple[bytes, bytes]] = None,
|
||||
) -> None:
|
||||
conn = socks5.SOCKS5Connection()
|
||||
|
||||
# Auth method request
|
||||
auth_method = (
|
||||
socks5.SOCKS5AuthMethod.NO_AUTH_REQUIRED
|
||||
if auth is None
|
||||
else socks5.SOCKS5AuthMethod.USERNAME_PASSWORD
|
||||
)
|
||||
conn.send(socks5.SOCKS5AuthMethodsRequest([auth_method]))
|
||||
outgoing_bytes = conn.data_to_send()
|
||||
stream.write(outgoing_bytes)
|
||||
|
||||
# Auth method response
|
||||
incoming_bytes = stream.read(max_bytes=4096)
|
||||
response = conn.receive_data(incoming_bytes)
|
||||
assert isinstance(response, socks5.SOCKS5AuthReply)
|
||||
if response.method != auth_method:
|
||||
requested = AUTH_METHODS.get(auth_method, "UNKNOWN")
|
||||
responded = AUTH_METHODS.get(response.method, "UNKNOWN")
|
||||
raise ProxyError(
|
||||
f"Requested {requested} from proxy server, but got {responded}."
|
||||
)
|
||||
|
||||
if response.method == socks5.SOCKS5AuthMethod.USERNAME_PASSWORD:
|
||||
# Username/password request
|
||||
assert auth is not None
|
||||
username, password = auth
|
||||
conn.send(socks5.SOCKS5UsernamePasswordRequest(username, password))
|
||||
outgoing_bytes = conn.data_to_send()
|
||||
stream.write(outgoing_bytes)
|
||||
|
||||
# Username/password response
|
||||
incoming_bytes = stream.read(max_bytes=4096)
|
||||
response = conn.receive_data(incoming_bytes)
|
||||
assert isinstance(response, socks5.SOCKS5UsernamePasswordReply)
|
||||
if not response.success:
|
||||
raise ProxyError("Invalid username/password")
|
||||
|
||||
# Connect request
|
||||
conn.send(
|
||||
socks5.SOCKS5CommandRequest.from_address(
|
||||
socks5.SOCKS5Command.CONNECT, (host, port)
|
||||
)
|
||||
)
|
||||
outgoing_bytes = conn.data_to_send()
|
||||
stream.write(outgoing_bytes)
|
||||
|
||||
# Connect response
|
||||
incoming_bytes = stream.read(max_bytes=4096)
|
||||
response = conn.receive_data(incoming_bytes)
|
||||
assert isinstance(response, socks5.SOCKS5Reply)
|
||||
if response.reply_code != socks5.SOCKS5ReplyCode.SUCCEEDED:
|
||||
reply_code = REPLY_CODES.get(response.reply_code, "UNKOWN")
|
||||
raise ProxyError(f"Proxy Server could not connect: {reply_code}.")
|
||||
|
||||
|
||||
class SOCKSProxy(ConnectionPool):
|
||||
"""
|
||||
A connection pool that sends requests via an HTTP proxy.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
proxy_url: typing.Union[URL, bytes, str],
|
||||
proxy_auth: typing.Optional[
|
||||
typing.Tuple[typing.Union[bytes, str], typing.Union[bytes, str]]
|
||||
] = None,
|
||||
ssl_context: typing.Optional[ssl.SSLContext] = None,
|
||||
max_connections: typing.Optional[int] = 10,
|
||||
max_keepalive_connections: typing.Optional[int] = None,
|
||||
keepalive_expiry: typing.Optional[float] = None,
|
||||
http1: bool = True,
|
||||
http2: bool = False,
|
||||
retries: int = 0,
|
||||
network_backend: typing.Optional[NetworkBackend] = None,
|
||||
) -> None:
|
||||
"""
|
||||
A connection pool for making HTTP requests.
|
||||
|
||||
Parameters:
|
||||
proxy_url: The URL to use when connecting to the proxy server.
|
||||
For example `"http://127.0.0.1:8080/"`.
|
||||
ssl_context: An SSL context to use for verifying connections.
|
||||
If not specified, the default `httpcore.default_ssl_context()`
|
||||
will be used.
|
||||
max_connections: The maximum number of concurrent HTTP connections that
|
||||
the pool should allow. Any attempt to send a request on a pool that
|
||||
would exceed this amount will block until a connection is available.
|
||||
max_keepalive_connections: The maximum number of idle HTTP connections
|
||||
that will be maintained in the pool.
|
||||
keepalive_expiry: The duration in seconds that an idle HTTP connection
|
||||
may be maintained for before being expired from the pool.
|
||||
http1: A boolean indicating if HTTP/1.1 requests should be supported
|
||||
by the connection pool. Defaults to True.
|
||||
http2: A boolean indicating if HTTP/2 requests should be supported by
|
||||
the connection pool. Defaults to False.
|
||||
retries: The maximum number of retries when trying to establish
|
||||
a connection.
|
||||
local_address: Local address to connect from. Can also be used to
|
||||
connect using a particular address family. Using
|
||||
`local_address="0.0.0.0"` will connect using an `AF_INET` address
|
||||
(IPv4), while using `local_address="::"` will connect using an
|
||||
`AF_INET6` address (IPv6).
|
||||
uds: Path to a Unix Domain Socket to use instead of TCP sockets.
|
||||
network_backend: A backend instance to use for handling network I/O.
|
||||
"""
|
||||
super().__init__(
|
||||
ssl_context=ssl_context,
|
||||
max_connections=max_connections,
|
||||
max_keepalive_connections=max_keepalive_connections,
|
||||
keepalive_expiry=keepalive_expiry,
|
||||
http1=http1,
|
||||
http2=http2,
|
||||
network_backend=network_backend,
|
||||
retries=retries,
|
||||
)
|
||||
self._ssl_context = ssl_context
|
||||
self._proxy_url = enforce_url(proxy_url, name="proxy_url")
|
||||
if proxy_auth is not None:
|
||||
username, password = proxy_auth
|
||||
username_bytes = enforce_bytes(username, name="proxy_auth")
|
||||
password_bytes = enforce_bytes(password, name="proxy_auth")
|
||||
self._proxy_auth: typing.Optional[typing.Tuple[bytes, bytes]] = (
|
||||
username_bytes,
|
||||
password_bytes,
|
||||
)
|
||||
else:
|
||||
self._proxy_auth = None
|
||||
|
||||
def create_connection(self, origin: Origin) -> ConnectionInterface:
|
||||
return Socks5Connection(
|
||||
proxy_origin=self._proxy_url.origin,
|
||||
remote_origin=origin,
|
||||
proxy_auth=self._proxy_auth,
|
||||
ssl_context=self._ssl_context,
|
||||
keepalive_expiry=self._keepalive_expiry,
|
||||
http1=self._http1,
|
||||
http2=self._http2,
|
||||
network_backend=self._network_backend,
|
||||
)
|
||||
|
||||
|
||||
class Socks5Connection(ConnectionInterface):
|
||||
def __init__(
|
||||
self,
|
||||
proxy_origin: Origin,
|
||||
remote_origin: Origin,
|
||||
proxy_auth: typing.Optional[typing.Tuple[bytes, bytes]] = None,
|
||||
ssl_context: typing.Optional[ssl.SSLContext] = None,
|
||||
keepalive_expiry: typing.Optional[float] = None,
|
||||
http1: bool = True,
|
||||
http2: bool = False,
|
||||
network_backend: typing.Optional[NetworkBackend] = None,
|
||||
) -> None:
|
||||
self._proxy_origin = proxy_origin
|
||||
self._remote_origin = remote_origin
|
||||
self._proxy_auth = proxy_auth
|
||||
self._ssl_context = ssl_context
|
||||
self._keepalive_expiry = keepalive_expiry
|
||||
self._http1 = http1
|
||||
self._http2 = http2
|
||||
|
||||
self._network_backend: NetworkBackend = (
|
||||
SyncBackend() if network_backend is None else network_backend
|
||||
)
|
||||
self._connect_lock = Lock()
|
||||
self._connection: typing.Optional[ConnectionInterface] = None
|
||||
self._connect_failed = False
|
||||
|
||||
def handle_request(self, request: Request) -> Response:
|
||||
timeouts = request.extensions.get("timeout", {})
|
||||
timeout = timeouts.get("connect", None)
|
||||
|
||||
with self._connect_lock:
|
||||
if self._connection is None:
|
||||
try:
|
||||
# Connect to the proxy
|
||||
kwargs = {
|
||||
"host": self._proxy_origin.host.decode("ascii"),
|
||||
"port": self._proxy_origin.port,
|
||||
"timeout": timeout,
|
||||
}
|
||||
with Trace("connect_tcp", logger, request, kwargs) as trace:
|
||||
stream = self._network_backend.connect_tcp(**kwargs)
|
||||
trace.return_value = stream
|
||||
|
||||
# Connect to the remote host using socks5
|
||||
kwargs = {
|
||||
"stream": stream,
|
||||
"host": self._remote_origin.host.decode("ascii"),
|
||||
"port": self._remote_origin.port,
|
||||
"auth": self._proxy_auth,
|
||||
}
|
||||
with Trace(
|
||||
"setup_socks5_connection", logger, request, kwargs
|
||||
) as trace:
|
||||
_init_socks5_connection(**kwargs)
|
||||
trace.return_value = stream
|
||||
|
||||
# Upgrade the stream to SSL
|
||||
if self._remote_origin.scheme == b"https":
|
||||
ssl_context = (
|
||||
default_ssl_context()
|
||||
if self._ssl_context is None
|
||||
else self._ssl_context
|
||||
)
|
||||
alpn_protocols = (
|
||||
["http/1.1", "h2"] if self._http2 else ["http/1.1"]
|
||||
)
|
||||
ssl_context.set_alpn_protocols(alpn_protocols)
|
||||
|
||||
kwargs = {
|
||||
"ssl_context": ssl_context,
|
||||
"server_hostname": self._remote_origin.host.decode("ascii"),
|
||||
"timeout": timeout,
|
||||
}
|
||||
with Trace("start_tls", logger, request, kwargs) as trace:
|
||||
stream = stream.start_tls(**kwargs)
|
||||
trace.return_value = stream
|
||||
|
||||
# Determine if we should be using HTTP/1.1 or HTTP/2
|
||||
ssl_object = stream.get_extra_info("ssl_object")
|
||||
http2_negotiated = (
|
||||
ssl_object is not None
|
||||
and ssl_object.selected_alpn_protocol() == "h2"
|
||||
)
|
||||
|
||||
# Create the HTTP/1.1 or HTTP/2 connection
|
||||
if http2_negotiated or (
|
||||
self._http2 and not self._http1
|
||||
): # pragma: nocover
|
||||
from .http2 import HTTP2Connection
|
||||
|
||||
self._connection = HTTP2Connection(
|
||||
origin=self._remote_origin,
|
||||
stream=stream,
|
||||
keepalive_expiry=self._keepalive_expiry,
|
||||
)
|
||||
else:
|
||||
self._connection = HTTP11Connection(
|
||||
origin=self._remote_origin,
|
||||
stream=stream,
|
||||
keepalive_expiry=self._keepalive_expiry,
|
||||
)
|
||||
except Exception as exc:
|
||||
self._connect_failed = True
|
||||
raise exc
|
||||
elif not self._connection.is_available(): # pragma: nocover
|
||||
raise ConnectionNotAvailable()
|
||||
|
||||
return self._connection.handle_request(request)
|
||||
|
||||
def can_handle_request(self, origin: Origin) -> bool:
|
||||
return origin == self._remote_origin
|
||||
|
||||
def close(self) -> None:
|
||||
if self._connection is not None:
|
||||
self._connection.close()
|
||||
|
||||
def is_available(self) -> bool:
|
||||
if self._connection is None: # pragma: nocover
|
||||
# If HTTP/2 support is enabled, and the resulting connection could
|
||||
# end up as HTTP/2 then we should indicate the connection as being
|
||||
# available to service multiple requests.
|
||||
return (
|
||||
self._http2
|
||||
and (self._remote_origin.scheme == b"https" or not self._http1)
|
||||
and not self._connect_failed
|
||||
)
|
||||
return self._connection.is_available()
|
||||
|
||||
def has_expired(self) -> bool:
|
||||
if self._connection is None: # pragma: nocover
|
||||
return self._connect_failed
|
||||
return self._connection.has_expired()
|
||||
|
||||
def is_idle(self) -> bool:
|
||||
if self._connection is None: # pragma: nocover
|
||||
return self._connect_failed
|
||||
return self._connection.is_idle()
|
||||
|
||||
def is_closed(self) -> bool:
|
||||
if self._connection is None: # pragma: nocover
|
||||
return self._connect_failed
|
||||
return self._connection.is_closed()
|
||||
|
||||
def info(self) -> str:
|
||||
if self._connection is None: # pragma: nocover
|
||||
return "CONNECTION FAILED" if self._connect_failed else "CONNECTING"
|
||||
return self._connection.info()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<{self.__class__.__name__} [{self.info()}]>"
|
|
@ -0,0 +1,333 @@
|
|||
import threading
|
||||
from types import TracebackType
|
||||
from typing import Optional, Type
|
||||
|
||||
import sniffio
|
||||
|
||||
from ._exceptions import ExceptionMapping, PoolTimeout, map_exceptions
|
||||
|
||||
# Our async synchronization primatives use either 'anyio' or 'trio' depending
|
||||
# on if they're running under asyncio or trio.
|
||||
|
||||
try:
|
||||
import trio
|
||||
except ImportError: # pragma: nocover
|
||||
trio = None # type: ignore
|
||||
|
||||
try:
|
||||
import anyio
|
||||
except ImportError: # pragma: nocover
|
||||
anyio = None # type: ignore
|
||||
|
||||
try:
|
||||
import structio
|
||||
except ImportError: # pragma: nocover
|
||||
structio = None # type: ignore
|
||||
|
||||
|
||||
class AsyncLock:
|
||||
def __init__(self) -> None:
|
||||
self._backend = ""
|
||||
|
||||
def setup(self) -> None:
|
||||
"""
|
||||
Detect if we're running under 'asyncio' or 'trio' and create
|
||||
a lock with the correct implementation.
|
||||
"""
|
||||
self._backend = sniffio.current_async_library()
|
||||
if self._backend == "trio":
|
||||
if trio is None: # pragma: nocover
|
||||
raise RuntimeError(
|
||||
"Running under trio, requires the 'trio' package to be installed."
|
||||
)
|
||||
self._trio_lock = trio.Lock()
|
||||
elif self._backend == "structured-io":
|
||||
if structio is None: # pragma: nocover
|
||||
raise RuntimeError(
|
||||
"Running under structio requires the 'structio' package to be installed."
|
||||
)
|
||||
self._structio_lock = structio.Lock()
|
||||
else:
|
||||
if anyio is None: # pragma: nocover
|
||||
raise RuntimeError(
|
||||
"Running under asyncio requires the 'anyio' package to be installed."
|
||||
)
|
||||
self._anyio_lock = anyio.Lock()
|
||||
|
||||
async def __aenter__(self) -> "AsyncLock":
|
||||
if not self._backend:
|
||||
self.setup()
|
||||
|
||||
if self._backend == "trio":
|
||||
await self._trio_lock.acquire()
|
||||
elif self._backend == "structured-io":
|
||||
await self._structio_lock.acquire()
|
||||
else:
|
||||
await self._anyio_lock.acquire()
|
||||
|
||||
return self
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: Optional[Type[BaseException]] = None,
|
||||
exc_value: Optional[BaseException] = None,
|
||||
traceback: Optional[TracebackType] = None,
|
||||
) -> None:
|
||||
if self._backend == "trio":
|
||||
self._trio_lock.release()
|
||||
elif self._backend == "structured-io":
|
||||
await self._structio_lock.release()
|
||||
else:
|
||||
self._anyio_lock.release()
|
||||
|
||||
|
||||
class AsyncEvent:
|
||||
def __init__(self) -> None:
|
||||
self._backend = ""
|
||||
|
||||
def setup(self) -> None:
|
||||
"""
|
||||
Detect if we're running under 'asyncio' or 'trio' and create
|
||||
a lock with the correct implementation.
|
||||
"""
|
||||
self._backend = sniffio.current_async_library()
|
||||
if self._backend == "trio":
|
||||
if trio is None: # pragma: nocover
|
||||
raise RuntimeError(
|
||||
"Running under trio requires the 'trio' package to be installed."
|
||||
)
|
||||
self._trio_event = trio.Event()
|
||||
elif self._backend == "structured-io":
|
||||
if structio is None: # pragma: nocover
|
||||
raise RuntimeError(
|
||||
"Running under structio requires the 'structio' package to be installed."
|
||||
)
|
||||
self._structio_event = structio.Event()
|
||||
else:
|
||||
if anyio is None: # pragma: nocover
|
||||
raise RuntimeError(
|
||||
"Running under asyncio requires the 'anyio' package to be installed."
|
||||
)
|
||||
self._anyio_event = anyio.Event()
|
||||
|
||||
def set(self) -> None:
|
||||
if not self._backend:
|
||||
self.setup()
|
||||
|
||||
if self._backend == "trio":
|
||||
self._trio_event.set()
|
||||
elif self._backend == "structured-io":
|
||||
self._structio_event.set()
|
||||
else:
|
||||
self._anyio_event.set()
|
||||
|
||||
async def wait(self, timeout: Optional[float] = None) -> None:
|
||||
if not self._backend:
|
||||
self.setup()
|
||||
|
||||
if self._backend == "trio":
|
||||
if trio is None: # pragma: nocover
|
||||
raise RuntimeError(
|
||||
"Running under trio requires the 'trio' package to be installed."
|
||||
)
|
||||
|
||||
trio_exc_map: ExceptionMapping = {trio.TooSlowError: PoolTimeout}
|
||||
timeout_or_inf = float("inf") if timeout is None else timeout
|
||||
with map_exceptions(trio_exc_map):
|
||||
with trio.fail_after(timeout_or_inf):
|
||||
await self._trio_event.wait()
|
||||
elif self._backend == "structured-io":
|
||||
if structio is None: # pragma: nocover
|
||||
raise RuntimeError(
|
||||
"Running under structio requires the 'structio' package to be installed."
|
||||
)
|
||||
|
||||
structio_exc_map: ExceptionMapping = {structio.TimedOut: PoolTimeout}
|
||||
timeout_or_inf = float("inf") if timeout is None else timeout
|
||||
with map_exceptions(trio_exc_map):
|
||||
with structio.with_timeout(timeout_or_inf):
|
||||
await self._structio_event.wait()
|
||||
else:
|
||||
if anyio is None: # pragma: nocover
|
||||
raise RuntimeError(
|
||||
"Running under asyncio requires the 'anyio' package to be installed."
|
||||
)
|
||||
|
||||
anyio_exc_map: ExceptionMapping = {TimeoutError: PoolTimeout}
|
||||
with map_exceptions(anyio_exc_map):
|
||||
with anyio.fail_after(timeout):
|
||||
await self._anyio_event.wait()
|
||||
|
||||
|
||||
class AsyncSemaphore:
|
||||
def __init__(self, bound: int) -> None:
|
||||
self._bound = bound
|
||||
self._backend = ""
|
||||
|
||||
def setup(self) -> None:
|
||||
"""
|
||||
Detect if we're running under 'asyncio' or 'trio' and create
|
||||
a semaphore with the correct implementation.
|
||||
"""
|
||||
self._backend = sniffio.current_async_library()
|
||||
if self._backend == "trio":
|
||||
if trio is None: # pragma: nocover
|
||||
raise RuntimeError(
|
||||
"Running under trio requires the 'trio' package to be installed."
|
||||
)
|
||||
|
||||
self._trio_semaphore = trio.Semaphore(
|
||||
initial_value=self._bound, max_value=self._bound
|
||||
)
|
||||
elif self._backend == "structured-io":
|
||||
if structio is None: # pragma: nocover
|
||||
raise RuntimeError(
|
||||
"Running under structio requires the 'structio' package to be installed."
|
||||
)
|
||||
self._structio_semaphore = structio.Semaphore(initial_size=self._bound, max_size=self._bound)
|
||||
else:
|
||||
if anyio is None: # pragma: nocover
|
||||
raise RuntimeError(
|
||||
"Running under asyncio requires the 'anyio' package to be installed."
|
||||
)
|
||||
|
||||
self._anyio_semaphore = anyio.Semaphore(
|
||||
initial_value=self._bound, max_value=self._bound
|
||||
)
|
||||
|
||||
async def acquire(self) -> None:
|
||||
if not self._backend:
|
||||
self.setup()
|
||||
|
||||
if self._backend == "trio":
|
||||
await self._trio_semaphore.acquire()
|
||||
elif self._backend == "structured-io":
|
||||
await self._structio_semaphore.acquire()
|
||||
else:
|
||||
await self._anyio_semaphore.acquire()
|
||||
|
||||
async def release(self) -> None:
|
||||
if self._backend == "trio":
|
||||
self._trio_semaphore.release()
|
||||
elif self._backend == "structured-io":
|
||||
await self._structio_semaphore.release()
|
||||
else:
|
||||
self._anyio_semaphore.release()
|
||||
|
||||
|
||||
class AsyncShieldCancellation:
|
||||
# For certain portions of our codebase where we're dealing with
|
||||
# closing connections during exception handling we want to shield
|
||||
# the operation from being cancelled.
|
||||
#
|
||||
# with AsyncShieldCancellation():
|
||||
# ... # clean-up operations, shielded from cancellation.
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""
|
||||
Detect if we're running under 'asyncio' or 'trio' and create
|
||||
a shielded scope with the correct implementation.
|
||||
"""
|
||||
self._backend = sniffio.current_async_library()
|
||||
|
||||
if self._backend == "trio":
|
||||
if trio is None: # pragma: nocover
|
||||
raise RuntimeError(
|
||||
"Running under trio requires the 'trio' package to be installed."
|
||||
)
|
||||
|
||||
self._trio_shield = trio.CancelScope(shield=True)
|
||||
elif self._backend == "structured-io":
|
||||
if structio is None: # pragma: nocover
|
||||
raise RuntimeError(
|
||||
"Running under structio requires the 'structio' package to be installed."
|
||||
)
|
||||
self._structio_shield = structio.TaskScope(shielded=True)
|
||||
else:
|
||||
if anyio is None: # pragma: nocover
|
||||
raise RuntimeError(
|
||||
"Running under asyncio requires the 'anyio' package to be installed."
|
||||
)
|
||||
|
||||
self._anyio_shield = anyio.CancelScope(shield=True)
|
||||
|
||||
def __enter__(self) -> "AsyncShieldCancellation":
|
||||
if self._backend == "trio":
|
||||
self._trio_shield.__enter__()
|
||||
elif self._backend == "structured-io":
|
||||
self._structio_shield.__enter__()
|
||||
else:
|
||||
self._anyio_shield.__enter__()
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: Optional[Type[BaseException]] = None,
|
||||
exc_value: Optional[BaseException] = None,
|
||||
traceback: Optional[TracebackType] = None,
|
||||
) -> None:
|
||||
if self._backend == "trio":
|
||||
self._trio_shield.__exit__(exc_type, exc_value, traceback)
|
||||
elif self._backend == "structured-io":
|
||||
self._structio_shield.__exit__(exc_type, exc_value, traceback)
|
||||
else:
|
||||
self._anyio_shield.__exit__(exc_type, exc_value, traceback)
|
||||
|
||||
|
||||
# Our thread-based synchronization primitives...
|
||||
|
||||
|
||||
class Lock:
|
||||
def __init__(self) -> None:
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def __enter__(self) -> "Lock":
|
||||
self._lock.acquire()
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: Optional[Type[BaseException]] = None,
|
||||
exc_value: Optional[BaseException] = None,
|
||||
traceback: Optional[TracebackType] = None,
|
||||
) -> None:
|
||||
self._lock.release()
|
||||
|
||||
|
||||
class Event:
|
||||
def __init__(self) -> None:
|
||||
self._event = threading.Event()
|
||||
|
||||
def set(self) -> None:
|
||||
self._event.set()
|
||||
|
||||
def wait(self, timeout: Optional[float] = None) -> None:
|
||||
if not self._event.wait(timeout=timeout):
|
||||
raise PoolTimeout() # pragma: nocover
|
||||
|
||||
|
||||
class Semaphore:
|
||||
def __init__(self, bound: int) -> None:
|
||||
self._semaphore = threading.Semaphore(value=bound)
|
||||
|
||||
def acquire(self) -> None:
|
||||
self._semaphore.acquire()
|
||||
|
||||
def release(self) -> None:
|
||||
self._semaphore.release()
|
||||
|
||||
|
||||
class ShieldCancellation:
|
||||
# Thread-synchronous codebases don't support cancellation semantics.
|
||||
# We have this class because we need to mirror the async and sync
|
||||
# cases within our package, but it's just a no-op.
|
||||
def __enter__(self) -> "ShieldCancellation":
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: Optional[Type[BaseException]] = None,
|
||||
exc_value: Optional[BaseException] = None,
|
||||
traceback: Optional[TracebackType] = None,
|
||||
) -> None:
|
||||
pass
|
|
@ -0,0 +1,105 @@
|
|||
import inspect
|
||||
import logging
|
||||
from types import TracebackType
|
||||
from typing import Any, Dict, Optional, Type
|
||||
|
||||
from ._models import Request
|
||||
|
||||
|
||||
class Trace:
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
logger: logging.Logger,
|
||||
request: Optional[Request] = None,
|
||||
kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
self.name = name
|
||||
self.logger = logger
|
||||
self.trace_extension = (
|
||||
None if request is None else request.extensions.get("trace")
|
||||
)
|
||||
self.debug = self.logger.isEnabledFor(logging.DEBUG)
|
||||
self.kwargs = kwargs or {}
|
||||
self.return_value: Any = None
|
||||
self.should_trace = self.debug or self.trace_extension is not None
|
||||
self.prefix = self.logger.name.split(".")[-1]
|
||||
|
||||
def trace(self, name: str, info: Dict[str, Any]) -> None:
|
||||
if self.trace_extension is not None:
|
||||
prefix_and_name = f"{self.prefix}.{name}"
|
||||
ret = self.trace_extension(prefix_and_name, info)
|
||||
if inspect.iscoroutine(ret): # pragma: no cover
|
||||
raise TypeError(
|
||||
"If you are using a synchronous interface, "
|
||||
"the callback of the `trace` extension should "
|
||||
"be a normal function instead of an asynchronous function."
|
||||
)
|
||||
|
||||
if self.debug:
|
||||
if not info or "return_value" in info and info["return_value"] is None:
|
||||
message = name
|
||||
else:
|
||||
args = " ".join([f"{key}={value!r}" for key, value in info.items()])
|
||||
message = f"{name} {args}"
|
||||
self.logger.debug(message)
|
||||
|
||||
def __enter__(self) -> "Trace":
|
||||
if self.should_trace:
|
||||
info = self.kwargs
|
||||
self.trace(f"{self.name}.started", info)
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: Optional[Type[BaseException]] = None,
|
||||
exc_value: Optional[BaseException] = None,
|
||||
traceback: Optional[TracebackType] = None,
|
||||
) -> None:
|
||||
if self.should_trace:
|
||||
if exc_value is None:
|
||||
info = {"return_value": self.return_value}
|
||||
self.trace(f"{self.name}.complete", info)
|
||||
else:
|
||||
info = {"exception": exc_value}
|
||||
self.trace(f"{self.name}.failed", info)
|
||||
|
||||
async def atrace(self, name: str, info: Dict[str, Any]) -> None:
|
||||
if self.trace_extension is not None:
|
||||
prefix_and_name = f"{self.prefix}.{name}"
|
||||
coro = self.trace_extension(prefix_and_name, info)
|
||||
if not inspect.iscoroutine(coro): # pragma: no cover
|
||||
raise TypeError(
|
||||
"If you're using an asynchronous interface, "
|
||||
"the callback of the `trace` extension should "
|
||||
"be an asynchronous function rather than a normal function."
|
||||
)
|
||||
await coro
|
||||
|
||||
if self.debug:
|
||||
if not info or "return_value" in info and info["return_value"] is None:
|
||||
message = name
|
||||
else:
|
||||
args = " ".join([f"{key}={value!r}" for key, value in info.items()])
|
||||
message = f"{name} {args}"
|
||||
self.logger.debug(message)
|
||||
|
||||
async def __aenter__(self) -> "Trace":
|
||||
if self.should_trace:
|
||||
info = self.kwargs
|
||||
await self.atrace(f"{self.name}.started", info)
|
||||
return self
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: Optional[Type[BaseException]] = None,
|
||||
exc_value: Optional[BaseException] = None,
|
||||
traceback: Optional[TracebackType] = None,
|
||||
) -> None:
|
||||
if self.should_trace:
|
||||
if exc_value is None:
|
||||
info = {"return_value": self.return_value}
|
||||
await self.atrace(f"{self.name}.complete", info)
|
||||
else:
|
||||
info = {"exception": exc_value}
|
||||
await self.atrace(f"{self.name}.failed", info)
|
|
@ -0,0 +1,36 @@
|
|||
import select
|
||||
import socket
|
||||
import sys
|
||||
import typing
|
||||
|
||||
|
||||
def is_socket_readable(sock: typing.Optional[socket.socket]) -> bool:
|
||||
"""
|
||||
Return whether a socket, as identifed by its file descriptor, is readable.
|
||||
"A socket is readable" means that the read buffer isn't empty, i.e. that calling
|
||||
.recv() on it would immediately return some data.
|
||||
"""
|
||||
# NOTE: we want check for readability without actually attempting to read, because
|
||||
# we don't want to block forever if it's not readable.
|
||||
|
||||
# In the case that the socket no longer exists, or cannot return a file
|
||||
# descriptor, we treat it as being readable, as if it the next read operation
|
||||
# on it is ready to return the terminating `b""`.
|
||||
sock_fd = None if sock is None else sock.fileno()
|
||||
if sock_fd is None or sock_fd < 0: # pragma: nocover
|
||||
return True
|
||||
|
||||
# The implementation below was stolen from:
|
||||
# https://github.com/python-trio/trio/blob/20ee2b1b7376db637435d80e266212a35837ddcc/trio/_socket.py#L471-L478
|
||||
# See also: https://github.com/encode/httpcore/pull/193#issuecomment-703129316
|
||||
|
||||
# Use select.select on Windows, and when poll is unavailable and select.poll
|
||||
# everywhere else. (E.g. When eventlet is in use. See #327)
|
||||
if (
|
||||
sys.platform == "win32" or getattr(select, "poll", None) is None
|
||||
): # pragma: nocover
|
||||
rready, _, _ = select.select([sock_fd], [], [], 0)
|
||||
return bool(rready)
|
||||
p = select.poll()
|
||||
p.register(sock_fd, select.POLLIN)
|
||||
return bool(p.poll(0))
|
|
@ -1,10 +1,12 @@
|
|||
import httpcore
|
||||
import structio
|
||||
|
||||
|
||||
async def main():
|
||||
# Note: this test only works because we have our own version of httpcore that
|
||||
# implements a structio-compatible backend. It's just an example anyway
|
||||
pool = httpcore.AsyncConnectionPool()
|
||||
print(await pool.request("GET", "http://example.com"))
|
||||
# TODO: SSL support
|
||||
|
||||
|
||||
structio.run(main)
|
||||
|
|
Loading…
Reference in New Issue