This repository has been archived on 2023-05-12. You can view files and clone it, but cannot push or open issues or pull requests.
aiosched/aiosched/sync.py

384 lines
9.5 KiB
Python

"""
aiosched: Yet another Python async scheduler
Copyright (C) 2022 nocturn9x
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https:www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from collections import deque
from abc import ABC, abstractmethod
from typing import Any
from aiosched.errors import SchedulerError, ResourceClosed
from aiosched.internals.syscalls import (
suspend,
schedule,
current_task,
wait_readable,
)
from aiosched.task import Task
from aiosched.socket import wrap_socket
from socket import socketpair
class Event:
"""
An asynchronous, non thread-safe event
"""
def __init__(self):
"""
Object constructor
"""
self.set = False
self.waiters = set()
def reset(self):
"""
Resets the event's state
"""
self.__init__()
async def trigger(self):
"""
Sets the event, waking up all tasks that called
wait() on it
"""
if self.set:
raise SchedulerError("The event has already been set")
self.set = True
for waiter in self.waiters:
await schedule(waiter)
async def wait(self) -> Any:
"""
Waits until the event is set
"""
self.waiters.add(await current_task())
await suspend() # We get unsuspended by trigger()
class Queue:
"""
An asynchronous FIFO queue. As it is based
on events, it is not thread safe
"""
def __init__(self, maxsize: int | None = None):
"""
Object constructor
"""
self.maxsize = maxsize
# Stores event objects for tasks wanting to
# get items from the queue
self.getters = deque()
# Stores event objects for tasks wanting to
# put items on the queue
self.putters = deque()
self.container = deque()
def __len__(self):
"""
Returns the length of the queue
"""
return len(self.container)
def __repr__(self) -> str:
return f"{type(self).__name__}({f', '.join(map(str, self.container))})"
async def __aiter__(self):
"""
Implements the asynchronous iterator protocol
"""
return self
async def __anext__(self):
"""
Implements the asynchronous iterator protocol
"""
return await self.get()
async def put(self, item: Any):
"""
Pushes an element onto the queue. If the
queue is full, waits until there's
enough space for the queue
"""
if not self.maxsize or len(self.container) < self.maxsize:
self.container.append(item)
if self.getters:
await self.getters.popleft().trigger()
else:
ev = Event()
self.putters.append(ev)
await ev.wait()
self.container.append(item)
async def get(self) -> Any:
"""
Pops an element off the queue. Blocks until
an element is put onto it again if the queue
is empty
"""
if self.container:
if self.putters:
await self.putters.popleft().trigger()
else:
ev = Event()
self.getters.append(ev)
await ev.wait()
return self.container.popleft()
async def clear(self):
"""
Clears the queue
"""
self.container.clear()
async def reset(self):
"""
Resets the queue
"""
await self.clear()
self.getters.clear()
self.putters.clear()
class Channel(ABC):
"""
A generic, two-way, full-duplex communication channel
between tasks. This is just an abstract base class and
should not be instantiated directly. Please also note
that the read() and write() methods are not implemented
here because their signatures vary across subclasses
depending on the underlying communication mechanism
that is used. Implementors must provide those two methods
when subclassing Channel
"""
def __init__(self, maxsize: int | None = None):
"""
Public object constructor
"""
self.maxsize = maxsize
self.closed = False
@abstractmethod
async def close(self):
"""
Closes the memory channel. Any underlying
data is left for other tasks to read
"""
return NotImplemented
@abstractmethod
async def pending(self):
"""
Returns if there's pending
data to be read
"""
return NotImplemented
class MemoryChannel(Channel):
"""
A two-way communication channel between tasks.
Operations on this object are based on the Queue
class and do not involve any I/O, making this
an extremely efficient way to pass data around
to tasks. Since this channel is based on queues,
it is not thread safe
"""
def __init__(self, maxsize: int | None = None):
"""
Public object constructor
"""
super().__init__(maxsize)
# We use a queue as our buffer
self.buffer = Queue(maxsize=maxsize)
async def write(self, data: str):
"""
Writes data to the channel. Blocks if the internal
queue is full until a spot is available. Does nothing
if the channel has been closed
"""
if self.closed:
return
await self.buffer.put(data)
async def read(self):
"""
Reads data from the channel. Blocks until
a message arrives or returns immediately if
one is already waiting
"""
return await self.buffer.get()
async def close(self):
"""
Closes the memory channel. Any underlying
data is left for other tasks to read
"""
self.closed = True
async def pending(self):
"""
Returns if there's pending
data to be read
"""
return bool(len(self.buffer))
class NetworkChannel(Channel):
"""
A two-way communication channel between tasks
that uses an underlying socket pair to communicate
instead of in-memory queues. Not thread safe
"""
def __init__(self):
"""
Public object constructor
"""
super().__init__(None)
# We use a socket as our buffer instead of a queue
sockets = socketpair()
self.reader = wrap_socket(sockets[0])
self.writer = wrap_socket(sockets[1])
self.reader.needs_closing = True
self.writer.needs_closing = True
async def write(self, data: bytes):
"""
Writes data to the channel. Blocks if the internal
socket is not currently available. Does nothing
if the channel has been closed
"""
if self.closed:
raise ValueError("I/O operation on closed channel")
await self.writer.send_all(data)
async def read(self, size: int):
"""
Reads exactly size bytes from the channel. Blocks until
enough data arrives. Extra data is cached and used on the
next read
"""
if self.closed:
raise ValueError("I/O operation on closed channel")
return await self.reader.receive_exactly(size)
async def close(self):
"""
Closes the memory channel. Any underlying
data is flushed out of the internal socket
and is lost
"""
self.closed = True
await self.reader.close()
await self.writer.close()
async def pending(self):
"""
Returns if there's pending
data to be read
"""
if self.closed:
return False
elif self.reader.fileno == -1:
return False
else:
try:
await wait_readable(self.reader.stream)
except ResourceClosed:
return False
return True
class Lock:
"""
A simple asynchronous single-owner lock.
Not thread safe
"""
def __init__(self):
"""
Public constructor
"""
self.owner: Task | None = None
self.tasks: deque[Event] = deque()
async def acquire(self):
"""
Acquires the lock
"""
task = await current_task()
if self.owner is None:
self.owner = task
elif task is self.owner:
raise RuntimeError("lock is already acquired by current task")
elif self.owner is not task:
self.tasks.append(Event())
await self.tasks[-1].wait()
self.owner = task
async def release(self):
"""
Releases the lock
"""
task = await current_task()
if self.owner is None:
raise RuntimeError("lock is not acquired")
elif self.owner is not task:
raise RuntimeError("lock can only released by its owner")
elif self.tasks:
await self.tasks.popleft().trigger()
else:
self.owner = None
async def __aenter__(self):
await self.acquire()
return self
async def __aexit__(self, *args):
await self.release()