Skip to content

Commit

Permalink
Enable transformers as a backend
Browse files Browse the repository at this point in the history
  • Loading branch information
Jeronymous committed Jan 26, 2024
1 parent 713626e commit db03b65
Showing 1 changed file with 255 additions and 8 deletions.
263 changes: 255 additions & 8 deletions whisper_timestamped/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,10 @@ def transcribe_timestamped(
if fp16 is None:
fp16 = model.device != torch.device("cpu")

# TODO: implement efficient approach with transformers
if is_transformer_model(model):
naive_approach = True

# Safety check
input_stride = N_FRAMES // model.dims.n_audio_ctx
time_precision = input_stride * HOP_LENGTH / SAMPLE_RATE
Expand Down Expand Up @@ -1035,7 +1039,7 @@ def hook_output_logits(layer, ins, outs):

n_mels = model.dims.n_mels if hasattr(model.dims, "n_mels") else 80

attention_weights = [[] for _ in range(min(word_alignement_most_top_layers,len(model.decoder.blocks)))]
attention_weights = [[] for _ in range(min(word_alignement_most_top_layers, len(model.decoder.blocks)))]

try:

Expand All @@ -1047,9 +1051,15 @@ def hook_output_logits(layer, ins, outs):
for i, block in enumerate(model.decoder.blocks):
if i < nblocks - word_alignement_most_top_layers:
continue
def hook(layer, ins, outs, index=j):
if is_transformer_model(model):
attention_weights[index] = outs[1].log()
else:
attention_weights[index] = outs[1]
all_hooks.append(
block.cross_attn.register_forward_hook(
lambda layer, ins, outs, index=j: attention_weights.__setitem__(index, outs[-1])
hook
# lambda layer, ins, outs, index=j: attention_weights.__setitem__(index, outs[1])
)
)
j += 1
Expand Down Expand Up @@ -1159,12 +1169,20 @@ def hook_output_logits(layer, ins, outs):
last_token_check = tokens[-1]
tokens = tokens[:-1]

sot_sequence = tokenizer.sot_sequence
if language:
assert len(sot_sequence) == 3
sot_sequence = (
sot_sequence[0],
tokenizer.to_language_token(language),
sot_sequence[2],
)
tokens = [
*tokenizer.sot_sequence,
*sot_sequence,
tokenizer.timestamp_begin,
] + tokens

i_start = len(tokenizer.sot_sequence)
i_start = len(sot_sequence)

with torch.no_grad():
logprobs = model(mfcc, torch.Tensor(tokens).int().to(model.device).unsqueeze(0))
Expand Down Expand Up @@ -1234,8 +1252,10 @@ def hook_output_logits(layer, ins, outs):
segment_tokens_check.append(last_token_check)
if trust_whisper_timestamps:
if segment_tokens_check != segment["tokens"]:
assert len(segment_tokens_check) < len(segment["tokens"]) and segment_tokens_check[:-1] == segment["tokens"][:len(segment_tokens_check)-1], \
f"Got inconsistent tokens: {tokenizer.decode(segment_tokens_check)} != {tokenizer.decode(segment['tokens'])}"
assert len(segment_tokens_check) < len(segment["tokens"]), \
f"First should be longer by one token: '{tokenizer.decode_with_timestamps(segment_tokens_check)}' should include '{tokenizer.decode_with_timestamps(segment['tokens'])}'"
assert segment_tokens_check[:-1] == segment["tokens"][:len(segment_tokens_check)-1], \
f"Got inconsistent tokens: {tokenizer.decode_with_timestamps(segment_tokens_check)} != {tokenizer.decode_with_timestamps(segment['tokens'])}"
segment["tokens"] = segment_tokens_check
segment["text"] = tokenizer.decode(segment["tokens"])
# else: TODO
Expand Down Expand Up @@ -1293,6 +1313,10 @@ def print_timestamped(w):


def get_logit_filters(model, whisper_options, prompt = None):
if is_transformer_model(model):
# import transformers
# transformers.WhisperTimeStampLogitsProcessor
raise NotImplementedError("TODO")
decoding_options = get_decoding_options(whisper_options)
if "initial_prompt" in decoding_options:
prompt0 = decoding_options.pop("initial_prompt")
Expand Down Expand Up @@ -1324,6 +1348,15 @@ def get_decoding_options(whisper_options):
])

def get_tokenizer(model, task="transcribe", language="en"):
if is_transformer_model(model):
tokenizer = model.tokenizer
tokenizer.sot_sequence = (
tokenizer.sot,
tokenizer.to_language_token(language or "en"),
tokenizer.to_task_token(task),
)
tokenizer.sot_sequence
return model.tokenizer
try:
return whisper.tokenizer.get_tokenizer(
model.is_multilingual,
Expand Down Expand Up @@ -2260,7 +2293,7 @@ def _get_alignment_heads(model_name, num_layers, num_heads):
def _get_number_of_parameters(model):
num_parameters = 0
for name, p in model.named_parameters():
if name in ["decoder.proj_out.weight"]:
if name in ["decoder.proj_out.weight", "model.encoder.embed_positions.weight"]:
continue
num_parameters += p.numel()
return num_parameters
Expand All @@ -2271,7 +2304,20 @@ def load_model(
device: Optional[Union[str, torch.device]] = None,
download_root: str = None,
in_memory: bool = False,
use_transformers: bool = False,
):
if use_transformers:
import transformers
if name in whisper.available_models():
name = f"openai/whisper-{name}"
# TODO: use download_root
# TODO: does in_memory makes sense?
generation_config = transformers.GenerationConfig.from_pretrained(name)
processor = transformers.WhisperProcessor.from_pretrained(name)
model = transformers.WhisperForConditionalGeneration.from_pretrained(name)
model = model.to(device)
return TransformerWhisperAsOpenAIWhisper(model, processor, generation_config)

extension = os.path.splitext(name)[-1] if os.path.isfile(name) else None

if name in whisper.available_models() or extension == ".pt":
Expand Down Expand Up @@ -2359,7 +2405,206 @@ def torch_load(model_path):
hf_state_dict = torch.load(model_path, map_location="cpu")
return hf_state_dict

# Some helpers to manage transformers/openai-whisper model

class TransformerWhisperAsOpenAIWhisper:
"""
Wrapper to use a transformers model as a whisper model (at least in whisper-timestamped)
"""

def __init__(self, model, processor, generation_config):

self.model = model # transformers.WhisperForConditionalGeneration
self.processor = processor # transformers.WhisperProcessor
self.generation_config = generation_config # transformers.GenerationConfig
self.device = model.device

# Dimensions
self.dims = whisper.model.ModelDimensions(
n_mels = model.get_encoder().get_input_embeddings().in_channels,
n_audio_ctx = 1500,
n_audio_state = model.get_encoder().get_input_embeddings().out_channels,
n_audio_head = model.get_encoder().layers[0].self_attn.num_heads,
n_audio_layer = len(model.get_encoder().layers),
n_vocab = model.get_decoder().get_input_embeddings().num_embeddings,
n_text_ctx = 448,
n_text_state = model.get_decoder().get_input_embeddings().embedding_dim,
n_text_head = model.get_decoder().layers[0].self_attn.num_heads,
n_text_layer = len(model.get_decoder().layers),
)

# Tokenization
self.tokenizer = processor.tokenizer
(
self.tokenizer.sot,
self.tokenizer.eot,
self.tokenizer.timestamp_begin,
self.tokenizer.no_speech,
) = self.tokenizer.convert_tokens_to_ids([
"<|startoftranscript|>",
"<|endoftext|>",
"<|0.00|>",
"<|nospeech|>",
])
self.tokenizer.all_language_tokens = self.tokenizer.convert_tokens_to_ids([
t for t in self.tokenizer.additional_special_tokens if len(t) in [6,7]
])
self.tokenizer.to_language_token = lambda language: self.generation_config.lang_to_id["<|" + language + "|>"]
self.tokenizer.to_task_token = lambda task: self.generation_config.task_to_id[task]
self.tokenizer.to_timestamp_token = lambda t: self.tokenizer.encode(f"<|{t:0.2f}|>", add_special_tokens=False)[0]
self.tokenizer.decode_with_timestamps = lambda tokens: self.tokenizer.decode(tokens, decode_with_timestamps=True)

# Access to layers (renamed attributes)
self.decoder = self.model.get_decoder()
self.decoder.ln = self.decoder.layer_norm
self.decoder.token_embedding = self.decoder.embed_tokens
self.decoder.blocks = self.decoder.layers
for block in self.decoder.blocks:
block.cross_attn = block.encoder_attn

# From the config
self.is_multilingual = generation_config.is_multilingual # (self.tokenizer.sot != 50257)

# Alignment heads
if hasattr(generation_config, "alignment_heads"):
a = generation_config.alignment_heads
self.alignment_heads = torch.sparse_coo_tensor(np.array(a).transpose(), [True]*len(a)).coalesce().to(self.device)

def named_parameters(self):
return self.model.named_parameters()

def transcribe(self, audio, **kwargs):
features = self.processor(
audio,
return_tensors="pt",
sampling_rate=16_000,
truncation=False,
).input_features.to(self.device)

# TODO: double check that this is correct
for k in "temperature", "beam_size", "best_of":
if k in kwargs:
k2= {
"beam_size": "num_beams",
"best_of": "top_k",
}.get(k, k)
setattr(self.generation_config, k2, kwargs[k])

output = self.model.generate(
features,
return_dict_in_generate = True,
return_segments = True,
return_timestamps = True,
return_token_timestamps = False, # Note: concurrent token timestamps by transformers
max_length = self.dims.n_text_ctx,
is_multilingual = self.is_multilingual,
task = kwargs.get("task", "transcribe"),
language = kwargs.get("language"),
prompt_ids = kwargs.get("initial_prompt"),
generation_config = self.generation_config,
)

output_dict = {}

language_detected = None
if "segments" in output:
# Several segments
full_text = ""
segments = []
id = -1
previous_end = 0
for segment in output["segments"]:
for alternative in segment:
id += 1
tokens = alternative["tokens"]
text = self.tokenizer.decode(tokens, skip_special_tokens=True)

token_timestamps = [(i, t.item()) for i, t in enumerate(tokens) if t >= self.tokenizer.timestamp_begin]
if len(token_timestamps):
assert len(token_timestamps) == 2, f"Got unexpected number of timestamps: {token_timestamps}"
i_start, token_start = token_timestamps[0]
i_end, token_end = token_timestamps[1]
tokens = tokens[i_start+1:i_end]
offset = max(previous_end, alternative["start"].item())
start = offset + (token_start - self.tokenizer.timestamp_begin) * AUDIO_TIME_PER_TOKEN
end = offset + (token_end - self.tokenizer.timestamp_begin) * AUDIO_TIME_PER_TOKEN
previous_end = end
else:
start = max(previous_end, alternative["start"].item())
end = alternative["end"].item()
assert end >= start, f"Got end < start ({end} < {start})"
previous_end = end
token_start = self.tokenizer.to_timestamp_token(start)
token_end = self.tokenizer.to_timestamp_token(end)
# Accumulate
segments.append({
"id": id,
"start": start,
"end": end,
"text": text,
"tokens": [token_start] + tokens.tolist() + [token_end],
# "seek": 0,
# "temperature": 0.0,
# "avg_logprob": -0.6982866287231445,
# "compression_ratio": 0.5294117647058824,
# "no_speech_prob": 0.019023602828383446
})
if full_text:
full_text += " "
full_text += text

output_dict = {
"text": full_text,
"segments": segments,
}
if not kwargs.get("language"):
language_detected = self.tokenizer.decode(output["segments"][0][0]["tokens"][0])
else:
# One segment only
tokens = output.sequences[0]
text = self.tokenizer.decode(tokens, skip_special_tokens=True).strip()
tokens = tokens.tolist()
i_sot = tokens.index(self.tokenizer.sot)
i_eot = tokens.index(self.tokenizer.eot) if self.tokenizer.eot in tokens else len(tokens)
i_start = i_sot+3
i_end = i_eot-1
start = self.tokenizer.decode([tokens[i_start]], decode_with_timestamps=True)
end = self.tokenizer.decode([tokens[i_end]], decode_with_timestamps=True)
start = float(start[2:-2])
end = float(end[2:-2])

output_dict = {
"text": text,
"segments": [{
"id": 0,
"start": start,
"end": end,
"text": text,
"tokens": tokens[i_start:i_end+1],
# "seek": 0,
# "temperature": 0.0,
# "avg_logprob": -0.6982866287231445,
# "compression_ratio": 0.5294117647058824,
# "no_speech_prob": 0.019023602828383446
}]
}
if not kwargs.get("language"):
language_detected = self.tokenizer.decode([tokens[i_sot+1]])

if language_detected is not None:
assert len(language_detected) in [6,7], f"Unexpected language detected: {language_detected}"
language_detected = language_detected[2:-2]
output_dict["language"] = language_detected

return output_dict

def __call__(self, mfcc, tokens):
output = self.model(mfcc, decoder_input_ids=tokens, output_attentions=True)
return output.logits


def is_transformer_model(model):
return isinstance(model, TransformerWhisperAsOpenAIWhisper)


# Credit: https://github.com/openai/whisper/discussions/830
Expand Down Expand Up @@ -2500,6 +2745,7 @@ def get_do_write(output_format):
parser.add_argument('--model', help=f"name of the Whisper model to use. Examples: {', '.join(whisper.available_models())}", default="small")
parser.add_argument("--model_dir", default=None, help="the path to save model files; uses ~/.cache/whisper by default", type=str)
parser.add_argument("--device", default=get_default_device(), help="device to use for PyTorch inference")
parser.add_argument("--use_transformers", default=False, help="whether to use transformers (instead of openai-whisper) backend", type=str2bool)
parser.add_argument("--output_dir", "-o", default=None, help="directory to save the outputs", type=str)
valid_formats = ["txt", "vtt", "srt", "tsv", "csv", "json"]
def str2output_formats(string):
Expand Down Expand Up @@ -2590,8 +2836,9 @@ def __call__(self, parser, namespace, values, option_string=None):
force_cudnn_initialization(device)

output_format = args.pop("output_format")
use_transformers = args.pop("use_transformers")

model = load_model(model, device=device, download_root=model_dir)
model = load_model(model, device=device, download_root=model_dir, use_transformers=use_transformers)

plot_word_alignment = args.pop("plot")

Expand Down

0 comments on commit db03b65

Please sign in to comment.