From 268745c552efde8556afd4d5833950dd1abdea6c Mon Sep 17 00:00:00 2001 From: nocturn9x Date: Mon, 31 May 2021 22:56:03 +0200 Subject: [PATCH] Attempts to fix SSL --- giambio/core.py | 8 +++++--- giambio/socket.py | 13 +++++++++---- 2 files changed, 14 insertions(+), 7 deletions(-) diff --git a/giambio/core.py b/giambio/core.py index ba5ec6e..22afb8f 100644 --- a/giambio/core.py +++ b/giambio/core.py @@ -599,7 +599,7 @@ class AsyncScheduler: # Since we don't reschedule the task, it will # not execute until check_events is called - def register_sock(self, sock: socket.socket, evt_type: str): + def register_sock(self, sock, evt_type: str): """ Registers the given socket inside the selector to perform I/0 multiplexing @@ -612,6 +612,7 @@ class AsyncScheduler: :type evt_type: str """ + self.current_task.status = "io" evt = EVENT_READ if evt_type == "read" else EVENT_WRITE if self.current_task.last_io: @@ -635,8 +636,9 @@ class AsyncScheduler: # If the event to listen for has changed we just modify it self.selector.modify(sock, evt, self.current_task) self.current_task.last_io = (evt_type, sock) - else: - # Otherwise we register the new socket in our selector + elif not self.current_task.last_io or self.current_task.last_io[1] != sock: + # The task has either registered a new socket or is doing + # I/O for the first time. In both cases, we register a new socket self.current_task.last_io = evt_type, sock try: self.selector.register(sock, evt, self.current_task) diff --git a/giambio/socket.py b/giambio/socket.py index eefdb21..f1636fb 100644 --- a/giambio/socket.py +++ b/giambio/socket.py @@ -15,6 +15,7 @@ See the License for the specific language governing permissions and limitations under the License. """ +import ssl import socket as builtin_socket from giambio.run import get_event_loop from giambio.exceptions import ResourceClosed @@ -29,7 +30,7 @@ class AsyncSocket: Abstraction layer for asynchronous sockets """ - def __init__(self, sock: builtin_socket.socket): + def __init__(self, sock): self.sock = sock self.loop = get_event_loop() self._closed = False @@ -43,6 +44,9 @@ class AsyncSocket: if self._closed: raise ResourceClosed("I/O operation on closed socket") assert max_size >= 1, "max_size must be >= 1" + if isinstance(self.sock, ssl.SSLSocket) and self.sock.pending(): + print(self.sock.pending()) + return self.sock.recv(self.sock.pending()) await want_read(self.sock) try: return self.sock.recv(max_size) @@ -112,9 +116,8 @@ class AsyncSocket: await want_write(self.sock) try: self.sock.connect(addr) - except IOInterrupt: + except IOInterrupt as io_interrupt: await want_write(self.sock) - self.sock.connect(addr) async def bind(self, addr: tuple): """ @@ -132,7 +135,7 @@ class AsyncSocket: """ Starts listening with the given backlog - :param backlog: The address, port tuple to bind to + :param backlog: The socket's backlog :type backlog: int """ @@ -151,6 +154,8 @@ class AsyncSocket: if not self._closed and self.loop.selector.get_map() and self.sock in self.loop.selector.get_map(): self.loop.selector.unregister(self.sock) + self.loop.current_task.last_io = () + self._closed = True async def __aenter__(self): return self