From e7cb6a72f5ff3193223c379fa7666d5aa53b68d1 Mon Sep 17 00:00:00 2001 From: Nocturn9x Date: Fri, 2 Jun 2023 10:58:46 +0200 Subject: [PATCH] Fixed bug with scopes and added 'smart' events --- structio/__init__.py | 2 +- structio/core/kernels/fifo.py | 17 +++++++++---- structio/sync.py | 48 +++++++++++++++++++++++++++++++++-- 3 files changed, 59 insertions(+), 8 deletions(-) diff --git a/structio/__init__.py b/structio/__init__.py index aa64ccd..5a34fbf 100644 --- a/structio/__init__.py +++ b/structio/__init__.py @@ -9,7 +9,7 @@ from structio.core.context import TaskPool, TaskScope from structio.exceptions import Cancelled, TimedOut, ResourceClosed from structio.core import task from structio.core.task import Task, TaskState -from structio.sync import Event, Queue, MemoryChannel, Semaphore, Lock, RLock +from structio.sync import Event, Queue, MemoryChannel, Semaphore, Lock, RLock, emit, on_event, register_event from structio.abc import Channel, Stream, ChannelReader, ChannelWriter from structio.io.files import ( open_file, diff --git a/structio/core/kernels/fifo.py b/structio/core/kernels/fifo.py index 1857367..a1f83f4 100644 --- a/structio/core/kernels/fifo.py +++ b/structio/core/kernels/fifo.py @@ -82,17 +82,21 @@ class FIFOKernel(BaseKernel): traceback.print_tb(e.__traceback__) def done(self): - if self.entry_point.done(): - return True if any([self.run_queue, self.paused, self.io_manager.pending()]): return False for scope in self.scopes: if not scope.done(): return False + if self.entry_point.done(): + return True return True def spawn(self, func: Callable[[Any, Any], Coroutine[Any, Any, Any]], *args): - task = Task(func.__name__ or repr(func), func(*args), self.current_pool) + if isinstance(func, partial): + name = func.func.__name__ or repr(func.func) + else: + name = func.__name__ or repr(func) + task = Task(name, func(*args), self.current_pool) # We inject our magic secret variable into the coroutine's stack frame, so # we can look it up later task.coroutine.cr_frame.f_locals.setdefault(CTRLC_PROTECTION_ENABLED, False) @@ -104,7 +108,11 @@ class FIFOKernel(BaseKernel): def spawn_system_task( self, func: Callable[[Any, Any], Coroutine[Any, Any, Any]], *args ): - task = Task(func.__name__ or repr(func), func(*args), self.pool) + if isinstance(func, partial): + name = func.func.__name__ or repr(func.func) + else: + name = func.__name__ or repr(func) + task = Task(name, func(*args), self.pool) task.coroutine.cr_frame.f_locals.setdefault(CTRLC_PROTECTION_ENABLED, True) task.pool.scope.tasks.append(task) self.run_queue.append(task) @@ -346,7 +354,6 @@ class FIFOKernel(BaseKernel): def close_scope(self, scope: TaskScope): self.current_scope = scope.outer - self.reschedule(scope.owner) self.scopes.pop() def cancel_task(self, task: Task): diff --git a/structio/sync.py b/structio/sync.py index 1fd801b..4656346 100644 --- a/structio/sync.py +++ b/structio/sync.py @@ -1,12 +1,14 @@ # Task synchronization primitives +import structio from structio.core.syscalls import suspend, checkpoint from structio.exceptions import ResourceClosed from structio.core.run import current_task, current_loop from structio.abc import ChannelReader, ChannelWriter, Channel from structio.util.ki import enable_ki_protection from structio.core.task import Task -from collections import deque -from typing import Any +from collections import deque, defaultdict +from typing import Any, Callable, Coroutine +from functools import partial, wraps class Event: @@ -361,3 +363,45 @@ class RLock(Lock): await super().release() else: await checkpoint() + + +_events: dict[str, list[Callable[[Any, Any], Coroutine[Any, Any, Any]]]] = defaultdict(list) + + +async def emit(evt: str, *args, **kwargs): + """ + Fire the event and call all of its handlers with + the event name as the first argument and all other + positional and keyword arguments passed to this + function after that. Returns once all events have + completed execution + """ + + async with structio.create_pool() as pool: + for func in _events[evt]: + pool.spawn(partial(func, evt, *args, **kwargs)) + + +def register_event(evt: str, func: Callable[[Any, Any], Coroutine[Any, Any, Any]]): + """ + Register the given async function for the given event name + """ + + _events[evt].append(func) + + +def on_event(evt: str): + """ + Convenience decorator to + register async functions + to events + """ + + def decorator(f): + @wraps + def wrapper(*args, **kwargs): + f(*args, **kwargs) + register_event(evt, f) + return wrapper + + return decorator