Added Hyper-Internal-Service

This commit is contained in:
netkas 2020-12-25 15:14:57 -05:00
parent f7a5d5cda7
commit 7da75d9366
93 changed files with 38913 additions and 1 deletions

View File

@ -16,9 +16,13 @@ clean_nlpfr:
clean_alg:
rm -rf alg/build alg/dist alg/alg.egg-info
clean_his:
rm -rf hyper_internal_service/build hyper_internal_service/dist hyper_internal_service/hyper_internal_service.egg-info
clean:
make clean_apt clean_stopwords clean_tokenizer clean_nlpfr
make clean_dltc
make clean_his
make clean_alg
# ======================================================================================================================
@ -43,10 +47,14 @@ build_dltc:
cd dltc; python3 setup.py build; python3 setup.py sdist
build_alg:
cd alg; python3 setup.py build; python3 setup.py dist
cd alg; python3 setup.py build; python3 setup.py sdist
build_his:
cd hyper_internal_service; python3 setup.py build; python3 setup.py sdist
build:
make build_nlpfr
make build_his
make build_dltc
make build_alg
@ -74,7 +82,22 @@ install_dltc:
install_alg:
cd alg; python3 setup.py install
install_his:
cd hyper_internal_service; make install
install:
make install_nlpfr
make install_his
make install_dltc
make install_alg
# ======================================================================================================================
system_prep_pip:
apt -y install python3 python3-distutils wget curl
wget https://bootstrap.pypa.io/get-pip.py
python3 get-pip.py
rm get-pip.py
system_prep_gcc:
apt -y install gcc

View File

@ -0,0 +1,24 @@
This is free and unencumbered software released into the public domain.
Anyone is free to copy, modify, publish, use, compile, sell, or
distribute this software, either in source code form or as a compiled
binary, for any purpose, commercial or non-commercial, and by any
means.
In jurisdictions that recognize copyright laws, the author or authors
of this software dedicate any and all copyright interest in the
software to the public domain. We make this dedication for the benefit
of the public at large and to the detriment of our heirs and
successors. We intend this dedication to be an overt act of
relinquishment in perpetuity of all present and future rights to this
software under copyright law.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR
OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE,
ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
OTHER DEALINGS IN THE SOFTWARE.
For more information, please refer to <https://unlicense.org>

View File

@ -0,0 +1,13 @@
include README.md
include Makefile
graft hyper_internal_service
recursive-include vendor *
global-include hyper_internal_service *.pyi
global-exclude *.pyc
global-exclude *.pyd
global-exclude *.so
global-exclude *.lib
global-exclude *.dll
global-exclude *.a
global-exclude *.obj
exclude hyper_internal_service/*.html

View File

@ -0,0 +1,124 @@
# Some simple testing tasks (sorry, UNIX only).
PYXS = $(wildcard hyper_internal_service/*.pyx)
SRC = hyper_internal_service examples tests setup.py
all: test
.install-cython:
pip install -r cython_requirements.txt
touch .install-cython
hyper_internal_service/%.c: hyper_internal_service/%.pyx
cython -3 -o $@ $< -I hyper_internal_service
cythonize: .install-cython $(PYXS:.pyx=.c)
.install-deps: cythonize $(shell find requirements -type f)
pip install -r dev_requirements.txt
@touch .install-deps
isort:
isort -rc $(SRC)
flake: .flake
.flake: .install-deps $(shell find hyper_internal_service -type f)
flake8 hyper_internal_service examples tests
@if ! isort -c -rc hyper_internal_service tests examples; then \
echo "Import sort errors, run 'make isort' to fix them!!!"; \
isort --diff -rc hyper_internal_service tests examples; \
false; \
fi
@if ! LC_ALL=C sort -c CONTRIBUTORS.txt; then \
echo "CONTRIBUTORS.txt sort error"; \
fi
@touch .flake
flake8:
flake8 $(SRC)
mypy: .flake
mypy hyper_internal_service
isort-check:
@if ! isort -rc --check-only $(SRC); then \
echo "Import sort errors, run 'make isort' to fix them!!!"; \
isort --diff -rc $(SRC); \
false; \
fi
check_changes:
./tools/check_changes.py
.develop: .install-deps $(shell find hyper_internal_service -type f) .flake check_changes mypy
# pip install -e .
@touch .develop
test: .develop
@pytest -q
vtest: .develop
@pytest -s -v
cov cover coverage:
tox
cov-dev: .develop
@pytest --cov-report=html
@echo "open file://`pwd`/htmlcov/index.html"
cov-ci-run: .develop
@echo "Regular run"
@pytest --cov-report=html
cov-dev-full: cov-ci-run
@echo "open file://`pwd`/htmlcov/index.html"
clean:
@rm -rf `find . -name __pycache__`
@rm -f `find . -type f -name '*.py[co]' `
@rm -f `find . -type f -name '*~' `
@rm -f `find . -type f -name '.*~' `
@rm -f `find . -type f -name '@*' `
@rm -f `find . -type f -name '#*#' `
@rm -f `find . -type f -name '*.orig' `
@rm -f `find . -type f -name '*.rej' `
@rm -f .coverage
@rm -rf htmlcov
@rm -rf build
@rm -rf cover
@python setup.py clean
@rm -f hyper_internal_service/_frozenlist.html
@rm -f hyper_internal_service/_frozenlist.c
@rm -f hyper_internal_service/_frozenlist.*.so
@rm -f hyper_internal_service/_frozenlist.*.pyd
@rm -f hyper_internal_service/_http_parser.html
@rm -f hyper_internal_service/_http_parser.c
@rm -f hyper_internal_service/_http_parser.*.so
@rm -f hyper_internal_service/_http_parser.*.pyd
@rm -f hyper_internal_service/_multidict.html
@rm -f hyper_internal_service/_multidict.c
@rm -f hyper_internal_service/_multidict.*.so
@rm -f hyper_internal_service/_multidict.*.pyd
@rm -f hyper_internal_service/_websocket.html
@rm -f hyper_internal_service/_websocket.c
@rm -f hyper_internal_service/_websocket.*.so
@rm -f hyper_internal_service/_websocket.*.pyd
@rm -f hyper_internal_service/_parser.html
@rm -f hyper_internal_service/_parser.c
@rm -f hyper_internal_service/_parser.*.so
@rm -f hyper_internal_service/_parser.*.pyd
@rm -rf .tox
@rm -f .develop
@rm -f .flake
@rm -f .install-deps
@rm -rf hyper_internal_service.egg-info
install:
@pip install -U 'pip'
@pip install -Ur dev_requirements.txt
.PHONY: all build flake test vtest cov clean doc mypy

View File

@ -0,0 +1,80 @@
# Hyper Internal Service
Hyper Internal Service is a internal async HTTP client/server
allowing for different internal components to communicate with
each other using various interchangeable data formats.
## Installation
```shell script
sudo -H make install
```
or
```shell script
python3 -m pip install -Ur dev_requirements.txt
python3 setup.py install
```
## Example Server
```python
from hyper_internal_service import web
async def handle(request):
name = request.match_info.get("name", "Anonymous")
text = "Hello, " + name
return web.Response(text=text)
async def wshandle(request):
ws = web.WebSocketResponse()
await ws.prepare(request)
async for msg in ws:
if msg.type == web.WSMsgType.TEXT:
await ws.send_str("Hello, {}".format(msg.data))
elif msg.type == web.WSMsgType.BINARY:
await ws.send_bytes(msg.data)
elif msg.type == web.WSMsgType.CLOSE:
break
return ws
app = web.Application()
app.add_routes([web.get("/", handle),
web.get("/echo", wshandle),
web.get("/{name}", handle)])
web.run_app(app)
```
## Example Client
```python
import asyncio
import hyper_internal_service
async def fetch(session):
print('Query http://httpbin.org/get')
async with session.get(
'http://httpbin.org/get') as resp:
print(resp.status)
data = await resp.json()
print(data)
async def go():
async with hyper_internal_service.ClientSession() as session:
await fetch(session)
loop = asyncio.get_event_loop()
loop.run_until_complete(go())
loop.close()
```

View File

@ -0,0 +1 @@
cython==0.29.18

View File

@ -0,0 +1,3 @@
-r requirements/ci.txt
-r requirements/towncrier.txt
cherry_picker==1.3.2; python_version>="3.6"

View File

@ -0,0 +1,227 @@
__version__ = '1.0.0'
__base__ = "3.6.2"
from typing import Tuple # noqa
from . import hdrs as hdrs
from .client import BaseConnector as BaseConnector
from .client import ClientConnectionError as ClientConnectionError
from .client import (
ClientConnectorCertificateError as ClientConnectorCertificateError,
)
from .client import ClientConnectorError as ClientConnectorError
from .client import ClientConnectorSSLError as ClientConnectorSSLError
from .client import ClientError as ClientError
from .client import ClientHttpProxyError as ClientHttpProxyError
from .client import ClientOSError as ClientOSError
from .client import ClientPayloadError as ClientPayloadError
from .client import ClientProxyConnectionError as ClientProxyConnectionError
from .client import ClientRequest as ClientRequest
from .client import ClientResponse as ClientResponse
from .client import ClientResponseError as ClientResponseError
from .client import ClientSession as ClientSession
from .client import ClientSSLError as ClientSSLError
from .client import ClientTimeout as ClientTimeout
from .client import ClientWebSocketResponse as ClientWebSocketResponse
from .client import ContentTypeError as ContentTypeError
from .client import Fingerprint as Fingerprint
from .client import InvalidURL as InvalidURL
from .client import NamedPipeConnector as NamedPipeConnector
from .client import RequestInfo as RequestInfo
from .client import ServerConnectionError as ServerConnectionError
from .client import ServerDisconnectedError as ServerDisconnectedError
from .client import ServerFingerprintMismatch as ServerFingerprintMismatch
from .client import ServerTimeoutError as ServerTimeoutError
from .client import TCPConnector as TCPConnector
from .client import TooManyRedirects as TooManyRedirects
from .client import UnixConnector as UnixConnector
from .client import WSServerHandshakeError as WSServerHandshakeError
from .client import request as request
from .cookiejar import CookieJar as CookieJar
from .cookiejar import DummyCookieJar as DummyCookieJar
from .formdata import FormData as FormData
from .helpers import BasicAuth as BasicAuth
from .helpers import ChainMapProxy as ChainMapProxy
from .http import HttpVersion as HttpVersion
from .http import HttpVersion10 as HttpVersion10
from .http import HttpVersion11 as HttpVersion11
from .http import WebSocketError as WebSocketError
from .http import WSCloseCode as WSCloseCode
from .http import WSMessage as WSMessage
from .http import WSMsgType as WSMsgType
from .multipart import (
BadContentDispositionHeader as BadContentDispositionHeader,
)
from .multipart import BadContentDispositionParam as BadContentDispositionParam
from .multipart import BodyPartReader as BodyPartReader
from .multipart import MultipartReader as MultipartReader
from .multipart import MultipartWriter as MultipartWriter
from .multipart import (
content_disposition_filename as content_disposition_filename,
)
from .multipart import parse_content_disposition as parse_content_disposition
from .payload import PAYLOAD_REGISTRY as PAYLOAD_REGISTRY
from .payload import AsyncIterablePayload as AsyncIterablePayload
from .payload import BufferedReaderPayload as BufferedReaderPayload
from .payload import BytesIOPayload as BytesIOPayload
from .payload import BytesPayload as BytesPayload
from .payload import IOBasePayload as IOBasePayload
from .payload import JsonPayload as JsonPayload
from .payload import Payload as Payload
from .payload import StringIOPayload as StringIOPayload
from .payload import StringPayload as StringPayload
from .payload import TextIOPayload as TextIOPayload
from .payload import get_payload as get_payload
from .payload import payload_type as payload_type
from .payload_streamer import streamer as streamer
from .resolver import AsyncResolver as AsyncResolver
from .resolver import DefaultResolver as DefaultResolver
from .resolver import ThreadedResolver as ThreadedResolver
from .signals import Signal as Signal
from .streams import EMPTY_PAYLOAD as EMPTY_PAYLOAD
from .streams import DataQueue as DataQueue
from .streams import EofStream as EofStream
from .streams import FlowControlDataQueue as FlowControlDataQueue
from .streams import StreamReader as StreamReader
from .tracing import TraceConfig as TraceConfig
from .tracing import (
TraceConnectionCreateEndParams as TraceConnectionCreateEndParams,
)
from .tracing import (
TraceConnectionCreateStartParams as TraceConnectionCreateStartParams,
)
from .tracing import (
TraceConnectionQueuedEndParams as TraceConnectionQueuedEndParams,
)
from .tracing import (
TraceConnectionQueuedStartParams as TraceConnectionQueuedStartParams,
)
from .tracing import (
TraceConnectionReuseconnParams as TraceConnectionReuseconnParams,
)
from .tracing import TraceDnsCacheHitParams as TraceDnsCacheHitParams
from .tracing import TraceDnsCacheMissParams as TraceDnsCacheMissParams
from .tracing import (
TraceDnsResolveHostEndParams as TraceDnsResolveHostEndParams,
)
from .tracing import (
TraceDnsResolveHostStartParams as TraceDnsResolveHostStartParams,
)
from .tracing import TraceRequestChunkSentParams as TraceRequestChunkSentParams
from .tracing import TraceRequestEndParams as TraceRequestEndParams
from .tracing import TraceRequestExceptionParams as TraceRequestExceptionParams
from .tracing import TraceRequestRedirectParams as TraceRequestRedirectParams
from .tracing import TraceRequestStartParams as TraceRequestStartParams
from .tracing import (
TraceResponseChunkReceivedParams as TraceResponseChunkReceivedParams,
)
__all__ = (
'hdrs',
# client
'BaseConnector',
'ClientConnectionError',
'ClientConnectorCertificateError',
'ClientConnectorError',
'ClientConnectorSSLError',
'ClientError',
'ClientHttpProxyError',
'ClientOSError',
'ClientPayloadError',
'ClientProxyConnectionError',
'ClientResponse',
'ClientRequest',
'ClientResponseError',
'ClientSSLError',
'ClientSession',
'ClientTimeout',
'ClientWebSocketResponse',
'ContentTypeError',
'Fingerprint',
'InvalidURL',
'RequestInfo',
'ServerConnectionError',
'ServerDisconnectedError',
'ServerFingerprintMismatch',
'ServerTimeoutError',
'TCPConnector',
'TooManyRedirects',
'UnixConnector',
'NamedPipeConnector',
'WSServerHandshakeError',
'request',
# cookiejar
'CookieJar',
'DummyCookieJar',
# formdata
'FormData',
# helpers
'BasicAuth',
'ChainMapProxy',
# http
'HttpVersion',
'HttpVersion10',
'HttpVersion11',
'WSMsgType',
'WSCloseCode',
'WSMessage',
'WebSocketError',
# multipart
'BadContentDispositionHeader',
'BadContentDispositionParam',
'BodyPartReader',
'MultipartReader',
'MultipartWriter',
'content_disposition_filename',
'parse_content_disposition',
# payload
'AsyncIterablePayload',
'BufferedReaderPayload',
'BytesIOPayload',
'BytesPayload',
'IOBasePayload',
'JsonPayload',
'PAYLOAD_REGISTRY',
'Payload',
'StringIOPayload',
'StringPayload',
'TextIOPayload',
'get_payload',
'payload_type',
# payload_streamer
'streamer',
# resolver
'AsyncResolver',
'DefaultResolver',
'ThreadedResolver',
# signals
'Signal',
'DataQueue',
'EMPTY_PAYLOAD',
'EofStream',
'FlowControlDataQueue',
'StreamReader',
# tracing
'TraceConfig',
'TraceConnectionCreateEndParams',
'TraceConnectionCreateStartParams',
'TraceConnectionQueuedEndParams',
'TraceConnectionQueuedStartParams',
'TraceConnectionReuseconnParams',
'TraceDnsCacheHitParams',
'TraceDnsCacheMissParams',
'TraceDnsResolveHostEndParams',
'TraceDnsResolveHostStartParams',
'TraceRequestChunkSentParams',
'TraceRequestEndParams',
'TraceRequestExceptionParams',
'TraceRequestRedirectParams',
'TraceRequestStartParams',
'TraceResponseChunkReceivedParams',
) # type: Tuple[str, ...]
try:
from .worker import GunicornWebWorker, GunicornUVLoopWebWorker # noqa
__all__ += ('GunicornWebWorker', 'GunicornUVLoopWebWorker')
except ImportError: # pragma: no cover
pass

View File

@ -0,0 +1,140 @@
from libc.stdint cimport uint16_t, uint32_t, uint64_t
cdef extern from "../vendor/http-parser/http_parser.h":
ctypedef int (*http_data_cb) (http_parser*,
const char *at,
size_t length) except -1
ctypedef int (*http_cb) (http_parser*) except -1
struct http_parser:
unsigned int type
unsigned int flags
unsigned int state
unsigned int header_state
unsigned int index
uint32_t nread
uint64_t content_length
unsigned short http_major
unsigned short http_minor
unsigned int status_code
unsigned int method
unsigned int http_errno
unsigned int upgrade
void *data
struct http_parser_settings:
http_cb on_message_begin
http_data_cb on_url
http_data_cb on_status
http_data_cb on_header_field
http_data_cb on_header_value
http_cb on_headers_complete
http_data_cb on_body
http_cb on_message_complete
http_cb on_chunk_header
http_cb on_chunk_complete
enum http_parser_type:
HTTP_REQUEST,
HTTP_RESPONSE,
HTTP_BOTH
enum http_errno:
HPE_OK,
HPE_CB_message_begin,
HPE_CB_url,
HPE_CB_header_field,
HPE_CB_header_value,
HPE_CB_headers_complete,
HPE_CB_body,
HPE_CB_message_complete,
HPE_CB_status,
HPE_CB_chunk_header,
HPE_CB_chunk_complete,
HPE_INVALID_EOF_STATE,
HPE_HEADER_OVERFLOW,
HPE_CLOSED_CONNECTION,
HPE_INVALID_VERSION,
HPE_INVALID_STATUS,
HPE_INVALID_METHOD,
HPE_INVALID_URL,
HPE_INVALID_HOST,
HPE_INVALID_PORT,
HPE_INVALID_PATH,
HPE_INVALID_QUERY_STRING,
HPE_INVALID_FRAGMENT,
HPE_LF_EXPECTED,
HPE_INVALID_HEADER_TOKEN,
HPE_INVALID_CONTENT_LENGTH,
HPE_INVALID_CHUNK_SIZE,
HPE_INVALID_CONSTANT,
HPE_INVALID_INTERNAL_STATE,
HPE_STRICT,
HPE_PAUSED,
HPE_UNKNOWN
enum flags:
F_CHUNKED,
F_CONNECTION_KEEP_ALIVE,
F_CONNECTION_CLOSE,
F_CONNECTION_UPGRADE,
F_TRAILING,
F_UPGRADE,
F_SKIPBODY,
F_CONTENTLENGTH
enum http_method:
DELETE, GET, HEAD, POST, PUT, CONNECT, OPTIONS, TRACE, COPY,
LOCK, MKCOL, MOVE, PROPFIND, PROPPATCH, SEARCH, UNLOCK, BIND,
REBIND, UNBIND, ACL, REPORT, MKACTIVITY, CHECKOUT, MERGE,
MSEARCH, NOTIFY, SUBSCRIBE, UNSUBSCRIBE, PATCH, PURGE, MKCALENDAR,
LINK, UNLINK
void http_parser_init(http_parser *parser, http_parser_type type)
size_t http_parser_execute(http_parser *parser,
const http_parser_settings *settings,
const char *data,
size_t len)
int http_should_keep_alive(const http_parser *parser)
void http_parser_settings_init(http_parser_settings *settings)
const char *http_errno_name(http_errno err)
const char *http_errno_description(http_errno err)
const char *http_method_str(http_method m)
# URL Parser
enum http_parser_url_fields:
UF_SCHEMA = 0,
UF_HOST = 1,
UF_PORT = 2,
UF_PATH = 3,
UF_QUERY = 4,
UF_FRAGMENT = 5,
UF_USERINFO = 6,
UF_MAX = 7
struct http_parser_url_field_data:
uint16_t off
uint16_t len
struct http_parser_url:
uint16_t field_set
uint16_t port
http_parser_url_field_data[<int>UF_MAX] field_data
void http_parser_url_init(http_parser_url *u)
int http_parser_parse_url(const char *buf,
size_t buflen,
int is_connect,
http_parser_url *u)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,14 @@
#ifndef _FIND_HEADERS_H
#define _FIND_HEADERS_H
#ifdef __cplusplus
extern "C" {
#endif
int find_header(const char *str, int size);
#ifdef __cplusplus
}
#endif
#endif

View File

@ -0,0 +1,2 @@
cdef extern from "_find_header.h":
int find_header(char *, int)

View File

@ -0,0 +1,108 @@
from collections.abc import MutableSequence
cdef class FrozenList:
cdef readonly bint frozen
cdef list _items
def __init__(self, items=None):
self.frozen = False
if items is not None:
items = list(items)
else:
items = []
self._items = items
cdef object _check_frozen(self):
if self.frozen:
raise RuntimeError("Cannot modify frozen list.")
cdef inline object _fast_len(self):
return len(self._items)
def freeze(self):
self.frozen = True
def __getitem__(self, index):
return self._items[index]
def __setitem__(self, index, value):
self._check_frozen()
self._items[index] = value
def __delitem__(self, index):
self._check_frozen()
del self._items[index]
def __len__(self):
return self._fast_len()
def __iter__(self):
return self._items.__iter__()
def __reversed__(self):
return self._items.__reversed__()
def __richcmp__(self, other, op):
if op == 0: # <
return list(self) < other
if op == 1: # <=
return list(self) <= other
if op == 2: # ==
return list(self) == other
if op == 3: # !=
return list(self) != other
if op == 4: # >
return list(self) > other
if op == 5: # =>
return list(self) >= other
def insert(self, pos, item):
self._check_frozen()
self._items.insert(pos, item)
def __contains__(self, item):
return item in self._items
def __iadd__(self, items):
self._check_frozen()
self._items += list(items)
return self
def index(self, item):
return self._items.index(item)
def remove(self, item):
self._check_frozen()
self._items.remove(item)
def clear(self):
self._check_frozen()
self._items.clear()
def extend(self, items):
self._check_frozen()
self._items += list(items)
def reverse(self):
self._check_frozen()
self._items.reverse()
def pop(self, index=-1):
self._check_frozen()
return self._items.pop(index)
def append(self, item):
self._check_frozen()
return self._items.append(item)
def count(self, item):
return self._items.count(item)
def __repr__(self):
return '<FrozenList(frozen={}, {!r})>'.format(self.frozen,
self._items)
MutableSequence.register(FrozenList)

View File

@ -0,0 +1,85 @@
# The file is autogenerated from hyper_internal_service/hdrs.py
# Run ./tools/gen.py to update it after the origin changing.
from . import hdrs
cdef tuple headers = (
hdrs.ACCEPT,
hdrs.ACCEPT_CHARSET,
hdrs.ACCEPT_ENCODING,
hdrs.ACCEPT_LANGUAGE,
hdrs.ACCEPT_RANGES,
hdrs.ACCESS_CONTROL_ALLOW_CREDENTIALS,
hdrs.ACCESS_CONTROL_ALLOW_HEADERS,
hdrs.ACCESS_CONTROL_ALLOW_METHODS,
hdrs.ACCESS_CONTROL_ALLOW_ORIGIN,
hdrs.ACCESS_CONTROL_EXPOSE_HEADERS,
hdrs.ACCESS_CONTROL_MAX_AGE,
hdrs.ACCESS_CONTROL_REQUEST_HEADERS,
hdrs.ACCESS_CONTROL_REQUEST_METHOD,
hdrs.AGE,
hdrs.ALLOW,
hdrs.AUTHORIZATION,
hdrs.CACHE_CONTROL,
hdrs.CONNECTION,
hdrs.CONTENT_DISPOSITION,
hdrs.CONTENT_ENCODING,
hdrs.CONTENT_LANGUAGE,
hdrs.CONTENT_LENGTH,
hdrs.CONTENT_LOCATION,
hdrs.CONTENT_MD5,
hdrs.CONTENT_RANGE,
hdrs.CONTENT_TRANSFER_ENCODING,
hdrs.CONTENT_TYPE,
hdrs.COOKIE,
hdrs.DATE,
hdrs.DESTINATION,
hdrs.DIGEST,
hdrs.ETAG,
hdrs.EXPECT,
hdrs.EXPIRES,
hdrs.FORWARDED,
hdrs.FROM,
hdrs.HOST,
hdrs.IF_MATCH,
hdrs.IF_MODIFIED_SINCE,
hdrs.IF_NONE_MATCH,
hdrs.IF_RANGE,
hdrs.IF_UNMODIFIED_SINCE,
hdrs.KEEP_ALIVE,
hdrs.LAST_EVENT_ID,
hdrs.LAST_MODIFIED,
hdrs.LINK,
hdrs.LOCATION,
hdrs.MAX_FORWARDS,
hdrs.ORIGIN,
hdrs.PRAGMA,
hdrs.PROXY_AUTHENTICATE,
hdrs.PROXY_AUTHORIZATION,
hdrs.RANGE,
hdrs.REFERER,
hdrs.RETRY_AFTER,
hdrs.SEC_WEBSOCKET_ACCEPT,
hdrs.SEC_WEBSOCKET_EXTENSIONS,
hdrs.SEC_WEBSOCKET_KEY,
hdrs.SEC_WEBSOCKET_KEY1,
hdrs.SEC_WEBSOCKET_PROTOCOL,
hdrs.SEC_WEBSOCKET_VERSION,
hdrs.SERVER,
hdrs.SET_COOKIE,
hdrs.TE,
hdrs.TRAILER,
hdrs.TRANSFER_ENCODING,
hdrs.UPGRADE,
hdrs.URI,
hdrs.USER_AGENT,
hdrs.VARY,
hdrs.VIA,
hdrs.WANT_DIGEST,
hdrs.WARNING,
hdrs.WEBSOCKET,
hdrs.WWW_AUTHENTICATE,
hdrs.X_POWERED_BY,
hdrs.X_FORWARDED_FOR,
hdrs.X_FORWARDED_HOST,
hdrs.X_FORWARDED_PROTO,
)

View File

@ -0,0 +1,8 @@
from typing import Any
class reify:
def __init__(self, wrapped: Any) -> None: ...
def __get__(self, inst: Any, owner: Any) -> Any: ...
def __set__(self, inst: Any, value: Any) -> None: ...

View File

@ -0,0 +1,35 @@
cdef class reify:
"""Use as a class method decorator. It operates almost exactly like
the Python `@property` decorator, but it puts the result of the
method it decorates into the instance dict after the first call,
effectively replacing the function it decorates with an instance
variable. It is, in Python parlance, a data descriptor.
"""
cdef object wrapped
cdef object name
def __init__(self, wrapped):
self.wrapped = wrapped
self.name = wrapped.__name__
@property
def __doc__(self):
return self.wrapped.__doc__
def __get__(self, inst, owner):
try:
try:
return inst._cache[self.name]
except KeyError:
val = self.wrapped(inst)
inst._cache[self.name] = val
return val
except AttributeError:
if inst is None:
return self
raise
def __set__(self, inst, value):
raise AttributeError("reified property is read-only")

View File

@ -0,0 +1,858 @@
#cython: language_level=3
#
# Based on https://github.com/MagicStack/httptools
#
from __future__ import absolute_import, print_function
from cpython.mem cimport PyMem_Malloc, PyMem_Free
from libc.string cimport memcpy
from libc.limits cimport ULLONG_MAX
from cpython cimport (PyObject_GetBuffer, PyBuffer_Release, PyBUF_SIMPLE,
Py_buffer, PyBytes_AsString, PyBytes_AsStringAndSize)
from multidict import (CIMultiDict as _CIMultiDict,
CIMultiDictProxy as _CIMultiDictProxy)
from yarl import URL as _URL
from hyper_internal_service import hdrs
from .http_exceptions import (
BadHttpMessage, BadStatusLine, InvalidHeader, LineTooLong, InvalidURLError,
PayloadEncodingError, ContentLengthError, TransferEncodingError)
from .http_writer import (HttpVersion as _HttpVersion,
HttpVersion10 as _HttpVersion10,
HttpVersion11 as _HttpVersion11)
from .http_parser import DeflateBuffer as _DeflateBuffer
from .streams import (EMPTY_PAYLOAD as _EMPTY_PAYLOAD,
StreamReader as _StreamReader)
cimport cython
from hyper_internal_service cimport _cparser as cparser
include "_headers.pxi"
from hyper_internal_service cimport _find_header
DEF DEFAULT_FREELIST_SIZE = 250
cdef extern from "Python.h":
int PyByteArray_Resize(object, Py_ssize_t) except -1
Py_ssize_t PyByteArray_Size(object) except -1
char* PyByteArray_AsString(object)
__all__ = ('HttpRequestParser', 'HttpResponseParser',
'RawRequestMessage', 'RawResponseMessage')
cdef object URL = _URL
cdef object URL_build = URL.build
cdef object CIMultiDict = _CIMultiDict
cdef object CIMultiDictProxy = _CIMultiDictProxy
cdef object HttpVersion = _HttpVersion
cdef object HttpVersion10 = _HttpVersion10
cdef object HttpVersion11 = _HttpVersion11
cdef object SEC_WEBSOCKET_KEY1 = hdrs.SEC_WEBSOCKET_KEY1
cdef object CONTENT_ENCODING = hdrs.CONTENT_ENCODING
cdef object EMPTY_PAYLOAD = _EMPTY_PAYLOAD
cdef object StreamReader = _StreamReader
cdef object DeflateBuffer = _DeflateBuffer
cdef inline object extend(object buf, const char* at, size_t length):
cdef Py_ssize_t s
cdef char* ptr
s = PyByteArray_Size(buf)
PyByteArray_Resize(buf, s + length)
ptr = PyByteArray_AsString(buf)
memcpy(ptr + s, at, length)
DEF METHODS_COUNT = 34;
cdef list _http_method = []
for i in range(METHODS_COUNT):
_http_method.append(
cparser.http_method_str(<cparser.http_method> i).decode('ascii'))
cdef inline str http_method_str(int i):
if i < METHODS_COUNT:
return <str>_http_method[i]
else:
return "<unknown>"
cdef inline object find_header(bytes raw_header):
cdef Py_ssize_t size
cdef char *buf
cdef int idx
PyBytes_AsStringAndSize(raw_header, &buf, &size)
idx = _find_header.find_header(buf, size)
if idx == -1:
return raw_header.decode('utf-8', 'surrogateescape')
return headers[idx]
@cython.freelist(DEFAULT_FREELIST_SIZE)
cdef class RawRequestMessage:
cdef readonly str method
cdef readonly str path
cdef readonly object version # HttpVersion
cdef readonly object headers # CIMultiDict
cdef readonly object raw_headers # tuple
cdef readonly object should_close
cdef readonly object compression
cdef readonly object upgrade
cdef readonly object chunked
cdef readonly object url # yarl.URL
def __init__(self, method, path, version, headers, raw_headers,
should_close, compression, upgrade, chunked, url):
self.method = method
self.path = path
self.version = version
self.headers = headers
self.raw_headers = raw_headers
self.should_close = should_close
self.compression = compression
self.upgrade = upgrade
self.chunked = chunked
self.url = url
def __repr__(self):
info = []
info.append(("method", self.method))
info.append(("path", self.path))
info.append(("version", self.version))
info.append(("headers", self.headers))
info.append(("raw_headers", self.raw_headers))
info.append(("should_close", self.should_close))
info.append(("compression", self.compression))
info.append(("upgrade", self.upgrade))
info.append(("chunked", self.chunked))
info.append(("url", self.url))
sinfo = ', '.join(name + '=' + repr(val) for name, val in info)
return '<RawRequestMessage(' + sinfo + ')>'
def _replace(self, **dct):
cdef RawRequestMessage ret
ret = _new_request_message(self.method,
self.path,
self.version,
self.headers,
self.raw_headers,
self.should_close,
self.compression,
self.upgrade,
self.chunked,
self.url)
if "method" in dct:
ret.method = dct["method"]
if "path" in dct:
ret.path = dct["path"]
if "version" in dct:
ret.version = dct["version"]
if "headers" in dct:
ret.headers = dct["headers"]
if "raw_headers" in dct:
ret.raw_headers = dct["raw_headers"]
if "should_close" in dct:
ret.should_close = dct["should_close"]
if "compression" in dct:
ret.compression = dct["compression"]
if "upgrade" in dct:
ret.upgrade = dct["upgrade"]
if "chunked" in dct:
ret.chunked = dct["chunked"]
if "url" in dct:
ret.url = dct["url"]
return ret
cdef _new_request_message(str method,
str path,
object version,
object headers,
object raw_headers,
bint should_close,
object compression,
bint upgrade,
bint chunked,
object url):
cdef RawRequestMessage ret
ret = RawRequestMessage.__new__(RawRequestMessage)
ret.method = method
ret.path = path
ret.version = version
ret.headers = headers
ret.raw_headers = raw_headers
ret.should_close = should_close
ret.compression = compression
ret.upgrade = upgrade
ret.chunked = chunked
ret.url = url
return ret
@cython.freelist(DEFAULT_FREELIST_SIZE)
cdef class RawResponseMessage:
cdef readonly object version # HttpVersion
cdef readonly int code
cdef readonly str reason
cdef readonly object headers # CIMultiDict
cdef readonly object raw_headers # tuple
cdef readonly object should_close
cdef readonly object compression
cdef readonly object upgrade
cdef readonly object chunked
def __init__(self, version, code, reason, headers, raw_headers,
should_close, compression, upgrade, chunked):
self.version = version
self.code = code
self.reason = reason
self.headers = headers
self.raw_headers = raw_headers
self.should_close = should_close
self.compression = compression
self.upgrade = upgrade
self.chunked = chunked
def __repr__(self):
info = []
info.append(("version", self.version))
info.append(("code", self.code))
info.append(("reason", self.reason))
info.append(("headers", self.headers))
info.append(("raw_headers", self.raw_headers))
info.append(("should_close", self.should_close))
info.append(("compression", self.compression))
info.append(("upgrade", self.upgrade))
info.append(("chunked", self.chunked))
sinfo = ', '.join(name + '=' + repr(val) for name, val in info)
return '<RawResponseMessage(' + sinfo + ')>'
cdef _new_response_message(object version,
int code,
str reason,
object headers,
object raw_headers,
bint should_close,
object compression,
bint upgrade,
bint chunked):
cdef RawResponseMessage ret
ret = RawResponseMessage.__new__(RawResponseMessage)
ret.version = version
ret.code = code
ret.reason = reason
ret.headers = headers
ret.raw_headers = raw_headers
ret.should_close = should_close
ret.compression = compression
ret.upgrade = upgrade
ret.chunked = chunked
return ret
@cython.internal
cdef class HttpParser:
cdef:
cparser.http_parser* _cparser
cparser.http_parser_settings* _csettings
bytearray _raw_name
bytearray _raw_value
bint _has_value
object _protocol
object _loop
object _timer
size_t _max_line_size
size_t _max_field_size
size_t _max_headers
bint _response_with_body
bint _read_until_eof
bint _started
object _url
bytearray _buf
str _path
str _reason
object _headers
list _raw_headers
bint _upgraded
list _messages
object _payload
bint _payload_error
object _payload_exception
object _last_error
bint _auto_decompress
str _content_encoding
Py_buffer py_buf
def __cinit__(self):
self._cparser = <cparser.http_parser*> \
PyMem_Malloc(sizeof(cparser.http_parser))
if self._cparser is NULL:
raise MemoryError()
self._csettings = <cparser.http_parser_settings*> \
PyMem_Malloc(sizeof(cparser.http_parser_settings))
if self._csettings is NULL:
raise MemoryError()
def __dealloc__(self):
PyMem_Free(self._cparser)
PyMem_Free(self._csettings)
cdef _init(self, cparser.http_parser_type mode,
object protocol, object loop, object timer=None,
size_t max_line_size=8190, size_t max_headers=32768,
size_t max_field_size=8190, payload_exception=None,
bint response_with_body=True, bint read_until_eof=False,
bint auto_decompress=True):
cparser.http_parser_init(self._cparser, mode)
self._cparser.data = <void*>self
self._cparser.content_length = 0
cparser.http_parser_settings_init(self._csettings)
self._protocol = protocol
self._loop = loop
self._timer = timer
self._buf = bytearray()
self._payload = None
self._payload_error = 0
self._payload_exception = payload_exception
self._messages = []
self._raw_name = bytearray()
self._raw_value = bytearray()
self._has_value = False
self._max_line_size = max_line_size
self._max_headers = max_headers
self._max_field_size = max_field_size
self._response_with_body = response_with_body
self._read_until_eof = read_until_eof
self._upgraded = False
self._auto_decompress = auto_decompress
self._content_encoding = None
self._csettings.on_url = cb_on_url
self._csettings.on_status = cb_on_status
self._csettings.on_header_field = cb_on_header_field
self._csettings.on_header_value = cb_on_header_value
self._csettings.on_headers_complete = cb_on_headers_complete
self._csettings.on_body = cb_on_body
self._csettings.on_message_begin = cb_on_message_begin
self._csettings.on_message_complete = cb_on_message_complete
self._csettings.on_chunk_header = cb_on_chunk_header
self._csettings.on_chunk_complete = cb_on_chunk_complete
self._last_error = None
cdef _process_header(self):
if self._raw_name:
raw_name = bytes(self._raw_name)
raw_value = bytes(self._raw_value)
name = find_header(raw_name)
value = raw_value.decode('utf-8', 'surrogateescape')
self._headers.add(name, value)
if name is CONTENT_ENCODING:
self._content_encoding = value
PyByteArray_Resize(self._raw_name, 0)
PyByteArray_Resize(self._raw_value, 0)
self._has_value = False
self._raw_headers.append((raw_name, raw_value))
cdef _on_header_field(self, char* at, size_t length):
cdef Py_ssize_t size
cdef char *buf
if self._has_value:
self._process_header()
size = PyByteArray_Size(self._raw_name)
PyByteArray_Resize(self._raw_name, size + length)
buf = PyByteArray_AsString(self._raw_name)
memcpy(buf + size, at, length)
cdef _on_header_value(self, char* at, size_t length):
cdef Py_ssize_t size
cdef char *buf
size = PyByteArray_Size(self._raw_value)
PyByteArray_Resize(self._raw_value, size + length)
buf = PyByteArray_AsString(self._raw_value)
memcpy(buf + size, at, length)
self._has_value = True
cdef _on_headers_complete(self):
self._process_header()
method = http_method_str(self._cparser.method)
should_close = not cparser.http_should_keep_alive(self._cparser)
upgrade = self._cparser.upgrade
chunked = self._cparser.flags & cparser.F_CHUNKED
raw_headers = tuple(self._raw_headers)
headers = CIMultiDictProxy(self._headers)
if upgrade or self._cparser.method == 5: # cparser.CONNECT:
self._upgraded = True
# do not support old websocket spec
if SEC_WEBSOCKET_KEY1 in headers:
raise InvalidHeader(SEC_WEBSOCKET_KEY1)
encoding = None
enc = self._content_encoding
if enc is not None:
self._content_encoding = None
enc = enc.lower()
if enc in ('gzip', 'deflate', 'br'):
encoding = enc
if self._cparser.type == cparser.HTTP_REQUEST:
msg = _new_request_message(
method, self._path,
self.http_version(), headers, raw_headers,
should_close, encoding, upgrade, chunked, self._url)
else:
msg = _new_response_message(
self.http_version(), self._cparser.status_code, self._reason,
headers, raw_headers, should_close, encoding,
upgrade, chunked)
if (ULLONG_MAX > self._cparser.content_length > 0 or chunked or
self._cparser.method == 5 or # CONNECT: 5
(self._cparser.status_code >= 199 and
self._cparser.content_length == ULLONG_MAX and
self._read_until_eof)
):
payload = StreamReader(
self._protocol, timer=self._timer, loop=self._loop)
else:
payload = EMPTY_PAYLOAD
self._payload = payload
if encoding is not None and self._auto_decompress:
self._payload = DeflateBuffer(payload, encoding)
if not self._response_with_body:
payload = EMPTY_PAYLOAD
self._messages.append((msg, payload))
cdef _on_message_complete(self):
self._payload.feed_eof()
self._payload = None
cdef _on_chunk_header(self):
self._payload.begin_http_chunk_receiving()
cdef _on_chunk_complete(self):
self._payload.end_http_chunk_receiving()
cdef object _on_status_complete(self):
pass
cdef inline http_version(self):
cdef cparser.http_parser* parser = self._cparser
if parser.http_major == 1:
if parser.http_minor == 0:
return HttpVersion10
elif parser.http_minor == 1:
return HttpVersion11
return HttpVersion(parser.http_major, parser.http_minor)
### Public API ###
def feed_eof(self):
cdef bytes desc
if self._payload is not None:
if self._cparser.flags & cparser.F_CHUNKED:
raise TransferEncodingError(
"Not enough data for satisfy transfer length header.")
elif self._cparser.flags & cparser.F_CONTENTLENGTH:
raise ContentLengthError(
"Not enough data for satisfy content length header.")
elif self._cparser.http_errno != cparser.HPE_OK:
desc = cparser.http_errno_description(
<cparser.http_errno> self._cparser.http_errno)
raise PayloadEncodingError(desc.decode('latin-1'))
else:
self._payload.feed_eof()
elif self._started:
self._on_headers_complete()
if self._messages:
return self._messages[-1][0]
def feed_data(self, data):
cdef:
size_t data_len
size_t nb
PyObject_GetBuffer(data, &self.py_buf, PyBUF_SIMPLE)
data_len = <size_t>self.py_buf.len
nb = cparser.http_parser_execute(
self._cparser,
self._csettings,
<char*>self.py_buf.buf,
data_len)
PyBuffer_Release(&self.py_buf)
# i am not sure about cparser.HPE_INVALID_METHOD,
# seems get err for valid request
# test_client_functional.py::test_post_data_with_bytesio_file
if (self._cparser.http_errno != cparser.HPE_OK and
(self._cparser.http_errno != cparser.HPE_INVALID_METHOD or
self._cparser.method == 0)):
if self._payload_error == 0:
if self._last_error is not None:
ex = self._last_error
self._last_error = None
else:
ex = parser_error_from_errno(
<cparser.http_errno> self._cparser.http_errno)
self._payload = None
raise ex
if self._messages:
messages = self._messages
self._messages = []
else:
messages = ()
if self._upgraded:
return messages, True, data[nb:]
else:
return messages, False, b''
def set_upgraded(self, val):
self._upgraded = val
cdef class HttpRequestParser(HttpParser):
def __init__(self, protocol, loop, timer=None,
size_t max_line_size=8190, size_t max_headers=32768,
size_t max_field_size=8190, payload_exception=None,
bint response_with_body=True, bint read_until_eof=False):
self._init(cparser.HTTP_REQUEST, protocol, loop, timer,
max_line_size, max_headers, max_field_size,
payload_exception, response_with_body, read_until_eof)
cdef object _on_status_complete(self):
cdef Py_buffer py_buf
if not self._buf:
return
self._path = self._buf.decode('utf-8', 'surrogateescape')
if self._cparser.method == 5: # CONNECT
self._url = URL(self._path)
else:
PyObject_GetBuffer(self._buf, &py_buf, PyBUF_SIMPLE)
try:
self._url = _parse_url(<char*>py_buf.buf,
py_buf.len)
finally:
PyBuffer_Release(&py_buf)
PyByteArray_Resize(self._buf, 0)
cdef class HttpResponseParser(HttpParser):
def __init__(self, protocol, loop, timer=None,
size_t max_line_size=8190, size_t max_headers=32768,
size_t max_field_size=8190, payload_exception=None,
bint response_with_body=True, bint read_until_eof=False,
bint auto_decompress=True):
self._init(cparser.HTTP_RESPONSE, protocol, loop, timer,
max_line_size, max_headers, max_field_size,
payload_exception, response_with_body, read_until_eof,
auto_decompress)
cdef object _on_status_complete(self):
if self._buf:
self._reason = self._buf.decode('utf-8', 'surrogateescape')
PyByteArray_Resize(self._buf, 0)
else:
self._reason = self._reason or ''
cdef int cb_on_message_begin(cparser.http_parser* parser) except -1:
cdef HttpParser pyparser = <HttpParser>parser.data
pyparser._started = True
pyparser._headers = CIMultiDict()
pyparser._raw_headers = []
PyByteArray_Resize(pyparser._buf, 0)
pyparser._path = None
pyparser._reason = None
return 0
cdef int cb_on_url(cparser.http_parser* parser,
const char *at, size_t length) except -1:
cdef HttpParser pyparser = <HttpParser>parser.data
try:
if length > pyparser._max_line_size:
raise LineTooLong(
'Status line is too long', pyparser._max_line_size, length)
extend(pyparser._buf, at, length)
except BaseException as ex:
pyparser._last_error = ex
return -1
else:
return 0
cdef int cb_on_status(cparser.http_parser* parser,
const char *at, size_t length) except -1:
cdef HttpParser pyparser = <HttpParser>parser.data
cdef str reason
try:
if length > pyparser._max_line_size:
raise LineTooLong(
'Status line is too long', pyparser._max_line_size, length)
extend(pyparser._buf, at, length)
except BaseException as ex:
pyparser._last_error = ex
return -1
else:
return 0
cdef int cb_on_header_field(cparser.http_parser* parser,
const char *at, size_t length) except -1:
cdef HttpParser pyparser = <HttpParser>parser.data
cdef Py_ssize_t size
try:
pyparser._on_status_complete()
size = len(pyparser._raw_name) + length
if size > pyparser._max_field_size:
raise LineTooLong(
'Header name is too long', pyparser._max_field_size, size)
pyparser._on_header_field(at, length)
except BaseException as ex:
pyparser._last_error = ex
return -1
else:
return 0
cdef int cb_on_header_value(cparser.http_parser* parser,
const char *at, size_t length) except -1:
cdef HttpParser pyparser = <HttpParser>parser.data
cdef Py_ssize_t size
try:
size = len(pyparser._raw_value) + length
if size > pyparser._max_field_size:
raise LineTooLong(
'Header value is too long', pyparser._max_field_size, size)
pyparser._on_header_value(at, length)
except BaseException as ex:
pyparser._last_error = ex
return -1
else:
return 0
cdef int cb_on_headers_complete(cparser.http_parser* parser) except -1:
cdef HttpParser pyparser = <HttpParser>parser.data
try:
pyparser._on_status_complete()
pyparser._on_headers_complete()
except BaseException as exc:
pyparser._last_error = exc
return -1
else:
if pyparser._cparser.upgrade or pyparser._cparser.method == 5: # CONNECT
return 2
else:
return 0
cdef int cb_on_body(cparser.http_parser* parser,
const char *at, size_t length) except -1:
cdef HttpParser pyparser = <HttpParser>parser.data
cdef bytes body = at[:length]
try:
pyparser._payload.feed_data(body, length)
except BaseException as exc:
if pyparser._payload_exception is not None:
pyparser._payload.set_exception(pyparser._payload_exception(str(exc)))
else:
pyparser._payload.set_exception(exc)
pyparser._payload_error = 1
return -1
else:
return 0
cdef int cb_on_message_complete(cparser.http_parser* parser) except -1:
cdef HttpParser pyparser = <HttpParser>parser.data
try:
pyparser._started = False
pyparser._on_message_complete()
except BaseException as exc:
pyparser._last_error = exc
return -1
else:
return 0
cdef int cb_on_chunk_header(cparser.http_parser* parser) except -1:
cdef HttpParser pyparser = <HttpParser>parser.data
try:
pyparser._on_chunk_header()
except BaseException as exc:
pyparser._last_error = exc
return -1
else:
return 0
cdef int cb_on_chunk_complete(cparser.http_parser* parser) except -1:
cdef HttpParser pyparser = <HttpParser>parser.data
try:
pyparser._on_chunk_complete()
except BaseException as exc:
pyparser._last_error = exc
return -1
else:
return 0
cdef parser_error_from_errno(cparser.http_errno errno):
cdef bytes desc = cparser.http_errno_description(errno)
if errno in (cparser.HPE_CB_message_begin,
cparser.HPE_CB_url,
cparser.HPE_CB_header_field,
cparser.HPE_CB_header_value,
cparser.HPE_CB_headers_complete,
cparser.HPE_CB_body,
cparser.HPE_CB_message_complete,
cparser.HPE_CB_status,
cparser.HPE_CB_chunk_header,
cparser.HPE_CB_chunk_complete):
cls = BadHttpMessage
elif errno == cparser.HPE_INVALID_STATUS:
cls = BadStatusLine
elif errno == cparser.HPE_INVALID_METHOD:
cls = BadStatusLine
elif errno == cparser.HPE_INVALID_URL:
cls = InvalidURLError
else:
cls = BadHttpMessage
return cls(desc.decode('latin-1'))
def parse_url(url):
cdef:
Py_buffer py_buf
char* buf_data
PyObject_GetBuffer(url, &py_buf, PyBUF_SIMPLE)
try:
buf_data = <char*>py_buf.buf
return _parse_url(buf_data, py_buf.len)
finally:
PyBuffer_Release(&py_buf)
cdef _parse_url(char* buf_data, size_t length):
cdef:
cparser.http_parser_url* parsed
int res
str schema = None
str host = None
object port = None
str path = None
str query = None
str fragment = None
str user = None
str password = None
str userinfo = None
object result = None
int off
int ln
parsed = <cparser.http_parser_url*> \
PyMem_Malloc(sizeof(cparser.http_parser_url))
if parsed is NULL:
raise MemoryError()
cparser.http_parser_url_init(parsed)
try:
res = cparser.http_parser_parse_url(buf_data, length, 0, parsed)
if res == 0:
if parsed.field_set & (1 << cparser.UF_SCHEMA):
off = parsed.field_data[<int>cparser.UF_SCHEMA].off
ln = parsed.field_data[<int>cparser.UF_SCHEMA].len
schema = buf_data[off:off+ln].decode('utf-8', 'surrogateescape')
else:
schema = ''
if parsed.field_set & (1 << cparser.UF_HOST):
off = parsed.field_data[<int>cparser.UF_HOST].off
ln = parsed.field_data[<int>cparser.UF_HOST].len
host = buf_data[off:off+ln].decode('utf-8', 'surrogateescape')
else:
host = ''
if parsed.field_set & (1 << cparser.UF_PORT):
port = parsed.port
if parsed.field_set & (1 << cparser.UF_PATH):
off = parsed.field_data[<int>cparser.UF_PATH].off
ln = parsed.field_data[<int>cparser.UF_PATH].len
path = buf_data[off:off+ln].decode('utf-8', 'surrogateescape')
else:
path = ''
if parsed.field_set & (1 << cparser.UF_QUERY):
off = parsed.field_data[<int>cparser.UF_QUERY].off
ln = parsed.field_data[<int>cparser.UF_QUERY].len
query = buf_data[off:off+ln].decode('utf-8', 'surrogateescape')
else:
query = ''
if parsed.field_set & (1 << cparser.UF_FRAGMENT):
off = parsed.field_data[<int>cparser.UF_FRAGMENT].off
ln = parsed.field_data[<int>cparser.UF_FRAGMENT].len
fragment = buf_data[off:off+ln].decode('utf-8', 'surrogateescape')
else:
fragment = ''
if parsed.field_set & (1 << cparser.UF_USERINFO):
off = parsed.field_data[<int>cparser.UF_USERINFO].off
ln = parsed.field_data[<int>cparser.UF_USERINFO].len
userinfo = buf_data[off:off+ln].decode('utf-8', 'surrogateescape')
user, sep, password = userinfo.partition(':')
return URL_build(scheme=schema,
user=user, password=password, host=host, port=port,
path=path, query=query, fragment=fragment)
else:
raise InvalidURLError("invalid url {!r}".format(buf_data))
finally:
PyMem_Free(parsed)

View File

@ -0,0 +1,152 @@
from libc.stdint cimport uint8_t, uint64_t
from libc.string cimport memcpy
from cpython.exc cimport PyErr_NoMemory
from cpython.mem cimport PyMem_Malloc, PyMem_Realloc, PyMem_Free
from cpython.bytes cimport PyBytes_FromStringAndSize
from cpython.object cimport PyObject_Str
from multidict import istr
DEF BUF_SIZE = 16 * 1024 # 16KiB
cdef char BUFFER[BUF_SIZE]
cdef object _istr = istr
# ----------------- writer ---------------------------
cdef struct Writer:
char *buf
Py_ssize_t size
Py_ssize_t pos
cdef inline void _init_writer(Writer* writer):
writer.buf = &BUFFER[0]
writer.size = BUF_SIZE
writer.pos = 0
cdef inline void _release_writer(Writer* writer):
if writer.buf != BUFFER:
PyMem_Free(writer.buf)
cdef inline int _write_byte(Writer* writer, uint8_t ch):
cdef char * buf
cdef Py_ssize_t size
if writer.pos == writer.size:
# reallocate
size = writer.size + BUF_SIZE
if writer.buf == BUFFER:
buf = <char*>PyMem_Malloc(size)
if buf == NULL:
PyErr_NoMemory()
return -1
memcpy(buf, writer.buf, writer.size)
else:
buf = <char*>PyMem_Realloc(writer.buf, size)
if buf == NULL:
PyErr_NoMemory()
return -1
writer.buf = buf
writer.size = size
writer.buf[writer.pos] = <char>ch
writer.pos += 1
return 0
cdef inline int _write_utf8(Writer* writer, Py_UCS4 symbol):
cdef uint64_t utf = <uint64_t> symbol
if utf < 0x80:
return _write_byte(writer, <uint8_t>utf)
elif utf < 0x800:
if _write_byte(writer, <uint8_t>(0xc0 | (utf >> 6))) < 0:
return -1
return _write_byte(writer, <uint8_t>(0x80 | (utf & 0x3f)))
elif 0xD800 <= utf <= 0xDFFF:
# surogate pair, ignored
return 0
elif utf < 0x10000:
if _write_byte(writer, <uint8_t>(0xe0 | (utf >> 12))) < 0:
return -1
if _write_byte(writer, <uint8_t>(0x80 | ((utf >> 6) & 0x3f))) < 0:
return -1
return _write_byte(writer, <uint8_t>(0x80 | (utf & 0x3f)))
elif utf > 0x10FFFF:
# symbol is too large
return 0
else:
if _write_byte(writer, <uint8_t>(0xf0 | (utf >> 18))) < 0:
return -1
if _write_byte(writer,
<uint8_t>(0x80 | ((utf >> 12) & 0x3f))) < 0:
return -1
if _write_byte(writer,
<uint8_t>(0x80 | ((utf >> 6) & 0x3f))) < 0:
return -1
return _write_byte(writer, <uint8_t>(0x80 | (utf & 0x3f)))
cdef inline int _write_str(Writer* writer, str s):
cdef Py_UCS4 ch
for ch in s:
if _write_utf8(writer, ch) < 0:
return -1
# --------------- _serialize_headers ----------------------
cdef str to_str(object s):
typ = type(s)
if typ is str:
return <str>s
elif typ is _istr:
return PyObject_Str(s)
elif not isinstance(s, str):
raise TypeError("Cannot serialize non-str key {!r}".format(s))
else:
return str(s)
def _serialize_headers(str status_line, headers):
cdef Writer writer
cdef object key
cdef object val
cdef bytes ret
_init_writer(&writer)
try:
if _write_str(&writer, status_line) < 0:
raise
if _write_byte(&writer, b'\r') < 0:
raise
if _write_byte(&writer, b'\n') < 0:
raise
for key, val in headers.items():
if _write_str(&writer, to_str(key)) < 0:
raise
if _write_byte(&writer, b':') < 0:
raise
if _write_byte(&writer, b' ') < 0:
raise
if _write_str(&writer, to_str(val)) < 0:
raise
if _write_byte(&writer, b'\r') < 0:
raise
if _write_byte(&writer, b'\n') < 0:
raise
if _write_byte(&writer, b'\r') < 0:
raise
if _write_byte(&writer, b'\n') < 0:
raise
return PyBytes_FromStringAndSize(writer.buf, writer.pos)
finally:
_release_writer(&writer)

View File

@ -0,0 +1,54 @@
from cpython cimport PyBytes_AsString
#from cpython cimport PyByteArray_AsString # cython still not exports that
cdef extern from "Python.h":
char* PyByteArray_AsString(bytearray ba) except NULL
from libc.stdint cimport uint32_t, uint64_t, uintmax_t
def _websocket_mask_cython(object mask, object data):
"""Note, this function mutates its `data` argument
"""
cdef:
Py_ssize_t data_len, i
# bit operations on signed integers are implementation-specific
unsigned char * in_buf
const unsigned char * mask_buf
uint32_t uint32_msk
uint64_t uint64_msk
assert len(mask) == 4
if not isinstance(mask, bytes):
mask = bytes(mask)
if isinstance(data, bytearray):
data = <bytearray>data
else:
data = bytearray(data)
data_len = len(data)
in_buf = <unsigned char*>PyByteArray_AsString(data)
mask_buf = <const unsigned char*>PyBytes_AsString(mask)
uint32_msk = (<uint32_t*>mask_buf)[0]
# TODO: align in_data ptr to achieve even faster speeds
# does it need in python ?! malloc() always aligns to sizeof(long) bytes
if sizeof(size_t) >= 8:
uint64_msk = uint32_msk
uint64_msk = (uint64_msk << 32) | uint32_msk
while data_len >= 8:
(<uint64_t*>in_buf)[0] ^= uint64_msk
in_buf += 8
data_len -= 8
while data_len >= 4:
(<uint32_t*>in_buf)[0] ^= uint32_msk
in_buf += 4
data_len -= 4
for i in range(0, data_len):
in_buf[i] ^= mask_buf[i]

View File

@ -0,0 +1,208 @@
import asyncio
import logging
from abc import ABC, abstractmethod
from collections.abc import Sized
from http.cookies import BaseCookie, Morsel # noqa
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
Dict,
Generator,
Iterable,
List,
Optional,
Tuple,
)
from multidict import CIMultiDict # noqa
from yarl import URL
from .helpers import get_running_loop
from .typedefs import LooseCookies
if TYPE_CHECKING: # pragma: no cover
from .web_request import BaseRequest, Request
from .web_response import StreamResponse
from .web_app import Application
from .web_exceptions import HTTPException
else:
BaseRequest = Request = Application = StreamResponse = None
HTTPException = None
class AbstractRouter(ABC):
def __init__(self) -> None:
self._frozen = False
def post_init(self, app: Application) -> None:
"""Post init stage.
Not an abstract method for sake of backward compatibility,
but if the router wants to be aware of the application
it can override this.
"""
@property
def frozen(self) -> bool:
return self._frozen
def freeze(self) -> None:
"""Freeze router."""
self._frozen = True
@abstractmethod
async def resolve(self, request: Request) -> 'AbstractMatchInfo':
"""Return MATCH_INFO for given request"""
class AbstractMatchInfo(ABC):
@property # pragma: no branch
@abstractmethod
def handler(self) -> Callable[[Request], Awaitable[StreamResponse]]:
"""Execute matched request handler"""
@property
@abstractmethod
def expect_handler(self) -> Callable[[Request], Awaitable[None]]:
"""Expect handler for 100-continue processing"""
@property # pragma: no branch
@abstractmethod
def http_exception(self) -> Optional[HTTPException]:
"""HTTPException instance raised on router's resolving, or None"""
@abstractmethod # pragma: no branch
def get_info(self) -> Dict[str, Any]:
"""Return a dict with additional info useful for introspection"""
@property # pragma: no branch
@abstractmethod
def apps(self) -> Tuple[Application, ...]:
"""Stack of nested applications.
Top level application is left-most element.
"""
@abstractmethod
def add_app(self, app: Application) -> None:
"""Add application to the nested apps stack."""
@abstractmethod
def freeze(self) -> None:
"""Freeze the match info.
The method is called after route resolution.
After the call .add_app() is forbidden.
"""
class AbstractView(ABC):
"""Abstract class based view."""
def __init__(self, request: Request) -> None:
self._request = request
@property
def request(self) -> Request:
"""Request instance."""
return self._request
@abstractmethod
def __await__(self) -> Generator[Any, None, StreamResponse]:
"""Execute the view handler."""
class AbstractResolver(ABC):
"""Abstract DNS resolver."""
@abstractmethod
async def resolve(self, host: str,
port: int, family: int) -> List[Dict[str, Any]]:
"""Return IP address for given hostname"""
@abstractmethod
async def close(self) -> None:
"""Release resolver"""
if TYPE_CHECKING: # pragma: no cover
IterableBase = Iterable[Morsel[str]]
else:
IterableBase = Iterable
class AbstractCookieJar(Sized, IterableBase):
"""Abstract Cookie Jar."""
def __init__(self, *,
loop: Optional[asyncio.AbstractEventLoop]=None) -> None:
self._loop = get_running_loop(loop)
@abstractmethod
def clear(self) -> None:
"""Clear all cookies."""
@abstractmethod
def update_cookies(self,
cookies: LooseCookies,
response_url: URL=URL()) -> None:
"""Update cookies."""
@abstractmethod
def filter_cookies(self, request_url: URL) -> 'BaseCookie[str]':
"""Return the jar's cookies filtered by their attributes."""
class AbstractStreamWriter(ABC):
"""Abstract stream writer."""
buffer_size = 0
output_size = 0
length = 0 # type: Optional[int]
@abstractmethod
async def write(self, chunk: bytes) -> None:
"""Write chunk into stream."""
@abstractmethod
async def write_eof(self, chunk: bytes=b'') -> None:
"""Write last chunk."""
@abstractmethod
async def drain(self) -> None:
"""Flush the write buffer."""
@abstractmethod
def enable_compression(self, encoding: str='deflate') -> None:
"""Enable HTTP body compression"""
@abstractmethod
def enable_chunking(self) -> None:
"""Enable HTTP chunked mode"""
@abstractmethod
async def write_headers(self, status_line: str,
headers: 'CIMultiDict[str]') -> None:
"""Write HTTP headers"""
class AbstractAccessLogger(ABC):
"""Abstract writer to access log."""
def __init__(self, logger: logging.Logger, log_format: str) -> None:
self.logger = logger
self.log_format = log_format
@abstractmethod
def log(self,
request: BaseRequest,
response: StreamResponse,
time: float) -> None:
"""Emit log to logger."""

View File

@ -0,0 +1,81 @@
import asyncio
from typing import Optional, cast
from .tcp_helpers import tcp_nodelay
class BaseProtocol(asyncio.Protocol):
__slots__ = ('_loop', '_paused', '_drain_waiter',
'_connection_lost', '_reading_paused', 'transport')
def __init__(self, loop: asyncio.AbstractEventLoop) -> None:
self._loop = loop # type: asyncio.AbstractEventLoop
self._paused = False
self._drain_waiter = None # type: Optional[asyncio.Future[None]]
self._connection_lost = False
self._reading_paused = False
self.transport = None # type: Optional[asyncio.Transport]
def pause_writing(self) -> None:
assert not self._paused
self._paused = True
def resume_writing(self) -> None:
assert self._paused
self._paused = False
waiter = self._drain_waiter
if waiter is not None:
self._drain_waiter = None
if not waiter.done():
waiter.set_result(None)
def pause_reading(self) -> None:
if not self._reading_paused and self.transport is not None:
try:
self.transport.pause_reading()
except (AttributeError, NotImplementedError, RuntimeError):
pass
self._reading_paused = True
def resume_reading(self) -> None:
if self._reading_paused and self.transport is not None:
try:
self.transport.resume_reading()
except (AttributeError, NotImplementedError, RuntimeError):
pass
self._reading_paused = False
def connection_made(self, transport: asyncio.BaseTransport) -> None:
tr = cast(asyncio.Transport, transport)
tcp_nodelay(tr, True)
self.transport = tr
def connection_lost(self, exc: Optional[BaseException]) -> None:
self._connection_lost = True
# Wake up the writer if currently paused.
self.transport = None
if not self._paused:
return
waiter = self._drain_waiter
if waiter is None:
return
self._drain_waiter = None
if waiter.done():
return
if exc is None:
waiter.set_result(None)
else:
waiter.set_exception(exc)
async def _drain_helper(self) -> None:
if self._connection_lost:
raise ConnectionResetError('Connection lost')
if not self._paused:
return
waiter = self._drain_waiter
assert waiter is None or waiter.cancelled()
waiter = self._loop.create_future()
self._drain_waiter = waiter
await waiter

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,292 @@
"""HTTP related errors."""
import asyncio
import warnings
from typing import TYPE_CHECKING, Any, Optional, Tuple, Union
from .typedefs import _CIMultiDict
try:
import ssl
SSLContext = ssl.SSLContext
except ImportError: # pragma: no cover
ssl = SSLContext = None # type: ignore
if TYPE_CHECKING: # pragma: no cover
from .client_reqrep import (RequestInfo, ClientResponse, ConnectionKey, # noqa
Fingerprint)
else:
RequestInfo = ClientResponse = ConnectionKey = None
__all__ = (
'ClientError',
'ClientConnectionError',
'ClientOSError', 'ClientConnectorError', 'ClientProxyConnectionError',
'ClientSSLError',
'ClientConnectorSSLError', 'ClientConnectorCertificateError',
'ServerConnectionError', 'ServerTimeoutError', 'ServerDisconnectedError',
'ServerFingerprintMismatch',
'ClientResponseError', 'ClientHttpProxyError',
'WSServerHandshakeError', 'ContentTypeError',
'ClientPayloadError', 'InvalidURL')
class ClientError(Exception):
"""Base class for client connection errors."""
class ClientResponseError(ClientError):
"""Connection error during reading response.
request_info: instance of RequestInfo
"""
def __init__(self, request_info: RequestInfo,
history: Tuple[ClientResponse, ...], *,
code: Optional[int]=None,
status: Optional[int]=None,
message: str='',
headers: Optional[_CIMultiDict]=None) -> None:
self.request_info = request_info
if code is not None:
if status is not None:
raise ValueError(
"Both code and status arguments are provided; "
"code is deprecated, use status instead")
warnings.warn("code argument is deprecated, use status instead",
DeprecationWarning,
stacklevel=2)
if status is not None:
self.status = status
elif code is not None:
self.status = code
else:
self.status = 0
self.message = message
self.headers = headers
self.history = history
self.args = (request_info, history)
def __str__(self) -> str:
return ("%s, message=%r, url=%r" %
(self.status, self.message, self.request_info.real_url))
def __repr__(self) -> str:
args = "%r, %r" % (self.request_info, self.history)
if self.status != 0:
args += ", status=%r" % (self.status,)
if self.message != '':
args += ", message=%r" % (self.message,)
if self.headers is not None:
args += ", headers=%r" % (self.headers,)
return "%s(%s)" % (type(self).__name__, args)
@property
def code(self) -> int:
warnings.warn("code property is deprecated, use status instead",
DeprecationWarning,
stacklevel=2)
return self.status
@code.setter
def code(self, value: int) -> None:
warnings.warn("code property is deprecated, use status instead",
DeprecationWarning,
stacklevel=2)
self.status = value
class ContentTypeError(ClientResponseError):
"""ContentType found is not valid."""
class WSServerHandshakeError(ClientResponseError):
"""websocket server handshake error."""
class ClientHttpProxyError(ClientResponseError):
"""HTTP proxy error.
Raised in :class:`hyper_internal_service.connector.TCPConnector` if
proxy responds with status other than ``200 OK``
on ``CONNECT`` request.
"""
class TooManyRedirects(ClientResponseError):
"""Client was redirected too many times."""
class ClientConnectionError(ClientError):
"""Base class for client socket errors."""
class ClientOSError(ClientConnectionError, OSError):
"""OSError error."""
class ClientConnectorError(ClientOSError):
"""Client connector error.
Raised in :class:`hyper_internal_service.connector.TCPConnector` if
connection to proxy can not be established.
"""
def __init__(self, connection_key: ConnectionKey,
os_error: OSError) -> None:
self._conn_key = connection_key
self._os_error = os_error
super().__init__(os_error.errno, os_error.strerror)
self.args = (connection_key, os_error)
@property
def os_error(self) -> OSError:
return self._os_error
@property
def host(self) -> str:
return self._conn_key.host
@property
def port(self) -> Optional[int]:
return self._conn_key.port
@property
def ssl(self) -> Union[SSLContext, None, bool, 'Fingerprint']:
return self._conn_key.ssl
def __str__(self) -> str:
return ('Cannot connect to host {0.host}:{0.port} ssl:{1} [{2}]'
.format(self, self.ssl if self.ssl is not None else 'default',
self.strerror))
# OSError.__reduce__ does too much black magick
__reduce__ = BaseException.__reduce__
class ClientProxyConnectionError(ClientConnectorError):
"""Proxy connection error.
Raised in :class:`hyper_internal_service.connector.TCPConnector` if
connection to proxy can not be established.
"""
class ServerConnectionError(ClientConnectionError):
"""Server connection errors."""
class ServerDisconnectedError(ServerConnectionError):
"""Server disconnected."""
def __init__(self, message: Optional[str]=None) -> None:
if message is None:
message = 'Server disconnected'
self.args = (message,)
self.message = message
class ServerTimeoutError(ServerConnectionError, asyncio.TimeoutError):
"""Server timeout error."""
class ServerFingerprintMismatch(ServerConnectionError):
"""SSL certificate does not match expected fingerprint."""
def __init__(self, expected: bytes, got: bytes,
host: str, port: int) -> None:
self.expected = expected
self.got = got
self.host = host
self.port = port
self.args = (expected, got, host, port)
def __repr__(self) -> str:
return '<{} expected={!r} got={!r} host={!r} port={!r}>'.format(
self.__class__.__name__, self.expected, self.got,
self.host, self.port)
class ClientPayloadError(ClientError):
"""Response payload error."""
class InvalidURL(ClientError, ValueError):
"""Invalid URL.
URL used for fetching is malformed, e.g. it doesn't contains host
part."""
# Derive from ValueError for backward compatibility
def __init__(self, url: Any) -> None:
# The type of url is not yarl.URL because the exception can be raised
# on URL(url) call
super().__init__(url)
@property
def url(self) -> Any:
return self.args[0]
def __repr__(self) -> str:
return '<{} {}>'.format(self.__class__.__name__, self.url)
class ClientSSLError(ClientConnectorError):
"""Base error for ssl.*Errors."""
if ssl is not None:
cert_errors = (ssl.CertificateError,)
cert_errors_bases = (ClientSSLError, ssl.CertificateError,)
ssl_errors = (ssl.SSLError,)
ssl_error_bases = (ClientSSLError, ssl.SSLError)
else: # pragma: no cover
cert_errors = tuple()
cert_errors_bases = (ClientSSLError, ValueError,)
ssl_errors = tuple()
ssl_error_bases = (ClientSSLError,)
class ClientConnectorSSLError(*ssl_error_bases): # type: ignore
"""Response ssl error."""
class ClientConnectorCertificateError(*cert_errors_bases): # type: ignore
"""Response certificate error."""
def __init__(self, connection_key:
ConnectionKey, certificate_error: Exception) -> None:
self._conn_key = connection_key
self._certificate_error = certificate_error
self.args = (connection_key, certificate_error)
@property
def certificate_error(self) -> Exception:
return self._certificate_error
@property
def host(self) -> str:
return self._conn_key.host
@property
def port(self) -> Optional[int]:
return self._conn_key.port
@property
def ssl(self) -> bool:
return self._conn_key.is_ssl
def __str__(self) -> str:
return ('Cannot connect to host {0.host}:{0.port} ssl:{0.ssl} '
'[{0.certificate_error.__class__.__name__}: '
'{0.certificate_error.args}]'.format(self))

View File

@ -0,0 +1,239 @@
import asyncio
from contextlib import suppress
from typing import Any, Optional, Tuple
from .base_protocol import BaseProtocol
from .client_exceptions import (
ClientOSError,
ClientPayloadError,
ServerDisconnectedError,
ServerTimeoutError,
)
from .helpers import BaseTimerContext
from .http import HttpResponseParser, RawResponseMessage
from .streams import EMPTY_PAYLOAD, DataQueue, StreamReader
class ResponseHandler(BaseProtocol,
DataQueue[Tuple[RawResponseMessage, StreamReader]]):
"""Helper class to adapt between Protocol and StreamReader."""
def __init__(self,
loop: asyncio.AbstractEventLoop) -> None:
BaseProtocol.__init__(self, loop=loop)
DataQueue.__init__(self, loop)
self._should_close = False
self._payload = None
self._skip_payload = False
self._payload_parser = None
self._timer = None
self._tail = b''
self._upgraded = False
self._parser = None # type: Optional[HttpResponseParser]
self._read_timeout = None # type: Optional[float]
self._read_timeout_handle = None # type: Optional[asyncio.TimerHandle]
@property
def upgraded(self) -> bool:
return self._upgraded
@property
def should_close(self) -> bool:
if (self._payload is not None and
not self._payload.is_eof() or self._upgraded):
return True
return (self._should_close or self._upgraded or
self.exception() is not None or
self._payload_parser is not None or
len(self) > 0 or bool(self._tail))
def force_close(self) -> None:
self._should_close = True
def close(self) -> None:
transport = self.transport
if transport is not None:
transport.close()
self.transport = None
self._payload = None
self._drop_timeout()
def is_connected(self) -> bool:
return self.transport is not None
def connection_lost(self, exc: Optional[BaseException]) -> None:
self._drop_timeout()
if self._payload_parser is not None:
with suppress(Exception):
self._payload_parser.feed_eof()
uncompleted = None
if self._parser is not None:
try:
uncompleted = self._parser.feed_eof()
except Exception:
if self._payload is not None:
self._payload.set_exception(
ClientPayloadError(
'Response payload is not completed'))
if not self.is_eof():
if isinstance(exc, OSError):
exc = ClientOSError(*exc.args)
if exc is None:
exc = ServerDisconnectedError(uncompleted)
# assigns self._should_close to True as side effect,
# we do it anyway below
self.set_exception(exc)
self._should_close = True
self._parser = None
self._payload = None
self._payload_parser = None
self._reading_paused = False
super().connection_lost(exc)
def eof_received(self) -> None:
# should call parser.feed_eof() most likely
self._drop_timeout()
def pause_reading(self) -> None:
super().pause_reading()
self._drop_timeout()
def resume_reading(self) -> None:
super().resume_reading()
self._reschedule_timeout()
def set_exception(self, exc: BaseException) -> None:
self._should_close = True
self._drop_timeout()
super().set_exception(exc)
def set_parser(self, parser: Any, payload: Any) -> None:
# TODO: actual types are:
# parser: WebSocketReader
# payload: FlowControlDataQueue
# but they are not generi enough
# Need an ABC for both types
self._payload = payload
self._payload_parser = parser
self._drop_timeout()
if self._tail:
data, self._tail = self._tail, b''
self.data_received(data)
def set_response_params(self, *, timer: BaseTimerContext=None,
skip_payload: bool=False,
read_until_eof: bool=False,
auto_decompress: bool=True,
read_timeout: Optional[float]=None) -> None:
self._skip_payload = skip_payload
self._read_timeout = read_timeout
self._reschedule_timeout()
self._parser = HttpResponseParser(
self, self._loop, timer=timer,
payload_exception=ClientPayloadError,
read_until_eof=read_until_eof,
auto_decompress=auto_decompress)
if self._tail:
data, self._tail = self._tail, b''
self.data_received(data)
def _drop_timeout(self) -> None:
if self._read_timeout_handle is not None:
self._read_timeout_handle.cancel()
self._read_timeout_handle = None
def _reschedule_timeout(self) -> None:
timeout = self._read_timeout
if self._read_timeout_handle is not None:
self._read_timeout_handle.cancel()
if timeout:
self._read_timeout_handle = self._loop.call_later(
timeout, self._on_read_timeout)
else:
self._read_timeout_handle = None
def _on_read_timeout(self) -> None:
exc = ServerTimeoutError("Timeout on reading data from socket")
self.set_exception(exc)
if self._payload is not None:
self._payload.set_exception(exc)
def data_received(self, data: bytes) -> None:
self._reschedule_timeout()
if not data:
return
# custom payload parser
if self._payload_parser is not None:
eof, tail = self._payload_parser.feed_data(data)
if eof:
self._payload = None
self._payload_parser = None
if tail:
self.data_received(tail)
return
else:
if self._upgraded or self._parser is None:
# i.e. websocket connection, websocket parser is not set yet
self._tail += data
else:
# parse http messages
try:
messages, upgraded, tail = self._parser.feed_data(data)
except BaseException as exc:
if self.transport is not None:
# connection.release() could be called BEFORE
# data_received(), the transport is already
# closed in this case
self.transport.close()
# should_close is True after the call
self.set_exception(exc)
return
self._upgraded = upgraded
payload = None
for message, payload in messages:
if message.should_close:
self._should_close = True
self._payload = payload
if self._skip_payload or message.code in (204, 304):
self.feed_data((message, EMPTY_PAYLOAD), 0) # type: ignore # noqa
else:
self.feed_data((message, payload), 0)
if payload is not None:
# new message(s) was processed
# register timeout handler unsubscribing
# either on end-of-stream or immediately for
# EMPTY_PAYLOAD
if payload is not EMPTY_PAYLOAD:
payload.on_eof(self._drop_timeout)
else:
self._drop_timeout()
if tail:
if upgraded:
self.data_received(tail)
else:
self._tail = tail

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,301 @@
"""WebSocket client for asyncio."""
import asyncio
from typing import Any, Optional
import async_timeout
from .client_exceptions import ClientError
from .client_reqrep import ClientResponse
from .helpers import call_later, set_result
from .http import (
WS_CLOSED_MESSAGE,
WS_CLOSING_MESSAGE,
WebSocketError,
WSMessage,
WSMsgType,
)
from .http_websocket import WebSocketWriter # WSMessage
from .streams import EofStream, FlowControlDataQueue # noqa
from .typedefs import (
DEFAULT_JSON_DECODER,
DEFAULT_JSON_ENCODER,
JSONDecoder,
JSONEncoder,
)
class ClientWebSocketResponse:
def __init__(self,
reader: 'FlowControlDataQueue[WSMessage]',
writer: WebSocketWriter,
protocol: Optional[str],
response: ClientResponse,
timeout: float,
autoclose: bool,
autoping: bool,
loop: asyncio.AbstractEventLoop,
*,
receive_timeout: Optional[float]=None,
heartbeat: Optional[float]=None,
compress: int=0,
client_notakeover: bool=False) -> None:
self._response = response
self._conn = response.connection
self._writer = writer
self._reader = reader
self._protocol = protocol
self._closed = False
self._closing = False
self._close_code = None # type: Optional[int]
self._timeout = timeout
self._receive_timeout = receive_timeout
self._autoclose = autoclose
self._autoping = autoping
self._heartbeat = heartbeat
self._heartbeat_cb = None
if heartbeat is not None:
self._pong_heartbeat = heartbeat / 2.0
self._pong_response_cb = None
self._loop = loop
self._waiting = None # type: Optional[asyncio.Future[bool]]
self._exception = None # type: Optional[BaseException]
self._compress = compress
self._client_notakeover = client_notakeover
self._reset_heartbeat()
def _cancel_heartbeat(self) -> None:
if self._pong_response_cb is not None:
self._pong_response_cb.cancel()
self._pong_response_cb = None
if self._heartbeat_cb is not None:
self._heartbeat_cb.cancel()
self._heartbeat_cb = None
def _reset_heartbeat(self) -> None:
self._cancel_heartbeat()
if self._heartbeat is not None:
self._heartbeat_cb = call_later(
self._send_heartbeat, self._heartbeat, self._loop)
def _send_heartbeat(self) -> None:
if self._heartbeat is not None and not self._closed:
# fire-and-forget a task is not perfect but maybe ok for
# sending ping. Otherwise we need a long-living heartbeat
# task in the class.
self._loop.create_task(self._writer.ping())
if self._pong_response_cb is not None:
self._pong_response_cb.cancel()
self._pong_response_cb = call_later(
self._pong_not_received, self._pong_heartbeat, self._loop)
def _pong_not_received(self) -> None:
if not self._closed:
self._closed = True
self._close_code = 1006
self._exception = asyncio.TimeoutError()
self._response.close()
@property
def closed(self) -> bool:
return self._closed
@property
def close_code(self) -> Optional[int]:
return self._close_code
@property
def protocol(self) -> Optional[str]:
return self._protocol
@property
def compress(self) -> int:
return self._compress
@property
def client_notakeover(self) -> bool:
return self._client_notakeover
def get_extra_info(self, name: str, default: Any=None) -> Any:
"""extra info from connection transport"""
conn = self._response.connection
if conn is None:
return default
transport = conn.transport
if transport is None:
return default
return transport.get_extra_info(name, default)
def exception(self) -> Optional[BaseException]:
return self._exception
async def ping(self, message: bytes=b'') -> None:
await self._writer.ping(message)
async def pong(self, message: bytes=b'') -> None:
await self._writer.pong(message)
async def send_str(self, data: str,
compress: Optional[int]=None) -> None:
if not isinstance(data, str):
raise TypeError('data argument must be str (%r)' % type(data))
await self._writer.send(data, binary=False, compress=compress)
async def send_bytes(self, data: bytes,
compress: Optional[int]=None) -> None:
if not isinstance(data, (bytes, bytearray, memoryview)):
raise TypeError('data argument must be byte-ish (%r)' %
type(data))
await self._writer.send(data, binary=True, compress=compress)
async def send_json(self, data: Any,
compress: Optional[int]=None,
*, dumps: JSONEncoder=DEFAULT_JSON_ENCODER) -> None:
await self.send_str(dumps(data), compress=compress)
async def close(self, *, code: int=1000, message: bytes=b'') -> bool:
# we need to break `receive()` cycle first,
# `close()` may be called from different task
if self._waiting is not None and not self._closed:
self._reader.feed_data(WS_CLOSING_MESSAGE, 0)
await self._waiting
if not self._closed:
self._cancel_heartbeat()
self._closed = True
try:
await self._writer.close(code, message)
except asyncio.CancelledError:
self._close_code = 1006
self._response.close()
raise
except Exception as exc:
self._close_code = 1006
self._exception = exc
self._response.close()
return True
if self._closing:
self._response.close()
return True
while True:
try:
with async_timeout.timeout(self._timeout, loop=self._loop):
msg = await self._reader.read()
except asyncio.CancelledError:
self._close_code = 1006
self._response.close()
raise
except Exception as exc:
self._close_code = 1006
self._exception = exc
self._response.close()
return True
if msg.type == WSMsgType.CLOSE:
self._close_code = msg.data
self._response.close()
return True
else:
return False
async def receive(self, timeout: Optional[float]=None) -> WSMessage:
while True:
if self._waiting is not None:
raise RuntimeError(
'Concurrent call to receive() is not allowed')
if self._closed:
return WS_CLOSED_MESSAGE
elif self._closing:
await self.close()
return WS_CLOSED_MESSAGE
try:
self._waiting = self._loop.create_future()
try:
with async_timeout.timeout(
timeout or self._receive_timeout,
loop=self._loop):
msg = await self._reader.read()
self._reset_heartbeat()
finally:
waiter = self._waiting
self._waiting = None
set_result(waiter, True)
except (asyncio.CancelledError, asyncio.TimeoutError):
self._close_code = 1006
raise
except EofStream:
self._close_code = 1000
await self.close()
return WSMessage(WSMsgType.CLOSED, None, None)
except ClientError:
self._closed = True
self._close_code = 1006
return WS_CLOSED_MESSAGE
except WebSocketError as exc:
self._close_code = exc.code
await self.close(code=exc.code)
return WSMessage(WSMsgType.ERROR, exc, None)
except Exception as exc:
self._exception = exc
self._closing = True
self._close_code = 1006
await self.close()
return WSMessage(WSMsgType.ERROR, exc, None)
if msg.type == WSMsgType.CLOSE:
self._closing = True
self._close_code = msg.data
if not self._closed and self._autoclose:
await self.close()
elif msg.type == WSMsgType.CLOSING:
self._closing = True
elif msg.type == WSMsgType.PING and self._autoping:
await self.pong(msg.data)
continue
elif msg.type == WSMsgType.PONG and self._autoping:
continue
return msg
async def receive_str(self, *, timeout: Optional[float]=None) -> str:
msg = await self.receive(timeout)
if msg.type != WSMsgType.TEXT:
raise TypeError(
"Received message {}:{!r} is not str".format(msg.type,
msg.data))
return msg.data
async def receive_bytes(self, *, timeout: Optional[float]=None) -> bytes:
msg = await self.receive(timeout)
if msg.type != WSMsgType.BINARY:
raise TypeError(
"Received message {}:{!r} is not bytes".format(msg.type,
msg.data))
return msg.data
async def receive_json(self,
*, loads: JSONDecoder=DEFAULT_JSON_DECODER,
timeout: Optional[float]=None) -> Any:
data = await self.receive_str(timeout=timeout)
return loads(data)
def __aiter__(self) -> 'ClientWebSocketResponse':
return self
async def __anext__(self) -> WSMessage:
msg = await self.receive()
if msg.type in (WSMsgType.CLOSE,
WSMsgType.CLOSING,
WSMsgType.CLOSED):
raise StopAsyncIteration # NOQA
return msg

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,368 @@
import asyncio
import datetime
import os # noqa
import pathlib
import pickle
import re
from collections import defaultdict
from http.cookies import BaseCookie, Morsel, SimpleCookie # noqa
from typing import ( # noqa
DefaultDict,
Dict,
Iterable,
Iterator,
Mapping,
Optional,
Set,
Tuple,
Union,
cast,
)
from yarl import URL
from .abc import AbstractCookieJar
from .helpers import is_ip_address, next_whole_second
from .typedefs import LooseCookies, PathLike
__all__ = ('CookieJar', 'DummyCookieJar')
CookieItem = Union[str, 'Morsel[str]']
class CookieJar(AbstractCookieJar):
"""Implements cookie storage adhering to RFC 6265."""
DATE_TOKENS_RE = re.compile(
r"[\x09\x20-\x2F\x3B-\x40\x5B-\x60\x7B-\x7E]*"
r"(?P<token>[\x00-\x08\x0A-\x1F\d:a-zA-Z\x7F-\xFF]+)")
DATE_HMS_TIME_RE = re.compile(r"(\d{1,2}):(\d{1,2}):(\d{1,2})")
DATE_DAY_OF_MONTH_RE = re.compile(r"(\d{1,2})")
DATE_MONTH_RE = re.compile("(jan)|(feb)|(mar)|(apr)|(may)|(jun)|(jul)|"
"(aug)|(sep)|(oct)|(nov)|(dec)", re.I)
DATE_YEAR_RE = re.compile(r"(\d{2,4})")
MAX_TIME = datetime.datetime.max.replace(
tzinfo=datetime.timezone.utc)
def __init__(self, *, unsafe: bool=False,
loop: Optional[asyncio.AbstractEventLoop]=None) -> None:
super().__init__(loop=loop)
self._cookies = defaultdict(SimpleCookie) #type: DefaultDict[str, SimpleCookie[str]] # noqa
self._host_only_cookies = set() # type: Set[Tuple[str, str]]
self._unsafe = unsafe
self._next_expiration = next_whole_second()
self._expirations = {} # type: Dict[Tuple[str, str], datetime.datetime] # noqa: E501
def save(self, file_path: PathLike) -> None:
file_path = pathlib.Path(file_path)
with file_path.open(mode='wb') as f:
pickle.dump(self._cookies, f, pickle.HIGHEST_PROTOCOL)
def load(self, file_path: PathLike) -> None:
file_path = pathlib.Path(file_path)
with file_path.open(mode='rb') as f:
self._cookies = pickle.load(f)
def clear(self) -> None:
self._cookies.clear()
self._host_only_cookies.clear()
self._next_expiration = next_whole_second()
self._expirations.clear()
def __iter__(self) -> 'Iterator[Morsel[str]]':
self._do_expiration()
for val in self._cookies.values():
yield from val.values()
def __len__(self) -> int:
return sum(1 for i in self)
def _do_expiration(self) -> None:
now = datetime.datetime.now(datetime.timezone.utc)
if self._next_expiration > now:
return
if not self._expirations:
return
next_expiration = self.MAX_TIME
to_del = []
cookies = self._cookies
expirations = self._expirations
for (domain, name), when in expirations.items():
if when <= now:
cookies[domain].pop(name, None)
to_del.append((domain, name))
self._host_only_cookies.discard((domain, name))
else:
next_expiration = min(next_expiration, when)
for key in to_del:
del expirations[key]
try:
self._next_expiration = (next_expiration.replace(microsecond=0) +
datetime.timedelta(seconds=1))
except OverflowError:
self._next_expiration = self.MAX_TIME
def _expire_cookie(self, when: datetime.datetime, domain: str, name: str
) -> None:
self._next_expiration = min(self._next_expiration, when)
self._expirations[(domain, name)] = when
def update_cookies(self,
cookies: LooseCookies,
response_url: URL=URL()) -> None:
"""Update cookies."""
hostname = response_url.raw_host
if not self._unsafe and is_ip_address(hostname):
# Don't accept cookies from IPs
return
if isinstance(cookies, Mapping):
cookies = cookies.items()
for name, cookie in cookies:
if not isinstance(cookie, Morsel):
tmp = SimpleCookie() # type: SimpleCookie[str]
tmp[name] = cookie # type: ignore
cookie = tmp[name]
domain = cookie["domain"]
# ignore domains with trailing dots
if domain.endswith('.'):
domain = ""
del cookie["domain"]
if not domain and hostname is not None:
# Set the cookie's domain to the response hostname
# and set its host-only-flag
self._host_only_cookies.add((hostname, name))
domain = cookie["domain"] = hostname
if domain.startswith("."):
# Remove leading dot
domain = domain[1:]
cookie["domain"] = domain
if hostname and not self._is_domain_match(domain, hostname):
# Setting cookies for different domains is not allowed
continue
path = cookie["path"]
if not path or not path.startswith("/"):
# Set the cookie's path to the response path
path = response_url.path
if not path.startswith("/"):
path = "/"
else:
# Cut everything from the last slash to the end
path = "/" + path[1:path.rfind("/")]
cookie["path"] = path
max_age = cookie["max-age"]
if max_age:
try:
delta_seconds = int(max_age)
try:
max_age_expiration = (
datetime.datetime.now(datetime.timezone.utc) +
datetime.timedelta(seconds=delta_seconds))
except OverflowError:
max_age_expiration = self.MAX_TIME
self._expire_cookie(max_age_expiration,
domain, name)
except ValueError:
cookie["max-age"] = ""
else:
expires = cookie["expires"]
if expires:
expire_time = self._parse_date(expires)
if expire_time:
self._expire_cookie(expire_time,
domain, name)
else:
cookie["expires"] = ""
self._cookies[domain][name] = cookie
self._do_expiration()
def filter_cookies(self, request_url: URL=URL()) -> 'BaseCookie[str]':
"""Returns this jar's cookies filtered by their attributes."""
self._do_expiration()
request_url = URL(request_url)
filtered = SimpleCookie() # type: SimpleCookie[str]
hostname = request_url.raw_host or ""
is_not_secure = request_url.scheme not in ("https", "wss")
for cookie in self:
name = cookie.key
domain = cookie["domain"]
# Send shared cookies
if not domain:
filtered[name] = cookie.value
continue
if not self._unsafe and is_ip_address(hostname):
continue
if (domain, name) in self._host_only_cookies:
if domain != hostname:
continue
elif not self._is_domain_match(domain, hostname):
continue
if not self._is_path_match(request_url.path, cookie["path"]):
continue
if is_not_secure and cookie["secure"]:
continue
# It's critical we use the Morsel so the coded_value
# (based on cookie version) is preserved
mrsl_val = cast('Morsel[str]', cookie.get(cookie.key, Morsel()))
mrsl_val.set(cookie.key, cookie.value, cookie.coded_value)
filtered[name] = mrsl_val
return filtered
@staticmethod
def _is_domain_match(domain: str, hostname: str) -> bool:
"""Implements domain matching adhering to RFC 6265."""
if hostname == domain:
return True
if not hostname.endswith(domain):
return False
non_matching = hostname[:-len(domain)]
if not non_matching.endswith("."):
return False
return not is_ip_address(hostname)
@staticmethod
def _is_path_match(req_path: str, cookie_path: str) -> bool:
"""Implements path matching adhering to RFC 6265."""
if not req_path.startswith("/"):
req_path = "/"
if req_path == cookie_path:
return True
if not req_path.startswith(cookie_path):
return False
if cookie_path.endswith("/"):
return True
non_matching = req_path[len(cookie_path):]
return non_matching.startswith("/")
@classmethod
def _parse_date(cls, date_str: str) -> Optional[datetime.datetime]:
"""Implements date string parsing adhering to RFC 6265."""
if not date_str:
return None
found_time = False
found_day = False
found_month = False
found_year = False
hour = minute = second = 0
day = 0
month = 0
year = 0
for token_match in cls.DATE_TOKENS_RE.finditer(date_str):
token = token_match.group("token")
if not found_time:
time_match = cls.DATE_HMS_TIME_RE.match(token)
if time_match:
found_time = True
hour, minute, second = [
int(s) for s in time_match.groups()]
continue
if not found_day:
day_match = cls.DATE_DAY_OF_MONTH_RE.match(token)
if day_match:
found_day = True
day = int(day_match.group())
continue
if not found_month:
month_match = cls.DATE_MONTH_RE.match(token)
if month_match:
found_month = True
assert month_match.lastindex is not None
month = month_match.lastindex
continue
if not found_year:
year_match = cls.DATE_YEAR_RE.match(token)
if year_match:
found_year = True
year = int(year_match.group())
if 70 <= year <= 99:
year += 1900
elif 0 <= year <= 69:
year += 2000
if False in (found_day, found_month, found_year, found_time):
return None
if not 1 <= day <= 31:
return None
if year < 1601 or hour > 23 or minute > 59 or second > 59:
return None
return datetime.datetime(year, month, day,
hour, minute, second,
tzinfo=datetime.timezone.utc)
class DummyCookieJar(AbstractCookieJar):
"""Implements a dummy cookie storage.
It can be used with the ClientSession when no cookie processing is needed.
"""
def __init__(self, *,
loop: Optional[asyncio.AbstractEventLoop]=None) -> None:
super().__init__(loop=loop)
def __iter__(self) -> 'Iterator[Morsel[str]]':
while False:
yield None
def __len__(self) -> int:
return 0
def clear(self) -> None:
pass
def update_cookies(self,
cookies: LooseCookies,
response_url: URL=URL()) -> None:
pass
def filter_cookies(self, request_url: URL) -> 'BaseCookie[str]':
return SimpleCookie()

View File

@ -0,0 +1,154 @@
import io
from typing import Any, Iterable, List, Optional # noqa
from urllib.parse import urlencode
from multidict import MultiDict, MultiDictProxy
from . import hdrs, multipart, payload
from .helpers import guess_filename
from .payload import Payload
__all__ = ('FormData',)
class FormData:
"""Helper class for multipart/form-data and
application/x-www-form-urlencoded body generation."""
def __init__(self, fields:
Iterable[Any]=(),
quote_fields: bool=True,
charset: Optional[str]=None) -> None:
self._writer = multipart.MultipartWriter('form-data')
self._fields = [] # type: List[Any]
self._is_multipart = False
self._is_processed = False
self._quote_fields = quote_fields
self._charset = charset
if isinstance(fields, dict):
fields = list(fields.items())
elif not isinstance(fields, (list, tuple)):
fields = (fields,)
self.add_fields(*fields)
@property
def is_multipart(self) -> bool:
return self._is_multipart
def add_field(self, name: str, value: Any, *,
content_type: Optional[str]=None,
filename: Optional[str]=None,
content_transfer_encoding: Optional[str]=None) -> None:
if isinstance(value, io.IOBase):
self._is_multipart = True
elif isinstance(value, (bytes, bytearray, memoryview)):
if filename is None and content_transfer_encoding is None:
filename = name
type_options = MultiDict({'name': name}) # type: MultiDict[str]
if filename is not None and not isinstance(filename, str):
raise TypeError('filename must be an instance of str. '
'Got: %s' % filename)
if filename is None and isinstance(value, io.IOBase):
filename = guess_filename(value, name)
if filename is not None:
type_options['filename'] = filename
self._is_multipart = True
headers = {}
if content_type is not None:
if not isinstance(content_type, str):
raise TypeError('content_type must be an instance of str. '
'Got: %s' % content_type)
headers[hdrs.CONTENT_TYPE] = content_type
self._is_multipart = True
if content_transfer_encoding is not None:
if not isinstance(content_transfer_encoding, str):
raise TypeError('content_transfer_encoding must be an instance'
' of str. Got: %s' % content_transfer_encoding)
headers[hdrs.CONTENT_TRANSFER_ENCODING] = content_transfer_encoding
self._is_multipart = True
self._fields.append((type_options, headers, value))
def add_fields(self, *fields: Any) -> None:
to_add = list(fields)
while to_add:
rec = to_add.pop(0)
if isinstance(rec, io.IOBase):
k = guess_filename(rec, 'unknown')
self.add_field(k, rec) # type: ignore
elif isinstance(rec, (MultiDictProxy, MultiDict)):
to_add.extend(rec.items())
elif isinstance(rec, (list, tuple)) and len(rec) == 2:
k, fp = rec
self.add_field(k, fp) # type: ignore
else:
raise TypeError('Only io.IOBase, multidict and (name, file) '
'pairs allowed, use .add_field() for passing '
'more complex parameters, got {!r}'
.format(rec))
def _gen_form_urlencoded(self) -> payload.BytesPayload:
# form data (x-www-form-urlencoded)
data = []
for type_options, _, value in self._fields:
data.append((type_options['name'], value))
charset = self._charset if self._charset is not None else 'utf-8'
if charset == 'utf-8':
content_type = 'application/x-www-form-urlencoded'
else:
content_type = ('application/x-www-form-urlencoded; '
'charset=%s' % charset)
return payload.BytesPayload(
urlencode(data, doseq=True, encoding=charset).encode(),
content_type=content_type)
def _gen_form_data(self) -> multipart.MultipartWriter:
"""Encode a list of fields using the multipart/form-data MIME format"""
if self._is_processed:
raise RuntimeError('Form data has been processed already')
for dispparams, headers, value in self._fields:
try:
if hdrs.CONTENT_TYPE in headers:
part = payload.get_payload(
value, content_type=headers[hdrs.CONTENT_TYPE],
headers=headers, encoding=self._charset)
else:
part = payload.get_payload(
value, headers=headers, encoding=self._charset)
except Exception as exc:
raise TypeError(
'Can not serialize value type: %r\n '
'headers: %r\n value: %r' % (
type(value), headers, value)) from exc
if dispparams:
part.set_content_disposition(
'form-data', quote_fields=self._quote_fields, **dispparams
)
# FIXME cgi.FieldStorage doesn't likes body parts with
# Content-Length which were sent via chunked transfer encoding
assert part.headers is not None
part.headers.popall(hdrs.CONTENT_LENGTH, None)
self._writer.append_payload(part)
self._is_processed = True
return self._writer
def __call__(self) -> Payload:
if self._is_multipart:
return self._gen_form_data()
else:
return self._gen_form_urlencoded()

View File

@ -0,0 +1,72 @@
from collections.abc import MutableSequence
from functools import total_ordering
from .helpers import NO_EXTENSIONS
@total_ordering
class FrozenList(MutableSequence):
__slots__ = ('_frozen', '_items')
def __init__(self, items=None):
self._frozen = False
if items is not None:
items = list(items)
else:
items = []
self._items = items
@property
def frozen(self):
return self._frozen
def freeze(self):
self._frozen = True
def __getitem__(self, index):
return self._items[index]
def __setitem__(self, index, value):
if self._frozen:
raise RuntimeError("Cannot modify frozen list.")
self._items[index] = value
def __delitem__(self, index):
if self._frozen:
raise RuntimeError("Cannot modify frozen list.")
del self._items[index]
def __len__(self):
return self._items.__len__()
def __iter__(self):
return self._items.__iter__()
def __reversed__(self):
return self._items.__reversed__()
def __eq__(self, other):
return list(self) == other
def __le__(self, other):
return list(self) <= other
def insert(self, pos, item):
if self._frozen:
raise RuntimeError("Cannot modify frozen list.")
self._items.insert(pos, item)
def __repr__(self):
return '<FrozenList(frozen={}, {!r})>'.format(self._frozen,
self._items)
PyFrozenList = FrozenList
try:
from hyper_internal_service._frozenlist import FrozenList as CFrozenList # type: ignore
if not NO_EXTENSIONS:
FrozenList = CFrozenList # type: ignore
except ImportError: # pragma: no cover
pass

View File

@ -0,0 +1,63 @@
from typing import (
Generic,
Iterable,
Iterator,
List,
MutableSequence,
Optional,
TypeVar,
Union,
overload,
)
_T = TypeVar('_T')
_Arg = Union[List[_T], Iterable[_T]]
class FrozenList(MutableSequence[_T], Generic[_T]):
def __init__(self, items: Optional[_Arg[_T]]=...) -> None: ...
@property
def frozen(self) -> bool: ...
def freeze(self) -> None: ...
@overload
def __getitem__(self, i: int) -> _T: ...
@overload
def __getitem__(self, s: slice) -> FrozenList[_T]: ...
@overload
def __setitem__(self, i: int, o: _T) -> None: ...
@overload
def __setitem__(self, s: slice, o: Iterable[_T]) -> None: ...
@overload
def __delitem__(self, i: int) -> None: ...
@overload
def __delitem__(self, i: slice) -> None: ...
def __len__(self) -> int: ...
def __iter__(self) -> Iterator[_T]: ...
def __reversed__(self) -> Iterator[_T]: ...
def __eq__(self, other: object) -> bool: ...
def __le__(self, other: FrozenList[_T]) -> bool: ...
def __ne__(self, other: object) -> bool: ...
def __lt__(self, other: FrozenList[_T]) -> bool: ...
def __ge__(self, other: FrozenList[_T]) -> bool: ...
def __gt__(self, other: FrozenList[_T]) -> bool: ...
def insert(self, pos: int, item: _T) -> None: ...
def __repr__(self) -> str: ...
# types for C accelerators are the same
CFrozenList = PyFrozenList = FrozenList

View File

@ -0,0 +1,101 @@
"""HTTP Headers constants."""
# After changing the file content call ./tools/gen.py
# to regenerate the headers parser
from multidict import istr
METH_ANY = '*'
METH_CONNECT = 'CONNECT'
METH_HEAD = 'HEAD'
METH_GET = 'GET'
METH_DELETE = 'DELETE'
METH_OPTIONS = 'OPTIONS'
METH_PATCH = 'PATCH'
METH_POST = 'POST'
METH_PUT = 'PUT'
METH_TRACE = 'TRACE'
METH_ALL = {METH_CONNECT, METH_HEAD, METH_GET, METH_DELETE,
METH_OPTIONS, METH_PATCH, METH_POST, METH_PUT, METH_TRACE}
ACCEPT = istr('Accept')
ACCEPT_CHARSET = istr('Accept-Charset')
ACCEPT_ENCODING = istr('Accept-Encoding')
ACCEPT_LANGUAGE = istr('Accept-Language')
ACCEPT_RANGES = istr('Accept-Ranges')
ACCESS_CONTROL_MAX_AGE = istr('Access-Control-Max-Age')
ACCESS_CONTROL_ALLOW_CREDENTIALS = istr('Access-Control-Allow-Credentials')
ACCESS_CONTROL_ALLOW_HEADERS = istr('Access-Control-Allow-Headers')
ACCESS_CONTROL_ALLOW_METHODS = istr('Access-Control-Allow-Methods')
ACCESS_CONTROL_ALLOW_ORIGIN = istr('Access-Control-Allow-Origin')
ACCESS_CONTROL_EXPOSE_HEADERS = istr('Access-Control-Expose-Headers')
ACCESS_CONTROL_REQUEST_HEADERS = istr('Access-Control-Request-Headers')
ACCESS_CONTROL_REQUEST_METHOD = istr('Access-Control-Request-Method')
AGE = istr('Age')
ALLOW = istr('Allow')
AUTHORIZATION = istr('Authorization')
CACHE_CONTROL = istr('Cache-Control')
CONNECTION = istr('Connection')
CONTENT_DISPOSITION = istr('Content-Disposition')
CONTENT_ENCODING = istr('Content-Encoding')
CONTENT_LANGUAGE = istr('Content-Language')
CONTENT_LENGTH = istr('Content-Length')
CONTENT_LOCATION = istr('Content-Location')
CONTENT_MD5 = istr('Content-MD5')
CONTENT_RANGE = istr('Content-Range')
CONTENT_TRANSFER_ENCODING = istr('Content-Transfer-Encoding')
CONTENT_TYPE = istr('Content-Type')
COOKIE = istr('Cookie')
DATE = istr('Date')
DESTINATION = istr('Destination')
DIGEST = istr('Digest')
ETAG = istr('Etag')
EXPECT = istr('Expect')
EXPIRES = istr('Expires')
FORWARDED = istr('Forwarded')
FROM = istr('From')
HOST = istr('Host')
IF_MATCH = istr('If-Match')
IF_MODIFIED_SINCE = istr('If-Modified-Since')
IF_NONE_MATCH = istr('If-None-Match')
IF_RANGE = istr('If-Range')
IF_UNMODIFIED_SINCE = istr('If-Unmodified-Since')
KEEP_ALIVE = istr('Keep-Alive')
LAST_EVENT_ID = istr('Last-Event-ID')
LAST_MODIFIED = istr('Last-Modified')
LINK = istr('Link')
LOCATION = istr('Location')
MAX_FORWARDS = istr('Max-Forwards')
ORIGIN = istr('Origin')
PRAGMA = istr('Pragma')
PROXY_AUTHENTICATE = istr('Proxy-Authenticate')
PROXY_AUTHORIZATION = istr('Proxy-Authorization')
RANGE = istr('Range')
REFERER = istr('Referer')
RETRY_AFTER = istr('Retry-After')
SEC_WEBSOCKET_ACCEPT = istr('Sec-WebSocket-Accept')
SEC_WEBSOCKET_VERSION = istr('Sec-WebSocket-Version')
SEC_WEBSOCKET_PROTOCOL = istr('Sec-WebSocket-Protocol')
SEC_WEBSOCKET_EXTENSIONS = istr('Sec-WebSocket-Extensions')
SEC_WEBSOCKET_KEY = istr('Sec-WebSocket-Key')
SEC_WEBSOCKET_KEY1 = istr('Sec-WebSocket-Key1')
SERVER = istr('Server')
SET_COOKIE = istr('Set-Cookie')
TE = istr('TE')
TRAILER = istr('Trailer')
TRANSFER_ENCODING = istr('Transfer-Encoding')
UPGRADE = istr('Upgrade')
WEBSOCKET = istr('WebSocket')
URI = istr('URI')
USER_AGENT = istr('User-Agent')
VARY = istr('Vary')
VIA = istr('Via')
WANT_DIGEST = istr('Want-Digest')
WARNING = istr('Warning')
WWW_AUTHENTICATE = istr('WWW-Authenticate')
X_POWERED_BY = istr('X-Powered-By')
X_FORWARDED_FOR = istr('X-Forwarded-For')
X_FORWARDED_HOST = istr('X-Forwarded-Host')
X_FORWARDED_PROTO = istr('X-Forwarded-Proto')

View File

@ -0,0 +1,702 @@
"""Various helper functions"""
import asyncio
import base64
import binascii
import cgi
import datetime
import functools
import inspect
import netrc
import os
import platform
import re
import sys
import time
import warnings
import weakref
from collections import namedtuple
from contextlib import suppress
from math import ceil
from pathlib import Path
from types import TracebackType
from typing import ( # noqa
Any,
Callable,
Dict,
Generator,
Iterable,
Iterator,
List,
Mapping,
Optional,
Pattern,
Set,
Tuple,
Type,
TypeVar,
Union,
cast,
)
from urllib.parse import quote
from urllib.request import getproxies
import async_timeout
import attr
from multidict import MultiDict, MultiDictProxy
from yarl import URL
from . import hdrs
from .log import client_logger, internal_logger
from .typedefs import PathLike # noqa
__all__ = ('BasicAuth', 'ChainMapProxy')
PY_36 = sys.version_info >= (3, 6)
PY_37 = sys.version_info >= (3, 7)
PY_38 = sys.version_info >= (3, 8)
if not PY_37:
import idna_ssl
idna_ssl.patch_match_hostname()
try:
from typing import ContextManager
except ImportError:
from typing_extensions import ContextManager
def all_tasks(
loop: Optional[asyncio.AbstractEventLoop] = None
) -> Set['asyncio.Task[Any]']:
tasks = list(asyncio.Task.all_tasks(loop))
return {t for t in tasks if not t.done()}
if PY_37:
all_tasks = getattr(asyncio, 'all_tasks') # noqa
_T = TypeVar('_T')
sentinel = object() # type: Any
NO_EXTENSIONS = bool(os.environ.get('AIOHTTP_NO_EXTENSIONS')) # type: bool
# N.B. sys.flags.dev_mode is available on Python 3.7+, use getattr
# for compatibility with older versions
DEBUG = (getattr(sys.flags, 'dev_mode', False) or
(not sys.flags.ignore_environment and
bool(os.environ.get('PYTHONASYNCIODEBUG')))) # type: bool
CHAR = set(chr(i) for i in range(0, 128))
CTL = set(chr(i) for i in range(0, 32)) | {chr(127), }
SEPARATORS = {'(', ')', '<', '>', '@', ',', ';', ':', '\\', '"', '/', '[', ']',
'?', '=', '{', '}', ' ', chr(9)}
TOKEN = CHAR ^ CTL ^ SEPARATORS
class noop:
def __await__(self) -> Generator[None, None, None]:
yield
class BasicAuth(namedtuple('BasicAuth', ['login', 'password', 'encoding'])):
"""Http basic authentication helper."""
def __new__(cls, login: str,
password: str='',
encoding: str='latin1') -> 'BasicAuth':
if login is None:
raise ValueError('None is not allowed as login value')
if password is None:
raise ValueError('None is not allowed as password value')
if ':' in login:
raise ValueError(
'A ":" is not allowed in login (RFC 1945#section-11.1)')
return super().__new__(cls, login, password, encoding)
@classmethod
def decode(cls, auth_header: str, encoding: str='latin1') -> 'BasicAuth':
"""Create a BasicAuth object from an Authorization HTTP header."""
try:
auth_type, encoded_credentials = auth_header.split(' ', 1)
except ValueError:
raise ValueError('Could not parse authorization header.')
if auth_type.lower() != 'basic':
raise ValueError('Unknown authorization method %s' % auth_type)
try:
decoded = base64.b64decode(
encoded_credentials.encode('ascii'), validate=True
).decode(encoding)
except binascii.Error:
raise ValueError('Invalid base64 encoding.')
try:
# RFC 2617 HTTP Authentication
# https://www.ietf.org/rfc/rfc2617.txt
# the colon must be present, but the username and password may be
# otherwise blank.
username, password = decoded.split(':', 1)
except ValueError:
raise ValueError('Invalid credentials.')
return cls(username, password, encoding=encoding)
@classmethod
def from_url(cls, url: URL,
*, encoding: str='latin1') -> Optional['BasicAuth']:
"""Create BasicAuth from url."""
if not isinstance(url, URL):
raise TypeError("url should be yarl.URL instance")
if url.user is None:
return None
return cls(url.user, url.password or '', encoding=encoding)
def encode(self) -> str:
"""Encode credentials."""
creds = ('%s:%s' % (self.login, self.password)).encode(self.encoding)
return 'Basic %s' % base64.b64encode(creds).decode(self.encoding)
def strip_auth_from_url(url: URL) -> Tuple[URL, Optional[BasicAuth]]:
auth = BasicAuth.from_url(url)
if auth is None:
return url, None
else:
return url.with_user(None), auth
def netrc_from_env() -> Optional[netrc.netrc]:
"""Attempt to load the netrc file from the path specified by the env-var
NETRC or in the default location in the user's home directory.
Returns None if it couldn't be found or fails to parse.
"""
netrc_env = os.environ.get('NETRC')
if netrc_env is not None:
netrc_path = Path(netrc_env)
else:
try:
home_dir = Path.home()
except RuntimeError as e: # pragma: no cover
# if pathlib can't resolve home, it may raise a RuntimeError
client_logger.debug('Could not resolve home directory when '
'trying to look for .netrc file: %s', e)
return None
netrc_path = home_dir / (
'_netrc' if platform.system() == 'Windows' else '.netrc')
try:
return netrc.netrc(str(netrc_path))
except netrc.NetrcParseError as e:
client_logger.warning('Could not parse .netrc file: %s', e)
except OSError as e:
# we couldn't read the file (doesn't exist, permissions, etc.)
if netrc_env or netrc_path.is_file():
# only warn if the environment wanted us to load it,
# or it appears like the default file does actually exist
client_logger.warning('Could not read .netrc file: %s', e)
return None
@attr.s(frozen=True, slots=True)
class ProxyInfo:
proxy = attr.ib(type=URL)
proxy_auth = attr.ib(type=Optional[BasicAuth])
def proxies_from_env() -> Dict[str, ProxyInfo]:
proxy_urls = {k: URL(v) for k, v in getproxies().items()
if k in ('http', 'https')}
netrc_obj = netrc_from_env()
stripped = {k: strip_auth_from_url(v) for k, v in proxy_urls.items()}
ret = {}
for proto, val in stripped.items():
proxy, auth = val
if proxy.scheme == 'https':
client_logger.warning(
"HTTPS proxies %s are not supported, ignoring", proxy)
continue
if netrc_obj and auth is None:
auth_from_netrc = None
if proxy.host is not None:
auth_from_netrc = netrc_obj.authenticators(proxy.host)
if auth_from_netrc is not None:
# auth_from_netrc is a (`user`, `account`, `password`) tuple,
# `user` and `account` both can be username,
# if `user` is None, use `account`
*logins, password = auth_from_netrc
login = logins[0] if logins[0] else logins[-1]
auth = BasicAuth(cast(str, login), cast(str, password))
ret[proto] = ProxyInfo(proxy, auth)
return ret
def current_task(loop: Optional[asyncio.AbstractEventLoop]=None) -> asyncio.Task: # type: ignore # noqa # Return type is intentionally Generic here
if PY_37:
return asyncio.current_task(loop=loop) # type: ignore
else:
return asyncio.Task.current_task(loop=loop)
def get_running_loop(
loop: Optional[asyncio.AbstractEventLoop]=None
) -> asyncio.AbstractEventLoop:
if loop is None:
loop = asyncio.get_event_loop()
if not loop.is_running():
warnings.warn("The object should be created within an async function",
DeprecationWarning, stacklevel=3)
if loop.get_debug():
internal_logger.warning(
"The object should be created within an async function",
stack_info=True)
return loop
def isasyncgenfunction(obj: Any) -> bool:
func = getattr(inspect, 'isasyncgenfunction', None)
if func is not None:
return func(obj)
else:
return False
@attr.s(frozen=True, slots=True)
class MimeType:
type = attr.ib(type=str)
subtype = attr.ib(type=str)
suffix = attr.ib(type=str)
parameters = attr.ib(type=MultiDictProxy) # type: MultiDictProxy[str]
@functools.lru_cache(maxsize=56)
def parse_mimetype(mimetype: str) -> MimeType:
"""Parses a MIME type into its components.
mimetype is a MIME type string.
Returns a MimeType object.
Example:
>>> parse_mimetype('text/html; charset=utf-8')
MimeType(type='text', subtype='html', suffix='',
parameters={'charset': 'utf-8'})
"""
if not mimetype:
return MimeType(type='', subtype='', suffix='',
parameters=MultiDictProxy(MultiDict()))
parts = mimetype.split(';')
params = MultiDict() # type: MultiDict[str]
for item in parts[1:]:
if not item:
continue
key, value = cast(Tuple[str, str],
item.split('=', 1) if '=' in item else (item, ''))
params.add(key.lower().strip(), value.strip(' "'))
fulltype = parts[0].strip().lower()
if fulltype == '*':
fulltype = '*/*'
mtype, stype = (cast(Tuple[str, str], fulltype.split('/', 1))
if '/' in fulltype else (fulltype, ''))
stype, suffix = (cast(Tuple[str, str], stype.split('+', 1))
if '+' in stype else (stype, ''))
return MimeType(type=mtype, subtype=stype, suffix=suffix,
parameters=MultiDictProxy(params))
def guess_filename(obj: Any, default: Optional[str]=None) -> Optional[str]:
name = getattr(obj, 'name', None)
if name and isinstance(name, str) and name[0] != '<' and name[-1] != '>':
return Path(name).name
return default
def content_disposition_header(disptype: str,
quote_fields: bool=True,
**params: str) -> str:
"""Sets ``Content-Disposition`` header.
disptype is a disposition type: inline, attachment, form-data.
Should be valid extension token (see RFC 2183)
params is a dict with disposition params.
"""
if not disptype or not (TOKEN > set(disptype)):
raise ValueError('bad content disposition type {!r}'
''.format(disptype))
value = disptype
if params:
lparams = []
for key, val in params.items():
if not key or not (TOKEN > set(key)):
raise ValueError('bad content disposition parameter'
' {!r}={!r}'.format(key, val))
qval = quote(val, '') if quote_fields else val
lparams.append((key, '"%s"' % qval))
if key == 'filename':
lparams.append(('filename*', "utf-8''" + qval))
sparams = '; '.join('='.join(pair) for pair in lparams)
value = '; '.join((value, sparams))
return value
class reify:
"""Use as a class method decorator. It operates almost exactly like
the Python `@property` decorator, but it puts the result of the
method it decorates into the instance dict after the first call,
effectively replacing the function it decorates with an instance
variable. It is, in Python parlance, a data descriptor.
"""
def __init__(self, wrapped: Callable[..., Any]) -> None:
self.wrapped = wrapped
self.__doc__ = wrapped.__doc__
self.name = wrapped.__name__
def __get__(self, inst: Any, owner: Any) -> Any:
try:
try:
return inst._cache[self.name]
except KeyError:
val = self.wrapped(inst)
inst._cache[self.name] = val
return val
except AttributeError:
if inst is None:
return self
raise
def __set__(self, inst: Any, value: Any) -> None:
raise AttributeError("reified property is read-only")
reify_py = reify
try:
from ._helpers import reify as reify_c
if not NO_EXTENSIONS:
reify = reify_c # type: ignore
except ImportError:
pass
_ipv4_pattern = (r'^(?:(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}'
r'(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)$')
_ipv6_pattern = (
r'^(?:(?:(?:[A-F0-9]{1,4}:){6}|(?=(?:[A-F0-9]{0,4}:){0,6}'
r'(?:[0-9]{1,3}\.){3}[0-9]{1,3}$)(([0-9A-F]{1,4}:){0,5}|:)'
r'((:[0-9A-F]{1,4}){1,5}:|:)|::(?:[A-F0-9]{1,4}:){5})'
r'(?:(?:25[0-5]|2[0-4][0-9]|1[0-9][0-9]|[1-9]?[0-9])\.){3}'
r'(?:25[0-5]|2[0-4][0-9]|1[0-9][0-9]|[1-9]?[0-9])|(?:[A-F0-9]{1,4}:){7}'
r'[A-F0-9]{1,4}|(?=(?:[A-F0-9]{0,4}:){0,7}[A-F0-9]{0,4}$)'
r'(([0-9A-F]{1,4}:){1,7}|:)((:[0-9A-F]{1,4}){1,7}|:)|(?:[A-F0-9]{1,4}:){7}'
r':|:(:[A-F0-9]{1,4}){7})$')
_ipv4_regex = re.compile(_ipv4_pattern)
_ipv6_regex = re.compile(_ipv6_pattern, flags=re.IGNORECASE)
_ipv4_regexb = re.compile(_ipv4_pattern.encode('ascii'))
_ipv6_regexb = re.compile(_ipv6_pattern.encode('ascii'), flags=re.IGNORECASE)
def _is_ip_address(
regex: Pattern[str], regexb: Pattern[bytes],
host: Optional[Union[str, bytes]]) -> bool:
if host is None:
return False
if isinstance(host, str):
return bool(regex.match(host))
elif isinstance(host, (bytes, bytearray, memoryview)):
return bool(regexb.match(host))
else:
raise TypeError("{} [{}] is not a str or bytes"
.format(host, type(host)))
is_ipv4_address = functools.partial(_is_ip_address, _ipv4_regex, _ipv4_regexb)
is_ipv6_address = functools.partial(_is_ip_address, _ipv6_regex, _ipv6_regexb)
def is_ip_address(
host: Optional[Union[str, bytes, bytearray, memoryview]]) -> bool:
return is_ipv4_address(host) or is_ipv6_address(host)
def next_whole_second() -> datetime.datetime:
"""Return current time rounded up to the next whole second."""
return (
datetime.datetime.now(
datetime.timezone.utc).replace(microsecond=0) +
datetime.timedelta(seconds=0)
)
_cached_current_datetime = None # type: Optional[int]
_cached_formatted_datetime = ""
def rfc822_formatted_time() -> str:
global _cached_current_datetime
global _cached_formatted_datetime
now = int(time.time())
if now != _cached_current_datetime:
# Weekday and month names for HTTP date/time formatting;
# always English!
# Tuples are constants stored in codeobject!
_weekdayname = ("Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun")
_monthname = ("", # Dummy so we can use 1-based month numbers
"Jan", "Feb", "Mar", "Apr", "May", "Jun",
"Jul", "Aug", "Sep", "Oct", "Nov", "Dec")
year, month, day, hh, mm, ss, wd, *tail = time.gmtime(now)
_cached_formatted_datetime = "%s, %02d %3s %4d %02d:%02d:%02d GMT" % (
_weekdayname[wd], day, _monthname[month], year, hh, mm, ss
)
_cached_current_datetime = now
return _cached_formatted_datetime
def _weakref_handle(info): # type: ignore
ref, name = info
ob = ref()
if ob is not None:
with suppress(Exception):
getattr(ob, name)()
def weakref_handle(ob, name, timeout, loop, ceil_timeout=True): # type: ignore
if timeout is not None and timeout > 0:
when = loop.time() + timeout
if ceil_timeout:
when = ceil(when)
return loop.call_at(when, _weakref_handle, (weakref.ref(ob), name))
def call_later(cb, timeout, loop): # type: ignore
if timeout is not None and timeout > 0:
when = ceil(loop.time() + timeout)
return loop.call_at(when, cb)
class TimeoutHandle:
""" Timeout handle """
def __init__(self,
loop: asyncio.AbstractEventLoop,
timeout: Optional[float]) -> None:
self._timeout = timeout
self._loop = loop
self._callbacks = [] # type: List[Tuple[Callable[..., None], Tuple[Any, ...], Dict[str, Any]]] # noqa
def register(self, callback: Callable[..., None],
*args: Any, **kwargs: Any) -> None:
self._callbacks.append((callback, args, kwargs))
def close(self) -> None:
self._callbacks.clear()
def start(self) -> Optional[asyncio.Handle]:
if self._timeout is not None and self._timeout > 0:
at = ceil(self._loop.time() + self._timeout)
return self._loop.call_at(at, self.__call__)
else:
return None
def timer(self) -> 'BaseTimerContext':
if self._timeout is not None and self._timeout > 0:
timer = TimerContext(self._loop)
self.register(timer.timeout)
return timer
else:
return TimerNoop()
def __call__(self) -> None:
for cb, args, kwargs in self._callbacks:
with suppress(Exception):
cb(*args, **kwargs)
self._callbacks.clear()
class BaseTimerContext(ContextManager['BaseTimerContext']):
pass
class TimerNoop(BaseTimerContext):
def __enter__(self) -> BaseTimerContext:
return self
def __exit__(self, exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType]) -> None:
return
class TimerContext(BaseTimerContext):
""" Low resolution timeout context manager """
def __init__(self, loop: asyncio.AbstractEventLoop) -> None:
self._loop = loop
self._tasks = [] # type: List[asyncio.Task[Any]]
self._cancelled = False
def __enter__(self) -> BaseTimerContext:
task = current_task(loop=self._loop)
if task is None:
raise RuntimeError('Timeout context manager should be used '
'inside a task')
if self._cancelled:
task.cancel()
raise asyncio.TimeoutError from None
self._tasks.append(task)
return self
def __exit__(self, exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType]) -> Optional[bool]:
if self._tasks:
self._tasks.pop()
if exc_type is asyncio.CancelledError and self._cancelled:
raise asyncio.TimeoutError from None
return None
def timeout(self) -> None:
if not self._cancelled:
for task in set(self._tasks):
task.cancel()
self._cancelled = True
class CeilTimeout(async_timeout.timeout):
def __enter__(self) -> async_timeout.timeout:
if self._timeout is not None:
self._task = current_task(loop=self._loop)
if self._task is None:
raise RuntimeError(
'Timeout context manager should be used inside a task')
self._cancel_handler = self._loop.call_at(
ceil(self._loop.time() + self._timeout), self._cancel_task)
return self
class HeadersMixin:
ATTRS = frozenset([
'_content_type', '_content_dict', '_stored_content_type'])
_content_type = None # type: Optional[str]
_content_dict = None # type: Optional[Dict[str, str]]
_stored_content_type = sentinel
def _parse_content_type(self, raw: str) -> None:
self._stored_content_type = raw
if raw is None:
# default value according to RFC 2616
self._content_type = 'application/octet-stream'
self._content_dict = {}
else:
self._content_type, self._content_dict = cgi.parse_header(raw)
@property
def content_type(self) -> str:
"""The value of content part for Content-Type HTTP header."""
raw = self._headers.get(hdrs.CONTENT_TYPE) # type: ignore
if self._stored_content_type != raw:
self._parse_content_type(raw)
return self._content_type # type: ignore
@property
def charset(self) -> Optional[str]:
"""The value of charset part for Content-Type HTTP header."""
raw = self._headers.get(hdrs.CONTENT_TYPE) # type: ignore
if self._stored_content_type != raw:
self._parse_content_type(raw)
return self._content_dict.get('charset') # type: ignore
@property
def content_length(self) -> Optional[int]:
"""The value of Content-Length HTTP header."""
content_length = self._headers.get(hdrs.CONTENT_LENGTH) # type: ignore
if content_length is not None:
return int(content_length)
else:
return None
def set_result(fut: 'asyncio.Future[_T]', result: _T) -> None:
if not fut.done():
fut.set_result(result)
def set_exception(fut: 'asyncio.Future[_T]', exc: BaseException) -> None:
if not fut.done():
fut.set_exception(exc)
class ChainMapProxy(Mapping[str, Any]):
__slots__ = ('_maps',)
def __init__(self, maps: Iterable[Mapping[str, Any]]) -> None:
self._maps = tuple(maps)
def __init_subclass__(cls) -> None:
raise TypeError("Inheritance class {} from ChainMapProxy "
"is forbidden".format(cls.__name__))
def __getitem__(self, key: str) -> Any:
for mapping in self._maps:
try:
return mapping[key]
except KeyError:
pass
raise KeyError(key)
def get(self, key: str, default: Any=None) -> Any:
return self[key] if key in self else default
def __len__(self) -> int:
# reuses stored hash values if possible
return len(set().union(*self._maps)) # type: ignore
def __iter__(self) -> Iterator[str]:
d = {} # type: Dict[str, Any]
for mapping in reversed(self._maps):
# reuses stored hash values if possible
d.update(mapping)
return iter(d)
def __contains__(self, key: object) -> bool:
return any(key in m for m in self._maps)
def __bool__(self) -> bool:
return any(self._maps)
def __repr__(self) -> str:
content = ", ".join(map(repr, self._maps))
return 'ChainMapProxy({})'.format(content)

View File

@ -0,0 +1,50 @@
import http.server
import sys
from typing import Mapping, Tuple # noqa
from . import __version__
from .http_exceptions import HttpProcessingError as HttpProcessingError
from .http_parser import HeadersParser as HeadersParser
from .http_parser import HttpParser as HttpParser
from .http_parser import HttpRequestParser as HttpRequestParser
from .http_parser import HttpResponseParser as HttpResponseParser
from .http_parser import RawRequestMessage as RawRequestMessage
from .http_parser import RawResponseMessage as RawResponseMessage
from .http_websocket import WS_CLOSED_MESSAGE as WS_CLOSED_MESSAGE
from .http_websocket import WS_CLOSING_MESSAGE as WS_CLOSING_MESSAGE
from .http_websocket import WS_KEY as WS_KEY
from .http_websocket import WebSocketError as WebSocketError
from .http_websocket import WebSocketReader as WebSocketReader
from .http_websocket import WebSocketWriter as WebSocketWriter
from .http_websocket import WSCloseCode as WSCloseCode
from .http_websocket import WSMessage as WSMessage
from .http_websocket import WSMsgType as WSMsgType
from .http_websocket import ws_ext_gen as ws_ext_gen
from .http_websocket import ws_ext_parse as ws_ext_parse
from .http_writer import HttpVersion as HttpVersion
from .http_writer import HttpVersion10 as HttpVersion10
from .http_writer import HttpVersion11 as HttpVersion11
from .http_writer import StreamWriter as StreamWriter
__all__ = (
'HttpProcessingError', 'RESPONSES', 'SERVER_SOFTWARE',
# .http_writer
'StreamWriter', 'HttpVersion', 'HttpVersion10', 'HttpVersion11',
# .http_parser
'HeadersParser', 'HttpParser',
'HttpRequestParser', 'HttpResponseParser',
'RawRequestMessage', 'RawResponseMessage',
# .http_websocket
'WS_CLOSED_MESSAGE', 'WS_CLOSING_MESSAGE', 'WS_KEY',
'WebSocketReader', 'WebSocketWriter', 'ws_ext_gen', 'ws_ext_parse',
'WSMessage', 'WebSocketError', 'WSMsgType', 'WSCloseCode',
)
SERVER_SOFTWARE = 'Python/{0[0]}.{0[1]} hyper_internal_service/{1}'.format(
sys.version_info, __version__) # type: str
RESPONSES = http.server.BaseHTTPRequestHandler.responses # type: Mapping[int, Tuple[str, str]] # noqa

View File

@ -0,0 +1,108 @@
"""Low-level http related exceptions."""
from typing import Optional, Union
from .typedefs import _CIMultiDict
__all__ = ('HttpProcessingError',)
class HttpProcessingError(Exception):
"""HTTP error.
Shortcut for raising HTTP errors with custom code, message and headers.
code: HTTP Error code.
message: (optional) Error message.
headers: (optional) Headers to be sent in response, a list of pairs
"""
code = 0
message = ''
headers = None
def __init__(self, *,
code: Optional[int]=None,
message: str='',
headers: Optional[_CIMultiDict]=None) -> None:
if code is not None:
self.code = code
self.headers = headers
self.message = message
def __str__(self) -> str:
return "%s, message=%r" % (self.code, self.message)
def __repr__(self) -> str:
return "<%s: %s>" % (self.__class__.__name__, self)
class BadHttpMessage(HttpProcessingError):
code = 400
message = 'Bad Request'
def __init__(self, message: str, *,
headers: Optional[_CIMultiDict]=None) -> None:
super().__init__(message=message, headers=headers)
self.args = (message,)
class HttpBadRequest(BadHttpMessage):
code = 400
message = 'Bad Request'
class PayloadEncodingError(BadHttpMessage):
"""Base class for payload errors"""
class ContentEncodingError(PayloadEncodingError):
"""Content encoding error."""
class TransferEncodingError(PayloadEncodingError):
"""transfer encoding error."""
class ContentLengthError(PayloadEncodingError):
"""Not enough data for satisfy content length header."""
class LineTooLong(BadHttpMessage):
def __init__(self, line: str,
limit: str='Unknown',
actual_size: str='Unknown') -> None:
super().__init__(
"Got more than %s bytes (%s) when reading %s." % (
limit, actual_size, line))
self.args = (line, limit, actual_size)
class InvalidHeader(BadHttpMessage):
def __init__(self, hdr: Union[bytes, str]) -> None:
if isinstance(hdr, bytes):
hdr = hdr.decode('utf-8', 'surrogateescape')
super().__init__('Invalid HTTP Header: {}'.format(hdr))
self.hdr = hdr
self.args = (hdr,)
class BadStatusLine(BadHttpMessage):
def __init__(self, line: str='') -> None:
if not isinstance(line, str):
line = repr(line)
self.args = (line,)
self.line = line
__str__ = Exception.__str__
__repr__ = Exception.__repr__
class InvalidURLError(BadHttpMessage):
pass

View File

@ -0,0 +1,776 @@
import abc
import asyncio
import collections
import re
import string
import zlib
from enum import IntEnum
from typing import Any, List, Optional, Tuple, Type, Union # noqa
from multidict import CIMultiDict, CIMultiDictProxy, istr
from yarl import URL
from . import hdrs
from .base_protocol import BaseProtocol
from .helpers import NO_EXTENSIONS, BaseTimerContext
from .http_exceptions import (
BadStatusLine,
ContentEncodingError,
ContentLengthError,
InvalidHeader,
LineTooLong,
TransferEncodingError,
)
from .http_writer import HttpVersion, HttpVersion10
from .log import internal_logger
from .streams import EMPTY_PAYLOAD, StreamReader
from .typedefs import RawHeaders
try:
import brotli
HAS_BROTLI = True
except ImportError: # pragma: no cover
HAS_BROTLI = False
__all__ = (
'HeadersParser', 'HttpParser', 'HttpRequestParser', 'HttpResponseParser',
'RawRequestMessage', 'RawResponseMessage')
ASCIISET = set(string.printable)
# See https://tools.ietf.org/html/rfc7230#section-3.1.1
# and https://tools.ietf.org/html/rfc7230#appendix-B
#
# method = token
# tchar = "!" / "#" / "$" / "%" / "&" / "'" / "*" / "+" / "-" / "." /
# "^" / "_" / "`" / "|" / "~" / DIGIT / ALPHA
# token = 1*tchar
METHRE = re.compile(r"[!#$%&'*+\-.^_`|~0-9A-Za-z]+")
VERSRE = re.compile(r'HTTP/(\d+).(\d+)')
HDRRE = re.compile(rb'[\x00-\x1F\x7F()<>@,;:\[\]={} \t\\\\\"]')
RawRequestMessage = collections.namedtuple(
'RawRequestMessage',
['method', 'path', 'version', 'headers', 'raw_headers',
'should_close', 'compression', 'upgrade', 'chunked', 'url'])
RawResponseMessage = collections.namedtuple(
'RawResponseMessage',
['version', 'code', 'reason', 'headers', 'raw_headers',
'should_close', 'compression', 'upgrade', 'chunked'])
class ParseState(IntEnum):
PARSE_NONE = 0
PARSE_LENGTH = 1
PARSE_CHUNKED = 2
PARSE_UNTIL_EOF = 3
class ChunkState(IntEnum):
PARSE_CHUNKED_SIZE = 0
PARSE_CHUNKED_CHUNK = 1
PARSE_CHUNKED_CHUNK_EOF = 2
PARSE_MAYBE_TRAILERS = 3
PARSE_TRAILERS = 4
class HeadersParser:
def __init__(self,
max_line_size: int=8190,
max_headers: int=32768,
max_field_size: int=8190) -> None:
self.max_line_size = max_line_size
self.max_headers = max_headers
self.max_field_size = max_field_size
def parse_headers(
self,
lines: List[bytes]
) -> Tuple['CIMultiDictProxy[str]', RawHeaders]:
headers = CIMultiDict() # type: CIMultiDict[str]
raw_headers = []
lines_idx = 1
line = lines[1]
line_count = len(lines)
while line:
# Parse initial header name : value pair.
try:
bname, bvalue = line.split(b':', 1)
except ValueError:
raise InvalidHeader(line) from None
bname = bname.strip(b' \t')
bvalue = bvalue.lstrip()
if HDRRE.search(bname):
raise InvalidHeader(bname)
if len(bname) > self.max_field_size:
raise LineTooLong(
"request header name {}".format(
bname.decode("utf8", "xmlcharrefreplace")),
str(self.max_field_size),
str(len(bname)))
header_length = len(bvalue)
# next line
lines_idx += 1
line = lines[lines_idx]
# consume continuation lines
continuation = line and line[0] in (32, 9) # (' ', '\t')
if continuation:
bvalue_lst = [bvalue]
while continuation:
header_length += len(line)
if header_length > self.max_field_size:
raise LineTooLong(
'request header field {}'.format(
bname.decode("utf8", "xmlcharrefreplace")),
str(self.max_field_size),
str(header_length))
bvalue_lst.append(line)
# next line
lines_idx += 1
if lines_idx < line_count:
line = lines[lines_idx]
if line:
continuation = line[0] in (32, 9) # (' ', '\t')
else:
line = b''
break
bvalue = b''.join(bvalue_lst)
else:
if header_length > self.max_field_size:
raise LineTooLong(
'request header field {}'.format(
bname.decode("utf8", "xmlcharrefreplace")),
str(self.max_field_size),
str(header_length))
bvalue = bvalue.strip()
name = bname.decode('utf-8', 'surrogateescape')
value = bvalue.decode('utf-8', 'surrogateescape')
headers.add(name, value)
raw_headers.append((bname, bvalue))
return (CIMultiDictProxy(headers), tuple(raw_headers))
class HttpParser(abc.ABC):
def __init__(self, protocol: Optional[BaseProtocol]=None,
loop: Optional[asyncio.AbstractEventLoop]=None,
max_line_size: int=8190,
max_headers: int=32768,
max_field_size: int=8190,
timer: Optional[BaseTimerContext]=None,
code: Optional[int]=None,
method: Optional[str]=None,
readall: bool=False,
payload_exception: Optional[Type[BaseException]]=None,
response_with_body: bool=True,
read_until_eof: bool=False,
auto_decompress: bool=True) -> None:
self.protocol = protocol
self.loop = loop
self.max_line_size = max_line_size
self.max_headers = max_headers
self.max_field_size = max_field_size
self.timer = timer
self.code = code
self.method = method
self.readall = readall
self.payload_exception = payload_exception
self.response_with_body = response_with_body
self.read_until_eof = read_until_eof
self._lines = [] # type: List[bytes]
self._tail = b''
self._upgraded = False
self._payload = None
self._payload_parser = None # type: Optional[HttpPayloadParser]
self._auto_decompress = auto_decompress
self._headers_parser = HeadersParser(max_line_size,
max_headers,
max_field_size)
@abc.abstractmethod
def parse_message(self, lines: List[bytes]) -> Any:
pass
def feed_eof(self) -> Any:
if self._payload_parser is not None:
self._payload_parser.feed_eof()
self._payload_parser = None
else:
# try to extract partial message
if self._tail:
self._lines.append(self._tail)
if self._lines:
if self._lines[-1] != '\r\n':
self._lines.append(b'')
try:
return self.parse_message(self._lines)
except Exception:
return None
def feed_data(
self,
data: bytes,
SEP: bytes=b'\r\n',
EMPTY: bytes=b'',
CONTENT_LENGTH: istr=hdrs.CONTENT_LENGTH,
METH_CONNECT: str=hdrs.METH_CONNECT,
SEC_WEBSOCKET_KEY1: istr=hdrs.SEC_WEBSOCKET_KEY1
) -> Tuple[List[Any], bool, bytes]:
messages = []
if self._tail:
data, self._tail = self._tail + data, b''
data_len = len(data)
start_pos = 0
loop = self.loop
while start_pos < data_len:
# read HTTP message (request/response line + headers), \r\n\r\n
# and split by lines
if self._payload_parser is None and not self._upgraded:
pos = data.find(SEP, start_pos)
# consume \r\n
if pos == start_pos and not self._lines:
start_pos = pos + 2
continue
if pos >= start_pos:
# line found
self._lines.append(data[start_pos:pos])
start_pos = pos + 2
# \r\n\r\n found
if self._lines[-1] == EMPTY:
try:
msg = self.parse_message(self._lines)
finally:
self._lines.clear()
# payload length
length = msg.headers.get(CONTENT_LENGTH)
if length is not None:
try:
length = int(length)
except ValueError:
raise InvalidHeader(CONTENT_LENGTH)
if length < 0:
raise InvalidHeader(CONTENT_LENGTH)
# do not support old websocket spec
if SEC_WEBSOCKET_KEY1 in msg.headers:
raise InvalidHeader(SEC_WEBSOCKET_KEY1)
self._upgraded = msg.upgrade
method = getattr(msg, 'method', self.method)
assert self.protocol is not None
# calculate payload
if ((length is not None and length > 0) or
msg.chunked and not msg.upgrade):
payload = StreamReader(
self.protocol, timer=self.timer, loop=loop)
payload_parser = HttpPayloadParser(
payload, length=length,
chunked=msg.chunked, method=method,
compression=msg.compression,
code=self.code, readall=self.readall,
response_with_body=self.response_with_body,
auto_decompress=self._auto_decompress)
if not payload_parser.done:
self._payload_parser = payload_parser
elif method == METH_CONNECT:
payload = StreamReader(
self.protocol, timer=self.timer, loop=loop)
self._upgraded = True
self._payload_parser = HttpPayloadParser(
payload, method=msg.method,
compression=msg.compression, readall=True,
auto_decompress=self._auto_decompress)
else:
if (getattr(msg, 'code', 100) >= 199 and
length is None and self.read_until_eof):
payload = StreamReader(
self.protocol, timer=self.timer, loop=loop)
payload_parser = HttpPayloadParser(
payload, length=length,
chunked=msg.chunked, method=method,
compression=msg.compression,
code=self.code, readall=True,
response_with_body=self.response_with_body,
auto_decompress=self._auto_decompress)
if not payload_parser.done:
self._payload_parser = payload_parser
else:
payload = EMPTY_PAYLOAD # type: ignore
messages.append((msg, payload))
else:
self._tail = data[start_pos:]
data = EMPTY
break
# no parser, just store
elif self._payload_parser is None and self._upgraded:
assert not self._lines
break
# feed payload
elif data and start_pos < data_len:
assert not self._lines
assert self._payload_parser is not None
try:
eof, data = self._payload_parser.feed_data(
data[start_pos:])
except BaseException as exc:
if self.payload_exception is not None:
self._payload_parser.payload.set_exception(
self.payload_exception(str(exc)))
else:
self._payload_parser.payload.set_exception(exc)
eof = True
data = b''
if eof:
start_pos = 0
data_len = len(data)
self._payload_parser = None
continue
else:
break
if data and start_pos < data_len:
data = data[start_pos:]
else:
data = EMPTY
return messages, self._upgraded, data
def parse_headers(
self,
lines: List[bytes]
) -> Tuple['CIMultiDictProxy[str]',
RawHeaders,
Optional[bool],
Optional[str],
bool,
bool]:
"""Parses RFC 5322 headers from a stream.
Line continuations are supported. Returns list of header name
and value pairs. Header name is in upper case.
"""
headers, raw_headers = self._headers_parser.parse_headers(lines)
close_conn = None
encoding = None
upgrade = False
chunked = False
# keep-alive
conn = headers.get(hdrs.CONNECTION)
if conn:
v = conn.lower()
if v == 'close':
close_conn = True
elif v == 'keep-alive':
close_conn = False
elif v == 'upgrade':
upgrade = True
# encoding
enc = headers.get(hdrs.CONTENT_ENCODING)
if enc:
enc = enc.lower()
if enc in ('gzip', 'deflate', 'br'):
encoding = enc
# chunking
te = headers.get(hdrs.TRANSFER_ENCODING)
if te and 'chunked' in te.lower():
chunked = True
return (headers, raw_headers, close_conn, encoding, upgrade, chunked)
def set_upgraded(self, val: bool) -> None:
"""Set connection upgraded (to websocket) mode.
:param bool val: new state.
"""
self._upgraded = val
class HttpRequestParser(HttpParser):
"""Read request status line. Exception .http_exceptions.BadStatusLine
could be raised in case of any errors in status line.
Returns RawRequestMessage.
"""
def parse_message(self, lines: List[bytes]) -> Any:
# request line
line = lines[0].decode('utf-8', 'surrogateescape')
try:
method, path, version = line.split(None, 2)
except ValueError:
raise BadStatusLine(line) from None
if len(path) > self.max_line_size:
raise LineTooLong(
'Status line is too long',
str(self.max_line_size),
str(len(path)))
# method
if not METHRE.match(method):
raise BadStatusLine(method)
# version
try:
if version.startswith('HTTP/'):
n1, n2 = version[5:].split('.', 1)
version_o = HttpVersion(int(n1), int(n2))
else:
raise BadStatusLine(version)
except Exception:
raise BadStatusLine(version)
# read headers
(headers, raw_headers,
close, compression, upgrade, chunked) = self.parse_headers(lines)
if close is None: # then the headers weren't set in the request
if version_o <= HttpVersion10: # HTTP 1.0 must asks to not close
close = True
else: # HTTP 1.1 must ask to close.
close = False
return RawRequestMessage(
method, path, version_o, headers, raw_headers,
close, compression, upgrade, chunked, URL(path))
class HttpResponseParser(HttpParser):
"""Read response status line and headers.
BadStatusLine could be raised in case of any errors in status line.
Returns RawResponseMessage"""
def parse_message(self, lines: List[bytes]) -> Any:
line = lines[0].decode('utf-8', 'surrogateescape')
try:
version, status = line.split(None, 1)
except ValueError:
raise BadStatusLine(line) from None
try:
status, reason = status.split(None, 1)
except ValueError:
reason = ''
if len(reason) > self.max_line_size:
raise LineTooLong(
'Status line is too long',
str(self.max_line_size),
str(len(reason)))
# version
match = VERSRE.match(version)
if match is None:
raise BadStatusLine(line)
version_o = HttpVersion(int(match.group(1)), int(match.group(2)))
# The status code is a three-digit number
try:
status_i = int(status)
except ValueError:
raise BadStatusLine(line) from None
if status_i > 999:
raise BadStatusLine(line)
# read headers
(headers, raw_headers,
close, compression, upgrade, chunked) = self.parse_headers(lines)
if close is None:
close = version_o <= HttpVersion10
return RawResponseMessage(
version_o, status_i, reason.strip(),
headers, raw_headers, close, compression, upgrade, chunked)
class HttpPayloadParser:
def __init__(self, payload: StreamReader,
length: Optional[int]=None,
chunked: bool=False,
compression: Optional[str]=None,
code: Optional[int]=None,
method: Optional[str]=None,
readall: bool=False,
response_with_body: bool=True,
auto_decompress: bool=True) -> None:
self._length = 0
self._type = ParseState.PARSE_NONE
self._chunk = ChunkState.PARSE_CHUNKED_SIZE
self._chunk_size = 0
self._chunk_tail = b''
self._auto_decompress = auto_decompress
self.done = False
# payload decompression wrapper
if response_with_body and compression and self._auto_decompress:
real_payload = DeflateBuffer(payload, compression) # type: Union[StreamReader, DeflateBuffer] # noqa
else:
real_payload = payload
# payload parser
if not response_with_body:
# don't parse payload if it's not expected to be received
self._type = ParseState.PARSE_NONE
real_payload.feed_eof()
self.done = True
elif chunked:
self._type = ParseState.PARSE_CHUNKED
elif length is not None:
self._type = ParseState.PARSE_LENGTH
self._length = length
if self._length == 0:
real_payload.feed_eof()
self.done = True
else:
if readall and code != 204:
self._type = ParseState.PARSE_UNTIL_EOF
elif method in ('PUT', 'POST'):
internal_logger.warning( # pragma: no cover
'Content-Length or Transfer-Encoding header is required')
self._type = ParseState.PARSE_NONE
real_payload.feed_eof()
self.done = True
self.payload = real_payload
def feed_eof(self) -> None:
if self._type == ParseState.PARSE_UNTIL_EOF:
self.payload.feed_eof()
elif self._type == ParseState.PARSE_LENGTH:
raise ContentLengthError(
"Not enough data for satisfy content length header.")
elif self._type == ParseState.PARSE_CHUNKED:
raise TransferEncodingError(
"Not enough data for satisfy transfer length header.")
def feed_data(self,
chunk: bytes,
SEP: bytes=b'\r\n',
CHUNK_EXT: bytes=b';') -> Tuple[bool, bytes]:
# Read specified amount of bytes
if self._type == ParseState.PARSE_LENGTH:
required = self._length
chunk_len = len(chunk)
if required >= chunk_len:
self._length = required - chunk_len
self.payload.feed_data(chunk, chunk_len)
if self._length == 0:
self.payload.feed_eof()
return True, b''
else:
self._length = 0
self.payload.feed_data(chunk[:required], required)
self.payload.feed_eof()
return True, chunk[required:]
# Chunked transfer encoding parser
elif self._type == ParseState.PARSE_CHUNKED:
if self._chunk_tail:
chunk = self._chunk_tail + chunk
self._chunk_tail = b''
while chunk:
# read next chunk size
if self._chunk == ChunkState.PARSE_CHUNKED_SIZE:
pos = chunk.find(SEP)
if pos >= 0:
i = chunk.find(CHUNK_EXT, 0, pos)
if i >= 0:
size_b = chunk[:i] # strip chunk-extensions
else:
size_b = chunk[:pos]
try:
size = int(bytes(size_b), 16)
except ValueError:
exc = TransferEncodingError(
chunk[:pos].decode('ascii', 'surrogateescape'))
self.payload.set_exception(exc)
raise exc from None
chunk = chunk[pos+2:]
if size == 0: # eof marker
self._chunk = ChunkState.PARSE_MAYBE_TRAILERS
else:
self._chunk = ChunkState.PARSE_CHUNKED_CHUNK
self._chunk_size = size
self.payload.begin_http_chunk_receiving()
else:
self._chunk_tail = chunk
return False, b''
# read chunk and feed buffer
if self._chunk == ChunkState.PARSE_CHUNKED_CHUNK:
required = self._chunk_size
chunk_len = len(chunk)
if required > chunk_len:
self._chunk_size = required - chunk_len
self.payload.feed_data(chunk, chunk_len)
return False, b''
else:
self._chunk_size = 0
self.payload.feed_data(chunk[:required], required)
chunk = chunk[required:]
self._chunk = ChunkState.PARSE_CHUNKED_CHUNK_EOF
self.payload.end_http_chunk_receiving()
# toss the CRLF at the end of the chunk
if self._chunk == ChunkState.PARSE_CHUNKED_CHUNK_EOF:
if chunk[:2] == SEP:
chunk = chunk[2:]
self._chunk = ChunkState.PARSE_CHUNKED_SIZE
else:
self._chunk_tail = chunk
return False, b''
# if stream does not contain trailer, after 0\r\n
# we should get another \r\n otherwise
# trailers needs to be skiped until \r\n\r\n
if self._chunk == ChunkState.PARSE_MAYBE_TRAILERS:
if chunk[:2] == SEP:
# end of stream
self.payload.feed_eof()
return True, chunk[2:]
else:
self._chunk = ChunkState.PARSE_TRAILERS
# read and discard trailer up to the CRLF terminator
if self._chunk == ChunkState.PARSE_TRAILERS:
pos = chunk.find(SEP)
if pos >= 0:
chunk = chunk[pos+2:]
self._chunk = ChunkState.PARSE_MAYBE_TRAILERS
else:
self._chunk_tail = chunk
return False, b''
# Read all bytes until eof
elif self._type == ParseState.PARSE_UNTIL_EOF:
self.payload.feed_data(chunk, len(chunk))
return False, b''
class DeflateBuffer:
"""DeflateStream decompress stream and feed data into specified stream."""
def __init__(self, out: StreamReader, encoding: Optional[str]) -> None:
self.out = out
self.size = 0
self.encoding = encoding
self._started_decoding = False
if encoding == 'br':
if not HAS_BROTLI: # pragma: no cover
raise ContentEncodingError(
'Can not decode content-encoding: brotli (br). '
'Please install `brotlipy`')
self.decompressor = brotli.Decompressor()
else:
zlib_mode = (16 + zlib.MAX_WBITS
if encoding == 'gzip' else zlib.MAX_WBITS)
self.decompressor = zlib.decompressobj(wbits=zlib_mode)
def set_exception(self, exc: BaseException) -> None:
self.out.set_exception(exc)
def feed_data(self, chunk: bytes, size: int) -> None:
if not size:
return
self.size += size
# RFC1950
# bits 0..3 = CM = 0b1000 = 8 = "deflate"
# bits 4..7 = CINFO = 1..7 = windows size.
if not self._started_decoding and self.encoding == 'deflate' \
and chunk[0] & 0xf != 8:
# Change the decoder to decompress incorrectly compressed data
# Actually we should issue a warning about non-RFC-compilant data.
self.decompressor = zlib.decompressobj(wbits=-zlib.MAX_WBITS)
try:
chunk = self.decompressor.decompress(chunk)
except Exception:
raise ContentEncodingError(
'Can not decode content-encoding: %s' % self.encoding)
self._started_decoding = True
if chunk:
self.out.feed_data(chunk, len(chunk))
def feed_eof(self) -> None:
chunk = self.decompressor.flush()
if chunk or self.size > 0:
self.out.feed_data(chunk, len(chunk))
if self.encoding == 'deflate' and not self.decompressor.eof:
raise ContentEncodingError('deflate')
self.out.feed_eof()
def begin_http_chunk_receiving(self) -> None:
self.out.begin_http_chunk_receiving()
def end_http_chunk_receiving(self) -> None:
self.out.end_http_chunk_receiving()
HttpRequestParserPy = HttpRequestParser
HttpResponseParserPy = HttpResponseParser
RawRequestMessagePy = RawRequestMessage
RawResponseMessagePy = RawResponseMessage
try:
if not NO_EXTENSIONS:
from ._http_parser import (HttpRequestParser, # type: ignore # noqa
HttpResponseParser,
RawRequestMessage,
RawResponseMessage)
HttpRequestParserC = HttpRequestParser
HttpResponseParserC = HttpResponseParser
RawRequestMessageC = RawRequestMessage
RawResponseMessageC = RawResponseMessage
except ImportError: # pragma: no cover
pass

View File

@ -0,0 +1,659 @@
"""WebSocket protocol versions 13 and 8."""
import asyncio
import collections
import json
import random
import re
import sys
import zlib
from enum import IntEnum
from struct import Struct
from typing import Any, Callable, List, Optional, Tuple, Union
from .base_protocol import BaseProtocol
from .helpers import NO_EXTENSIONS
from .log import ws_logger
from .streams import DataQueue
__all__ = ('WS_CLOSED_MESSAGE', 'WS_CLOSING_MESSAGE', 'WS_KEY',
'WebSocketReader', 'WebSocketWriter', 'WSMessage',
'WebSocketError', 'WSMsgType', 'WSCloseCode')
class WSCloseCode(IntEnum):
OK = 1000
GOING_AWAY = 1001
PROTOCOL_ERROR = 1002
UNSUPPORTED_DATA = 1003
INVALID_TEXT = 1007
POLICY_VIOLATION = 1008
MESSAGE_TOO_BIG = 1009
MANDATORY_EXTENSION = 1010
INTERNAL_ERROR = 1011
SERVICE_RESTART = 1012
TRY_AGAIN_LATER = 1013
ALLOWED_CLOSE_CODES = {int(i) for i in WSCloseCode}
class WSMsgType(IntEnum):
# websocket spec types
CONTINUATION = 0x0
TEXT = 0x1
BINARY = 0x2
PING = 0x9
PONG = 0xa
CLOSE = 0x8
# hyper_internal_service specific types
CLOSING = 0x100
CLOSED = 0x101
ERROR = 0x102
text = TEXT
binary = BINARY
ping = PING
pong = PONG
close = CLOSE
closing = CLOSING
closed = CLOSED
error = ERROR
WS_KEY = b'258EAFA5-E914-47DA-95CA-C5AB0DC85B11'
UNPACK_LEN2 = Struct('!H').unpack_from
UNPACK_LEN3 = Struct('!Q').unpack_from
UNPACK_CLOSE_CODE = Struct('!H').unpack
PACK_LEN1 = Struct('!BB').pack
PACK_LEN2 = Struct('!BBH').pack
PACK_LEN3 = Struct('!BBQ').pack
PACK_CLOSE_CODE = Struct('!H').pack
MSG_SIZE = 2 ** 14
DEFAULT_LIMIT = 2 ** 16
_WSMessageBase = collections.namedtuple('_WSMessageBase',
['type', 'data', 'extra'])
class WSMessage(_WSMessageBase):
def json(self, *,
loads: Callable[[Any], Any]=json.loads) -> Any:
"""Return parsed JSON data.
.. versionadded:: 0.22
"""
return loads(self.data)
WS_CLOSED_MESSAGE = WSMessage(WSMsgType.CLOSED, None, None)
WS_CLOSING_MESSAGE = WSMessage(WSMsgType.CLOSING, None, None)
class WebSocketError(Exception):
"""WebSocket protocol parser error."""
def __init__(self, code: int, message: str) -> None:
self.code = code
super().__init__(code, message)
def __str__(self) -> str:
return self.args[1]
class WSHandshakeError(Exception):
"""WebSocket protocol handshake error."""
native_byteorder = sys.byteorder
# Used by _websocket_mask_python
_XOR_TABLE = [bytes(a ^ b for a in range(256)) for b in range(256)]
def _websocket_mask_python(mask: bytes, data: bytearray) -> None:
"""Websocket masking function.
`mask` is a `bytes` object of length 4; `data` is a `bytearray`
object of any length. The contents of `data` are masked with `mask`,
as specified in section 5.3 of RFC 6455.
Note that this function mutates the `data` argument.
This pure-python implementation may be replaced by an optimized
version when available.
"""
assert isinstance(data, bytearray), data
assert len(mask) == 4, mask
if data:
a, b, c, d = (_XOR_TABLE[n] for n in mask)
data[::4] = data[::4].translate(a)
data[1::4] = data[1::4].translate(b)
data[2::4] = data[2::4].translate(c)
data[3::4] = data[3::4].translate(d)
if NO_EXTENSIONS: # pragma: no cover
_websocket_mask = _websocket_mask_python
else:
try:
from ._websocket import _websocket_mask_cython # type: ignore
_websocket_mask = _websocket_mask_cython
except ImportError: # pragma: no cover
_websocket_mask = _websocket_mask_python
_WS_DEFLATE_TRAILING = bytes([0x00, 0x00, 0xff, 0xff])
_WS_EXT_RE = re.compile(r'^(?:;\s*(?:'
r'(server_no_context_takeover)|'
r'(client_no_context_takeover)|'
r'(server_max_window_bits(?:=(\d+))?)|'
r'(client_max_window_bits(?:=(\d+))?)))*$')
_WS_EXT_RE_SPLIT = re.compile(r'permessage-deflate([^,]+)?')
def ws_ext_parse(extstr: str, isserver: bool=False) -> Tuple[int, bool]:
if not extstr:
return 0, False
compress = 0
notakeover = False
for ext in _WS_EXT_RE_SPLIT.finditer(extstr):
defext = ext.group(1)
# Return compress = 15 when get `permessage-deflate`
if not defext:
compress = 15
break
match = _WS_EXT_RE.match(defext)
if match:
compress = 15
if isserver:
# Server never fail to detect compress handshake.
# Server does not need to send max wbit to client
if match.group(4):
compress = int(match.group(4))
# Group3 must match if group4 matches
# Compress wbit 8 does not support in zlib
# If compress level not support,
# CONTINUE to next extension
if compress > 15 or compress < 9:
compress = 0
continue
if match.group(1):
notakeover = True
# Ignore regex group 5 & 6 for client_max_window_bits
break
else:
if match.group(6):
compress = int(match.group(6))
# Group5 must match if group6 matches
# Compress wbit 8 does not support in zlib
# If compress level not support,
# FAIL the parse progress
if compress > 15 or compress < 9:
raise WSHandshakeError('Invalid window size')
if match.group(2):
notakeover = True
# Ignore regex group 5 & 6 for client_max_window_bits
break
# Return Fail if client side and not match
elif not isserver:
raise WSHandshakeError('Extension for deflate not supported' +
ext.group(1))
return compress, notakeover
def ws_ext_gen(compress: int=15, isserver: bool=False,
server_notakeover: bool=False) -> str:
# client_notakeover=False not used for server
# compress wbit 8 does not support in zlib
if compress < 9 or compress > 15:
raise ValueError('Compress wbits must between 9 and 15, '
'zlib does not support wbits=8')
enabledext = ['permessage-deflate']
if not isserver:
enabledext.append('client_max_window_bits')
if compress < 15:
enabledext.append('server_max_window_bits=' + str(compress))
if server_notakeover:
enabledext.append('server_no_context_takeover')
# if client_notakeover:
# enabledext.append('client_no_context_takeover')
return '; '.join(enabledext)
class WSParserState(IntEnum):
READ_HEADER = 1
READ_PAYLOAD_LENGTH = 2
READ_PAYLOAD_MASK = 3
READ_PAYLOAD = 4
class WebSocketReader:
def __init__(self, queue: DataQueue[WSMessage],
max_msg_size: int, compress: bool=True) -> None:
self.queue = queue
self._max_msg_size = max_msg_size
self._exc = None # type: Optional[BaseException]
self._partial = bytearray()
self._state = WSParserState.READ_HEADER
self._opcode = None # type: Optional[int]
self._frame_fin = False
self._frame_opcode = None # type: Optional[int]
self._frame_payload = bytearray()
self._tail = b''
self._has_mask = False
self._frame_mask = None # type: Optional[bytes]
self._payload_length = 0
self._payload_length_flag = 0
self._compressed = None # type: Optional[bool]
self._decompressobj = None # type: Any # zlib.decompressobj actually
self._compress = compress
def feed_eof(self) -> None:
self.queue.feed_eof()
def feed_data(self, data: bytes) -> Tuple[bool, bytes]:
if self._exc:
return True, data
try:
return self._feed_data(data)
except Exception as exc:
self._exc = exc
self.queue.set_exception(exc)
return True, b''
def _feed_data(self, data: bytes) -> Tuple[bool, bytes]:
for fin, opcode, payload, compressed in self.parse_frame(data):
if compressed and not self._decompressobj:
self._decompressobj = zlib.decompressobj(wbits=-zlib.MAX_WBITS)
if opcode == WSMsgType.CLOSE:
if len(payload) >= 2:
close_code = UNPACK_CLOSE_CODE(payload[:2])[0]
if (close_code < 3000 and
close_code not in ALLOWED_CLOSE_CODES):
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
'Invalid close code: {}'.format(close_code))
try:
close_message = payload[2:].decode('utf-8')
except UnicodeDecodeError as exc:
raise WebSocketError(
WSCloseCode.INVALID_TEXT,
'Invalid UTF-8 text message') from exc
msg = WSMessage(WSMsgType.CLOSE, close_code, close_message)
elif payload:
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
'Invalid close frame: {} {} {!r}'.format(
fin, opcode, payload))
else:
msg = WSMessage(WSMsgType.CLOSE, 0, '')
self.queue.feed_data(msg, 0)
elif opcode == WSMsgType.PING:
self.queue.feed_data(
WSMessage(WSMsgType.PING, payload, ''), len(payload))
elif opcode == WSMsgType.PONG:
self.queue.feed_data(
WSMessage(WSMsgType.PONG, payload, ''), len(payload))
elif opcode not in (
WSMsgType.TEXT, WSMsgType.BINARY) and self._opcode is None:
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
"Unexpected opcode={!r}".format(opcode))
else:
# load text/binary
if not fin:
# got partial frame payload
if opcode != WSMsgType.CONTINUATION:
self._opcode = opcode
self._partial.extend(payload)
if (self._max_msg_size and
len(self._partial) >= self._max_msg_size):
raise WebSocketError(
WSCloseCode.MESSAGE_TOO_BIG,
"Message size {} exceeds limit {}".format(
len(self._partial), self._max_msg_size))
else:
# previous frame was non finished
# we should get continuation opcode
if self._partial:
if opcode != WSMsgType.CONTINUATION:
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
'The opcode in non-fin frame is expected '
'to be zero, got {!r}'.format(opcode))
if opcode == WSMsgType.CONTINUATION:
assert self._opcode is not None
opcode = self._opcode
self._opcode = None
self._partial.extend(payload)
if (self._max_msg_size and
len(self._partial) >= self._max_msg_size):
raise WebSocketError(
WSCloseCode.MESSAGE_TOO_BIG,
"Message size {} exceeds limit {}".format(
len(self._partial), self._max_msg_size))
# Decompress process must to be done after all packets
# received.
if compressed:
self._partial.extend(_WS_DEFLATE_TRAILING)
payload_merged = self._decompressobj.decompress(
self._partial, self._max_msg_size)
if self._decompressobj.unconsumed_tail:
left = len(self._decompressobj.unconsumed_tail)
raise WebSocketError(
WSCloseCode.MESSAGE_TOO_BIG,
"Decompressed message size {} exceeds limit {}"
.format(
self._max_msg_size + left,
self._max_msg_size
)
)
else:
payload_merged = bytes(self._partial)
self._partial.clear()
if opcode == WSMsgType.TEXT:
try:
text = payload_merged.decode('utf-8')
self.queue.feed_data(
WSMessage(WSMsgType.TEXT, text, ''), len(text))
except UnicodeDecodeError as exc:
raise WebSocketError(
WSCloseCode.INVALID_TEXT,
'Invalid UTF-8 text message') from exc
else:
self.queue.feed_data(
WSMessage(WSMsgType.BINARY, payload_merged, ''),
len(payload_merged))
return False, b''
def parse_frame(self, buf: bytes) -> List[Tuple[bool, Optional[int],
bytearray,
Optional[bool]]]:
"""Return the next frame from the socket."""
frames = []
if self._tail:
buf, self._tail = self._tail + buf, b''
start_pos = 0
buf_length = len(buf)
while True:
# read header
if self._state == WSParserState.READ_HEADER:
if buf_length - start_pos >= 2:
data = buf[start_pos:start_pos+2]
start_pos += 2
first_byte, second_byte = data
fin = (first_byte >> 7) & 1
rsv1 = (first_byte >> 6) & 1
rsv2 = (first_byte >> 5) & 1
rsv3 = (first_byte >> 4) & 1
opcode = first_byte & 0xf
# frame-fin = %x0 ; more frames of this message follow
# / %x1 ; final frame of this message
# frame-rsv1 = %x0 ;
# 1 bit, MUST be 0 unless negotiated otherwise
# frame-rsv2 = %x0 ;
# 1 bit, MUST be 0 unless negotiated otherwise
# frame-rsv3 = %x0 ;
# 1 bit, MUST be 0 unless negotiated otherwise
#
# Remove rsv1 from this test for deflate development
if rsv2 or rsv3 or (rsv1 and not self._compress):
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
'Received frame with non-zero reserved bits')
if opcode > 0x7 and fin == 0:
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
'Received fragmented control frame')
has_mask = (second_byte >> 7) & 1
length = second_byte & 0x7f
# Control frames MUST have a payload
# length of 125 bytes or less
if opcode > 0x7 and length > 125:
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
'Control frame payload cannot be '
'larger than 125 bytes')
# Set compress status if last package is FIN
# OR set compress status if this is first fragment
# Raise error if not first fragment with rsv1 = 0x1
if self._frame_fin or self._compressed is None:
self._compressed = True if rsv1 else False
elif rsv1:
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
'Received frame with non-zero reserved bits')
self._frame_fin = bool(fin)
self._frame_opcode = opcode
self._has_mask = bool(has_mask)
self._payload_length_flag = length
self._state = WSParserState.READ_PAYLOAD_LENGTH
else:
break
# read payload length
if self._state == WSParserState.READ_PAYLOAD_LENGTH:
length = self._payload_length_flag
if length == 126:
if buf_length - start_pos >= 2:
data = buf[start_pos:start_pos+2]
start_pos += 2
length = UNPACK_LEN2(data)[0]
self._payload_length = length
self._state = (
WSParserState.READ_PAYLOAD_MASK
if self._has_mask
else WSParserState.READ_PAYLOAD)
else:
break
elif length > 126:
if buf_length - start_pos >= 8:
data = buf[start_pos:start_pos+8]
start_pos += 8
length = UNPACK_LEN3(data)[0]
self._payload_length = length
self._state = (
WSParserState.READ_PAYLOAD_MASK
if self._has_mask
else WSParserState.READ_PAYLOAD)
else:
break
else:
self._payload_length = length
self._state = (
WSParserState.READ_PAYLOAD_MASK
if self._has_mask
else WSParserState.READ_PAYLOAD)
# read payload mask
if self._state == WSParserState.READ_PAYLOAD_MASK:
if buf_length - start_pos >= 4:
self._frame_mask = buf[start_pos:start_pos+4]
start_pos += 4
self._state = WSParserState.READ_PAYLOAD
else:
break
if self._state == WSParserState.READ_PAYLOAD:
length = self._payload_length
payload = self._frame_payload
chunk_len = buf_length - start_pos
if length >= chunk_len:
self._payload_length = length - chunk_len
payload.extend(buf[start_pos:])
start_pos = buf_length
else:
self._payload_length = 0
payload.extend(buf[start_pos:start_pos+length])
start_pos = start_pos + length
if self._payload_length == 0:
if self._has_mask:
assert self._frame_mask is not None
_websocket_mask(self._frame_mask, payload)
frames.append((
self._frame_fin,
self._frame_opcode,
payload,
self._compressed))
self._frame_payload = bytearray()
self._state = WSParserState.READ_HEADER
else:
break
self._tail = buf[start_pos:]
return frames
class WebSocketWriter:
def __init__(self, protocol: BaseProtocol, transport: asyncio.Transport, *,
use_mask: bool=False, limit: int=DEFAULT_LIMIT,
random: Any=random.Random(),
compress: int=0, notakeover: bool=False) -> None:
self.protocol = protocol
self.transport = transport
self.use_mask = use_mask
self.randrange = random.randrange
self.compress = compress
self.notakeover = notakeover
self._closing = False
self._limit = limit
self._output_size = 0
self._compressobj = None # type: Any # actually compressobj
async def _send_frame(self, message: bytes, opcode: int,
compress: Optional[int]=None) -> None:
"""Send a frame over the websocket with message as its payload."""
if self._closing:
ws_logger.warning('websocket connection is closing.')
rsv = 0
# Only compress larger packets (disabled)
# Does small packet needs to be compressed?
# if self.compress and opcode < 8 and len(message) > 124:
if (compress or self.compress) and opcode < 8:
if compress:
# Do not set self._compress if compressing is for this frame
compressobj = zlib.compressobj(wbits=-compress)
else: # self.compress
if not self._compressobj:
self._compressobj = zlib.compressobj(wbits=-self.compress)
compressobj = self._compressobj
message = compressobj.compress(message)
message = message + compressobj.flush(
zlib.Z_FULL_FLUSH if self.notakeover else zlib.Z_SYNC_FLUSH)
if message.endswith(_WS_DEFLATE_TRAILING):
message = message[:-4]
rsv = rsv | 0x40
msg_length = len(message)
use_mask = self.use_mask
if use_mask:
mask_bit = 0x80
else:
mask_bit = 0
if msg_length < 126:
header = PACK_LEN1(0x80 | rsv | opcode, msg_length | mask_bit)
elif msg_length < (1 << 16):
header = PACK_LEN2(0x80 | rsv | opcode, 126 | mask_bit, msg_length)
else:
header = PACK_LEN3(0x80 | rsv | opcode, 127 | mask_bit, msg_length)
if use_mask:
mask = self.randrange(0, 0xffffffff)
mask = mask.to_bytes(4, 'big')
message = bytearray(message)
_websocket_mask(mask, message)
self.transport.write(header + mask + message)
self._output_size += len(header) + len(mask) + len(message)
else:
if len(message) > MSG_SIZE:
self.transport.write(header)
self.transport.write(message)
else:
self.transport.write(header + message)
self._output_size += len(header) + len(message)
if self._output_size > self._limit:
self._output_size = 0
await self.protocol._drain_helper()
async def pong(self, message: bytes=b'') -> None:
"""Send pong message."""
if isinstance(message, str):
message = message.encode('utf-8')
await self._send_frame(message, WSMsgType.PONG)
async def ping(self, message: bytes=b'') -> None:
"""Send ping message."""
if isinstance(message, str):
message = message.encode('utf-8')
await self._send_frame(message, WSMsgType.PING)
async def send(self, message: Union[str, bytes],
binary: bool=False,
compress: Optional[int]=None) -> None:
"""Send a frame over the websocket with message as its payload."""
if isinstance(message, str):
message = message.encode('utf-8')
if binary:
await self._send_frame(message, WSMsgType.BINARY, compress)
else:
await self._send_frame(message, WSMsgType.TEXT, compress)
async def close(self, code: int=1000, message: bytes=b'') -> None:
"""Close the websocket, sending the specified code and message."""
if isinstance(message, str):
message = message.encode('utf-8')
try:
await self._send_frame(
PACK_CLOSE_CODE(code) + message, opcode=WSMsgType.CLOSE)
finally:
self._closing = True

View File

@ -0,0 +1,172 @@
"""Http related parsers and protocol."""
import asyncio
import collections
import zlib
from typing import Any, Awaitable, Callable, Optional, Union # noqa
from multidict import CIMultiDict # noqa
from .abc import AbstractStreamWriter
from .base_protocol import BaseProtocol
from .helpers import NO_EXTENSIONS
__all__ = ('StreamWriter', 'HttpVersion', 'HttpVersion10', 'HttpVersion11')
HttpVersion = collections.namedtuple('HttpVersion', ['major', 'minor'])
HttpVersion10 = HttpVersion(1, 0)
HttpVersion11 = HttpVersion(1, 1)
_T_OnChunkSent = Optional[Callable[[bytes], Awaitable[None]]]
class StreamWriter(AbstractStreamWriter):
def __init__(self,
protocol: BaseProtocol,
loop: asyncio.AbstractEventLoop,
on_chunk_sent: _T_OnChunkSent = None) -> None:
self._protocol = protocol
self._transport = protocol.transport
self.loop = loop
self.length = None
self.chunked = False
self.buffer_size = 0
self.output_size = 0
self._eof = False
self._compress = None # type: Any
self._drain_waiter = None
self._on_chunk_sent = on_chunk_sent # type: _T_OnChunkSent
@property
def transport(self) -> Optional[asyncio.Transport]:
return self._transport
@property
def protocol(self) -> BaseProtocol:
return self._protocol
def enable_chunking(self) -> None:
self.chunked = True
def enable_compression(self, encoding: str='deflate') -> None:
zlib_mode = (16 + zlib.MAX_WBITS
if encoding == 'gzip' else zlib.MAX_WBITS)
self._compress = zlib.compressobj(wbits=zlib_mode)
def _write(self, chunk: bytes) -> None:
size = len(chunk)
self.buffer_size += size
self.output_size += size
if self._transport is None or self._transport.is_closing():
raise ConnectionResetError('Cannot write to closing transport')
self._transport.write(chunk)
async def write(self, chunk: bytes,
*, drain: bool=True, LIMIT: int=0x10000) -> None:
"""Writes chunk of data to a stream.
write_eof() indicates end of stream.
writer can't be used after write_eof() method being called.
write() return drain future.
"""
if self._on_chunk_sent is not None:
await self._on_chunk_sent(chunk)
if self._compress is not None:
chunk = self._compress.compress(chunk)
if not chunk:
return
if self.length is not None:
chunk_len = len(chunk)
if self.length >= chunk_len:
self.length = self.length - chunk_len
else:
chunk = chunk[:self.length]
self.length = 0
if not chunk:
return
if chunk:
if self.chunked:
chunk_len_pre = ('%x\r\n' % len(chunk)).encode('ascii')
chunk = chunk_len_pre + chunk + b'\r\n'
self._write(chunk)
if self.buffer_size > LIMIT and drain:
self.buffer_size = 0
await self.drain()
async def write_headers(self, status_line: str,
headers: 'CIMultiDict[str]') -> None:
"""Write request/response status and headers."""
# status + headers
buf = _serialize_headers(status_line, headers)
self._write(buf)
async def write_eof(self, chunk: bytes=b'') -> None:
if self._eof:
return
if chunk and self._on_chunk_sent is not None:
await self._on_chunk_sent(chunk)
if self._compress:
if chunk:
chunk = self._compress.compress(chunk)
chunk = chunk + self._compress.flush()
if chunk and self.chunked:
chunk_len = ('%x\r\n' % len(chunk)).encode('ascii')
chunk = chunk_len + chunk + b'\r\n0\r\n\r\n'
else:
if self.chunked:
if chunk:
chunk_len = ('%x\r\n' % len(chunk)).encode('ascii')
chunk = chunk_len + chunk + b'\r\n0\r\n\r\n'
else:
chunk = b'0\r\n\r\n'
if chunk:
self._write(chunk)
await self.drain()
self._eof = True
self._transport = None
async def drain(self) -> None:
"""Flush the write buffer.
The intended use is to write
await w.write(data)
await w.drain()
"""
if self._protocol.transport is not None:
await self._protocol._drain_helper()
def _py_serialize_headers(status_line: str,
headers: 'CIMultiDict[str]') -> bytes:
line = status_line + '\r\n' + ''.join(
[k + ': ' + v + '\r\n' for k, v in headers.items()])
return line.encode('utf-8') + b'\r\n'
_serialize_headers = _py_serialize_headers
try:
import hyper_internal_service._http_writer as _http_writer # type: ignore
_c_serialize_headers = _http_writer._serialize_headers
if not NO_EXTENSIONS:
_serialize_headers = _c_serialize_headers
except ImportError:
pass

View File

@ -0,0 +1,44 @@
import asyncio
import collections
from typing import Any, Optional
try:
from typing import Deque
except ImportError:
from typing_extensions import Deque # noqa
class EventResultOrError:
"""
This class wrappers the Event asyncio lock allowing either awake the
locked Tasks without any error or raising an exception.
thanks to @vorpalsmith for the simple design.
"""
def __init__(self, loop: asyncio.AbstractEventLoop) -> None:
self._loop = loop
self._exc = None # type: Optional[BaseException]
self._event = asyncio.Event()
self._waiters = collections.deque() # type: Deque[asyncio.Future[Any]]
def set(self, exc: Optional[BaseException]=None) -> None:
self._exc = exc
self._event.set()
async def wait(self) -> Any:
waiter = self._loop.create_task(self._event.wait())
self._waiters.append(waiter)
try:
val = await waiter
finally:
self._waiters.remove(waiter)
if self._exc is not None:
raise self._exc
return val
def cancel(self) -> None:
""" Cancel all waiters """
for waiter in self._waiters:
waiter.cancel()

View File

@ -0,0 +1,8 @@
import logging
access_logger = logging.getLogger('hyper_internal_service.access')
client_logger = logging.getLogger('hyper_internal_service.client')
internal_logger = logging.getLogger('hyper_internal_service.internal')
server_logger = logging.getLogger('hyper_internal_service.server')
web_logger = logging.getLogger('hyper_internal_service.web')
ws_logger = logging.getLogger('hyper_internal_service.websocket')

View File

@ -0,0 +1,959 @@
import base64
import binascii
import json
import re
import uuid
import warnings
import zlib
from collections import deque
from types import TracebackType
from typing import ( # noqa
TYPE_CHECKING,
Any,
Dict,
Iterator,
List,
Mapping,
Optional,
Sequence,
Tuple,
Type,
Union,
)
from urllib.parse import parse_qsl, unquote, urlencode
from multidict import CIMultiDict, CIMultiDictProxy, MultiMapping # noqa
from .hdrs import (
CONTENT_DISPOSITION,
CONTENT_ENCODING,
CONTENT_LENGTH,
CONTENT_TRANSFER_ENCODING,
CONTENT_TYPE,
)
from .helpers import CHAR, TOKEN, parse_mimetype, reify
from .http import HeadersParser
from .payload import (
JsonPayload,
LookupError,
Order,
Payload,
StringPayload,
get_payload,
payload_type,
)
from .streams import StreamReader
__all__ = ('MultipartReader', 'MultipartWriter', 'BodyPartReader',
'BadContentDispositionHeader', 'BadContentDispositionParam',
'parse_content_disposition', 'content_disposition_filename')
if TYPE_CHECKING: # pragma: no cover
from .client_reqrep import ClientResponse # noqa
class BadContentDispositionHeader(RuntimeWarning):
pass
class BadContentDispositionParam(RuntimeWarning):
pass
def parse_content_disposition(header: Optional[str]) -> Tuple[Optional[str],
Dict[str, str]]:
def is_token(string: str) -> bool:
return bool(string) and TOKEN >= set(string)
def is_quoted(string: str) -> bool:
return string[0] == string[-1] == '"'
def is_rfc5987(string: str) -> bool:
return is_token(string) and string.count("'") == 2
def is_extended_param(string: str) -> bool:
return string.endswith('*')
def is_continuous_param(string: str) -> bool:
pos = string.find('*') + 1
if not pos:
return False
substring = string[pos:-1] if string.endswith('*') else string[pos:]
return substring.isdigit()
def unescape(text: str, *,
chars: str=''.join(map(re.escape, CHAR))) -> str:
return re.sub('\\\\([{}])'.format(chars), '\\1', text)
if not header:
return None, {}
disptype, *parts = header.split(';')
if not is_token(disptype):
warnings.warn(BadContentDispositionHeader(header))
return None, {}
params = {} # type: Dict[str, str]
while parts:
item = parts.pop(0)
if '=' not in item:
warnings.warn(BadContentDispositionHeader(header))
return None, {}
key, value = item.split('=', 1)
key = key.lower().strip()
value = value.lstrip()
if key in params:
warnings.warn(BadContentDispositionHeader(header))
return None, {}
if not is_token(key):
warnings.warn(BadContentDispositionParam(item))
continue
elif is_continuous_param(key):
if is_quoted(value):
value = unescape(value[1:-1])
elif not is_token(value):
warnings.warn(BadContentDispositionParam(item))
continue
elif is_extended_param(key):
if is_rfc5987(value):
encoding, _, value = value.split("'", 2)
encoding = encoding or 'utf-8'
else:
warnings.warn(BadContentDispositionParam(item))
continue
try:
value = unquote(value, encoding, 'strict')
except UnicodeDecodeError: # pragma: nocover
warnings.warn(BadContentDispositionParam(item))
continue
else:
failed = True
if is_quoted(value):
failed = False
value = unescape(value[1:-1].lstrip('\\/'))
elif is_token(value):
failed = False
elif parts:
# maybe just ; in filename, in any case this is just
# one case fix, for proper fix we need to redesign parser
_value = '%s;%s' % (value, parts[0])
if is_quoted(_value):
parts.pop(0)
value = unescape(_value[1:-1].lstrip('\\/'))
failed = False
if failed:
warnings.warn(BadContentDispositionHeader(header))
return None, {}
params[key] = value
return disptype.lower(), params
def content_disposition_filename(params: Mapping[str, str],
name: str='filename') -> Optional[str]:
name_suf = '%s*' % name
if not params:
return None
elif name_suf in params:
return params[name_suf]
elif name in params:
return params[name]
else:
parts = []
fnparams = sorted((key, value)
for key, value in params.items()
if key.startswith(name_suf))
for num, (key, value) in enumerate(fnparams):
_, tail = key.split('*', 1)
if tail.endswith('*'):
tail = tail[:-1]
if tail == str(num):
parts.append(value)
else:
break
if not parts:
return None
value = ''.join(parts)
if "'" in value:
encoding, _, value = value.split("'", 2)
encoding = encoding or 'utf-8'
return unquote(value, encoding, 'strict')
return value
class MultipartResponseWrapper:
"""Wrapper around the MultipartReader.
It takes care about
underlying connection and close it when it needs in.
"""
def __init__(
self,
resp: 'ClientResponse',
stream: 'MultipartReader',
) -> None:
self.resp = resp
self.stream = stream
def __aiter__(self) -> 'MultipartResponseWrapper':
return self
async def __anext__(
self,
) -> Union['MultipartReader', 'BodyPartReader']:
part = await self.next()
if part is None:
raise StopAsyncIteration # NOQA
return part
def at_eof(self) -> bool:
"""Returns True when all response data had been read."""
return self.resp.content.at_eof()
async def next(
self,
) -> Optional[Union['MultipartReader', 'BodyPartReader']]:
"""Emits next multipart reader object."""
item = await self.stream.next()
if self.stream.at_eof():
await self.release()
return item
async def release(self) -> None:
"""Releases the connection gracefully, reading all the content
to the void."""
await self.resp.release()
class BodyPartReader:
"""Multipart reader for single body part."""
chunk_size = 8192
def __init__(self, boundary: bytes,
headers: 'CIMultiDictProxy[str]',
content: StreamReader) -> None:
self.headers = headers
self._boundary = boundary
self._content = content
self._at_eof = False
length = self.headers.get(CONTENT_LENGTH, None)
self._length = int(length) if length is not None else None
self._read_bytes = 0
# TODO: typeing.Deque is not supported by Python 3.5
self._unread = deque() # type: Any
self._prev_chunk = None # type: Optional[bytes]
self._content_eof = 0
self._cache = {} # type: Dict[str, Any]
def __aiter__(self) -> 'BodyPartReader':
return self
async def __anext__(self) -> bytes:
part = await self.next()
if part is None:
raise StopAsyncIteration # NOQA
return part
async def next(self) -> Optional[bytes]:
item = await self.read()
if not item:
return None
return item
async def read(self, *, decode: bool=False) -> bytes:
"""Reads body part data.
decode: Decodes data following by encoding
method from Content-Encoding header. If it missed
data remains untouched
"""
if self._at_eof:
return b''
data = bytearray()
while not self._at_eof:
data.extend((await self.read_chunk(self.chunk_size)))
if decode:
return self.decode(data)
return data
async def read_chunk(self, size: int=chunk_size) -> bytes:
"""Reads body part content chunk of the specified size.
size: chunk size
"""
if self._at_eof:
return b''
if self._length:
chunk = await self._read_chunk_from_length(size)
else:
chunk = await self._read_chunk_from_stream(size)
self._read_bytes += len(chunk)
if self._read_bytes == self._length:
self._at_eof = True
if self._at_eof:
clrf = await self._content.readline()
assert b'\r\n' == clrf, \
'reader did not read all the data or it is malformed'
return chunk
async def _read_chunk_from_length(self, size: int) -> bytes:
# Reads body part content chunk of the specified size.
# The body part must has Content-Length header with proper value.
assert self._length is not None, \
'Content-Length required for chunked read'
chunk_size = min(size, self._length - self._read_bytes)
chunk = await self._content.read(chunk_size)
return chunk
async def _read_chunk_from_stream(self, size: int) -> bytes:
# Reads content chunk of body part with unknown length.
# The Content-Length header for body part is not necessary.
assert size >= len(self._boundary) + 2, \
'Chunk size must be greater or equal than boundary length + 2'
first_chunk = self._prev_chunk is None
if first_chunk:
self._prev_chunk = await self._content.read(size)
chunk = await self._content.read(size)
self._content_eof += int(self._content.at_eof())
assert self._content_eof < 3, "Reading after EOF"
assert self._prev_chunk is not None
window = self._prev_chunk + chunk
sub = b'\r\n' + self._boundary
if first_chunk:
idx = window.find(sub)
else:
idx = window.find(sub, max(0, len(self._prev_chunk) - len(sub)))
if idx >= 0:
# pushing boundary back to content
with warnings.catch_warnings():
warnings.filterwarnings("ignore",
category=DeprecationWarning)
self._content.unread_data(window[idx:])
if size > idx:
self._prev_chunk = self._prev_chunk[:idx]
chunk = window[len(self._prev_chunk):idx]
if not chunk:
self._at_eof = True
result = self._prev_chunk
self._prev_chunk = chunk
return result
async def readline(self) -> bytes:
"""Reads body part by line by line."""
if self._at_eof:
return b''
if self._unread:
line = self._unread.popleft()
else:
line = await self._content.readline()
if line.startswith(self._boundary):
# the very last boundary may not come with \r\n,
# so set single rules for everyone
sline = line.rstrip(b'\r\n')
boundary = self._boundary
last_boundary = self._boundary + b'--'
# ensure that we read exactly the boundary, not something alike
if sline == boundary or sline == last_boundary:
self._at_eof = True
self._unread.append(line)
return b''
else:
next_line = await self._content.readline()
if next_line.startswith(self._boundary):
line = line[:-2] # strip CRLF but only once
self._unread.append(next_line)
return line
async def release(self) -> None:
"""Like read(), but reads all the data to the void."""
if self._at_eof:
return
while not self._at_eof:
await self.read_chunk(self.chunk_size)
async def text(self, *, encoding: Optional[str]=None) -> str:
"""Like read(), but assumes that body part contains text data."""
data = await self.read(decode=True)
# see https://www.w3.org/TR/html5/forms.html#multipart/form-data-encoding-algorithm # NOQA
# and https://dvcs.w3.org/hg/xhr/raw-file/tip/Overview.html#dom-xmlhttprequest-send # NOQA
encoding = encoding or self.get_charset(default='utf-8')
return data.decode(encoding)
async def json(self,
*,
encoding: Optional[str]=None) -> Optional[Dict[str, Any]]:
"""Like read(), but assumes that body parts contains JSON data."""
data = await self.read(decode=True)
if not data:
return None
encoding = encoding or self.get_charset(default='utf-8')
return json.loads(data.decode(encoding))
async def form(self, *,
encoding: Optional[str]=None) -> List[Tuple[str, str]]:
"""Like read(), but assumes that body parts contains form
urlencoded data.
"""
data = await self.read(decode=True)
if not data:
return []
if encoding is not None:
real_encoding = encoding
else:
real_encoding = self.get_charset(default='utf-8')
return parse_qsl(data.rstrip().decode(real_encoding),
keep_blank_values=True,
encoding=real_encoding)
def at_eof(self) -> bool:
"""Returns True if the boundary was reached or False otherwise."""
return self._at_eof
def decode(self, data: bytes) -> bytes:
"""Decodes data according the specified Content-Encoding
or Content-Transfer-Encoding headers value.
"""
if CONTENT_TRANSFER_ENCODING in self.headers:
data = self._decode_content_transfer(data)
if CONTENT_ENCODING in self.headers:
return self._decode_content(data)
return data
def _decode_content(self, data: bytes) -> bytes:
encoding = self.headers.get(CONTENT_ENCODING, '').lower()
if encoding == 'deflate':
return zlib.decompress(data, -zlib.MAX_WBITS)
elif encoding == 'gzip':
return zlib.decompress(data, 16 + zlib.MAX_WBITS)
elif encoding == 'identity':
return data
else:
raise RuntimeError('unknown content encoding: {}'.format(encoding))
def _decode_content_transfer(self, data: bytes) -> bytes:
encoding = self.headers.get(CONTENT_TRANSFER_ENCODING, '').lower()
if encoding == 'base64':
return base64.b64decode(data)
elif encoding == 'quoted-printable':
return binascii.a2b_qp(data)
elif encoding in ('binary', '8bit', '7bit'):
return data
else:
raise RuntimeError('unknown content transfer encoding: {}'
''.format(encoding))
def get_charset(self, default: str) -> str:
"""Returns charset parameter from Content-Type header or default."""
ctype = self.headers.get(CONTENT_TYPE, '')
mimetype = parse_mimetype(ctype)
return mimetype.parameters.get('charset', default)
@reify
def name(self) -> Optional[str]:
"""Returns name specified in Content-Disposition header or None
if missed or header is malformed.
"""
_, params = parse_content_disposition(
self.headers.get(CONTENT_DISPOSITION))
return content_disposition_filename(params, 'name')
@reify
def filename(self) -> Optional[str]:
"""Returns filename specified in Content-Disposition header or None
if missed or header is malformed.
"""
_, params = parse_content_disposition(
self.headers.get(CONTENT_DISPOSITION))
return content_disposition_filename(params, 'filename')
@payload_type(BodyPartReader, order=Order.try_first)
class BodyPartReaderPayload(Payload):
def __init__(self, value: BodyPartReader,
*args: Any, **kwargs: Any) -> None:
super().__init__(value, *args, **kwargs)
params = {} # type: Dict[str, str]
if value.name is not None:
params['name'] = value.name
if value.filename is not None:
params['filename'] = value.filename
if params:
self.set_content_disposition('attachment', True, **params)
async def write(self, writer: Any) -> None:
field = self._value
chunk = await field.read_chunk(size=2**16)
while chunk:
await writer.write(field.decode(chunk))
chunk = await field.read_chunk(size=2**16)
class MultipartReader:
"""Multipart body reader."""
#: Response wrapper, used when multipart readers constructs from response.
response_wrapper_cls = MultipartResponseWrapper
#: Multipart reader class, used to handle multipart/* body parts.
#: None points to type(self)
multipart_reader_cls = None
#: Body part reader class for non multipart/* content types.
part_reader_cls = BodyPartReader
def __init__(self, headers: Mapping[str, str],
content: StreamReader) -> None:
self.headers = headers
self._boundary = ('--' + self._get_boundary()).encode()
self._content = content
self._last_part = None # type: Optional[Union['MultipartReader', BodyPartReader]] # noqa
self._at_eof = False
self._at_bof = True
self._unread = [] # type: List[bytes]
def __aiter__(self) -> 'MultipartReader':
return self
async def __anext__(
self,
) -> Union['MultipartReader', BodyPartReader]:
part = await self.next()
if part is None:
raise StopAsyncIteration # NOQA
return part
@classmethod
def from_response(
cls,
response: 'ClientResponse',
) -> MultipartResponseWrapper:
"""Constructs reader instance from HTTP response.
:param response: :class:`~hyper_internal_service.client.ClientResponse` instance
"""
obj = cls.response_wrapper_cls(response, cls(response.headers,
response.content))
return obj
def at_eof(self) -> bool:
"""Returns True if the final boundary was reached or
False otherwise.
"""
return self._at_eof
async def next(
self,
) -> Optional[Union['MultipartReader', BodyPartReader]]:
"""Emits the next multipart body part."""
# So, if we're at BOF, we need to skip till the boundary.
if self._at_eof:
return None
await self._maybe_release_last_part()
if self._at_bof:
await self._read_until_first_boundary()
self._at_bof = False
else:
await self._read_boundary()
if self._at_eof: # we just read the last boundary, nothing to do there
return None
self._last_part = await self.fetch_next_part()
return self._last_part
async def release(self) -> None:
"""Reads all the body parts to the void till the final boundary."""
while not self._at_eof:
item = await self.next()
if item is None:
break
await item.release()
async def fetch_next_part(
self,
) -> Union['MultipartReader', BodyPartReader]:
"""Returns the next body part reader."""
headers = await self._read_headers()
return self._get_part_reader(headers)
def _get_part_reader(
self,
headers: 'CIMultiDictProxy[str]',
) -> Union['MultipartReader', BodyPartReader]:
"""Dispatches the response by the `Content-Type` header, returning
suitable reader instance.
:param dict headers: Response headers
"""
ctype = headers.get(CONTENT_TYPE, '')
mimetype = parse_mimetype(ctype)
if mimetype.type == 'multipart':
if self.multipart_reader_cls is None:
return type(self)(headers, self._content)
return self.multipart_reader_cls(headers, self._content)
else:
return self.part_reader_cls(self._boundary, headers, self._content)
def _get_boundary(self) -> str:
mimetype = parse_mimetype(self.headers[CONTENT_TYPE])
assert mimetype.type == 'multipart', (
'multipart/* content type expected'
)
if 'boundary' not in mimetype.parameters:
raise ValueError('boundary missed for Content-Type: %s'
% self.headers[CONTENT_TYPE])
boundary = mimetype.parameters['boundary']
if len(boundary) > 70:
raise ValueError('boundary %r is too long (70 chars max)'
% boundary)
return boundary
async def _readline(self) -> bytes:
if self._unread:
return self._unread.pop()
return await self._content.readline()
async def _read_until_first_boundary(self) -> None:
while True:
chunk = await self._readline()
if chunk == b'':
raise ValueError("Could not find starting boundary %r"
% (self._boundary))
chunk = chunk.rstrip()
if chunk == self._boundary:
return
elif chunk == self._boundary + b'--':
self._at_eof = True
return
async def _read_boundary(self) -> None:
chunk = (await self._readline()).rstrip()
if chunk == self._boundary:
pass
elif chunk == self._boundary + b'--':
self._at_eof = True
epilogue = await self._readline()
next_line = await self._readline()
# the epilogue is expected and then either the end of input or the
# parent multipart boundary, if the parent boundary is found then
# it should be marked as unread and handed to the parent for
# processing
if next_line[:2] == b'--':
self._unread.append(next_line)
# otherwise the request is likely missing an epilogue and both
# lines should be passed to the parent for processing
# (this handles the old behavior gracefully)
else:
self._unread.extend([next_line, epilogue])
else:
raise ValueError('Invalid boundary %r, expected %r'
% (chunk, self._boundary))
async def _read_headers(self) -> 'CIMultiDictProxy[str]':
lines = [b'']
while True:
chunk = await self._content.readline()
chunk = chunk.strip()
lines.append(chunk)
if not chunk:
break
parser = HeadersParser()
headers, raw_headers = parser.parse_headers(lines)
return headers
async def _maybe_release_last_part(self) -> None:
"""Ensures that the last read body part is read completely."""
if self._last_part is not None:
if not self._last_part.at_eof():
await self._last_part.release()
self._unread.extend(self._last_part._unread)
self._last_part = None
_Part = Tuple[Payload, str, str]
class MultipartWriter(Payload):
"""Multipart body writer."""
def __init__(self, subtype: str='mixed',
boundary: Optional[str]=None) -> None:
boundary = boundary if boundary is not None else uuid.uuid4().hex
# The underlying Payload API demands a str (utf-8), not bytes,
# so we need to ensure we don't lose anything during conversion.
# As a result, require the boundary to be ASCII only.
# In both situations.
try:
self._boundary = boundary.encode('ascii')
except UnicodeEncodeError:
raise ValueError('boundary should contain ASCII only chars') \
from None
ctype = ('multipart/{}; boundary={}'
.format(subtype, self._boundary_value))
super().__init__(None, content_type=ctype)
self._parts = [] # type: List[_Part] # noqa
def __enter__(self) -> 'MultipartWriter':
return self
def __exit__(self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType]) -> None:
pass
def __iter__(self) -> Iterator[_Part]:
return iter(self._parts)
def __len__(self) -> int:
return len(self._parts)
def __bool__(self) -> bool:
return True
_valid_tchar_regex = re.compile(br"\A[!#$%&'*+\-.^_`|~\w]+\Z")
_invalid_qdtext_char_regex = re.compile(br"[\x00-\x08\x0A-\x1F\x7F]")
@property
def _boundary_value(self) -> str:
"""Wrap boundary parameter value in quotes, if necessary.
Reads self.boundary and returns a unicode sting.
"""
# Refer to RFCs 7231, 7230, 5234.
#
# parameter = token "=" ( token / quoted-string )
# token = 1*tchar
# quoted-string = DQUOTE *( qdtext / quoted-pair ) DQUOTE
# qdtext = HTAB / SP / %x21 / %x23-5B / %x5D-7E / obs-text
# obs-text = %x80-FF
# quoted-pair = "\" ( HTAB / SP / VCHAR / obs-text )
# tchar = "!" / "#" / "$" / "%" / "&" / "'" / "*"
# / "+" / "-" / "." / "^" / "_" / "`" / "|" / "~"
# / DIGIT / ALPHA
# ; any VCHAR, except delimiters
# VCHAR = %x21-7E
value = self._boundary
if re.match(self._valid_tchar_regex, value):
return value.decode('ascii') # cannot fail
if re.search(self._invalid_qdtext_char_regex, value):
raise ValueError("boundary value contains invalid characters")
# escape %x5C and %x22
quoted_value_content = value.replace(b'\\', b'\\\\')
quoted_value_content = quoted_value_content.replace(b'"', b'\\"')
return '"' + quoted_value_content.decode('ascii') + '"'
@property
def boundary(self) -> str:
return self._boundary.decode('ascii')
def append(
self,
obj: Any,
headers: Optional[MultiMapping[str]]=None
) -> Payload:
if headers is None:
headers = CIMultiDict()
if isinstance(obj, Payload):
obj.headers.update(headers)
return self.append_payload(obj)
else:
try:
payload = get_payload(obj, headers=headers)
except LookupError:
raise TypeError('Cannot create payload from %r' % obj)
else:
return self.append_payload(payload)
def append_payload(self, payload: Payload) -> Payload:
"""Adds a new body part to multipart writer."""
# compression
encoding = payload.headers.get(
CONTENT_ENCODING,
'',
).lower() # type: Optional[str]
if encoding and encoding not in ('deflate', 'gzip', 'identity'):
raise RuntimeError('unknown content encoding: {}'.format(encoding))
if encoding == 'identity':
encoding = None
# te encoding
te_encoding = payload.headers.get(
CONTENT_TRANSFER_ENCODING,
'',
).lower() # type: Optional[str]
if te_encoding not in ('', 'base64', 'quoted-printable', 'binary'):
raise RuntimeError('unknown content transfer encoding: {}'
''.format(te_encoding))
if te_encoding == 'binary':
te_encoding = None
# size
size = payload.size
if size is not None and not (encoding or te_encoding):
payload.headers[CONTENT_LENGTH] = str(size)
self._parts.append((payload, encoding, te_encoding)) # type: ignore
return payload
def append_json(
self,
obj: Any,
headers: Optional[MultiMapping[str]]=None
) -> Payload:
"""Helper to append JSON part."""
if headers is None:
headers = CIMultiDict()
return self.append_payload(JsonPayload(obj, headers=headers))
def append_form(
self,
obj: Union[Sequence[Tuple[str, str]],
Mapping[str, str]],
headers: Optional[MultiMapping[str]]=None
) -> Payload:
"""Helper to append form urlencoded part."""
assert isinstance(obj, (Sequence, Mapping))
if headers is None:
headers = CIMultiDict()
if isinstance(obj, Mapping):
obj = list(obj.items())
data = urlencode(obj, doseq=True)
return self.append_payload(
StringPayload(data, headers=headers,
content_type='application/x-www-form-urlencoded'))
@property
def size(self) -> Optional[int]:
"""Size of the payload."""
total = 0
for part, encoding, te_encoding in self._parts:
if encoding or te_encoding or part.size is None:
return None
total += int(
2 + len(self._boundary) + 2 + # b'--'+self._boundary+b'\r\n'
part.size + len(part._binary_headers) +
2 # b'\r\n'
)
total += 2 + len(self._boundary) + 4 # b'--'+self._boundary+b'--\r\n'
return total
async def write(self, writer: Any,
close_boundary: bool=True) -> None:
"""Write body."""
for part, encoding, te_encoding in self._parts:
await writer.write(b'--' + self._boundary + b'\r\n')
await writer.write(part._binary_headers)
if encoding or te_encoding:
w = MultipartPayloadWriter(writer)
if encoding:
w.enable_compression(encoding)
if te_encoding:
w.enable_encoding(te_encoding)
await part.write(w) # type: ignore
await w.write_eof()
else:
await part.write(writer)
await writer.write(b'\r\n')
if close_boundary:
await writer.write(b'--' + self._boundary + b'--\r\n')
class MultipartPayloadWriter:
def __init__(self, writer: Any) -> None:
self._writer = writer
self._encoding = None # type: Optional[str]
self._compress = None # type: Any
self._encoding_buffer = None # type: Optional[bytearray]
def enable_encoding(self, encoding: str) -> None:
if encoding == 'base64':
self._encoding = encoding
self._encoding_buffer = bytearray()
elif encoding == 'quoted-printable':
self._encoding = 'quoted-printable'
def enable_compression(self, encoding: str='deflate') -> None:
zlib_mode = (16 + zlib.MAX_WBITS
if encoding == 'gzip' else -zlib.MAX_WBITS)
self._compress = zlib.compressobj(wbits=zlib_mode)
async def write_eof(self) -> None:
if self._compress is not None:
chunk = self._compress.flush()
if chunk:
self._compress = None
await self.write(chunk)
if self._encoding == 'base64':
if self._encoding_buffer:
await self._writer.write(base64.b64encode(
self._encoding_buffer))
async def write(self, chunk: bytes) -> None:
if self._compress is not None:
if chunk:
chunk = self._compress.compress(chunk)
if not chunk:
return
if self._encoding == 'base64':
buf = self._encoding_buffer
assert buf is not None
buf.extend(chunk)
if buf:
div, mod = divmod(len(buf), 3)
enc_chunk, self._encoding_buffer = (
buf[:div * 3], buf[div * 3:])
if enc_chunk:
b64chunk = base64.b64encode(enc_chunk)
await self._writer.write(b64chunk)
elif self._encoding == 'quoted-printable':
await self._writer.write(binascii.b2a_qp(chunk))
else:
await self._writer.write(chunk)

View File

@ -0,0 +1,456 @@
import asyncio
import enum
import io
import json
import mimetypes
import os
import warnings
from abc import ABC, abstractmethod
from itertools import chain
from typing import (
IO,
TYPE_CHECKING,
Any,
ByteString,
Dict,
Iterable,
Optional,
Text,
TextIO,
Tuple,
Type,
Union,
)
from multidict import CIMultiDict
from . import hdrs
from .abc import AbstractStreamWriter
from .helpers import (
PY_36,
content_disposition_header,
guess_filename,
parse_mimetype,
sentinel,
)
from .streams import DEFAULT_LIMIT, StreamReader
from .typedefs import JSONEncoder, _CIMultiDict
__all__ = ('PAYLOAD_REGISTRY', 'get_payload', 'payload_type', 'Payload',
'BytesPayload', 'StringPayload',
'IOBasePayload', 'BytesIOPayload', 'BufferedReaderPayload',
'TextIOPayload', 'StringIOPayload', 'JsonPayload',
'AsyncIterablePayload')
TOO_LARGE_BYTES_BODY = 2 ** 20 # 1 MB
if TYPE_CHECKING: # pragma: no cover
from typing import List # noqa
class LookupError(Exception):
pass
class Order(str, enum.Enum):
normal = 'normal'
try_first = 'try_first'
try_last = 'try_last'
def get_payload(data: Any, *args: Any, **kwargs: Any) -> 'Payload':
return PAYLOAD_REGISTRY.get(data, *args, **kwargs)
def register_payload(factory: Type['Payload'],
type: Any,
*,
order: Order=Order.normal) -> None:
PAYLOAD_REGISTRY.register(factory, type, order=order)
class payload_type:
def __init__(self, type: Any, *, order: Order=Order.normal) -> None:
self.type = type
self.order = order
def __call__(self, factory: Type['Payload']) -> Type['Payload']:
register_payload(factory, self.type, order=self.order)
return factory
class PayloadRegistry:
"""Payload registry.
note: we need zope.interface for more efficient adapter search
"""
def __init__(self) -> None:
self._first = [] # type: List[Tuple[Type[Payload], Any]]
self._normal = [] # type: List[Tuple[Type[Payload], Any]]
self._last = [] # type: List[Tuple[Type[Payload], Any]]
def get(self,
data: Any,
*args: Any,
_CHAIN: Any=chain,
**kwargs: Any) -> 'Payload':
if isinstance(data, Payload):
return data
for factory, type in _CHAIN(self._first, self._normal, self._last):
if isinstance(data, type):
return factory(data, *args, **kwargs)
raise LookupError()
def register(self,
factory: Type['Payload'],
type: Any,
*,
order: Order=Order.normal) -> None:
if order is Order.try_first:
self._first.append((factory, type))
elif order is Order.normal:
self._normal.append((factory, type))
elif order is Order.try_last:
self._last.append((factory, type))
else:
raise ValueError("Unsupported order {!r}".format(order))
class Payload(ABC):
_default_content_type = 'application/octet-stream' # type: str
_size = None # type: Optional[int]
def __init__(self,
value: Any,
headers: Optional[
Union[
_CIMultiDict,
Dict[str, str],
Iterable[Tuple[str, str]]
]
] = None,
content_type: Optional[str]=sentinel,
filename: Optional[str]=None,
encoding: Optional[str]=None,
**kwargs: Any) -> None:
self._encoding = encoding
self._filename = filename
self._headers = CIMultiDict() # type: _CIMultiDict
self._value = value
if content_type is not sentinel and content_type is not None:
self._headers[hdrs.CONTENT_TYPE] = content_type
elif self._filename is not None:
content_type = mimetypes.guess_type(self._filename)[0]
if content_type is None:
content_type = self._default_content_type
self._headers[hdrs.CONTENT_TYPE] = content_type
else:
self._headers[hdrs.CONTENT_TYPE] = self._default_content_type
self._headers.update(headers or {})
@property
def size(self) -> Optional[int]:
"""Size of the payload."""
return self._size
@property
def filename(self) -> Optional[str]:
"""Filename of the payload."""
return self._filename
@property
def headers(self) -> _CIMultiDict:
"""Custom item headers"""
return self._headers
@property
def _binary_headers(self) -> bytes:
return ''.join(
[k + ': ' + v + '\r\n' for k, v in self.headers.items()]
).encode('utf-8') + b'\r\n'
@property
def encoding(self) -> Optional[str]:
"""Payload encoding"""
return self._encoding
@property
def content_type(self) -> str:
"""Content type"""
return self._headers[hdrs.CONTENT_TYPE]
def set_content_disposition(self,
disptype: str,
quote_fields: bool=True,
**params: Any) -> None:
"""Sets ``Content-Disposition`` header."""
self._headers[hdrs.CONTENT_DISPOSITION] = content_disposition_header(
disptype, quote_fields=quote_fields, **params)
@abstractmethod
async def write(self, writer: AbstractStreamWriter) -> None:
"""Write payload.
writer is an AbstractStreamWriter instance:
"""
class BytesPayload(Payload):
def __init__(self,
value: ByteString,
*args: Any,
**kwargs: Any) -> None:
if not isinstance(value, (bytes, bytearray, memoryview)):
raise TypeError("value argument must be byte-ish, not {!r}"
.format(type(value)))
if 'content_type' not in kwargs:
kwargs['content_type'] = 'application/octet-stream'
super().__init__(value, *args, **kwargs)
self._size = len(value)
if self._size > TOO_LARGE_BYTES_BODY:
if PY_36:
kwargs = {'source': self}
else:
kwargs = {}
warnings.warn("Sending a large body directly with raw bytes might"
" lock the event loop. You should probably pass an "
"io.BytesIO object instead", ResourceWarning,
**kwargs)
async def write(self, writer: AbstractStreamWriter) -> None:
await writer.write(self._value)
class StringPayload(BytesPayload):
def __init__(self,
value: Text,
*args: Any,
encoding: Optional[str]=None,
content_type: Optional[str]=None,
**kwargs: Any) -> None:
if encoding is None:
if content_type is None:
real_encoding = 'utf-8'
content_type = 'text/plain; charset=utf-8'
else:
mimetype = parse_mimetype(content_type)
real_encoding = mimetype.parameters.get('charset', 'utf-8')
else:
if content_type is None:
content_type = 'text/plain; charset=%s' % encoding
real_encoding = encoding
super().__init__(
value.encode(real_encoding),
encoding=real_encoding,
content_type=content_type,
*args,
**kwargs,
)
class StringIOPayload(StringPayload):
def __init__(self,
value: IO[str],
*args: Any,
**kwargs: Any) -> None:
super().__init__(value.read(), *args, **kwargs)
class IOBasePayload(Payload):
def __init__(self,
value: IO[Any],
disposition: str='attachment',
*args: Any,
**kwargs: Any) -> None:
if 'filename' not in kwargs:
kwargs['filename'] = guess_filename(value)
super().__init__(value, *args, **kwargs)
if self._filename is not None and disposition is not None:
if hdrs.CONTENT_DISPOSITION not in self.headers:
self.set_content_disposition(
disposition, filename=self._filename
)
async def write(self, writer: AbstractStreamWriter) -> None:
loop = asyncio.get_event_loop()
try:
chunk = await loop.run_in_executor(
None, self._value.read, DEFAULT_LIMIT
)
while chunk:
await writer.write(chunk)
chunk = await loop.run_in_executor(
None, self._value.read, DEFAULT_LIMIT
)
finally:
await loop.run_in_executor(None, self._value.close)
class TextIOPayload(IOBasePayload):
def __init__(self,
value: TextIO,
*args: Any,
encoding: Optional[str]=None,
content_type: Optional[str]=None,
**kwargs: Any) -> None:
if encoding is None:
if content_type is None:
encoding = 'utf-8'
content_type = 'text/plain; charset=utf-8'
else:
mimetype = parse_mimetype(content_type)
encoding = mimetype.parameters.get('charset', 'utf-8')
else:
if content_type is None:
content_type = 'text/plain; charset=%s' % encoding
super().__init__(
value,
content_type=content_type,
encoding=encoding,
*args,
**kwargs,
)
@property
def size(self) -> Optional[int]:
try:
return os.fstat(self._value.fileno()).st_size - self._value.tell()
except OSError:
return None
async def write(self, writer: AbstractStreamWriter) -> None:
loop = asyncio.get_event_loop()
try:
chunk = await loop.run_in_executor(
None, self._value.read, DEFAULT_LIMIT
)
while chunk:
await writer.write(chunk.encode(self._encoding))
chunk = await loop.run_in_executor(
None, self._value.read, DEFAULT_LIMIT
)
finally:
await loop.run_in_executor(None, self._value.close)
class BytesIOPayload(IOBasePayload):
@property
def size(self) -> int:
position = self._value.tell()
end = self._value.seek(0, os.SEEK_END)
self._value.seek(position)
return end - position
class BufferedReaderPayload(IOBasePayload):
@property
def size(self) -> Optional[int]:
try:
return os.fstat(self._value.fileno()).st_size - self._value.tell()
except OSError:
# data.fileno() is not supported, e.g.
# io.BufferedReader(io.BytesIO(b'data'))
return None
class JsonPayload(BytesPayload):
def __init__(self,
value: Any,
encoding: str='utf-8',
content_type: str='application/json',
dumps: JSONEncoder=json.dumps,
*args: Any,
**kwargs: Any) -> None:
super().__init__(
dumps(value).encode(encoding),
content_type=content_type, encoding=encoding, *args, **kwargs)
if TYPE_CHECKING: # pragma: no cover
from typing import AsyncIterator, AsyncIterable
_AsyncIterator = AsyncIterator[bytes]
_AsyncIterable = AsyncIterable[bytes]
else:
from collections.abc import AsyncIterable, AsyncIterator
_AsyncIterator = AsyncIterator
_AsyncIterable = AsyncIterable
class AsyncIterablePayload(Payload):
_iter = None # type: Optional[_AsyncIterator]
def __init__(self,
value: _AsyncIterable,
*args: Any,
**kwargs: Any) -> None:
if not isinstance(value, AsyncIterable):
raise TypeError("value argument must support "
"collections.abc.AsyncIterablebe interface, "
"got {!r}".format(type(value)))
if 'content_type' not in kwargs:
kwargs['content_type'] = 'application/octet-stream'
super().__init__(value, *args, **kwargs)
self._iter = value.__aiter__()
async def write(self, writer: AbstractStreamWriter) -> None:
if self._iter:
try:
# iter is not None check prevents rare cases
# when the case iterable is used twice
while True:
chunk = await self._iter.__anext__()
await writer.write(chunk)
except StopAsyncIteration:
self._iter = None
class StreamReaderPayload(AsyncIterablePayload):
def __init__(self, value: StreamReader, *args: Any, **kwargs: Any) -> None:
super().__init__(value.iter_any(), *args, **kwargs)
PAYLOAD_REGISTRY = PayloadRegistry()
PAYLOAD_REGISTRY.register(BytesPayload, (bytes, bytearray, memoryview))
PAYLOAD_REGISTRY.register(StringPayload, str)
PAYLOAD_REGISTRY.register(StringIOPayload, io.StringIO)
PAYLOAD_REGISTRY.register(TextIOPayload, io.TextIOBase)
PAYLOAD_REGISTRY.register(BytesIOPayload, io.BytesIO)
PAYLOAD_REGISTRY.register(
BufferedReaderPayload, (io.BufferedReader, io.BufferedRandom))
PAYLOAD_REGISTRY.register(IOBasePayload, io.IOBase)
PAYLOAD_REGISTRY.register(StreamReaderPayload, StreamReader)
# try_last for giving a chance to more specialized async interables like
# multidict.BodyPartReaderPayload override the default
PAYLOAD_REGISTRY.register(AsyncIterablePayload, AsyncIterable,
order=Order.try_last)

View File

@ -0,0 +1,74 @@
""" Payload implemenation for coroutines as data provider.
As a simple case, you can upload data from file::
@hyper_internal_service.streamer
async def file_sender(writer, file_name=None):
with open(file_name, 'rb') as f:
chunk = f.read(2**16)
while chunk:
await writer.write(chunk)
chunk = f.read(2**16)
Then you can use `file_sender` like this:
async with session.post('http://httpbin.org/post',
data=file_sender(file_name='huge_file')) as resp:
print(await resp.text())
..note:: Coroutine must accept `writer` as first argument
"""
import types
import warnings
from typing import Any, Awaitable, Callable, Dict, Tuple
from .abc import AbstractStreamWriter
from .payload import Payload, payload_type
__all__ = ('streamer',)
class _stream_wrapper:
def __init__(self,
coro: Callable[..., Awaitable[None]],
args: Tuple[Any, ...],
kwargs: Dict[str, Any]) -> None:
self.coro = types.coroutine(coro)
self.args = args
self.kwargs = kwargs
async def __call__(self, writer: AbstractStreamWriter) -> None:
await self.coro(writer, *self.args, **self.kwargs) # type: ignore
class streamer:
def __init__(self, coro: Callable[..., Awaitable[None]]) -> None:
warnings.warn("@streamer is deprecated, use async generators instead",
DeprecationWarning,
stacklevel=2)
self.coro = coro
def __call__(self, *args: Any, **kwargs: Any) -> _stream_wrapper:
return _stream_wrapper(self.coro, args, kwargs)
@payload_type(_stream_wrapper)
class StreamWrapperPayload(Payload):
async def write(self, writer: AbstractStreamWriter) -> None:
await self._value(writer)
@payload_type(streamer)
class StreamPayload(StreamWrapperPayload):
def __init__(self, value: Any, *args: Any, **kwargs: Any) -> None:
super().__init__(value(), *args, **kwargs)
async def write(self, writer: AbstractStreamWriter) -> None:
await self._value(writer)

View File

@ -0,0 +1 @@
Marker

View File

@ -0,0 +1,356 @@
import asyncio
import contextlib
import warnings
from collections.abc import Callable
import pytest
from hyper_internal_service.helpers import PY_37, isasyncgenfunction
from hyper_internal_service.web import Application
from .test_utils import (
BaseTestServer,
RawTestServer,
TestClient,
TestServer,
loop_context,
setup_test_loop,
teardown_test_loop,
)
from .test_utils import unused_port as _unused_port
try:
import uvloop
except ImportError: # pragma: no cover
uvloop = None
try:
import tokio
except ImportError: # pragma: no cover
tokio = None
def pytest_addoption(parser): # type: ignore
parser.addoption(
'--hyper_internal_service-fast', action='store_true', default=False,
help='run tests faster by disabling extra checks')
parser.addoption(
'--hyper_internal_service-loop', action='store', default='pyloop',
help='run tests with specific loop: pyloop, uvloop, tokio or all')
parser.addoption(
'--hyper_internal_service-enable-loop-debug', action='store_true', default=False,
help='enable event loop debug mode')
def pytest_fixture_setup(fixturedef): # type: ignore
"""
Allow fixtures to be coroutines. Run coroutine fixtures in an event loop.
"""
func = fixturedef.func
if isasyncgenfunction(func):
# async generator fixture
is_async_gen = True
elif asyncio.iscoroutinefunction(func):
# regular async fixture
is_async_gen = False
else:
# not an async fixture, nothing to do
return
strip_request = False
if 'request' not in fixturedef.argnames:
fixturedef.argnames += ('request',)
strip_request = True
def wrapper(*args, **kwargs): # type: ignore
request = kwargs['request']
if strip_request:
del kwargs['request']
# if neither the fixture nor the test use the 'loop' fixture,
# 'getfixturevalue' will fail because the test is not parameterized
# (this can be removed someday if 'loop' is no longer parameterized)
if 'loop' not in request.fixturenames:
raise Exception(
"Asynchronous fixtures must depend on the 'loop' fixture or "
"be used in tests depending from it."
)
_loop = request.getfixturevalue('loop')
if is_async_gen:
# for async generators, we need to advance the generator once,
# then advance it again in a finalizer
gen = func(*args, **kwargs)
def finalizer(): # type: ignore
try:
return _loop.run_until_complete(gen.__anext__())
except StopAsyncIteration: # NOQA
pass
request.addfinalizer(finalizer)
return _loop.run_until_complete(gen.__anext__())
else:
return _loop.run_until_complete(func(*args, **kwargs))
fixturedef.func = wrapper
@pytest.fixture
def fast(request): # type: ignore
"""--fast config option"""
return request.config.getoption('--hyper_internal_service-fast')
@pytest.fixture
def loop_debug(request): # type: ignore
"""--enable-loop-debug config option"""
return request.config.getoption('--hyper_internal_service-enable-loop-debug')
@contextlib.contextmanager
def _runtime_warning_context(): # type: ignore
"""
Context manager which checks for RuntimeWarnings, specifically to
avoid "coroutine 'X' was never awaited" warnings being missed.
If RuntimeWarnings occur in the context a RuntimeError is raised.
"""
with warnings.catch_warnings(record=True) as _warnings:
yield
rw = ['{w.filename}:{w.lineno}:{w.message}'.format(w=w)
for w in _warnings
if w.category == RuntimeWarning]
if rw:
raise RuntimeError('{} Runtime Warning{},\n{}'.format(
len(rw),
'' if len(rw) == 1 else 's',
'\n'.join(rw)
))
@contextlib.contextmanager
def _passthrough_loop_context(loop, fast=False): # type: ignore
"""
setups and tears down a loop unless one is passed in via the loop
argument when it's passed straight through.
"""
if loop:
# loop already exists, pass it straight through
yield loop
else:
# this shadows loop_context's standard behavior
loop = setup_test_loop()
yield loop
teardown_test_loop(loop, fast=fast)
def pytest_pycollect_makeitem(collector, name, obj): # type: ignore
"""
Fix pytest collecting for coroutines.
"""
if collector.funcnamefilter(name) and asyncio.iscoroutinefunction(obj):
return list(collector._genfunctions(name, obj))
def pytest_pyfunc_call(pyfuncitem): # type: ignore
"""
Run coroutines in an event loop instead of a normal function call.
"""
fast = pyfuncitem.config.getoption("--hyper_internal_service-fast")
if asyncio.iscoroutinefunction(pyfuncitem.function):
existing_loop = pyfuncitem.funcargs.get('proactor_loop')\
or pyfuncitem.funcargs.get('loop', None)
with _runtime_warning_context():
with _passthrough_loop_context(existing_loop, fast=fast) as _loop:
testargs = {arg: pyfuncitem.funcargs[arg]
for arg in pyfuncitem._fixtureinfo.argnames}
_loop.run_until_complete(pyfuncitem.obj(**testargs))
return True
def pytest_generate_tests(metafunc): # type: ignore
if 'loop_factory' not in metafunc.fixturenames:
return
loops = metafunc.config.option.hyper_internal_service_loop
avail_factories = {'pyloop': asyncio.DefaultEventLoopPolicy}
if uvloop is not None: # pragma: no cover
avail_factories['uvloop'] = uvloop.EventLoopPolicy
if tokio is not None: # pragma: no cover
avail_factories['tokio'] = tokio.EventLoopPolicy
if loops == 'all':
loops = 'pyloop,uvloop?,tokio?'
factories = {} # type: ignore
for name in loops.split(','):
required = not name.endswith('?')
name = name.strip(' ?')
if name not in avail_factories: # pragma: no cover
if required:
raise ValueError(
"Unknown loop '%s', available loops: %s" % (
name, list(factories.keys())))
else:
continue
factories[name] = avail_factories[name]
metafunc.parametrize("loop_factory",
list(factories.values()),
ids=list(factories.keys()))
@pytest.fixture
def loop(loop_factory, fast, loop_debug): # type: ignore
"""Return an instance of the event loop."""
policy = loop_factory()
asyncio.set_event_loop_policy(policy)
with loop_context(fast=fast) as _loop:
if loop_debug:
_loop.set_debug(True) # pragma: no cover
asyncio.set_event_loop(_loop)
yield _loop
@pytest.fixture
def proactor_loop(): # type: ignore
if not PY_37:
policy = asyncio.get_event_loop_policy()
policy._loop_factory = asyncio.ProactorEventLoop # type: ignore
else:
policy = asyncio.WindowsProactorEventLoopPolicy() # type: ignore
asyncio.set_event_loop_policy(policy)
with loop_context(policy.new_event_loop) as _loop:
asyncio.set_event_loop(_loop)
yield _loop
@pytest.fixture
def unused_port(hyper_internal_service_unused_port): # type: ignore # pragma: no cover
warnings.warn("Deprecated, use hyper_internal_service_unused_port fixture instead",
DeprecationWarning,
stacklevel=2)
return hyper_internal_service_unused_port
@pytest.fixture
def hyper_internal_service_unused_port(): # type: ignore
"""Return a port that is unused on the current host."""
return _unused_port
@pytest.fixture
def hyper_internal_service_server(loop): # type: ignore
"""Factory to create a TestServer instance, given an app.
hyper_internal_service_server(app, **kwargs)
"""
servers = []
async def go(app, *, port=None, **kwargs): # type: ignore
server = TestServer(app, port=port)
await server.start_server(loop=loop, **kwargs)
servers.append(server)
return server
yield go
async def finalize(): # type: ignore
while servers:
await servers.pop().close()
loop.run_until_complete(finalize())
@pytest.fixture
def test_server(hyper_internal_service_server): # type: ignore # pragma: no cover
warnings.warn("Deprecated, use hyper_internal_service_server fixture instead",
DeprecationWarning,
stacklevel=2)
return hyper_internal_service_server
@pytest.fixture
def hyper_internal_service_raw_server(loop): # type: ignore
"""Factory to create a RawTestServer instance, given a web handler.
hyper_internal_service_raw_server(handler, **kwargs)
"""
servers = []
async def go(handler, *, port=None, **kwargs): # type: ignore
server = RawTestServer(handler, port=port)
await server.start_server(loop=loop, **kwargs)
servers.append(server)
return server
yield go
async def finalize(): # type: ignore
while servers:
await servers.pop().close()
loop.run_until_complete(finalize())
@pytest.fixture
def raw_test_server(hyper_internal_service_raw_server): # type: ignore # pragma: no cover
warnings.warn("Deprecated, use hyper_internal_service_raw_server fixture instead",
DeprecationWarning,
stacklevel=2)
return hyper_internal_service_raw_server
@pytest.fixture
def hyper_internal_service_client(loop): # type: ignore
"""Factory to create a TestClient instance.
hyper_internal_service_client(app, **kwargs)
hyper_internal_service_client(server, **kwargs)
hyper_internal_service_client(raw_server, **kwargs)
"""
clients = []
async def go(__param, *args, server_kwargs=None, **kwargs): # type: ignore
if (isinstance(__param, Callable) and # type: ignore
not isinstance(__param, (Application, BaseTestServer))):
__param = __param(loop, *args, **kwargs)
kwargs = {}
else:
assert not args, "args should be empty"
if isinstance(__param, Application):
server_kwargs = server_kwargs or {}
server = TestServer(__param, loop=loop, **server_kwargs)
client = TestClient(server, loop=loop, **kwargs)
elif isinstance(__param, BaseTestServer):
client = TestClient(__param, loop=loop, **kwargs)
else:
raise ValueError("Unknown argument type: %r" % type(__param))
await client.start_server()
clients.append(client)
return client
yield go
async def finalize(): # type: ignore
while clients:
await clients.pop().close()
loop.run_until_complete(finalize())
@pytest.fixture
def test_client(hyper_internal_service_client): # type: ignore # pragma: no cover
warnings.warn("Deprecated, use hyper_internal_service_client fixture instead",
DeprecationWarning,
stacklevel=2)
return hyper_internal_service_client

View File

@ -0,0 +1,112 @@
import asyncio
import socket
from typing import Any, Dict, List, Optional
from .abc import AbstractResolver
from .helpers import get_running_loop
__all__ = ('ThreadedResolver', 'AsyncResolver', 'DefaultResolver')
try:
import aiodns
# aiodns_default = hasattr(aiodns.DNSResolver, 'gethostbyname')
except ImportError: # pragma: no cover
aiodns = None
aiodns_default = False
class ThreadedResolver(AbstractResolver):
"""Use Executor for synchronous getaddrinfo() calls, which defaults to
concurrent.futures.ThreadPoolExecutor.
"""
def __init__(self, loop: Optional[asyncio.AbstractEventLoop]=None) -> None:
self._loop = get_running_loop(loop)
async def resolve(self, host: str, port: int=0,
family: int=socket.AF_INET) -> List[Dict[str, Any]]:
infos = await self._loop.getaddrinfo(
host, port, type=socket.SOCK_STREAM, family=family)
hosts = []
for family, _, proto, _, address in infos:
hosts.append(
{'hostname': host,
'host': address[0], 'port': address[1],
'family': family, 'proto': proto,
'flags': socket.AI_NUMERICHOST})
return hosts
async def close(self) -> None:
pass
class AsyncResolver(AbstractResolver):
"""Use the `aiodns` package to make asynchronous DNS lookups"""
def __init__(self, loop: Optional[asyncio.AbstractEventLoop]=None,
*args: Any, **kwargs: Any) -> None:
if aiodns is None:
raise RuntimeError("Resolver requires aiodns library")
self._loop = get_running_loop(loop)
self._resolver = aiodns.DNSResolver(*args, loop=loop, **kwargs)
if not hasattr(self._resolver, 'gethostbyname'):
# aiodns 1.1 is not available, fallback to DNSResolver.query
self.resolve = self._resolve_with_query # type: ignore
async def resolve(self, host: str, port: int=0,
family: int=socket.AF_INET) -> List[Dict[str, Any]]:
try:
resp = await self._resolver.gethostbyname(host, family)
except aiodns.error.DNSError as exc:
msg = exc.args[1] if len(exc.args) >= 1 else "DNS lookup failed"
raise OSError(msg) from exc
hosts = []
for address in resp.addresses:
hosts.append(
{'hostname': host,
'host': address, 'port': port,
'family': family, 'proto': 0,
'flags': socket.AI_NUMERICHOST})
if not hosts:
raise OSError("DNS lookup failed")
return hosts
async def _resolve_with_query(
self, host: str, port: int=0,
family: int=socket.AF_INET) -> List[Dict[str, Any]]:
if family == socket.AF_INET6:
qtype = 'AAAA'
else:
qtype = 'A'
try:
resp = await self._resolver.query(host, qtype)
except aiodns.error.DNSError as exc:
msg = exc.args[1] if len(exc.args) >= 1 else "DNS lookup failed"
raise OSError(msg) from exc
hosts = []
for rr in resp:
hosts.append(
{'hostname': host,
'host': rr.host, 'port': port,
'family': family, 'proto': 0,
'flags': socket.AI_NUMERICHOST})
if not hosts:
raise OSError("DNS lookup failed")
return hosts
async def close(self) -> None:
return self._resolver.cancel()
DefaultResolver = AsyncResolver if aiodns_default else ThreadedResolver

View File

@ -0,0 +1,34 @@
from hyper_internal_service.frozenlist import FrozenList
__all__ = ('Signal',)
class Signal(FrozenList):
"""Coroutine-based signal implementation.
To connect a callback to a signal, use any list method.
Signals are fired using the send() coroutine, which takes named
arguments.
"""
__slots__ = ('_owner',)
def __init__(self, owner):
super().__init__()
self._owner = owner
def __repr__(self):
return '<Signal owner={}, frozen={}, {!r}>'.format(self._owner,
self.frozen,
list(self))
async def send(self, *args, **kwargs):
"""
Sends data to all registered receivers.
"""
if not self.frozen:
raise RuntimeError("Cannot send non-frozen signal.")
for receiver in self:
await receiver(*args, **kwargs) # type: ignore

View File

@ -0,0 +1,17 @@
from typing import Any, Generic, TypeVar
from hyper_internal_service.frozenlist import FrozenList
__all__ = ('Signal',)
_T = TypeVar('_T')
class Signal(FrozenList[_T], Generic[_T]):
def __init__(self, owner: Any) -> None: ...
def __repr__(self) -> str: ...
async def send(self, *args: Any, **kwargs: Any) -> None: ...

View File

@ -0,0 +1,634 @@
import asyncio
import collections
import warnings
from typing import List # noqa
from typing import Awaitable, Callable, Generic, Optional, Tuple, TypeVar
from .base_protocol import BaseProtocol
from .helpers import BaseTimerContext, set_exception, set_result
from .log import internal_logger
try: # pragma: no cover
from typing import Deque # noqa
except ImportError:
from typing_extensions import Deque # noqa
__all__ = (
'EMPTY_PAYLOAD', 'EofStream', 'StreamReader', 'DataQueue',
'FlowControlDataQueue')
DEFAULT_LIMIT = 2 ** 16
_T = TypeVar('_T')
class EofStream(Exception):
"""eof stream indication."""
class AsyncStreamIterator(Generic[_T]):
def __init__(self, read_func: Callable[[], Awaitable[_T]]) -> None:
self.read_func = read_func
def __aiter__(self) -> 'AsyncStreamIterator[_T]':
return self
async def __anext__(self) -> _T:
try:
rv = await self.read_func()
except EofStream:
raise StopAsyncIteration # NOQA
if rv == b'':
raise StopAsyncIteration # NOQA
return rv
class ChunkTupleAsyncStreamIterator:
def __init__(self, stream: 'StreamReader') -> None:
self._stream = stream
def __aiter__(self) -> 'ChunkTupleAsyncStreamIterator':
return self
async def __anext__(self) -> Tuple[bytes, bool]:
rv = await self._stream.readchunk()
if rv == (b'', False):
raise StopAsyncIteration # NOQA
return rv
class AsyncStreamReaderMixin:
def __aiter__(self) -> AsyncStreamIterator[bytes]:
return AsyncStreamIterator(self.readline) # type: ignore
def iter_chunked(self, n: int) -> AsyncStreamIterator[bytes]:
"""Returns an asynchronous iterator that yields chunks of size n.
Python-3.5 available for Python 3.5+ only
"""
return AsyncStreamIterator(lambda: self.read(n)) # type: ignore
def iter_any(self) -> AsyncStreamIterator[bytes]:
"""Returns an asynchronous iterator that yields all the available
data as soon as it is received
Python-3.5 available for Python 3.5+ only
"""
return AsyncStreamIterator(self.readany) # type: ignore
def iter_chunks(self) -> ChunkTupleAsyncStreamIterator:
"""Returns an asynchronous iterator that yields chunks of data
as they are received by the server. The yielded objects are tuples
of (bytes, bool) as returned by the StreamReader.readchunk method.
Python-3.5 available for Python 3.5+ only
"""
return ChunkTupleAsyncStreamIterator(self) # type: ignore
class StreamReader(AsyncStreamReaderMixin):
"""An enhancement of asyncio.StreamReader.
Supports asynchronous iteration by line, chunk or as available::
async for line in reader:
...
async for chunk in reader.iter_chunked(1024):
...
async for slice in reader.iter_any():
...
"""
total_bytes = 0
def __init__(self, protocol: BaseProtocol,
*, limit: int=DEFAULT_LIMIT,
timer: Optional[BaseTimerContext]=None,
loop: Optional[asyncio.AbstractEventLoop]=None) -> None:
self._protocol = protocol
self._low_water = limit
self._high_water = limit * 2
if loop is None:
loop = asyncio.get_event_loop()
self._loop = loop
self._size = 0
self._cursor = 0
self._http_chunk_splits = None # type: Optional[List[int]]
self._buffer = collections.deque() # type: Deque[bytes]
self._buffer_offset = 0
self._eof = False
self._waiter = None # type: Optional[asyncio.Future[None]]
self._eof_waiter = None # type: Optional[asyncio.Future[None]]
self._exception = None # type: Optional[BaseException]
self._timer = timer
self._eof_callbacks = [] # type: List[Callable[[], None]]
def __repr__(self) -> str:
info = [self.__class__.__name__]
if self._size:
info.append('%d bytes' % self._size)
if self._eof:
info.append('eof')
if self._low_water != DEFAULT_LIMIT:
info.append('low=%d high=%d' % (self._low_water, self._high_water))
if self._waiter:
info.append('w=%r' % self._waiter)
if self._exception:
info.append('e=%r' % self._exception)
return '<%s>' % ' '.join(info)
def exception(self) -> Optional[BaseException]:
return self._exception
def set_exception(self, exc: BaseException) -> None:
self._exception = exc
self._eof_callbacks.clear()
waiter = self._waiter
if waiter is not None:
self._waiter = None
set_exception(waiter, exc)
waiter = self._eof_waiter
if waiter is not None:
self._eof_waiter = None
set_exception(waiter, exc)
def on_eof(self, callback: Callable[[], None]) -> None:
if self._eof:
try:
callback()
except Exception:
internal_logger.exception('Exception in eof callback')
else:
self._eof_callbacks.append(callback)
def feed_eof(self) -> None:
self._eof = True
waiter = self._waiter
if waiter is not None:
self._waiter = None
set_result(waiter, None)
waiter = self._eof_waiter
if waiter is not None:
self._eof_waiter = None
set_result(waiter, None)
for cb in self._eof_callbacks:
try:
cb()
except Exception:
internal_logger.exception('Exception in eof callback')
self._eof_callbacks.clear()
def is_eof(self) -> bool:
"""Return True if 'feed_eof' was called."""
return self._eof
def at_eof(self) -> bool:
"""Return True if the buffer is empty and 'feed_eof' was called."""
return self._eof and not self._buffer
async def wait_eof(self) -> None:
if self._eof:
return
assert self._eof_waiter is None
self._eof_waiter = self._loop.create_future()
try:
await self._eof_waiter
finally:
self._eof_waiter = None
def unread_data(self, data: bytes) -> None:
""" rollback reading some data from stream, inserting it to buffer head.
"""
warnings.warn("unread_data() is deprecated "
"and will be removed in future releases (#3260)",
DeprecationWarning,
stacklevel=2)
if not data:
return
if self._buffer_offset:
self._buffer[0] = self._buffer[0][self._buffer_offset:]
self._buffer_offset = 0
self._size += len(data)
self._cursor -= len(data)
self._buffer.appendleft(data)
self._eof_counter = 0
# TODO: size is ignored, remove the param later
def feed_data(self, data: bytes, size: int=0) -> None:
assert not self._eof, 'feed_data after feed_eof'
if not data:
return
self._size += len(data)
self._buffer.append(data)
self.total_bytes += len(data)
waiter = self._waiter
if waiter is not None:
self._waiter = None
set_result(waiter, None)
if (self._size > self._high_water and
not self._protocol._reading_paused):
self._protocol.pause_reading()
def begin_http_chunk_receiving(self) -> None:
if self._http_chunk_splits is None:
if self.total_bytes:
raise RuntimeError("Called begin_http_chunk_receiving when"
"some data was already fed")
self._http_chunk_splits = []
def end_http_chunk_receiving(self) -> None:
if self._http_chunk_splits is None:
raise RuntimeError("Called end_chunk_receiving without calling "
"begin_chunk_receiving first")
# self._http_chunk_splits contains logical byte offsets from start of
# the body transfer. Each offset is the offset of the end of a chunk.
# "Logical" means bytes, accessible for a user.
# If no chunks containig logical data were received, current position
# is difinitely zero.
pos = self._http_chunk_splits[-1] if self._http_chunk_splits else 0
if self.total_bytes == pos:
# We should not add empty chunks here. So we check for that.
# Note, when chunked + gzip is used, we can receive a chunk
# of compressed data, but that data may not be enough for gzip FSM
# to yield any uncompressed data. That's why current position may
# not change after receiving a chunk.
return
self._http_chunk_splits.append(self.total_bytes)
# wake up readchunk when end of http chunk received
waiter = self._waiter
if waiter is not None:
self._waiter = None
set_result(waiter, None)
async def _wait(self, func_name: str) -> None:
# StreamReader uses a future to link the protocol feed_data() method
# to a read coroutine. Running two read coroutines at the same time
# would have an unexpected behaviour. It would not possible to know
# which coroutine would get the next data.
if self._waiter is not None:
raise RuntimeError('%s() called while another coroutine is '
'already waiting for incoming data' % func_name)
waiter = self._waiter = self._loop.create_future()
try:
if self._timer:
with self._timer:
await waiter
else:
await waiter
finally:
self._waiter = None
async def readline(self) -> bytes:
if self._exception is not None:
raise self._exception
line = []
line_size = 0
not_enough = True
while not_enough:
while self._buffer and not_enough:
offset = self._buffer_offset
ichar = self._buffer[0].find(b'\n', offset) + 1
# Read from current offset to found b'\n' or to the end.
data = self._read_nowait_chunk(ichar - offset if ichar else -1)
line.append(data)
line_size += len(data)
if ichar:
not_enough = False
if line_size > self._high_water:
raise ValueError('Line is too long')
if self._eof:
break
if not_enough:
await self._wait('readline')
return b''.join(line)
async def read(self, n: int=-1) -> bytes:
if self._exception is not None:
raise self._exception
# migration problem; with DataQueue you have to catch
# EofStream exception, so common way is to run payload.read() inside
# infinite loop. what can cause real infinite loop with StreamReader
# lets keep this code one major release.
if __debug__:
if self._eof and not self._buffer:
self._eof_counter = getattr(self, '_eof_counter', 0) + 1
if self._eof_counter > 5:
internal_logger.warning(
'Multiple access to StreamReader in eof state, '
'might be infinite loop.', stack_info=True)
if not n:
return b''
if n < 0:
# This used to just loop creating a new waiter hoping to
# collect everything in self._buffer, but that would
# deadlock if the subprocess sends more than self.limit
# bytes. So just call self.readany() until EOF.
blocks = []
while True:
block = await self.readany()
if not block:
break
blocks.append(block)
return b''.join(blocks)
# TODO: should be `if` instead of `while`
# because waiter maybe triggered on chunk end,
# without feeding any data
while not self._buffer and not self._eof:
await self._wait('read')
return self._read_nowait(n)
async def readany(self) -> bytes:
if self._exception is not None:
raise self._exception
# TODO: should be `if` instead of `while`
# because waiter maybe triggered on chunk end,
# without feeding any data
while not self._buffer and not self._eof:
await self._wait('readany')
return self._read_nowait(-1)
async def readchunk(self) -> Tuple[bytes, bool]:
"""Returns a tuple of (data, end_of_http_chunk). When chunked transfer
encoding is used, end_of_http_chunk is a boolean indicating if the end
of the data corresponds to the end of a HTTP chunk , otherwise it is
always False.
"""
while True:
if self._exception is not None:
raise self._exception
while self._http_chunk_splits:
pos = self._http_chunk_splits.pop(0)
if pos == self._cursor:
return (b"", True)
if pos > self._cursor:
return (self._read_nowait(pos-self._cursor), True)
internal_logger.warning('Skipping HTTP chunk end due to data '
'consumption beyond chunk boundary')
if self._buffer:
return (self._read_nowait_chunk(-1), False)
# return (self._read_nowait(-1), False)
if self._eof:
# Special case for signifying EOF.
# (b'', True) is not a final return value actually.
return (b'', False)
await self._wait('readchunk')
async def readexactly(self, n: int) -> bytes:
if self._exception is not None:
raise self._exception
blocks = [] # type: List[bytes]
while n > 0:
block = await self.read(n)
if not block:
partial = b''.join(blocks)
raise asyncio.IncompleteReadError(
partial, len(partial) + n)
blocks.append(block)
n -= len(block)
return b''.join(blocks)
def read_nowait(self, n: int=-1) -> bytes:
# default was changed to be consistent with .read(-1)
#
# I believe the most users don't know about the method and
# they are not affected.
if self._exception is not None:
raise self._exception
if self._waiter and not self._waiter.done():
raise RuntimeError(
'Called while some coroutine is waiting for incoming data.')
return self._read_nowait(n)
def _read_nowait_chunk(self, n: int) -> bytes:
first_buffer = self._buffer[0]
offset = self._buffer_offset
if n != -1 and len(first_buffer) - offset > n:
data = first_buffer[offset:offset + n]
self._buffer_offset += n
elif offset:
self._buffer.popleft()
data = first_buffer[offset:]
self._buffer_offset = 0
else:
data = self._buffer.popleft()
self._size -= len(data)
self._cursor += len(data)
chunk_splits = self._http_chunk_splits
# Prevent memory leak: drop useless chunk splits
while chunk_splits and chunk_splits[0] < self._cursor:
chunk_splits.pop(0)
if self._size < self._low_water and self._protocol._reading_paused:
self._protocol.resume_reading()
return data
def _read_nowait(self, n: int) -> bytes:
""" Read not more than n bytes, or whole buffer if n == -1 """
chunks = []
while self._buffer:
chunk = self._read_nowait_chunk(n)
chunks.append(chunk)
if n != -1:
n -= len(chunk)
if n == 0:
break
return b''.join(chunks) if chunks else b''
class EmptyStreamReader(AsyncStreamReaderMixin):
def exception(self) -> Optional[BaseException]:
return None
def set_exception(self, exc: BaseException) -> None:
pass
def on_eof(self, callback: Callable[[], None]) -> None:
try:
callback()
except Exception:
internal_logger.exception('Exception in eof callback')
def feed_eof(self) -> None:
pass
def is_eof(self) -> bool:
return True
def at_eof(self) -> bool:
return True
async def wait_eof(self) -> None:
return
def feed_data(self, data: bytes, n: int=0) -> None:
pass
async def readline(self) -> bytes:
return b''
async def read(self, n: int=-1) -> bytes:
return b''
async def readany(self) -> bytes:
return b''
async def readchunk(self) -> Tuple[bytes, bool]:
return (b'', True)
async def readexactly(self, n: int) -> bytes:
raise asyncio.IncompleteReadError(b'', n)
def read_nowait(self) -> bytes:
return b''
EMPTY_PAYLOAD = EmptyStreamReader()
class DataQueue(Generic[_T]):
"""DataQueue is a general-purpose blocking queue with one reader."""
def __init__(self, loop: asyncio.AbstractEventLoop) -> None:
self._loop = loop
self._eof = False
self._waiter = None # type: Optional[asyncio.Future[None]]
self._exception = None # type: Optional[BaseException]
self._size = 0
self._buffer = collections.deque() # type: Deque[Tuple[_T, int]]
def __len__(self) -> int:
return len(self._buffer)
def is_eof(self) -> bool:
return self._eof
def at_eof(self) -> bool:
return self._eof and not self._buffer
def exception(self) -> Optional[BaseException]:
return self._exception
def set_exception(self, exc: BaseException) -> None:
self._eof = True
self._exception = exc
waiter = self._waiter
if waiter is not None:
self._waiter = None
set_exception(waiter, exc)
def feed_data(self, data: _T, size: int=0) -> None:
self._size += size
self._buffer.append((data, size))
waiter = self._waiter
if waiter is not None:
self._waiter = None
set_result(waiter, None)
def feed_eof(self) -> None:
self._eof = True
waiter = self._waiter
if waiter is not None:
self._waiter = None
set_result(waiter, None)
async def read(self) -> _T:
if not self._buffer and not self._eof:
assert not self._waiter
self._waiter = self._loop.create_future()
try:
await self._waiter
except (asyncio.CancelledError, asyncio.TimeoutError):
self._waiter = None
raise
if self._buffer:
data, size = self._buffer.popleft()
self._size -= size
return data
else:
if self._exception is not None:
raise self._exception
else:
raise EofStream
def __aiter__(self) -> AsyncStreamIterator[_T]:
return AsyncStreamIterator(self.read)
class FlowControlDataQueue(DataQueue[_T]):
"""FlowControlDataQueue resumes and pauses an underlying stream.
It is a destination for parsed data."""
def __init__(self, protocol: BaseProtocol, *,
limit: int=DEFAULT_LIMIT,
loop: asyncio.AbstractEventLoop) -> None:
super().__init__(loop=loop)
self._protocol = protocol
self._limit = limit * 2
def feed_data(self, data: _T, size: int=0) -> None:
super().feed_data(data, size)
if self._size > self._limit and not self._protocol._reading_paused:
self._protocol.pause_reading()
async def read(self) -> _T:
try:
return await super().read()
finally:
if self._size < self._limit and self._protocol._reading_paused:
self._protocol.resume_reading()

View File

@ -0,0 +1,63 @@
"""Helper methods to tune a TCP connection"""
import asyncio
import socket
from contextlib import suppress
from typing import Optional # noqa
__all__ = ('tcp_keepalive', 'tcp_nodelay', 'tcp_cork')
if hasattr(socket, 'TCP_CORK'): # pragma: no cover
CORK = socket.TCP_CORK # type: Optional[int]
elif hasattr(socket, 'TCP_NOPUSH'): # pragma: no cover
CORK = socket.TCP_NOPUSH # type: ignore
else: # pragma: no cover
CORK = None
if hasattr(socket, 'SO_KEEPALIVE'):
def tcp_keepalive(transport: asyncio.Transport) -> None:
sock = transport.get_extra_info('socket')
if sock is not None:
sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
else:
def tcp_keepalive(
transport: asyncio.Transport) -> None: # pragma: no cover
pass
def tcp_nodelay(transport: asyncio.Transport, value: bool) -> None:
sock = transport.get_extra_info('socket')
if sock is None:
return
if sock.family not in (socket.AF_INET, socket.AF_INET6):
return
value = bool(value)
# socket may be closed already, on windows OSError get raised
with suppress(OSError):
sock.setsockopt(
socket.IPPROTO_TCP, socket.TCP_NODELAY, value)
def tcp_cork(transport: asyncio.Transport, value: bool) -> None:
sock = transport.get_extra_info('socket')
if CORK is None:
return
if sock is None:
return
if sock.family not in (socket.AF_INET, socket.AF_INET6):
return
value = bool(value)
with suppress(OSError):
sock.setsockopt(
socket.IPPROTO_TCP, CORK, value)

View File

@ -0,0 +1,679 @@
"""Utilities shared by tests."""
import asyncio
import contextlib
import functools
import gc
import inspect
import os
import socket
import sys
import unittest
from abc import ABC, abstractmethod
from types import TracebackType
from typing import ( # noqa
TYPE_CHECKING,
Any,
Callable,
Iterator,
List,
Optional,
Type,
Union,
)
from unittest import mock
from multidict import CIMultiDict, CIMultiDictProxy
from yarl import URL
import hyper_internal_service
from hyper_internal_service.client import (
ClientResponse,
_RequestContextManager,
_WSRequestContextManager,
)
from . import ClientSession, hdrs
from .abc import AbstractCookieJar
from .client_reqrep import ClientResponse # noqa
from .client_ws import ClientWebSocketResponse # noqa
from .helpers import sentinel
from .http import HttpVersion, RawRequestMessage
from .signals import Signal
from .web import (
Application,
AppRunner,
BaseRunner,
Request,
Server,
ServerRunner,
SockSite,
UrlMappingMatchInfo,
)
from .web_protocol import _RequestHandler
if TYPE_CHECKING: # pragma: no cover
from ssl import SSLContext
else:
SSLContext = None
REUSE_ADDRESS = os.name == 'posix' and sys.platform != 'cygwin'
def get_unused_port_socket(host: str) -> socket.socket:
return get_port_socket(host, 0)
def get_port_socket(host: str, port: int) -> socket.socket:
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
if REUSE_ADDRESS:
# Windows has different semantics for SO_REUSEADDR,
# so don't set it. Ref:
# https://docs.microsoft.com/en-us/windows/win32/winsock/using-so-reuseaddr-and-so-exclusiveaddruse
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
s.bind((host, port))
return s
def unused_port() -> int:
"""Return a port that is unused on the current host."""
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(('127.0.0.1', 0))
return s.getsockname()[1]
class BaseTestServer(ABC):
__test__ = False
def __init__(self,
*,
scheme: Union[str, object]=sentinel,
loop: Optional[asyncio.AbstractEventLoop]=None,
host: str='127.0.0.1',
port: Optional[int]=None,
skip_url_asserts: bool=False,
**kwargs: Any) -> None:
self._loop = loop
self.runner = None # type: Optional[BaseRunner]
self._root = None # type: Optional[URL]
self.host = host
self.port = port
self._closed = False
self.scheme = scheme
self.skip_url_asserts = skip_url_asserts
async def start_server(self,
loop: Optional[asyncio.AbstractEventLoop]=None,
**kwargs: Any) -> None:
if self.runner:
return
self._loop = loop
self._ssl = kwargs.pop('ssl', None)
self.runner = await self._make_runner(**kwargs)
await self.runner.setup()
if not self.port:
self.port = 0
_sock = get_port_socket(self.host, self.port)
self.host, self.port = _sock.getsockname()[:2]
site = SockSite(self.runner, sock=_sock, ssl_context=self._ssl)
await site.start()
server = site._server
assert server is not None
sockets = server.sockets
assert sockets is not None
self.port = sockets[0].getsockname()[1]
if self.scheme is sentinel:
if self._ssl:
scheme = 'https'
else:
scheme = 'http'
self.scheme = scheme
self._root = URL('{}://{}:{}'.format(self.scheme,
self.host,
self.port))
@abstractmethod # pragma: no cover
async def _make_runner(self, **kwargs: Any) -> BaseRunner:
pass
def make_url(self, path: str) -> URL:
assert self._root is not None
url = URL(path)
if not self.skip_url_asserts:
assert not url.is_absolute()
return self._root.join(url)
else:
return URL(str(self._root) + path)
@property
def started(self) -> bool:
return self.runner is not None
@property
def closed(self) -> bool:
return self._closed
@property
def handler(self) -> Server:
# for backward compatibility
# web.Server instance
runner = self.runner
assert runner is not None
assert runner.server is not None
return runner.server
async def close(self) -> None:
"""Close all fixtures created by the test client.
After that point, the TestClient is no longer usable.
This is an idempotent function: running close multiple times
will not have any additional effects.
close is also run when the object is garbage collected, and on
exit when used as a context manager.
"""
if self.started and not self.closed:
assert self.runner is not None
await self.runner.cleanup()
self._root = None
self.port = None
self._closed = True
def __enter__(self) -> None:
raise TypeError("Use async with instead")
def __exit__(self,
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType]) -> None:
# __exit__ should exist in pair with __enter__ but never executed
pass # pragma: no cover
async def __aenter__(self) -> 'BaseTestServer':
await self.start_server(loop=self._loop)
return self
async def __aexit__(self,
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType]) -> None:
await self.close()
class TestServer(BaseTestServer):
def __init__(self, app: Application, *,
scheme: Union[str, object]=sentinel,
host: str='127.0.0.1',
port: Optional[int]=None,
**kwargs: Any):
self.app = app
super().__init__(scheme=scheme, host=host, port=port, **kwargs)
async def _make_runner(self, **kwargs: Any) -> BaseRunner:
return AppRunner(self.app, **kwargs)
class RawTestServer(BaseTestServer):
def __init__(self, handler: _RequestHandler, *,
scheme: Union[str, object]=sentinel,
host: str='127.0.0.1',
port: Optional[int]=None,
**kwargs: Any) -> None:
self._handler = handler
super().__init__(scheme=scheme, host=host, port=port, **kwargs)
async def _make_runner(self,
debug: bool=True,
**kwargs: Any) -> ServerRunner:
srv = Server(
self._handler, loop=self._loop, debug=debug, **kwargs)
return ServerRunner(srv, debug=debug, **kwargs)
class TestClient:
"""
A test client implementation.
To write functional tests for hyper_internal_service based servers.
"""
__test__ = False
def __init__(self, server: BaseTestServer, *,
cookie_jar: Optional[AbstractCookieJar]=None,
loop: Optional[asyncio.AbstractEventLoop]=None,
**kwargs: Any) -> None:
if not isinstance(server, BaseTestServer):
raise TypeError("server must be TestServer "
"instance, found type: %r" % type(server))
self._server = server
self._loop = loop
if cookie_jar is None:
cookie_jar = hyper_internal_service.CookieJar(unsafe=True, loop=loop)
self._session = ClientSession(loop=loop,
cookie_jar=cookie_jar,
**kwargs)
self._closed = False
self._responses = [] # type: List[ClientResponse]
self._websockets = [] # type: List[ClientWebSocketResponse]
async def start_server(self) -> None:
await self._server.start_server(loop=self._loop)
@property
def host(self) -> str:
return self._server.host
@property
def port(self) -> Optional[int]:
return self._server.port
@property
def server(self) -> BaseTestServer:
return self._server
@property
def app(self) -> Application:
return getattr(self._server, "app", None)
@property
def session(self) -> ClientSession:
"""An internal hyper_internal_service.ClientSession.
Unlike the methods on the TestClient, client session requests
do not automatically include the host in the url queried, and
will require an absolute path to the resource.
"""
return self._session
def make_url(self, path: str) -> URL:
return self._server.make_url(path)
async def _request(self, method: str, path: str,
**kwargs: Any) -> ClientResponse:
resp = await self._session.request(
method, self.make_url(path), **kwargs
)
# save it to close later
self._responses.append(resp)
return resp
def request(self, method: str, path: str,
**kwargs: Any) -> _RequestContextManager:
"""Routes a request to tested http server.
The interface is identical to hyper_internal_service.ClientSession.request,
except the loop kwarg is overridden by the instance used by the
test server.
"""
return _RequestContextManager(
self._request(method, path, **kwargs)
)
def get(self, path: str, **kwargs: Any) -> _RequestContextManager:
"""Perform an HTTP GET request."""
return _RequestContextManager(
self._request(hdrs.METH_GET, path, **kwargs)
)
def post(self, path: str, **kwargs: Any) -> _RequestContextManager:
"""Perform an HTTP POST request."""
return _RequestContextManager(
self._request(hdrs.METH_POST, path, **kwargs)
)
def options(self, path: str, **kwargs: Any) -> _RequestContextManager:
"""Perform an HTTP OPTIONS request."""
return _RequestContextManager(
self._request(hdrs.METH_OPTIONS, path, **kwargs)
)
def head(self, path: str, **kwargs: Any) -> _RequestContextManager:
"""Perform an HTTP HEAD request."""
return _RequestContextManager(
self._request(hdrs.METH_HEAD, path, **kwargs)
)
def put(self, path: str, **kwargs: Any) -> _RequestContextManager:
"""Perform an HTTP PUT request."""
return _RequestContextManager(
self._request(hdrs.METH_PUT, path, **kwargs)
)
def patch(self, path: str, **kwargs: Any) -> _RequestContextManager:
"""Perform an HTTP PATCH request."""
return _RequestContextManager(
self._request(hdrs.METH_PATCH, path, **kwargs)
)
def delete(self, path: str, **kwargs: Any) -> _RequestContextManager:
"""Perform an HTTP PATCH request."""
return _RequestContextManager(
self._request(hdrs.METH_DELETE, path, **kwargs)
)
def ws_connect(self, path: str, **kwargs: Any) -> _WSRequestContextManager:
"""Initiate websocket connection.
The api corresponds to hyper_internal_service.ClientSession.ws_connect.
"""
return _WSRequestContextManager(
self._ws_connect(path, **kwargs)
)
async def _ws_connect(self, path: str,
**kwargs: Any) -> ClientWebSocketResponse:
ws = await self._session.ws_connect(
self.make_url(path), **kwargs)
self._websockets.append(ws)
return ws
async def close(self) -> None:
"""Close all fixtures created by the test client.
After that point, the TestClient is no longer usable.
This is an idempotent function: running close multiple times
will not have any additional effects.
close is also run on exit when used as a(n) (asynchronous)
context manager.
"""
if not self._closed:
for resp in self._responses:
resp.close()
for ws in self._websockets:
await ws.close()
await self._session.close()
await self._server.close()
self._closed = True
def __enter__(self) -> None:
raise TypeError("Use async with instead")
def __exit__(self,
exc_type: Optional[Type[BaseException]],
exc: Optional[BaseException],
tb: Optional[TracebackType]) -> None:
# __exit__ should exist in pair with __enter__ but never executed
pass # pragma: no cover
async def __aenter__(self) -> 'TestClient':
await self.start_server()
return self
async def __aexit__(self,
exc_type: Optional[Type[BaseException]],
exc: Optional[BaseException],
tb: Optional[TracebackType]) -> None:
await self.close()
class AioHTTPTestCase(unittest.TestCase):
"""A base class to allow for unittest web applications using
hyper_internal_service.
Provides the following:
* self.client (hyper_internal_service.test_utils.TestClient): an hyper_internal_service test client.
* self.loop (asyncio.BaseEventLoop): the event loop in which the
application and server are running.
* self.app (hyper_internal_service.web.Application): the application returned by
self.get_application()
Note that the TestClient's methods are asynchronous: you have to
execute function on the test client using asynchronous methods.
"""
async def get_application(self) -> Application:
"""
This method should be overridden
to return the hyper_internal_service.web.Application
object to test.
"""
return self.get_app()
def get_app(self) -> Application:
"""Obsolete method used to constructing web application.
Use .get_application() coroutine instead
"""
raise RuntimeError("Did you forget to define get_application()?")
def setUp(self) -> None:
self.loop = setup_test_loop()
self.app = self.loop.run_until_complete(self.get_application())
self.server = self.loop.run_until_complete(self.get_server(self.app))
self.client = self.loop.run_until_complete(
self.get_client(self.server))
self.loop.run_until_complete(self.client.start_server())
self.loop.run_until_complete(self.setUpAsync())
async def setUpAsync(self) -> None:
pass
def tearDown(self) -> None:
self.loop.run_until_complete(self.tearDownAsync())
self.loop.run_until_complete(self.client.close())
teardown_test_loop(self.loop)
async def tearDownAsync(self) -> None:
pass
async def get_server(self, app: Application) -> TestServer:
"""Return a TestServer instance."""
return TestServer(app, loop=self.loop)
async def get_client(self, server: TestServer) -> TestClient:
"""Return a TestClient instance."""
return TestClient(server, loop=self.loop)
def unittest_run_loop(func: Any, *args: Any, **kwargs: Any) -> Any:
"""A decorator dedicated to use with asynchronous methods of an
AioHTTPTestCase.
Handles executing an asynchronous function, using
the self.loop of the AioHTTPTestCase.
"""
@functools.wraps(func, *args, **kwargs)
def new_func(self: Any, *inner_args: Any, **inner_kwargs: Any) -> Any:
return self.loop.run_until_complete(
func(self, *inner_args, **inner_kwargs))
return new_func
_LOOP_FACTORY = Callable[[], asyncio.AbstractEventLoop]
@contextlib.contextmanager
def loop_context(loop_factory: _LOOP_FACTORY=asyncio.new_event_loop,
fast: bool=False) -> Iterator[asyncio.AbstractEventLoop]:
"""A contextmanager that creates an event_loop, for test purposes.
Handles the creation and cleanup of a test loop.
"""
loop = setup_test_loop(loop_factory)
yield loop
teardown_test_loop(loop, fast=fast)
def setup_test_loop(
loop_factory: _LOOP_FACTORY=asyncio.new_event_loop
) -> asyncio.AbstractEventLoop:
"""Create and return an asyncio.BaseEventLoop
instance.
The caller should also call teardown_test_loop,
once they are done with the loop.
"""
loop = loop_factory()
try:
module = loop.__class__.__module__
skip_watcher = 'uvloop' in module
except AttributeError: # pragma: no cover
# Just in case
skip_watcher = True
asyncio.set_event_loop(loop)
if sys.platform != "win32" and not skip_watcher:
policy = asyncio.get_event_loop_policy()
watcher = asyncio.SafeChildWatcher()
watcher.attach_loop(loop)
with contextlib.suppress(NotImplementedError):
policy.set_child_watcher(watcher)
return loop
def teardown_test_loop(loop: asyncio.AbstractEventLoop,
fast: bool=False) -> None:
"""Teardown and cleanup an event_loop created
by setup_test_loop.
"""
closed = loop.is_closed()
if not closed:
loop.call_soon(loop.stop)
loop.run_forever()
loop.close()
if not fast:
gc.collect()
asyncio.set_event_loop(None)
def _create_app_mock() -> mock.MagicMock:
def get_dict(app: Any, key: str) -> Any:
return app.__app_dict[key]
def set_dict(app: Any, key: str, value: Any) -> None:
app.__app_dict[key] = value
app = mock.MagicMock()
app.__app_dict = {}
app.__getitem__ = get_dict
app.__setitem__ = set_dict
app._debug = False
app.on_response_prepare = Signal(app)
app.on_response_prepare.freeze()
return app
def _create_transport(sslcontext: Optional[SSLContext]=None) -> mock.Mock:
transport = mock.Mock()
def get_extra_info(key: str) -> Optional[SSLContext]:
if key == 'sslcontext':
return sslcontext
else:
return None
transport.get_extra_info.side_effect = get_extra_info
return transport
def make_mocked_request(method: str, path: str,
headers: Any=None, *,
match_info: Any=sentinel,
version: HttpVersion=HttpVersion(1, 1),
closing: bool=False,
app: Any=None,
writer: Any=sentinel,
protocol: Any=sentinel,
transport: Any=sentinel,
payload: Any=sentinel,
sslcontext: Optional[SSLContext]=None,
client_max_size: int=1024**2,
loop: Any=...) -> Any:
"""Creates mocked web.Request testing purposes.
Useful in unit tests, when spinning full web server is overkill or
specific conditions and errors are hard to trigger.
"""
task = mock.Mock()
if loop is ...:
loop = mock.Mock()
loop.create_future.return_value = ()
if version < HttpVersion(1, 1):
closing = True
if headers:
headers = CIMultiDictProxy(CIMultiDict(headers))
raw_hdrs = tuple(
(k.encode('utf-8'), v.encode('utf-8')) for k, v in headers.items())
else:
headers = CIMultiDictProxy(CIMultiDict())
raw_hdrs = ()
chunked = 'chunked' in headers.get(hdrs.TRANSFER_ENCODING, '').lower()
message = RawRequestMessage(
method, path, version, headers,
raw_hdrs, closing, False, False, chunked, URL(path))
if app is None:
app = _create_app_mock()
if transport is sentinel:
transport = _create_transport(sslcontext)
if protocol is sentinel:
protocol = mock.Mock()
protocol.transport = transport
if writer is sentinel:
writer = mock.Mock()
writer.write_headers = make_mocked_coro(None)
writer.write = make_mocked_coro(None)
writer.write_eof = make_mocked_coro(None)
writer.drain = make_mocked_coro(None)
writer.transport = transport
protocol.transport = transport
protocol.writer = writer
if payload is sentinel:
payload = mock.Mock()
req = Request(message, payload,
protocol, writer, task, loop,
client_max_size=client_max_size)
match_info = UrlMappingMatchInfo(
{} if match_info is sentinel else match_info, mock.Mock())
match_info.add_app(app)
req._match_info = match_info
return req
def make_mocked_coro(return_value: Any=sentinel,
raise_exception: Any=sentinel) -> Any:
"""Creates a coroutine mock."""
async def mock_coro(*args: Any, **kwargs: Any) -> Any:
if raise_exception is not sentinel:
raise raise_exception
if not inspect.isawaitable(return_value):
return return_value
await return_value
return mock.Mock(wraps=mock_coro)

View File

@ -0,0 +1,387 @@
from types import SimpleNamespace
from typing import TYPE_CHECKING, Awaitable, Callable, Type, Union
import attr
from multidict import CIMultiDict # noqa
from yarl import URL
from .client_reqrep import ClientResponse
from .signals import Signal
if TYPE_CHECKING: # pragma: no cover
from .client import ClientSession # noqa
_SignalArgs = Union[
'TraceRequestStartParams',
'TraceRequestEndParams',
'TraceRequestExceptionParams',
'TraceConnectionQueuedStartParams',
'TraceConnectionQueuedEndParams',
'TraceConnectionCreateStartParams',
'TraceConnectionCreateEndParams',
'TraceConnectionReuseconnParams',
'TraceDnsResolveHostStartParams',
'TraceDnsResolveHostEndParams',
'TraceDnsCacheHitParams',
'TraceDnsCacheMissParams',
'TraceRequestRedirectParams',
'TraceRequestChunkSentParams',
'TraceResponseChunkReceivedParams',
]
_Signal = Signal[Callable[[ClientSession, SimpleNamespace, _SignalArgs],
Awaitable[None]]]
else:
_Signal = Signal
__all__ = (
'TraceConfig', 'TraceRequestStartParams', 'TraceRequestEndParams',
'TraceRequestExceptionParams', 'TraceConnectionQueuedStartParams',
'TraceConnectionQueuedEndParams', 'TraceConnectionCreateStartParams',
'TraceConnectionCreateEndParams', 'TraceConnectionReuseconnParams',
'TraceDnsResolveHostStartParams', 'TraceDnsResolveHostEndParams',
'TraceDnsCacheHitParams', 'TraceDnsCacheMissParams',
'TraceRequestRedirectParams',
'TraceRequestChunkSentParams', 'TraceResponseChunkReceivedParams',
)
class TraceConfig:
"""First-class used to trace requests launched via ClientSession
objects."""
def __init__(
self,
trace_config_ctx_factory: Type[SimpleNamespace]=SimpleNamespace
) -> None:
self._on_request_start = Signal(self) # type: _Signal
self._on_request_chunk_sent = Signal(self) # type: _Signal
self._on_response_chunk_received = Signal(self) # type: _Signal
self._on_request_end = Signal(self) # type: _Signal
self._on_request_exception = Signal(self) # type: _Signal
self._on_request_redirect = Signal(self) # type: _Signal
self._on_connection_queued_start = Signal(self) # type: _Signal
self._on_connection_queued_end = Signal(self) # type: _Signal
self._on_connection_create_start = Signal(self) # type: _Signal
self._on_connection_create_end = Signal(self) # type: _Signal
self._on_connection_reuseconn = Signal(self) # type: _Signal
self._on_dns_resolvehost_start = Signal(self) # type: _Signal
self._on_dns_resolvehost_end = Signal(self) # type: _Signal
self._on_dns_cache_hit = Signal(self) # type: _Signal
self._on_dns_cache_miss = Signal(self) # type: _Signal
self._trace_config_ctx_factory = trace_config_ctx_factory # type: Type[SimpleNamespace] # noqa
def trace_config_ctx(
self,
trace_request_ctx: SimpleNamespace=None
) -> SimpleNamespace: # noqa
""" Return a new trace_config_ctx instance """
return self._trace_config_ctx_factory(
trace_request_ctx=trace_request_ctx)
def freeze(self) -> None:
self._on_request_start.freeze()
self._on_request_chunk_sent.freeze()
self._on_response_chunk_received.freeze()
self._on_request_end.freeze()
self._on_request_exception.freeze()
self._on_request_redirect.freeze()
self._on_connection_queued_start.freeze()
self._on_connection_queued_end.freeze()
self._on_connection_create_start.freeze()
self._on_connection_create_end.freeze()
self._on_connection_reuseconn.freeze()
self._on_dns_resolvehost_start.freeze()
self._on_dns_resolvehost_end.freeze()
self._on_dns_cache_hit.freeze()
self._on_dns_cache_miss.freeze()
@property
def on_request_start(self) -> _Signal:
return self._on_request_start
@property
def on_request_chunk_sent(self) -> _Signal:
return self._on_request_chunk_sent
@property
def on_response_chunk_received(self) -> _Signal:
return self._on_response_chunk_received
@property
def on_request_end(self) -> _Signal:
return self._on_request_end
@property
def on_request_exception(self) -> _Signal:
return self._on_request_exception
@property
def on_request_redirect(self) -> _Signal:
return self._on_request_redirect
@property
def on_connection_queued_start(self) -> _Signal:
return self._on_connection_queued_start
@property
def on_connection_queued_end(self) -> _Signal:
return self._on_connection_queued_end
@property
def on_connection_create_start(self) -> _Signal:
return self._on_connection_create_start
@property
def on_connection_create_end(self) -> _Signal:
return self._on_connection_create_end
@property
def on_connection_reuseconn(self) -> _Signal:
return self._on_connection_reuseconn
@property
def on_dns_resolvehost_start(self) -> _Signal:
return self._on_dns_resolvehost_start
@property
def on_dns_resolvehost_end(self) -> _Signal:
return self._on_dns_resolvehost_end
@property
def on_dns_cache_hit(self) -> _Signal:
return self._on_dns_cache_hit
@property
def on_dns_cache_miss(self) -> _Signal:
return self._on_dns_cache_miss
@attr.s(frozen=True, slots=True)
class TraceRequestStartParams:
""" Parameters sent by the `on_request_start` signal"""
method = attr.ib(type=str)
url = attr.ib(type=URL)
headers = attr.ib(type='CIMultiDict[str]')
@attr.s(frozen=True, slots=True)
class TraceRequestChunkSentParams:
""" Parameters sent by the `on_request_chunk_sent` signal"""
chunk = attr.ib(type=bytes)
@attr.s(frozen=True, slots=True)
class TraceResponseChunkReceivedParams:
""" Parameters sent by the `on_response_chunk_received` signal"""
chunk = attr.ib(type=bytes)
@attr.s(frozen=True, slots=True)
class TraceRequestEndParams:
""" Parameters sent by the `on_request_end` signal"""
method = attr.ib(type=str)
url = attr.ib(type=URL)
headers = attr.ib(type='CIMultiDict[str]')
response = attr.ib(type=ClientResponse)
@attr.s(frozen=True, slots=True)
class TraceRequestExceptionParams:
""" Parameters sent by the `on_request_exception` signal"""
method = attr.ib(type=str)
url = attr.ib(type=URL)
headers = attr.ib(type='CIMultiDict[str]')
exception = attr.ib(type=BaseException)
@attr.s(frozen=True, slots=True)
class TraceRequestRedirectParams:
""" Parameters sent by the `on_request_redirect` signal"""
method = attr.ib(type=str)
url = attr.ib(type=URL)
headers = attr.ib(type='CIMultiDict[str]')
response = attr.ib(type=ClientResponse)
@attr.s(frozen=True, slots=True)
class TraceConnectionQueuedStartParams:
""" Parameters sent by the `on_connection_queued_start` signal"""
@attr.s(frozen=True, slots=True)
class TraceConnectionQueuedEndParams:
""" Parameters sent by the `on_connection_queued_end` signal"""
@attr.s(frozen=True, slots=True)
class TraceConnectionCreateStartParams:
""" Parameters sent by the `on_connection_create_start` signal"""
@attr.s(frozen=True, slots=True)
class TraceConnectionCreateEndParams:
""" Parameters sent by the `on_connection_create_end` signal"""
@attr.s(frozen=True, slots=True)
class TraceConnectionReuseconnParams:
""" Parameters sent by the `on_connection_reuseconn` signal"""
@attr.s(frozen=True, slots=True)
class TraceDnsResolveHostStartParams:
""" Parameters sent by the `on_dns_resolvehost_start` signal"""
host = attr.ib(type=str)
@attr.s(frozen=True, slots=True)
class TraceDnsResolveHostEndParams:
""" Parameters sent by the `on_dns_resolvehost_end` signal"""
host = attr.ib(type=str)
@attr.s(frozen=True, slots=True)
class TraceDnsCacheHitParams:
""" Parameters sent by the `on_dns_cache_hit` signal"""
host = attr.ib(type=str)
@attr.s(frozen=True, slots=True)
class TraceDnsCacheMissParams:
""" Parameters sent by the `on_dns_cache_miss` signal"""
host = attr.ib(type=str)
class Trace:
""" Internal class used to keep together the main dependencies used
at the moment of send a signal."""
def __init__(self,
session: 'ClientSession',
trace_config: TraceConfig,
trace_config_ctx: SimpleNamespace) -> None:
self._trace_config = trace_config
self._trace_config_ctx = trace_config_ctx
self._session = session
async def send_request_start(self,
method: str,
url: URL,
headers: 'CIMultiDict[str]') -> None:
return await self._trace_config.on_request_start.send(
self._session,
self._trace_config_ctx,
TraceRequestStartParams(method, url, headers)
)
async def send_request_chunk_sent(self, chunk: bytes) -> None:
return await self._trace_config.on_request_chunk_sent.send(
self._session,
self._trace_config_ctx,
TraceRequestChunkSentParams(chunk)
)
async def send_response_chunk_received(self, chunk: bytes) -> None:
return await self._trace_config.on_response_chunk_received.send(
self._session,
self._trace_config_ctx,
TraceResponseChunkReceivedParams(chunk)
)
async def send_request_end(self,
method: str,
url: URL,
headers: 'CIMultiDict[str]',
response: ClientResponse) -> None:
return await self._trace_config.on_request_end.send(
self._session,
self._trace_config_ctx,
TraceRequestEndParams(method, url, headers, response)
)
async def send_request_exception(self,
method: str,
url: URL,
headers: 'CIMultiDict[str]',
exception: BaseException) -> None:
return await self._trace_config.on_request_exception.send(
self._session,
self._trace_config_ctx,
TraceRequestExceptionParams(method, url, headers, exception)
)
async def send_request_redirect(self,
method: str,
url: URL,
headers: 'CIMultiDict[str]',
response: ClientResponse) -> None:
return await self._trace_config._on_request_redirect.send(
self._session,
self._trace_config_ctx,
TraceRequestRedirectParams(method, url, headers, response)
)
async def send_connection_queued_start(self) -> None:
return await self._trace_config.on_connection_queued_start.send(
self._session,
self._trace_config_ctx,
TraceConnectionQueuedStartParams()
)
async def send_connection_queued_end(self) -> None:
return await self._trace_config.on_connection_queued_end.send(
self._session,
self._trace_config_ctx,
TraceConnectionQueuedEndParams()
)
async def send_connection_create_start(self) -> None:
return await self._trace_config.on_connection_create_start.send(
self._session,
self._trace_config_ctx,
TraceConnectionCreateStartParams()
)
async def send_connection_create_end(self) -> None:
return await self._trace_config.on_connection_create_end.send(
self._session,
self._trace_config_ctx,
TraceConnectionCreateEndParams()
)
async def send_connection_reuseconn(self) -> None:
return await self._trace_config.on_connection_reuseconn.send(
self._session,
self._trace_config_ctx,
TraceConnectionReuseconnParams()
)
async def send_dns_resolvehost_start(self, host: str) -> None:
return await self._trace_config.on_dns_resolvehost_start.send(
self._session,
self._trace_config_ctx,
TraceDnsResolveHostStartParams(host)
)
async def send_dns_resolvehost_end(self, host: str) -> None:
return await self._trace_config.on_dns_resolvehost_end.send(
self._session,
self._trace_config_ctx,
TraceDnsResolveHostEndParams(host)
)
async def send_dns_cache_hit(self, host: str) -> None:
return await self._trace_config.on_dns_cache_hit.send(
self._session,
self._trace_config_ctx,
TraceDnsCacheHitParams(host)
)
async def send_dns_cache_miss(self, host: str) -> None:
return await self._trace_config.on_dns_cache_miss.send(
self._session,
self._trace_config_ctx,
TraceDnsCacheMissParams(host)
)

View File

@ -0,0 +1,63 @@
import json
import os # noqa
import pathlib # noqa
import sys
from typing import (
TYPE_CHECKING,
Any,
Callable,
Iterable,
Mapping,
Tuple,
Union,
)
from multidict import (
CIMultiDict,
CIMultiDictProxy,
MultiDict,
MultiDictProxy,
istr,
)
from yarl import URL
DEFAULT_JSON_ENCODER = json.dumps
DEFAULT_JSON_DECODER = json.loads
if TYPE_CHECKING: # pragma: no cover
_CIMultiDict = CIMultiDict[str]
_CIMultiDictProxy = CIMultiDictProxy[str]
_MultiDict = MultiDict[str]
_MultiDictProxy = MultiDictProxy[str]
from http.cookies import BaseCookie, Morsel # noqa
else:
_CIMultiDict = CIMultiDict
_CIMultiDictProxy = CIMultiDictProxy
_MultiDict = MultiDict
_MultiDictProxy = MultiDictProxy
Byteish = Union[bytes, bytearray, memoryview]
JSONEncoder = Callable[[Any], str]
JSONDecoder = Callable[[str], Any]
LooseHeaders = Union[Mapping[Union[str, istr], str], _CIMultiDict,
_CIMultiDictProxy]
RawHeaders = Tuple[Tuple[bytes, bytes], ...]
StrOrURL = Union[str, URL]
LooseCookiesMappings = Mapping[
str, Union[str, 'BaseCookie[str]', 'Morsel[Any]']
]
LooseCookiesIterables = Iterable[
Tuple[str, Union[str, 'BaseCookie[str]', 'Morsel[Any]']]
]
LooseCookies = Union[
LooseCookiesMappings,
LooseCookiesIterables,
'BaseCookie[str]',
]
if sys.version_info >= (3, 6):
PathLike = Union[str, 'os.PathLike[str]']
else:
PathLike = Union[str, pathlib.PurePath]

View File

@ -0,0 +1,515 @@
import asyncio
import logging
import socket
import sys
from argparse import ArgumentParser
from collections.abc import Iterable
from importlib import import_module
from typing import (
Any,
Awaitable,
Callable,
List,
Optional,
Set,
Type,
Union,
cast,
)
from .abc import AbstractAccessLogger
from .helpers import all_tasks
from .log import access_logger
from .web_app import Application as Application
from .web_app import CleanupError as CleanupError
from .web_exceptions import HTTPAccepted as HTTPAccepted
from .web_exceptions import HTTPBadGateway as HTTPBadGateway
from .web_exceptions import HTTPBadRequest as HTTPBadRequest
from .web_exceptions import HTTPClientError as HTTPClientError
from .web_exceptions import HTTPConflict as HTTPConflict
from .web_exceptions import HTTPCreated as HTTPCreated
from .web_exceptions import HTTPError as HTTPError
from .web_exceptions import HTTPException as HTTPException
from .web_exceptions import HTTPExpectationFailed as HTTPExpectationFailed
from .web_exceptions import HTTPFailedDependency as HTTPFailedDependency
from .web_exceptions import HTTPForbidden as HTTPForbidden
from .web_exceptions import HTTPFound as HTTPFound
from .web_exceptions import HTTPGatewayTimeout as HTTPGatewayTimeout
from .web_exceptions import HTTPGone as HTTPGone
from .web_exceptions import HTTPInsufficientStorage as HTTPInsufficientStorage
from .web_exceptions import HTTPInternalServerError as HTTPInternalServerError
from .web_exceptions import HTTPLengthRequired as HTTPLengthRequired
from .web_exceptions import HTTPMethodNotAllowed as HTTPMethodNotAllowed
from .web_exceptions import HTTPMisdirectedRequest as HTTPMisdirectedRequest
from .web_exceptions import HTTPMovedPermanently as HTTPMovedPermanently
from .web_exceptions import HTTPMultipleChoices as HTTPMultipleChoices
from .web_exceptions import (
HTTPNetworkAuthenticationRequired as HTTPNetworkAuthenticationRequired,
)
from .web_exceptions import HTTPNoContent as HTTPNoContent
from .web_exceptions import (
HTTPNonAuthoritativeInformation as HTTPNonAuthoritativeInformation,
)
from .web_exceptions import HTTPNotAcceptable as HTTPNotAcceptable
from .web_exceptions import HTTPNotExtended as HTTPNotExtended
from .web_exceptions import HTTPNotFound as HTTPNotFound
from .web_exceptions import HTTPNotImplemented as HTTPNotImplemented
from .web_exceptions import HTTPNotModified as HTTPNotModified
from .web_exceptions import HTTPOk as HTTPOk
from .web_exceptions import HTTPPartialContent as HTTPPartialContent
from .web_exceptions import HTTPPaymentRequired as HTTPPaymentRequired
from .web_exceptions import HTTPPermanentRedirect as HTTPPermanentRedirect
from .web_exceptions import HTTPPreconditionFailed as HTTPPreconditionFailed
from .web_exceptions import (
HTTPPreconditionRequired as HTTPPreconditionRequired,
)
from .web_exceptions import (
HTTPProxyAuthenticationRequired as HTTPProxyAuthenticationRequired,
)
from .web_exceptions import HTTPRedirection as HTTPRedirection
from .web_exceptions import (
HTTPRequestEntityTooLarge as HTTPRequestEntityTooLarge,
)
from .web_exceptions import (
HTTPRequestHeaderFieldsTooLarge as HTTPRequestHeaderFieldsTooLarge,
)
from .web_exceptions import (
HTTPRequestRangeNotSatisfiable as HTTPRequestRangeNotSatisfiable,
)
from .web_exceptions import HTTPRequestTimeout as HTTPRequestTimeout
from .web_exceptions import HTTPRequestURITooLong as HTTPRequestURITooLong
from .web_exceptions import HTTPResetContent as HTTPResetContent
from .web_exceptions import HTTPSeeOther as HTTPSeeOther
from .web_exceptions import HTTPServerError as HTTPServerError
from .web_exceptions import HTTPServiceUnavailable as HTTPServiceUnavailable
from .web_exceptions import HTTPSuccessful as HTTPSuccessful
from .web_exceptions import HTTPTemporaryRedirect as HTTPTemporaryRedirect
from .web_exceptions import HTTPTooManyRequests as HTTPTooManyRequests
from .web_exceptions import HTTPUnauthorized as HTTPUnauthorized
from .web_exceptions import (
HTTPUnavailableForLegalReasons as HTTPUnavailableForLegalReasons,
)
from .web_exceptions import HTTPUnprocessableEntity as HTTPUnprocessableEntity
from .web_exceptions import (
HTTPUnsupportedMediaType as HTTPUnsupportedMediaType,
)
from .web_exceptions import HTTPUpgradeRequired as HTTPUpgradeRequired
from .web_exceptions import HTTPUseProxy as HTTPUseProxy
from .web_exceptions import (
HTTPVariantAlsoNegotiates as HTTPVariantAlsoNegotiates,
)
from .web_exceptions import HTTPVersionNotSupported as HTTPVersionNotSupported
from .web_fileresponse import FileResponse as FileResponse
from .web_log import AccessLogger
from .web_middlewares import middleware as middleware
from .web_middlewares import (
normalize_path_middleware as normalize_path_middleware,
)
from .web_protocol import PayloadAccessError as PayloadAccessError
from .web_protocol import RequestHandler as RequestHandler
from .web_protocol import RequestPayloadError as RequestPayloadError
from .web_request import BaseRequest as BaseRequest
from .web_request import FileField as FileField
from .web_request import Request as Request
from .web_response import ContentCoding as ContentCoding
from .web_response import Response as Response
from .web_response import StreamResponse as StreamResponse
from .web_response import json_response as json_response
from .web_routedef import AbstractRouteDef as AbstractRouteDef
from .web_routedef import RouteDef as RouteDef
from .web_routedef import RouteTableDef as RouteTableDef
from .web_routedef import StaticDef as StaticDef
from .web_routedef import delete as delete
from .web_routedef import get as get
from .web_routedef import head as head
from .web_routedef import options as options
from .web_routedef import patch as patch
from .web_routedef import post as post
from .web_routedef import put as put
from .web_routedef import route as route
from .web_routedef import static as static
from .web_routedef import view as view
from .web_runner import AppRunner as AppRunner
from .web_runner import BaseRunner as BaseRunner
from .web_runner import BaseSite as BaseSite
from .web_runner import GracefulExit as GracefulExit
from .web_runner import NamedPipeSite as NamedPipeSite
from .web_runner import ServerRunner as ServerRunner
from .web_runner import SockSite as SockSite
from .web_runner import TCPSite as TCPSite
from .web_runner import UnixSite as UnixSite
from .web_server import Server as Server
from .web_urldispatcher import AbstractResource as AbstractResource
from .web_urldispatcher import AbstractRoute as AbstractRoute
from .web_urldispatcher import DynamicResource as DynamicResource
from .web_urldispatcher import PlainResource as PlainResource
from .web_urldispatcher import Resource as Resource
from .web_urldispatcher import ResourceRoute as ResourceRoute
from .web_urldispatcher import StaticResource as StaticResource
from .web_urldispatcher import UrlDispatcher as UrlDispatcher
from .web_urldispatcher import UrlMappingMatchInfo as UrlMappingMatchInfo
from .web_urldispatcher import View as View
from .web_ws import WebSocketReady as WebSocketReady
from .web_ws import WebSocketResponse as WebSocketResponse
from .web_ws import WSMsgType as WSMsgType
__all__ = (
# web_app
'Application',
'CleanupError',
# web_exceptions
'HTTPAccepted',
'HTTPBadGateway',
'HTTPBadRequest',
'HTTPClientError',
'HTTPConflict',
'HTTPCreated',
'HTTPError',
'HTTPException',
'HTTPExpectationFailed',
'HTTPFailedDependency',
'HTTPForbidden',
'HTTPFound',
'HTTPGatewayTimeout',
'HTTPGone',
'HTTPInsufficientStorage',
'HTTPInternalServerError',
'HTTPLengthRequired',
'HTTPMethodNotAllowed',
'HTTPMisdirectedRequest',
'HTTPMovedPermanently',
'HTTPMultipleChoices',
'HTTPNetworkAuthenticationRequired',
'HTTPNoContent',
'HTTPNonAuthoritativeInformation',
'HTTPNotAcceptable',
'HTTPNotExtended',
'HTTPNotFound',
'HTTPNotImplemented',
'HTTPNotModified',
'HTTPOk',
'HTTPPartialContent',
'HTTPPaymentRequired',
'HTTPPermanentRedirect',
'HTTPPreconditionFailed',
'HTTPPreconditionRequired',
'HTTPProxyAuthenticationRequired',
'HTTPRedirection',
'HTTPRequestEntityTooLarge',
'HTTPRequestHeaderFieldsTooLarge',
'HTTPRequestRangeNotSatisfiable',
'HTTPRequestTimeout',
'HTTPRequestURITooLong',
'HTTPResetContent',
'HTTPSeeOther',
'HTTPServerError',
'HTTPServiceUnavailable',
'HTTPSuccessful',
'HTTPTemporaryRedirect',
'HTTPTooManyRequests',
'HTTPUnauthorized',
'HTTPUnavailableForLegalReasons',
'HTTPUnprocessableEntity',
'HTTPUnsupportedMediaType',
'HTTPUpgradeRequired',
'HTTPUseProxy',
'HTTPVariantAlsoNegotiates',
'HTTPVersionNotSupported',
# web_fileresponse
'FileResponse',
# web_middlewares
'middleware',
'normalize_path_middleware',
# web_protocol
'PayloadAccessError',
'RequestHandler',
'RequestPayloadError',
# web_request
'BaseRequest',
'FileField',
'Request',
# web_response
'ContentCoding',
'Response',
'StreamResponse',
'json_response',
# web_routedef
'AbstractRouteDef',
'RouteDef',
'RouteTableDef',
'StaticDef',
'delete',
'get',
'head',
'options',
'patch',
'post',
'put',
'route',
'static',
'view',
# web_runner
'AppRunner',
'BaseRunner',
'BaseSite',
'GracefulExit',
'ServerRunner',
'SockSite',
'TCPSite',
'UnixSite',
'NamedPipeSite',
# web_server
'Server',
# web_urldispatcher
'AbstractResource',
'AbstractRoute',
'DynamicResource',
'PlainResource',
'Resource',
'ResourceRoute',
'StaticResource',
'UrlDispatcher',
'UrlMappingMatchInfo',
'View',
# web_ws
'WebSocketReady',
'WebSocketResponse',
'WSMsgType',
# web
'run_app',
)
try:
from ssl import SSLContext
except ImportError: # pragma: no cover
SSLContext = Any # type: ignore
async def _run_app(app: Union[Application, Awaitable[Application]], *,
host: Optional[str]=None,
port: Optional[int]=None,
path: Optional[str]=None,
sock: Optional[socket.socket]=None,
shutdown_timeout: float=60.0,
ssl_context: Optional[SSLContext]=None,
print: Callable[..., None]=print,
backlog: int=128,
access_log_class: Type[AbstractAccessLogger]=AccessLogger,
access_log_format: str=AccessLogger.LOG_FORMAT,
access_log: Optional[logging.Logger]=access_logger,
handle_signals: bool=True,
reuse_address: Optional[bool]=None,
reuse_port: Optional[bool]=None) -> None:
# A internal functio to actually do all dirty job for application running
if asyncio.iscoroutine(app):
app = await app # type: ignore
app = cast(Application, app)
runner = AppRunner(app, handle_signals=handle_signals,
access_log_class=access_log_class,
access_log_format=access_log_format,
access_log=access_log)
await runner.setup()
sites = [] # type: List[BaseSite]
try:
if host is not None:
if isinstance(host, (str, bytes, bytearray, memoryview)):
sites.append(TCPSite(runner, host, port,
shutdown_timeout=shutdown_timeout,
ssl_context=ssl_context,
backlog=backlog,
reuse_address=reuse_address,
reuse_port=reuse_port))
else:
for h in host:
sites.append(TCPSite(runner, h, port,
shutdown_timeout=shutdown_timeout,
ssl_context=ssl_context,
backlog=backlog,
reuse_address=reuse_address,
reuse_port=reuse_port))
elif path is None and sock is None or port is not None:
sites.append(TCPSite(runner, port=port,
shutdown_timeout=shutdown_timeout,
ssl_context=ssl_context, backlog=backlog,
reuse_address=reuse_address,
reuse_port=reuse_port))
if path is not None:
if isinstance(path, (str, bytes, bytearray, memoryview)):
sites.append(UnixSite(runner, path,
shutdown_timeout=shutdown_timeout,
ssl_context=ssl_context,
backlog=backlog))
else:
for p in path:
sites.append(UnixSite(runner, p,
shutdown_timeout=shutdown_timeout,
ssl_context=ssl_context,
backlog=backlog))
if sock is not None:
if not isinstance(sock, Iterable):
sites.append(SockSite(runner, sock,
shutdown_timeout=shutdown_timeout,
ssl_context=ssl_context,
backlog=backlog))
else:
for s in sock:
sites.append(SockSite(runner, s,
shutdown_timeout=shutdown_timeout,
ssl_context=ssl_context,
backlog=backlog))
for site in sites:
await site.start()
if print: # pragma: no branch
names = sorted(str(s.name) for s in runner.sites)
print("======== Hyper Internal Service Running on {} ========\n"
"(Press CTRL+C to quit)".format(', '.join(names)))
while True:
await asyncio.sleep(3600) # sleep forever by 1 hour intervals
finally:
await runner.cleanup()
def _cancel_tasks(to_cancel: Set['asyncio.Task[Any]'],
loop: asyncio.AbstractEventLoop) -> None:
if not to_cancel:
return
for task in to_cancel:
task.cancel()
loop.run_until_complete(
asyncio.gather(*to_cancel, loop=loop, return_exceptions=True))
for task in to_cancel:
if task.cancelled():
continue
if task.exception() is not None:
loop.call_exception_handler({
'message': 'unhandled exception during asyncio.run() shutdown',
'exception': task.exception(),
'task': task,
})
def run_app(app: Union[Application, Awaitable[Application]], *,
host: Optional[str]=None,
port: Optional[int]=None,
path: Optional[str]=None,
sock: Optional[socket.socket]=None,
shutdown_timeout: float=60.0,
ssl_context: Optional[SSLContext]=None,
print: Callable[..., None]=print,
backlog: int=128,
access_log_class: Type[AbstractAccessLogger]=AccessLogger,
access_log_format: str=AccessLogger.LOG_FORMAT,
access_log: Optional[logging.Logger]=access_logger,
handle_signals: bool=True,
reuse_address: Optional[bool]=None,
reuse_port: Optional[bool]=None) -> None:
"""Run an app locally"""
loop = asyncio.get_event_loop()
# Configure if and only if in debugging mode and using the default logger
if loop.get_debug() and access_log and access_log.name == 'hyper_internal_service.access':
if access_log.level == logging.NOTSET:
access_log.setLevel(logging.DEBUG)
if not access_log.hasHandlers():
access_log.addHandler(logging.StreamHandler())
try:
main_task = loop.create_task(_run_app(
app,
host=host,
port=port,
path=path,
sock=sock,
shutdown_timeout=shutdown_timeout,
ssl_context=ssl_context,
print=print,
backlog=backlog,
access_log_class=access_log_class,
access_log_format=access_log_format,
access_log=access_log,
handle_signals=handle_signals,
reuse_address=reuse_address,
reuse_port=reuse_port))
loop.run_until_complete(main_task)
except (GracefulExit, KeyboardInterrupt): # pragma: no cover
pass
finally:
_cancel_tasks({main_task}, loop)
_cancel_tasks(all_tasks(loop), loop)
if sys.version_info >= (3, 6): # don't use PY_36 to pass mypy
loop.run_until_complete(loop.shutdown_asyncgens())
loop.close()
def main(argv: List[str]) -> None:
arg_parser = ArgumentParser(
description="hyper_internal_service.web Application server",
prog="hyper_internal_service.web"
)
arg_parser.add_argument(
"entry_func",
help=("Callable returning the `hyper_internal_service.web.Application` instance to "
"run. Should be specified in the 'module:function' syntax."),
metavar="entry-func"
)
arg_parser.add_argument(
"-H", "--hostname",
help="TCP/IP hostname to serve on (default: %(default)r)",
default="localhost"
)
arg_parser.add_argument(
"-P", "--port",
help="TCP/IP port to serve on (default: %(default)r)",
type=int,
default="8080"
)
arg_parser.add_argument(
"-U", "--path",
help="Unix file system path to serve on. Specifying a path will cause "
"hostname and port arguments to be ignored.",
)
args, extra_argv = arg_parser.parse_known_args(argv)
# Import logic
mod_str, _, func_str = args.entry_func.partition(":")
if not func_str or not mod_str:
arg_parser.error(
"'entry-func' not in 'module:function' syntax"
)
if mod_str.startswith("."):
arg_parser.error("relative module names not supported")
try:
module = import_module(mod_str)
except ImportError as ex:
arg_parser.error("unable to import %s: %s" % (mod_str, ex))
try:
func = getattr(module, func_str)
except AttributeError:
arg_parser.error("module %r has no attribute %r" % (mod_str, func_str))
# Compatibility logic
if args.path is not None and not hasattr(socket, 'AF_UNIX'):
arg_parser.error("file system paths not supported by your operating"
" environment")
logging.basicConfig(level=logging.DEBUG)
app = func(extra_argv)
run_app(app, host=args.hostname, port=args.port, path=args.path)
arg_parser.exit(message="Stopped\n")
if __name__ == "__main__": # pragma: no branch
main(sys.argv[1:]) # pragma: no cover

View File

@ -0,0 +1,518 @@
import asyncio
import logging
import warnings
from functools import partial, update_wrapper
from typing import ( # noqa
TYPE_CHECKING,
Any,
AsyncIterator,
Awaitable,
Callable,
Dict,
Iterable,
Iterator,
List,
Mapping,
MutableMapping,
Optional,
Sequence,
Tuple,
Type,
Union,
cast,
)
from . import hdrs
from .abc import (
AbstractAccessLogger,
AbstractMatchInfo,
AbstractRouter,
AbstractStreamWriter,
)
from .frozenlist import FrozenList
from .helpers import DEBUG
from .http_parser import RawRequestMessage
from .log import web_logger
from .signals import Signal
from .streams import StreamReader
from .web_log import AccessLogger
from .web_middlewares import _fix_request_current_app
from .web_protocol import RequestHandler
from .web_request import Request
from .web_response import StreamResponse
from .web_routedef import AbstractRouteDef
from .web_server import Server
from .web_urldispatcher import (
AbstractResource,
AbstractRoute,
Domain,
MaskDomain,
MatchedSubAppResource,
PrefixedSubAppResource,
UrlDispatcher,
)
__all__ = ('Application', 'CleanupError')
if TYPE_CHECKING: # pragma: no cover
_AppSignal = Signal[Callable[['Application'], Awaitable[None]]]
_RespPrepareSignal = Signal[Callable[[Request, StreamResponse],
Awaitable[None]]]
_Handler = Callable[[Request], Awaitable[StreamResponse]]
_Middleware = Union[Callable[[Request, _Handler],
Awaitable[StreamResponse]],
Callable[['Application', _Handler], # old-style
Awaitable[_Handler]]]
_Middlewares = FrozenList[_Middleware]
_MiddlewaresHandlers = Optional[Sequence[Tuple[_Middleware, bool]]]
_Subapps = List['Application']
else:
# No type checker mode, skip types
_AppSignal = Signal
_RespPrepareSignal = Signal
_Handler = Callable
_Middleware = Callable
_Middlewares = FrozenList
_MiddlewaresHandlers = Optional[Sequence]
_Subapps = List
class Application(MutableMapping[str, Any]):
ATTRS = frozenset([
'logger', '_debug', '_router', '_loop', '_handler_args',
'_middlewares', '_middlewares_handlers', '_run_middlewares',
'_state', '_frozen', '_pre_frozen', '_subapps',
'_on_response_prepare', '_on_startup', '_on_shutdown',
'_on_cleanup', '_client_max_size', '_cleanup_ctx'])
def __init__(self, *,
logger: logging.Logger=web_logger,
router: Optional[UrlDispatcher]=None,
middlewares: Iterable[_Middleware]=(),
handler_args: Mapping[str, Any]=None,
client_max_size: int=1024**2,
loop: Optional[asyncio.AbstractEventLoop]=None,
debug: Any=... # mypy doesn't support ellipsis
) -> None:
if router is None:
router = UrlDispatcher()
else:
warnings.warn("router argument is deprecated", DeprecationWarning,
stacklevel=2)
assert isinstance(router, AbstractRouter), router
if loop is not None:
warnings.warn("loop argument is deprecated", DeprecationWarning,
stacklevel=2)
if debug is not ...:
warnings.warn("debug argument is deprecated",
DeprecationWarning,
stacklevel=2)
self._debug = debug
self._router = router # type: UrlDispatcher
self._loop = loop
self._handler_args = handler_args
self.logger = logger
self._middlewares = FrozenList(middlewares) # type: _Middlewares
# initialized on freezing
self._middlewares_handlers = None # type: _MiddlewaresHandlers
# initialized on freezing
self._run_middlewares = None # type: Optional[bool]
self._state = {} # type: Dict[str, Any]
self._frozen = False
self._pre_frozen = False
self._subapps = [] # type: _Subapps
self._on_response_prepare = Signal(self) # type: _RespPrepareSignal
self._on_startup = Signal(self) # type: _AppSignal
self._on_shutdown = Signal(self) # type: _AppSignal
self._on_cleanup = Signal(self) # type: _AppSignal
self._cleanup_ctx = CleanupContext()
self._on_startup.append(self._cleanup_ctx._on_startup)
self._on_cleanup.append(self._cleanup_ctx._on_cleanup)
self._client_max_size = client_max_size
def __init_subclass__(cls: Type['Application']) -> None:
warnings.warn("Inheritance class {} from web.Application "
"is discouraged".format(cls.__name__),
DeprecationWarning,
stacklevel=2)
if DEBUG: # pragma: no cover
def __setattr__(self, name: str, val: Any) -> None:
if name not in self.ATTRS:
warnings.warn("Setting custom web.Application.{} attribute "
"is discouraged".format(name),
DeprecationWarning,
stacklevel=2)
super().__setattr__(name, val)
# MutableMapping API
def __eq__(self, other: object) -> bool:
return self is other
def __getitem__(self, key: str) -> Any:
return self._state[key]
def _check_frozen(self) -> None:
if self._frozen:
warnings.warn("Changing state of started or joined "
"application is deprecated",
DeprecationWarning,
stacklevel=3)
def __setitem__(self, key: str, value: Any) -> None:
self._check_frozen()
self._state[key] = value
def __delitem__(self, key: str) -> None:
self._check_frozen()
del self._state[key]
def __len__(self) -> int:
return len(self._state)
def __iter__(self) -> Iterator[str]:
return iter(self._state)
########
@property
def loop(self) -> asyncio.AbstractEventLoop:
# Technically the loop can be None
# but we mask it by explicit type cast
# to provide more convinient type annotation
warnings.warn("loop property is deprecated",
DeprecationWarning,
stacklevel=2)
return cast(asyncio.AbstractEventLoop, self._loop)
def _set_loop(self, loop: Optional[asyncio.AbstractEventLoop]) -> None:
if loop is None:
loop = asyncio.get_event_loop()
if self._loop is not None and self._loop is not loop:
raise RuntimeError(
"web.Application instance initialized with different loop")
self._loop = loop
# set loop debug
if self._debug is ...:
self._debug = loop.get_debug()
# set loop to sub applications
for subapp in self._subapps:
subapp._set_loop(loop)
@property
def pre_frozen(self) -> bool:
return self._pre_frozen
def pre_freeze(self) -> None:
if self._pre_frozen:
return
self._pre_frozen = True
self._middlewares.freeze()
self._router.freeze()
self._on_response_prepare.freeze()
self._cleanup_ctx.freeze()
self._on_startup.freeze()
self._on_shutdown.freeze()
self._on_cleanup.freeze()
self._middlewares_handlers = tuple(self._prepare_middleware())
# If current app and any subapp do not have middlewares avoid run all
# of the code footprint that it implies, which have a middleware
# hardcoded per app that sets up the current_app attribute. If no
# middlewares are configured the handler will receive the proper
# current_app without needing all of this code.
self._run_middlewares = True if self.middlewares else False
for subapp in self._subapps:
subapp.pre_freeze()
self._run_middlewares = (self._run_middlewares or
subapp._run_middlewares)
@property
def frozen(self) -> bool:
return self._frozen
def freeze(self) -> None:
if self._frozen:
return
self.pre_freeze()
self._frozen = True
for subapp in self._subapps:
subapp.freeze()
@property
def debug(self) -> bool:
warnings.warn("debug property is deprecated",
DeprecationWarning,
stacklevel=2)
return self._debug
def _reg_subapp_signals(self, subapp: 'Application') -> None:
def reg_handler(signame: str) -> None:
subsig = getattr(subapp, signame)
async def handler(app: 'Application') -> None:
await subsig.send(subapp)
appsig = getattr(self, signame)
appsig.append(handler)
reg_handler('on_startup')
reg_handler('on_shutdown')
reg_handler('on_cleanup')
def add_subapp(self, prefix: str,
subapp: 'Application') -> AbstractResource:
if not isinstance(prefix, str):
raise TypeError("Prefix must be str")
prefix = prefix.rstrip('/')
if not prefix:
raise ValueError("Prefix cannot be empty")
factory = partial(PrefixedSubAppResource, prefix, subapp)
return self._add_subapp(factory, subapp)
def _add_subapp(self,
resource_factory: Callable[[], AbstractResource],
subapp: 'Application') -> AbstractResource:
if self.frozen:
raise RuntimeError(
"Cannot add sub application to frozen application")
if subapp.frozen:
raise RuntimeError("Cannot add frozen application")
resource = resource_factory()
self.router.register_resource(resource)
self._reg_subapp_signals(subapp)
self._subapps.append(subapp)
subapp.pre_freeze()
if self._loop is not None:
subapp._set_loop(self._loop)
return resource
def add_domain(self, domain: str,
subapp: 'Application') -> AbstractResource:
if not isinstance(domain, str):
raise TypeError("Domain must be str")
elif '*' in domain:
rule = MaskDomain(domain) # type: Domain
else:
rule = Domain(domain)
factory = partial(MatchedSubAppResource, rule, subapp)
return self._add_subapp(factory, subapp)
def add_routes(self,
routes: Iterable[AbstractRouteDef]) -> List[AbstractRoute]:
return self.router.add_routes(routes)
@property
def on_response_prepare(self) -> _RespPrepareSignal:
return self._on_response_prepare
@property
def on_startup(self) -> _AppSignal:
return self._on_startup
@property
def on_shutdown(self) -> _AppSignal:
return self._on_shutdown
@property
def on_cleanup(self) -> _AppSignal:
return self._on_cleanup
@property
def cleanup_ctx(self) -> 'CleanupContext':
return self._cleanup_ctx
@property
def router(self) -> UrlDispatcher:
return self._router
@property
def middlewares(self) -> _Middlewares:
return self._middlewares
def _make_handler(self, *,
loop: Optional[asyncio.AbstractEventLoop]=None,
access_log_class: Type[
AbstractAccessLogger]=AccessLogger,
**kwargs: Any) -> Server:
if not issubclass(access_log_class, AbstractAccessLogger):
raise TypeError(
'access_log_class must be subclass of '
'hyper_internal_service.abc.AbstractAccessLogger, got {}'.format(
access_log_class))
self._set_loop(loop)
self.freeze()
kwargs['debug'] = self._debug
kwargs['access_log_class'] = access_log_class
if self._handler_args:
for k, v in self._handler_args.items():
kwargs[k] = v
return Server(self._handle, # type: ignore
request_factory=self._make_request,
loop=self._loop, **kwargs)
def make_handler(self, *,
loop: Optional[asyncio.AbstractEventLoop]=None,
access_log_class: Type[
AbstractAccessLogger]=AccessLogger,
**kwargs: Any) -> Server:
warnings.warn("Application.make_handler(...) is deprecated, "
"use AppRunner API instead",
DeprecationWarning,
stacklevel=2)
return self._make_handler(loop=loop,
access_log_class=access_log_class,
**kwargs)
async def startup(self) -> None:
"""Causes on_startup signal
Should be called in the event loop along with the request handler.
"""
await self.on_startup.send(self)
async def shutdown(self) -> None:
"""Causes on_shutdown signal
Should be called before cleanup()
"""
await self.on_shutdown.send(self)
async def cleanup(self) -> None:
"""Causes on_cleanup signal
Should be called after shutdown()
"""
await self.on_cleanup.send(self)
def _make_request(self, message: RawRequestMessage,
payload: StreamReader,
protocol: RequestHandler,
writer: AbstractStreamWriter,
task: 'asyncio.Task[None]',
_cls: Type[Request]=Request) -> Request:
return _cls(
message, payload, protocol, writer, task,
self._loop,
client_max_size=self._client_max_size)
def _prepare_middleware(self) -> Iterator[Tuple[_Middleware, bool]]:
for m in reversed(self._middlewares):
if getattr(m, '__middleware_version__', None) == 1:
yield m, True
else:
warnings.warn('old-style middleware "{!r}" deprecated, '
'see #2252'.format(m),
DeprecationWarning, stacklevel=2)
yield m, False
yield _fix_request_current_app(self), True
async def _handle(self, request: Request) -> StreamResponse:
loop = asyncio.get_event_loop()
debug = loop.get_debug()
match_info = await self._router.resolve(request)
if debug: # pragma: no cover
if not isinstance(match_info, AbstractMatchInfo):
raise TypeError("match_info should be AbstractMatchInfo "
"instance, not {!r}".format(match_info))
match_info.add_app(self)
match_info.freeze()
resp = None
request._match_info = match_info # type: ignore
expect = request.headers.get(hdrs.EXPECT)
if expect:
resp = await match_info.expect_handler(request)
await request.writer.drain()
if resp is None:
handler = match_info.handler
if self._run_middlewares:
for app in match_info.apps[::-1]:
for m, new_style in app._middlewares_handlers: # type: ignore # noqa
if new_style:
handler = update_wrapper(
partial(m, handler=handler), handler
)
else:
handler = await m(app, handler) # type: ignore
resp = await handler(request)
return resp
def __call__(self) -> 'Application':
"""gunicorn compatibility"""
return self
def __repr__(self) -> str:
return "<Application 0x{:x}>".format(id(self))
def __bool__(self) -> bool:
return True
class CleanupError(RuntimeError):
@property
def exceptions(self) -> List[BaseException]:
return self.args[1]
if TYPE_CHECKING: # pragma: no cover
_CleanupContextBase = FrozenList[Callable[[Application],
AsyncIterator[None]]]
else:
_CleanupContextBase = FrozenList
class CleanupContext(_CleanupContextBase):
def __init__(self) -> None:
super().__init__()
self._exits = [] # type: List[AsyncIterator[None]]
async def _on_startup(self, app: Application) -> None:
for cb in self:
it = cb(app).__aiter__()
await it.__anext__()
self._exits.append(it)
async def _on_cleanup(self, app: Application) -> None:
errors = []
for it in reversed(self._exits):
try:
await it.__anext__()
except StopAsyncIteration:
pass
except Exception as exc:
errors.append(exc)
else:
errors.append(RuntimeError("{!r} has more than one 'yield'"
.format(it)))
if errors:
if len(errors) == 1:
raise errors[0]
else:
raise CleanupError("Multiple errors on cleanup stage", errors)

View File

@ -0,0 +1,413 @@
import warnings
from typing import Any, Dict, Iterable, List, Optional, Set # noqa
from yarl import URL
from .typedefs import LooseHeaders, StrOrURL
from .web_response import Response
__all__ = (
'HTTPException',
'HTTPError',
'HTTPRedirection',
'HTTPSuccessful',
'HTTPOk',
'HTTPCreated',
'HTTPAccepted',
'HTTPNonAuthoritativeInformation',
'HTTPNoContent',
'HTTPResetContent',
'HTTPPartialContent',
'HTTPMultipleChoices',
'HTTPMovedPermanently',
'HTTPFound',
'HTTPSeeOther',
'HTTPNotModified',
'HTTPUseProxy',
'HTTPTemporaryRedirect',
'HTTPPermanentRedirect',
'HTTPClientError',
'HTTPBadRequest',
'HTTPUnauthorized',
'HTTPPaymentRequired',
'HTTPForbidden',
'HTTPNotFound',
'HTTPMethodNotAllowed',
'HTTPNotAcceptable',
'HTTPProxyAuthenticationRequired',
'HTTPRequestTimeout',
'HTTPConflict',
'HTTPGone',
'HTTPLengthRequired',
'HTTPPreconditionFailed',
'HTTPRequestEntityTooLarge',
'HTTPRequestURITooLong',
'HTTPUnsupportedMediaType',
'HTTPRequestRangeNotSatisfiable',
'HTTPExpectationFailed',
'HTTPMisdirectedRequest',
'HTTPUnprocessableEntity',
'HTTPFailedDependency',
'HTTPUpgradeRequired',
'HTTPPreconditionRequired',
'HTTPTooManyRequests',
'HTTPRequestHeaderFieldsTooLarge',
'HTTPUnavailableForLegalReasons',
'HTTPServerError',
'HTTPInternalServerError',
'HTTPNotImplemented',
'HTTPBadGateway',
'HTTPServiceUnavailable',
'HTTPGatewayTimeout',
'HTTPVersionNotSupported',
'HTTPVariantAlsoNegotiates',
'HTTPInsufficientStorage',
'HTTPNotExtended',
'HTTPNetworkAuthenticationRequired',
)
############################################################
# HTTP Exceptions
############################################################
class HTTPException(Response, Exception):
# You should set in subclasses:
# status = 200
status_code = -1
empty_body = False
__http_exception__ = True
def __init__(self, *,
headers: Optional[LooseHeaders]=None,
reason: Optional[str]=None,
body: Any=None,
text: Optional[str]=None,
content_type: Optional[str]=None) -> None:
if body is not None:
warnings.warn(
"body argument is deprecated for http web exceptions",
DeprecationWarning)
Response.__init__(self, status=self.status_code,
headers=headers, reason=reason,
body=body, text=text, content_type=content_type)
Exception.__init__(self, self.reason)
if self.body is None and not self.empty_body:
self.text = "{}: {}".format(self.status, self.reason)
def __bool__(self) -> bool:
return True
class HTTPError(HTTPException):
"""Base class for exceptions with status codes in the 400s and 500s."""
class HTTPRedirection(HTTPException):
"""Base class for exceptions with status codes in the 300s."""
class HTTPSuccessful(HTTPException):
"""Base class for exceptions with status codes in the 200s."""
class HTTPOk(HTTPSuccessful):
status_code = 200
class HTTPCreated(HTTPSuccessful):
status_code = 201
class HTTPAccepted(HTTPSuccessful):
status_code = 202
class HTTPNonAuthoritativeInformation(HTTPSuccessful):
status_code = 203
class HTTPNoContent(HTTPSuccessful):
status_code = 204
empty_body = True
class HTTPResetContent(HTTPSuccessful):
status_code = 205
empty_body = True
class HTTPPartialContent(HTTPSuccessful):
status_code = 206
############################################################
# 3xx redirection
############################################################
class _HTTPMove(HTTPRedirection):
def __init__(self,
location: StrOrURL,
*,
headers: Optional[LooseHeaders]=None,
reason: Optional[str]=None,
body: Any=None,
text: Optional[str]=None,
content_type: Optional[str]=None) -> None:
if not location:
raise ValueError("HTTP redirects need a location to redirect to.")
super().__init__(headers=headers, reason=reason,
body=body, text=text, content_type=content_type)
self.headers['Location'] = str(URL(location))
self.location = location
class HTTPMultipleChoices(_HTTPMove):
status_code = 300
class HTTPMovedPermanently(_HTTPMove):
status_code = 301
class HTTPFound(_HTTPMove):
status_code = 302
# This one is safe after a POST (the redirected location will be
# retrieved with GET):
class HTTPSeeOther(_HTTPMove):
status_code = 303
class HTTPNotModified(HTTPRedirection):
# FIXME: this should include a date or etag header
status_code = 304
empty_body = True
class HTTPUseProxy(_HTTPMove):
# Not a move, but looks a little like one
status_code = 305
class HTTPTemporaryRedirect(_HTTPMove):
status_code = 307
class HTTPPermanentRedirect(_HTTPMove):
status_code = 308
############################################################
# 4xx client error
############################################################
class HTTPClientError(HTTPError):
pass
class HTTPBadRequest(HTTPClientError):
status_code = 400
class HTTPUnauthorized(HTTPClientError):
status_code = 401
class HTTPPaymentRequired(HTTPClientError):
status_code = 402
class HTTPForbidden(HTTPClientError):
status_code = 403
class HTTPNotFound(HTTPClientError):
status_code = 404
class HTTPMethodNotAllowed(HTTPClientError):
status_code = 405
def __init__(self,
method: str,
allowed_methods: Iterable[str],
*,
headers: Optional[LooseHeaders]=None,
reason: Optional[str]=None,
body: Any=None,
text: Optional[str]=None,
content_type: Optional[str]=None) -> None:
allow = ','.join(sorted(allowed_methods))
super().__init__(headers=headers, reason=reason,
body=body, text=text, content_type=content_type)
self.headers['Allow'] = allow
self.allowed_methods = set(allowed_methods) # type: Set[str]
self.method = method.upper()
class HTTPNotAcceptable(HTTPClientError):
status_code = 406
class HTTPProxyAuthenticationRequired(HTTPClientError):
status_code = 407
class HTTPRequestTimeout(HTTPClientError):
status_code = 408
class HTTPConflict(HTTPClientError):
status_code = 409
class HTTPGone(HTTPClientError):
status_code = 410
class HTTPLengthRequired(HTTPClientError):
status_code = 411
class HTTPPreconditionFailed(HTTPClientError):
status_code = 412
class HTTPRequestEntityTooLarge(HTTPClientError):
status_code = 413
def __init__(self,
max_size: float,
actual_size: float,
**kwargs: Any) -> None:
kwargs.setdefault(
'text',
'Maximum request body size {} exceeded, '
'actual body size {}'.format(max_size, actual_size)
)
super().__init__(**kwargs)
class HTTPRequestURITooLong(HTTPClientError):
status_code = 414
class HTTPUnsupportedMediaType(HTTPClientError):
status_code = 415
class HTTPRequestRangeNotSatisfiable(HTTPClientError):
status_code = 416
class HTTPExpectationFailed(HTTPClientError):
status_code = 417
class HTTPMisdirectedRequest(HTTPClientError):
status_code = 421
class HTTPUnprocessableEntity(HTTPClientError):
status_code = 422
class HTTPFailedDependency(HTTPClientError):
status_code = 424
class HTTPUpgradeRequired(HTTPClientError):
status_code = 426
class HTTPPreconditionRequired(HTTPClientError):
status_code = 428
class HTTPTooManyRequests(HTTPClientError):
status_code = 429
class HTTPRequestHeaderFieldsTooLarge(HTTPClientError):
status_code = 431
class HTTPUnavailableForLegalReasons(HTTPClientError):
status_code = 451
def __init__(self,
link: str,
*,
headers: Optional[LooseHeaders]=None,
reason: Optional[str]=None,
body: Any=None,
text: Optional[str]=None,
content_type: Optional[str]=None) -> None:
super().__init__(headers=headers, reason=reason,
body=body, text=text, content_type=content_type)
self.headers['Link'] = '<%s>; rel="blocked-by"' % link
self.link = link
############################################################
# 5xx Server Error
############################################################
# Response status codes beginning with the digit "5" indicate cases in
# which the server is aware that it has erred or is incapable of
# performing the request. Except when responding to a HEAD request, the
# server SHOULD include an entity containing an explanation of the error
# situation, and whether it is a temporary or permanent condition. User
# agents SHOULD display any included entity to the user. These response
# codes are applicable to any request method.
class HTTPServerError(HTTPError):
pass
class HTTPInternalServerError(HTTPServerError):
status_code = 500
class HTTPNotImplemented(HTTPServerError):
status_code = 501
class HTTPBadGateway(HTTPServerError):
status_code = 502
class HTTPServiceUnavailable(HTTPServerError):
status_code = 503
class HTTPGatewayTimeout(HTTPServerError):
status_code = 504
class HTTPVersionNotSupported(HTTPServerError):
status_code = 505
class HTTPVariantAlsoNegotiates(HTTPServerError):
status_code = 506
class HTTPInsufficientStorage(HTTPServerError):
status_code = 507
class HTTPNotExtended(HTTPServerError):
status_code = 510
class HTTPNetworkAuthenticationRequired(HTTPServerError):
status_code = 511

View File

@ -0,0 +1,372 @@
import asyncio
import mimetypes
import os
import pathlib
from functools import partial
from typing import ( # noqa
IO,
TYPE_CHECKING,
Any,
Awaitable,
Callable,
List,
Optional,
Union,
cast,
)
from . import hdrs
from .abc import AbstractStreamWriter
from .base_protocol import BaseProtocol
from .helpers import set_exception, set_result
from .http_writer import StreamWriter
from .log import server_logger
from .typedefs import LooseHeaders
from .web_exceptions import (
HTTPNotModified,
HTTPPartialContent,
HTTPPreconditionFailed,
HTTPRequestRangeNotSatisfiable,
)
from .web_response import StreamResponse
__all__ = ('FileResponse',)
if TYPE_CHECKING: # pragma: no cover
from .web_request import BaseRequest # noqa
_T_OnChunkSent = Optional[Callable[[bytes], Awaitable[None]]]
NOSENDFILE = bool(os.environ.get("AIOHTTP_NOSENDFILE"))
class SendfileStreamWriter(StreamWriter):
def __init__(self,
protocol: BaseProtocol,
loop: asyncio.AbstractEventLoop,
fobj: IO[Any],
offset: int,
count: int,
on_chunk_sent: _T_OnChunkSent=None) -> None:
super().__init__(protocol, loop, on_chunk_sent)
self._sendfile_buffer = [] # type: List[bytes]
self._fobj = fobj
self._count = count
self._offset = offset
self._in_fd = fobj.fileno()
def _write(self, chunk: bytes) -> None:
# we overwrite StreamWriter._write, so nothing can be appended to
# _buffer, and nothing is written to the transport directly by the
# parent class
self.output_size += len(chunk)
self._sendfile_buffer.append(chunk)
def _sendfile_cb(self, fut: 'asyncio.Future[None]', out_fd: int) -> None:
if fut.cancelled():
return
try:
if self._do_sendfile(out_fd):
set_result(fut, None)
except Exception as exc:
set_exception(fut, exc)
def _do_sendfile(self, out_fd: int) -> bool:
try:
n = os.sendfile(out_fd,
self._in_fd,
self._offset,
self._count)
if n == 0: # in_fd EOF reached
n = self._count
except (BlockingIOError, InterruptedError):
n = 0
self.output_size += n
self._offset += n
self._count -= n
assert self._count >= 0
return self._count == 0
def _done_fut(self, out_fd: int, fut: 'asyncio.Future[None]') -> None:
self.loop.remove_writer(out_fd)
async def sendfile(self) -> None:
assert self.transport is not None
loop = self.loop
data = b''.join(self._sendfile_buffer)
if hasattr(loop, "sendfile"):
# Python 3.7+
self.transport.write(data)
await loop.sendfile(
self.transport,
self._fobj,
self._offset,
self._count
)
await super().write_eof()
return
self._fobj.seek(self._offset)
out_socket = self.transport.get_extra_info('socket').dup()
out_socket.setblocking(False)
out_fd = out_socket.fileno()
try:
await loop.sock_sendall(out_socket, data)
if not self._do_sendfile(out_fd):
fut = loop.create_future()
fut.add_done_callback(partial(self._done_fut, out_fd))
loop.add_writer(out_fd, self._sendfile_cb, fut, out_fd)
await fut
except asyncio.CancelledError:
raise
except Exception:
server_logger.debug('Socket error')
self.transport.close()
finally:
out_socket.close()
await super().write_eof()
async def write_eof(self, chunk: bytes=b'') -> None:
pass
class FileResponse(StreamResponse):
"""A response object can be used to send files."""
def __init__(self, path: Union[str, pathlib.Path],
chunk_size: int=256*1024,
status: int=200,
reason: Optional[str]=None,
headers: Optional[LooseHeaders]=None) -> None:
super().__init__(status=status, reason=reason, headers=headers)
if isinstance(path, str):
path = pathlib.Path(path)
self._path = path
self._chunk_size = chunk_size
async def _sendfile_system(self, request: 'BaseRequest',
fobj: IO[Any],
offset: int,
count: int) -> AbstractStreamWriter:
# Write count bytes of fobj to resp using
# the os.sendfile system call.
#
# For details check
# https://github.com/KeepSafe/hyper_internal_service/issues/1177
# See https://github.com/KeepSafe/hyper_internal_service/issues/958 for details
#
# request should be an hyper_internal_service.web.Request instance.
# fobj should be an open file object.
# count should be an integer > 0.
transport = request.transport
assert transport is not None
if (transport.get_extra_info("sslcontext") or
transport.get_extra_info("socket") is None or
self.compression):
writer = await self._sendfile_fallback(
request,
fobj,
offset,
count
)
else:
writer = SendfileStreamWriter(
request.protocol,
request._loop,
fobj,
offset,
count
)
request._payload_writer = writer
await super().prepare(request)
await writer.sendfile()
return writer
async def _sendfile_fallback(self, request: 'BaseRequest',
fobj: IO[Any],
offset: int,
count: int) -> AbstractStreamWriter:
# Mimic the _sendfile_system() method, but without using the
# os.sendfile() system call. This should be used on systems
# that don't support the os.sendfile().
# To keep memory usage low,fobj is transferred in chunks
# controlled by the constructor's chunk_size argument.
writer = await super().prepare(request)
assert writer is not None
chunk_size = self._chunk_size
loop = asyncio.get_event_loop()
await loop.run_in_executor(None, fobj.seek, offset)
chunk = await loop.run_in_executor(None, fobj.read, chunk_size)
while chunk:
await writer.write(chunk)
count = count - chunk_size
if count <= 0:
break
chunk = await loop.run_in_executor(
None, fobj.read, min(chunk_size, count)
)
await writer.drain()
return writer
if hasattr(os, "sendfile") and not NOSENDFILE: # pragma: no cover
_sendfile = _sendfile_system
else: # pragma: no cover
_sendfile = _sendfile_fallback
async def prepare(
self,
request: 'BaseRequest'
) -> Optional[AbstractStreamWriter]:
filepath = self._path
gzip = False
if 'gzip' in request.headers.get(hdrs.ACCEPT_ENCODING, ''):
gzip_path = filepath.with_name(filepath.name + '.gz')
if gzip_path.is_file():
filepath = gzip_path
gzip = True
loop = asyncio.get_event_loop()
st = await loop.run_in_executor(None, filepath.stat)
modsince = request.if_modified_since
if modsince is not None and st.st_mtime <= modsince.timestamp():
self.set_status(HTTPNotModified.status_code)
self._length_check = False
# Delete any Content-Length headers provided by user. HTTP 304
# should always have empty response body
return await super().prepare(request)
unmodsince = request.if_unmodified_since
if unmodsince is not None and st.st_mtime > unmodsince.timestamp():
self.set_status(HTTPPreconditionFailed.status_code)
return await super().prepare(request)
if hdrs.CONTENT_TYPE not in self.headers:
ct, encoding = mimetypes.guess_type(str(filepath))
if not ct:
ct = 'application/octet-stream'
should_set_ct = True
else:
encoding = 'gzip' if gzip else None
should_set_ct = False
status = self._status
file_size = st.st_size
count = file_size
start = None
ifrange = request.if_range
if ifrange is None or st.st_mtime <= ifrange.timestamp():
# If-Range header check:
# condition = cached date >= last modification date
# return 206 if True else 200.
# if False:
# Range header would not be processed, return 200
# if True but Range header missing
# return 200
try:
rng = request.http_range
start = rng.start
end = rng.stop
except ValueError:
# https://tools.ietf.org/html/rfc7233:
# A server generating a 416 (Range Not Satisfiable) response to
# a byte-range request SHOULD send a Content-Range header field
# with an unsatisfied-range value.
# The complete-length in a 416 response indicates the current
# length of the selected representation.
#
# Will do the same below. Many servers ignore this and do not
# send a Content-Range header with HTTP 416
self.headers[hdrs.CONTENT_RANGE] = 'bytes */{0}'.format(
file_size)
self.set_status(HTTPRequestRangeNotSatisfiable.status_code)
return await super().prepare(request)
# If a range request has been made, convert start, end slice
# notation into file pointer offset and count
if start is not None or end is not None:
if start < 0 and end is None: # return tail of file
start += file_size
if start < 0:
# if Range:bytes=-1000 in request header but file size
# is only 200, there would be trouble without this
start = 0
count = file_size - start
else:
# rfc7233:If the last-byte-pos value is
# absent, or if the value is greater than or equal to
# the current length of the representation data,
# the byte range is interpreted as the remainder
# of the representation (i.e., the server replaces the
# value of last-byte-pos with a value that is one less than
# the current length of the selected representation).
count = min(end if end is not None else file_size,
file_size) - start
if start >= file_size:
# HTTP 416 should be returned in this case.
#
# According to https://tools.ietf.org/html/rfc7233:
# If a valid byte-range-set includes at least one
# byte-range-spec with a first-byte-pos that is less than
# the current length of the representation, or at least one
# suffix-byte-range-spec with a non-zero suffix-length,
# then the byte-range-set is satisfiable. Otherwise, the
# byte-range-set is unsatisfiable.
self.headers[hdrs.CONTENT_RANGE] = 'bytes */{0}'.format(
file_size)
self.set_status(HTTPRequestRangeNotSatisfiable.status_code)
return await super().prepare(request)
status = HTTPPartialContent.status_code
# Even though you are sending the whole file, you should still
# return a HTTP 206 for a Range request.
self.set_status(status)
if should_set_ct:
self.content_type = ct # type: ignore
if encoding:
self.headers[hdrs.CONTENT_ENCODING] = encoding
if gzip:
self.headers[hdrs.VARY] = hdrs.ACCEPT_ENCODING
self.last_modified = st.st_mtime # type: ignore
self.content_length = count
self.headers[hdrs.ACCEPT_RANGES] = 'bytes'
real_start = cast(int, start)
if status == HTTPPartialContent.status_code:
self.headers[hdrs.CONTENT_RANGE] = 'bytes {0}-{1}/{2}'.format(
real_start, real_start + count - 1, file_size)
fobj = await loop.run_in_executor(None, filepath.open, 'rb')
if start: # be aware that start could be None or int=0 here.
offset = start
else:
offset = 0
try:
return await self._sendfile(request, fobj, offset, count)
finally:
await loop.run_in_executor(None, fobj.close)

View File

@ -0,0 +1,235 @@
import datetime
import functools
import logging
import os
import re
from collections import namedtuple
from typing import Any, Callable, Dict, Iterable, List, Tuple # noqa
from .abc import AbstractAccessLogger
from .web_request import BaseRequest
from .web_response import StreamResponse
KeyMethod = namedtuple('KeyMethod', 'key method')
class AccessLogger(AbstractAccessLogger):
"""Helper object to log access.
Usage:
log = logging.getLogger("spam")
log_format = "%a %{User-Agent}i"
access_logger = AccessLogger(log, log_format)
access_logger.log(request, response, time)
Format:
%% The percent sign
%a Remote IP-address (IP-address of proxy if using reverse proxy)
%t Time when the request was started to process
%P The process ID of the child that serviced the request
%r First line of request
%s Response status code
%b Size of response in bytes, including HTTP headers
%T Time taken to serve the request, in seconds
%Tf Time taken to serve the request, in seconds with floating fraction
in .06f format
%D Time taken to serve the request, in microseconds
%{FOO}i request.headers['FOO']
%{FOO}o response.headers['FOO']
%{FOO}e os.environ['FOO']
"""
LOG_FORMAT_MAP = {
'a': 'remote_address',
't': 'request_start_time',
'P': 'process_id',
'r': 'first_request_line',
's': 'response_status',
'b': 'response_size',
'T': 'request_time',
'Tf': 'request_time_frac',
'D': 'request_time_micro',
'i': 'request_header',
'o': 'response_header',
}
LOG_FORMAT = '%a %t "%r" %s %b "%{Referer}i" "%{User-Agent}i"'
FORMAT_RE = re.compile(r'%(\{([A-Za-z0-9\-_]+)\}([ioe])|[atPrsbOD]|Tf?)')
CLEANUP_RE = re.compile(r'(%[^s])')
_FORMAT_CACHE = {} # type: Dict[str, Tuple[str, List[KeyMethod]]]
def __init__(self, logger: logging.Logger,
log_format: str=LOG_FORMAT) -> None:
"""Initialise the logger.
logger is a logger object to be used for logging.
log_format is a string with apache compatible log format description.
"""
super().__init__(logger, log_format=log_format)
_compiled_format = AccessLogger._FORMAT_CACHE.get(log_format)
if not _compiled_format:
_compiled_format = self.compile_format(log_format)
AccessLogger._FORMAT_CACHE[log_format] = _compiled_format
self._log_format, self._methods = _compiled_format
def compile_format(self, log_format: str) -> Tuple[str, List[KeyMethod]]:
"""Translate log_format into form usable by modulo formatting
All known atoms will be replaced with %s
Also methods for formatting of those atoms will be added to
_methods in appropriate order
For example we have log_format = "%a %t"
This format will be translated to "%s %s"
Also contents of _methods will be
[self._format_a, self._format_t]
These method will be called and results will be passed
to translated string format.
Each _format_* method receive 'args' which is list of arguments
given to self.log
Exceptions are _format_e, _format_i and _format_o methods which
also receive key name (by functools.partial)
"""
# list of (key, method) tuples, we don't use an OrderedDict as users
# can repeat the same key more than once
methods = list()
for atom in self.FORMAT_RE.findall(log_format):
if atom[1] == '':
format_key1 = self.LOG_FORMAT_MAP[atom[0]]
m = getattr(AccessLogger, '_format_%s' % atom[0])
key_method = KeyMethod(format_key1, m)
else:
format_key2 = (self.LOG_FORMAT_MAP[atom[2]], atom[1])
m = getattr(AccessLogger, '_format_%s' % atom[2])
key_method = KeyMethod(format_key2,
functools.partial(m, atom[1]))
methods.append(key_method)
log_format = self.FORMAT_RE.sub(r'%s', log_format)
log_format = self.CLEANUP_RE.sub(r'%\1', log_format)
return log_format, methods
@staticmethod
def _format_i(key: str,
request: BaseRequest,
response: StreamResponse,
time: float) -> str:
if request is None:
return '(no headers)'
# suboptimal, make istr(key) once
return request.headers.get(key, '-')
@staticmethod
def _format_o(key: str,
request: BaseRequest,
response: StreamResponse,
time: float) -> str:
# suboptimal, make istr(key) once
return response.headers.get(key, '-')
@staticmethod
def _format_a(request: BaseRequest,
response: StreamResponse,
time: float) -> str:
if request is None:
return '-'
ip = request.remote
return ip if ip is not None else '-'
@staticmethod
def _format_t(request: BaseRequest,
response: StreamResponse,
time: float) -> str:
now = datetime.datetime.utcnow()
start_time = now - datetime.timedelta(seconds=time)
return start_time.strftime('[%d/%b/%Y:%H:%M:%S +0000]')
@staticmethod
def _format_P(request: BaseRequest,
response: StreamResponse,
time: float) -> str:
return "<%s>" % os.getpid()
@staticmethod
def _format_r(request: BaseRequest,
response: StreamResponse,
time: float) -> str:
if request is None:
return '-'
return '%s %s HTTP/%s.%s' % (request.method, request.path_qs,
request.version.major,
request.version.minor)
@staticmethod
def _format_s(request: BaseRequest,
response: StreamResponse,
time: float) -> int:
return response.status
@staticmethod
def _format_b(request: BaseRequest,
response: StreamResponse,
time: float) -> int:
return response.body_length
@staticmethod
def _format_T(request: BaseRequest,
response: StreamResponse,
time: float) -> str:
return str(round(time))
@staticmethod
def _format_Tf(request: BaseRequest,
response: StreamResponse,
time: float) -> str:
return '%06f' % time
@staticmethod
def _format_D(request: BaseRequest,
response: StreamResponse,
time: float) -> str:
return str(round(time * 1000000))
def _format_line(self,
request: BaseRequest,
response: StreamResponse,
time: float) -> Iterable[Tuple[str,
Callable[[BaseRequest,
StreamResponse,
float],
str]]]:
return [(key, method(request, response, time))
for key, method in self._methods]
def log(self,
request: BaseRequest,
response: StreamResponse,
time: float) -> None:
try:
fmt_info = self._format_line(request, response, time)
values = list()
extra = dict()
for key, value in fmt_info:
values.append(value)
if key.__class__ is str:
extra[key] = value
else:
k1, k2 = key
dct = extra.get(k1, {}) # type: Any
dct[k2] = value
extra[k1] = dct
self.logger.info(self._log_format % tuple(values), extra=extra)
except Exception:
self.logger.exception("Error in logging")

View File

@ -0,0 +1,120 @@
import re
from typing import TYPE_CHECKING, Awaitable, Callable, Tuple, Type, TypeVar
from .web_exceptions import HTTPPermanentRedirect, _HTTPMove
from .web_request import Request
from .web_response import StreamResponse
from .web_urldispatcher import SystemRoute
__all__ = (
'middleware',
'normalize_path_middleware',
)
if TYPE_CHECKING: # pragma: no cover
from .web_app import Application # noqa
_Func = TypeVar('_Func')
async def _check_request_resolves(request: Request,
path: str) -> Tuple[bool, Request]:
alt_request = request.clone(rel_url=path)
match_info = await request.app.router.resolve(alt_request)
alt_request._match_info = match_info # type: ignore
if match_info.http_exception is None:
return True, alt_request
return False, request
def middleware(f: _Func) -> _Func:
f.__middleware_version__ = 1 # type: ignore
return f
_Handler = Callable[[Request], Awaitable[StreamResponse]]
_Middleware = Callable[[Request, _Handler], Awaitable[StreamResponse]]
def normalize_path_middleware(
*, append_slash: bool=True, remove_slash: bool=False,
merge_slashes: bool=True,
redirect_class: Type[_HTTPMove]=HTTPPermanentRedirect) -> _Middleware:
"""
Middleware factory which produces a middleware that normalizes
the path of a request. By normalizing it means:
- Add or remove a trailing slash to the path.
- Double slashes are replaced by one.
The middleware returns as soon as it finds a path that resolves
correctly. The order if both merge and append/remove are enabled is
1) merge slashes
2) append/remove slash
3) both merge slashes and append/remove slash.
If the path resolves with at least one of those conditions, it will
redirect to the new path.
Only one of `append_slash` and `remove_slash` can be enabled. If both
are `True` the factory will raise an assertion error
If `append_slash` is `True` the middleware will append a slash when
needed. If a resource is defined with trailing slash and the request
comes without it, it will append it automatically.
If `remove_slash` is `True`, `append_slash` must be `False`. When enabled
the middleware will remove trailing slashes and redirect if the resource
is defined
If merge_slashes is True, merge multiple consecutive slashes in the
path into one.
"""
correct_configuration = not (append_slash and remove_slash)
assert correct_configuration, "Cannot both remove and append slash"
@middleware
async def impl(request: Request, handler: _Handler) -> StreamResponse:
if isinstance(request.match_info.route, SystemRoute):
paths_to_check = []
if '?' in request.raw_path:
path, query = request.raw_path.split('?', 1)
query = '?' + query
else:
query = ''
path = request.raw_path
if merge_slashes:
paths_to_check.append(re.sub('//+', '/', path))
if append_slash and not request.path.endswith('/'):
paths_to_check.append(path + '/')
if remove_slash and request.path.endswith('/'):
paths_to_check.append(path[:-1])
if merge_slashes and append_slash:
paths_to_check.append(
re.sub('//+', '/', path + '/'))
if merge_slashes and remove_slash:
merged_slashes = re.sub('//+', '/', path)
paths_to_check.append(merged_slashes[:-1])
for path in paths_to_check:
resolves, request = await _check_request_resolves(
request, path)
if resolves:
raise redirect_class(request.raw_path + query)
return await handler(request)
return impl
def _fix_request_current_app(app: 'Application') -> _Middleware:
@middleware
async def impl(request: Request, handler: _Handler) -> StreamResponse:
with request.match_info.set_current_app(app):
return await handler(request)
return impl

View File

@ -0,0 +1,627 @@
import asyncio
import asyncio.streams
import traceback
import warnings
from collections import deque
from contextlib import suppress
from html import escape as html_escape
from http import HTTPStatus
from logging import Logger
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
Optional,
Tuple,
Type,
cast,
)
import yarl
from .abc import AbstractAccessLogger, AbstractStreamWriter
from .base_protocol import BaseProtocol
from .helpers import CeilTimeout, current_task
from .http import (
HttpProcessingError,
HttpRequestParser,
HttpVersion10,
RawRequestMessage,
StreamWriter,
)
from .log import access_logger, server_logger
from .streams import EMPTY_PAYLOAD, StreamReader
from .tcp_helpers import tcp_keepalive
from .web_exceptions import HTTPException
from .web_log import AccessLogger
from .web_request import BaseRequest
from .web_response import Response, StreamResponse
__all__ = ('RequestHandler', 'RequestPayloadError', 'PayloadAccessError')
if TYPE_CHECKING: # pragma: no cover
from .web_server import Server # noqa
_RequestFactory = Callable[[RawRequestMessage,
StreamReader,
'RequestHandler',
AbstractStreamWriter,
'asyncio.Task[None]'],
BaseRequest]
_RequestHandler = Callable[[BaseRequest], Awaitable[StreamResponse]]
ERROR = RawRequestMessage(
'UNKNOWN', '/', HttpVersion10, {},
{}, True, False, False, False, yarl.URL('/'))
class RequestPayloadError(Exception):
"""Payload parsing error."""
class PayloadAccessError(Exception):
"""Payload was accessed after response was sent."""
class RequestHandler(BaseProtocol):
"""HTTP protocol implementation.
RequestHandler handles incoming HTTP request. It reads request line,
request headers and request payload and calls handle_request() method.
By default it always returns with 404 response.
RequestHandler handles errors in incoming request, like bad
status line, bad headers or incomplete payload. If any error occurs,
connection gets closed.
:param keepalive_timeout: number of seconds before closing
keep-alive connection
:type keepalive_timeout: int or None
:param bool tcp_keepalive: TCP keep-alive is on, default is on
:param bool debug: enable debug mode
:param logger: custom logger object
:type logger: hyper_internal_service.log.server_logger
:param access_log_class: custom class for access_logger
:type access_log_class: hyper_internal_service.abc.AbstractAccessLogger
:param access_log: custom logging object
:type access_log: hyper_internal_service.log.server_logger
:param str access_log_format: access log format string
:param loop: Optional event loop
:param int max_line_size: Optional maximum header line size
:param int max_field_size: Optional maximum header field size
:param int max_headers: Optional maximum header size
"""
KEEPALIVE_RESCHEDULE_DELAY = 1
__slots__ = ('_request_count', '_keepalive', '_manager',
'_request_handler', '_request_factory', '_tcp_keepalive',
'_keepalive_time', '_keepalive_handle', '_keepalive_timeout',
'_lingering_time', '_messages', '_message_tail',
'_waiter', '_error_handler', '_task_handler',
'_upgrade', '_payload_parser', '_request_parser',
'_reading_paused', 'logger', 'debug', 'access_log',
'access_logger', '_close', '_force_close')
def __init__(self, manager: 'Server', *,
loop: asyncio.AbstractEventLoop,
keepalive_timeout: float=75., # NGINX default is 75 secs
tcp_keepalive: bool=True,
logger: Logger=server_logger,
access_log_class: Type[AbstractAccessLogger]=AccessLogger,
access_log: Logger=access_logger,
access_log_format: str=AccessLogger.LOG_FORMAT,
debug: bool=False,
max_line_size: int=8190,
max_headers: int=32768,
max_field_size: int=8190,
lingering_time: float=10.0):
super().__init__(loop)
self._request_count = 0
self._keepalive = False
self._manager = manager # type: Optional[Server]
self._request_handler = manager.request_handler # type: Optional[_RequestHandler] # noqa
self._request_factory = manager.request_factory # type: Optional[_RequestFactory] # noqa
self._tcp_keepalive = tcp_keepalive
# placeholder to be replaced on keepalive timeout setup
self._keepalive_time = 0.0
self._keepalive_handle = None # type: Optional[asyncio.Handle]
self._keepalive_timeout = keepalive_timeout
self._lingering_time = float(lingering_time)
self._messages = deque() # type: Any # Python 3.5 has no typing.Deque
self._message_tail = b''
self._waiter = None # type: Optional[asyncio.Future[None]]
self._error_handler = None # type: Optional[asyncio.Task[None]]
self._task_handler = None # type: Optional[asyncio.Task[None]]
self._upgrade = False
self._payload_parser = None # type: Any
self._request_parser = HttpRequestParser(
self, loop,
max_line_size=max_line_size,
max_field_size=max_field_size,
max_headers=max_headers,
payload_exception=RequestPayloadError) # type: Optional[HttpRequestParser] # noqa
self.logger = logger
self.debug = debug
self.access_log = access_log
if access_log:
self.access_logger = access_log_class(
access_log, access_log_format) # type: Optional[AbstractAccessLogger] # noqa
else:
self.access_logger = None
self._close = False
self._force_close = False
def __repr__(self) -> str:
return "<{} {}>".format(
self.__class__.__name__,
'connected' if self.transport is not None else 'disconnected')
@property
def keepalive_timeout(self) -> float:
return self._keepalive_timeout
async def shutdown(self, timeout: Optional[float]=15.0) -> None:
"""Worker process is about to exit, we need cleanup everything and
stop accepting requests. It is especially important for keep-alive
connections."""
self._force_close = True
if self._keepalive_handle is not None:
self._keepalive_handle.cancel()
if self._waiter:
self._waiter.cancel()
# wait for handlers
with suppress(asyncio.CancelledError, asyncio.TimeoutError):
with CeilTimeout(timeout, loop=self._loop):
if (self._error_handler is not None and
not self._error_handler.done()):
await self._error_handler
if (self._task_handler is not None and
not self._task_handler.done()):
await self._task_handler
# force-close non-idle handler
if self._task_handler is not None:
self._task_handler.cancel()
if self.transport is not None:
self.transport.close()
self.transport = None
def connection_made(self, transport: asyncio.BaseTransport) -> None:
super().connection_made(transport)
real_transport = cast(asyncio.Transport, transport)
if self._tcp_keepalive:
tcp_keepalive(real_transport)
self._task_handler = self._loop.create_task(self.start())
assert self._manager is not None
self._manager.connection_made(self, real_transport)
def connection_lost(self, exc: Optional[BaseException]) -> None:
if self._manager is None:
return
self._manager.connection_lost(self, exc)
super().connection_lost(exc)
self._manager = None
self._force_close = True
self._request_factory = None
self._request_handler = None
self._request_parser = None
if self._keepalive_handle is not None:
self._keepalive_handle.cancel()
if self._task_handler is not None:
self._task_handler.cancel()
if self._error_handler is not None:
self._error_handler.cancel()
self._task_handler = None
if self._payload_parser is not None:
self._payload_parser.feed_eof()
self._payload_parser = None
def set_parser(self, parser: Any) -> None:
# Actual type is WebReader
assert self._payload_parser is None
self._payload_parser = parser
if self._message_tail:
self._payload_parser.feed_data(self._message_tail)
self._message_tail = b''
def eof_received(self) -> None:
pass
def data_received(self, data: bytes) -> None:
if self._force_close or self._close:
return
# parse http messages
if self._payload_parser is None and not self._upgrade:
assert self._request_parser is not None
try:
messages, upgraded, tail = self._request_parser.feed_data(data)
except HttpProcessingError as exc:
# something happened during parsing
self._error_handler = self._loop.create_task(
self.handle_parse_error(
StreamWriter(self, self._loop),
400, exc, exc.message))
self.close()
except Exception as exc:
# 500: internal error
self._error_handler = self._loop.create_task(
self.handle_parse_error(
StreamWriter(self, self._loop),
500, exc))
self.close()
else:
if messages:
# sometimes the parser returns no messages
for (msg, payload) in messages:
self._request_count += 1
self._messages.append((msg, payload))
waiter = self._waiter
if waiter is not None:
if not waiter.done():
# don't set result twice
waiter.set_result(None)
self._upgrade = upgraded
if upgraded and tail:
self._message_tail = tail
# no parser, just store
elif self._payload_parser is None and self._upgrade and data:
self._message_tail += data
# feed payload
elif data:
eof, tail = self._payload_parser.feed_data(data)
if eof:
self.close()
def keep_alive(self, val: bool) -> None:
"""Set keep-alive connection mode.
:param bool val: new state.
"""
self._keepalive = val
if self._keepalive_handle:
self._keepalive_handle.cancel()
self._keepalive_handle = None
def close(self) -> None:
"""Stop accepting new pipelinig messages and close
connection when handlers done processing messages"""
self._close = True
if self._waiter:
self._waiter.cancel()
def force_close(self) -> None:
"""Force close connection"""
self._force_close = True
if self._waiter:
self._waiter.cancel()
if self.transport is not None:
self.transport.close()
self.transport = None
def log_access(self,
request: BaseRequest,
response: StreamResponse,
time: float) -> None:
if self.access_logger is not None:
self.access_logger.log(request, response, time)
def log_debug(self, *args: Any, **kw: Any) -> None:
if self.debug:
self.logger.debug(*args, **kw)
def log_exception(self, *args: Any, **kw: Any) -> None:
self.logger.exception(*args, **kw)
def _process_keepalive(self) -> None:
if self._force_close or not self._keepalive:
return
next = self._keepalive_time + self._keepalive_timeout
# handler in idle state
if self._waiter:
if self._loop.time() > next:
self.force_close()
return
# not all request handlers are done,
# reschedule itself to next second
self._keepalive_handle = self._loop.call_later(
self.KEEPALIVE_RESCHEDULE_DELAY, self._process_keepalive)
async def _handle_request(self,
request: BaseRequest,
start_time: float,
) -> Tuple[StreamResponse, bool]:
assert self._request_handler is not None
try:
resp = await self._request_handler(request)
except HTTPException as exc:
resp = Response(status=exc.status,
reason=exc.reason,
text=exc.text,
headers=exc.headers)
reset = await self.finish_response(request, resp, start_time)
except asyncio.CancelledError:
raise
except asyncio.TimeoutError as exc:
self.log_debug('Request handler timed out.', exc_info=exc)
resp = self.handle_error(request, 504)
reset = await self.finish_response(request, resp, start_time)
except Exception as exc:
resp = self.handle_error(request, 500, exc)
reset = await self.finish_response(request, resp, start_time)
else:
reset = await self.finish_response(request, resp, start_time)
return resp, reset
async def start(self) -> None:
"""Process incoming request.
It reads request line, request headers and request payload, then
calls handle_request() method. Subclass has to override
handle_request(). start() handles various exceptions in request
or response handling. Connection is being closed always unless
keep_alive(True) specified.
"""
loop = self._loop
handler = self._task_handler
assert handler is not None
manager = self._manager
assert manager is not None
keepalive_timeout = self._keepalive_timeout
resp = None
assert self._request_factory is not None
assert self._request_handler is not None
while not self._force_close:
if not self._messages:
try:
# wait for next request
self._waiter = loop.create_future()
await self._waiter
except asyncio.CancelledError:
break
finally:
self._waiter = None
message, payload = self._messages.popleft()
start = loop.time()
manager.requests_count += 1
writer = StreamWriter(self, loop)
request = self._request_factory(
message, payload, self, writer, handler)
try:
# a new task is used for copy context vars (#3406)
task = self._loop.create_task(
self._handle_request(request, start))
try:
resp, reset = await task
except (asyncio.CancelledError, ConnectionError):
self.log_debug('Ignored premature client disconnection')
break
# Deprecation warning (See #2415)
if getattr(resp, '__http_exception__', False):
warnings.warn(
"returning HTTPException object is deprecated "
"(#2415) and will be removed, "
"please raise the exception instead",
DeprecationWarning)
# Drop the processed task from asyncio.Task.all_tasks() early
del task
if reset:
self.log_debug('Ignored premature client disconnection 2')
break
# notify server about keep-alive
self._keepalive = bool(resp.keep_alive)
# check payload
if not payload.is_eof():
lingering_time = self._lingering_time
if not self._force_close and lingering_time:
self.log_debug(
'Start lingering close timer for %s sec.',
lingering_time)
now = loop.time()
end_t = now + lingering_time
with suppress(
asyncio.TimeoutError, asyncio.CancelledError):
while not payload.is_eof() and now < end_t:
with CeilTimeout(end_t - now, loop=loop):
# read and ignore
await payload.readany()
now = loop.time()
# if payload still uncompleted
if not payload.is_eof() and not self._force_close:
self.log_debug('Uncompleted request.')
self.close()
payload.set_exception(PayloadAccessError())
except asyncio.CancelledError:
self.log_debug('Ignored premature client disconnection ')
break
except RuntimeError as exc:
if self.debug:
self.log_exception(
'Unhandled runtime exception', exc_info=exc)
self.force_close()
except Exception as exc:
self.log_exception('Unhandled exception', exc_info=exc)
self.force_close()
finally:
if self.transport is None and resp is not None:
self.log_debug('Ignored premature client disconnection.')
elif not self._force_close:
if self._keepalive and not self._close:
# start keep-alive timer
if keepalive_timeout is not None:
now = self._loop.time()
self._keepalive_time = now
if self._keepalive_handle is None:
self._keepalive_handle = loop.call_at(
now + keepalive_timeout,
self._process_keepalive)
else:
break
# remove handler, close transport if no handlers left
if not self._force_close:
self._task_handler = None
if self.transport is not None and self._error_handler is None:
self.transport.close()
async def finish_response(self,
request: BaseRequest,
resp: StreamResponse,
start_time: float) -> bool:
"""
Prepare the response and write_eof, then log access. This has to
be called within the context of any exception so the access logger
can get exception information. Returns True if the client disconnects
prematurely.
"""
if self._request_parser is not None:
self._request_parser.set_upgraded(False)
self._upgrade = False
if self._message_tail:
self._request_parser.feed_data(self._message_tail)
self._message_tail = b''
try:
prepare_meth = resp.prepare
except AttributeError:
if resp is None:
raise RuntimeError("Missing return "
"statement on request handler")
else:
raise RuntimeError("Web-handler should return "
"a response instance, "
"got {!r}".format(resp))
try:
await prepare_meth(request)
await resp.write_eof()
except ConnectionError:
self.log_access(request, resp, start_time)
return True
else:
self.log_access(request, resp, start_time)
return False
def handle_error(self,
request: BaseRequest,
status: int=500,
exc: Optional[BaseException]=None,
message: Optional[str]=None) -> StreamResponse:
"""Handle errors.
Returns HTTP response with specific status code. Logs additional
information. It always closes current connection."""
self.log_exception("Error handling request", exc_info=exc)
ct = 'text/plain'
if status == HTTPStatus.INTERNAL_SERVER_ERROR:
title = '{0.value} {0.phrase}'.format(
HTTPStatus.INTERNAL_SERVER_ERROR
)
msg = HTTPStatus.INTERNAL_SERVER_ERROR.description
tb = None
if self.debug:
with suppress(Exception):
tb = traceback.format_exc()
if 'text/html' in request.headers.get('Accept', ''):
if tb:
tb = html_escape(tb)
msg = '<h2>Traceback:</h2>\n<pre>{}</pre>'.format(tb)
message = (
"<html><head>"
"<title>{title}</title>"
"</head><body>\n<h1>{title}</h1>"
"\n{msg}\n</body></html>\n"
).format(title=title, msg=msg)
ct = 'text/html'
else:
if tb:
msg = tb
message = title + '\n\n' + msg
resp = Response(status=status, text=message, content_type=ct)
resp.force_close()
# some data already got sent, connection is broken
if request.writer.output_size > 0 or self.transport is None:
self.force_close()
return resp
async def handle_parse_error(self,
writer: AbstractStreamWriter,
status: int,
exc: Optional[BaseException]=None,
message: Optional[str]=None) -> None:
request = BaseRequest(
ERROR,
EMPTY_PAYLOAD, # type: ignore
self, writer,
current_task(),
self._loop)
resp = self.handle_error(request, status, exc, message)
await resp.prepare(request)
await resp.write_eof()
if self.transport is not None:
self.transport.close()
self._error_handler = None

View File

@ -0,0 +1,766 @@
import asyncio
import datetime
import io
import re
import socket
import string
import tempfile
import types
import warnings
from email.utils import parsedate
from http.cookies import SimpleCookie
from types import MappingProxyType
from typing import ( # noqa
TYPE_CHECKING,
Any,
Dict,
Iterator,
Mapping,
MutableMapping,
Optional,
Tuple,
Union,
cast,
)
from urllib.parse import parse_qsl
import attr
from multidict import CIMultiDict, CIMultiDictProxy, MultiDict, MultiDictProxy
from yarl import URL
from . import hdrs
from .abc import AbstractStreamWriter
from .helpers import DEBUG, ChainMapProxy, HeadersMixin, reify, sentinel
from .http_parser import RawRequestMessage
from .multipart import BodyPartReader, MultipartReader
from .streams import EmptyStreamReader, StreamReader
from .typedefs import (
DEFAULT_JSON_DECODER,
JSONDecoder,
LooseHeaders,
RawHeaders,
StrOrURL,
)
from .web_exceptions import HTTPRequestEntityTooLarge
from .web_response import StreamResponse
__all__ = ('BaseRequest', 'FileField', 'Request')
if TYPE_CHECKING: # pragma: no cover
from .web_app import Application # noqa
from .web_urldispatcher import UrlMappingMatchInfo # noqa
from .web_protocol import RequestHandler # noqa
@attr.s(frozen=True, slots=True)
class FileField:
name = attr.ib(type=str)
filename = attr.ib(type=str)
file = attr.ib(type=io.BufferedReader)
content_type = attr.ib(type=str)
headers = attr.ib(type=CIMultiDictProxy) # type: CIMultiDictProxy[str]
_TCHAR = string.digits + string.ascii_letters + r"!#$%&'*+.^_`|~-"
# '-' at the end to prevent interpretation as range in a char class
_TOKEN = r'[{tchar}]+'.format(tchar=_TCHAR)
_QDTEXT = r'[{}]'.format(
r''.join(chr(c) for c in (0x09, 0x20, 0x21) + tuple(range(0x23, 0x7F))))
# qdtext includes 0x5C to escape 0x5D ('\]')
# qdtext excludes obs-text (because obsoleted, and encoding not specified)
_QUOTED_PAIR = r'\\[\t !-~]'
_QUOTED_STRING = r'"(?:{quoted_pair}|{qdtext})*"'.format(
qdtext=_QDTEXT, quoted_pair=_QUOTED_PAIR)
_FORWARDED_PAIR = (
r'({token})=({token}|{quoted_string})(:\d{{1,4}})?'.format(
token=_TOKEN,
quoted_string=_QUOTED_STRING))
_QUOTED_PAIR_REPLACE_RE = re.compile(r'\\([\t !-~])')
# same pattern as _QUOTED_PAIR but contains a capture group
_FORWARDED_PAIR_RE = re.compile(_FORWARDED_PAIR)
############################################################
# HTTP Request
############################################################
class BaseRequest(MutableMapping[str, Any], HeadersMixin):
POST_METHODS = {hdrs.METH_PATCH, hdrs.METH_POST, hdrs.METH_PUT,
hdrs.METH_TRACE, hdrs.METH_DELETE}
ATTRS = HeadersMixin.ATTRS | frozenset([
'_message', '_protocol', '_payload_writer', '_payload', '_headers',
'_method', '_version', '_rel_url', '_post', '_read_bytes',
'_state', '_cache', '_task', '_client_max_size', '_loop',
'_transport_sslcontext', '_transport_peername'])
def __init__(self, message: RawRequestMessage,
payload: StreamReader, protocol: 'RequestHandler',
payload_writer: AbstractStreamWriter,
task: 'asyncio.Task[None]',
loop: asyncio.AbstractEventLoop,
*, client_max_size: int=1024**2,
state: Optional[Dict[str, Any]]=None,
scheme: Optional[str]=None,
host: Optional[str]=None,
remote: Optional[str]=None) -> None:
if state is None:
state = {}
self._message = message
self._protocol = protocol
self._payload_writer = payload_writer
self._payload = payload
self._headers = message.headers
self._method = message.method
self._version = message.version
self._rel_url = message.url
self._post = None # type: Optional[MultiDictProxy[Union[str, bytes, FileField]]] # noqa
self._read_bytes = None # type: Optional[bytes]
self._state = state
self._cache = {} # type: Dict[str, Any]
self._task = task
self._client_max_size = client_max_size
self._loop = loop
transport = self._protocol.transport
assert transport is not None
self._transport_sslcontext = transport.get_extra_info('sslcontext')
self._transport_peername = transport.get_extra_info('peername')
if scheme is not None:
self._cache['scheme'] = scheme
if host is not None:
self._cache['host'] = host
if remote is not None:
self._cache['remote'] = remote
def clone(self, *, method: str=sentinel, rel_url: StrOrURL=sentinel,
headers: LooseHeaders=sentinel, scheme: str=sentinel,
host: str=sentinel,
remote: str=sentinel) -> 'BaseRequest':
"""Clone itself with replacement some attributes.
Creates and returns a new instance of Request object. If no parameters
are given, an exact copy is returned. If a parameter is not passed, it
will reuse the one from the current request object.
"""
if self._read_bytes:
raise RuntimeError("Cannot clone request "
"after reading its content")
dct = {} # type: Dict[str, Any]
if method is not sentinel:
dct['method'] = method
if rel_url is not sentinel:
new_url = URL(rel_url)
dct['url'] = new_url
dct['path'] = str(new_url)
if headers is not sentinel:
# a copy semantic
dct['headers'] = CIMultiDictProxy(CIMultiDict(headers))
dct['raw_headers'] = tuple((k.encode('utf-8'), v.encode('utf-8'))
for k, v in headers.items())
message = self._message._replace(**dct)
kwargs = {}
if scheme is not sentinel:
kwargs['scheme'] = scheme
if host is not sentinel:
kwargs['host'] = host
if remote is not sentinel:
kwargs['remote'] = remote
return self.__class__(
message,
self._payload,
self._protocol,
self._payload_writer,
self._task,
self._loop,
client_max_size=self._client_max_size,
state=self._state.copy(),
**kwargs)
@property
def task(self) -> 'asyncio.Task[None]':
return self._task
@property
def protocol(self) -> 'RequestHandler':
return self._protocol
@property
def transport(self) -> Optional[asyncio.Transport]:
if self._protocol is None:
return None
return self._protocol.transport
@property
def writer(self) -> AbstractStreamWriter:
return self._payload_writer
@reify
def message(self) -> RawRequestMessage:
warnings.warn("Request.message is deprecated",
DeprecationWarning,
stacklevel=3)
return self._message
@reify
def rel_url(self) -> URL:
return self._rel_url
@reify
def loop(self) -> asyncio.AbstractEventLoop:
warnings.warn("request.loop property is deprecated",
DeprecationWarning,
stacklevel=2)
return self._loop
# MutableMapping API
def __getitem__(self, key: str) -> Any:
return self._state[key]
def __setitem__(self, key: str, value: Any) -> None:
self._state[key] = value
def __delitem__(self, key: str) -> None:
del self._state[key]
def __len__(self) -> int:
return len(self._state)
def __iter__(self) -> Iterator[str]:
return iter(self._state)
########
@reify
def secure(self) -> bool:
"""A bool indicating if the request is handled with SSL."""
return self.scheme == 'https'
@reify
def forwarded(self) -> Tuple[Mapping[str, str], ...]:
"""A tuple containing all parsed Forwarded header(s).
Makes an effort to parse Forwarded headers as specified by RFC 7239:
- It adds one (immutable) dictionary per Forwarded 'field-value', ie
per proxy. The element corresponds to the data in the Forwarded
field-value added by the first proxy encountered by the client. Each
subsequent item corresponds to those added by later proxies.
- It checks that every value has valid syntax in general as specified
in section 4: either a 'token' or a 'quoted-string'.
- It un-escapes found escape sequences.
- It does NOT validate 'by' and 'for' contents as specified in section
6.
- It does NOT validate 'host' contents (Host ABNF).
- It does NOT validate 'proto' contents for valid URI scheme names.
Returns a tuple containing one or more immutable dicts
"""
elems = []
for field_value in self._message.headers.getall(hdrs.FORWARDED, ()):
length = len(field_value)
pos = 0
need_separator = False
elem = {} # type: Dict[str, str]
elems.append(types.MappingProxyType(elem))
while 0 <= pos < length:
match = _FORWARDED_PAIR_RE.match(field_value, pos)
if match is not None: # got a valid forwarded-pair
if need_separator:
# bad syntax here, skip to next comma
pos = field_value.find(',', pos)
else:
name, value, port = match.groups()
if value[0] == '"':
# quoted string: remove quotes and unescape
value = _QUOTED_PAIR_REPLACE_RE.sub(r'\1',
value[1:-1])
if port:
value += port
elem[name.lower()] = value
pos += len(match.group(0))
need_separator = True
elif field_value[pos] == ',': # next forwarded-element
need_separator = False
elem = {}
elems.append(types.MappingProxyType(elem))
pos += 1
elif field_value[pos] == ';': # next forwarded-pair
need_separator = False
pos += 1
elif field_value[pos] in ' \t':
# Allow whitespace even between forwarded-pairs, though
# RFC 7239 doesn't. This simplifies code and is in line
# with Postel's law.
pos += 1
else:
# bad syntax here, skip to next comma
pos = field_value.find(',', pos)
return tuple(elems)
@reify
def scheme(self) -> str:
"""A string representing the scheme of the request.
Hostname is resolved in this order:
- overridden value by .clone(scheme=new_scheme) call.
- type of connection to peer: HTTPS if socket is SSL, HTTP otherwise.
'http' or 'https'.
"""
if self._transport_sslcontext:
return 'https'
else:
return 'http'
@reify
def method(self) -> str:
"""Read only property for getting HTTP method.
The value is upper-cased str like 'GET', 'POST', 'PUT' etc.
"""
return self._method
@reify
def version(self) -> Tuple[int, int]:
"""Read only property for getting HTTP version of request.
Returns hyper_internal_service.protocol.HttpVersion instance.
"""
return self._version
@reify
def host(self) -> str:
"""Hostname of the request.
Hostname is resolved in this order:
- overridden value by .clone(host=new_host) call.
- HOST HTTP header
- socket.getfqdn() value
"""
host = self._message.headers.get(hdrs.HOST)
if host is not None:
return host
else:
return socket.getfqdn()
@reify
def remote(self) -> Optional[str]:
"""Remote IP of client initiated HTTP request.
The IP is resolved in this order:
- overridden value by .clone(remote=new_remote) call.
- peername of opened socket
"""
if isinstance(self._transport_peername, (list, tuple)):
return self._transport_peername[0]
else:
return self._transport_peername
@reify
def url(self) -> URL:
url = URL.build(scheme=self.scheme, host=self.host)
return url.join(self._rel_url)
@reify
def path(self) -> str:
"""The URL including *PATH INFO* without the host or scheme.
E.g., ``/app/blog``
"""
return self._rel_url.path
@reify
def path_qs(self) -> str:
"""The URL including PATH_INFO and the query string.
E.g, /app/blog?id=10
"""
return str(self._rel_url)
@reify
def raw_path(self) -> str:
""" The URL including raw *PATH INFO* without the host or scheme.
Warning, the path is unquoted and may contains non valid URL characters
E.g., ``/my%2Fpath%7Cwith%21some%25strange%24characters``
"""
return self._message.path
@reify
def query(self) -> 'MultiDictProxy[str]':
"""A multidict with all the variables in the query string."""
return self._rel_url.query
@reify
def query_string(self) -> str:
"""The query string in the URL.
E.g., id=10
"""
return self._rel_url.query_string
@reify
def headers(self) -> 'CIMultiDictProxy[str]':
"""A case-insensitive multidict proxy with all headers."""
return self._headers
@reify
def raw_headers(self) -> RawHeaders:
"""A sequence of pairs for all headers."""
return self._message.raw_headers
@staticmethod
def _http_date(_date_str: str) -> Optional[datetime.datetime]:
"""Process a date string, return a datetime object
"""
if _date_str is not None:
timetuple = parsedate(_date_str)
if timetuple is not None:
return datetime.datetime(*timetuple[:6],
tzinfo=datetime.timezone.utc)
return None
@reify
def if_modified_since(self) -> Optional[datetime.datetime]:
"""The value of If-Modified-Since HTTP header, or None.
This header is represented as a `datetime` object.
"""
return self._http_date(self.headers.get(hdrs.IF_MODIFIED_SINCE))
@reify
def if_unmodified_since(self) -> Optional[datetime.datetime]:
"""The value of If-Unmodified-Since HTTP header, or None.
This header is represented as a `datetime` object.
"""
return self._http_date(self.headers.get(hdrs.IF_UNMODIFIED_SINCE))
@reify
def if_range(self) -> Optional[datetime.datetime]:
"""The value of If-Range HTTP header, or None.
This header is represented as a `datetime` object.
"""
return self._http_date(self.headers.get(hdrs.IF_RANGE))
@reify
def keep_alive(self) -> bool:
"""Is keepalive enabled by client?"""
return not self._message.should_close
@reify
def cookies(self) -> Mapping[str, str]:
"""Return request cookies.
A read-only dictionary-like object.
"""
raw = self.headers.get(hdrs.COOKIE, '')
parsed = SimpleCookie(raw) # type: SimpleCookie[str]
return MappingProxyType(
{key: val.value for key, val in parsed.items()})
@reify
def http_range(self) -> slice:
"""The content of Range HTTP header.
Return a slice instance.
"""
rng = self._headers.get(hdrs.RANGE)
start, end = None, None
if rng is not None:
try:
pattern = r'^bytes=(\d*)-(\d*)$'
start, end = re.findall(pattern, rng)[0]
except IndexError: # pattern was not found in header
raise ValueError("range not in acceptable format")
end = int(end) if end else None
start = int(start) if start else None
if start is None and end is not None:
# end with no start is to return tail of content
start = -end
end = None
if start is not None and end is not None:
# end is inclusive in range header, exclusive for slice
end += 1
if start >= end:
raise ValueError('start cannot be after end')
if start is end is None: # No valid range supplied
raise ValueError('No start or end of range specified')
return slice(start, end, 1)
@reify
def content(self) -> StreamReader:
"""Return raw payload stream."""
return self._payload
@property
def has_body(self) -> bool:
"""Return True if request's HTTP BODY can be read, False otherwise."""
warnings.warn(
"Deprecated, use .can_read_body #2005",
DeprecationWarning, stacklevel=2)
return not self._payload.at_eof()
@property
def can_read_body(self) -> bool:
"""Return True if request's HTTP BODY can be read, False otherwise."""
return not self._payload.at_eof()
@reify
def body_exists(self) -> bool:
"""Return True if request has HTTP BODY, False otherwise."""
return type(self._payload) is not EmptyStreamReader
async def release(self) -> None:
"""Release request.
Eat unread part of HTTP BODY if present.
"""
while not self._payload.at_eof():
await self._payload.readany()
async def read(self) -> bytes:
"""Read request body if present.
Returns bytes object with full request content.
"""
if self._read_bytes is None:
body = bytearray()
while True:
chunk = await self._payload.readany()
body.extend(chunk)
if self._client_max_size:
body_size = len(body)
if body_size >= self._client_max_size:
raise HTTPRequestEntityTooLarge(
max_size=self._client_max_size,
actual_size=body_size
)
if not chunk:
break
self._read_bytes = bytes(body)
return self._read_bytes
async def text(self) -> str:
"""Return BODY as text using encoding from .charset."""
bytes_body = await self.read()
encoding = self.charset or 'utf-8'
return bytes_body.decode(encoding)
async def json(self, *, loads: JSONDecoder=DEFAULT_JSON_DECODER) -> Any:
"""Return BODY as JSON."""
body = await self.text()
return loads(body)
async def multipart(self) -> MultipartReader:
"""Return async iterator to process BODY as multipart."""
return MultipartReader(self._headers, self._payload)
async def post(self) -> 'MultiDictProxy[Union[str, bytes, FileField]]':
"""Return POST parameters."""
if self._post is not None:
return self._post
if self._method not in self.POST_METHODS:
self._post = MultiDictProxy(MultiDict())
return self._post
content_type = self.content_type
if (content_type not in ('',
'application/x-www-form-urlencoded',
'multipart/form-data')):
self._post = MultiDictProxy(MultiDict())
return self._post
out = MultiDict() # type: MultiDict[Union[str, bytes, FileField]]
if content_type == 'multipart/form-data':
multipart = await self.multipart()
max_size = self._client_max_size
field = await multipart.next()
while field is not None:
size = 0
field_ct = field.headers.get(hdrs.CONTENT_TYPE)
if isinstance(field, BodyPartReader):
if field.filename and field_ct:
# store file in temp file
tmp = tempfile.TemporaryFile()
chunk = await field.read_chunk(size=2**16)
while chunk:
chunk = field.decode(chunk)
tmp.write(chunk)
size += len(chunk)
if 0 < max_size < size:
raise HTTPRequestEntityTooLarge(
max_size=max_size,
actual_size=size
)
chunk = await field.read_chunk(size=2**16)
tmp.seek(0)
ff = FileField(field.name, field.filename,
cast(io.BufferedReader, tmp),
field_ct, field.headers)
out.add(field.name, ff)
else:
# deal with ordinary data
value = await field.read(decode=True)
if field_ct is None or \
field_ct.startswith('text/'):
charset = field.get_charset(default='utf-8')
out.add(field.name, value.decode(charset))
else:
out.add(field.name, value)
size += len(value)
if 0 < max_size < size:
raise HTTPRequestEntityTooLarge(
max_size=max_size,
actual_size=size
)
else:
raise ValueError(
'To decode nested multipart you need '
'to use custom reader',
)
field = await multipart.next()
else:
data = await self.read()
if data:
charset = self.charset or 'utf-8'
out.extend(
parse_qsl(
data.rstrip().decode(charset),
keep_blank_values=True,
encoding=charset))
self._post = MultiDictProxy(out)
return self._post
def get_extra_info(self, name: str, default: Any = None) -> Any:
"""Extra info from protocol transport"""
protocol = self._protocol
if protocol is None:
return default
transport = protocol.transport
if transport is None:
return default
return transport.get_extra_info(name, default)
def __repr__(self) -> str:
ascii_encodable_path = self.path.encode('ascii', 'backslashreplace') \
.decode('ascii')
return "<{} {} {} >".format(self.__class__.__name__,
self._method, ascii_encodable_path)
def __eq__(self, other: object) -> bool:
return id(self) == id(other)
def __bool__(self) -> bool:
return True
async def _prepare_hook(self, response: StreamResponse) -> None:
return
class Request(BaseRequest):
ATTRS = BaseRequest.ATTRS | frozenset(['_match_info'])
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
# matchdict, route_name, handler
# or information about traversal lookup
# initialized after route resolving
self._match_info = None # type: Optional[UrlMappingMatchInfo]
if DEBUG:
def __setattr__(self, name: str, val: Any) -> None:
if name not in self.ATTRS:
warnings.warn("Setting custom {}.{} attribute "
"is discouraged".format(self.__class__.__name__,
name),
DeprecationWarning,
stacklevel=2)
super().__setattr__(name, val)
def clone(self, *, method: str=sentinel, rel_url:
StrOrURL=sentinel, headers: LooseHeaders=sentinel,
scheme: str=sentinel, host: str=sentinel, remote:
str=sentinel) -> 'Request':
ret = super().clone(method=method,
rel_url=rel_url,
headers=headers,
scheme=scheme,
host=host,
remote=remote)
new_ret = cast(Request, ret)
new_ret._match_info = self._match_info
return new_ret
@reify
def match_info(self) -> 'UrlMappingMatchInfo':
"""Result of route resolving."""
match_info = self._match_info
assert match_info is not None
return match_info
@property
def app(self) -> 'Application':
"""Application instance."""
match_info = self._match_info
assert match_info is not None
return match_info.current_app
@property
def config_dict(self) -> ChainMapProxy:
match_info = self._match_info
assert match_info is not None
lst = match_info.apps
app = self.app
idx = lst.index(app)
sublist = list(reversed(lst[:idx + 1]))
return ChainMapProxy(sublist)
async def _prepare_hook(self, response: StreamResponse) -> None:
match_info = self._match_info
if match_info is None:
return
for app in match_info._apps:
await app.on_response_prepare.send(self, response)

View File

@ -0,0 +1,727 @@
import asyncio # noqa
import collections.abc # noqa
import datetime
import enum
import json
import math
import time
import warnings
import zlib
from concurrent.futures import Executor
from email.utils import parsedate
from http.cookies import Morsel, SimpleCookie
from typing import ( # noqa
TYPE_CHECKING,
Any,
Dict,
Iterator,
Mapping,
MutableMapping,
Optional,
Tuple,
Union,
cast,
)
from multidict import CIMultiDict, istr
from . import hdrs, payload
from .abc import AbstractStreamWriter
from .helpers import PY_38, HeadersMixin, rfc822_formatted_time, sentinel
from .http import RESPONSES, SERVER_SOFTWARE, HttpVersion10, HttpVersion11
from .payload import Payload
from .typedefs import JSONEncoder, LooseHeaders
__all__ = ('ContentCoding', 'StreamResponse', 'Response', 'json_response')
if TYPE_CHECKING: # pragma: no cover
from .web_request import BaseRequest # noqa
BaseClass = MutableMapping[str, Any]
else:
BaseClass = collections.abc.MutableMapping
if not PY_38:
# allow samesite to be used in python < 3.8
# already permitted in python 3.8, see https://bugs.python.org/issue29613
Morsel._reserved['samesite'] = 'SameSite' # type: ignore
class ContentCoding(enum.Enum):
# The content codings that we have support for.
#
# Additional registered codings are listed at:
# https://www.iana.org/assignments/http-parameters/http-parameters.xhtml#content-coding
deflate = 'deflate'
gzip = 'gzip'
identity = 'identity'
############################################################
# HTTP Response classes
############################################################
class StreamResponse(BaseClass, HeadersMixin):
_length_check = True
def __init__(self, *,
status: int=200,
reason: Optional[str]=None,
headers: Optional[LooseHeaders]=None) -> None:
self._body = None
self._keep_alive = None # type: Optional[bool]
self._chunked = False
self._compression = False
self._compression_force = None # type: Optional[ContentCoding]
self._cookies = SimpleCookie() # type: SimpleCookie[str]
self._req = None # type: Optional[BaseRequest]
self._payload_writer = None # type: Optional[AbstractStreamWriter]
self._eof_sent = False
self._body_length = 0
self._state = {} # type: Dict[str, Any]
if headers is not None:
self._headers = CIMultiDict(headers) # type: CIMultiDict[str]
else:
self._headers = CIMultiDict()
self.set_status(status, reason)
@property
def prepared(self) -> bool:
return self._payload_writer is not None
@property
def task(self) -> 'asyncio.Task[None]':
return getattr(self._req, 'task', None)
@property
def status(self) -> int:
return self._status
@property
def chunked(self) -> bool:
return self._chunked
@property
def compression(self) -> bool:
return self._compression
@property
def reason(self) -> str:
return self._reason
def set_status(self, status: int,
reason: Optional[str]=None,
_RESPONSES: Mapping[int,
Tuple[str, str]]=RESPONSES) -> None:
assert not self.prepared, \
'Cannot change the response status code after ' \
'the headers have been sent'
self._status = int(status)
if reason is None:
try:
reason = _RESPONSES[self._status][0]
except Exception:
reason = ''
self._reason = reason
@property
def keep_alive(self) -> Optional[bool]:
return self._keep_alive
def force_close(self) -> None:
self._keep_alive = False
@property
def body_length(self) -> int:
return self._body_length
@property
def output_length(self) -> int:
warnings.warn('output_length is deprecated', DeprecationWarning)
assert self._payload_writer
return self._payload_writer.buffer_size
def enable_chunked_encoding(self, chunk_size: Optional[int]=None) -> None:
"""Enables automatic chunked transfer encoding."""
self._chunked = True
if hdrs.CONTENT_LENGTH in self._headers:
raise RuntimeError("You can't enable chunked encoding when "
"a content length is set")
if chunk_size is not None:
warnings.warn('Chunk size is deprecated #1615', DeprecationWarning)
def enable_compression(self,
force: Optional[Union[bool, ContentCoding]]=None
) -> None:
"""Enables response compression encoding."""
# Backwards compatibility for when force was a bool <0.17.
if type(force) == bool:
force = ContentCoding.deflate if force else ContentCoding.identity
warnings.warn("Using boolean for force is deprecated #3318",
DeprecationWarning)
elif force is not None:
assert isinstance(force, ContentCoding), ("force should one of "
"None, bool or "
"ContentEncoding")
self._compression = True
self._compression_force = force
@property
def headers(self) -> 'CIMultiDict[str]':
return self._headers
@property
def cookies(self) -> 'SimpleCookie[str]':
return self._cookies
def set_cookie(self, name: str, value: str, *,
expires: Optional[str]=None,
domain: Optional[str]=None,
max_age: Optional[Union[int, str]]=None,
path: str='/',
secure: Optional[bool]=None,
httponly: Optional[bool]=None,
version: Optional[str]=None,
samesite: Optional[str]=None) -> None:
"""Set or update response cookie.
Sets new cookie or updates existent with new value.
Also updates only those params which are not None.
"""
old = self._cookies.get(name)
if old is not None and old.coded_value == '':
# deleted cookie
self._cookies.pop(name, None)
self._cookies[name] = value
c = self._cookies[name]
if expires is not None:
c['expires'] = expires
elif c.get('expires') == 'Thu, 01 Jan 1970 00:00:00 GMT':
del c['expires']
if domain is not None:
c['domain'] = domain
if max_age is not None:
c['max-age'] = str(max_age)
elif 'max-age' in c:
del c['max-age']
c['path'] = path
if secure is not None:
c['secure'] = secure
if httponly is not None:
c['httponly'] = httponly
if version is not None:
c['version'] = version
if samesite is not None:
c['samesite'] = samesite
def del_cookie(self, name: str, *,
domain: Optional[str]=None,
path: str='/') -> None:
"""Delete cookie.
Creates new empty expired cookie.
"""
# TODO: do we need domain/path here?
self._cookies.pop(name, None)
self.set_cookie(name, '', max_age=0,
expires="Thu, 01 Jan 1970 00:00:00 GMT",
domain=domain, path=path)
@property
def content_length(self) -> Optional[int]:
# Just a placeholder for adding setter
return super().content_length
@content_length.setter
def content_length(self, value: Optional[int]) -> None:
if value is not None:
value = int(value)
if self._chunked:
raise RuntimeError("You can't set content length when "
"chunked encoding is enable")
self._headers[hdrs.CONTENT_LENGTH] = str(value)
else:
self._headers.pop(hdrs.CONTENT_LENGTH, None)
@property
def content_type(self) -> str:
# Just a placeholder for adding setter
return super().content_type
@content_type.setter
def content_type(self, value: str) -> None:
self.content_type # read header values if needed
self._content_type = str(value)
self._generate_content_type_header()
@property
def charset(self) -> Optional[str]:
# Just a placeholder for adding setter
return super().charset
@charset.setter
def charset(self, value: Optional[str]) -> None:
ctype = self.content_type # read header values if needed
if ctype == 'application/octet-stream':
raise RuntimeError("Setting charset for application/octet-stream "
"doesn't make sense, setup content_type first")
assert self._content_dict is not None
if value is None:
self._content_dict.pop('charset', None)
else:
self._content_dict['charset'] = str(value).lower()
self._generate_content_type_header()
@property
def last_modified(self) -> Optional[datetime.datetime]:
"""The value of Last-Modified HTTP header, or None.
This header is represented as a `datetime` object.
"""
httpdate = self._headers.get(hdrs.LAST_MODIFIED)
if httpdate is not None:
timetuple = parsedate(httpdate)
if timetuple is not None:
return datetime.datetime(*timetuple[:6],
tzinfo=datetime.timezone.utc)
return None
@last_modified.setter
def last_modified(self,
value: Optional[
Union[int, float, datetime.datetime, str]]) -> None:
if value is None:
self._headers.pop(hdrs.LAST_MODIFIED, None)
elif isinstance(value, (int, float)):
self._headers[hdrs.LAST_MODIFIED] = time.strftime(
"%a, %d %b %Y %H:%M:%S GMT", time.gmtime(math.ceil(value)))
elif isinstance(value, datetime.datetime):
self._headers[hdrs.LAST_MODIFIED] = time.strftime(
"%a, %d %b %Y %H:%M:%S GMT", value.utctimetuple())
elif isinstance(value, str):
self._headers[hdrs.LAST_MODIFIED] = value
def _generate_content_type_header(
self,
CONTENT_TYPE: istr=hdrs.CONTENT_TYPE) -> None:
assert self._content_dict is not None
assert self._content_type is not None
params = '; '.join("{}={}".format(k, v)
for k, v in self._content_dict.items())
if params:
ctype = self._content_type + '; ' + params
else:
ctype = self._content_type
self._headers[CONTENT_TYPE] = ctype
async def _do_start_compression(self, coding: ContentCoding) -> None:
if coding != ContentCoding.identity:
assert self._payload_writer is not None
self._headers[hdrs.CONTENT_ENCODING] = coding.value
self._payload_writer.enable_compression(coding.value)
# Compressed payload may have different content length,
# remove the header
self._headers.popall(hdrs.CONTENT_LENGTH, None)
async def _start_compression(self, request: 'BaseRequest') -> None:
if self._compression_force:
await self._do_start_compression(self._compression_force)
else:
accept_encoding = request.headers.get(
hdrs.ACCEPT_ENCODING, '').lower()
for coding in ContentCoding:
if coding.value in accept_encoding:
await self._do_start_compression(coding)
return
async def prepare(
self,
request: 'BaseRequest'
) -> Optional[AbstractStreamWriter]:
if self._eof_sent:
return None
if self._payload_writer is not None:
return self._payload_writer
await request._prepare_hook(self)
return await self._start(request)
async def _start(self, request: 'BaseRequest') -> AbstractStreamWriter:
self._req = request
keep_alive = self._keep_alive
if keep_alive is None:
keep_alive = request.keep_alive
self._keep_alive = keep_alive
version = request.version
writer = self._payload_writer = request._payload_writer
headers = self._headers
for cookie in self._cookies.values():
value = cookie.output(header='')[1:]
headers.add(hdrs.SET_COOKIE, value)
if self._compression:
await self._start_compression(request)
if self._chunked:
if version != HttpVersion11:
raise RuntimeError(
"Using chunked encoding is forbidden "
"for HTTP/{0.major}.{0.minor}".format(request.version))
writer.enable_chunking()
headers[hdrs.TRANSFER_ENCODING] = 'chunked'
if hdrs.CONTENT_LENGTH in headers:
del headers[hdrs.CONTENT_LENGTH]
elif self._length_check:
writer.length = self.content_length
if writer.length is None:
if version >= HttpVersion11:
writer.enable_chunking()
headers[hdrs.TRANSFER_ENCODING] = 'chunked'
if hdrs.CONTENT_LENGTH in headers:
del headers[hdrs.CONTENT_LENGTH]
else:
keep_alive = False
headers.setdefault(hdrs.CONTENT_TYPE, 'application/octet-stream')
headers.setdefault(hdrs.DATE, rfc822_formatted_time())
headers.setdefault(hdrs.SERVER, SERVER_SOFTWARE)
# connection header
if hdrs.CONNECTION not in headers:
if keep_alive:
if version == HttpVersion10:
headers[hdrs.CONNECTION] = 'keep-alive'
else:
if version == HttpVersion11:
headers[hdrs.CONNECTION] = 'close'
# status line
status_line = 'HTTP/{}.{} {} {}'.format(
version[0], version[1], self._status, self._reason)
await writer.write_headers(status_line, headers)
return writer
async def write(self, data: bytes) -> None:
assert isinstance(data, (bytes, bytearray, memoryview)), \
"data argument must be byte-ish (%r)" % type(data)
if self._eof_sent:
raise RuntimeError("Cannot call write() after write_eof()")
if self._payload_writer is None:
raise RuntimeError("Cannot call write() before prepare()")
await self._payload_writer.write(data)
async def drain(self) -> None:
assert not self._eof_sent, "EOF has already been sent"
assert self._payload_writer is not None, \
"Response has not been started"
warnings.warn("drain method is deprecated, use await resp.write()",
DeprecationWarning,
stacklevel=2)
await self._payload_writer.drain()
async def write_eof(self, data: bytes=b'') -> None:
assert isinstance(data, (bytes, bytearray, memoryview)), \
"data argument must be byte-ish (%r)" % type(data)
if self._eof_sent:
return
assert self._payload_writer is not None, \
"Response has not been started"
await self._payload_writer.write_eof(data)
self._eof_sent = True
self._req = None
self._body_length = self._payload_writer.output_size
self._payload_writer = None
def __repr__(self) -> str:
if self._eof_sent:
info = "eof"
elif self.prepared:
assert self._req is not None
info = "{} {} ".format(self._req.method, self._req.path)
else:
info = "not prepared"
return "<{} {} {}>".format(self.__class__.__name__,
self.reason, info)
def __getitem__(self, key: str) -> Any:
return self._state[key]
def __setitem__(self, key: str, value: Any) -> None:
self._state[key] = value
def __delitem__(self, key: str) -> None:
del self._state[key]
def __len__(self) -> int:
return len(self._state)
def __iter__(self) -> Iterator[str]:
return iter(self._state)
def __hash__(self) -> int:
return hash(id(self))
def __eq__(self, other: object) -> bool:
return self is other
class Response(StreamResponse):
def __init__(self, *,
body: Any=None,
status: int=200,
reason: Optional[str]=None,
text: Optional[str]=None,
headers: Optional[LooseHeaders]=None,
content_type: Optional[str]=None,
charset: Optional[str]=None,
zlib_executor_size: Optional[int]=None,
zlib_executor: Executor=None) -> None:
if body is not None and text is not None:
raise ValueError("body and text are not allowed together")
if headers is None:
real_headers = CIMultiDict() # type: CIMultiDict[str]
elif not isinstance(headers, CIMultiDict):
real_headers = CIMultiDict(headers)
else:
real_headers = headers # = cast('CIMultiDict[str]', headers)
if content_type is not None and "charset" in content_type:
raise ValueError("charset must not be in content_type "
"argument")
if text is not None:
if hdrs.CONTENT_TYPE in real_headers:
if content_type or charset:
raise ValueError("passing both Content-Type header and "
"content_type or charset params "
"is forbidden")
else:
# fast path for filling headers
if not isinstance(text, str):
raise TypeError("text argument must be str (%r)" %
type(text))
if content_type is None:
content_type = 'text/plain'
if charset is None:
charset = 'utf-8'
real_headers[hdrs.CONTENT_TYPE] = (
content_type + '; charset=' + charset)
body = text.encode(charset)
text = None
else:
if hdrs.CONTENT_TYPE in real_headers:
if content_type is not None or charset is not None:
raise ValueError("passing both Content-Type header and "
"content_type or charset params "
"is forbidden")
else:
if content_type is not None:
if charset is not None:
content_type += '; charset=' + charset
real_headers[hdrs.CONTENT_TYPE] = content_type
super().__init__(status=status, reason=reason, headers=real_headers)
if text is not None:
self.text = text
else:
self.body = body
self._compressed_body = None # type: Optional[bytes]
self._zlib_executor_size = zlib_executor_size
self._zlib_executor = zlib_executor
@property
def body(self) -> Optional[Union[bytes, Payload]]:
return self._body
@body.setter
def body(self, body: bytes,
CONTENT_TYPE: istr=hdrs.CONTENT_TYPE,
CONTENT_LENGTH: istr=hdrs.CONTENT_LENGTH) -> None:
if body is None:
self._body = None # type: Optional[bytes]
self._body_payload = False # type: bool
elif isinstance(body, (bytes, bytearray)):
self._body = body
self._body_payload = False
else:
try:
self._body = body = payload.PAYLOAD_REGISTRY.get(body)
except payload.LookupError:
raise ValueError('Unsupported body type %r' % type(body))
self._body_payload = True
headers = self._headers
# set content-length header if needed
if not self._chunked and CONTENT_LENGTH not in headers:
size = body.size
if size is not None:
headers[CONTENT_LENGTH] = str(size)
# set content-type
if CONTENT_TYPE not in headers:
headers[CONTENT_TYPE] = body.content_type
# copy payload headers
if body.headers:
for (key, value) in body.headers.items():
if key not in headers:
headers[key] = value
self._compressed_body = None
@property
def text(self) -> Optional[str]:
if self._body is None:
return None
return self._body.decode(self.charset or 'utf-8')
@text.setter
def text(self, text: str) -> None:
assert text is None or isinstance(text, str), \
"text argument must be str (%r)" % type(text)
if self.content_type == 'application/octet-stream':
self.content_type = 'text/plain'
if self.charset is None:
self.charset = 'utf-8'
self._body = text.encode(self.charset)
self._body_payload = False
self._compressed_body = None
@property
def content_length(self) -> Optional[int]:
if self._chunked:
return None
if hdrs.CONTENT_LENGTH in self._headers:
return super().content_length
if self._compressed_body is not None:
# Return length of the compressed body
return len(self._compressed_body)
elif self._body_payload:
# A payload without content length, or a compressed payload
return None
elif self._body is not None:
return len(self._body)
else:
return 0
@content_length.setter
def content_length(self, value: Optional[int]) -> None:
raise RuntimeError("Content length is set automatically")
async def write_eof(self, data: bytes=b'') -> None:
if self._eof_sent:
return
if self._compressed_body is None:
body = self._body # type: Optional[Union[bytes, Payload]]
else:
body = self._compressed_body
assert not data, "data arg is not supported, got {!r}".format(data)
assert self._req is not None
assert self._payload_writer is not None
if body is not None:
if (self._req._method == hdrs.METH_HEAD or
self._status in [204, 304]):
await super().write_eof()
elif self._body_payload:
payload = cast(Payload, body)
await payload.write(self._payload_writer)
await super().write_eof()
else:
await super().write_eof(cast(bytes, body))
else:
await super().write_eof()
async def _start(self, request: 'BaseRequest') -> AbstractStreamWriter:
if not self._chunked and hdrs.CONTENT_LENGTH not in self._headers:
if not self._body_payload:
if self._body is not None:
self._headers[hdrs.CONTENT_LENGTH] = str(len(self._body))
else:
self._headers[hdrs.CONTENT_LENGTH] = '0'
return await super()._start(request)
def _compress_body(self, zlib_mode: int) -> None:
assert zlib_mode > 0
compressobj = zlib.compressobj(wbits=zlib_mode)
body_in = self._body
assert body_in is not None
self._compressed_body = \
compressobj.compress(body_in) + compressobj.flush()
async def _do_start_compression(self, coding: ContentCoding) -> None:
if self._body_payload or self._chunked:
return await super()._do_start_compression(coding)
if coding != ContentCoding.identity:
# Instead of using _payload_writer.enable_compression,
# compress the whole body
zlib_mode = (16 + zlib.MAX_WBITS
if coding == ContentCoding.gzip else zlib.MAX_WBITS)
body_in = self._body
assert body_in is not None
if self._zlib_executor_size is not None and \
len(body_in) > self._zlib_executor_size:
await asyncio.get_event_loop().run_in_executor(
self._zlib_executor, self._compress_body, zlib_mode)
else:
self._compress_body(zlib_mode)
body_out = self._compressed_body
assert body_out is not None
self._headers[hdrs.CONTENT_ENCODING] = coding.value
self._headers[hdrs.CONTENT_LENGTH] = str(len(body_out))
def json_response(data: Any=sentinel, *,
text: str=None,
body: bytes=None,
status: int=200,
reason: Optional[str]=None,
headers: LooseHeaders=None,
content_type: str='application/json',
dumps: JSONEncoder=json.dumps) -> Response:
if data is not sentinel:
if text or body:
raise ValueError(
"only one of data, text, or body should be specified"
)
else:
text = dumps(data)
return Response(text=text, body=body, status=status, reason=reason,
headers=headers, content_type=content_type)

View File

@ -0,0 +1,199 @@
import abc
import os # noqa
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
Dict,
Iterator,
List,
Optional,
Sequence,
Type,
Union,
overload,
)
import attr
from . import hdrs
from .abc import AbstractView
from .typedefs import PathLike
if TYPE_CHECKING: # pragma: no cover
from .web_urldispatcher import (
UrlDispatcher,
AbstractRoute
)
from .web_request import Request
from .web_response import StreamResponse
else:
Request = StreamResponse = UrlDispatcher = AbstractRoute = None
__all__ = ('AbstractRouteDef', 'RouteDef', 'StaticDef', 'RouteTableDef',
'head', 'options', 'get', 'post', 'patch', 'put', 'delete',
'route', 'view', 'static')
class AbstractRouteDef(abc.ABC):
@abc.abstractmethod
def register(self, router: UrlDispatcher) -> List[AbstractRoute]:
pass # pragma: no cover
_SimpleHandler = Callable[[Request], Awaitable[StreamResponse]]
_HandlerType = Union[Type[AbstractView], _SimpleHandler]
@attr.s(frozen=True, repr=False, slots=True)
class RouteDef(AbstractRouteDef):
method = attr.ib(type=str)
path = attr.ib(type=str)
handler = attr.ib() # type: _HandlerType
kwargs = attr.ib(type=Dict[str, Any])
def __repr__(self) -> str:
info = []
for name, value in sorted(self.kwargs.items()):
info.append(", {}={!r}".format(name, value))
return ("<RouteDef {method} {path} -> {handler.__name__!r}"
"{info}>".format(method=self.method, path=self.path,
handler=self.handler, info=''.join(info)))
def register(self, router: UrlDispatcher) -> List[AbstractRoute]:
if self.method in hdrs.METH_ALL:
reg = getattr(router, 'add_'+self.method.lower())
return [reg(self.path, self.handler, **self.kwargs)]
else:
return [router.add_route(self.method, self.path, self.handler,
**self.kwargs)]
@attr.s(frozen=True, repr=False, slots=True)
class StaticDef(AbstractRouteDef):
prefix = attr.ib(type=str)
path = attr.ib() # type: PathLike
kwargs = attr.ib(type=Dict[str, Any])
def __repr__(self) -> str:
info = []
for name, value in sorted(self.kwargs.items()):
info.append(", {}={!r}".format(name, value))
return ("<StaticDef {prefix} -> {path}"
"{info}>".format(prefix=self.prefix, path=self.path,
info=''.join(info)))
def register(self, router: UrlDispatcher) -> List[AbstractRoute]:
resource = router.add_static(self.prefix, self.path, **self.kwargs)
routes = resource.get_info().get('routes', {})
return routes.values()
def route(method: str, path: str, handler: _HandlerType,
**kwargs: Any) -> RouteDef:
return RouteDef(method, path, handler, kwargs)
def head(path: str, handler: _HandlerType, **kwargs: Any) -> RouteDef:
return route(hdrs.METH_HEAD, path, handler, **kwargs)
def options(path: str, handler: _HandlerType, **kwargs: Any) -> RouteDef:
return route(hdrs.METH_OPTIONS, path, handler, **kwargs)
def get(path: str, handler: _HandlerType, *, name: Optional[str]=None,
allow_head: bool=True, **kwargs: Any) -> RouteDef:
return route(hdrs.METH_GET, path, handler, name=name,
allow_head=allow_head, **kwargs)
def post(path: str, handler: _HandlerType, **kwargs: Any) -> RouteDef:
return route(hdrs.METH_POST, path, handler, **kwargs)
def put(path: str, handler: _HandlerType, **kwargs: Any) -> RouteDef:
return route(hdrs.METH_PUT, path, handler, **kwargs)
def patch(path: str, handler: _HandlerType, **kwargs: Any) -> RouteDef:
return route(hdrs.METH_PATCH, path, handler, **kwargs)
def delete(path: str, handler: _HandlerType, **kwargs: Any) -> RouteDef:
return route(hdrs.METH_DELETE, path, handler, **kwargs)
def view(path: str, handler: Type[AbstractView], **kwargs: Any) -> RouteDef:
return route(hdrs.METH_ANY, path, handler, **kwargs)
def static(prefix: str, path: PathLike,
**kwargs: Any) -> StaticDef:
return StaticDef(prefix, path, kwargs)
_Deco = Callable[[_HandlerType], _HandlerType]
class RouteTableDef(Sequence[AbstractRouteDef]):
"""Route definition table"""
def __init__(self) -> None:
self._items = [] # type: List[AbstractRouteDef]
def __repr__(self) -> str:
return "<RouteTableDef count={}>".format(len(self._items))
@overload
def __getitem__(self, index: int) -> AbstractRouteDef: ... # noqa
@overload # noqa
def __getitem__(self, index: slice) -> List[AbstractRouteDef]: ... # noqa
def __getitem__(self, index): # type: ignore # noqa
return self._items[index]
def __iter__(self) -> Iterator[AbstractRouteDef]:
return iter(self._items)
def __len__(self) -> int:
return len(self._items)
def __contains__(self, item: object) -> bool:
return item in self._items
def route(self,
method: str,
path: str,
**kwargs: Any) -> _Deco:
def inner(handler: _HandlerType) -> _HandlerType:
self._items.append(RouteDef(method, path, handler, kwargs))
return handler
return inner
def head(self, path: str, **kwargs: Any) -> _Deco:
return self.route(hdrs.METH_HEAD, path, **kwargs)
def get(self, path: str, **kwargs: Any) -> _Deco:
return self.route(hdrs.METH_GET, path, **kwargs)
def post(self, path: str, **kwargs: Any) -> _Deco:
return self.route(hdrs.METH_POST, path, **kwargs)
def put(self, path: str, **kwargs: Any) -> _Deco:
return self.route(hdrs.METH_PUT, path, **kwargs)
def patch(self, path: str, **kwargs: Any) -> _Deco:
return self.route(hdrs.METH_PATCH, path, **kwargs)
def delete(self, path: str, **kwargs: Any) -> _Deco:
return self.route(hdrs.METH_DELETE, path, **kwargs)
def view(self, path: str, **kwargs: Any) -> _Deco:
return self.route(hdrs.METH_ANY, path, **kwargs)
def static(self, prefix: str, path: PathLike,
**kwargs: Any) -> None:
self._items.append(StaticDef(prefix, path, kwargs))

View File

@ -0,0 +1,337 @@
import asyncio
import signal
import socket
from abc import ABC, abstractmethod
from typing import Any, List, Optional, Set
from yarl import URL
from .web_app import Application
from .web_server import Server
try:
from ssl import SSLContext
except ImportError:
SSLContext = object # type: ignore
__all__ = ('BaseSite', 'TCPSite', 'UnixSite', 'NamedPipeSite', 'SockSite',
'BaseRunner', 'AppRunner', 'ServerRunner', 'GracefulExit')
class GracefulExit(SystemExit):
code = 1
def _raise_graceful_exit() -> None:
raise GracefulExit()
class BaseSite(ABC):
__slots__ = ('_runner', '_shutdown_timeout', '_ssl_context', '_backlog',
'_server')
def __init__(self, runner: 'BaseRunner', *,
shutdown_timeout: float=60.0,
ssl_context: Optional[SSLContext]=None,
backlog: int=128) -> None:
if runner.server is None:
raise RuntimeError("Call runner.setup() before making a site")
self._runner = runner
self._shutdown_timeout = shutdown_timeout
self._ssl_context = ssl_context
self._backlog = backlog
self._server = None # type: Optional[asyncio.AbstractServer]
@property
@abstractmethod
def name(self) -> str:
pass # pragma: no cover
@abstractmethod
async def start(self) -> None:
self._runner._reg_site(self)
async def stop(self) -> None:
self._runner._check_site(self)
if self._server is None:
self._runner._unreg_site(self)
return # not started yet
self._server.close()
# named pipes do not have wait_closed property
if hasattr(self._server, 'wait_closed'):
await self._server.wait_closed()
await self._runner.shutdown()
assert self._runner.server
await self._runner.server.shutdown(self._shutdown_timeout)
self._runner._unreg_site(self)
class TCPSite(BaseSite):
__slots__ = ('_host', '_port', '_reuse_address', '_reuse_port')
def __init__(self, runner: 'BaseRunner',
host: str=None, port: int=None, *,
shutdown_timeout: float=60.0,
ssl_context: Optional[SSLContext]=None,
backlog: int=128, reuse_address: Optional[bool]=None,
reuse_port: Optional[bool]=None) -> None:
super().__init__(runner, shutdown_timeout=shutdown_timeout,
ssl_context=ssl_context, backlog=backlog)
if host is None:
host = "0.0.0.0"
self._host = host
if port is None:
port = 8443 if self._ssl_context else 8080
self._port = port
self._reuse_address = reuse_address
self._reuse_port = reuse_port
@property
def name(self) -> str:
scheme = 'https' if self._ssl_context else 'http'
return str(URL.build(scheme=scheme, host=self._host, port=self._port))
async def start(self) -> None:
await super().start()
loop = asyncio.get_event_loop()
server = self._runner.server
assert server is not None
self._server = await loop.create_server(
server, self._host, self._port,
ssl=self._ssl_context, backlog=self._backlog,
reuse_address=self._reuse_address,
reuse_port=self._reuse_port)
class UnixSite(BaseSite):
__slots__ = ('_path', )
def __init__(self, runner: 'BaseRunner', path: str, *,
shutdown_timeout: float=60.0,
ssl_context: Optional[SSLContext]=None,
backlog: int=128) -> None:
super().__init__(runner, shutdown_timeout=shutdown_timeout,
ssl_context=ssl_context, backlog=backlog)
self._path = path
@property
def name(self) -> str:
scheme = 'https' if self._ssl_context else 'http'
return '{}://unix:{}:'.format(scheme, self._path)
async def start(self) -> None:
await super().start()
loop = asyncio.get_event_loop()
server = self._runner.server
assert server is not None
self._server = await loop.create_unix_server(
server, self._path,
ssl=self._ssl_context, backlog=self._backlog)
class NamedPipeSite(BaseSite):
__slots__ = ('_path', )
def __init__(self, runner: 'BaseRunner', path: str, *,
shutdown_timeout: float=60.0) -> None:
loop = asyncio.get_event_loop()
if not isinstance(loop, asyncio.ProactorEventLoop): # type: ignore
raise RuntimeError("Named Pipes only available in proactor"
"loop under windows")
super().__init__(runner, shutdown_timeout=shutdown_timeout)
self._path = path
@property
def name(self) -> str:
return self._path
async def start(self) -> None:
await super().start()
loop = asyncio.get_event_loop()
server = self._runner.server
assert server is not None
_server = await loop.start_serving_pipe( # type: ignore
server, self._path
)
self._server = _server[0]
class SockSite(BaseSite):
__slots__ = ('_sock', '_name')
def __init__(self, runner: 'BaseRunner', sock: socket.socket, *,
shutdown_timeout: float=60.0,
ssl_context: Optional[SSLContext]=None,
backlog: int=128) -> None:
super().__init__(runner, shutdown_timeout=shutdown_timeout,
ssl_context=ssl_context, backlog=backlog)
self._sock = sock
scheme = 'https' if self._ssl_context else 'http'
if hasattr(socket, 'AF_UNIX') and sock.family == socket.AF_UNIX:
name = '{}://unix:{}:'.format(scheme, sock.getsockname())
else:
host, port = sock.getsockname()[:2]
name = str(URL.build(scheme=scheme, host=host, port=port))
self._name = name
@property
def name(self) -> str:
return self._name
async def start(self) -> None:
await super().start()
loop = asyncio.get_event_loop()
server = self._runner.server
assert server is not None
self._server = await loop.create_server(
server, sock=self._sock,
ssl=self._ssl_context, backlog=self._backlog)
class BaseRunner(ABC):
__slots__ = ('_handle_signals', '_kwargs', '_server', '_sites')
def __init__(self, *, handle_signals: bool=False, **kwargs: Any) -> None:
self._handle_signals = handle_signals
self._kwargs = kwargs
self._server = None # type: Optional[Server]
self._sites = [] # type: List[BaseSite]
@property
def server(self) -> Optional[Server]:
return self._server
@property
def addresses(self) -> List[str]:
ret = [] # type: List[str]
for site in self._sites:
server = site._server
if server is not None:
sockets = server.sockets
if sockets is not None:
for sock in sockets:
ret.append(sock.getsockname())
return ret
@property
def sites(self) -> Set[BaseSite]:
return set(self._sites)
async def setup(self) -> None:
loop = asyncio.get_event_loop()
if self._handle_signals:
try:
loop.add_signal_handler(signal.SIGINT, _raise_graceful_exit)
loop.add_signal_handler(signal.SIGTERM, _raise_graceful_exit)
except NotImplementedError: # pragma: no cover
# add_signal_handler is not implemented on Windows
pass
self._server = await self._make_server()
@abstractmethod
async def shutdown(self) -> None:
pass # pragma: no cover
async def cleanup(self) -> None:
loop = asyncio.get_event_loop()
if self._server is None:
# no started yet, do nothing
return
# The loop over sites is intentional, an exception on gather()
# leaves self._sites in unpredictable state.
# The loop guaranties that a site is either deleted on success or
# still present on failure
for site in list(self._sites):
await site.stop()
await self._cleanup_server()
self._server = None
if self._handle_signals:
try:
loop.remove_signal_handler(signal.SIGINT)
loop.remove_signal_handler(signal.SIGTERM)
except NotImplementedError: # pragma: no cover
# remove_signal_handler is not implemented on Windows
pass
@abstractmethod
async def _make_server(self) -> Server:
pass # pragma: no cover
@abstractmethod
async def _cleanup_server(self) -> None:
pass # pragma: no cover
def _reg_site(self, site: BaseSite) -> None:
if site in self._sites:
raise RuntimeError("Site {} is already registered in runner {}"
.format(site, self))
self._sites.append(site)
def _check_site(self, site: BaseSite) -> None:
if site not in self._sites:
raise RuntimeError("Site {} is not registered in runner {}"
.format(site, self))
def _unreg_site(self, site: BaseSite) -> None:
if site not in self._sites:
raise RuntimeError("Site {} is not registered in runner {}"
.format(site, self))
self._sites.remove(site)
class ServerRunner(BaseRunner):
"""Low-level web server runner"""
__slots__ = ('_web_server',)
def __init__(self, web_server: Server, *,
handle_signals: bool=False, **kwargs: Any) -> None:
super().__init__(handle_signals=handle_signals, **kwargs)
self._web_server = web_server
async def shutdown(self) -> None:
pass
async def _make_server(self) -> Server:
return self._web_server
async def _cleanup_server(self) -> None:
pass
class AppRunner(BaseRunner):
"""Web Application runner"""
__slots__ = ('_app',)
def __init__(self, app: Application, *,
handle_signals: bool=False, **kwargs: Any) -> None:
super().__init__(handle_signals=handle_signals, **kwargs)
if not isinstance(app, Application):
raise TypeError("The first argument should be web.Application "
"instance, got {!r}".format(app))
self._app = app
@property
def app(self) -> Application:
return self._app
async def shutdown(self) -> None:
await self._app.shutdown()
async def _make_server(self) -> Server:
loop = asyncio.get_event_loop()
self._app._set_loop(loop)
self._app.on_startup.freeze()
await self._app.startup()
self._app.freeze()
return self._app._make_handler(loop=loop, **self._kwargs)
async def _cleanup_server(self) -> None:
await self._app.cleanup()

View File

@ -0,0 +1,57 @@
"""Low level HTTP server."""
import asyncio
from typing import Any, Awaitable, Callable, Dict, List, Optional # noqa
from .abc import AbstractStreamWriter
from .helpers import get_running_loop
from .http_parser import RawRequestMessage
from .streams import StreamReader
from .web_protocol import RequestHandler, _RequestFactory, _RequestHandler
from .web_request import BaseRequest
__all__ = ('Server',)
class Server:
def __init__(self,
handler: _RequestHandler,
*,
request_factory: Optional[_RequestFactory]=None,
loop: Optional[asyncio.AbstractEventLoop]=None,
**kwargs: Any) -> None:
self._loop = get_running_loop(loop)
self._connections = {} # type: Dict[RequestHandler, asyncio.Transport]
self._kwargs = kwargs
self.requests_count = 0
self.request_handler = handler
self.request_factory = request_factory or self._make_request
@property
def connections(self) -> List[RequestHandler]:
return list(self._connections.keys())
def connection_made(self, handler: RequestHandler,
transport: asyncio.Transport) -> None:
self._connections[handler] = transport
def connection_lost(self, handler: RequestHandler,
exc: Optional[BaseException]=None) -> None:
if handler in self._connections:
del self._connections[handler]
def _make_request(self, message: RawRequestMessage,
payload: StreamReader,
protocol: RequestHandler,
writer: AbstractStreamWriter,
task: 'asyncio.Task[None]') -> BaseRequest:
return BaseRequest(
message, payload, protocol, writer, task, self._loop)
async def shutdown(self, timeout: Optional[float]=None) -> None:
coros = [conn.shutdown(timeout) for conn in self._connections]
await asyncio.gather(*coros)
self._connections.clear()
def __call__(self) -> RequestHandler:
return RequestHandler(self, loop=self._loop, **self._kwargs)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,457 @@
import asyncio
import base64
import binascii
import hashlib
import json
from typing import Any, Iterable, Optional, Tuple
import async_timeout
import attr
from multidict import CIMultiDict
from . import hdrs
from .abc import AbstractStreamWriter
from .helpers import call_later, set_result
from .http import (
WS_CLOSED_MESSAGE,
WS_CLOSING_MESSAGE,
WS_KEY,
WebSocketError,
WebSocketReader,
WebSocketWriter,
WSMessage,
)
from .http import WSMsgType as WSMsgType
from .http import ws_ext_gen, ws_ext_parse
from .log import ws_logger
from .streams import EofStream, FlowControlDataQueue
from .typedefs import JSONDecoder, JSONEncoder
from .web_exceptions import HTTPBadRequest, HTTPException
from .web_request import BaseRequest
from .web_response import StreamResponse
__all__ = ('WebSocketResponse', 'WebSocketReady', 'WSMsgType',)
THRESHOLD_CONNLOST_ACCESS = 5
@attr.s(frozen=True, slots=True)
class WebSocketReady:
ok = attr.ib(type=bool)
protocol = attr.ib(type=Optional[str])
def __bool__(self) -> bool:
return self.ok
class WebSocketResponse(StreamResponse):
_length_check = False
def __init__(self, *,
timeout: float=10.0, receive_timeout: Optional[float]=None,
autoclose: bool=True, autoping: bool=True,
heartbeat: Optional[float]=None,
protocols: Iterable[str]=(),
compress: bool=True, max_msg_size: int=4*1024*1024) -> None:
super().__init__(status=101)
self._protocols = protocols
self._ws_protocol = None # type: Optional[str]
self._writer = None # type: Optional[WebSocketWriter]
self._reader = None # type: Optional[FlowControlDataQueue[WSMessage]]
self._closed = False
self._closing = False
self._conn_lost = 0
self._close_code = None # type: Optional[int]
self._loop = None # type: Optional[asyncio.AbstractEventLoop]
self._waiting = None # type: Optional[asyncio.Future[bool]]
self._exception = None # type: Optional[BaseException]
self._timeout = timeout
self._receive_timeout = receive_timeout
self._autoclose = autoclose
self._autoping = autoping
self._heartbeat = heartbeat
self._heartbeat_cb = None
if heartbeat is not None:
self._pong_heartbeat = heartbeat / 2.0
self._pong_response_cb = None
self._compress = compress
self._max_msg_size = max_msg_size
def _cancel_heartbeat(self) -> None:
if self._pong_response_cb is not None:
self._pong_response_cb.cancel()
self._pong_response_cb = None
if self._heartbeat_cb is not None:
self._heartbeat_cb.cancel()
self._heartbeat_cb = None
def _reset_heartbeat(self) -> None:
self._cancel_heartbeat()
if self._heartbeat is not None:
self._heartbeat_cb = call_later(
self._send_heartbeat, self._heartbeat, self._loop)
def _send_heartbeat(self) -> None:
if self._heartbeat is not None and not self._closed:
# fire-and-forget a task is not perfect but maybe ok for
# sending ping. Otherwise we need a long-living heartbeat
# task in the class.
self._loop.create_task(self._writer.ping()) # type: ignore
if self._pong_response_cb is not None:
self._pong_response_cb.cancel()
self._pong_response_cb = call_later(
self._pong_not_received, self._pong_heartbeat, self._loop)
def _pong_not_received(self) -> None:
if self._req is not None and self._req.transport is not None:
self._closed = True
self._close_code = 1006
self._exception = asyncio.TimeoutError()
self._req.transport.close()
async def prepare(self, request: BaseRequest) -> AbstractStreamWriter:
# make pre-check to don't hide it by do_handshake() exceptions
if self._payload_writer is not None:
return self._payload_writer
protocol, writer = self._pre_start(request)
payload_writer = await super().prepare(request)
assert payload_writer is not None
self._post_start(request, protocol, writer)
await payload_writer.drain()
return payload_writer
def _handshake(self, request: BaseRequest) -> Tuple['CIMultiDict[str]',
str,
bool,
bool]:
headers = request.headers
if 'websocket' != headers.get(hdrs.UPGRADE, '').lower().strip():
raise HTTPBadRequest(
text=('No WebSocket UPGRADE hdr: {}\n Can '
'"Upgrade" only to "WebSocket".')
.format(headers.get(hdrs.UPGRADE)))
if 'upgrade' not in headers.get(hdrs.CONNECTION, '').lower():
raise HTTPBadRequest(
text='No CONNECTION upgrade hdr: {}'.format(
headers.get(hdrs.CONNECTION)))
# find common sub-protocol between client and server
protocol = None
if hdrs.SEC_WEBSOCKET_PROTOCOL in headers:
req_protocols = [str(proto.strip()) for proto in
headers[hdrs.SEC_WEBSOCKET_PROTOCOL].split(',')]
for proto in req_protocols:
if proto in self._protocols:
protocol = proto
break
else:
# No overlap found: Return no protocol as per spec
ws_logger.warning(
'Client protocols %r dont overlap server-known ones %r',
req_protocols, self._protocols)
# check supported version
version = headers.get(hdrs.SEC_WEBSOCKET_VERSION, '')
if version not in ('13', '8', '7'):
raise HTTPBadRequest(
text='Unsupported version: {}'.format(version))
# check client handshake for validity
key = headers.get(hdrs.SEC_WEBSOCKET_KEY)
try:
if not key or len(base64.b64decode(key)) != 16:
raise HTTPBadRequest(
text='Handshake error: {!r}'.format(key))
except binascii.Error:
raise HTTPBadRequest(
text='Handshake error: {!r}'.format(key)) from None
accept_val = base64.b64encode(
hashlib.sha1(key.encode() + WS_KEY).digest()).decode()
response_headers = CIMultiDict( # type: ignore
{hdrs.UPGRADE: 'websocket', # type: ignore
hdrs.CONNECTION: 'upgrade',
hdrs.SEC_WEBSOCKET_ACCEPT: accept_val})
notakeover = False
compress = 0
if self._compress:
extensions = headers.get(hdrs.SEC_WEBSOCKET_EXTENSIONS)
# Server side always get return with no exception.
# If something happened, just drop compress extension
compress, notakeover = ws_ext_parse(extensions, isserver=True)
if compress:
enabledext = ws_ext_gen(compress=compress, isserver=True,
server_notakeover=notakeover)
response_headers[hdrs.SEC_WEBSOCKET_EXTENSIONS] = enabledext
if protocol:
response_headers[hdrs.SEC_WEBSOCKET_PROTOCOL] = protocol
return (response_headers, # type: ignore
protocol,
compress,
notakeover)
def _pre_start(self, request: BaseRequest) -> Tuple[str, WebSocketWriter]:
self._loop = request._loop
headers, protocol, compress, notakeover = self._handshake(
request)
self.set_status(101)
self.headers.update(headers)
self.force_close()
self._compress = compress
transport = request._protocol.transport
assert transport is not None
writer = WebSocketWriter(request._protocol,
transport,
compress=compress,
notakeover=notakeover)
return protocol, writer
def _post_start(self, request: BaseRequest,
protocol: str, writer: WebSocketWriter) -> None:
self._ws_protocol = protocol
self._writer = writer
self._reset_heartbeat()
loop = self._loop
assert loop is not None
self._reader = FlowControlDataQueue(
request._protocol, limit=2 ** 16, loop=loop)
request.protocol.set_parser(WebSocketReader(
self._reader, self._max_msg_size, compress=self._compress))
# disable HTTP keepalive for WebSocket
request.protocol.keep_alive(False)
def can_prepare(self, request: BaseRequest) -> WebSocketReady:
if self._writer is not None:
raise RuntimeError('Already started')
try:
_, protocol, _, _ = self._handshake(request)
except HTTPException:
return WebSocketReady(False, None)
else:
return WebSocketReady(True, protocol)
@property
def closed(self) -> bool:
return self._closed
@property
def close_code(self) -> Optional[int]:
return self._close_code
@property
def ws_protocol(self) -> Optional[str]:
return self._ws_protocol
@property
def compress(self) -> bool:
return self._compress
def exception(self) -> Optional[BaseException]:
return self._exception
async def ping(self, message: bytes=b'') -> None:
if self._writer is None:
raise RuntimeError('Call .prepare() first')
await self._writer.ping(message)
async def pong(self, message: bytes=b'') -> None:
# unsolicited pong
if self._writer is None:
raise RuntimeError('Call .prepare() first')
await self._writer.pong(message)
async def send_str(self, data: str, compress: Optional[bool]=None) -> None:
if self._writer is None:
raise RuntimeError('Call .prepare() first')
if not isinstance(data, str):
raise TypeError('data argument must be str (%r)' % type(data))
await self._writer.send(data, binary=False, compress=compress)
async def send_bytes(self, data: bytes,
compress: Optional[bool]=None) -> None:
if self._writer is None:
raise RuntimeError('Call .prepare() first')
if not isinstance(data, (bytes, bytearray, memoryview)):
raise TypeError('data argument must be byte-ish (%r)' %
type(data))
await self._writer.send(data, binary=True, compress=compress)
async def send_json(self, data: Any, compress: Optional[bool]=None, *,
dumps: JSONEncoder=json.dumps) -> None:
await self.send_str(dumps(data), compress=compress)
async def write_eof(self) -> None: # type: ignore
if self._eof_sent:
return
if self._payload_writer is None:
raise RuntimeError("Response has not been started")
await self.close()
self._eof_sent = True
async def close(self, *, code: int=1000, message: bytes=b'') -> bool:
if self._writer is None:
raise RuntimeError('Call .prepare() first')
self._cancel_heartbeat()
reader = self._reader
assert reader is not None
# we need to break `receive()` cycle first,
# `close()` may be called from different task
if self._waiting is not None and not self._closed:
reader.feed_data(WS_CLOSING_MESSAGE, 0)
await self._waiting
if not self._closed:
self._closed = True
try:
await self._writer.close(code, message)
writer = self._payload_writer
assert writer is not None
await writer.drain()
except (asyncio.CancelledError, asyncio.TimeoutError):
self._close_code = 1006
raise
except Exception as exc:
self._close_code = 1006
self._exception = exc
return True
if self._closing:
return True
reader = self._reader
assert reader is not None
try:
with async_timeout.timeout(self._timeout, loop=self._loop):
msg = await reader.read()
except asyncio.CancelledError:
self._close_code = 1006
raise
except Exception as exc:
self._close_code = 1006
self._exception = exc
return True
if msg.type == WSMsgType.CLOSE:
self._close_code = msg.data
return True
self._close_code = 1006
self._exception = asyncio.TimeoutError()
return True
else:
return False
async def receive(self, timeout: Optional[float]=None) -> WSMessage:
if self._reader is None:
raise RuntimeError('Call .prepare() first')
loop = self._loop
assert loop is not None
while True:
if self._waiting is not None:
raise RuntimeError(
'Concurrent call to receive() is not allowed')
if self._closed:
self._conn_lost += 1
if self._conn_lost >= THRESHOLD_CONNLOST_ACCESS:
raise RuntimeError('WebSocket connection is closed.')
return WS_CLOSED_MESSAGE
elif self._closing:
return WS_CLOSING_MESSAGE
try:
self._waiting = loop.create_future()
try:
with async_timeout.timeout(
timeout or self._receive_timeout, loop=self._loop):
msg = await self._reader.read()
self._reset_heartbeat()
finally:
waiter = self._waiting
set_result(waiter, True)
self._waiting = None
except (asyncio.CancelledError, asyncio.TimeoutError):
self._close_code = 1006
raise
except EofStream:
self._close_code = 1000
await self.close()
return WSMessage(WSMsgType.CLOSED, None, None)
except WebSocketError as exc:
self._close_code = exc.code
await self.close(code=exc.code)
return WSMessage(WSMsgType.ERROR, exc, None)
except Exception as exc:
self._exception = exc
self._closing = True
self._close_code = 1006
await self.close()
return WSMessage(WSMsgType.ERROR, exc, None)
if msg.type == WSMsgType.CLOSE:
self._closing = True
self._close_code = msg.data
if not self._closed and self._autoclose:
await self.close()
elif msg.type == WSMsgType.CLOSING:
self._closing = True
elif msg.type == WSMsgType.PING and self._autoping:
await self.pong(msg.data)
continue
elif msg.type == WSMsgType.PONG and self._autoping:
continue
return msg
async def receive_str(self, *, timeout: Optional[float]=None) -> str:
msg = await self.receive(timeout)
if msg.type != WSMsgType.TEXT:
raise TypeError(
"Received message {}:{!r} is not WSMsgType.TEXT".format(
msg.type, msg.data))
return msg.data
async def receive_bytes(self, *, timeout: Optional[float]=None) -> bytes:
msg = await self.receive(timeout)
if msg.type != WSMsgType.BINARY:
raise TypeError(
"Received message {}:{!r} is not bytes".format(msg.type,
msg.data))
return msg.data
async def receive_json(self, *, loads: JSONDecoder=json.loads,
timeout: Optional[float]=None) -> Any:
data = await self.receive_str(timeout=timeout)
return loads(data)
async def write(self, data: bytes) -> None:
raise RuntimeError("Cannot call .write() for websocket")
def __aiter__(self) -> 'WebSocketResponse':
return self
async def __anext__(self) -> WSMessage:
msg = await self.receive()
if msg.type in (WSMsgType.CLOSE,
WSMsgType.CLOSING,
WSMsgType.CLOSED):
raise StopAsyncIteration # NOQA
return msg

View File

@ -0,0 +1,242 @@
"""Async gunicorn worker for hyper_internal_service.web"""
import asyncio
import os
import re
import signal
import sys
from types import FrameType
from typing import Any, Awaitable, Callable, Optional, Union # noqa
from gunicorn.config import AccessLogFormat as GunicornAccessLogFormat
from gunicorn.workers import base
from hyper_internal_service import web
from .helpers import set_result
from .web_app import Application
from .web_log import AccessLogger
try:
import ssl
SSLContext = ssl.SSLContext # noqa
except ImportError: # pragma: no cover
ssl = None # type: ignore
SSLContext = object # type: ignore
__all__ = ('GunicornWebWorker',
'GunicornUVLoopWebWorker',
'GunicornTokioWebWorker')
class GunicornWebWorker(base.Worker):
DEFAULT_AIOHTTP_LOG_FORMAT = AccessLogger.LOG_FORMAT
DEFAULT_GUNICORN_LOG_FORMAT = GunicornAccessLogFormat.default
def __init__(self, *args: Any, **kw: Any) -> None: # pragma: no cover
super().__init__(*args, **kw)
self._task = None # type: Optional[asyncio.Task[None]]
self.exit_code = 0
self._notify_waiter = None # type: Optional[asyncio.Future[bool]]
def init_process(self) -> None:
# create new event_loop after fork
asyncio.get_event_loop().close()
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)
super().init_process()
def run(self) -> None:
self._task = self.loop.create_task(self._run())
try: # ignore all finalization problems
self.loop.run_until_complete(self._task)
except Exception:
self.log.exception("Exception in gunicorn worker")
if sys.version_info >= (3, 6):
self.loop.run_until_complete(self.loop.shutdown_asyncgens())
self.loop.close()
sys.exit(self.exit_code)
async def _run(self) -> None:
if isinstance(self.wsgi, Application):
app = self.wsgi
elif asyncio.iscoroutinefunction(self.wsgi):
app = await self.wsgi()
else:
raise RuntimeError("wsgi app should be either Application or "
"async function returning Application, got {}"
.format(self.wsgi))
access_log = self.log.access_log if self.cfg.accesslog else None
runner = web.AppRunner(app,
logger=self.log,
keepalive_timeout=self.cfg.keepalive,
access_log=access_log,
access_log_format=self._get_valid_log_format(
self.cfg.access_log_format))
await runner.setup()
ctx = self._create_ssl_context(self.cfg) if self.cfg.is_ssl else None
runner = runner
assert runner is not None
server = runner.server
assert server is not None
for sock in self.sockets:
site = web.SockSite(
runner, sock, ssl_context=ctx,
shutdown_timeout=self.cfg.graceful_timeout / 100 * 95)
await site.start()
# If our parent changed then we shut down.
pid = os.getpid()
try:
while self.alive: # type: ignore
self.notify()
cnt = server.requests_count
if self.cfg.max_requests and cnt > self.cfg.max_requests:
self.alive = False
self.log.info("Max requests, shutting down: %s", self)
elif pid == os.getpid() and self.ppid != os.getppid():
self.alive = False
self.log.info("Parent changed, shutting down: %s", self)
else:
await self._wait_next_notify()
except BaseException:
pass
await runner.cleanup()
def _wait_next_notify(self) -> 'asyncio.Future[bool]':
self._notify_waiter_done()
loop = self.loop
assert loop is not None
self._notify_waiter = waiter = loop.create_future()
self.loop.call_later(1.0, self._notify_waiter_done, waiter)
return waiter
def _notify_waiter_done(self, waiter: 'asyncio.Future[bool]'=None) -> None:
if waiter is None:
waiter = self._notify_waiter
if waiter is not None:
set_result(waiter, True)
if waiter is self._notify_waiter:
self._notify_waiter = None
def init_signals(self) -> None:
# Set up signals through the event loop API.
self.loop.add_signal_handler(signal.SIGQUIT, self.handle_quit,
signal.SIGQUIT, None)
self.loop.add_signal_handler(signal.SIGTERM, self.handle_exit,
signal.SIGTERM, None)
self.loop.add_signal_handler(signal.SIGINT, self.handle_quit,
signal.SIGINT, None)
self.loop.add_signal_handler(signal.SIGWINCH, self.handle_winch,
signal.SIGWINCH, None)
self.loop.add_signal_handler(signal.SIGUSR1, self.handle_usr1,
signal.SIGUSR1, None)
self.loop.add_signal_handler(signal.SIGABRT, self.handle_abort,
signal.SIGABRT, None)
# Don't let SIGTERM and SIGUSR1 disturb active requests
# by interrupting system calls
signal.siginterrupt(signal.SIGTERM, False)
signal.siginterrupt(signal.SIGUSR1, False)
def handle_quit(self, sig: int, frame: FrameType) -> None:
self.alive = False
# worker_int callback
self.cfg.worker_int(self)
# wakeup closing process
self._notify_waiter_done()
def handle_abort(self, sig: int, frame: FrameType) -> None:
self.alive = False
self.exit_code = 1
self.cfg.worker_abort(self)
sys.exit(1)
@staticmethod
def _create_ssl_context(cfg: Any) -> 'SSLContext':
""" Creates SSLContext instance for usage in asyncio.create_server.
See ssl.SSLSocket.__init__ for more details.
"""
if ssl is None: # pragma: no cover
raise RuntimeError('SSL is not supported.')
ctx = ssl.SSLContext(cfg.ssl_version)
ctx.load_cert_chain(cfg.certfile, cfg.keyfile)
ctx.verify_mode = cfg.cert_reqs
if cfg.ca_certs:
ctx.load_verify_locations(cfg.ca_certs)
if cfg.ciphers:
ctx.set_ciphers(cfg.ciphers)
return ctx
def _get_valid_log_format(self, source_format: str) -> str:
if source_format == self.DEFAULT_GUNICORN_LOG_FORMAT:
return self.DEFAULT_AIOHTTP_LOG_FORMAT
elif re.search(r'%\([^\)]+\)', source_format):
raise ValueError(
"Gunicorn's style options in form of `%(name)s` are not "
"supported for the log formatting. Please use hyper_internal_service's "
"format specification to configure access log formatting: "
"http://docs.hyper_internal_service.org/en/stable/logging.html"
"#format-specification"
)
else:
return source_format
class GunicornUVLoopWebWorker(GunicornWebWorker):
def init_process(self) -> None:
import uvloop
# Close any existing event loop before setting a
# new policy.
asyncio.get_event_loop().close()
# Setup uvloop policy, so that every
# asyncio.get_event_loop() will create an instance
# of uvloop event loop.
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
super().init_process()
class GunicornTokioWebWorker(GunicornWebWorker):
def init_process(self) -> None: # pragma: no cover
import tokio
# Close any existing event loop before setting a
# new policy.
asyncio.get_event_loop().close()
# Setup tokio policy, so that every
# asyncio.get_event_loop() will create an instance
# of tokio event loop.
asyncio.set_event_loop_policy(tokio.EventLoopPolicy())
super().init_process()

View File

@ -0,0 +1,25 @@
-r flake.txt
attrs==19.3.0
async-generator==1.10
async-timeout==3.0.1
brotlipy==0.7.0
cchardet==2.1.6
chardet==3.0.4
coverage==5.1
gunicorn==20.0.4
multidict==4.7.6
pytest==5.4.2
pytest-cov==2.8.1
pytest-mock==3.1.0
typing_extensions==3.7.4.2
yarl==1.4.2
# Using PEP 508 env markers to control dependency on runtimes:
# required c-ares will not build on windows and has build problems on Macos Python<3.7
aiodns==2.0.0; sys_platform=="linux" or sys_platform=="darwin" and python_version>="3.7"
cryptography==2.9.2; platform_machine!="i686" # no 32-bit wheels
trustme==0.6.0; platform_machine!="i686" # no 32-bit wheels
codecov==2.1.0
uvloop==0.12.1; platform_system!="Windows" and implementation_name=="cpython" and python_version<"3.7" # MagicStack/uvloop#14
idna-ssl==1.1.0; python_version<"3.7"

View File

@ -0,0 +1,8 @@
setuptools-git==1.2
mypy==0.770; implementation_name=="cpython"
mypy-extensions==0.4.3; implementation_name=="cpython"
freezegun==0.3.15
-r ci-wheel.txt
-r doc.txt
-e .

View File

@ -0,0 +1 @@
cython==0.29.18

View File

@ -0,0 +1,3 @@
-r ci.txt
-r towncrier.txt
cherry_picker==1.3.2; python_version>="3.6"

View File

@ -0,0 +1,2 @@
-r doc.txt
sphinxcontrib-spelling==5.0.0; platform_system!="Windows" # We only use it in Travis CI

View File

@ -0,0 +1,5 @@
sphinx==2.4.4
sphinxcontrib-asyncio==0.2.0
pygments==2.6.1
aiohttp-theme==0.1.6
sphinxcontrib-blockdiag==2.0.0

View File

@ -0,0 +1,2 @@
flake8==3.7.9
isort==4.3.21

View File

@ -0,0 +1,5 @@
mypy==0.770; implementation_name=="cpython"
flake8==3.7.9
flake8-pyi==20.5.0; python_version >= "3.6"
black==19.10b0; python_version >= "3.6"
isort==4.3.21

View File

@ -0,0 +1 @@
towncrier==19.2.0

View File

@ -0,0 +1 @@
pytest==5.4.2

View File

@ -0,0 +1,131 @@
import pathlib
import re
import sys
from distutils.command.build_ext import build_ext
from distutils.errors import (
CCompilerError,
DistutilsExecError,
DistutilsPlatformError,
)
from setuptools import Extension, setup
if sys.version_info < (3, 5, 3):
raise RuntimeError("hyper_internal_service 3.x requires Python 3.5.3+")
here = pathlib.Path(__file__).parent
if (
(here / '.git').exists() and
not (here / 'vendor/http-parser/README.md').exists()
):
print("Install submodules when building from git clone", file=sys.stderr)
print("Hint:", file=sys.stderr)
print(" git submodule update --init", file=sys.stderr)
sys.exit(2)
# NOTE: makefile cythonizes all Cython modules
extensions = [Extension('hyper_internal_service._websocket', ['hyper_internal_service/_websocket.c']),
Extension('hyper_internal_service._http_parser',
['hyper_internal_service/_http_parser.c',
'vendor/http-parser/http_parser.c',
'hyper_internal_service/_find_header.c'],
define_macros=[('HTTP_PARSER_STRICT', 0)],
),
Extension('hyper_internal_service._frozenlist',
['hyper_internal_service/_frozenlist.c']),
Extension('hyper_internal_service._helpers',
['hyper_internal_service/_helpers.c']),
Extension('hyper_internal_service._http_writer',
['hyper_internal_service/_http_writer.c'])]
class BuildFailed(Exception):
pass
class ve_build_ext(build_ext):
# This class allows C extension building to fail.
def run(self):
try:
build_ext.run(self)
except (DistutilsPlatformError, FileNotFoundError):
raise BuildFailed()
def build_extension(self, ext):
try:
build_ext.build_extension(self, ext)
except (CCompilerError, DistutilsExecError,
DistutilsPlatformError, ValueError):
raise BuildFailed()
txt = (here / 'hyper_internal_service' / '__init__.py').read_text('utf-8')
try:
version = re.findall(r"^__version__ = '([^']+)'\r?$",
txt, re.M)[0]
except IndexError:
raise RuntimeError('Unable to determine version.')
install_requires = [
'attrs>=17.3.0',
'chardet>=2.0,<4.0',
'multidict>=4.5,<5.0',
'async_timeout>=3.0,<4.0',
'yarl>=1.0,<2.0',
'idna-ssl>=1.0; python_version<"3.7"',
'typing_extensions>=3.6.5; python_version<"3.7"',
]
def read(f):
return (here / f).read_text('utf-8').strip()
args = dict(
name='hyper_internal_service',
version=version,
description='Async http client/server framework (asyncio)',
long_description='\n\n'.join(read('README.md')),
classifiers=[
'Intended Audience :: Developers',
'Programming Language :: Python',
'Programming Language :: Python :: 3',
'Programming Language :: Python :: 3.5',
'Programming Language :: Python :: 3.6',
'Programming Language :: Python :: 3.7',
'Development Status :: 5 - Production/Stable',
'Operating System :: POSIX',
'Operating System :: MacOS :: MacOS X',
'Operating System :: Microsoft :: Windows',
'Topic :: Internet :: WWW/HTTP',
],
author='Intellivoid Technologies',
author_email='netkas@intellivoid.net',
maintainer='Zi Xing Narrakas <netkas@intellivoid.net>',
maintainer_email='netkas@intellivoid.net',
url='https://github.com/intellivoid/Hyper-Internal-Service',
project_urls={
'GitHub: issues': 'https://github.com/intellivoid/Hyper-Internal-Service/issues',
'GitHub: repo': 'https://github.com/intellivoid/Hyper-Internal-Service',
},
license='Apache 2',
packages=['hyper_internal_service'],
python_requires='>=3.5.3',
install_requires=install_requires,
include_package_data=True,
ext_modules=extensions,
cmdclass=dict(build_ext=ve_build_ext),
)
try:
setup(**args)
except BuildFailed:
print("************************************************************")
print("Cannot compile C accelerator module, use pure python version")
print("************************************************************")
del args['ext_modules']
del args['cmdclass']
setup(**args)

View File

@ -0,0 +1,30 @@
/out/
core
tags
*.o
test
test_g
test_fast
bench
url_parser
parsertrace
parsertrace_g
*.mk
*.Makefile
*.so.*
*.exe.*
*.exe
*.a
# Visual Studio uglies
*.suo
*.sln
*.vcxproj
*.vcxproj.filters
*.vcxproj.user
*.opensdf
*.ncrunchsolution*
*.sdf
*.vsp
*.psess

View File

@ -0,0 +1,8 @@
# update AUTHORS with:
# git log --all --reverse --format='%aN <%aE>' | perl -ne 'BEGIN{print "# Authors ordered by first contribution.\n"} print unless $h{$_}; $h{$_} = 1' > AUTHORS
Ryan Dahl <ry@tinyclouds.org>
Salman Haq <salman.haq@asti-usa.com>
Simon Zimmermann <simonz05@gmail.com>
Thomas LE ROUX <thomas@november-eleven.fr> LE ROUX Thomas <thomas@procheo.fr>
Thomas LE ROUX <thomas@november-eleven.fr> Thomas LE ROUX <thomas@procheo.fr>
Fedor Indutny <fedor@indutny.com>

View File

@ -0,0 +1,13 @@
language: c
compiler:
- clang
- gcc
script:
- "make"
notifications:
email: false
irc:
- "irc.freenode.net#node-ci"

View File

@ -0,0 +1,68 @@
# Authors ordered by first contribution.
Ryan Dahl <ry@tinyclouds.org>
Jeremy Hinegardner <jeremy@hinegardner.org>
Sergey Shepelev <temotor@gmail.com>
Joe Damato <ice799@gmail.com>
tomika <tomika_nospam@freemail.hu>
Phoenix Sol <phoenix@burninglabs.com>
Cliff Frey <cliff@meraki.com>
Ewen Cheslack-Postava <ewencp@cs.stanford.edu>
Santiago Gala <sgala@apache.org>
Tim Becker <tim.becker@syngenio.de>
Jeff Terrace <jterrace@gmail.com>
Ben Noordhuis <info@bnoordhuis.nl>
Nathan Rajlich <nathan@tootallnate.net>
Mark Nottingham <mnot@mnot.net>
Aman Gupta <aman@tmm1.net>
Tim Becker <tim.becker@kuriositaet.de>
Sean Cunningham <sean.cunningham@mandiant.com>
Peter Griess <pg@std.in>
Salman Haq <salman.haq@asti-usa.com>
Cliff Frey <clifffrey@gmail.com>
Jon Kolb <jon@b0g.us>
Fouad Mardini <f.mardini@gmail.com>
Paul Querna <pquerna@apache.org>
Felix Geisendörfer <felix@debuggable.com>
koichik <koichik@improvement.jp>
Andre Caron <andre.l.caron@gmail.com>
Ivo Raisr <ivosh@ivosh.net>
James McLaughlin <jamie@lacewing-project.org>
David Gwynne <loki@animata.net>
Thomas LE ROUX <thomas@november-eleven.fr>
Randy Rizun <rrizun@ortivawireless.com>
Andre Louis Caron <andre.louis.caron@usherbrooke.ca>
Simon Zimmermann <simonz05@gmail.com>
Erik Dubbelboer <erik@dubbelboer.com>
Martell Malone <martellmalone@gmail.com>
Bertrand Paquet <bpaquet@octo.com>
BogDan Vatra <bogdan@kde.org>
Peter Faiman <peter@thepicard.org>
Corey Richardson <corey@octayn.net>
Tóth Tamás <tomika_nospam@freemail.hu>
Cam Swords <cam.swords@gmail.com>
Chris Dickinson <christopher.s.dickinson@gmail.com>
Uli Köhler <ukoehler@btronik.de>
Charlie Somerville <charlie@charliesomerville.com>
Patrik Stutz <patrik.stutz@gmail.com>
Fedor Indutny <fedor.indutny@gmail.com>
runner <runner.mei@gmail.com>
Alexis Campailla <alexis@janeasystems.com>
David Wragg <david@wragg.org>
Vinnie Falco <vinnie.falco@gmail.com>
Alex Butum <alexbutum@linux.com>
Rex Feng <rexfeng@gmail.com>
Alex Kocharin <alex@kocharin.ru>
Mark Koopman <markmontymark@yahoo.com>
Helge Heß <me@helgehess.eu>
Alexis La Goutte <alexis.lagoutte@gmail.com>
George Miroshnykov <george.miroshnykov@gmail.com>
Maciej Małecki <me@mmalecki.com>
Marc O'Morain <github.com@marcomorain.com>
Jeff Pinner <jpinner@twitter.com>
Timothy J Fontaine <tjfontaine@gmail.com>
Akagi201 <akagi201@gmail.com>
Romain Giraud <giraud.romain@gmail.com>
Jay Satiro <raysatiro@yahoo.com>
Arne Steen <Arne.Steen@gmx.de>
Kjell Schubert <kjell.schubert@gmail.com>
Olivier Mengué <dolmen@cpan.org>

View File

@ -0,0 +1,19 @@
Copyright Joyent, Inc. and other Node contributors.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to
deal in the Software without restriction, including without limitation the
rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
sell copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
IN THE SOFTWARE.

View File

@ -0,0 +1,160 @@
# Copyright Joyent, Inc. and other Node contributors. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to
# deal in the Software without restriction, including without limitation the
# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
# sell copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
# IN THE SOFTWARE.
PLATFORM ?= $(shell sh -c 'uname -s | tr "[A-Z]" "[a-z]"')
HELPER ?=
BINEXT ?=
SOLIBNAME = libhttp_parser
SOMAJOR = 2
SOMINOR = 8
SOREV = 1
ifeq (darwin,$(PLATFORM))
SOEXT ?= dylib
SONAME ?= $(SOLIBNAME).$(SOMAJOR).$(SOMINOR).$(SOEXT)
LIBNAME ?= $(SOLIBNAME).$(SOMAJOR).$(SOMINOR).$(SOREV).$(SOEXT)
else ifeq (wine,$(PLATFORM))
CC = winegcc
BINEXT = .exe.so
HELPER = wine
else
SOEXT ?= so
SONAME ?= $(SOLIBNAME).$(SOEXT).$(SOMAJOR).$(SOMINOR)
LIBNAME ?= $(SOLIBNAME).$(SOEXT).$(SOMAJOR).$(SOMINOR).$(SOREV)
endif
CC?=gcc
AR?=ar
CPPFLAGS ?=
LDFLAGS ?=
CPPFLAGS += -I.
CPPFLAGS_DEBUG = $(CPPFLAGS) -DHTTP_PARSER_STRICT=1
CPPFLAGS_DEBUG += $(CPPFLAGS_DEBUG_EXTRA)
CPPFLAGS_FAST = $(CPPFLAGS) -DHTTP_PARSER_STRICT=0
CPPFLAGS_FAST += $(CPPFLAGS_FAST_EXTRA)
CPPFLAGS_BENCH = $(CPPFLAGS_FAST)
CFLAGS += -Wall -Wextra -Werror
CFLAGS_DEBUG = $(CFLAGS) -O0 -g $(CFLAGS_DEBUG_EXTRA)
CFLAGS_FAST = $(CFLAGS) -O3 $(CFLAGS_FAST_EXTRA)
CFLAGS_BENCH = $(CFLAGS_FAST) -Wno-unused-parameter
CFLAGS_LIB = $(CFLAGS_FAST) -fPIC
LDFLAGS_LIB = $(LDFLAGS) -shared
INSTALL ?= install
PREFIX ?= /usr/local
LIBDIR = $(PREFIX)/lib
INCLUDEDIR = $(PREFIX)/include
ifeq (darwin,$(PLATFORM))
LDFLAGS_LIB += -Wl,-install_name,$(LIBDIR)/$(SONAME)
else
# TODO(bnoordhuis) The native SunOS linker expects -h rather than -soname...
LDFLAGS_LIB += -Wl,-soname=$(SONAME)
endif
test: test_g test_fast
$(HELPER) ./test_g$(BINEXT)
$(HELPER) ./test_fast$(BINEXT)
test_g: http_parser_g.o test_g.o
$(CC) $(CFLAGS_DEBUG) $(LDFLAGS) http_parser_g.o test_g.o -o $@
test_g.o: test.c http_parser.h Makefile
$(CC) $(CPPFLAGS_DEBUG) $(CFLAGS_DEBUG) -c test.c -o $@
http_parser_g.o: http_parser.c http_parser.h Makefile
$(CC) $(CPPFLAGS_DEBUG) $(CFLAGS_DEBUG) -c http_parser.c -o $@
test_fast: http_parser.o test.o http_parser.h
$(CC) $(CFLAGS_FAST) $(LDFLAGS) http_parser.o test.o -o $@
test.o: test.c http_parser.h Makefile
$(CC) $(CPPFLAGS_FAST) $(CFLAGS_FAST) -c test.c -o $@
bench: http_parser.o bench.o
$(CC) $(CFLAGS_BENCH) $(LDFLAGS) http_parser.o bench.o -o $@
bench.o: bench.c http_parser.h Makefile
$(CC) $(CPPFLAGS_BENCH) $(CFLAGS_BENCH) -c bench.c -o $@
http_parser.o: http_parser.c http_parser.h Makefile
$(CC) $(CPPFLAGS_FAST) $(CFLAGS_FAST) -c http_parser.c
test-run-timed: test_fast
while(true) do time $(HELPER) ./test_fast$(BINEXT) > /dev/null; done
test-valgrind: test_g
valgrind ./test_g
libhttp_parser.o: http_parser.c http_parser.h Makefile
$(CC) $(CPPFLAGS_FAST) $(CFLAGS_LIB) -c http_parser.c -o libhttp_parser.o
library: libhttp_parser.o
$(CC) $(LDFLAGS_LIB) -o $(LIBNAME) $<
package: http_parser.o
$(AR) rcs libhttp_parser.a http_parser.o
url_parser: http_parser.o contrib/url_parser.c
$(CC) $(CPPFLAGS_FAST) $(CFLAGS_FAST) $^ -o $@
url_parser_g: http_parser_g.o contrib/url_parser.c
$(CC) $(CPPFLAGS_DEBUG) $(CFLAGS_DEBUG) $^ -o $@
parsertrace: http_parser.o contrib/parsertrace.c
$(CC) $(CPPFLAGS_FAST) $(CFLAGS_FAST) $^ -o parsertrace$(BINEXT)
parsertrace_g: http_parser_g.o contrib/parsertrace.c
$(CC) $(CPPFLAGS_DEBUG) $(CFLAGS_DEBUG) $^ -o parsertrace_g$(BINEXT)
tags: http_parser.c http_parser.h test.c
ctags $^
install: library
$(INSTALL) -D http_parser.h $(DESTDIR)$(INCLUDEDIR)/http_parser.h
$(INSTALL) -D $(LIBNAME) $(DESTDIR)$(LIBDIR)/$(LIBNAME)
ln -s $(LIBNAME) $(DESTDIR)$(LIBDIR)/$(SONAME)
ln -s $(LIBNAME) $(DESTDIR)$(LIBDIR)/$(SOLIBNAME).$(SOEXT)
install-strip: library
$(INSTALL) -D http_parser.h $(DESTDIR)$(INCLUDEDIR)/http_parser.h
$(INSTALL) -D -s $(LIBNAME) $(DESTDIR)$(LIBDIR)/$(LIBNAME)
ln -s $(LIBNAME) $(DESTDIR)$(LIBDIR)/$(SONAME)
ln -s $(LIBNAME) $(DESTDIR)$(LIBDIR)/$(SOLIBNAME).$(SOEXT)
uninstall:
rm $(DESTDIR)$(INCLUDEDIR)/http_parser.h
rm $(DESTDIR)$(LIBDIR)/$(SOLIBNAME).$(SOEXT)
rm $(DESTDIR)$(LIBDIR)/$(SONAME)
rm $(DESTDIR)$(LIBDIR)/$(LIBNAME)
clean:
rm -f *.o *.a tags test test_fast test_g \
http_parser.tar libhttp_parser.so.* \
url_parser url_parser_g parsertrace parsertrace_g \
*.exe *.exe.so
contrib/url_parser.c: http_parser.h
contrib/parsertrace.c: http_parser.h
.PHONY: clean package test-run test-run-timed test-valgrind install install-strip uninstall

View File

@ -0,0 +1,246 @@
HTTP Parser
===========
[![Build Status](https://api.travis-ci.org/nodejs/http-parser.svg?branch=master)](https://travis-ci.org/nodejs/http-parser)
This is a parser for HTTP messages written in C. It parses both requests and
responses. The parser is designed to be used in performance HTTP
applications. It does not make any syscalls nor allocations, it does not
buffer data, it can be interrupted at anytime. Depending on your
architecture, it only requires about 40 bytes of data per message
stream (in a web server that is per connection).
Features:
* No dependencies
* Handles persistent streams (keep-alive).
* Decodes chunked encoding.
* Upgrade support
* Defends against buffer overflow attacks.
The parser extracts the following information from HTTP messages:
* Header fields and values
* Content-Length
* Request method
* Response status code
* Transfer-Encoding
* HTTP version
* Request URL
* Message body
Usage
-----
One `http_parser` object is used per TCP connection. Initialize the struct
using `http_parser_init()` and set the callbacks. That might look something
like this for a request parser:
```c
http_parser_settings settings;
settings.on_url = my_url_callback;
settings.on_header_field = my_header_field_callback;
/* ... */
http_parser *parser = malloc(sizeof(http_parser));
http_parser_init(parser, HTTP_REQUEST);
parser->data = my_socket;
```
When data is received on the socket execute the parser and check for errors.
```c
size_t len = 80*1024, nparsed;
char buf[len];
ssize_t recved;
recved = recv(fd, buf, len, 0);
if (recved < 0) {
/* Handle error. */
}
/* Start up / continue the parser.
* Note we pass recved==0 to signal that EOF has been received.
*/
nparsed = http_parser_execute(parser, &settings, buf, recved);
if (parser->upgrade) {
/* handle new protocol */
} else if (nparsed != recved) {
/* Handle error. Usually just close the connection. */
}
```
`http_parser` needs to know where the end of the stream is. For example, sometimes
servers send responses without Content-Length and expect the client to
consume input (for the body) until EOF. To tell `http_parser` about EOF, give
`0` as the fourth parameter to `http_parser_execute()`. Callbacks and errors
can still be encountered during an EOF, so one must still be prepared
to receive them.
Scalar valued message information such as `status_code`, `method`, and the
HTTP version are stored in the parser structure. This data is only
temporally stored in `http_parser` and gets reset on each new message. If
this information is needed later, copy it out of the structure during the
`headers_complete` callback.
The parser decodes the transfer-encoding for both requests and responses
transparently. That is, a chunked encoding is decoded before being sent to
the on_body callback.
The Special Problem of Upgrade
------------------------------
`http_parser` supports upgrading the connection to a different protocol. An
increasingly common example of this is the WebSocket protocol which sends
a request like
GET /demo HTTP/1.1
Upgrade: WebSocket
Connection: Upgrade
Host: example.com
Origin: http://example.com
WebSocket-Protocol: sample
followed by non-HTTP data.
(See [RFC6455](https://tools.ietf.org/html/rfc6455) for more information the
WebSocket protocol.)
To support this, the parser will treat this as a normal HTTP message without a
body, issuing both on_headers_complete and on_message_complete callbacks. However
http_parser_execute() will stop parsing at the end of the headers and return.
The user is expected to check if `parser->upgrade` has been set to 1 after
`http_parser_execute()` returns. Non-HTTP data begins at the buffer supplied
offset by the return value of `http_parser_execute()`.
Callbacks
---------
During the `http_parser_execute()` call, the callbacks set in
`http_parser_settings` will be executed. The parser maintains state and
never looks behind, so buffering the data is not necessary. If you need to
save certain data for later usage, you can do that from the callbacks.
There are two types of callbacks:
* notification `typedef int (*http_cb) (http_parser*);`
Callbacks: on_message_begin, on_headers_complete, on_message_complete.
* data `typedef int (*http_data_cb) (http_parser*, const char *at, size_t length);`
Callbacks: (requests only) on_url,
(common) on_header_field, on_header_value, on_body;
Callbacks must return 0 on success. Returning a non-zero value indicates
error to the parser, making it exit immediately.
For cases where it is necessary to pass local information to/from a callback,
the `http_parser` object's `data` field can be used.
An example of such a case is when using threads to handle a socket connection,
parse a request, and then give a response over that socket. By instantiation
of a thread-local struct containing relevant data (e.g. accepted socket,
allocated memory for callbacks to write into, etc), a parser's callbacks are
able to communicate data between the scope of the thread and the scope of the
callback in a threadsafe manner. This allows `http_parser` to be used in
multi-threaded contexts.
Example:
```c
typedef struct {
socket_t sock;
void* buffer;
int buf_len;
} custom_data_t;
int my_url_callback(http_parser* parser, const char *at, size_t length) {
/* access to thread local custom_data_t struct.
Use this access save parsed data for later use into thread local
buffer, or communicate over socket
*/
parser->data;
...
return 0;
}
...
void http_parser_thread(socket_t sock) {
int nparsed = 0;
/* allocate memory for user data */
custom_data_t *my_data = malloc(sizeof(custom_data_t));
/* some information for use by callbacks.
* achieves thread -> callback information flow */
my_data->sock = sock;
/* instantiate a thread-local parser */
http_parser *parser = malloc(sizeof(http_parser));
http_parser_init(parser, HTTP_REQUEST); /* initialise parser */
/* this custom data reference is accessible through the reference to the
parser supplied to callback functions */
parser->data = my_data;
http_parser_settings settings; /* set up callbacks */
settings.on_url = my_url_callback;
/* execute parser */
nparsed = http_parser_execute(parser, &settings, buf, recved);
...
/* parsed information copied from callback.
can now perform action on data copied into thread-local memory from callbacks.
achieves callback -> thread information flow */
my_data->buffer;
...
}
```
In case you parse HTTP message in chunks (i.e. `read()` request line
from socket, parse, read half headers, parse, etc) your data callbacks
may be called more than once. `http_parser` guarantees that data pointer is only
valid for the lifetime of callback. You can also `read()` into a heap allocated
buffer to avoid copying memory around if this fits your application.
Reading headers may be a tricky task if you read/parse headers partially.
Basically, you need to remember whether last header callback was field or value
and apply the following logic:
(on_header_field and on_header_value shortened to on_h_*)
------------------------ ------------ --------------------------------------------
| State (prev. callback) | Callback | Description/action |
------------------------ ------------ --------------------------------------------
| nothing (first call) | on_h_field | Allocate new buffer and copy callback data |
| | | into it |
------------------------ ------------ --------------------------------------------
| value | on_h_field | New header started. |
| | | Copy current name,value buffers to headers |
| | | list and allocate new buffer for new name |
------------------------ ------------ --------------------------------------------
| field | on_h_field | Previous name continues. Reallocate name |
| | | buffer and append callback data to it |
------------------------ ------------ --------------------------------------------
| field | on_h_value | Value for current header started. Allocate |
| | | new buffer and copy callback data to it |
------------------------ ------------ --------------------------------------------
| value | on_h_value | Value continues. Reallocate value buffer |
| | | and append callback data to it |
------------------------ ------------ --------------------------------------------
Parsing URLs
------------
A simplistic zero-copy URL parser is provided as `http_parser_parse_url()`.
Users of this library may wish to use it to parse URLs constructed from
consecutive `on_url` callbacks.
See examples of reading in headers:
* [partial example](http://gist.github.com/155877) in C
* [from http-parser tests](http://github.com/joyent/http-parser/blob/37a0ff8/test.c#L403) in C
* [from Node library](http://github.com/joyent/node/blob/842eaf4/src/http.js#L284) in Javascript

View File

@ -0,0 +1,128 @@
/* Copyright Fedor Indutny. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to
* deal in the Software without restriction, including without limitation the
* rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
* sell copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
* IN THE SOFTWARE.
*/
#include "http_parser.h"
#include <assert.h>
#include <stdint.h>
#include <stdio.h>
#include <string.h>
#include <sys/time.h>
/* 8 gb */
static const int64_t kBytes = 8LL << 30;
static const char data[] =
"POST /joyent/http-parser HTTP/1.1\r\n"
"Host: github.com\r\n"
"DNT: 1\r\n"
"Accept-Encoding: gzip, deflate, sdch\r\n"
"Accept-Language: ru-RU,ru;q=0.8,en-US;q=0.6,en;q=0.4\r\n"
"User-Agent: Mozilla/5.0 (Macintosh; Intel Mac OS X 10_10_1) "
"AppleWebKit/537.36 (KHTML, like Gecko) "
"Chrome/39.0.2171.65 Safari/537.36\r\n"
"Accept: text/html,application/xhtml+xml,application/xml;q=0.9,"
"image/webp,*/*;q=0.8\r\n"
"Referer: https://github.com/joyent/http-parser\r\n"
"Connection: keep-alive\r\n"
"Transfer-Encoding: chunked\r\n"
"Cache-Control: max-age=0\r\n\r\nb\r\nhello world\r\n0\r\n";
static const size_t data_len = sizeof(data) - 1;
static int on_info(http_parser* p) {
return 0;
}
static int on_data(http_parser* p, const char *at, size_t length) {
return 0;
}
static http_parser_settings settings = {
.on_message_begin = on_info,
.on_headers_complete = on_info,
.on_message_complete = on_info,
.on_header_field = on_data,
.on_header_value = on_data,
.on_url = on_data,
.on_status = on_data,
.on_body = on_data
};
int bench(int iter_count, int silent) {
struct http_parser parser;
int i;
int err;
struct timeval start;
struct timeval end;
if (!silent) {
err = gettimeofday(&start, NULL);
assert(err == 0);
}
fprintf(stderr, "req_len=%d\n", (int) data_len);
for (i = 0; i < iter_count; i++) {
size_t parsed;
http_parser_init(&parser, HTTP_REQUEST);
parsed = http_parser_execute(&parser, &settings, data, data_len);
assert(parsed == data_len);
}
if (!silent) {
double elapsed;
double bw;
double total;
err = gettimeofday(&end, NULL);
assert(err == 0);
fprintf(stdout, "Benchmark result:\n");
elapsed = (double) (end.tv_sec - start.tv_sec) +
(end.tv_usec - start.tv_usec) * 1e-6f;
total = (double) iter_count * data_len;
bw = (double) total / elapsed;
fprintf(stdout, "%.2f mb | %.2f mb/s | %.2f req/sec | %.2f s\n",
(double) total / (1024 * 1024),
bw / (1024 * 1024),
(double) iter_count / elapsed,
elapsed);
fflush(stdout);
}
return 0;
}
int main(int argc, char** argv) {
int64_t iterations;
iterations = kBytes / (int64_t) data_len;
if (argc == 2 && strcmp(argv[1], "infinite") == 0) {
for (;;)
bench(iterations, 1);
return 0;
} else {
return bench(iterations, 0);
}
}

View File

@ -0,0 +1,157 @@
/* Copyright Joyent, Inc. and other Node contributors.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to
* deal in the Software without restriction, including without limitation the
* rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
* sell copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
* IN THE SOFTWARE.
*/
/* Dump what the parser finds to stdout as it happen */
#include "http_parser.h"
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
int on_message_begin(http_parser* _) {
(void)_;
printf("\n***MESSAGE BEGIN***\n\n");
return 0;
}
int on_headers_complete(http_parser* _) {
(void)_;
printf("\n***HEADERS COMPLETE***\n\n");
return 0;
}
int on_message_complete(http_parser* _) {
(void)_;
printf("\n***MESSAGE COMPLETE***\n\n");
return 0;
}
int on_url(http_parser* _, const char* at, size_t length) {
(void)_;
printf("Url: %.*s\n", (int)length, at);
return 0;
}
int on_header_field(http_parser* _, const char* at, size_t length) {
(void)_;
printf("Header field: %.*s\n", (int)length, at);
return 0;
}
int on_header_value(http_parser* _, const char* at, size_t length) {
(void)_;
printf("Header value: %.*s\n", (int)length, at);
return 0;
}
int on_body(http_parser* _, const char* at, size_t length) {
(void)_;
printf("Body: %.*s\n", (int)length, at);
return 0;
}
void usage(const char* name) {
fprintf(stderr,
"Usage: %s $type $filename\n"
" type: -x, where x is one of {r,b,q}\n"
" parses file as a Response, reQuest, or Both\n",
name);
exit(EXIT_FAILURE);
}
int main(int argc, char* argv[]) {
enum http_parser_type file_type;
if (argc != 3) {
usage(argv[0]);
}
char* type = argv[1];
if (type[0] != '-') {
usage(argv[0]);
}
switch (type[1]) {
/* in the case of "-", type[1] will be NUL */
case 'r':
file_type = HTTP_RESPONSE;
break;
case 'q':
file_type = HTTP_REQUEST;
break;
case 'b':
file_type = HTTP_BOTH;
break;
default:
usage(argv[0]);
}
char* filename = argv[2];
FILE* file = fopen(filename, "r");
if (file == NULL) {
perror("fopen");
goto fail;
}
fseek(file, 0, SEEK_END);
long file_length = ftell(file);
if (file_length == -1) {
perror("ftell");
goto fail;
}
fseek(file, 0, SEEK_SET);
char* data = malloc(file_length);
if (fread(data, 1, file_length, file) != (size_t)file_length) {
fprintf(stderr, "couldn't read entire file\n");
free(data);
goto fail;
}
http_parser_settings settings;
memset(&settings, 0, sizeof(settings));
settings.on_message_begin = on_message_begin;
settings.on_url = on_url;
settings.on_header_field = on_header_field;
settings.on_header_value = on_header_value;
settings.on_headers_complete = on_headers_complete;
settings.on_body = on_body;
settings.on_message_complete = on_message_complete;
http_parser parser;
http_parser_init(&parser, file_type);
size_t nparsed = http_parser_execute(&parser, &settings, data, file_length);
free(data);
if (nparsed != (size_t)file_length) {
fprintf(stderr,
"Error: %s (%s)\n",
http_errno_description(HTTP_PARSER_ERRNO(&parser)),
http_errno_name(HTTP_PARSER_ERRNO(&parser)));
goto fail;
}
return EXIT_SUCCESS;
fail:
fclose(file);
return EXIT_FAILURE;
}

View File

@ -0,0 +1,47 @@
#include "http_parser.h"
#include <stdio.h>
#include <string.h>
void
dump_url (const char *url, const struct http_parser_url *u)
{
unsigned int i;
printf("\tfield_set: 0x%x, port: %u\n", u->field_set, u->port);
for (i = 0; i < UF_MAX; i++) {
if ((u->field_set & (1 << i)) == 0) {
printf("\tfield_data[%u]: unset\n", i);
continue;
}
printf("\tfield_data[%u]: off: %u, len: %u, part: %.*s\n",
i,
u->field_data[i].off,
u->field_data[i].len,
u->field_data[i].len,
url + u->field_data[i].off);
}
}
int main(int argc, char ** argv) {
struct http_parser_url u;
int len, connect, result;
if (argc != 3) {
printf("Syntax : %s connect|get url\n", argv[0]);
return 1;
}
len = strlen(argv[2]);
connect = strcmp("connect", argv[1]) == 0 ? 1 : 0;
printf("Parsing %s, connect %d\n", argv[2], connect);
http_parser_url_init(&u);
result = http_parser_parse_url(argv[2], len, connect, &u);
if (result != 0) {
printf("Parse error : %d\n", result);
return result;
}
printf("Parse ok, result : \n");
dump_url(argv[2], &u);
return 0;
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,111 @@
# This file is used with the GYP meta build system.
# http://code.google.com/p/gyp/
# To build try this:
# svn co http://gyp.googlecode.com/svn/trunk gyp
# ./gyp/gyp -f make --depth=`pwd` http_parser.gyp
# ./out/Debug/test
{
'target_defaults': {
'default_configuration': 'Debug',
'configurations': {
# TODO: hoist these out and put them somewhere common, because
# RuntimeLibrary MUST MATCH across the entire project
'Debug': {
'defines': [ 'DEBUG', '_DEBUG' ],
'cflags': [ '-Wall', '-Wextra', '-O0', '-g', '-ftrapv' ],
'msvs_settings': {
'VCCLCompilerTool': {
'RuntimeLibrary': 1, # static debug
},
},
},
'Release': {
'defines': [ 'NDEBUG' ],
'cflags': [ '-Wall', '-Wextra', '-O3' ],
'msvs_settings': {
'VCCLCompilerTool': {
'RuntimeLibrary': 0, # static release
},
},
}
},
'msvs_settings': {
'VCCLCompilerTool': {
},
'VCLibrarianTool': {
},
'VCLinkerTool': {
'GenerateDebugInformation': 'true',
},
},
'conditions': [
['OS == "win"', {
'defines': [
'WIN32'
],
}]
],
},
'targets': [
{
'target_name': 'http_parser',
'type': 'static_library',
'include_dirs': [ '.' ],
'direct_dependent_settings': {
'defines': [ 'HTTP_PARSER_STRICT=0' ],
'include_dirs': [ '.' ],
},
'defines': [ 'HTTP_PARSER_STRICT=0' ],
'sources': [ './http_parser.c', ],
'conditions': [
['OS=="win"', {
'msvs_settings': {
'VCCLCompilerTool': {
# Compile as C++. http_parser.c is actually C99, but C++ is
# close enough in this case.
'CompileAs': 2,
},
},
}]
],
},
{
'target_name': 'http_parser_strict',
'type': 'static_library',
'include_dirs': [ '.' ],
'direct_dependent_settings': {
'defines': [ 'HTTP_PARSER_STRICT=1' ],
'include_dirs': [ '.' ],
},
'defines': [ 'HTTP_PARSER_STRICT=1' ],
'sources': [ './http_parser.c', ],
'conditions': [
['OS=="win"', {
'msvs_settings': {
'VCCLCompilerTool': {
# Compile as C++. http_parser.c is actually C99, but C++ is
# close enough in this case.
'CompileAs': 2,
},
},
}]
],
},
{
'target_name': 'test-nonstrict',
'type': 'executable',
'dependencies': [ 'http_parser' ],
'sources': [ 'test.c' ]
},
{
'target_name': 'test-strict',
'type': 'executable',
'dependencies': [ 'http_parser_strict' ],
'sources': [ 'test.c' ]
}
]
}

View File

@ -0,0 +1,436 @@
/* Copyright Joyent, Inc. and other Node contributors. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to
* deal in the Software without restriction, including without limitation the
* rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
* sell copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
* IN THE SOFTWARE.
*/
#ifndef http_parser_h
#define http_parser_h
#ifdef __cplusplus
extern "C" {
#endif
/* Also update SONAME in the Makefile whenever you change these. */
#define HTTP_PARSER_VERSION_MAJOR 2
#define HTTP_PARSER_VERSION_MINOR 8
#define HTTP_PARSER_VERSION_PATCH 1
#include <stddef.h>
#if defined(_WIN32) && !defined(__MINGW32__) && \
(!defined(_MSC_VER) || _MSC_VER<1600) && !defined(__WINE__)
#include <BaseTsd.h>
typedef __int8 int8_t;
typedef unsigned __int8 uint8_t;
typedef __int16 int16_t;
typedef unsigned __int16 uint16_t;
typedef __int32 int32_t;
typedef unsigned __int32 uint32_t;
typedef __int64 int64_t;
typedef unsigned __int64 uint64_t;
#else
#include <stdint.h>
#endif
/* Compile with -DHTTP_PARSER_STRICT=0 to make less checks, but run
* faster
*/
#ifndef HTTP_PARSER_STRICT
# define HTTP_PARSER_STRICT 1
#endif
/* Maximium header size allowed. If the macro is not defined
* before including this header then the default is used. To
* change the maximum header size, define the macro in the build
* environment (e.g. -DHTTP_MAX_HEADER_SIZE=<value>). To remove
* the effective limit on the size of the header, define the macro
* to a very large number (e.g. -DHTTP_MAX_HEADER_SIZE=0x7fffffff)
*/
#ifndef HTTP_MAX_HEADER_SIZE
# define HTTP_MAX_HEADER_SIZE (80*1024)
#endif
typedef struct http_parser http_parser;
typedef struct http_parser_settings http_parser_settings;
/* Callbacks should return non-zero to indicate an error. The parser will
* then halt execution.
*
* The one exception is on_headers_complete. In a HTTP_RESPONSE parser
* returning '1' from on_headers_complete will tell the parser that it
* should not expect a body. This is used when receiving a response to a
* HEAD request which may contain 'Content-Length' or 'Transfer-Encoding:
* chunked' headers that indicate the presence of a body.
*
* Returning `2` from on_headers_complete will tell parser that it should not
* expect neither a body nor any futher responses on this connection. This is
* useful for handling responses to a CONNECT request which may not contain
* `Upgrade` or `Connection: upgrade` headers.
*
* http_data_cb does not return data chunks. It will be called arbitrarily
* many times for each string. E.G. you might get 10 callbacks for "on_url"
* each providing just a few characters more data.
*/
typedef int (*http_data_cb) (http_parser*, const char *at, size_t length);
typedef int (*http_cb) (http_parser*);
/* Status Codes */
#define HTTP_STATUS_MAP(XX) \
XX(100, CONTINUE, Continue) \
XX(101, SWITCHING_PROTOCOLS, Switching Protocols) \
XX(102, PROCESSING, Processing) \
XX(200, OK, OK) \
XX(201, CREATED, Created) \
XX(202, ACCEPTED, Accepted) \
XX(203, NON_AUTHORITATIVE_INFORMATION, Non-Authoritative Information) \
XX(204, NO_CONTENT, No Content) \
XX(205, RESET_CONTENT, Reset Content) \
XX(206, PARTIAL_CONTENT, Partial Content) \
XX(207, MULTI_STATUS, Multi-Status) \
XX(208, ALREADY_REPORTED, Already Reported) \
XX(226, IM_USED, IM Used) \
XX(300, MULTIPLE_CHOICES, Multiple Choices) \
XX(301, MOVED_PERMANENTLY, Moved Permanently) \
XX(302, FOUND, Found) \
XX(303, SEE_OTHER, See Other) \
XX(304, NOT_MODIFIED, Not Modified) \
XX(305, USE_PROXY, Use Proxy) \
XX(307, TEMPORARY_REDIRECT, Temporary Redirect) \
XX(308, PERMANENT_REDIRECT, Permanent Redirect) \
XX(400, BAD_REQUEST, Bad Request) \
XX(401, UNAUTHORIZED, Unauthorized) \
XX(402, PAYMENT_REQUIRED, Payment Required) \
XX(403, FORBIDDEN, Forbidden) \
XX(404, NOT_FOUND, Not Found) \
XX(405, METHOD_NOT_ALLOWED, Method Not Allowed) \
XX(406, NOT_ACCEPTABLE, Not Acceptable) \
XX(407, PROXY_AUTHENTICATION_REQUIRED, Proxy Authentication Required) \
XX(408, REQUEST_TIMEOUT, Request Timeout) \
XX(409, CONFLICT, Conflict) \
XX(410, GONE, Gone) \
XX(411, LENGTH_REQUIRED, Length Required) \
XX(412, PRECONDITION_FAILED, Precondition Failed) \
XX(413, PAYLOAD_TOO_LARGE, Payload Too Large) \
XX(414, URI_TOO_LONG, URI Too Long) \
XX(415, UNSUPPORTED_MEDIA_TYPE, Unsupported Media Type) \
XX(416, RANGE_NOT_SATISFIABLE, Range Not Satisfiable) \
XX(417, EXPECTATION_FAILED, Expectation Failed) \
XX(421, MISDIRECTED_REQUEST, Misdirected Request) \
XX(422, UNPROCESSABLE_ENTITY, Unprocessable Entity) \
XX(423, LOCKED, Locked) \
XX(424, FAILED_DEPENDENCY, Failed Dependency) \
XX(426, UPGRADE_REQUIRED, Upgrade Required) \
XX(428, PRECONDITION_REQUIRED, Precondition Required) \
XX(429, TOO_MANY_REQUESTS, Too Many Requests) \
XX(431, REQUEST_HEADER_FIELDS_TOO_LARGE, Request Header Fields Too Large) \
XX(451, UNAVAILABLE_FOR_LEGAL_REASONS, Unavailable For Legal Reasons) \
XX(500, INTERNAL_SERVER_ERROR, Internal Server Error) \
XX(501, NOT_IMPLEMENTED, Not Implemented) \
XX(502, BAD_GATEWAY, Bad Gateway) \
XX(503, SERVICE_UNAVAILABLE, Service Unavailable) \
XX(504, GATEWAY_TIMEOUT, Gateway Timeout) \
XX(505, HTTP_VERSION_NOT_SUPPORTED, HTTP Version Not Supported) \
XX(506, VARIANT_ALSO_NEGOTIATES, Variant Also Negotiates) \
XX(507, INSUFFICIENT_STORAGE, Insufficient Storage) \
XX(508, LOOP_DETECTED, Loop Detected) \
XX(510, NOT_EXTENDED, Not Extended) \
XX(511, NETWORK_AUTHENTICATION_REQUIRED, Network Authentication Required) \
enum http_status
{
#define XX(num, name, string) HTTP_STATUS_##name = num,
HTTP_STATUS_MAP(XX)
#undef XX
};
/* Request Methods */
#define HTTP_METHOD_MAP(XX) \
XX(0, DELETE, DELETE) \
XX(1, GET, GET) \
XX(2, HEAD, HEAD) \
XX(3, POST, POST) \
XX(4, PUT, PUT) \
/* pathological */ \
XX(5, CONNECT, CONNECT) \
XX(6, OPTIONS, OPTIONS) \
XX(7, TRACE, TRACE) \
/* WebDAV */ \
XX(8, COPY, COPY) \
XX(9, LOCK, LOCK) \
XX(10, MKCOL, MKCOL) \
XX(11, MOVE, MOVE) \
XX(12, PROPFIND, PROPFIND) \
XX(13, PROPPATCH, PROPPATCH) \
XX(14, SEARCH, SEARCH) \
XX(15, UNLOCK, UNLOCK) \
XX(16, BIND, BIND) \
XX(17, REBIND, REBIND) \
XX(18, UNBIND, UNBIND) \
XX(19, ACL, ACL) \
/* subversion */ \
XX(20, REPORT, REPORT) \
XX(21, MKACTIVITY, MKACTIVITY) \
XX(22, CHECKOUT, CHECKOUT) \
XX(23, MERGE, MERGE) \
/* upnp */ \
XX(24, MSEARCH, M-SEARCH) \
XX(25, NOTIFY, NOTIFY) \
XX(26, SUBSCRIBE, SUBSCRIBE) \
XX(27, UNSUBSCRIBE, UNSUBSCRIBE) \
/* RFC-5789 */ \
XX(28, PATCH, PATCH) \
XX(29, PURGE, PURGE) \
/* CalDAV */ \
XX(30, MKCALENDAR, MKCALENDAR) \
/* RFC-2068, section 19.6.1.2 */ \
XX(31, LINK, LINK) \
XX(32, UNLINK, UNLINK) \
/* icecast */ \
XX(33, SOURCE, SOURCE) \
enum http_method
{
#define XX(num, name, string) HTTP_##name = num,
HTTP_METHOD_MAP(XX)
#undef XX
};
enum http_parser_type { HTTP_REQUEST, HTTP_RESPONSE, HTTP_BOTH };
/* Flag values for http_parser.flags field */
enum flags
{ F_CHUNKED = 1 << 0
, F_CONNECTION_KEEP_ALIVE = 1 << 1
, F_CONNECTION_CLOSE = 1 << 2
, F_CONNECTION_UPGRADE = 1 << 3
, F_TRAILING = 1 << 4
, F_UPGRADE = 1 << 5
, F_SKIPBODY = 1 << 6
, F_CONTENTLENGTH = 1 << 7
};
/* Map for errno-related constants
*
* The provided argument should be a macro that takes 2 arguments.
*/
#define HTTP_ERRNO_MAP(XX) \
/* No error */ \
XX(OK, "success") \
\
/* Callback-related errors */ \
XX(CB_message_begin, "the on_message_begin callback failed") \
XX(CB_url, "the on_url callback failed") \
XX(CB_header_field, "the on_header_field callback failed") \
XX(CB_header_value, "the on_header_value callback failed") \
XX(CB_headers_complete, "the on_headers_complete callback failed") \
XX(CB_body, "the on_body callback failed") \
XX(CB_message_complete, "the on_message_complete callback failed") \
XX(CB_status, "the on_status callback failed") \
XX(CB_chunk_header, "the on_chunk_header callback failed") \
XX(CB_chunk_complete, "the on_chunk_complete callback failed") \
\
/* Parsing-related errors */ \
XX(INVALID_EOF_STATE, "stream ended at an unexpected time") \
XX(HEADER_OVERFLOW, \
"too many header bytes seen; overflow detected") \
XX(CLOSED_CONNECTION, \
"data received after completed connection: close message") \
XX(INVALID_VERSION, "invalid HTTP version") \
XX(INVALID_STATUS, "invalid HTTP status code") \
XX(INVALID_METHOD, "invalid HTTP method") \
XX(INVALID_URL, "invalid URL") \
XX(INVALID_HOST, "invalid host") \
XX(INVALID_PORT, "invalid port") \
XX(INVALID_PATH, "invalid path") \
XX(INVALID_QUERY_STRING, "invalid query string") \
XX(INVALID_FRAGMENT, "invalid fragment") \
XX(LF_EXPECTED, "LF character expected") \
XX(INVALID_HEADER_TOKEN, "invalid character in header") \
XX(INVALID_CONTENT_LENGTH, \
"invalid character in content-length header") \
XX(UNEXPECTED_CONTENT_LENGTH, \
"unexpected content-length header") \
XX(INVALID_CHUNK_SIZE, \
"invalid character in chunk size header") \
XX(INVALID_CONSTANT, "invalid constant string") \
XX(INVALID_INTERNAL_STATE, "encountered unexpected internal state")\
XX(STRICT, "strict mode assertion failed") \
XX(PAUSED, "parser is paused") \
XX(UNKNOWN, "an unknown error occurred")
/* Define HPE_* values for each errno value above */
#define HTTP_ERRNO_GEN(n, s) HPE_##n,
enum http_errno {
HTTP_ERRNO_MAP(HTTP_ERRNO_GEN)
};
#undef HTTP_ERRNO_GEN
/* Get an http_errno value from an http_parser */
#define HTTP_PARSER_ERRNO(p) ((enum http_errno) (p)->http_errno)
struct http_parser {
/** PRIVATE **/
unsigned int type : 2; /* enum http_parser_type */
unsigned int flags : 8; /* F_* values from 'flags' enum; semi-public */
unsigned int state : 7; /* enum state from http_parser.c */
unsigned int header_state : 7; /* enum header_state from http_parser.c */
unsigned int index : 7; /* index into current matcher */
unsigned int lenient_http_headers : 1;
uint32_t nread; /* # bytes read in various scenarios */
uint64_t content_length; /* # bytes in body (0 if no Content-Length header) */
/** READ-ONLY **/
unsigned short http_major;
unsigned short http_minor;
unsigned int status_code : 16; /* responses only */
unsigned int method : 8; /* requests only */
unsigned int http_errno : 7;
/* 1 = Upgrade header was present and the parser has exited because of that.
* 0 = No upgrade header present.
* Should be checked when http_parser_execute() returns in addition to
* error checking.
*/
unsigned int upgrade : 1;
/** PUBLIC **/
void *data; /* A pointer to get hook to the "connection" or "socket" object */
};
struct http_parser_settings {
http_cb on_message_begin;
http_data_cb on_url;
http_data_cb on_status;
http_data_cb on_header_field;
http_data_cb on_header_value;
http_cb on_headers_complete;
http_data_cb on_body;
http_cb on_message_complete;
/* When on_chunk_header is called, the current chunk length is stored
* in parser->content_length.
*/
http_cb on_chunk_header;
http_cb on_chunk_complete;
};
enum http_parser_url_fields
{ UF_SCHEMA = 0
, UF_HOST = 1
, UF_PORT = 2
, UF_PATH = 3
, UF_QUERY = 4
, UF_FRAGMENT = 5
, UF_USERINFO = 6
, UF_MAX = 7
};
/* Result structure for http_parser_parse_url().
*
* Callers should index into field_data[] with UF_* values iff field_set
* has the relevant (1 << UF_*) bit set. As a courtesy to clients (and
* because we probably have padding left over), we convert any port to
* a uint16_t.
*/
struct http_parser_url {
uint16_t field_set; /* Bitmask of (1 << UF_*) values */
uint16_t port; /* Converted UF_PORT string */
struct {
uint16_t off; /* Offset into buffer in which field starts */
uint16_t len; /* Length of run in buffer */
} field_data[UF_MAX];
};
/* Returns the library version. Bits 16-23 contain the major version number,
* bits 8-15 the minor version number and bits 0-7 the patch level.
* Usage example:
*
* unsigned long version = http_parser_version();
* unsigned major = (version >> 16) & 255;
* unsigned minor = (version >> 8) & 255;
* unsigned patch = version & 255;
* printf("http_parser v%u.%u.%u\n", major, minor, patch);
*/
unsigned long http_parser_version(void);
void http_parser_init(http_parser *parser, enum http_parser_type type);
/* Initialize http_parser_settings members to 0
*/
void http_parser_settings_init(http_parser_settings *settings);
/* Executes the parser. Returns number of parsed bytes. Sets
* `parser->http_errno` on error. */
size_t http_parser_execute(http_parser *parser,
const http_parser_settings *settings,
const char *data,
size_t len);
/* If http_should_keep_alive() in the on_headers_complete or
* on_message_complete callback returns 0, then this should be
* the last message on the connection.
* If you are the server, respond with the "Connection: close" header.
* If you are the client, close the connection.
*/
int http_should_keep_alive(const http_parser *parser);
/* Returns a string version of the HTTP method. */
const char *http_method_str(enum http_method m);
/* Returns a string version of the HTTP status code. */
const char *http_status_str(enum http_status s);
/* Return a string name of the given error */
const char *http_errno_name(enum http_errno err);
/* Return a string description of the given error */
const char *http_errno_description(enum http_errno err);
/* Initialize all http_parser_url members to 0 */
void http_parser_url_init(struct http_parser_url *u);
/* Parse a URL; return nonzero on failure */
int http_parser_parse_url(const char *buf, size_t buflen,
int is_connect,
struct http_parser_url *u);
/* Pause or un-pause the parser; a nonzero value pauses */
void http_parser_pause(http_parser *parser, int paused);
/* Checks if this is the final chunk of the body. */
int http_body_is_final(const http_parser *parser);
#ifdef __cplusplus
}
#endif
#endif

File diff suppressed because it is too large Load Diff