Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added tests/fdr.mp3
Binary file not shown.
27 changes: 27 additions & 0 deletions tests/test_progress_callback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import os

import pytest
import torch

import whisper


def test_progress_callback():
device = "cuda" if torch.cuda.is_available() else "cpu"
model = whisper.load_model("tiny").to(device)
audio_path = os.path.join(os.path.dirname(__file__), "fdr.mp3")

progress = []

def callback(progress_data):
progress.append(progress_data)

model.transcribe(
audio_path,
language="en",
verbose=False, # purely for visualization purposes, not needed for the progress callback
progress_callback=callback
)
print(progress)
assert len(progress) > 0
assert progress[-1] == 100.0
10 changes: 8 additions & 2 deletions whisper/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def transcribe(
audio: Union[str, np.ndarray, torch.Tensor],
*,
verbose: Optional[bool] = None,
progress_callback: Optional[callable] = None,
temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
compression_ratio_threshold: Optional[float] = 2.4,
logprob_threshold: Optional[float] = -1.0,
Expand Down Expand Up @@ -138,6 +139,7 @@ def transcribe(
# Pad 30-seconds of silence to the input audio, for slicing
mel = log_mel_spectrogram(audio, model.dims.n_mels, padding=N_SAMPLES)
content_frames = mel.shape[-1] - N_FRAMES
curr_frames = 0
content_duration = float(content_frames * HOP_LENGTH / SAMPLE_RATE)

if decode_options.get("language", None) is None:
Expand Down Expand Up @@ -262,7 +264,7 @@ def new_segment(

# show the progress bar when verbose is False (if True, transcribed text will be printed)
with tqdm.tqdm(
total=content_frames, unit="frames", disable=verbose is not False
total=content_frames, unit="frame", disable=verbose is not False
) as pbar:
last_speech_timestamp = 0.0
# NOTE: This loop is obscurely flattened to make the diff readable.
Expand Down Expand Up @@ -505,7 +507,11 @@ def next_words_segment(segments: List[dict]) -> Optional[dict]:
prompt_reset_since = len(all_tokens)

# update progress bar
pbar.update(min(content_frames, seek) - previous_seek)
frames_processed = min(content_frames, seek) - previous_seek
if progress_callback is not None:
curr_frames = frames_processed + curr_frames
progress_callback(curr_frames / content_frames * 100)
pbar.update(frames_processed)

return dict(
text=tokenizer.decode(all_tokens[len(initial_prompt_tokens) :]),
Expand Down