diff --git a/.gitignore b/.gitignore index f8b73e7..6b77028 100644 --- a/.gitignore +++ b/.gitignore @@ -138,3 +138,5 @@ dmypy.json # Cython debug symbols cython_debug/ +/config.py +/piccolo_conf.py diff --git a/.idea/socialMedia.iml b/.idea/socialMedia.iml new file mode 100644 index 0000000..74d515a --- /dev/null +++ b/.idea/socialMedia.iml @@ -0,0 +1,10 @@ + + + + + + + + + + \ No newline at end of file diff --git a/README.md b/README.md index 3c953c7..655f72a 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,38 @@ # PySimpleSocial -An advanced REST API written in Python for a generic social media website \ No newline at end of file +An advanced REST API written in Python for a generic social media website. + +__Note__: This is a WIP so far. + +## Tech Stack + +The project is written using [FastAPI](https://https://fastapi.tiangolo.com/), [piccolo](https://piccolo-orm.readthedocs.io/en/latest/) and [uvicorn](https://www.uvicorn.org/). Other awesome libraries used (only direct dependencies are listed here): +- [validators](https://validators.readthedocs.io/en/latest/) +- [pydantic](https://pydantic-docs.helpmanual.io/) +- [aiosmtplib](https://aiosmtplib.readthedocs.io/en/latest/usage.html) +- [uvloop](https://uvloop.readthedocs.io/) +- [slowapi](https://slowapi.readthedocs.io/en/latest/) +- [fastapi-login](https://fastapi-login.readthedocs.io/) +- [bcrypt](https://github.com/pyca/bcrypt/) + +# Feature overview + +__Note__: Not all of this is implemented yet + +- Simple authentication system using salted bcrypt hashes for password storage +- PostgreSQL is used as the main database +- Simple rate limiting using redis/in-memory storage +- Support for various kinds of media stored in a CDN, directly inside the database or on a local/remote filesystem +- Regular social media mechanics: (Un)following users, posting media with captions, etc. +- User settings (change username, profile picture, etc.) +- Simple messaging system using websockets or a polling HTTP API +- Admin functionality with basic metrics and administration features (flagging/deleting users/posts, handling tickets, etc.) + +# Setup + +Move the *.py.example files to their respective *.py files, fill them as necessary, then simply install the dependencies via pip and run main.py + +# License + +This software is licensed under the MIT license. For more information, read the [license file](LICENSE) + diff --git a/config.py.example b/config.py.example new file mode 100644 index 0000000..87fe7aa --- /dev/null +++ b/config.py.example @@ -0,0 +1,115 @@ +# Configuration file. Each variable can be overridden by a +# corresponding environment variable +import os +import sys +import logging +import pathlib + +from slowapi import Limiter +from slowapi.util import get_remote_address +from fastapi_login import LoginManager + + +# Authentication configuration +LOGIN_SECRET_KEY = ( + os.getenv("LOGIN_SECRET_KEY") or "login-secret" +) # Recommended value: os.urandom(24).hex() +USE_BEARER_HEADER = True +USE_COOKIE = True + +# Logging configuration +LOG_LEVEL = int(os.getenv("LOG_LEVEL") or 0) or 20 +LOG_FILE = os.getenv("LOG_FILE") or "" # Empty to disable +LOG_FORMAT = os.getenv("LOG_FORMAT") or "[%(levelname)s - %(asctime)s] %(message)s" +LOG_DATE_FORMAT = os.getenv("LOG_DATE_FORMAT") or "%d/%m/%Y %p" + +# Bcrypt configuration +BCRYPT_ROUNDS = ( + os.getenv("BCRYPT_ROUNDS") or 10 +) # How many rounds are used when salting + + +# Rate limit configuration +REDIS_URL = "" # Used for rate limits. Empty to fall back to memory +REDIS_OPTIONS = {} # Options for redis +RATELIMIT_ENABLED = False # False to disable rate limiting +RATELIMIT_STRATEGY = "moving-window" # Refer to https://flask-limiter.readthedocs.io/en/stable/strategies.html + +# Session configuration +SESSION_EXPIRE_LIMIT = 3600 # Unit is in seconds +SESSION_COOKIE_NAME = "_social_media_session" +COOKIE_SAMESITE_POLICY = "none" # Options are "lax", "none", "strict" +COOKIE_DOMAIN = "localhost" # Empty to disable this +COOKIE_PATH = "/" +COOKIE_HTTPONLY = True +SECURE_COOKIE = False # Set to true in production, False during development (unless your local server has HTTPS) + +# SMTP configuration +SMTP_HOST = "smtp.nocturn9x.space" +SMTP_USER = "info@example.com" +SMTP_PASSWORD = "password" +SMTP_PORT = 587 +SMTP_USE_TLS = True +SMTP_FROM_USER = "info@example.com" +SMTP_TEMPLATES_DIRECTORY = pathlib.Path(__file__) / "templates" / "email" +SMTP_TIMEOUT = 10 + +# Miscellaneous + +# Usernames containing these characters are not valid +INVALID_USERNAME_CHARACTERS = ["@", "\\", "/"] # Empty this to allow any character +# Empty this to disable username validation +VALIDATE_USERNAME_REGEX = ( + rf"^([^{''.join(INVALID_USERNAME_CHARACTERS)}]|[a-z0-9A-Z]){{5,32}}$" +) +# Criteria: +# - Between 10 and 72 characters long +# - At least 2 uppercase letters +# - At least 3 lowercase letters +# - At least one special character (!, @, #, $, %, &, *, _, +, /, \, (, ), £, ", ?, ^ +# - At least 2 numbers +# You can change the repetitions to enforce stricter/laxer rules or empty +# this field to disable weakness validation +VALIDATE_PASSWORD_REGEX = r"^(?=.*[A-Z]){2,}(?=.*[!@%#$&*_^\?\\\/(\)\+\-])+(?=.*[0-9]){2,}(?=.*[a-z]){3,}.{10,72}$" + + +class NotAuthenticated(Exception): + pass + + +if __name__ != "__main__": + LOGGER: logging.Logger = logging.getLogger("socialMedia") + LOGGER.setLevel(LOG_LEVEL) + handler = logging.StreamHandler(sys.stderr) + formatter = logging.Formatter(fmt=LOG_FORMAT, datefmt=LOG_DATE_FORMAT) + handler.setFormatter(formatter) + LOGGER.addHandler(handler) + handler.setLevel(LOG_LEVEL) + if LOG_FILE: + file_handler = logging.FileHandler(LOG_FILE, "a", "utf8") + file_handler.setFormatter(formatter) + LOGGER.addHandler(handler) + file_handler.setLevel(LOG_LEVEL) + logging.getLogger("uvicorn").addHandler(file_handler) + LIMITER = Limiter( + key_func=get_remote_address, + strategy=RATELIMIT_STRATEGY, + storage_uri=REDIS_URL or None, + in_memory_fallback_enabled=bool(REDIS_URL), + storage_options=REDIS_OPTIONS, + enabled=RATELIMIT_ENABLED, + ) + MANAGER = LoginManager( + LOGIN_SECRET_KEY, + "/login", + use_cookie=True, + cookie_name=SESSION_COOKIE_NAME, + use_header=USE_BEARER_HEADER, + custom_exception=NotAuthenticated, + ) + +# Uvicorn config + +HOST = os.getenv("HOST") or "localhost" +PORT = int(os.getenv("PORT") or 0) or 8000 +WORKERS = int(os.getenv("PORT") or 0) or 1 diff --git a/endpoints/__init__.py b/endpoints/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/endpoints/users.py b/endpoints/users.py new file mode 100644 index 0000000..fee0a20 --- /dev/null +++ b/endpoints/users.py @@ -0,0 +1,334 @@ +import re + +import bcrypt +from uuid import UUID + +import validators +from fastapi import APIRouter as FastAPI, Depends, Response, Request +from fastapi.exceptions import HTTPException +from fastapi.security import OAuth2PasswordRequestForm +from datetime import timedelta + +from config import ( + BCRYPT_ROUNDS, + LOGGER, + SESSION_EXPIRE_LIMIT, + COOKIE_SAMESITE_POLICY, + COOKIE_DOMAIN, + MANAGER, + SESSION_COOKIE_NAME, + LIMITER, + VALIDATE_PASSWORD_REGEX, + VALIDATE_USERNAME_REGEX, + SECURE_COOKIE, + COOKIE_PATH, + COOKIE_HTTPONLY, +) +from orm.users import UserModel, User + +router = FastAPI() + + +async def get_user_by_id( + public_id: UUID, include_secrets: bool = False, restricted_ok: bool = False, + deleted_ok: bool = False +) -> dict | None: + """ + Retrieves a user by its public ID + """ + + user = ( + await User.select( + *User.all_columns(exclude=["public_id"]), + User.public_id.as_alias("id"), + exclude_secrets=not include_secrets, + ) + .where(User.public_id == public_id) + .first() + ) + if user: + # Performs validation + UserModel(**user) + if (user["deleted"] and not deleted_ok) or (user["restricted"] and not restricted_ok): + return + return user + return + + +@MANAGER.user_loader() +async def get_self_by_id(public_id: UUID) -> dict: + return await get_user_by_id(public_id, include_secrets=True, restricted_ok=True) + + +async def get_user_by_username( + username: str, include_secrets: bool = False, restricted_ok: bool = False, + deleted_ok: bool = False +) -> dict | None: + """ + Retrieves a user by its public username + """ + + user = ( + await User.select( + *User.all_columns(exclude=["public_id"]), + User.public_id.as_alias("id"), + exclude_secrets=not include_secrets, + ) + .where(User.username == username) + .first() + ) + if user: + # Performs validation + UserModel(**user) + if (user["deleted"] and not deleted_ok) or (user["restricted"] and not restricted_ok): + return + return user + return + + +async def get_user_by_email( + email: str, include_secrets: bool = False, restricted_ok: bool = False, + deleted_ok: bool = False +) -> dict | None: + """ + Retrieves a user by its email address (meant to + be used internally) + """ + + user = ( + await User.select( + *User.all_columns(exclude=["public_id"]), + User.public_id.as_alias("id"), + exclude_secrets=not include_secrets, + ) + .where(User.email_address == email) + .first() + ) + if user: + # Performs validation + UserModel(**user) + if (user["deleted"] and not deleted_ok) or (user["restricted"] and not restricted_ok): + return + return user + return + + +@LIMITER.limit("5/minute") +@router.post("/user") +async def login( + request: Request, response: Response, data: OAuth2PasswordRequestForm = Depends() +) -> dict: + if request.cookies.get(SESSION_COOKIE_NAME): + raise HTTPException(status_code=400, detail="Please logout first") + username = data.username + if len(username) > 32: + raise HTTPException( + status_code=413, detail="Authentication failed: username is too long" + ) + try: + password = data.password.encode() + if len(password) > 72: + raise HTTPException( + status_code=413, detail="Authentication failed: password is too long" + ) + except UnicodeEncodeError as e: + LOGGER.warning( + f"An error occurred while attempting to decode password for user {username} -> {type(e).__name__}: {e}" + ) + raise HTTPException( + status_code=413, + detail="Authentication failed: invalid characters in password", + ) + if not ( + user := await get_user_by_username( + username, include_secrets=True, restricted_ok=True + ) + ): + raise HTTPException( + status_code=413, + detail="Authentication failed: the user does not exist", + ) + if not bcrypt.checkpw(password, user["password_hash"]): + raise HTTPException( + status_code=413, + detail="Authentication failed: password mismatch", + ) + token = MANAGER.create_access_token( + expires=timedelta(seconds=SESSION_EXPIRE_LIMIT), data={"sub": str(user["id"])} + ) + response.set_cookie( + secure=SECURE_COOKIE, + key=SESSION_COOKIE_NAME, + max_age=SESSION_EXPIRE_LIMIT, + value=token, + httponly=COOKIE_HTTPONLY, + samesite=COOKIE_SAMESITE_POLICY, + domain=COOKIE_DOMAIN or None, + path=COOKIE_PATH or "/", + ) + return {"status_code": 200, "msg": "Authentication successful"} + + +@router.get("/user/logout") +@LIMITER.limit("5/minute") +async def logout( + request: Request, response: Response, user: dict = Depends(MANAGER) +) -> dict: + """ + Deletes a user's session cookie, logging them + out + """ + + response.delete_cookie( + secure=SECURE_COOKIE, + key=SESSION_COOKIE_NAME, + httponly=COOKIE_HTTPONLY, + samesite=COOKIE_SAMESITE_POLICY, + domain=COOKIE_DOMAIN or None, + path=COOKIE_PATH or "/", + ) + return {"status_code": 200, "msg": "Logged out"} + + +@router.get("/user/me") +@LIMITER.limit("2/second") +async def get_self(request: Request, user: dict = Depends(MANAGER)) -> dict: + """ + Fetches a user's own info. This returns some + extra data such as email address, account + creation date and email verification status, + which is not available from the regular endpoint + """ + + user.pop("password_hash") + user.pop("internal_id") + user.pop("deleted") + return {"status_code": 200, "msg": "Success", "data": user} + + +@router.get("/user/username/{username}") +@LIMITER.limit("30/second") +async def get_user_by_name( + request: Request, username: str, _auth: dict = Depends(MANAGER) +) -> dict: + """ + Fetches a single user by its public ID + """ + + if not (user := await get_user_by_username(username)): + return { + "status_code": 404, + "msg": "Lookup failed: the user does not exist", + } + user.pop("restricted") + user.pop("deleted") + return {"status_code": 200, "msg": "Lookup successful", "data": user} + + +@router.get("/user/id/{public_id}") +@LIMITER.limit("30/second") +async def get_user_by_public_id( + request: Request, public_id: str, _auth: dict = Depends(MANAGER) +) -> dict: + """ + Fetches a single user by its public ID + """ + + if not (user := await get_user_by_id(UUID(public_id))): + raise HTTPException( + status_code=404, detail="Lookup failed: the user does not exist" + ) + user.pop("restricted") + user.pop("deleted") + return {"status_code": 200, "msg": "Lookup successful", "data": user} + + +async def validate_user( + first_name: str, last_name: str, username: str, email: str, password: str +) -> tuple[bool, str]: + """ + Performs some validation upon user creation. Returns + a tuple (success, msg) to be used by routes + """ + + if len(first_name) > 64: + return False, "first name is too long" + if len(first_name) < 5: + return False, "first name is too short" + if len(last_name) > 64: + return False, "last name is too long" + if len(last_name) < 2: + return False, "last name is too short" + if len(username) < 5: + return False, "username is too short" + if len(username) > 32: + return False, "username is too long" + if VALIDATE_USERNAME_REGEX and not re.match(VALIDATE_USERNAME_REGEX, username): + return False, "username is invalid" + if not validators.email(email): + return False, "email is not valid" + if len(password) > 72: + return False, "password is too long" + if VALIDATE_PASSWORD_REGEX and not re.match(VALIDATE_PASSWORD_REGEX, password): + return False, "password is too weak" + if await get_user_by_username(username, deleted_ok=True, restricted_ok=True): + return False, "username is already taken" + if await get_user_by_email(email, deleted_ok=True, restricted_ok=True): + return False, "email is already registered" + return True, "" + + +@router.delete("/user") +@LIMITER.limit("1/minute") +async def delete(request: Request, response: Response, user: dict = Depends(MANAGER)) -> dict: + """ + Sets the user's deleted flag in the database, + without actually deleting the associated + data + """ + + await User.update({User.deleted: True}).where(User.public_id == user["id"]) + response.delete_cookie( + secure=SECURE_COOKIE, + key=SESSION_COOKIE_NAME, + httponly=COOKIE_HTTPONLY, + samesite=COOKIE_SAMESITE_POLICY, + domain=COOKIE_DOMAIN or None, + path=COOKIE_PATH or "/",) + return {"status_code": 200, "msg": "Success"} + + +@router.put("/user") +@LIMITER.limit("2/minute") +async def signup( + request: Request, + first_name: str, + last_name: str, + username: str, + email: str, + password: str, +) -> dict: + """ + Endpoint used to create new users + """ + + if request.cookies.get(SESSION_COOKIE_NAME): + raise HTTPException(status_code=400, detail="Please logout first") + # We don't use FastAPI's validation because we want custom error + # messages + result, msg = await validate_user(first_name, last_name, username, email, password) + if not result: + return {"status_code": 413, "msg": f"Signup failed: {msg}"} + else: + await User.insert( + User( + first_name=first_name, + last_name=last_name, + username=username, + email_address=email, + password_hash=bcrypt.hashpw( + password.encode(), bcrypt.gensalt(BCRYPT_ROUNDS) + ), + ) + ) + return {"status_code": 200, "msg": "Success"} diff --git a/main.py b/main.py new file mode 100644 index 0000000..98d16b9 --- /dev/null +++ b/main.py @@ -0,0 +1,127 @@ +import uvloop +import asyncio +import uvicorn + +from fastapi import FastAPI, Request +from fastapi.exceptions import ( + HTTPException, + RequestValidationError, + StarletteHTTPException, +) +from slowapi.errors import RateLimitExceeded + + +from endpoints import users +from config import ( + LOGGER, + LIMITER, + NotAuthenticated, + SMTP_HOST, + SMTP_USER, + SMTP_PORT, + SMTP_USE_TLS, + SMTP_PASSWORD, + SMTP_TIMEOUT, + HOST, + PORT, + WORKERS, +) +from orm import create_tables, Media +from util.exception_handlers import ( + http_exception, + rate_limited, + request_invalid, + not_authenticated, +) +from util.email import test_smtp + +app = FastAPI() + + +@app.get("/") +@LIMITER.limit("10/second") +async def root(request: Request): + return {"status_code": 403, "msg": "Unauthorized"} + + +@app.get("/ping") +@LIMITER.limit("1/minute") +async def ping(request: Request) -> dict: + """ + This handler simply replies to "ping" requests and + is used to check whether the API is up and running. + It also performs a sanity check with the database and + the SMTP server to ensure that they are functioning correctly. + For this reason, this endpoint's rate limit is much stricter + """ + + LOGGER.info(f"Processing ping request from {request.client.host}") + try: + await Media.raw("SELECT 1;") + await test_smtp( + SMTP_HOST, + SMTP_PORT, + SMTP_USER, + SMTP_PASSWORD, + SMTP_TIMEOUT, + SMTP_USE_TLS, + True, + ) + return {"status_code": 200, "msg": "OK"} + except Exception: + raise HTTPException(500) + + +async def startup_checks(): + LOGGER.info("Initializing database") + try: + await create_tables() + await Media.raw("SELECT 1;") + except Exception as e: + LOGGER.error( + f"An error occurred while trying to initialize the database -> {type(e.__name__)}: {e}" + ) + else: + LOGGER.info("Database initialized") + LOGGER.info("Testing SMTP connection") + try: + await test_smtp( + SMTP_HOST, + SMTP_PORT, + SMTP_USER, + SMTP_PASSWORD, + SMTP_TIMEOUT, + SMTP_USE_TLS, + True, + ) + except Exception as e: + LOGGER.error( + f"An error occurred while trying to connect to the SMTP server -> {type(e.__name__)}: {e}" + ) + else: + LOGGER.info("SMTP test was successful") + + +if __name__ == "__main__": + LOGGER.info("Backend starting up!") + LOGGER.debug("Including modules") + app.include_router(users.router) + app.state.limiter = LIMITER + LOGGER.debug("Setting exception handlers") + app.add_exception_handler(RateLimitExceeded, rate_limited) + app.add_exception_handler(NotAuthenticated, not_authenticated) + app.add_exception_handler(HTTPException, http_exception) + app.add_exception_handler(StarletteHTTPException, http_exception) + app.add_exception_handler(RequestValidationError, request_invalid) + LOGGER.debug("Installing uvloop") + uvloop.install() + log_config = uvicorn.config.LOGGING_CONFIG + log_config["formatters"]["access"]["datefmt"] = LOGGER.handlers[0].formatter.datefmt + log_config["formatters"]["default"]["datefmt"] = LOGGER.handlers[ + 0 + ].formatter.datefmt + log_config["formatters"]["access"]["fmt"] = LOGGER.handlers[0].formatter._fmt + log_config["formatters"]["default"]["fmt"] = LOGGER.handlers[0].formatter._fmt + log_config["handlers"]["access"]["stream"] = "ext://sys.stderr" + asyncio.run(startup_checks()) + uvicorn.run(host=HOST, port=PORT, app=app, log_config=log_config, workers=WORKERS) diff --git a/orm/__init__.py b/orm/__init__.py new file mode 100644 index 0000000..ab5efaf --- /dev/null +++ b/orm/__init__.py @@ -0,0 +1,14 @@ +import asyncio +from piccolo.table import create_db_tables +from .users import User +from .media import Media + + +async def create_tables(): + """ + Initializes the database + """ + + await create_db_tables(User, Media, if_not_exists=True) + await User.create_index([User.public_id], if_not_exists=True) + await User.create_index([User.username], if_not_exists=True) diff --git a/orm/media.py b/orm/media.py new file mode 100644 index 0000000..fd3a634 --- /dev/null +++ b/orm/media.py @@ -0,0 +1,19 @@ +""" +Media relation +""" + +from piccolo.table import Table +from piccolo.columns import UUID, Text, Boolean, Date +from piccolo.columns.defaults.date import DateNow + + +class Media(Table): + """ + A piece of media on a CDN + """ + + media_id = UUID(primary_key=True) + media_url = Text(null=False) + flagged = Boolean(default=False, null=False) + deleted = Boolean(default=False, null=False) + creation_date = Date(default=DateNow(), null=False) diff --git a/orm/users.py b/orm/users.py new file mode 100644 index 0000000..42deb61 --- /dev/null +++ b/orm/users.py @@ -0,0 +1,47 @@ +""" +User relation +""" + +from piccolo.utils.pydantic import create_pydantic_model +from piccolo.table import Table +from piccolo.columns import ( + ForeignKey, + Varchar, + BigSerial, + UUID, + Date, + OnDelete, + OnUpdate, + Boolean, + Email, + Bytea, +) +from piccolo.columns.defaults.date import DateNow + +from .media import Media + + +class User(Table, tablename="users"): + internal_id = BigSerial(null=False, secret=True) + public_id = UUID(primary_key=True) + first_name = Varchar(length=64, null=False) + last_name = Varchar(length=64, null=True) + email_address = Email(secret=True) + username = Varchar(length=32, null=False, unique=True) + password_hash = Bytea(null=False, secret=True) + profile_picture = ForeignKey( + references=Media, + on_delete=OnDelete.set_null, + on_update=OnUpdate.cascade, + null=True, + default=None, + ) + creation_date = Date(secret=True, default=DateNow()) + bio = Varchar(length=4096, null=True, default=None) + restricted = Boolean(default=False, null=False) + email_verified = Boolean(default=False, null=False, secret=True) + verified_account = Boolean(default=False, null=False) + deleted = Boolean(default=False, null=False) + + +UserModel = create_pydantic_model(User) diff --git a/piccolo_conf.py.example b/piccolo_conf.py.example new file mode 100644 index 0000000..0752e31 --- /dev/null +++ b/piccolo_conf.py.example @@ -0,0 +1,16 @@ +from piccolo.conf.apps import AppRegistry +from piccolo.engine.postgres import PostgresEngine + +DB = PostgresEngine( + config={ + "host": "example.com", + "user": "database", + "password": "password", + "database": "database", + } +) + + +# A list of paths to piccolo apps +# e.g. ['blog.piccolo_app'] +APP_REGISTRY = AppRegistry(apps=[]) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..3b35fd3 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,10 @@ +piccolo[all]~=0.91.0 +fastapi-login +bcrypt~=4.0.0 +slowapi~=0.1.6 +python-multipart +uvloop~=0.17.0 +fastapi~=0.85.0 +validators +aiosmtplib +uvicorn \ No newline at end of file diff --git a/templates/email/en_US.json b/templates/email/en_US.json new file mode 100644 index 0000000..64966a3 --- /dev/null +++ b/templates/email/en_US.json @@ -0,0 +1,3 @@ +[{ + +}] \ No newline at end of file diff --git a/util/__init__.py b/util/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/util/email.py b/util/email.py new file mode 100644 index 0000000..fa3961c --- /dev/null +++ b/util/email.py @@ -0,0 +1,87 @@ +import asyncio +import logging +import ssl + +import aiosmtplib +from email.mime.base import MIMEBase +from email.mime.text import MIMEText +from email.mime.multipart import MIMEMultipart +from fastapi.responses import RedirectResponse +from typing import Union, List + + +async def send_email( + host: str, + port: int, + message: str, + timeout: int, + sender: str, + recipient: str, + subject: str, + login_email: str, + password: str, + attachments: List[MIMEBase] = tuple(), + use_tls: bool = True, + check_hostname: bool = True, +) -> Union[bool, aiosmtplib.SMTPException]: + """ + Sends an email with the given details. Returns True on success + or an exception object upon failure + """ + + try: + async with aiosmtplib.SMTP(host, port, timeout=timeout) as srv: + msg = MIMEMultipart() + msg["From"] = sender + msg["To"] = recipient + msg["Subject"] = subject + msg.attach(MIMEText(message, "html")) + for attachment in attachments: + msg.attach(attachment) + await srv.ehlo() # We identify ourselves + if use_tls: + context = ssl.create_default_context() + context.check_hostname = check_hostname + await srv.starttls(tls_context=context) + await srv.ehlo() # We do it again, but encrypted! + await srv.login(login_email, password) + await srv.sendmail(sender, recipient, msg.as_string()) + except (aiosmtplib.SMTPException, asyncio.TimeoutError) as error: + logging.error( + f"An error occurred while dealing with {host}:{port} (SMTP): {type(error).__name__}: {error}" + ) + return error + return True + + +async def test_smtp( + host: str, + port: int, + login_email: str, + password: str, + timeout: int, + use_tls: bool = True, + check_hostname: bool = True, +): + """ + Attempts login to the given SMTP server with the given credentials. + Used upon startup, raises an exception upon failure. This will + fail if the server does not support TLS encryption for login + """ + + async with aiosmtplib.SMTP(host, port, timeout=timeout) as srv: + await srv.ehlo() + if use_tls: + context = ssl.create_default_context() + context.check_hostname = check_hostname + await srv.starttls(tls_context=context) + await srv.ehlo() + await srv.login(login_email, password) + + +def redirect(url: str) -> RedirectResponse: + """ + Returns a redirect response to the given url + """ + + return RedirectResponse(url=url) diff --git a/util/exception_handlers.py b/util/exception_handlers.py new file mode 100644 index 0000000..060d673 --- /dev/null +++ b/util/exception_handlers.py @@ -0,0 +1,79 @@ +from config import LOGGER, NotAuthenticated +from fastapi import Request +from fastapi.responses import JSONResponse +from fastapi.exceptions import RequestValidationError +from fastapi.exceptions import HTTPException, StarletteHTTPException +from slowapi.errors import RateLimitExceeded + + +async def rate_limited(request: Request, error: RateLimitExceeded) -> JSONResponse: + n = 0 + while True: + if error.detail[n].isnumeric(): + n += 1 + else: + break + error.detail = error.detail[:n] + " requests" + error.detail[n:] + LOGGER.info( + f"{request.client.host} got rate-limited at {str(request.url)} " + f"(exceeded {error.detail})" + ) + return JSONResponse( + status_code=200, + content={ + "msg": f"Too many requests, retry after {error.detail[error.detail.find('per') + 4:]}", + "status_code": 429, + }, + ) + + +def not_authenticated(request: Request, _: NotAuthenticated) -> JSONResponse: + LOGGER.info(f"{request.client.host} failed to authenticate at {str(request.url)}") + return JSONResponse( + status_code=200, + content={ + "msg": "Authentication is required", + "status_code": 401, + }, + ) + + +def request_invalid(request: Request, exc: RequestValidationError) -> JSONResponse: + LOGGER.info( + f"{request.client.host} sent an invalid request at {request.url!r}: {type(exc).__name__}: {exc}" + ) + return JSONResponse( + status_code=200, + content={ + "msg": f"Bad request: {type(exc).__name__}: {exc}", + "status_code": 400, + }, + ) + + +def http_exception( + request: Request, exc: HTTPException | StarletteHTTPException +) -> JSONResponse: + if exc.status_code >= 500: + LOGGER.error( + f"{request.client.host} raised a {exc.status_code} error at {request.url!r}:" + f"{type(exc).__name__}: {exc}" + ) + return JSONResponse( + status_code=200, + content={ + "msg": "Internal server error", + "status_code": exc.status_code, + }, + ) + else: + LOGGER.info( + f"{request.client.host} raised an HTTP error ({exc.status_code}) at {str(request.url)}" + ) + return JSONResponse( + status_code=200, + content={ + "msg": exc.detail, + "status_code": exc.status_code, + }, + )