Fix various bugs with I/O and timeouts
This commit is contained in:
parent
5071399431
commit
060d61dc32
|
@ -46,6 +46,16 @@ class SchedulingPolicy(ABC):
|
|||
event loop
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def is_scheduled(self, task: Task) -> bool:
|
||||
"""
|
||||
Returns whether the given task is
|
||||
scheduled to run. This doesn't
|
||||
necessarily mean that the task will
|
||||
actually get executed, just that the
|
||||
policy knows about this task
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def has_next_task(self) -> bool:
|
||||
"""
|
||||
|
|
|
@ -161,6 +161,7 @@ class TaskPool:
|
|||
self.error = e
|
||||
self.scope.cancel()
|
||||
finally:
|
||||
self.scope.__exit__(exc_type, exc_val, exc_tb)
|
||||
current_loop().close_pool(self)
|
||||
self._closed = True
|
||||
if self.error:
|
||||
|
@ -168,7 +169,10 @@ class TaskPool:
|
|||
|
||||
def done(self):
|
||||
"""
|
||||
Returns whether the task pool has finished executing
|
||||
Returns whether the task pool's internal
|
||||
task scope has finished executing. Note
|
||||
that this does not take the scope's owner
|
||||
into account!
|
||||
"""
|
||||
|
||||
return self.scope.done()
|
||||
|
@ -184,4 +188,7 @@ class TaskPool:
|
|||
executing until it is picked by the scheduler later on
|
||||
"""
|
||||
|
||||
if self._closed:
|
||||
raise StructIOException("task pool is closed")
|
||||
|
||||
return current_loop().spawn(func, *args)
|
||||
|
|
|
@ -67,12 +67,10 @@ class DefaultKernel(BaseKernel):
|
|||
return self.policy.get_closest_deadline()
|
||||
|
||||
def wait_readable(self, resource: FdWrapper):
|
||||
self.current_task.state = TaskState.IO
|
||||
self.current_task: Task
|
||||
self.io_manager.request_read(resource, self.current_task)
|
||||
|
||||
def wait_writable(self, resource: FdWrapper):
|
||||
self.current_task.state = TaskState.IO
|
||||
self.current_task: Task
|
||||
self.io_manager.request_write(resource, self.current_task)
|
||||
|
||||
|
@ -198,8 +196,7 @@ class DefaultKernel(BaseKernel):
|
|||
data = self.handle(runner, self.current_task)
|
||||
if data is not None:
|
||||
method, args, kwargs = data
|
||||
self.current_task.state = TaskState.PAUSED
|
||||
self.current_task.paused_when = self.clock.current_time()
|
||||
self.suspend()
|
||||
if not callable(getattr(self, method, None)):
|
||||
# This if block is meant to be triggered by other async
|
||||
# libraries, which most likely have different method names and behaviors
|
||||
|
@ -234,7 +231,6 @@ class DefaultKernel(BaseKernel):
|
|||
self.cancel_task(self.current_task)
|
||||
elif schedule:
|
||||
self.current_task: Task
|
||||
self.current_task.state = TaskState.READY
|
||||
# We reschedule the caller immediately!
|
||||
self.policy.schedule(self.current_task, front=True)
|
||||
|
||||
|
@ -272,6 +268,8 @@ class DefaultKernel(BaseKernel):
|
|||
error.scope = scope
|
||||
scope.cancel()
|
||||
self.throw(scope.owner, error)
|
||||
if not self.policy.is_scheduled(scope.owner):
|
||||
self.reschedule(scope.owner)
|
||||
|
||||
def wakeup(self):
|
||||
while (
|
||||
|
@ -288,10 +286,9 @@ class DefaultKernel(BaseKernel):
|
|||
def _tick(self):
|
||||
if self._sigint_handled and not self.restrict_ki_to_checkpoints:
|
||||
self.throw(self.entry_point, KeyboardInterrupt())
|
||||
if self.policy.has_next_task():
|
||||
self.step()
|
||||
self.wakeup()
|
||||
self.check_scopes()
|
||||
self.step()
|
||||
if self.io_manager.pending():
|
||||
self.io_manager.wait_io()
|
||||
|
||||
|
@ -450,9 +447,6 @@ class DefaultKernel(BaseKernel):
|
|||
self.release(task)
|
||||
|
||||
def init_scope(self, scope: TaskScope):
|
||||
if self.current_task is not self.current_scope.owner:
|
||||
self.current_task: Task
|
||||
self.current_scope.tasks.remove(self.current_task)
|
||||
self.current_task.scope = scope
|
||||
scope.deadline = self.clock.deadline(scope.timeout)
|
||||
scope.owner = self.current_task
|
||||
|
|
|
@ -73,12 +73,12 @@ class SimpleIOManager(BaseIOManager):
|
|||
current_time = kernel.clock.current_time()
|
||||
deadline = kernel.get_closest_deadline()
|
||||
if deadline == float("inf"):
|
||||
# FIXME: This delay seems to help throttle the calls
|
||||
# to this method. Should we be calling it this often?
|
||||
deadline = 0.01
|
||||
deadline = 0
|
||||
elif deadline > 0:
|
||||
deadline -= current_time
|
||||
deadline = max(0, deadline)
|
||||
# FIXME: This delay seems to help throttle the calls
|
||||
# to this method. Should we be calling it this often?
|
||||
deadline = max(0.01, deadline)
|
||||
readers = self._collect_readers()
|
||||
writers = self._collect_writers()
|
||||
kernel.event("before_io", deadline)
|
||||
|
@ -95,15 +95,17 @@ class SimpleIOManager(BaseIOManager):
|
|||
writable.extend(exceptional)
|
||||
del exceptional
|
||||
for read_ready in readable:
|
||||
for resource, task in self.readers.copy().items():
|
||||
if resource.fileno() == read_ready and task.state == TaskState.IO:
|
||||
kernel.reschedule(task)
|
||||
self.readers.pop(resource)
|
||||
wrapper = FdWrapper(read_ready)
|
||||
task = self.readers[wrapper]
|
||||
kernel.reschedule(task)
|
||||
# We don't want to listen for read events on
|
||||
# this resource anymore, so we release it
|
||||
self.release(wrapper)
|
||||
for write_ready in writable:
|
||||
for resource, task in self.writers.copy().items():
|
||||
if resource.fileno() == write_ready and task.state == TaskState.IO:
|
||||
kernel.reschedule(task)
|
||||
self.writers.pop(resource)
|
||||
wrapper = FdWrapper(write_ready)
|
||||
task = self.writers[wrapper]
|
||||
kernel.reschedule(task)
|
||||
self.release(wrapper)
|
||||
|
||||
def request_read(self, rsc: FdWrapper, task: Task):
|
||||
self._check_closed()
|
||||
|
|
|
@ -16,6 +16,10 @@ class FIFOPolicy(SchedulingPolicy):
|
|||
# Paused tasks along with their deadlines
|
||||
self.paused: TimeQueue = TimeQueue()
|
||||
|
||||
def is_scheduled(self, task: Task) -> bool:
|
||||
# TODO: This should be fine, make sure of it
|
||||
return task.state == TaskState.READY
|
||||
|
||||
def has_next_task(self) -> bool:
|
||||
return bool(self.run_queue)
|
||||
|
||||
|
|
|
@ -45,10 +45,18 @@ class FdWrapper:
|
|||
def fileno(self):
|
||||
return self.fd
|
||||
|
||||
def __hash__(self):
|
||||
return self.fd.__hash__()
|
||||
|
||||
# Can be converted to an int
|
||||
def __int__(self):
|
||||
return self.fd
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, FdWrapper):
|
||||
return False
|
||||
return self.fileno() == other.fileno()
|
||||
|
||||
def __repr__(self):
|
||||
return f"<fd={self.fd!r}>"
|
||||
|
||||
|
|
|
@ -385,18 +385,14 @@ class AsyncSocket(AsyncResource):
|
|||
|
||||
if self._fd == -1:
|
||||
raise ResourceClosed("I/O operation on closed socket")
|
||||
await checkpoint()
|
||||
with self.write_lock, self.read_lock:
|
||||
while True:
|
||||
try:
|
||||
self.socket.connect(address)
|
||||
if self.do_handshake_on_connect:
|
||||
await self.do_handshake()
|
||||
await checkpoint()
|
||||
break
|
||||
except WantRead:
|
||||
await wait_readable(self._fd)
|
||||
except WantWrite:
|
||||
await wait_writable(self._fd)
|
||||
try:
|
||||
self.socket.connect(address)
|
||||
except WantWrite:
|
||||
await wait_writable(self._fd)
|
||||
if self.do_handshake_on_connect:
|
||||
await self.do_handshake()
|
||||
|
||||
async def close(self):
|
||||
"""
|
||||
|
|
|
@ -230,14 +230,12 @@ class MemoryReceiveChannel(ChannelReader):
|
|||
def __init__(self, buffer):
|
||||
self._buffer = buffer
|
||||
self._closed = False
|
||||
self._read_lock = ThereCanBeOnlyOne("another task is reading from this channel")
|
||||
|
||||
@enable_ki_protection
|
||||
async def receive(self):
|
||||
if self._closed:
|
||||
raise ResourceClosed("cannot operate on a closed channel")
|
||||
with self._read_lock:
|
||||
return await self._buffer.get()
|
||||
return await self._buffer.get()
|
||||
|
||||
@enable_ki_protection
|
||||
async def close(self):
|
||||
|
@ -260,14 +258,12 @@ class MemorySendChannel(ChannelWriter):
|
|||
def __init__(self, buffer):
|
||||
self._buffer = buffer
|
||||
self._closed = False
|
||||
self._write_lock = ThereCanBeOnlyOne("another task is writing to this channel")
|
||||
|
||||
@enable_ki_protection
|
||||
async def send(self, item):
|
||||
if self._closed:
|
||||
raise ResourceClosed("cannot operate on a closed channel")
|
||||
with self._write_lock:
|
||||
return await self._buffer.put(item)
|
||||
return await self._buffer.put(item)
|
||||
|
||||
@enable_ki_protection
|
||||
async def close(self):
|
||||
|
|
|
@ -6,13 +6,15 @@ _print = print
|
|||
|
||||
|
||||
def print(*args, **kwargs):
|
||||
sys.stdout.write(f"[{time.strftime('%H:%M:%S')}] ")
|
||||
_print(*args, **kwargs)
|
||||
_print(f"[{time.strftime('%H:%M:%S')}]", *args, **kwargs)
|
||||
|
||||
|
||||
async def test(host: str, port: int, bufsize: int = 4096, keepalive: bool = False):
|
||||
async def test(host: str, port: int, bufsize: int = 4096, keepalive: bool = False, secure: bool = False):
|
||||
print(f"Attempting a connection to {host}:{port} {'in keep-alive mode' if keepalive else ''}")
|
||||
socket = await structio.socket.connect_tcp_ssl_socket(host, port)
|
||||
if secure:
|
||||
socket = await structio.socket.connect_tcp_ssl_socket(host, port)
|
||||
else:
|
||||
socket = await structio.socket.connect_tcp_socket(host, port)
|
||||
buffer = b""
|
||||
print("Connected")
|
||||
# Ensures the code below doesn't run for more than 5 seconds
|
||||
|
@ -41,7 +43,7 @@ async def test(host: str, port: int, bufsize: int = 4096, keepalive: bool = Fals
|
|||
print("Received empty stream, closing connection")
|
||||
break
|
||||
if buffer:
|
||||
data = buffer.decode().split("\r\n")
|
||||
data = buffer.decode(errors="ignore").split("\r\n")
|
||||
print(
|
||||
f"HTTP Response below {'(might be incomplete)' if scope.timed_out else ''}:"
|
||||
)
|
||||
|
@ -63,6 +65,6 @@ async def test(host: str, port: int, bufsize: int = 4096, keepalive: bool = Fals
|
|||
_print("Done!")
|
||||
|
||||
|
||||
structio.run(test, "google.com", 443, 256)
|
||||
structio.run(test, "google.com", 80, 256)
|
||||
# With keep-alive on, our timeout will kick in
|
||||
structio.run(test, "google.com", 443, 256, True)
|
||||
structio.run(test, "google.com", 80, 256, True)
|
Loading…
Reference in New Issue