Skip to content

feat(glm_moe_dsa): experimental bucketed sparse-MLA fwd/bwd kernels#2432

Draft
S1ro1 wants to merge 1 commit into
mainfrom
feat/mla-bucketed-bwd
Draft

feat(glm_moe_dsa): experimental bucketed sparse-MLA fwd/bwd kernels#2432
S1ro1 wants to merge 1 commit into
mainfrom
feat/mla-bucketed-bwd

Conversation

@S1ro1
Copy link
Copy Markdown
Collaborator

@S1ro1 S1ro1 commented May 6, 2026

Summary

Adds a bucketed alternative to the single-kernel TileLang sparse-MLA fwd/bwd, gated by PRIME_RL_EXPERIMENTAL_MLA_BWD=1. Default behaviour is unchanged.

The bucketed backend partitions queries by their number of valid KV indices into the buckets (256, 512, 1024, 1536, 2048) and dispatches a size-specialized fwd/bwd kernel per bucket. Pays a per-bucket compile up front, then earns it back when most queries don't need the full topk.

What's in this PR

  • src/prime_rl/trainer/models/kernels/sparse_mla_bucketed.py — new module: sparse_mla_fwd_bucketed + sparse_mla_bwd_bucketed. Reuses preprocess/postprocess from sparse_mla_bwd.py and sparse_mla_fwd_interface from sparse_mla_fwd.py.
  • src/prime_rl/trainer/models/kernels/sparse_mla_bwd.py / sparse_mla_fwd.py — extended kernel APIs (block_size, split_store, use_gemm_v1, gemm-v1 path) needed by the bucketed wrapper.
  • src/prime_rl/trainer/models/glm_moe_dsa/sparse_mla_attention.py — single-env-var dispatch:
    • PRIME_RL_EXPERIMENTAL_MLA_BWD unset / 0 / false → existing TileLang path (no change).
    • set to 1 / true / yes / on → bucketed path.

Notes

  • Bucketed fwd/bwd currently assume batch_size=1 (asserted at the entry point).
  • Default-disabled, so this PR is functionally a no-op for any existing config that doesn't set the env var.

Adds an alternative sparse-MLA implementation that buckets queries by their
number of valid KV indices (256, 512, 1024, 1536, 2048) and dispatches a
size-specialized fwd/bwd kernel per bucket instead of one max-topk kernel
for everything. Pays a per-bucket compile up front, then earns it back when
most queries don't need the full topk.

Gated by PRIME_RL_EXPERIMENTAL_MLA_BWD=1 — default behaviour is unchanged
and still uses the existing single-kernel TileLang path. The bucketed
backend lives in sparse_mla_bucketed.py and reuses preprocess/postprocess
from sparse_mla_bwd.py + sparse_mla_fwd_interface from sparse_mla_fwd.py.

Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant