Initial work on proper Ctrl+C handling. Minor fixes and additions
This commit is contained in:
parent
d10ae9c55b
commit
a86e0afbbd
|
@ -39,23 +39,27 @@ class TaskScope:
|
||||||
self.silent = silent
|
self.silent = silent
|
||||||
self.inner: TaskScope | None = None
|
self.inner: TaskScope | None = None
|
||||||
self.outer: TaskScope | None = None
|
self.outer: TaskScope | None = None
|
||||||
self.pools: list[TaskPool] = list()
|
|
||||||
self.waiter: Task | None = None
|
self.waiter: Task | None = None
|
||||||
self.entry_point: Task | None = None
|
self.entry_point: Task | None = None
|
||||||
self.timed_out: bool = False
|
self.timed_out: bool = False
|
||||||
|
# Can we be cancelled?
|
||||||
|
self.cancellable: bool = True
|
||||||
|
# Task scope of our timeout worker
|
||||||
|
self.timeout_scope: TaskScope = None
|
||||||
|
|
||||||
async def _timeout_worker(self):
|
async def _timeout_worker(self):
|
||||||
await sleep(self.timeout)
|
async with TaskScope() as scope:
|
||||||
for pool in self.pools:
|
self.timeout_scope = scope
|
||||||
if not pool.done():
|
# We can't let this task be cancelled
|
||||||
|
# or interrupted by Ctrl+C because this
|
||||||
|
# is the only safeguard of our timeouts:
|
||||||
|
# if this crashes, then timeouts don't work
|
||||||
|
# at all!
|
||||||
|
scope.cancellable = False
|
||||||
|
await sleep(self.timeout)
|
||||||
|
if not self.entry_point.done():
|
||||||
self.timed_out = True
|
self.timed_out = True
|
||||||
await pool.cancel()
|
await throw(self.entry_point, TimeoutError("timed out"))
|
||||||
if pool.entry_point is not self.entry_point:
|
|
||||||
await cancel(pool.entry_point, block=True)
|
|
||||||
if not self.entry_point.done():
|
|
||||||
self.timed_out = True
|
|
||||||
# raise TimeoutError("timed out")
|
|
||||||
await throw(self.entry_point, TimeoutError("timed out"))
|
|
||||||
|
|
||||||
async def __aenter__(self):
|
async def __aenter__(self):
|
||||||
self.entry_point = await current_task()
|
self.entry_point = await current_task()
|
||||||
|
@ -66,9 +70,15 @@ class TaskScope:
|
||||||
|
|
||||||
async def __aexit__(self, exc_type: type, exception: Exception, tb):
|
async def __aexit__(self, exc_type: type, exception: Exception, tb):
|
||||||
await close_scope(self)
|
await close_scope(self)
|
||||||
if not self.waiter.done():
|
if self.timeout and not self.waiter.done():
|
||||||
|
# Well, looks like we finished before our worker.
|
||||||
|
# Thanks for your help! Now die.
|
||||||
|
self.timeout_scope.cancellable = True
|
||||||
await cancel(self.waiter, block=True)
|
await cancel(self.waiter, block=True)
|
||||||
if exception is not None:
|
# Task scopes are sick: Nathaniel, you're an effing genius.
|
||||||
|
if isinstance(exception, TimeoutError) and self.timed_out:
|
||||||
|
# This way we only silence our own timeouts and not
|
||||||
|
# someone else's!
|
||||||
return self.silent
|
return self.silent
|
||||||
|
|
||||||
|
|
||||||
|
@ -108,10 +118,7 @@ class TaskPool:
|
||||||
Spawns a child task
|
Spawns a child task
|
||||||
"""
|
"""
|
||||||
|
|
||||||
task = await spawn(func, *args, **kwargs)
|
return await spawn(func, *args, **kwargs)
|
||||||
task.context = self
|
|
||||||
self.tasks.append(task)
|
|
||||||
return task
|
|
||||||
|
|
||||||
async def __aenter__(self):
|
async def __aenter__(self):
|
||||||
"""
|
"""
|
||||||
|
@ -119,9 +126,6 @@ class TaskPool:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
self.entry_point = await current_task()
|
self.entry_point = await current_task()
|
||||||
scope = await get_current_scope()
|
|
||||||
if scope:
|
|
||||||
scope.pools.append(self)
|
|
||||||
await set_context(self)
|
await set_context(self)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
@ -141,10 +145,11 @@ class TaskPool:
|
||||||
if self.inner:
|
if self.inner:
|
||||||
# We wait for inner contexts to terminate
|
# We wait for inner contexts to terminate
|
||||||
await self.event.wait()
|
await self.event.wait()
|
||||||
except (Exception, KeyboardInterrupt) as exc:
|
except (Exception, KeyboardInterrupt) as err:
|
||||||
|
print(f"ctx: {err!r}")
|
||||||
if not self.cancelled:
|
if not self.cancelled:
|
||||||
await self.cancel()
|
await self.cancel()
|
||||||
self.error = exc
|
self.error = err
|
||||||
finally:
|
finally:
|
||||||
self.entry_point.propagate = True
|
self.entry_point.propagate = True
|
||||||
await close_context(self)
|
await close_context(self)
|
||||||
|
|
|
@ -58,6 +58,7 @@ class AsyncStream:
|
||||||
if open_fd:
|
if open_fd:
|
||||||
self.stream = os.fdopen(self._fd, **kwargs)
|
self.stream = os.fdopen(self._fd, **kwargs)
|
||||||
os.set_blocking(self._fd, False)
|
os.set_blocking(self._fd, False)
|
||||||
|
# Do we close ourselves upon the end of a context manager?
|
||||||
self.close_on_context_exit = close_on_context_exit
|
self.close_on_context_exit = close_on_context_exit
|
||||||
|
|
||||||
async def read(self, size: int = -1):
|
async def read(self, size: int = -1):
|
||||||
|
@ -152,18 +153,16 @@ class AsyncSocket(AsyncStream):
|
||||||
close_on_context_exit: bool = True,
|
close_on_context_exit: bool = True,
|
||||||
do_handshake_on_connect: bool = True,
|
do_handshake_on_connect: bool = True,
|
||||||
):
|
):
|
||||||
|
super().__init__(fd=sock.fileno(), open_fd=False, close_on_context_exit=close_on_context_exit)
|
||||||
# Do we perform the TCP handshake automatically
|
# Do we perform the TCP handshake automatically
|
||||||
# upon connection? This is mostly needed for SSL
|
# upon connection? This is mostly needed for SSL
|
||||||
# sockets
|
# sockets
|
||||||
self.do_handshake_on_connect = do_handshake_on_connect
|
self.do_handshake_on_connect = do_handshake_on_connect
|
||||||
# Do we close ourselves upon the end of a context manager?
|
# As sockets, we have different methods compared to a
|
||||||
self.close_on_context_exit = close_on_context_exit
|
# generic asynchronous stream, so we need to set these
|
||||||
# The socket.fromfd function copies the file descriptor
|
# fields manually. It's also why we passed open_fd=False
|
||||||
# instead of using the same one, so we'd be trying to close
|
# to our constructor call earlier
|
||||||
# a different resource if we used sock.fileno() instead
|
self.stream = sock
|
||||||
# of self.stream.fileno() as our file descriptor
|
|
||||||
self.stream = socket.fromfd(sock.fileno(), sock.family, sock.type, sock.proto)
|
|
||||||
self._fd = self.stream.fileno()
|
|
||||||
self.stream.setblocking(False)
|
self.stream.setblocking(False)
|
||||||
# A socket that isn't connected doesn't
|
# A socket that isn't connected doesn't
|
||||||
# need to be closed
|
# need to be closed
|
||||||
|
|
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||||
"""
|
"""
|
||||||
import signal
|
import signal
|
||||||
import itertools
|
import itertools
|
||||||
|
import warnings
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from aiosched.task import Task, TaskState
|
from aiosched.task import Task, TaskState
|
||||||
|
@ -32,6 +33,7 @@ from aiosched.errors import (
|
||||||
ResourceBroken,
|
ResourceBroken,
|
||||||
)
|
)
|
||||||
from aiosched.context import TaskPool, TaskScope
|
from aiosched.context import TaskPool, TaskScope
|
||||||
|
from aiosched.sigint import CTRLC_PROTECTION_ENABLED
|
||||||
from selectors import DefaultSelector, BaseSelector, EVENT_READ, EVENT_WRITE
|
from selectors import DefaultSelector, BaseSelector, EVENT_READ, EVENT_WRITE
|
||||||
|
|
||||||
|
|
||||||
|
@ -91,8 +93,8 @@ class FIFOKernel:
|
||||||
self._sigint_handled: bool = False
|
self._sigint_handled: bool = False
|
||||||
# Are we executing any task code?
|
# Are we executing any task code?
|
||||||
self._running: bool = False
|
self._running: bool = False
|
||||||
# The current context we're in
|
# The current task pool we're in
|
||||||
self.current_context: TaskPool | None = None
|
self.current_pool: TaskPool | None = None
|
||||||
# The current task scope we're in
|
# The current task scope we're in
|
||||||
self.current_scope: TaskScope | None = None
|
self.current_scope: TaskScope | None = None
|
||||||
|
|
||||||
|
@ -116,7 +118,7 @@ class FIFOKernel:
|
||||||
)
|
)
|
||||||
return f"{type(self).__name__}({data})"
|
return f"{type(self).__name__}({data})"
|
||||||
|
|
||||||
def _sigint_handler(self, *_args):
|
def _sigint_handler(self, _sig: int, frame):
|
||||||
"""
|
"""
|
||||||
Handles SIGINT
|
Handles SIGINT
|
||||||
|
|
||||||
|
@ -124,10 +126,7 @@ class FIFOKernel:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
self._sigint_handled = True
|
self._sigint_handled = True
|
||||||
# We reschedule the current task
|
# Poke the event loop with a stick ;)
|
||||||
# immediately no matter what it's
|
|
||||||
# doing so that we process the
|
|
||||||
# exception right away
|
|
||||||
self.reschedule_running()
|
self.reschedule_running()
|
||||||
|
|
||||||
def done(self) -> bool:
|
def done(self) -> bool:
|
||||||
|
@ -161,29 +160,34 @@ class FIFOKernel:
|
||||||
self.current_task.throw(
|
self.current_task.throw(
|
||||||
InternalError("cannot shut down a running event loop")
|
InternalError("cannot shut down a running event loop")
|
||||||
)
|
)
|
||||||
for task in self.all():
|
for task in self.all(copy=True):
|
||||||
self.cancel(task)
|
self.cancel(task)
|
||||||
|
self.selector.close()
|
||||||
|
|
||||||
def all(self) -> Task:
|
def all(self, copy: bool = False) -> Task:
|
||||||
"""
|
"""
|
||||||
Yields all the tasks the event loop is keeping track of
|
Yields a ll the tasks the event loop is keeping track of.
|
||||||
|
This is an internal undocumented method
|
||||||
"""
|
"""
|
||||||
|
|
||||||
for task in itertools.chain(self.run_ready, self.paused):
|
sources = []
|
||||||
|
if self.paused:
|
||||||
|
sources.append([])
|
||||||
|
for _, __, task, ___ in self.paused.container:
|
||||||
|
sources[-1].append(task)
|
||||||
|
if copy:
|
||||||
|
sources.append(self.run_ready.copy())
|
||||||
|
else:
|
||||||
|
sources.append(self.run_ready)
|
||||||
|
if self.selector.get_map():
|
||||||
|
sources.append([])
|
||||||
|
for key in (self.selector.get_map() or {}).values():
|
||||||
|
for task in key.data.values():
|
||||||
|
sources[-1].append(task)
|
||||||
|
for task in itertools.chain(*sources):
|
||||||
task: Task
|
task: Task
|
||||||
yield task
|
yield task
|
||||||
|
|
||||||
def shutdown(self):
|
|
||||||
"""
|
|
||||||
Shuts down the event loop
|
|
||||||
"""
|
|
||||||
|
|
||||||
for task in self.all():
|
|
||||||
self.io_release_task(task)
|
|
||||||
self.paused.discard(task)
|
|
||||||
self.selector.close()
|
|
||||||
self.close()
|
|
||||||
|
|
||||||
def wait_io(self):
|
def wait_io(self):
|
||||||
"""
|
"""
|
||||||
Waits for I/O and schedules tasks when their
|
Waits for I/O and schedules tasks when their
|
||||||
|
@ -265,12 +269,12 @@ class FIFOKernel:
|
||||||
|
|
||||||
self.debugger.on_context_creation(ctx)
|
self.debugger.on_context_creation(ctx)
|
||||||
self.current_task.context = ctx
|
self.current_task.context = ctx
|
||||||
if not self.current_context:
|
if self.current_pool is None:
|
||||||
self.current_context = ctx
|
self.current_pool = ctx
|
||||||
else:
|
else:
|
||||||
self.current_context.inner = ctx
|
self.current_pool.inner = ctx
|
||||||
ctx.outer = self.current_context
|
ctx.outer = self.current_pool
|
||||||
self.current_context = ctx
|
self.current_pool = ctx
|
||||||
self.reschedule_running()
|
self.reschedule_running()
|
||||||
|
|
||||||
def close_context(self, ctx: TaskPool):
|
def close_context(self, ctx: TaskPool):
|
||||||
|
@ -281,7 +285,7 @@ class FIFOKernel:
|
||||||
ctx.inner = None
|
ctx.inner = None
|
||||||
self.debugger.on_context_exit(ctx)
|
self.debugger.on_context_exit(ctx)
|
||||||
ctx.entry_point.context = None
|
ctx.entry_point.context = None
|
||||||
self.current_context = ctx.outer
|
self.current_pool = ctx.outer
|
||||||
self.reschedule_running()
|
self.reschedule_running()
|
||||||
|
|
||||||
def set_scope(self, scope: TaskScope):
|
def set_scope(self, scope: TaskScope):
|
||||||
|
@ -289,7 +293,7 @@ class FIFOKernel:
|
||||||
Sets the current task scope
|
Sets the current task scope
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if not self.current_scope:
|
if self.current_scope is None:
|
||||||
self.current_scope = scope
|
self.current_scope = scope
|
||||||
else:
|
else:
|
||||||
self.current_scope.inner = scope
|
self.current_scope.inner = scope
|
||||||
|
@ -338,9 +342,9 @@ class FIFOKernel:
|
||||||
self._running = True
|
self._running = True
|
||||||
if self._sigint_handled:
|
if self._sigint_handled:
|
||||||
self._sigint_handled = False
|
self._sigint_handled = False
|
||||||
self.reschedule_running()
|
self.run_ready.appendleft(self.current_task)
|
||||||
self.current_task.throw(KeyboardInterrupt())
|
self.current_task.throw(KeyboardInterrupt())
|
||||||
elif self.current_task.pending_cancellation:
|
if self.current_task.pending_cancellation:
|
||||||
# We perform the deferred cancellation
|
# We perform the deferred cancellation
|
||||||
# if it was previously scheduled
|
# if it was previously scheduled
|
||||||
self.cancel(self.current_task)
|
self.cancel(self.current_task)
|
||||||
|
@ -368,6 +372,25 @@ class FIFOKernel:
|
||||||
getattr(self, method)(*args, **kwargs)
|
getattr(self, method)(*args, **kwargs)
|
||||||
self.debugger.after_task_step(self.current_task)
|
self.debugger.after_task_step(self.current_task)
|
||||||
|
|
||||||
|
def setup(self):
|
||||||
|
"""
|
||||||
|
Configures the event loop
|
||||||
|
"""
|
||||||
|
|
||||||
|
if signal.getsignal(signal.SIGINT) == signal.default_int_handler:
|
||||||
|
signal.signal(signal.SIGINT, self._sigint_handler)
|
||||||
|
else:
|
||||||
|
warnings.warn("aiosched detected a custom signal handler for SIGINT and it won't touch it, but"
|
||||||
|
" keep in mind that Ctrl+C is likely to break!")
|
||||||
|
|
||||||
|
def teardown(self):
|
||||||
|
"""
|
||||||
|
Undoes any modification made by setup()
|
||||||
|
"""
|
||||||
|
|
||||||
|
if signal.getsignal(signal.SIGINT) == self._sigint_handled:
|
||||||
|
signal.signal(signal.SIGINT, signal.default_int_handler)
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
"""
|
"""
|
||||||
The event loop's runner function. This method drives
|
The event loop's runner function. This method drives
|
||||||
|
@ -388,7 +411,7 @@ class FIFOKernel:
|
||||||
# If we're done, which means there are
|
# If we're done, which means there are
|
||||||
# both no paused tasks and no running tasks, we
|
# both no paused tasks and no running tasks, we
|
||||||
# simply tear us down and return to self.start
|
# simply tear us down and return to self.start
|
||||||
self.shutdown()
|
self.close()
|
||||||
break
|
break
|
||||||
elif self._sigint_handled:
|
elif self._sigint_handled:
|
||||||
# We got Ctrl+C-ed while not running a task! We pick
|
# We got Ctrl+C-ed while not running a task! We pick
|
||||||
|
@ -404,7 +427,7 @@ class FIFOKernel:
|
||||||
task, *_ = self.paused.get()
|
task, *_ = self.paused.get()
|
||||||
else:
|
else:
|
||||||
task = self.current_task
|
task = self.current_task
|
||||||
self.run_ready.append(task)
|
self.run_ready.appendleft(task)
|
||||||
self.handle_errors(self.run_task_step)
|
self.handle_errors(self.run_task_step)
|
||||||
elif not self.run_ready:
|
elif not self.run_ready:
|
||||||
# If there are no actively running tasks, we start by
|
# If there are no actively running tasks, we start by
|
||||||
|
@ -427,15 +450,14 @@ class FIFOKernel:
|
||||||
Starts the event loop from a synchronous context
|
Starts the event loop from a synchronous context
|
||||||
"""
|
"""
|
||||||
|
|
||||||
old = signal.getsignal(signal.SIGINT)
|
self.setup()
|
||||||
signal.signal(signal.SIGINT, self._sigint_handler)
|
|
||||||
self.entry_point = Task(func.__name__ or str(func), func(*args, **kwargs))
|
self.entry_point = Task(func.__name__ or str(func), func(*args, **kwargs))
|
||||||
self.run_ready.append(self.entry_point)
|
self.run_ready.append(self.entry_point)
|
||||||
self.debugger.on_start()
|
self.debugger.on_start()
|
||||||
try:
|
try:
|
||||||
self.run()
|
self.run()
|
||||||
finally:
|
finally:
|
||||||
signal.signal(signal.SIGINT, old)
|
self.teardown()
|
||||||
if (
|
if (
|
||||||
self.entry_point.exc
|
self.entry_point.exc
|
||||||
and self.entry_point.context is None
|
and self.entry_point.context is None
|
||||||
|
@ -448,7 +470,7 @@ class FIFOKernel:
|
||||||
self.debugger.on_exit()
|
self.debugger.on_exit()
|
||||||
return self.entry_point.result
|
return self.entry_point.result
|
||||||
|
|
||||||
def io_release(self, resource):
|
def io_release(self, resource, internal: bool = False):
|
||||||
"""
|
"""
|
||||||
Releases the given resource from our
|
Releases the given resource from our
|
||||||
selector
|
selector
|
||||||
|
@ -458,9 +480,10 @@ class FIFOKernel:
|
||||||
if resource in self.selector.get_map():
|
if resource in self.selector.get_map():
|
||||||
self.selector.unregister(resource)
|
self.selector.unregister(resource)
|
||||||
self.debugger.on_io_unschedule(resource)
|
self.debugger.on_io_unschedule(resource)
|
||||||
if resource is self.current_task.last_io[1]:
|
if self.current_task.last_io and resource is self.current_task.last_io[1]:
|
||||||
self.current_task.last_io = None
|
self.current_task.last_io = None
|
||||||
self.reschedule_running()
|
if not internal:
|
||||||
|
self.reschedule_running()
|
||||||
|
|
||||||
def io_release_task(self, task: Task):
|
def io_release_task(self, task: Task):
|
||||||
"""
|
"""
|
||||||
|
@ -468,14 +491,15 @@ class FIFOKernel:
|
||||||
for each I/O resource the given task owns
|
for each I/O resource the given task owns
|
||||||
"""
|
"""
|
||||||
|
|
||||||
for key in dict(self.selector.get_map()).values():
|
for key in dict(self.selector.get_map() or {}).values():
|
||||||
if task not in key.data.values():
|
if task not in key.data.values():
|
||||||
continue
|
continue
|
||||||
if len(key.data.values()) == 2:
|
if len(key.data.values()) == 2:
|
||||||
if key.data.values()[0] != task or key.data.values[1] != task:
|
a, b = key.data.values()
|
||||||
|
if a is not task or b is not task:
|
||||||
continue
|
continue
|
||||||
self.notify_closing(key.fileobj, broken=True)
|
self.notify_closing(key.fileobj, broken=True)
|
||||||
self.selector.unregister(key.fileobj)
|
self.io_release(key.fileobj, internal=True)
|
||||||
task.last_io = None
|
task.last_io = None
|
||||||
|
|
||||||
def get_active_io_count(self) -> int:
|
def get_active_io_count(self) -> int:
|
||||||
|
@ -523,16 +547,20 @@ class FIFOKernel:
|
||||||
it fails
|
it fails
|
||||||
"""
|
"""
|
||||||
|
|
||||||
self.paused.discard(task)
|
if not task.scope or task.scope.cancellable:
|
||||||
self.io_release_task(task)
|
self.paused.discard(task)
|
||||||
self.handle_errors(partial(task.throw, Cancelled(task)), task)
|
self.io_release_task(task)
|
||||||
if task.state != TaskState.CANCELLED:
|
self.handle_errors(partial(task.throw, Cancelled(task)), task)
|
||||||
task.pending_cancellation = True
|
if task.state != TaskState.CANCELLED:
|
||||||
self.run_ready.append(task)
|
task.pending_cancellation = True
|
||||||
if self.current_task not in self.run_ready:
|
self.run_ready.append(task)
|
||||||
self.reschedule_running()
|
self.reschedule_running()
|
||||||
|
|
||||||
def throw(self, task, error):
|
def throw(self, task, error):
|
||||||
|
"""
|
||||||
|
Throws the given exception into the given task
|
||||||
|
"""
|
||||||
|
|
||||||
self.paused.discard(task)
|
self.paused.discard(task)
|
||||||
self.io_release_task(task)
|
self.io_release_task(task)
|
||||||
self.handle_errors(partial(task.throw, error), task)
|
self.handle_errors(partial(task.throw, error), task)
|
||||||
|
@ -624,6 +652,13 @@ class FIFOKernel:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
task = Task(func.__name__ or repr(func), func(*args, **kwargs))
|
task = Task(func.__name__ or repr(func), func(*args, **kwargs))
|
||||||
|
# We inject our magic secret variable into the coroutine's stack frame so
|
||||||
|
# we can look it up later
|
||||||
|
task.coroutine.cr_frame.f_locals.setdefault(CTRLC_PROTECTION_ENABLED, False)
|
||||||
|
task.scope = self.current_scope
|
||||||
|
if self.current_pool:
|
||||||
|
task.context = self.current_pool
|
||||||
|
self.current_pool.tasks.append(task)
|
||||||
self.data[self.current_task] = task
|
self.data[self.current_task] = task
|
||||||
self.run_ready.append(task)
|
self.run_ready.append(task)
|
||||||
self.reschedule_running()
|
self.reschedule_running()
|
||||||
|
|
|
@ -0,0 +1,29 @@
|
||||||
|
"""
|
||||||
|
aiosched: Yet another Python async scheduler
|
||||||
|
|
||||||
|
Copyright (C) 2022 nocturn9x
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
https:www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
"""
|
||||||
|
# Special magic module half-stolen from Trio (thanks njsmith I love you)
|
||||||
|
# that makes Ctrl+C work. P.S.: Please Python, get your signals straight.
|
||||||
|
|
||||||
|
|
||||||
|
# Just a funny variable name that is not a valid
|
||||||
|
# identifier (but still a string so tools that hack
|
||||||
|
# into frames don't freak out when they look at the
|
||||||
|
# local variables) which will get injected silently
|
||||||
|
# into every frame to enable/disable the safeguards
|
||||||
|
# for Ctrl+C/KeyboardInterrupt
|
||||||
|
CTRLC_PROTECTION_ENABLED = "|yes-it-is|"
|
||||||
|
|
|
@ -19,6 +19,9 @@ import socket as _socket
|
||||||
from aiosched.io import AsyncSocket
|
from aiosched.io import AsyncSocket
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_HE_DELAY = 0.250
|
||||||
|
|
||||||
|
|
||||||
def wrap_socket(sock: _socket.socket) -> AsyncSocket:
|
def wrap_socket(sock: _socket.socket) -> AsyncSocket:
|
||||||
"""
|
"""
|
||||||
Wraps a standard socket into an async socket
|
Wraps a standard socket into an async socket
|
||||||
|
@ -29,9 +32,9 @@ def wrap_socket(sock: _socket.socket) -> AsyncSocket:
|
||||||
|
|
||||||
def socket(*args, **kwargs):
|
def socket(*args, **kwargs):
|
||||||
"""
|
"""
|
||||||
Creates a new giambio socket, taking in the same positional and
|
Creates a new aiosched socket, taking in the same positional and
|
||||||
keyword arguments as the standard library's socket.socket
|
keyword arguments as the standard library's socket.socket
|
||||||
constructor
|
constructor
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return wrap_socket(_socket.socket(*args, **kwargs))
|
return wrap_socket(_socket.socket(*args, **kwargs))
|
|
@ -83,6 +83,8 @@ class Task:
|
||||||
context: "TaskPool" = field(default=None, repr=False)
|
context: "TaskPool" = field(default=None, repr=False)
|
||||||
# We propagate exception only at the first call to wait()
|
# We propagate exception only at the first call to wait()
|
||||||
propagate: bool = True
|
propagate: bool = True
|
||||||
|
# The task's scope
|
||||||
|
scope: "TaskScope" = field(default=None, repr=False)
|
||||||
|
|
||||||
def run(self, what: Any | None = None):
|
def run(self, what: Any | None = None):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -10,9 +10,9 @@ names: set[str] = set()
|
||||||
|
|
||||||
async def serve(bind_address: tuple):
|
async def serve(bind_address: tuple):
|
||||||
"""
|
"""
|
||||||
Serves asynchronously forever
|
Serves asynchronously forever (or until Ctrl+C ;))
|
||||||
|
|
||||||
:param bind_address: The address to bind the server to represented as a tuple
|
:param bind_address: The address to bind the server to, represented as a tuple
|
||||||
(address, port) where address is a string and port is an integer
|
(address, port) where address is a string and port is an integer
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -20,17 +20,18 @@ async def serve(bind_address: tuple):
|
||||||
await sock.bind(bind_address)
|
await sock.bind(bind_address)
|
||||||
await sock.listen(5)
|
await sock.listen(5)
|
||||||
logging.info(f"Serving asynchronously at {bind_address[0]}:{bind_address[1]}")
|
logging.info(f"Serving asynchronously at {bind_address[0]}:{bind_address[1]}")
|
||||||
async with aiosched.create_pool() as ctx:
|
async with sock:
|
||||||
async with sock:
|
while True:
|
||||||
while True:
|
try:
|
||||||
try:
|
async with aiosched.create_pool() as pool:
|
||||||
conn, address_tuple = await sock.accept()
|
conn, address_tuple = await sock.accept()
|
||||||
clients[conn] = ["", f"{address_tuple[0]}:{address_tuple[1]}"]
|
clients[conn] = ["", f"{address_tuple[0]}:{address_tuple[1]}"]
|
||||||
logging.info(f"{address_tuple[0]}:{address_tuple[1]} connected")
|
logging.info(f"{address_tuple[0]}:{address_tuple[1]} connected")
|
||||||
await ctx.spawn(handler, conn)
|
await pool.spawn(handler, conn)
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
# Because exceptions just *work*
|
raise
|
||||||
logging.info(f"{address_tuple[0]}:{address_tuple[1]} has raised {type(err).__name__}: {err}")
|
# Because exceptions just *work*
|
||||||
|
logging.info(f"{address_tuple[0]}:{address_tuple[1]} has raised {type(err).__name__}: {err}")
|
||||||
|
|
||||||
|
|
||||||
async def handler(sock: aiosched.socket.AsyncSocket):
|
async def handler(sock: aiosched.socket.AsyncSocket):
|
||||||
|
@ -87,6 +88,10 @@ async def handler(sock: aiosched.socket.AsyncSocket):
|
||||||
await client_sock.send_all(f"[{name}] ({address}): {data.decode()}> ".encode())
|
await client_sock.send_all(f"[{name}] ({address}): {data.decode()}> ".encode())
|
||||||
logging.info(f"Sent {data!r} to {i} clients")
|
logging.info(f"Sent {data!r} to {i} clients")
|
||||||
logging.info(f"Connection from {address} closed")
|
logging.info(f"Connection from {address} closed")
|
||||||
|
logging.info(f"{name} has left the chatroom ({address}), informing clients")
|
||||||
|
for i, client_sock in enumerate(clients):
|
||||||
|
if client_sock != sock and clients[client_sock][0]:
|
||||||
|
await client_sock.send_all(f"{name} has left the chatroom\n> ".encode())
|
||||||
clients.pop(sock)
|
clients.pop(sock)
|
||||||
names.discard(name)
|
names.discard(name)
|
||||||
logging.info("Handler shutting down")
|
logging.info("Handler shutting down")
|
||||||
|
@ -105,4 +110,5 @@ if __name__ == "__main__":
|
||||||
if isinstance(error, KeyboardInterrupt):
|
if isinstance(error, KeyboardInterrupt):
|
||||||
logging.info("Ctrl+C detected, exiting")
|
logging.info("Ctrl+C detected, exiting")
|
||||||
else:
|
else:
|
||||||
|
raise
|
||||||
logging.error(f"Exiting due to a {type(error).__name__}: {error}")
|
logging.error(f"Exiting due to a {type(error).__name__}: {error}")
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
import aiosched
|
import aiosched
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
async def sender(c: aiosched.MemoryChannel, n: int):
|
async def sender(c: aiosched.MemoryChannel, n: int):
|
||||||
for i in range(n):
|
for i in range(n):
|
||||||
await c.write(str(i))
|
await c.write(str(i))
|
||||||
|
|
|
@ -0,0 +1,34 @@
|
||||||
|
import aiosched
|
||||||
|
import socket
|
||||||
|
|
||||||
|
# TODO: This is borked
|
||||||
|
|
||||||
|
# Pls notice me njsmith senpai :>
|
||||||
|
|
||||||
|
|
||||||
|
async def proxy_one_way(source: aiosched.socket.AsyncSocket, sink: aiosched.socket.AsyncSocket):
|
||||||
|
while True:
|
||||||
|
data = await source.receive(1024)
|
||||||
|
if not data:
|
||||||
|
await sink.shutdown(socket.SHUT_WR)
|
||||||
|
break
|
||||||
|
await sink.send_all(data)
|
||||||
|
|
||||||
|
|
||||||
|
async def proxy_two_way(a: aiosched.socket.AsyncSocket, b: aiosched.socket.AsyncSocket):
|
||||||
|
async with aiosched.create_pool() as pool:
|
||||||
|
await pool.spawn(proxy_one_way, a, b)
|
||||||
|
await pool.spawn(proxy_one_way, b, a)
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
async with aiosched.skip_after(10):
|
||||||
|
a = aiosched.socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||||
|
b = aiosched.socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||||
|
await a.connect(("localhost", 12345))
|
||||||
|
await b.connect(("localhost", 54321))
|
||||||
|
async with a, b:
|
||||||
|
await proxy_two_way(a, b)
|
||||||
|
|
||||||
|
|
||||||
|
aiosched.run(main)
|
|
@ -26,17 +26,21 @@ async def test(host: str, port: int, bufsize: int = 4096):
|
||||||
# in the AsyncSocket class instead
|
# in the AsyncSocket class instead
|
||||||
do_handshake_on_connect=False,
|
do_handshake_on_connect=False,
|
||||||
server_hostname=host,
|
server_hostname=host,
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
|
# You have the option to do the handshake yourself
|
||||||
|
# socket.do_handshake_on_connect = False
|
||||||
print(f"Attempting a connection to {host}:{port}")
|
print(f"Attempting a connection to {host}:{port}")
|
||||||
await socket.connect((host, port))
|
await socket.connect((host, port))
|
||||||
|
# await socket.do_handshake()
|
||||||
print("Connected")
|
print("Connected")
|
||||||
|
# Ensures the code below doesn't run for more than 5 seconds
|
||||||
async with aiosched.skip_after(5) as scope:
|
async with aiosched.skip_after(5) as scope:
|
||||||
|
# Closes the socket automatically
|
||||||
async with socket:
|
async with socket:
|
||||||
# Closes the socket automatically
|
|
||||||
print("Entered socket context manager, sending request data")
|
print("Entered socket context manager, sending request data")
|
||||||
await socket.send_all(
|
await socket.send_all(
|
||||||
f"GET / HTTP/1.1\r\nHost: {host}\r\nUser-Agent: owo\r\nAccept: text/html\r\nConnection: keep-alive\r\nAccept: */*\r\n\r\n".encode()
|
f"GET / HTTP/1.1\r\nUser-Agent: PostmanRuntime/7.32.2\r\nAccept: */*\r\nHost: {host}\r\nAccept-Encoding: gzip, deflate, br\r\nConnection: keep-alive\r\n\r\n".encode()
|
||||||
)
|
)
|
||||||
print("Data sent")
|
print("Data sent")
|
||||||
buffer = b""
|
buffer = b""
|
||||||
|
@ -60,8 +64,6 @@ async def test(host: str, port: int, bufsize: int = 4096):
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
if not element.strip() and not content:
|
if not element.strip() and not content:
|
||||||
# This only works because google sends a newline
|
|
||||||
# before the content
|
|
||||||
sys.stdout.write("\nContent:")
|
sys.stdout.write("\nContent:")
|
||||||
content = True
|
content = True
|
||||||
if not content:
|
if not content:
|
||||||
|
@ -72,4 +74,4 @@ async def test(host: str, port: int, bufsize: int = 4096):
|
||||||
_print("Done!")
|
_print("Done!")
|
||||||
|
|
||||||
|
|
||||||
aiosched.run(test, "debian.org", 443, 256, debugger=())
|
aiosched.run(test, "nocturn9x.space", 443, 256, debugger=())
|
||||||
|
|
Reference in New Issue