
364 lines
9.5 KiB

# Task synchronization primitives
from structio.core.syscalls import suspend, checkpoint
from structio.exceptions import ResourceClosed
from import current_task, current_loop
from import ChannelReader, ChannelWriter, Channel
from import enable_ki_protection
from structio.core.task import Task
from collections import deque
from typing import Any
class Event:
A wrapper around a boolean value that can be waited
on asynchronously. The majority of structio's API is
designed on top of/around this class, as it constitutes
the simplest synchronization primitive there is
def __init__(self):
Public object constructor
self._set = False
self._tasks: deque[Task] = deque()
def is_set(self):
return self._set
async def wait(self):
Wait until someone else calls set() on
this event. If the event has already been
set, this method returns immediately
if self.is_set():
await checkpoint()
await suspend() # We get re-scheduled by set()
def set(self):
Sets the event, awaking all tasks
that called wait() on it
if self.is_set():
raise RuntimeError("the event has already been set")
self._set = True
for waiter in self._tasks:
class Queue:
An asynchronous FIFO queue
def __init__(self, maxsize: int | None = None):
Object constructor
self.maxsize = maxsize
# Stores event objects for tasks wanting to
# get items from the queue
self.getters: deque[Event] = deque()
# Stores event objects for tasks wanting to
# put items on the queue
self.putters: deque[Event] = deque()
self.container: deque[Event] = deque()
def __len__(self):
Returns the length of the queue
return len(self.container)
def __repr__(self) -> str:
return f"{type(self).__name__}({f', '.join(map(str, self.container))})"
async def __aiter__(self):
Implements the asynchronous iterator protocol
return self
async def __anext__(self):
Implements the asynchronous iterator protocol
return await self.get()
async def put(self, item: Any):
Pushes an element onto the queue. If the
queue is full, waits until there's
enough space for the queue
if self.maxsize and len(self.container) == self.maxsize:
await self.putters[-1].wait()
if self.getters:
await checkpoint()
async def get(self) -> Any:
Pops an element off the queue. Blocks until
an element is put onto it again if the queue
is empty
if not self.container:
await self.getters[-1].wait()
if self.putters:
result = self.container.popleft()
await checkpoint()
return result
def clear(self):
Clears the queue
def reset(self):
Resets the queue
class MemorySendChannel(ChannelWriter):
An in-memory one-way channel to send
def __init__(self, buffer):
self._buffer = buffer
self._closed = False
async def send(self, value):
if self._closed:
raise ResourceClosed("cannot operate on a closed channel")
await self._buffer.put(value)
async def close(self):
if self._closed:
raise ResourceClosed("cannot operate on a closed channel")
self._closed = True
await checkpoint()
def writers(self):
return len(self._buffer.putters)
class MemoryReceiveChannel(ChannelReader):
An in-memory one-way channel to read
def __init__(self, buffer):
self._buffer = buffer
self._closed = False
async def receive(self):
if self._closed:
raise ResourceClosed("cannot operate on a closed channel")
return await self._buffer.get()
async def close(self):
if self._closed:
raise ResourceClosed("cannot operate on a closed channel")
self._closed = True
await checkpoint()
def pending(self):
return bool(self._buffer)
def readers(self):
return len(self._buffer.getters)
class MemoryChannel(Channel, MemorySendChannel, MemoryReceiveChannel):
An in-memory two-way channel between
tasks with optional buffering
def __init__(self, buffer_size):
self._buffer = Queue(buffer_size)
self.reader = MemoryReceiveChannel(self._buffer)
self.writer = MemorySendChannel(self._buffer)
async def close(self):
await self.reader.close()
await self.writer.close()
class Semaphore:
An asynchronous integer semaphore. The use of initial_size
is for semaphores which we know that can grow up to max_size
but that can't right now, say because there's too much load on
the application and resources are constrained. If it is None,
initial_size equals max_size
def __init__(self, max_size: int, initial_size: int | None = None):
if initial_size is None:
initial_size = max_size
assert initial_size <= max_size
self.max_size = max_size
# We use an unbuffered memory channel to pause
# as necessary, kind like socket.set_wakeup_fd
# or something? Anyways I think it's pretty nifty MemoryChannel = MemoryChannel(0)
self._counter: int = initial_size
async def acquire(self):
Acquires the semaphore, possibly
blocking if the task counter is
if self._counter == 0:
self._counter -= 1
await checkpoint()
async def release(self):
Releases the semaphore if it was previously
acquired by the caller. Raises RuntimeError
if the semaphore is not acquired by anyone
if self._counter == self.max_size:
raise RuntimeError("semaphore is not acquired")
self._counter += 1
await checkpoint()
async def __aenter__(self):
await self.acquire()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.release()
class Lock:
An asynchronous single-owner task lock
def __init__(self):
self.owner: Task | None = None
self._sem: Semaphore = Semaphore(1)
async def acquire(self):
Acquires the lock, possibly
blocking until it is available
await self._sem.acquire()
self.owner = current_task()
async def release(self):
Releases the lock if it was previously
acquired by the caller. If the lock is
not currently acquired or if it is not
acquired by the calling task, RuntimeError
is raised
if not self.owner:
raise RuntimeError("lock is not acquired")
if current_task() is not self.owner:
raise RuntimeError("lock can only be released by the owner")
self.owner = None
await self._sem.release()
async def __aenter__(self):
await self.acquire()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.release()
class RLock(Lock):
An asynchronous single-owner recursive lock.
Recursive locks have the property that their
acquire() method can be called multiple times
by the owner without deadlocking: each call
increments an internal counter, which is decremented
at every call to release(). The lock is released only
when the internal counter reaches zero
def __init__(self):
self._acquire_count = 0
async def acquire(self):
if self.owner is None:
await super().acquire()
if current_task() is self.owner:
self._acquire_count += 1
await checkpoint()
async def release(self):
self._acquire_count -= 1
if self._acquire_count == 0:
await super().release()
await checkpoint()