Minor fixes to EmailVerification class and bug fixes

This commit is contained in:
Nocturn9x 2022-10-06 13:46:11 +02:00
parent d28675b0d6
commit 73ac217b97
3 changed files with 32 additions and 8 deletions

View File

@ -73,7 +73,7 @@ from orm.users import (
get_user_by_email, get_user_by_email,
) )
from orm.media import Media, MediaType, PublicMediaModel from orm.media import Media, MediaType, PublicMediaModel
from orm.email_verification import EmailVerification from orm.email_verification import EmailVerification, EmailVerificationType
from util.email import send_email from util.email import send_email
router = FastAPI() router = FastAPI()
@ -95,7 +95,6 @@ async def get_self_by_id_unverified(public_id: UUID) -> UserModel:
# Here follow our *beautifully* documented path operations # Here follow our *beautifully* documented path operations
@router.post( @router.post(
"/user", "/user",
tags=["Users"], tags=["Users"],
@ -467,8 +466,11 @@ async def reset_password(
detail="Verification window has expired. Try again", detail="Verification window has expired. Try again",
) )
else: else:
# Note how we don't update based on the verification ID:
# this way, multiple pending email verification requests
# are all cleared at once
await EmailVerification.update({EmailVerification.pending: False}).where( await EmailVerification.update({EmailVerification.pending: False}).where(
EmailVerification.user == user.public_id EmailVerification.user == user.public_id and EmailVerification.kind == EmailVerificationType.PASSWORD_RESET
) )
await User.update({User.password_hash: verification["data"]}).where( await User.update({User.password_hash: verification["data"]}).where(
User.public_id == user.public_id User.public_id == user.public_id
@ -513,8 +515,11 @@ async def change_email(
detail="Verification window has expired. Try again", detail="Verification window has expired. Try again",
) )
else: else:
# Note how we don't update based on the verification ID:
# this way, multiple pending email verification requests
# are all cleared at once
await EmailVerification.update({EmailVerification.pending: False}).where( await EmailVerification.update({EmailVerification.pending: False}).where(
EmailVerification.user == user.public_id EmailVerification.user == user.public_id and EmailVerification.kind == EmailVerificationType.CHANGE_EMAIL
) )
await User.update( await User.update(
{ {
@ -937,9 +942,9 @@ async def update_user(
for field in user: for field in user:
if ( if (
field not in ["email_address", "password"] field not in ["email_address", "password"]
and orig_user[field] != user[field] and getattr(orig_user, field) != getattr(user, field)
): ):
fields.append((field, user[field])) fields.append((field, getattr(user, field)))
if fields: if fields:
# If anything has changed, we update our info # If anything has changed, we update our info
await User.update({field: value for field, value in fields}).where( await User.update({field: value for field, value in fields}).where(

View File

@ -7,9 +7,15 @@ from .email_verification import EmailVerification
async def create_tables(): async def create_tables():
""" """
Initializes the database Initializes the database by creating the
necessary tables and indexes
""" """
await create_db_tables(User, Media, EmailVerification, if_not_exists=True) await create_db_tables(User, Media, EmailVerification, if_not_exists=True)
# Even though we use an auto-incrementing internal ID as the primary key,
# we'll almost never use it for lookups from the API (it's mostly needed
# for statistics), so we index by public_id and username since those are
# the two main fields we're going to fetch users with
await User.create_index([User.public_id], if_not_exists=True) await User.create_index([User.public_id], if_not_exists=True)
await User.create_index([User.username], if_not_exists=True) await User.create_index([User.username], if_not_exists=True)
await EmailVerification.create_index([EmailVerification.user], if_not_exists=True)

View File

@ -1,14 +1,27 @@
from piccolo.table import Table from piccolo.table import Table
from piccolo.columns import ForeignKey, Timestamptz, Boolean, UUID, Bytea from piccolo.columns import ForeignKey, Timestamptz, Boolean, UUID, Bytea, Integer
from piccolo.columns.defaults.timestamptz import TimestamptzNow from piccolo.columns.defaults.timestamptz import TimestamptzNow
from enum import Enum, auto
from .users import User from .users import User
class EmailVerificationType(Enum):
"""
Enumeration of email verification
types
"""
PASSWORD_RESET: int = auto()
CHANGE_EMAIL: int = auto()
OTHER: int = auto()
class EmailVerification(Table, tablename="email_verifications"): class EmailVerification(Table, tablename="email_verifications"):
id = UUID(primary_key=True, null=False) id = UUID(primary_key=True, null=False)
user = ForeignKey(references=User, null=False) user = ForeignKey(references=User, null=False)
kind = Integer(choices=EmailVerificationType)
creation_date = Timestamptz(default=TimestamptzNow(), null=False) creation_date = Timestamptz(default=TimestamptzNow(), null=False)
pending = Boolean(default=True, null=False) pending = Boolean(default=True, null=False)
data = Bytea(default=None, null=True) data = Bytea(default=None, null=True)