Added more task synchronization primitives (Queue, Channel) and related tests
This commit is contained in:
parent
c2bb63149b
commit
4a974ab06d
|
@ -20,7 +20,7 @@ from aiosched.internals.syscalls import spawn, wait, sleep, cancel
|
|||
import aiosched.task
|
||||
import aiosched.errors
|
||||
import aiosched.context
|
||||
from aiosched.sync import Event
|
||||
from aiosched.sync import Event, Queue, Channel, MemoryChannel
|
||||
|
||||
__all__ = [
|
||||
"run",
|
||||
|
@ -34,4 +34,7 @@ __all__ = [
|
|||
"cancel",
|
||||
"with_context",
|
||||
"Event",
|
||||
"Queue",
|
||||
"Channel",
|
||||
"MemoryChannel"
|
||||
]
|
||||
|
|
203
aiosched/sync.py
203
aiosched/sync.py
|
@ -15,6 +15,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
from collections import deque
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
from aiosched.errors import SchedulerError
|
||||
from aiosched.internals.syscalls import (
|
||||
|
@ -63,3 +65,204 @@ class Event:
|
|||
|
||||
self.waiters.add(await current_task())
|
||||
await suspend() # We get unsuspended by trigger()
|
||||
|
||||
|
||||
class Queue:
|
||||
"""
|
||||
An asynchronous FIFO queue. Not thread safe
|
||||
"""
|
||||
|
||||
def __init__(self, maxsize: int | None = None):
|
||||
"""
|
||||
Object constructor
|
||||
"""
|
||||
|
||||
self.maxsize = maxsize
|
||||
# Stores event objects for tasks wanting to
|
||||
# get items from the queue
|
||||
self.getters = deque()
|
||||
# Stores event objects for tasks wanting to
|
||||
# put items on the queue
|
||||
self.putters = deque()
|
||||
self.container = deque()
|
||||
|
||||
def __len__(self):
|
||||
"""
|
||||
Returns the length of the queue
|
||||
"""
|
||||
|
||||
return len(self.container)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{type(self).__name__}({f', '.join(map(str, self.container))})"
|
||||
|
||||
async def __aiter__(self):
|
||||
"""
|
||||
Implements the asynchronous iterator protocol
|
||||
"""
|
||||
|
||||
return self
|
||||
|
||||
async def __anext__(self):
|
||||
"""
|
||||
Implements the asynchronous iterator protocol
|
||||
"""
|
||||
|
||||
return await self.get()
|
||||
|
||||
async def put(self, item: Any):
|
||||
"""
|
||||
Pushes an element onto the queue. If the
|
||||
queue is full, waits until there's
|
||||
enough space for the queue
|
||||
"""
|
||||
|
||||
if not self.maxsize or len(self.container) < self.maxsize:
|
||||
self.container.append(item)
|
||||
if self.getters:
|
||||
await self.getters.popleft().trigger(self.container.popleft())
|
||||
else:
|
||||
ev = Event()
|
||||
self.putters.append(ev)
|
||||
await ev.wait()
|
||||
self.container.append(item)
|
||||
|
||||
async def get(self) -> Any:
|
||||
"""
|
||||
Pops an element off the queue. Blocks until
|
||||
an element is put onto it again if the queue
|
||||
is empty
|
||||
"""
|
||||
|
||||
if self.container:
|
||||
if self.putters:
|
||||
await self.putters.popleft().trigger()
|
||||
return self.container.popleft()
|
||||
else:
|
||||
ev = Event()
|
||||
self.getters.append(ev)
|
||||
return await ev.wait()
|
||||
|
||||
async def clear(self):
|
||||
"""
|
||||
Clears the queue
|
||||
"""
|
||||
|
||||
self.container.clear()
|
||||
|
||||
async def reset(self):
|
||||
"""
|
||||
Resets the queue
|
||||
"""
|
||||
|
||||
await self.clear()
|
||||
self.getters.clear()
|
||||
self.putters.clear()
|
||||
|
||||
|
||||
class Channel(ABC):
|
||||
"""
|
||||
A generic, two-way, full-duplex communication channel
|
||||
between tasks. This is just an abstract base class and
|
||||
should not be instantiated directly
|
||||
"""
|
||||
|
||||
def __init__(self, maxsize: int | None = None):
|
||||
"""
|
||||
Public object constructor
|
||||
"""
|
||||
|
||||
self.maxsize = maxsize
|
||||
self.closed = False
|
||||
|
||||
@abstractmethod
|
||||
async def write(self, data: str):
|
||||
"""
|
||||
Writes data to the channel. Blocks if the internal
|
||||
queue is full until a spot is available. Does nothing
|
||||
if the channel has been closed
|
||||
"""
|
||||
|
||||
return NotImplemented
|
||||
|
||||
@abstractmethod
|
||||
async def read(self):
|
||||
"""
|
||||
Reads data from the channel. Blocks until
|
||||
a message arrives or returns immediately if
|
||||
one is already waiting
|
||||
"""
|
||||
|
||||
return NotImplemented
|
||||
|
||||
@abstractmethod
|
||||
async def close(self):
|
||||
"""
|
||||
Closes the memory channel. Any underlying
|
||||
data is left for other tasks to read
|
||||
"""
|
||||
|
||||
return NotImplemented
|
||||
|
||||
@abstractmethod
|
||||
async def pending(self):
|
||||
"""
|
||||
Returns if there's pending
|
||||
data to be read
|
||||
"""
|
||||
|
||||
return NotImplemented
|
||||
|
||||
|
||||
class MemoryChannel(Channel):
|
||||
"""
|
||||
A two-way communication channel between tasks.
|
||||
Operations on this object do not perform any I/O
|
||||
or other system call and are therefore extremely
|
||||
efficient. Not thread safe
|
||||
"""
|
||||
|
||||
def __init__(self, maxsize: int | None = None):
|
||||
"""
|
||||
Public object constructor
|
||||
"""
|
||||
|
||||
super().__init__(maxsize)
|
||||
# We use a queue as our buffer
|
||||
self.buffer = Queue(maxsize=maxsize)
|
||||
|
||||
async def write(self, data: str):
|
||||
"""
|
||||
Writes data to the channel. Blocks if the internal
|
||||
queue is full until a spot is available. Does nothing
|
||||
if the channel has been closed
|
||||
"""
|
||||
|
||||
if self.closed:
|
||||
return
|
||||
await self.buffer.put(data)
|
||||
|
||||
async def read(self):
|
||||
"""
|
||||
Reads data from the channel. Blocks until
|
||||
a message arrives or returns immediately if
|
||||
one is already waiting
|
||||
"""
|
||||
|
||||
return await self.buffer.get()
|
||||
|
||||
async def close(self):
|
||||
"""
|
||||
Closes the memory channel. Any underlying
|
||||
data is left for other tasks to read
|
||||
"""
|
||||
|
||||
self.closed = True
|
||||
|
||||
async def pending(self):
|
||||
"""
|
||||
Returns if there's pending
|
||||
data to be read
|
||||
"""
|
||||
|
||||
return bool(len(self.buffer))
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
import aiosched
|
||||
from catch import child
|
||||
from raw_catch import child
|
||||
from debugger import Debugger
|
||||
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
import aiosched
|
||||
from catch import child
|
||||
from raw_catch import child
|
||||
from debugger import Debugger
|
||||
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
import aiosched
|
||||
from wait import child
|
||||
from raw_wait import child
|
||||
from debugger import Debugger
|
||||
|
||||
|
||||
|
|
|
@ -0,0 +1,31 @@
|
|||
import aiosched
|
||||
from debugger import Debugger
|
||||
|
||||
|
||||
async def sender(c: aiosched.MemoryChannel, n: int):
|
||||
for i in range(n):
|
||||
await c.write(str(i))
|
||||
print(f"Sent {i}")
|
||||
await c.close()
|
||||
print("Sender done")
|
||||
|
||||
|
||||
async def receiver(c: aiosched.MemoryChannel):
|
||||
while True:
|
||||
if not await c.pending() and c.closed:
|
||||
print("Receiver done")
|
||||
break
|
||||
item = await c.read()
|
||||
print(f"Received {item}")
|
||||
await aiosched.sleep(1)
|
||||
|
||||
|
||||
async def main(channel: aiosched.MemoryChannel, n: int):
|
||||
print("Starting sender and receiver")
|
||||
async with aiosched.with_context() as ctx:
|
||||
await ctx.spawn(sender, channel, n)
|
||||
await ctx.spawn(receiver, channel)
|
||||
print("All done!")
|
||||
|
||||
|
||||
aiosched.run(main, aiosched.MemoryChannel(2), 5, debugger=()) # 2 is the max size of the channel
|
|
@ -1,6 +1,6 @@
|
|||
import aiosched
|
||||
from catch import child as errorer
|
||||
from wait import child as successful
|
||||
from raw_catch import child as errorer
|
||||
from raw_wait import child as successful
|
||||
from debugger import Debugger
|
||||
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
import aiosched
|
||||
from catch import child
|
||||
from raw_catch import child
|
||||
from debugger import Debugger
|
||||
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
import aiosched
|
||||
from wait import child
|
||||
from raw_wait import child
|
||||
from debugger import Debugger
|
||||
|
||||
|
||||
|
|
|
@ -0,0 +1,41 @@
|
|||
import aiosched
|
||||
from debugger import Debugger
|
||||
|
||||
|
||||
async def producer(q: aiosched.Queue, n: int):
|
||||
for i in range(n):
|
||||
# This will wait until the
|
||||
# queue is emptied by the
|
||||
# consumer
|
||||
await q.put(i)
|
||||
print(f"Produced {i}")
|
||||
await q.put(None)
|
||||
print("Producer done")
|
||||
|
||||
|
||||
async def consumer(q: aiosched.Queue):
|
||||
while True:
|
||||
# Hangs until there is
|
||||
# something on the queue
|
||||
item = await q.get()
|
||||
if item is None:
|
||||
print("Consumer done")
|
||||
break
|
||||
print(f"Consumed {item}")
|
||||
# Simulates some work so the
|
||||
# producer waits before putting
|
||||
# the next value
|
||||
await aiosched.sleep(1)
|
||||
|
||||
|
||||
async def main(q: aiosched.Queue, n: int):
|
||||
print("Starting consumer and producer")
|
||||
async with aiosched.with_context() as ctx:
|
||||
await ctx.spawn(producer, q, n)
|
||||
await ctx.spawn(consumer, q)
|
||||
print("Bye!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
queue = aiosched.Queue(2) # Queue has size limit of 2
|
||||
aiosched.run(main, queue, 5, debugger=None)
|
Reference in New Issue