From acc436d518437cb040b30e408342c0d25a8d307b Mon Sep 17 00:00:00 2001 From: Nocturn9x Date: Fri, 11 Nov 2022 17:39:11 +0100 Subject: [PATCH] Fixes/additions to I/O mechanism, bugs still exist in network_channel test --- aiosched/context.py | 34 ++++++++++++---- aiosched/internals/syscalls.py | 21 ++++++---- aiosched/io.py | 52 ++++++++++++++++++++---- aiosched/kernel.py | 72 ++++++++++++++++++++++++++++------ aiosched/sync.py | 65 +++++++++++++----------------- aiosched/task.py | 4 +- aiosched/util/debugging.py | 50 +++++++++++++++++++++++ tests/chatroom_server.py | 1 + tests/debugger.py | 20 ++++++++++ tests/network_channel.py | 37 +++++++++-------- 10 files changed, 267 insertions(+), 89 deletions(-) diff --git a/aiosched/context.py b/aiosched/context.py index d1be89b..ddc96e6 100644 --- a/aiosched/context.py +++ b/aiosched/context.py @@ -23,6 +23,7 @@ from aiosched.internals.syscalls import ( set_context, close_context, join, + current_task, ) from typing import Any, Coroutine, Callable @@ -34,13 +35,13 @@ class TaskContext(Task): an exception occurs. A TaskContext object behaves like a regular task and the event loop treats it like a single unit rather than a collection of tasks (in fact, the event - loop doesn't even know whether the current task is a task - context or not, which is by design). TaskContexts can be - nested and will cancel inner ones if an exception is raised - inside them + loop doesn't even know, nor care about, whether the current + task is a task context or not, which is by design). Contexts + can be nested and will cancel inner ones if an exception is + raised inside them """ - def __init__(self, silent: bool = False, gather: bool = True) -> None: + def __init__(self, silent: bool = False, gather: bool = True, timeout: int | float = 0.0) -> None: """ Object constructor """ @@ -49,13 +50,16 @@ class TaskContext(Task): self.tasks: list[Task] = [] # Whether we have been cancelled or not self.cancelled: bool = False - # The context's entry point (needed to forward run() calls and the like) + # The context's entry point (needed to disguise ourselves as a task ;)) self.entry_point: Task | TaskContext | None = None # Do we ignore exceptions? self.silent: bool = silent # Do we gather multiple exceptions from # children tasks? - self.gather: bool = gather + self.gather: bool = gather # TODO: Implement + # For how long do we allow tasks inside us + # to run? + self.timeout: int | float = timeout # TODO: Implement async def spawn( self, func: Callable[..., Coroutine[Any, Any, Any]], *args, **kwargs @@ -78,6 +82,17 @@ class TaskContext(Task): await set_context(self) return self + def __eq__(self, other): + """ + Implements self == other + """ + + if isinstance(other, TaskContext): + return super().__eq__(self, other) + elif isinstance(other, Task): + return other == self.entry_point + return False + async def __aexit__(self, exc_type: Exception, exc: Exception, tb): """ Implements the asynchronous context manager interface, waiting @@ -91,6 +106,11 @@ class TaskContext(Task): # end of the block and wait for all # children to exit if task is self.entry_point: + # We don't wait on the entry + # point because that's us! + # Besides, even if we tried, + # wait() would raise an error + # to avoid a deadlock continue await wait(task) except BaseException as exc: diff --git a/aiosched/internals/syscalls.py b/aiosched/internals/syscalls.py index b0255fb..e6024e6 100644 --- a/aiosched/internals/syscalls.py +++ b/aiosched/internals/syscalls.py @@ -103,10 +103,9 @@ async def checkpoint(): async def suspend(): """ - Suspends the current task. The task is not - rescheduled until some other event (for example - a timer, an event or an I/O operation) reschedules - it + Suspends the calling task indefinitely. + The task can be unsuspended by a timer, + an event or an incoming I/O operation """ await syscall("suspend") @@ -125,7 +124,9 @@ async def join(task: Task): """ Tells the event loop that the current task wants to wait on the given one, but without - waiting for its completion + waiting for its completion. This is a low + level trap and should not be used on its + own """ await syscall("join", task) @@ -140,7 +141,8 @@ async def wait(task: Task) -> Any | None: Returns immediately if the task has completed already, but exceptions are propagated only once. Returns the task's - return value, if it has one + return value, if it has one (returned once + for each call). :param task: The task to wait for :type task: :class: Task @@ -148,7 +150,10 @@ async def wait(task: Task) -> Any | None: """ current = await current_task() - if task is current: + if task == current: + # We don't do an "x is y" check because + # tasks and task contexts can compare equal + # despite having different memory addresses raise SchedulerError("a task cannot join itself") if current not in task.joiners: # Luckily we use a set, so this has O(1) @@ -156,6 +161,8 @@ async def wait(task: Task) -> Any | None: await join(task) # Waiting implies joining! await syscall("wait", task) if task.exc and task.state != TaskState.CANCELLED and task.propagate: + # Task raised an error that wasn't directly caused by a cancellation: + # raise it, but do so only the first time wait was called task.propagate = False raise task.exc return task.result diff --git a/aiosched/io.py b/aiosched/io.py index c3a1e74..f2ab17a 100644 --- a/aiosched/io.py +++ b/aiosched/io.py @@ -17,11 +17,10 @@ limitations under the License. """ import socket -import ssl import warnings import os import aiosched -from aiosched.errors import ResourceClosed +from aiosched.errors import ResourceClosed, ResourceBroken from aiosched.internals.syscalls import ( wait_writable, wait_readable, @@ -99,8 +98,8 @@ class AsyncStream: await io_release(self.stream) self.stream.close() self.stream = None + await aiosched.checkpoint() - @property async def fileno(self): """ Wrapper socket method @@ -132,7 +131,7 @@ class AsyncStream: this directly: stuff will break """ - if self._fd != -1: + if self._fd != -1 and self.stream.fileno() != -1: try: os.set_blocking(self._fd, False) os.close(self._fd) @@ -153,11 +152,18 @@ class AsyncSocket(AsyncStream): close_on_context_exit: bool = True, do_handshake_on_connect: bool = True, ): - super().__init__( - sock.fileno(), open_fd=False, close_on_context_exit=close_on_context_exit - ) + # Do we perform the TCP handshake automatically + # upon connection? This is mostly needed for SSL + # sockets self.do_handshake_on_connect = do_handshake_on_connect - self.stream = socket.fromfd(self._fd, sock.family, sock.type, sock.proto) + # Do we close ourselves upon the end of a context manager? + self.close_on_context_exit = close_on_context_exit + # The socket.fromfd function copies the file descriptor + # instead of using the same one, so we'd be trying to close + # a different resource if we used sock.fileno() instead + # of self.stream.fileno() as our file descriptor + self.stream = socket.fromfd(sock.fileno(), sock.family, sock.type, sock.proto) + self._fd = self.stream.fileno() self.stream.setblocking(False) # A socket that isn't connected doesn't # need to be closed @@ -179,6 +185,21 @@ class AsyncSocket(AsyncStream): except WriteBlock: await wait_writable(self.stream) + async def receive_exactly(self, size: int, flags: int = 0) -> bytes: + """ + Receives exactly size bytes from a socket asynchronously. + """ + + # https://stackoverflow.com/questions/55825905/how-can-i-reliably-read-exactly-n-bytes-from-a-tcp-socket + buf = bytearray(size) + pos = 0 + while pos < size: + n = await self.recv_into(memoryview(buf)[pos:], flags=flags) + if n == 0: + raise ResourceBroken("incomplete read detected") + pos += n + return bytes(buf) + async def connect(self, address): """ Wrapper socket method @@ -240,6 +261,8 @@ class AsyncSocket(AsyncStream): Wrapper socket method """ + if self._fd == -1: + raise ResourceClosed("I/O operation on closed socket") if self.stream: self.stream.shutdown(how) await aiosched.checkpoint() @@ -320,6 +343,19 @@ class AsyncSocket(AsyncStream): except WriteBlock: await wait_writable(self.stream) + async def recv_into(self, buffer, nbytes=0, flags=0): + """ + Wrapper socket method + """ + + while True: + try: + return self.stream.recv_into(buffer, nbytes, flags) + except ReadBlock: + await wait_readable(self.stream) + except WriteBlock: + await wait_writable(self.stream) + async def recvfrom_into(self, buffer, bytes=0, flags=0): """ Wrapper socket method diff --git a/aiosched/kernel.py b/aiosched/kernel.py index d4d56ff..f718d74 100644 --- a/aiosched/kernel.py +++ b/aiosched/kernel.py @@ -113,7 +113,25 @@ class FIFOKernel: to do """ - return not any([self.paused, self.run_ready, self.selector.get_map()]) + if self.current_task and not self.current_task.done(): + # Current task isn't done yet! + return False + if any([self.paused, self.run_ready]): + # There's tasks sleeping and/or on the + # ready queue! + return False + for key in self.selector.get_map().values(): + # We don't just do any([self.paused, self.run_ready, self.selector.get_map()]) + # because we don't want to just know if there's any resources we're waiting on, + # but if there's at least one non-terminated task that owns a resource we're + # waiting on. This avoids issues such as the event loop never exiting if the + # user forgets to close a socket, for example + key.data: Task + if key.data.done(): + continue + elif self.get_task_io(key.data): + return False + return True def close(self, force: bool = False): """ @@ -159,16 +177,30 @@ class FIFOKernel: timeout = 0.0 if self.run_ready: # If there is work to do immediately (tasks to run) we - # can't wait + # can't wait. + # TODO: This could cause I/O starvation in highly concurrent + # environments: maybe a more convoluted scheduling strategy + # where I/O timeouts can only be skipped n times before a + # mandatory x-second timeout occurs is needed? It should of + # course take deadlines into account so that timeouts are + # always delivered in a timely manner and tasks awake from + # sleeping at the right moment timeout = 0.0 elif self.paused: # If there are asleep tasks or deadlines, wait until the closest date - timeout = self.paused.get_closest_deadline() + timeout = self.paused.get_closest_deadline() - self.clock() self.debugger.before_io(timeout) - io_ready = self.selector.select(timeout) # Get sockets that are ready and schedule their tasks - for key, _ in io_ready: - self.run_ready.append(key.data) # Resource ready? Schedule its task + for key, _ in self.selector.select(timeout): + key.data: Task + if key.data.state == TaskState.IO: + # We don't reschedule a task that wasn't + # blocking on I/O before: this way if a + # task waits on a socket and then goes to + # sleep, it won't be woken up early if the + # resource becomes available before its + # deadline expires + self.run_ready.append(key.data) # Resource ready? Schedule its task self.debugger.after_io(self.clock() - before_time) def awake_tasks(self): @@ -220,9 +252,9 @@ class FIFOKernel: our primitives or async methods. Note that this method does NOT catch any - exception arising from tasks, nor does it - take StopIteration or CancelledError into - account: that's the job for run()! + errors arising from tasks, nor does it take + StopIteration or Cancelled exceptions into + account """ # Sets the currently running task @@ -253,12 +285,12 @@ class FIFOKernel: ) if not hasattr(self, method) or not callable(getattr(self, method)): # This if block is meant to be triggered by other async - # libraries, which most likely have different trap names and behaviors + # libraries, which most likely have different method names and behaviors # compared to us. If you get this exception, and you're 100% sure you're # not mixing async primitives from other libraries, then it's a bug! self.current_task.throw( InternalError( - "Uh oh! Something very bad just happened, did you try to mix primitives from other async libraries?" + "Uh oh! Something bad just happened: did you try to mix primitives from other async libraries?" ) ) # Sneaky method call, thanks to David Beazley for this ;) @@ -321,7 +353,8 @@ class FIFOKernel: and self.entry_point.propagate ): # Contexts already manage exceptions for us, - # no need to raise it manually + # no need to raise it manually. If a context + # is not used, *then* we can raise the error raise self.entry_point.exc return self.entry_point.result @@ -334,6 +367,7 @@ class FIFOKernel: if self.selector.get_map() and resource in self.selector.get_map(): self.selector.unregister(resource) + self.debugger.on_io_unschedule(resource) def io_release_task(self, task: Task): """ @@ -348,6 +382,14 @@ class FIFOKernel: self.selector.unregister(key.fileobj) task.last_io = () + def get_task_io(self, task: Task) -> list: + """ + Returns the streams currently in use by + the given task + """ + + return list(map(lambda k: k.fileobj, filter(lambda k: k.data == task, self.selector.get_map().values()))) + def notify_closing(self, stream, broken: bool = False): """ Notifies paused tasks that a stream @@ -452,6 +494,7 @@ class FIFOKernel: self.paused.discard(task) self.io_release_task(task) self.run_ready.extend(task.joiners) + self.reschedule_running() def join(self, task: Task): """ @@ -491,6 +534,7 @@ class FIFOKernel: ctx.tasks.append(ctx.entry_point) self.current_task.context = ctx self.current_task = ctx + self.debugger.on_context_creation(ctx) self.reschedule_running() def close_context(self, ctx: TaskContext): @@ -498,6 +542,7 @@ class FIFOKernel: Closes the given context """ + self.debugger.on_context_exit(ctx) task = ctx.entry_point task.context = None self.current_task = task @@ -547,12 +592,14 @@ class FIFOKernel: # If the event to listen for has changed we just modify it self.selector.modify(resource, evt_type, self.current_task) self.current_task.last_io = (evt_type, resource) + self.debugger.on_io_schedule(resource, evt_type) elif not self.current_task.last_io or self.current_task.last_io[1] != resource: # The task has either registered a new resource or is doing # I/O for the first time self.current_task.last_io = evt_type, resource try: self.selector.register(resource, evt_type, self.current_task) + self.debugger.on_io_schedule(resource, evt_type) except KeyError: # The stream is already being used key = self.selector.get_key(resource) @@ -565,6 +612,7 @@ class FIFOKernel: # off a given stream while another one is # writing to it self.selector.modify(resource, evt_type, self.current_task) + self.debugger.on_io_schedule(resource, evt_type) else: # One task reading and one writing on the same # resource is fine (think producer-consumer), diff --git a/aiosched/sync.py b/aiosched/sync.py index 7fda2ad..fefd317 100644 --- a/aiosched/sync.py +++ b/aiosched/sync.py @@ -18,11 +18,12 @@ limitations under the License. from collections import deque from abc import ABC, abstractmethod from typing import Any -from aiosched.errors import SchedulerError +from aiosched.errors import SchedulerError, ResourceClosed from aiosched.internals.syscalls import ( suspend, schedule, current_task, + wait_readable, ) from aiosched.task import Task from aiosched.socket import wrap_socket @@ -72,7 +73,8 @@ class Event: class Queue: """ - An asynchronous FIFO queue. Not thread safe + An asynchronous FIFO queue. As it is based + on events, it is not thread safe """ def __init__(self, maxsize: int | None = None): @@ -167,7 +169,12 @@ class Channel(ABC): """ A generic, two-way, full-duplex communication channel between tasks. This is just an abstract base class and - should not be instantiated directly + should not be instantiated directly. Please also note + that the read() and write() methods are not implemented + here because their signatures vary across subclasses + depending on the underlying communication mechanism + that is used. Implementors must provide those two methods + when subclassing Channel """ def __init__(self, maxsize: int | None = None): @@ -178,26 +185,6 @@ class Channel(ABC): self.maxsize = maxsize self.closed = False - @abstractmethod - async def write(self, data: str): - """ - Writes data to the channel. Blocks if the internal - queue is full until a spot is available. Does nothing - if the channel has been closed - """ - - return NotImplemented - - @abstractmethod - async def read(self): - """ - Reads data from the channel. Blocks until - a message arrives or returns immediately if - one is already waiting - """ - - return NotImplemented - @abstractmethod async def close(self): """ @@ -220,9 +207,11 @@ class Channel(ABC): class MemoryChannel(Channel): """ A two-way communication channel between tasks. - Operations on this object do not perform any I/O - or other system call and are therefore extremely - efficient. Not thread safe + Operations on this object are based on the Queue + class and do not involve any I/O, making this + an extremely efficient way to pass data around + to tasks. Since this channel is based on queues, + it is not thread safe """ def __init__(self, maxsize: int | None = None): @@ -288,7 +277,8 @@ class NetworkChannel(Channel): sockets = socketpair() self.reader = wrap_socket(sockets[0]) self.writer = wrap_socket(sockets[1]) - self._internal_buffer = b"" + self.reader.needs_closing = True + self.writer.needs_closing = True async def write(self, data: bytes): """ @@ -298,7 +288,7 @@ class NetworkChannel(Channel): """ if self.closed: - return + raise ValueError("I/O operation on closed channel") await self.writer.send_all(data) async def read(self, size: int): @@ -308,12 +298,9 @@ class NetworkChannel(Channel): next read """ - data = self._internal_buffer - while len(data) < size: - data += await self.reader.receive(size) - self._internal_buffer = data[size:] - data = data[:size] - return data + if self.closed: + raise ValueError("I/O operation on closed channel") + return await self.reader.receive_exactly(size) async def close(self): """ @@ -332,13 +319,15 @@ class NetworkChannel(Channel): data to be read """ - # TODO: Ugly! if self.closed: return False - try: - self._internal_buffer += self.reader.stream.recv(1) - except BlockingIOError: + elif self.reader.fileno == -1: return False + else: + try: + await wait_readable(self.reader.stream) + except ResourceClosed: + return False return True diff --git a/aiosched/task.py b/aiosched/task.py index 198be10..7d0bde5 100644 --- a/aiosched/task.py +++ b/aiosched/task.py @@ -135,5 +135,7 @@ class Task: Task destructor """ + if not self.done(): + warnings.warn(f"task '{self.name}' was destroyed, but it has not completed yet") if self.last_io: - warnings.warn(f"task '{self.name}' was destroyed, but has pending I/O") + warnings.warn(f"task '{self.name}' was destroyed, but it has pending I/O") diff --git a/aiosched/util/debugging.py b/aiosched/util/debugging.py index c751c74..8d27c71 100644 --- a/aiosched/util/debugging.py +++ b/aiosched/util/debugging.py @@ -1,5 +1,6 @@ from abc import ABC, abstractmethod from aiosched.task import Task +from aiosched.context import TaskContext class BaseDebugger(ABC): @@ -192,3 +193,52 @@ class BaseDebugger(ABC): """ return NotImplemented + + @abstractmethod + def on_context_creation(self, ctx: TaskContext): + """ + This method is called right after a task + context is initialized, i.e. when set_context + in the event loop is called + + :param ctx: The context object + :type ctx: TaskContext + :return: + """ + + return NotImplemented + + @abstractmethod + def on_context_exit(self, ctx: TaskContext): + """ + This method is called right before a task + context is closed, i.e. when close_context + in the event loop is called + + :param ctx: The context object + :type ctx: TaskContext + :return: + """ + + return NotImplemented + + @abstractmethod + def on_io_schedule(self, stream, event: int): + """ + This method is called whenever the + perform_io primitive is called within + the aiosched event loop with the stream + to be registered in the selector and the + chosen event mask + """ + + return NotImplemented + + @abstractmethod + def on_io_unschedule(self, stream): + """ + This method is called whenever a stream + is unregistered from the loop's I/O selector + """ + + return NotImplemented diff --git a/tests/chatroom_server.py b/tests/chatroom_server.py index cc8096d..9148f7f 100644 --- a/tests/chatroom_server.py +++ b/tests/chatroom_server.py @@ -75,6 +75,7 @@ async def handler(sock: aiosched.socket.AsyncSocket): logging.info(f"Connection from {address} closed") clients.pop(sock) names.discard(name) + logging.info("Handler shutting down") if __name__ == "__main__": diff --git a/tests/debugger.py b/tests/debugger.py index ab33212..fd95b06 100644 --- a/tests/debugger.py +++ b/tests/debugger.py @@ -1,4 +1,5 @@ from aiosched.util.debugging import BaseDebugger +from selectors import EVENT_READ, EVENT_WRITE class Debugger(BaseDebugger): @@ -51,3 +52,22 @@ class Debugger(BaseDebugger): def on_exception_raised(self, task, exc): print(f"== '{task.name}' raised {repr(exc)}") + + def on_context_creation(self, ctx): + print(f"=> A new context was created by {ctx.entry_point.name!r}") + + def on_context_exit(self, ctx): + print(f"=> A context was closed by {ctx.entry_point.name}") + + def on_io_schedule(self, stream, event: int): + evt = "" + if event == EVENT_READ: + evt = "reading" + elif event == EVENT_WRITE: + evt = "writing" + elif event == EVENT_WRITE | EVENT_READ: + evt = "reading or writing" + print(f"|| Stream {stream!r} was scheduled for {evt}") + + def on_io_unschedule(self, stream): + print(f"|| Stream {stream!r} was unscheduled") diff --git a/tests/network_channel.py b/tests/network_channel.py index 6b7bea8..582af6c 100644 --- a/tests/network_channel.py +++ b/tests/network_channel.py @@ -2,30 +2,35 @@ import aiosched from debugger import Debugger -async def sender(c: aiosched.NetworkChannel, n: int): +async def producer(c: aiosched.NetworkChannel, n: int): + print("[producer] Started") for i in range(n): await c.write(str(i).encode()) - print(f"Sent {i}") - await c.close() - print("Sender done") + print(f"[producer] Sent {i}") + await aiosched.sleep(0.5) # This makes the receiver wait on us! + #await c.close() + print("[producer] Done") -async def receiver(c: aiosched.NetworkChannel): - while True: - if not await c.pending() and c.closed: - print("Receiver done") - break - item = (await c.read(1)).decode() - print(f"Received {item}") - await aiosched.sleep(1) +async def consumer(c: aiosched.NetworkChannel): + print("[receiver] Started") + try: + while await c.pending(): + item = await c.read(1) + print(f"[consumer] Received {item.decode()}") + # await aiosched.sleep(2) # If you uncomment this, the except block will be triggered + except aiosched.errors.ResourceClosed: + print("[consumer] Stream has been closed early!") + print("[consumer] Done") async def main(channel: aiosched.NetworkChannel, n: int): - print("Starting sender and receiver") + t = aiosched.clock() + print("[main] Starting children") async with aiosched.with_context() as ctx: - await ctx.spawn(sender, channel, n) - await ctx.spawn(receiver, channel) - print("All done!") + await ctx.spawn(consumer, channel) + await ctx.spawn(producer, channel, n) + print(f"[main] All done in {aiosched.clock() - t:.2f} seconds") aiosched.run(