Fixed bugs, added shielded task scopes and related test
This commit is contained in:
parent
52daf54ee3
commit
242d4818bb
|
@ -102,5 +102,9 @@ __all__ = ["run",
|
||||||
"ChannelReader",
|
"ChannelReader",
|
||||||
"ChannelWriter",
|
"ChannelWriter",
|
||||||
"Semaphore",
|
"Semaphore",
|
||||||
"TimedOut"
|
"TimedOut",
|
||||||
|
"Task",
|
||||||
|
"TaskState",
|
||||||
|
"TaskScope",
|
||||||
|
"TaskPool"
|
||||||
]
|
]
|
||||||
|
|
|
@ -10,21 +10,23 @@ class TaskScope:
|
||||||
A task scope
|
A task scope
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, timeout: int | float | None = None, silent: bool = False):
|
def __init__(self, timeout: int | float | None = None, silent: bool = False, shielded: bool = False):
|
||||||
"""
|
"""
|
||||||
Public object constructor
|
Public object constructor
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# When do we expire?
|
# When do we expire?
|
||||||
self.timeout = timeout
|
self.timeout = timeout or float("inf")
|
||||||
# Do we raise an error on timeout?
|
# Do we raise an error on timeout?
|
||||||
self.silent = silent
|
self.silent = silent
|
||||||
# Have we timed out?
|
# Has a cancellation attempt been done?
|
||||||
self.timed_out: bool = False
|
self.attempted_cancel: bool = False
|
||||||
|
# Have we been cancelled?
|
||||||
|
self.cancelled: bool = False
|
||||||
# Can we be indirectly cancelled? Note that this
|
# Can we be indirectly cancelled? Note that this
|
||||||
# does not affect explicit cancellations via the
|
# does not affect explicit cancellations via the
|
||||||
# cancel() method
|
# cancel() method
|
||||||
self.cancellable: bool = True
|
self.shielded: bool = shielded
|
||||||
# Data about inner and outer scopes.
|
# Data about inner and outer scopes.
|
||||||
# This is used internally to make sure
|
# This is used internally to make sure
|
||||||
# nesting task scopes works as expected
|
# nesting task scopes works as expected
|
||||||
|
@ -40,6 +42,7 @@ class TaskScope:
|
||||||
that belongs to it
|
that belongs to it
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
self.attempted_cancel = True
|
||||||
current_loop().cancel_scope(self)
|
current_loop().cancel_scope(self)
|
||||||
|
|
||||||
def get_actual_timeout(self):
|
def get_actual_timeout(self):
|
||||||
|
@ -54,6 +57,11 @@ class TaskScope:
|
||||||
|
|
||||||
if self.outer is None:
|
if self.outer is None:
|
||||||
return self.timeout
|
return self.timeout
|
||||||
|
current = self.inner
|
||||||
|
while current:
|
||||||
|
if current.shielded:
|
||||||
|
return float("inf")
|
||||||
|
current = current.inner
|
||||||
return min([self.timeout, self.outer.get_actual_timeout()])
|
return min([self.timeout, self.outer.get_actual_timeout()])
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
|
@ -63,7 +71,8 @@ class TaskScope:
|
||||||
|
|
||||||
def __exit__(self, exc_type: type, exc_val: BaseException, exc_tb):
|
def __exit__(self, exc_type: type, exc_val: BaseException, exc_tb):
|
||||||
current_loop().close_scope(self)
|
current_loop().close_scope(self)
|
||||||
if isinstance(exc_val, structio.TimedOut) and exc_val.scope is self:
|
if isinstance(exc_val, structio.TimedOut):
|
||||||
|
self.cancelled = True
|
||||||
return self.silent
|
return self.silent
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@ -103,13 +112,12 @@ class TaskPool:
|
||||||
return self
|
return self
|
||||||
|
|
||||||
async def __aexit__(self, exc_type: type, exc_val: BaseException, exc_tb):
|
async def __aexit__(self, exc_type: type, exc_val: BaseException, exc_tb):
|
||||||
if exc_val:
|
|
||||||
self.error = exc_val
|
|
||||||
self.scope.cancel()
|
|
||||||
await checkpoint()
|
|
||||||
return
|
|
||||||
try:
|
try:
|
||||||
await suspend()
|
if exc_val:
|
||||||
|
await checkpoint()
|
||||||
|
raise exc_val
|
||||||
|
else:
|
||||||
|
await suspend()
|
||||||
except (Exception, KeyboardInterrupt) as e:
|
except (Exception, KeyboardInterrupt) as e:
|
||||||
self.error = e
|
self.error = e
|
||||||
self.scope.cancel()
|
self.scope.cancel()
|
||||||
|
@ -117,7 +125,6 @@ class TaskPool:
|
||||||
current_loop().close_pool(self)
|
current_loop().close_pool(self)
|
||||||
self.scope.__exit__(exc_type, exc_val, exc_tb)
|
self.scope.__exit__(exc_type, exc_val, exc_tb)
|
||||||
if self.error:
|
if self.error:
|
||||||
self.outer.scope.cancel()
|
|
||||||
raise self.error
|
raise self.error
|
||||||
|
|
||||||
def done(self):
|
def done(self):
|
||||||
|
|
|
@ -38,7 +38,7 @@ class FIFOKernel(BaseKernel):
|
||||||
self.pool = TaskPool()
|
self.pool = TaskPool()
|
||||||
self.current_pool = self.pool
|
self.current_pool = self.pool
|
||||||
self.current_scope = self.current_pool.scope
|
self.current_scope = self.current_pool.scope
|
||||||
self.current_scope.cancellable = False
|
self.current_scope.shielded = False
|
||||||
self.scopes.append(self.current_scope)
|
self.scopes.append(self.current_scope)
|
||||||
|
|
||||||
def get_closest_deadline(self):
|
def get_closest_deadline(self):
|
||||||
|
@ -168,11 +168,7 @@ class FIFOKernel(BaseKernel):
|
||||||
|
|
||||||
def check_scopes(self):
|
def check_scopes(self):
|
||||||
for scope in self.scopes:
|
for scope in self.scopes:
|
||||||
if scope.timed_out:
|
|
||||||
continue
|
|
||||||
if scope.get_actual_timeout() <= self.clock.current_time():
|
if scope.get_actual_timeout() <= self.clock.current_time():
|
||||||
scope.timed_out = True
|
|
||||||
scope.cancel()
|
|
||||||
error = TimedOut("timed out")
|
error = TimedOut("timed out")
|
||||||
error.scope = scope
|
error.scope = scope
|
||||||
self.throw(scope.owner, error)
|
self.throw(scope.owner, error)
|
||||||
|
@ -301,7 +297,8 @@ class FIFOKernel(BaseKernel):
|
||||||
for waiter in task.waiters:
|
for waiter in task.waiters:
|
||||||
self.reschedule(waiter)
|
self.reschedule(waiter)
|
||||||
task.waiters.clear()
|
task.waiters.clear()
|
||||||
if task.pool.done() and task is not self.entry_point:
|
if task.pool.done():
|
||||||
|
task.pool.scope.cancelled = True
|
||||||
self.reschedule(task.pool.entry_point)
|
self.reschedule(task.pool.entry_point)
|
||||||
self.release(task)
|
self.release(task)
|
||||||
|
|
||||||
|
@ -327,15 +324,11 @@ class FIFOKernel(BaseKernel):
|
||||||
task.pending_cancellation = True
|
task.pending_cancellation = True
|
||||||
|
|
||||||
def cancel_scope(self, scope: TaskScope):
|
def cancel_scope(self, scope: TaskScope):
|
||||||
if scope.done():
|
|
||||||
return
|
|
||||||
inner = scope.inner
|
inner = scope.inner
|
||||||
if inner and inner.cancellable and inner is not self.pool.scope:
|
if inner and not inner.shielded:
|
||||||
scope.inner.cancel()
|
scope.inner.cancel()
|
||||||
self.reschedule(inner.owner)
|
|
||||||
for task in scope.tasks:
|
for task in scope.tasks:
|
||||||
self.cancel_task(task)
|
self.cancel_task(task)
|
||||||
self.reschedule(scope.owner)
|
|
||||||
|
|
||||||
def init_pool(self, pool: TaskPool):
|
def init_pool(self, pool: TaskPool):
|
||||||
pool.outer = self.current_pool
|
pool.outer = self.current_pool
|
||||||
|
|
|
@ -63,7 +63,7 @@ async def main_nested(
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
structio.run(
|
structio.run(
|
||||||
main_nested,
|
main,
|
||||||
[("first", 1), ("third", 3)],
|
[("first", 1), ("third", 3)],
|
||||||
[("second", 2), ("fourth", 4)],
|
[("second", 2), ("fourth", 4)],
|
||||||
)
|
)
|
||||||
|
@ -72,3 +72,5 @@ if __name__ == "__main__":
|
||||||
[("first", 1), ("third", 3)],
|
[("first", 1), ("third", 3)],
|
||||||
[("second", 2), ("fourth", 4)],
|
[("second", 2), ("fourth", 4)],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -40,7 +40,7 @@ async def main_nested(
|
||||||
async with structio.create_pool() as p3:
|
async with structio.create_pool() as p3:
|
||||||
print(f"[main] Spawning children in third context ({hex(id(p3))})")
|
print(f"[main] Spawning children in third context ({hex(id(p3))})")
|
||||||
for name, delay in children_inner:
|
for name, delay in children_inner:
|
||||||
p3.spawn(successful(), name, delay)
|
p3.spawn(successful, name, delay)
|
||||||
print("[main] Children spawned")
|
print("[main] Children spawned")
|
||||||
except TypeError:
|
except TypeError:
|
||||||
print("[main] TypeError caught!")
|
print("[main] TypeError caught!")
|
||||||
|
|
|
@ -0,0 +1,22 @@
|
||||||
|
import structio
|
||||||
|
|
||||||
|
|
||||||
|
async def shielded(i):
|
||||||
|
print("[shielded] Entering shielded section")
|
||||||
|
with structio.TaskScope(shielded=True) as s:
|
||||||
|
await structio.sleep(i)
|
||||||
|
print(f"[shielded] Slept {i} seconds")
|
||||||
|
s.shielded = False
|
||||||
|
print(f"[shielded] Exited shielded section, sleeping {i} more seconds")
|
||||||
|
await structio.sleep(i)
|
||||||
|
|
||||||
|
|
||||||
|
async def main(i):
|
||||||
|
print(f"[main] Parent has started, finishing in {i} seconds")
|
||||||
|
t = structio.clock()
|
||||||
|
with structio.skip_after(1):
|
||||||
|
await shielded(i)
|
||||||
|
print(f"[main] Exited in {structio.clock() - t:.2f} seconds")
|
||||||
|
|
||||||
|
|
||||||
|
structio.run(main, 5)
|
|
@ -7,14 +7,14 @@ async def test_silent(i, j):
|
||||||
with structio.skip_after(i) as scope:
|
with structio.skip_after(i) as scope:
|
||||||
print(f"[test] Sleeping for {j} seconds")
|
print(f"[test] Sleeping for {j} seconds")
|
||||||
await structio.sleep(j)
|
await structio.sleep(j)
|
||||||
print(f"[test] Finished in {structio.clock() - k:.2f} seconds (timed out: {scope.timed_out})")
|
print(f"[test] Finished in {structio.clock() - k:.2f} seconds (timed out: {scope.cancelled})")
|
||||||
|
|
||||||
|
|
||||||
async def test_loud(i, j):
|
async def test_loud(i, j):
|
||||||
print(f"[test] Parent is alive, exiting after {i:.2f} seconds")
|
print(f"[test] Parent is alive, exiting after {i:.2f} seconds")
|
||||||
k = structio.clock()
|
k = structio.clock()
|
||||||
try:
|
try:
|
||||||
with structio.with_timeout(i) as scope:
|
with structio.with_timeout(i):
|
||||||
print(f"[test] Sleeping for {j} seconds")
|
print(f"[test] Sleeping for {j} seconds")
|
||||||
await structio.sleep(j)
|
await structio.sleep(j)
|
||||||
except structio.TimedOut:
|
except structio.TimedOut:
|
||||||
|
|
Loading…
Reference in New Issue