From ce6ee2335a3c0f4cdf98a76032996c2d4d77931e Mon Sep 17 00:00:00 2001 From: Nocturn9x Date: Tue, 4 Oct 2022 23:45:21 +0200 Subject: [PATCH] Updated Media schema and implemented most of the PATCH functionality for the user endpoint --- endpoints/users.py | 129 ++++++++++++++++++++++++++++++++----- main.py | 4 +- orm/media.py | 19 +++++- util/exception_handlers.py | 18 ++++++ 4 files changed, 151 insertions(+), 19 deletions(-) diff --git a/endpoints/users.py b/endpoints/users.py index fee0a20..f4f6f45 100644 --- a/endpoints/users.py +++ b/endpoints/users.py @@ -1,13 +1,19 @@ +import base64 +import hashlib import re +import imghdr +import uuid +import zlib import bcrypt from uuid import UUID import validators -from fastapi import APIRouter as FastAPI, Depends, Response, Request +from fastapi import APIRouter as FastAPI, Depends, Response, Request, UploadFile from fastapi.exceptions import HTTPException from fastapi.security import OAuth2PasswordRequestForm from datetime import timedelta +from pathlib import Path from config import ( BCRYPT_ROUNDS, @@ -23,8 +29,14 @@ from config import ( SECURE_COOKIE, COOKIE_PATH, COOKIE_HTTPONLY, + ALLOWED_MEDIA_TYPES, + ZLIB_COMPRESSION_LEVEL, + MAX_MEDIA_SIZE, + STORAGE_ENGINE, + STORAGE_FOLDER ) from orm.users import UserModel, User +from orm.media import Media, MediaType router = FastAPI() @@ -244,36 +256,37 @@ async def get_user_by_public_id( async def validate_user( - first_name: str, last_name: str, username: str, email: str, password: str + first_name: str, last_name: str, username: str, email: str, password: str | None ) -> tuple[bool, str]: """ Performs some validation upon user creation. Returns - a tuple (success, msg) to be used by routes + a tuple (success, msg) to be used by routes. Values + set to None are not checked against """ - if len(first_name) > 64: + if first_name and len(first_name) > 64: return False, "first name is too long" - if len(first_name) < 5: + if first_name and len(first_name) < 5: return False, "first name is too short" - if len(last_name) > 64: + if last_name and len(last_name) > 64: return False, "last name is too long" - if len(last_name) < 2: + if last_name and len(last_name) < 2: return False, "last name is too short" - if len(username) < 5: + if username and len(username) < 5: return False, "username is too short" - if len(username) > 32: + if username and len(username) > 32: return False, "username is too long" - if VALIDATE_USERNAME_REGEX and not re.match(VALIDATE_USERNAME_REGEX, username): + if username and VALIDATE_USERNAME_REGEX and not re.match(VALIDATE_USERNAME_REGEX, username): return False, "username is invalid" - if not validators.email(email): + if email and not validators.email(email): return False, "email is not valid" - if len(password) > 72: + if password and len(password) > 72: return False, "password is too long" - if VALIDATE_PASSWORD_REGEX and not re.match(VALIDATE_PASSWORD_REGEX, password): + if password and VALIDATE_PASSWORD_REGEX and not re.match(VALIDATE_PASSWORD_REGEX, password): return False, "password is too weak" - if await get_user_by_username(username, deleted_ok=True, restricted_ok=True): + if username and await get_user_by_username(username, deleted_ok=True, restricted_ok=True): return False, "username is already taken" - if await get_user_by_email(email, deleted_ok=True, restricted_ok=True): + if email and await get_user_by_email(email, deleted_ok=True, restricted_ok=True): return False, "email is already registered" return True, "" @@ -332,3 +345,89 @@ async def signup( ) ) return {"status_code": 200, "msg": "Success"} + + +async def validate_profile_picture(file: UploadFile) -> tuple[bool | None, str, bytes, str]: + """ + Validates a profile picture's size and content to see if it fits + our criteria and returns a tuple result, ext, data where result is a + boolean or none (True = check was passed, False = size too large, + None = check was failed for other reasons) indicating if the check was successful, + ext is the file's type and extension, data is a compressed stream of bytes + representing the original media and hash is the file's SHA256 hash encoded in + hexadecimal, before the compression. This function never raises an exception + """ + + async with file: + try: + content = await file.read() + if len(content) > MAX_MEDIA_SIZE: + return False, "", b"", "" + if not (ext := imghdr.what(content.decode())) in ALLOWED_MEDIA_TYPES: + return None, "", b"", "" + return True, ext, zlib.compress(content, ZLIB_COMPRESSION_LEVEL), hashlib.sha256(content).hexdigest() + except (UnicodeDecodeError, zlib.error): + return None, "", b"", "" + + +@router.patch("/user") +async def update(request: Request, user: dict = Depends(MANAGER), + first_name: str | None = None, last_name: str | None = None, + username: str | None = None, profile_picture: UploadFile | None = None, + email_address: str | None = None, bio: str | None = None): + """ + Updates a user's profile information. Parameters that are not specified are left unchanged. + At least one parameter has to be non-null. Setting a new email address is only allowed if the + old one is verified and will require the user to click a link sent to the current email address + to authorize the operation, after which the address is modified. + """ + + if not any((first_name, last_name, username, profile_picture, email_address, bio)): + raise HTTPException(status_code=400, detail="At least one value has to be specified") + result, msg = await validate_user(first_name, last_name, username, email_address, None) + if not result: + raise HTTPException(status_code=413, detail=f"Update failed: {msg}") + orig_user = user.copy() + if first_name: + user["first_name"] = first_name + if last_name: + user["last_name"] = last_name + if username: + user["username"] = username + if profile_picture: + result, ext, media, digest = validate_profile_picture(profile_picture) + if result is False: + raise HTTPException(status_code=415, detail="The file type is unsupported") + elif result is None: + raise HTTPException(status_code=413, detail="The file is too large") + elif await (old_media := Media.select(Media.media_id).where(Media.media_id == digest).first()) is None: + # This media hasn't been already uploaded (either by this user or by someone + # else), so we save it now. If it has been already uploaded, there's no need + # to do it again (that's what the hash is for) + match STORAGE_ENGINE: + case "database": + await Media.insert(Media(media_id=digest, media_type=MediaType.BLOB, + content_type=ext, content=base64.b64encode(media))) + case "local": + file = Path(STORAGE_FOLDER).resolve(strict=True) / str(digest) + file.touch(mode=0o644) + with file.open("wb") as f: + f.write(media) + await Media.insert(Media(media_id=digest, media_type=MediaType.FILE, + content_type=ext, content=file.as_posix())) + case "url": + pass # TODO: Use/implement CDN uploading + else: + user["media"] = old_media + if email_address: + if not user["email_verified"]: + raise HTTPException(status_code=403, detail="The email address needs to be verified first") + pass # TODO: Requires email verification + fields = [] + for field in user: + if field != "email_address" and orig_user[field] != user[field]: + fields.append((field, user[field])) + if fields: + # If anything has changed, we update our info + await User.update({field: value for field, value in fields}).where(User.public_id == user["id"]) + return {"status_code": 200, "msg": "Changes saved successfully"} diff --git a/main.py b/main.py index 98d16b9..20cff00 100644 --- a/main.py +++ b/main.py @@ -32,6 +32,7 @@ from util.exception_handlers import ( rate_limited, request_invalid, not_authenticated, + generic_error ) from util.email import test_smtp @@ -41,7 +42,7 @@ app = FastAPI() @app.get("/") @LIMITER.limit("10/second") async def root(request: Request): - return {"status_code": 403, "msg": "Unauthorized"} + raise HTTPException(401, detail="Unauthorized") @app.get("/ping") @@ -113,6 +114,7 @@ if __name__ == "__main__": app.add_exception_handler(HTTPException, http_exception) app.add_exception_handler(StarletteHTTPException, http_exception) app.add_exception_handler(RequestValidationError, request_invalid) + app.add_exception_handler(Exception, generic_error) LOGGER.debug("Installing uvloop") uvloop.install() log_config = uvicorn.config.LOGGING_CONFIG diff --git a/orm/media.py b/orm/media.py index fd3a634..5615371 100644 --- a/orm/media.py +++ b/orm/media.py @@ -3,8 +3,19 @@ Media relation """ from piccolo.table import Table -from piccolo.columns import UUID, Text, Boolean, Date +from piccolo.columns import UUID, Text, Boolean, Date, SmallInt, Varchar from piccolo.columns.defaults.date import DateNow +from enum import Enum, auto + + +class MediaType(Enum): + """ + Represents a media type + """ + + URL: int = auto() + BLOB: int = auto() + FILE: int = auto() class Media(Table): @@ -12,8 +23,10 @@ class Media(Table): A piece of media on a CDN """ - media_id = UUID(primary_key=True) - media_url = Text(null=False) + media_id = Varchar(length=64, primary_key=True) + media_type = SmallInt(null=False, choices=MediaType) + content = Text(null=False) + content_type = Varchar(length=32, null=False) flagged = Boolean(default=False, null=False) deleted = Boolean(default=False, null=False) creation_date = Date(default=DateNow(), null=False) diff --git a/util/exception_handlers.py b/util/exception_handlers.py index 060d673..085ec8c 100644 --- a/util/exception_handlers.py +++ b/util/exception_handlers.py @@ -77,3 +77,21 @@ def http_exception( "status_code": exc.status_code, }, ) + + +async def generic_error(request: Request, exc: Exception) -> JSONResponse: + """ + Handles generic, unexpected errors in the application + """ + + LOGGER.info( + f"{request.client.host} raised an unexpected error ({type(exc).__name__}: {exc}) at {str(request.url)}" + ) + return JSONResponse( + status_code=200, + content={ + "msg": "Internal Server Error", # We can't leak anything about the error, it would be too risky + "status_code": 500, + }, + ) +