Initial work on users
This commit is contained in:
parent
4a06840a81
commit
d468aff227
|
@ -138,3 +138,5 @@ dmypy.json
|
||||||
# Cython debug symbols
|
# Cython debug symbols
|
||||||
cython_debug/
|
cython_debug/
|
||||||
|
|
||||||
|
/config.py
|
||||||
|
/piccolo_conf.py
|
||||||
|
|
|
@ -0,0 +1,10 @@
|
||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<module type="PYTHON_MODULE" version="4">
|
||||||
|
<component name="NewModuleRootManager">
|
||||||
|
<content url="file://$MODULE_DIR$">
|
||||||
|
<excludeFolder url="file://$MODULE_DIR$/venv" />
|
||||||
|
</content>
|
||||||
|
<orderEntry type="inheritedJdk" />
|
||||||
|
<orderEntry type="sourceFolder" forTests="false" />
|
||||||
|
</component>
|
||||||
|
</module>
|
37
README.md
37
README.md
|
@ -1,3 +1,38 @@
|
||||||
# PySimpleSocial
|
# PySimpleSocial
|
||||||
|
|
||||||
An advanced REST API written in Python for a generic social media website
|
An advanced REST API written in Python for a generic social media website.
|
||||||
|
|
||||||
|
__Note__: This is a WIP so far.
|
||||||
|
|
||||||
|
## Tech Stack
|
||||||
|
|
||||||
|
The project is written using [FastAPI](https://https://fastapi.tiangolo.com/), [piccolo](https://piccolo-orm.readthedocs.io/en/latest/) and [uvicorn](https://www.uvicorn.org/). Other awesome libraries used (only direct dependencies are listed here):
|
||||||
|
- [validators](https://validators.readthedocs.io/en/latest/)
|
||||||
|
- [pydantic](https://pydantic-docs.helpmanual.io/)
|
||||||
|
- [aiosmtplib](https://aiosmtplib.readthedocs.io/en/latest/usage.html)
|
||||||
|
- [uvloop](https://uvloop.readthedocs.io/)
|
||||||
|
- [slowapi](https://slowapi.readthedocs.io/en/latest/)
|
||||||
|
- [fastapi-login](https://fastapi-login.readthedocs.io/)
|
||||||
|
- [bcrypt](https://github.com/pyca/bcrypt/)
|
||||||
|
|
||||||
|
# Feature overview
|
||||||
|
|
||||||
|
__Note__: Not all of this is implemented yet
|
||||||
|
|
||||||
|
- Simple authentication system using salted bcrypt hashes for password storage
|
||||||
|
- PostgreSQL is used as the main database
|
||||||
|
- Simple rate limiting using redis/in-memory storage
|
||||||
|
- Support for various kinds of media stored in a CDN, directly inside the database or on a local/remote filesystem
|
||||||
|
- Regular social media mechanics: (Un)following users, posting media with captions, etc.
|
||||||
|
- User settings (change username, profile picture, etc.)
|
||||||
|
- Simple messaging system using websockets or a polling HTTP API
|
||||||
|
- Admin functionality with basic metrics and administration features (flagging/deleting users/posts, handling tickets, etc.)
|
||||||
|
|
||||||
|
# Setup
|
||||||
|
|
||||||
|
Move the *.py.example files to their respective *.py files, fill them as necessary, then simply install the dependencies via pip and run main.py
|
||||||
|
|
||||||
|
# License
|
||||||
|
|
||||||
|
This software is licensed under the MIT license. For more information, read the [license file](LICENSE)
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,115 @@
|
||||||
|
# Configuration file. Each variable can be overridden by a
|
||||||
|
# corresponding environment variable
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import logging
|
||||||
|
import pathlib
|
||||||
|
|
||||||
|
from slowapi import Limiter
|
||||||
|
from slowapi.util import get_remote_address
|
||||||
|
from fastapi_login import LoginManager
|
||||||
|
|
||||||
|
|
||||||
|
# Authentication configuration
|
||||||
|
LOGIN_SECRET_KEY = (
|
||||||
|
os.getenv("LOGIN_SECRET_KEY") or "login-secret"
|
||||||
|
) # Recommended value: os.urandom(24).hex()
|
||||||
|
USE_BEARER_HEADER = True
|
||||||
|
USE_COOKIE = True
|
||||||
|
|
||||||
|
# Logging configuration
|
||||||
|
LOG_LEVEL = int(os.getenv("LOG_LEVEL") or 0) or 20
|
||||||
|
LOG_FILE = os.getenv("LOG_FILE") or "" # Empty to disable
|
||||||
|
LOG_FORMAT = os.getenv("LOG_FORMAT") or "[%(levelname)s - %(asctime)s] %(message)s"
|
||||||
|
LOG_DATE_FORMAT = os.getenv("LOG_DATE_FORMAT") or "%d/%m/%Y %p"
|
||||||
|
|
||||||
|
# Bcrypt configuration
|
||||||
|
BCRYPT_ROUNDS = (
|
||||||
|
os.getenv("BCRYPT_ROUNDS") or 10
|
||||||
|
) # How many rounds are used when salting
|
||||||
|
|
||||||
|
|
||||||
|
# Rate limit configuration
|
||||||
|
REDIS_URL = "" # Used for rate limits. Empty to fall back to memory
|
||||||
|
REDIS_OPTIONS = {} # Options for redis
|
||||||
|
RATELIMIT_ENABLED = False # False to disable rate limiting
|
||||||
|
RATELIMIT_STRATEGY = "moving-window" # Refer to https://flask-limiter.readthedocs.io/en/stable/strategies.html
|
||||||
|
|
||||||
|
# Session configuration
|
||||||
|
SESSION_EXPIRE_LIMIT = 3600 # Unit is in seconds
|
||||||
|
SESSION_COOKIE_NAME = "_social_media_session"
|
||||||
|
COOKIE_SAMESITE_POLICY = "none" # Options are "lax", "none", "strict"
|
||||||
|
COOKIE_DOMAIN = "localhost" # Empty to disable this
|
||||||
|
COOKIE_PATH = "/"
|
||||||
|
COOKIE_HTTPONLY = True
|
||||||
|
SECURE_COOKIE = False # Set to true in production, False during development (unless your local server has HTTPS)
|
||||||
|
|
||||||
|
# SMTP configuration
|
||||||
|
SMTP_HOST = "smtp.nocturn9x.space"
|
||||||
|
SMTP_USER = "info@example.com"
|
||||||
|
SMTP_PASSWORD = "password"
|
||||||
|
SMTP_PORT = 587
|
||||||
|
SMTP_USE_TLS = True
|
||||||
|
SMTP_FROM_USER = "info@example.com"
|
||||||
|
SMTP_TEMPLATES_DIRECTORY = pathlib.Path(__file__) / "templates" / "email"
|
||||||
|
SMTP_TIMEOUT = 10
|
||||||
|
|
||||||
|
# Miscellaneous
|
||||||
|
|
||||||
|
# Usernames containing these characters are not valid
|
||||||
|
INVALID_USERNAME_CHARACTERS = ["@", "\\", "/"] # Empty this to allow any character
|
||||||
|
# Empty this to disable username validation
|
||||||
|
VALIDATE_USERNAME_REGEX = (
|
||||||
|
rf"^([^{''.join(INVALID_USERNAME_CHARACTERS)}]|[a-z0-9A-Z]){{5,32}}$"
|
||||||
|
)
|
||||||
|
# Criteria:
|
||||||
|
# - Between 10 and 72 characters long
|
||||||
|
# - At least 2 uppercase letters
|
||||||
|
# - At least 3 lowercase letters
|
||||||
|
# - At least one special character (!, @, #, $, %, &, *, _, +, /, \, (, ), £, ", ?, ^
|
||||||
|
# - At least 2 numbers
|
||||||
|
# You can change the repetitions to enforce stricter/laxer rules or empty
|
||||||
|
# this field to disable weakness validation
|
||||||
|
VALIDATE_PASSWORD_REGEX = r"^(?=.*[A-Z]){2,}(?=.*[!@%#$&*_^\?\\\/(\)\+\-])+(?=.*[0-9]){2,}(?=.*[a-z]){3,}.{10,72}$"
|
||||||
|
|
||||||
|
|
||||||
|
class NotAuthenticated(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ != "__main__":
|
||||||
|
LOGGER: logging.Logger = logging.getLogger("socialMedia")
|
||||||
|
LOGGER.setLevel(LOG_LEVEL)
|
||||||
|
handler = logging.StreamHandler(sys.stderr)
|
||||||
|
formatter = logging.Formatter(fmt=LOG_FORMAT, datefmt=LOG_DATE_FORMAT)
|
||||||
|
handler.setFormatter(formatter)
|
||||||
|
LOGGER.addHandler(handler)
|
||||||
|
handler.setLevel(LOG_LEVEL)
|
||||||
|
if LOG_FILE:
|
||||||
|
file_handler = logging.FileHandler(LOG_FILE, "a", "utf8")
|
||||||
|
file_handler.setFormatter(formatter)
|
||||||
|
LOGGER.addHandler(handler)
|
||||||
|
file_handler.setLevel(LOG_LEVEL)
|
||||||
|
logging.getLogger("uvicorn").addHandler(file_handler)
|
||||||
|
LIMITER = Limiter(
|
||||||
|
key_func=get_remote_address,
|
||||||
|
strategy=RATELIMIT_STRATEGY,
|
||||||
|
storage_uri=REDIS_URL or None,
|
||||||
|
in_memory_fallback_enabled=bool(REDIS_URL),
|
||||||
|
storage_options=REDIS_OPTIONS,
|
||||||
|
enabled=RATELIMIT_ENABLED,
|
||||||
|
)
|
||||||
|
MANAGER = LoginManager(
|
||||||
|
LOGIN_SECRET_KEY,
|
||||||
|
"/login",
|
||||||
|
use_cookie=True,
|
||||||
|
cookie_name=SESSION_COOKIE_NAME,
|
||||||
|
use_header=USE_BEARER_HEADER,
|
||||||
|
custom_exception=NotAuthenticated,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Uvicorn config
|
||||||
|
|
||||||
|
HOST = os.getenv("HOST") or "localhost"
|
||||||
|
PORT = int(os.getenv("PORT") or 0) or 8000
|
||||||
|
WORKERS = int(os.getenv("PORT") or 0) or 1
|
|
@ -0,0 +1,334 @@
|
||||||
|
import re
|
||||||
|
|
||||||
|
import bcrypt
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
import validators
|
||||||
|
from fastapi import APIRouter as FastAPI, Depends, Response, Request
|
||||||
|
from fastapi.exceptions import HTTPException
|
||||||
|
from fastapi.security import OAuth2PasswordRequestForm
|
||||||
|
from datetime import timedelta
|
||||||
|
|
||||||
|
from config import (
|
||||||
|
BCRYPT_ROUNDS,
|
||||||
|
LOGGER,
|
||||||
|
SESSION_EXPIRE_LIMIT,
|
||||||
|
COOKIE_SAMESITE_POLICY,
|
||||||
|
COOKIE_DOMAIN,
|
||||||
|
MANAGER,
|
||||||
|
SESSION_COOKIE_NAME,
|
||||||
|
LIMITER,
|
||||||
|
VALIDATE_PASSWORD_REGEX,
|
||||||
|
VALIDATE_USERNAME_REGEX,
|
||||||
|
SECURE_COOKIE,
|
||||||
|
COOKIE_PATH,
|
||||||
|
COOKIE_HTTPONLY,
|
||||||
|
)
|
||||||
|
from orm.users import UserModel, User
|
||||||
|
|
||||||
|
router = FastAPI()
|
||||||
|
|
||||||
|
|
||||||
|
async def get_user_by_id(
|
||||||
|
public_id: UUID, include_secrets: bool = False, restricted_ok: bool = False,
|
||||||
|
deleted_ok: bool = False
|
||||||
|
) -> dict | None:
|
||||||
|
"""
|
||||||
|
Retrieves a user by its public ID
|
||||||
|
"""
|
||||||
|
|
||||||
|
user = (
|
||||||
|
await User.select(
|
||||||
|
*User.all_columns(exclude=["public_id"]),
|
||||||
|
User.public_id.as_alias("id"),
|
||||||
|
exclude_secrets=not include_secrets,
|
||||||
|
)
|
||||||
|
.where(User.public_id == public_id)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
if user:
|
||||||
|
# Performs validation
|
||||||
|
UserModel(**user)
|
||||||
|
if (user["deleted"] and not deleted_ok) or (user["restricted"] and not restricted_ok):
|
||||||
|
return
|
||||||
|
return user
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
@MANAGER.user_loader()
|
||||||
|
async def get_self_by_id(public_id: UUID) -> dict:
|
||||||
|
return await get_user_by_id(public_id, include_secrets=True, restricted_ok=True)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_user_by_username(
|
||||||
|
username: str, include_secrets: bool = False, restricted_ok: bool = False,
|
||||||
|
deleted_ok: bool = False
|
||||||
|
) -> dict | None:
|
||||||
|
"""
|
||||||
|
Retrieves a user by its public username
|
||||||
|
"""
|
||||||
|
|
||||||
|
user = (
|
||||||
|
await User.select(
|
||||||
|
*User.all_columns(exclude=["public_id"]),
|
||||||
|
User.public_id.as_alias("id"),
|
||||||
|
exclude_secrets=not include_secrets,
|
||||||
|
)
|
||||||
|
.where(User.username == username)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
if user:
|
||||||
|
# Performs validation
|
||||||
|
UserModel(**user)
|
||||||
|
if (user["deleted"] and not deleted_ok) or (user["restricted"] and not restricted_ok):
|
||||||
|
return
|
||||||
|
return user
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
async def get_user_by_email(
|
||||||
|
email: str, include_secrets: bool = False, restricted_ok: bool = False,
|
||||||
|
deleted_ok: bool = False
|
||||||
|
) -> dict | None:
|
||||||
|
"""
|
||||||
|
Retrieves a user by its email address (meant to
|
||||||
|
be used internally)
|
||||||
|
"""
|
||||||
|
|
||||||
|
user = (
|
||||||
|
await User.select(
|
||||||
|
*User.all_columns(exclude=["public_id"]),
|
||||||
|
User.public_id.as_alias("id"),
|
||||||
|
exclude_secrets=not include_secrets,
|
||||||
|
)
|
||||||
|
.where(User.email_address == email)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
if user:
|
||||||
|
# Performs validation
|
||||||
|
UserModel(**user)
|
||||||
|
if (user["deleted"] and not deleted_ok) or (user["restricted"] and not restricted_ok):
|
||||||
|
return
|
||||||
|
return user
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
@LIMITER.limit("5/minute")
|
||||||
|
@router.post("/user")
|
||||||
|
async def login(
|
||||||
|
request: Request, response: Response, data: OAuth2PasswordRequestForm = Depends()
|
||||||
|
) -> dict:
|
||||||
|
if request.cookies.get(SESSION_COOKIE_NAME):
|
||||||
|
raise HTTPException(status_code=400, detail="Please logout first")
|
||||||
|
username = data.username
|
||||||
|
if len(username) > 32:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=413, detail="Authentication failed: username is too long"
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
password = data.password.encode()
|
||||||
|
if len(password) > 72:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=413, detail="Authentication failed: password is too long"
|
||||||
|
)
|
||||||
|
except UnicodeEncodeError as e:
|
||||||
|
LOGGER.warning(
|
||||||
|
f"An error occurred while attempting to decode password for user {username} -> {type(e).__name__}: {e}"
|
||||||
|
)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=413,
|
||||||
|
detail="Authentication failed: invalid characters in password",
|
||||||
|
)
|
||||||
|
if not (
|
||||||
|
user := await get_user_by_username(
|
||||||
|
username, include_secrets=True, restricted_ok=True
|
||||||
|
)
|
||||||
|
):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=413,
|
||||||
|
detail="Authentication failed: the user does not exist",
|
||||||
|
)
|
||||||
|
if not bcrypt.checkpw(password, user["password_hash"]):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=413,
|
||||||
|
detail="Authentication failed: password mismatch",
|
||||||
|
)
|
||||||
|
token = MANAGER.create_access_token(
|
||||||
|
expires=timedelta(seconds=SESSION_EXPIRE_LIMIT), data={"sub": str(user["id"])}
|
||||||
|
)
|
||||||
|
response.set_cookie(
|
||||||
|
secure=SECURE_COOKIE,
|
||||||
|
key=SESSION_COOKIE_NAME,
|
||||||
|
max_age=SESSION_EXPIRE_LIMIT,
|
||||||
|
value=token,
|
||||||
|
httponly=COOKIE_HTTPONLY,
|
||||||
|
samesite=COOKIE_SAMESITE_POLICY,
|
||||||
|
domain=COOKIE_DOMAIN or None,
|
||||||
|
path=COOKIE_PATH or "/",
|
||||||
|
)
|
||||||
|
return {"status_code": 200, "msg": "Authentication successful"}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/user/logout")
|
||||||
|
@LIMITER.limit("5/minute")
|
||||||
|
async def logout(
|
||||||
|
request: Request, response: Response, user: dict = Depends(MANAGER)
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
Deletes a user's session cookie, logging them
|
||||||
|
out
|
||||||
|
"""
|
||||||
|
|
||||||
|
response.delete_cookie(
|
||||||
|
secure=SECURE_COOKIE,
|
||||||
|
key=SESSION_COOKIE_NAME,
|
||||||
|
httponly=COOKIE_HTTPONLY,
|
||||||
|
samesite=COOKIE_SAMESITE_POLICY,
|
||||||
|
domain=COOKIE_DOMAIN or None,
|
||||||
|
path=COOKIE_PATH or "/",
|
||||||
|
)
|
||||||
|
return {"status_code": 200, "msg": "Logged out"}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/user/me")
|
||||||
|
@LIMITER.limit("2/second")
|
||||||
|
async def get_self(request: Request, user: dict = Depends(MANAGER)) -> dict:
|
||||||
|
"""
|
||||||
|
Fetches a user's own info. This returns some
|
||||||
|
extra data such as email address, account
|
||||||
|
creation date and email verification status,
|
||||||
|
which is not available from the regular endpoint
|
||||||
|
"""
|
||||||
|
|
||||||
|
user.pop("password_hash")
|
||||||
|
user.pop("internal_id")
|
||||||
|
user.pop("deleted")
|
||||||
|
return {"status_code": 200, "msg": "Success", "data": user}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/user/username/{username}")
|
||||||
|
@LIMITER.limit("30/second")
|
||||||
|
async def get_user_by_name(
|
||||||
|
request: Request, username: str, _auth: dict = Depends(MANAGER)
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
Fetches a single user by its public ID
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not (user := await get_user_by_username(username)):
|
||||||
|
return {
|
||||||
|
"status_code": 404,
|
||||||
|
"msg": "Lookup failed: the user does not exist",
|
||||||
|
}
|
||||||
|
user.pop("restricted")
|
||||||
|
user.pop("deleted")
|
||||||
|
return {"status_code": 200, "msg": "Lookup successful", "data": user}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/user/id/{public_id}")
|
||||||
|
@LIMITER.limit("30/second")
|
||||||
|
async def get_user_by_public_id(
|
||||||
|
request: Request, public_id: str, _auth: dict = Depends(MANAGER)
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
Fetches a single user by its public ID
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not (user := await get_user_by_id(UUID(public_id))):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404, detail="Lookup failed: the user does not exist"
|
||||||
|
)
|
||||||
|
user.pop("restricted")
|
||||||
|
user.pop("deleted")
|
||||||
|
return {"status_code": 200, "msg": "Lookup successful", "data": user}
|
||||||
|
|
||||||
|
|
||||||
|
async def validate_user(
|
||||||
|
first_name: str, last_name: str, username: str, email: str, password: str
|
||||||
|
) -> tuple[bool, str]:
|
||||||
|
"""
|
||||||
|
Performs some validation upon user creation. Returns
|
||||||
|
a tuple (success, msg) to be used by routes
|
||||||
|
"""
|
||||||
|
|
||||||
|
if len(first_name) > 64:
|
||||||
|
return False, "first name is too long"
|
||||||
|
if len(first_name) < 5:
|
||||||
|
return False, "first name is too short"
|
||||||
|
if len(last_name) > 64:
|
||||||
|
return False, "last name is too long"
|
||||||
|
if len(last_name) < 2:
|
||||||
|
return False, "last name is too short"
|
||||||
|
if len(username) < 5:
|
||||||
|
return False, "username is too short"
|
||||||
|
if len(username) > 32:
|
||||||
|
return False, "username is too long"
|
||||||
|
if VALIDATE_USERNAME_REGEX and not re.match(VALIDATE_USERNAME_REGEX, username):
|
||||||
|
return False, "username is invalid"
|
||||||
|
if not validators.email(email):
|
||||||
|
return False, "email is not valid"
|
||||||
|
if len(password) > 72:
|
||||||
|
return False, "password is too long"
|
||||||
|
if VALIDATE_PASSWORD_REGEX and not re.match(VALIDATE_PASSWORD_REGEX, password):
|
||||||
|
return False, "password is too weak"
|
||||||
|
if await get_user_by_username(username, deleted_ok=True, restricted_ok=True):
|
||||||
|
return False, "username is already taken"
|
||||||
|
if await get_user_by_email(email, deleted_ok=True, restricted_ok=True):
|
||||||
|
return False, "email is already registered"
|
||||||
|
return True, ""
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/user")
|
||||||
|
@LIMITER.limit("1/minute")
|
||||||
|
async def delete(request: Request, response: Response, user: dict = Depends(MANAGER)) -> dict:
|
||||||
|
"""
|
||||||
|
Sets the user's deleted flag in the database,
|
||||||
|
without actually deleting the associated
|
||||||
|
data
|
||||||
|
"""
|
||||||
|
|
||||||
|
await User.update({User.deleted: True}).where(User.public_id == user["id"])
|
||||||
|
response.delete_cookie(
|
||||||
|
secure=SECURE_COOKIE,
|
||||||
|
key=SESSION_COOKIE_NAME,
|
||||||
|
httponly=COOKIE_HTTPONLY,
|
||||||
|
samesite=COOKIE_SAMESITE_POLICY,
|
||||||
|
domain=COOKIE_DOMAIN or None,
|
||||||
|
path=COOKIE_PATH or "/",)
|
||||||
|
return {"status_code": 200, "msg": "Success"}
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/user")
|
||||||
|
@LIMITER.limit("2/minute")
|
||||||
|
async def signup(
|
||||||
|
request: Request,
|
||||||
|
first_name: str,
|
||||||
|
last_name: str,
|
||||||
|
username: str,
|
||||||
|
email: str,
|
||||||
|
password: str,
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
Endpoint used to create new users
|
||||||
|
"""
|
||||||
|
|
||||||
|
if request.cookies.get(SESSION_COOKIE_NAME):
|
||||||
|
raise HTTPException(status_code=400, detail="Please logout first")
|
||||||
|
# We don't use FastAPI's validation because we want custom error
|
||||||
|
# messages
|
||||||
|
result, msg = await validate_user(first_name, last_name, username, email, password)
|
||||||
|
if not result:
|
||||||
|
return {"status_code": 413, "msg": f"Signup failed: {msg}"}
|
||||||
|
else:
|
||||||
|
await User.insert(
|
||||||
|
User(
|
||||||
|
first_name=first_name,
|
||||||
|
last_name=last_name,
|
||||||
|
username=username,
|
||||||
|
email_address=email,
|
||||||
|
password_hash=bcrypt.hashpw(
|
||||||
|
password.encode(), bcrypt.gensalt(BCRYPT_ROUNDS)
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return {"status_code": 200, "msg": "Success"}
|
|
@ -0,0 +1,127 @@
|
||||||
|
import uvloop
|
||||||
|
import asyncio
|
||||||
|
import uvicorn
|
||||||
|
|
||||||
|
from fastapi import FastAPI, Request
|
||||||
|
from fastapi.exceptions import (
|
||||||
|
HTTPException,
|
||||||
|
RequestValidationError,
|
||||||
|
StarletteHTTPException,
|
||||||
|
)
|
||||||
|
from slowapi.errors import RateLimitExceeded
|
||||||
|
|
||||||
|
|
||||||
|
from endpoints import users
|
||||||
|
from config import (
|
||||||
|
LOGGER,
|
||||||
|
LIMITER,
|
||||||
|
NotAuthenticated,
|
||||||
|
SMTP_HOST,
|
||||||
|
SMTP_USER,
|
||||||
|
SMTP_PORT,
|
||||||
|
SMTP_USE_TLS,
|
||||||
|
SMTP_PASSWORD,
|
||||||
|
SMTP_TIMEOUT,
|
||||||
|
HOST,
|
||||||
|
PORT,
|
||||||
|
WORKERS,
|
||||||
|
)
|
||||||
|
from orm import create_tables, Media
|
||||||
|
from util.exception_handlers import (
|
||||||
|
http_exception,
|
||||||
|
rate_limited,
|
||||||
|
request_invalid,
|
||||||
|
not_authenticated,
|
||||||
|
)
|
||||||
|
from util.email import test_smtp
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/")
|
||||||
|
@LIMITER.limit("10/second")
|
||||||
|
async def root(request: Request):
|
||||||
|
return {"status_code": 403, "msg": "Unauthorized"}
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/ping")
|
||||||
|
@LIMITER.limit("1/minute")
|
||||||
|
async def ping(request: Request) -> dict:
|
||||||
|
"""
|
||||||
|
This handler simply replies to "ping" requests and
|
||||||
|
is used to check whether the API is up and running.
|
||||||
|
It also performs a sanity check with the database and
|
||||||
|
the SMTP server to ensure that they are functioning correctly.
|
||||||
|
For this reason, this endpoint's rate limit is much stricter
|
||||||
|
"""
|
||||||
|
|
||||||
|
LOGGER.info(f"Processing ping request from {request.client.host}")
|
||||||
|
try:
|
||||||
|
await Media.raw("SELECT 1;")
|
||||||
|
await test_smtp(
|
||||||
|
SMTP_HOST,
|
||||||
|
SMTP_PORT,
|
||||||
|
SMTP_USER,
|
||||||
|
SMTP_PASSWORD,
|
||||||
|
SMTP_TIMEOUT,
|
||||||
|
SMTP_USE_TLS,
|
||||||
|
True,
|
||||||
|
)
|
||||||
|
return {"status_code": 200, "msg": "OK"}
|
||||||
|
except Exception:
|
||||||
|
raise HTTPException(500)
|
||||||
|
|
||||||
|
|
||||||
|
async def startup_checks():
|
||||||
|
LOGGER.info("Initializing database")
|
||||||
|
try:
|
||||||
|
await create_tables()
|
||||||
|
await Media.raw("SELECT 1;")
|
||||||
|
except Exception as e:
|
||||||
|
LOGGER.error(
|
||||||
|
f"An error occurred while trying to initialize the database -> {type(e.__name__)}: {e}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
LOGGER.info("Database initialized")
|
||||||
|
LOGGER.info("Testing SMTP connection")
|
||||||
|
try:
|
||||||
|
await test_smtp(
|
||||||
|
SMTP_HOST,
|
||||||
|
SMTP_PORT,
|
||||||
|
SMTP_USER,
|
||||||
|
SMTP_PASSWORD,
|
||||||
|
SMTP_TIMEOUT,
|
||||||
|
SMTP_USE_TLS,
|
||||||
|
True,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
LOGGER.error(
|
||||||
|
f"An error occurred while trying to connect to the SMTP server -> {type(e.__name__)}: {e}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
LOGGER.info("SMTP test was successful")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
LOGGER.info("Backend starting up!")
|
||||||
|
LOGGER.debug("Including modules")
|
||||||
|
app.include_router(users.router)
|
||||||
|
app.state.limiter = LIMITER
|
||||||
|
LOGGER.debug("Setting exception handlers")
|
||||||
|
app.add_exception_handler(RateLimitExceeded, rate_limited)
|
||||||
|
app.add_exception_handler(NotAuthenticated, not_authenticated)
|
||||||
|
app.add_exception_handler(HTTPException, http_exception)
|
||||||
|
app.add_exception_handler(StarletteHTTPException, http_exception)
|
||||||
|
app.add_exception_handler(RequestValidationError, request_invalid)
|
||||||
|
LOGGER.debug("Installing uvloop")
|
||||||
|
uvloop.install()
|
||||||
|
log_config = uvicorn.config.LOGGING_CONFIG
|
||||||
|
log_config["formatters"]["access"]["datefmt"] = LOGGER.handlers[0].formatter.datefmt
|
||||||
|
log_config["formatters"]["default"]["datefmt"] = LOGGER.handlers[
|
||||||
|
0
|
||||||
|
].formatter.datefmt
|
||||||
|
log_config["formatters"]["access"]["fmt"] = LOGGER.handlers[0].formatter._fmt
|
||||||
|
log_config["formatters"]["default"]["fmt"] = LOGGER.handlers[0].formatter._fmt
|
||||||
|
log_config["handlers"]["access"]["stream"] = "ext://sys.stderr"
|
||||||
|
asyncio.run(startup_checks())
|
||||||
|
uvicorn.run(host=HOST, port=PORT, app=app, log_config=log_config, workers=WORKERS)
|
|
@ -0,0 +1,14 @@
|
||||||
|
import asyncio
|
||||||
|
from piccolo.table import create_db_tables
|
||||||
|
from .users import User
|
||||||
|
from .media import Media
|
||||||
|
|
||||||
|
|
||||||
|
async def create_tables():
|
||||||
|
"""
|
||||||
|
Initializes the database
|
||||||
|
"""
|
||||||
|
|
||||||
|
await create_db_tables(User, Media, if_not_exists=True)
|
||||||
|
await User.create_index([User.public_id], if_not_exists=True)
|
||||||
|
await User.create_index([User.username], if_not_exists=True)
|
|
@ -0,0 +1,19 @@
|
||||||
|
"""
|
||||||
|
Media relation
|
||||||
|
"""
|
||||||
|
|
||||||
|
from piccolo.table import Table
|
||||||
|
from piccolo.columns import UUID, Text, Boolean, Date
|
||||||
|
from piccolo.columns.defaults.date import DateNow
|
||||||
|
|
||||||
|
|
||||||
|
class Media(Table):
|
||||||
|
"""
|
||||||
|
A piece of media on a CDN
|
||||||
|
"""
|
||||||
|
|
||||||
|
media_id = UUID(primary_key=True)
|
||||||
|
media_url = Text(null=False)
|
||||||
|
flagged = Boolean(default=False, null=False)
|
||||||
|
deleted = Boolean(default=False, null=False)
|
||||||
|
creation_date = Date(default=DateNow(), null=False)
|
|
@ -0,0 +1,47 @@
|
||||||
|
"""
|
||||||
|
User relation
|
||||||
|
"""
|
||||||
|
|
||||||
|
from piccolo.utils.pydantic import create_pydantic_model
|
||||||
|
from piccolo.table import Table
|
||||||
|
from piccolo.columns import (
|
||||||
|
ForeignKey,
|
||||||
|
Varchar,
|
||||||
|
BigSerial,
|
||||||
|
UUID,
|
||||||
|
Date,
|
||||||
|
OnDelete,
|
||||||
|
OnUpdate,
|
||||||
|
Boolean,
|
||||||
|
Email,
|
||||||
|
Bytea,
|
||||||
|
)
|
||||||
|
from piccolo.columns.defaults.date import DateNow
|
||||||
|
|
||||||
|
from .media import Media
|
||||||
|
|
||||||
|
|
||||||
|
class User(Table, tablename="users"):
|
||||||
|
internal_id = BigSerial(null=False, secret=True)
|
||||||
|
public_id = UUID(primary_key=True)
|
||||||
|
first_name = Varchar(length=64, null=False)
|
||||||
|
last_name = Varchar(length=64, null=True)
|
||||||
|
email_address = Email(secret=True)
|
||||||
|
username = Varchar(length=32, null=False, unique=True)
|
||||||
|
password_hash = Bytea(null=False, secret=True)
|
||||||
|
profile_picture = ForeignKey(
|
||||||
|
references=Media,
|
||||||
|
on_delete=OnDelete.set_null,
|
||||||
|
on_update=OnUpdate.cascade,
|
||||||
|
null=True,
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
creation_date = Date(secret=True, default=DateNow())
|
||||||
|
bio = Varchar(length=4096, null=True, default=None)
|
||||||
|
restricted = Boolean(default=False, null=False)
|
||||||
|
email_verified = Boolean(default=False, null=False, secret=True)
|
||||||
|
verified_account = Boolean(default=False, null=False)
|
||||||
|
deleted = Boolean(default=False, null=False)
|
||||||
|
|
||||||
|
|
||||||
|
UserModel = create_pydantic_model(User)
|
|
@ -0,0 +1,16 @@
|
||||||
|
from piccolo.conf.apps import AppRegistry
|
||||||
|
from piccolo.engine.postgres import PostgresEngine
|
||||||
|
|
||||||
|
DB = PostgresEngine(
|
||||||
|
config={
|
||||||
|
"host": "example.com",
|
||||||
|
"user": "database",
|
||||||
|
"password": "password",
|
||||||
|
"database": "database",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# A list of paths to piccolo apps
|
||||||
|
# e.g. ['blog.piccolo_app']
|
||||||
|
APP_REGISTRY = AppRegistry(apps=[])
|
|
@ -0,0 +1,10 @@
|
||||||
|
piccolo[all]~=0.91.0
|
||||||
|
fastapi-login
|
||||||
|
bcrypt~=4.0.0
|
||||||
|
slowapi~=0.1.6
|
||||||
|
python-multipart
|
||||||
|
uvloop~=0.17.0
|
||||||
|
fastapi~=0.85.0
|
||||||
|
validators
|
||||||
|
aiosmtplib
|
||||||
|
uvicorn
|
|
@ -0,0 +1,3 @@
|
||||||
|
[{
|
||||||
|
|
||||||
|
}]
|
|
@ -0,0 +1,87 @@
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import ssl
|
||||||
|
|
||||||
|
import aiosmtplib
|
||||||
|
from email.mime.base import MIMEBase
|
||||||
|
from email.mime.text import MIMEText
|
||||||
|
from email.mime.multipart import MIMEMultipart
|
||||||
|
from fastapi.responses import RedirectResponse
|
||||||
|
from typing import Union, List
|
||||||
|
|
||||||
|
|
||||||
|
async def send_email(
|
||||||
|
host: str,
|
||||||
|
port: int,
|
||||||
|
message: str,
|
||||||
|
timeout: int,
|
||||||
|
sender: str,
|
||||||
|
recipient: str,
|
||||||
|
subject: str,
|
||||||
|
login_email: str,
|
||||||
|
password: str,
|
||||||
|
attachments: List[MIMEBase] = tuple(),
|
||||||
|
use_tls: bool = True,
|
||||||
|
check_hostname: bool = True,
|
||||||
|
) -> Union[bool, aiosmtplib.SMTPException]:
|
||||||
|
"""
|
||||||
|
Sends an email with the given details. Returns True on success
|
||||||
|
or an exception object upon failure
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with aiosmtplib.SMTP(host, port, timeout=timeout) as srv:
|
||||||
|
msg = MIMEMultipart()
|
||||||
|
msg["From"] = sender
|
||||||
|
msg["To"] = recipient
|
||||||
|
msg["Subject"] = subject
|
||||||
|
msg.attach(MIMEText(message, "html"))
|
||||||
|
for attachment in attachments:
|
||||||
|
msg.attach(attachment)
|
||||||
|
await srv.ehlo() # We identify ourselves
|
||||||
|
if use_tls:
|
||||||
|
context = ssl.create_default_context()
|
||||||
|
context.check_hostname = check_hostname
|
||||||
|
await srv.starttls(tls_context=context)
|
||||||
|
await srv.ehlo() # We do it again, but encrypted!
|
||||||
|
await srv.login(login_email, password)
|
||||||
|
await srv.sendmail(sender, recipient, msg.as_string())
|
||||||
|
except (aiosmtplib.SMTPException, asyncio.TimeoutError) as error:
|
||||||
|
logging.error(
|
||||||
|
f"An error occurred while dealing with {host}:{port} (SMTP): {type(error).__name__}: {error}"
|
||||||
|
)
|
||||||
|
return error
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
async def test_smtp(
|
||||||
|
host: str,
|
||||||
|
port: int,
|
||||||
|
login_email: str,
|
||||||
|
password: str,
|
||||||
|
timeout: int,
|
||||||
|
use_tls: bool = True,
|
||||||
|
check_hostname: bool = True,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Attempts login to the given SMTP server with the given credentials.
|
||||||
|
Used upon startup, raises an exception upon failure. This will
|
||||||
|
fail if the server does not support TLS encryption for login
|
||||||
|
"""
|
||||||
|
|
||||||
|
async with aiosmtplib.SMTP(host, port, timeout=timeout) as srv:
|
||||||
|
await srv.ehlo()
|
||||||
|
if use_tls:
|
||||||
|
context = ssl.create_default_context()
|
||||||
|
context.check_hostname = check_hostname
|
||||||
|
await srv.starttls(tls_context=context)
|
||||||
|
await srv.ehlo()
|
||||||
|
await srv.login(login_email, password)
|
||||||
|
|
||||||
|
|
||||||
|
def redirect(url: str) -> RedirectResponse:
|
||||||
|
"""
|
||||||
|
Returns a redirect response to the given url
|
||||||
|
"""
|
||||||
|
|
||||||
|
return RedirectResponse(url=url)
|
|
@ -0,0 +1,79 @@
|
||||||
|
from config import LOGGER, NotAuthenticated
|
||||||
|
from fastapi import Request
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
from fastapi.exceptions import RequestValidationError
|
||||||
|
from fastapi.exceptions import HTTPException, StarletteHTTPException
|
||||||
|
from slowapi.errors import RateLimitExceeded
|
||||||
|
|
||||||
|
|
||||||
|
async def rate_limited(request: Request, error: RateLimitExceeded) -> JSONResponse:
|
||||||
|
n = 0
|
||||||
|
while True:
|
||||||
|
if error.detail[n].isnumeric():
|
||||||
|
n += 1
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
error.detail = error.detail[:n] + " requests" + error.detail[n:]
|
||||||
|
LOGGER.info(
|
||||||
|
f"{request.client.host} got rate-limited at {str(request.url)} "
|
||||||
|
f"(exceeded {error.detail})"
|
||||||
|
)
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=200,
|
||||||
|
content={
|
||||||
|
"msg": f"Too many requests, retry after {error.detail[error.detail.find('per') + 4:]}",
|
||||||
|
"status_code": 429,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def not_authenticated(request: Request, _: NotAuthenticated) -> JSONResponse:
|
||||||
|
LOGGER.info(f"{request.client.host} failed to authenticate at {str(request.url)}")
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=200,
|
||||||
|
content={
|
||||||
|
"msg": "Authentication is required",
|
||||||
|
"status_code": 401,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def request_invalid(request: Request, exc: RequestValidationError) -> JSONResponse:
|
||||||
|
LOGGER.info(
|
||||||
|
f"{request.client.host} sent an invalid request at {request.url!r}: {type(exc).__name__}: {exc}"
|
||||||
|
)
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=200,
|
||||||
|
content={
|
||||||
|
"msg": f"Bad request: {type(exc).__name__}: {exc}",
|
||||||
|
"status_code": 400,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def http_exception(
|
||||||
|
request: Request, exc: HTTPException | StarletteHTTPException
|
||||||
|
) -> JSONResponse:
|
||||||
|
if exc.status_code >= 500:
|
||||||
|
LOGGER.error(
|
||||||
|
f"{request.client.host} raised a {exc.status_code} error at {request.url!r}:"
|
||||||
|
f"{type(exc).__name__}: {exc}"
|
||||||
|
)
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=200,
|
||||||
|
content={
|
||||||
|
"msg": "Internal server error",
|
||||||
|
"status_code": exc.status_code,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
LOGGER.info(
|
||||||
|
f"{request.client.host} raised an HTTP error ({exc.status_code}) at {str(request.url)}"
|
||||||
|
)
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=200,
|
||||||
|
content={
|
||||||
|
"msg": exc.detail,
|
||||||
|
"status_code": exc.status_code,
|
||||||
|
},
|
||||||
|
)
|
Loading…
Reference in New Issue