Ported network primitives from giambio
This commit is contained in:
parent
4a974ab06d
commit
3b81702c2b
|
@ -16,7 +16,7 @@ See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
"""
|
"""
|
||||||
from aiosched.runtime import run, get_event_loop, new_event_loop, clock, with_context
|
from aiosched.runtime import run, get_event_loop, new_event_loop, clock, with_context
|
||||||
from aiosched.internals.syscalls import spawn, wait, sleep, cancel
|
from aiosched.internals.syscalls import spawn, wait, sleep, cancel, checkpoint
|
||||||
import aiosched.task
|
import aiosched.task
|
||||||
import aiosched.errors
|
import aiosched.errors
|
||||||
import aiosched.context
|
import aiosched.context
|
||||||
|
@ -36,5 +36,6 @@ __all__ = [
|
||||||
"Event",
|
"Event",
|
||||||
"Queue",
|
"Queue",
|
||||||
"Channel",
|
"Channel",
|
||||||
"MemoryChannel"
|
"MemoryChannel",
|
||||||
|
"checkpoint"
|
||||||
]
|
]
|
||||||
|
|
|
@ -51,6 +51,7 @@ async def schedule(task: Task):
|
||||||
|
|
||||||
await syscall("schedule", task)
|
await syscall("schedule", task)
|
||||||
|
|
||||||
|
|
||||||
async def spawn(func: Callable[..., Coroutine[Any, Any, Any]], *args, **kwargs) -> Task:
|
async def spawn(func: Callable[..., Coroutine[Any, Any, Any]], *args, **kwargs) -> Task:
|
||||||
"""
|
"""
|
||||||
Spawns a task from a coroutine and returns it. The coroutine
|
Spawns a task from a coroutine and returns it. The coroutine
|
||||||
|
@ -208,6 +209,15 @@ async def wait_writable(stream):
|
||||||
await syscall("perform_io", stream, EVENT_WRITE)
|
await syscall("perform_io", stream, EVENT_WRITE)
|
||||||
|
|
||||||
|
|
||||||
|
async def io_release(stream):
|
||||||
|
"""
|
||||||
|
Signals to the event loop to
|
||||||
|
release a given I/O resource
|
||||||
|
"""
|
||||||
|
|
||||||
|
await syscall("io_release", stream)
|
||||||
|
|
||||||
|
|
||||||
async def set_context(ctx):
|
async def set_context(ctx):
|
||||||
"""
|
"""
|
||||||
Sets the current task context
|
Sets the current task context
|
||||||
|
|
|
@ -0,0 +1,399 @@
|
||||||
|
"""
|
||||||
|
aiosched: Yet another Python async scheduler
|
||||||
|
|
||||||
|
Copyright (C) 2022 nocturn9x
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
https:www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import socket
|
||||||
|
import warnings
|
||||||
|
import os
|
||||||
|
import aiosched
|
||||||
|
from aiosched.errors import ResourceClosed
|
||||||
|
from aiosched.internals.syscalls import wait_writable, wait_readable, io_release, closing
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
from ssl import SSLWantReadError, SSLWantWriteError, SSLSocket
|
||||||
|
|
||||||
|
WantRead = (BlockingIOError, InterruptedError, SSLWantReadError)
|
||||||
|
WantWrite = (BlockingIOError, InterruptedError, SSLWantWriteError)
|
||||||
|
except ImportError:
|
||||||
|
WantRead = (BlockingIOError, InterruptedError)
|
||||||
|
WantWrite = (BlockingIOError, InterruptedError)
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncStream:
|
||||||
|
"""
|
||||||
|
A generic asynchronous stream over
|
||||||
|
a file descriptor. Only works on Linux
|
||||||
|
& co because windows doesn't like select()
|
||||||
|
to be called on non-socket objects
|
||||||
|
(Thanks, Microsoft)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, fd: int, open_fd: bool = True, close_on_context_exit: bool = True, **kwargs):
|
||||||
|
self._fd = fd
|
||||||
|
self.stream = None
|
||||||
|
if open_fd:
|
||||||
|
self.stream = os.fdopen(self._fd, **kwargs)
|
||||||
|
os.set_blocking(self._fd, False)
|
||||||
|
self.close_on_context_exit = close_on_context_exit
|
||||||
|
|
||||||
|
async def read(self, size: int = -1):
|
||||||
|
"""
|
||||||
|
Reads up to size bytes from the
|
||||||
|
given stream. If size == -1, read
|
||||||
|
until EOF is reached
|
||||||
|
"""
|
||||||
|
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
return self.stream.read(size)
|
||||||
|
except WantRead:
|
||||||
|
await wait_readable(self.stream)
|
||||||
|
|
||||||
|
async def write(self, data):
|
||||||
|
"""
|
||||||
|
Writes data b to the file.
|
||||||
|
Returns the number of bytes
|
||||||
|
written
|
||||||
|
"""
|
||||||
|
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
return self.stream.write(data)
|
||||||
|
except WantWrite:
|
||||||
|
await wait_writable(self.stream)
|
||||||
|
|
||||||
|
async def close(self):
|
||||||
|
"""
|
||||||
|
Closes the stream asynchronously
|
||||||
|
"""
|
||||||
|
|
||||||
|
if self._fd == -1:
|
||||||
|
raise ResourceClosed("I/O operation on closed stream")
|
||||||
|
self._fd = -1
|
||||||
|
await closing(self.stream)
|
||||||
|
await io_release(self.stream)
|
||||||
|
self.stream.close()
|
||||||
|
self.stream = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
async def fileno(self):
|
||||||
|
"""
|
||||||
|
Wrapper socket method
|
||||||
|
"""
|
||||||
|
|
||||||
|
return self._fd
|
||||||
|
|
||||||
|
async def __aenter__(self):
|
||||||
|
self.stream.__enter__()
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(self, *args):
|
||||||
|
if self._fd != -1 and self.close_on_context_exit:
|
||||||
|
await self.close()
|
||||||
|
|
||||||
|
async def dup(self):
|
||||||
|
"""
|
||||||
|
Wrapper stream method
|
||||||
|
"""
|
||||||
|
|
||||||
|
return type(self)(os.dup(self._fd))
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"AsyncStream({self.stream})"
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
"""
|
||||||
|
Stream destructor. Do *not* call
|
||||||
|
this directly: stuff will break
|
||||||
|
"""
|
||||||
|
|
||||||
|
if self._fd != -1:
|
||||||
|
try:
|
||||||
|
os.set_blocking(self._fd, False)
|
||||||
|
os.close(self._fd)
|
||||||
|
except OSError as e:
|
||||||
|
warnings.warn(f"An exception occurred in __del__ for stream {self} -> {type(e).__name__}: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncSocket(AsyncStream):
|
||||||
|
"""
|
||||||
|
Abstraction layer for asynchronous sockets
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, sock: socket.socket, close_on_context_exit: bool = True, do_handshake_on_connect: bool = True):
|
||||||
|
super().__init__(sock.fileno(), open_fd=False, close_on_context_exit=close_on_context_exit)
|
||||||
|
self.do_handshake_on_connect = do_handshake_on_connect
|
||||||
|
self.stream = socket.fromfd(self._fd, sock.family, sock.type, sock.proto)
|
||||||
|
self.stream.setblocking(False)
|
||||||
|
# A socket that isn't connected doesn't
|
||||||
|
# need to be closed
|
||||||
|
self.needs_closing: bool = False
|
||||||
|
|
||||||
|
async def receive(self, max_size: int, flags: int = 0) -> bytes:
|
||||||
|
"""
|
||||||
|
Receives up to max_size bytes from a socket asynchronously
|
||||||
|
"""
|
||||||
|
|
||||||
|
assert max_size >= 1, "max_size must be >= 1"
|
||||||
|
if self._fd == -1:
|
||||||
|
raise ResourceClosed("I/O operation on closed socket")
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
return self.stream.recv(max_size, flags)
|
||||||
|
except WantRead:
|
||||||
|
await wait_readable(self.stream)
|
||||||
|
except WantWrite:
|
||||||
|
await wait_writable(self.stream)
|
||||||
|
|
||||||
|
async def connect(self, address):
|
||||||
|
"""
|
||||||
|
Wrapper socket method
|
||||||
|
"""
|
||||||
|
|
||||||
|
if self._fd == -1:
|
||||||
|
raise ResourceClosed("I/O operation on closed socket")
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
self.stream.connect(address)
|
||||||
|
if self.do_handshake_on_connect:
|
||||||
|
await self.do_handshake()
|
||||||
|
break
|
||||||
|
except WantWrite:
|
||||||
|
await wait_writable(self.stream)
|
||||||
|
self.needs_closing = True
|
||||||
|
|
||||||
|
async def close(self):
|
||||||
|
"""
|
||||||
|
Wrapper socket method
|
||||||
|
"""
|
||||||
|
|
||||||
|
if self.needs_closing:
|
||||||
|
await super().close()
|
||||||
|
|
||||||
|
async def accept(self):
|
||||||
|
"""
|
||||||
|
Accepts the socket, completing the 3-step TCP handshake asynchronously
|
||||||
|
"""
|
||||||
|
|
||||||
|
if self._fd == -1:
|
||||||
|
raise ResourceClosed("I/O operation on closed socket")
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
remote, addr = self.stream.accept()
|
||||||
|
return type(self)(remote), addr
|
||||||
|
except WantRead:
|
||||||
|
await wait_readable(self.stream)
|
||||||
|
|
||||||
|
async def send_all(self, data: bytes, flags: int = 0):
|
||||||
|
"""
|
||||||
|
Sends all data inside the buffer asynchronously until it is empty
|
||||||
|
"""
|
||||||
|
|
||||||
|
if self._fd == -1:
|
||||||
|
raise ResourceClosed("I/O operation on closed socket")
|
||||||
|
sent_no = 0
|
||||||
|
while data:
|
||||||
|
try:
|
||||||
|
sent_no = self.stream.send(data, flags)
|
||||||
|
except WantRead:
|
||||||
|
await wait_readable(self.stream)
|
||||||
|
except WantWrite:
|
||||||
|
await wait_writable(self.stream)
|
||||||
|
data = data[sent_no:]
|
||||||
|
|
||||||
|
async def shutdown(self, how):
|
||||||
|
"""
|
||||||
|
Wrapper socket method
|
||||||
|
"""
|
||||||
|
|
||||||
|
if self.stream:
|
||||||
|
self.stream.shutdown(how)
|
||||||
|
await aiosched.checkpoint()
|
||||||
|
|
||||||
|
async def bind(self, addr: tuple):
|
||||||
|
"""
|
||||||
|
Binds the socket to an address
|
||||||
|
:param addr: The address, port tuple to bind to
|
||||||
|
:type addr: tuple
|
||||||
|
"""
|
||||||
|
|
||||||
|
if self._fd == -1:
|
||||||
|
raise ResourceClosed("I/O operation on closed socket")
|
||||||
|
self.stream.bind(addr)
|
||||||
|
|
||||||
|
async def listen(self, backlog: int):
|
||||||
|
"""
|
||||||
|
Starts listening with the given backlog
|
||||||
|
:param backlog: The socket's backlog
|
||||||
|
:type backlog: int
|
||||||
|
"""
|
||||||
|
|
||||||
|
if self._fd == -1:
|
||||||
|
raise ResourceClosed("I/O operation on closed socket")
|
||||||
|
self.stream.listen(backlog)
|
||||||
|
|
||||||
|
# Yes, I stole these from Curio because I could not be
|
||||||
|
# arsed to write a bunch of uninteresting simple socket
|
||||||
|
# methods from scratch, deal with it.
|
||||||
|
|
||||||
|
async def settimeout(self, seconds):
|
||||||
|
"""
|
||||||
|
Wrapper socket method
|
||||||
|
"""
|
||||||
|
|
||||||
|
raise RuntimeError("Use with_timeout() to set a timeout")
|
||||||
|
|
||||||
|
async def gettimeout(self):
|
||||||
|
"""
|
||||||
|
Wrapper socket method
|
||||||
|
"""
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def dup(self):
|
||||||
|
"""
|
||||||
|
Wrapper socket method
|
||||||
|
"""
|
||||||
|
|
||||||
|
return type(self)(self.stream.dup(), self.do_handshake_on_connect)
|
||||||
|
|
||||||
|
async def do_handshake(self):
|
||||||
|
"""
|
||||||
|
Wrapper socket method
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not hasattr(self.stream, "do_handshake"):
|
||||||
|
return
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
self.stream: SSLSocket # Silences pycharm warnings
|
||||||
|
return self.stream.do_handshake()
|
||||||
|
except WantRead:
|
||||||
|
await wait_readable(self.stream)
|
||||||
|
except WantWrite:
|
||||||
|
await wait_writable(self.stream)
|
||||||
|
|
||||||
|
async def recvfrom(self, buffersize, flags=0):
|
||||||
|
"""
|
||||||
|
Wrapper socket method
|
||||||
|
"""
|
||||||
|
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
return self.stream.recvfrom(buffersize, flags)
|
||||||
|
except WantRead:
|
||||||
|
await wait_readable(self.stream)
|
||||||
|
except WantWrite:
|
||||||
|
await wait_writable(self.stream)
|
||||||
|
|
||||||
|
async def recvfrom_into(self, buffer, bytes=0, flags=0):
|
||||||
|
"""
|
||||||
|
Wrapper socket method
|
||||||
|
"""
|
||||||
|
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
return self.stream.recvfrom_into(buffer, bytes, flags)
|
||||||
|
except WantRead:
|
||||||
|
await wait_readable(self.stream)
|
||||||
|
except WantWrite:
|
||||||
|
await wait_writable(self.stream)
|
||||||
|
|
||||||
|
async def sendto(self, bytes, flags_or_address, address=None):
|
||||||
|
"""
|
||||||
|
Wrapper socket method
|
||||||
|
"""
|
||||||
|
|
||||||
|
if address:
|
||||||
|
flags = flags_or_address
|
||||||
|
else:
|
||||||
|
address = flags_or_address
|
||||||
|
flags = 0
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
return self.stream.sendto(bytes, flags, address)
|
||||||
|
except WantWrite:
|
||||||
|
await wait_writable(self.stream)
|
||||||
|
except WantRead:
|
||||||
|
await wait_readable(self.stream)
|
||||||
|
|
||||||
|
async def getpeername(self):
|
||||||
|
"""
|
||||||
|
Wrapper socket method
|
||||||
|
"""
|
||||||
|
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
return self.stream.getpeername()
|
||||||
|
except WantWrite:
|
||||||
|
await wait_writable(self.stream)
|
||||||
|
except WantRead:
|
||||||
|
await wait_readable(self.stream)
|
||||||
|
|
||||||
|
async def getsockname(self):
|
||||||
|
"""
|
||||||
|
Wrapper socket method
|
||||||
|
"""
|
||||||
|
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
return self.stream.getpeername()
|
||||||
|
except WantWrite:
|
||||||
|
await wait_writable(self.stream)
|
||||||
|
except WantRead:
|
||||||
|
await wait_readable(self.stream)
|
||||||
|
|
||||||
|
async def recvmsg(self, bufsize, ancbufsize=0, flags=0):
|
||||||
|
"""
|
||||||
|
Wrapper socket method
|
||||||
|
"""
|
||||||
|
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
return self.stream.recvmsg(bufsize, ancbufsize, flags)
|
||||||
|
except WantRead:
|
||||||
|
await wait_readable(self.stream)
|
||||||
|
|
||||||
|
async def recvmsg_into(self, buffers, ancbufsize=0, flags=0):
|
||||||
|
"""
|
||||||
|
Wrapper socket method
|
||||||
|
"""
|
||||||
|
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
return self.stream.recvmsg_into(buffers, ancbufsize, flags)
|
||||||
|
except WantRead:
|
||||||
|
await wait_readable(self.stream)
|
||||||
|
|
||||||
|
async def sendmsg(self, buffers, ancdata=(), flags=0, address=None):
|
||||||
|
"""
|
||||||
|
Wrapper socket method
|
||||||
|
"""
|
||||||
|
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
return self.stream.sendmsg(buffers, ancdata, flags, address)
|
||||||
|
except WantRead:
|
||||||
|
await wait_writable(self.stream)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"AsyncSocket({self.stream})"
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
if self.needs_closing:
|
||||||
|
super().__del__()
|
|
@ -0,0 +1,37 @@
|
||||||
|
"""
|
||||||
|
aiosched: Yet another Python async scheduler
|
||||||
|
|
||||||
|
Copyright (C) 2022 nocturn9x
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
https:www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
"""
|
||||||
|
import socket as _socket
|
||||||
|
from aiosched.io import AsyncSocket
|
||||||
|
|
||||||
|
|
||||||
|
def wrap_socket(sock: _socket.socket) -> AsyncSocket:
|
||||||
|
"""
|
||||||
|
Wraps a standard socket into an async socket
|
||||||
|
"""
|
||||||
|
|
||||||
|
return AsyncSocket(sock)
|
||||||
|
|
||||||
|
|
||||||
|
def socket(*args, **kwargs):
|
||||||
|
"""
|
||||||
|
Creates a new giambio socket, taking in the same positional and
|
||||||
|
keyword arguments as the standard library's socket.socket
|
||||||
|
constructor
|
||||||
|
"""
|
||||||
|
|
||||||
|
return wrap_socket(_socket.socket(*args, **kwargs))
|
Reference in New Issue