From 8f3d7056b7152dbcd0f33ffffcdf03e9045c46c8 Mon Sep 17 00:00:00 2001 From: Nocturn9x Date: Wed, 19 Oct 2022 12:22:02 +0200 Subject: [PATCH] Ported NetworkChannel and related test from giambio --- aiosched/__init__.py | 7 ++- aiosched/io.py | 61 ++++++++++--------- aiosched/sync.py | 126 +++++++++++++++++++++++++++++++++++++++ tests/network_channel.py | 31 ++++++++++ 4 files changed, 192 insertions(+), 33 deletions(-) create mode 100644 tests/network_channel.py diff --git a/aiosched/__init__.py b/aiosched/__init__.py index 8ec0d0e..5b7c833 100644 --- a/aiosched/__init__.py +++ b/aiosched/__init__.py @@ -20,7 +20,8 @@ from aiosched.internals.syscalls import spawn, wait, sleep, cancel, checkpoint import aiosched.task import aiosched.errors import aiosched.context -from aiosched.sync import Event, Queue, Channel, MemoryChannel +import aiosched.socket +from aiosched.sync import Event, Queue, Channel, MemoryChannel, NetworkChannel __all__ = [ "run", @@ -37,5 +38,7 @@ __all__ = [ "Queue", "Channel", "MemoryChannel", - "checkpoint" + "checkpoint", + "NetworkChannel", + "socket" ] diff --git a/aiosched/io.py b/aiosched/io.py index aa976df..6b98c20 100644 --- a/aiosched/io.py +++ b/aiosched/io.py @@ -17,6 +17,7 @@ limitations under the License. """ import socket +import ssl import warnings import os import aiosched @@ -27,20 +28,18 @@ from aiosched.internals.syscalls import wait_writable, wait_readable, io_release try: from ssl import SSLWantReadError, SSLWantWriteError, SSLSocket - WantRead = (BlockingIOError, InterruptedError, SSLWantReadError) - WantWrite = (BlockingIOError, InterruptedError, SSLWantWriteError) + ReadBlock = (BlockingIOError, InterruptedError, SSLWantReadError) + WriteBlock = (BlockingIOError, InterruptedError, SSLWantWriteError) except ImportError: - WantRead = (BlockingIOError, InterruptedError) - WantWrite = (BlockingIOError, InterruptedError) + ReadBlock = (BlockingIOError, InterruptedError) + WriteBlock = (BlockingIOError, InterruptedError) class AsyncStream: """ A generic asynchronous stream over - a file descriptor. Only works on Linux - & co because windows doesn't like select() - to be called on non-socket objects - (Thanks, Microsoft) + a file descriptor. Functionality + is OS-dependent """ def __init__(self, fd: int, open_fd: bool = True, close_on_context_exit: bool = True, **kwargs): @@ -61,7 +60,7 @@ class AsyncStream: while True: try: return self.stream.read(size) - except WantRead: + except ReadBlock: await wait_readable(self.stream) async def write(self, data): @@ -74,7 +73,7 @@ class AsyncStream: while True: try: return self.stream.write(data) - except WantWrite: + except WriteBlock: await wait_writable(self.stream) async def close(self): @@ -155,9 +154,9 @@ class AsyncSocket(AsyncStream): while True: try: return self.stream.recv(max_size, flags) - except WantRead: + except ReadBlock: await wait_readable(self.stream) - except WantWrite: + except WriteBlock: await wait_writable(self.stream) async def connect(self, address): @@ -173,7 +172,7 @@ class AsyncSocket(AsyncStream): if self.do_handshake_on_connect: await self.do_handshake() break - except WantWrite: + except WriteBlock: await wait_writable(self.stream) self.needs_closing = True @@ -196,7 +195,7 @@ class AsyncSocket(AsyncStream): try: remote, addr = self.stream.accept() return type(self)(remote), addr - except WantRead: + except ReadBlock: await wait_readable(self.stream) async def send_all(self, data: bytes, flags: int = 0): @@ -210,9 +209,9 @@ class AsyncSocket(AsyncStream): while data: try: sent_no = self.stream.send(data, flags) - except WantRead: + except ReadBlock: await wait_readable(self.stream) - except WantWrite: + except WriteBlock: await wait_writable(self.stream) data = data[sent_no:] @@ -283,9 +282,9 @@ class AsyncSocket(AsyncStream): try: self.stream: SSLSocket # Silences pycharm warnings return self.stream.do_handshake() - except WantRead: + except ReadBlock: await wait_readable(self.stream) - except WantWrite: + except WriteBlock: await wait_writable(self.stream) async def recvfrom(self, buffersize, flags=0): @@ -296,9 +295,9 @@ class AsyncSocket(AsyncStream): while True: try: return self.stream.recvfrom(buffersize, flags) - except WantRead: + except ReadBlock: await wait_readable(self.stream) - except WantWrite: + except WriteBlock: await wait_writable(self.stream) async def recvfrom_into(self, buffer, bytes=0, flags=0): @@ -309,9 +308,9 @@ class AsyncSocket(AsyncStream): while True: try: return self.stream.recvfrom_into(buffer, bytes, flags) - except WantRead: + except ReadBlock: await wait_readable(self.stream) - except WantWrite: + except WriteBlock: await wait_writable(self.stream) async def sendto(self, bytes, flags_or_address, address=None): @@ -327,9 +326,9 @@ class AsyncSocket(AsyncStream): while True: try: return self.stream.sendto(bytes, flags, address) - except WantWrite: + except WriteBlock: await wait_writable(self.stream) - except WantRead: + except ReadBlock: await wait_readable(self.stream) async def getpeername(self): @@ -340,9 +339,9 @@ class AsyncSocket(AsyncStream): while True: try: return self.stream.getpeername() - except WantWrite: + except WriteBlock: await wait_writable(self.stream) - except WantRead: + except ReadBlock: await wait_readable(self.stream) async def getsockname(self): @@ -353,9 +352,9 @@ class AsyncSocket(AsyncStream): while True: try: return self.stream.getpeername() - except WantWrite: + except WriteBlock: await wait_writable(self.stream) - except WantRead: + except ReadBlock: await wait_readable(self.stream) async def recvmsg(self, bufsize, ancbufsize=0, flags=0): @@ -366,7 +365,7 @@ class AsyncSocket(AsyncStream): while True: try: return self.stream.recvmsg(bufsize, ancbufsize, flags) - except WantRead: + except ReadBlock: await wait_readable(self.stream) async def recvmsg_into(self, buffers, ancbufsize=0, flags=0): @@ -377,7 +376,7 @@ class AsyncSocket(AsyncStream): while True: try: return self.stream.recvmsg_into(buffers, ancbufsize, flags) - except WantRead: + except ReadBlock: await wait_readable(self.stream) async def sendmsg(self, buffers, ancdata=(), flags=0, address=None): @@ -388,7 +387,7 @@ class AsyncSocket(AsyncStream): while True: try: return self.stream.sendmsg(buffers, ancdata, flags, address) - except WantRead: + except ReadBlock: await wait_writable(self.stream) def __repr__(self): diff --git a/aiosched/sync.py b/aiosched/sync.py index 133ddf8..fa615a5 100644 --- a/aiosched/sync.py +++ b/aiosched/sync.py @@ -24,6 +24,9 @@ from aiosched.internals.syscalls import ( schedule, current_task, ) +from aiosched.task import Task +from aiosched.socket import wrap_socket +from socket import socketpair class Event: @@ -266,3 +269,126 @@ class MemoryChannel(Channel): """ return bool(len(self.buffer)) + + +class NetworkChannel(Channel): + """ + A two-way communication channel between tasks + that uses an underlying socket pair to communicate + instead of in-memory queues. Not thread safe + """ + + def __init__(self): + """ + Public object constructor + """ + + super().__init__(None) + # We use a socket as our buffer instead of a queue + sockets = socketpair() + self.reader = wrap_socket(sockets[0]) + self.writer = wrap_socket(sockets[1]) + self._internal_buffer = b"" + + async def write(self, data: bytes): + """ + Writes data to the channel. Blocks if the internal + socket is not currently available. Does nothing + if the channel has been closed + """ + + if self.closed: + return + await self.writer.send_all(data) + + async def read(self, size: int): + """ + Reads exactly size bytes from the channel. Blocks until + enough data arrives. Extra data is cached and used on the + next read + """ + + data = self._internal_buffer + while len(data) < size: + data += await self.reader.receive(size) + self._internal_buffer = data[size:] + data = data[:size] + return data + + async def close(self): + """ + Closes the memory channel. Any underlying + data is flushed out of the internal socket + and is lost + """ + + self.closed = True + await self.reader.close() + await self.writer.close() + + async def pending(self): + """ + Returns if there's pending + data to be read + """ + + # TODO: Ugly! + if self.closed: + return False + try: + self._internal_buffer += self.reader.stream.recv(1) + except BlockingIOError: + return False + return True + + +class Lock: + """ + A simple asynchronous single-owner lock. + Not thread safe + """ + + def __init__(self): + """ + Public constructor + """ + + self.owner: Task | None = None + self.tasks: deque[Event] = deque() + + async def acquire(self): + """ + Acquires the lock + """ + + task = await current_task() + if self.owner is None: + self.owner = task + elif task is self.owner: + raise RuntimeError("lock is already acquired by current task") + elif self.owner is not task: + self.tasks.append(Event()) + await self.tasks[-1].wait() + self.owner = task + + async def release(self): + """ + Releases the lock + """ + + task = await current_task() + if self.owner is None: + raise RuntimeError("lock is not acquired") + elif self.owner is not task: + raise RuntimeError("lock can only released by its owner") + elif self.tasks: + await self.tasks.popleft().trigger() + else: + self.owner = None + + async def __aenter__(self): + await self.acquire() + return self + + async def __aexit__(self, *args): + await self.release() diff --git a/tests/network_channel.py b/tests/network_channel.py new file mode 100644 index 0000000..625bef7 --- /dev/null +++ b/tests/network_channel.py @@ -0,0 +1,31 @@ +import aiosched +from debugger import Debugger + + +async def sender(c: aiosched.NetworkChannel, n: int): + for i in range(n): + await c.write(str(i).encode()) + print(f"Sent {i}") + await c.close() + print("Sender done") + + +async def receiver(c: aiosched.NetworkChannel): + while True: + if not await c.pending() and c.closed: + print("Receiver done") + break + item = (await c.read(1)).decode() + print(f"Received {item}") + await aiosched.sleep(1) + + +async def main(channel: aiosched.NetworkChannel, n: int): + print("Starting sender and receiver") + async with aiosched.with_context() as ctx: + await ctx.spawn(sender, channel, n) + await ctx.spawn(receiver, channel) + print("All done!") + + +aiosched.run(main, aiosched.NetworkChannel(), 5, debugger=()) # 2 is the max size of the channel