Identified issue with task.cancel()

This commit is contained in:
nocturn9x 2020-11-14 12:59:58 +01:00
parent 7b4051f3b9
commit cc9eccf027
6 changed files with 73 additions and 66 deletions

View File

@ -26,6 +26,7 @@ from .socket import AsyncSocket, WantWrite, WantRead
from ._layers import Task, TimeQueue
from socket import SOL_SOCKET, SO_ERROR
from ._traps import want_read, want_write
import traceback, sys
class AsyncScheduler:
@ -45,13 +46,13 @@ class AsyncScheduler:
self.tasks = [] # Tasks that are ready to run
self.selector = DefaultSelector() # Selector object to perform I/O multiplexing
self.current_task = None # This will always point to the currently running coroutine (Task object)
self.catch = True
self.joined = (
{}
) # Maps child tasks that need to be joined their respective parent task
self.clock = (
default_timer # Monotonic clock to keep track of elapsed time reliably
)
self.some_cancel = False
self.paused = TimeQueue(self.clock) # Tasks that are asleep
self.events = set() # All Event objects
self.event_waiting = defaultdict(list) # Coroutines waiting on event objects
@ -82,30 +83,37 @@ class AsyncScheduler:
self._check_events()
while self.tasks: # While there are tasks to run
self.current_task = self.tasks.pop(0)
if self.some_cancel:
self._check_cancel()
# Sets the currently running task
if self.current_task.status == "cancel": # Deferred cancellation
self.current_task.cancelled = True
self.current_task.throw(CancelledError(self.current_task))
method, *args = self.current_task.run() # Run a single step with the calculation
self.current_task.status = "run"
getattr(self, f"_{method}")(*args)
# Sneaky method call, thanks to David Beazley for this ;)
except CancelledError as cancelled:
if cancelled.args[0] in self.tasks:
self.tasks.remove(cancelled.args[0]) # Remove the dead task
self.tasks.append(self.current_task)
except CancelledError:
self.current_task.cancelled = True
self._reschedule_parent()
except StopIteration as e: # Coroutine ends
self.current_task.result = e.args[0] if e.args else None
self.current_task.finished = True
self._reschedule_parent()
except RuntimeError:
continue
except BaseException as error: # Coroutine raised
print(error)
self.current_task.exc = error
if self.catch:
self._reschedule_parent()
self._join(self.current_task)
else:
if not isinstance(error, RuntimeError):
raise
self._reschedule_parent()
self._join(self.current_task)
raise
def _check_cancel(self):
"""
Checks for task cancellation
"""
if self.current_task.status == "cancel": # Deferred cancellation
self.current_task.cancelled = True
self.current_task.throw(CancelledError(self.current_task))
def _check_events(self):
"""
@ -126,7 +134,7 @@ class AsyncScheduler:
wait(max(0.0, self.paused[0][0] - self.clock()))
# Sleep until the closest deadline in order not to waste CPU cycles
while self.paused[0][0] < self.clock():
# Reschedules tasks when their deadline has elapsed
# Reschedules tasks when their deadline has elapsed
self.tasks.append(self.paused.get())
if not self.paused:
break
@ -150,7 +158,7 @@ class AsyncScheduler:
entry = Task(func(*args))
self.tasks.append(entry)
self._join(entry)
self._join(entry) # TODO -> Inspect this line, does it actually do anything useful?
self._run()
return entry
@ -261,12 +269,9 @@ class AsyncScheduler:
are independent
"""
if task.status in ("sleep", "I/O") and not task.cancelled:
# It is safe to cancel a task while blocking
task.cancelled = True
task.throw(CancelledError(task))
elif task.status == "run":
task.status = "cancel" # Cancellation is deferred
if not self.some_cancel:
self.some_cancel = True
task.status = "cancel" # Cancellation is deferred
def wrap_socket(self, sock):
"""
@ -282,6 +287,10 @@ class AsyncScheduler:
"""
await want_read(sock)
try:
return sock.recv(buffer)
except WantRead:
await want_write(sock)
return sock.recv(buffer)
async def _accept_sock(self, sock: socket.socket):

View File

@ -55,12 +55,11 @@ class Task:
raise self.exc
return res
async def cancel(self):
"""Cancels the task"""
await cancel(self)
assert self.cancelled, "Task ignored cancellation"
# await join(self) # TODO -> Join ourselves after cancellation?
def __repr__(self):
"""Implements repr(self)"""

View File

@ -51,7 +51,6 @@ class TaskManager:
return task
async def __aenter__(self):
self.loop.catch = True # Restore event loop's status
return self
async def __aexit__(self, exc_type, exc, tb):
@ -59,12 +58,11 @@ class TaskManager:
try:
await task.join()
except BaseException as e:
for task in self.loop.tasks:
await task.cancel()
for _, __, task in self.loop.paused:
await task.cancel()
for tasks in self.loop.event_waiting.values():
for task in tasks:
await task.cancel()
self.loop.catch = False
for running_task in self.loop.tasks:
await running_task.cancel()
for _, __, asleep_task in self.loop.paused:
await asleep_task.cancel()
for waiting_tasks in self.loop.event_waiting.values():
for waiting_task in waiting_tasks:
await waiting_task.cancel()
raise e

View File

@ -24,7 +24,6 @@ from ._traps import sleep
try:
from ssl import SSLWantReadError, SSLWantWriteError
WantRead = (BlockingIOError, InterruptedError, SSLWantReadError)
WantWrite = (BlockingIOError, InterruptedError, SSLWantWriteError)
except ImportError:

View File

@ -15,17 +15,14 @@ async def countdown(n: int):
async def countup(stop: int, step: int = 1):
try:
x = 0
while x < stop:
print(f"Up {x}")
x += 1
await giambio.sleep(step)
print("Countup over")
return 1
except giambio.exceptions.CancelledError:
print("I'm not gonna die!!")
raise BaseException(2)
x = 0
while x < stop:
print(f"Up {x}")
x += 1
await giambio.sleep(step)
print("Countup over")
return 1
async def main():
try:
@ -33,9 +30,16 @@ async def main():
async with giambio.create_pool() as pool:
print("Starting counters")
pool.spawn(countdown, 10)
t = pool.spawn(countup, 5, 2)
await giambio.sleep(2)
await t.cancel()
count_up = pool.spawn(countup, 5, 2)
# raise Exception
# Raising an exception here has a weird
# Behavior: The exception is propagated
# *after* all the child tasks complete,
# which is not what we want
# print("Sleeping for 2 seconds before cancelling")
# await giambio.sleep(2)
# await count_up.cancel() # TODO: Cancel _is_ broken, this does not re-schedule the parent!
# print("Cancelled countup")
print("Task execution complete")
except Exception as e:
print(f"Caught this bad boy in here, propagating it -> {type(e).__name__}: {e}")
@ -46,6 +50,6 @@ if __name__ == "__main__":
print("Starting event loop")
try:
giambio.run(main)
except BaseException as e:
print(f"Exception caught from main event loop!! -> {type(e).__name__}: {e}")
except BaseException as error:
print(f"Exception caught from main event loop! -> {type(error).__name__}: {error}")
print("Event loop done")

View File

@ -7,29 +7,26 @@ import sys
# A test to check for asynchronous I/O
logging.basicConfig(
level=20, format="[%(levelname)s] %(asctime)s %(message)s", datefmt="%d/%m/%Y %p"
)
async def server(address: tuple):
async def serve(address: tuple):
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.bind(address)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.listen(5)
asock = giambio.wrap_socket(sock) # We make the socket an async socket
logging.info(f"Echo server serving asynchronously at {address}")
logging.info(f"Serving asynchronously at {address[0]}:{address[1]}")
while True:
try:
async with giambio.async_pool() as pool:
async with giambio.create_pool() as pool:
conn, addr = await asock.accept()
logging.info(f"{addr} connected")
pool.spawn(echo_handler, conn, addr)
logging.info(f"{addr[0]}:{addr[1]} connected")
pool.spawn(handler, conn, addr)
except TypeError:
print("Looks like we have a naughty boy here!")
async def echo_handler(sock: AsyncSocket, addr: tuple):
async def handler(sock: AsyncSocket, addr: tuple):
addr = f"{addr[0]}:{addr[1]}"
async with sock:
await sock.send_all(b"Welcome to the server pal, feel free to send me something!\n")
while True:
@ -49,11 +46,12 @@ async def echo_handler(sock: AsyncSocket, addr: tuple):
if __name__ == "__main__":
if len(sys.argv) > 1:
port = int(sys.argv[1])
else:
port = 1500
port = int(sys.argv[1]) if len(sys.argv) > 1 else 1500
logging.basicConfig(level=20, format="[%(levelname)s] %(asctime)s %(message)s", datefmt="%d/%m/%Y %p")
try:
giambio.run(server, ("", port))
giambio.run(serve, ("localhost", port))
except (Exception, KeyboardInterrupt) as error: # Exceptions propagate!
print(f"Exiting due to a {type(error).__name__}: '{error}'")
if isinstance(error, KeyboardInterrupt):
logging.info("Ctrl+C detected, exiting")
else:
logging.error(f"Exiting due to a {type(error).__name__}: {error}")