init: new structure + fix lint errors
This commit is contained in:
@@ -0,0 +1,402 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from tempfile import NamedTemporaryFile
|
||||
from typing import Callable, cast
|
||||
|
||||
import anyio
|
||||
|
||||
from cpv3.infrastructure.settings import get_settings
|
||||
from cpv3.infrastructure.storage.base import StorageService
|
||||
from cpv3.modules.transcription.constants import (
|
||||
FIRST_LINE_IN_DOCUMENT,
|
||||
FIRST_LINE_IN_SEGMENT,
|
||||
FIRST_SEGMENT_IN_DOCUMENT,
|
||||
FIRST_WORD_IN_DOCUMENT,
|
||||
FIRST_WORD_IN_LINE,
|
||||
FIRST_WORD_IN_SEGMENT,
|
||||
LAST_LINE_IN_DOCUMENT,
|
||||
LAST_LINE_IN_SEGMENT,
|
||||
LAST_SEGMENT_IN_DOCUMENT,
|
||||
LAST_WORD_IN_DOCUMENT,
|
||||
LAST_WORD_IN_LINE,
|
||||
LAST_WORD_IN_SEGMENT,
|
||||
)
|
||||
from cpv3.modules.transcription.schemas import (
|
||||
Document,
|
||||
GoogleSpeechResult,
|
||||
GoogleSpeechSegment,
|
||||
GoogleSpeechWord,
|
||||
LineNode,
|
||||
SegmentNode,
|
||||
Tag,
|
||||
TimeRange,
|
||||
WhisperResult,
|
||||
WhisperSegment,
|
||||
WhisperWord,
|
||||
WordNode,
|
||||
WordOptions,
|
||||
)
|
||||
|
||||
|
||||
class DocumentBuilder:
|
||||
def compute_segment_lines(
|
||||
self, segment: WhisperSegment | GoogleSpeechSegment, max_chars_per_line: int
|
||||
) -> list[LineNode]:
|
||||
words = segment.words or []
|
||||
lines: list[list[WhisperWord | GoogleSpeechWord]] = []
|
||||
cur_line: list[WhisperWord | GoogleSpeechWord] = []
|
||||
cur_len = 0
|
||||
|
||||
for w in words:
|
||||
text = (w.word or "").strip()
|
||||
if not text:
|
||||
continue
|
||||
|
||||
extra = len(text) + (1 if cur_line else 0)
|
||||
|
||||
if cur_line and cur_len + extra > max_chars_per_line:
|
||||
lines.append(cur_line)
|
||||
cur_line, cur_len = [w], len(text)
|
||||
else:
|
||||
cur_line.append(w)
|
||||
cur_len += extra
|
||||
|
||||
if cur_line:
|
||||
lines.append(cur_line)
|
||||
|
||||
result_lines: list[LineNode] = []
|
||||
for rline in lines:
|
||||
time = TimeRange(start=rline[0].start, end=rline[-1].end)
|
||||
|
||||
word_nodes = [
|
||||
WordNode(
|
||||
text=(rword.word or "").strip(),
|
||||
time=TimeRange(start=rword.start, end=rword.end),
|
||||
semantic_tags=[],
|
||||
structure_tags=[],
|
||||
)
|
||||
for rword in rline
|
||||
]
|
||||
|
||||
line_node = LineNode(
|
||||
text=" ".join((rword.word or "") for rword in rline).strip(),
|
||||
semantic_tags=[],
|
||||
structure_tags=[],
|
||||
time=time,
|
||||
words=word_nodes,
|
||||
)
|
||||
result_lines.append(line_node)
|
||||
|
||||
return result_lines
|
||||
|
||||
def process_line(
|
||||
self,
|
||||
line: LineNode,
|
||||
is_first_line_in_document: bool,
|
||||
is_last_line_in_document: bool,
|
||||
is_first_line_in_segment: bool,
|
||||
is_last_line_in_segment: bool,
|
||||
) -> list[WordNode]:
|
||||
words: list[WordNode] = []
|
||||
for idx, word in enumerate(line.words):
|
||||
is_first = idx == 0
|
||||
is_last = idx == len(line.words) - 1
|
||||
|
||||
rules = [
|
||||
(is_first_line_in_document and is_first, FIRST_WORD_IN_DOCUMENT),
|
||||
(is_last_line_in_document and is_last, LAST_WORD_IN_DOCUMENT),
|
||||
(is_first_line_in_segment and is_first, FIRST_WORD_IN_SEGMENT),
|
||||
(is_last_line_in_segment and is_last, LAST_WORD_IN_SEGMENT),
|
||||
(is_first, FIRST_WORD_IN_LINE),
|
||||
(is_last, LAST_WORD_IN_LINE),
|
||||
]
|
||||
|
||||
structure_tags = [
|
||||
Tag(name=tag_name) for condition, tag_name in rules if condition
|
||||
]
|
||||
|
||||
new_word = word.model_copy(update={"structure_tags": structure_tags})
|
||||
words.append(new_word)
|
||||
|
||||
return words
|
||||
|
||||
def process_segment(
|
||||
self,
|
||||
segment: SegmentNode,
|
||||
is_first_segment_in_document: bool,
|
||||
is_last_segment_in_document: bool,
|
||||
) -> list[LineNode]:
|
||||
lines: list[LineNode] = []
|
||||
for idx, line in enumerate(segment.lines):
|
||||
is_first = idx == 0
|
||||
is_last = idx == len(segment.lines) - 1
|
||||
|
||||
rules = [
|
||||
(is_first_segment_in_document and is_first, FIRST_LINE_IN_DOCUMENT),
|
||||
(is_last_segment_in_document and is_last, LAST_LINE_IN_DOCUMENT),
|
||||
(is_first, FIRST_LINE_IN_SEGMENT),
|
||||
(is_last, LAST_LINE_IN_SEGMENT),
|
||||
]
|
||||
|
||||
structure_tags = [
|
||||
Tag(name=tag_name) for condition, tag_name in rules if condition
|
||||
]
|
||||
|
||||
words = self.process_line(
|
||||
line,
|
||||
is_first_line_in_document=is_first_segment_in_document and is_first,
|
||||
is_last_line_in_document=is_last_segment_in_document and is_last,
|
||||
is_first_line_in_segment=is_first,
|
||||
is_last_line_in_segment=is_last,
|
||||
)
|
||||
|
||||
new_line = line.model_copy(
|
||||
update={"structure_tags": structure_tags, "words": words}
|
||||
)
|
||||
lines.append(new_line)
|
||||
|
||||
return lines
|
||||
|
||||
def process_document(self, document: Document) -> Document:
|
||||
segments: list[SegmentNode] = []
|
||||
|
||||
for idx, segment in enumerate(document.segments):
|
||||
structure_tags: list[Tag] = []
|
||||
is_first_segment_in_document = idx == 0
|
||||
is_last_segment_in_document = idx == len(document.segments) - 1
|
||||
|
||||
if is_first_segment_in_document:
|
||||
structure_tags.append(Tag(name=FIRST_SEGMENT_IN_DOCUMENT))
|
||||
if is_last_segment_in_document:
|
||||
structure_tags.append(Tag(name=LAST_SEGMENT_IN_DOCUMENT))
|
||||
|
||||
lines = self.process_segment(
|
||||
segment, is_first_segment_in_document, is_last_segment_in_document
|
||||
)
|
||||
new_segment = segment.model_copy(
|
||||
update={"lines": lines, "structure_tags": structure_tags}
|
||||
)
|
||||
segments.append(new_segment)
|
||||
|
||||
return Document(segments=segments)
|
||||
|
||||
|
||||
async def _convert_local_to_ogg(input_path: str) -> tuple[str, Callable[[], None]]:
|
||||
with NamedTemporaryFile(suffix=".ogg", delete=False) as out:
|
||||
out_path = out.name
|
||||
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
"ffmpeg",
|
||||
"-y",
|
||||
"-i",
|
||||
input_path,
|
||||
"-c:a",
|
||||
"libopus",
|
||||
"-b:a",
|
||||
"24k",
|
||||
"-vn",
|
||||
"-ac",
|
||||
"1",
|
||||
"-ar",
|
||||
"16000",
|
||||
out_path,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
_, stderr = await proc.communicate()
|
||||
if proc.returncode != 0:
|
||||
raise RuntimeError(f"ffmpeg failed: {stderr.decode(errors='ignore')}")
|
||||
|
||||
def _cleanup() -> None:
|
||||
import os
|
||||
|
||||
if os.path.exists(out_path):
|
||||
os.remove(out_path)
|
||||
|
||||
return out_path, _cleanup
|
||||
|
||||
|
||||
def _make_document_from_segments(
|
||||
builder: DocumentBuilder,
|
||||
segments: list[WhisperSegment] | list[GoogleSpeechSegment],
|
||||
*,
|
||||
max_line_width: int,
|
||||
) -> Document:
|
||||
result_segments: list[SegmentNode] = []
|
||||
|
||||
for segment in segments:
|
||||
lines = builder.compute_segment_lines(segment, max_line_width)
|
||||
time = TimeRange(start=segment.start, end=segment.end)
|
||||
segment_node = SegmentNode(
|
||||
text=segment.text.strip(),
|
||||
semantic_tags=[],
|
||||
structure_tags=[],
|
||||
time=time,
|
||||
lines=lines,
|
||||
)
|
||||
result_segments.append(segment_node)
|
||||
|
||||
return Document(segments=result_segments)
|
||||
|
||||
|
||||
def _whisper_transcribe_sync(
|
||||
*,
|
||||
local_file_path: str,
|
||||
model_name: str,
|
||||
language: str | None,
|
||||
) -> Document:
|
||||
import whisper # type: ignore[import-untyped]
|
||||
|
||||
settings = get_settings()
|
||||
settings.transcription_models_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
builder = DocumentBuilder()
|
||||
|
||||
model = whisper.load_model(
|
||||
model_name, download_root=str(settings.transcription_models_dir)
|
||||
)
|
||||
|
||||
if language is None:
|
||||
audio = whisper.load_audio(local_file_path)
|
||||
audio = whisper.pad_or_trim(audio)
|
||||
mel = whisper.log_mel_spectrogram(audio, n_mels=model.dims.n_mels).to(
|
||||
model.device
|
||||
)
|
||||
_, probs_raw = model.detect_language(mel)
|
||||
probs = cast(dict[str, float], probs_raw)
|
||||
language = max(probs, key=lambda k: probs[k])
|
||||
|
||||
result = whisper.transcribe(
|
||||
audio=whisper.load_audio(local_file_path),
|
||||
model=model,
|
||||
word_timestamps=True,
|
||||
temperature=0.2,
|
||||
language=language,
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
parsed = WhisperResult.model_validate(result)
|
||||
|
||||
words_options = WordOptions(
|
||||
highlight_words=True,
|
||||
max_line_width=32,
|
||||
max_line_count=2,
|
||||
)
|
||||
|
||||
document = _make_document_from_segments(
|
||||
builder, parsed.segments, max_line_width=words_options.max_line_width
|
||||
)
|
||||
return builder.process_document(document)
|
||||
|
||||
|
||||
async def transcribe_with_whisper(
|
||||
storage: StorageService,
|
||||
*,
|
||||
file_key: str,
|
||||
model_name: str = "tiny",
|
||||
language: str | None = None,
|
||||
) -> Document:
|
||||
tmp = await storage.download_to_temp(file_key)
|
||||
try:
|
||||
return await anyio.to_thread.run_sync(
|
||||
lambda: _whisper_transcribe_sync(
|
||||
local_file_path=tmp.path,
|
||||
model_name=model_name,
|
||||
language=language,
|
||||
)
|
||||
)
|
||||
finally:
|
||||
tmp.cleanup()
|
||||
|
||||
|
||||
def _google_transcribe_sync(
|
||||
*, ogg_bytes: bytes, language_codes: list[str]
|
||||
) -> GoogleSpeechResult:
|
||||
from google.cloud import speech
|
||||
|
||||
settings = get_settings()
|
||||
|
||||
client: speech.SpeechClient = speech.SpeechClient.from_service_account_file(
|
||||
str(settings.google_service_key_path)
|
||||
)
|
||||
|
||||
audio = speech.RecognitionAudio(content=ogg_bytes)
|
||||
config = speech.RecognitionConfig(
|
||||
encoding=speech.RecognitionConfig.AudioEncoding.OGG_OPUS,
|
||||
sample_rate_hertz=16000,
|
||||
language_code=language_codes[0],
|
||||
alternative_language_codes=(
|
||||
language_codes[1:] if len(language_codes) > 1 else []
|
||||
),
|
||||
model="latest_long",
|
||||
enable_word_time_offsets=True,
|
||||
)
|
||||
|
||||
operation = client.long_running_recognize(config=config, audio=audio)
|
||||
response = operation.result(timeout=600)
|
||||
|
||||
segments: list[GoogleSpeechSegment] = []
|
||||
full_text = ""
|
||||
|
||||
for result in response.results:
|
||||
alternative = result.alternatives[0]
|
||||
words: list[GoogleSpeechWord] = []
|
||||
for word_info in alternative.words:
|
||||
words.append(
|
||||
GoogleSpeechWord(
|
||||
word=word_info.word,
|
||||
start=word_info.start_time.total_seconds(),
|
||||
end=word_info.end_time.total_seconds(),
|
||||
)
|
||||
)
|
||||
|
||||
if words:
|
||||
segment_text = alternative.transcript
|
||||
full_text += segment_text + " "
|
||||
segments.append(
|
||||
GoogleSpeechSegment(
|
||||
text=segment_text,
|
||||
start=words[0].start,
|
||||
end=words[-1].end,
|
||||
words=words,
|
||||
)
|
||||
)
|
||||
|
||||
return GoogleSpeechResult(
|
||||
text=full_text.strip(), segments=segments, language=language_codes[0]
|
||||
)
|
||||
|
||||
|
||||
async def transcribe_with_google_speech(
|
||||
storage: StorageService,
|
||||
*,
|
||||
file_key: str,
|
||||
language_codes: list[str] | None = None,
|
||||
) -> Document:
|
||||
language_codes = language_codes or ["ru-RU", "en-US"]
|
||||
|
||||
builder = DocumentBuilder()
|
||||
words_options = WordOptions()
|
||||
|
||||
input_tmp = await storage.download_to_temp(file_key)
|
||||
try:
|
||||
ogg_path, ogg_cleanup = await _convert_local_to_ogg(input_tmp.path)
|
||||
try:
|
||||
with open(ogg_path, "rb") as f:
|
||||
content = f.read()
|
||||
|
||||
result = await anyio.to_thread.run_sync(
|
||||
lambda: _google_transcribe_sync(
|
||||
ogg_bytes=content, language_codes=language_codes
|
||||
)
|
||||
)
|
||||
|
||||
document = _make_document_from_segments(
|
||||
builder, result.segments, max_line_width=words_options.max_line_width
|
||||
)
|
||||
return builder.process_document(document)
|
||||
finally:
|
||||
ogg_cleanup()
|
||||
finally:
|
||||
input_tmp.cleanup()
|
||||
Reference in New Issue
Block a user