PySimpleSocial/main.py

130 lines
3.7 KiB
Python

import uvloop
import asyncio
import uvicorn
from fastapi import FastAPI, Request
from fastapi.exceptions import (
HTTPException,
RequestValidationError,
StarletteHTTPException,
)
from slowapi.errors import RateLimitExceeded
from endpoints import users
from config import (
LOGGER,
LIMITER,
NotAuthenticated,
SMTP_HOST,
SMTP_USER,
SMTP_PORT,
SMTP_USE_TLS,
SMTP_PASSWORD,
SMTP_TIMEOUT,
HOST,
PORT,
WORKERS,
)
from orm import create_tables, Media
from util.exception_handlers import (
http_exception,
rate_limited,
request_invalid,
not_authenticated,
generic_error,
)
from util.email import test_smtp
app = FastAPI()
@app.get("/")
@LIMITER.limit("10/second")
async def root(request: Request):
raise HTTPException(401, detail="Unauthorized")
@app.get("/ping")
@LIMITER.limit("1/minute")
async def ping(request: Request) -> dict:
"""
This handler simply replies to "ping" requests and
is used to check whether the API is up and running.
It also performs a sanity check with the database and
the SMTP server to ensure that they are functioning correctly.
For this reason, this endpoint's rate limit is much stricter
"""
LOGGER.info(f"Processing ping request from {request.client.host}")
try:
await Media.raw("SELECT 1;")
await test_smtp(
SMTP_HOST,
SMTP_PORT,
SMTP_USER,
SMTP_PASSWORD,
SMTP_TIMEOUT,
SMTP_USE_TLS,
True,
)
return {"status_code": 200, "msg": "OK"}
except Exception:
raise HTTPException(500)
async def startup_checks():
LOGGER.info("Initializing database")
try:
await create_tables()
await Media.raw("SELECT 1;")
except Exception as e:
LOGGER.error(
f"An error occurred while trying to initialize the database -> {type(e.__name__)}: {e}"
)
else:
LOGGER.info("Database initialized")
LOGGER.info("Testing SMTP connection")
try:
await test_smtp(
SMTP_HOST,
SMTP_PORT,
SMTP_USER,
SMTP_PASSWORD,
SMTP_TIMEOUT,
SMTP_USE_TLS,
True,
)
except Exception as e:
LOGGER.error(
f"An error occurred while trying to connect to the SMTP server -> {type(e.__name__)}: {e}"
)
else:
LOGGER.info("SMTP test was successful")
if __name__ == "__main__":
LOGGER.info("Backend starting up!")
LOGGER.debug("Including modules")
app.include_router(users.router)
app.state.limiter = LIMITER
LOGGER.debug("Setting exception handlers")
app.add_exception_handler(RateLimitExceeded, rate_limited)
app.add_exception_handler(NotAuthenticated, not_authenticated)
app.add_exception_handler(HTTPException, http_exception)
app.add_exception_handler(StarletteHTTPException, http_exception)
app.add_exception_handler(RequestValidationError, request_invalid)
app.add_exception_handler(Exception, generic_error)
LOGGER.debug("Installing uvloop")
uvloop.install()
log_config = uvicorn.config.LOGGING_CONFIG
log_config["formatters"]["access"]["datefmt"] = LOGGER.handlers[0].formatter.datefmt
log_config["formatters"]["default"]["datefmt"] = LOGGER.handlers[
0
].formatter.datefmt
log_config["formatters"]["access"]["fmt"] = LOGGER.handlers[0].formatter._fmt
log_config["formatters"]["default"]["fmt"] = LOGGER.handlers[0].formatter._fmt
log_config["handlers"]["access"]["stream"] = "ext://sys.stderr"
asyncio.run(startup_checks())
uvicorn.run(host=HOST, port=PORT, app=app, log_config=log_config, workers=WORKERS)