-
Notifications
You must be signed in to change notification settings - Fork 214
[Feature] Support for Variable Q/KV Sequence Lengths in VSA ThunderKittens kernel #911
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
…d to be changed. The plan is just change the c++ api
Summary of ChangesHello @alexzms, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the Video Sparse Attention (VSA) system by introducing support for variable Query and Key/Value sequence lengths. This change removes a previous limitation where Q and KV were assumed to have identical sequence lengths, thereby increasing the adaptability and utility of the VSA implementation for diverse model architectures and data patterns. The modifications primarily involve updating the API and testing infrastructure, as the underlying kernel logic was already robust enough to handle these variations. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request successfully updates the Video Sparse Attention (VSA) implementation to support variable sequence lengths for Query and Key/Value tensors. The changes to the CUDA kernel, Python wrappers, and test suite are comprehensive and well-executed. I've added a new forward-only test and expanded existing tests to cover नेता Q/KV length scenarios. My review includes a few suggestions to enhance test robustness and improve code clarity through more consistent variable naming across the Python and CUDA layers. I also noted a misleading comment in a test file that should be corrected.
|
|
||
|
|
||
| def block_sparse_kernel_test(Q, K, V, block_sparse_mask, variable_block_sizes, non_pad_index, dO): | ||
| def block_sparse_kernel_test(Q, K, V, block_sparse_mask, variable_block_sizes, q_non_pad_index, kv_non_pad_index, q_num_blocks, kv_num_blocks, dO): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For clarity, consider renaming variable_block_sizes to kv_variable_block_sizes in the function signature. This parameter is used for the K and V tensors, and the underlying CUDA kernel now expects kv_block_size. This change would make the Python test code more consistent with the C++/CUDA layer. The usage on line 45 should be updated accordingly.
| def block_sparse_kernel_test(Q, K, V, block_sparse_mask, variable_block_sizes, q_non_pad_index, kv_non_pad_index, q_num_blocks, kv_num_blocks, dO): | |
| def block_sparse_kernel_test(Q, K, V, block_sparse_mask, kv_variable_block_sizes, q_non_pad_index, kv_non_pad_index, q_num_blocks, kv_num_blocks, dO): |
| abs_diff = torch.abs(diff) | ||
| results[name]['sum_diff'] += torch.sum(abs_diff).item() | ||
| results[name]['sum_abs'] += torch.sum(torch.abs(pt)).item() | ||
| rel_max_diff = torch.max(abs_diff) / torch.mean(torch.abs(pt)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The calculation of rel_max_diff could lead to a division by zero if torch.mean(torch.abs(pt)) is zero. This could cause NaN or inf values, making the test results unreliable. It's safer to add a small epsilon to the denominator to prevent this, similar to what's done in test_vsa_forward.py.
| rel_max_diff = torch.max(abs_diff) / torch.mean(torch.abs(pt)) | |
| rel_max_diff = torch.max(abs_diff) / (torch.mean(torch.abs(pt)) + 1e-6) |
| K: torch.Tensor, | ||
| V: torch.Tensor, | ||
| block_sparse_mask: torch.Tensor, | ||
| variable_block_sizes: torch.Tensor, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For clarity and consistency with the underlying CUDA/Triton kernels, consider renaming variable_block_sizes to kv_variable_block_sizes. This parameter specifically refers to the block sizes of the Key and Value tensors. The usage on line 71 should be updated accordingly.
| variable_block_sizes: torch.Tensor, | |
| kv_variable_block_sizes: torch.Tensor, |
| - The Triton backend supports different Q/KV logical lengths via padding. | ||
| - The SM90 (H100) CUDA backend currently assumes the same number of blocks | ||
| for Q and KV, so we skip this test there. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This PR updates the Video Sparse Attention ThunderKittens implementation to support scenarios where the Query and Key/Value sequence lengths differ (different number of blocks).
Kernel Logic Validity:$dK$ computation depends on $Q$ and $dO$ across the sequence, padded regions in the upstream gradient are numerically zero. Therefore, the existing accumulation logic remains mathematically correct.
- Forward Pass: No internal kernel modifications were necessary. The existing Softmax implementation already utilizes K-masking, ensuring logical correctness even when KV lengths differ from Q.
- Backward Pass: No internal kernel modifications were necessary. While
Interface Updates:
Updated the CUDA host code API to explicitly handle distinct
q_seq_lenandkv_seq_lenarguments, replacing the previous singleseq_lenassumption.Testing:
Added
csrc/attn/tests/test_vsa_forward.pyfor forward-only verification.Updated
csrc/attn/tests/test_vsa.pyto include backward tests.Introduced
qkdifftest cases to specifically verify correctness when q block number different to kv block number.All tests pass. Compilation and installation steps remain consistent with the existing documentation.
Future TODO:
Make triton kernel also support this.