From ed6aba490fd087c27e3e4c2b7785862e897a6b74 Mon Sep 17 00:00:00 2001 From: Nocturn9x Date: Sun, 27 Feb 2022 18:14:12 +0100 Subject: [PATCH] Added two-way proxy example stolen from njsmith and fixed bug with io_release_task being fucking dumb --- giambio/context.py | 10 ++++++--- giambio/core.py | 41 +++++++++++++++++---------------- giambio/io.py | 2 ++ tests/proxy.py | 56 ++++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 87 insertions(+), 22 deletions(-) create mode 100644 tests/proxy.py diff --git a/giambio/context.py b/giambio/context.py index f610bf4..c6fdd21 100644 --- a/giambio/context.py +++ b/giambio/context.py @@ -16,6 +16,7 @@ See the License for the specific language governing permissions and limitations under the License. """ +from lib2to3.pgen2.token import OP import types import giambio from typing import List, Optional @@ -54,6 +55,7 @@ class TaskManager: self._proper_init = False self.enclosed_pool: Optional["giambio.context.TaskManager"] = None self.raise_on_timeout: bool = raise_on_timeout + self.entry_point: Optional[Task] = None async def spawn(self, func: types.FunctionType, *args, **kwargs) -> "giambio.task.Task": """ @@ -70,6 +72,7 @@ class TaskManager: """ self._proper_init = True + self.entry_point = await giambio.traps.current_task() return self async def __aexit__(self, exc_type: Exception, exc: Exception, tb): @@ -89,8 +92,9 @@ class TaskManager: if isinstance(exc, giambio.exceptions.TooSlowError) and not self.raise_on_timeout: return True except giambio.exceptions.TooSlowError: - return True - + if not self.raise_on_timeout: + raise + async def cancel(self): """ Cancels the pool entirely, iterating over all @@ -108,4 +112,4 @@ class TaskManager: pool have exited, False otherwise """ - return self._proper_init and all([task.done() for task in self.tasks]) + return self._proper_init and all([task.done() for task in self.tasks]) and (True if not self.enclosed_pool else self.enclosed_pool.done()) diff --git a/giambio/core.py b/giambio/core.py index 58058fb..d2de041 100644 --- a/giambio/core.py +++ b/giambio/core.py @@ -341,7 +341,7 @@ class AsyncScheduler: if self.selector.get_map(): for k in filter( - lambda o: o.data == self.current_task, + lambda o: o.data == task, dict(self.selector.get_map()).values(), ): self.io_release(k.fileobj) @@ -358,7 +358,7 @@ class AsyncScheduler: before it's due """ - if self.current_task.last_io: + if self.current_task.last_io or self.current_task.status == "io": self.io_release_task(self.current_task) self.current_task.status = "sleep" self.suspended.append(self.current_task) @@ -441,13 +441,11 @@ class AsyncScheduler: while self.deadlines and self.deadlines.get_closest_deadline() <= self.clock(): pool = self.deadlines.get() pool.timed_out = True - if not pool.tasks and self.current_task is self.entry_point: - self.handle_task_exit(self.entry_point, partial(self.entry_point.throw, TooSlowError(self.entry_point))) + self.cancel_pool(pool) for task in pool.tasks: - if not task.done(): - self.paused.discard(task) - self.io_release_task(task) - self.handle_task_exit(task, partial(task.throw, TooSlowError(task))) + self.join(task) + self.handle_task_exit(self.entry_point, partial(self.entry_point.throw, TooSlowError(self.entry_point))) + def schedule_tasks(self, tasks: List[Task]): """ @@ -554,9 +552,12 @@ class AsyncScheduler: self.run_ready.append(entry) self.debugger.on_start() if loop: - self.run() - self.has_ran = True - self.debugger.on_exit() + try: + self.run() + finally: + self.has_ran = True + self.close() + self.debugger.on_exit() def cancel_pool(self, pool: TaskManager) -> bool: """ @@ -729,8 +730,8 @@ class AsyncScheduler: self.io_release_task(task) elif task.status == "sleep": self.paused.discard(task) - if task in self.suspended: - self.suspended.remove(task) + if task in self.suspended: + self.suspended.remove(task) try: self.do_cancel(task) except CancelledError as cancel: @@ -747,7 +748,6 @@ class AsyncScheduler: task.cancel_pending = False task.cancelled = True task.status = "cancelled" - self.io_release_task(self.current_task) self.debugger.after_cancel(task) self.tasks.remove(task) else: @@ -758,12 +758,12 @@ class AsyncScheduler: def register_sock(self, sock, evt_type: str): """ Registers the given socket inside the - selector to perform I/0 multiplexing + selector to perform I/O multiplexing :param sock: The socket on which a read or write operation - has to be performed + has to be performed :param evt_type: The type of event to perform on the given - socket, either "read" or "write" + socket, either "read" or "write" :type evt_type: str """ @@ -797,5 +797,8 @@ class AsyncScheduler: try: self.selector.register(sock, evt, self.current_task) except KeyError: - # The socket is already registered doing something else - raise ResourceBusy("The given socket is being read/written by another task") from None + # The socket is already registered doing something else, we + # modify the socket instead (or maybe not?) + self.selector.modify(sock, evt, self.current_task) + # TODO: Does this break stuff? + # raise ResourceBusy("The given socket is being read/written by another task") from None diff --git a/giambio/io.py b/giambio/io.py index a5cb425..95ff009 100644 --- a/giambio/io.py +++ b/giambio/io.py @@ -16,6 +16,7 @@ See the License for the specific language governing permissions and limitations under the License. """ +import giambio from giambio.exceptions import ResourceClosed from giambio.traps import want_write, want_read, io_release @@ -121,6 +122,7 @@ class AsyncSocket: if self.sock: self.sock.shutdown(how) + await giambio.sleep(0) # Checkpoint async def bind(self, addr: tuple): """ diff --git a/tests/proxy.py b/tests/proxy.py new file mode 100644 index 0000000..8d5ff8d --- /dev/null +++ b/tests/proxy.py @@ -0,0 +1,56 @@ +from debugger import Debugger +import giambio +import socket + + +async def proxy_one_way(source: giambio.socket.AsyncSocket, sink: giambio.socket.AsyncSocket): + """ + Sends data from source to sink + """ + + sink_addr = ":".join(map(str, await sink.getpeername())) + source_addr = ":".join(map(str, await source.getpeername())) + while True: + data = await source.receive(1024) + if not data: + print(f"{source_addr} has exited, closing connection to {sink_addr}") + await sink.shutdown(socket.SHUT_WR) + break + print(f"Got {data.decode('utf8', errors='ignore')!r} from {source_addr}, forwarding it to {sink_addr}") + await sink.send_all(data) + + +async def proxy_two_way(a: giambio.socket.AsyncSocket, b: giambio.socket.AsyncSocket): + """ + Sets up a two-way proxy from a to b and from b to a + """ + + async with giambio.create_pool() as pool: + await pool.spawn(proxy_one_way, a, b) + await pool.spawn(proxy_one_way, b, a) + + +async def main(delay: int, a: tuple, b: tuple): + """ + Sets up the proxy + """ + + start = giambio.clock() + print(f"Starting two-way proxy from {a[0]}:{a[1]} to {b[0]}:{b[1]}, lasting for {delay} seconds") + async with giambio.skip_after(delay) as p: + sock_a = giambio.socket.socket() + sock_b = giambio.socket.socket() + await sock_a.connect(a) + await sock_b.connect(b) + async with sock_a, sock_b: + await proxy_two_way(sock_a, sock_b) + print(f"Proxy has exited after {giambio.clock() - start:.2f} seconds") + + +try: + giambio.run(main, 60, ("localhost", 12345), ("localhost", 54321), debugger=()) +except (Exception, KeyboardInterrupt) as error: # Exceptions propagate! + if isinstance(error, KeyboardInterrupt): + print("Ctrl+C detected, exiting") + else: + print(f"Exiting due to a {type(error).__name__}: {error}") \ No newline at end of file