Fix cancellation bugs and add ability to call some objects in remote python processes
This commit is contained in:
parent
e4f50fe95a
commit
e647b7631d
|
@ -1,2 +1,3 @@
|
|||
sniffio==1.3.0
|
||||
msgpack==1.0.8
|
||||
msgpack==1.0.8
|
||||
multiprocessing_utils==0.4
|
|
@ -56,7 +56,6 @@ class FIFOKernel(BaseKernel):
|
|||
# Paused tasks along with their deadlines
|
||||
self.paused: TimeQueue = TimeQueue()
|
||||
self.pool = TaskPool()
|
||||
self.pool.scope.shielded = True
|
||||
self.current_scope = self.pool.scope
|
||||
self._shutting_down = False
|
||||
|
||||
|
@ -146,7 +145,7 @@ class FIFOKernel(BaseKernel):
|
|||
pool = self.pool
|
||||
else:
|
||||
pool = self.current_pool
|
||||
task = Task(name, func(*args), pool, is_system_task=system_task)
|
||||
task = Task(name, func(*args), pool.scope, pool, is_system_task=system_task)
|
||||
pool.scope.tasks.append(task)
|
||||
# We inject our magic secret variable into the coroutine's stack frame, so
|
||||
# we can look it up later
|
||||
|
@ -188,7 +187,8 @@ class FIFOKernel(BaseKernel):
|
|||
self.current_task.coroutine.send, self.data.pop(self.current_task, None)
|
||||
)
|
||||
if self.current_task.pending_cancellation:
|
||||
runner = partial(self.current_task.coroutine.throw, Cancelled())
|
||||
self.cancel_task(self.current_task)
|
||||
return
|
||||
elif self._sigint_handled and not critical_section(self.current_task.coroutine.cr_frame):
|
||||
self._sigint_handled = False
|
||||
runner = partial(self.current_task.coroutine.throw, KeyboardInterrupt())
|
||||
|
@ -196,7 +196,7 @@ class FIFOKernel(BaseKernel):
|
|||
self.current_task.state = TaskState.RUNNING
|
||||
self.current_task.paused_when = 0
|
||||
self.current_pool = self.current_task.pool
|
||||
self.current_scope = self.current_pool.scope
|
||||
self.current_scope = self.current_task.scope
|
||||
data = self.handle_errors(runner, self.current_task)
|
||||
if data is not None:
|
||||
method, args, kwargs = data
|
||||
|
@ -302,9 +302,16 @@ class FIFOKernel(BaseKernel):
|
|||
while not self.done():
|
||||
self._tick()
|
||||
self._shutting_down = True
|
||||
# Ensure all system tasks have a chance to spin up
|
||||
while any(task.state == TaskState.INIT for task in self.pool.scope.tasks):
|
||||
self._tick()
|
||||
# Cancel the system pool and wait for cancellation
|
||||
# to be delivered
|
||||
self.pool.scope.cancel()
|
||||
while not self.done():
|
||||
self._tick()
|
||||
# Reset some stuff
|
||||
self.pool.scope.attempted_cancel = False
|
||||
if self.entry_point.state == TaskState.FINISHED:
|
||||
while True:
|
||||
# Spawn all the shutdown tasks that are currently registered
|
||||
|
@ -367,7 +374,7 @@ class FIFOKernel(BaseKernel):
|
|||
except (Exception, KeyboardInterrupt) as err:
|
||||
# Any other exception is caught here
|
||||
task.exc = err
|
||||
err.scope = task.pool.scope
|
||||
err.scope = task.scope
|
||||
task.state = TaskState.CRASHED
|
||||
self.on_error(task)
|
||||
finally:
|
||||
|
@ -417,7 +424,7 @@ class FIFOKernel(BaseKernel):
|
|||
if task.is_system_task:
|
||||
self.close(force=True)
|
||||
raise task.exc from StructIOException(f"system task {task} crashed")
|
||||
scope = task.pool.scope
|
||||
scope = task.scope
|
||||
self.release(task)
|
||||
self.cancel_scope(scope)
|
||||
if task is not scope.owner:
|
||||
|
@ -435,6 +442,9 @@ class FIFOKernel(BaseKernel):
|
|||
self.release(task)
|
||||
|
||||
def init_scope(self, scope: TaskScope):
|
||||
if self.current_task is not self.current_scope.owner:
|
||||
self.current_scope.tasks.remove(self.current_task)
|
||||
self.current_task.scope = scope
|
||||
scope.deadline = self.clock.deadline(scope.timeout)
|
||||
scope.owner = self.current_task
|
||||
self.current_scope.inner.append(scope)
|
||||
|
@ -444,18 +454,19 @@ class FIFOKernel(BaseKernel):
|
|||
def close_scope(self, scope: TaskScope):
|
||||
self.current_scope = scope.outer
|
||||
self.current_scope.inner = []
|
||||
self.current_task.scope = self.current_scope
|
||||
|
||||
def cancel_task(self, task: Task):
|
||||
if task.done():
|
||||
return
|
||||
if task.state == TaskState.RUNNING:
|
||||
if task.state in [TaskState.RUNNING]:
|
||||
# Can't cancel a task while it's
|
||||
# running, will raise ValueError
|
||||
# if we try. We defer it for later
|
||||
# if we try, so we defer it for later
|
||||
task.pending_cancellation = True
|
||||
return
|
||||
err = Cancelled()
|
||||
err.scope = task.pool.scope
|
||||
err.scope = task.scope
|
||||
self.throw(task, err)
|
||||
if task.state != TaskState.CANCELLED:
|
||||
# Task is stubborn. But so are we,
|
||||
|
@ -479,15 +490,7 @@ class FIFOKernel(BaseKernel):
|
|||
if task is self.current_task:
|
||||
continue
|
||||
self.cancel_task(task)
|
||||
if (
|
||||
scope is not self.current_task.pool.scope
|
||||
and scope.owner is not self.current_task
|
||||
and scope.owner is not self.entry_point
|
||||
and scope.owner is not None
|
||||
):
|
||||
# Handles the case where the current task calls
|
||||
# cancel() for a scope which it doesn't own, which
|
||||
# is an entirely reasonable thing to do
|
||||
if scope.owner is not self.entry_point:
|
||||
self.cancel_task(scope.owner)
|
||||
if scope.done():
|
||||
scope.cancelled = True
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
import inspect
|
||||
import structio
|
||||
import functools
|
||||
from threading import local
|
||||
# I *really* hate fork()
|
||||
from multiprocessing_utils import local
|
||||
from structio.abc import (
|
||||
BaseKernel,
|
||||
BaseDebugger,
|
||||
|
|
|
@ -30,6 +30,8 @@ class Task:
|
|||
# The underlying coroutine of this
|
||||
# task
|
||||
coroutine: Coroutine = field(repr=False)
|
||||
# The task's scope
|
||||
scope: "TaskScope"
|
||||
# The task's pool
|
||||
pool: "TaskPool" = field(repr=False)
|
||||
# The state of the task
|
||||
|
|
|
@ -12,6 +12,7 @@ import platform
|
|||
import subprocess
|
||||
from itertools import count
|
||||
|
||||
from structio.util.finder import ObjectReference
|
||||
from structio.io import FileStream
|
||||
from multiprocessing import cpu_count
|
||||
from structio import Semaphore, Queue
|
||||
|
@ -315,7 +316,8 @@ class PythonProcess:
|
|||
Run a separate python process asynchronously
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, target):
|
||||
self.target = ObjectReference(target)
|
||||
self._sock = structio.socket.socket()
|
||||
self._remote: structio.AsyncSocket | None = None
|
||||
self._started = structio.Event()
|
||||
|
@ -325,18 +327,28 @@ class PythonProcess:
|
|||
"""
|
||||
Sends a
|
||||
"""
|
||||
|
||||
data = msgpack.dumps(data)
|
||||
await self._remote.send_all(struct.pack("Q", len(data)))
|
||||
await self._remote.send_all(data)
|
||||
await self._ensure_ack()
|
||||
|
||||
async def _send_ref(self):
|
||||
msg = {"msg": "EXEC", "ref": str(self.target), "file": None}
|
||||
if not self.target.in_package:
|
||||
msg["file"] = self.target.get_file()
|
||||
await self.send_message(msg)
|
||||
|
||||
async def _do_setup(self):
|
||||
if not self._remote:
|
||||
with structio.TaskScope(shielded=True):
|
||||
# We use a shielded task scope to ensure
|
||||
# the setup always runs to completion
|
||||
await self._sock.bind(("127.0.0.1", 0))
|
||||
await self._sock.listen(1)
|
||||
addr, port = self._sock.getsockname()
|
||||
self.process = Process([sys.executable, "-m", "structio._child_process", addr, str(port)])
|
||||
# If we didn't close the socket before calling wait(), we'd deadlock!
|
||||
self.process = Process([sys.executable, "-m", "structio.util.child_process", addr, str(port)])
|
||||
# If we didn't close the socket before calling wait(), we'd deadlock waiting for the
|
||||
# process to exit while the process waits for us to send them a message
|
||||
self.process.add_shutdown_handler(self.close, before_wait=True)
|
||||
self.process.start()
|
||||
await self.process.wait_started()
|
||||
|
@ -344,12 +356,16 @@ class PythonProcess:
|
|||
sock, _addr = await self._sock.accept()
|
||||
self._remote = sock
|
||||
await self.send_sos()
|
||||
await self._send_ref()
|
||||
|
||||
async def _ensure_ack(self):
|
||||
payload = await self.receive_message()
|
||||
try:
|
||||
payload = await self.receive_message()
|
||||
except StructIOException as e:
|
||||
raise StructIOException("unable to get ACK from remote process") from e
|
||||
if payload["msg"] != "ACK":
|
||||
raise StructIOException(f"invalid message type {payload['msg']!r} received from process (expecting 'ACK')"
|
||||
f": {payload}")
|
||||
raise StructIOException(f"invalid message type {payload['msg']!r} received from process (expecting "
|
||||
f"'ACK'): {payload}")
|
||||
|
||||
async def send_sos(self):
|
||||
"""
|
||||
|
@ -372,7 +388,10 @@ class PythonProcess:
|
|||
Terminate the remote process. If graceful equals
|
||||
True, the default, a graceful shutdown is attempted
|
||||
"""
|
||||
|
||||
|
||||
if not await self.is_running():
|
||||
return
|
||||
|
||||
if graceful:
|
||||
await self.send_eos()
|
||||
await self._remote.close()
|
||||
|
@ -391,17 +410,15 @@ class PythonProcess:
|
|||
await self._ensure_started()
|
||||
data = await self._remote.receive(8)
|
||||
if not data:
|
||||
raise ResourceBroken("something went wrong when communicating with the remote process")
|
||||
raise ResourceBroken("remote socket was closed abruptly")
|
||||
size, *_ = struct.unpack("Q", data)
|
||||
return msgpack.unpackb(await self._remote.receive_exactly(size))
|
||||
message = msgpack.unpackb(await self._remote.receive_exactly(size))
|
||||
if not message["ok"]:
|
||||
raise StructIOException(f"got error response from remote process: {message}")
|
||||
return message
|
||||
|
||||
def start(self):
|
||||
# We both spawn this as a system task and schedule it as a shutdown
|
||||
# handler so that the remote process always has a chance to connect
|
||||
# to us if the entry point exits with no error but without waiting for
|
||||
# the process to terminate
|
||||
structio.current_loop().spawn_system_task(self._do_setup)
|
||||
structio.current_loop().add_shutdown_task(self._do_setup)
|
||||
|
||||
async def is_running(self):
|
||||
if not self.process:
|
||||
|
|
|
@ -2,12 +2,14 @@
|
|||
Helper module to spawn asynchronous Python processes via
|
||||
structio
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import struct
|
||||
import structio
|
||||
from structio.exceptions import ResourceBroken
|
||||
import msgpack
|
||||
import structio
|
||||
import importlib
|
||||
from structio.util.finder import ObjectReference
|
||||
from structio.exceptions import ResourceBroken
|
||||
|
||||
|
||||
async def send_message(sock: structio.AsyncSocket, payload: dict):
|
||||
|
@ -32,6 +34,14 @@ async def receive_message(sock: structio.AsyncSocket) -> dict:
|
|||
return msgpack.unpackb(await sock.receive_exactly(size))
|
||||
|
||||
|
||||
async def send_ack(sock: structio.AsyncSocket):
|
||||
"""
|
||||
Sends an acknowledgement packet back to the remote
|
||||
event loop on the given socket
|
||||
"""
|
||||
return await send_message(sock, {"ok": True, "msg": "ACK"})
|
||||
|
||||
|
||||
async def dispatch(sock: structio.AsyncSocket, message: dict):
|
||||
"""
|
||||
Dispatches commands received from the remote event loop
|
||||
|
@ -45,16 +55,20 @@ async def dispatch(sock: structio.AsyncSocket, message: dict):
|
|||
match message["msg"]:
|
||||
case "HELO":
|
||||
# SOS: Start of session
|
||||
print("Received SOS from remote event loop")
|
||||
await send_message(sock, {"ok": True, "msg": "ACK"})
|
||||
await send_ack(sock)
|
||||
case "EXEC":
|
||||
await send_ack(sock)
|
||||
if message["file"]:
|
||||
sys.path.append(os.path.dirname(message["file"]))
|
||||
obj = ObjectReference.load_ref(message["ref"])
|
||||
obj()
|
||||
case "CYA":
|
||||
# EOS: End of session (aka the process can and should exit)
|
||||
print("Received EOS from remote event, shutting down")
|
||||
await send_message(sock, {"ok": True, "msg": "ACK"})
|
||||
await send_ack(sock)
|
||||
sys.exit(0)
|
||||
case _:
|
||||
print(f"Unknown message type {message['msg']!r}: {message}")
|
||||
await send_message(sock, {"ok": False, "msg": "ACK", "error": f"unknown message type {message['msg']!r}"})
|
||||
# IDK: I don't know (means the command is unknown)
|
||||
await send_message(sock, {"ok": False, "msg": "IDK", "error": f"unknown message type {message['msg']!r}"})
|
||||
|
||||
|
||||
async def main(addr: tuple[str, int]):
|
||||
|
@ -65,5 +79,5 @@ async def main(addr: tuple[str, int]):
|
|||
while True:
|
||||
await dispatch(socket, await receive_message(socket))
|
||||
|
||||
|
||||
structio.run(main, (sys.argv[1], int(sys.argv[2])))
|
||||
if __name__ == "__main__":
|
||||
structio.run(main, (sys.argv[1], int(sys.argv[2])))
|
|
@ -0,0 +1,89 @@
|
|||
"""
|
||||
Utility module to look up objects to be called in a Python subprocess spawned
|
||||
by structio. Inspired by https://pikers.dev/goodboy/tractor/src/branch/mv_to_new_trio_py3.11/tractor/msg/ptr.py#L53
|
||||
"""
|
||||
import importlib
|
||||
from pathlib import Path
|
||||
from inspect import ismethod, isfunction, isbuiltin, getmodule, getmodulename
|
||||
from pkgutil import resolve_name
|
||||
|
||||
|
||||
def islambda(v) -> bool:
|
||||
"""
|
||||
Returns whether the passed object is a lambda
|
||||
"""
|
||||
|
||||
f = lambda: ...
|
||||
return isinstance(v, type(f)) and v.__name__ == f.__name__
|
||||
|
||||
|
||||
def get_real_module_name(obj) -> tuple[bool, str]:
|
||||
in_package = True
|
||||
module = obj.__module__
|
||||
|
||||
if "__main__" in module:
|
||||
# get parent modules of object
|
||||
mod_obj = getmodule(obj)
|
||||
|
||||
# from the filename of the module, get its name
|
||||
mod_suffix = getmodulename(mod_obj.__file__)
|
||||
|
||||
# join parent to child with a .
|
||||
module = '.'.join(filter(bool, [mod_obj.__package__, mod_suffix]))
|
||||
|
||||
if mod_obj.__package__ is None:
|
||||
in_package = False
|
||||
|
||||
return in_package, module
|
||||
|
||||
|
||||
class ObjectReference:
|
||||
"""
|
||||
A reference to some arbitrary, named Python object located
|
||||
in a module we have access to. This class serves the purpose
|
||||
of serializing callables into unique string identifiers that
|
||||
can be passed along to remote Python processes. This solution
|
||||
is a lot cleaner than trying to serialize objects to bytes and
|
||||
is still useful for the majority of cases, however the approach
|
||||
does have limitations, namely:
|
||||
- lambdas aren't supported, since they have no name and would
|
||||
need to be copied/serialized (which is against the intended
|
||||
design)
|
||||
- bound methods aren't supported either, because we'd have to copy
|
||||
the instance they're bound to over to the remote process. If you
|
||||
need to call instance methods in the remote process, pass a function
|
||||
that constructs the object and calls the method instead
|
||||
"""
|
||||
|
||||
def __init__(self, obj):
|
||||
if ismethod(obj) or islambda(obj):
|
||||
raise ValueError("bound methods and lambdas cannot be passed to a remote process")
|
||||
self.obj = obj
|
||||
self.in_package = True
|
||||
self._make_ref()
|
||||
|
||||
def _make_ref(self) -> tuple[str, str]:
|
||||
"""
|
||||
Returns a tuple (module, name) that uniquely
|
||||
represents the wrapped object `name` in its
|
||||
`module`
|
||||
"""
|
||||
|
||||
name: str
|
||||
if isfunction(self.obj) or isbuiltin(self.obj):
|
||||
name = self.obj.__name__
|
||||
else:
|
||||
name = type(self.obj).__name__
|
||||
in_package, module = get_real_module_name(self.obj)
|
||||
self.in_package = in_package
|
||||
return module, name
|
||||
|
||||
def get_file(self):
|
||||
return str(Path(getmodule(self.obj).__file__).as_posix())
|
||||
|
||||
def __str__(self):
|
||||
return ":".join(self._make_ref())
|
||||
|
||||
@staticmethod
|
||||
def load_ref(ident: str):
|
||||
return resolve_name(ident)
|
|
@ -58,19 +58,27 @@ async def main_limiter():
|
|||
print(f"Submitted {i + 1} processes")
|
||||
|
||||
|
||||
def foo():
|
||||
print("Called in the remote process!")
|
||||
|
||||
|
||||
async def main_python():
|
||||
print("[main] Starting python process test")
|
||||
# Spawns a new Python process
|
||||
p = structio.parallel.PythonProcess()
|
||||
p = structio.parallel.PythonProcess(target=foo)
|
||||
p.start()
|
||||
# TODO: Allow for calling of arbitrary python objects in the spawned process, except bound methods
|
||||
# (which are tricky due to the shared state they carry), lambdas (they cannot looked up by name in
|
||||
# the newly created process and would need to be serialized, which is not the intended design) and
|
||||
# idk maybe a few others?
|
||||
# await p.wait_started()
|
||||
# await p.wait()
|
||||
print("[main] Pyhon process test complete")
|
||||
|
||||
|
||||
structio.run(main_simple, "owo")
|
||||
structio.run(main_limiter)
|
||||
structio.run(main_python)
|
||||
if __name__ == "__main__":
|
||||
# structio.run(main_simple, "owo")
|
||||
# structio.run(main_limiter)
|
||||
structio.run(main_python)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue