Large rework of the application for better documentation support
This commit is contained in:
parent
f0ef827577
commit
1c05716345
|
@ -6,10 +6,8 @@ import re
|
|||
import imghdr
|
||||
import uuid
|
||||
import zlib
|
||||
|
||||
import bcrypt
|
||||
from uuid import UUID
|
||||
|
||||
import validators
|
||||
from fastapi import APIRouter as FastAPI, Depends, Response, Request, UploadFile
|
||||
from fastapi.exceptions import HTTPException
|
||||
|
@ -52,47 +50,31 @@ from config import (
|
|||
FORCE_EMAIL_VERIFICATION,
|
||||
UNVERIFIED_MANAGER,
|
||||
)
|
||||
from orm.users import UserModel, User
|
||||
from responses import Response as APIResponse, UnprocessableEntity, BadRequest, NotFound, \
|
||||
MediaTypeNotAcceptable, PayloadTooLarge, InternalServerError
|
||||
from responses.users import (
|
||||
PrivateUserResponse,
|
||||
PublicUserResponse,
|
||||
PrivateUserModel,
|
||||
PublicUserModel,
|
||||
)
|
||||
from orm.users import (
|
||||
User,
|
||||
UserModel,
|
||||
get_user_by_username,
|
||||
get_user_by_id,
|
||||
get_user_by_email,
|
||||
)
|
||||
from orm.media import Media, MediaType
|
||||
from orm.email_verification import EmailVerification
|
||||
from util.email import send_email
|
||||
|
||||
|
||||
router = FastAPI()
|
||||
|
||||
|
||||
async def get_user_by_id(
|
||||
public_id: UUID,
|
||||
include_secrets: bool = False,
|
||||
restricted_ok: bool = False,
|
||||
deleted_ok: bool = False,
|
||||
) -> dict | None:
|
||||
"""
|
||||
Retrieves a user by its public ID
|
||||
"""
|
||||
|
||||
user = (
|
||||
await User.select(
|
||||
*User.all_columns(exclude=["public_id"]),
|
||||
User.public_id.as_alias("id"),
|
||||
exclude_secrets=not include_secrets,
|
||||
)
|
||||
.where(User.public_id == public_id)
|
||||
.first()
|
||||
)
|
||||
if user:
|
||||
# Performs validation
|
||||
UserModel(**user)
|
||||
if (user["deleted"] and not deleted_ok) or (
|
||||
user["restricted"] and not restricted_ok
|
||||
):
|
||||
return
|
||||
return user
|
||||
return
|
||||
|
||||
|
||||
# Credential loaders for our authenticated routes
|
||||
@MANAGER.user_loader()
|
||||
async def get_self_by_id(public_id: UUID, requires_verified: bool = True) -> dict:
|
||||
async def get_self_by_id(public_id: UUID, requires_verified: bool = True) -> UserModel:
|
||||
user = await get_user_by_id(public_id, include_secrets=True, restricted_ok=True)
|
||||
if FORCE_EMAIL_VERIFICATION and requires_verified and not user["email_verified"]:
|
||||
raise HTTPException(status_code=401, detail="Email verification is required")
|
||||
|
@ -100,76 +82,22 @@ async def get_self_by_id(public_id: UUID, requires_verified: bool = True) -> dic
|
|||
|
||||
|
||||
@UNVERIFIED_MANAGER.user_loader()
|
||||
async def get_self_by_id_unverified(public_id: UUID):
|
||||
async def get_self_by_id_unverified(public_id: UUID) -> UserModel:
|
||||
return await get_self_by_id(public_id, False)
|
||||
|
||||
|
||||
async def get_user_by_username(
|
||||
username: str,
|
||||
include_secrets: bool = False,
|
||||
restricted_ok: bool = False,
|
||||
deleted_ok: bool = False,
|
||||
) -> dict | None:
|
||||
"""
|
||||
Retrieves a user by its public username
|
||||
"""
|
||||
|
||||
user = (
|
||||
await User.select(
|
||||
*User.all_columns(exclude=["public_id"]),
|
||||
User.public_id.as_alias("id"),
|
||||
exclude_secrets=not include_secrets,
|
||||
)
|
||||
.where(User.username == username)
|
||||
.first()
|
||||
)
|
||||
if user:
|
||||
# Performs validation
|
||||
UserModel(**user)
|
||||
if (user["deleted"] and not deleted_ok) or (
|
||||
user["restricted"] and not restricted_ok
|
||||
):
|
||||
return
|
||||
return user
|
||||
return
|
||||
|
||||
|
||||
async def get_user_by_email(
|
||||
email: str,
|
||||
include_secrets: bool = False,
|
||||
restricted_ok: bool = False,
|
||||
deleted_ok: bool = False,
|
||||
) -> dict | None:
|
||||
"""
|
||||
Retrieves a user by its email address (meant to
|
||||
be used internally)
|
||||
"""
|
||||
|
||||
user = (
|
||||
await User.select(
|
||||
*User.all_columns(exclude=["public_id"]),
|
||||
User.public_id.as_alias("id"),
|
||||
exclude_secrets=not include_secrets,
|
||||
)
|
||||
.where(User.email_address == email)
|
||||
.first()
|
||||
)
|
||||
if user:
|
||||
# Performs validation
|
||||
UserModel(**user)
|
||||
if (user["deleted"] and not deleted_ok) or (
|
||||
user["restricted"] and not restricted_ok
|
||||
):
|
||||
return
|
||||
return user
|
||||
return
|
||||
|
||||
# Here follow our *beautifully* documented path operations
|
||||
|
||||
@router.post("/user", tags=["Users"], status_code=200,
|
||||
responses={200: {"model": APIResponse},
|
||||
400: {"model": BadRequest},
|
||||
422: {"model": UnprocessableEntity}})
|
||||
@LIMITER.limit("5/minute")
|
||||
@router.post("/user")
|
||||
async def login(
|
||||
request: Request, response: Response, data: OAuth2PasswordRequestForm = Depends()
|
||||
) -> dict:
|
||||
async def login(request: Request, response: Response, data: OAuth2PasswordRequestForm = Depends()):
|
||||
"""
|
||||
Performs user authentication. Endpoint is limited to 5 hits per minute
|
||||
"""
|
||||
|
||||
if request.cookies.get(SESSION_COOKIE_NAME):
|
||||
raise HTTPException(status_code=400, detail="Please logout first")
|
||||
username = data.username
|
||||
|
@ -192,21 +120,21 @@ async def login(
|
|||
detail="Authentication failed: invalid characters in password",
|
||||
)
|
||||
if not (
|
||||
user := await get_user_by_username(
|
||||
username, include_secrets=True, restricted_ok=True
|
||||
)
|
||||
user := await get_user_by_username(
|
||||
username, include_secrets=True, restricted_ok=True
|
||||
)
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=413,
|
||||
detail="Authentication failed: the user does not exist",
|
||||
)
|
||||
if not bcrypt.checkpw(password, user["password_hash"]):
|
||||
if not bcrypt.checkpw(password, user.password_hash):
|
||||
raise HTTPException(
|
||||
status_code=413,
|
||||
detail="Authentication failed: password mismatch",
|
||||
)
|
||||
token = MANAGER.create_access_token(
|
||||
expires=timedelta(seconds=SESSION_EXPIRE_LIMIT), data={"sub": str(user["id"])}
|
||||
expires=timedelta(seconds=SESSION_EXPIRE_LIMIT), data={"sub": str(user.public_id)}
|
||||
)
|
||||
response.set_cookie(
|
||||
secure=SECURE_COOKIE,
|
||||
|
@ -218,17 +146,17 @@ async def login(
|
|||
domain=COOKIE_DOMAIN or None,
|
||||
path=COOKIE_PATH or "/",
|
||||
)
|
||||
return {"status_code": 200, "msg": "Authentication successful"}
|
||||
return APIResponse(status_code=200, msg="Authentication successful")
|
||||
|
||||
|
||||
@router.get("/user/logout")
|
||||
@router.get("/user/logout", tags=["Users"], status_code=200,
|
||||
responses={200: {"model": APIResponse},
|
||||
422: {"model": UnprocessableEntity}})
|
||||
@LIMITER.limit("5/minute")
|
||||
async def logout(
|
||||
request: Request, response: Response, _user: dict = Depends(UNVERIFIED_MANAGER)
|
||||
) -> dict:
|
||||
async def logout(request: Request, response: Response, _user: UserModel = Depends(UNVERIFIED_MANAGER)):
|
||||
"""
|
||||
Deletes a user's session cookie, logging them
|
||||
out
|
||||
out. Endpoint is limited to 5 hits per minute
|
||||
"""
|
||||
|
||||
response.delete_cookie(
|
||||
|
@ -239,70 +167,81 @@ async def logout(
|
|||
domain=COOKIE_DOMAIN or None,
|
||||
path=COOKIE_PATH or "/",
|
||||
)
|
||||
return {"status_code": 200, "msg": "Logged out"}
|
||||
return APIResponse(status_code=200, msg="Logged out")
|
||||
|
||||
|
||||
@router.get("/user/me")
|
||||
@router.get(
|
||||
"/user/me",
|
||||
tags=["Users"],
|
||||
status_code=200,
|
||||
responses={200: {"model": PrivateUserResponse, "exclude": {"password_hash", "internal_id", "deleted"}},
|
||||
422: {"model": UnprocessableEntity}},
|
||||
)
|
||||
@LIMITER.limit("2/second")
|
||||
async def get_self(request: Request, user: dict = Depends(UNVERIFIED_MANAGER)) -> dict:
|
||||
async def get_self(request: Request, user: UserModel = Depends(UNVERIFIED_MANAGER)):
|
||||
"""
|
||||
Fetches a user's own info. This returns some
|
||||
extra data such as email address, account
|
||||
creation date and email verification status,
|
||||
which is not available from the regular endpoint
|
||||
which is not available from the regular endpoint.
|
||||
Endpoint is limited to 2 hits per second
|
||||
"""
|
||||
|
||||
user.pop("password_hash")
|
||||
user.pop("internal_id")
|
||||
user.pop("deleted")
|
||||
return {"status_code": 200, "msg": "Success", "data": user}
|
||||
return PrivateUserResponse(status_code=200, msg="Success", data=user)
|
||||
|
||||
|
||||
@router.get("/user/username/{username}")
|
||||
@router.get(
|
||||
"/user/username/{username}",
|
||||
tags=["Users"],
|
||||
status_code=200,
|
||||
responses={200: {"model": PrivateUserResponse, "exclude": {"password_hash", "internal_id", "deleted"}},
|
||||
404: {"model": NotFound},
|
||||
422: {"model": UnprocessableEntity}
|
||||
},
|
||||
)
|
||||
@LIMITER.limit("30/second")
|
||||
async def get_user_by_name(
|
||||
request: Request, username: str, _auth: dict = Depends(MANAGER)
|
||||
) -> dict:
|
||||
request: Request, username: str, _auth: UserModel = Depends(MANAGER)
|
||||
):
|
||||
"""
|
||||
Fetches a single user by its public ID
|
||||
Fetches a single user by its public username
|
||||
"""
|
||||
|
||||
if not (user := await get_user_by_username(username)):
|
||||
return {
|
||||
"status_code": 404,
|
||||
"msg": "Lookup failed: the user does not exist",
|
||||
}
|
||||
user.pop("restricted")
|
||||
user.pop("deleted")
|
||||
return {"status_code": 200, "msg": "Lookup successful", "data": user}
|
||||
return NotFound(msg="Lookup failed: the user does not exist")
|
||||
return PublicUserResponse(status_code=200, msg="Lookup successful", data=user)
|
||||
|
||||
|
||||
@router.get("/user/id/{public_id}")
|
||||
@router.get(
|
||||
"/user/id/{public_id}",
|
||||
tags=["Users"],
|
||||
status_code=200,
|
||||
responses={200: {"model": PrivateUserResponse, "exclude": {"password_hash", "internal_id", "deleted"}},
|
||||
404: {"model": NotFound},
|
||||
422: {"model": UnprocessableEntity}
|
||||
},
|
||||
)
|
||||
@LIMITER.limit("30/second")
|
||||
async def get_user_by_public_id(
|
||||
request: Request, public_id: str, _auth: dict = Depends(MANAGER)
|
||||
) -> dict:
|
||||
request: Request, public_id: str, _auth: UserModel = Depends(MANAGER)
|
||||
):
|
||||
"""
|
||||
Fetches a single user by its public ID
|
||||
"""
|
||||
|
||||
if not (user := await get_user_by_id(UUID(public_id))):
|
||||
raise HTTPException(
|
||||
status_code=404, detail="Lookup failed: the user does not exist"
|
||||
)
|
||||
user.pop("restricted")
|
||||
user.pop("deleted")
|
||||
return {"status_code": 200, "msg": "Lookup successful", "data": user}
|
||||
raise HTTPException(status_code=404, detail="Lookup failed: the user does not exist")
|
||||
return PublicUserResponse(data=user)
|
||||
|
||||
|
||||
async def validate_user(
|
||||
first_name: str | None,
|
||||
last_name: str | None,
|
||||
username: str | None,
|
||||
email: str | None,
|
||||
password: str | None,
|
||||
bio: str | None,
|
||||
) -> tuple[bool, str]:
|
||||
first_name: str | None,
|
||||
last_name: str | None,
|
||||
username: str | None,
|
||||
email: str | None,
|
||||
password: str | None,
|
||||
bio: str | None,
|
||||
):
|
||||
"""
|
||||
Performs some validation upon user creation. Returns
|
||||
a tuple (success, msg) to be used by routes. Values
|
||||
|
@ -322,9 +261,9 @@ async def validate_user(
|
|||
if username and len(username) > 32:
|
||||
return False, "username is too long"
|
||||
if (
|
||||
username
|
||||
and VALIDATE_USERNAME_REGEX
|
||||
and not re.match(VALIDATE_USERNAME_REGEX, username)
|
||||
username
|
||||
and VALIDATE_USERNAME_REGEX
|
||||
and not re.match(VALIDATE_USERNAME_REGEX, username)
|
||||
):
|
||||
return False, "username is invalid"
|
||||
if email and not validators.email(email):
|
||||
|
@ -332,13 +271,13 @@ async def validate_user(
|
|||
if password and len(password) > 72:
|
||||
return False, "password is too long"
|
||||
if (
|
||||
password
|
||||
and VALIDATE_PASSWORD_REGEX
|
||||
and not re.match(VALIDATE_PASSWORD_REGEX, password)
|
||||
password
|
||||
and VALIDATE_PASSWORD_REGEX
|
||||
and not re.match(VALIDATE_PASSWORD_REGEX, password)
|
||||
):
|
||||
return False, "password is too weak"
|
||||
if username and await get_user_by_username(
|
||||
username, deleted_ok=True, restricted_ok=True
|
||||
username, deleted_ok=True, restricted_ok=True
|
||||
):
|
||||
return False, "username is already taken"
|
||||
if email and await get_user_by_email(email, deleted_ok=True, restricted_ok=True):
|
||||
|
@ -353,18 +292,20 @@ async def validate_user(
|
|||
return True, ""
|
||||
|
||||
|
||||
@router.delete("/user")
|
||||
@router.delete("/user", tags=["Users"], status_code=200,
|
||||
responses={200: {"model": APIResponse},
|
||||
422: {"model": UnprocessableEntity}})
|
||||
@LIMITER.limit("1/minute")
|
||||
async def delete(
|
||||
request: Request, response: Response, user: dict = Depends(UNVERIFIED_MANAGER)
|
||||
) -> dict:
|
||||
request: Request, response: Response, user: UserModel = Depends(UNVERIFIED_MANAGER)
|
||||
):
|
||||
"""
|
||||
Sets the user's deleted flag in the database,
|
||||
without actually deleting the associated
|
||||
data
|
||||
"""
|
||||
|
||||
await User.update({User.deleted: True}).where(User.public_id == user["id"])
|
||||
await User.update({User.deleted: True}).where(User.public_id == user.public_id)
|
||||
response.delete_cookie(
|
||||
secure=SECURE_COOKIE,
|
||||
key=SESSION_COOKIE_NAME,
|
||||
|
@ -373,22 +314,28 @@ async def delete(
|
|||
domain=COOKIE_DOMAIN or None,
|
||||
path=COOKIE_PATH or "/",
|
||||
)
|
||||
return {"status_code": 200, "msg": "Success"}
|
||||
return APIResponse(status_code=200, msg="Success")
|
||||
|
||||
|
||||
@router.get("/user/verifyEmail/{verification_id}")
|
||||
@router.get("/user/verifyEmail/{verification_id}", tags=["Users"], status_code=200,
|
||||
responses={200: {"model": PrivateUserResponse, "exclude": {"password_hash", "internal_id", "deleted"}},
|
||||
404: {"model": NotFound},
|
||||
422: {"model": UnprocessableEntity}
|
||||
})
|
||||
@LIMITER.limit("3/second")
|
||||
async def verify_email(
|
||||
request: Request, verification_id: str, user: dict = Depends(UNVERIFIED_MANAGER)
|
||||
) -> dict:
|
||||
request: Request,
|
||||
verification_id: str,
|
||||
user: UserModel = Depends(UNVERIFIED_MANAGER),
|
||||
):
|
||||
"""
|
||||
Verifies a user's email address
|
||||
"""
|
||||
|
||||
if not (
|
||||
verification := await EmailVerification.select(*EmailVerification.all_columns())
|
||||
.where(EmailVerification.id == verification_id)
|
||||
.first()
|
||||
verification := await EmailVerification.select(*EmailVerification.all_columns())
|
||||
.where(EmailVerification.id == verification_id)
|
||||
.first()
|
||||
):
|
||||
raise HTTPException(status_code=404, detail="Verification ID is invalid")
|
||||
elif not verification["pending"]:
|
||||
|
@ -402,27 +349,31 @@ async def verify_email(
|
|||
)
|
||||
else:
|
||||
await EmailVerification.update({EmailVerification.pending: False}).where(
|
||||
EmailVerification.user == user["id"]
|
||||
EmailVerification.user == user.public_id
|
||||
)
|
||||
await User.update({User.email_verified: True}).where(
|
||||
User.public_id == user["id"]
|
||||
)
|
||||
return {"status_code": 200, "msg": "Verification successful"}
|
||||
await User.update({User.email_verified: True}).where(User.public_id == user.public_id)
|
||||
return APIResponse(status_code=200, msg="Verification successful")
|
||||
|
||||
|
||||
@router.get("/user/resetPassword/{verification_id}")
|
||||
@router.get("/user/resetPassword/{verification_id}", tags=["Users"], status_code=200,
|
||||
responses={200: {"model": APIResponse},
|
||||
400: {"model": BadRequest},
|
||||
422: {"model": UnprocessableEntity},
|
||||
404: {"model": NotFound}})
|
||||
@LIMITER.limit("3/second")
|
||||
async def reset_password(
|
||||
request: Request, verification_id: str, user: dict = Depends(UNVERIFIED_MANAGER)
|
||||
) -> dict:
|
||||
request: Request,
|
||||
verification_id: str,
|
||||
user: UserModel = Depends(UNVERIFIED_MANAGER),
|
||||
):
|
||||
"""
|
||||
Modifies a user's password
|
||||
"""
|
||||
|
||||
if not (
|
||||
verification := await EmailVerification.select(*EmailVerification.all_columns())
|
||||
.where(EmailVerification.id == verification_id)
|
||||
.first()
|
||||
verification := await EmailVerification.select(*EmailVerification.all_columns())
|
||||
.where(EmailVerification.id == verification_id)
|
||||
.first()
|
||||
):
|
||||
raise HTTPException(status_code=404, detail="Request ID is invalid")
|
||||
elif not verification["pending"]:
|
||||
|
@ -436,27 +387,33 @@ async def reset_password(
|
|||
)
|
||||
else:
|
||||
await EmailVerification.update({EmailVerification.pending: False}).where(
|
||||
EmailVerification.user == user["id"]
|
||||
EmailVerification.user == user.public_id
|
||||
)
|
||||
await User.update({User.password_hash: verification["data"]}).where(
|
||||
User.public_id == user["id"]
|
||||
User.public_id == user.public_id
|
||||
)
|
||||
return {"status_code": 200, "msg": "Password updated"}
|
||||
return APIResponse(status_code=200, msg="Password updated")
|
||||
|
||||
|
||||
@router.get("/user/changeEmail/{verification_id}")
|
||||
@router.get("/user/changeEmail/{verification_id}", tags=["Users"], status_code=200,
|
||||
responses={200: {"model": APIResponse},
|
||||
400: {"model": BadRequest},
|
||||
422: {"model": UnprocessableEntity},
|
||||
404: {"model": NotFound}})
|
||||
@LIMITER.limit("3/second")
|
||||
async def change_email(
|
||||
request: Request, verification_id: str, user: dict = Depends(UNVERIFIED_MANAGER)
|
||||
) -> dict:
|
||||
request: Request,
|
||||
verification_id: str,
|
||||
user: UserModel = Depends(UNVERIFIED_MANAGER),
|
||||
):
|
||||
"""
|
||||
Modifies a user's email
|
||||
"""
|
||||
|
||||
if not (
|
||||
verification := await EmailVerification.select(*EmailVerification.all_columns())
|
||||
.where(EmailVerification.id == verification_id)
|
||||
.first()
|
||||
verification := await EmailVerification.select(*EmailVerification.all_columns())
|
||||
.where(EmailVerification.id == verification_id)
|
||||
.first()
|
||||
):
|
||||
raise HTTPException(status_code=404, detail="Request ID is invalid")
|
||||
elif not verification["pending"]:
|
||||
|
@ -470,18 +427,24 @@ async def change_email(
|
|||
)
|
||||
else:
|
||||
await EmailVerification.update({EmailVerification.pending: False}).where(
|
||||
EmailVerification.user == user["id"]
|
||||
EmailVerification.user == user.public_id
|
||||
)
|
||||
await User.update({User.email_address: verification["data"].decode(),
|
||||
User.email_verified: False}).where(
|
||||
User.public_id == user["id"]
|
||||
)
|
||||
return {"status_code": 200, "msg": "Email updated"}
|
||||
await User.update(
|
||||
{
|
||||
User.email_address: verification["data"].decode(),
|
||||
User.email_verified: False,
|
||||
}
|
||||
).where(User.public_id == user.public_id)
|
||||
return APIResponse(status_code=200, msg="Email updated")
|
||||
|
||||
|
||||
@router.put("user/resendMail")
|
||||
@router.put("user/resendMail", tags=["Users"], status_code=200,
|
||||
responses={200: {"model": APIResponse},
|
||||
400: {"model": BadRequest},
|
||||
422: {"model": UnprocessableEntity},
|
||||
500: {"model": InternalServerError}})
|
||||
@LIMITER.limit("6/minute")
|
||||
async def resend_email(request: Request, user: dict = Depends(UNVERIFIED_MANAGER)) -> dict:
|
||||
async def resend_email(request: Request, user: UserModel = Depends(UNVERIFIED_MANAGER)):
|
||||
"""
|
||||
Resends the verification email to the user if the previous has expired
|
||||
"""
|
||||
|
@ -523,26 +486,34 @@ async def resend_email(request: Request, user: dict = Depends(UNVERIFIED_MANAGER
|
|||
use_tls=SMTP_USE_TLS,
|
||||
):
|
||||
await EmailVerification.update(
|
||||
{
|
||||
EmailVerification.id: verification_id,
|
||||
EmailVerification.creation_date: datetime.now()
|
||||
}
|
||||
{
|
||||
EmailVerification.id: verification_id,
|
||||
EmailVerification.creation_date: datetime.now(),
|
||||
}
|
||||
)
|
||||
return {"status_code": 200, "msg": "Success"}
|
||||
return APIResponse(status_code=200, msg="Success")
|
||||
else:
|
||||
raise HTTPException(status_code=500, detail="An error occurred while trying to resend the email,"
|
||||
" please try again later")
|
||||
|
||||
|
||||
@router.put("/user")
|
||||
@router.put("/user", tags=["Users"], status_code=200,
|
||||
responses={200: {"model": APIResponse},
|
||||
400: {"model": BadRequest},
|
||||
422: {"model": UnprocessableEntity},
|
||||
500: {"model": InternalServerError},
|
||||
413: {"model": PayloadTooLarge}})
|
||||
@LIMITER.limit("2/minute")
|
||||
async def signup(
|
||||
request: Request,
|
||||
first_name: str,
|
||||
last_name: str,
|
||||
username: str,
|
||||
email: str,
|
||||
password: str,
|
||||
bio: str | None = None,
|
||||
locale: str = "en_US",
|
||||
) -> dict:
|
||||
request: Request,
|
||||
first_name: str,
|
||||
last_name: str,
|
||||
username: str,
|
||||
email: str,
|
||||
password: str,
|
||||
bio: str | None = None,
|
||||
locale: str = "en_US",
|
||||
):
|
||||
"""
|
||||
Endpoint used to create new users
|
||||
"""
|
||||
|
@ -555,7 +526,7 @@ async def signup(
|
|||
first_name, last_name, username, email, password, bio
|
||||
)
|
||||
if not result:
|
||||
return {"status_code": 413, "msg": f"Signup failed: {msg}"}
|
||||
return APIResponse(status_code=413, msg=f"Signup failed: {msg}")
|
||||
else:
|
||||
salt = bcrypt.gensalt(BCRYPT_ROUNDS)
|
||||
user = User(
|
||||
|
@ -563,9 +534,7 @@ async def signup(
|
|||
last_name=last_name,
|
||||
username=username,
|
||||
email_address=email,
|
||||
password_hash=bcrypt.hashpw(
|
||||
password.encode(), salt
|
||||
),
|
||||
password_hash=bcrypt.hashpw(password.encode(), salt),
|
||||
bio=bio,
|
||||
)
|
||||
email_template = SMTP_TEMPLATES_DIRECTORY / f"{locale}.json"
|
||||
|
@ -577,30 +546,30 @@ async def signup(
|
|||
email_message = json.load(f)["signup"]
|
||||
verification_id = uuid.uuid4()
|
||||
if await send_email(
|
||||
SMTP_HOST,
|
||||
SMTP_PORT,
|
||||
email_message["content"].format(
|
||||
first_name=first_name,
|
||||
last_name=last_name,
|
||||
username=username,
|
||||
email=email,
|
||||
link=f"http{'s' if HAS_HTTPS else ''}://{HOST}"
|
||||
f"{'' if PORT == 443 and HAS_HTTPS or PORT == 80 else f':{PORT}'}/user/verifyEmail/{verification_id}",
|
||||
platformName=PLATFORM_NAME,
|
||||
),
|
||||
SMTP_TIMEOUT,
|
||||
SMTP_FROM_USER,
|
||||
email,
|
||||
email_message["subject"].format(
|
||||
first_name=first_name,
|
||||
last_name=last_name,
|
||||
username=username,
|
||||
email=email,
|
||||
platformName=PLATFORM_NAME,
|
||||
),
|
||||
SMTP_USER,
|
||||
SMTP_PASSWORD,
|
||||
use_tls=SMTP_USE_TLS,
|
||||
SMTP_HOST,
|
||||
SMTP_PORT,
|
||||
email_message["content"].format(
|
||||
first_name=first_name,
|
||||
last_name=last_name,
|
||||
username=username,
|
||||
email=email,
|
||||
link=f"http{'s' if HAS_HTTPS else ''}://{HOST}"
|
||||
f"{'' if PORT == 443 and HAS_HTTPS or PORT == 80 else f':{PORT}'}/user/verifyEmail/{verification_id}",
|
||||
platformName=PLATFORM_NAME,
|
||||
),
|
||||
SMTP_TIMEOUT,
|
||||
SMTP_FROM_USER,
|
||||
email,
|
||||
email_message["subject"].format(
|
||||
first_name=first_name,
|
||||
last_name=last_name,
|
||||
username=username,
|
||||
email=email,
|
||||
platformName=PLATFORM_NAME,
|
||||
),
|
||||
SMTP_USER,
|
||||
SMTP_PASSWORD,
|
||||
use_tls=SMTP_USE_TLS,
|
||||
):
|
||||
await User.insert(user)
|
||||
await EmailVerification.insert(
|
||||
|
@ -611,17 +580,17 @@ async def signup(
|
|||
}
|
||||
)
|
||||
)
|
||||
return {"status_code": 200, "msg": "Success"}
|
||||
return APIResponse(status_code=200, msg="Success")
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="An error occurred while sending verification email, please"
|
||||
" try again later",
|
||||
" try again later",
|
||||
)
|
||||
|
||||
|
||||
async def validate_profile_picture(
|
||||
file: UploadFile,
|
||||
file: UploadFile,
|
||||
) -> tuple[bool | None, str, bytes, str]:
|
||||
"""
|
||||
Validates a profile picture's size and content to see if it fits
|
||||
|
@ -650,18 +619,23 @@ async def validate_profile_picture(
|
|||
return None, "", b"", ""
|
||||
|
||||
|
||||
@router.patch("/user")
|
||||
async def update(
|
||||
request: Request,
|
||||
user: dict = Depends(UNVERIFIED_MANAGER),
|
||||
first_name: str | None = None,
|
||||
last_name: str | None = None,
|
||||
username: str | None = None,
|
||||
profile_picture: UploadFile | None = None,
|
||||
email_address: str | None = None,
|
||||
password: str | None = None,
|
||||
bio: str | None = None,
|
||||
delete: bool = False,
|
||||
@router.patch("/user", tags=["Users"], status_code=200,
|
||||
responses={200: {"model": APIResponse},
|
||||
400: {"model": BadRequest},
|
||||
422: {"model": UnprocessableEntity},
|
||||
500: {"model": InternalServerError},
|
||||
413: {"model": PayloadTooLarge}})
|
||||
async def update_user(
|
||||
request: Request,
|
||||
user: UserModel = Depends(UNVERIFIED_MANAGER),
|
||||
first_name: str | None = None,
|
||||
last_name: str | None = None,
|
||||
username: str | None = None,
|
||||
profile_picture: UploadFile | None = None,
|
||||
email_address: str | None = None,
|
||||
password: str | None = None,
|
||||
bio: str | None = None,
|
||||
delete: bool = False,
|
||||
):
|
||||
"""
|
||||
Updates a user's profile information. Parameters that are not specified are left unchanged unless
|
||||
|
@ -676,28 +650,26 @@ async def update(
|
|||
"""
|
||||
|
||||
if not delete and not any(
|
||||
(first_name, last_name, username, profile_picture, email_address, bio, password)
|
||||
(first_name, last_name, username, profile_picture, email_address, bio, password)
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=400, detail="At least one value has to be specified"
|
||||
)
|
||||
result, msg = (
|
||||
await validate_user(
|
||||
first_name, last_name, username, email_address, password, bio
|
||||
)
|
||||
result, msg = await validate_user(
|
||||
first_name, last_name, username, email_address, password, bio
|
||||
)
|
||||
if not result:
|
||||
raise HTTPException(status_code=413, detail=f"Update failed: {msg}")
|
||||
orig_user = user.copy()
|
||||
if not delete:
|
||||
if first_name:
|
||||
user["first_name"] = first_name
|
||||
user.first_name = first_name
|
||||
if last_name:
|
||||
user["last_name"] = last_name
|
||||
user.last_name = last_name
|
||||
if username:
|
||||
user["username"] = username
|
||||
user.username = username
|
||||
if bio:
|
||||
user["bio"] = bio
|
||||
user.bio = bio
|
||||
if profile_picture:
|
||||
result, ext, media, digest = validate_profile_picture(profile_picture)
|
||||
if result is False:
|
||||
|
@ -707,44 +679,44 @@ async def update(
|
|||
elif result is None:
|
||||
raise HTTPException(status_code=413, detail="The file is too large")
|
||||
elif (
|
||||
await (
|
||||
old_media := Media.select(Media.media_id)
|
||||
.where(Media.media_id == digest)
|
||||
.first()
|
||||
)
|
||||
is None
|
||||
await (
|
||||
old_media := Media.select(Media.media_id)
|
||||
.where(Media.media_id == digest)
|
||||
.first()
|
||||
)
|
||||
is None
|
||||
):
|
||||
# This media hasn't been already uploaded (either by this user or by someone
|
||||
# else), so we save it now. If it has been already uploaded, there's no need
|
||||
# to do it again (that's what the hash is for)
|
||||
match STORAGE_ENGINE:
|
||||
case "database":
|
||||
await Media.insert(
|
||||
Media(
|
||||
media_id=digest,
|
||||
media_type=MediaType.BLOB,
|
||||
content_type=ext,
|
||||
content=base64.b64encode(media),
|
||||
)
|
||||
m = Media(
|
||||
media_id=digest,
|
||||
media_type=MediaType.BLOB,
|
||||
content_type=ext,
|
||||
content=base64.b64encode(media),
|
||||
)
|
||||
await Media.insert(m)
|
||||
user.profile_picture = m
|
||||
case "local":
|
||||
file = Path(STORAGE_FOLDER).resolve(strict=True) / str(digest)
|
||||
file.touch(mode=0o644)
|
||||
with file.open("wb") as f:
|
||||
f.write(media)
|
||||
await Media.insert(
|
||||
Media(
|
||||
media_id=digest,
|
||||
media_type=MediaType.FILE,
|
||||
content_type=ext,
|
||||
content=file.as_posix(),
|
||||
)
|
||||
m = Media(
|
||||
media_id=digest,
|
||||
media_type=MediaType.FILE,
|
||||
content_type=ext,
|
||||
content=file.as_posix(),
|
||||
)
|
||||
await Media.insert(m)
|
||||
user.profile_picture = m
|
||||
case "url":
|
||||
pass # TODO: Use/implement CDN uploading
|
||||
else:
|
||||
user["media"] = old_media
|
||||
if password and not bcrypt.checkpw(password.encode(), user["password_hash"]):
|
||||
user.media = old_media
|
||||
if password and not bcrypt.checkpw(password.encode(), user.password_hash):
|
||||
email_template = SMTP_TEMPLATES_DIRECTORY / f"{user['locale']}.json"
|
||||
try:
|
||||
email_template.resolve(strict=True)
|
||||
|
@ -754,43 +726,48 @@ async def update(
|
|||
email_message = json.load(f)["password_change"]
|
||||
verification_id = uuid.uuid4()
|
||||
if not await send_email(
|
||||
SMTP_HOST,
|
||||
SMTP_PORT,
|
||||
email_message["content"].format(
|
||||
first_name=user["first_name"],
|
||||
last_name=user["last_name"],
|
||||
username=user["username"],
|
||||
email=user["email_address"],
|
||||
link=f"http{'s' if HAS_HTTPS else ''}://{HOST}"
|
||||
f"{'' if PORT == 443 and HAS_HTTPS or PORT == 80 else f':{PORT}'}/user/resetPassword/{verification_id}",
|
||||
platformName=PLATFORM_NAME,
|
||||
),
|
||||
SMTP_TIMEOUT,
|
||||
SMTP_FROM_USER,
|
||||
user["email_address"],
|
||||
email_message["subject"].format(
|
||||
first_name=user["first_name"],
|
||||
last_name=user["last_name"],
|
||||
username=user["username"],
|
||||
email=user["email_address"],
|
||||
platformName=PLATFORM_NAME,
|
||||
),
|
||||
SMTP_USER,
|
||||
SMTP_PASSWORD,
|
||||
use_tls=SMTP_USE_TLS,
|
||||
SMTP_HOST,
|
||||
SMTP_PORT,
|
||||
email_message["content"].format(
|
||||
first_name=user.first_name,
|
||||
last_name=user.last_name,
|
||||
username=user.username,
|
||||
email=user.email_address,
|
||||
link=f"http{'s' if HAS_HTTPS else ''}://{HOST}"
|
||||
f"{'' if PORT == 443 and HAS_HTTPS or PORT == 80 else f':{PORT}'}/user/resetPassword/{verification_id}",
|
||||
platformName=PLATFORM_NAME,
|
||||
),
|
||||
SMTP_TIMEOUT,
|
||||
SMTP_FROM_USER,
|
||||
user.email_address,
|
||||
email_message["subject"].format(
|
||||
first_name=user.first_name,
|
||||
last_name=user.last_name,
|
||||
username=user.username,
|
||||
email=user.email_address,
|
||||
platformName=PLATFORM_NAME,
|
||||
),
|
||||
SMTP_USER,
|
||||
SMTP_PASSWORD,
|
||||
use_tls=SMTP_USE_TLS,
|
||||
):
|
||||
raise HTTPException(500, detail="An error occurred while trying to send mail, please try again later")
|
||||
raise HTTPException(
|
||||
500,
|
||||
detail="An error occurred while trying to send mail, please try again later",
|
||||
)
|
||||
else:
|
||||
await EmailVerification.insert(
|
||||
EmailVerification(
|
||||
{
|
||||
EmailVerification.id: verification_id,
|
||||
EmailVerification.user: User(public_id=user["id"]),
|
||||
EmailVerification.data: bcrypt.hashpw(password.encode(), user["password_hash"][:29])
|
||||
EmailVerification.user: User(public_id=user.public_id),
|
||||
EmailVerification.data: bcrypt.hashpw(
|
||||
password.encode(), user.password_hash[:29]
|
||||
),
|
||||
}
|
||||
)
|
||||
)
|
||||
if email_address and user["email_address"] != email_address:
|
||||
if email_address and user.email_address != email_address:
|
||||
if not user["email_verified"]:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
|
@ -805,51 +782,61 @@ async def update(
|
|||
email_message = json.load(f)["email_change"]
|
||||
verification_id = uuid.uuid4()
|
||||
if not await send_email(
|
||||
SMTP_HOST,
|
||||
SMTP_PORT,
|
||||
email_message["content"].format(
|
||||
first_name=user["first_name"],
|
||||
last_name=user["last_name"],
|
||||
username=user["username"],
|
||||
email=user["email_address"],
|
||||
platformName=PLATFORM_NAME,
|
||||
link=f"http{'s' if HAS_HTTPS else ''}://{HOST}"
|
||||
f"{'' if PORT == 443 and HAS_HTTPS or PORT == 80 else f':{PORT}'}/user/changeEmail/{verification_id}",
|
||||
newMail=email_address,
|
||||
),
|
||||
SMTP_TIMEOUT,
|
||||
SMTP_FROM_USER,
|
||||
user["email_address"],
|
||||
email_message["subject"].format(
|
||||
first_name=user["first_name"],
|
||||
last_name=user["last_name"],
|
||||
username=user["username"],
|
||||
email=user["email_address"],
|
||||
platformName=PLATFORM_NAME,
|
||||
),
|
||||
SMTP_USER,
|
||||
SMTP_PASSWORD,
|
||||
use_tls=SMTP_USE_TLS,
|
||||
SMTP_HOST,
|
||||
SMTP_PORT,
|
||||
email_message["content"].format(
|
||||
first_name=user.first_name,
|
||||
last_name=user.last_name,
|
||||
username=user.username,
|
||||
email=user.email_address,
|
||||
platformName=PLATFORM_NAME,
|
||||
link=f"http{'s' if HAS_HTTPS else ''}://{HOST}"
|
||||
f"{'' if PORT == 443 and HAS_HTTPS or PORT == 80 else f':{PORT}'}/user/changeEmail/{verification_id}",
|
||||
newMail=email_address,
|
||||
),
|
||||
SMTP_TIMEOUT,
|
||||
SMTP_FROM_USER,
|
||||
user.email_address,
|
||||
email_message["subject"].format(
|
||||
first_name=user.first_name,
|
||||
last_name=user.last_name,
|
||||
username=user.username,
|
||||
email=user.email_address,
|
||||
platformName=PLATFORM_NAME,
|
||||
),
|
||||
SMTP_USER,
|
||||
SMTP_PASSWORD,
|
||||
use_tls=SMTP_USE_TLS,
|
||||
):
|
||||
raise HTTPException(500, detail="An error occurred while trying to send mail, please try again later")
|
||||
raise HTTPException(
|
||||
500,
|
||||
detail="An error occurred while trying to send mail, please try again later",
|
||||
)
|
||||
else:
|
||||
await EmailVerification.insert(EmailVerification({
|
||||
EmailVerification.id: verification_id,
|
||||
EmailVerification.user: User(public_id=user["id"]),
|
||||
EmailVerification.data: email_address.encode()
|
||||
}))
|
||||
await EmailVerification.insert(
|
||||
EmailVerification(
|
||||
{
|
||||
EmailVerification.id: verification_id,
|
||||
EmailVerification.user: User(public_id=user.public_id),
|
||||
EmailVerification.data: email_address.encode(),
|
||||
}
|
||||
)
|
||||
)
|
||||
else:
|
||||
if not bio:
|
||||
user["bio"] = None
|
||||
user.bio = None
|
||||
if not profile_picture:
|
||||
user["profile_picture"] = None
|
||||
fields = []
|
||||
for field in user:
|
||||
if field not in ["email_address", "password"] and orig_user[field] != user[field]:
|
||||
if (
|
||||
field not in ["email_address", "password"]
|
||||
and orig_user[field] != user[field]
|
||||
):
|
||||
fields.append((field, user[field]))
|
||||
if fields:
|
||||
# If anything has changed, we update our info
|
||||
await User.update({field: value for field, value in fields}).where(
|
||||
User.public_id == user["id"]
|
||||
User.public_id == user.public_id
|
||||
)
|
||||
return {"status_code": 200, "msg": "Changes saved successfully"}
|
||||
return APIResponse(status_code=200, msg="Changes saved successfully")
|
||||
|
|
43
main.py
43
main.py
|
@ -9,6 +9,7 @@ from fastapi.exceptions import (
|
|||
StarletteHTTPException,
|
||||
)
|
||||
from slowapi.errors import RateLimitExceeded
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
from endpoints import users
|
||||
|
@ -25,6 +26,7 @@ from config import (
|
|||
HOST,
|
||||
PORT,
|
||||
WORKERS,
|
||||
PLATFORM_NAME,
|
||||
)
|
||||
from orm import create_tables, Media
|
||||
from util.exception_handlers import (
|
||||
|
@ -36,24 +38,55 @@ from util.exception_handlers import (
|
|||
)
|
||||
from util.email import test_smtp
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
with (Path(__file__).parent / "README.md").resolve(strict=True).open() as f:
|
||||
description = f.read()
|
||||
|
||||
|
||||
@app.get("/")
|
||||
app = FastAPI(
|
||||
title=PLATFORM_NAME,
|
||||
license_info={"name": "MIT", "url": "https://opensource.org/licenses/MIT"},
|
||||
contact={
|
||||
"name": "Matt",
|
||||
"url": "https://nocturn9x.space",
|
||||
"email": "nocturn9x@nocturn9x.space",
|
||||
},
|
||||
version="0.0.1",
|
||||
description=description,
|
||||
openapi_tags=[
|
||||
{
|
||||
"name": "Miscellaneous",
|
||||
"description": "Simple endpoints that don't fit anywhere else",
|
||||
},
|
||||
{
|
||||
"name": "Users",
|
||||
"description": "Endpoints that handle user-related operations such as"
|
||||
" signing in, signing up, settings management and more",
|
||||
},
|
||||
{
|
||||
"name": "Posts",
|
||||
"description": "Endpoints that handle post creation, modification and deletion",
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@app.get("/", tags=["Miscellaneous"])
|
||||
@LIMITER.limit("10/second")
|
||||
async def root(request: Request):
|
||||
raise HTTPException(401, detail="Unauthorized")
|
||||
|
||||
|
||||
@app.get("/ping")
|
||||
@app.get("/ping", tags=["Miscellaneous"])
|
||||
@LIMITER.limit("1/minute")
|
||||
async def ping(request: Request) -> dict:
|
||||
"""
|
||||
This handler simply replies to "ping" requests and
|
||||
This method simply replies to "ping" requests and
|
||||
is used to check whether the API is up and running.
|
||||
It also performs a sanity check with the database and
|
||||
the SMTP server to ensure that they are functioning correctly.
|
||||
For this reason, this endpoint's rate limit is much stricter
|
||||
For this reason, this endpoint's rate limit is very strict:
|
||||
it can only be called once per minute
|
||||
"""
|
||||
|
||||
LOGGER.info(f"Processing ping request from {request.client.host}")
|
||||
|
|
82
orm/users.py
82
orm/users.py
|
@ -2,6 +2,7 @@
|
|||
User relation
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
from piccolo.utils.pydantic import create_pydantic_model
|
||||
from piccolo.table import Table
|
||||
from piccolo.columns import (
|
||||
|
@ -15,6 +16,7 @@ from piccolo.columns import (
|
|||
Boolean,
|
||||
Email,
|
||||
Bytea,
|
||||
Column,
|
||||
)
|
||||
from piccolo.columns.defaults.date import DateNow
|
||||
|
||||
|
@ -45,4 +47,82 @@ class User(Table, tablename="users"):
|
|||
locale = Varchar(length=12, default="en_US", null=False, secret=True)
|
||||
|
||||
|
||||
UserModel = create_pydantic_model(User)
|
||||
UserModel = create_pydantic_model(
|
||||
User,
|
||||
nested=True,
|
||||
)
|
||||
PublicUserModel = create_pydantic_model(
|
||||
User,
|
||||
nested=True,
|
||||
exclude_columns=(
|
||||
User.internal_id,
|
||||
User.deleted,
|
||||
User.restricted,
|
||||
User.locale,
|
||||
User.email_verified,
|
||||
User.email_address,
|
||||
User.password_hash,
|
||||
),
|
||||
model_name="PublicUser",
|
||||
)
|
||||
PrivateUserModel = create_pydantic_model(
|
||||
User,
|
||||
nested=True,
|
||||
exclude_columns=(User.internal_id, User.password_hash, User.deleted),
|
||||
model_name="PrivateUser",
|
||||
)
|
||||
|
||||
|
||||
async def get_user_by_column(
|
||||
column: Column,
|
||||
data: Any,
|
||||
include_secrets: bool = False,
|
||||
restricted_ok: bool = False,
|
||||
deleted_ok: bool = False,
|
||||
) -> UserModel | None:
|
||||
"""
|
||||
Retrieves a user object by a given criteria.
|
||||
Returns None if the user doesn't exist or
|
||||
if it's restricted/deleted (unless restricted_ok
|
||||
and deleted_ok are set accordingly)
|
||||
"""
|
||||
|
||||
user = (
|
||||
await User.select(
|
||||
*User.all_columns(),
|
||||
exclude_secrets=not include_secrets,
|
||||
)
|
||||
.where(column == data)
|
||||
.first()
|
||||
)
|
||||
if user:
|
||||
# Performs validation
|
||||
user = UserModel(**user)
|
||||
if (user.deleted and not deleted_ok) or (user.restricted and not restricted_ok):
|
||||
return
|
||||
return user
|
||||
return
|
||||
|
||||
|
||||
async def get_user_by_id(public_id: UUID, *args, **kwargs) -> UserModel:
|
||||
"""
|
||||
Retrieves a user by its public ID
|
||||
"""
|
||||
|
||||
return await get_user_by_column(User.public_id, public_id, *args, **kwargs)
|
||||
|
||||
|
||||
async def get_user_by_username(username: str, *args, **kwargs) -> UserModel | None:
|
||||
"""
|
||||
Retrieves a user by its public username
|
||||
"""
|
||||
|
||||
return await get_user_by_column(User.username, username, *args, **kwargs)
|
||||
|
||||
|
||||
async def get_user_by_email(email: str, *args, **kwargs) -> UserModel | None:
|
||||
"""
|
||||
Retrieves a user by its email address
|
||||
"""
|
||||
|
||||
return await get_user_by_column(User.email_address, email, *args, **kwargs)
|
||||
|
|
|
@ -7,4 +7,5 @@ uvloop~=0.17.0
|
|||
fastapi~=0.85.0
|
||||
validators
|
||||
aiosmtplib
|
||||
uvicorn
|
||||
uvicorn
|
||||
pydantic
|
|
@ -0,0 +1,96 @@
|
|||
from pydantic import BaseModel
|
||||
from orm.users import User
|
||||
|
||||
|
||||
class Response(BaseModel):
|
||||
"""
|
||||
A generic response model
|
||||
"""
|
||||
|
||||
status_code: int = 200
|
||||
msg: str = "Success"
|
||||
|
||||
|
||||
class UnprocessableEntity(Response):
|
||||
"""
|
||||
A 422 Unprocessable Entity response
|
||||
model
|
||||
"""
|
||||
|
||||
status_code: int = 422
|
||||
msg: str = "Input Validation Failure"
|
||||
|
||||
|
||||
class NotFound(Response):
|
||||
"""
|
||||
A 404 Not Found response
|
||||
model
|
||||
"""
|
||||
|
||||
status_code: int = 404
|
||||
msg: str = "Not Found"
|
||||
|
||||
|
||||
class PayloadTooLarge(Response):
|
||||
"""
|
||||
A 413 Payload Too Large response
|
||||
model
|
||||
"""
|
||||
|
||||
status_code: int = 413
|
||||
msg: str = "Payload too large"
|
||||
|
||||
|
||||
class MediaTypeNotAcceptable(Response):
|
||||
"""
|
||||
A 415 Media Type Not Acceptable response
|
||||
model
|
||||
"""
|
||||
|
||||
status_code: int = 415
|
||||
msg: str = "Media type not acceptable"
|
||||
|
||||
|
||||
class BadRequest(Response):
|
||||
"""
|
||||
A 400 Bad Request response model
|
||||
"""
|
||||
|
||||
status_code: int = 400
|
||||
msg: str = "Bad Request"
|
||||
|
||||
|
||||
class InternalServerError(Response):
|
||||
"""
|
||||
A 500 Internal Server Error response model
|
||||
"""
|
||||
|
||||
status_code: int = 500
|
||||
msg: str = "Internal Server Error"
|
||||
|
||||
|
||||
class TooManyRequests(Response):
|
||||
"""
|
||||
A 429 Too Many Requests response model
|
||||
"""
|
||||
|
||||
status_code: int = 429
|
||||
msg: str = "Too Many Requests"
|
||||
|
||||
|
||||
class Unauthorized(Response):
|
||||
"""
|
||||
A 401 Unauthorized response model
|
||||
"""
|
||||
|
||||
status_code: int = 401
|
||||
msg: str = "Unauthorized"
|
||||
|
||||
|
||||
class Forbidden(Response):
|
||||
"""
|
||||
A 403 Forbidden response model
|
||||
"""
|
||||
|
||||
status_code: int = 403
|
||||
msg: str = "Forbidden"
|
|
@ -0,0 +1,24 @@
|
|||
from responses import Response
|
||||
from orm.users import PublicUserModel, PrivateUserModel
|
||||
|
||||
|
||||
class PublicUserResponse(Response):
|
||||
"""
|
||||
A response sent by an endpoint that
|
||||
replies with a public user object
|
||||
"""
|
||||
|
||||
status_code: int = 200
|
||||
msg: str = "Lookup successful"
|
||||
data: PublicUserModel
|
||||
|
||||
|
||||
class PrivateUserResponse(Response):
|
||||
"""
|
||||
A response sent by an endpoint that
|
||||
replies with a private user object
|
||||
"""
|
||||
|
||||
status_code: int = 200
|
||||
data: PrivateUserModel
|
||||
msg: str = "Lookup successful"
|
|
@ -7,6 +7,9 @@ from slowapi.errors import RateLimitExceeded
|
|||
|
||||
|
||||
async def rate_limited(request: Request, error: RateLimitExceeded) -> JSONResponse:
|
||||
"""
|
||||
Handles the equivalent of a 429 Too Many Requests error
|
||||
"""
|
||||
n = 0
|
||||
while True:
|
||||
if error.detail[n].isnumeric():
|
||||
|
@ -18,79 +21,58 @@ async def rate_limited(request: Request, error: RateLimitExceeded) -> JSONRespon
|
|||
f"{request.client.host} got rate-limited at {str(request.url)} "
|
||||
f"(exceeded {error.detail})"
|
||||
)
|
||||
return JSONResponse(
|
||||
status_code=200,
|
||||
content={
|
||||
"msg": f"Too many requests, retry after {error.detail[error.detail.find('per') + 4:]}",
|
||||
"status_code": 429,
|
||||
},
|
||||
)
|
||||
return JSONResponse(status_code=200, content=dict(status_code=429,
|
||||
msg=f"Too many requests, retry after {error.detail[error.detail.find('per') + 4:]}"))
|
||||
|
||||
|
||||
def not_authenticated(request: Request, _: NotAuthenticated) -> JSONResponse:
|
||||
"""
|
||||
Handles the equivalent of a 401 Unauthorized exception
|
||||
"""
|
||||
|
||||
LOGGER.info(f"{request.client.host} failed to authenticate at {str(request.url)}")
|
||||
return JSONResponse(
|
||||
status_code=200,
|
||||
content={
|
||||
"msg": "Authentication is required",
|
||||
"status_code": 401,
|
||||
},
|
||||
)
|
||||
return JSONResponse(status_code=200, content=dict(status_code=401, msg="Authentication is required"))
|
||||
|
||||
|
||||
def request_invalid(request: Request, exc: RequestValidationError) -> JSONResponse:
|
||||
"""
|
||||
Handles Bad Request exceptions from FastAPI
|
||||
"""
|
||||
|
||||
LOGGER.info(
|
||||
f"{request.client.host} sent an invalid request at {request.url!r}: {type(exc).__name__}: {exc}"
|
||||
)
|
||||
return JSONResponse(
|
||||
status_code=200,
|
||||
content={
|
||||
"msg": f"Bad request: {type(exc).__name__}: {exc}",
|
||||
"status_code": 400,
|
||||
},
|
||||
)
|
||||
return JSONResponse(status_code=200, content=dict(status_code=400, msg=f"Bad request: {type(exc).__name__}: {exc}"))
|
||||
|
||||
|
||||
def http_exception(
|
||||
request: Request, exc: HTTPException | StarletteHTTPException
|
||||
) -> JSONResponse:
|
||||
"""
|
||||
Handles HTTP-specific exceptions raised explicitly by
|
||||
path operations
|
||||
"""
|
||||
|
||||
if exc.status_code >= 500:
|
||||
LOGGER.error(
|
||||
f"{request.client.host} raised a {exc.status_code} error at {request.url!r}:"
|
||||
f"{type(exc).__name__}: {exc}"
|
||||
)
|
||||
return JSONResponse(
|
||||
status_code=200,
|
||||
content={
|
||||
"msg": "Internal server error",
|
||||
"status_code": exc.status_code,
|
||||
},
|
||||
)
|
||||
return JSONResponse(status_code=200, content=dict(status_code=500, msg="Internal Server Error"))
|
||||
else:
|
||||
LOGGER.info(
|
||||
f"{request.client.host} raised an HTTP error ({exc.status_code}) at {str(request.url)}"
|
||||
)
|
||||
return JSONResponse(
|
||||
status_code=200,
|
||||
content={
|
||||
"msg": exc.detail,
|
||||
"status_code": exc.status_code,
|
||||
},
|
||||
)
|
||||
return JSONResponse(status_code=200, content=dict(status_code=exc.status_code, msg=exc.detail))
|
||||
|
||||
|
||||
async def generic_error(request: Request, exc: Exception) -> JSONResponse:
|
||||
"""
|
||||
Handles generic, unexpected errors in the application
|
||||
Handles generic, unexpected errors in the ASGI application
|
||||
"""
|
||||
|
||||
LOGGER.info(
|
||||
f"{request.client.host} raised an unexpected error ({type(exc).__name__}: {exc}) at {str(request.url)}"
|
||||
)
|
||||
return JSONResponse(
|
||||
status_code=200,
|
||||
content={
|
||||
"msg": "Internal Server Error", # We can't leak anything about the error, it would be too risky
|
||||
"status_code": 500,
|
||||
},
|
||||
)
|
||||
# We can't leak anything about the error, it would be too risky
|
||||
return JSONResponse(status_code=200, content=dict(status_code=500, msg="Internal Server Error"))
|
||||
|
|
Loading…
Reference in New Issue