Files
main_backend/cpv3/modules/transcription/service.py
T
Daniil 259d3da89f rev 4
2026-04-07 13:42:45 +03:00

816 lines
25 KiB
Python

from __future__ import annotations
import asyncio
import logging
import ssl
import threading
import time
import uuid
from pathlib import Path
from tempfile import NamedTemporaryFile
from typing import Callable, cast
import anyio
import httpx
from cpv3.infrastructure.settings import get_settings
from cpv3.infrastructure.storage.base import StorageService
from cpv3.modules.transcription.constants import (
FIRST_LINE_IN_DOCUMENT,
FIRST_LINE_IN_SEGMENT,
FIRST_SEGMENT_IN_DOCUMENT,
FIRST_WORD_IN_DOCUMENT,
FIRST_WORD_IN_LINE,
FIRST_WORD_IN_SEGMENT,
LAST_LINE_IN_DOCUMENT,
LAST_LINE_IN_SEGMENT,
LAST_SEGMENT_IN_DOCUMENT,
LAST_WORD_IN_DOCUMENT,
LAST_WORD_IN_LINE,
LAST_WORD_IN_SEGMENT,
)
from cpv3.modules.transcription.schemas import (
Document,
GoogleSpeechResult,
GoogleSpeechSegment,
GoogleSpeechWord,
LineNode,
SaluteSpeechSegment,
SaluteSpeechWord,
SegmentNode,
Tag,
TimeRange,
WhisperResult,
WhisperSegment,
WhisperWord,
WordNode,
WordOptions,
)
# ---------------------------------- SaluteSpeech Constants ----------------------------------
SALUTE_AUTH_URL = "https://ngw.devices.sberbank.ru:9443/api/v2/oauth"
SALUTE_API_BASE = "https://smartspeech.sber.ru/rest/v1"
SALUTE_POLL_INTERVAL_SECONDS = 5.0
SALUTE_POLL_TIMEOUT_SECONDS = 600
SALUTE_TOKEN_REFRESH_MARGIN_SECONDS = 60
SALUTE_ENCODING_MAP: dict[str, str] = {
".mp3": "MP3",
".wav": "PCM_S16LE",
".ogg": "opus",
".flac": "FLAC",
}
SALUTE_CONTENT_TYPE_MAP: dict[str, str] = {
".mp3": "audio/mpeg",
".wav": "audio/wav",
".ogg": "audio/ogg",
".flac": "audio/flac",
}
SALUTE_LANGUAGE_MAP: dict[str, str] = {
"ru": "ru-RU",
"en": "en-US",
}
ERROR_SALUTE_AUTH_FAILED = "Ошибка авторизации SaluteSpeech: {detail}"
ERROR_SALUTE_UPLOAD_FAILED = "Ошибка загрузки файла в SaluteSpeech: {detail}"
ERROR_SALUTE_TASK_FAILED = "Ошибка распознавания SaluteSpeech: {detail}"
ERROR_SALUTE_TIMEOUT = "Превышено время ожидания распознавания SaluteSpeech"
ERROR_SALUTE_UNSUPPORTED_FORMAT = "Неподдерживаемый формат аудио для SaluteSpeech: {ext}"
ERROR_SALUTE_AUTH_KEY_MISSING = "Не задан SALUTE_AUTH_KEY для авторизации SaluteSpeech"
ERROR_SALUTE_SSL_FAILED = (
"SSL ошибка при обращении к SaluteSpeech: {detail}. "
"Если используется корпоративный или локальный сертификат, "
"укажите путь в SALUTE_CA_CERT_PATH. "
"Для локальной отладки можно отключить проверку через SALUTE_SSL_VERIFY=false."
)
_salute_token_lock = threading.Lock()
_salute_token: str | None = None
_salute_token_expires_at: float = 0.0
logger = logging.getLogger(__name__)
class DocumentBuilder:
def compute_segment_lines(
self,
segment: WhisperSegment | GoogleSpeechSegment | SaluteSpeechSegment,
max_chars_per_line: int,
) -> list[LineNode]:
words = segment.words or []
lines: list[list[WhisperWord | GoogleSpeechWord]] = []
cur_line: list[WhisperWord | GoogleSpeechWord] = []
cur_len = 0
for w in words:
text = (w.word or "").strip()
if not text:
continue
extra = len(text) + (1 if cur_line else 0)
if cur_line and cur_len + extra > max_chars_per_line:
lines.append(cur_line)
cur_line, cur_len = [w], len(text)
else:
cur_line.append(w)
cur_len += extra
if cur_line:
lines.append(cur_line)
result_lines: list[LineNode] = []
for rline in lines:
time = TimeRange(start=rline[0].start, end=rline[-1].end)
word_nodes = [
WordNode(
text=(rword.word or "").strip(),
time=TimeRange(start=rword.start, end=rword.end),
semantic_tags=[],
structure_tags=[],
)
for rword in rline
]
line_node = LineNode(
text=" ".join((rword.word or "") for rword in rline).strip(),
semantic_tags=[],
structure_tags=[],
time=time,
words=word_nodes,
)
result_lines.append(line_node)
return result_lines
def process_line(
self,
line: LineNode,
is_first_line_in_document: bool,
is_last_line_in_document: bool,
is_first_line_in_segment: bool,
is_last_line_in_segment: bool,
) -> list[WordNode]:
words: list[WordNode] = []
for idx, word in enumerate(line.words):
is_first = idx == 0
is_last = idx == len(line.words) - 1
rules = [
(is_first_line_in_document and is_first, FIRST_WORD_IN_DOCUMENT),
(is_last_line_in_document and is_last, LAST_WORD_IN_DOCUMENT),
(is_first_line_in_segment and is_first, FIRST_WORD_IN_SEGMENT),
(is_last_line_in_segment and is_last, LAST_WORD_IN_SEGMENT),
(is_first, FIRST_WORD_IN_LINE),
(is_last, LAST_WORD_IN_LINE),
]
structure_tags = [
Tag(name=tag_name) for condition, tag_name in rules if condition
]
new_word = word.model_copy(update={"structure_tags": structure_tags})
words.append(new_word)
return words
def process_segment(
self,
segment: SegmentNode,
is_first_segment_in_document: bool,
is_last_segment_in_document: bool,
) -> list[LineNode]:
lines: list[LineNode] = []
for idx, line in enumerate(segment.lines):
is_first = idx == 0
is_last = idx == len(segment.lines) - 1
rules = [
(is_first_segment_in_document and is_first, FIRST_LINE_IN_DOCUMENT),
(is_last_segment_in_document and is_last, LAST_LINE_IN_DOCUMENT),
(is_first, FIRST_LINE_IN_SEGMENT),
(is_last, LAST_LINE_IN_SEGMENT),
]
structure_tags = [
Tag(name=tag_name) for condition, tag_name in rules if condition
]
words = self.process_line(
line,
is_first_line_in_document=is_first_segment_in_document and is_first,
is_last_line_in_document=is_last_segment_in_document and is_last,
is_first_line_in_segment=is_first,
is_last_line_in_segment=is_last,
)
new_line = line.model_copy(
update={"structure_tags": structure_tags, "words": words}
)
lines.append(new_line)
return lines
def process_document(self, document: Document) -> Document:
segments: list[SegmentNode] = []
for idx, segment in enumerate(document.segments):
structure_tags: list[Tag] = []
is_first_segment_in_document = idx == 0
is_last_segment_in_document = idx == len(document.segments) - 1
if is_first_segment_in_document:
structure_tags.append(Tag(name=FIRST_SEGMENT_IN_DOCUMENT))
if is_last_segment_in_document:
structure_tags.append(Tag(name=LAST_SEGMENT_IN_DOCUMENT))
lines = self.process_segment(
segment, is_first_segment_in_document, is_last_segment_in_document
)
new_segment = segment.model_copy(
update={"lines": lines, "structure_tags": structure_tags}
)
segments.append(new_segment)
return Document(segments=segments)
async def _convert_local_to_ogg(input_path: str) -> tuple[str, Callable[[], None]]:
with NamedTemporaryFile(suffix=".ogg", delete=False) as out:
out_path = out.name
proc = await asyncio.create_subprocess_exec(
"ffmpeg",
"-y",
"-i",
input_path,
"-c:a",
"libopus",
"-b:a",
"24k",
"-vn",
"-ac",
"1",
"-ar",
"16000",
out_path,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
_, stderr = await proc.communicate()
if proc.returncode != 0:
raise RuntimeError(f"ffmpeg failed: {stderr.decode(errors='ignore')}")
def _cleanup() -> None:
import os
if os.path.exists(out_path):
os.remove(out_path)
return out_path, _cleanup
def _make_document_from_segments(
builder: DocumentBuilder,
segments: list[WhisperSegment] | list[GoogleSpeechSegment] | list[SaluteSpeechSegment],
*,
max_line_width: int,
) -> Document:
result_segments: list[SegmentNode] = []
for segment in segments:
lines = builder.compute_segment_lines(segment, max_line_width)
time = TimeRange(start=segment.start, end=segment.end)
segment_node = SegmentNode(
text=segment.text.strip(),
semantic_tags=[],
structure_tags=[],
time=time,
lines=lines,
)
result_segments.append(segment_node)
return Document(segments=result_segments)
ProgressCallback = Callable[[float], None]
def _whisper_transcribe_sync(
*,
local_file_path: str,
model_name: str,
language: str | None,
on_progress: ProgressCallback | None = None,
) -> Document:
import whisper # type: ignore[import-untyped]
settings = get_settings()
settings.transcription_models_dir.mkdir(parents=True, exist_ok=True)
builder = DocumentBuilder()
model = whisper.load_model(
model_name, download_root=str(settings.transcription_models_dir)
)
if language is None:
audio = whisper.load_audio(local_file_path)
audio = whisper.pad_or_trim(audio)
mel = whisper.log_mel_spectrogram(audio, n_mels=model.dims.n_mels).to(
model.device
)
_, probs_raw = model.detect_language(mel)
probs = cast(dict[str, float], probs_raw)
language = max(probs, key=lambda k: probs[k])
if on_progress is not None:
from unittest.mock import patch
from tqdm import tqdm as _orig_tqdm
class _ProgressTqdm(_orig_tqdm):
def update(self, n=1):
super().update(n)
if self.total:
on_progress(min(self.n / self.total * 100.0, 100.0))
with patch("whisper.transcribe.tqdm.tqdm", _ProgressTqdm):
result = whisper.transcribe(
audio=whisper.load_audio(local_file_path),
model=model,
word_timestamps=True,
temperature=0.2,
language=language,
verbose=False,
)
on_progress(100.0)
else:
result = whisper.transcribe(
audio=whisper.load_audio(local_file_path),
model=model,
word_timestamps=True,
temperature=0.2,
language=language,
verbose=None,
)
parsed = WhisperResult.model_validate(result)
words_options = WordOptions(
highlight_words=True,
max_line_width=32,
max_line_count=2,
)
document = _make_document_from_segments(
builder, parsed.segments, max_line_width=words_options.max_line_width
)
return builder.process_document(document)
async def transcribe_with_whisper(
storage: StorageService,
*,
file_key: str,
model_name: str = "tiny",
language: str | None = None,
on_progress: ProgressCallback | None = None,
) -> Document:
tmp = await storage.download_to_temp(file_key)
try:
return await anyio.to_thread.run_sync(
lambda: _whisper_transcribe_sync(
local_file_path=tmp.path,
model_name=model_name,
language=language,
on_progress=on_progress,
)
)
finally:
tmp.cleanup()
def _google_transcribe_sync(
*, ogg_bytes: bytes, language_codes: list[str]
) -> GoogleSpeechResult:
from google.cloud import speech
settings = get_settings()
client: speech.SpeechClient = speech.SpeechClient.from_service_account_file(
str(settings.google_service_key_path)
)
audio = speech.RecognitionAudio(content=ogg_bytes)
config = speech.RecognitionConfig(
encoding=speech.RecognitionConfig.AudioEncoding.OGG_OPUS,
sample_rate_hertz=16000,
language_code=language_codes[0],
alternative_language_codes=(
language_codes[1:] if len(language_codes) > 1 else []
),
model="latest_long",
enable_word_time_offsets=True,
)
operation = client.long_running_recognize(config=config, audio=audio)
response = operation.result(timeout=600)
segments: list[GoogleSpeechSegment] = []
full_text = ""
for result in response.results:
alternative = result.alternatives[0]
words: list[GoogleSpeechWord] = []
for word_info in alternative.words:
words.append(
GoogleSpeechWord(
word=word_info.word,
start=word_info.start_time.total_seconds(),
end=word_info.end_time.total_seconds(),
)
)
if words:
segment_text = alternative.transcript
full_text += segment_text + " "
segments.append(
GoogleSpeechSegment(
text=segment_text,
start=words[0].start,
end=words[-1].end,
words=words,
)
)
return GoogleSpeechResult(
text=full_text.strip(), segments=segments, language=language_codes[0]
)
async def transcribe_with_google_speech(
storage: StorageService,
*,
file_key: str,
language_codes: list[str] | None = None,
) -> Document:
language_codes = language_codes or ["ru-RU", "en-US"]
builder = DocumentBuilder()
words_options = WordOptions()
input_tmp = await storage.download_to_temp(file_key)
try:
ogg_path, ogg_cleanup = await _convert_local_to_ogg(input_tmp.path)
try:
with open(ogg_path, "rb") as f:
content = f.read()
result = await anyio.to_thread.run_sync(
lambda: _google_transcribe_sync(
ogg_bytes=content, language_codes=language_codes
)
)
document = _make_document_from_segments(
builder, result.segments, max_line_width=words_options.max_line_width
)
return builder.process_document(document)
finally:
ogg_cleanup()
finally:
input_tmp.cleanup()
# ---------------------------------- SaluteSpeech Engine ----------------------------------
def _parse_salute_time(s: str) -> float:
"""Parse SaluteSpeech timestamp string '0.480s' → 0.48."""
return float(s.rstrip("s"))
def _build_salute_ssl_context() -> ssl.SSLContext:
"""Build SSL context for SaluteSpeech using system trust plus optional custom CA."""
settings = get_settings()
if not settings.salute_ssl_verify:
return ssl._create_unverified_context()
ssl_context = ssl.create_default_context()
if settings.salute_ca_cert_path is not None:
ssl_context.load_verify_locations(cafile=str(settings.salute_ca_cert_path))
return ssl_context
def _get_salute_auth_header_value() -> str:
"""Build Basic auth header for SaluteSpeech from settings."""
settings = get_settings()
auth_key = settings.salute_auth_key.strip()
if not auth_key:
raise RuntimeError(ERROR_SALUTE_AUTH_KEY_MISSING)
return f"Basic {auth_key}"
def _get_salute_access_token(client: httpx.Client) -> str:
"""Get or refresh SaluteSpeech OAuth token. Thread-safe."""
global _salute_token, _salute_token_expires_at
with _salute_token_lock:
if _salute_token and time.monotonic() < (
_salute_token_expires_at - SALUTE_TOKEN_REFRESH_MARGIN_SECONDS
):
return _salute_token
settings = get_settings()
response = client.post(
SALUTE_AUTH_URL,
headers={
"Authorization": _get_salute_auth_header_value(),
"RqUID": str(uuid.uuid4()),
"Content-Type": "application/x-www-form-urlencoded",
},
content=f"scope={settings.salute_scope}",
)
if response.status_code != 200:
raise RuntimeError(
ERROR_SALUTE_AUTH_FAILED.format(detail=response.text[:200])
)
data = response.json()
_salute_token = data["access_token"]
expires_in_seconds = (data["expires_at"] / 1000) - time.time()
_salute_token_expires_at = time.monotonic() + expires_in_seconds
return _salute_token
def _upload_salute_audio(
client: httpx.Client, token: str, audio_data: bytes, content_type: str
) -> str:
"""Upload audio to SaluteSpeech, return request_file_id."""
response = client.post(
f"{SALUTE_API_BASE}/data:upload",
headers={
"Authorization": f"Bearer {token}",
"Content-Type": content_type,
},
content=audio_data,
timeout=120.0,
)
if response.status_code != 200:
raise RuntimeError(
ERROR_SALUTE_UPLOAD_FAILED.format(detail=response.text[:200])
)
return response.json()["result"]["request_file_id"]
def _create_salute_task(
client: httpx.Client,
token: str,
file_id: str,
*,
language: str,
model: str,
audio_encoding: str,
sample_rate: int,
) -> str:
"""Create async recognition task, return task_id."""
body = {
"options": {
"audio_encoding": audio_encoding,
"sample_rate": sample_rate,
"language": language,
"model": model,
"channels_count": 1,
"hypotheses_count": 1,
},
"request_file_id": file_id,
}
response = client.post(
f"{SALUTE_API_BASE}/speech:async_recognize",
headers={
"Authorization": f"Bearer {token}",
"Content-Type": "application/json",
},
json=body,
)
if response.status_code != 200:
raise RuntimeError(
ERROR_SALUTE_TASK_FAILED.format(detail=response.text[:200])
)
return response.json()["result"]["id"]
def _poll_salute_task(
client: httpx.Client,
token: str,
task_id: str,
job_uuid: uuid.UUID | None,
on_progress: ProgressCallback | None,
) -> str:
"""Poll task until DONE, return response_file_id."""
start = time.monotonic()
while True:
elapsed = time.monotonic() - start
if elapsed > SALUTE_POLL_TIMEOUT_SECONDS:
raise TimeoutError(ERROR_SALUTE_TIMEOUT)
if job_uuid is not None:
from cpv3.modules.tasks.service import _raise_if_job_cancelled
_raise_if_job_cancelled(job_uuid)
response = client.get(
f"{SALUTE_API_BASE}/task:get",
params={"id": task_id},
headers={"Authorization": f"Bearer {token}"},
)
response.raise_for_status()
result = response.json()["result"]
status = result["status"]
if status == "DONE":
return result["response_file_id"]
if status == "ERROR":
error_msg = result.get("error", "unknown error")
raise RuntimeError(
ERROR_SALUTE_TASK_FAILED.format(detail=error_msg)
)
if on_progress is not None:
pct = min(elapsed / SALUTE_POLL_TIMEOUT_SECONDS * 100, 95.0)
on_progress(pct)
time.sleep(SALUTE_POLL_INTERVAL_SECONDS)
def _download_salute_result(
client: httpx.Client, token: str, response_file_id: str
) -> list[dict]:
"""Download recognition result JSON."""
response = client.get(
f"{SALUTE_API_BASE}/data:download",
params={"response_file_id": response_file_id},
headers={"Authorization": f"Bearer {token}"},
timeout=60.0,
)
response.raise_for_status()
return response.json()
def _build_document_from_salute_result(
raw_channels: list[dict], *, language: str
) -> Document:
"""Convert SaluteSpeech result JSON to Document."""
builder = DocumentBuilder()
words_options = WordOptions()
all_segments: list[SaluteSpeechSegment] = []
for channel_data in raw_channels:
for result_item in channel_data.get("results", []):
word_alignments = result_item.get("word_alignments", [])
words = [
SaluteSpeechWord(
word=w["word"],
start=_parse_salute_time(w["start"]),
end=_parse_salute_time(w["end"]),
)
for w in word_alignments
]
text = result_item.get("text", "")
seg_start = _parse_salute_time(result_item["start"])
seg_end = _parse_salute_time(result_item["end"])
all_segments.append(
SaluteSpeechSegment(
text=text,
start=seg_start,
end=seg_end,
words=words,
)
)
document = _make_document_from_segments(
builder, all_segments, max_line_width=words_options.max_line_width
)
return builder.process_document(document)
def _convert_to_wav_sync(input_path: str, sample_rate: int = 16000) -> tuple[str, Callable[[], None]]:
"""Convert any audio/video to WAV (PCM signed 16-bit LE) using ffmpeg. Sync version."""
import os
import subprocess
with NamedTemporaryFile(suffix=".wav", delete=False) as out:
out_path = out.name
result = subprocess.run(
[
"ffmpeg", "-y", "-i", input_path,
"-vn", "-ac", "1", "-ar", str(sample_rate),
"-acodec", "pcm_s16le",
out_path,
],
capture_output=True,
)
if result.returncode != 0:
raise RuntimeError(f"ffmpeg failed: {result.stderr.decode(errors='ignore')}")
def _cleanup() -> None:
if os.path.exists(out_path):
os.remove(out_path)
return out_path, _cleanup
def _salute_transcribe_sync(
*,
local_file_path: str,
language: str | None,
model: str,
sample_rate: int,
job_id: uuid.UUID | None = None,
on_progress: ProgressCallback | None = None,
) -> Document:
"""Synchronous SaluteSpeech transcription (runs in Dramatiq worker thread)."""
ext = Path(local_file_path).suffix.lower()
audio_encoding = SALUTE_ENCODING_MAP.get(ext)
content_type = SALUTE_CONTENT_TYPE_MAP.get(ext)
# Convert unsupported formats (mp4, webm, m4a, etc.) to WAV via ffmpeg
cleanup_fn: Callable[[], None] | None = None
if not audio_encoding or not content_type:
wav_path, cleanup_fn = _convert_to_wav_sync(local_file_path, sample_rate)
local_file_path = wav_path
audio_encoding = "PCM_S16LE"
content_type = "audio/wav"
salute_language = SALUTE_LANGUAGE_MAP.get(language or "", "ru-RU")
try:
ssl_context = _build_salute_ssl_context()
with httpx.Client(verify=ssl_context, timeout=30.0) as client:
token = _get_salute_access_token(client)
with open(local_file_path, "rb") as f:
audio_data = f.read()
file_id = _upload_salute_audio(client, token, audio_data, content_type)
task_id = _create_salute_task(
client,
token,
file_id,
language=salute_language,
model=model,
audio_encoding=audio_encoding,
sample_rate=sample_rate,
)
response_file_id = _poll_salute_task(
client, token, task_id, job_id, on_progress
)
raw_result = _download_salute_result(client, token, response_file_id)
return _build_document_from_salute_result(raw_result, language=salute_language)
except ssl.SSLError as exc:
raise RuntimeError(ERROR_SALUTE_SSL_FAILED.format(detail=str(exc))) from exc
except httpx.ConnectError as exc:
if isinstance(exc.__cause__, ssl.SSLError):
raise RuntimeError(
ERROR_SALUTE_SSL_FAILED.format(detail=str(exc.__cause__))
) from exc
raise
finally:
if cleanup_fn is not None:
cleanup_fn()
async def transcribe_with_salute_speech(
storage: StorageService,
*,
file_key: str,
language: str | None = None,
model: str = "general",
sample_rate: int = 16000,
job_id: uuid.UUID | None = None,
on_progress: ProgressCallback | None = None,
) -> Document:
"""Async wrapper for SaluteSpeech transcription."""
tmp = await storage.download_to_temp(file_key)
try:
return await anyio.to_thread.run_sync(
lambda: _salute_transcribe_sync(
local_file_path=tmp.path,
language=language,
model=model,
sample_rate=sample_rate,
job_id=job_id,
on_progress=on_progress,
)
)
finally:
tmp.cleanup()