new features
This commit is contained in:
@@ -15,6 +15,7 @@ from cpv3.modules.system.router import router as system_router
|
||||
from cpv3.modules.tasks.router import router as tasks_router
|
||||
from cpv3.modules.transcription.router import router as transcription_router
|
||||
from cpv3.modules.users.router import auth_router, users_router
|
||||
from cpv3.modules.notifications.router import router as notifications_router
|
||||
from cpv3.modules.webhooks.router import router as webhooks_router
|
||||
|
||||
api_router = APIRouter()
|
||||
@@ -48,5 +49,8 @@ api_router.include_router(events_router)
|
||||
# Tasks (background processing)
|
||||
api_router.include_router(tasks_router)
|
||||
|
||||
# Notifications
|
||||
api_router.include_router(notifications_router)
|
||||
|
||||
# Webhooks
|
||||
api_router.include_router(webhooks_router)
|
||||
|
||||
@@ -5,6 +5,7 @@ from cpv3.modules.projects.models import Project
|
||||
from cpv3.modules.files.models import File
|
||||
from cpv3.modules.transcription.models import Transcription
|
||||
from cpv3.modules.users.models import User
|
||||
from cpv3.modules.notifications.models import Notification
|
||||
from cpv3.modules.webhooks.models import Webhook
|
||||
|
||||
__all__ = [
|
||||
@@ -17,5 +18,6 @@ __all__ = [
|
||||
"Transcription",
|
||||
"Job",
|
||||
"JobEvent",
|
||||
"Notification",
|
||||
"Webhook",
|
||||
]
|
||||
|
||||
+29
-5
@@ -2,19 +2,43 @@ from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.ext.asyncio import (
|
||||
AsyncSession,
|
||||
async_sessionmaker,
|
||||
create_async_engine,
|
||||
)
|
||||
|
||||
from cpv3.infrastructure.settings import get_settings
|
||||
|
||||
|
||||
_settings = get_settings()
|
||||
_database_url = _settings.get_database_url()
|
||||
|
||||
_engine_kwargs: dict[str, bool | int] = {
|
||||
"echo": _settings.debug,
|
||||
"pool_pre_ping": True,
|
||||
}
|
||||
|
||||
if not _database_url.startswith("sqlite"):
|
||||
_engine_kwargs.update(
|
||||
{
|
||||
"pool_size": _settings.db_pool_size,
|
||||
"max_overflow": _settings.db_max_overflow,
|
||||
"pool_timeout": _settings.db_pool_timeout,
|
||||
"pool_recycle": _settings.db_pool_recycle_seconds,
|
||||
}
|
||||
)
|
||||
|
||||
_engine = create_async_engine(
|
||||
_settings.get_database_url(),
|
||||
echo=_settings.debug,
|
||||
pool_pre_ping=True,
|
||||
_database_url,
|
||||
**_engine_kwargs,
|
||||
)
|
||||
|
||||
SessionLocal = async_sessionmaker(bind=_engine, class_=AsyncSession, expire_on_commit=False)
|
||||
SessionLocal = async_sessionmaker(
|
||||
bind=_engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
)
|
||||
|
||||
|
||||
async def get_db() -> AsyncGenerator[AsyncSession, None]:
|
||||
|
||||
+38
-34
@@ -17,44 +17,48 @@ _bearer = HTTPBearer(auto_error=True)
|
||||
|
||||
async def get_current_user(
|
||||
credentials: HTTPAuthorizationCredentials = Depends(_bearer),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
db: AsyncSession = Depends(get_db, use_cache=False),
|
||||
) -> User:
|
||||
token = credentials.credentials
|
||||
|
||||
try:
|
||||
payload = decode_token(token)
|
||||
except ExpiredSignatureError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="Token expired"
|
||||
) from e
|
||||
except InvalidTokenError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token"
|
||||
) from e
|
||||
token = credentials.credentials
|
||||
|
||||
if payload.get("type") != "access":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token"
|
||||
)
|
||||
try:
|
||||
payload = decode_token(token)
|
||||
except ExpiredSignatureError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="Token expired"
|
||||
) from e
|
||||
except InvalidTokenError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token"
|
||||
) from e
|
||||
|
||||
sub = payload.get("sub")
|
||||
if not sub:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token"
|
||||
)
|
||||
if payload.get("type") != "access":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token"
|
||||
)
|
||||
|
||||
try:
|
||||
user_id = uuid.UUID(str(sub))
|
||||
except ValueError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token"
|
||||
) from e
|
||||
sub = payload.get("sub")
|
||||
if not sub:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token"
|
||||
)
|
||||
|
||||
user_repo = UserRepository(db)
|
||||
user = await user_repo.get_by_id(user_id)
|
||||
if user is None or not user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid credentials"
|
||||
)
|
||||
try:
|
||||
user_id = uuid.UUID(str(sub))
|
||||
except ValueError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token"
|
||||
) from e
|
||||
|
||||
return user
|
||||
user_repo = UserRepository(db)
|
||||
user = await user_repo.get_by_id(user_id)
|
||||
if user is None or not user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid credentials"
|
||||
)
|
||||
|
||||
return user
|
||||
finally:
|
||||
# Free the auth session immediately so long-running handlers don't pin a pool slot.
|
||||
await db.close()
|
||||
|
||||
@@ -17,7 +17,11 @@ class Settings(BaseSettings):
|
||||
# App
|
||||
debug: bool = Field(default=True, alias="DEBUG")
|
||||
cors_allowed_origins: list[str] = Field(
|
||||
default_factory=lambda: ["http://localhost:3000", "http://localhost:8000"],
|
||||
default_factory=lambda: [
|
||||
"http://localhost:3000",
|
||||
"http://localhost:3001",
|
||||
"http://localhost:8000",
|
||||
],
|
||||
alias="CORS_ALLOWED_ORIGINS",
|
||||
)
|
||||
|
||||
@@ -37,6 +41,13 @@ class Settings(BaseSettings):
|
||||
)
|
||||
|
||||
database_url: str | None = Field(default=None, alias="DATABASE_URL")
|
||||
db_pool_size: int = Field(default=5, alias="DB_POOL_SIZE")
|
||||
db_max_overflow: int = Field(default=10, alias="DB_MAX_OVERFLOW")
|
||||
db_pool_timeout: int = Field(default=30, alias="DB_POOL_TIMEOUT")
|
||||
db_pool_recycle_seconds: int = Field(
|
||||
default=1800,
|
||||
alias="DB_POOL_RECYCLE_SECONDS",
|
||||
)
|
||||
|
||||
# Storage
|
||||
storage_backend: str = Field(default="S3", alias="STORAGE_BACKEND")
|
||||
|
||||
@@ -0,0 +1,13 @@
|
||||
"""Storage utility helpers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from cpv3.modules.users.models import User
|
||||
|
||||
|
||||
def get_user_folder(user: User) -> str:
|
||||
"""Return the per-user S3 folder prefix: ``<username>_<user_id>``."""
|
||||
return f"{user.username}_{user.id}"
|
||||
@@ -33,6 +33,12 @@ class FileRepository:
|
||||
return None
|
||||
return file
|
||||
|
||||
async def get_by_path(self, path: str) -> File | None:
|
||||
result = await self._session.execute(
|
||||
select(File).where(File.path == path, File.is_deleted.is_(False))
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def create(self, *, requester: User, data: FileCreate) -> File:
|
||||
file = File(
|
||||
owner_id=requester.id,
|
||||
|
||||
@@ -19,6 +19,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from cpv3.infrastructure.auth import get_current_user
|
||||
from cpv3.infrastructure.deps import get_storage
|
||||
from cpv3.infrastructure.storage.base import StorageService
|
||||
from cpv3.infrastructure.storage.utils import get_user_folder
|
||||
from cpv3.infrastructure.settings import get_settings
|
||||
from cpv3.db.session import get_db
|
||||
from cpv3.modules.files.schemas import (
|
||||
@@ -32,7 +33,7 @@ from cpv3.modules.users.models import User
|
||||
|
||||
router = APIRouter(prefix="/api/files", tags=["Files"])
|
||||
|
||||
MAX_MB_SIZE = 100
|
||||
MAX_MB_SIZE = 1024
|
||||
|
||||
|
||||
@router.post(
|
||||
@@ -44,8 +45,6 @@ async def upload_file(
|
||||
current_user: User = Depends(get_current_user),
|
||||
storage: StorageService = Depends(get_storage),
|
||||
) -> FileInfoResponse:
|
||||
_ = current_user
|
||||
|
||||
# Validate max file size (matches old behavior).
|
||||
file.file.seek(0, 2)
|
||||
size_bytes = file.file.tell()
|
||||
@@ -58,10 +57,13 @@ async def upload_file(
|
||||
detail=f"File size exceeds the maximum limit of {MAX_MB_SIZE} MB.",
|
||||
)
|
||||
|
||||
user_folder = get_user_folder(current_user)
|
||||
resolved_folder = f"{user_folder}/{folder}" if folder else f"{user_folder}/user_upload"
|
||||
|
||||
key = await storage.upload_fileobj(
|
||||
fileobj=file.file,
|
||||
file_name=file.filename or "upload.bin",
|
||||
folder=folder,
|
||||
folder=resolved_folder,
|
||||
gen_name=True,
|
||||
content_type=file.content_type,
|
||||
)
|
||||
@@ -81,7 +83,10 @@ async def get_file_info(
|
||||
current_user: User = Depends(get_current_user),
|
||||
storage: StorageService = Depends(get_storage),
|
||||
) -> FileInfoResponse:
|
||||
_ = current_user
|
||||
if not current_user.is_staff:
|
||||
user_prefix = f"{get_user_folder(current_user)}/"
|
||||
if not file_path.startswith(user_prefix):
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Forbidden")
|
||||
|
||||
if not await storage.exists(file_path):
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Not found")
|
||||
|
||||
@@ -11,9 +11,12 @@ JobStatusEnum = Literal["PENDING", "RUNNING", "FAILED", "CANCELLED", "DONE"]
|
||||
JobTypeEnum = Literal[
|
||||
"MEDIA_PROBE",
|
||||
"SILENCE_REMOVE",
|
||||
"SILENCE_DETECT",
|
||||
"SILENCE_APPLY",
|
||||
"MEDIA_CONVERT",
|
||||
"TRANSCRIPTION_GENERATE",
|
||||
"CAPTIONS_GENERATE",
|
||||
"FRAME_EXTRACT",
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -46,8 +46,11 @@ class ArtifactMediaFile(Base, BaseModelMixin):
|
||||
file_id: Mapped[uuid.UUID | None] = mapped_column(
|
||||
UUID(as_uuid=True), ForeignKey("files.id", ondelete="RESTRICT"), nullable=True, index=True
|
||||
)
|
||||
media_file_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True), ForeignKey("media_files.id", ondelete="RESTRICT"), index=True
|
||||
media_file_id: Mapped[uuid.UUID | None] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("media_files.id", ondelete="RESTRICT"),
|
||||
nullable=True,
|
||||
index=True,
|
||||
)
|
||||
|
||||
artifact_type: Mapped[str] = mapped_column(String(32), default="TRANSCRIPTION_JSON")
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
import uuid
|
||||
from os import path
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Response, status
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
@@ -8,11 +10,14 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from cpv3.infrastructure.auth import get_current_user
|
||||
from cpv3.infrastructure.deps import get_storage
|
||||
from cpv3.infrastructure.storage.base import StorageService
|
||||
from cpv3.infrastructure.storage.utils import get_user_folder
|
||||
from cpv3.db.session import get_db
|
||||
from cpv3.modules.media.schemas import (
|
||||
ArtifactMediaFileCreate,
|
||||
ArtifactMediaFileRead,
|
||||
ArtifactMediaFileUpdate,
|
||||
FrameItem,
|
||||
FrameRangeResponse,
|
||||
MediaConverterParams,
|
||||
MediaFileCreate,
|
||||
MediaFileRead,
|
||||
@@ -20,7 +25,13 @@ from cpv3.modules.media.schemas import (
|
||||
MediaProbeSchema,
|
||||
MediaSilencerParams,
|
||||
)
|
||||
from cpv3.modules.media.service import convert_to_mp4, probe_media, remove_silence
|
||||
from cpv3.modules.media.service import (
|
||||
convert_to_mp4,
|
||||
get_frames_folder,
|
||||
probe_media,
|
||||
read_frames_metadata,
|
||||
remove_silence,
|
||||
)
|
||||
from cpv3.modules.media.repository import ArtifactRepository, MediaFileRepository
|
||||
from cpv3.modules.files.schemas import FileInfoResponse
|
||||
from cpv3.modules.users.models import User
|
||||
@@ -46,12 +57,13 @@ async def silence_remove(
|
||||
current_user: User = Depends(get_current_user),
|
||||
storage: StorageService = Depends(get_storage),
|
||||
) -> FileInfoResponse:
|
||||
_ = current_user
|
||||
user_folder = get_user_folder(current_user)
|
||||
resolved_folder = f"{user_folder}/{body.folder}" if body.folder else f"{user_folder}/output_files"
|
||||
|
||||
info = await remove_silence(
|
||||
storage,
|
||||
file_key=body.file_path,
|
||||
out_folder=body.folder,
|
||||
out_folder=resolved_folder,
|
||||
min_silence_duration_ms=body.min_silence_duration_ms,
|
||||
silence_threshold_db=body.silence_threshold_db,
|
||||
padding_ms=body.padding_ms,
|
||||
@@ -71,9 +83,10 @@ async def convert(
|
||||
current_user: User = Depends(get_current_user),
|
||||
storage: StorageService = Depends(get_storage),
|
||||
) -> FileInfoResponse:
|
||||
_ = current_user
|
||||
user_folder = get_user_folder(current_user)
|
||||
resolved_folder = f"{user_folder}/{body.folder}" if body.folder else f"{user_folder}/output_files"
|
||||
|
||||
info = await convert_to_mp4(storage, file_key=body.file_path, out_folder=body.folder)
|
||||
info = await convert_to_mp4(storage, file_key=body.file_path, out_folder=resolved_folder)
|
||||
return FileInfoResponse(
|
||||
file_path=info.file_path,
|
||||
file_url=info.file_url,
|
||||
@@ -82,6 +95,36 @@ async def convert(
|
||||
)
|
||||
|
||||
|
||||
@media_router.get("/frames/", response_model=FrameRangeResponse)
|
||||
async def get_frames(
|
||||
file_key: str = Query(..., description="S3 key of the source video"),
|
||||
start: float = Query(0.0, ge=0, description="Start time in seconds"),
|
||||
end: float = Query(..., gt=0, description="End time in seconds"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
storage: StorageService = Depends(get_storage),
|
||||
) -> FrameRangeResponse:
|
||||
"""Return presigned URLs for extracted frames within a time range."""
|
||||
user_folder = get_user_folder(current_user)
|
||||
frames_folder = get_frames_folder(user_folder, file_key)
|
||||
|
||||
metadata = await read_frames_metadata(storage, frames_folder=frames_folder)
|
||||
if metadata is None:
|
||||
return FrameRangeResponse(interval=1.0, frames=[])
|
||||
|
||||
interval = metadata.interval
|
||||
first_index = max(1, math.floor(start / interval) + 1)
|
||||
last_index = min(metadata.frame_count, math.ceil(end / interval) + 1)
|
||||
|
||||
frames: list[FrameItem] = []
|
||||
for i in range(first_index, last_index + 1):
|
||||
key = path.join(frames_folder, f"{i:06d}.jpg")
|
||||
timestamp = (i - 1) * interval
|
||||
url = await storage.url(key)
|
||||
frames.append(FrameItem(timestamp=timestamp, url=url))
|
||||
|
||||
return FrameRangeResponse(interval=interval, frames=frames)
|
||||
|
||||
|
||||
@mediafiles_router.get("/mediafiles/", response_model=list[MediaFileRead])
|
||||
async def list_mediafiles(
|
||||
current_user: User = Depends(get_current_user),
|
||||
|
||||
@@ -12,9 +12,11 @@ from cpv3.common.schemas import Schema
|
||||
ArtifactTypeEnum = Literal[
|
||||
"TRANSCRIPTION_JSON",
|
||||
"SILENCE_REMOVED_VIDEO",
|
||||
"CONVERTED_VIDEO",
|
||||
"THUMBNAIL",
|
||||
"AUDIO_PROXY",
|
||||
"RENDERED_VIDEO",
|
||||
"FRAME_SPRITES",
|
||||
]
|
||||
|
||||
|
||||
@@ -60,7 +62,7 @@ class ArtifactMediaFileRead(Schema):
|
||||
id: UUID
|
||||
project_id: UUID | None
|
||||
file_id: UUID | None
|
||||
media_file_id: UUID
|
||||
media_file_id: UUID | None
|
||||
|
||||
artifact_type: ArtifactTypeEnum
|
||||
|
||||
@@ -74,7 +76,7 @@ class ArtifactMediaFileRead(Schema):
|
||||
class ArtifactMediaFileCreate(Schema):
|
||||
project_id: UUID | None = None
|
||||
file_id: UUID | None = None
|
||||
media_file_id: UUID
|
||||
media_file_id: UUID | None = None
|
||||
artifact_type: ArtifactTypeEnum
|
||||
|
||||
|
||||
@@ -148,3 +150,28 @@ class MediaSilencerParams(Schema):
|
||||
class MediaConverterParams(Schema):
|
||||
file_path: str
|
||||
folder: str = ""
|
||||
|
||||
|
||||
class FrameSpriteMetadata(Schema):
|
||||
"""Metadata stored in ArtifactMediaFile.meta for extracted frames."""
|
||||
|
||||
frame_count: int
|
||||
interval: float
|
||||
width: int
|
||||
height: int
|
||||
folder_key: str
|
||||
source_file_key: str
|
||||
|
||||
|
||||
class FrameItem(Schema):
|
||||
"""Single frame in a range query response."""
|
||||
|
||||
timestamp: float
|
||||
url: str
|
||||
|
||||
|
||||
class FrameRangeResponse(Schema):
|
||||
"""Response for GET /api/media/frames/ range query."""
|
||||
|
||||
interval: float
|
||||
frames: list[FrameItem]
|
||||
|
||||
@@ -1,15 +1,30 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import glob as glob_mod
|
||||
import hashlib
|
||||
import io
|
||||
import json
|
||||
from os import path
|
||||
from tempfile import NamedTemporaryFile
|
||||
from tempfile import NamedTemporaryFile, mkdtemp
|
||||
from typing import Callable
|
||||
|
||||
import anyio
|
||||
|
||||
from cpv3.infrastructure.storage.base import StorageService
|
||||
from cpv3.infrastructure.storage.types import FileInfo
|
||||
from cpv3.modules.media.schemas import MediaProbeSchema
|
||||
from cpv3.modules.media.schemas import FrameSpriteMetadata, MediaProbeSchema
|
||||
|
||||
FRAME_WIDTH_PX = 128
|
||||
FRAME_FPS = 1
|
||||
FRAME_JPEG_QUALITY = 5
|
||||
FRAMES_META_FILENAME = "meta.json"
|
||||
|
||||
|
||||
def get_frames_folder(user_folder: str, file_key: str) -> str:
|
||||
"""Build deterministic S3 folder path for frames based on file_key hash."""
|
||||
key_hash = hashlib.sha256(file_key.encode()).hexdigest()[:16]
|
||||
return path.join(user_folder, "frames", key_hash)
|
||||
|
||||
|
||||
async def probe_media(storage: StorageService, *, file_key: str) -> MediaProbeSchema:
|
||||
@@ -68,6 +83,160 @@ def _compute_non_silent_segments(
|
||||
return segments
|
||||
|
||||
|
||||
async def detect_silence(
|
||||
storage: StorageService,
|
||||
*,
|
||||
file_key: str,
|
||||
min_silence_duration_ms: int = 200,
|
||||
silence_threshold_db: int = 16,
|
||||
padding_ms: int = 100,
|
||||
) -> dict:
|
||||
"""Detect silent segments in a media file and return their intervals."""
|
||||
input_tmp = await storage.download_to_temp(file_key)
|
||||
|
||||
try:
|
||||
from pydub import AudioSegment # type: ignore[import-untyped]
|
||||
|
||||
audio: AudioSegment = await anyio.to_thread.run_sync(
|
||||
lambda: AudioSegment.from_file(input_tmp.path)
|
||||
)
|
||||
duration_ms = len(audio)
|
||||
|
||||
non_silent = await anyio.to_thread.run_sync(
|
||||
lambda: _compute_non_silent_segments(
|
||||
local_audio_path=input_tmp.path,
|
||||
min_silence_duration_ms=min_silence_duration_ms,
|
||||
silence_threshold_db=silence_threshold_db,
|
||||
padding_ms=padding_ms,
|
||||
)
|
||||
)
|
||||
|
||||
# Invert non-silent segments to get silent segments
|
||||
silent_segments: list[dict[str, int]] = []
|
||||
prev_end = 0
|
||||
for start_ms, end_ms in non_silent:
|
||||
if start_ms > prev_end:
|
||||
silent_segments.append({"start_ms": prev_end, "end_ms": start_ms})
|
||||
prev_end = end_ms
|
||||
if prev_end < duration_ms:
|
||||
silent_segments.append({"start_ms": prev_end, "end_ms": duration_ms})
|
||||
|
||||
return {
|
||||
"silent_segments": silent_segments,
|
||||
"duration_ms": duration_ms,
|
||||
"file_key": file_key,
|
||||
}
|
||||
finally:
|
||||
input_tmp.cleanup()
|
||||
|
||||
|
||||
async def apply_silence_cuts(
|
||||
storage: StorageService,
|
||||
*,
|
||||
file_key: str,
|
||||
out_folder: str,
|
||||
cuts: list[dict],
|
||||
output_name: str | None = None,
|
||||
) -> FileInfo:
|
||||
"""Apply explicit cut regions to a media file, concatenating the non-cut parts."""
|
||||
input_tmp = await storage.download_to_temp(file_key)
|
||||
|
||||
try:
|
||||
from pydub import AudioSegment # type: ignore[import-untyped]
|
||||
|
||||
audio: AudioSegment = await anyio.to_thread.run_sync(
|
||||
lambda: AudioSegment.from_file(input_tmp.path)
|
||||
)
|
||||
duration_ms = len(audio)
|
||||
|
||||
# Sort cuts and compute non-cut (keep) segments
|
||||
sorted_cuts = sorted(cuts, key=lambda c: c["start_ms"])
|
||||
segments: list[tuple[int, int]] = []
|
||||
prev_end = 0
|
||||
for cut in sorted_cuts:
|
||||
cut_start = max(0, cut["start_ms"])
|
||||
cut_end = min(duration_ms, cut["end_ms"])
|
||||
if cut_start > prev_end:
|
||||
segments.append((prev_end, cut_start))
|
||||
prev_end = max(prev_end, cut_end)
|
||||
if prev_end < duration_ms:
|
||||
segments.append((prev_end, duration_ms))
|
||||
|
||||
if not segments:
|
||||
return await storage.get_file_info(file_key)
|
||||
|
||||
with NamedTemporaryFile(
|
||||
suffix=path.splitext(file_key)[1] or ".mp4", delete=False
|
||||
) as out:
|
||||
out_path = out.name
|
||||
|
||||
try:
|
||||
cmd: list[str] = ["ffmpeg"]
|
||||
for start_ms, end_ms in segments:
|
||||
start_s = start_ms / 1000.0
|
||||
duration_s = (end_ms - start_ms) / 1000.0
|
||||
cmd.extend(
|
||||
[
|
||||
"-ss",
|
||||
f"{start_s:.3f}",
|
||||
"-t",
|
||||
f"{duration_s:.3f}",
|
||||
"-y",
|
||||
"-i",
|
||||
input_tmp.path,
|
||||
]
|
||||
)
|
||||
|
||||
seg_count = len(segments)
|
||||
parts = [f"[{i}:v:0][{i}:a:0]" for i in range(seg_count)]
|
||||
filter_complex = "".join(parts) + f"concat=n={seg_count}:v=1:a=1[v][a]"
|
||||
|
||||
cmd.extend(
|
||||
[
|
||||
"-filter_complex",
|
||||
filter_complex,
|
||||
"-map",
|
||||
"[v]",
|
||||
"-map",
|
||||
"[a]",
|
||||
"-c:v",
|
||||
"libx264",
|
||||
"-c:a",
|
||||
"aac",
|
||||
"-preset",
|
||||
"medium",
|
||||
out_path,
|
||||
]
|
||||
)
|
||||
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
*cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
|
||||
)
|
||||
_, stderr = await proc.communicate()
|
||||
if proc.returncode != 0:
|
||||
raise RuntimeError(f"ffmpeg failed: {stderr.decode(errors='ignore')}")
|
||||
|
||||
base_name = output_name or path.basename(file_key)
|
||||
output_key = path.join(out_folder or "", "silent", base_name)
|
||||
with open(out_path, "rb") as out_file:
|
||||
_ = await storage.upload_fileobj(
|
||||
fileobj=out_file,
|
||||
file_name=path.basename(output_key),
|
||||
folder=path.dirname(output_key),
|
||||
gen_name=False,
|
||||
content_type="video/mp4",
|
||||
)
|
||||
|
||||
return await storage.get_file_info(output_key)
|
||||
finally:
|
||||
import os
|
||||
|
||||
if os.path.exists(out_path):
|
||||
os.remove(out_path)
|
||||
finally:
|
||||
input_tmp.cleanup()
|
||||
|
||||
|
||||
async def remove_silence(
|
||||
storage: StorageService,
|
||||
*,
|
||||
@@ -264,3 +433,131 @@ async def convert_to_ogg_temp(
|
||||
|
||||
_ = filename_without_ext
|
||||
return out_path, _cleanup
|
||||
|
||||
|
||||
async def extract_frames(
|
||||
storage: StorageService,
|
||||
*,
|
||||
file_key: str,
|
||||
frames_folder: str,
|
||||
on_progress: Callable[[int, int], None] | None = None,
|
||||
) -> FrameSpriteMetadata:
|
||||
"""Extract video frames at 1fps via ffmpeg and upload to S3.
|
||||
|
||||
Also writes a ``meta.json`` alongside the frames for fast lookup.
|
||||
Returns metadata about the extracted frames.
|
||||
"""
|
||||
input_tmp = await storage.download_to_temp(file_key)
|
||||
tmp_dir = mkdtemp(prefix="frames_")
|
||||
|
||||
try:
|
||||
cmd = [
|
||||
"ffmpeg",
|
||||
"-y",
|
||||
"-i",
|
||||
input_tmp.path,
|
||||
"-vf",
|
||||
f"fps={FRAME_FPS},scale={FRAME_WIDTH_PX}:-1",
|
||||
"-q:v",
|
||||
str(FRAME_JPEG_QUALITY),
|
||||
path.join(tmp_dir, "%06d.jpg"),
|
||||
]
|
||||
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
*cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
|
||||
)
|
||||
_, stderr = await proc.communicate()
|
||||
if proc.returncode != 0:
|
||||
raise RuntimeError(f"ffmpeg frame extraction failed: {stderr.decode(errors='ignore')}")
|
||||
|
||||
frame_files = sorted(glob_mod.glob(path.join(tmp_dir, "*.jpg")))
|
||||
frame_count = len(frame_files)
|
||||
|
||||
if frame_count == 0:
|
||||
raise RuntimeError("No frames extracted from video")
|
||||
|
||||
# Read first frame dimensions via ffprobe (avoids PIL dependency)
|
||||
probe_proc = await asyncio.create_subprocess_exec(
|
||||
"ffprobe",
|
||||
"-v", "error",
|
||||
"-select_streams", "v:0",
|
||||
"-show_entries", "stream=width,height",
|
||||
"-of", "json",
|
||||
frame_files[0],
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
probe_stdout, _ = await probe_proc.communicate()
|
||||
probe_data = json.loads(probe_stdout.decode())
|
||||
stream = probe_data.get("streams", [{}])[0]
|
||||
width = stream.get("width", FRAME_WIDTH_PX)
|
||||
height = stream.get("height", FRAME_WIDTH_PX)
|
||||
|
||||
# Upload each frame to S3
|
||||
for idx, frame_path in enumerate(frame_files):
|
||||
frame_name = path.basename(frame_path)
|
||||
with open(frame_path, "rb") as f:
|
||||
await storage.upload_fileobj(
|
||||
fileobj=f,
|
||||
file_name=frame_name,
|
||||
folder=frames_folder,
|
||||
gen_name=False,
|
||||
content_type="image/jpeg",
|
||||
)
|
||||
if on_progress is not None:
|
||||
on_progress(idx + 1, frame_count)
|
||||
|
||||
metadata = FrameSpriteMetadata(
|
||||
frame_count=frame_count,
|
||||
interval=1.0 / FRAME_FPS,
|
||||
width=width,
|
||||
height=height,
|
||||
folder_key=frames_folder,
|
||||
source_file_key=file_key,
|
||||
)
|
||||
|
||||
# Write metadata JSON to S3 for fast lookup by the frames endpoint
|
||||
meta_bytes = json.dumps(metadata.model_dump(mode="json")).encode("utf-8")
|
||||
await storage.upload_fileobj(
|
||||
fileobj=io.BytesIO(meta_bytes),
|
||||
file_name=FRAMES_META_FILENAME,
|
||||
folder=frames_folder,
|
||||
gen_name=False,
|
||||
content_type="application/json",
|
||||
)
|
||||
|
||||
return metadata
|
||||
finally:
|
||||
import shutil
|
||||
|
||||
input_tmp.cleanup()
|
||||
shutil.rmtree(tmp_dir, ignore_errors=True)
|
||||
|
||||
|
||||
async def read_frames_metadata(
|
||||
storage: StorageService, *, frames_folder: str
|
||||
) -> FrameSpriteMetadata | None:
|
||||
"""Read frame extraction metadata from S3. Returns None if not found."""
|
||||
meta_key = path.join(frames_folder, FRAMES_META_FILENAME)
|
||||
if not await storage.exists(meta_key):
|
||||
return None
|
||||
raw = await storage.read(meta_key)
|
||||
return FrameSpriteMetadata.model_validate(json.loads(raw))
|
||||
|
||||
|
||||
async def delete_frames(
|
||||
storage: StorageService, *, frames_folder: str, frame_count: int
|
||||
) -> None:
|
||||
"""Delete all frame files and metadata from S3 for a given folder."""
|
||||
for i in range(1, frame_count + 1):
|
||||
key = path.join(frames_folder, f"{i:06d}.jpg")
|
||||
try:
|
||||
await storage.delete(key)
|
||||
except Exception:
|
||||
pass
|
||||
# Delete metadata file
|
||||
meta_key = path.join(frames_folder, FRAMES_META_FILENAME)
|
||||
try:
|
||||
await storage.delete(meta_key)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@@ -0,0 +1,31 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
|
||||
from sqlalchemy import Boolean, ForeignKey, JSON, String, Text
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from cpv3.db.base import Base, BaseModelMixin
|
||||
|
||||
|
||||
class Notification(Base, BaseModelMixin):
|
||||
__tablename__ = "notifications"
|
||||
|
||||
user_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True), ForeignKey("users.id", ondelete="CASCADE"), index=True
|
||||
)
|
||||
job_id: Mapped[uuid.UUID | None] = mapped_column(
|
||||
UUID(as_uuid=True), ForeignKey("jobs.id", ondelete="SET NULL"), nullable=True
|
||||
)
|
||||
project_id: Mapped[uuid.UUID | None] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("projects.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
)
|
||||
|
||||
notification_type: Mapped[str] = mapped_column(String(32))
|
||||
title: Mapped[str] = mapped_column(String(255))
|
||||
message: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
payload: Mapped[dict | None] = mapped_column(JSON, nullable=True)
|
||||
is_read: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
@@ -0,0 +1,79 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
|
||||
from sqlalchemy import Select, func, select, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from cpv3.modules.notifications.models import Notification
|
||||
from cpv3.modules.notifications.schemas import NotificationCreate
|
||||
|
||||
|
||||
class NotificationRepository:
|
||||
def __init__(self, session: AsyncSession) -> None:
|
||||
self._session = session
|
||||
|
||||
async def list_for_user(
|
||||
self,
|
||||
user_id: uuid.UUID,
|
||||
*,
|
||||
limit: int = 50,
|
||||
unread_only: bool = False,
|
||||
) -> list[Notification]:
|
||||
stmt: Select[tuple[Notification]] = (
|
||||
select(Notification)
|
||||
.where(Notification.user_id == user_id)
|
||||
.where(Notification.is_active.is_(True))
|
||||
.order_by(Notification.created_at.desc())
|
||||
.limit(limit)
|
||||
)
|
||||
if unread_only:
|
||||
stmt = stmt.where(Notification.is_read.is_(False))
|
||||
|
||||
result = await self._session.execute(stmt)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def create(self, data: NotificationCreate) -> Notification:
|
||||
notification = Notification(
|
||||
user_id=data.user_id,
|
||||
job_id=data.job_id,
|
||||
project_id=data.project_id,
|
||||
notification_type=data.notification_type,
|
||||
title=data.title,
|
||||
message=data.message,
|
||||
payload=data.payload,
|
||||
)
|
||||
self._session.add(notification)
|
||||
await self._session.commit()
|
||||
await self._session.refresh(notification)
|
||||
return notification
|
||||
|
||||
async def mark_read(self, notification_id: uuid.UUID, user_id: uuid.UUID) -> bool:
|
||||
result = await self._session.execute(
|
||||
update(Notification)
|
||||
.where(Notification.id == notification_id)
|
||||
.where(Notification.user_id == user_id)
|
||||
.values(is_read=True)
|
||||
)
|
||||
await self._session.commit()
|
||||
return result.rowcount > 0 # type: ignore[union-attr]
|
||||
|
||||
async def mark_all_read(self, user_id: uuid.UUID) -> int:
|
||||
result = await self._session.execute(
|
||||
update(Notification)
|
||||
.where(Notification.user_id == user_id)
|
||||
.where(Notification.is_read.is_(False))
|
||||
.values(is_read=True)
|
||||
)
|
||||
await self._session.commit()
|
||||
return result.rowcount # type: ignore[return-value]
|
||||
|
||||
async def count_unread(self, user_id: uuid.UUID) -> int:
|
||||
result = await self._session.execute(
|
||||
select(func.count())
|
||||
.select_from(Notification)
|
||||
.where(Notification.user_id == user_id)
|
||||
.where(Notification.is_active.is_(True))
|
||||
.where(Notification.is_read.is_(False))
|
||||
)
|
||||
return result.scalar_one()
|
||||
@@ -0,0 +1,102 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, WebSocket, WebSocketDisconnect, status
|
||||
from jwt import ExpiredSignatureError, InvalidTokenError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from cpv3.db.session import SessionLocal, get_db
|
||||
from cpv3.infrastructure.auth import get_current_user
|
||||
from cpv3.infrastructure.security import decode_token
|
||||
from cpv3.modules.notifications.repository import NotificationRepository
|
||||
from cpv3.modules.notifications.schemas import NotificationRead
|
||||
from cpv3.modules.notifications.service import subscribe_and_forward
|
||||
from cpv3.modules.users.models import User
|
||||
from cpv3.modules.users.repository import UserRepository
|
||||
|
||||
router = APIRouter(prefix="/api/notifications", tags=["notifications"])
|
||||
|
||||
|
||||
@router.websocket("/ws/")
|
||||
async def notifications_ws(
|
||||
websocket: WebSocket,
|
||||
token: str = Query(...),
|
||||
) -> None:
|
||||
"""WebSocket endpoint for real-time notifications."""
|
||||
try:
|
||||
payload = decode_token(token)
|
||||
except (ExpiredSignatureError, InvalidTokenError):
|
||||
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
|
||||
return
|
||||
|
||||
if payload.get("type") != "access":
|
||||
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
|
||||
return
|
||||
|
||||
sub = payload.get("sub")
|
||||
if not sub:
|
||||
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
|
||||
return
|
||||
|
||||
try:
|
||||
user_id = uuid.UUID(str(sub))
|
||||
except ValueError:
|
||||
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
|
||||
return
|
||||
|
||||
async with SessionLocal() as db:
|
||||
user_repo = UserRepository(db)
|
||||
user = await user_repo.get_by_id(user_id)
|
||||
if user is None or not user.is_active:
|
||||
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
|
||||
return
|
||||
|
||||
await websocket.accept()
|
||||
|
||||
try:
|
||||
await subscribe_and_forward(websocket, user_id)
|
||||
except WebSocketDisconnect:
|
||||
pass
|
||||
|
||||
|
||||
@router.get("/", response_model=list[NotificationRead])
|
||||
async def list_notifications(
|
||||
unread_only: bool = Query(False),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> list[NotificationRead]:
|
||||
repo = NotificationRepository(db)
|
||||
items = await repo.list_for_user(current_user.id, unread_only=unread_only)
|
||||
return [NotificationRead.model_validate(n) for n in items]
|
||||
|
||||
|
||||
@router.post("/{notification_id}/read/", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def mark_notification_read(
|
||||
notification_id: uuid.UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> None:
|
||||
repo = NotificationRepository(db)
|
||||
found = await repo.mark_read(notification_id, current_user.id)
|
||||
if not found:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Not found")
|
||||
|
||||
|
||||
@router.post("/read-all/", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def mark_all_read(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> None:
|
||||
repo = NotificationRepository(db)
|
||||
await repo.mark_all_read(current_user.id)
|
||||
|
||||
|
||||
@router.get("/unread-count/")
|
||||
async def unread_count(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> dict[str, int]:
|
||||
repo = NotificationRepository(db)
|
||||
count = await repo.count_unread(current_user.id)
|
||||
return {"count": count}
|
||||
@@ -0,0 +1,53 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Literal
|
||||
from uuid import UUID
|
||||
|
||||
from cpv3.common.schemas import Schema
|
||||
|
||||
|
||||
NotificationTypeEnum = Literal["task_progress", "task_complete", "task_failed"]
|
||||
|
||||
|
||||
class NotificationCreate(Schema):
|
||||
user_id: UUID
|
||||
job_id: UUID | None = None
|
||||
project_id: UUID | None = None
|
||||
notification_type: NotificationTypeEnum
|
||||
title: str
|
||||
message: str | None = None
|
||||
payload: dict | None = None
|
||||
|
||||
|
||||
class NotificationRead(Schema):
|
||||
id: UUID
|
||||
user_id: UUID
|
||||
job_id: UUID | None
|
||||
project_id: UUID | None
|
||||
notification_type: NotificationTypeEnum
|
||||
title: str
|
||||
message: str | None
|
||||
payload: dict | None
|
||||
is_read: bool
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
|
||||
class NotificationUpdate(Schema):
|
||||
is_read: bool | None = None
|
||||
|
||||
|
||||
class WebSocketMessage(Schema):
|
||||
"""JSON shape pushed over WebSocket."""
|
||||
|
||||
event: str
|
||||
notification_id: UUID | None = None
|
||||
job_id: UUID | None = None
|
||||
project_id: UUID | None = None
|
||||
job_type: str | None = None
|
||||
status: str | None = None
|
||||
progress_pct: float | None = None
|
||||
message: str | None = None
|
||||
title: str | None = None
|
||||
created_at: datetime | None = None
|
||||
@@ -0,0 +1,152 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
|
||||
import redis.asyncio as aioredis
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from cpv3.db.base import utcnow
|
||||
from cpv3.infrastructure.settings import get_settings
|
||||
from cpv3.modules.jobs.models import Job
|
||||
from cpv3.modules.notifications.repository import NotificationRepository
|
||||
from cpv3.modules.notifications.schemas import (
|
||||
NotificationCreate,
|
||||
NotificationTypeEnum,
|
||||
WebSocketMessage,
|
||||
)
|
||||
from cpv3.modules.tasks.schemas import TaskWebhookEvent
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
JOB_TYPE_LABELS: dict[str, str] = {
|
||||
"MEDIA_PROBE": "Анализ медиа",
|
||||
"SILENCE_REMOVE": "Удаление тишины",
|
||||
"SILENCE_DETECT": "Обнаружение тишины",
|
||||
"SILENCE_APPLY": "Применение вырезок",
|
||||
"MEDIA_CONVERT": "Конвертация",
|
||||
"TRANSCRIPTION_GENERATE": "Транскрипция",
|
||||
"CAPTIONS_GENERATE": "Генерация субтитров",
|
||||
}
|
||||
|
||||
STATUS_TITLES: dict[str, str] = {
|
||||
"RUNNING": "Задача запущена",
|
||||
"DONE": "Задача завершена",
|
||||
"FAILED": "Ошибка выполнения",
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ConnectionManager — singleton for WebSocket pub/sub via Redis
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_redis_client: aioredis.Redis | None = None
|
||||
|
||||
|
||||
async def _get_redis() -> aioredis.Redis:
|
||||
global _redis_client
|
||||
if _redis_client is None:
|
||||
settings = get_settings()
|
||||
_redis_client = aioredis.from_url(settings.redis_url, decode_responses=True)
|
||||
return _redis_client
|
||||
|
||||
|
||||
def _channel_name(user_id: uuid.UUID) -> str:
|
||||
return f"notifications:{user_id}"
|
||||
|
||||
|
||||
async def publish_to_user(user_id: uuid.UUID, message: WebSocketMessage) -> None:
|
||||
"""Publish a notification message to a user's Redis channel."""
|
||||
redis = await _get_redis()
|
||||
payload = message.model_dump_json()
|
||||
await redis.publish(_channel_name(user_id), payload)
|
||||
|
||||
|
||||
async def subscribe_and_forward(websocket: object, user_id: uuid.UUID) -> None:
|
||||
"""Subscribe to a user's Redis channel and forward messages to WebSocket.
|
||||
|
||||
``websocket`` must be a ``fastapi.WebSocket`` instance — typed as ``object``
|
||||
to avoid importing FastAPI at module level.
|
||||
"""
|
||||
from fastapi import WebSocket as _WS
|
||||
|
||||
ws: _WS = websocket # type: ignore[assignment]
|
||||
redis = await _get_redis()
|
||||
pubsub = redis.pubsub()
|
||||
await pubsub.subscribe(_channel_name(user_id))
|
||||
|
||||
try:
|
||||
async for raw_message in pubsub.listen():
|
||||
if raw_message["type"] != "message":
|
||||
continue
|
||||
await ws.send_text(raw_message["data"])
|
||||
finally:
|
||||
await pubsub.unsubscribe(_channel_name(user_id))
|
||||
await pubsub.aclose()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# NotificationService
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class NotificationService:
|
||||
def __init__(self, session: AsyncSession) -> None:
|
||||
self._repo = NotificationRepository(session)
|
||||
|
||||
async def create_task_notification(
|
||||
self, *, user_id: uuid.UUID, job: Job, event: TaskWebhookEvent
|
||||
) -> None:
|
||||
"""Create a notification for a task status change and push via WebSocket."""
|
||||
notification_type: NotificationTypeEnum | None = None
|
||||
if event.status == "RUNNING":
|
||||
notification_type = "task_progress"
|
||||
elif event.status == "DONE":
|
||||
notification_type = "task_complete"
|
||||
elif event.status == "FAILED":
|
||||
notification_type = "task_failed"
|
||||
|
||||
job_type_label = JOB_TYPE_LABELS.get(job.job_type, job.job_type)
|
||||
now = utcnow()
|
||||
|
||||
# Only persist notifications on status changes (not progress-only updates)
|
||||
notification_id: uuid.UUID | None = None
|
||||
if notification_type is not None:
|
||||
title = STATUS_TITLES.get(event.status or "", job_type_label)
|
||||
notification = await self._repo.create(
|
||||
NotificationCreate(
|
||||
user_id=user_id,
|
||||
job_id=job.id,
|
||||
project_id=job.project_id,
|
||||
notification_type=notification_type,
|
||||
title=title,
|
||||
message=event.error_message or event.current_message,
|
||||
payload={
|
||||
"job_type": job.job_type,
|
||||
"progress_pct": event.progress_pct,
|
||||
"status": event.status,
|
||||
},
|
||||
)
|
||||
)
|
||||
notification_id = notification.id
|
||||
|
||||
# Always push WebSocket message (including progress-only updates)
|
||||
ws_event = "task_update"
|
||||
ws_message = WebSocketMessage(
|
||||
event=ws_event,
|
||||
notification_id=notification_id,
|
||||
job_id=job.id,
|
||||
project_id=job.project_id,
|
||||
job_type=job.job_type,
|
||||
status=event.status or job.status,
|
||||
progress_pct=event.progress_pct or job.project_pct,
|
||||
message=event.error_message or event.current_message or job.current_message,
|
||||
title=job_type_label,
|
||||
created_at=now,
|
||||
)
|
||||
|
||||
try:
|
||||
await publish_to_user(user_id, ws_message)
|
||||
except Exception:
|
||||
logger.exception("Failed to publish WebSocket notification for user %s", user_id)
|
||||
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
|
||||
from sqlalchemy import ForeignKey, String, Text
|
||||
from sqlalchemy import ForeignKey, JSON, String, Text
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
@@ -23,3 +23,4 @@ class Project(Base, BaseModelMixin):
|
||||
language: Mapped[str] = mapped_column(String(4), default="auto")
|
||||
folder: Mapped[str | None] = mapped_column(String(1024), nullable=True)
|
||||
status: Mapped[str] = mapped_column(String(16), default="DRAFT")
|
||||
workspace_state: Mapped[dict | None] = mapped_column(JSON, nullable=True)
|
||||
|
||||
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
|
||||
from sqlalchemy import Select, select
|
||||
from sqlalchemy import Select, or_, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from cpv3.modules.projects.models import Project
|
||||
@@ -16,13 +16,31 @@ class ProjectRepository:
|
||||
def __init__(self, session: AsyncSession) -> None:
|
||||
self._session = session
|
||||
|
||||
async def list_all(self, *, requester: User) -> list[Project]:
|
||||
async def list_all(
|
||||
self,
|
||||
*,
|
||||
requester: User,
|
||||
search: str | None = None,
|
||||
status: str | None = None,
|
||||
) -> list[Project]:
|
||||
stmt: Select[tuple[Project]] = select(Project).where(
|
||||
Project.is_active.is_(True)
|
||||
)
|
||||
if not requester.is_staff:
|
||||
stmt = stmt.where(Project.owner_id == requester.id)
|
||||
|
||||
if search:
|
||||
pattern = f"%{search}%"
|
||||
stmt = stmt.where(
|
||||
or_(
|
||||
Project.name.ilike(pattern),
|
||||
Project.description.ilike(pattern),
|
||||
)
|
||||
)
|
||||
|
||||
if status:
|
||||
stmt = stmt.where(Project.status == status)
|
||||
|
||||
result = await self._session.execute(stmt.order_by(Project.created_at.desc()))
|
||||
return list(result.scalars().all())
|
||||
|
||||
@@ -34,14 +52,16 @@ class ProjectRepository:
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def create(self, *, requester: User, data: ProjectCreate) -> Project:
|
||||
async def create(
|
||||
self, *, requester: User, data: ProjectCreate, folder: str, status: str,
|
||||
) -> Project:
|
||||
project = Project(
|
||||
owner_id=requester.id,
|
||||
name=data.name,
|
||||
description=data.description,
|
||||
language=data.language,
|
||||
folder=data.folder,
|
||||
status=data.status,
|
||||
folder=folder,
|
||||
status=status,
|
||||
)
|
||||
|
||||
self._session.add(project)
|
||||
|
||||
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Response, status
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Response, status
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from cpv3.infrastructure.auth import get_current_user
|
||||
@@ -16,11 +16,17 @@ router = APIRouter(prefix="/api/projects", tags=["Projects"])
|
||||
|
||||
@router.get("/", response_model=list[ProjectRead])
|
||||
async def list_all_projects(
|
||||
search: str | None = Query(None, description="Поиск по названию или описанию"),
|
||||
status_filter: str | None = Query(
|
||||
None, alias="status", description="Фильтр по статусу проекта"
|
||||
),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> list[ProjectRead]:
|
||||
service = ProjectService(db)
|
||||
projects = await service.list_projects(requester=current_user)
|
||||
projects = await service.list_projects(
|
||||
requester=current_user, search=search, status=status_filter,
|
||||
)
|
||||
return [ProjectRead.model_validate(p) for p in projects]
|
||||
|
||||
|
||||
|
||||
@@ -19,6 +19,8 @@ class ProjectRead(Schema):
|
||||
folder: str | None
|
||||
status: ProjectStatusEnum
|
||||
|
||||
workspace_state: dict | None
|
||||
|
||||
is_active: bool
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
@@ -28,8 +30,6 @@ class ProjectCreate(Schema):
|
||||
name: str
|
||||
description: str | None = None
|
||||
language: str = "auto"
|
||||
folder: str | None = None
|
||||
status: ProjectStatusEnum = "DRAFT"
|
||||
|
||||
|
||||
class ProjectUpdate(Schema):
|
||||
@@ -38,3 +38,4 @@ class ProjectUpdate(Schema):
|
||||
language: str | None = None
|
||||
folder: str | None = None
|
||||
status: ProjectStatusEnum | None = None
|
||||
workspace_state: dict | None = None
|
||||
|
||||
@@ -16,14 +16,25 @@ class ProjectService:
|
||||
def __init__(self, session: AsyncSession) -> None:
|
||||
self._repo = ProjectRepository(session)
|
||||
|
||||
async def list_projects(self, *, requester: User) -> list[Project]:
|
||||
return await self._repo.list_all(requester=requester)
|
||||
async def list_projects(
|
||||
self,
|
||||
*,
|
||||
requester: User,
|
||||
search: str | None = None,
|
||||
status: str | None = None,
|
||||
) -> list[Project]:
|
||||
return await self._repo.list_all(
|
||||
requester=requester, search=search, status=status,
|
||||
)
|
||||
|
||||
async def get_project(self, project_id: uuid.UUID) -> Project | None:
|
||||
return await self._repo.get_by_id(project_id)
|
||||
|
||||
async def create_project(self, *, requester: User, data: ProjectCreate) -> Project:
|
||||
return await self._repo.create(requester=requester, data=data)
|
||||
folder = f"/{requester.username}/{data.name}"
|
||||
return await self._repo.create(
|
||||
requester=requester, data=data, folder=folder, status="DRAFT",
|
||||
)
|
||||
|
||||
async def update_project(self, project: Project, data: ProjectUpdate) -> Project:
|
||||
return await self._repo.update(project, data)
|
||||
|
||||
@@ -15,8 +15,11 @@ from cpv3.infrastructure.auth import get_current_user
|
||||
from cpv3.modules.jobs.service import JobService
|
||||
from cpv3.modules.tasks.schemas import (
|
||||
CaptionsGenerateRequest,
|
||||
FrameExtractRequest,
|
||||
MediaConvertRequest,
|
||||
MediaProbeRequest,
|
||||
SilenceApplyRequest,
|
||||
SilenceDetectRequest,
|
||||
SilenceRemoveRequest,
|
||||
TaskStatusEnum,
|
||||
TaskStatusResponse,
|
||||
@@ -61,6 +64,36 @@ async def submit_silence_remove(
|
||||
return await service.submit_silence_remove(requester=current_user, request=body)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/silence-detect/",
|
||||
response_model=TaskSubmitResponse,
|
||||
status_code=status.HTTP_202_ACCEPTED,
|
||||
)
|
||||
async def submit_silence_detect(
|
||||
body: SilenceDetectRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> TaskSubmitResponse:
|
||||
"""Submit a background task to detect silent segments in media file."""
|
||||
service = TaskService(db)
|
||||
return await service.submit_silence_detect(requester=current_user, request=body)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/silence-apply/",
|
||||
response_model=TaskSubmitResponse,
|
||||
status_code=status.HTTP_202_ACCEPTED,
|
||||
)
|
||||
async def submit_silence_apply(
|
||||
body: SilenceApplyRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> TaskSubmitResponse:
|
||||
"""Submit a background task to apply silence cuts to media file."""
|
||||
service = TaskService(db)
|
||||
return await service.submit_silence_apply(requester=current_user, request=body)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/media-convert/",
|
||||
response_model=TaskSubmitResponse,
|
||||
@@ -93,6 +126,21 @@ async def submit_transcription_generate(
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/frame-extract/",
|
||||
response_model=TaskSubmitResponse,
|
||||
status_code=status.HTTP_202_ACCEPTED,
|
||||
)
|
||||
async def submit_frame_extract(
|
||||
body: FrameExtractRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> TaskSubmitResponse:
|
||||
"""Submit a background task to extract video frames for timeline thumbnails."""
|
||||
service = TaskService(db)
|
||||
return await service.submit_frame_extract(requester=current_user, request=body)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/captions-generate/",
|
||||
response_model=TaskSubmitResponse,
|
||||
|
||||
@@ -45,6 +45,36 @@ class SilenceRemoveRequest(Schema):
|
||||
)
|
||||
|
||||
|
||||
class SilenceDetectRequest(Schema):
|
||||
"""Request to detect silent segments in media file."""
|
||||
|
||||
file_key: str = Field(..., description="Storage key of the input file")
|
||||
project_id: UUID | None = Field(default=None, description="Associated project ID")
|
||||
min_silence_duration_ms: int = Field(
|
||||
default=200, description="Minimum silence duration in milliseconds"
|
||||
)
|
||||
silence_threshold_db: int = Field(
|
||||
default=16, description="Silence threshold in decibels"
|
||||
)
|
||||
padding_ms: int = Field(
|
||||
default=100, description="Padding around non-silent segments in milliseconds"
|
||||
)
|
||||
|
||||
|
||||
class SilenceApplyRequest(Schema):
|
||||
"""Request to apply silence cuts to media file."""
|
||||
|
||||
file_key: str = Field(..., description="Storage key of the input file")
|
||||
out_folder: str = Field(..., description="Output folder for processed file")
|
||||
project_id: UUID | None = Field(default=None, description="Associated project ID")
|
||||
output_name: str | None = Field(
|
||||
default=None, description="Display name for the output file"
|
||||
)
|
||||
cuts: list[dict] = Field(
|
||||
..., description="Cut regions: [{'start_ms': int, 'end_ms': int}, ...]"
|
||||
)
|
||||
|
||||
|
||||
class MediaConvertRequest(Schema):
|
||||
"""Request to convert media file to different format."""
|
||||
|
||||
@@ -75,6 +105,16 @@ class CaptionsGenerateRequest(Schema):
|
||||
project_id: UUID | None = Field(default=None, description="Associated project ID")
|
||||
|
||||
|
||||
class FrameExtractRequest(Schema):
|
||||
"""Request to extract video frames for timeline thumbnails."""
|
||||
|
||||
file_key: str = Field(..., description="S3 key of the video file")
|
||||
project_id: UUID | None = Field(default=None, description="Associated project ID")
|
||||
regenerate: bool = Field(
|
||||
default=False, description="Delete existing frames and re-extract"
|
||||
)
|
||||
|
||||
|
||||
# --- Response schemas ---
|
||||
|
||||
|
||||
|
||||
@@ -5,8 +5,12 @@ Task service for submitting and managing background tasks.
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
@@ -17,6 +21,8 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from cpv3.infrastructure.deps import _get_storage_service
|
||||
from cpv3.infrastructure.settings import get_settings
|
||||
from cpv3.modules.files.repository import FileRepository
|
||||
from cpv3.modules.files.schemas import FileCreate
|
||||
from cpv3.modules.jobs.models import Job
|
||||
from cpv3.modules.jobs.repository import JobEventRepository, JobRepository
|
||||
from cpv3.modules.jobs.schemas import (
|
||||
@@ -26,17 +32,26 @@ from cpv3.modules.jobs.schemas import (
|
||||
JobTypeEnum,
|
||||
JobUpdate,
|
||||
)
|
||||
from cpv3.modules.media.repository import ArtifactRepository
|
||||
from cpv3.modules.media.schemas import ArtifactMediaFileCreate
|
||||
from cpv3.modules.tasks.schemas import (
|
||||
CaptionsGenerateRequest,
|
||||
FrameExtractRequest,
|
||||
MediaConvertRequest,
|
||||
MediaProbeRequest,
|
||||
SilenceApplyRequest,
|
||||
SilenceDetectRequest,
|
||||
SilenceRemoveRequest,
|
||||
TaskSubmitResponse,
|
||||
TaskWebhookEvent,
|
||||
TranscriptionGenerateRequest,
|
||||
)
|
||||
from cpv3.infrastructure.storage.utils import get_user_folder
|
||||
from cpv3.modules.notifications.service import NotificationService
|
||||
from cpv3.modules.transcription.repository import TranscriptionRepository
|
||||
from cpv3.modules.transcription.schemas import TranscriptionCreate
|
||||
from cpv3.modules.users.models import User
|
||||
from cpv3.modules.users.repository import UserRepository
|
||||
from cpv3.modules.webhooks.repository import WebhookRepository
|
||||
from cpv3.modules.webhooks.schemas import WebhookCreate
|
||||
|
||||
@@ -49,9 +64,12 @@ JOB_STATUS_FAILED: JobStatusEnum = "FAILED"
|
||||
|
||||
JOB_TYPE_MEDIA_PROBE: JobTypeEnum = "MEDIA_PROBE"
|
||||
JOB_TYPE_SILENCE_REMOVE: JobTypeEnum = "SILENCE_REMOVE"
|
||||
JOB_TYPE_SILENCE_DETECT: JobTypeEnum = "SILENCE_DETECT"
|
||||
JOB_TYPE_SILENCE_APPLY: JobTypeEnum = "SILENCE_APPLY"
|
||||
JOB_TYPE_MEDIA_CONVERT: JobTypeEnum = "MEDIA_CONVERT"
|
||||
JOB_TYPE_TRANSCRIPTION_GENERATE: JobTypeEnum = "TRANSCRIPTION_GENERATE"
|
||||
JOB_TYPE_CAPTIONS_GENERATE: JobTypeEnum = "CAPTIONS_GENERATE"
|
||||
JOB_TYPE_FRAME_EXTRACT: JobTypeEnum = "FRAME_EXTRACT"
|
||||
|
||||
EVENT_TYPE_STATUS_PREFIX = "status_"
|
||||
EVENT_TYPE_PROGRESS = "progress"
|
||||
@@ -62,19 +80,41 @@ EVENT_TYPE_ERROR = "error"
|
||||
TASK_WEBHOOK_PATH = "/api/tasks/webhook/{job_id}/"
|
||||
WEBHOOK_TIMEOUT_SECONDS = 10
|
||||
|
||||
ERROR_NO_AUDIO_STREAM = "Файл не содержит аудиодорожки"
|
||||
ERROR_UNKNOWN_ENGINE = "Неизвестный движок транскрипции: {engine}"
|
||||
|
||||
ENGINE_MAP: dict[str, str] = {
|
||||
"whisper": "LOCAL_WHISPER",
|
||||
"google": "GOOGLE_SPEECH_CLOUD",
|
||||
}
|
||||
|
||||
MESSAGE_STARTING = "Starting"
|
||||
MESSAGE_COMPLETED = "Completed"
|
||||
MESSAGE_PROBING_MEDIA = "Probing media"
|
||||
MESSAGE_PROCESSING = "Processing"
|
||||
MESSAGE_CONVERTING = "Converting"
|
||||
MESSAGE_RENDERING_CAPTIONS = "Rendering captions"
|
||||
MESSAGE_EXTRACTING_FRAMES = "Извлечение кадров"
|
||||
MESSAGE_UPLOADING_FRAMES = "Загрузка кадров"
|
||||
MESSAGE_DELETING_OLD_FRAMES = "Удаление старых кадров"
|
||||
|
||||
PROGRESS_COMPLETE = 100.0
|
||||
PROGRESS_MEDIA_PROBE = 50.0
|
||||
PROGRESS_SILENCE_REMOVE = 30.0
|
||||
PROGRESS_MEDIA_CONVERT = 30.0
|
||||
PROGRESS_TRANSCRIPTION = 20.0
|
||||
PROGRESS_TRANSCRIPTION_START = 20.0
|
||||
PROGRESS_TRANSCRIPTION_END = 95.0
|
||||
PROGRESS_CAPTIONS = 30.0
|
||||
PROGRESS_FRAME_EXTRACT_START = 10.0
|
||||
PROGRESS_FRAME_EXTRACT_END = 95.0
|
||||
|
||||
PROGRESS_SILENCE_DETECT = 30.0
|
||||
PROGRESS_SILENCE_APPLY = 30.0
|
||||
|
||||
MESSAGE_DETECTING_SILENCE = "Обнаружение тишины"
|
||||
MESSAGE_APPLYING_CUTS = "Применение вырезок"
|
||||
|
||||
PROGRESS_THROTTLE_SECONDS = 3.0
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Dramatiq broker setup
|
||||
@@ -95,6 +135,18 @@ def _utc_now() -> datetime:
|
||||
return datetime.now(timezone.utc)
|
||||
|
||||
|
||||
def _parse_frame_rate(rate_str: str) -> float | None:
|
||||
"""Parse ffprobe frame rate string like '30/1' or '30000/1001'."""
|
||||
try:
|
||||
if "/" in rate_str:
|
||||
num, den = rate_str.split("/")
|
||||
den_val = int(den)
|
||||
return round(int(num) / den_val, 3) if den_val else None
|
||||
return float(rate_str)
|
||||
except (ValueError, ZeroDivisionError):
|
||||
return None
|
||||
|
||||
|
||||
def _build_webhook_url(job_id: uuid.UUID) -> str:
|
||||
"""Build the internal webhook URL for task updates."""
|
||||
settings = get_settings()
|
||||
@@ -267,6 +319,136 @@ def silence_remove_actor(
|
||||
raise
|
||||
|
||||
|
||||
@dramatiq.actor(max_retries=3, min_backoff=1000)
|
||||
def silence_detect_actor(
|
||||
job_id: str,
|
||||
webhook_url: str,
|
||||
file_key: str,
|
||||
min_silence_duration_ms: int,
|
||||
silence_threshold_db: int,
|
||||
padding_ms: int,
|
||||
) -> None:
|
||||
"""Detect silent segments in media file."""
|
||||
from cpv3.modules.media.service import detect_silence
|
||||
|
||||
job_uuid = uuid.UUID(job_id)
|
||||
_send_webhook_event(
|
||||
webhook_url,
|
||||
TaskWebhookEvent(
|
||||
status=JOB_STATUS_RUNNING,
|
||||
current_message=MESSAGE_STARTING,
|
||||
started_at=_utc_now(),
|
||||
),
|
||||
)
|
||||
|
||||
try:
|
||||
storage = _get_storage_service()
|
||||
_send_webhook_event(
|
||||
webhook_url,
|
||||
TaskWebhookEvent(
|
||||
current_message=MESSAGE_DETECTING_SILENCE,
|
||||
progress_pct=PROGRESS_SILENCE_DETECT,
|
||||
),
|
||||
)
|
||||
result = _run_async(
|
||||
detect_silence(
|
||||
storage,
|
||||
file_key=file_key,
|
||||
min_silence_duration_ms=min_silence_duration_ms,
|
||||
silence_threshold_db=silence_threshold_db,
|
||||
padding_ms=padding_ms,
|
||||
)
|
||||
)
|
||||
_send_webhook_event(
|
||||
webhook_url,
|
||||
TaskWebhookEvent(
|
||||
status=JOB_STATUS_DONE,
|
||||
current_message=MESSAGE_COMPLETED,
|
||||
progress_pct=PROGRESS_COMPLETE,
|
||||
output_data=result,
|
||||
finished_at=_utc_now(),
|
||||
),
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.exception("silence_detect_actor failed: %s", job_uuid)
|
||||
_send_webhook_event(
|
||||
webhook_url,
|
||||
TaskWebhookEvent(
|
||||
status=JOB_STATUS_FAILED,
|
||||
error_message=str(exc),
|
||||
finished_at=_utc_now(),
|
||||
),
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
@dramatiq.actor(max_retries=3, min_backoff=1000)
|
||||
def silence_apply_actor(
|
||||
job_id: str,
|
||||
webhook_url: str,
|
||||
file_key: str,
|
||||
out_folder: str,
|
||||
cuts: list[dict],
|
||||
output_name: str | None,
|
||||
) -> None:
|
||||
"""Apply silence cuts to media file."""
|
||||
from cpv3.modules.media.service import apply_silence_cuts
|
||||
|
||||
job_uuid = uuid.UUID(job_id)
|
||||
_send_webhook_event(
|
||||
webhook_url,
|
||||
TaskWebhookEvent(
|
||||
status=JOB_STATUS_RUNNING,
|
||||
current_message=MESSAGE_STARTING,
|
||||
started_at=_utc_now(),
|
||||
),
|
||||
)
|
||||
|
||||
try:
|
||||
storage = _get_storage_service()
|
||||
_send_webhook_event(
|
||||
webhook_url,
|
||||
TaskWebhookEvent(
|
||||
current_message=MESSAGE_APPLYING_CUTS,
|
||||
progress_pct=PROGRESS_SILENCE_APPLY,
|
||||
),
|
||||
)
|
||||
result = _run_async(
|
||||
apply_silence_cuts(
|
||||
storage,
|
||||
file_key=file_key,
|
||||
out_folder=out_folder,
|
||||
cuts=cuts,
|
||||
output_name=output_name,
|
||||
)
|
||||
)
|
||||
_send_webhook_event(
|
||||
webhook_url,
|
||||
TaskWebhookEvent(
|
||||
status=JOB_STATUS_DONE,
|
||||
current_message=MESSAGE_COMPLETED,
|
||||
progress_pct=PROGRESS_COMPLETE,
|
||||
output_data={
|
||||
"file_path": result.file_path,
|
||||
"file_url": result.file_url,
|
||||
"file_size": result.file_size,
|
||||
},
|
||||
finished_at=_utc_now(),
|
||||
),
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.exception("silence_apply_actor failed: %s", job_uuid)
|
||||
_send_webhook_event(
|
||||
webhook_url,
|
||||
TaskWebhookEvent(
|
||||
status=JOB_STATUS_FAILED,
|
||||
error_message=str(exc),
|
||||
finished_at=_utc_now(),
|
||||
),
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
@dramatiq.actor(max_retries=3, min_backoff=1000)
|
||||
def media_convert_actor(
|
||||
job_id: str,
|
||||
@@ -356,19 +538,60 @@ def transcription_generate_actor(
|
||||
)
|
||||
|
||||
try:
|
||||
from cpv3.modules.media.service import probe_media
|
||||
|
||||
storage = _get_storage_service()
|
||||
|
||||
probe = _run_async(probe_media(storage, file_key=file_key))
|
||||
has_audio = any(s.codec_type == "audio" for s in probe.streams)
|
||||
if not has_audio:
|
||||
raise ValueError(ERROR_NO_AUDIO_STREAM)
|
||||
|
||||
# Extract probe metadata for artifact creation
|
||||
duration_seconds = float(probe.format.duration) if probe.format and probe.format.duration else 0.0
|
||||
video_stream = next((s for s in probe.streams if s.codec_type == "video"), None)
|
||||
probe_meta = {
|
||||
"duration_seconds": duration_seconds,
|
||||
"frame_rate": _parse_frame_rate(video_stream.r_frame_rate) if video_stream and video_stream.r_frame_rate else None,
|
||||
"width": video_stream.width if video_stream else None,
|
||||
"height": video_stream.height if video_stream else None,
|
||||
}
|
||||
|
||||
_send_webhook_event(
|
||||
webhook_url,
|
||||
TaskWebhookEvent(
|
||||
current_message=f"Transcribing ({engine})",
|
||||
progress_pct=PROGRESS_TRANSCRIPTION,
|
||||
current_message=f"Транскрибирование ({engine})",
|
||||
progress_pct=PROGRESS_TRANSCRIPTION_START,
|
||||
),
|
||||
)
|
||||
|
||||
last_report_time = time.monotonic()
|
||||
|
||||
def _on_whisper_progress(pct: float) -> None:
|
||||
nonlocal last_report_time
|
||||
now = time.monotonic()
|
||||
if now - last_report_time < PROGRESS_THROTTLE_SECONDS:
|
||||
return
|
||||
last_report_time = now
|
||||
mapped = PROGRESS_TRANSCRIPTION_START + (
|
||||
pct / 100.0
|
||||
) * (PROGRESS_TRANSCRIPTION_END - PROGRESS_TRANSCRIPTION_START)
|
||||
_send_webhook_event(
|
||||
webhook_url,
|
||||
TaskWebhookEvent(
|
||||
current_message=f"Транскрибирование ({engine})",
|
||||
progress_pct=round(mapped, 1),
|
||||
),
|
||||
)
|
||||
|
||||
if engine == "whisper":
|
||||
document = _run_async(
|
||||
transcribe_with_whisper(
|
||||
storage, file_key=file_key, model_name=model, language=language
|
||||
storage,
|
||||
file_key=file_key,
|
||||
model_name=model,
|
||||
language=language,
|
||||
on_progress=_on_whisper_progress,
|
||||
)
|
||||
)
|
||||
elif engine == "google":
|
||||
@@ -379,7 +602,7 @@ def transcription_generate_actor(
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown engine: {engine}")
|
||||
raise ValueError(ERROR_UNKNOWN_ENGINE.format(engine=engine))
|
||||
|
||||
_send_webhook_event(
|
||||
webhook_url,
|
||||
@@ -387,7 +610,22 @@ def transcription_generate_actor(
|
||||
status=JOB_STATUS_DONE,
|
||||
current_message=MESSAGE_COMPLETED,
|
||||
progress_pct=PROGRESS_COMPLETE,
|
||||
output_data={"document": document.model_dump(mode="json")},
|
||||
output_data={
|
||||
"document": document.model_dump(mode="json"),
|
||||
"probe": probe_meta,
|
||||
},
|
||||
finished_at=_utc_now(),
|
||||
),
|
||||
)
|
||||
except (ValueError, RuntimeError) as exc:
|
||||
logger.exception(
|
||||
"transcription_generate_actor failed (non-transient): %s", job_uuid
|
||||
)
|
||||
_send_webhook_event(
|
||||
webhook_url,
|
||||
TaskWebhookEvent(
|
||||
status=JOB_STATUS_FAILED,
|
||||
error_message=str(exc),
|
||||
finished_at=_utc_now(),
|
||||
),
|
||||
)
|
||||
@@ -463,6 +701,115 @@ def captions_generate_actor(
|
||||
raise
|
||||
|
||||
|
||||
@dramatiq.actor(max_retries=2, min_backoff=2000)
|
||||
def frame_extract_actor(
|
||||
job_id: str,
|
||||
webhook_url: str,
|
||||
file_key: str,
|
||||
frames_folder: str,
|
||||
regenerate: bool,
|
||||
) -> None:
|
||||
"""Extract video frames at 1fps for timeline thumbnails."""
|
||||
from cpv3.modules.media.service import (
|
||||
delete_frames,
|
||||
extract_frames,
|
||||
read_frames_metadata,
|
||||
)
|
||||
|
||||
job_uuid = uuid.UUID(job_id)
|
||||
_send_webhook_event(
|
||||
webhook_url,
|
||||
TaskWebhookEvent(
|
||||
status=JOB_STATUS_RUNNING,
|
||||
current_message=MESSAGE_STARTING,
|
||||
started_at=_utc_now(),
|
||||
),
|
||||
)
|
||||
|
||||
try:
|
||||
storage = _get_storage_service()
|
||||
|
||||
# Delete old frames if regenerating
|
||||
if regenerate:
|
||||
_send_webhook_event(
|
||||
webhook_url,
|
||||
TaskWebhookEvent(
|
||||
current_message=MESSAGE_DELETING_OLD_FRAMES,
|
||||
progress_pct=PROGRESS_FRAME_EXTRACT_START,
|
||||
),
|
||||
)
|
||||
old_meta = _run_async(
|
||||
read_frames_metadata(storage, frames_folder=frames_folder)
|
||||
)
|
||||
if old_meta is not None:
|
||||
_run_async(
|
||||
delete_frames(
|
||||
storage,
|
||||
frames_folder=frames_folder,
|
||||
frame_count=old_meta.frame_count,
|
||||
)
|
||||
)
|
||||
|
||||
_send_webhook_event(
|
||||
webhook_url,
|
||||
TaskWebhookEvent(
|
||||
current_message=MESSAGE_EXTRACTING_FRAMES,
|
||||
progress_pct=PROGRESS_FRAME_EXTRACT_START,
|
||||
),
|
||||
)
|
||||
|
||||
last_report_time = time.monotonic()
|
||||
|
||||
def _on_progress(current: int, total: int) -> None:
|
||||
nonlocal last_report_time
|
||||
now = time.monotonic()
|
||||
if now - last_report_time < PROGRESS_THROTTLE_SECONDS:
|
||||
return
|
||||
last_report_time = now
|
||||
pct = current / total
|
||||
mapped = PROGRESS_FRAME_EXTRACT_START + pct * (
|
||||
PROGRESS_FRAME_EXTRACT_END - PROGRESS_FRAME_EXTRACT_START
|
||||
)
|
||||
_send_webhook_event(
|
||||
webhook_url,
|
||||
TaskWebhookEvent(
|
||||
current_message=MESSAGE_UPLOADING_FRAMES,
|
||||
progress_pct=round(mapped, 1),
|
||||
),
|
||||
)
|
||||
|
||||
metadata = _run_async(
|
||||
extract_frames(
|
||||
storage,
|
||||
file_key=file_key,
|
||||
frames_folder=frames_folder,
|
||||
on_progress=_on_progress,
|
||||
)
|
||||
)
|
||||
|
||||
_send_webhook_event(
|
||||
webhook_url,
|
||||
TaskWebhookEvent(
|
||||
status=JOB_STATUS_DONE,
|
||||
current_message=MESSAGE_COMPLETED,
|
||||
progress_pct=PROGRESS_COMPLETE,
|
||||
output_data=metadata.model_dump(mode="json"),
|
||||
finished_at=_utc_now(),
|
||||
),
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.exception("frame_extract_actor failed: %s", job_uuid)
|
||||
_send_webhook_event(
|
||||
webhook_url,
|
||||
TaskWebhookEvent(
|
||||
status=JOB_STATUS_FAILED,
|
||||
error_message=str(exc),
|
||||
finished_at=_utc_now(),
|
||||
),
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Task Service
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -557,8 +904,193 @@ class TaskService:
|
||||
await self._event_repo.create(
|
||||
JobEventCreate(job_id=job.id, event_type=event_type, payload=payload)
|
||||
)
|
||||
|
||||
# Save artifacts BEFORE sending notifications so data exists when frontend refetches
|
||||
if (
|
||||
job.job_type == JOB_TYPE_TRANSCRIPTION_GENERATE
|
||||
and event.status == JOB_STATUS_DONE
|
||||
):
|
||||
try:
|
||||
await self._save_transcription_artifacts(job)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to save transcription artifacts for job %s", job_id
|
||||
)
|
||||
|
||||
if job.job_type == JOB_TYPE_MEDIA_CONVERT and event.status == JOB_STATUS_DONE:
|
||||
try:
|
||||
await self._save_convert_artifacts(job)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to save convert artifacts for job %s", job_id
|
||||
)
|
||||
|
||||
# Push real-time notification via WebSocket (after artifacts are persisted)
|
||||
if job.user_id is not None:
|
||||
try:
|
||||
notification_service = NotificationService(self._session)
|
||||
await notification_service.create_task_notification(
|
||||
user_id=job.user_id, job=job, event=event
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Failed to create notification for job %s", job_id)
|
||||
|
||||
return job
|
||||
|
||||
async def _save_transcription_artifacts(self, job: Job) -> None:
|
||||
"""Create Transcription, ArtifactMediaFile and File records."""
|
||||
input_data = job.input_data or {}
|
||||
output_data = job.output_data or {}
|
||||
|
||||
file_key: str = input_data["file_key"]
|
||||
project_id: uuid.UUID | None = (
|
||||
uuid.UUID(input_data["project_id"]) if input_data.get("project_id") else None
|
||||
)
|
||||
engine_raw: str = input_data.get("engine", "whisper")
|
||||
language: str | None = input_data.get("language")
|
||||
|
||||
document: dict = output_data["document"]
|
||||
|
||||
# Resolve user
|
||||
user_repo = UserRepository(self._session)
|
||||
user = await user_repo.get_by_id(job.user_id) # type: ignore[arg-type]
|
||||
if user is None:
|
||||
logger.warning("User %s not found, skipping artifact save", job.user_id)
|
||||
return
|
||||
|
||||
# Find or create source File record
|
||||
file_repo = FileRepository(self._session)
|
||||
source_file = await file_repo.get_by_path(file_key)
|
||||
if source_file is None:
|
||||
source_file = await file_repo.create(
|
||||
requester=user,
|
||||
data=FileCreate(
|
||||
project_id=project_id,
|
||||
original_filename=file_key.rsplit("/", 1)[-1],
|
||||
path=file_key,
|
||||
storage_backend="S3",
|
||||
mime_type="application/octet-stream",
|
||||
size_bytes=0,
|
||||
is_uploaded=True,
|
||||
),
|
||||
)
|
||||
|
||||
# Upload document JSON to S3
|
||||
storage = _get_storage_service()
|
||||
user_folder = get_user_folder(user)
|
||||
json_bytes = json.dumps(document, ensure_ascii=False).encode("utf-8")
|
||||
|
||||
# Build display name: "Транскрипция <video_name>.json"
|
||||
video_stem = Path(source_file.original_filename).stem
|
||||
transcription_filename = f"Транскрипция {video_stem}.json"
|
||||
|
||||
artifact_key = await storage.upload_fileobj(
|
||||
fileobj=io.BytesIO(json_bytes),
|
||||
file_name=transcription_filename,
|
||||
folder=f"{user_folder}/artifacts",
|
||||
gen_name=True,
|
||||
content_type="application/json",
|
||||
)
|
||||
|
||||
# Create File record for the JSON artifact (no project_id — only reachable via artifact)
|
||||
json_file = await file_repo.create(
|
||||
requester=user,
|
||||
data=FileCreate(
|
||||
project_id=None,
|
||||
original_filename=transcription_filename,
|
||||
path=artifact_key,
|
||||
storage_backend="S3",
|
||||
mime_type="application/json",
|
||||
size_bytes=len(json_bytes),
|
||||
file_format="json",
|
||||
is_uploaded=True,
|
||||
),
|
||||
)
|
||||
|
||||
# Create ArtifactMediaFile (no media_file_id — transcription is not a media file)
|
||||
artifact_repo = ArtifactRepository(self._session)
|
||||
artifact = await artifact_repo.create(
|
||||
data=ArtifactMediaFileCreate(
|
||||
project_id=project_id,
|
||||
file_id=json_file.id,
|
||||
media_file_id=None,
|
||||
artifact_type="TRANSCRIPTION_JSON",
|
||||
),
|
||||
)
|
||||
|
||||
# Create Transcription record
|
||||
transcription_repo = TranscriptionRepository(self._session)
|
||||
engine_db = ENGINE_MAP.get(engine_raw, "LOCAL_WHISPER")
|
||||
await transcription_repo.create(
|
||||
data=TranscriptionCreate(
|
||||
project_id=project_id,
|
||||
source_file_id=source_file.id,
|
||||
artifact_id=artifact.id,
|
||||
engine=engine_db, # type: ignore[arg-type]
|
||||
language=language,
|
||||
document=document,
|
||||
),
|
||||
)
|
||||
|
||||
logger.info("Saved transcription artifacts for job %s", job.id)
|
||||
|
||||
async def _save_convert_artifacts(self, job: Job) -> None:
|
||||
"""Create File and ArtifactMediaFile records for converted MP4."""
|
||||
input_data = job.input_data or {}
|
||||
output_data = job.output_data or {}
|
||||
|
||||
file_key: str = input_data["file_key"]
|
||||
project_id: uuid.UUID | None = (
|
||||
uuid.UUID(input_data["project_id"]) if input_data.get("project_id") else None
|
||||
)
|
||||
|
||||
file_path: str = output_data["file_path"]
|
||||
file_size: int = output_data.get("file_size", 0)
|
||||
|
||||
# Resolve user
|
||||
user_repo = UserRepository(self._session)
|
||||
user = await user_repo.get_by_id(job.user_id) # type: ignore[arg-type]
|
||||
if user is None:
|
||||
logger.warning("User %s not found, skipping convert artifact save", job.user_id)
|
||||
return
|
||||
|
||||
# Derive output filename from source file
|
||||
file_repo = FileRepository(self._session)
|
||||
source_file = await file_repo.get_by_path(file_key)
|
||||
if source_file is not None:
|
||||
stem = Path(source_file.original_filename).stem
|
||||
else:
|
||||
stem = Path(file_key).stem
|
||||
converted_filename = f"{stem}.mp4"
|
||||
|
||||
# Create File record for the converted MP4 (no project_id — only reachable via artifact)
|
||||
converted_file = await file_repo.create(
|
||||
requester=user,
|
||||
data=FileCreate(
|
||||
project_id=None,
|
||||
original_filename=converted_filename,
|
||||
path=file_path,
|
||||
storage_backend="S3",
|
||||
mime_type="video/mp4",
|
||||
size_bytes=file_size,
|
||||
file_format="mp4",
|
||||
is_uploaded=True,
|
||||
),
|
||||
)
|
||||
|
||||
# Create ArtifactMediaFile record
|
||||
artifact_repo = ArtifactRepository(self._session)
|
||||
await artifact_repo.create(
|
||||
data=ArtifactMediaFileCreate(
|
||||
project_id=project_id,
|
||||
file_id=converted_file.id,
|
||||
media_file_id=None,
|
||||
artifact_type="CONVERTED_VIDEO",
|
||||
),
|
||||
)
|
||||
|
||||
logger.info("Saved convert artifacts for job %s", job.id)
|
||||
|
||||
async def submit_media_probe(
|
||||
self, *, requester: User, request: MediaProbeRequest
|
||||
) -> TaskSubmitResponse:
|
||||
@@ -576,6 +1108,12 @@ class TaskService:
|
||||
self, *, requester: User, request: SilenceRemoveRequest
|
||||
) -> TaskSubmitResponse:
|
||||
"""Submit silence removal task."""
|
||||
user_folder = get_user_folder(requester)
|
||||
resolved_folder = (
|
||||
f"{user_folder}/{request.out_folder}"
|
||||
if request.out_folder
|
||||
else f"{user_folder}/output_files"
|
||||
)
|
||||
return await self._submit_task(
|
||||
requester=requester,
|
||||
job_type=JOB_TYPE_SILENCE_REMOVE,
|
||||
@@ -584,17 +1122,65 @@ class TaskService:
|
||||
actor=silence_remove_actor,
|
||||
actor_kwargs={
|
||||
"file_key": request.file_key,
|
||||
"out_folder": request.out_folder,
|
||||
"out_folder": resolved_folder,
|
||||
"min_silence_duration_ms": request.min_silence_duration_ms,
|
||||
"silence_threshold_db": request.silence_threshold_db,
|
||||
"padding_ms": request.padding_ms,
|
||||
},
|
||||
)
|
||||
|
||||
async def submit_silence_detect(
|
||||
self, *, requester: User, request: SilenceDetectRequest
|
||||
) -> TaskSubmitResponse:
|
||||
"""Submit silence detection task."""
|
||||
return await self._submit_task(
|
||||
requester=requester,
|
||||
job_type=JOB_TYPE_SILENCE_DETECT,
|
||||
project_id=request.project_id,
|
||||
input_data=request.model_dump(mode="json"),
|
||||
actor=silence_detect_actor,
|
||||
actor_kwargs={
|
||||
"file_key": request.file_key,
|
||||
"min_silence_duration_ms": request.min_silence_duration_ms,
|
||||
"silence_threshold_db": request.silence_threshold_db,
|
||||
"padding_ms": request.padding_ms,
|
||||
},
|
||||
)
|
||||
|
||||
async def submit_silence_apply(
|
||||
self, *, requester: User, request: SilenceApplyRequest
|
||||
) -> TaskSubmitResponse:
|
||||
"""Submit silence apply task."""
|
||||
user_folder = get_user_folder(requester)
|
||||
resolved_folder = (
|
||||
f"{user_folder}/{request.out_folder}"
|
||||
if request.out_folder
|
||||
else f"{user_folder}/output_files"
|
||||
)
|
||||
return await self._submit_task(
|
||||
requester=requester,
|
||||
job_type=JOB_TYPE_SILENCE_APPLY,
|
||||
project_id=request.project_id,
|
||||
input_data=request.model_dump(mode="json"),
|
||||
actor=silence_apply_actor,
|
||||
actor_kwargs={
|
||||
"file_key": request.file_key,
|
||||
"out_folder": resolved_folder,
|
||||
"cuts": request.cuts,
|
||||
"output_name": request.output_name,
|
||||
},
|
||||
)
|
||||
|
||||
async def submit_media_convert(
|
||||
self, *, requester: User, request: MediaConvertRequest
|
||||
) -> TaskSubmitResponse:
|
||||
"""Submit media conversion task."""
|
||||
user_folder = get_user_folder(requester)
|
||||
resolved_folder = (
|
||||
f"{user_folder}/{request.out_folder}"
|
||||
if request.out_folder
|
||||
else f"{user_folder}/output_files"
|
||||
)
|
||||
return await self._submit_task(
|
||||
requester=requester,
|
||||
job_type=JOB_TYPE_MEDIA_CONVERT,
|
||||
@@ -603,7 +1189,7 @@ class TaskService:
|
||||
actor=media_convert_actor,
|
||||
actor_kwargs={
|
||||
"file_key": request.file_key,
|
||||
"out_folder": request.out_folder,
|
||||
"out_folder": resolved_folder,
|
||||
"output_format": request.output_format,
|
||||
},
|
||||
)
|
||||
@@ -626,6 +1212,28 @@ class TaskService:
|
||||
},
|
||||
)
|
||||
|
||||
async def submit_frame_extract(
|
||||
self, *, requester: User, request: FrameExtractRequest
|
||||
) -> TaskSubmitResponse:
|
||||
"""Submit frame extraction task."""
|
||||
from cpv3.modules.media.service import get_frames_folder
|
||||
|
||||
user_folder = get_user_folder(requester)
|
||||
frames_folder = get_frames_folder(user_folder, request.file_key)
|
||||
|
||||
return await self._submit_task(
|
||||
requester=requester,
|
||||
job_type=JOB_TYPE_FRAME_EXTRACT,
|
||||
project_id=request.project_id,
|
||||
input_data=request.model_dump(mode="json"),
|
||||
actor=frame_extract_actor,
|
||||
actor_kwargs={
|
||||
"file_key": request.file_key,
|
||||
"frames_folder": frames_folder,
|
||||
"regenerate": request.regenerate,
|
||||
},
|
||||
)
|
||||
|
||||
async def submit_captions_generate(
|
||||
self, *, requester: User, request: CaptionsGenerateRequest
|
||||
) -> TaskSubmitResponse:
|
||||
@@ -635,6 +1243,13 @@ class TaskService:
|
||||
if transcription is None:
|
||||
raise ValueError(f"Transcription {request.transcription_id} not found")
|
||||
|
||||
user_folder = get_user_folder(requester)
|
||||
resolved_folder = (
|
||||
f"{user_folder}/{request.folder}"
|
||||
if request.folder
|
||||
else f"{user_folder}/output_files"
|
||||
)
|
||||
|
||||
return await self._submit_task(
|
||||
requester=requester,
|
||||
job_type=JOB_TYPE_CAPTIONS_GENERATE,
|
||||
@@ -643,7 +1258,7 @@ class TaskService:
|
||||
actor=captions_generate_actor,
|
||||
actor_kwargs={
|
||||
"video_s3_path": request.video_s3_path,
|
||||
"folder": request.folder,
|
||||
"folder": resolved_folder,
|
||||
"transcription_json": transcription.document,
|
||||
},
|
||||
)
|
||||
|
||||
@@ -32,6 +32,14 @@ class TranscriptionRepository:
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_by_artifact_id(self, artifact_id: uuid.UUID) -> Transcription | None:
|
||||
result = await self._session.execute(
|
||||
select(Transcription)
|
||||
.where(Transcription.artifact_id == artifact_id)
|
||||
.where(Transcription.is_active.is_(True))
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def create(self, data: TranscriptionCreate) -> Transcription:
|
||||
transcription = Transcription(
|
||||
project_id=data.project_id,
|
||||
|
||||
@@ -67,6 +67,21 @@ async def retrieve_transcription_entry(
|
||||
return TranscriptionRead.model_validate(transcription)
|
||||
|
||||
|
||||
@router.get("/transcriptions/by-artifact/{artifact_id}/", response_model=TranscriptionRead)
|
||||
async def retrieve_transcription_by_artifact(
|
||||
artifact_id: uuid.UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> TranscriptionRead:
|
||||
_ = current_user
|
||||
repo = TranscriptionRepository(db)
|
||||
transcription = await repo.get_by_artifact_id(artifact_id)
|
||||
if transcription is None:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Not found")
|
||||
|
||||
return TranscriptionRead.model_validate(transcription)
|
||||
|
||||
|
||||
@router.patch("/transcriptions/{transcription_id}/", response_model=TranscriptionRead)
|
||||
async def patch_transcription_entry(
|
||||
transcription_id: uuid.UUID,
|
||||
|
||||
@@ -240,11 +240,15 @@ def _make_document_from_segments(
|
||||
return Document(segments=result_segments)
|
||||
|
||||
|
||||
ProgressCallback = Callable[[float], None]
|
||||
|
||||
|
||||
def _whisper_transcribe_sync(
|
||||
*,
|
||||
local_file_path: str,
|
||||
model_name: str,
|
||||
language: str | None,
|
||||
on_progress: ProgressCallback | None = None,
|
||||
) -> Document:
|
||||
import whisper # type: ignore[import-untyped]
|
||||
|
||||
@@ -267,14 +271,35 @@ def _whisper_transcribe_sync(
|
||||
probs = cast(dict[str, float], probs_raw)
|
||||
language = max(probs, key=lambda k: probs[k])
|
||||
|
||||
result = whisper.transcribe(
|
||||
audio=whisper.load_audio(local_file_path),
|
||||
model=model,
|
||||
word_timestamps=True,
|
||||
temperature=0.2,
|
||||
language=language,
|
||||
verbose=False,
|
||||
)
|
||||
if on_progress is not None:
|
||||
from unittest.mock import patch
|
||||
from tqdm import tqdm as _orig_tqdm
|
||||
|
||||
class _ProgressTqdm(_orig_tqdm):
|
||||
def update(self, n=1):
|
||||
super().update(n)
|
||||
if self.total:
|
||||
on_progress(min(self.n / self.total * 100.0, 100.0))
|
||||
|
||||
with patch("whisper.transcribe.tqdm.tqdm", _ProgressTqdm):
|
||||
result = whisper.transcribe(
|
||||
audio=whisper.load_audio(local_file_path),
|
||||
model=model,
|
||||
word_timestamps=True,
|
||||
temperature=0.2,
|
||||
language=language,
|
||||
verbose=False,
|
||||
)
|
||||
on_progress(100.0)
|
||||
else:
|
||||
result = whisper.transcribe(
|
||||
audio=whisper.load_audio(local_file_path),
|
||||
model=model,
|
||||
word_timestamps=True,
|
||||
temperature=0.2,
|
||||
language=language,
|
||||
verbose=None,
|
||||
)
|
||||
|
||||
parsed = WhisperResult.model_validate(result)
|
||||
|
||||
@@ -296,6 +321,7 @@ async def transcribe_with_whisper(
|
||||
file_key: str,
|
||||
model_name: str = "tiny",
|
||||
language: str | None = None,
|
||||
on_progress: ProgressCallback | None = None,
|
||||
) -> Document:
|
||||
tmp = await storage.download_to_temp(file_key)
|
||||
try:
|
||||
@@ -304,6 +330,7 @@ async def transcribe_with_whisper(
|
||||
local_file_path=tmp.path,
|
||||
model_name=model_name,
|
||||
language=language,
|
||||
on_progress=on_progress,
|
||||
)
|
||||
)
|
||||
finally:
|
||||
|
||||
@@ -71,6 +71,10 @@ class UserRepository:
|
||||
await self._session.refresh(user)
|
||||
return user
|
||||
|
||||
async def update_password(self, user: User, new_hash: str) -> None:
|
||||
user.password_hash = new_hash
|
||||
await self._session.commit()
|
||||
|
||||
async def deactivate(self, user: User) -> None:
|
||||
user.is_active = False
|
||||
await self._session.commit()
|
||||
|
||||
@@ -2,17 +2,21 @@ from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from datetime import timedelta
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Response, status
|
||||
from jwt import ExpiredSignatureError, InvalidTokenError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from cpv3.infrastructure.auth import get_current_user
|
||||
from cpv3.infrastructure.deps import get_storage
|
||||
from cpv3.infrastructure.security import create_token, decode_token
|
||||
from cpv3.infrastructure.settings import get_settings
|
||||
from cpv3.infrastructure.storage.base import StorageService
|
||||
from cpv3.db.session import get_db
|
||||
from cpv3.modules.users.models import User
|
||||
from cpv3.modules.users.schemas import (
|
||||
PasswordChange,
|
||||
TokenRefresh,
|
||||
TokenRefreshResponse,
|
||||
UserCreate,
|
||||
@@ -28,6 +32,21 @@ users_router = APIRouter(prefix="/api/users", tags=["Users"])
|
||||
auth_router = APIRouter(prefix="/auth", tags=["auth"])
|
||||
|
||||
|
||||
def _is_s3_key(value: str) -> bool:
|
||||
"""Return True if value looks like a bare S3 key, not a full URL."""
|
||||
parsed = urlparse(value)
|
||||
return not parsed.scheme and not parsed.netloc
|
||||
|
||||
|
||||
async def _resolve_avatar(user: User, storage: StorageService) -> UserRead:
|
||||
"""Build UserRead with a fresh presigned avatar URL."""
|
||||
data = UserRead.model_validate(user)
|
||||
if data.avatar:
|
||||
if _is_s3_key(data.avatar):
|
||||
data.avatar = await storage.url(data.avatar)
|
||||
return data
|
||||
|
||||
|
||||
def _issue_tokens(user: User) -> tuple[str, str]:
|
||||
settings = get_settings()
|
||||
|
||||
@@ -49,10 +68,11 @@ def _issue_tokens(user: User) -> tuple[str, str]:
|
||||
async def list_all_users(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
storage: StorageService = Depends(get_storage),
|
||||
) -> list[UserRead]:
|
||||
service = UserService(db)
|
||||
users = await service.list_users(requester=current_user)
|
||||
return [UserRead.model_validate(u) for u in users]
|
||||
return [await _resolve_avatar(u, storage) for u in users]
|
||||
|
||||
|
||||
@users_router.post("/", response_model=UserRead, status_code=status.HTTP_201_CREATED)
|
||||
@@ -60,6 +80,7 @@ async def create_user_endpoint(
|
||||
body: UserCreate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
storage: StorageService = Depends(get_storage),
|
||||
) -> UserRead:
|
||||
service = UserService(db)
|
||||
try:
|
||||
@@ -67,12 +88,29 @@ async def create_user_endpoint(
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) from e
|
||||
|
||||
return UserRead.model_validate(user)
|
||||
return await _resolve_avatar(user, storage)
|
||||
|
||||
|
||||
@users_router.get("/me/", response_model=UserRead)
|
||||
async def me(current_user: User = Depends(get_current_user)) -> UserRead:
|
||||
return UserRead.model_validate(current_user)
|
||||
async def me(
|
||||
current_user: User = Depends(get_current_user),
|
||||
storage: StorageService = Depends(get_storage),
|
||||
) -> UserRead:
|
||||
return await _resolve_avatar(current_user, storage)
|
||||
|
||||
|
||||
@users_router.post("/me/change-password/", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def change_password(
|
||||
body: PasswordChange,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Response:
|
||||
service = UserService(db)
|
||||
try:
|
||||
await service.change_password(current_user, body.current_password, body.new_password)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) from e
|
||||
return Response(status_code=status.HTTP_204_NO_CONTENT)
|
||||
|
||||
|
||||
@users_router.get("/{user_id}/", response_model=UserRead)
|
||||
@@ -80,6 +118,7 @@ async def retrieve_user(
|
||||
user_id: uuid.UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
storage: StorageService = Depends(get_storage),
|
||||
) -> UserRead:
|
||||
service = UserService(db)
|
||||
user = await service.get_user_by_id(user_id)
|
||||
@@ -89,7 +128,7 @@ async def retrieve_user(
|
||||
if not current_user.is_staff and user.id != current_user.id:
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Forbidden")
|
||||
|
||||
return UserRead.model_validate(user)
|
||||
return await _resolve_avatar(user, storage)
|
||||
|
||||
|
||||
@users_router.patch("/{user_id}/", response_model=UserRead)
|
||||
@@ -98,6 +137,7 @@ async def patch_user(
|
||||
body: UserUpdate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
storage: StorageService = Depends(get_storage),
|
||||
) -> UserRead:
|
||||
service = UserService(db)
|
||||
user = await service.get_user_by_id(user_id)
|
||||
@@ -112,7 +152,7 @@ async def patch_user(
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) from e
|
||||
|
||||
return UserRead.model_validate(user)
|
||||
return await _resolve_avatar(user, storage)
|
||||
|
||||
|
||||
@users_router.delete("/{user_id}/", status_code=status.HTTP_204_NO_CONTENT)
|
||||
@@ -136,7 +176,11 @@ async def delete_user(
|
||||
@auth_router.post(
|
||||
"/register", response_model=UserRegisterResponse, status_code=status.HTTP_201_CREATED
|
||||
)
|
||||
async def register(body: UserRegister, db: AsyncSession = Depends(get_db)) -> UserRegisterResponse:
|
||||
async def register(
|
||||
body: UserRegister,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
storage: StorageService = Depends(get_storage),
|
||||
) -> UserRegisterResponse:
|
||||
service = UserService(db)
|
||||
try:
|
||||
user = await service.register_user(body)
|
||||
@@ -144,18 +188,24 @@ async def register(body: UserRegister, db: AsyncSession = Depends(get_db)) -> Us
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) from e
|
||||
|
||||
access, refresh = _issue_tokens(user)
|
||||
return UserRegisterResponse(user=UserRead.model_validate(user), access=access, refresh=refresh)
|
||||
user_read = await _resolve_avatar(user, storage)
|
||||
return UserRegisterResponse(user=user_read, access=access, refresh=refresh)
|
||||
|
||||
|
||||
@auth_router.post("/login", response_model=UserRegisterResponse)
|
||||
async def login(body: UserLogin, db: AsyncSession = Depends(get_db)) -> UserRegisterResponse:
|
||||
async def login(
|
||||
body: UserLogin,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
storage: StorageService = Depends(get_storage),
|
||||
) -> UserRegisterResponse:
|
||||
service = UserService(db)
|
||||
user = await service.authenticate(body.username, body.password)
|
||||
if user is None:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid credentials")
|
||||
|
||||
access, refresh = _issue_tokens(user)
|
||||
return UserRegisterResponse(user=UserRead.model_validate(user), access=access, refresh=refresh)
|
||||
user_read = await _resolve_avatar(user, storage)
|
||||
return UserRegisterResponse(user=user_read, access=access, refresh=refresh)
|
||||
|
||||
|
||||
@auth_router.post("/refresh", response_model=TokenRefreshResponse)
|
||||
|
||||
@@ -66,6 +66,11 @@ class UserRegisterResponse(Schema):
|
||||
refresh: str
|
||||
|
||||
|
||||
class PasswordChange(Schema):
|
||||
current_password: str
|
||||
new_password: str
|
||||
|
||||
|
||||
class TokenRefresh(Schema):
|
||||
refresh: str
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ import uuid
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from cpv3.infrastructure.security import verify_password
|
||||
from cpv3.infrastructure.security import hash_password, verify_password
|
||||
from cpv3.modules.users.models import User
|
||||
from cpv3.modules.users.repository import UserRepository
|
||||
from cpv3.modules.users.schemas import UserCreate, UserRegister, UserUpdate
|
||||
@@ -40,6 +40,12 @@ class UserService:
|
||||
async def deactivate_user(self, user: User) -> None:
|
||||
await self._repo.deactivate(user)
|
||||
|
||||
async def change_password(self, user: User, current_password: str, new_password: str) -> None:
|
||||
if not verify_password(current_password, user.password_hash):
|
||||
raise ValueError("Current password is incorrect")
|
||||
new_hash = hash_password(new_password)
|
||||
await self._repo.update_password(user, new_hash)
|
||||
|
||||
async def authenticate(self, username: str, password: str) -> User | None:
|
||||
user = await self._repo.get_by_username(username)
|
||||
if user is None:
|
||||
@@ -87,6 +93,13 @@ async def deactivate_user(session: AsyncSession, user: User) -> None:
|
||||
await service.deactivate_user(user)
|
||||
|
||||
|
||||
async def change_password(
|
||||
session: AsyncSession, user: User, current_password: str, new_password: str
|
||||
) -> None:
|
||||
service = UserService(session)
|
||||
await service.change_password(user, current_password, new_password)
|
||||
|
||||
|
||||
async def authenticate(session: AsyncSession, username: str, password: str) -> User | None:
|
||||
service = UserService(session)
|
||||
return await service.authenticate(username, password)
|
||||
|
||||
Reference in New Issue
Block a user