mirror of https://github.com/nocturn9x/giambio.git
Cancellation/Exceptions almost complete
This commit is contained in:
parent
31fa71fd84
commit
caee01977e
|
@ -20,6 +20,7 @@ limitations under the License.
|
|||
import types
|
||||
from .core import AsyncScheduler
|
||||
from .objects import Task
|
||||
from .exceptions import CancelledError
|
||||
|
||||
|
||||
class TaskManager:
|
||||
|
@ -45,6 +46,7 @@ class TaskManager:
|
|||
self.loop.tasks.append(task)
|
||||
self.tasks.append(task)
|
||||
self.loop.debugger.on_task_spawn(task)
|
||||
return task
|
||||
|
||||
def spawn_after(self, func: types.FunctionType, n: int, *args):
|
||||
"""
|
||||
|
@ -58,15 +60,11 @@ class TaskManager:
|
|||
self.loop.paused.put(task, n)
|
||||
self.tasks.append(task)
|
||||
self.loop.debugger.on_task_schedule(task, n)
|
||||
return task
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
async def __aexit__(self, exc_type: Exception, exc: Exception, tb):
|
||||
for task in self.tasks:
|
||||
try:
|
||||
await task.join()
|
||||
except BaseException:
|
||||
self.tasks.remove(task)
|
||||
for to_cancel in self.tasks:
|
||||
await to_cancel.cancel()
|
128
giambio/core.py
128
giambio/core.py
|
@ -26,6 +26,7 @@ from socket import SOL_SOCKET, SO_ERROR
|
|||
from .traps import want_read, want_write
|
||||
from .util.debug import BaseDebugger
|
||||
from collections import deque
|
||||
from itertools import chain
|
||||
from .socket import AsyncSocket, WantWrite, WantRead
|
||||
from selectors import DefaultSelector, EVENT_READ, EVENT_WRITE
|
||||
from .exceptions import (InternalError,
|
||||
|
@ -56,7 +57,7 @@ class AsyncScheduler:
|
|||
assert issubclass(type(debugger), BaseDebugger), "The debugger must be a subclass of giambio.util.BaseDebugger"
|
||||
self.debugger = debugger or type("DumbDebugger", (object, ), {"__getattr__": lambda *args: lambda *args: None})()
|
||||
# Tasks that are ready to run
|
||||
self.tasks = deque()
|
||||
self.tasks = []
|
||||
# Selector object to perform I/O multiplexing
|
||||
self.selector = DefaultSelector()
|
||||
# This will always point to the currently running coroutine (Task object)
|
||||
|
@ -117,7 +118,7 @@ class AsyncScheduler:
|
|||
# Otherwise, while there are tasks ready to run, well, run them!
|
||||
while self.tasks:
|
||||
# Sets the currently running task
|
||||
self.current_task = self.tasks.popleft()
|
||||
self.current_task = self.tasks.pop(0)
|
||||
self.debugger.before_task_step(self.current_task)
|
||||
if self.current_task.cancel_pending:
|
||||
self.do_cancel()
|
||||
|
@ -143,34 +144,33 @@ class AsyncScheduler:
|
|||
self.current_task.cancelled = True
|
||||
self.current_task.cancel_pending = False
|
||||
self.debugger.after_cancel(self.current_task)
|
||||
self.join() # TODO: Investigate if a call to join() is needed
|
||||
self.join(self.current_task) # TODO: Investigate if a call to join() is needed
|
||||
except StopIteration as ret:
|
||||
# Coroutine ends
|
||||
self.current_task.status = "end"
|
||||
self.current_task.result = ret.value
|
||||
self.current_task.finished = True
|
||||
self.debugger.on_task_exit(self.current_task)
|
||||
self.join()
|
||||
self.join(self.current_task)
|
||||
except BaseException as err:
|
||||
self.current_task.exc = err
|
||||
self.current_task.status = "crashed"
|
||||
self.join()
|
||||
self.join(self.current_task)
|
||||
|
||||
def do_cancel(self):
|
||||
def do_cancel(self, task: Task = None):
|
||||
"""
|
||||
Performs task cancellation by throwing CancelledError inside the current
|
||||
task in order to stop it from executing. The loop continues to execute
|
||||
as tasks are independent
|
||||
"""
|
||||
|
||||
# TODO: Do we need anything else?
|
||||
self.debugger.before_cancel(self.current_task)
|
||||
self.current_task.throw(CancelledError)
|
||||
|
||||
task = task or self.current_task
|
||||
self.debugger.before_cancel(task)
|
||||
task.throw(CancelledError)
|
||||
|
||||
def get_running(self):
|
||||
"""
|
||||
Returns the current task
|
||||
Returns the current task to an async caller
|
||||
"""
|
||||
|
||||
self.tasks.append(self.current_task)
|
||||
|
@ -184,7 +184,7 @@ class AsyncScheduler:
|
|||
for event in self.events.copy():
|
||||
if event.set:
|
||||
event.event_caught = True
|
||||
event.waiters
|
||||
event.waiters.append(self.current_task)
|
||||
self.tasks.extend(event.waiters)
|
||||
self.events.remove(event)
|
||||
|
||||
|
@ -210,6 +210,7 @@ class AsyncScheduler:
|
|||
Checks and schedules task to perform I/O
|
||||
"""
|
||||
|
||||
before_time = self.clock()
|
||||
if self.tasks or self.events and not self.selector.get_map():
|
||||
# If there are either tasks or events and no I/O, never wait
|
||||
timeout = 0.0
|
||||
|
@ -220,17 +221,12 @@ class AsyncScheduler:
|
|||
# If there is *only* I/O, we wait a fixed amount of time
|
||||
timeout = 1 # TODO: Is this ok?
|
||||
self.debugger.before_io(timeout)
|
||||
for key in dict(self.selector.get_map()).values():
|
||||
# We make sure we don't reschedule finished tasks
|
||||
if key.data.finished:
|
||||
key.data.last_io = ()
|
||||
self.selector.unregister(key.fileobj)
|
||||
if self.selector.get_map(): # If there is indeed tasks waiting on I/O
|
||||
if self.selector.get_map():
|
||||
io_ready = self.selector.select(timeout)
|
||||
# Get sockets that are ready and schedule their tasks
|
||||
for key, _ in io_ready:
|
||||
self.tasks.append(key.data) # Resource ready? Schedule its task
|
||||
self.debugger.after_io(timeout)
|
||||
self.debugger.after_io(self.clock() - before_time)
|
||||
|
||||
def start(self, func: types.FunctionType, *args):
|
||||
"""
|
||||
|
@ -243,31 +239,37 @@ class AsyncScheduler:
|
|||
self.run()
|
||||
self.has_ran = True
|
||||
self.debugger.on_exit()
|
||||
if entry.exc:
|
||||
raise entry.exc from None
|
||||
|
||||
def reschedule_joinee(self):
|
||||
def reschedule_joinee(self, task: Task):
|
||||
"""
|
||||
Reschedules the joinee(s) of the
|
||||
currently running task, if any
|
||||
Reschedules the joinee of the
|
||||
given task, if any
|
||||
"""
|
||||
|
||||
self.tasks.extend(self.current_task.waiters)
|
||||
if task.parent:
|
||||
self.tasks.append(task.parent)
|
||||
|
||||
def join(self):
|
||||
def join(self, child: Task):
|
||||
"""
|
||||
Handler for the 'join' event, does some magic to tell the scheduler
|
||||
to wait until the current coroutine ends
|
||||
"""
|
||||
|
||||
child = self.current_task
|
||||
child.joined = True
|
||||
if child.parent:
|
||||
child.waiters.append(child.parent)
|
||||
if child.finished:
|
||||
self.reschedule_joinee()
|
||||
self.reschedule_joinee(child)
|
||||
elif child.exc:
|
||||
... # TODO: Handle exceptions
|
||||
for task in chain(self.tasks, self.paused):
|
||||
try:
|
||||
self.cancel(task)
|
||||
except CancelledError:
|
||||
task.status = "cancelled"
|
||||
task.cancelled = True
|
||||
task.cancel_pending = False
|
||||
self.debugger.after_cancel(task)
|
||||
self.tasks.remove(task)
|
||||
child.parent.throw(child.exc)
|
||||
self.tasks.append(child.parent)
|
||||
|
||||
def sleep(self, seconds: int or float):
|
||||
"""
|
||||
|
@ -282,6 +284,37 @@ class AsyncScheduler:
|
|||
else:
|
||||
self.tasks.append(self.current_task)
|
||||
|
||||
def cancel(self, task: Task = None):
|
||||
"""
|
||||
Handler for the 'cancel' event, schedules the task to be cancelled later
|
||||
or does so straight away if it is safe to do so
|
||||
"""
|
||||
|
||||
task = task or self.current_task
|
||||
if not task.finished and not task.exc:
|
||||
if task.status in ("I/O", "sleep"):
|
||||
# We cancel right away
|
||||
self.do_cancel(task)
|
||||
else:
|
||||
task.cancel_pending = True # Cancellation is deferred
|
||||
|
||||
def event_set(self, event):
|
||||
"""
|
||||
Sets an event
|
||||
"""
|
||||
|
||||
self.events.add(event)
|
||||
event.waiters.append(self.current_task)
|
||||
event.set = True
|
||||
self.reschedule_joinee()
|
||||
|
||||
def event_wait(self, event):
|
||||
"""
|
||||
Pauses the current task on an event
|
||||
"""
|
||||
|
||||
event.waiters.append(self.current_task)
|
||||
|
||||
# TODO: More generic I/O rather than just sockets
|
||||
def want_read(self, sock: socket.socket):
|
||||
"""
|
||||
|
@ -294,7 +327,6 @@ class AsyncScheduler:
|
|||
if self.current_task.last_io == ("READ", sock):
|
||||
# Socket is already scheduled!
|
||||
return
|
||||
else:
|
||||
self.selector.unregister(sock)
|
||||
self.current_task.last_io = "READ", sock
|
||||
try:
|
||||
|
@ -314,7 +346,6 @@ class AsyncScheduler:
|
|||
if self.current_task.last_io == ("WRITE", sock):
|
||||
# Socket is already scheduled!
|
||||
return
|
||||
else:
|
||||
# TODO: Inspect why modify() causes issues
|
||||
self.selector.unregister(sock)
|
||||
self.current_task.last_io = "WRITE", sock
|
||||
|
@ -322,37 +353,6 @@ class AsyncScheduler:
|
|||
self.selector.register(sock, EVENT_WRITE, self.current_task)
|
||||
except KeyError:
|
||||
raise ResourceBusy("The given resource is busy!") from None
|
||||
|
||||
def event_set(self, event):
|
||||
"""
|
||||
Sets an event
|
||||
"""
|
||||
|
||||
self.events.add(event)
|
||||
event.waiters.append(self.current_task)
|
||||
event.set = True
|
||||
self.reschedule_joinee()
|
||||
|
||||
def event_wait(self, event):
|
||||
"""
|
||||
Pauses the current task on an event
|
||||
"""
|
||||
|
||||
event.waiters.append(self.current_task)
|
||||
|
||||
|
||||
def cancel(self):
|
||||
"""
|
||||
Handler for the 'cancel' event, schedules the task to be cancelled later
|
||||
or does so straight away if it is safe to do so
|
||||
"""
|
||||
|
||||
if self.current_task.status in ("I/O", "sleep"):
|
||||
# We cancel right away
|
||||
self.do_cancel()
|
||||
else:
|
||||
self.current_task.cancel_pending = True # Cancellation is deferred
|
||||
|
||||
def wrap_socket(self, sock):
|
||||
"""
|
||||
Wraps a standard socket into an AsyncSocket object
|
||||
|
|
|
@ -42,7 +42,6 @@ class Task:
|
|||
parent: object = None
|
||||
joined: bool= False
|
||||
cancel_pending: bool = False
|
||||
waiters: list = field(default_factory=list)
|
||||
sleep_start: int = None
|
||||
|
||||
def run(self, what=None):
|
||||
|
@ -131,7 +130,13 @@ class TimeQueue:
|
|||
return item in self.container
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.container)
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
try:
|
||||
return self.get()
|
||||
except IndexError:
|
||||
raise StopIteration from None
|
||||
|
||||
def __getitem__(self, item):
|
||||
return self.container.__getitem__(item)
|
||||
|
|
|
@ -70,7 +70,7 @@ def run(func: FunctionType, *args, **kwargs):
|
|||
elif not isinstance(func, FunctionType):
|
||||
raise GiambioError("gaibmio.run() requires an async function as parameter!")
|
||||
new_event_loop(kwargs.get("debugger", None))
|
||||
thread_local.loop.start(func, *args)
|
||||
get_event_loop().start(func, *args)
|
||||
|
||||
|
||||
def clock():
|
||||
|
|
|
@ -73,7 +73,7 @@ async def join(task):
|
|||
:type task: class: Task
|
||||
"""
|
||||
|
||||
return await create_trap("join")
|
||||
return await create_trap("join", task)
|
||||
|
||||
|
||||
async def cancel(task):
|
||||
|
|
|
@ -162,7 +162,7 @@ class BaseDebugger(ABC):
|
|||
This method is called right after
|
||||
the event loop has checked for I/O events
|
||||
|
||||
:param timeout: The max. amount of seconds
|
||||
:param timeout: The actual amount of seconds
|
||||
that the loop has hung when using the select()
|
||||
system call
|
||||
:type timeout: int
|
||||
|
|
|
@ -37,7 +37,7 @@ class Debugger(giambio.debug.BaseDebugger):
|
|||
print(f"!! About to check for I/O for up to {timeout:.2f} seconds")
|
||||
|
||||
def after_io(self, timeout):
|
||||
print(f"!! Done I/O check (timeout {timeout:.2f} seconds)")
|
||||
print(f"!! Done I/O check (waited for {timeout:.2f} seconds)")
|
||||
|
||||
def before_cancel(self, task):
|
||||
print(f"// About to cancel '{task.name}'")
|
||||
|
@ -50,12 +50,14 @@ async def child():
|
|||
print("[child] Child spawned!! Sleeping for 2 seconds")
|
||||
await giambio.sleep(2)
|
||||
print("[child] Had a nice nap!")
|
||||
raise TypeError("rip")
|
||||
|
||||
async def child1():
|
||||
print("[child 1] Child spawned!! Sleeping for 2 seconds")
|
||||
await giambio.sleep(2)
|
||||
print("[child 1] Had a nice nap!")
|
||||
|
||||
|
||||
async def main():
|
||||
start = giambio.clock()
|
||||
try:
|
||||
|
@ -63,10 +65,10 @@ async def main():
|
|||
pool.spawn(child)
|
||||
pool.spawn(child1)
|
||||
print("[main] Children spawned, awaiting completion")
|
||||
except Exception as e:
|
||||
print(f"Got -> {type(e).__name__}: {e}")
|
||||
except Exception as error:
|
||||
print(f"[main] Exception from child catched! {repr(error)}")
|
||||
print(f"[main] Children execution complete in {giambio.clock() - start:.2f} seconds")
|
||||
|
||||
await giambio.sleep(5)
|
||||
|
||||
if __name__ == "__main__":
|
||||
giambio.run(main, debugger=Debugger())
|
||||
|
|
Loading…
Reference in New Issue