mirror of https://github.com/nocturn9x/giambio.git
Identified issue with task.cancel()
This commit is contained in:
parent
7b4051f3b9
commit
cc9eccf027
|
@ -26,6 +26,7 @@ from .socket import AsyncSocket, WantWrite, WantRead
|
||||||
from ._layers import Task, TimeQueue
|
from ._layers import Task, TimeQueue
|
||||||
from socket import SOL_SOCKET, SO_ERROR
|
from socket import SOL_SOCKET, SO_ERROR
|
||||||
from ._traps import want_read, want_write
|
from ._traps import want_read, want_write
|
||||||
|
import traceback, sys
|
||||||
|
|
||||||
|
|
||||||
class AsyncScheduler:
|
class AsyncScheduler:
|
||||||
|
@ -45,13 +46,13 @@ class AsyncScheduler:
|
||||||
self.tasks = [] # Tasks that are ready to run
|
self.tasks = [] # Tasks that are ready to run
|
||||||
self.selector = DefaultSelector() # Selector object to perform I/O multiplexing
|
self.selector = DefaultSelector() # Selector object to perform I/O multiplexing
|
||||||
self.current_task = None # This will always point to the currently running coroutine (Task object)
|
self.current_task = None # This will always point to the currently running coroutine (Task object)
|
||||||
self.catch = True
|
|
||||||
self.joined = (
|
self.joined = (
|
||||||
{}
|
{}
|
||||||
) # Maps child tasks that need to be joined their respective parent task
|
) # Maps child tasks that need to be joined their respective parent task
|
||||||
self.clock = (
|
self.clock = (
|
||||||
default_timer # Monotonic clock to keep track of elapsed time reliably
|
default_timer # Monotonic clock to keep track of elapsed time reliably
|
||||||
)
|
)
|
||||||
|
self.some_cancel = False
|
||||||
self.paused = TimeQueue(self.clock) # Tasks that are asleep
|
self.paused = TimeQueue(self.clock) # Tasks that are asleep
|
||||||
self.events = set() # All Event objects
|
self.events = set() # All Event objects
|
||||||
self.event_waiting = defaultdict(list) # Coroutines waiting on event objects
|
self.event_waiting = defaultdict(list) # Coroutines waiting on event objects
|
||||||
|
@ -82,30 +83,37 @@ class AsyncScheduler:
|
||||||
self._check_events()
|
self._check_events()
|
||||||
while self.tasks: # While there are tasks to run
|
while self.tasks: # While there are tasks to run
|
||||||
self.current_task = self.tasks.pop(0)
|
self.current_task = self.tasks.pop(0)
|
||||||
|
if self.some_cancel:
|
||||||
|
self._check_cancel()
|
||||||
# Sets the currently running task
|
# Sets the currently running task
|
||||||
if self.current_task.status == "cancel": # Deferred cancellation
|
|
||||||
self.current_task.cancelled = True
|
|
||||||
self.current_task.throw(CancelledError(self.current_task))
|
|
||||||
method, *args = self.current_task.run() # Run a single step with the calculation
|
method, *args = self.current_task.run() # Run a single step with the calculation
|
||||||
self.current_task.status = "run"
|
self.current_task.status = "run"
|
||||||
getattr(self, f"_{method}")(*args)
|
getattr(self, f"_{method}")(*args)
|
||||||
# Sneaky method call, thanks to David Beazley for this ;)
|
# Sneaky method call, thanks to David Beazley for this ;)
|
||||||
except CancelledError as cancelled:
|
except CancelledError:
|
||||||
if cancelled.args[0] in self.tasks:
|
self.current_task.cancelled = True
|
||||||
self.tasks.remove(cancelled.args[0]) # Remove the dead task
|
self._reschedule_parent()
|
||||||
self.tasks.append(self.current_task)
|
|
||||||
except StopIteration as e: # Coroutine ends
|
except StopIteration as e: # Coroutine ends
|
||||||
self.current_task.result = e.args[0] if e.args else None
|
self.current_task.result = e.args[0] if e.args else None
|
||||||
self.current_task.finished = True
|
self.current_task.finished = True
|
||||||
self._reschedule_parent()
|
self._reschedule_parent()
|
||||||
|
except RuntimeError:
|
||||||
|
continue
|
||||||
except BaseException as error: # Coroutine raised
|
except BaseException as error: # Coroutine raised
|
||||||
|
print(error)
|
||||||
self.current_task.exc = error
|
self.current_task.exc = error
|
||||||
if self.catch:
|
self._reschedule_parent()
|
||||||
self._reschedule_parent()
|
self._join(self.current_task)
|
||||||
self._join(self.current_task)
|
raise
|
||||||
else:
|
|
||||||
if not isinstance(error, RuntimeError):
|
def _check_cancel(self):
|
||||||
raise
|
"""
|
||||||
|
Checks for task cancellation
|
||||||
|
"""
|
||||||
|
|
||||||
|
if self.current_task.status == "cancel": # Deferred cancellation
|
||||||
|
self.current_task.cancelled = True
|
||||||
|
self.current_task.throw(CancelledError(self.current_task))
|
||||||
|
|
||||||
def _check_events(self):
|
def _check_events(self):
|
||||||
"""
|
"""
|
||||||
|
@ -126,7 +134,7 @@ class AsyncScheduler:
|
||||||
wait(max(0.0, self.paused[0][0] - self.clock()))
|
wait(max(0.0, self.paused[0][0] - self.clock()))
|
||||||
# Sleep until the closest deadline in order not to waste CPU cycles
|
# Sleep until the closest deadline in order not to waste CPU cycles
|
||||||
while self.paused[0][0] < self.clock():
|
while self.paused[0][0] < self.clock():
|
||||||
# Reschedules tasks when their deadline has elapsed
|
# Reschedules tasks when their deadline has elapsed
|
||||||
self.tasks.append(self.paused.get())
|
self.tasks.append(self.paused.get())
|
||||||
if not self.paused:
|
if not self.paused:
|
||||||
break
|
break
|
||||||
|
@ -150,7 +158,7 @@ class AsyncScheduler:
|
||||||
|
|
||||||
entry = Task(func(*args))
|
entry = Task(func(*args))
|
||||||
self.tasks.append(entry)
|
self.tasks.append(entry)
|
||||||
self._join(entry)
|
self._join(entry) # TODO -> Inspect this line, does it actually do anything useful?
|
||||||
self._run()
|
self._run()
|
||||||
return entry
|
return entry
|
||||||
|
|
||||||
|
@ -261,12 +269,9 @@ class AsyncScheduler:
|
||||||
are independent
|
are independent
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if task.status in ("sleep", "I/O") and not task.cancelled:
|
if not self.some_cancel:
|
||||||
# It is safe to cancel a task while blocking
|
self.some_cancel = True
|
||||||
task.cancelled = True
|
task.status = "cancel" # Cancellation is deferred
|
||||||
task.throw(CancelledError(task))
|
|
||||||
elif task.status == "run":
|
|
||||||
task.status = "cancel" # Cancellation is deferred
|
|
||||||
|
|
||||||
def wrap_socket(self, sock):
|
def wrap_socket(self, sock):
|
||||||
"""
|
"""
|
||||||
|
@ -282,6 +287,10 @@ class AsyncScheduler:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
await want_read(sock)
|
await want_read(sock)
|
||||||
|
try:
|
||||||
|
return sock.recv(buffer)
|
||||||
|
except WantRead:
|
||||||
|
await want_write(sock)
|
||||||
return sock.recv(buffer)
|
return sock.recv(buffer)
|
||||||
|
|
||||||
async def _accept_sock(self, sock: socket.socket):
|
async def _accept_sock(self, sock: socket.socket):
|
||||||
|
|
|
@ -55,12 +55,11 @@ class Task:
|
||||||
raise self.exc
|
raise self.exc
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
async def cancel(self):
|
async def cancel(self):
|
||||||
"""Cancels the task"""
|
"""Cancels the task"""
|
||||||
|
|
||||||
await cancel(self)
|
await cancel(self)
|
||||||
assert self.cancelled, "Task ignored cancellation"
|
# await join(self) # TODO -> Join ourselves after cancellation?
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
"""Implements repr(self)"""
|
"""Implements repr(self)"""
|
||||||
|
|
|
@ -51,7 +51,6 @@ class TaskManager:
|
||||||
return task
|
return task
|
||||||
|
|
||||||
async def __aenter__(self):
|
async def __aenter__(self):
|
||||||
self.loop.catch = True # Restore event loop's status
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
async def __aexit__(self, exc_type, exc, tb):
|
async def __aexit__(self, exc_type, exc, tb):
|
||||||
|
@ -59,12 +58,11 @@ class TaskManager:
|
||||||
try:
|
try:
|
||||||
await task.join()
|
await task.join()
|
||||||
except BaseException as e:
|
except BaseException as e:
|
||||||
for task in self.loop.tasks:
|
for running_task in self.loop.tasks:
|
||||||
await task.cancel()
|
await running_task.cancel()
|
||||||
for _, __, task in self.loop.paused:
|
for _, __, asleep_task in self.loop.paused:
|
||||||
await task.cancel()
|
await asleep_task.cancel()
|
||||||
for tasks in self.loop.event_waiting.values():
|
for waiting_tasks in self.loop.event_waiting.values():
|
||||||
for task in tasks:
|
for waiting_task in waiting_tasks:
|
||||||
await task.cancel()
|
await waiting_task.cancel()
|
||||||
self.loop.catch = False
|
|
||||||
raise e
|
raise e
|
||||||
|
|
|
@ -24,7 +24,6 @@ from ._traps import sleep
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from ssl import SSLWantReadError, SSLWantWriteError
|
from ssl import SSLWantReadError, SSLWantWriteError
|
||||||
|
|
||||||
WantRead = (BlockingIOError, InterruptedError, SSLWantReadError)
|
WantRead = (BlockingIOError, InterruptedError, SSLWantReadError)
|
||||||
WantWrite = (BlockingIOError, InterruptedError, SSLWantWriteError)
|
WantWrite = (BlockingIOError, InterruptedError, SSLWantWriteError)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
|
|
|
@ -15,17 +15,14 @@ async def countdown(n: int):
|
||||||
|
|
||||||
|
|
||||||
async def countup(stop: int, step: int = 1):
|
async def countup(stop: int, step: int = 1):
|
||||||
try:
|
x = 0
|
||||||
x = 0
|
while x < stop:
|
||||||
while x < stop:
|
print(f"Up {x}")
|
||||||
print(f"Up {x}")
|
x += 1
|
||||||
x += 1
|
await giambio.sleep(step)
|
||||||
await giambio.sleep(step)
|
print("Countup over")
|
||||||
print("Countup over")
|
return 1
|
||||||
return 1
|
|
||||||
except giambio.exceptions.CancelledError:
|
|
||||||
print("I'm not gonna die!!")
|
|
||||||
raise BaseException(2)
|
|
||||||
|
|
||||||
async def main():
|
async def main():
|
||||||
try:
|
try:
|
||||||
|
@ -33,9 +30,16 @@ async def main():
|
||||||
async with giambio.create_pool() as pool:
|
async with giambio.create_pool() as pool:
|
||||||
print("Starting counters")
|
print("Starting counters")
|
||||||
pool.spawn(countdown, 10)
|
pool.spawn(countdown, 10)
|
||||||
t = pool.spawn(countup, 5, 2)
|
count_up = pool.spawn(countup, 5, 2)
|
||||||
await giambio.sleep(2)
|
# raise Exception
|
||||||
await t.cancel()
|
# Raising an exception here has a weird
|
||||||
|
# Behavior: The exception is propagated
|
||||||
|
# *after* all the child tasks complete,
|
||||||
|
# which is not what we want
|
||||||
|
# print("Sleeping for 2 seconds before cancelling")
|
||||||
|
# await giambio.sleep(2)
|
||||||
|
# await count_up.cancel() # TODO: Cancel _is_ broken, this does not re-schedule the parent!
|
||||||
|
# print("Cancelled countup")
|
||||||
print("Task execution complete")
|
print("Task execution complete")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Caught this bad boy in here, propagating it -> {type(e).__name__}: {e}")
|
print(f"Caught this bad boy in here, propagating it -> {type(e).__name__}: {e}")
|
||||||
|
@ -46,6 +50,6 @@ if __name__ == "__main__":
|
||||||
print("Starting event loop")
|
print("Starting event loop")
|
||||||
try:
|
try:
|
||||||
giambio.run(main)
|
giambio.run(main)
|
||||||
except BaseException as e:
|
except BaseException as error:
|
||||||
print(f"Exception caught from main event loop!! -> {type(e).__name__}: {e}")
|
print(f"Exception caught from main event loop! -> {type(error).__name__}: {error}")
|
||||||
print("Event loop done")
|
print("Event loop done")
|
||||||
|
|
|
@ -7,29 +7,26 @@ import sys
|
||||||
|
|
||||||
# A test to check for asynchronous I/O
|
# A test to check for asynchronous I/O
|
||||||
|
|
||||||
logging.basicConfig(
|
|
||||||
level=20, format="[%(levelname)s] %(asctime)s %(message)s", datefmt="%d/%m/%Y %p"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
async def serve(address: tuple):
|
||||||
async def server(address: tuple):
|
|
||||||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||||
sock.bind(address)
|
sock.bind(address)
|
||||||
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||||
sock.listen(5)
|
sock.listen(5)
|
||||||
asock = giambio.wrap_socket(sock) # We make the socket an async socket
|
asock = giambio.wrap_socket(sock) # We make the socket an async socket
|
||||||
logging.info(f"Echo server serving asynchronously at {address}")
|
logging.info(f"Serving asynchronously at {address[0]}:{address[1]}")
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
async with giambio.async_pool() as pool:
|
async with giambio.create_pool() as pool:
|
||||||
conn, addr = await asock.accept()
|
conn, addr = await asock.accept()
|
||||||
logging.info(f"{addr} connected")
|
logging.info(f"{addr[0]}:{addr[1]} connected")
|
||||||
pool.spawn(echo_handler, conn, addr)
|
pool.spawn(handler, conn, addr)
|
||||||
except TypeError:
|
except TypeError:
|
||||||
print("Looks like we have a naughty boy here!")
|
print("Looks like we have a naughty boy here!")
|
||||||
|
|
||||||
|
|
||||||
async def echo_handler(sock: AsyncSocket, addr: tuple):
|
async def handler(sock: AsyncSocket, addr: tuple):
|
||||||
|
addr = f"{addr[0]}:{addr[1]}"
|
||||||
async with sock:
|
async with sock:
|
||||||
await sock.send_all(b"Welcome to the server pal, feel free to send me something!\n")
|
await sock.send_all(b"Welcome to the server pal, feel free to send me something!\n")
|
||||||
while True:
|
while True:
|
||||||
|
@ -49,11 +46,12 @@ async def echo_handler(sock: AsyncSocket, addr: tuple):
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
if len(sys.argv) > 1:
|
port = int(sys.argv[1]) if len(sys.argv) > 1 else 1500
|
||||||
port = int(sys.argv[1])
|
logging.basicConfig(level=20, format="[%(levelname)s] %(asctime)s %(message)s", datefmt="%d/%m/%Y %p")
|
||||||
else:
|
|
||||||
port = 1500
|
|
||||||
try:
|
try:
|
||||||
giambio.run(server, ("", port))
|
giambio.run(serve, ("localhost", port))
|
||||||
except (Exception, KeyboardInterrupt) as error: # Exceptions propagate!
|
except (Exception, KeyboardInterrupt) as error: # Exceptions propagate!
|
||||||
print(f"Exiting due to a {type(error).__name__}: '{error}'")
|
if isinstance(error, KeyboardInterrupt):
|
||||||
|
logging.info("Ctrl+C detected, exiting")
|
||||||
|
else:
|
||||||
|
logging.error(f"Exiting due to a {type(error).__name__}: {error}")
|
||||||
|
|
Loading…
Reference in New Issue