Compare commits

...

4 Commits

Author SHA1 Message Date
Nocturn9x e37ffdeb06 Initial broken work on a generic streams interface 2022-10-10 13:35:22 +02:00
Nocturn9x 55868c450e Removed debugging raise statement 2022-10-10 10:22:37 +02:00
Nocturn9x 60df2f059a Fixed mistake from rebase 2022-10-10 10:21:37 +02:00
Nocturn9x d408cffa87 Bug fixes with exception handling and minor documentation improvements 2022-10-10 09:55:18 +02:00
10 changed files with 402 additions and 250 deletions

View File

@ -16,6 +16,10 @@ rock-solid and structured concurrency framework (I personally recommend trio and
that most of the content of this document is ~~stolen~~ inspired from its documentation)
# Disclaimer #2
This is a toy project. Don't try to use it in production, it *will* explode
## Goals of this project
Making yet another async library might sound dumb in an already fragmented ecosystem like Python's.

View File

@ -7,7 +7,7 @@ Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
@ -17,8 +17,7 @@ limitations under the License.
"""
import giambio
from giambio.task import Task
from typing import List, Optional, Callable, Coroutine, Any
from typing import List, Optional, Any, Coroutine, Callable
class TaskManager:
@ -32,13 +31,13 @@ class TaskManager:
:type raise_on_timeout: bool, optional
"""
def __init__(self, current_task: Task, timeout: float = None, raise_on_timeout: bool = False) -> None:
def __init__(self, timeout: float = None, raise_on_timeout: bool = True) -> None:
"""
Object constructor
"""
# All the tasks that belong to this pool
self.tasks: List[Task] = []
self.tasks: List[giambio.task.Task] = []
# Whether we have been cancelled or not
self.cancelled: bool = False
# The clock time of when we started running, used for
@ -51,19 +50,10 @@ class TaskManager:
self.timeout = None
# Whether our timeout expired or not
self.timed_out: bool = False
# Internal check so users don't try
# to use the pool manually
self._proper_init = False
# We keep track of any inner pools to propagate
# exceptions properly
self.enclosed_pool: Optional["giambio.context.TaskManager"] = None
# Do we raise an error after timeout?
self.raise_on_timeout: bool = raise_on_timeout
# The task that created the pool. We keep track of
# it because we only cancel ourselves if this task
# errors out (so if the error is caught before reaching
# it we just do nothing)
self.owner: Task = current_task
self.entry_point: Optional[Task] = None
async def spawn(self, func: Callable[..., Coroutine[Any, Any, Any]], *args, **kwargs) -> "giambio.task.Task":
"""
@ -76,10 +66,11 @@ class TaskManager:
async def __aenter__(self):
"""
Implements the asynchronous context manager interface,
Implements the asynchronous context manager interface
"""
self._proper_init = True
self.entry_point = await giambio.traps.current_task()
return self
async def __aexit__(self, exc_type: Exception, exc: Exception, tb):
@ -95,14 +86,13 @@ class TaskManager:
# children to exit
await task.join()
self.tasks.remove(task)
self._proper_init = False
if isinstance(exc, giambio.exceptions.TooSlowError) and not self.raise_on_timeout:
return True
except giambio.exceptions.TooSlowError:
if self.raise_on_timeout:
raise
finally:
self._proper_init = False
if isinstance(exc, giambio.exceptions.TooSlowError) and not self.raise_on_timeout:
return True
async def cancel(self):
"""
Cancels the pool entirely, iterating over all
@ -120,4 +110,4 @@ class TaskManager:
pool have exited, False otherwise
"""
return self._proper_init and all([task.done() for task in self.tasks])
return self._proper_init and all([task.done() for task in self.tasks]) and (True if not self.enclosed_pool else self.enclosed_pool.done())

View File

@ -7,7 +7,7 @@ Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
@ -15,8 +15,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import functools
# Import libraries and internal resources
from numbers import Number
from giambio.task import Task
from collections import deque
from functools import partial
@ -32,14 +33,15 @@ from giambio.exceptions import (
ResourceBusy,
GiambioError,
TooSlowError,
ResourceClosed
)
class AsyncScheduler:
"""
A simple task scheduler implementation that tries to mimic thread programming
in its simplicity, without using actual threads, but rather alternating
across coroutines execution to let more than one thing at a time to proceed
in its simplicity, without using actual threads, but rather alternating the
execution of coroutines to let more than one thing at a time to proceed
with its calculations. An attempt to fix the threaded model has been made
without making the API unnecessarily complicated.
@ -55,7 +57,7 @@ class AsyncScheduler:
:param clock: A callable returning monotonically increasing values at each call,
usually using seconds as units, but this is not enforced, defaults to timeit.default_timer
:type clock: :class: Callable
:type clock: :class: types.FunctionType
:param debugger: A subclass of giambio.util.BaseDebugger or None if no debugging output
is desired, defaults to None
:type debugger: :class: giambio.util.BaseDebugger
@ -72,7 +74,7 @@ class AsyncScheduler:
def __init__(
self,
clock: Callable = default_timer,
clock: Callable[[], Number] = default_timer,
debugger: Optional[BaseDebugger] = None,
selector: Optional[Any] = None,
io_skip_limit: Optional[int] = None,
@ -94,7 +96,7 @@ class AsyncScheduler:
or type(
"DumbDebugger",
(object,),
{"__getattr__": lambda *args: lambda *arg: None},
{"__getattr__": lambda *_: lambda *_: None},
)()
)
# All tasks the loop has
@ -106,7 +108,7 @@ class AsyncScheduler:
# This will always point to the currently running coroutine (Task object)
self.current_task: Optional[Task] = None
# Monotonic clock to keep track of elapsed time reliably
self.clock: Callable = clock
self.clock: Callable[[], Number] = clock
# Tasks that are asleep
self.paused: TimeQueue = TimeQueue(self.clock)
# Have we ever ran?
@ -129,6 +131,7 @@ class AsyncScheduler:
self.entry_point: Optional[Task] = None
# Suspended tasks
self.suspended: deque = deque()
def __repr__(self):
"""
@ -150,6 +153,8 @@ class AsyncScheduler:
"_data",
"io_skip_limit",
"io_max_timeout",
"suspended",
"entry_point"
}
data = ", ".join(
name + "=" + str(value) for name, value in zip(fields, (getattr(self, field) for field in fields))
@ -168,7 +173,7 @@ class AsyncScheduler:
Shuts down the event loop
"""
for task in self.tasks:
for task in self.get_all_tasks():
self.io_release_task(task)
self.selector.close()
# TODO: Anything else?
@ -206,7 +211,10 @@ class AsyncScheduler:
# after it is set, but it makes the implementation easier
if not self.current_pool and self.current_task.pool:
self.current_pool = self.current_task.pool
self.deadlines.put(self.current_pool)
pool = self.current_pool
while pool:
self.deadlines.put(pool)
pool = self.current_pool.enclosed_pool
# If there are no actively running tasks, we start by
# checking for I/O. This method will wait for I/O until
# the closest deadline to avoid starving sleeping tasks
@ -230,9 +238,10 @@ class AsyncScheduler:
# some tricky behaviors, and this is one of them. When a coroutine
# hits a return statement (either explicit or implicit), it raises
# a StopIteration exception, which has an attribute named value that
# represents the return value of the coroutine, if any. Of course this
# exception is not an error and we should happily keep going after it,
# represents the return value of the coroutine, if it has one. Of course
# this exception is not an error and we should happily keep going after it:
# most of this code below is just useful for internal/debugging purposes
self.current_task.status = "end"
self.current_task.result = ret.value
self.current_task.finished = True
self.join(self.current_task)
@ -244,20 +253,22 @@ class AsyncScheduler:
self.current_task.exc = err
self.join(self.current_task)
def create_task(self, coro: Coroutine[Any, Any, Any], pool) -> Task:
def create_task(self, corofunc: Callable[..., Coroutine[Any, Any, Any]], pool, *args, **kwargs) -> Task:
"""
Creates a task from a coroutine function and schedules it
to run. The associated pool that spawned said task is also
needed, while any extra keyword or positional arguments are
passed to the function itself
:param coro: The coroutine to spawn
:type coro: Coroutine[Any, Any, Any]
:param corofunc: The coroutine function (NOT a coroutine!) to
spawn
:type corofunc: function
:param pool: The giambio.context.TaskManager object that
spawned the task
"""
task = Task(coro.__name__ or str(coro), coro, pool)
task = Task(corofunc.__name__ or str(corofunc), corofunc(*args, **kwargs), pool)
task.next_deadline = pool.timeout or 0.0
task.joiners = {self.current_task}
self._data[self.current_task] = task
@ -288,9 +299,15 @@ class AsyncScheduler:
# We need to make sure we don't try to execute
# exited tasks that are on the running queue
return
if not self.current_pool and self.current_task.pool:
if self.current_pool:
if self.current_task.pool and self.current_task.pool is not self.current_pool:
self.current_task.pool.enclosed_pool = self.current_pool
else:
self.current_pool = self.current_task.pool
self.deadlines.put(self.current_pool)
pool = self.current_pool
while pool:
self.deadlines.put(pool)
pool = self.current_pool.enclosed_pool
self.debugger.before_task_step(self.current_task)
# Some debugging and internal chatter here
self.current_task.status = "run"
@ -319,7 +336,7 @@ class AsyncScheduler:
def io_release(self, sock):
"""
Releases the given resource from our
selector.
selector
:param sock: The resource to be released
"""
@ -334,7 +351,7 @@ class AsyncScheduler:
if self.selector.get_map():
for k in filter(
lambda o: o.data == self.current_task,
lambda o: o.data == task,
dict(self.selector.get_map()).values(),
):
self.io_release(k.fileobj)
@ -344,11 +361,16 @@ class AsyncScheduler:
"""
Suspends execution of the current task. This is basically
a do-nothing method, since it will not reschedule the task
before returning. The task will stay suspended until a timer,
I/O operation or cancellation wakes it up, or until another
running task reschedules it.
before returning. The task will stay suspended as long as
something else outside the loop calls a trap to reschedule it.
Any pending I/O for the task is temporarily unscheduled to
avoid some previous network operation to reschedule the task
before it's due
"""
if self.current_task.last_io or self.current_task.status == "io":
self.io_release_task(self.current_task)
self.current_task.status = "sleep"
self.suspended.append(self.current_task)
def reschedule_running(self):
@ -408,27 +430,35 @@ class AsyncScheduler:
try:
to_call()
except StopIteration as ret:
task.status = "end"
task.result = ret.value
task.finished = True
self.join(task)
except CancelledError as cancel:
task.status = "cancelled"
task.cancel_pending = False
task.cancelled = True
self.join(task)
except BaseException as err:
task.exc = err
self.join(task)
def prune_deadlines(self):
"""
Removes expired deadlines after their timeout
has expired and cancels their associated pool
has expired
"""
while self.deadlines and self.deadlines.get_closest_deadline() <= self.clock():
pool = self.deadlines.get()
pool.timed_out = True
self.cancel_pool(pool)
for task in pool.tasks:
if task is not pool.owner:
self.handle_task_exit(task, partial(task.throw, TooSlowError(self.current_task)))
if pool.raise_on_timeout:
self.handle_task_exit(pool.owner, partial(pool.owner.throw, TooSlowError(self.current_task)))
self.join(task)
if pool.entry_point is self.entry_point:
self.handle_task_exit(self.entry_point, partial(self.entry_point.throw, TooSlowError(self.entry_point)))
self.run_ready.append(self.entry_point)
def schedule_tasks(self, tasks: List[Task]):
"""
@ -439,7 +469,8 @@ class AsyncScheduler:
for task in tasks:
self.paused.discard(task)
self.suspended.remove(task)
if task in self.suspended:
self.suspended.remove(task)
self.run_ready.extend(tasks)
self.reschedule_running()
@ -462,6 +493,7 @@ class AsyncScheduler:
self.run_ready.append(task)
self.debugger.after_sleep(task, slept)
def get_closest_deadline(self) -> float:
"""
Gets the closest expiration deadline (asleep tasks, timeouts)
@ -469,7 +501,7 @@ class AsyncScheduler:
:return: The closest deadline according to our clock
:rtype: float
"""
if not self.deadlines:
# If there are no deadlines just wait until the first task wakeup
timeout = max(0.0, self.paused.get_closest_deadline() - self.clock())
@ -535,9 +567,12 @@ class AsyncScheduler:
self.run_ready.append(entry)
self.debugger.on_start()
if loop:
self.run()
self.has_ran = True
self.debugger.on_exit()
try:
self.run()
finally:
self.has_ran = True
self.close()
self.debugger.on_exit()
def cancel_pool(self, pool: TaskManager) -> bool:
"""
@ -589,8 +624,9 @@ class AsyncScheduler:
If ensure_done equals False, the loop will cancel ALL
running and scheduled tasks and then tear itself down.
If ensure_done equals True, which is the default behavior,
this method will raise a GiambioError if the loop hasn't
finished running.
this method will raise a GiambioError exception if the loop
hasn't finished running. The state of the event loop is reset
so it can be reused with another run() call
"""
if ensure_done:
@ -598,6 +634,16 @@ class AsyncScheduler:
elif not self.done():
raise GiambioError("event loop not terminated, call this method with ensure_done=False to forcefully exit")
self.shutdown()
# We reset the event loop's state
self.tasks = []
self.entry_point = None
self.current_pool = None
self.current_task = None
self.paused = TimeQueue(self.clock)
self.deadlines = DeadlinesQueue()
self.run_ready = deque()
self.suspended = deque()
def reschedule_joiners(self, task: Task):
"""
@ -605,87 +651,71 @@ class AsyncScheduler:
given task, if any
"""
for t in task.joiners:
self.run_ready.append(t)
# noinspection PyMethodMayBeStatic
def is_pool_done(self, pool: Optional[TaskManager]):
"""
Returns True if a given pool has finished
executing
"""
while pool:
if not pool.done():
return False
pool = pool.enclosed_pool
return True
if task.pool and task.pool.enclosed_pool and not task.pool.enclosed_pool.done():
return
self.run_ready.extend(task.joiners)
def join(self, task: Task):
"""
Joins a task to its callers (implicitly, the parent
Joins a task to its callers (implicitly the parent
task, but also every other task who called await
task.join() on the task object)
"""
task.joined = True
if any([task.finished, task.cancelled, task.exc]) and task in self.tasks:
self.io_release_task(task)
self.tasks.remove(task)
self.paused.discard(task)
if task.finished or task.cancelled:
task.status = "end"
if not task.cancelled:
task.status = "cancelled"
# This way join() returns the
# task's return value
for joiner in task.joiners:
self._data[joiner] = task.result
self.debugger.on_task_exit(task)
# If the pool has finished executing or we're at the first parent
# task that kicked the loop, we can safely reschedule the parent(s)
if self.is_pool_done(task.pool):
if task.last_io:
self.io_release_task(task)
if task in self.suspended:
self.suspended.remove(task)
if task in self.tasks:
self.tasks.remove(task)
# If the pool (including any enclosing pools) has finished executing
# or we're at the first task that kicked the loop, we can safely
# reschedule the parent(s)
if task.pool is None:
return
if task.pool.done():
self.reschedule_joiners(task)
self.reschedule_running()
elif task.exc:
if task in self.suspended:
self.suspended.remove(task)
task.status = "crashed"
if task.exc.__traceback__:
# TODO: We might want to do a bit more complex traceback hacking to remove any extra
# frames from the exception call stack, but for now removing at least the first few
# seems a sensible approach (it's us catching it so we don't care about that)
for _ in range(5):
if task.exc.__traceback__.tb_next:
task.exc.__traceback__ = task.exc.__traceback__.tb_next
self.debugger.on_exception_raised(task, task.exc)
if task is self.entry_point and not task.pool:
try:
task.throw(task.exc)
except StopIteration:
... # TODO: ?
except BaseException:
# TODO: No idea what to do here
raise
elif any(map(lambda tk: tk is task.pool.owner, task.joiners)) or task is task.pool.owner:
# We check if the pool's
# owner catches our error
# or not. If they don't, we
# cancel the entire pool, but
# if they do, we do nothing
if task.pool.owner is not task:
self.handle_task_exit(task.pool.owner, partial(task.pool.owner.coroutine.throw, task.exc))
if any([task.pool.owner.exc, task.pool.owner.cancelled, task.pool.owner.finished]):
for t in task.joiners.copy():
# Propagate the exception
self.handle_task_exit(t, partial(t.throw, task.exc))
if any([t.exc, t.finished, t.cancelled]):
task.joiners.remove(t)
for t in task.pool.tasks:
if not t.joined:
self.handle_task_exit(t, partial(t.throw, task.exc))
if any([t.exc, t.finished, t.cancelled]):
task.joiners.discard(t)
self.reschedule_joiners(task)
self.reschedule_running()
if task.pool is None or task is self.entry_point:
# Parent task has no pool, so we propagate
raise task.exc
if self.cancel_pool(task.pool):
# This will reschedule the parent(s)
# only if all the tasks inside the task's
# pool have finished executing, either
# by cancellation, an exception
# or just returned
for t in task.joiners.copy():
# Propagate the exception
try:
t.throw(task.exc)
except (StopIteration, CancelledError, RuntimeError) as e:
# TODO: Need anything else?
task.joiners.remove(t)
if isinstance(e, StopIteration):
t.status = "end"
t.result = e.value
t.finished = True
elif isinstance(e, CancelledError):
t = e.task
t.cancel_pending = False
t.cancelled = True
t.status = "cancelled"
self.debugger.after_cancel(t)
elif isinstance(e, BaseException):
t.exc = e
finally:
if t in self.tasks:
self.tasks.remove(t)
self.reschedule_joiners(task)
def sleep(self, seconds: int or float):
"""
@ -727,6 +757,8 @@ class AsyncScheduler:
self.io_release_task(task)
elif task.status == "sleep":
self.paused.discard(task)
if task in self.suspended:
self.suspended.remove(task)
try:
self.do_cancel(task)
except CancelledError as cancel:
@ -742,24 +774,36 @@ class AsyncScheduler:
task = cancel.task
task.cancel_pending = False
task.cancelled = True
self.io_release_task(self.current_task)
task.status = "cancelled"
self.debugger.after_cancel(task)
self.tasks.remove(task)
self.join(task)
else:
# If the task ignores our exception, we'll
# raise it later again
task.cancel_pending = True
self.join(task)
def notify_closing(self, stream):
"""
Implements the notify_closing trap
"""
if self.selector.get_map():
for k in filter(
lambda o: o.data == self.current_task,
dict(self.selector.get_map()).values(),
):
self.handle_task_exit(k.data,
functools.partial(k.data.throw(ResourceClosed("stream has been closed"))))
def register_sock(self, sock, evt_type: str):
"""
Registers the given socket inside the
selector to perform I/0 multiplexing
selector to perform I/O multiplexing
:param sock: The socket on which a read or write operation
has to be performed
has to be performed
:param evt_type: The type of event to perform on the given
socket, either "read" or "write"
socket, either "read" or "write"
:type evt_type: str
"""
@ -793,5 +837,8 @@ class AsyncScheduler:
try:
self.selector.register(sock, evt, self.current_task)
except KeyError:
# The socket is already registered doing something else
raise ResourceBusy("The given socket is being read/written by another task") from None
# The socket is already registered doing something else, we
# modify the socket instead (or maybe not?)
self.selector.modify(sock, evt, self.current_task)
# TODO: Does this break stuff?
# raise ResourceBusy("The given socket is being read/written by another task") from None

View File

@ -37,7 +37,7 @@ class InternalError(GiambioError):
...
class CancelledError(GiambioError):
class CancelledError(BaseException):
"""
Exception raised by the giambio.objects.Task.cancel() method
to terminate a child task. This should NOT be caught, or

View File

@ -15,14 +15,16 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import socket
import warnings
import os
import giambio
from giambio.exceptions import ResourceClosed
from giambio.traps import want_write, want_read, io_release
from giambio.traps import want_write, want_read, io_release, notify_closing
try:
from ssl import SSLWantReadError, SSLWantWriteError
from ssl import SSLWantReadError, SSLWantWriteError, SSLSocket
WantRead = (BlockingIOError, InterruptedError, SSLWantReadError)
WantWrite = (BlockingIOError, InterruptedError, SSLWantWriteError)
@ -31,16 +33,115 @@ except ImportError:
WantWrite = (BlockingIOError, InterruptedError)
class AsyncSocket:
class AsyncStream:
"""
A generic asynchronous stream over
a file descriptor. Only works on Linux
& co because windows doesn't like select()
to be called on non-socket objects
(Thanks, Microsoft)
"""
def __init__(self, fd: int, open_fd: bool = True, close_on_context_exit: bool = True, **kwargs):
self._fd = fd
self.stream = None
if open_fd:
self.stream = os.fdopen(self._fd, **kwargs)
os.set_blocking(self._fd, False)
self.close_on_context_exit = close_on_context_exit
async def read(self, size: int = -1):
"""
Reads up to size bytes from the
given stream. If size == -1, read
until EOF is reached
"""
while True:
try:
return self.stream.read(size)
except WantRead:
await want_read(self.stream)
async def write(self, data):
"""
Writes data b to the file.
Returns the number of bytes
written
"""
while True:
try:
return self.stream.write(data)
except WantWrite:
await want_write(self.stream)
async def close(self):
"""
Closes the stream asynchronously
"""
if self._fd == -1:
raise ResourceClosed("I/O operation on closed stream")
self._fd = -1
await notify_closing(self.stream)
await io_release(self.stream)
self.stream.close()
self.stream = None
@property
async def fileno(self):
"""
Wrapper socket method
"""
return self._fd
async def __aenter__(self):
self.stream.__enter__()
return self
async def __aexit__(self, *args):
if self._fd != -1 and self.close_on_context_exit:
await self.close()
async def dup(self):
"""
Wrapper stream method
"""
return type(self)(os.dup(self._fd))
def __repr__(self):
return f"AsyncStream({self.stream})"
def __del__(self):
"""
Stream destructor. Do *not* call
this directly: stuff will break
"""
if self._fd != -1:
try:
os.set_blocking(self._fd, False)
os.close(self._fd)
except OSError as e:
warnings.warn(f"An exception occurred in __del__ for stream {self} -> {type(e).__name__}: {e}")
class AsyncSocket(AsyncStream):
"""
Abstraction layer for asynchronous sockets
"""
def __init__(self, sock, do_handshake_on_connect: bool = True):
self.sock = sock
def __init__(self, sock: socket.socket, close_on_context_exit: bool = True, do_handshake_on_connect: bool = True):
super().__init__(sock.fileno(), open_fd=False, close_on_context_exit=close_on_context_exit)
self.do_handshake_on_connect = do_handshake_on_connect
self._fd = sock.fileno()
self.sock.setblocking(False)
self.stream = socket.fromfd(self._fd, sock.family, sock.type, sock.proto)
self.stream.setblocking(False)
# A socket that isn't connected doesn't
# need to be closed
self.needs_closing: bool = False
async def receive(self, max_size: int, flags: int = 0) -> bytes:
"""
@ -52,11 +153,11 @@ class AsyncSocket:
raise ResourceClosed("I/O operation on closed socket")
while True:
try:
return self.sock.recv(max_size, flags)
return self.stream.recv(max_size, flags)
except WantRead:
await want_read(self.sock)
await want_read(self.stream)
except WantWrite:
await want_write(self.sock)
await want_write(self.stream)
async def connect(self, address):
"""
@ -67,12 +168,21 @@ class AsyncSocket:
raise ResourceClosed("I/O operation on closed socket")
while True:
try:
self.sock.connect(address)
self.stream.connect(address)
if self.do_handshake_on_connect:
await self.do_handshake()
return
break
except WantWrite:
await want_write(self.sock)
await want_write(self.stream)
self.needs_closing = True
async def close(self):
"""
Wrapper socket method
"""
if self.needs_closing:
await super().close()
async def accept(self):
"""
@ -83,10 +193,10 @@ class AsyncSocket:
raise ResourceClosed("I/O operation on closed socket")
while True:
try:
remote, addr = self.sock.accept()
remote, addr = self.stream.accept()
return type(self)(remote), addr
except WantRead:
await want_read(self.sock)
await want_read(self.stream)
async def send_all(self, data: bytes, flags: int = 0):
"""
@ -98,32 +208,20 @@ class AsyncSocket:
sent_no = 0
while data:
try:
sent_no = self.sock.send(data, flags)
sent_no = self.stream.send(data, flags)
except WantRead:
await want_read(self.sock)
await want_read(self.stream)
except WantWrite:
await want_write(self.sock)
await want_write(self.stream)
data = data[sent_no:]
async def close(self):
"""
Closes the socket asynchronously
"""
if self._fd == -1:
raise ResourceClosed("I/O operation on closed socket")
await io_release(self.sock)
self.sock.close()
self._fd = -1
self.sock = None
async def shutdown(self, how):
"""
Wrapper socket method
"""
if self.sock:
self.sock.shutdown(how)
if self.stream:
self.stream.shutdown(how)
await giambio.sleep(0) # Checkpoint
async def bind(self, addr: tuple):
@ -136,7 +234,7 @@ class AsyncSocket:
if self._fd == -1:
raise ResourceClosed("I/O operation on closed socket")
self.sock.bind(addr)
self.stream.bind(addr)
async def listen(self, backlog: int):
"""
@ -148,27 +246,12 @@ class AsyncSocket:
if self._fd == -1:
raise ResourceClosed("I/O operation on closed socket")
self.sock.listen(backlog)
async def __aenter__(self):
self.sock.__enter__()
return self
async def __aexit__(self, *args):
if self.sock:
self.sock.__exit__(*args)
self.stream.listen(backlog)
# Yes, I stole these from Curio because I could not be
# arsed to write a bunch of uninteresting simple socket
# methods from scratch, deal with it.
async def fileno(self):
"""
Wrapper socket method
"""
return self._fd
async def settimeout(self, seconds):
"""
Wrapper socket method
@ -188,22 +271,23 @@ class AsyncSocket:
Wrapper socket method
"""
return type(self)(self.sock.dup())
return type(self)(self.stream.dup(), self.do_handshake_on_connect)
async def do_handshake(self):
"""
Wrapper socket method
"""
if not hasattr(self.sock, "do_handshake"):
if not hasattr(self.stream, "do_handshake"):
return
while True:
try:
return self.sock.do_handshake()
self.stream: SSLSocket # Silences pycharm warnings
return self.stream.do_handshake()
except WantRead:
await want_read(self.sock)
await want_read(self.stream)
except WantWrite:
await want_write(self.sock)
await want_write(self.stream)
async def recvfrom(self, buffersize, flags=0):
"""
@ -212,11 +296,11 @@ class AsyncSocket:
while True:
try:
return self.sock.recvfrom(buffersize, flags)
return self.stream.recvfrom(buffersize, flags)
except WantRead:
await want_read(self.sock)
await want_read(self.stream)
except WantWrite:
await want_write(self.sock)
await want_write(self.stream)
async def recvfrom_into(self, buffer, bytes=0, flags=0):
"""
@ -225,11 +309,11 @@ class AsyncSocket:
while True:
try:
return self.sock.recvfrom_into(buffer, bytes, flags)
return self.stream.recvfrom_into(buffer, bytes, flags)
except WantRead:
await want_read(self.sock)
await want_read(self.stream)
except WantWrite:
await want_write(self.sock)
await want_write(self.stream)
async def sendto(self, bytes, flags_or_address, address=None):
"""
@ -243,11 +327,11 @@ class AsyncSocket:
flags = 0
while True:
try:
return self.sock.sendto(bytes, flags, address)
return self.stream.sendto(bytes, flags, address)
except WantWrite:
await want_write(self.sock)
await want_write(self.stream)
except WantRead:
await want_read(self.sock)
await want_read(self.stream)
async def getpeername(self):
"""
@ -256,11 +340,11 @@ class AsyncSocket:
while True:
try:
return self.sock.getpeername()
return self.stream.getpeername()
except WantWrite:
await want_write(self.sock)
await want_write(self.stream)
except WantRead:
await want_read(self.sock)
await want_read(self.stream)
async def getsockname(self):
"""
@ -269,11 +353,11 @@ class AsyncSocket:
while True:
try:
return self.sock.getpeername()
return self.stream.getpeername()
except WantWrite:
await want_write(self.sock)
await want_write(self.stream)
except WantRead:
await want_read(self.sock)
await want_read(self.stream)
async def recvmsg(self, bufsize, ancbufsize=0, flags=0):
"""
@ -282,9 +366,9 @@ class AsyncSocket:
while True:
try:
return self.sock.recvmsg(bufsize, ancbufsize, flags)
return self.stream.recvmsg(bufsize, ancbufsize, flags)
except WantRead:
await want_read(self.sock)
await want_read(self.stream)
async def recvmsg_into(self, buffers, ancbufsize=0, flags=0):
"""
@ -293,9 +377,9 @@ class AsyncSocket:
while True:
try:
return self.sock.recvmsg_into(buffers, ancbufsize, flags)
return self.stream.recvmsg_into(buffers, ancbufsize, flags)
except WantRead:
await want_read(self.sock)
await want_read(self.stream)
async def sendmsg(self, buffers, ancdata=(), flags=0, address=None):
"""
@ -304,17 +388,13 @@ class AsyncSocket:
while True:
try:
return self.sock.sendmsg(buffers, ancdata, flags, address)
return self.stream.sendmsg(buffers, ancdata, flags, address)
except WantRead:
await want_write(self.sock)
await want_write(self.stream)
def __repr__(self):
return f"AsyncSocket({self.sock})"
return f"AsyncSocket({self.stream})"
def __del__(self):
"""
Socket destructor
"""
if not self._fd != -1:
warnings.warn(f"socket '{self}' was destroyed, but was not closed, leading to a potential resource leak")
if self.needs_closing:
super().__del__()

View File

@ -92,7 +92,7 @@ def create_pool():
Creates an async pool
"""
return TaskManager(get_event_loop().current_task)
return TaskManager()
def with_timeout(timeout: int or float):
@ -101,7 +101,7 @@ def with_timeout(timeout: int or float):
"""
assert timeout > 0, "The timeout must be greater than 0"
mgr = TaskManager(get_event_loop().current_task, timeout, True)
mgr = TaskManager(timeout, True)
loop = get_event_loop()
if loop.current_task is loop.entry_point:
loop.current_pool = mgr
@ -117,7 +117,7 @@ def skip_after(timeout: int or float):
"""
assert timeout > 0, "The timeout must be greater than 0"
mgr = TaskManager(get_event_loop().current_task, timeout)
mgr = TaskManager(timeout)
loop = get_event_loop()
if loop.current_task is loop.entry_point:
loop.current_pool = mgr

View File

@ -67,7 +67,7 @@ async def create_task(coro: Callable[[Any, Any], Coroutine[Any, Any, Any]], pool
"\nWhat you wanna do, instead, is this: pool.create_task(your_func, arg1, arg2, ...)"
)
elif inspect.iscoroutinefunction(coro):
return await create_trap("create_task", coro(*args, **kwargs), pool)
return await create_trap("create_task", coro, pool, *args, **kwargs)
else:
raise TypeError("coro must be a coroutine function")
@ -178,6 +178,19 @@ async def want_write(stream):
await create_trap("register_sock", stream, "write")
async def notify_closing(stream):
"""
Notifies the event loop that a given
stream needs to be closed. This makes
all callers waiting on want_read or
want_write crash with a ResourceClosed
exception, but it doesn't actually close
the socket object itself
"""
await create_trap("notify_closing", stream)
async def schedule_tasks(tasks: Iterable[Task]):
"""
Schedules a list of tasks for execution. Usuaully

View File

@ -1,44 +1,49 @@
import sys
from typing import Tuple
import giambio
import logging
from debugger import Debugger
async def sender(sock: giambio.socket.AsyncSocket, q: giambio.Queue):
async def reader(q: giambio.Queue, prompt: str = ""):
in_stream = giambio.io.AsyncStream(sys.stdin.fileno(), close_on_context_exit=False, mode="r")
out_stream = giambio.io.AsyncStream(sys.stdout.fileno(), close_on_context_exit=False, mode="w")
while True:
await sock.send_all(b"yo")
await q.put((0, ""))
await giambio.sleep(1)
await out_stream.write(prompt)
await q.put((0, await in_stream.read()))
async def receiver(sock: giambio.socket.AsyncSocket, q: giambio.Queue):
data = b""
while True:
while not data.endswith(b"\n"):
data += await sock.receive(1024)
temp = await sock.receive(1024)
if not temp:
raise EOFError("end of file")
data += temp
data, rest = data.split(b"\n", maxsplit=2)
buffer = b"".join(rest)
await q.put((1, data.decode()))
data = buffer
async def main(host: Tuple[str, int]):
async def main(host: tuple[str, int]):
"""
Main client entry point
"""
queue = giambio.Queue()
out_stream = giambio.io.AsyncStream(sys.stdout.fileno(), close_on_context_exit=False, mode="w")
async with giambio.create_pool() as pool:
async with giambio.socket.socket() as sock:
await sock.connect(host)
await pool.spawn(sender, sock, queue)
await out_stream.write("Connection successful\n")
await pool.spawn(receiver, sock, queue)
await pool.spawn(reader, queue, "> ")
while True:
op, data = await queue.get()
if op == 0:
print(f"Sent.")
else:
print(f"Received: {data}")
if op == 1:
await out_stream.write(data)
if __name__ == "__main__":
@ -49,7 +54,7 @@ if __name__ == "__main__":
datefmt="%d/%m/%Y %p",
)
try:
giambio.run(main, ("localhost", port))
giambio.run(main, ("localhost", port), debugger=Debugger())
except (Exception, KeyboardInterrupt) as error: # Exceptions propagate!
if isinstance(error, KeyboardInterrupt):
logging.info("Ctrl+C detected, exiting")

View File

@ -1,4 +1,3 @@
from typing import List
import giambio
from giambio.socket import AsyncSocket
import logging
@ -6,7 +5,8 @@ import sys
# An asynchronous chatroom
clients: List[giambio.socket.AsyncSocket] = []
clients: dict[AsyncSocket, list[str, str]] = {}
names: set[str] = set()
async def serve(bind_address: tuple):
@ -26,39 +26,52 @@ async def serve(bind_address: tuple):
while True:
try:
conn, address_tuple = await sock.accept()
clients.append(conn)
clients[conn] = ["", f"{address_tuple[0]}:{address_tuple[1]}"]
logging.info(f"{address_tuple[0]}:{address_tuple[1]} connected")
await pool.spawn(handler, conn, address_tuple)
await pool.spawn(handler, conn)
except Exception as err:
# Because exceptions just *work*
logging.info(f"{address_tuple[0]}:{address_tuple[1]} has raised {type(err).__name__}: {err}")
async def handler(sock: AsyncSocket, client_address: tuple):
async def handler(sock: AsyncSocket):
"""
Handles a single client connection
:param sock: The AsyncSocket object connected to the client
:param client_address: The client's address represented as a tuple
(address, port) where address is a string and port is an integer
:type client_address: tuple
"""
address = f"{client_address[0]}:{client_address[1]}"
address = clients[sock][1]
name = ""
async with sock: # Closes the socket automatically
await sock.send_all(b"Welcome to the chatroom pal, start typing and press enter!\n")
await sock.send_all(b"Welcome to the chatroom pal, may you tell me your name?\n> ")
while True:
while not name.endswith("\n"):
name = (await sock.receive(64)).decode()
name = name[:-1]
if name not in names:
names.add(name)
clients[sock][0] = name
break
else:
await sock.send_all(b"Sorry, but that name is already taken. Try again!\n> ")
await sock.send_all(f"Okay {name}, welcome to the chatroom!\n".encode())
logging.info(f"{name} has joined the chatroom ({address}), informing clients")
for i, client_sock in enumerate(clients):
if client_sock != sock and clients[client_sock][0]:
await client_sock.send_all(f"{name} joins the chatroom!\n> ".encode())
while True:
await sock.send_all(b"> ")
data = await sock.receive(1024)
if not data:
break
elif data == b"exit\n":
await sock.send_all(b"I'm dead dude\n")
raise TypeError("Oh, no, I'm gonna die!")
logging.info(f"Got: {data!r} from {address}")
for i, client_sock in enumerate(clients):
logging.info(f"Sending {data!r} to {':'.join(map(str, await client_sock.getpeername()))}")
if client_sock != sock:
await client_sock.send_all(data)
if client_sock != sock and clients[client_sock][0]:
logging.info(f"Sending {data!r} to {':'.join(map(str, await client_sock.getpeername()))}")
if not data.endswith(b"\n"):
data += b"\n"
await client_sock.send_all(f"[{name}] ({address}): {data.decode()}> ".encode())
logging.info(f"Sent {data!r} to {i} clients")
logging.info(f"Connection from {address} closed")
clients.remove(sock)

View File

@ -63,7 +63,7 @@ if __name__ == "__main__":
logging.basicConfig(
level=20,
format="[%(levelname)s] %(asctime)s %(message)s",
datefmt="%d/%m/%Y %p",
datefmt="%d/%m/%Y %H:%M:%S %p",
)
try:
giambio.run(serve, ("localhost", port), debugger=())