diff --git a/structio/core/context.py b/structio/core/context.py index d226210..94280bc 100644 --- a/structio/core/context.py +++ b/structio/core/context.py @@ -126,7 +126,9 @@ class TaskPool: raise exc_val else: await suspend() - except Cancelled: + except Cancelled as e: + if e.scope is not self.scope: + self.error = e self.scope.cancelled = True except (Exception, KeyboardInterrupt) as e: self.error = e diff --git a/structio/thread.py b/structio/thread.py index 01dc8ab..5696d98 100644 --- a/structio/thread.py +++ b/structio/thread.py @@ -114,7 +114,8 @@ class AsyncThreadQueue(Queue): evt: AsyncThreadEvent | None = None with self._lock: if self.maxsize and self.maxsize == len(self.container): - self.putters.append(AsyncThreadEvent()) + evt = AsyncThreadEvent() + self.putters.append(evt) if self.getters: self.getters.popleft().set() if evt: @@ -124,7 +125,7 @@ class AsyncThreadQueue(Queue): @enable_ki_protection def get_sync(self): """ - Like get(), but asynchronous + Like get(), but synchronous """ evt: AsyncThreadEvent | None = None @@ -147,6 +148,7 @@ def _threaded_runner(f, q: AsyncThreadQueue, parent_loop: BaseKernel, *args): q.put_sync((False, e)) +@enable_ki_protection async def _async_runner(f, *args): queue = AsyncThreadQueue(1) th = threading.Thread(target=_threaded_runner, args=(f, queue, current_loop(), *args), @@ -175,11 +177,10 @@ async def run_in_worker(sync_func, if not hasattr(_storage, "parent_loop"): _storage.parent_loop = current_loop() async with _storage.max_workers: - async with structio.create_pool() as pool: - # This will automatically block once - # we run out of slots and proceed once - # we have more - return await pool.spawn(_async_runner, sync_func, *args) + # This will automatically block once + # we run out of slots and proceed once + # we have more + return await current_loop().current_pool.spawn(_async_runner, sync_func, *args) def set_max_worker_count(count: int): diff --git a/tests/queue.py b/tests/queue.py index daf6779..50b2719 100644 --- a/tests/queue.py +++ b/tests/queue.py @@ -1,5 +1,4 @@ import time -import threading import structio @@ -9,9 +8,9 @@ async def producer(q: structio.Queue, n: int): # queue is emptied by the # consumer await q.put(i) - print(f"Produced {i}") + print(f"[producer] Produced {i}") await q.put(None) - print("Producer done") + print("[producer] Producer done") async def consumer(q: structio.Queue): @@ -20,9 +19,9 @@ async def consumer(q: structio.Queue): # something on the queue item = await q.get() if item is None: - print("Consumer done") + print("[consumer] Consumer done") break - print(f"Consumed {item}") + print(f"[consumer] Consumed {item}") # Simulates some work so the # producer waits before putting # the next value @@ -35,29 +34,56 @@ def threaded_consumer(q: structio.thread.AsyncThreadQueue): # something on the queue item = q.get_sync() if item is None: - print("Consumer done") + print("[worker consumer] Consumer done") break - print(f"Consumed {item}") + print(f"[worker consumer] Consumed {item}") # Simulates some work so the # producer waits before putting # the next value time.sleep(1) + return 69 async def main(q: structio.Queue, n: int): - print("Starting consumer and producer") + print("[main] Starting consumer and producer") + t = structio.clock() async with structio.create_pool() as ctx: ctx.spawn(producer, q, n) ctx.spawn(consumer, q) - print("Bye!") + print(f"[main] Exited in {structio.clock() - t:.2f} seconds") + + +def threaded_producer(q: structio.thread.AsyncThreadQueue, n: int): + print("[worker producer] Producer started") + for i in range(n): + # This will wait until the + # queue is emptied by the + # consumer + q.put_sync(i) + print(f"[worker producer] Produced {i}") + q.put_sync(None) + print("[worker producer] Producer done") + return 42 async def main_threaded(q: structio.thread.AsyncThreadQueue, n: int): - print("Starting consumer and producer") + print("[main] Starting consumer and producer") + t = structio.clock() async with structio.create_pool() as pool: pool.spawn(producer, q, n) - await structio.thread.run_in_worker(threaded_consumer, q) - print("Bye!") + val = await structio.thread.run_in_worker(threaded_consumer, q) + print(f"[main] Thread returned {val!r}") + print(f"[main] Exited in {structio.clock() - t:.2f} seconds") + + +async def main_threaded_2(q: structio.thread.AsyncThreadQueue, n: int): + print("[main] Starting consumer and producer") + t = structio.clock() + async with structio.create_pool() as pool: + pool.spawn(consumer, q) + val = await structio.thread.run_in_worker(threaded_producer, q, n) + print(f"[main] Thread returned {val!r}") + print(f"[main] Exited in {structio.clock() - t:.2f} seconds") if __name__ == "__main__": @@ -65,3 +91,4 @@ if __name__ == "__main__": structio.run(main, queue, 5) queue = structio.thread.AsyncThreadQueue(2) structio.run(main_threaded, queue, 5) + structio.run(main_threaded_2, queue, 5)