mirror of https://github.com/nocturn9x/giambio.git
Added two-way proxy example stolen from njsmith and fixed bug with io_release_task being fucking dumb
This commit is contained in:
parent
b8ee9945c1
commit
ed6aba490f
|
@ -16,6 +16,7 @@ See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from lib2to3.pgen2.token import OP
|
||||||
import types
|
import types
|
||||||
import giambio
|
import giambio
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
@ -54,6 +55,7 @@ class TaskManager:
|
||||||
self._proper_init = False
|
self._proper_init = False
|
||||||
self.enclosed_pool: Optional["giambio.context.TaskManager"] = None
|
self.enclosed_pool: Optional["giambio.context.TaskManager"] = None
|
||||||
self.raise_on_timeout: bool = raise_on_timeout
|
self.raise_on_timeout: bool = raise_on_timeout
|
||||||
|
self.entry_point: Optional[Task] = None
|
||||||
|
|
||||||
async def spawn(self, func: types.FunctionType, *args, **kwargs) -> "giambio.task.Task":
|
async def spawn(self, func: types.FunctionType, *args, **kwargs) -> "giambio.task.Task":
|
||||||
"""
|
"""
|
||||||
|
@ -70,6 +72,7 @@ class TaskManager:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
self._proper_init = True
|
self._proper_init = True
|
||||||
|
self.entry_point = await giambio.traps.current_task()
|
||||||
return self
|
return self
|
||||||
|
|
||||||
async def __aexit__(self, exc_type: Exception, exc: Exception, tb):
|
async def __aexit__(self, exc_type: Exception, exc: Exception, tb):
|
||||||
|
@ -89,8 +92,9 @@ class TaskManager:
|
||||||
if isinstance(exc, giambio.exceptions.TooSlowError) and not self.raise_on_timeout:
|
if isinstance(exc, giambio.exceptions.TooSlowError) and not self.raise_on_timeout:
|
||||||
return True
|
return True
|
||||||
except giambio.exceptions.TooSlowError:
|
except giambio.exceptions.TooSlowError:
|
||||||
return True
|
if not self.raise_on_timeout:
|
||||||
|
raise
|
||||||
|
|
||||||
async def cancel(self):
|
async def cancel(self):
|
||||||
"""
|
"""
|
||||||
Cancels the pool entirely, iterating over all
|
Cancels the pool entirely, iterating over all
|
||||||
|
@ -108,4 +112,4 @@ class TaskManager:
|
||||||
pool have exited, False otherwise
|
pool have exited, False otherwise
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return self._proper_init and all([task.done() for task in self.tasks])
|
return self._proper_init and all([task.done() for task in self.tasks]) and (True if not self.enclosed_pool else self.enclosed_pool.done())
|
||||||
|
|
|
@ -341,7 +341,7 @@ class AsyncScheduler:
|
||||||
|
|
||||||
if self.selector.get_map():
|
if self.selector.get_map():
|
||||||
for k in filter(
|
for k in filter(
|
||||||
lambda o: o.data == self.current_task,
|
lambda o: o.data == task,
|
||||||
dict(self.selector.get_map()).values(),
|
dict(self.selector.get_map()).values(),
|
||||||
):
|
):
|
||||||
self.io_release(k.fileobj)
|
self.io_release(k.fileobj)
|
||||||
|
@ -358,7 +358,7 @@ class AsyncScheduler:
|
||||||
before it's due
|
before it's due
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if self.current_task.last_io:
|
if self.current_task.last_io or self.current_task.status == "io":
|
||||||
self.io_release_task(self.current_task)
|
self.io_release_task(self.current_task)
|
||||||
self.current_task.status = "sleep"
|
self.current_task.status = "sleep"
|
||||||
self.suspended.append(self.current_task)
|
self.suspended.append(self.current_task)
|
||||||
|
@ -441,13 +441,11 @@ class AsyncScheduler:
|
||||||
while self.deadlines and self.deadlines.get_closest_deadline() <= self.clock():
|
while self.deadlines and self.deadlines.get_closest_deadline() <= self.clock():
|
||||||
pool = self.deadlines.get()
|
pool = self.deadlines.get()
|
||||||
pool.timed_out = True
|
pool.timed_out = True
|
||||||
if not pool.tasks and self.current_task is self.entry_point:
|
self.cancel_pool(pool)
|
||||||
self.handle_task_exit(self.entry_point, partial(self.entry_point.throw, TooSlowError(self.entry_point)))
|
|
||||||
for task in pool.tasks:
|
for task in pool.tasks:
|
||||||
if not task.done():
|
self.join(task)
|
||||||
self.paused.discard(task)
|
self.handle_task_exit(self.entry_point, partial(self.entry_point.throw, TooSlowError(self.entry_point)))
|
||||||
self.io_release_task(task)
|
|
||||||
self.handle_task_exit(task, partial(task.throw, TooSlowError(task)))
|
|
||||||
|
|
||||||
def schedule_tasks(self, tasks: List[Task]):
|
def schedule_tasks(self, tasks: List[Task]):
|
||||||
"""
|
"""
|
||||||
|
@ -554,9 +552,12 @@ class AsyncScheduler:
|
||||||
self.run_ready.append(entry)
|
self.run_ready.append(entry)
|
||||||
self.debugger.on_start()
|
self.debugger.on_start()
|
||||||
if loop:
|
if loop:
|
||||||
self.run()
|
try:
|
||||||
self.has_ran = True
|
self.run()
|
||||||
self.debugger.on_exit()
|
finally:
|
||||||
|
self.has_ran = True
|
||||||
|
self.close()
|
||||||
|
self.debugger.on_exit()
|
||||||
|
|
||||||
def cancel_pool(self, pool: TaskManager) -> bool:
|
def cancel_pool(self, pool: TaskManager) -> bool:
|
||||||
"""
|
"""
|
||||||
|
@ -729,8 +730,8 @@ class AsyncScheduler:
|
||||||
self.io_release_task(task)
|
self.io_release_task(task)
|
||||||
elif task.status == "sleep":
|
elif task.status == "sleep":
|
||||||
self.paused.discard(task)
|
self.paused.discard(task)
|
||||||
if task in self.suspended:
|
if task in self.suspended:
|
||||||
self.suspended.remove(task)
|
self.suspended.remove(task)
|
||||||
try:
|
try:
|
||||||
self.do_cancel(task)
|
self.do_cancel(task)
|
||||||
except CancelledError as cancel:
|
except CancelledError as cancel:
|
||||||
|
@ -747,7 +748,6 @@ class AsyncScheduler:
|
||||||
task.cancel_pending = False
|
task.cancel_pending = False
|
||||||
task.cancelled = True
|
task.cancelled = True
|
||||||
task.status = "cancelled"
|
task.status = "cancelled"
|
||||||
self.io_release_task(self.current_task)
|
|
||||||
self.debugger.after_cancel(task)
|
self.debugger.after_cancel(task)
|
||||||
self.tasks.remove(task)
|
self.tasks.remove(task)
|
||||||
else:
|
else:
|
||||||
|
@ -758,12 +758,12 @@ class AsyncScheduler:
|
||||||
def register_sock(self, sock, evt_type: str):
|
def register_sock(self, sock, evt_type: str):
|
||||||
"""
|
"""
|
||||||
Registers the given socket inside the
|
Registers the given socket inside the
|
||||||
selector to perform I/0 multiplexing
|
selector to perform I/O multiplexing
|
||||||
|
|
||||||
:param sock: The socket on which a read or write operation
|
:param sock: The socket on which a read or write operation
|
||||||
has to be performed
|
has to be performed
|
||||||
:param evt_type: The type of event to perform on the given
|
:param evt_type: The type of event to perform on the given
|
||||||
socket, either "read" or "write"
|
socket, either "read" or "write"
|
||||||
:type evt_type: str
|
:type evt_type: str
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -797,5 +797,8 @@ class AsyncScheduler:
|
||||||
try:
|
try:
|
||||||
self.selector.register(sock, evt, self.current_task)
|
self.selector.register(sock, evt, self.current_task)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
# The socket is already registered doing something else
|
# The socket is already registered doing something else, we
|
||||||
raise ResourceBusy("The given socket is being read/written by another task") from None
|
# modify the socket instead (or maybe not?)
|
||||||
|
self.selector.modify(sock, evt, self.current_task)
|
||||||
|
# TODO: Does this break stuff?
|
||||||
|
# raise ResourceBusy("The given socket is being read/written by another task") from None
|
||||||
|
|
|
@ -16,6 +16,7 @@ See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import giambio
|
||||||
from giambio.exceptions import ResourceClosed
|
from giambio.exceptions import ResourceClosed
|
||||||
from giambio.traps import want_write, want_read, io_release
|
from giambio.traps import want_write, want_read, io_release
|
||||||
|
|
||||||
|
@ -121,6 +122,7 @@ class AsyncSocket:
|
||||||
|
|
||||||
if self.sock:
|
if self.sock:
|
||||||
self.sock.shutdown(how)
|
self.sock.shutdown(how)
|
||||||
|
await giambio.sleep(0) # Checkpoint
|
||||||
|
|
||||||
async def bind(self, addr: tuple):
|
async def bind(self, addr: tuple):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -0,0 +1,56 @@
|
||||||
|
from debugger import Debugger
|
||||||
|
import giambio
|
||||||
|
import socket
|
||||||
|
|
||||||
|
|
||||||
|
async def proxy_one_way(source: giambio.socket.AsyncSocket, sink: giambio.socket.AsyncSocket):
|
||||||
|
"""
|
||||||
|
Sends data from source to sink
|
||||||
|
"""
|
||||||
|
|
||||||
|
sink_addr = ":".join(map(str, await sink.getpeername()))
|
||||||
|
source_addr = ":".join(map(str, await source.getpeername()))
|
||||||
|
while True:
|
||||||
|
data = await source.receive(1024)
|
||||||
|
if not data:
|
||||||
|
print(f"{source_addr} has exited, closing connection to {sink_addr}")
|
||||||
|
await sink.shutdown(socket.SHUT_WR)
|
||||||
|
break
|
||||||
|
print(f"Got {data.decode('utf8', errors='ignore')!r} from {source_addr}, forwarding it to {sink_addr}")
|
||||||
|
await sink.send_all(data)
|
||||||
|
|
||||||
|
|
||||||
|
async def proxy_two_way(a: giambio.socket.AsyncSocket, b: giambio.socket.AsyncSocket):
|
||||||
|
"""
|
||||||
|
Sets up a two-way proxy from a to b and from b to a
|
||||||
|
"""
|
||||||
|
|
||||||
|
async with giambio.create_pool() as pool:
|
||||||
|
await pool.spawn(proxy_one_way, a, b)
|
||||||
|
await pool.spawn(proxy_one_way, b, a)
|
||||||
|
|
||||||
|
|
||||||
|
async def main(delay: int, a: tuple, b: tuple):
|
||||||
|
"""
|
||||||
|
Sets up the proxy
|
||||||
|
"""
|
||||||
|
|
||||||
|
start = giambio.clock()
|
||||||
|
print(f"Starting two-way proxy from {a[0]}:{a[1]} to {b[0]}:{b[1]}, lasting for {delay} seconds")
|
||||||
|
async with giambio.skip_after(delay) as p:
|
||||||
|
sock_a = giambio.socket.socket()
|
||||||
|
sock_b = giambio.socket.socket()
|
||||||
|
await sock_a.connect(a)
|
||||||
|
await sock_b.connect(b)
|
||||||
|
async with sock_a, sock_b:
|
||||||
|
await proxy_two_way(sock_a, sock_b)
|
||||||
|
print(f"Proxy has exited after {giambio.clock() - start:.2f} seconds")
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
giambio.run(main, 60, ("localhost", 12345), ("localhost", 54321), debugger=())
|
||||||
|
except (Exception, KeyboardInterrupt) as error: # Exceptions propagate!
|
||||||
|
if isinstance(error, KeyboardInterrupt):
|
||||||
|
print("Ctrl+C detected, exiting")
|
||||||
|
else:
|
||||||
|
print(f"Exiting due to a {type(error).__name__}: {error}")
|
Loading…
Reference in New Issue