diff --git a/structio/__init__.py b/structio/__init__.py index fd26e11..3392059 100644 --- a/structio/__init__.py +++ b/structio/__init__.py @@ -92,6 +92,8 @@ def clock(): async def _join(self: Task): + if self.done(): + return self.result self.waiters.add(_run.current_task()) await _suspend() if self.state == TaskState.CRASHED: diff --git a/structio/core/kernels/fifo.py b/structio/core/kernels/fifo.py index 925e3a3..aa13f67 100644 --- a/structio/core/kernels/fifo.py +++ b/structio/core/kernels/fifo.py @@ -199,8 +199,6 @@ class FIFOKernel(BaseKernel): def throw(self, task: Task, err: BaseException): if task.done(): return - if task.pool.scope.shielded: - return if task.state == TaskState.PAUSED: self.paused.discard(task) elif task.state == TaskState.IO: @@ -372,7 +370,8 @@ class FIFOKernel(BaseKernel): self.event("on_exception_raised", task, task.exc) for waiter in task.waiters: self.reschedule(waiter) - self.throw(task.pool.scope.owner, task.exc) + if task.pool.scope.owner is not self.current_task: + self.throw(task.pool.scope.owner, task.exc) task.waiters.clear() self.release(task) @@ -403,10 +402,20 @@ class FIFOKernel(BaseKernel): def cancel_task(self, task: Task): if task.done(): return + if task.state == TaskState.RUNNING: + # Can't cancel a task while it's + # running, will raise ValueError + # if we try. We defer it for later + task.pending_cancellation = True + return err = Cancelled() err.scope = task.pool.scope self.throw(task, err) if task.state != TaskState.CANCELLED: + # Task is stubborn. But so are we, + # so we'll redeliver the cancellation + # every time said task tries to call any + # event loop primitive task.pending_cancellation = True def cancel_scope(self, scope: TaskScope): @@ -416,18 +425,20 @@ class FIFOKernel(BaseKernel): # called synchronously by TaskScope.cancel(), # so there is nowhere to throw an exception # to - if self.current_task in scope.tasks: + if self.current_task in scope.tasks and self.current_task is not scope.owner: self.current_task.pending_cancellation = True inner = scope.inner if inner and not inner.shielded: self.cancel_scope(inner) - for task in scope.tasks.copy(): + for task in scope.tasks: if task is self.current_task: continue - # We make a copy of the list because we - # need to make sure that tasks aren't - # removed out from under us self.cancel_task(task) + if scope is not self.current_task.pool.scope and scope.owner is not self.current_task: + # Handles the case where the current task calls + # cancel() for a scope which it doesn't own, which + # is an entirely reasonable thing to do + self.cancel_task(scope.owner) def init_pool(self, pool: TaskPool): pool.outer = self.current_pool diff --git a/tests/shields.py b/tests/shields.py index dc2aba7..13baae9 100644 --- a/tests/shields.py +++ b/tests/shields.py @@ -14,7 +14,7 @@ async def shielded(i): async def main(i): print(f"[main] Parent has started, finishing in {i} seconds") t = structio.clock() - with structio.skip_after(1): + with structio.skip_after(i): await shielded(i) print(f"[main] Exited in {structio.clock() - t:.2f} seconds") @@ -22,7 +22,6 @@ async def main(i): async def canceller(s, i): print("[canceller] Entering shielded section") with s: - s.shielded = True await structio.sleep(i) @@ -30,9 +29,10 @@ async def main_cancel(i, j): print(f"[main] Parent has started, finishing in {j} seconds") t = structio.clock() async with structio.create_pool() as p: - s = structio.TaskScope() - p.spawn(canceller, s, i) + s = structio.TaskScope(shielded=True) + task = p.spawn(canceller, s, i) await structio.sleep(j) + assert not task.done() print("[main] Canceling scope") # Shields only protect from indirect cancellations # coming from outer scopes: they are still cancellable @@ -41,5 +41,5 @@ async def main_cancel(i, j): print(f"[main] Exited in {structio.clock() - t:.2f} seconds") -structio.run(main, 5) -structio.run(main_cancel, 5, 3) +structio.run(main, 2) +structio.run(main_cancel, 5, 2)