Fixed some issues with join() not properly rescheduling its caller when appropriate

This commit is contained in:
Mattia Giambirtone 2022-02-05 16:14:21 +01:00
parent 584f762d61
commit 5c05de495d
17 changed files with 112 additions and 110 deletions

View File

@ -45,5 +45,5 @@ __all__ = [
"skip_after", "skip_after",
"task", "task",
"io", "io",
"socket" "socket",
] ]

View File

@ -16,9 +16,8 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
""" """
import types
import giambio import giambio
from typing import List, Optional from typing import List, Optional, Callable, Coroutine, Any
class TaskManager: class TaskManager:
@ -55,7 +54,7 @@ class TaskManager:
self.enclosed_pool: Optional["giambio.context.TaskManager"] = None self.enclosed_pool: Optional["giambio.context.TaskManager"] = None
self.raise_on_timeout: bool = raise_on_timeout self.raise_on_timeout: bool = raise_on_timeout
async def spawn(self, func: types.FunctionType, *args, **kwargs) -> "giambio.task.Task": async def spawn(self, func: Callable[..., Coroutine[Any, Any, Any]], *args, **kwargs) -> "giambio.task.Task":
""" """
Spawns a child task Spawns a child task
""" """

View File

@ -17,13 +17,12 @@ limitations under the License.
""" """
# Import libraries and internal resources # Import libraries and internal resources
import types
from giambio.task import Task from giambio.task import Task
from collections import deque from collections import deque
from functools import partial from functools import partial
from timeit import default_timer from timeit import default_timer
from giambio.context import TaskManager from giambio.context import TaskManager
from typing import Callable, List, Optional, Any, Dict from typing import Callable, List, Optional, Any, Dict, Coroutine
from giambio.util.debug import BaseDebugger from giambio.util.debug import BaseDebugger
from giambio.internal import TimeQueue, DeadlinesQueue from giambio.internal import TimeQueue, DeadlinesQueue
from selectors import DefaultSelector, EVENT_READ, EVENT_WRITE from selectors import DefaultSelector, EVENT_READ, EVENT_WRITE
@ -56,7 +55,7 @@ class AsyncScheduler:
:param clock: A callable returning monotonically increasing values at each call, :param clock: A callable returning monotonically increasing values at each call,
usually using seconds as units, but this is not enforced, defaults to timeit.default_timer usually using seconds as units, but this is not enforced, defaults to timeit.default_timer
:type clock: :class: types.FunctionType :type clock: :class: Callable
:param debugger: A subclass of giambio.util.BaseDebugger or None if no debugging output :param debugger: A subclass of giambio.util.BaseDebugger or None if no debugging output
is desired, defaults to None is desired, defaults to None
:type debugger: :class: giambio.util.BaseDebugger :type debugger: :class: giambio.util.BaseDebugger
@ -73,7 +72,7 @@ class AsyncScheduler:
def __init__( def __init__(
self, self,
clock: types.FunctionType = default_timer, clock: Callable = default_timer,
debugger: Optional[BaseDebugger] = None, debugger: Optional[BaseDebugger] = None,
selector: Optional[Any] = None, selector: Optional[Any] = None,
io_skip_limit: Optional[int] = None, io_skip_limit: Optional[int] = None,
@ -107,7 +106,7 @@ class AsyncScheduler:
# This will always point to the currently running coroutine (Task object) # This will always point to the currently running coroutine (Task object)
self.current_task: Optional[Task] = None self.current_task: Optional[Task] = None
# Monotonic clock to keep track of elapsed time reliably # Monotonic clock to keep track of elapsed time reliably
self.clock: types.FunctionType = clock self.clock: Callable = clock
# Tasks that are asleep # Tasks that are asleep
self.paused: TimeQueue = TimeQueue(self.clock) self.paused: TimeQueue = TimeQueue(self.clock)
# Have we ever ran? # Have we ever ran?
@ -246,8 +245,7 @@ class AsyncScheduler:
self.current_task.exc = err self.current_task.exc = err
self.join(self.current_task) self.join(self.current_task)
def create_task(self, corofunc: Callable[..., Coroutine[Any, Any, Any]], pool, *args, **kwargs) -> Task:
def create_task(self, corofunc: types.FunctionType, pool, *args, **kwargs) -> Task:
""" """
Creates a task from a coroutine function and schedules it Creates a task from a coroutine function and schedules it
to run. The associated pool that spawned said task is also to run. The associated pool that spawned said task is also
@ -286,7 +284,6 @@ class AsyncScheduler:
account, that's self.run's job! account, that's self.run's job!
""" """
data = None
# Sets the currently running task # Sets the currently running task
self.current_task = self.run_ready.popleft() self.current_task = self.run_ready.popleft()
if self.current_task.done(): if self.current_task.done():
@ -351,11 +348,11 @@ class AsyncScheduler:
a do-nothing method, since it will not reschedule the task a do-nothing method, since it will not reschedule the task
before returning. The task will stay suspended as long as before returning. The task will stay suspended as long as
something else outside the loop calls a trap to reschedule it. something else outside the loop calls a trap to reschedule it.
Any pending I/O for the task is temporarily unscheduled to Any pending I/O for the task is temporarily unscheduled to
avoid some previous network operation to reschedule the task avoid some previous network operation to reschedule the task
before it's due before it's due
""" """
if self.current_task.last_io: if self.current_task.last_io:
self.io_release_task(self.current_task) self.io_release_task(self.current_task)
self.suspended.append(self.current_task) self.suspended.append(self.current_task)
@ -540,7 +537,7 @@ class AsyncScheduler:
self.run_ready.append(key.data) # Resource ready? Schedule its task self.run_ready.append(key.data) # Resource ready? Schedule its task
self.debugger.after_io(self.clock() - before_time) self.debugger.after_io(self.clock() - before_time)
def start(self, func: types.FunctionType, *args, loop: bool = True): def start(self, func: Callable[..., Coroutine[Any, Any, Any]], *args, loop: bool = True):
""" """
Starts the event loop from a sync context. If the loop parameter Starts the event loop from a sync context. If the loop parameter
is false, the event loop will not start listening for events is false, the event loop will not start listening for events
@ -623,16 +620,21 @@ class AsyncScheduler:
given task, if any given task, if any
""" """
if task.pool and task.pool.enclosed_pool and not task.pool.enclosed_pool.done():
return
for t in task.joiners: for t in task.joiners:
if t not in self.run_ready: self.run_ready.append(t)
# Since a task can be the parent
# of multiple children, we need to # noinspection PyMethodMayBeStatic
# make sure we reschedule it only def is_pool_done(self, pool: Optional[TaskManager]):
# once, otherwise a RuntimeError will """
# occur Returns True if a given pool has finished
self.run_ready.append(t) executing
"""
while pool:
if not pool.done():
return False
pool = pool.enclosed_pool
return True
def join(self, task: Task): def join(self, task: Task):
""" """
@ -643,6 +645,8 @@ class AsyncScheduler:
task.joined = True task.joined = True
if task.finished or task.cancelled: if task.finished or task.cancelled:
if task in self.tasks:
self.tasks.remove(task)
if not task.cancelled: if not task.cancelled:
# This way join() returns the # This way join() returns the
# task's return value # task's return value
@ -653,9 +657,12 @@ class AsyncScheduler:
self.io_release_task(task) self.io_release_task(task)
# If the pool has finished executing or we're at the first parent # If the pool has finished executing or we're at the first parent
# task that kicked the loop, we can safely reschedule the parent(s) # task that kicked the loop, we can safely reschedule the parent(s)
if not task.pool or task.pool.done(): if self.is_pool_done(task.pool):
self.reschedule_joiners(task) self.reschedule_joiners(task)
self.reschedule_running()
elif task.exc: elif task.exc:
if task in self.tasks:
self.tasks.remove(task)
task.status = "crashed" task.status = "crashed"
if task.exc.__traceback__: if task.exc.__traceback__:
# TODO: We might want to do a bit more complex traceback hacking to remove any extra # TODO: We might want to do a bit more complex traceback hacking to remove any extra
@ -676,15 +683,11 @@ class AsyncScheduler:
# or just returned # or just returned
for t in task.joiners.copy(): for t in task.joiners.copy():
# Propagate the exception # Propagate the exception
try: self.handle_task_exit(t, partial(t.throw, task.exc))
t.throw(task.exc) if t.exc or t.finished or t.cancelled:
except (StopIteration, CancelledError, RuntimeError):
# TODO: Need anything else?
task.joiners.remove(t) task.joiners.remove(t)
finally:
if t in self.tasks:
self.tasks.remove(t)
self.reschedule_joiners(task) self.reschedule_joiners(task)
self.reschedule_running()
def sleep(self, seconds: int or float): def sleep(self, seconds: int or float):
""" """

View File

@ -44,6 +44,13 @@ class TimeQueue:
self.sequence = 0 self.sequence = 0
self.container: List[Tuple[float, int, Task]] = [] self.container: List[Tuple[float, int, Task]] = []
def __len__(self):
"""
Returns len(self)
"""
return len(self.container)
def __contains__(self, item: Task): def __contains__(self, item: Task):
""" """
Implements item in self. This method behaves Implements item in self. This method behaves
@ -263,6 +270,13 @@ class DeadlinesQueue:
return f"DeadlinesQueue({self.container})" return f"DeadlinesQueue({self.container})"
def __len__(self):
"""
Returns len(self)
"""
return len(self.container)
def put(self, pool: "giambio.context.TaskManager"): def put(self, pool: "giambio.context.TaskManager"):
""" """
Pushes a pool with its deadline onto the queue. The Pushes a pool with its deadline onto the queue. The

View File

@ -244,7 +244,7 @@ class AsyncSocket:
await want_write(self.sock) await want_write(self.sock)
except WantRead: except WantRead:
await want_read(self.sock) await want_read(self.sock)
async def getpeername(self): async def getpeername(self):
""" """
Wrapper socket method Wrapper socket method

View File

@ -18,12 +18,13 @@ limitations under the License.
import inspect import inspect
import threading import threading
from typing import Callable, Coroutine, Any, Union
from giambio.core import AsyncScheduler from giambio.core import AsyncScheduler
from giambio.exceptions import GiambioError from giambio.exceptions import GiambioError
from giambio.context import TaskManager from giambio.context import TaskManager
from timeit import default_timer from timeit import default_timer
from giambio.util.debug import BaseDebugger from giambio.util.debug import BaseDebugger
from types import FunctionType
thread_local = threading.local() thread_local = threading.local()
@ -41,7 +42,7 @@ def get_event_loop():
raise GiambioError("giambio is not running") from None raise GiambioError("giambio is not running") from None
def new_event_loop(debugger: BaseDebugger, clock: FunctionType): def new_event_loop(debugger: BaseDebugger, clock: Callable):
""" """
Associates a new event loop to the current thread Associates a new event loop to the current thread
and deactivates the old one. This should not be and deactivates the old one. This should not be
@ -62,7 +63,7 @@ def new_event_loop(debugger: BaseDebugger, clock: FunctionType):
thread_local.loop = AsyncScheduler(clock, debugger) thread_local.loop = AsyncScheduler(clock, debugger)
def run(func: FunctionType, *args, **kwargs): def run(func: Callable[..., Coroutine[Any, Any, Any]], *args, **kwargs):
""" """
Starts the event loop from a synchronous entry point Starts the event loop from a synchronous entry point
""" """
@ -95,23 +96,16 @@ def create_pool():
return TaskManager() return TaskManager()
def with_timeout(timeout: int or float): def with_timeout(timeout: Union[int, float]):
""" """
Creates an async pool with an associated timeout Creates an async pool with an associated timeout
""" """
assert timeout > 0, "The timeout must be greater than 0" assert timeout > 0, "The timeout must be greater than 0"
mgr = TaskManager(timeout) return TaskManager(timeout)
loop = get_event_loop()
if loop.current_task.pool is None:
loop.current_pool = mgr
loop.current_task.pool = mgr
loop.current_task.next_deadline = mgr.timeout or 0.0
loop.deadlines.put(mgr)
return mgr
def skip_after(timeout: int or float): def skip_after(timeout: Union[int, float]):
""" """
Creates an async pool with an associated timeout, but Creates an async pool with an associated timeout, but
without raising a TooSlowError exception. The pool without raising a TooSlowError exception. The pool
@ -119,11 +113,4 @@ def skip_after(timeout: int or float):
""" """
assert timeout > 0, "The timeout must be greater than 0" assert timeout > 0, "The timeout must be greater than 0"
mgr = TaskManager(timeout, False) return TaskManager(timeout, False)
loop = get_event_loop()
if loop.current_task.pool is None:
loop.current_pool = mgr
loop.current_task.pool = mgr
loop.current_task.next_deadline = mgr.timeout or 0.0
loop.deadlines.put(mgr)
return mgr

View File

@ -73,11 +73,10 @@ class Queue:
self.putters = deque() self.putters = deque()
self.container = deque(maxlen=maxsize) self.container = deque(maxlen=maxsize)
async def put(self, item: Any): async def put(self, item: Any):
""" """
Pushes an element onto the queue. If the Pushes an element onto the queue. If the
queue is full, waits until there's queue is full, waits until there's
enough space for the queue enough space for the queue
""" """
@ -88,7 +87,6 @@ class Queue:
else: else:
self.putters.append(Event()) self.putters.append(Event())
await self.putters[-1].wait() await self.putters[-1].wait()
async def get(self) -> Any: async def get(self) -> Any:
""" """
@ -103,4 +101,4 @@ class Queue:
return self.container.popleft() return self.container.popleft()
else: else:
self.getters.append(Event()) self.getters.append(Event())
return await self.getters[-1].wait() return await self.getters[-1].wait()

View File

@ -54,8 +54,8 @@ class Task:
# when the task has been created but not started running yet--, "run"-- when # when the task has been created but not started running yet--, "run"-- when
# the task is running synchronous code--, "io"-- when the task is waiting on # the task is running synchronous code--, "io"-- when the task is waiting on
# an I/O resource--, "sleep"-- when the task is either asleep, waiting on # an I/O resource--, "sleep"-- when the task is either asleep, waiting on
# an event or otherwise suspended, "crashed"-- when the task has exited because # an event or otherwise suspended, "crashed"-- when the task has exited because
# of an exception and "cancelled" when-- when the task has been explicitly cancelled # of an exception and "cancelled" when-- when the task has been explicitly cancelled
# with its cancel() method or as a result of an exception # with its cancel() method or as a result of an exception
status: str = "init" status: str = "init"
# This attribute counts how many times the task's run() method has been called # This attribute counts how many times the task's run() method has been called
@ -112,7 +112,6 @@ class Task:
self.joiners.add(task) self.joiners.add(task)
return await giambio.traps.join(self) return await giambio.traps.join(self)
async def cancel(self): async def cancel(self):
""" """
Cancels the task Cancels the task

View File

@ -24,8 +24,7 @@ limitations under the License.
import types import types
import inspect import inspect
from giambio.task import Task from giambio.task import Task
from types import FunctionType from typing import Any, Union, Iterable, Coroutine, Callable
from typing import Any, Union, Iterable
from giambio.exceptions import GiambioError from giambio.exceptions import GiambioError
@ -49,7 +48,7 @@ async def suspend() -> Any:
return await create_trap("suspend") return await create_trap("suspend")
async def create_task(coro: FunctionType, pool, *args): async def create_task(coro: Callable[..., Coroutine[Any, Any, Any]], pool, *args):
""" """
Spawns a new task in the current event loop from a bare coroutine Spawns a new task in the current event loop from a bare coroutine
function. All extra positional arguments are passed to the function function. All extra positional arguments are passed to the function

View File

@ -19,7 +19,7 @@ async def receiver(sock: giambio.socket.AsyncSocket, q: giambio.Queue):
data, rest = data.split(b"\n", maxsplit=2) data, rest = data.split(b"\n", maxsplit=2)
buffer = b"".join(rest) buffer = b"".join(rest)
await q.put((1, data.decode())) await q.put((1, data.decode()))
data = buffer data = buffer
async def main(host: Tuple[str, int]): async def main(host: Tuple[str, int]):

View File

@ -23,15 +23,15 @@ async def serve(bind_address: tuple):
logging.info(f"Serving asynchronously at {bind_address[0]}:{bind_address[1]}") logging.info(f"Serving asynchronously at {bind_address[0]}:{bind_address[1]}")
async with giambio.create_pool() as pool: async with giambio.create_pool() as pool:
async with sock: async with sock:
while True: while True:
try: try:
conn, address_tuple = await sock.accept() conn, address_tuple = await sock.accept()
clients.append(conn) clients.append(conn)
logging.info(f"{address_tuple[0]}:{address_tuple[1]} connected") logging.info(f"{address_tuple[0]}:{address_tuple[1]} connected")
await pool.spawn(handler, conn, address_tuple) await pool.spawn(handler, conn, address_tuple)
except Exception as err: except Exception as err:
# Because exceptions just *work* # Because exceptions just *work*
logging.info(f"{address_tuple[0]}:{address_tuple[1]} has raised {type(err).__name__}: {err}") logging.info(f"{address_tuple[0]}:{address_tuple[1]} has raised {type(err).__name__}: {err}")
async def handler(sock: AsyncSocket, client_address: tuple): async def handler(sock: AsyncSocket, client_address: tuple):

View File

@ -20,14 +20,14 @@ async def serve(bind_address: tuple):
logging.info(f"Serving asynchronously at {bind_address[0]}:{bind_address[1]}") logging.info(f"Serving asynchronously at {bind_address[0]}:{bind_address[1]}")
async with giambio.create_pool() as pool: async with giambio.create_pool() as pool:
async with sock: async with sock:
while True: while True:
try: try:
conn, address_tuple = await sock.accept() conn, address_tuple = await sock.accept()
logging.info(f"{address_tuple[0]}:{address_tuple[1]} connected") logging.info(f"{address_tuple[0]}:{address_tuple[1]} connected")
await pool.spawn(handler, conn, address_tuple) await pool.spawn(handler, conn, address_tuple)
except Exception as err: except Exception as err:
# Because exceptions just *work* # Because exceptions just *work*
logging.info(f"{address_tuple[0]}:{address_tuple[1]} has raised {type(err).__name__}: {err}") logging.info(f"{address_tuple[0]}:{address_tuple[1]} has raised {type(err).__name__}: {err}")
async def handler(sock: AsyncSocket, client_address: tuple): async def handler(sock: AsyncSocket, client_address: tuple):

View File

@ -15,7 +15,9 @@ async def child(ev: giambio.Event, pause: int):
await giambio.sleep(pause) await giambio.sleep(pause)
end_sleep = giambio.clock() - start_sleep end_sleep = giambio.clock() - start_sleep
end_total = giambio.clock() - start_total end_total = giambio.clock() - start_total
print(f"[child] Done! Slept for {end_total:.2f} seconds total ({end_pause:.2f} waiting, {end_sleep:.2f} sleeping), nice nap!") print(
f"[child] Done! Slept for {end_total:.2f} seconds total ({end_pause:.2f} waiting, {end_sleep:.2f} sleeping), nice nap!"
)
async def parent(pause: int = 1): async def parent(pause: int = 1):

View File

@ -18,14 +18,12 @@ async def consumer(q: giambio.Queue):
break break
print(f"Consumed {item}") print(f"Consumed {item}")
await giambio.sleep(1) await giambio.sleep(1)
async def main(q: giambio.Queue, n: int): async def main(q: giambio.Queue, n: int):
async with giambio.create_pool() as pool: async with giambio.create_pool() as pool:
await pool.spawn(consumer, q) await pool.spawn(consumer, q)
await pool.spawn(producer, q, n) await pool.spawn(producer, q, n)
queue = giambio.Queue() queue = giambio.Queue()

View File

@ -7,6 +7,7 @@ import time
_print = print _print = print
def print(*args, **kwargs): def print(*args, **kwargs):
sys.stdout.write(f"[{time.strftime('%H:%M:%S')}] ") sys.stdout.write(f"[{time.strftime('%H:%M:%S')}] ")
_print(*args, **kwargs) _print(*args, **kwargs)
@ -14,18 +15,19 @@ def print(*args, **kwargs):
async def test(host: str, port: int, bufsize: int = 4096): async def test(host: str, port: int, bufsize: int = 4096):
socket = giambio.socket.wrap_socket( socket = giambio.socket.wrap_socket(
ssl.create_default_context().wrap_socket( ssl.create_default_context().wrap_socket(
sock=sock.socket(), sock=sock.socket(),
# Note: do_handshake_on_connect MUST # Note: do_handshake_on_connect MUST
# be set to False on the synchronous socket! # be set to False on the synchronous socket!
# Giambio handles the TLS handshake asynchronously # Giambio handles the TLS handshake asynchronously
# and making the SSL library handle it blocks # and making the SSL library handle it blocks
# the entire event loop. To perform the TLS # the entire event loop. To perform the TLS
# handshake upon connection, set the this # handshake upon connection, set the this
# parameter in the AsyncSocket class instead # parameter in the AsyncSocket class instead
do_handshake_on_connect=False, do_handshake_on_connect=False,
server_hostname=host) server_hostname=host,
) )
)
print(f"Attempting a connection to {host}:{port}") print(f"Attempting a connection to {host}:{port}")
await socket.connect((host, port)) await socket.connect((host, port))
print("Connected") print("Connected")
@ -34,18 +36,20 @@ async def test(host: str, port: int, bufsize: int = 4096):
async with socket: async with socket:
# Closes the socket automatically # Closes the socket automatically
print("Entered socket context manager, sending request data") print("Entered socket context manager, sending request data")
await socket.send_all(b"""GET / HTTP/1.1\r\nHost: google.com\r\nUser-Agent: owo\r\nAccept: text/html\r\nConnection: keep-alive\r\nAccept: */*\r\n\r\n""") await socket.send_all(
b"""GET / HTTP/1.1\r\nHost: google.com\r\nUser-Agent: owo\r\nAccept: text/html\r\nConnection: keep-alive\r\nAccept: */*\r\n\r\n"""
)
print("Data sent") print("Data sent")
buffer = b"" buffer = b""
while not buffer.endswith(b"\r\n\r\n"): while not buffer.endswith(b"\r\n\r\n"):
print(f"Requesting up to {bufsize} bytes (current response size: {len(buffer)})") print(f"Requesting up to {bufsize} bytes (current response size: {len(buffer)})")
data = await socket.receive(bufsize) data = await socket.receive(bufsize)
print(f"Received {len(data)} bytes") print(f"Received {len(data)} bytes")
if data: if data:
buffer += data buffer += data
else: else:
print("Received empty stream, closing connection") print("Received empty stream, closing connection")
break break
print(f"Request has{' not' if not p.timed_out else ''} timed out!") print(f"Request has{' not' if not p.timed_out else ''} timed out!")
if buffer: if buffer:
data = buffer.decode().split("\r\n") data = buffer.decode().split("\r\n")
@ -70,4 +74,3 @@ async def test(host: str, port: int, bufsize: int = 4096):
giambio.run(test, "google.com", 443, 256, debugger=()) giambio.run(test, "google.com", 443, 256, debugger=())

View File

@ -13,7 +13,7 @@ async def main():
try: try:
async with giambio.with_timeout(12) as pool: async with giambio.with_timeout(12) as pool:
await pool.spawn(child, 7) # This will complete await pool.spawn(child, 7) # This will complete
await giambio.sleep(2) # This will make the code below wait 2 seconds await giambio.sleep(2) # This will make the code below wait 2 seconds
await pool.spawn(child, 15) # This will not complete await pool.spawn(child, 15) # This will not complete
await giambio.sleep(50) await giambio.sleep(50)
await child(20) # Neither will this await child(20) # Neither will this

View File

@ -14,8 +14,8 @@ async def main():
try: try:
async with giambio.with_timeout(5) as pool: async with giambio.with_timeout(5) as pool:
task = await pool.spawn(child, 2) task = await pool.spawn(child, 2)
print(await task.join()) print(f"Child has returned: {await task.join()}")
await giambio.sleep(5) await giambio.sleep(5) # This will trigger the timeout
except giambio.exceptions.TooSlowError: except giambio.exceptions.TooSlowError:
print("[main] One or more children have timed out!") print("[main] One or more children have timed out!")
print(f"[main] Children execution complete in {giambio.clock() - start:.2f} seconds") print(f"[main] Children execution complete in {giambio.clock() - start:.2f} seconds")