Added patched httpcore to make the test run properly

This commit is contained in:
Mattia Giambirtone 2024-03-10 19:53:58 +01:00
parent 9b6735b924
commit 3f48c74346
34 changed files with 6848 additions and 1 deletions

139
tests/httpcore/__init__.py Normal file
View File

@ -0,0 +1,139 @@
from ._api import request, stream
from ._async import (
AsyncConnectionInterface,
AsyncConnectionPool,
AsyncHTTP2Connection,
AsyncHTTP11Connection,
AsyncHTTPConnection,
AsyncHTTPProxy,
AsyncSOCKSProxy,
)
from ._backends.base import (
SOCKET_OPTION,
AsyncNetworkBackend,
AsyncNetworkStream,
NetworkBackend,
NetworkStream,
)
from ._backends.mock import AsyncMockBackend, AsyncMockStream, MockBackend, MockStream
from ._backends.sync import SyncBackend
from ._exceptions import (
ConnectError,
ConnectionNotAvailable,
ConnectTimeout,
LocalProtocolError,
NetworkError,
PoolTimeout,
ProtocolError,
ProxyError,
ReadError,
ReadTimeout,
RemoteProtocolError,
TimeoutException,
UnsupportedProtocol,
WriteError,
WriteTimeout,
)
from ._models import URL, Origin, Request, Response
from ._ssl import default_ssl_context
from ._sync import (
ConnectionInterface,
ConnectionPool,
HTTP2Connection,
HTTP11Connection,
HTTPConnection,
HTTPProxy,
SOCKSProxy,
)
# The 'httpcore.AnyIOBackend' class is conditional on 'anyio' being installed.
try:
from ._backends.anyio import AnyIOBackend
except ImportError: # pragma: nocover
class AnyIOBackend: # type: ignore
def __init__(self, *args, **kwargs): # type: ignore
msg = (
"Attempted to use 'httpcore.AnyIOBackend' but 'anyio' is not installed."
)
raise RuntimeError(msg)
# The 'httpcore.TrioBackend' class is conditional on 'trio' being installed.
try:
from ._backends.trio import TrioBackend
except ImportError: # pragma: nocover
class TrioBackend: # type: ignore
def __init__(self, *args, **kwargs): # type: ignore
msg = "Attempted to use 'httpcore.TrioBackend' but 'trio' is not installed."
raise RuntimeError(msg)
__all__ = [
# top-level requests
"request",
"stream",
# models
"Origin",
"URL",
"Request",
"Response",
# async
"AsyncHTTPConnection",
"AsyncConnectionPool",
"AsyncHTTPProxy",
"AsyncHTTP11Connection",
"AsyncHTTP2Connection",
"AsyncConnectionInterface",
"AsyncSOCKSProxy",
# sync
"HTTPConnection",
"ConnectionPool",
"HTTPProxy",
"HTTP11Connection",
"HTTP2Connection",
"ConnectionInterface",
"SOCKSProxy",
# network backends, implementations
"SyncBackend",
"AnyIOBackend",
"TrioBackend",
# network backends, mock implementations
"AsyncMockBackend",
"AsyncMockStream",
"MockBackend",
"MockStream",
# network backends, interface
"AsyncNetworkStream",
"AsyncNetworkBackend",
"NetworkStream",
"NetworkBackend",
# util
"default_ssl_context",
"SOCKET_OPTION",
# exceptions
"ConnectionNotAvailable",
"ProxyError",
"ProtocolError",
"LocalProtocolError",
"RemoteProtocolError",
"UnsupportedProtocol",
"TimeoutException",
"PoolTimeout",
"ConnectTimeout",
"ReadTimeout",
"WriteTimeout",
"NetworkError",
"ConnectError",
"ReadError",
"WriteError",
]
__version__ = "0.17.3"
__locals = locals()
for __name in __all__:
if not __name.startswith("__"):
setattr(__locals[__name], "__module__", "httpcore") # noqa

92
tests/httpcore/_api.py Normal file
View File

@ -0,0 +1,92 @@
from contextlib import contextmanager
from typing import Iterator, Optional, Union
from ._models import URL, Extensions, HeaderTypes, Response
from ._sync.connection_pool import ConnectionPool
def request(
method: Union[bytes, str],
url: Union[URL, bytes, str],
*,
headers: HeaderTypes = None,
content: Union[bytes, Iterator[bytes], None] = None,
extensions: Optional[Extensions] = None,
) -> Response:
"""
Sends an HTTP request, returning the response.
```
response = httpcore.request("GET", "https://www.example.com/")
```
Arguments:
method: The HTTP method for the request. Typically one of `"GET"`,
`"OPTIONS"`, `"HEAD"`, `"POST"`, `"PUT"`, `"PATCH"`, or `"DELETE"`.
url: The URL of the HTTP request. Either as an instance of `httpcore.URL`,
or as str/bytes.
headers: The HTTP request headers. Either as a dictionary of str/bytes,
or as a list of two-tuples of str/bytes.
content: The content of the request body. Either as bytes,
or as a bytes iterator.
extensions: A dictionary of optional extra information included on the request.
Possible keys include `"timeout"`.
Returns:
An instance of `httpcore.Response`.
"""
with ConnectionPool() as pool:
return pool.request(
method=method,
url=url,
headers=headers,
content=content,
extensions=extensions,
)
@contextmanager
def stream(
method: Union[bytes, str],
url: Union[URL, bytes, str],
*,
headers: HeaderTypes = None,
content: Union[bytes, Iterator[bytes], None] = None,
extensions: Optional[Extensions] = None,
) -> Iterator[Response]:
"""
Sends an HTTP request, returning the response within a content manager.
```
with httpcore.stream("GET", "https://www.example.com/") as response:
...
```
When using the `stream()` function, the body of the response will not be
automatically read. If you want to access the response body you should
either use `content = response.read()`, or `for chunk in response.iter_content()`.
Arguments:
method: The HTTP method for the request. Typically one of `"GET"`,
`"OPTIONS"`, `"HEAD"`, `"POST"`, `"PUT"`, `"PATCH"`, or `"DELETE"`.
url: The URL of the HTTP request. Either as an instance of `httpcore.URL`,
or as str/bytes.
headers: The HTTP request headers. Either as a dictionary of str/bytes,
or as a list of two-tuples of str/bytes.
content: The content of the request body. Either as bytes,
or as a bytes iterator.
extensions: A dictionary of optional extra information included on the request.
Possible keys include `"timeout"`.
Returns:
An instance of `httpcore.Response`.
"""
with ConnectionPool() as pool:
with pool.stream(
method=method,
url=url,
headers=headers,
content=content,
extensions=extensions,
) as response:
yield response

View File

@ -0,0 +1,39 @@
from .connection import AsyncHTTPConnection
from .connection_pool import AsyncConnectionPool
from .http11 import AsyncHTTP11Connection
from .http_proxy import AsyncHTTPProxy
from .interfaces import AsyncConnectionInterface
try:
from .http2 import AsyncHTTP2Connection
except ImportError: # pragma: nocover
class AsyncHTTP2Connection: # type: ignore
def __init__(self, *args, **kwargs) -> None: # type: ignore
raise RuntimeError(
"Attempted to use http2 support, but the `h2` package is not "
"installed. Use 'pip install httpcore[http2]'."
)
try:
from .socks_proxy import AsyncSOCKSProxy
except ImportError: # pragma: nocover
class AsyncSOCKSProxy: # type: ignore
def __init__(self, *args, **kwargs) -> None: # type: ignore
raise RuntimeError(
"Attempted to use SOCKS support, but the `socksio` package is not "
"installed. Use 'pip install httpcore[socks]'."
)
__all__ = [
"AsyncHTTPConnection",
"AsyncConnectionPool",
"AsyncHTTPProxy",
"AsyncHTTP11Connection",
"AsyncHTTP2Connection",
"AsyncConnectionInterface",
"AsyncSOCKSProxy",
]

View File

@ -0,0 +1,215 @@
import itertools
import logging
import ssl
from types import TracebackType
from typing import Iterable, Iterator, Optional, Type
from .._backends.auto import AutoBackend
from .._backends.base import SOCKET_OPTION, AsyncNetworkBackend, AsyncNetworkStream
from .._exceptions import ConnectError, ConnectionNotAvailable, ConnectTimeout
from .._models import Origin, Request, Response
from .._ssl import default_ssl_context
from .._synchronization import AsyncLock
from .._trace import Trace
from .http11 import AsyncHTTP11Connection
from .interfaces import AsyncConnectionInterface
RETRIES_BACKOFF_FACTOR = 0.5 # 0s, 0.5s, 1s, 2s, 4s, etc.
logger = logging.getLogger("httpcore.connection")
def exponential_backoff(factor: float) -> Iterator[float]:
yield 0
for n in itertools.count(2):
yield factor * (2 ** (n - 2))
class AsyncHTTPConnection(AsyncConnectionInterface):
def __init__(
self,
origin: Origin,
ssl_context: Optional[ssl.SSLContext] = None,
keepalive_expiry: Optional[float] = None,
http1: bool = True,
http2: bool = False,
retries: int = 0,
local_address: Optional[str] = None,
uds: Optional[str] = None,
network_backend: Optional[AsyncNetworkBackend] = None,
socket_options: Optional[Iterable[SOCKET_OPTION]] = None,
) -> None:
self._origin = origin
self._ssl_context = ssl_context
self._keepalive_expiry = keepalive_expiry
self._http1 = http1
self._http2 = http2
self._retries = retries
self._local_address = local_address
self._uds = uds
self._network_backend: AsyncNetworkBackend = (
AutoBackend() if network_backend is None else network_backend
)
self._connection: Optional[AsyncConnectionInterface] = None
self._connect_failed: bool = False
self._request_lock = AsyncLock()
self._socket_options = socket_options
async def handle_async_request(self, request: Request) -> Response:
if not self.can_handle_request(request.url.origin):
raise RuntimeError(
f"Attempted to send request to {request.url.origin} on connection to {self._origin}"
)
async with self._request_lock:
if self._connection is None:
try:
stream = await self._connect(request)
ssl_object = stream.get_extra_info("ssl_object")
http2_negotiated = (
ssl_object is not None
and ssl_object.selected_alpn_protocol() == "h2"
)
if http2_negotiated or (self._http2 and not self._http1):
from .http2 import AsyncHTTP2Connection
self._connection = AsyncHTTP2Connection(
origin=self._origin,
stream=stream,
keepalive_expiry=self._keepalive_expiry,
)
else:
self._connection = AsyncHTTP11Connection(
origin=self._origin,
stream=stream,
keepalive_expiry=self._keepalive_expiry,
)
except Exception as exc:
self._connect_failed = True
raise exc
elif not self._connection.is_available():
raise ConnectionNotAvailable()
return await self._connection.handle_async_request(request)
async def _connect(self, request: Request) -> AsyncNetworkStream:
timeouts = request.extensions.get("timeout", {})
sni_hostname = request.extensions.get("sni_hostname", None)
timeout = timeouts.get("connect", None)
retries_left = self._retries
delays = exponential_backoff(factor=RETRIES_BACKOFF_FACTOR)
while True:
try:
if self._uds is None:
kwargs = {
"host": self._origin.host.decode("ascii"),
"port": self._origin.port,
"local_address": self._local_address,
"timeout": timeout,
"socket_options": self._socket_options,
}
async with Trace("connect_tcp", logger, request, kwargs) as trace:
stream = await self._network_backend.connect_tcp(**kwargs)
trace.return_value = stream
else:
kwargs = {
"path": self._uds,
"timeout": timeout,
"socket_options": self._socket_options,
}
async with Trace(
"connect_unix_socket", logger, request, kwargs
) as trace:
stream = await self._network_backend.connect_unix_socket(
**kwargs
)
trace.return_value = stream
if self._origin.scheme == b"https":
ssl_context = (
default_ssl_context()
if self._ssl_context is None
else self._ssl_context
)
alpn_protocols = ["http/1.1", "h2"] if self._http2 else ["http/1.1"]
ssl_context.set_alpn_protocols(alpn_protocols)
kwargs = {
"ssl_context": ssl_context,
"server_hostname": sni_hostname
or self._origin.host.decode("ascii"),
"timeout": timeout,
}
async with Trace("start_tls", logger, request, kwargs) as trace:
stream = await stream.start_tls(**kwargs)
trace.return_value = stream
return stream
except (ConnectError, ConnectTimeout):
if retries_left <= 0:
raise
retries_left -= 1
delay = next(delays)
async with Trace("retry", logger, request, kwargs) as trace:
await self._network_backend.sleep(delay)
def can_handle_request(self, origin: Origin) -> bool:
return origin == self._origin
async def aclose(self) -> None:
if self._connection is not None:
async with Trace("close", logger, None, {}):
await self._connection.aclose()
def is_available(self) -> bool:
if self._connection is None:
# If HTTP/2 support is enabled, and the resulting connection could
# end up as HTTP/2 then we should indicate the connection as being
# available to service multiple requests.
return (
self._http2
and (self._origin.scheme == b"https" or not self._http1)
and not self._connect_failed
)
return self._connection.is_available()
def has_expired(self) -> bool:
if self._connection is None:
return self._connect_failed
return self._connection.has_expired()
def is_idle(self) -> bool:
if self._connection is None:
return self._connect_failed
return self._connection.is_idle()
def is_closed(self) -> bool:
if self._connection is None:
return self._connect_failed
return self._connection.is_closed()
def info(self) -> str:
if self._connection is None:
return "CONNECTION FAILED" if self._connect_failed else "CONNECTING"
return self._connection.info()
def __repr__(self) -> str:
return f"<{self.__class__.__name__} [{self.info()}]>"
# These context managers are not used in the standard flow, but are
# useful for testing or working with connection instances directly.
async def __aenter__(self) -> "AsyncHTTPConnection":
return self
async def __aexit__(
self,
exc_type: Optional[Type[BaseException]] = None,
exc_value: Optional[BaseException] = None,
traceback: Optional[TracebackType] = None,
) -> None:
await self.aclose()

View File

@ -0,0 +1,356 @@
import ssl
import sys
from types import TracebackType
from typing import AsyncIterable, AsyncIterator, Iterable, List, Optional, Type
from .._backends.auto import AutoBackend
from .._backends.base import SOCKET_OPTION, AsyncNetworkBackend
from .._exceptions import ConnectionNotAvailable, UnsupportedProtocol
from .._models import Origin, Request, Response
from .._synchronization import AsyncEvent, AsyncLock, AsyncShieldCancellation
from .connection import AsyncHTTPConnection
from .interfaces import AsyncConnectionInterface, AsyncRequestInterface
class RequestStatus:
def __init__(self, request: Request):
self.request = request
self.connection: Optional[AsyncConnectionInterface] = None
self._connection_acquired = AsyncEvent()
def set_connection(self, connection: AsyncConnectionInterface) -> None:
assert self.connection is None
self.connection = connection
self._connection_acquired.set()
def unset_connection(self) -> None:
assert self.connection is not None
self.connection = None
self._connection_acquired = AsyncEvent()
async def wait_for_connection(
self, timeout: Optional[float] = None
) -> AsyncConnectionInterface:
if self.connection is None:
await self._connection_acquired.wait(timeout=timeout)
assert self.connection is not None
return self.connection
class AsyncConnectionPool(AsyncRequestInterface):
"""
A connection pool for making HTTP requests.
"""
def __init__(
self,
ssl_context: Optional[ssl.SSLContext] = None,
max_connections: Optional[int] = 10,
max_keepalive_connections: Optional[int] = None,
keepalive_expiry: Optional[float] = None,
http1: bool = True,
http2: bool = False,
retries: int = 0,
local_address: Optional[str] = None,
uds: Optional[str] = None,
network_backend: Optional[AsyncNetworkBackend] = None,
socket_options: Optional[Iterable[SOCKET_OPTION]] = None,
) -> None:
"""
A connection pool for making HTTP requests.
Parameters:
ssl_context: An SSL context to use for verifying connections.
If not specified, the default `httpcore.default_ssl_context()`
will be used.
max_connections: The maximum number of concurrent HTTP connections that
the pool should allow. Any attempt to send a request on a pool that
would exceed this amount will block until a connection is available.
max_keepalive_connections: The maximum number of idle HTTP connections
that will be maintained in the pool.
keepalive_expiry: The duration in seconds that an idle HTTP connection
may be maintained for before being expired from the pool.
http1: A boolean indicating if HTTP/1.1 requests should be supported
by the connection pool. Defaults to True.
http2: A boolean indicating if HTTP/2 requests should be supported by
the connection pool. Defaults to False.
retries: The maximum number of retries when trying to establish a
connection.
local_address: Local address to connect from. Can also be used to connect
using a particular address family. Using `local_address="0.0.0.0"`
will connect using an `AF_INET` address (IPv4), while using
`local_address="::"` will connect using an `AF_INET6` address (IPv6).
uds: Path to a Unix Domain Socket to use instead of TCP sockets.
network_backend: A backend instance to use for handling network I/O.
socket_options: Socket options that have to be included
in the TCP socket when the connection was established.
"""
self._ssl_context = ssl_context
self._max_connections = (
sys.maxsize if max_connections is None else max_connections
)
self._max_keepalive_connections = (
sys.maxsize
if max_keepalive_connections is None
else max_keepalive_connections
)
self._max_keepalive_connections = min(
self._max_connections, self._max_keepalive_connections
)
self._keepalive_expiry = keepalive_expiry
self._http1 = http1
self._http2 = http2
self._retries = retries
self._local_address = local_address
self._uds = uds
self._pool: List[AsyncConnectionInterface] = []
self._requests: List[RequestStatus] = []
self._pool_lock = AsyncLock()
self._network_backend = (
AutoBackend() if network_backend is None else network_backend
)
self._socket_options = socket_options
def create_connection(self, origin: Origin) -> AsyncConnectionInterface:
return AsyncHTTPConnection(
origin=origin,
ssl_context=self._ssl_context,
keepalive_expiry=self._keepalive_expiry,
http1=self._http1,
http2=self._http2,
retries=self._retries,
local_address=self._local_address,
uds=self._uds,
network_backend=self._network_backend,
socket_options=self._socket_options,
)
@property
def connections(self) -> List[AsyncConnectionInterface]:
"""
Return a list of the connections currently in the pool.
For example:
```python
>>> pool.connections
[
<AsyncHTTPConnection ['https://example.com:443', HTTP/1.1, ACTIVE, Request Count: 6]>,
<AsyncHTTPConnection ['https://example.com:443', HTTP/1.1, IDLE, Request Count: 9]> ,
<AsyncHTTPConnection ['http://example.com:80', HTTP/1.1, IDLE, Request Count: 1]>,
]
```
"""
return list(self._pool)
async def _attempt_to_acquire_connection(self, status: RequestStatus) -> bool:
"""
Attempt to provide a connection that can handle the given origin.
"""
origin = status.request.url.origin
# If there are queued requests in front of us, then don't acquire a
# connection. We handle requests strictly in order.
waiting = [s for s in self._requests if s.connection is None]
if waiting and waiting[0] is not status:
return False
# Reuse an existing connection if one is currently available.
for idx, connection in enumerate(self._pool):
if connection.can_handle_request(origin) and connection.is_available():
self._pool.pop(idx)
self._pool.insert(0, connection)
status.set_connection(connection)
return True
# If the pool is currently full, attempt to close one idle connection.
if len(self._pool) >= self._max_connections:
for idx, connection in reversed(list(enumerate(self._pool))):
if connection.is_idle():
await connection.aclose()
self._pool.pop(idx)
break
# If the pool is still full, then we cannot acquire a connection.
if len(self._pool) >= self._max_connections:
return False
# Otherwise create a new connection.
connection = self.create_connection(origin)
self._pool.insert(0, connection)
status.set_connection(connection)
return True
async def _close_expired_connections(self) -> None:
"""
Clean up the connection pool by closing off any connections that have expired.
"""
# Close any connections that have expired their keep-alive time.
for idx, connection in reversed(list(enumerate(self._pool))):
if connection.has_expired():
await connection.aclose()
self._pool.pop(idx)
# If the pool size exceeds the maximum number of allowed keep-alive connections,
# then close off idle connections as required.
pool_size = len(self._pool)
for idx, connection in reversed(list(enumerate(self._pool))):
if connection.is_idle() and pool_size > self._max_keepalive_connections:
await connection.aclose()
self._pool.pop(idx)
pool_size -= 1
async def handle_async_request(self, request: Request) -> Response:
"""
Send an HTTP request, and return an HTTP response.
This is the core implementation that is called into by `.request()` or `.stream()`.
"""
scheme = request.url.scheme.decode()
if scheme == "":
raise UnsupportedProtocol(
"Request URL is missing an 'http://' or 'https://' protocol."
)
if scheme not in ("http", "https", "ws", "wss"):
raise UnsupportedProtocol(
f"Request URL has an unsupported protocol '{scheme}://'."
)
status = RequestStatus(request)
async with self._pool_lock:
self._requests.append(status)
await self._close_expired_connections()
await self._attempt_to_acquire_connection(status)
while True:
timeouts = request.extensions.get("timeout", {})
timeout = timeouts.get("pool", None)
try:
connection = await status.wait_for_connection(timeout=timeout)
except BaseException as exc:
# If we timeout here, or if the task is cancelled, then make
# sure to remove the request from the queue before bubbling
# up the exception.
async with self._pool_lock:
# Ensure only remove when task exists.
if status in self._requests:
self._requests.remove(status)
raise exc
try:
response = await connection.handle_async_request(request)
except ConnectionNotAvailable:
# The ConnectionNotAvailable exception is a special case, that
# indicates we need to retry the request on a new connection.
#
# The most common case where this can occur is when multiple
# requests are queued waiting for a single connection, which
# might end up as an HTTP/2 connection, but which actually ends
# up as HTTP/1.1.
async with self._pool_lock:
# Maintain our position in the request queue, but reset the
# status so that the request becomes queued again.
status.unset_connection()
await self._attempt_to_acquire_connection(status)
except BaseException as exc:
with AsyncShieldCancellation():
await self.response_closed(status)
raise exc
else:
break
# When we return the response, we wrap the stream in a special class
# that handles notifying the connection pool once the response
# has been released.
assert isinstance(response.stream, AsyncIterable)
return Response(
status=response.status,
headers=response.headers,
content=ConnectionPoolByteStream(response.stream, self, status),
extensions=response.extensions,
)
async def response_closed(self, status: RequestStatus) -> None:
"""
This method acts as a callback once the request/response cycle is complete.
It is called into from the `ConnectionPoolByteStream.aclose()` method.
"""
assert status.connection is not None
connection = status.connection
async with self._pool_lock:
# Update the state of the connection pool.
if status in self._requests:
self._requests.remove(status)
if connection.is_closed() and connection in self._pool:
self._pool.remove(connection)
# Since we've had a response closed, it's possible we'll now be able
# to service one or more requests that are currently pending.
for status in self._requests:
if status.connection is None:
acquired = await self._attempt_to_acquire_connection(status)
# If we could not acquire a connection for a queued request
# then we don't need to check anymore requests that are
# queued later behind it.
if not acquired:
break
# Housekeeping.
await self._close_expired_connections()
async def aclose(self) -> None:
"""
Close any connections in the pool.
"""
async with self._pool_lock:
for connection in self._pool:
await connection.aclose()
self._pool = []
self._requests = []
async def __aenter__(self) -> "AsyncConnectionPool":
return self
async def __aexit__(
self,
exc_type: Optional[Type[BaseException]] = None,
exc_value: Optional[BaseException] = None,
traceback: Optional[TracebackType] = None,
) -> None:
await self.aclose()
class ConnectionPoolByteStream:
"""
A wrapper around the response byte stream, that additionally handles
notifying the connection pool when the response has been closed.
"""
def __init__(
self,
stream: AsyncIterable[bytes],
pool: AsyncConnectionPool,
status: RequestStatus,
) -> None:
self._stream = stream
self._pool = pool
self._status = status
async def __aiter__(self) -> AsyncIterator[bytes]:
async for part in self._stream:
yield part
async def aclose(self) -> None:
try:
if hasattr(self._stream, "aclose"):
await self._stream.aclose()
finally:
with AsyncShieldCancellation():
await self._pool.response_closed(self._status)

View File

@ -0,0 +1,331 @@
import enum
import logging
import time
from types import TracebackType
from typing import (
AsyncIterable,
AsyncIterator,
List,
Optional,
Tuple,
Type,
Union,
cast,
)
import h11
from .._backends.base import AsyncNetworkStream
from .._exceptions import (
ConnectionNotAvailable,
LocalProtocolError,
RemoteProtocolError,
map_exceptions,
)
from .._models import Origin, Request, Response
from .._synchronization import AsyncLock, AsyncShieldCancellation
from .._trace import Trace
from .interfaces import AsyncConnectionInterface
logger = logging.getLogger("httpcore.http11")
# A subset of `h11.Event` types supported by `_send_event`
H11SendEvent = Union[
h11.Request,
h11.Data,
h11.EndOfMessage,
]
class HTTPConnectionState(enum.IntEnum):
NEW = 0
ACTIVE = 1
IDLE = 2
CLOSED = 3
class AsyncHTTP11Connection(AsyncConnectionInterface):
READ_NUM_BYTES = 64 * 1024
MAX_INCOMPLETE_EVENT_SIZE = 100 * 1024
def __init__(
self,
origin: Origin,
stream: AsyncNetworkStream,
keepalive_expiry: Optional[float] = None,
) -> None:
self._origin = origin
self._network_stream = stream
self._keepalive_expiry: Optional[float] = keepalive_expiry
self._expire_at: Optional[float] = None
self._state = HTTPConnectionState.NEW
self._state_lock = AsyncLock()
self._request_count = 0
self._h11_state = h11.Connection(
our_role=h11.CLIENT,
max_incomplete_event_size=self.MAX_INCOMPLETE_EVENT_SIZE,
)
async def handle_async_request(self, request: Request) -> Response:
if not self.can_handle_request(request.url.origin):
raise RuntimeError(
f"Attempted to send request to {request.url.origin} on connection "
f"to {self._origin}"
)
async with self._state_lock:
if self._state in (HTTPConnectionState.NEW, HTTPConnectionState.IDLE):
self._request_count += 1
self._state = HTTPConnectionState.ACTIVE
self._expire_at = None
else:
raise ConnectionNotAvailable()
try:
kwargs = {"request": request}
async with Trace("send_request_headers", logger, request, kwargs) as trace:
await self._send_request_headers(**kwargs)
async with Trace("send_request_body", logger, request, kwargs) as trace:
await self._send_request_body(**kwargs)
async with Trace(
"receive_response_headers", logger, request, kwargs
) as trace:
(
http_version,
status,
reason_phrase,
headers,
) = await self._receive_response_headers(**kwargs)
trace.return_value = (
http_version,
status,
reason_phrase,
headers,
)
return Response(
status=status,
headers=headers,
content=HTTP11ConnectionByteStream(self, request),
extensions={
"http_version": http_version,
"reason_phrase": reason_phrase,
"network_stream": self._network_stream,
},
)
except BaseException as exc:
with AsyncShieldCancellation():
async with Trace("response_closed", logger, request) as trace:
await self._response_closed()
raise exc
# Sending the request...
async def _send_request_headers(self, request: Request) -> None:
timeouts = request.extensions.get("timeout", {})
timeout = timeouts.get("write", None)
with map_exceptions({h11.LocalProtocolError: LocalProtocolError}):
event = h11.Request(
method=request.method,
target=request.url.target,
headers=request.headers,
)
await self._send_event(event, timeout=timeout)
async def _send_request_body(self, request: Request) -> None:
timeouts = request.extensions.get("timeout", {})
timeout = timeouts.get("write", None)
assert isinstance(request.stream, AsyncIterable)
async for chunk in request.stream:
event = h11.Data(data=chunk)
await self._send_event(event, timeout=timeout)
await self._send_event(h11.EndOfMessage(), timeout=timeout)
async def _send_event(
self, event: h11.Event, timeout: Optional[float] = None
) -> None:
bytes_to_send = self._h11_state.send(event)
if bytes_to_send is not None:
await self._network_stream.write(bytes_to_send, timeout=timeout)
# Receiving the response...
async def _receive_response_headers(
self, request: Request
) -> Tuple[bytes, int, bytes, List[Tuple[bytes, bytes]]]:
timeouts = request.extensions.get("timeout", {})
timeout = timeouts.get("read", None)
while True:
event = await self._receive_event(timeout=timeout)
if isinstance(event, h11.Response):
break
if (
isinstance(event, h11.InformationalResponse)
and event.status_code == 101
):
break
http_version = b"HTTP/" + event.http_version
# h11 version 0.11+ supports a `raw_items` interface to get the
# raw header casing, rather than the enforced lowercase headers.
headers = event.headers.raw_items()
return http_version, event.status_code, event.reason, headers
async def _receive_response_body(self, request: Request) -> AsyncIterator[bytes]:
timeouts = request.extensions.get("timeout", {})
timeout = timeouts.get("read", None)
while True:
event = await self._receive_event(timeout=timeout)
if isinstance(event, h11.Data):
yield bytes(event.data)
elif isinstance(event, (h11.EndOfMessage, h11.PAUSED)):
break
async def _receive_event(
self, timeout: Optional[float] = None
) -> Union[h11.Event, Type[h11.PAUSED]]:
while True:
with map_exceptions({h11.RemoteProtocolError: RemoteProtocolError}):
event = self._h11_state.next_event()
if event is h11.NEED_DATA:
data = await self._network_stream.read(
self.READ_NUM_BYTES, timeout=timeout
)
# If we feed this case through h11 we'll raise an exception like:
#
# httpcore.RemoteProtocolError: can't handle event type
# ConnectionClosed when role=SERVER and state=SEND_RESPONSE
#
# Which is accurate, but not very informative from an end-user
# perspective. Instead we handle this case distinctly and treat
# it as a ConnectError.
if data == b"" and self._h11_state.their_state == h11.SEND_RESPONSE:
msg = "Server disconnected without sending a response."
raise RemoteProtocolError(msg)
self._h11_state.receive_data(data)
else:
# mypy fails to narrow the type in the above if statement above
return cast(Union[h11.Event, Type[h11.PAUSED]], event)
async def _response_closed(self) -> None:
async with self._state_lock:
if (
self._h11_state.our_state is h11.DONE
and self._h11_state.their_state is h11.DONE
):
self._state = HTTPConnectionState.IDLE
self._h11_state.start_next_cycle()
if self._keepalive_expiry is not None:
now = time.monotonic()
self._expire_at = now + self._keepalive_expiry
else:
await self.aclose()
# Once the connection is no longer required...
async def aclose(self) -> None:
# Note that this method unilaterally closes the connection, and does
# not have any kind of locking in place around it.
self._state = HTTPConnectionState.CLOSED
await self._network_stream.aclose()
# The AsyncConnectionInterface methods provide information about the state of
# the connection, allowing for a connection pooling implementation to
# determine when to reuse and when to close the connection...
def can_handle_request(self, origin: Origin) -> bool:
return origin == self._origin
def is_available(self) -> bool:
# Note that HTTP/1.1 connections in the "NEW" state are not treated as
# being "available". The control flow which created the connection will
# be able to send an outgoing request, but the connection will not be
# acquired from the connection pool for any other request.
return self._state == HTTPConnectionState.IDLE
def has_expired(self) -> bool:
now = time.monotonic()
keepalive_expired = self._expire_at is not None and now > self._expire_at
# If the HTTP connection is idle but the socket is readable, then the
# only valid state is that the socket is about to return b"", indicating
# a server-initiated disconnect.
server_disconnected = (
self._state == HTTPConnectionState.IDLE
and self._network_stream.get_extra_info("is_readable")
)
return keepalive_expired or server_disconnected
def is_idle(self) -> bool:
return self._state == HTTPConnectionState.IDLE
def is_closed(self) -> bool:
return self._state == HTTPConnectionState.CLOSED
def info(self) -> str:
origin = str(self._origin)
return (
f"{origin!r}, HTTP/1.1, {self._state.name}, "
f"Request Count: {self._request_count}"
)
def __repr__(self) -> str:
class_name = self.__class__.__name__
origin = str(self._origin)
return (
f"<{class_name} [{origin!r}, {self._state.name}, "
f"Request Count: {self._request_count}]>"
)
# These context managers are not used in the standard flow, but are
# useful for testing or working with connection instances directly.
async def __aenter__(self) -> "AsyncHTTP11Connection":
return self
async def __aexit__(
self,
exc_type: Optional[Type[BaseException]] = None,
exc_value: Optional[BaseException] = None,
traceback: Optional[TracebackType] = None,
) -> None:
await self.aclose()
class HTTP11ConnectionByteStream:
def __init__(self, connection: AsyncHTTP11Connection, request: Request) -> None:
self._connection = connection
self._request = request
self._closed = False
async def __aiter__(self) -> AsyncIterator[bytes]:
kwargs = {"request": self._request}
try:
async with Trace("receive_response_body", logger, self._request, kwargs):
async for chunk in self._connection._receive_response_body(**kwargs):
yield chunk
except BaseException as exc:
# If we get an exception while streaming the response,
# we want to close the response (and possibly the connection)
# before raising that exception.
with AsyncShieldCancellation():
await self.aclose()
raise exc
async def aclose(self) -> None:
if not self._closed:
self._closed = True
async with Trace("response_closed", logger, self._request):
await self._connection._response_closed()

View File

@ -0,0 +1,589 @@
import enum
import logging
import time
import types
import typing
import h2.config
import h2.connection
import h2.events
import h2.exceptions
import h2.settings
from .._backends.base import AsyncNetworkStream
from .._exceptions import (
ConnectionNotAvailable,
LocalProtocolError,
RemoteProtocolError,
)
from .._models import Origin, Request, Response
from .._synchronization import AsyncLock, AsyncSemaphore, AsyncShieldCancellation
from .._trace import Trace
from .interfaces import AsyncConnectionInterface
logger = logging.getLogger("httpcore.http2")
def has_body_headers(request: Request) -> bool:
return any(
k.lower() == b"content-length" or k.lower() == b"transfer-encoding"
for k, v in request.headers
)
class HTTPConnectionState(enum.IntEnum):
ACTIVE = 1
IDLE = 2
CLOSED = 3
class AsyncHTTP2Connection(AsyncConnectionInterface):
READ_NUM_BYTES = 64 * 1024
CONFIG = h2.config.H2Configuration(validate_inbound_headers=False)
def __init__(
self,
origin: Origin,
stream: AsyncNetworkStream,
keepalive_expiry: typing.Optional[float] = None,
):
self._origin = origin
self._network_stream = stream
self._keepalive_expiry: typing.Optional[float] = keepalive_expiry
self._h2_state = h2.connection.H2Connection(config=self.CONFIG)
self._state = HTTPConnectionState.IDLE
self._expire_at: typing.Optional[float] = None
self._request_count = 0
self._init_lock = AsyncLock()
self._state_lock = AsyncLock()
self._read_lock = AsyncLock()
self._write_lock = AsyncLock()
self._sent_connection_init = False
self._used_all_stream_ids = False
self._connection_error = False
# Mapping from stream ID to response stream events.
self._events: typing.Dict[
int,
typing.Union[
h2.events.ResponseReceived,
h2.events.DataReceived,
h2.events.StreamEnded,
h2.events.StreamReset,
],
] = {}
# Connection terminated events are stored as state since
# we need to handle them for all streams.
self._connection_terminated: typing.Optional[
h2.events.ConnectionTerminated
] = None
self._read_exception: typing.Optional[Exception] = None
self._write_exception: typing.Optional[Exception] = None
async def handle_async_request(self, request: Request) -> Response:
if not self.can_handle_request(request.url.origin):
# This cannot occur in normal operation, since the connection pool
# will only send requests on connections that handle them.
# It's in place simply for resilience as a guard against incorrect
# usage, for anyone working directly with httpcore connections.
raise RuntimeError(
f"Attempted to send request to {request.url.origin} on connection "
f"to {self._origin}"
)
async with self._state_lock:
if self._state in (HTTPConnectionState.ACTIVE, HTTPConnectionState.IDLE):
self._request_count += 1
self._expire_at = None
self._state = HTTPConnectionState.ACTIVE
else:
raise ConnectionNotAvailable()
async with self._init_lock:
if not self._sent_connection_init:
try:
kwargs = {"request": request}
async with Trace("send_connection_init", logger, request, kwargs):
await self._send_connection_init(**kwargs)
except BaseException as exc:
with AsyncShieldCancellation():
await self.aclose()
raise exc
self._sent_connection_init = True
# Initially start with just 1 until the remote server provides
# its max_concurrent_streams value
self._max_streams = 1
local_settings_max_streams = (
self._h2_state.local_settings.max_concurrent_streams
)
self._max_streams_semaphore = AsyncSemaphore(local_settings_max_streams)
for _ in range(local_settings_max_streams - self._max_streams):
await self._max_streams_semaphore.acquire()
await self._max_streams_semaphore.acquire()
try:
stream_id = self._h2_state.get_next_available_stream_id()
self._events[stream_id] = []
except h2.exceptions.NoAvailableStreamIDError: # pragma: nocover
self._used_all_stream_ids = True
self._request_count -= 1
raise ConnectionNotAvailable()
try:
kwargs = {"request": request, "stream_id": stream_id}
async with Trace("send_request_headers", logger, request, kwargs):
await self._send_request_headers(request=request, stream_id=stream_id)
async with Trace("send_request_body", logger, request, kwargs):
await self._send_request_body(request=request, stream_id=stream_id)
async with Trace(
"receive_response_headers", logger, request, kwargs
) as trace:
status, headers = await self._receive_response(
request=request, stream_id=stream_id
)
trace.return_value = (status, headers)
return Response(
status=status,
headers=headers,
content=HTTP2ConnectionByteStream(self, request, stream_id=stream_id),
extensions={
"http_version": b"HTTP/2",
"network_stream": self._network_stream,
"stream_id": stream_id,
},
)
except BaseException as exc: # noqa: PIE786
with AsyncShieldCancellation():
kwargs = {"stream_id": stream_id}
async with Trace("response_closed", logger, request, kwargs):
await self._response_closed(stream_id=stream_id)
if isinstance(exc, h2.exceptions.ProtocolError):
# One case where h2 can raise a protocol error is when a
# closed frame has been seen by the state machine.
#
# This happens when one stream is reading, and encounters
# a GOAWAY event. Other flows of control may then raise
# a protocol error at any point they interact with the 'h2_state'.
#
# In this case we'll have stored the event, and should raise
# it as a RemoteProtocolError.
if self._connection_terminated: # pragma: nocover
raise RemoteProtocolError(self._connection_terminated)
# If h2 raises a protocol error in some other state then we
# must somehow have made a protocol violation.
raise LocalProtocolError(exc) # pragma: nocover
raise exc
async def _send_connection_init(self, request: Request) -> None:
"""
The HTTP/2 connection requires some initial setup before we can start
using individual request/response streams on it.
"""
# Need to set these manually here instead of manipulating via
# __setitem__() otherwise the H2Connection will emit SettingsUpdate
# frames in addition to sending the undesired defaults.
self._h2_state.local_settings = h2.settings.Settings(
client=True,
initial_values={
# Disable PUSH_PROMISE frames from the server since we don't do anything
# with them for now. Maybe when we support caching?
h2.settings.SettingCodes.ENABLE_PUSH: 0,
# These two are taken from h2 for safe defaults
h2.settings.SettingCodes.MAX_CONCURRENT_STREAMS: 100,
h2.settings.SettingCodes.MAX_HEADER_LIST_SIZE: 65536,
},
)
# Some websites (*cough* Yahoo *cough*) balk at this setting being
# present in the initial handshake since it's not defined in the original
# RFC despite the RFC mandating ignoring settings you don't know about.
del self._h2_state.local_settings[
h2.settings.SettingCodes.ENABLE_CONNECT_PROTOCOL
]
self._h2_state.initiate_connection()
self._h2_state.increment_flow_control_window(2**24)
await self._write_outgoing_data(request)
# Sending the request...
async def _send_request_headers(self, request: Request, stream_id: int) -> None:
"""
Send the request headers to a given stream ID.
"""
end_stream = not has_body_headers(request)
# In HTTP/2 the ':authority' pseudo-header is used instead of 'Host'.
# In order to gracefully handle HTTP/1.1 and HTTP/2 we always require
# HTTP/1.1 style headers, and map them appropriately if we end up on
# an HTTP/2 connection.
authority = [v for k, v in request.headers if k.lower() == b"host"][0]
headers = [
(b":method", request.method),
(b":authority", authority),
(b":scheme", request.url.scheme),
(b":path", request.url.target),
] + [
(k.lower(), v)
for k, v in request.headers
if k.lower()
not in (
b"host",
b"transfer-encoding",
)
]
self._h2_state.send_headers(stream_id, headers, end_stream=end_stream)
self._h2_state.increment_flow_control_window(2**24, stream_id=stream_id)
await self._write_outgoing_data(request)
async def _send_request_body(self, request: Request, stream_id: int) -> None:
"""
Iterate over the request body sending it to a given stream ID.
"""
if not has_body_headers(request):
return
assert isinstance(request.stream, typing.AsyncIterable)
async for data in request.stream:
await self._send_stream_data(request, stream_id, data)
await self._send_end_stream(request, stream_id)
async def _send_stream_data(
self, request: Request, stream_id: int, data: bytes
) -> None:
"""
Send a single chunk of data in one or more data frames.
"""
while data:
max_flow = await self._wait_for_outgoing_flow(request, stream_id)
chunk_size = min(len(data), max_flow)
chunk, data = data[:chunk_size], data[chunk_size:]
self._h2_state.send_data(stream_id, chunk)
await self._write_outgoing_data(request)
async def _send_end_stream(self, request: Request, stream_id: int) -> None:
"""
Send an empty data frame on on a given stream ID with the END_STREAM flag set.
"""
self._h2_state.end_stream(stream_id)
await self._write_outgoing_data(request)
# Receiving the response...
async def _receive_response(
self, request: Request, stream_id: int
) -> typing.Tuple[int, typing.List[typing.Tuple[bytes, bytes]]]:
"""
Return the response status code and headers for a given stream ID.
"""
while True:
event = await self._receive_stream_event(request, stream_id)
if isinstance(event, h2.events.ResponseReceived):
break
status_code = 200
headers = []
for k, v in event.headers:
if k == b":status":
status_code = int(v.decode("ascii", errors="ignore"))
elif not k.startswith(b":"):
headers.append((k, v))
return (status_code, headers)
async def _receive_response_body(
self, request: Request, stream_id: int
) -> typing.AsyncIterator[bytes]:
"""
Iterator that returns the bytes of the response body for a given stream ID.
"""
while True:
event = await self._receive_stream_event(request, stream_id)
if isinstance(event, h2.events.DataReceived):
amount = event.flow_controlled_length
self._h2_state.acknowledge_received_data(amount, stream_id)
await self._write_outgoing_data(request)
yield event.data
elif isinstance(event, h2.events.StreamEnded):
break
async def _receive_stream_event(
self, request: Request, stream_id: int
) -> typing.Union[
h2.events.ResponseReceived, h2.events.DataReceived, h2.events.StreamEnded
]:
"""
Return the next available event for a given stream ID.
Will read more data from the network if required.
"""
while not self._events.get(stream_id):
await self._receive_events(request, stream_id)
event = self._events[stream_id].pop(0)
if isinstance(event, h2.events.StreamReset):
raise RemoteProtocolError(event)
return event
async def _receive_events(
self, request: Request, stream_id: typing.Optional[int] = None
) -> None:
"""
Read some data from the network until we see one or more events
for a given stream ID.
"""
async with self._read_lock:
if self._connection_terminated is not None:
last_stream_id = self._connection_terminated.last_stream_id
if stream_id and last_stream_id and stream_id > last_stream_id:
self._request_count -= 1
raise ConnectionNotAvailable()
raise RemoteProtocolError(self._connection_terminated)
# This conditional is a bit icky. We don't want to block reading if we've
# actually got an event to return for a given stream. We need to do that
# check *within* the atomic read lock. Though it also need to be optional,
# because when we call it from `_wait_for_outgoing_flow` we *do* want to
# block until we've available flow control, event when we have events
# pending for the stream ID we're attempting to send on.
if stream_id is None or not self._events.get(stream_id):
events = await self._read_incoming_data(request)
for event in events:
if isinstance(event, h2.events.RemoteSettingsChanged):
async with Trace(
"receive_remote_settings", logger, request
) as trace:
await self._receive_remote_settings_change(event)
trace.return_value = event
elif isinstance(
event,
(
h2.events.ResponseReceived,
h2.events.DataReceived,
h2.events.StreamEnded,
h2.events.StreamReset,
),
):
if event.stream_id in self._events:
self._events[event.stream_id].append(event)
elif isinstance(event, h2.events.ConnectionTerminated):
self._connection_terminated = event
await self._write_outgoing_data(request)
async def _receive_remote_settings_change(self, event: h2.events.Event) -> None:
max_concurrent_streams = event.changed_settings.get(
h2.settings.SettingCodes.MAX_CONCURRENT_STREAMS
)
if max_concurrent_streams:
new_max_streams = min(
max_concurrent_streams.new_value,
self._h2_state.local_settings.max_concurrent_streams,
)
if new_max_streams and new_max_streams != self._max_streams:
while new_max_streams > self._max_streams:
await self._max_streams_semaphore.release()
self._max_streams += 1
while new_max_streams < self._max_streams:
await self._max_streams_semaphore.acquire()
self._max_streams -= 1
async def _response_closed(self, stream_id: int) -> None:
await self._max_streams_semaphore.release()
del self._events[stream_id]
async with self._state_lock:
if self._connection_terminated and not self._events:
await self.aclose()
elif self._state == HTTPConnectionState.ACTIVE and not self._events:
self._state = HTTPConnectionState.IDLE
if self._keepalive_expiry is not None:
now = time.monotonic()
self._expire_at = now + self._keepalive_expiry
if self._used_all_stream_ids: # pragma: nocover
await self.aclose()
async def aclose(self) -> None:
# Note that this method unilaterally closes the connection, and does
# not have any kind of locking in place around it.
self._h2_state.close_connection()
self._state = HTTPConnectionState.CLOSED
await self._network_stream.aclose()
# Wrappers around network read/write operations...
async def _read_incoming_data(
self, request: Request
) -> typing.List[h2.events.Event]:
timeouts = request.extensions.get("timeout", {})
timeout = timeouts.get("read", None)
if self._read_exception is not None:
raise self._read_exception # pragma: nocover
try:
data = await self._network_stream.read(self.READ_NUM_BYTES, timeout)
if data == b"":
raise RemoteProtocolError("Server disconnected")
except Exception as exc:
# If we get a network error we should:
#
# 1. Save the exception and just raise it immediately on any future reads.
# (For example, this means that a single read timeout or disconnect will
# immediately close all pending streams. Without requiring multiple
# sequential timeouts.)
# 2. Mark the connection as errored, so that we don't accept any other
# incoming requests.
self._read_exception = exc
self._connection_error = True
raise exc
events: typing.List[h2.events.Event] = self._h2_state.receive_data(data)
return events
async def _write_outgoing_data(self, request: Request) -> None:
timeouts = request.extensions.get("timeout", {})
timeout = timeouts.get("write", None)
async with self._write_lock:
data_to_send = self._h2_state.data_to_send()
if self._write_exception is not None:
raise self._write_exception # pragma: nocover
try:
await self._network_stream.write(data_to_send, timeout)
except Exception as exc: # pragma: nocover
# If we get a network error we should:
#
# 1. Save the exception and just raise it immediately on any future write.
# (For example, this means that a single write timeout or disconnect will
# immediately close all pending streams. Without requiring multiple
# sequential timeouts.)
# 2. Mark the connection as errored, so that we don't accept any other
# incoming requests.
self._write_exception = exc
self._connection_error = True
raise exc
# Flow control...
async def _wait_for_outgoing_flow(self, request: Request, stream_id: int) -> int:
"""
Returns the maximum allowable outgoing flow for a given stream.
If the allowable flow is zero, then waits on the network until
WindowUpdated frames have increased the flow rate.
https://tools.ietf.org/html/rfc7540#section-6.9
"""
local_flow: int = self._h2_state.local_flow_control_window(stream_id)
max_frame_size: int = self._h2_state.max_outbound_frame_size
flow = min(local_flow, max_frame_size)
while flow == 0:
await self._receive_events(request)
local_flow = self._h2_state.local_flow_control_window(stream_id)
max_frame_size = self._h2_state.max_outbound_frame_size
flow = min(local_flow, max_frame_size)
return flow
# Interface for connection pooling...
def can_handle_request(self, origin: Origin) -> bool:
return origin == self._origin
def is_available(self) -> bool:
return (
self._state != HTTPConnectionState.CLOSED
and not self._connection_error
and not self._used_all_stream_ids
and not (
self._h2_state.state_machine.state
== h2.connection.ConnectionState.CLOSED
)
)
def has_expired(self) -> bool:
now = time.monotonic()
return self._expire_at is not None and now > self._expire_at
def is_idle(self) -> bool:
return self._state == HTTPConnectionState.IDLE
def is_closed(self) -> bool:
return self._state == HTTPConnectionState.CLOSED
def info(self) -> str:
origin = str(self._origin)
return (
f"{origin!r}, HTTP/2, {self._state.name}, "
f"Request Count: {self._request_count}"
)
def __repr__(self) -> str:
class_name = self.__class__.__name__
origin = str(self._origin)
return (
f"<{class_name} [{origin!r}, {self._state.name}, "
f"Request Count: {self._request_count}]>"
)
# These context managers are not used in the standard flow, but are
# useful for testing or working with connection instances directly.
async def __aenter__(self) -> "AsyncHTTP2Connection":
return self
async def __aexit__(
self,
exc_type: typing.Optional[typing.Type[BaseException]] = None,
exc_value: typing.Optional[BaseException] = None,
traceback: typing.Optional[types.TracebackType] = None,
) -> None:
await self.aclose()
class HTTP2ConnectionByteStream:
def __init__(
self, connection: AsyncHTTP2Connection, request: Request, stream_id: int
) -> None:
self._connection = connection
self._request = request
self._stream_id = stream_id
self._closed = False
async def __aiter__(self) -> typing.AsyncIterator[bytes]:
kwargs = {"request": self._request, "stream_id": self._stream_id}
try:
async with Trace("receive_response_body", logger, self._request, kwargs):
async for chunk in self._connection._receive_response_body(
request=self._request, stream_id=self._stream_id
):
yield chunk
except BaseException as exc:
# If we get an exception while streaming the response,
# we want to close the response (and possibly the connection)
# before raising that exception.
with AsyncShieldCancellation():
await self.aclose()
raise exc
async def aclose(self) -> None:
if not self._closed:
self._closed = True
kwargs = {"stream_id": self._stream_id}
async with Trace("response_closed", logger, self._request, kwargs):
await self._connection._response_closed(stream_id=self._stream_id)

View File

@ -0,0 +1,350 @@
import logging
import ssl
from base64 import b64encode
from typing import Iterable, List, Mapping, Optional, Sequence, Tuple, Union
from .._backends.base import SOCKET_OPTION, AsyncNetworkBackend
from .._exceptions import ProxyError
from .._models import (
URL,
Origin,
Request,
Response,
enforce_bytes,
enforce_headers,
enforce_url,
)
from .._ssl import default_ssl_context
from .._synchronization import AsyncLock
from .._trace import Trace
from .connection import AsyncHTTPConnection
from .connection_pool import AsyncConnectionPool
from .http11 import AsyncHTTP11Connection
from .interfaces import AsyncConnectionInterface
HeadersAsSequence = Sequence[Tuple[Union[bytes, str], Union[bytes, str]]]
HeadersAsMapping = Mapping[Union[bytes, str], Union[bytes, str]]
logger = logging.getLogger("httpcore.proxy")
def merge_headers(
default_headers: Optional[Sequence[Tuple[bytes, bytes]]] = None,
override_headers: Optional[Sequence[Tuple[bytes, bytes]]] = None,
) -> List[Tuple[bytes, bytes]]:
"""
Append default_headers and override_headers, de-duplicating if a key exists
in both cases.
"""
default_headers = [] if default_headers is None else list(default_headers)
override_headers = [] if override_headers is None else list(override_headers)
has_override = set(key.lower() for key, value in override_headers)
default_headers = [
(key, value)
for key, value in default_headers
if key.lower() not in has_override
]
return default_headers + override_headers
def build_auth_header(username: bytes, password: bytes) -> bytes:
userpass = username + b":" + password
return b"Basic " + b64encode(userpass)
class AsyncHTTPProxy(AsyncConnectionPool):
"""
A connection pool that sends requests via an HTTP proxy.
"""
def __init__(
self,
proxy_url: Union[URL, bytes, str],
proxy_auth: Optional[Tuple[Union[bytes, str], Union[bytes, str]]] = None,
proxy_headers: Union[HeadersAsMapping, HeadersAsSequence, None] = None,
ssl_context: Optional[ssl.SSLContext] = None,
max_connections: Optional[int] = 10,
max_keepalive_connections: Optional[int] = None,
keepalive_expiry: Optional[float] = None,
http1: bool = True,
http2: bool = False,
retries: int = 0,
local_address: Optional[str] = None,
uds: Optional[str] = None,
network_backend: Optional[AsyncNetworkBackend] = None,
socket_options: Optional[Iterable[SOCKET_OPTION]] = None,
) -> None:
"""
A connection pool for making HTTP requests.
Parameters:
proxy_url: The URL to use when connecting to the proxy server.
For example `"http://127.0.0.1:8080/"`.
proxy_auth: Any proxy authentication as a two-tuple of
(username, password). May be either bytes or ascii-only str.
proxy_headers: Any HTTP headers to use for the proxy requests.
For example `{"Proxy-Authorization": "Basic <username>:<password>"}`.
ssl_context: An SSL context to use for verifying connections.
If not specified, the default `httpcore.default_ssl_context()`
will be used.
max_connections: The maximum number of concurrent HTTP connections that
the pool should allow. Any attempt to send a request on a pool that
would exceed this amount will block until a connection is available.
max_keepalive_connections: The maximum number of idle HTTP connections
that will be maintained in the pool.
keepalive_expiry: The duration in seconds that an idle HTTP connection
may be maintained for before being expired from the pool.
http1: A boolean indicating if HTTP/1.1 requests should be supported
by the connection pool. Defaults to True.
http2: A boolean indicating if HTTP/2 requests should be supported by
the connection pool. Defaults to False.
retries: The maximum number of retries when trying to establish
a connection.
local_address: Local address to connect from. Can also be used to
connect using a particular address family. Using
`local_address="0.0.0.0"` will connect using an `AF_INET` address
(IPv4), while using `local_address="::"` will connect using an
`AF_INET6` address (IPv6).
uds: Path to a Unix Domain Socket to use instead of TCP sockets.
network_backend: A backend instance to use for handling network I/O.
"""
super().__init__(
ssl_context=ssl_context,
max_connections=max_connections,
max_keepalive_connections=max_keepalive_connections,
keepalive_expiry=keepalive_expiry,
http1=http1,
http2=http2,
network_backend=network_backend,
retries=retries,
local_address=local_address,
uds=uds,
socket_options=socket_options,
)
self._ssl_context = ssl_context
self._proxy_url = enforce_url(proxy_url, name="proxy_url")
self._proxy_headers = enforce_headers(proxy_headers, name="proxy_headers")
if proxy_auth is not None:
username = enforce_bytes(proxy_auth[0], name="proxy_auth")
password = enforce_bytes(proxy_auth[1], name="proxy_auth")
authorization = build_auth_header(username, password)
self._proxy_headers = [
(b"Proxy-Authorization", authorization)
] + self._proxy_headers
def create_connection(self, origin: Origin) -> AsyncConnectionInterface:
if origin.scheme == b"http":
return AsyncForwardHTTPConnection(
proxy_origin=self._proxy_url.origin,
proxy_headers=self._proxy_headers,
remote_origin=origin,
keepalive_expiry=self._keepalive_expiry,
network_backend=self._network_backend,
)
return AsyncTunnelHTTPConnection(
proxy_origin=self._proxy_url.origin,
proxy_headers=self._proxy_headers,
remote_origin=origin,
ssl_context=self._ssl_context,
keepalive_expiry=self._keepalive_expiry,
http1=self._http1,
http2=self._http2,
network_backend=self._network_backend,
)
class AsyncForwardHTTPConnection(AsyncConnectionInterface):
def __init__(
self,
proxy_origin: Origin,
remote_origin: Origin,
proxy_headers: Union[HeadersAsMapping, HeadersAsSequence, None] = None,
keepalive_expiry: Optional[float] = None,
network_backend: Optional[AsyncNetworkBackend] = None,
socket_options: Optional[Iterable[SOCKET_OPTION]] = None,
) -> None:
self._connection = AsyncHTTPConnection(
origin=proxy_origin,
keepalive_expiry=keepalive_expiry,
network_backend=network_backend,
socket_options=socket_options,
)
self._proxy_origin = proxy_origin
self._proxy_headers = enforce_headers(proxy_headers, name="proxy_headers")
self._remote_origin = remote_origin
async def handle_async_request(self, request: Request) -> Response:
headers = merge_headers(self._proxy_headers, request.headers)
url = URL(
scheme=self._proxy_origin.scheme,
host=self._proxy_origin.host,
port=self._proxy_origin.port,
target=bytes(request.url),
)
proxy_request = Request(
method=request.method,
url=url,
headers=headers,
content=request.stream,
extensions=request.extensions,
)
return await self._connection.handle_async_request(proxy_request)
def can_handle_request(self, origin: Origin) -> bool:
return origin == self._remote_origin
async def aclose(self) -> None:
await self._connection.aclose()
def info(self) -> str:
return self._connection.info()
def is_available(self) -> bool:
return self._connection.is_available()
def has_expired(self) -> bool:
return self._connection.has_expired()
def is_idle(self) -> bool:
return self._connection.is_idle()
def is_closed(self) -> bool:
return self._connection.is_closed()
def __repr__(self) -> str:
return f"<{self.__class__.__name__} [{self.info()}]>"
class AsyncTunnelHTTPConnection(AsyncConnectionInterface):
def __init__(
self,
proxy_origin: Origin,
remote_origin: Origin,
ssl_context: Optional[ssl.SSLContext] = None,
proxy_headers: Optional[Sequence[Tuple[bytes, bytes]]] = None,
keepalive_expiry: Optional[float] = None,
http1: bool = True,
http2: bool = False,
network_backend: Optional[AsyncNetworkBackend] = None,
socket_options: Optional[Iterable[SOCKET_OPTION]] = None,
) -> None:
self._connection: AsyncConnectionInterface = AsyncHTTPConnection(
origin=proxy_origin,
keepalive_expiry=keepalive_expiry,
network_backend=network_backend,
socket_options=socket_options,
)
self._proxy_origin = proxy_origin
self._remote_origin = remote_origin
self._ssl_context = ssl_context
self._proxy_headers = enforce_headers(proxy_headers, name="proxy_headers")
self._keepalive_expiry = keepalive_expiry
self._http1 = http1
self._http2 = http2
self._connect_lock = AsyncLock()
self._connected = False
async def handle_async_request(self, request: Request) -> Response:
timeouts = request.extensions.get("timeout", {})
timeout = timeouts.get("connect", None)
async with self._connect_lock:
if not self._connected:
target = b"%b:%d" % (self._remote_origin.host, self._remote_origin.port)
connect_url = URL(
scheme=self._proxy_origin.scheme,
host=self._proxy_origin.host,
port=self._proxy_origin.port,
target=target,
)
connect_headers = merge_headers(
[(b"Host", target), (b"Accept", b"*/*")], self._proxy_headers
)
connect_request = Request(
method=b"CONNECT",
url=connect_url,
headers=connect_headers,
extensions=request.extensions,
)
connect_response = await self._connection.handle_async_request(
connect_request
)
if connect_response.status < 200 or connect_response.status > 299:
reason_bytes = connect_response.extensions.get("reason_phrase", b"")
reason_str = reason_bytes.decode("ascii", errors="ignore")
msg = "%d %s" % (connect_response.status, reason_str)
await self._connection.aclose()
raise ProxyError(msg)
stream = connect_response.extensions["network_stream"]
# Upgrade the stream to SSL
ssl_context = (
default_ssl_context()
if self._ssl_context is None
else self._ssl_context
)
alpn_protocols = ["http/1.1", "h2"] if self._http2 else ["http/1.1"]
ssl_context.set_alpn_protocols(alpn_protocols)
kwargs = {
"ssl_context": ssl_context,
"server_hostname": self._remote_origin.host.decode("ascii"),
"timeout": timeout,
}
async with Trace("start_tls", logger, request, kwargs) as trace:
stream = await stream.start_tls(**kwargs)
trace.return_value = stream
# Determine if we should be using HTTP/1.1 or HTTP/2
ssl_object = stream.get_extra_info("ssl_object")
http2_negotiated = (
ssl_object is not None
and ssl_object.selected_alpn_protocol() == "h2"
)
# Create the HTTP/1.1 or HTTP/2 connection
if http2_negotiated or (self._http2 and not self._http1):
from .http2 import AsyncHTTP2Connection
self._connection = AsyncHTTP2Connection(
origin=self._remote_origin,
stream=stream,
keepalive_expiry=self._keepalive_expiry,
)
else:
self._connection = AsyncHTTP11Connection(
origin=self._remote_origin,
stream=stream,
keepalive_expiry=self._keepalive_expiry,
)
self._connected = True
return await self._connection.handle_async_request(request)
def can_handle_request(self, origin: Origin) -> bool:
return origin == self._remote_origin
async def aclose(self) -> None:
await self._connection.aclose()
def info(self) -> str:
return self._connection.info()
def is_available(self) -> bool:
return self._connection.is_available()
def has_expired(self) -> bool:
return self._connection.has_expired()
def is_idle(self) -> bool:
return self._connection.is_idle()
def is_closed(self) -> bool:
return self._connection.is_closed()
def __repr__(self) -> str:
return f"<{self.__class__.__name__} [{self.info()}]>"

View File

@ -0,0 +1,135 @@
from contextlib import asynccontextmanager
from typing import AsyncIterator, Optional, Union
from .._models import (
URL,
Extensions,
HeaderTypes,
Origin,
Request,
Response,
enforce_bytes,
enforce_headers,
enforce_url,
include_request_headers,
)
class AsyncRequestInterface:
async def request(
self,
method: Union[bytes, str],
url: Union[URL, bytes, str],
*,
headers: HeaderTypes = None,
content: Union[bytes, AsyncIterator[bytes], None] = None,
extensions: Optional[Extensions] = None,
) -> Response:
# Strict type checking on our parameters.
method = enforce_bytes(method, name="method")
url = enforce_url(url, name="url")
headers = enforce_headers(headers, name="headers")
# Include Host header, and optionally Content-Length or Transfer-Encoding.
headers = include_request_headers(headers, url=url, content=content)
request = Request(
method=method,
url=url,
headers=headers,
content=content,
extensions=extensions,
)
response = await self.handle_async_request(request)
try:
await response.aread()
finally:
await response.aclose()
return response
@asynccontextmanager
async def stream(
self,
method: Union[bytes, str],
url: Union[URL, bytes, str],
*,
headers: HeaderTypes = None,
content: Union[bytes, AsyncIterator[bytes], None] = None,
extensions: Optional[Extensions] = None,
) -> AsyncIterator[Response]:
# Strict type checking on our parameters.
method = enforce_bytes(method, name="method")
url = enforce_url(url, name="url")
headers = enforce_headers(headers, name="headers")
# Include Host header, and optionally Content-Length or Transfer-Encoding.
headers = include_request_headers(headers, url=url, content=content)
request = Request(
method=method,
url=url,
headers=headers,
content=content,
extensions=extensions,
)
response = await self.handle_async_request(request)
try:
yield response
finally:
await response.aclose()
async def handle_async_request(self, request: Request) -> Response:
raise NotImplementedError() # pragma: nocover
class AsyncConnectionInterface(AsyncRequestInterface):
async def aclose(self) -> None:
raise NotImplementedError() # pragma: nocover
def info(self) -> str:
raise NotImplementedError() # pragma: nocover
def can_handle_request(self, origin: Origin) -> bool:
raise NotImplementedError() # pragma: nocover
def is_available(self) -> bool:
"""
Return `True` if the connection is currently able to accept an
outgoing request.
An HTTP/1.1 connection will only be available if it is currently idle.
An HTTP/2 connection will be available so long as the stream ID space is
not yet exhausted, and the connection is not in an error state.
While the connection is being established we may not yet know if it is going
to result in an HTTP/1.1 or HTTP/2 connection. The connection should be
treated as being available, but might ultimately raise `NewConnectionRequired`
required exceptions if multiple requests are attempted over a connection
that ends up being established as HTTP/1.1.
"""
raise NotImplementedError() # pragma: nocover
def has_expired(self) -> bool:
"""
Return `True` if the connection is in a state where it should be closed.
This either means that the connection is idle and it has passed the
expiry time on its keep-alive, or that server has sent an EOF.
"""
raise NotImplementedError() # pragma: nocover
def is_idle(self) -> bool:
"""
Return `True` if the connection is currently idle.
"""
raise NotImplementedError() # pragma: nocover
def is_closed(self) -> bool:
"""
Return `True` if the connection has been closed.
Used when a response is closed to determine if the connection may be
returned to the connection pool or not.
"""
raise NotImplementedError() # pragma: nocover

View File

@ -0,0 +1,340 @@
import logging
import ssl
import typing
from socksio import socks5
from .._backends.auto import AutoBackend
from .._backends.base import AsyncNetworkBackend, AsyncNetworkStream
from .._exceptions import ConnectionNotAvailable, ProxyError
from .._models import URL, Origin, Request, Response, enforce_bytes, enforce_url
from .._ssl import default_ssl_context
from .._synchronization import AsyncLock
from .._trace import Trace
from .connection_pool import AsyncConnectionPool
from .http11 import AsyncHTTP11Connection
from .interfaces import AsyncConnectionInterface
logger = logging.getLogger("httpcore.socks")
AUTH_METHODS = {
b"\x00": "NO AUTHENTICATION REQUIRED",
b"\x01": "GSSAPI",
b"\x02": "USERNAME/PASSWORD",
b"\xff": "NO ACCEPTABLE METHODS",
}
REPLY_CODES = {
b"\x00": "Succeeded",
b"\x01": "General SOCKS server failure",
b"\x02": "Connection not allowed by ruleset",
b"\x03": "Network unreachable",
b"\x04": "Host unreachable",
b"\x05": "Connection refused",
b"\x06": "TTL expired",
b"\x07": "Command not supported",
b"\x08": "Address type not supported",
}
async def _init_socks5_connection(
stream: AsyncNetworkStream,
*,
host: bytes,
port: int,
auth: typing.Optional[typing.Tuple[bytes, bytes]] = None,
) -> None:
conn = socks5.SOCKS5Connection()
# Auth method request
auth_method = (
socks5.SOCKS5AuthMethod.NO_AUTH_REQUIRED
if auth is None
else socks5.SOCKS5AuthMethod.USERNAME_PASSWORD
)
conn.send(socks5.SOCKS5AuthMethodsRequest([auth_method]))
outgoing_bytes = conn.data_to_send()
await stream.write(outgoing_bytes)
# Auth method response
incoming_bytes = await stream.read(max_bytes=4096)
response = conn.receive_data(incoming_bytes)
assert isinstance(response, socks5.SOCKS5AuthReply)
if response.method != auth_method:
requested = AUTH_METHODS.get(auth_method, "UNKNOWN")
responded = AUTH_METHODS.get(response.method, "UNKNOWN")
raise ProxyError(
f"Requested {requested} from proxy server, but got {responded}."
)
if response.method == socks5.SOCKS5AuthMethod.USERNAME_PASSWORD:
# Username/password request
assert auth is not None
username, password = auth
conn.send(socks5.SOCKS5UsernamePasswordRequest(username, password))
outgoing_bytes = conn.data_to_send()
await stream.write(outgoing_bytes)
# Username/password response
incoming_bytes = await stream.read(max_bytes=4096)
response = conn.receive_data(incoming_bytes)
assert isinstance(response, socks5.SOCKS5UsernamePasswordReply)
if not response.success:
raise ProxyError("Invalid username/password")
# Connect request
conn.send(
socks5.SOCKS5CommandRequest.from_address(
socks5.SOCKS5Command.CONNECT, (host, port)
)
)
outgoing_bytes = conn.data_to_send()
await stream.write(outgoing_bytes)
# Connect response
incoming_bytes = await stream.read(max_bytes=4096)
response = conn.receive_data(incoming_bytes)
assert isinstance(response, socks5.SOCKS5Reply)
if response.reply_code != socks5.SOCKS5ReplyCode.SUCCEEDED:
reply_code = REPLY_CODES.get(response.reply_code, "UNKOWN")
raise ProxyError(f"Proxy Server could not connect: {reply_code}.")
class AsyncSOCKSProxy(AsyncConnectionPool):
"""
A connection pool that sends requests via an HTTP proxy.
"""
def __init__(
self,
proxy_url: typing.Union[URL, bytes, str],
proxy_auth: typing.Optional[
typing.Tuple[typing.Union[bytes, str], typing.Union[bytes, str]]
] = None,
ssl_context: typing.Optional[ssl.SSLContext] = None,
max_connections: typing.Optional[int] = 10,
max_keepalive_connections: typing.Optional[int] = None,
keepalive_expiry: typing.Optional[float] = None,
http1: bool = True,
http2: bool = False,
retries: int = 0,
network_backend: typing.Optional[AsyncNetworkBackend] = None,
) -> None:
"""
A connection pool for making HTTP requests.
Parameters:
proxy_url: The URL to use when connecting to the proxy server.
For example `"http://127.0.0.1:8080/"`.
ssl_context: An SSL context to use for verifying connections.
If not specified, the default `httpcore.default_ssl_context()`
will be used.
max_connections: The maximum number of concurrent HTTP connections that
the pool should allow. Any attempt to send a request on a pool that
would exceed this amount will block until a connection is available.
max_keepalive_connections: The maximum number of idle HTTP connections
that will be maintained in the pool.
keepalive_expiry: The duration in seconds that an idle HTTP connection
may be maintained for before being expired from the pool.
http1: A boolean indicating if HTTP/1.1 requests should be supported
by the connection pool. Defaults to True.
http2: A boolean indicating if HTTP/2 requests should be supported by
the connection pool. Defaults to False.
retries: The maximum number of retries when trying to establish
a connection.
local_address: Local address to connect from. Can also be used to
connect using a particular address family. Using
`local_address="0.0.0.0"` will connect using an `AF_INET` address
(IPv4), while using `local_address="::"` will connect using an
`AF_INET6` address (IPv6).
uds: Path to a Unix Domain Socket to use instead of TCP sockets.
network_backend: A backend instance to use for handling network I/O.
"""
super().__init__(
ssl_context=ssl_context,
max_connections=max_connections,
max_keepalive_connections=max_keepalive_connections,
keepalive_expiry=keepalive_expiry,
http1=http1,
http2=http2,
network_backend=network_backend,
retries=retries,
)
self._ssl_context = ssl_context
self._proxy_url = enforce_url(proxy_url, name="proxy_url")
if proxy_auth is not None:
username, password = proxy_auth
username_bytes = enforce_bytes(username, name="proxy_auth")
password_bytes = enforce_bytes(password, name="proxy_auth")
self._proxy_auth: typing.Optional[typing.Tuple[bytes, bytes]] = (
username_bytes,
password_bytes,
)
else:
self._proxy_auth = None
def create_connection(self, origin: Origin) -> AsyncConnectionInterface:
return AsyncSocks5Connection(
proxy_origin=self._proxy_url.origin,
remote_origin=origin,
proxy_auth=self._proxy_auth,
ssl_context=self._ssl_context,
keepalive_expiry=self._keepalive_expiry,
http1=self._http1,
http2=self._http2,
network_backend=self._network_backend,
)
class AsyncSocks5Connection(AsyncConnectionInterface):
def __init__(
self,
proxy_origin: Origin,
remote_origin: Origin,
proxy_auth: typing.Optional[typing.Tuple[bytes, bytes]] = None,
ssl_context: typing.Optional[ssl.SSLContext] = None,
keepalive_expiry: typing.Optional[float] = None,
http1: bool = True,
http2: bool = False,
network_backend: typing.Optional[AsyncNetworkBackend] = None,
) -> None:
self._proxy_origin = proxy_origin
self._remote_origin = remote_origin
self._proxy_auth = proxy_auth
self._ssl_context = ssl_context
self._keepalive_expiry = keepalive_expiry
self._http1 = http1
self._http2 = http2
self._network_backend: AsyncNetworkBackend = (
AutoBackend() if network_backend is None else network_backend
)
self._connect_lock = AsyncLock()
self._connection: typing.Optional[AsyncConnectionInterface] = None
self._connect_failed = False
async def handle_async_request(self, request: Request) -> Response:
timeouts = request.extensions.get("timeout", {})
timeout = timeouts.get("connect", None)
async with self._connect_lock:
if self._connection is None:
try:
# Connect to the proxy
kwargs = {
"host": self._proxy_origin.host.decode("ascii"),
"port": self._proxy_origin.port,
"timeout": timeout,
}
with Trace("connect_tcp", logger, request, kwargs) as trace:
stream = await self._network_backend.connect_tcp(**kwargs)
trace.return_value = stream
# Connect to the remote host using socks5
kwargs = {
"stream": stream,
"host": self._remote_origin.host.decode("ascii"),
"port": self._remote_origin.port,
"auth": self._proxy_auth,
}
with Trace(
"setup_socks5_connection", logger, request, kwargs
) as trace:
await _init_socks5_connection(**kwargs)
trace.return_value = stream
# Upgrade the stream to SSL
if self._remote_origin.scheme == b"https":
ssl_context = (
default_ssl_context()
if self._ssl_context is None
else self._ssl_context
)
alpn_protocols = (
["http/1.1", "h2"] if self._http2 else ["http/1.1"]
)
ssl_context.set_alpn_protocols(alpn_protocols)
kwargs = {
"ssl_context": ssl_context,
"server_hostname": self._remote_origin.host.decode("ascii"),
"timeout": timeout,
}
async with Trace("start_tls", logger, request, kwargs) as trace:
stream = await stream.start_tls(**kwargs)
trace.return_value = stream
# Determine if we should be using HTTP/1.1 or HTTP/2
ssl_object = stream.get_extra_info("ssl_object")
http2_negotiated = (
ssl_object is not None
and ssl_object.selected_alpn_protocol() == "h2"
)
# Create the HTTP/1.1 or HTTP/2 connection
if http2_negotiated or (
self._http2 and not self._http1
): # pragma: nocover
from .http2 import AsyncHTTP2Connection
self._connection = AsyncHTTP2Connection(
origin=self._remote_origin,
stream=stream,
keepalive_expiry=self._keepalive_expiry,
)
else:
self._connection = AsyncHTTP11Connection(
origin=self._remote_origin,
stream=stream,
keepalive_expiry=self._keepalive_expiry,
)
except Exception as exc:
self._connect_failed = True
raise exc
elif not self._connection.is_available(): # pragma: nocover
raise ConnectionNotAvailable()
return await self._connection.handle_async_request(request)
def can_handle_request(self, origin: Origin) -> bool:
return origin == self._remote_origin
async def aclose(self) -> None:
if self._connection is not None:
await self._connection.aclose()
def is_available(self) -> bool:
if self._connection is None: # pragma: nocover
# If HTTP/2 support is enabled, and the resulting connection could
# end up as HTTP/2 then we should indicate the connection as being
# available to service multiple requests.
return (
self._http2
and (self._remote_origin.scheme == b"https" or not self._http1)
and not self._connect_failed
)
return self._connection.is_available()
def has_expired(self) -> bool:
if self._connection is None: # pragma: nocover
return self._connect_failed
return self._connection.has_expired()
def is_idle(self) -> bool:
if self._connection is None: # pragma: nocover
return self._connect_failed
return self._connection.is_idle()
def is_closed(self) -> bool:
if self._connection is None: # pragma: nocover
return self._connect_failed
return self._connection.is_closed()
def info(self) -> str:
if self._connection is None: # pragma: nocover
return "CONNECTION FAILED" if self._connect_failed else "CONNECTING"
return self._connection.info()
def __repr__(self) -> str:
return f"<{self.__class__.__name__} [{self.info()}]>"

View File

View File

@ -0,0 +1,145 @@
import ssl
import typing
import anyio
from .._exceptions import (
ConnectError,
ConnectTimeout,
ReadError,
ReadTimeout,
WriteError,
WriteTimeout,
map_exceptions,
)
from .._utils import is_socket_readable
from .base import SOCKET_OPTION, AsyncNetworkBackend, AsyncNetworkStream
class AnyIOStream(AsyncNetworkStream):
def __init__(self, stream: anyio.abc.ByteStream) -> None:
self._stream = stream
async def read(
self, max_bytes: int, timeout: typing.Optional[float] = None
) -> bytes:
exc_map = {
TimeoutError: ReadTimeout,
anyio.BrokenResourceError: ReadError,
anyio.ClosedResourceError: ReadError,
}
with map_exceptions(exc_map):
with anyio.fail_after(timeout):
try:
return await self._stream.receive(max_bytes=max_bytes)
except anyio.EndOfStream: # pragma: nocover
return b""
async def write(
self, buffer: bytes, timeout: typing.Optional[float] = None
) -> None:
if not buffer:
return
exc_map = {
TimeoutError: WriteTimeout,
anyio.BrokenResourceError: WriteError,
anyio.ClosedResourceError: WriteError,
}
with map_exceptions(exc_map):
with anyio.fail_after(timeout):
await self._stream.send(item=buffer)
async def aclose(self) -> None:
await self._stream.aclose()
async def start_tls(
self,
ssl_context: ssl.SSLContext,
server_hostname: typing.Optional[str] = None,
timeout: typing.Optional[float] = None,
) -> AsyncNetworkStream:
exc_map = {
TimeoutError: ConnectTimeout,
anyio.BrokenResourceError: ConnectError,
}
with map_exceptions(exc_map):
try:
with anyio.fail_after(timeout):
ssl_stream = await anyio.streams.tls.TLSStream.wrap(
self._stream,
ssl_context=ssl_context,
hostname=server_hostname,
standard_compatible=False,
server_side=False,
)
except Exception as exc: # pragma: nocover
await self.aclose()
raise exc
return AnyIOStream(ssl_stream)
def get_extra_info(self, info: str) -> typing.Any:
if info == "ssl_object":
return self._stream.extra(anyio.streams.tls.TLSAttribute.ssl_object, None)
if info == "client_addr":
return self._stream.extra(anyio.abc.SocketAttribute.local_address, None)
if info == "server_addr":
return self._stream.extra(anyio.abc.SocketAttribute.remote_address, None)
if info == "socket":
return self._stream.extra(anyio.abc.SocketAttribute.raw_socket, None)
if info == "is_readable":
sock = self._stream.extra(anyio.abc.SocketAttribute.raw_socket, None)
return is_socket_readable(sock)
return None
class AnyIOBackend(AsyncNetworkBackend):
async def connect_tcp(
self,
host: str,
port: int,
timeout: typing.Optional[float] = None,
local_address: typing.Optional[str] = None,
socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None,
) -> AsyncNetworkStream:
if socket_options is None:
socket_options = [] # pragma: no cover
exc_map = {
TimeoutError: ConnectTimeout,
OSError: ConnectError,
anyio.BrokenResourceError: ConnectError,
}
with map_exceptions(exc_map):
with anyio.fail_after(timeout):
stream: anyio.abc.ByteStream = await anyio.connect_tcp(
remote_host=host,
remote_port=port,
local_host=local_address,
)
# By default TCP sockets opened in `asyncio` include TCP_NODELAY.
for option in socket_options:
stream._raw_socket.setsockopt(*option) # type: ignore[attr-defined] # pragma: no cover
return AnyIOStream(stream)
async def connect_unix_socket(
self,
path: str,
timeout: typing.Optional[float] = None,
socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None,
) -> AsyncNetworkStream: # pragma: nocover
if socket_options is None:
socket_options = []
exc_map = {
TimeoutError: ConnectTimeout,
OSError: ConnectError,
anyio.BrokenResourceError: ConnectError,
}
with map_exceptions(exc_map):
with anyio.fail_after(timeout):
stream: anyio.abc.ByteStream = await anyio.connect_unix(path)
for option in socket_options:
stream._raw_socket.setsockopt(*option) # type: ignore[attr-defined] # pragma: no cover
return AnyIOStream(stream)
async def sleep(self, seconds: float) -> None:
await anyio.sleep(seconds) # pragma: nocover

View File

@ -0,0 +1,56 @@
import typing
from typing import Optional
import sniffio
from .base import SOCKET_OPTION, AsyncNetworkBackend, AsyncNetworkStream
class AutoBackend(AsyncNetworkBackend):
async def _init_backend(self) -> None:
if not (hasattr(self, "_backend")):
backend = sniffio.current_async_library()
if backend == "trio":
from .trio import TrioBackend
self._backend: AsyncNetworkBackend = TrioBackend()
elif backend == "structured-io":
from .structio import StructioBackend
self._backend: AsyncNetworkBackend = StructioBackend()
else:
from .anyio import AnyIOBackend
self._backend = AnyIOBackend()
async def connect_tcp(
self,
host: str,
port: int,
timeout: Optional[float] = None,
local_address: Optional[str] = None,
socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None,
) -> AsyncNetworkStream:
await self._init_backend()
return await self._backend.connect_tcp(
host,
port,
timeout=timeout,
local_address=local_address,
socket_options=socket_options,
)
async def connect_unix_socket(
self,
path: str,
timeout: Optional[float] = None,
socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None,
) -> AsyncNetworkStream: # pragma: nocover
await self._init_backend()
return await self._backend.connect_unix_socket(
path, timeout=timeout, socket_options=socket_options
)
async def sleep(self, seconds: float) -> None: # pragma: nocover
await self._init_backend()
return await self._backend.sleep(seconds)

View File

@ -0,0 +1,103 @@
import ssl
import time
import typing
SOCKET_OPTION = typing.Union[
typing.Tuple[int, int, int],
typing.Tuple[int, int, typing.Union[bytes, bytearray]],
typing.Tuple[int, int, None, int],
]
class NetworkStream:
def read(self, max_bytes: int, timeout: typing.Optional[float] = None) -> bytes:
raise NotImplementedError() # pragma: nocover
def write(self, buffer: bytes, timeout: typing.Optional[float] = None) -> None:
raise NotImplementedError() # pragma: nocover
def close(self) -> None:
raise NotImplementedError() # pragma: nocover
def start_tls(
self,
ssl_context: ssl.SSLContext,
server_hostname: typing.Optional[str] = None,
timeout: typing.Optional[float] = None,
) -> "NetworkStream":
raise NotImplementedError() # pragma: nocover
def get_extra_info(self, info: str) -> typing.Any:
return None # pragma: nocover
class NetworkBackend:
def connect_tcp(
self,
host: str,
port: int,
timeout: typing.Optional[float] = None,
local_address: typing.Optional[str] = None,
socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None,
) -> NetworkStream:
raise NotImplementedError() # pragma: nocover
def connect_unix_socket(
self,
path: str,
timeout: typing.Optional[float] = None,
socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None,
) -> NetworkStream:
raise NotImplementedError() # pragma: nocover
def sleep(self, seconds: float) -> None:
time.sleep(seconds) # pragma: nocover
class AsyncNetworkStream:
async def read(
self, max_bytes: int, timeout: typing.Optional[float] = None
) -> bytes:
raise NotImplementedError() # pragma: nocover
async def write(
self, buffer: bytes, timeout: typing.Optional[float] = None
) -> None:
raise NotImplementedError() # pragma: nocover
async def aclose(self) -> None:
raise NotImplementedError() # pragma: nocover
async def start_tls(
self,
ssl_context: ssl.SSLContext,
server_hostname: typing.Optional[str] = None,
timeout: typing.Optional[float] = None,
) -> "AsyncNetworkStream":
raise NotImplementedError() # pragma: nocover
def get_extra_info(self, info: str) -> typing.Any:
return None # pragma: nocover
class AsyncNetworkBackend:
async def connect_tcp(
self,
host: str,
port: int,
timeout: typing.Optional[float] = None,
local_address: typing.Optional[str] = None,
socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None,
) -> AsyncNetworkStream:
raise NotImplementedError() # pragma: nocover
async def connect_unix_socket(
self,
path: str,
timeout: typing.Optional[float] = None,
socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None,
) -> AsyncNetworkStream:
raise NotImplementedError() # pragma: nocover
async def sleep(self, seconds: float) -> None:
raise NotImplementedError() # pragma: nocover

View File

@ -0,0 +1,142 @@
import ssl
import typing
from typing import Optional
from .._exceptions import ReadError
from .base import (
SOCKET_OPTION,
AsyncNetworkBackend,
AsyncNetworkStream,
NetworkBackend,
NetworkStream,
)
class MockSSLObject:
def __init__(self, http2: bool):
self._http2 = http2
def selected_alpn_protocol(self) -> str:
return "h2" if self._http2 else "http/1.1"
class MockStream(NetworkStream):
def __init__(self, buffer: typing.List[bytes], http2: bool = False) -> None:
self._buffer = buffer
self._http2 = http2
self._closed = False
def read(self, max_bytes: int, timeout: Optional[float] = None) -> bytes:
if self._closed:
raise ReadError("Connection closed")
if not self._buffer:
return b""
return self._buffer.pop(0)
def write(self, buffer: bytes, timeout: Optional[float] = None) -> None:
pass
def close(self) -> None:
self._closed = True
def start_tls(
self,
ssl_context: ssl.SSLContext,
server_hostname: Optional[str] = None,
timeout: Optional[float] = None,
) -> NetworkStream:
return self
def get_extra_info(self, info: str) -> typing.Any:
return MockSSLObject(http2=self._http2) if info == "ssl_object" else None
def __repr__(self) -> str:
return "<httpcore.MockStream>"
class MockBackend(NetworkBackend):
def __init__(self, buffer: typing.List[bytes], http2: bool = False) -> None:
self._buffer = buffer
self._http2 = http2
def connect_tcp(
self,
host: str,
port: int,
timeout: Optional[float] = None,
local_address: Optional[str] = None,
socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None,
) -> NetworkStream:
return MockStream(list(self._buffer), http2=self._http2)
def connect_unix_socket(
self,
path: str,
timeout: Optional[float] = None,
socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None,
) -> NetworkStream:
return MockStream(list(self._buffer), http2=self._http2)
def sleep(self, seconds: float) -> None:
pass
class AsyncMockStream(AsyncNetworkStream):
def __init__(self, buffer: typing.List[bytes], http2: bool = False) -> None:
self._buffer = buffer
self._http2 = http2
self._closed = False
async def read(self, max_bytes: int, timeout: Optional[float] = None) -> bytes:
if self._closed:
raise ReadError("Connection closed")
if not self._buffer:
return b""
return self._buffer.pop(0)
async def write(self, buffer: bytes, timeout: Optional[float] = None) -> None:
pass
async def aclose(self) -> None:
self._closed = True
async def start_tls(
self,
ssl_context: ssl.SSLContext,
server_hostname: Optional[str] = None,
timeout: Optional[float] = None,
) -> AsyncNetworkStream:
return self
def get_extra_info(self, info: str) -> typing.Any:
return MockSSLObject(http2=self._http2) if info == "ssl_object" else None
def __repr__(self) -> str:
return "<httpcore.AsyncMockStream>"
class AsyncMockBackend(AsyncNetworkBackend):
def __init__(self, buffer: typing.List[bytes], http2: bool = False) -> None:
self._buffer = buffer
self._http2 = http2
async def connect_tcp(
self,
host: str,
port: int,
timeout: Optional[float] = None,
local_address: Optional[str] = None,
socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None,
) -> AsyncNetworkStream:
return AsyncMockStream(list(self._buffer), http2=self._http2)
async def connect_unix_socket(
self,
path: str,
timeout: Optional[float] = None,
socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None,
) -> AsyncNetworkStream:
return AsyncMockStream(list(self._buffer), http2=self._http2)
async def sleep(self, seconds: float) -> None:
pass

View File

@ -0,0 +1,118 @@
from .base import AsyncNetworkStream, AsyncNetworkBackend, SOCKET_OPTION
from .._exceptions import (ConnectTimeout, ReadError, ReadTimeout,
WriteError, WriteTimeout, ConnectError,
map_exceptions, ExceptionMapping, PoolTimeout)
import structio
import ssl
import typing
class StructioStream(AsyncNetworkStream):
"""
Structio-compatible async stream for
httpx
"""
def __init__(self, stream: structio.AsyncSocket):
self._stream = stream
async def read(
self, max_bytes: int, timeout: typing.Optional[float] = None
) -> bytes:
timeout_or_inf = float("inf") if timeout is None else timeout
exc_map: ExceptionMapping = {
structio.TimedOut: ReadTimeout,
structio.ResourceClosed: ReadError,
structio.ResourceBusy: ReadError,
structio.ResourceBroken: ReadError
}
with map_exceptions(exc_map):
with structio.with_timeout(timeout_or_inf):
data: bytes = await self._stream.receive(max_bytes)
return data
async def write(
self, buffer: bytes, timeout: typing.Optional[float] = None
) -> None:
if not buffer:
return
timeout_or_inf = float("inf") if timeout is None else timeout
exc_map: ExceptionMapping = {
structio.TimedOut: WriteTimeout,
structio.ResourceClosed: WriteError,
structio.ResourceBusy: WriteError,
structio.ResourceBroken: WriteError
}
with map_exceptions(exc_map):
with structio.with_timeout(timeout_or_inf):
await self._stream.send_all(data=buffer)
async def start_tls(
self,
ssl_context: ssl.SSLContext,
server_hostname: typing.Optional[str] = None,
timeout: typing.Optional[float] = None,
) -> AsyncNetworkStream:
timeout_or_inf = float("inf") if timeout is None else timeout
exc_map: ExceptionMapping = {
structio.TimedOut: ConnectTimeout,
structio.ResourceBroken: ConnectError,
}
with map_exceptions(exc_map):
try:
with structio.with_timeout(timeout_or_inf):
self._stream = await structio.socket.wrap_socket_with_ssl(self._stream, context=ssl_context, server_hostname=server_hostname)
except Exception as exc: # pragma: nocover
await self.aclose()
raise exc
return self
async def aclose(self) -> None:
return await self._stream.close()
def get_extra_info(self, info: str) -> typing.Any:
if info == "ssl_object" and hasattr(self._stream, "_sslobj"):
return self._stream._sslobj
if info == "client_addr":
return self._stream.socket.getsockname()
if info == "server_addr":
return self._stream.socket.getpeername()
if info == "socket":
return self._stream
if info == "is_readable":
return self._stream.is_readable()
return None
class StructioBackend(AsyncNetworkBackend):
async def connect_tcp(
self,
host: str,
port: int,
timeout: typing.Optional[float] = None,
local_address: typing.Optional[str] = None,
socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None,
) -> AsyncNetworkStream:
if socket_options is None:
socket_options = []
timeout_or_inf = float("inf") if timeout is None else timeout
exc_map: ExceptionMapping = {
structio.TimedOut: ConnectTimeout,
structio.ResourceBusy: ConnectError,
OSError: ConnectError,
}
with map_exceptions(exc_map):
with structio.with_timeout(timeout_or_inf):
stream: structio.AsyncSocket = await structio.socket.connect_tcp_socket(
host=host, port=port, source_address=local_address
)
for option in socket_options:
stream.setsockopt(*option)
return StructioStream(stream)
# TODO: connect_unix_socket
async def sleep(self, seconds: float) -> None:
await structio.sleep(seconds)

View File

@ -0,0 +1,133 @@
import socket
import ssl
import sys
import typing
from .._exceptions import (
ConnectError,
ConnectTimeout,
ExceptionMapping,
ReadError,
ReadTimeout,
WriteError,
WriteTimeout,
map_exceptions,
)
from .._utils import is_socket_readable
from .base import SOCKET_OPTION, NetworkBackend, NetworkStream
class SyncStream(NetworkStream):
def __init__(self, sock: socket.socket) -> None:
self._sock = sock
def read(self, max_bytes: int, timeout: typing.Optional[float] = None) -> bytes:
exc_map: ExceptionMapping = {socket.timeout: ReadTimeout, OSError: ReadError}
with map_exceptions(exc_map):
self._sock.settimeout(timeout)
return self._sock.recv(max_bytes)
def write(self, buffer: bytes, timeout: typing.Optional[float] = None) -> None:
if not buffer:
return
exc_map: ExceptionMapping = {socket.timeout: WriteTimeout, OSError: WriteError}
with map_exceptions(exc_map):
while buffer:
self._sock.settimeout(timeout)
n = self._sock.send(buffer)
buffer = buffer[n:]
def close(self) -> None:
self._sock.close()
def start_tls(
self,
ssl_context: ssl.SSLContext,
server_hostname: typing.Optional[str] = None,
timeout: typing.Optional[float] = None,
) -> NetworkStream:
exc_map: ExceptionMapping = {
socket.timeout: ConnectTimeout,
OSError: ConnectError,
}
with map_exceptions(exc_map):
try:
self._sock.settimeout(timeout)
sock = ssl_context.wrap_socket(
self._sock, server_hostname=server_hostname
)
except Exception as exc: # pragma: nocover
self.close()
raise exc
return SyncStream(sock)
def get_extra_info(self, info: str) -> typing.Any:
if info == "ssl_object" and isinstance(self._sock, ssl.SSLSocket):
return self._sock._sslobj # type: ignore
if info == "client_addr":
return self._sock.getsockname()
if info == "server_addr":
return self._sock.getpeername()
if info == "socket":
return self._sock
if info == "is_readable":
return is_socket_readable(self._sock)
return None
class SyncBackend(NetworkBackend):
def connect_tcp(
self,
host: str,
port: int,
timeout: typing.Optional[float] = None,
local_address: typing.Optional[str] = None,
socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None,
) -> NetworkStream:
# Note that we automatically include `TCP_NODELAY`
# in addition to any other custom socket options.
if socket_options is None:
socket_options = [] # pragma: no cover
address = (host, port)
source_address = None if local_address is None else (local_address, 0)
exc_map: ExceptionMapping = {
socket.timeout: ConnectTimeout,
OSError: ConnectError,
}
with map_exceptions(exc_map):
sock = socket.create_connection(
address,
timeout,
source_address=source_address,
)
for option in socket_options:
sock.setsockopt(*option) # pragma: no cover
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
return SyncStream(sock)
def connect_unix_socket(
self,
path: str,
timeout: typing.Optional[float] = None,
socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None,
) -> NetworkStream: # pragma: nocover
if sys.platform == "win32":
raise RuntimeError(
"Attempted to connect to a UNIX socket on a Windows system."
)
if socket_options is None:
socket_options = []
exc_map: ExceptionMapping = {
socket.timeout: ConnectTimeout,
OSError: ConnectError,
}
with map_exceptions(exc_map):
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
for option in socket_options:
sock.setsockopt(*option)
sock.settimeout(timeout)
sock.connect(path)
return SyncStream(sock)

View File

@ -0,0 +1,161 @@
import ssl
import typing
import trio
from .._exceptions import (
ConnectError,
ConnectTimeout,
ExceptionMapping,
ReadError,
ReadTimeout,
WriteError,
WriteTimeout,
map_exceptions,
)
from .base import SOCKET_OPTION, AsyncNetworkBackend, AsyncNetworkStream
class TrioStream(AsyncNetworkStream):
def __init__(self, stream: trio.abc.Stream) -> None:
self._stream = stream
async def read(
self, max_bytes: int, timeout: typing.Optional[float] = None
) -> bytes:
timeout_or_inf = float("inf") if timeout is None else timeout
exc_map: ExceptionMapping = {
trio.TooSlowError: ReadTimeout,
trio.BrokenResourceError: ReadError,
trio.ClosedResourceError: ReadError,
}
with map_exceptions(exc_map):
with trio.fail_after(timeout_or_inf):
data: bytes = await self._stream.receive_some(max_bytes=max_bytes)
return data
async def write(
self, buffer: bytes, timeout: typing.Optional[float] = None
) -> None:
if not buffer:
return
timeout_or_inf = float("inf") if timeout is None else timeout
exc_map: ExceptionMapping = {
trio.TooSlowError: WriteTimeout,
trio.BrokenResourceError: WriteError,
trio.ClosedResourceError: WriteError,
}
with map_exceptions(exc_map):
with trio.fail_after(timeout_or_inf):
await self._stream.send_all(data=buffer)
async def aclose(self) -> None:
await self._stream.aclose()
async def start_tls(
self,
ssl_context: ssl.SSLContext,
server_hostname: typing.Optional[str] = None,
timeout: typing.Optional[float] = None,
) -> AsyncNetworkStream:
timeout_or_inf = float("inf") if timeout is None else timeout
exc_map: ExceptionMapping = {
trio.TooSlowError: ConnectTimeout,
trio.BrokenResourceError: ConnectError,
}
ssl_stream = trio.SSLStream(
self._stream,
ssl_context=ssl_context,
server_hostname=server_hostname,
https_compatible=True,
server_side=False,
)
with map_exceptions(exc_map):
try:
with trio.fail_after(timeout_or_inf):
await ssl_stream.do_handshake()
except Exception as exc: # pragma: nocover
await self.aclose()
raise exc
return TrioStream(ssl_stream)
def get_extra_info(self, info: str) -> typing.Any:
if info == "ssl_object" and isinstance(self._stream, trio.SSLStream):
# Type checkers cannot see `_ssl_object` attribute because trio._ssl.SSLStream uses __getattr__/__setattr__.
# Tracked at https://github.com/python-trio/trio/issues/542
return self._stream._ssl_object # type: ignore[attr-defined]
if info == "client_addr":
return self._get_socket_stream().socket.getsockname()
if info == "server_addr":
return self._get_socket_stream().socket.getpeername()
if info == "socket":
stream = self._stream
while isinstance(stream, trio.SSLStream):
stream = stream.transport_stream
assert isinstance(stream, trio.SocketStream)
return stream.socket
if info == "is_readable":
socket = self.get_extra_info("socket")
return socket.is_readable()
return None
def _get_socket_stream(self) -> trio.SocketStream:
stream = self._stream
while isinstance(stream, trio.SSLStream):
stream = stream.transport_stream
assert isinstance(stream, trio.SocketStream)
return stream
class TrioBackend(AsyncNetworkBackend):
async def connect_tcp(
self,
host: str,
port: int,
timeout: typing.Optional[float] = None,
local_address: typing.Optional[str] = None,
socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None,
) -> AsyncNetworkStream:
# By default for TCP sockets, trio enables TCP_NODELAY.
# https://trio.readthedocs.io/en/stable/reference-io.html#trio.SocketStream
if socket_options is None:
socket_options = [] # pragma: no cover
timeout_or_inf = float("inf") if timeout is None else timeout
exc_map: ExceptionMapping = {
trio.TooSlowError: ConnectTimeout,
trio.BrokenResourceError: ConnectError,
OSError: ConnectError,
}
with map_exceptions(exc_map):
with trio.fail_after(timeout_or_inf):
stream: trio.abc.Stream = await trio.open_tcp_stream(
host=host, port=port, local_address=local_address
)
for option in socket_options:
stream.setsockopt(*option) # type: ignore[attr-defined] # pragma: no cover
return TrioStream(stream)
async def connect_unix_socket(
self,
path: str,
timeout: typing.Optional[float] = None,
socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None,
) -> AsyncNetworkStream: # pragma: nocover
if socket_options is None:
socket_options = []
timeout_or_inf = float("inf") if timeout is None else timeout
exc_map: ExceptionMapping = {
trio.TooSlowError: ConnectTimeout,
trio.BrokenResourceError: ConnectError,
OSError: ConnectError,
}
with map_exceptions(exc_map):
with trio.fail_after(timeout_or_inf):
stream: trio.abc.Stream = await trio.open_unix_socket(path)
for option in socket_options:
stream.setsockopt(*option) # type: ignore[attr-defined] # pragma: no cover
return TrioStream(stream)
async def sleep(self, seconds: float) -> None:
await trio.sleep(seconds) # pragma: nocover

View File

@ -0,0 +1,81 @@
import contextlib
from typing import Iterator, Mapping, Type
ExceptionMapping = Mapping[Type[Exception], Type[Exception]]
@contextlib.contextmanager
def map_exceptions(map: ExceptionMapping) -> Iterator[None]:
try:
yield
except Exception as exc: # noqa: PIE786
for from_exc, to_exc in map.items():
if isinstance(exc, from_exc):
raise to_exc(exc) from exc
raise # pragma: nocover
class ConnectionNotAvailable(Exception):
pass
class ProxyError(Exception):
pass
class UnsupportedProtocol(Exception):
pass
class ProtocolError(Exception):
pass
class RemoteProtocolError(ProtocolError):
pass
class LocalProtocolError(ProtocolError):
pass
# Timeout errors
class TimeoutException(Exception):
pass
class PoolTimeout(TimeoutException):
pass
class ConnectTimeout(TimeoutException):
pass
class ReadTimeout(TimeoutException):
pass
class WriteTimeout(TimeoutException):
pass
# Network errors
class NetworkError(Exception):
pass
class ConnectError(NetworkError):
pass
class ReadError(NetworkError):
pass
class WriteError(NetworkError):
pass

483
tests/httpcore/_models.py Normal file
View File

@ -0,0 +1,483 @@
from typing import (
Any,
AsyncIterable,
AsyncIterator,
Iterable,
Iterator,
List,
Mapping,
Optional,
Sequence,
Tuple,
Union,
)
from urllib.parse import urlparse
# Functions for typechecking...
HeadersAsSequence = Sequence[Tuple[Union[bytes, str], Union[bytes, str]]]
HeadersAsMapping = Mapping[Union[bytes, str], Union[bytes, str]]
HeaderTypes = Union[HeadersAsSequence, HeadersAsMapping, None]
Extensions = Mapping[str, Any]
def enforce_bytes(value: Union[bytes, str], *, name: str) -> bytes:
"""
Any arguments that are ultimately represented as bytes can be specified
either as bytes or as strings.
However we enforce that any string arguments must only contain characters in
the plain ASCII range. chr(0)...chr(127). If you need to use characters
outside that range then be precise, and use a byte-wise argument.
"""
if isinstance(value, str):
try:
return value.encode("ascii")
except UnicodeEncodeError:
raise TypeError(f"{name} strings may not include unicode characters.")
elif isinstance(value, bytes):
return value
seen_type = type(value).__name__
raise TypeError(f"{name} must be bytes or str, but got {seen_type}.")
def enforce_url(value: Union["URL", bytes, str], *, name: str) -> "URL":
"""
Type check for URL parameters.
"""
if isinstance(value, (bytes, str)):
return URL(value)
elif isinstance(value, URL):
return value
seen_type = type(value).__name__
raise TypeError(f"{name} must be a URL, bytes, or str, but got {seen_type}.")
def enforce_headers(
value: Union[HeadersAsMapping, HeadersAsSequence, None] = None, *, name: str
) -> List[Tuple[bytes, bytes]]:
"""
Convienence function that ensure all items in request or response headers
are either bytes or strings in the plain ASCII range.
"""
if value is None:
return []
elif isinstance(value, Mapping):
return [
(
enforce_bytes(k, name="header name"),
enforce_bytes(v, name="header value"),
)
for k, v in value.items()
]
elif isinstance(value, Sequence):
return [
(
enforce_bytes(k, name="header name"),
enforce_bytes(v, name="header value"),
)
for k, v in value
]
seen_type = type(value).__name__
raise TypeError(
f"{name} must be a mapping or sequence of two-tuples, but got {seen_type}."
)
def enforce_stream(
value: Union[bytes, Iterable[bytes], AsyncIterable[bytes], None], *, name: str
) -> Union[Iterable[bytes], AsyncIterable[bytes]]:
if value is None:
return ByteStream(b"")
elif isinstance(value, bytes):
return ByteStream(value)
return value
# * https://tools.ietf.org/html/rfc3986#section-3.2.3
# * https://url.spec.whatwg.org/#url-miscellaneous
# * https://url.spec.whatwg.org/#scheme-state
DEFAULT_PORTS = {
b"ftp": 21,
b"http": 80,
b"https": 443,
b"ws": 80,
b"wss": 443,
}
def include_request_headers(
headers: List[Tuple[bytes, bytes]],
*,
url: "URL",
content: Union[None, bytes, Iterable[bytes], AsyncIterable[bytes]],
) -> List[Tuple[bytes, bytes]]:
headers_set = set(k.lower() for k, v in headers)
if b"host" not in headers_set:
default_port = DEFAULT_PORTS.get(url.scheme)
if url.port is None or url.port == default_port:
header_value = url.host
else:
header_value = b"%b:%d" % (url.host, url.port)
headers = [(b"Host", header_value)] + headers
if (
content is not None
and b"content-length" not in headers_set
and b"transfer-encoding" not in headers_set
):
if isinstance(content, bytes):
content_length = str(len(content)).encode("ascii")
headers += [(b"Content-Length", content_length)]
else:
headers += [(b"Transfer-Encoding", b"chunked")] # pragma: nocover
return headers
# Interfaces for byte streams...
class ByteStream:
"""
A container for non-streaming content, and that supports both sync and async
stream iteration.
"""
def __init__(self, content: bytes) -> None:
self._content = content
def __iter__(self) -> Iterator[bytes]:
yield self._content
async def __aiter__(self) -> AsyncIterator[bytes]:
yield self._content
def __repr__(self) -> str:
return f"<{self.__class__.__name__} [{len(self._content)} bytes]>"
class Origin:
def __init__(self, scheme: bytes, host: bytes, port: int) -> None:
self.scheme = scheme
self.host = host
self.port = port
def __eq__(self, other: Any) -> bool:
return (
isinstance(other, Origin)
and self.scheme == other.scheme
and self.host == other.host
and self.port == other.port
)
def __str__(self) -> str:
scheme = self.scheme.decode("ascii")
host = self.host.decode("ascii")
port = str(self.port)
return f"{scheme}://{host}:{port}"
class URL:
"""
Represents the URL against which an HTTP request may be made.
The URL may either be specified as a plain string, for convienence:
```python
url = httpcore.URL("https://www.example.com/")
```
Or be constructed with explicitily pre-parsed components:
```python
url = httpcore.URL(scheme=b'https', host=b'www.example.com', port=None, target=b'/')
```
Using this second more explicit style allows integrations that are using
`httpcore` to pass through URLs that have already been parsed in order to use
libraries such as `rfc-3986` rather than relying on the stdlib. It also ensures
that URL parsing is treated identically at both the networking level and at any
higher layers of abstraction.
The four components are important here, as they allow the URL to be precisely
specified in a pre-parsed format. They also allow certain types of request to
be created that could not otherwise be expressed.
For example, an HTTP request to `http://www.example.com/` forwarded via a proxy
at `http://localhost:8080`...
```python
# Constructs an HTTP request with a complete URL as the target:
# GET https://www.example.com/ HTTP/1.1
url = httpcore.URL(
scheme=b'http',
host=b'localhost',
port=8080,
target=b'https://www.example.com/'
)
request = httpcore.Request(
method="GET",
url=url
)
```
Another example is constructing an `OPTIONS *` request...
```python
# Constructs an 'OPTIONS *' HTTP request:
# OPTIONS * HTTP/1.1
url = httpcore.URL(scheme=b'https', host=b'www.example.com', target=b'*')
request = httpcore.Request(method="OPTIONS", url=url)
```
This kind of request is not possible to formulate with a URL string,
because the `/` delimiter is always used to demark the target from the
host/port portion of the URL.
For convenience, string-like arguments may be specified either as strings or
as bytes. However, once a request is being issue over-the-wire, the URL
components are always ultimately required to be a bytewise representation.
In order to avoid any ambiguity over character encodings, when strings are used
as arguments, they must be strictly limited to the ASCII range `chr(0)`-`chr(127)`.
If you require a bytewise representation that is outside this range you must
handle the character encoding directly, and pass a bytes instance.
"""
def __init__(
self,
url: Union[bytes, str] = "",
*,
scheme: Union[bytes, str] = b"",
host: Union[bytes, str] = b"",
port: Optional[int] = None,
target: Union[bytes, str] = b"",
) -> None:
"""
Parameters:
url: The complete URL as a string or bytes.
scheme: The URL scheme as a string or bytes.
Typically either `"http"` or `"https"`.
host: The URL host as a string or bytes. Such as `"www.example.com"`.
port: The port to connect to. Either an integer or `None`.
target: The target of the HTTP request. Such as `"/items?search=red"`.
"""
if url:
parsed = urlparse(enforce_bytes(url, name="url"))
self.scheme = parsed.scheme
self.host = parsed.hostname or b""
self.port = parsed.port
self.target = (parsed.path or b"/") + (
b"?" + parsed.query if parsed.query else b""
)
else:
self.scheme = enforce_bytes(scheme, name="scheme")
self.host = enforce_bytes(host, name="host")
self.port = port
self.target = enforce_bytes(target, name="target")
@property
def origin(self) -> Origin:
default_port = {
b"http": 80,
b"https": 443,
b"ws": 80,
b"wss": 443,
b"socks5": 1080,
}[self.scheme]
return Origin(
scheme=self.scheme, host=self.host, port=self.port or default_port
)
def __eq__(self, other: Any) -> bool:
return (
isinstance(other, URL)
and other.scheme == self.scheme
and other.host == self.host
and other.port == self.port
and other.target == self.target
)
def __bytes__(self) -> bytes:
if self.port is None:
return b"%b://%b%b" % (self.scheme, self.host, self.target)
return b"%b://%b:%d%b" % (self.scheme, self.host, self.port, self.target)
def __repr__(self) -> str:
return (
f"{self.__class__.__name__}(scheme={self.scheme!r}, "
f"host={self.host!r}, port={self.port!r}, target={self.target!r})"
)
class Request:
"""
An HTTP request.
"""
def __init__(
self,
method: Union[bytes, str],
url: Union[URL, bytes, str],
*,
headers: HeaderTypes = None,
content: Union[bytes, Iterable[bytes], AsyncIterable[bytes], None] = None,
extensions: Optional[Extensions] = None,
) -> None:
"""
Parameters:
method: The HTTP request method, either as a string or bytes.
For example: `GET`.
url: The request URL, either as a `URL` instance, or as a string or bytes.
For example: `"https://www.example.com".`
headers: The HTTP request headers.
content: The content of the response body.
extensions: A dictionary of optional extra information included on
the request. Possible keys include `"timeout"`, and `"trace"`.
"""
self.method: bytes = enforce_bytes(method, name="method")
self.url: URL = enforce_url(url, name="url")
self.headers: List[Tuple[bytes, bytes]] = enforce_headers(
headers, name="headers"
)
self.stream: Union[Iterable[bytes], AsyncIterable[bytes]] = enforce_stream(
content, name="content"
)
self.extensions = {} if extensions is None else extensions
def __repr__(self) -> str:
return f"<{self.__class__.__name__} [{self.method!r}]>"
class Response:
"""
An HTTP response.
"""
def __init__(
self,
status: int,
*,
headers: HeaderTypes = None,
content: Union[bytes, Iterable[bytes], AsyncIterable[bytes], None] = None,
extensions: Optional[Extensions] = None,
) -> None:
"""
Parameters:
status: The HTTP status code of the response. For example `200`.
headers: The HTTP response headers.
content: The content of the response body.
extensions: A dictionary of optional extra information included on
the responseself.Possible keys include `"http_version"`,
`"reason_phrase"`, and `"network_stream"`.
"""
self.status: int = status
self.headers: List[Tuple[bytes, bytes]] = enforce_headers(
headers, name="headers"
)
self.stream: Union[Iterable[bytes], AsyncIterable[bytes]] = enforce_stream(
content, name="content"
)
self.extensions = {} if extensions is None else extensions
self._stream_consumed = False
@property
def content(self) -> bytes:
if not hasattr(self, "_content"):
if isinstance(self.stream, Iterable):
raise RuntimeError(
"Attempted to access 'response.content' on a streaming response. "
"Call 'response.read()' first."
)
else:
raise RuntimeError(
"Attempted to access 'response.content' on a streaming response. "
"Call 'await response.aread()' first."
)
return self._content
def __repr__(self) -> str:
return f"<{self.__class__.__name__} [{self.status}]>"
# Sync interface...
def read(self) -> bytes:
if not isinstance(self.stream, Iterable): # pragma: nocover
raise RuntimeError(
"Attempted to read an asynchronous response using 'response.read()'. "
"You should use 'await response.aread()' instead."
)
if not hasattr(self, "_content"):
self._content = b"".join([part for part in self.iter_stream()])
return self._content
def iter_stream(self) -> Iterator[bytes]:
if not isinstance(self.stream, Iterable): # pragma: nocover
raise RuntimeError(
"Attempted to stream an asynchronous response using 'for ... in "
"response.iter_stream()'. "
"You should use 'async for ... in response.aiter_stream()' instead."
)
if self._stream_consumed:
raise RuntimeError(
"Attempted to call 'for ... in response.iter_stream()' more than once."
)
self._stream_consumed = True
for chunk in self.stream:
yield chunk
def close(self) -> None:
if not isinstance(self.stream, Iterable): # pragma: nocover
raise RuntimeError(
"Attempted to close an asynchronous response using 'response.close()'. "
"You should use 'await response.aclose()' instead."
)
if hasattr(self.stream, "close"):
self.stream.close()
# Async interface...
async def aread(self) -> bytes:
if not isinstance(self.stream, AsyncIterable): # pragma: nocover
raise RuntimeError(
"Attempted to read an synchronous response using "
"'await response.aread()'. "
"You should use 'response.read()' instead."
)
if not hasattr(self, "_content"):
self._content = b"".join([part async for part in self.aiter_stream()])
return self._content
async def aiter_stream(self) -> AsyncIterator[bytes]:
if not isinstance(self.stream, AsyncIterable): # pragma: nocover
raise RuntimeError(
"Attempted to stream an synchronous response using 'async for ... in "
"response.aiter_stream()'. "
"You should use 'for ... in response.iter_stream()' instead."
)
if self._stream_consumed:
raise RuntimeError(
"Attempted to call 'async for ... in response.aiter_stream()' "
"more than once."
)
self._stream_consumed = True
async for chunk in self.stream:
yield chunk
async def aclose(self) -> None:
if not isinstance(self.stream, AsyncIterable): # pragma: nocover
raise RuntimeError(
"Attempted to close a synchronous response using "
"'await response.aclose()'. "
"You should use 'response.close()' instead."
)
if hasattr(self.stream, "aclose"):
await self.stream.aclose()

8
tests/httpcore/_ssl.py Normal file
View File

@ -0,0 +1,8 @@
import ssl
import certifi
def default_ssl_context() -> ssl.SSLContext:
context = ssl.create_default_context()
context.load_verify_locations(certifi.where())
return context

View File

@ -0,0 +1,39 @@
from .connection import HTTPConnection
from .connection_pool import ConnectionPool
from .http11 import HTTP11Connection
from .http_proxy import HTTPProxy
from .interfaces import ConnectionInterface
try:
from .http2 import HTTP2Connection
except ImportError: # pragma: nocover
class HTTP2Connection: # type: ignore
def __init__(self, *args, **kwargs) -> None: # type: ignore
raise RuntimeError(
"Attempted to use http2 support, but the `h2` package is not "
"installed. Use 'pip install httpcore[http2]'."
)
try:
from .socks_proxy import SOCKSProxy
except ImportError: # pragma: nocover
class SOCKSProxy: # type: ignore
def __init__(self, *args, **kwargs) -> None: # type: ignore
raise RuntimeError(
"Attempted to use SOCKS support, but the `socksio` package is not "
"installed. Use 'pip install httpcore[socks]'."
)
__all__ = [
"HTTPConnection",
"ConnectionPool",
"HTTPProxy",
"HTTP11Connection",
"HTTP2Connection",
"ConnectionInterface",
"SOCKSProxy",
]

View File

@ -0,0 +1,215 @@
import itertools
import logging
import ssl
from types import TracebackType
from typing import Iterable, Iterator, Optional, Type
from .._backends.sync import SyncBackend
from .._backends.base import SOCKET_OPTION, NetworkBackend, NetworkStream
from .._exceptions import ConnectError, ConnectionNotAvailable, ConnectTimeout
from .._models import Origin, Request, Response
from .._ssl import default_ssl_context
from .._synchronization import Lock
from .._trace import Trace
from .http11 import HTTP11Connection
from .interfaces import ConnectionInterface
RETRIES_BACKOFF_FACTOR = 0.5 # 0s, 0.5s, 1s, 2s, 4s, etc.
logger = logging.getLogger("httpcore.connection")
def exponential_backoff(factor: float) -> Iterator[float]:
yield 0
for n in itertools.count(2):
yield factor * (2 ** (n - 2))
class HTTPConnection(ConnectionInterface):
def __init__(
self,
origin: Origin,
ssl_context: Optional[ssl.SSLContext] = None,
keepalive_expiry: Optional[float] = None,
http1: bool = True,
http2: bool = False,
retries: int = 0,
local_address: Optional[str] = None,
uds: Optional[str] = None,
network_backend: Optional[NetworkBackend] = None,
socket_options: Optional[Iterable[SOCKET_OPTION]] = None,
) -> None:
self._origin = origin
self._ssl_context = ssl_context
self._keepalive_expiry = keepalive_expiry
self._http1 = http1
self._http2 = http2
self._retries = retries
self._local_address = local_address
self._uds = uds
self._network_backend: NetworkBackend = (
SyncBackend() if network_backend is None else network_backend
)
self._connection: Optional[ConnectionInterface] = None
self._connect_failed: bool = False
self._request_lock = Lock()
self._socket_options = socket_options
def handle_request(self, request: Request) -> Response:
if not self.can_handle_request(request.url.origin):
raise RuntimeError(
f"Attempted to send request to {request.url.origin} on connection to {self._origin}"
)
with self._request_lock:
if self._connection is None:
try:
stream = self._connect(request)
ssl_object = stream.get_extra_info("ssl_object")
http2_negotiated = (
ssl_object is not None
and ssl_object.selected_alpn_protocol() == "h2"
)
if http2_negotiated or (self._http2 and not self._http1):
from .http2 import HTTP2Connection
self._connection = HTTP2Connection(
origin=self._origin,
stream=stream,
keepalive_expiry=self._keepalive_expiry,
)
else:
self._connection = HTTP11Connection(
origin=self._origin,
stream=stream,
keepalive_expiry=self._keepalive_expiry,
)
except Exception as exc:
self._connect_failed = True
raise exc
elif not self._connection.is_available():
raise ConnectionNotAvailable()
return self._connection.handle_request(request)
def _connect(self, request: Request) -> NetworkStream:
timeouts = request.extensions.get("timeout", {})
sni_hostname = request.extensions.get("sni_hostname", None)
timeout = timeouts.get("connect", None)
retries_left = self._retries
delays = exponential_backoff(factor=RETRIES_BACKOFF_FACTOR)
while True:
try:
if self._uds is None:
kwargs = {
"host": self._origin.host.decode("ascii"),
"port": self._origin.port,
"local_address": self._local_address,
"timeout": timeout,
"socket_options": self._socket_options,
}
with Trace("connect_tcp", logger, request, kwargs) as trace:
stream = self._network_backend.connect_tcp(**kwargs)
trace.return_value = stream
else:
kwargs = {
"path": self._uds,
"timeout": timeout,
"socket_options": self._socket_options,
}
with Trace(
"connect_unix_socket", logger, request, kwargs
) as trace:
stream = self._network_backend.connect_unix_socket(
**kwargs
)
trace.return_value = stream
if self._origin.scheme == b"https":
ssl_context = (
default_ssl_context()
if self._ssl_context is None
else self._ssl_context
)
alpn_protocols = ["http/1.1", "h2"] if self._http2 else ["http/1.1"]
ssl_context.set_alpn_protocols(alpn_protocols)
kwargs = {
"ssl_context": ssl_context,
"server_hostname": sni_hostname
or self._origin.host.decode("ascii"),
"timeout": timeout,
}
with Trace("start_tls", logger, request, kwargs) as trace:
stream = stream.start_tls(**kwargs)
trace.return_value = stream
return stream
except (ConnectError, ConnectTimeout):
if retries_left <= 0:
raise
retries_left -= 1
delay = next(delays)
with Trace("retry", logger, request, kwargs) as trace:
self._network_backend.sleep(delay)
def can_handle_request(self, origin: Origin) -> bool:
return origin == self._origin
def close(self) -> None:
if self._connection is not None:
with Trace("close", logger, None, {}):
self._connection.close()
def is_available(self) -> bool:
if self._connection is None:
# If HTTP/2 support is enabled, and the resulting connection could
# end up as HTTP/2 then we should indicate the connection as being
# available to service multiple requests.
return (
self._http2
and (self._origin.scheme == b"https" or not self._http1)
and not self._connect_failed
)
return self._connection.is_available()
def has_expired(self) -> bool:
if self._connection is None:
return self._connect_failed
return self._connection.has_expired()
def is_idle(self) -> bool:
if self._connection is None:
return self._connect_failed
return self._connection.is_idle()
def is_closed(self) -> bool:
if self._connection is None:
return self._connect_failed
return self._connection.is_closed()
def info(self) -> str:
if self._connection is None:
return "CONNECTION FAILED" if self._connect_failed else "CONNECTING"
return self._connection.info()
def __repr__(self) -> str:
return f"<{self.__class__.__name__} [{self.info()}]>"
# These context managers are not used in the standard flow, but are
# useful for testing or working with connection instances directly.
def __enter__(self) -> "HTTPConnection":
return self
def __exit__(
self,
exc_type: Optional[Type[BaseException]] = None,
exc_value: Optional[BaseException] = None,
traceback: Optional[TracebackType] = None,
) -> None:
self.close()

View File

@ -0,0 +1,356 @@
import ssl
import sys
from types import TracebackType
from typing import Iterable, Iterator, Iterable, List, Optional, Type
from .._backends.sync import SyncBackend
from .._backends.base import SOCKET_OPTION, NetworkBackend
from .._exceptions import ConnectionNotAvailable, UnsupportedProtocol
from .._models import Origin, Request, Response
from .._synchronization import Event, Lock, ShieldCancellation
from .connection import HTTPConnection
from .interfaces import ConnectionInterface, RequestInterface
class RequestStatus:
def __init__(self, request: Request):
self.request = request
self.connection: Optional[ConnectionInterface] = None
self._connection_acquired = Event()
def set_connection(self, connection: ConnectionInterface) -> None:
assert self.connection is None
self.connection = connection
self._connection_acquired.set()
def unset_connection(self) -> None:
assert self.connection is not None
self.connection = None
self._connection_acquired = Event()
def wait_for_connection(
self, timeout: Optional[float] = None
) -> ConnectionInterface:
if self.connection is None:
self._connection_acquired.wait(timeout=timeout)
assert self.connection is not None
return self.connection
class ConnectionPool(RequestInterface):
"""
A connection pool for making HTTP requests.
"""
def __init__(
self,
ssl_context: Optional[ssl.SSLContext] = None,
max_connections: Optional[int] = 10,
max_keepalive_connections: Optional[int] = None,
keepalive_expiry: Optional[float] = None,
http1: bool = True,
http2: bool = False,
retries: int = 0,
local_address: Optional[str] = None,
uds: Optional[str] = None,
network_backend: Optional[NetworkBackend] = None,
socket_options: Optional[Iterable[SOCKET_OPTION]] = None,
) -> None:
"""
A connection pool for making HTTP requests.
Parameters:
ssl_context: An SSL context to use for verifying connections.
If not specified, the default `httpcore.default_ssl_context()`
will be used.
max_connections: The maximum number of concurrent HTTP connections that
the pool should allow. Any attempt to send a request on a pool that
would exceed this amount will block until a connection is available.
max_keepalive_connections: The maximum number of idle HTTP connections
that will be maintained in the pool.
keepalive_expiry: The duration in seconds that an idle HTTP connection
may be maintained for before being expired from the pool.
http1: A boolean indicating if HTTP/1.1 requests should be supported
by the connection pool. Defaults to True.
http2: A boolean indicating if HTTP/2 requests should be supported by
the connection pool. Defaults to False.
retries: The maximum number of retries when trying to establish a
connection.
local_address: Local address to connect from. Can also be used to connect
using a particular address family. Using `local_address="0.0.0.0"`
will connect using an `AF_INET` address (IPv4), while using
`local_address="::"` will connect using an `AF_INET6` address (IPv6).
uds: Path to a Unix Domain Socket to use instead of TCP sockets.
network_backend: A backend instance to use for handling network I/O.
socket_options: Socket options that have to be included
in the TCP socket when the connection was established.
"""
self._ssl_context = ssl_context
self._max_connections = (
sys.maxsize if max_connections is None else max_connections
)
self._max_keepalive_connections = (
sys.maxsize
if max_keepalive_connections is None
else max_keepalive_connections
)
self._max_keepalive_connections = min(
self._max_connections, self._max_keepalive_connections
)
self._keepalive_expiry = keepalive_expiry
self._http1 = http1
self._http2 = http2
self._retries = retries
self._local_address = local_address
self._uds = uds
self._pool: List[ConnectionInterface] = []
self._requests: List[RequestStatus] = []
self._pool_lock = Lock()
self._network_backend = (
SyncBackend() if network_backend is None else network_backend
)
self._socket_options = socket_options
def create_connection(self, origin: Origin) -> ConnectionInterface:
return HTTPConnection(
origin=origin,
ssl_context=self._ssl_context,
keepalive_expiry=self._keepalive_expiry,
http1=self._http1,
http2=self._http2,
retries=self._retries,
local_address=self._local_address,
uds=self._uds,
network_backend=self._network_backend,
socket_options=self._socket_options,
)
@property
def connections(self) -> List[ConnectionInterface]:
"""
Return a list of the connections currently in the pool.
For example:
```python
>>> pool.connections
[
<HTTPConnection ['https://example.com:443', HTTP/1.1, ACTIVE, Request Count: 6]>,
<HTTPConnection ['https://example.com:443', HTTP/1.1, IDLE, Request Count: 9]> ,
<HTTPConnection ['http://example.com:80', HTTP/1.1, IDLE, Request Count: 1]>,
]
```
"""
return list(self._pool)
def _attempt_to_acquire_connection(self, status: RequestStatus) -> bool:
"""
Attempt to provide a connection that can handle the given origin.
"""
origin = status.request.url.origin
# If there are queued requests in front of us, then don't acquire a
# connection. We handle requests strictly in order.
waiting = [s for s in self._requests if s.connection is None]
if waiting and waiting[0] is not status:
return False
# Reuse an existing connection if one is currently available.
for idx, connection in enumerate(self._pool):
if connection.can_handle_request(origin) and connection.is_available():
self._pool.pop(idx)
self._pool.insert(0, connection)
status.set_connection(connection)
return True
# If the pool is currently full, attempt to close one idle connection.
if len(self._pool) >= self._max_connections:
for idx, connection in reversed(list(enumerate(self._pool))):
if connection.is_idle():
connection.close()
self._pool.pop(idx)
break
# If the pool is still full, then we cannot acquire a connection.
if len(self._pool) >= self._max_connections:
return False
# Otherwise create a new connection.
connection = self.create_connection(origin)
self._pool.insert(0, connection)
status.set_connection(connection)
return True
def _close_expired_connections(self) -> None:
"""
Clean up the connection pool by closing off any connections that have expired.
"""
# Close any connections that have expired their keep-alive time.
for idx, connection in reversed(list(enumerate(self._pool))):
if connection.has_expired():
connection.close()
self._pool.pop(idx)
# If the pool size exceeds the maximum number of allowed keep-alive connections,
# then close off idle connections as required.
pool_size = len(self._pool)
for idx, connection in reversed(list(enumerate(self._pool))):
if connection.is_idle() and pool_size > self._max_keepalive_connections:
connection.close()
self._pool.pop(idx)
pool_size -= 1
def handle_request(self, request: Request) -> Response:
"""
Send an HTTP request, and return an HTTP response.
This is the core implementation that is called into by `.request()` or `.stream()`.
"""
scheme = request.url.scheme.decode()
if scheme == "":
raise UnsupportedProtocol(
"Request URL is missing an 'http://' or 'https://' protocol."
)
if scheme not in ("http", "https", "ws", "wss"):
raise UnsupportedProtocol(
f"Request URL has an unsupported protocol '{scheme}://'."
)
status = RequestStatus(request)
with self._pool_lock:
self._requests.append(status)
self._close_expired_connections()
self._attempt_to_acquire_connection(status)
while True:
timeouts = request.extensions.get("timeout", {})
timeout = timeouts.get("pool", None)
try:
connection = status.wait_for_connection(timeout=timeout)
except BaseException as exc:
# If we timeout here, or if the task is cancelled, then make
# sure to remove the request from the queue before bubbling
# up the exception.
with self._pool_lock:
# Ensure only remove when task exists.
if status in self._requests:
self._requests.remove(status)
raise exc
try:
response = connection.handle_request(request)
except ConnectionNotAvailable:
# The ConnectionNotAvailable exception is a special case, that
# indicates we need to retry the request on a new connection.
#
# The most common case where this can occur is when multiple
# requests are queued waiting for a single connection, which
# might end up as an HTTP/2 connection, but which actually ends
# up as HTTP/1.1.
with self._pool_lock:
# Maintain our position in the request queue, but reset the
# status so that the request becomes queued again.
status.unset_connection()
self._attempt_to_acquire_connection(status)
except BaseException as exc:
with ShieldCancellation():
self.response_closed(status)
raise exc
else:
break
# When we return the response, we wrap the stream in a special class
# that handles notifying the connection pool once the response
# has been released.
assert isinstance(response.stream, Iterable)
return Response(
status=response.status,
headers=response.headers,
content=ConnectionPoolByteStream(response.stream, self, status),
extensions=response.extensions,
)
def response_closed(self, status: RequestStatus) -> None:
"""
This method acts as a callback once the request/response cycle is complete.
It is called into from the `ConnectionPoolByteStream.close()` method.
"""
assert status.connection is not None
connection = status.connection
with self._pool_lock:
# Update the state of the connection pool.
if status in self._requests:
self._requests.remove(status)
if connection.is_closed() and connection in self._pool:
self._pool.remove(connection)
# Since we've had a response closed, it's possible we'll now be able
# to service one or more requests that are currently pending.
for status in self._requests:
if status.connection is None:
acquired = self._attempt_to_acquire_connection(status)
# If we could not acquire a connection for a queued request
# then we don't need to check anymore requests that are
# queued later behind it.
if not acquired:
break
# Housekeeping.
self._close_expired_connections()
def close(self) -> None:
"""
Close any connections in the pool.
"""
with self._pool_lock:
for connection in self._pool:
connection.close()
self._pool = []
self._requests = []
def __enter__(self) -> "ConnectionPool":
return self
def __exit__(
self,
exc_type: Optional[Type[BaseException]] = None,
exc_value: Optional[BaseException] = None,
traceback: Optional[TracebackType] = None,
) -> None:
self.close()
class ConnectionPoolByteStream:
"""
A wrapper around the response byte stream, that additionally handles
notifying the connection pool when the response has been closed.
"""
def __init__(
self,
stream: Iterable[bytes],
pool: ConnectionPool,
status: RequestStatus,
) -> None:
self._stream = stream
self._pool = pool
self._status = status
def __iter__(self) -> Iterator[bytes]:
for part in self._stream:
yield part
def close(self) -> None:
try:
if hasattr(self._stream, "close"):
self._stream.close()
finally:
with ShieldCancellation():
self._pool.response_closed(self._status)

View File

@ -0,0 +1,331 @@
import enum
import logging
import time
from types import TracebackType
from typing import (
Iterable,
Iterator,
List,
Optional,
Tuple,
Type,
Union,
cast,
)
import h11
from .._backends.base import NetworkStream
from .._exceptions import (
ConnectionNotAvailable,
LocalProtocolError,
RemoteProtocolError,
map_exceptions,
)
from .._models import Origin, Request, Response
from .._synchronization import Lock, ShieldCancellation
from .._trace import Trace
from .interfaces import ConnectionInterface
logger = logging.getLogger("httpcore.http11")
# A subset of `h11.Event` types supported by `_send_event`
H11SendEvent = Union[
h11.Request,
h11.Data,
h11.EndOfMessage,
]
class HTTPConnectionState(enum.IntEnum):
NEW = 0
ACTIVE = 1
IDLE = 2
CLOSED = 3
class HTTP11Connection(ConnectionInterface):
READ_NUM_BYTES = 64 * 1024
MAX_INCOMPLETE_EVENT_SIZE = 100 * 1024
def __init__(
self,
origin: Origin,
stream: NetworkStream,
keepalive_expiry: Optional[float] = None,
) -> None:
self._origin = origin
self._network_stream = stream
self._keepalive_expiry: Optional[float] = keepalive_expiry
self._expire_at: Optional[float] = None
self._state = HTTPConnectionState.NEW
self._state_lock = Lock()
self._request_count = 0
self._h11_state = h11.Connection(
our_role=h11.CLIENT,
max_incomplete_event_size=self.MAX_INCOMPLETE_EVENT_SIZE,
)
def handle_request(self, request: Request) -> Response:
if not self.can_handle_request(request.url.origin):
raise RuntimeError(
f"Attempted to send request to {request.url.origin} on connection "
f"to {self._origin}"
)
with self._state_lock:
if self._state in (HTTPConnectionState.NEW, HTTPConnectionState.IDLE):
self._request_count += 1
self._state = HTTPConnectionState.ACTIVE
self._expire_at = None
else:
raise ConnectionNotAvailable()
try:
kwargs = {"request": request}
with Trace("send_request_headers", logger, request, kwargs) as trace:
self._send_request_headers(**kwargs)
with Trace("send_request_body", logger, request, kwargs) as trace:
self._send_request_body(**kwargs)
with Trace(
"receive_response_headers", logger, request, kwargs
) as trace:
(
http_version,
status,
reason_phrase,
headers,
) = self._receive_response_headers(**kwargs)
trace.return_value = (
http_version,
status,
reason_phrase,
headers,
)
return Response(
status=status,
headers=headers,
content=HTTP11ConnectionByteStream(self, request),
extensions={
"http_version": http_version,
"reason_phrase": reason_phrase,
"network_stream": self._network_stream,
},
)
except BaseException as exc:
with ShieldCancellation():
with Trace("response_closed", logger, request) as trace:
self._response_closed()
raise exc
# Sending the request...
def _send_request_headers(self, request: Request) -> None:
timeouts = request.extensions.get("timeout", {})
timeout = timeouts.get("write", None)
with map_exceptions({h11.LocalProtocolError: LocalProtocolError}):
event = h11.Request(
method=request.method,
target=request.url.target,
headers=request.headers,
)
self._send_event(event, timeout=timeout)
def _send_request_body(self, request: Request) -> None:
timeouts = request.extensions.get("timeout", {})
timeout = timeouts.get("write", None)
assert isinstance(request.stream, Iterable)
for chunk in request.stream:
event = h11.Data(data=chunk)
self._send_event(event, timeout=timeout)
self._send_event(h11.EndOfMessage(), timeout=timeout)
def _send_event(
self, event: h11.Event, timeout: Optional[float] = None
) -> None:
bytes_to_send = self._h11_state.send(event)
if bytes_to_send is not None:
self._network_stream.write(bytes_to_send, timeout=timeout)
# Receiving the response...
def _receive_response_headers(
self, request: Request
) -> Tuple[bytes, int, bytes, List[Tuple[bytes, bytes]]]:
timeouts = request.extensions.get("timeout", {})
timeout = timeouts.get("read", None)
while True:
event = self._receive_event(timeout=timeout)
if isinstance(event, h11.Response):
break
if (
isinstance(event, h11.InformationalResponse)
and event.status_code == 101
):
break
http_version = b"HTTP/" + event.http_version
# h11 version 0.11+ supports a `raw_items` interface to get the
# raw header casing, rather than the enforced lowercase headers.
headers = event.headers.raw_items()
return http_version, event.status_code, event.reason, headers
def _receive_response_body(self, request: Request) -> Iterator[bytes]:
timeouts = request.extensions.get("timeout", {})
timeout = timeouts.get("read", None)
while True:
event = self._receive_event(timeout=timeout)
if isinstance(event, h11.Data):
yield bytes(event.data)
elif isinstance(event, (h11.EndOfMessage, h11.PAUSED)):
break
def _receive_event(
self, timeout: Optional[float] = None
) -> Union[h11.Event, Type[h11.PAUSED]]:
while True:
with map_exceptions({h11.RemoteProtocolError: RemoteProtocolError}):
event = self._h11_state.next_event()
if event is h11.NEED_DATA:
data = self._network_stream.read(
self.READ_NUM_BYTES, timeout=timeout
)
# If we feed this case through h11 we'll raise an exception like:
#
# httpcore.RemoteProtocolError: can't handle event type
# ConnectionClosed when role=SERVER and state=SEND_RESPONSE
#
# Which is accurate, but not very informative from an end-user
# perspective. Instead we handle this case distinctly and treat
# it as a ConnectError.
if data == b"" and self._h11_state.their_state == h11.SEND_RESPONSE:
msg = "Server disconnected without sending a response."
raise RemoteProtocolError(msg)
self._h11_state.receive_data(data)
else:
# mypy fails to narrow the type in the above if statement above
return cast(Union[h11.Event, Type[h11.PAUSED]], event)
def _response_closed(self) -> None:
with self._state_lock:
if (
self._h11_state.our_state is h11.DONE
and self._h11_state.their_state is h11.DONE
):
self._state = HTTPConnectionState.IDLE
self._h11_state.start_next_cycle()
if self._keepalive_expiry is not None:
now = time.monotonic()
self._expire_at = now + self._keepalive_expiry
else:
self.close()
# Once the connection is no longer required...
def close(self) -> None:
# Note that this method unilaterally closes the connection, and does
# not have any kind of locking in place around it.
self._state = HTTPConnectionState.CLOSED
self._network_stream.close()
# The ConnectionInterface methods provide information about the state of
# the connection, allowing for a connection pooling implementation to
# determine when to reuse and when to close the connection...
def can_handle_request(self, origin: Origin) -> bool:
return origin == self._origin
def is_available(self) -> bool:
# Note that HTTP/1.1 connections in the "NEW" state are not treated as
# being "available". The control flow which created the connection will
# be able to send an outgoing request, but the connection will not be
# acquired from the connection pool for any other request.
return self._state == HTTPConnectionState.IDLE
def has_expired(self) -> bool:
now = time.monotonic()
keepalive_expired = self._expire_at is not None and now > self._expire_at
# If the HTTP connection is idle but the socket is readable, then the
# only valid state is that the socket is about to return b"", indicating
# a server-initiated disconnect.
server_disconnected = (
self._state == HTTPConnectionState.IDLE
and self._network_stream.get_extra_info("is_readable")
)
return keepalive_expired or server_disconnected
def is_idle(self) -> bool:
return self._state == HTTPConnectionState.IDLE
def is_closed(self) -> bool:
return self._state == HTTPConnectionState.CLOSED
def info(self) -> str:
origin = str(self._origin)
return (
f"{origin!r}, HTTP/1.1, {self._state.name}, "
f"Request Count: {self._request_count}"
)
def __repr__(self) -> str:
class_name = self.__class__.__name__
origin = str(self._origin)
return (
f"<{class_name} [{origin!r}, {self._state.name}, "
f"Request Count: {self._request_count}]>"
)
# These context managers are not used in the standard flow, but are
# useful for testing or working with connection instances directly.
def __enter__(self) -> "HTTP11Connection":
return self
def __exit__(
self,
exc_type: Optional[Type[BaseException]] = None,
exc_value: Optional[BaseException] = None,
traceback: Optional[TracebackType] = None,
) -> None:
self.close()
class HTTP11ConnectionByteStream:
def __init__(self, connection: HTTP11Connection, request: Request) -> None:
self._connection = connection
self._request = request
self._closed = False
def __iter__(self) -> Iterator[bytes]:
kwargs = {"request": self._request}
try:
with Trace("receive_response_body", logger, self._request, kwargs):
for chunk in self._connection._receive_response_body(**kwargs):
yield chunk
except BaseException as exc:
# If we get an exception while streaming the response,
# we want to close the response (and possibly the connection)
# before raising that exception.
with ShieldCancellation():
self.close()
raise exc
def close(self) -> None:
if not self._closed:
self._closed = True
with Trace("response_closed", logger, self._request):
self._connection._response_closed()

View File

@ -0,0 +1,589 @@
import enum
import logging
import time
import types
import typing
import h2.config
import h2.connection
import h2.events
import h2.exceptions
import h2.settings
from .._backends.base import NetworkStream
from .._exceptions import (
ConnectionNotAvailable,
LocalProtocolError,
RemoteProtocolError,
)
from .._models import Origin, Request, Response
from .._synchronization import Lock, Semaphore, ShieldCancellation
from .._trace import Trace
from .interfaces import ConnectionInterface
logger = logging.getLogger("httpcore.http2")
def has_body_headers(request: Request) -> bool:
return any(
k.lower() == b"content-length" or k.lower() == b"transfer-encoding"
for k, v in request.headers
)
class HTTPConnectionState(enum.IntEnum):
ACTIVE = 1
IDLE = 2
CLOSED = 3
class HTTP2Connection(ConnectionInterface):
READ_NUM_BYTES = 64 * 1024
CONFIG = h2.config.H2Configuration(validate_inbound_headers=False)
def __init__(
self,
origin: Origin,
stream: NetworkStream,
keepalive_expiry: typing.Optional[float] = None,
):
self._origin = origin
self._network_stream = stream
self._keepalive_expiry: typing.Optional[float] = keepalive_expiry
self._h2_state = h2.connection.H2Connection(config=self.CONFIG)
self._state = HTTPConnectionState.IDLE
self._expire_at: typing.Optional[float] = None
self._request_count = 0
self._init_lock = Lock()
self._state_lock = Lock()
self._read_lock = Lock()
self._write_lock = Lock()
self._sent_connection_init = False
self._used_all_stream_ids = False
self._connection_error = False
# Mapping from stream ID to response stream events.
self._events: typing.Dict[
int,
typing.Union[
h2.events.ResponseReceived,
h2.events.DataReceived,
h2.events.StreamEnded,
h2.events.StreamReset,
],
] = {}
# Connection terminated events are stored as state since
# we need to handle them for all streams.
self._connection_terminated: typing.Optional[
h2.events.ConnectionTerminated
] = None
self._read_exception: typing.Optional[Exception] = None
self._write_exception: typing.Optional[Exception] = None
def handle_request(self, request: Request) -> Response:
if not self.can_handle_request(request.url.origin):
# This cannot occur in normal operation, since the connection pool
# will only send requests on connections that handle them.
# It's in place simply for resilience as a guard against incorrect
# usage, for anyone working directly with httpcore connections.
raise RuntimeError(
f"Attempted to send request to {request.url.origin} on connection "
f"to {self._origin}"
)
with self._state_lock:
if self._state in (HTTPConnectionState.ACTIVE, HTTPConnectionState.IDLE):
self._request_count += 1
self._expire_at = None
self._state = HTTPConnectionState.ACTIVE
else:
raise ConnectionNotAvailable()
with self._init_lock:
if not self._sent_connection_init:
try:
kwargs = {"request": request}
with Trace("send_connection_init", logger, request, kwargs):
self._send_connection_init(**kwargs)
except BaseException as exc:
with ShieldCancellation():
self.close()
raise exc
self._sent_connection_init = True
# Initially start with just 1 until the remote server provides
# its max_concurrent_streams value
self._max_streams = 1
local_settings_max_streams = (
self._h2_state.local_settings.max_concurrent_streams
)
self._max_streams_semaphore = Semaphore(local_settings_max_streams)
for _ in range(local_settings_max_streams - self._max_streams):
self._max_streams_semaphore.acquire()
self._max_streams_semaphore.acquire()
try:
stream_id = self._h2_state.get_next_available_stream_id()
self._events[stream_id] = []
except h2.exceptions.NoAvailableStreamIDError: # pragma: nocover
self._used_all_stream_ids = True
self._request_count -= 1
raise ConnectionNotAvailable()
try:
kwargs = {"request": request, "stream_id": stream_id}
with Trace("send_request_headers", logger, request, kwargs):
self._send_request_headers(request=request, stream_id=stream_id)
with Trace("send_request_body", logger, request, kwargs):
self._send_request_body(request=request, stream_id=stream_id)
with Trace(
"receive_response_headers", logger, request, kwargs
) as trace:
status, headers = self._receive_response(
request=request, stream_id=stream_id
)
trace.return_value = (status, headers)
return Response(
status=status,
headers=headers,
content=HTTP2ConnectionByteStream(self, request, stream_id=stream_id),
extensions={
"http_version": b"HTTP/2",
"network_stream": self._network_stream,
"stream_id": stream_id,
},
)
except BaseException as exc: # noqa: PIE786
with ShieldCancellation():
kwargs = {"stream_id": stream_id}
with Trace("response_closed", logger, request, kwargs):
self._response_closed(stream_id=stream_id)
if isinstance(exc, h2.exceptions.ProtocolError):
# One case where h2 can raise a protocol error is when a
# closed frame has been seen by the state machine.
#
# This happens when one stream is reading, and encounters
# a GOAWAY event. Other flows of control may then raise
# a protocol error at any point they interact with the 'h2_state'.
#
# In this case we'll have stored the event, and should raise
# it as a RemoteProtocolError.
if self._connection_terminated: # pragma: nocover
raise RemoteProtocolError(self._connection_terminated)
# If h2 raises a protocol error in some other state then we
# must somehow have made a protocol violation.
raise LocalProtocolError(exc) # pragma: nocover
raise exc
def _send_connection_init(self, request: Request) -> None:
"""
The HTTP/2 connection requires some initial setup before we can start
using individual request/response streams on it.
"""
# Need to set these manually here instead of manipulating via
# __setitem__() otherwise the H2Connection will emit SettingsUpdate
# frames in addition to sending the undesired defaults.
self._h2_state.local_settings = h2.settings.Settings(
client=True,
initial_values={
# Disable PUSH_PROMISE frames from the server since we don't do anything
# with them for now. Maybe when we support caching?
h2.settings.SettingCodes.ENABLE_PUSH: 0,
# These two are taken from h2 for safe defaults
h2.settings.SettingCodes.MAX_CONCURRENT_STREAMS: 100,
h2.settings.SettingCodes.MAX_HEADER_LIST_SIZE: 65536,
},
)
# Some websites (*cough* Yahoo *cough*) balk at this setting being
# present in the initial handshake since it's not defined in the original
# RFC despite the RFC mandating ignoring settings you don't know about.
del self._h2_state.local_settings[
h2.settings.SettingCodes.ENABLE_CONNECT_PROTOCOL
]
self._h2_state.initiate_connection()
self._h2_state.increment_flow_control_window(2**24)
self._write_outgoing_data(request)
# Sending the request...
def _send_request_headers(self, request: Request, stream_id: int) -> None:
"""
Send the request headers to a given stream ID.
"""
end_stream = not has_body_headers(request)
# In HTTP/2 the ':authority' pseudo-header is used instead of 'Host'.
# In order to gracefully handle HTTP/1.1 and HTTP/2 we always require
# HTTP/1.1 style headers, and map them appropriately if we end up on
# an HTTP/2 connection.
authority = [v for k, v in request.headers if k.lower() == b"host"][0]
headers = [
(b":method", request.method),
(b":authority", authority),
(b":scheme", request.url.scheme),
(b":path", request.url.target),
] + [
(k.lower(), v)
for k, v in request.headers
if k.lower()
not in (
b"host",
b"transfer-encoding",
)
]
self._h2_state.send_headers(stream_id, headers, end_stream=end_stream)
self._h2_state.increment_flow_control_window(2**24, stream_id=stream_id)
self._write_outgoing_data(request)
def _send_request_body(self, request: Request, stream_id: int) -> None:
"""
Iterate over the request body sending it to a given stream ID.
"""
if not has_body_headers(request):
return
assert isinstance(request.stream, typing.Iterable)
for data in request.stream:
self._send_stream_data(request, stream_id, data)
self._send_end_stream(request, stream_id)
def _send_stream_data(
self, request: Request, stream_id: int, data: bytes
) -> None:
"""
Send a single chunk of data in one or more data frames.
"""
while data:
max_flow = self._wait_for_outgoing_flow(request, stream_id)
chunk_size = min(len(data), max_flow)
chunk, data = data[:chunk_size], data[chunk_size:]
self._h2_state.send_data(stream_id, chunk)
self._write_outgoing_data(request)
def _send_end_stream(self, request: Request, stream_id: int) -> None:
"""
Send an empty data frame on on a given stream ID with the END_STREAM flag set.
"""
self._h2_state.end_stream(stream_id)
self._write_outgoing_data(request)
# Receiving the response...
def _receive_response(
self, request: Request, stream_id: int
) -> typing.Tuple[int, typing.List[typing.Tuple[bytes, bytes]]]:
"""
Return the response status code and headers for a given stream ID.
"""
while True:
event = self._receive_stream_event(request, stream_id)
if isinstance(event, h2.events.ResponseReceived):
break
status_code = 200
headers = []
for k, v in event.headers:
if k == b":status":
status_code = int(v.decode("ascii", errors="ignore"))
elif not k.startswith(b":"):
headers.append((k, v))
return (status_code, headers)
def _receive_response_body(
self, request: Request, stream_id: int
) -> typing.Iterator[bytes]:
"""
Iterator that returns the bytes of the response body for a given stream ID.
"""
while True:
event = self._receive_stream_event(request, stream_id)
if isinstance(event, h2.events.DataReceived):
amount = event.flow_controlled_length
self._h2_state.acknowledge_received_data(amount, stream_id)
self._write_outgoing_data(request)
yield event.data
elif isinstance(event, h2.events.StreamEnded):
break
def _receive_stream_event(
self, request: Request, stream_id: int
) -> typing.Union[
h2.events.ResponseReceived, h2.events.DataReceived, h2.events.StreamEnded
]:
"""
Return the next available event for a given stream ID.
Will read more data from the network if required.
"""
while not self._events.get(stream_id):
self._receive_events(request, stream_id)
event = self._events[stream_id].pop(0)
if isinstance(event, h2.events.StreamReset):
raise RemoteProtocolError(event)
return event
def _receive_events(
self, request: Request, stream_id: typing.Optional[int] = None
) -> None:
"""
Read some data from the network until we see one or more events
for a given stream ID.
"""
with self._read_lock:
if self._connection_terminated is not None:
last_stream_id = self._connection_terminated.last_stream_id
if stream_id and last_stream_id and stream_id > last_stream_id:
self._request_count -= 1
raise ConnectionNotAvailable()
raise RemoteProtocolError(self._connection_terminated)
# This conditional is a bit icky. We don't want to block reading if we've
# actually got an event to return for a given stream. We need to do that
# check *within* the atomic read lock. Though it also need to be optional,
# because when we call it from `_wait_for_outgoing_flow` we *do* want to
# block until we've available flow control, event when we have events
# pending for the stream ID we're attempting to send on.
if stream_id is None or not self._events.get(stream_id):
events = self._read_incoming_data(request)
for event in events:
if isinstance(event, h2.events.RemoteSettingsChanged):
with Trace(
"receive_remote_settings", logger, request
) as trace:
self._receive_remote_settings_change(event)
trace.return_value = event
elif isinstance(
event,
(
h2.events.ResponseReceived,
h2.events.DataReceived,
h2.events.StreamEnded,
h2.events.StreamReset,
),
):
if event.stream_id in self._events:
self._events[event.stream_id].append(event)
elif isinstance(event, h2.events.ConnectionTerminated):
self._connection_terminated = event
self._write_outgoing_data(request)
def _receive_remote_settings_change(self, event: h2.events.Event) -> None:
max_concurrent_streams = event.changed_settings.get(
h2.settings.SettingCodes.MAX_CONCURRENT_STREAMS
)
if max_concurrent_streams:
new_max_streams = min(
max_concurrent_streams.new_value,
self._h2_state.local_settings.max_concurrent_streams,
)
if new_max_streams and new_max_streams != self._max_streams:
while new_max_streams > self._max_streams:
self._max_streams_semaphore.release()
self._max_streams += 1
while new_max_streams < self._max_streams:
self._max_streams_semaphore.acquire()
self._max_streams -= 1
def _response_closed(self, stream_id: int) -> None:
self._max_streams_semaphore.release()
del self._events[stream_id]
with self._state_lock:
if self._connection_terminated and not self._events:
self.close()
elif self._state == HTTPConnectionState.ACTIVE and not self._events:
self._state = HTTPConnectionState.IDLE
if self._keepalive_expiry is not None:
now = time.monotonic()
self._expire_at = now + self._keepalive_expiry
if self._used_all_stream_ids: # pragma: nocover
self.close()
def close(self) -> None:
# Note that this method unilaterally closes the connection, and does
# not have any kind of locking in place around it.
self._h2_state.close_connection()
self._state = HTTPConnectionState.CLOSED
self._network_stream.close()
# Wrappers around network read/write operations...
def _read_incoming_data(
self, request: Request
) -> typing.List[h2.events.Event]:
timeouts = request.extensions.get("timeout", {})
timeout = timeouts.get("read", None)
if self._read_exception is not None:
raise self._read_exception # pragma: nocover
try:
data = self._network_stream.read(self.READ_NUM_BYTES, timeout)
if data == b"":
raise RemoteProtocolError("Server disconnected")
except Exception as exc:
# If we get a network error we should:
#
# 1. Save the exception and just raise it immediately on any future reads.
# (For example, this means that a single read timeout or disconnect will
# immediately close all pending streams. Without requiring multiple
# sequential timeouts.)
# 2. Mark the connection as errored, so that we don't accept any other
# incoming requests.
self._read_exception = exc
self._connection_error = True
raise exc
events: typing.List[h2.events.Event] = self._h2_state.receive_data(data)
return events
def _write_outgoing_data(self, request: Request) -> None:
timeouts = request.extensions.get("timeout", {})
timeout = timeouts.get("write", None)
with self._write_lock:
data_to_send = self._h2_state.data_to_send()
if self._write_exception is not None:
raise self._write_exception # pragma: nocover
try:
self._network_stream.write(data_to_send, timeout)
except Exception as exc: # pragma: nocover
# If we get a network error we should:
#
# 1. Save the exception and just raise it immediately on any future write.
# (For example, this means that a single write timeout or disconnect will
# immediately close all pending streams. Without requiring multiple
# sequential timeouts.)
# 2. Mark the connection as errored, so that we don't accept any other
# incoming requests.
self._write_exception = exc
self._connection_error = True
raise exc
# Flow control...
def _wait_for_outgoing_flow(self, request: Request, stream_id: int) -> int:
"""
Returns the maximum allowable outgoing flow for a given stream.
If the allowable flow is zero, then waits on the network until
WindowUpdated frames have increased the flow rate.
https://tools.ietf.org/html/rfc7540#section-6.9
"""
local_flow: int = self._h2_state.local_flow_control_window(stream_id)
max_frame_size: int = self._h2_state.max_outbound_frame_size
flow = min(local_flow, max_frame_size)
while flow == 0:
self._receive_events(request)
local_flow = self._h2_state.local_flow_control_window(stream_id)
max_frame_size = self._h2_state.max_outbound_frame_size
flow = min(local_flow, max_frame_size)
return flow
# Interface for connection pooling...
def can_handle_request(self, origin: Origin) -> bool:
return origin == self._origin
def is_available(self) -> bool:
return (
self._state != HTTPConnectionState.CLOSED
and not self._connection_error
and not self._used_all_stream_ids
and not (
self._h2_state.state_machine.state
== h2.connection.ConnectionState.CLOSED
)
)
def has_expired(self) -> bool:
now = time.monotonic()
return self._expire_at is not None and now > self._expire_at
def is_idle(self) -> bool:
return self._state == HTTPConnectionState.IDLE
def is_closed(self) -> bool:
return self._state == HTTPConnectionState.CLOSED
def info(self) -> str:
origin = str(self._origin)
return (
f"{origin!r}, HTTP/2, {self._state.name}, "
f"Request Count: {self._request_count}"
)
def __repr__(self) -> str:
class_name = self.__class__.__name__
origin = str(self._origin)
return (
f"<{class_name} [{origin!r}, {self._state.name}, "
f"Request Count: {self._request_count}]>"
)
# These context managers are not used in the standard flow, but are
# useful for testing or working with connection instances directly.
def __enter__(self) -> "HTTP2Connection":
return self
def __exit__(
self,
exc_type: typing.Optional[typing.Type[BaseException]] = None,
exc_value: typing.Optional[BaseException] = None,
traceback: typing.Optional[types.TracebackType] = None,
) -> None:
self.close()
class HTTP2ConnectionByteStream:
def __init__(
self, connection: HTTP2Connection, request: Request, stream_id: int
) -> None:
self._connection = connection
self._request = request
self._stream_id = stream_id
self._closed = False
def __iter__(self) -> typing.Iterator[bytes]:
kwargs = {"request": self._request, "stream_id": self._stream_id}
try:
with Trace("receive_response_body", logger, self._request, kwargs):
for chunk in self._connection._receive_response_body(
request=self._request, stream_id=self._stream_id
):
yield chunk
except BaseException as exc:
# If we get an exception while streaming the response,
# we want to close the response (and possibly the connection)
# before raising that exception.
with ShieldCancellation():
self.close()
raise exc
def close(self) -> None:
if not self._closed:
self._closed = True
kwargs = {"stream_id": self._stream_id}
with Trace("response_closed", logger, self._request, kwargs):
self._connection._response_closed(stream_id=self._stream_id)

View File

@ -0,0 +1,350 @@
import logging
import ssl
from base64 import b64encode
from typing import Iterable, List, Mapping, Optional, Sequence, Tuple, Union
from .._backends.base import SOCKET_OPTION, NetworkBackend
from .._exceptions import ProxyError
from .._models import (
URL,
Origin,
Request,
Response,
enforce_bytes,
enforce_headers,
enforce_url,
)
from .._ssl import default_ssl_context
from .._synchronization import Lock
from .._trace import Trace
from .connection import HTTPConnection
from .connection_pool import ConnectionPool
from .http11 import HTTP11Connection
from .interfaces import ConnectionInterface
HeadersAsSequence = Sequence[Tuple[Union[bytes, str], Union[bytes, str]]]
HeadersAsMapping = Mapping[Union[bytes, str], Union[bytes, str]]
logger = logging.getLogger("httpcore.proxy")
def merge_headers(
default_headers: Optional[Sequence[Tuple[bytes, bytes]]] = None,
override_headers: Optional[Sequence[Tuple[bytes, bytes]]] = None,
) -> List[Tuple[bytes, bytes]]:
"""
Append default_headers and override_headers, de-duplicating if a key exists
in both cases.
"""
default_headers = [] if default_headers is None else list(default_headers)
override_headers = [] if override_headers is None else list(override_headers)
has_override = set(key.lower() for key, value in override_headers)
default_headers = [
(key, value)
for key, value in default_headers
if key.lower() not in has_override
]
return default_headers + override_headers
def build_auth_header(username: bytes, password: bytes) -> bytes:
userpass = username + b":" + password
return b"Basic " + b64encode(userpass)
class HTTPProxy(ConnectionPool):
"""
A connection pool that sends requests via an HTTP proxy.
"""
def __init__(
self,
proxy_url: Union[URL, bytes, str],
proxy_auth: Optional[Tuple[Union[bytes, str], Union[bytes, str]]] = None,
proxy_headers: Union[HeadersAsMapping, HeadersAsSequence, None] = None,
ssl_context: Optional[ssl.SSLContext] = None,
max_connections: Optional[int] = 10,
max_keepalive_connections: Optional[int] = None,
keepalive_expiry: Optional[float] = None,
http1: bool = True,
http2: bool = False,
retries: int = 0,
local_address: Optional[str] = None,
uds: Optional[str] = None,
network_backend: Optional[NetworkBackend] = None,
socket_options: Optional[Iterable[SOCKET_OPTION]] = None,
) -> None:
"""
A connection pool for making HTTP requests.
Parameters:
proxy_url: The URL to use when connecting to the proxy server.
For example `"http://127.0.0.1:8080/"`.
proxy_auth: Any proxy authentication as a two-tuple of
(username, password). May be either bytes or ascii-only str.
proxy_headers: Any HTTP headers to use for the proxy requests.
For example `{"Proxy-Authorization": "Basic <username>:<password>"}`.
ssl_context: An SSL context to use for verifying connections.
If not specified, the default `httpcore.default_ssl_context()`
will be used.
max_connections: The maximum number of concurrent HTTP connections that
the pool should allow. Any attempt to send a request on a pool that
would exceed this amount will block until a connection is available.
max_keepalive_connections: The maximum number of idle HTTP connections
that will be maintained in the pool.
keepalive_expiry: The duration in seconds that an idle HTTP connection
may be maintained for before being expired from the pool.
http1: A boolean indicating if HTTP/1.1 requests should be supported
by the connection pool. Defaults to True.
http2: A boolean indicating if HTTP/2 requests should be supported by
the connection pool. Defaults to False.
retries: The maximum number of retries when trying to establish
a connection.
local_address: Local address to connect from. Can also be used to
connect using a particular address family. Using
`local_address="0.0.0.0"` will connect using an `AF_INET` address
(IPv4), while using `local_address="::"` will connect using an
`AF_INET6` address (IPv6).
uds: Path to a Unix Domain Socket to use instead of TCP sockets.
network_backend: A backend instance to use for handling network I/O.
"""
super().__init__(
ssl_context=ssl_context,
max_connections=max_connections,
max_keepalive_connections=max_keepalive_connections,
keepalive_expiry=keepalive_expiry,
http1=http1,
http2=http2,
network_backend=network_backend,
retries=retries,
local_address=local_address,
uds=uds,
socket_options=socket_options,
)
self._ssl_context = ssl_context
self._proxy_url = enforce_url(proxy_url, name="proxy_url")
self._proxy_headers = enforce_headers(proxy_headers, name="proxy_headers")
if proxy_auth is not None:
username = enforce_bytes(proxy_auth[0], name="proxy_auth")
password = enforce_bytes(proxy_auth[1], name="proxy_auth")
authorization = build_auth_header(username, password)
self._proxy_headers = [
(b"Proxy-Authorization", authorization)
] + self._proxy_headers
def create_connection(self, origin: Origin) -> ConnectionInterface:
if origin.scheme == b"http":
return ForwardHTTPConnection(
proxy_origin=self._proxy_url.origin,
proxy_headers=self._proxy_headers,
remote_origin=origin,
keepalive_expiry=self._keepalive_expiry,
network_backend=self._network_backend,
)
return TunnelHTTPConnection(
proxy_origin=self._proxy_url.origin,
proxy_headers=self._proxy_headers,
remote_origin=origin,
ssl_context=self._ssl_context,
keepalive_expiry=self._keepalive_expiry,
http1=self._http1,
http2=self._http2,
network_backend=self._network_backend,
)
class ForwardHTTPConnection(ConnectionInterface):
def __init__(
self,
proxy_origin: Origin,
remote_origin: Origin,
proxy_headers: Union[HeadersAsMapping, HeadersAsSequence, None] = None,
keepalive_expiry: Optional[float] = None,
network_backend: Optional[NetworkBackend] = None,
socket_options: Optional[Iterable[SOCKET_OPTION]] = None,
) -> None:
self._connection = HTTPConnection(
origin=proxy_origin,
keepalive_expiry=keepalive_expiry,
network_backend=network_backend,
socket_options=socket_options,
)
self._proxy_origin = proxy_origin
self._proxy_headers = enforce_headers(proxy_headers, name="proxy_headers")
self._remote_origin = remote_origin
def handle_request(self, request: Request) -> Response:
headers = merge_headers(self._proxy_headers, request.headers)
url = URL(
scheme=self._proxy_origin.scheme,
host=self._proxy_origin.host,
port=self._proxy_origin.port,
target=bytes(request.url),
)
proxy_request = Request(
method=request.method,
url=url,
headers=headers,
content=request.stream,
extensions=request.extensions,
)
return self._connection.handle_request(proxy_request)
def can_handle_request(self, origin: Origin) -> bool:
return origin == self._remote_origin
def close(self) -> None:
self._connection.close()
def info(self) -> str:
return self._connection.info()
def is_available(self) -> bool:
return self._connection.is_available()
def has_expired(self) -> bool:
return self._connection.has_expired()
def is_idle(self) -> bool:
return self._connection.is_idle()
def is_closed(self) -> bool:
return self._connection.is_closed()
def __repr__(self) -> str:
return f"<{self.__class__.__name__} [{self.info()}]>"
class TunnelHTTPConnection(ConnectionInterface):
def __init__(
self,
proxy_origin: Origin,
remote_origin: Origin,
ssl_context: Optional[ssl.SSLContext] = None,
proxy_headers: Optional[Sequence[Tuple[bytes, bytes]]] = None,
keepalive_expiry: Optional[float] = None,
http1: bool = True,
http2: bool = False,
network_backend: Optional[NetworkBackend] = None,
socket_options: Optional[Iterable[SOCKET_OPTION]] = None,
) -> None:
self._connection: ConnectionInterface = HTTPConnection(
origin=proxy_origin,
keepalive_expiry=keepalive_expiry,
network_backend=network_backend,
socket_options=socket_options,
)
self._proxy_origin = proxy_origin
self._remote_origin = remote_origin
self._ssl_context = ssl_context
self._proxy_headers = enforce_headers(proxy_headers, name="proxy_headers")
self._keepalive_expiry = keepalive_expiry
self._http1 = http1
self._http2 = http2
self._connect_lock = Lock()
self._connected = False
def handle_request(self, request: Request) -> Response:
timeouts = request.extensions.get("timeout", {})
timeout = timeouts.get("connect", None)
with self._connect_lock:
if not self._connected:
target = b"%b:%d" % (self._remote_origin.host, self._remote_origin.port)
connect_url = URL(
scheme=self._proxy_origin.scheme,
host=self._proxy_origin.host,
port=self._proxy_origin.port,
target=target,
)
connect_headers = merge_headers(
[(b"Host", target), (b"Accept", b"*/*")], self._proxy_headers
)
connect_request = Request(
method=b"CONNECT",
url=connect_url,
headers=connect_headers,
extensions=request.extensions,
)
connect_response = self._connection.handle_request(
connect_request
)
if connect_response.status < 200 or connect_response.status > 299:
reason_bytes = connect_response.extensions.get("reason_phrase", b"")
reason_str = reason_bytes.decode("ascii", errors="ignore")
msg = "%d %s" % (connect_response.status, reason_str)
self._connection.close()
raise ProxyError(msg)
stream = connect_response.extensions["network_stream"]
# Upgrade the stream to SSL
ssl_context = (
default_ssl_context()
if self._ssl_context is None
else self._ssl_context
)
alpn_protocols = ["http/1.1", "h2"] if self._http2 else ["http/1.1"]
ssl_context.set_alpn_protocols(alpn_protocols)
kwargs = {
"ssl_context": ssl_context,
"server_hostname": self._remote_origin.host.decode("ascii"),
"timeout": timeout,
}
with Trace("start_tls", logger, request, kwargs) as trace:
stream = stream.start_tls(**kwargs)
trace.return_value = stream
# Determine if we should be using HTTP/1.1 or HTTP/2
ssl_object = stream.get_extra_info("ssl_object")
http2_negotiated = (
ssl_object is not None
and ssl_object.selected_alpn_protocol() == "h2"
)
# Create the HTTP/1.1 or HTTP/2 connection
if http2_negotiated or (self._http2 and not self._http1):
from .http2 import HTTP2Connection
self._connection = HTTP2Connection(
origin=self._remote_origin,
stream=stream,
keepalive_expiry=self._keepalive_expiry,
)
else:
self._connection = HTTP11Connection(
origin=self._remote_origin,
stream=stream,
keepalive_expiry=self._keepalive_expiry,
)
self._connected = True
return self._connection.handle_request(request)
def can_handle_request(self, origin: Origin) -> bool:
return origin == self._remote_origin
def close(self) -> None:
self._connection.close()
def info(self) -> str:
return self._connection.info()
def is_available(self) -> bool:
return self._connection.is_available()
def has_expired(self) -> bool:
return self._connection.has_expired()
def is_idle(self) -> bool:
return self._connection.is_idle()
def is_closed(self) -> bool:
return self._connection.is_closed()
def __repr__(self) -> str:
return f"<{self.__class__.__name__} [{self.info()}]>"

View File

@ -0,0 +1,135 @@
from contextlib import contextmanager
from typing import Iterator, Optional, Union
from .._models import (
URL,
Extensions,
HeaderTypes,
Origin,
Request,
Response,
enforce_bytes,
enforce_headers,
enforce_url,
include_request_headers,
)
class RequestInterface:
def request(
self,
method: Union[bytes, str],
url: Union[URL, bytes, str],
*,
headers: HeaderTypes = None,
content: Union[bytes, Iterator[bytes], None] = None,
extensions: Optional[Extensions] = None,
) -> Response:
# Strict type checking on our parameters.
method = enforce_bytes(method, name="method")
url = enforce_url(url, name="url")
headers = enforce_headers(headers, name="headers")
# Include Host header, and optionally Content-Length or Transfer-Encoding.
headers = include_request_headers(headers, url=url, content=content)
request = Request(
method=method,
url=url,
headers=headers,
content=content,
extensions=extensions,
)
response = self.handle_request(request)
try:
response.read()
finally:
response.close()
return response
@contextmanager
def stream(
self,
method: Union[bytes, str],
url: Union[URL, bytes, str],
*,
headers: HeaderTypes = None,
content: Union[bytes, Iterator[bytes], None] = None,
extensions: Optional[Extensions] = None,
) -> Iterator[Response]:
# Strict type checking on our parameters.
method = enforce_bytes(method, name="method")
url = enforce_url(url, name="url")
headers = enforce_headers(headers, name="headers")
# Include Host header, and optionally Content-Length or Transfer-Encoding.
headers = include_request_headers(headers, url=url, content=content)
request = Request(
method=method,
url=url,
headers=headers,
content=content,
extensions=extensions,
)
response = self.handle_request(request)
try:
yield response
finally:
response.close()
def handle_request(self, request: Request) -> Response:
raise NotImplementedError() # pragma: nocover
class ConnectionInterface(RequestInterface):
def close(self) -> None:
raise NotImplementedError() # pragma: nocover
def info(self) -> str:
raise NotImplementedError() # pragma: nocover
def can_handle_request(self, origin: Origin) -> bool:
raise NotImplementedError() # pragma: nocover
def is_available(self) -> bool:
"""
Return `True` if the connection is currently able to accept an
outgoing request.
An HTTP/1.1 connection will only be available if it is currently idle.
An HTTP/2 connection will be available so long as the stream ID space is
not yet exhausted, and the connection is not in an error state.
While the connection is being established we may not yet know if it is going
to result in an HTTP/1.1 or HTTP/2 connection. The connection should be
treated as being available, but might ultimately raise `NewConnectionRequired`
required exceptions if multiple requests are attempted over a connection
that ends up being established as HTTP/1.1.
"""
raise NotImplementedError() # pragma: nocover
def has_expired(self) -> bool:
"""
Return `True` if the connection is in a state where it should be closed.
This either means that the connection is idle and it has passed the
expiry time on its keep-alive, or that server has sent an EOF.
"""
raise NotImplementedError() # pragma: nocover
def is_idle(self) -> bool:
"""
Return `True` if the connection is currently idle.
"""
raise NotImplementedError() # pragma: nocover
def is_closed(self) -> bool:
"""
Return `True` if the connection has been closed.
Used when a response is closed to determine if the connection may be
returned to the connection pool or not.
"""
raise NotImplementedError() # pragma: nocover

View File

@ -0,0 +1,340 @@
import logging
import ssl
import typing
from socksio import socks5
from .._backends.sync import SyncBackend
from .._backends.base import NetworkBackend, NetworkStream
from .._exceptions import ConnectionNotAvailable, ProxyError
from .._models import URL, Origin, Request, Response, enforce_bytes, enforce_url
from .._ssl import default_ssl_context
from .._synchronization import Lock
from .._trace import Trace
from .connection_pool import ConnectionPool
from .http11 import HTTP11Connection
from .interfaces import ConnectionInterface
logger = logging.getLogger("httpcore.socks")
AUTH_METHODS = {
b"\x00": "NO AUTHENTICATION REQUIRED",
b"\x01": "GSSAPI",
b"\x02": "USERNAME/PASSWORD",
b"\xff": "NO ACCEPTABLE METHODS",
}
REPLY_CODES = {
b"\x00": "Succeeded",
b"\x01": "General SOCKS server failure",
b"\x02": "Connection not allowed by ruleset",
b"\x03": "Network unreachable",
b"\x04": "Host unreachable",
b"\x05": "Connection refused",
b"\x06": "TTL expired",
b"\x07": "Command not supported",
b"\x08": "Address type not supported",
}
def _init_socks5_connection(
stream: NetworkStream,
*,
host: bytes,
port: int,
auth: typing.Optional[typing.Tuple[bytes, bytes]] = None,
) -> None:
conn = socks5.SOCKS5Connection()
# Auth method request
auth_method = (
socks5.SOCKS5AuthMethod.NO_AUTH_REQUIRED
if auth is None
else socks5.SOCKS5AuthMethod.USERNAME_PASSWORD
)
conn.send(socks5.SOCKS5AuthMethodsRequest([auth_method]))
outgoing_bytes = conn.data_to_send()
stream.write(outgoing_bytes)
# Auth method response
incoming_bytes = stream.read(max_bytes=4096)
response = conn.receive_data(incoming_bytes)
assert isinstance(response, socks5.SOCKS5AuthReply)
if response.method != auth_method:
requested = AUTH_METHODS.get(auth_method, "UNKNOWN")
responded = AUTH_METHODS.get(response.method, "UNKNOWN")
raise ProxyError(
f"Requested {requested} from proxy server, but got {responded}."
)
if response.method == socks5.SOCKS5AuthMethod.USERNAME_PASSWORD:
# Username/password request
assert auth is not None
username, password = auth
conn.send(socks5.SOCKS5UsernamePasswordRequest(username, password))
outgoing_bytes = conn.data_to_send()
stream.write(outgoing_bytes)
# Username/password response
incoming_bytes = stream.read(max_bytes=4096)
response = conn.receive_data(incoming_bytes)
assert isinstance(response, socks5.SOCKS5UsernamePasswordReply)
if not response.success:
raise ProxyError("Invalid username/password")
# Connect request
conn.send(
socks5.SOCKS5CommandRequest.from_address(
socks5.SOCKS5Command.CONNECT, (host, port)
)
)
outgoing_bytes = conn.data_to_send()
stream.write(outgoing_bytes)
# Connect response
incoming_bytes = stream.read(max_bytes=4096)
response = conn.receive_data(incoming_bytes)
assert isinstance(response, socks5.SOCKS5Reply)
if response.reply_code != socks5.SOCKS5ReplyCode.SUCCEEDED:
reply_code = REPLY_CODES.get(response.reply_code, "UNKOWN")
raise ProxyError(f"Proxy Server could not connect: {reply_code}.")
class SOCKSProxy(ConnectionPool):
"""
A connection pool that sends requests via an HTTP proxy.
"""
def __init__(
self,
proxy_url: typing.Union[URL, bytes, str],
proxy_auth: typing.Optional[
typing.Tuple[typing.Union[bytes, str], typing.Union[bytes, str]]
] = None,
ssl_context: typing.Optional[ssl.SSLContext] = None,
max_connections: typing.Optional[int] = 10,
max_keepalive_connections: typing.Optional[int] = None,
keepalive_expiry: typing.Optional[float] = None,
http1: bool = True,
http2: bool = False,
retries: int = 0,
network_backend: typing.Optional[NetworkBackend] = None,
) -> None:
"""
A connection pool for making HTTP requests.
Parameters:
proxy_url: The URL to use when connecting to the proxy server.
For example `"http://127.0.0.1:8080/"`.
ssl_context: An SSL context to use for verifying connections.
If not specified, the default `httpcore.default_ssl_context()`
will be used.
max_connections: The maximum number of concurrent HTTP connections that
the pool should allow. Any attempt to send a request on a pool that
would exceed this amount will block until a connection is available.
max_keepalive_connections: The maximum number of idle HTTP connections
that will be maintained in the pool.
keepalive_expiry: The duration in seconds that an idle HTTP connection
may be maintained for before being expired from the pool.
http1: A boolean indicating if HTTP/1.1 requests should be supported
by the connection pool. Defaults to True.
http2: A boolean indicating if HTTP/2 requests should be supported by
the connection pool. Defaults to False.
retries: The maximum number of retries when trying to establish
a connection.
local_address: Local address to connect from. Can also be used to
connect using a particular address family. Using
`local_address="0.0.0.0"` will connect using an `AF_INET` address
(IPv4), while using `local_address="::"` will connect using an
`AF_INET6` address (IPv6).
uds: Path to a Unix Domain Socket to use instead of TCP sockets.
network_backend: A backend instance to use for handling network I/O.
"""
super().__init__(
ssl_context=ssl_context,
max_connections=max_connections,
max_keepalive_connections=max_keepalive_connections,
keepalive_expiry=keepalive_expiry,
http1=http1,
http2=http2,
network_backend=network_backend,
retries=retries,
)
self._ssl_context = ssl_context
self._proxy_url = enforce_url(proxy_url, name="proxy_url")
if proxy_auth is not None:
username, password = proxy_auth
username_bytes = enforce_bytes(username, name="proxy_auth")
password_bytes = enforce_bytes(password, name="proxy_auth")
self._proxy_auth: typing.Optional[typing.Tuple[bytes, bytes]] = (
username_bytes,
password_bytes,
)
else:
self._proxy_auth = None
def create_connection(self, origin: Origin) -> ConnectionInterface:
return Socks5Connection(
proxy_origin=self._proxy_url.origin,
remote_origin=origin,
proxy_auth=self._proxy_auth,
ssl_context=self._ssl_context,
keepalive_expiry=self._keepalive_expiry,
http1=self._http1,
http2=self._http2,
network_backend=self._network_backend,
)
class Socks5Connection(ConnectionInterface):
def __init__(
self,
proxy_origin: Origin,
remote_origin: Origin,
proxy_auth: typing.Optional[typing.Tuple[bytes, bytes]] = None,
ssl_context: typing.Optional[ssl.SSLContext] = None,
keepalive_expiry: typing.Optional[float] = None,
http1: bool = True,
http2: bool = False,
network_backend: typing.Optional[NetworkBackend] = None,
) -> None:
self._proxy_origin = proxy_origin
self._remote_origin = remote_origin
self._proxy_auth = proxy_auth
self._ssl_context = ssl_context
self._keepalive_expiry = keepalive_expiry
self._http1 = http1
self._http2 = http2
self._network_backend: NetworkBackend = (
SyncBackend() if network_backend is None else network_backend
)
self._connect_lock = Lock()
self._connection: typing.Optional[ConnectionInterface] = None
self._connect_failed = False
def handle_request(self, request: Request) -> Response:
timeouts = request.extensions.get("timeout", {})
timeout = timeouts.get("connect", None)
with self._connect_lock:
if self._connection is None:
try:
# Connect to the proxy
kwargs = {
"host": self._proxy_origin.host.decode("ascii"),
"port": self._proxy_origin.port,
"timeout": timeout,
}
with Trace("connect_tcp", logger, request, kwargs) as trace:
stream = self._network_backend.connect_tcp(**kwargs)
trace.return_value = stream
# Connect to the remote host using socks5
kwargs = {
"stream": stream,
"host": self._remote_origin.host.decode("ascii"),
"port": self._remote_origin.port,
"auth": self._proxy_auth,
}
with Trace(
"setup_socks5_connection", logger, request, kwargs
) as trace:
_init_socks5_connection(**kwargs)
trace.return_value = stream
# Upgrade the stream to SSL
if self._remote_origin.scheme == b"https":
ssl_context = (
default_ssl_context()
if self._ssl_context is None
else self._ssl_context
)
alpn_protocols = (
["http/1.1", "h2"] if self._http2 else ["http/1.1"]
)
ssl_context.set_alpn_protocols(alpn_protocols)
kwargs = {
"ssl_context": ssl_context,
"server_hostname": self._remote_origin.host.decode("ascii"),
"timeout": timeout,
}
with Trace("start_tls", logger, request, kwargs) as trace:
stream = stream.start_tls(**kwargs)
trace.return_value = stream
# Determine if we should be using HTTP/1.1 or HTTP/2
ssl_object = stream.get_extra_info("ssl_object")
http2_negotiated = (
ssl_object is not None
and ssl_object.selected_alpn_protocol() == "h2"
)
# Create the HTTP/1.1 or HTTP/2 connection
if http2_negotiated or (
self._http2 and not self._http1
): # pragma: nocover
from .http2 import HTTP2Connection
self._connection = HTTP2Connection(
origin=self._remote_origin,
stream=stream,
keepalive_expiry=self._keepalive_expiry,
)
else:
self._connection = HTTP11Connection(
origin=self._remote_origin,
stream=stream,
keepalive_expiry=self._keepalive_expiry,
)
except Exception as exc:
self._connect_failed = True
raise exc
elif not self._connection.is_available(): # pragma: nocover
raise ConnectionNotAvailable()
return self._connection.handle_request(request)
def can_handle_request(self, origin: Origin) -> bool:
return origin == self._remote_origin
def close(self) -> None:
if self._connection is not None:
self._connection.close()
def is_available(self) -> bool:
if self._connection is None: # pragma: nocover
# If HTTP/2 support is enabled, and the resulting connection could
# end up as HTTP/2 then we should indicate the connection as being
# available to service multiple requests.
return (
self._http2
and (self._remote_origin.scheme == b"https" or not self._http1)
and not self._connect_failed
)
return self._connection.is_available()
def has_expired(self) -> bool:
if self._connection is None: # pragma: nocover
return self._connect_failed
return self._connection.has_expired()
def is_idle(self) -> bool:
if self._connection is None: # pragma: nocover
return self._connect_failed
return self._connection.is_idle()
def is_closed(self) -> bool:
if self._connection is None: # pragma: nocover
return self._connect_failed
return self._connection.is_closed()
def info(self) -> str:
if self._connection is None: # pragma: nocover
return "CONNECTION FAILED" if self._connect_failed else "CONNECTING"
return self._connection.info()
def __repr__(self) -> str:
return f"<{self.__class__.__name__} [{self.info()}]>"

View File

@ -0,0 +1,333 @@
import threading
from types import TracebackType
from typing import Optional, Type
import sniffio
from ._exceptions import ExceptionMapping, PoolTimeout, map_exceptions
# Our async synchronization primatives use either 'anyio' or 'trio' depending
# on if they're running under asyncio or trio.
try:
import trio
except ImportError: # pragma: nocover
trio = None # type: ignore
try:
import anyio
except ImportError: # pragma: nocover
anyio = None # type: ignore
try:
import structio
except ImportError: # pragma: nocover
structio = None # type: ignore
class AsyncLock:
def __init__(self) -> None:
self._backend = ""
def setup(self) -> None:
"""
Detect if we're running under 'asyncio' or 'trio' and create
a lock with the correct implementation.
"""
self._backend = sniffio.current_async_library()
if self._backend == "trio":
if trio is None: # pragma: nocover
raise RuntimeError(
"Running under trio, requires the 'trio' package to be installed."
)
self._trio_lock = trio.Lock()
elif self._backend == "structured-io":
if structio is None: # pragma: nocover
raise RuntimeError(
"Running under structio requires the 'structio' package to be installed."
)
self._structio_lock = structio.Lock()
else:
if anyio is None: # pragma: nocover
raise RuntimeError(
"Running under asyncio requires the 'anyio' package to be installed."
)
self._anyio_lock = anyio.Lock()
async def __aenter__(self) -> "AsyncLock":
if not self._backend:
self.setup()
if self._backend == "trio":
await self._trio_lock.acquire()
elif self._backend == "structured-io":
await self._structio_lock.acquire()
else:
await self._anyio_lock.acquire()
return self
async def __aexit__(
self,
exc_type: Optional[Type[BaseException]] = None,
exc_value: Optional[BaseException] = None,
traceback: Optional[TracebackType] = None,
) -> None:
if self._backend == "trio":
self._trio_lock.release()
elif self._backend == "structured-io":
await self._structio_lock.release()
else:
self._anyio_lock.release()
class AsyncEvent:
def __init__(self) -> None:
self._backend = ""
def setup(self) -> None:
"""
Detect if we're running under 'asyncio' or 'trio' and create
a lock with the correct implementation.
"""
self._backend = sniffio.current_async_library()
if self._backend == "trio":
if trio is None: # pragma: nocover
raise RuntimeError(
"Running under trio requires the 'trio' package to be installed."
)
self._trio_event = trio.Event()
elif self._backend == "structured-io":
if structio is None: # pragma: nocover
raise RuntimeError(
"Running under structio requires the 'structio' package to be installed."
)
self._structio_event = structio.Event()
else:
if anyio is None: # pragma: nocover
raise RuntimeError(
"Running under asyncio requires the 'anyio' package to be installed."
)
self._anyio_event = anyio.Event()
def set(self) -> None:
if not self._backend:
self.setup()
if self._backend == "trio":
self._trio_event.set()
elif self._backend == "structured-io":
self._structio_event.set()
else:
self._anyio_event.set()
async def wait(self, timeout: Optional[float] = None) -> None:
if not self._backend:
self.setup()
if self._backend == "trio":
if trio is None: # pragma: nocover
raise RuntimeError(
"Running under trio requires the 'trio' package to be installed."
)
trio_exc_map: ExceptionMapping = {trio.TooSlowError: PoolTimeout}
timeout_or_inf = float("inf") if timeout is None else timeout
with map_exceptions(trio_exc_map):
with trio.fail_after(timeout_or_inf):
await self._trio_event.wait()
elif self._backend == "structured-io":
if structio is None: # pragma: nocover
raise RuntimeError(
"Running under structio requires the 'structio' package to be installed."
)
structio_exc_map: ExceptionMapping = {structio.TimedOut: PoolTimeout}
timeout_or_inf = float("inf") if timeout is None else timeout
with map_exceptions(trio_exc_map):
with structio.with_timeout(timeout_or_inf):
await self._structio_event.wait()
else:
if anyio is None: # pragma: nocover
raise RuntimeError(
"Running under asyncio requires the 'anyio' package to be installed."
)
anyio_exc_map: ExceptionMapping = {TimeoutError: PoolTimeout}
with map_exceptions(anyio_exc_map):
with anyio.fail_after(timeout):
await self._anyio_event.wait()
class AsyncSemaphore:
def __init__(self, bound: int) -> None:
self._bound = bound
self._backend = ""
def setup(self) -> None:
"""
Detect if we're running under 'asyncio' or 'trio' and create
a semaphore with the correct implementation.
"""
self._backend = sniffio.current_async_library()
if self._backend == "trio":
if trio is None: # pragma: nocover
raise RuntimeError(
"Running under trio requires the 'trio' package to be installed."
)
self._trio_semaphore = trio.Semaphore(
initial_value=self._bound, max_value=self._bound
)
elif self._backend == "structured-io":
if structio is None: # pragma: nocover
raise RuntimeError(
"Running under structio requires the 'structio' package to be installed."
)
self._structio_semaphore = structio.Semaphore(initial_size=self._bound, max_size=self._bound)
else:
if anyio is None: # pragma: nocover
raise RuntimeError(
"Running under asyncio requires the 'anyio' package to be installed."
)
self._anyio_semaphore = anyio.Semaphore(
initial_value=self._bound, max_value=self._bound
)
async def acquire(self) -> None:
if not self._backend:
self.setup()
if self._backend == "trio":
await self._trio_semaphore.acquire()
elif self._backend == "structured-io":
await self._structio_semaphore.acquire()
else:
await self._anyio_semaphore.acquire()
async def release(self) -> None:
if self._backend == "trio":
self._trio_semaphore.release()
elif self._backend == "structured-io":
await self._structio_semaphore.release()
else:
self._anyio_semaphore.release()
class AsyncShieldCancellation:
# For certain portions of our codebase where we're dealing with
# closing connections during exception handling we want to shield
# the operation from being cancelled.
#
# with AsyncShieldCancellation():
# ... # clean-up operations, shielded from cancellation.
def __init__(self) -> None:
"""
Detect if we're running under 'asyncio' or 'trio' and create
a shielded scope with the correct implementation.
"""
self._backend = sniffio.current_async_library()
if self._backend == "trio":
if trio is None: # pragma: nocover
raise RuntimeError(
"Running under trio requires the 'trio' package to be installed."
)
self._trio_shield = trio.CancelScope(shield=True)
elif self._backend == "structured-io":
if structio is None: # pragma: nocover
raise RuntimeError(
"Running under structio requires the 'structio' package to be installed."
)
self._structio_shield = structio.TaskScope(shielded=True)
else:
if anyio is None: # pragma: nocover
raise RuntimeError(
"Running under asyncio requires the 'anyio' package to be installed."
)
self._anyio_shield = anyio.CancelScope(shield=True)
def __enter__(self) -> "AsyncShieldCancellation":
if self._backend == "trio":
self._trio_shield.__enter__()
elif self._backend == "structured-io":
self._structio_shield.__enter__()
else:
self._anyio_shield.__enter__()
return self
def __exit__(
self,
exc_type: Optional[Type[BaseException]] = None,
exc_value: Optional[BaseException] = None,
traceback: Optional[TracebackType] = None,
) -> None:
if self._backend == "trio":
self._trio_shield.__exit__(exc_type, exc_value, traceback)
elif self._backend == "structured-io":
self._structio_shield.__exit__(exc_type, exc_value, traceback)
else:
self._anyio_shield.__exit__(exc_type, exc_value, traceback)
# Our thread-based synchronization primitives...
class Lock:
def __init__(self) -> None:
self._lock = threading.Lock()
def __enter__(self) -> "Lock":
self._lock.acquire()
return self
def __exit__(
self,
exc_type: Optional[Type[BaseException]] = None,
exc_value: Optional[BaseException] = None,
traceback: Optional[TracebackType] = None,
) -> None:
self._lock.release()
class Event:
def __init__(self) -> None:
self._event = threading.Event()
def set(self) -> None:
self._event.set()
def wait(self, timeout: Optional[float] = None) -> None:
if not self._event.wait(timeout=timeout):
raise PoolTimeout() # pragma: nocover
class Semaphore:
def __init__(self, bound: int) -> None:
self._semaphore = threading.Semaphore(value=bound)
def acquire(self) -> None:
self._semaphore.acquire()
def release(self) -> None:
self._semaphore.release()
class ShieldCancellation:
# Thread-synchronous codebases don't support cancellation semantics.
# We have this class because we need to mirror the async and sync
# cases within our package, but it's just a no-op.
def __enter__(self) -> "ShieldCancellation":
return self
def __exit__(
self,
exc_type: Optional[Type[BaseException]] = None,
exc_value: Optional[BaseException] = None,
traceback: Optional[TracebackType] = None,
) -> None:
pass

105
tests/httpcore/_trace.py Normal file
View File

@ -0,0 +1,105 @@
import inspect
import logging
from types import TracebackType
from typing import Any, Dict, Optional, Type
from ._models import Request
class Trace:
def __init__(
self,
name: str,
logger: logging.Logger,
request: Optional[Request] = None,
kwargs: Optional[Dict[str, Any]] = None,
) -> None:
self.name = name
self.logger = logger
self.trace_extension = (
None if request is None else request.extensions.get("trace")
)
self.debug = self.logger.isEnabledFor(logging.DEBUG)
self.kwargs = kwargs or {}
self.return_value: Any = None
self.should_trace = self.debug or self.trace_extension is not None
self.prefix = self.logger.name.split(".")[-1]
def trace(self, name: str, info: Dict[str, Any]) -> None:
if self.trace_extension is not None:
prefix_and_name = f"{self.prefix}.{name}"
ret = self.trace_extension(prefix_and_name, info)
if inspect.iscoroutine(ret): # pragma: no cover
raise TypeError(
"If you are using a synchronous interface, "
"the callback of the `trace` extension should "
"be a normal function instead of an asynchronous function."
)
if self.debug:
if not info or "return_value" in info and info["return_value"] is None:
message = name
else:
args = " ".join([f"{key}={value!r}" for key, value in info.items()])
message = f"{name} {args}"
self.logger.debug(message)
def __enter__(self) -> "Trace":
if self.should_trace:
info = self.kwargs
self.trace(f"{self.name}.started", info)
return self
def __exit__(
self,
exc_type: Optional[Type[BaseException]] = None,
exc_value: Optional[BaseException] = None,
traceback: Optional[TracebackType] = None,
) -> None:
if self.should_trace:
if exc_value is None:
info = {"return_value": self.return_value}
self.trace(f"{self.name}.complete", info)
else:
info = {"exception": exc_value}
self.trace(f"{self.name}.failed", info)
async def atrace(self, name: str, info: Dict[str, Any]) -> None:
if self.trace_extension is not None:
prefix_and_name = f"{self.prefix}.{name}"
coro = self.trace_extension(prefix_and_name, info)
if not inspect.iscoroutine(coro): # pragma: no cover
raise TypeError(
"If you're using an asynchronous interface, "
"the callback of the `trace` extension should "
"be an asynchronous function rather than a normal function."
)
await coro
if self.debug:
if not info or "return_value" in info and info["return_value"] is None:
message = name
else:
args = " ".join([f"{key}={value!r}" for key, value in info.items()])
message = f"{name} {args}"
self.logger.debug(message)
async def __aenter__(self) -> "Trace":
if self.should_trace:
info = self.kwargs
await self.atrace(f"{self.name}.started", info)
return self
async def __aexit__(
self,
exc_type: Optional[Type[BaseException]] = None,
exc_value: Optional[BaseException] = None,
traceback: Optional[TracebackType] = None,
) -> None:
if self.should_trace:
if exc_value is None:
info = {"return_value": self.return_value}
await self.atrace(f"{self.name}.complete", info)
else:
info = {"exception": exc_value}
await self.atrace(f"{self.name}.failed", info)

36
tests/httpcore/_utils.py Normal file
View File

@ -0,0 +1,36 @@
import select
import socket
import sys
import typing
def is_socket_readable(sock: typing.Optional[socket.socket]) -> bool:
"""
Return whether a socket, as identifed by its file descriptor, is readable.
"A socket is readable" means that the read buffer isn't empty, i.e. that calling
.recv() on it would immediately return some data.
"""
# NOTE: we want check for readability without actually attempting to read, because
# we don't want to block forever if it's not readable.
# In the case that the socket no longer exists, or cannot return a file
# descriptor, we treat it as being readable, as if it the next read operation
# on it is ready to return the terminating `b""`.
sock_fd = None if sock is None else sock.fileno()
if sock_fd is None or sock_fd < 0: # pragma: nocover
return True
# The implementation below was stolen from:
# https://github.com/python-trio/trio/blob/20ee2b1b7376db637435d80e266212a35837ddcc/trio/_socket.py#L471-L478
# See also: https://github.com/encode/httpcore/pull/193#issuecomment-703129316
# Use select.select on Windows, and when poll is unavailable and select.poll
# everywhere else. (E.g. When eventlet is in use. See #327)
if (
sys.platform == "win32" or getattr(select, "poll", None) is None
): # pragma: nocover
rready, _, _ = select.select([sock_fd], [], [], 0)
return bool(rready)
p = select.poll()
p.register(sock_fd, select.POLLIN)
return bool(p.poll(0))

0
tests/httpcore/py.typed Normal file
View File

View File

@ -1,10 +1,12 @@
import httpcore
import structio
async def main():
# Note: this test only works because we have our own version of httpcore that
# implements a structio-compatible backend. It's just an example anyway
pool = httpcore.AsyncConnectionPool()
print(await pool.request("GET", "http://example.com"))
# TODO: SSL support
structio.run(main)