From e37ffdeb062e873b705e92dc0f521158df85adb2 Mon Sep 17 00:00:00 2001 From: Nocturn9x Date: Mon, 10 Oct 2022 13:35:22 +0200 Subject: [PATCH] Initial broken work on a generic streams interface --- README.md | 4 + giambio/core.py | 28 ++++- giambio/exceptions.py | 2 +- giambio/io.py | 246 ++++++++++++++++++++++++++------------- giambio/traps.py | 13 +++ tests/chatroom_client.py | 31 ++--- tests/chatroom_server.py | 45 ++++--- tests/echo_server.py | 2 +- 8 files changed, 252 insertions(+), 119 deletions(-) diff --git a/README.md b/README.md index db44bfc..1cb08ab 100644 --- a/README.md +++ b/README.md @@ -16,6 +16,10 @@ rock-solid and structured concurrency framework (I personally recommend trio and that most of the content of this document is ~~stolen~~ inspired from its documentation) +# Disclaimer #2 + +This is a toy project. Don't try to use it in production, it *will* explode + ## Goals of this project Making yet another async library might sound dumb in an already fragmented ecosystem like Python's. diff --git a/giambio/core.py b/giambio/core.py index 871bb83..9c1f004 100644 --- a/giambio/core.py +++ b/giambio/core.py @@ -15,7 +15,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ - +import functools # Import libraries and internal resources from numbers import Number from giambio.task import Task @@ -33,6 +33,7 @@ from giambio.exceptions import ( ResourceBusy, GiambioError, TooSlowError, + ResourceClosed ) @@ -433,12 +434,15 @@ class AsyncScheduler: task.result = ret.value task.finished = True self.join(task) - self.tasks.remove(task) + except CancelledError as cancel: + task.status = "cancelled" + task.cancel_pending = False + task.cancelled = True + self.join(task) except BaseException as err: task.exc = err self.join(task) - if task in self.tasks: - self.tasks.remove(task) + def prune_deadlines(self): """ @@ -666,6 +670,8 @@ class AsyncScheduler: self.io_release_task(task) if task in self.suspended: self.suspended.remove(task) + if task in self.tasks: + self.tasks.remove(task) # If the pool (including any enclosing pools) has finished executing # or we're at the first task that kicked the loop, we can safely # reschedule the parent(s) @@ -770,13 +776,25 @@ class AsyncScheduler: task.cancelled = True task.status = "cancelled" self.debugger.after_cancel(task) - self.tasks.remove(task) self.join(task) else: # If the task ignores our exception, we'll # raise it later again task.cancel_pending = True + def notify_closing(self, stream): + """ + Implements the notify_closing trap + """ + + if self.selector.get_map(): + for k in filter( + lambda o: o.data == self.current_task, + dict(self.selector.get_map()).values(), + ): + self.handle_task_exit(k.data, + functools.partial(k.data.throw(ResourceClosed("stream has been closed")))) + def register_sock(self, sock, evt_type: str): """ Registers the given socket inside the diff --git a/giambio/exceptions.py b/giambio/exceptions.py index dc45242..15e9852 100644 --- a/giambio/exceptions.py +++ b/giambio/exceptions.py @@ -37,7 +37,7 @@ class InternalError(GiambioError): ... -class CancelledError(GiambioError): +class CancelledError(BaseException): """ Exception raised by the giambio.objects.Task.cancel() method to terminate a child task. This should NOT be caught, or diff --git a/giambio/io.py b/giambio/io.py index 1d060f4..cd3013e 100644 --- a/giambio/io.py +++ b/giambio/io.py @@ -15,14 +15,16 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ +import socket import warnings - +import os import giambio from giambio.exceptions import ResourceClosed -from giambio.traps import want_write, want_read, io_release +from giambio.traps import want_write, want_read, io_release, notify_closing + try: - from ssl import SSLWantReadError, SSLWantWriteError + from ssl import SSLWantReadError, SSLWantWriteError, SSLSocket WantRead = (BlockingIOError, InterruptedError, SSLWantReadError) WantWrite = (BlockingIOError, InterruptedError, SSLWantWriteError) @@ -31,16 +33,115 @@ except ImportError: WantWrite = (BlockingIOError, InterruptedError) -class AsyncSocket: +class AsyncStream: + """ + A generic asynchronous stream over + a file descriptor. Only works on Linux + & co because windows doesn't like select() + to be called on non-socket objects + (Thanks, Microsoft) + """ + + def __init__(self, fd: int, open_fd: bool = True, close_on_context_exit: bool = True, **kwargs): + self._fd = fd + self.stream = None + if open_fd: + self.stream = os.fdopen(self._fd, **kwargs) + os.set_blocking(self._fd, False) + self.close_on_context_exit = close_on_context_exit + + async def read(self, size: int = -1): + """ + Reads up to size bytes from the + given stream. If size == -1, read + until EOF is reached + """ + + while True: + try: + return self.stream.read(size) + except WantRead: + await want_read(self.stream) + + async def write(self, data): + """ + Writes data b to the file. + Returns the number of bytes + written + """ + + while True: + try: + return self.stream.write(data) + except WantWrite: + await want_write(self.stream) + + async def close(self): + """ + Closes the stream asynchronously + """ + + if self._fd == -1: + raise ResourceClosed("I/O operation on closed stream") + self._fd = -1 + await notify_closing(self.stream) + await io_release(self.stream) + self.stream.close() + self.stream = None + + @property + async def fileno(self): + """ + Wrapper socket method + """ + + return self._fd + + async def __aenter__(self): + self.stream.__enter__() + return self + + async def __aexit__(self, *args): + if self._fd != -1 and self.close_on_context_exit: + await self.close() + + async def dup(self): + """ + Wrapper stream method + """ + + return type(self)(os.dup(self._fd)) + + def __repr__(self): + return f"AsyncStream({self.stream})" + + def __del__(self): + """ + Stream destructor. Do *not* call + this directly: stuff will break + """ + + if self._fd != -1: + try: + os.set_blocking(self._fd, False) + os.close(self._fd) + except OSError as e: + warnings.warn(f"An exception occurred in __del__ for stream {self} -> {type(e).__name__}: {e}") + + +class AsyncSocket(AsyncStream): """ Abstraction layer for asynchronous sockets """ - def __init__(self, sock, do_handshake_on_connect: bool = True): - self.sock = sock + def __init__(self, sock: socket.socket, close_on_context_exit: bool = True, do_handshake_on_connect: bool = True): + super().__init__(sock.fileno(), open_fd=False, close_on_context_exit=close_on_context_exit) self.do_handshake_on_connect = do_handshake_on_connect - self._fd = sock.fileno() - self.sock.setblocking(False) + self.stream = socket.fromfd(self._fd, sock.family, sock.type, sock.proto) + self.stream.setblocking(False) + # A socket that isn't connected doesn't + # need to be closed + self.needs_closing: bool = False async def receive(self, max_size: int, flags: int = 0) -> bytes: """ @@ -52,11 +153,11 @@ class AsyncSocket: raise ResourceClosed("I/O operation on closed socket") while True: try: - return self.sock.recv(max_size, flags) + return self.stream.recv(max_size, flags) except WantRead: - await want_read(self.sock) + await want_read(self.stream) except WantWrite: - await want_write(self.sock) + await want_write(self.stream) async def connect(self, address): """ @@ -67,12 +168,21 @@ class AsyncSocket: raise ResourceClosed("I/O operation on closed socket") while True: try: - self.sock.connect(address) + self.stream.connect(address) if self.do_handshake_on_connect: await self.do_handshake() - return + break except WantWrite: - await want_write(self.sock) + await want_write(self.stream) + self.needs_closing = True + + async def close(self): + """ + Wrapper socket method + """ + + if self.needs_closing: + await super().close() async def accept(self): """ @@ -83,10 +193,10 @@ class AsyncSocket: raise ResourceClosed("I/O operation on closed socket") while True: try: - remote, addr = self.sock.accept() + remote, addr = self.stream.accept() return type(self)(remote), addr except WantRead: - await want_read(self.sock) + await want_read(self.stream) async def send_all(self, data: bytes, flags: int = 0): """ @@ -98,32 +208,20 @@ class AsyncSocket: sent_no = 0 while data: try: - sent_no = self.sock.send(data, flags) + sent_no = self.stream.send(data, flags) except WantRead: - await want_read(self.sock) + await want_read(self.stream) except WantWrite: - await want_write(self.sock) + await want_write(self.stream) data = data[sent_no:] - async def close(self): - """ - Closes the socket asynchronously - """ - - if self._fd == -1: - raise ResourceClosed("I/O operation on closed socket") - await io_release(self.sock) - self.sock.close() - self._fd = -1 - self.sock = None - async def shutdown(self, how): """ Wrapper socket method """ - if self.sock: - self.sock.shutdown(how) + if self.stream: + self.stream.shutdown(how) await giambio.sleep(0) # Checkpoint async def bind(self, addr: tuple): @@ -136,7 +234,7 @@ class AsyncSocket: if self._fd == -1: raise ResourceClosed("I/O operation on closed socket") - self.sock.bind(addr) + self.stream.bind(addr) async def listen(self, backlog: int): """ @@ -148,27 +246,12 @@ class AsyncSocket: if self._fd == -1: raise ResourceClosed("I/O operation on closed socket") - self.sock.listen(backlog) - - async def __aenter__(self): - self.sock.__enter__() - return self - - async def __aexit__(self, *args): - if self.sock: - self.sock.__exit__(*args) + self.stream.listen(backlog) # Yes, I stole these from Curio because I could not be # arsed to write a bunch of uninteresting simple socket # methods from scratch, deal with it. - async def fileno(self): - """ - Wrapper socket method - """ - - return self._fd - async def settimeout(self, seconds): """ Wrapper socket method @@ -188,22 +271,23 @@ class AsyncSocket: Wrapper socket method """ - return type(self)(self.sock.dup()) + return type(self)(self.stream.dup(), self.do_handshake_on_connect) async def do_handshake(self): """ Wrapper socket method """ - if not hasattr(self.sock, "do_handshake"): + if not hasattr(self.stream, "do_handshake"): return while True: try: - return self.sock.do_handshake() + self.stream: SSLSocket # Silences pycharm warnings + return self.stream.do_handshake() except WantRead: - await want_read(self.sock) + await want_read(self.stream) except WantWrite: - await want_write(self.sock) + await want_write(self.stream) async def recvfrom(self, buffersize, flags=0): """ @@ -212,11 +296,11 @@ class AsyncSocket: while True: try: - return self.sock.recvfrom(buffersize, flags) + return self.stream.recvfrom(buffersize, flags) except WantRead: - await want_read(self.sock) + await want_read(self.stream) except WantWrite: - await want_write(self.sock) + await want_write(self.stream) async def recvfrom_into(self, buffer, bytes=0, flags=0): """ @@ -225,11 +309,11 @@ class AsyncSocket: while True: try: - return self.sock.recvfrom_into(buffer, bytes, flags) + return self.stream.recvfrom_into(buffer, bytes, flags) except WantRead: - await want_read(self.sock) + await want_read(self.stream) except WantWrite: - await want_write(self.sock) + await want_write(self.stream) async def sendto(self, bytes, flags_or_address, address=None): """ @@ -243,11 +327,11 @@ class AsyncSocket: flags = 0 while True: try: - return self.sock.sendto(bytes, flags, address) + return self.stream.sendto(bytes, flags, address) except WantWrite: - await want_write(self.sock) + await want_write(self.stream) except WantRead: - await want_read(self.sock) + await want_read(self.stream) async def getpeername(self): """ @@ -256,11 +340,11 @@ class AsyncSocket: while True: try: - return self.sock.getpeername() + return self.stream.getpeername() except WantWrite: - await want_write(self.sock) + await want_write(self.stream) except WantRead: - await want_read(self.sock) + await want_read(self.stream) async def getsockname(self): """ @@ -269,11 +353,11 @@ class AsyncSocket: while True: try: - return self.sock.getpeername() + return self.stream.getpeername() except WantWrite: - await want_write(self.sock) + await want_write(self.stream) except WantRead: - await want_read(self.sock) + await want_read(self.stream) async def recvmsg(self, bufsize, ancbufsize=0, flags=0): """ @@ -282,9 +366,9 @@ class AsyncSocket: while True: try: - return self.sock.recvmsg(bufsize, ancbufsize, flags) + return self.stream.recvmsg(bufsize, ancbufsize, flags) except WantRead: - await want_read(self.sock) + await want_read(self.stream) async def recvmsg_into(self, buffers, ancbufsize=0, flags=0): """ @@ -293,9 +377,9 @@ class AsyncSocket: while True: try: - return self.sock.recvmsg_into(buffers, ancbufsize, flags) + return self.stream.recvmsg_into(buffers, ancbufsize, flags) except WantRead: - await want_read(self.sock) + await want_read(self.stream) async def sendmsg(self, buffers, ancdata=(), flags=0, address=None): """ @@ -304,17 +388,13 @@ class AsyncSocket: while True: try: - return self.sock.sendmsg(buffers, ancdata, flags, address) + return self.stream.sendmsg(buffers, ancdata, flags, address) except WantRead: - await want_write(self.sock) + await want_write(self.stream) def __repr__(self): - return f"AsyncSocket({self.sock})" + return f"AsyncSocket({self.stream})" def __del__(self): - """ - Socket destructor - """ - - if not self._fd != -1: - warnings.warn(f"socket '{self}' was destroyed, but was not closed, leading to a potential resource leak") + if self.needs_closing: + super().__del__() diff --git a/giambio/traps.py b/giambio/traps.py index e9cbea4..24a21ac 100644 --- a/giambio/traps.py +++ b/giambio/traps.py @@ -178,6 +178,19 @@ async def want_write(stream): await create_trap("register_sock", stream, "write") +async def notify_closing(stream): + """ + Notifies the event loop that a given + stream needs to be closed. This makes + all callers waiting on want_read or + want_write crash with a ResourceClosed + exception, but it doesn't actually close + the socket object itself + """ + + await create_trap("notify_closing", stream) + + async def schedule_tasks(tasks: Iterable[Task]): """ Schedules a list of tasks for execution. Usuaully diff --git a/tests/chatroom_client.py b/tests/chatroom_client.py index e6c0a9b..ed23b64 100644 --- a/tests/chatroom_client.py +++ b/tests/chatroom_client.py @@ -1,44 +1,49 @@ import sys -from typing import Tuple import giambio import logging +from debugger import Debugger -async def sender(sock: giambio.socket.AsyncSocket, q: giambio.Queue): + +async def reader(q: giambio.Queue, prompt: str = ""): + in_stream = giambio.io.AsyncStream(sys.stdin.fileno(), close_on_context_exit=False, mode="r") + out_stream = giambio.io.AsyncStream(sys.stdout.fileno(), close_on_context_exit=False, mode="w") while True: - await sock.send_all(b"yo") - await q.put((0, "")) - await giambio.sleep(1) + await out_stream.write(prompt) + await q.put((0, await in_stream.read())) async def receiver(sock: giambio.socket.AsyncSocket, q: giambio.Queue): data = b"" while True: while not data.endswith(b"\n"): - data += await sock.receive(1024) + temp = await sock.receive(1024) + if not temp: + raise EOFError("end of file") + data += temp data, rest = data.split(b"\n", maxsplit=2) buffer = b"".join(rest) await q.put((1, data.decode())) data = buffer -async def main(host: Tuple[str, int]): +async def main(host: tuple[str, int]): """ Main client entry point """ queue = giambio.Queue() + out_stream = giambio.io.AsyncStream(sys.stdout.fileno(), close_on_context_exit=False, mode="w") async with giambio.create_pool() as pool: async with giambio.socket.socket() as sock: await sock.connect(host) - await pool.spawn(sender, sock, queue) + await out_stream.write("Connection successful\n") await pool.spawn(receiver, sock, queue) + await pool.spawn(reader, queue, "> ") while True: op, data = await queue.get() - if op == 0: - print(f"Sent.") - else: - print(f"Received: {data}") + if op == 1: + await out_stream.write(data) if __name__ == "__main__": @@ -49,7 +54,7 @@ if __name__ == "__main__": datefmt="%d/%m/%Y %p", ) try: - giambio.run(main, ("localhost", port)) + giambio.run(main, ("localhost", port), debugger=Debugger()) except (Exception, KeyboardInterrupt) as error: # Exceptions propagate! if isinstance(error, KeyboardInterrupt): logging.info("Ctrl+C detected, exiting") diff --git a/tests/chatroom_server.py b/tests/chatroom_server.py index e46d07a..865de5e 100644 --- a/tests/chatroom_server.py +++ b/tests/chatroom_server.py @@ -1,4 +1,3 @@ -from typing import List import giambio from giambio.socket import AsyncSocket import logging @@ -6,7 +5,8 @@ import sys # An asynchronous chatroom -clients: List[giambio.socket.AsyncSocket] = [] +clients: dict[AsyncSocket, list[str, str]] = {} +names: set[str] = set() async def serve(bind_address: tuple): @@ -26,39 +26,52 @@ async def serve(bind_address: tuple): while True: try: conn, address_tuple = await sock.accept() - clients.append(conn) + clients[conn] = ["", f"{address_tuple[0]}:{address_tuple[1]}"] logging.info(f"{address_tuple[0]}:{address_tuple[1]} connected") - await pool.spawn(handler, conn, address_tuple) + await pool.spawn(handler, conn) except Exception as err: # Because exceptions just *work* logging.info(f"{address_tuple[0]}:{address_tuple[1]} has raised {type(err).__name__}: {err}") -async def handler(sock: AsyncSocket, client_address: tuple): +async def handler(sock: AsyncSocket): """ Handles a single client connection :param sock: The AsyncSocket object connected to the client - :param client_address: The client's address represented as a tuple - (address, port) where address is a string and port is an integer - :type client_address: tuple """ - address = f"{client_address[0]}:{client_address[1]}" + address = clients[sock][1] + name = "" async with sock: # Closes the socket automatically - await sock.send_all(b"Welcome to the chatroom pal, start typing and press enter!\n") + await sock.send_all(b"Welcome to the chatroom pal, may you tell me your name?\n> ") while True: + while not name.endswith("\n"): + name = (await sock.receive(64)).decode() + name = name[:-1] + if name not in names: + names.add(name) + clients[sock][0] = name + break + else: + await sock.send_all(b"Sorry, but that name is already taken. Try again!\n> ") + await sock.send_all(f"Okay {name}, welcome to the chatroom!\n".encode()) + logging.info(f"{name} has joined the chatroom ({address}), informing clients") + for i, client_sock in enumerate(clients): + if client_sock != sock and clients[client_sock][0]: + await client_sock.send_all(f"{name} joins the chatroom!\n> ".encode()) + while True: + await sock.send_all(b"> ") data = await sock.receive(1024) if not data: break - elif data == b"exit\n": - await sock.send_all(b"I'm dead dude\n") - raise TypeError("Oh, no, I'm gonna die!") logging.info(f"Got: {data!r} from {address}") for i, client_sock in enumerate(clients): - logging.info(f"Sending {data!r} to {':'.join(map(str, await client_sock.getpeername()))}") - if client_sock != sock: - await client_sock.send_all(data) + if client_sock != sock and clients[client_sock][0]: + logging.info(f"Sending {data!r} to {':'.join(map(str, await client_sock.getpeername()))}") + if not data.endswith(b"\n"): + data += b"\n" + await client_sock.send_all(f"[{name}] ({address}): {data.decode()}> ".encode()) logging.info(f"Sent {data!r} to {i} clients") logging.info(f"Connection from {address} closed") clients.remove(sock) diff --git a/tests/echo_server.py b/tests/echo_server.py index 761c1f1..d428ea8 100644 --- a/tests/echo_server.py +++ b/tests/echo_server.py @@ -63,7 +63,7 @@ if __name__ == "__main__": logging.basicConfig( level=20, format="[%(levelname)s] %(asctime)s %(message)s", - datefmt="%d/%m/%Y %p", + datefmt="%d/%m/%Y %H:%M:%S %p", ) try: giambio.run(serve, ("localhost", port), debugger=())