mirror of https://github.com/nocturn9x/giambio.git
Minor bug fixes, need to fix I/O
This commit is contained in:
parent
899e12ead7
commit
2661a153e9
|
@ -36,7 +36,7 @@ from .exceptions import (InternalError,
|
||||||
|
|
||||||
class AsyncScheduler:
|
class AsyncScheduler:
|
||||||
"""
|
"""
|
||||||
An asynchronous scheduler implementation. Tries to mimic the threaded
|
A simple asynchronous scheduler implementation. Tries to mimic the threaded
|
||||||
model in its simplicity, without using actual threads, but rather alternating
|
model in its simplicity, without using actual threads, but rather alternating
|
||||||
across coroutines execution to let more than one thing at a time to proceed
|
across coroutines execution to let more than one thing at a time to proceed
|
||||||
with its calculations. An attempt to fix the threaded model has been made
|
with its calculations. An attempt to fix the threaded model has been made
|
||||||
|
@ -44,7 +44,7 @@ class AsyncScheduler:
|
||||||
A few examples are tasks cancellation and exception propagation.
|
A few examples are tasks cancellation and exception propagation.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, debugger: BaseDebugger = None):
|
def __init__(self, clock: types.FunctionType = default_timer, debugger: BaseDebugger = None):
|
||||||
"""
|
"""
|
||||||
Object constructor
|
Object constructor
|
||||||
"""
|
"""
|
||||||
|
@ -54,7 +54,7 @@ class AsyncScheduler:
|
||||||
if debugger:
|
if debugger:
|
||||||
assert issubclass(type(debugger),
|
assert issubclass(type(debugger),
|
||||||
BaseDebugger), "The debugger must be a subclass of giambio.util.BaseDebugger"
|
BaseDebugger), "The debugger must be a subclass of giambio.util.BaseDebugger"
|
||||||
self.debugger = debugger or type("DumbDebugger", (object,), {"__getattr__": lambda *args: lambda *arg: None})()
|
self.debugger = debugger or type("DumbDebugger", (object, ), {"__getattr__": lambda *args: lambda *arg: None})()
|
||||||
# Tasks that are ready to run
|
# Tasks that are ready to run
|
||||||
self.tasks = []
|
self.tasks = []
|
||||||
# Selector object to perform I/O multiplexing
|
# Selector object to perform I/O multiplexing
|
||||||
|
@ -62,7 +62,7 @@ class AsyncScheduler:
|
||||||
# This will always point to the currently running coroutine (Task object)
|
# This will always point to the currently running coroutine (Task object)
|
||||||
self.current_task = None
|
self.current_task = None
|
||||||
# Monotonic clock to keep track of elapsed time reliably
|
# Monotonic clock to keep track of elapsed time reliably
|
||||||
self.clock = default_timer
|
self.clock = clock
|
||||||
# Tasks that are asleep
|
# Tasks that are asleep
|
||||||
self.paused = TimeQueue(self.clock)
|
self.paused = TimeQueue(self.clock)
|
||||||
# All active Event objects
|
# All active Event objects
|
||||||
|
@ -88,7 +88,6 @@ class AsyncScheduler:
|
||||||
Shuts down the event loop
|
Shuts down the event loop
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# TODO: See if other teardown is required (massive join()?)
|
|
||||||
self.selector.close()
|
self.selector.close()
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
|
@ -102,7 +101,7 @@ class AsyncScheduler:
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
if self.done():
|
if self.done():
|
||||||
# If we're done, which means there is no
|
# If we're done, which means there are no
|
||||||
# sleeping tasks, no events to deliver,
|
# sleeping tasks, no events to deliver,
|
||||||
# no I/O to do and no running tasks, we
|
# no I/O to do and no running tasks, we
|
||||||
# simply tear us down and return to self.start
|
# simply tear us down and return to self.start
|
||||||
|
@ -123,6 +122,7 @@ class AsyncScheduler:
|
||||||
while self.tasks:
|
while self.tasks:
|
||||||
# Sets the currently running task
|
# Sets the currently running task
|
||||||
self.current_task = self.tasks.pop(0)
|
self.current_task = self.tasks.pop(0)
|
||||||
|
# Sets the current pool (for nested pools)
|
||||||
self.current_pool = self.current_task.pool
|
self.current_pool = self.current_task.pool
|
||||||
self.debugger.before_task_step(self.current_task)
|
self.debugger.before_task_step(self.current_task)
|
||||||
if self.current_task.cancel_pending:
|
if self.current_task.cancel_pending:
|
||||||
|
@ -150,21 +150,15 @@ class AsyncScheduler:
|
||||||
except AttributeError: # If this happens, that's quite bad!
|
except AttributeError: # If this happens, that's quite bad!
|
||||||
raise InternalError("Uh oh! Something very bad just happened, did"
|
raise InternalError("Uh oh! Something very bad just happened, did"
|
||||||
" you try to mix primitives from other async libraries?") from None
|
" you try to mix primitives from other async libraries?") from None
|
||||||
except CancelledError:
|
|
||||||
self.current_task.status = "cancelled"
|
|
||||||
self.current_task.cancelled = True
|
|
||||||
self.current_task.cancel_pending = False
|
|
||||||
self.debugger.after_cancel(self.current_task)
|
|
||||||
self.join(self.current_task)
|
|
||||||
except StopIteration as ret:
|
except StopIteration as ret:
|
||||||
# Coroutine ends
|
# Task finished executing
|
||||||
self.current_task.status = "end"
|
self.current_task.status = "end"
|
||||||
self.current_task.result = ret.value
|
self.current_task.result = ret.value
|
||||||
self.current_task.finished = True
|
self.current_task.finished = True
|
||||||
self.debugger.on_task_exit(self.current_task)
|
self.debugger.on_task_exit(self.current_task)
|
||||||
self.join(self.current_task)
|
self.join(self.current_task)
|
||||||
except BaseException as err:
|
except BaseException as err:
|
||||||
# Coroutine raised
|
# Task raised an exception
|
||||||
self.current_task.exc = err
|
self.current_task.exc = err
|
||||||
self.current_task.status = "crashed"
|
self.current_task.status = "crashed"
|
||||||
self.debugger.on_exception_raised(self.current_task, err)
|
self.debugger.on_exception_raised(self.current_task, err)
|
||||||
|
@ -269,7 +263,14 @@ class AsyncScheduler:
|
||||||
# occur
|
# occur
|
||||||
self.tasks.append(t)
|
self.tasks.append(t)
|
||||||
|
|
||||||
def cancel_all(self):
|
def get_event_tasks(self):
|
||||||
|
"""
|
||||||
|
Returns all tasks currently waiting on events
|
||||||
|
"""
|
||||||
|
|
||||||
|
return set(waiter for waiter in (evt.waiters for evt in self.events))
|
||||||
|
|
||||||
|
def cancel_all_from_current_pool(self):
|
||||||
"""
|
"""
|
||||||
Cancels all tasks in the current pool,
|
Cancels all tasks in the current pool,
|
||||||
preparing for the exception throwing
|
preparing for the exception throwing
|
||||||
|
@ -277,30 +278,47 @@ class AsyncScheduler:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
to_reschedule = []
|
to_reschedule = []
|
||||||
for to_cancel in chain(self.tasks, self.paused):
|
for to_cancel in chain(self.tasks, self.paused, self.get_event_tasks()):
|
||||||
try:
|
if to_cancel.pool is self.current_pool:
|
||||||
if to_cancel.pool is self.current_pool:
|
try:
|
||||||
self.cancel(to_cancel)
|
self.cancel(to_cancel)
|
||||||
elif to_cancel.status == "sleep":
|
except CancelledError:
|
||||||
deadline = to_cancel.next_deadline - self.clock()
|
# Task was cancelled
|
||||||
to_reschedule.append((to_cancel, deadline))
|
self.current_task.status = "cancelled"
|
||||||
else:
|
self.current_task.cancelled = True
|
||||||
to_reschedule.append((to_cancel, None))
|
self.current_task.cancel_pending = False
|
||||||
except CancelledError:
|
self.debugger.after_cancel(self.current_task)
|
||||||
to_cancel.status = "cancelled"
|
elif to_cancel.status == "sleep":
|
||||||
to_cancel.cancelled = True
|
deadline = to_cancel.next_deadline - self.clock()
|
||||||
to_cancel.cancel_pending = False
|
to_reschedule.append((to_cancel, deadline))
|
||||||
self.debugger.after_cancel(to_cancel)
|
else:
|
||||||
self.tasks.remove(to_cancel)
|
to_reschedule.append((to_cancel, None))
|
||||||
for task, deadline in to_reschedule:
|
for task, deadline in to_reschedule:
|
||||||
if deadline is not None:
|
if deadline is not None:
|
||||||
self.paused.put(task, deadline)
|
self.paused.put(task, deadline)
|
||||||
else:
|
|
||||||
self.tasks.append(task)
|
|
||||||
# If there is other work to do (nested pools)
|
# If there is other work to do (nested pools)
|
||||||
# we tell so to our caller
|
# we tell so to our caller
|
||||||
return bool(to_reschedule)
|
return bool(to_reschedule)
|
||||||
|
|
||||||
|
def cancel_all(self):
|
||||||
|
"""
|
||||||
|
Cancels ALL tasks, this method is called as a result
|
||||||
|
of self.close()
|
||||||
|
"""
|
||||||
|
|
||||||
|
for to_cancel in chain(self.tasks, self.paused, self.get_event_tasks()):
|
||||||
|
self.cancel(to_cancel)
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
"""
|
||||||
|
Closes the event loop, terminating all tasks
|
||||||
|
inside it and tearing down any extra machinery
|
||||||
|
"""
|
||||||
|
|
||||||
|
# self.cancel_all()
|
||||||
|
# self.shutdown()
|
||||||
|
...
|
||||||
|
|
||||||
def join(self, task: Task):
|
def join(self, task: Task):
|
||||||
"""
|
"""
|
||||||
Joins a task to its callers (implicitly, the parent
|
Joins a task to its callers (implicitly, the parent
|
||||||
|
@ -312,7 +330,10 @@ class AsyncScheduler:
|
||||||
if task.finished or task.cancelled:
|
if task.finished or task.cancelled:
|
||||||
self.reschedule_joinee(task)
|
self.reschedule_joinee(task)
|
||||||
elif task.exc:
|
elif task.exc:
|
||||||
if not self.cancel_all():
|
if not self.cancel_all_from_current_pool():
|
||||||
|
# This will reschedule the parent
|
||||||
|
# only if any enclosed pool has
|
||||||
|
# already exited, which is what we want
|
||||||
self.reschedule_joinee(task)
|
self.reschedule_joinee(task)
|
||||||
|
|
||||||
def sleep(self, seconds: int or float):
|
def sleep(self, seconds: int or float):
|
||||||
|
@ -361,8 +382,7 @@ class AsyncScheduler:
|
||||||
# Since we don't reschedule the task, it will
|
# Since we don't reschedule the task, it will
|
||||||
# not execute until check_events is called
|
# not execute until check_events is called
|
||||||
|
|
||||||
# TODO: More generic I/O rather than just sockets
|
# TODO: More generic I/O rather than just sockets (threads)
|
||||||
# Best way to do so? Probably threads
|
|
||||||
def read_or_write(self, sock: socket.socket, evt_type: str):
|
def read_or_write(self, sock: socket.socket, evt_type: str):
|
||||||
"""
|
"""
|
||||||
Registers the given socket inside the
|
Registers the given socket inside the
|
||||||
|
@ -411,7 +431,7 @@ class AsyncScheduler:
|
||||||
|
|
||||||
async def sock_sendall(self, sock: socket.socket, data: bytes):
|
async def sock_sendall(self, sock: socket.socket, data: bytes):
|
||||||
"""
|
"""
|
||||||
Sends all the passed data, as bytes, trough the socket asynchronously
|
Sends all the passed bytes trough a socket asynchronously
|
||||||
"""
|
"""
|
||||||
|
|
||||||
while data:
|
while data:
|
||||||
|
@ -419,9 +439,10 @@ class AsyncScheduler:
|
||||||
sent_no = sock.send(data)
|
sent_no = sock.send(data)
|
||||||
data = data[sent_no:]
|
data = data[sent_no:]
|
||||||
|
|
||||||
|
# TODO: This method seems to cause issues
|
||||||
async def close_sock(self, sock: socket.socket):
|
async def close_sock(self, sock: socket.socket):
|
||||||
"""
|
"""
|
||||||
Closes the socket asynchronously
|
Closes a socket asynchronously
|
||||||
"""
|
"""
|
||||||
|
|
||||||
await want_write(sock)
|
await want_write(sock)
|
||||||
|
@ -440,4 +461,4 @@ class AsyncScheduler:
|
||||||
await want_write(sock)
|
await want_write(sock)
|
||||||
err = sock.getsockopt(SOL_SOCKET, SO_ERROR)
|
err = sock.getsockopt(SOL_SOCKET, SO_ERROR)
|
||||||
if err != 0:
|
if err != 0:
|
||||||
raise OSError(err, f"Connect call failed: {addr}")
|
raise OSError(err, f"Connect call failed: {addr}")
|
|
@ -17,6 +17,10 @@ limitations under the License.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
from typing import List
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
|
||||||
class GiambioError(Exception):
|
class GiambioError(Exception):
|
||||||
"""
|
"""
|
||||||
Base class for giambio exceptions
|
Base class for giambio exceptions
|
||||||
|
@ -58,3 +62,35 @@ class ResourceClosed(GiambioError):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
...
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class ErrorStack(GiambioError):
|
||||||
|
"""
|
||||||
|
This exception wraps multiple exceptions
|
||||||
|
and shows each individual traceback of them when
|
||||||
|
printed. This is to ensure that no exception is
|
||||||
|
ever lost even if 2 or more tasks raise at the
|
||||||
|
same time
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, errors: List[BaseException]):
|
||||||
|
"""
|
||||||
|
Object constructor
|
||||||
|
"""
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
self.errors = errors
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
"""
|
||||||
|
Returns str(self)
|
||||||
|
"""
|
||||||
|
|
||||||
|
tracebacks = ""
|
||||||
|
for i, err in enumerate(self.errors):
|
||||||
|
if i not in (1, len(self.errors)):
|
||||||
|
tracebacks += f"\n{''.join(traceback.format_exception(type(err), err, err.__traceback__))}\n{'-' * 32}\n"
|
||||||
|
else:
|
||||||
|
tracebacks += f"\n{''.join(traceback.format_exception(type(err), err, err.__traceback__))}"
|
||||||
|
return f"Multiple errors occurred:\n{tracebacks}"
|
||||||
|
|
||||||
|
|
|
@ -46,7 +46,7 @@ class Task:
|
||||||
sleep_start: float = 0.0
|
sleep_start: float = 0.0
|
||||||
next_deadline: float = 0.0
|
next_deadline: float = 0.0
|
||||||
|
|
||||||
def run(self, what=None):
|
def run(self, what: object = None):
|
||||||
"""
|
"""
|
||||||
Simple abstraction layer over coroutines' ``send`` method
|
Simple abstraction layer over coroutines' ``send`` method
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -22,6 +22,7 @@ import threading
|
||||||
from .core import AsyncScheduler
|
from .core import AsyncScheduler
|
||||||
from .exceptions import GiambioError
|
from .exceptions import GiambioError
|
||||||
from .context import TaskManager
|
from .context import TaskManager
|
||||||
|
from timeit import default_timer
|
||||||
from .socket import AsyncSocket
|
from .socket import AsyncSocket
|
||||||
from .util.debug import BaseDebugger
|
from .util.debug import BaseDebugger
|
||||||
from types import FunctionType
|
from types import FunctionType
|
||||||
|
@ -42,7 +43,7 @@ def get_event_loop():
|
||||||
raise GiambioError("giambio is not running") from None
|
raise GiambioError("giambio is not running") from None
|
||||||
|
|
||||||
|
|
||||||
def new_event_loop(debugger: BaseDebugger):
|
def new_event_loop(debugger: BaseDebugger, clock: FunctionType):
|
||||||
"""
|
"""
|
||||||
Associates a new event loop to the current thread
|
Associates a new event loop to the current thread
|
||||||
and deactivates the old one. This should not be
|
and deactivates the old one. This should not be
|
||||||
|
@ -52,14 +53,15 @@ def new_event_loop(debugger: BaseDebugger):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
loop = thread_local.loop
|
loop = get_event_loop()
|
||||||
except AttributeError:
|
except GiambioError:
|
||||||
thread_local.loop = AsyncScheduler(debugger)
|
thread_local.loop = AsyncScheduler(clock, debugger)
|
||||||
else:
|
else:
|
||||||
if not loop.done():
|
if not loop.done():
|
||||||
raise GiambioError("cannot change event loop while running")
|
raise GiambioError("cannot change event loop while running")
|
||||||
else:
|
else:
|
||||||
thread_local.loop = AsyncScheduler(debugger)
|
loop.close()
|
||||||
|
thread_local.loop = AsyncScheduler(clock, debugger)
|
||||||
|
|
||||||
|
|
||||||
def run(func: FunctionType, *args, **kwargs):
|
def run(func: FunctionType, *args, **kwargs):
|
||||||
|
@ -72,7 +74,7 @@ def run(func: FunctionType, *args, **kwargs):
|
||||||
"\nWhat you wanna do, instead, is this: giambio.run(your_func, arg1, arg2, ...)")
|
"\nWhat you wanna do, instead, is this: giambio.run(your_func, arg1, arg2, ...)")
|
||||||
elif not isinstance(func, FunctionType):
|
elif not isinstance(func, FunctionType):
|
||||||
raise GiambioError("giambio.run() requires an async function as parameter!")
|
raise GiambioError("giambio.run() requires an async function as parameter!")
|
||||||
new_event_loop(kwargs.get("debugger", None))
|
new_event_loop(kwargs.get("debugger", None), kwargs.get("clock", default_timer))
|
||||||
get_event_loop().start(func, *args)
|
get_event_loop().start(func, *args)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,34 @@
|
||||||
|
import giambio
|
||||||
|
from debugger import Debugger
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: How to create a race condition of 2 exceptions at the same time?
|
||||||
|
|
||||||
|
async def child():
|
||||||
|
print("[child] Child spawned!! Sleeping for 2 seconds")
|
||||||
|
await giambio.sleep(2)
|
||||||
|
print("[child] Had a nice nap!")
|
||||||
|
|
||||||
|
|
||||||
|
async def child1():
|
||||||
|
print("[child 1] Child spawned!! Sleeping for 2 seconds")
|
||||||
|
await giambio.sleep(2)
|
||||||
|
print("[child 1] Had a nice nap!")
|
||||||
|
raise Exception("bruh")
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
start = giambio.clock()
|
||||||
|
try:
|
||||||
|
async with giambio.create_pool() as pool:
|
||||||
|
pool.spawn(child)
|
||||||
|
pool.spawn(child1)
|
||||||
|
print("[main] Children spawned, awaiting completion")
|
||||||
|
except Exception as error:
|
||||||
|
# Because exceptions just *work*!
|
||||||
|
print(f"[main] Exception from child caught! {repr(error)}")
|
||||||
|
print(f"[main] Children execution complete in {giambio.clock() - start:.2f} seconds")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
giambio.run(main, debugger=())
|
|
@ -2,7 +2,6 @@ import giambio
|
||||||
from debugger import Debugger
|
from debugger import Debugger
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
async def child():
|
async def child():
|
||||||
print("[child] Child spawned!! Sleeping for 2 seconds")
|
print("[child] Child spawned!! Sleeping for 2 seconds")
|
||||||
await giambio.sleep(2)
|
await giambio.sleep(2)
|
||||||
|
@ -39,4 +38,4 @@ async def main():
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
giambio.run(main, debugger=None)
|
giambio.run(main, debugger=())
|
||||||
|
|
|
@ -50,5 +50,4 @@ if __name__ == "__main__":
|
||||||
if isinstance(error, KeyboardInterrupt):
|
if isinstance(error, KeyboardInterrupt):
|
||||||
logging.info("Ctrl+C detected, exiting")
|
logging.info("Ctrl+C detected, exiting")
|
||||||
else:
|
else:
|
||||||
raise
|
|
||||||
logging.error(f"Exiting due to a {type(error).__name__}: {error}")
|
logging.error(f"Exiting due to a {type(error).__name__}: {error}")
|
||||||
|
|
Loading…
Reference in New Issue