2023-06-06 11:31:30 +02:00
|
|
|
import structio
|
|
|
|
import logging
|
|
|
|
import sys
|
|
|
|
|
|
|
|
# An asynchronous chatroom
|
|
|
|
|
|
|
|
clients: dict[structio.socket.AsyncSocket, list[str, str]] = {}
|
|
|
|
names: set[str] = set()
|
|
|
|
|
|
|
|
|
|
|
|
async def event_handler(q: structio.Queue):
|
|
|
|
"""
|
|
|
|
Reads data submitted onto the queue
|
|
|
|
"""
|
|
|
|
|
|
|
|
try:
|
|
|
|
logging.info("Event handler spawned")
|
|
|
|
while True:
|
|
|
|
msg, payload = await q.get()
|
|
|
|
logging.info(f"Caught event {msg!r} with the following payload: {payload}")
|
|
|
|
except Exception as e:
|
|
|
|
logging.error(f"An exception occurred in the message handler -> {type(e).__name__}: {e}")
|
|
|
|
except structio.exceptions.Cancelled:
|
|
|
|
logging.warning(f"Cancellation detected, message handler shutting down")
|
|
|
|
# Propagate the cancellation
|
|
|
|
raise
|
|
|
|
|
|
|
|
|
|
|
|
async def serve(bind_address: tuple):
|
|
|
|
"""
|
|
|
|
Serves asynchronously forever (or until Ctrl+C ;))
|
|
|
|
:param bind_address: The address to bind the server to, represented as a tuple
|
|
|
|
(address, port) where address is a string and port is an integer
|
|
|
|
"""
|
|
|
|
|
|
|
|
sock = structio.socket.socket()
|
|
|
|
queue = structio.Queue()
|
|
|
|
await sock.bind(bind_address)
|
|
|
|
await sock.listen(5)
|
|
|
|
logging.info(f"Serving asynchronously at {bind_address[0]}:{bind_address[1]}")
|
|
|
|
async with structio.create_pool() as pool:
|
|
|
|
pool.spawn(event_handler, queue)
|
|
|
|
async with sock:
|
|
|
|
while True:
|
|
|
|
try:
|
|
|
|
conn, address_tuple = await sock.accept()
|
|
|
|
clients[conn] = ["", f"{address_tuple[0]}:{address_tuple[1]}"]
|
|
|
|
await queue.put(("connect", clients[conn]))
|
|
|
|
logging.info(f"{address_tuple[0]}:{address_tuple[1]} connected")
|
2023-06-06 12:25:56 +02:00
|
|
|
pool.spawn(handler, conn, queue)
|
2023-06-06 11:31:30 +02:00
|
|
|
except Exception as err:
|
|
|
|
# Because exceptions just *work*
|
|
|
|
logging.info(f"{address_tuple[0]}:{address_tuple[1]} has raised {type(err).__name__}: {err}")
|
|
|
|
|
|
|
|
|
|
|
|
async def handler(sock: structio.socket.AsyncSocket, q: structio.Queue):
|
|
|
|
"""
|
|
|
|
Handles a single client connection
|
|
|
|
:param sock: The AsyncSocket object connected to the client
|
|
|
|
"""
|
|
|
|
|
|
|
|
address = clients[sock][1]
|
|
|
|
name = ""
|
|
|
|
async with sock: # Closes the socket automatically
|
|
|
|
await sock.send_all(b"Welcome to the chatroom pal, may you tell me your name?\n> ")
|
|
|
|
cond = True
|
|
|
|
while cond:
|
|
|
|
while not name.endswith("\n"):
|
|
|
|
name = (await sock.receive(64)).decode()
|
|
|
|
if name == "":
|
|
|
|
cond = False
|
|
|
|
break
|
|
|
|
name = name.rstrip("\n")
|
|
|
|
if name not in names:
|
|
|
|
names.add(name)
|
|
|
|
clients[sock][0] = name
|
|
|
|
break
|
|
|
|
else:
|
|
|
|
await sock.send_all(b"Sorry, but that name is already taken. Try again!\n> ")
|
|
|
|
await sock.send_all(f"Okay {name}, welcome to the chatroom!\n".encode())
|
|
|
|
await q.put(("join", (address, name)))
|
|
|
|
logging.info(f"{name} has joined the chatroom ({address}), informing clients")
|
|
|
|
for i, client_sock in enumerate(clients):
|
|
|
|
if client_sock != sock and clients[client_sock][0]:
|
|
|
|
await client_sock.send_all(f"{name} joins the chatroom!\n> ".encode())
|
|
|
|
while True:
|
|
|
|
await sock.send_all(b"> ")
|
|
|
|
data = await sock.receive(1024)
|
|
|
|
if not data:
|
|
|
|
break
|
|
|
|
decoded = data.decode().rstrip("\n")
|
|
|
|
if decoded.startswith("/"):
|
|
|
|
logging.info(f"{name} issued server command {decoded}")
|
|
|
|
await q.put(("cmd", (name, decoded[1:])))
|
|
|
|
match decoded[1:]:
|
|
|
|
case "bye":
|
|
|
|
await sock.send_all(b"Bye!\n")
|
|
|
|
break
|
|
|
|
case _:
|
|
|
|
await sock.send_all(b"Unknown command\n")
|
|
|
|
else:
|
|
|
|
await q.put(("msg", (name, data)))
|
|
|
|
logging.info(f"Got: {data!r} from {address}")
|
|
|
|
for i, client_sock in enumerate(clients):
|
|
|
|
if client_sock != sock and clients[client_sock][0]:
|
|
|
|
logging.info(f"Sending {data!r} to {':'.join(map(str, await client_sock.getpeername()))}")
|
|
|
|
if not data.endswith(b"\n"):
|
|
|
|
data += b"\n"
|
|
|
|
await client_sock.send_all(f"[{name}] ({address}): {data.decode()}> ".encode())
|
|
|
|
logging.info(f"Sent {data!r} to {i} clients")
|
|
|
|
await q.put(("leave", name))
|
|
|
|
logging.info(f"Connection from {address} closed")
|
|
|
|
logging.info(f"{name} has left the chatroom ({address}), informing clients")
|
|
|
|
for i, client_sock in enumerate(clients):
|
|
|
|
if client_sock != sock and clients[client_sock][0]:
|
|
|
|
await client_sock.send_all(f"{name} has left the chatroom\n> ".encode())
|
|
|
|
clients.pop(sock)
|
|
|
|
names.discard(name)
|
|
|
|
logging.info("Handler shutting down")
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
port = int(sys.argv[1]) if len(sys.argv) > 1 else 1501
|
|
|
|
logging.basicConfig(
|
|
|
|
level=20,
|
|
|
|
format="[%(levelname)s] %(asctime)s %(message)s",
|
|
|
|
datefmt="%d/%m/%Y %p",
|
|
|
|
)
|
|
|
|
try:
|
|
|
|
structio.run(serve, ("0.0.0.0", port))
|
|
|
|
except (Exception, KeyboardInterrupt) as error: # Exceptions propagate!
|
|
|
|
if isinstance(error, KeyboardInterrupt):
|
|
|
|
logging.info("Ctrl+C detected, exiting")
|
|
|
|
else:
|
|
|
|
logging.error(f"Exiting due to a {type(error).__name__}: {error}")
|
|
|
|
|