From 9bb091cdef05b37f80c91438c709fbe5ec642a63 Mon Sep 17 00:00:00 2001 From: nocturn9x Date: Thu, 26 Aug 2021 16:19:40 +0200 Subject: [PATCH] Fixed nested pools (sorta?) --- giambio/context.py | 7 ++++-- giambio/core.py | 52 ++++++++++++++++++++++++--------------- giambio/run.py | 13 +++------- giambio/task.py | 7 ++++-- giambio/traps.py | 4 +-- tests/nested_exception.py | 6 ++--- tests/nested_pool.py | 1 + 7 files changed, 52 insertions(+), 38 deletions(-) diff --git a/giambio/context.py b/giambio/context.py index bd66fc5..53fbe14 100644 --- a/giambio/context.py +++ b/giambio/context.py @@ -18,7 +18,7 @@ limitations under the License. import types import giambio -from typing import List +from typing import List, Optional class TaskManager: @@ -49,6 +49,9 @@ class TaskManager: # Whether our timeout expired or not self.timed_out: bool = False self._proper_init = False + self.enclosing_pool: Optional["giambio.context.TaskManager"] = giambio.get_event_loop().current_pool + self.enclosed_pool: Optional["giambio.context.TaskManager"] = None + # giambio.get_event_loop().current_pool = self async def spawn(self, func: types.FunctionType, *args) -> "giambio.task.Task": """ @@ -56,7 +59,7 @@ class TaskManager: """ assert self._proper_init, "Cannot use improperly initialized pool" - return await giambio.traps.create_task(func, *args) + return await giambio.traps.create_task(func, self, *args) async def __aenter__(self): """ diff --git a/giambio/core.py b/giambio/core.py index 424a99c..9067b75 100644 --- a/giambio/core.py +++ b/giambio/core.py @@ -236,7 +236,7 @@ class AsyncScheduler: self.current_task.exc = err self.join(self.current_task) - def create_task(self, corofunc: types.FunctionType, *args, **kwargs) -> Task: + def create_task(self, corofunc: types.FunctionType, pool, *args, **kwargs) -> Task: """ Creates a task from a coroutine function and schedules it to run. Any extra keyword or positional argument are then @@ -247,15 +247,18 @@ class AsyncScheduler: :type corofunc: function """ - task = Task(corofunc.__name__ or str(corofunc), corofunc(*args, **kwargs), self.current_pool) - task.next_deadline = self.current_pool.timeout or 0.0 + task = Task(corofunc.__name__ or str(corofunc), corofunc(*args, **kwargs), pool) + task.next_deadline = pool.timeout or 0.0 task.joiners = {self.current_task} self._data = task self.tasks.append(task) self.run_ready.append(task) self.debugger.on_task_spawn(task) - self.current_pool.tasks.append(task) + pool.tasks.append(task) self.reschedule_running() + if self.current_pool and task.pool is not self.current_pool: + self.current_pool.enclosed_pool = task.pool + self.current_pool = task.pool return task def run_task_step(self): @@ -270,8 +273,8 @@ class AsyncScheduler: account, that's self.run's job! """ - # Sets the currently running task data = None + # Sets the currently running task self.current_task = self.run_ready.pop(0) self.debugger.before_task_step(self.current_task) if self.current_task.done(): @@ -293,8 +296,7 @@ class AsyncScheduler: # Some debugging and internal chatter here self.current_task.status = "run" self.current_task.steps += 1 - self.debugger.after_task_step(self.current_task) - if not hasattr(self, method): + if not hasattr(self, method) and not callable(getattr(self, method)): # If this happens, that's quite bad! # This if block is meant to be triggered by other async # libraries, which most likely have different trap names and behaviors @@ -305,6 +307,7 @@ class AsyncScheduler: ) from None # Sneaky method call, thanks to David Beazley for this ;) getattr(self, method)(*args) + self.debugger.after_task_step(self.current_task) def io_release_task(self, task: Task): """ @@ -518,7 +521,10 @@ class AsyncScheduler: # current pool. If, however, there are still some # tasks running, we wait for them to exit in order # to avoid orphaned tasks - return pool.done() + if pool.enclosed_pool and self.cancel_pool(pool.enclosed_pool): + return True + else: + return pool.done() else: # If we're at the main task, we're sure everything else exited return True @@ -582,6 +588,8 @@ class AsyncScheduler: """ task.joined = True + if task is not self.current_task: + task.joiners.add(self.current_task) if task.finished or task.cancelled: if not task.cancelled: self.debugger.on_task_exit(task) @@ -589,33 +597,34 @@ class AsyncScheduler: self.io_release_task(task) if task.pool is None: return - if self.current_pool and self.current_pool.done(): - # If the current pool has finished executing or we're at the first parent + if task.pool.done(): + # If the pool has finished executing or we're at the first parent # task that kicked the loop, we can safely reschedule the parent(s) self.reschedule_joiners(task) elif task.exc: task.status = "crashed" - # TODO: We might want to do a bit more complex traceback hacking to remove any extra - # frames from the exception call stack, but for now removing at least the first one - # seems a sensible approach (it's us catching it so we don't care about that) - task.exc.__traceback__ = task.exc.__traceback__.tb_next + if task.exc.__traceback__: + # TODO: We might want to do a bit more complex traceback hacking to remove any extra + # frames from the exception call stack, but for now removing at least the first one + # seems a sensible approach (it's us catching it so we don't care about that) + task.exc.__traceback__ = task.exc.__traceback__.tb_next if task.last_io: self.io_release_task(task) self.debugger.on_exception_raised(task, task.exc) if task.pool is None: # Parent task has no pool, so we propagate raise - if self.cancel_pool(self.current_pool): + if self.cancel_pool(task.pool): # This will reschedule the parent(s) - # only if all the tasks inside the current + # only if all the tasks inside the task's # pool have finished executing, either # by cancellation, an exception # or just returned - for t in task.joiners: + for t in task.joiners.copy(): # Propagate the exception try: t.throw(task.exc) - except (StopIteration, CancelledError): + except (StopIteration, CancelledError, RuntimeError): # TODO: Need anything else? task.joiners.remove(t) self.reschedule_joiners(task) @@ -657,8 +666,7 @@ class AsyncScheduler: # or dangling resource open after being cancelled, so maybe we need # a different approach altogether if task.status == "io": - for k in filter(lambda o: o.data == task, dict(self.selector.get_map()).values()): - self.selector.unregister(k.fileobj) + self.io_release_task(task) elif task.status == "sleep": self.paused.discard(task) try: @@ -679,6 +687,10 @@ class AsyncScheduler: task.status = "cancelled" self.io_release_task(self.current_task) self.debugger.after_cancel(task) + else: + # If the task ignores our exception, we'll + # raise it later again + task.cancel_pending = True else: # If we can't cancel in a somewhat "graceful" way, we just # defer this operation for later (check run() for more info) diff --git a/giambio/run.py b/giambio/run.py index 0b070ba..2d651ae 100644 --- a/giambio/run.py +++ b/giambio/run.py @@ -42,7 +42,8 @@ def get_event_loop(): def new_event_loop(debugger: BaseDebugger, clock: FunctionType): - """ + """ print(hex(id(pool))) + Associates a new event loop to the current thread and deactivates the old one. This should not be called explicitly unless you know what you're doing. @@ -92,10 +93,7 @@ def create_pool(): Creates an async pool """ - loop = get_event_loop() - pool = TaskManager() - loop.current_pool = pool - return pool + return TaskManager() def with_timeout(timeout: int or float): @@ -103,10 +101,7 @@ def with_timeout(timeout: int or float): Creates an async pool with an associated timeout """ - loop = get_event_loop() # We add 1 to make the timeout intuitive and inclusive (i.e. # a 10 seconds timeout means the task is allowed to run 10 # whole seconds instead of cancelling at the tenth second) - pool = TaskManager(timeout + 1) - loop.current_pool = pool - return pool + return TaskManager(timeout + 1) diff --git a/giambio/task.py b/giambio/task.py index e15c5bf..804c802 100644 --- a/giambio/task.py +++ b/giambio/task.py @@ -7,7 +7,7 @@ Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, @@ -17,6 +17,7 @@ limitations under the License. """ import giambio +import warnings from dataclasses import dataclass, field from typing import Union, Coroutine, Set @@ -95,6 +96,7 @@ class Task: :type err: Exception """ + # self.exc = err return self.coroutine.throw(err) async def join(self): @@ -143,4 +145,5 @@ class Task: self.coroutine.close() except RuntimeError: pass # TODO: This is kinda bad - assert not self.last_io, f"task {self.name} was destroyed, but has pending I/O" + if self.last_io: + warnings.warn(f"task '{self.name}' was destroyed, but has pending I/O") diff --git a/giambio/traps.py b/giambio/traps.py index c73d574..d1d7a6a 100644 --- a/giambio/traps.py +++ b/giambio/traps.py @@ -41,7 +41,7 @@ def create_trap(method, *args): return data -async def create_task(coro: FunctionType, *args): +async def create_task(coro: FunctionType, pool, *args): """ Spawns a new task in the current event loop from a bare coroutine function. All extra positional arguments are passed to the function @@ -56,7 +56,7 @@ async def create_task(coro: FunctionType, *args): "\nWhat you wanna do, instead, is this: giambio.run(your_func, arg1, arg2, ...)" ) elif inspect.iscoroutinefunction(coro): - return await create_trap("create_task", coro, *args) + return await create_trap("create_task", coro, pool, *args) else: raise TypeError("coro must be a coroutine function") diff --git a/tests/nested_exception.py b/tests/nested_exception.py index 005dbc1..4f1f367 100644 --- a/tests/nested_exception.py +++ b/tests/nested_exception.py @@ -5,7 +5,7 @@ from debugger import Debugger async def child(): print("[child] Child spawned!! Sleeping for 2 seconds") await giambio.sleep(2) - print("[child] Had a nice nap!") + print("[child] Had a nice nap, suiciding now!") raise TypeError("rip") # Watch the exception magically propagate! @@ -33,13 +33,13 @@ async def main(): async with giambio.create_pool() as pool: await pool.spawn(child) await pool.spawn(child1) - print("[main] Children spawned, awaiting completion") + print("[main] First 2 children spawned, awaiting completion") async with giambio.create_pool() as new_pool: # This pool will be cancelled by the exception # in the other pool await new_pool.spawn(child2) await new_pool.spawn(child3) - print("[main] 3rd child spawned") + print("[main] Third and fourth children spawned") except Exception as error: # Because exceptions just *work*! print(f"[main] Exception from child caught! {repr(error)}") diff --git a/tests/nested_pool.py b/tests/nested_pool.py index 023f418..428e980 100644 --- a/tests/nested_pool.py +++ b/tests/nested_pool.py @@ -16,6 +16,7 @@ async def main(): async with giambio.create_pool() as a_pool: await a_pool.spawn(child, 3) await a_pool.spawn(child, 4) + # This executes after spawning all 4 tasks print("[main] Children spawned, awaiting completion") # This will *only* execute when everything inside the async with block # has ran, including any other pool