new features
This commit is contained in:
@@ -0,0 +1,102 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, WebSocket, WebSocketDisconnect, status
|
||||
from jwt import ExpiredSignatureError, InvalidTokenError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from cpv3.db.session import SessionLocal, get_db
|
||||
from cpv3.infrastructure.auth import get_current_user
|
||||
from cpv3.infrastructure.security import decode_token
|
||||
from cpv3.modules.notifications.repository import NotificationRepository
|
||||
from cpv3.modules.notifications.schemas import NotificationRead
|
||||
from cpv3.modules.notifications.service import subscribe_and_forward
|
||||
from cpv3.modules.users.models import User
|
||||
from cpv3.modules.users.repository import UserRepository
|
||||
|
||||
router = APIRouter(prefix="/api/notifications", tags=["notifications"])
|
||||
|
||||
|
||||
@router.websocket("/ws/")
|
||||
async def notifications_ws(
|
||||
websocket: WebSocket,
|
||||
token: str = Query(...),
|
||||
) -> None:
|
||||
"""WebSocket endpoint for real-time notifications."""
|
||||
try:
|
||||
payload = decode_token(token)
|
||||
except (ExpiredSignatureError, InvalidTokenError):
|
||||
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
|
||||
return
|
||||
|
||||
if payload.get("type") != "access":
|
||||
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
|
||||
return
|
||||
|
||||
sub = payload.get("sub")
|
||||
if not sub:
|
||||
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
|
||||
return
|
||||
|
||||
try:
|
||||
user_id = uuid.UUID(str(sub))
|
||||
except ValueError:
|
||||
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
|
||||
return
|
||||
|
||||
async with SessionLocal() as db:
|
||||
user_repo = UserRepository(db)
|
||||
user = await user_repo.get_by_id(user_id)
|
||||
if user is None or not user.is_active:
|
||||
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
|
||||
return
|
||||
|
||||
await websocket.accept()
|
||||
|
||||
try:
|
||||
await subscribe_and_forward(websocket, user_id)
|
||||
except WebSocketDisconnect:
|
||||
pass
|
||||
|
||||
|
||||
@router.get("/", response_model=list[NotificationRead])
|
||||
async def list_notifications(
|
||||
unread_only: bool = Query(False),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> list[NotificationRead]:
|
||||
repo = NotificationRepository(db)
|
||||
items = await repo.list_for_user(current_user.id, unread_only=unread_only)
|
||||
return [NotificationRead.model_validate(n) for n in items]
|
||||
|
||||
|
||||
@router.post("/{notification_id}/read/", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def mark_notification_read(
|
||||
notification_id: uuid.UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> None:
|
||||
repo = NotificationRepository(db)
|
||||
found = await repo.mark_read(notification_id, current_user.id)
|
||||
if not found:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Not found")
|
||||
|
||||
|
||||
@router.post("/read-all/", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def mark_all_read(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> None:
|
||||
repo = NotificationRepository(db)
|
||||
await repo.mark_all_read(current_user.id)
|
||||
|
||||
|
||||
@router.get("/unread-count/")
|
||||
async def unread_count(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> dict[str, int]:
|
||||
repo = NotificationRepository(db)
|
||||
count = await repo.count_unread(current_user.id)
|
||||
return {"count": count}
|
||||
Reference in New Issue
Block a user