diff --git a/giambio/context.py b/giambio/context.py index e64af14..9943dc9 100644 --- a/giambio/context.py +++ b/giambio/context.py @@ -20,6 +20,7 @@ limitations under the License. import types from .core import AsyncScheduler from .objects import Task +from .exceptions import CancelledError class TaskManager: @@ -45,6 +46,7 @@ class TaskManager: self.loop.tasks.append(task) self.tasks.append(task) self.loop.debugger.on_task_spawn(task) + return task def spawn_after(self, func: types.FunctionType, n: int, *args): """ @@ -58,15 +60,11 @@ class TaskManager: self.loop.paused.put(task, n) self.tasks.append(task) self.loop.debugger.on_task_schedule(task, n) + return task async def __aenter__(self): return self - async def __aexit__(self, exc_type, exc, tb): + async def __aexit__(self, exc_type: Exception, exc: Exception, tb): for task in self.tasks: - try: - await task.join() - except BaseException: - self.tasks.remove(task) - for to_cancel in self.tasks: - await to_cancel.cancel() \ No newline at end of file + await task.join() \ No newline at end of file diff --git a/giambio/core.py b/giambio/core.py index aa0fc60..7ff33dc 100644 --- a/giambio/core.py +++ b/giambio/core.py @@ -26,6 +26,7 @@ from socket import SOL_SOCKET, SO_ERROR from .traps import want_read, want_write from .util.debug import BaseDebugger from collections import deque +from itertools import chain from .socket import AsyncSocket, WantWrite, WantRead from selectors import DefaultSelector, EVENT_READ, EVENT_WRITE from .exceptions import (InternalError, @@ -56,7 +57,7 @@ class AsyncScheduler: assert issubclass(type(debugger), BaseDebugger), "The debugger must be a subclass of giambio.util.BaseDebugger" self.debugger = debugger or type("DumbDebugger", (object, ), {"__getattr__": lambda *args: lambda *args: None})() # Tasks that are ready to run - self.tasks = deque() + self.tasks = [] # Selector object to perform I/O multiplexing self.selector = DefaultSelector() # This will always point to the currently running coroutine (Task object) @@ -117,7 +118,7 @@ class AsyncScheduler: # Otherwise, while there are tasks ready to run, well, run them! while self.tasks: # Sets the currently running task - self.current_task = self.tasks.popleft() + self.current_task = self.tasks.pop(0) self.debugger.before_task_step(self.current_task) if self.current_task.cancel_pending: self.do_cancel() @@ -143,34 +144,33 @@ class AsyncScheduler: self.current_task.cancelled = True self.current_task.cancel_pending = False self.debugger.after_cancel(self.current_task) - self.join() # TODO: Investigate if a call to join() is needed + self.join(self.current_task) # TODO: Investigate if a call to join() is needed except StopIteration as ret: # Coroutine ends self.current_task.status = "end" self.current_task.result = ret.value self.current_task.finished = True self.debugger.on_task_exit(self.current_task) - self.join() + self.join(self.current_task) except BaseException as err: self.current_task.exc = err self.current_task.status = "crashed" - self.join() + self.join(self.current_task) - def do_cancel(self): + def do_cancel(self, task: Task = None): """ Performs task cancellation by throwing CancelledError inside the current task in order to stop it from executing. The loop continues to execute as tasks are independent """ - # TODO: Do we need anything else? - self.debugger.before_cancel(self.current_task) - self.current_task.throw(CancelledError) - + task = task or self.current_task + self.debugger.before_cancel(task) + task.throw(CancelledError) def get_running(self): """ - Returns the current task + Returns the current task to an async caller """ self.tasks.append(self.current_task) @@ -184,7 +184,7 @@ class AsyncScheduler: for event in self.events.copy(): if event.set: event.event_caught = True - event.waiters + event.waiters.append(self.current_task) self.tasks.extend(event.waiters) self.events.remove(event) @@ -210,6 +210,7 @@ class AsyncScheduler: Checks and schedules task to perform I/O """ + before_time = self.clock() if self.tasks or self.events and not self.selector.get_map(): # If there are either tasks or events and no I/O, never wait timeout = 0.0 @@ -220,17 +221,12 @@ class AsyncScheduler: # If there is *only* I/O, we wait a fixed amount of time timeout = 1 # TODO: Is this ok? self.debugger.before_io(timeout) - for key in dict(self.selector.get_map()).values(): - # We make sure we don't reschedule finished tasks - if key.data.finished: - key.data.last_io = () - self.selector.unregister(key.fileobj) - if self.selector.get_map(): # If there is indeed tasks waiting on I/O + if self.selector.get_map(): io_ready = self.selector.select(timeout) # Get sockets that are ready and schedule their tasks for key, _ in io_ready: self.tasks.append(key.data) # Resource ready? Schedule its task - self.debugger.after_io(timeout) + self.debugger.after_io(self.clock() - before_time) def start(self, func: types.FunctionType, *args): """ @@ -243,31 +239,37 @@ class AsyncScheduler: self.run() self.has_ran = True self.debugger.on_exit() - if entry.exc: - raise entry.exc from None - def reschedule_joinee(self): + def reschedule_joinee(self, task: Task): """ - Reschedules the joinee(s) of the - currently running task, if any + Reschedules the joinee of the + given task, if any """ - self.tasks.extend(self.current_task.waiters) + if task.parent: + self.tasks.append(task.parent) - def join(self): + def join(self, child: Task): """ Handler for the 'join' event, does some magic to tell the scheduler to wait until the current coroutine ends """ - child = self.current_task child.joined = True - if child.parent: - child.waiters.append(child.parent) if child.finished: - self.reschedule_joinee() + self.reschedule_joinee(child) elif child.exc: - ... # TODO: Handle exceptions + for task in chain(self.tasks, self.paused): + try: + self.cancel(task) + except CancelledError: + task.status = "cancelled" + task.cancelled = True + task.cancel_pending = False + self.debugger.after_cancel(task) + self.tasks.remove(task) + child.parent.throw(child.exc) + self.tasks.append(child.parent) def sleep(self, seconds: int or float): """ @@ -282,46 +284,19 @@ class AsyncScheduler: else: self.tasks.append(self.current_task) - # TODO: More generic I/O rather than just sockets - def want_read(self, sock: socket.socket): + def cancel(self, task: Task = None): """ - Handler for the 'want_read' event, registers the socket inside the - selector to perform I/0 multiplexing + Handler for the 'cancel' event, schedules the task to be cancelled later + or does so straight away if it is safe to do so """ - self.current_task.status = "I/O" - if self.current_task.last_io: - if self.current_task.last_io == ("READ", sock): - # Socket is already scheduled! - return + task = task or self.current_task + if not task.finished and not task.exc: + if task.status in ("I/O", "sleep"): + # We cancel right away + self.do_cancel(task) else: - self.selector.unregister(sock) - self.current_task.last_io = "READ", sock - try: - self.selector.register(sock, EVENT_READ, self.current_task) - except KeyError: - # The socket is already registered doing something else - raise ResourceBusy("The given resource is busy!") from None - - def want_write(self, sock: socket.socket): - """ - Handler for the 'want_write' event, registers the socket inside the - selector to perform I/0 multiplexing - """ - - self.current_task.status = "I/O" - if self.current_task.last_io: - if self.current_task.last_io == ("WRITE", sock): - # Socket is already scheduled! - return - else: - # TODO: Inspect why modify() causes issues - self.selector.unregister(sock) - self.current_task.last_io = "WRITE", sock - try: - self.selector.register(sock, EVENT_WRITE, self.current_task) - except KeyError: - raise ResourceBusy("The given resource is busy!") from None + task.cancel_pending = True # Cancellation is deferred def event_set(self, event): """ @@ -340,19 +315,44 @@ class AsyncScheduler: event.waiters.append(self.current_task) - - def cancel(self): + # TODO: More generic I/O rather than just sockets + def want_read(self, sock: socket.socket): """ - Handler for the 'cancel' event, schedules the task to be cancelled later - or does so straight away if it is safe to do so + Handler for the 'want_read' event, registers the socket inside the + selector to perform I/0 multiplexing """ - if self.current_task.status in ("I/O", "sleep"): - # We cancel right away - self.do_cancel() - else: - self.current_task.cancel_pending = True # Cancellation is deferred + self.current_task.status = "I/O" + if self.current_task.last_io: + if self.current_task.last_io == ("READ", sock): + # Socket is already scheduled! + return + self.selector.unregister(sock) + self.current_task.last_io = "READ", sock + try: + self.selector.register(sock, EVENT_READ, self.current_task) + except KeyError: + # The socket is already registered doing something else + raise ResourceBusy("The given resource is busy!") from None + def want_write(self, sock: socket.socket): + """ + Handler for the 'want_write' event, registers the socket inside the + selector to perform I/0 multiplexing + """ + + self.current_task.status = "I/O" + if self.current_task.last_io: + if self.current_task.last_io == ("WRITE", sock): + # Socket is already scheduled! + return + # TODO: Inspect why modify() causes issues + self.selector.unregister(sock) + self.current_task.last_io = "WRITE", sock + try: + self.selector.register(sock, EVENT_WRITE, self.current_task) + except KeyError: + raise ResourceBusy("The given resource is busy!") from None def wrap_socket(self, sock): """ Wraps a standard socket into an AsyncSocket object diff --git a/giambio/objects.py b/giambio/objects.py index 058a857..0c44eb5 100644 --- a/giambio/objects.py +++ b/giambio/objects.py @@ -42,7 +42,6 @@ class Task: parent: object = None joined: bool= False cancel_pending: bool = False - waiters: list = field(default_factory=list) sleep_start: int = None def run(self, what=None): @@ -131,7 +130,13 @@ class TimeQueue: return item in self.container def __iter__(self): - return iter(self.container) + return self + + def __next__(self): + try: + return self.get() + except IndexError: + raise StopIteration from None def __getitem__(self, item): return self.container.__getitem__(item) @@ -156,4 +161,4 @@ class TimeQueue: Gets the first task that is meant to run """ - return heappop(self.container)[2] + return heappop(self.container)[2] \ No newline at end of file diff --git a/giambio/run.py b/giambio/run.py index ae6fbd9..535af67 100644 --- a/giambio/run.py +++ b/giambio/run.py @@ -70,7 +70,7 @@ def run(func: FunctionType, *args, **kwargs): elif not isinstance(func, FunctionType): raise GiambioError("gaibmio.run() requires an async function as parameter!") new_event_loop(kwargs.get("debugger", None)) - thread_local.loop.start(func, *args) + get_event_loop().start(func, *args) def clock(): diff --git a/giambio/traps.py b/giambio/traps.py index b97a72b..a1a51be 100644 --- a/giambio/traps.py +++ b/giambio/traps.py @@ -73,7 +73,7 @@ async def join(task): :type task: class: Task """ - return await create_trap("join") + return await create_trap("join", task) async def cancel(task): diff --git a/giambio/util/debug.py b/giambio/util/debug.py index 7d580da..188e28d 100644 --- a/giambio/util/debug.py +++ b/giambio/util/debug.py @@ -162,7 +162,7 @@ class BaseDebugger(ABC): This method is called right after the event loop has checked for I/O events - :param timeout: The max. amount of seconds + :param timeout: The actual amount of seconds that the loop has hung when using the select() system call :type timeout: int diff --git a/tests/sleep.py b/tests/sleep.py index 09e080b..5332ef9 100644 --- a/tests/sleep.py +++ b/tests/sleep.py @@ -17,10 +17,10 @@ class Debugger(giambio.debug.BaseDebugger): def on_task_spawn(self, task): print(f">> A task named '{task.name}' was spawned") - + def on_task_exit(self, task): print(f"<< Task '{task.name}' exited") - + def before_task_step(self, task): print(f"-> About to run a step for '{task.name}'") @@ -29,7 +29,7 @@ class Debugger(giambio.debug.BaseDebugger): def before_sleep(self, task, seconds): print(f"# About to put '{task.name}' to sleep for {seconds:.2f} seconds") - + def after_sleep(self, task, seconds): print(f"# Task '{task.name}' slept for {seconds:.2f} seconds") @@ -37,7 +37,7 @@ class Debugger(giambio.debug.BaseDebugger): print(f"!! About to check for I/O for up to {timeout:.2f} seconds") def after_io(self, timeout): - print(f"!! Done I/O check (timeout {timeout:.2f} seconds)") + print(f"!! Done I/O check (waited for {timeout:.2f} seconds)") def before_cancel(self, task): print(f"// About to cancel '{task.name}'") @@ -50,12 +50,14 @@ async def child(): print("[child] Child spawned!! Sleeping for 2 seconds") await giambio.sleep(2) print("[child] Had a nice nap!") + raise TypeError("rip") async def child1(): print("[child 1] Child spawned!! Sleeping for 2 seconds") await giambio.sleep(2) print("[child 1] Had a nice nap!") + async def main(): start = giambio.clock() try: @@ -63,10 +65,10 @@ async def main(): pool.spawn(child) pool.spawn(child1) print("[main] Children spawned, awaiting completion") - except Exception as e: - print(f"Got -> {type(e).__name__}: {e}") + except Exception as error: + print(f"[main] Exception from child catched! {repr(error)}") print(f"[main] Children execution complete in {giambio.clock() - start:.2f} seconds") - + await giambio.sleep(5) if __name__ == "__main__": giambio.run(main, debugger=Debugger())