[FEATURE SUPPORT] Add Triton decode support with KV-cache APIs#271
Merged
LoserCheems merged 59 commits intomainfrom Apr 26, 2026
Merged
[FEATURE SUPPORT] Add Triton decode support with KV-cache APIs#271LoserCheems merged 59 commits intomainfrom
LoserCheems merged 59 commits intomainfrom
Conversation
…clarity Co-authored-by: Copilot <[email protected]>
Co-authored-by: Copilot <[email protected]>
Co-authored-by: Copilot <[email protected]>
…rieval Co-authored-by: Copilot <[email protected]>
…s retrieval Co-authored-by: Copilot <[email protected]>
Co-authored-by: Copilot <[email protected]>
Co-authored-by: Copilot <[email protected]>
…s retrieval Co-authored-by: Copilot <[email protected]>
…ton's launcher expectations Co-authored-by: Copilot <[email protected]>
…laration order Co-authored-by: Copilot <[email protected]>
…streamline caching mechanisms
…kernel wrapping Co-authored-by: Copilot <[email protected]>
…ernel wrapping Co-authored-by: Copilot <[email protected]>
…gement Co-authored-by: Copilot <[email protected]>
…gement Co-authored-by: Copilot <[email protected]>
…agement Co-authored-by: Copilot <[email protected]>
Co-authored-by: Copilot <[email protected]>
…management and add get_dec_grid function for decode kernel
… architecture retrieval and add get_dec_dense_launch_config for decode dense kernel
… performance and clarity Co-authored-by: Copilot <[email protected]>
…roved device handling
… stride handling and output summation Co-authored-by: Copilot <[email protected]>
…rameter handling and clarity Co-authored-by: Copilot <[email protected]>
…roved parameter handling and clarity
…e handling with cache modifiers
…able length support Co-authored-by: Copilot <[email protected]>
…hitecture parameters Co-authored-by: Copilot <[email protected]>
…ntion kernels Co-authored-by: Copilot <[email protected]>
…rchitecture parameters
…rchitecture parameters Co-authored-by: Copilot <[email protected]>
Co-authored-by: Copilot <[email protected]>
Co-authored-by: Copilot <[email protected]>
Co-authored-by: Copilot <[email protected]>
Co-authored-by: Copilot <[email protected]>
Co-authored-by: Copilot <[email protected]>
…, and gated attention Co-authored-by: Copilot <[email protected]>
…ated attention Co-authored-by: Copilot <[email protected]>
…utput buffer support Co-authored-by: Copilot <[email protected]>
…evice checks for cumulative sequence lengths and sequences Co-authored-by: Copilot <[email protected]>
…rs and simplify tensor shapes Co-authored-by: Copilot <[email protected]>
…ers and simplify tensor shapes Co-authored-by: Copilot <[email protected]>
…rs and simplify tensor shapes Co-authored-by: Copilot <[email protected]>
…ers and update tensor shape descriptions Co-authored-by: Copilot <[email protected]>
…unused parameters Co-authored-by: Copilot <[email protected]>
Co-authored-by: Copilot <[email protected]>
…eadability Co-authored-by: Copilot <[email protected]>
…nsistency Co-authored-by: Copilot <[email protected]>
Contributor
There was a problem hiding this comment.
Pull request overview
This PR adds Triton-based decode (single-token) attention support with KV-cache inputs across dense/sparse/gated variants, and introduces shared caching utilities for Triton kernel compilation/launch configuration to improve reuse across forward/decode paths.
Changes:
- Add new Triton decode kernels + public
*_with_kvcache_funcAPIs for dense/sparse/gated (base + varlen KV). - Introduce shared Triton caching utilities (compiled kernel cache, launch-config/grid caching) and apply them across kernels.
- Add decode correctness tests (including preallocated output/LSE buffers) and update decode benchmarks to use KV-cache APIs.
Reviewed changes
Copilot reviewed 28 out of 28 changed files in this pull request and generated 5 comments.
Show a summary per file
| File | Description |
|---|---|
| tests/test_utils.py | Adds decode test helpers calling new KV-cache decode APIs and reference checks. |
| tests/test_dense_base_decode.py | New dense base decode correctness test (optionally preallocated buffers). |
| tests/test_dense_varlen_decode.py | New dense varlen decode correctness test (optionally preallocated buffers). |
| tests/test_sparse_base_decode.py | New sparse base decode correctness test (optionally preallocated buffers). |
| tests/test_sparse_varlen_decode.py | New sparse varlen decode correctness test (optionally preallocated buffers). |
| tests/test_gated_base_decode.py | New gated base decode correctness test (optionally preallocated buffers). |
| tests/test_gated_varlen_decode.py | New gated varlen decode correctness test (optionally preallocated buffers). |
| tests/benchmark_decode.py | Switch decode benchmarks to KV-cache decode APIs and decode-shaped inputs. |
| flash_sparse_attn/ops/triton/utils.py | Removes old get_arch; adds caching to num_splits_heuristic and adjusts heuristic. |
| flash_sparse_attn/ops/triton/cache_utils.py | New shared caching helpers (device arch/SMs, launch-config/grid caching, compiled kernel caching wrapper). |
| flash_sparse_attn/ops/triton/launch_template.py | Refactors launch-config selection to accept (device, arch) and caches results; adds decode + combine launch configs. |
| flash_sparse_attn/ops/triton/launch_grid.py | Adds cached grid factories; adds decode grid + decode-combine grid; adds forward-combine grid. |
| flash_sparse_attn/ops/triton/assert_inputs.py | Extends validation for decode + optional outputs; refactors fwd/bwd validation to accept (device, arch) and supports seqused_*. |
| flash_sparse_attn/ops/triton/interface.py | Exposes new public KV-cache decode APIs for dense/sparse/gated (base + varlen). |
| flash_sparse_attn/ops/triton/flash_dense_fwd.py | Adds caching wrappers and forward-combine usage; refactors to use (device, arch); adds is_split_kv parameter. |
| flash_sparse_attn/ops/triton/flash_dense_dec.py | New dense decode kernels (base + varlen KV) and decode entry points. |
| flash_sparse_attn/ops/triton/flash_dense_bwd.py | Adds caching wrappers; refactors to use (device, arch); cache-modifier tweaks. |
| flash_sparse_attn/ops/triton/flash_sparse_fwd.py | Adds caching wrappers; refactors to use (device, arch); cache-modifier tweaks. |
| flash_sparse_attn/ops/triton/flash_sparse_dec.py | New sparse decode kernels (base + varlen KV) and decode entry points. |
| flash_sparse_attn/ops/triton/flash_sparse_bwd.py | Adds caching wrappers; refactors to use (device, arch); cache-modifier tweaks. |
| flash_sparse_attn/ops/triton/flash_gated_fwd.py | Adds caching wrappers; refactors to use (device, arch); fixes launch-config selector for gated; cache-modifier tweaks. |
| flash_sparse_attn/ops/triton/flash_gated_dec.py | New gated decode kernels (base + varlen KV) and decode entry points. |
| flash_sparse_attn/ops/triton/flash_gated_bwd.py | Adds caching wrappers; refactors to use (device, arch); fixes launch-config selector for gated; cache-modifier tweaks. |
| flash_sparse_attn/ops/triton/flash_dec_combine.py | Refactors decode-combine kernel shape/launch; adds caching wrappers and (device, arch) launch config; updates stride handling. |
| flash_sparse_attn/ops/triton/flash_fwd_combine.py | New forward split-KV combine kernel and launcher using cached launch/grid utilities. |
| flash_sparse_attn/ops/triton/flash_bwd_preprocess.py | Wraps preprocess Triton kernel with compiled-kernel caching. |
| flash_sparse_attn/ops/triton/flash_bwd_postprocess.py | Wraps postprocess Triton kernel with compiled-kernel caching. |
| flash_sparse_attn/init.py | Exports new top-level KV-cache decode APIs. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
|
|
||
| :return num_sms: Number of streaming multiprocessors. | ||
| """ | ||
| return torch.cuda.get_device_properties(device).multi_processor_count |
Comment on lines
+618
to
626
| is_split_kv: bool = False, | ||
| pack_gqa: bool = False, | ||
| ) -> Tuple[torch.Tensor, torch.Tensor, float]: | ||
| num_SMs = torch.cuda.get_device_properties(query.device).multi_processor_count | ||
| device = query.device | ||
| arch = cache_utils.get_device_arch(device) | ||
| num_SMs = cache_utils.get_device_num_sms(device) | ||
| batch_size, seqlen_q, num_heads_q, head_dim = query.shape | ||
| _, seqlen_k, num_heads_kv, _ = key.shape | ||
| is_split_kv = seqlen_q == 1 and seqlen_q != seqlen_k | ||
| window_size_left, window_size_right = window_size |
Comment on lines
+774
to
787
| is_split_kv: bool = False, | ||
| pack_gqa: bool = False, | ||
| seqused_q: Optional[torch.Tensor] = None, | ||
| seqused_k: Optional[torch.Tensor] = None, | ||
| ) -> Tuple[torch.Tensor, torch.Tensor, float]: | ||
| num_SMs = torch.cuda.get_device_properties(query.device).multi_processor_count | ||
| device = query.device | ||
| arch = cache_utils.get_device_arch(device) | ||
| num_SMs = cache_utils.get_device_num_sms(device) | ||
| total_seqlen_q, num_heads_q, head_dim = query.shape | ||
| _, num_heads_kv, _ = key.shape | ||
| batch_size = cu_seqlens_q.shape[0] - 1 | ||
| seqlen_q = max_seqlen_q | ||
| seqlen_k = max_seqlen_k | ||
| is_split_kv = seqlen_q == 1 and seqlen_q != seqlen_k | ||
| window_size_left, window_size_right = window_size |
Comment on lines
+165
to
+166
| # inv_sum = tl.where((e_sum == 0.0) | (e_sum != e_sum), 0.0, 1.0 / e_sum) | ||
| inv_sum = 1.0 / e_sum |
Comment on lines
+166
to
+167
| # inv_sum = tl.where((e_sum == 0.0) | (e_sum != e_sum), 0.0, 1.0 / e_sum) | ||
| inv_sum = 1.0 / e_sum |
…hanced flexibility Co-authored-by: Copilot <[email protected]>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Design
Changes
Implementation Notes
Tests
Docs
Checklist