Fixed some issues with join() not properly rescheduling its caller when appropriate

This commit is contained in:
Mattia Giambirtone 2022-02-05 16:14:21 +01:00
parent 584f762d61
commit 5c05de495d
17 changed files with 112 additions and 110 deletions

View File

@ -45,5 +45,5 @@ __all__ = [
"skip_after",
"task",
"io",
"socket"
"socket",
]

View File

@ -16,9 +16,8 @@ See the License for the specific language governing permissions and
limitations under the License.
"""
import types
import giambio
from typing import List, Optional
from typing import List, Optional, Callable, Coroutine, Any
class TaskManager:
@ -55,7 +54,7 @@ class TaskManager:
self.enclosed_pool: Optional["giambio.context.TaskManager"] = None
self.raise_on_timeout: bool = raise_on_timeout
async def spawn(self, func: types.FunctionType, *args, **kwargs) -> "giambio.task.Task":
async def spawn(self, func: Callable[..., Coroutine[Any, Any, Any]], *args, **kwargs) -> "giambio.task.Task":
"""
Spawns a child task
"""

View File

@ -17,13 +17,12 @@ limitations under the License.
"""
# Import libraries and internal resources
import types
from giambio.task import Task
from collections import deque
from functools import partial
from timeit import default_timer
from giambio.context import TaskManager
from typing import Callable, List, Optional, Any, Dict
from typing import Callable, List, Optional, Any, Dict, Coroutine
from giambio.util.debug import BaseDebugger
from giambio.internal import TimeQueue, DeadlinesQueue
from selectors import DefaultSelector, EVENT_READ, EVENT_WRITE
@ -56,7 +55,7 @@ class AsyncScheduler:
:param clock: A callable returning monotonically increasing values at each call,
usually using seconds as units, but this is not enforced, defaults to timeit.default_timer
:type clock: :class: types.FunctionType
:type clock: :class: Callable
:param debugger: A subclass of giambio.util.BaseDebugger or None if no debugging output
is desired, defaults to None
:type debugger: :class: giambio.util.BaseDebugger
@ -73,7 +72,7 @@ class AsyncScheduler:
def __init__(
self,
clock: types.FunctionType = default_timer,
clock: Callable = default_timer,
debugger: Optional[BaseDebugger] = None,
selector: Optional[Any] = None,
io_skip_limit: Optional[int] = None,
@ -107,7 +106,7 @@ class AsyncScheduler:
# This will always point to the currently running coroutine (Task object)
self.current_task: Optional[Task] = None
# Monotonic clock to keep track of elapsed time reliably
self.clock: types.FunctionType = clock
self.clock: Callable = clock
# Tasks that are asleep
self.paused: TimeQueue = TimeQueue(self.clock)
# Have we ever ran?
@ -246,8 +245,7 @@ class AsyncScheduler:
self.current_task.exc = err
self.join(self.current_task)
def create_task(self, corofunc: types.FunctionType, pool, *args, **kwargs) -> Task:
def create_task(self, corofunc: Callable[..., Coroutine[Any, Any, Any]], pool, *args, **kwargs) -> Task:
"""
Creates a task from a coroutine function and schedules it
to run. The associated pool that spawned said task is also
@ -286,7 +284,6 @@ class AsyncScheduler:
account, that's self.run's job!
"""
data = None
# Sets the currently running task
self.current_task = self.run_ready.popleft()
if self.current_task.done():
@ -540,7 +537,7 @@ class AsyncScheduler:
self.run_ready.append(key.data) # Resource ready? Schedule its task
self.debugger.after_io(self.clock() - before_time)
def start(self, func: types.FunctionType, *args, loop: bool = True):
def start(self, func: Callable[..., Coroutine[Any, Any, Any]], *args, loop: bool = True):
"""
Starts the event loop from a sync context. If the loop parameter
is false, the event loop will not start listening for events
@ -623,17 +620,22 @@ class AsyncScheduler:
given task, if any
"""
if task.pool and task.pool.enclosed_pool and not task.pool.enclosed_pool.done():
return
for t in task.joiners:
if t not in self.run_ready:
# Since a task can be the parent
# of multiple children, we need to
# make sure we reschedule it only
# once, otherwise a RuntimeError will
# occur
self.run_ready.append(t)
# noinspection PyMethodMayBeStatic
def is_pool_done(self, pool: Optional[TaskManager]):
"""
Returns True if a given pool has finished
executing
"""
while pool:
if not pool.done():
return False
pool = pool.enclosed_pool
return True
def join(self, task: Task):
"""
Joins a task to its callers (implicitly, the parent
@ -643,6 +645,8 @@ class AsyncScheduler:
task.joined = True
if task.finished or task.cancelled:
if task in self.tasks:
self.tasks.remove(task)
if not task.cancelled:
# This way join() returns the
# task's return value
@ -653,9 +657,12 @@ class AsyncScheduler:
self.io_release_task(task)
# If the pool has finished executing or we're at the first parent
# task that kicked the loop, we can safely reschedule the parent(s)
if not task.pool or task.pool.done():
if self.is_pool_done(task.pool):
self.reschedule_joiners(task)
self.reschedule_running()
elif task.exc:
if task in self.tasks:
self.tasks.remove(task)
task.status = "crashed"
if task.exc.__traceback__:
# TODO: We might want to do a bit more complex traceback hacking to remove any extra
@ -676,15 +683,11 @@ class AsyncScheduler:
# or just returned
for t in task.joiners.copy():
# Propagate the exception
try:
t.throw(task.exc)
except (StopIteration, CancelledError, RuntimeError):
# TODO: Need anything else?
self.handle_task_exit(t, partial(t.throw, task.exc))
if t.exc or t.finished or t.cancelled:
task.joiners.remove(t)
finally:
if t in self.tasks:
self.tasks.remove(t)
self.reschedule_joiners(task)
self.reschedule_running()
def sleep(self, seconds: int or float):
"""

View File

@ -44,6 +44,13 @@ class TimeQueue:
self.sequence = 0
self.container: List[Tuple[float, int, Task]] = []
def __len__(self):
"""
Returns len(self)
"""
return len(self.container)
def __contains__(self, item: Task):
"""
Implements item in self. This method behaves
@ -263,6 +270,13 @@ class DeadlinesQueue:
return f"DeadlinesQueue({self.container})"
def __len__(self):
"""
Returns len(self)
"""
return len(self.container)
def put(self, pool: "giambio.context.TaskManager"):
"""
Pushes a pool with its deadline onto the queue. The

View File

@ -18,12 +18,13 @@ limitations under the License.
import inspect
import threading
from typing import Callable, Coroutine, Any, Union
from giambio.core import AsyncScheduler
from giambio.exceptions import GiambioError
from giambio.context import TaskManager
from timeit import default_timer
from giambio.util.debug import BaseDebugger
from types import FunctionType
thread_local = threading.local()
@ -41,7 +42,7 @@ def get_event_loop():
raise GiambioError("giambio is not running") from None
def new_event_loop(debugger: BaseDebugger, clock: FunctionType):
def new_event_loop(debugger: BaseDebugger, clock: Callable):
"""
Associates a new event loop to the current thread
and deactivates the old one. This should not be
@ -62,7 +63,7 @@ def new_event_loop(debugger: BaseDebugger, clock: FunctionType):
thread_local.loop = AsyncScheduler(clock, debugger)
def run(func: FunctionType, *args, **kwargs):
def run(func: Callable[..., Coroutine[Any, Any, Any]], *args, **kwargs):
"""
Starts the event loop from a synchronous entry point
"""
@ -95,23 +96,16 @@ def create_pool():
return TaskManager()
def with_timeout(timeout: int or float):
def with_timeout(timeout: Union[int, float]):
"""
Creates an async pool with an associated timeout
"""
assert timeout > 0, "The timeout must be greater than 0"
mgr = TaskManager(timeout)
loop = get_event_loop()
if loop.current_task.pool is None:
loop.current_pool = mgr
loop.current_task.pool = mgr
loop.current_task.next_deadline = mgr.timeout or 0.0
loop.deadlines.put(mgr)
return mgr
return TaskManager(timeout)
def skip_after(timeout: int or float):
def skip_after(timeout: Union[int, float]):
"""
Creates an async pool with an associated timeout, but
without raising a TooSlowError exception. The pool
@ -119,11 +113,4 @@ def skip_after(timeout: int or float):
"""
assert timeout > 0, "The timeout must be greater than 0"
mgr = TaskManager(timeout, False)
loop = get_event_loop()
if loop.current_task.pool is None:
loop.current_pool = mgr
loop.current_task.pool = mgr
loop.current_task.next_deadline = mgr.timeout or 0.0
loop.deadlines.put(mgr)
return mgr
return TaskManager(timeout, False)

View File

@ -73,7 +73,6 @@ class Queue:
self.putters = deque()
self.container = deque(maxlen=maxsize)
async def put(self, item: Any):
"""
Pushes an element onto the queue. If the
@ -89,7 +88,6 @@ class Queue:
self.putters.append(Event())
await self.putters[-1].wait()
async def get(self) -> Any:
"""
Pops an element off the queue. Blocks until

View File

@ -112,7 +112,6 @@ class Task:
self.joiners.add(task)
return await giambio.traps.join(self)
async def cancel(self):
"""
Cancels the task

View File

@ -24,8 +24,7 @@ limitations under the License.
import types
import inspect
from giambio.task import Task
from types import FunctionType
from typing import Any, Union, Iterable
from typing import Any, Union, Iterable, Coroutine, Callable
from giambio.exceptions import GiambioError
@ -49,7 +48,7 @@ async def suspend() -> Any:
return await create_trap("suspend")
async def create_task(coro: FunctionType, pool, *args):
async def create_task(coro: Callable[..., Coroutine[Any, Any, Any]], pool, *args):
"""
Spawns a new task in the current event loop from a bare coroutine
function. All extra positional arguments are passed to the function

View File

@ -15,7 +15,9 @@ async def child(ev: giambio.Event, pause: int):
await giambio.sleep(pause)
end_sleep = giambio.clock() - start_sleep
end_total = giambio.clock() - start_total
print(f"[child] Done! Slept for {end_total:.2f} seconds total ({end_pause:.2f} waiting, {end_sleep:.2f} sleeping), nice nap!")
print(
f"[child] Done! Slept for {end_total:.2f} seconds total ({end_pause:.2f} waiting, {end_sleep:.2f} sleeping), nice nap!"
)
async def parent(pause: int = 1):

View File

@ -20,13 +20,11 @@ async def consumer(q: giambio.Queue):
await giambio.sleep(1)
async def main(q: giambio.Queue, n: int):
async with giambio.create_pool() as pool:
await pool.spawn(consumer, q)
await pool.spawn(producer, q, n)
queue = giambio.Queue()
giambio.run(main, queue, 5, debugger=())

View File

@ -7,6 +7,7 @@ import time
_print = print
def print(*args, **kwargs):
sys.stdout.write(f"[{time.strftime('%H:%M:%S')}] ")
_print(*args, **kwargs)
@ -24,7 +25,8 @@ async def test(host: str, port: int, bufsize: int = 4096):
# handshake upon connection, set the this
# parameter in the AsyncSocket class instead
do_handshake_on_connect=False,
server_hostname=host)
server_hostname=host,
)
)
print(f"Attempting a connection to {host}:{port}")
await socket.connect((host, port))
@ -34,7 +36,9 @@ async def test(host: str, port: int, bufsize: int = 4096):
async with socket:
# Closes the socket automatically
print("Entered socket context manager, sending request data")
await socket.send_all(b"""GET / HTTP/1.1\r\nHost: google.com\r\nUser-Agent: owo\r\nAccept: text/html\r\nConnection: keep-alive\r\nAccept: */*\r\n\r\n""")
await socket.send_all(
b"""GET / HTTP/1.1\r\nHost: google.com\r\nUser-Agent: owo\r\nAccept: text/html\r\nConnection: keep-alive\r\nAccept: */*\r\n\r\n"""
)
print("Data sent")
buffer = b""
while not buffer.endswith(b"\r\n\r\n"):
@ -70,4 +74,3 @@ async def test(host: str, port: int, bufsize: int = 4096):
giambio.run(test, "google.com", 443, 256, debugger=())

View File

@ -14,8 +14,8 @@ async def main():
try:
async with giambio.with_timeout(5) as pool:
task = await pool.spawn(child, 2)
print(await task.join())
await giambio.sleep(5)
print(f"Child has returned: {await task.join()}")
await giambio.sleep(5) # This will trigger the timeout
except giambio.exceptions.TooSlowError:
print("[main] One or more children have timed out!")
print(f"[main] Children execution complete in {giambio.clock() - start:.2f} seconds")