Files
main_backend/cpv3/modules/transcription/service.py
T
2026-04-03 23:59:19 +03:00

433 lines
13 KiB
Python

from __future__ import annotations
import asyncio
from tempfile import NamedTemporaryFile
from typing import Callable, cast
import anyio
from cpv3.infrastructure.settings import get_settings
from cpv3.infrastructure.storage.base import StorageService
from cpv3.modules.transcription.constants import (
FIRST_LINE_IN_DOCUMENT,
FIRST_LINE_IN_SEGMENT,
FIRST_SEGMENT_IN_DOCUMENT,
FIRST_WORD_IN_DOCUMENT,
FIRST_WORD_IN_LINE,
FIRST_WORD_IN_SEGMENT,
LAST_LINE_IN_DOCUMENT,
LAST_LINE_IN_SEGMENT,
LAST_SEGMENT_IN_DOCUMENT,
LAST_WORD_IN_DOCUMENT,
LAST_WORD_IN_LINE,
LAST_WORD_IN_SEGMENT,
)
from cpv3.modules.transcription.schemas import (
Document,
GoogleSpeechResult,
GoogleSpeechSegment,
GoogleSpeechWord,
LineNode,
SaluteSpeechSegment,
SegmentNode,
Tag,
TimeRange,
WhisperResult,
WhisperSegment,
WhisperWord,
WordNode,
WordOptions,
)
class DocumentBuilder:
def compute_segment_lines(
self,
segment: WhisperSegment | GoogleSpeechSegment | SaluteSpeechSegment,
max_chars_per_line: int,
) -> list[LineNode]:
words = segment.words or []
lines: list[list[WhisperWord | GoogleSpeechWord]] = []
cur_line: list[WhisperWord | GoogleSpeechWord] = []
cur_len = 0
for w in words:
text = (w.word or "").strip()
if not text:
continue
extra = len(text) + (1 if cur_line else 0)
if cur_line and cur_len + extra > max_chars_per_line:
lines.append(cur_line)
cur_line, cur_len = [w], len(text)
else:
cur_line.append(w)
cur_len += extra
if cur_line:
lines.append(cur_line)
result_lines: list[LineNode] = []
for rline in lines:
time = TimeRange(start=rline[0].start, end=rline[-1].end)
word_nodes = [
WordNode(
text=(rword.word or "").strip(),
time=TimeRange(start=rword.start, end=rword.end),
semantic_tags=[],
structure_tags=[],
)
for rword in rline
]
line_node = LineNode(
text=" ".join((rword.word or "") for rword in rline).strip(),
semantic_tags=[],
structure_tags=[],
time=time,
words=word_nodes,
)
result_lines.append(line_node)
return result_lines
def process_line(
self,
line: LineNode,
is_first_line_in_document: bool,
is_last_line_in_document: bool,
is_first_line_in_segment: bool,
is_last_line_in_segment: bool,
) -> list[WordNode]:
words: list[WordNode] = []
for idx, word in enumerate(line.words):
is_first = idx == 0
is_last = idx == len(line.words) - 1
rules = [
(is_first_line_in_document and is_first, FIRST_WORD_IN_DOCUMENT),
(is_last_line_in_document and is_last, LAST_WORD_IN_DOCUMENT),
(is_first_line_in_segment and is_first, FIRST_WORD_IN_SEGMENT),
(is_last_line_in_segment and is_last, LAST_WORD_IN_SEGMENT),
(is_first, FIRST_WORD_IN_LINE),
(is_last, LAST_WORD_IN_LINE),
]
structure_tags = [
Tag(name=tag_name) for condition, tag_name in rules if condition
]
new_word = word.model_copy(update={"structure_tags": structure_tags})
words.append(new_word)
return words
def process_segment(
self,
segment: SegmentNode,
is_first_segment_in_document: bool,
is_last_segment_in_document: bool,
) -> list[LineNode]:
lines: list[LineNode] = []
for idx, line in enumerate(segment.lines):
is_first = idx == 0
is_last = idx == len(segment.lines) - 1
rules = [
(is_first_segment_in_document and is_first, FIRST_LINE_IN_DOCUMENT),
(is_last_segment_in_document and is_last, LAST_LINE_IN_DOCUMENT),
(is_first, FIRST_LINE_IN_SEGMENT),
(is_last, LAST_LINE_IN_SEGMENT),
]
structure_tags = [
Tag(name=tag_name) for condition, tag_name in rules if condition
]
words = self.process_line(
line,
is_first_line_in_document=is_first_segment_in_document and is_first,
is_last_line_in_document=is_last_segment_in_document and is_last,
is_first_line_in_segment=is_first,
is_last_line_in_segment=is_last,
)
new_line = line.model_copy(
update={"structure_tags": structure_tags, "words": words}
)
lines.append(new_line)
return lines
def process_document(self, document: Document) -> Document:
segments: list[SegmentNode] = []
for idx, segment in enumerate(document.segments):
structure_tags: list[Tag] = []
is_first_segment_in_document = idx == 0
is_last_segment_in_document = idx == len(document.segments) - 1
if is_first_segment_in_document:
structure_tags.append(Tag(name=FIRST_SEGMENT_IN_DOCUMENT))
if is_last_segment_in_document:
structure_tags.append(Tag(name=LAST_SEGMENT_IN_DOCUMENT))
lines = self.process_segment(
segment, is_first_segment_in_document, is_last_segment_in_document
)
new_segment = segment.model_copy(
update={"lines": lines, "structure_tags": structure_tags}
)
segments.append(new_segment)
return Document(segments=segments)
async def _convert_local_to_ogg(input_path: str) -> tuple[str, Callable[[], None]]:
with NamedTemporaryFile(suffix=".ogg", delete=False) as out:
out_path = out.name
proc = await asyncio.create_subprocess_exec(
"ffmpeg",
"-y",
"-i",
input_path,
"-c:a",
"libopus",
"-b:a",
"24k",
"-vn",
"-ac",
"1",
"-ar",
"16000",
out_path,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
_, stderr = await proc.communicate()
if proc.returncode != 0:
raise RuntimeError(f"ffmpeg failed: {stderr.decode(errors='ignore')}")
def _cleanup() -> None:
import os
if os.path.exists(out_path):
os.remove(out_path)
return out_path, _cleanup
def _make_document_from_segments(
builder: DocumentBuilder,
segments: list[WhisperSegment] | list[GoogleSpeechSegment] | list[SaluteSpeechSegment],
*,
max_line_width: int,
) -> Document:
result_segments: list[SegmentNode] = []
for segment in segments:
lines = builder.compute_segment_lines(segment, max_line_width)
time = TimeRange(start=segment.start, end=segment.end)
segment_node = SegmentNode(
text=segment.text.strip(),
semantic_tags=[],
structure_tags=[],
time=time,
lines=lines,
)
result_segments.append(segment_node)
return Document(segments=result_segments)
ProgressCallback = Callable[[float], None]
def _whisper_transcribe_sync(
*,
local_file_path: str,
model_name: str,
language: str | None,
on_progress: ProgressCallback | None = None,
) -> Document:
import whisper # type: ignore[import-untyped]
settings = get_settings()
settings.transcription_models_dir.mkdir(parents=True, exist_ok=True)
builder = DocumentBuilder()
model = whisper.load_model(
model_name, download_root=str(settings.transcription_models_dir)
)
if language is None:
audio = whisper.load_audio(local_file_path)
audio = whisper.pad_or_trim(audio)
mel = whisper.log_mel_spectrogram(audio, n_mels=model.dims.n_mels).to(
model.device
)
_, probs_raw = model.detect_language(mel)
probs = cast(dict[str, float], probs_raw)
language = max(probs, key=lambda k: probs[k])
if on_progress is not None:
from unittest.mock import patch
from tqdm import tqdm as _orig_tqdm
class _ProgressTqdm(_orig_tqdm):
def update(self, n=1):
super().update(n)
if self.total:
on_progress(min(self.n / self.total * 100.0, 100.0))
with patch("whisper.transcribe.tqdm.tqdm", _ProgressTqdm):
result = whisper.transcribe(
audio=whisper.load_audio(local_file_path),
model=model,
word_timestamps=True,
temperature=0.2,
language=language,
verbose=False,
)
on_progress(100.0)
else:
result = whisper.transcribe(
audio=whisper.load_audio(local_file_path),
model=model,
word_timestamps=True,
temperature=0.2,
language=language,
verbose=None,
)
parsed = WhisperResult.model_validate(result)
words_options = WordOptions(
highlight_words=True,
max_line_width=32,
max_line_count=2,
)
document = _make_document_from_segments(
builder, parsed.segments, max_line_width=words_options.max_line_width
)
return builder.process_document(document)
async def transcribe_with_whisper(
storage: StorageService,
*,
file_key: str,
model_name: str = "tiny",
language: str | None = None,
on_progress: ProgressCallback | None = None,
) -> Document:
tmp = await storage.download_to_temp(file_key)
try:
return await anyio.to_thread.run_sync(
lambda: _whisper_transcribe_sync(
local_file_path=tmp.path,
model_name=model_name,
language=language,
on_progress=on_progress,
)
)
finally:
tmp.cleanup()
def _google_transcribe_sync(
*, ogg_bytes: bytes, language_codes: list[str]
) -> GoogleSpeechResult:
from google.cloud import speech
settings = get_settings()
client: speech.SpeechClient = speech.SpeechClient.from_service_account_file(
str(settings.google_service_key_path)
)
audio = speech.RecognitionAudio(content=ogg_bytes)
config = speech.RecognitionConfig(
encoding=speech.RecognitionConfig.AudioEncoding.OGG_OPUS,
sample_rate_hertz=16000,
language_code=language_codes[0],
alternative_language_codes=(
language_codes[1:] if len(language_codes) > 1 else []
),
model="latest_long",
enable_word_time_offsets=True,
)
operation = client.long_running_recognize(config=config, audio=audio)
response = operation.result(timeout=600)
segments: list[GoogleSpeechSegment] = []
full_text = ""
for result in response.results:
alternative = result.alternatives[0]
words: list[GoogleSpeechWord] = []
for word_info in alternative.words:
words.append(
GoogleSpeechWord(
word=word_info.word,
start=word_info.start_time.total_seconds(),
end=word_info.end_time.total_seconds(),
)
)
if words:
segment_text = alternative.transcript
full_text += segment_text + " "
segments.append(
GoogleSpeechSegment(
text=segment_text,
start=words[0].start,
end=words[-1].end,
words=words,
)
)
return GoogleSpeechResult(
text=full_text.strip(), segments=segments, language=language_codes[0]
)
async def transcribe_with_google_speech(
storage: StorageService,
*,
file_key: str,
language_codes: list[str] | None = None,
) -> Document:
language_codes = language_codes or ["ru-RU", "en-US"]
builder = DocumentBuilder()
words_options = WordOptions()
input_tmp = await storage.download_to_temp(file_key)
try:
ogg_path, ogg_cleanup = await _convert_local_to_ogg(input_tmp.path)
try:
with open(ogg_path, "rb") as f:
content = f.read()
result = await anyio.to_thread.run_sync(
lambda: _google_transcribe_sync(
ogg_bytes=content, language_codes=language_codes
)
)
document = _make_document_from_segments(
builder, result.segments, max_line_width=words_options.max_line_width
)
return builder.process_document(document)
finally:
ogg_cleanup()
finally:
input_tmp.cleanup()