Minor bug fixes, need to fix I/O

This commit is contained in:
nocturn9x 2020-11-29 12:06:09 +01:00
parent 899e12ead7
commit 2661a153e9
7 changed files with 138 additions and 47 deletions

View File

@ -36,7 +36,7 @@ from .exceptions import (InternalError,
class AsyncScheduler: class AsyncScheduler:
""" """
An asynchronous scheduler implementation. Tries to mimic the threaded A simple asynchronous scheduler implementation. Tries to mimic the threaded
model in its simplicity, without using actual threads, but rather alternating model in its simplicity, without using actual threads, but rather alternating
across coroutines execution to let more than one thing at a time to proceed across coroutines execution to let more than one thing at a time to proceed
with its calculations. An attempt to fix the threaded model has been made with its calculations. An attempt to fix the threaded model has been made
@ -44,7 +44,7 @@ class AsyncScheduler:
A few examples are tasks cancellation and exception propagation. A few examples are tasks cancellation and exception propagation.
""" """
def __init__(self, debugger: BaseDebugger = None): def __init__(self, clock: types.FunctionType = default_timer, debugger: BaseDebugger = None):
""" """
Object constructor Object constructor
""" """
@ -54,7 +54,7 @@ class AsyncScheduler:
if debugger: if debugger:
assert issubclass(type(debugger), assert issubclass(type(debugger),
BaseDebugger), "The debugger must be a subclass of giambio.util.BaseDebugger" BaseDebugger), "The debugger must be a subclass of giambio.util.BaseDebugger"
self.debugger = debugger or type("DumbDebugger", (object,), {"__getattr__": lambda *args: lambda *arg: None})() self.debugger = debugger or type("DumbDebugger", (object, ), {"__getattr__": lambda *args: lambda *arg: None})()
# Tasks that are ready to run # Tasks that are ready to run
self.tasks = [] self.tasks = []
# Selector object to perform I/O multiplexing # Selector object to perform I/O multiplexing
@ -62,7 +62,7 @@ class AsyncScheduler:
# This will always point to the currently running coroutine (Task object) # This will always point to the currently running coroutine (Task object)
self.current_task = None self.current_task = None
# Monotonic clock to keep track of elapsed time reliably # Monotonic clock to keep track of elapsed time reliably
self.clock = default_timer self.clock = clock
# Tasks that are asleep # Tasks that are asleep
self.paused = TimeQueue(self.clock) self.paused = TimeQueue(self.clock)
# All active Event objects # All active Event objects
@ -88,7 +88,6 @@ class AsyncScheduler:
Shuts down the event loop Shuts down the event loop
""" """
# TODO: See if other teardown is required (massive join()?)
self.selector.close() self.selector.close()
def run(self): def run(self):
@ -102,7 +101,7 @@ class AsyncScheduler:
while True: while True:
try: try:
if self.done(): if self.done():
# If we're done, which means there is no # If we're done, which means there are no
# sleeping tasks, no events to deliver, # sleeping tasks, no events to deliver,
# no I/O to do and no running tasks, we # no I/O to do and no running tasks, we
# simply tear us down and return to self.start # simply tear us down and return to self.start
@ -123,6 +122,7 @@ class AsyncScheduler:
while self.tasks: while self.tasks:
# Sets the currently running task # Sets the currently running task
self.current_task = self.tasks.pop(0) self.current_task = self.tasks.pop(0)
# Sets the current pool (for nested pools)
self.current_pool = self.current_task.pool self.current_pool = self.current_task.pool
self.debugger.before_task_step(self.current_task) self.debugger.before_task_step(self.current_task)
if self.current_task.cancel_pending: if self.current_task.cancel_pending:
@ -150,21 +150,15 @@ class AsyncScheduler:
except AttributeError: # If this happens, that's quite bad! except AttributeError: # If this happens, that's quite bad!
raise InternalError("Uh oh! Something very bad just happened, did" raise InternalError("Uh oh! Something very bad just happened, did"
" you try to mix primitives from other async libraries?") from None " you try to mix primitives from other async libraries?") from None
except CancelledError:
self.current_task.status = "cancelled"
self.current_task.cancelled = True
self.current_task.cancel_pending = False
self.debugger.after_cancel(self.current_task)
self.join(self.current_task)
except StopIteration as ret: except StopIteration as ret:
# Coroutine ends # Task finished executing
self.current_task.status = "end" self.current_task.status = "end"
self.current_task.result = ret.value self.current_task.result = ret.value
self.current_task.finished = True self.current_task.finished = True
self.debugger.on_task_exit(self.current_task) self.debugger.on_task_exit(self.current_task)
self.join(self.current_task) self.join(self.current_task)
except BaseException as err: except BaseException as err:
# Coroutine raised # Task raised an exception
self.current_task.exc = err self.current_task.exc = err
self.current_task.status = "crashed" self.current_task.status = "crashed"
self.debugger.on_exception_raised(self.current_task, err) self.debugger.on_exception_raised(self.current_task, err)
@ -269,7 +263,14 @@ class AsyncScheduler:
# occur # occur
self.tasks.append(t) self.tasks.append(t)
def cancel_all(self): def get_event_tasks(self):
"""
Returns all tasks currently waiting on events
"""
return set(waiter for waiter in (evt.waiters for evt in self.events))
def cancel_all_from_current_pool(self):
""" """
Cancels all tasks in the current pool, Cancels all tasks in the current pool,
preparing for the exception throwing preparing for the exception throwing
@ -277,30 +278,47 @@ class AsyncScheduler:
""" """
to_reschedule = [] to_reschedule = []
for to_cancel in chain(self.tasks, self.paused): for to_cancel in chain(self.tasks, self.paused, self.get_event_tasks()):
try: if to_cancel.pool is self.current_pool:
if to_cancel.pool is self.current_pool: try:
self.cancel(to_cancel) self.cancel(to_cancel)
elif to_cancel.status == "sleep": except CancelledError:
deadline = to_cancel.next_deadline - self.clock() # Task was cancelled
to_reschedule.append((to_cancel, deadline)) self.current_task.status = "cancelled"
else: self.current_task.cancelled = True
to_reschedule.append((to_cancel, None)) self.current_task.cancel_pending = False
except CancelledError: self.debugger.after_cancel(self.current_task)
to_cancel.status = "cancelled" elif to_cancel.status == "sleep":
to_cancel.cancelled = True deadline = to_cancel.next_deadline - self.clock()
to_cancel.cancel_pending = False to_reschedule.append((to_cancel, deadline))
self.debugger.after_cancel(to_cancel) else:
self.tasks.remove(to_cancel) to_reschedule.append((to_cancel, None))
for task, deadline in to_reschedule: for task, deadline in to_reschedule:
if deadline is not None: if deadline is not None:
self.paused.put(task, deadline) self.paused.put(task, deadline)
else:
self.tasks.append(task)
# If there is other work to do (nested pools) # If there is other work to do (nested pools)
# we tell so to our caller # we tell so to our caller
return bool(to_reschedule) return bool(to_reschedule)
def cancel_all(self):
"""
Cancels ALL tasks, this method is called as a result
of self.close()
"""
for to_cancel in chain(self.tasks, self.paused, self.get_event_tasks()):
self.cancel(to_cancel)
def close(self):
"""
Closes the event loop, terminating all tasks
inside it and tearing down any extra machinery
"""
# self.cancel_all()
# self.shutdown()
...
def join(self, task: Task): def join(self, task: Task):
""" """
Joins a task to its callers (implicitly, the parent Joins a task to its callers (implicitly, the parent
@ -312,7 +330,10 @@ class AsyncScheduler:
if task.finished or task.cancelled: if task.finished or task.cancelled:
self.reschedule_joinee(task) self.reschedule_joinee(task)
elif task.exc: elif task.exc:
if not self.cancel_all(): if not self.cancel_all_from_current_pool():
# This will reschedule the parent
# only if any enclosed pool has
# already exited, which is what we want
self.reschedule_joinee(task) self.reschedule_joinee(task)
def sleep(self, seconds: int or float): def sleep(self, seconds: int or float):
@ -361,8 +382,7 @@ class AsyncScheduler:
# Since we don't reschedule the task, it will # Since we don't reschedule the task, it will
# not execute until check_events is called # not execute until check_events is called
# TODO: More generic I/O rather than just sockets # TODO: More generic I/O rather than just sockets (threads)
# Best way to do so? Probably threads
def read_or_write(self, sock: socket.socket, evt_type: str): def read_or_write(self, sock: socket.socket, evt_type: str):
""" """
Registers the given socket inside the Registers the given socket inside the
@ -411,7 +431,7 @@ class AsyncScheduler:
async def sock_sendall(self, sock: socket.socket, data: bytes): async def sock_sendall(self, sock: socket.socket, data: bytes):
""" """
Sends all the passed data, as bytes, trough the socket asynchronously Sends all the passed bytes trough a socket asynchronously
""" """
while data: while data:
@ -419,9 +439,10 @@ class AsyncScheduler:
sent_no = sock.send(data) sent_no = sock.send(data)
data = data[sent_no:] data = data[sent_no:]
# TODO: This method seems to cause issues
async def close_sock(self, sock: socket.socket): async def close_sock(self, sock: socket.socket):
""" """
Closes the socket asynchronously Closes a socket asynchronously
""" """
await want_write(sock) await want_write(sock)
@ -440,4 +461,4 @@ class AsyncScheduler:
await want_write(sock) await want_write(sock)
err = sock.getsockopt(SOL_SOCKET, SO_ERROR) err = sock.getsockopt(SOL_SOCKET, SO_ERROR)
if err != 0: if err != 0:
raise OSError(err, f"Connect call failed: {addr}") raise OSError(err, f"Connect call failed: {addr}")

View File

@ -17,6 +17,10 @@ limitations under the License.
""" """
from typing import List
import traceback
class GiambioError(Exception): class GiambioError(Exception):
""" """
Base class for giambio exceptions Base class for giambio exceptions
@ -58,3 +62,35 @@ class ResourceClosed(GiambioError):
""" """
... ...
class ErrorStack(GiambioError):
"""
This exception wraps multiple exceptions
and shows each individual traceback of them when
printed. This is to ensure that no exception is
ever lost even if 2 or more tasks raise at the
same time
"""
def __init__(self, errors: List[BaseException]):
"""
Object constructor
"""
super().__init__()
self.errors = errors
def __str__(self):
"""
Returns str(self)
"""
tracebacks = ""
for i, err in enumerate(self.errors):
if i not in (1, len(self.errors)):
tracebacks += f"\n{''.join(traceback.format_exception(type(err), err, err.__traceback__))}\n{'-' * 32}\n"
else:
tracebacks += f"\n{''.join(traceback.format_exception(type(err), err, err.__traceback__))}"
return f"Multiple errors occurred:\n{tracebacks}"

View File

@ -46,7 +46,7 @@ class Task:
sleep_start: float = 0.0 sleep_start: float = 0.0
next_deadline: float = 0.0 next_deadline: float = 0.0
def run(self, what=None): def run(self, what: object = None):
""" """
Simple abstraction layer over coroutines' ``send`` method Simple abstraction layer over coroutines' ``send`` method
""" """

View File

@ -22,6 +22,7 @@ import threading
from .core import AsyncScheduler from .core import AsyncScheduler
from .exceptions import GiambioError from .exceptions import GiambioError
from .context import TaskManager from .context import TaskManager
from timeit import default_timer
from .socket import AsyncSocket from .socket import AsyncSocket
from .util.debug import BaseDebugger from .util.debug import BaseDebugger
from types import FunctionType from types import FunctionType
@ -42,7 +43,7 @@ def get_event_loop():
raise GiambioError("giambio is not running") from None raise GiambioError("giambio is not running") from None
def new_event_loop(debugger: BaseDebugger): def new_event_loop(debugger: BaseDebugger, clock: FunctionType):
""" """
Associates a new event loop to the current thread Associates a new event loop to the current thread
and deactivates the old one. This should not be and deactivates the old one. This should not be
@ -52,14 +53,15 @@ def new_event_loop(debugger: BaseDebugger):
""" """
try: try:
loop = thread_local.loop loop = get_event_loop()
except AttributeError: except GiambioError:
thread_local.loop = AsyncScheduler(debugger) thread_local.loop = AsyncScheduler(clock, debugger)
else: else:
if not loop.done(): if not loop.done():
raise GiambioError("cannot change event loop while running") raise GiambioError("cannot change event loop while running")
else: else:
thread_local.loop = AsyncScheduler(debugger) loop.close()
thread_local.loop = AsyncScheduler(clock, debugger)
def run(func: FunctionType, *args, **kwargs): def run(func: FunctionType, *args, **kwargs):
@ -72,7 +74,7 @@ def run(func: FunctionType, *args, **kwargs):
"\nWhat you wanna do, instead, is this: giambio.run(your_func, arg1, arg2, ...)") "\nWhat you wanna do, instead, is this: giambio.run(your_func, arg1, arg2, ...)")
elif not isinstance(func, FunctionType): elif not isinstance(func, FunctionType):
raise GiambioError("giambio.run() requires an async function as parameter!") raise GiambioError("giambio.run() requires an async function as parameter!")
new_event_loop(kwargs.get("debugger", None)) new_event_loop(kwargs.get("debugger", None), kwargs.get("clock", default_timer))
get_event_loop().start(func, *args) get_event_loop().start(func, *args)

34
tests/error_stack.py Normal file
View File

@ -0,0 +1,34 @@
import giambio
from debugger import Debugger
# TODO: How to create a race condition of 2 exceptions at the same time?
async def child():
print("[child] Child spawned!! Sleeping for 2 seconds")
await giambio.sleep(2)
print("[child] Had a nice nap!")
async def child1():
print("[child 1] Child spawned!! Sleeping for 2 seconds")
await giambio.sleep(2)
print("[child 1] Had a nice nap!")
raise Exception("bruh")
async def main():
start = giambio.clock()
try:
async with giambio.create_pool() as pool:
pool.spawn(child)
pool.spawn(child1)
print("[main] Children spawned, awaiting completion")
except Exception as error:
# Because exceptions just *work*!
print(f"[main] Exception from child caught! {repr(error)}")
print(f"[main] Children execution complete in {giambio.clock() - start:.2f} seconds")
if __name__ == "__main__":
giambio.run(main, debugger=())

View File

@ -2,7 +2,6 @@ import giambio
from debugger import Debugger from debugger import Debugger
async def child(): async def child():
print("[child] Child spawned!! Sleeping for 2 seconds") print("[child] Child spawned!! Sleeping for 2 seconds")
await giambio.sleep(2) await giambio.sleep(2)
@ -39,4 +38,4 @@ async def main():
if __name__ == "__main__": if __name__ == "__main__":
giambio.run(main, debugger=None) giambio.run(main, debugger=())

View File

@ -50,5 +50,4 @@ if __name__ == "__main__":
if isinstance(error, KeyboardInterrupt): if isinstance(error, KeyboardInterrupt):
logging.info("Ctrl+C detected, exiting") logging.info("Ctrl+C detected, exiting")
else: else:
raise
logging.error(f"Exiting due to a {type(error).__name__}: {error}") logging.error(f"Exiting due to a {type(error).__name__}: {error}")