from __future__ import annotations import asyncio import logging import threading import time import uuid from pathlib import Path from tempfile import NamedTemporaryFile from typing import Callable, cast import anyio import httpx from cpv3.infrastructure.settings import get_settings from cpv3.infrastructure.storage.base import StorageService from cpv3.modules.transcription.constants import ( FIRST_LINE_IN_DOCUMENT, FIRST_LINE_IN_SEGMENT, FIRST_SEGMENT_IN_DOCUMENT, FIRST_WORD_IN_DOCUMENT, FIRST_WORD_IN_LINE, FIRST_WORD_IN_SEGMENT, LAST_LINE_IN_DOCUMENT, LAST_LINE_IN_SEGMENT, LAST_SEGMENT_IN_DOCUMENT, LAST_WORD_IN_DOCUMENT, LAST_WORD_IN_LINE, LAST_WORD_IN_SEGMENT, ) from cpv3.modules.transcription.schemas import ( Document, GoogleSpeechResult, GoogleSpeechSegment, GoogleSpeechWord, LineNode, SaluteSpeechSegment, SaluteSpeechWord, SegmentNode, Tag, TimeRange, WhisperResult, WhisperSegment, WhisperWord, WordNode, WordOptions, ) # ---------------------------------- SaluteSpeech Constants ---------------------------------- SALUTE_AUTH_URL = "https://ngw.devices.sberbank.ru:9443/api/v2/oauth" SALUTE_API_BASE = "https://smartspeech.sber.ru/rest/v1" SALUTE_POLL_INTERVAL_SECONDS = 5.0 SALUTE_POLL_TIMEOUT_SECONDS = 600 SALUTE_TOKEN_REFRESH_MARGIN_SECONDS = 60 SALUTE_ENCODING_MAP: dict[str, str] = { ".mp3": "MP3", ".wav": "PCM_S16LE", ".ogg": "opus", ".flac": "FLAC", } SALUTE_CONTENT_TYPE_MAP: dict[str, str] = { ".mp3": "audio/mpeg", ".wav": "audio/wav", ".ogg": "audio/ogg", ".flac": "audio/flac", } SALUTE_LANGUAGE_MAP: dict[str, str] = { "ru": "ru-RU", "en": "en-US", } ERROR_SALUTE_AUTH_FAILED = "Ошибка авторизации SaluteSpeech: {detail}" ERROR_SALUTE_UPLOAD_FAILED = "Ошибка загрузки файла в SaluteSpeech: {detail}" ERROR_SALUTE_TASK_FAILED = "Ошибка распознавания SaluteSpeech: {detail}" ERROR_SALUTE_TIMEOUT = "Превышено время ожидания распознавания SaluteSpeech" ERROR_SALUTE_UNSUPPORTED_FORMAT = "Неподдерживаемый формат аудио для SaluteSpeech: {ext}" _salute_token_lock = threading.Lock() _salute_token: str | None = None _salute_token_expires_at: float = 0.0 logger = logging.getLogger(__name__) class DocumentBuilder: def compute_segment_lines( self, segment: WhisperSegment | GoogleSpeechSegment | SaluteSpeechSegment, max_chars_per_line: int, ) -> list[LineNode]: words = segment.words or [] lines: list[list[WhisperWord | GoogleSpeechWord]] = [] cur_line: list[WhisperWord | GoogleSpeechWord] = [] cur_len = 0 for w in words: text = (w.word or "").strip() if not text: continue extra = len(text) + (1 if cur_line else 0) if cur_line and cur_len + extra > max_chars_per_line: lines.append(cur_line) cur_line, cur_len = [w], len(text) else: cur_line.append(w) cur_len += extra if cur_line: lines.append(cur_line) result_lines: list[LineNode] = [] for rline in lines: time = TimeRange(start=rline[0].start, end=rline[-1].end) word_nodes = [ WordNode( text=(rword.word or "").strip(), time=TimeRange(start=rword.start, end=rword.end), semantic_tags=[], structure_tags=[], ) for rword in rline ] line_node = LineNode( text=" ".join((rword.word or "") for rword in rline).strip(), semantic_tags=[], structure_tags=[], time=time, words=word_nodes, ) result_lines.append(line_node) return result_lines def process_line( self, line: LineNode, is_first_line_in_document: bool, is_last_line_in_document: bool, is_first_line_in_segment: bool, is_last_line_in_segment: bool, ) -> list[WordNode]: words: list[WordNode] = [] for idx, word in enumerate(line.words): is_first = idx == 0 is_last = idx == len(line.words) - 1 rules = [ (is_first_line_in_document and is_first, FIRST_WORD_IN_DOCUMENT), (is_last_line_in_document and is_last, LAST_WORD_IN_DOCUMENT), (is_first_line_in_segment and is_first, FIRST_WORD_IN_SEGMENT), (is_last_line_in_segment and is_last, LAST_WORD_IN_SEGMENT), (is_first, FIRST_WORD_IN_LINE), (is_last, LAST_WORD_IN_LINE), ] structure_tags = [ Tag(name=tag_name) for condition, tag_name in rules if condition ] new_word = word.model_copy(update={"structure_tags": structure_tags}) words.append(new_word) return words def process_segment( self, segment: SegmentNode, is_first_segment_in_document: bool, is_last_segment_in_document: bool, ) -> list[LineNode]: lines: list[LineNode] = [] for idx, line in enumerate(segment.lines): is_first = idx == 0 is_last = idx == len(segment.lines) - 1 rules = [ (is_first_segment_in_document and is_first, FIRST_LINE_IN_DOCUMENT), (is_last_segment_in_document and is_last, LAST_LINE_IN_DOCUMENT), (is_first, FIRST_LINE_IN_SEGMENT), (is_last, LAST_LINE_IN_SEGMENT), ] structure_tags = [ Tag(name=tag_name) for condition, tag_name in rules if condition ] words = self.process_line( line, is_first_line_in_document=is_first_segment_in_document and is_first, is_last_line_in_document=is_last_segment_in_document and is_last, is_first_line_in_segment=is_first, is_last_line_in_segment=is_last, ) new_line = line.model_copy( update={"structure_tags": structure_tags, "words": words} ) lines.append(new_line) return lines def process_document(self, document: Document) -> Document: segments: list[SegmentNode] = [] for idx, segment in enumerate(document.segments): structure_tags: list[Tag] = [] is_first_segment_in_document = idx == 0 is_last_segment_in_document = idx == len(document.segments) - 1 if is_first_segment_in_document: structure_tags.append(Tag(name=FIRST_SEGMENT_IN_DOCUMENT)) if is_last_segment_in_document: structure_tags.append(Tag(name=LAST_SEGMENT_IN_DOCUMENT)) lines = self.process_segment( segment, is_first_segment_in_document, is_last_segment_in_document ) new_segment = segment.model_copy( update={"lines": lines, "structure_tags": structure_tags} ) segments.append(new_segment) return Document(segments=segments) async def _convert_local_to_ogg(input_path: str) -> tuple[str, Callable[[], None]]: with NamedTemporaryFile(suffix=".ogg", delete=False) as out: out_path = out.name proc = await asyncio.create_subprocess_exec( "ffmpeg", "-y", "-i", input_path, "-c:a", "libopus", "-b:a", "24k", "-vn", "-ac", "1", "-ar", "16000", out_path, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, ) _, stderr = await proc.communicate() if proc.returncode != 0: raise RuntimeError(f"ffmpeg failed: {stderr.decode(errors='ignore')}") def _cleanup() -> None: import os if os.path.exists(out_path): os.remove(out_path) return out_path, _cleanup def _make_document_from_segments( builder: DocumentBuilder, segments: list[WhisperSegment] | list[GoogleSpeechSegment] | list[SaluteSpeechSegment], *, max_line_width: int, ) -> Document: result_segments: list[SegmentNode] = [] for segment in segments: lines = builder.compute_segment_lines(segment, max_line_width) time = TimeRange(start=segment.start, end=segment.end) segment_node = SegmentNode( text=segment.text.strip(), semantic_tags=[], structure_tags=[], time=time, lines=lines, ) result_segments.append(segment_node) return Document(segments=result_segments) ProgressCallback = Callable[[float], None] def _whisper_transcribe_sync( *, local_file_path: str, model_name: str, language: str | None, on_progress: ProgressCallback | None = None, ) -> Document: import whisper # type: ignore[import-untyped] settings = get_settings() settings.transcription_models_dir.mkdir(parents=True, exist_ok=True) builder = DocumentBuilder() model = whisper.load_model( model_name, download_root=str(settings.transcription_models_dir) ) if language is None: audio = whisper.load_audio(local_file_path) audio = whisper.pad_or_trim(audio) mel = whisper.log_mel_spectrogram(audio, n_mels=model.dims.n_mels).to( model.device ) _, probs_raw = model.detect_language(mel) probs = cast(dict[str, float], probs_raw) language = max(probs, key=lambda k: probs[k]) if on_progress is not None: from unittest.mock import patch from tqdm import tqdm as _orig_tqdm class _ProgressTqdm(_orig_tqdm): def update(self, n=1): super().update(n) if self.total: on_progress(min(self.n / self.total * 100.0, 100.0)) with patch("whisper.transcribe.tqdm.tqdm", _ProgressTqdm): result = whisper.transcribe( audio=whisper.load_audio(local_file_path), model=model, word_timestamps=True, temperature=0.2, language=language, verbose=False, ) on_progress(100.0) else: result = whisper.transcribe( audio=whisper.load_audio(local_file_path), model=model, word_timestamps=True, temperature=0.2, language=language, verbose=None, ) parsed = WhisperResult.model_validate(result) words_options = WordOptions( highlight_words=True, max_line_width=32, max_line_count=2, ) document = _make_document_from_segments( builder, parsed.segments, max_line_width=words_options.max_line_width ) return builder.process_document(document) async def transcribe_with_whisper( storage: StorageService, *, file_key: str, model_name: str = "tiny", language: str | None = None, on_progress: ProgressCallback | None = None, ) -> Document: tmp = await storage.download_to_temp(file_key) try: return await anyio.to_thread.run_sync( lambda: _whisper_transcribe_sync( local_file_path=tmp.path, model_name=model_name, language=language, on_progress=on_progress, ) ) finally: tmp.cleanup() def _google_transcribe_sync( *, ogg_bytes: bytes, language_codes: list[str] ) -> GoogleSpeechResult: from google.cloud import speech settings = get_settings() client: speech.SpeechClient = speech.SpeechClient.from_service_account_file( str(settings.google_service_key_path) ) audio = speech.RecognitionAudio(content=ogg_bytes) config = speech.RecognitionConfig( encoding=speech.RecognitionConfig.AudioEncoding.OGG_OPUS, sample_rate_hertz=16000, language_code=language_codes[0], alternative_language_codes=( language_codes[1:] if len(language_codes) > 1 else [] ), model="latest_long", enable_word_time_offsets=True, ) operation = client.long_running_recognize(config=config, audio=audio) response = operation.result(timeout=600) segments: list[GoogleSpeechSegment] = [] full_text = "" for result in response.results: alternative = result.alternatives[0] words: list[GoogleSpeechWord] = [] for word_info in alternative.words: words.append( GoogleSpeechWord( word=word_info.word, start=word_info.start_time.total_seconds(), end=word_info.end_time.total_seconds(), ) ) if words: segment_text = alternative.transcript full_text += segment_text + " " segments.append( GoogleSpeechSegment( text=segment_text, start=words[0].start, end=words[-1].end, words=words, ) ) return GoogleSpeechResult( text=full_text.strip(), segments=segments, language=language_codes[0] ) async def transcribe_with_google_speech( storage: StorageService, *, file_key: str, language_codes: list[str] | None = None, ) -> Document: language_codes = language_codes or ["ru-RU", "en-US"] builder = DocumentBuilder() words_options = WordOptions() input_tmp = await storage.download_to_temp(file_key) try: ogg_path, ogg_cleanup = await _convert_local_to_ogg(input_tmp.path) try: with open(ogg_path, "rb") as f: content = f.read() result = await anyio.to_thread.run_sync( lambda: _google_transcribe_sync( ogg_bytes=content, language_codes=language_codes ) ) document = _make_document_from_segments( builder, result.segments, max_line_width=words_options.max_line_width ) return builder.process_document(document) finally: ogg_cleanup() finally: input_tmp.cleanup() # ---------------------------------- SaluteSpeech Engine ---------------------------------- def _parse_salute_time(s: str) -> float: """Parse SaluteSpeech timestamp string '0.480s' → 0.48.""" return float(s.rstrip("s")) def _get_salute_access_token(client: httpx.Client) -> str: """Get or refresh SaluteSpeech OAuth token. Thread-safe.""" global _salute_token, _salute_token_expires_at with _salute_token_lock: if _salute_token and time.monotonic() < ( _salute_token_expires_at - SALUTE_TOKEN_REFRESH_MARGIN_SECONDS ): return _salute_token settings = get_settings() response = client.post( SALUTE_AUTH_URL, headers={ "Authorization": f"Basic {settings.salute_auth_key}", "RqUID": str(uuid.uuid4()), "Content-Type": "application/x-www-form-urlencoded", }, content=f"scope={settings.salute_scope}", ) if response.status_code != 200: raise RuntimeError( ERROR_SALUTE_AUTH_FAILED.format(detail=response.text[:200]) ) data = response.json() _salute_token = data["access_token"] expires_in_seconds = (data["expires_at"] / 1000) - time.time() _salute_token_expires_at = time.monotonic() + expires_in_seconds return _salute_token def _upload_salute_audio( client: httpx.Client, token: str, audio_data: bytes, content_type: str ) -> str: """Upload audio to SaluteSpeech, return request_file_id.""" response = client.post( f"{SALUTE_API_BASE}/data:upload", headers={ "Authorization": f"Bearer {token}", "Content-Type": content_type, }, content=audio_data, timeout=120.0, ) if response.status_code != 200: raise RuntimeError( ERROR_SALUTE_UPLOAD_FAILED.format(detail=response.text[:200]) ) return response.json()["result"]["request_file_id"] def _create_salute_task( client: httpx.Client, token: str, file_id: str, *, language: str, model: str, audio_encoding: str, sample_rate: int, ) -> str: """Create async recognition task, return task_id.""" body = { "options": { "audio_encoding": audio_encoding, "sample_rate": sample_rate, "language": language, "model": model, "channels_count": 1, "hypotheses_count": 1, }, "request_file_id": file_id, } response = client.post( f"{SALUTE_API_BASE}/speech:async_recognize", headers={ "Authorization": f"Bearer {token}", "Content-Type": "application/json", }, json=body, ) if response.status_code != 200: raise RuntimeError( ERROR_SALUTE_TASK_FAILED.format(detail=response.text[:200]) ) return response.json()["result"]["id"] def _poll_salute_task( client: httpx.Client, token: str, task_id: str, job_uuid: uuid.UUID | None, on_progress: ProgressCallback | None, ) -> str: """Poll task until DONE, return response_file_id.""" start = time.monotonic() while True: elapsed = time.monotonic() - start if elapsed > SALUTE_POLL_TIMEOUT_SECONDS: raise TimeoutError(ERROR_SALUTE_TIMEOUT) if job_uuid is not None: from cpv3.modules.tasks.service import _raise_if_job_cancelled _raise_if_job_cancelled(job_uuid) response = client.get( f"{SALUTE_API_BASE}/task:get", params={"id": task_id}, headers={"Authorization": f"Bearer {token}"}, ) response.raise_for_status() result = response.json()["result"] status = result["status"] if status == "DONE": return result["response_file_id"] if status == "ERROR": error_msg = result.get("error", "unknown error") raise RuntimeError( ERROR_SALUTE_TASK_FAILED.format(detail=error_msg) ) if on_progress is not None: pct = min(elapsed / SALUTE_POLL_TIMEOUT_SECONDS * 100, 95.0) on_progress(pct) time.sleep(SALUTE_POLL_INTERVAL_SECONDS) def _download_salute_result( client: httpx.Client, token: str, response_file_id: str ) -> list[dict]: """Download recognition result JSON.""" response = client.get( f"{SALUTE_API_BASE}/data:download", params={"response_file_id": response_file_id}, headers={"Authorization": f"Bearer {token}"}, timeout=60.0, ) response.raise_for_status() return response.json() def _build_document_from_salute_result( raw_channels: list[dict], *, language: str ) -> Document: """Convert SaluteSpeech result JSON to Document.""" builder = DocumentBuilder() words_options = WordOptions() all_segments: list[SaluteSpeechSegment] = [] for channel_data in raw_channels: for result_item in channel_data.get("results", []): word_alignments = result_item.get("word_alignments", []) words = [ SaluteSpeechWord( word=w["word"], start=_parse_salute_time(w["start"]), end=_parse_salute_time(w["end"]), ) for w in word_alignments ] text = result_item.get("text", "") seg_start = _parse_salute_time(result_item["start"]) seg_end = _parse_salute_time(result_item["end"]) all_segments.append( SaluteSpeechSegment( text=text, start=seg_start, end=seg_end, words=words, ) ) document = _make_document_from_segments( builder, all_segments, max_line_width=words_options.max_line_width ) return builder.process_document(document) def _convert_to_wav_sync(input_path: str, sample_rate: int = 16000) -> tuple[str, Callable[[], None]]: """Convert any audio/video to WAV (PCM signed 16-bit LE) using ffmpeg. Sync version.""" import os import subprocess with NamedTemporaryFile(suffix=".wav", delete=False) as out: out_path = out.name result = subprocess.run( [ "ffmpeg", "-y", "-i", input_path, "-vn", "-ac", "1", "-ar", str(sample_rate), "-acodec", "pcm_s16le", out_path, ], capture_output=True, ) if result.returncode != 0: raise RuntimeError(f"ffmpeg failed: {result.stderr.decode(errors='ignore')}") def _cleanup() -> None: if os.path.exists(out_path): os.remove(out_path) return out_path, _cleanup def _salute_transcribe_sync( *, local_file_path: str, language: str | None, model: str, sample_rate: int, job_id: uuid.UUID | None = None, on_progress: ProgressCallback | None = None, ) -> Document: """Synchronous SaluteSpeech transcription (runs in Dramatiq worker thread).""" settings = get_settings() ext = Path(local_file_path).suffix.lower() audio_encoding = SALUTE_ENCODING_MAP.get(ext) content_type = SALUTE_CONTENT_TYPE_MAP.get(ext) # Convert unsupported formats (mp4, webm, m4a, etc.) to WAV via ffmpeg cleanup_fn: Callable[[], None] | None = None if not audio_encoding or not content_type: wav_path, cleanup_fn = _convert_to_wav_sync(local_file_path, sample_rate) local_file_path = wav_path audio_encoding = "PCM_S16LE" content_type = "audio/wav" salute_language = SALUTE_LANGUAGE_MAP.get(language or "", "ru-RU") try: verify = str(settings.salute_ca_cert_path) if settings.salute_ca_cert_path else True with httpx.Client(verify=verify, timeout=30.0) as client: token = _get_salute_access_token(client) with open(local_file_path, "rb") as f: audio_data = f.read() file_id = _upload_salute_audio(client, token, audio_data, content_type) task_id = _create_salute_task( client, token, file_id, language=salute_language, model=model, audio_encoding=audio_encoding, sample_rate=sample_rate, ) response_file_id = _poll_salute_task( client, token, task_id, job_id, on_progress ) raw_result = _download_salute_result(client, token, response_file_id) return _build_document_from_salute_result(raw_result, language=salute_language) finally: if cleanup_fn is not None: cleanup_fn() async def transcribe_with_salute_speech( storage: StorageService, *, file_key: str, language: str | None = None, model: str = "general", sample_rate: int = 16000, job_id: uuid.UUID | None = None, on_progress: ProgressCallback | None = None, ) -> Document: """Async wrapper for SaluteSpeech transcription.""" tmp = await storage.download_to_temp(file_key) try: return await anyio.to_thread.run_sync( lambda: _salute_transcribe_sync( local_file_path=tmp.path, language=language, model=model, sample_rate=sample_rate, job_id=job_id, on_progress=on_progress, ) ) finally: tmp.cleanup()