Added more task synchronization primitives (Queue, Channel) and related tests

This commit is contained in:
Nocturn9x 2022-10-19 12:02:40 +02:00
parent c2bb63149b
commit 4a974ab06d
12 changed files with 286 additions and 8 deletions

View File

@ -20,7 +20,7 @@ from aiosched.internals.syscalls import spawn, wait, sleep, cancel
import aiosched.task
import aiosched.errors
import aiosched.context
from aiosched.sync import Event
from aiosched.sync import Event, Queue, Channel, MemoryChannel
__all__ = [
"run",
@ -34,4 +34,7 @@ __all__ = [
"cancel",
"with_context",
"Event",
"Queue",
"Channel",
"MemoryChannel"
]

View File

@ -15,6 +15,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from collections import deque
from abc import ABC, abstractmethod
from typing import Any
from aiosched.errors import SchedulerError
from aiosched.internals.syscalls import (
@ -63,3 +65,204 @@ class Event:
self.waiters.add(await current_task())
await suspend() # We get unsuspended by trigger()
class Queue:
"""
An asynchronous FIFO queue. Not thread safe
"""
def __init__(self, maxsize: int | None = None):
"""
Object constructor
"""
self.maxsize = maxsize
# Stores event objects for tasks wanting to
# get items from the queue
self.getters = deque()
# Stores event objects for tasks wanting to
# put items on the queue
self.putters = deque()
self.container = deque()
def __len__(self):
"""
Returns the length of the queue
"""
return len(self.container)
def __repr__(self) -> str:
return f"{type(self).__name__}({f', '.join(map(str, self.container))})"
async def __aiter__(self):
"""
Implements the asynchronous iterator protocol
"""
return self
async def __anext__(self):
"""
Implements the asynchronous iterator protocol
"""
return await self.get()
async def put(self, item: Any):
"""
Pushes an element onto the queue. If the
queue is full, waits until there's
enough space for the queue
"""
if not self.maxsize or len(self.container) < self.maxsize:
self.container.append(item)
if self.getters:
await self.getters.popleft().trigger(self.container.popleft())
else:
ev = Event()
self.putters.append(ev)
await ev.wait()
self.container.append(item)
async def get(self) -> Any:
"""
Pops an element off the queue. Blocks until
an element is put onto it again if the queue
is empty
"""
if self.container:
if self.putters:
await self.putters.popleft().trigger()
return self.container.popleft()
else:
ev = Event()
self.getters.append(ev)
return await ev.wait()
async def clear(self):
"""
Clears the queue
"""
self.container.clear()
async def reset(self):
"""
Resets the queue
"""
await self.clear()
self.getters.clear()
self.putters.clear()
class Channel(ABC):
"""
A generic, two-way, full-duplex communication channel
between tasks. This is just an abstract base class and
should not be instantiated directly
"""
def __init__(self, maxsize: int | None = None):
"""
Public object constructor
"""
self.maxsize = maxsize
self.closed = False
@abstractmethod
async def write(self, data: str):
"""
Writes data to the channel. Blocks if the internal
queue is full until a spot is available. Does nothing
if the channel has been closed
"""
return NotImplemented
@abstractmethod
async def read(self):
"""
Reads data from the channel. Blocks until
a message arrives or returns immediately if
one is already waiting
"""
return NotImplemented
@abstractmethod
async def close(self):
"""
Closes the memory channel. Any underlying
data is left for other tasks to read
"""
return NotImplemented
@abstractmethod
async def pending(self):
"""
Returns if there's pending
data to be read
"""
return NotImplemented
class MemoryChannel(Channel):
"""
A two-way communication channel between tasks.
Operations on this object do not perform any I/O
or other system call and are therefore extremely
efficient. Not thread safe
"""
def __init__(self, maxsize: int | None = None):
"""
Public object constructor
"""
super().__init__(maxsize)
# We use a queue as our buffer
self.buffer = Queue(maxsize=maxsize)
async def write(self, data: str):
"""
Writes data to the channel. Blocks if the internal
queue is full until a spot is available. Does nothing
if the channel has been closed
"""
if self.closed:
return
await self.buffer.put(data)
async def read(self):
"""
Reads data from the channel. Blocks until
a message arrives or returns immediately if
one is already waiting
"""
return await self.buffer.get()
async def close(self):
"""
Closes the memory channel. Any underlying
data is left for other tasks to read
"""
self.closed = True
async def pending(self):
"""
Returns if there's pending
data to be read
"""
return bool(len(self.buffer))

View File

@ -1,5 +1,5 @@
import aiosched
from catch import child
from raw_catch import child
from debugger import Debugger

View File

@ -1,5 +1,5 @@
import aiosched
from catch import child
from raw_catch import child
from debugger import Debugger

View File

@ -1,5 +1,5 @@
import aiosched
from wait import child
from raw_wait import child
from debugger import Debugger

31
tests/memory_channel.py Normal file
View File

@ -0,0 +1,31 @@
import aiosched
from debugger import Debugger
async def sender(c: aiosched.MemoryChannel, n: int):
for i in range(n):
await c.write(str(i))
print(f"Sent {i}")
await c.close()
print("Sender done")
async def receiver(c: aiosched.MemoryChannel):
while True:
if not await c.pending() and c.closed:
print("Receiver done")
break
item = await c.read()
print(f"Received {item}")
await aiosched.sleep(1)
async def main(channel: aiosched.MemoryChannel, n: int):
print("Starting sender and receiver")
async with aiosched.with_context() as ctx:
await ctx.spawn(sender, channel, n)
await ctx.spawn(receiver, channel)
print("All done!")
aiosched.run(main, aiosched.MemoryChannel(2), 5, debugger=()) # 2 is the max size of the channel

View File

@ -1,6 +1,6 @@
import aiosched
from catch import child as errorer
from wait import child as successful
from raw_catch import child as errorer
from raw_wait import child as successful
from debugger import Debugger

View File

@ -1,5 +1,5 @@
import aiosched
from catch import child
from raw_catch import child
from debugger import Debugger

View File

@ -1,5 +1,5 @@
import aiosched
from wait import child
from raw_wait import child
from debugger import Debugger

41
tests/queue.py Normal file
View File

@ -0,0 +1,41 @@
import aiosched
from debugger import Debugger
async def producer(q: aiosched.Queue, n: int):
for i in range(n):
# This will wait until the
# queue is emptied by the
# consumer
await q.put(i)
print(f"Produced {i}")
await q.put(None)
print("Producer done")
async def consumer(q: aiosched.Queue):
while True:
# Hangs until there is
# something on the queue
item = await q.get()
if item is None:
print("Consumer done")
break
print(f"Consumed {item}")
# Simulates some work so the
# producer waits before putting
# the next value
await aiosched.sleep(1)
async def main(q: aiosched.Queue, n: int):
print("Starting consumer and producer")
async with aiosched.with_context() as ctx:
await ctx.spawn(producer, q, n)
await ctx.spawn(consumer, q)
print("Bye!")
if __name__ == "__main__":
queue = aiosched.Queue(2) # Queue has size limit of 2
aiosched.run(main, queue, 5, debugger=None)