import enum import logging import time import types import typing import h2.config import h2.connection import h2.events import h2.exceptions import h2.settings from .._backends.base import AsyncNetworkStream from .._exceptions import ( ConnectionNotAvailable, LocalProtocolError, RemoteProtocolError, ) from .._models import Origin, Request, Response from .._synchronization import AsyncLock, AsyncSemaphore, AsyncShieldCancellation from .._trace import Trace from .interfaces import AsyncConnectionInterface logger = logging.getLogger("httpcore.http2") def has_body_headers(request: Request) -> bool: return any( k.lower() == b"content-length" or k.lower() == b"transfer-encoding" for k, v in request.headers ) class HTTPConnectionState(enum.IntEnum): ACTIVE = 1 IDLE = 2 CLOSED = 3 class AsyncHTTP2Connection(AsyncConnectionInterface): READ_NUM_BYTES = 64 * 1024 CONFIG = h2.config.H2Configuration(validate_inbound_headers=False) def __init__( self, origin: Origin, stream: AsyncNetworkStream, keepalive_expiry: typing.Optional[float] = None, ): self._origin = origin self._network_stream = stream self._keepalive_expiry: typing.Optional[float] = keepalive_expiry self._h2_state = h2.connection.H2Connection(config=self.CONFIG) self._state = HTTPConnectionState.IDLE self._expire_at: typing.Optional[float] = None self._request_count = 0 self._init_lock = AsyncLock() self._state_lock = AsyncLock() self._read_lock = AsyncLock() self._write_lock = AsyncLock() self._sent_connection_init = False self._used_all_stream_ids = False self._connection_error = False # Mapping from stream ID to response stream events. self._events: typing.Dict[ int, typing.Union[ h2.events.ResponseReceived, h2.events.DataReceived, h2.events.StreamEnded, h2.events.StreamReset, ], ] = {} # Connection terminated events are stored as state since # we need to handle them for all streams. self._connection_terminated: typing.Optional[ h2.events.ConnectionTerminated ] = None self._read_exception: typing.Optional[Exception] = None self._write_exception: typing.Optional[Exception] = None async def handle_async_request(self, request: Request) -> Response: if not self.can_handle_request(request.url.origin): # This cannot occur in normal operation, since the connection pool # will only send requests on connections that handle them. # It's in place simply for resilience as a guard against incorrect # usage, for anyone working directly with httpcore connections. raise RuntimeError( f"Attempted to send request to {request.url.origin} on connection " f"to {self._origin}" ) async with self._state_lock: if self._state in (HTTPConnectionState.ACTIVE, HTTPConnectionState.IDLE): self._request_count += 1 self._expire_at = None self._state = HTTPConnectionState.ACTIVE else: raise ConnectionNotAvailable() async with self._init_lock: if not self._sent_connection_init: try: kwargs = {"request": request} async with Trace("send_connection_init", logger, request, kwargs): await self._send_connection_init(**kwargs) except BaseException as exc: with AsyncShieldCancellation(): await self.aclose() raise exc self._sent_connection_init = True # Initially start with just 1 until the remote server provides # its max_concurrent_streams value self._max_streams = 1 local_settings_max_streams = ( self._h2_state.local_settings.max_concurrent_streams ) self._max_streams_semaphore = AsyncSemaphore(local_settings_max_streams) for _ in range(local_settings_max_streams - self._max_streams): await self._max_streams_semaphore.acquire() await self._max_streams_semaphore.acquire() try: stream_id = self._h2_state.get_next_available_stream_id() self._events[stream_id] = [] except h2.exceptions.NoAvailableStreamIDError: # pragma: nocover self._used_all_stream_ids = True self._request_count -= 1 raise ConnectionNotAvailable() try: kwargs = {"request": request, "stream_id": stream_id} async with Trace("send_request_headers", logger, request, kwargs): await self._send_request_headers(request=request, stream_id=stream_id) async with Trace("send_request_body", logger, request, kwargs): await self._send_request_body(request=request, stream_id=stream_id) async with Trace( "receive_response_headers", logger, request, kwargs ) as trace: status, headers = await self._receive_response( request=request, stream_id=stream_id ) trace.return_value = (status, headers) return Response( status=status, headers=headers, content=HTTP2ConnectionByteStream(self, request, stream_id=stream_id), extensions={ "http_version": b"HTTP/2", "network_stream": self._network_stream, "stream_id": stream_id, }, ) except BaseException as exc: # noqa: PIE786 with AsyncShieldCancellation(): kwargs = {"stream_id": stream_id} async with Trace("response_closed", logger, request, kwargs): await self._response_closed(stream_id=stream_id) if isinstance(exc, h2.exceptions.ProtocolError): # One case where h2 can raise a protocol error is when a # closed frame has been seen by the state machine. # # This happens when one stream is reading, and encounters # a GOAWAY event. Other flows of control may then raise # a protocol error at any point they interact with the 'h2_state'. # # In this case we'll have stored the event, and should raise # it as a RemoteProtocolError. if self._connection_terminated: # pragma: nocover raise RemoteProtocolError(self._connection_terminated) # If h2 raises a protocol error in some other state then we # must somehow have made a protocol violation. raise LocalProtocolError(exc) # pragma: nocover raise exc async def _send_connection_init(self, request: Request) -> None: """ The HTTP/2 connection requires some initial setup before we can start using individual request/response streams on it. """ # Need to set these manually here instead of manipulating via # __setitem__() otherwise the H2Connection will emit SettingsUpdate # frames in addition to sending the undesired defaults. self._h2_state.local_settings = h2.settings.Settings( client=True, initial_values={ # Disable PUSH_PROMISE frames from the server since we don't do anything # with them for now. Maybe when we support caching? h2.settings.SettingCodes.ENABLE_PUSH: 0, # These two are taken from h2 for safe defaults h2.settings.SettingCodes.MAX_CONCURRENT_STREAMS: 100, h2.settings.SettingCodes.MAX_HEADER_LIST_SIZE: 65536, }, ) # Some websites (*cough* Yahoo *cough*) balk at this setting being # present in the initial handshake since it's not defined in the original # RFC despite the RFC mandating ignoring settings you don't know about. del self._h2_state.local_settings[ h2.settings.SettingCodes.ENABLE_CONNECT_PROTOCOL ] self._h2_state.initiate_connection() self._h2_state.increment_flow_control_window(2**24) await self._write_outgoing_data(request) # Sending the request... async def _send_request_headers(self, request: Request, stream_id: int) -> None: """ Send the request headers to a given stream ID. """ end_stream = not has_body_headers(request) # In HTTP/2 the ':authority' pseudo-header is used instead of 'Host'. # In order to gracefully handle HTTP/1.1 and HTTP/2 we always require # HTTP/1.1 style headers, and map them appropriately if we end up on # an HTTP/2 connection. authority = [v for k, v in request.headers if k.lower() == b"host"][0] headers = [ (b":method", request.method), (b":authority", authority), (b":scheme", request.url.scheme), (b":path", request.url.target), ] + [ (k.lower(), v) for k, v in request.headers if k.lower() not in ( b"host", b"transfer-encoding", ) ] self._h2_state.send_headers(stream_id, headers, end_stream=end_stream) self._h2_state.increment_flow_control_window(2**24, stream_id=stream_id) await self._write_outgoing_data(request) async def _send_request_body(self, request: Request, stream_id: int) -> None: """ Iterate over the request body sending it to a given stream ID. """ if not has_body_headers(request): return assert isinstance(request.stream, typing.AsyncIterable) async for data in request.stream: await self._send_stream_data(request, stream_id, data) await self._send_end_stream(request, stream_id) async def _send_stream_data( self, request: Request, stream_id: int, data: bytes ) -> None: """ Send a single chunk of data in one or more data frames. """ while data: max_flow = await self._wait_for_outgoing_flow(request, stream_id) chunk_size = min(len(data), max_flow) chunk, data = data[:chunk_size], data[chunk_size:] self._h2_state.send_data(stream_id, chunk) await self._write_outgoing_data(request) async def _send_end_stream(self, request: Request, stream_id: int) -> None: """ Send an empty data frame on on a given stream ID with the END_STREAM flag set. """ self._h2_state.end_stream(stream_id) await self._write_outgoing_data(request) # Receiving the response... async def _receive_response( self, request: Request, stream_id: int ) -> typing.Tuple[int, typing.List[typing.Tuple[bytes, bytes]]]: """ Return the response status code and headers for a given stream ID. """ while True: event = await self._receive_stream_event(request, stream_id) if isinstance(event, h2.events.ResponseReceived): break status_code = 200 headers = [] for k, v in event.headers: if k == b":status": status_code = int(v.decode("ascii", errors="ignore")) elif not k.startswith(b":"): headers.append((k, v)) return (status_code, headers) async def _receive_response_body( self, request: Request, stream_id: int ) -> typing.AsyncIterator[bytes]: """ Iterator that returns the bytes of the response body for a given stream ID. """ while True: event = await self._receive_stream_event(request, stream_id) if isinstance(event, h2.events.DataReceived): amount = event.flow_controlled_length self._h2_state.acknowledge_received_data(amount, stream_id) await self._write_outgoing_data(request) yield event.data elif isinstance(event, h2.events.StreamEnded): break async def _receive_stream_event( self, request: Request, stream_id: int ) -> typing.Union[ h2.events.ResponseReceived, h2.events.DataReceived, h2.events.StreamEnded ]: """ Return the next available event for a given stream ID. Will read more data from the network if required. """ while not self._events.get(stream_id): await self._receive_events(request, stream_id) event = self._events[stream_id].pop(0) if isinstance(event, h2.events.StreamReset): raise RemoteProtocolError(event) return event async def _receive_events( self, request: Request, stream_id: typing.Optional[int] = None ) -> None: """ Read some data from the network until we see one or more events for a given stream ID. """ async with self._read_lock: if self._connection_terminated is not None: last_stream_id = self._connection_terminated.last_stream_id if stream_id and last_stream_id and stream_id > last_stream_id: self._request_count -= 1 raise ConnectionNotAvailable() raise RemoteProtocolError(self._connection_terminated) # This conditional is a bit icky. We don't want to block reading if we've # actually got an event to return for a given stream. We need to do that # check *within* the atomic read lock. Though it also need to be optional, # because when we call it from `_wait_for_outgoing_flow` we *do* want to # block until we've available flow control, event when we have events # pending for the stream ID we're attempting to send on. if stream_id is None or not self._events.get(stream_id): events = await self._read_incoming_data(request) for event in events: if isinstance(event, h2.events.RemoteSettingsChanged): async with Trace( "receive_remote_settings", logger, request ) as trace: await self._receive_remote_settings_change(event) trace.return_value = event elif isinstance( event, ( h2.events.ResponseReceived, h2.events.DataReceived, h2.events.StreamEnded, h2.events.StreamReset, ), ): if event.stream_id in self._events: self._events[event.stream_id].append(event) elif isinstance(event, h2.events.ConnectionTerminated): self._connection_terminated = event await self._write_outgoing_data(request) async def _receive_remote_settings_change(self, event: h2.events.Event) -> None: max_concurrent_streams = event.changed_settings.get( h2.settings.SettingCodes.MAX_CONCURRENT_STREAMS ) if max_concurrent_streams: new_max_streams = min( max_concurrent_streams.new_value, self._h2_state.local_settings.max_concurrent_streams, ) if new_max_streams and new_max_streams != self._max_streams: while new_max_streams > self._max_streams: await self._max_streams_semaphore.release() self._max_streams += 1 while new_max_streams < self._max_streams: await self._max_streams_semaphore.acquire() self._max_streams -= 1 async def _response_closed(self, stream_id: int) -> None: await self._max_streams_semaphore.release() del self._events[stream_id] async with self._state_lock: if self._connection_terminated and not self._events: await self.aclose() elif self._state == HTTPConnectionState.ACTIVE and not self._events: self._state = HTTPConnectionState.IDLE if self._keepalive_expiry is not None: now = time.monotonic() self._expire_at = now + self._keepalive_expiry if self._used_all_stream_ids: # pragma: nocover await self.aclose() async def aclose(self) -> None: # Note that this method unilaterally closes the connection, and does # not have any kind of locking in place around it. self._h2_state.close_connection() self._state = HTTPConnectionState.CLOSED await self._network_stream.aclose() # Wrappers around network read/write operations... async def _read_incoming_data( self, request: Request ) -> typing.List[h2.events.Event]: timeouts = request.extensions.get("timeout", {}) timeout = timeouts.get("read", None) if self._read_exception is not None: raise self._read_exception # pragma: nocover try: data = await self._network_stream.read(self.READ_NUM_BYTES, timeout) if data == b"": raise RemoteProtocolError("Server disconnected") except Exception as exc: # If we get a network error we should: # # 1. Save the exception and just raise it immediately on any future reads. # (For example, this means that a single read timeout or disconnect will # immediately close all pending streams. Without requiring multiple # sequential timeouts.) # 2. Mark the connection as errored, so that we don't accept any other # incoming requests. self._read_exception = exc self._connection_error = True raise exc events: typing.List[h2.events.Event] = self._h2_state.receive_data(data) return events async def _write_outgoing_data(self, request: Request) -> None: timeouts = request.extensions.get("timeout", {}) timeout = timeouts.get("write", None) async with self._write_lock: data_to_send = self._h2_state.data_to_send() if self._write_exception is not None: raise self._write_exception # pragma: nocover try: await self._network_stream.write(data_to_send, timeout) except Exception as exc: # pragma: nocover # If we get a network error we should: # # 1. Save the exception and just raise it immediately on any future write. # (For example, this means that a single write timeout or disconnect will # immediately close all pending streams. Without requiring multiple # sequential timeouts.) # 2. Mark the connection as errored, so that we don't accept any other # incoming requests. self._write_exception = exc self._connection_error = True raise exc # Flow control... async def _wait_for_outgoing_flow(self, request: Request, stream_id: int) -> int: """ Returns the maximum allowable outgoing flow for a given stream. If the allowable flow is zero, then waits on the network until WindowUpdated frames have increased the flow rate. https://tools.ietf.org/html/rfc7540#section-6.9 """ local_flow: int = self._h2_state.local_flow_control_window(stream_id) max_frame_size: int = self._h2_state.max_outbound_frame_size flow = min(local_flow, max_frame_size) while flow == 0: await self._receive_events(request) local_flow = self._h2_state.local_flow_control_window(stream_id) max_frame_size = self._h2_state.max_outbound_frame_size flow = min(local_flow, max_frame_size) return flow # Interface for connection pooling... def can_handle_request(self, origin: Origin) -> bool: return origin == self._origin def is_available(self) -> bool: return ( self._state != HTTPConnectionState.CLOSED and not self._connection_error and not self._used_all_stream_ids and not ( self._h2_state.state_machine.state == h2.connection.ConnectionState.CLOSED ) ) def has_expired(self) -> bool: now = time.monotonic() return self._expire_at is not None and now > self._expire_at def is_idle(self) -> bool: return self._state == HTTPConnectionState.IDLE def is_closed(self) -> bool: return self._state == HTTPConnectionState.CLOSED def info(self) -> str: origin = str(self._origin) return ( f"{origin!r}, HTTP/2, {self._state.name}, " f"Request Count: {self._request_count}" ) def __repr__(self) -> str: class_name = self.__class__.__name__ origin = str(self._origin) return ( f"<{class_name} [{origin!r}, {self._state.name}, " f"Request Count: {self._request_count}]>" ) # These context managers are not used in the standard flow, but are # useful for testing or working with connection instances directly. async def __aenter__(self) -> "AsyncHTTP2Connection": return self async def __aexit__( self, exc_type: typing.Optional[typing.Type[BaseException]] = None, exc_value: typing.Optional[BaseException] = None, traceback: typing.Optional[types.TracebackType] = None, ) -> None: await self.aclose() class HTTP2ConnectionByteStream: def __init__( self, connection: AsyncHTTP2Connection, request: Request, stream_id: int ) -> None: self._connection = connection self._request = request self._stream_id = stream_id self._closed = False async def __aiter__(self) -> typing.AsyncIterator[bytes]: kwargs = {"request": self._request, "stream_id": self._stream_id} try: async with Trace("receive_response_body", logger, self._request, kwargs): async for chunk in self._connection._receive_response_body( request=self._request, stream_id=self._stream_id ): yield chunk except BaseException as exc: # If we get an exception while streaming the response, # we want to close the response (and possibly the connection) # before raising that exception. with AsyncShieldCancellation(): await self.aclose() raise exc async def aclose(self) -> None: if not self._closed: self._closed = True kwargs = {"stream_id": self._stream_id} async with Trace("response_closed", logger, self._request, kwargs): await self._connection._response_closed(stream_id=self._stream_id)