diff --git a/whisperx/alignment.py b/whisperx/alignment.py index 3b2fdae9..897e9684 100644 --- a/whisperx/alignment.py +++ b/whisperx/alignment.py @@ -5,7 +5,7 @@ import math from dataclasses import dataclass -from typing import Iterable, Optional, Union, List +from typing import Iterable, Union, List, Callable, Optional import numpy as np import pandas as pd @@ -119,6 +119,7 @@ def align( return_char_alignments: bool = False, print_progress: bool = False, combined_progress: bool = False, + on_progress: Callable[[int, int], None] = None ) -> AlignedTranscriptionResult: """ Align phoneme recognition predictions to known transcription. @@ -147,6 +148,9 @@ def align( base_progress = ((sdx + 1) / total_segments) * 100 percent_complete = (50 + base_progress / 2) if combined_progress else base_progress print(f"Progress: {percent_complete:.2f}%...") + + if on_progress: + on_progress(sdx + 1, total_segments) num_leading = len(segment["text"]) - len(segment["text"].lstrip()) num_trailing = len(segment["text"]) - len(segment["text"].rstrip()) diff --git a/whisperx/asr.py b/whisperx/asr.py index 6de94900..deacecab 100644 --- a/whisperx/asr.py +++ b/whisperx/asr.py @@ -1,6 +1,8 @@ import os -from typing import List, Optional, Union from dataclasses import replace +import warnings +from typing import List, Union, Optional, NamedTuple, Callable +from enum import Enum import ctranslate2 import faster_whisper @@ -101,6 +103,12 @@ class FasterWhisperPipeline(Pipeline): # - add support for timestamp mode # - add support for custom inference kwargs + class TranscriptionState(Enum): + LOADING_AUDIO = "loading_audio" + GENERATING_VAD_SEGMENTS = "generating_vad_segments" + TRANSCRIBING = "transcribing" + FINISHED = "finished" + def __init__( self, model: WhisperModel, @@ -195,8 +203,12 @@ def transcribe( print_progress=False, combined_progress=False, verbose=False, + on_progress: Callable[[TranscriptionState, Optional[int], Optional[int]], None] = None, ) -> TranscriptionResult: if isinstance(audio, str): + if on_progress: + on_progress(self.__class__.TranscriptionState.LOADING_AUDIO) + audio = load_audio(audio) def data(audio, segments): @@ -214,6 +226,8 @@ def data(audio, segments): else: waveform = Pyannote.preprocess_audio(audio) merge_chunks = Pyannote.merge_chunks + if on_progress: + on_progress(self.__class__.TranscriptionState.GENERATING_VAD_SEGMENTS) vad_segments = self.vad_model({"waveform": waveform, "sample_rate": SAMPLE_RATE}) vad_segments = merge_chunks( @@ -253,16 +267,22 @@ def data(audio, segments): segments: List[SingleSegment] = [] batch_size = batch_size or self._batch_size total_segments = len(vad_segments) + + if on_progress: + on_progress(self.__class__.TranscriptionState.TRANSCRIBING, 0, total_segments) + for idx, out in enumerate(self.__call__(data(audio, vad_segments), batch_size=batch_size, num_workers=num_workers)): if print_progress: base_progress = ((idx + 1) / total_segments) * 100 percent_complete = base_progress / 2 if combined_progress else base_progress print(f"Progress: {percent_complete:.2f}%...") + + if on_progress: + on_progress(self.__class__.TranscriptionState.TRANSCRIBING, idx + 1, total_segments) + text = out['text'] if batch_size in [0, 1, None]: text = text[0] - if verbose: - print(f"Transcript: [{round(vad_segments[idx]['start'], 3)} --> {round(vad_segments[idx]['end'], 3)}] {text}") segments.append( { "text": text, @@ -271,6 +291,9 @@ def data(audio, segments): } ) + if on_progress: + on_progress(self.__class__.TranscriptionState.FINISHED) + # revert the tokenizer if multilingual inference is enabled if self.preset_language is None: self.tokenizer = None