Initial work on proper Ctrl+C handling. Minor fixes and additions

This commit is contained in:
Mattia Giambirtone 2023-05-01 14:41:05 +02:00
parent d10ae9c55b
commit a86e0afbbd
Signed by: nocturn9x
GPG Key ID: 8270F9F467971E59
10 changed files with 213 additions and 99 deletions

View File

@ -39,23 +39,27 @@ class TaskScope:
self.silent = silent
self.inner: TaskScope | None = None
self.outer: TaskScope | None = None
self.pools: list[TaskPool] = list()
self.waiter: Task | None = None
self.entry_point: Task | None = None
self.timed_out: bool = False
# Can we be cancelled?
self.cancellable: bool = True
# Task scope of our timeout worker
self.timeout_scope: TaskScope = None
async def _timeout_worker(self):
await sleep(self.timeout)
for pool in self.pools:
if not pool.done():
async with TaskScope() as scope:
self.timeout_scope = scope
# We can't let this task be cancelled
# or interrupted by Ctrl+C because this
# is the only safeguard of our timeouts:
# if this crashes, then timeouts don't work
# at all!
scope.cancellable = False
await sleep(self.timeout)
if not self.entry_point.done():
self.timed_out = True
await pool.cancel()
if pool.entry_point is not self.entry_point:
await cancel(pool.entry_point, block=True)
if not self.entry_point.done():
self.timed_out = True
# raise TimeoutError("timed out")
await throw(self.entry_point, TimeoutError("timed out"))
await throw(self.entry_point, TimeoutError("timed out"))
async def __aenter__(self):
self.entry_point = await current_task()
@ -66,9 +70,15 @@ class TaskScope:
async def __aexit__(self, exc_type: type, exception: Exception, tb):
await close_scope(self)
if not self.waiter.done():
if self.timeout and not self.waiter.done():
# Well, looks like we finished before our worker.
# Thanks for your help! Now die.
self.timeout_scope.cancellable = True
await cancel(self.waiter, block=True)
if exception is not None:
# Task scopes are sick: Nathaniel, you're an effing genius.
if isinstance(exception, TimeoutError) and self.timed_out:
# This way we only silence our own timeouts and not
# someone else's!
return self.silent
@ -108,10 +118,7 @@ class TaskPool:
Spawns a child task
"""
task = await spawn(func, *args, **kwargs)
task.context = self
self.tasks.append(task)
return task
return await spawn(func, *args, **kwargs)
async def __aenter__(self):
"""
@ -119,9 +126,6 @@ class TaskPool:
"""
self.entry_point = await current_task()
scope = await get_current_scope()
if scope:
scope.pools.append(self)
await set_context(self)
return self
@ -141,10 +145,11 @@ class TaskPool:
if self.inner:
# We wait for inner contexts to terminate
await self.event.wait()
except (Exception, KeyboardInterrupt) as exc:
except (Exception, KeyboardInterrupt) as err:
print(f"ctx: {err!r}")
if not self.cancelled:
await self.cancel()
self.error = exc
self.error = err
finally:
self.entry_point.propagate = True
await close_context(self)

View File

@ -58,6 +58,7 @@ class AsyncStream:
if open_fd:
self.stream = os.fdopen(self._fd, **kwargs)
os.set_blocking(self._fd, False)
# Do we close ourselves upon the end of a context manager?
self.close_on_context_exit = close_on_context_exit
async def read(self, size: int = -1):
@ -152,18 +153,16 @@ class AsyncSocket(AsyncStream):
close_on_context_exit: bool = True,
do_handshake_on_connect: bool = True,
):
super().__init__(fd=sock.fileno(), open_fd=False, close_on_context_exit=close_on_context_exit)
# Do we perform the TCP handshake automatically
# upon connection? This is mostly needed for SSL
# sockets
self.do_handshake_on_connect = do_handshake_on_connect
# Do we close ourselves upon the end of a context manager?
self.close_on_context_exit = close_on_context_exit
# The socket.fromfd function copies the file descriptor
# instead of using the same one, so we'd be trying to close
# a different resource if we used sock.fileno() instead
# of self.stream.fileno() as our file descriptor
self.stream = socket.fromfd(sock.fileno(), sock.family, sock.type, sock.proto)
self._fd = self.stream.fileno()
# As sockets, we have different methods compared to a
# generic asynchronous stream, so we need to set these
# fields manually. It's also why we passed open_fd=False
# to our constructor call earlier
self.stream = sock
self.stream.setblocking(False)
# A socket that isn't connected doesn't
# need to be closed

View File

@ -17,6 +17,7 @@ limitations under the License.
"""
import signal
import itertools
import warnings
from collections import deque
from functools import partial
from aiosched.task import Task, TaskState
@ -32,6 +33,7 @@ from aiosched.errors import (
ResourceBroken,
)
from aiosched.context import TaskPool, TaskScope
from aiosched.sigint import CTRLC_PROTECTION_ENABLED
from selectors import DefaultSelector, BaseSelector, EVENT_READ, EVENT_WRITE
@ -91,8 +93,8 @@ class FIFOKernel:
self._sigint_handled: bool = False
# Are we executing any task code?
self._running: bool = False
# The current context we're in
self.current_context: TaskPool | None = None
# The current task pool we're in
self.current_pool: TaskPool | None = None
# The current task scope we're in
self.current_scope: TaskScope | None = None
@ -116,7 +118,7 @@ class FIFOKernel:
)
return f"{type(self).__name__}({data})"
def _sigint_handler(self, *_args):
def _sigint_handler(self, _sig: int, frame):
"""
Handles SIGINT
@ -124,10 +126,7 @@ class FIFOKernel:
"""
self._sigint_handled = True
# We reschedule the current task
# immediately no matter what it's
# doing so that we process the
# exception right away
# Poke the event loop with a stick ;)
self.reschedule_running()
def done(self) -> bool:
@ -161,29 +160,34 @@ class FIFOKernel:
self.current_task.throw(
InternalError("cannot shut down a running event loop")
)
for task in self.all():
for task in self.all(copy=True):
self.cancel(task)
self.selector.close()
def all(self) -> Task:
def all(self, copy: bool = False) -> Task:
"""
Yields all the tasks the event loop is keeping track of
Yields a ll the tasks the event loop is keeping track of.
This is an internal undocumented method
"""
for task in itertools.chain(self.run_ready, self.paused):
sources = []
if self.paused:
sources.append([])
for _, __, task, ___ in self.paused.container:
sources[-1].append(task)
if copy:
sources.append(self.run_ready.copy())
else:
sources.append(self.run_ready)
if self.selector.get_map():
sources.append([])
for key in (self.selector.get_map() or {}).values():
for task in key.data.values():
sources[-1].append(task)
for task in itertools.chain(*sources):
task: Task
yield task
def shutdown(self):
"""
Shuts down the event loop
"""
for task in self.all():
self.io_release_task(task)
self.paused.discard(task)
self.selector.close()
self.close()
def wait_io(self):
"""
Waits for I/O and schedules tasks when their
@ -265,12 +269,12 @@ class FIFOKernel:
self.debugger.on_context_creation(ctx)
self.current_task.context = ctx
if not self.current_context:
self.current_context = ctx
if self.current_pool is None:
self.current_pool = ctx
else:
self.current_context.inner = ctx
ctx.outer = self.current_context
self.current_context = ctx
self.current_pool.inner = ctx
ctx.outer = self.current_pool
self.current_pool = ctx
self.reschedule_running()
def close_context(self, ctx: TaskPool):
@ -281,7 +285,7 @@ class FIFOKernel:
ctx.inner = None
self.debugger.on_context_exit(ctx)
ctx.entry_point.context = None
self.current_context = ctx.outer
self.current_pool = ctx.outer
self.reschedule_running()
def set_scope(self, scope: TaskScope):
@ -289,7 +293,7 @@ class FIFOKernel:
Sets the current task scope
"""
if not self.current_scope:
if self.current_scope is None:
self.current_scope = scope
else:
self.current_scope.inner = scope
@ -338,9 +342,9 @@ class FIFOKernel:
self._running = True
if self._sigint_handled:
self._sigint_handled = False
self.reschedule_running()
self.run_ready.appendleft(self.current_task)
self.current_task.throw(KeyboardInterrupt())
elif self.current_task.pending_cancellation:
if self.current_task.pending_cancellation:
# We perform the deferred cancellation
# if it was previously scheduled
self.cancel(self.current_task)
@ -368,6 +372,25 @@ class FIFOKernel:
getattr(self, method)(*args, **kwargs)
self.debugger.after_task_step(self.current_task)
def setup(self):
"""
Configures the event loop
"""
if signal.getsignal(signal.SIGINT) == signal.default_int_handler:
signal.signal(signal.SIGINT, self._sigint_handler)
else:
warnings.warn("aiosched detected a custom signal handler for SIGINT and it won't touch it, but"
" keep in mind that Ctrl+C is likely to break!")
def teardown(self):
"""
Undoes any modification made by setup()
"""
if signal.getsignal(signal.SIGINT) == self._sigint_handled:
signal.signal(signal.SIGINT, signal.default_int_handler)
def run(self):
"""
The event loop's runner function. This method drives
@ -388,7 +411,7 @@ class FIFOKernel:
# If we're done, which means there are
# both no paused tasks and no running tasks, we
# simply tear us down and return to self.start
self.shutdown()
self.close()
break
elif self._sigint_handled:
# We got Ctrl+C-ed while not running a task! We pick
@ -404,7 +427,7 @@ class FIFOKernel:
task, *_ = self.paused.get()
else:
task = self.current_task
self.run_ready.append(task)
self.run_ready.appendleft(task)
self.handle_errors(self.run_task_step)
elif not self.run_ready:
# If there are no actively running tasks, we start by
@ -427,15 +450,14 @@ class FIFOKernel:
Starts the event loop from a synchronous context
"""
old = signal.getsignal(signal.SIGINT)
signal.signal(signal.SIGINT, self._sigint_handler)
self.setup()
self.entry_point = Task(func.__name__ or str(func), func(*args, **kwargs))
self.run_ready.append(self.entry_point)
self.debugger.on_start()
try:
self.run()
finally:
signal.signal(signal.SIGINT, old)
self.teardown()
if (
self.entry_point.exc
and self.entry_point.context is None
@ -448,7 +470,7 @@ class FIFOKernel:
self.debugger.on_exit()
return self.entry_point.result
def io_release(self, resource):
def io_release(self, resource, internal: bool = False):
"""
Releases the given resource from our
selector
@ -458,9 +480,10 @@ class FIFOKernel:
if resource in self.selector.get_map():
self.selector.unregister(resource)
self.debugger.on_io_unschedule(resource)
if resource is self.current_task.last_io[1]:
if self.current_task.last_io and resource is self.current_task.last_io[1]:
self.current_task.last_io = None
self.reschedule_running()
if not internal:
self.reschedule_running()
def io_release_task(self, task: Task):
"""
@ -468,14 +491,15 @@ class FIFOKernel:
for each I/O resource the given task owns
"""
for key in dict(self.selector.get_map()).values():
for key in dict(self.selector.get_map() or {}).values():
if task not in key.data.values():
continue
if len(key.data.values()) == 2:
if key.data.values()[0] != task or key.data.values[1] != task:
a, b = key.data.values()
if a is not task or b is not task:
continue
self.notify_closing(key.fileobj, broken=True)
self.selector.unregister(key.fileobj)
self.io_release(key.fileobj, internal=True)
task.last_io = None
def get_active_io_count(self) -> int:
@ -523,16 +547,20 @@ class FIFOKernel:
it fails
"""
self.paused.discard(task)
self.io_release_task(task)
self.handle_errors(partial(task.throw, Cancelled(task)), task)
if task.state != TaskState.CANCELLED:
task.pending_cancellation = True
self.run_ready.append(task)
if self.current_task not in self.run_ready:
self.reschedule_running()
if not task.scope or task.scope.cancellable:
self.paused.discard(task)
self.io_release_task(task)
self.handle_errors(partial(task.throw, Cancelled(task)), task)
if task.state != TaskState.CANCELLED:
task.pending_cancellation = True
self.run_ready.append(task)
self.reschedule_running()
def throw(self, task, error):
"""
Throws the given exception into the given task
"""
self.paused.discard(task)
self.io_release_task(task)
self.handle_errors(partial(task.throw, error), task)
@ -624,6 +652,13 @@ class FIFOKernel:
"""
task = Task(func.__name__ or repr(func), func(*args, **kwargs))
# We inject our magic secret variable into the coroutine's stack frame so
# we can look it up later
task.coroutine.cr_frame.f_locals.setdefault(CTRLC_PROTECTION_ENABLED, False)
task.scope = self.current_scope
if self.current_pool:
task.context = self.current_pool
self.current_pool.tasks.append(task)
self.data[self.current_task] = task
self.run_ready.append(task)
self.reschedule_running()

29
aiosched/sigint.py Normal file
View File

@ -0,0 +1,29 @@
"""
aiosched: Yet another Python async scheduler
Copyright (C) 2022 nocturn9x
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https:www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
# Special magic module half-stolen from Trio (thanks njsmith I love you)
# that makes Ctrl+C work. P.S.: Please Python, get your signals straight.
# Just a funny variable name that is not a valid
# identifier (but still a string so tools that hack
# into frames don't freak out when they look at the
# local variables) which will get injected silently
# into every frame to enable/disable the safeguards
# for Ctrl+C/KeyboardInterrupt
CTRLC_PROTECTION_ENABLED = "|yes-it-is|"

View File

@ -19,6 +19,9 @@ import socket as _socket
from aiosched.io import AsyncSocket
DEFAULT_HE_DELAY = 0.250
def wrap_socket(sock: _socket.socket) -> AsyncSocket:
"""
Wraps a standard socket into an async socket
@ -29,9 +32,9 @@ def wrap_socket(sock: _socket.socket) -> AsyncSocket:
def socket(*args, **kwargs):
"""
Creates a new giambio socket, taking in the same positional and
Creates a new aiosched socket, taking in the same positional and
keyword arguments as the standard library's socket.socket
constructor
"""
return wrap_socket(_socket.socket(*args, **kwargs))
return wrap_socket(_socket.socket(*args, **kwargs))

View File

@ -83,6 +83,8 @@ class Task:
context: "TaskPool" = field(default=None, repr=False)
# We propagate exception only at the first call to wait()
propagate: bool = True
# The task's scope
scope: "TaskScope" = field(default=None, repr=False)
def run(self, what: Any | None = None):
"""

View File

@ -10,9 +10,9 @@ names: set[str] = set()
async def serve(bind_address: tuple):
"""
Serves asynchronously forever
Serves asynchronously forever (or until Ctrl+C ;))
:param bind_address: The address to bind the server to represented as a tuple
:param bind_address: The address to bind the server to, represented as a tuple
(address, port) where address is a string and port is an integer
"""
@ -20,17 +20,18 @@ async def serve(bind_address: tuple):
await sock.bind(bind_address)
await sock.listen(5)
logging.info(f"Serving asynchronously at {bind_address[0]}:{bind_address[1]}")
async with aiosched.create_pool() as ctx:
async with sock:
while True:
try:
async with sock:
while True:
try:
async with aiosched.create_pool() as pool:
conn, address_tuple = await sock.accept()
clients[conn] = ["", f"{address_tuple[0]}:{address_tuple[1]}"]
logging.info(f"{address_tuple[0]}:{address_tuple[1]} connected")
await ctx.spawn(handler, conn)
except Exception as err:
# Because exceptions just *work*
logging.info(f"{address_tuple[0]}:{address_tuple[1]} has raised {type(err).__name__}: {err}")
await pool.spawn(handler, conn)
except Exception as err:
raise
# Because exceptions just *work*
logging.info(f"{address_tuple[0]}:{address_tuple[1]} has raised {type(err).__name__}: {err}")
async def handler(sock: aiosched.socket.AsyncSocket):
@ -87,6 +88,10 @@ async def handler(sock: aiosched.socket.AsyncSocket):
await client_sock.send_all(f"[{name}] ({address}): {data.decode()}> ".encode())
logging.info(f"Sent {data!r} to {i} clients")
logging.info(f"Connection from {address} closed")
logging.info(f"{name} has left the chatroom ({address}), informing clients")
for i, client_sock in enumerate(clients):
if client_sock != sock and clients[client_sock][0]:
await client_sock.send_all(f"{name} has left the chatroom\n> ".encode())
clients.pop(sock)
names.discard(name)
logging.info("Handler shutting down")
@ -105,4 +110,5 @@ if __name__ == "__main__":
if isinstance(error, KeyboardInterrupt):
logging.info("Ctrl+C detected, exiting")
else:
raise
logging.error(f"Exiting due to a {type(error).__name__}: {error}")

View File

@ -1,7 +1,6 @@
import aiosched
async def sender(c: aiosched.MemoryChannel, n: int):
for i in range(n):
await c.write(str(i))

34
tests/proxy.py Normal file
View File

@ -0,0 +1,34 @@
import aiosched
import socket
# TODO: This is borked
# Pls notice me njsmith senpai :>
async def proxy_one_way(source: aiosched.socket.AsyncSocket, sink: aiosched.socket.AsyncSocket):
while True:
data = await source.receive(1024)
if not data:
await sink.shutdown(socket.SHUT_WR)
break
await sink.send_all(data)
async def proxy_two_way(a: aiosched.socket.AsyncSocket, b: aiosched.socket.AsyncSocket):
async with aiosched.create_pool() as pool:
await pool.spawn(proxy_one_way, a, b)
await pool.spawn(proxy_one_way, b, a)
async def main():
async with aiosched.skip_after(10):
a = aiosched.socket.socket(socket.AF_INET, socket.SOCK_STREAM)
b = aiosched.socket.socket(socket.AF_INET, socket.SOCK_STREAM)
await a.connect(("localhost", 12345))
await b.connect(("localhost", 54321))
async with a, b:
await proxy_two_way(a, b)
aiosched.run(main)

View File

@ -26,17 +26,21 @@ async def test(host: str, port: int, bufsize: int = 4096):
# in the AsyncSocket class instead
do_handshake_on_connect=False,
server_hostname=host,
)
),
)
# You have the option to do the handshake yourself
# socket.do_handshake_on_connect = False
print(f"Attempting a connection to {host}:{port}")
await socket.connect((host, port))
# await socket.do_handshake()
print("Connected")
# Ensures the code below doesn't run for more than 5 seconds
async with aiosched.skip_after(5) as scope:
# Closes the socket automatically
async with socket:
# Closes the socket automatically
print("Entered socket context manager, sending request data")
await socket.send_all(
f"GET / HTTP/1.1\r\nHost: {host}\r\nUser-Agent: owo\r\nAccept: text/html\r\nConnection: keep-alive\r\nAccept: */*\r\n\r\n".encode()
f"GET / HTTP/1.1\r\nUser-Agent: PostmanRuntime/7.32.2\r\nAccept: */*\r\nHost: {host}\r\nAccept-Encoding: gzip, deflate, br\r\nConnection: keep-alive\r\n\r\n".encode()
)
print("Data sent")
buffer = b""
@ -60,8 +64,6 @@ async def test(host: str, port: int, bufsize: int = 4096):
continue
else:
if not element.strip() and not content:
# This only works because google sends a newline
# before the content
sys.stdout.write("\nContent:")
content = True
if not content:
@ -72,4 +74,4 @@ async def test(host: str, port: int, bufsize: int = 4096):
_print("Done!")
aiosched.run(test, "debian.org", 443, 256, debugger=())
aiosched.run(test, "nocturn9x.space", 443, 256, debugger=())