diff --git a/aiosched/context.py b/aiosched/context.py index a10ba05..c2d3e98 100644 --- a/aiosched/context.py +++ b/aiosched/context.py @@ -39,23 +39,27 @@ class TaskScope: self.silent = silent self.inner: TaskScope | None = None self.outer: TaskScope | None = None - self.pools: list[TaskPool] = list() self.waiter: Task | None = None self.entry_point: Task | None = None self.timed_out: bool = False + # Can we be cancelled? + self.cancellable: bool = True + # Task scope of our timeout worker + self.timeout_scope: TaskScope = None async def _timeout_worker(self): - await sleep(self.timeout) - for pool in self.pools: - if not pool.done(): + async with TaskScope() as scope: + self.timeout_scope = scope + # We can't let this task be cancelled + # or interrupted by Ctrl+C because this + # is the only safeguard of our timeouts: + # if this crashes, then timeouts don't work + # at all! + scope.cancellable = False + await sleep(self.timeout) + if not self.entry_point.done(): self.timed_out = True - await pool.cancel() - if pool.entry_point is not self.entry_point: - await cancel(pool.entry_point, block=True) - if not self.entry_point.done(): - self.timed_out = True - # raise TimeoutError("timed out") - await throw(self.entry_point, TimeoutError("timed out")) + await throw(self.entry_point, TimeoutError("timed out")) async def __aenter__(self): self.entry_point = await current_task() @@ -66,9 +70,15 @@ class TaskScope: async def __aexit__(self, exc_type: type, exception: Exception, tb): await close_scope(self) - if not self.waiter.done(): + if self.timeout and not self.waiter.done(): + # Well, looks like we finished before our worker. + # Thanks for your help! Now die. + self.timeout_scope.cancellable = True await cancel(self.waiter, block=True) - if exception is not None: + # Task scopes are sick: Nathaniel, you're an effing genius. + if isinstance(exception, TimeoutError) and self.timed_out: + # This way we only silence our own timeouts and not + # someone else's! return self.silent @@ -108,10 +118,7 @@ class TaskPool: Spawns a child task """ - task = await spawn(func, *args, **kwargs) - task.context = self - self.tasks.append(task) - return task + return await spawn(func, *args, **kwargs) async def __aenter__(self): """ @@ -119,9 +126,6 @@ class TaskPool: """ self.entry_point = await current_task() - scope = await get_current_scope() - if scope: - scope.pools.append(self) await set_context(self) return self @@ -141,10 +145,11 @@ class TaskPool: if self.inner: # We wait for inner contexts to terminate await self.event.wait() - except (Exception, KeyboardInterrupt) as exc: + except (Exception, KeyboardInterrupt) as err: + print(f"ctx: {err!r}") if not self.cancelled: await self.cancel() - self.error = exc + self.error = err finally: self.entry_point.propagate = True await close_context(self) diff --git a/aiosched/io.py b/aiosched/io.py index 88900a9..d48c79e 100644 --- a/aiosched/io.py +++ b/aiosched/io.py @@ -58,6 +58,7 @@ class AsyncStream: if open_fd: self.stream = os.fdopen(self._fd, **kwargs) os.set_blocking(self._fd, False) + # Do we close ourselves upon the end of a context manager? self.close_on_context_exit = close_on_context_exit async def read(self, size: int = -1): @@ -152,18 +153,16 @@ class AsyncSocket(AsyncStream): close_on_context_exit: bool = True, do_handshake_on_connect: bool = True, ): + super().__init__(fd=sock.fileno(), open_fd=False, close_on_context_exit=close_on_context_exit) # Do we perform the TCP handshake automatically # upon connection? This is mostly needed for SSL # sockets self.do_handshake_on_connect = do_handshake_on_connect - # Do we close ourselves upon the end of a context manager? - self.close_on_context_exit = close_on_context_exit - # The socket.fromfd function copies the file descriptor - # instead of using the same one, so we'd be trying to close - # a different resource if we used sock.fileno() instead - # of self.stream.fileno() as our file descriptor - self.stream = socket.fromfd(sock.fileno(), sock.family, sock.type, sock.proto) - self._fd = self.stream.fileno() + # As sockets, we have different methods compared to a + # generic asynchronous stream, so we need to set these + # fields manually. It's also why we passed open_fd=False + # to our constructor call earlier + self.stream = sock self.stream.setblocking(False) # A socket that isn't connected doesn't # need to be closed diff --git a/aiosched/kernel.py b/aiosched/kernel.py index 0502fef..70ab149 100644 --- a/aiosched/kernel.py +++ b/aiosched/kernel.py @@ -17,6 +17,7 @@ limitations under the License. """ import signal import itertools +import warnings from collections import deque from functools import partial from aiosched.task import Task, TaskState @@ -32,6 +33,7 @@ from aiosched.errors import ( ResourceBroken, ) from aiosched.context import TaskPool, TaskScope +from aiosched.sigint import CTRLC_PROTECTION_ENABLED from selectors import DefaultSelector, BaseSelector, EVENT_READ, EVENT_WRITE @@ -91,8 +93,8 @@ class FIFOKernel: self._sigint_handled: bool = False # Are we executing any task code? self._running: bool = False - # The current context we're in - self.current_context: TaskPool | None = None + # The current task pool we're in + self.current_pool: TaskPool | None = None # The current task scope we're in self.current_scope: TaskScope | None = None @@ -116,7 +118,7 @@ class FIFOKernel: ) return f"{type(self).__name__}({data})" - def _sigint_handler(self, *_args): + def _sigint_handler(self, _sig: int, frame): """ Handles SIGINT @@ -124,10 +126,7 @@ class FIFOKernel: """ self._sigint_handled = True - # We reschedule the current task - # immediately no matter what it's - # doing so that we process the - # exception right away + # Poke the event loop with a stick ;) self.reschedule_running() def done(self) -> bool: @@ -161,29 +160,34 @@ class FIFOKernel: self.current_task.throw( InternalError("cannot shut down a running event loop") ) - for task in self.all(): + for task in self.all(copy=True): self.cancel(task) + self.selector.close() - def all(self) -> Task: + def all(self, copy: bool = False) -> Task: """ - Yields all the tasks the event loop is keeping track of + Yields a ll the tasks the event loop is keeping track of. + This is an internal undocumented method """ - for task in itertools.chain(self.run_ready, self.paused): + sources = [] + if self.paused: + sources.append([]) + for _, __, task, ___ in self.paused.container: + sources[-1].append(task) + if copy: + sources.append(self.run_ready.copy()) + else: + sources.append(self.run_ready) + if self.selector.get_map(): + sources.append([]) + for key in (self.selector.get_map() or {}).values(): + for task in key.data.values(): + sources[-1].append(task) + for task in itertools.chain(*sources): task: Task yield task - def shutdown(self): - """ - Shuts down the event loop - """ - - for task in self.all(): - self.io_release_task(task) - self.paused.discard(task) - self.selector.close() - self.close() - def wait_io(self): """ Waits for I/O and schedules tasks when their @@ -265,12 +269,12 @@ class FIFOKernel: self.debugger.on_context_creation(ctx) self.current_task.context = ctx - if not self.current_context: - self.current_context = ctx + if self.current_pool is None: + self.current_pool = ctx else: - self.current_context.inner = ctx - ctx.outer = self.current_context - self.current_context = ctx + self.current_pool.inner = ctx + ctx.outer = self.current_pool + self.current_pool = ctx self.reschedule_running() def close_context(self, ctx: TaskPool): @@ -281,7 +285,7 @@ class FIFOKernel: ctx.inner = None self.debugger.on_context_exit(ctx) ctx.entry_point.context = None - self.current_context = ctx.outer + self.current_pool = ctx.outer self.reschedule_running() def set_scope(self, scope: TaskScope): @@ -289,7 +293,7 @@ class FIFOKernel: Sets the current task scope """ - if not self.current_scope: + if self.current_scope is None: self.current_scope = scope else: self.current_scope.inner = scope @@ -338,9 +342,9 @@ class FIFOKernel: self._running = True if self._sigint_handled: self._sigint_handled = False - self.reschedule_running() + self.run_ready.appendleft(self.current_task) self.current_task.throw(KeyboardInterrupt()) - elif self.current_task.pending_cancellation: + if self.current_task.pending_cancellation: # We perform the deferred cancellation # if it was previously scheduled self.cancel(self.current_task) @@ -368,6 +372,25 @@ class FIFOKernel: getattr(self, method)(*args, **kwargs) self.debugger.after_task_step(self.current_task) + def setup(self): + """ + Configures the event loop + """ + + if signal.getsignal(signal.SIGINT) == signal.default_int_handler: + signal.signal(signal.SIGINT, self._sigint_handler) + else: + warnings.warn("aiosched detected a custom signal handler for SIGINT and it won't touch it, but" + " keep in mind that Ctrl+C is likely to break!") + + def teardown(self): + """ + Undoes any modification made by setup() + """ + + if signal.getsignal(signal.SIGINT) == self._sigint_handled: + signal.signal(signal.SIGINT, signal.default_int_handler) + def run(self): """ The event loop's runner function. This method drives @@ -388,7 +411,7 @@ class FIFOKernel: # If we're done, which means there are # both no paused tasks and no running tasks, we # simply tear us down and return to self.start - self.shutdown() + self.close() break elif self._sigint_handled: # We got Ctrl+C-ed while not running a task! We pick @@ -404,7 +427,7 @@ class FIFOKernel: task, *_ = self.paused.get() else: task = self.current_task - self.run_ready.append(task) + self.run_ready.appendleft(task) self.handle_errors(self.run_task_step) elif not self.run_ready: # If there are no actively running tasks, we start by @@ -427,15 +450,14 @@ class FIFOKernel: Starts the event loop from a synchronous context """ - old = signal.getsignal(signal.SIGINT) - signal.signal(signal.SIGINT, self._sigint_handler) + self.setup() self.entry_point = Task(func.__name__ or str(func), func(*args, **kwargs)) self.run_ready.append(self.entry_point) self.debugger.on_start() try: self.run() finally: - signal.signal(signal.SIGINT, old) + self.teardown() if ( self.entry_point.exc and self.entry_point.context is None @@ -448,7 +470,7 @@ class FIFOKernel: self.debugger.on_exit() return self.entry_point.result - def io_release(self, resource): + def io_release(self, resource, internal: bool = False): """ Releases the given resource from our selector @@ -458,9 +480,10 @@ class FIFOKernel: if resource in self.selector.get_map(): self.selector.unregister(resource) self.debugger.on_io_unschedule(resource) - if resource is self.current_task.last_io[1]: + if self.current_task.last_io and resource is self.current_task.last_io[1]: self.current_task.last_io = None - self.reschedule_running() + if not internal: + self.reschedule_running() def io_release_task(self, task: Task): """ @@ -468,14 +491,15 @@ class FIFOKernel: for each I/O resource the given task owns """ - for key in dict(self.selector.get_map()).values(): + for key in dict(self.selector.get_map() or {}).values(): if task not in key.data.values(): continue if len(key.data.values()) == 2: - if key.data.values()[0] != task or key.data.values[1] != task: + a, b = key.data.values() + if a is not task or b is not task: continue self.notify_closing(key.fileobj, broken=True) - self.selector.unregister(key.fileobj) + self.io_release(key.fileobj, internal=True) task.last_io = None def get_active_io_count(self) -> int: @@ -523,16 +547,20 @@ class FIFOKernel: it fails """ - self.paused.discard(task) - self.io_release_task(task) - self.handle_errors(partial(task.throw, Cancelled(task)), task) - if task.state != TaskState.CANCELLED: - task.pending_cancellation = True - self.run_ready.append(task) - if self.current_task not in self.run_ready: - self.reschedule_running() + if not task.scope or task.scope.cancellable: + self.paused.discard(task) + self.io_release_task(task) + self.handle_errors(partial(task.throw, Cancelled(task)), task) + if task.state != TaskState.CANCELLED: + task.pending_cancellation = True + self.run_ready.append(task) + self.reschedule_running() def throw(self, task, error): + """ + Throws the given exception into the given task + """ + self.paused.discard(task) self.io_release_task(task) self.handle_errors(partial(task.throw, error), task) @@ -624,6 +652,13 @@ class FIFOKernel: """ task = Task(func.__name__ or repr(func), func(*args, **kwargs)) + # We inject our magic secret variable into the coroutine's stack frame so + # we can look it up later + task.coroutine.cr_frame.f_locals.setdefault(CTRLC_PROTECTION_ENABLED, False) + task.scope = self.current_scope + if self.current_pool: + task.context = self.current_pool + self.current_pool.tasks.append(task) self.data[self.current_task] = task self.run_ready.append(task) self.reschedule_running() diff --git a/aiosched/sigint.py b/aiosched/sigint.py new file mode 100644 index 0000000..83085c2 --- /dev/null +++ b/aiosched/sigint.py @@ -0,0 +1,29 @@ +""" +aiosched: Yet another Python async scheduler + +Copyright (C) 2022 nocturn9x + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https:www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +# Special magic module half-stolen from Trio (thanks njsmith I love you) +# that makes Ctrl+C work. P.S.: Please Python, get your signals straight. + + +# Just a funny variable name that is not a valid +# identifier (but still a string so tools that hack +# into frames don't freak out when they look at the +# local variables) which will get injected silently +# into every frame to enable/disable the safeguards +# for Ctrl+C/KeyboardInterrupt +CTRLC_PROTECTION_ENABLED = "|yes-it-is|" + diff --git a/aiosched/socket.py b/aiosched/socket.py index 9a260f7..e08688a 100644 --- a/aiosched/socket.py +++ b/aiosched/socket.py @@ -19,6 +19,9 @@ import socket as _socket from aiosched.io import AsyncSocket +DEFAULT_HE_DELAY = 0.250 + + def wrap_socket(sock: _socket.socket) -> AsyncSocket: """ Wraps a standard socket into an async socket @@ -29,9 +32,9 @@ def wrap_socket(sock: _socket.socket) -> AsyncSocket: def socket(*args, **kwargs): """ - Creates a new giambio socket, taking in the same positional and + Creates a new aiosched socket, taking in the same positional and keyword arguments as the standard library's socket.socket constructor """ - return wrap_socket(_socket.socket(*args, **kwargs)) + return wrap_socket(_socket.socket(*args, **kwargs)) \ No newline at end of file diff --git a/aiosched/task.py b/aiosched/task.py index 28943b6..01d842e 100644 --- a/aiosched/task.py +++ b/aiosched/task.py @@ -83,6 +83,8 @@ class Task: context: "TaskPool" = field(default=None, repr=False) # We propagate exception only at the first call to wait() propagate: bool = True + # The task's scope + scope: "TaskScope" = field(default=None, repr=False) def run(self, what: Any | None = None): """ diff --git a/tests/chatroom_server.py b/tests/chatroom_server.py index a612c67..ddf6292 100644 --- a/tests/chatroom_server.py +++ b/tests/chatroom_server.py @@ -10,9 +10,9 @@ names: set[str] = set() async def serve(bind_address: tuple): """ - Serves asynchronously forever + Serves asynchronously forever (or until Ctrl+C ;)) - :param bind_address: The address to bind the server to represented as a tuple + :param bind_address: The address to bind the server to, represented as a tuple (address, port) where address is a string and port is an integer """ @@ -20,17 +20,18 @@ async def serve(bind_address: tuple): await sock.bind(bind_address) await sock.listen(5) logging.info(f"Serving asynchronously at {bind_address[0]}:{bind_address[1]}") - async with aiosched.create_pool() as ctx: - async with sock: - while True: - try: + async with sock: + while True: + try: + async with aiosched.create_pool() as pool: conn, address_tuple = await sock.accept() clients[conn] = ["", f"{address_tuple[0]}:{address_tuple[1]}"] logging.info(f"{address_tuple[0]}:{address_tuple[1]} connected") - await ctx.spawn(handler, conn) - except Exception as err: - # Because exceptions just *work* - logging.info(f"{address_tuple[0]}:{address_tuple[1]} has raised {type(err).__name__}: {err}") + await pool.spawn(handler, conn) + except Exception as err: + raise + # Because exceptions just *work* + logging.info(f"{address_tuple[0]}:{address_tuple[1]} has raised {type(err).__name__}: {err}") async def handler(sock: aiosched.socket.AsyncSocket): @@ -87,6 +88,10 @@ async def handler(sock: aiosched.socket.AsyncSocket): await client_sock.send_all(f"[{name}] ({address}): {data.decode()}> ".encode()) logging.info(f"Sent {data!r} to {i} clients") logging.info(f"Connection from {address} closed") + logging.info(f"{name} has left the chatroom ({address}), informing clients") + for i, client_sock in enumerate(clients): + if client_sock != sock and clients[client_sock][0]: + await client_sock.send_all(f"{name} has left the chatroom\n> ".encode()) clients.pop(sock) names.discard(name) logging.info("Handler shutting down") @@ -105,4 +110,5 @@ if __name__ == "__main__": if isinstance(error, KeyboardInterrupt): logging.info("Ctrl+C detected, exiting") else: + raise logging.error(f"Exiting due to a {type(error).__name__}: {error}") diff --git a/tests/memory_channel.py b/tests/memory_channel.py index 3a36bf7..8d8793e 100644 --- a/tests/memory_channel.py +++ b/tests/memory_channel.py @@ -1,7 +1,6 @@ import aiosched - async def sender(c: aiosched.MemoryChannel, n: int): for i in range(n): await c.write(str(i)) diff --git a/tests/proxy.py b/tests/proxy.py new file mode 100644 index 0000000..69edc1c --- /dev/null +++ b/tests/proxy.py @@ -0,0 +1,34 @@ +import aiosched +import socket + +# TODO: This is borked + +# Pls notice me njsmith senpai :> + + +async def proxy_one_way(source: aiosched.socket.AsyncSocket, sink: aiosched.socket.AsyncSocket): + while True: + data = await source.receive(1024) + if not data: + await sink.shutdown(socket.SHUT_WR) + break + await sink.send_all(data) + + +async def proxy_two_way(a: aiosched.socket.AsyncSocket, b: aiosched.socket.AsyncSocket): + async with aiosched.create_pool() as pool: + await pool.spawn(proxy_one_way, a, b) + await pool.spawn(proxy_one_way, b, a) + + +async def main(): + async with aiosched.skip_after(10): + a = aiosched.socket.socket(socket.AF_INET, socket.SOCK_STREAM) + b = aiosched.socket.socket(socket.AF_INET, socket.SOCK_STREAM) + await a.connect(("localhost", 12345)) + await b.connect(("localhost", 54321)) + async with a, b: + await proxy_two_way(a, b) + + +aiosched.run(main) diff --git a/tests/socket_ssl.py b/tests/socket_ssl.py index e8238b6..c6168ce 100644 --- a/tests/socket_ssl.py +++ b/tests/socket_ssl.py @@ -26,17 +26,21 @@ async def test(host: str, port: int, bufsize: int = 4096): # in the AsyncSocket class instead do_handshake_on_connect=False, server_hostname=host, - ) + ), ) + # You have the option to do the handshake yourself + # socket.do_handshake_on_connect = False print(f"Attempting a connection to {host}:{port}") await socket.connect((host, port)) + # await socket.do_handshake() print("Connected") + # Ensures the code below doesn't run for more than 5 seconds async with aiosched.skip_after(5) as scope: + # Closes the socket automatically async with socket: - # Closes the socket automatically print("Entered socket context manager, sending request data") await socket.send_all( - f"GET / HTTP/1.1\r\nHost: {host}\r\nUser-Agent: owo\r\nAccept: text/html\r\nConnection: keep-alive\r\nAccept: */*\r\n\r\n".encode() + f"GET / HTTP/1.1\r\nUser-Agent: PostmanRuntime/7.32.2\r\nAccept: */*\r\nHost: {host}\r\nAccept-Encoding: gzip, deflate, br\r\nConnection: keep-alive\r\n\r\n".encode() ) print("Data sent") buffer = b"" @@ -60,8 +64,6 @@ async def test(host: str, port: int, bufsize: int = 4096): continue else: if not element.strip() and not content: - # This only works because google sends a newline - # before the content sys.stdout.write("\nContent:") content = True if not content: @@ -72,4 +74,4 @@ async def test(host: str, port: int, bufsize: int = 4096): _print("Done!") -aiosched.run(test, "debian.org", 443, 256, debugger=()) +aiosched.run(test, "nocturn9x.space", 443, 256, debugger=())