-
Notifications
You must be signed in to change notification settings - Fork 3.4k
[webgpu] Optimize FlashAttention for prefill #25395
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
This PR enhances unidirectional `FlashAttention` by applying causal masking inside the main loop. This optimization eliminates unnecessary memory loads by avoiding future entries in the KV cache. Testing on Lunar Lake shows up to a 20% performance improvement for `phi-4-mini-accuracy4` (with a prompt of 4096). Similar performance gains were also observed for other models, including `Qwen3-0.6B-accuracy4`. This PR now uses the more readable `unidirectional` attribute instead of `is_gpa`, to control causal masking.
Lunar Lake, Phi-4-mini-accuracy4:
|
@sushraja-msft @qjia7 pls take a look. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM with nits.
@@ -337,7 +352,7 @@ Status FlashAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const { | |||
qk_4 = qk_4 + loadAttentionBias(q_idx_global, k_start+12, head_idx); | |||
} | |||
|
|||
let seq_causal_length = select(uniforms.total_sequence_length, uniforms.past_sequence_length + q_idx_global + 1, uniforms.is_gqa > 0); | |||
let seq_causal_length = select(uniforms.total_sequence_length, uniforms.past_sequence_length + q_idx_global + 1, uniforms.is_unidirectional > 0); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Is it better if we move this assignment out of the for loop?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure. Let me address it in follow-up PRs to keep this one focused.
/azp run Linux QNN CI Pipeline, Win_TRT_Minimal_CUDA_Test_CI, Windows ARM64 QNN CI Pipeline, Windows GPU Doc Gen CI Pipeline, Windows x64 QNN CI Pipeline |
Azure Pipelines successfully started running 5 pipeline(s). |
Description
This PR enhances unidirectional
FlashAttention
by applying causal masking inside the main loop. This optimization eliminates unnecessary memory loads by avoiding future entries in the KV cache.Testing on Lunar Lake shows up to a 20% performance improvement for
phi-4-mini-accuracy4
(with a prompt of 4096). Similar performance gains were also observed for other models, includingQwen3-0.6B-accuracy4
.This PR now uses the more readable
unidirectional
attribute instead ofis_gqa
, to control causal masking.Motivation and Context
See above.