Bug fixes with exception handling and minor documentation improvements

This commit is contained in:
Nocturn9x 2022-10-10 09:55:04 +02:00
parent 6d089d7d5f
commit d408cffa87
2 changed files with 148 additions and 129 deletions

View File

@ -7,7 +7,7 @@ Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
@ -17,8 +17,7 @@ limitations under the License.
"""
import giambio
from giambio.task import Task
from typing import List, Optional, Callable, Coroutine, Any
from typing import List, Optional, Any, Coroutine, Callable
class TaskManager:
@ -32,13 +31,13 @@ class TaskManager:
:type raise_on_timeout: bool, optional
"""
def __init__(self, current_task: Task, timeout: float = None, raise_on_timeout: bool = False) -> None:
def __init__(self, timeout: float = None, raise_on_timeout: bool = True) -> None:
"""
Object constructor
"""
# All the tasks that belong to this pool
self.tasks: List[Task] = []
self.tasks: List[giambio.task.Task] = []
# Whether we have been cancelled or not
self.cancelled: bool = False
# The clock time of when we started running, used for
@ -51,19 +50,10 @@ class TaskManager:
self.timeout = None
# Whether our timeout expired or not
self.timed_out: bool = False
# Internal check so users don't try
# to use the pool manually
self._proper_init = False
# We keep track of any inner pools to propagate
# exceptions properly
self.enclosed_pool: Optional["giambio.context.TaskManager"] = None
# Do we raise an error after timeout?
self.raise_on_timeout: bool = raise_on_timeout
# The task that created the pool. We keep track of
# it because we only cancel ourselves if this task
# errors out (so if the error is caught before reaching
# it we just do nothing)
self.owner: Task = current_task
self.entry_point: Optional[giambio.Task] = None
async def spawn(self, func: Callable[..., Coroutine[Any, Any, Any]], *args, **kwargs) -> "giambio.task.Task":
"""
@ -80,6 +70,7 @@ class TaskManager:
"""
self._proper_init = True
self.entry_point = await giambio.traps.current_task()
return self
async def __aexit__(self, exc_type: Exception, exc: Exception, tb):
@ -95,14 +86,13 @@ class TaskManager:
# children to exit
await task.join()
self.tasks.remove(task)
self._proper_init = False
if isinstance(exc, giambio.exceptions.TooSlowError) and not self.raise_on_timeout:
return True
except giambio.exceptions.TooSlowError:
if self.raise_on_timeout:
raise
finally:
self._proper_init = False
if isinstance(exc, giambio.exceptions.TooSlowError) and not self.raise_on_timeout:
return True
async def cancel(self):
"""
Cancels the pool entirely, iterating over all
@ -120,4 +110,4 @@ class TaskManager:
pool have exited, False otherwise
"""
return self._proper_init and all([task.done() for task in self.tasks])
return self._proper_init and all([task.done() for task in self.tasks]) and (True if not self.enclosed_pool else self.enclosed_pool.done())

View File

@ -7,7 +7,7 @@ Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
@ -17,6 +17,7 @@ limitations under the License.
"""
# Import libraries and internal resources
from numbers import Number
from giambio.task import Task
from collections import deque
from functools import partial
@ -38,8 +39,8 @@ from giambio.exceptions import (
class AsyncScheduler:
"""
A simple task scheduler implementation that tries to mimic thread programming
in its simplicity, without using actual threads, but rather alternating
across coroutines execution to let more than one thing at a time to proceed
in its simplicity, without using actual threads, but rather alternating the
execution of coroutines to let more than one thing at a time to proceed
with its calculations. An attempt to fix the threaded model has been made
without making the API unnecessarily complicated.
@ -55,7 +56,7 @@ class AsyncScheduler:
:param clock: A callable returning monotonically increasing values at each call,
usually using seconds as units, but this is not enforced, defaults to timeit.default_timer
:type clock: :class: Callable
:type clock: :class: types.FunctionType
:param debugger: A subclass of giambio.util.BaseDebugger or None if no debugging output
is desired, defaults to None
:type debugger: :class: giambio.util.BaseDebugger
@ -72,7 +73,7 @@ class AsyncScheduler:
def __init__(
self,
clock: Callable = default_timer,
clock: Callable[[], Number] = default_timer,
debugger: Optional[BaseDebugger] = None,
selector: Optional[Any] = None,
io_skip_limit: Optional[int] = None,
@ -94,7 +95,7 @@ class AsyncScheduler:
or type(
"DumbDebugger",
(object,),
{"__getattr__": lambda *args: lambda *arg: None},
{"__getattr__": lambda *_: lambda *_: None},
)()
)
# All tasks the loop has
@ -106,7 +107,7 @@ class AsyncScheduler:
# This will always point to the currently running coroutine (Task object)
self.current_task: Optional[Task] = None
# Monotonic clock to keep track of elapsed time reliably
self.clock: Callable = clock
self.clock: Callable[[], Number] = clock
# Tasks that are asleep
self.paused: TimeQueue = TimeQueue(self.clock)
# Have we ever ran?
@ -129,6 +130,7 @@ class AsyncScheduler:
self.entry_point: Optional[Task] = None
# Suspended tasks
self.suspended: deque = deque()
def __repr__(self):
"""
@ -150,6 +152,8 @@ class AsyncScheduler:
"_data",
"io_skip_limit",
"io_max_timeout",
"suspended",
"entry_point"
}
data = ", ".join(
name + "=" + str(value) for name, value in zip(fields, (getattr(self, field) for field in fields))
@ -168,7 +172,7 @@ class AsyncScheduler:
Shuts down the event loop
"""
for task in self.tasks:
for task in self.get_all_tasks():
self.io_release_task(task)
self.selector.close()
# TODO: Anything else?
@ -206,7 +210,10 @@ class AsyncScheduler:
# after it is set, but it makes the implementation easier
if not self.current_pool and self.current_task.pool:
self.current_pool = self.current_task.pool
self.deadlines.put(self.current_pool)
pool = self.current_pool
while pool:
self.deadlines.put(pool)
pool = self.current_pool.enclosed_pool
# If there are no actively running tasks, we start by
# checking for I/O. This method will wait for I/O until
# the closest deadline to avoid starving sleeping tasks
@ -230,9 +237,10 @@ class AsyncScheduler:
# some tricky behaviors, and this is one of them. When a coroutine
# hits a return statement (either explicit or implicit), it raises
# a StopIteration exception, which has an attribute named value that
# represents the return value of the coroutine, if any. Of course this
# exception is not an error and we should happily keep going after it,
# represents the return value of the coroutine, if it has one. Of course
# this exception is not an error and we should happily keep going after it:
# most of this code below is just useful for internal/debugging purposes
self.current_task.status = "end"
self.current_task.result = ret.value
self.current_task.finished = True
self.join(self.current_task)
@ -244,20 +252,22 @@ class AsyncScheduler:
self.current_task.exc = err
self.join(self.current_task)
def create_task(self, coro: Coroutine[Any, Any, Any], pool) -> Task:
def create_task(self, corofunc: Callable[..., Coroutine[Any, Any, Any]], pool, *args, **kwargs) -> Task:
"""
Creates a task from a coroutine function and schedules it
to run. The associated pool that spawned said task is also
needed, while any extra keyword or positional arguments are
passed to the function itself
:param coro: The coroutine to spawn
:type coro: Coroutine[Any, Any, Any]
:param corofunc: The coroutine function (NOT a coroutine!) to
spawn
:type corofunc: function
:param pool: The giambio.context.TaskManager object that
spawned the task
"""
task = Task(coro.__name__ or str(coro), coro, pool)
task = Task(corofunc.__name__ or str(corofunc), corofunc(*args, **kwargs), pool)
task.next_deadline = pool.timeout or 0.0
task.joiners = {self.current_task}
self._data[self.current_task] = task
@ -288,9 +298,15 @@ class AsyncScheduler:
# We need to make sure we don't try to execute
# exited tasks that are on the running queue
return
if not self.current_pool and self.current_task.pool:
if self.current_pool:
if self.current_task.pool and self.current_task.pool is not self.current_pool:
self.current_task.pool.enclosed_pool = self.current_pool
else:
self.current_pool = self.current_task.pool
self.deadlines.put(self.current_pool)
pool = self.current_pool
while pool:
self.deadlines.put(pool)
pool = self.current_pool.enclosed_pool
self.debugger.before_task_step(self.current_task)
# Some debugging and internal chatter here
self.current_task.status = "run"
@ -319,7 +335,7 @@ class AsyncScheduler:
def io_release(self, sock):
"""
Releases the given resource from our
selector.
selector
:param sock: The resource to be released
"""
@ -334,7 +350,7 @@ class AsyncScheduler:
if self.selector.get_map():
for k in filter(
lambda o: o.data == self.current_task,
lambda o: o.data == task,
dict(self.selector.get_map()).values(),
):
self.io_release(k.fileobj)
@ -344,11 +360,16 @@ class AsyncScheduler:
"""
Suspends execution of the current task. This is basically
a do-nothing method, since it will not reschedule the task
before returning. The task will stay suspended until a timer,
I/O operation or cancellation wakes it up, or until another
running task reschedules it.
before returning. The task will stay suspended as long as
something else outside the loop calls a trap to reschedule it.
Any pending I/O for the task is temporarily unscheduled to
avoid some previous network operation to reschedule the task
before it's due
"""
if self.current_task.last_io or self.current_task.status == "io":
self.io_release_task(self.current_task)
self.current_task.status = "sleep"
self.suspended.append(self.current_task)
def reschedule_running(self):
@ -408,27 +429,32 @@ class AsyncScheduler:
try:
to_call()
except StopIteration as ret:
task.status = "end"
task.result = ret.value
task.finished = True
self.join(task)
self.tasks.remove(task)
except BaseException as err:
task.exc = err
self.join(task)
if task in self.tasks:
self.tasks.remove(task)
def prune_deadlines(self):
"""
Removes expired deadlines after their timeout
has expired and cancels their associated pool
has expired
"""
while self.deadlines and self.deadlines.get_closest_deadline() <= self.clock():
pool = self.deadlines.get()
pool.timed_out = True
self.cancel_pool(pool)
for task in pool.tasks:
if task is not pool.owner:
self.handle_task_exit(task, partial(task.throw, TooSlowError(self.current_task)))
if pool.raise_on_timeout:
self.handle_task_exit(pool.owner, partial(pool.owner.throw, TooSlowError(self.current_task)))
self.join(task)
if pool.entry_point is self.entry_point:
self.handle_task_exit(self.entry_point, partial(self.entry_point.throw, TooSlowError(self.entry_point)))
self.run_ready.append(self.entry_point)
def schedule_tasks(self, tasks: List[Task]):
"""
@ -439,7 +465,8 @@ class AsyncScheduler:
for task in tasks:
self.paused.discard(task)
self.suspended.remove(task)
if task in self.suspended:
self.suspended.remove(task)
self.run_ready.extend(tasks)
self.reschedule_running()
@ -462,6 +489,7 @@ class AsyncScheduler:
self.run_ready.append(task)
self.debugger.after_sleep(task, slept)
def get_closest_deadline(self) -> float:
"""
Gets the closest expiration deadline (asleep tasks, timeouts)
@ -469,7 +497,7 @@ class AsyncScheduler:
:return: The closest deadline according to our clock
:rtype: float
"""
if not self.deadlines:
# If there are no deadlines just wait until the first task wakeup
timeout = max(0.0, self.paused.get_closest_deadline() - self.clock())
@ -535,9 +563,12 @@ class AsyncScheduler:
self.run_ready.append(entry)
self.debugger.on_start()
if loop:
self.run()
self.has_ran = True
self.debugger.on_exit()
try:
self.run()
finally:
self.has_ran = True
self.close()
self.debugger.on_exit()
def cancel_pool(self, pool: TaskManager) -> bool:
"""
@ -589,8 +620,9 @@ class AsyncScheduler:
If ensure_done equals False, the loop will cancel ALL
running and scheduled tasks and then tear itself down.
If ensure_done equals True, which is the default behavior,
this method will raise a GiambioError if the loop hasn't
finished running.
this method will raise a GiambioError exception if the loop
hasn't finished running. The state of the event loop is reset
so it can be reused with another run() call
"""
if ensure_done:
@ -598,6 +630,16 @@ class AsyncScheduler:
elif not self.done():
raise GiambioError("event loop not terminated, call this method with ensure_done=False to forcefully exit")
self.shutdown()
# We reset the event loop's state
self.tasks = []
self.entry_point = None
self.current_pool = None
self.current_task = None
self.paused = TimeQueue(self.clock)
self.deadlines = DeadlinesQueue()
self.run_ready = deque()
self.suspended = deque()
def reschedule_joiners(self, task: Task):
"""
@ -605,87 +647,69 @@ class AsyncScheduler:
given task, if any
"""
for t in task.joiners:
self.run_ready.append(t)
# noinspection PyMethodMayBeStatic
def is_pool_done(self, pool: Optional[TaskManager]):
"""
Returns True if a given pool has finished
executing
"""
while pool:
if not pool.done():
return False
pool = pool.enclosed_pool
return True
if task.pool and task.pool.enclosed_pool and not task.pool.enclosed_pool.done():
return
self.run_ready.extend(task.joiners)
def join(self, task: Task):
"""
Joins a task to its callers (implicitly, the parent
Joins a task to its callers (implicitly the parent
task, but also every other task who called await
task.join() on the task object)
"""
task.joined = True
if any([task.finished, task.cancelled, task.exc]) and task in self.tasks:
self.io_release_task(task)
self.tasks.remove(task)
self.paused.discard(task)
if task.finished or task.cancelled:
task.status = "end"
if not task.cancelled:
task.status = "cancelled"
# This way join() returns the
# task's return value
for joiner in task.joiners:
self._data[joiner] = task.result
self.debugger.on_task_exit(task)
# If the pool has finished executing or we're at the first parent
# task that kicked the loop, we can safely reschedule the parent(s)
if self.is_pool_done(task.pool):
if task.last_io:
self.io_release_task(task)
if task in self.suspended:
self.suspended.remove(task)
# If the pool (including any enclosing pools) has finished executing
# or we're at the first task that kicked the loop, we can safely
# reschedule the parent(s)
if task.pool is None:
return
if task.pool.done():
self.reschedule_joiners(task)
self.reschedule_running()
elif task.exc:
if task in self.suspended:
self.suspended.remove(task)
task.status = "crashed"
if task.exc.__traceback__:
# TODO: We might want to do a bit more complex traceback hacking to remove any extra
# frames from the exception call stack, but for now removing at least the first few
# seems a sensible approach (it's us catching it so we don't care about that)
for _ in range(5):
if task.exc.__traceback__.tb_next:
task.exc.__traceback__ = task.exc.__traceback__.tb_next
self.debugger.on_exception_raised(task, task.exc)
if task is self.entry_point and not task.pool:
try:
task.throw(task.exc)
except StopIteration:
... # TODO: ?
except BaseException:
# TODO: No idea what to do here
raise
elif any(map(lambda tk: tk is task.pool.owner, task.joiners)) or task is task.pool.owner:
# We check if the pool's
# owner catches our error
# or not. If they don't, we
# cancel the entire pool, but
# if they do, we do nothing
if task.pool.owner is not task:
self.handle_task_exit(task.pool.owner, partial(task.pool.owner.coroutine.throw, task.exc))
if any([task.pool.owner.exc, task.pool.owner.cancelled, task.pool.owner.finished]):
for t in task.joiners.copy():
# Propagate the exception
self.handle_task_exit(t, partial(t.throw, task.exc))
if any([t.exc, t.finished, t.cancelled]):
task.joiners.remove(t)
for t in task.pool.tasks:
if not t.joined:
self.handle_task_exit(t, partial(t.throw, task.exc))
if any([t.exc, t.finished, t.cancelled]):
task.joiners.discard(t)
self.reschedule_joiners(task)
self.reschedule_running()
if task.pool is None or task is self.entry_point:
# Parent task has no pool, so we propagate
raise task.exc
if self.cancel_pool(task.pool):
# This will reschedule the parent(s)
# only if all the tasks inside the task's
# pool have finished executing, either
# by cancellation, an exception
# or just returned
for t in task.joiners.copy():
# Propagate the exception
try:
t.throw(task.exc)
except (StopIteration, CancelledError, RuntimeError) as e:
# TODO: Need anything else?
task.joiners.remove(t)
if isinstance(e, StopIteration):
t.status = "end"
t.result = e.value
t.finished = True
elif isinstance(e, CancelledError):
t = e.task
t.cancel_pending = False
t.cancelled = True
t.status = "cancelled"
self.debugger.after_cancel(t)
elif isinstance(e, BaseException):
t.exc = e
finally:
if t in self.tasks:
self.tasks.remove(t)
self.reschedule_joiners(task)
def sleep(self, seconds: int or float):
"""
@ -727,6 +751,8 @@ class AsyncScheduler:
self.io_release_task(task)
elif task.status == "sleep":
self.paused.discard(task)
if task in self.suspended:
self.suspended.remove(task)
try:
self.do_cancel(task)
except CancelledError as cancel:
@ -742,24 +768,24 @@ class AsyncScheduler:
task = cancel.task
task.cancel_pending = False
task.cancelled = True
self.io_release_task(self.current_task)
task.status = "cancelled"
self.debugger.after_cancel(task)
self.tasks.remove(task)
self.join(task)
else:
# If the task ignores our exception, we'll
# raise it later again
task.cancel_pending = True
self.join(task)
def register_sock(self, sock, evt_type: str):
"""
Registers the given socket inside the
selector to perform I/0 multiplexing
selector to perform I/O multiplexing
:param sock: The socket on which a read or write operation
has to be performed
has to be performed
:param evt_type: The type of event to perform on the given
socket, either "read" or "write"
socket, either "read" or "write"
:type evt_type: str
"""
@ -793,5 +819,8 @@ class AsyncScheduler:
try:
self.selector.register(sock, evt, self.current_task)
except KeyError:
# The socket is already registered doing something else
raise ResourceBusy("The given socket is being read/written by another task") from None
# The socket is already registered doing something else, we
# modify the socket instead (or maybe not?)
self.selector.modify(sock, evt, self.current_task)
# TODO: Does this break stuff?
# raise ResourceBusy("The given socket is being read/written by another task") from None