Skip to content

[webgpu] Optimize FlashAttention for prefill #25395

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

daijh
Copy link
Contributor

@daijh daijh commented Jul 15, 2025

Description

This PR enhances unidirectional FlashAttention by applying causal masking inside the main loop. This optimization eliminates unnecessary memory loads by avoiding future entries in the KV cache.

Testing on Lunar Lake shows up to a 20% performance improvement for phi-4-mini-accuracy4 (with a prompt of 4096). Similar performance gains were also observed for other models, including Qwen3-0.6B-accuracy4.

This PR now uses the more readable unidirectional attribute instead of is_gqa, to control causal masking.

Motivation and Context

See above.

This PR enhances unidirectional `FlashAttention` by applying causal
masking inside the main loop. This optimization eliminates unnecessary
memory loads by avoiding future entries in the KV cache.

Testing on Lunar Lake shows up to a 20% performance improvement for
`phi-4-mini-accuracy4` (with a prompt of 4096). Similar performance
gains were also observed for other models, including
`Qwen3-0.6B-accuracy4`.

This PR now uses the more readable `unidirectional` attribute instead of
`is_gpa`, to control causal masking.
@daijh
Copy link
Contributor Author

daijh commented Jul 15, 2025

Lunar Lake, Phi-4-mini-accuracy4:

Prompt Default Prefill Speed (tps) Opt Prefill Speed (tps) Improvement
Prompt-1024 561.40 593.76 5.76%
Prompt-2048 498.60 549.59 10.23%
Prompt-3072 465.98 537.64 15.38%
Prompt-4096 430.61 513.40 19.23%

@daijh
Copy link
Contributor Author

daijh commented Jul 15, 2025

@sushraja-msft @qjia7 pls take a look.

cc @jchen10 @xhcao

qjia7
qjia7 previously approved these changes Jul 15, 2025
Copy link
Contributor

@qjia7 qjia7 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM with nits.

@@ -337,7 +352,7 @@ Status FlashAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const {
qk_4 = qk_4 + loadAttentionBias(q_idx_global, k_start+12, head_idx);
}

let seq_causal_length = select(uniforms.total_sequence_length, uniforms.past_sequence_length + q_idx_global + 1, uniforms.is_gqa > 0);
let seq_causal_length = select(uniforms.total_sequence_length, uniforms.past_sequence_length + q_idx_global + 1, uniforms.is_unidirectional > 0);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Is it better if we move this assignment out of the for loop?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. Let me address it in follow-up PRs to keep this one focused.

@daijh
Copy link
Contributor Author

daijh commented Jul 16, 2025

@guschmue @fs-eire
Please take a look.

@fs-eire
Copy link
Contributor

fs-eire commented Jul 22, 2025

/azp run Linux QNN CI Pipeline, Win_TRT_Minimal_CUDA_Test_CI, Windows ARM64 QNN CI Pipeline, Windows GPU Doc Gen CI Pipeline, Windows x64 QNN CI Pipeline

Copy link

Azure Pipelines successfully started running 5 pipeline(s).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants