mirror of https://github.com/nocturn9x/giambio.git
Identified issue with task.cancel()
This commit is contained in:
parent
7b4051f3b9
commit
cc9eccf027
|
@ -26,6 +26,7 @@ from .socket import AsyncSocket, WantWrite, WantRead
|
|||
from ._layers import Task, TimeQueue
|
||||
from socket import SOL_SOCKET, SO_ERROR
|
||||
from ._traps import want_read, want_write
|
||||
import traceback, sys
|
||||
|
||||
|
||||
class AsyncScheduler:
|
||||
|
@ -45,13 +46,13 @@ class AsyncScheduler:
|
|||
self.tasks = [] # Tasks that are ready to run
|
||||
self.selector = DefaultSelector() # Selector object to perform I/O multiplexing
|
||||
self.current_task = None # This will always point to the currently running coroutine (Task object)
|
||||
self.catch = True
|
||||
self.joined = (
|
||||
{}
|
||||
) # Maps child tasks that need to be joined their respective parent task
|
||||
self.clock = (
|
||||
default_timer # Monotonic clock to keep track of elapsed time reliably
|
||||
)
|
||||
self.some_cancel = False
|
||||
self.paused = TimeQueue(self.clock) # Tasks that are asleep
|
||||
self.events = set() # All Event objects
|
||||
self.event_waiting = defaultdict(list) # Coroutines waiting on event objects
|
||||
|
@ -82,30 +83,37 @@ class AsyncScheduler:
|
|||
self._check_events()
|
||||
while self.tasks: # While there are tasks to run
|
||||
self.current_task = self.tasks.pop(0)
|
||||
if self.some_cancel:
|
||||
self._check_cancel()
|
||||
# Sets the currently running task
|
||||
if self.current_task.status == "cancel": # Deferred cancellation
|
||||
self.current_task.cancelled = True
|
||||
self.current_task.throw(CancelledError(self.current_task))
|
||||
method, *args = self.current_task.run() # Run a single step with the calculation
|
||||
self.current_task.status = "run"
|
||||
getattr(self, f"_{method}")(*args)
|
||||
# Sneaky method call, thanks to David Beazley for this ;)
|
||||
except CancelledError as cancelled:
|
||||
if cancelled.args[0] in self.tasks:
|
||||
self.tasks.remove(cancelled.args[0]) # Remove the dead task
|
||||
self.tasks.append(self.current_task)
|
||||
except CancelledError:
|
||||
self.current_task.cancelled = True
|
||||
self._reschedule_parent()
|
||||
except StopIteration as e: # Coroutine ends
|
||||
self.current_task.result = e.args[0] if e.args else None
|
||||
self.current_task.finished = True
|
||||
self._reschedule_parent()
|
||||
except RuntimeError:
|
||||
continue
|
||||
except BaseException as error: # Coroutine raised
|
||||
print(error)
|
||||
self.current_task.exc = error
|
||||
if self.catch:
|
||||
self._reschedule_parent()
|
||||
self._join(self.current_task)
|
||||
else:
|
||||
if not isinstance(error, RuntimeError):
|
||||
raise
|
||||
self._reschedule_parent()
|
||||
self._join(self.current_task)
|
||||
raise
|
||||
|
||||
def _check_cancel(self):
|
||||
"""
|
||||
Checks for task cancellation
|
||||
"""
|
||||
|
||||
if self.current_task.status == "cancel": # Deferred cancellation
|
||||
self.current_task.cancelled = True
|
||||
self.current_task.throw(CancelledError(self.current_task))
|
||||
|
||||
def _check_events(self):
|
||||
"""
|
||||
|
@ -126,7 +134,7 @@ class AsyncScheduler:
|
|||
wait(max(0.0, self.paused[0][0] - self.clock()))
|
||||
# Sleep until the closest deadline in order not to waste CPU cycles
|
||||
while self.paused[0][0] < self.clock():
|
||||
# Reschedules tasks when their deadline has elapsed
|
||||
# Reschedules tasks when their deadline has elapsed
|
||||
self.tasks.append(self.paused.get())
|
||||
if not self.paused:
|
||||
break
|
||||
|
@ -150,7 +158,7 @@ class AsyncScheduler:
|
|||
|
||||
entry = Task(func(*args))
|
||||
self.tasks.append(entry)
|
||||
self._join(entry)
|
||||
self._join(entry) # TODO -> Inspect this line, does it actually do anything useful?
|
||||
self._run()
|
||||
return entry
|
||||
|
||||
|
@ -261,12 +269,9 @@ class AsyncScheduler:
|
|||
are independent
|
||||
"""
|
||||
|
||||
if task.status in ("sleep", "I/O") and not task.cancelled:
|
||||
# It is safe to cancel a task while blocking
|
||||
task.cancelled = True
|
||||
task.throw(CancelledError(task))
|
||||
elif task.status == "run":
|
||||
task.status = "cancel" # Cancellation is deferred
|
||||
if not self.some_cancel:
|
||||
self.some_cancel = True
|
||||
task.status = "cancel" # Cancellation is deferred
|
||||
|
||||
def wrap_socket(self, sock):
|
||||
"""
|
||||
|
@ -282,6 +287,10 @@ class AsyncScheduler:
|
|||
"""
|
||||
|
||||
await want_read(sock)
|
||||
try:
|
||||
return sock.recv(buffer)
|
||||
except WantRead:
|
||||
await want_write(sock)
|
||||
return sock.recv(buffer)
|
||||
|
||||
async def _accept_sock(self, sock: socket.socket):
|
||||
|
|
|
@ -55,12 +55,11 @@ class Task:
|
|||
raise self.exc
|
||||
return res
|
||||
|
||||
|
||||
async def cancel(self):
|
||||
"""Cancels the task"""
|
||||
|
||||
await cancel(self)
|
||||
assert self.cancelled, "Task ignored cancellation"
|
||||
# await join(self) # TODO -> Join ourselves after cancellation?
|
||||
|
||||
def __repr__(self):
|
||||
"""Implements repr(self)"""
|
||||
|
|
|
@ -51,7 +51,6 @@ class TaskManager:
|
|||
return task
|
||||
|
||||
async def __aenter__(self):
|
||||
self.loop.catch = True # Restore event loop's status
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
|
@ -59,12 +58,11 @@ class TaskManager:
|
|||
try:
|
||||
await task.join()
|
||||
except BaseException as e:
|
||||
for task in self.loop.tasks:
|
||||
await task.cancel()
|
||||
for _, __, task in self.loop.paused:
|
||||
await task.cancel()
|
||||
for tasks in self.loop.event_waiting.values():
|
||||
for task in tasks:
|
||||
await task.cancel()
|
||||
self.loop.catch = False
|
||||
for running_task in self.loop.tasks:
|
||||
await running_task.cancel()
|
||||
for _, __, asleep_task in self.loop.paused:
|
||||
await asleep_task.cancel()
|
||||
for waiting_tasks in self.loop.event_waiting.values():
|
||||
for waiting_task in waiting_tasks:
|
||||
await waiting_task.cancel()
|
||||
raise e
|
||||
|
|
|
@ -24,7 +24,6 @@ from ._traps import sleep
|
|||
|
||||
try:
|
||||
from ssl import SSLWantReadError, SSLWantWriteError
|
||||
|
||||
WantRead = (BlockingIOError, InterruptedError, SSLWantReadError)
|
||||
WantWrite = (BlockingIOError, InterruptedError, SSLWantWriteError)
|
||||
except ImportError:
|
||||
|
|
|
@ -15,17 +15,14 @@ async def countdown(n: int):
|
|||
|
||||
|
||||
async def countup(stop: int, step: int = 1):
|
||||
try:
|
||||
x = 0
|
||||
while x < stop:
|
||||
print(f"Up {x}")
|
||||
x += 1
|
||||
await giambio.sleep(step)
|
||||
print("Countup over")
|
||||
return 1
|
||||
except giambio.exceptions.CancelledError:
|
||||
print("I'm not gonna die!!")
|
||||
raise BaseException(2)
|
||||
x = 0
|
||||
while x < stop:
|
||||
print(f"Up {x}")
|
||||
x += 1
|
||||
await giambio.sleep(step)
|
||||
print("Countup over")
|
||||
return 1
|
||||
|
||||
|
||||
async def main():
|
||||
try:
|
||||
|
@ -33,9 +30,16 @@ async def main():
|
|||
async with giambio.create_pool() as pool:
|
||||
print("Starting counters")
|
||||
pool.spawn(countdown, 10)
|
||||
t = pool.spawn(countup, 5, 2)
|
||||
await giambio.sleep(2)
|
||||
await t.cancel()
|
||||
count_up = pool.spawn(countup, 5, 2)
|
||||
# raise Exception
|
||||
# Raising an exception here has a weird
|
||||
# Behavior: The exception is propagated
|
||||
# *after* all the child tasks complete,
|
||||
# which is not what we want
|
||||
# print("Sleeping for 2 seconds before cancelling")
|
||||
# await giambio.sleep(2)
|
||||
# await count_up.cancel() # TODO: Cancel _is_ broken, this does not re-schedule the parent!
|
||||
# print("Cancelled countup")
|
||||
print("Task execution complete")
|
||||
except Exception as e:
|
||||
print(f"Caught this bad boy in here, propagating it -> {type(e).__name__}: {e}")
|
||||
|
@ -46,6 +50,6 @@ if __name__ == "__main__":
|
|||
print("Starting event loop")
|
||||
try:
|
||||
giambio.run(main)
|
||||
except BaseException as e:
|
||||
print(f"Exception caught from main event loop!! -> {type(e).__name__}: {e}")
|
||||
except BaseException as error:
|
||||
print(f"Exception caught from main event loop! -> {type(error).__name__}: {error}")
|
||||
print("Event loop done")
|
||||
|
|
|
@ -7,29 +7,26 @@ import sys
|
|||
|
||||
# A test to check for asynchronous I/O
|
||||
|
||||
logging.basicConfig(
|
||||
level=20, format="[%(levelname)s] %(asctime)s %(message)s", datefmt="%d/%m/%Y %p"
|
||||
)
|
||||
|
||||
|
||||
async def server(address: tuple):
|
||||
async def serve(address: tuple):
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
sock.bind(address)
|
||||
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
sock.listen(5)
|
||||
asock = giambio.wrap_socket(sock) # We make the socket an async socket
|
||||
logging.info(f"Echo server serving asynchronously at {address}")
|
||||
logging.info(f"Serving asynchronously at {address[0]}:{address[1]}")
|
||||
while True:
|
||||
try:
|
||||
async with giambio.async_pool() as pool:
|
||||
async with giambio.create_pool() as pool:
|
||||
conn, addr = await asock.accept()
|
||||
logging.info(f"{addr} connected")
|
||||
pool.spawn(echo_handler, conn, addr)
|
||||
logging.info(f"{addr[0]}:{addr[1]} connected")
|
||||
pool.spawn(handler, conn, addr)
|
||||
except TypeError:
|
||||
print("Looks like we have a naughty boy here!")
|
||||
|
||||
|
||||
async def echo_handler(sock: AsyncSocket, addr: tuple):
|
||||
async def handler(sock: AsyncSocket, addr: tuple):
|
||||
addr = f"{addr[0]}:{addr[1]}"
|
||||
async with sock:
|
||||
await sock.send_all(b"Welcome to the server pal, feel free to send me something!\n")
|
||||
while True:
|
||||
|
@ -49,11 +46,12 @@ async def echo_handler(sock: AsyncSocket, addr: tuple):
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) > 1:
|
||||
port = int(sys.argv[1])
|
||||
else:
|
||||
port = 1500
|
||||
port = int(sys.argv[1]) if len(sys.argv) > 1 else 1500
|
||||
logging.basicConfig(level=20, format="[%(levelname)s] %(asctime)s %(message)s", datefmt="%d/%m/%Y %p")
|
||||
try:
|
||||
giambio.run(server, ("", port))
|
||||
giambio.run(serve, ("localhost", port))
|
||||
except (Exception, KeyboardInterrupt) as error: # Exceptions propagate!
|
||||
print(f"Exiting due to a {type(error).__name__}: '{error}'")
|
||||
if isinstance(error, KeyboardInterrupt):
|
||||
logging.info("Ctrl+C detected, exiting")
|
||||
else:
|
||||
logging.error(f"Exiting due to a {type(error).__name__}: {error}")
|
||||
|
|
Loading…
Reference in New Issue