diff --git a/src/config.py b/src/config.py new file mode 100644 index 0000000..775baae --- /dev/null +++ b/src/config.py @@ -0,0 +1,143 @@ +# Configuration file +import os +import sys +import copy +import logging +import pathlib + +from slowapi import Limiter +from slowapi.util import get_remote_address +from fastapi_login import LoginManager + + +# Authentication configuration +LOGIN_SECRET_KEY = ( + os.getenv("LOGIN_SECRET_KEY") or "login-secret" +) # Recommended value: os.urandom(24).hex() +USE_BEARER_HEADER = True +USE_COOKIE = True + +# Logging configuration +LOG_LEVEL = int(os.getenv("LOG_LEVEL") or 0) or 20 +LOG_FILE = os.getenv("LOG_FILE") or "" # Empty to disable +LOG_FORMAT = os.getenv("LOG_FORMAT") or "[%(levelname)s - %(asctime)s] %(message)s" +LOG_DATE_FORMAT = os.getenv("LOG_DATE_FORMAT") or "%d/%m/%Y %p" + +# Bcrypt configuration +BCRYPT_ROUNDS = ( + os.getenv("BCRYPT_ROUNDS") or 10 +) # How many rounds are used when salting + + +# Rate limit configuration +REDIS_URL = "" # Used for rate limits. Empty to fall back to memory +REDIS_OPTIONS = {} # Options for redis +RATELIMIT_ENABLED = False # False to disable rate limiting +RATELIMIT_STRATEGY = "moving-window" # Refer to https://flask-limiter.readthedocs.io/en/stable/strategies.html + +# Session configuration +SESSION_EXPIRE_LIMIT = 3600 # Unit is in seconds +SESSION_COOKIE_NAME = "_social_media_session" +COOKIE_SAMESITE_POLICY = "none" # Options are "lax", "none", "strict" +COOKIE_DOMAIN = "localhost" # Empty to disable this +COOKIE_PATH = "/" +COOKIE_HTTPONLY = True +SECURE_COOKIE = False # Set to true in production, False during development (unless your local server has HTTPS) + +# SMTP configuration +SMTP_HOST = "mail.nocturn9x.space" +SMTP_USER = "nocturn9x@nocturn9x.space" +SMTP_PASSWORD = "Je2aXGVH33" +SMTP_PORT = 587 +SMTP_USE_TLS = True +SMTP_FROM_USER = "nocturn9x@nocturn9x.space" +SMTP_TEMPLATES_DIRECTORY = pathlib.Path(__file__).parent.parent / "templates" / "email" +SMTP_TIMEOUT = 10 + +# Miscellaneous + +# Usernames containing these characters are not valid +INVALID_USERNAME_CHARACTERS = ["@", "\\", "/"] # Empty this to allow any character +# Empty this to disable username validation +VALIDATE_USERNAME_REGEX = ( + rf"^([^{''.join(INVALID_USERNAME_CHARACTERS)}]|[a-z0-9A-Z]){{5,32}}$" +) +# Criteria: +# - Between 10 and 72 characters long +# - At least 2 uppercase letters +# - At least 3 lowercase letters +# - At least one special character (!, @, #, $, %, &, *, _, +, /, \, (, ), £, ", ?, ^ +# - At least 2 numbers +# You can change the repetitions to enforce stricter/laxer rules or empty +# this field to disable weakness validation +VALIDATE_PASSWORD_REGEX = r"^(?=.*[A-Z]){2,}(?=.*[\"!@%#$&*_^\?\\\/(\)\+\-])+(?=.*[0-9]){2,}(?=.*[a-z]){3,}.{10,72}$" +FORCE_EMAIL_VERIFICATION = False +EMAIL_VERIFICATION_EXPIRATION = 3600 # In seconds +PLATFORM_NAME = "PySimpleSocial" # Used in emails +HAS_HTTPS = False + + +class NotAuthenticated(Exception): + pass + + +class AdminNotAuthenticated(Exception): + pass + + +if __name__ != "__main__": + LOGGER: logging.Logger = logging.getLogger("socialMedia") + LOGGER.setLevel(LOG_LEVEL) + handler = logging.StreamHandler(sys.stderr) + formatter = logging.Formatter(fmt=LOG_FORMAT, datefmt=LOG_DATE_FORMAT) + handler.setFormatter(formatter) + LOGGER.addHandler(handler) + handler.setLevel(LOG_LEVEL) + if LOG_FILE: + file_handler = logging.FileHandler(LOG_FILE, "a", "utf8") + file_handler.setFormatter(formatter) + LOGGER.addHandler(handler) + file_handler.setLevel(LOG_LEVEL) + logging.getLogger("uvicorn").addHandler(file_handler) + LIMITER = Limiter( + key_func=get_remote_address, + strategy=RATELIMIT_STRATEGY, + storage_uri=REDIS_URL or None, + in_memory_fallback_enabled=bool(REDIS_URL), + storage_options=REDIS_OPTIONS, + enabled=RATELIMIT_ENABLED, + ) + MANAGER = LoginManager( + LOGIN_SECRET_KEY, + "/user", + use_cookie=True, + cookie_name=SESSION_COOKIE_NAME, + use_header=USE_BEARER_HEADER, + custom_exception=NotAuthenticated, + ) + UNVERIFIED_MANAGER = copy.deepcopy(MANAGER) + ADMIN_MANAGER = LoginManager( + LOGIN_SECRET_KEY, + "/admin", + use_cookie=True, + cookie_name=f"{SESSION_COOKIE_NAME}_admin", + use_header=USE_BEARER_HEADER, + custom_exception=AdminNotAuthenticated, + ) + +# Uvicorn config + +HOST = os.getenv("HOST") or "localhost" +PORT = int(os.getenv("PORT") or 0) or 8000 +WORKERS = int(os.getenv("WORKERS") or 0) or 1 + +# Storage configuration + +STORAGE_ENGINE = "database" # Stores media inside the database. Other options are "local" to use a local/remote folder +# or "url" for uploading to a CDN +STORAGE_FOLDER = "" # Only needed when STORAGE_ENGINE is set to local +MAX_MEDIA_SIZE = ( + 5242880 # Max length of media allowed. Anything bigger raises a 415 HTTP Exception +) +ALLOWED_MEDIA_TYPES = ["gif", "png", "jpeg", "tiff"] +ZLIB_COMPRESSION_LEVEL = 9 diff --git a/config.py.example b/src/config.py.example similarity index 93% rename from config.py.example rename to src/config.py.example index 09bf90a..ea27cc0 100644 --- a/config.py.example +++ b/src/config.py.example @@ -1,5 +1,4 @@ -config.py.example# Configuration file. Each variable can be overridden by a -# corresponding environment variable +# Configuration file import os import sys import copy @@ -46,13 +45,13 @@ COOKIE_HTTPONLY = True SECURE_COOKIE = False # Set to true in production, False during development (unless your local server has HTTPS) # SMTP configuration -SMTP_HOST = "smtp.nocturn9x.space" +SMTP_HOST = "smtp.example.com" SMTP_USER = "info@example.com" SMTP_PASSWORD = "password" SMTP_PORT = 587 SMTP_USE_TLS = True SMTP_FROM_USER = "info@example.com" -SMTP_TEMPLATES_DIRECTORY = pathlib.Path(__file__) / "templates" / "email" +SMTP_TEMPLATES_DIRECTORY = pathlib.Path(__file__).parent.parent / "templates" / "email" SMTP_TIMEOUT = 10 # Miscellaneous @@ -116,7 +115,7 @@ if __name__ != "__main__": use_header=USE_BEARER_HEADER, custom_exception=NotAuthenticated, ) - UNVERIFIED_MANAGER = copy.deepcopy(MANAGE) + UNVERIFIED_MANAGER = copy.deepcopy(MANAGER) ADMIN_MANAGER = LoginManager( LOGIN_SECRET_KEY, "/admin", @@ -130,7 +129,7 @@ if __name__ != "__main__": HOST = os.getenv("HOST") or "localhost" PORT = int(os.getenv("PORT") or 0) or 8000 -WORKERS = int(os.getenv("PORT") or 0) or 1 +WORKERS = int(os.getenv("WORKERS") or 0) or 1 # Storage configuration @@ -139,4 +138,4 @@ STORAGE_ENGINE = "database" # Stores media inside the database. Other options STORAGE_FOLDER = "" # Only needed when STORAGE_ENGINE is set to local MAX_MEDIA_SIZE = 5242880 # Max length of media allowed. Anything bigger raises a 415 HTTP Exception ALLOWED_MEDIA_TYPES = ["gif", "png", "jpeg", "tiff"] -ZLIB_COMPRESSION_LEVEL = 9 \ No newline at end of file +ZLIB_COMPRESSION_LEVEL = 9 diff --git a/endpoints/__init__.py b/src/endpoints/__init__.py similarity index 100% rename from endpoints/__init__.py rename to src/endpoints/__init__.py diff --git a/src/endpoints/auth.py b/src/endpoints/auth.py new file mode 100644 index 0000000..0f8fbf4 --- /dev/null +++ b/src/endpoints/auth.py @@ -0,0 +1,129 @@ +import bcrypt +from datetime import timedelta +from fastapi.exceptions import HTTPException +from fastapi.security import OAuth2PasswordRequestForm +from fastapi import APIRouter as FastAPI, Depends, Response, Request + + +from config import ( + LOGGER, + SESSION_EXPIRE_LIMIT, + COOKIE_SAMESITE_POLICY, + COOKIE_DOMAIN, + MANAGER, + SESSION_COOKIE_NAME, + LIMITER, + SECURE_COOKIE, + COOKIE_PATH, + COOKIE_HTTPONLY, + UNVERIFIED_MANAGER, +) +from responses import ( + Response as APIResponse, + UnprocessableEntity, + BadRequest, +) +from orm.users import ( + UserModel, + PrivateUserModel, + get_user_by_username, +) + +router = FastAPI() + + +@router.post( + "/user", + tags=["Users"], + status_code=200, + responses={ + 200: {"model": APIResponse}, + 400: {"model": BadRequest}, + 422: {"model": UnprocessableEntity}, + }, +) +@LIMITER.limit("5/minute") +async def login( + request: Request, response: Response, data: OAuth2PasswordRequestForm = Depends() +): + """ + Performs user authentication. Endpoint is limited to 5 hits per minute + """ + + if request.cookies.get(SESSION_COOKIE_NAME): + raise HTTPException(status_code=400, detail="Please logout first") + username = data.username + if len(username) > 32: + raise HTTPException( + status_code=413, detail="Authentication failed: username is too long" + ) + try: + password = data.password.encode() + if len(password) > 72: + raise HTTPException( + status_code=413, detail="Authentication failed: password is too long" + ) + except UnicodeEncodeError as e: + LOGGER.warning( + f"An error occurred while attempting to decode password for user {username} -> {type(e).__name__}: {e}" + ) + raise HTTPException( + status_code=413, + detail="Authentication failed: invalid characters in password", + ) + if not ( + user := await get_user_by_username( + username, include_secrets=True, restricted_ok=True + ) + ): + raise HTTPException( + status_code=413, + detail="Authentication failed: the user does not exist", + ) + user: PrivateUserModel + if not bcrypt.checkpw(password, user.password_hash): + raise HTTPException( + status_code=413, + detail="Authentication failed: password mismatch", + ) + token = MANAGER.create_access_token( + expires=timedelta(seconds=SESSION_EXPIRE_LIMIT), + data={"sub": str(user.public_id)}, + ) + response.set_cookie( + secure=SECURE_COOKIE, + key=SESSION_COOKIE_NAME, + max_age=SESSION_EXPIRE_LIMIT, + value=token, + httponly=COOKIE_HTTPONLY, + samesite=COOKIE_SAMESITE_POLICY, + domain=COOKIE_DOMAIN or None, + path=COOKIE_PATH or "/", + ) + return APIResponse(status_code=200, msg="Authentication successful") + + +@router.get( + "/user/logout", + tags=["Users"], + status_code=200, + responses={200: {"model": APIResponse}, 422: {"model": UnprocessableEntity}}, +) +@LIMITER.limit("5/minute") +async def logout( + request: Request, response: Response, _user: UserModel = Depends(UNVERIFIED_MANAGER) +): + """ + Deletes a user's session cookie, logging them + out. Endpoint is limited to 5 hits per minute + """ + + response.delete_cookie( + secure=SECURE_COOKIE, + key=SESSION_COOKIE_NAME, + httponly=COOKIE_HTTPONLY, + samesite=COOKIE_SAMESITE_POLICY, + domain=COOKIE_DOMAIN or None, + path=COOKIE_PATH or "/", + ) + return APIResponse(status_code=200, msg="Logged out") diff --git a/src/endpoints/email.py b/src/endpoints/email.py new file mode 100644 index 0000000..3f274f6 --- /dev/null +++ b/src/endpoints/email.py @@ -0,0 +1,210 @@ +import json +import uuid +from datetime import timedelta +from datetime import timezone, datetime +from fastapi.exceptions import HTTPException +from fastapi import APIRouter as FastAPI, Depends, Request + + +from config import ( + LIMITER, + SMTP_HOST, + SMTP_PORT, + SMTP_USER, + SMTP_PASSWORD, + SMTP_USE_TLS, + SMTP_TIMEOUT, + SMTP_FROM_USER, + SMTP_TEMPLATES_DIRECTORY, + PLATFORM_NAME, + HAS_HTTPS, + HOST, + PORT, + EMAIL_VERIFICATION_EXPIRATION, + UNVERIFIED_MANAGER, +) +from responses import ( + Response as APIResponse, + UnprocessableEntity, + BadRequest, + NotFound, + InternalServerError, +) +from orm.users import ( + User, + UserModel, +) +from orm.email_verification import EmailVerification, EmailVerificationType +from util.email import send_email + + +router = FastAPI() + + +@router.get( + "/user/verifyEmail/{verification_id}", + tags=["Users"], + status_code=200, + responses={ + 200: {"model": APIResponse}, + 404: {"model": NotFound}, + 422: {"model": UnprocessableEntity}, + }, +) +@LIMITER.limit("3/second") +async def verify_email( + request: Request, + verification_id: str, + user: UserModel = Depends(UNVERIFIED_MANAGER), +): + """ + Verifies a user's email address. Endpoint is + limited to 3 hits per second + """ + + if not ( + verification := await EmailVerification.select(*EmailVerification.all_columns()) + .where(EmailVerification.id == verification_id) + .first() + ): + raise HTTPException(status_code=404, detail="Verification ID is invalid") + elif not verification["pending"]: + raise HTTPException(status_code=400, detail="Email is already verified") + elif datetime.now().astimezone(timezone.utc) - verification[ + "creation_date" + ].astimezone(timezone.utc) > timedelta(seconds=EMAIL_VERIFICATION_EXPIRATION): + raise HTTPException( + status_code=400, + detail="Verification window has expired. Try again", + ) + else: + await EmailVerification.update({EmailVerification.pending: False}).where( + EmailVerification.user == user.public_id + ) + await User.update({User.email_verified: True}).where( + User.public_id == user.public_id + ) + return APIResponse(status_code=200, msg="Verification successful") + + +@router.get( + "/user/changeEmail/{verification_id}", + tags=["Users"], + status_code=200, + responses={ + 200: {"model": APIResponse}, + 400: {"model": BadRequest}, + 422: {"model": UnprocessableEntity}, + 404: {"model": NotFound}, + }, +) +@LIMITER.limit("3/second") +async def change_email( + request: Request, + verification_id: str, + user: UserModel = Depends(UNVERIFIED_MANAGER), +): + """ + Modifies a user's email address. + Endpoint is limited to 3 hits per second + """ + + if not ( + verification := await EmailVerification.select(*EmailVerification.all_columns()) + .where(EmailVerification.id == verification_id) + .first() + ): + raise HTTPException(status_code=404, detail="Request ID is invalid") + elif not verification["pending"]: + raise HTTPException(status_code=400, detail="This link has already been used") + elif datetime.now().astimezone(timezone.utc) - verification[ + "creation_date" + ].astimezone(timezone.utc) > timedelta(seconds=EMAIL_VERIFICATION_EXPIRATION): + raise HTTPException( + status_code=400, + detail="Verification window has expired. Try again", + ) + else: + # Note how we don't update based on the verification ID: + # this way, multiple pending email verification requests + # are all cleared at once + await EmailVerification.update({EmailVerification.pending: False}).where( + EmailVerification.user == user.public_id + and EmailVerification.kind == EmailVerificationType.CHANGE_EMAIL + ) + await User.update( + { + User.email_address: verification["data"].decode(), + User.email_verified: False, + } + ).where(User.public_id == user.public_id) + return APIResponse(status_code=200, msg="Email updated") + + +@router.put( + "user/resendMail", + tags=["Users"], + status_code=200, + responses={ + 200: {"model": APIResponse}, + 400: {"model": BadRequest}, + 422: {"model": UnprocessableEntity}, + 500: {"model": InternalServerError}, + }, +) +@LIMITER.limit("6/minute") +async def resend_email(request: Request, user: UserModel = Depends(UNVERIFIED_MANAGER)): + """ + Resends the verification email to the user if the previous has expired. + Endpoint is limited to 6 hits per minute + """ + + if user.email_verified: + raise HTTPException(status_code=400, detail="Email is already verified") + email_template = SMTP_TEMPLATES_DIRECTORY / f"{user.locale}.json" + try: + email_template.resolve(strict=True) + except FileNotFoundError: + email_template = SMTP_TEMPLATES_DIRECTORY / "en_US.json" + with email_template.open() as f: + email_message = json.load(f)["signup"] + verification_id = uuid.uuid4() + if await send_email( + SMTP_HOST, + SMTP_PORT, + email_message["content"].format( + first_name=user.first_name, + last_name=user.last_name, + username=user.username, + email=user.email_address, + link=f"http{'s' if HAS_HTTPS else ''}://{HOST}" + f"{'' if PORT == 443 and HAS_HTTPS or PORT == 80 else f':{PORT}'}/user/verifyEmail/{verification_id}", + platformName=PLATFORM_NAME, + ), + SMTP_TIMEOUT, + SMTP_FROM_USER, + user.email_address, + email_message["subject"].format( + first_name=user.first_name, + last_name=user.last_name, + username=user.username, + email=user.email_address, + platformName=PLATFORM_NAME, + ), + SMTP_USER, + SMTP_PASSWORD, + use_tls=SMTP_USE_TLS, + ): + await EmailVerification.update( + { + EmailVerification.id: verification_id, + EmailVerification.creation_date: datetime.now(), + } + ) + return APIResponse(status_code=200, msg="Success") + else: + raise HTTPException( + status_code=500, + detail="An error occurred while trying to resend the email," + " please try again later", + ) diff --git a/endpoints/media.py b/src/endpoints/media.py similarity index 79% rename from endpoints/media.py rename to src/endpoints/media.py index a6c37a5..1c39b57 100644 --- a/endpoints/media.py +++ b/src/endpoints/media.py @@ -22,13 +22,19 @@ router = FastAPI() }, ) @LIMITER.limit("2/second") -async def get_media(request: Request, media_id: str, _user: UserModel = Depends(MANAGER)): +async def get_media( + request: Request, media_id: str, _user: UserModel = Depends(MANAGER) +): """ Gets a media object by its ID. Endpoint is limited to 2 hits per second """ - if (m := await Media.select(Media.media_id).where(Media.media_id == media_id).first()) is None: + if ( + m := await Media.select(Media.media_id) + .where(Media.media_id == media_id) + .first() + ) is None: raise HTTPException(status_code=404, detail="Media not found") m = Media(**m) if m.media_type == MediaType.FILE: @@ -55,7 +61,9 @@ async def get_media(request: Request, media_id: str, _user: UserModel = Depends( }, ) @LIMITER.limit("2/second") -async def report_media(request: Request, media_id: str, _user: UserModel = Depends(MANAGER)): +async def report_media( + request: Request, media_id: str, _user: UserModel = Depends(MANAGER) +): """ Reports a piece of media by its ID. This creates a report that can be seen by admins, which can @@ -63,7 +71,11 @@ async def report_media(request: Request, media_id: str, _user: UserModel = Depen hits per second """ - if (m := await Media.select(Media.media_id).where(Media.media_id == media_id).first()) is None: + if ( + m := await Media.select(Media.media_id) + .where(Media.media_id == media_id) + .first() + ) is None: raise HTTPException(status_code=404, detail="Media not found") # TODO: Create report return APIResponse(msg="Success") diff --git a/src/endpoints/password.py b/src/endpoints/password.py new file mode 100644 index 0000000..4ef8978 --- /dev/null +++ b/src/endpoints/password.py @@ -0,0 +1,73 @@ +from datetime import timedelta +from datetime import timezone, datetime +from fastapi.exceptions import HTTPException +from fastapi import APIRouter as FastAPI, Depends, Request + + +from config import ( + LIMITER, + EMAIL_VERIFICATION_EXPIRATION, + UNVERIFIED_MANAGER, +) +from responses import ( + Response as APIResponse, + UnprocessableEntity, + BadRequest, + NotFound, +) +from orm.users import User, UserModel +from orm.email_verification import EmailVerification, EmailVerificationType + + +router = FastAPI() + + +@router.get( + "/user/resetPassword/{verification_id}", + tags=["Users"], + status_code=200, + responses={ + 200: {"model": APIResponse}, + 400: {"model": BadRequest}, + 422: {"model": UnprocessableEntity}, + 404: {"model": NotFound}, + }, +) +@LIMITER.limit("3/second") +async def reset_password( + request: Request, + verification_id: str, + user: UserModel = Depends(UNVERIFIED_MANAGER), +): + """ + Modifies a user's password. Endpoint is limited + to 3 hits per second + """ + + if not ( + verification := await EmailVerification.select(*EmailVerification.all_columns()) + .where(EmailVerification.id == verification_id) + .first() + ): + raise HTTPException(status_code=404, detail="Request ID is invalid") + elif not verification["pending"]: + raise HTTPException(status_code=400, detail="This link has already been used") + elif datetime.now().astimezone(timezone.utc) - verification[ + "creation_date" + ].astimezone(timezone.utc) > timedelta(seconds=EMAIL_VERIFICATION_EXPIRATION): + raise HTTPException( + status_code=400, + detail="Verification window has expired. Try again", + ) + else: + # Note how we don't update based on the verification ID: + # this way, multiple pending email verification requests + # are all cleared at once + await EmailVerification.update({EmailVerification.pending: False}).where( + EmailVerification.user == user.public_id + and EmailVerification.kind == EmailVerificationType.PASSWORD_RESET + ) + await User.update({User.password_hash: verification["data"]}).where( + User.public_id == user.public_id + ) + return APIResponse(status_code=200, msg="Password updated") diff --git a/endpoints/users.py b/src/endpoints/users.py similarity index 52% rename from endpoints/users.py rename to src/endpoints/users.py index 84f14cf..3188706 100644 --- a/endpoints/users.py +++ b/src/endpoints/users.py @@ -1,37 +1,23 @@ import base64 -from datetime import timezone, datetime -import hashlib import json -import re -import imghdr import uuid -import zlib import bcrypt from uuid import UUID -import validators -from fastapi import APIRouter as FastAPI, Depends, Response, Request, UploadFile -from fastapi.exceptions import HTTPException -from fastapi.security import OAuth2PasswordRequestForm -from datetime import timedelta from pathlib import Path +from fastapi.exceptions import HTTPException +from fastapi import APIRouter as FastAPI, Depends, Response, Request, UploadFile + from config import ( BCRYPT_ROUNDS, - LOGGER, - SESSION_EXPIRE_LIMIT, COOKIE_SAMESITE_POLICY, COOKIE_DOMAIN, MANAGER, SESSION_COOKIE_NAME, LIMITER, - VALIDATE_PASSWORD_REGEX, - VALIDATE_USERNAME_REGEX, SECURE_COOKIE, COOKIE_PATH, COOKIE_HTTPONLY, - ALLOWED_MEDIA_TYPES, - ZLIB_COMPRESSION_LEVEL, - MAX_MEDIA_SIZE, STORAGE_ENGINE, STORAGE_FOLDER, SMTP_HOST, @@ -46,8 +32,6 @@ from config import ( HAS_HTTPS, HOST, PORT, - EMAIL_VERIFICATION_EXPIRATION, - FORCE_EMAIL_VERIFICATION, UNVERIFIED_MANAGER, ) from responses import ( @@ -70,116 +54,19 @@ from orm.users import ( PrivateUserModel, get_user_by_username, get_user_by_id, - get_user_by_email, ) from orm.media import Media, MediaType, PublicMediaModel -from orm.email_verification import EmailVerification, EmailVerificationType +from orm.email_verification import EmailVerification from util.email import send_email +from util.users import validate_user, validate_profile_picture + router = FastAPI() -# Credential loaders for our authenticated routes -@MANAGER.user_loader() -async def get_self_by_id(public_id: UUID, requires_verified: bool = True) -> UserModel: - user = await get_user_by_id(public_id, include_secrets=True, restricted_ok=True) - if FORCE_EMAIL_VERIFICATION and requires_verified and not user.email_verified: - raise HTTPException(status_code=401, detail="Email verification is required") - return user - - -@UNVERIFIED_MANAGER.user_loader() -async def get_self_by_id_unverified(public_id: UUID) -> UserModel: - return await get_self_by_id(public_id, False) - - # Here follow our *beautifully* documented path operations -@router.post( - "/user", - tags=["Users"], - status_code=200, - responses={ - 200: {"model": APIResponse}, - 400: {"model": BadRequest}, - 422: {"model": UnprocessableEntity}, - }, -) -@LIMITER.limit("5/minute") -async def login(request: Request, response: Response, data: OAuth2PasswordRequestForm = Depends()): - """ - Performs user authentication. Endpoint is limited to 5 hits per minute - """ - - if request.cookies.get(SESSION_COOKIE_NAME): - raise HTTPException(status_code=400, detail="Please logout first") - username = data.username - if len(username) > 32: - raise HTTPException(status_code=413, detail="Authentication failed: username is too long") - try: - password = data.password.encode() - if len(password) > 72: - raise HTTPException(status_code=413, detail="Authentication failed: password is too long") - except UnicodeEncodeError as e: - LOGGER.warning( - f"An error occurred while attempting to decode password for user {username} -> {type(e).__name__}: {e}" - ) - raise HTTPException( - status_code=413, - detail="Authentication failed: invalid characters in password", - ) - if not (user := await get_user_by_username(username, include_secrets=True, restricted_ok=True)): - raise HTTPException( - status_code=413, - detail="Authentication failed: the user does not exist", - ) - if not bcrypt.checkpw(password, user.password_hash): - raise HTTPException( - status_code=413, - detail="Authentication failed: password mismatch", - ) - token = MANAGER.create_access_token( - expires=timedelta(seconds=SESSION_EXPIRE_LIMIT), - data={"sub": str(user.public_id)}, - ) - response.set_cookie( - secure=SECURE_COOKIE, - key=SESSION_COOKIE_NAME, - max_age=SESSION_EXPIRE_LIMIT, - value=token, - httponly=COOKIE_HTTPONLY, - samesite=COOKIE_SAMESITE_POLICY, - domain=COOKIE_DOMAIN or None, - path=COOKIE_PATH or "/", - ) - return APIResponse(status_code=200, msg="Authentication successful") - - -@router.get( - "/user/logout", - tags=["Users"], - status_code=200, - responses={200: {"model": APIResponse}, 422: {"model": UnprocessableEntity}}, -) -@LIMITER.limit("5/minute") -async def logout(request: Request, response: Response, _user: UserModel = Depends(UNVERIFIED_MANAGER)): - """ - Deletes a user's session cookie, logging them - out. Endpoint is limited to 5 hits per minute - """ - - response.delete_cookie( - secure=SECURE_COOKIE, - key=SESSION_COOKIE_NAME, - httponly=COOKIE_HTTPONLY, - samesite=COOKIE_SAMESITE_POLICY, - domain=COOKIE_DOMAIN or None, - path=COOKIE_PATH or "/", - ) - return APIResponse(status_code=200, msg="Logged out") - - @router.get( "/user/me", tags=["Users"], @@ -219,7 +106,9 @@ async def get_self(request: Request, user: UserModel = Depends(UNVERIFIED_MANAGE }, ) @LIMITER.limit("30/second") -async def get_user_by_name(request: Request, username: str, _auth: UserModel = Depends(MANAGER)): +async def get_user_by_name( + request: Request, username: str, _auth: UserModel = Depends(MANAGER) +): """ Fetches a single user by its public username. Endpoint is limited to 30 hits per second @@ -227,6 +116,7 @@ async def get_user_by_name(request: Request, username: str, _auth: UserModel = D if not (user := await get_user_by_username(username)): return NotFound(msg="Lookup failed: the user does not exist") + user: PrivateUserModel return PublicUserResponse( data=PublicUserModel( public_id=user.public_id, @@ -258,14 +148,19 @@ async def get_user_by_name(request: Request, username: str, _auth: UserModel = D }, ) @LIMITER.limit("30/second") -async def get_user_by_public_id(request: Request, public_id: str, _auth: UserModel = Depends(MANAGER)): +async def get_user_by_public_id( + request: Request, public_id: str, _auth: UserModel = Depends(MANAGER) +): """ Fetches a single user by its public ID. Endpoint is limited to 30 hits per second """ if not (user := await get_user_by_id(UUID(public_id))): - raise HTTPException(status_code=404, detail="Lookup failed: the user does not exist") + raise HTTPException( + status_code=404, detail="Lookup failed: the user does not exist" + ) + user: PrivateUserModel return PublicUserResponse( data=PublicUserModel( public_id=user.public_id, @@ -286,54 +181,6 @@ async def get_user_by_public_id(request: Request, public_id: str, _auth: UserMod ) -async def validate_user( - first_name: str | None, - last_name: str | None, - username: str | None, - email: str | None, - password: str | None, - bio: str | None, -): - """ - Performs some validation upon user creation. Returns - a tuple (success, msg) to be used by routes. Values - set to None are not checked against - """ - - if first_name and len(first_name) > 64: - return False, "first name is too long" - if first_name and len(first_name) < 5: - return False, "first name is too short" - if last_name and len(last_name) > 64: - return False, "last name is too long" - if last_name and len(last_name) < 2: - return False, "last name is too short" - if username and len(username) < 5: - return False, "username is too short" - if username and len(username) > 32: - return False, "username is too long" - if username and VALIDATE_USERNAME_REGEX and not re.match(VALIDATE_USERNAME_REGEX, username): - return False, "username is invalid" - if email and not validators.email(email): - return False, "email is not valid" - if password and len(password) > 72: - return False, "password is too long" - if password and VALIDATE_PASSWORD_REGEX and not re.match(VALIDATE_PASSWORD_REGEX, password): - return False, "password is too weak" - if username and await get_user_by_username(username, deleted_ok=True, restricted_ok=True): - return False, "username is already taken" - if email and await get_user_by_email(email, deleted_ok=True, restricted_ok=True): - return False, "email is already registered" - if bio and len(bio) > 4096: - return False, "bio is too long" - if bio: - try: - bio.encode("utf-8") - except UnicodeDecodeError: - return False, "bio contains invalid characters" - return True, "" - - @router.delete( "/user", tags=["Users"], @@ -341,7 +188,9 @@ async def validate_user( responses={200: {"model": APIResponse}, 422: {"model": UnprocessableEntity}}, ) @LIMITER.limit("1/minute") -async def delete(request: Request, response: Response, user: UserModel = Depends(UNVERIFIED_MANAGER)): +async def delete_user( + request: Request, response: Response, user: UserModel = Depends(UNVERIFIED_MANAGER) +): """ Sets the user's deleted flag in the database, without actually deleting the associated @@ -362,219 +211,6 @@ async def delete(request: Request, response: Response, user: UserModel = Depends return APIResponse(status_code=200, msg="Success") -@router.get( - "/user/verifyEmail/{verification_id}", - tags=["Users"], - status_code=200, - responses={ - 200: {"model": APIResponse}, - 404: {"model": NotFound}, - 422: {"model": UnprocessableEntity}, - }, -) -@LIMITER.limit("3/second") -async def verify_email( - request: Request, - verification_id: str, - user: UserModel = Depends(UNVERIFIED_MANAGER), -): - """ - Verifies a user's email address. Endpoint is - limited to 3 hits per second - """ - - if not ( - verification := await EmailVerification.select(*EmailVerification.all_columns()) - .where(EmailVerification.id == verification_id) - .first() - ): - raise HTTPException(status_code=404, detail="Verification ID is invalid") - elif not verification["pending"]: - raise HTTPException(status_code=400, detail="Email is already verified") - elif datetime.now().astimezone(timezone.utc) - verification["creation_date"].astimezone(timezone.utc) > timedelta( - seconds=EMAIL_VERIFICATION_EXPIRATION - ): - raise HTTPException( - status_code=400, - detail="Verification window has expired. Try again", - ) - else: - await EmailVerification.update({EmailVerification.pending: False}).where( - EmailVerification.user == user.public_id - ) - await User.update({User.email_verified: True}).where(User.public_id == user.public_id) - return APIResponse(status_code=200, msg="Verification successful") - - -@router.get( - "/user/resetPassword/{verification_id}", - tags=["Users"], - status_code=200, - responses={ - 200: {"model": APIResponse}, - 400: {"model": BadRequest}, - 422: {"model": UnprocessableEntity}, - 404: {"model": NotFound}, - }, -) -@LIMITER.limit("3/second") -async def reset_password( - request: Request, - verification_id: str, - user: UserModel = Depends(UNVERIFIED_MANAGER), -): - """ - Modifies a user's password. Endpoint is limited - to 3 hits per second - """ - - if not ( - verification := await EmailVerification.select(*EmailVerification.all_columns()) - .where(EmailVerification.id == verification_id) - .first() - ): - raise HTTPException(status_code=404, detail="Request ID is invalid") - elif not verification["pending"]: - raise HTTPException(status_code=400, detail="This link has already been used") - elif datetime.now().astimezone(timezone.utc) - verification["creation_date"].astimezone(timezone.utc) > timedelta( - seconds=EMAIL_VERIFICATION_EXPIRATION - ): - raise HTTPException( - status_code=400, - detail="Verification window has expired. Try again", - ) - else: - # Note how we don't update based on the verification ID: - # this way, multiple pending email verification requests - # are all cleared at once - await EmailVerification.update({EmailVerification.pending: False}).where( - EmailVerification.user == user.public_id and EmailVerification.kind == EmailVerificationType.PASSWORD_RESET - ) - await User.update({User.password_hash: verification["data"]}).where(User.public_id == user.public_id) - return APIResponse(status_code=200, msg="Password updated") - - -@router.get( - "/user/changeEmail/{verification_id}", - tags=["Users"], - status_code=200, - responses={ - 200: {"model": APIResponse}, - 400: {"model": BadRequest}, - 422: {"model": UnprocessableEntity}, - 404: {"model": NotFound}, - }, -) -@LIMITER.limit("3/second") -async def change_email( - request: Request, - verification_id: str, - user: UserModel = Depends(UNVERIFIED_MANAGER), -): - """ - Modifies a user's email address. - Endpoint is limited to 3 hits per second - """ - - if not ( - verification := await EmailVerification.select(*EmailVerification.all_columns()) - .where(EmailVerification.id == verification_id) - .first() - ): - raise HTTPException(status_code=404, detail="Request ID is invalid") - elif not verification["pending"]: - raise HTTPException(status_code=400, detail="This link has already been used") - elif datetime.now().astimezone(timezone.utc) - verification["creation_date"].astimezone(timezone.utc) > timedelta( - seconds=EMAIL_VERIFICATION_EXPIRATION - ): - raise HTTPException( - status_code=400, - detail="Verification window has expired. Try again", - ) - else: - # Note how we don't update based on the verification ID: - # this way, multiple pending email verification requests - # are all cleared at once - await EmailVerification.update({EmailVerification.pending: False}).where( - EmailVerification.user == user.public_id and EmailVerification.kind == EmailVerificationType.CHANGE_EMAIL - ) - await User.update( - { - User.email_address: verification["data"].decode(), - User.email_verified: False, - } - ).where(User.public_id == user.public_id) - return APIResponse(status_code=200, msg="Email updated") - - -@router.put( - "user/resendMail", - tags=["Users"], - status_code=200, - responses={ - 200: {"model": APIResponse}, - 400: {"model": BadRequest}, - 422: {"model": UnprocessableEntity}, - 500: {"model": InternalServerError}, - }, -) -@LIMITER.limit("6/minute") -async def resend_email(request: Request, user: UserModel = Depends(UNVERIFIED_MANAGER)): - """ - Resends the verification email to the user if the previous has expired. - Endpoint is limited to 6 hits per minute - """ - - if user.email_verified: - raise HTTPException(status_code=400, detail="Email is already verified") - email_template = SMTP_TEMPLATES_DIRECTORY / f"{user.locale}.json" - try: - email_template.resolve(strict=True) - except FileNotFoundError: - email_template = SMTP_TEMPLATES_DIRECTORY / "en_US.json" - with email_template.open() as f: - email_message = json.load(f)["signup"] - verification_id = uuid.uuid4() - if await send_email( - SMTP_HOST, - SMTP_PORT, - email_message["content"].format( - first_name=user.first_name, - last_name=user.last_name, - username=user.username, - email=user.email_address, - link=f"http{'s' if HAS_HTTPS else ''}://{HOST}" - f"{'' if PORT == 443 and HAS_HTTPS or PORT == 80 else f':{PORT}'}/user/verifyEmail/{verification_id}", - platformName=PLATFORM_NAME, - ), - SMTP_TIMEOUT, - SMTP_FROM_USER, - user.email_address, - email_message["subject"].format( - first_name=user.first_name, - last_name=user.last_name, - username=user.username, - email=user.email_address, - platformName=PLATFORM_NAME, - ), - SMTP_USER, - SMTP_PASSWORD, - use_tls=SMTP_USE_TLS, - ): - await EmailVerification.update( - { - EmailVerification.id: verification_id, - EmailVerification.creation_date: datetime.now(), - } - ) - return APIResponse(status_code=200, msg="Success") - else: - raise HTTPException( - status_code=500, - detail="An error occurred while trying to resend the email," " please try again later", - ) - - @router.put( "/user", tags=["Users"], @@ -607,7 +243,9 @@ async def signup( raise HTTPException(status_code=400, detail="Please logout first") # We don't use FastAPI's validation because we want custom error # messages - result, msg = await validate_user(first_name, last_name, username, email, password, bio) + result, msg = await validate_user( + first_name, last_name, username, email, password, bio + ) if not result: return APIResponse(status_code=413, msg=f"Signup failed: {msg}") else: @@ -667,40 +305,11 @@ async def signup( else: raise HTTPException( status_code=500, - detail="An error occurred while sending verification email, please" " try again later", + detail="An error occurred while sending verification email, please" + " try again later", ) -async def validate_profile_picture( - file: UploadFile, -) -> tuple[bool | None, str, bytes, str]: - """ - Validates a profile picture's size and content to see if it fits - our criteria and returns a tuple result, ext, data where result is a - boolean or none (True = check was passed, False = size too large, - None = check was failed for other reasons) indicating if the check was successful, - ext is the file's type and extension, data is a compressed stream of bytes - representing the original media and hash is the file's SHA256 hash encoded in - hexadecimal, before the compression. This function never raises an exception - """ - - async with file: - try: - content = await file.read() - if len(content) > MAX_MEDIA_SIZE: - return False, "", b"", "" - if not (ext := imghdr.what(content.decode())) in ALLOWED_MEDIA_TYPES: - return None, "", b"", "" - return ( - True, - ext, - zlib.compress(content, ZLIB_COMPRESSION_LEVEL), - hashlib.sha256(content).hexdigest(), - ) - except (UnicodeDecodeError, zlib.error): - return None, "", b"", "" - - @router.patch( "/user", tags=["Users"], @@ -739,9 +348,15 @@ async def update_user( since they're the only ones that can be set to a null value. Endpoint is limited to 6 hits per minute """ - if not delete and not any((first_name, last_name, username, profile_picture, email_address, bio, password)): - raise HTTPException(status_code=400, detail="At least one value has to be specified") - result, msg = await validate_user(first_name, last_name, username, email_address, password, bio) + if not delete and not any( + (first_name, last_name, username, profile_picture, email_address, bio, password) + ): + raise HTTPException( + status_code=400, detail="At least one value has to be specified" + ) + result, msg = await validate_user( + first_name, last_name, username, email_address, password, bio + ) if not result: raise HTTPException(status_code=413, detail=f"Update failed: {msg}") orig_user = user.copy() @@ -757,10 +372,16 @@ async def update_user( if profile_picture: result, ext, media, digest = validate_profile_picture(profile_picture) if result is False: - raise HTTPException(status_code=415, detail="The file type is unsupported") + raise HTTPException( + status_code=415, detail="The file type is unsupported" + ) elif result is None: raise HTTPException(status_code=413, detail="The file is too large") - elif (m := await Media.select(Media.media_id).where(Media.media_id == digest).first()) is None: + elif ( + m := await Media.select(Media.media_id) + .where(Media.media_id == digest) + .first() + ) is None: # This media hasn't been already uploaded (either by this user or by someone # else), so we save it now. If it has been already uploaded, there's no need # to do it again (that's what the hash is for) @@ -836,7 +457,9 @@ async def update_user( { EmailVerification.id: verification_id, EmailVerification.user: User(public_id=user.public_id), - EmailVerification.data: bcrypt.hashpw(password.encode(), user.password_hash[:29]), + EmailVerification.data: bcrypt.hashpw( + password.encode(), user.password_hash[:29] + ), } ) ) @@ -902,9 +525,13 @@ async def update_user( user.profile_picture = None fields = [] for field in user: - if field not in ["email_address", "password"] and getattr(orig_user, field) != getattr(user, field): + if field not in ["email_address", "password"] and getattr( + orig_user, field + ) != getattr(user, field): fields.append((field, getattr(user, field))) if fields: # If anything has changed, we update our info - await User.update({field: value for field, value in fields}).where(User.public_id == user.public_id) + await User.update({field: value for field, value in fields}).where( + User.public_id == user.public_id + ) return APIResponse(status_code=200, msg="Changes saved successfully") diff --git a/main.py b/src/main.py similarity index 79% rename from main.py rename to src/main.py index a9ee7e7..43ca53f 100644 --- a/main.py +++ b/src/main.py @@ -1,7 +1,7 @@ import uvloop import asyncio import uvicorn - +from pathlib import Path from fastapi import FastAPI, Request from fastapi.exceptions import ( HTTPException, @@ -9,10 +9,7 @@ from fastapi.exceptions import ( StarletteHTTPException, ) from slowapi.errors import RateLimitExceeded -from pathlib import Path - - -from endpoints import users, media +from endpoints import users, media, auth, email, password from config import ( LOGGER, LIMITER, @@ -37,9 +34,11 @@ from util.exception_handlers import ( generic_error, ) from util.email import test_smtp +from config import MANAGER, UNVERIFIED_MANAGER +from util.users import get_self_by_id, get_self_by_id_unverified -with (Path(__file__).parent / "API.md").resolve(strict=True).open() as f: +with (Path(__file__).parent.parent / "API.md").resolve(strict=True).open() as f: description = f.read() @@ -116,7 +115,9 @@ async def startup_checks(): await create_tables() await Media.raw("SELECT 1;") except Exception as e: - LOGGER.error(f"An error occurred while trying to initialize the database -> {type(e).__name__}: {e}") + LOGGER.error( + f"An error occurred while trying to initialize the database -> {type(e).__name__}: {e}" + ) else: LOGGER.info("Database initialized") LOGGER.info("Testing SMTP connection") @@ -131,7 +132,10 @@ async def startup_checks(): True, ) except Exception as e: - LOGGER.error(f"An error occurred while trying to connect to the SMTP server -> {type(e).__name__}: {e}") + LOGGER.error( + f"An error occurred while trying to connect to the SMTP server -> {type(e).__name__}: {e}" + ) + raise else: LOGGER.info("SMTP test was successful") @@ -141,7 +145,13 @@ if __name__ == "__main__": LOGGER.debug("Including modules") app.include_router(users.router) app.include_router(media.router) + app.include_router(auth.router) + app.include_router(email.router) + app.include_router(password.router) + LOGGER.debug("Configuring authentication and rate-limiting") app.state.limiter = LIMITER + MANAGER.user_loader(get_self_by_id) + UNVERIFIED_MANAGER.user_loader(get_self_by_id_unverified) LOGGER.debug("Setting exception handlers") app.add_exception_handler(RateLimitExceeded, rate_limited) app.add_exception_handler(NotAuthenticated, not_authenticated) @@ -153,9 +163,21 @@ if __name__ == "__main__": uvloop.install() log_config = uvicorn.config.LOGGING_CONFIG log_config["formatters"]["access"]["datefmt"] = LOGGER.handlers[0].formatter.datefmt - log_config["formatters"]["default"]["datefmt"] = LOGGER.handlers[0].formatter.datefmt - log_config["formatters"]["access"]["fmt"] = LOGGER.handlers[0].formatter._fmt - log_config["formatters"]["default"]["fmt"] = LOGGER.handlers[0].formatter._fmt + log_config["formatters"]["default"]["datefmt"] = LOGGER.handlers[ + 0 + ].formatter.datefmt + log_config["formatters"]["access"]["fmt"] = LOGGER.handlers[ + 0 + ].formatter._fmt # noqa + log_config["formatters"]["default"]["fmt"] = LOGGER.handlers[ + 0 + ].formatter._fmt # noqa log_config["handlers"]["access"]["stream"] = "ext://sys.stderr" - asyncio.run(startup_checks()) + try: + asyncio.run(startup_checks()) + except Exception as e: + LOGGER.error( + f"Startup checks failed due to a {type(e).__name__} exception -> {e}" + ) + exit(1) uvicorn.run(host=HOST, port=PORT, app=app, log_config=log_config, workers=WORKERS) diff --git a/orm/__init__.py b/src/orm/__init__.py similarity index 100% rename from orm/__init__.py rename to src/orm/__init__.py diff --git a/orm/email_verification.py b/src/orm/email_verification.py similarity index 100% rename from orm/email_verification.py rename to src/orm/email_verification.py diff --git a/orm/media.py b/src/orm/media.py similarity index 94% rename from orm/media.py rename to src/orm/media.py index 3d64429..67851b0 100644 --- a/orm/media.py +++ b/src/orm/media.py @@ -45,7 +45,9 @@ class Media(Table): MediaModel = create_pydantic_model(Media) -PublicMediaModel = create_pydantic_model(Media, exclude_columns=(Media.flagged, Media.deleted, Media.media_type)) +PublicMediaModel = create_pydantic_model( + Media, exclude_columns=(Media.flagged, Media.deleted, Media.media_type) +) async def get_media_by_column( diff --git a/orm/posts.py b/src/orm/posts.py similarity index 92% rename from orm/posts.py rename to src/orm/posts.py index 2830000..28c1a04 100644 --- a/orm/posts.py +++ b/src/orm/posts.py @@ -38,7 +38,9 @@ class Post(Table, tablename="posts"): PostModel = create_pydantic_model(Post, nested=True) -PrivatePostModelInternal = create_pydantic_model(Post, nested=True, exclude_columns=(Post.flagged, Post.internal_id)) +PrivatePostModelInternal = create_pydantic_model( + Post, nested=True, exclude_columns=(Post.flagged, Post.internal_id) +) PublicPostModelInternal = create_pydantic_model( Post, nested=True, exclude_columns=(Post.flagged, Post.deleted, Post.internal_id) ) diff --git a/orm/users.py b/src/orm/users.py similarity index 100% rename from orm/users.py rename to src/orm/users.py diff --git a/src/piccolo_conf.py b/src/piccolo_conf.py new file mode 100644 index 0000000..8d7733d --- /dev/null +++ b/src/piccolo_conf.py @@ -0,0 +1,16 @@ +from piccolo.conf.apps import AppRegistry +from piccolo.engine.postgres import PostgresEngine + +DB = PostgresEngine( + config={ + "host": "mouse.db.elephantsql.com", + "user": "zubtlqiv", + "password": "l9bgeChDsTHJUwSFex2N2kSbLZ8VkRch", + "database": "zubtlqiv", + } +) + + +# A list of paths to piccolo apps +# e.g. ['blog.piccolo_app'] +APP_REGISTRY = AppRegistry(apps=[]) diff --git a/piccolo_conf.py.example b/src/piccolo_conf.py.example similarity index 100% rename from piccolo_conf.py.example rename to src/piccolo_conf.py.example diff --git a/responses/__init__.py b/src/responses/__init__.py similarity index 100% rename from responses/__init__.py rename to src/responses/__init__.py diff --git a/responses/media.py b/src/responses/media.py similarity index 100% rename from responses/media.py rename to src/responses/media.py diff --git a/responses/users.py b/src/responses/users.py similarity index 100% rename from responses/users.py rename to src/responses/users.py diff --git a/util/__init__.py b/src/util/__init__.py similarity index 100% rename from util/__init__.py rename to src/util/__init__.py diff --git a/util/email.py b/src/util/email.py similarity index 94% rename from util/email.py rename to src/util/email.py index 731bf62..fa3961c 100644 --- a/util/email.py +++ b/src/util/email.py @@ -47,7 +47,9 @@ async def send_email( await srv.login(login_email, password) await srv.sendmail(sender, recipient, msg.as_string()) except (aiosmtplib.SMTPException, asyncio.TimeoutError) as error: - logging.error(f"An error occurred while dealing with {host}:{port} (SMTP): {type(error).__name__}: {error}") + logging.error( + f"An error occurred while dealing with {host}:{port} (SMTP): {type(error).__name__}: {error}" + ) return error return True diff --git a/util/exception_handlers.py b/src/util/exception_handlers.py similarity index 60% rename from util/exception_handlers.py rename to src/util/exception_handlers.py index 6910642..a9f2527 100644 --- a/util/exception_handlers.py +++ b/src/util/exception_handlers.py @@ -18,7 +18,10 @@ async def rate_limited(request: Request, error: RateLimitExceeded) -> JSONRespon else: break error.detail = error.detail[:n] + " requests" + error.detail[n:] - LOGGER.info(f"{request.client.host} got rate-limited at {str(request.url)} " f"(exceeded {error.detail})") + LOGGER.info( + f"{request.client.host} got rate-limited at {str(request.url)} " + f"(exceeded {error.detail})" + ) return JSONResponse( status_code=200, content=dict( @@ -34,7 +37,9 @@ def not_authenticated(request: Request, _: NotAuthenticated) -> JSONResponse: """ LOGGER.info(f"{request.client.host} failed to authenticate at {str(request.url)}") - return JSONResponse(status_code=200, content=dict(status_code=401, msg="Authentication is required")) + return JSONResponse( + status_code=200, content=dict(status_code=401, msg="Authentication is required") + ) def request_invalid(request: Request, exc: RequestValidationError) -> JSONResponse: @@ -42,14 +47,18 @@ def request_invalid(request: Request, exc: RequestValidationError) -> JSONRespon Handles Bad Request exceptions from FastAPI """ - LOGGER.info(f"{request.client.host} sent an invalid request at {request.url!r}: {type(exc).__name__}: {exc}") + LOGGER.info( + f"{request.client.host} sent an invalid request at {request.url!r}: {type(exc).__name__}: {exc}" + ) return JSONResponse( status_code=200, content=dict(status_code=400, msg=f"Bad request: {type(exc).__name__}: {exc}"), ) -def http_exception(request: Request, exc: HTTPException | StarletteHTTPException) -> JSONResponse: +def http_exception( + request: Request, exc: HTTPException | StarletteHTTPException +) -> JSONResponse: """ Handles HTTP-specific exceptions raised explicitly by path operations @@ -57,12 +66,19 @@ def http_exception(request: Request, exc: HTTPException | StarletteHTTPException if exc.status_code >= 500: LOGGER.error( - f"{request.client.host} raised a {exc.status_code} error at {request.url!r}:" f"{type(exc).__name__}: {exc}" + f"{request.client.host} raised a {exc.status_code} error at {request.url!r}:" + f"{type(exc).__name__}: {exc}" + ) + return JSONResponse( + status_code=200, content=dict(status_code=500, msg="Internal Server Error") ) - return JSONResponse(status_code=200, content=dict(status_code=500, msg="Internal Server Error")) else: - LOGGER.info(f"{request.client.host} raised an HTTP error ({exc.status_code}) at {str(request.url)}") - return JSONResponse(status_code=200, content=dict(status_code=exc.status_code, msg=exc.detail)) + LOGGER.info( + f"{request.client.host} raised an HTTP error ({exc.status_code}) at {str(request.url)}" + ) + return JSONResponse( + status_code=200, content=dict(status_code=exc.status_code, msg=exc.detail) + ) async def generic_error(request: Request, exc: Exception) -> JSONResponse: @@ -70,6 +86,10 @@ async def generic_error(request: Request, exc: Exception) -> JSONResponse: Handles generic, unexpected errors in the ASGI application """ - LOGGER.info(f"{request.client.host} raised an unexpected error ({type(exc).__name__}: {exc}) at {str(request.url)}") + LOGGER.info( + f"{request.client.host} raised an unexpected error ({type(exc).__name__}: {exc}) at {str(request.url)}" + ) # We can't leak anything about the error, it would be too risky - return JSONResponse(status_code=200, content=dict(status_code=500, msg="Internal Server Error")) + return JSONResponse( + status_code=200, content=dict(status_code=500, msg="Internal Server Error") + ) diff --git a/src/util/users.py b/src/util/users.py new file mode 100644 index 0000000..8ba75eb --- /dev/null +++ b/src/util/users.py @@ -0,0 +1,122 @@ +import hashlib +import imghdr +import re +import zlib +import validators +from uuid import UUID +from config import ( + VALIDATE_PASSWORD_REGEX, + VALIDATE_USERNAME_REGEX, + FORCE_EMAIL_VERIFICATION, + MAX_MEDIA_SIZE, + ZLIB_COMPRESSION_LEVEL, + ALLOWED_MEDIA_TYPES, +) +from orm.users import ( + get_user_by_username, + get_user_by_email, + UserModel, + get_user_by_id, +) +from fastapi import UploadFile +from fastapi.exceptions import HTTPException + + +async def validate_user( + first_name: str | None, + last_name: str | None, + username: str | None, + email: str | None, + password: str | None, + bio: str | None, +): + """ + Performs some validation upon user creation. Returns + a tuple (success, msg) to be used by routes. Values + set to None are not checked against + """ + + if first_name and len(first_name) > 64: + return False, "first name is too long" + if first_name and len(first_name) < 5: + return False, "first name is too short" + if last_name and len(last_name) > 64: + return False, "last name is too long" + if last_name and len(last_name) < 2: + return False, "last name is too short" + if username and len(username) < 5: + return False, "username is too short" + if username and len(username) > 32: + return False, "username is too long" + if ( + username + and VALIDATE_USERNAME_REGEX + and not re.match(VALIDATE_USERNAME_REGEX, username) + ): + return False, "username is invalid" + if email and not validators.email(email): + return False, "email is not valid" + if password and len(password) > 72: + return False, "password is too long" + if ( + password + and VALIDATE_PASSWORD_REGEX + and not re.match(VALIDATE_PASSWORD_REGEX, password) + ): + return False, "password is too weak" + if username and await get_user_by_username( + username, deleted_ok=True, restricted_ok=True + ): + return False, "username is already taken" + if email and await get_user_by_email(email, deleted_ok=True, restricted_ok=True): + return False, "email is already registered" + if bio and len(bio) > 4096: + return False, "bio is too long" + if bio: + try: + bio.encode("utf-8") + except UnicodeDecodeError: + return False, "bio contains invalid characters" + return True, "" + + +async def validate_profile_picture( + file: UploadFile, +) -> tuple[bool | None, str, bytes, str]: + """ + Validates a profile picture's size and content to see if it fits + our criteria and returns a tuple result, ext, data where result is a + boolean or none (True = check was passed, False = size too large, + None = check was failed for other reasons) indicating if the check was successful, + ext is the file's type and extension, data is a compressed stream of bytes + representing the original media and hash is the file's SHA256 hash encoded in + hexadecimal, before the compression. This function never raises an exception + """ + + async with file: + try: + content = await file.read() + if len(content) > MAX_MEDIA_SIZE: + return False, "", b"", "" + if not (ext := imghdr.what(content.decode())) in ALLOWED_MEDIA_TYPES: + return None, "", b"", "" + return ( + True, + ext, + zlib.compress(content, ZLIB_COMPRESSION_LEVEL), + hashlib.sha256(content).hexdigest(), + ) + except (UnicodeDecodeError, zlib.error): + return None, "", b"", "" + + +# Credential loaders for our authenticated routes +async def get_self_by_id(public_id: UUID, requires_verified: bool = True) -> UserModel: + user = await get_user_by_id(public_id, include_secrets=True, restricted_ok=True) + if FORCE_EMAIL_VERIFICATION and requires_verified and not user.email_verified: + raise HTTPException(status_code=401, detail="Email verification is required") + return user + + +async def get_self_by_id_unverified(public_id: UUID) -> UserModel: + return await get_self_by_id(public_id, False)