Added two-way proxy example stolen from njsmith and fixed bug with io_release_task being fucking dumb

This commit is contained in:
Nocturn9x 2022-02-27 18:14:12 +01:00
parent b8ee9945c1
commit ed6aba490f
4 changed files with 87 additions and 22 deletions

View File

@ -16,6 +16,7 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
""" """
from lib2to3.pgen2.token import OP
import types import types
import giambio import giambio
from typing import List, Optional from typing import List, Optional
@ -54,6 +55,7 @@ class TaskManager:
self._proper_init = False self._proper_init = False
self.enclosed_pool: Optional["giambio.context.TaskManager"] = None self.enclosed_pool: Optional["giambio.context.TaskManager"] = None
self.raise_on_timeout: bool = raise_on_timeout self.raise_on_timeout: bool = raise_on_timeout
self.entry_point: Optional[Task] = None
async def spawn(self, func: types.FunctionType, *args, **kwargs) -> "giambio.task.Task": async def spawn(self, func: types.FunctionType, *args, **kwargs) -> "giambio.task.Task":
""" """
@ -70,6 +72,7 @@ class TaskManager:
""" """
self._proper_init = True self._proper_init = True
self.entry_point = await giambio.traps.current_task()
return self return self
async def __aexit__(self, exc_type: Exception, exc: Exception, tb): async def __aexit__(self, exc_type: Exception, exc: Exception, tb):
@ -89,8 +92,9 @@ class TaskManager:
if isinstance(exc, giambio.exceptions.TooSlowError) and not self.raise_on_timeout: if isinstance(exc, giambio.exceptions.TooSlowError) and not self.raise_on_timeout:
return True return True
except giambio.exceptions.TooSlowError: except giambio.exceptions.TooSlowError:
return True if not self.raise_on_timeout:
raise
async def cancel(self): async def cancel(self):
""" """
Cancels the pool entirely, iterating over all Cancels the pool entirely, iterating over all
@ -108,4 +112,4 @@ class TaskManager:
pool have exited, False otherwise pool have exited, False otherwise
""" """
return self._proper_init and all([task.done() for task in self.tasks]) return self._proper_init and all([task.done() for task in self.tasks]) and (True if not self.enclosed_pool else self.enclosed_pool.done())

View File

@ -341,7 +341,7 @@ class AsyncScheduler:
if self.selector.get_map(): if self.selector.get_map():
for k in filter( for k in filter(
lambda o: o.data == self.current_task, lambda o: o.data == task,
dict(self.selector.get_map()).values(), dict(self.selector.get_map()).values(),
): ):
self.io_release(k.fileobj) self.io_release(k.fileobj)
@ -358,7 +358,7 @@ class AsyncScheduler:
before it's due before it's due
""" """
if self.current_task.last_io: if self.current_task.last_io or self.current_task.status == "io":
self.io_release_task(self.current_task) self.io_release_task(self.current_task)
self.current_task.status = "sleep" self.current_task.status = "sleep"
self.suspended.append(self.current_task) self.suspended.append(self.current_task)
@ -441,13 +441,11 @@ class AsyncScheduler:
while self.deadlines and self.deadlines.get_closest_deadline() <= self.clock(): while self.deadlines and self.deadlines.get_closest_deadline() <= self.clock():
pool = self.deadlines.get() pool = self.deadlines.get()
pool.timed_out = True pool.timed_out = True
if not pool.tasks and self.current_task is self.entry_point: self.cancel_pool(pool)
self.handle_task_exit(self.entry_point, partial(self.entry_point.throw, TooSlowError(self.entry_point)))
for task in pool.tasks: for task in pool.tasks:
if not task.done(): self.join(task)
self.paused.discard(task) self.handle_task_exit(self.entry_point, partial(self.entry_point.throw, TooSlowError(self.entry_point)))
self.io_release_task(task)
self.handle_task_exit(task, partial(task.throw, TooSlowError(task)))
def schedule_tasks(self, tasks: List[Task]): def schedule_tasks(self, tasks: List[Task]):
""" """
@ -554,9 +552,12 @@ class AsyncScheduler:
self.run_ready.append(entry) self.run_ready.append(entry)
self.debugger.on_start() self.debugger.on_start()
if loop: if loop:
self.run() try:
self.has_ran = True self.run()
self.debugger.on_exit() finally:
self.has_ran = True
self.close()
self.debugger.on_exit()
def cancel_pool(self, pool: TaskManager) -> bool: def cancel_pool(self, pool: TaskManager) -> bool:
""" """
@ -729,8 +730,8 @@ class AsyncScheduler:
self.io_release_task(task) self.io_release_task(task)
elif task.status == "sleep": elif task.status == "sleep":
self.paused.discard(task) self.paused.discard(task)
if task in self.suspended: if task in self.suspended:
self.suspended.remove(task) self.suspended.remove(task)
try: try:
self.do_cancel(task) self.do_cancel(task)
except CancelledError as cancel: except CancelledError as cancel:
@ -747,7 +748,6 @@ class AsyncScheduler:
task.cancel_pending = False task.cancel_pending = False
task.cancelled = True task.cancelled = True
task.status = "cancelled" task.status = "cancelled"
self.io_release_task(self.current_task)
self.debugger.after_cancel(task) self.debugger.after_cancel(task)
self.tasks.remove(task) self.tasks.remove(task)
else: else:
@ -758,12 +758,12 @@ class AsyncScheduler:
def register_sock(self, sock, evt_type: str): def register_sock(self, sock, evt_type: str):
""" """
Registers the given socket inside the Registers the given socket inside the
selector to perform I/0 multiplexing selector to perform I/O multiplexing
:param sock: The socket on which a read or write operation :param sock: The socket on which a read or write operation
has to be performed has to be performed
:param evt_type: The type of event to perform on the given :param evt_type: The type of event to perform on the given
socket, either "read" or "write" socket, either "read" or "write"
:type evt_type: str :type evt_type: str
""" """
@ -797,5 +797,8 @@ class AsyncScheduler:
try: try:
self.selector.register(sock, evt, self.current_task) self.selector.register(sock, evt, self.current_task)
except KeyError: except KeyError:
# The socket is already registered doing something else # The socket is already registered doing something else, we
raise ResourceBusy("The given socket is being read/written by another task") from None # modify the socket instead (or maybe not?)
self.selector.modify(sock, evt, self.current_task)
# TODO: Does this break stuff?
# raise ResourceBusy("The given socket is being read/written by another task") from None

View File

@ -16,6 +16,7 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
""" """
import giambio
from giambio.exceptions import ResourceClosed from giambio.exceptions import ResourceClosed
from giambio.traps import want_write, want_read, io_release from giambio.traps import want_write, want_read, io_release
@ -121,6 +122,7 @@ class AsyncSocket:
if self.sock: if self.sock:
self.sock.shutdown(how) self.sock.shutdown(how)
await giambio.sleep(0) # Checkpoint
async def bind(self, addr: tuple): async def bind(self, addr: tuple):
""" """

56
tests/proxy.py Normal file
View File

@ -0,0 +1,56 @@
from debugger import Debugger
import giambio
import socket
async def proxy_one_way(source: giambio.socket.AsyncSocket, sink: giambio.socket.AsyncSocket):
"""
Sends data from source to sink
"""
sink_addr = ":".join(map(str, await sink.getpeername()))
source_addr = ":".join(map(str, await source.getpeername()))
while True:
data = await source.receive(1024)
if not data:
print(f"{source_addr} has exited, closing connection to {sink_addr}")
await sink.shutdown(socket.SHUT_WR)
break
print(f"Got {data.decode('utf8', errors='ignore')!r} from {source_addr}, forwarding it to {sink_addr}")
await sink.send_all(data)
async def proxy_two_way(a: giambio.socket.AsyncSocket, b: giambio.socket.AsyncSocket):
"""
Sets up a two-way proxy from a to b and from b to a
"""
async with giambio.create_pool() as pool:
await pool.spawn(proxy_one_way, a, b)
await pool.spawn(proxy_one_way, b, a)
async def main(delay: int, a: tuple, b: tuple):
"""
Sets up the proxy
"""
start = giambio.clock()
print(f"Starting two-way proxy from {a[0]}:{a[1]} to {b[0]}:{b[1]}, lasting for {delay} seconds")
async with giambio.skip_after(delay) as p:
sock_a = giambio.socket.socket()
sock_b = giambio.socket.socket()
await sock_a.connect(a)
await sock_b.connect(b)
async with sock_a, sock_b:
await proxy_two_way(sock_a, sock_b)
print(f"Proxy has exited after {giambio.clock() - start:.2f} seconds")
try:
giambio.run(main, 60, ("localhost", 12345), ("localhost", 54321), debugger=())
except (Exception, KeyboardInterrupt) as error: # Exceptions propagate!
if isinstance(error, KeyboardInterrupt):
print("Ctrl+C detected, exiting")
else:
print(f"Exiting due to a {type(error).__name__}: {error}")