Added more task synchronization primitives (Queue, Channel) and related tests
This commit is contained in:
parent
c2bb63149b
commit
4a974ab06d
|
@ -20,7 +20,7 @@ from aiosched.internals.syscalls import spawn, wait, sleep, cancel
|
||||||
import aiosched.task
|
import aiosched.task
|
||||||
import aiosched.errors
|
import aiosched.errors
|
||||||
import aiosched.context
|
import aiosched.context
|
||||||
from aiosched.sync import Event
|
from aiosched.sync import Event, Queue, Channel, MemoryChannel
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"run",
|
"run",
|
||||||
|
@ -34,4 +34,7 @@ __all__ = [
|
||||||
"cancel",
|
"cancel",
|
||||||
"with_context",
|
"with_context",
|
||||||
"Event",
|
"Event",
|
||||||
|
"Queue",
|
||||||
|
"Channel",
|
||||||
|
"MemoryChannel"
|
||||||
]
|
]
|
||||||
|
|
203
aiosched/sync.py
203
aiosched/sync.py
|
@ -15,6 +15,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
See the License for the specific language governing permissions and
|
See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
"""
|
"""
|
||||||
|
from collections import deque
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from aiosched.errors import SchedulerError
|
from aiosched.errors import SchedulerError
|
||||||
from aiosched.internals.syscalls import (
|
from aiosched.internals.syscalls import (
|
||||||
|
@ -63,3 +65,204 @@ class Event:
|
||||||
|
|
||||||
self.waiters.add(await current_task())
|
self.waiters.add(await current_task())
|
||||||
await suspend() # We get unsuspended by trigger()
|
await suspend() # We get unsuspended by trigger()
|
||||||
|
|
||||||
|
|
||||||
|
class Queue:
|
||||||
|
"""
|
||||||
|
An asynchronous FIFO queue. Not thread safe
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, maxsize: int | None = None):
|
||||||
|
"""
|
||||||
|
Object constructor
|
||||||
|
"""
|
||||||
|
|
||||||
|
self.maxsize = maxsize
|
||||||
|
# Stores event objects for tasks wanting to
|
||||||
|
# get items from the queue
|
||||||
|
self.getters = deque()
|
||||||
|
# Stores event objects for tasks wanting to
|
||||||
|
# put items on the queue
|
||||||
|
self.putters = deque()
|
||||||
|
self.container = deque()
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
"""
|
||||||
|
Returns the length of the queue
|
||||||
|
"""
|
||||||
|
|
||||||
|
return len(self.container)
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"{type(self).__name__}({f', '.join(map(str, self.container))})"
|
||||||
|
|
||||||
|
async def __aiter__(self):
|
||||||
|
"""
|
||||||
|
Implements the asynchronous iterator protocol
|
||||||
|
"""
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __anext__(self):
|
||||||
|
"""
|
||||||
|
Implements the asynchronous iterator protocol
|
||||||
|
"""
|
||||||
|
|
||||||
|
return await self.get()
|
||||||
|
|
||||||
|
async def put(self, item: Any):
|
||||||
|
"""
|
||||||
|
Pushes an element onto the queue. If the
|
||||||
|
queue is full, waits until there's
|
||||||
|
enough space for the queue
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not self.maxsize or len(self.container) < self.maxsize:
|
||||||
|
self.container.append(item)
|
||||||
|
if self.getters:
|
||||||
|
await self.getters.popleft().trigger(self.container.popleft())
|
||||||
|
else:
|
||||||
|
ev = Event()
|
||||||
|
self.putters.append(ev)
|
||||||
|
await ev.wait()
|
||||||
|
self.container.append(item)
|
||||||
|
|
||||||
|
async def get(self) -> Any:
|
||||||
|
"""
|
||||||
|
Pops an element off the queue. Blocks until
|
||||||
|
an element is put onto it again if the queue
|
||||||
|
is empty
|
||||||
|
"""
|
||||||
|
|
||||||
|
if self.container:
|
||||||
|
if self.putters:
|
||||||
|
await self.putters.popleft().trigger()
|
||||||
|
return self.container.popleft()
|
||||||
|
else:
|
||||||
|
ev = Event()
|
||||||
|
self.getters.append(ev)
|
||||||
|
return await ev.wait()
|
||||||
|
|
||||||
|
async def clear(self):
|
||||||
|
"""
|
||||||
|
Clears the queue
|
||||||
|
"""
|
||||||
|
|
||||||
|
self.container.clear()
|
||||||
|
|
||||||
|
async def reset(self):
|
||||||
|
"""
|
||||||
|
Resets the queue
|
||||||
|
"""
|
||||||
|
|
||||||
|
await self.clear()
|
||||||
|
self.getters.clear()
|
||||||
|
self.putters.clear()
|
||||||
|
|
||||||
|
|
||||||
|
class Channel(ABC):
|
||||||
|
"""
|
||||||
|
A generic, two-way, full-duplex communication channel
|
||||||
|
between tasks. This is just an abstract base class and
|
||||||
|
should not be instantiated directly
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, maxsize: int | None = None):
|
||||||
|
"""
|
||||||
|
Public object constructor
|
||||||
|
"""
|
||||||
|
|
||||||
|
self.maxsize = maxsize
|
||||||
|
self.closed = False
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def write(self, data: str):
|
||||||
|
"""
|
||||||
|
Writes data to the channel. Blocks if the internal
|
||||||
|
queue is full until a spot is available. Does nothing
|
||||||
|
if the channel has been closed
|
||||||
|
"""
|
||||||
|
|
||||||
|
return NotImplemented
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def read(self):
|
||||||
|
"""
|
||||||
|
Reads data from the channel. Blocks until
|
||||||
|
a message arrives or returns immediately if
|
||||||
|
one is already waiting
|
||||||
|
"""
|
||||||
|
|
||||||
|
return NotImplemented
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def close(self):
|
||||||
|
"""
|
||||||
|
Closes the memory channel. Any underlying
|
||||||
|
data is left for other tasks to read
|
||||||
|
"""
|
||||||
|
|
||||||
|
return NotImplemented
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def pending(self):
|
||||||
|
"""
|
||||||
|
Returns if there's pending
|
||||||
|
data to be read
|
||||||
|
"""
|
||||||
|
|
||||||
|
return NotImplemented
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryChannel(Channel):
|
||||||
|
"""
|
||||||
|
A two-way communication channel between tasks.
|
||||||
|
Operations on this object do not perform any I/O
|
||||||
|
or other system call and are therefore extremely
|
||||||
|
efficient. Not thread safe
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, maxsize: int | None = None):
|
||||||
|
"""
|
||||||
|
Public object constructor
|
||||||
|
"""
|
||||||
|
|
||||||
|
super().__init__(maxsize)
|
||||||
|
# We use a queue as our buffer
|
||||||
|
self.buffer = Queue(maxsize=maxsize)
|
||||||
|
|
||||||
|
async def write(self, data: str):
|
||||||
|
"""
|
||||||
|
Writes data to the channel. Blocks if the internal
|
||||||
|
queue is full until a spot is available. Does nothing
|
||||||
|
if the channel has been closed
|
||||||
|
"""
|
||||||
|
|
||||||
|
if self.closed:
|
||||||
|
return
|
||||||
|
await self.buffer.put(data)
|
||||||
|
|
||||||
|
async def read(self):
|
||||||
|
"""
|
||||||
|
Reads data from the channel. Blocks until
|
||||||
|
a message arrives or returns immediately if
|
||||||
|
one is already waiting
|
||||||
|
"""
|
||||||
|
|
||||||
|
return await self.buffer.get()
|
||||||
|
|
||||||
|
async def close(self):
|
||||||
|
"""
|
||||||
|
Closes the memory channel. Any underlying
|
||||||
|
data is left for other tasks to read
|
||||||
|
"""
|
||||||
|
|
||||||
|
self.closed = True
|
||||||
|
|
||||||
|
async def pending(self):
|
||||||
|
"""
|
||||||
|
Returns if there's pending
|
||||||
|
data to be read
|
||||||
|
"""
|
||||||
|
|
||||||
|
return bool(len(self.buffer))
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
import aiosched
|
import aiosched
|
||||||
from catch import child
|
from raw_catch import child
|
||||||
from debugger import Debugger
|
from debugger import Debugger
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
import aiosched
|
import aiosched
|
||||||
from catch import child
|
from raw_catch import child
|
||||||
from debugger import Debugger
|
from debugger import Debugger
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
import aiosched
|
import aiosched
|
||||||
from wait import child
|
from raw_wait import child
|
||||||
from debugger import Debugger
|
from debugger import Debugger
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,31 @@
|
||||||
|
import aiosched
|
||||||
|
from debugger import Debugger
|
||||||
|
|
||||||
|
|
||||||
|
async def sender(c: aiosched.MemoryChannel, n: int):
|
||||||
|
for i in range(n):
|
||||||
|
await c.write(str(i))
|
||||||
|
print(f"Sent {i}")
|
||||||
|
await c.close()
|
||||||
|
print("Sender done")
|
||||||
|
|
||||||
|
|
||||||
|
async def receiver(c: aiosched.MemoryChannel):
|
||||||
|
while True:
|
||||||
|
if not await c.pending() and c.closed:
|
||||||
|
print("Receiver done")
|
||||||
|
break
|
||||||
|
item = await c.read()
|
||||||
|
print(f"Received {item}")
|
||||||
|
await aiosched.sleep(1)
|
||||||
|
|
||||||
|
|
||||||
|
async def main(channel: aiosched.MemoryChannel, n: int):
|
||||||
|
print("Starting sender and receiver")
|
||||||
|
async with aiosched.with_context() as ctx:
|
||||||
|
await ctx.spawn(sender, channel, n)
|
||||||
|
await ctx.spawn(receiver, channel)
|
||||||
|
print("All done!")
|
||||||
|
|
||||||
|
|
||||||
|
aiosched.run(main, aiosched.MemoryChannel(2), 5, debugger=()) # 2 is the max size of the channel
|
|
@ -1,6 +1,6 @@
|
||||||
import aiosched
|
import aiosched
|
||||||
from catch import child as errorer
|
from raw_catch import child as errorer
|
||||||
from wait import child as successful
|
from raw_wait import child as successful
|
||||||
from debugger import Debugger
|
from debugger import Debugger
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
import aiosched
|
import aiosched
|
||||||
from catch import child
|
from raw_catch import child
|
||||||
from debugger import Debugger
|
from debugger import Debugger
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
import aiosched
|
import aiosched
|
||||||
from wait import child
|
from raw_wait import child
|
||||||
from debugger import Debugger
|
from debugger import Debugger
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,41 @@
|
||||||
|
import aiosched
|
||||||
|
from debugger import Debugger
|
||||||
|
|
||||||
|
|
||||||
|
async def producer(q: aiosched.Queue, n: int):
|
||||||
|
for i in range(n):
|
||||||
|
# This will wait until the
|
||||||
|
# queue is emptied by the
|
||||||
|
# consumer
|
||||||
|
await q.put(i)
|
||||||
|
print(f"Produced {i}")
|
||||||
|
await q.put(None)
|
||||||
|
print("Producer done")
|
||||||
|
|
||||||
|
|
||||||
|
async def consumer(q: aiosched.Queue):
|
||||||
|
while True:
|
||||||
|
# Hangs until there is
|
||||||
|
# something on the queue
|
||||||
|
item = await q.get()
|
||||||
|
if item is None:
|
||||||
|
print("Consumer done")
|
||||||
|
break
|
||||||
|
print(f"Consumed {item}")
|
||||||
|
# Simulates some work so the
|
||||||
|
# producer waits before putting
|
||||||
|
# the next value
|
||||||
|
await aiosched.sleep(1)
|
||||||
|
|
||||||
|
|
||||||
|
async def main(q: aiosched.Queue, n: int):
|
||||||
|
print("Starting consumer and producer")
|
||||||
|
async with aiosched.with_context() as ctx:
|
||||||
|
await ctx.spawn(producer, q, n)
|
||||||
|
await ctx.spawn(consumer, q)
|
||||||
|
print("Bye!")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
queue = aiosched.Queue(2) # Queue has size limit of 2
|
||||||
|
aiosched.run(main, queue, 5, debugger=None)
|
Reference in New Issue