Several fixes to nested pools, cancellation, timeouts and more. Fixed SSL I/O (WIP)

This commit is contained in:
nocturn9x 2021-08-28 23:26:24 +02:00
parent b9ed99e3ee
commit 0b8e1487c7
17 changed files with 324 additions and 122 deletions

25
a.py Normal file
View File

@ -0,0 +1,25 @@
import giambio
import socket as sock
import ssl
async def test(host: str, port: int):
socket = giambio.socket.wrap_socket(ssl.wrap_socket(sock.socket()))
await socket.connect((host, port))
async with socket:
await socket.send_all(b"""GET / HTTP/1.1\r
Host: google.com\r
User-Agent: owo\r
Accept: text/html\r
Connection: keep-alive\r\n\r\n""")
buffer = b""
while True:
data = await socket.receive(4096)
if data:
buffer += data
else:
break
print("\n".join(buffer.decode().split("\r\n")))
giambio.run(test, "google.com", 443)

View File

@ -23,7 +23,7 @@ __version__ = (0, 0, 1)
from . import exceptions, socket, context, core, task, io
from .traps import sleep, current_task
from .sync import Event
from .run import run, clock, create_pool, get_event_loop, new_event_loop, with_timeout
from .run import run, clock, create_pool, get_event_loop, new_event_loop, with_timeout, skip_after
from .util import debug
@ -41,4 +41,5 @@ __all__ = [
"current_task",
"new_event_loop",
"debug",
"skip_after"
]

View File

@ -27,9 +27,12 @@ class TaskManager:
:param timeout: The pool's timeout length in seconds, if any, defaults to None
:type timeout: float, optional
:param raise_on_timeout: Whether to catch a TooSlowError exception when the pool's
timeout expires or not. Defaults to True
:type raise_on_timeout: bool, optional
"""
def __init__(self, timeout: float = None) -> None:
def __init__(self, timeout: float = None, raise_on_timeout: bool = True) -> None:
"""
Object constructor
"""
@ -45,21 +48,21 @@ class TaskManager:
if timeout:
self.timeout: float = self.started + timeout
else:
self.timeout: None = None
self.timeout = None
# Whether our timeout expired or not
self.timed_out: bool = False
self._proper_init = False
self.enclosing_pool: Optional["giambio.context.TaskManager"] = giambio.get_event_loop().current_pool
self.enclosed_pool: Optional["giambio.context.TaskManager"] = None
# giambio.get_event_loop().current_pool = self
self.raise_on_timeout: bool = raise_on_timeout
async def spawn(self, func: types.FunctionType, *args) -> "giambio.task.Task":
async def spawn(self, func: types.FunctionType, *args, **kwargs) -> "giambio.task.Task":
"""
Spawns a child task
"""
assert self._proper_init, "Cannot use improperly initialized pool"
return await giambio.traps.create_task(func, self, *args)
self.tasks.append(await giambio.traps.create_task(func, self, *args, **kwargs))
return self.tasks[-1]
async def __aenter__(self):
"""
@ -80,7 +83,10 @@ class TaskManager:
# end of the block and wait for all
# children to exit
await task.join()
self.tasks.remove(task)
self._proper_init = False
if isinstance(exc, giambio.exceptions.TooSlowError) and not self.raise_on_timeout:
return True
async def cancel(self):
"""
@ -91,6 +97,7 @@ class TaskManager:
# TODO: This breaks, somehow, investigation needed
for task in self.tasks:
await task.cancel()
self.tasks.remove(task)
def done(self) -> bool:
"""
@ -98,4 +105,4 @@ class TaskManager:
pool have exited, False otherwise
"""
return all([task.done() for task in self.tasks])
return self._proper_init and all([task.done() for task in self.tasks])

View File

@ -19,6 +19,7 @@ limitations under the License.
# Import libraries and internal resources
import types
from giambio.task import Task
from collections import deque
from timeit import default_timer
from giambio.context import TaskManager
from typing import List, Optional, Any, Dict
@ -99,7 +100,7 @@ class AsyncScheduler:
# All tasks the loop has
self.tasks: List[Task] = []
# Tasks that are ready to run
self.run_ready: List[Task] = []
self.run_ready: deque = deque()
# Selector object to perform I/O multiplexing
self.selector = selector or DefaultSelector()
# This will always point to the currently running coroutine (Task object)
@ -123,7 +124,7 @@ class AsyncScheduler:
# The I/O skip limit. TODO: Back up this value with euristics
self.io_skip_limit = io_skip_limit or 5
# The max. I/O timeout
self.io_max_timeout = io_max_timeout
self.io_max_timeout = io_max_timeout or 86400
def __repr__(self):
"""
@ -187,33 +188,42 @@ class AsyncScheduler:
"""
while True:
if self.done():
# If we're done, which means there are
# both no paused tasks and no running tasks, we
# simply tear us down and return to self.start
self.close()
break
elif not self.run_ready:
# Stores deadlines for tasks (deadlines are pool-specific).
# The deadlines queue will internally make sure not to store
# a deadline for the same pool twice. This makes the timeouts
# model less flexible, because one can't change the timeout
# after it is set, but it makes the implementation easier
if not self.current_pool and self.current_task.pool:
self.current_pool = self.current_task.pool
self.deadlines.put(self.current_pool)
# If there are no actively running tasks, we start by
# checking for I/O. This method will wait for I/O until
# the closest deadline to avoid starving sleeping tasks
# or missing deadlines
if self.selector.get_map():
self.check_io()
if self.deadlines:
# Deadline expiration is our next step
try:
self.prune_deadlines()
except TooSlowError as t:
task = t.args[0]
task.exc = t
self.join(task)
if self.paused:
# Next we try to (re)schedule the asleep tasks
self.awake_sleeping()
# Otherwise, while there are tasks ready to run, we run them!
try:
if self.done():
# If we're done, which means there are
# both no paused tasks and no running tasks, we
# simply tear us down and return to self.start
self.close()
break
elif not self.run_ready:
# If there are no actively running tasks, we start by
# checking for I/O. This method will wait for I/O until
# the closest deadline to avoid starving sleeping tasks
if self.selector.get_map():
self.check_io()
if self.deadlines:
# Then we start checking for deadlines, if there are any
self.expire_deadlines()
if self.paused:
# Next we try to (re)schedule the asleep tasks
self.awake_sleeping()
if self.current_pool and self.current_pool.timeout and not self.current_pool.timed_out:
# Stores deadlines for tasks (deadlines are pool-specific).
# The deadlines queue will internally make sure not to store
# a deadline for the same pool twice. This makes the timeouts
# model less flexible, because one can't change the timeout
# after it is set, but it makes the implementation easier
self.deadlines.put(self.current_pool)
# Otherwise, while there are tasks ready to run, we run them!
# This try/except block catches all runtime
# exceptions
while self.run_ready:
self.run_task_step()
except StopIteration as ret:
@ -236,17 +246,21 @@ class AsyncScheduler:
# self.join() work its magic
self.current_task.exc = err
self.join(self.current_task)
self.tasks.remove(self.current_task)
if self.current_task in self.tasks:
self.tasks.remove(self.current_task)
def create_task(self, corofunc: types.FunctionType, pool, *args, **kwargs) -> Task:
"""
Creates a task from a coroutine function and schedules it
to run. Any extra keyword or positional argument are then
passed to the function
to run. The associated pool that spawned said task is also
needed, while any extra keyword or positional arguments are
passed to the function itself
:param corofunc: The coroutine function (not a coroutine!) to
spawn
spawn
:type corofunc: function
:param pool: The giambio.context.TaskManager object that
spawned the task
"""
task = Task(corofunc.__name__ or str(corofunc), corofunc(*args, **kwargs), pool)
@ -256,11 +270,10 @@ class AsyncScheduler:
self.tasks.append(task)
self.run_ready.append(task)
self.debugger.on_task_spawn(task)
pool.tasks.append(task)
self.reschedule_running()
if self.current_pool and task.pool is not self.current_pool:
self.current_pool.enclosed_pool = task.pool
self.current_pool = task.pool
self.reschedule_running()
return task
def run_task_step(self):
@ -277,12 +290,18 @@ class AsyncScheduler:
data = None
# Sets the currently running task
self.current_task = self.run_ready.pop(0)
self.debugger.before_task_step(self.current_task)
self.current_task = self.run_ready.popleft()
if self.current_task.done():
# We need to make sure we don't try to execute
# exited tasks that are on the running queue
return
if not self.current_pool and self.current_task.pool:
self.current_pool = self.current_task.pool
self.deadlines.put(self.current_pool)
self.debugger.before_task_step(self.current_task)
# Some debugging and internal chatter here
self.current_task.status = "run"
self.current_task.steps += 1
if self.current_task.cancel_pending:
# We perform the deferred cancellation
# if it was previously scheduled
@ -291,9 +310,6 @@ class AsyncScheduler:
# somewhere)
method, *args = self.current_task.run(self._data.get(self.current_task))
self._data.pop(self.current_task, None)
# Some debugging and internal chatter here
self.current_task.status = "run"
self.current_task.steps += 1
if not hasattr(self, method) and not callable(getattr(self, method)):
# If this happens, that's quite bad!
# This if block is meant to be triggered by other async
@ -307,6 +323,16 @@ class AsyncScheduler:
getattr(self, method)(*args)
self.debugger.after_task_step(self.current_task)
def io_release(self, sock):
"""
Releases the given resource from our
selector.
:param sock: The resource to be released
"""
if self.selector.get_map() and sock in self.selector.get_map():
self.selector.unregister(sock)
def io_release_task(self, task: Task):
"""
Calls self.io_release in a loop
@ -321,16 +347,6 @@ class AsyncScheduler:
self.io_release(k.fileobj)
task.last_io = ()
def io_release(self, sock):
"""
Releases the given resource from our
selector.
:param sock: The resource to be released
"""
if self.selector.get_map() and sock in self.selector.get_map():
self.selector.unregister(sock)
def suspend(self, task: Task):
"""
Suspends execution of the given task. This is basically
@ -393,16 +409,22 @@ class AsyncScheduler:
self._data[self.current_task] = self
self.reschedule_running()
def expire_deadlines(self):
def prune_deadlines(self):
"""
Handles expiring deadlines by raising an exception
inside the correct pool if its timeout expired
Removes expired deadlines after their timeout
has expired
"""
while self.deadlines.get_closest_deadline() <= self.clock():
while self.deadlines and self.deadlines.get_closest_deadline() <= self.clock():
pool = self.deadlines.get()
if pool.done():
continue
pool.timed_out = True
self.cancel_pool(pool)
for task in pool.tasks:
if not task.done():
self.paused.discard(task)
self.io_release_task(task)
task.throw(TooSlowError(task))
def schedule_tasks(self, tasks: List[Task]):
"""
@ -420,6 +442,12 @@ class AsyncScheduler:
has elapsed
"""
for _, __, t in self.paused.container:
# This is to ensure that even when tasks are
# awaited instead of spawned, timeouts work as
# expected
if t.done() or t in self.run_ready or t is self.current_task:
self.paused.discard(t)
while self.paused and self.paused.get_closest_deadline() <= self.clock():
# Reschedules tasks when their deadline has elapsed
task = self.paused.get()
@ -569,6 +597,8 @@ class AsyncScheduler:
given task, if any
"""
if task.pool and task.pool.enclosed_pool and not task.pool.enclosed_pool.done():
return
for t in task.joiners:
if t not in self.run_ready:
# Since a task can be the parent
@ -586,8 +616,6 @@ class AsyncScheduler:
"""
task.joined = True
if task is not self.current_task:
task.joiners.add(self.current_task)
if task.finished or task.cancelled:
if not task.cancelled:
self.debugger.on_task_exit(task)
@ -603,8 +631,8 @@ class AsyncScheduler:
task.status = "crashed"
if task.exc.__traceback__:
# TODO: We might want to do a bit more complex traceback hacking to remove any extra
# frames from the exception call stack, but for now removing at least the first one
# seems a sensible approach (it's us catching it so we don't care about that)
# frames from the exception call stack, but for now removing at least the first one
# seems a sensible approach (it's us catching it so we don't care about that)
task.exc.__traceback__ = task.exc.__traceback__.tb_next
if task.last_io:
self.io_release_task(task)

View File

@ -53,23 +53,20 @@ class ResourceBusy(GiambioError):
one task at a time
"""
...
class ResourceClosed(GiambioError):
"""
Raised when I/O is attempted on a closed resource
"""
...
class TooSlowError(GiambioError):
"""
This is raised if the timeout of a pool created using
giambio.with_timeout expires
"""
task: Task
class ErrorStack(GiambioError):
"""

View File

@ -162,7 +162,7 @@ class TimeQueue:
class DeadlinesQueue:
"""
An ordered queue for storing tasks deadlines
An ordered queue for storing task deadlines
"""
def __init__(self):
@ -211,7 +211,7 @@ class DeadlinesQueue:
"""
idx = self.index(item)
if idx != 1:
if idx != -1:
self.container.pop(idx)
heapify(self.container)
@ -267,12 +267,13 @@ class DeadlinesQueue:
"""
Pushes a pool with its deadline onto the queue. The
timeout amount will be inferred from the pool object
itself
itself. Completed or expired pools are not added to the
queue. Pools without a timeout are also ignored
:param pool: The pool object to store
"""
if pool not in self.pools:
if pool and pool not in self.pools and not pool.done() and not pool.timed_out and pool.timeout:
self.pools.add(pool)
heappush(self.container, (pool.timeout, self.sequence, pool))
self.sequence += 1

View File

@ -37,8 +37,9 @@ class AsyncSocket:
Abstraction layer for asynchronous sockets
"""
def __init__(self, sock):
def __init__(self, sock, do_handshake_on_connect: bool = True):
self.sock = sock
self.do_handshake_on_connect = do_handshake_on_connect
self._fd = sock.fileno()
self.sock.setblocking(False)
@ -58,6 +59,22 @@ class AsyncSocket:
except WantWrite:
await want_write(self.sock)
async def connect(self, address):
"""
Wrapper socket method
"""
if self._fd == -1:
raise ResourceClosed("I/O operation on closed socket")
while True:
try:
self.sock.connect(address)
if self.do_handshake_on_connect:
await self.do_handshake()
return
except WantWrite:
await want_write(self.sock)
async def accept(self):
"""
Accepts the socket, completing the 3-step TCP handshake asynchronously
@ -108,19 +125,6 @@ class AsyncSocket:
if self.sock:
self.sock.shutdown(how)
async def connect(self, addr: tuple):
"""
Connects the socket to an endpoint
"""
if self._fd == -1:
raise ResourceClosed("I/O operation on closed socket")
try:
self.sock.connect(addr)
except WantWrite:
await want_write(self.sock)
self.sock.connect(addr)
async def bind(self, addr: tuple):
"""
Binds the socket to an address
@ -198,24 +202,6 @@ class AsyncSocket:
except WantWrite:
await want_write(self.sock)
async def connect(self, address):
"""
Wrapper socket method
"""
try:
result = self.sock.connect(address)
if getattr(self, "do_handshake_on_connect", False):
await self.do_handshake()
return result
except WantWrite:
await want_write(self.sock)
err = self.sock.getsockopt(SOL_SOCKET, SO_ERROR)
if err != 0:
raise OSError(err, f"Connect call failed {address}")
if getattr(self, "do_handshake_on_connect", False):
await self.do_handshake()
async def recvfrom(self, buffersize, flags=0):
"""
Wrapper socket method

View File

@ -42,8 +42,7 @@ def get_event_loop():
def new_event_loop(debugger: BaseDebugger, clock: FunctionType):
""" print(hex(id(pool)))
"""
Associates a new event loop to the current thread
and deactivates the old one. This should not be
called explicitly unless you know what you're doing.
@ -101,7 +100,31 @@ def with_timeout(timeout: int or float):
Creates an async pool with an associated timeout
"""
# We add 1 to make the timeout intuitive and inclusive (i.e.
# a 10 seconds timeout means the task is allowed to run 10
# whole seconds instead of cancelling at the tenth second)
return TaskManager(timeout + 1)
assert timeout > 0, "The timeout must be greater than 0"
mgr = TaskManager(timeout)
loop = get_event_loop()
if loop.current_task.pool is None:
loop.current_pool = mgr
loop.current_task.pool = mgr
loop.current_task.next_deadline = mgr.timeout or 0.0
loop.deadlines.put(mgr)
return mgr
def skip_after(timeout: int or float):
"""
Creates an async pool with an associated timeout, but
without raising a TooSlowError exception. The pool
is simply cancelled and code execution moves on
"""
assert timeout > 0, "The timeout must be greater than 0"
mgr = TaskManager(timeout, False)
loop = get_event_loop()
if loop.current_task.pool is None:
loop.current_pool = mgr
loop.current_task.pool = mgr
loop.current_task.next_deadline = mgr.timeout or 0.0
loop.deadlines.put(mgr)
return mgr

View File

@ -15,7 +15,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from giambio.traps import event_wait, event_set
from giambio.traps import event_wait, event_set, current_task
from giambio.exceptions import GiambioError
@ -48,3 +48,19 @@ class Event:
"""
await event_wait(self)
class Queue:
"""
An asynchronous queue similar to asyncio.Queue.
NOT thread safe!
"""
def __init__(self):
"""
Object constructor
"""
self.events = {}
# async def put

View File

@ -19,4 +19,4 @@ async def main():
if __name__ == "__main__":
giambio.run(main, debugger=Debugger())
giambio.run(main, debugger=())

View File

@ -34,6 +34,8 @@ class Debugger(giambio.debug.BaseDebugger):
print(f"# Task '{task.name}' slept for {seconds:.2f} seconds")
def before_io(self, timeout):
if timeout is None:
timeout = float("inf")
print(f"!! About to check for I/O for up to {timeout:.2f} seconds")
def after_io(self, timeout):

View File

@ -30,4 +30,4 @@ async def main():
if __name__ == "__main__":
giambio.run(main, debugger=Debugger())
giambio.run(main, debugger=())

View File

@ -20,10 +20,14 @@ async def serve(bind_address: tuple):
logging.info(f"Serving asynchronously at {bind_address[0]}:{bind_address[1]}")
async with giambio.create_pool() as pool:
async with sock:
while True:
conn, address_tuple = await sock.accept()
logging.info(f"{address_tuple[0]}:{address_tuple[1]} connected")
await pool.spawn(handler, conn, address_tuple)
while True:
try:
conn, address_tuple = await sock.accept()
logging.info(f"{address_tuple[0]}:{address_tuple[1]} connected")
await pool.spawn(handler, conn, address_tuple)
except Exception as err:
# Because exceptions just *work*
logging.info(f"{address_tuple[0]}:{address_tuple[1]} has raised {type(err).__name__}: {err}")
async def handler(sock: AsyncSocket, client_address: tuple):
@ -46,7 +50,7 @@ async def handler(sock: AsyncSocket, client_address: tuple):
break
elif data == b"exit\n":
await sock.send_all(b"I'm dead dude\n")
raise TypeError("Oh, no, I'm gonna die!") # This kills the entire application!
raise TypeError("Oh, no, I'm gonna die!")
logging.info(f"Got: {data!r} from {address}")
await sock.send_all(b"Got: " + data)
logging.info(f"Echoed back {data!r} to {address}")

33
tests/socket_ssl.py Normal file
View File

@ -0,0 +1,33 @@
from debugger import Debugger
import giambio
import socket as sock
import ssl
async def test(host: str, port: int):
socket = giambio.socket.wrap_socket(
ssl.wrap_socket(
sock.socket(),
do_handshake_on_connect=False)
)
await socket.connect((host, port))
async with giambio.skip_after(2) as p:
async with socket:
await socket.send_all(b"""GET / HTTP/1.1\r
Host: google.com\r
User-Agent: owo\r
Accept: text/html\r
Connection: keep-alive\r\n\r\n""")
buffer = b""
while True:
data = await socket.receive(4096)
if data:
buffer += data
else:
break
print("\n".join(buffer.decode().split("\r\n")))
print(p.timed_out)
giambio.run(test, "google.com", 443, debugger=())

View File

@ -13,7 +13,8 @@ async def main():
try:
async with giambio.with_timeout(10) as pool:
await pool.spawn(child, 7) # This will complete
await child(20) # TODO: Broken
await pool.spawn(child, 15) # This will not
await child(20) # Neither will this
except giambio.exceptions.TooSlowError:
print("[main] One or more children have timed out!")
print(f"[main] Children execution complete in {giambio.clock() - start:.2f} seconds")

23
tests/timeout2.py Normal file
View File

@ -0,0 +1,23 @@
import giambio
from debugger import Debugger
async def child(name: int):
print(f"[child {name}] Child spawned!! Sleeping for {name} seconds")
await giambio.sleep(name)
print(f"[child {name}] Had a nice nap!")
async def main():
start = giambio.clock()
async with giambio.skip_after(10) as pool:
await pool.spawn(child, 7) # This will complete
await pool.spawn(child, 15) # This will not
await child(20) # Neither will this
if pool.timed_out:
print("[main] One or more children have timed out!")
print(f"[main] Children execution complete in {giambio.clock() - start:.2f} seconds")
if __name__ == "__main__":
giambio.run(main, debugger=())

View File

@ -0,0 +1,55 @@
import giambio
from debugger import Debugger
async def child():
print("[child] Child spawned!! Sleeping for 5 seconds")
await giambio.sleep(5)
print("[child] Had a nice nap!")
async def child1():
print("[child 1] Child spawned!! Sleeping for 2 seconds")
await giambio.sleep(2)
print("[child 1] Had a nice nap, suiciding now!")
raise TypeError("rip") # Watch the exception magically propagate!
async def child2():
print("[child 2] Child spawned!! Sleeping for 4 seconds")
await giambio.sleep(4)
print("[child 2] Had a nice nap!")
async def child3():
print("[child 3] Child spawned!! Sleeping for 6 seconds")
await giambio.sleep(6)
print("[child 3] Had a nice nap!")
async def main():
start = giambio.clock()
try:
async with giambio.create_pool() as pool:
# This pool will run until completion of its
# tasks and then propagate the exception
await pool.spawn(child)
await pool.spawn(child)
print("[main] First 2 children spawned, awaiting completion")
async with giambio.create_pool() as a_pool:
await a_pool.spawn(child1)
print("[main] Third children spawned, prepare for trouble in 2 seconds")
async with giambio.create_pool() as new_pool:
# This pool will be cancelled by the exception
# in the outer pool
await new_pool.spawn(child2)
await new_pool.spawn(child3)
print("[main] Fourth and fifth children spawned")
except Exception as error:
# Because exceptions just *work*!
print(f"[main] Exception from child caught! {repr(error)}")
print(f"[main] Children execution complete in {giambio.clock() - start:.2f} seconds")
if __name__ == "__main__":
giambio.run(main, debugger=())