-
Couldn't load subscription status.
- Fork 31k
Fix Qwen2Audio flash attention mask format for generation #41843
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix Qwen2Audio flash attention mask format for generation #41843
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
@remi-or ready for review 🙂 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should move away from this way of creating masks. It will make our lives are harder to maintain
|
run-slow: qwen2_audio |
|
This comment contains run-slow, running the specified jobs: models: ['models/qwen2_audio'] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The nccl error is unrelated, known to fail atm
LGTM
|
Can you push an empty commit? Cant merge with red CI |
|
[For maintainers] Suggested jobs to run (before merge) run-slow: qwen2_audio |
| dummy_embeds = torch.zeros( | ||
| (batch_size, max_seq_len, 1), | ||
| dtype=self.audio_tower.conv1.weight.dtype, | ||
| device=self.audio_tower.conv1.weight.device, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah sorry maybe one last nit: can we change the device/dtype here? Inputs_embeds should suffice?
|
Thx a lot 🤗 |
What does this PR fix?
This PR fixes the
test_eager_matches_fa2_generatetest failure for Qwen2Audio by using thecreate_bidirectional_maskutility function to properly handle attention masks across different attention implementations.The Qwen2Audio model was manually creating a 4D attention mask with
-infvalues for the audio encoder, regardless of the attention implementation being used. This caused issues with Flash Attention 2/3, which requires a 2D boolean mask (shape(batch_size, seq_len)) with1for valid tokens and0for padding.