Fixes/additions to I/O mechanism, bugs still exist in network_channel test

This commit is contained in:
Nocturn9x 2022-11-11 17:39:11 +01:00
parent 7b134f9a1d
commit acc436d518
10 changed files with 267 additions and 89 deletions

View File

@ -23,6 +23,7 @@ from aiosched.internals.syscalls import (
set_context,
close_context,
join,
current_task,
)
from typing import Any, Coroutine, Callable
@ -34,13 +35,13 @@ class TaskContext(Task):
an exception occurs. A TaskContext object behaves like
a regular task and the event loop treats it like a single
unit rather than a collection of tasks (in fact, the event
loop doesn't even know whether the current task is a task
context or not, which is by design). TaskContexts can be
nested and will cancel inner ones if an exception is raised
inside them
loop doesn't even know, nor care about, whether the current
task is a task context or not, which is by design). Contexts
can be nested and will cancel inner ones if an exception is
raised inside them
"""
def __init__(self, silent: bool = False, gather: bool = True) -> None:
def __init__(self, silent: bool = False, gather: bool = True, timeout: int | float = 0.0) -> None:
"""
Object constructor
"""
@ -49,13 +50,16 @@ class TaskContext(Task):
self.tasks: list[Task] = []
# Whether we have been cancelled or not
self.cancelled: bool = False
# The context's entry point (needed to forward run() calls and the like)
# The context's entry point (needed to disguise ourselves as a task ;))
self.entry_point: Task | TaskContext | None = None
# Do we ignore exceptions?
self.silent: bool = silent
# Do we gather multiple exceptions from
# children tasks?
self.gather: bool = gather
self.gather: bool = gather # TODO: Implement
# For how long do we allow tasks inside us
# to run?
self.timeout: int | float = timeout # TODO: Implement
async def spawn(
self, func: Callable[..., Coroutine[Any, Any, Any]], *args, **kwargs
@ -78,6 +82,17 @@ class TaskContext(Task):
await set_context(self)
return self
def __eq__(self, other):
"""
Implements self == other
"""
if isinstance(other, TaskContext):
return super().__eq__(self, other)
elif isinstance(other, Task):
return other == self.entry_point
return False
async def __aexit__(self, exc_type: Exception, exc: Exception, tb):
"""
Implements the asynchronous context manager interface, waiting
@ -91,6 +106,11 @@ class TaskContext(Task):
# end of the block and wait for all
# children to exit
if task is self.entry_point:
# We don't wait on the entry
# point because that's us!
# Besides, even if we tried,
# wait() would raise an error
# to avoid a deadlock
continue
await wait(task)
except BaseException as exc:

View File

@ -103,10 +103,9 @@ async def checkpoint():
async def suspend():
"""
Suspends the current task. The task is not
rescheduled until some other event (for example
a timer, an event or an I/O operation) reschedules
it
Suspends the calling task indefinitely.
The task can be unsuspended by a timer,
an event or an incoming I/O operation
"""
await syscall("suspend")
@ -125,7 +124,9 @@ async def join(task: Task):
"""
Tells the event loop that the current task
wants to wait on the given one, but without
waiting for its completion
waiting for its completion. This is a low
level trap and should not be used on its
own
"""
await syscall("join", task)
@ -140,7 +141,8 @@ async def wait(task: Task) -> Any | None:
Returns immediately if the task has
completed already, but exceptions are
propagated only once. Returns the task's
return value, if it has one
return value, if it has one (returned once
for each call).
:param task: The task to wait for
:type task: :class: Task
@ -148,7 +150,10 @@ async def wait(task: Task) -> Any | None:
"""
current = await current_task()
if task is current:
if task == current:
# We don't do an "x is y" check because
# tasks and task contexts can compare equal
# despite having different memory addresses
raise SchedulerError("a task cannot join itself")
if current not in task.joiners:
# Luckily we use a set, so this has O(1)
@ -156,6 +161,8 @@ async def wait(task: Task) -> Any | None:
await join(task) # Waiting implies joining!
await syscall("wait", task)
if task.exc and task.state != TaskState.CANCELLED and task.propagate:
# Task raised an error that wasn't directly caused by a cancellation:
# raise it, but do so only the first time wait was called
task.propagate = False
raise task.exc
return task.result

View File

@ -17,11 +17,10 @@ limitations under the License.
"""
import socket
import ssl
import warnings
import os
import aiosched
from aiosched.errors import ResourceClosed
from aiosched.errors import ResourceClosed, ResourceBroken
from aiosched.internals.syscalls import (
wait_writable,
wait_readable,
@ -99,8 +98,8 @@ class AsyncStream:
await io_release(self.stream)
self.stream.close()
self.stream = None
await aiosched.checkpoint()
@property
async def fileno(self):
"""
Wrapper socket method
@ -132,7 +131,7 @@ class AsyncStream:
this directly: stuff will break
"""
if self._fd != -1:
if self._fd != -1 and self.stream.fileno() != -1:
try:
os.set_blocking(self._fd, False)
os.close(self._fd)
@ -153,11 +152,18 @@ class AsyncSocket(AsyncStream):
close_on_context_exit: bool = True,
do_handshake_on_connect: bool = True,
):
super().__init__(
sock.fileno(), open_fd=False, close_on_context_exit=close_on_context_exit
)
# Do we perform the TCP handshake automatically
# upon connection? This is mostly needed for SSL
# sockets
self.do_handshake_on_connect = do_handshake_on_connect
self.stream = socket.fromfd(self._fd, sock.family, sock.type, sock.proto)
# Do we close ourselves upon the end of a context manager?
self.close_on_context_exit = close_on_context_exit
# The socket.fromfd function copies the file descriptor
# instead of using the same one, so we'd be trying to close
# a different resource if we used sock.fileno() instead
# of self.stream.fileno() as our file descriptor
self.stream = socket.fromfd(sock.fileno(), sock.family, sock.type, sock.proto)
self._fd = self.stream.fileno()
self.stream.setblocking(False)
# A socket that isn't connected doesn't
# need to be closed
@ -179,6 +185,21 @@ class AsyncSocket(AsyncStream):
except WriteBlock:
await wait_writable(self.stream)
async def receive_exactly(self, size: int, flags: int = 0) -> bytes:
"""
Receives exactly size bytes from a socket asynchronously.
"""
# https://stackoverflow.com/questions/55825905/how-can-i-reliably-read-exactly-n-bytes-from-a-tcp-socket
buf = bytearray(size)
pos = 0
while pos < size:
n = await self.recv_into(memoryview(buf)[pos:], flags=flags)
if n == 0:
raise ResourceBroken("incomplete read detected")
pos += n
return bytes(buf)
async def connect(self, address):
"""
Wrapper socket method
@ -240,6 +261,8 @@ class AsyncSocket(AsyncStream):
Wrapper socket method
"""
if self._fd == -1:
raise ResourceClosed("I/O operation on closed socket")
if self.stream:
self.stream.shutdown(how)
await aiosched.checkpoint()
@ -320,6 +343,19 @@ class AsyncSocket(AsyncStream):
except WriteBlock:
await wait_writable(self.stream)
async def recv_into(self, buffer, nbytes=0, flags=0):
"""
Wrapper socket method
"""
while True:
try:
return self.stream.recv_into(buffer, nbytes, flags)
except ReadBlock:
await wait_readable(self.stream)
except WriteBlock:
await wait_writable(self.stream)
async def recvfrom_into(self, buffer, bytes=0, flags=0):
"""
Wrapper socket method

View File

@ -113,7 +113,25 @@ class FIFOKernel:
to do
"""
return not any([self.paused, self.run_ready, self.selector.get_map()])
if self.current_task and not self.current_task.done():
# Current task isn't done yet!
return False
if any([self.paused, self.run_ready]):
# There's tasks sleeping and/or on the
# ready queue!
return False
for key in self.selector.get_map().values():
# We don't just do any([self.paused, self.run_ready, self.selector.get_map()])
# because we don't want to just know if there's any resources we're waiting on,
# but if there's at least one non-terminated task that owns a resource we're
# waiting on. This avoids issues such as the event loop never exiting if the
# user forgets to close a socket, for example
key.data: Task
if key.data.done():
continue
elif self.get_task_io(key.data):
return False
return True
def close(self, force: bool = False):
"""
@ -159,16 +177,30 @@ class FIFOKernel:
timeout = 0.0
if self.run_ready:
# If there is work to do immediately (tasks to run) we
# can't wait
# can't wait.
# TODO: This could cause I/O starvation in highly concurrent
# environments: maybe a more convoluted scheduling strategy
# where I/O timeouts can only be skipped n times before a
# mandatory x-second timeout occurs is needed? It should of
# course take deadlines into account so that timeouts are
# always delivered in a timely manner and tasks awake from
# sleeping at the right moment
timeout = 0.0
elif self.paused:
# If there are asleep tasks or deadlines, wait until the closest date
timeout = self.paused.get_closest_deadline()
timeout = self.paused.get_closest_deadline() - self.clock()
self.debugger.before_io(timeout)
io_ready = self.selector.select(timeout)
# Get sockets that are ready and schedule their tasks
for key, _ in io_ready:
self.run_ready.append(key.data) # Resource ready? Schedule its task
for key, _ in self.selector.select(timeout):
key.data: Task
if key.data.state == TaskState.IO:
# We don't reschedule a task that wasn't
# blocking on I/O before: this way if a
# task waits on a socket and then goes to
# sleep, it won't be woken up early if the
# resource becomes available before its
# deadline expires
self.run_ready.append(key.data) # Resource ready? Schedule its task
self.debugger.after_io(self.clock() - before_time)
def awake_tasks(self):
@ -220,9 +252,9 @@ class FIFOKernel:
our primitives or async methods.
Note that this method does NOT catch any
exception arising from tasks, nor does it
take StopIteration or CancelledError into
account: that's the job for run()!
errors arising from tasks, nor does it take
StopIteration or Cancelled exceptions into
account
"""
# Sets the currently running task
@ -253,12 +285,12 @@ class FIFOKernel:
)
if not hasattr(self, method) or not callable(getattr(self, method)):
# This if block is meant to be triggered by other async
# libraries, which most likely have different trap names and behaviors
# libraries, which most likely have different method names and behaviors
# compared to us. If you get this exception, and you're 100% sure you're
# not mixing async primitives from other libraries, then it's a bug!
self.current_task.throw(
InternalError(
"Uh oh! Something very bad just happened, did you try to mix primitives from other async libraries?"
"Uh oh! Something bad just happened: did you try to mix primitives from other async libraries?"
)
)
# Sneaky method call, thanks to David Beazley for this ;)
@ -321,7 +353,8 @@ class FIFOKernel:
and self.entry_point.propagate
):
# Contexts already manage exceptions for us,
# no need to raise it manually
# no need to raise it manually. If a context
# is not used, *then* we can raise the error
raise self.entry_point.exc
return self.entry_point.result
@ -334,6 +367,7 @@ class FIFOKernel:
if self.selector.get_map() and resource in self.selector.get_map():
self.selector.unregister(resource)
self.debugger.on_io_unschedule(resource)
def io_release_task(self, task: Task):
"""
@ -348,6 +382,14 @@ class FIFOKernel:
self.selector.unregister(key.fileobj)
task.last_io = ()
def get_task_io(self, task: Task) -> list:
"""
Returns the streams currently in use by
the given task
"""
return list(map(lambda k: k.fileobj, filter(lambda k: k.data == task, self.selector.get_map().values())))
def notify_closing(self, stream, broken: bool = False):
"""
Notifies paused tasks that a stream
@ -452,6 +494,7 @@ class FIFOKernel:
self.paused.discard(task)
self.io_release_task(task)
self.run_ready.extend(task.joiners)
self.reschedule_running()
def join(self, task: Task):
"""
@ -491,6 +534,7 @@ class FIFOKernel:
ctx.tasks.append(ctx.entry_point)
self.current_task.context = ctx
self.current_task = ctx
self.debugger.on_context_creation(ctx)
self.reschedule_running()
def close_context(self, ctx: TaskContext):
@ -498,6 +542,7 @@ class FIFOKernel:
Closes the given context
"""
self.debugger.on_context_exit(ctx)
task = ctx.entry_point
task.context = None
self.current_task = task
@ -547,12 +592,14 @@ class FIFOKernel:
# If the event to listen for has changed we just modify it
self.selector.modify(resource, evt_type, self.current_task)
self.current_task.last_io = (evt_type, resource)
self.debugger.on_io_schedule(resource, evt_type)
elif not self.current_task.last_io or self.current_task.last_io[1] != resource:
# The task has either registered a new resource or is doing
# I/O for the first time
self.current_task.last_io = evt_type, resource
try:
self.selector.register(resource, evt_type, self.current_task)
self.debugger.on_io_schedule(resource, evt_type)
except KeyError:
# The stream is already being used
key = self.selector.get_key(resource)
@ -565,6 +612,7 @@ class FIFOKernel:
# off a given stream while another one is
# writing to it
self.selector.modify(resource, evt_type, self.current_task)
self.debugger.on_io_schedule(resource, evt_type)
else:
# One task reading and one writing on the same
# resource is fine (think producer-consumer),

View File

@ -18,11 +18,12 @@ limitations under the License.
from collections import deque
from abc import ABC, abstractmethod
from typing import Any
from aiosched.errors import SchedulerError
from aiosched.errors import SchedulerError, ResourceClosed
from aiosched.internals.syscalls import (
suspend,
schedule,
current_task,
wait_readable,
)
from aiosched.task import Task
from aiosched.socket import wrap_socket
@ -72,7 +73,8 @@ class Event:
class Queue:
"""
An asynchronous FIFO queue. Not thread safe
An asynchronous FIFO queue. As it is based
on events, it is not thread safe
"""
def __init__(self, maxsize: int | None = None):
@ -167,7 +169,12 @@ class Channel(ABC):
"""
A generic, two-way, full-duplex communication channel
between tasks. This is just an abstract base class and
should not be instantiated directly
should not be instantiated directly. Please also note
that the read() and write() methods are not implemented
here because their signatures vary across subclasses
depending on the underlying communication mechanism
that is used. Implementors must provide those two methods
when subclassing Channel
"""
def __init__(self, maxsize: int | None = None):
@ -178,26 +185,6 @@ class Channel(ABC):
self.maxsize = maxsize
self.closed = False
@abstractmethod
async def write(self, data: str):
"""
Writes data to the channel. Blocks if the internal
queue is full until a spot is available. Does nothing
if the channel has been closed
"""
return NotImplemented
@abstractmethod
async def read(self):
"""
Reads data from the channel. Blocks until
a message arrives or returns immediately if
one is already waiting
"""
return NotImplemented
@abstractmethod
async def close(self):
"""
@ -220,9 +207,11 @@ class Channel(ABC):
class MemoryChannel(Channel):
"""
A two-way communication channel between tasks.
Operations on this object do not perform any I/O
or other system call and are therefore extremely
efficient. Not thread safe
Operations on this object are based on the Queue
class and do not involve any I/O, making this
an extremely efficient way to pass data around
to tasks. Since this channel is based on queues,
it is not thread safe
"""
def __init__(self, maxsize: int | None = None):
@ -288,7 +277,8 @@ class NetworkChannel(Channel):
sockets = socketpair()
self.reader = wrap_socket(sockets[0])
self.writer = wrap_socket(sockets[1])
self._internal_buffer = b""
self.reader.needs_closing = True
self.writer.needs_closing = True
async def write(self, data: bytes):
"""
@ -298,7 +288,7 @@ class NetworkChannel(Channel):
"""
if self.closed:
return
raise ValueError("I/O operation on closed channel")
await self.writer.send_all(data)
async def read(self, size: int):
@ -308,12 +298,9 @@ class NetworkChannel(Channel):
next read
"""
data = self._internal_buffer
while len(data) < size:
data += await self.reader.receive(size)
self._internal_buffer = data[size:]
data = data[:size]
return data
if self.closed:
raise ValueError("I/O operation on closed channel")
return await self.reader.receive_exactly(size)
async def close(self):
"""
@ -332,13 +319,15 @@ class NetworkChannel(Channel):
data to be read
"""
# TODO: Ugly!
if self.closed:
return False
try:
self._internal_buffer += self.reader.stream.recv(1)
except BlockingIOError:
elif self.reader.fileno == -1:
return False
else:
try:
await wait_readable(self.reader.stream)
except ResourceClosed:
return False
return True

View File

@ -135,5 +135,7 @@ class Task:
Task destructor
"""
if not self.done():
warnings.warn(f"task '{self.name}' was destroyed, but it has not completed yet")
if self.last_io:
warnings.warn(f"task '{self.name}' was destroyed, but has pending I/O")
warnings.warn(f"task '{self.name}' was destroyed, but it has pending I/O")

View File

@ -1,5 +1,6 @@
from abc import ABC, abstractmethod
from aiosched.task import Task
from aiosched.context import TaskContext
class BaseDebugger(ABC):
@ -192,3 +193,52 @@ class BaseDebugger(ABC):
"""
return NotImplemented
@abstractmethod
def on_context_creation(self, ctx: TaskContext):
"""
This method is called right after a task
context is initialized, i.e. when set_context
in the event loop is called
:param ctx: The context object
:type ctx: TaskContext
:return:
"""
return NotImplemented
@abstractmethod
def on_context_exit(self, ctx: TaskContext):
"""
This method is called right before a task
context is closed, i.e. when close_context
in the event loop is called
:param ctx: The context object
:type ctx: TaskContext
:return:
"""
return NotImplemented
@abstractmethod
def on_io_schedule(self, stream, event: int):
"""
This method is called whenever the
perform_io primitive is called within
the aiosched event loop with the stream
to be registered in the selector and the
chosen event mask
"""
return NotImplemented
@abstractmethod
def on_io_unschedule(self, stream):
"""
This method is called whenever a stream
is unregistered from the loop's I/O selector
"""
return NotImplemented

View File

@ -75,6 +75,7 @@ async def handler(sock: aiosched.socket.AsyncSocket):
logging.info(f"Connection from {address} closed")
clients.pop(sock)
names.discard(name)
logging.info("Handler shutting down")
if __name__ == "__main__":

View File

@ -1,4 +1,5 @@
from aiosched.util.debugging import BaseDebugger
from selectors import EVENT_READ, EVENT_WRITE
class Debugger(BaseDebugger):
@ -51,3 +52,22 @@ class Debugger(BaseDebugger):
def on_exception_raised(self, task, exc):
print(f"== '{task.name}' raised {repr(exc)}")
def on_context_creation(self, ctx):
print(f"=> A new context was created by {ctx.entry_point.name!r}")
def on_context_exit(self, ctx):
print(f"=> A context was closed by {ctx.entry_point.name}")
def on_io_schedule(self, stream, event: int):
evt = ""
if event == EVENT_READ:
evt = "reading"
elif event == EVENT_WRITE:
evt = "writing"
elif event == EVENT_WRITE | EVENT_READ:
evt = "reading or writing"
print(f"|| Stream {stream!r} was scheduled for {evt}")
def on_io_unschedule(self, stream):
print(f"|| Stream {stream!r} was unscheduled")

View File

@ -2,30 +2,35 @@ import aiosched
from debugger import Debugger
async def sender(c: aiosched.NetworkChannel, n: int):
async def producer(c: aiosched.NetworkChannel, n: int):
print("[producer] Started")
for i in range(n):
await c.write(str(i).encode())
print(f"Sent {i}")
await c.close()
print("Sender done")
print(f"[producer] Sent {i}")
await aiosched.sleep(0.5) # This makes the receiver wait on us!
#await c.close()
print("[producer] Done")
async def receiver(c: aiosched.NetworkChannel):
while True:
if not await c.pending() and c.closed:
print("Receiver done")
break
item = (await c.read(1)).decode()
print(f"Received {item}")
await aiosched.sleep(1)
async def consumer(c: aiosched.NetworkChannel):
print("[receiver] Started")
try:
while await c.pending():
item = await c.read(1)
print(f"[consumer] Received {item.decode()}")
# await aiosched.sleep(2) # If you uncomment this, the except block will be triggered
except aiosched.errors.ResourceClosed:
print("[consumer] Stream has been closed early!")
print("[consumer] Done")
async def main(channel: aiosched.NetworkChannel, n: int):
print("Starting sender and receiver")
t = aiosched.clock()
print("[main] Starting children")
async with aiosched.with_context() as ctx:
await ctx.spawn(sender, channel, n)
await ctx.spawn(receiver, channel)
print("All done!")
await ctx.spawn(consumer, channel)
await ctx.spawn(producer, channel, n)
print(f"[main] All done in {aiosched.clock() - t:.2f} seconds")
aiosched.run(