Large rework of the application for better documentation support

This commit is contained in:
Mattia Giambirtone 2022-10-05 19:52:37 +02:00
parent f0ef827577
commit 1c05716345
7 changed files with 597 additions and 394 deletions

View File

@ -6,10 +6,8 @@ import re
import imghdr import imghdr
import uuid import uuid
import zlib import zlib
import bcrypt import bcrypt
from uuid import UUID from uuid import UUID
import validators import validators
from fastapi import APIRouter as FastAPI, Depends, Response, Request, UploadFile from fastapi import APIRouter as FastAPI, Depends, Response, Request, UploadFile
from fastapi.exceptions import HTTPException from fastapi.exceptions import HTTPException
@ -52,47 +50,31 @@ from config import (
FORCE_EMAIL_VERIFICATION, FORCE_EMAIL_VERIFICATION,
UNVERIFIED_MANAGER, UNVERIFIED_MANAGER,
) )
from orm.users import UserModel, User from responses import Response as APIResponse, UnprocessableEntity, BadRequest, NotFound, \
MediaTypeNotAcceptable, PayloadTooLarge, InternalServerError
from responses.users import (
PrivateUserResponse,
PublicUserResponse,
PrivateUserModel,
PublicUserModel,
)
from orm.users import (
User,
UserModel,
get_user_by_username,
get_user_by_id,
get_user_by_email,
)
from orm.media import Media, MediaType from orm.media import Media, MediaType
from orm.email_verification import EmailVerification from orm.email_verification import EmailVerification
from util.email import send_email from util.email import send_email
router = FastAPI() router = FastAPI()
async def get_user_by_id( # Credential loaders for our authenticated routes
public_id: UUID,
include_secrets: bool = False,
restricted_ok: bool = False,
deleted_ok: bool = False,
) -> dict | None:
"""
Retrieves a user by its public ID
"""
user = (
await User.select(
*User.all_columns(exclude=["public_id"]),
User.public_id.as_alias("id"),
exclude_secrets=not include_secrets,
)
.where(User.public_id == public_id)
.first()
)
if user:
# Performs validation
UserModel(**user)
if (user["deleted"] and not deleted_ok) or (
user["restricted"] and not restricted_ok
):
return
return user
return
@MANAGER.user_loader() @MANAGER.user_loader()
async def get_self_by_id(public_id: UUID, requires_verified: bool = True) -> dict: async def get_self_by_id(public_id: UUID, requires_verified: bool = True) -> UserModel:
user = await get_user_by_id(public_id, include_secrets=True, restricted_ok=True) user = await get_user_by_id(public_id, include_secrets=True, restricted_ok=True)
if FORCE_EMAIL_VERIFICATION and requires_verified and not user["email_verified"]: if FORCE_EMAIL_VERIFICATION and requires_verified and not user["email_verified"]:
raise HTTPException(status_code=401, detail="Email verification is required") raise HTTPException(status_code=401, detail="Email verification is required")
@ -100,76 +82,22 @@ async def get_self_by_id(public_id: UUID, requires_verified: bool = True) -> dic
@UNVERIFIED_MANAGER.user_loader() @UNVERIFIED_MANAGER.user_loader()
async def get_self_by_id_unverified(public_id: UUID): async def get_self_by_id_unverified(public_id: UUID) -> UserModel:
return await get_self_by_id(public_id, False) return await get_self_by_id(public_id, False)
async def get_user_by_username( # Here follow our *beautifully* documented path operations
username: str,
include_secrets: bool = False,
restricted_ok: bool = False,
deleted_ok: bool = False,
) -> dict | None:
"""
Retrieves a user by its public username
"""
user = (
await User.select(
*User.all_columns(exclude=["public_id"]),
User.public_id.as_alias("id"),
exclude_secrets=not include_secrets,
)
.where(User.username == username)
.first()
)
if user:
# Performs validation
UserModel(**user)
if (user["deleted"] and not deleted_ok) or (
user["restricted"] and not restricted_ok
):
return
return user
return
async def get_user_by_email(
email: str,
include_secrets: bool = False,
restricted_ok: bool = False,
deleted_ok: bool = False,
) -> dict | None:
"""
Retrieves a user by its email address (meant to
be used internally)
"""
user = (
await User.select(
*User.all_columns(exclude=["public_id"]),
User.public_id.as_alias("id"),
exclude_secrets=not include_secrets,
)
.where(User.email_address == email)
.first()
)
if user:
# Performs validation
UserModel(**user)
if (user["deleted"] and not deleted_ok) or (
user["restricted"] and not restricted_ok
):
return
return user
return
@router.post("/user", tags=["Users"], status_code=200,
responses={200: {"model": APIResponse},
400: {"model": BadRequest},
422: {"model": UnprocessableEntity}})
@LIMITER.limit("5/minute") @LIMITER.limit("5/minute")
@router.post("/user") async def login(request: Request, response: Response, data: OAuth2PasswordRequestForm = Depends()):
async def login( """
request: Request, response: Response, data: OAuth2PasswordRequestForm = Depends() Performs user authentication. Endpoint is limited to 5 hits per minute
) -> dict: """
if request.cookies.get(SESSION_COOKIE_NAME): if request.cookies.get(SESSION_COOKIE_NAME):
raise HTTPException(status_code=400, detail="Please logout first") raise HTTPException(status_code=400, detail="Please logout first")
username = data.username username = data.username
@ -192,21 +120,21 @@ async def login(
detail="Authentication failed: invalid characters in password", detail="Authentication failed: invalid characters in password",
) )
if not ( if not (
user := await get_user_by_username( user := await get_user_by_username(
username, include_secrets=True, restricted_ok=True username, include_secrets=True, restricted_ok=True
) )
): ):
raise HTTPException( raise HTTPException(
status_code=413, status_code=413,
detail="Authentication failed: the user does not exist", detail="Authentication failed: the user does not exist",
) )
if not bcrypt.checkpw(password, user["password_hash"]): if not bcrypt.checkpw(password, user.password_hash):
raise HTTPException( raise HTTPException(
status_code=413, status_code=413,
detail="Authentication failed: password mismatch", detail="Authentication failed: password mismatch",
) )
token = MANAGER.create_access_token( token = MANAGER.create_access_token(
expires=timedelta(seconds=SESSION_EXPIRE_LIMIT), data={"sub": str(user["id"])} expires=timedelta(seconds=SESSION_EXPIRE_LIMIT), data={"sub": str(user.public_id)}
) )
response.set_cookie( response.set_cookie(
secure=SECURE_COOKIE, secure=SECURE_COOKIE,
@ -218,17 +146,17 @@ async def login(
domain=COOKIE_DOMAIN or None, domain=COOKIE_DOMAIN or None,
path=COOKIE_PATH or "/", path=COOKIE_PATH or "/",
) )
return {"status_code": 200, "msg": "Authentication successful"} return APIResponse(status_code=200, msg="Authentication successful")
@router.get("/user/logout") @router.get("/user/logout", tags=["Users"], status_code=200,
responses={200: {"model": APIResponse},
422: {"model": UnprocessableEntity}})
@LIMITER.limit("5/minute") @LIMITER.limit("5/minute")
async def logout( async def logout(request: Request, response: Response, _user: UserModel = Depends(UNVERIFIED_MANAGER)):
request: Request, response: Response, _user: dict = Depends(UNVERIFIED_MANAGER)
) -> dict:
""" """
Deletes a user's session cookie, logging them Deletes a user's session cookie, logging them
out out. Endpoint is limited to 5 hits per minute
""" """
response.delete_cookie( response.delete_cookie(
@ -239,70 +167,81 @@ async def logout(
domain=COOKIE_DOMAIN or None, domain=COOKIE_DOMAIN or None,
path=COOKIE_PATH or "/", path=COOKIE_PATH or "/",
) )
return {"status_code": 200, "msg": "Logged out"} return APIResponse(status_code=200, msg="Logged out")
@router.get("/user/me") @router.get(
"/user/me",
tags=["Users"],
status_code=200,
responses={200: {"model": PrivateUserResponse, "exclude": {"password_hash", "internal_id", "deleted"}},
422: {"model": UnprocessableEntity}},
)
@LIMITER.limit("2/second") @LIMITER.limit("2/second")
async def get_self(request: Request, user: dict = Depends(UNVERIFIED_MANAGER)) -> dict: async def get_self(request: Request, user: UserModel = Depends(UNVERIFIED_MANAGER)):
""" """
Fetches a user's own info. This returns some Fetches a user's own info. This returns some
extra data such as email address, account extra data such as email address, account
creation date and email verification status, creation date and email verification status,
which is not available from the regular endpoint which is not available from the regular endpoint.
Endpoint is limited to 2 hits per second
""" """
user.pop("password_hash") return PrivateUserResponse(status_code=200, msg="Success", data=user)
user.pop("internal_id")
user.pop("deleted")
return {"status_code": 200, "msg": "Success", "data": user}
@router.get("/user/username/{username}") @router.get(
"/user/username/{username}",
tags=["Users"],
status_code=200,
responses={200: {"model": PrivateUserResponse, "exclude": {"password_hash", "internal_id", "deleted"}},
404: {"model": NotFound},
422: {"model": UnprocessableEntity}
},
)
@LIMITER.limit("30/second") @LIMITER.limit("30/second")
async def get_user_by_name( async def get_user_by_name(
request: Request, username: str, _auth: dict = Depends(MANAGER) request: Request, username: str, _auth: UserModel = Depends(MANAGER)
) -> dict: ):
""" """
Fetches a single user by its public ID Fetches a single user by its public username
""" """
if not (user := await get_user_by_username(username)): if not (user := await get_user_by_username(username)):
return { return NotFound(msg="Lookup failed: the user does not exist")
"status_code": 404, return PublicUserResponse(status_code=200, msg="Lookup successful", data=user)
"msg": "Lookup failed: the user does not exist",
}
user.pop("restricted")
user.pop("deleted")
return {"status_code": 200, "msg": "Lookup successful", "data": user}
@router.get("/user/id/{public_id}") @router.get(
"/user/id/{public_id}",
tags=["Users"],
status_code=200,
responses={200: {"model": PrivateUserResponse, "exclude": {"password_hash", "internal_id", "deleted"}},
404: {"model": NotFound},
422: {"model": UnprocessableEntity}
},
)
@LIMITER.limit("30/second") @LIMITER.limit("30/second")
async def get_user_by_public_id( async def get_user_by_public_id(
request: Request, public_id: str, _auth: dict = Depends(MANAGER) request: Request, public_id: str, _auth: UserModel = Depends(MANAGER)
) -> dict: ):
""" """
Fetches a single user by its public ID Fetches a single user by its public ID
""" """
if not (user := await get_user_by_id(UUID(public_id))): if not (user := await get_user_by_id(UUID(public_id))):
raise HTTPException( raise HTTPException(status_code=404, detail="Lookup failed: the user does not exist")
status_code=404, detail="Lookup failed: the user does not exist" return PublicUserResponse(data=user)
)
user.pop("restricted")
user.pop("deleted")
return {"status_code": 200, "msg": "Lookup successful", "data": user}
async def validate_user( async def validate_user(
first_name: str | None, first_name: str | None,
last_name: str | None, last_name: str | None,
username: str | None, username: str | None,
email: str | None, email: str | None,
password: str | None, password: str | None,
bio: str | None, bio: str | None,
) -> tuple[bool, str]: ):
""" """
Performs some validation upon user creation. Returns Performs some validation upon user creation. Returns
a tuple (success, msg) to be used by routes. Values a tuple (success, msg) to be used by routes. Values
@ -322,9 +261,9 @@ async def validate_user(
if username and len(username) > 32: if username and len(username) > 32:
return False, "username is too long" return False, "username is too long"
if ( if (
username username
and VALIDATE_USERNAME_REGEX and VALIDATE_USERNAME_REGEX
and not re.match(VALIDATE_USERNAME_REGEX, username) and not re.match(VALIDATE_USERNAME_REGEX, username)
): ):
return False, "username is invalid" return False, "username is invalid"
if email and not validators.email(email): if email and not validators.email(email):
@ -332,13 +271,13 @@ async def validate_user(
if password and len(password) > 72: if password and len(password) > 72:
return False, "password is too long" return False, "password is too long"
if ( if (
password password
and VALIDATE_PASSWORD_REGEX and VALIDATE_PASSWORD_REGEX
and not re.match(VALIDATE_PASSWORD_REGEX, password) and not re.match(VALIDATE_PASSWORD_REGEX, password)
): ):
return False, "password is too weak" return False, "password is too weak"
if username and await get_user_by_username( if username and await get_user_by_username(
username, deleted_ok=True, restricted_ok=True username, deleted_ok=True, restricted_ok=True
): ):
return False, "username is already taken" return False, "username is already taken"
if email and await get_user_by_email(email, deleted_ok=True, restricted_ok=True): if email and await get_user_by_email(email, deleted_ok=True, restricted_ok=True):
@ -353,18 +292,20 @@ async def validate_user(
return True, "" return True, ""
@router.delete("/user") @router.delete("/user", tags=["Users"], status_code=200,
responses={200: {"model": APIResponse},
422: {"model": UnprocessableEntity}})
@LIMITER.limit("1/minute") @LIMITER.limit("1/minute")
async def delete( async def delete(
request: Request, response: Response, user: dict = Depends(UNVERIFIED_MANAGER) request: Request, response: Response, user: UserModel = Depends(UNVERIFIED_MANAGER)
) -> dict: ):
""" """
Sets the user's deleted flag in the database, Sets the user's deleted flag in the database,
without actually deleting the associated without actually deleting the associated
data data
""" """
await User.update({User.deleted: True}).where(User.public_id == user["id"]) await User.update({User.deleted: True}).where(User.public_id == user.public_id)
response.delete_cookie( response.delete_cookie(
secure=SECURE_COOKIE, secure=SECURE_COOKIE,
key=SESSION_COOKIE_NAME, key=SESSION_COOKIE_NAME,
@ -373,22 +314,28 @@ async def delete(
domain=COOKIE_DOMAIN or None, domain=COOKIE_DOMAIN or None,
path=COOKIE_PATH or "/", path=COOKIE_PATH or "/",
) )
return {"status_code": 200, "msg": "Success"} return APIResponse(status_code=200, msg="Success")
@router.get("/user/verifyEmail/{verification_id}") @router.get("/user/verifyEmail/{verification_id}", tags=["Users"], status_code=200,
responses={200: {"model": PrivateUserResponse, "exclude": {"password_hash", "internal_id", "deleted"}},
404: {"model": NotFound},
422: {"model": UnprocessableEntity}
})
@LIMITER.limit("3/second") @LIMITER.limit("3/second")
async def verify_email( async def verify_email(
request: Request, verification_id: str, user: dict = Depends(UNVERIFIED_MANAGER) request: Request,
) -> dict: verification_id: str,
user: UserModel = Depends(UNVERIFIED_MANAGER),
):
""" """
Verifies a user's email address Verifies a user's email address
""" """
if not ( if not (
verification := await EmailVerification.select(*EmailVerification.all_columns()) verification := await EmailVerification.select(*EmailVerification.all_columns())
.where(EmailVerification.id == verification_id) .where(EmailVerification.id == verification_id)
.first() .first()
): ):
raise HTTPException(status_code=404, detail="Verification ID is invalid") raise HTTPException(status_code=404, detail="Verification ID is invalid")
elif not verification["pending"]: elif not verification["pending"]:
@ -402,27 +349,31 @@ async def verify_email(
) )
else: else:
await EmailVerification.update({EmailVerification.pending: False}).where( await EmailVerification.update({EmailVerification.pending: False}).where(
EmailVerification.user == user["id"] EmailVerification.user == user.public_id
) )
await User.update({User.email_verified: True}).where( await User.update({User.email_verified: True}).where(User.public_id == user.public_id)
User.public_id == user["id"] return APIResponse(status_code=200, msg="Verification successful")
)
return {"status_code": 200, "msg": "Verification successful"}
@router.get("/user/resetPassword/{verification_id}") @router.get("/user/resetPassword/{verification_id}", tags=["Users"], status_code=200,
responses={200: {"model": APIResponse},
400: {"model": BadRequest},
422: {"model": UnprocessableEntity},
404: {"model": NotFound}})
@LIMITER.limit("3/second") @LIMITER.limit("3/second")
async def reset_password( async def reset_password(
request: Request, verification_id: str, user: dict = Depends(UNVERIFIED_MANAGER) request: Request,
) -> dict: verification_id: str,
user: UserModel = Depends(UNVERIFIED_MANAGER),
):
""" """
Modifies a user's password Modifies a user's password
""" """
if not ( if not (
verification := await EmailVerification.select(*EmailVerification.all_columns()) verification := await EmailVerification.select(*EmailVerification.all_columns())
.where(EmailVerification.id == verification_id) .where(EmailVerification.id == verification_id)
.first() .first()
): ):
raise HTTPException(status_code=404, detail="Request ID is invalid") raise HTTPException(status_code=404, detail="Request ID is invalid")
elif not verification["pending"]: elif not verification["pending"]:
@ -436,27 +387,33 @@ async def reset_password(
) )
else: else:
await EmailVerification.update({EmailVerification.pending: False}).where( await EmailVerification.update({EmailVerification.pending: False}).where(
EmailVerification.user == user["id"] EmailVerification.user == user.public_id
) )
await User.update({User.password_hash: verification["data"]}).where( await User.update({User.password_hash: verification["data"]}).where(
User.public_id == user["id"] User.public_id == user.public_id
) )
return {"status_code": 200, "msg": "Password updated"} return APIResponse(status_code=200, msg="Password updated")
@router.get("/user/changeEmail/{verification_id}") @router.get("/user/changeEmail/{verification_id}", tags=["Users"], status_code=200,
responses={200: {"model": APIResponse},
400: {"model": BadRequest},
422: {"model": UnprocessableEntity},
404: {"model": NotFound}})
@LIMITER.limit("3/second") @LIMITER.limit("3/second")
async def change_email( async def change_email(
request: Request, verification_id: str, user: dict = Depends(UNVERIFIED_MANAGER) request: Request,
) -> dict: verification_id: str,
user: UserModel = Depends(UNVERIFIED_MANAGER),
):
""" """
Modifies a user's email Modifies a user's email
""" """
if not ( if not (
verification := await EmailVerification.select(*EmailVerification.all_columns()) verification := await EmailVerification.select(*EmailVerification.all_columns())
.where(EmailVerification.id == verification_id) .where(EmailVerification.id == verification_id)
.first() .first()
): ):
raise HTTPException(status_code=404, detail="Request ID is invalid") raise HTTPException(status_code=404, detail="Request ID is invalid")
elif not verification["pending"]: elif not verification["pending"]:
@ -470,18 +427,24 @@ async def change_email(
) )
else: else:
await EmailVerification.update({EmailVerification.pending: False}).where( await EmailVerification.update({EmailVerification.pending: False}).where(
EmailVerification.user == user["id"] EmailVerification.user == user.public_id
) )
await User.update({User.email_address: verification["data"].decode(), await User.update(
User.email_verified: False}).where( {
User.public_id == user["id"] User.email_address: verification["data"].decode(),
) User.email_verified: False,
return {"status_code": 200, "msg": "Email updated"} }
).where(User.public_id == user.public_id)
return APIResponse(status_code=200, msg="Email updated")
@router.put("user/resendMail") @router.put("user/resendMail", tags=["Users"], status_code=200,
responses={200: {"model": APIResponse},
400: {"model": BadRequest},
422: {"model": UnprocessableEntity},
500: {"model": InternalServerError}})
@LIMITER.limit("6/minute") @LIMITER.limit("6/minute")
async def resend_email(request: Request, user: dict = Depends(UNVERIFIED_MANAGER)) -> dict: async def resend_email(request: Request, user: UserModel = Depends(UNVERIFIED_MANAGER)):
""" """
Resends the verification email to the user if the previous has expired Resends the verification email to the user if the previous has expired
""" """
@ -523,26 +486,34 @@ async def resend_email(request: Request, user: dict = Depends(UNVERIFIED_MANAGER
use_tls=SMTP_USE_TLS, use_tls=SMTP_USE_TLS,
): ):
await EmailVerification.update( await EmailVerification.update(
{ {
EmailVerification.id: verification_id, EmailVerification.id: verification_id,
EmailVerification.creation_date: datetime.now() EmailVerification.creation_date: datetime.now(),
} }
) )
return {"status_code": 200, "msg": "Success"} return APIResponse(status_code=200, msg="Success")
else:
raise HTTPException(status_code=500, detail="An error occurred while trying to resend the email,"
" please try again later")
@router.put("/user") @router.put("/user", tags=["Users"], status_code=200,
responses={200: {"model": APIResponse},
400: {"model": BadRequest},
422: {"model": UnprocessableEntity},
500: {"model": InternalServerError},
413: {"model": PayloadTooLarge}})
@LIMITER.limit("2/minute") @LIMITER.limit("2/minute")
async def signup( async def signup(
request: Request, request: Request,
first_name: str, first_name: str,
last_name: str, last_name: str,
username: str, username: str,
email: str, email: str,
password: str, password: str,
bio: str | None = None, bio: str | None = None,
locale: str = "en_US", locale: str = "en_US",
) -> dict: ):
""" """
Endpoint used to create new users Endpoint used to create new users
""" """
@ -555,7 +526,7 @@ async def signup(
first_name, last_name, username, email, password, bio first_name, last_name, username, email, password, bio
) )
if not result: if not result:
return {"status_code": 413, "msg": f"Signup failed: {msg}"} return APIResponse(status_code=413, msg=f"Signup failed: {msg}")
else: else:
salt = bcrypt.gensalt(BCRYPT_ROUNDS) salt = bcrypt.gensalt(BCRYPT_ROUNDS)
user = User( user = User(
@ -563,9 +534,7 @@ async def signup(
last_name=last_name, last_name=last_name,
username=username, username=username,
email_address=email, email_address=email,
password_hash=bcrypt.hashpw( password_hash=bcrypt.hashpw(password.encode(), salt),
password.encode(), salt
),
bio=bio, bio=bio,
) )
email_template = SMTP_TEMPLATES_DIRECTORY / f"{locale}.json" email_template = SMTP_TEMPLATES_DIRECTORY / f"{locale}.json"
@ -577,30 +546,30 @@ async def signup(
email_message = json.load(f)["signup"] email_message = json.load(f)["signup"]
verification_id = uuid.uuid4() verification_id = uuid.uuid4()
if await send_email( if await send_email(
SMTP_HOST, SMTP_HOST,
SMTP_PORT, SMTP_PORT,
email_message["content"].format( email_message["content"].format(
first_name=first_name, first_name=first_name,
last_name=last_name, last_name=last_name,
username=username, username=username,
email=email, email=email,
link=f"http{'s' if HAS_HTTPS else ''}://{HOST}" link=f"http{'s' if HAS_HTTPS else ''}://{HOST}"
f"{'' if PORT == 443 and HAS_HTTPS or PORT == 80 else f':{PORT}'}/user/verifyEmail/{verification_id}", f"{'' if PORT == 443 and HAS_HTTPS or PORT == 80 else f':{PORT}'}/user/verifyEmail/{verification_id}",
platformName=PLATFORM_NAME, platformName=PLATFORM_NAME,
), ),
SMTP_TIMEOUT, SMTP_TIMEOUT,
SMTP_FROM_USER, SMTP_FROM_USER,
email, email,
email_message["subject"].format( email_message["subject"].format(
first_name=first_name, first_name=first_name,
last_name=last_name, last_name=last_name,
username=username, username=username,
email=email, email=email,
platformName=PLATFORM_NAME, platformName=PLATFORM_NAME,
), ),
SMTP_USER, SMTP_USER,
SMTP_PASSWORD, SMTP_PASSWORD,
use_tls=SMTP_USE_TLS, use_tls=SMTP_USE_TLS,
): ):
await User.insert(user) await User.insert(user)
await EmailVerification.insert( await EmailVerification.insert(
@ -611,17 +580,17 @@ async def signup(
} }
) )
) )
return {"status_code": 200, "msg": "Success"} return APIResponse(status_code=200, msg="Success")
else: else:
raise HTTPException( raise HTTPException(
status_code=500, status_code=500,
detail="An error occurred while sending verification email, please" detail="An error occurred while sending verification email, please"
" try again later", " try again later",
) )
async def validate_profile_picture( async def validate_profile_picture(
file: UploadFile, file: UploadFile,
) -> tuple[bool | None, str, bytes, str]: ) -> tuple[bool | None, str, bytes, str]:
""" """
Validates a profile picture's size and content to see if it fits Validates a profile picture's size and content to see if it fits
@ -650,18 +619,23 @@ async def validate_profile_picture(
return None, "", b"", "" return None, "", b"", ""
@router.patch("/user") @router.patch("/user", tags=["Users"], status_code=200,
async def update( responses={200: {"model": APIResponse},
request: Request, 400: {"model": BadRequest},
user: dict = Depends(UNVERIFIED_MANAGER), 422: {"model": UnprocessableEntity},
first_name: str | None = None, 500: {"model": InternalServerError},
last_name: str | None = None, 413: {"model": PayloadTooLarge}})
username: str | None = None, async def update_user(
profile_picture: UploadFile | None = None, request: Request,
email_address: str | None = None, user: UserModel = Depends(UNVERIFIED_MANAGER),
password: str | None = None, first_name: str | None = None,
bio: str | None = None, last_name: str | None = None,
delete: bool = False, username: str | None = None,
profile_picture: UploadFile | None = None,
email_address: str | None = None,
password: str | None = None,
bio: str | None = None,
delete: bool = False,
): ):
""" """
Updates a user's profile information. Parameters that are not specified are left unchanged unless Updates a user's profile information. Parameters that are not specified are left unchanged unless
@ -676,28 +650,26 @@ async def update(
""" """
if not delete and not any( if not delete and not any(
(first_name, last_name, username, profile_picture, email_address, bio, password) (first_name, last_name, username, profile_picture, email_address, bio, password)
): ):
raise HTTPException( raise HTTPException(
status_code=400, detail="At least one value has to be specified" status_code=400, detail="At least one value has to be specified"
) )
result, msg = ( result, msg = await validate_user(
await validate_user( first_name, last_name, username, email_address, password, bio
first_name, last_name, username, email_address, password, bio
)
) )
if not result: if not result:
raise HTTPException(status_code=413, detail=f"Update failed: {msg}") raise HTTPException(status_code=413, detail=f"Update failed: {msg}")
orig_user = user.copy() orig_user = user.copy()
if not delete: if not delete:
if first_name: if first_name:
user["first_name"] = first_name user.first_name = first_name
if last_name: if last_name:
user["last_name"] = last_name user.last_name = last_name
if username: if username:
user["username"] = username user.username = username
if bio: if bio:
user["bio"] = bio user.bio = bio
if profile_picture: if profile_picture:
result, ext, media, digest = validate_profile_picture(profile_picture) result, ext, media, digest = validate_profile_picture(profile_picture)
if result is False: if result is False:
@ -707,44 +679,44 @@ async def update(
elif result is None: elif result is None:
raise HTTPException(status_code=413, detail="The file is too large") raise HTTPException(status_code=413, detail="The file is too large")
elif ( elif (
await ( await (
old_media := Media.select(Media.media_id) old_media := Media.select(Media.media_id)
.where(Media.media_id == digest) .where(Media.media_id == digest)
.first() .first()
) )
is None is None
): ):
# This media hasn't been already uploaded (either by this user or by someone # This media hasn't been already uploaded (either by this user or by someone
# else), so we save it now. If it has been already uploaded, there's no need # else), so we save it now. If it has been already uploaded, there's no need
# to do it again (that's what the hash is for) # to do it again (that's what the hash is for)
match STORAGE_ENGINE: match STORAGE_ENGINE:
case "database": case "database":
await Media.insert( m = Media(
Media( media_id=digest,
media_id=digest, media_type=MediaType.BLOB,
media_type=MediaType.BLOB, content_type=ext,
content_type=ext, content=base64.b64encode(media),
content=base64.b64encode(media),
)
) )
await Media.insert(m)
user.profile_picture = m
case "local": case "local":
file = Path(STORAGE_FOLDER).resolve(strict=True) / str(digest) file = Path(STORAGE_FOLDER).resolve(strict=True) / str(digest)
file.touch(mode=0o644) file.touch(mode=0o644)
with file.open("wb") as f: with file.open("wb") as f:
f.write(media) f.write(media)
await Media.insert( m = Media(
Media( media_id=digest,
media_id=digest, media_type=MediaType.FILE,
media_type=MediaType.FILE, content_type=ext,
content_type=ext, content=file.as_posix(),
content=file.as_posix(),
)
) )
await Media.insert(m)
user.profile_picture = m
case "url": case "url":
pass # TODO: Use/implement CDN uploading pass # TODO: Use/implement CDN uploading
else: else:
user["media"] = old_media user.media = old_media
if password and not bcrypt.checkpw(password.encode(), user["password_hash"]): if password and not bcrypt.checkpw(password.encode(), user.password_hash):
email_template = SMTP_TEMPLATES_DIRECTORY / f"{user['locale']}.json" email_template = SMTP_TEMPLATES_DIRECTORY / f"{user['locale']}.json"
try: try:
email_template.resolve(strict=True) email_template.resolve(strict=True)
@ -754,43 +726,48 @@ async def update(
email_message = json.load(f)["password_change"] email_message = json.load(f)["password_change"]
verification_id = uuid.uuid4() verification_id = uuid.uuid4()
if not await send_email( if not await send_email(
SMTP_HOST, SMTP_HOST,
SMTP_PORT, SMTP_PORT,
email_message["content"].format( email_message["content"].format(
first_name=user["first_name"], first_name=user.first_name,
last_name=user["last_name"], last_name=user.last_name,
username=user["username"], username=user.username,
email=user["email_address"], email=user.email_address,
link=f"http{'s' if HAS_HTTPS else ''}://{HOST}" link=f"http{'s' if HAS_HTTPS else ''}://{HOST}"
f"{'' if PORT == 443 and HAS_HTTPS or PORT == 80 else f':{PORT}'}/user/resetPassword/{verification_id}", f"{'' if PORT == 443 and HAS_HTTPS or PORT == 80 else f':{PORT}'}/user/resetPassword/{verification_id}",
platformName=PLATFORM_NAME, platformName=PLATFORM_NAME,
), ),
SMTP_TIMEOUT, SMTP_TIMEOUT,
SMTP_FROM_USER, SMTP_FROM_USER,
user["email_address"], user.email_address,
email_message["subject"].format( email_message["subject"].format(
first_name=user["first_name"], first_name=user.first_name,
last_name=user["last_name"], last_name=user.last_name,
username=user["username"], username=user.username,
email=user["email_address"], email=user.email_address,
platformName=PLATFORM_NAME, platformName=PLATFORM_NAME,
), ),
SMTP_USER, SMTP_USER,
SMTP_PASSWORD, SMTP_PASSWORD,
use_tls=SMTP_USE_TLS, use_tls=SMTP_USE_TLS,
): ):
raise HTTPException(500, detail="An error occurred while trying to send mail, please try again later") raise HTTPException(
500,
detail="An error occurred while trying to send mail, please try again later",
)
else: else:
await EmailVerification.insert( await EmailVerification.insert(
EmailVerification( EmailVerification(
{ {
EmailVerification.id: verification_id, EmailVerification.id: verification_id,
EmailVerification.user: User(public_id=user["id"]), EmailVerification.user: User(public_id=user.public_id),
EmailVerification.data: bcrypt.hashpw(password.encode(), user["password_hash"][:29]) EmailVerification.data: bcrypt.hashpw(
password.encode(), user.password_hash[:29]
),
} }
) )
) )
if email_address and user["email_address"] != email_address: if email_address and user.email_address != email_address:
if not user["email_verified"]: if not user["email_verified"]:
raise HTTPException( raise HTTPException(
status_code=403, status_code=403,
@ -805,51 +782,61 @@ async def update(
email_message = json.load(f)["email_change"] email_message = json.load(f)["email_change"]
verification_id = uuid.uuid4() verification_id = uuid.uuid4()
if not await send_email( if not await send_email(
SMTP_HOST, SMTP_HOST,
SMTP_PORT, SMTP_PORT,
email_message["content"].format( email_message["content"].format(
first_name=user["first_name"], first_name=user.first_name,
last_name=user["last_name"], last_name=user.last_name,
username=user["username"], username=user.username,
email=user["email_address"], email=user.email_address,
platformName=PLATFORM_NAME, platformName=PLATFORM_NAME,
link=f"http{'s' if HAS_HTTPS else ''}://{HOST}" link=f"http{'s' if HAS_HTTPS else ''}://{HOST}"
f"{'' if PORT == 443 and HAS_HTTPS or PORT == 80 else f':{PORT}'}/user/changeEmail/{verification_id}", f"{'' if PORT == 443 and HAS_HTTPS or PORT == 80 else f':{PORT}'}/user/changeEmail/{verification_id}",
newMail=email_address, newMail=email_address,
), ),
SMTP_TIMEOUT, SMTP_TIMEOUT,
SMTP_FROM_USER, SMTP_FROM_USER,
user["email_address"], user.email_address,
email_message["subject"].format( email_message["subject"].format(
first_name=user["first_name"], first_name=user.first_name,
last_name=user["last_name"], last_name=user.last_name,
username=user["username"], username=user.username,
email=user["email_address"], email=user.email_address,
platformName=PLATFORM_NAME, platformName=PLATFORM_NAME,
), ),
SMTP_USER, SMTP_USER,
SMTP_PASSWORD, SMTP_PASSWORD,
use_tls=SMTP_USE_TLS, use_tls=SMTP_USE_TLS,
): ):
raise HTTPException(500, detail="An error occurred while trying to send mail, please try again later") raise HTTPException(
500,
detail="An error occurred while trying to send mail, please try again later",
)
else: else:
await EmailVerification.insert(EmailVerification({ await EmailVerification.insert(
EmailVerification.id: verification_id, EmailVerification(
EmailVerification.user: User(public_id=user["id"]), {
EmailVerification.data: email_address.encode() EmailVerification.id: verification_id,
})) EmailVerification.user: User(public_id=user.public_id),
EmailVerification.data: email_address.encode(),
}
)
)
else: else:
if not bio: if not bio:
user["bio"] = None user.bio = None
if not profile_picture: if not profile_picture:
user["profile_picture"] = None user["profile_picture"] = None
fields = [] fields = []
for field in user: for field in user:
if field not in ["email_address", "password"] and orig_user[field] != user[field]: if (
field not in ["email_address", "password"]
and orig_user[field] != user[field]
):
fields.append((field, user[field])) fields.append((field, user[field]))
if fields: if fields:
# If anything has changed, we update our info # If anything has changed, we update our info
await User.update({field: value for field, value in fields}).where( await User.update({field: value for field, value in fields}).where(
User.public_id == user["id"] User.public_id == user.public_id
) )
return {"status_code": 200, "msg": "Changes saved successfully"} return APIResponse(status_code=200, msg="Changes saved successfully")

43
main.py
View File

@ -9,6 +9,7 @@ from fastapi.exceptions import (
StarletteHTTPException, StarletteHTTPException,
) )
from slowapi.errors import RateLimitExceeded from slowapi.errors import RateLimitExceeded
from pathlib import Path
from endpoints import users from endpoints import users
@ -25,6 +26,7 @@ from config import (
HOST, HOST,
PORT, PORT,
WORKERS, WORKERS,
PLATFORM_NAME,
) )
from orm import create_tables, Media from orm import create_tables, Media
from util.exception_handlers import ( from util.exception_handlers import (
@ -36,24 +38,55 @@ from util.exception_handlers import (
) )
from util.email import test_smtp from util.email import test_smtp
app = FastAPI()
with (Path(__file__).parent / "README.md").resolve(strict=True).open() as f:
description = f.read()
@app.get("/") app = FastAPI(
title=PLATFORM_NAME,
license_info={"name": "MIT", "url": "https://opensource.org/licenses/MIT"},
contact={
"name": "Matt",
"url": "https://nocturn9x.space",
"email": "nocturn9x@nocturn9x.space",
},
version="0.0.1",
description=description,
openapi_tags=[
{
"name": "Miscellaneous",
"description": "Simple endpoints that don't fit anywhere else",
},
{
"name": "Users",
"description": "Endpoints that handle user-related operations such as"
" signing in, signing up, settings management and more",
},
{
"name": "Posts",
"description": "Endpoints that handle post creation, modification and deletion",
},
],
)
@app.get("/", tags=["Miscellaneous"])
@LIMITER.limit("10/second") @LIMITER.limit("10/second")
async def root(request: Request): async def root(request: Request):
raise HTTPException(401, detail="Unauthorized") raise HTTPException(401, detail="Unauthorized")
@app.get("/ping") @app.get("/ping", tags=["Miscellaneous"])
@LIMITER.limit("1/minute") @LIMITER.limit("1/minute")
async def ping(request: Request) -> dict: async def ping(request: Request) -> dict:
""" """
This handler simply replies to "ping" requests and This method simply replies to "ping" requests and
is used to check whether the API is up and running. is used to check whether the API is up and running.
It also performs a sanity check with the database and It also performs a sanity check with the database and
the SMTP server to ensure that they are functioning correctly. the SMTP server to ensure that they are functioning correctly.
For this reason, this endpoint's rate limit is much stricter For this reason, this endpoint's rate limit is very strict:
it can only be called once per minute
""" """
LOGGER.info(f"Processing ping request from {request.client.host}") LOGGER.info(f"Processing ping request from {request.client.host}")

View File

@ -2,6 +2,7 @@
User relation User relation
""" """
from typing import Any
from piccolo.utils.pydantic import create_pydantic_model from piccolo.utils.pydantic import create_pydantic_model
from piccolo.table import Table from piccolo.table import Table
from piccolo.columns import ( from piccolo.columns import (
@ -15,6 +16,7 @@ from piccolo.columns import (
Boolean, Boolean,
Email, Email,
Bytea, Bytea,
Column,
) )
from piccolo.columns.defaults.date import DateNow from piccolo.columns.defaults.date import DateNow
@ -45,4 +47,82 @@ class User(Table, tablename="users"):
locale = Varchar(length=12, default="en_US", null=False, secret=True) locale = Varchar(length=12, default="en_US", null=False, secret=True)
UserModel = create_pydantic_model(User) UserModel = create_pydantic_model(
User,
nested=True,
)
PublicUserModel = create_pydantic_model(
User,
nested=True,
exclude_columns=(
User.internal_id,
User.deleted,
User.restricted,
User.locale,
User.email_verified,
User.email_address,
User.password_hash,
),
model_name="PublicUser",
)
PrivateUserModel = create_pydantic_model(
User,
nested=True,
exclude_columns=(User.internal_id, User.password_hash, User.deleted),
model_name="PrivateUser",
)
async def get_user_by_column(
column: Column,
data: Any,
include_secrets: bool = False,
restricted_ok: bool = False,
deleted_ok: bool = False,
) -> UserModel | None:
"""
Retrieves a user object by a given criteria.
Returns None if the user doesn't exist or
if it's restricted/deleted (unless restricted_ok
and deleted_ok are set accordingly)
"""
user = (
await User.select(
*User.all_columns(),
exclude_secrets=not include_secrets,
)
.where(column == data)
.first()
)
if user:
# Performs validation
user = UserModel(**user)
if (user.deleted and not deleted_ok) or (user.restricted and not restricted_ok):
return
return user
return
async def get_user_by_id(public_id: UUID, *args, **kwargs) -> UserModel:
"""
Retrieves a user by its public ID
"""
return await get_user_by_column(User.public_id, public_id, *args, **kwargs)
async def get_user_by_username(username: str, *args, **kwargs) -> UserModel | None:
"""
Retrieves a user by its public username
"""
return await get_user_by_column(User.username, username, *args, **kwargs)
async def get_user_by_email(email: str, *args, **kwargs) -> UserModel | None:
"""
Retrieves a user by its email address
"""
return await get_user_by_column(User.email_address, email, *args, **kwargs)

View File

@ -7,4 +7,5 @@ uvloop~=0.17.0
fastapi~=0.85.0 fastapi~=0.85.0
validators validators
aiosmtplib aiosmtplib
uvicorn uvicorn
pydantic

96
responses/__init__.py Normal file
View File

@ -0,0 +1,96 @@
from pydantic import BaseModel
from orm.users import User
class Response(BaseModel):
"""
A generic response model
"""
status_code: int = 200
msg: str = "Success"
class UnprocessableEntity(Response):
"""
A 422 Unprocessable Entity response
model
"""
status_code: int = 422
msg: str = "Input Validation Failure"
class NotFound(Response):
"""
A 404 Not Found response
model
"""
status_code: int = 404
msg: str = "Not Found"
class PayloadTooLarge(Response):
"""
A 413 Payload Too Large response
model
"""
status_code: int = 413
msg: str = "Payload too large"
class MediaTypeNotAcceptable(Response):
"""
A 415 Media Type Not Acceptable response
model
"""
status_code: int = 415
msg: str = "Media type not acceptable"
class BadRequest(Response):
"""
A 400 Bad Request response model
"""
status_code: int = 400
msg: str = "Bad Request"
class InternalServerError(Response):
"""
A 500 Internal Server Error response model
"""
status_code: int = 500
msg: str = "Internal Server Error"
class TooManyRequests(Response):
"""
A 429 Too Many Requests response model
"""
status_code: int = 429
msg: str = "Too Many Requests"
class Unauthorized(Response):
"""
A 401 Unauthorized response model
"""
status_code: int = 401
msg: str = "Unauthorized"
class Forbidden(Response):
"""
A 403 Forbidden response model
"""
status_code: int = 403
msg: str = "Forbidden"

24
responses/users.py Normal file
View File

@ -0,0 +1,24 @@
from responses import Response
from orm.users import PublicUserModel, PrivateUserModel
class PublicUserResponse(Response):
"""
A response sent by an endpoint that
replies with a public user object
"""
status_code: int = 200
msg: str = "Lookup successful"
data: PublicUserModel
class PrivateUserResponse(Response):
"""
A response sent by an endpoint that
replies with a private user object
"""
status_code: int = 200
data: PrivateUserModel
msg: str = "Lookup successful"

View File

@ -7,6 +7,9 @@ from slowapi.errors import RateLimitExceeded
async def rate_limited(request: Request, error: RateLimitExceeded) -> JSONResponse: async def rate_limited(request: Request, error: RateLimitExceeded) -> JSONResponse:
"""
Handles the equivalent of a 429 Too Many Requests error
"""
n = 0 n = 0
while True: while True:
if error.detail[n].isnumeric(): if error.detail[n].isnumeric():
@ -18,79 +21,58 @@ async def rate_limited(request: Request, error: RateLimitExceeded) -> JSONRespon
f"{request.client.host} got rate-limited at {str(request.url)} " f"{request.client.host} got rate-limited at {str(request.url)} "
f"(exceeded {error.detail})" f"(exceeded {error.detail})"
) )
return JSONResponse( return JSONResponse(status_code=200, content=dict(status_code=429,
status_code=200, msg=f"Too many requests, retry after {error.detail[error.detail.find('per') + 4:]}"))
content={
"msg": f"Too many requests, retry after {error.detail[error.detail.find('per') + 4:]}",
"status_code": 429,
},
)
def not_authenticated(request: Request, _: NotAuthenticated) -> JSONResponse: def not_authenticated(request: Request, _: NotAuthenticated) -> JSONResponse:
"""
Handles the equivalent of a 401 Unauthorized exception
"""
LOGGER.info(f"{request.client.host} failed to authenticate at {str(request.url)}") LOGGER.info(f"{request.client.host} failed to authenticate at {str(request.url)}")
return JSONResponse( return JSONResponse(status_code=200, content=dict(status_code=401, msg="Authentication is required"))
status_code=200,
content={
"msg": "Authentication is required",
"status_code": 401,
},
)
def request_invalid(request: Request, exc: RequestValidationError) -> JSONResponse: def request_invalid(request: Request, exc: RequestValidationError) -> JSONResponse:
"""
Handles Bad Request exceptions from FastAPI
"""
LOGGER.info( LOGGER.info(
f"{request.client.host} sent an invalid request at {request.url!r}: {type(exc).__name__}: {exc}" f"{request.client.host} sent an invalid request at {request.url!r}: {type(exc).__name__}: {exc}"
) )
return JSONResponse( return JSONResponse(status_code=200, content=dict(status_code=400, msg=f"Bad request: {type(exc).__name__}: {exc}"))
status_code=200,
content={
"msg": f"Bad request: {type(exc).__name__}: {exc}",
"status_code": 400,
},
)
def http_exception( def http_exception(
request: Request, exc: HTTPException | StarletteHTTPException request: Request, exc: HTTPException | StarletteHTTPException
) -> JSONResponse: ) -> JSONResponse:
"""
Handles HTTP-specific exceptions raised explicitly by
path operations
"""
if exc.status_code >= 500: if exc.status_code >= 500:
LOGGER.error( LOGGER.error(
f"{request.client.host} raised a {exc.status_code} error at {request.url!r}:" f"{request.client.host} raised a {exc.status_code} error at {request.url!r}:"
f"{type(exc).__name__}: {exc}" f"{type(exc).__name__}: {exc}"
) )
return JSONResponse( return JSONResponse(status_code=200, content=dict(status_code=500, msg="Internal Server Error"))
status_code=200,
content={
"msg": "Internal server error",
"status_code": exc.status_code,
},
)
else: else:
LOGGER.info( LOGGER.info(
f"{request.client.host} raised an HTTP error ({exc.status_code}) at {str(request.url)}" f"{request.client.host} raised an HTTP error ({exc.status_code}) at {str(request.url)}"
) )
return JSONResponse( return JSONResponse(status_code=200, content=dict(status_code=exc.status_code, msg=exc.detail))
status_code=200,
content={
"msg": exc.detail,
"status_code": exc.status_code,
},
)
async def generic_error(request: Request, exc: Exception) -> JSONResponse: async def generic_error(request: Request, exc: Exception) -> JSONResponse:
""" """
Handles generic, unexpected errors in the application Handles generic, unexpected errors in the ASGI application
""" """
LOGGER.info( LOGGER.info(
f"{request.client.host} raised an unexpected error ({type(exc).__name__}: {exc}) at {str(request.url)}" f"{request.client.host} raised an unexpected error ({type(exc).__name__}: {exc}) at {str(request.url)}"
) )
return JSONResponse( # We can't leak anything about the error, it would be too risky
status_code=200, return JSONResponse(status_code=200, content=dict(status_code=500, msg="Internal Server Error"))
content={
"msg": "Internal Server Error", # We can't leak anything about the error, it would be too risky
"status_code": 500,
},
)