Skip to content

Commit

Permalink
add options about precision when decoding with transformers (WIP)
Browse files Browse the repository at this point in the history
  • Loading branch information
Jeronymous committed Feb 15, 2024
1 parent dabd52c commit e5b5819
Showing 1 changed file with 17 additions and 5 deletions.
22 changes: 17 additions & 5 deletions whisper_timestamped/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -2371,11 +2371,22 @@ def load_model(
except OSError:
generation_config = transformers.GenerationConfig.from_pretrained("openai/whisper-tiny")
processor = transformers.WhisperProcessor.from_pretrained(name)
model = transformers.WhisperForConditionalGeneration.from_pretrained(name)
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
precision = torch.float32
model = transformers.WhisperForConditionalGeneration.from_pretrained(
name,
# load_in_8bit=True,
# load_in_4bit=True,
torch_dtype=precision,
# torch_dtype=torch.bfloat16,
# attn_implementation="flash_attention_2",
# attn_implementation="sdpa",
)
# model = model.to_bettertransformer()

model = model.to(device)
return TransformerWhisperAsOpenAIWhisper(model, processor, generation_config)
return TransformerWhisperAsOpenAIWhisper(model, processor, generation_config, precision)

elif backend not in ["openai", "openai-whisper"]:
raise ValueError(f"Got unexpected backend {backend}")
Expand Down Expand Up @@ -2474,13 +2485,14 @@ class TransformerWhisperAsOpenAIWhisper:
Wrapper to use a transformers model as a whisper model (at least in whisper-timestamped)
"""

def __init__(self, model, processor, generation_config):
def __init__(self, model, processor, generation_config, precision):

self.model = model # transformers.WhisperForConditionalGeneration
self.processor = processor # transformers.WhisperProcessor
self.generation_config = generation_config # transformers.GenerationConfig

self.device = model.device
self.precision = precision

# Dimensions
model_config = model.config
Expand Down Expand Up @@ -2609,7 +2621,7 @@ def transcribe(self, audio, use_token_timestamps=False, **kwargs):

# Transcribe
output = self.model.generate(
features,
features.to(self.precision),
**generate_kwargs
)

Expand Down Expand Up @@ -2759,7 +2771,7 @@ def _iter_segments(self, output, prompt_ids):


def __call__(self, mfcc, tokens):
output = self.model(mfcc, decoder_input_ids=tokens, output_attentions=True)
output = self.model(mfcc.to(self.precision), decoder_input_ids=tokens, output_attentions=True)
return output.logits


Expand Down

0 comments on commit e5b5819

Please sign in to comment.