mirror of https://github.com/nocturn9x/giambio.git
Minor bug fixes, need to fix I/O
This commit is contained in:
parent
899e12ead7
commit
2661a153e9
|
@ -36,7 +36,7 @@ from .exceptions import (InternalError,
|
|||
|
||||
class AsyncScheduler:
|
||||
"""
|
||||
An asynchronous scheduler implementation. Tries to mimic the threaded
|
||||
A simple asynchronous scheduler implementation. Tries to mimic the threaded
|
||||
model in its simplicity, without using actual threads, but rather alternating
|
||||
across coroutines execution to let more than one thing at a time to proceed
|
||||
with its calculations. An attempt to fix the threaded model has been made
|
||||
|
@ -44,7 +44,7 @@ class AsyncScheduler:
|
|||
A few examples are tasks cancellation and exception propagation.
|
||||
"""
|
||||
|
||||
def __init__(self, debugger: BaseDebugger = None):
|
||||
def __init__(self, clock: types.FunctionType = default_timer, debugger: BaseDebugger = None):
|
||||
"""
|
||||
Object constructor
|
||||
"""
|
||||
|
@ -54,7 +54,7 @@ class AsyncScheduler:
|
|||
if debugger:
|
||||
assert issubclass(type(debugger),
|
||||
BaseDebugger), "The debugger must be a subclass of giambio.util.BaseDebugger"
|
||||
self.debugger = debugger or type("DumbDebugger", (object,), {"__getattr__": lambda *args: lambda *arg: None})()
|
||||
self.debugger = debugger or type("DumbDebugger", (object, ), {"__getattr__": lambda *args: lambda *arg: None})()
|
||||
# Tasks that are ready to run
|
||||
self.tasks = []
|
||||
# Selector object to perform I/O multiplexing
|
||||
|
@ -62,7 +62,7 @@ class AsyncScheduler:
|
|||
# This will always point to the currently running coroutine (Task object)
|
||||
self.current_task = None
|
||||
# Monotonic clock to keep track of elapsed time reliably
|
||||
self.clock = default_timer
|
||||
self.clock = clock
|
||||
# Tasks that are asleep
|
||||
self.paused = TimeQueue(self.clock)
|
||||
# All active Event objects
|
||||
|
@ -88,7 +88,6 @@ class AsyncScheduler:
|
|||
Shuts down the event loop
|
||||
"""
|
||||
|
||||
# TODO: See if other teardown is required (massive join()?)
|
||||
self.selector.close()
|
||||
|
||||
def run(self):
|
||||
|
@ -102,7 +101,7 @@ class AsyncScheduler:
|
|||
while True:
|
||||
try:
|
||||
if self.done():
|
||||
# If we're done, which means there is no
|
||||
# If we're done, which means there are no
|
||||
# sleeping tasks, no events to deliver,
|
||||
# no I/O to do and no running tasks, we
|
||||
# simply tear us down and return to self.start
|
||||
|
@ -123,6 +122,7 @@ class AsyncScheduler:
|
|||
while self.tasks:
|
||||
# Sets the currently running task
|
||||
self.current_task = self.tasks.pop(0)
|
||||
# Sets the current pool (for nested pools)
|
||||
self.current_pool = self.current_task.pool
|
||||
self.debugger.before_task_step(self.current_task)
|
||||
if self.current_task.cancel_pending:
|
||||
|
@ -150,21 +150,15 @@ class AsyncScheduler:
|
|||
except AttributeError: # If this happens, that's quite bad!
|
||||
raise InternalError("Uh oh! Something very bad just happened, did"
|
||||
" you try to mix primitives from other async libraries?") from None
|
||||
except CancelledError:
|
||||
self.current_task.status = "cancelled"
|
||||
self.current_task.cancelled = True
|
||||
self.current_task.cancel_pending = False
|
||||
self.debugger.after_cancel(self.current_task)
|
||||
self.join(self.current_task)
|
||||
except StopIteration as ret:
|
||||
# Coroutine ends
|
||||
# Task finished executing
|
||||
self.current_task.status = "end"
|
||||
self.current_task.result = ret.value
|
||||
self.current_task.finished = True
|
||||
self.debugger.on_task_exit(self.current_task)
|
||||
self.join(self.current_task)
|
||||
except BaseException as err:
|
||||
# Coroutine raised
|
||||
# Task raised an exception
|
||||
self.current_task.exc = err
|
||||
self.current_task.status = "crashed"
|
||||
self.debugger.on_exception_raised(self.current_task, err)
|
||||
|
@ -269,7 +263,14 @@ class AsyncScheduler:
|
|||
# occur
|
||||
self.tasks.append(t)
|
||||
|
||||
def cancel_all(self):
|
||||
def get_event_tasks(self):
|
||||
"""
|
||||
Returns all tasks currently waiting on events
|
||||
"""
|
||||
|
||||
return set(waiter for waiter in (evt.waiters for evt in self.events))
|
||||
|
||||
def cancel_all_from_current_pool(self):
|
||||
"""
|
||||
Cancels all tasks in the current pool,
|
||||
preparing for the exception throwing
|
||||
|
@ -277,30 +278,47 @@ class AsyncScheduler:
|
|||
"""
|
||||
|
||||
to_reschedule = []
|
||||
for to_cancel in chain(self.tasks, self.paused):
|
||||
try:
|
||||
if to_cancel.pool is self.current_pool:
|
||||
for to_cancel in chain(self.tasks, self.paused, self.get_event_tasks()):
|
||||
if to_cancel.pool is self.current_pool:
|
||||
try:
|
||||
self.cancel(to_cancel)
|
||||
elif to_cancel.status == "sleep":
|
||||
deadline = to_cancel.next_deadline - self.clock()
|
||||
to_reschedule.append((to_cancel, deadline))
|
||||
else:
|
||||
to_reschedule.append((to_cancel, None))
|
||||
except CancelledError:
|
||||
to_cancel.status = "cancelled"
|
||||
to_cancel.cancelled = True
|
||||
to_cancel.cancel_pending = False
|
||||
self.debugger.after_cancel(to_cancel)
|
||||
self.tasks.remove(to_cancel)
|
||||
except CancelledError:
|
||||
# Task was cancelled
|
||||
self.current_task.status = "cancelled"
|
||||
self.current_task.cancelled = True
|
||||
self.current_task.cancel_pending = False
|
||||
self.debugger.after_cancel(self.current_task)
|
||||
elif to_cancel.status == "sleep":
|
||||
deadline = to_cancel.next_deadline - self.clock()
|
||||
to_reschedule.append((to_cancel, deadline))
|
||||
else:
|
||||
to_reschedule.append((to_cancel, None))
|
||||
for task, deadline in to_reschedule:
|
||||
if deadline is not None:
|
||||
self.paused.put(task, deadline)
|
||||
else:
|
||||
self.tasks.append(task)
|
||||
# If there is other work to do (nested pools)
|
||||
# we tell so to our caller
|
||||
return bool(to_reschedule)
|
||||
|
||||
def cancel_all(self):
|
||||
"""
|
||||
Cancels ALL tasks, this method is called as a result
|
||||
of self.close()
|
||||
"""
|
||||
|
||||
for to_cancel in chain(self.tasks, self.paused, self.get_event_tasks()):
|
||||
self.cancel(to_cancel)
|
||||
|
||||
def close(self):
|
||||
"""
|
||||
Closes the event loop, terminating all tasks
|
||||
inside it and tearing down any extra machinery
|
||||
"""
|
||||
|
||||
# self.cancel_all()
|
||||
# self.shutdown()
|
||||
...
|
||||
|
||||
def join(self, task: Task):
|
||||
"""
|
||||
Joins a task to its callers (implicitly, the parent
|
||||
|
@ -312,7 +330,10 @@ class AsyncScheduler:
|
|||
if task.finished or task.cancelled:
|
||||
self.reschedule_joinee(task)
|
||||
elif task.exc:
|
||||
if not self.cancel_all():
|
||||
if not self.cancel_all_from_current_pool():
|
||||
# This will reschedule the parent
|
||||
# only if any enclosed pool has
|
||||
# already exited, which is what we want
|
||||
self.reschedule_joinee(task)
|
||||
|
||||
def sleep(self, seconds: int or float):
|
||||
|
@ -361,8 +382,7 @@ class AsyncScheduler:
|
|||
# Since we don't reschedule the task, it will
|
||||
# not execute until check_events is called
|
||||
|
||||
# TODO: More generic I/O rather than just sockets
|
||||
# Best way to do so? Probably threads
|
||||
# TODO: More generic I/O rather than just sockets (threads)
|
||||
def read_or_write(self, sock: socket.socket, evt_type: str):
|
||||
"""
|
||||
Registers the given socket inside the
|
||||
|
@ -411,7 +431,7 @@ class AsyncScheduler:
|
|||
|
||||
async def sock_sendall(self, sock: socket.socket, data: bytes):
|
||||
"""
|
||||
Sends all the passed data, as bytes, trough the socket asynchronously
|
||||
Sends all the passed bytes trough a socket asynchronously
|
||||
"""
|
||||
|
||||
while data:
|
||||
|
@ -419,9 +439,10 @@ class AsyncScheduler:
|
|||
sent_no = sock.send(data)
|
||||
data = data[sent_no:]
|
||||
|
||||
# TODO: This method seems to cause issues
|
||||
async def close_sock(self, sock: socket.socket):
|
||||
"""
|
||||
Closes the socket asynchronously
|
||||
Closes a socket asynchronously
|
||||
"""
|
||||
|
||||
await want_write(sock)
|
||||
|
@ -440,4 +461,4 @@ class AsyncScheduler:
|
|||
await want_write(sock)
|
||||
err = sock.getsockopt(SOL_SOCKET, SO_ERROR)
|
||||
if err != 0:
|
||||
raise OSError(err, f"Connect call failed: {addr}")
|
||||
raise OSError(err, f"Connect call failed: {addr}")
|
|
@ -17,6 +17,10 @@ limitations under the License.
|
|||
"""
|
||||
|
||||
|
||||
from typing import List
|
||||
import traceback
|
||||
|
||||
|
||||
class GiambioError(Exception):
|
||||
"""
|
||||
Base class for giambio exceptions
|
||||
|
@ -58,3 +62,35 @@ class ResourceClosed(GiambioError):
|
|||
"""
|
||||
|
||||
...
|
||||
|
||||
|
||||
class ErrorStack(GiambioError):
|
||||
"""
|
||||
This exception wraps multiple exceptions
|
||||
and shows each individual traceback of them when
|
||||
printed. This is to ensure that no exception is
|
||||
ever lost even if 2 or more tasks raise at the
|
||||
same time
|
||||
"""
|
||||
|
||||
def __init__(self, errors: List[BaseException]):
|
||||
"""
|
||||
Object constructor
|
||||
"""
|
||||
|
||||
super().__init__()
|
||||
self.errors = errors
|
||||
|
||||
def __str__(self):
|
||||
"""
|
||||
Returns str(self)
|
||||
"""
|
||||
|
||||
tracebacks = ""
|
||||
for i, err in enumerate(self.errors):
|
||||
if i not in (1, len(self.errors)):
|
||||
tracebacks += f"\n{''.join(traceback.format_exception(type(err), err, err.__traceback__))}\n{'-' * 32}\n"
|
||||
else:
|
||||
tracebacks += f"\n{''.join(traceback.format_exception(type(err), err, err.__traceback__))}"
|
||||
return f"Multiple errors occurred:\n{tracebacks}"
|
||||
|
||||
|
|
|
@ -46,7 +46,7 @@ class Task:
|
|||
sleep_start: float = 0.0
|
||||
next_deadline: float = 0.0
|
||||
|
||||
def run(self, what=None):
|
||||
def run(self, what: object = None):
|
||||
"""
|
||||
Simple abstraction layer over coroutines' ``send`` method
|
||||
"""
|
||||
|
|
|
@ -22,6 +22,7 @@ import threading
|
|||
from .core import AsyncScheduler
|
||||
from .exceptions import GiambioError
|
||||
from .context import TaskManager
|
||||
from timeit import default_timer
|
||||
from .socket import AsyncSocket
|
||||
from .util.debug import BaseDebugger
|
||||
from types import FunctionType
|
||||
|
@ -42,7 +43,7 @@ def get_event_loop():
|
|||
raise GiambioError("giambio is not running") from None
|
||||
|
||||
|
||||
def new_event_loop(debugger: BaseDebugger):
|
||||
def new_event_loop(debugger: BaseDebugger, clock: FunctionType):
|
||||
"""
|
||||
Associates a new event loop to the current thread
|
||||
and deactivates the old one. This should not be
|
||||
|
@ -52,14 +53,15 @@ def new_event_loop(debugger: BaseDebugger):
|
|||
"""
|
||||
|
||||
try:
|
||||
loop = thread_local.loop
|
||||
except AttributeError:
|
||||
thread_local.loop = AsyncScheduler(debugger)
|
||||
loop = get_event_loop()
|
||||
except GiambioError:
|
||||
thread_local.loop = AsyncScheduler(clock, debugger)
|
||||
else:
|
||||
if not loop.done():
|
||||
raise GiambioError("cannot change event loop while running")
|
||||
else:
|
||||
thread_local.loop = AsyncScheduler(debugger)
|
||||
loop.close()
|
||||
thread_local.loop = AsyncScheduler(clock, debugger)
|
||||
|
||||
|
||||
def run(func: FunctionType, *args, **kwargs):
|
||||
|
@ -72,7 +74,7 @@ def run(func: FunctionType, *args, **kwargs):
|
|||
"\nWhat you wanna do, instead, is this: giambio.run(your_func, arg1, arg2, ...)")
|
||||
elif not isinstance(func, FunctionType):
|
||||
raise GiambioError("giambio.run() requires an async function as parameter!")
|
||||
new_event_loop(kwargs.get("debugger", None))
|
||||
new_event_loop(kwargs.get("debugger", None), kwargs.get("clock", default_timer))
|
||||
get_event_loop().start(func, *args)
|
||||
|
||||
|
||||
|
|
|
@ -0,0 +1,34 @@
|
|||
import giambio
|
||||
from debugger import Debugger
|
||||
|
||||
|
||||
# TODO: How to create a race condition of 2 exceptions at the same time?
|
||||
|
||||
async def child():
|
||||
print("[child] Child spawned!! Sleeping for 2 seconds")
|
||||
await giambio.sleep(2)
|
||||
print("[child] Had a nice nap!")
|
||||
|
||||
|
||||
async def child1():
|
||||
print("[child 1] Child spawned!! Sleeping for 2 seconds")
|
||||
await giambio.sleep(2)
|
||||
print("[child 1] Had a nice nap!")
|
||||
raise Exception("bruh")
|
||||
|
||||
|
||||
async def main():
|
||||
start = giambio.clock()
|
||||
try:
|
||||
async with giambio.create_pool() as pool:
|
||||
pool.spawn(child)
|
||||
pool.spawn(child1)
|
||||
print("[main] Children spawned, awaiting completion")
|
||||
except Exception as error:
|
||||
# Because exceptions just *work*!
|
||||
print(f"[main] Exception from child caught! {repr(error)}")
|
||||
print(f"[main] Children execution complete in {giambio.clock() - start:.2f} seconds")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
giambio.run(main, debugger=())
|
|
@ -2,7 +2,6 @@ import giambio
|
|||
from debugger import Debugger
|
||||
|
||||
|
||||
|
||||
async def child():
|
||||
print("[child] Child spawned!! Sleeping for 2 seconds")
|
||||
await giambio.sleep(2)
|
||||
|
@ -39,4 +38,4 @@ async def main():
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
giambio.run(main, debugger=None)
|
||||
giambio.run(main, debugger=())
|
||||
|
|
|
@ -50,5 +50,4 @@ if __name__ == "__main__":
|
|||
if isinstance(error, KeyboardInterrupt):
|
||||
logging.info("Ctrl+C detected, exiting")
|
||||
else:
|
||||
raise
|
||||
logging.error(f"Exiting due to a {type(error).__name__}: {error}")
|
||||
|
|
Loading…
Reference in New Issue