Made the API a little bit cleaner

This commit is contained in:
nocturn9x 2020-11-12 23:25:51 +01:00
parent 3c9421c84c
commit daf727d67d
3 changed files with 36 additions and 40 deletions

View File

@ -53,7 +53,7 @@ class AsyncScheduler:
) )
self.paused = TimeQueue(self.clock) # Tasks that are asleep self.paused = TimeQueue(self.clock) # Tasks that are asleep
self.events = {} # All Event objects self.events = {} # All Event objects
self.event_waiting = defaultdict(list) # Coroutines waiting on event objects self._event_waiting = defaultdict(list) # Coroutines waiting on event objects
self.sequence = 0 self.sequence = 0
def run(self): def run(self):
@ -67,16 +67,16 @@ class AsyncScheduler:
while True: while True:
try: try:
if not self.selector.get_map() and not any( if not self.selector.get_map() and not any(
[self.paused, self.tasks, self.event_waiting] [self.paused, self.tasks, self._event_waiting]
): # If there is nothing to do, just exit ): # If there is nothing to do, just exit
break break
if not self.tasks: if not self.tasks:
if ( if (
self.paused self.paused
): # If there are no actively running tasks, we try to schedule the asleep ones ): # If there are no actively running tasks, we try to schedule the asleep ones
self.check_sleeping() self._check_sleeping()
if self.selector.get_map(): if self.selector.get_map():
self.check_io() # The next step is checking for I/O self._check_io() # The next step is checking for I/O
while self.tasks: # While there are tasks to run while self.tasks: # While there are tasks to run
self.current_task = ( self.current_task = (
self.tasks.popleft() self.tasks.popleft()
@ -88,10 +88,10 @@ class AsyncScheduler:
self.current_task._notify self.current_task._notify
) # Run a single step with the calculation (and awake event-waiting tasks if any) ) # Run a single step with the calculation (and awake event-waiting tasks if any)
self.current_task.status = "run" self.current_task.status = "run"
getattr(self, method)( getattr(self, f"_{method}")(
*args *args
) # Sneaky method call, thanks to David Beazley for this ;) ) # Sneaky method call, thanks to David Beazley for this ;)
if self.event_waiting: if self._event_waiting:
self.check_events() self.check_events()
except CancelledError as cancelled: except CancelledError as cancelled:
self.tasks.remove(cancelled.args[0]) # Remove the dead task self.tasks.remove(cancelled.args[0]) # Remove the dead task
@ -103,22 +103,22 @@ class AsyncScheduler:
except BaseException as error: # Coroutine raised except BaseException as error: # Coroutine raised
self.current_task.exc = error self.current_task.exc = error
self.reschedule_parent(self.current_task) self.reschedule_parent(self.current_task)
self.join(self.current_task) self._join(self.current_task)
def check_events(self): def check_events(self):
""" """
Checks for ready or expired events and triggers them Checks for ready or expired events and triggers them
""" """
for event, tasks in self.event_waiting.copy().items(): for event, tasks in self._event_waiting.copy().items():
if event._set: if event._set:
event.event_caught = True event.event_caught = True
for task in tasks: for task in tasks:
task._notify = event._notify task._notify = event._notify
self.tasks.extend(tasks + [event.notifier]) self.tasks.extend(tasks + [event.notifier])
self.event_waiting.pop(event) self._event_waiting.pop(event)
def check_sleeping(self): def _check_sleeping(self):
""" """
Checks and reschedules sleeping tasks Checks and reschedules sleeping tasks
""" """
@ -133,7 +133,7 @@ class AsyncScheduler:
if not self.paused: if not self.paused:
break break
def check_io(self): def _check_io(self):
""" """
Checks and schedules task to perform I/O Checks and schedules task to perform I/O
""" """
@ -192,7 +192,7 @@ class AsyncScheduler:
return parent return parent
# TODO: More generic I/O rather than just sockets # TODO: More generic I/O rather than just sockets
def want_read(self, sock: socket.socket): def _want_read(self, sock: socket.socket):
""" """
Handler for the 'want_read' event, registers the socket inside the selector to perform I/0 multiplexing Handler for the 'want_read' event, registers the socket inside the selector to perform I/0 multiplexing
""" """
@ -203,16 +203,13 @@ class AsyncScheduler:
return # Socket is already scheduled! return # Socket is already scheduled!
else: else:
self.selector.unregister(sock) self.selector.unregister(sock)
busy = False
self.current_task._last_io = "READ", sock self.current_task._last_io = "READ", sock
try: try:
self.selector.register(sock, EVENT_READ, self.current_task) self.selector.register(sock, EVENT_READ, self.current_task)
except KeyError: except KeyError:
busy = True raise ResourceBusy("The given resource is busy!") from None
if busy:
raise ResourceBusy("The given resource is busy!")
def want_write(self, sock: socket.socket): def _want_write(self, sock: socket.socket):
""" """
Handler for the 'want_write' event, registers the socket inside the selector to perform I/0 multiplexing Handler for the 'want_write' event, registers the socket inside the selector to perform I/0 multiplexing
""" """
@ -223,16 +220,13 @@ class AsyncScheduler:
return # Socket is already scheduled! return # Socket is already scheduled!
else: else:
self.selector.unregister(sock) # modify() causes issues self.selector.unregister(sock) # modify() causes issues
busy = False
self.current_task._last_io = "WRITE", sock self.current_task._last_io = "WRITE", sock
try: try:
self.selector.register(sock, EVENT_WRITE, self.current_task) self.selector.register(sock, EVENT_WRITE, self.current_task)
except KeyError: except KeyError:
busy = True raise ResourceBusy("The given resource is busy!") from None
if busy:
raise ResourceBusy("The given resource is busy!")
def join(self, child: types.coroutine): def _join(self, child: types.coroutine):
""" """
Handler for the 'join' event, does some magic to tell the scheduler Handler for the 'join' event, does some magic to tell the scheduler
to wait until the passed coroutine ends. The result of this call equals whatever the to wait until the passed coroutine ends. The result of this call equals whatever the
@ -255,7 +249,7 @@ class AsyncScheduler:
"Joining the same task multiple times is not allowed!" "Joining the same task multiple times is not allowed!"
) )
def sleep(self, seconds: int or float): def _sleep(self, seconds: int or float):
""" """
Puts the caller to sleep for a given amount of seconds Puts the caller to sleep for a given amount of seconds
""" """
@ -266,7 +260,7 @@ class AsyncScheduler:
else: else:
self.tasks.append(self.current_task) self.tasks.append(self.current_task)
def event_set(self, event, value): def _event_set(self, event, value):
""" """
Sets an event Sets an event
""" """
@ -276,7 +270,7 @@ class AsyncScheduler:
event._notify = value event._notify = value
self.events[event] = value self.events[event] = value
def event_wait(self, event): def _event_wait(self, event):
""" """
Waits for an event Waits for an event
""" """
@ -288,9 +282,9 @@ class AsyncScheduler:
else: else:
return self.events[event] return self.events[event]
else: else:
self.event_waiting[event].append(self.current_task) self._event_waiting[event].append(self.current_task)
def cancel(self, task): def _cancel(self, task):
""" """
Handler for the 'cancel' event, throws CancelledError inside a coroutine Handler for the 'cancel' event, throws CancelledError inside a coroutine
in order to stop it from executing. The loop continues to execute as tasks in order to stop it from executing. The loop continues to execute as tasks
@ -312,7 +306,7 @@ class AsyncScheduler:
return AsyncSocket(sock, self) return AsyncSocket(sock, self)
async def read_sock(self, sock: socket.socket, buffer: int): async def _read_sock(self, sock: socket.socket, buffer: int):
""" """
Reads from a socket asynchronously, waiting until the resource is available and returning up to buffer bytes Reads from a socket asynchronously, waiting until the resource is available and returning up to buffer bytes
from the socket from the socket
@ -321,7 +315,7 @@ class AsyncScheduler:
await want_read(sock) await want_read(sock)
return sock.recv(buffer) return sock.recv(buffer)
async def accept_sock(self, sock: socket.socket): async def _accept_sock(self, sock: socket.socket):
""" """
Accepts a socket connection asynchronously, waiting until the resource is available and returning the Accepts a socket connection asynchronously, waiting until the resource is available and returning the
result of the accept() call result of the accept() call
@ -330,7 +324,7 @@ class AsyncScheduler:
await want_read(sock) await want_read(sock)
return sock.accept() return sock.accept()
async def sock_sendall(self, sock: socket.socket, data: bytes): async def _sock_sendall(self, sock: socket.socket, data: bytes):
""" """
Sends all the passed data, as bytes, trough the socket asynchronously Sends all the passed data, as bytes, trough the socket asynchronously
""" """
@ -340,7 +334,7 @@ class AsyncScheduler:
sent_no = sock.send(data) sent_no = sock.send(data)
data = data[sent_no:] data = data[sent_no:]
async def close_sock(self, sock: socket.socket): async def _close_sock(self, sock: socket.socket):
""" """
Closes the socket asynchronously Closes the socket asynchronously
""" """
@ -348,7 +342,7 @@ class AsyncScheduler:
await want_write(sock) await want_write(sock)
return sock.close() return sock.close()
async def connect_sock(self, sock: socket.socket, addr: tuple): async def _connect_sock(self, sock: socket.socket, addr: tuple):
""" """
Connects a socket asynchronously Connects a socket asynchronously
""" """

View File

@ -47,14 +47,14 @@ class AsyncSocket(object):
if self._closed: if self._closed:
raise ResourceClosed("I/O operation on closed socket") raise ResourceClosed("I/O operation on closed socket")
self.loop.current_task.status = "I/O" self.loop.current_task.status = "I/O"
return await self.loop.read_sock(self.sock, max_size) return await self.loop._read_sock(self.sock, max_size)
async def accept(self): async def accept(self):
"""Accepts the socket, completing the 3-step TCP handshake asynchronously""" """Accepts the socket, completing the 3-step TCP handshake asynchronously"""
if self._closed: if self._closed:
raise ResourceClosed("I/O operation on closed socket") raise ResourceClosed("I/O operation on closed socket")
to_wrap = await self.loop.accept_sock(self.sock) to_wrap = await self.loop._accept_sock(self.sock)
return self.loop.wrap_socket(to_wrap[0]), to_wrap[1] return self.loop.wrap_socket(to_wrap[0]), to_wrap[1]
async def send_all(self, data: bytes): async def send_all(self, data: bytes):
@ -62,7 +62,7 @@ class AsyncSocket(object):
if self._closed: if self._closed:
raise ResourceClosed("I/O operation on closed socket") raise ResourceClosed("I/O operation on closed socket")
return await self.loop.sock_sendall(self.sock, data) return await self.loop._sock_sendall(self.sock, data)
async def close(self): async def close(self):
"""Closes the socket asynchronously""" """Closes the socket asynchronously"""
@ -70,7 +70,7 @@ class AsyncSocket(object):
if self._closed: if self._closed:
raise ResourceClosed("I/O operation on closed socket") raise ResourceClosed("I/O operation on closed socket")
await sleep(0) # Give the scheduler the time to unregister the socket first await sleep(0) # Give the scheduler the time to unregister the socket first
await self.loop.close_sock(self.sock) await self.loop._close_sock(self.sock)
self._closed = True self._closed = True
async def connect(self, addr: tuple): async def connect(self, addr: tuple):
@ -78,12 +78,14 @@ class AsyncSocket(object):
if self._closed: if self._closed:
raise ResourceClosed("I/O operation on closed socket") raise ResourceClosed("I/O operation on closed socket")
await self.loop.connect_sock(self.sock, addr) await self.loop._connect_sock(self.sock, addr)
def __enter__(self): async def __aenter__(self):
await sleep(0)
return self.sock.__enter__() return self.sock.__enter__()
def __exit__(self, *args): async def __aexit__(self, *args):
await sleep(0)
return self.sock.__exit__(*args) return self.sock.__exit__(*args)
def __repr__(self): def __repr__(self):

View File

@ -24,7 +24,7 @@ async def server(address: tuple):
async def echo_handler(sock: AsyncSocket, addr: tuple): async def echo_handler(sock: AsyncSocket, addr: tuple):
with sock: async with sock:
await sock.send_all(b"Welcome to the server pal!\n") await sock.send_all(b"Welcome to the server pal!\n")
while True: while True:
data = await sock.receive(1000) data = await sock.receive(1000)