Fix various bugs with I/O and timeouts

This commit is contained in:
Mattia Giambirtone 2024-03-22 18:16:41 +01:00
parent 5071399431
commit 060d61dc32
9 changed files with 66 additions and 47 deletions

View File

@ -46,6 +46,16 @@ class SchedulingPolicy(ABC):
event loop
"""
@abstractmethod
def is_scheduled(self, task: Task) -> bool:
"""
Returns whether the given task is
scheduled to run. This doesn't
necessarily mean that the task will
actually get executed, just that the
policy knows about this task
"""
@abstractmethod
def has_next_task(self) -> bool:
"""

View File

@ -161,6 +161,7 @@ class TaskPool:
self.error = e
self.scope.cancel()
finally:
self.scope.__exit__(exc_type, exc_val, exc_tb)
current_loop().close_pool(self)
self._closed = True
if self.error:
@ -168,7 +169,10 @@ class TaskPool:
def done(self):
"""
Returns whether the task pool has finished executing
Returns whether the task pool's internal
task scope has finished executing. Note
that this does not take the scope's owner
into account!
"""
return self.scope.done()
@ -184,4 +188,7 @@ class TaskPool:
executing until it is picked by the scheduler later on
"""
if self._closed:
raise StructIOException("task pool is closed")
return current_loop().spawn(func, *args)

View File

@ -67,12 +67,10 @@ class DefaultKernel(BaseKernel):
return self.policy.get_closest_deadline()
def wait_readable(self, resource: FdWrapper):
self.current_task.state = TaskState.IO
self.current_task: Task
self.io_manager.request_read(resource, self.current_task)
def wait_writable(self, resource: FdWrapper):
self.current_task.state = TaskState.IO
self.current_task: Task
self.io_manager.request_write(resource, self.current_task)
@ -198,8 +196,7 @@ class DefaultKernel(BaseKernel):
data = self.handle(runner, self.current_task)
if data is not None:
method, args, kwargs = data
self.current_task.state = TaskState.PAUSED
self.current_task.paused_when = self.clock.current_time()
self.suspend()
if not callable(getattr(self, method, None)):
# This if block is meant to be triggered by other async
# libraries, which most likely have different method names and behaviors
@ -234,7 +231,6 @@ class DefaultKernel(BaseKernel):
self.cancel_task(self.current_task)
elif schedule:
self.current_task: Task
self.current_task.state = TaskState.READY
# We reschedule the caller immediately!
self.policy.schedule(self.current_task, front=True)
@ -272,6 +268,8 @@ class DefaultKernel(BaseKernel):
error.scope = scope
scope.cancel()
self.throw(scope.owner, error)
if not self.policy.is_scheduled(scope.owner):
self.reschedule(scope.owner)
def wakeup(self):
while (
@ -288,10 +286,9 @@ class DefaultKernel(BaseKernel):
def _tick(self):
if self._sigint_handled and not self.restrict_ki_to_checkpoints:
self.throw(self.entry_point, KeyboardInterrupt())
if self.policy.has_next_task():
self.step()
self.wakeup()
self.check_scopes()
self.step()
if self.io_manager.pending():
self.io_manager.wait_io()
@ -450,9 +447,6 @@ class DefaultKernel(BaseKernel):
self.release(task)
def init_scope(self, scope: TaskScope):
if self.current_task is not self.current_scope.owner:
self.current_task: Task
self.current_scope.tasks.remove(self.current_task)
self.current_task.scope = scope
scope.deadline = self.clock.deadline(scope.timeout)
scope.owner = self.current_task

View File

@ -73,12 +73,12 @@ class SimpleIOManager(BaseIOManager):
current_time = kernel.clock.current_time()
deadline = kernel.get_closest_deadline()
if deadline == float("inf"):
# FIXME: This delay seems to help throttle the calls
# to this method. Should we be calling it this often?
deadline = 0.01
deadline = 0
elif deadline > 0:
deadline -= current_time
deadline = max(0, deadline)
# FIXME: This delay seems to help throttle the calls
# to this method. Should we be calling it this often?
deadline = max(0.01, deadline)
readers = self._collect_readers()
writers = self._collect_writers()
kernel.event("before_io", deadline)
@ -95,15 +95,17 @@ class SimpleIOManager(BaseIOManager):
writable.extend(exceptional)
del exceptional
for read_ready in readable:
for resource, task in self.readers.copy().items():
if resource.fileno() == read_ready and task.state == TaskState.IO:
kernel.reschedule(task)
self.readers.pop(resource)
wrapper = FdWrapper(read_ready)
task = self.readers[wrapper]
kernel.reschedule(task)
# We don't want to listen for read events on
# this resource anymore, so we release it
self.release(wrapper)
for write_ready in writable:
for resource, task in self.writers.copy().items():
if resource.fileno() == write_ready and task.state == TaskState.IO:
kernel.reschedule(task)
self.writers.pop(resource)
wrapper = FdWrapper(write_ready)
task = self.writers[wrapper]
kernel.reschedule(task)
self.release(wrapper)
def request_read(self, rsc: FdWrapper, task: Task):
self._check_closed()

View File

@ -16,6 +16,10 @@ class FIFOPolicy(SchedulingPolicy):
# Paused tasks along with their deadlines
self.paused: TimeQueue = TimeQueue()
def is_scheduled(self, task: Task) -> bool:
# TODO: This should be fine, make sure of it
return task.state == TaskState.READY
def has_next_task(self) -> bool:
return bool(self.run_queue)

View File

@ -45,10 +45,18 @@ class FdWrapper:
def fileno(self):
return self.fd
def __hash__(self):
return self.fd.__hash__()
# Can be converted to an int
def __int__(self):
return self.fd
def __eq__(self, other):
if not isinstance(other, FdWrapper):
return False
return self.fileno() == other.fileno()
def __repr__(self):
return f"<fd={self.fd!r}>"

View File

@ -385,18 +385,14 @@ class AsyncSocket(AsyncResource):
if self._fd == -1:
raise ResourceClosed("I/O operation on closed socket")
await checkpoint()
with self.write_lock, self.read_lock:
while True:
try:
self.socket.connect(address)
if self.do_handshake_on_connect:
await self.do_handshake()
await checkpoint()
break
except WantRead:
await wait_readable(self._fd)
except WantWrite:
await wait_writable(self._fd)
try:
self.socket.connect(address)
except WantWrite:
await wait_writable(self._fd)
if self.do_handshake_on_connect:
await self.do_handshake()
async def close(self):
"""

View File

@ -230,14 +230,12 @@ class MemoryReceiveChannel(ChannelReader):
def __init__(self, buffer):
self._buffer = buffer
self._closed = False
self._read_lock = ThereCanBeOnlyOne("another task is reading from this channel")
@enable_ki_protection
async def receive(self):
if self._closed:
raise ResourceClosed("cannot operate on a closed channel")
with self._read_lock:
return await self._buffer.get()
return await self._buffer.get()
@enable_ki_protection
async def close(self):
@ -260,14 +258,12 @@ class MemorySendChannel(ChannelWriter):
def __init__(self, buffer):
self._buffer = buffer
self._closed = False
self._write_lock = ThereCanBeOnlyOne("another task is writing to this channel")
@enable_ki_protection
async def send(self, item):
if self._closed:
raise ResourceClosed("cannot operate on a closed channel")
with self._write_lock:
return await self._buffer.put(item)
return await self._buffer.put(item)
@enable_ki_protection
async def close(self):

View File

@ -6,13 +6,15 @@ _print = print
def print(*args, **kwargs):
sys.stdout.write(f"[{time.strftime('%H:%M:%S')}] ")
_print(*args, **kwargs)
_print(f"[{time.strftime('%H:%M:%S')}]", *args, **kwargs)
async def test(host: str, port: int, bufsize: int = 4096, keepalive: bool = False):
async def test(host: str, port: int, bufsize: int = 4096, keepalive: bool = False, secure: bool = False):
print(f"Attempting a connection to {host}:{port} {'in keep-alive mode' if keepalive else ''}")
socket = await structio.socket.connect_tcp_ssl_socket(host, port)
if secure:
socket = await structio.socket.connect_tcp_ssl_socket(host, port)
else:
socket = await structio.socket.connect_tcp_socket(host, port)
buffer = b""
print("Connected")
# Ensures the code below doesn't run for more than 5 seconds
@ -41,7 +43,7 @@ async def test(host: str, port: int, bufsize: int = 4096, keepalive: bool = Fals
print("Received empty stream, closing connection")
break
if buffer:
data = buffer.decode().split("\r\n")
data = buffer.decode(errors="ignore").split("\r\n")
print(
f"HTTP Response below {'(might be incomplete)' if scope.timed_out else ''}:"
)
@ -63,6 +65,6 @@ async def test(host: str, port: int, bufsize: int = 4096, keepalive: bool = Fals
_print("Done!")
structio.run(test, "google.com", 443, 256)
structio.run(test, "google.com", 80, 256)
# With keep-alive on, our timeout will kick in
structio.run(test, "google.com", 443, 256, True)
structio.run(test, "google.com", 80, 256, True)