diff --git a/giambio/context.py b/giambio/context.py index 8c355de..906be9c 100644 --- a/giambio/context.py +++ b/giambio/context.py @@ -7,7 +7,7 @@ Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, @@ -17,8 +17,7 @@ limitations under the License. """ import giambio -from giambio.task import Task -from typing import List, Optional, Callable, Coroutine, Any +from typing import List, Optional, Any, Coroutine, Callable class TaskManager: @@ -32,13 +31,13 @@ class TaskManager: :type raise_on_timeout: bool, optional """ - def __init__(self, current_task: Task, timeout: float = None, raise_on_timeout: bool = False) -> None: + def __init__(self, timeout: float = None, raise_on_timeout: bool = True) -> None: """ Object constructor """ # All the tasks that belong to this pool - self.tasks: List[Task] = [] + self.tasks: List[giambio.task.Task] = [] # Whether we have been cancelled or not self.cancelled: bool = False # The clock time of when we started running, used for @@ -51,19 +50,10 @@ class TaskManager: self.timeout = None # Whether our timeout expired or not self.timed_out: bool = False - # Internal check so users don't try - # to use the pool manually self._proper_init = False - # We keep track of any inner pools to propagate - # exceptions properly self.enclosed_pool: Optional["giambio.context.TaskManager"] = None - # Do we raise an error after timeout? self.raise_on_timeout: bool = raise_on_timeout - # The task that created the pool. We keep track of - # it because we only cancel ourselves if this task - # errors out (so if the error is caught before reaching - # it we just do nothing) - self.owner: Task = current_task + self.entry_point: Optional[giambio.Task] = None async def spawn(self, func: Callable[..., Coroutine[Any, Any, Any]], *args, **kwargs) -> "giambio.task.Task": """ @@ -80,6 +70,7 @@ class TaskManager: """ self._proper_init = True + self.entry_point = await giambio.traps.current_task() return self async def __aexit__(self, exc_type: Exception, exc: Exception, tb): @@ -95,14 +86,13 @@ class TaskManager: # children to exit await task.join() self.tasks.remove(task) + self._proper_init = False + if isinstance(exc, giambio.exceptions.TooSlowError) and not self.raise_on_timeout: + return True except giambio.exceptions.TooSlowError: if self.raise_on_timeout: raise - finally: - self._proper_init = False - if isinstance(exc, giambio.exceptions.TooSlowError) and not self.raise_on_timeout: - return True - + async def cancel(self): """ Cancels the pool entirely, iterating over all @@ -120,4 +110,4 @@ class TaskManager: pool have exited, False otherwise """ - return self._proper_init and all([task.done() for task in self.tasks]) + return self._proper_init and all([task.done() for task in self.tasks]) and (True if not self.enclosed_pool else self.enclosed_pool.done()) diff --git a/giambio/core.py b/giambio/core.py index ff888c0..871bb83 100644 --- a/giambio/core.py +++ b/giambio/core.py @@ -7,7 +7,7 @@ Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, @@ -17,6 +17,7 @@ limitations under the License. """ # Import libraries and internal resources +from numbers import Number from giambio.task import Task from collections import deque from functools import partial @@ -38,8 +39,8 @@ from giambio.exceptions import ( class AsyncScheduler: """ A simple task scheduler implementation that tries to mimic thread programming - in its simplicity, without using actual threads, but rather alternating - across coroutines execution to let more than one thing at a time to proceed + in its simplicity, without using actual threads, but rather alternating the + execution of coroutines to let more than one thing at a time to proceed with its calculations. An attempt to fix the threaded model has been made without making the API unnecessarily complicated. @@ -55,7 +56,7 @@ class AsyncScheduler: :param clock: A callable returning monotonically increasing values at each call, usually using seconds as units, but this is not enforced, defaults to timeit.default_timer - :type clock: :class: Callable + :type clock: :class: types.FunctionType :param debugger: A subclass of giambio.util.BaseDebugger or None if no debugging output is desired, defaults to None :type debugger: :class: giambio.util.BaseDebugger @@ -72,7 +73,7 @@ class AsyncScheduler: def __init__( self, - clock: Callable = default_timer, + clock: Callable[[], Number] = default_timer, debugger: Optional[BaseDebugger] = None, selector: Optional[Any] = None, io_skip_limit: Optional[int] = None, @@ -94,7 +95,7 @@ class AsyncScheduler: or type( "DumbDebugger", (object,), - {"__getattr__": lambda *args: lambda *arg: None}, + {"__getattr__": lambda *_: lambda *_: None}, )() ) # All tasks the loop has @@ -106,7 +107,7 @@ class AsyncScheduler: # This will always point to the currently running coroutine (Task object) self.current_task: Optional[Task] = None # Monotonic clock to keep track of elapsed time reliably - self.clock: Callable = clock + self.clock: Callable[[], Number] = clock # Tasks that are asleep self.paused: TimeQueue = TimeQueue(self.clock) # Have we ever ran? @@ -129,6 +130,7 @@ class AsyncScheduler: self.entry_point: Optional[Task] = None # Suspended tasks self.suspended: deque = deque() + def __repr__(self): """ @@ -150,6 +152,8 @@ class AsyncScheduler: "_data", "io_skip_limit", "io_max_timeout", + "suspended", + "entry_point" } data = ", ".join( name + "=" + str(value) for name, value in zip(fields, (getattr(self, field) for field in fields)) @@ -168,7 +172,7 @@ class AsyncScheduler: Shuts down the event loop """ - for task in self.tasks: + for task in self.get_all_tasks(): self.io_release_task(task) self.selector.close() # TODO: Anything else? @@ -206,7 +210,10 @@ class AsyncScheduler: # after it is set, but it makes the implementation easier if not self.current_pool and self.current_task.pool: self.current_pool = self.current_task.pool - self.deadlines.put(self.current_pool) + pool = self.current_pool + while pool: + self.deadlines.put(pool) + pool = self.current_pool.enclosed_pool # If there are no actively running tasks, we start by # checking for I/O. This method will wait for I/O until # the closest deadline to avoid starving sleeping tasks @@ -230,9 +237,10 @@ class AsyncScheduler: # some tricky behaviors, and this is one of them. When a coroutine # hits a return statement (either explicit or implicit), it raises # a StopIteration exception, which has an attribute named value that - # represents the return value of the coroutine, if any. Of course this - # exception is not an error and we should happily keep going after it, + # represents the return value of the coroutine, if it has one. Of course + # this exception is not an error and we should happily keep going after it: # most of this code below is just useful for internal/debugging purposes + self.current_task.status = "end" self.current_task.result = ret.value self.current_task.finished = True self.join(self.current_task) @@ -244,20 +252,22 @@ class AsyncScheduler: self.current_task.exc = err self.join(self.current_task) - def create_task(self, coro: Coroutine[Any, Any, Any], pool) -> Task: + + def create_task(self, corofunc: Callable[..., Coroutine[Any, Any, Any]], pool, *args, **kwargs) -> Task: """ Creates a task from a coroutine function and schedules it to run. The associated pool that spawned said task is also needed, while any extra keyword or positional arguments are passed to the function itself - :param coro: The coroutine to spawn - :type coro: Coroutine[Any, Any, Any] + :param corofunc: The coroutine function (NOT a coroutine!) to + spawn + :type corofunc: function :param pool: The giambio.context.TaskManager object that spawned the task """ - task = Task(coro.__name__ or str(coro), coro, pool) + task = Task(corofunc.__name__ or str(corofunc), corofunc(*args, **kwargs), pool) task.next_deadline = pool.timeout or 0.0 task.joiners = {self.current_task} self._data[self.current_task] = task @@ -288,9 +298,15 @@ class AsyncScheduler: # We need to make sure we don't try to execute # exited tasks that are on the running queue return - if not self.current_pool and self.current_task.pool: + if self.current_pool: + if self.current_task.pool and self.current_task.pool is not self.current_pool: + self.current_task.pool.enclosed_pool = self.current_pool + else: self.current_pool = self.current_task.pool - self.deadlines.put(self.current_pool) + pool = self.current_pool + while pool: + self.deadlines.put(pool) + pool = self.current_pool.enclosed_pool self.debugger.before_task_step(self.current_task) # Some debugging and internal chatter here self.current_task.status = "run" @@ -319,7 +335,7 @@ class AsyncScheduler: def io_release(self, sock): """ Releases the given resource from our - selector. + selector :param sock: The resource to be released """ @@ -334,7 +350,7 @@ class AsyncScheduler: if self.selector.get_map(): for k in filter( - lambda o: o.data == self.current_task, + lambda o: o.data == task, dict(self.selector.get_map()).values(), ): self.io_release(k.fileobj) @@ -344,11 +360,16 @@ class AsyncScheduler: """ Suspends execution of the current task. This is basically a do-nothing method, since it will not reschedule the task - before returning. The task will stay suspended until a timer, - I/O operation or cancellation wakes it up, or until another - running task reschedules it. + before returning. The task will stay suspended as long as + something else outside the loop calls a trap to reschedule it. + Any pending I/O for the task is temporarily unscheduled to + avoid some previous network operation to reschedule the task + before it's due """ - + + if self.current_task.last_io or self.current_task.status == "io": + self.io_release_task(self.current_task) + self.current_task.status = "sleep" self.suspended.append(self.current_task) def reschedule_running(self): @@ -408,27 +429,32 @@ class AsyncScheduler: try: to_call() except StopIteration as ret: + task.status = "end" task.result = ret.value task.finished = True self.join(task) + self.tasks.remove(task) except BaseException as err: task.exc = err self.join(task) + if task in self.tasks: + self.tasks.remove(task) def prune_deadlines(self): """ Removes expired deadlines after their timeout - has expired and cancels their associated pool + has expired """ while self.deadlines and self.deadlines.get_closest_deadline() <= self.clock(): pool = self.deadlines.get() pool.timed_out = True + self.cancel_pool(pool) for task in pool.tasks: - if task is not pool.owner: - self.handle_task_exit(task, partial(task.throw, TooSlowError(self.current_task))) - if pool.raise_on_timeout: - self.handle_task_exit(pool.owner, partial(pool.owner.throw, TooSlowError(self.current_task))) + self.join(task) + if pool.entry_point is self.entry_point: + self.handle_task_exit(self.entry_point, partial(self.entry_point.throw, TooSlowError(self.entry_point))) + self.run_ready.append(self.entry_point) def schedule_tasks(self, tasks: List[Task]): """ @@ -439,7 +465,8 @@ class AsyncScheduler: for task in tasks: self.paused.discard(task) - self.suspended.remove(task) + if task in self.suspended: + self.suspended.remove(task) self.run_ready.extend(tasks) self.reschedule_running() @@ -462,6 +489,7 @@ class AsyncScheduler: self.run_ready.append(task) self.debugger.after_sleep(task, slept) + def get_closest_deadline(self) -> float: """ Gets the closest expiration deadline (asleep tasks, timeouts) @@ -469,7 +497,7 @@ class AsyncScheduler: :return: The closest deadline according to our clock :rtype: float """ - + if not self.deadlines: # If there are no deadlines just wait until the first task wakeup timeout = max(0.0, self.paused.get_closest_deadline() - self.clock()) @@ -535,9 +563,12 @@ class AsyncScheduler: self.run_ready.append(entry) self.debugger.on_start() if loop: - self.run() - self.has_ran = True - self.debugger.on_exit() + try: + self.run() + finally: + self.has_ran = True + self.close() + self.debugger.on_exit() def cancel_pool(self, pool: TaskManager) -> bool: """ @@ -589,8 +620,9 @@ class AsyncScheduler: If ensure_done equals False, the loop will cancel ALL running and scheduled tasks and then tear itself down. If ensure_done equals True, which is the default behavior, - this method will raise a GiambioError if the loop hasn't - finished running. + this method will raise a GiambioError exception if the loop + hasn't finished running. The state of the event loop is reset + so it can be reused with another run() call """ if ensure_done: @@ -598,6 +630,16 @@ class AsyncScheduler: elif not self.done(): raise GiambioError("event loop not terminated, call this method with ensure_done=False to forcefully exit") self.shutdown() + # We reset the event loop's state + self.tasks = [] + self.entry_point = None + self.current_pool = None + self.current_task = None + self.paused = TimeQueue(self.clock) + self.deadlines = DeadlinesQueue() + self.run_ready = deque() + self.suspended = deque() + def reschedule_joiners(self, task: Task): """ @@ -605,87 +647,69 @@ class AsyncScheduler: given task, if any """ - for t in task.joiners: - self.run_ready.append(t) - - # noinspection PyMethodMayBeStatic - def is_pool_done(self, pool: Optional[TaskManager]): - """ - Returns True if a given pool has finished - executing - """ - - while pool: - if not pool.done(): - return False - pool = pool.enclosed_pool - return True + if task.pool and task.pool.enclosed_pool and not task.pool.enclosed_pool.done(): + return + self.run_ready.extend(task.joiners) def join(self, task: Task): """ - Joins a task to its callers (implicitly, the parent + Joins a task to its callers (implicitly the parent task, but also every other task who called await task.join() on the task object) """ task.joined = True - if any([task.finished, task.cancelled, task.exc]) and task in self.tasks: - self.io_release_task(task) - self.tasks.remove(task) - self.paused.discard(task) if task.finished or task.cancelled: - task.status = "end" if not task.cancelled: - task.status = "cancelled" - # This way join() returns the - # task's return value - for joiner in task.joiners: - self._data[joiner] = task.result self.debugger.on_task_exit(task) - # If the pool has finished executing or we're at the first parent - # task that kicked the loop, we can safely reschedule the parent(s) - if self.is_pool_done(task.pool): + if task.last_io: + self.io_release_task(task) + if task in self.suspended: + self.suspended.remove(task) + # If the pool (including any enclosing pools) has finished executing + # or we're at the first task that kicked the loop, we can safely + # reschedule the parent(s) + if task.pool is None: + return + if task.pool.done(): self.reschedule_joiners(task) - self.reschedule_running() elif task.exc: + if task in self.suspended: + self.suspended.remove(task) task.status = "crashed" - if task.exc.__traceback__: - # TODO: We might want to do a bit more complex traceback hacking to remove any extra - # frames from the exception call stack, but for now removing at least the first few - # seems a sensible approach (it's us catching it so we don't care about that) - for _ in range(5): - if task.exc.__traceback__.tb_next: - task.exc.__traceback__ = task.exc.__traceback__.tb_next self.debugger.on_exception_raised(task, task.exc) - if task is self.entry_point and not task.pool: - try: - task.throw(task.exc) - except StopIteration: - ... # TODO: ? - except BaseException: - # TODO: No idea what to do here - raise - elif any(map(lambda tk: tk is task.pool.owner, task.joiners)) or task is task.pool.owner: - # We check if the pool's - # owner catches our error - # or not. If they don't, we - # cancel the entire pool, but - # if they do, we do nothing - if task.pool.owner is not task: - self.handle_task_exit(task.pool.owner, partial(task.pool.owner.coroutine.throw, task.exc)) - if any([task.pool.owner.exc, task.pool.owner.cancelled, task.pool.owner.finished]): - for t in task.joiners.copy(): - # Propagate the exception - self.handle_task_exit(t, partial(t.throw, task.exc)) - if any([t.exc, t.finished, t.cancelled]): - task.joiners.remove(t) - for t in task.pool.tasks: - if not t.joined: - self.handle_task_exit(t, partial(t.throw, task.exc)) - if any([t.exc, t.finished, t.cancelled]): - task.joiners.discard(t) - self.reschedule_joiners(task) - self.reschedule_running() + if task.pool is None or task is self.entry_point: + # Parent task has no pool, so we propagate + raise task.exc + if self.cancel_pool(task.pool): + # This will reschedule the parent(s) + # only if all the tasks inside the task's + # pool have finished executing, either + # by cancellation, an exception + # or just returned + for t in task.joiners.copy(): + # Propagate the exception + try: + t.throw(task.exc) + except (StopIteration, CancelledError, RuntimeError) as e: + # TODO: Need anything else? + task.joiners.remove(t) + if isinstance(e, StopIteration): + t.status = "end" + t.result = e.value + t.finished = True + elif isinstance(e, CancelledError): + t = e.task + t.cancel_pending = False + t.cancelled = True + t.status = "cancelled" + self.debugger.after_cancel(t) + elif isinstance(e, BaseException): + t.exc = e + finally: + if t in self.tasks: + self.tasks.remove(t) + self.reschedule_joiners(task) def sleep(self, seconds: int or float): """ @@ -727,6 +751,8 @@ class AsyncScheduler: self.io_release_task(task) elif task.status == "sleep": self.paused.discard(task) + if task in self.suspended: + self.suspended.remove(task) try: self.do_cancel(task) except CancelledError as cancel: @@ -742,24 +768,24 @@ class AsyncScheduler: task = cancel.task task.cancel_pending = False task.cancelled = True - self.io_release_task(self.current_task) + task.status = "cancelled" self.debugger.after_cancel(task) self.tasks.remove(task) + self.join(task) else: # If the task ignores our exception, we'll # raise it later again task.cancel_pending = True - self.join(task) def register_sock(self, sock, evt_type: str): """ Registers the given socket inside the - selector to perform I/0 multiplexing + selector to perform I/O multiplexing :param sock: The socket on which a read or write operation - has to be performed + has to be performed :param evt_type: The type of event to perform on the given - socket, either "read" or "write" + socket, either "read" or "write" :type evt_type: str """ @@ -793,5 +819,8 @@ class AsyncScheduler: try: self.selector.register(sock, evt, self.current_task) except KeyError: - # The socket is already registered doing something else - raise ResourceBusy("The given socket is being read/written by another task") from None + # The socket is already registered doing something else, we + # modify the socket instead (or maybe not?) + self.selector.modify(sock, evt, self.current_task) + # TODO: Does this break stuff? + # raise ResourceBusy("The given socket is being read/written by another task") from None