Ported NetworkChannel and related test from giambio

This commit is contained in:
Nocturn9x 2022-10-19 12:22:02 +02:00
parent 3b81702c2b
commit 8f3d7056b7
4 changed files with 192 additions and 33 deletions

View File

@ -20,7 +20,8 @@ from aiosched.internals.syscalls import spawn, wait, sleep, cancel, checkpoint
import aiosched.task import aiosched.task
import aiosched.errors import aiosched.errors
import aiosched.context import aiosched.context
from aiosched.sync import Event, Queue, Channel, MemoryChannel import aiosched.socket
from aiosched.sync import Event, Queue, Channel, MemoryChannel, NetworkChannel
__all__ = [ __all__ = [
"run", "run",
@ -37,5 +38,7 @@ __all__ = [
"Queue", "Queue",
"Channel", "Channel",
"MemoryChannel", "MemoryChannel",
"checkpoint" "checkpoint",
"NetworkChannel",
"socket"
] ]

View File

@ -17,6 +17,7 @@ limitations under the License.
""" """
import socket import socket
import ssl
import warnings import warnings
import os import os
import aiosched import aiosched
@ -27,20 +28,18 @@ from aiosched.internals.syscalls import wait_writable, wait_readable, io_release
try: try:
from ssl import SSLWantReadError, SSLWantWriteError, SSLSocket from ssl import SSLWantReadError, SSLWantWriteError, SSLSocket
WantRead = (BlockingIOError, InterruptedError, SSLWantReadError) ReadBlock = (BlockingIOError, InterruptedError, SSLWantReadError)
WantWrite = (BlockingIOError, InterruptedError, SSLWantWriteError) WriteBlock = (BlockingIOError, InterruptedError, SSLWantWriteError)
except ImportError: except ImportError:
WantRead = (BlockingIOError, InterruptedError) ReadBlock = (BlockingIOError, InterruptedError)
WantWrite = (BlockingIOError, InterruptedError) WriteBlock = (BlockingIOError, InterruptedError)
class AsyncStream: class AsyncStream:
""" """
A generic asynchronous stream over A generic asynchronous stream over
a file descriptor. Only works on Linux a file descriptor. Functionality
& co because windows doesn't like select() is OS-dependent
to be called on non-socket objects
(Thanks, Microsoft)
""" """
def __init__(self, fd: int, open_fd: bool = True, close_on_context_exit: bool = True, **kwargs): def __init__(self, fd: int, open_fd: bool = True, close_on_context_exit: bool = True, **kwargs):
@ -61,7 +60,7 @@ class AsyncStream:
while True: while True:
try: try:
return self.stream.read(size) return self.stream.read(size)
except WantRead: except ReadBlock:
await wait_readable(self.stream) await wait_readable(self.stream)
async def write(self, data): async def write(self, data):
@ -74,7 +73,7 @@ class AsyncStream:
while True: while True:
try: try:
return self.stream.write(data) return self.stream.write(data)
except WantWrite: except WriteBlock:
await wait_writable(self.stream) await wait_writable(self.stream)
async def close(self): async def close(self):
@ -155,9 +154,9 @@ class AsyncSocket(AsyncStream):
while True: while True:
try: try:
return self.stream.recv(max_size, flags) return self.stream.recv(max_size, flags)
except WantRead: except ReadBlock:
await wait_readable(self.stream) await wait_readable(self.stream)
except WantWrite: except WriteBlock:
await wait_writable(self.stream) await wait_writable(self.stream)
async def connect(self, address): async def connect(self, address):
@ -173,7 +172,7 @@ class AsyncSocket(AsyncStream):
if self.do_handshake_on_connect: if self.do_handshake_on_connect:
await self.do_handshake() await self.do_handshake()
break break
except WantWrite: except WriteBlock:
await wait_writable(self.stream) await wait_writable(self.stream)
self.needs_closing = True self.needs_closing = True
@ -196,7 +195,7 @@ class AsyncSocket(AsyncStream):
try: try:
remote, addr = self.stream.accept() remote, addr = self.stream.accept()
return type(self)(remote), addr return type(self)(remote), addr
except WantRead: except ReadBlock:
await wait_readable(self.stream) await wait_readable(self.stream)
async def send_all(self, data: bytes, flags: int = 0): async def send_all(self, data: bytes, flags: int = 0):
@ -210,9 +209,9 @@ class AsyncSocket(AsyncStream):
while data: while data:
try: try:
sent_no = self.stream.send(data, flags) sent_no = self.stream.send(data, flags)
except WantRead: except ReadBlock:
await wait_readable(self.stream) await wait_readable(self.stream)
except WantWrite: except WriteBlock:
await wait_writable(self.stream) await wait_writable(self.stream)
data = data[sent_no:] data = data[sent_no:]
@ -283,9 +282,9 @@ class AsyncSocket(AsyncStream):
try: try:
self.stream: SSLSocket # Silences pycharm warnings self.stream: SSLSocket # Silences pycharm warnings
return self.stream.do_handshake() return self.stream.do_handshake()
except WantRead: except ReadBlock:
await wait_readable(self.stream) await wait_readable(self.stream)
except WantWrite: except WriteBlock:
await wait_writable(self.stream) await wait_writable(self.stream)
async def recvfrom(self, buffersize, flags=0): async def recvfrom(self, buffersize, flags=0):
@ -296,9 +295,9 @@ class AsyncSocket(AsyncStream):
while True: while True:
try: try:
return self.stream.recvfrom(buffersize, flags) return self.stream.recvfrom(buffersize, flags)
except WantRead: except ReadBlock:
await wait_readable(self.stream) await wait_readable(self.stream)
except WantWrite: except WriteBlock:
await wait_writable(self.stream) await wait_writable(self.stream)
async def recvfrom_into(self, buffer, bytes=0, flags=0): async def recvfrom_into(self, buffer, bytes=0, flags=0):
@ -309,9 +308,9 @@ class AsyncSocket(AsyncStream):
while True: while True:
try: try:
return self.stream.recvfrom_into(buffer, bytes, flags) return self.stream.recvfrom_into(buffer, bytes, flags)
except WantRead: except ReadBlock:
await wait_readable(self.stream) await wait_readable(self.stream)
except WantWrite: except WriteBlock:
await wait_writable(self.stream) await wait_writable(self.stream)
async def sendto(self, bytes, flags_or_address, address=None): async def sendto(self, bytes, flags_or_address, address=None):
@ -327,9 +326,9 @@ class AsyncSocket(AsyncStream):
while True: while True:
try: try:
return self.stream.sendto(bytes, flags, address) return self.stream.sendto(bytes, flags, address)
except WantWrite: except WriteBlock:
await wait_writable(self.stream) await wait_writable(self.stream)
except WantRead: except ReadBlock:
await wait_readable(self.stream) await wait_readable(self.stream)
async def getpeername(self): async def getpeername(self):
@ -340,9 +339,9 @@ class AsyncSocket(AsyncStream):
while True: while True:
try: try:
return self.stream.getpeername() return self.stream.getpeername()
except WantWrite: except WriteBlock:
await wait_writable(self.stream) await wait_writable(self.stream)
except WantRead: except ReadBlock:
await wait_readable(self.stream) await wait_readable(self.stream)
async def getsockname(self): async def getsockname(self):
@ -353,9 +352,9 @@ class AsyncSocket(AsyncStream):
while True: while True:
try: try:
return self.stream.getpeername() return self.stream.getpeername()
except WantWrite: except WriteBlock:
await wait_writable(self.stream) await wait_writable(self.stream)
except WantRead: except ReadBlock:
await wait_readable(self.stream) await wait_readable(self.stream)
async def recvmsg(self, bufsize, ancbufsize=0, flags=0): async def recvmsg(self, bufsize, ancbufsize=0, flags=0):
@ -366,7 +365,7 @@ class AsyncSocket(AsyncStream):
while True: while True:
try: try:
return self.stream.recvmsg(bufsize, ancbufsize, flags) return self.stream.recvmsg(bufsize, ancbufsize, flags)
except WantRead: except ReadBlock:
await wait_readable(self.stream) await wait_readable(self.stream)
async def recvmsg_into(self, buffers, ancbufsize=0, flags=0): async def recvmsg_into(self, buffers, ancbufsize=0, flags=0):
@ -377,7 +376,7 @@ class AsyncSocket(AsyncStream):
while True: while True:
try: try:
return self.stream.recvmsg_into(buffers, ancbufsize, flags) return self.stream.recvmsg_into(buffers, ancbufsize, flags)
except WantRead: except ReadBlock:
await wait_readable(self.stream) await wait_readable(self.stream)
async def sendmsg(self, buffers, ancdata=(), flags=0, address=None): async def sendmsg(self, buffers, ancdata=(), flags=0, address=None):
@ -388,7 +387,7 @@ class AsyncSocket(AsyncStream):
while True: while True:
try: try:
return self.stream.sendmsg(buffers, ancdata, flags, address) return self.stream.sendmsg(buffers, ancdata, flags, address)
except WantRead: except ReadBlock:
await wait_writable(self.stream) await wait_writable(self.stream)
def __repr__(self): def __repr__(self):

View File

@ -24,6 +24,9 @@ from aiosched.internals.syscalls import (
schedule, schedule,
current_task, current_task,
) )
from aiosched.task import Task
from aiosched.socket import wrap_socket
from socket import socketpair
class Event: class Event:
@ -266,3 +269,126 @@ class MemoryChannel(Channel):
""" """
return bool(len(self.buffer)) return bool(len(self.buffer))
class NetworkChannel(Channel):
"""
A two-way communication channel between tasks
that uses an underlying socket pair to communicate
instead of in-memory queues. Not thread safe
"""
def __init__(self):
"""
Public object constructor
"""
super().__init__(None)
# We use a socket as our buffer instead of a queue
sockets = socketpair()
self.reader = wrap_socket(sockets[0])
self.writer = wrap_socket(sockets[1])
self._internal_buffer = b""
async def write(self, data: bytes):
"""
Writes data to the channel. Blocks if the internal
socket is not currently available. Does nothing
if the channel has been closed
"""
if self.closed:
return
await self.writer.send_all(data)
async def read(self, size: int):
"""
Reads exactly size bytes from the channel. Blocks until
enough data arrives. Extra data is cached and used on the
next read
"""
data = self._internal_buffer
while len(data) < size:
data += await self.reader.receive(size)
self._internal_buffer = data[size:]
data = data[:size]
return data
async def close(self):
"""
Closes the memory channel. Any underlying
data is flushed out of the internal socket
and is lost
"""
self.closed = True
await self.reader.close()
await self.writer.close()
async def pending(self):
"""
Returns if there's pending
data to be read
"""
# TODO: Ugly!
if self.closed:
return False
try:
self._internal_buffer += self.reader.stream.recv(1)
except BlockingIOError:
return False
return True
class Lock:
"""
A simple asynchronous single-owner lock.
Not thread safe
"""
def __init__(self):
"""
Public constructor
"""
self.owner: Task | None = None
self.tasks: deque[Event] = deque()
async def acquire(self):
"""
Acquires the lock
"""
task = await current_task()
if self.owner is None:
self.owner = task
elif task is self.owner:
raise RuntimeError("lock is already acquired by current task")
elif self.owner is not task:
self.tasks.append(Event())
await self.tasks[-1].wait()
self.owner = task
async def release(self):
"""
Releases the lock
"""
task = await current_task()
if self.owner is None:
raise RuntimeError("lock is not acquired")
elif self.owner is not task:
raise RuntimeError("lock can only released by its owner")
elif self.tasks:
await self.tasks.popleft().trigger()
else:
self.owner = None
async def __aenter__(self):
await self.acquire()
return self
async def __aexit__(self, *args):
await self.release()

31
tests/network_channel.py Normal file
View File

@ -0,0 +1,31 @@
import aiosched
from debugger import Debugger
async def sender(c: aiosched.NetworkChannel, n: int):
for i in range(n):
await c.write(str(i).encode())
print(f"Sent {i}")
await c.close()
print("Sender done")
async def receiver(c: aiosched.NetworkChannel):
while True:
if not await c.pending() and c.closed:
print("Receiver done")
break
item = (await c.read(1)).decode()
print(f"Received {item}")
await aiosched.sleep(1)
async def main(channel: aiosched.NetworkChannel, n: int):
print("Starting sender and receiver")
async with aiosched.with_context() as ctx:
await ctx.spawn(sender, channel, n)
await ctx.spawn(receiver, channel)
print("All done!")
aiosched.run(main, aiosched.NetworkChannel(), 5, debugger=()) # 2 is the max size of the channel