Initial work on users
This commit is contained in:
parent
4a06840a81
commit
d468aff227
|
@ -138,3 +138,5 @@ dmypy.json
|
|||
# Cython debug symbols
|
||||
cython_debug/
|
||||
|
||||
/config.py
|
||||
/piccolo_conf.py
|
||||
|
|
|
@ -0,0 +1,10 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<module type="PYTHON_MODULE" version="4">
|
||||
<component name="NewModuleRootManager">
|
||||
<content url="file://$MODULE_DIR$">
|
||||
<excludeFolder url="file://$MODULE_DIR$/venv" />
|
||||
</content>
|
||||
<orderEntry type="inheritedJdk" />
|
||||
<orderEntry type="sourceFolder" forTests="false" />
|
||||
</component>
|
||||
</module>
|
37
README.md
37
README.md
|
@ -1,3 +1,38 @@
|
|||
# PySimpleSocial
|
||||
|
||||
An advanced REST API written in Python for a generic social media website
|
||||
An advanced REST API written in Python for a generic social media website.
|
||||
|
||||
__Note__: This is a WIP so far.
|
||||
|
||||
## Tech Stack
|
||||
|
||||
The project is written using [FastAPI](https://https://fastapi.tiangolo.com/), [piccolo](https://piccolo-orm.readthedocs.io/en/latest/) and [uvicorn](https://www.uvicorn.org/). Other awesome libraries used (only direct dependencies are listed here):
|
||||
- [validators](https://validators.readthedocs.io/en/latest/)
|
||||
- [pydantic](https://pydantic-docs.helpmanual.io/)
|
||||
- [aiosmtplib](https://aiosmtplib.readthedocs.io/en/latest/usage.html)
|
||||
- [uvloop](https://uvloop.readthedocs.io/)
|
||||
- [slowapi](https://slowapi.readthedocs.io/en/latest/)
|
||||
- [fastapi-login](https://fastapi-login.readthedocs.io/)
|
||||
- [bcrypt](https://github.com/pyca/bcrypt/)
|
||||
|
||||
# Feature overview
|
||||
|
||||
__Note__: Not all of this is implemented yet
|
||||
|
||||
- Simple authentication system using salted bcrypt hashes for password storage
|
||||
- PostgreSQL is used as the main database
|
||||
- Simple rate limiting using redis/in-memory storage
|
||||
- Support for various kinds of media stored in a CDN, directly inside the database or on a local/remote filesystem
|
||||
- Regular social media mechanics: (Un)following users, posting media with captions, etc.
|
||||
- User settings (change username, profile picture, etc.)
|
||||
- Simple messaging system using websockets or a polling HTTP API
|
||||
- Admin functionality with basic metrics and administration features (flagging/deleting users/posts, handling tickets, etc.)
|
||||
|
||||
# Setup
|
||||
|
||||
Move the *.py.example files to their respective *.py files, fill them as necessary, then simply install the dependencies via pip and run main.py
|
||||
|
||||
# License
|
||||
|
||||
This software is licensed under the MIT license. For more information, read the [license file](LICENSE)
|
||||
|
||||
|
|
|
@ -0,0 +1,115 @@
|
|||
# Configuration file. Each variable can be overridden by a
|
||||
# corresponding environment variable
|
||||
import os
|
||||
import sys
|
||||
import logging
|
||||
import pathlib
|
||||
|
||||
from slowapi import Limiter
|
||||
from slowapi.util import get_remote_address
|
||||
from fastapi_login import LoginManager
|
||||
|
||||
|
||||
# Authentication configuration
|
||||
LOGIN_SECRET_KEY = (
|
||||
os.getenv("LOGIN_SECRET_KEY") or "login-secret"
|
||||
) # Recommended value: os.urandom(24).hex()
|
||||
USE_BEARER_HEADER = True
|
||||
USE_COOKIE = True
|
||||
|
||||
# Logging configuration
|
||||
LOG_LEVEL = int(os.getenv("LOG_LEVEL") or 0) or 20
|
||||
LOG_FILE = os.getenv("LOG_FILE") or "" # Empty to disable
|
||||
LOG_FORMAT = os.getenv("LOG_FORMAT") or "[%(levelname)s - %(asctime)s] %(message)s"
|
||||
LOG_DATE_FORMAT = os.getenv("LOG_DATE_FORMAT") or "%d/%m/%Y %p"
|
||||
|
||||
# Bcrypt configuration
|
||||
BCRYPT_ROUNDS = (
|
||||
os.getenv("BCRYPT_ROUNDS") or 10
|
||||
) # How many rounds are used when salting
|
||||
|
||||
|
||||
# Rate limit configuration
|
||||
REDIS_URL = "" # Used for rate limits. Empty to fall back to memory
|
||||
REDIS_OPTIONS = {} # Options for redis
|
||||
RATELIMIT_ENABLED = False # False to disable rate limiting
|
||||
RATELIMIT_STRATEGY = "moving-window" # Refer to https://flask-limiter.readthedocs.io/en/stable/strategies.html
|
||||
|
||||
# Session configuration
|
||||
SESSION_EXPIRE_LIMIT = 3600 # Unit is in seconds
|
||||
SESSION_COOKIE_NAME = "_social_media_session"
|
||||
COOKIE_SAMESITE_POLICY = "none" # Options are "lax", "none", "strict"
|
||||
COOKIE_DOMAIN = "localhost" # Empty to disable this
|
||||
COOKIE_PATH = "/"
|
||||
COOKIE_HTTPONLY = True
|
||||
SECURE_COOKIE = False # Set to true in production, False during development (unless your local server has HTTPS)
|
||||
|
||||
# SMTP configuration
|
||||
SMTP_HOST = "smtp.nocturn9x.space"
|
||||
SMTP_USER = "info@example.com"
|
||||
SMTP_PASSWORD = "password"
|
||||
SMTP_PORT = 587
|
||||
SMTP_USE_TLS = True
|
||||
SMTP_FROM_USER = "info@example.com"
|
||||
SMTP_TEMPLATES_DIRECTORY = pathlib.Path(__file__) / "templates" / "email"
|
||||
SMTP_TIMEOUT = 10
|
||||
|
||||
# Miscellaneous
|
||||
|
||||
# Usernames containing these characters are not valid
|
||||
INVALID_USERNAME_CHARACTERS = ["@", "\\", "/"] # Empty this to allow any character
|
||||
# Empty this to disable username validation
|
||||
VALIDATE_USERNAME_REGEX = (
|
||||
rf"^([^{''.join(INVALID_USERNAME_CHARACTERS)}]|[a-z0-9A-Z]){{5,32}}$"
|
||||
)
|
||||
# Criteria:
|
||||
# - Between 10 and 72 characters long
|
||||
# - At least 2 uppercase letters
|
||||
# - At least 3 lowercase letters
|
||||
# - At least one special character (!, @, #, $, %, &, *, _, +, /, \, (, ), £, ", ?, ^
|
||||
# - At least 2 numbers
|
||||
# You can change the repetitions to enforce stricter/laxer rules or empty
|
||||
# this field to disable weakness validation
|
||||
VALIDATE_PASSWORD_REGEX = r"^(?=.*[A-Z]){2,}(?=.*[!@%#$&*_^\?\\\/(\)\+\-])+(?=.*[0-9]){2,}(?=.*[a-z]){3,}.{10,72}$"
|
||||
|
||||
|
||||
class NotAuthenticated(Exception):
|
||||
pass
|
||||
|
||||
|
||||
if __name__ != "__main__":
|
||||
LOGGER: logging.Logger = logging.getLogger("socialMedia")
|
||||
LOGGER.setLevel(LOG_LEVEL)
|
||||
handler = logging.StreamHandler(sys.stderr)
|
||||
formatter = logging.Formatter(fmt=LOG_FORMAT, datefmt=LOG_DATE_FORMAT)
|
||||
handler.setFormatter(formatter)
|
||||
LOGGER.addHandler(handler)
|
||||
handler.setLevel(LOG_LEVEL)
|
||||
if LOG_FILE:
|
||||
file_handler = logging.FileHandler(LOG_FILE, "a", "utf8")
|
||||
file_handler.setFormatter(formatter)
|
||||
LOGGER.addHandler(handler)
|
||||
file_handler.setLevel(LOG_LEVEL)
|
||||
logging.getLogger("uvicorn").addHandler(file_handler)
|
||||
LIMITER = Limiter(
|
||||
key_func=get_remote_address,
|
||||
strategy=RATELIMIT_STRATEGY,
|
||||
storage_uri=REDIS_URL or None,
|
||||
in_memory_fallback_enabled=bool(REDIS_URL),
|
||||
storage_options=REDIS_OPTIONS,
|
||||
enabled=RATELIMIT_ENABLED,
|
||||
)
|
||||
MANAGER = LoginManager(
|
||||
LOGIN_SECRET_KEY,
|
||||
"/login",
|
||||
use_cookie=True,
|
||||
cookie_name=SESSION_COOKIE_NAME,
|
||||
use_header=USE_BEARER_HEADER,
|
||||
custom_exception=NotAuthenticated,
|
||||
)
|
||||
|
||||
# Uvicorn config
|
||||
|
||||
HOST = os.getenv("HOST") or "localhost"
|
||||
PORT = int(os.getenv("PORT") or 0) or 8000
|
||||
WORKERS = int(os.getenv("PORT") or 0) or 1
|
|
@ -0,0 +1,334 @@
|
|||
import re
|
||||
|
||||
import bcrypt
|
||||
from uuid import UUID
|
||||
|
||||
import validators
|
||||
from fastapi import APIRouter as FastAPI, Depends, Response, Request
|
||||
from fastapi.exceptions import HTTPException
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
from datetime import timedelta
|
||||
|
||||
from config import (
|
||||
BCRYPT_ROUNDS,
|
||||
LOGGER,
|
||||
SESSION_EXPIRE_LIMIT,
|
||||
COOKIE_SAMESITE_POLICY,
|
||||
COOKIE_DOMAIN,
|
||||
MANAGER,
|
||||
SESSION_COOKIE_NAME,
|
||||
LIMITER,
|
||||
VALIDATE_PASSWORD_REGEX,
|
||||
VALIDATE_USERNAME_REGEX,
|
||||
SECURE_COOKIE,
|
||||
COOKIE_PATH,
|
||||
COOKIE_HTTPONLY,
|
||||
)
|
||||
from orm.users import UserModel, User
|
||||
|
||||
router = FastAPI()
|
||||
|
||||
|
||||
async def get_user_by_id(
|
||||
public_id: UUID, include_secrets: bool = False, restricted_ok: bool = False,
|
||||
deleted_ok: bool = False
|
||||
) -> dict | None:
|
||||
"""
|
||||
Retrieves a user by its public ID
|
||||
"""
|
||||
|
||||
user = (
|
||||
await User.select(
|
||||
*User.all_columns(exclude=["public_id"]),
|
||||
User.public_id.as_alias("id"),
|
||||
exclude_secrets=not include_secrets,
|
||||
)
|
||||
.where(User.public_id == public_id)
|
||||
.first()
|
||||
)
|
||||
if user:
|
||||
# Performs validation
|
||||
UserModel(**user)
|
||||
if (user["deleted"] and not deleted_ok) or (user["restricted"] and not restricted_ok):
|
||||
return
|
||||
return user
|
||||
return
|
||||
|
||||
|
||||
@MANAGER.user_loader()
|
||||
async def get_self_by_id(public_id: UUID) -> dict:
|
||||
return await get_user_by_id(public_id, include_secrets=True, restricted_ok=True)
|
||||
|
||||
|
||||
async def get_user_by_username(
|
||||
username: str, include_secrets: bool = False, restricted_ok: bool = False,
|
||||
deleted_ok: bool = False
|
||||
) -> dict | None:
|
||||
"""
|
||||
Retrieves a user by its public username
|
||||
"""
|
||||
|
||||
user = (
|
||||
await User.select(
|
||||
*User.all_columns(exclude=["public_id"]),
|
||||
User.public_id.as_alias("id"),
|
||||
exclude_secrets=not include_secrets,
|
||||
)
|
||||
.where(User.username == username)
|
||||
.first()
|
||||
)
|
||||
if user:
|
||||
# Performs validation
|
||||
UserModel(**user)
|
||||
if (user["deleted"] and not deleted_ok) or (user["restricted"] and not restricted_ok):
|
||||
return
|
||||
return user
|
||||
return
|
||||
|
||||
|
||||
async def get_user_by_email(
|
||||
email: str, include_secrets: bool = False, restricted_ok: bool = False,
|
||||
deleted_ok: bool = False
|
||||
) -> dict | None:
|
||||
"""
|
||||
Retrieves a user by its email address (meant to
|
||||
be used internally)
|
||||
"""
|
||||
|
||||
user = (
|
||||
await User.select(
|
||||
*User.all_columns(exclude=["public_id"]),
|
||||
User.public_id.as_alias("id"),
|
||||
exclude_secrets=not include_secrets,
|
||||
)
|
||||
.where(User.email_address == email)
|
||||
.first()
|
||||
)
|
||||
if user:
|
||||
# Performs validation
|
||||
UserModel(**user)
|
||||
if (user["deleted"] and not deleted_ok) or (user["restricted"] and not restricted_ok):
|
||||
return
|
||||
return user
|
||||
return
|
||||
|
||||
|
||||
@LIMITER.limit("5/minute")
|
||||
@router.post("/user")
|
||||
async def login(
|
||||
request: Request, response: Response, data: OAuth2PasswordRequestForm = Depends()
|
||||
) -> dict:
|
||||
if request.cookies.get(SESSION_COOKIE_NAME):
|
||||
raise HTTPException(status_code=400, detail="Please logout first")
|
||||
username = data.username
|
||||
if len(username) > 32:
|
||||
raise HTTPException(
|
||||
status_code=413, detail="Authentication failed: username is too long"
|
||||
)
|
||||
try:
|
||||
password = data.password.encode()
|
||||
if len(password) > 72:
|
||||
raise HTTPException(
|
||||
status_code=413, detail="Authentication failed: password is too long"
|
||||
)
|
||||
except UnicodeEncodeError as e:
|
||||
LOGGER.warning(
|
||||
f"An error occurred while attempting to decode password for user {username} -> {type(e).__name__}: {e}"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=413,
|
||||
detail="Authentication failed: invalid characters in password",
|
||||
)
|
||||
if not (
|
||||
user := await get_user_by_username(
|
||||
username, include_secrets=True, restricted_ok=True
|
||||
)
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=413,
|
||||
detail="Authentication failed: the user does not exist",
|
||||
)
|
||||
if not bcrypt.checkpw(password, user["password_hash"]):
|
||||
raise HTTPException(
|
||||
status_code=413,
|
||||
detail="Authentication failed: password mismatch",
|
||||
)
|
||||
token = MANAGER.create_access_token(
|
||||
expires=timedelta(seconds=SESSION_EXPIRE_LIMIT), data={"sub": str(user["id"])}
|
||||
)
|
||||
response.set_cookie(
|
||||
secure=SECURE_COOKIE,
|
||||
key=SESSION_COOKIE_NAME,
|
||||
max_age=SESSION_EXPIRE_LIMIT,
|
||||
value=token,
|
||||
httponly=COOKIE_HTTPONLY,
|
||||
samesite=COOKIE_SAMESITE_POLICY,
|
||||
domain=COOKIE_DOMAIN or None,
|
||||
path=COOKIE_PATH or "/",
|
||||
)
|
||||
return {"status_code": 200, "msg": "Authentication successful"}
|
||||
|
||||
|
||||
@router.get("/user/logout")
|
||||
@LIMITER.limit("5/minute")
|
||||
async def logout(
|
||||
request: Request, response: Response, user: dict = Depends(MANAGER)
|
||||
) -> dict:
|
||||
"""
|
||||
Deletes a user's session cookie, logging them
|
||||
out
|
||||
"""
|
||||
|
||||
response.delete_cookie(
|
||||
secure=SECURE_COOKIE,
|
||||
key=SESSION_COOKIE_NAME,
|
||||
httponly=COOKIE_HTTPONLY,
|
||||
samesite=COOKIE_SAMESITE_POLICY,
|
||||
domain=COOKIE_DOMAIN or None,
|
||||
path=COOKIE_PATH or "/",
|
||||
)
|
||||
return {"status_code": 200, "msg": "Logged out"}
|
||||
|
||||
|
||||
@router.get("/user/me")
|
||||
@LIMITER.limit("2/second")
|
||||
async def get_self(request: Request, user: dict = Depends(MANAGER)) -> dict:
|
||||
"""
|
||||
Fetches a user's own info. This returns some
|
||||
extra data such as email address, account
|
||||
creation date and email verification status,
|
||||
which is not available from the regular endpoint
|
||||
"""
|
||||
|
||||
user.pop("password_hash")
|
||||
user.pop("internal_id")
|
||||
user.pop("deleted")
|
||||
return {"status_code": 200, "msg": "Success", "data": user}
|
||||
|
||||
|
||||
@router.get("/user/username/{username}")
|
||||
@LIMITER.limit("30/second")
|
||||
async def get_user_by_name(
|
||||
request: Request, username: str, _auth: dict = Depends(MANAGER)
|
||||
) -> dict:
|
||||
"""
|
||||
Fetches a single user by its public ID
|
||||
"""
|
||||
|
||||
if not (user := await get_user_by_username(username)):
|
||||
return {
|
||||
"status_code": 404,
|
||||
"msg": "Lookup failed: the user does not exist",
|
||||
}
|
||||
user.pop("restricted")
|
||||
user.pop("deleted")
|
||||
return {"status_code": 200, "msg": "Lookup successful", "data": user}
|
||||
|
||||
|
||||
@router.get("/user/id/{public_id}")
|
||||
@LIMITER.limit("30/second")
|
||||
async def get_user_by_public_id(
|
||||
request: Request, public_id: str, _auth: dict = Depends(MANAGER)
|
||||
) -> dict:
|
||||
"""
|
||||
Fetches a single user by its public ID
|
||||
"""
|
||||
|
||||
if not (user := await get_user_by_id(UUID(public_id))):
|
||||
raise HTTPException(
|
||||
status_code=404, detail="Lookup failed: the user does not exist"
|
||||
)
|
||||
user.pop("restricted")
|
||||
user.pop("deleted")
|
||||
return {"status_code": 200, "msg": "Lookup successful", "data": user}
|
||||
|
||||
|
||||
async def validate_user(
|
||||
first_name: str, last_name: str, username: str, email: str, password: str
|
||||
) -> tuple[bool, str]:
|
||||
"""
|
||||
Performs some validation upon user creation. Returns
|
||||
a tuple (success, msg) to be used by routes
|
||||
"""
|
||||
|
||||
if len(first_name) > 64:
|
||||
return False, "first name is too long"
|
||||
if len(first_name) < 5:
|
||||
return False, "first name is too short"
|
||||
if len(last_name) > 64:
|
||||
return False, "last name is too long"
|
||||
if len(last_name) < 2:
|
||||
return False, "last name is too short"
|
||||
if len(username) < 5:
|
||||
return False, "username is too short"
|
||||
if len(username) > 32:
|
||||
return False, "username is too long"
|
||||
if VALIDATE_USERNAME_REGEX and not re.match(VALIDATE_USERNAME_REGEX, username):
|
||||
return False, "username is invalid"
|
||||
if not validators.email(email):
|
||||
return False, "email is not valid"
|
||||
if len(password) > 72:
|
||||
return False, "password is too long"
|
||||
if VALIDATE_PASSWORD_REGEX and not re.match(VALIDATE_PASSWORD_REGEX, password):
|
||||
return False, "password is too weak"
|
||||
if await get_user_by_username(username, deleted_ok=True, restricted_ok=True):
|
||||
return False, "username is already taken"
|
||||
if await get_user_by_email(email, deleted_ok=True, restricted_ok=True):
|
||||
return False, "email is already registered"
|
||||
return True, ""
|
||||
|
||||
|
||||
@router.delete("/user")
|
||||
@LIMITER.limit("1/minute")
|
||||
async def delete(request: Request, response: Response, user: dict = Depends(MANAGER)) -> dict:
|
||||
"""
|
||||
Sets the user's deleted flag in the database,
|
||||
without actually deleting the associated
|
||||
data
|
||||
"""
|
||||
|
||||
await User.update({User.deleted: True}).where(User.public_id == user["id"])
|
||||
response.delete_cookie(
|
||||
secure=SECURE_COOKIE,
|
||||
key=SESSION_COOKIE_NAME,
|
||||
httponly=COOKIE_HTTPONLY,
|
||||
samesite=COOKIE_SAMESITE_POLICY,
|
||||
domain=COOKIE_DOMAIN or None,
|
||||
path=COOKIE_PATH or "/",)
|
||||
return {"status_code": 200, "msg": "Success"}
|
||||
|
||||
|
||||
@router.put("/user")
|
||||
@LIMITER.limit("2/minute")
|
||||
async def signup(
|
||||
request: Request,
|
||||
first_name: str,
|
||||
last_name: str,
|
||||
username: str,
|
||||
email: str,
|
||||
password: str,
|
||||
) -> dict:
|
||||
"""
|
||||
Endpoint used to create new users
|
||||
"""
|
||||
|
||||
if request.cookies.get(SESSION_COOKIE_NAME):
|
||||
raise HTTPException(status_code=400, detail="Please logout first")
|
||||
# We don't use FastAPI's validation because we want custom error
|
||||
# messages
|
||||
result, msg = await validate_user(first_name, last_name, username, email, password)
|
||||
if not result:
|
||||
return {"status_code": 413, "msg": f"Signup failed: {msg}"}
|
||||
else:
|
||||
await User.insert(
|
||||
User(
|
||||
first_name=first_name,
|
||||
last_name=last_name,
|
||||
username=username,
|
||||
email_address=email,
|
||||
password_hash=bcrypt.hashpw(
|
||||
password.encode(), bcrypt.gensalt(BCRYPT_ROUNDS)
|
||||
),
|
||||
)
|
||||
)
|
||||
return {"status_code": 200, "msg": "Success"}
|
|
@ -0,0 +1,127 @@
|
|||
import uvloop
|
||||
import asyncio
|
||||
import uvicorn
|
||||
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.exceptions import (
|
||||
HTTPException,
|
||||
RequestValidationError,
|
||||
StarletteHTTPException,
|
||||
)
|
||||
from slowapi.errors import RateLimitExceeded
|
||||
|
||||
|
||||
from endpoints import users
|
||||
from config import (
|
||||
LOGGER,
|
||||
LIMITER,
|
||||
NotAuthenticated,
|
||||
SMTP_HOST,
|
||||
SMTP_USER,
|
||||
SMTP_PORT,
|
||||
SMTP_USE_TLS,
|
||||
SMTP_PASSWORD,
|
||||
SMTP_TIMEOUT,
|
||||
HOST,
|
||||
PORT,
|
||||
WORKERS,
|
||||
)
|
||||
from orm import create_tables, Media
|
||||
from util.exception_handlers import (
|
||||
http_exception,
|
||||
rate_limited,
|
||||
request_invalid,
|
||||
not_authenticated,
|
||||
)
|
||||
from util.email import test_smtp
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
@app.get("/")
|
||||
@LIMITER.limit("10/second")
|
||||
async def root(request: Request):
|
||||
return {"status_code": 403, "msg": "Unauthorized"}
|
||||
|
||||
|
||||
@app.get("/ping")
|
||||
@LIMITER.limit("1/minute")
|
||||
async def ping(request: Request) -> dict:
|
||||
"""
|
||||
This handler simply replies to "ping" requests and
|
||||
is used to check whether the API is up and running.
|
||||
It also performs a sanity check with the database and
|
||||
the SMTP server to ensure that they are functioning correctly.
|
||||
For this reason, this endpoint's rate limit is much stricter
|
||||
"""
|
||||
|
||||
LOGGER.info(f"Processing ping request from {request.client.host}")
|
||||
try:
|
||||
await Media.raw("SELECT 1;")
|
||||
await test_smtp(
|
||||
SMTP_HOST,
|
||||
SMTP_PORT,
|
||||
SMTP_USER,
|
||||
SMTP_PASSWORD,
|
||||
SMTP_TIMEOUT,
|
||||
SMTP_USE_TLS,
|
||||
True,
|
||||
)
|
||||
return {"status_code": 200, "msg": "OK"}
|
||||
except Exception:
|
||||
raise HTTPException(500)
|
||||
|
||||
|
||||
async def startup_checks():
|
||||
LOGGER.info("Initializing database")
|
||||
try:
|
||||
await create_tables()
|
||||
await Media.raw("SELECT 1;")
|
||||
except Exception as e:
|
||||
LOGGER.error(
|
||||
f"An error occurred while trying to initialize the database -> {type(e.__name__)}: {e}"
|
||||
)
|
||||
else:
|
||||
LOGGER.info("Database initialized")
|
||||
LOGGER.info("Testing SMTP connection")
|
||||
try:
|
||||
await test_smtp(
|
||||
SMTP_HOST,
|
||||
SMTP_PORT,
|
||||
SMTP_USER,
|
||||
SMTP_PASSWORD,
|
||||
SMTP_TIMEOUT,
|
||||
SMTP_USE_TLS,
|
||||
True,
|
||||
)
|
||||
except Exception as e:
|
||||
LOGGER.error(
|
||||
f"An error occurred while trying to connect to the SMTP server -> {type(e.__name__)}: {e}"
|
||||
)
|
||||
else:
|
||||
LOGGER.info("SMTP test was successful")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
LOGGER.info("Backend starting up!")
|
||||
LOGGER.debug("Including modules")
|
||||
app.include_router(users.router)
|
||||
app.state.limiter = LIMITER
|
||||
LOGGER.debug("Setting exception handlers")
|
||||
app.add_exception_handler(RateLimitExceeded, rate_limited)
|
||||
app.add_exception_handler(NotAuthenticated, not_authenticated)
|
||||
app.add_exception_handler(HTTPException, http_exception)
|
||||
app.add_exception_handler(StarletteHTTPException, http_exception)
|
||||
app.add_exception_handler(RequestValidationError, request_invalid)
|
||||
LOGGER.debug("Installing uvloop")
|
||||
uvloop.install()
|
||||
log_config = uvicorn.config.LOGGING_CONFIG
|
||||
log_config["formatters"]["access"]["datefmt"] = LOGGER.handlers[0].formatter.datefmt
|
||||
log_config["formatters"]["default"]["datefmt"] = LOGGER.handlers[
|
||||
0
|
||||
].formatter.datefmt
|
||||
log_config["formatters"]["access"]["fmt"] = LOGGER.handlers[0].formatter._fmt
|
||||
log_config["formatters"]["default"]["fmt"] = LOGGER.handlers[0].formatter._fmt
|
||||
log_config["handlers"]["access"]["stream"] = "ext://sys.stderr"
|
||||
asyncio.run(startup_checks())
|
||||
uvicorn.run(host=HOST, port=PORT, app=app, log_config=log_config, workers=WORKERS)
|
|
@ -0,0 +1,14 @@
|
|||
import asyncio
|
||||
from piccolo.table import create_db_tables
|
||||
from .users import User
|
||||
from .media import Media
|
||||
|
||||
|
||||
async def create_tables():
|
||||
"""
|
||||
Initializes the database
|
||||
"""
|
||||
|
||||
await create_db_tables(User, Media, if_not_exists=True)
|
||||
await User.create_index([User.public_id], if_not_exists=True)
|
||||
await User.create_index([User.username], if_not_exists=True)
|
|
@ -0,0 +1,19 @@
|
|||
"""
|
||||
Media relation
|
||||
"""
|
||||
|
||||
from piccolo.table import Table
|
||||
from piccolo.columns import UUID, Text, Boolean, Date
|
||||
from piccolo.columns.defaults.date import DateNow
|
||||
|
||||
|
||||
class Media(Table):
|
||||
"""
|
||||
A piece of media on a CDN
|
||||
"""
|
||||
|
||||
media_id = UUID(primary_key=True)
|
||||
media_url = Text(null=False)
|
||||
flagged = Boolean(default=False, null=False)
|
||||
deleted = Boolean(default=False, null=False)
|
||||
creation_date = Date(default=DateNow(), null=False)
|
|
@ -0,0 +1,47 @@
|
|||
"""
|
||||
User relation
|
||||
"""
|
||||
|
||||
from piccolo.utils.pydantic import create_pydantic_model
|
||||
from piccolo.table import Table
|
||||
from piccolo.columns import (
|
||||
ForeignKey,
|
||||
Varchar,
|
||||
BigSerial,
|
||||
UUID,
|
||||
Date,
|
||||
OnDelete,
|
||||
OnUpdate,
|
||||
Boolean,
|
||||
Email,
|
||||
Bytea,
|
||||
)
|
||||
from piccolo.columns.defaults.date import DateNow
|
||||
|
||||
from .media import Media
|
||||
|
||||
|
||||
class User(Table, tablename="users"):
|
||||
internal_id = BigSerial(null=False, secret=True)
|
||||
public_id = UUID(primary_key=True)
|
||||
first_name = Varchar(length=64, null=False)
|
||||
last_name = Varchar(length=64, null=True)
|
||||
email_address = Email(secret=True)
|
||||
username = Varchar(length=32, null=False, unique=True)
|
||||
password_hash = Bytea(null=False, secret=True)
|
||||
profile_picture = ForeignKey(
|
||||
references=Media,
|
||||
on_delete=OnDelete.set_null,
|
||||
on_update=OnUpdate.cascade,
|
||||
null=True,
|
||||
default=None,
|
||||
)
|
||||
creation_date = Date(secret=True, default=DateNow())
|
||||
bio = Varchar(length=4096, null=True, default=None)
|
||||
restricted = Boolean(default=False, null=False)
|
||||
email_verified = Boolean(default=False, null=False, secret=True)
|
||||
verified_account = Boolean(default=False, null=False)
|
||||
deleted = Boolean(default=False, null=False)
|
||||
|
||||
|
||||
UserModel = create_pydantic_model(User)
|
|
@ -0,0 +1,16 @@
|
|||
from piccolo.conf.apps import AppRegistry
|
||||
from piccolo.engine.postgres import PostgresEngine
|
||||
|
||||
DB = PostgresEngine(
|
||||
config={
|
||||
"host": "example.com",
|
||||
"user": "database",
|
||||
"password": "password",
|
||||
"database": "database",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
# A list of paths to piccolo apps
|
||||
# e.g. ['blog.piccolo_app']
|
||||
APP_REGISTRY = AppRegistry(apps=[])
|
|
@ -0,0 +1,10 @@
|
|||
piccolo[all]~=0.91.0
|
||||
fastapi-login
|
||||
bcrypt~=4.0.0
|
||||
slowapi~=0.1.6
|
||||
python-multipart
|
||||
uvloop~=0.17.0
|
||||
fastapi~=0.85.0
|
||||
validators
|
||||
aiosmtplib
|
||||
uvicorn
|
|
@ -0,0 +1,3 @@
|
|||
[{
|
||||
|
||||
}]
|
|
@ -0,0 +1,87 @@
|
|||
import asyncio
|
||||
import logging
|
||||
import ssl
|
||||
|
||||
import aiosmtplib
|
||||
from email.mime.base import MIMEBase
|
||||
from email.mime.text import MIMEText
|
||||
from email.mime.multipart import MIMEMultipart
|
||||
from fastapi.responses import RedirectResponse
|
||||
from typing import Union, List
|
||||
|
||||
|
||||
async def send_email(
|
||||
host: str,
|
||||
port: int,
|
||||
message: str,
|
||||
timeout: int,
|
||||
sender: str,
|
||||
recipient: str,
|
||||
subject: str,
|
||||
login_email: str,
|
||||
password: str,
|
||||
attachments: List[MIMEBase] = tuple(),
|
||||
use_tls: bool = True,
|
||||
check_hostname: bool = True,
|
||||
) -> Union[bool, aiosmtplib.SMTPException]:
|
||||
"""
|
||||
Sends an email with the given details. Returns True on success
|
||||
or an exception object upon failure
|
||||
"""
|
||||
|
||||
try:
|
||||
async with aiosmtplib.SMTP(host, port, timeout=timeout) as srv:
|
||||
msg = MIMEMultipart()
|
||||
msg["From"] = sender
|
||||
msg["To"] = recipient
|
||||
msg["Subject"] = subject
|
||||
msg.attach(MIMEText(message, "html"))
|
||||
for attachment in attachments:
|
||||
msg.attach(attachment)
|
||||
await srv.ehlo() # We identify ourselves
|
||||
if use_tls:
|
||||
context = ssl.create_default_context()
|
||||
context.check_hostname = check_hostname
|
||||
await srv.starttls(tls_context=context)
|
||||
await srv.ehlo() # We do it again, but encrypted!
|
||||
await srv.login(login_email, password)
|
||||
await srv.sendmail(sender, recipient, msg.as_string())
|
||||
except (aiosmtplib.SMTPException, asyncio.TimeoutError) as error:
|
||||
logging.error(
|
||||
f"An error occurred while dealing with {host}:{port} (SMTP): {type(error).__name__}: {error}"
|
||||
)
|
||||
return error
|
||||
return True
|
||||
|
||||
|
||||
async def test_smtp(
|
||||
host: str,
|
||||
port: int,
|
||||
login_email: str,
|
||||
password: str,
|
||||
timeout: int,
|
||||
use_tls: bool = True,
|
||||
check_hostname: bool = True,
|
||||
):
|
||||
"""
|
||||
Attempts login to the given SMTP server with the given credentials.
|
||||
Used upon startup, raises an exception upon failure. This will
|
||||
fail if the server does not support TLS encryption for login
|
||||
"""
|
||||
|
||||
async with aiosmtplib.SMTP(host, port, timeout=timeout) as srv:
|
||||
await srv.ehlo()
|
||||
if use_tls:
|
||||
context = ssl.create_default_context()
|
||||
context.check_hostname = check_hostname
|
||||
await srv.starttls(tls_context=context)
|
||||
await srv.ehlo()
|
||||
await srv.login(login_email, password)
|
||||
|
||||
|
||||
def redirect(url: str) -> RedirectResponse:
|
||||
"""
|
||||
Returns a redirect response to the given url
|
||||
"""
|
||||
|
||||
return RedirectResponse(url=url)
|
|
@ -0,0 +1,79 @@
|
|||
from config import LOGGER, NotAuthenticated
|
||||
from fastapi import Request
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.exceptions import HTTPException, StarletteHTTPException
|
||||
from slowapi.errors import RateLimitExceeded
|
||||
|
||||
|
||||
async def rate_limited(request: Request, error: RateLimitExceeded) -> JSONResponse:
|
||||
n = 0
|
||||
while True:
|
||||
if error.detail[n].isnumeric():
|
||||
n += 1
|
||||
else:
|
||||
break
|
||||
error.detail = error.detail[:n] + " requests" + error.detail[n:]
|
||||
LOGGER.info(
|
||||
f"{request.client.host} got rate-limited at {str(request.url)} "
|
||||
f"(exceeded {error.detail})"
|
||||
)
|
||||
return JSONResponse(
|
||||
status_code=200,
|
||||
content={
|
||||
"msg": f"Too many requests, retry after {error.detail[error.detail.find('per') + 4:]}",
|
||||
"status_code": 429,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def not_authenticated(request: Request, _: NotAuthenticated) -> JSONResponse:
|
||||
LOGGER.info(f"{request.client.host} failed to authenticate at {str(request.url)}")
|
||||
return JSONResponse(
|
||||
status_code=200,
|
||||
content={
|
||||
"msg": "Authentication is required",
|
||||
"status_code": 401,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def request_invalid(request: Request, exc: RequestValidationError) -> JSONResponse:
|
||||
LOGGER.info(
|
||||
f"{request.client.host} sent an invalid request at {request.url!r}: {type(exc).__name__}: {exc}"
|
||||
)
|
||||
return JSONResponse(
|
||||
status_code=200,
|
||||
content={
|
||||
"msg": f"Bad request: {type(exc).__name__}: {exc}",
|
||||
"status_code": 400,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def http_exception(
|
||||
request: Request, exc: HTTPException | StarletteHTTPException
|
||||
) -> JSONResponse:
|
||||
if exc.status_code >= 500:
|
||||
LOGGER.error(
|
||||
f"{request.client.host} raised a {exc.status_code} error at {request.url!r}:"
|
||||
f"{type(exc).__name__}: {exc}"
|
||||
)
|
||||
return JSONResponse(
|
||||
status_code=200,
|
||||
content={
|
||||
"msg": "Internal server error",
|
||||
"status_code": exc.status_code,
|
||||
},
|
||||
)
|
||||
else:
|
||||
LOGGER.info(
|
||||
f"{request.client.host} raised an HTTP error ({exc.status_code}) at {str(request.url)}"
|
||||
)
|
||||
return JSONResponse(
|
||||
status_code=200,
|
||||
content={
|
||||
"msg": exc.detail,
|
||||
"status_code": exc.status_code,
|
||||
},
|
||||
)
|
Loading…
Reference in New Issue