mirror of https://github.com/nocturn9x/giambio.git
Fixed some issues with join() not properly rescheduling its caller when appropriate
This commit is contained in:
parent
584f762d61
commit
5c05de495d
|
@ -45,5 +45,5 @@ __all__ = [
|
|||
"skip_after",
|
||||
"task",
|
||||
"io",
|
||||
"socket"
|
||||
"socket",
|
||||
]
|
||||
|
|
|
@ -16,9 +16,8 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
"""
|
||||
|
||||
import types
|
||||
import giambio
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Callable, Coroutine, Any
|
||||
|
||||
|
||||
class TaskManager:
|
||||
|
@ -55,7 +54,7 @@ class TaskManager:
|
|||
self.enclosed_pool: Optional["giambio.context.TaskManager"] = None
|
||||
self.raise_on_timeout: bool = raise_on_timeout
|
||||
|
||||
async def spawn(self, func: types.FunctionType, *args, **kwargs) -> "giambio.task.Task":
|
||||
async def spawn(self, func: Callable[..., Coroutine[Any, Any, Any]], *args, **kwargs) -> "giambio.task.Task":
|
||||
"""
|
||||
Spawns a child task
|
||||
"""
|
||||
|
|
|
@ -17,13 +17,12 @@ limitations under the License.
|
|||
"""
|
||||
|
||||
# Import libraries and internal resources
|
||||
import types
|
||||
from giambio.task import Task
|
||||
from collections import deque
|
||||
from functools import partial
|
||||
from timeit import default_timer
|
||||
from giambio.context import TaskManager
|
||||
from typing import Callable, List, Optional, Any, Dict
|
||||
from typing import Callable, List, Optional, Any, Dict, Coroutine
|
||||
from giambio.util.debug import BaseDebugger
|
||||
from giambio.internal import TimeQueue, DeadlinesQueue
|
||||
from selectors import DefaultSelector, EVENT_READ, EVENT_WRITE
|
||||
|
@ -56,7 +55,7 @@ class AsyncScheduler:
|
|||
|
||||
:param clock: A callable returning monotonically increasing values at each call,
|
||||
usually using seconds as units, but this is not enforced, defaults to timeit.default_timer
|
||||
:type clock: :class: types.FunctionType
|
||||
:type clock: :class: Callable
|
||||
:param debugger: A subclass of giambio.util.BaseDebugger or None if no debugging output
|
||||
is desired, defaults to None
|
||||
:type debugger: :class: giambio.util.BaseDebugger
|
||||
|
@ -73,7 +72,7 @@ class AsyncScheduler:
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
clock: types.FunctionType = default_timer,
|
||||
clock: Callable = default_timer,
|
||||
debugger: Optional[BaseDebugger] = None,
|
||||
selector: Optional[Any] = None,
|
||||
io_skip_limit: Optional[int] = None,
|
||||
|
@ -107,7 +106,7 @@ class AsyncScheduler:
|
|||
# This will always point to the currently running coroutine (Task object)
|
||||
self.current_task: Optional[Task] = None
|
||||
# Monotonic clock to keep track of elapsed time reliably
|
||||
self.clock: types.FunctionType = clock
|
||||
self.clock: Callable = clock
|
||||
# Tasks that are asleep
|
||||
self.paused: TimeQueue = TimeQueue(self.clock)
|
||||
# Have we ever ran?
|
||||
|
@ -246,8 +245,7 @@ class AsyncScheduler:
|
|||
self.current_task.exc = err
|
||||
self.join(self.current_task)
|
||||
|
||||
|
||||
def create_task(self, corofunc: types.FunctionType, pool, *args, **kwargs) -> Task:
|
||||
def create_task(self, corofunc: Callable[..., Coroutine[Any, Any, Any]], pool, *args, **kwargs) -> Task:
|
||||
"""
|
||||
Creates a task from a coroutine function and schedules it
|
||||
to run. The associated pool that spawned said task is also
|
||||
|
@ -286,7 +284,6 @@ class AsyncScheduler:
|
|||
account, that's self.run's job!
|
||||
"""
|
||||
|
||||
data = None
|
||||
# Sets the currently running task
|
||||
self.current_task = self.run_ready.popleft()
|
||||
if self.current_task.done():
|
||||
|
@ -540,7 +537,7 @@ class AsyncScheduler:
|
|||
self.run_ready.append(key.data) # Resource ready? Schedule its task
|
||||
self.debugger.after_io(self.clock() - before_time)
|
||||
|
||||
def start(self, func: types.FunctionType, *args, loop: bool = True):
|
||||
def start(self, func: Callable[..., Coroutine[Any, Any, Any]], *args, loop: bool = True):
|
||||
"""
|
||||
Starts the event loop from a sync context. If the loop parameter
|
||||
is false, the event loop will not start listening for events
|
||||
|
@ -623,16 +620,21 @@ class AsyncScheduler:
|
|||
given task, if any
|
||||
"""
|
||||
|
||||
if task.pool and task.pool.enclosed_pool and not task.pool.enclosed_pool.done():
|
||||
return
|
||||
for t in task.joiners:
|
||||
if t not in self.run_ready:
|
||||
# Since a task can be the parent
|
||||
# of multiple children, we need to
|
||||
# make sure we reschedule it only
|
||||
# once, otherwise a RuntimeError will
|
||||
# occur
|
||||
self.run_ready.append(t)
|
||||
self.run_ready.append(t)
|
||||
|
||||
# noinspection PyMethodMayBeStatic
|
||||
def is_pool_done(self, pool: Optional[TaskManager]):
|
||||
"""
|
||||
Returns True if a given pool has finished
|
||||
executing
|
||||
"""
|
||||
|
||||
while pool:
|
||||
if not pool.done():
|
||||
return False
|
||||
pool = pool.enclosed_pool
|
||||
return True
|
||||
|
||||
def join(self, task: Task):
|
||||
"""
|
||||
|
@ -643,6 +645,8 @@ class AsyncScheduler:
|
|||
|
||||
task.joined = True
|
||||
if task.finished or task.cancelled:
|
||||
if task in self.tasks:
|
||||
self.tasks.remove(task)
|
||||
if not task.cancelled:
|
||||
# This way join() returns the
|
||||
# task's return value
|
||||
|
@ -653,9 +657,12 @@ class AsyncScheduler:
|
|||
self.io_release_task(task)
|
||||
# If the pool has finished executing or we're at the first parent
|
||||
# task that kicked the loop, we can safely reschedule the parent(s)
|
||||
if not task.pool or task.pool.done():
|
||||
if self.is_pool_done(task.pool):
|
||||
self.reschedule_joiners(task)
|
||||
self.reschedule_running()
|
||||
elif task.exc:
|
||||
if task in self.tasks:
|
||||
self.tasks.remove(task)
|
||||
task.status = "crashed"
|
||||
if task.exc.__traceback__:
|
||||
# TODO: We might want to do a bit more complex traceback hacking to remove any extra
|
||||
|
@ -676,15 +683,11 @@ class AsyncScheduler:
|
|||
# or just returned
|
||||
for t in task.joiners.copy():
|
||||
# Propagate the exception
|
||||
try:
|
||||
t.throw(task.exc)
|
||||
except (StopIteration, CancelledError, RuntimeError):
|
||||
# TODO: Need anything else?
|
||||
self.handle_task_exit(t, partial(t.throw, task.exc))
|
||||
if t.exc or t.finished or t.cancelled:
|
||||
task.joiners.remove(t)
|
||||
finally:
|
||||
if t in self.tasks:
|
||||
self.tasks.remove(t)
|
||||
self.reschedule_joiners(task)
|
||||
self.reschedule_running()
|
||||
|
||||
def sleep(self, seconds: int or float):
|
||||
"""
|
||||
|
|
|
@ -44,6 +44,13 @@ class TimeQueue:
|
|||
self.sequence = 0
|
||||
self.container: List[Tuple[float, int, Task]] = []
|
||||
|
||||
def __len__(self):
|
||||
"""
|
||||
Returns len(self)
|
||||
"""
|
||||
|
||||
return len(self.container)
|
||||
|
||||
def __contains__(self, item: Task):
|
||||
"""
|
||||
Implements item in self. This method behaves
|
||||
|
@ -263,6 +270,13 @@ class DeadlinesQueue:
|
|||
|
||||
return f"DeadlinesQueue({self.container})"
|
||||
|
||||
def __len__(self):
|
||||
"""
|
||||
Returns len(self)
|
||||
"""
|
||||
|
||||
return len(self.container)
|
||||
|
||||
def put(self, pool: "giambio.context.TaskManager"):
|
||||
"""
|
||||
Pushes a pool with its deadline onto the queue. The
|
||||
|
|
|
@ -18,12 +18,13 @@ limitations under the License.
|
|||
|
||||
import inspect
|
||||
import threading
|
||||
from typing import Callable, Coroutine, Any, Union
|
||||
|
||||
from giambio.core import AsyncScheduler
|
||||
from giambio.exceptions import GiambioError
|
||||
from giambio.context import TaskManager
|
||||
from timeit import default_timer
|
||||
from giambio.util.debug import BaseDebugger
|
||||
from types import FunctionType
|
||||
|
||||
|
||||
thread_local = threading.local()
|
||||
|
@ -41,7 +42,7 @@ def get_event_loop():
|
|||
raise GiambioError("giambio is not running") from None
|
||||
|
||||
|
||||
def new_event_loop(debugger: BaseDebugger, clock: FunctionType):
|
||||
def new_event_loop(debugger: BaseDebugger, clock: Callable):
|
||||
"""
|
||||
Associates a new event loop to the current thread
|
||||
and deactivates the old one. This should not be
|
||||
|
@ -62,7 +63,7 @@ def new_event_loop(debugger: BaseDebugger, clock: FunctionType):
|
|||
thread_local.loop = AsyncScheduler(clock, debugger)
|
||||
|
||||
|
||||
def run(func: FunctionType, *args, **kwargs):
|
||||
def run(func: Callable[..., Coroutine[Any, Any, Any]], *args, **kwargs):
|
||||
"""
|
||||
Starts the event loop from a synchronous entry point
|
||||
"""
|
||||
|
@ -95,23 +96,16 @@ def create_pool():
|
|||
return TaskManager()
|
||||
|
||||
|
||||
def with_timeout(timeout: int or float):
|
||||
def with_timeout(timeout: Union[int, float]):
|
||||
"""
|
||||
Creates an async pool with an associated timeout
|
||||
"""
|
||||
|
||||
assert timeout > 0, "The timeout must be greater than 0"
|
||||
mgr = TaskManager(timeout)
|
||||
loop = get_event_loop()
|
||||
if loop.current_task.pool is None:
|
||||
loop.current_pool = mgr
|
||||
loop.current_task.pool = mgr
|
||||
loop.current_task.next_deadline = mgr.timeout or 0.0
|
||||
loop.deadlines.put(mgr)
|
||||
return mgr
|
||||
return TaskManager(timeout)
|
||||
|
||||
|
||||
def skip_after(timeout: int or float):
|
||||
def skip_after(timeout: Union[int, float]):
|
||||
"""
|
||||
Creates an async pool with an associated timeout, but
|
||||
without raising a TooSlowError exception. The pool
|
||||
|
@ -119,11 +113,4 @@ def skip_after(timeout: int or float):
|
|||
"""
|
||||
|
||||
assert timeout > 0, "The timeout must be greater than 0"
|
||||
mgr = TaskManager(timeout, False)
|
||||
loop = get_event_loop()
|
||||
if loop.current_task.pool is None:
|
||||
loop.current_pool = mgr
|
||||
loop.current_task.pool = mgr
|
||||
loop.current_task.next_deadline = mgr.timeout or 0.0
|
||||
loop.deadlines.put(mgr)
|
||||
return mgr
|
||||
return TaskManager(timeout, False)
|
||||
|
|
|
@ -73,7 +73,6 @@ class Queue:
|
|||
self.putters = deque()
|
||||
self.container = deque(maxlen=maxsize)
|
||||
|
||||
|
||||
async def put(self, item: Any):
|
||||
"""
|
||||
Pushes an element onto the queue. If the
|
||||
|
@ -89,7 +88,6 @@ class Queue:
|
|||
self.putters.append(Event())
|
||||
await self.putters[-1].wait()
|
||||
|
||||
|
||||
async def get(self) -> Any:
|
||||
"""
|
||||
Pops an element off the queue. Blocks until
|
||||
|
|
|
@ -112,7 +112,6 @@ class Task:
|
|||
self.joiners.add(task)
|
||||
return await giambio.traps.join(self)
|
||||
|
||||
|
||||
async def cancel(self):
|
||||
"""
|
||||
Cancels the task
|
||||
|
|
|
@ -24,8 +24,7 @@ limitations under the License.
|
|||
import types
|
||||
import inspect
|
||||
from giambio.task import Task
|
||||
from types import FunctionType
|
||||
from typing import Any, Union, Iterable
|
||||
from typing import Any, Union, Iterable, Coroutine, Callable
|
||||
from giambio.exceptions import GiambioError
|
||||
|
||||
|
||||
|
@ -49,7 +48,7 @@ async def suspend() -> Any:
|
|||
return await create_trap("suspend")
|
||||
|
||||
|
||||
async def create_task(coro: FunctionType, pool, *args):
|
||||
async def create_task(coro: Callable[..., Coroutine[Any, Any, Any]], pool, *args):
|
||||
"""
|
||||
Spawns a new task in the current event loop from a bare coroutine
|
||||
function. All extra positional arguments are passed to the function
|
||||
|
|
|
@ -23,15 +23,15 @@ async def serve(bind_address: tuple):
|
|||
logging.info(f"Serving asynchronously at {bind_address[0]}:{bind_address[1]}")
|
||||
async with giambio.create_pool() as pool:
|
||||
async with sock:
|
||||
while True:
|
||||
try:
|
||||
conn, address_tuple = await sock.accept()
|
||||
clients.append(conn)
|
||||
logging.info(f"{address_tuple[0]}:{address_tuple[1]} connected")
|
||||
await pool.spawn(handler, conn, address_tuple)
|
||||
except Exception as err:
|
||||
# Because exceptions just *work*
|
||||
logging.info(f"{address_tuple[0]}:{address_tuple[1]} has raised {type(err).__name__}: {err}")
|
||||
while True:
|
||||
try:
|
||||
conn, address_tuple = await sock.accept()
|
||||
clients.append(conn)
|
||||
logging.info(f"{address_tuple[0]}:{address_tuple[1]} connected")
|
||||
await pool.spawn(handler, conn, address_tuple)
|
||||
except Exception as err:
|
||||
# Because exceptions just *work*
|
||||
logging.info(f"{address_tuple[0]}:{address_tuple[1]} has raised {type(err).__name__}: {err}")
|
||||
|
||||
|
||||
async def handler(sock: AsyncSocket, client_address: tuple):
|
||||
|
|
|
@ -20,14 +20,14 @@ async def serve(bind_address: tuple):
|
|||
logging.info(f"Serving asynchronously at {bind_address[0]}:{bind_address[1]}")
|
||||
async with giambio.create_pool() as pool:
|
||||
async with sock:
|
||||
while True:
|
||||
try:
|
||||
conn, address_tuple = await sock.accept()
|
||||
logging.info(f"{address_tuple[0]}:{address_tuple[1]} connected")
|
||||
await pool.spawn(handler, conn, address_tuple)
|
||||
except Exception as err:
|
||||
# Because exceptions just *work*
|
||||
logging.info(f"{address_tuple[0]}:{address_tuple[1]} has raised {type(err).__name__}: {err}")
|
||||
while True:
|
||||
try:
|
||||
conn, address_tuple = await sock.accept()
|
||||
logging.info(f"{address_tuple[0]}:{address_tuple[1]} connected")
|
||||
await pool.spawn(handler, conn, address_tuple)
|
||||
except Exception as err:
|
||||
# Because exceptions just *work*
|
||||
logging.info(f"{address_tuple[0]}:{address_tuple[1]} has raised {type(err).__name__}: {err}")
|
||||
|
||||
|
||||
async def handler(sock: AsyncSocket, client_address: tuple):
|
||||
|
|
|
@ -15,7 +15,9 @@ async def child(ev: giambio.Event, pause: int):
|
|||
await giambio.sleep(pause)
|
||||
end_sleep = giambio.clock() - start_sleep
|
||||
end_total = giambio.clock() - start_total
|
||||
print(f"[child] Done! Slept for {end_total:.2f} seconds total ({end_pause:.2f} waiting, {end_sleep:.2f} sleeping), nice nap!")
|
||||
print(
|
||||
f"[child] Done! Slept for {end_total:.2f} seconds total ({end_pause:.2f} waiting, {end_sleep:.2f} sleeping), nice nap!"
|
||||
)
|
||||
|
||||
|
||||
async def parent(pause: int = 1):
|
||||
|
|
|
@ -20,13 +20,11 @@ async def consumer(q: giambio.Queue):
|
|||
await giambio.sleep(1)
|
||||
|
||||
|
||||
|
||||
async def main(q: giambio.Queue, n: int):
|
||||
async with giambio.create_pool() as pool:
|
||||
await pool.spawn(consumer, q)
|
||||
await pool.spawn(producer, q, n)
|
||||
|
||||
|
||||
|
||||
queue = giambio.Queue()
|
||||
giambio.run(main, queue, 5, debugger=())
|
||||
|
|
|
@ -7,6 +7,7 @@ import time
|
|||
|
||||
_print = print
|
||||
|
||||
|
||||
def print(*args, **kwargs):
|
||||
sys.stdout.write(f"[{time.strftime('%H:%M:%S')}] ")
|
||||
_print(*args, **kwargs)
|
||||
|
@ -14,18 +15,19 @@ def print(*args, **kwargs):
|
|||
|
||||
async def test(host: str, port: int, bufsize: int = 4096):
|
||||
socket = giambio.socket.wrap_socket(
|
||||
ssl.create_default_context().wrap_socket(
|
||||
sock=sock.socket(),
|
||||
# Note: do_handshake_on_connect MUST
|
||||
# be set to False on the synchronous socket!
|
||||
# Giambio handles the TLS handshake asynchronously
|
||||
# and making the SSL library handle it blocks
|
||||
# the entire event loop. To perform the TLS
|
||||
# handshake upon connection, set the this
|
||||
# parameter in the AsyncSocket class instead
|
||||
do_handshake_on_connect=False,
|
||||
server_hostname=host)
|
||||
)
|
||||
ssl.create_default_context().wrap_socket(
|
||||
sock=sock.socket(),
|
||||
# Note: do_handshake_on_connect MUST
|
||||
# be set to False on the synchronous socket!
|
||||
# Giambio handles the TLS handshake asynchronously
|
||||
# and making the SSL library handle it blocks
|
||||
# the entire event loop. To perform the TLS
|
||||
# handshake upon connection, set the this
|
||||
# parameter in the AsyncSocket class instead
|
||||
do_handshake_on_connect=False,
|
||||
server_hostname=host,
|
||||
)
|
||||
)
|
||||
print(f"Attempting a connection to {host}:{port}")
|
||||
await socket.connect((host, port))
|
||||
print("Connected")
|
||||
|
@ -34,18 +36,20 @@ async def test(host: str, port: int, bufsize: int = 4096):
|
|||
async with socket:
|
||||
# Closes the socket automatically
|
||||
print("Entered socket context manager, sending request data")
|
||||
await socket.send_all(b"""GET / HTTP/1.1\r\nHost: google.com\r\nUser-Agent: owo\r\nAccept: text/html\r\nConnection: keep-alive\r\nAccept: */*\r\n\r\n""")
|
||||
await socket.send_all(
|
||||
b"""GET / HTTP/1.1\r\nHost: google.com\r\nUser-Agent: owo\r\nAccept: text/html\r\nConnection: keep-alive\r\nAccept: */*\r\n\r\n"""
|
||||
)
|
||||
print("Data sent")
|
||||
buffer = b""
|
||||
while not buffer.endswith(b"\r\n\r\n"):
|
||||
print(f"Requesting up to {bufsize} bytes (current response size: {len(buffer)})")
|
||||
data = await socket.receive(bufsize)
|
||||
print(f"Received {len(data)} bytes")
|
||||
if data:
|
||||
buffer += data
|
||||
else:
|
||||
print("Received empty stream, closing connection")
|
||||
break
|
||||
print(f"Requesting up to {bufsize} bytes (current response size: {len(buffer)})")
|
||||
data = await socket.receive(bufsize)
|
||||
print(f"Received {len(data)} bytes")
|
||||
if data:
|
||||
buffer += data
|
||||
else:
|
||||
print("Received empty stream, closing connection")
|
||||
break
|
||||
print(f"Request has{' not' if not p.timed_out else ''} timed out!")
|
||||
if buffer:
|
||||
data = buffer.decode().split("\r\n")
|
||||
|
@ -70,4 +74,3 @@ async def test(host: str, port: int, bufsize: int = 4096):
|
|||
|
||||
|
||||
giambio.run(test, "google.com", 443, 256, debugger=())
|
||||
|
||||
|
|
|
@ -13,7 +13,7 @@ async def main():
|
|||
try:
|
||||
async with giambio.with_timeout(12) as pool:
|
||||
await pool.spawn(child, 7) # This will complete
|
||||
await giambio.sleep(2) # This will make the code below wait 2 seconds
|
||||
await giambio.sleep(2) # This will make the code below wait 2 seconds
|
||||
await pool.spawn(child, 15) # This will not complete
|
||||
await giambio.sleep(50)
|
||||
await child(20) # Neither will this
|
||||
|
|
|
@ -14,8 +14,8 @@ async def main():
|
|||
try:
|
||||
async with giambio.with_timeout(5) as pool:
|
||||
task = await pool.spawn(child, 2)
|
||||
print(await task.join())
|
||||
await giambio.sleep(5)
|
||||
print(f"Child has returned: {await task.join()}")
|
||||
await giambio.sleep(5) # This will trigger the timeout
|
||||
except giambio.exceptions.TooSlowError:
|
||||
print("[main] One or more children have timed out!")
|
||||
print(f"[main] Children execution complete in {giambio.clock() - start:.2f} seconds")
|
||||
|
|
Loading…
Reference in New Issue