feat(glm_moe_dsa): experimental bucketed sparse-MLA fwd/bwd kernels#2432
Draft
S1ro1 wants to merge 1 commit into
Draft
feat(glm_moe_dsa): experimental bucketed sparse-MLA fwd/bwd kernels#2432S1ro1 wants to merge 1 commit into
S1ro1 wants to merge 1 commit into
Conversation
Adds an alternative sparse-MLA implementation that buckets queries by their number of valid KV indices (256, 512, 1024, 1536, 2048) and dispatches a size-specialized fwd/bwd kernel per bucket instead of one max-topk kernel for everything. Pays a per-bucket compile up front, then earns it back when most queries don't need the full topk. Gated by PRIME_RL_EXPERIMENTAL_MLA_BWD=1 — default behaviour is unchanged and still uses the existing single-kernel TileLang path. The bucketed backend lives in sparse_mla_bucketed.py and reuses preprocess/postprocess from sparse_mla_bwd.py + sparse_mla_fwd_interface from sparse_mla_fwd.py. Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Adds a bucketed alternative to the single-kernel TileLang sparse-MLA fwd/bwd, gated by
PRIME_RL_EXPERIMENTAL_MLA_BWD=1. Default behaviour is unchanged.The bucketed backend partitions queries by their number of valid KV indices into the buckets
(256, 512, 1024, 1536, 2048)and dispatches a size-specialized fwd/bwd kernel per bucket. Pays a per-bucket compile up front, then earns it back when most queries don't need the full topk.What's in this PR
src/prime_rl/trainer/models/kernels/sparse_mla_bucketed.py— new module:sparse_mla_fwd_bucketed+sparse_mla_bwd_bucketed. Reusespreprocess/postprocessfromsparse_mla_bwd.pyandsparse_mla_fwd_interfacefromsparse_mla_fwd.py.src/prime_rl/trainer/models/kernels/sparse_mla_bwd.py/sparse_mla_fwd.py— extended kernel APIs (block_size,split_store,use_gemm_v1, gemm-v1 path) needed by the bucketed wrapper.src/prime_rl/trainer/models/glm_moe_dsa/sparse_mla_attention.py— single-env-var dispatch:PRIME_RL_EXPERIMENTAL_MLA_BWDunset /0/false→ existing TileLang path (no change).1/true/yes/on→ bucketed path.Notes
batch_size=1(asserted at the entry point).