diff --git a/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py b/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py index 23fc12e81427..736d67b1a2ad 100644 --- a/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py +++ b/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py @@ -25,6 +25,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache from ...generation import GenerationMixin +from ...masking_utils import create_bidirectional_mask from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, ModelOutput from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel @@ -774,14 +775,19 @@ def forward( lengths_expand = audio_feat_lengths.unsqueeze(1).expand(batch_size, max_seq_len) # Create mask padding_mask = seq_range >= lengths_expand + audio_attention_mask_2d = (~padding_mask).to(dtype=torch.long, device=audio_feat_lengths.device) - audio_attention_mask_ = padding_mask.view(batch_size, 1, 1, max_seq_len).expand( - batch_size, 1, max_seq_len, max_seq_len + dummy_embeds = torch.zeros( + (batch_size, max_seq_len, 1), + dtype=inputs_embeds.dtype, + device=inputs_embeds.device, ) - audio_attention_mask = audio_attention_mask_.to( - dtype=self.audio_tower.conv1.weight.dtype, device=self.audio_tower.conv1.weight.device + + audio_attention_mask = create_bidirectional_mask( + config=self.audio_tower.config, + input_embeds=dummy_embeds, + attention_mask=audio_attention_mask_2d, ) - audio_attention_mask[audio_attention_mask_] = float("-inf") audio_outputs = self.audio_tower(input_features, attention_mask=audio_attention_mask) selected_audio_feature = audio_outputs.last_hidden_state