Made the API a little bit cleaner

This commit is contained in:
nocturn9x 2020-11-12 23:25:51 +01:00
parent 3c9421c84c
commit daf727d67d
3 changed files with 36 additions and 40 deletions

View File

@ -53,7 +53,7 @@ class AsyncScheduler:
)
self.paused = TimeQueue(self.clock) # Tasks that are asleep
self.events = {} # All Event objects
self.event_waiting = defaultdict(list) # Coroutines waiting on event objects
self._event_waiting = defaultdict(list) # Coroutines waiting on event objects
self.sequence = 0
def run(self):
@ -67,16 +67,16 @@ class AsyncScheduler:
while True:
try:
if not self.selector.get_map() and not any(
[self.paused, self.tasks, self.event_waiting]
[self.paused, self.tasks, self._event_waiting]
): # If there is nothing to do, just exit
break
if not self.tasks:
if (
self.paused
): # If there are no actively running tasks, we try to schedule the asleep ones
self.check_sleeping()
self._check_sleeping()
if self.selector.get_map():
self.check_io() # The next step is checking for I/O
self._check_io() # The next step is checking for I/O
while self.tasks: # While there are tasks to run
self.current_task = (
self.tasks.popleft()
@ -88,10 +88,10 @@ class AsyncScheduler:
self.current_task._notify
) # Run a single step with the calculation (and awake event-waiting tasks if any)
self.current_task.status = "run"
getattr(self, method)(
getattr(self, f"_{method}")(
*args
) # Sneaky method call, thanks to David Beazley for this ;)
if self.event_waiting:
if self._event_waiting:
self.check_events()
except CancelledError as cancelled:
self.tasks.remove(cancelled.args[0]) # Remove the dead task
@ -103,22 +103,22 @@ class AsyncScheduler:
except BaseException as error: # Coroutine raised
self.current_task.exc = error
self.reschedule_parent(self.current_task)
self.join(self.current_task)
self._join(self.current_task)
def check_events(self):
"""
Checks for ready or expired events and triggers them
"""
for event, tasks in self.event_waiting.copy().items():
for event, tasks in self._event_waiting.copy().items():
if event._set:
event.event_caught = True
for task in tasks:
task._notify = event._notify
self.tasks.extend(tasks + [event.notifier])
self.event_waiting.pop(event)
self._event_waiting.pop(event)
def check_sleeping(self):
def _check_sleeping(self):
"""
Checks and reschedules sleeping tasks
"""
@ -133,7 +133,7 @@ class AsyncScheduler:
if not self.paused:
break
def check_io(self):
def _check_io(self):
"""
Checks and schedules task to perform I/O
"""
@ -192,7 +192,7 @@ class AsyncScheduler:
return parent
# TODO: More generic I/O rather than just sockets
def want_read(self, sock: socket.socket):
def _want_read(self, sock: socket.socket):
"""
Handler for the 'want_read' event, registers the socket inside the selector to perform I/0 multiplexing
"""
@ -203,16 +203,13 @@ class AsyncScheduler:
return # Socket is already scheduled!
else:
self.selector.unregister(sock)
busy = False
self.current_task._last_io = "READ", sock
try:
self.selector.register(sock, EVENT_READ, self.current_task)
except KeyError:
busy = True
if busy:
raise ResourceBusy("The given resource is busy!")
raise ResourceBusy("The given resource is busy!") from None
def want_write(self, sock: socket.socket):
def _want_write(self, sock: socket.socket):
"""
Handler for the 'want_write' event, registers the socket inside the selector to perform I/0 multiplexing
"""
@ -223,16 +220,13 @@ class AsyncScheduler:
return # Socket is already scheduled!
else:
self.selector.unregister(sock) # modify() causes issues
busy = False
self.current_task._last_io = "WRITE", sock
try:
self.selector.register(sock, EVENT_WRITE, self.current_task)
except KeyError:
busy = True
if busy:
raise ResourceBusy("The given resource is busy!")
raise ResourceBusy("The given resource is busy!") from None
def join(self, child: types.coroutine):
def _join(self, child: types.coroutine):
"""
Handler for the 'join' event, does some magic to tell the scheduler
to wait until the passed coroutine ends. The result of this call equals whatever the
@ -255,7 +249,7 @@ class AsyncScheduler:
"Joining the same task multiple times is not allowed!"
)
def sleep(self, seconds: int or float):
def _sleep(self, seconds: int or float):
"""
Puts the caller to sleep for a given amount of seconds
"""
@ -266,7 +260,7 @@ class AsyncScheduler:
else:
self.tasks.append(self.current_task)
def event_set(self, event, value):
def _event_set(self, event, value):
"""
Sets an event
"""
@ -276,7 +270,7 @@ class AsyncScheduler:
event._notify = value
self.events[event] = value
def event_wait(self, event):
def _event_wait(self, event):
"""
Waits for an event
"""
@ -288,9 +282,9 @@ class AsyncScheduler:
else:
return self.events[event]
else:
self.event_waiting[event].append(self.current_task)
self._event_waiting[event].append(self.current_task)
def cancel(self, task):
def _cancel(self, task):
"""
Handler for the 'cancel' event, throws CancelledError inside a coroutine
in order to stop it from executing. The loop continues to execute as tasks
@ -312,7 +306,7 @@ class AsyncScheduler:
return AsyncSocket(sock, self)
async def read_sock(self, sock: socket.socket, buffer: int):
async def _read_sock(self, sock: socket.socket, buffer: int):
"""
Reads from a socket asynchronously, waiting until the resource is available and returning up to buffer bytes
from the socket
@ -321,7 +315,7 @@ class AsyncScheduler:
await want_read(sock)
return sock.recv(buffer)
async def accept_sock(self, sock: socket.socket):
async def _accept_sock(self, sock: socket.socket):
"""
Accepts a socket connection asynchronously, waiting until the resource is available and returning the
result of the accept() call
@ -330,7 +324,7 @@ class AsyncScheduler:
await want_read(sock)
return sock.accept()
async def sock_sendall(self, sock: socket.socket, data: bytes):
async def _sock_sendall(self, sock: socket.socket, data: bytes):
"""
Sends all the passed data, as bytes, trough the socket asynchronously
"""
@ -340,7 +334,7 @@ class AsyncScheduler:
sent_no = sock.send(data)
data = data[sent_no:]
async def close_sock(self, sock: socket.socket):
async def _close_sock(self, sock: socket.socket):
"""
Closes the socket asynchronously
"""
@ -348,7 +342,7 @@ class AsyncScheduler:
await want_write(sock)
return sock.close()
async def connect_sock(self, sock: socket.socket, addr: tuple):
async def _connect_sock(self, sock: socket.socket, addr: tuple):
"""
Connects a socket asynchronously
"""

View File

@ -47,14 +47,14 @@ class AsyncSocket(object):
if self._closed:
raise ResourceClosed("I/O operation on closed socket")
self.loop.current_task.status = "I/O"
return await self.loop.read_sock(self.sock, max_size)
return await self.loop._read_sock(self.sock, max_size)
async def accept(self):
"""Accepts the socket, completing the 3-step TCP handshake asynchronously"""
if self._closed:
raise ResourceClosed("I/O operation on closed socket")
to_wrap = await self.loop.accept_sock(self.sock)
to_wrap = await self.loop._accept_sock(self.sock)
return self.loop.wrap_socket(to_wrap[0]), to_wrap[1]
async def send_all(self, data: bytes):
@ -62,7 +62,7 @@ class AsyncSocket(object):
if self._closed:
raise ResourceClosed("I/O operation on closed socket")
return await self.loop.sock_sendall(self.sock, data)
return await self.loop._sock_sendall(self.sock, data)
async def close(self):
"""Closes the socket asynchronously"""
@ -70,7 +70,7 @@ class AsyncSocket(object):
if self._closed:
raise ResourceClosed("I/O operation on closed socket")
await sleep(0) # Give the scheduler the time to unregister the socket first
await self.loop.close_sock(self.sock)
await self.loop._close_sock(self.sock)
self._closed = True
async def connect(self, addr: tuple):
@ -78,12 +78,14 @@ class AsyncSocket(object):
if self._closed:
raise ResourceClosed("I/O operation on closed socket")
await self.loop.connect_sock(self.sock, addr)
await self.loop._connect_sock(self.sock, addr)
def __enter__(self):
async def __aenter__(self):
await sleep(0)
return self.sock.__enter__()
def __exit__(self, *args):
async def __aexit__(self, *args):
await sleep(0)
return self.sock.__exit__(*args)
def __repr__(self):

View File

@ -24,7 +24,7 @@ async def server(address: tuple):
async def echo_handler(sock: AsyncSocket, addr: tuple):
with sock:
async with sock:
await sock.send_all(b"Welcome to the server pal!\n")
while True:
data = await sock.receive(1000)