Fixed nested pools (sorta?)

This commit is contained in:
nocturn9x 2021-08-26 16:19:40 +02:00
parent 2cdaa231b4
commit 9bb091cdef
7 changed files with 52 additions and 38 deletions

View File

@ -18,7 +18,7 @@ limitations under the License.
import types
import giambio
from typing import List
from typing import List, Optional
class TaskManager:
@ -49,6 +49,9 @@ class TaskManager:
# Whether our timeout expired or not
self.timed_out: bool = False
self._proper_init = False
self.enclosing_pool: Optional["giambio.context.TaskManager"] = giambio.get_event_loop().current_pool
self.enclosed_pool: Optional["giambio.context.TaskManager"] = None
# giambio.get_event_loop().current_pool = self
async def spawn(self, func: types.FunctionType, *args) -> "giambio.task.Task":
"""
@ -56,7 +59,7 @@ class TaskManager:
"""
assert self._proper_init, "Cannot use improperly initialized pool"
return await giambio.traps.create_task(func, *args)
return await giambio.traps.create_task(func, self, *args)
async def __aenter__(self):
"""

View File

@ -236,7 +236,7 @@ class AsyncScheduler:
self.current_task.exc = err
self.join(self.current_task)
def create_task(self, corofunc: types.FunctionType, *args, **kwargs) -> Task:
def create_task(self, corofunc: types.FunctionType, pool, *args, **kwargs) -> Task:
"""
Creates a task from a coroutine function and schedules it
to run. Any extra keyword or positional argument are then
@ -247,15 +247,18 @@ class AsyncScheduler:
:type corofunc: function
"""
task = Task(corofunc.__name__ or str(corofunc), corofunc(*args, **kwargs), self.current_pool)
task.next_deadline = self.current_pool.timeout or 0.0
task = Task(corofunc.__name__ or str(corofunc), corofunc(*args, **kwargs), pool)
task.next_deadline = pool.timeout or 0.0
task.joiners = {self.current_task}
self._data = task
self.tasks.append(task)
self.run_ready.append(task)
self.debugger.on_task_spawn(task)
self.current_pool.tasks.append(task)
pool.tasks.append(task)
self.reschedule_running()
if self.current_pool and task.pool is not self.current_pool:
self.current_pool.enclosed_pool = task.pool
self.current_pool = task.pool
return task
def run_task_step(self):
@ -270,8 +273,8 @@ class AsyncScheduler:
account, that's self.run's job!
"""
# Sets the currently running task
data = None
# Sets the currently running task
self.current_task = self.run_ready.pop(0)
self.debugger.before_task_step(self.current_task)
if self.current_task.done():
@ -293,8 +296,7 @@ class AsyncScheduler:
# Some debugging and internal chatter here
self.current_task.status = "run"
self.current_task.steps += 1
self.debugger.after_task_step(self.current_task)
if not hasattr(self, method):
if not hasattr(self, method) and not callable(getattr(self, method)):
# If this happens, that's quite bad!
# This if block is meant to be triggered by other async
# libraries, which most likely have different trap names and behaviors
@ -305,6 +307,7 @@ class AsyncScheduler:
) from None
# Sneaky method call, thanks to David Beazley for this ;)
getattr(self, method)(*args)
self.debugger.after_task_step(self.current_task)
def io_release_task(self, task: Task):
"""
@ -518,7 +521,10 @@ class AsyncScheduler:
# current pool. If, however, there are still some
# tasks running, we wait for them to exit in order
# to avoid orphaned tasks
return pool.done()
if pool.enclosed_pool and self.cancel_pool(pool.enclosed_pool):
return True
else:
return pool.done()
else: # If we're at the main task, we're sure everything else exited
return True
@ -582,6 +588,8 @@ class AsyncScheduler:
"""
task.joined = True
if task is not self.current_task:
task.joiners.add(self.current_task)
if task.finished or task.cancelled:
if not task.cancelled:
self.debugger.on_task_exit(task)
@ -589,33 +597,34 @@ class AsyncScheduler:
self.io_release_task(task)
if task.pool is None:
return
if self.current_pool and self.current_pool.done():
# If the current pool has finished executing or we're at the first parent
if task.pool.done():
# If the pool has finished executing or we're at the first parent
# task that kicked the loop, we can safely reschedule the parent(s)
self.reschedule_joiners(task)
elif task.exc:
task.status = "crashed"
# TODO: We might want to do a bit more complex traceback hacking to remove any extra
# frames from the exception call stack, but for now removing at least the first one
# seems a sensible approach (it's us catching it so we don't care about that)
task.exc.__traceback__ = task.exc.__traceback__.tb_next
if task.exc.__traceback__:
# TODO: We might want to do a bit more complex traceback hacking to remove any extra
# frames from the exception call stack, but for now removing at least the first one
# seems a sensible approach (it's us catching it so we don't care about that)
task.exc.__traceback__ = task.exc.__traceback__.tb_next
if task.last_io:
self.io_release_task(task)
self.debugger.on_exception_raised(task, task.exc)
if task.pool is None:
# Parent task has no pool, so we propagate
raise
if self.cancel_pool(self.current_pool):
if self.cancel_pool(task.pool):
# This will reschedule the parent(s)
# only if all the tasks inside the current
# only if all the tasks inside the task's
# pool have finished executing, either
# by cancellation, an exception
# or just returned
for t in task.joiners:
for t in task.joiners.copy():
# Propagate the exception
try:
t.throw(task.exc)
except (StopIteration, CancelledError):
except (StopIteration, CancelledError, RuntimeError):
# TODO: Need anything else?
task.joiners.remove(t)
self.reschedule_joiners(task)
@ -657,8 +666,7 @@ class AsyncScheduler:
# or dangling resource open after being cancelled, so maybe we need
# a different approach altogether
if task.status == "io":
for k in filter(lambda o: o.data == task, dict(self.selector.get_map()).values()):
self.selector.unregister(k.fileobj)
self.io_release_task(task)
elif task.status == "sleep":
self.paused.discard(task)
try:
@ -679,6 +687,10 @@ class AsyncScheduler:
task.status = "cancelled"
self.io_release_task(self.current_task)
self.debugger.after_cancel(task)
else:
# If the task ignores our exception, we'll
# raise it later again
task.cancel_pending = True
else:
# If we can't cancel in a somewhat "graceful" way, we just
# defer this operation for later (check run() for more info)

View File

@ -42,7 +42,8 @@ def get_event_loop():
def new_event_loop(debugger: BaseDebugger, clock: FunctionType):
"""
""" print(hex(id(pool)))
Associates a new event loop to the current thread
and deactivates the old one. This should not be
called explicitly unless you know what you're doing.
@ -92,10 +93,7 @@ def create_pool():
Creates an async pool
"""
loop = get_event_loop()
pool = TaskManager()
loop.current_pool = pool
return pool
return TaskManager()
def with_timeout(timeout: int or float):
@ -103,10 +101,7 @@ def with_timeout(timeout: int or float):
Creates an async pool with an associated timeout
"""
loop = get_event_loop()
# We add 1 to make the timeout intuitive and inclusive (i.e.
# a 10 seconds timeout means the task is allowed to run 10
# whole seconds instead of cancelling at the tenth second)
pool = TaskManager(timeout + 1)
loop.current_pool = pool
return pool
return TaskManager(timeout + 1)

View File

@ -7,7 +7,7 @@ Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
@ -17,6 +17,7 @@ limitations under the License.
"""
import giambio
import warnings
from dataclasses import dataclass, field
from typing import Union, Coroutine, Set
@ -95,6 +96,7 @@ class Task:
:type err: Exception
"""
# self.exc = err
return self.coroutine.throw(err)
async def join(self):
@ -143,4 +145,5 @@ class Task:
self.coroutine.close()
except RuntimeError:
pass # TODO: This is kinda bad
assert not self.last_io, f"task {self.name} was destroyed, but has pending I/O"
if self.last_io:
warnings.warn(f"task '{self.name}' was destroyed, but has pending I/O")

View File

@ -41,7 +41,7 @@ def create_trap(method, *args):
return data
async def create_task(coro: FunctionType, *args):
async def create_task(coro: FunctionType, pool, *args):
"""
Spawns a new task in the current event loop from a bare coroutine
function. All extra positional arguments are passed to the function
@ -56,7 +56,7 @@ async def create_task(coro: FunctionType, *args):
"\nWhat you wanna do, instead, is this: giambio.run(your_func, arg1, arg2, ...)"
)
elif inspect.iscoroutinefunction(coro):
return await create_trap("create_task", coro, *args)
return await create_trap("create_task", coro, pool, *args)
else:
raise TypeError("coro must be a coroutine function")

View File

@ -5,7 +5,7 @@ from debugger import Debugger
async def child():
print("[child] Child spawned!! Sleeping for 2 seconds")
await giambio.sleep(2)
print("[child] Had a nice nap!")
print("[child] Had a nice nap, suiciding now!")
raise TypeError("rip") # Watch the exception magically propagate!
@ -33,13 +33,13 @@ async def main():
async with giambio.create_pool() as pool:
await pool.spawn(child)
await pool.spawn(child1)
print("[main] Children spawned, awaiting completion")
print("[main] First 2 children spawned, awaiting completion")
async with giambio.create_pool() as new_pool:
# This pool will be cancelled by the exception
# in the other pool
await new_pool.spawn(child2)
await new_pool.spawn(child3)
print("[main] 3rd child spawned")
print("[main] Third and fourth children spawned")
except Exception as error:
# Because exceptions just *work*!
print(f"[main] Exception from child caught! {repr(error)}")

View File

@ -16,6 +16,7 @@ async def main():
async with giambio.create_pool() as a_pool:
await a_pool.spawn(child, 3)
await a_pool.spawn(child, 4)
# This executes after spawning all 4 tasks
print("[main] Children spawned, awaiting completion")
# This will *only* execute when everything inside the async with block
# has ran, including any other pool