Fixed bugs, added shielded task scopes and related test

This commit is contained in:
Mattia Giambirtone 2023-05-17 00:27:24 +02:00 committed by nocturn9x
parent 52daf54ee3
commit 242d4818bb
Signed by: nocturn9x
GPG Key ID: 8270F9F467971E59
7 changed files with 57 additions and 29 deletions

View File

@ -102,5 +102,9 @@ __all__ = ["run",
"ChannelReader", "ChannelReader",
"ChannelWriter", "ChannelWriter",
"Semaphore", "Semaphore",
"TimedOut" "TimedOut",
"Task",
"TaskState",
"TaskScope",
"TaskPool"
] ]

View File

@ -10,21 +10,23 @@ class TaskScope:
A task scope A task scope
""" """
def __init__(self, timeout: int | float | None = None, silent: bool = False): def __init__(self, timeout: int | float | None = None, silent: bool = False, shielded: bool = False):
""" """
Public object constructor Public object constructor
""" """
# When do we expire? # When do we expire?
self.timeout = timeout self.timeout = timeout or float("inf")
# Do we raise an error on timeout? # Do we raise an error on timeout?
self.silent = silent self.silent = silent
# Have we timed out? # Has a cancellation attempt been done?
self.timed_out: bool = False self.attempted_cancel: bool = False
# Have we been cancelled?
self.cancelled: bool = False
# Can we be indirectly cancelled? Note that this # Can we be indirectly cancelled? Note that this
# does not affect explicit cancellations via the # does not affect explicit cancellations via the
# cancel() method # cancel() method
self.cancellable: bool = True self.shielded: bool = shielded
# Data about inner and outer scopes. # Data about inner and outer scopes.
# This is used internally to make sure # This is used internally to make sure
# nesting task scopes works as expected # nesting task scopes works as expected
@ -40,6 +42,7 @@ class TaskScope:
that belongs to it that belongs to it
""" """
self.attempted_cancel = True
current_loop().cancel_scope(self) current_loop().cancel_scope(self)
def get_actual_timeout(self): def get_actual_timeout(self):
@ -54,6 +57,11 @@ class TaskScope:
if self.outer is None: if self.outer is None:
return self.timeout return self.timeout
current = self.inner
while current:
if current.shielded:
return float("inf")
current = current.inner
return min([self.timeout, self.outer.get_actual_timeout()]) return min([self.timeout, self.outer.get_actual_timeout()])
def __enter__(self): def __enter__(self):
@ -63,7 +71,8 @@ class TaskScope:
def __exit__(self, exc_type: type, exc_val: BaseException, exc_tb): def __exit__(self, exc_type: type, exc_val: BaseException, exc_tb):
current_loop().close_scope(self) current_loop().close_scope(self)
if isinstance(exc_val, structio.TimedOut) and exc_val.scope is self: if isinstance(exc_val, structio.TimedOut):
self.cancelled = True
return self.silent return self.silent
return False return False
@ -103,13 +112,12 @@ class TaskPool:
return self return self
async def __aexit__(self, exc_type: type, exc_val: BaseException, exc_tb): async def __aexit__(self, exc_type: type, exc_val: BaseException, exc_tb):
if exc_val:
self.error = exc_val
self.scope.cancel()
await checkpoint()
return
try: try:
await suspend() if exc_val:
await checkpoint()
raise exc_val
else:
await suspend()
except (Exception, KeyboardInterrupt) as e: except (Exception, KeyboardInterrupt) as e:
self.error = e self.error = e
self.scope.cancel() self.scope.cancel()
@ -117,7 +125,6 @@ class TaskPool:
current_loop().close_pool(self) current_loop().close_pool(self)
self.scope.__exit__(exc_type, exc_val, exc_tb) self.scope.__exit__(exc_type, exc_val, exc_tb)
if self.error: if self.error:
self.outer.scope.cancel()
raise self.error raise self.error
def done(self): def done(self):

View File

@ -38,7 +38,7 @@ class FIFOKernel(BaseKernel):
self.pool = TaskPool() self.pool = TaskPool()
self.current_pool = self.pool self.current_pool = self.pool
self.current_scope = self.current_pool.scope self.current_scope = self.current_pool.scope
self.current_scope.cancellable = False self.current_scope.shielded = False
self.scopes.append(self.current_scope) self.scopes.append(self.current_scope)
def get_closest_deadline(self): def get_closest_deadline(self):
@ -168,11 +168,7 @@ class FIFOKernel(BaseKernel):
def check_scopes(self): def check_scopes(self):
for scope in self.scopes: for scope in self.scopes:
if scope.timed_out:
continue
if scope.get_actual_timeout() <= self.clock.current_time(): if scope.get_actual_timeout() <= self.clock.current_time():
scope.timed_out = True
scope.cancel()
error = TimedOut("timed out") error = TimedOut("timed out")
error.scope = scope error.scope = scope
self.throw(scope.owner, error) self.throw(scope.owner, error)
@ -301,7 +297,8 @@ class FIFOKernel(BaseKernel):
for waiter in task.waiters: for waiter in task.waiters:
self.reschedule(waiter) self.reschedule(waiter)
task.waiters.clear() task.waiters.clear()
if task.pool.done() and task is not self.entry_point: if task.pool.done():
task.pool.scope.cancelled = True
self.reschedule(task.pool.entry_point) self.reschedule(task.pool.entry_point)
self.release(task) self.release(task)
@ -327,15 +324,11 @@ class FIFOKernel(BaseKernel):
task.pending_cancellation = True task.pending_cancellation = True
def cancel_scope(self, scope: TaskScope): def cancel_scope(self, scope: TaskScope):
if scope.done():
return
inner = scope.inner inner = scope.inner
if inner and inner.cancellable and inner is not self.pool.scope: if inner and not inner.shielded:
scope.inner.cancel() scope.inner.cancel()
self.reschedule(inner.owner)
for task in scope.tasks: for task in scope.tasks:
self.cancel_task(task) self.cancel_task(task)
self.reschedule(scope.owner)
def init_pool(self, pool: TaskPool): def init_pool(self, pool: TaskPool):
pool.outer = self.current_pool pool.outer = self.current_pool

View File

@ -63,7 +63,7 @@ async def main_nested(
if __name__ == "__main__": if __name__ == "__main__":
structio.run( structio.run(
main_nested, main,
[("first", 1), ("third", 3)], [("first", 1), ("third", 3)],
[("second", 2), ("fourth", 4)], [("second", 2), ("fourth", 4)],
) )
@ -72,3 +72,5 @@ if __name__ == "__main__":
[("first", 1), ("third", 3)], [("first", 1), ("third", 3)],
[("second", 2), ("fourth", 4)], [("second", 2), ("fourth", 4)],
) )

View File

@ -40,7 +40,7 @@ async def main_nested(
async with structio.create_pool() as p3: async with structio.create_pool() as p3:
print(f"[main] Spawning children in third context ({hex(id(p3))})") print(f"[main] Spawning children in third context ({hex(id(p3))})")
for name, delay in children_inner: for name, delay in children_inner:
p3.spawn(successful(), name, delay) p3.spawn(successful, name, delay)
print("[main] Children spawned") print("[main] Children spawned")
except TypeError: except TypeError:
print("[main] TypeError caught!") print("[main] TypeError caught!")

22
tests/shields.py Normal file
View File

@ -0,0 +1,22 @@
import structio
async def shielded(i):
print("[shielded] Entering shielded section")
with structio.TaskScope(shielded=True) as s:
await structio.sleep(i)
print(f"[shielded] Slept {i} seconds")
s.shielded = False
print(f"[shielded] Exited shielded section, sleeping {i} more seconds")
await structio.sleep(i)
async def main(i):
print(f"[main] Parent has started, finishing in {i} seconds")
t = structio.clock()
with structio.skip_after(1):
await shielded(i)
print(f"[main] Exited in {structio.clock() - t:.2f} seconds")
structio.run(main, 5)

View File

@ -7,14 +7,14 @@ async def test_silent(i, j):
with structio.skip_after(i) as scope: with structio.skip_after(i) as scope:
print(f"[test] Sleeping for {j} seconds") print(f"[test] Sleeping for {j} seconds")
await structio.sleep(j) await structio.sleep(j)
print(f"[test] Finished in {structio.clock() - k:.2f} seconds (timed out: {scope.timed_out})") print(f"[test] Finished in {structio.clock() - k:.2f} seconds (timed out: {scope.cancelled})")
async def test_loud(i, j): async def test_loud(i, j):
print(f"[test] Parent is alive, exiting after {i:.2f} seconds") print(f"[test] Parent is alive, exiting after {i:.2f} seconds")
k = structio.clock() k = structio.clock()
try: try:
with structio.with_timeout(i) as scope: with structio.with_timeout(i):
print(f"[test] Sleeping for {j} seconds") print(f"[test] Sleeping for {j} seconds")
await structio.sleep(j) await structio.sleep(j)
except structio.TimedOut: except structio.TimedOut: