join() partially fixed

This commit is contained in:
nocturn9x 2020-11-16 21:49:13 +01:00
parent 981a598ae7
commit 10c1b33e20
5 changed files with 33 additions and 42 deletions

View File

@ -61,7 +61,6 @@ class TaskManager:
for task in self.tasks: for task in self.tasks:
try: try:
await task.join() await task.join()
except BaseException as task_error: except BaseException:
for dead in self.tasks: for to_cancel in self.tasks:
await dead.cancel() await to_cancel.cancel()
raise task.exc

View File

@ -129,19 +129,19 @@ class AsyncScheduler:
# Sneaky method call, thanks to David Beazley for this ;) # Sneaky method call, thanks to David Beazley for this ;)
getattr(self, method)(*args) getattr(self, method)(*args)
except CancelledError: except CancelledError:
self.current_task.status = "end" self.current_task.status = "cancelled"
self.current_task.cancelled = True self.current_task.cancelled = True
self.current_task.cancel_pending = False self.current_task.cancel_pending = False
self.join(self.current_task, self.current_task.parent)
except StopIteration as ret: except StopIteration as ret:
# Coroutine ends # Coroutine ends
self.current_task.status = "end" self.current_task.status = "end"
self.current_task.result = ret.value self.current_task.result = ret.value
self.current_task.finished = True self.current_task.finished = True
self.join()
except BaseException as err: except BaseException as err:
self.current_task.exc = err self.current_task.exc = err
self.current_task.status = "crashed" self.current_task.status = "crashed"
self.join(self.current_task, self.current_task.parent) self.join()
def do_cancel(self): def do_cancel(self):
""" """
@ -151,6 +151,7 @@ class AsyncScheduler:
""" """
self.current_task.throw(CancelledError) self.current_task.throw(CancelledError)
self.current_task.coroutine.close()
def get_running(self): def get_running(self):
""" """
@ -218,40 +219,30 @@ class AsyncScheduler:
currently running task, if any currently running task, if any
""" """
parent = self.current_task.parent if parent := self.current_task.parent:
if parent:
self.tasks.append(parent) self.tasks.append(parent)
return parent
def reschedule_joinee(self): def reschedule_joinee(self):
""" """
Reschedules the joinee task of the Reschedules the joinee(s) task of the
currently running task, if any currently running task, if any
""" """
self.tasks.extend(self.current_task.waiters) self.tasks.extend(self.current_task.waiters)
def join(self, child: types.coroutine, parent): def join(self):
""" """
Handler for the 'join' event, does some magic to tell the scheduler Handler for the 'join' event, does some magic to tell the scheduler
to wait in the given parent until the current coroutine ends to wait until the current coroutine ends
""" """
child = self.current_task
child.joined = True child.joined = True
if parent: if child.finished:
print("p") self.reschedule_joinee()
child.waiters.append(parent) self.reschedule_parent()
if child.cancelled or child.exc: elif child.exc:
print("f") raise child.exc
# Task was cancelled or has errored
if child.parent:
self.tasks.append(child.parent)
self.tasks.extend(child.waiters)
elif child.finished:
print("finish")
# if parent:
# self.tasks.append(parent)
self.tasks.extend(child.waiters)
def sleep(self, seconds: int or float): def sleep(self, seconds: int or float):
""" """
@ -326,12 +317,16 @@ class AsyncScheduler:
else: else:
self.event_waiting[event].append(self.current_task) self.event_waiting[event].append(self.current_task)
def cancel(self, task): def cancel(self):
""" """
Handler for the 'cancel' event, sets the task to be cancelled later Handler for the 'cancel' event, schedules the task to be cancelled later
""" """
task.cancel_pending = True # Cancellation is deferred if self.current_task.status in ("I/O", "sleep"):
# We cancel right away
self.do_cancel()
else:
self.current_task.cancel_pending = True # Cancellation is deferred
def wrap_socket(self, sock): def wrap_socket(self, sock):
""" """
@ -384,7 +379,6 @@ class AsyncScheduler:
await want_write(sock) await want_write(sock)
self.selector.unregister(sock) self.selector.unregister(sock)
sock.setblocking(False)
return sock.close() return sock.close()
async def connect_sock(self, sock: socket.socket, addr: tuple): async def connect_sock(self, sock: socket.socket, addr: tuple):

View File

@ -61,17 +61,15 @@ class Task:
Joins the task Joins the task
""" """
res = await join(self) return await join(self)
if self.exc:
raise self.exc
return res
async def cancel(self): async def cancel(self):
""" """
Cancels the task Cancels the task
""" """
await cancel(self) if not self.exc and not self.cancelled and not self.finished:
await cancel(self)
def __del__(self): def __del__(self):
self.coroutine.close() self.coroutine.close()
@ -92,7 +90,7 @@ class Event:
self.timeout = None self.timeout = None
self.waiting = 0 self.waiting = 0
async def set(self): async def activate(self):
""" """
Sets the event, waking up all tasks that called Sets the event, waking up all tasks that called
pause() on us pause() on us

View File

@ -73,7 +73,7 @@ async def join(task):
:type task: class: Task :type task: class: Task
""" """
return await create_trap("join", task, await current_task()) return await create_trap("join")
async def cancel(task): async def cancel(task):
@ -89,7 +89,7 @@ async def cancel(task):
code, so if you really wanna do that be sure to re-raise it when done! code, so if you really wanna do that be sure to re-raise it when done!
""" """
await create_trap("cancel", task) await create_trap("cancel")
assert task.cancelled, f"Coroutine ignored CancelledError" assert task.cancelled, f"Coroutine ignored CancelledError"

View File

@ -27,8 +27,8 @@ async def countup(stop: int, step: int = 1):
async def main(): async def main():
try: try:
async with giambio.create_pool() as pool: async with giambio.create_pool() as pool:
pool.spawn(countdown, 10) pool.spawn(countdown, 5)
pool.spawn(countup, 5, 2) pool.spawn(countup, 5, 1)
except Exception as e: except Exception as e:
print(f"Got -> {type(e).__name__}: {e}") print(f"Got -> {type(e).__name__}: {e}")
print("Task execution complete") print("Task execution complete")