Ported NetworkChannel and related test from giambio

This commit is contained in:
Nocturn9x 2022-10-19 12:22:02 +02:00
parent 3b81702c2b
commit 8f3d7056b7
4 changed files with 192 additions and 33 deletions

View File

@ -20,7 +20,8 @@ from aiosched.internals.syscalls import spawn, wait, sleep, cancel, checkpoint
import aiosched.task
import aiosched.errors
import aiosched.context
from aiosched.sync import Event, Queue, Channel, MemoryChannel
import aiosched.socket
from aiosched.sync import Event, Queue, Channel, MemoryChannel, NetworkChannel
__all__ = [
"run",
@ -37,5 +38,7 @@ __all__ = [
"Queue",
"Channel",
"MemoryChannel",
"checkpoint"
"checkpoint",
"NetworkChannel",
"socket"
]

View File

@ -17,6 +17,7 @@ limitations under the License.
"""
import socket
import ssl
import warnings
import os
import aiosched
@ -27,20 +28,18 @@ from aiosched.internals.syscalls import wait_writable, wait_readable, io_release
try:
from ssl import SSLWantReadError, SSLWantWriteError, SSLSocket
WantRead = (BlockingIOError, InterruptedError, SSLWantReadError)
WantWrite = (BlockingIOError, InterruptedError, SSLWantWriteError)
ReadBlock = (BlockingIOError, InterruptedError, SSLWantReadError)
WriteBlock = (BlockingIOError, InterruptedError, SSLWantWriteError)
except ImportError:
WantRead = (BlockingIOError, InterruptedError)
WantWrite = (BlockingIOError, InterruptedError)
ReadBlock = (BlockingIOError, InterruptedError)
WriteBlock = (BlockingIOError, InterruptedError)
class AsyncStream:
"""
A generic asynchronous stream over
a file descriptor. Only works on Linux
& co because windows doesn't like select()
to be called on non-socket objects
(Thanks, Microsoft)
a file descriptor. Functionality
is OS-dependent
"""
def __init__(self, fd: int, open_fd: bool = True, close_on_context_exit: bool = True, **kwargs):
@ -61,7 +60,7 @@ class AsyncStream:
while True:
try:
return self.stream.read(size)
except WantRead:
except ReadBlock:
await wait_readable(self.stream)
async def write(self, data):
@ -74,7 +73,7 @@ class AsyncStream:
while True:
try:
return self.stream.write(data)
except WantWrite:
except WriteBlock:
await wait_writable(self.stream)
async def close(self):
@ -155,9 +154,9 @@ class AsyncSocket(AsyncStream):
while True:
try:
return self.stream.recv(max_size, flags)
except WantRead:
except ReadBlock:
await wait_readable(self.stream)
except WantWrite:
except WriteBlock:
await wait_writable(self.stream)
async def connect(self, address):
@ -173,7 +172,7 @@ class AsyncSocket(AsyncStream):
if self.do_handshake_on_connect:
await self.do_handshake()
break
except WantWrite:
except WriteBlock:
await wait_writable(self.stream)
self.needs_closing = True
@ -196,7 +195,7 @@ class AsyncSocket(AsyncStream):
try:
remote, addr = self.stream.accept()
return type(self)(remote), addr
except WantRead:
except ReadBlock:
await wait_readable(self.stream)
async def send_all(self, data: bytes, flags: int = 0):
@ -210,9 +209,9 @@ class AsyncSocket(AsyncStream):
while data:
try:
sent_no = self.stream.send(data, flags)
except WantRead:
except ReadBlock:
await wait_readable(self.stream)
except WantWrite:
except WriteBlock:
await wait_writable(self.stream)
data = data[sent_no:]
@ -283,9 +282,9 @@ class AsyncSocket(AsyncStream):
try:
self.stream: SSLSocket # Silences pycharm warnings
return self.stream.do_handshake()
except WantRead:
except ReadBlock:
await wait_readable(self.stream)
except WantWrite:
except WriteBlock:
await wait_writable(self.stream)
async def recvfrom(self, buffersize, flags=0):
@ -296,9 +295,9 @@ class AsyncSocket(AsyncStream):
while True:
try:
return self.stream.recvfrom(buffersize, flags)
except WantRead:
except ReadBlock:
await wait_readable(self.stream)
except WantWrite:
except WriteBlock:
await wait_writable(self.stream)
async def recvfrom_into(self, buffer, bytes=0, flags=0):
@ -309,9 +308,9 @@ class AsyncSocket(AsyncStream):
while True:
try:
return self.stream.recvfrom_into(buffer, bytes, flags)
except WantRead:
except ReadBlock:
await wait_readable(self.stream)
except WantWrite:
except WriteBlock:
await wait_writable(self.stream)
async def sendto(self, bytes, flags_or_address, address=None):
@ -327,9 +326,9 @@ class AsyncSocket(AsyncStream):
while True:
try:
return self.stream.sendto(bytes, flags, address)
except WantWrite:
except WriteBlock:
await wait_writable(self.stream)
except WantRead:
except ReadBlock:
await wait_readable(self.stream)
async def getpeername(self):
@ -340,9 +339,9 @@ class AsyncSocket(AsyncStream):
while True:
try:
return self.stream.getpeername()
except WantWrite:
except WriteBlock:
await wait_writable(self.stream)
except WantRead:
except ReadBlock:
await wait_readable(self.stream)
async def getsockname(self):
@ -353,9 +352,9 @@ class AsyncSocket(AsyncStream):
while True:
try:
return self.stream.getpeername()
except WantWrite:
except WriteBlock:
await wait_writable(self.stream)
except WantRead:
except ReadBlock:
await wait_readable(self.stream)
async def recvmsg(self, bufsize, ancbufsize=0, flags=0):
@ -366,7 +365,7 @@ class AsyncSocket(AsyncStream):
while True:
try:
return self.stream.recvmsg(bufsize, ancbufsize, flags)
except WantRead:
except ReadBlock:
await wait_readable(self.stream)
async def recvmsg_into(self, buffers, ancbufsize=0, flags=0):
@ -377,7 +376,7 @@ class AsyncSocket(AsyncStream):
while True:
try:
return self.stream.recvmsg_into(buffers, ancbufsize, flags)
except WantRead:
except ReadBlock:
await wait_readable(self.stream)
async def sendmsg(self, buffers, ancdata=(), flags=0, address=None):
@ -388,7 +387,7 @@ class AsyncSocket(AsyncStream):
while True:
try:
return self.stream.sendmsg(buffers, ancdata, flags, address)
except WantRead:
except ReadBlock:
await wait_writable(self.stream)
def __repr__(self):

View File

@ -24,6 +24,9 @@ from aiosched.internals.syscalls import (
schedule,
current_task,
)
from aiosched.task import Task
from aiosched.socket import wrap_socket
from socket import socketpair
class Event:
@ -266,3 +269,126 @@ class MemoryChannel(Channel):
"""
return bool(len(self.buffer))
class NetworkChannel(Channel):
"""
A two-way communication channel between tasks
that uses an underlying socket pair to communicate
instead of in-memory queues. Not thread safe
"""
def __init__(self):
"""
Public object constructor
"""
super().__init__(None)
# We use a socket as our buffer instead of a queue
sockets = socketpair()
self.reader = wrap_socket(sockets[0])
self.writer = wrap_socket(sockets[1])
self._internal_buffer = b""
async def write(self, data: bytes):
"""
Writes data to the channel. Blocks if the internal
socket is not currently available. Does nothing
if the channel has been closed
"""
if self.closed:
return
await self.writer.send_all(data)
async def read(self, size: int):
"""
Reads exactly size bytes from the channel. Blocks until
enough data arrives. Extra data is cached and used on the
next read
"""
data = self._internal_buffer
while len(data) < size:
data += await self.reader.receive(size)
self._internal_buffer = data[size:]
data = data[:size]
return data
async def close(self):
"""
Closes the memory channel. Any underlying
data is flushed out of the internal socket
and is lost
"""
self.closed = True
await self.reader.close()
await self.writer.close()
async def pending(self):
"""
Returns if there's pending
data to be read
"""
# TODO: Ugly!
if self.closed:
return False
try:
self._internal_buffer += self.reader.stream.recv(1)
except BlockingIOError:
return False
return True
class Lock:
"""
A simple asynchronous single-owner lock.
Not thread safe
"""
def __init__(self):
"""
Public constructor
"""
self.owner: Task | None = None
self.tasks: deque[Event] = deque()
async def acquire(self):
"""
Acquires the lock
"""
task = await current_task()
if self.owner is None:
self.owner = task
elif task is self.owner:
raise RuntimeError("lock is already acquired by current task")
elif self.owner is not task:
self.tasks.append(Event())
await self.tasks[-1].wait()
self.owner = task
async def release(self):
"""
Releases the lock
"""
task = await current_task()
if self.owner is None:
raise RuntimeError("lock is not acquired")
elif self.owner is not task:
raise RuntimeError("lock can only released by its owner")
elif self.tasks:
await self.tasks.popleft().trigger()
else:
self.owner = None
async def __aenter__(self):
await self.acquire()
return self
async def __aexit__(self, *args):
await self.release()

31
tests/network_channel.py Normal file
View File

@ -0,0 +1,31 @@
import aiosched
from debugger import Debugger
async def sender(c: aiosched.NetworkChannel, n: int):
for i in range(n):
await c.write(str(i).encode())
print(f"Sent {i}")
await c.close()
print("Sender done")
async def receiver(c: aiosched.NetworkChannel):
while True:
if not await c.pending() and c.closed:
print("Receiver done")
break
item = (await c.read(1)).decode()
print(f"Received {item}")
await aiosched.sleep(1)
async def main(channel: aiosched.NetworkChannel, n: int):
print("Starting sender and receiver")
async with aiosched.with_context() as ctx:
await ctx.spawn(sender, channel, n)
await ctx.spawn(receiver, channel)
print("All done!")
aiosched.run(main, aiosched.NetworkChannel(), 5, debugger=()) # 2 is the max size of the channel