Skip to content

[FEATURE SUPPORT] Add Triton decode support with KV-cache APIs#271

Merged
LoserCheems merged 59 commits intomainfrom
support-triton-decode
Apr 26, 2026
Merged

[FEATURE SUPPORT] Add Triton decode support with KV-cache APIs#271
LoserCheems merged 59 commits intomainfrom
support-triton-decode

Conversation

@LoserCheems
Copy link
Copy Markdown
Collaborator

Summary

  • Add Triton-based decode support for flash sparse attention across dense, sparse, and gated attention variants.
  • Extend the library with decode paths that support KV cache inputs for autoregressive inference workloads.
  • This branch also standardizes decode naming and parameter handling so the new kernels fit the existing Triton interface more cleanly.

Design

  • The feature adds dedicated Triton decode kernels for dense, sparse, and gated attention instead of overloading the existing forward path.
  • Decode-specific launch configuration and grid selection are handled separately to match single-token / KV-cache inference behavior.
  • Kernel and launch caching were consolidated through shared cache utilities so device-specific configuration can be reused efficiently across decode and existing Triton paths.
  • Alternatives considered:
    • Reusing the forward kernels directly for decode. This would have kept the surface smaller, but it would not model KV-cache decode behavior cleanly and would make launch/config specialization harder.
    • Implementing decode support only for one attention mode first. This was rejected in favor of keeping dense, sparse, and gated interfaces aligned.

Changes

  • Add Triton decode kernels for:
    • dense attention
    • sparse attention
    • gated attention
  • Add public KV-cache decode APIs for:
    • flash_dense_attn_with_kvcache_func
    • flash_dense_attn_varlen_with_kvcache_func
    • flash_sparse_attn_with_kvcache_func
    • flash_sparse_attn_varlen_with_kvcache_func
    • flash_gated_attn_with_kvcache_func
    • flash_gated_attn_varlen_with_kvcache_func
  • Export the new KV-cache functions from the package top-level API.
  • Add support for optional preallocated output and LSE buffers in decode functions.
  • Rename forward-combine terminology to decode-combine for consistency with the new execution path.
  • Standardize decode-related parameter naming and simplify decode call signatures, including cleanup of unused parameters in varlen decode flows.

Implementation Notes

  • New Triton decode kernels were introduced for dense, sparse, and gated attention, with separate base and variable-length decode paths.
  • KV-cache support is implemented as dedicated decode entry points rather than as a thin wrapper over the regular forward path.
  • Shared cache utilities were added/refined to cache Triton launchers, launch configs, grid factories, and device-architecture-aware kernel setup.
  • Input validation was extended with decode-specific checks, including validation for optional output tensors.
  • Several internal refactors were included to improve decode readability and consistency:
    • keyword-argument based decode calls
    • decode-specific launch/grid helpers
    • unified naming such as scale_log2 and decode-combine terminology
    • removal of unused decode parameters and simplified tensor-shape handling

Tests

  • Added and updated decode-focused tests for:
    • dense base decode
    • dense varlen decode
    • sparse base decode
    • sparse varlen decode
    • gated base decode
    • gated varlen decode
  • Tests cover both normal decode execution and paths using preallocated output buffers.
  • Benchmark coverage was also updated to exercise KV-cache decode variants.

Docs

  • No user-facing documentation or example files were changed in this branch.
  • The change is currently covered by code-level API exposure and tests.

Checklist

LoserCheems and others added 30 commits April 22, 2026 16:33
…management and add get_dec_grid function for decode kernel
… architecture retrieval and add get_dec_dense_launch_config for decode dense kernel
… stride handling and output summation

Co-authored-by: Copilot <[email protected]>
LoserCheems and others added 25 commits April 26, 2026 00:30
…evice checks for cumulative sequence lengths and sequences

Co-authored-by: Copilot <[email protected]>
…rs and simplify tensor shapes

Co-authored-by: Copilot <[email protected]>
…ers and simplify tensor shapes

Co-authored-by: Copilot <[email protected]>
…rs and simplify tensor shapes

Co-authored-by: Copilot <[email protected]>
…ers and update tensor shape descriptions

Co-authored-by: Copilot <[email protected]>
Copilot AI review requested due to automatic review settings April 26, 2026 03:47
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds Triton-based decode (single-token) attention support with KV-cache inputs across dense/sparse/gated variants, and introduces shared caching utilities for Triton kernel compilation/launch configuration to improve reuse across forward/decode paths.

Changes:

  • Add new Triton decode kernels + public *_with_kvcache_func APIs for dense/sparse/gated (base + varlen KV).
  • Introduce shared Triton caching utilities (compiled kernel cache, launch-config/grid caching) and apply them across kernels.
  • Add decode correctness tests (including preallocated output/LSE buffers) and update decode benchmarks to use KV-cache APIs.

Reviewed changes

Copilot reviewed 28 out of 28 changed files in this pull request and generated 5 comments.

Show a summary per file
File Description
tests/test_utils.py Adds decode test helpers calling new KV-cache decode APIs and reference checks.
tests/test_dense_base_decode.py New dense base decode correctness test (optionally preallocated buffers).
tests/test_dense_varlen_decode.py New dense varlen decode correctness test (optionally preallocated buffers).
tests/test_sparse_base_decode.py New sparse base decode correctness test (optionally preallocated buffers).
tests/test_sparse_varlen_decode.py New sparse varlen decode correctness test (optionally preallocated buffers).
tests/test_gated_base_decode.py New gated base decode correctness test (optionally preallocated buffers).
tests/test_gated_varlen_decode.py New gated varlen decode correctness test (optionally preallocated buffers).
tests/benchmark_decode.py Switch decode benchmarks to KV-cache decode APIs and decode-shaped inputs.
flash_sparse_attn/ops/triton/utils.py Removes old get_arch; adds caching to num_splits_heuristic and adjusts heuristic.
flash_sparse_attn/ops/triton/cache_utils.py New shared caching helpers (device arch/SMs, launch-config/grid caching, compiled kernel caching wrapper).
flash_sparse_attn/ops/triton/launch_template.py Refactors launch-config selection to accept (device, arch) and caches results; adds decode + combine launch configs.
flash_sparse_attn/ops/triton/launch_grid.py Adds cached grid factories; adds decode grid + decode-combine grid; adds forward-combine grid.
flash_sparse_attn/ops/triton/assert_inputs.py Extends validation for decode + optional outputs; refactors fwd/bwd validation to accept (device, arch) and supports seqused_*.
flash_sparse_attn/ops/triton/interface.py Exposes new public KV-cache decode APIs for dense/sparse/gated (base + varlen).
flash_sparse_attn/ops/triton/flash_dense_fwd.py Adds caching wrappers and forward-combine usage; refactors to use (device, arch); adds is_split_kv parameter.
flash_sparse_attn/ops/triton/flash_dense_dec.py New dense decode kernels (base + varlen KV) and decode entry points.
flash_sparse_attn/ops/triton/flash_dense_bwd.py Adds caching wrappers; refactors to use (device, arch); cache-modifier tweaks.
flash_sparse_attn/ops/triton/flash_sparse_fwd.py Adds caching wrappers; refactors to use (device, arch); cache-modifier tweaks.
flash_sparse_attn/ops/triton/flash_sparse_dec.py New sparse decode kernels (base + varlen KV) and decode entry points.
flash_sparse_attn/ops/triton/flash_sparse_bwd.py Adds caching wrappers; refactors to use (device, arch); cache-modifier tweaks.
flash_sparse_attn/ops/triton/flash_gated_fwd.py Adds caching wrappers; refactors to use (device, arch); fixes launch-config selector for gated; cache-modifier tweaks.
flash_sparse_attn/ops/triton/flash_gated_dec.py New gated decode kernels (base + varlen KV) and decode entry points.
flash_sparse_attn/ops/triton/flash_gated_bwd.py Adds caching wrappers; refactors to use (device, arch); fixes launch-config selector for gated; cache-modifier tweaks.
flash_sparse_attn/ops/triton/flash_dec_combine.py Refactors decode-combine kernel shape/launch; adds caching wrappers and (device, arch) launch config; updates stride handling.
flash_sparse_attn/ops/triton/flash_fwd_combine.py New forward split-KV combine kernel and launcher using cached launch/grid utilities.
flash_sparse_attn/ops/triton/flash_bwd_preprocess.py Wraps preprocess Triton kernel with compiled-kernel caching.
flash_sparse_attn/ops/triton/flash_bwd_postprocess.py Wraps postprocess Triton kernel with compiled-kernel caching.
flash_sparse_attn/init.py Exports new top-level KV-cache decode APIs.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.


:return num_sms: Number of streaming multiprocessors.
"""
return torch.cuda.get_device_properties(device).multi_processor_count
Comment on lines +618 to 626
is_split_kv: bool = False,
pack_gqa: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor, float]:
num_SMs = torch.cuda.get_device_properties(query.device).multi_processor_count
device = query.device
arch = cache_utils.get_device_arch(device)
num_SMs = cache_utils.get_device_num_sms(device)
batch_size, seqlen_q, num_heads_q, head_dim = query.shape
_, seqlen_k, num_heads_kv, _ = key.shape
is_split_kv = seqlen_q == 1 and seqlen_q != seqlen_k
window_size_left, window_size_right = window_size
Comment on lines +774 to 787
is_split_kv: bool = False,
pack_gqa: bool = False,
seqused_q: Optional[torch.Tensor] = None,
seqused_k: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, float]:
num_SMs = torch.cuda.get_device_properties(query.device).multi_processor_count
device = query.device
arch = cache_utils.get_device_arch(device)
num_SMs = cache_utils.get_device_num_sms(device)
total_seqlen_q, num_heads_q, head_dim = query.shape
_, num_heads_kv, _ = key.shape
batch_size = cu_seqlens_q.shape[0] - 1
seqlen_q = max_seqlen_q
seqlen_k = max_seqlen_k
is_split_kv = seqlen_q == 1 and seqlen_q != seqlen_k
window_size_left, window_size_right = window_size
Comment on lines +165 to +166
# inv_sum = tl.where((e_sum == 0.0) | (e_sum != e_sum), 0.0, 1.0 / e_sum)
inv_sum = 1.0 / e_sum
Comment on lines +166 to +167
# inv_sum = tl.where((e_sum == 0.0) | (e_sum != e_sum), 0.0, 1.0 / e_sum)
inv_sum = 1.0 / e_sum
@LoserCheems LoserCheems merged commit 6b009de into main Apr 26, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants