Bug fixes with exception handling and minor documentation improvements

This commit is contained in:
Nocturn9x 2022-10-10 09:55:04 +02:00
parent 6d089d7d5f
commit d408cffa87
2 changed files with 148 additions and 129 deletions

View File

@ -7,7 +7,7 @@ Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
@ -17,8 +17,7 @@ limitations under the License.
""" """
import giambio import giambio
from giambio.task import Task from typing import List, Optional, Any, Coroutine, Callable
from typing import List, Optional, Callable, Coroutine, Any
class TaskManager: class TaskManager:
@ -32,13 +31,13 @@ class TaskManager:
:type raise_on_timeout: bool, optional :type raise_on_timeout: bool, optional
""" """
def __init__(self, current_task: Task, timeout: float = None, raise_on_timeout: bool = False) -> None: def __init__(self, timeout: float = None, raise_on_timeout: bool = True) -> None:
""" """
Object constructor Object constructor
""" """
# All the tasks that belong to this pool # All the tasks that belong to this pool
self.tasks: List[Task] = [] self.tasks: List[giambio.task.Task] = []
# Whether we have been cancelled or not # Whether we have been cancelled or not
self.cancelled: bool = False self.cancelled: bool = False
# The clock time of when we started running, used for # The clock time of when we started running, used for
@ -51,19 +50,10 @@ class TaskManager:
self.timeout = None self.timeout = None
# Whether our timeout expired or not # Whether our timeout expired or not
self.timed_out: bool = False self.timed_out: bool = False
# Internal check so users don't try
# to use the pool manually
self._proper_init = False self._proper_init = False
# We keep track of any inner pools to propagate
# exceptions properly
self.enclosed_pool: Optional["giambio.context.TaskManager"] = None self.enclosed_pool: Optional["giambio.context.TaskManager"] = None
# Do we raise an error after timeout?
self.raise_on_timeout: bool = raise_on_timeout self.raise_on_timeout: bool = raise_on_timeout
# The task that created the pool. We keep track of self.entry_point: Optional[giambio.Task] = None
# it because we only cancel ourselves if this task
# errors out (so if the error is caught before reaching
# it we just do nothing)
self.owner: Task = current_task
async def spawn(self, func: Callable[..., Coroutine[Any, Any, Any]], *args, **kwargs) -> "giambio.task.Task": async def spawn(self, func: Callable[..., Coroutine[Any, Any, Any]], *args, **kwargs) -> "giambio.task.Task":
""" """
@ -80,6 +70,7 @@ class TaskManager:
""" """
self._proper_init = True self._proper_init = True
self.entry_point = await giambio.traps.current_task()
return self return self
async def __aexit__(self, exc_type: Exception, exc: Exception, tb): async def __aexit__(self, exc_type: Exception, exc: Exception, tb):
@ -95,14 +86,13 @@ class TaskManager:
# children to exit # children to exit
await task.join() await task.join()
self.tasks.remove(task) self.tasks.remove(task)
self._proper_init = False
if isinstance(exc, giambio.exceptions.TooSlowError) and not self.raise_on_timeout:
return True
except giambio.exceptions.TooSlowError: except giambio.exceptions.TooSlowError:
if self.raise_on_timeout: if self.raise_on_timeout:
raise raise
finally:
self._proper_init = False
if isinstance(exc, giambio.exceptions.TooSlowError) and not self.raise_on_timeout:
return True
async def cancel(self): async def cancel(self):
""" """
Cancels the pool entirely, iterating over all Cancels the pool entirely, iterating over all
@ -120,4 +110,4 @@ class TaskManager:
pool have exited, False otherwise pool have exited, False otherwise
""" """
return self._proper_init and all([task.done() for task in self.tasks]) return self._proper_init and all([task.done() for task in self.tasks]) and (True if not self.enclosed_pool else self.enclosed_pool.done())

View File

@ -7,7 +7,7 @@ Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
@ -17,6 +17,7 @@ limitations under the License.
""" """
# Import libraries and internal resources # Import libraries and internal resources
from numbers import Number
from giambio.task import Task from giambio.task import Task
from collections import deque from collections import deque
from functools import partial from functools import partial
@ -38,8 +39,8 @@ from giambio.exceptions import (
class AsyncScheduler: class AsyncScheduler:
""" """
A simple task scheduler implementation that tries to mimic thread programming A simple task scheduler implementation that tries to mimic thread programming
in its simplicity, without using actual threads, but rather alternating in its simplicity, without using actual threads, but rather alternating the
across coroutines execution to let more than one thing at a time to proceed execution of coroutines to let more than one thing at a time to proceed
with its calculations. An attempt to fix the threaded model has been made with its calculations. An attempt to fix the threaded model has been made
without making the API unnecessarily complicated. without making the API unnecessarily complicated.
@ -55,7 +56,7 @@ class AsyncScheduler:
:param clock: A callable returning monotonically increasing values at each call, :param clock: A callable returning monotonically increasing values at each call,
usually using seconds as units, but this is not enforced, defaults to timeit.default_timer usually using seconds as units, but this is not enforced, defaults to timeit.default_timer
:type clock: :class: Callable :type clock: :class: types.FunctionType
:param debugger: A subclass of giambio.util.BaseDebugger or None if no debugging output :param debugger: A subclass of giambio.util.BaseDebugger or None if no debugging output
is desired, defaults to None is desired, defaults to None
:type debugger: :class: giambio.util.BaseDebugger :type debugger: :class: giambio.util.BaseDebugger
@ -72,7 +73,7 @@ class AsyncScheduler:
def __init__( def __init__(
self, self,
clock: Callable = default_timer, clock: Callable[[], Number] = default_timer,
debugger: Optional[BaseDebugger] = None, debugger: Optional[BaseDebugger] = None,
selector: Optional[Any] = None, selector: Optional[Any] = None,
io_skip_limit: Optional[int] = None, io_skip_limit: Optional[int] = None,
@ -94,7 +95,7 @@ class AsyncScheduler:
or type( or type(
"DumbDebugger", "DumbDebugger",
(object,), (object,),
{"__getattr__": lambda *args: lambda *arg: None}, {"__getattr__": lambda *_: lambda *_: None},
)() )()
) )
# All tasks the loop has # All tasks the loop has
@ -106,7 +107,7 @@ class AsyncScheduler:
# This will always point to the currently running coroutine (Task object) # This will always point to the currently running coroutine (Task object)
self.current_task: Optional[Task] = None self.current_task: Optional[Task] = None
# Monotonic clock to keep track of elapsed time reliably # Monotonic clock to keep track of elapsed time reliably
self.clock: Callable = clock self.clock: Callable[[], Number] = clock
# Tasks that are asleep # Tasks that are asleep
self.paused: TimeQueue = TimeQueue(self.clock) self.paused: TimeQueue = TimeQueue(self.clock)
# Have we ever ran? # Have we ever ran?
@ -129,6 +130,7 @@ class AsyncScheduler:
self.entry_point: Optional[Task] = None self.entry_point: Optional[Task] = None
# Suspended tasks # Suspended tasks
self.suspended: deque = deque() self.suspended: deque = deque()
def __repr__(self): def __repr__(self):
""" """
@ -150,6 +152,8 @@ class AsyncScheduler:
"_data", "_data",
"io_skip_limit", "io_skip_limit",
"io_max_timeout", "io_max_timeout",
"suspended",
"entry_point"
} }
data = ", ".join( data = ", ".join(
name + "=" + str(value) for name, value in zip(fields, (getattr(self, field) for field in fields)) name + "=" + str(value) for name, value in zip(fields, (getattr(self, field) for field in fields))
@ -168,7 +172,7 @@ class AsyncScheduler:
Shuts down the event loop Shuts down the event loop
""" """
for task in self.tasks: for task in self.get_all_tasks():
self.io_release_task(task) self.io_release_task(task)
self.selector.close() self.selector.close()
# TODO: Anything else? # TODO: Anything else?
@ -206,7 +210,10 @@ class AsyncScheduler:
# after it is set, but it makes the implementation easier # after it is set, but it makes the implementation easier
if not self.current_pool and self.current_task.pool: if not self.current_pool and self.current_task.pool:
self.current_pool = self.current_task.pool self.current_pool = self.current_task.pool
self.deadlines.put(self.current_pool) pool = self.current_pool
while pool:
self.deadlines.put(pool)
pool = self.current_pool.enclosed_pool
# If there are no actively running tasks, we start by # If there are no actively running tasks, we start by
# checking for I/O. This method will wait for I/O until # checking for I/O. This method will wait for I/O until
# the closest deadline to avoid starving sleeping tasks # the closest deadline to avoid starving sleeping tasks
@ -230,9 +237,10 @@ class AsyncScheduler:
# some tricky behaviors, and this is one of them. When a coroutine # some tricky behaviors, and this is one of them. When a coroutine
# hits a return statement (either explicit or implicit), it raises # hits a return statement (either explicit or implicit), it raises
# a StopIteration exception, which has an attribute named value that # a StopIteration exception, which has an attribute named value that
# represents the return value of the coroutine, if any. Of course this # represents the return value of the coroutine, if it has one. Of course
# exception is not an error and we should happily keep going after it, # this exception is not an error and we should happily keep going after it:
# most of this code below is just useful for internal/debugging purposes # most of this code below is just useful for internal/debugging purposes
self.current_task.status = "end"
self.current_task.result = ret.value self.current_task.result = ret.value
self.current_task.finished = True self.current_task.finished = True
self.join(self.current_task) self.join(self.current_task)
@ -244,20 +252,22 @@ class AsyncScheduler:
self.current_task.exc = err self.current_task.exc = err
self.join(self.current_task) self.join(self.current_task)
def create_task(self, coro: Coroutine[Any, Any, Any], pool) -> Task:
def create_task(self, corofunc: Callable[..., Coroutine[Any, Any, Any]], pool, *args, **kwargs) -> Task:
""" """
Creates a task from a coroutine function and schedules it Creates a task from a coroutine function and schedules it
to run. The associated pool that spawned said task is also to run. The associated pool that spawned said task is also
needed, while any extra keyword or positional arguments are needed, while any extra keyword or positional arguments are
passed to the function itself passed to the function itself
:param coro: The coroutine to spawn :param corofunc: The coroutine function (NOT a coroutine!) to
:type coro: Coroutine[Any, Any, Any] spawn
:type corofunc: function
:param pool: The giambio.context.TaskManager object that :param pool: The giambio.context.TaskManager object that
spawned the task spawned the task
""" """
task = Task(coro.__name__ or str(coro), coro, pool) task = Task(corofunc.__name__ or str(corofunc), corofunc(*args, **kwargs), pool)
task.next_deadline = pool.timeout or 0.0 task.next_deadline = pool.timeout or 0.0
task.joiners = {self.current_task} task.joiners = {self.current_task}
self._data[self.current_task] = task self._data[self.current_task] = task
@ -288,9 +298,15 @@ class AsyncScheduler:
# We need to make sure we don't try to execute # We need to make sure we don't try to execute
# exited tasks that are on the running queue # exited tasks that are on the running queue
return return
if not self.current_pool and self.current_task.pool: if self.current_pool:
if self.current_task.pool and self.current_task.pool is not self.current_pool:
self.current_task.pool.enclosed_pool = self.current_pool
else:
self.current_pool = self.current_task.pool self.current_pool = self.current_task.pool
self.deadlines.put(self.current_pool) pool = self.current_pool
while pool:
self.deadlines.put(pool)
pool = self.current_pool.enclosed_pool
self.debugger.before_task_step(self.current_task) self.debugger.before_task_step(self.current_task)
# Some debugging and internal chatter here # Some debugging and internal chatter here
self.current_task.status = "run" self.current_task.status = "run"
@ -319,7 +335,7 @@ class AsyncScheduler:
def io_release(self, sock): def io_release(self, sock):
""" """
Releases the given resource from our Releases the given resource from our
selector. selector
:param sock: The resource to be released :param sock: The resource to be released
""" """
@ -334,7 +350,7 @@ class AsyncScheduler:
if self.selector.get_map(): if self.selector.get_map():
for k in filter( for k in filter(
lambda o: o.data == self.current_task, lambda o: o.data == task,
dict(self.selector.get_map()).values(), dict(self.selector.get_map()).values(),
): ):
self.io_release(k.fileobj) self.io_release(k.fileobj)
@ -344,11 +360,16 @@ class AsyncScheduler:
""" """
Suspends execution of the current task. This is basically Suspends execution of the current task. This is basically
a do-nothing method, since it will not reschedule the task a do-nothing method, since it will not reschedule the task
before returning. The task will stay suspended until a timer, before returning. The task will stay suspended as long as
I/O operation or cancellation wakes it up, or until another something else outside the loop calls a trap to reschedule it.
running task reschedules it. Any pending I/O for the task is temporarily unscheduled to
avoid some previous network operation to reschedule the task
before it's due
""" """
if self.current_task.last_io or self.current_task.status == "io":
self.io_release_task(self.current_task)
self.current_task.status = "sleep"
self.suspended.append(self.current_task) self.suspended.append(self.current_task)
def reschedule_running(self): def reschedule_running(self):
@ -408,27 +429,32 @@ class AsyncScheduler:
try: try:
to_call() to_call()
except StopIteration as ret: except StopIteration as ret:
task.status = "end"
task.result = ret.value task.result = ret.value
task.finished = True task.finished = True
self.join(task) self.join(task)
self.tasks.remove(task)
except BaseException as err: except BaseException as err:
task.exc = err task.exc = err
self.join(task) self.join(task)
if task in self.tasks:
self.tasks.remove(task)
def prune_deadlines(self): def prune_deadlines(self):
""" """
Removes expired deadlines after their timeout Removes expired deadlines after their timeout
has expired and cancels their associated pool has expired
""" """
while self.deadlines and self.deadlines.get_closest_deadline() <= self.clock(): while self.deadlines and self.deadlines.get_closest_deadline() <= self.clock():
pool = self.deadlines.get() pool = self.deadlines.get()
pool.timed_out = True pool.timed_out = True
self.cancel_pool(pool)
for task in pool.tasks: for task in pool.tasks:
if task is not pool.owner: self.join(task)
self.handle_task_exit(task, partial(task.throw, TooSlowError(self.current_task))) if pool.entry_point is self.entry_point:
if pool.raise_on_timeout: self.handle_task_exit(self.entry_point, partial(self.entry_point.throw, TooSlowError(self.entry_point)))
self.handle_task_exit(pool.owner, partial(pool.owner.throw, TooSlowError(self.current_task))) self.run_ready.append(self.entry_point)
def schedule_tasks(self, tasks: List[Task]): def schedule_tasks(self, tasks: List[Task]):
""" """
@ -439,7 +465,8 @@ class AsyncScheduler:
for task in tasks: for task in tasks:
self.paused.discard(task) self.paused.discard(task)
self.suspended.remove(task) if task in self.suspended:
self.suspended.remove(task)
self.run_ready.extend(tasks) self.run_ready.extend(tasks)
self.reschedule_running() self.reschedule_running()
@ -462,6 +489,7 @@ class AsyncScheduler:
self.run_ready.append(task) self.run_ready.append(task)
self.debugger.after_sleep(task, slept) self.debugger.after_sleep(task, slept)
def get_closest_deadline(self) -> float: def get_closest_deadline(self) -> float:
""" """
Gets the closest expiration deadline (asleep tasks, timeouts) Gets the closest expiration deadline (asleep tasks, timeouts)
@ -469,7 +497,7 @@ class AsyncScheduler:
:return: The closest deadline according to our clock :return: The closest deadline according to our clock
:rtype: float :rtype: float
""" """
if not self.deadlines: if not self.deadlines:
# If there are no deadlines just wait until the first task wakeup # If there are no deadlines just wait until the first task wakeup
timeout = max(0.0, self.paused.get_closest_deadline() - self.clock()) timeout = max(0.0, self.paused.get_closest_deadline() - self.clock())
@ -535,9 +563,12 @@ class AsyncScheduler:
self.run_ready.append(entry) self.run_ready.append(entry)
self.debugger.on_start() self.debugger.on_start()
if loop: if loop:
self.run() try:
self.has_ran = True self.run()
self.debugger.on_exit() finally:
self.has_ran = True
self.close()
self.debugger.on_exit()
def cancel_pool(self, pool: TaskManager) -> bool: def cancel_pool(self, pool: TaskManager) -> bool:
""" """
@ -589,8 +620,9 @@ class AsyncScheduler:
If ensure_done equals False, the loop will cancel ALL If ensure_done equals False, the loop will cancel ALL
running and scheduled tasks and then tear itself down. running and scheduled tasks and then tear itself down.
If ensure_done equals True, which is the default behavior, If ensure_done equals True, which is the default behavior,
this method will raise a GiambioError if the loop hasn't this method will raise a GiambioError exception if the loop
finished running. hasn't finished running. The state of the event loop is reset
so it can be reused with another run() call
""" """
if ensure_done: if ensure_done:
@ -598,6 +630,16 @@ class AsyncScheduler:
elif not self.done(): elif not self.done():
raise GiambioError("event loop not terminated, call this method with ensure_done=False to forcefully exit") raise GiambioError("event loop not terminated, call this method with ensure_done=False to forcefully exit")
self.shutdown() self.shutdown()
# We reset the event loop's state
self.tasks = []
self.entry_point = None
self.current_pool = None
self.current_task = None
self.paused = TimeQueue(self.clock)
self.deadlines = DeadlinesQueue()
self.run_ready = deque()
self.suspended = deque()
def reschedule_joiners(self, task: Task): def reschedule_joiners(self, task: Task):
""" """
@ -605,87 +647,69 @@ class AsyncScheduler:
given task, if any given task, if any
""" """
for t in task.joiners: if task.pool and task.pool.enclosed_pool and not task.pool.enclosed_pool.done():
self.run_ready.append(t) return
self.run_ready.extend(task.joiners)
# noinspection PyMethodMayBeStatic
def is_pool_done(self, pool: Optional[TaskManager]):
"""
Returns True if a given pool has finished
executing
"""
while pool:
if not pool.done():
return False
pool = pool.enclosed_pool
return True
def join(self, task: Task): def join(self, task: Task):
""" """
Joins a task to its callers (implicitly, the parent Joins a task to its callers (implicitly the parent
task, but also every other task who called await task, but also every other task who called await
task.join() on the task object) task.join() on the task object)
""" """
task.joined = True task.joined = True
if any([task.finished, task.cancelled, task.exc]) and task in self.tasks:
self.io_release_task(task)
self.tasks.remove(task)
self.paused.discard(task)
if task.finished or task.cancelled: if task.finished or task.cancelled:
task.status = "end"
if not task.cancelled: if not task.cancelled:
task.status = "cancelled"
# This way join() returns the
# task's return value
for joiner in task.joiners:
self._data[joiner] = task.result
self.debugger.on_task_exit(task) self.debugger.on_task_exit(task)
# If the pool has finished executing or we're at the first parent if task.last_io:
# task that kicked the loop, we can safely reschedule the parent(s) self.io_release_task(task)
if self.is_pool_done(task.pool): if task in self.suspended:
self.suspended.remove(task)
# If the pool (including any enclosing pools) has finished executing
# or we're at the first task that kicked the loop, we can safely
# reschedule the parent(s)
if task.pool is None:
return
if task.pool.done():
self.reschedule_joiners(task) self.reschedule_joiners(task)
self.reschedule_running()
elif task.exc: elif task.exc:
if task in self.suspended:
self.suspended.remove(task)
task.status = "crashed" task.status = "crashed"
if task.exc.__traceback__:
# TODO: We might want to do a bit more complex traceback hacking to remove any extra
# frames from the exception call stack, but for now removing at least the first few
# seems a sensible approach (it's us catching it so we don't care about that)
for _ in range(5):
if task.exc.__traceback__.tb_next:
task.exc.__traceback__ = task.exc.__traceback__.tb_next
self.debugger.on_exception_raised(task, task.exc) self.debugger.on_exception_raised(task, task.exc)
if task is self.entry_point and not task.pool: if task.pool is None or task is self.entry_point:
try: # Parent task has no pool, so we propagate
task.throw(task.exc) raise task.exc
except StopIteration: if self.cancel_pool(task.pool):
... # TODO: ? # This will reschedule the parent(s)
except BaseException: # only if all the tasks inside the task's
# TODO: No idea what to do here # pool have finished executing, either
raise # by cancellation, an exception
elif any(map(lambda tk: tk is task.pool.owner, task.joiners)) or task is task.pool.owner: # or just returned
# We check if the pool's for t in task.joiners.copy():
# owner catches our error # Propagate the exception
# or not. If they don't, we try:
# cancel the entire pool, but t.throw(task.exc)
# if they do, we do nothing except (StopIteration, CancelledError, RuntimeError) as e:
if task.pool.owner is not task: # TODO: Need anything else?
self.handle_task_exit(task.pool.owner, partial(task.pool.owner.coroutine.throw, task.exc)) task.joiners.remove(t)
if any([task.pool.owner.exc, task.pool.owner.cancelled, task.pool.owner.finished]): if isinstance(e, StopIteration):
for t in task.joiners.copy(): t.status = "end"
# Propagate the exception t.result = e.value
self.handle_task_exit(t, partial(t.throw, task.exc)) t.finished = True
if any([t.exc, t.finished, t.cancelled]): elif isinstance(e, CancelledError):
task.joiners.remove(t) t = e.task
for t in task.pool.tasks: t.cancel_pending = False
if not t.joined: t.cancelled = True
self.handle_task_exit(t, partial(t.throw, task.exc)) t.status = "cancelled"
if any([t.exc, t.finished, t.cancelled]): self.debugger.after_cancel(t)
task.joiners.discard(t) elif isinstance(e, BaseException):
self.reschedule_joiners(task) t.exc = e
self.reschedule_running() finally:
if t in self.tasks:
self.tasks.remove(t)
self.reschedule_joiners(task)
def sleep(self, seconds: int or float): def sleep(self, seconds: int or float):
""" """
@ -727,6 +751,8 @@ class AsyncScheduler:
self.io_release_task(task) self.io_release_task(task)
elif task.status == "sleep": elif task.status == "sleep":
self.paused.discard(task) self.paused.discard(task)
if task in self.suspended:
self.suspended.remove(task)
try: try:
self.do_cancel(task) self.do_cancel(task)
except CancelledError as cancel: except CancelledError as cancel:
@ -742,24 +768,24 @@ class AsyncScheduler:
task = cancel.task task = cancel.task
task.cancel_pending = False task.cancel_pending = False
task.cancelled = True task.cancelled = True
self.io_release_task(self.current_task) task.status = "cancelled"
self.debugger.after_cancel(task) self.debugger.after_cancel(task)
self.tasks.remove(task) self.tasks.remove(task)
self.join(task)
else: else:
# If the task ignores our exception, we'll # If the task ignores our exception, we'll
# raise it later again # raise it later again
task.cancel_pending = True task.cancel_pending = True
self.join(task)
def register_sock(self, sock, evt_type: str): def register_sock(self, sock, evt_type: str):
""" """
Registers the given socket inside the Registers the given socket inside the
selector to perform I/0 multiplexing selector to perform I/O multiplexing
:param sock: The socket on which a read or write operation :param sock: The socket on which a read or write operation
has to be performed has to be performed
:param evt_type: The type of event to perform on the given :param evt_type: The type of event to perform on the given
socket, either "read" or "write" socket, either "read" or "write"
:type evt_type: str :type evt_type: str
""" """
@ -793,5 +819,8 @@ class AsyncScheduler:
try: try:
self.selector.register(sock, evt, self.current_task) self.selector.register(sock, evt, self.current_task)
except KeyError: except KeyError:
# The socket is already registered doing something else # The socket is already registered doing something else, we
raise ResourceBusy("The given socket is being read/written by another task") from None # modify the socket instead (or maybe not?)
self.selector.modify(sock, evt, self.current_task)
# TODO: Does this break stuff?
# raise ResourceBusy("The given socket is being read/written by another task") from None