144 lines
5.2 KiB
Python
144 lines
5.2 KiB
Python
from __future__ import annotations
|
|
|
|
import logging
|
|
import uuid
|
|
|
|
import redis.asyncio as aioredis
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from cpv3.db.base import utcnow
|
|
from cpv3.infrastructure.settings import get_settings
|
|
from cpv3.modules.jobs.models import Job
|
|
from cpv3.modules.notifications.repository import NotificationRepository
|
|
from cpv3.modules.notifications.schemas import (
|
|
NotificationCreate,
|
|
NotificationTypeEnum,
|
|
WebSocketMessage,
|
|
)
|
|
from cpv3.modules.tasks.schemas import TaskWebhookEvent
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
JOB_TYPE_LABELS: dict[str, str] = {
|
|
"MEDIA_PROBE": "Анализ медиа",
|
|
"SILENCE_REMOVE": "Удаление тишины",
|
|
"SILENCE_DETECT": "Обнаружение тишины",
|
|
"SILENCE_APPLY": "Применение вырезок",
|
|
"MEDIA_CONVERT": "Конвертация",
|
|
"TRANSCRIPTION_GENERATE": "Транскрипция",
|
|
"CAPTIONS_GENERATE": "Генерация субтитров",
|
|
}
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# ConnectionManager — singleton for WebSocket pub/sub via Redis
|
|
# ---------------------------------------------------------------------------
|
|
|
|
_redis_client: aioredis.Redis | None = None
|
|
|
|
|
|
async def _get_redis() -> aioredis.Redis:
|
|
global _redis_client
|
|
if _redis_client is None:
|
|
settings = get_settings()
|
|
_redis_client = aioredis.from_url(settings.redis_url, decode_responses=True)
|
|
return _redis_client
|
|
|
|
|
|
def _channel_name(user_id: uuid.UUID) -> str:
|
|
return f"notifications:{user_id}"
|
|
|
|
|
|
async def publish_to_user(user_id: uuid.UUID, message: WebSocketMessage) -> None:
|
|
"""Publish a notification message to a user's Redis channel."""
|
|
redis = await _get_redis()
|
|
payload = message.model_dump_json()
|
|
await redis.publish(_channel_name(user_id), payload)
|
|
|
|
|
|
async def subscribe_and_forward(websocket: object, user_id: uuid.UUID) -> None:
|
|
"""Subscribe to a user's Redis channel and forward messages to WebSocket.
|
|
|
|
``websocket`` must be a ``fastapi.WebSocket`` instance — typed as ``object``
|
|
to avoid importing FastAPI at module level.
|
|
"""
|
|
from fastapi import WebSocket as _WS
|
|
|
|
ws: _WS = websocket # type: ignore[assignment]
|
|
redis = await _get_redis()
|
|
pubsub = redis.pubsub()
|
|
await pubsub.subscribe(_channel_name(user_id))
|
|
|
|
try:
|
|
async for raw_message in pubsub.listen():
|
|
if raw_message["type"] != "message":
|
|
continue
|
|
await ws.send_text(raw_message["data"])
|
|
finally:
|
|
await pubsub.unsubscribe(_channel_name(user_id))
|
|
await pubsub.aclose()
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# NotificationService
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class NotificationService:
|
|
def __init__(self, session: AsyncSession) -> None:
|
|
self._repo = NotificationRepository(session)
|
|
|
|
async def create_task_notification(
|
|
self, *, user_id: uuid.UUID, job: Job, event: TaskWebhookEvent
|
|
) -> None:
|
|
"""Create a notification for a task status change and push via WebSocket."""
|
|
notification_type: NotificationTypeEnum | None = None
|
|
if event.status == "RUNNING":
|
|
notification_type = "task_progress"
|
|
elif event.status == "DONE":
|
|
notification_type = "task_complete"
|
|
elif event.status == "FAILED":
|
|
notification_type = "task_failed"
|
|
|
|
job_type_label = JOB_TYPE_LABELS.get(job.job_type, job.job_type)
|
|
now = utcnow()
|
|
|
|
# Only persist notifications on status changes (not progress-only updates)
|
|
notification_id: uuid.UUID | None = None
|
|
if notification_type is not None:
|
|
notification = await self._repo.create(
|
|
NotificationCreate(
|
|
user_id=user_id,
|
|
job_id=job.id,
|
|
project_id=job.project_id,
|
|
notification_type=notification_type,
|
|
title=job_type_label,
|
|
message=event.error_message or event.current_message,
|
|
payload={
|
|
"job_type": job.job_type,
|
|
"progress_pct": event.progress_pct,
|
|
"status": event.status,
|
|
},
|
|
)
|
|
)
|
|
notification_id = notification.id
|
|
|
|
# Always push WebSocket message (including progress-only updates)
|
|
ws_event = "task_update"
|
|
ws_message = WebSocketMessage(
|
|
event=ws_event,
|
|
notification_id=notification_id,
|
|
job_id=job.id,
|
|
project_id=job.project_id,
|
|
job_type=job.job_type,
|
|
status=event.status if event.status is not None else job.status,
|
|
progress_pct=job.project_pct if event.progress_pct is None else event.progress_pct,
|
|
message=event.error_message or event.current_message or job.current_message,
|
|
title=job_type_label,
|
|
created_at=now,
|
|
)
|
|
|
|
try:
|
|
await publish_to_user(user_id, ws_message)
|
|
except Exception:
|
|
logger.exception("Failed to publish WebSocket notification for user %s", user_id)
|