From 4a974ab06d5a6317bb48c96db35d866b4001b498 Mon Sep 17 00:00:00 2001 From: Nocturn9x Date: Wed, 19 Oct 2022 12:02:40 +0200 Subject: [PATCH] Added more task synchronization primitives (Queue, Channel) and related tests --- aiosched/__init__.py | 5 +- aiosched/sync.py | 203 ++++++++++++++++++++++++++++ tests/context_catch.py | 2 +- tests/context_silent_catch.py | 2 +- tests/context_wait.py | 2 +- tests/memory_channel.py | 31 +++++ tests/nested_context_catch_inner.py | 4 +- tests/nested_context_catch_outer.py | 2 +- tests/nested_context_wait.py | 2 +- tests/queue.py | 41 ++++++ tests/{catch.py => raw_catch.py} | 0 tests/{wait.py => raw_wait.py} | 0 12 files changed, 286 insertions(+), 8 deletions(-) create mode 100644 tests/memory_channel.py create mode 100644 tests/queue.py rename tests/{catch.py => raw_catch.py} (100%) rename tests/{wait.py => raw_wait.py} (100%) diff --git a/aiosched/__init__.py b/aiosched/__init__.py index f6db51a..b32f1a3 100644 --- a/aiosched/__init__.py +++ b/aiosched/__init__.py @@ -20,7 +20,7 @@ from aiosched.internals.syscalls import spawn, wait, sleep, cancel import aiosched.task import aiosched.errors import aiosched.context -from aiosched.sync import Event +from aiosched.sync import Event, Queue, Channel, MemoryChannel __all__ = [ "run", @@ -34,4 +34,7 @@ __all__ = [ "cancel", "with_context", "Event", + "Queue", + "Channel", + "MemoryChannel" ] diff --git a/aiosched/sync.py b/aiosched/sync.py index a24d7a9..133ddf8 100644 --- a/aiosched/sync.py +++ b/aiosched/sync.py @@ -15,6 +15,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ +from collections import deque +from abc import ABC, abstractmethod from typing import Any from aiosched.errors import SchedulerError from aiosched.internals.syscalls import ( @@ -63,3 +65,204 @@ class Event: self.waiters.add(await current_task()) await suspend() # We get unsuspended by trigger() + + +class Queue: + """ + An asynchronous FIFO queue. Not thread safe + """ + + def __init__(self, maxsize: int | None = None): + """ + Object constructor + """ + + self.maxsize = maxsize + # Stores event objects for tasks wanting to + # get items from the queue + self.getters = deque() + # Stores event objects for tasks wanting to + # put items on the queue + self.putters = deque() + self.container = deque() + + def __len__(self): + """ + Returns the length of the queue + """ + + return len(self.container) + + def __repr__(self) -> str: + return f"{type(self).__name__}({f', '.join(map(str, self.container))})" + + async def __aiter__(self): + """ + Implements the asynchronous iterator protocol + """ + + return self + + async def __anext__(self): + """ + Implements the asynchronous iterator protocol + """ + + return await self.get() + + async def put(self, item: Any): + """ + Pushes an element onto the queue. If the + queue is full, waits until there's + enough space for the queue + """ + + if not self.maxsize or len(self.container) < self.maxsize: + self.container.append(item) + if self.getters: + await self.getters.popleft().trigger(self.container.popleft()) + else: + ev = Event() + self.putters.append(ev) + await ev.wait() + self.container.append(item) + + async def get(self) -> Any: + """ + Pops an element off the queue. Blocks until + an element is put onto it again if the queue + is empty + """ + + if self.container: + if self.putters: + await self.putters.popleft().trigger() + return self.container.popleft() + else: + ev = Event() + self.getters.append(ev) + return await ev.wait() + + async def clear(self): + """ + Clears the queue + """ + + self.container.clear() + + async def reset(self): + """ + Resets the queue + """ + + await self.clear() + self.getters.clear() + self.putters.clear() + + +class Channel(ABC): + """ + A generic, two-way, full-duplex communication channel + between tasks. This is just an abstract base class and + should not be instantiated directly + """ + + def __init__(self, maxsize: int | None = None): + """ + Public object constructor + """ + + self.maxsize = maxsize + self.closed = False + + @abstractmethod + async def write(self, data: str): + """ + Writes data to the channel. Blocks if the internal + queue is full until a spot is available. Does nothing + if the channel has been closed + """ + + return NotImplemented + + @abstractmethod + async def read(self): + """ + Reads data from the channel. Blocks until + a message arrives or returns immediately if + one is already waiting + """ + + return NotImplemented + + @abstractmethod + async def close(self): + """ + Closes the memory channel. Any underlying + data is left for other tasks to read + """ + + return NotImplemented + + @abstractmethod + async def pending(self): + """ + Returns if there's pending + data to be read + """ + + return NotImplemented + + +class MemoryChannel(Channel): + """ + A two-way communication channel between tasks. + Operations on this object do not perform any I/O + or other system call and are therefore extremely + efficient. Not thread safe + """ + + def __init__(self, maxsize: int | None = None): + """ + Public object constructor + """ + + super().__init__(maxsize) + # We use a queue as our buffer + self.buffer = Queue(maxsize=maxsize) + + async def write(self, data: str): + """ + Writes data to the channel. Blocks if the internal + queue is full until a spot is available. Does nothing + if the channel has been closed + """ + + if self.closed: + return + await self.buffer.put(data) + + async def read(self): + """ + Reads data from the channel. Blocks until + a message arrives or returns immediately if + one is already waiting + """ + + return await self.buffer.get() + + async def close(self): + """ + Closes the memory channel. Any underlying + data is left for other tasks to read + """ + + self.closed = True + + async def pending(self): + """ + Returns if there's pending + data to be read + """ + + return bool(len(self.buffer)) diff --git a/tests/context_catch.py b/tests/context_catch.py index 730cd50..e53d904 100644 --- a/tests/context_catch.py +++ b/tests/context_catch.py @@ -1,5 +1,5 @@ import aiosched -from catch import child +from raw_catch import child from debugger import Debugger diff --git a/tests/context_silent_catch.py b/tests/context_silent_catch.py index e5f4e03..142502f 100644 --- a/tests/context_silent_catch.py +++ b/tests/context_silent_catch.py @@ -1,5 +1,5 @@ import aiosched -from catch import child +from raw_catch import child from debugger import Debugger diff --git a/tests/context_wait.py b/tests/context_wait.py index c9b3d01..af7f93b 100644 --- a/tests/context_wait.py +++ b/tests/context_wait.py @@ -1,5 +1,5 @@ import aiosched -from wait import child +from raw_wait import child from debugger import Debugger diff --git a/tests/memory_channel.py b/tests/memory_channel.py new file mode 100644 index 0000000..5326d53 --- /dev/null +++ b/tests/memory_channel.py @@ -0,0 +1,31 @@ +import aiosched +from debugger import Debugger + + +async def sender(c: aiosched.MemoryChannel, n: int): + for i in range(n): + await c.write(str(i)) + print(f"Sent {i}") + await c.close() + print("Sender done") + + +async def receiver(c: aiosched.MemoryChannel): + while True: + if not await c.pending() and c.closed: + print("Receiver done") + break + item = await c.read() + print(f"Received {item}") + await aiosched.sleep(1) + + +async def main(channel: aiosched.MemoryChannel, n: int): + print("Starting sender and receiver") + async with aiosched.with_context() as ctx: + await ctx.spawn(sender, channel, n) + await ctx.spawn(receiver, channel) + print("All done!") + + +aiosched.run(main, aiosched.MemoryChannel(2), 5, debugger=()) # 2 is the max size of the channel diff --git a/tests/nested_context_catch_inner.py b/tests/nested_context_catch_inner.py index f9bbf95..b1f67d6 100644 --- a/tests/nested_context_catch_inner.py +++ b/tests/nested_context_catch_inner.py @@ -1,6 +1,6 @@ import aiosched -from catch import child as errorer -from wait import child as successful +from raw_catch import child as errorer +from raw_wait import child as successful from debugger import Debugger diff --git a/tests/nested_context_catch_outer.py b/tests/nested_context_catch_outer.py index effa85b..774192b 100644 --- a/tests/nested_context_catch_outer.py +++ b/tests/nested_context_catch_outer.py @@ -1,5 +1,5 @@ import aiosched -from catch import child +from raw_catch import child from debugger import Debugger diff --git a/tests/nested_context_wait.py b/tests/nested_context_wait.py index fdad03d..17e2635 100644 --- a/tests/nested_context_wait.py +++ b/tests/nested_context_wait.py @@ -1,5 +1,5 @@ import aiosched -from wait import child +from raw_wait import child from debugger import Debugger diff --git a/tests/queue.py b/tests/queue.py new file mode 100644 index 0000000..6928465 --- /dev/null +++ b/tests/queue.py @@ -0,0 +1,41 @@ +import aiosched +from debugger import Debugger + + +async def producer(q: aiosched.Queue, n: int): + for i in range(n): + # This will wait until the + # queue is emptied by the + # consumer + await q.put(i) + print(f"Produced {i}") + await q.put(None) + print("Producer done") + + +async def consumer(q: aiosched.Queue): + while True: + # Hangs until there is + # something on the queue + item = await q.get() + if item is None: + print("Consumer done") + break + print(f"Consumed {item}") + # Simulates some work so the + # producer waits before putting + # the next value + await aiosched.sleep(1) + + +async def main(q: aiosched.Queue, n: int): + print("Starting consumer and producer") + async with aiosched.with_context() as ctx: + await ctx.spawn(producer, q, n) + await ctx.spawn(consumer, q) + print("Bye!") + + +if __name__ == "__main__": + queue = aiosched.Queue(2) # Queue has size limit of 2 + aiosched.run(main, queue, 5, debugger=None) diff --git a/tests/catch.py b/tests/raw_catch.py similarity index 100% rename from tests/catch.py rename to tests/raw_catch.py diff --git a/tests/wait.py b/tests/raw_wait.py similarity index 100% rename from tests/wait.py rename to tests/raw_wait.py