Fixed bugs, added shielded task scopes and related test

This commit is contained in:
Mattia Giambirtone 2023-05-17 00:27:24 +02:00 committed by nocturn9x
parent 52daf54ee3
commit 242d4818bb
Signed by: nocturn9x
GPG Key ID: 8270F9F467971E59
7 changed files with 57 additions and 29 deletions

View File

@ -102,5 +102,9 @@ __all__ = ["run",
"ChannelReader",
"ChannelWriter",
"Semaphore",
"TimedOut"
"TimedOut",
"Task",
"TaskState",
"TaskScope",
"TaskPool"
]

View File

@ -10,21 +10,23 @@ class TaskScope:
A task scope
"""
def __init__(self, timeout: int | float | None = None, silent: bool = False):
def __init__(self, timeout: int | float | None = None, silent: bool = False, shielded: bool = False):
"""
Public object constructor
"""
# When do we expire?
self.timeout = timeout
self.timeout = timeout or float("inf")
# Do we raise an error on timeout?
self.silent = silent
# Have we timed out?
self.timed_out: bool = False
# Has a cancellation attempt been done?
self.attempted_cancel: bool = False
# Have we been cancelled?
self.cancelled: bool = False
# Can we be indirectly cancelled? Note that this
# does not affect explicit cancellations via the
# cancel() method
self.cancellable: bool = True
self.shielded: bool = shielded
# Data about inner and outer scopes.
# This is used internally to make sure
# nesting task scopes works as expected
@ -40,6 +42,7 @@ class TaskScope:
that belongs to it
"""
self.attempted_cancel = True
current_loop().cancel_scope(self)
def get_actual_timeout(self):
@ -54,6 +57,11 @@ class TaskScope:
if self.outer is None:
return self.timeout
current = self.inner
while current:
if current.shielded:
return float("inf")
current = current.inner
return min([self.timeout, self.outer.get_actual_timeout()])
def __enter__(self):
@ -63,7 +71,8 @@ class TaskScope:
def __exit__(self, exc_type: type, exc_val: BaseException, exc_tb):
current_loop().close_scope(self)
if isinstance(exc_val, structio.TimedOut) and exc_val.scope is self:
if isinstance(exc_val, structio.TimedOut):
self.cancelled = True
return self.silent
return False
@ -103,13 +112,12 @@ class TaskPool:
return self
async def __aexit__(self, exc_type: type, exc_val: BaseException, exc_tb):
if exc_val:
self.error = exc_val
self.scope.cancel()
await checkpoint()
return
try:
await suspend()
if exc_val:
await checkpoint()
raise exc_val
else:
await suspend()
except (Exception, KeyboardInterrupt) as e:
self.error = e
self.scope.cancel()
@ -117,7 +125,6 @@ class TaskPool:
current_loop().close_pool(self)
self.scope.__exit__(exc_type, exc_val, exc_tb)
if self.error:
self.outer.scope.cancel()
raise self.error
def done(self):

View File

@ -38,7 +38,7 @@ class FIFOKernel(BaseKernel):
self.pool = TaskPool()
self.current_pool = self.pool
self.current_scope = self.current_pool.scope
self.current_scope.cancellable = False
self.current_scope.shielded = False
self.scopes.append(self.current_scope)
def get_closest_deadline(self):
@ -168,11 +168,7 @@ class FIFOKernel(BaseKernel):
def check_scopes(self):
for scope in self.scopes:
if scope.timed_out:
continue
if scope.get_actual_timeout() <= self.clock.current_time():
scope.timed_out = True
scope.cancel()
error = TimedOut("timed out")
error.scope = scope
self.throw(scope.owner, error)
@ -301,7 +297,8 @@ class FIFOKernel(BaseKernel):
for waiter in task.waiters:
self.reschedule(waiter)
task.waiters.clear()
if task.pool.done() and task is not self.entry_point:
if task.pool.done():
task.pool.scope.cancelled = True
self.reschedule(task.pool.entry_point)
self.release(task)
@ -327,15 +324,11 @@ class FIFOKernel(BaseKernel):
task.pending_cancellation = True
def cancel_scope(self, scope: TaskScope):
if scope.done():
return
inner = scope.inner
if inner and inner.cancellable and inner is not self.pool.scope:
if inner and not inner.shielded:
scope.inner.cancel()
self.reschedule(inner.owner)
for task in scope.tasks:
self.cancel_task(task)
self.reschedule(scope.owner)
def init_pool(self, pool: TaskPool):
pool.outer = self.current_pool

View File

@ -63,7 +63,7 @@ async def main_nested(
if __name__ == "__main__":
structio.run(
main_nested,
main,
[("first", 1), ("third", 3)],
[("second", 2), ("fourth", 4)],
)
@ -72,3 +72,5 @@ if __name__ == "__main__":
[("first", 1), ("third", 3)],
[("second", 2), ("fourth", 4)],
)

View File

@ -40,7 +40,7 @@ async def main_nested(
async with structio.create_pool() as p3:
print(f"[main] Spawning children in third context ({hex(id(p3))})")
for name, delay in children_inner:
p3.spawn(successful(), name, delay)
p3.spawn(successful, name, delay)
print("[main] Children spawned")
except TypeError:
print("[main] TypeError caught!")

22
tests/shields.py Normal file
View File

@ -0,0 +1,22 @@
import structio
async def shielded(i):
print("[shielded] Entering shielded section")
with structio.TaskScope(shielded=True) as s:
await structio.sleep(i)
print(f"[shielded] Slept {i} seconds")
s.shielded = False
print(f"[shielded] Exited shielded section, sleeping {i} more seconds")
await structio.sleep(i)
async def main(i):
print(f"[main] Parent has started, finishing in {i} seconds")
t = structio.clock()
with structio.skip_after(1):
await shielded(i)
print(f"[main] Exited in {structio.clock() - t:.2f} seconds")
structio.run(main, 5)

View File

@ -7,14 +7,14 @@ async def test_silent(i, j):
with structio.skip_after(i) as scope:
print(f"[test] Sleeping for {j} seconds")
await structio.sleep(j)
print(f"[test] Finished in {structio.clock() - k:.2f} seconds (timed out: {scope.timed_out})")
print(f"[test] Finished in {structio.clock() - k:.2f} seconds (timed out: {scope.cancelled})")
async def test_loud(i, j):
print(f"[test] Parent is alive, exiting after {i:.2f} seconds")
k = structio.clock()
try:
with structio.with_timeout(i) as scope:
with structio.with_timeout(i):
print(f"[test] Sleeping for {j} seconds")
await structio.sleep(j)
except structio.TimedOut: