new features
This commit is contained in:
@@ -240,11 +240,15 @@ def _make_document_from_segments(
|
||||
return Document(segments=result_segments)
|
||||
|
||||
|
||||
ProgressCallback = Callable[[float], None]
|
||||
|
||||
|
||||
def _whisper_transcribe_sync(
|
||||
*,
|
||||
local_file_path: str,
|
||||
model_name: str,
|
||||
language: str | None,
|
||||
on_progress: ProgressCallback | None = None,
|
||||
) -> Document:
|
||||
import whisper # type: ignore[import-untyped]
|
||||
|
||||
@@ -267,14 +271,35 @@ def _whisper_transcribe_sync(
|
||||
probs = cast(dict[str, float], probs_raw)
|
||||
language = max(probs, key=lambda k: probs[k])
|
||||
|
||||
result = whisper.transcribe(
|
||||
audio=whisper.load_audio(local_file_path),
|
||||
model=model,
|
||||
word_timestamps=True,
|
||||
temperature=0.2,
|
||||
language=language,
|
||||
verbose=False,
|
||||
)
|
||||
if on_progress is not None:
|
||||
from unittest.mock import patch
|
||||
from tqdm import tqdm as _orig_tqdm
|
||||
|
||||
class _ProgressTqdm(_orig_tqdm):
|
||||
def update(self, n=1):
|
||||
super().update(n)
|
||||
if self.total:
|
||||
on_progress(min(self.n / self.total * 100.0, 100.0))
|
||||
|
||||
with patch("whisper.transcribe.tqdm.tqdm", _ProgressTqdm):
|
||||
result = whisper.transcribe(
|
||||
audio=whisper.load_audio(local_file_path),
|
||||
model=model,
|
||||
word_timestamps=True,
|
||||
temperature=0.2,
|
||||
language=language,
|
||||
verbose=False,
|
||||
)
|
||||
on_progress(100.0)
|
||||
else:
|
||||
result = whisper.transcribe(
|
||||
audio=whisper.load_audio(local_file_path),
|
||||
model=model,
|
||||
word_timestamps=True,
|
||||
temperature=0.2,
|
||||
language=language,
|
||||
verbose=None,
|
||||
)
|
||||
|
||||
parsed = WhisperResult.model_validate(result)
|
||||
|
||||
@@ -296,6 +321,7 @@ async def transcribe_with_whisper(
|
||||
file_key: str,
|
||||
model_name: str = "tiny",
|
||||
language: str | None = None,
|
||||
on_progress: ProgressCallback | None = None,
|
||||
) -> Document:
|
||||
tmp = await storage.download_to_temp(file_key)
|
||||
try:
|
||||
@@ -304,6 +330,7 @@ async def transcribe_with_whisper(
|
||||
local_file_path=tmp.path,
|
||||
model_name=model_name,
|
||||
language=language,
|
||||
on_progress=on_progress,
|
||||
)
|
||||
)
|
||||
finally:
|
||||
|
||||
Reference in New Issue
Block a user