mirror of https://github.com/nocturn9x/giambio.git
Fixed nested pools (sorta?)
This commit is contained in:
parent
2cdaa231b4
commit
9bb091cdef
|
@ -18,7 +18,7 @@ limitations under the License.
|
|||
|
||||
import types
|
||||
import giambio
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
|
||||
|
||||
class TaskManager:
|
||||
|
@ -49,6 +49,9 @@ class TaskManager:
|
|||
# Whether our timeout expired or not
|
||||
self.timed_out: bool = False
|
||||
self._proper_init = False
|
||||
self.enclosing_pool: Optional["giambio.context.TaskManager"] = giambio.get_event_loop().current_pool
|
||||
self.enclosed_pool: Optional["giambio.context.TaskManager"] = None
|
||||
# giambio.get_event_loop().current_pool = self
|
||||
|
||||
async def spawn(self, func: types.FunctionType, *args) -> "giambio.task.Task":
|
||||
"""
|
||||
|
@ -56,7 +59,7 @@ class TaskManager:
|
|||
"""
|
||||
|
||||
assert self._proper_init, "Cannot use improperly initialized pool"
|
||||
return await giambio.traps.create_task(func, *args)
|
||||
return await giambio.traps.create_task(func, self, *args)
|
||||
|
||||
async def __aenter__(self):
|
||||
"""
|
||||
|
|
|
@ -236,7 +236,7 @@ class AsyncScheduler:
|
|||
self.current_task.exc = err
|
||||
self.join(self.current_task)
|
||||
|
||||
def create_task(self, corofunc: types.FunctionType, *args, **kwargs) -> Task:
|
||||
def create_task(self, corofunc: types.FunctionType, pool, *args, **kwargs) -> Task:
|
||||
"""
|
||||
Creates a task from a coroutine function and schedules it
|
||||
to run. Any extra keyword or positional argument are then
|
||||
|
@ -247,15 +247,18 @@ class AsyncScheduler:
|
|||
:type corofunc: function
|
||||
"""
|
||||
|
||||
task = Task(corofunc.__name__ or str(corofunc), corofunc(*args, **kwargs), self.current_pool)
|
||||
task.next_deadline = self.current_pool.timeout or 0.0
|
||||
task = Task(corofunc.__name__ or str(corofunc), corofunc(*args, **kwargs), pool)
|
||||
task.next_deadline = pool.timeout or 0.0
|
||||
task.joiners = {self.current_task}
|
||||
self._data = task
|
||||
self.tasks.append(task)
|
||||
self.run_ready.append(task)
|
||||
self.debugger.on_task_spawn(task)
|
||||
self.current_pool.tasks.append(task)
|
||||
pool.tasks.append(task)
|
||||
self.reschedule_running()
|
||||
if self.current_pool and task.pool is not self.current_pool:
|
||||
self.current_pool.enclosed_pool = task.pool
|
||||
self.current_pool = task.pool
|
||||
return task
|
||||
|
||||
def run_task_step(self):
|
||||
|
@ -270,8 +273,8 @@ class AsyncScheduler:
|
|||
account, that's self.run's job!
|
||||
"""
|
||||
|
||||
# Sets the currently running task
|
||||
data = None
|
||||
# Sets the currently running task
|
||||
self.current_task = self.run_ready.pop(0)
|
||||
self.debugger.before_task_step(self.current_task)
|
||||
if self.current_task.done():
|
||||
|
@ -293,8 +296,7 @@ class AsyncScheduler:
|
|||
# Some debugging and internal chatter here
|
||||
self.current_task.status = "run"
|
||||
self.current_task.steps += 1
|
||||
self.debugger.after_task_step(self.current_task)
|
||||
if not hasattr(self, method):
|
||||
if not hasattr(self, method) and not callable(getattr(self, method)):
|
||||
# If this happens, that's quite bad!
|
||||
# This if block is meant to be triggered by other async
|
||||
# libraries, which most likely have different trap names and behaviors
|
||||
|
@ -305,6 +307,7 @@ class AsyncScheduler:
|
|||
) from None
|
||||
# Sneaky method call, thanks to David Beazley for this ;)
|
||||
getattr(self, method)(*args)
|
||||
self.debugger.after_task_step(self.current_task)
|
||||
|
||||
def io_release_task(self, task: Task):
|
||||
"""
|
||||
|
@ -518,7 +521,10 @@ class AsyncScheduler:
|
|||
# current pool. If, however, there are still some
|
||||
# tasks running, we wait for them to exit in order
|
||||
# to avoid orphaned tasks
|
||||
return pool.done()
|
||||
if pool.enclosed_pool and self.cancel_pool(pool.enclosed_pool):
|
||||
return True
|
||||
else:
|
||||
return pool.done()
|
||||
else: # If we're at the main task, we're sure everything else exited
|
||||
return True
|
||||
|
||||
|
@ -582,6 +588,8 @@ class AsyncScheduler:
|
|||
"""
|
||||
|
||||
task.joined = True
|
||||
if task is not self.current_task:
|
||||
task.joiners.add(self.current_task)
|
||||
if task.finished or task.cancelled:
|
||||
if not task.cancelled:
|
||||
self.debugger.on_task_exit(task)
|
||||
|
@ -589,33 +597,34 @@ class AsyncScheduler:
|
|||
self.io_release_task(task)
|
||||
if task.pool is None:
|
||||
return
|
||||
if self.current_pool and self.current_pool.done():
|
||||
# If the current pool has finished executing or we're at the first parent
|
||||
if task.pool.done():
|
||||
# If the pool has finished executing or we're at the first parent
|
||||
# task that kicked the loop, we can safely reschedule the parent(s)
|
||||
self.reschedule_joiners(task)
|
||||
elif task.exc:
|
||||
task.status = "crashed"
|
||||
# TODO: We might want to do a bit more complex traceback hacking to remove any extra
|
||||
# frames from the exception call stack, but for now removing at least the first one
|
||||
# seems a sensible approach (it's us catching it so we don't care about that)
|
||||
task.exc.__traceback__ = task.exc.__traceback__.tb_next
|
||||
if task.exc.__traceback__:
|
||||
# TODO: We might want to do a bit more complex traceback hacking to remove any extra
|
||||
# frames from the exception call stack, but for now removing at least the first one
|
||||
# seems a sensible approach (it's us catching it so we don't care about that)
|
||||
task.exc.__traceback__ = task.exc.__traceback__.tb_next
|
||||
if task.last_io:
|
||||
self.io_release_task(task)
|
||||
self.debugger.on_exception_raised(task, task.exc)
|
||||
if task.pool is None:
|
||||
# Parent task has no pool, so we propagate
|
||||
raise
|
||||
if self.cancel_pool(self.current_pool):
|
||||
if self.cancel_pool(task.pool):
|
||||
# This will reschedule the parent(s)
|
||||
# only if all the tasks inside the current
|
||||
# only if all the tasks inside the task's
|
||||
# pool have finished executing, either
|
||||
# by cancellation, an exception
|
||||
# or just returned
|
||||
for t in task.joiners:
|
||||
for t in task.joiners.copy():
|
||||
# Propagate the exception
|
||||
try:
|
||||
t.throw(task.exc)
|
||||
except (StopIteration, CancelledError):
|
||||
except (StopIteration, CancelledError, RuntimeError):
|
||||
# TODO: Need anything else?
|
||||
task.joiners.remove(t)
|
||||
self.reschedule_joiners(task)
|
||||
|
@ -657,8 +666,7 @@ class AsyncScheduler:
|
|||
# or dangling resource open after being cancelled, so maybe we need
|
||||
# a different approach altogether
|
||||
if task.status == "io":
|
||||
for k in filter(lambda o: o.data == task, dict(self.selector.get_map()).values()):
|
||||
self.selector.unregister(k.fileobj)
|
||||
self.io_release_task(task)
|
||||
elif task.status == "sleep":
|
||||
self.paused.discard(task)
|
||||
try:
|
||||
|
@ -679,6 +687,10 @@ class AsyncScheduler:
|
|||
task.status = "cancelled"
|
||||
self.io_release_task(self.current_task)
|
||||
self.debugger.after_cancel(task)
|
||||
else:
|
||||
# If the task ignores our exception, we'll
|
||||
# raise it later again
|
||||
task.cancel_pending = True
|
||||
else:
|
||||
# If we can't cancel in a somewhat "graceful" way, we just
|
||||
# defer this operation for later (check run() for more info)
|
||||
|
|
|
@ -42,7 +42,8 @@ def get_event_loop():
|
|||
|
||||
|
||||
def new_event_loop(debugger: BaseDebugger, clock: FunctionType):
|
||||
"""
|
||||
""" print(hex(id(pool)))
|
||||
|
||||
Associates a new event loop to the current thread
|
||||
and deactivates the old one. This should not be
|
||||
called explicitly unless you know what you're doing.
|
||||
|
@ -92,10 +93,7 @@ def create_pool():
|
|||
Creates an async pool
|
||||
"""
|
||||
|
||||
loop = get_event_loop()
|
||||
pool = TaskManager()
|
||||
loop.current_pool = pool
|
||||
return pool
|
||||
return TaskManager()
|
||||
|
||||
|
||||
def with_timeout(timeout: int or float):
|
||||
|
@ -103,10 +101,7 @@ def with_timeout(timeout: int or float):
|
|||
Creates an async pool with an associated timeout
|
||||
"""
|
||||
|
||||
loop = get_event_loop()
|
||||
# We add 1 to make the timeout intuitive and inclusive (i.e.
|
||||
# a 10 seconds timeout means the task is allowed to run 10
|
||||
# whole seconds instead of cancelling at the tenth second)
|
||||
pool = TaskManager(timeout + 1)
|
||||
loop.current_pool = pool
|
||||
return pool
|
||||
return TaskManager(timeout + 1)
|
||||
|
|
|
@ -7,7 +7,7 @@ Licensed under the Apache License, Version 2.0 (the "License");
|
|||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
https://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
|||
"""
|
||||
|
||||
import giambio
|
||||
import warnings
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Union, Coroutine, Set
|
||||
|
||||
|
@ -95,6 +96,7 @@ class Task:
|
|||
:type err: Exception
|
||||
"""
|
||||
|
||||
# self.exc = err
|
||||
return self.coroutine.throw(err)
|
||||
|
||||
async def join(self):
|
||||
|
@ -143,4 +145,5 @@ class Task:
|
|||
self.coroutine.close()
|
||||
except RuntimeError:
|
||||
pass # TODO: This is kinda bad
|
||||
assert not self.last_io, f"task {self.name} was destroyed, but has pending I/O"
|
||||
if self.last_io:
|
||||
warnings.warn(f"task '{self.name}' was destroyed, but has pending I/O")
|
||||
|
|
|
@ -41,7 +41,7 @@ def create_trap(method, *args):
|
|||
return data
|
||||
|
||||
|
||||
async def create_task(coro: FunctionType, *args):
|
||||
async def create_task(coro: FunctionType, pool, *args):
|
||||
"""
|
||||
Spawns a new task in the current event loop from a bare coroutine
|
||||
function. All extra positional arguments are passed to the function
|
||||
|
@ -56,7 +56,7 @@ async def create_task(coro: FunctionType, *args):
|
|||
"\nWhat you wanna do, instead, is this: giambio.run(your_func, arg1, arg2, ...)"
|
||||
)
|
||||
elif inspect.iscoroutinefunction(coro):
|
||||
return await create_trap("create_task", coro, *args)
|
||||
return await create_trap("create_task", coro, pool, *args)
|
||||
else:
|
||||
raise TypeError("coro must be a coroutine function")
|
||||
|
||||
|
|
|
@ -5,7 +5,7 @@ from debugger import Debugger
|
|||
async def child():
|
||||
print("[child] Child spawned!! Sleeping for 2 seconds")
|
||||
await giambio.sleep(2)
|
||||
print("[child] Had a nice nap!")
|
||||
print("[child] Had a nice nap, suiciding now!")
|
||||
raise TypeError("rip") # Watch the exception magically propagate!
|
||||
|
||||
|
||||
|
@ -33,13 +33,13 @@ async def main():
|
|||
async with giambio.create_pool() as pool:
|
||||
await pool.spawn(child)
|
||||
await pool.spawn(child1)
|
||||
print("[main] Children spawned, awaiting completion")
|
||||
print("[main] First 2 children spawned, awaiting completion")
|
||||
async with giambio.create_pool() as new_pool:
|
||||
# This pool will be cancelled by the exception
|
||||
# in the other pool
|
||||
await new_pool.spawn(child2)
|
||||
await new_pool.spawn(child3)
|
||||
print("[main] 3rd child spawned")
|
||||
print("[main] Third and fourth children spawned")
|
||||
except Exception as error:
|
||||
# Because exceptions just *work*!
|
||||
print(f"[main] Exception from child caught! {repr(error)}")
|
||||
|
|
|
@ -16,6 +16,7 @@ async def main():
|
|||
async with giambio.create_pool() as a_pool:
|
||||
await a_pool.spawn(child, 3)
|
||||
await a_pool.spawn(child, 4)
|
||||
# This executes after spawning all 4 tasks
|
||||
print("[main] Children spawned, awaiting completion")
|
||||
# This will *only* execute when everything inside the async with block
|
||||
# has ran, including any other pool
|
||||
|
|
Loading…
Reference in New Issue