Fix cancellation bugs and add ability to call some objects in remote python processes

This commit is contained in:
Mattia Giambirtone 2024-03-16 22:17:45 +01:00
parent e4f50fe95a
commit e647b7631d
8 changed files with 185 additions and 50 deletions

View File

@ -1,2 +1,3 @@
sniffio==1.3.0
msgpack==1.0.8
msgpack==1.0.8
multiprocessing_utils==0.4

View File

@ -56,7 +56,6 @@ class FIFOKernel(BaseKernel):
# Paused tasks along with their deadlines
self.paused: TimeQueue = TimeQueue()
self.pool = TaskPool()
self.pool.scope.shielded = True
self.current_scope = self.pool.scope
self._shutting_down = False
@ -146,7 +145,7 @@ class FIFOKernel(BaseKernel):
pool = self.pool
else:
pool = self.current_pool
task = Task(name, func(*args), pool, is_system_task=system_task)
task = Task(name, func(*args), pool.scope, pool, is_system_task=system_task)
pool.scope.tasks.append(task)
# We inject our magic secret variable into the coroutine's stack frame, so
# we can look it up later
@ -188,7 +187,8 @@ class FIFOKernel(BaseKernel):
self.current_task.coroutine.send, self.data.pop(self.current_task, None)
)
if self.current_task.pending_cancellation:
runner = partial(self.current_task.coroutine.throw, Cancelled())
self.cancel_task(self.current_task)
return
elif self._sigint_handled and not critical_section(self.current_task.coroutine.cr_frame):
self._sigint_handled = False
runner = partial(self.current_task.coroutine.throw, KeyboardInterrupt())
@ -196,7 +196,7 @@ class FIFOKernel(BaseKernel):
self.current_task.state = TaskState.RUNNING
self.current_task.paused_when = 0
self.current_pool = self.current_task.pool
self.current_scope = self.current_pool.scope
self.current_scope = self.current_task.scope
data = self.handle_errors(runner, self.current_task)
if data is not None:
method, args, kwargs = data
@ -302,9 +302,16 @@ class FIFOKernel(BaseKernel):
while not self.done():
self._tick()
self._shutting_down = True
# Ensure all system tasks have a chance to spin up
while any(task.state == TaskState.INIT for task in self.pool.scope.tasks):
self._tick()
# Cancel the system pool and wait for cancellation
# to be delivered
self.pool.scope.cancel()
while not self.done():
self._tick()
# Reset some stuff
self.pool.scope.attempted_cancel = False
if self.entry_point.state == TaskState.FINISHED:
while True:
# Spawn all the shutdown tasks that are currently registered
@ -367,7 +374,7 @@ class FIFOKernel(BaseKernel):
except (Exception, KeyboardInterrupt) as err:
# Any other exception is caught here
task.exc = err
err.scope = task.pool.scope
err.scope = task.scope
task.state = TaskState.CRASHED
self.on_error(task)
finally:
@ -417,7 +424,7 @@ class FIFOKernel(BaseKernel):
if task.is_system_task:
self.close(force=True)
raise task.exc from StructIOException(f"system task {task} crashed")
scope = task.pool.scope
scope = task.scope
self.release(task)
self.cancel_scope(scope)
if task is not scope.owner:
@ -435,6 +442,9 @@ class FIFOKernel(BaseKernel):
self.release(task)
def init_scope(self, scope: TaskScope):
if self.current_task is not self.current_scope.owner:
self.current_scope.tasks.remove(self.current_task)
self.current_task.scope = scope
scope.deadline = self.clock.deadline(scope.timeout)
scope.owner = self.current_task
self.current_scope.inner.append(scope)
@ -444,18 +454,19 @@ class FIFOKernel(BaseKernel):
def close_scope(self, scope: TaskScope):
self.current_scope = scope.outer
self.current_scope.inner = []
self.current_task.scope = self.current_scope
def cancel_task(self, task: Task):
if task.done():
return
if task.state == TaskState.RUNNING:
if task.state in [TaskState.RUNNING]:
# Can't cancel a task while it's
# running, will raise ValueError
# if we try. We defer it for later
# if we try, so we defer it for later
task.pending_cancellation = True
return
err = Cancelled()
err.scope = task.pool.scope
err.scope = task.scope
self.throw(task, err)
if task.state != TaskState.CANCELLED:
# Task is stubborn. But so are we,
@ -479,15 +490,7 @@ class FIFOKernel(BaseKernel):
if task is self.current_task:
continue
self.cancel_task(task)
if (
scope is not self.current_task.pool.scope
and scope.owner is not self.current_task
and scope.owner is not self.entry_point
and scope.owner is not None
):
# Handles the case where the current task calls
# cancel() for a scope which it doesn't own, which
# is an entirely reasonable thing to do
if scope.owner is not self.entry_point:
self.cancel_task(scope.owner)
if scope.done():
scope.cancelled = True

View File

@ -1,7 +1,8 @@
import inspect
import structio
import functools
from threading import local
# I *really* hate fork()
from multiprocessing_utils import local
from structio.abc import (
BaseKernel,
BaseDebugger,

View File

@ -30,6 +30,8 @@ class Task:
# The underlying coroutine of this
# task
coroutine: Coroutine = field(repr=False)
# The task's scope
scope: "TaskScope"
# The task's pool
pool: "TaskPool" = field(repr=False)
# The state of the task

View File

@ -12,6 +12,7 @@ import platform
import subprocess
from itertools import count
from structio.util.finder import ObjectReference
from structio.io import FileStream
from multiprocessing import cpu_count
from structio import Semaphore, Queue
@ -315,7 +316,8 @@ class PythonProcess:
Run a separate python process asynchronously
"""
def __init__(self):
def __init__(self, target):
self.target = ObjectReference(target)
self._sock = structio.socket.socket()
self._remote: structio.AsyncSocket | None = None
self._started = structio.Event()
@ -325,18 +327,28 @@ class PythonProcess:
"""
Sends a
"""
data = msgpack.dumps(data)
await self._remote.send_all(struct.pack("Q", len(data)))
await self._remote.send_all(data)
await self._ensure_ack()
async def _send_ref(self):
msg = {"msg": "EXEC", "ref": str(self.target), "file": None}
if not self.target.in_package:
msg["file"] = self.target.get_file()
await self.send_message(msg)
async def _do_setup(self):
if not self._remote:
with structio.TaskScope(shielded=True):
# We use a shielded task scope to ensure
# the setup always runs to completion
await self._sock.bind(("127.0.0.1", 0))
await self._sock.listen(1)
addr, port = self._sock.getsockname()
self.process = Process([sys.executable, "-m", "structio._child_process", addr, str(port)])
# If we didn't close the socket before calling wait(), we'd deadlock!
self.process = Process([sys.executable, "-m", "structio.util.child_process", addr, str(port)])
# If we didn't close the socket before calling wait(), we'd deadlock waiting for the
# process to exit while the process waits for us to send them a message
self.process.add_shutdown_handler(self.close, before_wait=True)
self.process.start()
await self.process.wait_started()
@ -344,12 +356,16 @@ class PythonProcess:
sock, _addr = await self._sock.accept()
self._remote = sock
await self.send_sos()
await self._send_ref()
async def _ensure_ack(self):
payload = await self.receive_message()
try:
payload = await self.receive_message()
except StructIOException as e:
raise StructIOException("unable to get ACK from remote process") from e
if payload["msg"] != "ACK":
raise StructIOException(f"invalid message type {payload['msg']!r} received from process (expecting 'ACK')"
f": {payload}")
raise StructIOException(f"invalid message type {payload['msg']!r} received from process (expecting "
f"'ACK'): {payload}")
async def send_sos(self):
"""
@ -372,7 +388,10 @@ class PythonProcess:
Terminate the remote process. If graceful equals
True, the default, a graceful shutdown is attempted
"""
if not await self.is_running():
return
if graceful:
await self.send_eos()
await self._remote.close()
@ -391,17 +410,15 @@ class PythonProcess:
await self._ensure_started()
data = await self._remote.receive(8)
if not data:
raise ResourceBroken("something went wrong when communicating with the remote process")
raise ResourceBroken("remote socket was closed abruptly")
size, *_ = struct.unpack("Q", data)
return msgpack.unpackb(await self._remote.receive_exactly(size))
message = msgpack.unpackb(await self._remote.receive_exactly(size))
if not message["ok"]:
raise StructIOException(f"got error response from remote process: {message}")
return message
def start(self):
# We both spawn this as a system task and schedule it as a shutdown
# handler so that the remote process always has a chance to connect
# to us if the entry point exits with no error but without waiting for
# the process to terminate
structio.current_loop().spawn_system_task(self._do_setup)
structio.current_loop().add_shutdown_task(self._do_setup)
async def is_running(self):
if not self.process:

View File

@ -2,12 +2,14 @@
Helper module to spawn asynchronous Python processes via
structio
"""
import os
import sys
import struct
import structio
from structio.exceptions import ResourceBroken
import msgpack
import structio
import importlib
from structio.util.finder import ObjectReference
from structio.exceptions import ResourceBroken
async def send_message(sock: structio.AsyncSocket, payload: dict):
@ -32,6 +34,14 @@ async def receive_message(sock: structio.AsyncSocket) -> dict:
return msgpack.unpackb(await sock.receive_exactly(size))
async def send_ack(sock: structio.AsyncSocket):
"""
Sends an acknowledgement packet back to the remote
event loop on the given socket
"""
return await send_message(sock, {"ok": True, "msg": "ACK"})
async def dispatch(sock: structio.AsyncSocket, message: dict):
"""
Dispatches commands received from the remote event loop
@ -45,16 +55,20 @@ async def dispatch(sock: structio.AsyncSocket, message: dict):
match message["msg"]:
case "HELO":
# SOS: Start of session
print("Received SOS from remote event loop")
await send_message(sock, {"ok": True, "msg": "ACK"})
await send_ack(sock)
case "EXEC":
await send_ack(sock)
if message["file"]:
sys.path.append(os.path.dirname(message["file"]))
obj = ObjectReference.load_ref(message["ref"])
obj()
case "CYA":
# EOS: End of session (aka the process can and should exit)
print("Received EOS from remote event, shutting down")
await send_message(sock, {"ok": True, "msg": "ACK"})
await send_ack(sock)
sys.exit(0)
case _:
print(f"Unknown message type {message['msg']!r}: {message}")
await send_message(sock, {"ok": False, "msg": "ACK", "error": f"unknown message type {message['msg']!r}"})
# IDK: I don't know (means the command is unknown)
await send_message(sock, {"ok": False, "msg": "IDK", "error": f"unknown message type {message['msg']!r}"})
async def main(addr: tuple[str, int]):
@ -65,5 +79,5 @@ async def main(addr: tuple[str, int]):
while True:
await dispatch(socket, await receive_message(socket))
structio.run(main, (sys.argv[1], int(sys.argv[2])))
if __name__ == "__main__":
structio.run(main, (sys.argv[1], int(sys.argv[2])))

89
structio/util/finder.py Normal file
View File

@ -0,0 +1,89 @@
"""
Utility module to look up objects to be called in a Python subprocess spawned
by structio. Inspired by https://pikers.dev/goodboy/tractor/src/branch/mv_to_new_trio_py3.11/tractor/msg/ptr.py#L53
"""
import importlib
from pathlib import Path
from inspect import ismethod, isfunction, isbuiltin, getmodule, getmodulename
from pkgutil import resolve_name
def islambda(v) -> bool:
"""
Returns whether the passed object is a lambda
"""
f = lambda: ...
return isinstance(v, type(f)) and v.__name__ == f.__name__
def get_real_module_name(obj) -> tuple[bool, str]:
in_package = True
module = obj.__module__
if "__main__" in module:
# get parent modules of object
mod_obj = getmodule(obj)
# from the filename of the module, get its name
mod_suffix = getmodulename(mod_obj.__file__)
# join parent to child with a .
module = '.'.join(filter(bool, [mod_obj.__package__, mod_suffix]))
if mod_obj.__package__ is None:
in_package = False
return in_package, module
class ObjectReference:
"""
A reference to some arbitrary, named Python object located
in a module we have access to. This class serves the purpose
of serializing callables into unique string identifiers that
can be passed along to remote Python processes. This solution
is a lot cleaner than trying to serialize objects to bytes and
is still useful for the majority of cases, however the approach
does have limitations, namely:
- lambdas aren't supported, since they have no name and would
need to be copied/serialized (which is against the intended
design)
- bound methods aren't supported either, because we'd have to copy
the instance they're bound to over to the remote process. If you
need to call instance methods in the remote process, pass a function
that constructs the object and calls the method instead
"""
def __init__(self, obj):
if ismethod(obj) or islambda(obj):
raise ValueError("bound methods and lambdas cannot be passed to a remote process")
self.obj = obj
self.in_package = True
self._make_ref()
def _make_ref(self) -> tuple[str, str]:
"""
Returns a tuple (module, name) that uniquely
represents the wrapped object `name` in its
`module`
"""
name: str
if isfunction(self.obj) or isbuiltin(self.obj):
name = self.obj.__name__
else:
name = type(self.obj).__name__
in_package, module = get_real_module_name(self.obj)
self.in_package = in_package
return module, name
def get_file(self):
return str(Path(getmodule(self.obj).__file__).as_posix())
def __str__(self):
return ":".join(self._make_ref())
@staticmethod
def load_ref(ident: str):
return resolve_name(ident)

View File

@ -58,19 +58,27 @@ async def main_limiter():
print(f"Submitted {i + 1} processes")
def foo():
print("Called in the remote process!")
async def main_python():
print("[main] Starting python process test")
# Spawns a new Python process
p = structio.parallel.PythonProcess()
p = structio.parallel.PythonProcess(target=foo)
p.start()
# TODO: Allow for calling of arbitrary python objects in the spawned process, except bound methods
# (which are tricky due to the shared state they carry), lambdas (they cannot looked up by name in
# the newly created process and would need to be serialized, which is not the intended design) and
# idk maybe a few others?
# await p.wait_started()
# await p.wait()
print("[main] Pyhon process test complete")
structio.run(main_simple, "owo")
structio.run(main_limiter)
structio.run(main_python)
if __name__ == "__main__":
# structio.run(main_simple, "owo")
# structio.run(main_limiter)
structio.run(main_python)