Initial broken work on a generic streams interface

This commit is contained in:
Nocturn9x 2022-10-10 13:35:22 +02:00
parent 55868c450e
commit e37ffdeb06
8 changed files with 252 additions and 119 deletions

View File

@ -16,6 +16,10 @@ rock-solid and structured concurrency framework (I personally recommend trio and
that most of the content of this document is ~~stolen~~ inspired from its documentation)
# Disclaimer #2
This is a toy project. Don't try to use it in production, it *will* explode
## Goals of this project
Making yet another async library might sound dumb in an already fragmented ecosystem like Python's.

View File

@ -15,7 +15,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import functools
# Import libraries and internal resources
from numbers import Number
from giambio.task import Task
@ -33,6 +33,7 @@ from giambio.exceptions import (
ResourceBusy,
GiambioError,
TooSlowError,
ResourceClosed
)
@ -433,12 +434,15 @@ class AsyncScheduler:
task.result = ret.value
task.finished = True
self.join(task)
self.tasks.remove(task)
except CancelledError as cancel:
task.status = "cancelled"
task.cancel_pending = False
task.cancelled = True
self.join(task)
except BaseException as err:
task.exc = err
self.join(task)
if task in self.tasks:
self.tasks.remove(task)
def prune_deadlines(self):
"""
@ -666,6 +670,8 @@ class AsyncScheduler:
self.io_release_task(task)
if task in self.suspended:
self.suspended.remove(task)
if task in self.tasks:
self.tasks.remove(task)
# If the pool (including any enclosing pools) has finished executing
# or we're at the first task that kicked the loop, we can safely
# reschedule the parent(s)
@ -770,13 +776,25 @@ class AsyncScheduler:
task.cancelled = True
task.status = "cancelled"
self.debugger.after_cancel(task)
self.tasks.remove(task)
self.join(task)
else:
# If the task ignores our exception, we'll
# raise it later again
task.cancel_pending = True
def notify_closing(self, stream):
"""
Implements the notify_closing trap
"""
if self.selector.get_map():
for k in filter(
lambda o: o.data == self.current_task,
dict(self.selector.get_map()).values(),
):
self.handle_task_exit(k.data,
functools.partial(k.data.throw(ResourceClosed("stream has been closed"))))
def register_sock(self, sock, evt_type: str):
"""
Registers the given socket inside the

View File

@ -37,7 +37,7 @@ class InternalError(GiambioError):
...
class CancelledError(GiambioError):
class CancelledError(BaseException):
"""
Exception raised by the giambio.objects.Task.cancel() method
to terminate a child task. This should NOT be caught, or

View File

@ -15,14 +15,16 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import socket
import warnings
import os
import giambio
from giambio.exceptions import ResourceClosed
from giambio.traps import want_write, want_read, io_release
from giambio.traps import want_write, want_read, io_release, notify_closing
try:
from ssl import SSLWantReadError, SSLWantWriteError
from ssl import SSLWantReadError, SSLWantWriteError, SSLSocket
WantRead = (BlockingIOError, InterruptedError, SSLWantReadError)
WantWrite = (BlockingIOError, InterruptedError, SSLWantWriteError)
@ -31,16 +33,115 @@ except ImportError:
WantWrite = (BlockingIOError, InterruptedError)
class AsyncSocket:
class AsyncStream:
"""
A generic asynchronous stream over
a file descriptor. Only works on Linux
& co because windows doesn't like select()
to be called on non-socket objects
(Thanks, Microsoft)
"""
def __init__(self, fd: int, open_fd: bool = True, close_on_context_exit: bool = True, **kwargs):
self._fd = fd
self.stream = None
if open_fd:
self.stream = os.fdopen(self._fd, **kwargs)
os.set_blocking(self._fd, False)
self.close_on_context_exit = close_on_context_exit
async def read(self, size: int = -1):
"""
Reads up to size bytes from the
given stream. If size == -1, read
until EOF is reached
"""
while True:
try:
return self.stream.read(size)
except WantRead:
await want_read(self.stream)
async def write(self, data):
"""
Writes data b to the file.
Returns the number of bytes
written
"""
while True:
try:
return self.stream.write(data)
except WantWrite:
await want_write(self.stream)
async def close(self):
"""
Closes the stream asynchronously
"""
if self._fd == -1:
raise ResourceClosed("I/O operation on closed stream")
self._fd = -1
await notify_closing(self.stream)
await io_release(self.stream)
self.stream.close()
self.stream = None
@property
async def fileno(self):
"""
Wrapper socket method
"""
return self._fd
async def __aenter__(self):
self.stream.__enter__()
return self
async def __aexit__(self, *args):
if self._fd != -1 and self.close_on_context_exit:
await self.close()
async def dup(self):
"""
Wrapper stream method
"""
return type(self)(os.dup(self._fd))
def __repr__(self):
return f"AsyncStream({self.stream})"
def __del__(self):
"""
Stream destructor. Do *not* call
this directly: stuff will break
"""
if self._fd != -1:
try:
os.set_blocking(self._fd, False)
os.close(self._fd)
except OSError as e:
warnings.warn(f"An exception occurred in __del__ for stream {self} -> {type(e).__name__}: {e}")
class AsyncSocket(AsyncStream):
"""
Abstraction layer for asynchronous sockets
"""
def __init__(self, sock, do_handshake_on_connect: bool = True):
self.sock = sock
def __init__(self, sock: socket.socket, close_on_context_exit: bool = True, do_handshake_on_connect: bool = True):
super().__init__(sock.fileno(), open_fd=False, close_on_context_exit=close_on_context_exit)
self.do_handshake_on_connect = do_handshake_on_connect
self._fd = sock.fileno()
self.sock.setblocking(False)
self.stream = socket.fromfd(self._fd, sock.family, sock.type, sock.proto)
self.stream.setblocking(False)
# A socket that isn't connected doesn't
# need to be closed
self.needs_closing: bool = False
async def receive(self, max_size: int, flags: int = 0) -> bytes:
"""
@ -52,11 +153,11 @@ class AsyncSocket:
raise ResourceClosed("I/O operation on closed socket")
while True:
try:
return self.sock.recv(max_size, flags)
return self.stream.recv(max_size, flags)
except WantRead:
await want_read(self.sock)
await want_read(self.stream)
except WantWrite:
await want_write(self.sock)
await want_write(self.stream)
async def connect(self, address):
"""
@ -67,12 +168,21 @@ class AsyncSocket:
raise ResourceClosed("I/O operation on closed socket")
while True:
try:
self.sock.connect(address)
self.stream.connect(address)
if self.do_handshake_on_connect:
await self.do_handshake()
return
break
except WantWrite:
await want_write(self.sock)
await want_write(self.stream)
self.needs_closing = True
async def close(self):
"""
Wrapper socket method
"""
if self.needs_closing:
await super().close()
async def accept(self):
"""
@ -83,10 +193,10 @@ class AsyncSocket:
raise ResourceClosed("I/O operation on closed socket")
while True:
try:
remote, addr = self.sock.accept()
remote, addr = self.stream.accept()
return type(self)(remote), addr
except WantRead:
await want_read(self.sock)
await want_read(self.stream)
async def send_all(self, data: bytes, flags: int = 0):
"""
@ -98,32 +208,20 @@ class AsyncSocket:
sent_no = 0
while data:
try:
sent_no = self.sock.send(data, flags)
sent_no = self.stream.send(data, flags)
except WantRead:
await want_read(self.sock)
await want_read(self.stream)
except WantWrite:
await want_write(self.sock)
await want_write(self.stream)
data = data[sent_no:]
async def close(self):
"""
Closes the socket asynchronously
"""
if self._fd == -1:
raise ResourceClosed("I/O operation on closed socket")
await io_release(self.sock)
self.sock.close()
self._fd = -1
self.sock = None
async def shutdown(self, how):
"""
Wrapper socket method
"""
if self.sock:
self.sock.shutdown(how)
if self.stream:
self.stream.shutdown(how)
await giambio.sleep(0) # Checkpoint
async def bind(self, addr: tuple):
@ -136,7 +234,7 @@ class AsyncSocket:
if self._fd == -1:
raise ResourceClosed("I/O operation on closed socket")
self.sock.bind(addr)
self.stream.bind(addr)
async def listen(self, backlog: int):
"""
@ -148,27 +246,12 @@ class AsyncSocket:
if self._fd == -1:
raise ResourceClosed("I/O operation on closed socket")
self.sock.listen(backlog)
async def __aenter__(self):
self.sock.__enter__()
return self
async def __aexit__(self, *args):
if self.sock:
self.sock.__exit__(*args)
self.stream.listen(backlog)
# Yes, I stole these from Curio because I could not be
# arsed to write a bunch of uninteresting simple socket
# methods from scratch, deal with it.
async def fileno(self):
"""
Wrapper socket method
"""
return self._fd
async def settimeout(self, seconds):
"""
Wrapper socket method
@ -188,22 +271,23 @@ class AsyncSocket:
Wrapper socket method
"""
return type(self)(self.sock.dup())
return type(self)(self.stream.dup(), self.do_handshake_on_connect)
async def do_handshake(self):
"""
Wrapper socket method
"""
if not hasattr(self.sock, "do_handshake"):
if not hasattr(self.stream, "do_handshake"):
return
while True:
try:
return self.sock.do_handshake()
self.stream: SSLSocket # Silences pycharm warnings
return self.stream.do_handshake()
except WantRead:
await want_read(self.sock)
await want_read(self.stream)
except WantWrite:
await want_write(self.sock)
await want_write(self.stream)
async def recvfrom(self, buffersize, flags=0):
"""
@ -212,11 +296,11 @@ class AsyncSocket:
while True:
try:
return self.sock.recvfrom(buffersize, flags)
return self.stream.recvfrom(buffersize, flags)
except WantRead:
await want_read(self.sock)
await want_read(self.stream)
except WantWrite:
await want_write(self.sock)
await want_write(self.stream)
async def recvfrom_into(self, buffer, bytes=0, flags=0):
"""
@ -225,11 +309,11 @@ class AsyncSocket:
while True:
try:
return self.sock.recvfrom_into(buffer, bytes, flags)
return self.stream.recvfrom_into(buffer, bytes, flags)
except WantRead:
await want_read(self.sock)
await want_read(self.stream)
except WantWrite:
await want_write(self.sock)
await want_write(self.stream)
async def sendto(self, bytes, flags_or_address, address=None):
"""
@ -243,11 +327,11 @@ class AsyncSocket:
flags = 0
while True:
try:
return self.sock.sendto(bytes, flags, address)
return self.stream.sendto(bytes, flags, address)
except WantWrite:
await want_write(self.sock)
await want_write(self.stream)
except WantRead:
await want_read(self.sock)
await want_read(self.stream)
async def getpeername(self):
"""
@ -256,11 +340,11 @@ class AsyncSocket:
while True:
try:
return self.sock.getpeername()
return self.stream.getpeername()
except WantWrite:
await want_write(self.sock)
await want_write(self.stream)
except WantRead:
await want_read(self.sock)
await want_read(self.stream)
async def getsockname(self):
"""
@ -269,11 +353,11 @@ class AsyncSocket:
while True:
try:
return self.sock.getpeername()
return self.stream.getpeername()
except WantWrite:
await want_write(self.sock)
await want_write(self.stream)
except WantRead:
await want_read(self.sock)
await want_read(self.stream)
async def recvmsg(self, bufsize, ancbufsize=0, flags=0):
"""
@ -282,9 +366,9 @@ class AsyncSocket:
while True:
try:
return self.sock.recvmsg(bufsize, ancbufsize, flags)
return self.stream.recvmsg(bufsize, ancbufsize, flags)
except WantRead:
await want_read(self.sock)
await want_read(self.stream)
async def recvmsg_into(self, buffers, ancbufsize=0, flags=0):
"""
@ -293,9 +377,9 @@ class AsyncSocket:
while True:
try:
return self.sock.recvmsg_into(buffers, ancbufsize, flags)
return self.stream.recvmsg_into(buffers, ancbufsize, flags)
except WantRead:
await want_read(self.sock)
await want_read(self.stream)
async def sendmsg(self, buffers, ancdata=(), flags=0, address=None):
"""
@ -304,17 +388,13 @@ class AsyncSocket:
while True:
try:
return self.sock.sendmsg(buffers, ancdata, flags, address)
return self.stream.sendmsg(buffers, ancdata, flags, address)
except WantRead:
await want_write(self.sock)
await want_write(self.stream)
def __repr__(self):
return f"AsyncSocket({self.sock})"
return f"AsyncSocket({self.stream})"
def __del__(self):
"""
Socket destructor
"""
if not self._fd != -1:
warnings.warn(f"socket '{self}' was destroyed, but was not closed, leading to a potential resource leak")
if self.needs_closing:
super().__del__()

View File

@ -178,6 +178,19 @@ async def want_write(stream):
await create_trap("register_sock", stream, "write")
async def notify_closing(stream):
"""
Notifies the event loop that a given
stream needs to be closed. This makes
all callers waiting on want_read or
want_write crash with a ResourceClosed
exception, but it doesn't actually close
the socket object itself
"""
await create_trap("notify_closing", stream)
async def schedule_tasks(tasks: Iterable[Task]):
"""
Schedules a list of tasks for execution. Usuaully

View File

@ -1,44 +1,49 @@
import sys
from typing import Tuple
import giambio
import logging
from debugger import Debugger
async def sender(sock: giambio.socket.AsyncSocket, q: giambio.Queue):
async def reader(q: giambio.Queue, prompt: str = ""):
in_stream = giambio.io.AsyncStream(sys.stdin.fileno(), close_on_context_exit=False, mode="r")
out_stream = giambio.io.AsyncStream(sys.stdout.fileno(), close_on_context_exit=False, mode="w")
while True:
await sock.send_all(b"yo")
await q.put((0, ""))
await giambio.sleep(1)
await out_stream.write(prompt)
await q.put((0, await in_stream.read()))
async def receiver(sock: giambio.socket.AsyncSocket, q: giambio.Queue):
data = b""
while True:
while not data.endswith(b"\n"):
data += await sock.receive(1024)
temp = await sock.receive(1024)
if not temp:
raise EOFError("end of file")
data += temp
data, rest = data.split(b"\n", maxsplit=2)
buffer = b"".join(rest)
await q.put((1, data.decode()))
data = buffer
async def main(host: Tuple[str, int]):
async def main(host: tuple[str, int]):
"""
Main client entry point
"""
queue = giambio.Queue()
out_stream = giambio.io.AsyncStream(sys.stdout.fileno(), close_on_context_exit=False, mode="w")
async with giambio.create_pool() as pool:
async with giambio.socket.socket() as sock:
await sock.connect(host)
await pool.spawn(sender, sock, queue)
await out_stream.write("Connection successful\n")
await pool.spawn(receiver, sock, queue)
await pool.spawn(reader, queue, "> ")
while True:
op, data = await queue.get()
if op == 0:
print(f"Sent.")
else:
print(f"Received: {data}")
if op == 1:
await out_stream.write(data)
if __name__ == "__main__":
@ -49,7 +54,7 @@ if __name__ == "__main__":
datefmt="%d/%m/%Y %p",
)
try:
giambio.run(main, ("localhost", port))
giambio.run(main, ("localhost", port), debugger=Debugger())
except (Exception, KeyboardInterrupt) as error: # Exceptions propagate!
if isinstance(error, KeyboardInterrupt):
logging.info("Ctrl+C detected, exiting")

View File

@ -1,4 +1,3 @@
from typing import List
import giambio
from giambio.socket import AsyncSocket
import logging
@ -6,7 +5,8 @@ import sys
# An asynchronous chatroom
clients: List[giambio.socket.AsyncSocket] = []
clients: dict[AsyncSocket, list[str, str]] = {}
names: set[str] = set()
async def serve(bind_address: tuple):
@ -26,39 +26,52 @@ async def serve(bind_address: tuple):
while True:
try:
conn, address_tuple = await sock.accept()
clients.append(conn)
clients[conn] = ["", f"{address_tuple[0]}:{address_tuple[1]}"]
logging.info(f"{address_tuple[0]}:{address_tuple[1]} connected")
await pool.spawn(handler, conn, address_tuple)
await pool.spawn(handler, conn)
except Exception as err:
# Because exceptions just *work*
logging.info(f"{address_tuple[0]}:{address_tuple[1]} has raised {type(err).__name__}: {err}")
async def handler(sock: AsyncSocket, client_address: tuple):
async def handler(sock: AsyncSocket):
"""
Handles a single client connection
:param sock: The AsyncSocket object connected to the client
:param client_address: The client's address represented as a tuple
(address, port) where address is a string and port is an integer
:type client_address: tuple
"""
address = f"{client_address[0]}:{client_address[1]}"
address = clients[sock][1]
name = ""
async with sock: # Closes the socket automatically
await sock.send_all(b"Welcome to the chatroom pal, start typing and press enter!\n")
await sock.send_all(b"Welcome to the chatroom pal, may you tell me your name?\n> ")
while True:
while not name.endswith("\n"):
name = (await sock.receive(64)).decode()
name = name[:-1]
if name not in names:
names.add(name)
clients[sock][0] = name
break
else:
await sock.send_all(b"Sorry, but that name is already taken. Try again!\n> ")
await sock.send_all(f"Okay {name}, welcome to the chatroom!\n".encode())
logging.info(f"{name} has joined the chatroom ({address}), informing clients")
for i, client_sock in enumerate(clients):
if client_sock != sock and clients[client_sock][0]:
await client_sock.send_all(f"{name} joins the chatroom!\n> ".encode())
while True:
await sock.send_all(b"> ")
data = await sock.receive(1024)
if not data:
break
elif data == b"exit\n":
await sock.send_all(b"I'm dead dude\n")
raise TypeError("Oh, no, I'm gonna die!")
logging.info(f"Got: {data!r} from {address}")
for i, client_sock in enumerate(clients):
if client_sock != sock and clients[client_sock][0]:
logging.info(f"Sending {data!r} to {':'.join(map(str, await client_sock.getpeername()))}")
if client_sock != sock:
await client_sock.send_all(data)
if not data.endswith(b"\n"):
data += b"\n"
await client_sock.send_all(f"[{name}] ({address}): {data.decode()}> ".encode())
logging.info(f"Sent {data!r} to {i} clients")
logging.info(f"Connection from {address} closed")
clients.remove(sock)

View File

@ -63,7 +63,7 @@ if __name__ == "__main__":
logging.basicConfig(
level=20,
format="[%(levelname)s] %(asctime)s %(message)s",
datefmt="%d/%m/%Y %p",
datefmt="%d/%m/%Y %H:%M:%S %p",
)
try:
giambio.run(serve, ("localhost", port), debugger=())