diff --git a/cpv3/modules/transcription/service.py b/cpv3/modules/transcription/service.py index 5bb7041..c8ad8ab 100644 --- a/cpv3/modules/transcription/service.py +++ b/cpv3/modules/transcription/service.py @@ -1,10 +1,16 @@ from __future__ import annotations import asyncio +import logging +import threading +import time +import uuid +from pathlib import Path from tempfile import NamedTemporaryFile from typing import Callable, cast import anyio +import httpx from cpv3.infrastructure.settings import get_settings from cpv3.infrastructure.storage.base import StorageService @@ -29,6 +35,7 @@ from cpv3.modules.transcription.schemas import ( GoogleSpeechWord, LineNode, SaluteSpeechSegment, + SaluteSpeechWord, SegmentNode, Tag, TimeRange, @@ -40,6 +47,46 @@ from cpv3.modules.transcription.schemas import ( ) +# ---------------------------------- SaluteSpeech Constants ---------------------------------- + +SALUTE_AUTH_URL = "https://ngw.devices.sberbank.ru:9443/api/v2/oauth" +SALUTE_API_BASE = "https://smartspeech.sber.ru/rest/v1" +SALUTE_POLL_INTERVAL_SECONDS = 5.0 +SALUTE_POLL_TIMEOUT_SECONDS = 600 +SALUTE_TOKEN_REFRESH_MARGIN_SECONDS = 60 + +SALUTE_ENCODING_MAP: dict[str, str] = { + ".mp3": "MP3", + ".wav": "PCM_S16LE", + ".ogg": "opus", + ".flac": "FLAC", +} + +SALUTE_CONTENT_TYPE_MAP: dict[str, str] = { + ".mp3": "audio/mpeg", + ".wav": "audio/wav", + ".ogg": "audio/ogg", + ".flac": "audio/flac", +} + +SALUTE_LANGUAGE_MAP: dict[str, str] = { + "ru": "ru-RU", + "en": "en-US", +} + +ERROR_SALUTE_AUTH_FAILED = "Ошибка авторизации SaluteSpeech: {detail}" +ERROR_SALUTE_UPLOAD_FAILED = "Ошибка загрузки файла в SaluteSpeech: {detail}" +ERROR_SALUTE_TASK_FAILED = "Ошибка распознавания SaluteSpeech: {detail}" +ERROR_SALUTE_TIMEOUT = "Превышено время ожидания распознавания SaluteSpeech" +ERROR_SALUTE_UNSUPPORTED_FORMAT = "Неподдерживаемый формат аудио для SaluteSpeech: {ext}" + +_salute_token_lock = threading.Lock() +_salute_token: str | None = None +_salute_token_expires_at: float = 0.0 + +logger = logging.getLogger(__name__) + + class DocumentBuilder: def compute_segment_lines( self, @@ -430,3 +477,267 @@ async def transcribe_with_google_speech( ogg_cleanup() finally: input_tmp.cleanup() + + +# ---------------------------------- SaluteSpeech Engine ---------------------------------- + + +def _parse_salute_time(s: str) -> float: + """Parse SaluteSpeech timestamp string '0.480s' → 0.48.""" + return float(s.rstrip("s")) + + +def _get_salute_access_token(client: httpx.Client) -> str: + """Get or refresh SaluteSpeech OAuth token. Thread-safe.""" + global _salute_token, _salute_token_expires_at + with _salute_token_lock: + if _salute_token and time.monotonic() < ( + _salute_token_expires_at - SALUTE_TOKEN_REFRESH_MARGIN_SECONDS + ): + return _salute_token + + settings = get_settings() + response = client.post( + SALUTE_AUTH_URL, + headers={ + "Authorization": f"Basic {settings.salute_auth_key}", + "RqUID": str(uuid.uuid4()), + "Content-Type": "application/x-www-form-urlencoded", + }, + content=f"scope={settings.salute_scope}", + ) + if response.status_code != 200: + raise RuntimeError( + ERROR_SALUTE_AUTH_FAILED.format(detail=response.text[:200]) + ) + data = response.json() + _salute_token = data["access_token"] + expires_in_seconds = (data["expires_at"] / 1000) - time.time() + _salute_token_expires_at = time.monotonic() + expires_in_seconds + return _salute_token + + +def _upload_salute_audio( + client: httpx.Client, token: str, audio_data: bytes, content_type: str +) -> str: + """Upload audio to SaluteSpeech, return request_file_id.""" + response = client.post( + f"{SALUTE_API_BASE}/data:upload", + headers={ + "Authorization": f"Bearer {token}", + "Content-Type": content_type, + }, + content=audio_data, + timeout=120.0, + ) + if response.status_code != 200: + raise RuntimeError( + ERROR_SALUTE_UPLOAD_FAILED.format(detail=response.text[:200]) + ) + return response.json()["result"]["request_file_id"] + + +def _create_salute_task( + client: httpx.Client, + token: str, + file_id: str, + *, + language: str, + model: str, + audio_encoding: str, + sample_rate: int, +) -> str: + """Create async recognition task, return task_id.""" + body = { + "options": { + "audio_encoding": audio_encoding, + "sample_rate": sample_rate, + "language": language, + "model": model, + "channels_count": 1, + "hypotheses_count": 1, + }, + "request_file_id": file_id, + } + response = client.post( + f"{SALUTE_API_BASE}/speech:async_recognize", + headers={ + "Authorization": f"Bearer {token}", + "Content-Type": "application/json", + }, + json=body, + ) + if response.status_code != 200: + raise RuntimeError( + ERROR_SALUTE_TASK_FAILED.format(detail=response.text[:200]) + ) + return response.json()["result"]["id"] + + +def _poll_salute_task( + client: httpx.Client, + token: str, + task_id: str, + job_uuid: uuid.UUID | None, + on_progress: ProgressCallback | None, +) -> str: + """Poll task until DONE, return response_file_id.""" + start = time.monotonic() + while True: + elapsed = time.monotonic() - start + if elapsed > SALUTE_POLL_TIMEOUT_SECONDS: + raise TimeoutError(ERROR_SALUTE_TIMEOUT) + + if job_uuid is not None: + from cpv3.modules.tasks.service import _raise_if_job_cancelled + + _raise_if_job_cancelled(job_uuid) + + response = client.get( + f"{SALUTE_API_BASE}/task:get", + params={"id": task_id}, + headers={"Authorization": f"Bearer {token}"}, + ) + response.raise_for_status() + result = response.json()["result"] + status = result["status"] + + if status == "DONE": + return result["response_file_id"] + if status == "ERROR": + error_msg = result.get("error", "unknown error") + raise RuntimeError( + ERROR_SALUTE_TASK_FAILED.format(detail=error_msg) + ) + + if on_progress is not None: + pct = min(elapsed / SALUTE_POLL_TIMEOUT_SECONDS * 100, 95.0) + on_progress(pct) + + time.sleep(SALUTE_POLL_INTERVAL_SECONDS) + + +def _download_salute_result( + client: httpx.Client, token: str, response_file_id: str +) -> list[dict]: + """Download recognition result JSON.""" + response = client.get( + f"{SALUTE_API_BASE}/data:download", + params={"response_file_id": response_file_id}, + headers={"Authorization": f"Bearer {token}"}, + timeout=60.0, + ) + response.raise_for_status() + return response.json() + + +def _build_document_from_salute_result( + raw_channels: list[dict], *, language: str +) -> Document: + """Convert SaluteSpeech result JSON to Document.""" + builder = DocumentBuilder() + words_options = WordOptions() + + all_segments: list[SaluteSpeechSegment] = [] + + for channel_data in raw_channels: + for result_item in channel_data.get("results", []): + word_alignments = result_item.get("word_alignments", []) + words = [ + SaluteSpeechWord( + word=w["word"], + start=_parse_salute_time(w["start"]), + end=_parse_salute_time(w["end"]), + ) + for w in word_alignments + ] + + text = result_item.get("text", "") + seg_start = _parse_salute_time(result_item["start"]) + seg_end = _parse_salute_time(result_item["end"]) + + all_segments.append( + SaluteSpeechSegment( + text=text, + start=seg_start, + end=seg_end, + words=words, + ) + ) + + document = _make_document_from_segments( + builder, all_segments, max_line_width=words_options.max_line_width + ) + return builder.process_document(document) + + +def _salute_transcribe_sync( + *, + local_file_path: str, + language: str | None, + model: str, + sample_rate: int, + job_id: uuid.UUID | None = None, + on_progress: ProgressCallback | None = None, +) -> Document: + """Synchronous SaluteSpeech transcription (runs in Dramatiq worker thread).""" + settings = get_settings() + + ext = Path(local_file_path).suffix.lower() + audio_encoding = SALUTE_ENCODING_MAP.get(ext) + content_type = SALUTE_CONTENT_TYPE_MAP.get(ext) + if not audio_encoding or not content_type: + raise ValueError(ERROR_SALUTE_UNSUPPORTED_FORMAT.format(ext=ext)) + + salute_language = SALUTE_LANGUAGE_MAP.get(language or "", "ru-RU") + + verify = str(settings.salute_ca_cert_path) if settings.salute_ca_cert_path else True + with httpx.Client(verify=verify, timeout=30.0) as client: + token = _get_salute_access_token(client) + + with open(local_file_path, "rb") as f: + audio_data = f.read() + + file_id = _upload_salute_audio(client, token, audio_data, content_type) + task_id = _create_salute_task( + client, + token, + file_id, + language=salute_language, + model=model, + audio_encoding=audio_encoding, + sample_rate=sample_rate, + ) + response_file_id = _poll_salute_task( + client, token, task_id, job_id, on_progress + ) + raw_result = _download_salute_result(client, token, response_file_id) + + return _build_document_from_salute_result(raw_result, language=salute_language) + + +async def transcribe_with_salute_speech( + storage: StorageService, + *, + file_key: str, + language: str | None = None, + model: str = "general", + sample_rate: int = 16000, + job_id: uuid.UUID | None = None, + on_progress: ProgressCallback | None = None, +) -> Document: + """Async wrapper for SaluteSpeech transcription.""" + tmp = await storage.download_to_temp(file_key) + try: + return await anyio.to_thread.run_sync( + lambda: _salute_transcribe_sync( + local_file_path=tmp.path, + language=language, + model=model, + sample_rate=sample_rate, + job_id=job_id, + on_progress=on_progress, + ) + ) + finally: + tmp.cleanup()