new features

This commit is contained in:
Daniil
2026-02-27 23:33:56 +03:00
parent 937e58859a
commit dc04efe0fb
41 changed files with 2067 additions and 141 deletions
+6
View File
@@ -33,6 +33,12 @@ class FileRepository:
return None
return file
async def get_by_path(self, path: str) -> File | None:
result = await self._session.execute(
select(File).where(File.path == path, File.is_deleted.is_(False))
)
return result.scalar_one_or_none()
async def create(self, *, requester: User, data: FileCreate) -> File:
file = File(
owner_id=requester.id,
+10 -5
View File
@@ -19,6 +19,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from cpv3.infrastructure.auth import get_current_user
from cpv3.infrastructure.deps import get_storage
from cpv3.infrastructure.storage.base import StorageService
from cpv3.infrastructure.storage.utils import get_user_folder
from cpv3.infrastructure.settings import get_settings
from cpv3.db.session import get_db
from cpv3.modules.files.schemas import (
@@ -32,7 +33,7 @@ from cpv3.modules.users.models import User
router = APIRouter(prefix="/api/files", tags=["Files"])
MAX_MB_SIZE = 100
MAX_MB_SIZE = 1024
@router.post(
@@ -44,8 +45,6 @@ async def upload_file(
current_user: User = Depends(get_current_user),
storage: StorageService = Depends(get_storage),
) -> FileInfoResponse:
_ = current_user
# Validate max file size (matches old behavior).
file.file.seek(0, 2)
size_bytes = file.file.tell()
@@ -58,10 +57,13 @@ async def upload_file(
detail=f"File size exceeds the maximum limit of {MAX_MB_SIZE} MB.",
)
user_folder = get_user_folder(current_user)
resolved_folder = f"{user_folder}/{folder}" if folder else f"{user_folder}/user_upload"
key = await storage.upload_fileobj(
fileobj=file.file,
file_name=file.filename or "upload.bin",
folder=folder,
folder=resolved_folder,
gen_name=True,
content_type=file.content_type,
)
@@ -81,7 +83,10 @@ async def get_file_info(
current_user: User = Depends(get_current_user),
storage: StorageService = Depends(get_storage),
) -> FileInfoResponse:
_ = current_user
if not current_user.is_staff:
user_prefix = f"{get_user_folder(current_user)}/"
if not file_path.startswith(user_prefix):
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Forbidden")
if not await storage.exists(file_path):
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Not found")
+3
View File
@@ -11,9 +11,12 @@ JobStatusEnum = Literal["PENDING", "RUNNING", "FAILED", "CANCELLED", "DONE"]
JobTypeEnum = Literal[
"MEDIA_PROBE",
"SILENCE_REMOVE",
"SILENCE_DETECT",
"SILENCE_APPLY",
"MEDIA_CONVERT",
"TRANSCRIPTION_GENERATE",
"CAPTIONS_GENERATE",
"FRAME_EXTRACT",
]
+5 -2
View File
@@ -46,8 +46,11 @@ class ArtifactMediaFile(Base, BaseModelMixin):
file_id: Mapped[uuid.UUID | None] = mapped_column(
UUID(as_uuid=True), ForeignKey("files.id", ondelete="RESTRICT"), nullable=True, index=True
)
media_file_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True), ForeignKey("media_files.id", ondelete="RESTRICT"), index=True
media_file_id: Mapped[uuid.UUID | None] = mapped_column(
UUID(as_uuid=True),
ForeignKey("media_files.id", ondelete="RESTRICT"),
nullable=True,
index=True,
)
artifact_type: Mapped[str] = mapped_column(String(32), default="TRANSCRIPTION_JSON")
+48 -5
View File
@@ -1,6 +1,8 @@
from __future__ import annotations
import math
import uuid
from os import path
from fastapi import APIRouter, Depends, HTTPException, Query, Response, status
from sqlalchemy.ext.asyncio import AsyncSession
@@ -8,11 +10,14 @@ from sqlalchemy.ext.asyncio import AsyncSession
from cpv3.infrastructure.auth import get_current_user
from cpv3.infrastructure.deps import get_storage
from cpv3.infrastructure.storage.base import StorageService
from cpv3.infrastructure.storage.utils import get_user_folder
from cpv3.db.session import get_db
from cpv3.modules.media.schemas import (
ArtifactMediaFileCreate,
ArtifactMediaFileRead,
ArtifactMediaFileUpdate,
FrameItem,
FrameRangeResponse,
MediaConverterParams,
MediaFileCreate,
MediaFileRead,
@@ -20,7 +25,13 @@ from cpv3.modules.media.schemas import (
MediaProbeSchema,
MediaSilencerParams,
)
from cpv3.modules.media.service import convert_to_mp4, probe_media, remove_silence
from cpv3.modules.media.service import (
convert_to_mp4,
get_frames_folder,
probe_media,
read_frames_metadata,
remove_silence,
)
from cpv3.modules.media.repository import ArtifactRepository, MediaFileRepository
from cpv3.modules.files.schemas import FileInfoResponse
from cpv3.modules.users.models import User
@@ -46,12 +57,13 @@ async def silence_remove(
current_user: User = Depends(get_current_user),
storage: StorageService = Depends(get_storage),
) -> FileInfoResponse:
_ = current_user
user_folder = get_user_folder(current_user)
resolved_folder = f"{user_folder}/{body.folder}" if body.folder else f"{user_folder}/output_files"
info = await remove_silence(
storage,
file_key=body.file_path,
out_folder=body.folder,
out_folder=resolved_folder,
min_silence_duration_ms=body.min_silence_duration_ms,
silence_threshold_db=body.silence_threshold_db,
padding_ms=body.padding_ms,
@@ -71,9 +83,10 @@ async def convert(
current_user: User = Depends(get_current_user),
storage: StorageService = Depends(get_storage),
) -> FileInfoResponse:
_ = current_user
user_folder = get_user_folder(current_user)
resolved_folder = f"{user_folder}/{body.folder}" if body.folder else f"{user_folder}/output_files"
info = await convert_to_mp4(storage, file_key=body.file_path, out_folder=body.folder)
info = await convert_to_mp4(storage, file_key=body.file_path, out_folder=resolved_folder)
return FileInfoResponse(
file_path=info.file_path,
file_url=info.file_url,
@@ -82,6 +95,36 @@ async def convert(
)
@media_router.get("/frames/", response_model=FrameRangeResponse)
async def get_frames(
file_key: str = Query(..., description="S3 key of the source video"),
start: float = Query(0.0, ge=0, description="Start time in seconds"),
end: float = Query(..., gt=0, description="End time in seconds"),
current_user: User = Depends(get_current_user),
storage: StorageService = Depends(get_storage),
) -> FrameRangeResponse:
"""Return presigned URLs for extracted frames within a time range."""
user_folder = get_user_folder(current_user)
frames_folder = get_frames_folder(user_folder, file_key)
metadata = await read_frames_metadata(storage, frames_folder=frames_folder)
if metadata is None:
return FrameRangeResponse(interval=1.0, frames=[])
interval = metadata.interval
first_index = max(1, math.floor(start / interval) + 1)
last_index = min(metadata.frame_count, math.ceil(end / interval) + 1)
frames: list[FrameItem] = []
for i in range(first_index, last_index + 1):
key = path.join(frames_folder, f"{i:06d}.jpg")
timestamp = (i - 1) * interval
url = await storage.url(key)
frames.append(FrameItem(timestamp=timestamp, url=url))
return FrameRangeResponse(interval=interval, frames=frames)
@mediafiles_router.get("/mediafiles/", response_model=list[MediaFileRead])
async def list_mediafiles(
current_user: User = Depends(get_current_user),
+29 -2
View File
@@ -12,9 +12,11 @@ from cpv3.common.schemas import Schema
ArtifactTypeEnum = Literal[
"TRANSCRIPTION_JSON",
"SILENCE_REMOVED_VIDEO",
"CONVERTED_VIDEO",
"THUMBNAIL",
"AUDIO_PROXY",
"RENDERED_VIDEO",
"FRAME_SPRITES",
]
@@ -60,7 +62,7 @@ class ArtifactMediaFileRead(Schema):
id: UUID
project_id: UUID | None
file_id: UUID | None
media_file_id: UUID
media_file_id: UUID | None
artifact_type: ArtifactTypeEnum
@@ -74,7 +76,7 @@ class ArtifactMediaFileRead(Schema):
class ArtifactMediaFileCreate(Schema):
project_id: UUID | None = None
file_id: UUID | None = None
media_file_id: UUID
media_file_id: UUID | None = None
artifact_type: ArtifactTypeEnum
@@ -148,3 +150,28 @@ class MediaSilencerParams(Schema):
class MediaConverterParams(Schema):
file_path: str
folder: str = ""
class FrameSpriteMetadata(Schema):
"""Metadata stored in ArtifactMediaFile.meta for extracted frames."""
frame_count: int
interval: float
width: int
height: int
folder_key: str
source_file_key: str
class FrameItem(Schema):
"""Single frame in a range query response."""
timestamp: float
url: str
class FrameRangeResponse(Schema):
"""Response for GET /api/media/frames/ range query."""
interval: float
frames: list[FrameItem]
+299 -2
View File
@@ -1,15 +1,30 @@
from __future__ import annotations
import asyncio
import glob as glob_mod
import hashlib
import io
import json
from os import path
from tempfile import NamedTemporaryFile
from tempfile import NamedTemporaryFile, mkdtemp
from typing import Callable
import anyio
from cpv3.infrastructure.storage.base import StorageService
from cpv3.infrastructure.storage.types import FileInfo
from cpv3.modules.media.schemas import MediaProbeSchema
from cpv3.modules.media.schemas import FrameSpriteMetadata, MediaProbeSchema
FRAME_WIDTH_PX = 128
FRAME_FPS = 1
FRAME_JPEG_QUALITY = 5
FRAMES_META_FILENAME = "meta.json"
def get_frames_folder(user_folder: str, file_key: str) -> str:
"""Build deterministic S3 folder path for frames based on file_key hash."""
key_hash = hashlib.sha256(file_key.encode()).hexdigest()[:16]
return path.join(user_folder, "frames", key_hash)
async def probe_media(storage: StorageService, *, file_key: str) -> MediaProbeSchema:
@@ -68,6 +83,160 @@ def _compute_non_silent_segments(
return segments
async def detect_silence(
storage: StorageService,
*,
file_key: str,
min_silence_duration_ms: int = 200,
silence_threshold_db: int = 16,
padding_ms: int = 100,
) -> dict:
"""Detect silent segments in a media file and return their intervals."""
input_tmp = await storage.download_to_temp(file_key)
try:
from pydub import AudioSegment # type: ignore[import-untyped]
audio: AudioSegment = await anyio.to_thread.run_sync(
lambda: AudioSegment.from_file(input_tmp.path)
)
duration_ms = len(audio)
non_silent = await anyio.to_thread.run_sync(
lambda: _compute_non_silent_segments(
local_audio_path=input_tmp.path,
min_silence_duration_ms=min_silence_duration_ms,
silence_threshold_db=silence_threshold_db,
padding_ms=padding_ms,
)
)
# Invert non-silent segments to get silent segments
silent_segments: list[dict[str, int]] = []
prev_end = 0
for start_ms, end_ms in non_silent:
if start_ms > prev_end:
silent_segments.append({"start_ms": prev_end, "end_ms": start_ms})
prev_end = end_ms
if prev_end < duration_ms:
silent_segments.append({"start_ms": prev_end, "end_ms": duration_ms})
return {
"silent_segments": silent_segments,
"duration_ms": duration_ms,
"file_key": file_key,
}
finally:
input_tmp.cleanup()
async def apply_silence_cuts(
storage: StorageService,
*,
file_key: str,
out_folder: str,
cuts: list[dict],
output_name: str | None = None,
) -> FileInfo:
"""Apply explicit cut regions to a media file, concatenating the non-cut parts."""
input_tmp = await storage.download_to_temp(file_key)
try:
from pydub import AudioSegment # type: ignore[import-untyped]
audio: AudioSegment = await anyio.to_thread.run_sync(
lambda: AudioSegment.from_file(input_tmp.path)
)
duration_ms = len(audio)
# Sort cuts and compute non-cut (keep) segments
sorted_cuts = sorted(cuts, key=lambda c: c["start_ms"])
segments: list[tuple[int, int]] = []
prev_end = 0
for cut in sorted_cuts:
cut_start = max(0, cut["start_ms"])
cut_end = min(duration_ms, cut["end_ms"])
if cut_start > prev_end:
segments.append((prev_end, cut_start))
prev_end = max(prev_end, cut_end)
if prev_end < duration_ms:
segments.append((prev_end, duration_ms))
if not segments:
return await storage.get_file_info(file_key)
with NamedTemporaryFile(
suffix=path.splitext(file_key)[1] or ".mp4", delete=False
) as out:
out_path = out.name
try:
cmd: list[str] = ["ffmpeg"]
for start_ms, end_ms in segments:
start_s = start_ms / 1000.0
duration_s = (end_ms - start_ms) / 1000.0
cmd.extend(
[
"-ss",
f"{start_s:.3f}",
"-t",
f"{duration_s:.3f}",
"-y",
"-i",
input_tmp.path,
]
)
seg_count = len(segments)
parts = [f"[{i}:v:0][{i}:a:0]" for i in range(seg_count)]
filter_complex = "".join(parts) + f"concat=n={seg_count}:v=1:a=1[v][a]"
cmd.extend(
[
"-filter_complex",
filter_complex,
"-map",
"[v]",
"-map",
"[a]",
"-c:v",
"libx264",
"-c:a",
"aac",
"-preset",
"medium",
out_path,
]
)
proc = await asyncio.create_subprocess_exec(
*cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
)
_, stderr = await proc.communicate()
if proc.returncode != 0:
raise RuntimeError(f"ffmpeg failed: {stderr.decode(errors='ignore')}")
base_name = output_name or path.basename(file_key)
output_key = path.join(out_folder or "", "silent", base_name)
with open(out_path, "rb") as out_file:
_ = await storage.upload_fileobj(
fileobj=out_file,
file_name=path.basename(output_key),
folder=path.dirname(output_key),
gen_name=False,
content_type="video/mp4",
)
return await storage.get_file_info(output_key)
finally:
import os
if os.path.exists(out_path):
os.remove(out_path)
finally:
input_tmp.cleanup()
async def remove_silence(
storage: StorageService,
*,
@@ -264,3 +433,131 @@ async def convert_to_ogg_temp(
_ = filename_without_ext
return out_path, _cleanup
async def extract_frames(
storage: StorageService,
*,
file_key: str,
frames_folder: str,
on_progress: Callable[[int, int], None] | None = None,
) -> FrameSpriteMetadata:
"""Extract video frames at 1fps via ffmpeg and upload to S3.
Also writes a ``meta.json`` alongside the frames for fast lookup.
Returns metadata about the extracted frames.
"""
input_tmp = await storage.download_to_temp(file_key)
tmp_dir = mkdtemp(prefix="frames_")
try:
cmd = [
"ffmpeg",
"-y",
"-i",
input_tmp.path,
"-vf",
f"fps={FRAME_FPS},scale={FRAME_WIDTH_PX}:-1",
"-q:v",
str(FRAME_JPEG_QUALITY),
path.join(tmp_dir, "%06d.jpg"),
]
proc = await asyncio.create_subprocess_exec(
*cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
)
_, stderr = await proc.communicate()
if proc.returncode != 0:
raise RuntimeError(f"ffmpeg frame extraction failed: {stderr.decode(errors='ignore')}")
frame_files = sorted(glob_mod.glob(path.join(tmp_dir, "*.jpg")))
frame_count = len(frame_files)
if frame_count == 0:
raise RuntimeError("No frames extracted from video")
# Read first frame dimensions via ffprobe (avoids PIL dependency)
probe_proc = await asyncio.create_subprocess_exec(
"ffprobe",
"-v", "error",
"-select_streams", "v:0",
"-show_entries", "stream=width,height",
"-of", "json",
frame_files[0],
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
probe_stdout, _ = await probe_proc.communicate()
probe_data = json.loads(probe_stdout.decode())
stream = probe_data.get("streams", [{}])[0]
width = stream.get("width", FRAME_WIDTH_PX)
height = stream.get("height", FRAME_WIDTH_PX)
# Upload each frame to S3
for idx, frame_path in enumerate(frame_files):
frame_name = path.basename(frame_path)
with open(frame_path, "rb") as f:
await storage.upload_fileobj(
fileobj=f,
file_name=frame_name,
folder=frames_folder,
gen_name=False,
content_type="image/jpeg",
)
if on_progress is not None:
on_progress(idx + 1, frame_count)
metadata = FrameSpriteMetadata(
frame_count=frame_count,
interval=1.0 / FRAME_FPS,
width=width,
height=height,
folder_key=frames_folder,
source_file_key=file_key,
)
# Write metadata JSON to S3 for fast lookup by the frames endpoint
meta_bytes = json.dumps(metadata.model_dump(mode="json")).encode("utf-8")
await storage.upload_fileobj(
fileobj=io.BytesIO(meta_bytes),
file_name=FRAMES_META_FILENAME,
folder=frames_folder,
gen_name=False,
content_type="application/json",
)
return metadata
finally:
import shutil
input_tmp.cleanup()
shutil.rmtree(tmp_dir, ignore_errors=True)
async def read_frames_metadata(
storage: StorageService, *, frames_folder: str
) -> FrameSpriteMetadata | None:
"""Read frame extraction metadata from S3. Returns None if not found."""
meta_key = path.join(frames_folder, FRAMES_META_FILENAME)
if not await storage.exists(meta_key):
return None
raw = await storage.read(meta_key)
return FrameSpriteMetadata.model_validate(json.loads(raw))
async def delete_frames(
storage: StorageService, *, frames_folder: str, frame_count: int
) -> None:
"""Delete all frame files and metadata from S3 for a given folder."""
for i in range(1, frame_count + 1):
key = path.join(frames_folder, f"{i:06d}.jpg")
try:
await storage.delete(key)
except Exception:
pass
# Delete metadata file
meta_key = path.join(frames_folder, FRAMES_META_FILENAME)
try:
await storage.delete(meta_key)
except Exception:
pass
+31
View File
@@ -0,0 +1,31 @@
from __future__ import annotations
import uuid
from sqlalchemy import Boolean, ForeignKey, JSON, String, Text
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import Mapped, mapped_column
from cpv3.db.base import Base, BaseModelMixin
class Notification(Base, BaseModelMixin):
__tablename__ = "notifications"
user_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True), ForeignKey("users.id", ondelete="CASCADE"), index=True
)
job_id: Mapped[uuid.UUID | None] = mapped_column(
UUID(as_uuid=True), ForeignKey("jobs.id", ondelete="SET NULL"), nullable=True
)
project_id: Mapped[uuid.UUID | None] = mapped_column(
UUID(as_uuid=True),
ForeignKey("projects.id", ondelete="SET NULL"),
nullable=True,
)
notification_type: Mapped[str] = mapped_column(String(32))
title: Mapped[str] = mapped_column(String(255))
message: Mapped[str | None] = mapped_column(Text, nullable=True)
payload: Mapped[dict | None] = mapped_column(JSON, nullable=True)
is_read: Mapped[bool] = mapped_column(Boolean, default=False)
+79
View File
@@ -0,0 +1,79 @@
from __future__ import annotations
import uuid
from sqlalchemy import Select, func, select, update
from sqlalchemy.ext.asyncio import AsyncSession
from cpv3.modules.notifications.models import Notification
from cpv3.modules.notifications.schemas import NotificationCreate
class NotificationRepository:
def __init__(self, session: AsyncSession) -> None:
self._session = session
async def list_for_user(
self,
user_id: uuid.UUID,
*,
limit: int = 50,
unread_only: bool = False,
) -> list[Notification]:
stmt: Select[tuple[Notification]] = (
select(Notification)
.where(Notification.user_id == user_id)
.where(Notification.is_active.is_(True))
.order_by(Notification.created_at.desc())
.limit(limit)
)
if unread_only:
stmt = stmt.where(Notification.is_read.is_(False))
result = await self._session.execute(stmt)
return list(result.scalars().all())
async def create(self, data: NotificationCreate) -> Notification:
notification = Notification(
user_id=data.user_id,
job_id=data.job_id,
project_id=data.project_id,
notification_type=data.notification_type,
title=data.title,
message=data.message,
payload=data.payload,
)
self._session.add(notification)
await self._session.commit()
await self._session.refresh(notification)
return notification
async def mark_read(self, notification_id: uuid.UUID, user_id: uuid.UUID) -> bool:
result = await self._session.execute(
update(Notification)
.where(Notification.id == notification_id)
.where(Notification.user_id == user_id)
.values(is_read=True)
)
await self._session.commit()
return result.rowcount > 0 # type: ignore[union-attr]
async def mark_all_read(self, user_id: uuid.UUID) -> int:
result = await self._session.execute(
update(Notification)
.where(Notification.user_id == user_id)
.where(Notification.is_read.is_(False))
.values(is_read=True)
)
await self._session.commit()
return result.rowcount # type: ignore[return-value]
async def count_unread(self, user_id: uuid.UUID) -> int:
result = await self._session.execute(
select(func.count())
.select_from(Notification)
.where(Notification.user_id == user_id)
.where(Notification.is_active.is_(True))
.where(Notification.is_read.is_(False))
)
return result.scalar_one()
+102
View File
@@ -0,0 +1,102 @@
from __future__ import annotations
import uuid
from fastapi import APIRouter, Depends, HTTPException, Query, WebSocket, WebSocketDisconnect, status
from jwt import ExpiredSignatureError, InvalidTokenError
from sqlalchemy.ext.asyncio import AsyncSession
from cpv3.db.session import SessionLocal, get_db
from cpv3.infrastructure.auth import get_current_user
from cpv3.infrastructure.security import decode_token
from cpv3.modules.notifications.repository import NotificationRepository
from cpv3.modules.notifications.schemas import NotificationRead
from cpv3.modules.notifications.service import subscribe_and_forward
from cpv3.modules.users.models import User
from cpv3.modules.users.repository import UserRepository
router = APIRouter(prefix="/api/notifications", tags=["notifications"])
@router.websocket("/ws/")
async def notifications_ws(
websocket: WebSocket,
token: str = Query(...),
) -> None:
"""WebSocket endpoint for real-time notifications."""
try:
payload = decode_token(token)
except (ExpiredSignatureError, InvalidTokenError):
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
return
if payload.get("type") != "access":
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
return
sub = payload.get("sub")
if not sub:
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
return
try:
user_id = uuid.UUID(str(sub))
except ValueError:
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
return
async with SessionLocal() as db:
user_repo = UserRepository(db)
user = await user_repo.get_by_id(user_id)
if user is None or not user.is_active:
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
return
await websocket.accept()
try:
await subscribe_and_forward(websocket, user_id)
except WebSocketDisconnect:
pass
@router.get("/", response_model=list[NotificationRead])
async def list_notifications(
unread_only: bool = Query(False),
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> list[NotificationRead]:
repo = NotificationRepository(db)
items = await repo.list_for_user(current_user.id, unread_only=unread_only)
return [NotificationRead.model_validate(n) for n in items]
@router.post("/{notification_id}/read/", status_code=status.HTTP_204_NO_CONTENT)
async def mark_notification_read(
notification_id: uuid.UUID,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> None:
repo = NotificationRepository(db)
found = await repo.mark_read(notification_id, current_user.id)
if not found:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Not found")
@router.post("/read-all/", status_code=status.HTTP_204_NO_CONTENT)
async def mark_all_read(
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> None:
repo = NotificationRepository(db)
await repo.mark_all_read(current_user.id)
@router.get("/unread-count/")
async def unread_count(
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> dict[str, int]:
repo = NotificationRepository(db)
count = await repo.count_unread(current_user.id)
return {"count": count}
+53
View File
@@ -0,0 +1,53 @@
from __future__ import annotations
from datetime import datetime
from typing import Literal
from uuid import UUID
from cpv3.common.schemas import Schema
NotificationTypeEnum = Literal["task_progress", "task_complete", "task_failed"]
class NotificationCreate(Schema):
user_id: UUID
job_id: UUID | None = None
project_id: UUID | None = None
notification_type: NotificationTypeEnum
title: str
message: str | None = None
payload: dict | None = None
class NotificationRead(Schema):
id: UUID
user_id: UUID
job_id: UUID | None
project_id: UUID | None
notification_type: NotificationTypeEnum
title: str
message: str | None
payload: dict | None
is_read: bool
created_at: datetime
updated_at: datetime
class NotificationUpdate(Schema):
is_read: bool | None = None
class WebSocketMessage(Schema):
"""JSON shape pushed over WebSocket."""
event: str
notification_id: UUID | None = None
job_id: UUID | None = None
project_id: UUID | None = None
job_type: str | None = None
status: str | None = None
progress_pct: float | None = None
message: str | None = None
title: str | None = None
created_at: datetime | None = None
+152
View File
@@ -0,0 +1,152 @@
from __future__ import annotations
import json
import logging
import uuid
import redis.asyncio as aioredis
from sqlalchemy.ext.asyncio import AsyncSession
from cpv3.db.base import utcnow
from cpv3.infrastructure.settings import get_settings
from cpv3.modules.jobs.models import Job
from cpv3.modules.notifications.repository import NotificationRepository
from cpv3.modules.notifications.schemas import (
NotificationCreate,
NotificationTypeEnum,
WebSocketMessage,
)
from cpv3.modules.tasks.schemas import TaskWebhookEvent
logger = logging.getLogger(__name__)
JOB_TYPE_LABELS: dict[str, str] = {
"MEDIA_PROBE": "Анализ медиа",
"SILENCE_REMOVE": "Удаление тишины",
"SILENCE_DETECT": "Обнаружение тишины",
"SILENCE_APPLY": "Применение вырезок",
"MEDIA_CONVERT": "Конвертация",
"TRANSCRIPTION_GENERATE": "Транскрипция",
"CAPTIONS_GENERATE": "Генерация субтитров",
}
STATUS_TITLES: dict[str, str] = {
"RUNNING": "Задача запущена",
"DONE": "Задача завершена",
"FAILED": "Ошибка выполнения",
}
# ---------------------------------------------------------------------------
# ConnectionManager — singleton for WebSocket pub/sub via Redis
# ---------------------------------------------------------------------------
_redis_client: aioredis.Redis | None = None
async def _get_redis() -> aioredis.Redis:
global _redis_client
if _redis_client is None:
settings = get_settings()
_redis_client = aioredis.from_url(settings.redis_url, decode_responses=True)
return _redis_client
def _channel_name(user_id: uuid.UUID) -> str:
return f"notifications:{user_id}"
async def publish_to_user(user_id: uuid.UUID, message: WebSocketMessage) -> None:
"""Publish a notification message to a user's Redis channel."""
redis = await _get_redis()
payload = message.model_dump_json()
await redis.publish(_channel_name(user_id), payload)
async def subscribe_and_forward(websocket: object, user_id: uuid.UUID) -> None:
"""Subscribe to a user's Redis channel and forward messages to WebSocket.
``websocket`` must be a ``fastapi.WebSocket`` instance — typed as ``object``
to avoid importing FastAPI at module level.
"""
from fastapi import WebSocket as _WS
ws: _WS = websocket # type: ignore[assignment]
redis = await _get_redis()
pubsub = redis.pubsub()
await pubsub.subscribe(_channel_name(user_id))
try:
async for raw_message in pubsub.listen():
if raw_message["type"] != "message":
continue
await ws.send_text(raw_message["data"])
finally:
await pubsub.unsubscribe(_channel_name(user_id))
await pubsub.aclose()
# ---------------------------------------------------------------------------
# NotificationService
# ---------------------------------------------------------------------------
class NotificationService:
def __init__(self, session: AsyncSession) -> None:
self._repo = NotificationRepository(session)
async def create_task_notification(
self, *, user_id: uuid.UUID, job: Job, event: TaskWebhookEvent
) -> None:
"""Create a notification for a task status change and push via WebSocket."""
notification_type: NotificationTypeEnum | None = None
if event.status == "RUNNING":
notification_type = "task_progress"
elif event.status == "DONE":
notification_type = "task_complete"
elif event.status == "FAILED":
notification_type = "task_failed"
job_type_label = JOB_TYPE_LABELS.get(job.job_type, job.job_type)
now = utcnow()
# Only persist notifications on status changes (not progress-only updates)
notification_id: uuid.UUID | None = None
if notification_type is not None:
title = STATUS_TITLES.get(event.status or "", job_type_label)
notification = await self._repo.create(
NotificationCreate(
user_id=user_id,
job_id=job.id,
project_id=job.project_id,
notification_type=notification_type,
title=title,
message=event.error_message or event.current_message,
payload={
"job_type": job.job_type,
"progress_pct": event.progress_pct,
"status": event.status,
},
)
)
notification_id = notification.id
# Always push WebSocket message (including progress-only updates)
ws_event = "task_update"
ws_message = WebSocketMessage(
event=ws_event,
notification_id=notification_id,
job_id=job.id,
project_id=job.project_id,
job_type=job.job_type,
status=event.status or job.status,
progress_pct=event.progress_pct or job.project_pct,
message=event.error_message or event.current_message or job.current_message,
title=job_type_label,
created_at=now,
)
try:
await publish_to_user(user_id, ws_message)
except Exception:
logger.exception("Failed to publish WebSocket notification for user %s", user_id)
+2 -1
View File
@@ -2,7 +2,7 @@ from __future__ import annotations
import uuid
from sqlalchemy import ForeignKey, String, Text
from sqlalchemy import ForeignKey, JSON, String, Text
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import Mapped, mapped_column
@@ -23,3 +23,4 @@ class Project(Base, BaseModelMixin):
language: Mapped[str] = mapped_column(String(4), default="auto")
folder: Mapped[str | None] = mapped_column(String(1024), nullable=True)
status: Mapped[str] = mapped_column(String(16), default="DRAFT")
workspace_state: Mapped[dict | None] = mapped_column(JSON, nullable=True)
+25 -5
View File
@@ -2,7 +2,7 @@ from __future__ import annotations
import uuid
from sqlalchemy import Select, select
from sqlalchemy import Select, or_, select
from sqlalchemy.ext.asyncio import AsyncSession
from cpv3.modules.projects.models import Project
@@ -16,13 +16,31 @@ class ProjectRepository:
def __init__(self, session: AsyncSession) -> None:
self._session = session
async def list_all(self, *, requester: User) -> list[Project]:
async def list_all(
self,
*,
requester: User,
search: str | None = None,
status: str | None = None,
) -> list[Project]:
stmt: Select[tuple[Project]] = select(Project).where(
Project.is_active.is_(True)
)
if not requester.is_staff:
stmt = stmt.where(Project.owner_id == requester.id)
if search:
pattern = f"%{search}%"
stmt = stmt.where(
or_(
Project.name.ilike(pattern),
Project.description.ilike(pattern),
)
)
if status:
stmt = stmt.where(Project.status == status)
result = await self._session.execute(stmt.order_by(Project.created_at.desc()))
return list(result.scalars().all())
@@ -34,14 +52,16 @@ class ProjectRepository:
)
return result.scalar_one_or_none()
async def create(self, *, requester: User, data: ProjectCreate) -> Project:
async def create(
self, *, requester: User, data: ProjectCreate, folder: str, status: str,
) -> Project:
project = Project(
owner_id=requester.id,
name=data.name,
description=data.description,
language=data.language,
folder=data.folder,
status=data.status,
folder=folder,
status=status,
)
self._session.add(project)
+8 -2
View File
@@ -2,7 +2,7 @@ from __future__ import annotations
import uuid
from fastapi import APIRouter, Depends, HTTPException, Response, status
from fastapi import APIRouter, Depends, HTTPException, Query, Response, status
from sqlalchemy.ext.asyncio import AsyncSession
from cpv3.infrastructure.auth import get_current_user
@@ -16,11 +16,17 @@ router = APIRouter(prefix="/api/projects", tags=["Projects"])
@router.get("/", response_model=list[ProjectRead])
async def list_all_projects(
search: str | None = Query(None, description="Поиск по названию или описанию"),
status_filter: str | None = Query(
None, alias="status", description="Фильтр по статусу проекта"
),
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> list[ProjectRead]:
service = ProjectService(db)
projects = await service.list_projects(requester=current_user)
projects = await service.list_projects(
requester=current_user, search=search, status=status_filter,
)
return [ProjectRead.model_validate(p) for p in projects]
+3 -2
View File
@@ -19,6 +19,8 @@ class ProjectRead(Schema):
folder: str | None
status: ProjectStatusEnum
workspace_state: dict | None
is_active: bool
created_at: datetime
updated_at: datetime
@@ -28,8 +30,6 @@ class ProjectCreate(Schema):
name: str
description: str | None = None
language: str = "auto"
folder: str | None = None
status: ProjectStatusEnum = "DRAFT"
class ProjectUpdate(Schema):
@@ -38,3 +38,4 @@ class ProjectUpdate(Schema):
language: str | None = None
folder: str | None = None
status: ProjectStatusEnum | None = None
workspace_state: dict | None = None
+14 -3
View File
@@ -16,14 +16,25 @@ class ProjectService:
def __init__(self, session: AsyncSession) -> None:
self._repo = ProjectRepository(session)
async def list_projects(self, *, requester: User) -> list[Project]:
return await self._repo.list_all(requester=requester)
async def list_projects(
self,
*,
requester: User,
search: str | None = None,
status: str | None = None,
) -> list[Project]:
return await self._repo.list_all(
requester=requester, search=search, status=status,
)
async def get_project(self, project_id: uuid.UUID) -> Project | None:
return await self._repo.get_by_id(project_id)
async def create_project(self, *, requester: User, data: ProjectCreate) -> Project:
return await self._repo.create(requester=requester, data=data)
folder = f"/{requester.username}/{data.name}"
return await self._repo.create(
requester=requester, data=data, folder=folder, status="DRAFT",
)
async def update_project(self, project: Project, data: ProjectUpdate) -> Project:
return await self._repo.update(project, data)
+48
View File
@@ -15,8 +15,11 @@ from cpv3.infrastructure.auth import get_current_user
from cpv3.modules.jobs.service import JobService
from cpv3.modules.tasks.schemas import (
CaptionsGenerateRequest,
FrameExtractRequest,
MediaConvertRequest,
MediaProbeRequest,
SilenceApplyRequest,
SilenceDetectRequest,
SilenceRemoveRequest,
TaskStatusEnum,
TaskStatusResponse,
@@ -61,6 +64,36 @@ async def submit_silence_remove(
return await service.submit_silence_remove(requester=current_user, request=body)
@router.post(
"/silence-detect/",
response_model=TaskSubmitResponse,
status_code=status.HTTP_202_ACCEPTED,
)
async def submit_silence_detect(
body: SilenceDetectRequest,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> TaskSubmitResponse:
"""Submit a background task to detect silent segments in media file."""
service = TaskService(db)
return await service.submit_silence_detect(requester=current_user, request=body)
@router.post(
"/silence-apply/",
response_model=TaskSubmitResponse,
status_code=status.HTTP_202_ACCEPTED,
)
async def submit_silence_apply(
body: SilenceApplyRequest,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> TaskSubmitResponse:
"""Submit a background task to apply silence cuts to media file."""
service = TaskService(db)
return await service.submit_silence_apply(requester=current_user, request=body)
@router.post(
"/media-convert/",
response_model=TaskSubmitResponse,
@@ -93,6 +126,21 @@ async def submit_transcription_generate(
)
@router.post(
"/frame-extract/",
response_model=TaskSubmitResponse,
status_code=status.HTTP_202_ACCEPTED,
)
async def submit_frame_extract(
body: FrameExtractRequest,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> TaskSubmitResponse:
"""Submit a background task to extract video frames for timeline thumbnails."""
service = TaskService(db)
return await service.submit_frame_extract(requester=current_user, request=body)
@router.post(
"/captions-generate/",
response_model=TaskSubmitResponse,
+40
View File
@@ -45,6 +45,36 @@ class SilenceRemoveRequest(Schema):
)
class SilenceDetectRequest(Schema):
"""Request to detect silent segments in media file."""
file_key: str = Field(..., description="Storage key of the input file")
project_id: UUID | None = Field(default=None, description="Associated project ID")
min_silence_duration_ms: int = Field(
default=200, description="Minimum silence duration in milliseconds"
)
silence_threshold_db: int = Field(
default=16, description="Silence threshold in decibels"
)
padding_ms: int = Field(
default=100, description="Padding around non-silent segments in milliseconds"
)
class SilenceApplyRequest(Schema):
"""Request to apply silence cuts to media file."""
file_key: str = Field(..., description="Storage key of the input file")
out_folder: str = Field(..., description="Output folder for processed file")
project_id: UUID | None = Field(default=None, description="Associated project ID")
output_name: str | None = Field(
default=None, description="Display name for the output file"
)
cuts: list[dict] = Field(
..., description="Cut regions: [{'start_ms': int, 'end_ms': int}, ...]"
)
class MediaConvertRequest(Schema):
"""Request to convert media file to different format."""
@@ -75,6 +105,16 @@ class CaptionsGenerateRequest(Schema):
project_id: UUID | None = Field(default=None, description="Associated project ID")
class FrameExtractRequest(Schema):
"""Request to extract video frames for timeline thumbnails."""
file_key: str = Field(..., description="S3 key of the video file")
project_id: UUID | None = Field(default=None, description="Associated project ID")
regenerate: bool = Field(
default=False, description="Delete existing frames and re-extract"
)
# --- Response schemas ---
+624 -9
View File
@@ -5,8 +5,12 @@ Task service for submitting and managing background tasks.
from __future__ import annotations
import asyncio
import io
import json
import logging
import time
import uuid
from pathlib import Path
from datetime import datetime, timezone
from typing import Any
@@ -17,6 +21,8 @@ from sqlalchemy.ext.asyncio import AsyncSession
from cpv3.infrastructure.deps import _get_storage_service
from cpv3.infrastructure.settings import get_settings
from cpv3.modules.files.repository import FileRepository
from cpv3.modules.files.schemas import FileCreate
from cpv3.modules.jobs.models import Job
from cpv3.modules.jobs.repository import JobEventRepository, JobRepository
from cpv3.modules.jobs.schemas import (
@@ -26,17 +32,26 @@ from cpv3.modules.jobs.schemas import (
JobTypeEnum,
JobUpdate,
)
from cpv3.modules.media.repository import ArtifactRepository
from cpv3.modules.media.schemas import ArtifactMediaFileCreate
from cpv3.modules.tasks.schemas import (
CaptionsGenerateRequest,
FrameExtractRequest,
MediaConvertRequest,
MediaProbeRequest,
SilenceApplyRequest,
SilenceDetectRequest,
SilenceRemoveRequest,
TaskSubmitResponse,
TaskWebhookEvent,
TranscriptionGenerateRequest,
)
from cpv3.infrastructure.storage.utils import get_user_folder
from cpv3.modules.notifications.service import NotificationService
from cpv3.modules.transcription.repository import TranscriptionRepository
from cpv3.modules.transcription.schemas import TranscriptionCreate
from cpv3.modules.users.models import User
from cpv3.modules.users.repository import UserRepository
from cpv3.modules.webhooks.repository import WebhookRepository
from cpv3.modules.webhooks.schemas import WebhookCreate
@@ -49,9 +64,12 @@ JOB_STATUS_FAILED: JobStatusEnum = "FAILED"
JOB_TYPE_MEDIA_PROBE: JobTypeEnum = "MEDIA_PROBE"
JOB_TYPE_SILENCE_REMOVE: JobTypeEnum = "SILENCE_REMOVE"
JOB_TYPE_SILENCE_DETECT: JobTypeEnum = "SILENCE_DETECT"
JOB_TYPE_SILENCE_APPLY: JobTypeEnum = "SILENCE_APPLY"
JOB_TYPE_MEDIA_CONVERT: JobTypeEnum = "MEDIA_CONVERT"
JOB_TYPE_TRANSCRIPTION_GENERATE: JobTypeEnum = "TRANSCRIPTION_GENERATE"
JOB_TYPE_CAPTIONS_GENERATE: JobTypeEnum = "CAPTIONS_GENERATE"
JOB_TYPE_FRAME_EXTRACT: JobTypeEnum = "FRAME_EXTRACT"
EVENT_TYPE_STATUS_PREFIX = "status_"
EVENT_TYPE_PROGRESS = "progress"
@@ -62,19 +80,41 @@ EVENT_TYPE_ERROR = "error"
TASK_WEBHOOK_PATH = "/api/tasks/webhook/{job_id}/"
WEBHOOK_TIMEOUT_SECONDS = 10
ERROR_NO_AUDIO_STREAM = "Файл не содержит аудиодорожки"
ERROR_UNKNOWN_ENGINE = "Неизвестный движок транскрипции: {engine}"
ENGINE_MAP: dict[str, str] = {
"whisper": "LOCAL_WHISPER",
"google": "GOOGLE_SPEECH_CLOUD",
}
MESSAGE_STARTING = "Starting"
MESSAGE_COMPLETED = "Completed"
MESSAGE_PROBING_MEDIA = "Probing media"
MESSAGE_PROCESSING = "Processing"
MESSAGE_CONVERTING = "Converting"
MESSAGE_RENDERING_CAPTIONS = "Rendering captions"
MESSAGE_EXTRACTING_FRAMES = "Извлечение кадров"
MESSAGE_UPLOADING_FRAMES = "Загрузка кадров"
MESSAGE_DELETING_OLD_FRAMES = "Удаление старых кадров"
PROGRESS_COMPLETE = 100.0
PROGRESS_MEDIA_PROBE = 50.0
PROGRESS_SILENCE_REMOVE = 30.0
PROGRESS_MEDIA_CONVERT = 30.0
PROGRESS_TRANSCRIPTION = 20.0
PROGRESS_TRANSCRIPTION_START = 20.0
PROGRESS_TRANSCRIPTION_END = 95.0
PROGRESS_CAPTIONS = 30.0
PROGRESS_FRAME_EXTRACT_START = 10.0
PROGRESS_FRAME_EXTRACT_END = 95.0
PROGRESS_SILENCE_DETECT = 30.0
PROGRESS_SILENCE_APPLY = 30.0
MESSAGE_DETECTING_SILENCE = "Обнаружение тишины"
MESSAGE_APPLYING_CUTS = "Применение вырезок"
PROGRESS_THROTTLE_SECONDS = 3.0
# ---------------------------------------------------------------------------
# Dramatiq broker setup
@@ -95,6 +135,18 @@ def _utc_now() -> datetime:
return datetime.now(timezone.utc)
def _parse_frame_rate(rate_str: str) -> float | None:
"""Parse ffprobe frame rate string like '30/1' or '30000/1001'."""
try:
if "/" in rate_str:
num, den = rate_str.split("/")
den_val = int(den)
return round(int(num) / den_val, 3) if den_val else None
return float(rate_str)
except (ValueError, ZeroDivisionError):
return None
def _build_webhook_url(job_id: uuid.UUID) -> str:
"""Build the internal webhook URL for task updates."""
settings = get_settings()
@@ -267,6 +319,136 @@ def silence_remove_actor(
raise
@dramatiq.actor(max_retries=3, min_backoff=1000)
def silence_detect_actor(
job_id: str,
webhook_url: str,
file_key: str,
min_silence_duration_ms: int,
silence_threshold_db: int,
padding_ms: int,
) -> None:
"""Detect silent segments in media file."""
from cpv3.modules.media.service import detect_silence
job_uuid = uuid.UUID(job_id)
_send_webhook_event(
webhook_url,
TaskWebhookEvent(
status=JOB_STATUS_RUNNING,
current_message=MESSAGE_STARTING,
started_at=_utc_now(),
),
)
try:
storage = _get_storage_service()
_send_webhook_event(
webhook_url,
TaskWebhookEvent(
current_message=MESSAGE_DETECTING_SILENCE,
progress_pct=PROGRESS_SILENCE_DETECT,
),
)
result = _run_async(
detect_silence(
storage,
file_key=file_key,
min_silence_duration_ms=min_silence_duration_ms,
silence_threshold_db=silence_threshold_db,
padding_ms=padding_ms,
)
)
_send_webhook_event(
webhook_url,
TaskWebhookEvent(
status=JOB_STATUS_DONE,
current_message=MESSAGE_COMPLETED,
progress_pct=PROGRESS_COMPLETE,
output_data=result,
finished_at=_utc_now(),
),
)
except Exception as exc:
logger.exception("silence_detect_actor failed: %s", job_uuid)
_send_webhook_event(
webhook_url,
TaskWebhookEvent(
status=JOB_STATUS_FAILED,
error_message=str(exc),
finished_at=_utc_now(),
),
)
raise
@dramatiq.actor(max_retries=3, min_backoff=1000)
def silence_apply_actor(
job_id: str,
webhook_url: str,
file_key: str,
out_folder: str,
cuts: list[dict],
output_name: str | None,
) -> None:
"""Apply silence cuts to media file."""
from cpv3.modules.media.service import apply_silence_cuts
job_uuid = uuid.UUID(job_id)
_send_webhook_event(
webhook_url,
TaskWebhookEvent(
status=JOB_STATUS_RUNNING,
current_message=MESSAGE_STARTING,
started_at=_utc_now(),
),
)
try:
storage = _get_storage_service()
_send_webhook_event(
webhook_url,
TaskWebhookEvent(
current_message=MESSAGE_APPLYING_CUTS,
progress_pct=PROGRESS_SILENCE_APPLY,
),
)
result = _run_async(
apply_silence_cuts(
storage,
file_key=file_key,
out_folder=out_folder,
cuts=cuts,
output_name=output_name,
)
)
_send_webhook_event(
webhook_url,
TaskWebhookEvent(
status=JOB_STATUS_DONE,
current_message=MESSAGE_COMPLETED,
progress_pct=PROGRESS_COMPLETE,
output_data={
"file_path": result.file_path,
"file_url": result.file_url,
"file_size": result.file_size,
},
finished_at=_utc_now(),
),
)
except Exception as exc:
logger.exception("silence_apply_actor failed: %s", job_uuid)
_send_webhook_event(
webhook_url,
TaskWebhookEvent(
status=JOB_STATUS_FAILED,
error_message=str(exc),
finished_at=_utc_now(),
),
)
raise
@dramatiq.actor(max_retries=3, min_backoff=1000)
def media_convert_actor(
job_id: str,
@@ -356,19 +538,60 @@ def transcription_generate_actor(
)
try:
from cpv3.modules.media.service import probe_media
storage = _get_storage_service()
probe = _run_async(probe_media(storage, file_key=file_key))
has_audio = any(s.codec_type == "audio" for s in probe.streams)
if not has_audio:
raise ValueError(ERROR_NO_AUDIO_STREAM)
# Extract probe metadata for artifact creation
duration_seconds = float(probe.format.duration) if probe.format and probe.format.duration else 0.0
video_stream = next((s for s in probe.streams if s.codec_type == "video"), None)
probe_meta = {
"duration_seconds": duration_seconds,
"frame_rate": _parse_frame_rate(video_stream.r_frame_rate) if video_stream and video_stream.r_frame_rate else None,
"width": video_stream.width if video_stream else None,
"height": video_stream.height if video_stream else None,
}
_send_webhook_event(
webhook_url,
TaskWebhookEvent(
current_message=f"Transcribing ({engine})",
progress_pct=PROGRESS_TRANSCRIPTION,
current_message=f"Транскрибирование ({engine})",
progress_pct=PROGRESS_TRANSCRIPTION_START,
),
)
last_report_time = time.monotonic()
def _on_whisper_progress(pct: float) -> None:
nonlocal last_report_time
now = time.monotonic()
if now - last_report_time < PROGRESS_THROTTLE_SECONDS:
return
last_report_time = now
mapped = PROGRESS_TRANSCRIPTION_START + (
pct / 100.0
) * (PROGRESS_TRANSCRIPTION_END - PROGRESS_TRANSCRIPTION_START)
_send_webhook_event(
webhook_url,
TaskWebhookEvent(
current_message=f"Транскрибирование ({engine})",
progress_pct=round(mapped, 1),
),
)
if engine == "whisper":
document = _run_async(
transcribe_with_whisper(
storage, file_key=file_key, model_name=model, language=language
storage,
file_key=file_key,
model_name=model,
language=language,
on_progress=_on_whisper_progress,
)
)
elif engine == "google":
@@ -379,7 +602,7 @@ def transcription_generate_actor(
)
)
else:
raise ValueError(f"Unknown engine: {engine}")
raise ValueError(ERROR_UNKNOWN_ENGINE.format(engine=engine))
_send_webhook_event(
webhook_url,
@@ -387,7 +610,22 @@ def transcription_generate_actor(
status=JOB_STATUS_DONE,
current_message=MESSAGE_COMPLETED,
progress_pct=PROGRESS_COMPLETE,
output_data={"document": document.model_dump(mode="json")},
output_data={
"document": document.model_dump(mode="json"),
"probe": probe_meta,
},
finished_at=_utc_now(),
),
)
except (ValueError, RuntimeError) as exc:
logger.exception(
"transcription_generate_actor failed (non-transient): %s", job_uuid
)
_send_webhook_event(
webhook_url,
TaskWebhookEvent(
status=JOB_STATUS_FAILED,
error_message=str(exc),
finished_at=_utc_now(),
),
)
@@ -463,6 +701,115 @@ def captions_generate_actor(
raise
@dramatiq.actor(max_retries=2, min_backoff=2000)
def frame_extract_actor(
job_id: str,
webhook_url: str,
file_key: str,
frames_folder: str,
regenerate: bool,
) -> None:
"""Extract video frames at 1fps for timeline thumbnails."""
from cpv3.modules.media.service import (
delete_frames,
extract_frames,
read_frames_metadata,
)
job_uuid = uuid.UUID(job_id)
_send_webhook_event(
webhook_url,
TaskWebhookEvent(
status=JOB_STATUS_RUNNING,
current_message=MESSAGE_STARTING,
started_at=_utc_now(),
),
)
try:
storage = _get_storage_service()
# Delete old frames if regenerating
if regenerate:
_send_webhook_event(
webhook_url,
TaskWebhookEvent(
current_message=MESSAGE_DELETING_OLD_FRAMES,
progress_pct=PROGRESS_FRAME_EXTRACT_START,
),
)
old_meta = _run_async(
read_frames_metadata(storage, frames_folder=frames_folder)
)
if old_meta is not None:
_run_async(
delete_frames(
storage,
frames_folder=frames_folder,
frame_count=old_meta.frame_count,
)
)
_send_webhook_event(
webhook_url,
TaskWebhookEvent(
current_message=MESSAGE_EXTRACTING_FRAMES,
progress_pct=PROGRESS_FRAME_EXTRACT_START,
),
)
last_report_time = time.monotonic()
def _on_progress(current: int, total: int) -> None:
nonlocal last_report_time
now = time.monotonic()
if now - last_report_time < PROGRESS_THROTTLE_SECONDS:
return
last_report_time = now
pct = current / total
mapped = PROGRESS_FRAME_EXTRACT_START + pct * (
PROGRESS_FRAME_EXTRACT_END - PROGRESS_FRAME_EXTRACT_START
)
_send_webhook_event(
webhook_url,
TaskWebhookEvent(
current_message=MESSAGE_UPLOADING_FRAMES,
progress_pct=round(mapped, 1),
),
)
metadata = _run_async(
extract_frames(
storage,
file_key=file_key,
frames_folder=frames_folder,
on_progress=_on_progress,
)
)
_send_webhook_event(
webhook_url,
TaskWebhookEvent(
status=JOB_STATUS_DONE,
current_message=MESSAGE_COMPLETED,
progress_pct=PROGRESS_COMPLETE,
output_data=metadata.model_dump(mode="json"),
finished_at=_utc_now(),
),
)
except Exception as exc:
logger.exception("frame_extract_actor failed: %s", job_uuid)
_send_webhook_event(
webhook_url,
TaskWebhookEvent(
status=JOB_STATUS_FAILED,
error_message=str(exc),
finished_at=_utc_now(),
),
)
raise
# ---------------------------------------------------------------------------
# Task Service
# ---------------------------------------------------------------------------
@@ -557,8 +904,193 @@ class TaskService:
await self._event_repo.create(
JobEventCreate(job_id=job.id, event_type=event_type, payload=payload)
)
# Save artifacts BEFORE sending notifications so data exists when frontend refetches
if (
job.job_type == JOB_TYPE_TRANSCRIPTION_GENERATE
and event.status == JOB_STATUS_DONE
):
try:
await self._save_transcription_artifacts(job)
except Exception:
logger.exception(
"Failed to save transcription artifacts for job %s", job_id
)
if job.job_type == JOB_TYPE_MEDIA_CONVERT and event.status == JOB_STATUS_DONE:
try:
await self._save_convert_artifacts(job)
except Exception:
logger.exception(
"Failed to save convert artifacts for job %s", job_id
)
# Push real-time notification via WebSocket (after artifacts are persisted)
if job.user_id is not None:
try:
notification_service = NotificationService(self._session)
await notification_service.create_task_notification(
user_id=job.user_id, job=job, event=event
)
except Exception:
logger.exception("Failed to create notification for job %s", job_id)
return job
async def _save_transcription_artifacts(self, job: Job) -> None:
"""Create Transcription, ArtifactMediaFile and File records."""
input_data = job.input_data or {}
output_data = job.output_data or {}
file_key: str = input_data["file_key"]
project_id: uuid.UUID | None = (
uuid.UUID(input_data["project_id"]) if input_data.get("project_id") else None
)
engine_raw: str = input_data.get("engine", "whisper")
language: str | None = input_data.get("language")
document: dict = output_data["document"]
# Resolve user
user_repo = UserRepository(self._session)
user = await user_repo.get_by_id(job.user_id) # type: ignore[arg-type]
if user is None:
logger.warning("User %s not found, skipping artifact save", job.user_id)
return
# Find or create source File record
file_repo = FileRepository(self._session)
source_file = await file_repo.get_by_path(file_key)
if source_file is None:
source_file = await file_repo.create(
requester=user,
data=FileCreate(
project_id=project_id,
original_filename=file_key.rsplit("/", 1)[-1],
path=file_key,
storage_backend="S3",
mime_type="application/octet-stream",
size_bytes=0,
is_uploaded=True,
),
)
# Upload document JSON to S3
storage = _get_storage_service()
user_folder = get_user_folder(user)
json_bytes = json.dumps(document, ensure_ascii=False).encode("utf-8")
# Build display name: "Транскрипция <video_name>.json"
video_stem = Path(source_file.original_filename).stem
transcription_filename = f"Транскрипция {video_stem}.json"
artifact_key = await storage.upload_fileobj(
fileobj=io.BytesIO(json_bytes),
file_name=transcription_filename,
folder=f"{user_folder}/artifacts",
gen_name=True,
content_type="application/json",
)
# Create File record for the JSON artifact (no project_id — only reachable via artifact)
json_file = await file_repo.create(
requester=user,
data=FileCreate(
project_id=None,
original_filename=transcription_filename,
path=artifact_key,
storage_backend="S3",
mime_type="application/json",
size_bytes=len(json_bytes),
file_format="json",
is_uploaded=True,
),
)
# Create ArtifactMediaFile (no media_file_id — transcription is not a media file)
artifact_repo = ArtifactRepository(self._session)
artifact = await artifact_repo.create(
data=ArtifactMediaFileCreate(
project_id=project_id,
file_id=json_file.id,
media_file_id=None,
artifact_type="TRANSCRIPTION_JSON",
),
)
# Create Transcription record
transcription_repo = TranscriptionRepository(self._session)
engine_db = ENGINE_MAP.get(engine_raw, "LOCAL_WHISPER")
await transcription_repo.create(
data=TranscriptionCreate(
project_id=project_id,
source_file_id=source_file.id,
artifact_id=artifact.id,
engine=engine_db, # type: ignore[arg-type]
language=language,
document=document,
),
)
logger.info("Saved transcription artifacts for job %s", job.id)
async def _save_convert_artifacts(self, job: Job) -> None:
"""Create File and ArtifactMediaFile records for converted MP4."""
input_data = job.input_data or {}
output_data = job.output_data or {}
file_key: str = input_data["file_key"]
project_id: uuid.UUID | None = (
uuid.UUID(input_data["project_id"]) if input_data.get("project_id") else None
)
file_path: str = output_data["file_path"]
file_size: int = output_data.get("file_size", 0)
# Resolve user
user_repo = UserRepository(self._session)
user = await user_repo.get_by_id(job.user_id) # type: ignore[arg-type]
if user is None:
logger.warning("User %s not found, skipping convert artifact save", job.user_id)
return
# Derive output filename from source file
file_repo = FileRepository(self._session)
source_file = await file_repo.get_by_path(file_key)
if source_file is not None:
stem = Path(source_file.original_filename).stem
else:
stem = Path(file_key).stem
converted_filename = f"{stem}.mp4"
# Create File record for the converted MP4 (no project_id — only reachable via artifact)
converted_file = await file_repo.create(
requester=user,
data=FileCreate(
project_id=None,
original_filename=converted_filename,
path=file_path,
storage_backend="S3",
mime_type="video/mp4",
size_bytes=file_size,
file_format="mp4",
is_uploaded=True,
),
)
# Create ArtifactMediaFile record
artifact_repo = ArtifactRepository(self._session)
await artifact_repo.create(
data=ArtifactMediaFileCreate(
project_id=project_id,
file_id=converted_file.id,
media_file_id=None,
artifact_type="CONVERTED_VIDEO",
),
)
logger.info("Saved convert artifacts for job %s", job.id)
async def submit_media_probe(
self, *, requester: User, request: MediaProbeRequest
) -> TaskSubmitResponse:
@@ -576,6 +1108,12 @@ class TaskService:
self, *, requester: User, request: SilenceRemoveRequest
) -> TaskSubmitResponse:
"""Submit silence removal task."""
user_folder = get_user_folder(requester)
resolved_folder = (
f"{user_folder}/{request.out_folder}"
if request.out_folder
else f"{user_folder}/output_files"
)
return await self._submit_task(
requester=requester,
job_type=JOB_TYPE_SILENCE_REMOVE,
@@ -584,17 +1122,65 @@ class TaskService:
actor=silence_remove_actor,
actor_kwargs={
"file_key": request.file_key,
"out_folder": request.out_folder,
"out_folder": resolved_folder,
"min_silence_duration_ms": request.min_silence_duration_ms,
"silence_threshold_db": request.silence_threshold_db,
"padding_ms": request.padding_ms,
},
)
async def submit_silence_detect(
self, *, requester: User, request: SilenceDetectRequest
) -> TaskSubmitResponse:
"""Submit silence detection task."""
return await self._submit_task(
requester=requester,
job_type=JOB_TYPE_SILENCE_DETECT,
project_id=request.project_id,
input_data=request.model_dump(mode="json"),
actor=silence_detect_actor,
actor_kwargs={
"file_key": request.file_key,
"min_silence_duration_ms": request.min_silence_duration_ms,
"silence_threshold_db": request.silence_threshold_db,
"padding_ms": request.padding_ms,
},
)
async def submit_silence_apply(
self, *, requester: User, request: SilenceApplyRequest
) -> TaskSubmitResponse:
"""Submit silence apply task."""
user_folder = get_user_folder(requester)
resolved_folder = (
f"{user_folder}/{request.out_folder}"
if request.out_folder
else f"{user_folder}/output_files"
)
return await self._submit_task(
requester=requester,
job_type=JOB_TYPE_SILENCE_APPLY,
project_id=request.project_id,
input_data=request.model_dump(mode="json"),
actor=silence_apply_actor,
actor_kwargs={
"file_key": request.file_key,
"out_folder": resolved_folder,
"cuts": request.cuts,
"output_name": request.output_name,
},
)
async def submit_media_convert(
self, *, requester: User, request: MediaConvertRequest
) -> TaskSubmitResponse:
"""Submit media conversion task."""
user_folder = get_user_folder(requester)
resolved_folder = (
f"{user_folder}/{request.out_folder}"
if request.out_folder
else f"{user_folder}/output_files"
)
return await self._submit_task(
requester=requester,
job_type=JOB_TYPE_MEDIA_CONVERT,
@@ -603,7 +1189,7 @@ class TaskService:
actor=media_convert_actor,
actor_kwargs={
"file_key": request.file_key,
"out_folder": request.out_folder,
"out_folder": resolved_folder,
"output_format": request.output_format,
},
)
@@ -626,6 +1212,28 @@ class TaskService:
},
)
async def submit_frame_extract(
self, *, requester: User, request: FrameExtractRequest
) -> TaskSubmitResponse:
"""Submit frame extraction task."""
from cpv3.modules.media.service import get_frames_folder
user_folder = get_user_folder(requester)
frames_folder = get_frames_folder(user_folder, request.file_key)
return await self._submit_task(
requester=requester,
job_type=JOB_TYPE_FRAME_EXTRACT,
project_id=request.project_id,
input_data=request.model_dump(mode="json"),
actor=frame_extract_actor,
actor_kwargs={
"file_key": request.file_key,
"frames_folder": frames_folder,
"regenerate": request.regenerate,
},
)
async def submit_captions_generate(
self, *, requester: User, request: CaptionsGenerateRequest
) -> TaskSubmitResponse:
@@ -635,6 +1243,13 @@ class TaskService:
if transcription is None:
raise ValueError(f"Transcription {request.transcription_id} not found")
user_folder = get_user_folder(requester)
resolved_folder = (
f"{user_folder}/{request.folder}"
if request.folder
else f"{user_folder}/output_files"
)
return await self._submit_task(
requester=requester,
job_type=JOB_TYPE_CAPTIONS_GENERATE,
@@ -643,7 +1258,7 @@ class TaskService:
actor=captions_generate_actor,
actor_kwargs={
"video_s3_path": request.video_s3_path,
"folder": request.folder,
"folder": resolved_folder,
"transcription_json": transcription.document,
},
)
+8
View File
@@ -32,6 +32,14 @@ class TranscriptionRepository:
)
return result.scalar_one_or_none()
async def get_by_artifact_id(self, artifact_id: uuid.UUID) -> Transcription | None:
result = await self._session.execute(
select(Transcription)
.where(Transcription.artifact_id == artifact_id)
.where(Transcription.is_active.is_(True))
)
return result.scalar_one_or_none()
async def create(self, data: TranscriptionCreate) -> Transcription:
transcription = Transcription(
project_id=data.project_id,
+15
View File
@@ -67,6 +67,21 @@ async def retrieve_transcription_entry(
return TranscriptionRead.model_validate(transcription)
@router.get("/transcriptions/by-artifact/{artifact_id}/", response_model=TranscriptionRead)
async def retrieve_transcription_by_artifact(
artifact_id: uuid.UUID,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> TranscriptionRead:
_ = current_user
repo = TranscriptionRepository(db)
transcription = await repo.get_by_artifact_id(artifact_id)
if transcription is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Not found")
return TranscriptionRead.model_validate(transcription)
@router.patch("/transcriptions/{transcription_id}/", response_model=TranscriptionRead)
async def patch_transcription_entry(
transcription_id: uuid.UUID,
+35 -8
View File
@@ -240,11 +240,15 @@ def _make_document_from_segments(
return Document(segments=result_segments)
ProgressCallback = Callable[[float], None]
def _whisper_transcribe_sync(
*,
local_file_path: str,
model_name: str,
language: str | None,
on_progress: ProgressCallback | None = None,
) -> Document:
import whisper # type: ignore[import-untyped]
@@ -267,14 +271,35 @@ def _whisper_transcribe_sync(
probs = cast(dict[str, float], probs_raw)
language = max(probs, key=lambda k: probs[k])
result = whisper.transcribe(
audio=whisper.load_audio(local_file_path),
model=model,
word_timestamps=True,
temperature=0.2,
language=language,
verbose=False,
)
if on_progress is not None:
from unittest.mock import patch
from tqdm import tqdm as _orig_tqdm
class _ProgressTqdm(_orig_tqdm):
def update(self, n=1):
super().update(n)
if self.total:
on_progress(min(self.n / self.total * 100.0, 100.0))
with patch("whisper.transcribe.tqdm.tqdm", _ProgressTqdm):
result = whisper.transcribe(
audio=whisper.load_audio(local_file_path),
model=model,
word_timestamps=True,
temperature=0.2,
language=language,
verbose=False,
)
on_progress(100.0)
else:
result = whisper.transcribe(
audio=whisper.load_audio(local_file_path),
model=model,
word_timestamps=True,
temperature=0.2,
language=language,
verbose=None,
)
parsed = WhisperResult.model_validate(result)
@@ -296,6 +321,7 @@ async def transcribe_with_whisper(
file_key: str,
model_name: str = "tiny",
language: str | None = None,
on_progress: ProgressCallback | None = None,
) -> Document:
tmp = await storage.download_to_temp(file_key)
try:
@@ -304,6 +330,7 @@ async def transcribe_with_whisper(
local_file_path=tmp.path,
model_name=model_name,
language=language,
on_progress=on_progress,
)
)
finally:
+4
View File
@@ -71,6 +71,10 @@ class UserRepository:
await self._session.refresh(user)
return user
async def update_password(self, user: User, new_hash: str) -> None:
user.password_hash = new_hash
await self._session.commit()
async def deactivate(self, user: User) -> None:
user.is_active = False
await self._session.commit()
+60 -10
View File
@@ -2,17 +2,21 @@ from __future__ import annotations
import uuid
from datetime import timedelta
from urllib.parse import urlparse
from fastapi import APIRouter, Depends, HTTPException, Response, status
from jwt import ExpiredSignatureError, InvalidTokenError
from sqlalchemy.ext.asyncio import AsyncSession
from cpv3.infrastructure.auth import get_current_user
from cpv3.infrastructure.deps import get_storage
from cpv3.infrastructure.security import create_token, decode_token
from cpv3.infrastructure.settings import get_settings
from cpv3.infrastructure.storage.base import StorageService
from cpv3.db.session import get_db
from cpv3.modules.users.models import User
from cpv3.modules.users.schemas import (
PasswordChange,
TokenRefresh,
TokenRefreshResponse,
UserCreate,
@@ -28,6 +32,21 @@ users_router = APIRouter(prefix="/api/users", tags=["Users"])
auth_router = APIRouter(prefix="/auth", tags=["auth"])
def _is_s3_key(value: str) -> bool:
"""Return True if value looks like a bare S3 key, not a full URL."""
parsed = urlparse(value)
return not parsed.scheme and not parsed.netloc
async def _resolve_avatar(user: User, storage: StorageService) -> UserRead:
"""Build UserRead with a fresh presigned avatar URL."""
data = UserRead.model_validate(user)
if data.avatar:
if _is_s3_key(data.avatar):
data.avatar = await storage.url(data.avatar)
return data
def _issue_tokens(user: User) -> tuple[str, str]:
settings = get_settings()
@@ -49,10 +68,11 @@ def _issue_tokens(user: User) -> tuple[str, str]:
async def list_all_users(
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
storage: StorageService = Depends(get_storage),
) -> list[UserRead]:
service = UserService(db)
users = await service.list_users(requester=current_user)
return [UserRead.model_validate(u) for u in users]
return [await _resolve_avatar(u, storage) for u in users]
@users_router.post("/", response_model=UserRead, status_code=status.HTTP_201_CREATED)
@@ -60,6 +80,7 @@ async def create_user_endpoint(
body: UserCreate,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
storage: StorageService = Depends(get_storage),
) -> UserRead:
service = UserService(db)
try:
@@ -67,12 +88,29 @@ async def create_user_endpoint(
except ValueError as e:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) from e
return UserRead.model_validate(user)
return await _resolve_avatar(user, storage)
@users_router.get("/me/", response_model=UserRead)
async def me(current_user: User = Depends(get_current_user)) -> UserRead:
return UserRead.model_validate(current_user)
async def me(
current_user: User = Depends(get_current_user),
storage: StorageService = Depends(get_storage),
) -> UserRead:
return await _resolve_avatar(current_user, storage)
@users_router.post("/me/change-password/", status_code=status.HTTP_204_NO_CONTENT)
async def change_password(
body: PasswordChange,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> Response:
service = UserService(db)
try:
await service.change_password(current_user, body.current_password, body.new_password)
except ValueError as e:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) from e
return Response(status_code=status.HTTP_204_NO_CONTENT)
@users_router.get("/{user_id}/", response_model=UserRead)
@@ -80,6 +118,7 @@ async def retrieve_user(
user_id: uuid.UUID,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
storage: StorageService = Depends(get_storage),
) -> UserRead:
service = UserService(db)
user = await service.get_user_by_id(user_id)
@@ -89,7 +128,7 @@ async def retrieve_user(
if not current_user.is_staff and user.id != current_user.id:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Forbidden")
return UserRead.model_validate(user)
return await _resolve_avatar(user, storage)
@users_router.patch("/{user_id}/", response_model=UserRead)
@@ -98,6 +137,7 @@ async def patch_user(
body: UserUpdate,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
storage: StorageService = Depends(get_storage),
) -> UserRead:
service = UserService(db)
user = await service.get_user_by_id(user_id)
@@ -112,7 +152,7 @@ async def patch_user(
except ValueError as e:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) from e
return UserRead.model_validate(user)
return await _resolve_avatar(user, storage)
@users_router.delete("/{user_id}/", status_code=status.HTTP_204_NO_CONTENT)
@@ -136,7 +176,11 @@ async def delete_user(
@auth_router.post(
"/register", response_model=UserRegisterResponse, status_code=status.HTTP_201_CREATED
)
async def register(body: UserRegister, db: AsyncSession = Depends(get_db)) -> UserRegisterResponse:
async def register(
body: UserRegister,
db: AsyncSession = Depends(get_db),
storage: StorageService = Depends(get_storage),
) -> UserRegisterResponse:
service = UserService(db)
try:
user = await service.register_user(body)
@@ -144,18 +188,24 @@ async def register(body: UserRegister, db: AsyncSession = Depends(get_db)) -> Us
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) from e
access, refresh = _issue_tokens(user)
return UserRegisterResponse(user=UserRead.model_validate(user), access=access, refresh=refresh)
user_read = await _resolve_avatar(user, storage)
return UserRegisterResponse(user=user_read, access=access, refresh=refresh)
@auth_router.post("/login", response_model=UserRegisterResponse)
async def login(body: UserLogin, db: AsyncSession = Depends(get_db)) -> UserRegisterResponse:
async def login(
body: UserLogin,
db: AsyncSession = Depends(get_db),
storage: StorageService = Depends(get_storage),
) -> UserRegisterResponse:
service = UserService(db)
user = await service.authenticate(body.username, body.password)
if user is None:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid credentials")
access, refresh = _issue_tokens(user)
return UserRegisterResponse(user=UserRead.model_validate(user), access=access, refresh=refresh)
user_read = await _resolve_avatar(user, storage)
return UserRegisterResponse(user=user_read, access=access, refresh=refresh)
@auth_router.post("/refresh", response_model=TokenRefreshResponse)
+5
View File
@@ -66,6 +66,11 @@ class UserRegisterResponse(Schema):
refresh: str
class PasswordChange(Schema):
current_password: str
new_password: str
class TokenRefresh(Schema):
refresh: str
+14 -1
View File
@@ -4,7 +4,7 @@ import uuid
from sqlalchemy.ext.asyncio import AsyncSession
from cpv3.infrastructure.security import verify_password
from cpv3.infrastructure.security import hash_password, verify_password
from cpv3.modules.users.models import User
from cpv3.modules.users.repository import UserRepository
from cpv3.modules.users.schemas import UserCreate, UserRegister, UserUpdate
@@ -40,6 +40,12 @@ class UserService:
async def deactivate_user(self, user: User) -> None:
await self._repo.deactivate(user)
async def change_password(self, user: User, current_password: str, new_password: str) -> None:
if not verify_password(current_password, user.password_hash):
raise ValueError("Current password is incorrect")
new_hash = hash_password(new_password)
await self._repo.update_password(user, new_hash)
async def authenticate(self, username: str, password: str) -> User | None:
user = await self._repo.get_by_username(username)
if user is None:
@@ -87,6 +93,13 @@ async def deactivate_user(session: AsyncSession, user: User) -> None:
await service.deactivate_user(user)
async def change_password(
session: AsyncSession, user: User, current_password: str, new_password: str
) -> None:
service = UserService(session)
await service.change_password(user, current_password, new_password)
async def authenticate(session: AsyncSession, username: str, password: str) -> User | None:
service = UserService(session)
return await service.authenticate(username, password)