From 5c05de495d6fe028e17ba6632632b3251716cee7 Mon Sep 17 00:00:00 2001 From: Mattia Giambirtone Date: Sat, 5 Feb 2022 16:14:21 +0100 Subject: [PATCH] Fixed some issues with join() not properly rescheduling its caller when appropriate --- giambio/__init__.py | 2 +- giambio/context.py | 5 ++-- giambio/core.py | 59 +++++++++++++++++++++------------------- giambio/internal.py | 14 ++++++++++ giambio/io.py | 2 +- giambio/runtime.py | 29 ++++++-------------- giambio/sync.py | 6 ++-- giambio/task.py | 5 ++-- giambio/traps.py | 5 ++-- tests/chatroom_client.py | 2 +- tests/chatroom_server.py | 18 ++++++------ tests/echo_server.py | 16 +++++------ tests/events.py | 4 ++- tests/queue.py | 2 -- tests/socket_ssl.py | 47 +++++++++++++++++--------------- tests/timeout.py | 2 +- tests/timeout3.py | 4 +-- 17 files changed, 112 insertions(+), 110 deletions(-) diff --git a/giambio/__init__.py b/giambio/__init__.py index 0ced5c2..9f24333 100644 --- a/giambio/__init__.py +++ b/giambio/__init__.py @@ -45,5 +45,5 @@ __all__ = [ "skip_after", "task", "io", - "socket" + "socket", ] diff --git a/giambio/context.py b/giambio/context.py index abc12f9..024f688 100644 --- a/giambio/context.py +++ b/giambio/context.py @@ -16,9 +16,8 @@ See the License for the specific language governing permissions and limitations under the License. """ -import types import giambio -from typing import List, Optional +from typing import List, Optional, Callable, Coroutine, Any class TaskManager: @@ -55,7 +54,7 @@ class TaskManager: self.enclosed_pool: Optional["giambio.context.TaskManager"] = None self.raise_on_timeout: bool = raise_on_timeout - async def spawn(self, func: types.FunctionType, *args, **kwargs) -> "giambio.task.Task": + async def spawn(self, func: Callable[..., Coroutine[Any, Any, Any]], *args, **kwargs) -> "giambio.task.Task": """ Spawns a child task """ diff --git a/giambio/core.py b/giambio/core.py index 2ceb3f8..0ad9210 100644 --- a/giambio/core.py +++ b/giambio/core.py @@ -17,13 +17,12 @@ limitations under the License. """ # Import libraries and internal resources -import types from giambio.task import Task from collections import deque from functools import partial from timeit import default_timer from giambio.context import TaskManager -from typing import Callable, List, Optional, Any, Dict +from typing import Callable, List, Optional, Any, Dict, Coroutine from giambio.util.debug import BaseDebugger from giambio.internal import TimeQueue, DeadlinesQueue from selectors import DefaultSelector, EVENT_READ, EVENT_WRITE @@ -56,7 +55,7 @@ class AsyncScheduler: :param clock: A callable returning monotonically increasing values at each call, usually using seconds as units, but this is not enforced, defaults to timeit.default_timer - :type clock: :class: types.FunctionType + :type clock: :class: Callable :param debugger: A subclass of giambio.util.BaseDebugger or None if no debugging output is desired, defaults to None :type debugger: :class: giambio.util.BaseDebugger @@ -73,7 +72,7 @@ class AsyncScheduler: def __init__( self, - clock: types.FunctionType = default_timer, + clock: Callable = default_timer, debugger: Optional[BaseDebugger] = None, selector: Optional[Any] = None, io_skip_limit: Optional[int] = None, @@ -107,7 +106,7 @@ class AsyncScheduler: # This will always point to the currently running coroutine (Task object) self.current_task: Optional[Task] = None # Monotonic clock to keep track of elapsed time reliably - self.clock: types.FunctionType = clock + self.clock: Callable = clock # Tasks that are asleep self.paused: TimeQueue = TimeQueue(self.clock) # Have we ever ran? @@ -246,8 +245,7 @@ class AsyncScheduler: self.current_task.exc = err self.join(self.current_task) - - def create_task(self, corofunc: types.FunctionType, pool, *args, **kwargs) -> Task: + def create_task(self, corofunc: Callable[..., Coroutine[Any, Any, Any]], pool, *args, **kwargs) -> Task: """ Creates a task from a coroutine function and schedules it to run. The associated pool that spawned said task is also @@ -286,7 +284,6 @@ class AsyncScheduler: account, that's self.run's job! """ - data = None # Sets the currently running task self.current_task = self.run_ready.popleft() if self.current_task.done(): @@ -351,11 +348,11 @@ class AsyncScheduler: a do-nothing method, since it will not reschedule the task before returning. The task will stay suspended as long as something else outside the loop calls a trap to reschedule it. - Any pending I/O for the task is temporarily unscheduled to + Any pending I/O for the task is temporarily unscheduled to avoid some previous network operation to reschedule the task before it's due """ - + if self.current_task.last_io: self.io_release_task(self.current_task) self.suspended.append(self.current_task) @@ -540,7 +537,7 @@ class AsyncScheduler: self.run_ready.append(key.data) # Resource ready? Schedule its task self.debugger.after_io(self.clock() - before_time) - def start(self, func: types.FunctionType, *args, loop: bool = True): + def start(self, func: Callable[..., Coroutine[Any, Any, Any]], *args, loop: bool = True): """ Starts the event loop from a sync context. If the loop parameter is false, the event loop will not start listening for events @@ -623,16 +620,21 @@ class AsyncScheduler: given task, if any """ - if task.pool and task.pool.enclosed_pool and not task.pool.enclosed_pool.done(): - return for t in task.joiners: - if t not in self.run_ready: - # Since a task can be the parent - # of multiple children, we need to - # make sure we reschedule it only - # once, otherwise a RuntimeError will - # occur - self.run_ready.append(t) + self.run_ready.append(t) + + # noinspection PyMethodMayBeStatic + def is_pool_done(self, pool: Optional[TaskManager]): + """ + Returns True if a given pool has finished + executing + """ + + while pool: + if not pool.done(): + return False + pool = pool.enclosed_pool + return True def join(self, task: Task): """ @@ -643,6 +645,8 @@ class AsyncScheduler: task.joined = True if task.finished or task.cancelled: + if task in self.tasks: + self.tasks.remove(task) if not task.cancelled: # This way join() returns the # task's return value @@ -653,9 +657,12 @@ class AsyncScheduler: self.io_release_task(task) # If the pool has finished executing or we're at the first parent # task that kicked the loop, we can safely reschedule the parent(s) - if not task.pool or task.pool.done(): + if self.is_pool_done(task.pool): self.reschedule_joiners(task) + self.reschedule_running() elif task.exc: + if task in self.tasks: + self.tasks.remove(task) task.status = "crashed" if task.exc.__traceback__: # TODO: We might want to do a bit more complex traceback hacking to remove any extra @@ -676,15 +683,11 @@ class AsyncScheduler: # or just returned for t in task.joiners.copy(): # Propagate the exception - try: - t.throw(task.exc) - except (StopIteration, CancelledError, RuntimeError): - # TODO: Need anything else? + self.handle_task_exit(t, partial(t.throw, task.exc)) + if t.exc or t.finished or t.cancelled: task.joiners.remove(t) - finally: - if t in self.tasks: - self.tasks.remove(t) self.reschedule_joiners(task) + self.reschedule_running() def sleep(self, seconds: int or float): """ diff --git a/giambio/internal.py b/giambio/internal.py index 0449e03..6dbcf02 100644 --- a/giambio/internal.py +++ b/giambio/internal.py @@ -44,6 +44,13 @@ class TimeQueue: self.sequence = 0 self.container: List[Tuple[float, int, Task]] = [] + def __len__(self): + """ + Returns len(self) + """ + + return len(self.container) + def __contains__(self, item: Task): """ Implements item in self. This method behaves @@ -263,6 +270,13 @@ class DeadlinesQueue: return f"DeadlinesQueue({self.container})" + def __len__(self): + """ + Returns len(self) + """ + + return len(self.container) + def put(self, pool: "giambio.context.TaskManager"): """ Pushes a pool with its deadline onto the queue. The diff --git a/giambio/io.py b/giambio/io.py index a5cb425..9ddd84e 100644 --- a/giambio/io.py +++ b/giambio/io.py @@ -244,7 +244,7 @@ class AsyncSocket: await want_write(self.sock) except WantRead: await want_read(self.sock) - + async def getpeername(self): """ Wrapper socket method diff --git a/giambio/runtime.py b/giambio/runtime.py index bf14e28..f5124b4 100644 --- a/giambio/runtime.py +++ b/giambio/runtime.py @@ -18,12 +18,13 @@ limitations under the License. import inspect import threading +from typing import Callable, Coroutine, Any, Union + from giambio.core import AsyncScheduler from giambio.exceptions import GiambioError from giambio.context import TaskManager from timeit import default_timer from giambio.util.debug import BaseDebugger -from types import FunctionType thread_local = threading.local() @@ -41,7 +42,7 @@ def get_event_loop(): raise GiambioError("giambio is not running") from None -def new_event_loop(debugger: BaseDebugger, clock: FunctionType): +def new_event_loop(debugger: BaseDebugger, clock: Callable): """ Associates a new event loop to the current thread and deactivates the old one. This should not be @@ -62,7 +63,7 @@ def new_event_loop(debugger: BaseDebugger, clock: FunctionType): thread_local.loop = AsyncScheduler(clock, debugger) -def run(func: FunctionType, *args, **kwargs): +def run(func: Callable[..., Coroutine[Any, Any, Any]], *args, **kwargs): """ Starts the event loop from a synchronous entry point """ @@ -95,23 +96,16 @@ def create_pool(): return TaskManager() -def with_timeout(timeout: int or float): +def with_timeout(timeout: Union[int, float]): """ Creates an async pool with an associated timeout """ assert timeout > 0, "The timeout must be greater than 0" - mgr = TaskManager(timeout) - loop = get_event_loop() - if loop.current_task.pool is None: - loop.current_pool = mgr - loop.current_task.pool = mgr - loop.current_task.next_deadline = mgr.timeout or 0.0 - loop.deadlines.put(mgr) - return mgr + return TaskManager(timeout) -def skip_after(timeout: int or float): +def skip_after(timeout: Union[int, float]): """ Creates an async pool with an associated timeout, but without raising a TooSlowError exception. The pool @@ -119,11 +113,4 @@ def skip_after(timeout: int or float): """ assert timeout > 0, "The timeout must be greater than 0" - mgr = TaskManager(timeout, False) - loop = get_event_loop() - if loop.current_task.pool is None: - loop.current_pool = mgr - loop.current_task.pool = mgr - loop.current_task.next_deadline = mgr.timeout or 0.0 - loop.deadlines.put(mgr) - return mgr + return TaskManager(timeout, False) diff --git a/giambio/sync.py b/giambio/sync.py index f6babab..0dab79c 100644 --- a/giambio/sync.py +++ b/giambio/sync.py @@ -73,11 +73,10 @@ class Queue: self.putters = deque() self.container = deque(maxlen=maxsize) - async def put(self, item: Any): """ Pushes an element onto the queue. If the - queue is full, waits until there's + queue is full, waits until there's enough space for the queue """ @@ -88,7 +87,6 @@ class Queue: else: self.putters.append(Event()) await self.putters[-1].wait() - async def get(self) -> Any: """ @@ -103,4 +101,4 @@ class Queue: return self.container.popleft() else: self.getters.append(Event()) - return await self.getters[-1].wait() \ No newline at end of file + return await self.getters[-1].wait() diff --git a/giambio/task.py b/giambio/task.py index 899fc42..64b6515 100644 --- a/giambio/task.py +++ b/giambio/task.py @@ -54,8 +54,8 @@ class Task: # when the task has been created but not started running yet--, "run"-- when # the task is running synchronous code--, "io"-- when the task is waiting on # an I/O resource--, "sleep"-- when the task is either asleep, waiting on - # an event or otherwise suspended, "crashed"-- when the task has exited because - # of an exception and "cancelled" when-- when the task has been explicitly cancelled + # an event or otherwise suspended, "crashed"-- when the task has exited because + # of an exception and "cancelled" when-- when the task has been explicitly cancelled # with its cancel() method or as a result of an exception status: str = "init" # This attribute counts how many times the task's run() method has been called @@ -112,7 +112,6 @@ class Task: self.joiners.add(task) return await giambio.traps.join(self) - async def cancel(self): """ Cancels the task diff --git a/giambio/traps.py b/giambio/traps.py index 240917c..6756fe0 100644 --- a/giambio/traps.py +++ b/giambio/traps.py @@ -24,8 +24,7 @@ limitations under the License. import types import inspect from giambio.task import Task -from types import FunctionType -from typing import Any, Union, Iterable +from typing import Any, Union, Iterable, Coroutine, Callable from giambio.exceptions import GiambioError @@ -49,7 +48,7 @@ async def suspend() -> Any: return await create_trap("suspend") -async def create_task(coro: FunctionType, pool, *args): +async def create_task(coro: Callable[..., Coroutine[Any, Any, Any]], pool, *args): """ Spawns a new task in the current event loop from a bare coroutine function. All extra positional arguments are passed to the function diff --git a/tests/chatroom_client.py b/tests/chatroom_client.py index 80ac22e..0c1d2f0 100644 --- a/tests/chatroom_client.py +++ b/tests/chatroom_client.py @@ -19,7 +19,7 @@ async def receiver(sock: giambio.socket.AsyncSocket, q: giambio.Queue): data, rest = data.split(b"\n", maxsplit=2) buffer = b"".join(rest) await q.put((1, data.decode())) - data = buffer + data = buffer async def main(host: Tuple[str, int]): diff --git a/tests/chatroom_server.py b/tests/chatroom_server.py index b3e31bb..441b16d 100644 --- a/tests/chatroom_server.py +++ b/tests/chatroom_server.py @@ -23,15 +23,15 @@ async def serve(bind_address: tuple): logging.info(f"Serving asynchronously at {bind_address[0]}:{bind_address[1]}") async with giambio.create_pool() as pool: async with sock: - while True: - try: - conn, address_tuple = await sock.accept() - clients.append(conn) - logging.info(f"{address_tuple[0]}:{address_tuple[1]} connected") - await pool.spawn(handler, conn, address_tuple) - except Exception as err: - # Because exceptions just *work* - logging.info(f"{address_tuple[0]}:{address_tuple[1]} has raised {type(err).__name__}: {err}") + while True: + try: + conn, address_tuple = await sock.accept() + clients.append(conn) + logging.info(f"{address_tuple[0]}:{address_tuple[1]} connected") + await pool.spawn(handler, conn, address_tuple) + except Exception as err: + # Because exceptions just *work* + logging.info(f"{address_tuple[0]}:{address_tuple[1]} has raised {type(err).__name__}: {err}") async def handler(sock: AsyncSocket, client_address: tuple): diff --git a/tests/echo_server.py b/tests/echo_server.py index a42a1a6..7a79a60 100644 --- a/tests/echo_server.py +++ b/tests/echo_server.py @@ -20,14 +20,14 @@ async def serve(bind_address: tuple): logging.info(f"Serving asynchronously at {bind_address[0]}:{bind_address[1]}") async with giambio.create_pool() as pool: async with sock: - while True: - try: - conn, address_tuple = await sock.accept() - logging.info(f"{address_tuple[0]}:{address_tuple[1]} connected") - await pool.spawn(handler, conn, address_tuple) - except Exception as err: - # Because exceptions just *work* - logging.info(f"{address_tuple[0]}:{address_tuple[1]} has raised {type(err).__name__}: {err}") + while True: + try: + conn, address_tuple = await sock.accept() + logging.info(f"{address_tuple[0]}:{address_tuple[1]} connected") + await pool.spawn(handler, conn, address_tuple) + except Exception as err: + # Because exceptions just *work* + logging.info(f"{address_tuple[0]}:{address_tuple[1]} has raised {type(err).__name__}: {err}") async def handler(sock: AsyncSocket, client_address: tuple): diff --git a/tests/events.py b/tests/events.py index f056b2f..87d4ee9 100644 --- a/tests/events.py +++ b/tests/events.py @@ -15,7 +15,9 @@ async def child(ev: giambio.Event, pause: int): await giambio.sleep(pause) end_sleep = giambio.clock() - start_sleep end_total = giambio.clock() - start_total - print(f"[child] Done! Slept for {end_total:.2f} seconds total ({end_pause:.2f} waiting, {end_sleep:.2f} sleeping), nice nap!") + print( + f"[child] Done! Slept for {end_total:.2f} seconds total ({end_pause:.2f} waiting, {end_sleep:.2f} sleeping), nice nap!" + ) async def parent(pause: int = 1): diff --git a/tests/queue.py b/tests/queue.py index d8dc4bb..68375ae 100644 --- a/tests/queue.py +++ b/tests/queue.py @@ -18,14 +18,12 @@ async def consumer(q: giambio.Queue): break print(f"Consumed {item}") await giambio.sleep(1) - async def main(q: giambio.Queue, n: int): async with giambio.create_pool() as pool: await pool.spawn(consumer, q) await pool.spawn(producer, q, n) - queue = giambio.Queue() diff --git a/tests/socket_ssl.py b/tests/socket_ssl.py index b74cf74..6eb9a26 100644 --- a/tests/socket_ssl.py +++ b/tests/socket_ssl.py @@ -7,6 +7,7 @@ import time _print = print + def print(*args, **kwargs): sys.stdout.write(f"[{time.strftime('%H:%M:%S')}] ") _print(*args, **kwargs) @@ -14,18 +15,19 @@ def print(*args, **kwargs): async def test(host: str, port: int, bufsize: int = 4096): socket = giambio.socket.wrap_socket( - ssl.create_default_context().wrap_socket( - sock=sock.socket(), - # Note: do_handshake_on_connect MUST - # be set to False on the synchronous socket! - # Giambio handles the TLS handshake asynchronously - # and making the SSL library handle it blocks - # the entire event loop. To perform the TLS - # handshake upon connection, set the this - # parameter in the AsyncSocket class instead - do_handshake_on_connect=False, - server_hostname=host) - ) + ssl.create_default_context().wrap_socket( + sock=sock.socket(), + # Note: do_handshake_on_connect MUST + # be set to False on the synchronous socket! + # Giambio handles the TLS handshake asynchronously + # and making the SSL library handle it blocks + # the entire event loop. To perform the TLS + # handshake upon connection, set the this + # parameter in the AsyncSocket class instead + do_handshake_on_connect=False, + server_hostname=host, + ) + ) print(f"Attempting a connection to {host}:{port}") await socket.connect((host, port)) print("Connected") @@ -34,18 +36,20 @@ async def test(host: str, port: int, bufsize: int = 4096): async with socket: # Closes the socket automatically print("Entered socket context manager, sending request data") - await socket.send_all(b"""GET / HTTP/1.1\r\nHost: google.com\r\nUser-Agent: owo\r\nAccept: text/html\r\nConnection: keep-alive\r\nAccept: */*\r\n\r\n""") + await socket.send_all( + b"""GET / HTTP/1.1\r\nHost: google.com\r\nUser-Agent: owo\r\nAccept: text/html\r\nConnection: keep-alive\r\nAccept: */*\r\n\r\n""" + ) print("Data sent") buffer = b"" while not buffer.endswith(b"\r\n\r\n"): - print(f"Requesting up to {bufsize} bytes (current response size: {len(buffer)})") - data = await socket.receive(bufsize) - print(f"Received {len(data)} bytes") - if data: - buffer += data - else: - print("Received empty stream, closing connection") - break + print(f"Requesting up to {bufsize} bytes (current response size: {len(buffer)})") + data = await socket.receive(bufsize) + print(f"Received {len(data)} bytes") + if data: + buffer += data + else: + print("Received empty stream, closing connection") + break print(f"Request has{' not' if not p.timed_out else ''} timed out!") if buffer: data = buffer.decode().split("\r\n") @@ -70,4 +74,3 @@ async def test(host: str, port: int, bufsize: int = 4096): giambio.run(test, "google.com", 443, 256, debugger=()) - diff --git a/tests/timeout.py b/tests/timeout.py index 01d4e50..f697fd7 100644 --- a/tests/timeout.py +++ b/tests/timeout.py @@ -13,7 +13,7 @@ async def main(): try: async with giambio.with_timeout(12) as pool: await pool.spawn(child, 7) # This will complete - await giambio.sleep(2) # This will make the code below wait 2 seconds + await giambio.sleep(2) # This will make the code below wait 2 seconds await pool.spawn(child, 15) # This will not complete await giambio.sleep(50) await child(20) # Neither will this diff --git a/tests/timeout3.py b/tests/timeout3.py index a8496d6..0c31057 100644 --- a/tests/timeout3.py +++ b/tests/timeout3.py @@ -14,8 +14,8 @@ async def main(): try: async with giambio.with_timeout(5) as pool: task = await pool.spawn(child, 2) - print(await task.join()) - await giambio.sleep(5) + print(f"Child has returned: {await task.join()}") + await giambio.sleep(5) # This will trigger the timeout except giambio.exceptions.TooSlowError: print("[main] One or more children have timed out!") print(f"[main] Children execution complete in {giambio.clock() - start:.2f} seconds")