Fixed nested pools (sorta?)

This commit is contained in:
nocturn9x 2021-08-26 16:19:40 +02:00
parent 2cdaa231b4
commit 9bb091cdef
7 changed files with 52 additions and 38 deletions

View File

@ -18,7 +18,7 @@ limitations under the License.
import types import types
import giambio import giambio
from typing import List from typing import List, Optional
class TaskManager: class TaskManager:
@ -49,6 +49,9 @@ class TaskManager:
# Whether our timeout expired or not # Whether our timeout expired or not
self.timed_out: bool = False self.timed_out: bool = False
self._proper_init = False self._proper_init = False
self.enclosing_pool: Optional["giambio.context.TaskManager"] = giambio.get_event_loop().current_pool
self.enclosed_pool: Optional["giambio.context.TaskManager"] = None
# giambio.get_event_loop().current_pool = self
async def spawn(self, func: types.FunctionType, *args) -> "giambio.task.Task": async def spawn(self, func: types.FunctionType, *args) -> "giambio.task.Task":
""" """
@ -56,7 +59,7 @@ class TaskManager:
""" """
assert self._proper_init, "Cannot use improperly initialized pool" assert self._proper_init, "Cannot use improperly initialized pool"
return await giambio.traps.create_task(func, *args) return await giambio.traps.create_task(func, self, *args)
async def __aenter__(self): async def __aenter__(self):
""" """

View File

@ -236,7 +236,7 @@ class AsyncScheduler:
self.current_task.exc = err self.current_task.exc = err
self.join(self.current_task) self.join(self.current_task)
def create_task(self, corofunc: types.FunctionType, *args, **kwargs) -> Task: def create_task(self, corofunc: types.FunctionType, pool, *args, **kwargs) -> Task:
""" """
Creates a task from a coroutine function and schedules it Creates a task from a coroutine function and schedules it
to run. Any extra keyword or positional argument are then to run. Any extra keyword or positional argument are then
@ -247,15 +247,18 @@ class AsyncScheduler:
:type corofunc: function :type corofunc: function
""" """
task = Task(corofunc.__name__ or str(corofunc), corofunc(*args, **kwargs), self.current_pool) task = Task(corofunc.__name__ or str(corofunc), corofunc(*args, **kwargs), pool)
task.next_deadline = self.current_pool.timeout or 0.0 task.next_deadline = pool.timeout or 0.0
task.joiners = {self.current_task} task.joiners = {self.current_task}
self._data = task self._data = task
self.tasks.append(task) self.tasks.append(task)
self.run_ready.append(task) self.run_ready.append(task)
self.debugger.on_task_spawn(task) self.debugger.on_task_spawn(task)
self.current_pool.tasks.append(task) pool.tasks.append(task)
self.reschedule_running() self.reschedule_running()
if self.current_pool and task.pool is not self.current_pool:
self.current_pool.enclosed_pool = task.pool
self.current_pool = task.pool
return task return task
def run_task_step(self): def run_task_step(self):
@ -270,8 +273,8 @@ class AsyncScheduler:
account, that's self.run's job! account, that's self.run's job!
""" """
# Sets the currently running task
data = None data = None
# Sets the currently running task
self.current_task = self.run_ready.pop(0) self.current_task = self.run_ready.pop(0)
self.debugger.before_task_step(self.current_task) self.debugger.before_task_step(self.current_task)
if self.current_task.done(): if self.current_task.done():
@ -293,8 +296,7 @@ class AsyncScheduler:
# Some debugging and internal chatter here # Some debugging and internal chatter here
self.current_task.status = "run" self.current_task.status = "run"
self.current_task.steps += 1 self.current_task.steps += 1
self.debugger.after_task_step(self.current_task) if not hasattr(self, method) and not callable(getattr(self, method)):
if not hasattr(self, method):
# If this happens, that's quite bad! # If this happens, that's quite bad!
# This if block is meant to be triggered by other async # This if block is meant to be triggered by other async
# libraries, which most likely have different trap names and behaviors # libraries, which most likely have different trap names and behaviors
@ -305,6 +307,7 @@ class AsyncScheduler:
) from None ) from None
# Sneaky method call, thanks to David Beazley for this ;) # Sneaky method call, thanks to David Beazley for this ;)
getattr(self, method)(*args) getattr(self, method)(*args)
self.debugger.after_task_step(self.current_task)
def io_release_task(self, task: Task): def io_release_task(self, task: Task):
""" """
@ -518,7 +521,10 @@ class AsyncScheduler:
# current pool. If, however, there are still some # current pool. If, however, there are still some
# tasks running, we wait for them to exit in order # tasks running, we wait for them to exit in order
# to avoid orphaned tasks # to avoid orphaned tasks
return pool.done() if pool.enclosed_pool and self.cancel_pool(pool.enclosed_pool):
return True
else:
return pool.done()
else: # If we're at the main task, we're sure everything else exited else: # If we're at the main task, we're sure everything else exited
return True return True
@ -582,6 +588,8 @@ class AsyncScheduler:
""" """
task.joined = True task.joined = True
if task is not self.current_task:
task.joiners.add(self.current_task)
if task.finished or task.cancelled: if task.finished or task.cancelled:
if not task.cancelled: if not task.cancelled:
self.debugger.on_task_exit(task) self.debugger.on_task_exit(task)
@ -589,33 +597,34 @@ class AsyncScheduler:
self.io_release_task(task) self.io_release_task(task)
if task.pool is None: if task.pool is None:
return return
if self.current_pool and self.current_pool.done(): if task.pool.done():
# If the current pool has finished executing or we're at the first parent # If the pool has finished executing or we're at the first parent
# task that kicked the loop, we can safely reschedule the parent(s) # task that kicked the loop, we can safely reschedule the parent(s)
self.reschedule_joiners(task) self.reschedule_joiners(task)
elif task.exc: elif task.exc:
task.status = "crashed" task.status = "crashed"
# TODO: We might want to do a bit more complex traceback hacking to remove any extra if task.exc.__traceback__:
# frames from the exception call stack, but for now removing at least the first one # TODO: We might want to do a bit more complex traceback hacking to remove any extra
# seems a sensible approach (it's us catching it so we don't care about that) # frames from the exception call stack, but for now removing at least the first one
task.exc.__traceback__ = task.exc.__traceback__.tb_next # seems a sensible approach (it's us catching it so we don't care about that)
task.exc.__traceback__ = task.exc.__traceback__.tb_next
if task.last_io: if task.last_io:
self.io_release_task(task) self.io_release_task(task)
self.debugger.on_exception_raised(task, task.exc) self.debugger.on_exception_raised(task, task.exc)
if task.pool is None: if task.pool is None:
# Parent task has no pool, so we propagate # Parent task has no pool, so we propagate
raise raise
if self.cancel_pool(self.current_pool): if self.cancel_pool(task.pool):
# This will reschedule the parent(s) # This will reschedule the parent(s)
# only if all the tasks inside the current # only if all the tasks inside the task's
# pool have finished executing, either # pool have finished executing, either
# by cancellation, an exception # by cancellation, an exception
# or just returned # or just returned
for t in task.joiners: for t in task.joiners.copy():
# Propagate the exception # Propagate the exception
try: try:
t.throw(task.exc) t.throw(task.exc)
except (StopIteration, CancelledError): except (StopIteration, CancelledError, RuntimeError):
# TODO: Need anything else? # TODO: Need anything else?
task.joiners.remove(t) task.joiners.remove(t)
self.reschedule_joiners(task) self.reschedule_joiners(task)
@ -657,8 +666,7 @@ class AsyncScheduler:
# or dangling resource open after being cancelled, so maybe we need # or dangling resource open after being cancelled, so maybe we need
# a different approach altogether # a different approach altogether
if task.status == "io": if task.status == "io":
for k in filter(lambda o: o.data == task, dict(self.selector.get_map()).values()): self.io_release_task(task)
self.selector.unregister(k.fileobj)
elif task.status == "sleep": elif task.status == "sleep":
self.paused.discard(task) self.paused.discard(task)
try: try:
@ -679,6 +687,10 @@ class AsyncScheduler:
task.status = "cancelled" task.status = "cancelled"
self.io_release_task(self.current_task) self.io_release_task(self.current_task)
self.debugger.after_cancel(task) self.debugger.after_cancel(task)
else:
# If the task ignores our exception, we'll
# raise it later again
task.cancel_pending = True
else: else:
# If we can't cancel in a somewhat "graceful" way, we just # If we can't cancel in a somewhat "graceful" way, we just
# defer this operation for later (check run() for more info) # defer this operation for later (check run() for more info)

View File

@ -42,7 +42,8 @@ def get_event_loop():
def new_event_loop(debugger: BaseDebugger, clock: FunctionType): def new_event_loop(debugger: BaseDebugger, clock: FunctionType):
""" """ print(hex(id(pool)))
Associates a new event loop to the current thread Associates a new event loop to the current thread
and deactivates the old one. This should not be and deactivates the old one. This should not be
called explicitly unless you know what you're doing. called explicitly unless you know what you're doing.
@ -92,10 +93,7 @@ def create_pool():
Creates an async pool Creates an async pool
""" """
loop = get_event_loop() return TaskManager()
pool = TaskManager()
loop.current_pool = pool
return pool
def with_timeout(timeout: int or float): def with_timeout(timeout: int or float):
@ -103,10 +101,7 @@ def with_timeout(timeout: int or float):
Creates an async pool with an associated timeout Creates an async pool with an associated timeout
""" """
loop = get_event_loop()
# We add 1 to make the timeout intuitive and inclusive (i.e. # We add 1 to make the timeout intuitive and inclusive (i.e.
# a 10 seconds timeout means the task is allowed to run 10 # a 10 seconds timeout means the task is allowed to run 10
# whole seconds instead of cancelling at the tenth second) # whole seconds instead of cancelling at the tenth second)
pool = TaskManager(timeout + 1) return TaskManager(timeout + 1)
loop.current_pool = pool
return pool

View File

@ -7,7 +7,7 @@ Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
@ -17,6 +17,7 @@ limitations under the License.
""" """
import giambio import giambio
import warnings
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Union, Coroutine, Set from typing import Union, Coroutine, Set
@ -95,6 +96,7 @@ class Task:
:type err: Exception :type err: Exception
""" """
# self.exc = err
return self.coroutine.throw(err) return self.coroutine.throw(err)
async def join(self): async def join(self):
@ -143,4 +145,5 @@ class Task:
self.coroutine.close() self.coroutine.close()
except RuntimeError: except RuntimeError:
pass # TODO: This is kinda bad pass # TODO: This is kinda bad
assert not self.last_io, f"task {self.name} was destroyed, but has pending I/O" if self.last_io:
warnings.warn(f"task '{self.name}' was destroyed, but has pending I/O")

View File

@ -41,7 +41,7 @@ def create_trap(method, *args):
return data return data
async def create_task(coro: FunctionType, *args): async def create_task(coro: FunctionType, pool, *args):
""" """
Spawns a new task in the current event loop from a bare coroutine Spawns a new task in the current event loop from a bare coroutine
function. All extra positional arguments are passed to the function function. All extra positional arguments are passed to the function
@ -56,7 +56,7 @@ async def create_task(coro: FunctionType, *args):
"\nWhat you wanna do, instead, is this: giambio.run(your_func, arg1, arg2, ...)" "\nWhat you wanna do, instead, is this: giambio.run(your_func, arg1, arg2, ...)"
) )
elif inspect.iscoroutinefunction(coro): elif inspect.iscoroutinefunction(coro):
return await create_trap("create_task", coro, *args) return await create_trap("create_task", coro, pool, *args)
else: else:
raise TypeError("coro must be a coroutine function") raise TypeError("coro must be a coroutine function")

View File

@ -5,7 +5,7 @@ from debugger import Debugger
async def child(): async def child():
print("[child] Child spawned!! Sleeping for 2 seconds") print("[child] Child spawned!! Sleeping for 2 seconds")
await giambio.sleep(2) await giambio.sleep(2)
print("[child] Had a nice nap!") print("[child] Had a nice nap, suiciding now!")
raise TypeError("rip") # Watch the exception magically propagate! raise TypeError("rip") # Watch the exception magically propagate!
@ -33,13 +33,13 @@ async def main():
async with giambio.create_pool() as pool: async with giambio.create_pool() as pool:
await pool.spawn(child) await pool.spawn(child)
await pool.spawn(child1) await pool.spawn(child1)
print("[main] Children spawned, awaiting completion") print("[main] First 2 children spawned, awaiting completion")
async with giambio.create_pool() as new_pool: async with giambio.create_pool() as new_pool:
# This pool will be cancelled by the exception # This pool will be cancelled by the exception
# in the other pool # in the other pool
await new_pool.spawn(child2) await new_pool.spawn(child2)
await new_pool.spawn(child3) await new_pool.spawn(child3)
print("[main] 3rd child spawned") print("[main] Third and fourth children spawned")
except Exception as error: except Exception as error:
# Because exceptions just *work*! # Because exceptions just *work*!
print(f"[main] Exception from child caught! {repr(error)}") print(f"[main] Exception from child caught! {repr(error)}")

View File

@ -16,6 +16,7 @@ async def main():
async with giambio.create_pool() as a_pool: async with giambio.create_pool() as a_pool:
await a_pool.spawn(child, 3) await a_pool.spawn(child, 3)
await a_pool.spawn(child, 4) await a_pool.spawn(child, 4)
# This executes after spawning all 4 tasks
print("[main] Children spawned, awaiting completion") print("[main] Children spawned, awaiting completion")
# This will *only* execute when everything inside the async with block # This will *only* execute when everything inside the async with block
# has ran, including any other pool # has ran, including any other pool