From 3f0daece7ef5f056799b8c0450f3cb9827acac0c Mon Sep 17 00:00:00 2001 From: Nocturn9x Date: Wed, 10 May 2023 12:05:33 +0200 Subject: [PATCH] Initial work on Ctrl+C safety + added a message handler to the chatroom --- aiosched/kernel.py | 25 ++++---- aiosched/sigint.py | 29 --------- aiosched/util/sigint.py | 125 +++++++++++++++++++++++++++++++++++++++ tests/chatroom_server.py | 27 ++++++++- 4 files changed, 165 insertions(+), 41 deletions(-) delete mode 100644 aiosched/sigint.py create mode 100644 aiosched/util/sigint.py diff --git a/aiosched/kernel.py b/aiosched/kernel.py index 013c7dc..d6bc2fa 100644 --- a/aiosched/kernel.py +++ b/aiosched/kernel.py @@ -33,8 +33,9 @@ from aiosched.errors import ( ResourceBroken, ) from aiosched.context import TaskPool, TaskScope -from aiosched.sigint import CTRLC_PROTECTION_ENABLED +from aiosched.util.sigint import CTRLC_PROTECTION_ENABLED, currently_protected, enable_ki_protection from selectors import DefaultSelector, BaseSelector, EVENT_READ, EVENT_WRITE +from types import FrameType class FIFOKernel: @@ -118,16 +119,17 @@ class FIFOKernel: ) return f"{type(self).__name__}({data})" - def _sigint_handler(self, _sig: int, frame): + def _sigint_handler(self, _sig: int, _frame: FrameType): """ Handles SIGINT :return: """ - self._sigint_handled = True - # Poke the event loop with a stick ;) - self.reschedule_running() + if currently_protected(): + self._sigint_handled = True + else: + raise KeyboardInterrupt def done(self) -> bool: """ @@ -146,6 +148,8 @@ class FIFOKernel: # waiting on. This avoids issues such as the event loop never exiting if the # user forgets to close a socket, for example return False + if self.current_task: + return self.current_task.done() return True def close(self, force: bool = False): @@ -346,14 +350,14 @@ class FIFOKernel: # We perform the deferred cancellation # if it was previously scheduled self.cancel(self.current_task) - if self._sigint_handled: - self._sigint_handled = False - _runner = partial(self.current_task.throw, KeyboardInterrupt()) - _data = [] # Some debugging and internal chatter here self.current_task.steps += 1 self.current_task.state = TaskState.RUN self.debugger.before_task_step(self.current_task) + if self._sigint_handled: + self._sigint_handled = False + _runner = partial(self.current_task.throw, KeyboardInterrupt) + _data = [] # Run a single step with the calculation (i.e. until a yield # somewhere) method, args, kwargs = _runner(*_data) @@ -387,7 +391,7 @@ class FIFOKernel: Undoes any modification made by setup() """ - if signal.getsignal(signal.SIGINT) is self._sigint_handled: + if signal.getsignal(signal.SIGINT) is self._sigint_handler: signal.signal(signal.SIGINT, signal.default_int_handler) def run(self): @@ -442,6 +446,7 @@ class FIFOKernel: # Otherwise, while there are tasks ready to run, we run them! self.handle_errors(self.run_task_step) + @enable_ki_protection def start( self, func: Callable[..., Coroutine[Any, Any, Any]], *args, **kwargs ) -> Any: diff --git a/aiosched/sigint.py b/aiosched/sigint.py deleted file mode 100644 index 83085c2..0000000 --- a/aiosched/sigint.py +++ /dev/null @@ -1,29 +0,0 @@ -""" -aiosched: Yet another Python async scheduler - -Copyright (C) 2022 nocturn9x - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - https:www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -""" -# Special magic module half-stolen from Trio (thanks njsmith I love you) -# that makes Ctrl+C work. P.S.: Please Python, get your signals straight. - - -# Just a funny variable name that is not a valid -# identifier (but still a string so tools that hack -# into frames don't freak out when they look at the -# local variables) which will get injected silently -# into every frame to enable/disable the safeguards -# for Ctrl+C/KeyboardInterrupt -CTRLC_PROTECTION_ENABLED = "|yes-it-is|" - diff --git a/aiosched/util/sigint.py b/aiosched/util/sigint.py new file mode 100644 index 0000000..0e0ce46 --- /dev/null +++ b/aiosched/util/sigint.py @@ -0,0 +1,125 @@ +""" +aiosched: Yet another Python async scheduler + +Copyright (C) 2022 nocturn9x + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https:www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import sys +import inspect +from functools import wraps +from typing import Callable +from types import FrameType + + +# Special magic module half-stolen from Trio (thanks njsmith I love you) +# that makes Ctrl+C work. P.S.: Please Python, get your signals straight. + + +# Just a funny variable name that is not a valid +# identifier (but still a string so tools that hack +# into frames don't freak out when they look at the +# local variables) which will get injected silently +# into every frame to enable/disable the safeguards +# for Ctrl+C/KeyboardInterrupt +CTRLC_PROTECTION_ENABLED = "|yes-it-is|" + + +def critical_section(frame: FrameType) -> bool: + """ + Returns whether Ctrl+C protection is currently + enabled in the given frame or in any of its children. + Stolen from Trio + """ + + while frame is not None: + if CTRLC_PROTECTION_ENABLED in frame.f_locals: + return frame.f_locals[CTRLC_PROTECTION_ENABLED] + if frame.f_code.co_name == "__del__": + return True + frame = frame.f_back + return True + + +def currently_protected() -> bool: + """ + Returns whether Ctrl+C protection is currently + enabled in the current context + """ + + return critical_section(sys._getframe()) + + +def legacy_isasyncgenfunction(obj): + return getattr(obj, "_async_gen_function", None) == id(obj) + + +def _ki_protection_decorator(enabled): + def decorator(fn): + # In some version of Python, isgeneratorfunction returns true for + # coroutine functions, so we have to check for coroutine functions + # first. + if inspect.iscoroutinefunction(fn): + + @wraps(fn) + def wrapper(*args, **kwargs): + # See the comment for regular generators below + coro = fn(*args, **kwargs) + coro.cr_frame.f_locals[CTRLC_PROTECTION_ENABLED] = enabled + return coro + + return wrapper + elif inspect.isgeneratorfunction(fn): + + @wraps(fn) + def wrapper(*args, **kwargs): + # It's important that we inject this directly into the + # generator's locals, as opposed to setting it here and then + # doing 'yield from'. The reason is, if a generator is + # throw()n into, then it may magically pop to the top of the + # stack. And @contextmanager generators in particular are a + # case where we often want KI protection, and which are often + # thrown into! See: + # https://bugs.python.org/issue29590 + gen = fn(*args, **kwargs) + gen.gi_frame.f_locals[CTRLC_PROTECTION_ENABLED] = enabled + return gen + + return wrapper + elif inspect.isasyncgenfunction(fn) or legacy_isasyncgenfunction(fn): + + @wraps(fn) + def wrapper(*args, **kwargs): + # See the comment for regular generators above + agen = fn(*args, **kwargs) + agen.ag_frame.f_locals[CTRLC_PROTECTION_ENABLED] = enabled + return agen + + return wrapper + else: + + @wraps(fn) + def wrapper(*args, **kwargs): + locals()[CTRLC_PROTECTION_ENABLED] = enabled + return fn(*args, **kwargs) + + return wrapper + + return decorator + + +enable_ki_protection = _ki_protection_decorator(True) +enable_ki_protection.__name__ = "enable_ki_protection" + +disable_ki_protection = _ki_protection_decorator(False) +disable_ki_protection.__name__ = "disable_ki_protection" diff --git a/tests/chatroom_server.py b/tests/chatroom_server.py index 3846eab..2de28aa 100644 --- a/tests/chatroom_server.py +++ b/tests/chatroom_server.py @@ -8,6 +8,21 @@ clients: dict[aiosched.socket.AsyncSocket, list[str, str]] = {} names: set[str] = set() +async def message_handler(q: aiosched.Queue): + """ + Reads data submitted onto the queue + """ + + try: + logging.info("Message handler spawned") + while True: + msg, payload = await q.get() + logging.info(f"Got message {msg!r} with the following payload: {payload}") + except Exception as e: + logging.error(f"An exception occurred in the message handler -> {type(e).__name__}: {e}") + raise + + async def serve(bind_address: tuple): """ Serves asynchronously forever (or until Ctrl+C ;)) @@ -17,23 +32,26 @@ async def serve(bind_address: tuple): """ sock = aiosched.socket.socket() + queue = aiosched.Queue() await sock.bind(bind_address) await sock.listen(5) logging.info(f"Serving asynchronously at {bind_address[0]}:{bind_address[1]}") async with aiosched.create_pool() as pool: + await pool.spawn(message_handler, queue) async with sock: while True: try: conn, address_tuple = await sock.accept() clients[conn] = ["", f"{address_tuple[0]}:{address_tuple[1]}"] + await queue.put(("connect", clients[conn])) logging.info(f"{address_tuple[0]}:{address_tuple[1]} connected") - await pool.spawn(handler, conn) + await pool.spawn(handler, conn, queue) except Exception as err: # Because exceptions just *work* logging.info(f"{address_tuple[0]}:{address_tuple[1]} has raised {type(err).__name__}: {err}") -async def handler(sock: aiosched.socket.AsyncSocket): +async def handler(sock: aiosched.socket.AsyncSocket, q: aiosched.Queue): """ Handles a single client connection @@ -59,6 +77,7 @@ async def handler(sock: aiosched.socket.AsyncSocket): else: await sock.send_all(b"Sorry, but that name is already taken. Try again!\n> ") await sock.send_all(f"Okay {name}, welcome to the chatroom!\n".encode()) + await q.put(("join", (address, name))) logging.info(f"{name} has joined the chatroom ({address}), informing clients") for i, client_sock in enumerate(clients): if client_sock != sock and clients[client_sock][0]: @@ -71,6 +90,7 @@ async def handler(sock: aiosched.socket.AsyncSocket): decoded = data.decode().rstrip("\n") if decoded.startswith("/"): logging.info(f"{name} issued server command {decoded}") + await q.put(("cmd", (name, decoded[1:]))) match decoded[1:]: case "bye": await sock.send_all(b"Bye!\n") @@ -78,6 +98,7 @@ async def handler(sock: aiosched.socket.AsyncSocket): case _: await sock.send_all(b"Unknown command\n") else: + await q.put(("msg", (name, data))) logging.info(f"Got: {data!r} from {address}") for i, client_sock in enumerate(clients): if client_sock != sock and clients[client_sock][0]: @@ -86,6 +107,7 @@ async def handler(sock: aiosched.socket.AsyncSocket): data += b"\n" await client_sock.send_all(f"[{name}] ({address}): {data.decode()}> ".encode()) logging.info(f"Sent {data!r} to {i} clients") + await q.put(("leave", name)) logging.info(f"Connection from {address} closed") logging.info(f"{name} has left the chatroom ({address}), informing clients") for i, client_sock in enumerate(clients): @@ -109,4 +131,5 @@ if __name__ == "__main__": if isinstance(error, KeyboardInterrupt): logging.info("Ctrl+C detected, exiting") else: + raise logging.error(f"Exiting due to a {type(error).__name__}: {error}")