Initial work on Ctrl+C safety + added a message handler to the chatroom

This commit is contained in:
Nocturn9x 2023-05-10 12:05:33 +02:00
parent 6524fafa7c
commit 3f0daece7e
Signed by: nocturn9x
GPG Key ID: 8270F9F467971E59
4 changed files with 165 additions and 41 deletions

View File

@ -33,8 +33,9 @@ from aiosched.errors import (
ResourceBroken,
)
from aiosched.context import TaskPool, TaskScope
from aiosched.sigint import CTRLC_PROTECTION_ENABLED
from aiosched.util.sigint import CTRLC_PROTECTION_ENABLED, currently_protected, enable_ki_protection
from selectors import DefaultSelector, BaseSelector, EVENT_READ, EVENT_WRITE
from types import FrameType
class FIFOKernel:
@ -118,16 +119,17 @@ class FIFOKernel:
)
return f"{type(self).__name__}({data})"
def _sigint_handler(self, _sig: int, frame):
def _sigint_handler(self, _sig: int, _frame: FrameType):
"""
Handles SIGINT
:return:
"""
self._sigint_handled = True
# Poke the event loop with a stick ;)
self.reschedule_running()
if currently_protected():
self._sigint_handled = True
else:
raise KeyboardInterrupt
def done(self) -> bool:
"""
@ -146,6 +148,8 @@ class FIFOKernel:
# waiting on. This avoids issues such as the event loop never exiting if the
# user forgets to close a socket, for example
return False
if self.current_task:
return self.current_task.done()
return True
def close(self, force: bool = False):
@ -346,14 +350,14 @@ class FIFOKernel:
# We perform the deferred cancellation
# if it was previously scheduled
self.cancel(self.current_task)
if self._sigint_handled:
self._sigint_handled = False
_runner = partial(self.current_task.throw, KeyboardInterrupt())
_data = []
# Some debugging and internal chatter here
self.current_task.steps += 1
self.current_task.state = TaskState.RUN
self.debugger.before_task_step(self.current_task)
if self._sigint_handled:
self._sigint_handled = False
_runner = partial(self.current_task.throw, KeyboardInterrupt)
_data = []
# Run a single step with the calculation (i.e. until a yield
# somewhere)
method, args, kwargs = _runner(*_data)
@ -387,7 +391,7 @@ class FIFOKernel:
Undoes any modification made by setup()
"""
if signal.getsignal(signal.SIGINT) is self._sigint_handled:
if signal.getsignal(signal.SIGINT) is self._sigint_handler:
signal.signal(signal.SIGINT, signal.default_int_handler)
def run(self):
@ -442,6 +446,7 @@ class FIFOKernel:
# Otherwise, while there are tasks ready to run, we run them!
self.handle_errors(self.run_task_step)
@enable_ki_protection
def start(
self, func: Callable[..., Coroutine[Any, Any, Any]], *args, **kwargs
) -> Any:

View File

@ -1,29 +0,0 @@
"""
aiosched: Yet another Python async scheduler
Copyright (C) 2022 nocturn9x
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https:www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
# Special magic module half-stolen from Trio (thanks njsmith I love you)
# that makes Ctrl+C work. P.S.: Please Python, get your signals straight.
# Just a funny variable name that is not a valid
# identifier (but still a string so tools that hack
# into frames don't freak out when they look at the
# local variables) which will get injected silently
# into every frame to enable/disable the safeguards
# for Ctrl+C/KeyboardInterrupt
CTRLC_PROTECTION_ENABLED = "|yes-it-is|"

125
aiosched/util/sigint.py Normal file
View File

@ -0,0 +1,125 @@
"""
aiosched: Yet another Python async scheduler
Copyright (C) 2022 nocturn9x
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https:www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import sys
import inspect
from functools import wraps
from typing import Callable
from types import FrameType
# Special magic module half-stolen from Trio (thanks njsmith I love you)
# that makes Ctrl+C work. P.S.: Please Python, get your signals straight.
# Just a funny variable name that is not a valid
# identifier (but still a string so tools that hack
# into frames don't freak out when they look at the
# local variables) which will get injected silently
# into every frame to enable/disable the safeguards
# for Ctrl+C/KeyboardInterrupt
CTRLC_PROTECTION_ENABLED = "|yes-it-is|"
def critical_section(frame: FrameType) -> bool:
"""
Returns whether Ctrl+C protection is currently
enabled in the given frame or in any of its children.
Stolen from Trio
"""
while frame is not None:
if CTRLC_PROTECTION_ENABLED in frame.f_locals:
return frame.f_locals[CTRLC_PROTECTION_ENABLED]
if frame.f_code.co_name == "__del__":
return True
frame = frame.f_back
return True
def currently_protected() -> bool:
"""
Returns whether Ctrl+C protection is currently
enabled in the current context
"""
return critical_section(sys._getframe())
def legacy_isasyncgenfunction(obj):
return getattr(obj, "_async_gen_function", None) == id(obj)
def _ki_protection_decorator(enabled):
def decorator(fn):
# In some version of Python, isgeneratorfunction returns true for
# coroutine functions, so we have to check for coroutine functions
# first.
if inspect.iscoroutinefunction(fn):
@wraps(fn)
def wrapper(*args, **kwargs):
# See the comment for regular generators below
coro = fn(*args, **kwargs)
coro.cr_frame.f_locals[CTRLC_PROTECTION_ENABLED] = enabled
return coro
return wrapper
elif inspect.isgeneratorfunction(fn):
@wraps(fn)
def wrapper(*args, **kwargs):
# It's important that we inject this directly into the
# generator's locals, as opposed to setting it here and then
# doing 'yield from'. The reason is, if a generator is
# throw()n into, then it may magically pop to the top of the
# stack. And @contextmanager generators in particular are a
# case where we often want KI protection, and which are often
# thrown into! See:
# https://bugs.python.org/issue29590
gen = fn(*args, **kwargs)
gen.gi_frame.f_locals[CTRLC_PROTECTION_ENABLED] = enabled
return gen
return wrapper
elif inspect.isasyncgenfunction(fn) or legacy_isasyncgenfunction(fn):
@wraps(fn)
def wrapper(*args, **kwargs):
# See the comment for regular generators above
agen = fn(*args, **kwargs)
agen.ag_frame.f_locals[CTRLC_PROTECTION_ENABLED] = enabled
return agen
return wrapper
else:
@wraps(fn)
def wrapper(*args, **kwargs):
locals()[CTRLC_PROTECTION_ENABLED] = enabled
return fn(*args, **kwargs)
return wrapper
return decorator
enable_ki_protection = _ki_protection_decorator(True)
enable_ki_protection.__name__ = "enable_ki_protection"
disable_ki_protection = _ki_protection_decorator(False)
disable_ki_protection.__name__ = "disable_ki_protection"

View File

@ -8,6 +8,21 @@ clients: dict[aiosched.socket.AsyncSocket, list[str, str]] = {}
names: set[str] = set()
async def message_handler(q: aiosched.Queue):
"""
Reads data submitted onto the queue
"""
try:
logging.info("Message handler spawned")
while True:
msg, payload = await q.get()
logging.info(f"Got message {msg!r} with the following payload: {payload}")
except Exception as e:
logging.error(f"An exception occurred in the message handler -> {type(e).__name__}: {e}")
raise
async def serve(bind_address: tuple):
"""
Serves asynchronously forever (or until Ctrl+C ;))
@ -17,23 +32,26 @@ async def serve(bind_address: tuple):
"""
sock = aiosched.socket.socket()
queue = aiosched.Queue()
await sock.bind(bind_address)
await sock.listen(5)
logging.info(f"Serving asynchronously at {bind_address[0]}:{bind_address[1]}")
async with aiosched.create_pool() as pool:
await pool.spawn(message_handler, queue)
async with sock:
while True:
try:
conn, address_tuple = await sock.accept()
clients[conn] = ["", f"{address_tuple[0]}:{address_tuple[1]}"]
await queue.put(("connect", clients[conn]))
logging.info(f"{address_tuple[0]}:{address_tuple[1]} connected")
await pool.spawn(handler, conn)
await pool.spawn(handler, conn, queue)
except Exception as err:
# Because exceptions just *work*
logging.info(f"{address_tuple[0]}:{address_tuple[1]} has raised {type(err).__name__}: {err}")
async def handler(sock: aiosched.socket.AsyncSocket):
async def handler(sock: aiosched.socket.AsyncSocket, q: aiosched.Queue):
"""
Handles a single client connection
@ -59,6 +77,7 @@ async def handler(sock: aiosched.socket.AsyncSocket):
else:
await sock.send_all(b"Sorry, but that name is already taken. Try again!\n> ")
await sock.send_all(f"Okay {name}, welcome to the chatroom!\n".encode())
await q.put(("join", (address, name)))
logging.info(f"{name} has joined the chatroom ({address}), informing clients")
for i, client_sock in enumerate(clients):
if client_sock != sock and clients[client_sock][0]:
@ -71,6 +90,7 @@ async def handler(sock: aiosched.socket.AsyncSocket):
decoded = data.decode().rstrip("\n")
if decoded.startswith("/"):
logging.info(f"{name} issued server command {decoded}")
await q.put(("cmd", (name, decoded[1:])))
match decoded[1:]:
case "bye":
await sock.send_all(b"Bye!\n")
@ -78,6 +98,7 @@ async def handler(sock: aiosched.socket.AsyncSocket):
case _:
await sock.send_all(b"Unknown command\n")
else:
await q.put(("msg", (name, data)))
logging.info(f"Got: {data!r} from {address}")
for i, client_sock in enumerate(clients):
if client_sock != sock and clients[client_sock][0]:
@ -86,6 +107,7 @@ async def handler(sock: aiosched.socket.AsyncSocket):
data += b"\n"
await client_sock.send_all(f"[{name}] ({address}): {data.decode()}> ".encode())
logging.info(f"Sent {data!r} to {i} clients")
await q.put(("leave", name))
logging.info(f"Connection from {address} closed")
logging.info(f"{name} has left the chatroom ({address}), informing clients")
for i, client_sock in enumerate(clients):
@ -109,4 +131,5 @@ if __name__ == "__main__":
if isinstance(error, KeyboardInterrupt):
logging.info("Ctrl+C detected, exiting")
else:
raise
logging.error(f"Exiting due to a {type(error).__name__}: {error}")