mirror of https://github.com/nocturn9x/giambio.git
Compare commits
4 Commits
6d089d7d5f
...
e37ffdeb06
Author | SHA1 | Date |
---|---|---|
Nocturn9x | e37ffdeb06 | |
Nocturn9x | 55868c450e | |
Nocturn9x | 60df2f059a | |
Nocturn9x | d408cffa87 |
|
@ -16,6 +16,10 @@ rock-solid and structured concurrency framework (I personally recommend trio and
|
|||
that most of the content of this document is ~~stolen~~ inspired from its documentation)
|
||||
|
||||
|
||||
# Disclaimer #2
|
||||
|
||||
This is a toy project. Don't try to use it in production, it *will* explode
|
||||
|
||||
## Goals of this project
|
||||
|
||||
Making yet another async library might sound dumb in an already fragmented ecosystem like Python's.
|
||||
|
|
|
@ -7,7 +7,7 @@ Licensed under the Apache License, Version 2.0 (the "License");
|
|||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
https://www.apache.org/licenses/LICENSE-2.0
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
|
@ -17,8 +17,7 @@ limitations under the License.
|
|||
"""
|
||||
|
||||
import giambio
|
||||
from giambio.task import Task
|
||||
from typing import List, Optional, Callable, Coroutine, Any
|
||||
from typing import List, Optional, Any, Coroutine, Callable
|
||||
|
||||
|
||||
class TaskManager:
|
||||
|
@ -32,13 +31,13 @@ class TaskManager:
|
|||
:type raise_on_timeout: bool, optional
|
||||
"""
|
||||
|
||||
def __init__(self, current_task: Task, timeout: float = None, raise_on_timeout: bool = False) -> None:
|
||||
def __init__(self, timeout: float = None, raise_on_timeout: bool = True) -> None:
|
||||
"""
|
||||
Object constructor
|
||||
"""
|
||||
|
||||
# All the tasks that belong to this pool
|
||||
self.tasks: List[Task] = []
|
||||
self.tasks: List[giambio.task.Task] = []
|
||||
# Whether we have been cancelled or not
|
||||
self.cancelled: bool = False
|
||||
# The clock time of when we started running, used for
|
||||
|
@ -51,19 +50,10 @@ class TaskManager:
|
|||
self.timeout = None
|
||||
# Whether our timeout expired or not
|
||||
self.timed_out: bool = False
|
||||
# Internal check so users don't try
|
||||
# to use the pool manually
|
||||
self._proper_init = False
|
||||
# We keep track of any inner pools to propagate
|
||||
# exceptions properly
|
||||
self.enclosed_pool: Optional["giambio.context.TaskManager"] = None
|
||||
# Do we raise an error after timeout?
|
||||
self.raise_on_timeout: bool = raise_on_timeout
|
||||
# The task that created the pool. We keep track of
|
||||
# it because we only cancel ourselves if this task
|
||||
# errors out (so if the error is caught before reaching
|
||||
# it we just do nothing)
|
||||
self.owner: Task = current_task
|
||||
self.entry_point: Optional[Task] = None
|
||||
|
||||
async def spawn(self, func: Callable[..., Coroutine[Any, Any, Any]], *args, **kwargs) -> "giambio.task.Task":
|
||||
"""
|
||||
|
@ -76,10 +66,11 @@ class TaskManager:
|
|||
|
||||
async def __aenter__(self):
|
||||
"""
|
||||
Implements the asynchronous context manager interface,
|
||||
Implements the asynchronous context manager interface
|
||||
"""
|
||||
|
||||
self._proper_init = True
|
||||
self.entry_point = await giambio.traps.current_task()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type: Exception, exc: Exception, tb):
|
||||
|
@ -95,14 +86,13 @@ class TaskManager:
|
|||
# children to exit
|
||||
await task.join()
|
||||
self.tasks.remove(task)
|
||||
self._proper_init = False
|
||||
if isinstance(exc, giambio.exceptions.TooSlowError) and not self.raise_on_timeout:
|
||||
return True
|
||||
except giambio.exceptions.TooSlowError:
|
||||
if self.raise_on_timeout:
|
||||
raise
|
||||
finally:
|
||||
self._proper_init = False
|
||||
if isinstance(exc, giambio.exceptions.TooSlowError) and not self.raise_on_timeout:
|
||||
return True
|
||||
|
||||
|
||||
async def cancel(self):
|
||||
"""
|
||||
Cancels the pool entirely, iterating over all
|
||||
|
@ -120,4 +110,4 @@ class TaskManager:
|
|||
pool have exited, False otherwise
|
||||
"""
|
||||
|
||||
return self._proper_init and all([task.done() for task in self.tasks])
|
||||
return self._proper_init and all([task.done() for task in self.tasks]) and (True if not self.enclosed_pool else self.enclosed_pool.done())
|
||||
|
|
267
giambio/core.py
267
giambio/core.py
|
@ -7,7 +7,7 @@ Licensed under the Apache License, Version 2.0 (the "License");
|
|||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
https://www.apache.org/licenses/LICENSE-2.0
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
|
@ -15,8 +15,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import functools
|
||||
# Import libraries and internal resources
|
||||
from numbers import Number
|
||||
from giambio.task import Task
|
||||
from collections import deque
|
||||
from functools import partial
|
||||
|
@ -32,14 +33,15 @@ from giambio.exceptions import (
|
|||
ResourceBusy,
|
||||
GiambioError,
|
||||
TooSlowError,
|
||||
ResourceClosed
|
||||
)
|
||||
|
||||
|
||||
class AsyncScheduler:
|
||||
"""
|
||||
A simple task scheduler implementation that tries to mimic thread programming
|
||||
in its simplicity, without using actual threads, but rather alternating
|
||||
across coroutines execution to let more than one thing at a time to proceed
|
||||
in its simplicity, without using actual threads, but rather alternating the
|
||||
execution of coroutines to let more than one thing at a time to proceed
|
||||
with its calculations. An attempt to fix the threaded model has been made
|
||||
without making the API unnecessarily complicated.
|
||||
|
||||
|
@ -55,7 +57,7 @@ class AsyncScheduler:
|
|||
|
||||
:param clock: A callable returning monotonically increasing values at each call,
|
||||
usually using seconds as units, but this is not enforced, defaults to timeit.default_timer
|
||||
:type clock: :class: Callable
|
||||
:type clock: :class: types.FunctionType
|
||||
:param debugger: A subclass of giambio.util.BaseDebugger or None if no debugging output
|
||||
is desired, defaults to None
|
||||
:type debugger: :class: giambio.util.BaseDebugger
|
||||
|
@ -72,7 +74,7 @@ class AsyncScheduler:
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
clock: Callable = default_timer,
|
||||
clock: Callable[[], Number] = default_timer,
|
||||
debugger: Optional[BaseDebugger] = None,
|
||||
selector: Optional[Any] = None,
|
||||
io_skip_limit: Optional[int] = None,
|
||||
|
@ -94,7 +96,7 @@ class AsyncScheduler:
|
|||
or type(
|
||||
"DumbDebugger",
|
||||
(object,),
|
||||
{"__getattr__": lambda *args: lambda *arg: None},
|
||||
{"__getattr__": lambda *_: lambda *_: None},
|
||||
)()
|
||||
)
|
||||
# All tasks the loop has
|
||||
|
@ -106,7 +108,7 @@ class AsyncScheduler:
|
|||
# This will always point to the currently running coroutine (Task object)
|
||||
self.current_task: Optional[Task] = None
|
||||
# Monotonic clock to keep track of elapsed time reliably
|
||||
self.clock: Callable = clock
|
||||
self.clock: Callable[[], Number] = clock
|
||||
# Tasks that are asleep
|
||||
self.paused: TimeQueue = TimeQueue(self.clock)
|
||||
# Have we ever ran?
|
||||
|
@ -129,6 +131,7 @@ class AsyncScheduler:
|
|||
self.entry_point: Optional[Task] = None
|
||||
# Suspended tasks
|
||||
self.suspended: deque = deque()
|
||||
|
||||
|
||||
def __repr__(self):
|
||||
"""
|
||||
|
@ -150,6 +153,8 @@ class AsyncScheduler:
|
|||
"_data",
|
||||
"io_skip_limit",
|
||||
"io_max_timeout",
|
||||
"suspended",
|
||||
"entry_point"
|
||||
}
|
||||
data = ", ".join(
|
||||
name + "=" + str(value) for name, value in zip(fields, (getattr(self, field) for field in fields))
|
||||
|
@ -168,7 +173,7 @@ class AsyncScheduler:
|
|||
Shuts down the event loop
|
||||
"""
|
||||
|
||||
for task in self.tasks:
|
||||
for task in self.get_all_tasks():
|
||||
self.io_release_task(task)
|
||||
self.selector.close()
|
||||
# TODO: Anything else?
|
||||
|
@ -206,7 +211,10 @@ class AsyncScheduler:
|
|||
# after it is set, but it makes the implementation easier
|
||||
if not self.current_pool and self.current_task.pool:
|
||||
self.current_pool = self.current_task.pool
|
||||
self.deadlines.put(self.current_pool)
|
||||
pool = self.current_pool
|
||||
while pool:
|
||||
self.deadlines.put(pool)
|
||||
pool = self.current_pool.enclosed_pool
|
||||
# If there are no actively running tasks, we start by
|
||||
# checking for I/O. This method will wait for I/O until
|
||||
# the closest deadline to avoid starving sleeping tasks
|
||||
|
@ -230,9 +238,10 @@ class AsyncScheduler:
|
|||
# some tricky behaviors, and this is one of them. When a coroutine
|
||||
# hits a return statement (either explicit or implicit), it raises
|
||||
# a StopIteration exception, which has an attribute named value that
|
||||
# represents the return value of the coroutine, if any. Of course this
|
||||
# exception is not an error and we should happily keep going after it,
|
||||
# represents the return value of the coroutine, if it has one. Of course
|
||||
# this exception is not an error and we should happily keep going after it:
|
||||
# most of this code below is just useful for internal/debugging purposes
|
||||
self.current_task.status = "end"
|
||||
self.current_task.result = ret.value
|
||||
self.current_task.finished = True
|
||||
self.join(self.current_task)
|
||||
|
@ -244,20 +253,22 @@ class AsyncScheduler:
|
|||
self.current_task.exc = err
|
||||
self.join(self.current_task)
|
||||
|
||||
def create_task(self, coro: Coroutine[Any, Any, Any], pool) -> Task:
|
||||
|
||||
def create_task(self, corofunc: Callable[..., Coroutine[Any, Any, Any]], pool, *args, **kwargs) -> Task:
|
||||
"""
|
||||
Creates a task from a coroutine function and schedules it
|
||||
to run. The associated pool that spawned said task is also
|
||||
needed, while any extra keyword or positional arguments are
|
||||
passed to the function itself
|
||||
|
||||
:param coro: The coroutine to spawn
|
||||
:type coro: Coroutine[Any, Any, Any]
|
||||
:param corofunc: The coroutine function (NOT a coroutine!) to
|
||||
spawn
|
||||
:type corofunc: function
|
||||
:param pool: The giambio.context.TaskManager object that
|
||||
spawned the task
|
||||
"""
|
||||
|
||||
task = Task(coro.__name__ or str(coro), coro, pool)
|
||||
task = Task(corofunc.__name__ or str(corofunc), corofunc(*args, **kwargs), pool)
|
||||
task.next_deadline = pool.timeout or 0.0
|
||||
task.joiners = {self.current_task}
|
||||
self._data[self.current_task] = task
|
||||
|
@ -288,9 +299,15 @@ class AsyncScheduler:
|
|||
# We need to make sure we don't try to execute
|
||||
# exited tasks that are on the running queue
|
||||
return
|
||||
if not self.current_pool and self.current_task.pool:
|
||||
if self.current_pool:
|
||||
if self.current_task.pool and self.current_task.pool is not self.current_pool:
|
||||
self.current_task.pool.enclosed_pool = self.current_pool
|
||||
else:
|
||||
self.current_pool = self.current_task.pool
|
||||
self.deadlines.put(self.current_pool)
|
||||
pool = self.current_pool
|
||||
while pool:
|
||||
self.deadlines.put(pool)
|
||||
pool = self.current_pool.enclosed_pool
|
||||
self.debugger.before_task_step(self.current_task)
|
||||
# Some debugging and internal chatter here
|
||||
self.current_task.status = "run"
|
||||
|
@ -319,7 +336,7 @@ class AsyncScheduler:
|
|||
def io_release(self, sock):
|
||||
"""
|
||||
Releases the given resource from our
|
||||
selector.
|
||||
selector
|
||||
:param sock: The resource to be released
|
||||
"""
|
||||
|
||||
|
@ -334,7 +351,7 @@ class AsyncScheduler:
|
|||
|
||||
if self.selector.get_map():
|
||||
for k in filter(
|
||||
lambda o: o.data == self.current_task,
|
||||
lambda o: o.data == task,
|
||||
dict(self.selector.get_map()).values(),
|
||||
):
|
||||
self.io_release(k.fileobj)
|
||||
|
@ -344,11 +361,16 @@ class AsyncScheduler:
|
|||
"""
|
||||
Suspends execution of the current task. This is basically
|
||||
a do-nothing method, since it will not reschedule the task
|
||||
before returning. The task will stay suspended until a timer,
|
||||
I/O operation or cancellation wakes it up, or until another
|
||||
running task reschedules it.
|
||||
before returning. The task will stay suspended as long as
|
||||
something else outside the loop calls a trap to reschedule it.
|
||||
Any pending I/O for the task is temporarily unscheduled to
|
||||
avoid some previous network operation to reschedule the task
|
||||
before it's due
|
||||
"""
|
||||
|
||||
|
||||
if self.current_task.last_io or self.current_task.status == "io":
|
||||
self.io_release_task(self.current_task)
|
||||
self.current_task.status = "sleep"
|
||||
self.suspended.append(self.current_task)
|
||||
|
||||
def reschedule_running(self):
|
||||
|
@ -408,27 +430,35 @@ class AsyncScheduler:
|
|||
try:
|
||||
to_call()
|
||||
except StopIteration as ret:
|
||||
task.status = "end"
|
||||
task.result = ret.value
|
||||
task.finished = True
|
||||
self.join(task)
|
||||
except CancelledError as cancel:
|
||||
task.status = "cancelled"
|
||||
task.cancel_pending = False
|
||||
task.cancelled = True
|
||||
self.join(task)
|
||||
except BaseException as err:
|
||||
task.exc = err
|
||||
self.join(task)
|
||||
|
||||
|
||||
def prune_deadlines(self):
|
||||
"""
|
||||
Removes expired deadlines after their timeout
|
||||
has expired and cancels their associated pool
|
||||
has expired
|
||||
"""
|
||||
|
||||
while self.deadlines and self.deadlines.get_closest_deadline() <= self.clock():
|
||||
pool = self.deadlines.get()
|
||||
pool.timed_out = True
|
||||
self.cancel_pool(pool)
|
||||
for task in pool.tasks:
|
||||
if task is not pool.owner:
|
||||
self.handle_task_exit(task, partial(task.throw, TooSlowError(self.current_task)))
|
||||
if pool.raise_on_timeout:
|
||||
self.handle_task_exit(pool.owner, partial(pool.owner.throw, TooSlowError(self.current_task)))
|
||||
self.join(task)
|
||||
if pool.entry_point is self.entry_point:
|
||||
self.handle_task_exit(self.entry_point, partial(self.entry_point.throw, TooSlowError(self.entry_point)))
|
||||
self.run_ready.append(self.entry_point)
|
||||
|
||||
def schedule_tasks(self, tasks: List[Task]):
|
||||
"""
|
||||
|
@ -439,7 +469,8 @@ class AsyncScheduler:
|
|||
|
||||
for task in tasks:
|
||||
self.paused.discard(task)
|
||||
self.suspended.remove(task)
|
||||
if task in self.suspended:
|
||||
self.suspended.remove(task)
|
||||
self.run_ready.extend(tasks)
|
||||
self.reschedule_running()
|
||||
|
||||
|
@ -462,6 +493,7 @@ class AsyncScheduler:
|
|||
self.run_ready.append(task)
|
||||
self.debugger.after_sleep(task, slept)
|
||||
|
||||
|
||||
def get_closest_deadline(self) -> float:
|
||||
"""
|
||||
Gets the closest expiration deadline (asleep tasks, timeouts)
|
||||
|
@ -469,7 +501,7 @@ class AsyncScheduler:
|
|||
:return: The closest deadline according to our clock
|
||||
:rtype: float
|
||||
"""
|
||||
|
||||
|
||||
if not self.deadlines:
|
||||
# If there are no deadlines just wait until the first task wakeup
|
||||
timeout = max(0.0, self.paused.get_closest_deadline() - self.clock())
|
||||
|
@ -535,9 +567,12 @@ class AsyncScheduler:
|
|||
self.run_ready.append(entry)
|
||||
self.debugger.on_start()
|
||||
if loop:
|
||||
self.run()
|
||||
self.has_ran = True
|
||||
self.debugger.on_exit()
|
||||
try:
|
||||
self.run()
|
||||
finally:
|
||||
self.has_ran = True
|
||||
self.close()
|
||||
self.debugger.on_exit()
|
||||
|
||||
def cancel_pool(self, pool: TaskManager) -> bool:
|
||||
"""
|
||||
|
@ -589,8 +624,9 @@ class AsyncScheduler:
|
|||
If ensure_done equals False, the loop will cancel ALL
|
||||
running and scheduled tasks and then tear itself down.
|
||||
If ensure_done equals True, which is the default behavior,
|
||||
this method will raise a GiambioError if the loop hasn't
|
||||
finished running.
|
||||
this method will raise a GiambioError exception if the loop
|
||||
hasn't finished running. The state of the event loop is reset
|
||||
so it can be reused with another run() call
|
||||
"""
|
||||
|
||||
if ensure_done:
|
||||
|
@ -598,6 +634,16 @@ class AsyncScheduler:
|
|||
elif not self.done():
|
||||
raise GiambioError("event loop not terminated, call this method with ensure_done=False to forcefully exit")
|
||||
self.shutdown()
|
||||
# We reset the event loop's state
|
||||
self.tasks = []
|
||||
self.entry_point = None
|
||||
self.current_pool = None
|
||||
self.current_task = None
|
||||
self.paused = TimeQueue(self.clock)
|
||||
self.deadlines = DeadlinesQueue()
|
||||
self.run_ready = deque()
|
||||
self.suspended = deque()
|
||||
|
||||
|
||||
def reschedule_joiners(self, task: Task):
|
||||
"""
|
||||
|
@ -605,87 +651,71 @@ class AsyncScheduler:
|
|||
given task, if any
|
||||
"""
|
||||
|
||||
for t in task.joiners:
|
||||
self.run_ready.append(t)
|
||||
|
||||
# noinspection PyMethodMayBeStatic
|
||||
def is_pool_done(self, pool: Optional[TaskManager]):
|
||||
"""
|
||||
Returns True if a given pool has finished
|
||||
executing
|
||||
"""
|
||||
|
||||
while pool:
|
||||
if not pool.done():
|
||||
return False
|
||||
pool = pool.enclosed_pool
|
||||
return True
|
||||
if task.pool and task.pool.enclosed_pool and not task.pool.enclosed_pool.done():
|
||||
return
|
||||
self.run_ready.extend(task.joiners)
|
||||
|
||||
def join(self, task: Task):
|
||||
"""
|
||||
Joins a task to its callers (implicitly, the parent
|
||||
Joins a task to its callers (implicitly the parent
|
||||
task, but also every other task who called await
|
||||
task.join() on the task object)
|
||||
"""
|
||||
|
||||
task.joined = True
|
||||
if any([task.finished, task.cancelled, task.exc]) and task in self.tasks:
|
||||
self.io_release_task(task)
|
||||
self.tasks.remove(task)
|
||||
self.paused.discard(task)
|
||||
if task.finished or task.cancelled:
|
||||
task.status = "end"
|
||||
if not task.cancelled:
|
||||
task.status = "cancelled"
|
||||
# This way join() returns the
|
||||
# task's return value
|
||||
for joiner in task.joiners:
|
||||
self._data[joiner] = task.result
|
||||
self.debugger.on_task_exit(task)
|
||||
# If the pool has finished executing or we're at the first parent
|
||||
# task that kicked the loop, we can safely reschedule the parent(s)
|
||||
if self.is_pool_done(task.pool):
|
||||
if task.last_io:
|
||||
self.io_release_task(task)
|
||||
if task in self.suspended:
|
||||
self.suspended.remove(task)
|
||||
if task in self.tasks:
|
||||
self.tasks.remove(task)
|
||||
# If the pool (including any enclosing pools) has finished executing
|
||||
# or we're at the first task that kicked the loop, we can safely
|
||||
# reschedule the parent(s)
|
||||
if task.pool is None:
|
||||
return
|
||||
if task.pool.done():
|
||||
self.reschedule_joiners(task)
|
||||
self.reschedule_running()
|
||||
elif task.exc:
|
||||
if task in self.suspended:
|
||||
self.suspended.remove(task)
|
||||
task.status = "crashed"
|
||||
if task.exc.__traceback__:
|
||||
# TODO: We might want to do a bit more complex traceback hacking to remove any extra
|
||||
# frames from the exception call stack, but for now removing at least the first few
|
||||
# seems a sensible approach (it's us catching it so we don't care about that)
|
||||
for _ in range(5):
|
||||
if task.exc.__traceback__.tb_next:
|
||||
task.exc.__traceback__ = task.exc.__traceback__.tb_next
|
||||
self.debugger.on_exception_raised(task, task.exc)
|
||||
if task is self.entry_point and not task.pool:
|
||||
try:
|
||||
task.throw(task.exc)
|
||||
except StopIteration:
|
||||
... # TODO: ?
|
||||
except BaseException:
|
||||
# TODO: No idea what to do here
|
||||
raise
|
||||
elif any(map(lambda tk: tk is task.pool.owner, task.joiners)) or task is task.pool.owner:
|
||||
# We check if the pool's
|
||||
# owner catches our error
|
||||
# or not. If they don't, we
|
||||
# cancel the entire pool, but
|
||||
# if they do, we do nothing
|
||||
if task.pool.owner is not task:
|
||||
self.handle_task_exit(task.pool.owner, partial(task.pool.owner.coroutine.throw, task.exc))
|
||||
if any([task.pool.owner.exc, task.pool.owner.cancelled, task.pool.owner.finished]):
|
||||
for t in task.joiners.copy():
|
||||
# Propagate the exception
|
||||
self.handle_task_exit(t, partial(t.throw, task.exc))
|
||||
if any([t.exc, t.finished, t.cancelled]):
|
||||
task.joiners.remove(t)
|
||||
for t in task.pool.tasks:
|
||||
if not t.joined:
|
||||
self.handle_task_exit(t, partial(t.throw, task.exc))
|
||||
if any([t.exc, t.finished, t.cancelled]):
|
||||
task.joiners.discard(t)
|
||||
self.reschedule_joiners(task)
|
||||
self.reschedule_running()
|
||||
if task.pool is None or task is self.entry_point:
|
||||
# Parent task has no pool, so we propagate
|
||||
raise task.exc
|
||||
if self.cancel_pool(task.pool):
|
||||
# This will reschedule the parent(s)
|
||||
# only if all the tasks inside the task's
|
||||
# pool have finished executing, either
|
||||
# by cancellation, an exception
|
||||
# or just returned
|
||||
for t in task.joiners.copy():
|
||||
# Propagate the exception
|
||||
try:
|
||||
t.throw(task.exc)
|
||||
except (StopIteration, CancelledError, RuntimeError) as e:
|
||||
# TODO: Need anything else?
|
||||
task.joiners.remove(t)
|
||||
if isinstance(e, StopIteration):
|
||||
t.status = "end"
|
||||
t.result = e.value
|
||||
t.finished = True
|
||||
elif isinstance(e, CancelledError):
|
||||
t = e.task
|
||||
t.cancel_pending = False
|
||||
t.cancelled = True
|
||||
t.status = "cancelled"
|
||||
self.debugger.after_cancel(t)
|
||||
elif isinstance(e, BaseException):
|
||||
t.exc = e
|
||||
finally:
|
||||
if t in self.tasks:
|
||||
self.tasks.remove(t)
|
||||
self.reschedule_joiners(task)
|
||||
|
||||
def sleep(self, seconds: int or float):
|
||||
"""
|
||||
|
@ -727,6 +757,8 @@ class AsyncScheduler:
|
|||
self.io_release_task(task)
|
||||
elif task.status == "sleep":
|
||||
self.paused.discard(task)
|
||||
if task in self.suspended:
|
||||
self.suspended.remove(task)
|
||||
try:
|
||||
self.do_cancel(task)
|
||||
except CancelledError as cancel:
|
||||
|
@ -742,24 +774,36 @@ class AsyncScheduler:
|
|||
task = cancel.task
|
||||
task.cancel_pending = False
|
||||
task.cancelled = True
|
||||
self.io_release_task(self.current_task)
|
||||
task.status = "cancelled"
|
||||
self.debugger.after_cancel(task)
|
||||
self.tasks.remove(task)
|
||||
self.join(task)
|
||||
else:
|
||||
# If the task ignores our exception, we'll
|
||||
# raise it later again
|
||||
task.cancel_pending = True
|
||||
self.join(task)
|
||||
|
||||
def notify_closing(self, stream):
|
||||
"""
|
||||
Implements the notify_closing trap
|
||||
"""
|
||||
|
||||
if self.selector.get_map():
|
||||
for k in filter(
|
||||
lambda o: o.data == self.current_task,
|
||||
dict(self.selector.get_map()).values(),
|
||||
):
|
||||
self.handle_task_exit(k.data,
|
||||
functools.partial(k.data.throw(ResourceClosed("stream has been closed"))))
|
||||
|
||||
def register_sock(self, sock, evt_type: str):
|
||||
"""
|
||||
Registers the given socket inside the
|
||||
selector to perform I/0 multiplexing
|
||||
selector to perform I/O multiplexing
|
||||
|
||||
:param sock: The socket on which a read or write operation
|
||||
has to be performed
|
||||
has to be performed
|
||||
:param evt_type: The type of event to perform on the given
|
||||
socket, either "read" or "write"
|
||||
socket, either "read" or "write"
|
||||
:type evt_type: str
|
||||
"""
|
||||
|
||||
|
@ -793,5 +837,8 @@ class AsyncScheduler:
|
|||
try:
|
||||
self.selector.register(sock, evt, self.current_task)
|
||||
except KeyError:
|
||||
# The socket is already registered doing something else
|
||||
raise ResourceBusy("The given socket is being read/written by another task") from None
|
||||
# The socket is already registered doing something else, we
|
||||
# modify the socket instead (or maybe not?)
|
||||
self.selector.modify(sock, evt, self.current_task)
|
||||
# TODO: Does this break stuff?
|
||||
# raise ResourceBusy("The given socket is being read/written by another task") from None
|
||||
|
|
|
@ -37,7 +37,7 @@ class InternalError(GiambioError):
|
|||
...
|
||||
|
||||
|
||||
class CancelledError(GiambioError):
|
||||
class CancelledError(BaseException):
|
||||
"""
|
||||
Exception raised by the giambio.objects.Task.cancel() method
|
||||
to terminate a child task. This should NOT be caught, or
|
||||
|
|
246
giambio/io.py
246
giambio/io.py
|
@ -15,14 +15,16 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
import socket
|
||||
import warnings
|
||||
|
||||
import os
|
||||
import giambio
|
||||
from giambio.exceptions import ResourceClosed
|
||||
from giambio.traps import want_write, want_read, io_release
|
||||
from giambio.traps import want_write, want_read, io_release, notify_closing
|
||||
|
||||
|
||||
try:
|
||||
from ssl import SSLWantReadError, SSLWantWriteError
|
||||
from ssl import SSLWantReadError, SSLWantWriteError, SSLSocket
|
||||
|
||||
WantRead = (BlockingIOError, InterruptedError, SSLWantReadError)
|
||||
WantWrite = (BlockingIOError, InterruptedError, SSLWantWriteError)
|
||||
|
@ -31,16 +33,115 @@ except ImportError:
|
|||
WantWrite = (BlockingIOError, InterruptedError)
|
||||
|
||||
|
||||
class AsyncSocket:
|
||||
class AsyncStream:
|
||||
"""
|
||||
A generic asynchronous stream over
|
||||
a file descriptor. Only works on Linux
|
||||
& co because windows doesn't like select()
|
||||
to be called on non-socket objects
|
||||
(Thanks, Microsoft)
|
||||
"""
|
||||
|
||||
def __init__(self, fd: int, open_fd: bool = True, close_on_context_exit: bool = True, **kwargs):
|
||||
self._fd = fd
|
||||
self.stream = None
|
||||
if open_fd:
|
||||
self.stream = os.fdopen(self._fd, **kwargs)
|
||||
os.set_blocking(self._fd, False)
|
||||
self.close_on_context_exit = close_on_context_exit
|
||||
|
||||
async def read(self, size: int = -1):
|
||||
"""
|
||||
Reads up to size bytes from the
|
||||
given stream. If size == -1, read
|
||||
until EOF is reached
|
||||
"""
|
||||
|
||||
while True:
|
||||
try:
|
||||
return self.stream.read(size)
|
||||
except WantRead:
|
||||
await want_read(self.stream)
|
||||
|
||||
async def write(self, data):
|
||||
"""
|
||||
Writes data b to the file.
|
||||
Returns the number of bytes
|
||||
written
|
||||
"""
|
||||
|
||||
while True:
|
||||
try:
|
||||
return self.stream.write(data)
|
||||
except WantWrite:
|
||||
await want_write(self.stream)
|
||||
|
||||
async def close(self):
|
||||
"""
|
||||
Closes the stream asynchronously
|
||||
"""
|
||||
|
||||
if self._fd == -1:
|
||||
raise ResourceClosed("I/O operation on closed stream")
|
||||
self._fd = -1
|
||||
await notify_closing(self.stream)
|
||||
await io_release(self.stream)
|
||||
self.stream.close()
|
||||
self.stream = None
|
||||
|
||||
@property
|
||||
async def fileno(self):
|
||||
"""
|
||||
Wrapper socket method
|
||||
"""
|
||||
|
||||
return self._fd
|
||||
|
||||
async def __aenter__(self):
|
||||
self.stream.__enter__()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *args):
|
||||
if self._fd != -1 and self.close_on_context_exit:
|
||||
await self.close()
|
||||
|
||||
async def dup(self):
|
||||
"""
|
||||
Wrapper stream method
|
||||
"""
|
||||
|
||||
return type(self)(os.dup(self._fd))
|
||||
|
||||
def __repr__(self):
|
||||
return f"AsyncStream({self.stream})"
|
||||
|
||||
def __del__(self):
|
||||
"""
|
||||
Stream destructor. Do *not* call
|
||||
this directly: stuff will break
|
||||
"""
|
||||
|
||||
if self._fd != -1:
|
||||
try:
|
||||
os.set_blocking(self._fd, False)
|
||||
os.close(self._fd)
|
||||
except OSError as e:
|
||||
warnings.warn(f"An exception occurred in __del__ for stream {self} -> {type(e).__name__}: {e}")
|
||||
|
||||
|
||||
class AsyncSocket(AsyncStream):
|
||||
"""
|
||||
Abstraction layer for asynchronous sockets
|
||||
"""
|
||||
|
||||
def __init__(self, sock, do_handshake_on_connect: bool = True):
|
||||
self.sock = sock
|
||||
def __init__(self, sock: socket.socket, close_on_context_exit: bool = True, do_handshake_on_connect: bool = True):
|
||||
super().__init__(sock.fileno(), open_fd=False, close_on_context_exit=close_on_context_exit)
|
||||
self.do_handshake_on_connect = do_handshake_on_connect
|
||||
self._fd = sock.fileno()
|
||||
self.sock.setblocking(False)
|
||||
self.stream = socket.fromfd(self._fd, sock.family, sock.type, sock.proto)
|
||||
self.stream.setblocking(False)
|
||||
# A socket that isn't connected doesn't
|
||||
# need to be closed
|
||||
self.needs_closing: bool = False
|
||||
|
||||
async def receive(self, max_size: int, flags: int = 0) -> bytes:
|
||||
"""
|
||||
|
@ -52,11 +153,11 @@ class AsyncSocket:
|
|||
raise ResourceClosed("I/O operation on closed socket")
|
||||
while True:
|
||||
try:
|
||||
return self.sock.recv(max_size, flags)
|
||||
return self.stream.recv(max_size, flags)
|
||||
except WantRead:
|
||||
await want_read(self.sock)
|
||||
await want_read(self.stream)
|
||||
except WantWrite:
|
||||
await want_write(self.sock)
|
||||
await want_write(self.stream)
|
||||
|
||||
async def connect(self, address):
|
||||
"""
|
||||
|
@ -67,12 +168,21 @@ class AsyncSocket:
|
|||
raise ResourceClosed("I/O operation on closed socket")
|
||||
while True:
|
||||
try:
|
||||
self.sock.connect(address)
|
||||
self.stream.connect(address)
|
||||
if self.do_handshake_on_connect:
|
||||
await self.do_handshake()
|
||||
return
|
||||
break
|
||||
except WantWrite:
|
||||
await want_write(self.sock)
|
||||
await want_write(self.stream)
|
||||
self.needs_closing = True
|
||||
|
||||
async def close(self):
|
||||
"""
|
||||
Wrapper socket method
|
||||
"""
|
||||
|
||||
if self.needs_closing:
|
||||
await super().close()
|
||||
|
||||
async def accept(self):
|
||||
"""
|
||||
|
@ -83,10 +193,10 @@ class AsyncSocket:
|
|||
raise ResourceClosed("I/O operation on closed socket")
|
||||
while True:
|
||||
try:
|
||||
remote, addr = self.sock.accept()
|
||||
remote, addr = self.stream.accept()
|
||||
return type(self)(remote), addr
|
||||
except WantRead:
|
||||
await want_read(self.sock)
|
||||
await want_read(self.stream)
|
||||
|
||||
async def send_all(self, data: bytes, flags: int = 0):
|
||||
"""
|
||||
|
@ -98,32 +208,20 @@ class AsyncSocket:
|
|||
sent_no = 0
|
||||
while data:
|
||||
try:
|
||||
sent_no = self.sock.send(data, flags)
|
||||
sent_no = self.stream.send(data, flags)
|
||||
except WantRead:
|
||||
await want_read(self.sock)
|
||||
await want_read(self.stream)
|
||||
except WantWrite:
|
||||
await want_write(self.sock)
|
||||
await want_write(self.stream)
|
||||
data = data[sent_no:]
|
||||
|
||||
async def close(self):
|
||||
"""
|
||||
Closes the socket asynchronously
|
||||
"""
|
||||
|
||||
if self._fd == -1:
|
||||
raise ResourceClosed("I/O operation on closed socket")
|
||||
await io_release(self.sock)
|
||||
self.sock.close()
|
||||
self._fd = -1
|
||||
self.sock = None
|
||||
|
||||
async def shutdown(self, how):
|
||||
"""
|
||||
Wrapper socket method
|
||||
"""
|
||||
|
||||
if self.sock:
|
||||
self.sock.shutdown(how)
|
||||
if self.stream:
|
||||
self.stream.shutdown(how)
|
||||
await giambio.sleep(0) # Checkpoint
|
||||
|
||||
async def bind(self, addr: tuple):
|
||||
|
@ -136,7 +234,7 @@ class AsyncSocket:
|
|||
|
||||
if self._fd == -1:
|
||||
raise ResourceClosed("I/O operation on closed socket")
|
||||
self.sock.bind(addr)
|
||||
self.stream.bind(addr)
|
||||
|
||||
async def listen(self, backlog: int):
|
||||
"""
|
||||
|
@ -148,27 +246,12 @@ class AsyncSocket:
|
|||
|
||||
if self._fd == -1:
|
||||
raise ResourceClosed("I/O operation on closed socket")
|
||||
self.sock.listen(backlog)
|
||||
|
||||
async def __aenter__(self):
|
||||
self.sock.__enter__()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *args):
|
||||
if self.sock:
|
||||
self.sock.__exit__(*args)
|
||||
self.stream.listen(backlog)
|
||||
|
||||
# Yes, I stole these from Curio because I could not be
|
||||
# arsed to write a bunch of uninteresting simple socket
|
||||
# methods from scratch, deal with it.
|
||||
|
||||
async def fileno(self):
|
||||
"""
|
||||
Wrapper socket method
|
||||
"""
|
||||
|
||||
return self._fd
|
||||
|
||||
async def settimeout(self, seconds):
|
||||
"""
|
||||
Wrapper socket method
|
||||
|
@ -188,22 +271,23 @@ class AsyncSocket:
|
|||
Wrapper socket method
|
||||
"""
|
||||
|
||||
return type(self)(self.sock.dup())
|
||||
return type(self)(self.stream.dup(), self.do_handshake_on_connect)
|
||||
|
||||
async def do_handshake(self):
|
||||
"""
|
||||
Wrapper socket method
|
||||
"""
|
||||
|
||||
if not hasattr(self.sock, "do_handshake"):
|
||||
if not hasattr(self.stream, "do_handshake"):
|
||||
return
|
||||
while True:
|
||||
try:
|
||||
return self.sock.do_handshake()
|
||||
self.stream: SSLSocket # Silences pycharm warnings
|
||||
return self.stream.do_handshake()
|
||||
except WantRead:
|
||||
await want_read(self.sock)
|
||||
await want_read(self.stream)
|
||||
except WantWrite:
|
||||
await want_write(self.sock)
|
||||
await want_write(self.stream)
|
||||
|
||||
async def recvfrom(self, buffersize, flags=0):
|
||||
"""
|
||||
|
@ -212,11 +296,11 @@ class AsyncSocket:
|
|||
|
||||
while True:
|
||||
try:
|
||||
return self.sock.recvfrom(buffersize, flags)
|
||||
return self.stream.recvfrom(buffersize, flags)
|
||||
except WantRead:
|
||||
await want_read(self.sock)
|
||||
await want_read(self.stream)
|
||||
except WantWrite:
|
||||
await want_write(self.sock)
|
||||
await want_write(self.stream)
|
||||
|
||||
async def recvfrom_into(self, buffer, bytes=0, flags=0):
|
||||
"""
|
||||
|
@ -225,11 +309,11 @@ class AsyncSocket:
|
|||
|
||||
while True:
|
||||
try:
|
||||
return self.sock.recvfrom_into(buffer, bytes, flags)
|
||||
return self.stream.recvfrom_into(buffer, bytes, flags)
|
||||
except WantRead:
|
||||
await want_read(self.sock)
|
||||
await want_read(self.stream)
|
||||
except WantWrite:
|
||||
await want_write(self.sock)
|
||||
await want_write(self.stream)
|
||||
|
||||
async def sendto(self, bytes, flags_or_address, address=None):
|
||||
"""
|
||||
|
@ -243,11 +327,11 @@ class AsyncSocket:
|
|||
flags = 0
|
||||
while True:
|
||||
try:
|
||||
return self.sock.sendto(bytes, flags, address)
|
||||
return self.stream.sendto(bytes, flags, address)
|
||||
except WantWrite:
|
||||
await want_write(self.sock)
|
||||
await want_write(self.stream)
|
||||
except WantRead:
|
||||
await want_read(self.sock)
|
||||
await want_read(self.stream)
|
||||
|
||||
async def getpeername(self):
|
||||
"""
|
||||
|
@ -256,11 +340,11 @@ class AsyncSocket:
|
|||
|
||||
while True:
|
||||
try:
|
||||
return self.sock.getpeername()
|
||||
return self.stream.getpeername()
|
||||
except WantWrite:
|
||||
await want_write(self.sock)
|
||||
await want_write(self.stream)
|
||||
except WantRead:
|
||||
await want_read(self.sock)
|
||||
await want_read(self.stream)
|
||||
|
||||
async def getsockname(self):
|
||||
"""
|
||||
|
@ -269,11 +353,11 @@ class AsyncSocket:
|
|||
|
||||
while True:
|
||||
try:
|
||||
return self.sock.getpeername()
|
||||
return self.stream.getpeername()
|
||||
except WantWrite:
|
||||
await want_write(self.sock)
|
||||
await want_write(self.stream)
|
||||
except WantRead:
|
||||
await want_read(self.sock)
|
||||
await want_read(self.stream)
|
||||
|
||||
async def recvmsg(self, bufsize, ancbufsize=0, flags=0):
|
||||
"""
|
||||
|
@ -282,9 +366,9 @@ class AsyncSocket:
|
|||
|
||||
while True:
|
||||
try:
|
||||
return self.sock.recvmsg(bufsize, ancbufsize, flags)
|
||||
return self.stream.recvmsg(bufsize, ancbufsize, flags)
|
||||
except WantRead:
|
||||
await want_read(self.sock)
|
||||
await want_read(self.stream)
|
||||
|
||||
async def recvmsg_into(self, buffers, ancbufsize=0, flags=0):
|
||||
"""
|
||||
|
@ -293,9 +377,9 @@ class AsyncSocket:
|
|||
|
||||
while True:
|
||||
try:
|
||||
return self.sock.recvmsg_into(buffers, ancbufsize, flags)
|
||||
return self.stream.recvmsg_into(buffers, ancbufsize, flags)
|
||||
except WantRead:
|
||||
await want_read(self.sock)
|
||||
await want_read(self.stream)
|
||||
|
||||
async def sendmsg(self, buffers, ancdata=(), flags=0, address=None):
|
||||
"""
|
||||
|
@ -304,17 +388,13 @@ class AsyncSocket:
|
|||
|
||||
while True:
|
||||
try:
|
||||
return self.sock.sendmsg(buffers, ancdata, flags, address)
|
||||
return self.stream.sendmsg(buffers, ancdata, flags, address)
|
||||
except WantRead:
|
||||
await want_write(self.sock)
|
||||
await want_write(self.stream)
|
||||
|
||||
def __repr__(self):
|
||||
return f"AsyncSocket({self.sock})"
|
||||
return f"AsyncSocket({self.stream})"
|
||||
|
||||
def __del__(self):
|
||||
"""
|
||||
Socket destructor
|
||||
"""
|
||||
|
||||
if not self._fd != -1:
|
||||
warnings.warn(f"socket '{self}' was destroyed, but was not closed, leading to a potential resource leak")
|
||||
if self.needs_closing:
|
||||
super().__del__()
|
||||
|
|
|
@ -92,7 +92,7 @@ def create_pool():
|
|||
Creates an async pool
|
||||
"""
|
||||
|
||||
return TaskManager(get_event_loop().current_task)
|
||||
return TaskManager()
|
||||
|
||||
|
||||
def with_timeout(timeout: int or float):
|
||||
|
@ -101,7 +101,7 @@ def with_timeout(timeout: int or float):
|
|||
"""
|
||||
|
||||
assert timeout > 0, "The timeout must be greater than 0"
|
||||
mgr = TaskManager(get_event_loop().current_task, timeout, True)
|
||||
mgr = TaskManager(timeout, True)
|
||||
loop = get_event_loop()
|
||||
if loop.current_task is loop.entry_point:
|
||||
loop.current_pool = mgr
|
||||
|
@ -117,7 +117,7 @@ def skip_after(timeout: int or float):
|
|||
"""
|
||||
|
||||
assert timeout > 0, "The timeout must be greater than 0"
|
||||
mgr = TaskManager(get_event_loop().current_task, timeout)
|
||||
mgr = TaskManager(timeout)
|
||||
loop = get_event_loop()
|
||||
if loop.current_task is loop.entry_point:
|
||||
loop.current_pool = mgr
|
||||
|
|
|
@ -67,7 +67,7 @@ async def create_task(coro: Callable[[Any, Any], Coroutine[Any, Any, Any]], pool
|
|||
"\nWhat you wanna do, instead, is this: pool.create_task(your_func, arg1, arg2, ...)"
|
||||
)
|
||||
elif inspect.iscoroutinefunction(coro):
|
||||
return await create_trap("create_task", coro(*args, **kwargs), pool)
|
||||
return await create_trap("create_task", coro, pool, *args, **kwargs)
|
||||
else:
|
||||
raise TypeError("coro must be a coroutine function")
|
||||
|
||||
|
@ -178,6 +178,19 @@ async def want_write(stream):
|
|||
await create_trap("register_sock", stream, "write")
|
||||
|
||||
|
||||
async def notify_closing(stream):
|
||||
"""
|
||||
Notifies the event loop that a given
|
||||
stream needs to be closed. This makes
|
||||
all callers waiting on want_read or
|
||||
want_write crash with a ResourceClosed
|
||||
exception, but it doesn't actually close
|
||||
the socket object itself
|
||||
"""
|
||||
|
||||
await create_trap("notify_closing", stream)
|
||||
|
||||
|
||||
async def schedule_tasks(tasks: Iterable[Task]):
|
||||
"""
|
||||
Schedules a list of tasks for execution. Usuaully
|
||||
|
|
|
@ -1,44 +1,49 @@
|
|||
import sys
|
||||
from typing import Tuple
|
||||
import giambio
|
||||
import logging
|
||||
|
||||
from debugger import Debugger
|
||||
|
||||
async def sender(sock: giambio.socket.AsyncSocket, q: giambio.Queue):
|
||||
|
||||
async def reader(q: giambio.Queue, prompt: str = ""):
|
||||
in_stream = giambio.io.AsyncStream(sys.stdin.fileno(), close_on_context_exit=False, mode="r")
|
||||
out_stream = giambio.io.AsyncStream(sys.stdout.fileno(), close_on_context_exit=False, mode="w")
|
||||
while True:
|
||||
await sock.send_all(b"yo")
|
||||
await q.put((0, ""))
|
||||
await giambio.sleep(1)
|
||||
await out_stream.write(prompt)
|
||||
await q.put((0, await in_stream.read()))
|
||||
|
||||
|
||||
async def receiver(sock: giambio.socket.AsyncSocket, q: giambio.Queue):
|
||||
data = b""
|
||||
while True:
|
||||
while not data.endswith(b"\n"):
|
||||
data += await sock.receive(1024)
|
||||
temp = await sock.receive(1024)
|
||||
if not temp:
|
||||
raise EOFError("end of file")
|
||||
data += temp
|
||||
data, rest = data.split(b"\n", maxsplit=2)
|
||||
buffer = b"".join(rest)
|
||||
await q.put((1, data.decode()))
|
||||
data = buffer
|
||||
|
||||
|
||||
async def main(host: Tuple[str, int]):
|
||||
async def main(host: tuple[str, int]):
|
||||
"""
|
||||
Main client entry point
|
||||
"""
|
||||
|
||||
queue = giambio.Queue()
|
||||
out_stream = giambio.io.AsyncStream(sys.stdout.fileno(), close_on_context_exit=False, mode="w")
|
||||
async with giambio.create_pool() as pool:
|
||||
async with giambio.socket.socket() as sock:
|
||||
await sock.connect(host)
|
||||
await pool.spawn(sender, sock, queue)
|
||||
await out_stream.write("Connection successful\n")
|
||||
await pool.spawn(receiver, sock, queue)
|
||||
await pool.spawn(reader, queue, "> ")
|
||||
while True:
|
||||
op, data = await queue.get()
|
||||
if op == 0:
|
||||
print(f"Sent.")
|
||||
else:
|
||||
print(f"Received: {data}")
|
||||
if op == 1:
|
||||
await out_stream.write(data)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -49,7 +54,7 @@ if __name__ == "__main__":
|
|||
datefmt="%d/%m/%Y %p",
|
||||
)
|
||||
try:
|
||||
giambio.run(main, ("localhost", port))
|
||||
giambio.run(main, ("localhost", port), debugger=Debugger())
|
||||
except (Exception, KeyboardInterrupt) as error: # Exceptions propagate!
|
||||
if isinstance(error, KeyboardInterrupt):
|
||||
logging.info("Ctrl+C detected, exiting")
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
from typing import List
|
||||
import giambio
|
||||
from giambio.socket import AsyncSocket
|
||||
import logging
|
||||
|
@ -6,7 +5,8 @@ import sys
|
|||
|
||||
# An asynchronous chatroom
|
||||
|
||||
clients: List[giambio.socket.AsyncSocket] = []
|
||||
clients: dict[AsyncSocket, list[str, str]] = {}
|
||||
names: set[str] = set()
|
||||
|
||||
|
||||
async def serve(bind_address: tuple):
|
||||
|
@ -26,39 +26,52 @@ async def serve(bind_address: tuple):
|
|||
while True:
|
||||
try:
|
||||
conn, address_tuple = await sock.accept()
|
||||
clients.append(conn)
|
||||
clients[conn] = ["", f"{address_tuple[0]}:{address_tuple[1]}"]
|
||||
logging.info(f"{address_tuple[0]}:{address_tuple[1]} connected")
|
||||
await pool.spawn(handler, conn, address_tuple)
|
||||
await pool.spawn(handler, conn)
|
||||
except Exception as err:
|
||||
# Because exceptions just *work*
|
||||
logging.info(f"{address_tuple[0]}:{address_tuple[1]} has raised {type(err).__name__}: {err}")
|
||||
|
||||
|
||||
async def handler(sock: AsyncSocket, client_address: tuple):
|
||||
async def handler(sock: AsyncSocket):
|
||||
"""
|
||||
Handles a single client connection
|
||||
|
||||
:param sock: The AsyncSocket object connected to the client
|
||||
:param client_address: The client's address represented as a tuple
|
||||
(address, port) where address is a string and port is an integer
|
||||
:type client_address: tuple
|
||||
"""
|
||||
|
||||
address = f"{client_address[0]}:{client_address[1]}"
|
||||
address = clients[sock][1]
|
||||
name = ""
|
||||
async with sock: # Closes the socket automatically
|
||||
await sock.send_all(b"Welcome to the chatroom pal, start typing and press enter!\n")
|
||||
await sock.send_all(b"Welcome to the chatroom pal, may you tell me your name?\n> ")
|
||||
while True:
|
||||
while not name.endswith("\n"):
|
||||
name = (await sock.receive(64)).decode()
|
||||
name = name[:-1]
|
||||
if name not in names:
|
||||
names.add(name)
|
||||
clients[sock][0] = name
|
||||
break
|
||||
else:
|
||||
await sock.send_all(b"Sorry, but that name is already taken. Try again!\n> ")
|
||||
await sock.send_all(f"Okay {name}, welcome to the chatroom!\n".encode())
|
||||
logging.info(f"{name} has joined the chatroom ({address}), informing clients")
|
||||
for i, client_sock in enumerate(clients):
|
||||
if client_sock != sock and clients[client_sock][0]:
|
||||
await client_sock.send_all(f"{name} joins the chatroom!\n> ".encode())
|
||||
while True:
|
||||
await sock.send_all(b"> ")
|
||||
data = await sock.receive(1024)
|
||||
if not data:
|
||||
break
|
||||
elif data == b"exit\n":
|
||||
await sock.send_all(b"I'm dead dude\n")
|
||||
raise TypeError("Oh, no, I'm gonna die!")
|
||||
logging.info(f"Got: {data!r} from {address}")
|
||||
for i, client_sock in enumerate(clients):
|
||||
logging.info(f"Sending {data!r} to {':'.join(map(str, await client_sock.getpeername()))}")
|
||||
if client_sock != sock:
|
||||
await client_sock.send_all(data)
|
||||
if client_sock != sock and clients[client_sock][0]:
|
||||
logging.info(f"Sending {data!r} to {':'.join(map(str, await client_sock.getpeername()))}")
|
||||
if not data.endswith(b"\n"):
|
||||
data += b"\n"
|
||||
await client_sock.send_all(f"[{name}] ({address}): {data.decode()}> ".encode())
|
||||
logging.info(f"Sent {data!r} to {i} clients")
|
||||
logging.info(f"Connection from {address} closed")
|
||||
clients.remove(sock)
|
||||
|
|
|
@ -63,7 +63,7 @@ if __name__ == "__main__":
|
|||
logging.basicConfig(
|
||||
level=20,
|
||||
format="[%(levelname)s] %(asctime)s %(message)s",
|
||||
datefmt="%d/%m/%Y %p",
|
||||
datefmt="%d/%m/%Y %H:%M:%S %p",
|
||||
)
|
||||
try:
|
||||
giambio.run(serve, ("localhost", port), debugger=())
|
||||
|
|
Loading…
Reference in New Issue