181 lines
6.2 KiB
Python
181 lines
6.2 KiB
Python
from __future__ import annotations
|
|
|
|
import uuid
|
|
from datetime import timedelta
|
|
|
|
from fastapi import APIRouter, Depends, HTTPException, Response, status
|
|
from jwt import ExpiredSignatureError, InvalidTokenError
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from cpv3.infrastructure.auth import get_current_user
|
|
from cpv3.infrastructure.security import create_token, decode_token
|
|
from cpv3.infrastructure.settings import get_settings
|
|
from cpv3.db.session import get_db
|
|
from cpv3.modules.users.models import User
|
|
from cpv3.modules.users.schemas import (
|
|
TokenRefresh,
|
|
TokenRefreshResponse,
|
|
UserCreate,
|
|
UserLogin,
|
|
UserRead,
|
|
UserRegister,
|
|
UserRegisterResponse,
|
|
UserUpdate,
|
|
)
|
|
from cpv3.modules.users.service import UserService
|
|
|
|
users_router = APIRouter(prefix="/api/users", tags=["Users"])
|
|
auth_router = APIRouter(prefix="/auth", tags=["auth"])
|
|
|
|
|
|
def _issue_tokens(user: User) -> tuple[str, str]:
|
|
settings = get_settings()
|
|
|
|
access = create_token(
|
|
subject=str(user.id),
|
|
token_type="access",
|
|
expires_in=timedelta(minutes=settings.jwt_access_ttl_minutes),
|
|
extra={"is_staff": user.is_staff, "is_superuser": user.is_superuser},
|
|
)
|
|
refresh = create_token(
|
|
subject=str(user.id),
|
|
token_type="refresh",
|
|
expires_in=timedelta(days=settings.jwt_refresh_ttl_days),
|
|
)
|
|
return access, refresh
|
|
|
|
|
|
@users_router.get("/", response_model=list[UserRead])
|
|
async def list_all_users(
|
|
current_user: User = Depends(get_current_user),
|
|
db: AsyncSession = Depends(get_db),
|
|
) -> list[UserRead]:
|
|
service = UserService(db)
|
|
users = await service.list_users(requester=current_user)
|
|
return [UserRead.model_validate(u) for u in users]
|
|
|
|
|
|
@users_router.post("/", response_model=UserRead, status_code=status.HTTP_201_CREATED)
|
|
async def create_user_endpoint(
|
|
body: UserCreate,
|
|
current_user: User = Depends(get_current_user),
|
|
db: AsyncSession = Depends(get_db),
|
|
) -> UserRead:
|
|
service = UserService(db)
|
|
try:
|
|
user = await service.create_user(body, requester=current_user)
|
|
except ValueError as e:
|
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) from e
|
|
|
|
return UserRead.model_validate(user)
|
|
|
|
|
|
@users_router.get("/me/", response_model=UserRead)
|
|
async def me(current_user: User = Depends(get_current_user)) -> UserRead:
|
|
return UserRead.model_validate(current_user)
|
|
|
|
|
|
@users_router.get("/{user_id}/", response_model=UserRead)
|
|
async def retrieve_user(
|
|
user_id: uuid.UUID,
|
|
current_user: User = Depends(get_current_user),
|
|
db: AsyncSession = Depends(get_db),
|
|
) -> UserRead:
|
|
service = UserService(db)
|
|
user = await service.get_user_by_id(user_id)
|
|
if user is None:
|
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Not found")
|
|
|
|
if not current_user.is_staff and user.id != current_user.id:
|
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Forbidden")
|
|
|
|
return UserRead.model_validate(user)
|
|
|
|
|
|
@users_router.patch("/{user_id}/", response_model=UserRead)
|
|
async def patch_user(
|
|
user_id: uuid.UUID,
|
|
body: UserUpdate,
|
|
current_user: User = Depends(get_current_user),
|
|
db: AsyncSession = Depends(get_db),
|
|
) -> UserRead:
|
|
service = UserService(db)
|
|
user = await service.get_user_by_id(user_id)
|
|
if user is None:
|
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Not found")
|
|
|
|
if not current_user.is_staff and user.id != current_user.id:
|
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Forbidden")
|
|
|
|
try:
|
|
user = await service.update_user(user, body)
|
|
except ValueError as e:
|
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) from e
|
|
|
|
return UserRead.model_validate(user)
|
|
|
|
|
|
@users_router.delete("/{user_id}/", status_code=status.HTTP_204_NO_CONTENT)
|
|
async def delete_user(
|
|
user_id: uuid.UUID,
|
|
current_user: User = Depends(get_current_user),
|
|
db: AsyncSession = Depends(get_db),
|
|
) -> Response:
|
|
service = UserService(db)
|
|
user = await service.get_user_by_id(user_id)
|
|
if user is None:
|
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Not found")
|
|
|
|
if not current_user.is_staff and user.id != current_user.id:
|
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Forbidden")
|
|
|
|
await service.deactivate_user(user)
|
|
return Response(status_code=status.HTTP_204_NO_CONTENT)
|
|
|
|
|
|
@auth_router.post(
|
|
"/register", response_model=UserRegisterResponse, status_code=status.HTTP_201_CREATED
|
|
)
|
|
async def register(body: UserRegister, db: AsyncSession = Depends(get_db)) -> UserRegisterResponse:
|
|
service = UserService(db)
|
|
try:
|
|
user = await service.register_user(body)
|
|
except ValueError as e:
|
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) from e
|
|
|
|
access, refresh = _issue_tokens(user)
|
|
return UserRegisterResponse(user=UserRead.model_validate(user), access=access, refresh=refresh)
|
|
|
|
|
|
@auth_router.post("/login", response_model=UserRegisterResponse)
|
|
async def login(body: UserLogin, db: AsyncSession = Depends(get_db)) -> UserRegisterResponse:
|
|
service = UserService(db)
|
|
user = await service.authenticate(body.username, body.password)
|
|
if user is None:
|
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid credentials")
|
|
|
|
access, refresh = _issue_tokens(user)
|
|
return UserRegisterResponse(user=UserRead.model_validate(user), access=access, refresh=refresh)
|
|
|
|
|
|
@auth_router.post("/refresh", response_model=TokenRefreshResponse)
|
|
async def refresh(body: TokenRefresh) -> TokenRefreshResponse:
|
|
try:
|
|
payload = decode_token(body.refresh)
|
|
if payload.get("type") != "refresh":
|
|
raise InvalidTokenError("wrong type")
|
|
|
|
user_id = uuid.UUID(str(payload.get("sub")))
|
|
|
|
settings = get_settings()
|
|
access = create_token(
|
|
subject=str(user_id),
|
|
token_type="access",
|
|
expires_in=timedelta(minutes=settings.jwt_access_ttl_minutes),
|
|
)
|
|
return TokenRefreshResponse(access=access, refresh=body.refresh)
|
|
except (ExpiredSignatureError, InvalidTokenError, ValueError):
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid refresh token"
|
|
)
|