From e495276f69f5ee31caa1209e7a05dae4442be5e4 Mon Sep 17 00:00:00 2001 From: Jeronymous Date: Mon, 4 Nov 2024 07:46:54 +0100 Subject: [PATCH] Fixes #221 : workaround that disable SPD attention in latest version of openai-whisper (20240930) which prevents from accessing attention weights --- whisper_timestamped/transcribe.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/whisper_timestamped/transcribe.py b/whisper_timestamped/transcribe.py index 2162ca2..2d62e6b 100755 --- a/whisper_timestamped/transcribe.py +++ b/whisper_timestamped/transcribe.py @@ -3,7 +3,7 @@ __author__ = "Jérôme Louradour" __credits__ = ["Jérôme Louradour"] __license__ = "GPLv3" -__version__ = "1.15.5" +__version__ = "1.15.6" # Set some environment variables import os @@ -899,8 +899,9 @@ def hook_output_logits(layer, ins, outs): if compute_word_confidence or no_speech_threshold is not None: all_hooks.append(model.decoder.ln.register_forward_hook(hook_output_logits)) - with disable_sdpa(): - transcription = model.transcribe(audio, **whisper_options) + with torch.no_grad(): + with disable_sdpa(): + transcription = model.transcribe(audio, **whisper_options) finally: @@ -1062,8 +1063,9 @@ def hook_output_logits(layer, ins, outs): try: model.alignment_heads = alignment_heads # Avoid exception "AttributeError: 'WhisperUntied' object has no attribute 'alignment_heads'. Did you mean: 'set_alignment_heads'?"" - with disable_sdpa(): - transcription = model.transcribe(audio, **whisper_options) + with torch.no_grad(): + with disable_sdpa(): + transcription = model.transcribe(audio, **whisper_options) finally: for hook in all_hooks: hook.remove() @@ -1238,8 +1240,9 @@ def hook(layer, ins, outs, index=j): i_start = len(sot_sequence) with torch.no_grad(): - logprobs = model(mfcc, torch.Tensor(tokens).int().to(model.device).unsqueeze(0)) - logprobs = F.log_softmax(logprobs, dim=-1) + with disable_sdpa(): + logprobs = model(mfcc, torch.Tensor(tokens).int().to(model.device).unsqueeze(0)) + logprobs = F.log_softmax(logprobs, dim=-1) end_token = tokenizer.timestamp_begin + round(min(N_FRAMES * HOP_LENGTH, end_sample - start_sample) // AUDIO_SAMPLES_PER_TOKEN) tokens = tokens[i_start:] + [end_token]