diff --git a/giambio/core.py b/giambio/core.py index 237c76b..2ceb3f8 100644 --- a/giambio/core.py +++ b/giambio/core.py @@ -437,7 +437,10 @@ class AsyncScheduler: while self.deadlines and self.deadlines.get_closest_deadline() <= self.clock(): pool = self.deadlines.get() pool.timed_out = True - if not pool.tasks and self.current_task is self.entry_point: + if self.current_task is self.entry_point: + self.paused.discard(self.current_task) + self.io_release_task(self.current_task) + self.reschedule_running() self.handle_task_exit(self.entry_point, partial(self.entry_point.throw, TooSlowError(self.entry_point))) for task in pool.tasks: if not task.done(): @@ -641,14 +644,16 @@ class AsyncScheduler: task.joined = True if task.finished or task.cancelled: if not task.cancelled: + # This way join() returns the + # task's return value + for joiner in task.joiners: + self._data[joiner] = task.result self.debugger.on_task_exit(task) if task.last_io: self.io_release_task(task) # If the pool has finished executing or we're at the first parent # task that kicked the loop, we can safely reschedule the parent(s) - if task.pool is None: - return - if task.pool.done(): + if not task.pool or task.pool.done(): self.reschedule_joiners(task) elif task.exc: task.status = "crashed" diff --git a/tests/timeout3.py b/tests/timeout3.py new file mode 100644 index 0000000..a8496d6 --- /dev/null +++ b/tests/timeout3.py @@ -0,0 +1,26 @@ +import giambio +from debugger import Debugger + + +async def child(name: int): + print(f"[child {name}] Child spawned!! Sleeping for {name} seconds") + await giambio.sleep(name) + print(f"[child {name}] Had a nice nap!") + return name + + +async def main(): + start = giambio.clock() + try: + async with giambio.with_timeout(5) as pool: + task = await pool.spawn(child, 2) + print(await task.join()) + await giambio.sleep(5) + except giambio.exceptions.TooSlowError: + print("[main] One or more children have timed out!") + print(f"[main] Children execution complete in {giambio.clock() - start:.2f} seconds") + return 12 + + +if __name__ == "__main__": + giambio.run(main, debugger=())