diff --git a/aiosched/__init__.py b/aiosched/__init__.py index b32f1a3..8ec0d0e 100644 --- a/aiosched/__init__.py +++ b/aiosched/__init__.py @@ -16,7 +16,7 @@ See the License for the specific language governing permissions and limitations under the License. """ from aiosched.runtime import run, get_event_loop, new_event_loop, clock, with_context -from aiosched.internals.syscalls import spawn, wait, sleep, cancel +from aiosched.internals.syscalls import spawn, wait, sleep, cancel, checkpoint import aiosched.task import aiosched.errors import aiosched.context @@ -36,5 +36,6 @@ __all__ = [ "Event", "Queue", "Channel", - "MemoryChannel" + "MemoryChannel", + "checkpoint" ] diff --git a/aiosched/internals/syscalls.py b/aiosched/internals/syscalls.py index 356aeeb..98e45a2 100644 --- a/aiosched/internals/syscalls.py +++ b/aiosched/internals/syscalls.py @@ -51,6 +51,7 @@ async def schedule(task: Task): await syscall("schedule", task) + async def spawn(func: Callable[..., Coroutine[Any, Any, Any]], *args, **kwargs) -> Task: """ Spawns a task from a coroutine and returns it. The coroutine @@ -208,6 +209,15 @@ async def wait_writable(stream): await syscall("perform_io", stream, EVENT_WRITE) +async def io_release(stream): + """ + Signals to the event loop to + release a given I/O resource + """ + + await syscall("io_release", stream) + + async def set_context(ctx): """ Sets the current task context diff --git a/aiosched/io.py b/aiosched/io.py new file mode 100644 index 0000000..aa976df --- /dev/null +++ b/aiosched/io.py @@ -0,0 +1,399 @@ +""" +aiosched: Yet another Python async scheduler + +Copyright (C) 2022 nocturn9x + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https:www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import socket +import warnings +import os +import aiosched +from aiosched.errors import ResourceClosed +from aiosched.internals.syscalls import wait_writable, wait_readable, io_release, closing + + +try: + from ssl import SSLWantReadError, SSLWantWriteError, SSLSocket + + WantRead = (BlockingIOError, InterruptedError, SSLWantReadError) + WantWrite = (BlockingIOError, InterruptedError, SSLWantWriteError) +except ImportError: + WantRead = (BlockingIOError, InterruptedError) + WantWrite = (BlockingIOError, InterruptedError) + + +class AsyncStream: + """ + A generic asynchronous stream over + a file descriptor. Only works on Linux + & co because windows doesn't like select() + to be called on non-socket objects + (Thanks, Microsoft) + """ + + def __init__(self, fd: int, open_fd: bool = True, close_on_context_exit: bool = True, **kwargs): + self._fd = fd + self.stream = None + if open_fd: + self.stream = os.fdopen(self._fd, **kwargs) + os.set_blocking(self._fd, False) + self.close_on_context_exit = close_on_context_exit + + async def read(self, size: int = -1): + """ + Reads up to size bytes from the + given stream. If size == -1, read + until EOF is reached + """ + + while True: + try: + return self.stream.read(size) + except WantRead: + await wait_readable(self.stream) + + async def write(self, data): + """ + Writes data b to the file. + Returns the number of bytes + written + """ + + while True: + try: + return self.stream.write(data) + except WantWrite: + await wait_writable(self.stream) + + async def close(self): + """ + Closes the stream asynchronously + """ + + if self._fd == -1: + raise ResourceClosed("I/O operation on closed stream") + self._fd = -1 + await closing(self.stream) + await io_release(self.stream) + self.stream.close() + self.stream = None + + @property + async def fileno(self): + """ + Wrapper socket method + """ + + return self._fd + + async def __aenter__(self): + self.stream.__enter__() + return self + + async def __aexit__(self, *args): + if self._fd != -1 and self.close_on_context_exit: + await self.close() + + async def dup(self): + """ + Wrapper stream method + """ + + return type(self)(os.dup(self._fd)) + + def __repr__(self): + return f"AsyncStream({self.stream})" + + def __del__(self): + """ + Stream destructor. Do *not* call + this directly: stuff will break + """ + + if self._fd != -1: + try: + os.set_blocking(self._fd, False) + os.close(self._fd) + except OSError as e: + warnings.warn(f"An exception occurred in __del__ for stream {self} -> {type(e).__name__}: {e}") + + +class AsyncSocket(AsyncStream): + """ + Abstraction layer for asynchronous sockets + """ + + def __init__(self, sock: socket.socket, close_on_context_exit: bool = True, do_handshake_on_connect: bool = True): + super().__init__(sock.fileno(), open_fd=False, close_on_context_exit=close_on_context_exit) + self.do_handshake_on_connect = do_handshake_on_connect + self.stream = socket.fromfd(self._fd, sock.family, sock.type, sock.proto) + self.stream.setblocking(False) + # A socket that isn't connected doesn't + # need to be closed + self.needs_closing: bool = False + + async def receive(self, max_size: int, flags: int = 0) -> bytes: + """ + Receives up to max_size bytes from a socket asynchronously + """ + + assert max_size >= 1, "max_size must be >= 1" + if self._fd == -1: + raise ResourceClosed("I/O operation on closed socket") + while True: + try: + return self.stream.recv(max_size, flags) + except WantRead: + await wait_readable(self.stream) + except WantWrite: + await wait_writable(self.stream) + + async def connect(self, address): + """ + Wrapper socket method + """ + + if self._fd == -1: + raise ResourceClosed("I/O operation on closed socket") + while True: + try: + self.stream.connect(address) + if self.do_handshake_on_connect: + await self.do_handshake() + break + except WantWrite: + await wait_writable(self.stream) + self.needs_closing = True + + async def close(self): + """ + Wrapper socket method + """ + + if self.needs_closing: + await super().close() + + async def accept(self): + """ + Accepts the socket, completing the 3-step TCP handshake asynchronously + """ + + if self._fd == -1: + raise ResourceClosed("I/O operation on closed socket") + while True: + try: + remote, addr = self.stream.accept() + return type(self)(remote), addr + except WantRead: + await wait_readable(self.stream) + + async def send_all(self, data: bytes, flags: int = 0): + """ + Sends all data inside the buffer asynchronously until it is empty + """ + + if self._fd == -1: + raise ResourceClosed("I/O operation on closed socket") + sent_no = 0 + while data: + try: + sent_no = self.stream.send(data, flags) + except WantRead: + await wait_readable(self.stream) + except WantWrite: + await wait_writable(self.stream) + data = data[sent_no:] + + async def shutdown(self, how): + """ + Wrapper socket method + """ + + if self.stream: + self.stream.shutdown(how) + await aiosched.checkpoint() + + async def bind(self, addr: tuple): + """ + Binds the socket to an address + :param addr: The address, port tuple to bind to + :type addr: tuple + """ + + if self._fd == -1: + raise ResourceClosed("I/O operation on closed socket") + self.stream.bind(addr) + + async def listen(self, backlog: int): + """ + Starts listening with the given backlog + :param backlog: The socket's backlog + :type backlog: int + """ + + if self._fd == -1: + raise ResourceClosed("I/O operation on closed socket") + self.stream.listen(backlog) + + # Yes, I stole these from Curio because I could not be + # arsed to write a bunch of uninteresting simple socket + # methods from scratch, deal with it. + + async def settimeout(self, seconds): + """ + Wrapper socket method + """ + + raise RuntimeError("Use with_timeout() to set a timeout") + + async def gettimeout(self): + """ + Wrapper socket method + """ + + return None + + async def dup(self): + """ + Wrapper socket method + """ + + return type(self)(self.stream.dup(), self.do_handshake_on_connect) + + async def do_handshake(self): + """ + Wrapper socket method + """ + + if not hasattr(self.stream, "do_handshake"): + return + while True: + try: + self.stream: SSLSocket # Silences pycharm warnings + return self.stream.do_handshake() + except WantRead: + await wait_readable(self.stream) + except WantWrite: + await wait_writable(self.stream) + + async def recvfrom(self, buffersize, flags=0): + """ + Wrapper socket method + """ + + while True: + try: + return self.stream.recvfrom(buffersize, flags) + except WantRead: + await wait_readable(self.stream) + except WantWrite: + await wait_writable(self.stream) + + async def recvfrom_into(self, buffer, bytes=0, flags=0): + """ + Wrapper socket method + """ + + while True: + try: + return self.stream.recvfrom_into(buffer, bytes, flags) + except WantRead: + await wait_readable(self.stream) + except WantWrite: + await wait_writable(self.stream) + + async def sendto(self, bytes, flags_or_address, address=None): + """ + Wrapper socket method + """ + + if address: + flags = flags_or_address + else: + address = flags_or_address + flags = 0 + while True: + try: + return self.stream.sendto(bytes, flags, address) + except WantWrite: + await wait_writable(self.stream) + except WantRead: + await wait_readable(self.stream) + + async def getpeername(self): + """ + Wrapper socket method + """ + + while True: + try: + return self.stream.getpeername() + except WantWrite: + await wait_writable(self.stream) + except WantRead: + await wait_readable(self.stream) + + async def getsockname(self): + """ + Wrapper socket method + """ + + while True: + try: + return self.stream.getpeername() + except WantWrite: + await wait_writable(self.stream) + except WantRead: + await wait_readable(self.stream) + + async def recvmsg(self, bufsize, ancbufsize=0, flags=0): + """ + Wrapper socket method + """ + + while True: + try: + return self.stream.recvmsg(bufsize, ancbufsize, flags) + except WantRead: + await wait_readable(self.stream) + + async def recvmsg_into(self, buffers, ancbufsize=0, flags=0): + """ + Wrapper socket method + """ + + while True: + try: + return self.stream.recvmsg_into(buffers, ancbufsize, flags) + except WantRead: + await wait_readable(self.stream) + + async def sendmsg(self, buffers, ancdata=(), flags=0, address=None): + """ + Wrapper socket method + """ + + while True: + try: + return self.stream.sendmsg(buffers, ancdata, flags, address) + except WantRead: + await wait_writable(self.stream) + + def __repr__(self): + return f"AsyncSocket({self.stream})" + + def __del__(self): + if self.needs_closing: + super().__del__() diff --git a/aiosched/socket.py b/aiosched/socket.py new file mode 100644 index 0000000..721ee26 --- /dev/null +++ b/aiosched/socket.py @@ -0,0 +1,37 @@ +""" +aiosched: Yet another Python async scheduler + +Copyright (C) 2022 nocturn9x + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https:www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import socket as _socket +from aiosched.io import AsyncSocket + + +def wrap_socket(sock: _socket.socket) -> AsyncSocket: + """ + Wraps a standard socket into an async socket + """ + + return AsyncSocket(sock) + + +def socket(*args, **kwargs): + """ + Creates a new giambio socket, taking in the same positional and + keyword arguments as the standard library's socket.socket + constructor + """ + + return wrap_socket(_socket.socket(*args, **kwargs)) \ No newline at end of file