Initial work on Ctrl+C safety + added a message handler to the chatroom
This commit is contained in:
parent
6524fafa7c
commit
3f0daece7e
|
@ -33,8 +33,9 @@ from aiosched.errors import (
|
|||
ResourceBroken,
|
||||
)
|
||||
from aiosched.context import TaskPool, TaskScope
|
||||
from aiosched.sigint import CTRLC_PROTECTION_ENABLED
|
||||
from aiosched.util.sigint import CTRLC_PROTECTION_ENABLED, currently_protected, enable_ki_protection
|
||||
from selectors import DefaultSelector, BaseSelector, EVENT_READ, EVENT_WRITE
|
||||
from types import FrameType
|
||||
|
||||
|
||||
class FIFOKernel:
|
||||
|
@ -118,16 +119,17 @@ class FIFOKernel:
|
|||
)
|
||||
return f"{type(self).__name__}({data})"
|
||||
|
||||
def _sigint_handler(self, _sig: int, frame):
|
||||
def _sigint_handler(self, _sig: int, _frame: FrameType):
|
||||
"""
|
||||
Handles SIGINT
|
||||
|
||||
:return:
|
||||
"""
|
||||
|
||||
self._sigint_handled = True
|
||||
# Poke the event loop with a stick ;)
|
||||
self.reschedule_running()
|
||||
if currently_protected():
|
||||
self._sigint_handled = True
|
||||
else:
|
||||
raise KeyboardInterrupt
|
||||
|
||||
def done(self) -> bool:
|
||||
"""
|
||||
|
@ -146,6 +148,8 @@ class FIFOKernel:
|
|||
# waiting on. This avoids issues such as the event loop never exiting if the
|
||||
# user forgets to close a socket, for example
|
||||
return False
|
||||
if self.current_task:
|
||||
return self.current_task.done()
|
||||
return True
|
||||
|
||||
def close(self, force: bool = False):
|
||||
|
@ -346,14 +350,14 @@ class FIFOKernel:
|
|||
# We perform the deferred cancellation
|
||||
# if it was previously scheduled
|
||||
self.cancel(self.current_task)
|
||||
if self._sigint_handled:
|
||||
self._sigint_handled = False
|
||||
_runner = partial(self.current_task.throw, KeyboardInterrupt())
|
||||
_data = []
|
||||
# Some debugging and internal chatter here
|
||||
self.current_task.steps += 1
|
||||
self.current_task.state = TaskState.RUN
|
||||
self.debugger.before_task_step(self.current_task)
|
||||
if self._sigint_handled:
|
||||
self._sigint_handled = False
|
||||
_runner = partial(self.current_task.throw, KeyboardInterrupt)
|
||||
_data = []
|
||||
# Run a single step with the calculation (i.e. until a yield
|
||||
# somewhere)
|
||||
method, args, kwargs = _runner(*_data)
|
||||
|
@ -387,7 +391,7 @@ class FIFOKernel:
|
|||
Undoes any modification made by setup()
|
||||
"""
|
||||
|
||||
if signal.getsignal(signal.SIGINT) is self._sigint_handled:
|
||||
if signal.getsignal(signal.SIGINT) is self._sigint_handler:
|
||||
signal.signal(signal.SIGINT, signal.default_int_handler)
|
||||
|
||||
def run(self):
|
||||
|
@ -442,6 +446,7 @@ class FIFOKernel:
|
|||
# Otherwise, while there are tasks ready to run, we run them!
|
||||
self.handle_errors(self.run_task_step)
|
||||
|
||||
@enable_ki_protection
|
||||
def start(
|
||||
self, func: Callable[..., Coroutine[Any, Any, Any]], *args, **kwargs
|
||||
) -> Any:
|
||||
|
|
|
@ -1,29 +0,0 @@
|
|||
"""
|
||||
aiosched: Yet another Python async scheduler
|
||||
|
||||
Copyright (C) 2022 nocturn9x
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
https:www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
# Special magic module half-stolen from Trio (thanks njsmith I love you)
|
||||
# that makes Ctrl+C work. P.S.: Please Python, get your signals straight.
|
||||
|
||||
|
||||
# Just a funny variable name that is not a valid
|
||||
# identifier (but still a string so tools that hack
|
||||
# into frames don't freak out when they look at the
|
||||
# local variables) which will get injected silently
|
||||
# into every frame to enable/disable the safeguards
|
||||
# for Ctrl+C/KeyboardInterrupt
|
||||
CTRLC_PROTECTION_ENABLED = "|yes-it-is|"
|
||||
|
|
@ -0,0 +1,125 @@
|
|||
"""
|
||||
aiosched: Yet another Python async scheduler
|
||||
|
||||
Copyright (C) 2022 nocturn9x
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
https:www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
import sys
|
||||
import inspect
|
||||
from functools import wraps
|
||||
from typing import Callable
|
||||
from types import FrameType
|
||||
|
||||
|
||||
# Special magic module half-stolen from Trio (thanks njsmith I love you)
|
||||
# that makes Ctrl+C work. P.S.: Please Python, get your signals straight.
|
||||
|
||||
|
||||
# Just a funny variable name that is not a valid
|
||||
# identifier (but still a string so tools that hack
|
||||
# into frames don't freak out when they look at the
|
||||
# local variables) which will get injected silently
|
||||
# into every frame to enable/disable the safeguards
|
||||
# for Ctrl+C/KeyboardInterrupt
|
||||
CTRLC_PROTECTION_ENABLED = "|yes-it-is|"
|
||||
|
||||
|
||||
def critical_section(frame: FrameType) -> bool:
|
||||
"""
|
||||
Returns whether Ctrl+C protection is currently
|
||||
enabled in the given frame or in any of its children.
|
||||
Stolen from Trio
|
||||
"""
|
||||
|
||||
while frame is not None:
|
||||
if CTRLC_PROTECTION_ENABLED in frame.f_locals:
|
||||
return frame.f_locals[CTRLC_PROTECTION_ENABLED]
|
||||
if frame.f_code.co_name == "__del__":
|
||||
return True
|
||||
frame = frame.f_back
|
||||
return True
|
||||
|
||||
|
||||
def currently_protected() -> bool:
|
||||
"""
|
||||
Returns whether Ctrl+C protection is currently
|
||||
enabled in the current context
|
||||
"""
|
||||
|
||||
return critical_section(sys._getframe())
|
||||
|
||||
|
||||
def legacy_isasyncgenfunction(obj):
|
||||
return getattr(obj, "_async_gen_function", None) == id(obj)
|
||||
|
||||
|
||||
def _ki_protection_decorator(enabled):
|
||||
def decorator(fn):
|
||||
# In some version of Python, isgeneratorfunction returns true for
|
||||
# coroutine functions, so we have to check for coroutine functions
|
||||
# first.
|
||||
if inspect.iscoroutinefunction(fn):
|
||||
|
||||
@wraps(fn)
|
||||
def wrapper(*args, **kwargs):
|
||||
# See the comment for regular generators below
|
||||
coro = fn(*args, **kwargs)
|
||||
coro.cr_frame.f_locals[CTRLC_PROTECTION_ENABLED] = enabled
|
||||
return coro
|
||||
|
||||
return wrapper
|
||||
elif inspect.isgeneratorfunction(fn):
|
||||
|
||||
@wraps(fn)
|
||||
def wrapper(*args, **kwargs):
|
||||
# It's important that we inject this directly into the
|
||||
# generator's locals, as opposed to setting it here and then
|
||||
# doing 'yield from'. The reason is, if a generator is
|
||||
# throw()n into, then it may magically pop to the top of the
|
||||
# stack. And @contextmanager generators in particular are a
|
||||
# case where we often want KI protection, and which are often
|
||||
# thrown into! See:
|
||||
# https://bugs.python.org/issue29590
|
||||
gen = fn(*args, **kwargs)
|
||||
gen.gi_frame.f_locals[CTRLC_PROTECTION_ENABLED] = enabled
|
||||
return gen
|
||||
|
||||
return wrapper
|
||||
elif inspect.isasyncgenfunction(fn) or legacy_isasyncgenfunction(fn):
|
||||
|
||||
@wraps(fn)
|
||||
def wrapper(*args, **kwargs):
|
||||
# See the comment for regular generators above
|
||||
agen = fn(*args, **kwargs)
|
||||
agen.ag_frame.f_locals[CTRLC_PROTECTION_ENABLED] = enabled
|
||||
return agen
|
||||
|
||||
return wrapper
|
||||
else:
|
||||
|
||||
@wraps(fn)
|
||||
def wrapper(*args, **kwargs):
|
||||
locals()[CTRLC_PROTECTION_ENABLED] = enabled
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
enable_ki_protection = _ki_protection_decorator(True)
|
||||
enable_ki_protection.__name__ = "enable_ki_protection"
|
||||
|
||||
disable_ki_protection = _ki_protection_decorator(False)
|
||||
disable_ki_protection.__name__ = "disable_ki_protection"
|
|
@ -8,6 +8,21 @@ clients: dict[aiosched.socket.AsyncSocket, list[str, str]] = {}
|
|||
names: set[str] = set()
|
||||
|
||||
|
||||
async def message_handler(q: aiosched.Queue):
|
||||
"""
|
||||
Reads data submitted onto the queue
|
||||
"""
|
||||
|
||||
try:
|
||||
logging.info("Message handler spawned")
|
||||
while True:
|
||||
msg, payload = await q.get()
|
||||
logging.info(f"Got message {msg!r} with the following payload: {payload}")
|
||||
except Exception as e:
|
||||
logging.error(f"An exception occurred in the message handler -> {type(e).__name__}: {e}")
|
||||
raise
|
||||
|
||||
|
||||
async def serve(bind_address: tuple):
|
||||
"""
|
||||
Serves asynchronously forever (or until Ctrl+C ;))
|
||||
|
@ -17,23 +32,26 @@ async def serve(bind_address: tuple):
|
|||
"""
|
||||
|
||||
sock = aiosched.socket.socket()
|
||||
queue = aiosched.Queue()
|
||||
await sock.bind(bind_address)
|
||||
await sock.listen(5)
|
||||
logging.info(f"Serving asynchronously at {bind_address[0]}:{bind_address[1]}")
|
||||
async with aiosched.create_pool() as pool:
|
||||
await pool.spawn(message_handler, queue)
|
||||
async with sock:
|
||||
while True:
|
||||
try:
|
||||
conn, address_tuple = await sock.accept()
|
||||
clients[conn] = ["", f"{address_tuple[0]}:{address_tuple[1]}"]
|
||||
await queue.put(("connect", clients[conn]))
|
||||
logging.info(f"{address_tuple[0]}:{address_tuple[1]} connected")
|
||||
await pool.spawn(handler, conn)
|
||||
await pool.spawn(handler, conn, queue)
|
||||
except Exception as err:
|
||||
# Because exceptions just *work*
|
||||
logging.info(f"{address_tuple[0]}:{address_tuple[1]} has raised {type(err).__name__}: {err}")
|
||||
|
||||
|
||||
async def handler(sock: aiosched.socket.AsyncSocket):
|
||||
async def handler(sock: aiosched.socket.AsyncSocket, q: aiosched.Queue):
|
||||
"""
|
||||
Handles a single client connection
|
||||
|
||||
|
@ -59,6 +77,7 @@ async def handler(sock: aiosched.socket.AsyncSocket):
|
|||
else:
|
||||
await sock.send_all(b"Sorry, but that name is already taken. Try again!\n> ")
|
||||
await sock.send_all(f"Okay {name}, welcome to the chatroom!\n".encode())
|
||||
await q.put(("join", (address, name)))
|
||||
logging.info(f"{name} has joined the chatroom ({address}), informing clients")
|
||||
for i, client_sock in enumerate(clients):
|
||||
if client_sock != sock and clients[client_sock][0]:
|
||||
|
@ -71,6 +90,7 @@ async def handler(sock: aiosched.socket.AsyncSocket):
|
|||
decoded = data.decode().rstrip("\n")
|
||||
if decoded.startswith("/"):
|
||||
logging.info(f"{name} issued server command {decoded}")
|
||||
await q.put(("cmd", (name, decoded[1:])))
|
||||
match decoded[1:]:
|
||||
case "bye":
|
||||
await sock.send_all(b"Bye!\n")
|
||||
|
@ -78,6 +98,7 @@ async def handler(sock: aiosched.socket.AsyncSocket):
|
|||
case _:
|
||||
await sock.send_all(b"Unknown command\n")
|
||||
else:
|
||||
await q.put(("msg", (name, data)))
|
||||
logging.info(f"Got: {data!r} from {address}")
|
||||
for i, client_sock in enumerate(clients):
|
||||
if client_sock != sock and clients[client_sock][0]:
|
||||
|
@ -86,6 +107,7 @@ async def handler(sock: aiosched.socket.AsyncSocket):
|
|||
data += b"\n"
|
||||
await client_sock.send_all(f"[{name}] ({address}): {data.decode()}> ".encode())
|
||||
logging.info(f"Sent {data!r} to {i} clients")
|
||||
await q.put(("leave", name))
|
||||
logging.info(f"Connection from {address} closed")
|
||||
logging.info(f"{name} has left the chatroom ({address}), informing clients")
|
||||
for i, client_sock in enumerate(clients):
|
||||
|
@ -109,4 +131,5 @@ if __name__ == "__main__":
|
|||
if isinstance(error, KeyboardInterrupt):
|
||||
logging.info("Ctrl+C detected, exiting")
|
||||
else:
|
||||
raise
|
||||
logging.error(f"Exiting due to a {type(error).__name__}: {error}")
|
||||
|
|
Reference in New Issue