Skip to content

MojoFusedNormRoPESageQuantStore: fused RoPE + KV-Quant with Key per-token Quant operator#358

Open
NASA1473 wants to merge 10 commits into
XPU-Forces:dev/m13_ilufrom
NASA1473:fused_sage_quant
Open

MojoFusedNormRoPESageQuantStore: fused RoPE + KV-Quant with Key per-token Quant operator#358
NASA1473 wants to merge 10 commits into
XPU-Forces:dev/m13_ilufrom
NASA1473:fused_sage_quant

Conversation

@NASA1473

Copy link
Copy Markdown
Collaborator

Adds MojoFusedNormRoPESageQuantStore, fusing QK-Norm, RoPE, int8 K/V quant, and paged store into one op, with optional SAGE per-token int8 key + scale stored inline via the ixformer rms_norm_sage_qk_rotary_embedding kernel.

gemini-code-assist[bot]

This comment was marked as low quality.

@NASA1473 NASA1473 changed the title Add MojoFusedNormRoPESageQuantStore: fused RoPE + KV-Quant with Key per-token Quant operator MojoFusedNormRoPESageQuantStore: fused RoPE + KV-Quant with Key per-token Quant operator Jun 12, 2026
@NASA1473 NASA1473 requested a review from Copilot June 12, 2026 07:52

Copilot AI left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds a new experimental operator, MojoFusedNormRoPESageQuantStore, intended to fuse QK-RMSNorm + RoPE + static int8 KV quant + paged KV store, with an optional SAGE-style per-token int8 key + per-token scale snapshot (and ixformer fused-kernel support).

Changes:

  • Introduce MojoFusedNormRoPESageQuantStore torch reference implementation (norm + RoPE + KV static quant + paged store + optional per-token dynamic key quant).
  • Add ixformer backend implementation using ixformer.functions fused kernels (including the SAGE kernel variant).
  • Add accuracy and reference-contract tests, and export the operator via mojo_opset.experimental.

Reviewed changes

Copilot reviewed 5 out of 5 changed files in this pull request and generated 5 comments.

Show a summary per file
File Description
mojo_opset/experimental/operators/fused_norm_rope_sage_quant_store.py New fused operator (torch reference path) with optional SAGE per-token key quant and optional extra paged-cache stores.
mojo_opset/backends/ixformer/operators/fused_norm_rope_sage_quant_store.py Ixformer backend implementation calling fused ixformer kernels (SWA stream + full stream with optional SAGE).
mojo_opset/tests/accuracy/operators/test_fused_norm_rope_sage_quant_store.py New accuracy/reference tests for output contract, quant math, and determinism.
mojo_opset/experimental/operators/init.py Exports the new operator from the experimental operators package.
mojo_opset/experimental/init.py Exports the new operator from mojo_opset.experimental.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +72 to +74
# SAGE: dynamic per-token int8 quant of the full key (over head_dim).
if self.enable_sage:
self.sage_full_k_quantize = MojoDynamicQuant._registry.get(self._backend)(quant_dtype=quant_dtype)
Comment on lines +219 to +225
if (
self.enable_sage
and full_key_pt_int8 is not None
and cu_q_lens is not None
and sage_full_k_pt_cache is not None
and sage_full_k_pt_scale_cache is not None
):
Comment on lines +28 to +35
if self.head_dim != 128:
raise NotImplementedError(
f"ixformer fused kernel only supports head_dim=128, got {self.head_dim}"
)
if not (self.use_query_norm and self.use_key_norm):
raise NotImplementedError(
f"ixformer fused kernel only supports use_query_norm and use_key_norm, got {self.use_query_norm} and {self.use_key_norm}"
)
Comment on lines +180 to +190
# --- Full stream: all-in-one fused kernel (+ per-token int8 K for SAGE) ---
if self.enable_sage:
(full_q_out, full_key_q, full_val_q,
full_key_pt_int8, full_key_pt_scale) = self._run_sage_stream_update_kv(
full_query, full_key, full_value, full_wq, full_wk,
full_ks, full_vs,
cos, sin, rotary_dim,
full_key_cache, full_value_cache,
block_tables, cu_q_lens, context_kv_lens, eps,
sage_full_k_pt_cache, sage_full_k_pt_scale_cache,
)
Comment on lines +124 to +132
@pytest.mark.parametrize("num_heads_swa_q, num_heads_swa_k, num_heads_full_q, num_heads_full_k, head_dim, rope_dim", CONFIGS)
@pytest.mark.parametrize("batch_size, q_lens_val, context_kv_lens_val", SEQ_CONFIGS)
@pytest.mark.parametrize("update_kv", [True, False])
@bypass_not_implemented
def test_diff_vs_torch_no_sage(
num_heads_swa_q, num_heads_swa_k, num_heads_full_q, num_heads_full_k, head_dim, rope_dim,
batch_size, q_lens_val, context_kv_lens_val,
update_kv,
):
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants