diff --git a/endpoints/users.py b/endpoints/users.py index 3a626a5..c53887b 100644 --- a/endpoints/users.py +++ b/endpoints/users.py @@ -194,7 +194,7 @@ async def get_self(request: Request, user: UserModel = Depends(UNVERIFIED_MANAGE "/user/username/{username}", tags=["Users"], status_code=200, - responses={200: {"model": PrivateUserResponse, "exclude": {"password_hash", "internal_id", "deleted"}}, + responses={200: {"model": PublicUserResponse, "exclude": {"password_hash", "internal_id", "deleted"}}, 404: {"model": NotFound}, 422: {"model": UnprocessableEntity} }, @@ -216,7 +216,7 @@ async def get_user_by_name( "/user/id/{public_id}", tags=["Users"], status_code=200, - responses={200: {"model": PrivateUserResponse, "exclude": {"password_hash", "internal_id", "deleted"}}, + responses={200: {"model": PublicUserResponse, "exclude": {"password_hash", "internal_id", "deleted"}}, 404: {"model": NotFound}, 422: {"model": UnprocessableEntity} }, diff --git a/main.py b/main.py index 538db9f..b3e7d22 100644 --- a/main.py +++ b/main.py @@ -113,7 +113,7 @@ async def startup_checks(): await Media.raw("SELECT 1;") except Exception as e: LOGGER.error( - f"An error occurred while trying to initialize the database -> {type(e.__name__)}: {e}" + f"An error occurred while trying to initialize the database -> {type(e).__name__}: {e}" ) else: LOGGER.info("Database initialized") @@ -130,7 +130,7 @@ async def startup_checks(): ) except Exception as e: LOGGER.error( - f"An error occurred while trying to connect to the SMTP server -> {type(e.__name__)}: {e}" + f"An error occurred while trying to connect to the SMTP server -> {type(e).__name__}: {e}" ) else: LOGGER.info("SMTP test was successful") diff --git a/orm/media.py b/orm/media.py index 5615371..e497b38 100644 --- a/orm/media.py +++ b/orm/media.py @@ -3,9 +3,11 @@ Media relation """ from piccolo.table import Table -from piccolo.columns import UUID, Text, Boolean, Date, SmallInt, Varchar +from piccolo.utils.pydantic import create_pydantic_model +from piccolo.columns import Text, Boolean, Date, SmallInt, Varchar, Column, ForeignKey, OnUpdate, OnDelete from piccolo.columns.defaults.date import DateNow from enum import Enum, auto +from typing import Any class MediaType(Enum): @@ -24,9 +26,52 @@ class Media(Table): """ media_id = Varchar(length=64, primary_key=True) - media_type = SmallInt(null=False, choices=MediaType) + media_type = SmallInt(null=False, choices=MediaType, secret=True) content = Text(null=False) content_type = Varchar(length=32, null=False) - flagged = Boolean(default=False, null=False) - deleted = Boolean(default=False, null=False) + flagged = Boolean(default=False, null=False, secret=True) + deleted = Boolean(default=False, null=False, secret=True) creation_date = Date(default=DateNow(), null=False) + + +MediaModel = create_pydantic_model(Media) +PublicMediaModel = create_pydantic_model(Media, exclude_columns=(Media.flagged, Media.deleted, Media.media_type)) + + +async def get_media_by_column( + column: Column, + data: Any, + include_secrets: bool = False, + flagged_ok: bool = False, + deleted_ok: bool = False, +) -> MediaModel | None: + """ + Retrieves a media object by a given criteria. + Returns None if the media doesn't exist or + if it's restricted/deleted (unless flagged_ok + and deleted_ok are set accordingly) + """ + + media = ( + await Media.select( + *Media.all_columns(), + exclude_secrets=not include_secrets, + ) + .where(column == data) + .first() + ) + if media: + # Performs validation + media = MediaModel(**media) + if (media.deleted and not deleted_ok) or (media.flagged and not flagged_ok): + return + return media + return + + +async def get_media_by_id(media_id: str, *args, **kwargs) -> MediaModel: + """ + Retrieves a media object by its ID + """ + + return await get_media_by_column(Media.media_id, media_id, *args, **kwargs) diff --git a/orm/users.py b/orm/users.py index 981e426..2c7a791 100644 --- a/orm/users.py +++ b/orm/users.py @@ -20,7 +20,7 @@ from piccolo.columns import ( ) from piccolo.columns.defaults.date import DateNow -from .media import Media +from .media import Media, PublicMediaModel class User(Table, tablename="users"): @@ -51,7 +51,10 @@ UserModel = create_pydantic_model( User, nested=True, ) -PublicUserModel = create_pydantic_model( + +# This madness is needed because we need to exclude +# some fields from our public API responses +PublicUserModelInternal = create_pydantic_model( User, nested=True, exclude_columns=( @@ -65,7 +68,7 @@ PublicUserModel = create_pydantic_model( ), model_name="PublicUser", ) -PrivateUserModel = create_pydantic_model( +PrivateUserModelInternal = create_pydantic_model( User, nested=True, exclude_columns=(User.internal_id, User.password_hash, User.deleted), @@ -73,6 +76,14 @@ PrivateUserModel = create_pydantic_model( ) +class PublicUserModel(PublicUserModelInternal): + profile_picture: PublicMediaModel + + +class PrivateUserModel(PrivateUserModelInternal): + profile_picture: PublicMediaModel + + async def get_user_by_column( column: Column, data: Any,