Compare commits
95 Commits
279a2a3d3f
...
e1485d9317
Author | SHA1 | Date |
---|---|---|
Mattia Giambirtone | e1485d9317 | |
Mattia Giambirtone | 52a09307ae | |
Mattia Giambirtone | 2601ebb514 | |
Mattia Giambirtone | d77ddcf6a6 | |
Mattia Giambirtone | 2aecb7f440 | |
Mattia Giambirtone | d5b9564d7a | |
Mattia Giambirtone | 12f4e5c0bf | |
Mattia Giambirtone | a95be6d26c | |
Mattia Giambirtone | a03c01b74c | |
Mattia Giambirtone | 1c870da111 | |
Mattia Giambirtone | b39d0ff809 | |
Mattia Giambirtone | 4287296efc | |
Mattia Giambirtone | 6b098b7c46 | |
Mattia Giambirtone | 51a5cd072a | |
Mattia Giambirtone | 95e265aca1 | |
Mattia Giambirtone | 7f790051b2 | |
Mattia Giambirtone | 7cd087307a | |
Mattia Giambirtone | f15ff2224b | |
Mattia Giambirtone | 38f9a22ae1 | |
Mattia Giambirtone | 09a4e2f576 | |
Mattia Giambirtone | c5d55e6ea6 | |
Mattia Giambirtone | 0db7c2e4d3 | |
Mattia Giambirtone | 5bf46de096 | |
Mattia Giambirtone | 4d50130d53 | |
Mattia Giambirtone | 351a212ccd | |
Mattia Giambirtone | 09ad7e12e3 | |
Mattia Giambirtone | 4c969d7827 | |
Mattia Giambirtone | 993cb118e3 | |
Mattia Giambirtone | 4f2c4979fd | |
Mattia Giambirtone | fd6037ba88 | |
Mattia Giambirtone | 9e6ee1e104 | |
Mattia Giambirtone | e0f2e87cad | |
Mattia Giambirtone | e452ab4a25 | |
Mattia Giambirtone | e2f2abf026 | |
Mattia Giambirtone | f4c72e40e2 | |
Mattia Giambirtone | dd6cb509e7 | |
Mattia Giambirtone | 28c8b01554 | |
Mattia Giambirtone | 140a1bca8f | |
Mattia Giambirtone | cd2d810b5a | |
Mattia Giambirtone | ee7451014b | |
Mattia Giambirtone | a09babae53 | |
Mattia Giambirtone | 81bdd64a7e | |
Mattia Giambirtone | a91fb8eb8f | |
Mattia Giambirtone | 0c3cc11f79 | |
Mattia Giambirtone | 50de381033 | |
Mattia Giambirtone | 5910d8574b | |
Mattia Giambirtone | b51d91b11c | |
Mattia Giambirtone | 8d22b348e0 | |
Mattia Giambirtone | 6f3394d7d6 | |
Mattia Giambirtone | 422304fcd9 | |
Mattia Giambirtone | 5aad2f666e | |
Mattia Giambirtone | 156b3c6fc8 | |
Mattia Giambirtone | f9e56cffc4 | |
Mattia Giambirtone | 2da89cf138 | |
Mattia Giambirtone | 77df81fa4b | |
Mattia Giambirtone | 382f6993cc | |
Mattia Giambirtone | fb0881ae0f | |
Mattia Giambirtone | 26ee910ba0 | |
Mattia Giambirtone | 26a43d5f84 | |
Mattia Giambirtone | 3ea159c858 | |
Mattia Giambirtone | e77c154bcd | |
Mattia Giambirtone | feda2708ef | |
Mattia Giambirtone | 2f0f6f82bd | |
Mattia Giambirtone | 9cec11dcc6 | |
Mattia Giambirtone | 4d35957d17 | |
Mattia Giambirtone | e3f7af28f2 | |
Mattia Giambirtone | 59096f34b8 | |
Mattia Giambirtone | f1a20be126 | |
Mattia Giambirtone | 1819ef844c | |
Mattia Giambirtone | 7904afb985 | |
Mattia Giambirtone | 400b0fa04c | |
Mattia Giambirtone | 3e49f71f00 | |
Mattia Giambirtone | 45d3e308d9 | |
Mattia Giambirtone | dd56232250 | |
Mattia Giambirtone | 5a18314dcc | |
Mattia Giambirtone | 42bf9d5daf | |
Mattia Giambirtone | 8a16bb41d6 | |
Mattia Giambirtone | 0abd2c2364 | |
Mattia Giambirtone | 1b4193ce79 | |
Mattia Giambirtone | f5ec5beab3 | |
Mattia Giambirtone | a0acce3ed3 | |
Mattia Giambirtone | 3e33f2732e | |
Mattia Giambirtone | 568f27534b | |
Mattia Giambirtone | 4a7e4cb732 | |
Mattia Giambirtone | e46a41dd8f | |
Mattia Giambirtone | 242d4818bb | |
Mattia Giambirtone | 52daf54ee3 | |
Mattia Giambirtone | cbbe8cc114 | |
Mattia Giambirtone | 654add480d | |
Mattia Giambirtone | d8b2066126 | |
Mattia Giambirtone | 15d0a0674f | |
Mattia Giambirtone | f7dedeeb6c | |
Mattia Giambirtone | f10796ac53 | |
Mattia Giambirtone | 1947b9ddd4 | |
Mattia Giambirtone | 758e50b6e3 |
|
@ -138,3 +138,4 @@ dmypy.json
|
|||
# Cython debug symbols
|
||||
cython_debug/
|
||||
|
||||
pyvenv.cfg
|
||||
|
|
|
@ -0,0 +1,8 @@
|
|||
# Default ignored files
|
||||
/shelf/
|
||||
/workspace.xml
|
||||
# Editor-based HTTP Client requests
|
||||
/httpRequests/
|
||||
# Datasource local storage ignored files
|
||||
/dataSources/
|
||||
/dataSources.local.xml
|
|
@ -0,0 +1,13 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<module type="PYTHON_MODULE" version="4">
|
||||
<component name="NewModuleRootManager">
|
||||
<content url="file://$MODULE_DIR$">
|
||||
<sourceFolder url="file://$MODULE_DIR$" isTestSource="false" />
|
||||
<excludeFolder url="file://$MODULE_DIR$/venv" />
|
||||
<excludeFolder url="file://$MODULE_DIR$/build" />
|
||||
<excludeFolder url="file://$MODULE_DIR$/dist" />
|
||||
</content>
|
||||
<orderEntry type="inheritedJdk" />
|
||||
<orderEntry type="sourceFolder" forTests="false" />
|
||||
</component>
|
||||
</module>
|
|
@ -0,0 +1,6 @@
|
|||
<component name="InspectionProjectProfileManager">
|
||||
<settings>
|
||||
<option name="USE_PROJECT_PROFILE" value="false" />
|
||||
<version value="1.0" />
|
||||
</settings>
|
||||
</component>
|
|
@ -0,0 +1,7 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="Black">
|
||||
<option name="sdkName" value="Python 3.10 (structio)" />
|
||||
</component>
|
||||
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.10 (structio)" project-jdk-type="Python SDK" />
|
||||
</project>
|
|
@ -0,0 +1,8 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="ProjectModuleManager">
|
||||
<modules>
|
||||
<module fileurl="file://$PROJECT_DIR$/.idea/StructuredIO.iml" filepath="$PROJECT_DIR$/.idea/StructuredIO.iml" />
|
||||
</modules>
|
||||
</component>
|
||||
</project>
|
|
@ -0,0 +1,6 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="VcsDirectoryMappings">
|
||||
<mapping directory="$PROJECT_DIR$" vcs="Git" />
|
||||
</component>
|
||||
</project>
|
24
README.md
24
README.md
|
@ -1,3 +1,23 @@
|
|||
# structio
|
||||
# structio - What am I even doing?
|
||||
|
||||
A proof of concept for an experimental structured concurrency framework written in Python
|
||||
A proof of concept for an experimental structured concurrency framework written in Python
|
||||
|
||||
## Disclaimer
|
||||
|
||||
This library is highly experimental and currently in alpha stage (it doesn't even have a proper version
|
||||
number yet, that's how alpha it is), so it's not production ready (and probably never will be). If you
|
||||
want the fancy structured concurrency paradigm in a library that works today, consider [trio](https://trio.readthedocs.org),
|
||||
from which structio is heavily inspired ([curio](https://github.com/dabeaz/curio) is also worth looking into, although
|
||||
technically it doesn't implement SC).
|
||||
|
||||
## Why?
|
||||
|
||||
This library (and [its](https://git.nocturn9x.space/nocturn9x/giambio) [predecessors](https://git.nocturn9x.space/nocturn9x/aiosched)) is just a way for me to test my knowledge and make sure I understand the basics of structured concurrency
|
||||
and building solid coroutine runners so that I can implement the paradigm in my own programming language. For more info, see [here](https://git.nocturn9x.space/nocturn9x/peon).
|
||||
|
||||
**P.S.**: structio is only thoroughly tested for Linux: While Windows/macOS support is one of the goals
|
||||
of the project, I currently don't have enough time to dedicate to the quirks of the I/O subsystem of each OS.
|
||||
All features that don't rely on I/O (timeouts, events, queues, memory channels, etc.) are cross-platform, but
|
||||
things like sockets behave very differently depending on the platform, and it'll take some time for me to
|
||||
apply the necessary fixes for each of them. File I/O (in its current form using threads), as well as asynchronous
|
||||
threads and processes _should_ work, but there's no guarantee
|
|
@ -0,0 +1 @@
|
|||
sniffio==1.3.0
|
|
@ -0,0 +1,23 @@
|
|||
import setuptools
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
with open("README.md", "r") as readme:
|
||||
long_description = readme.read()
|
||||
setuptools.setup(
|
||||
name="structured-io",
|
||||
version="0.1.0",
|
||||
author="nocturn9x",
|
||||
author_email="nocturn9x@nocturn9x.space",
|
||||
description="An experimental structured concurrency framework",
|
||||
long_description=long_description,
|
||||
long_description_content_type="text/markdown",
|
||||
url="https://git.nocturn9x.space/nocturn9x/structio",
|
||||
packages=setuptools.find_packages(),
|
||||
classifiers=[
|
||||
"Programming Language :: Python :: 3",
|
||||
"Operating System :: OS Independent",
|
||||
"License :: OSI Approved :: Apache Software License",
|
||||
],
|
||||
python_requires=">=3.10",
|
||||
)
|
|
@ -0,0 +1,178 @@
|
|||
from structio.core import run as _run
|
||||
from typing import Coroutine, Any, Callable
|
||||
from structio.core.kernels.fifo import FIFOKernel
|
||||
from structio.core.managers.io.simple import SimpleIOManager
|
||||
from structio.core.managers.signals.sigint import SigIntManager
|
||||
from structio.core.time.clock import DefaultClock
|
||||
from structio.core.syscalls import sleep, suspend as _suspend
|
||||
from structio.core.context import TaskPool, TaskScope
|
||||
from structio.exceptions import (
|
||||
Cancelled,
|
||||
TimedOut,
|
||||
ResourceClosed,
|
||||
ResourceBroken,
|
||||
ResourceBusy,
|
||||
WouldBlock,
|
||||
)
|
||||
from structio.core import task
|
||||
from structio.core.task import Task, TaskState
|
||||
from structio.sync import (
|
||||
Event,
|
||||
Queue,
|
||||
MemoryChannel,
|
||||
Semaphore,
|
||||
Lock,
|
||||
RLock,
|
||||
emit,
|
||||
on_event,
|
||||
register_event,
|
||||
)
|
||||
from structio.abc import Channel, Stream, ChannelReader, ChannelWriter
|
||||
from structio.io import socket
|
||||
from structio.io.socket import AsyncSocket
|
||||
from structio.io.files import (
|
||||
open_file,
|
||||
wrap_file,
|
||||
aprint,
|
||||
stdout,
|
||||
stderr,
|
||||
stdin,
|
||||
ainput,
|
||||
)
|
||||
from structio.core.run import current_loop, current_task
|
||||
from structio import thread, parallel
|
||||
from structio.path import Path
|
||||
from structio.signals import set_signal_handler, get_signal_handler
|
||||
from structio import signals as _signals
|
||||
from structio import util
|
||||
|
||||
|
||||
def run(
|
||||
func: Callable[[Any, Any], Coroutine[Any, Any, Any]],
|
||||
*args,
|
||||
restrict_ki_to_checkpoints: bool = False,
|
||||
tools: list | None = None,
|
||||
):
|
||||
result = None
|
||||
try:
|
||||
result = _run.run(
|
||||
func,
|
||||
FIFOKernel,
|
||||
SimpleIOManager(),
|
||||
[SigIntManager()],
|
||||
DefaultClock(),
|
||||
tools,
|
||||
restrict_ki_to_checkpoints,
|
||||
*args,
|
||||
)
|
||||
finally:
|
||||
# Bunch of cleanup
|
||||
_signals._sig_handlers.clear()
|
||||
_signals._sig_data.clear()
|
||||
return result
|
||||
|
||||
|
||||
run.__doc__ = _run.run.__doc__
|
||||
|
||||
|
||||
def create_pool() -> TaskPool:
|
||||
"""
|
||||
Creates a new task pool
|
||||
"""
|
||||
|
||||
return TaskPool()
|
||||
|
||||
|
||||
def skip_after(timeout) -> TaskScope:
|
||||
"""
|
||||
Creates a new task scope with the
|
||||
specified timeout. No error is raised
|
||||
when the timeout expires
|
||||
"""
|
||||
|
||||
return TaskScope(timeout=timeout, silent=True)
|
||||
|
||||
|
||||
def with_timeout(timeout) -> TaskScope:
|
||||
"""
|
||||
Creates a new task scope with the
|
||||
specified timeout. TimeoutError is raised
|
||||
when the timeout expires
|
||||
"""
|
||||
|
||||
return TaskScope(timeout=timeout)
|
||||
|
||||
|
||||
def clock():
|
||||
"""
|
||||
Returns the current clock time of
|
||||
the event loop
|
||||
"""
|
||||
|
||||
return _run.current_loop().clock.current_time()
|
||||
|
||||
|
||||
async def _join(self: Task):
|
||||
if self.done():
|
||||
return self.result
|
||||
await _suspend()
|
||||
assert self.done()
|
||||
if self.state == TaskState.CRASHED:
|
||||
raise self.exc
|
||||
return self.result
|
||||
|
||||
|
||||
def _cancel(self: Task):
|
||||
_run.current_loop().cancel_task(self)
|
||||
|
||||
|
||||
task._joiner = _join
|
||||
|
||||
_cancel.__name__ = Task.cancel.__name__
|
||||
_cancel.__doc__ = Task.cancel.__doc__
|
||||
Task.cancel = _cancel
|
||||
|
||||
|
||||
__all__ = [
|
||||
"run",
|
||||
"sleep",
|
||||
"create_pool",
|
||||
"clock",
|
||||
"Cancelled",
|
||||
"skip_after",
|
||||
"with_timeout",
|
||||
"Event",
|
||||
"Queue",
|
||||
"MemoryChannel",
|
||||
"Channel",
|
||||
"Stream",
|
||||
"ChannelReader",
|
||||
"ChannelWriter",
|
||||
"Semaphore",
|
||||
"TimedOut",
|
||||
"Task",
|
||||
"TaskState",
|
||||
"TaskScope",
|
||||
"TaskPool",
|
||||
"ResourceClosed",
|
||||
"Lock",
|
||||
"RLock",
|
||||
"thread",
|
||||
"open_file",
|
||||
"wrap_file",
|
||||
"aprint",
|
||||
"stderr",
|
||||
"stdin",
|
||||
"stdout",
|
||||
"ainput",
|
||||
"current_loop",
|
||||
"current_task",
|
||||
"Path",
|
||||
"parallel",
|
||||
"get_signal_handler",
|
||||
"set_signal_handler",
|
||||
"util",
|
||||
"ResourceBusy",
|
||||
"ResourceBroken",
|
||||
"WouldBlock",
|
||||
]
|
|
@ -0,0 +1,739 @@
|
|||
import io
|
||||
import os
|
||||
from abc import abstractmethod, ABC
|
||||
from types import FrameType
|
||||
|
||||
from structio.core.task import Task
|
||||
from structio.exceptions import StructIOException
|
||||
from typing import Callable, Any, Coroutine
|
||||
|
||||
|
||||
class BaseClock(ABC):
|
||||
"""
|
||||
Abstract base clock class
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def start(self):
|
||||
return NotImplemented
|
||||
|
||||
@abstractmethod
|
||||
def setup(self):
|
||||
return NotImplemented
|
||||
|
||||
@abstractmethod
|
||||
def teardown(self):
|
||||
return NotImplemented
|
||||
|
||||
@abstractmethod
|
||||
def current_time(self):
|
||||
return NotImplemented
|
||||
|
||||
@abstractmethod
|
||||
def deadline(self, deadline):
|
||||
return NotImplemented
|
||||
|
||||
|
||||
class AsyncResource(ABC):
|
||||
"""
|
||||
A generic asynchronous resource which needs to
|
||||
be closed properly, possibly blocking. Can be
|
||||
used as a context manager (note that only the
|
||||
__aexit__ method actually blocks!)
|
||||
"""
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
@abstractmethod
|
||||
async def close(self):
|
||||
return NotImplemented
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
await self.close()
|
||||
|
||||
|
||||
class StreamWriter(AsyncResource, ABC):
|
||||
"""
|
||||
Interface for writing binary data to
|
||||
a byte stream
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def write(self, data):
|
||||
return NotImplemented
|
||||
|
||||
|
||||
class StreamReader(AsyncResource, ABC):
|
||||
"""
|
||||
Interface for reading binary data from
|
||||
a byte stream. The stream implements the
|
||||
asynchronous iterator protocol and can
|
||||
therefore be used with "async for" loops
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def _read(self, size: int = -1):
|
||||
return NotImplemented
|
||||
|
||||
|
||||
class Stream(StreamReader, StreamWriter, ABC):
|
||||
"""
|
||||
A generic, asynchronous, readable/writable binary stream
|
||||
"""
|
||||
|
||||
def __init__(self, f):
|
||||
if isinstance(f, io.TextIOBase):
|
||||
raise TypeError("only binary files can be streamed")
|
||||
self.fileobj = f
|
||||
self.buf = bytearray()
|
||||
os.set_blocking(self.fileobj.fileno(), False)
|
||||
|
||||
@abstractmethod
|
||||
async def flush(self):
|
||||
"""
|
||||
Flushes the underlying resource asynchronously
|
||||
"""
|
||||
|
||||
return NotImplemented
|
||||
|
||||
|
||||
class WriteCloseableStream(Stream, ABC):
|
||||
"""
|
||||
Extension to the Stream class that allows
|
||||
shutting down the write end of the stream
|
||||
without closing the read side on our end
|
||||
nor the read/write side on the other one
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def eof(self):
|
||||
"""
|
||||
Send an end-of-file on this stream, if possible.
|
||||
The resource can still be read from (and the
|
||||
other end can still read/write to it), but no more
|
||||
data can be written after an EOF has been sent. If an
|
||||
EOF has already been sent, this method is a no-op
|
||||
"""
|
||||
|
||||
|
||||
class ChannelReader(AsyncResource, ABC):
|
||||
"""
|
||||
Interface for reading data from a
|
||||
channel
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def receive(self):
|
||||
"""
|
||||
Receive an object from the channel,
|
||||
possibly blocking
|
||||
"""
|
||||
|
||||
return NotImplemented
|
||||
|
||||
@abstractmethod
|
||||
def pending(self):
|
||||
"""
|
||||
Returns if there is any data waiting
|
||||
to be read
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def readers(self):
|
||||
"""
|
||||
Returns how many tasks are waiting to
|
||||
read from the channel
|
||||
"""
|
||||
|
||||
|
||||
class ChannelWriter(AsyncResource, ABC):
|
||||
"""
|
||||
Interface for writing data to a
|
||||
channel
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def send(self, value):
|
||||
"""
|
||||
Send the given object on the channel,
|
||||
possibly blocking
|
||||
"""
|
||||
|
||||
return NotImplemented
|
||||
|
||||
@abstractmethod
|
||||
def writers(self):
|
||||
"""
|
||||
Returns how many tasks are waiting
|
||||
to write to the channel
|
||||
"""
|
||||
|
||||
|
||||
class Channel(ChannelWriter, ChannelReader, ABC):
|
||||
"""
|
||||
A generic, two-way channel
|
||||
"""
|
||||
|
||||
|
||||
class BaseDebugger(ABC):
|
||||
"""
|
||||
The base for all debugger objects
|
||||
"""
|
||||
|
||||
def on_start(self):
|
||||
"""
|
||||
This method is called when the event
|
||||
loop starts executing
|
||||
"""
|
||||
|
||||
return NotImplemented
|
||||
|
||||
def on_exit(self):
|
||||
"""
|
||||
This method is called when the event
|
||||
loop exits entirely (all tasks completed)
|
||||
"""
|
||||
|
||||
return NotImplemented
|
||||
|
||||
def on_task_schedule(self, task: Task, delay: float):
|
||||
"""
|
||||
This method is called when a new task is
|
||||
scheduled (not spawned)
|
||||
|
||||
:param task: The Task that was (re)scheduled
|
||||
:type task: :class: structio.objects.Task
|
||||
:param delay: The delay, in seconds, after which
|
||||
the task will start executing
|
||||
:type delay: float
|
||||
"""
|
||||
|
||||
return NotImplemented
|
||||
|
||||
def on_task_spawn(self, task: Task):
|
||||
"""
|
||||
This method is called when a new task is
|
||||
spawned
|
||||
|
||||
:param task: The Task that was spawned
|
||||
:type task: :class: structio.objects.Task
|
||||
"""
|
||||
|
||||
return NotImplemented
|
||||
|
||||
def on_task_exit(self, task: Task):
|
||||
"""
|
||||
This method is called when a task exits
|
||||
|
||||
:param task: The Task that exited
|
||||
:type task: :class: structio.objects.Task
|
||||
"""
|
||||
|
||||
return NotImplemented
|
||||
|
||||
def before_task_step(self, task: Task):
|
||||
"""
|
||||
This method is called right before
|
||||
calling a task's run() method
|
||||
|
||||
:param task: The Task that is about to run
|
||||
:type task: :class: structio.objects.Task
|
||||
"""
|
||||
|
||||
return NotImplemented
|
||||
|
||||
def after_task_step(self, task: Task):
|
||||
"""
|
||||
This method is called right after
|
||||
calling a task's run() method
|
||||
|
||||
:param task: The Task that has ran
|
||||
:type task: :class: structio.objects.Task
|
||||
"""
|
||||
|
||||
return NotImplemented
|
||||
|
||||
def before_sleep(self, task: Task, seconds: float):
|
||||
"""
|
||||
This method is called before a task goes
|
||||
to sleep
|
||||
|
||||
:param task: The Task that is about to sleep
|
||||
:type task: :class: structio.objects.Task
|
||||
:param seconds: The amount of seconds the
|
||||
task wants to sleep for
|
||||
:type seconds: int
|
||||
"""
|
||||
|
||||
return NotImplemented
|
||||
|
||||
def after_sleep(self, task: Task, seconds: float):
|
||||
"""
|
||||
This method is called after a tasks
|
||||
awakes from sleeping
|
||||
|
||||
:param task: The Task that has just slept
|
||||
:type task: :class: structio.objects.Task
|
||||
:param seconds: The amount of seconds the
|
||||
task slept for
|
||||
:type seconds: float
|
||||
"""
|
||||
|
||||
return NotImplemented
|
||||
|
||||
def before_io(self, timeout: float):
|
||||
"""
|
||||
This method is called right before
|
||||
the event loop checks for I/O events
|
||||
|
||||
:param timeout: The max. amount of seconds
|
||||
that the loop will hang for while waiting
|
||||
for I/O events
|
||||
:type timeout: float
|
||||
"""
|
||||
|
||||
return NotImplemented
|
||||
|
||||
def after_io(self, timeout: float):
|
||||
"""
|
||||
This method is called right after
|
||||
the event loop has checked for I/O events
|
||||
|
||||
:param timeout: The actual amount of seconds
|
||||
that the loop has hung for while waiting
|
||||
for I/O events
|
||||
:type timeout: float
|
||||
"""
|
||||
|
||||
return NotImplemented
|
||||
|
||||
def before_cancel(self, task: Task):
|
||||
"""
|
||||
This method is called right before a task
|
||||
gets cancelled
|
||||
|
||||
:param task: The Task that is about to be cancelled
|
||||
:type task: :class: structio.objects.Task
|
||||
"""
|
||||
|
||||
return NotImplemented
|
||||
|
||||
def after_cancel(self, task: Task) -> object:
|
||||
"""
|
||||
This method is called right after a task
|
||||
gets successfully cancelled
|
||||
|
||||
:param task: The Task that was cancelled
|
||||
:type task: :class: structio.objects.Task
|
||||
"""
|
||||
|
||||
return NotImplemented
|
||||
|
||||
def on_exception_raised(self, task: Task, exc: BaseException):
|
||||
"""
|
||||
This method is called right after a task
|
||||
has raised an exception
|
||||
|
||||
:param task: The Task that raised the error
|
||||
:type task: :class: structio.objects.Task
|
||||
:param exc: The exception that was raised
|
||||
:type exc: BaseException
|
||||
"""
|
||||
|
||||
return NotImplemented
|
||||
|
||||
def on_io_schedule(self, stream, event: int):
|
||||
"""
|
||||
This method is called whenever an
|
||||
I/O resource is scheduled for listening
|
||||
"""
|
||||
|
||||
return NotImplemented
|
||||
|
||||
def on_io_unschedule(self, stream):
|
||||
"""
|
||||
This method is called whenever a stream
|
||||
is unregistered from the loop's I/O selector
|
||||
"""
|
||||
|
||||
return NotImplemented
|
||||
|
||||
|
||||
class BaseIOManager(ABC):
|
||||
"""
|
||||
Base class for all I/O managers
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def wait_io(self, current_time):
|
||||
"""
|
||||
Waits for I/O and reschedules tasks
|
||||
when data is ready to be read/written
|
||||
"""
|
||||
|
||||
return NotImplemented
|
||||
|
||||
@abstractmethod
|
||||
def request_read(self, rsc, task: Task):
|
||||
"""
|
||||
"Requests" a read operation on the given
|
||||
resource to the I/O manager from the given
|
||||
task
|
||||
"""
|
||||
|
||||
return NotImplemented
|
||||
|
||||
@abstractmethod
|
||||
def request_write(self, rsc, task: Task):
|
||||
"""
|
||||
"Requests" a write operation on the given
|
||||
resource to the I/O manager from the given
|
||||
task
|
||||
"""
|
||||
|
||||
return NotImplemented
|
||||
|
||||
@abstractmethod
|
||||
def pending(self):
|
||||
"""
|
||||
Returns whether there's any tasks waiting
|
||||
to read from/write to a resource registered
|
||||
in the manager
|
||||
"""
|
||||
|
||||
return NotImplemented
|
||||
|
||||
@abstractmethod
|
||||
def release(self, resource):
|
||||
"""
|
||||
Releases the given async resource from the
|
||||
manager. Note that the resource is *not*
|
||||
closed!
|
||||
"""
|
||||
|
||||
return NotImplemented
|
||||
|
||||
@abstractmethod
|
||||
def release_task(self, task: Task):
|
||||
"""
|
||||
Releases ownership of the given
|
||||
resource from the given task. Note
|
||||
that if the resource is being used by
|
||||
other tasks that this method will
|
||||
not unschedule it for those as well
|
||||
"""
|
||||
|
||||
return NotImplemented
|
||||
|
||||
@abstractmethod
|
||||
def get_reader(self, rsc):
|
||||
"""
|
||||
Returns the task reading from the given
|
||||
resource, if any (None otherwise)
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_writer(self, rsc):
|
||||
"""
|
||||
Returns the task writing to the given
|
||||
resource, if any (None otherwise)
|
||||
"""
|
||||
|
||||
|
||||
class SignalManager(ABC):
|
||||
"""
|
||||
A signal manager
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def install(self):
|
||||
"""
|
||||
Installs the signal handler
|
||||
"""
|
||||
|
||||
return NotImplemented
|
||||
|
||||
@abstractmethod
|
||||
def uninstall(self):
|
||||
"""
|
||||
Uninstalls the signal handler
|
||||
"""
|
||||
|
||||
return NotImplemented
|
||||
|
||||
|
||||
class BaseKernel(ABC):
|
||||
"""
|
||||
Abstract kernel base class
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
clock: BaseClock,
|
||||
io_manager: BaseIOManager,
|
||||
signal_managers: list[SignalManager],
|
||||
tools: list[BaseDebugger] | None = None,
|
||||
restrict_ki_to_checkpoints: bool = False,
|
||||
):
|
||||
self.clock = clock
|
||||
self.current_task: Task | None = None
|
||||
self.current_pool: "TaskPool" = None
|
||||
self.current_scope: "TaskScope" = None
|
||||
self.tools: list[BaseDebugger] = tools or []
|
||||
self.restrict_ki_to_checkpoints: bool = restrict_ki_to_checkpoints
|
||||
self.io_manager = io_manager
|
||||
self.signal_managers = signal_managers
|
||||
self.entry_point: Task | None = None
|
||||
# Pool for system tasks
|
||||
self.pool: "TaskPool" = None
|
||||
|
||||
@abstractmethod
|
||||
def wait_readable(self, resource: AsyncResource):
|
||||
"""
|
||||
Schedule the given resource for reading from
|
||||
the current task
|
||||
"""
|
||||
|
||||
return NotImplemented
|
||||
|
||||
@abstractmethod
|
||||
def wait_writable(self, resource: AsyncResource):
|
||||
"""
|
||||
Schedule the given resource for reading from
|
||||
the current task
|
||||
"""
|
||||
|
||||
return NotImplemented
|
||||
|
||||
@abstractmethod
|
||||
def release_resource(self, resource: AsyncResource):
|
||||
"""
|
||||
Releases the given resource from the scheduler
|
||||
"""
|
||||
|
||||
return NotImplemented
|
||||
|
||||
@abstractmethod
|
||||
def notify_closing(
|
||||
self, resource: AsyncResource, broken: bool = False, owner: Task | None = None
|
||||
):
|
||||
"""
|
||||
Notifies the event loop that a given resource
|
||||
is about to be closed and can be unscheduled
|
||||
"""
|
||||
|
||||
return NotImplemented
|
||||
|
||||
@abstractmethod
|
||||
def cancel_task(self, task: Task):
|
||||
"""
|
||||
Cancels the given task individually
|
||||
"""
|
||||
|
||||
return NotImplemented
|
||||
|
||||
@abstractmethod
|
||||
def signal_notify(self, sig: int, frame: FrameType):
|
||||
"""
|
||||
Notifies the event loop that a signal was
|
||||
received
|
||||
"""
|
||||
|
||||
return NotImplemented
|
||||
|
||||
@abstractmethod
|
||||
def spawn(self, func: Callable[[Any, Any], Coroutine[Any, Any, Any]], *args):
|
||||
"""
|
||||
Readies a task for execution. All positional arguments are passed
|
||||
to the given coroutine (for keyword arguments, use functools.partial)
|
||||
"""
|
||||
|
||||
return NotImplemented
|
||||
|
||||
@abstractmethod
|
||||
def spawn_system_task(
|
||||
self, func: Callable[[Any, Any], Coroutine[Any, Any, Any]], *args
|
||||
):
|
||||
"""
|
||||
Spawns a system task. System tasks run in a special internal
|
||||
task pool and begin execution in a scope shielded by cancellations
|
||||
and with Ctrl+C protection enabled
|
||||
"""
|
||||
|
||||
return NotImplemented
|
||||
|
||||
@abstractmethod
|
||||
def get_closest_deadline(self):
|
||||
"""
|
||||
Returns the closest deadline to be satisfied
|
||||
"""
|
||||
|
||||
return NotImplemented
|
||||
|
||||
@abstractmethod
|
||||
def setup(self):
|
||||
"""
|
||||
This method is called right before startup and can
|
||||
be used by implementors to perform extra setup before
|
||||
starting the event loop
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def teardown(self):
|
||||
"""
|
||||
This method is called right before exiting, even
|
||||
if an error occurred, and can be used by implementors
|
||||
to perform extra cleanup before terminating the event loop
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def throw(self, task: Task, err: BaseException):
|
||||
"""
|
||||
Throws the given exception into the given
|
||||
task
|
||||
"""
|
||||
|
||||
return NotImplemented
|
||||
|
||||
@abstractmethod
|
||||
def reschedule(self, task: Task):
|
||||
"""
|
||||
Reschedules the given task for further
|
||||
execution
|
||||
"""
|
||||
|
||||
return NotImplemented
|
||||
|
||||
@abstractmethod
|
||||
def event(self, evt_name, *args):
|
||||
"""
|
||||
Fires the specified event for every registered tool
|
||||
in the event loop
|
||||
"""
|
||||
|
||||
return NotImplemented
|
||||
|
||||
@abstractmethod
|
||||
def run(self):
|
||||
"""
|
||||
This is the actual "loop" part
|
||||
of the "event loop"
|
||||
"""
|
||||
|
||||
return NotImplemented
|
||||
|
||||
@abstractmethod
|
||||
def sleep(self, amount):
|
||||
"""
|
||||
Puts the current task to sleep for the given amount of
|
||||
time as defined by our current clock
|
||||
"""
|
||||
|
||||
return NotImplemented
|
||||
|
||||
@abstractmethod
|
||||
def suspend(self):
|
||||
"""
|
||||
Suspends the current task until it is rescheduled
|
||||
"""
|
||||
|
||||
return NotImplemented
|
||||
|
||||
@abstractmethod
|
||||
def init_scope(self, scope):
|
||||
"""
|
||||
Initializes the given task scope (called by
|
||||
TaskScope.__enter__)
|
||||
"""
|
||||
|
||||
return NotImplemented
|
||||
|
||||
@abstractmethod
|
||||
def close_scope(self, scope):
|
||||
"""
|
||||
Closes the given task scope (called by
|
||||
TaskScope.__exit__)
|
||||
"""
|
||||
|
||||
return NotImplemented
|
||||
|
||||
@abstractmethod
|
||||
def init_pool(self, pool):
|
||||
"""
|
||||
Initializes the given task pool (called by
|
||||
TaskPool.__aenter__)
|
||||
"""
|
||||
|
||||
return NotImplemented
|
||||
|
||||
@abstractmethod
|
||||
def close_pool(self, pool):
|
||||
"""
|
||||
Closes the given task pool (called by
|
||||
TaskPool.__aexit__)
|
||||
"""
|
||||
|
||||
return NotImplemented
|
||||
|
||||
@abstractmethod
|
||||
def cancel_scope(self, scope):
|
||||
"""
|
||||
Cancels the given scope
|
||||
"""
|
||||
|
||||
return NotImplemented
|
||||
|
||||
def start(self, entry_point: Callable[[Any, Any], Coroutine[Any, Any, Any]], *args):
|
||||
"""
|
||||
Starts the event loop from a synchronous entry
|
||||
point. This method only returns once execution
|
||||
has finished. Normally, this method doesn't need
|
||||
to be overridden: consider using setup() and teardown()
|
||||
if you need to do some operations before startup/teardown
|
||||
"""
|
||||
|
||||
self.setup()
|
||||
self.event("on_start")
|
||||
self.current_pool = self.pool
|
||||
self.entry_point = self.spawn(entry_point, *args)
|
||||
self.current_pool.scope.owner = self.entry_point
|
||||
self.entry_point.pool = self.current_pool
|
||||
self.current_pool.entry_point = self.entry_point
|
||||
self.current_scope = self.current_pool.scope
|
||||
try:
|
||||
self.run()
|
||||
finally:
|
||||
self.teardown()
|
||||
self.close(force=True)
|
||||
if self.entry_point.exc:
|
||||
raise self.entry_point.exc
|
||||
self.event("on_exit")
|
||||
return self.entry_point.result
|
||||
|
||||
@abstractmethod
|
||||
def done(self):
|
||||
"""
|
||||
Returns whether the loop has work to do
|
||||
"""
|
||||
|
||||
return NotImplemented
|
||||
|
||||
def close(self, force: bool = False):
|
||||
"""
|
||||
Terminates and shuts down the event loop
|
||||
This method is meant to be extended by
|
||||
implementations to do their own cleanup
|
||||
|
||||
:param force: When force equals false,
|
||||
the default, and the event loop is
|
||||
not done, this function raises a
|
||||
StructIOException
|
||||
"""
|
||||
|
||||
if not self.done() and not force:
|
||||
raise StructIOException("the event loop is running")
|
|
@ -0,0 +1,195 @@
|
|||
import structio
|
||||
from structio.core.task import Task
|
||||
from structio.core.run import current_loop
|
||||
from typing import Callable, Coroutine, Any
|
||||
from structio.core.time.queue import TimeQueue
|
||||
from structio.core.syscalls import suspend, checkpoint
|
||||
from structio.exceptions import Cancelled, StructIOException
|
||||
|
||||
|
||||
class TaskScope:
|
||||
"""
|
||||
A task scope
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
timeout: int | float | None = None,
|
||||
silent: bool = False,
|
||||
shielded: bool = False,
|
||||
):
|
||||
"""
|
||||
Public object constructor
|
||||
"""
|
||||
|
||||
# When do we expire?
|
||||
self._timeout = timeout or float("inf")
|
||||
# This is updated with the actual wall clock
|
||||
# time of when we expire each time that the
|
||||
# timeout is modified
|
||||
self.deadline = -1
|
||||
# Do we raise an error on timeout?
|
||||
self.silent = silent
|
||||
# Has a cancellation attempt been done?
|
||||
self.attempted_cancel: bool = False
|
||||
# Have we been cancelled?
|
||||
self.cancelled: bool = False
|
||||
# Have we timed out?
|
||||
self.timed_out: bool = False
|
||||
# Can we be indirectly cancelled? Note that this
|
||||
# does not affect explicit cancellations via the
|
||||
# cancel() method
|
||||
self.shielded: bool = shielded
|
||||
# Data about inner and outer scopes.
|
||||
# This is used internally to make sure
|
||||
# nesting task scopes works as expected
|
||||
self.inner: list[TaskScope] = []
|
||||
self.outer: TaskScope | None = None
|
||||
# Which tasks do we contain?
|
||||
self.tasks: list[Task] = []
|
||||
self.owner: Task | None = None
|
||||
|
||||
@property
|
||||
def timeout(self):
|
||||
return self._timeout
|
||||
|
||||
@timeout.setter
|
||||
def timeout(self, value):
|
||||
self._timeout = value
|
||||
self.deadline = current_loop().clock.deadline(self.timeout)
|
||||
|
||||
def cancel(self):
|
||||
"""
|
||||
Cancels the task scope and all the work
|
||||
that belongs to it
|
||||
"""
|
||||
|
||||
current_loop().cancel_scope(self)
|
||||
|
||||
def get_effective_deadline(self) -> tuple[float, "TaskScope"]:
|
||||
"""
|
||||
Returns the effective deadline of the whole
|
||||
cancel scope
|
||||
"""
|
||||
|
||||
queue = TimeQueue()
|
||||
if self.shielded:
|
||||
return float("inf"), self
|
||||
times = queue.put(self, self.deadline)
|
||||
for child in self.children:
|
||||
if child.shielded:
|
||||
return float("inf"), self
|
||||
deadline, scope = child.get_effective_deadline()
|
||||
queue.put(scope, deadline)
|
||||
return queue.get_closest_deadline(), queue.get()[0]
|
||||
|
||||
def __repr__(self):
|
||||
return f"{self.__class__.__name__}(owner={self.owner}, timeout={self.timeout})"
|
||||
|
||||
def __enter__(self):
|
||||
current_loop().init_scope(self)
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type: type, exc_val: BaseException, exc_tb):
|
||||
current_loop().close_scope(self)
|
||||
if exc_val and isinstance(exc_val, structio.TimedOut):
|
||||
if exc_val.scope is self:
|
||||
return self.silent
|
||||
return True
|
||||
return False
|
||||
|
||||
# Just a recursive helper
|
||||
def _get_children(self, lst=None):
|
||||
if lst is None:
|
||||
lst = []
|
||||
for child in self.inner:
|
||||
lst.append(child)
|
||||
child._get_children(lst)
|
||||
return lst
|
||||
|
||||
@property
|
||||
def children(self) -> list["TaskScope"]:
|
||||
"""
|
||||
Gets all the scopes contained within this one
|
||||
"""
|
||||
|
||||
return self._get_children()
|
||||
|
||||
def done(self):
|
||||
"""
|
||||
Returns whether the task scope has finished executing
|
||||
"""
|
||||
|
||||
if not all(child.done() for child in self.children):
|
||||
return False
|
||||
return all(task.done() for task in self.tasks)
|
||||
|
||||
|
||||
class TaskPool:
|
||||
"""
|
||||
A task pool
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
Public object constructor
|
||||
"""
|
||||
|
||||
self.entry_point: Task | None = None
|
||||
self.scope: TaskScope = TaskScope(timeout=float("inf"))
|
||||
# This pool's parent
|
||||
self.outer: TaskPool | None = None
|
||||
# Have we errored out?
|
||||
self.error: BaseException | None = None
|
||||
# Have we exited? This is so we can forbid reuse of
|
||||
# dead task pools
|
||||
self._closed: bool = False
|
||||
|
||||
async def __aenter__(self):
|
||||
if self._closed:
|
||||
raise StructIOException("task pool is closed")
|
||||
self.scope.__enter__()
|
||||
current_loop().init_pool(self)
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type: type, exc_val: BaseException, exc_tb):
|
||||
try:
|
||||
if exc_val:
|
||||
await checkpoint()
|
||||
raise exc_val.with_traceback(exc_tb)
|
||||
elif not self.done():
|
||||
await suspend()
|
||||
else:
|
||||
await checkpoint()
|
||||
except Cancelled as e:
|
||||
self.error = e
|
||||
self.scope.cancelled = True
|
||||
except (Exception, KeyboardInterrupt) as e:
|
||||
self.error = e
|
||||
self.scope.cancel()
|
||||
finally:
|
||||
current_loop().close_pool(self)
|
||||
self._closed = True
|
||||
if self.error:
|
||||
raise self.error
|
||||
|
||||
def done(self):
|
||||
"""
|
||||
Returns whether the task pool has finished executing
|
||||
"""
|
||||
|
||||
return self.scope.done()
|
||||
|
||||
def spawn(
|
||||
self, func: Callable[[Any, Any], Coroutine[Any, Any, Any]], *args
|
||||
) -> Task:
|
||||
"""
|
||||
Schedule a new concurrent task for execution in the task pool from the given
|
||||
async function. All positional arguments are passed to the underlying coroutine
|
||||
(for keyword arguments, consider using functools.partial). A Task object is
|
||||
returned. Note that the coroutine is merely scheduled to run and does not begin
|
||||
executing until it is picked by the scheduler later on
|
||||
"""
|
||||
|
||||
self.scope.tasks.append(current_loop().spawn(func, *args))
|
||||
return self.scope.tasks[-1]
|
|
@ -0,0 +1,477 @@
|
|||
import traceback
|
||||
import warnings
|
||||
from types import FrameType
|
||||
from structio.abc import (
|
||||
BaseKernel,
|
||||
BaseClock,
|
||||
BaseDebugger,
|
||||
BaseIOManager,
|
||||
SignalManager,
|
||||
)
|
||||
from structio.io import FdWrapper
|
||||
from structio.core.context import TaskPool, TaskScope
|
||||
from structio.core.task import Task, TaskState
|
||||
from structio.util.ki import CTRLC_PROTECTION_ENABLED
|
||||
from structio.core.time.queue import TimeQueue
|
||||
from structio.exceptions import (
|
||||
StructIOException,
|
||||
Cancelled,
|
||||
TimedOut,
|
||||
ResourceClosed,
|
||||
ResourceBroken,
|
||||
)
|
||||
from collections import deque
|
||||
from typing import Callable, Coroutine, Any
|
||||
from functools import partial
|
||||
import signal
|
||||
import sniffio
|
||||
|
||||
|
||||
class FIFOKernel(BaseKernel):
|
||||
"""
|
||||
An asynchronous event loop implementation
|
||||
with a FIFO scheduling policy
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
clock: BaseClock,
|
||||
io_manager: BaseIOManager,
|
||||
signal_managers: list[SignalManager],
|
||||
tools: list[BaseDebugger] | None = None,
|
||||
restrict_ki_to_checkpoints: bool = False,
|
||||
):
|
||||
super().__init__(
|
||||
clock, io_manager, signal_managers, tools, restrict_ki_to_checkpoints
|
||||
)
|
||||
# Tasks that are ready to run
|
||||
self.run_queue: deque[Task] = deque()
|
||||
# Data to send back to tasks
|
||||
self.data: dict[Task, Any] = {}
|
||||
# Have we handled SIGINT?
|
||||
self._sigint_handled: bool = False
|
||||
# Paused tasks along with their deadlines
|
||||
self.paused: TimeQueue = TimeQueue()
|
||||
self.pool = TaskPool()
|
||||
self.pool.scope.shielded = True
|
||||
self.current_scope = self.pool.scope
|
||||
self.current_scope.shielded = False
|
||||
|
||||
def get_closest_deadline(self):
|
||||
if self.run_queue:
|
||||
# We absolutely cannot block while other
|
||||
# tasks have things to do!
|
||||
return self.clock.current_time()
|
||||
deadlines = []
|
||||
for scope in self.pool.scope.children:
|
||||
deadlines.append(scope.get_effective_deadline()[0])
|
||||
if not deadlines:
|
||||
deadlines.append(float("inf"))
|
||||
return min(
|
||||
[
|
||||
min(deadlines),
|
||||
self.paused.get_closest_deadline(),
|
||||
]
|
||||
)
|
||||
|
||||
def wait_readable(self, resource: FdWrapper):
|
||||
self.current_task.state = TaskState.IO
|
||||
self.io_manager.request_read(resource, self.current_task)
|
||||
|
||||
def wait_writable(self, resource: FdWrapper):
|
||||
self.current_task.state = TaskState.IO
|
||||
self.io_manager.request_write(resource, self.current_task)
|
||||
|
||||
def notify_closing(
|
||||
self, resource: FdWrapper, broken: bool = False, owner: Task | None = None
|
||||
):
|
||||
if not broken:
|
||||
exc = ResourceClosed("stream has been closed")
|
||||
else:
|
||||
exc = ResourceBroken("stream might be corrupted")
|
||||
owner = owner or self.current_task
|
||||
reader = self.io_manager.get_reader(resource)
|
||||
writer = self.io_manager.get_writer(resource)
|
||||
if reader and reader is not owner:
|
||||
self.throw(reader, exc)
|
||||
if writer and writer is not owner:
|
||||
self.throw(writer, exc)
|
||||
self.reschedule_running()
|
||||
|
||||
def event(self, evt_name: str, *args):
|
||||
if not hasattr(BaseDebugger, evt_name):
|
||||
warnings.warn(f"Invalid debugging event fired: {evt_name!r}")
|
||||
return
|
||||
for tool in self.tools:
|
||||
if f := getattr(tool, evt_name, None):
|
||||
try:
|
||||
f(*args)
|
||||
except BaseException as e:
|
||||
# We really can't afford to have our internals explode,
|
||||
# sorry!
|
||||
warnings.warn(
|
||||
f"Exception during debugging event delivery in {f!r} ({evt_name!r}): {type(e).__name__} -> {e}",
|
||||
)
|
||||
traceback.print_tb(e.__traceback__)
|
||||
# We disable the tool, so it can't raise at the next debugging
|
||||
# event
|
||||
self.tools.remove(tool)
|
||||
|
||||
def done(self):
|
||||
if self.entry_point.done():
|
||||
return True
|
||||
if any([self.run_queue, self.paused, self.io_manager.pending()]):
|
||||
return False
|
||||
if not self.pool.done():
|
||||
return False
|
||||
return True
|
||||
|
||||
def spawn(
|
||||
self,
|
||||
func: Callable[[Any, Any], Coroutine[Any, Any, Any]],
|
||||
*args,
|
||||
ki_protected: bool = False,
|
||||
pool: TaskPool = None,
|
||||
):
|
||||
if isinstance(func, partial):
|
||||
name = func.func.__name__ or repr(func.func)
|
||||
else:
|
||||
name = func.__name__ or repr(func)
|
||||
if pool is None:
|
||||
pool = self.current_pool
|
||||
task = Task(name, func(*args), pool)
|
||||
# We inject our magic secret variable into the coroutine's stack frame, so
|
||||
# we can look it up later
|
||||
task.coroutine.cr_frame.f_locals.setdefault(
|
||||
CTRLC_PROTECTION_ENABLED, ki_protected
|
||||
)
|
||||
self.run_queue.append(task)
|
||||
self.event("on_task_spawn")
|
||||
return task
|
||||
|
||||
def spawn_system_task(
|
||||
self, func: Callable[[Any, Any], Coroutine[Any, Any, Any]], *args
|
||||
):
|
||||
return self.spawn(func, *args, ki_protected=True, pool=self.pool)
|
||||
|
||||
def signal_notify(self, sig: int, frame: FrameType):
|
||||
match sig:
|
||||
case signal.SIGINT:
|
||||
self._sigint_handled = True
|
||||
# Poke the event loop with a stick ;)
|
||||
self.run_queue.append(self.entry_point)
|
||||
case _:
|
||||
pass
|
||||
|
||||
def step(self):
|
||||
"""
|
||||
Run a single task step (i.e. until an "await" to our
|
||||
primitives somewhere)
|
||||
"""
|
||||
|
||||
self.current_task = self.run_queue.popleft()
|
||||
while self.current_task.done():
|
||||
if not self.run_queue:
|
||||
return
|
||||
self.current_task = self.run_queue.popleft()
|
||||
runner = partial(
|
||||
self.current_task.coroutine.send, self.data.pop(self.current_task, None)
|
||||
)
|
||||
if self.current_task.pending_cancellation:
|
||||
runner = partial(self.current_task.coroutine.throw, Cancelled())
|
||||
elif self._sigint_handled:
|
||||
self._sigint_handled = False
|
||||
runner = partial(self.current_task.coroutine.throw, KeyboardInterrupt())
|
||||
self.event("before_task_step", self.current_task)
|
||||
self.current_task.state = TaskState.RUNNING
|
||||
self.current_task.paused_when = 0
|
||||
self.current_pool = self.current_task.pool
|
||||
self.current_scope = self.current_pool.scope
|
||||
method, args, kwargs = runner()
|
||||
self.current_task.state = TaskState.PAUSED
|
||||
self.current_task.paused_when = self.clock.current_time()
|
||||
if not callable(getattr(self, method, None)):
|
||||
# This if block is meant to be triggered by other async
|
||||
# libraries, which most likely have different method names and behaviors
|
||||
# compared to us. If you get this exception, and you're 100% sure you're
|
||||
# not mixing async primitives from other libraries, then it's a bug!
|
||||
self.throw(
|
||||
self.current_task,
|
||||
StructIOException(
|
||||
"Uh oh! Something bad just happened: did you try to mix "
|
||||
"primitives from other async libraries?"
|
||||
),
|
||||
)
|
||||
# Sneaky method call, thanks to David Beazley for this ;)
|
||||
getattr(self, method)(*args, **kwargs)
|
||||
self.event("after_task_step", self.current_task)
|
||||
|
||||
def throw(self, task: Task, err: BaseException):
|
||||
if task.done():
|
||||
return
|
||||
self.release(task)
|
||||
self.handle_errors(partial(task.coroutine.throw, err), task)
|
||||
|
||||
def reschedule(self, task: Task):
|
||||
if task.done():
|
||||
return
|
||||
self.run_queue.append(task)
|
||||
|
||||
def check_cancelled(self):
|
||||
if self._sigint_handled:
|
||||
self.throw(self.entry_point, KeyboardInterrupt())
|
||||
elif self.current_task.pending_cancellation:
|
||||
self.cancel_task(self.current_task)
|
||||
else:
|
||||
# We reschedule the caller immediately!
|
||||
self.run_queue.appendleft(self.current_task)
|
||||
|
||||
def schedule_point(self):
|
||||
self.reschedule_running()
|
||||
|
||||
def sleep(self, amount):
|
||||
"""
|
||||
Puts the current task to sleep for the given amount of
|
||||
time as defined by our current clock
|
||||
"""
|
||||
|
||||
# Just to avoid code duplication, you know
|
||||
self.suspend()
|
||||
if amount > 0:
|
||||
self.event("before_sleep", self.current_task, amount)
|
||||
self.current_task.next_deadline = self.clock.deadline(amount)
|
||||
self.paused.put(self.current_task, self.clock.deadline(amount))
|
||||
else:
|
||||
# If sleep is called with 0 as argument,
|
||||
# then it's just a checkpoint!
|
||||
self.schedule_point()
|
||||
self.check_cancelled()
|
||||
|
||||
def check_scopes(self):
|
||||
expired = set()
|
||||
for scope in self.pool.scope.children:
|
||||
deadline, actual = scope.get_effective_deadline()
|
||||
if deadline <= self.clock.current_time() and not actual.timed_out:
|
||||
expired.add(actual)
|
||||
for scope in expired:
|
||||
scope.timed_out = True
|
||||
error = TimedOut("timed out")
|
||||
error.scope = scope
|
||||
self.throw(scope.owner, error)
|
||||
self.reschedule(scope.owner)
|
||||
|
||||
def wakeup(self):
|
||||
while (
|
||||
self.paused
|
||||
and self.paused.peek().next_deadline <= self.clock.current_time()
|
||||
):
|
||||
task, _ = self.paused.get()
|
||||
task.next_deadline = 0
|
||||
self.event(
|
||||
"after_sleep", task, task.paused_when - self.clock.current_time()
|
||||
)
|
||||
self.reschedule(task)
|
||||
|
||||
def run(self):
|
||||
"""
|
||||
This is the actual "loop" part
|
||||
of the "event loop"
|
||||
"""
|
||||
|
||||
while not self.done():
|
||||
if self._sigint_handled and not self.restrict_ki_to_checkpoints:
|
||||
self.throw(self.entry_point, KeyboardInterrupt())
|
||||
if self.run_queue:
|
||||
self.handle_errors(self.step)
|
||||
self.wakeup()
|
||||
self.check_scopes()
|
||||
if self.io_manager.pending():
|
||||
self.io_manager.wait_io(self.clock.current_time())
|
||||
self.close()
|
||||
|
||||
def reschedule_running(self):
|
||||
"""
|
||||
Reschedules the currently running task
|
||||
"""
|
||||
|
||||
self.reschedule(self.current_task)
|
||||
|
||||
def handle_errors(self, func: Callable, task: Task | None = None):
|
||||
"""
|
||||
Convenience method for handling various exceptions
|
||||
from tasks
|
||||
"""
|
||||
|
||||
old_name, sniffio.thread_local.name = sniffio.thread_local.name, "structured-io"
|
||||
try:
|
||||
func()
|
||||
except StopIteration as ret:
|
||||
# We re-define it because we call step() with
|
||||
# this method and that changes the current task
|
||||
task = task or self.current_task
|
||||
# At the end of the day, coroutines are generator functions with
|
||||
# some tricky behaviors, and this is one of them. When a coroutine
|
||||
# hits a return statement (either explicit or implicit), it raises
|
||||
# a StopIteration exception, which has an attribute named value that
|
||||
# represents the return value of the coroutine, if it has one. Of course
|
||||
# this exception is not an error, and we should happily keep going after it:
|
||||
# most of this code below is just useful for internal/debugging purposes
|
||||
task.state = TaskState.FINISHED
|
||||
task.result = ret.value
|
||||
self.on_success(task)
|
||||
self.event("on_task_exit", task)
|
||||
except Cancelled:
|
||||
# When a task needs to be cancelled, we try to do it gracefully first:
|
||||
# if the task is paused in either I/O or sleeping, that's perfect.
|
||||
# But we also need to cancel a task if it was not sleeping or waiting on
|
||||
# any I/O because it could never do so (therefore blocking everything
|
||||
# forever). So, when cancellation can't be done right away, we schedule
|
||||
# it for the next execution step of the task. We will also make sure
|
||||
# to re-raise cancellations at every checkpoint until the task lets the
|
||||
# exception propagate into us, because we *really* want the task to be
|
||||
# cancelled
|
||||
task = task or self.current_task
|
||||
task.state = TaskState.CANCELLED
|
||||
task.pending_cancellation = False
|
||||
self.on_cancel(task)
|
||||
self.event("after_cancel", task)
|
||||
except (Exception, KeyboardInterrupt) as err:
|
||||
# Any other exception is caught here
|
||||
task = task or self.current_task
|
||||
task.exc = err
|
||||
err.scope = task.pool.scope
|
||||
task.state = TaskState.CRASHED
|
||||
self.on_error(task)
|
||||
self.event("on_exception_raised", task)
|
||||
finally:
|
||||
sniffio.thread_local.name = old_name
|
||||
|
||||
def release_resource(self, resource: FdWrapper):
|
||||
self.io_manager.release(resource)
|
||||
self.reschedule_running()
|
||||
|
||||
def release(self, task: Task):
|
||||
"""
|
||||
Releases the timeouts and associated
|
||||
I/O resourced that the given task owns
|
||||
"""
|
||||
|
||||
self.io_manager.release_task(task)
|
||||
self.paused.discard(task)
|
||||
|
||||
def on_success(self, task: Task):
|
||||
"""
|
||||
The given task has exited gracefully: hooray!
|
||||
"""
|
||||
|
||||
assert task.state == TaskState.FINISHED
|
||||
# Walk up the scope tree and reschedule all necessary
|
||||
# tasks
|
||||
scope = task.pool.scope
|
||||
while scope.done() and scope is not self.pool.scope:
|
||||
if scope.done():
|
||||
self.reschedule(scope.owner)
|
||||
scope = scope.outer
|
||||
self.event("on_task_exit", task)
|
||||
self.release(task)
|
||||
|
||||
def on_error(self, task: Task):
|
||||
"""
|
||||
The given task raised an exception
|
||||
"""
|
||||
|
||||
assert task.state == TaskState.CRASHED
|
||||
self.event("on_exception_raised", task, task.exc)
|
||||
scope = task.pool.scope
|
||||
if task is not scope.owner:
|
||||
self.reschedule(scope.owner)
|
||||
self.throw(scope.owner, task.exc)
|
||||
self.release(task)
|
||||
|
||||
def on_cancel(self, task: Task):
|
||||
"""
|
||||
The given task crashed because of a
|
||||
cancellation exception
|
||||
"""
|
||||
|
||||
assert task.state == TaskState.CANCELLED
|
||||
self.event("after_cancel", task)
|
||||
self.release(task)
|
||||
|
||||
def init_scope(self, scope: TaskScope):
|
||||
scope.deadline = self.clock.deadline(scope.timeout)
|
||||
scope.owner = self.current_task
|
||||
self.current_scope.inner.append(scope)
|
||||
scope.outer = self.current_scope
|
||||
self.current_scope = scope
|
||||
|
||||
def close_scope(self, scope: TaskScope):
|
||||
self.current_scope = scope.outer
|
||||
self.current_scope.inner = []
|
||||
|
||||
def cancel_task(self, task: Task):
|
||||
if task.done():
|
||||
return
|
||||
if task.state == TaskState.RUNNING:
|
||||
# Can't cancel a task while it's
|
||||
# running, will raise ValueError
|
||||
# if we try. We defer it for later
|
||||
task.pending_cancellation = True
|
||||
return
|
||||
err = Cancelled()
|
||||
err.scope = task.pool.scope
|
||||
self.throw(task, err)
|
||||
if task.state != TaskState.CANCELLED:
|
||||
# Task is stubborn. But so are we,
|
||||
# so we'll redeliver the cancellation
|
||||
# every time said task tries to call any
|
||||
# event loop primitive
|
||||
task.pending_cancellation = True
|
||||
|
||||
def cancel_scope(self, scope: TaskScope):
|
||||
scope.attempted_cancel = True
|
||||
# We can't just immediately cancel the
|
||||
# current task because this method is
|
||||
# called synchronously by TaskScope.cancel(),
|
||||
# so there is nowhere to throw an exception
|
||||
# to
|
||||
if self.current_task in scope.tasks and self.current_task is not scope.owner:
|
||||
self.current_task.pending_cancellation = True
|
||||
for child in filter(lambda c: not c.shielded, scope.children):
|
||||
self.cancel_scope(child)
|
||||
for task in scope.tasks:
|
||||
if task is self.current_task:
|
||||
continue
|
||||
self.cancel_task(task)
|
||||
if (
|
||||
scope is not self.current_task.pool.scope
|
||||
and scope.owner is not self.current_task
|
||||
and scope.owner is not self.entry_point
|
||||
):
|
||||
# Handles the case where the current task calls
|
||||
# cancel() for a scope which it doesn't own, which
|
||||
# is an entirely reasonable thing to do
|
||||
self.cancel_task(scope.owner)
|
||||
if scope.done():
|
||||
scope.cancelled = True
|
||||
|
||||
def init_pool(self, pool: TaskPool):
|
||||
pool.outer = self.current_pool
|
||||
pool.entry_point = self.current_task
|
||||
self.current_pool = pool
|
||||
|
||||
def close_pool(self, pool: TaskPool):
|
||||
self.current_pool = pool.outer
|
||||
self.close_scope(pool.scope)
|
||||
|
||||
def suspend(self):
|
||||
self.current_task.state = TaskState.PAUSED
|
||||
self.current_task.paused_when = self.clock.current_time()
|
||||
|
||||
def setup(self):
|
||||
for manager in self.signal_managers:
|
||||
manager.install()
|
||||
|
||||
def teardown(self):
|
||||
for manager in self.signal_managers:
|
||||
manager.uninstall()
|
|
@ -0,0 +1,109 @@
|
|||
from structio.abc import BaseIOManager, BaseKernel
|
||||
from structio.core.task import Task, TaskState
|
||||
from structio.util.ki import CTRLC_PROTECTION_ENABLED
|
||||
from structio.core.run import current_loop, current_task
|
||||
from structio.io import FdWrapper
|
||||
import select
|
||||
import signal
|
||||
|
||||
|
||||
class SimpleIOManager(BaseIOManager):
|
||||
"""
|
||||
A simple, cross-platform, select()-based
|
||||
I/O manager. This class is only meant to
|
||||
be used as a default fallback and is quite
|
||||
inefficient and slower compared to more ad-hoc
|
||||
alternatives such as epoll or kqueue (it should
|
||||
work on most platforms though)
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
Public object constructor
|
||||
"""
|
||||
|
||||
# Maps resources to tasks
|
||||
self.readers: dict[FdWrapper, Task] = {}
|
||||
self.writers: dict[FdWrapper, Task] = {}
|
||||
|
||||
def pending(self) -> bool:
|
||||
return bool(self.readers or self.writers)
|
||||
|
||||
def get_reader(self, rsc: FdWrapper):
|
||||
return self.readers.get(rsc)
|
||||
|
||||
def get_writer(self, rsc: FdWrapper):
|
||||
return self.writers.get(rsc)
|
||||
|
||||
def _collect_readers(self) -> list[int]:
|
||||
"""
|
||||
Collects all resources that need to be read from,
|
||||
so we can select() on them later
|
||||
"""
|
||||
|
||||
result = []
|
||||
for reader in self.readers:
|
||||
result.append(reader.fileno())
|
||||
return result
|
||||
|
||||
def _collect_writers(self) -> list[int]:
|
||||
"""
|
||||
Collects all resources that need to be written to,
|
||||
so we can select() on them later
|
||||
"""
|
||||
|
||||
result = []
|
||||
for writer in self.writers:
|
||||
result.append(writer.fileno())
|
||||
return result
|
||||
|
||||
def wait_io(self, current_time):
|
||||
kernel: BaseKernel = current_loop()
|
||||
deadline = kernel.get_closest_deadline()
|
||||
if deadline == float("inf"):
|
||||
deadline = 0
|
||||
elif deadline > 0:
|
||||
deadline -= current_time
|
||||
deadline = max(0, deadline)
|
||||
writers = self._collect_writers()
|
||||
readable, writable, exceptional = select.select(
|
||||
self._collect_readers(),
|
||||
writers,
|
||||
writers,
|
||||
deadline,
|
||||
)
|
||||
# On Windows, a successful connection is marked
|
||||
# as an exceptional event rather than a write
|
||||
# one
|
||||
writable.extend(exceptional)
|
||||
del exceptional
|
||||
for read_ready in readable:
|
||||
for resource, task in self.readers.copy().items():
|
||||
if resource.fileno() == read_ready and task.state == TaskState.IO:
|
||||
kernel.reschedule(task)
|
||||
self.readers.pop(resource)
|
||||
for write_ready in writable:
|
||||
for resource, task in self.writers.copy().items():
|
||||
if resource.fileno() == write_ready and task.state == TaskState.IO:
|
||||
kernel.reschedule(task)
|
||||
self.writers.pop(resource)
|
||||
|
||||
def request_read(self, rsc: FdWrapper, task: Task):
|
||||
current_task().state = TaskState.IO
|
||||
self.readers[rsc] = task
|
||||
|
||||
def request_write(self, rsc: FdWrapper, task: Task):
|
||||
current_task().state = TaskState.IO
|
||||
self.writers[rsc] = task
|
||||
|
||||
def release(self, resource: FdWrapper):
|
||||
self.readers.pop(resource, None)
|
||||
self.writers.pop(resource, None)
|
||||
|
||||
def release_task(self, task: Task):
|
||||
for resource, owner in self.readers.copy().items():
|
||||
if owner == task:
|
||||
self.readers.pop(resource)
|
||||
for resource, owner in self.writers.copy().items():
|
||||
if owner == task:
|
||||
self.writers.pop(resource)
|
|
@ -0,0 +1,39 @@
|
|||
from structio.abc import SignalManager
|
||||
from structio.util.ki import currently_protected
|
||||
from structio.signals import set_signal_handler
|
||||
from structio.core.run import current_loop
|
||||
from types import FrameType
|
||||
import warnings
|
||||
import signal
|
||||
|
||||
|
||||
class SigIntManager(SignalManager):
|
||||
"""
|
||||
Handles Ctrl+C
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.installed = False
|
||||
|
||||
@staticmethod
|
||||
async def _handle(sig: int, frame: FrameType):
|
||||
if currently_protected():
|
||||
current_loop().signal_notify(sig, frame)
|
||||
else:
|
||||
current_loop().reschedule(current_loop().entry_point)
|
||||
current_loop().throw(current_loop().entry_point, KeyboardInterrupt())
|
||||
|
||||
def install(self):
|
||||
if signal.getsignal(signal.SIGINT) != signal.default_int_handler:
|
||||
warnings.warn(
|
||||
f"structio has detected a custom SIGINT handler and won't touch it: keep in mind"
|
||||
f" this is likely to break KeyboardInterrupt delivery!"
|
||||
)
|
||||
return
|
||||
set_signal_handler(signal.SIGINT, self._handle)
|
||||
self.installed = True
|
||||
|
||||
def uninstall(self):
|
||||
if self.installed:
|
||||
signal.signal(signal.SIGINT, signal.default_int_handler)
|
||||
self.installed = False
|
|
@ -0,0 +1,104 @@
|
|||
import inspect
|
||||
import structio
|
||||
import functools
|
||||
from threading import local
|
||||
from structio.abc import (
|
||||
BaseKernel,
|
||||
BaseDebugger,
|
||||
BaseClock,
|
||||
SignalManager,
|
||||
BaseIOManager,
|
||||
)
|
||||
from structio.exceptions import StructIOException
|
||||
from structio.core.task import Task
|
||||
from typing import Callable, Any, Coroutine
|
||||
|
||||
_RUN = local()
|
||||
|
||||
|
||||
def current_loop() -> BaseKernel:
|
||||
"""
|
||||
Returns the current event loop in the calling
|
||||
thread. Raises a StructIOException if no async
|
||||
context exists
|
||||
"""
|
||||
try:
|
||||
return _RUN.kernel
|
||||
except AttributeError:
|
||||
raise StructIOException("must be called from async context") from None
|
||||
|
||||
|
||||
def current_task() -> Task:
|
||||
"""
|
||||
Shorthand for current_loop().current_task
|
||||
"""
|
||||
|
||||
return current_loop().current_task
|
||||
|
||||
|
||||
def new_event_loop(kernel: BaseKernel):
|
||||
"""
|
||||
Initializes a new event loop using the
|
||||
given kernel implementation. Cannot be
|
||||
called from an asynchronous context
|
||||
"""
|
||||
|
||||
try:
|
||||
current_loop()
|
||||
except StructIOException:
|
||||
_RUN.kernel = kernel
|
||||
else:
|
||||
if not _RUN.kernel.done():
|
||||
raise StructIOException(
|
||||
"cannot be called from running async context"
|
||||
) from None
|
||||
_RUN.kernel = kernel
|
||||
|
||||
|
||||
def run(
|
||||
func: Callable[[Any, Any], Coroutine[Any, Any, Any]],
|
||||
kernel: type,
|
||||
io_manager: BaseIOManager,
|
||||
signal_managers: list[SignalManager],
|
||||
clock: BaseClock,
|
||||
tools: list[BaseDebugger] | None = None,
|
||||
restrict_ki_to_checkpoints: bool = False,
|
||||
*args,
|
||||
):
|
||||
"""
|
||||
Starts the event loop from a synchronous entry point. All
|
||||
positional arguments are passed to the given coroutine
|
||||
function. If you want to pass keyword arguments, consider
|
||||
using functools.partial()
|
||||
"""
|
||||
|
||||
if not issubclass(kernel, BaseKernel):
|
||||
raise TypeError(
|
||||
f"kernel must be a subclass of {BaseKernel.__module__}.{BaseKernel.__qualname__}!"
|
||||
)
|
||||
check = func
|
||||
if isinstance(func, functools.partial):
|
||||
check = func.func
|
||||
if inspect.iscoroutine(check):
|
||||
raise StructIOException(
|
||||
"Looks like you tried to call structio.run(your_func(arg1, arg2, ...)), that is wrong!"
|
||||
"\nWhat you wanna do, instead, is this: structio.run(your_func, arg1, arg2, ...)"
|
||||
)
|
||||
elif not inspect.iscoroutinefunction(check):
|
||||
raise StructIOException(
|
||||
"structio.run() requires an async function as its first argument!"
|
||||
)
|
||||
waker = structio.util.wakeup_fd.WakeupFd()
|
||||
watcher = structio.signals.signal_watcher
|
||||
waker.set_wakeup_fd()
|
||||
new_event_loop(
|
||||
kernel(
|
||||
clock=clock,
|
||||
restrict_ki_to_checkpoints=restrict_ki_to_checkpoints,
|
||||
io_manager=io_manager,
|
||||
signal_managers=signal_managers,
|
||||
tools=tools,
|
||||
)
|
||||
)
|
||||
current_loop().spawn_system_task(watcher, waker.reader)
|
||||
return current_loop().start(func, *args)
|
|
@ -0,0 +1,86 @@
|
|||
from types import coroutine
|
||||
from typing import Any
|
||||
|
||||
|
||||
@coroutine
|
||||
def syscall(method: str, *args, **kwargs) -> Any | None:
|
||||
"""
|
||||
Lowest-level primitive to interact with the event loop:
|
||||
calls a loop method with the provided arguments. This
|
||||
function should not be used directly, but through abstraction
|
||||
layers. All positional and keyword arguments are passed to
|
||||
the method itself and its return value is provided once the
|
||||
loop yields control back to us
|
||||
|
||||
:param method: The loop method to call
|
||||
:type method: str
|
||||
:returns: The result of the method call, if any
|
||||
"""
|
||||
|
||||
result = yield method, args, kwargs
|
||||
return result
|
||||
|
||||
|
||||
async def sleep(amount):
|
||||
"""
|
||||
Puts the caller asleep for the given amount of
|
||||
time which is, by default, measured in seconds,
|
||||
although this is not enforced: if a custom clock
|
||||
implementation is being used, the values passed
|
||||
to this function may have a different meaning
|
||||
"""
|
||||
|
||||
await syscall("sleep", amount)
|
||||
|
||||
|
||||
async def suspend():
|
||||
"""
|
||||
Pauses the caller indefinitely
|
||||
until it is rescheduled
|
||||
"""
|
||||
|
||||
await syscall("suspend")
|
||||
|
||||
|
||||
async def check_cancelled():
|
||||
"""
|
||||
Introduce a cancellation point, but
|
||||
not a schedule point
|
||||
"""
|
||||
|
||||
return await syscall("check_cancelled")
|
||||
|
||||
|
||||
async def schedule_point():
|
||||
"""
|
||||
Introduce a schedule point, but not a
|
||||
cancellation point
|
||||
"""
|
||||
|
||||
return await syscall("schedule_point")
|
||||
|
||||
|
||||
async def checkpoint():
|
||||
"""
|
||||
Introduce a cancellation point and a
|
||||
schedule point
|
||||
"""
|
||||
|
||||
await check_cancelled()
|
||||
await schedule_point()
|
||||
|
||||
|
||||
async def wait_readable(rsc):
|
||||
return await syscall("wait_readable", rsc)
|
||||
|
||||
|
||||
async def wait_writable(rsc):
|
||||
return await syscall("wait_writable", rsc)
|
||||
|
||||
|
||||
async def closing(rsc):
|
||||
return await syscall("notify_closing", rsc)
|
||||
|
||||
|
||||
async def release(rsc):
|
||||
return await syscall("release_resource", rsc)
|
|
@ -0,0 +1,81 @@
|
|||
from enum import Enum, auto
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Coroutine, Any, Callable
|
||||
from itertools import count
|
||||
|
||||
_counter = count()
|
||||
|
||||
|
||||
class TaskState(Enum):
|
||||
INIT: int = auto()
|
||||
RUNNING: int = auto()
|
||||
PAUSED: int = auto()
|
||||
FINISHED: int = auto()
|
||||
CRASHED: int = auto()
|
||||
CANCELLED: int = auto()
|
||||
IO: int = auto()
|
||||
|
||||
|
||||
_joiner: Callable[[Any, Any], Coroutine[Any, Any, Any]] | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class Task:
|
||||
"""
|
||||
An asynchronous task wrapper
|
||||
"""
|
||||
|
||||
# The task's name
|
||||
name: str
|
||||
# The underlying coroutine of this
|
||||
# task
|
||||
coroutine: Coroutine = field(repr=False)
|
||||
# The task's pool
|
||||
pool: "TaskPool" = field(repr=False)
|
||||
# The state of the task
|
||||
state: TaskState = field(default=TaskState.INIT)
|
||||
# Used for debugging
|
||||
id: int = field(default_factory=lambda: next(_counter))
|
||||
# What error did the task raise, if any?
|
||||
exc: BaseException | None = None
|
||||
# The task's return value, if any
|
||||
result: Any | None = None
|
||||
# When did the task pause last time?
|
||||
paused_when: Any = -1
|
||||
# When is the task's next deadline?
|
||||
next_deadline: Any = -1
|
||||
# Is cancellation pending?
|
||||
pending_cancellation: bool = False
|
||||
|
||||
def done(self):
|
||||
"""
|
||||
Returns whether the task is running
|
||||
"""
|
||||
|
||||
return self.state in [
|
||||
TaskState.CRASHED,
|
||||
TaskState.FINISHED,
|
||||
TaskState.CANCELLED,
|
||||
]
|
||||
|
||||
def __hash__(self):
|
||||
"""
|
||||
Implements hash(self)
|
||||
"""
|
||||
|
||||
return self.coroutine.__hash__()
|
||||
|
||||
# These are patched later at import time!
|
||||
def __await__(self):
|
||||
"""
|
||||
Wait for the task to complete and return/raise appropriately (returns when cancelled)
|
||||
"""
|
||||
|
||||
return _joiner(self).__await__()
|
||||
|
||||
def cancel(self):
|
||||
"""
|
||||
Cancels the given task
|
||||
"""
|
||||
|
||||
return NotImplemented
|
|
@ -0,0 +1,27 @@
|
|||
import random
|
||||
from timeit import default_timer
|
||||
from structio.abc import BaseClock
|
||||
|
||||
|
||||
class DefaultClock(BaseClock):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# We add a large random offset to our timer value
|
||||
# so users notice the problem if they try to compare
|
||||
# them across different runs
|
||||
self.offset: int = random.randint(100_000, 1_000_000)
|
||||
|
||||
def start(self):
|
||||
pass
|
||||
|
||||
def setup(self):
|
||||
pass
|
||||
|
||||
def teardown(self):
|
||||
pass
|
||||
|
||||
def current_time(self) -> float:
|
||||
return default_timer() + self.offset
|
||||
|
||||
def deadline(self, deadline):
|
||||
return self.current_time() + deadline
|
|
@ -0,0 +1,150 @@
|
|||
from typing import Any
|
||||
from structio.core.task import Task, TaskState
|
||||
from heapq import heappush, heappop, heapify
|
||||
|
||||
|
||||
class TimeQueue:
|
||||
"""
|
||||
An abstraction layer over a heap queue based on time
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
Object constructor
|
||||
"""
|
||||
|
||||
# The sequence float handles the race condition
|
||||
# of two items with identical deadlines, acting
|
||||
# as a tiebreaker
|
||||
self.sequence = 0
|
||||
self.container: list[tuple[float, int, Task, dict[str, Any]]] = []
|
||||
|
||||
def peek(self) -> Task:
|
||||
"""
|
||||
Returns the first task in the queue
|
||||
"""
|
||||
|
||||
return self.container[0][2]
|
||||
|
||||
def __len__(self):
|
||||
"""
|
||||
Returns len(self)
|
||||
"""
|
||||
|
||||
return len(self.container)
|
||||
|
||||
def __contains__(self, item):
|
||||
"""
|
||||
Implements item in self. This method ignores
|
||||
timeouts and tiebreakers
|
||||
"""
|
||||
|
||||
for i in self.container:
|
||||
if i[2] == item:
|
||||
return True
|
||||
return False
|
||||
|
||||
def index(self, item):
|
||||
"""
|
||||
Returns the index of the given item in the list
|
||||
or -1 if it is not present
|
||||
"""
|
||||
|
||||
for i, e in enumerate(self.container):
|
||||
if e[2] == item:
|
||||
return i
|
||||
return -1
|
||||
|
||||
def discard(self, item):
|
||||
"""
|
||||
Discards an item from the queue and
|
||||
calls heapify(self.container) to keep
|
||||
the heap invariant if an element is removed.
|
||||
This method does nothing if the item is not
|
||||
in the queue, but note that in this case the
|
||||
operation would still take at least O(n)
|
||||
iterations to complete
|
||||
|
||||
:param item: The item to be discarded
|
||||
"""
|
||||
|
||||
idx = self.index(item)
|
||||
if idx != -1:
|
||||
self.container.pop(idx)
|
||||
heapify(self.container)
|
||||
|
||||
def get_closest_deadline(self) -> float:
|
||||
"""
|
||||
Returns the closest deadline that is meant to expire
|
||||
"""
|
||||
|
||||
if not self:
|
||||
return float("inf")
|
||||
return self.container[0][0]
|
||||
|
||||
def __iter__(self):
|
||||
"""
|
||||
Implements iter(self)
|
||||
"""
|
||||
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
"""
|
||||
Implements next(self)
|
||||
"""
|
||||
|
||||
try:
|
||||
return self.get()
|
||||
except IndexError:
|
||||
raise StopIteration from None
|
||||
|
||||
def __getitem__(self, item: int):
|
||||
"""
|
||||
Implements self[n]
|
||||
"""
|
||||
|
||||
return self.container.__getitem__(item)
|
||||
|
||||
def __bool__(self):
|
||||
"""
|
||||
Implements bool(self)
|
||||
"""
|
||||
|
||||
return bool(self.container)
|
||||
|
||||
def __repr__(self):
|
||||
"""
|
||||
Implements repr(self) and str(self)
|
||||
"""
|
||||
|
||||
return f"TimeQueue({self.container})"
|
||||
|
||||
def put(self, item, delay: float, metadata: dict[str, Any] | None = None):
|
||||
"""
|
||||
Pushes an item onto the queue together with its
|
||||
delay and optional metadata
|
||||
|
||||
:param item: The item to be pushed
|
||||
:param delay: The delay associated with the item
|
||||
:type delay: float
|
||||
:param metadata: A dictionary representing additional
|
||||
metadata. Defaults to None
|
||||
:type metadata: dict[str, Any], optional
|
||||
"""
|
||||
|
||||
heappush(self.container, (delay, self.sequence, item, metadata))
|
||||
self.sequence += 1
|
||||
|
||||
def get(self) -> tuple[Any, dict[str, Any] | None]:
|
||||
"""
|
||||
Gets the first item on the queue along with
|
||||
its metadata
|
||||
|
||||
:raises: IndexError if the queue is empty
|
||||
"""
|
||||
|
||||
if not self.container:
|
||||
raise IndexError("get from empty TimeQueue")
|
||||
_, __, item, meta = heappop(self.container)
|
||||
return item, meta
|
|
@ -0,0 +1,9 @@
|
|||
from structio.abc import BaseDebugger
|
||||
|
||||
|
||||
class SimpleDebugger(BaseDebugger):
|
||||
def on_start(self):
|
||||
print(">> Started")
|
||||
|
||||
def on_exit(self):
|
||||
print(f"<< Stopped")
|
|
@ -0,0 +1,54 @@
|
|||
class StructIOException(Exception):
|
||||
"""
|
||||
A generic StructIO error
|
||||
"""
|
||||
|
||||
|
||||
class Cancelled(BaseException):
|
||||
# We inherit from BaseException
|
||||
# so that users don't accidentally
|
||||
# ignore cancellations
|
||||
"""
|
||||
A cancellation exception
|
||||
"""
|
||||
|
||||
scope: "TaskScope"
|
||||
|
||||
|
||||
class TimedOut(StructIOException):
|
||||
"""
|
||||
Raised when a task scope times out.
|
||||
The scope attribute can be used to
|
||||
know which scope originally timed
|
||||
out
|
||||
"""
|
||||
|
||||
scope: "TaskScope"
|
||||
|
||||
|
||||
class ResourceClosed(StructIOException):
|
||||
"""
|
||||
Raised when an asynchronous resource is
|
||||
closed and no longer usable
|
||||
"""
|
||||
|
||||
|
||||
class ResourceBusy(StructIOException):
|
||||
"""
|
||||
Raised when an attempt is made to use an
|
||||
asynchronous resource that is currently busy
|
||||
"""
|
||||
|
||||
|
||||
class ResourceBroken(StructIOException):
|
||||
"""
|
||||
Raised when an asynchronous resource gets
|
||||
corrupted and is no longer usable
|
||||
"""
|
||||
|
||||
|
||||
class WouldBlock(StructIOException):
|
||||
"""
|
||||
Raised when a non-blocking operation
|
||||
cannot be carried out immediately
|
||||
"""
|
|
@ -0,0 +1,196 @@
|
|||
# This is, ahem, inspired by Curio and Trio. See https://github.com/dabeaz/curio/issues/104
|
||||
import io
|
||||
import os
|
||||
from structio.core.syscalls import (
|
||||
checkpoint,
|
||||
wait_readable,
|
||||
wait_writable,
|
||||
closing,
|
||||
release,
|
||||
)
|
||||
from structio.exceptions import ResourceClosed
|
||||
from structio.abc import AsyncResource
|
||||
|
||||
try:
|
||||
from ssl import SSLWantReadError, SSLWantWriteError, SSLSocket
|
||||
|
||||
WantRead = (BlockingIOError, SSLWantReadError, InterruptedError)
|
||||
WantWrite = (BlockingIOError, SSLWantWriteError, InterruptedError)
|
||||
except ImportError:
|
||||
WantWrite = (BlockingIOError, InterruptedError)
|
||||
WantRead = (BlockingIOError, InterruptedError)
|
||||
SSLSocket = None
|
||||
|
||||
|
||||
class FdWrapper:
|
||||
"""
|
||||
A simple wrapper around a file descriptor that
|
||||
allows the event loop to perform an optimization
|
||||
regarding I/O event registration safely. This is
|
||||
because while integer file descriptors can be reused
|
||||
by the operating system, instances of this class will
|
||||
not (hence if the event loop keeps around a dead instance
|
||||
of an FdWrapper, it at least won't accidentally register
|
||||
a new file with that same file descriptor). A bonus is
|
||||
that this also allows us to always assume that we can call
|
||||
fileno() on all objects registered in our selector, regardless
|
||||
of whether the wrapped fd is an int or something else entirely
|
||||
"""
|
||||
|
||||
__slots__ = ("fd",)
|
||||
|
||||
def __init__(self, fd):
|
||||
self.fd = fd
|
||||
|
||||
def fileno(self):
|
||||
return self.fd
|
||||
|
||||
# Can be converted to an int
|
||||
def __int__(self):
|
||||
return self.fd
|
||||
|
||||
def __repr__(self):
|
||||
return f"<fd={self.fd!r}>"
|
||||
|
||||
|
||||
class AsyncStream(AsyncResource):
|
||||
"""
|
||||
A generic asynchronous stream over
|
||||
a file-like object, with buffering
|
||||
"""
|
||||
|
||||
def __init__(self, fileobj):
|
||||
self.fileobj = fileobj
|
||||
self._fd = FdWrapper(self.fileobj.fileno())
|
||||
self._buf = bytearray()
|
||||
|
||||
async def _read(self, size: int = -1) -> bytes:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def write(self, data):
|
||||
raise NotImplementedError()
|
||||
|
||||
async def read(self, size: int = -1):
|
||||
"""
|
||||
Reads up to size bytes from the
|
||||
given stream. If size == -1, read
|
||||
as much as possible
|
||||
"""
|
||||
|
||||
if size < 0 and size < -1:
|
||||
raise ValueError("size must be -1 or a positive integer")
|
||||
if size == -1:
|
||||
size = len(self._buf)
|
||||
buf = self._buf
|
||||
if not buf:
|
||||
return await self._read(size)
|
||||
if len(buf) <= size:
|
||||
data = bytes(buf)
|
||||
buf.clear()
|
||||
else:
|
||||
data = bytes(buf[:size])
|
||||
del buf[:size]
|
||||
return data
|
||||
|
||||
# Yes I stole this from curio. Sue me.
|
||||
async def readall(self):
|
||||
chunks = []
|
||||
maxread = 65536
|
||||
if self._buf:
|
||||
chunks.append(bytes(self._buf))
|
||||
self._buf.clear()
|
||||
while True:
|
||||
chunk = await self.read(maxread)
|
||||
if not chunk:
|
||||
return b"".join(chunks)
|
||||
chunks.append(chunk)
|
||||
if len(chunk) == maxread:
|
||||
maxread *= 2
|
||||
|
||||
async def flush(self):
|
||||
pass
|
||||
|
||||
async def close(self):
|
||||
"""
|
||||
Closes the stream asynchronously
|
||||
"""
|
||||
|
||||
if self.fileno() == -1:
|
||||
return
|
||||
await self.flush()
|
||||
await closing(self._fd)
|
||||
await release(self._fd)
|
||||
self.fileobj.close()
|
||||
self.fileobj = None
|
||||
self._fd = -1
|
||||
await checkpoint()
|
||||
|
||||
def fileno(self):
|
||||
"""
|
||||
Wrapper socket method
|
||||
"""
|
||||
|
||||
return int(self._fd)
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *args):
|
||||
if self.fileno() != -1:
|
||||
await self.close()
|
||||
|
||||
def __repr__(self):
|
||||
return f"AsyncStream({self.fileobj})"
|
||||
|
||||
|
||||
class FileStream(AsyncStream):
|
||||
"""
|
||||
A stream wrapper around a binary file-like object.
|
||||
The underlying file descriptor is put into non-blocking
|
||||
mode
|
||||
"""
|
||||
|
||||
async def _read(self, size: int = -1) -> bytes:
|
||||
while True:
|
||||
try:
|
||||
data = self.fileobj.read(size)
|
||||
if data is None:
|
||||
# Files in non-blocking mode don't always
|
||||
# raise a blocking I/O exception and can
|
||||
# return None instead, so we account for
|
||||
# that here
|
||||
raise BlockingIOError()
|
||||
return data
|
||||
except WantRead:
|
||||
await wait_readable(self._fd)
|
||||
|
||||
async def write(self, data):
|
||||
# We use a memory view so that
|
||||
# slicing doesn't copy any memory
|
||||
mem = memoryview(data)
|
||||
while mem:
|
||||
try:
|
||||
written = self.fileobj.write(data)
|
||||
if written is None:
|
||||
raise BlockingIOError()
|
||||
mem = mem[written:]
|
||||
except WantWrite:
|
||||
await wait_writable(self._fd)
|
||||
|
||||
async def flush(self):
|
||||
if self.fileno() == -1:
|
||||
return
|
||||
while True:
|
||||
try:
|
||||
return self.fileobj.flush()
|
||||
except WantWrite:
|
||||
await wait_writable(self._fd)
|
||||
except WantRead:
|
||||
await wait_readable(self._fd)
|
||||
|
||||
def __init__(self, fileobj):
|
||||
if isinstance(fileobj, io.TextIOBase):
|
||||
raise TypeError("only binary mode files can be streamed")
|
||||
super().__init__(fileobj)
|
||||
if hasattr(os, "set_blocking"):
|
||||
os.set_blocking(self.fileno(), False)
|
|
@ -0,0 +1,187 @@
|
|||
import io
|
||||
import sys
|
||||
import structio
|
||||
from functools import partial
|
||||
from structio.abc import AsyncResource
|
||||
from structio.core.syscalls import check_cancelled
|
||||
|
||||
# Stolen from Trio
|
||||
_FILE_SYNC_ATTRS = {
|
||||
"closed",
|
||||
"encoding",
|
||||
"errors",
|
||||
"fileno",
|
||||
"isatty",
|
||||
"newlines",
|
||||
"readable",
|
||||
"seekable",
|
||||
"writable",
|
||||
"buffer",
|
||||
"raw",
|
||||
"line_buffering",
|
||||
"closefd",
|
||||
"name",
|
||||
"mode",
|
||||
"getvalue",
|
||||
"getbuffer",
|
||||
}
|
||||
|
||||
_FILE_ASYNC_METHODS = {
|
||||
"flush",
|
||||
"read",
|
||||
"read1",
|
||||
"readall",
|
||||
"readinto",
|
||||
"readline",
|
||||
"readlines",
|
||||
"seek",
|
||||
"tell",
|
||||
"truncate",
|
||||
"write",
|
||||
"writelines",
|
||||
"readinto1",
|
||||
"peek",
|
||||
}
|
||||
|
||||
|
||||
class AsyncFile(AsyncResource):
|
||||
"""
|
||||
Asynchronous wrapper around regular file-like objects.
|
||||
Blocking operations are turned into async ones using threads.
|
||||
Note that this class can wrap pretty much anything with a fileno()
|
||||
and read/write methods
|
||||
"""
|
||||
|
||||
def fileno(self):
|
||||
return self.handle.fileno()
|
||||
|
||||
def __init__(self, f):
|
||||
self._file = f
|
||||
|
||||
@property
|
||||
def handle(self) -> io.IOBase:
|
||||
"""
|
||||
Returns the underlying (synchronous!) OS-specific
|
||||
handle for the given resource
|
||||
"""
|
||||
|
||||
return self._file
|
||||
|
||||
def __aiter__(self):
|
||||
return self
|
||||
|
||||
async def __anext__(self):
|
||||
line = await self.readline()
|
||||
if line:
|
||||
return line
|
||||
else:
|
||||
raise StopAsyncIteration
|
||||
|
||||
async def readall(self):
|
||||
chunks = []
|
||||
maxread = 65536
|
||||
sep = "" if hasattr(self._file, "encoding") else b""
|
||||
while True:
|
||||
chunk = await self.read(maxread)
|
||||
if not chunk:
|
||||
return sep.join(chunks)
|
||||
chunks.append(chunk)
|
||||
if len(chunk) == maxread:
|
||||
maxread *= 2
|
||||
|
||||
# Look, I get it, I can't keep stealing stuff from Trio,
|
||||
# but come on, it's so good!
|
||||
|
||||
def __getattr__(self, name):
|
||||
if name in _FILE_SYNC_ATTRS:
|
||||
return getattr(self.handle, name)
|
||||
if name in _FILE_ASYNC_METHODS:
|
||||
meth = getattr(self.handle, name)
|
||||
|
||||
async def wrapper(*args, **kwargs):
|
||||
func = partial(meth, *args, **kwargs)
|
||||
return await structio.thread.run_in_worker(func)
|
||||
|
||||
# cache the generated method
|
||||
setattr(self, name, wrapper)
|
||||
return wrapper
|
||||
raise AttributeError(name)
|
||||
|
||||
def __repr__(self):
|
||||
return f"structio.AsyncFile({self.handle})"
|
||||
|
||||
def __dir__(self):
|
||||
attrs = set(super().__dir__())
|
||||
attrs.update(a for a in _FILE_SYNC_ATTRS if hasattr(self.handle, a))
|
||||
attrs.update(a for a in _FILE_ASYNC_METHODS if hasattr(self.handle, a))
|
||||
return attrs
|
||||
|
||||
async def close(self):
|
||||
"""
|
||||
Closes the file asynchronously. If the operation
|
||||
is cancelled, the underlying file object is *still*
|
||||
closed!
|
||||
"""
|
||||
|
||||
# This operation is non-cancellable, meaning it'll run
|
||||
# no matter what our event loop has to say about it.
|
||||
# After we're done, we'll obviously re-raise the cancellation
|
||||
# if necessary. This ensures files are always closed even when
|
||||
# the operation gets cancelled
|
||||
await structio.thread.run_in_worker(self.handle.close)
|
||||
# If we were cancelled, here is where we raise
|
||||
await check_cancelled()
|
||||
|
||||
|
||||
async def open_file(
|
||||
file,
|
||||
mode="r",
|
||||
buffering=-1,
|
||||
encoding=None,
|
||||
errors=None,
|
||||
newline=None,
|
||||
closefd=True,
|
||||
opener=None,
|
||||
) -> AsyncFile:
|
||||
"""
|
||||
Like io.open(), but async
|
||||
"""
|
||||
|
||||
return wrap_file(
|
||||
await structio.thread.run_in_worker(
|
||||
io.open, file, mode, buffering, encoding, errors, newline, closefd, opener
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def wrap_file(file) -> AsyncFile:
|
||||
"""
|
||||
Wraps a file-like object into an async
|
||||
wrapper
|
||||
"""
|
||||
|
||||
return AsyncFile(file)
|
||||
|
||||
|
||||
stdin = wrap_file(sys.stdin)
|
||||
stdout = wrap_file(sys.stdout)
|
||||
stderr = wrap_file(sys.stderr)
|
||||
|
||||
|
||||
async def aprint(*args, sep=" ", end="\n", file=stdout, flush=False):
|
||||
"""
|
||||
Like print(), but asynchronous
|
||||
"""
|
||||
|
||||
await file.write(f"{sep.join(map(str, args))}{end}")
|
||||
if flush:
|
||||
await file.flush()
|
||||
|
||||
|
||||
async def ainput(prompt=None, /):
|
||||
"""
|
||||
Like input(), but asynchronous
|
||||
"""
|
||||
|
||||
await aprint(prompt, end="", flush=True)
|
||||
return (await stdin.readline()).rstrip("\n")
|
|
@ -0,0 +1,659 @@
|
|||
import warnings
|
||||
import platform
|
||||
from typing import Any
|
||||
import structio
|
||||
from structio.abc import AsyncResource
|
||||
from structio.io import FdWrapper, WantRead, WantWrite, SSLSocket
|
||||
from structio.thread import run_in_worker
|
||||
from structio.exceptions import ResourceClosed, ResourceBroken
|
||||
from structio.core.syscalls import (
|
||||
wait_readable,
|
||||
wait_writable,
|
||||
checkpoint,
|
||||
closing,
|
||||
release,
|
||||
)
|
||||
from functools import wraps
|
||||
import socket as _socket
|
||||
|
||||
try:
|
||||
import ssl as _ssl
|
||||
except ImportError:
|
||||
_ssl = None
|
||||
|
||||
|
||||
@wraps(_socket.socket)
|
||||
def socket(*args, **kwargs):
|
||||
return AsyncSocket(_socket.socket(*args, **kwargs))
|
||||
|
||||
|
||||
@wraps(_socket.fromfd)
|
||||
async def fromfd(
|
||||
fd: Any,
|
||||
family: _socket.AddressFamily | int,
|
||||
type: _socket.SocketKind | int,
|
||||
proto: int = 0,
|
||||
) -> "AsyncSocket":
|
||||
return AsyncSocket(_socket.fromfd(fd, family, type, proto))
|
||||
|
||||
|
||||
async def wrap_socket_with_ssl(
|
||||
sock, *args, context, do_handshake_on_connect=True, **kwargs
|
||||
):
|
||||
"""
|
||||
Wraps a regular unencrypted socket or a structio async socket into a
|
||||
TLS-capable asynchronous socket. All positional and keyword arguments
|
||||
(aside from context and do_handshake_on_connect) are passed to context.wrap_socket()
|
||||
(if context is None, one with reasonable defaults is created using ssl.create_default_context()).
|
||||
Note that the do_handshake_on_connect parameter passed to the given SSL context is always False,
|
||||
because structio handles TLS handshaking on its own (this means that you mostly don't need to care
|
||||
about where do_handshake_on_connect is set: it'll just work)
|
||||
"""
|
||||
|
||||
if not _ssl:
|
||||
raise RuntimeError("SSL is not supported on this platform")
|
||||
if isinstance(sock, AsyncSocket):
|
||||
sock = sock.socket
|
||||
sock: _socket.socket
|
||||
context: _ssl.SSLContext
|
||||
if context is None:
|
||||
context = _ssl.create_default_context()
|
||||
# do_handshake_on_connect MUST be set to False on the
|
||||
# synchronous socket! Structio performs the TLS handshake
|
||||
# asynchronously, and letting the SSL library handle it
|
||||
# blocks the entire event loop
|
||||
raw_ssl = context.wrap_socket(sock, *args, do_handshake_on_connect=False, **kwargs)
|
||||
wrapped = AsyncSocket(raw_ssl, do_handshake_on_connect=do_handshake_on_connect)
|
||||
if raw_ssl._connected:
|
||||
wrapped.connected = True
|
||||
if wrapped.do_handshake_on_connect and wrapped.connected:
|
||||
await wrapped.do_handshake()
|
||||
return wrapped
|
||||
|
||||
|
||||
# Wrappers of the socket module
|
||||
|
||||
|
||||
@wraps(_socket.socketpair)
|
||||
def socketpair(
|
||||
family=None, type=_socket.SOCK_STREAM, proto=0
|
||||
) -> tuple["AsyncSocket", "AsyncSocket"]:
|
||||
if family is None and platform.system() == "Windows":
|
||||
family = _socket.AF_INET
|
||||
a, b = _socket.socketpair(family, type, proto)
|
||||
return AsyncSocket(a), AsyncSocket(b)
|
||||
|
||||
|
||||
@wraps(_socket.getaddrinfo)
|
||||
async def getaddrinfo(
|
||||
host: bytearray | bytes | str | None,
|
||||
port: str | int | None,
|
||||
family: int = 0,
|
||||
type: int = 0,
|
||||
proto: int = 0,
|
||||
flags: int = 0,
|
||||
):
|
||||
return await run_in_worker(
|
||||
_socket.getaddrinfo, host, port, family, type, proto, flags, cancellable=True
|
||||
)
|
||||
|
||||
|
||||
@wraps(_socket.getfqdn)
|
||||
async def getfqdn(name: str) -> str:
|
||||
return await run_in_worker(_socket.getfqdn, name, cancellable=True)
|
||||
|
||||
|
||||
@wraps(_socket.getnameinfo)
|
||||
async def getnameinfo(
|
||||
sockaddr: tuple[str, int] | tuple[str, int, int, int], flags: int
|
||||
) -> tuple[str, str]:
|
||||
return await run_in_worker(_socket.getnameinfo, sockaddr, flags, cancellable=True)
|
||||
|
||||
|
||||
@wraps(_socket.gethostname)
|
||||
async def gethostname() -> str:
|
||||
return await run_in_worker(_socket.gethostname, cancellable=True)
|
||||
|
||||
|
||||
@wraps(_socket.gethostbyaddr)
|
||||
async def gethostbyaddr(ip_address: str) -> tuple[str, list[str], list[str]]:
|
||||
return await run_in_worker(_socket.gethostbyaddr, ip_address, cancellable=True)
|
||||
|
||||
|
||||
@wraps(_socket.gethostbyname)
|
||||
async def gethostbyname(hostname: str) -> str:
|
||||
return await run_in_worker(_socket.gethostbyname, hostname, cancellable=True)
|
||||
|
||||
|
||||
@wraps(_socket.gethostbyname_ex)
|
||||
async def gethostbyname_ex(hostname: str) -> tuple[str, list[str], list[str]]:
|
||||
return await run_in_worker(_socket.gethostbyname_ex, hostname, cancellable=True)
|
||||
|
||||
|
||||
# As per RFC 8305, the default connection delay is set to
|
||||
# 250 milliseconds
|
||||
CONNECT_DELAY: float = 0.250
|
||||
|
||||
|
||||
async def connect_tcp_socket(
|
||||
host: str | bytes,
|
||||
port: int,
|
||||
*,
|
||||
source_address=None,
|
||||
happy_eyeballs_delay: float = CONNECT_DELAY,
|
||||
) -> "AsyncSocket":
|
||||
"""
|
||||
Resolve the given (non-numeric) host and attempt to connect to it, at the chosen port.
|
||||
Connection attempts are made according to the "Happy eyeballs" algorithm as per RFC
|
||||
8305. If source_address is provided, the connection is established through that,
|
||||
otherwise we let the OS pick one. The happy_eyeballs_delay parameter controls
|
||||
how much time (in seconds) we wait for a connection attempt to stall before
|
||||
attempting the next one (by default it's set to 250ms)
|
||||
"""
|
||||
|
||||
# Trio states these behaviors are technically accepted by
|
||||
# (some versions of) getaddrinfo, but they're non-portable,
|
||||
# broken or not useful. And so we follow
|
||||
if host is None:
|
||||
raise ValueError("host can't be None")
|
||||
if not isinstance(port, int):
|
||||
raise TypeError(f"port must be int, got {type(port)} instead")
|
||||
|
||||
hosts = await getaddrinfo(host, port, type=_socket.SOCK_STREAM)
|
||||
# RFC 8305 specifies that if we get addresses of different families,
|
||||
# our first two connection attempts should be using different ones
|
||||
# (in english: if getaddrinfo() returns, say, 2 IPV4 addresses and one IPV6
|
||||
# address, then we have to make sure our first and second attempt use one
|
||||
# of each type)
|
||||
for i in range(1, len(hosts)):
|
||||
# If the family of the ith socket (skipping
|
||||
# the first one) is different from that of
|
||||
# our very first socket (and if it isn't already
|
||||
# in second place) then we pick it and shift it
|
||||
# all the way to second place
|
||||
if hosts[i][0] != hosts[0][0] and i != 1:
|
||||
hosts.insert(1, hosts.pop(i))
|
||||
break
|
||||
if not hosts:
|
||||
# Trio is paranoid, and so are we. This should never happen
|
||||
# by the way, getaddrinfo will just raise OSError on its own
|
||||
raise OSError(f"name resolution failed for {host!r}")
|
||||
# Store all sockets we create when attempting to connect so
|
||||
# that we can shut them down later
|
||||
sockets: list[AsyncSocket] = []
|
||||
# The socket that managed to connect
|
||||
successful: AsyncSocket | None = None
|
||||
# We chain exceptions via __cause__ so that we can
|
||||
# provide information about all failed attempts if
|
||||
# we don't manage to connect
|
||||
exc_obj: Exception | None = None
|
||||
|
||||
async def attempt(sock_data, addr, evt: structio.Event, scope: structio.TaskScope):
|
||||
nonlocal successful
|
||||
try:
|
||||
attempt_sock = socket(*sock_data)
|
||||
sockets.append(attempt_sock)
|
||||
if source_address:
|
||||
# This trick (again stolen from Trio), lets us
|
||||
# bind to a given address without actually busying
|
||||
# up a local port up until the moment where we actually
|
||||
# need to connect. That way, we can perform as many connection
|
||||
# attempts as we want from a given source address without ever
|
||||
# worrying about running out of local ports
|
||||
try:
|
||||
attempt_sock.setsockopt(
|
||||
_socket.IPPROTO_IP, _socket.IP_BIND_ADDRESS_NO_PORT, 1
|
||||
)
|
||||
except (OSError, AttributeError):
|
||||
# Not all platforms support this option (for example,
|
||||
# my Linux/Windows installations don't seem to have IP_BIND_ADDRESS_NO_PORT
|
||||
# defined in the socket module), so if setting it fails
|
||||
# we just ignore it and hope for the best
|
||||
pass
|
||||
# This makes sure users don't try to send IPv6
|
||||
# traffic with an IPv4 source address
|
||||
try:
|
||||
await attempt_sock.bind((source_address, 0))
|
||||
except OSError:
|
||||
# Almost hit the 120 character line, phew...
|
||||
raise OSError(
|
||||
f"Source address {source_address!r} is incompatible with remote address {addr!r}"
|
||||
)
|
||||
await attempt_sock.connect(addr)
|
||||
# Hooray! Connection was successful. Record the socket
|
||||
# and cancel the rest of the attempts (either future or
|
||||
# currently running)
|
||||
successful = attempt_sock
|
||||
sockets.remove(attempt_sock)
|
||||
scope.cancel()
|
||||
except OSError as exc:
|
||||
# Well, this attempt failed. Right now, we just ignore the error (we'll
|
||||
# fail with OSError later if all connection attempts fail), but we should
|
||||
# really have support for ExceptionGroups (coming soon btw), so we can
|
||||
# keep track of all the errors and use fancy stuff like the new except*
|
||||
# syntax introduced in Python 3.11
|
||||
|
||||
# Oh, and we also notify our next attempt that they can start early
|
||||
evt.set()
|
||||
|
||||
nonlocal exc_obj
|
||||
if exc_obj:
|
||||
exc_obj.__cause__ = exc
|
||||
else:
|
||||
exc_obj = exc
|
||||
|
||||
try:
|
||||
async with structio.create_pool() as pool:
|
||||
for *sock_args, _, address in hosts:
|
||||
# This event notifies us if a connection attempt
|
||||
# fails, so we can start early
|
||||
event = structio.Event()
|
||||
pool.spawn(attempt, sock_args, address, event, pool.scope)
|
||||
with structio.skip_after(happy_eyeballs_delay):
|
||||
# We'll wait for the event to be triggered or for at
|
||||
# most happy_eyeballs_delay seconds before moving on,
|
||||
# whichever happens first
|
||||
await event.wait()
|
||||
finally:
|
||||
for sock in sockets:
|
||||
if sock.fileno() == -1:
|
||||
# Socket is already dead
|
||||
continue
|
||||
try:
|
||||
# FIXME: Could this block forever? I mean, maybe it's not
|
||||
# such a huge deal since you can always wrap the call in
|
||||
# a timeout or something, but it may be something worth
|
||||
#
|
||||
await sock.close()
|
||||
except BaseException as e:
|
||||
# Again, we shouldn't be ignoring
|
||||
# errors willy-nilly like that, but
|
||||
# hey beta software am I right?
|
||||
warnings.warn(
|
||||
f"Failed to close {sock!r} in call to connect_socket -> {type(e).__name__}: {e}"
|
||||
)
|
||||
continue
|
||||
if not successful:
|
||||
# All connection attempts failed
|
||||
err = OSError(f"all connection attempts to {host}:{port} failed")
|
||||
if exc_obj:
|
||||
raise exc_obj from err
|
||||
raise err
|
||||
return successful
|
||||
|
||||
|
||||
async def connect_tcp_ssl_socket(
|
||||
host: str | bytes,
|
||||
port: int,
|
||||
*,
|
||||
ssl_context=None,
|
||||
source_address=None,
|
||||
happy_eyeballs_delay: float = CONNECT_DELAY,
|
||||
) -> "AsyncSocket":
|
||||
"""
|
||||
Convenience wrapper over connect_socket with SSL/TLS functionality
|
||||
"""
|
||||
|
||||
if not _ssl:
|
||||
raise RuntimeError("SSL is not supported on the current platform")
|
||||
return await wrap_socket_with_ssl(
|
||||
await connect_tcp_socket(
|
||||
host,
|
||||
port,
|
||||
source_address=source_address,
|
||||
happy_eyeballs_delay=happy_eyeballs_delay,
|
||||
),
|
||||
context=ssl_context,
|
||||
server_hostname=host,
|
||||
)
|
||||
|
||||
|
||||
class AsyncSocket(AsyncResource):
|
||||
"""
|
||||
Abstraction layer for asynchronous sockets
|
||||
"""
|
||||
|
||||
def fileno(self):
|
||||
return int(self._fd)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
sock: _socket.socket,
|
||||
do_handshake_on_connect: bool = True,
|
||||
):
|
||||
self._fd = FdWrapper(sock.fileno())
|
||||
# Do we perform the TCP handshake automatically
|
||||
# upon connection? This is only needed for SSL
|
||||
# sockets
|
||||
self.do_handshake_on_connect = do_handshake_on_connect
|
||||
self.socket = sock
|
||||
self.socket.setblocking(False)
|
||||
self.connected: bool = False
|
||||
self.write_lock = structio.util.misc.ThereCanBeOnlyOne(
|
||||
"another task is writing on this socket"
|
||||
)
|
||||
self.read_lock = structio.util.misc.ThereCanBeOnlyOne(
|
||||
"another task is reading from this socket"
|
||||
)
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
await self.close()
|
||||
|
||||
async def receive(self, max_size: int, flags: int = 0) -> bytes:
|
||||
"""
|
||||
Receives up to max_size bytes from a socket asynchronously
|
||||
"""
|
||||
|
||||
assert max_size >= 1, "max_size must be >= 1"
|
||||
if self._fd == -1:
|
||||
raise ResourceClosed("I/O operation on closed socket")
|
||||
with self.read_lock:
|
||||
while True:
|
||||
try:
|
||||
data = self.socket.recv(max_size, flags)
|
||||
await checkpoint()
|
||||
return data
|
||||
except WantRead:
|
||||
await wait_readable(self._fd)
|
||||
|
||||
async def receive_exactly(self, size: int, flags: int = 0) -> bytes:
|
||||
"""
|
||||
Receives exactly size bytes from a socket asynchronously
|
||||
"""
|
||||
|
||||
# https://stackoverflow.com/questions/55825905/how-can-i-reliably-read-exactly-n-bytes-from-a-tcp-socket
|
||||
buf = bytearray(size)
|
||||
pos = 0
|
||||
while pos < size:
|
||||
n = await self.recv_into(memoryview(buf)[pos:], flags=flags)
|
||||
if n == 0:
|
||||
raise ResourceBroken("incomplete read detected")
|
||||
pos += n
|
||||
return bytes(buf)
|
||||
|
||||
async def connect(self, address):
|
||||
"""
|
||||
Wrapper socket method
|
||||
"""
|
||||
|
||||
if self._fd == -1:
|
||||
raise ResourceClosed("I/O operation on closed socket")
|
||||
with self.write_lock, self.read_lock:
|
||||
while not self.connected:
|
||||
try:
|
||||
self.socket.connect(address)
|
||||
if self.do_handshake_on_connect:
|
||||
await self.do_handshake()
|
||||
self.connected = True
|
||||
await checkpoint()
|
||||
except WantRead:
|
||||
await wait_readable(self._fd)
|
||||
except WantWrite:
|
||||
await wait_writable(self._fd)
|
||||
|
||||
async def close(self):
|
||||
"""
|
||||
Wrapper socket method
|
||||
"""
|
||||
|
||||
if self.connected:
|
||||
# We set our own fd to -1
|
||||
# before calling any async
|
||||
# primitive so that any further
|
||||
# I/O operations on this object
|
||||
# fail before they even start.
|
||||
# Of course, we still want to
|
||||
# release the actual fd, so we
|
||||
# save it separately
|
||||
fd = self._fd
|
||||
self._fd = -1
|
||||
await closing(fd)
|
||||
await release(fd)
|
||||
self.socket.close()
|
||||
else:
|
||||
await checkpoint()
|
||||
|
||||
async def accept(self):
|
||||
"""
|
||||
Accepts the socket, completing the 3-step TCP handshake asynchronously
|
||||
"""
|
||||
|
||||
if self._fd == -1:
|
||||
raise ResourceClosed("I/O operation on closed socket")
|
||||
with self.read_lock:
|
||||
while True:
|
||||
try:
|
||||
remote, addr = self.socket.accept()
|
||||
await checkpoint()
|
||||
return type(self)(remote), addr
|
||||
except WantRead:
|
||||
await wait_readable(self._fd)
|
||||
|
||||
async def send_all(self, data: bytes, flags: int = 0):
|
||||
"""
|
||||
Sends all the provided data asynchronously
|
||||
"""
|
||||
|
||||
if self._fd == -1:
|
||||
raise ResourceClosed("I/O operation on closed socket")
|
||||
with self.write_lock:
|
||||
sent_no = 0
|
||||
while data:
|
||||
try:
|
||||
sent_no = self.socket.send(data, flags)
|
||||
await checkpoint()
|
||||
except WantRead:
|
||||
await wait_readable(self._fd)
|
||||
except WantWrite:
|
||||
await wait_writable(self._fd)
|
||||
data = data[sent_no:]
|
||||
|
||||
async def shutdown(self, how):
|
||||
"""
|
||||
Wrapper socket method
|
||||
"""
|
||||
|
||||
if self.fileno() == -1:
|
||||
raise ResourceClosed("I/O operation on closed socket")
|
||||
if self.socket:
|
||||
self.socket.shutdown(how)
|
||||
await checkpoint()
|
||||
|
||||
async def bind(self, addr: tuple):
|
||||
"""
|
||||
Binds the socket to an address
|
||||
:param addr: The address, port tuple to bind to
|
||||
:type addr: tuple
|
||||
"""
|
||||
|
||||
if self._fd == -1:
|
||||
raise ResourceClosed("I/O operation on closed socket")
|
||||
self.socket.bind(addr)
|
||||
await checkpoint()
|
||||
|
||||
async def listen(self, backlog: int):
|
||||
"""
|
||||
Starts listening with the given backlog
|
||||
:param backlog: The socket's backlog
|
||||
:type backlog: int
|
||||
"""
|
||||
|
||||
if self._fd == -1:
|
||||
raise ResourceClosed("I/O operation on closed socket")
|
||||
self.socket.listen(backlog)
|
||||
await checkpoint()
|
||||
|
||||
def settimeout(self, seconds):
|
||||
"""
|
||||
Wrapper socket method
|
||||
"""
|
||||
|
||||
raise RuntimeError("Use with_timeout() to set a timeout")
|
||||
|
||||
def gettimeout(self):
|
||||
"""
|
||||
Wrapper socket method
|
||||
"""
|
||||
|
||||
return None
|
||||
|
||||
def dup(self):
|
||||
"""
|
||||
Wrapper socket method
|
||||
"""
|
||||
|
||||
return type(self)(self.socket.dup(), self.do_handshake_on_connect)
|
||||
|
||||
def setsockopt(self, *args, **kwargs):
|
||||
"""
|
||||
Wrapper socket method
|
||||
"""
|
||||
|
||||
return self.socket.setsockopt(*args, **kwargs)
|
||||
|
||||
async def do_handshake(self):
|
||||
"""
|
||||
Wrapper socket method
|
||||
"""
|
||||
|
||||
if not hasattr(self.socket, "do_handshake"):
|
||||
# Regular sockets don't have a do_handshake method
|
||||
return
|
||||
with self.read_lock, self.write_lock:
|
||||
while True:
|
||||
try:
|
||||
self.socket: SSLSocket # Silences pycharm warnings
|
||||
self.socket.do_handshake()
|
||||
await checkpoint()
|
||||
break
|
||||
except WantRead:
|
||||
await wait_readable(self._fd)
|
||||
except WantWrite:
|
||||
await wait_writable(self._fd)
|
||||
|
||||
async def recvfrom(self, buffersize, flags=0):
|
||||
"""
|
||||
Wrapper socket method
|
||||
"""
|
||||
|
||||
with self.read_lock:
|
||||
while True:
|
||||
try:
|
||||
data = self.socket.recvfrom(buffersize, flags)
|
||||
await checkpoint()
|
||||
return data
|
||||
except WantRead:
|
||||
await wait_readable(self._fd)
|
||||
except WantWrite:
|
||||
await wait_writable(self._fd)
|
||||
|
||||
async def recv_into(self, buffer, nbytes=0, flags=0):
|
||||
"""
|
||||
Wrapper socket method
|
||||
"""
|
||||
|
||||
with self.read_lock:
|
||||
while True:
|
||||
try:
|
||||
data = self.socket.recv_into(buffer, nbytes, flags)
|
||||
await checkpoint()
|
||||
return data
|
||||
except WantRead:
|
||||
await wait_readable(self._fd)
|
||||
except WantWrite:
|
||||
await wait_writable(self._fd)
|
||||
|
||||
async def recvfrom_into(self, buffer, bytes=0, flags=0):
|
||||
"""
|
||||
Wrapper socket method
|
||||
"""
|
||||
|
||||
with self.read_lock:
|
||||
while True:
|
||||
try:
|
||||
data = self.socket.recvfrom_into(buffer, bytes, flags)
|
||||
await checkpoint()
|
||||
return data
|
||||
except WantRead:
|
||||
await wait_readable(self._fd)
|
||||
|
||||
async def sendto(self, bytes, flags_or_address, address=None):
|
||||
"""
|
||||
Wrapper socket method
|
||||
"""
|
||||
|
||||
if address:
|
||||
flags = flags_or_address
|
||||
else:
|
||||
address = flags_or_address
|
||||
flags = 0
|
||||
with self.write_lock:
|
||||
while True:
|
||||
try:
|
||||
data = self.socket.sendto(bytes, flags, address)
|
||||
await checkpoint()
|
||||
return data
|
||||
except WantWrite:
|
||||
await wait_writable(self._fd)
|
||||
|
||||
def getpeername(self):
|
||||
"""
|
||||
Wrapper socket method
|
||||
"""
|
||||
|
||||
return self.socket.getpeername()
|
||||
|
||||
def getsockname(self):
|
||||
"""
|
||||
Wrapper socket method
|
||||
"""
|
||||
|
||||
return self.socket.getsockname()
|
||||
|
||||
async def recvmsg(self, bufsize, ancbufsize=0, flags=0):
|
||||
"""
|
||||
Wrapper socket method
|
||||
"""
|
||||
|
||||
with self.read_lock:
|
||||
while True:
|
||||
try:
|
||||
data = self.socket.recvmsg(bufsize, ancbufsize, flags)
|
||||
await checkpoint()
|
||||
return data
|
||||
except WantRead:
|
||||
await wait_readable(self._fd)
|
||||
|
||||
async def recvmsg_into(self, buffers, ancbufsize=0, flags=0):
|
||||
"""
|
||||
Wrapper socket method
|
||||
"""
|
||||
|
||||
with self.read_lock:
|
||||
while True:
|
||||
try:
|
||||
data = self.socket.recvmsg_into(buffers, ancbufsize, flags)
|
||||
await checkpoint()
|
||||
return data
|
||||
except WantRead:
|
||||
await wait_readable(self._fd)
|
||||
|
||||
async def sendmsg(self, buffers, ancdata=(), flags=0, address=None):
|
||||
"""
|
||||
Wrapper socket method
|
||||
"""
|
||||
|
||||
with self.write_lock:
|
||||
while True:
|
||||
try:
|
||||
data = self.socket.sendmsg(buffers, ancdata, flags, address)
|
||||
await checkpoint()
|
||||
return data
|
||||
except WantWrite:
|
||||
await wait_writable(self._fd)
|
||||
|
||||
def __repr__(self):
|
||||
return f"AsyncSocket({self.socket})"
|
|
@ -0,0 +1,124 @@
|
|||
"""Module inspired by subprocess which allows for asynchronous
|
||||
multiprocessing"""
|
||||
|
||||
import os
|
||||
import structio
|
||||
import platform
|
||||
import subprocess
|
||||
from subprocess import CalledProcessError, CompletedProcess, DEVNULL, PIPE
|
||||
from structio.io import FileStream
|
||||
|
||||
if platform.system() == "Windows":
|
||||
# Windows doesn't really support non-blocking file
|
||||
# descriptors (except sockets), so we just use threads
|
||||
from structio.io.files import AsyncFile as FileStream
|
||||
|
||||
|
||||
class Popen:
|
||||
"""
|
||||
Wrapper around subprocess.Popen, but async
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
if "universal_newlines" in kwargs:
|
||||
# Not sure why? But everyone else is doing it so :shrug:
|
||||
raise RuntimeError("universal_newlines is not supported")
|
||||
if stdin := kwargs.get("stdin"):
|
||||
if stdin not in {PIPE, DEVNULL}:
|
||||
# Curio mentions stuff breaking if the child process
|
||||
# is passed a stdin fd that is set to non-blocking mode
|
||||
if hasattr(os, "set_blocking"):
|
||||
os.set_blocking(stdin.fileno(), True)
|
||||
# Delegate to Popen's constructor
|
||||
self._process: subprocess.Popen = subprocess.Popen(*args, **kwargs)
|
||||
self.stdin = None
|
||||
self.stdout = None
|
||||
self.stderr = None
|
||||
if self._process.stdin:
|
||||
self.stdin = FileStream(self._process.stdin)
|
||||
if self._process.stdout:
|
||||
self.stdout = FileStream(self._process.stdout)
|
||||
if self._process.stderr:
|
||||
self.stderr = FileStream(self._process.stderr)
|
||||
|
||||
async def wait(self):
|
||||
status = self._process.poll()
|
||||
if status is None:
|
||||
status = await structio.thread.run_in_worker(
|
||||
self._process.wait, cancellable=True
|
||||
)
|
||||
return status
|
||||
|
||||
async def communicate(self, input=b"") -> tuple[bytes, bytes]:
|
||||
async with structio.create_pool() as pool:
|
||||
stdout = pool.spawn(self.stdout.readall) if self.stdout else None
|
||||
stderr = pool.spawn(self.stderr.readall) if self.stderr else None
|
||||
if input:
|
||||
await self.stdin.write(input)
|
||||
await self.stdin.close()
|
||||
# Awaiting a task object waits for its completion and
|
||||
# returns its return value!
|
||||
out = b""
|
||||
err = b""
|
||||
if stdout:
|
||||
out = await stdout
|
||||
if stderr:
|
||||
err = await stderr
|
||||
return out, err
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *args):
|
||||
if self.stdin:
|
||||
await self.stdin.close()
|
||||
if self.stdout:
|
||||
await self.stdout.close()
|
||||
if self.stderr:
|
||||
await self.stderr.close()
|
||||
await self.wait()
|
||||
|
||||
def __getattr__(self, item):
|
||||
# Delegate to internal process object
|
||||
return getattr(self._process, item)
|
||||
|
||||
|
||||
async def run(
|
||||
args, *, stdin=None, input=None, stdout=None, stderr=None, shell=False, check=False
|
||||
):
|
||||
"""
|
||||
Async version of subprocess.run()
|
||||
"""
|
||||
|
||||
if input:
|
||||
stdin = subprocess.PIPE
|
||||
async with Popen(
|
||||
args, stdin=stdin, stdout=stdout, stderr=stderr, shell=shell
|
||||
) as process:
|
||||
try:
|
||||
stdout, stderr = await process.communicate(input)
|
||||
except:
|
||||
process.kill()
|
||||
raise
|
||||
|
||||
status = process.poll()
|
||||
if check and status:
|
||||
raise CalledProcessError(status, process.args, output=stdout, stderr=stderr)
|
||||
return CompletedProcess(process.args, status, stdout, stderr)
|
||||
|
||||
|
||||
async def check_output(args, *, stdin=None, stderr=None, shell=False, input=None):
|
||||
"""
|
||||
Async version of subprocess.check_output
|
||||
"""
|
||||
|
||||
out = await run(
|
||||
args,
|
||||
stdout=PIPE,
|
||||
stdin=stdin,
|
||||
stderr=stderr,
|
||||
shell=shell,
|
||||
check=True,
|
||||
input=input,
|
||||
)
|
||||
return out.stdout
|
|
@ -0,0 +1,169 @@
|
|||
# Async wrapper for pathlib.Path (blocking calls are run in threads)
|
||||
import os
|
||||
from functools import partial, wraps
|
||||
import structio
|
||||
import pathlib
|
||||
from structio.core.syscalls import checkpoint
|
||||
|
||||
|
||||
_SYNC = {
|
||||
"as_posix",
|
||||
"as_uri",
|
||||
"is_absolute",
|
||||
"is_reserved",
|
||||
"joinpath",
|
||||
"match",
|
||||
"relative_to",
|
||||
"with_name",
|
||||
"with_suffix",
|
||||
}
|
||||
|
||||
_ASYNC = {
|
||||
"chmod",
|
||||
"exists",
|
||||
"expanduser",
|
||||
"glob",
|
||||
"group",
|
||||
"is_block_device",
|
||||
"is_char_device",
|
||||
"is_dir",
|
||||
"is_fifo",
|
||||
"is_file",
|
||||
"is_mount",
|
||||
"is_socket",
|
||||
"is_symlink",
|
||||
"lchmod",
|
||||
"lstat",
|
||||
"mkdir",
|
||||
"owner",
|
||||
"read_bytes",
|
||||
"read_text",
|
||||
"rename",
|
||||
"replace",
|
||||
"resolve",
|
||||
"rglob",
|
||||
"rmdir",
|
||||
"samefile",
|
||||
"stat",
|
||||
"symlink_to",
|
||||
"touch",
|
||||
"unlink",
|
||||
"rmdir",
|
||||
"write_text",
|
||||
"write_bytes",
|
||||
}
|
||||
|
||||
|
||||
def _wrap(v):
|
||||
if isinstance(v, pathlib.Path):
|
||||
return Path(v)
|
||||
return v
|
||||
|
||||
|
||||
class Path:
|
||||
"""
|
||||
A wrapper to pathlib.Path which executes
|
||||
blocking calls using structio.thread.run_in_worker
|
||||
"""
|
||||
|
||||
def __init__(self, *args):
|
||||
self._sync_path: pathlib.Path = pathlib.Path(*args)
|
||||
|
||||
@classmethod
|
||||
@wraps(pathlib.Path.cwd)
|
||||
async def cwd(*args, **kwargs):
|
||||
"""
|
||||
Like pathlib.Path.cwd(), but async
|
||||
"""
|
||||
|
||||
# This method is special and can't be just forwarded like the others because
|
||||
# it is a class method and I don't feel like doing all the wild metaprogramming
|
||||
# stuff that Trio did (which is cool but gooood luck debugging that), so here ya go.
|
||||
return _wrap(
|
||||
await structio.thread.run_in_worker(pathlib.Path.cwd, *args, **kwargs)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@wraps(pathlib.Path.home)
|
||||
async def home(*args, **kwargs):
|
||||
"""
|
||||
Like pathlib.Path.home(), but async
|
||||
"""
|
||||
|
||||
return _wrap(
|
||||
await structio.thread.run_in_worker(pathlib.Path.home, *args, **kwargs)
|
||||
)
|
||||
|
||||
@wraps(pathlib.Path.open)
|
||||
async def open(self, *args, **kwargs):
|
||||
"""
|
||||
Like pathlib.Path.open(), but async
|
||||
"""
|
||||
|
||||
f = await structio.thread.run_in_worker(self._sync_path.open, *args, **kwargs)
|
||||
return structio.wrap_file(f)
|
||||
|
||||
def __repr__(self):
|
||||
return f"structio.Path({repr(str(self._sync_path))})"
|
||||
|
||||
def __dir__(self):
|
||||
return super().__dir__()
|
||||
|
||||
async def iterdir(self):
|
||||
"""
|
||||
Like pathlib.Path.iterdir(), but async
|
||||
"""
|
||||
|
||||
# Inspired by https://github.com/python-trio/trio/issues/501#issuecomment-381724137
|
||||
func = partial(os.listdir, self._sync_path)
|
||||
files = await structio.thread.run_in_worker(func)
|
||||
|
||||
async def agen():
|
||||
if not files:
|
||||
await checkpoint()
|
||||
for name in files:
|
||||
yield self._sync_path._make_child_relpath(name)
|
||||
await checkpoint()
|
||||
|
||||
return agen()
|
||||
|
||||
def __fspath__(self):
|
||||
return os.fspath(self._wrapped)
|
||||
|
||||
def __truediv__(self, other):
|
||||
return _wrap(self._sync_path.__truediv__(other))
|
||||
|
||||
def __rtruediv__(self, other):
|
||||
return _wrap(self._sync_path.__rtruediv__(other))
|
||||
|
||||
def __getattr__(self, attr: str):
|
||||
# We use a similar trick to the one we stole from
|
||||
# Trio for async files, except we also wrap sync
|
||||
# methods because we want them to return our own
|
||||
# Path objects, not pathlib.Path!
|
||||
|
||||
if attr in _SYNC:
|
||||
# We duplicate the code here because we only
|
||||
# want to forward the stuff in _SYNC and _ASYNC,
|
||||
# not everything (like our cwd() classmethod above)
|
||||
m = getattr(self._sync_path, attr)
|
||||
|
||||
@wraps(m)
|
||||
def wrapper(*args, **kwargs):
|
||||
return _wrap(m(*args, **kwargs))
|
||||
|
||||
setattr(self, attr, wrapper)
|
||||
return wrapper
|
||||
if attr in _ASYNC:
|
||||
m = getattr(self._sync_path, attr)
|
||||
|
||||
@wraps(m)
|
||||
async def wrapper(*args, **kwargs):
|
||||
f = partial(m, *args, **kwargs)
|
||||
return _wrap(await structio.thread.run_in_worker(f))
|
||||
|
||||
setattr(self, attr, wrapper)
|
||||
return wrapper
|
||||
# Falls down to __getattribute__, which may find a cached
|
||||
# method we generated earlier!
|
||||
raise AttributeError(attr)
|
|
@ -0,0 +1,91 @@
|
|||
# Signal handling module
|
||||
import platform
|
||||
import signal
|
||||
from collections import defaultdict
|
||||
from types import FrameType
|
||||
from structio.io.socket import AsyncSocket
|
||||
from typing import Callable, Any, Coroutine
|
||||
from structio.thread import AsyncThreadQueue
|
||||
from structio.core.run import current_loop
|
||||
|
||||
|
||||
_sig_data = AsyncThreadQueue(float("inf"))
|
||||
_sig_handlers: dict[
|
||||
signal.Signals, Callable[[Any, Any], Coroutine[Any, Any, Any]] | None
|
||||
] = defaultdict(lambda: None)
|
||||
|
||||
|
||||
def _handle(sig: int, frame: FrameType):
|
||||
_sig_data.put_sync((sig, frame))
|
||||
|
||||
|
||||
def get_signal_handler(
|
||||
sig: int,
|
||||
) -> Callable[[Any, Any], Coroutine[Any, Any, Any]] | None:
|
||||
"""
|
||||
Returns the currently installed async signal handler for the
|
||||
given signal or None if it is not set
|
||||
"""
|
||||
|
||||
return _sig_handlers[signal.Signals(sig)]
|
||||
|
||||
|
||||
def set_signal_handler(
|
||||
sig: int, handler: Callable[[Any, Any], Coroutine[Any, Any, Any]]
|
||||
) -> Callable[[Any, Any], Coroutine[Any, Any, Any]] | None:
|
||||
"""
|
||||
Sets the given coroutine to handle the given signal asynchronously. The
|
||||
previous async signal handler is returned if any was set, otherwise
|
||||
None is returned
|
||||
"""
|
||||
|
||||
# Raises an appropriate error
|
||||
sig = signal.Signals(sig)
|
||||
illegal_signals = []
|
||||
if platform.system() in {"Linux", "Darwin"}:
|
||||
# Linux/MacOS
|
||||
illegal_signals.append(signal.SIGKILL)
|
||||
illegal_signals.append(signal.SIGSTOP)
|
||||
match sig:
|
||||
case sig if sig in illegal_signals:
|
||||
raise ValueError(f"signal {sig!r} does not support custom handlers")
|
||||
case _:
|
||||
prev = _sig_handlers[sig]
|
||||
signal.signal(sig, _handle)
|
||||
_sig_handlers[sig] = handler
|
||||
return prev
|
||||
|
||||
|
||||
async def signal_watcher(sock: AsyncSocket):
|
||||
while True:
|
||||
# Even though we use set_wakeup_fd (which makes sure
|
||||
# our I/O manager is signal-aware and exits cleanly
|
||||
# when they arrive), it turns out that actually using the
|
||||
# data Python sends over our socket is trickier than it
|
||||
# sounds at first. That is because if we receive a bunch
|
||||
# of signals and the socket buffer gets filled, we are going
|
||||
# to lose all signals after that. Python can raise a warning
|
||||
# about this, but it's 1) Not ideal, we're still losing signals,
|
||||
# which is bad if we can do better and 2) It can be confusing,
|
||||
# because now we're leaking details about the way signals are
|
||||
# implemented, and that sucks too; So instead, we use set_wakeup_fd
|
||||
# merely as a notification mechanism to wake up our watcher and
|
||||
# register a custom signal handler that stores all the information
|
||||
# about incoming signals in an unbuffered queue (which means that even
|
||||
# if the socket's buffer gets filled, we are still going to deliver all
|
||||
# signals when we do our first call to read()). I'm a little uneasy about
|
||||
# using an unbounded queue, but realistically I doubt that one would face
|
||||
# memory problems because their code is receiving thousands of signals and
|
||||
# the event loop is not handling them fast enough (right?)
|
||||
await sock.receive(1)
|
||||
async for (sig, frame) in _sig_data:
|
||||
if _sig_handlers[sig]:
|
||||
try:
|
||||
await _sig_handlers[sig](sig, frame)
|
||||
except (Exception, KeyboardInterrupt) as e:
|
||||
# We try to mimic the behavior of native signal
|
||||
# handlers by propagating errors into the program's
|
||||
# entry point when an exception occurs. This is far
|
||||
# from ideal, but I don't honestly know what else to
|
||||
# do with this exception
|
||||
current_loop().throw(current_loop().entry_point, e)
|
|
@ -0,0 +1,525 @@
|
|||
# Task synchronization primitives
|
||||
import structio
|
||||
from structio.core.syscalls import suspend, checkpoint
|
||||
from structio.exceptions import ResourceClosed, WouldBlock
|
||||
from structio.core.run import current_task, current_loop
|
||||
from structio.abc import ChannelReader, ChannelWriter, Channel
|
||||
from structio.util.ki import enable_ki_protection
|
||||
from structio.util.misc import ThereCanBeOnlyOne
|
||||
from structio.core.task import Task
|
||||
from collections import deque, defaultdict
|
||||
from typing import Any, Callable, Coroutine
|
||||
from functools import partial, wraps
|
||||
|
||||
|
||||
class Event:
|
||||
"""
|
||||
A wrapper around a boolean value that can be waited
|
||||
on asynchronously. The majority of structio's API is
|
||||
designed on top of/around this class, as it constitutes
|
||||
the simplest synchronization primitive there is
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
Public object constructor
|
||||
"""
|
||||
|
||||
self._set = False
|
||||
self._tasks: deque[Task] = deque()
|
||||
|
||||
def is_set(self):
|
||||
return self._set
|
||||
|
||||
@enable_ki_protection
|
||||
async def wait(self):
|
||||
"""
|
||||
Wait until someone else calls set() on
|
||||
this event. If the event has already been
|
||||
set, this method returns immediately
|
||||
"""
|
||||
|
||||
if self.is_set():
|
||||
await checkpoint()
|
||||
return
|
||||
self._tasks.append(current_task())
|
||||
await suspend() # We get re-scheduled by set()
|
||||
|
||||
@enable_ki_protection
|
||||
def set(self):
|
||||
"""
|
||||
Sets the event, awaking all tasks
|
||||
that called wait() on it
|
||||
"""
|
||||
|
||||
if self.is_set():
|
||||
raise RuntimeError(
|
||||
"this event has already been set: create a new Event object instead"
|
||||
)
|
||||
self._set = True
|
||||
for waiter in self._tasks:
|
||||
current_loop().reschedule(waiter)
|
||||
self._tasks.clear()
|
||||
|
||||
|
||||
class Queue:
|
||||
"""
|
||||
An asynchronous FIFO queue
|
||||
"""
|
||||
|
||||
def __init__(self, maxsize: int | None = None):
|
||||
"""
|
||||
Object constructor
|
||||
"""
|
||||
|
||||
self.maxsize = maxsize
|
||||
# Stores event objects for tasks wanting to
|
||||
# get items from the queue
|
||||
self.getters: deque[Event] = deque()
|
||||
# Stores event objects for tasks wanting to
|
||||
# put items on the queue
|
||||
self.putters: deque[Event] = deque()
|
||||
self.container: deque = deque()
|
||||
|
||||
def __len__(self):
|
||||
"""
|
||||
Returns the length of the queue
|
||||
"""
|
||||
|
||||
return len(self.container)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}({f', '.join(map(str, self.container))})"
|
||||
|
||||
def __aiter__(self):
|
||||
"""
|
||||
Implements the asynchronous iterator protocol
|
||||
"""
|
||||
|
||||
return self
|
||||
|
||||
async def __anext__(self):
|
||||
"""
|
||||
Implements the asynchronous iterator protocol
|
||||
"""
|
||||
|
||||
if self:
|
||||
return await self.get()
|
||||
else:
|
||||
raise StopAsyncIteration()
|
||||
|
||||
@enable_ki_protection
|
||||
async def put(self, item: Any):
|
||||
"""
|
||||
Pushes an element onto the queue. If the
|
||||
queue is full, waits until a slot is
|
||||
available
|
||||
"""
|
||||
|
||||
if self.maxsize and len(self.container) == self.maxsize:
|
||||
self.putters.append(Event())
|
||||
await self.putters[-1].wait()
|
||||
if self.getters:
|
||||
self.getters.popleft().set()
|
||||
self.container.append(item)
|
||||
await checkpoint()
|
||||
|
||||
@enable_ki_protection
|
||||
async def get(self) -> Any:
|
||||
"""
|
||||
Pops an element off the queue. Blocks until
|
||||
an element is put onto it if the queue is empty
|
||||
"""
|
||||
|
||||
if not self.container:
|
||||
self.getters.append(Event())
|
||||
await self.getters[-1].wait()
|
||||
if self.putters:
|
||||
self.putters.popleft().set()
|
||||
result = self.container.popleft()
|
||||
await checkpoint()
|
||||
return result
|
||||
|
||||
@enable_ki_protection
|
||||
def get_noblock(self) -> Any:
|
||||
"""
|
||||
Equivalent of get(), but it raises
|
||||
structio.WouldBlock if there's no
|
||||
elements on the queue instead of
|
||||
blocking
|
||||
"""
|
||||
|
||||
if not self.container:
|
||||
raise WouldBlock()
|
||||
if self.putters:
|
||||
self.putters.popleft().set()
|
||||
return self.container.popleft()
|
||||
|
||||
@enable_ki_protection
|
||||
def put_noblock(self, item: Any):
|
||||
"""
|
||||
Equivalent of put(), but it raises
|
||||
structio.WouldBlock if there's not
|
||||
enough space on the queue instead
|
||||
of blocking
|
||||
"""
|
||||
|
||||
if self.maxsize and len(self.container) == self.maxsize:
|
||||
raise WouldBlock()
|
||||
if self.getters:
|
||||
self.getters.popleft().set()
|
||||
self.container.append(item)
|
||||
|
||||
def clear(self):
|
||||
"""
|
||||
Clears the queue
|
||||
"""
|
||||
|
||||
self.container.clear()
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
Resets the queue
|
||||
"""
|
||||
|
||||
self.clear()
|
||||
self.getters.clear()
|
||||
self.putters.clear()
|
||||
|
||||
|
||||
class MemoryReceiveChannel(ChannelReader):
|
||||
"""
|
||||
An in-memory one-way channel to read
|
||||
data
|
||||
"""
|
||||
|
||||
def __init__(self, buffer):
|
||||
self._buffer = buffer
|
||||
self._closed = False
|
||||
self._read_lock = ThereCanBeOnlyOne("another task is reading from this channel")
|
||||
|
||||
@enable_ki_protection
|
||||
async def receive(self):
|
||||
if self._closed:
|
||||
raise ResourceClosed("cannot operate on a closed channel")
|
||||
with self._read_lock:
|
||||
return await self._buffer.get()
|
||||
|
||||
@enable_ki_protection
|
||||
async def close(self):
|
||||
self._closed = True
|
||||
await checkpoint()
|
||||
|
||||
def pending(self):
|
||||
return bool(self._buffer)
|
||||
|
||||
def readers(self):
|
||||
return len(self._buffer.getters)
|
||||
|
||||
|
||||
class MemorySendChannel(ChannelWriter):
|
||||
"""
|
||||
An in-memory one-way channel to send
|
||||
data
|
||||
"""
|
||||
|
||||
def __init__(self, buffer):
|
||||
self._buffer = buffer
|
||||
self._closed = False
|
||||
self._write_lock = ThereCanBeOnlyOne("another task is writing to this channel")
|
||||
|
||||
@enable_ki_protection
|
||||
async def send(self, item):
|
||||
if self._closed:
|
||||
raise ResourceClosed("cannot operate on a closed channel")
|
||||
with self._write_lock:
|
||||
return await self._buffer.put(item)
|
||||
|
||||
@enable_ki_protection
|
||||
async def close(self):
|
||||
self._closed = True
|
||||
await checkpoint()
|
||||
|
||||
def pending(self):
|
||||
return bool(self._buffer)
|
||||
|
||||
def writers(self):
|
||||
return len(self._buffer.putters)
|
||||
|
||||
|
||||
class MemoryChannel(Channel):
|
||||
"""
|
||||
An in-memory, two-way channel between
|
||||
tasks with optional buffering
|
||||
"""
|
||||
|
||||
def __init__(self, buffer_size):
|
||||
self._buffer = Queue(buffer_size)
|
||||
self.reader = MemoryReceiveChannel(self._buffer)
|
||||
self.writer = MemorySendChannel(self._buffer)
|
||||
|
||||
def pending(self):
|
||||
return self.reader.pending()
|
||||
|
||||
def readers(self):
|
||||
return self.reader.readers()
|
||||
|
||||
def writers(self):
|
||||
return self.writer.writers()
|
||||
|
||||
async def send(self, value):
|
||||
await self.writer.send(value)
|
||||
|
||||
async def receive(self):
|
||||
return await self.reader.receive()
|
||||
|
||||
@enable_ki_protection
|
||||
async def close(self):
|
||||
await self.reader.close()
|
||||
await self.writer.close()
|
||||
|
||||
|
||||
class Semaphore:
|
||||
"""
|
||||
An asynchronous integer semaphore. The use of initial_size
|
||||
is for semaphores which we know that can grow up to max_size
|
||||
but that can't right now, say because there's too much load on
|
||||
the application and resources are constrained. If it is None,
|
||||
initial_size equals max_size
|
||||
"""
|
||||
|
||||
def __init__(self, max_size: int, initial_size: int | None = None):
|
||||
if initial_size is None:
|
||||
initial_size = max_size
|
||||
assert initial_size <= max_size
|
||||
self.max_size = max_size
|
||||
# We use an unbuffered memory channel to pause
|
||||
# as necessary, kinda like socket.set_wakeup_fd
|
||||
# or something? Anyway I think it's pretty nifty
|
||||
# because we're doing no I/O whatsoever so things
|
||||
# stay pretty damn efficient (and cheap!)
|
||||
self.channel: MemoryChannel = MemoryChannel(0)
|
||||
self._counter: int = initial_size
|
||||
|
||||
def __repr__(self):
|
||||
return f"<structio.Semaphore max_size={self.max_size} size={self._counter}>"
|
||||
|
||||
@property
|
||||
def size(self) -> int:
|
||||
return self._counter
|
||||
|
||||
@enable_ki_protection
|
||||
async def acquire(self):
|
||||
"""
|
||||
Acquires the semaphore, possibly
|
||||
blocking if the task counter is
|
||||
exhausted
|
||||
"""
|
||||
|
||||
if self._counter == 0:
|
||||
await self.channel.receive()
|
||||
self._counter -= 1
|
||||
await checkpoint()
|
||||
|
||||
@enable_ki_protection
|
||||
async def release(self):
|
||||
"""
|
||||
Releases a slot in the semaphore. Raises RuntimeError
|
||||
if there are no occupied slots to release
|
||||
"""
|
||||
|
||||
if self._counter == self.max_size:
|
||||
raise RuntimeError("semaphore is not acquired")
|
||||
self._counter += 1
|
||||
if self.channel.readers():
|
||||
await self.channel.send(None)
|
||||
else:
|
||||
await checkpoint()
|
||||
|
||||
@enable_ki_protection
|
||||
async def __aenter__(self):
|
||||
await self.acquire()
|
||||
return self
|
||||
|
||||
@enable_ki_protection
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
await self.release()
|
||||
|
||||
|
||||
class Lock:
|
||||
"""
|
||||
An asynchronous single-owner task lock. Unlike
|
||||
the lock in threading.Thread, only the lock's
|
||||
owner can release it
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._owner: Task | None = None
|
||||
self._sem: Semaphore = Semaphore(1)
|
||||
|
||||
@property
|
||||
def owner(self) -> Task | None:
|
||||
"""
|
||||
Returns the current owner of the lock,
|
||||
or None if the lock is not being held
|
||||
"""
|
||||
|
||||
return self._owner
|
||||
|
||||
@property
|
||||
def locked(self) -> bool:
|
||||
"""
|
||||
Returns whether the lock is currently
|
||||
held
|
||||
"""
|
||||
|
||||
return self._sem.size == 0
|
||||
|
||||
@enable_ki_protection
|
||||
async def acquire(self):
|
||||
"""
|
||||
Acquires the lock, possibly
|
||||
blocking until it is available
|
||||
"""
|
||||
|
||||
await self._sem.acquire()
|
||||
self._owner = current_task()
|
||||
|
||||
@enable_ki_protection
|
||||
async def release(self):
|
||||
"""
|
||||
Releases the lock if it was previously
|
||||
acquired by the caller. If the lock is
|
||||
not currently acquired or if it is not
|
||||
acquired by the calling task, RuntimeError
|
||||
is raised
|
||||
"""
|
||||
|
||||
if not self.owner:
|
||||
raise RuntimeError("lock is not acquired")
|
||||
if current_task() is not self.owner:
|
||||
raise RuntimeError("lock can only be released by the owner")
|
||||
self._owner = None
|
||||
await self._sem.release()
|
||||
|
||||
@enable_ki_protection
|
||||
async def __aenter__(self):
|
||||
await self.acquire()
|
||||
return self
|
||||
|
||||
@enable_ki_protection
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
await self.release()
|
||||
|
||||
|
||||
class RLock(Lock):
|
||||
"""
|
||||
An asynchronous, single-owner recursive lock.
|
||||
Recursive locks have the property that their
|
||||
acquire() method can be called multiple times
|
||||
by the owner without deadlocking: each call
|
||||
increments an internal counter, which is decremented
|
||||
at every call to release(). The lock is released only
|
||||
when the internal counter reaches zero
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self._acquire_count = 0
|
||||
|
||||
@enable_ki_protection
|
||||
async def acquire(self):
|
||||
current = current_task()
|
||||
if self.owner != current:
|
||||
await super().acquire()
|
||||
else:
|
||||
await checkpoint()
|
||||
self._acquire_count += 1
|
||||
|
||||
@property
|
||||
def acquire_count(self) -> int:
|
||||
"""
|
||||
Returns the number of times acquire()
|
||||
was called by the owner (note that it
|
||||
may be zero if the lock is not being
|
||||
held)
|
||||
"""
|
||||
|
||||
return self._acquire_count
|
||||
|
||||
@enable_ki_protection
|
||||
async def release(self):
|
||||
# I hate the repetition, but it's the
|
||||
# only way to make sure that a task can't
|
||||
# decrement the counter of a lock it does
|
||||
# not own
|
||||
current = current_task()
|
||||
if self.owner != current:
|
||||
await super().release()
|
||||
else:
|
||||
self._acquire_count -= 1
|
||||
if self._acquire_count == 0:
|
||||
await super().release()
|
||||
else:
|
||||
await checkpoint()
|
||||
|
||||
|
||||
_events: dict[str, list[Callable[[Any, Any], Coroutine[Any, Any, Any]]]] = defaultdict(
|
||||
list
|
||||
)
|
||||
|
||||
|
||||
async def emit(evt: str, *args, **kwargs):
|
||||
"""
|
||||
Fire the event and call all of its handlers with
|
||||
the event name as the first argument and all other
|
||||
positional and keyword arguments passed to this
|
||||
function after that. Returns once all events have
|
||||
completed execution
|
||||
"""
|
||||
|
||||
async with structio.create_pool() as pool:
|
||||
for func in _events[evt]:
|
||||
pool.spawn(partial(func, evt, *args, **kwargs))
|
||||
|
||||
|
||||
def register_event(evt: str, func: Callable[[Any, Any], Coroutine[Any, Any, Any]]):
|
||||
"""
|
||||
Register the given async function for the given event name.
|
||||
Note that if the given async function is already registered
|
||||
for the chosen event, it will be called once for each time
|
||||
this function is called once the associated event is fired
|
||||
"""
|
||||
|
||||
_events[evt].append(func)
|
||||
|
||||
|
||||
def unregister_event(evt: str, func: Callable[[Any, Any], Coroutine[Any, Any, Any]]):
|
||||
"""
|
||||
Unregisters the given async function from the given event.
|
||||
Nothing happens if the given event or async functions are
|
||||
not registered yet
|
||||
"""
|
||||
|
||||
try:
|
||||
_events[evt].remove(func)
|
||||
except IndexError:
|
||||
pass
|
||||
|
||||
|
||||
def on_event(evt: str):
|
||||
"""
|
||||
Convenience decorator to
|
||||
register async functions
|
||||
to events
|
||||
"""
|
||||
|
||||
def decorator(f):
|
||||
@wraps
|
||||
def wrapper(*args, **kwargs):
|
||||
f(*args, **kwargs)
|
||||
|
||||
register_event(evt, f)
|
||||
return wrapper
|
||||
|
||||
return decorator
|
|
@ -0,0 +1,384 @@
|
|||
# Support module for running synchronous functions as
|
||||
# coroutines into worker threads and to submit asynchronous
|
||||
# work to the event loop from a synchronous thread
|
||||
from functools import partial
|
||||
|
||||
import structio
|
||||
import threading
|
||||
from collections import deque
|
||||
from structio.abc import BaseKernel
|
||||
from structio.core.run import current_loop
|
||||
from typing import Callable, Any, Coroutine
|
||||
from structio.core.syscalls import checkpoint
|
||||
from structio.sync import Event, Semaphore, Queue
|
||||
from structio.util.ki import enable_ki_protection
|
||||
from structio.exceptions import StructIOException
|
||||
from itertools import count as _count
|
||||
|
||||
|
||||
_storage = threading.local()
|
||||
# Max number of concurrent threads that can
|
||||
# be spawned by run_in_worker before blocking
|
||||
_storage.max_workers = Semaphore(50)
|
||||
_worker_id = _count()
|
||||
|
||||
|
||||
def is_async_thread() -> bool:
|
||||
return hasattr(_storage, "parent_loop")
|
||||
|
||||
|
||||
class AsyncThreadEvent(Event):
|
||||
"""
|
||||
An extension of the regular event
|
||||
class that is safe to utilize both
|
||||
from threads and from async code
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self._lock = threading.Lock()
|
||||
self._workers: deque[threading.Event] = deque()
|
||||
|
||||
@enable_ki_protection
|
||||
def wait_sync(self):
|
||||
"""
|
||||
Like wait(), but synchronous
|
||||
"""
|
||||
|
||||
with self._lock:
|
||||
if self.is_set():
|
||||
return
|
||||
ev = threading.Event()
|
||||
self._workers.append(ev)
|
||||
ev.wait()
|
||||
|
||||
@enable_ki_protection
|
||||
async def wait(self):
|
||||
with self._lock:
|
||||
if self.is_set():
|
||||
return
|
||||
await super().wait()
|
||||
|
||||
@enable_ki_protection
|
||||
def set(self):
|
||||
with self._lock:
|
||||
if self.is_set():
|
||||
return
|
||||
# We can't just call super().set() because that
|
||||
# will call current_loop(), and we may have been
|
||||
# called from an async thread that doesn't have a
|
||||
# loop
|
||||
loop: BaseKernel = _storage.parent_loop
|
||||
for task in self._tasks:
|
||||
loop.reschedule(task)
|
||||
# Awakes all threads
|
||||
for evt in self._workers:
|
||||
evt.set()
|
||||
self._set = True
|
||||
|
||||
|
||||
class AsyncThreadQueue(Queue):
|
||||
"""
|
||||
An extension of the regular queue
|
||||
class that is safe to use both from
|
||||
threaded and asynchronous code
|
||||
"""
|
||||
|
||||
def __init__(self, max_size):
|
||||
super().__init__(max_size)
|
||||
self._lock = threading.Lock()
|
||||
|
||||
@enable_ki_protection
|
||||
async def get(self):
|
||||
evt: AsyncThreadEvent | None = None
|
||||
with self._lock:
|
||||
if not self.container:
|
||||
self.getters.append(AsyncThreadEvent())
|
||||
evt = self.getters[-1]
|
||||
if self.putters:
|
||||
self.putters.popleft().set()
|
||||
if evt:
|
||||
await evt.wait()
|
||||
await checkpoint()
|
||||
return self.container.popleft()
|
||||
|
||||
@enable_ki_protection
|
||||
async def put(self, item):
|
||||
evt: AsyncThreadEvent | None = None
|
||||
with self._lock:
|
||||
if self.maxsize and self.maxsize == len(self.container):
|
||||
self.putters.append(AsyncThreadEvent())
|
||||
evt = self.putters[-1]
|
||||
if self.getters:
|
||||
self.getters.popleft().set()
|
||||
if evt:
|
||||
await evt.wait()
|
||||
self.container.append(item)
|
||||
await checkpoint()
|
||||
|
||||
@enable_ki_protection
|
||||
def get_noblock(self) -> Any:
|
||||
return super().get_noblock()
|
||||
|
||||
@enable_ki_protection
|
||||
def put_noblock(self, item: Any):
|
||||
return super().put_noblock(item)
|
||||
|
||||
@enable_ki_protection
|
||||
def put_sync(self, item):
|
||||
"""
|
||||
Like put(), but synchronous
|
||||
"""
|
||||
|
||||
evt: AsyncThreadEvent | None = None
|
||||
with self._lock:
|
||||
if self.maxsize and self.maxsize == len(self.container):
|
||||
evt = AsyncThreadEvent()
|
||||
self.putters.append(evt)
|
||||
if self.getters:
|
||||
self.getters.popleft().set()
|
||||
if evt:
|
||||
evt.wait_sync()
|
||||
self.container.append(item)
|
||||
|
||||
@enable_ki_protection
|
||||
def get_sync(self):
|
||||
"""
|
||||
Like get(), but synchronous
|
||||
"""
|
||||
|
||||
evt: AsyncThreadEvent | None = None
|
||||
with self._lock:
|
||||
if not self.container:
|
||||
self.getters.append(AsyncThreadEvent())
|
||||
evt = self.getters[-1]
|
||||
if self.putters:
|
||||
self.putters.popleft().set()
|
||||
if evt:
|
||||
evt.wait_sync()
|
||||
return self.container.popleft()
|
||||
|
||||
|
||||
# Just a bunch of private helpers to run sync/async functions
|
||||
|
||||
|
||||
def _threaded_runner(
|
||||
f,
|
||||
parent_loop: BaseKernel,
|
||||
rq: AsyncThreadQueue,
|
||||
rsq: AsyncThreadQueue,
|
||||
evt: AsyncThreadEvent,
|
||||
coro_runner: "structio.util.wakeup_fd.WakeupFd",
|
||||
supervisor: "structio.util.wakeup_fd.WakeupFd",
|
||||
*args,
|
||||
):
|
||||
"""
|
||||
This is the actual function where our worker thread "lives"
|
||||
"""
|
||||
|
||||
try:
|
||||
# Setup thread-local storage so future calls
|
||||
# to run_coro() can find this stuff
|
||||
_storage.parent_loop = parent_loop
|
||||
_storage.rq = rq
|
||||
_storage.rsq = rsq
|
||||
_storage.coro_runner = coro_runner
|
||||
_storage.supervisor = supervisor
|
||||
result = f(*args)
|
||||
except BaseException as e:
|
||||
rsq.put_sync((False, e))
|
||||
else:
|
||||
rsq.put_sync((True, result))
|
||||
finally:
|
||||
# Wakeup the event loop
|
||||
_storage.supervisor.wakeup()
|
||||
# Notify run_in_worker that the thread
|
||||
# has exited
|
||||
evt.set()
|
||||
|
||||
|
||||
@enable_ki_protection
|
||||
async def _coroutine_request_handler(
|
||||
coroutines: AsyncThreadQueue,
|
||||
results: AsyncThreadQueue,
|
||||
reader: "structio.socket.AsyncSocket",
|
||||
):
|
||||
"""
|
||||
Runs coroutines on behalf of a thread spawned by structio and
|
||||
submits the outcome back to the thread
|
||||
"""
|
||||
|
||||
while True:
|
||||
await reader.receive(1)
|
||||
coro = await coroutines.get()
|
||||
try:
|
||||
result = await coro
|
||||
except BaseException as e:
|
||||
await results.put((False, e))
|
||||
else:
|
||||
await results.put((True, result))
|
||||
|
||||
|
||||
@enable_ki_protection
|
||||
async def run_in_worker(
|
||||
sync_func,
|
||||
*args,
|
||||
cancellable: bool = False,
|
||||
):
|
||||
"""
|
||||
Call the given synchronous function in a separate
|
||||
worker thread, turning it into an async operation.
|
||||
Must be called from an asynchronous context (a
|
||||
StructIOException is raised otherwise). The result
|
||||
of the call is returned, and any exceptions that occur
|
||||
are propagated back to the caller. This is semantically
|
||||
identical to just calling the function itself from within
|
||||
the async context, but it has the added benefit of 1) Being
|
||||
partially cancellable (with a catch, see below) and 2) If
|
||||
the function performs some long-running blocking operation,
|
||||
calling it in the main thread is not advisable, as it would
|
||||
cause structio's event loop to grind to a halt, meaning that
|
||||
timeouts and cancellations don't work, I/O doesn't get scheduled,
|
||||
and all sorts of nasty things happen (or rather, don't happen,
|
||||
since no work is getting done). In short, don't do long-running
|
||||
sync calls in the main thread, use a worker. Also, don't do any
|
||||
CPU-bound work in it, or you're likely to negatively affect the main
|
||||
thread anyway because CPython is weird and likes to starve-out I/O
|
||||
bound threads if there's some CPU-bound workers running (for that kind
|
||||
of work, you might want to spawn an entire separate process instead).
|
||||
Now, onto cancellations: If cancellable equals False, then the operation
|
||||
cannot be canceled in any way (this is the default option). This means
|
||||
that even if you set a task scope with a timeout or explicitly cancel
|
||||
the pool where this function is awaited, its effects won't be visible
|
||||
until after the thread has exited. If cancellable equals True, cancellation
|
||||
will cause this function to return early and to abruptly drop the thread:
|
||||
keep in mind that it is likely to keep running in the background, as
|
||||
structio doesn't make any effort to stop it (it can't). If you call this
|
||||
with cancellable=True, make sure the operation you're performing is side-effect-free,
|
||||
or you might get nasty deadlocks or race conditions happening.
|
||||
|
||||
Note: If the number of current active thread workers is equal to the value of get_max_worker_count(),
|
||||
this function blocks until a slot is available and then proceeds normally.
|
||||
|
||||
"""
|
||||
|
||||
if not hasattr(_storage, "parent_loop"):
|
||||
_storage.parent_loop = current_loop()
|
||||
else:
|
||||
try:
|
||||
current_loop()
|
||||
except StructIOException:
|
||||
raise StructIOException("cannot be called from sync context")
|
||||
# This will automatically block once
|
||||
# we run out of slots and proceed once
|
||||
# we have more
|
||||
async with _storage.max_workers:
|
||||
# Thread termination event
|
||||
terminate = AsyncThreadEvent()
|
||||
# Request queue. This is where the thread
|
||||
# sends coroutines to run
|
||||
rq = AsyncThreadQueue(0)
|
||||
# Results queue. This is where we put the result
|
||||
# of the coroutines in the request queue
|
||||
rsq = AsyncThreadQueue(0)
|
||||
# This looks like a lot of bookkeeping to do synchronization, but it all has a purpose.
|
||||
# The termination event is necessary so that we can know when the thread has terminated,
|
||||
# no surprises there I'd say. The request and result queues are used to send coroutines
|
||||
# and their results back and forth when using run_coro from within the "asynchronous thread"
|
||||
async with structio.create_pool() as pool:
|
||||
# If the operation is cancellable, then we're not
|
||||
# shielded
|
||||
pool.scope.shielded = not cancellable
|
||||
worker_id = next(_worker_id)
|
||||
wakeup = structio.util.wakeup_fd.WakeupFd()
|
||||
wakeup2 = structio.util.wakeup_fd.WakeupFd()
|
||||
# Spawn a coroutine to process incoming requests from
|
||||
# the new async thread. We can't await it because it
|
||||
# needs to run in the background
|
||||
handler = pool.spawn(_coroutine_request_handler, rq, rsq, wakeup.reader)
|
||||
# Start the worker thread
|
||||
threading.Thread(
|
||||
target=_threaded_runner,
|
||||
args=(
|
||||
sync_func,
|
||||
current_loop(),
|
||||
rq,
|
||||
rsq,
|
||||
terminate,
|
||||
wakeup,
|
||||
wakeup2,
|
||||
*args,
|
||||
),
|
||||
name=f"structio-worker-thread-{worker_id}",
|
||||
# We start cancellable threads in daemonic mode so that
|
||||
# the main thread doesn't get stuck waiting on them forever
|
||||
# when their associated async counterpart gets cancelled. This
|
||||
# is due to the fact that there's really no way to "kill" a thread
|
||||
# (and for good reason!), so we just pretend nothing happened and go
|
||||
# about our merry way, hoping the thread dies eventually I guess
|
||||
daemon=cancellable,
|
||||
).start()
|
||||
# Ensure we get poked by the worker thread
|
||||
await wakeup2.reader.receive(1)
|
||||
# Wait for the thread to terminate
|
||||
await terminate.wait()
|
||||
# Worker thread has exited: we no longer need to process
|
||||
# any requests, so we shut our request handler down
|
||||
handler.cancel()
|
||||
# Fetch for the final result from the thread. We use get_noblock()
|
||||
# because we know the result should already be there, so the operation
|
||||
# should not block (and if this raises WouldBlock, then it's a bug)
|
||||
success, data = rsq.get_noblock()
|
||||
if success:
|
||||
return data
|
||||
raise data
|
||||
|
||||
|
||||
@enable_ki_protection
|
||||
def run_coro(
|
||||
async_func: Callable[[Any, Any], Coroutine[Any, Any, Any]], *args, **kwargs
|
||||
):
|
||||
"""
|
||||
Submits a coroutine for execution to the event loop from another thread,
|
||||
passing any arguments along the way. Return values and exceptions are
|
||||
propagated, and from the point of view of the calling thread this call
|
||||
blocks until the coroutine returns. The thread must be async flavored,
|
||||
meaning it must be able to communicate back and forth with the event
|
||||
loop running in the main thread (in practice, this means only threads
|
||||
spawned with run_in_worker are able to call this)
|
||||
"""
|
||||
|
||||
try:
|
||||
current_loop()
|
||||
except StructIOException:
|
||||
pass
|
||||
else:
|
||||
raise StructIOException("cannot be called from async context")
|
||||
if not is_async_thread() or _storage.parent_loop.done():
|
||||
raise StructIOException("run_coro requires a running loop in another thread!")
|
||||
# Wake up the event loop if it's blocked in a call to select() or similar I/O routine
|
||||
_storage.coro_runner.wakeup()
|
||||
_storage.rq.put_sync(async_func(*args, **kwargs))
|
||||
success, data = _storage.rsq.get_sync()
|
||||
if success:
|
||||
return data
|
||||
raise data
|
||||
|
||||
|
||||
def set_max_worker_count(count: int):
|
||||
"""
|
||||
Sets a new value for the maximum number of concurrent
|
||||
worker threads structio is allowed to spawn
|
||||
"""
|
||||
|
||||
# Everything, to avoid the unholy "global"
|
||||
_storage.max_workers = Semaphore(count)
|
||||
|
||||
|
||||
def get_max_worker_count() -> int:
|
||||
"""
|
||||
Gets the maximum number of concurrent worker
|
||||
threads structio is allowed to spawn
|
||||
"""
|
||||
|
||||
return _storage.max_workers.max_size
|
|
@ -0,0 +1,4 @@
|
|||
from . import misc, ki, wakeup_fd
|
||||
|
||||
|
||||
__all__ = ["misc", "ki", "wakeup_fd"]
|
|
@ -0,0 +1,126 @@
|
|||
"""
|
||||
aiosched: Yet another Python async scheduler
|
||||
|
||||
Copyright (C) 2022 nocturn9x
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
https:www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
import sys
|
||||
import inspect
|
||||
from functools import wraps
|
||||
from types import FrameType
|
||||
|
||||
|
||||
# Special magic module half-stolen from Trio (thanks njsmith I love you)
|
||||
# that makes Ctrl+C work. P.S.: Please Python, get your signals straight.
|
||||
|
||||
|
||||
# Just a funny variable name that is not a valid
|
||||
# identifier (but still a string so tools that hack
|
||||
# into frames don't freak out when they look at the
|
||||
# local variables) which will get injected silently
|
||||
# into every frame to enable/disable the safeguards
|
||||
# for Ctrl+C/KeyboardInterrupt
|
||||
CTRLC_PROTECTION_ENABLED = "|yes-it-is|"
|
||||
|
||||
|
||||
def critical_section(frame: FrameType) -> bool:
|
||||
"""
|
||||
Returns whether Ctrl+C protection is currently
|
||||
enabled in the given frame or in any of its children.
|
||||
Stolen from Trio
|
||||
"""
|
||||
|
||||
while frame is not None:
|
||||
if CTRLC_PROTECTION_ENABLED in frame.f_locals:
|
||||
return frame.f_locals[CTRLC_PROTECTION_ENABLED]
|
||||
if frame.f_code.co_name == "__del__":
|
||||
return True
|
||||
frame = frame.f_back
|
||||
return True
|
||||
|
||||
|
||||
def currently_protected() -> bool:
|
||||
"""
|
||||
Returns whether Ctrl+C protection is currently
|
||||
enabled in the current context
|
||||
"""
|
||||
|
||||
return critical_section(sys._getframe())
|
||||
|
||||
|
||||
def legacy_isasyncgenfunction(obj):
|
||||
return getattr(obj, "_async_gen_function", None) == id(obj)
|
||||
|
||||
|
||||
def _ki_protection_decorator(enabled):
|
||||
def decorator(fn):
|
||||
# In some version of Python, isgeneratorfunction returns true for
|
||||
# coroutine functions, so we have to check for coroutine functions
|
||||
# first.
|
||||
if inspect.iscoroutinefunction(fn):
|
||||
|
||||
@wraps(fn)
|
||||
def wrapper(*args, **kwargs):
|
||||
# See the comment for regular generators below
|
||||
coro = fn(*args, **kwargs)
|
||||
coro.cr_frame.f_locals[CTRLC_PROTECTION_ENABLED] = enabled
|
||||
return coro
|
||||
|
||||
return wrapper
|
||||
elif inspect.isgeneratorfunction(fn):
|
||||
|
||||
@wraps(fn)
|
||||
def wrapper(*args, **kwargs):
|
||||
# It's important that we inject this directly into the
|
||||
# generator's locals, as opposed to setting it here and then
|
||||
# doing 'yield from'. The reason is, if a generator is
|
||||
# throw()n into, then it may magically pop to the top of the
|
||||
# stack. And @contextmanager generators in particular are a
|
||||
# case where we often want KI protection, and which are often
|
||||
# thrown into! See:
|
||||
# https://bugs.python.org/issue29590
|
||||
gen = fn(*args, **kwargs)
|
||||
gen.gi_frame.f_locals[CTRLC_PROTECTION_ENABLED] = enabled
|
||||
return gen
|
||||
|
||||
return wrapper
|
||||
elif inspect.isasyncgenfunction(fn) or legacy_isasyncgenfunction(fn):
|
||||
|
||||
@wraps(fn)
|
||||
def wrapper(*args, **kwargs):
|
||||
# See the comment for regular generators above
|
||||
agen = fn(*args, **kwargs)
|
||||
agen.ag_frame.f_locals[CTRLC_PROTECTION_ENABLED] = enabled
|
||||
return agen
|
||||
|
||||
return wrapper
|
||||
else:
|
||||
|
||||
@wraps(fn)
|
||||
def wrapper(*args, **kwargs):
|
||||
locals()[CTRLC_PROTECTION_ENABLED] = enabled
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
enable_ki_protection = _ki_protection_decorator(True)
|
||||
enable_ki_protection.__name__ = "enable_ki_protection"
|
||||
enable_ki_protection.__doc__ = "Decorator to enable keyboard interrupt protection"
|
||||
|
||||
disable_ki_protection = _ki_protection_decorator(False)
|
||||
disable_ki_protection.__name__ = "disable_ki_protection"
|
||||
disable_ki_protection.__doc__ = "Decorator to disable keyboard interrupt protection"
|
|
@ -0,0 +1,30 @@
|
|||
from structio.exceptions import ResourceBusy
|
||||
|
||||
|
||||
# Yes, I stole trio's idea of the ConflictDetector class. Shut up
|
||||
class ThereCanBeOnlyOne:
|
||||
"""
|
||||
A simple context manager that raises an error when
|
||||
an attempt is made to acquire it from more than one
|
||||
task at a time. Can be used to protect sections of
|
||||
code handling some async resource that would need locking
|
||||
if they were allowed to be called from more than one task
|
||||
at a time, but that should never happen (for example, if you
|
||||
try to do call await send() on a socket from two different
|
||||
tasks at the same time)
|
||||
"""
|
||||
|
||||
def __init__(self, msg: str):
|
||||
self._acquired = False
|
||||
self.msg = msg
|
||||
|
||||
def __enter__(self):
|
||||
if self._acquired:
|
||||
raise ResourceBusy(self.msg)
|
||||
self._acquired = True
|
||||
|
||||
def __exit__(self, *args):
|
||||
self._acquired = False
|
||||
|
||||
|
||||
__all__ = ["ThereCanBeOnlyOne"]
|
|
@ -0,0 +1,21 @@
|
|||
from structio.io.socket import socketpair
|
||||
import signal
|
||||
|
||||
|
||||
class WakeupFd:
|
||||
"""
|
||||
A thin wrapper over a socket pair used in signal.set_wakeup_fd
|
||||
and for thread wakeup events
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.reader, self.writer = socketpair()
|
||||
|
||||
def set_wakeup_fd(self):
|
||||
signal.set_wakeup_fd(self.writer.socket.fileno())
|
||||
|
||||
def wakeup(self):
|
||||
try:
|
||||
self.writer.socket.send(b"\x00")
|
||||
except BlockingIOError:
|
||||
pass
|
|
@ -0,0 +1,148 @@
|
|||
import structio
|
||||
import logging
|
||||
import sys
|
||||
|
||||
|
||||
# An asynchronous chatroom
|
||||
|
||||
clients: dict[structio.socket.AsyncSocket, list[str, str]] = {}
|
||||
names: set[str] = set()
|
||||
|
||||
|
||||
async def event_handler(q: structio.Queue):
|
||||
"""
|
||||
Reads data submitted onto the queue
|
||||
"""
|
||||
|
||||
try:
|
||||
logging.info("Event handler spawned")
|
||||
while True:
|
||||
msg, payload = await q.get()
|
||||
logging.info(f"Caught event {msg!r} with the following payload: {payload}")
|
||||
except Exception as e:
|
||||
logging.error(
|
||||
f"An exception occurred in the message handler -> {type(e).__name__}: {e}"
|
||||
)
|
||||
except structio.exceptions.Cancelled:
|
||||
logging.warning(f"Cancellation detected, message handler shutting down")
|
||||
# Propagate the cancellation
|
||||
raise
|
||||
|
||||
|
||||
async def serve(bind_address: tuple):
|
||||
"""
|
||||
Serves asynchronously forever (or until Ctrl+C ;))
|
||||
:param bind_address: The address to bind the server to, represented as a tuple
|
||||
(address, port) where address is a string and port is an integer
|
||||
"""
|
||||
|
||||
sock = structio.socket.socket()
|
||||
queue = structio.Queue()
|
||||
await sock.bind(bind_address)
|
||||
await sock.listen(5)
|
||||
logging.info(f"Serving asynchronously at {bind_address[0]}:{bind_address[1]}")
|
||||
async with structio.create_pool() as pool:
|
||||
pool.spawn(event_handler, queue)
|
||||
async with sock:
|
||||
while True:
|
||||
try:
|
||||
conn, address_tuple = await sock.accept()
|
||||
clients[conn] = ["", f"{address_tuple[0]}:{address_tuple[1]}"]
|
||||
await queue.put(("connect", clients[conn]))
|
||||
logging.info(f"{address_tuple[0]}:{address_tuple[1]} connected")
|
||||
pool.spawn(handler, conn, queue)
|
||||
except Exception as err:
|
||||
# Because exceptions just *work*
|
||||
logging.info(
|
||||
f"{address_tuple[0]}:{address_tuple[1]} has raised {type(err).__name__}: {err}"
|
||||
)
|
||||
|
||||
|
||||
async def handler(sock: structio.socket.AsyncSocket, q: structio.Queue):
|
||||
"""
|
||||
Handles a single client connection
|
||||
:param sock: The AsyncSocket object connected to the client
|
||||
"""
|
||||
|
||||
address = clients[sock][1]
|
||||
name = ""
|
||||
async with sock: # Closes the socket automatically
|
||||
await sock.send_all(
|
||||
b"Welcome to the chatroom pal, may you tell me your name?\n> "
|
||||
)
|
||||
cond = True
|
||||
while cond:
|
||||
while not name.endswith("\n"):
|
||||
name = (await sock.receive(64)).decode()
|
||||
if name == "":
|
||||
cond = False
|
||||
break
|
||||
name = name.rstrip("\n")
|
||||
if name not in names:
|
||||
names.add(name)
|
||||
clients[sock][0] = name
|
||||
break
|
||||
else:
|
||||
await sock.send_all(
|
||||
b"Sorry, but that name is already taken. Try again!\n> "
|
||||
)
|
||||
await sock.send_all(f"Okay {name}, welcome to the chatroom!\n".encode())
|
||||
await q.put(("join", (address, name)))
|
||||
logging.info(f"{name} has joined the chatroom ({address}), informing clients")
|
||||
for i, client_sock in enumerate(clients):
|
||||
if client_sock != sock and clients[client_sock][0]:
|
||||
await client_sock.send_all(f"{name} joins the chatroom!\n> ".encode())
|
||||
while True:
|
||||
await sock.send_all(b"> ")
|
||||
data = await sock.receive(1024)
|
||||
if not data:
|
||||
break
|
||||
decoded = data.decode().rstrip("\n")
|
||||
if decoded.startswith("/"):
|
||||
logging.info(f"{name} issued server command {decoded}")
|
||||
await q.put(("cmd", (name, decoded[1:])))
|
||||
match decoded[1:]:
|
||||
case "bye":
|
||||
await sock.send_all(b"Bye!\n")
|
||||
break
|
||||
case _:
|
||||
await sock.send_all(b"Unknown command\n")
|
||||
else:
|
||||
await q.put(("msg", (name, data)))
|
||||
logging.info(f"Got: {data!r} from {address}")
|
||||
for i, client_sock in enumerate(clients):
|
||||
if client_sock != sock and clients[client_sock][0]:
|
||||
logging.info(
|
||||
f"Sending {data!r} to {':'.join(map(str, client_sock.getpeername()))}"
|
||||
)
|
||||
if not data.endswith(b"\n"):
|
||||
data += b"\n"
|
||||
await client_sock.send_all(
|
||||
f"[{name}] ({address}): {data.decode()}> ".encode()
|
||||
)
|
||||
logging.info(f"Sent {data!r} to {i} clients")
|
||||
await q.put(("leave", name))
|
||||
logging.info(f"Connection from {address} closed")
|
||||
logging.info(f"{name} has left the chatroom ({address}), informing clients")
|
||||
for i, client_sock in enumerate(clients):
|
||||
if client_sock != sock and clients[client_sock][0]:
|
||||
await client_sock.send_all(f"{name} has left the chatroom\n> ".encode())
|
||||
clients.pop(sock)
|
||||
names.discard(name)
|
||||
logging.info("Handler shutting down")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
port = int(sys.argv[1]) if len(sys.argv) > 1 else 1501
|
||||
logging.basicConfig(
|
||||
level=20,
|
||||
format="[%(levelname)s] %(asctime)s %(message)s",
|
||||
datefmt="%d/%m/%Y %p",
|
||||
)
|
||||
try:
|
||||
structio.run(serve, ("0.0.0.0", port))
|
||||
except (Exception, KeyboardInterrupt) as error: # Exceptions propagate!
|
||||
if isinstance(error, KeyboardInterrupt):
|
||||
logging.info("Ctrl+C detected, exiting")
|
||||
else:
|
||||
logging.error(f"Exiting due to a {type(error).__name__}: {error}")
|
|
@ -0,0 +1,78 @@
|
|||
import sys
|
||||
import logging
|
||||
import structio
|
||||
|
||||
|
||||
# A test to check for asynchronous I/O
|
||||
|
||||
|
||||
async def serve(bind_address: tuple):
|
||||
"""
|
||||
Serves asynchronously forever
|
||||
:param bind_address: The address to bind the server to represented as a tuple
|
||||
(address, port) where address is a string and port is an integer
|
||||
"""
|
||||
|
||||
sock = structio.socket.socket()
|
||||
await sock.bind(bind_address)
|
||||
await sock.listen(5)
|
||||
logging.info(f"Serving asynchronously at {bind_address[0]}:{bind_address[1]}")
|
||||
async with structio.create_pool() as ctx:
|
||||
async with sock:
|
||||
while True:
|
||||
try:
|
||||
conn, address_tuple = await sock.accept()
|
||||
logging.info(f"{address_tuple[0]}:{address_tuple[1]} connected")
|
||||
await ctx.spawn(handler, conn, address_tuple)
|
||||
except Exception as err:
|
||||
# Because exceptions just *work*
|
||||
logging.info(
|
||||
f"{address_tuple[0]}:{address_tuple[1]} has raised {type(err).__name__}: {err}"
|
||||
)
|
||||
|
||||
|
||||
async def handler(sock: structio.socket.AsyncSocket, client_address: tuple):
|
||||
"""
|
||||
Handles a single client connection
|
||||
:param sock: The AsyncSocket object connected to the client
|
||||
:param client_address: The client's address represented as a tuple
|
||||
(address, port) where address is a string and port is an integer
|
||||
:type client_address: tuple
|
||||
"""
|
||||
|
||||
address = f"{client_address[0]}:{client_address[1]}"
|
||||
async with sock: # Closes the socket automatically
|
||||
await sock.send_all(
|
||||
b"Welcome to the server pal, feel free to send me something!\n"
|
||||
)
|
||||
while True:
|
||||
await sock.send_all(b"-> ")
|
||||
data = await sock.receive(1024)
|
||||
if not data:
|
||||
break
|
||||
elif data == b"exit\n":
|
||||
await sock.send_all(b"I'm dead dude\n")
|
||||
raise TypeError("Oh, no, I'm gonna die!")
|
||||
elif data == b"fatal\n":
|
||||
await sock.send_all(b"What a dick\n")
|
||||
raise KeyboardInterrupt("He told me to do it!")
|
||||
logging.info(f"Got: {data!r} from {address}")
|
||||
await sock.send_all(b"Got: " + data)
|
||||
logging.info(f"Echoed back {data!r} to {address}")
|
||||
logging.info(f"Connection from {address} closed")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
port = int(sys.argv[1]) if len(sys.argv) > 1 else 1501
|
||||
logging.basicConfig(
|
||||
level=20,
|
||||
format="[%(levelname)s] %(asctime)s %(message)s",
|
||||
datefmt="%d/%m/%Y %H:%M:%S %p",
|
||||
)
|
||||
try:
|
||||
structio.run(serve, ("localhost", port))
|
||||
except (Exception, KeyboardInterrupt) as error: # Exceptions propagate!
|
||||
if isinstance(error, KeyboardInterrupt):
|
||||
logging.info("Ctrl+C detected, exiting")
|
||||
else:
|
||||
logging.error(f"Exiting due to a {type(error).__name__}: {error}")
|
|
@ -0,0 +1,52 @@
|
|||
import structio
|
||||
import random
|
||||
|
||||
|
||||
async def waiter(ch: structio.ChannelReader):
|
||||
print("[waiter] Waiter is alive!")
|
||||
while True:
|
||||
print("[waiter] Awaiting events")
|
||||
try:
|
||||
evt: structio.Event = await ch.receive()
|
||||
except structio.ResourceClosed:
|
||||
break
|
||||
print("[waiter] Received event, waiting to be triggered")
|
||||
t = structio.clock()
|
||||
await evt.wait()
|
||||
print(f"[waiter] Event triggered after {structio.clock() - t:.2f} seconds")
|
||||
print("[waiter] Done!")
|
||||
|
||||
|
||||
async def sender(ch: structio.Channel, n: int):
|
||||
print("[sender] Sender is alive!")
|
||||
async with ch:
|
||||
# Channel is automatically closed when exiting
|
||||
# the async with block
|
||||
for _ in range(n):
|
||||
print("[sender] Sending event")
|
||||
ev = structio.Event()
|
||||
await ch.send(ev)
|
||||
t = random.random()
|
||||
print(f"[sender] Sent event, sleeping {t:.2f} seconds")
|
||||
await structio.sleep(t)
|
||||
print("[sender] Setting the event")
|
||||
ev.set()
|
||||
print("[sender] Done!")
|
||||
|
||||
|
||||
async def main(n: int):
|
||||
print("[main] Parent is alive")
|
||||
channel = structio.MemoryChannel(1)
|
||||
async with structio.create_pool() as pool:
|
||||
# Each end of the channel can be used independently,
|
||||
# and closing one does not also close the other (which
|
||||
# is why we pass the full channel object to our sender
|
||||
# so it can close both ends and cause the reader to catch
|
||||
# the closing exception and exit cleanly)
|
||||
pool.spawn(waiter, channel.reader)
|
||||
pool.spawn(sender, channel, n)
|
||||
print("[main] Children spawned")
|
||||
print("[main] Done!")
|
||||
|
||||
|
||||
structio.run(main, 3)
|
|
@ -0,0 +1,85 @@
|
|||
import structio
|
||||
import time
|
||||
import threading
|
||||
|
||||
|
||||
async def child(ev: structio.Event, n):
|
||||
print(f"[child] I'm alive! Waiting {n} seconds before setting the event")
|
||||
await structio.sleep(n)
|
||||
print("[child] Slept! Setting the event")
|
||||
ev.set()
|
||||
assert ev.is_set()
|
||||
|
||||
|
||||
async def main(i):
|
||||
print("[main] Parent is alive")
|
||||
j = structio.clock()
|
||||
async with structio.create_pool() as pool:
|
||||
evt = structio.Event()
|
||||
print("[main] Spawning child")
|
||||
pool.spawn(child, evt, i)
|
||||
print("[main] Child spawned, waiting on the event")
|
||||
await evt.wait()
|
||||
assert evt.is_set()
|
||||
print(f"[main] Exited in {structio.clock() - j:.2f} seconds")
|
||||
|
||||
|
||||
def thread_worker(ev: structio.thread.AsyncThreadEvent):
|
||||
print("[worker] Worker thread spawned, waiting for event")
|
||||
t = time.time()
|
||||
ev.wait_sync()
|
||||
print(f"[worker] Event was fired after {time.time() - t:.2f} seconds")
|
||||
|
||||
|
||||
async def main_async_thread(i):
|
||||
print("[main] Parent is alive")
|
||||
j = structio.clock()
|
||||
async with structio.create_pool() as pool:
|
||||
# Identical to structio.Event, but this event
|
||||
# can talk to threads too
|
||||
evt = structio.thread.AsyncThreadEvent()
|
||||
print("[main] Spawning child")
|
||||
pool.spawn(child, evt, i)
|
||||
print("[main] Child spawned, calling worker thread")
|
||||
await structio.thread.run_in_worker(thread_worker, evt)
|
||||
assert evt.is_set()
|
||||
print(f"[main] Exited in {structio.clock() - j:.2f} seconds")
|
||||
|
||||
|
||||
# Of course, threaded events work both ways: coroutines and threads
|
||||
# can set/wait on them from either side. Isn't that neat?
|
||||
|
||||
|
||||
def thread_worker_2(n, ev: structio.thread.AsyncThreadEvent):
|
||||
print(
|
||||
f"[worker] Worker thread spawned, sleeping {n} seconds before setting the event"
|
||||
)
|
||||
time.sleep(n)
|
||||
print("[worker] Setting the event")
|
||||
ev.set()
|
||||
|
||||
|
||||
async def child_2(ev: structio.Event):
|
||||
print(f"[child] I'm alive! Waiting on the event")
|
||||
t = structio.clock()
|
||||
await ev.wait()
|
||||
print(f"[child] Slept for {structio.clock() - t:.2f} seconds")
|
||||
assert ev.is_set()
|
||||
|
||||
|
||||
async def main_async_thread_2(i):
|
||||
print("[main] Parent is alive")
|
||||
j = structio.clock()
|
||||
async with structio.create_pool() as pool:
|
||||
evt = structio.thread.AsyncThreadEvent()
|
||||
print("[main] Spawning child")
|
||||
pool.spawn(child_2, evt)
|
||||
print("[main] Child spawned, calling worker thread")
|
||||
await structio.thread.run_in_worker(thread_worker_2, i, evt)
|
||||
assert evt.is_set()
|
||||
print(f"[main] Exited in {structio.clock() - j:.2f} seconds")
|
||||
|
||||
|
||||
structio.run(main, 5)
|
||||
structio.run(main_async_thread, 5)
|
||||
structio.run(main_async_thread_2, 5)
|
|
@ -0,0 +1,48 @@
|
|||
import structio
|
||||
import tempfile
|
||||
import os
|
||||
from structio import aprint
|
||||
|
||||
|
||||
async def main_2(data: bytes):
|
||||
t = structio.clock()
|
||||
await aprint("[main] Using low level os module")
|
||||
async with await structio.open_file(
|
||||
os.path.join(tempfile.gettempdir(), "structio_test.txt"), "wb+"
|
||||
) as f:
|
||||
await aprint(
|
||||
f"[main] Opened async file {f.name!r}, writing payload of {len(data)} bytes"
|
||||
)
|
||||
await f.write(data)
|
||||
await f.seek(0)
|
||||
assert await f.read(len(data)) == data
|
||||
await f.flush()
|
||||
await aprint(f"[main] Deleting {f.name!r}")
|
||||
await structio.thread.run_in_worker(os.unlink, f.name)
|
||||
assert not await structio.thread.run_in_worker(os.path.isfile, f.name)
|
||||
await aprint(f"[main] Done in {structio.clock() - t:.2f} seconds")
|
||||
|
||||
|
||||
async def main_3(data: bytes):
|
||||
t = structio.clock()
|
||||
await aprint("[main] Using high level pathlib wrapper")
|
||||
path = structio.Path(tempfile.gettempdir()) / "structio_test.txt"
|
||||
async with await path.open("wb+") as f:
|
||||
await aprint(
|
||||
f"[main] Opened async file {f.name!r}, writing payload of {len(data)} bytes"
|
||||
)
|
||||
await f.write(data)
|
||||
await f.seek(0)
|
||||
assert await f.read(len(data)) == data
|
||||
await f.flush()
|
||||
await aprint(f"[main] Deleting {f.name!r}")
|
||||
await path.unlink()
|
||||
assert not await path.exists()
|
||||
await aprint(f"[main] Done in {structio.clock() - t:.2f} seconds")
|
||||
|
||||
|
||||
MB = 1048576
|
||||
payload = b"a" * MB * 100
|
||||
# Write 100MiB of data (too much?)
|
||||
structio.run(main_2, payload)
|
||||
structio.run(main_3, payload)
|
|
@ -0,0 +1,68 @@
|
|||
import structio
|
||||
import sys
|
||||
import time
|
||||
|
||||
_print = print
|
||||
|
||||
|
||||
def print(*args, **kwargs):
|
||||
sys.stdout.write(f"[{time.strftime('%H:%M:%S')}] ")
|
||||
_print(*args, **kwargs)
|
||||
|
||||
|
||||
async def test(host: str, port: int, bufsize: int = 4096, keepalive: bool = False):
|
||||
print(f"Attempting a connection to {host}:{port} {'in keep-alive mode' if keepalive else ''}")
|
||||
socket = await structio.socket.connect_tcp_ssl_socket(host, port)
|
||||
buffer = b""
|
||||
print("Connected")
|
||||
# Ensures the code below doesn't run for more than 5 seconds
|
||||
with structio.skip_after(5) as scope:
|
||||
# Closes the socket automatically
|
||||
async with socket:
|
||||
print("Entered socket context manager, sending HTTP request data")
|
||||
await socket.send_all(
|
||||
f"GET / HTTP/1.1\r\nUser-Agent: Structio/0.1.0\r\nAccept: */*\r\nHost: {host}"
|
||||
f"\r\nAccept-Encoding: gzip, deflate, br\r\nConnection: {'close' if not keepalive else 'keep-alive'}"
|
||||
f"\r\n\r\n".encode()
|
||||
)
|
||||
print("Data sent")
|
||||
while True:
|
||||
# We purposely do NOT check for the end of the response (\r\n) so that when the
|
||||
# connection is in keep-alive mode we hang and let our timeout expire the whole
|
||||
# block
|
||||
print(
|
||||
f"Requesting up to {bufsize} bytes (current response size: {len(buffer)})"
|
||||
)
|
||||
data = await socket.receive(bufsize)
|
||||
if data:
|
||||
print(f"Received {len(data)} bytes")
|
||||
buffer += data
|
||||
else:
|
||||
print("Received empty stream, closing connection")
|
||||
break
|
||||
if buffer:
|
||||
data = buffer.decode().split("\r\n")
|
||||
print(
|
||||
f"HTTP Response below {'(might be incomplete)' if scope.timed_out else ''}:"
|
||||
)
|
||||
_print(f"Response: {data[0]}")
|
||||
_print("Headers:")
|
||||
content = False
|
||||
for i, element in enumerate(data):
|
||||
if i == 0:
|
||||
continue
|
||||
else:
|
||||
if not element.strip() and not content:
|
||||
sys.stdout.write("\nContent:")
|
||||
content = True
|
||||
if not content:
|
||||
_print(f"\t{element}")
|
||||
else:
|
||||
for line in element.split("\n"):
|
||||
_print(f"\t{line}")
|
||||
_print("Done!")
|
||||
|
||||
|
||||
structio.run(test, "google.com", 443, 256)
|
||||
# With keep-alive on, our timeout will kick in
|
||||
structio.run(test, "google.com", 443, 256, True)
|
|
@ -0,0 +1,34 @@
|
|||
import structio
|
||||
|
||||
|
||||
async def child(n: int):
|
||||
print(f"Going to sleep for {n} seconds!")
|
||||
i = structio.clock()
|
||||
try:
|
||||
await structio.sleep(n)
|
||||
except structio.Cancelled:
|
||||
slept = structio.clock() - i
|
||||
print(
|
||||
f"Oh no, I've been cancelled! (was gonna sleep {n - slept:.2f} more seconds)"
|
||||
)
|
||||
raise
|
||||
print(f"Slept for {structio.clock() - i:.2f} seconds!")
|
||||
|
||||
|
||||
async def main() -> int:
|
||||
print("Parent is alive. Spawning children")
|
||||
t = structio.clock()
|
||||
try:
|
||||
async with structio.create_pool() as pool:
|
||||
pool.spawn(child, 5)
|
||||
pool.spawn(child, 3)
|
||||
pool.spawn(child, 8)
|
||||
print(f"Children spawned, awaiting completion")
|
||||
except KeyboardInterrupt:
|
||||
print("Ctrl+C caught")
|
||||
print(f"Children have completed in {structio.clock() - t:.2f} seconds")
|
||||
return 0
|
||||
|
||||
|
||||
assert structio.run(main) == 0
|
||||
print("Execution complete")
|
|
@ -0,0 +1,46 @@
|
|||
import structio
|
||||
from typing import Any
|
||||
|
||||
|
||||
async def reader(ch: structio.ChannelReader):
|
||||
print("[reader] Reader is alive!")
|
||||
async with ch:
|
||||
while True:
|
||||
print(f"[reader] Awaiting messages")
|
||||
data = await ch.receive()
|
||||
if not data:
|
||||
break
|
||||
print(f"[reader] Got: {data}")
|
||||
# Simulates some work
|
||||
await structio.sleep(1)
|
||||
print("[reader] Done!")
|
||||
|
||||
|
||||
async def writer(ch: structio.ChannelWriter, objects: list[Any]):
|
||||
print("[writer] Writer is alive!")
|
||||
async with ch:
|
||||
for obj in objects:
|
||||
print(f"[writer] Sending {obj!r}")
|
||||
await ch.send(obj)
|
||||
print(f"[writer] Sent {obj!r}")
|
||||
# Let's make the writer twice as fast as the receiver
|
||||
# to test backpressure :)
|
||||
await structio.sleep(0.5)
|
||||
await ch.send(None)
|
||||
print("[writer] Done!")
|
||||
|
||||
|
||||
async def main(objects: list[Any]):
|
||||
print("[main] Parent is alive")
|
||||
# We construct a new memory channel...
|
||||
channel = structio.MemoryChannel(1) # 1 is the size of the internal buffer
|
||||
async with structio.create_pool() as pool:
|
||||
# ... and dispatch the two ends to different
|
||||
# tasks. Isn't this neat?
|
||||
pool.spawn(reader, channel.reader)
|
||||
pool.spawn(writer, channel.writer, objects)
|
||||
print("[main] Children spawned")
|
||||
print("[main] Done!")
|
||||
|
||||
|
||||
structio.run(main, [1, 2, 3, 4])
|
|
@ -0,0 +1,77 @@
|
|||
import structio
|
||||
|
||||
|
||||
async def successful(name: str, n):
|
||||
before = structio.clock()
|
||||
print(f"[child {name}] Sleeping for {n} seconds")
|
||||
await structio.sleep(n)
|
||||
print(f"[child {name}] Done! Slept for {structio.clock() - before:.2f} seconds")
|
||||
return n
|
||||
|
||||
|
||||
async def failing(name: str, n):
|
||||
before = structio.clock()
|
||||
print(f"[child {name}] Sleeping for {n} seconds")
|
||||
await structio.sleep(n)
|
||||
print(
|
||||
f"[child {name}] Done! Slept for {structio.clock() - before:.2f} seconds, raising now!"
|
||||
)
|
||||
raise TypeError("waa")
|
||||
|
||||
|
||||
async def main(
|
||||
children_outer: list[tuple[str, int]], children_inner: list[tuple[str, int]]
|
||||
):
|
||||
before = structio.clock()
|
||||
try:
|
||||
async with structio.create_pool() as p1:
|
||||
print(f"[main] Spawning children in first context ({hex(id(p1))})")
|
||||
for name, delay in children_outer:
|
||||
p1.spawn(successful, name, delay)
|
||||
print("[main] Children spawned")
|
||||
async with structio.create_pool() as p2:
|
||||
print(f"[main] Spawning children in second context ({hex(id(p2))})")
|
||||
for name, delay in children_inner:
|
||||
p2.spawn(failing, name, delay)
|
||||
print("[main] Children spawned")
|
||||
except TypeError:
|
||||
print("[main] TypeError caught!")
|
||||
print(f"[main] Children exited in {structio.clock() - before:.2f} seconds")
|
||||
|
||||
|
||||
async def main_nested(
|
||||
children_outer: list[tuple[str, int]], children_inner: list[tuple[str, int]]
|
||||
):
|
||||
before = structio.clock()
|
||||
try:
|
||||
async with structio.create_pool() as p1:
|
||||
print(f"[main] Spawning children in first context ({hex(id(p1))})")
|
||||
for name, delay in children_outer:
|
||||
p1.spawn(successful, name, delay)
|
||||
print("[main] Children spawned")
|
||||
async with structio.create_pool() as p2:
|
||||
print(f"[main] Spawning children in second context ({hex(id(p2))})")
|
||||
for name, delay in children_outer:
|
||||
p2.spawn(successful, name, delay)
|
||||
print("[main] Children spawned")
|
||||
async with structio.create_pool() as p3:
|
||||
print(f"[main] Spawning children in third context ({hex(id(p3))})")
|
||||
for name, delay in children_inner:
|
||||
p3.spawn(failing, name, delay)
|
||||
print("[main] Children spawned")
|
||||
except TypeError:
|
||||
print("[main] TypeError caught!")
|
||||
print(f"[main] Children exited in {structio.clock() - before:.2f} seconds")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
structio.run(
|
||||
main,
|
||||
[("first", 1), ("third", 3)],
|
||||
[("second", 2), ("fourth", 4)],
|
||||
)
|
||||
structio.run(
|
||||
main_nested,
|
||||
[("first", 1), ("third", 3)],
|
||||
[("second", 2), ("fourth", 4)],
|
||||
)
|
|
@ -0,0 +1,60 @@
|
|||
import structio
|
||||
from nested_pool_inner_raises import successful, failing
|
||||
|
||||
|
||||
async def main_simple(
|
||||
children_outer: list[tuple[str, int]], children_inner: list[tuple[str, int]]
|
||||
):
|
||||
before = structio.clock()
|
||||
try:
|
||||
async with structio.create_pool() as p1:
|
||||
print(f"[main] Spawning children in first context ({hex(id(p1))})")
|
||||
for name, delay in children_outer:
|
||||
p1.spawn(failing, name, delay)
|
||||
print("[main] Children spawned")
|
||||
async with structio.create_pool() as p2:
|
||||
print(f"[main] Spawning children in second context ({hex(id(p2))})")
|
||||
for name, delay in children_inner:
|
||||
p2.spawn(successful, name, delay)
|
||||
print("[main] Children spawned")
|
||||
except TypeError:
|
||||
print("[main] TypeError caught!")
|
||||
print(f"[main] Children exited in {structio.clock() - before:.2f} seconds")
|
||||
|
||||
|
||||
async def main_nested(
|
||||
children_outer: list[tuple[str, int]], children_inner: list[tuple[str, int]]
|
||||
):
|
||||
before = structio.clock()
|
||||
try:
|
||||
async with structio.create_pool() as p1:
|
||||
print(f"[main] Spawning children in first context ({hex(id(p1))})")
|
||||
for name, delay in children_outer:
|
||||
p1.spawn(failing, name, delay)
|
||||
print("[main] Children spawned")
|
||||
async with structio.create_pool() as p2:
|
||||
print(f"[main] Spawning children in second context ({hex(id(p2))})")
|
||||
for name, delay in children_inner:
|
||||
p2.spawn(successful, name, delay)
|
||||
print("[main] Children spawned")
|
||||
async with structio.create_pool() as p3:
|
||||
print(f"[main] Spawning children in third context ({hex(id(p3))})")
|
||||
for name, delay in children_inner:
|
||||
p3.spawn(successful, name, delay)
|
||||
print("[main] Children spawned")
|
||||
except TypeError:
|
||||
print("[main] TypeError caught!")
|
||||
print(f"[main] Children exited in {structio.clock() - before:.2f} seconds")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
structio.run(
|
||||
main_simple,
|
||||
[("second", 2), ("third", 3)],
|
||||
[("first", 1), ("fourth", 4)],
|
||||
)
|
||||
structio.run(
|
||||
main_nested,
|
||||
[("second", 2), ("third", 3)],
|
||||
[("first", 1), ("fourth", 4)],
|
||||
)
|
|
@ -0,0 +1,10 @@
|
|||
import structio
|
||||
|
||||
|
||||
async def main():
|
||||
async with structio.create_pool():
|
||||
pass
|
||||
print("[main] Done")
|
||||
|
||||
|
||||
structio.run(main)
|
|
@ -0,0 +1,33 @@
|
|||
import structio
|
||||
import subprocess
|
||||
import shlex
|
||||
|
||||
# In the interest of compatibility, structio.parallel
|
||||
# tries to mirror the subprocess module. You can even
|
||||
# pass the constants such as DEVNULL, PIPE, etc. to it
|
||||
# and it'll work
|
||||
|
||||
|
||||
async def main(data: str):
|
||||
cmd = shlex.split("python -c 'print(input())'")
|
||||
to_send = data.encode(errors="ignore")
|
||||
# This will print data to stdout
|
||||
await structio.parallel.run(cmd, input=to_send)
|
||||
# Other option
|
||||
out = await structio.parallel.check_output(cmd, input=to_send)
|
||||
# Thanks to Linux, Mac OS X and Windows all using different
|
||||
# line endings, we have to do this abomination
|
||||
out = out.decode().rstrip("\r").rstrip("\r\n").rstrip("\n")
|
||||
assert out == data
|
||||
# Other, other option :D
|
||||
process = structio.parallel.Popen(
|
||||
cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE
|
||||
)
|
||||
# Note that the process is spawned as soon as the object is
|
||||
# created!
|
||||
out, _ = await process.communicate(to_send)
|
||||
out = out.decode().rstrip("\r").rstrip("\r\n").rstrip("\n")
|
||||
assert out == data
|
||||
|
||||
|
||||
structio.run(main, "owo")
|
|
@ -0,0 +1,94 @@
|
|||
import time
|
||||
import structio
|
||||
|
||||
|
||||
async def producer(q: structio.Queue, n: int):
|
||||
for i in range(n):
|
||||
# This will wait until the
|
||||
# queue is emptied by the
|
||||
# consumer
|
||||
await q.put(i)
|
||||
print(f"[producer] Produced {i}")
|
||||
await q.put(None)
|
||||
print("[producer] Producer done")
|
||||
|
||||
|
||||
async def consumer(q: structio.Queue):
|
||||
while True:
|
||||
# Hangs until there is
|
||||
# something on the queue
|
||||
item = await q.get()
|
||||
if item is None:
|
||||
print("[consumer] Consumer done")
|
||||
break
|
||||
print(f"[consumer] Consumed {item}")
|
||||
# Simulates some work so the
|
||||
# producer waits before putting
|
||||
# the next value
|
||||
await structio.sleep(1)
|
||||
|
||||
|
||||
def threaded_consumer(q: structio.thread.AsyncThreadQueue):
|
||||
while True:
|
||||
# Hangs until there is
|
||||
# something on the queue
|
||||
item = q.get_sync()
|
||||
if item is None:
|
||||
print("[worker consumer] Consumer done")
|
||||
break
|
||||
print(f"[worker consumer] Consumed {item}")
|
||||
# Simulates some work so the
|
||||
# producer waits before putting
|
||||
# the next value
|
||||
time.sleep(1)
|
||||
return 69
|
||||
|
||||
|
||||
async def main(q: structio.Queue, n: int):
|
||||
print("[main] Starting consumer and producer")
|
||||
t = structio.clock()
|
||||
async with structio.create_pool() as ctx:
|
||||
ctx.spawn(producer, q, n)
|
||||
ctx.spawn(consumer, q)
|
||||
print(f"[main] Exited in {structio.clock() - t:.2f} seconds")
|
||||
|
||||
|
||||
def threaded_producer(q: structio.thread.AsyncThreadQueue, n: int):
|
||||
print("[worker producer] Producer started")
|
||||
for i in range(n):
|
||||
# This will wait until the
|
||||
# queue is emptied by the
|
||||
# consumer
|
||||
q.put_sync(i)
|
||||
print(f"[worker producer] Produced {i}")
|
||||
q.put_sync(None)
|
||||
print("[worker producer] Producer done")
|
||||
return 42
|
||||
|
||||
|
||||
async def main_threaded(q: structio.thread.AsyncThreadQueue, n: int):
|
||||
print("[main] Starting consumer and producer")
|
||||
t = structio.clock()
|
||||
async with structio.create_pool() as pool:
|
||||
pool.spawn(producer, q, n)
|
||||
val = await structio.thread.run_in_worker(threaded_consumer, q)
|
||||
print(f"[main] Thread returned {val!r}")
|
||||
print(f"[main] Exited in {structio.clock() - t:.2f} seconds")
|
||||
|
||||
|
||||
async def main_threaded_2(q: structio.thread.AsyncThreadQueue, n: int):
|
||||
print("[main] Starting consumer and producer")
|
||||
t = structio.clock()
|
||||
async with structio.create_pool() as pool:
|
||||
pool.spawn(consumer, q)
|
||||
val = await structio.thread.run_in_worker(threaded_producer, q, n)
|
||||
print(f"[main] Thread returned {val!r}")
|
||||
print(f"[main] Exited in {structio.clock() - t:.2f} seconds")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
queue = structio.Queue(2) # Queue has size limit of 2
|
||||
structio.run(main, queue, 5)
|
||||
queue = structio.thread.AsyncThreadQueue(2)
|
||||
structio.run(main_threaded, queue, 5)
|
||||
structio.run(main_threaded_2, queue, 5)
|
|
@ -0,0 +1,23 @@
|
|||
import structio
|
||||
|
||||
|
||||
async def child(k):
|
||||
print("[child] I'm alive! Spawning sleeper")
|
||||
async with structio.create_pool() as p:
|
||||
p.spawn(structio.sleep, k)
|
||||
print("[child] I'm done sleeping!")
|
||||
|
||||
|
||||
async def main(n: int, k):
|
||||
print(
|
||||
f"[main] Spawning {n} children in their own pools, each sleeping for {k} seconds"
|
||||
)
|
||||
t = structio.clock()
|
||||
async with structio.create_pool() as p:
|
||||
for _ in range(n):
|
||||
p.spawn(child, k)
|
||||
print(f"[main] Done in {structio.clock() - t:.2f} seconds")
|
||||
|
||||
|
||||
# Should exit in ~2 seconds
|
||||
structio.run(main, 10, 2)
|
|
@ -0,0 +1,68 @@
|
|||
import structio
|
||||
|
||||
|
||||
async def sleeper(n):
|
||||
print(f"[sleeper] Going to sleep for {n} seconds!")
|
||||
i = structio.clock()
|
||||
try:
|
||||
await structio.sleep(n)
|
||||
except structio.Cancelled:
|
||||
print(
|
||||
f"[sleeper] Oh no, I've been cancelled! (was gonna sleep {structio.clock() - i:.2f} more seconds)"
|
||||
)
|
||||
raise
|
||||
print("[sleeper] Woke up!")
|
||||
|
||||
|
||||
async def main_simple(n, o, p):
|
||||
print(f"[main] Parent is alive, spawning {o} children sleeping {n} seconds each")
|
||||
t = structio.clock()
|
||||
async with structio.create_pool() as pool:
|
||||
for i in range(o):
|
||||
pool.spawn(sleeper, n)
|
||||
print(f"[main] Children spawned, sleeping {p} seconds before cancelling")
|
||||
await structio.sleep(p)
|
||||
# Note that cancellations propagate to all inner task scopes!
|
||||
pool.scope.cancel()
|
||||
print(f"[main] Parent exited in {structio.clock() - t:.2f} seconds")
|
||||
|
||||
|
||||
async def main_nested(n, o, p):
|
||||
print(
|
||||
f"[main] Parent is alive, spawning {o} children in two contexts sleeping {n} seconds each"
|
||||
)
|
||||
t = structio.clock()
|
||||
async with structio.create_pool() as p1:
|
||||
for i in range(o):
|
||||
p1.spawn(sleeper, n)
|
||||
async with structio.create_pool() as p2:
|
||||
for i in range(o):
|
||||
p2.spawn(sleeper, n)
|
||||
print(f"[main] Children spawned, sleeping {p} seconds before cancelling")
|
||||
await structio.sleep(p)
|
||||
# Note that cancellations propagate to all inner task scopes!
|
||||
p1.scope.cancel()
|
||||
print(f"[main] Parent exited in {structio.clock() - t:.2f} seconds")
|
||||
|
||||
|
||||
async def child(scope: structio.TaskScope, x: float):
|
||||
print(f"[main] Child is alive! Canceling scope in {x} seconds")
|
||||
await structio.sleep(x)
|
||||
scope.cancel()
|
||||
|
||||
|
||||
async def main_child(x: float):
|
||||
print("[main] Parent is alive")
|
||||
t = structio.clock()
|
||||
async with structio.create_pool() as p:
|
||||
print("[main] Spawning child")
|
||||
p.spawn(child, p.scope, x / 2)
|
||||
print(f"[main] Child spawned, sleeping for {x} seconds")
|
||||
await structio.sleep(x)
|
||||
print(f"[main] Done in {structio.clock() - t:.2f} seconds")
|
||||
|
||||
|
||||
# Should take about 5 seconds
|
||||
structio.run(main_simple, 5, 2, 2)
|
||||
structio.run(main_nested, 5, 2, 2)
|
||||
structio.run(main_child, 2)
|
|
@ -0,0 +1,73 @@
|
|||
import structio
|
||||
|
||||
|
||||
async def child(i: int, sem: structio.Semaphore):
|
||||
async with sem:
|
||||
print(f"[child {i}] Entered critical section")
|
||||
await structio.sleep(1)
|
||||
print(f"[child {i}] Exited critical section")
|
||||
|
||||
|
||||
async def main_sem(n: int, k: int):
|
||||
assert isinstance(n, int) and n > 0
|
||||
assert isinstance(k, int) and k > 1
|
||||
print(f"[main] Parent is alive, creating semaphore of size {n}")
|
||||
semaphore = structio.Semaphore(n)
|
||||
t = structio.clock()
|
||||
async with structio.create_pool() as pool:
|
||||
print(f"[main] Spawning {n * k} children")
|
||||
for i in range(1, (n * k) + 1):
|
||||
pool.spawn(child, i, semaphore)
|
||||
print("[main] All children spawned, waiting for completion")
|
||||
# Since our semaphore has a limit of n tasks that
|
||||
# can acquire it concurrently, we should see at most
|
||||
# n instances of child running at any given time,
|
||||
# like in batches
|
||||
print(f"[main] Done in {structio.clock() - t:.2f} seconds")
|
||||
|
||||
|
||||
async def main_lock(k: int):
|
||||
assert isinstance(k, int) and k > 1
|
||||
print(f"[main] Parent is alive, creating lock")
|
||||
lock = structio.Lock()
|
||||
t = structio.clock()
|
||||
async with structio.create_pool() as pool:
|
||||
print(f"[main] Spawning {k} children")
|
||||
for i in range(1, k + 1):
|
||||
# Locks are implemented as simple binary semaphores
|
||||
# and have an identical API, so they can be used
|
||||
# interchangeably
|
||||
pool.spawn(child, i, lock)
|
||||
print("[main] All children spawned, waiting for completion")
|
||||
print(f"[main] Done in {structio.clock() - t:.2f} seconds")
|
||||
|
||||
|
||||
async def recursive_child(i: int, sem: structio.Semaphore, kapow: bool = False):
|
||||
async with sem:
|
||||
print(f"[{'copy of ' if kapow else ''}child {i}] Entered critical section")
|
||||
if kapow:
|
||||
await recursive_child(i, sem)
|
||||
await structio.sleep(1)
|
||||
print(f"[{'copy of ' if kapow else ''}child {i}] Exited critical section")
|
||||
|
||||
|
||||
async def main_rlock(k: int):
|
||||
assert isinstance(k, int) and k > 1
|
||||
print(f"[main] Parent is alive, creating recursive lock")
|
||||
lock = structio.RLock()
|
||||
t = structio.clock()
|
||||
async with structio.create_pool() as pool:
|
||||
print(f"[main] Spawning {k} children")
|
||||
for i in range(1, k + 1):
|
||||
pool.spawn(recursive_child, i, lock, True)
|
||||
print("[main] All children spawned, waiting for completion")
|
||||
print(f"[main] Done in {structio.clock() - t:.2f} seconds")
|
||||
|
||||
|
||||
# Should exit in about k seconds
|
||||
structio.run(main_sem, 3, 5)
|
||||
# Same here, should exit in k seconds, but
|
||||
# it'll run one task at a time (also fewer tasks)
|
||||
structio.run(main_lock, 10)
|
||||
# This should exit in about 2k seconds
|
||||
structio.run(main_rlock, 5)
|
|
@ -0,0 +1,45 @@
|
|||
import structio
|
||||
|
||||
|
||||
async def shielded(i):
|
||||
print("[shielded] Entering shielded section")
|
||||
with structio.TaskScope(shielded=True) as s:
|
||||
await structio.sleep(i)
|
||||
print(f"[shielded] Slept {i} seconds")
|
||||
s.shielded = False
|
||||
print(f"[shielded] Exited shielded section, sleeping {i} more seconds")
|
||||
await structio.sleep(i)
|
||||
|
||||
|
||||
async def main(i):
|
||||
print(f"[main] Parent has started, finishing in {i} seconds")
|
||||
t = structio.clock()
|
||||
with structio.skip_after(i):
|
||||
await shielded(i)
|
||||
print(f"[main] Exited in {structio.clock() - t:.2f} seconds")
|
||||
|
||||
|
||||
async def canceller(s, i):
|
||||
print("[canceller] Entering shielded section")
|
||||
with s:
|
||||
await structio.sleep(i)
|
||||
|
||||
|
||||
async def main_cancel(i, j):
|
||||
print(f"[main] Parent has started, finishing in {j} seconds")
|
||||
t = structio.clock()
|
||||
async with structio.create_pool() as p:
|
||||
s = structio.TaskScope(shielded=True)
|
||||
task = p.spawn(canceller, s, i)
|
||||
await structio.sleep(j)
|
||||
assert not task.done()
|
||||
print("[main] Canceling scope")
|
||||
# Shields only protect from indirect cancellations
|
||||
# coming from outer scopes: they are still cancellable
|
||||
# explicitly!
|
||||
s.cancel()
|
||||
print(f"[main] Exited in {structio.clock() - t:.2f} seconds")
|
||||
|
||||
|
||||
structio.run(main, 2)
|
||||
structio.run(main_cancel, 5, 2)
|
|
@ -0,0 +1,49 @@
|
|||
import structio
|
||||
import signal
|
||||
from types import FrameType
|
||||
|
||||
ev = structio.Event()
|
||||
|
||||
|
||||
async def handler(sig: signal.Signals, _frame: FrameType):
|
||||
print(
|
||||
f"[handler] Handling signal {signal.Signals(sig).name!r}, waiting 1 second before setting event"
|
||||
)
|
||||
# Just to show off the async part
|
||||
await structio.sleep(1)
|
||||
ev.set()
|
||||
|
||||
|
||||
async def main(n):
|
||||
print("[main] Main is alive, setting signal handler")
|
||||
assert structio.get_signal_handler(signal.SIGALRM) is None
|
||||
structio.set_signal_handler(signal.SIGALRM, handler)
|
||||
assert structio.get_signal_handler(signal.SIGALRM) is handler
|
||||
print(f"[main] Signal handler set, calling signal.alarm({n})")
|
||||
signal.alarm(n)
|
||||
print("[main] Alarm scheduled, waiting on global event")
|
||||
t = structio.clock()
|
||||
await ev.wait()
|
||||
print(f"[main] Exited in {structio.clock() - t:.2f} seconds")
|
||||
|
||||
|
||||
async def handler_2(sig: signal.Signals, _frame: FrameType):
|
||||
print(
|
||||
f"[handler] Handling signal {signal.Signals(sig).name!r}, waiting 1 second before exiting"
|
||||
)
|
||||
await structio.sleep(1)
|
||||
|
||||
|
||||
async def main_2(n):
|
||||
structio.set_signal_handler(signal.SIGHUP, handler_2)
|
||||
t = structio.clock()
|
||||
while n:
|
||||
signal.raise_signal(signal.SIGHUP)
|
||||
print(f"[main] Sleeping half a second ({n=})")
|
||||
await structio.sleep(0.5)
|
||||
n -= 1
|
||||
print(f"[main] Done in {structio.clock() - t:.2f} seconds")
|
||||
|
||||
|
||||
structio.run(main_2, 5)
|
||||
structio.run(main, 5)
|
|
@ -0,0 +1,32 @@
|
|||
import structio
|
||||
|
||||
|
||||
async def main(n):
|
||||
print(f"[main] Starting sliding timer with timeout {n}")
|
||||
i = structio.clock()
|
||||
with structio.skip_after(1.5) as scope:
|
||||
while n:
|
||||
# This looks weird, but it allows us to
|
||||
# handle floating point values (basically
|
||||
# if n equals say, 7.5, then this loop will
|
||||
# sleep 7.5 seconds instead of 8), which would
|
||||
# otherwise cause this loop to run forever and
|
||||
# the deadline to shift indefinitely into the
|
||||
# future (because n would never reach zero, getting
|
||||
# immediately negative instead)
|
||||
shift = min(n, 1)
|
||||
print(f"[main] Waiting {shift:.2f} second{'' if shift == 1 else 's'}")
|
||||
await structio.sleep(shift)
|
||||
print(f"[main] Shifting deadline")
|
||||
# Updating the scope's timeout causes
|
||||
# its deadline to shift accordingly!
|
||||
scope.timeout += shift
|
||||
n -= shift
|
||||
print("[main] Deadline shifting complete")
|
||||
# Should take about n seconds to run, because we shift
|
||||
# the deadline of the cancellation n times and wait at most
|
||||
# 1 second after every shift
|
||||
print(f"[main] Exited in {structio.clock() - i:.2f} seconds")
|
||||
|
||||
|
||||
structio.run(main, 7.5)
|
|
@ -0,0 +1,25 @@
|
|||
import structio
|
||||
from functools import partial
|
||||
|
||||
|
||||
@structio.on_event("on_message")
|
||||
async def test(evt, *args, **kwargs):
|
||||
print(f"[test] New event {evt!r} with arguments: {args}, {kwargs}")
|
||||
# Simulate some work
|
||||
await structio.sleep(1)
|
||||
|
||||
|
||||
async def main():
|
||||
print("[main] Firing two events synchronously")
|
||||
t = structio.clock()
|
||||
await structio.emit("on_message", 1, 2, 3, **{"foo": "bar"})
|
||||
await structio.emit("on_message", 1, 2, 4, **{"foo": "baz"})
|
||||
print(f"[main] Done in {structio.clock() - t:.2f} seconds. Firing two events in parallel")
|
||||
t = structio.clock()
|
||||
async with structio.create_pool() as pool:
|
||||
pool.spawn(partial(structio.emit, "on_message", 1, 2, 3, **{"foo": "bar"}))
|
||||
pool.spawn(partial(structio.emit, "on_message", 1, 2, 4, **{"foo": "baz"}))
|
||||
print(f"[main] Done in {structio.clock() - t:.2f} seconds")
|
||||
|
||||
|
||||
structio.run(main)
|
|
@ -0,0 +1,11 @@
|
|||
import sniffio
|
||||
import structio
|
||||
|
||||
|
||||
async def main():
|
||||
backend = sniffio.current_async_library()
|
||||
assert backend == "structured-io"
|
||||
print(f"[main] Current async backend: {backend}")
|
||||
|
||||
|
||||
structio.run(main)
|
|
@ -0,0 +1,28 @@
|
|||
import datetime as dtt
|
||||
import structio
|
||||
|
||||
|
||||
async def task():
|
||||
for i in range(100):
|
||||
await structio.sleep(0.01)
|
||||
|
||||
|
||||
async def main(tests: list[int]):
|
||||
print("[main] Starting stress test, aggregate results will be printed at the end")
|
||||
results = []
|
||||
for N in tests:
|
||||
print(f"[main] Starting test with {N} tasks")
|
||||
start = dtt.datetime.utcnow()
|
||||
async with structio.create_pool() as p:
|
||||
for _ in range(N):
|
||||
p.spawn(task)
|
||||
end = dtt.datetime.utcnow()
|
||||
results.append((end - start).total_seconds())
|
||||
print(f"[main] Test with {N} tasks completed in {results[-1]:.2f} seconds")
|
||||
results = " ".join(
|
||||
(f"\n\t- {tests[i]} tasks: {r:.2f}" for i, r in enumerate(results))
|
||||
)
|
||||
print(f"[main] Results (values are in seconds): {results}")
|
||||
|
||||
|
||||
structio.run(main, [10, 100, 1000, 10_000])
|
|
@ -0,0 +1,52 @@
|
|||
import structio
|
||||
|
||||
from nested_pool_inner_raises import successful, failing
|
||||
|
||||
|
||||
async def main_cancel(i):
|
||||
print("[main] Parent is alive, spawning child")
|
||||
t = structio.clock()
|
||||
async with structio.create_pool() as pool:
|
||||
task: structio.Task = pool.spawn(successful, "test", i * 2)
|
||||
print(f"[main] Child spawned, waiting {i} seconds before canceling it")
|
||||
await structio.sleep(i)
|
||||
print("[main] Cancelling child")
|
||||
# Tasks can be cancelled individually, if necessary.
|
||||
# Cancellation is not a checkpoint, by the way (as
|
||||
# evidenced by the lack of the 'await' keyword!)
|
||||
task.cancel()
|
||||
print(f"[main] Exited in {structio.clock() - t:.2f} seconds")
|
||||
|
||||
|
||||
async def main_wait_successful(i):
|
||||
print("[main] Parent is alive, spawning (and explicitly waiting for) child")
|
||||
t = structio.clock()
|
||||
async with structio.create_pool() as pool:
|
||||
# The spawn() method returns a Task object that can be
|
||||
# independently managed if necessary. Awaiting the Task
|
||||
# will wait for it to complete and return its return value,
|
||||
# as well as propagate any exceptions it may raise. Note that
|
||||
# in this example we could've just awaited the coroutine directly,
|
||||
# so it's a bad show for the feature, but you could theoretically
|
||||
# pass the object around somewhere else and do the awaiting there
|
||||
print(f"[main] Child has returned: {await pool.spawn(successful, 'test', i)}")
|
||||
print(f"[main] Exited in {structio.clock() - t:.2f} seconds")
|
||||
|
||||
|
||||
async def main_wait_failing(i):
|
||||
print("[main] Parent is alive, spawning (and explicitly waiting for) child")
|
||||
t = structio.clock()
|
||||
try:
|
||||
async with structio.create_pool() as pool:
|
||||
# This never completes
|
||||
await pool.spawn(failing, "test", i)
|
||||
print("This is never executed")
|
||||
except TypeError:
|
||||
print(f"[main] TypeError caught!")
|
||||
print(f"[main] Exited in {structio.clock() - t:.2f} seconds")
|
||||
|
||||
|
||||
# Total time should be about 15s
|
||||
structio.run(main_cancel, 5)
|
||||
structio.run(main_wait_successful, 5)
|
||||
structio.run(main_wait_failing, 5)
|
|
@ -0,0 +1,55 @@
|
|||
import structio
|
||||
import time
|
||||
|
||||
|
||||
def fake_async_sleeper(n, name: str = ""):
|
||||
print(f"[thread{f' {name}' if name else ''}] About to sleep for {n} seconds")
|
||||
t = time.time()
|
||||
if structio.thread.is_async_thread():
|
||||
print(f"[thread{f' {name}' if name else ''}] I have async superpowers!")
|
||||
structio.thread.run_coro(structio.sleep, n)
|
||||
else:
|
||||
print(f"[thread{f' {name}' if name else ''}] Using old boring time.sleep :(")
|
||||
time.sleep(n)
|
||||
print(
|
||||
f"[thread{f' {name}' if name else ''}] Slept for {time.time() - t:.2f} seconds"
|
||||
)
|
||||
return n
|
||||
|
||||
|
||||
async def main(n):
|
||||
print(f"[main] Spawning worker thread, exiting in {n} seconds")
|
||||
t = structio.clock()
|
||||
d = await structio.thread.run_in_worker(fake_async_sleeper, n)
|
||||
assert d == n
|
||||
print(f"[main] Exited in {structio.clock() - t:.2f} seconds")
|
||||
|
||||
|
||||
async def main_timeout(n, k):
|
||||
print(f"[main] Spawning worker thread, exiting in {k} seconds")
|
||||
t = structio.clock()
|
||||
with structio.skip_after(k):
|
||||
# We need to make the operation explicitly cancellable if we want
|
||||
# to be able to move on!
|
||||
await structio.thread.run_in_worker(fake_async_sleeper, n, cancellable=True)
|
||||
print(f"[main] Exited in {structio.clock() - t:.2f} seconds")
|
||||
|
||||
|
||||
async def main_multiple(n, k):
|
||||
print(f"[main] Spawning {n} worker threads each sleeping for {k} seconds")
|
||||
t = structio.clock()
|
||||
async with structio.create_pool() as pool:
|
||||
for i in range(n):
|
||||
pool.spawn(structio.thread.run_in_worker, fake_async_sleeper, k, str(i))
|
||||
print(f"[main] Workers spawned")
|
||||
# Keep in mind that there is some overhead associated with running worker threads,
|
||||
# not to mention that it gets tricky with how the OS schedules them and whatnot. So,
|
||||
# it's unlikely that all threads finish exactly at the same time and that we exit in
|
||||
# k seconds, even just because there's a lot of back and forth going on under the hood
|
||||
# between structio and the worker threads themselves
|
||||
print(f"[main] Exited in {structio.clock() - t:.2f} seconds")
|
||||
|
||||
|
||||
structio.run(main, 2)
|
||||
structio.run(main_timeout, 5, 3)
|
||||
structio.run(main_multiple, 10, 2)
|
|
@ -0,0 +1,70 @@
|
|||
import structio
|
||||
|
||||
|
||||
async def test_silent(i, j):
|
||||
print(f"[test] Parent is alive, exiting after {i:.2f} seconds")
|
||||
k = structio.clock()
|
||||
with structio.skip_after(i) as scope:
|
||||
print(f"[test] Sleeping for {j} seconds")
|
||||
await structio.sleep(j)
|
||||
print(
|
||||
f"[test] Finished in {structio.clock() - k:.2f} seconds (timed out: {scope.timed_out})"
|
||||
)
|
||||
|
||||
|
||||
async def test_loud(i, j):
|
||||
print(f"[test] Parent is alive, exiting after {i:.2f} seconds")
|
||||
k = structio.clock()
|
||||
try:
|
||||
with structio.with_timeout(i):
|
||||
print(f"[test] Sleeping for {j} seconds")
|
||||
await structio.sleep(j)
|
||||
except structio.TimedOut:
|
||||
print("[test] Timed out!")
|
||||
print(f"[test] Finished in {structio.clock() - k:.2f} seconds")
|
||||
|
||||
|
||||
async def deadlock():
|
||||
await structio.Event().wait()
|
||||
|
||||
|
||||
async def test_deadlock(i):
|
||||
print(f"[test] Parent is alive, will exit in {i} seconds")
|
||||
t = structio.clock()
|
||||
with structio.skip_after(i):
|
||||
print("[test] Entering deadlock")
|
||||
await deadlock()
|
||||
print(f"[test] Done in {structio.clock() - t:.2f} seconds")
|
||||
|
||||
|
||||
async def test_nested(i):
|
||||
print(f"[test] Parent is alive, will exit in {i} seconds")
|
||||
t = structio.clock()
|
||||
with structio.skip_after(i):
|
||||
print("[test] Entered first scope")
|
||||
with structio.skip_after(i * 2):
|
||||
# Even though this scope's timeout is
|
||||
# larger than its parent, structio will
|
||||
# still cancel it when its containing
|
||||
# scope expires
|
||||
print("[test] Entered second scope")
|
||||
await deadlock()
|
||||
print(f"[test] Done in {structio.clock() - t:.2f} seconds")
|
||||
|
||||
|
||||
async def nested_mess(n):
|
||||
print(f"[test] Exiting in {n} seconds")
|
||||
t = structio.clock()
|
||||
with structio.TaskScope() as r:
|
||||
with structio.skip_after(n) as s:
|
||||
assert r.get_effective_deadline()[1] is s
|
||||
# This never completes
|
||||
await structio.sleep(n * 10)
|
||||
print(f"[test] Exited in {structio.clock() - t:.2f} seconds")
|
||||
|
||||
|
||||
structio.run(nested_mess, 3)
|
||||
structio.run(test_silent, 3, 5)
|
||||
structio.run(test_loud, 3, 5)
|
||||
structio.run(test_deadlock, 5)
|
||||
structio.run(test_nested, 5)
|
Loading…
Reference in New Issue