From cc9eccf027ff22f230f46f5e351af005adf9ee6c Mon Sep 17 00:00:00 2001 From: nocturn9x Date: Sat, 14 Nov 2020 12:59:58 +0100 Subject: [PATCH] Identified issue with task.cancel() --- giambio/_core.py | 53 ++++++++++++++++++++++++++------------------ giambio/_layers.py | 3 +-- giambio/_managers.py | 16 ++++++------- giambio/socket.py | 1 - tests/count.py | 36 +++++++++++++++++------------- tests/server.py | 30 ++++++++++++------------- 6 files changed, 73 insertions(+), 66 deletions(-) diff --git a/giambio/_core.py b/giambio/_core.py index 37e3698..32c0c22 100644 --- a/giambio/_core.py +++ b/giambio/_core.py @@ -26,6 +26,7 @@ from .socket import AsyncSocket, WantWrite, WantRead from ._layers import Task, TimeQueue from socket import SOL_SOCKET, SO_ERROR from ._traps import want_read, want_write +import traceback, sys class AsyncScheduler: @@ -45,13 +46,13 @@ class AsyncScheduler: self.tasks = [] # Tasks that are ready to run self.selector = DefaultSelector() # Selector object to perform I/O multiplexing self.current_task = None # This will always point to the currently running coroutine (Task object) - self.catch = True self.joined = ( {} ) # Maps child tasks that need to be joined their respective parent task self.clock = ( default_timer # Monotonic clock to keep track of elapsed time reliably ) + self.some_cancel = False self.paused = TimeQueue(self.clock) # Tasks that are asleep self.events = set() # All Event objects self.event_waiting = defaultdict(list) # Coroutines waiting on event objects @@ -82,30 +83,37 @@ class AsyncScheduler: self._check_events() while self.tasks: # While there are tasks to run self.current_task = self.tasks.pop(0) + if self.some_cancel: + self._check_cancel() # Sets the currently running task - if self.current_task.status == "cancel": # Deferred cancellation - self.current_task.cancelled = True - self.current_task.throw(CancelledError(self.current_task)) method, *args = self.current_task.run() # Run a single step with the calculation self.current_task.status = "run" getattr(self, f"_{method}")(*args) # Sneaky method call, thanks to David Beazley for this ;) - except CancelledError as cancelled: - if cancelled.args[0] in self.tasks: - self.tasks.remove(cancelled.args[0]) # Remove the dead task - self.tasks.append(self.current_task) + except CancelledError: + self.current_task.cancelled = True + self._reschedule_parent() except StopIteration as e: # Coroutine ends self.current_task.result = e.args[0] if e.args else None self.current_task.finished = True self._reschedule_parent() + except RuntimeError: + continue except BaseException as error: # Coroutine raised + print(error) self.current_task.exc = error - if self.catch: - self._reschedule_parent() - self._join(self.current_task) - else: - if not isinstance(error, RuntimeError): - raise + self._reschedule_parent() + self._join(self.current_task) + raise + + def _check_cancel(self): + """ + Checks for task cancellation + """ + + if self.current_task.status == "cancel": # Deferred cancellation + self.current_task.cancelled = True + self.current_task.throw(CancelledError(self.current_task)) def _check_events(self): """ @@ -126,7 +134,7 @@ class AsyncScheduler: wait(max(0.0, self.paused[0][0] - self.clock())) # Sleep until the closest deadline in order not to waste CPU cycles while self.paused[0][0] < self.clock(): - # Reschedules tasks when their deadline has elapsed + # Reschedules tasks when their deadline has elapsed self.tasks.append(self.paused.get()) if not self.paused: break @@ -150,7 +158,7 @@ class AsyncScheduler: entry = Task(func(*args)) self.tasks.append(entry) - self._join(entry) + self._join(entry) # TODO -> Inspect this line, does it actually do anything useful? self._run() return entry @@ -261,12 +269,9 @@ class AsyncScheduler: are independent """ - if task.status in ("sleep", "I/O") and not task.cancelled: - # It is safe to cancel a task while blocking - task.cancelled = True - task.throw(CancelledError(task)) - elif task.status == "run": - task.status = "cancel" # Cancellation is deferred + if not self.some_cancel: + self.some_cancel = True + task.status = "cancel" # Cancellation is deferred def wrap_socket(self, sock): """ @@ -282,6 +287,10 @@ class AsyncScheduler: """ await want_read(sock) + try: + return sock.recv(buffer) + except WantRead: + await want_write(sock) return sock.recv(buffer) async def _accept_sock(self, sock: socket.socket): diff --git a/giambio/_layers.py b/giambio/_layers.py index 68a5561..3dd52c5 100644 --- a/giambio/_layers.py +++ b/giambio/_layers.py @@ -55,12 +55,11 @@ class Task: raise self.exc return res - async def cancel(self): """Cancels the task""" await cancel(self) - assert self.cancelled, "Task ignored cancellation" + # await join(self) # TODO -> Join ourselves after cancellation? def __repr__(self): """Implements repr(self)""" diff --git a/giambio/_managers.py b/giambio/_managers.py index e7a666f..bbe6fcb 100644 --- a/giambio/_managers.py +++ b/giambio/_managers.py @@ -51,7 +51,6 @@ class TaskManager: return task async def __aenter__(self): - self.loop.catch = True # Restore event loop's status return self async def __aexit__(self, exc_type, exc, tb): @@ -59,12 +58,11 @@ class TaskManager: try: await task.join() except BaseException as e: - for task in self.loop.tasks: - await task.cancel() - for _, __, task in self.loop.paused: - await task.cancel() - for tasks in self.loop.event_waiting.values(): - for task in tasks: - await task.cancel() - self.loop.catch = False + for running_task in self.loop.tasks: + await running_task.cancel() + for _, __, asleep_task in self.loop.paused: + await asleep_task.cancel() + for waiting_tasks in self.loop.event_waiting.values(): + for waiting_task in waiting_tasks: + await waiting_task.cancel() raise e diff --git a/giambio/socket.py b/giambio/socket.py index fee93a9..60d16c6 100644 --- a/giambio/socket.py +++ b/giambio/socket.py @@ -24,7 +24,6 @@ from ._traps import sleep try: from ssl import SSLWantReadError, SSLWantWriteError - WantRead = (BlockingIOError, InterruptedError, SSLWantReadError) WantWrite = (BlockingIOError, InterruptedError, SSLWantWriteError) except ImportError: diff --git a/tests/count.py b/tests/count.py index d761e33..ee26da6 100644 --- a/tests/count.py +++ b/tests/count.py @@ -15,17 +15,14 @@ async def countdown(n: int): async def countup(stop: int, step: int = 1): - try: - x = 0 - while x < stop: - print(f"Up {x}") - x += 1 - await giambio.sleep(step) - print("Countup over") - return 1 - except giambio.exceptions.CancelledError: - print("I'm not gonna die!!") - raise BaseException(2) + x = 0 + while x < stop: + print(f"Up {x}") + x += 1 + await giambio.sleep(step) + print("Countup over") + return 1 + async def main(): try: @@ -33,9 +30,16 @@ async def main(): async with giambio.create_pool() as pool: print("Starting counters") pool.spawn(countdown, 10) - t = pool.spawn(countup, 5, 2) - await giambio.sleep(2) - await t.cancel() + count_up = pool.spawn(countup, 5, 2) + # raise Exception + # Raising an exception here has a weird + # Behavior: The exception is propagated + # *after* all the child tasks complete, + # which is not what we want + # print("Sleeping for 2 seconds before cancelling") + # await giambio.sleep(2) + # await count_up.cancel() # TODO: Cancel _is_ broken, this does not re-schedule the parent! + # print("Cancelled countup") print("Task execution complete") except Exception as e: print(f"Caught this bad boy in here, propagating it -> {type(e).__name__}: {e}") @@ -46,6 +50,6 @@ if __name__ == "__main__": print("Starting event loop") try: giambio.run(main) - except BaseException as e: - print(f"Exception caught from main event loop!! -> {type(e).__name__}: {e}") + except BaseException as error: + print(f"Exception caught from main event loop! -> {type(error).__name__}: {error}") print("Event loop done") diff --git a/tests/server.py b/tests/server.py index 74251c3..9ab945b 100644 --- a/tests/server.py +++ b/tests/server.py @@ -7,29 +7,26 @@ import sys # A test to check for asynchronous I/O -logging.basicConfig( - level=20, format="[%(levelname)s] %(asctime)s %(message)s", datefmt="%d/%m/%Y %p" -) - -async def server(address: tuple): +async def serve(address: tuple): sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.bind(address) sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) sock.listen(5) asock = giambio.wrap_socket(sock) # We make the socket an async socket - logging.info(f"Echo server serving asynchronously at {address}") + logging.info(f"Serving asynchronously at {address[0]}:{address[1]}") while True: try: - async with giambio.async_pool() as pool: + async with giambio.create_pool() as pool: conn, addr = await asock.accept() - logging.info(f"{addr} connected") - pool.spawn(echo_handler, conn, addr) + logging.info(f"{addr[0]}:{addr[1]} connected") + pool.spawn(handler, conn, addr) except TypeError: print("Looks like we have a naughty boy here!") -async def echo_handler(sock: AsyncSocket, addr: tuple): +async def handler(sock: AsyncSocket, addr: tuple): + addr = f"{addr[0]}:{addr[1]}" async with sock: await sock.send_all(b"Welcome to the server pal, feel free to send me something!\n") while True: @@ -49,11 +46,12 @@ async def echo_handler(sock: AsyncSocket, addr: tuple): if __name__ == "__main__": - if len(sys.argv) > 1: - port = int(sys.argv[1]) - else: - port = 1500 + port = int(sys.argv[1]) if len(sys.argv) > 1 else 1500 + logging.basicConfig(level=20, format="[%(levelname)s] %(asctime)s %(message)s", datefmt="%d/%m/%Y %p") try: - giambio.run(server, ("", port)) + giambio.run(serve, ("localhost", port)) except (Exception, KeyboardInterrupt) as error: # Exceptions propagate! - print(f"Exiting due to a {type(error).__name__}: '{error}'") + if isinstance(error, KeyboardInterrupt): + logging.info("Ctrl+C detected, exiting") + else: + logging.error(f"Exiting due to a {type(error).__name__}: {error}")