diff --git a/endpoints/media.py b/endpoints/media.py index c1b8f94..40e2e24 100644 --- a/endpoints/media.py +++ b/endpoints/media.py @@ -22,18 +22,12 @@ router = FastAPI() }, ) @LIMITER.limit("2/second") -async def get_media( - request: Request, media_id: str, _user: UserModel = Depends(MANAGER) -): +async def get_media(request: Request, media_id: str, _user: UserModel = Depends(MANAGER)): """ Gets a media object by its ID """ - if ( - m := await Media.select(Media.media_id) - .where(Media.media_id == media_id) - .first() - ) is None: + if (m := await Media.select(Media.media_id).where(Media.media_id == media_id).first()) is None: raise HTTPException(status_code=404, detail="Media not found") m = Media(**m) if m.media_type == MediaType.FILE: @@ -60,20 +54,14 @@ async def get_media( }, ) @LIMITER.limit("2/second") -async def report_media( - request: Request, media_id: str, _user: UserModel = Depends(MANAGER) -): +async def report_media(request: Request, media_id: str, _user: UserModel = Depends(MANAGER)): """ Reports a piece of media by its ID. This creates a report that can be seen by admins, which can then decide what to do """ - if ( - m := await Media.select(Media.media_id) - .where(Media.media_id == media_id) - .first() - ) is None: + if (m := await Media.select(Media.media_id).where(Media.media_id == media_id).first()) is None: raise HTTPException(status_code=404, detail="Media not found") # TODO: Create report return APIResponse(msg="Success") diff --git a/endpoints/users.py b/endpoints/users.py index 20f3f2c..ca745ad 100644 --- a/endpoints/users.py +++ b/endpoints/users.py @@ -95,6 +95,7 @@ async def get_self_by_id_unverified(public_id: UUID) -> UserModel: # Here follow our *beautifully* documented path operations + @router.post( "/user", tags=["Users"], @@ -106,9 +107,7 @@ async def get_self_by_id_unverified(public_id: UUID) -> UserModel: }, ) @LIMITER.limit("5/minute") -async def login( - request: Request, response: Response, data: OAuth2PasswordRequestForm = Depends() -): +async def login(request: Request, response: Response, data: OAuth2PasswordRequestForm = Depends()): """ Performs user authentication. Endpoint is limited to 5 hits per minute """ @@ -117,15 +116,11 @@ async def login( raise HTTPException(status_code=400, detail="Please logout first") username = data.username if len(username) > 32: - raise HTTPException( - status_code=413, detail="Authentication failed: username is too long" - ) + raise HTTPException(status_code=413, detail="Authentication failed: username is too long") try: password = data.password.encode() if len(password) > 72: - raise HTTPException( - status_code=413, detail="Authentication failed: password is too long" - ) + raise HTTPException(status_code=413, detail="Authentication failed: password is too long") except UnicodeEncodeError as e: LOGGER.warning( f"An error occurred while attempting to decode password for user {username} -> {type(e).__name__}: {e}" @@ -134,11 +129,7 @@ async def login( status_code=413, detail="Authentication failed: invalid characters in password", ) - if not ( - user := await get_user_by_username( - username, include_secrets=True, restricted_ok=True - ) - ): + if not (user := await get_user_by_username(username, include_secrets=True, restricted_ok=True)): raise HTTPException( status_code=413, detail="Authentication failed: the user does not exist", @@ -172,9 +163,7 @@ async def login( responses={200: {"model": APIResponse}, 422: {"model": UnprocessableEntity}}, ) @LIMITER.limit("5/minute") -async def logout( - request: Request, response: Response, _user: UserModel = Depends(UNVERIFIED_MANAGER) -): +async def logout(request: Request, response: Response, _user: UserModel = Depends(UNVERIFIED_MANAGER)): """ Deletes a user's session cookie, logging them out. Endpoint is limited to 5 hits per minute @@ -230,9 +219,7 @@ async def get_self(request: Request, user: UserModel = Depends(UNVERIFIED_MANAGE }, ) @LIMITER.limit("30/second") -async def get_user_by_name( - request: Request, username: str, _auth: UserModel = Depends(MANAGER) -): +async def get_user_by_name(request: Request, username: str, _auth: UserModel = Depends(MANAGER)): """ Fetches a single user by its public username """ @@ -252,7 +239,9 @@ async def get_user_by_name( content=user.profile_picture.content, content_type=user.profile_picture.content_type, creation_date=user.profile_picture.creation_date, - ) if user.profile_picture else None, + ) + if user.profile_picture + else None, ) ) @@ -268,17 +257,13 @@ async def get_user_by_name( }, ) @LIMITER.limit("30/second") -async def get_user_by_public_id( - request: Request, public_id: str, _auth: UserModel = Depends(MANAGER) -): +async def get_user_by_public_id(request: Request, public_id: str, _auth: UserModel = Depends(MANAGER)): """ Fetches a single user by its public ID """ if not (user := await get_user_by_id(UUID(public_id))): - raise HTTPException( - status_code=404, detail="Lookup failed: the user does not exist" - ) + raise HTTPException(status_code=404, detail="Lookup failed: the user does not exist") return PublicUserResponse( data=PublicUserModel( public_id=user.public_id, @@ -292,7 +277,9 @@ async def get_user_by_public_id( content=user.profile_picture.content, content_type=user.profile_picture.content_type, creation_date=user.profile_picture.creation_date, - ) if user.profile_picture else None, + ) + if user.profile_picture + else None, ) ) @@ -323,25 +310,15 @@ async def validate_user( return False, "username is too short" if username and len(username) > 32: return False, "username is too long" - if ( - username - and VALIDATE_USERNAME_REGEX - and not re.match(VALIDATE_USERNAME_REGEX, username) - ): + if username and VALIDATE_USERNAME_REGEX and not re.match(VALIDATE_USERNAME_REGEX, username): return False, "username is invalid" if email and not validators.email(email): return False, "email is not valid" if password and len(password) > 72: return False, "password is too long" - if ( - password - and VALIDATE_PASSWORD_REGEX - and not re.match(VALIDATE_PASSWORD_REGEX, password) - ): + if password and VALIDATE_PASSWORD_REGEX and not re.match(VALIDATE_PASSWORD_REGEX, password): return False, "password is too weak" - if username and await get_user_by_username( - username, deleted_ok=True, restricted_ok=True - ): + if username and await get_user_by_username(username, deleted_ok=True, restricted_ok=True): return False, "username is already taken" if email and await get_user_by_email(email, deleted_ok=True, restricted_ok=True): return False, "email is already registered" @@ -359,13 +336,10 @@ async def validate_user( "/user", tags=["Users"], status_code=200, - responses={200: {"model": APIResponse}, - 422: {"model": UnprocessableEntity}}, + responses={200: {"model": APIResponse}, 422: {"model": UnprocessableEntity}}, ) @LIMITER.limit("1/minute") -async def delete( - request: Request, response: Response, user: UserModel = Depends(UNVERIFIED_MANAGER) -): +async def delete(request: Request, response: Response, user: UserModel = Depends(UNVERIFIED_MANAGER)): """ Sets the user's deleted flag in the database, without actually deleting the associated @@ -412,9 +386,9 @@ async def verify_email( raise HTTPException(status_code=404, detail="Verification ID is invalid") elif not verification["pending"]: raise HTTPException(status_code=400, detail="Email is already verified") - elif datetime.now().astimezone(timezone.utc) - verification[ - "creation_date" - ].astimezone(timezone.utc) > timedelta(seconds=EMAIL_VERIFICATION_EXPIRATION): + elif datetime.now().astimezone(timezone.utc) - verification["creation_date"].astimezone(timezone.utc) > timedelta( + seconds=EMAIL_VERIFICATION_EXPIRATION + ): raise HTTPException( status_code=400, detail="Verification window has expired. Try again", @@ -423,9 +397,7 @@ async def verify_email( await EmailVerification.update({EmailVerification.pending: False}).where( EmailVerification.user == user.public_id ) - await User.update({User.email_verified: True}).where( - User.public_id == user.public_id - ) + await User.update({User.email_verified: True}).where(User.public_id == user.public_id) return APIResponse(status_code=200, msg="Verification successful") @@ -458,9 +430,9 @@ async def reset_password( raise HTTPException(status_code=404, detail="Request ID is invalid") elif not verification["pending"]: raise HTTPException(status_code=400, detail="This link has already been used") - elif datetime.now().astimezone(timezone.utc) - verification[ - "creation_date" - ].astimezone(timezone.utc) > timedelta(seconds=EMAIL_VERIFICATION_EXPIRATION): + elif datetime.now().astimezone(timezone.utc) - verification["creation_date"].astimezone(timezone.utc) > timedelta( + seconds=EMAIL_VERIFICATION_EXPIRATION + ): raise HTTPException( status_code=400, detail="Verification window has expired. Try again", @@ -472,9 +444,7 @@ async def reset_password( await EmailVerification.update({EmailVerification.pending: False}).where( EmailVerification.user == user.public_id and EmailVerification.kind == EmailVerificationType.PASSWORD_RESET ) - await User.update({User.password_hash: verification["data"]}).where( - User.public_id == user.public_id - ) + await User.update({User.password_hash: verification["data"]}).where(User.public_id == user.public_id) return APIResponse(status_code=200, msg="Password updated") @@ -507,9 +477,9 @@ async def change_email( raise HTTPException(status_code=404, detail="Request ID is invalid") elif not verification["pending"]: raise HTTPException(status_code=400, detail="This link has already been used") - elif datetime.now().astimezone(timezone.utc) - verification[ - "creation_date" - ].astimezone(timezone.utc) > timedelta(seconds=EMAIL_VERIFICATION_EXPIRATION): + elif datetime.now().astimezone(timezone.utc) - verification["creation_date"].astimezone(timezone.utc) > timedelta( + seconds=EMAIL_VERIFICATION_EXPIRATION + ): raise HTTPException( status_code=400, detail="Verification window has expired. Try again", @@ -593,8 +563,7 @@ async def resend_email(request: Request, user: UserModel = Depends(UNVERIFIED_MA else: raise HTTPException( status_code=500, - detail="An error occurred while trying to resend the email," - " please try again later", + detail="An error occurred while trying to resend the email," " please try again later", ) @@ -629,9 +598,7 @@ async def signup( raise HTTPException(status_code=400, detail="Please logout first") # We don't use FastAPI's validation because we want custom error # messages - result, msg = await validate_user( - first_name, last_name, username, email, password, bio - ) + result, msg = await validate_user(first_name, last_name, username, email, password, bio) if not result: return APIResponse(status_code=413, msg=f"Signup failed: {msg}") else: @@ -691,8 +658,7 @@ async def signup( else: raise HTTPException( status_code=500, - detail="An error occurred while sending verification email, please" - " try again later", + detail="An error occurred while sending verification email, please" " try again later", ) @@ -763,15 +729,9 @@ async def update_user( since they're the only ones that can be set to a null value """ - if not delete and not any( - (first_name, last_name, username, profile_picture, email_address, bio, password) - ): - raise HTTPException( - status_code=400, detail="At least one value has to be specified" - ) - result, msg = await validate_user( - first_name, last_name, username, email_address, password, bio - ) + if not delete and not any((first_name, last_name, username, profile_picture, email_address, bio, password)): + raise HTTPException(status_code=400, detail="At least one value has to be specified") + result, msg = await validate_user(first_name, last_name, username, email_address, password, bio) if not result: raise HTTPException(status_code=413, detail=f"Update failed: {msg}") orig_user = user.copy() @@ -787,16 +747,10 @@ async def update_user( if profile_picture: result, ext, media, digest = validate_profile_picture(profile_picture) if result is False: - raise HTTPException( - status_code=415, detail="The file type is unsupported" - ) + raise HTTPException(status_code=415, detail="The file type is unsupported") elif result is None: raise HTTPException(status_code=413, detail="The file is too large") - elif ( - m := await Media.select(Media.media_id) - .where(Media.media_id == digest) - .first() - ) is None: + elif (m := await Media.select(Media.media_id).where(Media.media_id == digest).first()) is None: # This media hasn't been already uploaded (either by this user or by someone # else), so we save it now. If it has been already uploaded, there's no need # to do it again (that's what the hash is for) @@ -872,9 +826,7 @@ async def update_user( { EmailVerification.id: verification_id, EmailVerification.user: User(public_id=user.public_id), - EmailVerification.data: bcrypt.hashpw( - password.encode(), user.password_hash[:29] - ), + EmailVerification.data: bcrypt.hashpw(password.encode(), user.password_hash[:29]), } ) ) @@ -940,14 +892,9 @@ async def update_user( user.profile_picture = None fields = [] for field in user: - if ( - field not in ["email_address", "password"] - and getattr(orig_user, field) != getattr(user, field) - ): + if field not in ["email_address", "password"] and getattr(orig_user, field) != getattr(user, field): fields.append((field, getattr(user, field))) if fields: # If anything has changed, we update our info - await User.update({field: value for field, value in fields}).where( - User.public_id == user.public_id - ) + await User.update({field: value for field, value in fields}).where(User.public_id == user.public_id) return APIResponse(status_code=200, msg="Changes saved successfully") diff --git a/main.py b/main.py index 60e9d54..a9ee7e7 100644 --- a/main.py +++ b/main.py @@ -116,9 +116,7 @@ async def startup_checks(): await create_tables() await Media.raw("SELECT 1;") except Exception as e: - LOGGER.error( - f"An error occurred while trying to initialize the database -> {type(e).__name__}: {e}" - ) + LOGGER.error(f"An error occurred while trying to initialize the database -> {type(e).__name__}: {e}") else: LOGGER.info("Database initialized") LOGGER.info("Testing SMTP connection") @@ -133,9 +131,7 @@ async def startup_checks(): True, ) except Exception as e: - LOGGER.error( - f"An error occurred while trying to connect to the SMTP server -> {type(e).__name__}: {e}" - ) + LOGGER.error(f"An error occurred while trying to connect to the SMTP server -> {type(e).__name__}: {e}") else: LOGGER.info("SMTP test was successful") @@ -157,9 +153,7 @@ if __name__ == "__main__": uvloop.install() log_config = uvicorn.config.LOGGING_CONFIG log_config["formatters"]["access"]["datefmt"] = LOGGER.handlers[0].formatter.datefmt - log_config["formatters"]["default"]["datefmt"] = LOGGER.handlers[ - 0 - ].formatter.datefmt + log_config["formatters"]["default"]["datefmt"] = LOGGER.handlers[0].formatter.datefmt log_config["formatters"]["access"]["fmt"] = LOGGER.handlers[0].formatter._fmt log_config["formatters"]["default"]["fmt"] = LOGGER.handlers[0].formatter._fmt log_config["handlers"]["access"]["stream"] = "ext://sys.stderr" diff --git a/orm/media.py b/orm/media.py index 67851b0..3d64429 100644 --- a/orm/media.py +++ b/orm/media.py @@ -45,9 +45,7 @@ class Media(Table): MediaModel = create_pydantic_model(Media) -PublicMediaModel = create_pydantic_model( - Media, exclude_columns=(Media.flagged, Media.deleted, Media.media_type) -) +PublicMediaModel = create_pydantic_model(Media, exclude_columns=(Media.flagged, Media.deleted, Media.media_type)) async def get_media_by_column( diff --git a/orm/posts.py b/orm/posts.py index 28c1a04..2830000 100644 --- a/orm/posts.py +++ b/orm/posts.py @@ -38,9 +38,7 @@ class Post(Table, tablename="posts"): PostModel = create_pydantic_model(Post, nested=True) -PrivatePostModelInternal = create_pydantic_model( - Post, nested=True, exclude_columns=(Post.flagged, Post.internal_id) -) +PrivatePostModelInternal = create_pydantic_model(Post, nested=True, exclude_columns=(Post.flagged, Post.internal_id)) PublicPostModelInternal = create_pydantic_model( Post, nested=True, exclude_columns=(Post.flagged, Post.deleted, Post.internal_id) ) diff --git a/orm/users.py b/orm/users.py index 4a1e8ae..58275e8 100644 --- a/orm/users.py +++ b/orm/users.py @@ -65,7 +65,7 @@ PublicUserModelInternal = create_pydantic_model( User.email_verified, User.email_address, User.password_hash, - User.creation_date + User.creation_date, ), model_name="PublicUser", ) diff --git a/util/email.py b/util/email.py index fa3961c..731bf62 100644 --- a/util/email.py +++ b/util/email.py @@ -47,9 +47,7 @@ async def send_email( await srv.login(login_email, password) await srv.sendmail(sender, recipient, msg.as_string()) except (aiosmtplib.SMTPException, asyncio.TimeoutError) as error: - logging.error( - f"An error occurred while dealing with {host}:{port} (SMTP): {type(error).__name__}: {error}" - ) + logging.error(f"An error occurred while dealing with {host}:{port} (SMTP): {type(error).__name__}: {error}") return error return True diff --git a/util/exception_handlers.py b/util/exception_handlers.py index cbf093d..6910642 100644 --- a/util/exception_handlers.py +++ b/util/exception_handlers.py @@ -10,6 +10,7 @@ async def rate_limited(request: Request, error: RateLimitExceeded) -> JSONRespon """ Handles the equivalent of a 429 Too Many Requests error """ + n = 0 while True: if error.detail[n].isnumeric(): @@ -17,10 +18,7 @@ async def rate_limited(request: Request, error: RateLimitExceeded) -> JSONRespon else: break error.detail = error.detail[:n] + " requests" + error.detail[n:] - LOGGER.info( - f"{request.client.host} got rate-limited at {str(request.url)} " - f"(exceeded {error.detail})" - ) + LOGGER.info(f"{request.client.host} got rate-limited at {str(request.url)} " f"(exceeded {error.detail})") return JSONResponse( status_code=200, content=dict( @@ -36,9 +34,7 @@ def not_authenticated(request: Request, _: NotAuthenticated) -> JSONResponse: """ LOGGER.info(f"{request.client.host} failed to authenticate at {str(request.url)}") - return JSONResponse( - status_code=200, content=dict(status_code=401, msg="Authentication is required") - ) + return JSONResponse(status_code=200, content=dict(status_code=401, msg="Authentication is required")) def request_invalid(request: Request, exc: RequestValidationError) -> JSONResponse: @@ -46,18 +42,14 @@ def request_invalid(request: Request, exc: RequestValidationError) -> JSONRespon Handles Bad Request exceptions from FastAPI """ - LOGGER.info( - f"{request.client.host} sent an invalid request at {request.url!r}: {type(exc).__name__}: {exc}" - ) + LOGGER.info(f"{request.client.host} sent an invalid request at {request.url!r}: {type(exc).__name__}: {exc}") return JSONResponse( status_code=200, content=dict(status_code=400, msg=f"Bad request: {type(exc).__name__}: {exc}"), ) -def http_exception( - request: Request, exc: HTTPException | StarletteHTTPException -) -> JSONResponse: +def http_exception(request: Request, exc: HTTPException | StarletteHTTPException) -> JSONResponse: """ Handles HTTP-specific exceptions raised explicitly by path operations @@ -65,19 +57,12 @@ def http_exception( if exc.status_code >= 500: LOGGER.error( - f"{request.client.host} raised a {exc.status_code} error at {request.url!r}:" - f"{type(exc).__name__}: {exc}" - ) - return JSONResponse( - status_code=200, content=dict(status_code=500, msg="Internal Server Error") + f"{request.client.host} raised a {exc.status_code} error at {request.url!r}:" f"{type(exc).__name__}: {exc}" ) + return JSONResponse(status_code=200, content=dict(status_code=500, msg="Internal Server Error")) else: - LOGGER.info( - f"{request.client.host} raised an HTTP error ({exc.status_code}) at {str(request.url)}" - ) - return JSONResponse( - status_code=200, content=dict(status_code=exc.status_code, msg=exc.detail) - ) + LOGGER.info(f"{request.client.host} raised an HTTP error ({exc.status_code}) at {str(request.url)}") + return JSONResponse(status_code=200, content=dict(status_code=exc.status_code, msg=exc.detail)) async def generic_error(request: Request, exc: Exception) -> JSONResponse: @@ -85,10 +70,6 @@ async def generic_error(request: Request, exc: Exception) -> JSONResponse: Handles generic, unexpected errors in the ASGI application """ - LOGGER.info( - f"{request.client.host} raised an unexpected error ({type(exc).__name__}: {exc}) at {str(request.url)}" - ) + LOGGER.info(f"{request.client.host} raised an unexpected error ({type(exc).__name__}: {exc}) at {str(request.url)}") # We can't leak anything about the error, it would be too risky - return JSONResponse( - status_code=200, content=dict(status_code=500, msg="Internal Server Error") - ) + return JSONResponse(status_code=200, content=dict(status_code=500, msg="Internal Server Error"))