Initial work on users

This commit is contained in:
Nocturn9x 2022-10-04 21:13:26 +02:00
parent 4a06840a81
commit d468aff227
16 changed files with 899 additions and 1 deletions

2
.gitignore vendored
View File

@ -138,3 +138,5 @@ dmypy.json
# Cython debug symbols
cython_debug/
/config.py
/piccolo_conf.py

10
.idea/socialMedia.iml Normal file
View File

@ -0,0 +1,10 @@
<?xml version="1.0" encoding="UTF-8"?>
<module type="PYTHON_MODULE" version="4">
<component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$">
<excludeFolder url="file://$MODULE_DIR$/venv" />
</content>
<orderEntry type="inheritedJdk" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
</module>

View File

@ -1,3 +1,38 @@
# PySimpleSocial
An advanced REST API written in Python for a generic social media website
An advanced REST API written in Python for a generic social media website.
__Note__: This is a WIP so far.
## Tech Stack
The project is written using [FastAPI](https://https://fastapi.tiangolo.com/), [piccolo](https://piccolo-orm.readthedocs.io/en/latest/) and [uvicorn](https://www.uvicorn.org/). Other awesome libraries used (only direct dependencies are listed here):
- [validators](https://validators.readthedocs.io/en/latest/)
- [pydantic](https://pydantic-docs.helpmanual.io/)
- [aiosmtplib](https://aiosmtplib.readthedocs.io/en/latest/usage.html)
- [uvloop](https://uvloop.readthedocs.io/)
- [slowapi](https://slowapi.readthedocs.io/en/latest/)
- [fastapi-login](https://fastapi-login.readthedocs.io/)
- [bcrypt](https://github.com/pyca/bcrypt/)
# Feature overview
__Note__: Not all of this is implemented yet
- Simple authentication system using salted bcrypt hashes for password storage
- PostgreSQL is used as the main database
- Simple rate limiting using redis/in-memory storage
- Support for various kinds of media stored in a CDN, directly inside the database or on a local/remote filesystem
- Regular social media mechanics: (Un)following users, posting media with captions, etc.
- User settings (change username, profile picture, etc.)
- Simple messaging system using websockets or a polling HTTP API
- Admin functionality with basic metrics and administration features (flagging/deleting users/posts, handling tickets, etc.)
# Setup
Move the *.py.example files to their respective *.py files, fill them as necessary, then simply install the dependencies via pip and run main.py
# License
This software is licensed under the MIT license. For more information, read the [license file](LICENSE)

115
config.py.example Normal file
View File

@ -0,0 +1,115 @@
# Configuration file. Each variable can be overridden by a
# corresponding environment variable
import os
import sys
import logging
import pathlib
from slowapi import Limiter
from slowapi.util import get_remote_address
from fastapi_login import LoginManager
# Authentication configuration
LOGIN_SECRET_KEY = (
os.getenv("LOGIN_SECRET_KEY") or "login-secret"
) # Recommended value: os.urandom(24).hex()
USE_BEARER_HEADER = True
USE_COOKIE = True
# Logging configuration
LOG_LEVEL = int(os.getenv("LOG_LEVEL") or 0) or 20
LOG_FILE = os.getenv("LOG_FILE") or "" # Empty to disable
LOG_FORMAT = os.getenv("LOG_FORMAT") or "[%(levelname)s - %(asctime)s] %(message)s"
LOG_DATE_FORMAT = os.getenv("LOG_DATE_FORMAT") or "%d/%m/%Y %p"
# Bcrypt configuration
BCRYPT_ROUNDS = (
os.getenv("BCRYPT_ROUNDS") or 10
) # How many rounds are used when salting
# Rate limit configuration
REDIS_URL = "" # Used for rate limits. Empty to fall back to memory
REDIS_OPTIONS = {} # Options for redis
RATELIMIT_ENABLED = False # False to disable rate limiting
RATELIMIT_STRATEGY = "moving-window" # Refer to https://flask-limiter.readthedocs.io/en/stable/strategies.html
# Session configuration
SESSION_EXPIRE_LIMIT = 3600 # Unit is in seconds
SESSION_COOKIE_NAME = "_social_media_session"
COOKIE_SAMESITE_POLICY = "none" # Options are "lax", "none", "strict"
COOKIE_DOMAIN = "localhost" # Empty to disable this
COOKIE_PATH = "/"
COOKIE_HTTPONLY = True
SECURE_COOKIE = False # Set to true in production, False during development (unless your local server has HTTPS)
# SMTP configuration
SMTP_HOST = "smtp.nocturn9x.space"
SMTP_USER = "info@example.com"
SMTP_PASSWORD = "password"
SMTP_PORT = 587
SMTP_USE_TLS = True
SMTP_FROM_USER = "info@example.com"
SMTP_TEMPLATES_DIRECTORY = pathlib.Path(__file__) / "templates" / "email"
SMTP_TIMEOUT = 10
# Miscellaneous
# Usernames containing these characters are not valid
INVALID_USERNAME_CHARACTERS = ["@", "\\", "/"] # Empty this to allow any character
# Empty this to disable username validation
VALIDATE_USERNAME_REGEX = (
rf"^([^{''.join(INVALID_USERNAME_CHARACTERS)}]|[a-z0-9A-Z]){{5,32}}$"
)
# Criteria:
# - Between 10 and 72 characters long
# - At least 2 uppercase letters
# - At least 3 lowercase letters
# - At least one special character (!, @, #, $, %, &, *, _, +, /, \, (, ), £, ", ?, ^
# - At least 2 numbers
# You can change the repetitions to enforce stricter/laxer rules or empty
# this field to disable weakness validation
VALIDATE_PASSWORD_REGEX = r"^(?=.*[A-Z]){2,}(?=.*[!@%#$&*_^\?\\\/(\)\+\-])+(?=.*[0-9]){2,}(?=.*[a-z]){3,}.{10,72}$"
class NotAuthenticated(Exception):
pass
if __name__ != "__main__":
LOGGER: logging.Logger = logging.getLogger("socialMedia")
LOGGER.setLevel(LOG_LEVEL)
handler = logging.StreamHandler(sys.stderr)
formatter = logging.Formatter(fmt=LOG_FORMAT, datefmt=LOG_DATE_FORMAT)
handler.setFormatter(formatter)
LOGGER.addHandler(handler)
handler.setLevel(LOG_LEVEL)
if LOG_FILE:
file_handler = logging.FileHandler(LOG_FILE, "a", "utf8")
file_handler.setFormatter(formatter)
LOGGER.addHandler(handler)
file_handler.setLevel(LOG_LEVEL)
logging.getLogger("uvicorn").addHandler(file_handler)
LIMITER = Limiter(
key_func=get_remote_address,
strategy=RATELIMIT_STRATEGY,
storage_uri=REDIS_URL or None,
in_memory_fallback_enabled=bool(REDIS_URL),
storage_options=REDIS_OPTIONS,
enabled=RATELIMIT_ENABLED,
)
MANAGER = LoginManager(
LOGIN_SECRET_KEY,
"/login",
use_cookie=True,
cookie_name=SESSION_COOKIE_NAME,
use_header=USE_BEARER_HEADER,
custom_exception=NotAuthenticated,
)
# Uvicorn config
HOST = os.getenv("HOST") or "localhost"
PORT = int(os.getenv("PORT") or 0) or 8000
WORKERS = int(os.getenv("PORT") or 0) or 1

0
endpoints/__init__.py Normal file
View File

334
endpoints/users.py Normal file
View File

@ -0,0 +1,334 @@
import re
import bcrypt
from uuid import UUID
import validators
from fastapi import APIRouter as FastAPI, Depends, Response, Request
from fastapi.exceptions import HTTPException
from fastapi.security import OAuth2PasswordRequestForm
from datetime import timedelta
from config import (
BCRYPT_ROUNDS,
LOGGER,
SESSION_EXPIRE_LIMIT,
COOKIE_SAMESITE_POLICY,
COOKIE_DOMAIN,
MANAGER,
SESSION_COOKIE_NAME,
LIMITER,
VALIDATE_PASSWORD_REGEX,
VALIDATE_USERNAME_REGEX,
SECURE_COOKIE,
COOKIE_PATH,
COOKIE_HTTPONLY,
)
from orm.users import UserModel, User
router = FastAPI()
async def get_user_by_id(
public_id: UUID, include_secrets: bool = False, restricted_ok: bool = False,
deleted_ok: bool = False
) -> dict | None:
"""
Retrieves a user by its public ID
"""
user = (
await User.select(
*User.all_columns(exclude=["public_id"]),
User.public_id.as_alias("id"),
exclude_secrets=not include_secrets,
)
.where(User.public_id == public_id)
.first()
)
if user:
# Performs validation
UserModel(**user)
if (user["deleted"] and not deleted_ok) or (user["restricted"] and not restricted_ok):
return
return user
return
@MANAGER.user_loader()
async def get_self_by_id(public_id: UUID) -> dict:
return await get_user_by_id(public_id, include_secrets=True, restricted_ok=True)
async def get_user_by_username(
username: str, include_secrets: bool = False, restricted_ok: bool = False,
deleted_ok: bool = False
) -> dict | None:
"""
Retrieves a user by its public username
"""
user = (
await User.select(
*User.all_columns(exclude=["public_id"]),
User.public_id.as_alias("id"),
exclude_secrets=not include_secrets,
)
.where(User.username == username)
.first()
)
if user:
# Performs validation
UserModel(**user)
if (user["deleted"] and not deleted_ok) or (user["restricted"] and not restricted_ok):
return
return user
return
async def get_user_by_email(
email: str, include_secrets: bool = False, restricted_ok: bool = False,
deleted_ok: bool = False
) -> dict | None:
"""
Retrieves a user by its email address (meant to
be used internally)
"""
user = (
await User.select(
*User.all_columns(exclude=["public_id"]),
User.public_id.as_alias("id"),
exclude_secrets=not include_secrets,
)
.where(User.email_address == email)
.first()
)
if user:
# Performs validation
UserModel(**user)
if (user["deleted"] and not deleted_ok) or (user["restricted"] and not restricted_ok):
return
return user
return
@LIMITER.limit("5/minute")
@router.post("/user")
async def login(
request: Request, response: Response, data: OAuth2PasswordRequestForm = Depends()
) -> dict:
if request.cookies.get(SESSION_COOKIE_NAME):
raise HTTPException(status_code=400, detail="Please logout first")
username = data.username
if len(username) > 32:
raise HTTPException(
status_code=413, detail="Authentication failed: username is too long"
)
try:
password = data.password.encode()
if len(password) > 72:
raise HTTPException(
status_code=413, detail="Authentication failed: password is too long"
)
except UnicodeEncodeError as e:
LOGGER.warning(
f"An error occurred while attempting to decode password for user {username} -> {type(e).__name__}: {e}"
)
raise HTTPException(
status_code=413,
detail="Authentication failed: invalid characters in password",
)
if not (
user := await get_user_by_username(
username, include_secrets=True, restricted_ok=True
)
):
raise HTTPException(
status_code=413,
detail="Authentication failed: the user does not exist",
)
if not bcrypt.checkpw(password, user["password_hash"]):
raise HTTPException(
status_code=413,
detail="Authentication failed: password mismatch",
)
token = MANAGER.create_access_token(
expires=timedelta(seconds=SESSION_EXPIRE_LIMIT), data={"sub": str(user["id"])}
)
response.set_cookie(
secure=SECURE_COOKIE,
key=SESSION_COOKIE_NAME,
max_age=SESSION_EXPIRE_LIMIT,
value=token,
httponly=COOKIE_HTTPONLY,
samesite=COOKIE_SAMESITE_POLICY,
domain=COOKIE_DOMAIN or None,
path=COOKIE_PATH or "/",
)
return {"status_code": 200, "msg": "Authentication successful"}
@router.get("/user/logout")
@LIMITER.limit("5/minute")
async def logout(
request: Request, response: Response, user: dict = Depends(MANAGER)
) -> dict:
"""
Deletes a user's session cookie, logging them
out
"""
response.delete_cookie(
secure=SECURE_COOKIE,
key=SESSION_COOKIE_NAME,
httponly=COOKIE_HTTPONLY,
samesite=COOKIE_SAMESITE_POLICY,
domain=COOKIE_DOMAIN or None,
path=COOKIE_PATH or "/",
)
return {"status_code": 200, "msg": "Logged out"}
@router.get("/user/me")
@LIMITER.limit("2/second")
async def get_self(request: Request, user: dict = Depends(MANAGER)) -> dict:
"""
Fetches a user's own info. This returns some
extra data such as email address, account
creation date and email verification status,
which is not available from the regular endpoint
"""
user.pop("password_hash")
user.pop("internal_id")
user.pop("deleted")
return {"status_code": 200, "msg": "Success", "data": user}
@router.get("/user/username/{username}")
@LIMITER.limit("30/second")
async def get_user_by_name(
request: Request, username: str, _auth: dict = Depends(MANAGER)
) -> dict:
"""
Fetches a single user by its public ID
"""
if not (user := await get_user_by_username(username)):
return {
"status_code": 404,
"msg": "Lookup failed: the user does not exist",
}
user.pop("restricted")
user.pop("deleted")
return {"status_code": 200, "msg": "Lookup successful", "data": user}
@router.get("/user/id/{public_id}")
@LIMITER.limit("30/second")
async def get_user_by_public_id(
request: Request, public_id: str, _auth: dict = Depends(MANAGER)
) -> dict:
"""
Fetches a single user by its public ID
"""
if not (user := await get_user_by_id(UUID(public_id))):
raise HTTPException(
status_code=404, detail="Lookup failed: the user does not exist"
)
user.pop("restricted")
user.pop("deleted")
return {"status_code": 200, "msg": "Lookup successful", "data": user}
async def validate_user(
first_name: str, last_name: str, username: str, email: str, password: str
) -> tuple[bool, str]:
"""
Performs some validation upon user creation. Returns
a tuple (success, msg) to be used by routes
"""
if len(first_name) > 64:
return False, "first name is too long"
if len(first_name) < 5:
return False, "first name is too short"
if len(last_name) > 64:
return False, "last name is too long"
if len(last_name) < 2:
return False, "last name is too short"
if len(username) < 5:
return False, "username is too short"
if len(username) > 32:
return False, "username is too long"
if VALIDATE_USERNAME_REGEX and not re.match(VALIDATE_USERNAME_REGEX, username):
return False, "username is invalid"
if not validators.email(email):
return False, "email is not valid"
if len(password) > 72:
return False, "password is too long"
if VALIDATE_PASSWORD_REGEX and not re.match(VALIDATE_PASSWORD_REGEX, password):
return False, "password is too weak"
if await get_user_by_username(username, deleted_ok=True, restricted_ok=True):
return False, "username is already taken"
if await get_user_by_email(email, deleted_ok=True, restricted_ok=True):
return False, "email is already registered"
return True, ""
@router.delete("/user")
@LIMITER.limit("1/minute")
async def delete(request: Request, response: Response, user: dict = Depends(MANAGER)) -> dict:
"""
Sets the user's deleted flag in the database,
without actually deleting the associated
data
"""
await User.update({User.deleted: True}).where(User.public_id == user["id"])
response.delete_cookie(
secure=SECURE_COOKIE,
key=SESSION_COOKIE_NAME,
httponly=COOKIE_HTTPONLY,
samesite=COOKIE_SAMESITE_POLICY,
domain=COOKIE_DOMAIN or None,
path=COOKIE_PATH or "/",)
return {"status_code": 200, "msg": "Success"}
@router.put("/user")
@LIMITER.limit("2/minute")
async def signup(
request: Request,
first_name: str,
last_name: str,
username: str,
email: str,
password: str,
) -> dict:
"""
Endpoint used to create new users
"""
if request.cookies.get(SESSION_COOKIE_NAME):
raise HTTPException(status_code=400, detail="Please logout first")
# We don't use FastAPI's validation because we want custom error
# messages
result, msg = await validate_user(first_name, last_name, username, email, password)
if not result:
return {"status_code": 413, "msg": f"Signup failed: {msg}"}
else:
await User.insert(
User(
first_name=first_name,
last_name=last_name,
username=username,
email_address=email,
password_hash=bcrypt.hashpw(
password.encode(), bcrypt.gensalt(BCRYPT_ROUNDS)
),
)
)
return {"status_code": 200, "msg": "Success"}

127
main.py Normal file
View File

@ -0,0 +1,127 @@
import uvloop
import asyncio
import uvicorn
from fastapi import FastAPI, Request
from fastapi.exceptions import (
HTTPException,
RequestValidationError,
StarletteHTTPException,
)
from slowapi.errors import RateLimitExceeded
from endpoints import users
from config import (
LOGGER,
LIMITER,
NotAuthenticated,
SMTP_HOST,
SMTP_USER,
SMTP_PORT,
SMTP_USE_TLS,
SMTP_PASSWORD,
SMTP_TIMEOUT,
HOST,
PORT,
WORKERS,
)
from orm import create_tables, Media
from util.exception_handlers import (
http_exception,
rate_limited,
request_invalid,
not_authenticated,
)
from util.email import test_smtp
app = FastAPI()
@app.get("/")
@LIMITER.limit("10/second")
async def root(request: Request):
return {"status_code": 403, "msg": "Unauthorized"}
@app.get("/ping")
@LIMITER.limit("1/minute")
async def ping(request: Request) -> dict:
"""
This handler simply replies to "ping" requests and
is used to check whether the API is up and running.
It also performs a sanity check with the database and
the SMTP server to ensure that they are functioning correctly.
For this reason, this endpoint's rate limit is much stricter
"""
LOGGER.info(f"Processing ping request from {request.client.host}")
try:
await Media.raw("SELECT 1;")
await test_smtp(
SMTP_HOST,
SMTP_PORT,
SMTP_USER,
SMTP_PASSWORD,
SMTP_TIMEOUT,
SMTP_USE_TLS,
True,
)
return {"status_code": 200, "msg": "OK"}
except Exception:
raise HTTPException(500)
async def startup_checks():
LOGGER.info("Initializing database")
try:
await create_tables()
await Media.raw("SELECT 1;")
except Exception as e:
LOGGER.error(
f"An error occurred while trying to initialize the database -> {type(e.__name__)}: {e}"
)
else:
LOGGER.info("Database initialized")
LOGGER.info("Testing SMTP connection")
try:
await test_smtp(
SMTP_HOST,
SMTP_PORT,
SMTP_USER,
SMTP_PASSWORD,
SMTP_TIMEOUT,
SMTP_USE_TLS,
True,
)
except Exception as e:
LOGGER.error(
f"An error occurred while trying to connect to the SMTP server -> {type(e.__name__)}: {e}"
)
else:
LOGGER.info("SMTP test was successful")
if __name__ == "__main__":
LOGGER.info("Backend starting up!")
LOGGER.debug("Including modules")
app.include_router(users.router)
app.state.limiter = LIMITER
LOGGER.debug("Setting exception handlers")
app.add_exception_handler(RateLimitExceeded, rate_limited)
app.add_exception_handler(NotAuthenticated, not_authenticated)
app.add_exception_handler(HTTPException, http_exception)
app.add_exception_handler(StarletteHTTPException, http_exception)
app.add_exception_handler(RequestValidationError, request_invalid)
LOGGER.debug("Installing uvloop")
uvloop.install()
log_config = uvicorn.config.LOGGING_CONFIG
log_config["formatters"]["access"]["datefmt"] = LOGGER.handlers[0].formatter.datefmt
log_config["formatters"]["default"]["datefmt"] = LOGGER.handlers[
0
].formatter.datefmt
log_config["formatters"]["access"]["fmt"] = LOGGER.handlers[0].formatter._fmt
log_config["formatters"]["default"]["fmt"] = LOGGER.handlers[0].formatter._fmt
log_config["handlers"]["access"]["stream"] = "ext://sys.stderr"
asyncio.run(startup_checks())
uvicorn.run(host=HOST, port=PORT, app=app, log_config=log_config, workers=WORKERS)

14
orm/__init__.py Normal file
View File

@ -0,0 +1,14 @@
import asyncio
from piccolo.table import create_db_tables
from .users import User
from .media import Media
async def create_tables():
"""
Initializes the database
"""
await create_db_tables(User, Media, if_not_exists=True)
await User.create_index([User.public_id], if_not_exists=True)
await User.create_index([User.username], if_not_exists=True)

19
orm/media.py Normal file
View File

@ -0,0 +1,19 @@
"""
Media relation
"""
from piccolo.table import Table
from piccolo.columns import UUID, Text, Boolean, Date
from piccolo.columns.defaults.date import DateNow
class Media(Table):
"""
A piece of media on a CDN
"""
media_id = UUID(primary_key=True)
media_url = Text(null=False)
flagged = Boolean(default=False, null=False)
deleted = Boolean(default=False, null=False)
creation_date = Date(default=DateNow(), null=False)

47
orm/users.py Normal file
View File

@ -0,0 +1,47 @@
"""
User relation
"""
from piccolo.utils.pydantic import create_pydantic_model
from piccolo.table import Table
from piccolo.columns import (
ForeignKey,
Varchar,
BigSerial,
UUID,
Date,
OnDelete,
OnUpdate,
Boolean,
Email,
Bytea,
)
from piccolo.columns.defaults.date import DateNow
from .media import Media
class User(Table, tablename="users"):
internal_id = BigSerial(null=False, secret=True)
public_id = UUID(primary_key=True)
first_name = Varchar(length=64, null=False)
last_name = Varchar(length=64, null=True)
email_address = Email(secret=True)
username = Varchar(length=32, null=False, unique=True)
password_hash = Bytea(null=False, secret=True)
profile_picture = ForeignKey(
references=Media,
on_delete=OnDelete.set_null,
on_update=OnUpdate.cascade,
null=True,
default=None,
)
creation_date = Date(secret=True, default=DateNow())
bio = Varchar(length=4096, null=True, default=None)
restricted = Boolean(default=False, null=False)
email_verified = Boolean(default=False, null=False, secret=True)
verified_account = Boolean(default=False, null=False)
deleted = Boolean(default=False, null=False)
UserModel = create_pydantic_model(User)

16
piccolo_conf.py.example Normal file
View File

@ -0,0 +1,16 @@
from piccolo.conf.apps import AppRegistry
from piccolo.engine.postgres import PostgresEngine
DB = PostgresEngine(
config={
"host": "example.com",
"user": "database",
"password": "password",
"database": "database",
}
)
# A list of paths to piccolo apps
# e.g. ['blog.piccolo_app']
APP_REGISTRY = AppRegistry(apps=[])

10
requirements.txt Normal file
View File

@ -0,0 +1,10 @@
piccolo[all]~=0.91.0
fastapi-login
bcrypt~=4.0.0
slowapi~=0.1.6
python-multipart
uvloop~=0.17.0
fastapi~=0.85.0
validators
aiosmtplib
uvicorn

View File

@ -0,0 +1,3 @@
[{
}]

0
util/__init__.py Normal file
View File

87
util/email.py Normal file
View File

@ -0,0 +1,87 @@
import asyncio
import logging
import ssl
import aiosmtplib
from email.mime.base import MIMEBase
from email.mime.text import MIMEText
from email.mime.multipart import MIMEMultipart
from fastapi.responses import RedirectResponse
from typing import Union, List
async def send_email(
host: str,
port: int,
message: str,
timeout: int,
sender: str,
recipient: str,
subject: str,
login_email: str,
password: str,
attachments: List[MIMEBase] = tuple(),
use_tls: bool = True,
check_hostname: bool = True,
) -> Union[bool, aiosmtplib.SMTPException]:
"""
Sends an email with the given details. Returns True on success
or an exception object upon failure
"""
try:
async with aiosmtplib.SMTP(host, port, timeout=timeout) as srv:
msg = MIMEMultipart()
msg["From"] = sender
msg["To"] = recipient
msg["Subject"] = subject
msg.attach(MIMEText(message, "html"))
for attachment in attachments:
msg.attach(attachment)
await srv.ehlo() # We identify ourselves
if use_tls:
context = ssl.create_default_context()
context.check_hostname = check_hostname
await srv.starttls(tls_context=context)
await srv.ehlo() # We do it again, but encrypted!
await srv.login(login_email, password)
await srv.sendmail(sender, recipient, msg.as_string())
except (aiosmtplib.SMTPException, asyncio.TimeoutError) as error:
logging.error(
f"An error occurred while dealing with {host}:{port} (SMTP): {type(error).__name__}: {error}"
)
return error
return True
async def test_smtp(
host: str,
port: int,
login_email: str,
password: str,
timeout: int,
use_tls: bool = True,
check_hostname: bool = True,
):
"""
Attempts login to the given SMTP server with the given credentials.
Used upon startup, raises an exception upon failure. This will
fail if the server does not support TLS encryption for login
"""
async with aiosmtplib.SMTP(host, port, timeout=timeout) as srv:
await srv.ehlo()
if use_tls:
context = ssl.create_default_context()
context.check_hostname = check_hostname
await srv.starttls(tls_context=context)
await srv.ehlo()
await srv.login(login_email, password)
def redirect(url: str) -> RedirectResponse:
"""
Returns a redirect response to the given url
"""
return RedirectResponse(url=url)

View File

@ -0,0 +1,79 @@
from config import LOGGER, NotAuthenticated
from fastapi import Request
from fastapi.responses import JSONResponse
from fastapi.exceptions import RequestValidationError
from fastapi.exceptions import HTTPException, StarletteHTTPException
from slowapi.errors import RateLimitExceeded
async def rate_limited(request: Request, error: RateLimitExceeded) -> JSONResponse:
n = 0
while True:
if error.detail[n].isnumeric():
n += 1
else:
break
error.detail = error.detail[:n] + " requests" + error.detail[n:]
LOGGER.info(
f"{request.client.host} got rate-limited at {str(request.url)} "
f"(exceeded {error.detail})"
)
return JSONResponse(
status_code=200,
content={
"msg": f"Too many requests, retry after {error.detail[error.detail.find('per') + 4:]}",
"status_code": 429,
},
)
def not_authenticated(request: Request, _: NotAuthenticated) -> JSONResponse:
LOGGER.info(f"{request.client.host} failed to authenticate at {str(request.url)}")
return JSONResponse(
status_code=200,
content={
"msg": "Authentication is required",
"status_code": 401,
},
)
def request_invalid(request: Request, exc: RequestValidationError) -> JSONResponse:
LOGGER.info(
f"{request.client.host} sent an invalid request at {request.url!r}: {type(exc).__name__}: {exc}"
)
return JSONResponse(
status_code=200,
content={
"msg": f"Bad request: {type(exc).__name__}: {exc}",
"status_code": 400,
},
)
def http_exception(
request: Request, exc: HTTPException | StarletteHTTPException
) -> JSONResponse:
if exc.status_code >= 500:
LOGGER.error(
f"{request.client.host} raised a {exc.status_code} error at {request.url!r}:"
f"{type(exc).__name__}: {exc}"
)
return JSONResponse(
status_code=200,
content={
"msg": "Internal server error",
"status_code": exc.status_code,
},
)
else:
LOGGER.info(
f"{request.client.host} raised an HTTP error ({exc.status_code}) at {str(request.url)}"
)
return JSONResponse(
status_code=200,
content={
"msg": exc.detail,
"status_code": exc.status_code,
},
)