Files
main_backend/cpv3/modules/transcription/service.py
T
Daniil 2c9c11fa17 feat(backend): implement SaluteSpeech transcription engine
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-04 00:04:49 +03:00

744 lines
22 KiB
Python

from __future__ import annotations
import asyncio
import logging
import threading
import time
import uuid
from pathlib import Path
from tempfile import NamedTemporaryFile
from typing import Callable, cast
import anyio
import httpx
from cpv3.infrastructure.settings import get_settings
from cpv3.infrastructure.storage.base import StorageService
from cpv3.modules.transcription.constants import (
FIRST_LINE_IN_DOCUMENT,
FIRST_LINE_IN_SEGMENT,
FIRST_SEGMENT_IN_DOCUMENT,
FIRST_WORD_IN_DOCUMENT,
FIRST_WORD_IN_LINE,
FIRST_WORD_IN_SEGMENT,
LAST_LINE_IN_DOCUMENT,
LAST_LINE_IN_SEGMENT,
LAST_SEGMENT_IN_DOCUMENT,
LAST_WORD_IN_DOCUMENT,
LAST_WORD_IN_LINE,
LAST_WORD_IN_SEGMENT,
)
from cpv3.modules.transcription.schemas import (
Document,
GoogleSpeechResult,
GoogleSpeechSegment,
GoogleSpeechWord,
LineNode,
SaluteSpeechSegment,
SaluteSpeechWord,
SegmentNode,
Tag,
TimeRange,
WhisperResult,
WhisperSegment,
WhisperWord,
WordNode,
WordOptions,
)
# ---------------------------------- SaluteSpeech Constants ----------------------------------
SALUTE_AUTH_URL = "https://ngw.devices.sberbank.ru:9443/api/v2/oauth"
SALUTE_API_BASE = "https://smartspeech.sber.ru/rest/v1"
SALUTE_POLL_INTERVAL_SECONDS = 5.0
SALUTE_POLL_TIMEOUT_SECONDS = 600
SALUTE_TOKEN_REFRESH_MARGIN_SECONDS = 60
SALUTE_ENCODING_MAP: dict[str, str] = {
".mp3": "MP3",
".wav": "PCM_S16LE",
".ogg": "opus",
".flac": "FLAC",
}
SALUTE_CONTENT_TYPE_MAP: dict[str, str] = {
".mp3": "audio/mpeg",
".wav": "audio/wav",
".ogg": "audio/ogg",
".flac": "audio/flac",
}
SALUTE_LANGUAGE_MAP: dict[str, str] = {
"ru": "ru-RU",
"en": "en-US",
}
ERROR_SALUTE_AUTH_FAILED = "Ошибка авторизации SaluteSpeech: {detail}"
ERROR_SALUTE_UPLOAD_FAILED = "Ошибка загрузки файла в SaluteSpeech: {detail}"
ERROR_SALUTE_TASK_FAILED = "Ошибка распознавания SaluteSpeech: {detail}"
ERROR_SALUTE_TIMEOUT = "Превышено время ожидания распознавания SaluteSpeech"
ERROR_SALUTE_UNSUPPORTED_FORMAT = "Неподдерживаемый формат аудио для SaluteSpeech: {ext}"
_salute_token_lock = threading.Lock()
_salute_token: str | None = None
_salute_token_expires_at: float = 0.0
logger = logging.getLogger(__name__)
class DocumentBuilder:
def compute_segment_lines(
self,
segment: WhisperSegment | GoogleSpeechSegment | SaluteSpeechSegment,
max_chars_per_line: int,
) -> list[LineNode]:
words = segment.words or []
lines: list[list[WhisperWord | GoogleSpeechWord]] = []
cur_line: list[WhisperWord | GoogleSpeechWord] = []
cur_len = 0
for w in words:
text = (w.word or "").strip()
if not text:
continue
extra = len(text) + (1 if cur_line else 0)
if cur_line and cur_len + extra > max_chars_per_line:
lines.append(cur_line)
cur_line, cur_len = [w], len(text)
else:
cur_line.append(w)
cur_len += extra
if cur_line:
lines.append(cur_line)
result_lines: list[LineNode] = []
for rline in lines:
time = TimeRange(start=rline[0].start, end=rline[-1].end)
word_nodes = [
WordNode(
text=(rword.word or "").strip(),
time=TimeRange(start=rword.start, end=rword.end),
semantic_tags=[],
structure_tags=[],
)
for rword in rline
]
line_node = LineNode(
text=" ".join((rword.word or "") for rword in rline).strip(),
semantic_tags=[],
structure_tags=[],
time=time,
words=word_nodes,
)
result_lines.append(line_node)
return result_lines
def process_line(
self,
line: LineNode,
is_first_line_in_document: bool,
is_last_line_in_document: bool,
is_first_line_in_segment: bool,
is_last_line_in_segment: bool,
) -> list[WordNode]:
words: list[WordNode] = []
for idx, word in enumerate(line.words):
is_first = idx == 0
is_last = idx == len(line.words) - 1
rules = [
(is_first_line_in_document and is_first, FIRST_WORD_IN_DOCUMENT),
(is_last_line_in_document and is_last, LAST_WORD_IN_DOCUMENT),
(is_first_line_in_segment and is_first, FIRST_WORD_IN_SEGMENT),
(is_last_line_in_segment and is_last, LAST_WORD_IN_SEGMENT),
(is_first, FIRST_WORD_IN_LINE),
(is_last, LAST_WORD_IN_LINE),
]
structure_tags = [
Tag(name=tag_name) for condition, tag_name in rules if condition
]
new_word = word.model_copy(update={"structure_tags": structure_tags})
words.append(new_word)
return words
def process_segment(
self,
segment: SegmentNode,
is_first_segment_in_document: bool,
is_last_segment_in_document: bool,
) -> list[LineNode]:
lines: list[LineNode] = []
for idx, line in enumerate(segment.lines):
is_first = idx == 0
is_last = idx == len(segment.lines) - 1
rules = [
(is_first_segment_in_document and is_first, FIRST_LINE_IN_DOCUMENT),
(is_last_segment_in_document and is_last, LAST_LINE_IN_DOCUMENT),
(is_first, FIRST_LINE_IN_SEGMENT),
(is_last, LAST_LINE_IN_SEGMENT),
]
structure_tags = [
Tag(name=tag_name) for condition, tag_name in rules if condition
]
words = self.process_line(
line,
is_first_line_in_document=is_first_segment_in_document and is_first,
is_last_line_in_document=is_last_segment_in_document and is_last,
is_first_line_in_segment=is_first,
is_last_line_in_segment=is_last,
)
new_line = line.model_copy(
update={"structure_tags": structure_tags, "words": words}
)
lines.append(new_line)
return lines
def process_document(self, document: Document) -> Document:
segments: list[SegmentNode] = []
for idx, segment in enumerate(document.segments):
structure_tags: list[Tag] = []
is_first_segment_in_document = idx == 0
is_last_segment_in_document = idx == len(document.segments) - 1
if is_first_segment_in_document:
structure_tags.append(Tag(name=FIRST_SEGMENT_IN_DOCUMENT))
if is_last_segment_in_document:
structure_tags.append(Tag(name=LAST_SEGMENT_IN_DOCUMENT))
lines = self.process_segment(
segment, is_first_segment_in_document, is_last_segment_in_document
)
new_segment = segment.model_copy(
update={"lines": lines, "structure_tags": structure_tags}
)
segments.append(new_segment)
return Document(segments=segments)
async def _convert_local_to_ogg(input_path: str) -> tuple[str, Callable[[], None]]:
with NamedTemporaryFile(suffix=".ogg", delete=False) as out:
out_path = out.name
proc = await asyncio.create_subprocess_exec(
"ffmpeg",
"-y",
"-i",
input_path,
"-c:a",
"libopus",
"-b:a",
"24k",
"-vn",
"-ac",
"1",
"-ar",
"16000",
out_path,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
_, stderr = await proc.communicate()
if proc.returncode != 0:
raise RuntimeError(f"ffmpeg failed: {stderr.decode(errors='ignore')}")
def _cleanup() -> None:
import os
if os.path.exists(out_path):
os.remove(out_path)
return out_path, _cleanup
def _make_document_from_segments(
builder: DocumentBuilder,
segments: list[WhisperSegment] | list[GoogleSpeechSegment] | list[SaluteSpeechSegment],
*,
max_line_width: int,
) -> Document:
result_segments: list[SegmentNode] = []
for segment in segments:
lines = builder.compute_segment_lines(segment, max_line_width)
time = TimeRange(start=segment.start, end=segment.end)
segment_node = SegmentNode(
text=segment.text.strip(),
semantic_tags=[],
structure_tags=[],
time=time,
lines=lines,
)
result_segments.append(segment_node)
return Document(segments=result_segments)
ProgressCallback = Callable[[float], None]
def _whisper_transcribe_sync(
*,
local_file_path: str,
model_name: str,
language: str | None,
on_progress: ProgressCallback | None = None,
) -> Document:
import whisper # type: ignore[import-untyped]
settings = get_settings()
settings.transcription_models_dir.mkdir(parents=True, exist_ok=True)
builder = DocumentBuilder()
model = whisper.load_model(
model_name, download_root=str(settings.transcription_models_dir)
)
if language is None:
audio = whisper.load_audio(local_file_path)
audio = whisper.pad_or_trim(audio)
mel = whisper.log_mel_spectrogram(audio, n_mels=model.dims.n_mels).to(
model.device
)
_, probs_raw = model.detect_language(mel)
probs = cast(dict[str, float], probs_raw)
language = max(probs, key=lambda k: probs[k])
if on_progress is not None:
from unittest.mock import patch
from tqdm import tqdm as _orig_tqdm
class _ProgressTqdm(_orig_tqdm):
def update(self, n=1):
super().update(n)
if self.total:
on_progress(min(self.n / self.total * 100.0, 100.0))
with patch("whisper.transcribe.tqdm.tqdm", _ProgressTqdm):
result = whisper.transcribe(
audio=whisper.load_audio(local_file_path),
model=model,
word_timestamps=True,
temperature=0.2,
language=language,
verbose=False,
)
on_progress(100.0)
else:
result = whisper.transcribe(
audio=whisper.load_audio(local_file_path),
model=model,
word_timestamps=True,
temperature=0.2,
language=language,
verbose=None,
)
parsed = WhisperResult.model_validate(result)
words_options = WordOptions(
highlight_words=True,
max_line_width=32,
max_line_count=2,
)
document = _make_document_from_segments(
builder, parsed.segments, max_line_width=words_options.max_line_width
)
return builder.process_document(document)
async def transcribe_with_whisper(
storage: StorageService,
*,
file_key: str,
model_name: str = "tiny",
language: str | None = None,
on_progress: ProgressCallback | None = None,
) -> Document:
tmp = await storage.download_to_temp(file_key)
try:
return await anyio.to_thread.run_sync(
lambda: _whisper_transcribe_sync(
local_file_path=tmp.path,
model_name=model_name,
language=language,
on_progress=on_progress,
)
)
finally:
tmp.cleanup()
def _google_transcribe_sync(
*, ogg_bytes: bytes, language_codes: list[str]
) -> GoogleSpeechResult:
from google.cloud import speech
settings = get_settings()
client: speech.SpeechClient = speech.SpeechClient.from_service_account_file(
str(settings.google_service_key_path)
)
audio = speech.RecognitionAudio(content=ogg_bytes)
config = speech.RecognitionConfig(
encoding=speech.RecognitionConfig.AudioEncoding.OGG_OPUS,
sample_rate_hertz=16000,
language_code=language_codes[0],
alternative_language_codes=(
language_codes[1:] if len(language_codes) > 1 else []
),
model="latest_long",
enable_word_time_offsets=True,
)
operation = client.long_running_recognize(config=config, audio=audio)
response = operation.result(timeout=600)
segments: list[GoogleSpeechSegment] = []
full_text = ""
for result in response.results:
alternative = result.alternatives[0]
words: list[GoogleSpeechWord] = []
for word_info in alternative.words:
words.append(
GoogleSpeechWord(
word=word_info.word,
start=word_info.start_time.total_seconds(),
end=word_info.end_time.total_seconds(),
)
)
if words:
segment_text = alternative.transcript
full_text += segment_text + " "
segments.append(
GoogleSpeechSegment(
text=segment_text,
start=words[0].start,
end=words[-1].end,
words=words,
)
)
return GoogleSpeechResult(
text=full_text.strip(), segments=segments, language=language_codes[0]
)
async def transcribe_with_google_speech(
storage: StorageService,
*,
file_key: str,
language_codes: list[str] | None = None,
) -> Document:
language_codes = language_codes or ["ru-RU", "en-US"]
builder = DocumentBuilder()
words_options = WordOptions()
input_tmp = await storage.download_to_temp(file_key)
try:
ogg_path, ogg_cleanup = await _convert_local_to_ogg(input_tmp.path)
try:
with open(ogg_path, "rb") as f:
content = f.read()
result = await anyio.to_thread.run_sync(
lambda: _google_transcribe_sync(
ogg_bytes=content, language_codes=language_codes
)
)
document = _make_document_from_segments(
builder, result.segments, max_line_width=words_options.max_line_width
)
return builder.process_document(document)
finally:
ogg_cleanup()
finally:
input_tmp.cleanup()
# ---------------------------------- SaluteSpeech Engine ----------------------------------
def _parse_salute_time(s: str) -> float:
"""Parse SaluteSpeech timestamp string '0.480s' → 0.48."""
return float(s.rstrip("s"))
def _get_salute_access_token(client: httpx.Client) -> str:
"""Get or refresh SaluteSpeech OAuth token. Thread-safe."""
global _salute_token, _salute_token_expires_at
with _salute_token_lock:
if _salute_token and time.monotonic() < (
_salute_token_expires_at - SALUTE_TOKEN_REFRESH_MARGIN_SECONDS
):
return _salute_token
settings = get_settings()
response = client.post(
SALUTE_AUTH_URL,
headers={
"Authorization": f"Basic {settings.salute_auth_key}",
"RqUID": str(uuid.uuid4()),
"Content-Type": "application/x-www-form-urlencoded",
},
content=f"scope={settings.salute_scope}",
)
if response.status_code != 200:
raise RuntimeError(
ERROR_SALUTE_AUTH_FAILED.format(detail=response.text[:200])
)
data = response.json()
_salute_token = data["access_token"]
expires_in_seconds = (data["expires_at"] / 1000) - time.time()
_salute_token_expires_at = time.monotonic() + expires_in_seconds
return _salute_token
def _upload_salute_audio(
client: httpx.Client, token: str, audio_data: bytes, content_type: str
) -> str:
"""Upload audio to SaluteSpeech, return request_file_id."""
response = client.post(
f"{SALUTE_API_BASE}/data:upload",
headers={
"Authorization": f"Bearer {token}",
"Content-Type": content_type,
},
content=audio_data,
timeout=120.0,
)
if response.status_code != 200:
raise RuntimeError(
ERROR_SALUTE_UPLOAD_FAILED.format(detail=response.text[:200])
)
return response.json()["result"]["request_file_id"]
def _create_salute_task(
client: httpx.Client,
token: str,
file_id: str,
*,
language: str,
model: str,
audio_encoding: str,
sample_rate: int,
) -> str:
"""Create async recognition task, return task_id."""
body = {
"options": {
"audio_encoding": audio_encoding,
"sample_rate": sample_rate,
"language": language,
"model": model,
"channels_count": 1,
"hypotheses_count": 1,
},
"request_file_id": file_id,
}
response = client.post(
f"{SALUTE_API_BASE}/speech:async_recognize",
headers={
"Authorization": f"Bearer {token}",
"Content-Type": "application/json",
},
json=body,
)
if response.status_code != 200:
raise RuntimeError(
ERROR_SALUTE_TASK_FAILED.format(detail=response.text[:200])
)
return response.json()["result"]["id"]
def _poll_salute_task(
client: httpx.Client,
token: str,
task_id: str,
job_uuid: uuid.UUID | None,
on_progress: ProgressCallback | None,
) -> str:
"""Poll task until DONE, return response_file_id."""
start = time.monotonic()
while True:
elapsed = time.monotonic() - start
if elapsed > SALUTE_POLL_TIMEOUT_SECONDS:
raise TimeoutError(ERROR_SALUTE_TIMEOUT)
if job_uuid is not None:
from cpv3.modules.tasks.service import _raise_if_job_cancelled
_raise_if_job_cancelled(job_uuid)
response = client.get(
f"{SALUTE_API_BASE}/task:get",
params={"id": task_id},
headers={"Authorization": f"Bearer {token}"},
)
response.raise_for_status()
result = response.json()["result"]
status = result["status"]
if status == "DONE":
return result["response_file_id"]
if status == "ERROR":
error_msg = result.get("error", "unknown error")
raise RuntimeError(
ERROR_SALUTE_TASK_FAILED.format(detail=error_msg)
)
if on_progress is not None:
pct = min(elapsed / SALUTE_POLL_TIMEOUT_SECONDS * 100, 95.0)
on_progress(pct)
time.sleep(SALUTE_POLL_INTERVAL_SECONDS)
def _download_salute_result(
client: httpx.Client, token: str, response_file_id: str
) -> list[dict]:
"""Download recognition result JSON."""
response = client.get(
f"{SALUTE_API_BASE}/data:download",
params={"response_file_id": response_file_id},
headers={"Authorization": f"Bearer {token}"},
timeout=60.0,
)
response.raise_for_status()
return response.json()
def _build_document_from_salute_result(
raw_channels: list[dict], *, language: str
) -> Document:
"""Convert SaluteSpeech result JSON to Document."""
builder = DocumentBuilder()
words_options = WordOptions()
all_segments: list[SaluteSpeechSegment] = []
for channel_data in raw_channels:
for result_item in channel_data.get("results", []):
word_alignments = result_item.get("word_alignments", [])
words = [
SaluteSpeechWord(
word=w["word"],
start=_parse_salute_time(w["start"]),
end=_parse_salute_time(w["end"]),
)
for w in word_alignments
]
text = result_item.get("text", "")
seg_start = _parse_salute_time(result_item["start"])
seg_end = _parse_salute_time(result_item["end"])
all_segments.append(
SaluteSpeechSegment(
text=text,
start=seg_start,
end=seg_end,
words=words,
)
)
document = _make_document_from_segments(
builder, all_segments, max_line_width=words_options.max_line_width
)
return builder.process_document(document)
def _salute_transcribe_sync(
*,
local_file_path: str,
language: str | None,
model: str,
sample_rate: int,
job_id: uuid.UUID | None = None,
on_progress: ProgressCallback | None = None,
) -> Document:
"""Synchronous SaluteSpeech transcription (runs in Dramatiq worker thread)."""
settings = get_settings()
ext = Path(local_file_path).suffix.lower()
audio_encoding = SALUTE_ENCODING_MAP.get(ext)
content_type = SALUTE_CONTENT_TYPE_MAP.get(ext)
if not audio_encoding or not content_type:
raise ValueError(ERROR_SALUTE_UNSUPPORTED_FORMAT.format(ext=ext))
salute_language = SALUTE_LANGUAGE_MAP.get(language or "", "ru-RU")
verify = str(settings.salute_ca_cert_path) if settings.salute_ca_cert_path else True
with httpx.Client(verify=verify, timeout=30.0) as client:
token = _get_salute_access_token(client)
with open(local_file_path, "rb") as f:
audio_data = f.read()
file_id = _upload_salute_audio(client, token, audio_data, content_type)
task_id = _create_salute_task(
client,
token,
file_id,
language=salute_language,
model=model,
audio_encoding=audio_encoding,
sample_rate=sample_rate,
)
response_file_id = _poll_salute_task(
client, token, task_id, job_id, on_progress
)
raw_result = _download_salute_result(client, token, response_file_id)
return _build_document_from_salute_result(raw_result, language=salute_language)
async def transcribe_with_salute_speech(
storage: StorageService,
*,
file_key: str,
language: str | None = None,
model: str = "general",
sample_rate: int = 16000,
job_id: uuid.UUID | None = None,
on_progress: ProgressCallback | None = None,
) -> Document:
"""Async wrapper for SaluteSpeech transcription."""
tmp = await storage.download_to_temp(file_key)
try:
return await anyio.to_thread.run_sync(
lambda: _salute_transcribe_sync(
local_file_path=tmp.path,
language=language,
model=model,
sample_rate=sample_rate,
job_id=job_id,
on_progress=on_progress,
)
)
finally:
tmp.cleanup()