From e5b5819c741ef1159a5251778b4a8f5fa8095555 Mon Sep 17 00:00:00 2001 From: Jeronymous Date: Thu, 15 Feb 2024 18:15:11 +0100 Subject: [PATCH] add options about precision when decoding with transformers (WIP) --- whisper_timestamped/transcribe.py | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/whisper_timestamped/transcribe.py b/whisper_timestamped/transcribe.py index 8750725..e7127ee 100644 --- a/whisper_timestamped/transcribe.py +++ b/whisper_timestamped/transcribe.py @@ -2371,11 +2371,22 @@ def load_model( except OSError: generation_config = transformers.GenerationConfig.from_pretrained("openai/whisper-tiny") processor = transformers.WhisperProcessor.from_pretrained(name) - model = transformers.WhisperForConditionalGeneration.from_pretrained(name) if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" + precision = torch.float32 + model = transformers.WhisperForConditionalGeneration.from_pretrained( + name, + # load_in_8bit=True, + # load_in_4bit=True, + torch_dtype=precision, + # torch_dtype=torch.bfloat16, + # attn_implementation="flash_attention_2", + # attn_implementation="sdpa", + ) + # model = model.to_bettertransformer() + model = model.to(device) - return TransformerWhisperAsOpenAIWhisper(model, processor, generation_config) + return TransformerWhisperAsOpenAIWhisper(model, processor, generation_config, precision) elif backend not in ["openai", "openai-whisper"]: raise ValueError(f"Got unexpected backend {backend}") @@ -2474,13 +2485,14 @@ class TransformerWhisperAsOpenAIWhisper: Wrapper to use a transformers model as a whisper model (at least in whisper-timestamped) """ - def __init__(self, model, processor, generation_config): + def __init__(self, model, processor, generation_config, precision): self.model = model # transformers.WhisperForConditionalGeneration self.processor = processor # transformers.WhisperProcessor self.generation_config = generation_config # transformers.GenerationConfig self.device = model.device + self.precision = precision # Dimensions model_config = model.config @@ -2609,7 +2621,7 @@ def transcribe(self, audio, use_token_timestamps=False, **kwargs): # Transcribe output = self.model.generate( - features, + features.to(self.precision), **generate_kwargs ) @@ -2759,7 +2771,7 @@ def _iter_segments(self, output, prompt_ids): def __call__(self, mfcc, tokens): - output = self.model(mfcc, decoder_input_ids=tokens, output_attentions=True) + output = self.model(mfcc.to(self.precision), decoder_input_ids=tokens, output_attentions=True) return output.logits