feat(backend): implement SaluteSpeech transcription engine
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -1,10 +1,16 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from pathlib import Path
|
||||||
from tempfile import NamedTemporaryFile
|
from tempfile import NamedTemporaryFile
|
||||||
from typing import Callable, cast
|
from typing import Callable, cast
|
||||||
|
|
||||||
import anyio
|
import anyio
|
||||||
|
import httpx
|
||||||
|
|
||||||
from cpv3.infrastructure.settings import get_settings
|
from cpv3.infrastructure.settings import get_settings
|
||||||
from cpv3.infrastructure.storage.base import StorageService
|
from cpv3.infrastructure.storage.base import StorageService
|
||||||
@@ -29,6 +35,7 @@ from cpv3.modules.transcription.schemas import (
|
|||||||
GoogleSpeechWord,
|
GoogleSpeechWord,
|
||||||
LineNode,
|
LineNode,
|
||||||
SaluteSpeechSegment,
|
SaluteSpeechSegment,
|
||||||
|
SaluteSpeechWord,
|
||||||
SegmentNode,
|
SegmentNode,
|
||||||
Tag,
|
Tag,
|
||||||
TimeRange,
|
TimeRange,
|
||||||
@@ -40,6 +47,46 @@ from cpv3.modules.transcription.schemas import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------- SaluteSpeech Constants ----------------------------------
|
||||||
|
|
||||||
|
SALUTE_AUTH_URL = "https://ngw.devices.sberbank.ru:9443/api/v2/oauth"
|
||||||
|
SALUTE_API_BASE = "https://smartspeech.sber.ru/rest/v1"
|
||||||
|
SALUTE_POLL_INTERVAL_SECONDS = 5.0
|
||||||
|
SALUTE_POLL_TIMEOUT_SECONDS = 600
|
||||||
|
SALUTE_TOKEN_REFRESH_MARGIN_SECONDS = 60
|
||||||
|
|
||||||
|
SALUTE_ENCODING_MAP: dict[str, str] = {
|
||||||
|
".mp3": "MP3",
|
||||||
|
".wav": "PCM_S16LE",
|
||||||
|
".ogg": "opus",
|
||||||
|
".flac": "FLAC",
|
||||||
|
}
|
||||||
|
|
||||||
|
SALUTE_CONTENT_TYPE_MAP: dict[str, str] = {
|
||||||
|
".mp3": "audio/mpeg",
|
||||||
|
".wav": "audio/wav",
|
||||||
|
".ogg": "audio/ogg",
|
||||||
|
".flac": "audio/flac",
|
||||||
|
}
|
||||||
|
|
||||||
|
SALUTE_LANGUAGE_MAP: dict[str, str] = {
|
||||||
|
"ru": "ru-RU",
|
||||||
|
"en": "en-US",
|
||||||
|
}
|
||||||
|
|
||||||
|
ERROR_SALUTE_AUTH_FAILED = "Ошибка авторизации SaluteSpeech: {detail}"
|
||||||
|
ERROR_SALUTE_UPLOAD_FAILED = "Ошибка загрузки файла в SaluteSpeech: {detail}"
|
||||||
|
ERROR_SALUTE_TASK_FAILED = "Ошибка распознавания SaluteSpeech: {detail}"
|
||||||
|
ERROR_SALUTE_TIMEOUT = "Превышено время ожидания распознавания SaluteSpeech"
|
||||||
|
ERROR_SALUTE_UNSUPPORTED_FORMAT = "Неподдерживаемый формат аудио для SaluteSpeech: {ext}"
|
||||||
|
|
||||||
|
_salute_token_lock = threading.Lock()
|
||||||
|
_salute_token: str | None = None
|
||||||
|
_salute_token_expires_at: float = 0.0
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class DocumentBuilder:
|
class DocumentBuilder:
|
||||||
def compute_segment_lines(
|
def compute_segment_lines(
|
||||||
self,
|
self,
|
||||||
@@ -430,3 +477,267 @@ async def transcribe_with_google_speech(
|
|||||||
ogg_cleanup()
|
ogg_cleanup()
|
||||||
finally:
|
finally:
|
||||||
input_tmp.cleanup()
|
input_tmp.cleanup()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------- SaluteSpeech Engine ----------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_salute_time(s: str) -> float:
|
||||||
|
"""Parse SaluteSpeech timestamp string '0.480s' → 0.48."""
|
||||||
|
return float(s.rstrip("s"))
|
||||||
|
|
||||||
|
|
||||||
|
def _get_salute_access_token(client: httpx.Client) -> str:
|
||||||
|
"""Get or refresh SaluteSpeech OAuth token. Thread-safe."""
|
||||||
|
global _salute_token, _salute_token_expires_at
|
||||||
|
with _salute_token_lock:
|
||||||
|
if _salute_token and time.monotonic() < (
|
||||||
|
_salute_token_expires_at - SALUTE_TOKEN_REFRESH_MARGIN_SECONDS
|
||||||
|
):
|
||||||
|
return _salute_token
|
||||||
|
|
||||||
|
settings = get_settings()
|
||||||
|
response = client.post(
|
||||||
|
SALUTE_AUTH_URL,
|
||||||
|
headers={
|
||||||
|
"Authorization": f"Basic {settings.salute_auth_key}",
|
||||||
|
"RqUID": str(uuid.uuid4()),
|
||||||
|
"Content-Type": "application/x-www-form-urlencoded",
|
||||||
|
},
|
||||||
|
content=f"scope={settings.salute_scope}",
|
||||||
|
)
|
||||||
|
if response.status_code != 200:
|
||||||
|
raise RuntimeError(
|
||||||
|
ERROR_SALUTE_AUTH_FAILED.format(detail=response.text[:200])
|
||||||
|
)
|
||||||
|
data = response.json()
|
||||||
|
_salute_token = data["access_token"]
|
||||||
|
expires_in_seconds = (data["expires_at"] / 1000) - time.time()
|
||||||
|
_salute_token_expires_at = time.monotonic() + expires_in_seconds
|
||||||
|
return _salute_token
|
||||||
|
|
||||||
|
|
||||||
|
def _upload_salute_audio(
|
||||||
|
client: httpx.Client, token: str, audio_data: bytes, content_type: str
|
||||||
|
) -> str:
|
||||||
|
"""Upload audio to SaluteSpeech, return request_file_id."""
|
||||||
|
response = client.post(
|
||||||
|
f"{SALUTE_API_BASE}/data:upload",
|
||||||
|
headers={
|
||||||
|
"Authorization": f"Bearer {token}",
|
||||||
|
"Content-Type": content_type,
|
||||||
|
},
|
||||||
|
content=audio_data,
|
||||||
|
timeout=120.0,
|
||||||
|
)
|
||||||
|
if response.status_code != 200:
|
||||||
|
raise RuntimeError(
|
||||||
|
ERROR_SALUTE_UPLOAD_FAILED.format(detail=response.text[:200])
|
||||||
|
)
|
||||||
|
return response.json()["result"]["request_file_id"]
|
||||||
|
|
||||||
|
|
||||||
|
def _create_salute_task(
|
||||||
|
client: httpx.Client,
|
||||||
|
token: str,
|
||||||
|
file_id: str,
|
||||||
|
*,
|
||||||
|
language: str,
|
||||||
|
model: str,
|
||||||
|
audio_encoding: str,
|
||||||
|
sample_rate: int,
|
||||||
|
) -> str:
|
||||||
|
"""Create async recognition task, return task_id."""
|
||||||
|
body = {
|
||||||
|
"options": {
|
||||||
|
"audio_encoding": audio_encoding,
|
||||||
|
"sample_rate": sample_rate,
|
||||||
|
"language": language,
|
||||||
|
"model": model,
|
||||||
|
"channels_count": 1,
|
||||||
|
"hypotheses_count": 1,
|
||||||
|
},
|
||||||
|
"request_file_id": file_id,
|
||||||
|
}
|
||||||
|
response = client.post(
|
||||||
|
f"{SALUTE_API_BASE}/speech:async_recognize",
|
||||||
|
headers={
|
||||||
|
"Authorization": f"Bearer {token}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
json=body,
|
||||||
|
)
|
||||||
|
if response.status_code != 200:
|
||||||
|
raise RuntimeError(
|
||||||
|
ERROR_SALUTE_TASK_FAILED.format(detail=response.text[:200])
|
||||||
|
)
|
||||||
|
return response.json()["result"]["id"]
|
||||||
|
|
||||||
|
|
||||||
|
def _poll_salute_task(
|
||||||
|
client: httpx.Client,
|
||||||
|
token: str,
|
||||||
|
task_id: str,
|
||||||
|
job_uuid: uuid.UUID | None,
|
||||||
|
on_progress: ProgressCallback | None,
|
||||||
|
) -> str:
|
||||||
|
"""Poll task until DONE, return response_file_id."""
|
||||||
|
start = time.monotonic()
|
||||||
|
while True:
|
||||||
|
elapsed = time.monotonic() - start
|
||||||
|
if elapsed > SALUTE_POLL_TIMEOUT_SECONDS:
|
||||||
|
raise TimeoutError(ERROR_SALUTE_TIMEOUT)
|
||||||
|
|
||||||
|
if job_uuid is not None:
|
||||||
|
from cpv3.modules.tasks.service import _raise_if_job_cancelled
|
||||||
|
|
||||||
|
_raise_if_job_cancelled(job_uuid)
|
||||||
|
|
||||||
|
response = client.get(
|
||||||
|
f"{SALUTE_API_BASE}/task:get",
|
||||||
|
params={"id": task_id},
|
||||||
|
headers={"Authorization": f"Bearer {token}"},
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
result = response.json()["result"]
|
||||||
|
status = result["status"]
|
||||||
|
|
||||||
|
if status == "DONE":
|
||||||
|
return result["response_file_id"]
|
||||||
|
if status == "ERROR":
|
||||||
|
error_msg = result.get("error", "unknown error")
|
||||||
|
raise RuntimeError(
|
||||||
|
ERROR_SALUTE_TASK_FAILED.format(detail=error_msg)
|
||||||
|
)
|
||||||
|
|
||||||
|
if on_progress is not None:
|
||||||
|
pct = min(elapsed / SALUTE_POLL_TIMEOUT_SECONDS * 100, 95.0)
|
||||||
|
on_progress(pct)
|
||||||
|
|
||||||
|
time.sleep(SALUTE_POLL_INTERVAL_SECONDS)
|
||||||
|
|
||||||
|
|
||||||
|
def _download_salute_result(
|
||||||
|
client: httpx.Client, token: str, response_file_id: str
|
||||||
|
) -> list[dict]:
|
||||||
|
"""Download recognition result JSON."""
|
||||||
|
response = client.get(
|
||||||
|
f"{SALUTE_API_BASE}/data:download",
|
||||||
|
params={"response_file_id": response_file_id},
|
||||||
|
headers={"Authorization": f"Bearer {token}"},
|
||||||
|
timeout=60.0,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
return response.json()
|
||||||
|
|
||||||
|
|
||||||
|
def _build_document_from_salute_result(
|
||||||
|
raw_channels: list[dict], *, language: str
|
||||||
|
) -> Document:
|
||||||
|
"""Convert SaluteSpeech result JSON to Document."""
|
||||||
|
builder = DocumentBuilder()
|
||||||
|
words_options = WordOptions()
|
||||||
|
|
||||||
|
all_segments: list[SaluteSpeechSegment] = []
|
||||||
|
|
||||||
|
for channel_data in raw_channels:
|
||||||
|
for result_item in channel_data.get("results", []):
|
||||||
|
word_alignments = result_item.get("word_alignments", [])
|
||||||
|
words = [
|
||||||
|
SaluteSpeechWord(
|
||||||
|
word=w["word"],
|
||||||
|
start=_parse_salute_time(w["start"]),
|
||||||
|
end=_parse_salute_time(w["end"]),
|
||||||
|
)
|
||||||
|
for w in word_alignments
|
||||||
|
]
|
||||||
|
|
||||||
|
text = result_item.get("text", "")
|
||||||
|
seg_start = _parse_salute_time(result_item["start"])
|
||||||
|
seg_end = _parse_salute_time(result_item["end"])
|
||||||
|
|
||||||
|
all_segments.append(
|
||||||
|
SaluteSpeechSegment(
|
||||||
|
text=text,
|
||||||
|
start=seg_start,
|
||||||
|
end=seg_end,
|
||||||
|
words=words,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
document = _make_document_from_segments(
|
||||||
|
builder, all_segments, max_line_width=words_options.max_line_width
|
||||||
|
)
|
||||||
|
return builder.process_document(document)
|
||||||
|
|
||||||
|
|
||||||
|
def _salute_transcribe_sync(
|
||||||
|
*,
|
||||||
|
local_file_path: str,
|
||||||
|
language: str | None,
|
||||||
|
model: str,
|
||||||
|
sample_rate: int,
|
||||||
|
job_id: uuid.UUID | None = None,
|
||||||
|
on_progress: ProgressCallback | None = None,
|
||||||
|
) -> Document:
|
||||||
|
"""Synchronous SaluteSpeech transcription (runs in Dramatiq worker thread)."""
|
||||||
|
settings = get_settings()
|
||||||
|
|
||||||
|
ext = Path(local_file_path).suffix.lower()
|
||||||
|
audio_encoding = SALUTE_ENCODING_MAP.get(ext)
|
||||||
|
content_type = SALUTE_CONTENT_TYPE_MAP.get(ext)
|
||||||
|
if not audio_encoding or not content_type:
|
||||||
|
raise ValueError(ERROR_SALUTE_UNSUPPORTED_FORMAT.format(ext=ext))
|
||||||
|
|
||||||
|
salute_language = SALUTE_LANGUAGE_MAP.get(language or "", "ru-RU")
|
||||||
|
|
||||||
|
verify = str(settings.salute_ca_cert_path) if settings.salute_ca_cert_path else True
|
||||||
|
with httpx.Client(verify=verify, timeout=30.0) as client:
|
||||||
|
token = _get_salute_access_token(client)
|
||||||
|
|
||||||
|
with open(local_file_path, "rb") as f:
|
||||||
|
audio_data = f.read()
|
||||||
|
|
||||||
|
file_id = _upload_salute_audio(client, token, audio_data, content_type)
|
||||||
|
task_id = _create_salute_task(
|
||||||
|
client,
|
||||||
|
token,
|
||||||
|
file_id,
|
||||||
|
language=salute_language,
|
||||||
|
model=model,
|
||||||
|
audio_encoding=audio_encoding,
|
||||||
|
sample_rate=sample_rate,
|
||||||
|
)
|
||||||
|
response_file_id = _poll_salute_task(
|
||||||
|
client, token, task_id, job_id, on_progress
|
||||||
|
)
|
||||||
|
raw_result = _download_salute_result(client, token, response_file_id)
|
||||||
|
|
||||||
|
return _build_document_from_salute_result(raw_result, language=salute_language)
|
||||||
|
|
||||||
|
|
||||||
|
async def transcribe_with_salute_speech(
|
||||||
|
storage: StorageService,
|
||||||
|
*,
|
||||||
|
file_key: str,
|
||||||
|
language: str | None = None,
|
||||||
|
model: str = "general",
|
||||||
|
sample_rate: int = 16000,
|
||||||
|
job_id: uuid.UUID | None = None,
|
||||||
|
on_progress: ProgressCallback | None = None,
|
||||||
|
) -> Document:
|
||||||
|
"""Async wrapper for SaluteSpeech transcription."""
|
||||||
|
tmp = await storage.download_to_temp(file_key)
|
||||||
|
try:
|
||||||
|
return await anyio.to_thread.run_sync(
|
||||||
|
lambda: _salute_transcribe_sync(
|
||||||
|
local_file_path=tmp.path,
|
||||||
|
language=language,
|
||||||
|
model=model,
|
||||||
|
sample_rate=sample_rate,
|
||||||
|
job_id=job_id,
|
||||||
|
on_progress=on_progress,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
tmp.cleanup()
|
||||||
|
|||||||
Reference in New Issue
Block a user