from __future__ import annotations import uuid from fastapi import APIRouter, Depends, HTTPException, Query, WebSocket, WebSocketDisconnect, status from jwt import ExpiredSignatureError, InvalidTokenError from sqlalchemy.ext.asyncio import AsyncSession from cpv3.db.session import SessionLocal, get_db from cpv3.infrastructure.auth import get_current_user from cpv3.infrastructure.security import decode_token from cpv3.modules.notifications.repository import NotificationRepository from cpv3.modules.notifications.schemas import NotificationRead from cpv3.modules.notifications.service import subscribe_and_forward from cpv3.modules.users.models import User from cpv3.modules.users.repository import UserRepository router = APIRouter(prefix="/api/notifications", tags=["notifications"]) @router.websocket("/ws/") async def notifications_ws( websocket: WebSocket, token: str = Query(...), ) -> None: """WebSocket endpoint for real-time notifications.""" try: payload = decode_token(token) except (ExpiredSignatureError, InvalidTokenError): await websocket.close(code=status.WS_1008_POLICY_VIOLATION) return if payload.get("type") != "access": await websocket.close(code=status.WS_1008_POLICY_VIOLATION) return sub = payload.get("sub") if not sub: await websocket.close(code=status.WS_1008_POLICY_VIOLATION) return try: user_id = uuid.UUID(str(sub)) except ValueError: await websocket.close(code=status.WS_1008_POLICY_VIOLATION) return async with SessionLocal() as db: user_repo = UserRepository(db) user = await user_repo.get_by_id(user_id) if user is None or not user.is_active: await websocket.close(code=status.WS_1008_POLICY_VIOLATION) return await websocket.accept() try: await subscribe_and_forward(websocket, user_id) except WebSocketDisconnect: pass @router.get("/", response_model=list[NotificationRead]) async def list_notifications( unread_only: bool = Query(False), current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), ) -> list[NotificationRead]: repo = NotificationRepository(db) items = await repo.list_for_user(current_user.id, unread_only=unread_only) return [NotificationRead.model_validate(n) for n in items] @router.post("/{notification_id}/read/", status_code=status.HTTP_204_NO_CONTENT) async def mark_notification_read( notification_id: uuid.UUID, current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), ) -> None: repo = NotificationRepository(db) found = await repo.mark_read(notification_id, current_user.id) if not found: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Не найдено") @router.post("/read-all/", status_code=status.HTTP_204_NO_CONTENT) async def mark_all_read( current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), ) -> None: repo = NotificationRepository(db) await repo.mark_all_read(current_user.id) @router.get("/unread-count/") async def unread_count( current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), ) -> dict[str, int]: repo = NotificationRepository(db) count = await repo.count_unread(current_user.id) return {"count": count}