diff --git a/endpoints/users.py b/endpoints/users.py index f782274..20f3f2c 100644 --- a/endpoints/users.py +++ b/endpoints/users.py @@ -73,7 +73,7 @@ from orm.users import ( get_user_by_email, ) from orm.media import Media, MediaType, PublicMediaModel -from orm.email_verification import EmailVerification +from orm.email_verification import EmailVerification, EmailVerificationType from util.email import send_email router = FastAPI() @@ -95,7 +95,6 @@ async def get_self_by_id_unverified(public_id: UUID) -> UserModel: # Here follow our *beautifully* documented path operations - @router.post( "/user", tags=["Users"], @@ -467,8 +466,11 @@ async def reset_password( detail="Verification window has expired. Try again", ) else: + # Note how we don't update based on the verification ID: + # this way, multiple pending email verification requests + # are all cleared at once await EmailVerification.update({EmailVerification.pending: False}).where( - EmailVerification.user == user.public_id + EmailVerification.user == user.public_id and EmailVerification.kind == EmailVerificationType.PASSWORD_RESET ) await User.update({User.password_hash: verification["data"]}).where( User.public_id == user.public_id @@ -513,8 +515,11 @@ async def change_email( detail="Verification window has expired. Try again", ) else: + # Note how we don't update based on the verification ID: + # this way, multiple pending email verification requests + # are all cleared at once await EmailVerification.update({EmailVerification.pending: False}).where( - EmailVerification.user == user.public_id + EmailVerification.user == user.public_id and EmailVerification.kind == EmailVerificationType.CHANGE_EMAIL ) await User.update( { @@ -937,9 +942,9 @@ async def update_user( for field in user: if ( field not in ["email_address", "password"] - and orig_user[field] != user[field] + and getattr(orig_user, field) != getattr(user, field) ): - fields.append((field, user[field])) + fields.append((field, getattr(user, field))) if fields: # If anything has changed, we update our info await User.update({field: value for field, value in fields}).where( diff --git a/orm/__init__.py b/orm/__init__.py index f0427fd..a88e9c4 100644 --- a/orm/__init__.py +++ b/orm/__init__.py @@ -7,9 +7,15 @@ from .email_verification import EmailVerification async def create_tables(): """ - Initializes the database + Initializes the database by creating the + necessary tables and indexes """ await create_db_tables(User, Media, EmailVerification, if_not_exists=True) + # Even though we use an auto-incrementing internal ID as the primary key, + # we'll almost never use it for lookups from the API (it's mostly needed + # for statistics), so we index by public_id and username since those are + # the two main fields we're going to fetch users with await User.create_index([User.public_id], if_not_exists=True) await User.create_index([User.username], if_not_exists=True) + await EmailVerification.create_index([EmailVerification.user], if_not_exists=True) diff --git a/orm/email_verification.py b/orm/email_verification.py index b401cdf..b6c0b25 100644 --- a/orm/email_verification.py +++ b/orm/email_verification.py @@ -1,14 +1,27 @@ from piccolo.table import Table -from piccolo.columns import ForeignKey, Timestamptz, Boolean, UUID, Bytea +from piccolo.columns import ForeignKey, Timestamptz, Boolean, UUID, Bytea, Integer from piccolo.columns.defaults.timestamptz import TimestamptzNow +from enum import Enum, auto from .users import User +class EmailVerificationType(Enum): + """ + Enumeration of email verification + types + """ + + PASSWORD_RESET: int = auto() + CHANGE_EMAIL: int = auto() + OTHER: int = auto() + + class EmailVerification(Table, tablename="email_verifications"): id = UUID(primary_key=True, null=False) user = ForeignKey(references=User, null=False) + kind = Integer(choices=EmailVerificationType) creation_date = Timestamptz(default=TimestamptzNow(), null=False) pending = Boolean(default=True, null=False) data = Bytea(default=None, null=True)