Fixed some issues with join() not properly rescheduling its caller when appropriate

This commit is contained in:
Mattia Giambirtone 2022-02-05 16:14:21 +01:00
parent 584f762d61
commit 5c05de495d
17 changed files with 112 additions and 110 deletions

View File

@ -45,5 +45,5 @@ __all__ = [
"skip_after",
"task",
"io",
"socket"
"socket",
]

View File

@ -16,9 +16,8 @@ See the License for the specific language governing permissions and
limitations under the License.
"""
import types
import giambio
from typing import List, Optional
from typing import List, Optional, Callable, Coroutine, Any
class TaskManager:
@ -55,7 +54,7 @@ class TaskManager:
self.enclosed_pool: Optional["giambio.context.TaskManager"] = None
self.raise_on_timeout: bool = raise_on_timeout
async def spawn(self, func: types.FunctionType, *args, **kwargs) -> "giambio.task.Task":
async def spawn(self, func: Callable[..., Coroutine[Any, Any, Any]], *args, **kwargs) -> "giambio.task.Task":
"""
Spawns a child task
"""

View File

@ -17,13 +17,12 @@ limitations under the License.
"""
# Import libraries and internal resources
import types
from giambio.task import Task
from collections import deque
from functools import partial
from timeit import default_timer
from giambio.context import TaskManager
from typing import Callable, List, Optional, Any, Dict
from typing import Callable, List, Optional, Any, Dict, Coroutine
from giambio.util.debug import BaseDebugger
from giambio.internal import TimeQueue, DeadlinesQueue
from selectors import DefaultSelector, EVENT_READ, EVENT_WRITE
@ -56,7 +55,7 @@ class AsyncScheduler:
:param clock: A callable returning monotonically increasing values at each call,
usually using seconds as units, but this is not enforced, defaults to timeit.default_timer
:type clock: :class: types.FunctionType
:type clock: :class: Callable
:param debugger: A subclass of giambio.util.BaseDebugger or None if no debugging output
is desired, defaults to None
:type debugger: :class: giambio.util.BaseDebugger
@ -73,7 +72,7 @@ class AsyncScheduler:
def __init__(
self,
clock: types.FunctionType = default_timer,
clock: Callable = default_timer,
debugger: Optional[BaseDebugger] = None,
selector: Optional[Any] = None,
io_skip_limit: Optional[int] = None,
@ -107,7 +106,7 @@ class AsyncScheduler:
# This will always point to the currently running coroutine (Task object)
self.current_task: Optional[Task] = None
# Monotonic clock to keep track of elapsed time reliably
self.clock: types.FunctionType = clock
self.clock: Callable = clock
# Tasks that are asleep
self.paused: TimeQueue = TimeQueue(self.clock)
# Have we ever ran?
@ -246,8 +245,7 @@ class AsyncScheduler:
self.current_task.exc = err
self.join(self.current_task)
def create_task(self, corofunc: types.FunctionType, pool, *args, **kwargs) -> Task:
def create_task(self, corofunc: Callable[..., Coroutine[Any, Any, Any]], pool, *args, **kwargs) -> Task:
"""
Creates a task from a coroutine function and schedules it
to run. The associated pool that spawned said task is also
@ -286,7 +284,6 @@ class AsyncScheduler:
account, that's self.run's job!
"""
data = None
# Sets the currently running task
self.current_task = self.run_ready.popleft()
if self.current_task.done():
@ -351,11 +348,11 @@ class AsyncScheduler:
a do-nothing method, since it will not reschedule the task
before returning. The task will stay suspended as long as
something else outside the loop calls a trap to reschedule it.
Any pending I/O for the task is temporarily unscheduled to
Any pending I/O for the task is temporarily unscheduled to
avoid some previous network operation to reschedule the task
before it's due
"""
if self.current_task.last_io:
self.io_release_task(self.current_task)
self.suspended.append(self.current_task)
@ -540,7 +537,7 @@ class AsyncScheduler:
self.run_ready.append(key.data) # Resource ready? Schedule its task
self.debugger.after_io(self.clock() - before_time)
def start(self, func: types.FunctionType, *args, loop: bool = True):
def start(self, func: Callable[..., Coroutine[Any, Any, Any]], *args, loop: bool = True):
"""
Starts the event loop from a sync context. If the loop parameter
is false, the event loop will not start listening for events
@ -623,16 +620,21 @@ class AsyncScheduler:
given task, if any
"""
if task.pool and task.pool.enclosed_pool and not task.pool.enclosed_pool.done():
return
for t in task.joiners:
if t not in self.run_ready:
# Since a task can be the parent
# of multiple children, we need to
# make sure we reschedule it only
# once, otherwise a RuntimeError will
# occur
self.run_ready.append(t)
self.run_ready.append(t)
# noinspection PyMethodMayBeStatic
def is_pool_done(self, pool: Optional[TaskManager]):
"""
Returns True if a given pool has finished
executing
"""
while pool:
if not pool.done():
return False
pool = pool.enclosed_pool
return True
def join(self, task: Task):
"""
@ -643,6 +645,8 @@ class AsyncScheduler:
task.joined = True
if task.finished or task.cancelled:
if task in self.tasks:
self.tasks.remove(task)
if not task.cancelled:
# This way join() returns the
# task's return value
@ -653,9 +657,12 @@ class AsyncScheduler:
self.io_release_task(task)
# If the pool has finished executing or we're at the first parent
# task that kicked the loop, we can safely reschedule the parent(s)
if not task.pool or task.pool.done():
if self.is_pool_done(task.pool):
self.reschedule_joiners(task)
self.reschedule_running()
elif task.exc:
if task in self.tasks:
self.tasks.remove(task)
task.status = "crashed"
if task.exc.__traceback__:
# TODO: We might want to do a bit more complex traceback hacking to remove any extra
@ -676,15 +683,11 @@ class AsyncScheduler:
# or just returned
for t in task.joiners.copy():
# Propagate the exception
try:
t.throw(task.exc)
except (StopIteration, CancelledError, RuntimeError):
# TODO: Need anything else?
self.handle_task_exit(t, partial(t.throw, task.exc))
if t.exc or t.finished or t.cancelled:
task.joiners.remove(t)
finally:
if t in self.tasks:
self.tasks.remove(t)
self.reschedule_joiners(task)
self.reschedule_running()
def sleep(self, seconds: int or float):
"""

View File

@ -44,6 +44,13 @@ class TimeQueue:
self.sequence = 0
self.container: List[Tuple[float, int, Task]] = []
def __len__(self):
"""
Returns len(self)
"""
return len(self.container)
def __contains__(self, item: Task):
"""
Implements item in self. This method behaves
@ -263,6 +270,13 @@ class DeadlinesQueue:
return f"DeadlinesQueue({self.container})"
def __len__(self):
"""
Returns len(self)
"""
return len(self.container)
def put(self, pool: "giambio.context.TaskManager"):
"""
Pushes a pool with its deadline onto the queue. The

View File

@ -244,7 +244,7 @@ class AsyncSocket:
await want_write(self.sock)
except WantRead:
await want_read(self.sock)
async def getpeername(self):
"""
Wrapper socket method

View File

@ -18,12 +18,13 @@ limitations under the License.
import inspect
import threading
from typing import Callable, Coroutine, Any, Union
from giambio.core import AsyncScheduler
from giambio.exceptions import GiambioError
from giambio.context import TaskManager
from timeit import default_timer
from giambio.util.debug import BaseDebugger
from types import FunctionType
thread_local = threading.local()
@ -41,7 +42,7 @@ def get_event_loop():
raise GiambioError("giambio is not running") from None
def new_event_loop(debugger: BaseDebugger, clock: FunctionType):
def new_event_loop(debugger: BaseDebugger, clock: Callable):
"""
Associates a new event loop to the current thread
and deactivates the old one. This should not be
@ -62,7 +63,7 @@ def new_event_loop(debugger: BaseDebugger, clock: FunctionType):
thread_local.loop = AsyncScheduler(clock, debugger)
def run(func: FunctionType, *args, **kwargs):
def run(func: Callable[..., Coroutine[Any, Any, Any]], *args, **kwargs):
"""
Starts the event loop from a synchronous entry point
"""
@ -95,23 +96,16 @@ def create_pool():
return TaskManager()
def with_timeout(timeout: int or float):
def with_timeout(timeout: Union[int, float]):
"""
Creates an async pool with an associated timeout
"""
assert timeout > 0, "The timeout must be greater than 0"
mgr = TaskManager(timeout)
loop = get_event_loop()
if loop.current_task.pool is None:
loop.current_pool = mgr
loop.current_task.pool = mgr
loop.current_task.next_deadline = mgr.timeout or 0.0
loop.deadlines.put(mgr)
return mgr
return TaskManager(timeout)
def skip_after(timeout: int or float):
def skip_after(timeout: Union[int, float]):
"""
Creates an async pool with an associated timeout, but
without raising a TooSlowError exception. The pool
@ -119,11 +113,4 @@ def skip_after(timeout: int or float):
"""
assert timeout > 0, "The timeout must be greater than 0"
mgr = TaskManager(timeout, False)
loop = get_event_loop()
if loop.current_task.pool is None:
loop.current_pool = mgr
loop.current_task.pool = mgr
loop.current_task.next_deadline = mgr.timeout or 0.0
loop.deadlines.put(mgr)
return mgr
return TaskManager(timeout, False)

View File

@ -73,11 +73,10 @@ class Queue:
self.putters = deque()
self.container = deque(maxlen=maxsize)
async def put(self, item: Any):
"""
Pushes an element onto the queue. If the
queue is full, waits until there's
queue is full, waits until there's
enough space for the queue
"""
@ -88,7 +87,6 @@ class Queue:
else:
self.putters.append(Event())
await self.putters[-1].wait()
async def get(self) -> Any:
"""
@ -103,4 +101,4 @@ class Queue:
return self.container.popleft()
else:
self.getters.append(Event())
return await self.getters[-1].wait()
return await self.getters[-1].wait()

View File

@ -54,8 +54,8 @@ class Task:
# when the task has been created but not started running yet--, "run"-- when
# the task is running synchronous code--, "io"-- when the task is waiting on
# an I/O resource--, "sleep"-- when the task is either asleep, waiting on
# an event or otherwise suspended, "crashed"-- when the task has exited because
# of an exception and "cancelled" when-- when the task has been explicitly cancelled
# an event or otherwise suspended, "crashed"-- when the task has exited because
# of an exception and "cancelled" when-- when the task has been explicitly cancelled
# with its cancel() method or as a result of an exception
status: str = "init"
# This attribute counts how many times the task's run() method has been called
@ -112,7 +112,6 @@ class Task:
self.joiners.add(task)
return await giambio.traps.join(self)
async def cancel(self):
"""
Cancels the task

View File

@ -24,8 +24,7 @@ limitations under the License.
import types
import inspect
from giambio.task import Task
from types import FunctionType
from typing import Any, Union, Iterable
from typing import Any, Union, Iterable, Coroutine, Callable
from giambio.exceptions import GiambioError
@ -49,7 +48,7 @@ async def suspend() -> Any:
return await create_trap("suspend")
async def create_task(coro: FunctionType, pool, *args):
async def create_task(coro: Callable[..., Coroutine[Any, Any, Any]], pool, *args):
"""
Spawns a new task in the current event loop from a bare coroutine
function. All extra positional arguments are passed to the function

View File

@ -19,7 +19,7 @@ async def receiver(sock: giambio.socket.AsyncSocket, q: giambio.Queue):
data, rest = data.split(b"\n", maxsplit=2)
buffer = b"".join(rest)
await q.put((1, data.decode()))
data = buffer
data = buffer
async def main(host: Tuple[str, int]):

View File

@ -23,15 +23,15 @@ async def serve(bind_address: tuple):
logging.info(f"Serving asynchronously at {bind_address[0]}:{bind_address[1]}")
async with giambio.create_pool() as pool:
async with sock:
while True:
try:
conn, address_tuple = await sock.accept()
clients.append(conn)
logging.info(f"{address_tuple[0]}:{address_tuple[1]} connected")
await pool.spawn(handler, conn, address_tuple)
except Exception as err:
# Because exceptions just *work*
logging.info(f"{address_tuple[0]}:{address_tuple[1]} has raised {type(err).__name__}: {err}")
while True:
try:
conn, address_tuple = await sock.accept()
clients.append(conn)
logging.info(f"{address_tuple[0]}:{address_tuple[1]} connected")
await pool.spawn(handler, conn, address_tuple)
except Exception as err:
# Because exceptions just *work*
logging.info(f"{address_tuple[0]}:{address_tuple[1]} has raised {type(err).__name__}: {err}")
async def handler(sock: AsyncSocket, client_address: tuple):

View File

@ -20,14 +20,14 @@ async def serve(bind_address: tuple):
logging.info(f"Serving asynchronously at {bind_address[0]}:{bind_address[1]}")
async with giambio.create_pool() as pool:
async with sock:
while True:
try:
conn, address_tuple = await sock.accept()
logging.info(f"{address_tuple[0]}:{address_tuple[1]} connected")
await pool.spawn(handler, conn, address_tuple)
except Exception as err:
# Because exceptions just *work*
logging.info(f"{address_tuple[0]}:{address_tuple[1]} has raised {type(err).__name__}: {err}")
while True:
try:
conn, address_tuple = await sock.accept()
logging.info(f"{address_tuple[0]}:{address_tuple[1]} connected")
await pool.spawn(handler, conn, address_tuple)
except Exception as err:
# Because exceptions just *work*
logging.info(f"{address_tuple[0]}:{address_tuple[1]} has raised {type(err).__name__}: {err}")
async def handler(sock: AsyncSocket, client_address: tuple):

View File

@ -15,7 +15,9 @@ async def child(ev: giambio.Event, pause: int):
await giambio.sleep(pause)
end_sleep = giambio.clock() - start_sleep
end_total = giambio.clock() - start_total
print(f"[child] Done! Slept for {end_total:.2f} seconds total ({end_pause:.2f} waiting, {end_sleep:.2f} sleeping), nice nap!")
print(
f"[child] Done! Slept for {end_total:.2f} seconds total ({end_pause:.2f} waiting, {end_sleep:.2f} sleeping), nice nap!"
)
async def parent(pause: int = 1):

View File

@ -18,14 +18,12 @@ async def consumer(q: giambio.Queue):
break
print(f"Consumed {item}")
await giambio.sleep(1)
async def main(q: giambio.Queue, n: int):
async with giambio.create_pool() as pool:
await pool.spawn(consumer, q)
await pool.spawn(producer, q, n)
queue = giambio.Queue()

View File

@ -7,6 +7,7 @@ import time
_print = print
def print(*args, **kwargs):
sys.stdout.write(f"[{time.strftime('%H:%M:%S')}] ")
_print(*args, **kwargs)
@ -14,18 +15,19 @@ def print(*args, **kwargs):
async def test(host: str, port: int, bufsize: int = 4096):
socket = giambio.socket.wrap_socket(
ssl.create_default_context().wrap_socket(
sock=sock.socket(),
# Note: do_handshake_on_connect MUST
# be set to False on the synchronous socket!
# Giambio handles the TLS handshake asynchronously
# and making the SSL library handle it blocks
# the entire event loop. To perform the TLS
# handshake upon connection, set the this
# parameter in the AsyncSocket class instead
do_handshake_on_connect=False,
server_hostname=host)
)
ssl.create_default_context().wrap_socket(
sock=sock.socket(),
# Note: do_handshake_on_connect MUST
# be set to False on the synchronous socket!
# Giambio handles the TLS handshake asynchronously
# and making the SSL library handle it blocks
# the entire event loop. To perform the TLS
# handshake upon connection, set the this
# parameter in the AsyncSocket class instead
do_handshake_on_connect=False,
server_hostname=host,
)
)
print(f"Attempting a connection to {host}:{port}")
await socket.connect((host, port))
print("Connected")
@ -34,18 +36,20 @@ async def test(host: str, port: int, bufsize: int = 4096):
async with socket:
# Closes the socket automatically
print("Entered socket context manager, sending request data")
await socket.send_all(b"""GET / HTTP/1.1\r\nHost: google.com\r\nUser-Agent: owo\r\nAccept: text/html\r\nConnection: keep-alive\r\nAccept: */*\r\n\r\n""")
await socket.send_all(
b"""GET / HTTP/1.1\r\nHost: google.com\r\nUser-Agent: owo\r\nAccept: text/html\r\nConnection: keep-alive\r\nAccept: */*\r\n\r\n"""
)
print("Data sent")
buffer = b""
while not buffer.endswith(b"\r\n\r\n"):
print(f"Requesting up to {bufsize} bytes (current response size: {len(buffer)})")
data = await socket.receive(bufsize)
print(f"Received {len(data)} bytes")
if data:
buffer += data
else:
print("Received empty stream, closing connection")
break
print(f"Requesting up to {bufsize} bytes (current response size: {len(buffer)})")
data = await socket.receive(bufsize)
print(f"Received {len(data)} bytes")
if data:
buffer += data
else:
print("Received empty stream, closing connection")
break
print(f"Request has{' not' if not p.timed_out else ''} timed out!")
if buffer:
data = buffer.decode().split("\r\n")
@ -70,4 +74,3 @@ async def test(host: str, port: int, bufsize: int = 4096):
giambio.run(test, "google.com", 443, 256, debugger=())

View File

@ -13,7 +13,7 @@ async def main():
try:
async with giambio.with_timeout(12) as pool:
await pool.spawn(child, 7) # This will complete
await giambio.sleep(2) # This will make the code below wait 2 seconds
await giambio.sleep(2) # This will make the code below wait 2 seconds
await pool.spawn(child, 15) # This will not complete
await giambio.sleep(50)
await child(20) # Neither will this

View File

@ -14,8 +14,8 @@ async def main():
try:
async with giambio.with_timeout(5) as pool:
task = await pool.spawn(child, 2)
print(await task.join())
await giambio.sleep(5)
print(f"Child has returned: {await task.join()}")
await giambio.sleep(5) # This will trigger the timeout
except giambio.exceptions.TooSlowError:
print("[main] One or more children have timed out!")
print(f"[main] Children execution complete in {giambio.clock() - start:.2f} seconds")