diff --git a/README.md b/README.md index 26733c5..b126a04 100644 --- a/README.md +++ b/README.md @@ -252,8 +252,8 @@ async def child1(): async def main(): start = giambio.clock() async with giambio.create_pool() as pool: - pool.spawn(child) - pool.spawn(child1) + await pool.spawn(child) + await pool.spawn(child1) print("[main] Children spawned, awaiting completion") print(f"[main] Children execution complete in {giambio.clock() - start:.2f} seconds") @@ -599,7 +599,7 @@ async def serve(bind_address: tuple): while True: conn, address_tuple = await sock.accept() logging.info(f"{address_tuple[0]}:{address_tuple[1]} connected") - pool.spawn(handler, conn, address_tuple) + await pool.spawn(handler, conn, address_tuple) ``` diff --git a/giambio/__init__.py b/giambio/__init__.py index c8f5a9d..354b04c 100644 --- a/giambio/__init__.py +++ b/giambio/__init__.py @@ -22,7 +22,7 @@ __version__ = (0, 0, 1) from . import exceptions, socket, context, core from .traps import sleep, current_task -from .objects import Event +from .sync import Event from .run import run, clock, create_pool, get_event_loop, new_event_loop, with_timeout from .util import debug diff --git a/giambio/context.py b/giambio/context.py index 46d8dc1..cd94e33 100644 --- a/giambio/context.py +++ b/giambio/context.py @@ -16,8 +16,8 @@ See the License for the specific language governing permissions and limitations under the License. """ -import giambio import types +import giambio from typing import List @@ -50,40 +50,22 @@ class TaskManager: self.timeout: None = None # Whether our timeout expired or not self.timed_out: bool = False + self._proper_init = False - def spawn(self, func: types.FunctionType, *args) -> "giambio.objects.Task": + async def spawn(self, func: types.FunctionType, *args) -> "giambio.task.Task": """ Spawns a child task """ - task = giambio.objects.Task(func.__name__ or str(func), func(*args), self) - task.joiners = [self.loop.current_task] - task.next_deadline = self.timeout or 0.0 - self.loop.tasks.append(task) - self.loop.debugger.on_task_spawn(task) - self.tasks.append(task) - return task - - def spawn_after(self, func: types.FunctionType, n: int, *args) -> "giambio.objects.Task": - """ - Schedules a task for execution after n seconds - """ - - assert n >= 0, "The time delay can't be negative" - task = giambio.objects.Task(func.__name__ or str(func), func(*args), self) - task.joiners = [self.loop.current_task] - task.next_deadline = self.timeout or 0.0 - task.sleep_start = self.loop.clock() - self.loop.paused.put(task, n) - self.loop.debugger.on_task_schedule(task, n) - self.tasks.append(task) - return task + assert self._proper_init + return await giambio.traps.create_task(func, *args) async def __aenter__(self): """ Implements the asynchronous context manager interface, """ + self._proper_init = True return self async def __aexit__(self, exc_type: Exception, exc: Exception, tb): @@ -97,6 +79,7 @@ class TaskManager: # end of the block and wait for all # children to exit await task.join() + self._proper_init = False async def cancel(self): """ diff --git a/giambio/core.py b/giambio/core.py index 1d4e07d..4a42807 100644 --- a/giambio/core.py +++ b/giambio/core.py @@ -20,20 +20,22 @@ limitations under the License. import types import socket from itertools import chain +from giambio.task import Task +from giambio.sync import Event from timeit import default_timer from giambio.context import TaskManager from typing import List, Optional, Set, Any from giambio.util.debug import BaseDebugger from giambio.traps import want_read, want_write -from giambio.objects import Task, TimeQueue, DeadlinesQueue, Event +from giambio.internal import TimeQueue, DeadlinesQueue from selectors import DefaultSelector, EVENT_READ, EVENT_WRITE -from giambio.exceptions import (InternalError, - CancelledError, - ResourceBusy, - GiambioError, - TooSlowError - ) - +from giambio.exceptions import ( + InternalError, + CancelledError, + ResourceBusy, + GiambioError, + TooSlowError, +) class AsyncScheduler: @@ -44,22 +46,21 @@ class AsyncScheduler: with its calculations. An attempt to fix the threaded model has been made without making the API unnecessarily complicated. - This loop only provides the most basic support for task scheduling, I/O - multiplexing, event delivery, task cancellation and exception propagation: - any other feature should therefore be implemented in higher-level object - wrappers (see socket.py and event.py for example). An object wrapper should + This loop only takes care of task scheduling, I/O multiplexing and basic + suspension: any other feature should therefore be implemented in object + wrappers (see io.py and sync.py for example). An object wrapper should not depend on the loop's implementation details such as internal state or - directly access its methods: traps should be used instead; This is to + directly access its methods: traps should be used instead. This is to ensure that the wrapper will keep working even if the scheduler giambio is using changes, which means it is entirely possible, and reasonable, to write your own event loop and run giambio on top of it, provided the required traps are correctly implemented. :param clock: A callable returning monotonically increasing values at each call, - defaults to timeit.default_timer + usually using seconds as units, but this is not enforced, defaults to timeit.default_timer :type clock: :class: types.FunctionType :param debugger: A subclass of giambio.util.BaseDebugger or None if no debugging output - is desired, defaults to None + is desired, defaults to None :type debugger: :class: giambio.util.BaseDebugger :param selector: The selector to use for I/O multiplexing, defaults to selectors.DefaultSelector :param io_skip_limit: The max. amount of times I/O checks can be skipped when @@ -72,7 +73,14 @@ class AsyncScheduler: :type io_max_timeout: int, optional """ - def __init__(self, clock: types.FunctionType = default_timer, debugger: Optional[BaseDebugger] = None, selector: Optional[Any] = None, io_skip_limit: Optional[int] = None, io_max_timeout: Optional[int] = None): + def __init__( + self, + clock: types.FunctionType = default_timer, + debugger: Optional[BaseDebugger] = None, + selector: Optional[Any] = None, + io_skip_limit: Optional[int] = None, + io_max_timeout: Optional[int] = None, + ): """ Object constructor """ @@ -81,11 +89,21 @@ class AsyncScheduler: # lambda which in turn returns None every time we access any of its attributes to avoid lots of # if self.debugger clauses if debugger: - assert issubclass(type(debugger), - BaseDebugger), "The debugger must be a subclass of giambio.util.BaseDebugger" - self.debugger = debugger or type("DumbDebugger", (object, ), {"__getattr__": lambda *args: lambda *arg: None})() - # Tasks that are ready to run + assert issubclass( + type(debugger), BaseDebugger + ), "The debugger must be a subclass of giambio.util.BaseDebugger" + self.debugger = ( + debugger + or type( + "DumbDebugger", + (object,), + {"__getattr__": lambda *args: lambda *arg: None}, + )() + ) + # All tasks the loop has self.tasks: List[Task] = [] + # Tasks that are ready to run + self.run_ready: List[Task] = [] # Selector object to perform I/O multiplexing self.selector: DefaultSelector = DefaultSelector() # This will always point to the currently running coroutine (Task object) @@ -94,8 +112,6 @@ class AsyncScheduler: self.clock: types.FunctionType = clock # Tasks that are asleep self.paused: TimeQueue = TimeQueue(self.clock) - # All active Event objects - self.events: Set[Event] = set() # Have we ever ran? self.has_ran: bool = False # The current pool @@ -113,13 +129,12 @@ class AsyncScheduler: # The max. I/O timeout self.io_max_timeout = io_max_timeout - def done(self) -> bool: """ Returns True if there is no work to do """ - if any([self.paused, self.tasks, self.events, self.selector.get_map()]): + if any([self.paused, self.run_ready, self.selector.get_map()]): return False return True @@ -157,7 +172,7 @@ class AsyncScheduler: # simply tear us down and return to self.start self.close() break - elif not self.tasks: + elif not self.run_ready: # If there are no actively running tasks, we start by # checking for I/O. This method will wait for I/O until # the closest deadline to avoid starving sleeping tasks @@ -169,10 +184,11 @@ class AsyncScheduler: if self.paused: # Next we try to (re)schedule the asleep tasks self.awake_sleeping() - # Then we try to awake event-waiting tasks - if self.events: - self.check_events() - if self.current_pool and self.current_pool.timeout and not self.current_pool.timed_out: + if ( + self.current_pool + and self.current_pool.timeout + and not self.current_pool.timed_out + ): # Stores deadlines for tasks (deadlines are pool-specific). # The deadlines queue will internally make sure not to store # a deadline for the same pool twice. This makes the timeouts @@ -180,7 +196,7 @@ class AsyncScheduler: # after it is set, but it makes the implementation easier self.deadlines.put(self.current_pool) # Otherwise, while there are tasks ready to run, we run them! - while self.tasks: + while self.run_ready: self.run_task_step() except StopIteration as ret: # At the end of the day, coroutines are generator functions with @@ -193,20 +209,30 @@ class AsyncScheduler: self.current_task.status = "end" self.current_task.result = ret.value self.current_task.finished = True - self.debugger.on_task_exit(self.current_task) - self.io_release_task(self.current_task) self.join(self.current_task) except BaseException as err: - # TODO: We might want to do a bit more complex traceback hacking to remove any extra - # frames from the exception call stack, but for now removing at least the first one - # seems a sensible approach (it's us catching it so we don't care about that) + # Our handy join mechanism will handle all the hassle of + # rescheduling joiners and propagating errors, so we + # just need to set the task's exception object and let + # self.join() work its magic self.current_task.exc = err - self.current_task.exc.__traceback__ = self.current_task.exc.__traceback__.tb_next - self.current_task.status = "crashed" - self.debugger.on_exception_raised(self.current_task, err) - self.io_release_task(self.current_task) self.join(self.current_task) + def create_task(self, coro, *args) -> Task: + """ + Creates a task + """ + + task = Task(coro.__name__ or str(coro), coro(*args), self.current_pool) + task.next_deadline = self.current_pool.timeout or 0.0 + task.joiners = {self.current_task} + self.tasks.append(task) + self.run_ready.append(task) + self.debugger.on_task_spawn(task) + self.current_pool.tasks.append(task) + self.reschedule_running() + return task + def run_task_step(self): """ Runs a single step for the current task. @@ -220,24 +246,25 @@ class AsyncScheduler: """ # Sets the currently running task - self.current_task = self.tasks.pop(0) + data = None + self.current_task = self.run_ready.pop(0) + self.debugger.before_task_step(self.current_task) if self.current_task.done(): # We need to make sure we don't try to execute # exited tasks that are on the running queue return - self.debugger.before_task_step(self.current_task) if self.current_task.cancel_pending: # We perform the deferred cancellation # if it was previously scheduled self.cancel(self.current_task) # Little boilerplate to send data back to an async trap - data = None if self.current_task.status != "init": data = self._data # Run a single step with the calculation (i.e. until a yield # somewhere) method, *args = self.current_task.run(data) - self._data = None + if data is self._data: + self._data = None # Some debugging and internal chatter here self.current_task.status = "run" self.current_task.steps += 1 @@ -248,13 +275,13 @@ class AsyncScheduler: # libraries, which most likely have different trap names and behaviors # compared to us. If you get this exception and you're 100% sure you're # not mixing async primitives from other libraries, then it's a bug! - raise InternalError("Uh oh! Something very bad just happened, did" - " you try to mix primitives from other async libraries?") from None - + raise InternalError( + "Uh oh! Something very bad just happened, did" + " you try to mix primitives from other async libraries?" + ) from None # Sneaky method call, thanks to David Beazley for this ;) getattr(self, method)(*args) - def io_release_task(self, task: Task): """ Calls self.io_release in a loop @@ -262,8 +289,10 @@ class AsyncScheduler: """ if self.selector.get_map(): - for k in filter(lambda o: o.data == self.current_task, - dict(self.selector.get_map()).values()): + for k in filter( + lambda o: o.data == self.current_task, + dict(self.selector.get_map()).values(), + ): self.io_release(k.fileobj) task.last_io = () @@ -271,13 +300,27 @@ class AsyncScheduler: """ Releases the given resource from our selector. - :param sock: The resource to be released """ if self.selector.get_map() and sock in self.selector.get_map(): self.selector.unregister(sock) + def suspend(self): + """ + Suspends execution of the current task + """ + + ... # TODO: Unschedule I/O? + + def reschedule_running(self): + """ + Reschedules the currently running task + """ + + if self.current_task: + self.run_ready.append(self.current_task) + def do_cancel(self, task: Task): """ Performs task cancellation by throwing CancelledError inside the given @@ -292,13 +335,31 @@ class AsyncScheduler: error.task = task task.throw(error) - def get_current(self) -> Task: + def get_current_task(self): """ 'Returns' the current task to an async caller """ self._data = self.current_task - self.tasks.append(self.current_task) + self.reschedule_running() + + + def get_current_pool(self): + """ + 'Returns' the current pool to an async caller + """ + + self._data = self.current_pool + self.reschedule_running() + + + def get_current_loop(self): + """ + 'Returns' self to an async caller + """ + + self._data = self + self.reschedule_running() def expire_deadlines(self): """ @@ -306,28 +367,19 @@ class AsyncScheduler: inside the correct pool if its timeout expired """ - while self.deadlines and self.deadlines.get_closest_deadline() <= self.clock(): + while self.deadlines.get_closest_deadline() <= self.clock(): pool = self.deadlines.get() pool.timed_out = True - if not self.current_task.done(): - self.current_task.throw(TooSlowError()) + self.cancel_pool(pool) - def check_events(self): + def schedule_tasks(self, tasks: List[Task]): """ - Checks for ready/expired events and triggers them by - rescheduling all the tasks that called wait() on them + Schedules the given tasks for execution + + :param tasks: The list of task objects to schedule """ - for event in self.events.copy(): - if event.set: - # When an event is set, all the tasks - # that called wait() on it are waken up. - # Since events can only be triggered once, - # we discard the event object from our - # set after we've rescheduled its waiters. - event.event_caught = True - self.tasks.extend(event.waiters) - self.events.remove(event) + self.run_ready.extend(tasks) def awake_sleeping(self): """ @@ -339,7 +391,7 @@ class AsyncScheduler: # Reschedules tasks when their deadline has elapsed task = self.paused.get() slept = self.clock() - task.sleep_start - self.tasks.append(task) + self.run_ready.append(task) self.debugger.after_sleep(task, slept) def get_closest_deadline(self) -> float: @@ -360,7 +412,12 @@ class AsyncScheduler: # If there are both deadlines AND sleeping tasks scheduled, we calculate # the absolute closest deadline among the two sets and use that as a timeout clock = self.clock() - timeout = min([max(0.0, self.paused.get_closest_deadline() - clock), self.deadlines.get_closest_deadline() - clock]) + timeout = min( + [ + max(0.0, self.paused.get_closest_deadline() - clock), + self.deadlines.get_closest_deadline() - clock, + ] + ) return timeout def check_io(self): @@ -369,8 +426,8 @@ class AsyncScheduler: for the event loop """ - before_time = self.clock() # Used for the debugger - if self.tasks or self.events: + before_time = self.clock() # Used for the debugger + if self.run_ready: # If there is work to do immediately (tasks to run) we prefer to # do that first unless some conditions are met, see below self.io_skip += 1 @@ -394,7 +451,7 @@ class AsyncScheduler: io_ready = self.selector.select(timeout) # Get sockets that are ready and schedule their tasks for key, _ in io_ready: - self.tasks.append(key.data) # Resource ready? Schedule its task + self.run_ready.append(key.data) # Resource ready? Schedule its task self.debugger.after_io(self.clock() - before_time) def start(self, func: types.FunctionType, *args): @@ -404,12 +461,11 @@ class AsyncScheduler: entry = Task(func.__name__ or str(func), func(*args), None) self.tasks.append(entry) + self.run_ready.append(entry) self.debugger.on_start() self.run() self.has_ran = True self.debugger.on_exit() - if entry.exc: - raise entry.exc def cancel_pool(self, pool: TaskManager) -> bool: """ @@ -428,48 +484,18 @@ class AsyncScheduler: # tasks running, we wait for them to exit in order # to avoid orphaned tasks return pool.done() - else: # If we're at the main task, we're sure everything else exited + else: # If we're at the main task, we're sure everything else exited return True - def get_event_tasks(self) -> Task: - """ - Yields all tasks currently waiting on events - """ - - for evt in self.events: - for waiter in evt.waiters: - yield waiter - - def get_asleep_tasks(self) -> Task: - """ - Yields all tasks that are currently sleeping - """ - - for asleep in self.paused.container: - yield asleep[2] # Deadline, tiebreaker, task - - def get_io_tasks(self) -> Task: - """ - Yields all tasks currently waiting on I/O resources - """ - - if self.selector.get_map(): - for k in self.selector.get_map().values(): - yield k.data - def get_all_tasks(self) -> chain: """ - Returns a generator yielding all tasks which the loop is currently + Returns a list of all the tasks the loop is currently keeping track of: this includes both running and paused tasks. A paused task is a task which is either waiting on an I/O resource, sleeping, or waiting on an event to be triggered """ - return chain(self.tasks, - self.get_asleep_tasks(), - self.get_event_tasks(), - self.get_io_tasks(), - [self.current_task]) + return self.tasks def cancel_all(self) -> bool: """ @@ -495,7 +521,9 @@ class AsyncScheduler: if ensure_done: self.cancel_all() elif not self.done(): - raise GiambioError("event loop not terminated, call this method with ensure_done=False to forcefully exit") + raise GiambioError( + "event loop not terminated, call this method with ensure_done=False to forcefully exit" + ) self.shutdown() def reschedule_joiners(self, task: Task): @@ -505,28 +533,13 @@ class AsyncScheduler: """ for t in task.joiners: - if t not in self.tasks: + if t not in self.run_ready: # Since a task can be the parent # of multiple children, we need to # make sure we reschedule it only # once, otherwise a RuntimeError will # occur - self.tasks.append(t) - - # noinspection PyMethodMayBeStatic - def is_pool_done(self, pool: TaskManager) -> bool: - """ - Returns true if the given pool has finished - running and can be safely terminated - - :return: Whether the pool finished running - :rtype: bool - """ - - if not pool: - # The parent task has no pool - return True - return pool.done() + self.run_ready.append(t) def join(self, task: Task): """ @@ -537,11 +550,28 @@ class AsyncScheduler: task.joined = True if task.finished or task.cancelled: - if self.is_pool_done(self.current_pool): + if not task.cancelled: + self.debugger.on_task_exit(task) + if task.last_io: + self.io_release_task(task) + if task.pool is None: + return + if self.current_pool and self.current_pool.done(): # If the current pool has finished executing or we're at the first parent # task that kicked the loop, we can safely reschedule the parent(s) self.reschedule_joiners(task) elif task.exc: + task.status = "crashed" + # TODO: We might want to do a bit more complex traceback hacking to remove any extra + # frames from the exception call stack, but for now removing at least the first one + # seems a sensible approach (it's us catching it so we don't care about that) + task.exc.__traceback__ = task.exc.__traceback__.tb_next + if task.last_io: + self.io_release_task(task) + self.debugger.on_exception_raised(task, task.exc) + if task.pool is None: + # Parent task has no pool, so we propagate + raise if self.cancel_pool(self.current_pool): # This will reschedule the parent(s) # only if all the tasks inside the current @@ -552,7 +582,7 @@ class AsyncScheduler: # Propagate the exception try: t.throw(task.exc) - except StopIteration: + except (StopIteration, CancelledError): # TODO: Need anything else? task.joiners.remove(t) self.reschedule_joiners(task) @@ -575,7 +605,7 @@ class AsyncScheduler: # for too long. It is recommended to put a couple of checkpoints like these # in your code if you see degraded concurrent performance in parts of your code # that block the loop - self.tasks.append(self.current_task) + self.reschedule_running() def cancel(self, task: Task): """ @@ -594,8 +624,9 @@ class AsyncScheduler: # or dangling resource open after being cancelled, so maybe we need # a different approach altogether if task.status == "io": - for k in filter(lambda o: o.data == task, - dict(self.selector.get_map()).values()): + for k in filter( + lambda o: o.data == task, dict(self.selector.get_map()).values() + ): self.selector.unregister(k.fileobj) elif task.status == "sleep": self.paused.discard(task) @@ -622,35 +653,6 @@ class AsyncScheduler: # defer this operation for later (check run() for more info) task.cancel_pending = True # Cancellation is deferred - def event_set(self, event: Event): - """ - Sets an event - - :param event: The event object to trigger - :type event: :class: Event - """ - - # When an event is set, we store the event object - # for later, set its attribute and reschedule the - # task that called this method. All tasks waiting - # on this event object will be waken up on the next - # iteration - self.events.add(event) - event.set = True - self.tasks.append(self.current_task) - - def event_wait(self, event): - """ - Pauses the current task on an event - - :param event: The event object to pause upon - :type event: :class: Event - """ - - event.waiters.append(self.current_task) - # Since we don't reschedule the task, it will - # not execute until check_events is called - def register_sock(self, sock, evt_type: str): """ Registers the given socket inside the @@ -663,7 +665,6 @@ class AsyncScheduler: :type evt_type: str """ - self.current_task.status = "io" evt = EVENT_READ if evt_type == "read" else EVENT_WRITE if self.current_task.last_io: @@ -695,7 +696,9 @@ class AsyncScheduler: self.selector.register(sock, evt, self.current_task) except KeyError: # The socket is already registered doing something else - raise ResourceBusy("The given socket is being read/written by another task") from None + raise ResourceBusy( + "The given socket is being read/written by another task" + ) from None # noinspection PyMethodMayBeStatic async def connect_sock(self, sock: socket.socket, address_tuple: tuple): diff --git a/giambio/exceptions.py b/giambio/exceptions.py index ee2cdf6..2d29fbe 100644 --- a/giambio/exceptions.py +++ b/giambio/exceptions.py @@ -99,4 +99,3 @@ class ErrorStack(GiambioError): else: tracebacks += f"\n{''.join(traceback.format_exception(type(err), err, err.__traceback__))}" return f"Multiple errors occurred:\n{tracebacks}" - diff --git a/giambio/objects.py b/giambio/internal.py similarity index 55% rename from giambio/objects.py rename to giambio/internal.py index 07c9546..4a9f82d 100644 --- a/giambio/objects.py +++ b/giambio/internal.py @@ -15,164 +15,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ - -import giambio -from dataclasses import dataclass, field -from heapq import heappop, heappush, heapify -from typing import Union, Coroutine, List, Tuple - - -@dataclass -class Task: - - """ - A simple wrapper around a coroutine object - """ - - # The name of the task. Usually this equals self.coroutine.__name__, - # but in some cases it falls back to repr(self.coroutine) - name: str - # The underlying coroutine object to wrap around a giambio task - coroutine: Coroutine - # The async pool that spawned this task. The one and only task that hasn't - # an associated pool is the main entry point which is not available externally - pool: Union["giambio.context.TaskManager", None] = None - # Whether the task has been cancelled or not. This is True both when the task is - # explicitly cancelled via its cancel() method or when it is cancelled as a result - # of an exception in another task in the same pool - cancelled: bool = False - # This attribute will be None unless the task raised an error - exc: BaseException = None - # The return value of the coroutine - result: object = None - # This attribute signals that the task has exited normally (returned) - finished: bool = False - # This attribute represents what the task is doing and is updated in real - # time by the event loop, internally. Possible values for this are "init"-- - # when the task has been created but not started running yet--, "run"-- when - # the task is running synchronous code--, "io"-- when the task is waiting on - # an I/O resource--, "sleep"-- when the task is either asleep or waiting on - # an event, "crashed"-- when the task has exited because of an exception - # and "cancelled" when-- when the task has been explicitly cancelled with - # its cancel() method or as a result of an exception - status: str = "init" - # This attribute counts how many times the task's run() method has been called - steps: int = 0 - # Simple optimization to improve the selector's efficiency. Check AsyncScheduler.register_sock - # inside giambio.core to know more about it - last_io: tuple = () - # All the tasks waiting on this task's completion - joiners: list = field(default_factory=list) - # Whether this task has been waited for completion or not. The one and only task - # that will have this attribute set to False is the main program entry point, since - # the loop will implicitly wait for anything else to complete before returning - joined: bool = False - # Whether this task has a pending cancellation scheduled. Check AsyncScheduler.cancel - # inside giambio.core to know more about this attribute - cancel_pending: bool = False - # Absolute clock time that represents the date at which the task started sleeping, - # mainly used for internal purposes and debugging - sleep_start: float = 0.0 - # The next deadline, in terms of the absolute clock of the loop, associated to the task - next_deadline: float = 0.0 - - def run(self, what: object = None): - """ - Simple abstraction layer over coroutines' ``send`` method - - :param what: The object that has to be sent to the coroutine, - defaults to None - :type what: object, optional - """ - - return self.coroutine.send(what) - - def throw(self, err: Exception): - """ - Simple abstraction layer over coroutines ``throw`` method - - :param err: The exception that has to be raised inside - the task - :type err: Exception - """ - - return self.coroutine.throw(err) - - async def join(self): - """ - Pauses the caller until the task has finished running. - Any return value is passed to the caller and exceptions - are propagated as well - """ - - res = await giambio.traps.join(self) - if self.exc: - raise self.exc - return res - - async def cancel(self): - """ - Cancels the task - """ - - await giambio.traps.cancel(self) - - def __hash__(self): - """ - Implements hash(self) - """ - - return hash(self.coroutine) - - def done(self): - """ - Returns True if the task is not running, - False otherwise - """ - - return self.exc or self.finished or self.cancelled - - def __del__(self): - """ - Task destructor - """ - - try: - self.coroutine.close() - except RuntimeError: - pass # TODO: This is kinda bad - assert not self.last_io - - -class Event: - """ - A class designed similarly to threading.Event - """ - - def __init__(self): - """ - Object constructor - """ - - self.set = False - self.waiters = [] - - async def trigger(self): - """ - Sets the event, waking up all tasks that called - pause() on it - """ - - if self.set: # This is set by the event loop internally - raise giambio.exceptions.GiambioError("The event has already been set") - await giambio.traps.event_set(self) - - async def wait(self): - """ - Waits until the event is set - """ - - await giambio.traps.event_wait(self) +from giambio.task import Task +from heapq import heappush, heappop class TimeQueue: @@ -371,11 +215,11 @@ class DeadlinesQueue: def get_closest_deadline(self) -> float: """ Returns the closest deadline that is meant to expire - or raises IndexError if the queue is empty + or returns 0.0 if the queue is empty """ if not self: - raise IndexError("DeadlinesQueue is empty") + return 0.0 return self.container[0][0] def __iter__(self): diff --git a/giambio/io.py b/giambio/io.py new file mode 100644 index 0000000..07962e6 --- /dev/null +++ b/giambio/io.py @@ -0,0 +1,297 @@ +""" +Basic abstraction layers for all async I/O primitives + +Copyright (C) 2020 nocturn9x + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import ssl +from socket import SOL_SOCKET, SO_ERROR +import socket as builtin_socket +from giambio.exceptions import ResourceClosed +from giambio.traps import want_write, want_read, io_release + +try: + from ssl import SSLWantReadError, SSLWantWriteError + + WantRead = (BlockingIOError, InterruptedError, SSLWantReadError) + WantWrite = (BlockingIOError, InterruptedError, SSLWantWriteError) +except ImportError: + WantRead = (BlockingIOError, InterruptedError) + WantWrite = (BlockingIOError, InterruptedError) + + +class AsyncSocket: + """ + Abstraction layer for asynchronous sockets + """ + + def __init__(self, sock): + self.sock = sock + self._fd = sock.fileno() + self.sock.setblocking(False) + + async def receive(self, max_size: int, flags: int = 0) -> bytes: + """ + Receives up to max_size bytes from a socket asynchronously + """ + + assert max_size >= 1, "max_size must be >= 1" + if self._fd == -1: + raise ResourceClosed("I/O operation on closed socket") + while True: + try: + return self.sock.recv(max_size, flags) + except WantRead: + await want_read(self.sock) + except WantWrite: + await want_write(self.sock) + + async def accept(self): + """ + Accepts the socket, completing the 3-step TCP handshake asynchronously + """ + + if self._fd == -1: + raise ResourceClosed("I/O operation on closed socket") + while True: + try: + remote, addr = self.sock.accept() + return type(self)(remote), addr + except WantRead: + await want_read(self.sock) + + async def send_all(self, data: bytes, flags: int = 0): + """ + Sends all data inside the buffer asynchronously until it is empty + """ + + if self._fd == -1: + raise ResourceClosed("I/O operation on closed socket") + while data: + try: + sent_no = self.sock.send(data, flags) + except WantRead: + await want_read(self.sock) + except WantWrite: + await want_write(self.sock) + data = data[sent_no:] + + async def close(self): + """ + Closes the socket asynchronously + """ + + if self._fd == -1: + raise ResourceClosed("I/O operation on closed socket") + await io_release(self.sock) + self.sock.close() + self._sock = None + self.sock = -1 + + async def shutdown(self, how): + """ + Wrapper socket method + """ + + if self.sock: + self.sock.shutdown(how) + + async def connect(self, addr: tuple): + """ + Connects the socket to an endpoint + """ + + if self._fd == -1: + raise ResourceClosed("I/O operation on closed socket") + try: + self.sock.connect(addr) + except WantWrite: + await want_write(self.sock) + self.sock.connect(addr) + + async def bind(self, addr: tuple): + """ + Binds the socket to an address + + :param addr: The address, port tuple to bind to + :type addr: tuple + """ + + if self._fd == -1: + raise ResourceClosed("I/O operation on closed socket") + self.sock.bind(addr) + + async def listen(self, backlog: int): + """ + Starts listening with the given backlog + + :param backlog: The socket's backlog + :type backlog: int + """ + + if self._fd == -1: + raise ResourceClosed("I/O operation on closed socket") + self.sock.listen(backlog) + + async def __aenter__(self): + self.sock.__enter__() + return self + + async def __aexit__(self, *args): + if self.sock: + self.sock.__exit__(*args) + + # Yes, I stole these from Curio because I could not be + # arsed to write a bunch of uninteresting simple socket + # methods from scratch, deal with it. + + def fileno(self): + """ + Wrapper socket method + """ + + return self._fd + + def settimeout(self, seconds): + """ + Wrapper socket method + """ + + raise RuntimeError('Use with_timeout() to set a timeout') + + def gettimeout(self): + """ + Wrapper socket method + """ + + return None + + def dup(self): + """ + Wrapper socket method + """ + + return type(self)(self._socket.dup()) + + async def do_handshake(self): + """ + Wrapper socket method + """ + + while True: + try: + return self.sock.do_handshake() + except WantRead: + await want_read(self.sock) + except WantWrite: + await want_write(self.sock) + + async def connect(self, address): + """ + Wrapper socket method + """ + + try: + result = self.sock.connect(address) + if getattr(self, 'do_handshake_on_connect', False): + await self.do_handshake() + return result + except WantWrite: + await want_write(self.sock) + err = self.sock.getsockopt(SOL_SOCKET, SO_ERROR) + if err != 0: + raise OSError(err, f'Connect call failed {address}') + if getattr(self, 'do_handshake_on_connect', False): + await self.do_handshake() + + async def recvfrom(self, buffersize, flags=0): + """ + Wrapper socket method + """ + + while True: + try: + return self.sock.recvfrom(buffersize, flags) + except WantRead: + await want_read(self.sock) + except WantWrite: + await want_write(self.sock) + + async def recvfrom_into(self, buffer, bytes=0, flags=0): + """ + Wrapper socket method + """ + + while True: + try: + return self.sock.recvfrom_into(buffer, bytes, flags) + except WantRead: + await want_read(self.sock) + except WantWrite: + await want_write(self.sock) + + async def sendto(self, bytes, flags_or_address, address=None): + """ + Wrapper socket method + """ + + if address: + flags = flags_or_address + else: + address = flags_or_address + flags = 0 + while True: + try: + return self.sock.sendto(bytes, flags, address) + except WantWrite: + await want_write(self.sock) + except WantRead: + await want_read(self.sock) + + async def recvmsg(self, bufsize, ancbufsize=0, flags=0): + """ + Wrapper socket method + """ + + while True: + try: + return self.sock.recvmsg(bufsize, ancbufsize, flags) + except WantRead: + await want_read(self.sock) + + async def recvmsg_into(self, buffers, ancbufsize=0, flags=0): + """ + Wrapper socket method + """ + + while True: + try: + return self.sock.recvmsg_into(buffers, ancbufsize, flags) + except WantRead: + await want_read(self.sock) + + async def sendmsg(self, buffers, ancdata=(), flags=0, address=None): + """ + Wrapper socket method + """ + + while True: + try: + return self.sock.sendmsg(buffers, ancdata, flags, address) + except WantRead: + await want_write(self.sock) + + def __repr__(self): + return f"AsyncSocket({self.sock})" diff --git a/giambio/run.py b/giambio/run.py index babee34..0b070ba 100644 --- a/giambio/run.py +++ b/giambio/run.py @@ -68,8 +68,10 @@ def run(func: FunctionType, *args, **kwargs): """ if inspect.iscoroutine(func): - raise GiambioError("Looks like you tried to call giambio.run(your_func(arg1, arg2, ...)), that is wrong!" - "\nWhat you wanna do, instead, is this: giambio.run(your_func, arg1, arg2, ...)") + raise GiambioError( + "Looks like you tried to call giambio.run(your_func(arg1, arg2, ...)), that is wrong!" + "\nWhat you wanna do, instead, is this: giambio.run(your_func, arg1, arg2, ...)" + ) elif not inspect.iscoroutinefunction(func): raise GiambioError("giambio.run() requires an async function as parameter!") new_event_loop(kwargs.get("debugger", None), kwargs.get("clock", default_timer)) diff --git a/giambio/socket.py b/giambio/socket.py index d3f4a22..3dd0bb2 100644 --- a/giambio/socket.py +++ b/giambio/socket.py @@ -1,240 +1,38 @@ -""" Basic abstraction layer for giambio asynchronous sockets - -Copyright (C) 2020 nocturn9x - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -""" - -import ssl -from socket import SOL_SOCKET, SO_ERROR -import socket as builtin_socket -from giambio.exceptions import ResourceClosed -from giambio.traps import want_write, want_read - -try: - from ssl import SSLWantReadError, SSLWantWriteError - WantRead = (BlockingIOError, InterruptedError, SSLWantReadError) - WantWrite = (BlockingIOError, InterruptedError, SSLWantWriteError) -except ImportError: - WantRead = (BlockingIOError, InterruptedError) - WantWrite = (BlockingIOError, InterruptedError) - - -class AsyncSocket: - """ - Abstraction layer for asynchronous sockets - """ - - def __init__(self, sock): - self.sock = sock - self._fd = sock.fileno() - self.sock.setblocking(False) - - - async def receive(self, max_size: int, flags: int = 0) -> bytes: - """ - Receives up to max_size bytes from a socket asynchronously - """ - - assert max_size >= 1, "max_size must be >= 1" - data = b"" - if self._fd == -1: - raise ResourceClosed("I/O operation on closed socket") - while True: - try: - return self.sock.recv(max_size, flags) - except WantRead: - await want_read(self.sock) - except WantWrite: - await want_write(self.sock) - - async def accept(self): - """ - Accepts the socket, completing the 3-step TCP handshake asynchronously - """ - - if self.sock == -1: - raise ResourceClosed("I/O operation on closed socket") - while True: - try: - remote, addr = self.sock.accept() - return wrap_socket(remote), addr - except WantRead: - await want_read(self.sock) - - async def send_all(self, data: bytes, flags: int = 0): - """ - Sends all data inside the buffer asynchronously until it is empty - """ - - if self.sock == -1: - raise ResourceClosed("I/O operation on closed socket") - while data: - try: - sent_no = self.sock.send(data, flags) - except WantRead: - await want_read(self.sock) - except WantWrite: - await want_write(self.sock) - data = data[sent_no:] - - async def close(self): - """ - Closes the socket asynchronously - """ - - if self.sock == -1: - raise ResourceClosed("I/O operation on closed socket") - await release_sock(self.sock) - self.sock.close() - self._sock = None - self.sock = -1 - - async def connect(self, addr: tuple): - """ - Connects the socket to an endpoint - """ - - if self.sock == -1: - raise ResourceClosed("I/O operation on closed socket") - try: - self.sock.connect(addr) - except WantWrite: - await want_write(self.sock) - self.sock.connect(addr) - - - async def bind(self, addr: tuple): - """ - Binds the socket to an address - - :param addr: The address, port tuple to bind to - :type addr: tuple - """ - - if self.sock == -1: - raise ResourceClosed("I/O operation on closed socket") - self.sock.bind(addr) - - async def listen(self, backlog: int): - """ - Starts listening with the given backlog - - :param backlog: The socket's backlog - :type backlog: int - """ - - if self.sock == -1: - raise ResourceClosed("I/O operation on closed socket") - self.sock.listen(backlog) - - async def __aenter__(self): - self.sock.__enter__() - return self - - async def __aexit__(self, *args): - if self.sock: - self.sock.__exit__(*args) - - # Yes, I stole these from Curio because I could not be - # arsed to write a bunch of uninteresting simple socket - # methods from scratch, deal with it. - - async def connect(self, address): - try: - result = self.sock.connect(address) - if getattr(self, 'do_handshake_on_connect', False): - await self.do_handshake() - return result - except WantWrite: - await want_write(self.sock) - err = self.sock.getsockopt(SOL_SOCKET, SO_ERROR) - if err != 0: - raise OSError(err, f'Connect call failed {address}') - if getattr(self, 'do_handshake_on_connect', False): - await self.do_handshake() - - async def recvfrom(self, buffersize, flags=0): - while True: - try: - return self.sock.recvfrom(buffersize, flags) - except WantRead: - await want_read(self.sock) - except WantWrite: - await want_write(self.sock) - - async def recvfrom_into(self, buffer, bytes=0, flags=0): - while True: - try: - return self.sock.recvfrom_into(buffer, bytes, flags) - except WantRead: - await want_read(self.sock) - except WantWrite: - await want_write(self.sock) - - async def sendto(self, bytes, flags_or_address, address=None): - if address: - flags = flags_or_address - else: - address = flags_or_address - flags = 0 - while True: - try: - return self.sock.sendto(bytes, flags, address) - except WantWrite: - await want_write(self.sock) - except WantRead: - await want_read(self.sock) - - async def recvmsg(self, bufsize, ancbufsize=0, flags=0): - while True: - try: - return self.sock.recvmsg(bufsize, ancbufsize, flags) - except WantRead: - await want_read(self.sock) - - async def recvmsg_into(self, buffers, ancbufsize=0, flags=0): - while True: - try: - return self.sock.recvmsg_into(buffers, ancbufsize, flags) - except WantRead: - await want_read(self.sock) - - async def sendmsg(self, buffers, ancdata=(), flags=0, address=None): - while True: - try: - return self.sock.sendmsg(buffers, ancdata, flags, address) - except WantRead: - await want_write(self.sock) - - def __repr__(self): - return f"giambio.socket.AsyncSocket({self.sock}, {self.loop})" - - -def wrap_socket(sock: builtin_socket.socket) -> AsyncSocket: - """ - Wraps a standard socket into an async socket - """ - - return AsyncSocket(sock) - - -def socket(*args, **kwargs): - """ - Creates a new giambio socket, taking in the same positional and - keyword arguments as the standard library's socket.socket - constructor - """ - - return AsyncSocket(builtin_socket.socket(*args, **kwargs)) - +""" +Socket and networking utilities + +Copyright (C) 2020 nocturn9x + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import socket as _socket +from .io import AsyncSocket + + +def wrap_socket(sock: _socket.socket) -> AsyncSocket: + """ + Wraps a standard socket into an async socket + """ + + return AsyncSocket(sock) + + +def socket(*args, **kwargs): + """ + Creates a new giambio socket, taking in the same positional and + keyword arguments as the standard library's socket.socket + constructor + """ + + return AsyncSocket(_socket.socket(*args, **kwargs)) diff --git a/giambio/sync.py b/giambio/sync.py new file mode 100644 index 0000000..c752386 --- /dev/null +++ b/giambio/sync.py @@ -0,0 +1,49 @@ +""" +Task synchronization primitives + +Copyright (C) 2020 nocturn9x + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +from giambio.traps import event_wait, event_set + + +class Event: + """ + A class designed similarly to threading.Event + """ + + def __init__(self): + """ + Object constructor + """ + + self.set = False + self.waiters = set() + + async def trigger(self): + """ + Sets the event, waking up all tasks that called + pause() on it + """ + + if self.set: + raise giambio.exceptions.GiambioError("The event has already been set") + await event_set(self) + + async def wait(self): + """ + Waits until the event is set + """ + + await event_wait(self) diff --git a/giambio/task.py b/giambio/task.py new file mode 100644 index 0000000..f2c7e35 --- /dev/null +++ b/giambio/task.py @@ -0,0 +1,146 @@ +""" +Object wrapper for asynchronous tasks + +Copyright (C) 2020 nocturn9x + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import giambio +from dataclasses import dataclass, field +from typing import Union, Coroutine, List, Tuple, Set + + +@dataclass +class Task: + + """ + A simple wrapper around a coroutine object + """ + + # The name of the task. Usually this equals self.coroutine.__name__, + # but in some cases it falls back to repr(self.coroutine) + name: str + # The underlying coroutine object to wrap around a giambio task + coroutine: Coroutine + # The async pool that spawned this task. The one and only task that hasn't + # an associated pool is the main entry point which is not available externally + pool: Union["giambio.context.TaskManager", None] = None + # Whether the task has been cancelled or not. This is True both when the task is + # explicitly cancelled via its cancel() method or when it is cancelled as a result + # of an exception in another task in the same pool + cancelled: bool = False + # This attribute will be None unless the task raised an error + exc: BaseException = None + # The return value of the coroutine + result: object = None + # This attribute signals that the task has exited normally (returned) + finished: bool = False + # This attribute represents what the task is doing and is updated in real + # time by the event loop, internally. Possible values for this are "init"-- + # when the task has been created but not started running yet--, "run"-- when + # the task is running synchronous code--, "io"-- when the task is waiting on + # an I/O resource--, "sleep"-- when the task is either asleep or waiting on + # an event, "crashed"-- when the task has exited because of an exception + # and "cancelled" when-- when the task has been explicitly cancelled with + # its cancel() method or as a result of an exception + status: str = "init" + # This attribute counts how many times the task's run() method has been called + steps: int = 0 + # Simple optimization to improve the selector's efficiency. Check AsyncScheduler.register_sock + # inside giambio.core to know more about it + last_io: tuple = () + # All the tasks waiting on this task's completion + joiners: Set = field(default_factory=set) + # Whether this task has been waited for completion or not. The one and only task + # that will have this attribute set to False is the main program entry point, since + # the loop will implicitly wait for anything else to complete before returning + joined: bool = False + # Whether this task has a pending cancellation scheduled. Check AsyncScheduler.cancel + # inside giambio.core to know more about this attribute + cancel_pending: bool = False + # Absolute clock time that represents the date at which the task started sleeping, + # mainly used for internal purposes and debugging + sleep_start: float = 0.0 + # The next deadline, in terms of the absolute clock of the loop, associated to the task + next_deadline: float = 0.0 + + def run(self, what: object = None): + """ + Simple abstraction layer over coroutines' ``send`` method + + :param what: The object that has to be sent to the coroutine, + defaults to None + :type what: object, optional + """ + + return self.coroutine.send(what) + + def throw(self, err: Exception): + """ + Simple abstraction layer over coroutines ``throw`` method + + :param err: The exception that has to be raised inside + the task + :type err: Exception + """ + + return self.coroutine.throw(err) + + async def join(self): + """ + Pauses the caller until the task has finished running. + Any return value is passed to the caller and exceptions + are propagated as well + """ + + self.joiners.add(await giambio.traps.current_task()) + print(self.joiners) + res = await giambio.traps.join(self) + if self.exc: + raise self.exc + return res + + async def cancel(self): + """ + Cancels the task + """ + + await giambio.traps.cancel(self) + + def __hash__(self): + """ + Implements hash(self) + """ + + return hash(self.coroutine) + + def done(self): + """ + Returns True if the task is not running, + False otherwise + """ + + return self.exc or self.finished or self.cancelled + + def __del__(self): + """ + Task destructor + """ + + try: + self.coroutine.close() + except RuntimeError: + pass # TODO: This is kinda bad + assert not self.last_io + diff --git a/giambio/traps.py b/giambio/traps.py index f908fe1..7a35fec 100644 --- a/giambio/traps.py +++ b/giambio/traps.py @@ -22,6 +22,11 @@ limitations under the License. import types +import inspect +from giambio.task import Task +from types import FunctionType +from typing import List, Union, Iterable +from giambio.exceptions import GiambioError @types.coroutine @@ -36,7 +41,27 @@ def create_trap(method, *args): return data -async def sleep(seconds: int): +async def create_task(coro: FunctionType, *args): + """ + Spawns a new task in the current event loop from a bare coroutine + function. All extra positional arguments are passed to the function + + This trap should *NOT* be used on its own, it is meant to be + called from internal giambio machinery + """ + + if inspect.iscoroutine(coro): + raise GiambioError( + "Looks like you tried to call giambio.run(your_func(arg1, arg2, ...)), that is wrong!" + "\nWhat you wanna do, instead, is this: giambio.run(your_func, arg1, arg2, ...)" + ) + elif inspect.iscoroutinefunction(coro): + return await create_trap("create_task", coro, *args) + else: + raise TypeError("coro must be a coroutine or coroutine function") + + +async def sleep(seconds: Union[int, float]): """ Pause the execution of an async function for a given amount of seconds. This function is functionally equivalent to time.sleep, but can be used @@ -73,7 +98,23 @@ async def current_task(): Gets the currently running task in an asynchronous fashion """ - return await create_trap("get_current") + return await create_trap("get_current_task") + + +async def current_loop(): + """ + Gets the currently running loop in an asynchronous fashion + """ + + return await create_trap("get_current_loop") + + +async def current_pool(): + """ + Gets the currently active task pool in an asynchronous fashion + """ + + return await create_trap("get_current_pool") async def join(task): @@ -126,16 +167,6 @@ async def want_write(stream): await create_trap("register_sock", stream, "write") -async def event_set(event): - """ - Communicates to the loop that the given event object - must be set. This is important as the loop constantly - checks for active events to deliver them - """ - - await create_trap("event_set", event) - - async def event_wait(event): """ Notifies the event loop that the current task has to wait @@ -145,5 +176,32 @@ async def event_wait(event): if event.set: return - await create_trap("event_wait", event) + event.waiters.add(await current_task()) + await create_trap("suspend") + +async def event_set(event): + """ + Sets the given event and reawakens its + waiters + """ + + event.set = True + await reschedule_running() + await schedule_tasks(event.waiters) + + +async def schedule_tasks(tasks: Iterable[Task]): + """ + Schedules a list of tasks for execution + """ + + await create_trap("schedule_tasks", tasks) + + +async def reschedule_running(): + """ + Reschedules the current task for execution + """ + + await create_trap("reschedule_running") diff --git a/giambio/util/__init__.py b/giambio/util/__init__.py index d6f0e96..77c5f82 100644 --- a/giambio/util/__init__.py +++ b/giambio/util/__init__.py @@ -12,4 +12,4 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. -""" \ No newline at end of file +""" diff --git a/giambio/util/debug.py b/giambio/util/debug.py index 8d6988e..f911e20 100644 --- a/giambio/util/debug.py +++ b/giambio/util/debug.py @@ -16,7 +16,7 @@ See the License for the specific language governing permissions and limitations under the License. """ from abc import ABC, abstractmethod -from giambio.objects import Task +from giambio.task import Task class BaseDebugger(ABC): diff --git a/tests/cancel.py b/tests/cancel.py index d55a80f..6981960 100644 --- a/tests/cancel.py +++ b/tests/cancel.py @@ -11,11 +11,15 @@ async def child(name: int): async def main(): start = giambio.clock() async with giambio.create_pool() as pool: - pool.spawn(child, 1) # If you comment this line, the pool will exit immediately! - task = pool.spawn(child, 2) + await pool.spawn( + child, 1 + ) # If you comment this line, the pool will exit immediately! + task = await pool.spawn(child, 2) await task.cancel() print("[main] Children spawned, awaiting completion") - print(f"[main] Children execution complete in {giambio.clock() - start:.2f} seconds") + print( + f"[main] Children execution complete in {giambio.clock() - start:.2f} seconds" + ) if __name__ == "__main__": diff --git a/tests/debugger.py b/tests/debugger.py index 7447fc3..d2619f0 100644 --- a/tests/debugger.py +++ b/tests/debugger.py @@ -13,7 +13,9 @@ class Debugger(giambio.debug.BaseDebugger): print("## Finished running") def on_task_schedule(self, task, delay: int): - print(f">> A task named '{task.name}' was scheduled to run in {delay:.2f} seconds") + print( + f">> A task named '{task.name}' was scheduled to run in {delay:.2f} seconds" + ) def on_task_spawn(self, task): print(f">> A task named '{task.name}' was spawned") @@ -47,4 +49,3 @@ class Debugger(giambio.debug.BaseDebugger): def on_exception_raised(self, task, exc): print(f"== '{task.name}' raised {repr(exc)}") - diff --git a/tests/events.py b/tests/events.py index cb63d67..c8175c7 100644 --- a/tests/events.py +++ b/tests/events.py @@ -1,3 +1,4 @@ +from debugger import Debugger import giambio @@ -14,14 +15,16 @@ async def child(ev: giambio.Event, pause: int): await giambio.sleep(pause) end_sleep = giambio.clock() - start_sleep end_total = giambio.clock() - start_total - print(f"[child] Done! Slept for {end_total} seconds total ({end_pause} paused, {end_sleep} sleeping), nice nap!") + print( + f"[child] Done! Slept for {end_total} seconds total ({end_pause} paused, {end_sleep} sleeping), nice nap!" + ) async def parent(pause: int = 1): async with giambio.create_pool() as pool: event = giambio.Event() print("[parent] Spawning child task") - pool.spawn(child, event, pause + 2) + await pool.spawn(child, event, pause + 2) start = giambio.clock() print(f"[parent] Sleeping {pause} second(s) before setting the event") await giambio.sleep(pause) @@ -32,4 +35,4 @@ async def parent(pause: int = 1): if __name__ == "__main__": - giambio.run(parent, 3) + giambio.run(parent, 3, debugger=()) diff --git a/tests/exceptions.py b/tests/exceptions.py index 439e29d..66b8c7b 100644 --- a/tests/exceptions.py +++ b/tests/exceptions.py @@ -20,13 +20,15 @@ async def main(): start = giambio.clock() try: async with giambio.create_pool() as pool: - pool.spawn(child) - pool.spawn(child1) + await pool.spawn(child) + await pool.spawn(child1) print("[main] Children spawned, awaiting completion") except Exception as error: # Because exceptions just *work*! print(f"[main] Exception from child caught! {repr(error)}") - print(f"[main] Children execution complete in {giambio.clock() - start:.2f} seconds") + print( + f"[main] Children execution complete in {giambio.clock() - start:.2f} seconds" + ) if __name__ == "__main__": diff --git a/tests/nested_exception.py b/tests/nested_exception.py index 28f299f..5ae10c8 100644 --- a/tests/nested_exception.py +++ b/tests/nested_exception.py @@ -31,19 +31,21 @@ async def main(): start = giambio.clock() try: async with giambio.create_pool() as pool: - pool.spawn(child) - pool.spawn(child1) + await pool.spawn(child) + await pool.spawn(child1) print("[main] Children spawned, awaiting completion") async with giambio.create_pool() as new_pool: # This pool will be cancelled by the exception # in the other pool - new_pool.spawn(child2) - new_pool.spawn(child3) + await new_pool.spawn(child2) + await new_pool.spawn(child3) print("[main] 3rd child spawned") except Exception as error: # Because exceptions just *work*! print(f"[main] Exception from child caught! {repr(error)}") - print(f"[main] Children execution complete in {giambio.clock() - start:.2f} seconds") + print( + f"[main] Children execution complete in {giambio.clock() - start:.2f} seconds" + ) if __name__ == "__main__": diff --git a/tests/nested_pool.py b/tests/nested_pool.py index 0845438..e8ef641 100644 --- a/tests/nested_pool.py +++ b/tests/nested_pool.py @@ -11,15 +11,17 @@ async def child(name: int): async def main(): start = giambio.clock() async with giambio.create_pool() as pool: - pool.spawn(child, 1) - pool.spawn(child, 2) + await pool.spawn(child, 1) + await pool.spawn(child, 2) async with giambio.create_pool() as a_pool: - a_pool.spawn(child, 3) - a_pool.spawn(child, 4) + await a_pool.spawn(child, 3) + await a_pool.spawn(child, 4) print("[main] Children spawned, awaiting completion") # This will *only* execute when everything inside the async with block # has ran, including any other pool - print(f"[main] Children execution complete in {giambio.clock() - start:.2f} seconds") + print( + f"[main] Children execution complete in {giambio.clock() - start:.2f} seconds" + ) if __name__ == "__main__": diff --git a/tests/server.py b/tests/server.py index 849b17c..998ec40 100644 --- a/tests/server.py +++ b/tests/server.py @@ -22,7 +22,7 @@ async def serve(bind_address: tuple): while True: conn, address_tuple = await sock.accept() logging.info(f"{address_tuple[0]}:{address_tuple[1]} connected") - pool.spawn(handler, conn, address_tuple) + await pool.spawn(handler, conn, address_tuple) async def handler(sock: AsyncSocket, client_address: tuple): @@ -38,8 +38,10 @@ async def handler(sock: AsyncSocket, client_address: tuple): """ address = f"{client_address[0]}:{client_address[1]}" - async with sock: # Closes the socket automatically - await sock.send_all(b"Welcome to the server pal, feel free to send me something!\n") + async with sock: # Closes the socket automatically + await sock.send_all( + b"Welcome to the server pal, feel free to send me something!\n" + ) while True: await sock.send_all(b"-> ") data = await sock.receive(1024) @@ -47,7 +49,9 @@ async def handler(sock: AsyncSocket, client_address: tuple): break elif data == b"exit\n": await sock.send_all(b"I'm dead dude\n") - raise TypeError("Oh, no, I'm gonna die!") # This kills the entire application! + raise TypeError( + "Oh, no, I'm gonna die!" + ) # This kills the entire application! logging.info(f"Got: {data!r} from {address}") await sock.send_all(b"Got: " + data) logging.info(f"Echoed back {data!r} to {address}") @@ -56,12 +60,15 @@ async def handler(sock: AsyncSocket, client_address: tuple): if __name__ == "__main__": port = int(sys.argv[1]) if len(sys.argv) > 1 else 1501 - logging.basicConfig(level=20, format="[%(levelname)s] %(asctime)s %(message)s", datefmt="%d/%m/%Y %p") + logging.basicConfig( + level=20, + format="[%(levelname)s] %(asctime)s %(message)s", + datefmt="%d/%m/%Y %p", + ) try: giambio.run(serve, ("localhost", port)) except (Exception, KeyboardInterrupt) as error: # Exceptions propagate! if isinstance(error, KeyboardInterrupt): logging.info("Ctrl+C detected, exiting") else: - raise logging.error(f"Exiting due to a {type(error).__name__}: {error}") diff --git a/tests/sleep.py b/tests/sleep.py index 4153a16..67ef90d 100644 --- a/tests/sleep.py +++ b/tests/sleep.py @@ -16,10 +16,12 @@ async def child1(): async def main(): start = giambio.clock() async with giambio.create_pool() as pool: - pool.spawn(child) - pool.spawn(child1) + await pool.spawn(child) + await pool.spawn(child1) print("[main] Children spawned, awaiting completion") - print(f"[main] Children execution complete in {giambio.clock() - start:.2f} seconds") + print( + f"[main] Children execution complete in {giambio.clock() - start:.2f} seconds" + ) if __name__ == "__main__": diff --git a/tests/timeout.py b/tests/timeout.py index 4c700de..01c89ca 100644 --- a/tests/timeout.py +++ b/tests/timeout.py @@ -12,12 +12,14 @@ async def main(): start = giambio.clock() try: async with giambio.with_timeout(10) as pool: - pool.spawn(child, 7) # This will complete - await child(20) # TODO: Broken + await pool.spawn(child, 7) # This will complete + await child(20) # TODO: Broken except giambio.exceptions.TooSlowError: print("[main] One or more children have timed out!") - print(f"[main] Children execution complete in {giambio.clock() - start:.2f} seconds") + print( + f"[main] Children execution complete in {giambio.clock() - start:.2f} seconds" + ) if __name__ == "__main__": - giambio.run(main, debugger=Debugger()) + giambio.run(main, debugger=())