269724d553
MP4/webm/m4a files are now auto-converted to WAV (PCM_S16LE) via ffmpeg before uploading to SaluteSpeech API. Follows the same pattern as Google Speech's _convert_local_to_ogg. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
781 lines
24 KiB
Python
781 lines
24 KiB
Python
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import logging
|
|
import threading
|
|
import time
|
|
import uuid
|
|
from pathlib import Path
|
|
from tempfile import NamedTemporaryFile
|
|
from typing import Callable, cast
|
|
|
|
import anyio
|
|
import httpx
|
|
|
|
from cpv3.infrastructure.settings import get_settings
|
|
from cpv3.infrastructure.storage.base import StorageService
|
|
from cpv3.modules.transcription.constants import (
|
|
FIRST_LINE_IN_DOCUMENT,
|
|
FIRST_LINE_IN_SEGMENT,
|
|
FIRST_SEGMENT_IN_DOCUMENT,
|
|
FIRST_WORD_IN_DOCUMENT,
|
|
FIRST_WORD_IN_LINE,
|
|
FIRST_WORD_IN_SEGMENT,
|
|
LAST_LINE_IN_DOCUMENT,
|
|
LAST_LINE_IN_SEGMENT,
|
|
LAST_SEGMENT_IN_DOCUMENT,
|
|
LAST_WORD_IN_DOCUMENT,
|
|
LAST_WORD_IN_LINE,
|
|
LAST_WORD_IN_SEGMENT,
|
|
)
|
|
from cpv3.modules.transcription.schemas import (
|
|
Document,
|
|
GoogleSpeechResult,
|
|
GoogleSpeechSegment,
|
|
GoogleSpeechWord,
|
|
LineNode,
|
|
SaluteSpeechSegment,
|
|
SaluteSpeechWord,
|
|
SegmentNode,
|
|
Tag,
|
|
TimeRange,
|
|
WhisperResult,
|
|
WhisperSegment,
|
|
WhisperWord,
|
|
WordNode,
|
|
WordOptions,
|
|
)
|
|
|
|
|
|
# ---------------------------------- SaluteSpeech Constants ----------------------------------
|
|
|
|
SALUTE_AUTH_URL = "https://ngw.devices.sberbank.ru:9443/api/v2/oauth"
|
|
SALUTE_API_BASE = "https://smartspeech.sber.ru/rest/v1"
|
|
SALUTE_POLL_INTERVAL_SECONDS = 5.0
|
|
SALUTE_POLL_TIMEOUT_SECONDS = 600
|
|
SALUTE_TOKEN_REFRESH_MARGIN_SECONDS = 60
|
|
|
|
SALUTE_ENCODING_MAP: dict[str, str] = {
|
|
".mp3": "MP3",
|
|
".wav": "PCM_S16LE",
|
|
".ogg": "opus",
|
|
".flac": "FLAC",
|
|
}
|
|
|
|
SALUTE_CONTENT_TYPE_MAP: dict[str, str] = {
|
|
".mp3": "audio/mpeg",
|
|
".wav": "audio/wav",
|
|
".ogg": "audio/ogg",
|
|
".flac": "audio/flac",
|
|
}
|
|
|
|
SALUTE_LANGUAGE_MAP: dict[str, str] = {
|
|
"ru": "ru-RU",
|
|
"en": "en-US",
|
|
}
|
|
|
|
ERROR_SALUTE_AUTH_FAILED = "Ошибка авторизации SaluteSpeech: {detail}"
|
|
ERROR_SALUTE_UPLOAD_FAILED = "Ошибка загрузки файла в SaluteSpeech: {detail}"
|
|
ERROR_SALUTE_TASK_FAILED = "Ошибка распознавания SaluteSpeech: {detail}"
|
|
ERROR_SALUTE_TIMEOUT = "Превышено время ожидания распознавания SaluteSpeech"
|
|
ERROR_SALUTE_UNSUPPORTED_FORMAT = "Неподдерживаемый формат аудио для SaluteSpeech: {ext}"
|
|
|
|
_salute_token_lock = threading.Lock()
|
|
_salute_token: str | None = None
|
|
_salute_token_expires_at: float = 0.0
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class DocumentBuilder:
|
|
def compute_segment_lines(
|
|
self,
|
|
segment: WhisperSegment | GoogleSpeechSegment | SaluteSpeechSegment,
|
|
max_chars_per_line: int,
|
|
) -> list[LineNode]:
|
|
words = segment.words or []
|
|
lines: list[list[WhisperWord | GoogleSpeechWord]] = []
|
|
cur_line: list[WhisperWord | GoogleSpeechWord] = []
|
|
cur_len = 0
|
|
|
|
for w in words:
|
|
text = (w.word or "").strip()
|
|
if not text:
|
|
continue
|
|
|
|
extra = len(text) + (1 if cur_line else 0)
|
|
|
|
if cur_line and cur_len + extra > max_chars_per_line:
|
|
lines.append(cur_line)
|
|
cur_line, cur_len = [w], len(text)
|
|
else:
|
|
cur_line.append(w)
|
|
cur_len += extra
|
|
|
|
if cur_line:
|
|
lines.append(cur_line)
|
|
|
|
result_lines: list[LineNode] = []
|
|
for rline in lines:
|
|
time = TimeRange(start=rline[0].start, end=rline[-1].end)
|
|
|
|
word_nodes = [
|
|
WordNode(
|
|
text=(rword.word or "").strip(),
|
|
time=TimeRange(start=rword.start, end=rword.end),
|
|
semantic_tags=[],
|
|
structure_tags=[],
|
|
)
|
|
for rword in rline
|
|
]
|
|
|
|
line_node = LineNode(
|
|
text=" ".join((rword.word or "") for rword in rline).strip(),
|
|
semantic_tags=[],
|
|
structure_tags=[],
|
|
time=time,
|
|
words=word_nodes,
|
|
)
|
|
result_lines.append(line_node)
|
|
|
|
return result_lines
|
|
|
|
def process_line(
|
|
self,
|
|
line: LineNode,
|
|
is_first_line_in_document: bool,
|
|
is_last_line_in_document: bool,
|
|
is_first_line_in_segment: bool,
|
|
is_last_line_in_segment: bool,
|
|
) -> list[WordNode]:
|
|
words: list[WordNode] = []
|
|
for idx, word in enumerate(line.words):
|
|
is_first = idx == 0
|
|
is_last = idx == len(line.words) - 1
|
|
|
|
rules = [
|
|
(is_first_line_in_document and is_first, FIRST_WORD_IN_DOCUMENT),
|
|
(is_last_line_in_document and is_last, LAST_WORD_IN_DOCUMENT),
|
|
(is_first_line_in_segment and is_first, FIRST_WORD_IN_SEGMENT),
|
|
(is_last_line_in_segment and is_last, LAST_WORD_IN_SEGMENT),
|
|
(is_first, FIRST_WORD_IN_LINE),
|
|
(is_last, LAST_WORD_IN_LINE),
|
|
]
|
|
|
|
structure_tags = [
|
|
Tag(name=tag_name) for condition, tag_name in rules if condition
|
|
]
|
|
|
|
new_word = word.model_copy(update={"structure_tags": structure_tags})
|
|
words.append(new_word)
|
|
|
|
return words
|
|
|
|
def process_segment(
|
|
self,
|
|
segment: SegmentNode,
|
|
is_first_segment_in_document: bool,
|
|
is_last_segment_in_document: bool,
|
|
) -> list[LineNode]:
|
|
lines: list[LineNode] = []
|
|
for idx, line in enumerate(segment.lines):
|
|
is_first = idx == 0
|
|
is_last = idx == len(segment.lines) - 1
|
|
|
|
rules = [
|
|
(is_first_segment_in_document and is_first, FIRST_LINE_IN_DOCUMENT),
|
|
(is_last_segment_in_document and is_last, LAST_LINE_IN_DOCUMENT),
|
|
(is_first, FIRST_LINE_IN_SEGMENT),
|
|
(is_last, LAST_LINE_IN_SEGMENT),
|
|
]
|
|
|
|
structure_tags = [
|
|
Tag(name=tag_name) for condition, tag_name in rules if condition
|
|
]
|
|
|
|
words = self.process_line(
|
|
line,
|
|
is_first_line_in_document=is_first_segment_in_document and is_first,
|
|
is_last_line_in_document=is_last_segment_in_document and is_last,
|
|
is_first_line_in_segment=is_first,
|
|
is_last_line_in_segment=is_last,
|
|
)
|
|
|
|
new_line = line.model_copy(
|
|
update={"structure_tags": structure_tags, "words": words}
|
|
)
|
|
lines.append(new_line)
|
|
|
|
return lines
|
|
|
|
def process_document(self, document: Document) -> Document:
|
|
segments: list[SegmentNode] = []
|
|
|
|
for idx, segment in enumerate(document.segments):
|
|
structure_tags: list[Tag] = []
|
|
is_first_segment_in_document = idx == 0
|
|
is_last_segment_in_document = idx == len(document.segments) - 1
|
|
|
|
if is_first_segment_in_document:
|
|
structure_tags.append(Tag(name=FIRST_SEGMENT_IN_DOCUMENT))
|
|
if is_last_segment_in_document:
|
|
structure_tags.append(Tag(name=LAST_SEGMENT_IN_DOCUMENT))
|
|
|
|
lines = self.process_segment(
|
|
segment, is_first_segment_in_document, is_last_segment_in_document
|
|
)
|
|
new_segment = segment.model_copy(
|
|
update={"lines": lines, "structure_tags": structure_tags}
|
|
)
|
|
segments.append(new_segment)
|
|
|
|
return Document(segments=segments)
|
|
|
|
|
|
async def _convert_local_to_ogg(input_path: str) -> tuple[str, Callable[[], None]]:
|
|
with NamedTemporaryFile(suffix=".ogg", delete=False) as out:
|
|
out_path = out.name
|
|
|
|
proc = await asyncio.create_subprocess_exec(
|
|
"ffmpeg",
|
|
"-y",
|
|
"-i",
|
|
input_path,
|
|
"-c:a",
|
|
"libopus",
|
|
"-b:a",
|
|
"24k",
|
|
"-vn",
|
|
"-ac",
|
|
"1",
|
|
"-ar",
|
|
"16000",
|
|
out_path,
|
|
stdout=asyncio.subprocess.PIPE,
|
|
stderr=asyncio.subprocess.PIPE,
|
|
)
|
|
_, stderr = await proc.communicate()
|
|
if proc.returncode != 0:
|
|
raise RuntimeError(f"ffmpeg failed: {stderr.decode(errors='ignore')}")
|
|
|
|
def _cleanup() -> None:
|
|
import os
|
|
|
|
if os.path.exists(out_path):
|
|
os.remove(out_path)
|
|
|
|
return out_path, _cleanup
|
|
|
|
|
|
def _make_document_from_segments(
|
|
builder: DocumentBuilder,
|
|
segments: list[WhisperSegment] | list[GoogleSpeechSegment] | list[SaluteSpeechSegment],
|
|
*,
|
|
max_line_width: int,
|
|
) -> Document:
|
|
result_segments: list[SegmentNode] = []
|
|
|
|
for segment in segments:
|
|
lines = builder.compute_segment_lines(segment, max_line_width)
|
|
time = TimeRange(start=segment.start, end=segment.end)
|
|
segment_node = SegmentNode(
|
|
text=segment.text.strip(),
|
|
semantic_tags=[],
|
|
structure_tags=[],
|
|
time=time,
|
|
lines=lines,
|
|
)
|
|
result_segments.append(segment_node)
|
|
|
|
return Document(segments=result_segments)
|
|
|
|
|
|
ProgressCallback = Callable[[float], None]
|
|
|
|
|
|
def _whisper_transcribe_sync(
|
|
*,
|
|
local_file_path: str,
|
|
model_name: str,
|
|
language: str | None,
|
|
on_progress: ProgressCallback | None = None,
|
|
) -> Document:
|
|
import whisper # type: ignore[import-untyped]
|
|
|
|
settings = get_settings()
|
|
settings.transcription_models_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
builder = DocumentBuilder()
|
|
|
|
model = whisper.load_model(
|
|
model_name, download_root=str(settings.transcription_models_dir)
|
|
)
|
|
|
|
if language is None:
|
|
audio = whisper.load_audio(local_file_path)
|
|
audio = whisper.pad_or_trim(audio)
|
|
mel = whisper.log_mel_spectrogram(audio, n_mels=model.dims.n_mels).to(
|
|
model.device
|
|
)
|
|
_, probs_raw = model.detect_language(mel)
|
|
probs = cast(dict[str, float], probs_raw)
|
|
language = max(probs, key=lambda k: probs[k])
|
|
|
|
if on_progress is not None:
|
|
from unittest.mock import patch
|
|
from tqdm import tqdm as _orig_tqdm
|
|
|
|
class _ProgressTqdm(_orig_tqdm):
|
|
def update(self, n=1):
|
|
super().update(n)
|
|
if self.total:
|
|
on_progress(min(self.n / self.total * 100.0, 100.0))
|
|
|
|
with patch("whisper.transcribe.tqdm.tqdm", _ProgressTqdm):
|
|
result = whisper.transcribe(
|
|
audio=whisper.load_audio(local_file_path),
|
|
model=model,
|
|
word_timestamps=True,
|
|
temperature=0.2,
|
|
language=language,
|
|
verbose=False,
|
|
)
|
|
on_progress(100.0)
|
|
else:
|
|
result = whisper.transcribe(
|
|
audio=whisper.load_audio(local_file_path),
|
|
model=model,
|
|
word_timestamps=True,
|
|
temperature=0.2,
|
|
language=language,
|
|
verbose=None,
|
|
)
|
|
|
|
parsed = WhisperResult.model_validate(result)
|
|
|
|
words_options = WordOptions(
|
|
highlight_words=True,
|
|
max_line_width=32,
|
|
max_line_count=2,
|
|
)
|
|
|
|
document = _make_document_from_segments(
|
|
builder, parsed.segments, max_line_width=words_options.max_line_width
|
|
)
|
|
return builder.process_document(document)
|
|
|
|
|
|
async def transcribe_with_whisper(
|
|
storage: StorageService,
|
|
*,
|
|
file_key: str,
|
|
model_name: str = "tiny",
|
|
language: str | None = None,
|
|
on_progress: ProgressCallback | None = None,
|
|
) -> Document:
|
|
tmp = await storage.download_to_temp(file_key)
|
|
try:
|
|
return await anyio.to_thread.run_sync(
|
|
lambda: _whisper_transcribe_sync(
|
|
local_file_path=tmp.path,
|
|
model_name=model_name,
|
|
language=language,
|
|
on_progress=on_progress,
|
|
)
|
|
)
|
|
finally:
|
|
tmp.cleanup()
|
|
|
|
|
|
def _google_transcribe_sync(
|
|
*, ogg_bytes: bytes, language_codes: list[str]
|
|
) -> GoogleSpeechResult:
|
|
from google.cloud import speech
|
|
|
|
settings = get_settings()
|
|
|
|
client: speech.SpeechClient = speech.SpeechClient.from_service_account_file(
|
|
str(settings.google_service_key_path)
|
|
)
|
|
|
|
audio = speech.RecognitionAudio(content=ogg_bytes)
|
|
config = speech.RecognitionConfig(
|
|
encoding=speech.RecognitionConfig.AudioEncoding.OGG_OPUS,
|
|
sample_rate_hertz=16000,
|
|
language_code=language_codes[0],
|
|
alternative_language_codes=(
|
|
language_codes[1:] if len(language_codes) > 1 else []
|
|
),
|
|
model="latest_long",
|
|
enable_word_time_offsets=True,
|
|
)
|
|
|
|
operation = client.long_running_recognize(config=config, audio=audio)
|
|
response = operation.result(timeout=600)
|
|
|
|
segments: list[GoogleSpeechSegment] = []
|
|
full_text = ""
|
|
|
|
for result in response.results:
|
|
alternative = result.alternatives[0]
|
|
words: list[GoogleSpeechWord] = []
|
|
for word_info in alternative.words:
|
|
words.append(
|
|
GoogleSpeechWord(
|
|
word=word_info.word,
|
|
start=word_info.start_time.total_seconds(),
|
|
end=word_info.end_time.total_seconds(),
|
|
)
|
|
)
|
|
|
|
if words:
|
|
segment_text = alternative.transcript
|
|
full_text += segment_text + " "
|
|
segments.append(
|
|
GoogleSpeechSegment(
|
|
text=segment_text,
|
|
start=words[0].start,
|
|
end=words[-1].end,
|
|
words=words,
|
|
)
|
|
)
|
|
|
|
return GoogleSpeechResult(
|
|
text=full_text.strip(), segments=segments, language=language_codes[0]
|
|
)
|
|
|
|
|
|
async def transcribe_with_google_speech(
|
|
storage: StorageService,
|
|
*,
|
|
file_key: str,
|
|
language_codes: list[str] | None = None,
|
|
) -> Document:
|
|
language_codes = language_codes or ["ru-RU", "en-US"]
|
|
|
|
builder = DocumentBuilder()
|
|
words_options = WordOptions()
|
|
|
|
input_tmp = await storage.download_to_temp(file_key)
|
|
try:
|
|
ogg_path, ogg_cleanup = await _convert_local_to_ogg(input_tmp.path)
|
|
try:
|
|
with open(ogg_path, "rb") as f:
|
|
content = f.read()
|
|
|
|
result = await anyio.to_thread.run_sync(
|
|
lambda: _google_transcribe_sync(
|
|
ogg_bytes=content, language_codes=language_codes
|
|
)
|
|
)
|
|
|
|
document = _make_document_from_segments(
|
|
builder, result.segments, max_line_width=words_options.max_line_width
|
|
)
|
|
return builder.process_document(document)
|
|
finally:
|
|
ogg_cleanup()
|
|
finally:
|
|
input_tmp.cleanup()
|
|
|
|
|
|
# ---------------------------------- SaluteSpeech Engine ----------------------------------
|
|
|
|
|
|
def _parse_salute_time(s: str) -> float:
|
|
"""Parse SaluteSpeech timestamp string '0.480s' → 0.48."""
|
|
return float(s.rstrip("s"))
|
|
|
|
|
|
def _get_salute_access_token(client: httpx.Client) -> str:
|
|
"""Get or refresh SaluteSpeech OAuth token. Thread-safe."""
|
|
global _salute_token, _salute_token_expires_at
|
|
with _salute_token_lock:
|
|
if _salute_token and time.monotonic() < (
|
|
_salute_token_expires_at - SALUTE_TOKEN_REFRESH_MARGIN_SECONDS
|
|
):
|
|
return _salute_token
|
|
|
|
settings = get_settings()
|
|
response = client.post(
|
|
SALUTE_AUTH_URL,
|
|
headers={
|
|
"Authorization": f"Basic {settings.salute_auth_key}",
|
|
"RqUID": str(uuid.uuid4()),
|
|
"Content-Type": "application/x-www-form-urlencoded",
|
|
},
|
|
content=f"scope={settings.salute_scope}",
|
|
)
|
|
if response.status_code != 200:
|
|
raise RuntimeError(
|
|
ERROR_SALUTE_AUTH_FAILED.format(detail=response.text[:200])
|
|
)
|
|
data = response.json()
|
|
_salute_token = data["access_token"]
|
|
expires_in_seconds = (data["expires_at"] / 1000) - time.time()
|
|
_salute_token_expires_at = time.monotonic() + expires_in_seconds
|
|
return _salute_token
|
|
|
|
|
|
def _upload_salute_audio(
|
|
client: httpx.Client, token: str, audio_data: bytes, content_type: str
|
|
) -> str:
|
|
"""Upload audio to SaluteSpeech, return request_file_id."""
|
|
response = client.post(
|
|
f"{SALUTE_API_BASE}/data:upload",
|
|
headers={
|
|
"Authorization": f"Bearer {token}",
|
|
"Content-Type": content_type,
|
|
},
|
|
content=audio_data,
|
|
timeout=120.0,
|
|
)
|
|
if response.status_code != 200:
|
|
raise RuntimeError(
|
|
ERROR_SALUTE_UPLOAD_FAILED.format(detail=response.text[:200])
|
|
)
|
|
return response.json()["result"]["request_file_id"]
|
|
|
|
|
|
def _create_salute_task(
|
|
client: httpx.Client,
|
|
token: str,
|
|
file_id: str,
|
|
*,
|
|
language: str,
|
|
model: str,
|
|
audio_encoding: str,
|
|
sample_rate: int,
|
|
) -> str:
|
|
"""Create async recognition task, return task_id."""
|
|
body = {
|
|
"options": {
|
|
"audio_encoding": audio_encoding,
|
|
"sample_rate": sample_rate,
|
|
"language": language,
|
|
"model": model,
|
|
"channels_count": 1,
|
|
"hypotheses_count": 1,
|
|
},
|
|
"request_file_id": file_id,
|
|
}
|
|
response = client.post(
|
|
f"{SALUTE_API_BASE}/speech:async_recognize",
|
|
headers={
|
|
"Authorization": f"Bearer {token}",
|
|
"Content-Type": "application/json",
|
|
},
|
|
json=body,
|
|
)
|
|
if response.status_code != 200:
|
|
raise RuntimeError(
|
|
ERROR_SALUTE_TASK_FAILED.format(detail=response.text[:200])
|
|
)
|
|
return response.json()["result"]["id"]
|
|
|
|
|
|
def _poll_salute_task(
|
|
client: httpx.Client,
|
|
token: str,
|
|
task_id: str,
|
|
job_uuid: uuid.UUID | None,
|
|
on_progress: ProgressCallback | None,
|
|
) -> str:
|
|
"""Poll task until DONE, return response_file_id."""
|
|
start = time.monotonic()
|
|
while True:
|
|
elapsed = time.monotonic() - start
|
|
if elapsed > SALUTE_POLL_TIMEOUT_SECONDS:
|
|
raise TimeoutError(ERROR_SALUTE_TIMEOUT)
|
|
|
|
if job_uuid is not None:
|
|
from cpv3.modules.tasks.service import _raise_if_job_cancelled
|
|
|
|
_raise_if_job_cancelled(job_uuid)
|
|
|
|
response = client.get(
|
|
f"{SALUTE_API_BASE}/task:get",
|
|
params={"id": task_id},
|
|
headers={"Authorization": f"Bearer {token}"},
|
|
)
|
|
response.raise_for_status()
|
|
result = response.json()["result"]
|
|
status = result["status"]
|
|
|
|
if status == "DONE":
|
|
return result["response_file_id"]
|
|
if status == "ERROR":
|
|
error_msg = result.get("error", "unknown error")
|
|
raise RuntimeError(
|
|
ERROR_SALUTE_TASK_FAILED.format(detail=error_msg)
|
|
)
|
|
|
|
if on_progress is not None:
|
|
pct = min(elapsed / SALUTE_POLL_TIMEOUT_SECONDS * 100, 95.0)
|
|
on_progress(pct)
|
|
|
|
time.sleep(SALUTE_POLL_INTERVAL_SECONDS)
|
|
|
|
|
|
def _download_salute_result(
|
|
client: httpx.Client, token: str, response_file_id: str
|
|
) -> list[dict]:
|
|
"""Download recognition result JSON."""
|
|
response = client.get(
|
|
f"{SALUTE_API_BASE}/data:download",
|
|
params={"response_file_id": response_file_id},
|
|
headers={"Authorization": f"Bearer {token}"},
|
|
timeout=60.0,
|
|
)
|
|
response.raise_for_status()
|
|
return response.json()
|
|
|
|
|
|
def _build_document_from_salute_result(
|
|
raw_channels: list[dict], *, language: str
|
|
) -> Document:
|
|
"""Convert SaluteSpeech result JSON to Document."""
|
|
builder = DocumentBuilder()
|
|
words_options = WordOptions()
|
|
|
|
all_segments: list[SaluteSpeechSegment] = []
|
|
|
|
for channel_data in raw_channels:
|
|
for result_item in channel_data.get("results", []):
|
|
word_alignments = result_item.get("word_alignments", [])
|
|
words = [
|
|
SaluteSpeechWord(
|
|
word=w["word"],
|
|
start=_parse_salute_time(w["start"]),
|
|
end=_parse_salute_time(w["end"]),
|
|
)
|
|
for w in word_alignments
|
|
]
|
|
|
|
text = result_item.get("text", "")
|
|
seg_start = _parse_salute_time(result_item["start"])
|
|
seg_end = _parse_salute_time(result_item["end"])
|
|
|
|
all_segments.append(
|
|
SaluteSpeechSegment(
|
|
text=text,
|
|
start=seg_start,
|
|
end=seg_end,
|
|
words=words,
|
|
)
|
|
)
|
|
|
|
document = _make_document_from_segments(
|
|
builder, all_segments, max_line_width=words_options.max_line_width
|
|
)
|
|
return builder.process_document(document)
|
|
|
|
|
|
def _convert_to_wav_sync(input_path: str, sample_rate: int = 16000) -> tuple[str, Callable[[], None]]:
|
|
"""Convert any audio/video to WAV (PCM signed 16-bit LE) using ffmpeg. Sync version."""
|
|
import os
|
|
import subprocess
|
|
|
|
with NamedTemporaryFile(suffix=".wav", delete=False) as out:
|
|
out_path = out.name
|
|
|
|
result = subprocess.run(
|
|
[
|
|
"ffmpeg", "-y", "-i", input_path,
|
|
"-vn", "-ac", "1", "-ar", str(sample_rate),
|
|
"-acodec", "pcm_s16le",
|
|
out_path,
|
|
],
|
|
capture_output=True,
|
|
)
|
|
if result.returncode != 0:
|
|
raise RuntimeError(f"ffmpeg failed: {result.stderr.decode(errors='ignore')}")
|
|
|
|
def _cleanup() -> None:
|
|
if os.path.exists(out_path):
|
|
os.remove(out_path)
|
|
|
|
return out_path, _cleanup
|
|
|
|
|
|
def _salute_transcribe_sync(
|
|
*,
|
|
local_file_path: str,
|
|
language: str | None,
|
|
model: str,
|
|
sample_rate: int,
|
|
job_id: uuid.UUID | None = None,
|
|
on_progress: ProgressCallback | None = None,
|
|
) -> Document:
|
|
"""Synchronous SaluteSpeech transcription (runs in Dramatiq worker thread)."""
|
|
settings = get_settings()
|
|
|
|
ext = Path(local_file_path).suffix.lower()
|
|
audio_encoding = SALUTE_ENCODING_MAP.get(ext)
|
|
content_type = SALUTE_CONTENT_TYPE_MAP.get(ext)
|
|
|
|
# Convert unsupported formats (mp4, webm, m4a, etc.) to WAV via ffmpeg
|
|
cleanup_fn: Callable[[], None] | None = None
|
|
if not audio_encoding or not content_type:
|
|
wav_path, cleanup_fn = _convert_to_wav_sync(local_file_path, sample_rate)
|
|
local_file_path = wav_path
|
|
audio_encoding = "PCM_S16LE"
|
|
content_type = "audio/wav"
|
|
|
|
salute_language = SALUTE_LANGUAGE_MAP.get(language or "", "ru-RU")
|
|
|
|
try:
|
|
verify = str(settings.salute_ca_cert_path) if settings.salute_ca_cert_path else True
|
|
with httpx.Client(verify=verify, timeout=30.0) as client:
|
|
token = _get_salute_access_token(client)
|
|
|
|
with open(local_file_path, "rb") as f:
|
|
audio_data = f.read()
|
|
|
|
file_id = _upload_salute_audio(client, token, audio_data, content_type)
|
|
task_id = _create_salute_task(
|
|
client,
|
|
token,
|
|
file_id,
|
|
language=salute_language,
|
|
model=model,
|
|
audio_encoding=audio_encoding,
|
|
sample_rate=sample_rate,
|
|
)
|
|
response_file_id = _poll_salute_task(
|
|
client, token, task_id, job_id, on_progress
|
|
)
|
|
raw_result = _download_salute_result(client, token, response_file_id)
|
|
|
|
return _build_document_from_salute_result(raw_result, language=salute_language)
|
|
finally:
|
|
if cleanup_fn is not None:
|
|
cleanup_fn()
|
|
|
|
|
|
async def transcribe_with_salute_speech(
|
|
storage: StorageService,
|
|
*,
|
|
file_key: str,
|
|
language: str | None = None,
|
|
model: str = "general",
|
|
sample_rate: int = 16000,
|
|
job_id: uuid.UUID | None = None,
|
|
on_progress: ProgressCallback | None = None,
|
|
) -> Document:
|
|
"""Async wrapper for SaluteSpeech transcription."""
|
|
tmp = await storage.download_to_temp(file_key)
|
|
try:
|
|
return await anyio.to_thread.run_sync(
|
|
lambda: _salute_transcribe_sync(
|
|
local_file_path=tmp.path,
|
|
language=language,
|
|
model=model,
|
|
sample_rate=sample_rate,
|
|
job_id=job_id,
|
|
on_progress=on_progress,
|
|
)
|
|
)
|
|
finally:
|
|
tmp.cleanup()
|