Refactoring
This commit is contained in:
parent
6d0d822083
commit
1d1f2368fe
|
@ -0,0 +1,143 @@
|
||||||
|
# Configuration file
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import copy
|
||||||
|
import logging
|
||||||
|
import pathlib
|
||||||
|
|
||||||
|
from slowapi import Limiter
|
||||||
|
from slowapi.util import get_remote_address
|
||||||
|
from fastapi_login import LoginManager
|
||||||
|
|
||||||
|
|
||||||
|
# Authentication configuration
|
||||||
|
LOGIN_SECRET_KEY = (
|
||||||
|
os.getenv("LOGIN_SECRET_KEY") or "login-secret"
|
||||||
|
) # Recommended value: os.urandom(24).hex()
|
||||||
|
USE_BEARER_HEADER = True
|
||||||
|
USE_COOKIE = True
|
||||||
|
|
||||||
|
# Logging configuration
|
||||||
|
LOG_LEVEL = int(os.getenv("LOG_LEVEL") or 0) or 20
|
||||||
|
LOG_FILE = os.getenv("LOG_FILE") or "" # Empty to disable
|
||||||
|
LOG_FORMAT = os.getenv("LOG_FORMAT") or "[%(levelname)s - %(asctime)s] %(message)s"
|
||||||
|
LOG_DATE_FORMAT = os.getenv("LOG_DATE_FORMAT") or "%d/%m/%Y %p"
|
||||||
|
|
||||||
|
# Bcrypt configuration
|
||||||
|
BCRYPT_ROUNDS = (
|
||||||
|
os.getenv("BCRYPT_ROUNDS") or 10
|
||||||
|
) # How many rounds are used when salting
|
||||||
|
|
||||||
|
|
||||||
|
# Rate limit configuration
|
||||||
|
REDIS_URL = "" # Used for rate limits. Empty to fall back to memory
|
||||||
|
REDIS_OPTIONS = {} # Options for redis
|
||||||
|
RATELIMIT_ENABLED = False # False to disable rate limiting
|
||||||
|
RATELIMIT_STRATEGY = "moving-window" # Refer to https://flask-limiter.readthedocs.io/en/stable/strategies.html
|
||||||
|
|
||||||
|
# Session configuration
|
||||||
|
SESSION_EXPIRE_LIMIT = 3600 # Unit is in seconds
|
||||||
|
SESSION_COOKIE_NAME = "_social_media_session"
|
||||||
|
COOKIE_SAMESITE_POLICY = "none" # Options are "lax", "none", "strict"
|
||||||
|
COOKIE_DOMAIN = "localhost" # Empty to disable this
|
||||||
|
COOKIE_PATH = "/"
|
||||||
|
COOKIE_HTTPONLY = True
|
||||||
|
SECURE_COOKIE = False # Set to true in production, False during development (unless your local server has HTTPS)
|
||||||
|
|
||||||
|
# SMTP configuration
|
||||||
|
SMTP_HOST = "mail.nocturn9x.space"
|
||||||
|
SMTP_USER = "nocturn9x@nocturn9x.space"
|
||||||
|
SMTP_PASSWORD = "Je2aXGVH33"
|
||||||
|
SMTP_PORT = 587
|
||||||
|
SMTP_USE_TLS = True
|
||||||
|
SMTP_FROM_USER = "nocturn9x@nocturn9x.space"
|
||||||
|
SMTP_TEMPLATES_DIRECTORY = pathlib.Path(__file__).parent.parent / "templates" / "email"
|
||||||
|
SMTP_TIMEOUT = 10
|
||||||
|
|
||||||
|
# Miscellaneous
|
||||||
|
|
||||||
|
# Usernames containing these characters are not valid
|
||||||
|
INVALID_USERNAME_CHARACTERS = ["@", "\\", "/"] # Empty this to allow any character
|
||||||
|
# Empty this to disable username validation
|
||||||
|
VALIDATE_USERNAME_REGEX = (
|
||||||
|
rf"^([^{''.join(INVALID_USERNAME_CHARACTERS)}]|[a-z0-9A-Z]){{5,32}}$"
|
||||||
|
)
|
||||||
|
# Criteria:
|
||||||
|
# - Between 10 and 72 characters long
|
||||||
|
# - At least 2 uppercase letters
|
||||||
|
# - At least 3 lowercase letters
|
||||||
|
# - At least one special character (!, @, #, $, %, &, *, _, +, /, \, (, ), £, ", ?, ^
|
||||||
|
# - At least 2 numbers
|
||||||
|
# You can change the repetitions to enforce stricter/laxer rules or empty
|
||||||
|
# this field to disable weakness validation
|
||||||
|
VALIDATE_PASSWORD_REGEX = r"^(?=.*[A-Z]){2,}(?=.*[\"!@%#$&*_^\?\\\/(\)\+\-])+(?=.*[0-9]){2,}(?=.*[a-z]){3,}.{10,72}$"
|
||||||
|
FORCE_EMAIL_VERIFICATION = False
|
||||||
|
EMAIL_VERIFICATION_EXPIRATION = 3600 # In seconds
|
||||||
|
PLATFORM_NAME = "PySimpleSocial" # Used in emails
|
||||||
|
HAS_HTTPS = False
|
||||||
|
|
||||||
|
|
||||||
|
class NotAuthenticated(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class AdminNotAuthenticated(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ != "__main__":
|
||||||
|
LOGGER: logging.Logger = logging.getLogger("socialMedia")
|
||||||
|
LOGGER.setLevel(LOG_LEVEL)
|
||||||
|
handler = logging.StreamHandler(sys.stderr)
|
||||||
|
formatter = logging.Formatter(fmt=LOG_FORMAT, datefmt=LOG_DATE_FORMAT)
|
||||||
|
handler.setFormatter(formatter)
|
||||||
|
LOGGER.addHandler(handler)
|
||||||
|
handler.setLevel(LOG_LEVEL)
|
||||||
|
if LOG_FILE:
|
||||||
|
file_handler = logging.FileHandler(LOG_FILE, "a", "utf8")
|
||||||
|
file_handler.setFormatter(formatter)
|
||||||
|
LOGGER.addHandler(handler)
|
||||||
|
file_handler.setLevel(LOG_LEVEL)
|
||||||
|
logging.getLogger("uvicorn").addHandler(file_handler)
|
||||||
|
LIMITER = Limiter(
|
||||||
|
key_func=get_remote_address,
|
||||||
|
strategy=RATELIMIT_STRATEGY,
|
||||||
|
storage_uri=REDIS_URL or None,
|
||||||
|
in_memory_fallback_enabled=bool(REDIS_URL),
|
||||||
|
storage_options=REDIS_OPTIONS,
|
||||||
|
enabled=RATELIMIT_ENABLED,
|
||||||
|
)
|
||||||
|
MANAGER = LoginManager(
|
||||||
|
LOGIN_SECRET_KEY,
|
||||||
|
"/user",
|
||||||
|
use_cookie=True,
|
||||||
|
cookie_name=SESSION_COOKIE_NAME,
|
||||||
|
use_header=USE_BEARER_HEADER,
|
||||||
|
custom_exception=NotAuthenticated,
|
||||||
|
)
|
||||||
|
UNVERIFIED_MANAGER = copy.deepcopy(MANAGER)
|
||||||
|
ADMIN_MANAGER = LoginManager(
|
||||||
|
LOGIN_SECRET_KEY,
|
||||||
|
"/admin",
|
||||||
|
use_cookie=True,
|
||||||
|
cookie_name=f"{SESSION_COOKIE_NAME}_admin",
|
||||||
|
use_header=USE_BEARER_HEADER,
|
||||||
|
custom_exception=AdminNotAuthenticated,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Uvicorn config
|
||||||
|
|
||||||
|
HOST = os.getenv("HOST") or "localhost"
|
||||||
|
PORT = int(os.getenv("PORT") or 0) or 8000
|
||||||
|
WORKERS = int(os.getenv("WORKERS") or 0) or 1
|
||||||
|
|
||||||
|
# Storage configuration
|
||||||
|
|
||||||
|
STORAGE_ENGINE = "database" # Stores media inside the database. Other options are "local" to use a local/remote folder
|
||||||
|
# or "url" for uploading to a CDN
|
||||||
|
STORAGE_FOLDER = "" # Only needed when STORAGE_ENGINE is set to local
|
||||||
|
MAX_MEDIA_SIZE = (
|
||||||
|
5242880 # Max length of media allowed. Anything bigger raises a 415 HTTP Exception
|
||||||
|
)
|
||||||
|
ALLOWED_MEDIA_TYPES = ["gif", "png", "jpeg", "tiff"]
|
||||||
|
ZLIB_COMPRESSION_LEVEL = 9
|
|
@ -1,5 +1,4 @@
|
||||||
config.py.example# Configuration file. Each variable can be overridden by a
|
# Configuration file
|
||||||
# corresponding environment variable
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import copy
|
import copy
|
||||||
|
@ -46,13 +45,13 @@ COOKIE_HTTPONLY = True
|
||||||
SECURE_COOKIE = False # Set to true in production, False during development (unless your local server has HTTPS)
|
SECURE_COOKIE = False # Set to true in production, False during development (unless your local server has HTTPS)
|
||||||
|
|
||||||
# SMTP configuration
|
# SMTP configuration
|
||||||
SMTP_HOST = "smtp.nocturn9x.space"
|
SMTP_HOST = "smtp.example.com"
|
||||||
SMTP_USER = "info@example.com"
|
SMTP_USER = "info@example.com"
|
||||||
SMTP_PASSWORD = "password"
|
SMTP_PASSWORD = "password"
|
||||||
SMTP_PORT = 587
|
SMTP_PORT = 587
|
||||||
SMTP_USE_TLS = True
|
SMTP_USE_TLS = True
|
||||||
SMTP_FROM_USER = "info@example.com"
|
SMTP_FROM_USER = "info@example.com"
|
||||||
SMTP_TEMPLATES_DIRECTORY = pathlib.Path(__file__) / "templates" / "email"
|
SMTP_TEMPLATES_DIRECTORY = pathlib.Path(__file__).parent.parent / "templates" / "email"
|
||||||
SMTP_TIMEOUT = 10
|
SMTP_TIMEOUT = 10
|
||||||
|
|
||||||
# Miscellaneous
|
# Miscellaneous
|
||||||
|
@ -116,7 +115,7 @@ if __name__ != "__main__":
|
||||||
use_header=USE_BEARER_HEADER,
|
use_header=USE_BEARER_HEADER,
|
||||||
custom_exception=NotAuthenticated,
|
custom_exception=NotAuthenticated,
|
||||||
)
|
)
|
||||||
UNVERIFIED_MANAGER = copy.deepcopy(MANAGE)
|
UNVERIFIED_MANAGER = copy.deepcopy(MANAGER)
|
||||||
ADMIN_MANAGER = LoginManager(
|
ADMIN_MANAGER = LoginManager(
|
||||||
LOGIN_SECRET_KEY,
|
LOGIN_SECRET_KEY,
|
||||||
"/admin",
|
"/admin",
|
||||||
|
@ -130,7 +129,7 @@ if __name__ != "__main__":
|
||||||
|
|
||||||
HOST = os.getenv("HOST") or "localhost"
|
HOST = os.getenv("HOST") or "localhost"
|
||||||
PORT = int(os.getenv("PORT") or 0) or 8000
|
PORT = int(os.getenv("PORT") or 0) or 8000
|
||||||
WORKERS = int(os.getenv("PORT") or 0) or 1
|
WORKERS = int(os.getenv("WORKERS") or 0) or 1
|
||||||
|
|
||||||
# Storage configuration
|
# Storage configuration
|
||||||
|
|
||||||
|
@ -139,4 +138,4 @@ STORAGE_ENGINE = "database" # Stores media inside the database. Other options
|
||||||
STORAGE_FOLDER = "" # Only needed when STORAGE_ENGINE is set to local
|
STORAGE_FOLDER = "" # Only needed when STORAGE_ENGINE is set to local
|
||||||
MAX_MEDIA_SIZE = 5242880 # Max length of media allowed. Anything bigger raises a 415 HTTP Exception
|
MAX_MEDIA_SIZE = 5242880 # Max length of media allowed. Anything bigger raises a 415 HTTP Exception
|
||||||
ALLOWED_MEDIA_TYPES = ["gif", "png", "jpeg", "tiff"]
|
ALLOWED_MEDIA_TYPES = ["gif", "png", "jpeg", "tiff"]
|
||||||
ZLIB_COMPRESSION_LEVEL = 9
|
ZLIB_COMPRESSION_LEVEL = 9
|
|
@ -0,0 +1,129 @@
|
||||||
|
import bcrypt
|
||||||
|
from datetime import timedelta
|
||||||
|
from fastapi.exceptions import HTTPException
|
||||||
|
from fastapi.security import OAuth2PasswordRequestForm
|
||||||
|
from fastapi import APIRouter as FastAPI, Depends, Response, Request
|
||||||
|
|
||||||
|
|
||||||
|
from config import (
|
||||||
|
LOGGER,
|
||||||
|
SESSION_EXPIRE_LIMIT,
|
||||||
|
COOKIE_SAMESITE_POLICY,
|
||||||
|
COOKIE_DOMAIN,
|
||||||
|
MANAGER,
|
||||||
|
SESSION_COOKIE_NAME,
|
||||||
|
LIMITER,
|
||||||
|
SECURE_COOKIE,
|
||||||
|
COOKIE_PATH,
|
||||||
|
COOKIE_HTTPONLY,
|
||||||
|
UNVERIFIED_MANAGER,
|
||||||
|
)
|
||||||
|
from responses import (
|
||||||
|
Response as APIResponse,
|
||||||
|
UnprocessableEntity,
|
||||||
|
BadRequest,
|
||||||
|
)
|
||||||
|
from orm.users import (
|
||||||
|
UserModel,
|
||||||
|
PrivateUserModel,
|
||||||
|
get_user_by_username,
|
||||||
|
)
|
||||||
|
|
||||||
|
router = FastAPI()
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/user",
|
||||||
|
tags=["Users"],
|
||||||
|
status_code=200,
|
||||||
|
responses={
|
||||||
|
200: {"model": APIResponse},
|
||||||
|
400: {"model": BadRequest},
|
||||||
|
422: {"model": UnprocessableEntity},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
@LIMITER.limit("5/minute")
|
||||||
|
async def login(
|
||||||
|
request: Request, response: Response, data: OAuth2PasswordRequestForm = Depends()
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Performs user authentication. Endpoint is limited to 5 hits per minute
|
||||||
|
"""
|
||||||
|
|
||||||
|
if request.cookies.get(SESSION_COOKIE_NAME):
|
||||||
|
raise HTTPException(status_code=400, detail="Please logout first")
|
||||||
|
username = data.username
|
||||||
|
if len(username) > 32:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=413, detail="Authentication failed: username is too long"
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
password = data.password.encode()
|
||||||
|
if len(password) > 72:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=413, detail="Authentication failed: password is too long"
|
||||||
|
)
|
||||||
|
except UnicodeEncodeError as e:
|
||||||
|
LOGGER.warning(
|
||||||
|
f"An error occurred while attempting to decode password for user {username} -> {type(e).__name__}: {e}"
|
||||||
|
)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=413,
|
||||||
|
detail="Authentication failed: invalid characters in password",
|
||||||
|
)
|
||||||
|
if not (
|
||||||
|
user := await get_user_by_username(
|
||||||
|
username, include_secrets=True, restricted_ok=True
|
||||||
|
)
|
||||||
|
):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=413,
|
||||||
|
detail="Authentication failed: the user does not exist",
|
||||||
|
)
|
||||||
|
user: PrivateUserModel
|
||||||
|
if not bcrypt.checkpw(password, user.password_hash):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=413,
|
||||||
|
detail="Authentication failed: password mismatch",
|
||||||
|
)
|
||||||
|
token = MANAGER.create_access_token(
|
||||||
|
expires=timedelta(seconds=SESSION_EXPIRE_LIMIT),
|
||||||
|
data={"sub": str(user.public_id)},
|
||||||
|
)
|
||||||
|
response.set_cookie(
|
||||||
|
secure=SECURE_COOKIE,
|
||||||
|
key=SESSION_COOKIE_NAME,
|
||||||
|
max_age=SESSION_EXPIRE_LIMIT,
|
||||||
|
value=token,
|
||||||
|
httponly=COOKIE_HTTPONLY,
|
||||||
|
samesite=COOKIE_SAMESITE_POLICY,
|
||||||
|
domain=COOKIE_DOMAIN or None,
|
||||||
|
path=COOKIE_PATH or "/",
|
||||||
|
)
|
||||||
|
return APIResponse(status_code=200, msg="Authentication successful")
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/user/logout",
|
||||||
|
tags=["Users"],
|
||||||
|
status_code=200,
|
||||||
|
responses={200: {"model": APIResponse}, 422: {"model": UnprocessableEntity}},
|
||||||
|
)
|
||||||
|
@LIMITER.limit("5/minute")
|
||||||
|
async def logout(
|
||||||
|
request: Request, response: Response, _user: UserModel = Depends(UNVERIFIED_MANAGER)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Deletes a user's session cookie, logging them
|
||||||
|
out. Endpoint is limited to 5 hits per minute
|
||||||
|
"""
|
||||||
|
|
||||||
|
response.delete_cookie(
|
||||||
|
secure=SECURE_COOKIE,
|
||||||
|
key=SESSION_COOKIE_NAME,
|
||||||
|
httponly=COOKIE_HTTPONLY,
|
||||||
|
samesite=COOKIE_SAMESITE_POLICY,
|
||||||
|
domain=COOKIE_DOMAIN or None,
|
||||||
|
path=COOKIE_PATH or "/",
|
||||||
|
)
|
||||||
|
return APIResponse(status_code=200, msg="Logged out")
|
|
@ -0,0 +1,210 @@
|
||||||
|
import json
|
||||||
|
import uuid
|
||||||
|
from datetime import timedelta
|
||||||
|
from datetime import timezone, datetime
|
||||||
|
from fastapi.exceptions import HTTPException
|
||||||
|
from fastapi import APIRouter as FastAPI, Depends, Request
|
||||||
|
|
||||||
|
|
||||||
|
from config import (
|
||||||
|
LIMITER,
|
||||||
|
SMTP_HOST,
|
||||||
|
SMTP_PORT,
|
||||||
|
SMTP_USER,
|
||||||
|
SMTP_PASSWORD,
|
||||||
|
SMTP_USE_TLS,
|
||||||
|
SMTP_TIMEOUT,
|
||||||
|
SMTP_FROM_USER,
|
||||||
|
SMTP_TEMPLATES_DIRECTORY,
|
||||||
|
PLATFORM_NAME,
|
||||||
|
HAS_HTTPS,
|
||||||
|
HOST,
|
||||||
|
PORT,
|
||||||
|
EMAIL_VERIFICATION_EXPIRATION,
|
||||||
|
UNVERIFIED_MANAGER,
|
||||||
|
)
|
||||||
|
from responses import (
|
||||||
|
Response as APIResponse,
|
||||||
|
UnprocessableEntity,
|
||||||
|
BadRequest,
|
||||||
|
NotFound,
|
||||||
|
InternalServerError,
|
||||||
|
)
|
||||||
|
from orm.users import (
|
||||||
|
User,
|
||||||
|
UserModel,
|
||||||
|
)
|
||||||
|
from orm.email_verification import EmailVerification, EmailVerificationType
|
||||||
|
from util.email import send_email
|
||||||
|
|
||||||
|
|
||||||
|
router = FastAPI()
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/user/verifyEmail/{verification_id}",
|
||||||
|
tags=["Users"],
|
||||||
|
status_code=200,
|
||||||
|
responses={
|
||||||
|
200: {"model": APIResponse},
|
||||||
|
404: {"model": NotFound},
|
||||||
|
422: {"model": UnprocessableEntity},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
@LIMITER.limit("3/second")
|
||||||
|
async def verify_email(
|
||||||
|
request: Request,
|
||||||
|
verification_id: str,
|
||||||
|
user: UserModel = Depends(UNVERIFIED_MANAGER),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Verifies a user's email address. Endpoint is
|
||||||
|
limited to 3 hits per second
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not (
|
||||||
|
verification := await EmailVerification.select(*EmailVerification.all_columns())
|
||||||
|
.where(EmailVerification.id == verification_id)
|
||||||
|
.first()
|
||||||
|
):
|
||||||
|
raise HTTPException(status_code=404, detail="Verification ID is invalid")
|
||||||
|
elif not verification["pending"]:
|
||||||
|
raise HTTPException(status_code=400, detail="Email is already verified")
|
||||||
|
elif datetime.now().astimezone(timezone.utc) - verification[
|
||||||
|
"creation_date"
|
||||||
|
].astimezone(timezone.utc) > timedelta(seconds=EMAIL_VERIFICATION_EXPIRATION):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail="Verification window has expired. Try again",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
await EmailVerification.update({EmailVerification.pending: False}).where(
|
||||||
|
EmailVerification.user == user.public_id
|
||||||
|
)
|
||||||
|
await User.update({User.email_verified: True}).where(
|
||||||
|
User.public_id == user.public_id
|
||||||
|
)
|
||||||
|
return APIResponse(status_code=200, msg="Verification successful")
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/user/changeEmail/{verification_id}",
|
||||||
|
tags=["Users"],
|
||||||
|
status_code=200,
|
||||||
|
responses={
|
||||||
|
200: {"model": APIResponse},
|
||||||
|
400: {"model": BadRequest},
|
||||||
|
422: {"model": UnprocessableEntity},
|
||||||
|
404: {"model": NotFound},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
@LIMITER.limit("3/second")
|
||||||
|
async def change_email(
|
||||||
|
request: Request,
|
||||||
|
verification_id: str,
|
||||||
|
user: UserModel = Depends(UNVERIFIED_MANAGER),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Modifies a user's email address.
|
||||||
|
Endpoint is limited to 3 hits per second
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not (
|
||||||
|
verification := await EmailVerification.select(*EmailVerification.all_columns())
|
||||||
|
.where(EmailVerification.id == verification_id)
|
||||||
|
.first()
|
||||||
|
):
|
||||||
|
raise HTTPException(status_code=404, detail="Request ID is invalid")
|
||||||
|
elif not verification["pending"]:
|
||||||
|
raise HTTPException(status_code=400, detail="This link has already been used")
|
||||||
|
elif datetime.now().astimezone(timezone.utc) - verification[
|
||||||
|
"creation_date"
|
||||||
|
].astimezone(timezone.utc) > timedelta(seconds=EMAIL_VERIFICATION_EXPIRATION):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail="Verification window has expired. Try again",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Note how we don't update based on the verification ID:
|
||||||
|
# this way, multiple pending email verification requests
|
||||||
|
# are all cleared at once
|
||||||
|
await EmailVerification.update({EmailVerification.pending: False}).where(
|
||||||
|
EmailVerification.user == user.public_id
|
||||||
|
and EmailVerification.kind == EmailVerificationType.CHANGE_EMAIL
|
||||||
|
)
|
||||||
|
await User.update(
|
||||||
|
{
|
||||||
|
User.email_address: verification["data"].decode(),
|
||||||
|
User.email_verified: False,
|
||||||
|
}
|
||||||
|
).where(User.public_id == user.public_id)
|
||||||
|
return APIResponse(status_code=200, msg="Email updated")
|
||||||
|
|
||||||
|
|
||||||
|
@router.put(
|
||||||
|
"user/resendMail",
|
||||||
|
tags=["Users"],
|
||||||
|
status_code=200,
|
||||||
|
responses={
|
||||||
|
200: {"model": APIResponse},
|
||||||
|
400: {"model": BadRequest},
|
||||||
|
422: {"model": UnprocessableEntity},
|
||||||
|
500: {"model": InternalServerError},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
@LIMITER.limit("6/minute")
|
||||||
|
async def resend_email(request: Request, user: UserModel = Depends(UNVERIFIED_MANAGER)):
|
||||||
|
"""
|
||||||
|
Resends the verification email to the user if the previous has expired.
|
||||||
|
Endpoint is limited to 6 hits per minute
|
||||||
|
"""
|
||||||
|
|
||||||
|
if user.email_verified:
|
||||||
|
raise HTTPException(status_code=400, detail="Email is already verified")
|
||||||
|
email_template = SMTP_TEMPLATES_DIRECTORY / f"{user.locale}.json"
|
||||||
|
try:
|
||||||
|
email_template.resolve(strict=True)
|
||||||
|
except FileNotFoundError:
|
||||||
|
email_template = SMTP_TEMPLATES_DIRECTORY / "en_US.json"
|
||||||
|
with email_template.open() as f:
|
||||||
|
email_message = json.load(f)["signup"]
|
||||||
|
verification_id = uuid.uuid4()
|
||||||
|
if await send_email(
|
||||||
|
SMTP_HOST,
|
||||||
|
SMTP_PORT,
|
||||||
|
email_message["content"].format(
|
||||||
|
first_name=user.first_name,
|
||||||
|
last_name=user.last_name,
|
||||||
|
username=user.username,
|
||||||
|
email=user.email_address,
|
||||||
|
link=f"http{'s' if HAS_HTTPS else ''}://{HOST}"
|
||||||
|
f"{'' if PORT == 443 and HAS_HTTPS or PORT == 80 else f':{PORT}'}/user/verifyEmail/{verification_id}",
|
||||||
|
platformName=PLATFORM_NAME,
|
||||||
|
),
|
||||||
|
SMTP_TIMEOUT,
|
||||||
|
SMTP_FROM_USER,
|
||||||
|
user.email_address,
|
||||||
|
email_message["subject"].format(
|
||||||
|
first_name=user.first_name,
|
||||||
|
last_name=user.last_name,
|
||||||
|
username=user.username,
|
||||||
|
email=user.email_address,
|
||||||
|
platformName=PLATFORM_NAME,
|
||||||
|
),
|
||||||
|
SMTP_USER,
|
||||||
|
SMTP_PASSWORD,
|
||||||
|
use_tls=SMTP_USE_TLS,
|
||||||
|
):
|
||||||
|
await EmailVerification.update(
|
||||||
|
{
|
||||||
|
EmailVerification.id: verification_id,
|
||||||
|
EmailVerification.creation_date: datetime.now(),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return APIResponse(status_code=200, msg="Success")
|
||||||
|
else:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail="An error occurred while trying to resend the email,"
|
||||||
|
" please try again later",
|
||||||
|
)
|
|
@ -22,13 +22,19 @@ router = FastAPI()
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
@LIMITER.limit("2/second")
|
@LIMITER.limit("2/second")
|
||||||
async def get_media(request: Request, media_id: str, _user: UserModel = Depends(MANAGER)):
|
async def get_media(
|
||||||
|
request: Request, media_id: str, _user: UserModel = Depends(MANAGER)
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Gets a media object by its ID. Endpoint is
|
Gets a media object by its ID. Endpoint is
|
||||||
limited to 2 hits per second
|
limited to 2 hits per second
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if (m := await Media.select(Media.media_id).where(Media.media_id == media_id).first()) is None:
|
if (
|
||||||
|
m := await Media.select(Media.media_id)
|
||||||
|
.where(Media.media_id == media_id)
|
||||||
|
.first()
|
||||||
|
) is None:
|
||||||
raise HTTPException(status_code=404, detail="Media not found")
|
raise HTTPException(status_code=404, detail="Media not found")
|
||||||
m = Media(**m)
|
m = Media(**m)
|
||||||
if m.media_type == MediaType.FILE:
|
if m.media_type == MediaType.FILE:
|
||||||
|
@ -55,7 +61,9 @@ async def get_media(request: Request, media_id: str, _user: UserModel = Depends(
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
@LIMITER.limit("2/second")
|
@LIMITER.limit("2/second")
|
||||||
async def report_media(request: Request, media_id: str, _user: UserModel = Depends(MANAGER)):
|
async def report_media(
|
||||||
|
request: Request, media_id: str, _user: UserModel = Depends(MANAGER)
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Reports a piece of media by its ID. This creates
|
Reports a piece of media by its ID. This creates
|
||||||
a report that can be seen by admins, which can
|
a report that can be seen by admins, which can
|
||||||
|
@ -63,7 +71,11 @@ async def report_media(request: Request, media_id: str, _user: UserModel = Depen
|
||||||
hits per second
|
hits per second
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if (m := await Media.select(Media.media_id).where(Media.media_id == media_id).first()) is None:
|
if (
|
||||||
|
m := await Media.select(Media.media_id)
|
||||||
|
.where(Media.media_id == media_id)
|
||||||
|
.first()
|
||||||
|
) is None:
|
||||||
raise HTTPException(status_code=404, detail="Media not found")
|
raise HTTPException(status_code=404, detail="Media not found")
|
||||||
# TODO: Create report
|
# TODO: Create report
|
||||||
return APIResponse(msg="Success")
|
return APIResponse(msg="Success")
|
|
@ -0,0 +1,73 @@
|
||||||
|
from datetime import timedelta
|
||||||
|
from datetime import timezone, datetime
|
||||||
|
from fastapi.exceptions import HTTPException
|
||||||
|
from fastapi import APIRouter as FastAPI, Depends, Request
|
||||||
|
|
||||||
|
|
||||||
|
from config import (
|
||||||
|
LIMITER,
|
||||||
|
EMAIL_VERIFICATION_EXPIRATION,
|
||||||
|
UNVERIFIED_MANAGER,
|
||||||
|
)
|
||||||
|
from responses import (
|
||||||
|
Response as APIResponse,
|
||||||
|
UnprocessableEntity,
|
||||||
|
BadRequest,
|
||||||
|
NotFound,
|
||||||
|
)
|
||||||
|
from orm.users import User, UserModel
|
||||||
|
from orm.email_verification import EmailVerification, EmailVerificationType
|
||||||
|
|
||||||
|
|
||||||
|
router = FastAPI()
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/user/resetPassword/{verification_id}",
|
||||||
|
tags=["Users"],
|
||||||
|
status_code=200,
|
||||||
|
responses={
|
||||||
|
200: {"model": APIResponse},
|
||||||
|
400: {"model": BadRequest},
|
||||||
|
422: {"model": UnprocessableEntity},
|
||||||
|
404: {"model": NotFound},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
@LIMITER.limit("3/second")
|
||||||
|
async def reset_password(
|
||||||
|
request: Request,
|
||||||
|
verification_id: str,
|
||||||
|
user: UserModel = Depends(UNVERIFIED_MANAGER),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Modifies a user's password. Endpoint is limited
|
||||||
|
to 3 hits per second
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not (
|
||||||
|
verification := await EmailVerification.select(*EmailVerification.all_columns())
|
||||||
|
.where(EmailVerification.id == verification_id)
|
||||||
|
.first()
|
||||||
|
):
|
||||||
|
raise HTTPException(status_code=404, detail="Request ID is invalid")
|
||||||
|
elif not verification["pending"]:
|
||||||
|
raise HTTPException(status_code=400, detail="This link has already been used")
|
||||||
|
elif datetime.now().astimezone(timezone.utc) - verification[
|
||||||
|
"creation_date"
|
||||||
|
].astimezone(timezone.utc) > timedelta(seconds=EMAIL_VERIFICATION_EXPIRATION):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail="Verification window has expired. Try again",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Note how we don't update based on the verification ID:
|
||||||
|
# this way, multiple pending email verification requests
|
||||||
|
# are all cleared at once
|
||||||
|
await EmailVerification.update({EmailVerification.pending: False}).where(
|
||||||
|
EmailVerification.user == user.public_id
|
||||||
|
and EmailVerification.kind == EmailVerificationType.PASSWORD_RESET
|
||||||
|
)
|
||||||
|
await User.update({User.password_hash: verification["data"]}).where(
|
||||||
|
User.public_id == user.public_id
|
||||||
|
)
|
||||||
|
return APIResponse(status_code=200, msg="Password updated")
|
|
@ -1,37 +1,23 @@
|
||||||
import base64
|
import base64
|
||||||
from datetime import timezone, datetime
|
|
||||||
import hashlib
|
|
||||||
import json
|
import json
|
||||||
import re
|
|
||||||
import imghdr
|
|
||||||
import uuid
|
import uuid
|
||||||
import zlib
|
|
||||||
import bcrypt
|
import bcrypt
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
import validators
|
|
||||||
from fastapi import APIRouter as FastAPI, Depends, Response, Request, UploadFile
|
|
||||||
from fastapi.exceptions import HTTPException
|
|
||||||
from fastapi.security import OAuth2PasswordRequestForm
|
|
||||||
from datetime import timedelta
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from fastapi.exceptions import HTTPException
|
||||||
|
from fastapi import APIRouter as FastAPI, Depends, Response, Request, UploadFile
|
||||||
|
|
||||||
|
|
||||||
from config import (
|
from config import (
|
||||||
BCRYPT_ROUNDS,
|
BCRYPT_ROUNDS,
|
||||||
LOGGER,
|
|
||||||
SESSION_EXPIRE_LIMIT,
|
|
||||||
COOKIE_SAMESITE_POLICY,
|
COOKIE_SAMESITE_POLICY,
|
||||||
COOKIE_DOMAIN,
|
COOKIE_DOMAIN,
|
||||||
MANAGER,
|
MANAGER,
|
||||||
SESSION_COOKIE_NAME,
|
SESSION_COOKIE_NAME,
|
||||||
LIMITER,
|
LIMITER,
|
||||||
VALIDATE_PASSWORD_REGEX,
|
|
||||||
VALIDATE_USERNAME_REGEX,
|
|
||||||
SECURE_COOKIE,
|
SECURE_COOKIE,
|
||||||
COOKIE_PATH,
|
COOKIE_PATH,
|
||||||
COOKIE_HTTPONLY,
|
COOKIE_HTTPONLY,
|
||||||
ALLOWED_MEDIA_TYPES,
|
|
||||||
ZLIB_COMPRESSION_LEVEL,
|
|
||||||
MAX_MEDIA_SIZE,
|
|
||||||
STORAGE_ENGINE,
|
STORAGE_ENGINE,
|
||||||
STORAGE_FOLDER,
|
STORAGE_FOLDER,
|
||||||
SMTP_HOST,
|
SMTP_HOST,
|
||||||
|
@ -46,8 +32,6 @@ from config import (
|
||||||
HAS_HTTPS,
|
HAS_HTTPS,
|
||||||
HOST,
|
HOST,
|
||||||
PORT,
|
PORT,
|
||||||
EMAIL_VERIFICATION_EXPIRATION,
|
|
||||||
FORCE_EMAIL_VERIFICATION,
|
|
||||||
UNVERIFIED_MANAGER,
|
UNVERIFIED_MANAGER,
|
||||||
)
|
)
|
||||||
from responses import (
|
from responses import (
|
||||||
|
@ -70,116 +54,19 @@ from orm.users import (
|
||||||
PrivateUserModel,
|
PrivateUserModel,
|
||||||
get_user_by_username,
|
get_user_by_username,
|
||||||
get_user_by_id,
|
get_user_by_id,
|
||||||
get_user_by_email,
|
|
||||||
)
|
)
|
||||||
from orm.media import Media, MediaType, PublicMediaModel
|
from orm.media import Media, MediaType, PublicMediaModel
|
||||||
from orm.email_verification import EmailVerification, EmailVerificationType
|
from orm.email_verification import EmailVerification
|
||||||
from util.email import send_email
|
from util.email import send_email
|
||||||
|
from util.users import validate_user, validate_profile_picture
|
||||||
|
|
||||||
|
|
||||||
router = FastAPI()
|
router = FastAPI()
|
||||||
|
|
||||||
|
|
||||||
# Credential loaders for our authenticated routes
|
|
||||||
@MANAGER.user_loader()
|
|
||||||
async def get_self_by_id(public_id: UUID, requires_verified: bool = True) -> UserModel:
|
|
||||||
user = await get_user_by_id(public_id, include_secrets=True, restricted_ok=True)
|
|
||||||
if FORCE_EMAIL_VERIFICATION and requires_verified and not user.email_verified:
|
|
||||||
raise HTTPException(status_code=401, detail="Email verification is required")
|
|
||||||
return user
|
|
||||||
|
|
||||||
|
|
||||||
@UNVERIFIED_MANAGER.user_loader()
|
|
||||||
async def get_self_by_id_unverified(public_id: UUID) -> UserModel:
|
|
||||||
return await get_self_by_id(public_id, False)
|
|
||||||
|
|
||||||
|
|
||||||
# Here follow our *beautifully* documented path operations
|
# Here follow our *beautifully* documented path operations
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
|
||||||
"/user",
|
|
||||||
tags=["Users"],
|
|
||||||
status_code=200,
|
|
||||||
responses={
|
|
||||||
200: {"model": APIResponse},
|
|
||||||
400: {"model": BadRequest},
|
|
||||||
422: {"model": UnprocessableEntity},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
@LIMITER.limit("5/minute")
|
|
||||||
async def login(request: Request, response: Response, data: OAuth2PasswordRequestForm = Depends()):
|
|
||||||
"""
|
|
||||||
Performs user authentication. Endpoint is limited to 5 hits per minute
|
|
||||||
"""
|
|
||||||
|
|
||||||
if request.cookies.get(SESSION_COOKIE_NAME):
|
|
||||||
raise HTTPException(status_code=400, detail="Please logout first")
|
|
||||||
username = data.username
|
|
||||||
if len(username) > 32:
|
|
||||||
raise HTTPException(status_code=413, detail="Authentication failed: username is too long")
|
|
||||||
try:
|
|
||||||
password = data.password.encode()
|
|
||||||
if len(password) > 72:
|
|
||||||
raise HTTPException(status_code=413, detail="Authentication failed: password is too long")
|
|
||||||
except UnicodeEncodeError as e:
|
|
||||||
LOGGER.warning(
|
|
||||||
f"An error occurred while attempting to decode password for user {username} -> {type(e).__name__}: {e}"
|
|
||||||
)
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=413,
|
|
||||||
detail="Authentication failed: invalid characters in password",
|
|
||||||
)
|
|
||||||
if not (user := await get_user_by_username(username, include_secrets=True, restricted_ok=True)):
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=413,
|
|
||||||
detail="Authentication failed: the user does not exist",
|
|
||||||
)
|
|
||||||
if not bcrypt.checkpw(password, user.password_hash):
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=413,
|
|
||||||
detail="Authentication failed: password mismatch",
|
|
||||||
)
|
|
||||||
token = MANAGER.create_access_token(
|
|
||||||
expires=timedelta(seconds=SESSION_EXPIRE_LIMIT),
|
|
||||||
data={"sub": str(user.public_id)},
|
|
||||||
)
|
|
||||||
response.set_cookie(
|
|
||||||
secure=SECURE_COOKIE,
|
|
||||||
key=SESSION_COOKIE_NAME,
|
|
||||||
max_age=SESSION_EXPIRE_LIMIT,
|
|
||||||
value=token,
|
|
||||||
httponly=COOKIE_HTTPONLY,
|
|
||||||
samesite=COOKIE_SAMESITE_POLICY,
|
|
||||||
domain=COOKIE_DOMAIN or None,
|
|
||||||
path=COOKIE_PATH or "/",
|
|
||||||
)
|
|
||||||
return APIResponse(status_code=200, msg="Authentication successful")
|
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
|
||||||
"/user/logout",
|
|
||||||
tags=["Users"],
|
|
||||||
status_code=200,
|
|
||||||
responses={200: {"model": APIResponse}, 422: {"model": UnprocessableEntity}},
|
|
||||||
)
|
|
||||||
@LIMITER.limit("5/minute")
|
|
||||||
async def logout(request: Request, response: Response, _user: UserModel = Depends(UNVERIFIED_MANAGER)):
|
|
||||||
"""
|
|
||||||
Deletes a user's session cookie, logging them
|
|
||||||
out. Endpoint is limited to 5 hits per minute
|
|
||||||
"""
|
|
||||||
|
|
||||||
response.delete_cookie(
|
|
||||||
secure=SECURE_COOKIE,
|
|
||||||
key=SESSION_COOKIE_NAME,
|
|
||||||
httponly=COOKIE_HTTPONLY,
|
|
||||||
samesite=COOKIE_SAMESITE_POLICY,
|
|
||||||
domain=COOKIE_DOMAIN or None,
|
|
||||||
path=COOKIE_PATH or "/",
|
|
||||||
)
|
|
||||||
return APIResponse(status_code=200, msg="Logged out")
|
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
"/user/me",
|
"/user/me",
|
||||||
tags=["Users"],
|
tags=["Users"],
|
||||||
|
@ -219,7 +106,9 @@ async def get_self(request: Request, user: UserModel = Depends(UNVERIFIED_MANAGE
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
@LIMITER.limit("30/second")
|
@LIMITER.limit("30/second")
|
||||||
async def get_user_by_name(request: Request, username: str, _auth: UserModel = Depends(MANAGER)):
|
async def get_user_by_name(
|
||||||
|
request: Request, username: str, _auth: UserModel = Depends(MANAGER)
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Fetches a single user by its public username.
|
Fetches a single user by its public username.
|
||||||
Endpoint is limited to 30 hits per second
|
Endpoint is limited to 30 hits per second
|
||||||
|
@ -227,6 +116,7 @@ async def get_user_by_name(request: Request, username: str, _auth: UserModel = D
|
||||||
|
|
||||||
if not (user := await get_user_by_username(username)):
|
if not (user := await get_user_by_username(username)):
|
||||||
return NotFound(msg="Lookup failed: the user does not exist")
|
return NotFound(msg="Lookup failed: the user does not exist")
|
||||||
|
user: PrivateUserModel
|
||||||
return PublicUserResponse(
|
return PublicUserResponse(
|
||||||
data=PublicUserModel(
|
data=PublicUserModel(
|
||||||
public_id=user.public_id,
|
public_id=user.public_id,
|
||||||
|
@ -258,14 +148,19 @@ async def get_user_by_name(request: Request, username: str, _auth: UserModel = D
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
@LIMITER.limit("30/second")
|
@LIMITER.limit("30/second")
|
||||||
async def get_user_by_public_id(request: Request, public_id: str, _auth: UserModel = Depends(MANAGER)):
|
async def get_user_by_public_id(
|
||||||
|
request: Request, public_id: str, _auth: UserModel = Depends(MANAGER)
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Fetches a single user by its public ID.
|
Fetches a single user by its public ID.
|
||||||
Endpoint is limited to 30 hits per second
|
Endpoint is limited to 30 hits per second
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if not (user := await get_user_by_id(UUID(public_id))):
|
if not (user := await get_user_by_id(UUID(public_id))):
|
||||||
raise HTTPException(status_code=404, detail="Lookup failed: the user does not exist")
|
raise HTTPException(
|
||||||
|
status_code=404, detail="Lookup failed: the user does not exist"
|
||||||
|
)
|
||||||
|
user: PrivateUserModel
|
||||||
return PublicUserResponse(
|
return PublicUserResponse(
|
||||||
data=PublicUserModel(
|
data=PublicUserModel(
|
||||||
public_id=user.public_id,
|
public_id=user.public_id,
|
||||||
|
@ -286,54 +181,6 @@ async def get_user_by_public_id(request: Request, public_id: str, _auth: UserMod
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def validate_user(
|
|
||||||
first_name: str | None,
|
|
||||||
last_name: str | None,
|
|
||||||
username: str | None,
|
|
||||||
email: str | None,
|
|
||||||
password: str | None,
|
|
||||||
bio: str | None,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Performs some validation upon user creation. Returns
|
|
||||||
a tuple (success, msg) to be used by routes. Values
|
|
||||||
set to None are not checked against
|
|
||||||
"""
|
|
||||||
|
|
||||||
if first_name and len(first_name) > 64:
|
|
||||||
return False, "first name is too long"
|
|
||||||
if first_name and len(first_name) < 5:
|
|
||||||
return False, "first name is too short"
|
|
||||||
if last_name and len(last_name) > 64:
|
|
||||||
return False, "last name is too long"
|
|
||||||
if last_name and len(last_name) < 2:
|
|
||||||
return False, "last name is too short"
|
|
||||||
if username and len(username) < 5:
|
|
||||||
return False, "username is too short"
|
|
||||||
if username and len(username) > 32:
|
|
||||||
return False, "username is too long"
|
|
||||||
if username and VALIDATE_USERNAME_REGEX and not re.match(VALIDATE_USERNAME_REGEX, username):
|
|
||||||
return False, "username is invalid"
|
|
||||||
if email and not validators.email(email):
|
|
||||||
return False, "email is not valid"
|
|
||||||
if password and len(password) > 72:
|
|
||||||
return False, "password is too long"
|
|
||||||
if password and VALIDATE_PASSWORD_REGEX and not re.match(VALIDATE_PASSWORD_REGEX, password):
|
|
||||||
return False, "password is too weak"
|
|
||||||
if username and await get_user_by_username(username, deleted_ok=True, restricted_ok=True):
|
|
||||||
return False, "username is already taken"
|
|
||||||
if email and await get_user_by_email(email, deleted_ok=True, restricted_ok=True):
|
|
||||||
return False, "email is already registered"
|
|
||||||
if bio and len(bio) > 4096:
|
|
||||||
return False, "bio is too long"
|
|
||||||
if bio:
|
|
||||||
try:
|
|
||||||
bio.encode("utf-8")
|
|
||||||
except UnicodeDecodeError:
|
|
||||||
return False, "bio contains invalid characters"
|
|
||||||
return True, ""
|
|
||||||
|
|
||||||
|
|
||||||
@router.delete(
|
@router.delete(
|
||||||
"/user",
|
"/user",
|
||||||
tags=["Users"],
|
tags=["Users"],
|
||||||
|
@ -341,7 +188,9 @@ async def validate_user(
|
||||||
responses={200: {"model": APIResponse}, 422: {"model": UnprocessableEntity}},
|
responses={200: {"model": APIResponse}, 422: {"model": UnprocessableEntity}},
|
||||||
)
|
)
|
||||||
@LIMITER.limit("1/minute")
|
@LIMITER.limit("1/minute")
|
||||||
async def delete(request: Request, response: Response, user: UserModel = Depends(UNVERIFIED_MANAGER)):
|
async def delete_user(
|
||||||
|
request: Request, response: Response, user: UserModel = Depends(UNVERIFIED_MANAGER)
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Sets the user's deleted flag in the database,
|
Sets the user's deleted flag in the database,
|
||||||
without actually deleting the associated
|
without actually deleting the associated
|
||||||
|
@ -362,219 +211,6 @@ async def delete(request: Request, response: Response, user: UserModel = Depends
|
||||||
return APIResponse(status_code=200, msg="Success")
|
return APIResponse(status_code=200, msg="Success")
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
|
||||||
"/user/verifyEmail/{verification_id}",
|
|
||||||
tags=["Users"],
|
|
||||||
status_code=200,
|
|
||||||
responses={
|
|
||||||
200: {"model": APIResponse},
|
|
||||||
404: {"model": NotFound},
|
|
||||||
422: {"model": UnprocessableEntity},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
@LIMITER.limit("3/second")
|
|
||||||
async def verify_email(
|
|
||||||
request: Request,
|
|
||||||
verification_id: str,
|
|
||||||
user: UserModel = Depends(UNVERIFIED_MANAGER),
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Verifies a user's email address. Endpoint is
|
|
||||||
limited to 3 hits per second
|
|
||||||
"""
|
|
||||||
|
|
||||||
if not (
|
|
||||||
verification := await EmailVerification.select(*EmailVerification.all_columns())
|
|
||||||
.where(EmailVerification.id == verification_id)
|
|
||||||
.first()
|
|
||||||
):
|
|
||||||
raise HTTPException(status_code=404, detail="Verification ID is invalid")
|
|
||||||
elif not verification["pending"]:
|
|
||||||
raise HTTPException(status_code=400, detail="Email is already verified")
|
|
||||||
elif datetime.now().astimezone(timezone.utc) - verification["creation_date"].astimezone(timezone.utc) > timedelta(
|
|
||||||
seconds=EMAIL_VERIFICATION_EXPIRATION
|
|
||||||
):
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=400,
|
|
||||||
detail="Verification window has expired. Try again",
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
await EmailVerification.update({EmailVerification.pending: False}).where(
|
|
||||||
EmailVerification.user == user.public_id
|
|
||||||
)
|
|
||||||
await User.update({User.email_verified: True}).where(User.public_id == user.public_id)
|
|
||||||
return APIResponse(status_code=200, msg="Verification successful")
|
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
|
||||||
"/user/resetPassword/{verification_id}",
|
|
||||||
tags=["Users"],
|
|
||||||
status_code=200,
|
|
||||||
responses={
|
|
||||||
200: {"model": APIResponse},
|
|
||||||
400: {"model": BadRequest},
|
|
||||||
422: {"model": UnprocessableEntity},
|
|
||||||
404: {"model": NotFound},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
@LIMITER.limit("3/second")
|
|
||||||
async def reset_password(
|
|
||||||
request: Request,
|
|
||||||
verification_id: str,
|
|
||||||
user: UserModel = Depends(UNVERIFIED_MANAGER),
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Modifies a user's password. Endpoint is limited
|
|
||||||
to 3 hits per second
|
|
||||||
"""
|
|
||||||
|
|
||||||
if not (
|
|
||||||
verification := await EmailVerification.select(*EmailVerification.all_columns())
|
|
||||||
.where(EmailVerification.id == verification_id)
|
|
||||||
.first()
|
|
||||||
):
|
|
||||||
raise HTTPException(status_code=404, detail="Request ID is invalid")
|
|
||||||
elif not verification["pending"]:
|
|
||||||
raise HTTPException(status_code=400, detail="This link has already been used")
|
|
||||||
elif datetime.now().astimezone(timezone.utc) - verification["creation_date"].astimezone(timezone.utc) > timedelta(
|
|
||||||
seconds=EMAIL_VERIFICATION_EXPIRATION
|
|
||||||
):
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=400,
|
|
||||||
detail="Verification window has expired. Try again",
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Note how we don't update based on the verification ID:
|
|
||||||
# this way, multiple pending email verification requests
|
|
||||||
# are all cleared at once
|
|
||||||
await EmailVerification.update({EmailVerification.pending: False}).where(
|
|
||||||
EmailVerification.user == user.public_id and EmailVerification.kind == EmailVerificationType.PASSWORD_RESET
|
|
||||||
)
|
|
||||||
await User.update({User.password_hash: verification["data"]}).where(User.public_id == user.public_id)
|
|
||||||
return APIResponse(status_code=200, msg="Password updated")
|
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
|
||||||
"/user/changeEmail/{verification_id}",
|
|
||||||
tags=["Users"],
|
|
||||||
status_code=200,
|
|
||||||
responses={
|
|
||||||
200: {"model": APIResponse},
|
|
||||||
400: {"model": BadRequest},
|
|
||||||
422: {"model": UnprocessableEntity},
|
|
||||||
404: {"model": NotFound},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
@LIMITER.limit("3/second")
|
|
||||||
async def change_email(
|
|
||||||
request: Request,
|
|
||||||
verification_id: str,
|
|
||||||
user: UserModel = Depends(UNVERIFIED_MANAGER),
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Modifies a user's email address.
|
|
||||||
Endpoint is limited to 3 hits per second
|
|
||||||
"""
|
|
||||||
|
|
||||||
if not (
|
|
||||||
verification := await EmailVerification.select(*EmailVerification.all_columns())
|
|
||||||
.where(EmailVerification.id == verification_id)
|
|
||||||
.first()
|
|
||||||
):
|
|
||||||
raise HTTPException(status_code=404, detail="Request ID is invalid")
|
|
||||||
elif not verification["pending"]:
|
|
||||||
raise HTTPException(status_code=400, detail="This link has already been used")
|
|
||||||
elif datetime.now().astimezone(timezone.utc) - verification["creation_date"].astimezone(timezone.utc) > timedelta(
|
|
||||||
seconds=EMAIL_VERIFICATION_EXPIRATION
|
|
||||||
):
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=400,
|
|
||||||
detail="Verification window has expired. Try again",
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Note how we don't update based on the verification ID:
|
|
||||||
# this way, multiple pending email verification requests
|
|
||||||
# are all cleared at once
|
|
||||||
await EmailVerification.update({EmailVerification.pending: False}).where(
|
|
||||||
EmailVerification.user == user.public_id and EmailVerification.kind == EmailVerificationType.CHANGE_EMAIL
|
|
||||||
)
|
|
||||||
await User.update(
|
|
||||||
{
|
|
||||||
User.email_address: verification["data"].decode(),
|
|
||||||
User.email_verified: False,
|
|
||||||
}
|
|
||||||
).where(User.public_id == user.public_id)
|
|
||||||
return APIResponse(status_code=200, msg="Email updated")
|
|
||||||
|
|
||||||
|
|
||||||
@router.put(
|
|
||||||
"user/resendMail",
|
|
||||||
tags=["Users"],
|
|
||||||
status_code=200,
|
|
||||||
responses={
|
|
||||||
200: {"model": APIResponse},
|
|
||||||
400: {"model": BadRequest},
|
|
||||||
422: {"model": UnprocessableEntity},
|
|
||||||
500: {"model": InternalServerError},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
@LIMITER.limit("6/minute")
|
|
||||||
async def resend_email(request: Request, user: UserModel = Depends(UNVERIFIED_MANAGER)):
|
|
||||||
"""
|
|
||||||
Resends the verification email to the user if the previous has expired.
|
|
||||||
Endpoint is limited to 6 hits per minute
|
|
||||||
"""
|
|
||||||
|
|
||||||
if user.email_verified:
|
|
||||||
raise HTTPException(status_code=400, detail="Email is already verified")
|
|
||||||
email_template = SMTP_TEMPLATES_DIRECTORY / f"{user.locale}.json"
|
|
||||||
try:
|
|
||||||
email_template.resolve(strict=True)
|
|
||||||
except FileNotFoundError:
|
|
||||||
email_template = SMTP_TEMPLATES_DIRECTORY / "en_US.json"
|
|
||||||
with email_template.open() as f:
|
|
||||||
email_message = json.load(f)["signup"]
|
|
||||||
verification_id = uuid.uuid4()
|
|
||||||
if await send_email(
|
|
||||||
SMTP_HOST,
|
|
||||||
SMTP_PORT,
|
|
||||||
email_message["content"].format(
|
|
||||||
first_name=user.first_name,
|
|
||||||
last_name=user.last_name,
|
|
||||||
username=user.username,
|
|
||||||
email=user.email_address,
|
|
||||||
link=f"http{'s' if HAS_HTTPS else ''}://{HOST}"
|
|
||||||
f"{'' if PORT == 443 and HAS_HTTPS or PORT == 80 else f':{PORT}'}/user/verifyEmail/{verification_id}",
|
|
||||||
platformName=PLATFORM_NAME,
|
|
||||||
),
|
|
||||||
SMTP_TIMEOUT,
|
|
||||||
SMTP_FROM_USER,
|
|
||||||
user.email_address,
|
|
||||||
email_message["subject"].format(
|
|
||||||
first_name=user.first_name,
|
|
||||||
last_name=user.last_name,
|
|
||||||
username=user.username,
|
|
||||||
email=user.email_address,
|
|
||||||
platformName=PLATFORM_NAME,
|
|
||||||
),
|
|
||||||
SMTP_USER,
|
|
||||||
SMTP_PASSWORD,
|
|
||||||
use_tls=SMTP_USE_TLS,
|
|
||||||
):
|
|
||||||
await EmailVerification.update(
|
|
||||||
{
|
|
||||||
EmailVerification.id: verification_id,
|
|
||||||
EmailVerification.creation_date: datetime.now(),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
return APIResponse(status_code=200, msg="Success")
|
|
||||||
else:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=500,
|
|
||||||
detail="An error occurred while trying to resend the email," " please try again later",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.put(
|
@router.put(
|
||||||
"/user",
|
"/user",
|
||||||
tags=["Users"],
|
tags=["Users"],
|
||||||
|
@ -607,7 +243,9 @@ async def signup(
|
||||||
raise HTTPException(status_code=400, detail="Please logout first")
|
raise HTTPException(status_code=400, detail="Please logout first")
|
||||||
# We don't use FastAPI's validation because we want custom error
|
# We don't use FastAPI's validation because we want custom error
|
||||||
# messages
|
# messages
|
||||||
result, msg = await validate_user(first_name, last_name, username, email, password, bio)
|
result, msg = await validate_user(
|
||||||
|
first_name, last_name, username, email, password, bio
|
||||||
|
)
|
||||||
if not result:
|
if not result:
|
||||||
return APIResponse(status_code=413, msg=f"Signup failed: {msg}")
|
return APIResponse(status_code=413, msg=f"Signup failed: {msg}")
|
||||||
else:
|
else:
|
||||||
|
@ -667,40 +305,11 @@ async def signup(
|
||||||
else:
|
else:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=500,
|
status_code=500,
|
||||||
detail="An error occurred while sending verification email, please" " try again later",
|
detail="An error occurred while sending verification email, please"
|
||||||
|
" try again later",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def validate_profile_picture(
|
|
||||||
file: UploadFile,
|
|
||||||
) -> tuple[bool | None, str, bytes, str]:
|
|
||||||
"""
|
|
||||||
Validates a profile picture's size and content to see if it fits
|
|
||||||
our criteria and returns a tuple result, ext, data where result is a
|
|
||||||
boolean or none (True = check was passed, False = size too large,
|
|
||||||
None = check was failed for other reasons) indicating if the check was successful,
|
|
||||||
ext is the file's type and extension, data is a compressed stream of bytes
|
|
||||||
representing the original media and hash is the file's SHA256 hash encoded in
|
|
||||||
hexadecimal, before the compression. This function never raises an exception
|
|
||||||
"""
|
|
||||||
|
|
||||||
async with file:
|
|
||||||
try:
|
|
||||||
content = await file.read()
|
|
||||||
if len(content) > MAX_MEDIA_SIZE:
|
|
||||||
return False, "", b"", ""
|
|
||||||
if not (ext := imghdr.what(content.decode())) in ALLOWED_MEDIA_TYPES:
|
|
||||||
return None, "", b"", ""
|
|
||||||
return (
|
|
||||||
True,
|
|
||||||
ext,
|
|
||||||
zlib.compress(content, ZLIB_COMPRESSION_LEVEL),
|
|
||||||
hashlib.sha256(content).hexdigest(),
|
|
||||||
)
|
|
||||||
except (UnicodeDecodeError, zlib.error):
|
|
||||||
return None, "", b"", ""
|
|
||||||
|
|
||||||
|
|
||||||
@router.patch(
|
@router.patch(
|
||||||
"/user",
|
"/user",
|
||||||
tags=["Users"],
|
tags=["Users"],
|
||||||
|
@ -739,9 +348,15 @@ async def update_user(
|
||||||
since they're the only ones that can be set to a null value. Endpoint is limited to 6 hits per minute
|
since they're the only ones that can be set to a null value. Endpoint is limited to 6 hits per minute
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if not delete and not any((first_name, last_name, username, profile_picture, email_address, bio, password)):
|
if not delete and not any(
|
||||||
raise HTTPException(status_code=400, detail="At least one value has to be specified")
|
(first_name, last_name, username, profile_picture, email_address, bio, password)
|
||||||
result, msg = await validate_user(first_name, last_name, username, email_address, password, bio)
|
):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400, detail="At least one value has to be specified"
|
||||||
|
)
|
||||||
|
result, msg = await validate_user(
|
||||||
|
first_name, last_name, username, email_address, password, bio
|
||||||
|
)
|
||||||
if not result:
|
if not result:
|
||||||
raise HTTPException(status_code=413, detail=f"Update failed: {msg}")
|
raise HTTPException(status_code=413, detail=f"Update failed: {msg}")
|
||||||
orig_user = user.copy()
|
orig_user = user.copy()
|
||||||
|
@ -757,10 +372,16 @@ async def update_user(
|
||||||
if profile_picture:
|
if profile_picture:
|
||||||
result, ext, media, digest = validate_profile_picture(profile_picture)
|
result, ext, media, digest = validate_profile_picture(profile_picture)
|
||||||
if result is False:
|
if result is False:
|
||||||
raise HTTPException(status_code=415, detail="The file type is unsupported")
|
raise HTTPException(
|
||||||
|
status_code=415, detail="The file type is unsupported"
|
||||||
|
)
|
||||||
elif result is None:
|
elif result is None:
|
||||||
raise HTTPException(status_code=413, detail="The file is too large")
|
raise HTTPException(status_code=413, detail="The file is too large")
|
||||||
elif (m := await Media.select(Media.media_id).where(Media.media_id == digest).first()) is None:
|
elif (
|
||||||
|
m := await Media.select(Media.media_id)
|
||||||
|
.where(Media.media_id == digest)
|
||||||
|
.first()
|
||||||
|
) is None:
|
||||||
# This media hasn't been already uploaded (either by this user or by someone
|
# This media hasn't been already uploaded (either by this user or by someone
|
||||||
# else), so we save it now. If it has been already uploaded, there's no need
|
# else), so we save it now. If it has been already uploaded, there's no need
|
||||||
# to do it again (that's what the hash is for)
|
# to do it again (that's what the hash is for)
|
||||||
|
@ -836,7 +457,9 @@ async def update_user(
|
||||||
{
|
{
|
||||||
EmailVerification.id: verification_id,
|
EmailVerification.id: verification_id,
|
||||||
EmailVerification.user: User(public_id=user.public_id),
|
EmailVerification.user: User(public_id=user.public_id),
|
||||||
EmailVerification.data: bcrypt.hashpw(password.encode(), user.password_hash[:29]),
|
EmailVerification.data: bcrypt.hashpw(
|
||||||
|
password.encode(), user.password_hash[:29]
|
||||||
|
),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
@ -902,9 +525,13 @@ async def update_user(
|
||||||
user.profile_picture = None
|
user.profile_picture = None
|
||||||
fields = []
|
fields = []
|
||||||
for field in user:
|
for field in user:
|
||||||
if field not in ["email_address", "password"] and getattr(orig_user, field) != getattr(user, field):
|
if field not in ["email_address", "password"] and getattr(
|
||||||
|
orig_user, field
|
||||||
|
) != getattr(user, field):
|
||||||
fields.append((field, getattr(user, field)))
|
fields.append((field, getattr(user, field)))
|
||||||
if fields:
|
if fields:
|
||||||
# If anything has changed, we update our info
|
# If anything has changed, we update our info
|
||||||
await User.update({field: value for field, value in fields}).where(User.public_id == user.public_id)
|
await User.update({field: value for field, value in fields}).where(
|
||||||
|
User.public_id == user.public_id
|
||||||
|
)
|
||||||
return APIResponse(status_code=200, msg="Changes saved successfully")
|
return APIResponse(status_code=200, msg="Changes saved successfully")
|
|
@ -1,7 +1,7 @@
|
||||||
import uvloop
|
import uvloop
|
||||||
import asyncio
|
import asyncio
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
from pathlib import Path
|
||||||
from fastapi import FastAPI, Request
|
from fastapi import FastAPI, Request
|
||||||
from fastapi.exceptions import (
|
from fastapi.exceptions import (
|
||||||
HTTPException,
|
HTTPException,
|
||||||
|
@ -9,10 +9,7 @@ from fastapi.exceptions import (
|
||||||
StarletteHTTPException,
|
StarletteHTTPException,
|
||||||
)
|
)
|
||||||
from slowapi.errors import RateLimitExceeded
|
from slowapi.errors import RateLimitExceeded
|
||||||
from pathlib import Path
|
from endpoints import users, media, auth, email, password
|
||||||
|
|
||||||
|
|
||||||
from endpoints import users, media
|
|
||||||
from config import (
|
from config import (
|
||||||
LOGGER,
|
LOGGER,
|
||||||
LIMITER,
|
LIMITER,
|
||||||
|
@ -37,9 +34,11 @@ from util.exception_handlers import (
|
||||||
generic_error,
|
generic_error,
|
||||||
)
|
)
|
||||||
from util.email import test_smtp
|
from util.email import test_smtp
|
||||||
|
from config import MANAGER, UNVERIFIED_MANAGER
|
||||||
|
from util.users import get_self_by_id, get_self_by_id_unverified
|
||||||
|
|
||||||
|
|
||||||
with (Path(__file__).parent / "API.md").resolve(strict=True).open() as f:
|
with (Path(__file__).parent.parent / "API.md").resolve(strict=True).open() as f:
|
||||||
description = f.read()
|
description = f.read()
|
||||||
|
|
||||||
|
|
||||||
|
@ -116,7 +115,9 @@ async def startup_checks():
|
||||||
await create_tables()
|
await create_tables()
|
||||||
await Media.raw("SELECT 1;")
|
await Media.raw("SELECT 1;")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
LOGGER.error(f"An error occurred while trying to initialize the database -> {type(e).__name__}: {e}")
|
LOGGER.error(
|
||||||
|
f"An error occurred while trying to initialize the database -> {type(e).__name__}: {e}"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
LOGGER.info("Database initialized")
|
LOGGER.info("Database initialized")
|
||||||
LOGGER.info("Testing SMTP connection")
|
LOGGER.info("Testing SMTP connection")
|
||||||
|
@ -131,7 +132,10 @@ async def startup_checks():
|
||||||
True,
|
True,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
LOGGER.error(f"An error occurred while trying to connect to the SMTP server -> {type(e).__name__}: {e}")
|
LOGGER.error(
|
||||||
|
f"An error occurred while trying to connect to the SMTP server -> {type(e).__name__}: {e}"
|
||||||
|
)
|
||||||
|
raise
|
||||||
else:
|
else:
|
||||||
LOGGER.info("SMTP test was successful")
|
LOGGER.info("SMTP test was successful")
|
||||||
|
|
||||||
|
@ -141,7 +145,13 @@ if __name__ == "__main__":
|
||||||
LOGGER.debug("Including modules")
|
LOGGER.debug("Including modules")
|
||||||
app.include_router(users.router)
|
app.include_router(users.router)
|
||||||
app.include_router(media.router)
|
app.include_router(media.router)
|
||||||
|
app.include_router(auth.router)
|
||||||
|
app.include_router(email.router)
|
||||||
|
app.include_router(password.router)
|
||||||
|
LOGGER.debug("Configuring authentication and rate-limiting")
|
||||||
app.state.limiter = LIMITER
|
app.state.limiter = LIMITER
|
||||||
|
MANAGER.user_loader(get_self_by_id)
|
||||||
|
UNVERIFIED_MANAGER.user_loader(get_self_by_id_unverified)
|
||||||
LOGGER.debug("Setting exception handlers")
|
LOGGER.debug("Setting exception handlers")
|
||||||
app.add_exception_handler(RateLimitExceeded, rate_limited)
|
app.add_exception_handler(RateLimitExceeded, rate_limited)
|
||||||
app.add_exception_handler(NotAuthenticated, not_authenticated)
|
app.add_exception_handler(NotAuthenticated, not_authenticated)
|
||||||
|
@ -153,9 +163,21 @@ if __name__ == "__main__":
|
||||||
uvloop.install()
|
uvloop.install()
|
||||||
log_config = uvicorn.config.LOGGING_CONFIG
|
log_config = uvicorn.config.LOGGING_CONFIG
|
||||||
log_config["formatters"]["access"]["datefmt"] = LOGGER.handlers[0].formatter.datefmt
|
log_config["formatters"]["access"]["datefmt"] = LOGGER.handlers[0].formatter.datefmt
|
||||||
log_config["formatters"]["default"]["datefmt"] = LOGGER.handlers[0].formatter.datefmt
|
log_config["formatters"]["default"]["datefmt"] = LOGGER.handlers[
|
||||||
log_config["formatters"]["access"]["fmt"] = LOGGER.handlers[0].formatter._fmt
|
0
|
||||||
log_config["formatters"]["default"]["fmt"] = LOGGER.handlers[0].formatter._fmt
|
].formatter.datefmt
|
||||||
|
log_config["formatters"]["access"]["fmt"] = LOGGER.handlers[
|
||||||
|
0
|
||||||
|
].formatter._fmt # noqa
|
||||||
|
log_config["formatters"]["default"]["fmt"] = LOGGER.handlers[
|
||||||
|
0
|
||||||
|
].formatter._fmt # noqa
|
||||||
log_config["handlers"]["access"]["stream"] = "ext://sys.stderr"
|
log_config["handlers"]["access"]["stream"] = "ext://sys.stderr"
|
||||||
asyncio.run(startup_checks())
|
try:
|
||||||
|
asyncio.run(startup_checks())
|
||||||
|
except Exception as e:
|
||||||
|
LOGGER.error(
|
||||||
|
f"Startup checks failed due to a {type(e).__name__} exception -> {e}"
|
||||||
|
)
|
||||||
|
exit(1)
|
||||||
uvicorn.run(host=HOST, port=PORT, app=app, log_config=log_config, workers=WORKERS)
|
uvicorn.run(host=HOST, port=PORT, app=app, log_config=log_config, workers=WORKERS)
|
|
@ -45,7 +45,9 @@ class Media(Table):
|
||||||
|
|
||||||
|
|
||||||
MediaModel = create_pydantic_model(Media)
|
MediaModel = create_pydantic_model(Media)
|
||||||
PublicMediaModel = create_pydantic_model(Media, exclude_columns=(Media.flagged, Media.deleted, Media.media_type))
|
PublicMediaModel = create_pydantic_model(
|
||||||
|
Media, exclude_columns=(Media.flagged, Media.deleted, Media.media_type)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def get_media_by_column(
|
async def get_media_by_column(
|
|
@ -38,7 +38,9 @@ class Post(Table, tablename="posts"):
|
||||||
|
|
||||||
|
|
||||||
PostModel = create_pydantic_model(Post, nested=True)
|
PostModel = create_pydantic_model(Post, nested=True)
|
||||||
PrivatePostModelInternal = create_pydantic_model(Post, nested=True, exclude_columns=(Post.flagged, Post.internal_id))
|
PrivatePostModelInternal = create_pydantic_model(
|
||||||
|
Post, nested=True, exclude_columns=(Post.flagged, Post.internal_id)
|
||||||
|
)
|
||||||
PublicPostModelInternal = create_pydantic_model(
|
PublicPostModelInternal = create_pydantic_model(
|
||||||
Post, nested=True, exclude_columns=(Post.flagged, Post.deleted, Post.internal_id)
|
Post, nested=True, exclude_columns=(Post.flagged, Post.deleted, Post.internal_id)
|
||||||
)
|
)
|
|
@ -0,0 +1,16 @@
|
||||||
|
from piccolo.conf.apps import AppRegistry
|
||||||
|
from piccolo.engine.postgres import PostgresEngine
|
||||||
|
|
||||||
|
DB = PostgresEngine(
|
||||||
|
config={
|
||||||
|
"host": "mouse.db.elephantsql.com",
|
||||||
|
"user": "zubtlqiv",
|
||||||
|
"password": "l9bgeChDsTHJUwSFex2N2kSbLZ8VkRch",
|
||||||
|
"database": "zubtlqiv",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# A list of paths to piccolo apps
|
||||||
|
# e.g. ['blog.piccolo_app']
|
||||||
|
APP_REGISTRY = AppRegistry(apps=[])
|
|
@ -47,7 +47,9 @@ async def send_email(
|
||||||
await srv.login(login_email, password)
|
await srv.login(login_email, password)
|
||||||
await srv.sendmail(sender, recipient, msg.as_string())
|
await srv.sendmail(sender, recipient, msg.as_string())
|
||||||
except (aiosmtplib.SMTPException, asyncio.TimeoutError) as error:
|
except (aiosmtplib.SMTPException, asyncio.TimeoutError) as error:
|
||||||
logging.error(f"An error occurred while dealing with {host}:{port} (SMTP): {type(error).__name__}: {error}")
|
logging.error(
|
||||||
|
f"An error occurred while dealing with {host}:{port} (SMTP): {type(error).__name__}: {error}"
|
||||||
|
)
|
||||||
return error
|
return error
|
||||||
return True
|
return True
|
||||||
|
|
|
@ -18,7 +18,10 @@ async def rate_limited(request: Request, error: RateLimitExceeded) -> JSONRespon
|
||||||
else:
|
else:
|
||||||
break
|
break
|
||||||
error.detail = error.detail[:n] + " requests" + error.detail[n:]
|
error.detail = error.detail[:n] + " requests" + error.detail[n:]
|
||||||
LOGGER.info(f"{request.client.host} got rate-limited at {str(request.url)} " f"(exceeded {error.detail})")
|
LOGGER.info(
|
||||||
|
f"{request.client.host} got rate-limited at {str(request.url)} "
|
||||||
|
f"(exceeded {error.detail})"
|
||||||
|
)
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
status_code=200,
|
status_code=200,
|
||||||
content=dict(
|
content=dict(
|
||||||
|
@ -34,7 +37,9 @@ def not_authenticated(request: Request, _: NotAuthenticated) -> JSONResponse:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
LOGGER.info(f"{request.client.host} failed to authenticate at {str(request.url)}")
|
LOGGER.info(f"{request.client.host} failed to authenticate at {str(request.url)}")
|
||||||
return JSONResponse(status_code=200, content=dict(status_code=401, msg="Authentication is required"))
|
return JSONResponse(
|
||||||
|
status_code=200, content=dict(status_code=401, msg="Authentication is required")
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def request_invalid(request: Request, exc: RequestValidationError) -> JSONResponse:
|
def request_invalid(request: Request, exc: RequestValidationError) -> JSONResponse:
|
||||||
|
@ -42,14 +47,18 @@ def request_invalid(request: Request, exc: RequestValidationError) -> JSONRespon
|
||||||
Handles Bad Request exceptions from FastAPI
|
Handles Bad Request exceptions from FastAPI
|
||||||
"""
|
"""
|
||||||
|
|
||||||
LOGGER.info(f"{request.client.host} sent an invalid request at {request.url!r}: {type(exc).__name__}: {exc}")
|
LOGGER.info(
|
||||||
|
f"{request.client.host} sent an invalid request at {request.url!r}: {type(exc).__name__}: {exc}"
|
||||||
|
)
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
status_code=200,
|
status_code=200,
|
||||||
content=dict(status_code=400, msg=f"Bad request: {type(exc).__name__}: {exc}"),
|
content=dict(status_code=400, msg=f"Bad request: {type(exc).__name__}: {exc}"),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def http_exception(request: Request, exc: HTTPException | StarletteHTTPException) -> JSONResponse:
|
def http_exception(
|
||||||
|
request: Request, exc: HTTPException | StarletteHTTPException
|
||||||
|
) -> JSONResponse:
|
||||||
"""
|
"""
|
||||||
Handles HTTP-specific exceptions raised explicitly by
|
Handles HTTP-specific exceptions raised explicitly by
|
||||||
path operations
|
path operations
|
||||||
|
@ -57,12 +66,19 @@ def http_exception(request: Request, exc: HTTPException | StarletteHTTPException
|
||||||
|
|
||||||
if exc.status_code >= 500:
|
if exc.status_code >= 500:
|
||||||
LOGGER.error(
|
LOGGER.error(
|
||||||
f"{request.client.host} raised a {exc.status_code} error at {request.url!r}:" f"{type(exc).__name__}: {exc}"
|
f"{request.client.host} raised a {exc.status_code} error at {request.url!r}:"
|
||||||
|
f"{type(exc).__name__}: {exc}"
|
||||||
|
)
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=200, content=dict(status_code=500, msg="Internal Server Error")
|
||||||
)
|
)
|
||||||
return JSONResponse(status_code=200, content=dict(status_code=500, msg="Internal Server Error"))
|
|
||||||
else:
|
else:
|
||||||
LOGGER.info(f"{request.client.host} raised an HTTP error ({exc.status_code}) at {str(request.url)}")
|
LOGGER.info(
|
||||||
return JSONResponse(status_code=200, content=dict(status_code=exc.status_code, msg=exc.detail))
|
f"{request.client.host} raised an HTTP error ({exc.status_code}) at {str(request.url)}"
|
||||||
|
)
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=200, content=dict(status_code=exc.status_code, msg=exc.detail)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def generic_error(request: Request, exc: Exception) -> JSONResponse:
|
async def generic_error(request: Request, exc: Exception) -> JSONResponse:
|
||||||
|
@ -70,6 +86,10 @@ async def generic_error(request: Request, exc: Exception) -> JSONResponse:
|
||||||
Handles generic, unexpected errors in the ASGI application
|
Handles generic, unexpected errors in the ASGI application
|
||||||
"""
|
"""
|
||||||
|
|
||||||
LOGGER.info(f"{request.client.host} raised an unexpected error ({type(exc).__name__}: {exc}) at {str(request.url)}")
|
LOGGER.info(
|
||||||
|
f"{request.client.host} raised an unexpected error ({type(exc).__name__}: {exc}) at {str(request.url)}"
|
||||||
|
)
|
||||||
# We can't leak anything about the error, it would be too risky
|
# We can't leak anything about the error, it would be too risky
|
||||||
return JSONResponse(status_code=200, content=dict(status_code=500, msg="Internal Server Error"))
|
return JSONResponse(
|
||||||
|
status_code=200, content=dict(status_code=500, msg="Internal Server Error")
|
||||||
|
)
|
|
@ -0,0 +1,122 @@
|
||||||
|
import hashlib
|
||||||
|
import imghdr
|
||||||
|
import re
|
||||||
|
import zlib
|
||||||
|
import validators
|
||||||
|
from uuid import UUID
|
||||||
|
from config import (
|
||||||
|
VALIDATE_PASSWORD_REGEX,
|
||||||
|
VALIDATE_USERNAME_REGEX,
|
||||||
|
FORCE_EMAIL_VERIFICATION,
|
||||||
|
MAX_MEDIA_SIZE,
|
||||||
|
ZLIB_COMPRESSION_LEVEL,
|
||||||
|
ALLOWED_MEDIA_TYPES,
|
||||||
|
)
|
||||||
|
from orm.users import (
|
||||||
|
get_user_by_username,
|
||||||
|
get_user_by_email,
|
||||||
|
UserModel,
|
||||||
|
get_user_by_id,
|
||||||
|
)
|
||||||
|
from fastapi import UploadFile
|
||||||
|
from fastapi.exceptions import HTTPException
|
||||||
|
|
||||||
|
|
||||||
|
async def validate_user(
|
||||||
|
first_name: str | None,
|
||||||
|
last_name: str | None,
|
||||||
|
username: str | None,
|
||||||
|
email: str | None,
|
||||||
|
password: str | None,
|
||||||
|
bio: str | None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Performs some validation upon user creation. Returns
|
||||||
|
a tuple (success, msg) to be used by routes. Values
|
||||||
|
set to None are not checked against
|
||||||
|
"""
|
||||||
|
|
||||||
|
if first_name and len(first_name) > 64:
|
||||||
|
return False, "first name is too long"
|
||||||
|
if first_name and len(first_name) < 5:
|
||||||
|
return False, "first name is too short"
|
||||||
|
if last_name and len(last_name) > 64:
|
||||||
|
return False, "last name is too long"
|
||||||
|
if last_name and len(last_name) < 2:
|
||||||
|
return False, "last name is too short"
|
||||||
|
if username and len(username) < 5:
|
||||||
|
return False, "username is too short"
|
||||||
|
if username and len(username) > 32:
|
||||||
|
return False, "username is too long"
|
||||||
|
if (
|
||||||
|
username
|
||||||
|
and VALIDATE_USERNAME_REGEX
|
||||||
|
and not re.match(VALIDATE_USERNAME_REGEX, username)
|
||||||
|
):
|
||||||
|
return False, "username is invalid"
|
||||||
|
if email and not validators.email(email):
|
||||||
|
return False, "email is not valid"
|
||||||
|
if password and len(password) > 72:
|
||||||
|
return False, "password is too long"
|
||||||
|
if (
|
||||||
|
password
|
||||||
|
and VALIDATE_PASSWORD_REGEX
|
||||||
|
and not re.match(VALIDATE_PASSWORD_REGEX, password)
|
||||||
|
):
|
||||||
|
return False, "password is too weak"
|
||||||
|
if username and await get_user_by_username(
|
||||||
|
username, deleted_ok=True, restricted_ok=True
|
||||||
|
):
|
||||||
|
return False, "username is already taken"
|
||||||
|
if email and await get_user_by_email(email, deleted_ok=True, restricted_ok=True):
|
||||||
|
return False, "email is already registered"
|
||||||
|
if bio and len(bio) > 4096:
|
||||||
|
return False, "bio is too long"
|
||||||
|
if bio:
|
||||||
|
try:
|
||||||
|
bio.encode("utf-8")
|
||||||
|
except UnicodeDecodeError:
|
||||||
|
return False, "bio contains invalid characters"
|
||||||
|
return True, ""
|
||||||
|
|
||||||
|
|
||||||
|
async def validate_profile_picture(
|
||||||
|
file: UploadFile,
|
||||||
|
) -> tuple[bool | None, str, bytes, str]:
|
||||||
|
"""
|
||||||
|
Validates a profile picture's size and content to see if it fits
|
||||||
|
our criteria and returns a tuple result, ext, data where result is a
|
||||||
|
boolean or none (True = check was passed, False = size too large,
|
||||||
|
None = check was failed for other reasons) indicating if the check was successful,
|
||||||
|
ext is the file's type and extension, data is a compressed stream of bytes
|
||||||
|
representing the original media and hash is the file's SHA256 hash encoded in
|
||||||
|
hexadecimal, before the compression. This function never raises an exception
|
||||||
|
"""
|
||||||
|
|
||||||
|
async with file:
|
||||||
|
try:
|
||||||
|
content = await file.read()
|
||||||
|
if len(content) > MAX_MEDIA_SIZE:
|
||||||
|
return False, "", b"", ""
|
||||||
|
if not (ext := imghdr.what(content.decode())) in ALLOWED_MEDIA_TYPES:
|
||||||
|
return None, "", b"", ""
|
||||||
|
return (
|
||||||
|
True,
|
||||||
|
ext,
|
||||||
|
zlib.compress(content, ZLIB_COMPRESSION_LEVEL),
|
||||||
|
hashlib.sha256(content).hexdigest(),
|
||||||
|
)
|
||||||
|
except (UnicodeDecodeError, zlib.error):
|
||||||
|
return None, "", b"", ""
|
||||||
|
|
||||||
|
|
||||||
|
# Credential loaders for our authenticated routes
|
||||||
|
async def get_self_by_id(public_id: UUID, requires_verified: bool = True) -> UserModel:
|
||||||
|
user = await get_user_by_id(public_id, include_secrets=True, restricted_ok=True)
|
||||||
|
if FORCE_EMAIL_VERIFICATION and requires_verified and not user.email_verified:
|
||||||
|
raise HTTPException(status_code=401, detail="Email verification is required")
|
||||||
|
return user
|
||||||
|
|
||||||
|
|
||||||
|
async def get_self_by_id_unverified(public_id: UUID) -> UserModel:
|
||||||
|
return await get_self_by_id(public_id, False)
|
Loading…
Reference in New Issue