From db03b65935a07fe646135af36e359ecd361d2683 Mon Sep 17 00:00:00 2001 From: Jeronymous Date: Fri, 26 Jan 2024 09:57:42 +0100 Subject: [PATCH] Enable transformers as a backend --- whisper_timestamped/transcribe.py | 263 +++++++++++++++++++++++++++++- 1 file changed, 255 insertions(+), 8 deletions(-) diff --git a/whisper_timestamped/transcribe.py b/whisper_timestamped/transcribe.py index 345338f..ce4c28b 100755 --- a/whisper_timestamped/transcribe.py +++ b/whisper_timestamped/transcribe.py @@ -229,6 +229,10 @@ def transcribe_timestamped( if fp16 is None: fp16 = model.device != torch.device("cpu") + # TODO: implement efficient approach with transformers + if is_transformer_model(model): + naive_approach = True + # Safety check input_stride = N_FRAMES // model.dims.n_audio_ctx time_precision = input_stride * HOP_LENGTH / SAMPLE_RATE @@ -1035,7 +1039,7 @@ def hook_output_logits(layer, ins, outs): n_mels = model.dims.n_mels if hasattr(model.dims, "n_mels") else 80 - attention_weights = [[] for _ in range(min(word_alignement_most_top_layers,len(model.decoder.blocks)))] + attention_weights = [[] for _ in range(min(word_alignement_most_top_layers, len(model.decoder.blocks)))] try: @@ -1047,9 +1051,15 @@ def hook_output_logits(layer, ins, outs): for i, block in enumerate(model.decoder.blocks): if i < nblocks - word_alignement_most_top_layers: continue + def hook(layer, ins, outs, index=j): + if is_transformer_model(model): + attention_weights[index] = outs[1].log() + else: + attention_weights[index] = outs[1] all_hooks.append( block.cross_attn.register_forward_hook( - lambda layer, ins, outs, index=j: attention_weights.__setitem__(index, outs[-1]) + hook + # lambda layer, ins, outs, index=j: attention_weights.__setitem__(index, outs[1]) ) ) j += 1 @@ -1159,12 +1169,20 @@ def hook_output_logits(layer, ins, outs): last_token_check = tokens[-1] tokens = tokens[:-1] + sot_sequence = tokenizer.sot_sequence + if language: + assert len(sot_sequence) == 3 + sot_sequence = ( + sot_sequence[0], + tokenizer.to_language_token(language), + sot_sequence[2], + ) tokens = [ - *tokenizer.sot_sequence, + *sot_sequence, tokenizer.timestamp_begin, ] + tokens - i_start = len(tokenizer.sot_sequence) + i_start = len(sot_sequence) with torch.no_grad(): logprobs = model(mfcc, torch.Tensor(tokens).int().to(model.device).unsqueeze(0)) @@ -1234,8 +1252,10 @@ def hook_output_logits(layer, ins, outs): segment_tokens_check.append(last_token_check) if trust_whisper_timestamps: if segment_tokens_check != segment["tokens"]: - assert len(segment_tokens_check) < len(segment["tokens"]) and segment_tokens_check[:-1] == segment["tokens"][:len(segment_tokens_check)-1], \ - f"Got inconsistent tokens: {tokenizer.decode(segment_tokens_check)} != {tokenizer.decode(segment['tokens'])}" + assert len(segment_tokens_check) < len(segment["tokens"]), \ + f"First should be longer by one token: '{tokenizer.decode_with_timestamps(segment_tokens_check)}' should include '{tokenizer.decode_with_timestamps(segment['tokens'])}'" + assert segment_tokens_check[:-1] == segment["tokens"][:len(segment_tokens_check)-1], \ + f"Got inconsistent tokens: {tokenizer.decode_with_timestamps(segment_tokens_check)} != {tokenizer.decode_with_timestamps(segment['tokens'])}" segment["tokens"] = segment_tokens_check segment["text"] = tokenizer.decode(segment["tokens"]) # else: TODO @@ -1293,6 +1313,10 @@ def print_timestamped(w): def get_logit_filters(model, whisper_options, prompt = None): + if is_transformer_model(model): + # import transformers + # transformers.WhisperTimeStampLogitsProcessor + raise NotImplementedError("TODO") decoding_options = get_decoding_options(whisper_options) if "initial_prompt" in decoding_options: prompt0 = decoding_options.pop("initial_prompt") @@ -1324,6 +1348,15 @@ def get_decoding_options(whisper_options): ]) def get_tokenizer(model, task="transcribe", language="en"): + if is_transformer_model(model): + tokenizer = model.tokenizer + tokenizer.sot_sequence = ( + tokenizer.sot, + tokenizer.to_language_token(language or "en"), + tokenizer.to_task_token(task), + ) + tokenizer.sot_sequence + return model.tokenizer try: return whisper.tokenizer.get_tokenizer( model.is_multilingual, @@ -2260,7 +2293,7 @@ def _get_alignment_heads(model_name, num_layers, num_heads): def _get_number_of_parameters(model): num_parameters = 0 for name, p in model.named_parameters(): - if name in ["decoder.proj_out.weight"]: + if name in ["decoder.proj_out.weight", "model.encoder.embed_positions.weight"]: continue num_parameters += p.numel() return num_parameters @@ -2271,7 +2304,20 @@ def load_model( device: Optional[Union[str, torch.device]] = None, download_root: str = None, in_memory: bool = False, + use_transformers: bool = False, ): + if use_transformers: + import transformers + if name in whisper.available_models(): + name = f"openai/whisper-{name}" + # TODO: use download_root + # TODO: does in_memory makes sense? + generation_config = transformers.GenerationConfig.from_pretrained(name) + processor = transformers.WhisperProcessor.from_pretrained(name) + model = transformers.WhisperForConditionalGeneration.from_pretrained(name) + model = model.to(device) + return TransformerWhisperAsOpenAIWhisper(model, processor, generation_config) + extension = os.path.splitext(name)[-1] if os.path.isfile(name) else None if name in whisper.available_models() or extension == ".pt": @@ -2359,7 +2405,206 @@ def torch_load(model_path): hf_state_dict = torch.load(model_path, map_location="cpu") return hf_state_dict +# Some helpers to manage transformers/openai-whisper model + +class TransformerWhisperAsOpenAIWhisper: + """ + Wrapper to use a transformers model as a whisper model (at least in whisper-timestamped) + """ + + def __init__(self, model, processor, generation_config): + + self.model = model # transformers.WhisperForConditionalGeneration + self.processor = processor # transformers.WhisperProcessor + self.generation_config = generation_config # transformers.GenerationConfig + self.device = model.device + + # Dimensions + self.dims = whisper.model.ModelDimensions( + n_mels = model.get_encoder().get_input_embeddings().in_channels, + n_audio_ctx = 1500, + n_audio_state = model.get_encoder().get_input_embeddings().out_channels, + n_audio_head = model.get_encoder().layers[0].self_attn.num_heads, + n_audio_layer = len(model.get_encoder().layers), + n_vocab = model.get_decoder().get_input_embeddings().num_embeddings, + n_text_ctx = 448, + n_text_state = model.get_decoder().get_input_embeddings().embedding_dim, + n_text_head = model.get_decoder().layers[0].self_attn.num_heads, + n_text_layer = len(model.get_decoder().layers), + ) + + # Tokenization + self.tokenizer = processor.tokenizer + ( + self.tokenizer.sot, + self.tokenizer.eot, + self.tokenizer.timestamp_begin, + self.tokenizer.no_speech, + ) = self.tokenizer.convert_tokens_to_ids([ + "<|startoftranscript|>", + "<|endoftext|>", + "<|0.00|>", + "<|nospeech|>", + ]) + self.tokenizer.all_language_tokens = self.tokenizer.convert_tokens_to_ids([ + t for t in self.tokenizer.additional_special_tokens if len(t) in [6,7] + ]) + self.tokenizer.to_language_token = lambda language: self.generation_config.lang_to_id["<|" + language + "|>"] + self.tokenizer.to_task_token = lambda task: self.generation_config.task_to_id[task] + self.tokenizer.to_timestamp_token = lambda t: self.tokenizer.encode(f"<|{t:0.2f}|>", add_special_tokens=False)[0] + self.tokenizer.decode_with_timestamps = lambda tokens: self.tokenizer.decode(tokens, decode_with_timestamps=True) + + # Access to layers (renamed attributes) + self.decoder = self.model.get_decoder() + self.decoder.ln = self.decoder.layer_norm + self.decoder.token_embedding = self.decoder.embed_tokens + self.decoder.blocks = self.decoder.layers + for block in self.decoder.blocks: + block.cross_attn = block.encoder_attn + + # From the config + self.is_multilingual = generation_config.is_multilingual # (self.tokenizer.sot != 50257) + + # Alignment heads + if hasattr(generation_config, "alignment_heads"): + a = generation_config.alignment_heads + self.alignment_heads = torch.sparse_coo_tensor(np.array(a).transpose(), [True]*len(a)).coalesce().to(self.device) + + def named_parameters(self): + return self.model.named_parameters() + + def transcribe(self, audio, **kwargs): + features = self.processor( + audio, + return_tensors="pt", + sampling_rate=16_000, + truncation=False, + ).input_features.to(self.device) + + # TODO: double check that this is correct + for k in "temperature", "beam_size", "best_of": + if k in kwargs: + k2= { + "beam_size": "num_beams", + "best_of": "top_k", + }.get(k, k) + setattr(self.generation_config, k2, kwargs[k]) + + output = self.model.generate( + features, + return_dict_in_generate = True, + return_segments = True, + return_timestamps = True, + return_token_timestamps = False, # Note: concurrent token timestamps by transformers + max_length = self.dims.n_text_ctx, + is_multilingual = self.is_multilingual, + task = kwargs.get("task", "transcribe"), + language = kwargs.get("language"), + prompt_ids = kwargs.get("initial_prompt"), + generation_config = self.generation_config, + ) + + output_dict = {} + + language_detected = None + if "segments" in output: + # Several segments + full_text = "" + segments = [] + id = -1 + previous_end = 0 + for segment in output["segments"]: + for alternative in segment: + id += 1 + tokens = alternative["tokens"] + text = self.tokenizer.decode(tokens, skip_special_tokens=True) + + token_timestamps = [(i, t.item()) for i, t in enumerate(tokens) if t >= self.tokenizer.timestamp_begin] + if len(token_timestamps): + assert len(token_timestamps) == 2, f"Got unexpected number of timestamps: {token_timestamps}" + i_start, token_start = token_timestamps[0] + i_end, token_end = token_timestamps[1] + tokens = tokens[i_start+1:i_end] + offset = max(previous_end, alternative["start"].item()) + start = offset + (token_start - self.tokenizer.timestamp_begin) * AUDIO_TIME_PER_TOKEN + end = offset + (token_end - self.tokenizer.timestamp_begin) * AUDIO_TIME_PER_TOKEN + previous_end = end + else: + start = max(previous_end, alternative["start"].item()) + end = alternative["end"].item() + assert end >= start, f"Got end < start ({end} < {start})" + previous_end = end + token_start = self.tokenizer.to_timestamp_token(start) + token_end = self.tokenizer.to_timestamp_token(end) + # Accumulate + segments.append({ + "id": id, + "start": start, + "end": end, + "text": text, + "tokens": [token_start] + tokens.tolist() + [token_end], + # "seek": 0, + # "temperature": 0.0, + # "avg_logprob": -0.6982866287231445, + # "compression_ratio": 0.5294117647058824, + # "no_speech_prob": 0.019023602828383446 + }) + if full_text: + full_text += " " + full_text += text + + output_dict = { + "text": full_text, + "segments": segments, + } + if not kwargs.get("language"): + language_detected = self.tokenizer.decode(output["segments"][0][0]["tokens"][0]) + else: + # One segment only + tokens = output.sequences[0] + text = self.tokenizer.decode(tokens, skip_special_tokens=True).strip() + tokens = tokens.tolist() + i_sot = tokens.index(self.tokenizer.sot) + i_eot = tokens.index(self.tokenizer.eot) if self.tokenizer.eot in tokens else len(tokens) + i_start = i_sot+3 + i_end = i_eot-1 + start = self.tokenizer.decode([tokens[i_start]], decode_with_timestamps=True) + end = self.tokenizer.decode([tokens[i_end]], decode_with_timestamps=True) + start = float(start[2:-2]) + end = float(end[2:-2]) + + output_dict = { + "text": text, + "segments": [{ + "id": 0, + "start": start, + "end": end, + "text": text, + "tokens": tokens[i_start:i_end+1], + # "seek": 0, + # "temperature": 0.0, + # "avg_logprob": -0.6982866287231445, + # "compression_ratio": 0.5294117647058824, + # "no_speech_prob": 0.019023602828383446 + }] + } + if not kwargs.get("language"): + language_detected = self.tokenizer.decode([tokens[i_sot+1]]) + + if language_detected is not None: + assert len(language_detected) in [6,7], f"Unexpected language detected: {language_detected}" + language_detected = language_detected[2:-2] + output_dict["language"] = language_detected + + return output_dict + + def __call__(self, mfcc, tokens): + output = self.model(mfcc, decoder_input_ids=tokens, output_attentions=True) + return output.logits + +def is_transformer_model(model): + return isinstance(model, TransformerWhisperAsOpenAIWhisper) # Credit: https://github.com/openai/whisper/discussions/830 @@ -2500,6 +2745,7 @@ def get_do_write(output_format): parser.add_argument('--model', help=f"name of the Whisper model to use. Examples: {', '.join(whisper.available_models())}", default="small") parser.add_argument("--model_dir", default=None, help="the path to save model files; uses ~/.cache/whisper by default", type=str) parser.add_argument("--device", default=get_default_device(), help="device to use for PyTorch inference") + parser.add_argument("--use_transformers", default=False, help="whether to use transformers (instead of openai-whisper) backend", type=str2bool) parser.add_argument("--output_dir", "-o", default=None, help="directory to save the outputs", type=str) valid_formats = ["txt", "vtt", "srt", "tsv", "csv", "json"] def str2output_formats(string): @@ -2590,8 +2836,9 @@ def __call__(self, parser, namespace, values, option_string=None): force_cudnn_initialization(device) output_format = args.pop("output_format") + use_transformers = args.pop("use_transformers") - model = load_model(model, device=device, download_root=model_dir) + model = load_model(model, device=device, download_root=model_dir, use_transformers=use_transformers) plot_word_alignment = args.pop("plot")