From 1c05716345d8b2fb04bf50a416e20aa75388f60c Mon Sep 17 00:00:00 2001 From: Mattia Giambirtone Date: Wed, 5 Oct 2022 19:52:37 +0200 Subject: [PATCH] Large rework of the application for better documentation support --- endpoints/users.py | 675 ++++++++++++++++++------------------- main.py | 43 ++- orm/users.py | 82 ++++- requirements.txt | 3 +- responses/__init__.py | 96 ++++++ responses/users.py | 24 ++ util/exception_handlers.py | 68 ++-- 7 files changed, 597 insertions(+), 394 deletions(-) create mode 100644 responses/__init__.py create mode 100644 responses/users.py diff --git a/endpoints/users.py b/endpoints/users.py index fd5d451..3a626a5 100644 --- a/endpoints/users.py +++ b/endpoints/users.py @@ -6,10 +6,8 @@ import re import imghdr import uuid import zlib - import bcrypt from uuid import UUID - import validators from fastapi import APIRouter as FastAPI, Depends, Response, Request, UploadFile from fastapi.exceptions import HTTPException @@ -52,47 +50,31 @@ from config import ( FORCE_EMAIL_VERIFICATION, UNVERIFIED_MANAGER, ) -from orm.users import UserModel, User +from responses import Response as APIResponse, UnprocessableEntity, BadRequest, NotFound, \ + MediaTypeNotAcceptable, PayloadTooLarge, InternalServerError +from responses.users import ( + PrivateUserResponse, + PublicUserResponse, + PrivateUserModel, + PublicUserModel, +) +from orm.users import ( + User, + UserModel, + get_user_by_username, + get_user_by_id, + get_user_by_email, +) from orm.media import Media, MediaType from orm.email_verification import EmailVerification from util.email import send_email - router = FastAPI() -async def get_user_by_id( - public_id: UUID, - include_secrets: bool = False, - restricted_ok: bool = False, - deleted_ok: bool = False, -) -> dict | None: - """ - Retrieves a user by its public ID - """ - - user = ( - await User.select( - *User.all_columns(exclude=["public_id"]), - User.public_id.as_alias("id"), - exclude_secrets=not include_secrets, - ) - .where(User.public_id == public_id) - .first() - ) - if user: - # Performs validation - UserModel(**user) - if (user["deleted"] and not deleted_ok) or ( - user["restricted"] and not restricted_ok - ): - return - return user - return - - +# Credential loaders for our authenticated routes @MANAGER.user_loader() -async def get_self_by_id(public_id: UUID, requires_verified: bool = True) -> dict: +async def get_self_by_id(public_id: UUID, requires_verified: bool = True) -> UserModel: user = await get_user_by_id(public_id, include_secrets=True, restricted_ok=True) if FORCE_EMAIL_VERIFICATION and requires_verified and not user["email_verified"]: raise HTTPException(status_code=401, detail="Email verification is required") @@ -100,76 +82,22 @@ async def get_self_by_id(public_id: UUID, requires_verified: bool = True) -> dic @UNVERIFIED_MANAGER.user_loader() -async def get_self_by_id_unverified(public_id: UUID): +async def get_self_by_id_unverified(public_id: UUID) -> UserModel: return await get_self_by_id(public_id, False) -async def get_user_by_username( - username: str, - include_secrets: bool = False, - restricted_ok: bool = False, - deleted_ok: bool = False, -) -> dict | None: - """ - Retrieves a user by its public username - """ - - user = ( - await User.select( - *User.all_columns(exclude=["public_id"]), - User.public_id.as_alias("id"), - exclude_secrets=not include_secrets, - ) - .where(User.username == username) - .first() - ) - if user: - # Performs validation - UserModel(**user) - if (user["deleted"] and not deleted_ok) or ( - user["restricted"] and not restricted_ok - ): - return - return user - return - - -async def get_user_by_email( - email: str, - include_secrets: bool = False, - restricted_ok: bool = False, - deleted_ok: bool = False, -) -> dict | None: - """ - Retrieves a user by its email address (meant to - be used internally) - """ - - user = ( - await User.select( - *User.all_columns(exclude=["public_id"]), - User.public_id.as_alias("id"), - exclude_secrets=not include_secrets, - ) - .where(User.email_address == email) - .first() - ) - if user: - # Performs validation - UserModel(**user) - if (user["deleted"] and not deleted_ok) or ( - user["restricted"] and not restricted_ok - ): - return - return user - return - +# Here follow our *beautifully* documented path operations +@router.post("/user", tags=["Users"], status_code=200, + responses={200: {"model": APIResponse}, + 400: {"model": BadRequest}, + 422: {"model": UnprocessableEntity}}) @LIMITER.limit("5/minute") -@router.post("/user") -async def login( - request: Request, response: Response, data: OAuth2PasswordRequestForm = Depends() -) -> dict: +async def login(request: Request, response: Response, data: OAuth2PasswordRequestForm = Depends()): + """ + Performs user authentication. Endpoint is limited to 5 hits per minute + """ + if request.cookies.get(SESSION_COOKIE_NAME): raise HTTPException(status_code=400, detail="Please logout first") username = data.username @@ -192,21 +120,21 @@ async def login( detail="Authentication failed: invalid characters in password", ) if not ( - user := await get_user_by_username( - username, include_secrets=True, restricted_ok=True - ) + user := await get_user_by_username( + username, include_secrets=True, restricted_ok=True + ) ): raise HTTPException( status_code=413, detail="Authentication failed: the user does not exist", ) - if not bcrypt.checkpw(password, user["password_hash"]): + if not bcrypt.checkpw(password, user.password_hash): raise HTTPException( status_code=413, detail="Authentication failed: password mismatch", ) token = MANAGER.create_access_token( - expires=timedelta(seconds=SESSION_EXPIRE_LIMIT), data={"sub": str(user["id"])} + expires=timedelta(seconds=SESSION_EXPIRE_LIMIT), data={"sub": str(user.public_id)} ) response.set_cookie( secure=SECURE_COOKIE, @@ -218,17 +146,17 @@ async def login( domain=COOKIE_DOMAIN or None, path=COOKIE_PATH or "/", ) - return {"status_code": 200, "msg": "Authentication successful"} + return APIResponse(status_code=200, msg="Authentication successful") -@router.get("/user/logout") +@router.get("/user/logout", tags=["Users"], status_code=200, + responses={200: {"model": APIResponse}, + 422: {"model": UnprocessableEntity}}) @LIMITER.limit("5/minute") -async def logout( - request: Request, response: Response, _user: dict = Depends(UNVERIFIED_MANAGER) -) -> dict: +async def logout(request: Request, response: Response, _user: UserModel = Depends(UNVERIFIED_MANAGER)): """ Deletes a user's session cookie, logging them - out + out. Endpoint is limited to 5 hits per minute """ response.delete_cookie( @@ -239,70 +167,81 @@ async def logout( domain=COOKIE_DOMAIN or None, path=COOKIE_PATH or "/", ) - return {"status_code": 200, "msg": "Logged out"} + return APIResponse(status_code=200, msg="Logged out") -@router.get("/user/me") +@router.get( + "/user/me", + tags=["Users"], + status_code=200, + responses={200: {"model": PrivateUserResponse, "exclude": {"password_hash", "internal_id", "deleted"}}, + 422: {"model": UnprocessableEntity}}, +) @LIMITER.limit("2/second") -async def get_self(request: Request, user: dict = Depends(UNVERIFIED_MANAGER)) -> dict: +async def get_self(request: Request, user: UserModel = Depends(UNVERIFIED_MANAGER)): """ Fetches a user's own info. This returns some extra data such as email address, account creation date and email verification status, - which is not available from the regular endpoint + which is not available from the regular endpoint. + Endpoint is limited to 2 hits per second """ - user.pop("password_hash") - user.pop("internal_id") - user.pop("deleted") - return {"status_code": 200, "msg": "Success", "data": user} + return PrivateUserResponse(status_code=200, msg="Success", data=user) -@router.get("/user/username/{username}") +@router.get( + "/user/username/{username}", + tags=["Users"], + status_code=200, + responses={200: {"model": PrivateUserResponse, "exclude": {"password_hash", "internal_id", "deleted"}}, + 404: {"model": NotFound}, + 422: {"model": UnprocessableEntity} + }, +) @LIMITER.limit("30/second") async def get_user_by_name( - request: Request, username: str, _auth: dict = Depends(MANAGER) -) -> dict: + request: Request, username: str, _auth: UserModel = Depends(MANAGER) +): """ - Fetches a single user by its public ID + Fetches a single user by its public username """ if not (user := await get_user_by_username(username)): - return { - "status_code": 404, - "msg": "Lookup failed: the user does not exist", - } - user.pop("restricted") - user.pop("deleted") - return {"status_code": 200, "msg": "Lookup successful", "data": user} + return NotFound(msg="Lookup failed: the user does not exist") + return PublicUserResponse(status_code=200, msg="Lookup successful", data=user) -@router.get("/user/id/{public_id}") +@router.get( + "/user/id/{public_id}", + tags=["Users"], + status_code=200, + responses={200: {"model": PrivateUserResponse, "exclude": {"password_hash", "internal_id", "deleted"}}, + 404: {"model": NotFound}, + 422: {"model": UnprocessableEntity} + }, +) @LIMITER.limit("30/second") async def get_user_by_public_id( - request: Request, public_id: str, _auth: dict = Depends(MANAGER) -) -> dict: + request: Request, public_id: str, _auth: UserModel = Depends(MANAGER) +): """ Fetches a single user by its public ID """ if not (user := await get_user_by_id(UUID(public_id))): - raise HTTPException( - status_code=404, detail="Lookup failed: the user does not exist" - ) - user.pop("restricted") - user.pop("deleted") - return {"status_code": 200, "msg": "Lookup successful", "data": user} + raise HTTPException(status_code=404, detail="Lookup failed: the user does not exist") + return PublicUserResponse(data=user) async def validate_user( - first_name: str | None, - last_name: str | None, - username: str | None, - email: str | None, - password: str | None, - bio: str | None, -) -> tuple[bool, str]: + first_name: str | None, + last_name: str | None, + username: str | None, + email: str | None, + password: str | None, + bio: str | None, +): """ Performs some validation upon user creation. Returns a tuple (success, msg) to be used by routes. Values @@ -322,9 +261,9 @@ async def validate_user( if username and len(username) > 32: return False, "username is too long" if ( - username - and VALIDATE_USERNAME_REGEX - and not re.match(VALIDATE_USERNAME_REGEX, username) + username + and VALIDATE_USERNAME_REGEX + and not re.match(VALIDATE_USERNAME_REGEX, username) ): return False, "username is invalid" if email and not validators.email(email): @@ -332,13 +271,13 @@ async def validate_user( if password and len(password) > 72: return False, "password is too long" if ( - password - and VALIDATE_PASSWORD_REGEX - and not re.match(VALIDATE_PASSWORD_REGEX, password) + password + and VALIDATE_PASSWORD_REGEX + and not re.match(VALIDATE_PASSWORD_REGEX, password) ): return False, "password is too weak" if username and await get_user_by_username( - username, deleted_ok=True, restricted_ok=True + username, deleted_ok=True, restricted_ok=True ): return False, "username is already taken" if email and await get_user_by_email(email, deleted_ok=True, restricted_ok=True): @@ -353,18 +292,20 @@ async def validate_user( return True, "" -@router.delete("/user") +@router.delete("/user", tags=["Users"], status_code=200, + responses={200: {"model": APIResponse}, + 422: {"model": UnprocessableEntity}}) @LIMITER.limit("1/minute") async def delete( - request: Request, response: Response, user: dict = Depends(UNVERIFIED_MANAGER) -) -> dict: + request: Request, response: Response, user: UserModel = Depends(UNVERIFIED_MANAGER) +): """ Sets the user's deleted flag in the database, without actually deleting the associated data """ - await User.update({User.deleted: True}).where(User.public_id == user["id"]) + await User.update({User.deleted: True}).where(User.public_id == user.public_id) response.delete_cookie( secure=SECURE_COOKIE, key=SESSION_COOKIE_NAME, @@ -373,22 +314,28 @@ async def delete( domain=COOKIE_DOMAIN or None, path=COOKIE_PATH or "/", ) - return {"status_code": 200, "msg": "Success"} + return APIResponse(status_code=200, msg="Success") -@router.get("/user/verifyEmail/{verification_id}") +@router.get("/user/verifyEmail/{verification_id}", tags=["Users"], status_code=200, + responses={200: {"model": PrivateUserResponse, "exclude": {"password_hash", "internal_id", "deleted"}}, + 404: {"model": NotFound}, + 422: {"model": UnprocessableEntity} + }) @LIMITER.limit("3/second") async def verify_email( - request: Request, verification_id: str, user: dict = Depends(UNVERIFIED_MANAGER) -) -> dict: + request: Request, + verification_id: str, + user: UserModel = Depends(UNVERIFIED_MANAGER), +): """ Verifies a user's email address """ if not ( - verification := await EmailVerification.select(*EmailVerification.all_columns()) - .where(EmailVerification.id == verification_id) - .first() + verification := await EmailVerification.select(*EmailVerification.all_columns()) + .where(EmailVerification.id == verification_id) + .first() ): raise HTTPException(status_code=404, detail="Verification ID is invalid") elif not verification["pending"]: @@ -402,27 +349,31 @@ async def verify_email( ) else: await EmailVerification.update({EmailVerification.pending: False}).where( - EmailVerification.user == user["id"] + EmailVerification.user == user.public_id ) - await User.update({User.email_verified: True}).where( - User.public_id == user["id"] - ) - return {"status_code": 200, "msg": "Verification successful"} + await User.update({User.email_verified: True}).where(User.public_id == user.public_id) + return APIResponse(status_code=200, msg="Verification successful") -@router.get("/user/resetPassword/{verification_id}") +@router.get("/user/resetPassword/{verification_id}", tags=["Users"], status_code=200, + responses={200: {"model": APIResponse}, + 400: {"model": BadRequest}, + 422: {"model": UnprocessableEntity}, + 404: {"model": NotFound}}) @LIMITER.limit("3/second") async def reset_password( - request: Request, verification_id: str, user: dict = Depends(UNVERIFIED_MANAGER) -) -> dict: + request: Request, + verification_id: str, + user: UserModel = Depends(UNVERIFIED_MANAGER), +): """ Modifies a user's password """ if not ( - verification := await EmailVerification.select(*EmailVerification.all_columns()) - .where(EmailVerification.id == verification_id) - .first() + verification := await EmailVerification.select(*EmailVerification.all_columns()) + .where(EmailVerification.id == verification_id) + .first() ): raise HTTPException(status_code=404, detail="Request ID is invalid") elif not verification["pending"]: @@ -436,27 +387,33 @@ async def reset_password( ) else: await EmailVerification.update({EmailVerification.pending: False}).where( - EmailVerification.user == user["id"] + EmailVerification.user == user.public_id ) await User.update({User.password_hash: verification["data"]}).where( - User.public_id == user["id"] + User.public_id == user.public_id ) - return {"status_code": 200, "msg": "Password updated"} + return APIResponse(status_code=200, msg="Password updated") -@router.get("/user/changeEmail/{verification_id}") +@router.get("/user/changeEmail/{verification_id}", tags=["Users"], status_code=200, + responses={200: {"model": APIResponse}, + 400: {"model": BadRequest}, + 422: {"model": UnprocessableEntity}, + 404: {"model": NotFound}}) @LIMITER.limit("3/second") async def change_email( - request: Request, verification_id: str, user: dict = Depends(UNVERIFIED_MANAGER) -) -> dict: + request: Request, + verification_id: str, + user: UserModel = Depends(UNVERIFIED_MANAGER), +): """ Modifies a user's email """ if not ( - verification := await EmailVerification.select(*EmailVerification.all_columns()) - .where(EmailVerification.id == verification_id) - .first() + verification := await EmailVerification.select(*EmailVerification.all_columns()) + .where(EmailVerification.id == verification_id) + .first() ): raise HTTPException(status_code=404, detail="Request ID is invalid") elif not verification["pending"]: @@ -470,18 +427,24 @@ async def change_email( ) else: await EmailVerification.update({EmailVerification.pending: False}).where( - EmailVerification.user == user["id"] + EmailVerification.user == user.public_id ) - await User.update({User.email_address: verification["data"].decode(), - User.email_verified: False}).where( - User.public_id == user["id"] - ) - return {"status_code": 200, "msg": "Email updated"} + await User.update( + { + User.email_address: verification["data"].decode(), + User.email_verified: False, + } + ).where(User.public_id == user.public_id) + return APIResponse(status_code=200, msg="Email updated") -@router.put("user/resendMail") +@router.put("user/resendMail", tags=["Users"], status_code=200, + responses={200: {"model": APIResponse}, + 400: {"model": BadRequest}, + 422: {"model": UnprocessableEntity}, + 500: {"model": InternalServerError}}) @LIMITER.limit("6/minute") -async def resend_email(request: Request, user: dict = Depends(UNVERIFIED_MANAGER)) -> dict: +async def resend_email(request: Request, user: UserModel = Depends(UNVERIFIED_MANAGER)): """ Resends the verification email to the user if the previous has expired """ @@ -523,26 +486,34 @@ async def resend_email(request: Request, user: dict = Depends(UNVERIFIED_MANAGER use_tls=SMTP_USE_TLS, ): await EmailVerification.update( - { - EmailVerification.id: verification_id, - EmailVerification.creation_date: datetime.now() - } + { + EmailVerification.id: verification_id, + EmailVerification.creation_date: datetime.now(), + } ) - return {"status_code": 200, "msg": "Success"} + return APIResponse(status_code=200, msg="Success") + else: + raise HTTPException(status_code=500, detail="An error occurred while trying to resend the email," + " please try again later") -@router.put("/user") +@router.put("/user", tags=["Users"], status_code=200, + responses={200: {"model": APIResponse}, + 400: {"model": BadRequest}, + 422: {"model": UnprocessableEntity}, + 500: {"model": InternalServerError}, + 413: {"model": PayloadTooLarge}}) @LIMITER.limit("2/minute") async def signup( - request: Request, - first_name: str, - last_name: str, - username: str, - email: str, - password: str, - bio: str | None = None, - locale: str = "en_US", -) -> dict: + request: Request, + first_name: str, + last_name: str, + username: str, + email: str, + password: str, + bio: str | None = None, + locale: str = "en_US", +): """ Endpoint used to create new users """ @@ -555,7 +526,7 @@ async def signup( first_name, last_name, username, email, password, bio ) if not result: - return {"status_code": 413, "msg": f"Signup failed: {msg}"} + return APIResponse(status_code=413, msg=f"Signup failed: {msg}") else: salt = bcrypt.gensalt(BCRYPT_ROUNDS) user = User( @@ -563,9 +534,7 @@ async def signup( last_name=last_name, username=username, email_address=email, - password_hash=bcrypt.hashpw( - password.encode(), salt - ), + password_hash=bcrypt.hashpw(password.encode(), salt), bio=bio, ) email_template = SMTP_TEMPLATES_DIRECTORY / f"{locale}.json" @@ -577,30 +546,30 @@ async def signup( email_message = json.load(f)["signup"] verification_id = uuid.uuid4() if await send_email( - SMTP_HOST, - SMTP_PORT, - email_message["content"].format( - first_name=first_name, - last_name=last_name, - username=username, - email=email, - link=f"http{'s' if HAS_HTTPS else ''}://{HOST}" - f"{'' if PORT == 443 and HAS_HTTPS or PORT == 80 else f':{PORT}'}/user/verifyEmail/{verification_id}", - platformName=PLATFORM_NAME, - ), - SMTP_TIMEOUT, - SMTP_FROM_USER, - email, - email_message["subject"].format( - first_name=first_name, - last_name=last_name, - username=username, - email=email, - platformName=PLATFORM_NAME, - ), - SMTP_USER, - SMTP_PASSWORD, - use_tls=SMTP_USE_TLS, + SMTP_HOST, + SMTP_PORT, + email_message["content"].format( + first_name=first_name, + last_name=last_name, + username=username, + email=email, + link=f"http{'s' if HAS_HTTPS else ''}://{HOST}" + f"{'' if PORT == 443 and HAS_HTTPS or PORT == 80 else f':{PORT}'}/user/verifyEmail/{verification_id}", + platformName=PLATFORM_NAME, + ), + SMTP_TIMEOUT, + SMTP_FROM_USER, + email, + email_message["subject"].format( + first_name=first_name, + last_name=last_name, + username=username, + email=email, + platformName=PLATFORM_NAME, + ), + SMTP_USER, + SMTP_PASSWORD, + use_tls=SMTP_USE_TLS, ): await User.insert(user) await EmailVerification.insert( @@ -611,17 +580,17 @@ async def signup( } ) ) - return {"status_code": 200, "msg": "Success"} + return APIResponse(status_code=200, msg="Success") else: raise HTTPException( status_code=500, detail="An error occurred while sending verification email, please" - " try again later", + " try again later", ) async def validate_profile_picture( - file: UploadFile, + file: UploadFile, ) -> tuple[bool | None, str, bytes, str]: """ Validates a profile picture's size and content to see if it fits @@ -650,18 +619,23 @@ async def validate_profile_picture( return None, "", b"", "" -@router.patch("/user") -async def update( - request: Request, - user: dict = Depends(UNVERIFIED_MANAGER), - first_name: str | None = None, - last_name: str | None = None, - username: str | None = None, - profile_picture: UploadFile | None = None, - email_address: str | None = None, - password: str | None = None, - bio: str | None = None, - delete: bool = False, +@router.patch("/user", tags=["Users"], status_code=200, + responses={200: {"model": APIResponse}, + 400: {"model": BadRequest}, + 422: {"model": UnprocessableEntity}, + 500: {"model": InternalServerError}, + 413: {"model": PayloadTooLarge}}) +async def update_user( + request: Request, + user: UserModel = Depends(UNVERIFIED_MANAGER), + first_name: str | None = None, + last_name: str | None = None, + username: str | None = None, + profile_picture: UploadFile | None = None, + email_address: str | None = None, + password: str | None = None, + bio: str | None = None, + delete: bool = False, ): """ Updates a user's profile information. Parameters that are not specified are left unchanged unless @@ -676,28 +650,26 @@ async def update( """ if not delete and not any( - (first_name, last_name, username, profile_picture, email_address, bio, password) + (first_name, last_name, username, profile_picture, email_address, bio, password) ): raise HTTPException( status_code=400, detail="At least one value has to be specified" ) - result, msg = ( - await validate_user( - first_name, last_name, username, email_address, password, bio - ) + result, msg = await validate_user( + first_name, last_name, username, email_address, password, bio ) if not result: raise HTTPException(status_code=413, detail=f"Update failed: {msg}") orig_user = user.copy() if not delete: if first_name: - user["first_name"] = first_name + user.first_name = first_name if last_name: - user["last_name"] = last_name + user.last_name = last_name if username: - user["username"] = username + user.username = username if bio: - user["bio"] = bio + user.bio = bio if profile_picture: result, ext, media, digest = validate_profile_picture(profile_picture) if result is False: @@ -707,44 +679,44 @@ async def update( elif result is None: raise HTTPException(status_code=413, detail="The file is too large") elif ( - await ( - old_media := Media.select(Media.media_id) - .where(Media.media_id == digest) - .first() - ) - is None + await ( + old_media := Media.select(Media.media_id) + .where(Media.media_id == digest) + .first() + ) + is None ): # This media hasn't been already uploaded (either by this user or by someone # else), so we save it now. If it has been already uploaded, there's no need # to do it again (that's what the hash is for) match STORAGE_ENGINE: case "database": - await Media.insert( - Media( - media_id=digest, - media_type=MediaType.BLOB, - content_type=ext, - content=base64.b64encode(media), - ) + m = Media( + media_id=digest, + media_type=MediaType.BLOB, + content_type=ext, + content=base64.b64encode(media), ) + await Media.insert(m) + user.profile_picture = m case "local": file = Path(STORAGE_FOLDER).resolve(strict=True) / str(digest) file.touch(mode=0o644) with file.open("wb") as f: f.write(media) - await Media.insert( - Media( - media_id=digest, - media_type=MediaType.FILE, - content_type=ext, - content=file.as_posix(), - ) + m = Media( + media_id=digest, + media_type=MediaType.FILE, + content_type=ext, + content=file.as_posix(), ) + await Media.insert(m) + user.profile_picture = m case "url": pass # TODO: Use/implement CDN uploading else: - user["media"] = old_media - if password and not bcrypt.checkpw(password.encode(), user["password_hash"]): + user.media = old_media + if password and not bcrypt.checkpw(password.encode(), user.password_hash): email_template = SMTP_TEMPLATES_DIRECTORY / f"{user['locale']}.json" try: email_template.resolve(strict=True) @@ -754,43 +726,48 @@ async def update( email_message = json.load(f)["password_change"] verification_id = uuid.uuid4() if not await send_email( - SMTP_HOST, - SMTP_PORT, - email_message["content"].format( - first_name=user["first_name"], - last_name=user["last_name"], - username=user["username"], - email=user["email_address"], - link=f"http{'s' if HAS_HTTPS else ''}://{HOST}" - f"{'' if PORT == 443 and HAS_HTTPS or PORT == 80 else f':{PORT}'}/user/resetPassword/{verification_id}", - platformName=PLATFORM_NAME, - ), - SMTP_TIMEOUT, - SMTP_FROM_USER, - user["email_address"], - email_message["subject"].format( - first_name=user["first_name"], - last_name=user["last_name"], - username=user["username"], - email=user["email_address"], - platformName=PLATFORM_NAME, - ), - SMTP_USER, - SMTP_PASSWORD, - use_tls=SMTP_USE_TLS, + SMTP_HOST, + SMTP_PORT, + email_message["content"].format( + first_name=user.first_name, + last_name=user.last_name, + username=user.username, + email=user.email_address, + link=f"http{'s' if HAS_HTTPS else ''}://{HOST}" + f"{'' if PORT == 443 and HAS_HTTPS or PORT == 80 else f':{PORT}'}/user/resetPassword/{verification_id}", + platformName=PLATFORM_NAME, + ), + SMTP_TIMEOUT, + SMTP_FROM_USER, + user.email_address, + email_message["subject"].format( + first_name=user.first_name, + last_name=user.last_name, + username=user.username, + email=user.email_address, + platformName=PLATFORM_NAME, + ), + SMTP_USER, + SMTP_PASSWORD, + use_tls=SMTP_USE_TLS, ): - raise HTTPException(500, detail="An error occurred while trying to send mail, please try again later") + raise HTTPException( + 500, + detail="An error occurred while trying to send mail, please try again later", + ) else: await EmailVerification.insert( EmailVerification( { EmailVerification.id: verification_id, - EmailVerification.user: User(public_id=user["id"]), - EmailVerification.data: bcrypt.hashpw(password.encode(), user["password_hash"][:29]) + EmailVerification.user: User(public_id=user.public_id), + EmailVerification.data: bcrypt.hashpw( + password.encode(), user.password_hash[:29] + ), } ) ) - if email_address and user["email_address"] != email_address: + if email_address and user.email_address != email_address: if not user["email_verified"]: raise HTTPException( status_code=403, @@ -805,51 +782,61 @@ async def update( email_message = json.load(f)["email_change"] verification_id = uuid.uuid4() if not await send_email( - SMTP_HOST, - SMTP_PORT, - email_message["content"].format( - first_name=user["first_name"], - last_name=user["last_name"], - username=user["username"], - email=user["email_address"], - platformName=PLATFORM_NAME, - link=f"http{'s' if HAS_HTTPS else ''}://{HOST}" - f"{'' if PORT == 443 and HAS_HTTPS or PORT == 80 else f':{PORT}'}/user/changeEmail/{verification_id}", - newMail=email_address, - ), - SMTP_TIMEOUT, - SMTP_FROM_USER, - user["email_address"], - email_message["subject"].format( - first_name=user["first_name"], - last_name=user["last_name"], - username=user["username"], - email=user["email_address"], - platformName=PLATFORM_NAME, - ), - SMTP_USER, - SMTP_PASSWORD, - use_tls=SMTP_USE_TLS, + SMTP_HOST, + SMTP_PORT, + email_message["content"].format( + first_name=user.first_name, + last_name=user.last_name, + username=user.username, + email=user.email_address, + platformName=PLATFORM_NAME, + link=f"http{'s' if HAS_HTTPS else ''}://{HOST}" + f"{'' if PORT == 443 and HAS_HTTPS or PORT == 80 else f':{PORT}'}/user/changeEmail/{verification_id}", + newMail=email_address, + ), + SMTP_TIMEOUT, + SMTP_FROM_USER, + user.email_address, + email_message["subject"].format( + first_name=user.first_name, + last_name=user.last_name, + username=user.username, + email=user.email_address, + platformName=PLATFORM_NAME, + ), + SMTP_USER, + SMTP_PASSWORD, + use_tls=SMTP_USE_TLS, ): - raise HTTPException(500, detail="An error occurred while trying to send mail, please try again later") + raise HTTPException( + 500, + detail="An error occurred while trying to send mail, please try again later", + ) else: - await EmailVerification.insert(EmailVerification({ - EmailVerification.id: verification_id, - EmailVerification.user: User(public_id=user["id"]), - EmailVerification.data: email_address.encode() - })) + await EmailVerification.insert( + EmailVerification( + { + EmailVerification.id: verification_id, + EmailVerification.user: User(public_id=user.public_id), + EmailVerification.data: email_address.encode(), + } + ) + ) else: if not bio: - user["bio"] = None + user.bio = None if not profile_picture: user["profile_picture"] = None fields = [] for field in user: - if field not in ["email_address", "password"] and orig_user[field] != user[field]: + if ( + field not in ["email_address", "password"] + and orig_user[field] != user[field] + ): fields.append((field, user[field])) if fields: # If anything has changed, we update our info await User.update({field: value for field, value in fields}).where( - User.public_id == user["id"] + User.public_id == user.public_id ) - return {"status_code": 200, "msg": "Changes saved successfully"} + return APIResponse(status_code=200, msg="Changes saved successfully") diff --git a/main.py b/main.py index 8bb9483..538db9f 100644 --- a/main.py +++ b/main.py @@ -9,6 +9,7 @@ from fastapi.exceptions import ( StarletteHTTPException, ) from slowapi.errors import RateLimitExceeded +from pathlib import Path from endpoints import users @@ -25,6 +26,7 @@ from config import ( HOST, PORT, WORKERS, + PLATFORM_NAME, ) from orm import create_tables, Media from util.exception_handlers import ( @@ -36,24 +38,55 @@ from util.exception_handlers import ( ) from util.email import test_smtp -app = FastAPI() + +with (Path(__file__).parent / "README.md").resolve(strict=True).open() as f: + description = f.read() -@app.get("/") +app = FastAPI( + title=PLATFORM_NAME, + license_info={"name": "MIT", "url": "https://opensource.org/licenses/MIT"}, + contact={ + "name": "Matt", + "url": "https://nocturn9x.space", + "email": "nocturn9x@nocturn9x.space", + }, + version="0.0.1", + description=description, + openapi_tags=[ + { + "name": "Miscellaneous", + "description": "Simple endpoints that don't fit anywhere else", + }, + { + "name": "Users", + "description": "Endpoints that handle user-related operations such as" + " signing in, signing up, settings management and more", + }, + { + "name": "Posts", + "description": "Endpoints that handle post creation, modification and deletion", + }, + ], +) + + +@app.get("/", tags=["Miscellaneous"]) @LIMITER.limit("10/second") async def root(request: Request): raise HTTPException(401, detail="Unauthorized") -@app.get("/ping") +@app.get("/ping", tags=["Miscellaneous"]) @LIMITER.limit("1/minute") async def ping(request: Request) -> dict: """ - This handler simply replies to "ping" requests and + This method simply replies to "ping" requests and is used to check whether the API is up and running. It also performs a sanity check with the database and the SMTP server to ensure that they are functioning correctly. - For this reason, this endpoint's rate limit is much stricter + For this reason, this endpoint's rate limit is very strict: + it can only be called once per minute """ LOGGER.info(f"Processing ping request from {request.client.host}") diff --git a/orm/users.py b/orm/users.py index d5521a8..981e426 100644 --- a/orm/users.py +++ b/orm/users.py @@ -2,6 +2,7 @@ User relation """ +from typing import Any from piccolo.utils.pydantic import create_pydantic_model from piccolo.table import Table from piccolo.columns import ( @@ -15,6 +16,7 @@ from piccolo.columns import ( Boolean, Email, Bytea, + Column, ) from piccolo.columns.defaults.date import DateNow @@ -45,4 +47,82 @@ class User(Table, tablename="users"): locale = Varchar(length=12, default="en_US", null=False, secret=True) -UserModel = create_pydantic_model(User) +UserModel = create_pydantic_model( + User, + nested=True, +) +PublicUserModel = create_pydantic_model( + User, + nested=True, + exclude_columns=( + User.internal_id, + User.deleted, + User.restricted, + User.locale, + User.email_verified, + User.email_address, + User.password_hash, + ), + model_name="PublicUser", +) +PrivateUserModel = create_pydantic_model( + User, + nested=True, + exclude_columns=(User.internal_id, User.password_hash, User.deleted), + model_name="PrivateUser", +) + + +async def get_user_by_column( + column: Column, + data: Any, + include_secrets: bool = False, + restricted_ok: bool = False, + deleted_ok: bool = False, +) -> UserModel | None: + """ + Retrieves a user object by a given criteria. + Returns None if the user doesn't exist or + if it's restricted/deleted (unless restricted_ok + and deleted_ok are set accordingly) + """ + + user = ( + await User.select( + *User.all_columns(), + exclude_secrets=not include_secrets, + ) + .where(column == data) + .first() + ) + if user: + # Performs validation + user = UserModel(**user) + if (user.deleted and not deleted_ok) or (user.restricted and not restricted_ok): + return + return user + return + + +async def get_user_by_id(public_id: UUID, *args, **kwargs) -> UserModel: + """ + Retrieves a user by its public ID + """ + + return await get_user_by_column(User.public_id, public_id, *args, **kwargs) + + +async def get_user_by_username(username: str, *args, **kwargs) -> UserModel | None: + """ + Retrieves a user by its public username + """ + + return await get_user_by_column(User.username, username, *args, **kwargs) + + +async def get_user_by_email(email: str, *args, **kwargs) -> UserModel | None: + """ + Retrieves a user by its email address + """ + + return await get_user_by_column(User.email_address, email, *args, **kwargs) diff --git a/requirements.txt b/requirements.txt index 3b35fd3..e03cb5f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,4 +7,5 @@ uvloop~=0.17.0 fastapi~=0.85.0 validators aiosmtplib -uvicorn \ No newline at end of file +uvicorn +pydantic \ No newline at end of file diff --git a/responses/__init__.py b/responses/__init__.py new file mode 100644 index 0000000..fc45fd8 --- /dev/null +++ b/responses/__init__.py @@ -0,0 +1,96 @@ +from pydantic import BaseModel +from orm.users import User + + +class Response(BaseModel): + """ + A generic response model + """ + + status_code: int = 200 + msg: str = "Success" + + +class UnprocessableEntity(Response): + """ + A 422 Unprocessable Entity response + model + """ + + status_code: int = 422 + msg: str = "Input Validation Failure" + + +class NotFound(Response): + """ + A 404 Not Found response + model + """ + + status_code: int = 404 + msg: str = "Not Found" + + +class PayloadTooLarge(Response): + """ + A 413 Payload Too Large response + model + """ + + status_code: int = 413 + msg: str = "Payload too large" + + +class MediaTypeNotAcceptable(Response): + """ + A 415 Media Type Not Acceptable response + model + """ + + status_code: int = 415 + msg: str = "Media type not acceptable" + + +class BadRequest(Response): + """ + A 400 Bad Request response model + """ + + status_code: int = 400 + msg: str = "Bad Request" + + +class InternalServerError(Response): + """ + A 500 Internal Server Error response model + """ + + status_code: int = 500 + msg: str = "Internal Server Error" + + +class TooManyRequests(Response): + """ + A 429 Too Many Requests response model + """ + + status_code: int = 429 + msg: str = "Too Many Requests" + + +class Unauthorized(Response): + """ + A 401 Unauthorized response model + """ + + status_code: int = 401 + msg: str = "Unauthorized" + + +class Forbidden(Response): + """ + A 403 Forbidden response model + """ + + status_code: int = 403 + msg: str = "Forbidden" diff --git a/responses/users.py b/responses/users.py new file mode 100644 index 0000000..a38c433 --- /dev/null +++ b/responses/users.py @@ -0,0 +1,24 @@ +from responses import Response +from orm.users import PublicUserModel, PrivateUserModel + + +class PublicUserResponse(Response): + """ + A response sent by an endpoint that + replies with a public user object + """ + + status_code: int = 200 + msg: str = "Lookup successful" + data: PublicUserModel + + +class PrivateUserResponse(Response): + """ + A response sent by an endpoint that + replies with a private user object + """ + + status_code: int = 200 + data: PrivateUserModel + msg: str = "Lookup successful" diff --git a/util/exception_handlers.py b/util/exception_handlers.py index ca89dd2..3f2963d 100644 --- a/util/exception_handlers.py +++ b/util/exception_handlers.py @@ -7,6 +7,9 @@ from slowapi.errors import RateLimitExceeded async def rate_limited(request: Request, error: RateLimitExceeded) -> JSONResponse: + """ + Handles the equivalent of a 429 Too Many Requests error + """ n = 0 while True: if error.detail[n].isnumeric(): @@ -18,79 +21,58 @@ async def rate_limited(request: Request, error: RateLimitExceeded) -> JSONRespon f"{request.client.host} got rate-limited at {str(request.url)} " f"(exceeded {error.detail})" ) - return JSONResponse( - status_code=200, - content={ - "msg": f"Too many requests, retry after {error.detail[error.detail.find('per') + 4:]}", - "status_code": 429, - }, - ) + return JSONResponse(status_code=200, content=dict(status_code=429, + msg=f"Too many requests, retry after {error.detail[error.detail.find('per') + 4:]}")) def not_authenticated(request: Request, _: NotAuthenticated) -> JSONResponse: + """ + Handles the equivalent of a 401 Unauthorized exception + """ + LOGGER.info(f"{request.client.host} failed to authenticate at {str(request.url)}") - return JSONResponse( - status_code=200, - content={ - "msg": "Authentication is required", - "status_code": 401, - }, - ) + return JSONResponse(status_code=200, content=dict(status_code=401, msg="Authentication is required")) def request_invalid(request: Request, exc: RequestValidationError) -> JSONResponse: + """ + Handles Bad Request exceptions from FastAPI + """ + LOGGER.info( f"{request.client.host} sent an invalid request at {request.url!r}: {type(exc).__name__}: {exc}" ) - return JSONResponse( - status_code=200, - content={ - "msg": f"Bad request: {type(exc).__name__}: {exc}", - "status_code": 400, - }, - ) + return JSONResponse(status_code=200, content=dict(status_code=400, msg=f"Bad request: {type(exc).__name__}: {exc}")) def http_exception( request: Request, exc: HTTPException | StarletteHTTPException ) -> JSONResponse: + """ + Handles HTTP-specific exceptions raised explicitly by + path operations + """ + if exc.status_code >= 500: LOGGER.error( f"{request.client.host} raised a {exc.status_code} error at {request.url!r}:" f"{type(exc).__name__}: {exc}" ) - return JSONResponse( - status_code=200, - content={ - "msg": "Internal server error", - "status_code": exc.status_code, - }, - ) + return JSONResponse(status_code=200, content=dict(status_code=500, msg="Internal Server Error")) else: LOGGER.info( f"{request.client.host} raised an HTTP error ({exc.status_code}) at {str(request.url)}" ) - return JSONResponse( - status_code=200, - content={ - "msg": exc.detail, - "status_code": exc.status_code, - }, - ) + return JSONResponse(status_code=200, content=dict(status_code=exc.status_code, msg=exc.detail)) async def generic_error(request: Request, exc: Exception) -> JSONResponse: """ - Handles generic, unexpected errors in the application + Handles generic, unexpected errors in the ASGI application """ LOGGER.info( f"{request.client.host} raised an unexpected error ({type(exc).__name__}: {exc}) at {str(request.url)}" ) - return JSONResponse( - status_code=200, - content={ - "msg": "Internal Server Error", # We can't leak anything about the error, it would be too risky - "status_code": 500, - }, - ) + # We can't leak anything about the error, it would be too risky + return JSONResponse(status_code=200, content=dict(status_code=500, msg="Internal Server Error"))