new features

This commit is contained in:
Daniil
2026-02-27 23:33:56 +03:00
parent 937e58859a
commit dc04efe0fb
41 changed files with 2067 additions and 141 deletions
+31
View File
@@ -0,0 +1,31 @@
from __future__ import annotations
import uuid
from sqlalchemy import Boolean, ForeignKey, JSON, String, Text
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import Mapped, mapped_column
from cpv3.db.base import Base, BaseModelMixin
class Notification(Base, BaseModelMixin):
__tablename__ = "notifications"
user_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True), ForeignKey("users.id", ondelete="CASCADE"), index=True
)
job_id: Mapped[uuid.UUID | None] = mapped_column(
UUID(as_uuid=True), ForeignKey("jobs.id", ondelete="SET NULL"), nullable=True
)
project_id: Mapped[uuid.UUID | None] = mapped_column(
UUID(as_uuid=True),
ForeignKey("projects.id", ondelete="SET NULL"),
nullable=True,
)
notification_type: Mapped[str] = mapped_column(String(32))
title: Mapped[str] = mapped_column(String(255))
message: Mapped[str | None] = mapped_column(Text, nullable=True)
payload: Mapped[dict | None] = mapped_column(JSON, nullable=True)
is_read: Mapped[bool] = mapped_column(Boolean, default=False)
+79
View File
@@ -0,0 +1,79 @@
from __future__ import annotations
import uuid
from sqlalchemy import Select, func, select, update
from sqlalchemy.ext.asyncio import AsyncSession
from cpv3.modules.notifications.models import Notification
from cpv3.modules.notifications.schemas import NotificationCreate
class NotificationRepository:
def __init__(self, session: AsyncSession) -> None:
self._session = session
async def list_for_user(
self,
user_id: uuid.UUID,
*,
limit: int = 50,
unread_only: bool = False,
) -> list[Notification]:
stmt: Select[tuple[Notification]] = (
select(Notification)
.where(Notification.user_id == user_id)
.where(Notification.is_active.is_(True))
.order_by(Notification.created_at.desc())
.limit(limit)
)
if unread_only:
stmt = stmt.where(Notification.is_read.is_(False))
result = await self._session.execute(stmt)
return list(result.scalars().all())
async def create(self, data: NotificationCreate) -> Notification:
notification = Notification(
user_id=data.user_id,
job_id=data.job_id,
project_id=data.project_id,
notification_type=data.notification_type,
title=data.title,
message=data.message,
payload=data.payload,
)
self._session.add(notification)
await self._session.commit()
await self._session.refresh(notification)
return notification
async def mark_read(self, notification_id: uuid.UUID, user_id: uuid.UUID) -> bool:
result = await self._session.execute(
update(Notification)
.where(Notification.id == notification_id)
.where(Notification.user_id == user_id)
.values(is_read=True)
)
await self._session.commit()
return result.rowcount > 0 # type: ignore[union-attr]
async def mark_all_read(self, user_id: uuid.UUID) -> int:
result = await self._session.execute(
update(Notification)
.where(Notification.user_id == user_id)
.where(Notification.is_read.is_(False))
.values(is_read=True)
)
await self._session.commit()
return result.rowcount # type: ignore[return-value]
async def count_unread(self, user_id: uuid.UUID) -> int:
result = await self._session.execute(
select(func.count())
.select_from(Notification)
.where(Notification.user_id == user_id)
.where(Notification.is_active.is_(True))
.where(Notification.is_read.is_(False))
)
return result.scalar_one()
+102
View File
@@ -0,0 +1,102 @@
from __future__ import annotations
import uuid
from fastapi import APIRouter, Depends, HTTPException, Query, WebSocket, WebSocketDisconnect, status
from jwt import ExpiredSignatureError, InvalidTokenError
from sqlalchemy.ext.asyncio import AsyncSession
from cpv3.db.session import SessionLocal, get_db
from cpv3.infrastructure.auth import get_current_user
from cpv3.infrastructure.security import decode_token
from cpv3.modules.notifications.repository import NotificationRepository
from cpv3.modules.notifications.schemas import NotificationRead
from cpv3.modules.notifications.service import subscribe_and_forward
from cpv3.modules.users.models import User
from cpv3.modules.users.repository import UserRepository
router = APIRouter(prefix="/api/notifications", tags=["notifications"])
@router.websocket("/ws/")
async def notifications_ws(
websocket: WebSocket,
token: str = Query(...),
) -> None:
"""WebSocket endpoint for real-time notifications."""
try:
payload = decode_token(token)
except (ExpiredSignatureError, InvalidTokenError):
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
return
if payload.get("type") != "access":
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
return
sub = payload.get("sub")
if not sub:
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
return
try:
user_id = uuid.UUID(str(sub))
except ValueError:
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
return
async with SessionLocal() as db:
user_repo = UserRepository(db)
user = await user_repo.get_by_id(user_id)
if user is None or not user.is_active:
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
return
await websocket.accept()
try:
await subscribe_and_forward(websocket, user_id)
except WebSocketDisconnect:
pass
@router.get("/", response_model=list[NotificationRead])
async def list_notifications(
unread_only: bool = Query(False),
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> list[NotificationRead]:
repo = NotificationRepository(db)
items = await repo.list_for_user(current_user.id, unread_only=unread_only)
return [NotificationRead.model_validate(n) for n in items]
@router.post("/{notification_id}/read/", status_code=status.HTTP_204_NO_CONTENT)
async def mark_notification_read(
notification_id: uuid.UUID,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> None:
repo = NotificationRepository(db)
found = await repo.mark_read(notification_id, current_user.id)
if not found:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Not found")
@router.post("/read-all/", status_code=status.HTTP_204_NO_CONTENT)
async def mark_all_read(
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> None:
repo = NotificationRepository(db)
await repo.mark_all_read(current_user.id)
@router.get("/unread-count/")
async def unread_count(
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> dict[str, int]:
repo = NotificationRepository(db)
count = await repo.count_unread(current_user.id)
return {"count": count}
+53
View File
@@ -0,0 +1,53 @@
from __future__ import annotations
from datetime import datetime
from typing import Literal
from uuid import UUID
from cpv3.common.schemas import Schema
NotificationTypeEnum = Literal["task_progress", "task_complete", "task_failed"]
class NotificationCreate(Schema):
user_id: UUID
job_id: UUID | None = None
project_id: UUID | None = None
notification_type: NotificationTypeEnum
title: str
message: str | None = None
payload: dict | None = None
class NotificationRead(Schema):
id: UUID
user_id: UUID
job_id: UUID | None
project_id: UUID | None
notification_type: NotificationTypeEnum
title: str
message: str | None
payload: dict | None
is_read: bool
created_at: datetime
updated_at: datetime
class NotificationUpdate(Schema):
is_read: bool | None = None
class WebSocketMessage(Schema):
"""JSON shape pushed over WebSocket."""
event: str
notification_id: UUID | None = None
job_id: UUID | None = None
project_id: UUID | None = None
job_type: str | None = None
status: str | None = None
progress_pct: float | None = None
message: str | None = None
title: str | None = None
created_at: datetime | None = None
+152
View File
@@ -0,0 +1,152 @@
from __future__ import annotations
import json
import logging
import uuid
import redis.asyncio as aioredis
from sqlalchemy.ext.asyncio import AsyncSession
from cpv3.db.base import utcnow
from cpv3.infrastructure.settings import get_settings
from cpv3.modules.jobs.models import Job
from cpv3.modules.notifications.repository import NotificationRepository
from cpv3.modules.notifications.schemas import (
NotificationCreate,
NotificationTypeEnum,
WebSocketMessage,
)
from cpv3.modules.tasks.schemas import TaskWebhookEvent
logger = logging.getLogger(__name__)
JOB_TYPE_LABELS: dict[str, str] = {
"MEDIA_PROBE": "Анализ медиа",
"SILENCE_REMOVE": "Удаление тишины",
"SILENCE_DETECT": "Обнаружение тишины",
"SILENCE_APPLY": "Применение вырезок",
"MEDIA_CONVERT": "Конвертация",
"TRANSCRIPTION_GENERATE": "Транскрипция",
"CAPTIONS_GENERATE": "Генерация субтитров",
}
STATUS_TITLES: dict[str, str] = {
"RUNNING": "Задача запущена",
"DONE": "Задача завершена",
"FAILED": "Ошибка выполнения",
}
# ---------------------------------------------------------------------------
# ConnectionManager — singleton for WebSocket pub/sub via Redis
# ---------------------------------------------------------------------------
_redis_client: aioredis.Redis | None = None
async def _get_redis() -> aioredis.Redis:
global _redis_client
if _redis_client is None:
settings = get_settings()
_redis_client = aioredis.from_url(settings.redis_url, decode_responses=True)
return _redis_client
def _channel_name(user_id: uuid.UUID) -> str:
return f"notifications:{user_id}"
async def publish_to_user(user_id: uuid.UUID, message: WebSocketMessage) -> None:
"""Publish a notification message to a user's Redis channel."""
redis = await _get_redis()
payload = message.model_dump_json()
await redis.publish(_channel_name(user_id), payload)
async def subscribe_and_forward(websocket: object, user_id: uuid.UUID) -> None:
"""Subscribe to a user's Redis channel and forward messages to WebSocket.
``websocket`` must be a ``fastapi.WebSocket`` instance — typed as ``object``
to avoid importing FastAPI at module level.
"""
from fastapi import WebSocket as _WS
ws: _WS = websocket # type: ignore[assignment]
redis = await _get_redis()
pubsub = redis.pubsub()
await pubsub.subscribe(_channel_name(user_id))
try:
async for raw_message in pubsub.listen():
if raw_message["type"] != "message":
continue
await ws.send_text(raw_message["data"])
finally:
await pubsub.unsubscribe(_channel_name(user_id))
await pubsub.aclose()
# ---------------------------------------------------------------------------
# NotificationService
# ---------------------------------------------------------------------------
class NotificationService:
def __init__(self, session: AsyncSession) -> None:
self._repo = NotificationRepository(session)
async def create_task_notification(
self, *, user_id: uuid.UUID, job: Job, event: TaskWebhookEvent
) -> None:
"""Create a notification for a task status change and push via WebSocket."""
notification_type: NotificationTypeEnum | None = None
if event.status == "RUNNING":
notification_type = "task_progress"
elif event.status == "DONE":
notification_type = "task_complete"
elif event.status == "FAILED":
notification_type = "task_failed"
job_type_label = JOB_TYPE_LABELS.get(job.job_type, job.job_type)
now = utcnow()
# Only persist notifications on status changes (not progress-only updates)
notification_id: uuid.UUID | None = None
if notification_type is not None:
title = STATUS_TITLES.get(event.status or "", job_type_label)
notification = await self._repo.create(
NotificationCreate(
user_id=user_id,
job_id=job.id,
project_id=job.project_id,
notification_type=notification_type,
title=title,
message=event.error_message or event.current_message,
payload={
"job_type": job.job_type,
"progress_pct": event.progress_pct,
"status": event.status,
},
)
)
notification_id = notification.id
# Always push WebSocket message (including progress-only updates)
ws_event = "task_update"
ws_message = WebSocketMessage(
event=ws_event,
notification_id=notification_id,
job_id=job.id,
project_id=job.project_id,
job_type=job.job_type,
status=event.status or job.status,
progress_pct=event.progress_pct or job.project_pct,
message=event.error_message or event.current_message or job.current_message,
title=job_type_label,
created_at=now,
)
try:
await publish_to_user(user_id, ws_message)
except Exception:
logger.exception("Failed to publish WebSocket notification for user %s", user_id)