mirror of https://github.com/nocturn9x/giambio.git
Fixed nested pools (sorta?)
This commit is contained in:
parent
2cdaa231b4
commit
9bb091cdef
|
@ -18,7 +18,7 @@ limitations under the License.
|
||||||
|
|
||||||
import types
|
import types
|
||||||
import giambio
|
import giambio
|
||||||
from typing import List
|
from typing import List, Optional
|
||||||
|
|
||||||
|
|
||||||
class TaskManager:
|
class TaskManager:
|
||||||
|
@ -49,6 +49,9 @@ class TaskManager:
|
||||||
# Whether our timeout expired or not
|
# Whether our timeout expired or not
|
||||||
self.timed_out: bool = False
|
self.timed_out: bool = False
|
||||||
self._proper_init = False
|
self._proper_init = False
|
||||||
|
self.enclosing_pool: Optional["giambio.context.TaskManager"] = giambio.get_event_loop().current_pool
|
||||||
|
self.enclosed_pool: Optional["giambio.context.TaskManager"] = None
|
||||||
|
# giambio.get_event_loop().current_pool = self
|
||||||
|
|
||||||
async def spawn(self, func: types.FunctionType, *args) -> "giambio.task.Task":
|
async def spawn(self, func: types.FunctionType, *args) -> "giambio.task.Task":
|
||||||
"""
|
"""
|
||||||
|
@ -56,7 +59,7 @@ class TaskManager:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
assert self._proper_init, "Cannot use improperly initialized pool"
|
assert self._proper_init, "Cannot use improperly initialized pool"
|
||||||
return await giambio.traps.create_task(func, *args)
|
return await giambio.traps.create_task(func, self, *args)
|
||||||
|
|
||||||
async def __aenter__(self):
|
async def __aenter__(self):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -236,7 +236,7 @@ class AsyncScheduler:
|
||||||
self.current_task.exc = err
|
self.current_task.exc = err
|
||||||
self.join(self.current_task)
|
self.join(self.current_task)
|
||||||
|
|
||||||
def create_task(self, corofunc: types.FunctionType, *args, **kwargs) -> Task:
|
def create_task(self, corofunc: types.FunctionType, pool, *args, **kwargs) -> Task:
|
||||||
"""
|
"""
|
||||||
Creates a task from a coroutine function and schedules it
|
Creates a task from a coroutine function and schedules it
|
||||||
to run. Any extra keyword or positional argument are then
|
to run. Any extra keyword or positional argument are then
|
||||||
|
@ -247,15 +247,18 @@ class AsyncScheduler:
|
||||||
:type corofunc: function
|
:type corofunc: function
|
||||||
"""
|
"""
|
||||||
|
|
||||||
task = Task(corofunc.__name__ or str(corofunc), corofunc(*args, **kwargs), self.current_pool)
|
task = Task(corofunc.__name__ or str(corofunc), corofunc(*args, **kwargs), pool)
|
||||||
task.next_deadline = self.current_pool.timeout or 0.0
|
task.next_deadline = pool.timeout or 0.0
|
||||||
task.joiners = {self.current_task}
|
task.joiners = {self.current_task}
|
||||||
self._data = task
|
self._data = task
|
||||||
self.tasks.append(task)
|
self.tasks.append(task)
|
||||||
self.run_ready.append(task)
|
self.run_ready.append(task)
|
||||||
self.debugger.on_task_spawn(task)
|
self.debugger.on_task_spawn(task)
|
||||||
self.current_pool.tasks.append(task)
|
pool.tasks.append(task)
|
||||||
self.reschedule_running()
|
self.reschedule_running()
|
||||||
|
if self.current_pool and task.pool is not self.current_pool:
|
||||||
|
self.current_pool.enclosed_pool = task.pool
|
||||||
|
self.current_pool = task.pool
|
||||||
return task
|
return task
|
||||||
|
|
||||||
def run_task_step(self):
|
def run_task_step(self):
|
||||||
|
@ -270,8 +273,8 @@ class AsyncScheduler:
|
||||||
account, that's self.run's job!
|
account, that's self.run's job!
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Sets the currently running task
|
|
||||||
data = None
|
data = None
|
||||||
|
# Sets the currently running task
|
||||||
self.current_task = self.run_ready.pop(0)
|
self.current_task = self.run_ready.pop(0)
|
||||||
self.debugger.before_task_step(self.current_task)
|
self.debugger.before_task_step(self.current_task)
|
||||||
if self.current_task.done():
|
if self.current_task.done():
|
||||||
|
@ -293,8 +296,7 @@ class AsyncScheduler:
|
||||||
# Some debugging and internal chatter here
|
# Some debugging and internal chatter here
|
||||||
self.current_task.status = "run"
|
self.current_task.status = "run"
|
||||||
self.current_task.steps += 1
|
self.current_task.steps += 1
|
||||||
self.debugger.after_task_step(self.current_task)
|
if not hasattr(self, method) and not callable(getattr(self, method)):
|
||||||
if not hasattr(self, method):
|
|
||||||
# If this happens, that's quite bad!
|
# If this happens, that's quite bad!
|
||||||
# This if block is meant to be triggered by other async
|
# This if block is meant to be triggered by other async
|
||||||
# libraries, which most likely have different trap names and behaviors
|
# libraries, which most likely have different trap names and behaviors
|
||||||
|
@ -305,6 +307,7 @@ class AsyncScheduler:
|
||||||
) from None
|
) from None
|
||||||
# Sneaky method call, thanks to David Beazley for this ;)
|
# Sneaky method call, thanks to David Beazley for this ;)
|
||||||
getattr(self, method)(*args)
|
getattr(self, method)(*args)
|
||||||
|
self.debugger.after_task_step(self.current_task)
|
||||||
|
|
||||||
def io_release_task(self, task: Task):
|
def io_release_task(self, task: Task):
|
||||||
"""
|
"""
|
||||||
|
@ -518,7 +521,10 @@ class AsyncScheduler:
|
||||||
# current pool. If, however, there are still some
|
# current pool. If, however, there are still some
|
||||||
# tasks running, we wait for them to exit in order
|
# tasks running, we wait for them to exit in order
|
||||||
# to avoid orphaned tasks
|
# to avoid orphaned tasks
|
||||||
return pool.done()
|
if pool.enclosed_pool and self.cancel_pool(pool.enclosed_pool):
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return pool.done()
|
||||||
else: # If we're at the main task, we're sure everything else exited
|
else: # If we're at the main task, we're sure everything else exited
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
@ -582,6 +588,8 @@ class AsyncScheduler:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
task.joined = True
|
task.joined = True
|
||||||
|
if task is not self.current_task:
|
||||||
|
task.joiners.add(self.current_task)
|
||||||
if task.finished or task.cancelled:
|
if task.finished or task.cancelled:
|
||||||
if not task.cancelled:
|
if not task.cancelled:
|
||||||
self.debugger.on_task_exit(task)
|
self.debugger.on_task_exit(task)
|
||||||
|
@ -589,33 +597,34 @@ class AsyncScheduler:
|
||||||
self.io_release_task(task)
|
self.io_release_task(task)
|
||||||
if task.pool is None:
|
if task.pool is None:
|
||||||
return
|
return
|
||||||
if self.current_pool and self.current_pool.done():
|
if task.pool.done():
|
||||||
# If the current pool has finished executing or we're at the first parent
|
# If the pool has finished executing or we're at the first parent
|
||||||
# task that kicked the loop, we can safely reschedule the parent(s)
|
# task that kicked the loop, we can safely reschedule the parent(s)
|
||||||
self.reschedule_joiners(task)
|
self.reschedule_joiners(task)
|
||||||
elif task.exc:
|
elif task.exc:
|
||||||
task.status = "crashed"
|
task.status = "crashed"
|
||||||
# TODO: We might want to do a bit more complex traceback hacking to remove any extra
|
if task.exc.__traceback__:
|
||||||
# frames from the exception call stack, but for now removing at least the first one
|
# TODO: We might want to do a bit more complex traceback hacking to remove any extra
|
||||||
# seems a sensible approach (it's us catching it so we don't care about that)
|
# frames from the exception call stack, but for now removing at least the first one
|
||||||
task.exc.__traceback__ = task.exc.__traceback__.tb_next
|
# seems a sensible approach (it's us catching it so we don't care about that)
|
||||||
|
task.exc.__traceback__ = task.exc.__traceback__.tb_next
|
||||||
if task.last_io:
|
if task.last_io:
|
||||||
self.io_release_task(task)
|
self.io_release_task(task)
|
||||||
self.debugger.on_exception_raised(task, task.exc)
|
self.debugger.on_exception_raised(task, task.exc)
|
||||||
if task.pool is None:
|
if task.pool is None:
|
||||||
# Parent task has no pool, so we propagate
|
# Parent task has no pool, so we propagate
|
||||||
raise
|
raise
|
||||||
if self.cancel_pool(self.current_pool):
|
if self.cancel_pool(task.pool):
|
||||||
# This will reschedule the parent(s)
|
# This will reschedule the parent(s)
|
||||||
# only if all the tasks inside the current
|
# only if all the tasks inside the task's
|
||||||
# pool have finished executing, either
|
# pool have finished executing, either
|
||||||
# by cancellation, an exception
|
# by cancellation, an exception
|
||||||
# or just returned
|
# or just returned
|
||||||
for t in task.joiners:
|
for t in task.joiners.copy():
|
||||||
# Propagate the exception
|
# Propagate the exception
|
||||||
try:
|
try:
|
||||||
t.throw(task.exc)
|
t.throw(task.exc)
|
||||||
except (StopIteration, CancelledError):
|
except (StopIteration, CancelledError, RuntimeError):
|
||||||
# TODO: Need anything else?
|
# TODO: Need anything else?
|
||||||
task.joiners.remove(t)
|
task.joiners.remove(t)
|
||||||
self.reschedule_joiners(task)
|
self.reschedule_joiners(task)
|
||||||
|
@ -657,8 +666,7 @@ class AsyncScheduler:
|
||||||
# or dangling resource open after being cancelled, so maybe we need
|
# or dangling resource open after being cancelled, so maybe we need
|
||||||
# a different approach altogether
|
# a different approach altogether
|
||||||
if task.status == "io":
|
if task.status == "io":
|
||||||
for k in filter(lambda o: o.data == task, dict(self.selector.get_map()).values()):
|
self.io_release_task(task)
|
||||||
self.selector.unregister(k.fileobj)
|
|
||||||
elif task.status == "sleep":
|
elif task.status == "sleep":
|
||||||
self.paused.discard(task)
|
self.paused.discard(task)
|
||||||
try:
|
try:
|
||||||
|
@ -679,6 +687,10 @@ class AsyncScheduler:
|
||||||
task.status = "cancelled"
|
task.status = "cancelled"
|
||||||
self.io_release_task(self.current_task)
|
self.io_release_task(self.current_task)
|
||||||
self.debugger.after_cancel(task)
|
self.debugger.after_cancel(task)
|
||||||
|
else:
|
||||||
|
# If the task ignores our exception, we'll
|
||||||
|
# raise it later again
|
||||||
|
task.cancel_pending = True
|
||||||
else:
|
else:
|
||||||
# If we can't cancel in a somewhat "graceful" way, we just
|
# If we can't cancel in a somewhat "graceful" way, we just
|
||||||
# defer this operation for later (check run() for more info)
|
# defer this operation for later (check run() for more info)
|
||||||
|
|
|
@ -42,7 +42,8 @@ def get_event_loop():
|
||||||
|
|
||||||
|
|
||||||
def new_event_loop(debugger: BaseDebugger, clock: FunctionType):
|
def new_event_loop(debugger: BaseDebugger, clock: FunctionType):
|
||||||
"""
|
""" print(hex(id(pool)))
|
||||||
|
|
||||||
Associates a new event loop to the current thread
|
Associates a new event loop to the current thread
|
||||||
and deactivates the old one. This should not be
|
and deactivates the old one. This should not be
|
||||||
called explicitly unless you know what you're doing.
|
called explicitly unless you know what you're doing.
|
||||||
|
@ -92,10 +93,7 @@ def create_pool():
|
||||||
Creates an async pool
|
Creates an async pool
|
||||||
"""
|
"""
|
||||||
|
|
||||||
loop = get_event_loop()
|
return TaskManager()
|
||||||
pool = TaskManager()
|
|
||||||
loop.current_pool = pool
|
|
||||||
return pool
|
|
||||||
|
|
||||||
|
|
||||||
def with_timeout(timeout: int or float):
|
def with_timeout(timeout: int or float):
|
||||||
|
@ -103,10 +101,7 @@ def with_timeout(timeout: int or float):
|
||||||
Creates an async pool with an associated timeout
|
Creates an async pool with an associated timeout
|
||||||
"""
|
"""
|
||||||
|
|
||||||
loop = get_event_loop()
|
|
||||||
# We add 1 to make the timeout intuitive and inclusive (i.e.
|
# We add 1 to make the timeout intuitive and inclusive (i.e.
|
||||||
# a 10 seconds timeout means the task is allowed to run 10
|
# a 10 seconds timeout means the task is allowed to run 10
|
||||||
# whole seconds instead of cancelling at the tenth second)
|
# whole seconds instead of cancelling at the tenth second)
|
||||||
pool = TaskManager(timeout + 1)
|
return TaskManager(timeout + 1)
|
||||||
loop.current_pool = pool
|
|
||||||
return pool
|
|
||||||
|
|
|
@ -7,7 +7,7 @@ Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
You may obtain a copy of the License at
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
Unless required by applicable law or agreed to in writing, software
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import giambio
|
import giambio
|
||||||
|
import warnings
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Union, Coroutine, Set
|
from typing import Union, Coroutine, Set
|
||||||
|
|
||||||
|
@ -95,6 +96,7 @@ class Task:
|
||||||
:type err: Exception
|
:type err: Exception
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# self.exc = err
|
||||||
return self.coroutine.throw(err)
|
return self.coroutine.throw(err)
|
||||||
|
|
||||||
async def join(self):
|
async def join(self):
|
||||||
|
@ -143,4 +145,5 @@ class Task:
|
||||||
self.coroutine.close()
|
self.coroutine.close()
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
pass # TODO: This is kinda bad
|
pass # TODO: This is kinda bad
|
||||||
assert not self.last_io, f"task {self.name} was destroyed, but has pending I/O"
|
if self.last_io:
|
||||||
|
warnings.warn(f"task '{self.name}' was destroyed, but has pending I/O")
|
||||||
|
|
|
@ -41,7 +41,7 @@ def create_trap(method, *args):
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
async def create_task(coro: FunctionType, *args):
|
async def create_task(coro: FunctionType, pool, *args):
|
||||||
"""
|
"""
|
||||||
Spawns a new task in the current event loop from a bare coroutine
|
Spawns a new task in the current event loop from a bare coroutine
|
||||||
function. All extra positional arguments are passed to the function
|
function. All extra positional arguments are passed to the function
|
||||||
|
@ -56,7 +56,7 @@ async def create_task(coro: FunctionType, *args):
|
||||||
"\nWhat you wanna do, instead, is this: giambio.run(your_func, arg1, arg2, ...)"
|
"\nWhat you wanna do, instead, is this: giambio.run(your_func, arg1, arg2, ...)"
|
||||||
)
|
)
|
||||||
elif inspect.iscoroutinefunction(coro):
|
elif inspect.iscoroutinefunction(coro):
|
||||||
return await create_trap("create_task", coro, *args)
|
return await create_trap("create_task", coro, pool, *args)
|
||||||
else:
|
else:
|
||||||
raise TypeError("coro must be a coroutine function")
|
raise TypeError("coro must be a coroutine function")
|
||||||
|
|
||||||
|
|
|
@ -5,7 +5,7 @@ from debugger import Debugger
|
||||||
async def child():
|
async def child():
|
||||||
print("[child] Child spawned!! Sleeping for 2 seconds")
|
print("[child] Child spawned!! Sleeping for 2 seconds")
|
||||||
await giambio.sleep(2)
|
await giambio.sleep(2)
|
||||||
print("[child] Had a nice nap!")
|
print("[child] Had a nice nap, suiciding now!")
|
||||||
raise TypeError("rip") # Watch the exception magically propagate!
|
raise TypeError("rip") # Watch the exception magically propagate!
|
||||||
|
|
||||||
|
|
||||||
|
@ -33,13 +33,13 @@ async def main():
|
||||||
async with giambio.create_pool() as pool:
|
async with giambio.create_pool() as pool:
|
||||||
await pool.spawn(child)
|
await pool.spawn(child)
|
||||||
await pool.spawn(child1)
|
await pool.spawn(child1)
|
||||||
print("[main] Children spawned, awaiting completion")
|
print("[main] First 2 children spawned, awaiting completion")
|
||||||
async with giambio.create_pool() as new_pool:
|
async with giambio.create_pool() as new_pool:
|
||||||
# This pool will be cancelled by the exception
|
# This pool will be cancelled by the exception
|
||||||
# in the other pool
|
# in the other pool
|
||||||
await new_pool.spawn(child2)
|
await new_pool.spawn(child2)
|
||||||
await new_pool.spawn(child3)
|
await new_pool.spawn(child3)
|
||||||
print("[main] 3rd child spawned")
|
print("[main] Third and fourth children spawned")
|
||||||
except Exception as error:
|
except Exception as error:
|
||||||
# Because exceptions just *work*!
|
# Because exceptions just *work*!
|
||||||
print(f"[main] Exception from child caught! {repr(error)}")
|
print(f"[main] Exception from child caught! {repr(error)}")
|
||||||
|
|
|
@ -16,6 +16,7 @@ async def main():
|
||||||
async with giambio.create_pool() as a_pool:
|
async with giambio.create_pool() as a_pool:
|
||||||
await a_pool.spawn(child, 3)
|
await a_pool.spawn(child, 3)
|
||||||
await a_pool.spawn(child, 4)
|
await a_pool.spawn(child, 4)
|
||||||
|
# This executes after spawning all 4 tasks
|
||||||
print("[main] Children spawned, awaiting completion")
|
print("[main] Children spawned, awaiting completion")
|
||||||
# This will *only* execute when everything inside the async with block
|
# This will *only* execute when everything inside the async with block
|
||||||
# has ran, including any other pool
|
# has ran, including any other pool
|
||||||
|
|
Loading…
Reference in New Issue