Refactoring

This commit is contained in:
Mattia Giambirtone 2023-03-13 15:59:03 +01:00
parent 6d0d822083
commit 1d1f2368fe
Signed by: nocturn9x
GPG Key ID: 8270F9F467971E59
23 changed files with 839 additions and 460 deletions

143
src/config.py Normal file
View File

@ -0,0 +1,143 @@
# Configuration file
import os
import sys
import copy
import logging
import pathlib
from slowapi import Limiter
from slowapi.util import get_remote_address
from fastapi_login import LoginManager
# Authentication configuration
LOGIN_SECRET_KEY = (
os.getenv("LOGIN_SECRET_KEY") or "login-secret"
) # Recommended value: os.urandom(24).hex()
USE_BEARER_HEADER = True
USE_COOKIE = True
# Logging configuration
LOG_LEVEL = int(os.getenv("LOG_LEVEL") or 0) or 20
LOG_FILE = os.getenv("LOG_FILE") or "" # Empty to disable
LOG_FORMAT = os.getenv("LOG_FORMAT") or "[%(levelname)s - %(asctime)s] %(message)s"
LOG_DATE_FORMAT = os.getenv("LOG_DATE_FORMAT") or "%d/%m/%Y %p"
# Bcrypt configuration
BCRYPT_ROUNDS = (
os.getenv("BCRYPT_ROUNDS") or 10
) # How many rounds are used when salting
# Rate limit configuration
REDIS_URL = "" # Used for rate limits. Empty to fall back to memory
REDIS_OPTIONS = {} # Options for redis
RATELIMIT_ENABLED = False # False to disable rate limiting
RATELIMIT_STRATEGY = "moving-window" # Refer to https://flask-limiter.readthedocs.io/en/stable/strategies.html
# Session configuration
SESSION_EXPIRE_LIMIT = 3600 # Unit is in seconds
SESSION_COOKIE_NAME = "_social_media_session"
COOKIE_SAMESITE_POLICY = "none" # Options are "lax", "none", "strict"
COOKIE_DOMAIN = "localhost" # Empty to disable this
COOKIE_PATH = "/"
COOKIE_HTTPONLY = True
SECURE_COOKIE = False # Set to true in production, False during development (unless your local server has HTTPS)
# SMTP configuration
SMTP_HOST = "mail.nocturn9x.space"
SMTP_USER = "nocturn9x@nocturn9x.space"
SMTP_PASSWORD = "Je2aXGVH33"
SMTP_PORT = 587
SMTP_USE_TLS = True
SMTP_FROM_USER = "nocturn9x@nocturn9x.space"
SMTP_TEMPLATES_DIRECTORY = pathlib.Path(__file__).parent.parent / "templates" / "email"
SMTP_TIMEOUT = 10
# Miscellaneous
# Usernames containing these characters are not valid
INVALID_USERNAME_CHARACTERS = ["@", "\\", "/"] # Empty this to allow any character
# Empty this to disable username validation
VALIDATE_USERNAME_REGEX = (
rf"^([^{''.join(INVALID_USERNAME_CHARACTERS)}]|[a-z0-9A-Z]){{5,32}}$"
)
# Criteria:
# - Between 10 and 72 characters long
# - At least 2 uppercase letters
# - At least 3 lowercase letters
# - At least one special character (!, @, #, $, %, &, *, _, +, /, \, (, ), £, ", ?, ^
# - At least 2 numbers
# You can change the repetitions to enforce stricter/laxer rules or empty
# this field to disable weakness validation
VALIDATE_PASSWORD_REGEX = r"^(?=.*[A-Z]){2,}(?=.*[\"!@%#$&*_^\?\\\/(\)\+\-])+(?=.*[0-9]){2,}(?=.*[a-z]){3,}.{10,72}$"
FORCE_EMAIL_VERIFICATION = False
EMAIL_VERIFICATION_EXPIRATION = 3600 # In seconds
PLATFORM_NAME = "PySimpleSocial" # Used in emails
HAS_HTTPS = False
class NotAuthenticated(Exception):
pass
class AdminNotAuthenticated(Exception):
pass
if __name__ != "__main__":
LOGGER: logging.Logger = logging.getLogger("socialMedia")
LOGGER.setLevel(LOG_LEVEL)
handler = logging.StreamHandler(sys.stderr)
formatter = logging.Formatter(fmt=LOG_FORMAT, datefmt=LOG_DATE_FORMAT)
handler.setFormatter(formatter)
LOGGER.addHandler(handler)
handler.setLevel(LOG_LEVEL)
if LOG_FILE:
file_handler = logging.FileHandler(LOG_FILE, "a", "utf8")
file_handler.setFormatter(formatter)
LOGGER.addHandler(handler)
file_handler.setLevel(LOG_LEVEL)
logging.getLogger("uvicorn").addHandler(file_handler)
LIMITER = Limiter(
key_func=get_remote_address,
strategy=RATELIMIT_STRATEGY,
storage_uri=REDIS_URL or None,
in_memory_fallback_enabled=bool(REDIS_URL),
storage_options=REDIS_OPTIONS,
enabled=RATELIMIT_ENABLED,
)
MANAGER = LoginManager(
LOGIN_SECRET_KEY,
"/user",
use_cookie=True,
cookie_name=SESSION_COOKIE_NAME,
use_header=USE_BEARER_HEADER,
custom_exception=NotAuthenticated,
)
UNVERIFIED_MANAGER = copy.deepcopy(MANAGER)
ADMIN_MANAGER = LoginManager(
LOGIN_SECRET_KEY,
"/admin",
use_cookie=True,
cookie_name=f"{SESSION_COOKIE_NAME}_admin",
use_header=USE_BEARER_HEADER,
custom_exception=AdminNotAuthenticated,
)
# Uvicorn config
HOST = os.getenv("HOST") or "localhost"
PORT = int(os.getenv("PORT") or 0) or 8000
WORKERS = int(os.getenv("WORKERS") or 0) or 1
# Storage configuration
STORAGE_ENGINE = "database" # Stores media inside the database. Other options are "local" to use a local/remote folder
# or "url" for uploading to a CDN
STORAGE_FOLDER = "" # Only needed when STORAGE_ENGINE is set to local
MAX_MEDIA_SIZE = (
5242880 # Max length of media allowed. Anything bigger raises a 415 HTTP Exception
)
ALLOWED_MEDIA_TYPES = ["gif", "png", "jpeg", "tiff"]
ZLIB_COMPRESSION_LEVEL = 9

View File

@ -1,5 +1,4 @@
config.py.example# Configuration file. Each variable can be overridden by a
# corresponding environment variable
# Configuration file
import os
import sys
import copy
@ -46,13 +45,13 @@ COOKIE_HTTPONLY = True
SECURE_COOKIE = False # Set to true in production, False during development (unless your local server has HTTPS)
# SMTP configuration
SMTP_HOST = "smtp.nocturn9x.space"
SMTP_HOST = "smtp.example.com"
SMTP_USER = "info@example.com"
SMTP_PASSWORD = "password"
SMTP_PORT = 587
SMTP_USE_TLS = True
SMTP_FROM_USER = "info@example.com"
SMTP_TEMPLATES_DIRECTORY = pathlib.Path(__file__) / "templates" / "email"
SMTP_TEMPLATES_DIRECTORY = pathlib.Path(__file__).parent.parent / "templates" / "email"
SMTP_TIMEOUT = 10
# Miscellaneous
@ -116,7 +115,7 @@ if __name__ != "__main__":
use_header=USE_BEARER_HEADER,
custom_exception=NotAuthenticated,
)
UNVERIFIED_MANAGER = copy.deepcopy(MANAGE)
UNVERIFIED_MANAGER = copy.deepcopy(MANAGER)
ADMIN_MANAGER = LoginManager(
LOGIN_SECRET_KEY,
"/admin",
@ -130,7 +129,7 @@ if __name__ != "__main__":
HOST = os.getenv("HOST") or "localhost"
PORT = int(os.getenv("PORT") or 0) or 8000
WORKERS = int(os.getenv("PORT") or 0) or 1
WORKERS = int(os.getenv("WORKERS") or 0) or 1
# Storage configuration
@ -139,4 +138,4 @@ STORAGE_ENGINE = "database" # Stores media inside the database. Other options
STORAGE_FOLDER = "" # Only needed when STORAGE_ENGINE is set to local
MAX_MEDIA_SIZE = 5242880 # Max length of media allowed. Anything bigger raises a 415 HTTP Exception
ALLOWED_MEDIA_TYPES = ["gif", "png", "jpeg", "tiff"]
ZLIB_COMPRESSION_LEVEL = 9
ZLIB_COMPRESSION_LEVEL = 9

129
src/endpoints/auth.py Normal file
View File

@ -0,0 +1,129 @@
import bcrypt
from datetime import timedelta
from fastapi.exceptions import HTTPException
from fastapi.security import OAuth2PasswordRequestForm
from fastapi import APIRouter as FastAPI, Depends, Response, Request
from config import (
LOGGER,
SESSION_EXPIRE_LIMIT,
COOKIE_SAMESITE_POLICY,
COOKIE_DOMAIN,
MANAGER,
SESSION_COOKIE_NAME,
LIMITER,
SECURE_COOKIE,
COOKIE_PATH,
COOKIE_HTTPONLY,
UNVERIFIED_MANAGER,
)
from responses import (
Response as APIResponse,
UnprocessableEntity,
BadRequest,
)
from orm.users import (
UserModel,
PrivateUserModel,
get_user_by_username,
)
router = FastAPI()
@router.post(
"/user",
tags=["Users"],
status_code=200,
responses={
200: {"model": APIResponse},
400: {"model": BadRequest},
422: {"model": UnprocessableEntity},
},
)
@LIMITER.limit("5/minute")
async def login(
request: Request, response: Response, data: OAuth2PasswordRequestForm = Depends()
):
"""
Performs user authentication. Endpoint is limited to 5 hits per minute
"""
if request.cookies.get(SESSION_COOKIE_NAME):
raise HTTPException(status_code=400, detail="Please logout first")
username = data.username
if len(username) > 32:
raise HTTPException(
status_code=413, detail="Authentication failed: username is too long"
)
try:
password = data.password.encode()
if len(password) > 72:
raise HTTPException(
status_code=413, detail="Authentication failed: password is too long"
)
except UnicodeEncodeError as e:
LOGGER.warning(
f"An error occurred while attempting to decode password for user {username} -> {type(e).__name__}: {e}"
)
raise HTTPException(
status_code=413,
detail="Authentication failed: invalid characters in password",
)
if not (
user := await get_user_by_username(
username, include_secrets=True, restricted_ok=True
)
):
raise HTTPException(
status_code=413,
detail="Authentication failed: the user does not exist",
)
user: PrivateUserModel
if not bcrypt.checkpw(password, user.password_hash):
raise HTTPException(
status_code=413,
detail="Authentication failed: password mismatch",
)
token = MANAGER.create_access_token(
expires=timedelta(seconds=SESSION_EXPIRE_LIMIT),
data={"sub": str(user.public_id)},
)
response.set_cookie(
secure=SECURE_COOKIE,
key=SESSION_COOKIE_NAME,
max_age=SESSION_EXPIRE_LIMIT,
value=token,
httponly=COOKIE_HTTPONLY,
samesite=COOKIE_SAMESITE_POLICY,
domain=COOKIE_DOMAIN or None,
path=COOKIE_PATH or "/",
)
return APIResponse(status_code=200, msg="Authentication successful")
@router.get(
"/user/logout",
tags=["Users"],
status_code=200,
responses={200: {"model": APIResponse}, 422: {"model": UnprocessableEntity}},
)
@LIMITER.limit("5/minute")
async def logout(
request: Request, response: Response, _user: UserModel = Depends(UNVERIFIED_MANAGER)
):
"""
Deletes a user's session cookie, logging them
out. Endpoint is limited to 5 hits per minute
"""
response.delete_cookie(
secure=SECURE_COOKIE,
key=SESSION_COOKIE_NAME,
httponly=COOKIE_HTTPONLY,
samesite=COOKIE_SAMESITE_POLICY,
domain=COOKIE_DOMAIN or None,
path=COOKIE_PATH or "/",
)
return APIResponse(status_code=200, msg="Logged out")

210
src/endpoints/email.py Normal file
View File

@ -0,0 +1,210 @@
import json
import uuid
from datetime import timedelta
from datetime import timezone, datetime
from fastapi.exceptions import HTTPException
from fastapi import APIRouter as FastAPI, Depends, Request
from config import (
LIMITER,
SMTP_HOST,
SMTP_PORT,
SMTP_USER,
SMTP_PASSWORD,
SMTP_USE_TLS,
SMTP_TIMEOUT,
SMTP_FROM_USER,
SMTP_TEMPLATES_DIRECTORY,
PLATFORM_NAME,
HAS_HTTPS,
HOST,
PORT,
EMAIL_VERIFICATION_EXPIRATION,
UNVERIFIED_MANAGER,
)
from responses import (
Response as APIResponse,
UnprocessableEntity,
BadRequest,
NotFound,
InternalServerError,
)
from orm.users import (
User,
UserModel,
)
from orm.email_verification import EmailVerification, EmailVerificationType
from util.email import send_email
router = FastAPI()
@router.get(
"/user/verifyEmail/{verification_id}",
tags=["Users"],
status_code=200,
responses={
200: {"model": APIResponse},
404: {"model": NotFound},
422: {"model": UnprocessableEntity},
},
)
@LIMITER.limit("3/second")
async def verify_email(
request: Request,
verification_id: str,
user: UserModel = Depends(UNVERIFIED_MANAGER),
):
"""
Verifies a user's email address. Endpoint is
limited to 3 hits per second
"""
if not (
verification := await EmailVerification.select(*EmailVerification.all_columns())
.where(EmailVerification.id == verification_id)
.first()
):
raise HTTPException(status_code=404, detail="Verification ID is invalid")
elif not verification["pending"]:
raise HTTPException(status_code=400, detail="Email is already verified")
elif datetime.now().astimezone(timezone.utc) - verification[
"creation_date"
].astimezone(timezone.utc) > timedelta(seconds=EMAIL_VERIFICATION_EXPIRATION):
raise HTTPException(
status_code=400,
detail="Verification window has expired. Try again",
)
else:
await EmailVerification.update({EmailVerification.pending: False}).where(
EmailVerification.user == user.public_id
)
await User.update({User.email_verified: True}).where(
User.public_id == user.public_id
)
return APIResponse(status_code=200, msg="Verification successful")
@router.get(
"/user/changeEmail/{verification_id}",
tags=["Users"],
status_code=200,
responses={
200: {"model": APIResponse},
400: {"model": BadRequest},
422: {"model": UnprocessableEntity},
404: {"model": NotFound},
},
)
@LIMITER.limit("3/second")
async def change_email(
request: Request,
verification_id: str,
user: UserModel = Depends(UNVERIFIED_MANAGER),
):
"""
Modifies a user's email address.
Endpoint is limited to 3 hits per second
"""
if not (
verification := await EmailVerification.select(*EmailVerification.all_columns())
.where(EmailVerification.id == verification_id)
.first()
):
raise HTTPException(status_code=404, detail="Request ID is invalid")
elif not verification["pending"]:
raise HTTPException(status_code=400, detail="This link has already been used")
elif datetime.now().astimezone(timezone.utc) - verification[
"creation_date"
].astimezone(timezone.utc) > timedelta(seconds=EMAIL_VERIFICATION_EXPIRATION):
raise HTTPException(
status_code=400,
detail="Verification window has expired. Try again",
)
else:
# Note how we don't update based on the verification ID:
# this way, multiple pending email verification requests
# are all cleared at once
await EmailVerification.update({EmailVerification.pending: False}).where(
EmailVerification.user == user.public_id
and EmailVerification.kind == EmailVerificationType.CHANGE_EMAIL
)
await User.update(
{
User.email_address: verification["data"].decode(),
User.email_verified: False,
}
).where(User.public_id == user.public_id)
return APIResponse(status_code=200, msg="Email updated")
@router.put(
"user/resendMail",
tags=["Users"],
status_code=200,
responses={
200: {"model": APIResponse},
400: {"model": BadRequest},
422: {"model": UnprocessableEntity},
500: {"model": InternalServerError},
},
)
@LIMITER.limit("6/minute")
async def resend_email(request: Request, user: UserModel = Depends(UNVERIFIED_MANAGER)):
"""
Resends the verification email to the user if the previous has expired.
Endpoint is limited to 6 hits per minute
"""
if user.email_verified:
raise HTTPException(status_code=400, detail="Email is already verified")
email_template = SMTP_TEMPLATES_DIRECTORY / f"{user.locale}.json"
try:
email_template.resolve(strict=True)
except FileNotFoundError:
email_template = SMTP_TEMPLATES_DIRECTORY / "en_US.json"
with email_template.open() as f:
email_message = json.load(f)["signup"]
verification_id = uuid.uuid4()
if await send_email(
SMTP_HOST,
SMTP_PORT,
email_message["content"].format(
first_name=user.first_name,
last_name=user.last_name,
username=user.username,
email=user.email_address,
link=f"http{'s' if HAS_HTTPS else ''}://{HOST}"
f"{'' if PORT == 443 and HAS_HTTPS or PORT == 80 else f':{PORT}'}/user/verifyEmail/{verification_id}",
platformName=PLATFORM_NAME,
),
SMTP_TIMEOUT,
SMTP_FROM_USER,
user.email_address,
email_message["subject"].format(
first_name=user.first_name,
last_name=user.last_name,
username=user.username,
email=user.email_address,
platformName=PLATFORM_NAME,
),
SMTP_USER,
SMTP_PASSWORD,
use_tls=SMTP_USE_TLS,
):
await EmailVerification.update(
{
EmailVerification.id: verification_id,
EmailVerification.creation_date: datetime.now(),
}
)
return APIResponse(status_code=200, msg="Success")
else:
raise HTTPException(
status_code=500,
detail="An error occurred while trying to resend the email,"
" please try again later",
)

View File

@ -22,13 +22,19 @@ router = FastAPI()
},
)
@LIMITER.limit("2/second")
async def get_media(request: Request, media_id: str, _user: UserModel = Depends(MANAGER)):
async def get_media(
request: Request, media_id: str, _user: UserModel = Depends(MANAGER)
):
"""
Gets a media object by its ID. Endpoint is
limited to 2 hits per second
"""
if (m := await Media.select(Media.media_id).where(Media.media_id == media_id).first()) is None:
if (
m := await Media.select(Media.media_id)
.where(Media.media_id == media_id)
.first()
) is None:
raise HTTPException(status_code=404, detail="Media not found")
m = Media(**m)
if m.media_type == MediaType.FILE:
@ -55,7 +61,9 @@ async def get_media(request: Request, media_id: str, _user: UserModel = Depends(
},
)
@LIMITER.limit("2/second")
async def report_media(request: Request, media_id: str, _user: UserModel = Depends(MANAGER)):
async def report_media(
request: Request, media_id: str, _user: UserModel = Depends(MANAGER)
):
"""
Reports a piece of media by its ID. This creates
a report that can be seen by admins, which can
@ -63,7 +71,11 @@ async def report_media(request: Request, media_id: str, _user: UserModel = Depen
hits per second
"""
if (m := await Media.select(Media.media_id).where(Media.media_id == media_id).first()) is None:
if (
m := await Media.select(Media.media_id)
.where(Media.media_id == media_id)
.first()
) is None:
raise HTTPException(status_code=404, detail="Media not found")
# TODO: Create report
return APIResponse(msg="Success")

73
src/endpoints/password.py Normal file
View File

@ -0,0 +1,73 @@
from datetime import timedelta
from datetime import timezone, datetime
from fastapi.exceptions import HTTPException
from fastapi import APIRouter as FastAPI, Depends, Request
from config import (
LIMITER,
EMAIL_VERIFICATION_EXPIRATION,
UNVERIFIED_MANAGER,
)
from responses import (
Response as APIResponse,
UnprocessableEntity,
BadRequest,
NotFound,
)
from orm.users import User, UserModel
from orm.email_verification import EmailVerification, EmailVerificationType
router = FastAPI()
@router.get(
"/user/resetPassword/{verification_id}",
tags=["Users"],
status_code=200,
responses={
200: {"model": APIResponse},
400: {"model": BadRequest},
422: {"model": UnprocessableEntity},
404: {"model": NotFound},
},
)
@LIMITER.limit("3/second")
async def reset_password(
request: Request,
verification_id: str,
user: UserModel = Depends(UNVERIFIED_MANAGER),
):
"""
Modifies a user's password. Endpoint is limited
to 3 hits per second
"""
if not (
verification := await EmailVerification.select(*EmailVerification.all_columns())
.where(EmailVerification.id == verification_id)
.first()
):
raise HTTPException(status_code=404, detail="Request ID is invalid")
elif not verification["pending"]:
raise HTTPException(status_code=400, detail="This link has already been used")
elif datetime.now().astimezone(timezone.utc) - verification[
"creation_date"
].astimezone(timezone.utc) > timedelta(seconds=EMAIL_VERIFICATION_EXPIRATION):
raise HTTPException(
status_code=400,
detail="Verification window has expired. Try again",
)
else:
# Note how we don't update based on the verification ID:
# this way, multiple pending email verification requests
# are all cleared at once
await EmailVerification.update({EmailVerification.pending: False}).where(
EmailVerification.user == user.public_id
and EmailVerification.kind == EmailVerificationType.PASSWORD_RESET
)
await User.update({User.password_hash: verification["data"]}).where(
User.public_id == user.public_id
)
return APIResponse(status_code=200, msg="Password updated")

View File

@ -1,37 +1,23 @@
import base64
from datetime import timezone, datetime
import hashlib
import json
import re
import imghdr
import uuid
import zlib
import bcrypt
from uuid import UUID
import validators
from fastapi import APIRouter as FastAPI, Depends, Response, Request, UploadFile
from fastapi.exceptions import HTTPException
from fastapi.security import OAuth2PasswordRequestForm
from datetime import timedelta
from pathlib import Path
from fastapi.exceptions import HTTPException
from fastapi import APIRouter as FastAPI, Depends, Response, Request, UploadFile
from config import (
BCRYPT_ROUNDS,
LOGGER,
SESSION_EXPIRE_LIMIT,
COOKIE_SAMESITE_POLICY,
COOKIE_DOMAIN,
MANAGER,
SESSION_COOKIE_NAME,
LIMITER,
VALIDATE_PASSWORD_REGEX,
VALIDATE_USERNAME_REGEX,
SECURE_COOKIE,
COOKIE_PATH,
COOKIE_HTTPONLY,
ALLOWED_MEDIA_TYPES,
ZLIB_COMPRESSION_LEVEL,
MAX_MEDIA_SIZE,
STORAGE_ENGINE,
STORAGE_FOLDER,
SMTP_HOST,
@ -46,8 +32,6 @@ from config import (
HAS_HTTPS,
HOST,
PORT,
EMAIL_VERIFICATION_EXPIRATION,
FORCE_EMAIL_VERIFICATION,
UNVERIFIED_MANAGER,
)
from responses import (
@ -70,116 +54,19 @@ from orm.users import (
PrivateUserModel,
get_user_by_username,
get_user_by_id,
get_user_by_email,
)
from orm.media import Media, MediaType, PublicMediaModel
from orm.email_verification import EmailVerification, EmailVerificationType
from orm.email_verification import EmailVerification
from util.email import send_email
from util.users import validate_user, validate_profile_picture
router = FastAPI()
# Credential loaders for our authenticated routes
@MANAGER.user_loader()
async def get_self_by_id(public_id: UUID, requires_verified: bool = True) -> UserModel:
user = await get_user_by_id(public_id, include_secrets=True, restricted_ok=True)
if FORCE_EMAIL_VERIFICATION and requires_verified and not user.email_verified:
raise HTTPException(status_code=401, detail="Email verification is required")
return user
@UNVERIFIED_MANAGER.user_loader()
async def get_self_by_id_unverified(public_id: UUID) -> UserModel:
return await get_self_by_id(public_id, False)
# Here follow our *beautifully* documented path operations
@router.post(
"/user",
tags=["Users"],
status_code=200,
responses={
200: {"model": APIResponse},
400: {"model": BadRequest},
422: {"model": UnprocessableEntity},
},
)
@LIMITER.limit("5/minute")
async def login(request: Request, response: Response, data: OAuth2PasswordRequestForm = Depends()):
"""
Performs user authentication. Endpoint is limited to 5 hits per minute
"""
if request.cookies.get(SESSION_COOKIE_NAME):
raise HTTPException(status_code=400, detail="Please logout first")
username = data.username
if len(username) > 32:
raise HTTPException(status_code=413, detail="Authentication failed: username is too long")
try:
password = data.password.encode()
if len(password) > 72:
raise HTTPException(status_code=413, detail="Authentication failed: password is too long")
except UnicodeEncodeError as e:
LOGGER.warning(
f"An error occurred while attempting to decode password for user {username} -> {type(e).__name__}: {e}"
)
raise HTTPException(
status_code=413,
detail="Authentication failed: invalid characters in password",
)
if not (user := await get_user_by_username(username, include_secrets=True, restricted_ok=True)):
raise HTTPException(
status_code=413,
detail="Authentication failed: the user does not exist",
)
if not bcrypt.checkpw(password, user.password_hash):
raise HTTPException(
status_code=413,
detail="Authentication failed: password mismatch",
)
token = MANAGER.create_access_token(
expires=timedelta(seconds=SESSION_EXPIRE_LIMIT),
data={"sub": str(user.public_id)},
)
response.set_cookie(
secure=SECURE_COOKIE,
key=SESSION_COOKIE_NAME,
max_age=SESSION_EXPIRE_LIMIT,
value=token,
httponly=COOKIE_HTTPONLY,
samesite=COOKIE_SAMESITE_POLICY,
domain=COOKIE_DOMAIN or None,
path=COOKIE_PATH or "/",
)
return APIResponse(status_code=200, msg="Authentication successful")
@router.get(
"/user/logout",
tags=["Users"],
status_code=200,
responses={200: {"model": APIResponse}, 422: {"model": UnprocessableEntity}},
)
@LIMITER.limit("5/minute")
async def logout(request: Request, response: Response, _user: UserModel = Depends(UNVERIFIED_MANAGER)):
"""
Deletes a user's session cookie, logging them
out. Endpoint is limited to 5 hits per minute
"""
response.delete_cookie(
secure=SECURE_COOKIE,
key=SESSION_COOKIE_NAME,
httponly=COOKIE_HTTPONLY,
samesite=COOKIE_SAMESITE_POLICY,
domain=COOKIE_DOMAIN or None,
path=COOKIE_PATH or "/",
)
return APIResponse(status_code=200, msg="Logged out")
@router.get(
"/user/me",
tags=["Users"],
@ -219,7 +106,9 @@ async def get_self(request: Request, user: UserModel = Depends(UNVERIFIED_MANAGE
},
)
@LIMITER.limit("30/second")
async def get_user_by_name(request: Request, username: str, _auth: UserModel = Depends(MANAGER)):
async def get_user_by_name(
request: Request, username: str, _auth: UserModel = Depends(MANAGER)
):
"""
Fetches a single user by its public username.
Endpoint is limited to 30 hits per second
@ -227,6 +116,7 @@ async def get_user_by_name(request: Request, username: str, _auth: UserModel = D
if not (user := await get_user_by_username(username)):
return NotFound(msg="Lookup failed: the user does not exist")
user: PrivateUserModel
return PublicUserResponse(
data=PublicUserModel(
public_id=user.public_id,
@ -258,14 +148,19 @@ async def get_user_by_name(request: Request, username: str, _auth: UserModel = D
},
)
@LIMITER.limit("30/second")
async def get_user_by_public_id(request: Request, public_id: str, _auth: UserModel = Depends(MANAGER)):
async def get_user_by_public_id(
request: Request, public_id: str, _auth: UserModel = Depends(MANAGER)
):
"""
Fetches a single user by its public ID.
Endpoint is limited to 30 hits per second
"""
if not (user := await get_user_by_id(UUID(public_id))):
raise HTTPException(status_code=404, detail="Lookup failed: the user does not exist")
raise HTTPException(
status_code=404, detail="Lookup failed: the user does not exist"
)
user: PrivateUserModel
return PublicUserResponse(
data=PublicUserModel(
public_id=user.public_id,
@ -286,54 +181,6 @@ async def get_user_by_public_id(request: Request, public_id: str, _auth: UserMod
)
async def validate_user(
first_name: str | None,
last_name: str | None,
username: str | None,
email: str | None,
password: str | None,
bio: str | None,
):
"""
Performs some validation upon user creation. Returns
a tuple (success, msg) to be used by routes. Values
set to None are not checked against
"""
if first_name and len(first_name) > 64:
return False, "first name is too long"
if first_name and len(first_name) < 5:
return False, "first name is too short"
if last_name and len(last_name) > 64:
return False, "last name is too long"
if last_name and len(last_name) < 2:
return False, "last name is too short"
if username and len(username) < 5:
return False, "username is too short"
if username and len(username) > 32:
return False, "username is too long"
if username and VALIDATE_USERNAME_REGEX and not re.match(VALIDATE_USERNAME_REGEX, username):
return False, "username is invalid"
if email and not validators.email(email):
return False, "email is not valid"
if password and len(password) > 72:
return False, "password is too long"
if password and VALIDATE_PASSWORD_REGEX and not re.match(VALIDATE_PASSWORD_REGEX, password):
return False, "password is too weak"
if username and await get_user_by_username(username, deleted_ok=True, restricted_ok=True):
return False, "username is already taken"
if email and await get_user_by_email(email, deleted_ok=True, restricted_ok=True):
return False, "email is already registered"
if bio and len(bio) > 4096:
return False, "bio is too long"
if bio:
try:
bio.encode("utf-8")
except UnicodeDecodeError:
return False, "bio contains invalid characters"
return True, ""
@router.delete(
"/user",
tags=["Users"],
@ -341,7 +188,9 @@ async def validate_user(
responses={200: {"model": APIResponse}, 422: {"model": UnprocessableEntity}},
)
@LIMITER.limit("1/minute")
async def delete(request: Request, response: Response, user: UserModel = Depends(UNVERIFIED_MANAGER)):
async def delete_user(
request: Request, response: Response, user: UserModel = Depends(UNVERIFIED_MANAGER)
):
"""
Sets the user's deleted flag in the database,
without actually deleting the associated
@ -362,219 +211,6 @@ async def delete(request: Request, response: Response, user: UserModel = Depends
return APIResponse(status_code=200, msg="Success")
@router.get(
"/user/verifyEmail/{verification_id}",
tags=["Users"],
status_code=200,
responses={
200: {"model": APIResponse},
404: {"model": NotFound},
422: {"model": UnprocessableEntity},
},
)
@LIMITER.limit("3/second")
async def verify_email(
request: Request,
verification_id: str,
user: UserModel = Depends(UNVERIFIED_MANAGER),
):
"""
Verifies a user's email address. Endpoint is
limited to 3 hits per second
"""
if not (
verification := await EmailVerification.select(*EmailVerification.all_columns())
.where(EmailVerification.id == verification_id)
.first()
):
raise HTTPException(status_code=404, detail="Verification ID is invalid")
elif not verification["pending"]:
raise HTTPException(status_code=400, detail="Email is already verified")
elif datetime.now().astimezone(timezone.utc) - verification["creation_date"].astimezone(timezone.utc) > timedelta(
seconds=EMAIL_VERIFICATION_EXPIRATION
):
raise HTTPException(
status_code=400,
detail="Verification window has expired. Try again",
)
else:
await EmailVerification.update({EmailVerification.pending: False}).where(
EmailVerification.user == user.public_id
)
await User.update({User.email_verified: True}).where(User.public_id == user.public_id)
return APIResponse(status_code=200, msg="Verification successful")
@router.get(
"/user/resetPassword/{verification_id}",
tags=["Users"],
status_code=200,
responses={
200: {"model": APIResponse},
400: {"model": BadRequest},
422: {"model": UnprocessableEntity},
404: {"model": NotFound},
},
)
@LIMITER.limit("3/second")
async def reset_password(
request: Request,
verification_id: str,
user: UserModel = Depends(UNVERIFIED_MANAGER),
):
"""
Modifies a user's password. Endpoint is limited
to 3 hits per second
"""
if not (
verification := await EmailVerification.select(*EmailVerification.all_columns())
.where(EmailVerification.id == verification_id)
.first()
):
raise HTTPException(status_code=404, detail="Request ID is invalid")
elif not verification["pending"]:
raise HTTPException(status_code=400, detail="This link has already been used")
elif datetime.now().astimezone(timezone.utc) - verification["creation_date"].astimezone(timezone.utc) > timedelta(
seconds=EMAIL_VERIFICATION_EXPIRATION
):
raise HTTPException(
status_code=400,
detail="Verification window has expired. Try again",
)
else:
# Note how we don't update based on the verification ID:
# this way, multiple pending email verification requests
# are all cleared at once
await EmailVerification.update({EmailVerification.pending: False}).where(
EmailVerification.user == user.public_id and EmailVerification.kind == EmailVerificationType.PASSWORD_RESET
)
await User.update({User.password_hash: verification["data"]}).where(User.public_id == user.public_id)
return APIResponse(status_code=200, msg="Password updated")
@router.get(
"/user/changeEmail/{verification_id}",
tags=["Users"],
status_code=200,
responses={
200: {"model": APIResponse},
400: {"model": BadRequest},
422: {"model": UnprocessableEntity},
404: {"model": NotFound},
},
)
@LIMITER.limit("3/second")
async def change_email(
request: Request,
verification_id: str,
user: UserModel = Depends(UNVERIFIED_MANAGER),
):
"""
Modifies a user's email address.
Endpoint is limited to 3 hits per second
"""
if not (
verification := await EmailVerification.select(*EmailVerification.all_columns())
.where(EmailVerification.id == verification_id)
.first()
):
raise HTTPException(status_code=404, detail="Request ID is invalid")
elif not verification["pending"]:
raise HTTPException(status_code=400, detail="This link has already been used")
elif datetime.now().astimezone(timezone.utc) - verification["creation_date"].astimezone(timezone.utc) > timedelta(
seconds=EMAIL_VERIFICATION_EXPIRATION
):
raise HTTPException(
status_code=400,
detail="Verification window has expired. Try again",
)
else:
# Note how we don't update based on the verification ID:
# this way, multiple pending email verification requests
# are all cleared at once
await EmailVerification.update({EmailVerification.pending: False}).where(
EmailVerification.user == user.public_id and EmailVerification.kind == EmailVerificationType.CHANGE_EMAIL
)
await User.update(
{
User.email_address: verification["data"].decode(),
User.email_verified: False,
}
).where(User.public_id == user.public_id)
return APIResponse(status_code=200, msg="Email updated")
@router.put(
"user/resendMail",
tags=["Users"],
status_code=200,
responses={
200: {"model": APIResponse},
400: {"model": BadRequest},
422: {"model": UnprocessableEntity},
500: {"model": InternalServerError},
},
)
@LIMITER.limit("6/minute")
async def resend_email(request: Request, user: UserModel = Depends(UNVERIFIED_MANAGER)):
"""
Resends the verification email to the user if the previous has expired.
Endpoint is limited to 6 hits per minute
"""
if user.email_verified:
raise HTTPException(status_code=400, detail="Email is already verified")
email_template = SMTP_TEMPLATES_DIRECTORY / f"{user.locale}.json"
try:
email_template.resolve(strict=True)
except FileNotFoundError:
email_template = SMTP_TEMPLATES_DIRECTORY / "en_US.json"
with email_template.open() as f:
email_message = json.load(f)["signup"]
verification_id = uuid.uuid4()
if await send_email(
SMTP_HOST,
SMTP_PORT,
email_message["content"].format(
first_name=user.first_name,
last_name=user.last_name,
username=user.username,
email=user.email_address,
link=f"http{'s' if HAS_HTTPS else ''}://{HOST}"
f"{'' if PORT == 443 and HAS_HTTPS or PORT == 80 else f':{PORT}'}/user/verifyEmail/{verification_id}",
platformName=PLATFORM_NAME,
),
SMTP_TIMEOUT,
SMTP_FROM_USER,
user.email_address,
email_message["subject"].format(
first_name=user.first_name,
last_name=user.last_name,
username=user.username,
email=user.email_address,
platformName=PLATFORM_NAME,
),
SMTP_USER,
SMTP_PASSWORD,
use_tls=SMTP_USE_TLS,
):
await EmailVerification.update(
{
EmailVerification.id: verification_id,
EmailVerification.creation_date: datetime.now(),
}
)
return APIResponse(status_code=200, msg="Success")
else:
raise HTTPException(
status_code=500,
detail="An error occurred while trying to resend the email," " please try again later",
)
@router.put(
"/user",
tags=["Users"],
@ -607,7 +243,9 @@ async def signup(
raise HTTPException(status_code=400, detail="Please logout first")
# We don't use FastAPI's validation because we want custom error
# messages
result, msg = await validate_user(first_name, last_name, username, email, password, bio)
result, msg = await validate_user(
first_name, last_name, username, email, password, bio
)
if not result:
return APIResponse(status_code=413, msg=f"Signup failed: {msg}")
else:
@ -667,40 +305,11 @@ async def signup(
else:
raise HTTPException(
status_code=500,
detail="An error occurred while sending verification email, please" " try again later",
detail="An error occurred while sending verification email, please"
" try again later",
)
async def validate_profile_picture(
file: UploadFile,
) -> tuple[bool | None, str, bytes, str]:
"""
Validates a profile picture's size and content to see if it fits
our criteria and returns a tuple result, ext, data where result is a
boolean or none (True = check was passed, False = size too large,
None = check was failed for other reasons) indicating if the check was successful,
ext is the file's type and extension, data is a compressed stream of bytes
representing the original media and hash is the file's SHA256 hash encoded in
hexadecimal, before the compression. This function never raises an exception
"""
async with file:
try:
content = await file.read()
if len(content) > MAX_MEDIA_SIZE:
return False, "", b"", ""
if not (ext := imghdr.what(content.decode())) in ALLOWED_MEDIA_TYPES:
return None, "", b"", ""
return (
True,
ext,
zlib.compress(content, ZLIB_COMPRESSION_LEVEL),
hashlib.sha256(content).hexdigest(),
)
except (UnicodeDecodeError, zlib.error):
return None, "", b"", ""
@router.patch(
"/user",
tags=["Users"],
@ -739,9 +348,15 @@ async def update_user(
since they're the only ones that can be set to a null value. Endpoint is limited to 6 hits per minute
"""
if not delete and not any((first_name, last_name, username, profile_picture, email_address, bio, password)):
raise HTTPException(status_code=400, detail="At least one value has to be specified")
result, msg = await validate_user(first_name, last_name, username, email_address, password, bio)
if not delete and not any(
(first_name, last_name, username, profile_picture, email_address, bio, password)
):
raise HTTPException(
status_code=400, detail="At least one value has to be specified"
)
result, msg = await validate_user(
first_name, last_name, username, email_address, password, bio
)
if not result:
raise HTTPException(status_code=413, detail=f"Update failed: {msg}")
orig_user = user.copy()
@ -757,10 +372,16 @@ async def update_user(
if profile_picture:
result, ext, media, digest = validate_profile_picture(profile_picture)
if result is False:
raise HTTPException(status_code=415, detail="The file type is unsupported")
raise HTTPException(
status_code=415, detail="The file type is unsupported"
)
elif result is None:
raise HTTPException(status_code=413, detail="The file is too large")
elif (m := await Media.select(Media.media_id).where(Media.media_id == digest).first()) is None:
elif (
m := await Media.select(Media.media_id)
.where(Media.media_id == digest)
.first()
) is None:
# This media hasn't been already uploaded (either by this user or by someone
# else), so we save it now. If it has been already uploaded, there's no need
# to do it again (that's what the hash is for)
@ -836,7 +457,9 @@ async def update_user(
{
EmailVerification.id: verification_id,
EmailVerification.user: User(public_id=user.public_id),
EmailVerification.data: bcrypt.hashpw(password.encode(), user.password_hash[:29]),
EmailVerification.data: bcrypt.hashpw(
password.encode(), user.password_hash[:29]
),
}
)
)
@ -902,9 +525,13 @@ async def update_user(
user.profile_picture = None
fields = []
for field in user:
if field not in ["email_address", "password"] and getattr(orig_user, field) != getattr(user, field):
if field not in ["email_address", "password"] and getattr(
orig_user, field
) != getattr(user, field):
fields.append((field, getattr(user, field)))
if fields:
# If anything has changed, we update our info
await User.update({field: value for field, value in fields}).where(User.public_id == user.public_id)
await User.update({field: value for field, value in fields}).where(
User.public_id == user.public_id
)
return APIResponse(status_code=200, msg="Changes saved successfully")

View File

@ -1,7 +1,7 @@
import uvloop
import asyncio
import uvicorn
from pathlib import Path
from fastapi import FastAPI, Request
from fastapi.exceptions import (
HTTPException,
@ -9,10 +9,7 @@ from fastapi.exceptions import (
StarletteHTTPException,
)
from slowapi.errors import RateLimitExceeded
from pathlib import Path
from endpoints import users, media
from endpoints import users, media, auth, email, password
from config import (
LOGGER,
LIMITER,
@ -37,9 +34,11 @@ from util.exception_handlers import (
generic_error,
)
from util.email import test_smtp
from config import MANAGER, UNVERIFIED_MANAGER
from util.users import get_self_by_id, get_self_by_id_unverified
with (Path(__file__).parent / "API.md").resolve(strict=True).open() as f:
with (Path(__file__).parent.parent / "API.md").resolve(strict=True).open() as f:
description = f.read()
@ -116,7 +115,9 @@ async def startup_checks():
await create_tables()
await Media.raw("SELECT 1;")
except Exception as e:
LOGGER.error(f"An error occurred while trying to initialize the database -> {type(e).__name__}: {e}")
LOGGER.error(
f"An error occurred while trying to initialize the database -> {type(e).__name__}: {e}"
)
else:
LOGGER.info("Database initialized")
LOGGER.info("Testing SMTP connection")
@ -131,7 +132,10 @@ async def startup_checks():
True,
)
except Exception as e:
LOGGER.error(f"An error occurred while trying to connect to the SMTP server -> {type(e).__name__}: {e}")
LOGGER.error(
f"An error occurred while trying to connect to the SMTP server -> {type(e).__name__}: {e}"
)
raise
else:
LOGGER.info("SMTP test was successful")
@ -141,7 +145,13 @@ if __name__ == "__main__":
LOGGER.debug("Including modules")
app.include_router(users.router)
app.include_router(media.router)
app.include_router(auth.router)
app.include_router(email.router)
app.include_router(password.router)
LOGGER.debug("Configuring authentication and rate-limiting")
app.state.limiter = LIMITER
MANAGER.user_loader(get_self_by_id)
UNVERIFIED_MANAGER.user_loader(get_self_by_id_unverified)
LOGGER.debug("Setting exception handlers")
app.add_exception_handler(RateLimitExceeded, rate_limited)
app.add_exception_handler(NotAuthenticated, not_authenticated)
@ -153,9 +163,21 @@ if __name__ == "__main__":
uvloop.install()
log_config = uvicorn.config.LOGGING_CONFIG
log_config["formatters"]["access"]["datefmt"] = LOGGER.handlers[0].formatter.datefmt
log_config["formatters"]["default"]["datefmt"] = LOGGER.handlers[0].formatter.datefmt
log_config["formatters"]["access"]["fmt"] = LOGGER.handlers[0].formatter._fmt
log_config["formatters"]["default"]["fmt"] = LOGGER.handlers[0].formatter._fmt
log_config["formatters"]["default"]["datefmt"] = LOGGER.handlers[
0
].formatter.datefmt
log_config["formatters"]["access"]["fmt"] = LOGGER.handlers[
0
].formatter._fmt # noqa
log_config["formatters"]["default"]["fmt"] = LOGGER.handlers[
0
].formatter._fmt # noqa
log_config["handlers"]["access"]["stream"] = "ext://sys.stderr"
asyncio.run(startup_checks())
try:
asyncio.run(startup_checks())
except Exception as e:
LOGGER.error(
f"Startup checks failed due to a {type(e).__name__} exception -> {e}"
)
exit(1)
uvicorn.run(host=HOST, port=PORT, app=app, log_config=log_config, workers=WORKERS)

View File

@ -45,7 +45,9 @@ class Media(Table):
MediaModel = create_pydantic_model(Media)
PublicMediaModel = create_pydantic_model(Media, exclude_columns=(Media.flagged, Media.deleted, Media.media_type))
PublicMediaModel = create_pydantic_model(
Media, exclude_columns=(Media.flagged, Media.deleted, Media.media_type)
)
async def get_media_by_column(

View File

@ -38,7 +38,9 @@ class Post(Table, tablename="posts"):
PostModel = create_pydantic_model(Post, nested=True)
PrivatePostModelInternal = create_pydantic_model(Post, nested=True, exclude_columns=(Post.flagged, Post.internal_id))
PrivatePostModelInternal = create_pydantic_model(
Post, nested=True, exclude_columns=(Post.flagged, Post.internal_id)
)
PublicPostModelInternal = create_pydantic_model(
Post, nested=True, exclude_columns=(Post.flagged, Post.deleted, Post.internal_id)
)

16
src/piccolo_conf.py Normal file
View File

@ -0,0 +1,16 @@
from piccolo.conf.apps import AppRegistry
from piccolo.engine.postgres import PostgresEngine
DB = PostgresEngine(
config={
"host": "mouse.db.elephantsql.com",
"user": "zubtlqiv",
"password": "l9bgeChDsTHJUwSFex2N2kSbLZ8VkRch",
"database": "zubtlqiv",
}
)
# A list of paths to piccolo apps
# e.g. ['blog.piccolo_app']
APP_REGISTRY = AppRegistry(apps=[])

View File

@ -47,7 +47,9 @@ async def send_email(
await srv.login(login_email, password)
await srv.sendmail(sender, recipient, msg.as_string())
except (aiosmtplib.SMTPException, asyncio.TimeoutError) as error:
logging.error(f"An error occurred while dealing with {host}:{port} (SMTP): {type(error).__name__}: {error}")
logging.error(
f"An error occurred while dealing with {host}:{port} (SMTP): {type(error).__name__}: {error}"
)
return error
return True

View File

@ -18,7 +18,10 @@ async def rate_limited(request: Request, error: RateLimitExceeded) -> JSONRespon
else:
break
error.detail = error.detail[:n] + " requests" + error.detail[n:]
LOGGER.info(f"{request.client.host} got rate-limited at {str(request.url)} " f"(exceeded {error.detail})")
LOGGER.info(
f"{request.client.host} got rate-limited at {str(request.url)} "
f"(exceeded {error.detail})"
)
return JSONResponse(
status_code=200,
content=dict(
@ -34,7 +37,9 @@ def not_authenticated(request: Request, _: NotAuthenticated) -> JSONResponse:
"""
LOGGER.info(f"{request.client.host} failed to authenticate at {str(request.url)}")
return JSONResponse(status_code=200, content=dict(status_code=401, msg="Authentication is required"))
return JSONResponse(
status_code=200, content=dict(status_code=401, msg="Authentication is required")
)
def request_invalid(request: Request, exc: RequestValidationError) -> JSONResponse:
@ -42,14 +47,18 @@ def request_invalid(request: Request, exc: RequestValidationError) -> JSONRespon
Handles Bad Request exceptions from FastAPI
"""
LOGGER.info(f"{request.client.host} sent an invalid request at {request.url!r}: {type(exc).__name__}: {exc}")
LOGGER.info(
f"{request.client.host} sent an invalid request at {request.url!r}: {type(exc).__name__}: {exc}"
)
return JSONResponse(
status_code=200,
content=dict(status_code=400, msg=f"Bad request: {type(exc).__name__}: {exc}"),
)
def http_exception(request: Request, exc: HTTPException | StarletteHTTPException) -> JSONResponse:
def http_exception(
request: Request, exc: HTTPException | StarletteHTTPException
) -> JSONResponse:
"""
Handles HTTP-specific exceptions raised explicitly by
path operations
@ -57,12 +66,19 @@ def http_exception(request: Request, exc: HTTPException | StarletteHTTPException
if exc.status_code >= 500:
LOGGER.error(
f"{request.client.host} raised a {exc.status_code} error at {request.url!r}:" f"{type(exc).__name__}: {exc}"
f"{request.client.host} raised a {exc.status_code} error at {request.url!r}:"
f"{type(exc).__name__}: {exc}"
)
return JSONResponse(
status_code=200, content=dict(status_code=500, msg="Internal Server Error")
)
return JSONResponse(status_code=200, content=dict(status_code=500, msg="Internal Server Error"))
else:
LOGGER.info(f"{request.client.host} raised an HTTP error ({exc.status_code}) at {str(request.url)}")
return JSONResponse(status_code=200, content=dict(status_code=exc.status_code, msg=exc.detail))
LOGGER.info(
f"{request.client.host} raised an HTTP error ({exc.status_code}) at {str(request.url)}"
)
return JSONResponse(
status_code=200, content=dict(status_code=exc.status_code, msg=exc.detail)
)
async def generic_error(request: Request, exc: Exception) -> JSONResponse:
@ -70,6 +86,10 @@ async def generic_error(request: Request, exc: Exception) -> JSONResponse:
Handles generic, unexpected errors in the ASGI application
"""
LOGGER.info(f"{request.client.host} raised an unexpected error ({type(exc).__name__}: {exc}) at {str(request.url)}")
LOGGER.info(
f"{request.client.host} raised an unexpected error ({type(exc).__name__}: {exc}) at {str(request.url)}"
)
# We can't leak anything about the error, it would be too risky
return JSONResponse(status_code=200, content=dict(status_code=500, msg="Internal Server Error"))
return JSONResponse(
status_code=200, content=dict(status_code=500, msg="Internal Server Error")
)

122
src/util/users.py Normal file
View File

@ -0,0 +1,122 @@
import hashlib
import imghdr
import re
import zlib
import validators
from uuid import UUID
from config import (
VALIDATE_PASSWORD_REGEX,
VALIDATE_USERNAME_REGEX,
FORCE_EMAIL_VERIFICATION,
MAX_MEDIA_SIZE,
ZLIB_COMPRESSION_LEVEL,
ALLOWED_MEDIA_TYPES,
)
from orm.users import (
get_user_by_username,
get_user_by_email,
UserModel,
get_user_by_id,
)
from fastapi import UploadFile
from fastapi.exceptions import HTTPException
async def validate_user(
first_name: str | None,
last_name: str | None,
username: str | None,
email: str | None,
password: str | None,
bio: str | None,
):
"""
Performs some validation upon user creation. Returns
a tuple (success, msg) to be used by routes. Values
set to None are not checked against
"""
if first_name and len(first_name) > 64:
return False, "first name is too long"
if first_name and len(first_name) < 5:
return False, "first name is too short"
if last_name and len(last_name) > 64:
return False, "last name is too long"
if last_name and len(last_name) < 2:
return False, "last name is too short"
if username and len(username) < 5:
return False, "username is too short"
if username and len(username) > 32:
return False, "username is too long"
if (
username
and VALIDATE_USERNAME_REGEX
and not re.match(VALIDATE_USERNAME_REGEX, username)
):
return False, "username is invalid"
if email and not validators.email(email):
return False, "email is not valid"
if password and len(password) > 72:
return False, "password is too long"
if (
password
and VALIDATE_PASSWORD_REGEX
and not re.match(VALIDATE_PASSWORD_REGEX, password)
):
return False, "password is too weak"
if username and await get_user_by_username(
username, deleted_ok=True, restricted_ok=True
):
return False, "username is already taken"
if email and await get_user_by_email(email, deleted_ok=True, restricted_ok=True):
return False, "email is already registered"
if bio and len(bio) > 4096:
return False, "bio is too long"
if bio:
try:
bio.encode("utf-8")
except UnicodeDecodeError:
return False, "bio contains invalid characters"
return True, ""
async def validate_profile_picture(
file: UploadFile,
) -> tuple[bool | None, str, bytes, str]:
"""
Validates a profile picture's size and content to see if it fits
our criteria and returns a tuple result, ext, data where result is a
boolean or none (True = check was passed, False = size too large,
None = check was failed for other reasons) indicating if the check was successful,
ext is the file's type and extension, data is a compressed stream of bytes
representing the original media and hash is the file's SHA256 hash encoded in
hexadecimal, before the compression. This function never raises an exception
"""
async with file:
try:
content = await file.read()
if len(content) > MAX_MEDIA_SIZE:
return False, "", b"", ""
if not (ext := imghdr.what(content.decode())) in ALLOWED_MEDIA_TYPES:
return None, "", b"", ""
return (
True,
ext,
zlib.compress(content, ZLIB_COMPRESSION_LEVEL),
hashlib.sha256(content).hexdigest(),
)
except (UnicodeDecodeError, zlib.error):
return None, "", b"", ""
# Credential loaders for our authenticated routes
async def get_self_by_id(public_id: UUID, requires_verified: bool = True) -> UserModel:
user = await get_user_by_id(public_id, include_secrets=True, restricted_ok=True)
if FORCE_EMAIL_VERIFICATION and requires_verified and not user.email_verified:
raise HTTPException(status_code=401, detail="Email verification is required")
return user
async def get_self_by_id_unverified(public_id: UUID) -> UserModel:
return await get_self_by_id(public_id, False)