Skip to content

Commit

Permalink
fix wrong order of token slice
Browse files Browse the repository at this point in the history
  • Loading branch information
yuekaizhang committed Jan 22, 2024
1 parent ab08201 commit 46605ea
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion egs/aishell/ASR/whisper/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,9 +481,9 @@ def _batch_tensors(tensors: List[Tensor], pad_value: Any) -> Tensor:
with torch.set_grad_enabled(is_training):
encoder_out = model.encoder(feature)
text_logits = model.decoder(prev_outputs_tokens.to(device), encoder_out)
loss = decoder_criterion(text_logits, target_tokens.to(device))
text_logits = text_logits[:, ignore_prefix_size:, :]
target_tokens = target_tokens[:, ignore_prefix_size:]
loss = decoder_criterion(text_logits, target_tokens.to(device))

assert loss.requires_grad == is_training

Expand Down

0 comments on commit 46605ea

Please sign in to comment.