Attempts to fix SSL

This commit is contained in:
nocturn9x 2021-05-31 22:56:03 +02:00
parent 668404b049
commit 268745c552
2 changed files with 14 additions and 7 deletions

View File

@ -599,7 +599,7 @@ class AsyncScheduler:
# Since we don't reschedule the task, it will # Since we don't reschedule the task, it will
# not execute until check_events is called # not execute until check_events is called
def register_sock(self, sock: socket.socket, evt_type: str): def register_sock(self, sock, evt_type: str):
""" """
Registers the given socket inside the Registers the given socket inside the
selector to perform I/0 multiplexing selector to perform I/0 multiplexing
@ -612,6 +612,7 @@ class AsyncScheduler:
:type evt_type: str :type evt_type: str
""" """
self.current_task.status = "io" self.current_task.status = "io"
evt = EVENT_READ if evt_type == "read" else EVENT_WRITE evt = EVENT_READ if evt_type == "read" else EVENT_WRITE
if self.current_task.last_io: if self.current_task.last_io:
@ -635,8 +636,9 @@ class AsyncScheduler:
# If the event to listen for has changed we just modify it # If the event to listen for has changed we just modify it
self.selector.modify(sock, evt, self.current_task) self.selector.modify(sock, evt, self.current_task)
self.current_task.last_io = (evt_type, sock) self.current_task.last_io = (evt_type, sock)
else: elif not self.current_task.last_io or self.current_task.last_io[1] != sock:
# Otherwise we register the new socket in our selector # The task has either registered a new socket or is doing
# I/O for the first time. In both cases, we register a new socket
self.current_task.last_io = evt_type, sock self.current_task.last_io = evt_type, sock
try: try:
self.selector.register(sock, evt, self.current_task) self.selector.register(sock, evt, self.current_task)

View File

@ -15,6 +15,7 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
""" """
import ssl
import socket as builtin_socket import socket as builtin_socket
from giambio.run import get_event_loop from giambio.run import get_event_loop
from giambio.exceptions import ResourceClosed from giambio.exceptions import ResourceClosed
@ -29,7 +30,7 @@ class AsyncSocket:
Abstraction layer for asynchronous sockets Abstraction layer for asynchronous sockets
""" """
def __init__(self, sock: builtin_socket.socket): def __init__(self, sock):
self.sock = sock self.sock = sock
self.loop = get_event_loop() self.loop = get_event_loop()
self._closed = False self._closed = False
@ -43,6 +44,9 @@ class AsyncSocket:
if self._closed: if self._closed:
raise ResourceClosed("I/O operation on closed socket") raise ResourceClosed("I/O operation on closed socket")
assert max_size >= 1, "max_size must be >= 1" assert max_size >= 1, "max_size must be >= 1"
if isinstance(self.sock, ssl.SSLSocket) and self.sock.pending():
print(self.sock.pending())
return self.sock.recv(self.sock.pending())
await want_read(self.sock) await want_read(self.sock)
try: try:
return self.sock.recv(max_size) return self.sock.recv(max_size)
@ -112,9 +116,8 @@ class AsyncSocket:
await want_write(self.sock) await want_write(self.sock)
try: try:
self.sock.connect(addr) self.sock.connect(addr)
except IOInterrupt: except IOInterrupt as io_interrupt:
await want_write(self.sock) await want_write(self.sock)
self.sock.connect(addr)
async def bind(self, addr: tuple): async def bind(self, addr: tuple):
""" """
@ -132,7 +135,7 @@ class AsyncSocket:
""" """
Starts listening with the given backlog Starts listening with the given backlog
:param backlog: The address, port tuple to bind to :param backlog: The socket's backlog
:type backlog: int :type backlog: int
""" """
@ -151,6 +154,8 @@ class AsyncSocket:
if not self._closed and self.loop.selector.get_map() and self.sock in self.loop.selector.get_map(): if not self._closed and self.loop.selector.get_map() and self.sock in self.loop.selector.get_map():
self.loop.selector.unregister(self.sock) self.loop.selector.unregister(self.sock)
self.loop.current_task.last_io = ()
self._closed = True
async def __aenter__(self): async def __aenter__(self):
return self return self