Initial work on Ctrl+C safety + added a message handler to the chatroom
This commit is contained in:
parent
6524fafa7c
commit
3f0daece7e
|
@ -33,8 +33,9 @@ from aiosched.errors import (
|
||||||
ResourceBroken,
|
ResourceBroken,
|
||||||
)
|
)
|
||||||
from aiosched.context import TaskPool, TaskScope
|
from aiosched.context import TaskPool, TaskScope
|
||||||
from aiosched.sigint import CTRLC_PROTECTION_ENABLED
|
from aiosched.util.sigint import CTRLC_PROTECTION_ENABLED, currently_protected, enable_ki_protection
|
||||||
from selectors import DefaultSelector, BaseSelector, EVENT_READ, EVENT_WRITE
|
from selectors import DefaultSelector, BaseSelector, EVENT_READ, EVENT_WRITE
|
||||||
|
from types import FrameType
|
||||||
|
|
||||||
|
|
||||||
class FIFOKernel:
|
class FIFOKernel:
|
||||||
|
@ -118,16 +119,17 @@ class FIFOKernel:
|
||||||
)
|
)
|
||||||
return f"{type(self).__name__}({data})"
|
return f"{type(self).__name__}({data})"
|
||||||
|
|
||||||
def _sigint_handler(self, _sig: int, frame):
|
def _sigint_handler(self, _sig: int, _frame: FrameType):
|
||||||
"""
|
"""
|
||||||
Handles SIGINT
|
Handles SIGINT
|
||||||
|
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
self._sigint_handled = True
|
if currently_protected():
|
||||||
# Poke the event loop with a stick ;)
|
self._sigint_handled = True
|
||||||
self.reschedule_running()
|
else:
|
||||||
|
raise KeyboardInterrupt
|
||||||
|
|
||||||
def done(self) -> bool:
|
def done(self) -> bool:
|
||||||
"""
|
"""
|
||||||
|
@ -146,6 +148,8 @@ class FIFOKernel:
|
||||||
# waiting on. This avoids issues such as the event loop never exiting if the
|
# waiting on. This avoids issues such as the event loop never exiting if the
|
||||||
# user forgets to close a socket, for example
|
# user forgets to close a socket, for example
|
||||||
return False
|
return False
|
||||||
|
if self.current_task:
|
||||||
|
return self.current_task.done()
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def close(self, force: bool = False):
|
def close(self, force: bool = False):
|
||||||
|
@ -346,14 +350,14 @@ class FIFOKernel:
|
||||||
# We perform the deferred cancellation
|
# We perform the deferred cancellation
|
||||||
# if it was previously scheduled
|
# if it was previously scheduled
|
||||||
self.cancel(self.current_task)
|
self.cancel(self.current_task)
|
||||||
if self._sigint_handled:
|
|
||||||
self._sigint_handled = False
|
|
||||||
_runner = partial(self.current_task.throw, KeyboardInterrupt())
|
|
||||||
_data = []
|
|
||||||
# Some debugging and internal chatter here
|
# Some debugging and internal chatter here
|
||||||
self.current_task.steps += 1
|
self.current_task.steps += 1
|
||||||
self.current_task.state = TaskState.RUN
|
self.current_task.state = TaskState.RUN
|
||||||
self.debugger.before_task_step(self.current_task)
|
self.debugger.before_task_step(self.current_task)
|
||||||
|
if self._sigint_handled:
|
||||||
|
self._sigint_handled = False
|
||||||
|
_runner = partial(self.current_task.throw, KeyboardInterrupt)
|
||||||
|
_data = []
|
||||||
# Run a single step with the calculation (i.e. until a yield
|
# Run a single step with the calculation (i.e. until a yield
|
||||||
# somewhere)
|
# somewhere)
|
||||||
method, args, kwargs = _runner(*_data)
|
method, args, kwargs = _runner(*_data)
|
||||||
|
@ -387,7 +391,7 @@ class FIFOKernel:
|
||||||
Undoes any modification made by setup()
|
Undoes any modification made by setup()
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if signal.getsignal(signal.SIGINT) is self._sigint_handled:
|
if signal.getsignal(signal.SIGINT) is self._sigint_handler:
|
||||||
signal.signal(signal.SIGINT, signal.default_int_handler)
|
signal.signal(signal.SIGINT, signal.default_int_handler)
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
|
@ -442,6 +446,7 @@ class FIFOKernel:
|
||||||
# Otherwise, while there are tasks ready to run, we run them!
|
# Otherwise, while there are tasks ready to run, we run them!
|
||||||
self.handle_errors(self.run_task_step)
|
self.handle_errors(self.run_task_step)
|
||||||
|
|
||||||
|
@enable_ki_protection
|
||||||
def start(
|
def start(
|
||||||
self, func: Callable[..., Coroutine[Any, Any, Any]], *args, **kwargs
|
self, func: Callable[..., Coroutine[Any, Any, Any]], *args, **kwargs
|
||||||
) -> Any:
|
) -> Any:
|
||||||
|
|
|
@ -1,29 +0,0 @@
|
||||||
"""
|
|
||||||
aiosched: Yet another Python async scheduler
|
|
||||||
|
|
||||||
Copyright (C) 2022 nocturn9x
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
https:www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
"""
|
|
||||||
# Special magic module half-stolen from Trio (thanks njsmith I love you)
|
|
||||||
# that makes Ctrl+C work. P.S.: Please Python, get your signals straight.
|
|
||||||
|
|
||||||
|
|
||||||
# Just a funny variable name that is not a valid
|
|
||||||
# identifier (but still a string so tools that hack
|
|
||||||
# into frames don't freak out when they look at the
|
|
||||||
# local variables) which will get injected silently
|
|
||||||
# into every frame to enable/disable the safeguards
|
|
||||||
# for Ctrl+C/KeyboardInterrupt
|
|
||||||
CTRLC_PROTECTION_ENABLED = "|yes-it-is|"
|
|
||||||
|
|
|
@ -0,0 +1,125 @@
|
||||||
|
"""
|
||||||
|
aiosched: Yet another Python async scheduler
|
||||||
|
|
||||||
|
Copyright (C) 2022 nocturn9x
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
https:www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
"""
|
||||||
|
import sys
|
||||||
|
import inspect
|
||||||
|
from functools import wraps
|
||||||
|
from typing import Callable
|
||||||
|
from types import FrameType
|
||||||
|
|
||||||
|
|
||||||
|
# Special magic module half-stolen from Trio (thanks njsmith I love you)
|
||||||
|
# that makes Ctrl+C work. P.S.: Please Python, get your signals straight.
|
||||||
|
|
||||||
|
|
||||||
|
# Just a funny variable name that is not a valid
|
||||||
|
# identifier (but still a string so tools that hack
|
||||||
|
# into frames don't freak out when they look at the
|
||||||
|
# local variables) which will get injected silently
|
||||||
|
# into every frame to enable/disable the safeguards
|
||||||
|
# for Ctrl+C/KeyboardInterrupt
|
||||||
|
CTRLC_PROTECTION_ENABLED = "|yes-it-is|"
|
||||||
|
|
||||||
|
|
||||||
|
def critical_section(frame: FrameType) -> bool:
|
||||||
|
"""
|
||||||
|
Returns whether Ctrl+C protection is currently
|
||||||
|
enabled in the given frame or in any of its children.
|
||||||
|
Stolen from Trio
|
||||||
|
"""
|
||||||
|
|
||||||
|
while frame is not None:
|
||||||
|
if CTRLC_PROTECTION_ENABLED in frame.f_locals:
|
||||||
|
return frame.f_locals[CTRLC_PROTECTION_ENABLED]
|
||||||
|
if frame.f_code.co_name == "__del__":
|
||||||
|
return True
|
||||||
|
frame = frame.f_back
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def currently_protected() -> bool:
|
||||||
|
"""
|
||||||
|
Returns whether Ctrl+C protection is currently
|
||||||
|
enabled in the current context
|
||||||
|
"""
|
||||||
|
|
||||||
|
return critical_section(sys._getframe())
|
||||||
|
|
||||||
|
|
||||||
|
def legacy_isasyncgenfunction(obj):
|
||||||
|
return getattr(obj, "_async_gen_function", None) == id(obj)
|
||||||
|
|
||||||
|
|
||||||
|
def _ki_protection_decorator(enabled):
|
||||||
|
def decorator(fn):
|
||||||
|
# In some version of Python, isgeneratorfunction returns true for
|
||||||
|
# coroutine functions, so we have to check for coroutine functions
|
||||||
|
# first.
|
||||||
|
if inspect.iscoroutinefunction(fn):
|
||||||
|
|
||||||
|
@wraps(fn)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
# See the comment for regular generators below
|
||||||
|
coro = fn(*args, **kwargs)
|
||||||
|
coro.cr_frame.f_locals[CTRLC_PROTECTION_ENABLED] = enabled
|
||||||
|
return coro
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
elif inspect.isgeneratorfunction(fn):
|
||||||
|
|
||||||
|
@wraps(fn)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
# It's important that we inject this directly into the
|
||||||
|
# generator's locals, as opposed to setting it here and then
|
||||||
|
# doing 'yield from'. The reason is, if a generator is
|
||||||
|
# throw()n into, then it may magically pop to the top of the
|
||||||
|
# stack. And @contextmanager generators in particular are a
|
||||||
|
# case where we often want KI protection, and which are often
|
||||||
|
# thrown into! See:
|
||||||
|
# https://bugs.python.org/issue29590
|
||||||
|
gen = fn(*args, **kwargs)
|
||||||
|
gen.gi_frame.f_locals[CTRLC_PROTECTION_ENABLED] = enabled
|
||||||
|
return gen
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
elif inspect.isasyncgenfunction(fn) or legacy_isasyncgenfunction(fn):
|
||||||
|
|
||||||
|
@wraps(fn)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
# See the comment for regular generators above
|
||||||
|
agen = fn(*args, **kwargs)
|
||||||
|
agen.ag_frame.f_locals[CTRLC_PROTECTION_ENABLED] = enabled
|
||||||
|
return agen
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
else:
|
||||||
|
|
||||||
|
@wraps(fn)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
locals()[CTRLC_PROTECTION_ENABLED] = enabled
|
||||||
|
return fn(*args, **kwargs)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
enable_ki_protection = _ki_protection_decorator(True)
|
||||||
|
enable_ki_protection.__name__ = "enable_ki_protection"
|
||||||
|
|
||||||
|
disable_ki_protection = _ki_protection_decorator(False)
|
||||||
|
disable_ki_protection.__name__ = "disable_ki_protection"
|
|
@ -8,6 +8,21 @@ clients: dict[aiosched.socket.AsyncSocket, list[str, str]] = {}
|
||||||
names: set[str] = set()
|
names: set[str] = set()
|
||||||
|
|
||||||
|
|
||||||
|
async def message_handler(q: aiosched.Queue):
|
||||||
|
"""
|
||||||
|
Reads data submitted onto the queue
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
logging.info("Message handler spawned")
|
||||||
|
while True:
|
||||||
|
msg, payload = await q.get()
|
||||||
|
logging.info(f"Got message {msg!r} with the following payload: {payload}")
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"An exception occurred in the message handler -> {type(e).__name__}: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
async def serve(bind_address: tuple):
|
async def serve(bind_address: tuple):
|
||||||
"""
|
"""
|
||||||
Serves asynchronously forever (or until Ctrl+C ;))
|
Serves asynchronously forever (or until Ctrl+C ;))
|
||||||
|
@ -17,23 +32,26 @@ async def serve(bind_address: tuple):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
sock = aiosched.socket.socket()
|
sock = aiosched.socket.socket()
|
||||||
|
queue = aiosched.Queue()
|
||||||
await sock.bind(bind_address)
|
await sock.bind(bind_address)
|
||||||
await sock.listen(5)
|
await sock.listen(5)
|
||||||
logging.info(f"Serving asynchronously at {bind_address[0]}:{bind_address[1]}")
|
logging.info(f"Serving asynchronously at {bind_address[0]}:{bind_address[1]}")
|
||||||
async with aiosched.create_pool() as pool:
|
async with aiosched.create_pool() as pool:
|
||||||
|
await pool.spawn(message_handler, queue)
|
||||||
async with sock:
|
async with sock:
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
conn, address_tuple = await sock.accept()
|
conn, address_tuple = await sock.accept()
|
||||||
clients[conn] = ["", f"{address_tuple[0]}:{address_tuple[1]}"]
|
clients[conn] = ["", f"{address_tuple[0]}:{address_tuple[1]}"]
|
||||||
|
await queue.put(("connect", clients[conn]))
|
||||||
logging.info(f"{address_tuple[0]}:{address_tuple[1]} connected")
|
logging.info(f"{address_tuple[0]}:{address_tuple[1]} connected")
|
||||||
await pool.spawn(handler, conn)
|
await pool.spawn(handler, conn, queue)
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
# Because exceptions just *work*
|
# Because exceptions just *work*
|
||||||
logging.info(f"{address_tuple[0]}:{address_tuple[1]} has raised {type(err).__name__}: {err}")
|
logging.info(f"{address_tuple[0]}:{address_tuple[1]} has raised {type(err).__name__}: {err}")
|
||||||
|
|
||||||
|
|
||||||
async def handler(sock: aiosched.socket.AsyncSocket):
|
async def handler(sock: aiosched.socket.AsyncSocket, q: aiosched.Queue):
|
||||||
"""
|
"""
|
||||||
Handles a single client connection
|
Handles a single client connection
|
||||||
|
|
||||||
|
@ -59,6 +77,7 @@ async def handler(sock: aiosched.socket.AsyncSocket):
|
||||||
else:
|
else:
|
||||||
await sock.send_all(b"Sorry, but that name is already taken. Try again!\n> ")
|
await sock.send_all(b"Sorry, but that name is already taken. Try again!\n> ")
|
||||||
await sock.send_all(f"Okay {name}, welcome to the chatroom!\n".encode())
|
await sock.send_all(f"Okay {name}, welcome to the chatroom!\n".encode())
|
||||||
|
await q.put(("join", (address, name)))
|
||||||
logging.info(f"{name} has joined the chatroom ({address}), informing clients")
|
logging.info(f"{name} has joined the chatroom ({address}), informing clients")
|
||||||
for i, client_sock in enumerate(clients):
|
for i, client_sock in enumerate(clients):
|
||||||
if client_sock != sock and clients[client_sock][0]:
|
if client_sock != sock and clients[client_sock][0]:
|
||||||
|
@ -71,6 +90,7 @@ async def handler(sock: aiosched.socket.AsyncSocket):
|
||||||
decoded = data.decode().rstrip("\n")
|
decoded = data.decode().rstrip("\n")
|
||||||
if decoded.startswith("/"):
|
if decoded.startswith("/"):
|
||||||
logging.info(f"{name} issued server command {decoded}")
|
logging.info(f"{name} issued server command {decoded}")
|
||||||
|
await q.put(("cmd", (name, decoded[1:])))
|
||||||
match decoded[1:]:
|
match decoded[1:]:
|
||||||
case "bye":
|
case "bye":
|
||||||
await sock.send_all(b"Bye!\n")
|
await sock.send_all(b"Bye!\n")
|
||||||
|
@ -78,6 +98,7 @@ async def handler(sock: aiosched.socket.AsyncSocket):
|
||||||
case _:
|
case _:
|
||||||
await sock.send_all(b"Unknown command\n")
|
await sock.send_all(b"Unknown command\n")
|
||||||
else:
|
else:
|
||||||
|
await q.put(("msg", (name, data)))
|
||||||
logging.info(f"Got: {data!r} from {address}")
|
logging.info(f"Got: {data!r} from {address}")
|
||||||
for i, client_sock in enumerate(clients):
|
for i, client_sock in enumerate(clients):
|
||||||
if client_sock != sock and clients[client_sock][0]:
|
if client_sock != sock and clients[client_sock][0]:
|
||||||
|
@ -86,6 +107,7 @@ async def handler(sock: aiosched.socket.AsyncSocket):
|
||||||
data += b"\n"
|
data += b"\n"
|
||||||
await client_sock.send_all(f"[{name}] ({address}): {data.decode()}> ".encode())
|
await client_sock.send_all(f"[{name}] ({address}): {data.decode()}> ".encode())
|
||||||
logging.info(f"Sent {data!r} to {i} clients")
|
logging.info(f"Sent {data!r} to {i} clients")
|
||||||
|
await q.put(("leave", name))
|
||||||
logging.info(f"Connection from {address} closed")
|
logging.info(f"Connection from {address} closed")
|
||||||
logging.info(f"{name} has left the chatroom ({address}), informing clients")
|
logging.info(f"{name} has left the chatroom ({address}), informing clients")
|
||||||
for i, client_sock in enumerate(clients):
|
for i, client_sock in enumerate(clients):
|
||||||
|
@ -109,4 +131,5 @@ if __name__ == "__main__":
|
||||||
if isinstance(error, KeyboardInterrupt):
|
if isinstance(error, KeyboardInterrupt):
|
||||||
logging.info("Ctrl+C detected, exiting")
|
logging.info("Ctrl+C detected, exiting")
|
||||||
else:
|
else:
|
||||||
|
raise
|
||||||
logging.error(f"Exiting due to a {type(error).__name__}: {error}")
|
logging.error(f"Exiting due to a {type(error).__name__}: {error}")
|
||||||
|
|
Reference in New Issue