From 7679e12e33a14c99a80b942a8437437c6fd4afac Mon Sep 17 00:00:00 2001 From: Uchechukwu Orji Date: Wed, 29 Apr 2026 14:00:06 +0100 Subject: [PATCH] rename user to account --- backend/prestart.sh | 14 -- backend/pyproject.toml | 5 +- backend/src/cms_backend/api/main.py | 8 +- backend/src/cms_backend/api/routes/account.py | 109 ++++++++++++++++ backend/src/cms_backend/api/routes/auth.py | 40 +++--- .../cms_backend/api/routes/dependencies.py | 74 ++++++----- backend/src/cms_backend/api/routes/user.py | 106 ---------------- backend/src/cms_backend/api/token.py | 96 ++++++++++---- backend/src/cms_backend/db/account.py | 120 ++++++++++++++++++ backend/src/cms_backend/db/models.py | 18 ++- backend/src/cms_backend/db/refresh_token.py | 8 +- backend/src/cms_backend/db/user.py | 106 ---------------- .../8f0ee48f7dce_rename_user_to_account.py | 76 +++++++++++ backend/src/cms_backend/roles.py | 12 +- backend/src/cms_backend/schemas/orms.py | 4 +- backend/src/cms_backend/utils/database.py | 14 +- .../routes/{test_user.py => test_account.py} | 28 ++-- backend/tests/api/routes/test_titles.py | 18 ++- backend/tests/api/test_token_decoder.py | 10 +- backend/tests/conftest.py | 24 ++-- backend/tests/db/test_account.py | 94 ++++++++++++++ backend/tests/db/test_refresh_token.py | 14 +- backend/tests/db/test_user.py | 90 ------------- backend/tests/db/test_zimfarm_notification.py | 2 +- dev/README.md | 8 +- frontend/src/stores/auth.ts | 46 +++---- frontend/src/types/account.ts | 4 + frontend/src/types/user.ts | 2 +- frontend/src/views/OAuthCallbackView.vue | 2 +- healthcheck/README.md | 2 +- 30 files changed, 654 insertions(+), 500 deletions(-) delete mode 100755 backend/prestart.sh create mode 100644 backend/src/cms_backend/api/routes/account.py delete mode 100644 backend/src/cms_backend/api/routes/user.py create mode 100644 backend/src/cms_backend/db/account.py delete mode 100644 backend/src/cms_backend/db/user.py create mode 100644 backend/src/cms_backend/migrations/versions/8f0ee48f7dce_rename_user_to_account.py rename backend/tests/api/routes/{test_user.py => test_account.py} (69%) create mode 100644 backend/tests/db/test_account.py delete mode 100644 backend/tests/db/test_user.py create mode 100644 frontend/src/types/account.ts diff --git a/backend/prestart.sh b/backend/prestart.sh deleted file mode 100755 index fad05d0b..00000000 --- a/backend/prestart.sh +++ /dev/null @@ -1,14 +0,0 @@ -#!/bin/sh - -die() { - echo "unable to run database initializations. exiting" - exit 1 -} - -if [ ! -z "$ALEMBIC_UPGRADE_HEAD_ON_START" ]; then - echo "Running alembic upgrade" - alembic check || true - alembic history - alembic upgrade head -fi -create-initial-user diff --git a/backend/pyproject.toml b/backend/pyproject.toml index ef52e065..c926dd05 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -78,7 +78,7 @@ Donate = "https://www.kiwix.org/en/support-us/" cms-api = "cms_backend.api.main:app" cms-mill = "cms_backend.mill.main:main" cms-shuttle = "cms_backend.shuttle.main:main" -create-initial-user = "cms_backend.utils.database:create_initial_user" +create-initial-account = "cms_backend.utils.database:create_initial_account" check-db-schema = "cms_backend.utils.database:check_if_schema_is_up_to_date" [tool.hatch.version] @@ -235,9 +235,6 @@ minversion = "8.3.5" testpaths = ["tests"] pythonpath = [".", "src"] addopts = "--strict-markers" -markers = [ - "num_users(num=10, *, permission=...): create num users in the database with permission (default: ADMIN)", -] [tool.coverage.paths] diff --git a/backend/src/cms_backend/api/main.py b/backend/src/cms_backend/api/main.py index 3223d754..b9954ebd 100644 --- a/backend/src/cms_backend/api/main.py +++ b/backend/src/cms_backend/api/main.py @@ -10,6 +10,7 @@ from fastapi.responses import JSONResponse from pydantic import ValidationError +from cms_backend.api.routes.account import router as account_router from cms_backend.api.routes.auth import router as auth_router from cms_backend.api.routes.books import router as books_router from cms_backend.api.routes.collection import router as collection_router @@ -18,7 +19,6 @@ from cms_backend.api.routes.http_errors import BadRequestError from cms_backend.api.routes.staging import router as staging_router from cms_backend.api.routes.titles import router as titles_router -from cms_backend.api.routes.user import router as user_router from cms_backend.api.routes.zimfarm_notifications import ( router as zimfarm_notification_router, ) @@ -30,7 +30,7 @@ ) from cms_backend.utils.database import ( check_if_schema_is_up_to_date, - create_initial_user, + create_initial_account, upgrade_db_schema, ) @@ -40,7 +40,7 @@ async def lifespan(_: FastAPI): if Context.alembic_upgrade_head_on_start: upgrade_db_schema() check_if_schema_is_up_to_date() - create_initial_user() + create_initial_account() yield @@ -72,7 +72,7 @@ def create_app(*, debug: bool = True): main_router.include_router(router=collection_router) main_router.include_router(router=events_router) main_router.include_router(router=auth_router) - main_router.include_router(router=user_router) + main_router.include_router(router=account_router) main_router.include_router(router=staging_router) app.include_router(router=main_router) diff --git a/backend/src/cms_backend/api/routes/account.py b/backend/src/cms_backend/api/routes/account.py new file mode 100644 index 00000000..ffdbaf3f --- /dev/null +++ b/backend/src/cms_backend/api/routes/account.py @@ -0,0 +1,109 @@ +from http import HTTPStatus +from typing import Annotated + +from fastapi import APIRouter, Depends, Path, Response +from sqlalchemy.orm import Session as OrmSession +from werkzeug.security import check_password_hash, generate_password_hash + +from cms_backend.api.routes.dependencies import get_current_account, require_permission +from cms_backend.api.routes.fields import NotEmptyString +from cms_backend.api.routes.http_errors import BadRequestError, UnauthorizedError +from cms_backend.db import gen_dbsession +from cms_backend.db.account import ( + check_account_permission, + create_account_schema, + get_account_by_username, +) +from cms_backend.db.account import create_account as db_create_account +from cms_backend.db.account import delete_account as db_delete_account +from cms_backend.db.account import update_account_password as db_update_account_password +from cms_backend.db.models import Account +from cms_backend.roles import RoleEnum +from cms_backend.schemas import BaseModel +from cms_backend.schemas.orms import AccountSchema + +router = APIRouter(prefix="/accounts", tags=["accounts"]) + + +class AccountCreateSchema(BaseModel): + """ + Schema for creating an account + """ + + username: NotEmptyString + password: NotEmptyString + role: RoleEnum + + +class PasswordUpdateSchema(BaseModel): + """ + Schema for updating an account's password + """ + + # account with elevated permissions can omit the current password + current: NotEmptyString | None = None + new: NotEmptyString + + +@router.post( + "", + dependencies=[Depends(require_permission(namespace="account", name="create"))], +) +def create_account( + account_schema: AccountCreateSchema, + db_session: Annotated[OrmSession, Depends(gen_dbsession)], +) -> AccountSchema: + account = db_create_account( + db_session, + username=account_schema.username, + password_hash=generate_password_hash(account_schema.password), + role=account_schema.role, + ) + + return create_account_schema(account) + + +@router.delete( + "/{username}", + dependencies=[Depends(require_permission(namespace="account", name="delete"))], +) +def delete_account( + username: Annotated[str, Path()], + db_session: Annotated[OrmSession, Depends(gen_dbsession)], +) -> Response: + """Delete a specific account""" + account = get_account_by_username(db_session, username=username) + db_delete_account(db_session, account_id=account.id) + return Response(status_code=HTTPStatus.NO_CONTENT) + + +@router.patch("/{username}/password") +def update_account_password( + username: Annotated[str, Path()], + password_update: PasswordUpdateSchema, + db_session: Annotated[OrmSession, Depends(gen_dbsession)], + current_account: Annotated[Account, Depends(get_current_account)], +) -> Response: + """Update an account's password""" + account = get_account_by_username(db_session, username=username) + + if current_account.username == username: + if password_update.current is None: + raise BadRequestError("You must enter your current password.") + + if not check_password_hash( + current_account.password_hash or "", password_update.current + ): + raise BadRequestError() + + elif not check_account_permission( + current_account, namespace="account", name="update" + ): + raise UnauthorizedError("You are not allowed to access this resource") + + db_update_account_password( + db_session, + account_id=account.id, + password_hash=generate_password_hash(password_update.new), + ) + return Response(status_code=HTTPStatus.NO_CONTENT) diff --git a/backend/src/cms_backend/api/routes/auth.py b/backend/src/cms_backend/api/routes/auth.py index c7e4d8db..6165bad3 100644 --- a/backend/src/cms_backend/api/routes/auth.py +++ b/backend/src/cms_backend/api/routes/auth.py @@ -7,21 +7,21 @@ from werkzeug.security import check_password_hash from cms_backend.api.context import Context -from cms_backend.api.routes.dependencies import get_current_user +from cms_backend.api.routes.dependencies import get_current_account from cms_backend.api.routes.http_errors import UnauthorizedError from cms_backend.api.token import generate_access_token from cms_backend.db import gen_dbsession +from cms_backend.db.account import create_account_schema, get_account_by_username from cms_backend.db.exceptions import RecordDoesNotExistError -from cms_backend.db.models import User +from cms_backend.db.models import Account from cms_backend.db.refresh_token import ( create_refresh_token, delete_refresh_token, expire_refresh_tokens, get_refresh_token, ) -from cms_backend.db.user import create_user_schema, get_user_by_username from cms_backend.schemas import BaseModel -from cms_backend.schemas.orms import UserSchema +from cms_backend.schemas.orms import AccountSchema from cms_backend.utils.datetime import getnow router = APIRouter(prefix="/auth", tags=["auth"]) @@ -45,17 +45,19 @@ class Token(BaseModel): refresh_token: str -def _access_token_response(db_session: OrmSession, db_user: User, response: Response): +def _access_token_response( + db_session: OrmSession, db_account: Account, response: Response +): response.headers["Cache-Control"] = "no-store" response.headers["Pragma"] = "no-cache" issue_time = getnow() return Token( access_token=generate_access_token( - user_id=str(db_user.id), + account_id=str(db_account.id), issue_time=issue_time, ), refresh_token=str( - create_refresh_token(session=db_session, user_id=db_user.id).token + create_refresh_token(session=db_session, account_id=db_account.id).token ), expires_time=issue_time + datetime.timedelta(seconds=Context.jwt_token_expiry_duration), @@ -65,19 +67,19 @@ def _access_token_response(db_session: OrmSession, db_user: User, response: Resp def _auth_with_credentials( db_session: OrmSession, credentials: CredentialsIn, response: Response ): - """Authorize a user with username and password.""" + """Authorize an account with username and password.""" try: - db_user = get_user_by_username(db_session, username=credentials.username) + db_account = get_account_by_username(db_session, username=credentials.username) except RecordDoesNotExistError as exc: raise UnauthorizedError() from exc if not ( - db_user.password_hash - and check_password_hash(db_user.password_hash, credentials.password) + db_account.password_hash + and check_password_hash(db_account.password_hash, credentials.password) ): raise UnauthorizedError("Invalid credentials") - return _access_token_response(db_session, db_user, response) + return _access_token_response(db_session, db_account, response) def _refresh_access_token( @@ -96,7 +98,7 @@ def _refresh_access_token( delete_refresh_token(db_session, token=refresh_token) expire_refresh_tokens(db_session, expire_time=now) - return _access_token_response(db_session, db_refresh_token.user, response) + return _access_token_response(db_session, db_refresh_token.account, response) @router.post("/authorize") @@ -105,7 +107,7 @@ def auth_with_credentials( response: Response, db_session: Annotated[OrmSession, Depends(gen_dbsession)], ) -> Token: - """Authorize a user with username and password.""" + """Authorize an account with username and password.""" return _auth_with_credentials(db_session, credentials, response) @@ -119,8 +121,8 @@ def refresh_access_token( @router.get("/me") -def get_current_user_info( - current_user: Annotated[User, Depends(get_current_user)], -) -> UserSchema: - """Get the current authenticated user's information including scopes.""" - return create_user_schema(current_user) +def get_current_account_info( + current_account: Annotated[Account, Depends(get_current_account)], +) -> AccountSchema: + """Get the current authenticated account's information including scopes.""" + return create_account_schema(current_account) diff --git a/backend/src/cms_backend/api/routes/dependencies.py b/backend/src/cms_backend/api/routes/dependencies.py index 5ed0cec6..45c115ac 100644 --- a/backend/src/cms_backend/api/routes/dependencies.py +++ b/backend/src/cms_backend/api/routes/dependencies.py @@ -9,12 +9,12 @@ from cms_backend.api.routes.http_errors import UnauthorizedError from cms_backend.api.token import JWTClaims, token_decoder from cms_backend.db import gen_dbsession, gen_manual_dbsession -from cms_backend.db.models import User -from cms_backend.db.user import ( - check_user_permission, - create_user, - get_user_by_id_or_none, +from cms_backend.db.account import ( + check_account_permission, + create_account, + get_account_by_id_or_none, ) +from cms_backend.db.models import Account from cms_backend.roles import RoleEnum security = HTTPBearer(description="Access Token", auto_error=False) @@ -27,7 +27,7 @@ def get_jwt_claims_or_none( authorization: AuthorizationCredentials, ) -> JWTClaims | None: """ - Get the JWT claims or None if the user is not authenticated. + Get the JWT claims or None if the account is not authenticated. """ if authorization is None: return None @@ -45,79 +45,85 @@ def get_jwt_claims_or_none( raise UnauthorizedError("Unable to verify token") from exc -def get_current_user_or_none_with_session( +def get_current_account_or_none_with_session( session_type: Literal["auto", "manual"] = "auto", ): - def _get_current_user_or_none( + def _get_current_account_or_none( claims: Annotated[JWTClaims | None, Depends(get_jwt_claims_or_none)], session: Annotated[ OrmSession, Depends(gen_dbsession if session_type == "auto" else gen_manual_dbsession), ], - ) -> User | None: + ) -> Account | None: if claims is None: return None - user = get_user_by_id_or_none(session, user_id=claims.sub) - # If this claim has a "name" property, we create a new user account - if user is None and Context.create_new_oauth_account: + account = get_account_by_id_or_none(session, account_id=claims.sub) + # If this claim has a "name" property, we create a new account account + if account is None and Context.create_new_oauth_account: if not claims.name: raise UnauthorizedError("Token is missing 'profile' scope") - create_user( + create_account( session, username=claims.name, role=RoleEnum.VIEWER, idp_sub=claims.sub, ) - user = get_user_by_id_or_none(session, user_id=claims.sub) + account = get_account_by_id_or_none(session, account_id=claims.sub) - return user + return account - return _get_current_user_or_none + return _get_current_account_or_none -def get_current_user_with_session( +def get_current_account_with_session( session_type: Literal["auto", "manual"] = "auto", ): - def _get_current_user( - user: Annotated[ - User | None, - Depends(get_current_user_or_none_with_session(session_type=session_type)), + def _get_current_account( + account: Annotated[ + Account | None, + Depends( + get_current_account_or_none_with_session(session_type=session_type) + ), ], - ) -> User: - # If we get here, it means the token was valid but the user being None + ) -> Account: + # If we get here, it means the token was valid but the account being None # means their idp_sub or id doesn't exist on the database or they have been # marked as deleted. - if user is None: + if account is None: raise UnauthorizedError( "This account is not yet authorized on the CMS. " "Please contact CMS admins." ) - if user.deleted: + if account.deleted: raise UnauthorizedError("This account does not exist on the CMS.") - return user + return account - return _get_current_user + return _get_current_account # Convenience functions for common cases -get_current_user_or_none = get_current_user_or_none_with_session(session_type="auto") -get_current_user = get_current_user_with_session(session_type="auto") +get_current_account_or_none = get_current_account_or_none_with_session( + session_type="auto" +) +get_current_account = get_current_account_with_session(session_type="auto") def require_permission(*, namespace: str, name: str): """ - checks if the current user has a specific permission. + checks if the current account has a specific permission. """ def _check_permission( - current_user: Annotated[User, Depends(get_current_user)], - ) -> User: - if not check_user_permission(current_user, namespace=namespace, name=name): + current_account: Annotated[Account, Depends(get_current_account)], + ) -> Account: + if not check_account_permission( + current_account, namespace=namespace, name=name + ): raise UnauthorizedError( "You do not have permission to perform this action. " ) - return current_user + return current_account return _check_permission diff --git a/backend/src/cms_backend/api/routes/user.py b/backend/src/cms_backend/api/routes/user.py deleted file mode 100644 index 7bc22fc5..00000000 --- a/backend/src/cms_backend/api/routes/user.py +++ /dev/null @@ -1,106 +0,0 @@ -from http import HTTPStatus -from typing import Annotated - -from fastapi import APIRouter, Depends, Path, Response -from sqlalchemy.orm import Session as OrmSession -from werkzeug.security import check_password_hash, generate_password_hash - -from cms_backend.api.routes.dependencies import get_current_user, require_permission -from cms_backend.api.routes.fields import NotEmptyString -from cms_backend.api.routes.http_errors import BadRequestError, UnauthorizedError -from cms_backend.db import gen_dbsession -from cms_backend.db.models import User -from cms_backend.db.user import ( - check_user_permission, - create_user_schema, - get_user_by_username, -) -from cms_backend.db.user import create_user as db_create_user -from cms_backend.db.user import delete_user as db_delete_user -from cms_backend.db.user import update_user_password as db_update_user_password -from cms_backend.roles import RoleEnum -from cms_backend.schemas import BaseModel -from cms_backend.schemas.orms import UserSchema - -router = APIRouter(prefix="/users", tags=["users"]) - - -class UserCreateSchema(BaseModel): - """ - Schema for creating a user - """ - - username: NotEmptyString - password: NotEmptyString - role: RoleEnum - - -class PasswordUpdateSchema(BaseModel): - """ - Schema for updating a user's password - """ - - # users with elevated permissions can omit the current password - current: NotEmptyString | None = None - new: NotEmptyString - - -@router.post( - "", dependencies=[Depends(require_permission(namespace="user", name="create"))] -) -def create_user( - user_schema: UserCreateSchema, - db_session: Annotated[OrmSession, Depends(gen_dbsession)], -) -> UserSchema: - user = db_create_user( - db_session, - username=user_schema.username, - password_hash=generate_password_hash(user_schema.password), - role=user_schema.role, - ) - - return create_user_schema(user) - - -@router.delete( - "/{username}", - dependencies=[Depends(require_permission(namespace="user", name="delete"))], -) -def delete_user( - username: Annotated[str, Path()], - db_session: Annotated[OrmSession, Depends(gen_dbsession)], -) -> Response: - """Delete a specific user""" - user = get_user_by_username(db_session, username=username) - db_delete_user(db_session, user_id=user.id) - return Response(status_code=HTTPStatus.NO_CONTENT) - - -@router.patch("/{username}/password") -def update_user_password( - username: Annotated[str, Path()], - password_update: PasswordUpdateSchema, - db_session: Annotated[OrmSession, Depends(gen_dbsession)], - current_user: Annotated[User, Depends(get_current_user)], -) -> Response: - """Update a user's password""" - user = get_user_by_username(db_session, username=username) - - if current_user.username == username: - if password_update.current is None: - raise BadRequestError("You must enter your current password.") - - if not check_password_hash( - current_user.password_hash or "", password_update.current - ): - raise BadRequestError() - - elif not check_user_permission(current_user, namespace="user", name="update"): - raise UnauthorizedError("You are not allowed to access this resource") - - db_update_user_password( - db_session, - user_id=user.id, - password_hash=generate_password_hash(password_update.new), - ) - return Response(status_code=HTTPStatus.NO_CONTENT) diff --git a/backend/src/cms_backend/api/token.py b/backend/src/cms_backend/api/token.py index 562e2434..9b62223f 100644 --- a/backend/src/cms_backend/api/token.py +++ b/backend/src/cms_backend/api/token.py @@ -39,9 +39,8 @@ def name(self) -> str: """ pass - @property @abc.abstractmethod - def can_decode(self) -> bool: + def can_decode(self, token: str) -> bool: """ Check if this decoder can potentially decode the given token. """ @@ -66,10 +65,29 @@ def decode(self, token: str) -> JWTClaims: def name(self) -> str: return "local" - @property - def can_decode(self) -> bool: + def can_decode(self, token: str) -> bool: return "local" in Context.auth_modes + if "local" not in Context.auth_modes: + return False + try: + payload = jwt.decode( + token, + options={ + "verify_signature": False, + "verify_exp": False, + "verify_aud": False, + "verify_iss": False, + }, + ) + except Exception: + return False + + if payload.get("iss") != Context.jwt_token_issuer: + return False + + return True + class OAuthTokenDecoder(TokenDecoder): """Decoder for OAuth JWT tokens.""" @@ -104,7 +122,7 @@ def decode(self, token: str) -> JWTClaims: raise ValueError("Oauth client ID does not match.") # Check for 2FA requirement only if client_id is not present in the token - # as those come from oauth2 clients and not real users + # as those come from oauth2 clients and not real accounts if ( not decoded_token.get("client_id") and Context.oauth_session_login_require_2fa @@ -121,9 +139,28 @@ def decode(self, token: str) -> JWTClaims: def name(self) -> str: return "oauth" - @property - def can_decode(self) -> bool: - return "oauth" in Context.auth_modes + def can_decode(self, token: str) -> bool: + if "oauth-session" not in Context.auth_modes: + return False + try: + payload = jwt.decode( + token, + options={ + "verify_signature": False, + "verify_exp": False, + "verify_aud": False, + "verify_iss": False, + }, + ) + except Exception: + return False + + if ( + payload.get("iss") != Context.oauth_issuer + or Context.oauth_session_audience_id not in payload.get("aud", []) + ): + return False + return True class TokenDecoderChain: @@ -140,27 +177,38 @@ def decode(self, token: str) -> JWTClaims: Try to decode token using each decoder in order. """ exc_cls: Exception | None = None - decoders = [decoder for decoder in self.decoders if decoder.can_decode] + decoders = [decoder for decoder in self.decoders if decoder.can_decode(token)] if not decoders: raise ValueError("No decoders registered for decoding token.") + if not decoders: + raise ValueError("No decoders can decode token.") + + if len(decoders) > 1: + logger.warning( + "Multiple token decoders detected. Set configuration values to match " + "only one token decoder to avoid overwriting exception messages." + ) + for decoder in decoders: - if decoder.can_decode: - try: - return decoder.decode(token) - except ( - jwt_exceptions.PyJWTError, - PydanticValidationError, - Exception, - ) as exc: - logger.debug(f"{decoder.name}: unable to decode token: {exc!s}") - # keep track of the most recent exception class - exc_cls = exc + try: + logger.debug(f"{decoder.name}-decoder: attempting to decode token.") + claims = decoder.decode(token) + except ( + jwt_exceptions.PyJWTError, + PydanticValidationError, + Exception, + ) as exc: + logger.debug(f"{decoder.name}-decoder: unable to decode token: {exc!s}") + exc_cls = exc + else: + logger.debug(f"{decoder.name}-decoder: decoded token successfully.") + return claims if exc_cls: raise exc_cls - raise ValueError("Inavlid token") + raise ValueError("Invalid token") token_decoder = TokenDecoderChain( @@ -173,10 +221,10 @@ def decode(self, token: str) -> JWTClaims: def generate_access_token( *, - user_id: str, + account_id: str, issue_time: datetime.datetime, ) -> str: - """Generate a JWT access token for the given user ID with configured expiry.""" + """Generate a JWT access token for the given account ID with configured expiry.""" expire_time = issue_time + datetime.timedelta( seconds=Context.jwt_token_expiry_duration @@ -185,6 +233,6 @@ def generate_access_token( "iss": Context.jwt_token_issuer, # issuer "exp": expire_time.timestamp(), # expiration time "iat": issue_time.timestamp(), # issued at - "subject": user_id, + "subject": account_id, } return jwt.encode(payload, key=Context.jwt_secret, algorithm="HS256") diff --git a/backend/src/cms_backend/db/account.py b/backend/src/cms_backend/db/account.py new file mode 100644 index 00000000..c189fb07 --- /dev/null +++ b/backend/src/cms_backend/db/account.py @@ -0,0 +1,120 @@ +from uuid import UUID + +from sqlalchemy import select, update +from sqlalchemy.exc import IntegrityError +from sqlalchemy.orm import Session as OrmSession + +from cms_backend.db.exceptions import ( + RecordAlreadyExistsError, + RecordDoesNotExistError, +) +from cms_backend.db.models import Account +from cms_backend.roles import ROLES, RoleEnum, merge_scopes +from cms_backend.schemas.orms import AccountSchema + + +def get_account_by_username_or_none( + session: OrmSession, *, username: str +) -> Account | None: + """Get an account by username or return None if the account does not exist""" + return session.scalars( + select(Account).where(Account.username == username) + ).one_or_none() + + +def get_account_by_username(session: OrmSession, *, username: str) -> Account: + """Get an account by username or raise an exception if the account does not exist""" + if (account := get_account_by_username_or_none(session, username=username)) is None: + raise RecordDoesNotExistError( + f"Account with username {username} does not exist" + ) + return account + + +def get_account_by_id_or_none( + session: OrmSession, *, account_id: UUID +) -> Account | None: + """Get an account by id or return None if the account does not exist""" + return session.scalars( + select(Account).where( + (Account.idp_sub == account_id) | (Account.id == account_id) + ) + ).one_or_none() + + +def get_account_by_id(session: OrmSession, *, account_id: UUID) -> Account: + """Get an account by id or raise an exception if the account does not exist""" + if (account := get_account_by_id_or_none(session, account_id=account_id)) is None: + raise RecordDoesNotExistError(f"Account with id {account_id} does not exist") + return account + + +def check_account_permission( + account: Account, + *, + namespace: str, + name: str, +) -> bool: + """Check if an account has a permission for a given namespace and name""" + # Select the scope that comes with their role enum or scope from the DB + scope = ROLES.get(account.role) + if not scope: + return False + return scope.get(namespace, {}).get(name, False) + + +def create_account_schema(account: Account) -> AccountSchema: + return AccountSchema( + username=account.username, + role=account.role, + scope=merge_scopes(ROLES.get(account.role, {}), ROLES[RoleEnum.EDITOR]), + ) + + +def create_account( + session: OrmSession, + *, + username: str, + role: str, + idp_sub: UUID | None = None, + password_hash: str | None = None, +) -> Account: + """Create a new account""" + account = Account( + username=username, + role=role, + deleted=False, + idp_sub=idp_sub, + password_hash=password_hash, + ) + session.add(account) + try: + session.flush() + except IntegrityError as exc: + raise RecordAlreadyExistsError("Account already exists") from exc + return account + + +def update_account_password( + session: OrmSession, + *, + account_id: UUID, + password_hash: str, +) -> None: + """Update an account's password""" + session.execute( + update(Account) + .where(Account.id == account_id) + .values(password_hash=password_hash) + ) + + +def delete_account( + session: OrmSession, + *, + account_id: UUID, +) -> None: + """Delete an account""" + session.execute( + update(Account).where(Account.id == account_id).values(deleted=True) + ) diff --git a/backend/src/cms_backend/db/models.py b/backend/src/cms_backend/db/models.py index 12b6ceca..7291819b 100644 --- a/backend/src/cms_backend/db/models.py +++ b/backend/src/cms_backend/db/models.py @@ -287,19 +287,19 @@ def full_str(self) -> str: return f"{self.warehouse.name}:{self.path_in_warehouse}" -class User(Base): - __tablename__ = "user" +class Account(Base): + __tablename__ = "account" id: Mapped[UUID] = mapped_column( init=False, primary_key=True, server_default=text("uuid_generate_v4()") ) - idp_sub: Mapped[UUID | None] username: Mapped[str] = mapped_column(unique=True, index=True) role: Mapped[str] password_hash: Mapped[str | None] deleted: Mapped[bool] = mapped_column(default=False, server_default=false()) + idp_sub: Mapped[UUID | None] = mapped_column(index=True, unique=True, default=None) refresh_tokens: Mapped[list["Refreshtoken"]] = relationship( - back_populates="user", cascade="all, delete-orphan", init=False + back_populates="account", cascade="all, delete-orphan", init=False ) @@ -310,11 +310,15 @@ class Refreshtoken(Base): ) token: Mapped[UUID] = mapped_column(server_default=text("uuid_generate_v4()")) expire_time: Mapped[datetime] - user_id: Mapped[UUID] = mapped_column(ForeignKey("user.id"), init=False) + account_id: Mapped[UUID] = mapped_column(ForeignKey("account.id"), init=False) - user: Mapped["User"] = relationship(back_populates="refresh_tokens", init=False) + account: Mapped["Account"] = relationship( + back_populates="refresh_tokens", init=False + ) - __table__args = (Index("user_id", "token", unique=True),) + __table_args__ = ( + Index("ix_refresh_token_account_id_token", "account_id", "token", unique=True), + ) class Event(Base): diff --git a/backend/src/cms_backend/db/refresh_token.py b/backend/src/cms_backend/db/refresh_token.py index f563e9a7..616ead7f 100644 --- a/backend/src/cms_backend/db/refresh_token.py +++ b/backend/src/cms_backend/db/refresh_token.py @@ -5,9 +5,9 @@ from sqlalchemy.orm import Session as OrmSession from cms_backend.api.context import Context +from cms_backend.db.account import get_account_by_id from cms_backend.db.exceptions import RecordDoesNotExistError from cms_backend.db.models import Refreshtoken -from cms_backend.db.user import get_user_by_id from cms_backend.utils.datetime import getnow @@ -26,14 +26,14 @@ def get_refresh_token(session: OrmSession, token: UUID) -> Refreshtoken: return db_refresh_token -def create_refresh_token(session: OrmSession, user_id: UUID) -> Refreshtoken: - """Create a refresh token for a user""" +def create_refresh_token(session: OrmSession, account_id: UUID) -> Refreshtoken: + """Create a refresh token for an account""" refresh_token = Refreshtoken( token=uuid4(), expire_time=getnow() + datetime.timedelta(seconds=Context.refresh_token_expiry_duration), ) - refresh_token.user = get_user_by_id(session, user_id=user_id) + refresh_token.account = get_account_by_id(session, account_id=account_id) session.add(refresh_token) session.flush() return refresh_token diff --git a/backend/src/cms_backend/db/user.py b/backend/src/cms_backend/db/user.py deleted file mode 100644 index eb1e7034..00000000 --- a/backend/src/cms_backend/db/user.py +++ /dev/null @@ -1,106 +0,0 @@ -from uuid import UUID - -from sqlalchemy import select, update -from sqlalchemy.exc import IntegrityError -from sqlalchemy.orm import Session as OrmSession - -from cms_backend.db.exceptions import ( - RecordAlreadyExistsError, - RecordDoesNotExistError, -) -from cms_backend.db.models import User -from cms_backend.roles import ROLES, RoleEnum, merge_scopes -from cms_backend.schemas.orms import UserSchema - - -def get_user_by_username_or_none(session: OrmSession, *, username: str) -> User | None: - """Get a user by username or return None if the user does not exist""" - return session.scalars(select(User).where(User.username == username)).one_or_none() - - -def get_user_by_username(session: OrmSession, *, username: str) -> User: - """Get a user by username or raise an exception if the user does not exist""" - if (user := get_user_by_username_or_none(session, username=username)) is None: - raise RecordDoesNotExistError(f"User with username {username} does not exist") - return user - - -def get_user_by_id_or_none(session: OrmSession, *, user_id: UUID) -> User | None: - """Get a user by id or return None if the user does not exist""" - return session.scalars( - select(User).where((User.idp_sub == user_id) | (User.id == user_id)) - ).one_or_none() - - -def get_user_by_id(session: OrmSession, *, user_id: UUID) -> User: - """Get a user by id or raise an exception if the user does not exist""" - if (user := get_user_by_id_or_none(session, user_id=user_id)) is None: - raise RecordDoesNotExistError(f"User with id {user_id} does not exist") - return user - - -def check_user_permission( - user: User, - *, - namespace: str, - name: str, -) -> bool: - """Check if a user has a permission for a given namespace and name""" - # Select the scope that comes with their role enum or scope from the DB - scope = ROLES.get(user.role) - if not scope: - return False - return scope.get(namespace, {}).get(name, False) - - -def create_user_schema(user: User) -> UserSchema: - return UserSchema( - username=user.username, - role=user.role, - scope=merge_scopes(ROLES.get(user.role, {}), ROLES[RoleEnum.EDITOR]), - ) - - -def create_user( - session: OrmSession, - *, - username: str, - role: str, - idp_sub: UUID | None = None, - password_hash: str | None = None, -) -> User: - """Create a new user""" - user = User( - username=username, - role=role, - deleted=False, - idp_sub=idp_sub, - password_hash=password_hash, - ) - session.add(user) - try: - session.flush() - except IntegrityError as exc: - raise RecordAlreadyExistsError("User already exists") from exc - return user - - -def update_user_password( - session: OrmSession, - *, - user_id: UUID, - password_hash: str, -) -> None: - """Update a user's password""" - session.execute( - update(User).where(User.id == user_id).values(password_hash=password_hash) - ) - - -def delete_user( - session: OrmSession, - *, - user_id: UUID, -) -> None: - """Delete a user""" - session.execute(update(User).where(User.id == user_id).values(deleted=True)) diff --git a/backend/src/cms_backend/migrations/versions/8f0ee48f7dce_rename_user_to_account.py b/backend/src/cms_backend/migrations/versions/8f0ee48f7dce_rename_user_to_account.py new file mode 100644 index 00000000..5d4b56dc --- /dev/null +++ b/backend/src/cms_backend/migrations/versions/8f0ee48f7dce_rename_user_to_account.py @@ -0,0 +1,76 @@ +"""rename user to account + +Revision ID: 8f0ee48f7dce +Revises: eb212a007a90 +Create Date: 2026-04-29 11:55:20.724507 + +""" + +from alembic import op + +# revision identifiers, used by Alembic. +revision = "8f0ee48f7dce" +down_revision = "eb212a007a90" +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + # Drop all indices on the user table + op.drop_index("ix_user_username", table_name="user") + # Drop constraints and rename columns against other tables + op.drop_constraint( + "fk_refresh_token_user_id_user", "refresh_token", type_="foreignkey" + ) + # Rename the table from user to account and recreate the indices + op.rename_table("user", "account") + # Rename the primary key constraint to match the new table name + op.execute("ALTER TABLE account RENAME CONSTRAINT pk_user TO pk_account") + op.create_index(op.f("ix_account_idp_sub"), "account", ["idp_sub"], unique=True) + op.create_index(op.f("ix_account_username"), "account", ["username"], unique=True) + + op.alter_column("refresh_token", "user_id", new_column_name="account_id") + op.create_index( + "ix_refresh_token_account_id_token", + table_name="refresh_token", + columns=["account_id", "token"], + unique=True, + ) + + # Recreate the inbound foreign key constraints to account + op.create_foreign_key( + op.f("fk_refresh_token_account_id_account"), + "refresh_token", + "account", + ["account_id"], + ["id"], + ) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + # Drop all foreign keys pointing to account table + op.drop_constraint( + op.f("fk_refresh_token_account_id_account"), + "refresh_token", + type_="foreignkey", + ) + op.drop_index("ix_refresh_token_account_id_token", table_name="refresh_token") + op.drop_index(op.f("ix_account_username"), table_name="account") + op.drop_index(op.f("ix_account_idp_sub"), table_name="account") + + op.execute("ALTER TABLE account RENAME CONSTRAINT pk_account TO pk_user") + op.rename_table("account", "user") + # Recreate foreign keys to user table + op.alter_column("refresh_token", "account_id", new_column_name="user_id") + op.create_foreign_key( + "fk_refresh_token_user_id_user", + "refresh_token", + "user", + ["user_id"], + ["id"], + ) + op.create_index("ix_user_username", "user", ["username"], unique=True) + # ### end Alembic commands ### diff --git a/backend/src/cms_backend/roles.py b/backend/src/cms_backend/roles.py index 631a5cd3..854f3290 100644 --- a/backend/src/cms_backend/roles.py +++ b/backend/src/cms_backend/roles.py @@ -30,7 +30,7 @@ class RoleEnum(StrEnum): "book": ResourcePermissions.get_all(), "title": ResourcePermissions.get_all(), "zimfarm_notification": ResourcePermissions.get_all(), - "user": ResourcePermissions.get_all(), + "account": ResourcePermissions.get_all(), }, RoleEnum.ZIMFARM: { "zimfarm_notification": ResourcePermissions.get(read=True, create=True), @@ -39,16 +39,18 @@ class RoleEnum(StrEnum): def merge_scopes( - user_scope: dict[str, dict[str, bool]], all_scopes: dict[str, dict[str, bool]] + account_scope: dict[str, dict[str, bool]], all_scopes: dict[str, dict[str, bool]] ) -> dict[str, dict[str, bool]]: - """Combine user scope and all scopes populating missing user scopes with False.""" + """ + Combine account scope and all scopes populating missing account scopes with False. + """ merged: dict[str, dict[str, bool]] = {} for category, permissions in all_scopes.items(): merged[category] = {} - user_permissions = user_scope.get(category, {}) + account_permissions = account_scope.get(category, {}) for perm, _ in permissions.items(): - merged[category][perm] = user_permissions.get(perm, False) + merged[category][perm] = account_permissions.get(perm, False) return merged diff --git a/backend/src/cms_backend/schemas/orms.py b/backend/src/cms_backend/schemas/orms.py index e2c80bc9..1c208f5b 100644 --- a/backend/src/cms_backend/schemas/orms.py +++ b/backend/src/cms_backend/schemas/orms.py @@ -118,9 +118,9 @@ class WarehousePathSchema(BaseModel): warehouse_name: str -class UserSchema(BaseModel): +class AccountSchema(BaseModel): """ - Schema for reading a user model + Schema for reading an account model """ username: str diff --git a/backend/src/cms_backend/utils/database.py b/backend/src/cms_backend/utils/database.py index 433e30cd..3b14233a 100644 --- a/backend/src/cms_backend/utils/database.py +++ b/backend/src/cms_backend/utils/database.py @@ -8,7 +8,7 @@ from cms_backend import logger from cms_backend.context import Context from cms_backend.db import Session -from cms_backend.db.user import create_user, get_user_by_username_or_none +from cms_backend.db.account import create_account, get_account_by_username_or_none def check_if_schema_is_up_to_date(): @@ -35,19 +35,19 @@ def upgrade_db_schema(): subprocess.check_output(args=["alembic", "upgrade", "head"], cwd=Context.base_dir) -def create_initial_user(): +def create_initial_account(): with Session.begin() as session: username = os.getenv("INIT_USERNAME", default="admin") password = os.getenv("INIT_PASSWORD", default="admin_pass") - user = get_user_by_username_or_none(session, username=username) - if user is None: - logger.info(f"creating initial user `{username}`") - create_user( + account = get_account_by_username_or_none(session, username=username) + if account is None: + logger.info(f"creating initial account `{username}`") + create_account( session=session, username=username, password_hash=generate_password_hash(password), role="editor", ) else: - logger.info(f"user {username} already exists") + logger.info(f"account {username} already exists") diff --git a/backend/tests/api/routes/test_user.py b/backend/tests/api/routes/test_account.py similarity index 69% rename from backend/tests/api/routes/test_user.py rename to backend/tests/api/routes/test_account.py index 6040e5e0..13366831 100644 --- a/backend/tests/api/routes/test_user.py +++ b/backend/tests/api/routes/test_account.py @@ -5,15 +5,15 @@ from fastapi.testclient import TestClient from cms_backend.api.token import generate_access_token -from cms_backend.db.models import User +from cms_backend.db.models import Account from cms_backend.utils.datetime import getnow -def test_create_user(client: TestClient, user: User): - url = "/v1/users/" +def test_create_account(client: TestClient, account: Account): + url = "/v1/accounts/" access_token = generate_access_token( issue_time=getnow(), - user_id=str(user.id), + account_id=str(account.id), ) response = client.post( url, @@ -27,17 +27,17 @@ def test_create_user(client: TestClient, user: User): assert response.status_code == HTTPStatus.OK -def test_create_user_duplicate(client: TestClient, user: User): - url = "/v1/users/" +def test_create_account_duplicate(client: TestClient, account: Account): + url = "/v1/accounts/" access_token = generate_access_token( issue_time=getnow(), - user_id=str(user.id), + account_id=str(account.id), ) response = client.post( url, headers={"Authorization": f"Bearer {access_token}"}, json={ - "username": user.username, + "username": account.username, "password": "test", "role": "viewer", }, @@ -53,22 +53,22 @@ def test_create_user_duplicate(client: TestClient, user: User): ("testpassword", "test2", HTTPStatus.NO_CONTENT), ], ) -def test_update_user_password_invalid( +def test_update_account_password_invalid( client: TestClient, - create_user: Callable[..., User], + create_account: Callable[..., Account], current: str, new: str, expected: HTTPStatus, ): - """Test updating a user's password with an invalid current password""" - user = create_user(password="testpassword") + """Test updating an account's password with an invalid current password""" + account = create_account(password="testpassword") access_token = generate_access_token( issue_time=getnow(), - user_id=str(user.id), + account_id=str(account.id), ) response = client.patch( - f"/v1/users/{user.username}/password", + f"/v1/accounts/{account.username}/password", headers={"Authorization": f"Bearer {access_token}"}, json={"current": current, "new": new}, ) diff --git a/backend/tests/api/routes/test_titles.py b/backend/tests/api/routes/test_titles.py index 2a54433a..b2fedbd9 100644 --- a/backend/tests/api/routes/test_titles.py +++ b/backend/tests/api/routes/test_titles.py @@ -8,7 +8,7 @@ from sqlalchemy.orm import Session as OrmSession from cms_backend.api.token import generate_access_token -from cms_backend.db.models import Book, Collection, Event, Title, User +from cms_backend.db.models import Account, Book, Collection, Event, Title from cms_backend.roles import RoleEnum from cms_backend.utils.datetime import getnow @@ -60,7 +60,7 @@ def test_get_titles( ) def test_create_title_required_permissions( client: TestClient, - create_user: Callable[..., User], + create_account: Callable[..., Account], permission: RoleEnum, expected_status_code: HTTPStatus, ): @@ -69,8 +69,10 @@ def test_create_title_required_permissions( "name": "wikipedia_en_test", } - user = create_user(permission=permission) - access_token = generate_access_token(user_id=str(user.id), issue_time=getnow()) + account = create_account(permission=permission) + access_token = generate_access_token( + account_id=str(account.id), issue_time=getnow() + ) response = client.post( "/v1/titles", json=title_data, @@ -316,7 +318,7 @@ def test_get_title_by_id_not_found(client: TestClient): ) def test_update_title_required_permissions( client: TestClient, - create_user: Callable[..., User], + create_account: Callable[..., Account], create_title: Callable[..., Title], permission: RoleEnum, expected_status_code: HTTPStatus, @@ -327,8 +329,10 @@ def test_update_title_required_permissions( "maturity": "robust", } - user = create_user(permission=permission) - access_token = generate_access_token(user_id=str(user.id), issue_time=getnow()) + account = create_account(permission=permission) + access_token = generate_access_token( + account_id=str(account.id), issue_time=getnow() + ) response = client.patch( f"/v1/titles/{title.id}", json=update_data, diff --git a/backend/tests/api/test_token_decoder.py b/backend/tests/api/test_token_decoder.py index 7ece332c..cc151eb2 100644 --- a/backend/tests/api/test_token_decoder.py +++ b/backend/tests/api/test_token_decoder.py @@ -32,7 +32,7 @@ def create_test_session_jwt_token( "exp": int((now + exp_delta).timestamp()), "aud": audience_id, "aal": aal, - "name": "Test User", + "name": "Test Account", } # Create a test token (unsigned for testing purposes) @@ -100,7 +100,7 @@ def test_verify_session_access_token_expired_token( def test_verify_session_access_token_with_2fa_enabled_and_valid_aal( monkeypatch: pytest.MonkeyPatch, ): - """Test successful verification when 2FA is enabled and user has aal2.""" + """Test successful verification when 2FA is enabled and account has aal2.""" monkeypatch.setattr("cms_backend.api.context.Context.oauth_issuer", TEST_ISSUER) monkeypatch.setattr( "cms_backend.api.context.Context.oauth_session_audience_id", @@ -120,7 +120,7 @@ def test_verify_session_access_token_with_2fa_enabled_and_valid_aal( "iss": TEST_ISSUER, "sub": str(UUID(int=0)), "aud": TEST_AUDIENCE_ID, - "name": "Test User", + "name": "Test Account", "iat": int(getnow().timestamp()), "exp": int((getnow() + datetime.timedelta(hours=1)).timestamp()), "aal": "aal2", # Authenticator Assurance Level 2 (2FA) @@ -167,7 +167,7 @@ def test_verify_session_access_token_with_2fa_enabled_only_aal1( "iss": TEST_ISSUER, "sub": str(UUID(int=0)), "aud": TEST_AUDIENCE_ID, - "name": "Test User", + "name": "Test Account", "iat": int(getnow().timestamp()), "exp": int((getnow() + datetime.timedelta(hours=1)).timestamp()), "aal": "aal1", # Only first factor (aal1) @@ -213,7 +213,7 @@ def test_verify_session_access_token_with_2fa_disabled_only_aal1( "iss": TEST_ISSUER, "sub": str(UUID(int=0)), "aud": TEST_AUDIENCE_ID, - "name": "Test User", + "name": "Test Account", "iat": int(getnow().timestamp()), "exp": int((getnow() + datetime.timedelta(hours=1)).timestamp()), "aal": "aal1", # Only first factor (aal1), but 2FA is disabled diff --git a/backend/tests/conftest.py b/backend/tests/conftest.py index 1c58dda1..a72f988c 100644 --- a/backend/tests/conftest.py +++ b/backend/tests/conftest.py @@ -14,6 +14,7 @@ from cms_backend.context import Context from cms_backend.db import Session from cms_backend.db.models import ( + Account, Base, Book, BookLocation, @@ -21,7 +22,6 @@ CollectionTitle, Event, Title, - User, Warehouse, ZimfarmNotification, ) @@ -329,17 +329,17 @@ def _create_collection_title( @pytest.fixture -def create_user( +def create_account( dbsession: OrmSession, faker: Faker, -) -> Callable[..., User]: - def _create_user( +) -> Callable[..., Account]: + def _create_account( *, username: str | None = None, permission: RoleEnum = RoleEnum.EDITOR, password: str | None = None, ): - user = User( + account = Account( username=username or faker.first_name(), role=permission, idp_sub=uuid4(), @@ -347,25 +347,25 @@ def _create_user( None if password is None else generate_password_hash(password) ), ) - dbsession.add(user) + dbsession.add(account) dbsession.flush() - return user + return account - return _create_user + return _create_account @pytest.fixture -def user(create_user: Callable[..., User]): - return create_user() +def account(create_account: Callable[..., Account]): + return create_account() @pytest.fixture -def access_token(user: User) -> str: +def access_token(account: Account) -> str: return generate_access_token( issue_time=getnow(), - user_id=str(user.id), + account_id=str(account.id), ) diff --git a/backend/tests/db/test_account.py b/backend/tests/db/test_account.py new file mode 100644 index 00000000..4f373c73 --- /dev/null +++ b/backend/tests/db/test_account.py @@ -0,0 +1,94 @@ +from uuid import uuid4 + +import pytest +from sqlalchemy.orm import Session as OrmSession + +from cms_backend.db.account import ( + delete_account, + get_account_by_id, + get_account_by_id_or_none, + get_account_by_username, + get_account_by_username_or_none, +) +from cms_backend.db.exceptions import ( + RecordDoesNotExistError, +) +from cms_backend.db.models import Account +from cms_backend.roles import merge_scopes + + +@pytest.mark.parametrize( + ["custom_scope", "all_scopes", "expected"], + [ + ( + {"book": {"read": True}}, + { + "book": {"read": True, "write": True}, + "title": {"read": True, "write": True}, + }, + { + "book": {"read": True, "write": False}, + "title": {"read": False, "write": False}, + }, + ), + ( + {}, + { + "book": {"read": True, "write": True}, + "title": {"read": True, "write": True}, + }, + { + "book": {"read": False, "write": False}, + "title": {"read": False, "write": False}, + }, + ), + ], +) +def test_merge_scopes( + custom_scope: dict[str, dict[str, bool]], + all_scopes: dict[str, dict[str, bool]], + expected: dict[str, dict[str, bool]], +): + assert merge_scopes(custom_scope, all_scopes) == expected + + +def test_get_account_by_id_or_none(dbsession: OrmSession): + """Test that get_account_by_id_or_none returns None if the account does not exist""" + account = get_account_by_id_or_none(dbsession, account_id=uuid4()) + assert account is None + + +def test_get_account_by_id_not_found(dbsession: OrmSession): + """Test that get_account_by_id raises an exception if the account does not exist""" + with pytest.raises(RecordDoesNotExistError): + get_account_by_id(dbsession, account_id=uuid4()) + + +def test_get_account_by_id(dbsession: OrmSession, account: Account): + """Test that get_account_by_id returns the account if the account exists""" + db_account = get_account_by_id(dbsession, account_id=account.id) + assert db_account is not None + assert db_account.id == account.id + + +def test_get_account_by_username_or_none(dbsession: OrmSession): + """ + Test that get_account_by_username_or_none returns None if the account does not exist + """ + account = get_account_by_username_or_none(dbsession, username="doesnotexist") + assert account is None + + +def test_get_account_by_username_not_found(dbsession: OrmSession): + """ + Test that get_account_by_username raises an exception if the account does not exist + """ + with pytest.raises(RecordDoesNotExistError): + get_account_by_username(dbsession, username="doesnotexist") + + +def test_delete_account(dbsession: OrmSession, account: Account): + """Test that delete_account marks account as deleted""" + delete_account(dbsession, account_id=account.id) + dbsession.refresh(account) + assert account.deleted diff --git a/backend/tests/db/test_refresh_token.py b/backend/tests/db/test_refresh_token.py index 03b17cf4..d625a6af 100644 --- a/backend/tests/db/test_refresh_token.py +++ b/backend/tests/db/test_refresh_token.py @@ -5,7 +5,7 @@ from sqlalchemy.orm import Session as OrmSession from cms_backend.db.exceptions import RecordDoesNotExistError -from cms_backend.db.models import Refreshtoken, User +from cms_backend.db.models import Account, Refreshtoken from cms_backend.db.refresh_token import ( create_refresh_token, delete_refresh_token, @@ -17,13 +17,13 @@ @pytest.fixture -def refresh_token(dbsession: OrmSession, user: User) -> Refreshtoken: - """Create a refresh token for a user""" +def refresh_token(dbsession: OrmSession, account: Account) -> Refreshtoken: + """Create a refresh token for an account""" token = Refreshtoken( token=uuid4(), expire_time=getnow() + datetime.timedelta(seconds=1_000), ) - token.user = user + token.account = account dbsession.add(token) dbsession.flush() return token @@ -35,12 +35,12 @@ def test_get_refresh_token_or_none(dbsession: OrmSession): assert refresh_token is None -def test_create_refresh_token(dbsession: OrmSession, user: User): +def test_create_refresh_token(dbsession: OrmSession, account: Account): """Test that create_refresh_token creates a refresh token""" - refresh_token = create_refresh_token(dbsession, user.id) + refresh_token = create_refresh_token(dbsession, account.id) assert refresh_token is not None assert refresh_token.token is not None - assert refresh_token.user_id == user.id + assert refresh_token.account_id == account.id assert refresh_token.expire_time is not None diff --git a/backend/tests/db/test_user.py b/backend/tests/db/test_user.py deleted file mode 100644 index eb6ab85f..00000000 --- a/backend/tests/db/test_user.py +++ /dev/null @@ -1,90 +0,0 @@ -from uuid import uuid4 - -import pytest -from sqlalchemy.orm import Session as OrmSession - -from cms_backend.db.exceptions import ( - RecordDoesNotExistError, -) -from cms_backend.db.models import User -from cms_backend.db.user import ( - delete_user, - get_user_by_id, - get_user_by_id_or_none, - get_user_by_username, - get_user_by_username_or_none, -) -from cms_backend.roles import merge_scopes - - -@pytest.mark.parametrize( - ["custom_scope", "all_scopes", "expected"], - [ - ( - {"book": {"read": True}}, - { - "book": {"read": True, "write": True}, - "title": {"read": True, "write": True}, - }, - { - "book": {"read": True, "write": False}, - "title": {"read": False, "write": False}, - }, - ), - ( - {}, - { - "book": {"read": True, "write": True}, - "title": {"read": True, "write": True}, - }, - { - "book": {"read": False, "write": False}, - "title": {"read": False, "write": False}, - }, - ), - ], -) -def test_merge_scopes( - custom_scope: dict[str, dict[str, bool]], - all_scopes: dict[str, dict[str, bool]], - expected: dict[str, dict[str, bool]], -): - assert merge_scopes(custom_scope, all_scopes) == expected - - -def test_get_user_by_id_or_none(dbsession: OrmSession): - """Test that get_user_by_id_or_none returns None if the user does not exist""" - user = get_user_by_id_or_none(dbsession, user_id=uuid4()) - assert user is None - - -def test_get_user_by_id_not_found(dbsession: OrmSession): - """Test that get_user_by_id raises an exception if the user does not exist""" - with pytest.raises(RecordDoesNotExistError): - get_user_by_id(dbsession, user_id=uuid4()) - - -def test_get_user_by_id(dbsession: OrmSession, user: User): - """Test that get_user_by_id returns the user if the user exists""" - db_user = get_user_by_id(dbsession, user_id=user.id) - assert db_user is not None - assert db_user.id == user.id - - -def test_get_user_by_username_or_none(dbsession: OrmSession): - """Test that get_user_by_username_or_none returns None if the user does not exist""" - user = get_user_by_username_or_none(dbsession, username="doesnotexist") - assert user is None - - -def test_get_user_by_username_not_found(dbsession: OrmSession): - """Test that get_user_by_username raises an exception if the user does not exist""" - with pytest.raises(RecordDoesNotExistError): - get_user_by_username(dbsession, username="doesnotexist") - - -def test_delete_user(dbsession: OrmSession, user: User): - """Test that delete_user marks user as deleted""" - delete_user(dbsession, user_id=user.id) - dbsession.refresh(user) - assert user.deleted diff --git a/backend/tests/db/test_zimfarm_notification.py b/backend/tests/db/test_zimfarm_notification.py index 8d1dd8e0..b278d5f0 100644 --- a/backend/tests/db/test_zimfarm_notification.py +++ b/backend/tests/db/test_zimfarm_notification.py @@ -39,7 +39,7 @@ def test_get_zimfarm_notification_not_found( dbsession: OrmSession, zimfarm_notification: ZimfarmNotification, # noqa: ARG001 - needed for conftest ): - """Raises an exception if the user does not exist""" + """Raises an exception if the notification does not exist""" with pytest.raises(RecordDoesNotExistError): get_zimfarm_notification(dbsession, notification_id=uuid4()) diff --git a/dev/README.md b/dev/README.md index 3dadfee3..39391825 100644 --- a/dev/README.md +++ b/dev/README.md @@ -59,11 +59,13 @@ docker exec cms_shuttle python /scripts/setup_warehouses.py ``` This script will: + - Create warehouse directories in `dev/warehouses/` - Create corresponding database records (Warehouse) - Print the LOCAL_WAREHOUSE_PATHS configuration (already configured in docker-compose.yml) Current warehouse configuration: + - **hidden**: 2 paths (`quarantine`, `staging`) - **prod**: 1 path (`other`, `wikipedia`) - **client1**: 1 path (`all`) @@ -83,6 +85,7 @@ Currently two collections are configured: **prod** (associated with **prod** war To modify collections configuration, edit the `COLLECTIONS_CONFIG` list in [scripts/setup_collections.py](scripts/setup_collections.py) and re-run the script. Once created, collection catalogs are accessible at: + - `http://localhost:37601/v1/collections/prod/catalog.xml` or `http://localhost:37601/v1/collections/aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa/catalog.xml` - `http://localhost:37601/v1/collections/client1/catalog.xml` or `http://localhost:37601/v1/collections/bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb/catalog.xml` @@ -105,6 +108,7 @@ docker exec cms_shuttle python /scripts/setup_notifications.py ``` This script will: + - Create ZimfarmNotification records with status "pending" - Create "fake" ZIMs in quarantine folder and subfolders @@ -145,13 +149,13 @@ docker exec -it cms_backend-tests python -m pytest You can select one specific set of tests by path. ```sh -docker exec -it cms_backend-tests python -m pytest tests/routes/test_user.py +docker exec -it cms_backend-tests python -m pytest tests/routes/test_account.py ``` Or just one specific test function. ```sh -docker exec -it cms_backend-tests python -m pytest tests/routes/test_user.py -k test_list_users_no_auth +docker exec -it cms_backend-tests python -m pytest tests/routes/test_account.py -k test_list_accounts_no_auth ``` This is normally not needed, but you might end-up in a situation where test DB gets corrupted. You can recreate test DB. diff --git a/frontend/src/stores/auth.ts b/frontend/src/stores/auth.ts index 188afa1c..0eac52bd 100644 --- a/frontend/src/stores/auth.ts +++ b/frontend/src/stores/auth.ts @@ -10,12 +10,12 @@ import { getOAuthConfig } from '@/services/auth/base' import { OAuthSessionProvider } from '@/services/auth/OAuthSessionProvider' import { LocalAuthProvider } from '@/services/auth/LocalAuthProvider' import type { AuthProvider } from '@/services/auth/base' -import type { User } from '@/types/user' +import type { Account } from '@/types/account' export const useAuthStore = defineStore('auth', () => { const errors = ref([]) const token = ref(null) - const user = ref(null) + const account = ref(null) const config = inject(constants.config) @@ -46,11 +46,11 @@ export const useAuthStore = defineStore('auth', () => { // Computed properties const isLoggedIn = computed(() => { - return token.value !== null && user.value !== null + return token.value !== null && account.value !== null }) const username = computed(() => { - return user.value?.username || null + return account.value?.username || null }) const accessToken = computed(() => { @@ -58,7 +58,7 @@ export const useAuthStore = defineStore('auth', () => { }) const permissions = computed(() => { - return user.value?.scope || {} + return account.value?.scope || {} }) const refreshToken = computed(() => { @@ -77,7 +77,7 @@ export const useAuthStore = defineStore('auth', () => { const hasPermission = (resource: string, action: string) => { if (!token.value) return false - return user.value?.scope[resource]?.[action] || false + return account.value?.scope[resource]?.[action] || false } const tokenType = computed(() => { @@ -125,7 +125,7 @@ export const useAuthStore = defineStore('auth', () => { throw new Error('Invalid authentication token') } token.value = newToken - await fetchUserInfo(newToken.access_token) + await fetchAccountInfo(newToken.access_token) errors.value = [] provider.saveToken(newToken) @@ -135,7 +135,7 @@ export const useAuthStore = defineStore('auth', () => { return true } catch (err: unknown) { token.value = null - user.value = null + account.value = null errors.value = translateErrors(err as ErrorResponse) return false } @@ -156,7 +156,7 @@ export const useAuthStore = defineStore('auth', () => { }) } - const fetchUserInfo = async (accessToken: string) => { + const fetchAccountInfo = async (accessToken: string) => { try { const apiService = httpRequest({ baseURL: `${config.CMS_API}/auth`, @@ -164,12 +164,12 @@ export const useAuthStore = defineStore('auth', () => { Authorization: `Bearer ${accessToken}`, }, }) - const response = (await apiService.get('/me')) as User - user.value = response + const response = (await apiService.get('/me')) as Account + account.value = response errors.value = [] } catch (error) { - console.error('Failed to fetch user info:', error) - user.value = null + console.error('Failed to fetch account info:', error) + account.value = null errors.value = translateErrors(error as ErrorResponse) } } @@ -220,8 +220,8 @@ export const useAuthStore = defineStore('auth', () => { await logout() return null } - if (!user.value) { - await fetchUserInfo(storedToken.access_token) + if (!account.value) { + await fetchAccountInfo(storedToken.access_token) } token.value = storedToken return storedToken @@ -253,7 +253,7 @@ export const useAuthStore = defineStore('auth', () => { if (!newToken) { throw new Error('Unable to refresh token') } - await fetchUserInfo(newToken.access_token) + await fetchAccountInfo(newToken.access_token) isRefreshFailed.value = false return newToken } catch (error) { @@ -266,7 +266,7 @@ export const useAuthStore = defineStore('auth', () => { } token.value = null - user.value = null + account.value = null errors.value = translateErrors(error as ErrorResponse) return null } finally { @@ -287,7 +287,7 @@ export const useAuthStore = defineStore('auth', () => { } token.value = null - user.value = null + account.value = null // Reset refresh failure state on logout isRefreshFailed.value = false @@ -300,8 +300,8 @@ export const useAuthStore = defineStore('auth', () => { const newToken = await provider.onCallback(callbackUrl) token.value = newToken - // Fetch user info from backend using the Kiwix token - await fetchUserInfo(newToken.access_token) + // Fetch account info from backend using the Kiwix token + await fetchAccountInfo(newToken.access_token) errors.value = [] provider.saveToken(newToken) @@ -312,7 +312,7 @@ export const useAuthStore = defineStore('auth', () => { return true } catch (err: unknown) { token.value = null - user.value = null + account.value = null errors.value = translateErrors(err as ErrorResponse) return false } @@ -322,7 +322,7 @@ export const useAuthStore = defineStore('auth', () => { // State errors, token, - user, + account, isRefreshFailed, refreshPromise, permissions, @@ -338,7 +338,7 @@ export const useAuthStore = defineStore('auth', () => { // Methods loadToken, - fetchUserInfo, + fetchAccountInfo, renewToken, authenticate, logout, diff --git a/frontend/src/types/account.ts b/frontend/src/types/account.ts new file mode 100644 index 00000000..7d1fcf0a --- /dev/null +++ b/frontend/src/types/account.ts @@ -0,0 +1,4 @@ +export interface Account { + username: string + scope: Record> +} diff --git a/frontend/src/types/user.ts b/frontend/src/types/user.ts index b62b994a..7d1fcf0a 100644 --- a/frontend/src/types/user.ts +++ b/frontend/src/types/user.ts @@ -1,4 +1,4 @@ -export interface User { +export interface Account { username: string scope: Record> } diff --git a/frontend/src/views/OAuthCallbackView.vue b/frontend/src/views/OAuthCallbackView.vue index ca94b7d4..07ae10dc 100644 --- a/frontend/src/views/OAuthCallbackView.vue +++ b/frontend/src/views/OAuthCallbackView.vue @@ -3,7 +3,7 @@ - Receives authorization code from Kiwix auth - Validates state parameter for CSRF protection - Exchanges code for token using PKCE verifier - - Fetches user info from Kiwix auth + - Fetches account info from Kiwix auth - Redirects to appropriate page after successful authentication -->