Large rework of the application for better documentation support
This commit is contained in:
parent
f0ef827577
commit
1c05716345
|
@ -6,10 +6,8 @@ import re
|
||||||
import imghdr
|
import imghdr
|
||||||
import uuid
|
import uuid
|
||||||
import zlib
|
import zlib
|
||||||
|
|
||||||
import bcrypt
|
import bcrypt
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
import validators
|
import validators
|
||||||
from fastapi import APIRouter as FastAPI, Depends, Response, Request, UploadFile
|
from fastapi import APIRouter as FastAPI, Depends, Response, Request, UploadFile
|
||||||
from fastapi.exceptions import HTTPException
|
from fastapi.exceptions import HTTPException
|
||||||
|
@ -52,47 +50,31 @@ from config import (
|
||||||
FORCE_EMAIL_VERIFICATION,
|
FORCE_EMAIL_VERIFICATION,
|
||||||
UNVERIFIED_MANAGER,
|
UNVERIFIED_MANAGER,
|
||||||
)
|
)
|
||||||
from orm.users import UserModel, User
|
from responses import Response as APIResponse, UnprocessableEntity, BadRequest, NotFound, \
|
||||||
|
MediaTypeNotAcceptable, PayloadTooLarge, InternalServerError
|
||||||
|
from responses.users import (
|
||||||
|
PrivateUserResponse,
|
||||||
|
PublicUserResponse,
|
||||||
|
PrivateUserModel,
|
||||||
|
PublicUserModel,
|
||||||
|
)
|
||||||
|
from orm.users import (
|
||||||
|
User,
|
||||||
|
UserModel,
|
||||||
|
get_user_by_username,
|
||||||
|
get_user_by_id,
|
||||||
|
get_user_by_email,
|
||||||
|
)
|
||||||
from orm.media import Media, MediaType
|
from orm.media import Media, MediaType
|
||||||
from orm.email_verification import EmailVerification
|
from orm.email_verification import EmailVerification
|
||||||
from util.email import send_email
|
from util.email import send_email
|
||||||
|
|
||||||
|
|
||||||
router = FastAPI()
|
router = FastAPI()
|
||||||
|
|
||||||
|
|
||||||
async def get_user_by_id(
|
# Credential loaders for our authenticated routes
|
||||||
public_id: UUID,
|
|
||||||
include_secrets: bool = False,
|
|
||||||
restricted_ok: bool = False,
|
|
||||||
deleted_ok: bool = False,
|
|
||||||
) -> dict | None:
|
|
||||||
"""
|
|
||||||
Retrieves a user by its public ID
|
|
||||||
"""
|
|
||||||
|
|
||||||
user = (
|
|
||||||
await User.select(
|
|
||||||
*User.all_columns(exclude=["public_id"]),
|
|
||||||
User.public_id.as_alias("id"),
|
|
||||||
exclude_secrets=not include_secrets,
|
|
||||||
)
|
|
||||||
.where(User.public_id == public_id)
|
|
||||||
.first()
|
|
||||||
)
|
|
||||||
if user:
|
|
||||||
# Performs validation
|
|
||||||
UserModel(**user)
|
|
||||||
if (user["deleted"] and not deleted_ok) or (
|
|
||||||
user["restricted"] and not restricted_ok
|
|
||||||
):
|
|
||||||
return
|
|
||||||
return user
|
|
||||||
return
|
|
||||||
|
|
||||||
|
|
||||||
@MANAGER.user_loader()
|
@MANAGER.user_loader()
|
||||||
async def get_self_by_id(public_id: UUID, requires_verified: bool = True) -> dict:
|
async def get_self_by_id(public_id: UUID, requires_verified: bool = True) -> UserModel:
|
||||||
user = await get_user_by_id(public_id, include_secrets=True, restricted_ok=True)
|
user = await get_user_by_id(public_id, include_secrets=True, restricted_ok=True)
|
||||||
if FORCE_EMAIL_VERIFICATION and requires_verified and not user["email_verified"]:
|
if FORCE_EMAIL_VERIFICATION and requires_verified and not user["email_verified"]:
|
||||||
raise HTTPException(status_code=401, detail="Email verification is required")
|
raise HTTPException(status_code=401, detail="Email verification is required")
|
||||||
|
@ -100,76 +82,22 @@ async def get_self_by_id(public_id: UUID, requires_verified: bool = True) -> dic
|
||||||
|
|
||||||
|
|
||||||
@UNVERIFIED_MANAGER.user_loader()
|
@UNVERIFIED_MANAGER.user_loader()
|
||||||
async def get_self_by_id_unverified(public_id: UUID):
|
async def get_self_by_id_unverified(public_id: UUID) -> UserModel:
|
||||||
return await get_self_by_id(public_id, False)
|
return await get_self_by_id(public_id, False)
|
||||||
|
|
||||||
|
|
||||||
async def get_user_by_username(
|
# Here follow our *beautifully* documented path operations
|
||||||
username: str,
|
|
||||||
include_secrets: bool = False,
|
|
||||||
restricted_ok: bool = False,
|
|
||||||
deleted_ok: bool = False,
|
|
||||||
) -> dict | None:
|
|
||||||
"""
|
|
||||||
Retrieves a user by its public username
|
|
||||||
"""
|
|
||||||
|
|
||||||
user = (
|
|
||||||
await User.select(
|
|
||||||
*User.all_columns(exclude=["public_id"]),
|
|
||||||
User.public_id.as_alias("id"),
|
|
||||||
exclude_secrets=not include_secrets,
|
|
||||||
)
|
|
||||||
.where(User.username == username)
|
|
||||||
.first()
|
|
||||||
)
|
|
||||||
if user:
|
|
||||||
# Performs validation
|
|
||||||
UserModel(**user)
|
|
||||||
if (user["deleted"] and not deleted_ok) or (
|
|
||||||
user["restricted"] and not restricted_ok
|
|
||||||
):
|
|
||||||
return
|
|
||||||
return user
|
|
||||||
return
|
|
||||||
|
|
||||||
|
|
||||||
async def get_user_by_email(
|
|
||||||
email: str,
|
|
||||||
include_secrets: bool = False,
|
|
||||||
restricted_ok: bool = False,
|
|
||||||
deleted_ok: bool = False,
|
|
||||||
) -> dict | None:
|
|
||||||
"""
|
|
||||||
Retrieves a user by its email address (meant to
|
|
||||||
be used internally)
|
|
||||||
"""
|
|
||||||
|
|
||||||
user = (
|
|
||||||
await User.select(
|
|
||||||
*User.all_columns(exclude=["public_id"]),
|
|
||||||
User.public_id.as_alias("id"),
|
|
||||||
exclude_secrets=not include_secrets,
|
|
||||||
)
|
|
||||||
.where(User.email_address == email)
|
|
||||||
.first()
|
|
||||||
)
|
|
||||||
if user:
|
|
||||||
# Performs validation
|
|
||||||
UserModel(**user)
|
|
||||||
if (user["deleted"] and not deleted_ok) or (
|
|
||||||
user["restricted"] and not restricted_ok
|
|
||||||
):
|
|
||||||
return
|
|
||||||
return user
|
|
||||||
return
|
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/user", tags=["Users"], status_code=200,
|
||||||
|
responses={200: {"model": APIResponse},
|
||||||
|
400: {"model": BadRequest},
|
||||||
|
422: {"model": UnprocessableEntity}})
|
||||||
@LIMITER.limit("5/minute")
|
@LIMITER.limit("5/minute")
|
||||||
@router.post("/user")
|
async def login(request: Request, response: Response, data: OAuth2PasswordRequestForm = Depends()):
|
||||||
async def login(
|
"""
|
||||||
request: Request, response: Response, data: OAuth2PasswordRequestForm = Depends()
|
Performs user authentication. Endpoint is limited to 5 hits per minute
|
||||||
) -> dict:
|
"""
|
||||||
|
|
||||||
if request.cookies.get(SESSION_COOKIE_NAME):
|
if request.cookies.get(SESSION_COOKIE_NAME):
|
||||||
raise HTTPException(status_code=400, detail="Please logout first")
|
raise HTTPException(status_code=400, detail="Please logout first")
|
||||||
username = data.username
|
username = data.username
|
||||||
|
@ -192,21 +120,21 @@ async def login(
|
||||||
detail="Authentication failed: invalid characters in password",
|
detail="Authentication failed: invalid characters in password",
|
||||||
)
|
)
|
||||||
if not (
|
if not (
|
||||||
user := await get_user_by_username(
|
user := await get_user_by_username(
|
||||||
username, include_secrets=True, restricted_ok=True
|
username, include_secrets=True, restricted_ok=True
|
||||||
)
|
)
|
||||||
):
|
):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=413,
|
status_code=413,
|
||||||
detail="Authentication failed: the user does not exist",
|
detail="Authentication failed: the user does not exist",
|
||||||
)
|
)
|
||||||
if not bcrypt.checkpw(password, user["password_hash"]):
|
if not bcrypt.checkpw(password, user.password_hash):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=413,
|
status_code=413,
|
||||||
detail="Authentication failed: password mismatch",
|
detail="Authentication failed: password mismatch",
|
||||||
)
|
)
|
||||||
token = MANAGER.create_access_token(
|
token = MANAGER.create_access_token(
|
||||||
expires=timedelta(seconds=SESSION_EXPIRE_LIMIT), data={"sub": str(user["id"])}
|
expires=timedelta(seconds=SESSION_EXPIRE_LIMIT), data={"sub": str(user.public_id)}
|
||||||
)
|
)
|
||||||
response.set_cookie(
|
response.set_cookie(
|
||||||
secure=SECURE_COOKIE,
|
secure=SECURE_COOKIE,
|
||||||
|
@ -218,17 +146,17 @@ async def login(
|
||||||
domain=COOKIE_DOMAIN or None,
|
domain=COOKIE_DOMAIN or None,
|
||||||
path=COOKIE_PATH or "/",
|
path=COOKIE_PATH or "/",
|
||||||
)
|
)
|
||||||
return {"status_code": 200, "msg": "Authentication successful"}
|
return APIResponse(status_code=200, msg="Authentication successful")
|
||||||
|
|
||||||
|
|
||||||
@router.get("/user/logout")
|
@router.get("/user/logout", tags=["Users"], status_code=200,
|
||||||
|
responses={200: {"model": APIResponse},
|
||||||
|
422: {"model": UnprocessableEntity}})
|
||||||
@LIMITER.limit("5/minute")
|
@LIMITER.limit("5/minute")
|
||||||
async def logout(
|
async def logout(request: Request, response: Response, _user: UserModel = Depends(UNVERIFIED_MANAGER)):
|
||||||
request: Request, response: Response, _user: dict = Depends(UNVERIFIED_MANAGER)
|
|
||||||
) -> dict:
|
|
||||||
"""
|
"""
|
||||||
Deletes a user's session cookie, logging them
|
Deletes a user's session cookie, logging them
|
||||||
out
|
out. Endpoint is limited to 5 hits per minute
|
||||||
"""
|
"""
|
||||||
|
|
||||||
response.delete_cookie(
|
response.delete_cookie(
|
||||||
|
@ -239,70 +167,81 @@ async def logout(
|
||||||
domain=COOKIE_DOMAIN or None,
|
domain=COOKIE_DOMAIN or None,
|
||||||
path=COOKIE_PATH or "/",
|
path=COOKIE_PATH or "/",
|
||||||
)
|
)
|
||||||
return {"status_code": 200, "msg": "Logged out"}
|
return APIResponse(status_code=200, msg="Logged out")
|
||||||
|
|
||||||
|
|
||||||
@router.get("/user/me")
|
@router.get(
|
||||||
|
"/user/me",
|
||||||
|
tags=["Users"],
|
||||||
|
status_code=200,
|
||||||
|
responses={200: {"model": PrivateUserResponse, "exclude": {"password_hash", "internal_id", "deleted"}},
|
||||||
|
422: {"model": UnprocessableEntity}},
|
||||||
|
)
|
||||||
@LIMITER.limit("2/second")
|
@LIMITER.limit("2/second")
|
||||||
async def get_self(request: Request, user: dict = Depends(UNVERIFIED_MANAGER)) -> dict:
|
async def get_self(request: Request, user: UserModel = Depends(UNVERIFIED_MANAGER)):
|
||||||
"""
|
"""
|
||||||
Fetches a user's own info. This returns some
|
Fetches a user's own info. This returns some
|
||||||
extra data such as email address, account
|
extra data such as email address, account
|
||||||
creation date and email verification status,
|
creation date and email verification status,
|
||||||
which is not available from the regular endpoint
|
which is not available from the regular endpoint.
|
||||||
|
Endpoint is limited to 2 hits per second
|
||||||
"""
|
"""
|
||||||
|
|
||||||
user.pop("password_hash")
|
return PrivateUserResponse(status_code=200, msg="Success", data=user)
|
||||||
user.pop("internal_id")
|
|
||||||
user.pop("deleted")
|
|
||||||
return {"status_code": 200, "msg": "Success", "data": user}
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/user/username/{username}")
|
@router.get(
|
||||||
|
"/user/username/{username}",
|
||||||
|
tags=["Users"],
|
||||||
|
status_code=200,
|
||||||
|
responses={200: {"model": PrivateUserResponse, "exclude": {"password_hash", "internal_id", "deleted"}},
|
||||||
|
404: {"model": NotFound},
|
||||||
|
422: {"model": UnprocessableEntity}
|
||||||
|
},
|
||||||
|
)
|
||||||
@LIMITER.limit("30/second")
|
@LIMITER.limit("30/second")
|
||||||
async def get_user_by_name(
|
async def get_user_by_name(
|
||||||
request: Request, username: str, _auth: dict = Depends(MANAGER)
|
request: Request, username: str, _auth: UserModel = Depends(MANAGER)
|
||||||
) -> dict:
|
):
|
||||||
"""
|
"""
|
||||||
Fetches a single user by its public ID
|
Fetches a single user by its public username
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if not (user := await get_user_by_username(username)):
|
if not (user := await get_user_by_username(username)):
|
||||||
return {
|
return NotFound(msg="Lookup failed: the user does not exist")
|
||||||
"status_code": 404,
|
return PublicUserResponse(status_code=200, msg="Lookup successful", data=user)
|
||||||
"msg": "Lookup failed: the user does not exist",
|
|
||||||
}
|
|
||||||
user.pop("restricted")
|
|
||||||
user.pop("deleted")
|
|
||||||
return {"status_code": 200, "msg": "Lookup successful", "data": user}
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/user/id/{public_id}")
|
@router.get(
|
||||||
|
"/user/id/{public_id}",
|
||||||
|
tags=["Users"],
|
||||||
|
status_code=200,
|
||||||
|
responses={200: {"model": PrivateUserResponse, "exclude": {"password_hash", "internal_id", "deleted"}},
|
||||||
|
404: {"model": NotFound},
|
||||||
|
422: {"model": UnprocessableEntity}
|
||||||
|
},
|
||||||
|
)
|
||||||
@LIMITER.limit("30/second")
|
@LIMITER.limit("30/second")
|
||||||
async def get_user_by_public_id(
|
async def get_user_by_public_id(
|
||||||
request: Request, public_id: str, _auth: dict = Depends(MANAGER)
|
request: Request, public_id: str, _auth: UserModel = Depends(MANAGER)
|
||||||
) -> dict:
|
):
|
||||||
"""
|
"""
|
||||||
Fetches a single user by its public ID
|
Fetches a single user by its public ID
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if not (user := await get_user_by_id(UUID(public_id))):
|
if not (user := await get_user_by_id(UUID(public_id))):
|
||||||
raise HTTPException(
|
raise HTTPException(status_code=404, detail="Lookup failed: the user does not exist")
|
||||||
status_code=404, detail="Lookup failed: the user does not exist"
|
return PublicUserResponse(data=user)
|
||||||
)
|
|
||||||
user.pop("restricted")
|
|
||||||
user.pop("deleted")
|
|
||||||
return {"status_code": 200, "msg": "Lookup successful", "data": user}
|
|
||||||
|
|
||||||
|
|
||||||
async def validate_user(
|
async def validate_user(
|
||||||
first_name: str | None,
|
first_name: str | None,
|
||||||
last_name: str | None,
|
last_name: str | None,
|
||||||
username: str | None,
|
username: str | None,
|
||||||
email: str | None,
|
email: str | None,
|
||||||
password: str | None,
|
password: str | None,
|
||||||
bio: str | None,
|
bio: str | None,
|
||||||
) -> tuple[bool, str]:
|
):
|
||||||
"""
|
"""
|
||||||
Performs some validation upon user creation. Returns
|
Performs some validation upon user creation. Returns
|
||||||
a tuple (success, msg) to be used by routes. Values
|
a tuple (success, msg) to be used by routes. Values
|
||||||
|
@ -322,9 +261,9 @@ async def validate_user(
|
||||||
if username and len(username) > 32:
|
if username and len(username) > 32:
|
||||||
return False, "username is too long"
|
return False, "username is too long"
|
||||||
if (
|
if (
|
||||||
username
|
username
|
||||||
and VALIDATE_USERNAME_REGEX
|
and VALIDATE_USERNAME_REGEX
|
||||||
and not re.match(VALIDATE_USERNAME_REGEX, username)
|
and not re.match(VALIDATE_USERNAME_REGEX, username)
|
||||||
):
|
):
|
||||||
return False, "username is invalid"
|
return False, "username is invalid"
|
||||||
if email and not validators.email(email):
|
if email and not validators.email(email):
|
||||||
|
@ -332,13 +271,13 @@ async def validate_user(
|
||||||
if password and len(password) > 72:
|
if password and len(password) > 72:
|
||||||
return False, "password is too long"
|
return False, "password is too long"
|
||||||
if (
|
if (
|
||||||
password
|
password
|
||||||
and VALIDATE_PASSWORD_REGEX
|
and VALIDATE_PASSWORD_REGEX
|
||||||
and not re.match(VALIDATE_PASSWORD_REGEX, password)
|
and not re.match(VALIDATE_PASSWORD_REGEX, password)
|
||||||
):
|
):
|
||||||
return False, "password is too weak"
|
return False, "password is too weak"
|
||||||
if username and await get_user_by_username(
|
if username and await get_user_by_username(
|
||||||
username, deleted_ok=True, restricted_ok=True
|
username, deleted_ok=True, restricted_ok=True
|
||||||
):
|
):
|
||||||
return False, "username is already taken"
|
return False, "username is already taken"
|
||||||
if email and await get_user_by_email(email, deleted_ok=True, restricted_ok=True):
|
if email and await get_user_by_email(email, deleted_ok=True, restricted_ok=True):
|
||||||
|
@ -353,18 +292,20 @@ async def validate_user(
|
||||||
return True, ""
|
return True, ""
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/user")
|
@router.delete("/user", tags=["Users"], status_code=200,
|
||||||
|
responses={200: {"model": APIResponse},
|
||||||
|
422: {"model": UnprocessableEntity}})
|
||||||
@LIMITER.limit("1/minute")
|
@LIMITER.limit("1/minute")
|
||||||
async def delete(
|
async def delete(
|
||||||
request: Request, response: Response, user: dict = Depends(UNVERIFIED_MANAGER)
|
request: Request, response: Response, user: UserModel = Depends(UNVERIFIED_MANAGER)
|
||||||
) -> dict:
|
):
|
||||||
"""
|
"""
|
||||||
Sets the user's deleted flag in the database,
|
Sets the user's deleted flag in the database,
|
||||||
without actually deleting the associated
|
without actually deleting the associated
|
||||||
data
|
data
|
||||||
"""
|
"""
|
||||||
|
|
||||||
await User.update({User.deleted: True}).where(User.public_id == user["id"])
|
await User.update({User.deleted: True}).where(User.public_id == user.public_id)
|
||||||
response.delete_cookie(
|
response.delete_cookie(
|
||||||
secure=SECURE_COOKIE,
|
secure=SECURE_COOKIE,
|
||||||
key=SESSION_COOKIE_NAME,
|
key=SESSION_COOKIE_NAME,
|
||||||
|
@ -373,22 +314,28 @@ async def delete(
|
||||||
domain=COOKIE_DOMAIN or None,
|
domain=COOKIE_DOMAIN or None,
|
||||||
path=COOKIE_PATH or "/",
|
path=COOKIE_PATH or "/",
|
||||||
)
|
)
|
||||||
return {"status_code": 200, "msg": "Success"}
|
return APIResponse(status_code=200, msg="Success")
|
||||||
|
|
||||||
|
|
||||||
@router.get("/user/verifyEmail/{verification_id}")
|
@router.get("/user/verifyEmail/{verification_id}", tags=["Users"], status_code=200,
|
||||||
|
responses={200: {"model": PrivateUserResponse, "exclude": {"password_hash", "internal_id", "deleted"}},
|
||||||
|
404: {"model": NotFound},
|
||||||
|
422: {"model": UnprocessableEntity}
|
||||||
|
})
|
||||||
@LIMITER.limit("3/second")
|
@LIMITER.limit("3/second")
|
||||||
async def verify_email(
|
async def verify_email(
|
||||||
request: Request, verification_id: str, user: dict = Depends(UNVERIFIED_MANAGER)
|
request: Request,
|
||||||
) -> dict:
|
verification_id: str,
|
||||||
|
user: UserModel = Depends(UNVERIFIED_MANAGER),
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Verifies a user's email address
|
Verifies a user's email address
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if not (
|
if not (
|
||||||
verification := await EmailVerification.select(*EmailVerification.all_columns())
|
verification := await EmailVerification.select(*EmailVerification.all_columns())
|
||||||
.where(EmailVerification.id == verification_id)
|
.where(EmailVerification.id == verification_id)
|
||||||
.first()
|
.first()
|
||||||
):
|
):
|
||||||
raise HTTPException(status_code=404, detail="Verification ID is invalid")
|
raise HTTPException(status_code=404, detail="Verification ID is invalid")
|
||||||
elif not verification["pending"]:
|
elif not verification["pending"]:
|
||||||
|
@ -402,27 +349,31 @@ async def verify_email(
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
await EmailVerification.update({EmailVerification.pending: False}).where(
|
await EmailVerification.update({EmailVerification.pending: False}).where(
|
||||||
EmailVerification.user == user["id"]
|
EmailVerification.user == user.public_id
|
||||||
)
|
)
|
||||||
await User.update({User.email_verified: True}).where(
|
await User.update({User.email_verified: True}).where(User.public_id == user.public_id)
|
||||||
User.public_id == user["id"]
|
return APIResponse(status_code=200, msg="Verification successful")
|
||||||
)
|
|
||||||
return {"status_code": 200, "msg": "Verification successful"}
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/user/resetPassword/{verification_id}")
|
@router.get("/user/resetPassword/{verification_id}", tags=["Users"], status_code=200,
|
||||||
|
responses={200: {"model": APIResponse},
|
||||||
|
400: {"model": BadRequest},
|
||||||
|
422: {"model": UnprocessableEntity},
|
||||||
|
404: {"model": NotFound}})
|
||||||
@LIMITER.limit("3/second")
|
@LIMITER.limit("3/second")
|
||||||
async def reset_password(
|
async def reset_password(
|
||||||
request: Request, verification_id: str, user: dict = Depends(UNVERIFIED_MANAGER)
|
request: Request,
|
||||||
) -> dict:
|
verification_id: str,
|
||||||
|
user: UserModel = Depends(UNVERIFIED_MANAGER),
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Modifies a user's password
|
Modifies a user's password
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if not (
|
if not (
|
||||||
verification := await EmailVerification.select(*EmailVerification.all_columns())
|
verification := await EmailVerification.select(*EmailVerification.all_columns())
|
||||||
.where(EmailVerification.id == verification_id)
|
.where(EmailVerification.id == verification_id)
|
||||||
.first()
|
.first()
|
||||||
):
|
):
|
||||||
raise HTTPException(status_code=404, detail="Request ID is invalid")
|
raise HTTPException(status_code=404, detail="Request ID is invalid")
|
||||||
elif not verification["pending"]:
|
elif not verification["pending"]:
|
||||||
|
@ -436,27 +387,33 @@ async def reset_password(
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
await EmailVerification.update({EmailVerification.pending: False}).where(
|
await EmailVerification.update({EmailVerification.pending: False}).where(
|
||||||
EmailVerification.user == user["id"]
|
EmailVerification.user == user.public_id
|
||||||
)
|
)
|
||||||
await User.update({User.password_hash: verification["data"]}).where(
|
await User.update({User.password_hash: verification["data"]}).where(
|
||||||
User.public_id == user["id"]
|
User.public_id == user.public_id
|
||||||
)
|
)
|
||||||
return {"status_code": 200, "msg": "Password updated"}
|
return APIResponse(status_code=200, msg="Password updated")
|
||||||
|
|
||||||
|
|
||||||
@router.get("/user/changeEmail/{verification_id}")
|
@router.get("/user/changeEmail/{verification_id}", tags=["Users"], status_code=200,
|
||||||
|
responses={200: {"model": APIResponse},
|
||||||
|
400: {"model": BadRequest},
|
||||||
|
422: {"model": UnprocessableEntity},
|
||||||
|
404: {"model": NotFound}})
|
||||||
@LIMITER.limit("3/second")
|
@LIMITER.limit("3/second")
|
||||||
async def change_email(
|
async def change_email(
|
||||||
request: Request, verification_id: str, user: dict = Depends(UNVERIFIED_MANAGER)
|
request: Request,
|
||||||
) -> dict:
|
verification_id: str,
|
||||||
|
user: UserModel = Depends(UNVERIFIED_MANAGER),
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Modifies a user's email
|
Modifies a user's email
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if not (
|
if not (
|
||||||
verification := await EmailVerification.select(*EmailVerification.all_columns())
|
verification := await EmailVerification.select(*EmailVerification.all_columns())
|
||||||
.where(EmailVerification.id == verification_id)
|
.where(EmailVerification.id == verification_id)
|
||||||
.first()
|
.first()
|
||||||
):
|
):
|
||||||
raise HTTPException(status_code=404, detail="Request ID is invalid")
|
raise HTTPException(status_code=404, detail="Request ID is invalid")
|
||||||
elif not verification["pending"]:
|
elif not verification["pending"]:
|
||||||
|
@ -470,18 +427,24 @@ async def change_email(
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
await EmailVerification.update({EmailVerification.pending: False}).where(
|
await EmailVerification.update({EmailVerification.pending: False}).where(
|
||||||
EmailVerification.user == user["id"]
|
EmailVerification.user == user.public_id
|
||||||
)
|
)
|
||||||
await User.update({User.email_address: verification["data"].decode(),
|
await User.update(
|
||||||
User.email_verified: False}).where(
|
{
|
||||||
User.public_id == user["id"]
|
User.email_address: verification["data"].decode(),
|
||||||
)
|
User.email_verified: False,
|
||||||
return {"status_code": 200, "msg": "Email updated"}
|
}
|
||||||
|
).where(User.public_id == user.public_id)
|
||||||
|
return APIResponse(status_code=200, msg="Email updated")
|
||||||
|
|
||||||
|
|
||||||
@router.put("user/resendMail")
|
@router.put("user/resendMail", tags=["Users"], status_code=200,
|
||||||
|
responses={200: {"model": APIResponse},
|
||||||
|
400: {"model": BadRequest},
|
||||||
|
422: {"model": UnprocessableEntity},
|
||||||
|
500: {"model": InternalServerError}})
|
||||||
@LIMITER.limit("6/minute")
|
@LIMITER.limit("6/minute")
|
||||||
async def resend_email(request: Request, user: dict = Depends(UNVERIFIED_MANAGER)) -> dict:
|
async def resend_email(request: Request, user: UserModel = Depends(UNVERIFIED_MANAGER)):
|
||||||
"""
|
"""
|
||||||
Resends the verification email to the user if the previous has expired
|
Resends the verification email to the user if the previous has expired
|
||||||
"""
|
"""
|
||||||
|
@ -523,26 +486,34 @@ async def resend_email(request: Request, user: dict = Depends(UNVERIFIED_MANAGER
|
||||||
use_tls=SMTP_USE_TLS,
|
use_tls=SMTP_USE_TLS,
|
||||||
):
|
):
|
||||||
await EmailVerification.update(
|
await EmailVerification.update(
|
||||||
{
|
{
|
||||||
EmailVerification.id: verification_id,
|
EmailVerification.id: verification_id,
|
||||||
EmailVerification.creation_date: datetime.now()
|
EmailVerification.creation_date: datetime.now(),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
return {"status_code": 200, "msg": "Success"}
|
return APIResponse(status_code=200, msg="Success")
|
||||||
|
else:
|
||||||
|
raise HTTPException(status_code=500, detail="An error occurred while trying to resend the email,"
|
||||||
|
" please try again later")
|
||||||
|
|
||||||
|
|
||||||
@router.put("/user")
|
@router.put("/user", tags=["Users"], status_code=200,
|
||||||
|
responses={200: {"model": APIResponse},
|
||||||
|
400: {"model": BadRequest},
|
||||||
|
422: {"model": UnprocessableEntity},
|
||||||
|
500: {"model": InternalServerError},
|
||||||
|
413: {"model": PayloadTooLarge}})
|
||||||
@LIMITER.limit("2/minute")
|
@LIMITER.limit("2/minute")
|
||||||
async def signup(
|
async def signup(
|
||||||
request: Request,
|
request: Request,
|
||||||
first_name: str,
|
first_name: str,
|
||||||
last_name: str,
|
last_name: str,
|
||||||
username: str,
|
username: str,
|
||||||
email: str,
|
email: str,
|
||||||
password: str,
|
password: str,
|
||||||
bio: str | None = None,
|
bio: str | None = None,
|
||||||
locale: str = "en_US",
|
locale: str = "en_US",
|
||||||
) -> dict:
|
):
|
||||||
"""
|
"""
|
||||||
Endpoint used to create new users
|
Endpoint used to create new users
|
||||||
"""
|
"""
|
||||||
|
@ -555,7 +526,7 @@ async def signup(
|
||||||
first_name, last_name, username, email, password, bio
|
first_name, last_name, username, email, password, bio
|
||||||
)
|
)
|
||||||
if not result:
|
if not result:
|
||||||
return {"status_code": 413, "msg": f"Signup failed: {msg}"}
|
return APIResponse(status_code=413, msg=f"Signup failed: {msg}")
|
||||||
else:
|
else:
|
||||||
salt = bcrypt.gensalt(BCRYPT_ROUNDS)
|
salt = bcrypt.gensalt(BCRYPT_ROUNDS)
|
||||||
user = User(
|
user = User(
|
||||||
|
@ -563,9 +534,7 @@ async def signup(
|
||||||
last_name=last_name,
|
last_name=last_name,
|
||||||
username=username,
|
username=username,
|
||||||
email_address=email,
|
email_address=email,
|
||||||
password_hash=bcrypt.hashpw(
|
password_hash=bcrypt.hashpw(password.encode(), salt),
|
||||||
password.encode(), salt
|
|
||||||
),
|
|
||||||
bio=bio,
|
bio=bio,
|
||||||
)
|
)
|
||||||
email_template = SMTP_TEMPLATES_DIRECTORY / f"{locale}.json"
|
email_template = SMTP_TEMPLATES_DIRECTORY / f"{locale}.json"
|
||||||
|
@ -577,30 +546,30 @@ async def signup(
|
||||||
email_message = json.load(f)["signup"]
|
email_message = json.load(f)["signup"]
|
||||||
verification_id = uuid.uuid4()
|
verification_id = uuid.uuid4()
|
||||||
if await send_email(
|
if await send_email(
|
||||||
SMTP_HOST,
|
SMTP_HOST,
|
||||||
SMTP_PORT,
|
SMTP_PORT,
|
||||||
email_message["content"].format(
|
email_message["content"].format(
|
||||||
first_name=first_name,
|
first_name=first_name,
|
||||||
last_name=last_name,
|
last_name=last_name,
|
||||||
username=username,
|
username=username,
|
||||||
email=email,
|
email=email,
|
||||||
link=f"http{'s' if HAS_HTTPS else ''}://{HOST}"
|
link=f"http{'s' if HAS_HTTPS else ''}://{HOST}"
|
||||||
f"{'' if PORT == 443 and HAS_HTTPS or PORT == 80 else f':{PORT}'}/user/verifyEmail/{verification_id}",
|
f"{'' if PORT == 443 and HAS_HTTPS or PORT == 80 else f':{PORT}'}/user/verifyEmail/{verification_id}",
|
||||||
platformName=PLATFORM_NAME,
|
platformName=PLATFORM_NAME,
|
||||||
),
|
),
|
||||||
SMTP_TIMEOUT,
|
SMTP_TIMEOUT,
|
||||||
SMTP_FROM_USER,
|
SMTP_FROM_USER,
|
||||||
email,
|
email,
|
||||||
email_message["subject"].format(
|
email_message["subject"].format(
|
||||||
first_name=first_name,
|
first_name=first_name,
|
||||||
last_name=last_name,
|
last_name=last_name,
|
||||||
username=username,
|
username=username,
|
||||||
email=email,
|
email=email,
|
||||||
platformName=PLATFORM_NAME,
|
platformName=PLATFORM_NAME,
|
||||||
),
|
),
|
||||||
SMTP_USER,
|
SMTP_USER,
|
||||||
SMTP_PASSWORD,
|
SMTP_PASSWORD,
|
||||||
use_tls=SMTP_USE_TLS,
|
use_tls=SMTP_USE_TLS,
|
||||||
):
|
):
|
||||||
await User.insert(user)
|
await User.insert(user)
|
||||||
await EmailVerification.insert(
|
await EmailVerification.insert(
|
||||||
|
@ -611,17 +580,17 @@ async def signup(
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return {"status_code": 200, "msg": "Success"}
|
return APIResponse(status_code=200, msg="Success")
|
||||||
else:
|
else:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=500,
|
status_code=500,
|
||||||
detail="An error occurred while sending verification email, please"
|
detail="An error occurred while sending verification email, please"
|
||||||
" try again later",
|
" try again later",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def validate_profile_picture(
|
async def validate_profile_picture(
|
||||||
file: UploadFile,
|
file: UploadFile,
|
||||||
) -> tuple[bool | None, str, bytes, str]:
|
) -> tuple[bool | None, str, bytes, str]:
|
||||||
"""
|
"""
|
||||||
Validates a profile picture's size and content to see if it fits
|
Validates a profile picture's size and content to see if it fits
|
||||||
|
@ -650,18 +619,23 @@ async def validate_profile_picture(
|
||||||
return None, "", b"", ""
|
return None, "", b"", ""
|
||||||
|
|
||||||
|
|
||||||
@router.patch("/user")
|
@router.patch("/user", tags=["Users"], status_code=200,
|
||||||
async def update(
|
responses={200: {"model": APIResponse},
|
||||||
request: Request,
|
400: {"model": BadRequest},
|
||||||
user: dict = Depends(UNVERIFIED_MANAGER),
|
422: {"model": UnprocessableEntity},
|
||||||
first_name: str | None = None,
|
500: {"model": InternalServerError},
|
||||||
last_name: str | None = None,
|
413: {"model": PayloadTooLarge}})
|
||||||
username: str | None = None,
|
async def update_user(
|
||||||
profile_picture: UploadFile | None = None,
|
request: Request,
|
||||||
email_address: str | None = None,
|
user: UserModel = Depends(UNVERIFIED_MANAGER),
|
||||||
password: str | None = None,
|
first_name: str | None = None,
|
||||||
bio: str | None = None,
|
last_name: str | None = None,
|
||||||
delete: bool = False,
|
username: str | None = None,
|
||||||
|
profile_picture: UploadFile | None = None,
|
||||||
|
email_address: str | None = None,
|
||||||
|
password: str | None = None,
|
||||||
|
bio: str | None = None,
|
||||||
|
delete: bool = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Updates a user's profile information. Parameters that are not specified are left unchanged unless
|
Updates a user's profile information. Parameters that are not specified are left unchanged unless
|
||||||
|
@ -676,28 +650,26 @@ async def update(
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if not delete and not any(
|
if not delete and not any(
|
||||||
(first_name, last_name, username, profile_picture, email_address, bio, password)
|
(first_name, last_name, username, profile_picture, email_address, bio, password)
|
||||||
):
|
):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=400, detail="At least one value has to be specified"
|
status_code=400, detail="At least one value has to be specified"
|
||||||
)
|
)
|
||||||
result, msg = (
|
result, msg = await validate_user(
|
||||||
await validate_user(
|
first_name, last_name, username, email_address, password, bio
|
||||||
first_name, last_name, username, email_address, password, bio
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
if not result:
|
if not result:
|
||||||
raise HTTPException(status_code=413, detail=f"Update failed: {msg}")
|
raise HTTPException(status_code=413, detail=f"Update failed: {msg}")
|
||||||
orig_user = user.copy()
|
orig_user = user.copy()
|
||||||
if not delete:
|
if not delete:
|
||||||
if first_name:
|
if first_name:
|
||||||
user["first_name"] = first_name
|
user.first_name = first_name
|
||||||
if last_name:
|
if last_name:
|
||||||
user["last_name"] = last_name
|
user.last_name = last_name
|
||||||
if username:
|
if username:
|
||||||
user["username"] = username
|
user.username = username
|
||||||
if bio:
|
if bio:
|
||||||
user["bio"] = bio
|
user.bio = bio
|
||||||
if profile_picture:
|
if profile_picture:
|
||||||
result, ext, media, digest = validate_profile_picture(profile_picture)
|
result, ext, media, digest = validate_profile_picture(profile_picture)
|
||||||
if result is False:
|
if result is False:
|
||||||
|
@ -707,44 +679,44 @@ async def update(
|
||||||
elif result is None:
|
elif result is None:
|
||||||
raise HTTPException(status_code=413, detail="The file is too large")
|
raise HTTPException(status_code=413, detail="The file is too large")
|
||||||
elif (
|
elif (
|
||||||
await (
|
await (
|
||||||
old_media := Media.select(Media.media_id)
|
old_media := Media.select(Media.media_id)
|
||||||
.where(Media.media_id == digest)
|
.where(Media.media_id == digest)
|
||||||
.first()
|
.first()
|
||||||
)
|
)
|
||||||
is None
|
is None
|
||||||
):
|
):
|
||||||
# This media hasn't been already uploaded (either by this user or by someone
|
# This media hasn't been already uploaded (either by this user or by someone
|
||||||
# else), so we save it now. If it has been already uploaded, there's no need
|
# else), so we save it now. If it has been already uploaded, there's no need
|
||||||
# to do it again (that's what the hash is for)
|
# to do it again (that's what the hash is for)
|
||||||
match STORAGE_ENGINE:
|
match STORAGE_ENGINE:
|
||||||
case "database":
|
case "database":
|
||||||
await Media.insert(
|
m = Media(
|
||||||
Media(
|
media_id=digest,
|
||||||
media_id=digest,
|
media_type=MediaType.BLOB,
|
||||||
media_type=MediaType.BLOB,
|
content_type=ext,
|
||||||
content_type=ext,
|
content=base64.b64encode(media),
|
||||||
content=base64.b64encode(media),
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
await Media.insert(m)
|
||||||
|
user.profile_picture = m
|
||||||
case "local":
|
case "local":
|
||||||
file = Path(STORAGE_FOLDER).resolve(strict=True) / str(digest)
|
file = Path(STORAGE_FOLDER).resolve(strict=True) / str(digest)
|
||||||
file.touch(mode=0o644)
|
file.touch(mode=0o644)
|
||||||
with file.open("wb") as f:
|
with file.open("wb") as f:
|
||||||
f.write(media)
|
f.write(media)
|
||||||
await Media.insert(
|
m = Media(
|
||||||
Media(
|
media_id=digest,
|
||||||
media_id=digest,
|
media_type=MediaType.FILE,
|
||||||
media_type=MediaType.FILE,
|
content_type=ext,
|
||||||
content_type=ext,
|
content=file.as_posix(),
|
||||||
content=file.as_posix(),
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
await Media.insert(m)
|
||||||
|
user.profile_picture = m
|
||||||
case "url":
|
case "url":
|
||||||
pass # TODO: Use/implement CDN uploading
|
pass # TODO: Use/implement CDN uploading
|
||||||
else:
|
else:
|
||||||
user["media"] = old_media
|
user.media = old_media
|
||||||
if password and not bcrypt.checkpw(password.encode(), user["password_hash"]):
|
if password and not bcrypt.checkpw(password.encode(), user.password_hash):
|
||||||
email_template = SMTP_TEMPLATES_DIRECTORY / f"{user['locale']}.json"
|
email_template = SMTP_TEMPLATES_DIRECTORY / f"{user['locale']}.json"
|
||||||
try:
|
try:
|
||||||
email_template.resolve(strict=True)
|
email_template.resolve(strict=True)
|
||||||
|
@ -754,43 +726,48 @@ async def update(
|
||||||
email_message = json.load(f)["password_change"]
|
email_message = json.load(f)["password_change"]
|
||||||
verification_id = uuid.uuid4()
|
verification_id = uuid.uuid4()
|
||||||
if not await send_email(
|
if not await send_email(
|
||||||
SMTP_HOST,
|
SMTP_HOST,
|
||||||
SMTP_PORT,
|
SMTP_PORT,
|
||||||
email_message["content"].format(
|
email_message["content"].format(
|
||||||
first_name=user["first_name"],
|
first_name=user.first_name,
|
||||||
last_name=user["last_name"],
|
last_name=user.last_name,
|
||||||
username=user["username"],
|
username=user.username,
|
||||||
email=user["email_address"],
|
email=user.email_address,
|
||||||
link=f"http{'s' if HAS_HTTPS else ''}://{HOST}"
|
link=f"http{'s' if HAS_HTTPS else ''}://{HOST}"
|
||||||
f"{'' if PORT == 443 and HAS_HTTPS or PORT == 80 else f':{PORT}'}/user/resetPassword/{verification_id}",
|
f"{'' if PORT == 443 and HAS_HTTPS or PORT == 80 else f':{PORT}'}/user/resetPassword/{verification_id}",
|
||||||
platformName=PLATFORM_NAME,
|
platformName=PLATFORM_NAME,
|
||||||
),
|
),
|
||||||
SMTP_TIMEOUT,
|
SMTP_TIMEOUT,
|
||||||
SMTP_FROM_USER,
|
SMTP_FROM_USER,
|
||||||
user["email_address"],
|
user.email_address,
|
||||||
email_message["subject"].format(
|
email_message["subject"].format(
|
||||||
first_name=user["first_name"],
|
first_name=user.first_name,
|
||||||
last_name=user["last_name"],
|
last_name=user.last_name,
|
||||||
username=user["username"],
|
username=user.username,
|
||||||
email=user["email_address"],
|
email=user.email_address,
|
||||||
platformName=PLATFORM_NAME,
|
platformName=PLATFORM_NAME,
|
||||||
),
|
),
|
||||||
SMTP_USER,
|
SMTP_USER,
|
||||||
SMTP_PASSWORD,
|
SMTP_PASSWORD,
|
||||||
use_tls=SMTP_USE_TLS,
|
use_tls=SMTP_USE_TLS,
|
||||||
):
|
):
|
||||||
raise HTTPException(500, detail="An error occurred while trying to send mail, please try again later")
|
raise HTTPException(
|
||||||
|
500,
|
||||||
|
detail="An error occurred while trying to send mail, please try again later",
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
await EmailVerification.insert(
|
await EmailVerification.insert(
|
||||||
EmailVerification(
|
EmailVerification(
|
||||||
{
|
{
|
||||||
EmailVerification.id: verification_id,
|
EmailVerification.id: verification_id,
|
||||||
EmailVerification.user: User(public_id=user["id"]),
|
EmailVerification.user: User(public_id=user.public_id),
|
||||||
EmailVerification.data: bcrypt.hashpw(password.encode(), user["password_hash"][:29])
|
EmailVerification.data: bcrypt.hashpw(
|
||||||
|
password.encode(), user.password_hash[:29]
|
||||||
|
),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
if email_address and user["email_address"] != email_address:
|
if email_address and user.email_address != email_address:
|
||||||
if not user["email_verified"]:
|
if not user["email_verified"]:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=403,
|
status_code=403,
|
||||||
|
@ -805,51 +782,61 @@ async def update(
|
||||||
email_message = json.load(f)["email_change"]
|
email_message = json.load(f)["email_change"]
|
||||||
verification_id = uuid.uuid4()
|
verification_id = uuid.uuid4()
|
||||||
if not await send_email(
|
if not await send_email(
|
||||||
SMTP_HOST,
|
SMTP_HOST,
|
||||||
SMTP_PORT,
|
SMTP_PORT,
|
||||||
email_message["content"].format(
|
email_message["content"].format(
|
||||||
first_name=user["first_name"],
|
first_name=user.first_name,
|
||||||
last_name=user["last_name"],
|
last_name=user.last_name,
|
||||||
username=user["username"],
|
username=user.username,
|
||||||
email=user["email_address"],
|
email=user.email_address,
|
||||||
platformName=PLATFORM_NAME,
|
platformName=PLATFORM_NAME,
|
||||||
link=f"http{'s' if HAS_HTTPS else ''}://{HOST}"
|
link=f"http{'s' if HAS_HTTPS else ''}://{HOST}"
|
||||||
f"{'' if PORT == 443 and HAS_HTTPS or PORT == 80 else f':{PORT}'}/user/changeEmail/{verification_id}",
|
f"{'' if PORT == 443 and HAS_HTTPS or PORT == 80 else f':{PORT}'}/user/changeEmail/{verification_id}",
|
||||||
newMail=email_address,
|
newMail=email_address,
|
||||||
),
|
),
|
||||||
SMTP_TIMEOUT,
|
SMTP_TIMEOUT,
|
||||||
SMTP_FROM_USER,
|
SMTP_FROM_USER,
|
||||||
user["email_address"],
|
user.email_address,
|
||||||
email_message["subject"].format(
|
email_message["subject"].format(
|
||||||
first_name=user["first_name"],
|
first_name=user.first_name,
|
||||||
last_name=user["last_name"],
|
last_name=user.last_name,
|
||||||
username=user["username"],
|
username=user.username,
|
||||||
email=user["email_address"],
|
email=user.email_address,
|
||||||
platformName=PLATFORM_NAME,
|
platformName=PLATFORM_NAME,
|
||||||
),
|
),
|
||||||
SMTP_USER,
|
SMTP_USER,
|
||||||
SMTP_PASSWORD,
|
SMTP_PASSWORD,
|
||||||
use_tls=SMTP_USE_TLS,
|
use_tls=SMTP_USE_TLS,
|
||||||
):
|
):
|
||||||
raise HTTPException(500, detail="An error occurred while trying to send mail, please try again later")
|
raise HTTPException(
|
||||||
|
500,
|
||||||
|
detail="An error occurred while trying to send mail, please try again later",
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
await EmailVerification.insert(EmailVerification({
|
await EmailVerification.insert(
|
||||||
EmailVerification.id: verification_id,
|
EmailVerification(
|
||||||
EmailVerification.user: User(public_id=user["id"]),
|
{
|
||||||
EmailVerification.data: email_address.encode()
|
EmailVerification.id: verification_id,
|
||||||
}))
|
EmailVerification.user: User(public_id=user.public_id),
|
||||||
|
EmailVerification.data: email_address.encode(),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
if not bio:
|
if not bio:
|
||||||
user["bio"] = None
|
user.bio = None
|
||||||
if not profile_picture:
|
if not profile_picture:
|
||||||
user["profile_picture"] = None
|
user["profile_picture"] = None
|
||||||
fields = []
|
fields = []
|
||||||
for field in user:
|
for field in user:
|
||||||
if field not in ["email_address", "password"] and orig_user[field] != user[field]:
|
if (
|
||||||
|
field not in ["email_address", "password"]
|
||||||
|
and orig_user[field] != user[field]
|
||||||
|
):
|
||||||
fields.append((field, user[field]))
|
fields.append((field, user[field]))
|
||||||
if fields:
|
if fields:
|
||||||
# If anything has changed, we update our info
|
# If anything has changed, we update our info
|
||||||
await User.update({field: value for field, value in fields}).where(
|
await User.update({field: value for field, value in fields}).where(
|
||||||
User.public_id == user["id"]
|
User.public_id == user.public_id
|
||||||
)
|
)
|
||||||
return {"status_code": 200, "msg": "Changes saved successfully"}
|
return APIResponse(status_code=200, msg="Changes saved successfully")
|
||||||
|
|
43
main.py
43
main.py
|
@ -9,6 +9,7 @@ from fastapi.exceptions import (
|
||||||
StarletteHTTPException,
|
StarletteHTTPException,
|
||||||
)
|
)
|
||||||
from slowapi.errors import RateLimitExceeded
|
from slowapi.errors import RateLimitExceeded
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
from endpoints import users
|
from endpoints import users
|
||||||
|
@ -25,6 +26,7 @@ from config import (
|
||||||
HOST,
|
HOST,
|
||||||
PORT,
|
PORT,
|
||||||
WORKERS,
|
WORKERS,
|
||||||
|
PLATFORM_NAME,
|
||||||
)
|
)
|
||||||
from orm import create_tables, Media
|
from orm import create_tables, Media
|
||||||
from util.exception_handlers import (
|
from util.exception_handlers import (
|
||||||
|
@ -36,24 +38,55 @@ from util.exception_handlers import (
|
||||||
)
|
)
|
||||||
from util.email import test_smtp
|
from util.email import test_smtp
|
||||||
|
|
||||||
app = FastAPI()
|
|
||||||
|
with (Path(__file__).parent / "README.md").resolve(strict=True).open() as f:
|
||||||
|
description = f.read()
|
||||||
|
|
||||||
|
|
||||||
@app.get("/")
|
app = FastAPI(
|
||||||
|
title=PLATFORM_NAME,
|
||||||
|
license_info={"name": "MIT", "url": "https://opensource.org/licenses/MIT"},
|
||||||
|
contact={
|
||||||
|
"name": "Matt",
|
||||||
|
"url": "https://nocturn9x.space",
|
||||||
|
"email": "nocturn9x@nocturn9x.space",
|
||||||
|
},
|
||||||
|
version="0.0.1",
|
||||||
|
description=description,
|
||||||
|
openapi_tags=[
|
||||||
|
{
|
||||||
|
"name": "Miscellaneous",
|
||||||
|
"description": "Simple endpoints that don't fit anywhere else",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "Users",
|
||||||
|
"description": "Endpoints that handle user-related operations such as"
|
||||||
|
" signing in, signing up, settings management and more",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "Posts",
|
||||||
|
"description": "Endpoints that handle post creation, modification and deletion",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/", tags=["Miscellaneous"])
|
||||||
@LIMITER.limit("10/second")
|
@LIMITER.limit("10/second")
|
||||||
async def root(request: Request):
|
async def root(request: Request):
|
||||||
raise HTTPException(401, detail="Unauthorized")
|
raise HTTPException(401, detail="Unauthorized")
|
||||||
|
|
||||||
|
|
||||||
@app.get("/ping")
|
@app.get("/ping", tags=["Miscellaneous"])
|
||||||
@LIMITER.limit("1/minute")
|
@LIMITER.limit("1/minute")
|
||||||
async def ping(request: Request) -> dict:
|
async def ping(request: Request) -> dict:
|
||||||
"""
|
"""
|
||||||
This handler simply replies to "ping" requests and
|
This method simply replies to "ping" requests and
|
||||||
is used to check whether the API is up and running.
|
is used to check whether the API is up and running.
|
||||||
It also performs a sanity check with the database and
|
It also performs a sanity check with the database and
|
||||||
the SMTP server to ensure that they are functioning correctly.
|
the SMTP server to ensure that they are functioning correctly.
|
||||||
For this reason, this endpoint's rate limit is much stricter
|
For this reason, this endpoint's rate limit is very strict:
|
||||||
|
it can only be called once per minute
|
||||||
"""
|
"""
|
||||||
|
|
||||||
LOGGER.info(f"Processing ping request from {request.client.host}")
|
LOGGER.info(f"Processing ping request from {request.client.host}")
|
||||||
|
|
82
orm/users.py
82
orm/users.py
|
@ -2,6 +2,7 @@
|
||||||
User relation
|
User relation
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
from piccolo.utils.pydantic import create_pydantic_model
|
from piccolo.utils.pydantic import create_pydantic_model
|
||||||
from piccolo.table import Table
|
from piccolo.table import Table
|
||||||
from piccolo.columns import (
|
from piccolo.columns import (
|
||||||
|
@ -15,6 +16,7 @@ from piccolo.columns import (
|
||||||
Boolean,
|
Boolean,
|
||||||
Email,
|
Email,
|
||||||
Bytea,
|
Bytea,
|
||||||
|
Column,
|
||||||
)
|
)
|
||||||
from piccolo.columns.defaults.date import DateNow
|
from piccolo.columns.defaults.date import DateNow
|
||||||
|
|
||||||
|
@ -45,4 +47,82 @@ class User(Table, tablename="users"):
|
||||||
locale = Varchar(length=12, default="en_US", null=False, secret=True)
|
locale = Varchar(length=12, default="en_US", null=False, secret=True)
|
||||||
|
|
||||||
|
|
||||||
UserModel = create_pydantic_model(User)
|
UserModel = create_pydantic_model(
|
||||||
|
User,
|
||||||
|
nested=True,
|
||||||
|
)
|
||||||
|
PublicUserModel = create_pydantic_model(
|
||||||
|
User,
|
||||||
|
nested=True,
|
||||||
|
exclude_columns=(
|
||||||
|
User.internal_id,
|
||||||
|
User.deleted,
|
||||||
|
User.restricted,
|
||||||
|
User.locale,
|
||||||
|
User.email_verified,
|
||||||
|
User.email_address,
|
||||||
|
User.password_hash,
|
||||||
|
),
|
||||||
|
model_name="PublicUser",
|
||||||
|
)
|
||||||
|
PrivateUserModel = create_pydantic_model(
|
||||||
|
User,
|
||||||
|
nested=True,
|
||||||
|
exclude_columns=(User.internal_id, User.password_hash, User.deleted),
|
||||||
|
model_name="PrivateUser",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_user_by_column(
|
||||||
|
column: Column,
|
||||||
|
data: Any,
|
||||||
|
include_secrets: bool = False,
|
||||||
|
restricted_ok: bool = False,
|
||||||
|
deleted_ok: bool = False,
|
||||||
|
) -> UserModel | None:
|
||||||
|
"""
|
||||||
|
Retrieves a user object by a given criteria.
|
||||||
|
Returns None if the user doesn't exist or
|
||||||
|
if it's restricted/deleted (unless restricted_ok
|
||||||
|
and deleted_ok are set accordingly)
|
||||||
|
"""
|
||||||
|
|
||||||
|
user = (
|
||||||
|
await User.select(
|
||||||
|
*User.all_columns(),
|
||||||
|
exclude_secrets=not include_secrets,
|
||||||
|
)
|
||||||
|
.where(column == data)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
if user:
|
||||||
|
# Performs validation
|
||||||
|
user = UserModel(**user)
|
||||||
|
if (user.deleted and not deleted_ok) or (user.restricted and not restricted_ok):
|
||||||
|
return
|
||||||
|
return user
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
async def get_user_by_id(public_id: UUID, *args, **kwargs) -> UserModel:
|
||||||
|
"""
|
||||||
|
Retrieves a user by its public ID
|
||||||
|
"""
|
||||||
|
|
||||||
|
return await get_user_by_column(User.public_id, public_id, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_user_by_username(username: str, *args, **kwargs) -> UserModel | None:
|
||||||
|
"""
|
||||||
|
Retrieves a user by its public username
|
||||||
|
"""
|
||||||
|
|
||||||
|
return await get_user_by_column(User.username, username, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_user_by_email(email: str, *args, **kwargs) -> UserModel | None:
|
||||||
|
"""
|
||||||
|
Retrieves a user by its email address
|
||||||
|
"""
|
||||||
|
|
||||||
|
return await get_user_by_column(User.email_address, email, *args, **kwargs)
|
||||||
|
|
|
@ -7,4 +7,5 @@ uvloop~=0.17.0
|
||||||
fastapi~=0.85.0
|
fastapi~=0.85.0
|
||||||
validators
|
validators
|
||||||
aiosmtplib
|
aiosmtplib
|
||||||
uvicorn
|
uvicorn
|
||||||
|
pydantic
|
|
@ -0,0 +1,96 @@
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from orm.users import User
|
||||||
|
|
||||||
|
|
||||||
|
class Response(BaseModel):
|
||||||
|
"""
|
||||||
|
A generic response model
|
||||||
|
"""
|
||||||
|
|
||||||
|
status_code: int = 200
|
||||||
|
msg: str = "Success"
|
||||||
|
|
||||||
|
|
||||||
|
class UnprocessableEntity(Response):
|
||||||
|
"""
|
||||||
|
A 422 Unprocessable Entity response
|
||||||
|
model
|
||||||
|
"""
|
||||||
|
|
||||||
|
status_code: int = 422
|
||||||
|
msg: str = "Input Validation Failure"
|
||||||
|
|
||||||
|
|
||||||
|
class NotFound(Response):
|
||||||
|
"""
|
||||||
|
A 404 Not Found response
|
||||||
|
model
|
||||||
|
"""
|
||||||
|
|
||||||
|
status_code: int = 404
|
||||||
|
msg: str = "Not Found"
|
||||||
|
|
||||||
|
|
||||||
|
class PayloadTooLarge(Response):
|
||||||
|
"""
|
||||||
|
A 413 Payload Too Large response
|
||||||
|
model
|
||||||
|
"""
|
||||||
|
|
||||||
|
status_code: int = 413
|
||||||
|
msg: str = "Payload too large"
|
||||||
|
|
||||||
|
|
||||||
|
class MediaTypeNotAcceptable(Response):
|
||||||
|
"""
|
||||||
|
A 415 Media Type Not Acceptable response
|
||||||
|
model
|
||||||
|
"""
|
||||||
|
|
||||||
|
status_code: int = 415
|
||||||
|
msg: str = "Media type not acceptable"
|
||||||
|
|
||||||
|
|
||||||
|
class BadRequest(Response):
|
||||||
|
"""
|
||||||
|
A 400 Bad Request response model
|
||||||
|
"""
|
||||||
|
|
||||||
|
status_code: int = 400
|
||||||
|
msg: str = "Bad Request"
|
||||||
|
|
||||||
|
|
||||||
|
class InternalServerError(Response):
|
||||||
|
"""
|
||||||
|
A 500 Internal Server Error response model
|
||||||
|
"""
|
||||||
|
|
||||||
|
status_code: int = 500
|
||||||
|
msg: str = "Internal Server Error"
|
||||||
|
|
||||||
|
|
||||||
|
class TooManyRequests(Response):
|
||||||
|
"""
|
||||||
|
A 429 Too Many Requests response model
|
||||||
|
"""
|
||||||
|
|
||||||
|
status_code: int = 429
|
||||||
|
msg: str = "Too Many Requests"
|
||||||
|
|
||||||
|
|
||||||
|
class Unauthorized(Response):
|
||||||
|
"""
|
||||||
|
A 401 Unauthorized response model
|
||||||
|
"""
|
||||||
|
|
||||||
|
status_code: int = 401
|
||||||
|
msg: str = "Unauthorized"
|
||||||
|
|
||||||
|
|
||||||
|
class Forbidden(Response):
|
||||||
|
"""
|
||||||
|
A 403 Forbidden response model
|
||||||
|
"""
|
||||||
|
|
||||||
|
status_code: int = 403
|
||||||
|
msg: str = "Forbidden"
|
|
@ -0,0 +1,24 @@
|
||||||
|
from responses import Response
|
||||||
|
from orm.users import PublicUserModel, PrivateUserModel
|
||||||
|
|
||||||
|
|
||||||
|
class PublicUserResponse(Response):
|
||||||
|
"""
|
||||||
|
A response sent by an endpoint that
|
||||||
|
replies with a public user object
|
||||||
|
"""
|
||||||
|
|
||||||
|
status_code: int = 200
|
||||||
|
msg: str = "Lookup successful"
|
||||||
|
data: PublicUserModel
|
||||||
|
|
||||||
|
|
||||||
|
class PrivateUserResponse(Response):
|
||||||
|
"""
|
||||||
|
A response sent by an endpoint that
|
||||||
|
replies with a private user object
|
||||||
|
"""
|
||||||
|
|
||||||
|
status_code: int = 200
|
||||||
|
data: PrivateUserModel
|
||||||
|
msg: str = "Lookup successful"
|
|
@ -7,6 +7,9 @@ from slowapi.errors import RateLimitExceeded
|
||||||
|
|
||||||
|
|
||||||
async def rate_limited(request: Request, error: RateLimitExceeded) -> JSONResponse:
|
async def rate_limited(request: Request, error: RateLimitExceeded) -> JSONResponse:
|
||||||
|
"""
|
||||||
|
Handles the equivalent of a 429 Too Many Requests error
|
||||||
|
"""
|
||||||
n = 0
|
n = 0
|
||||||
while True:
|
while True:
|
||||||
if error.detail[n].isnumeric():
|
if error.detail[n].isnumeric():
|
||||||
|
@ -18,79 +21,58 @@ async def rate_limited(request: Request, error: RateLimitExceeded) -> JSONRespon
|
||||||
f"{request.client.host} got rate-limited at {str(request.url)} "
|
f"{request.client.host} got rate-limited at {str(request.url)} "
|
||||||
f"(exceeded {error.detail})"
|
f"(exceeded {error.detail})"
|
||||||
)
|
)
|
||||||
return JSONResponse(
|
return JSONResponse(status_code=200, content=dict(status_code=429,
|
||||||
status_code=200,
|
msg=f"Too many requests, retry after {error.detail[error.detail.find('per') + 4:]}"))
|
||||||
content={
|
|
||||||
"msg": f"Too many requests, retry after {error.detail[error.detail.find('per') + 4:]}",
|
|
||||||
"status_code": 429,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def not_authenticated(request: Request, _: NotAuthenticated) -> JSONResponse:
|
def not_authenticated(request: Request, _: NotAuthenticated) -> JSONResponse:
|
||||||
|
"""
|
||||||
|
Handles the equivalent of a 401 Unauthorized exception
|
||||||
|
"""
|
||||||
|
|
||||||
LOGGER.info(f"{request.client.host} failed to authenticate at {str(request.url)}")
|
LOGGER.info(f"{request.client.host} failed to authenticate at {str(request.url)}")
|
||||||
return JSONResponse(
|
return JSONResponse(status_code=200, content=dict(status_code=401, msg="Authentication is required"))
|
||||||
status_code=200,
|
|
||||||
content={
|
|
||||||
"msg": "Authentication is required",
|
|
||||||
"status_code": 401,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def request_invalid(request: Request, exc: RequestValidationError) -> JSONResponse:
|
def request_invalid(request: Request, exc: RequestValidationError) -> JSONResponse:
|
||||||
|
"""
|
||||||
|
Handles Bad Request exceptions from FastAPI
|
||||||
|
"""
|
||||||
|
|
||||||
LOGGER.info(
|
LOGGER.info(
|
||||||
f"{request.client.host} sent an invalid request at {request.url!r}: {type(exc).__name__}: {exc}"
|
f"{request.client.host} sent an invalid request at {request.url!r}: {type(exc).__name__}: {exc}"
|
||||||
)
|
)
|
||||||
return JSONResponse(
|
return JSONResponse(status_code=200, content=dict(status_code=400, msg=f"Bad request: {type(exc).__name__}: {exc}"))
|
||||||
status_code=200,
|
|
||||||
content={
|
|
||||||
"msg": f"Bad request: {type(exc).__name__}: {exc}",
|
|
||||||
"status_code": 400,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def http_exception(
|
def http_exception(
|
||||||
request: Request, exc: HTTPException | StarletteHTTPException
|
request: Request, exc: HTTPException | StarletteHTTPException
|
||||||
) -> JSONResponse:
|
) -> JSONResponse:
|
||||||
|
"""
|
||||||
|
Handles HTTP-specific exceptions raised explicitly by
|
||||||
|
path operations
|
||||||
|
"""
|
||||||
|
|
||||||
if exc.status_code >= 500:
|
if exc.status_code >= 500:
|
||||||
LOGGER.error(
|
LOGGER.error(
|
||||||
f"{request.client.host} raised a {exc.status_code} error at {request.url!r}:"
|
f"{request.client.host} raised a {exc.status_code} error at {request.url!r}:"
|
||||||
f"{type(exc).__name__}: {exc}"
|
f"{type(exc).__name__}: {exc}"
|
||||||
)
|
)
|
||||||
return JSONResponse(
|
return JSONResponse(status_code=200, content=dict(status_code=500, msg="Internal Server Error"))
|
||||||
status_code=200,
|
|
||||||
content={
|
|
||||||
"msg": "Internal server error",
|
|
||||||
"status_code": exc.status_code,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
LOGGER.info(
|
LOGGER.info(
|
||||||
f"{request.client.host} raised an HTTP error ({exc.status_code}) at {str(request.url)}"
|
f"{request.client.host} raised an HTTP error ({exc.status_code}) at {str(request.url)}"
|
||||||
)
|
)
|
||||||
return JSONResponse(
|
return JSONResponse(status_code=200, content=dict(status_code=exc.status_code, msg=exc.detail))
|
||||||
status_code=200,
|
|
||||||
content={
|
|
||||||
"msg": exc.detail,
|
|
||||||
"status_code": exc.status_code,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def generic_error(request: Request, exc: Exception) -> JSONResponse:
|
async def generic_error(request: Request, exc: Exception) -> JSONResponse:
|
||||||
"""
|
"""
|
||||||
Handles generic, unexpected errors in the application
|
Handles generic, unexpected errors in the ASGI application
|
||||||
"""
|
"""
|
||||||
|
|
||||||
LOGGER.info(
|
LOGGER.info(
|
||||||
f"{request.client.host} raised an unexpected error ({type(exc).__name__}: {exc}) at {str(request.url)}"
|
f"{request.client.host} raised an unexpected error ({type(exc).__name__}: {exc}) at {str(request.url)}"
|
||||||
)
|
)
|
||||||
return JSONResponse(
|
# We can't leak anything about the error, it would be too risky
|
||||||
status_code=200,
|
return JSONResponse(status_code=200, content=dict(status_code=500, msg="Internal Server Error"))
|
||||||
content={
|
|
||||||
"msg": "Internal Server Error", # We can't leak anything about the error, it would be too risky
|
|
||||||
"status_code": 500,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
Loading…
Reference in New Issue