diff --git a/whisper_timestamped/transcribe.py b/whisper_timestamped/transcribe.py index 09eb6cc..a88d677 100644 --- a/whisper_timestamped/transcribe.py +++ b/whisper_timestamped/transcribe.py @@ -133,7 +133,7 @@ def transcribe_timestamped( Whether to compute word confidence. If True, a finer confidence for each segment will be computed as well. - vad: bool or str in ["silero", "silero:3.1", "auditok"] + vad: bool or str in ["silero", "silero:3.1", "auditok"] or list of start/end timestamps pairs corresponding to speech (ex: [(0.0, 3.50), (32.43, 36.43)]) Whether to perform voice activity detection (VAD) on the audio file, to remove silent parts before transcribing with Whisper model. This should decrease hallucinations from the Whisper model. When set to True, the default VAD algorithm is used (silero). @@ -279,7 +279,7 @@ def transcribe_timestamped( if vad: audio = get_audio_tensor(audio) - audio, convert_timestamps = remove_non_speech(audio, method=vad, plot=plot_word_alignment) + audio, convert_timestamps = remove_non_speech(audio, method=vad, sample_rate=SAMPLE_RATE, plot=plot_word_alignment) global num_alignment_for_plot num_alignment_for_plot = 0 @@ -1844,11 +1844,23 @@ def split_tokens_on_spaces(tokens: torch.Tensor, tokenizer, remove_punctuation_f return words, word_tokens, word_tokens_indices def check_vad_method(method, with_version=False): + """ + Check whether the VAD method is valid and return the method in a consistent format + + method: str or list or True or False + """ if method in [True, "True", "true"]: return check_vad_method("silero") # default method elif method in [False, "False", "false"]: return False - elif method.startswith("silero"): + elif not isinstance(method, str) and hasattr(method, '__iter__'): + # list of explicit timestamps + checked_pairs = [] + for s_e in method: + assert len(s_e) == 2, f"Got unexpected element {s_e} in the list of VAD segments. Expect (start, end) pairs" + checked_pairs.append(tuple(s_e)) + return checked_pairs + elif isinstance(method, str) and method.startswith("silero"): version = None if method != "silero": assert method.startswith("silero:"), f"Got unexpected VAD method {method}" @@ -1869,12 +1881,18 @@ def check_vad_method(method, with_version=False): except ImportError: raise ImportError("Please install auditok to use the auditok VAD (or use another VAD method)") else: - raise ValueError(f"Got unexpected VAD method {method}") + try: + method = eval(method) + assert hasattr(method, '__iter__') + except: + raise ValueError(f"Got unexpected VAD method {method}") + return check_vad_method(method, with_version=with_version) return method _silero_vad_model = {} _has_onnx = None def get_vad_segments(audio, + sample_rate=SAMPLE_RATE, output_sample=False, min_speech_duration=0.1, min_silence_duration=0.1, @@ -1894,12 +1912,17 @@ def get_vad_segments(audio, minimum duration (in sec) of a silence segment dilatation: float how much (in sec) to enlarge each speech segment detected by the VAD - method: str + method: str or list VAD method to use (auditok, silero, silero:v3.1) """ global _silero_vad_model, _silero_get_speech_ts, _has_onnx - if method.startswith("silero"): + if isinstance(method, list): + # Explicit timestamps + segments = [{"start": s * sample_rate, "end": e * sample_rate} for (s, e) in method] + dilatation = 0 + + elif isinstance(method, str) and method.startswith("silero"): version = None _, version = check_vad_method(method, True) @@ -1969,6 +1992,7 @@ def apply_folder_hack(): audio = audio / max(0.1, audio.abs().max()) segments = _silero_get_speech_ts(audio, _silero_vad_model[version], + sampling_rate = sample_rate, min_speech_duration_ms = round(min_speech_duration * 1000), min_silence_duration_ms = round(min_silence_duration * 1000), return_seconds = False, @@ -1982,11 +2006,11 @@ def apply_folder_hack(): data = (audio.numpy() * 32767).astype(np.int16).tobytes() - audio_duration = len(audio) / SAMPLE_RATE + audio_duration = len(audio) / sample_rate segments = auditok.split( data, - sampling_rate=SAMPLE_RATE, # sampling frequency in Hz + sampling_rate=sample_rate, # sampling frequency in Hz channels=1, # number of channels sample_width=2, # number of bytes per sample min_dur=min_speech_duration, # minimum duration of a valid audio event in seconds @@ -1996,13 +2020,13 @@ def apply_folder_hack(): drop_trailing_silence=True, ) - segments = [{"start": s._meta.start * SAMPLE_RATE, "end": s._meta.end * SAMPLE_RATE} for s in segments] + segments = [{"start": s._meta.start * sample_rate, "end": s._meta.end * sample_rate} for s in segments] else: raise ValueError(f"Got unexpected VAD method {method}") if dilatation > 0: - dilatation = round(dilatation * SAMPLE_RATE) + dilatation = round(dilatation * sample_rate) new_segments = [] for seg in segments: new_seg = { @@ -2015,7 +2039,7 @@ def apply_folder_hack(): new_segments.append(new_seg) segments = new_segments - ratio = 1 if output_sample else 1 / SAMPLE_RATE + ratio = 1 if output_sample else 1 / sample_rate if ratio != 1: for seg in segments: @@ -2031,6 +2055,8 @@ def remove_non_speech(audio, use_sample=False, min_speech_duration=0.1, min_silence_duration=1, + dilatation=0.5, + sample_rate=SAMPLE_RATE, method="silero", plot=False, ): @@ -2048,6 +2074,8 @@ def remove_non_speech(audio, minimum duration (in sec) of a speech segment min_silence_duration: float minimum duration (in sec) of a silence segment + dilatation: float + how much (in sec) to enlarge each speech segment detected by the VAD method: str method to use to remove non-speech segments plot: bool or str @@ -2057,9 +2085,11 @@ def remove_non_speech(audio, segments = get_vad_segments( audio, + sample_rate=sample_rate, output_sample=True, min_speech_duration=min_speech_duration, min_silence_duration=min_silence_duration, + dilatation=dilatation, method=method, ) @@ -2074,17 +2104,17 @@ def remove_non_speech(audio, plt.figure() max_num_samples = 10000 step = (audio.shape[-1] // max_num_samples) + 1 - times = [i*step/SAMPLE_RATE for i in range((audio.shape[-1]-1) // step + 1)] + times = [i*step/sample_rate for i in range((audio.shape[-1]-1) // step + 1)] plt.plot(times, audio[::step]) for s, e in segments: - plt.axvspan(s/SAMPLE_RATE, e/SAMPLE_RATE, color='red', alpha=0.1) + plt.axvspan(s/sample_rate, e/sample_rate, color='red', alpha=0.1) if isinstance(plot, str): plt.savefig(f"{plot}.VAD.jpg", bbox_inches='tight', pad_inches=0) else: plt.show() if not use_sample: - segments = [(float(s)/SAMPLE_RATE, float(e)/SAMPLE_RATE) for s,e in segments] + segments = [(float(s)/sample_rate, float(e)/sample_rate) for s,e in segments] return audio_speech, lambda t, t2 = None: do_convert_timestamps(segments, t, t2) @@ -2939,7 +2969,10 @@ def str2output_formats(string): parser.add_argument('--language', help=f"language spoken in the audio, specify None to perform language detection.", choices=sorted(whisper.tokenizer.LANGUAGES.keys()) + sorted([k.title() for k in whisper.tokenizer.TO_LANGUAGE_CODE.keys()]), default=None) # f"{', '.join(sorted(k+'('+v+')' for k,v in whisper.tokenizer.LANGUAGES.items()))} - parser.add_argument('--vad', default=False, help="whether to run Voice Activity Detection (VAD) to remove non-speech segment before applying Whisper model (removes hallucinations). Can be: True, False, silero, silero:3.1 (or another version), or autitok. Some additional libraries might be needed") + parser.add_argument('--vad', default=False, help="whether to run Voice Activity Detection (VAD) to remove non-speech segment before applying Whisper model (removes hallucinations). " + "Can be: True, False, auditok, silero (default when vad=True), silero:3.1 (or another version), or a list of timestamps in seconds (e.g. \"[(0.0, 3.50), (32.43, 36.43)]\"). " + "Note: Some additional libraries might be needed (torchaudio and onnxruntime for silero, auditok for auditok)." + ) parser.add_argument('--detect_disfluencies', default=False, help="whether to try to detect disfluencies, marking them as special words [*]", type=str2bool) parser.add_argument('--recompute_all_timestamps', default=not TRUST_WHISPER_TIMESTAMP_BY_DEFAULT, help="Do not rely at all on Whisper timestamps (Experimental option: did not bring any improvement, but could be useful in cases where Whipser segment timestamp are wrong by more than 0.5 seconds)", type=str2bool) parser.add_argument("--punctuations_with_words", default=True, help="whether to include punctuations in the words", type=str2bool)