MojoFusedNormRoPESageQuantStore: fused RoPE + KV-Quant with Key per-token Quant operator#358
Open
NASA1473 wants to merge 10 commits into
Open
MojoFusedNormRoPESageQuantStore: fused RoPE + KV-Quant with Key per-token Quant operator#358NASA1473 wants to merge 10 commits into
NASA1473 wants to merge 10 commits into
Conversation
There was a problem hiding this comment.
Pull request overview
Adds a new experimental operator, MojoFusedNormRoPESageQuantStore, intended to fuse QK-RMSNorm + RoPE + static int8 KV quant + paged KV store, with an optional SAGE-style per-token int8 key + per-token scale snapshot (and ixformer fused-kernel support).
Changes:
- Introduce
MojoFusedNormRoPESageQuantStoretorch reference implementation (norm + RoPE + KV static quant + paged store + optional per-token dynamic key quant). - Add ixformer backend implementation using
ixformer.functionsfused kernels (including the SAGE kernel variant). - Add accuracy and reference-contract tests, and export the operator via
mojo_opset.experimental.
Reviewed changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 5 comments.
Show a summary per file
| File | Description |
|---|---|
| mojo_opset/experimental/operators/fused_norm_rope_sage_quant_store.py | New fused operator (torch reference path) with optional SAGE per-token key quant and optional extra paged-cache stores. |
| mojo_opset/backends/ixformer/operators/fused_norm_rope_sage_quant_store.py | Ixformer backend implementation calling fused ixformer kernels (SWA stream + full stream with optional SAGE). |
| mojo_opset/tests/accuracy/operators/test_fused_norm_rope_sage_quant_store.py | New accuracy/reference tests for output contract, quant math, and determinism. |
| mojo_opset/experimental/operators/init.py | Exports the new operator from the experimental operators package. |
| mojo_opset/experimental/init.py | Exports the new operator from mojo_opset.experimental. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Comment on lines
+72
to
+74
| # SAGE: dynamic per-token int8 quant of the full key (over head_dim). | ||
| if self.enable_sage: | ||
| self.sage_full_k_quantize = MojoDynamicQuant._registry.get(self._backend)(quant_dtype=quant_dtype) |
Comment on lines
+219
to
+225
| if ( | ||
| self.enable_sage | ||
| and full_key_pt_int8 is not None | ||
| and cu_q_lens is not None | ||
| and sage_full_k_pt_cache is not None | ||
| and sage_full_k_pt_scale_cache is not None | ||
| ): |
Comment on lines
+28
to
+35
| if self.head_dim != 128: | ||
| raise NotImplementedError( | ||
| f"ixformer fused kernel only supports head_dim=128, got {self.head_dim}" | ||
| ) | ||
| if not (self.use_query_norm and self.use_key_norm): | ||
| raise NotImplementedError( | ||
| f"ixformer fused kernel only supports use_query_norm and use_key_norm, got {self.use_query_norm} and {self.use_key_norm}" | ||
| ) |
Comment on lines
+180
to
+190
| # --- Full stream: all-in-one fused kernel (+ per-token int8 K for SAGE) --- | ||
| if self.enable_sage: | ||
| (full_q_out, full_key_q, full_val_q, | ||
| full_key_pt_int8, full_key_pt_scale) = self._run_sage_stream_update_kv( | ||
| full_query, full_key, full_value, full_wq, full_wk, | ||
| full_ks, full_vs, | ||
| cos, sin, rotary_dim, | ||
| full_key_cache, full_value_cache, | ||
| block_tables, cu_q_lens, context_kv_lens, eps, | ||
| sage_full_k_pt_cache, sage_full_k_pt_scale_cache, | ||
| ) |
Comment on lines
+124
to
+132
| @pytest.mark.parametrize("num_heads_swa_q, num_heads_swa_k, num_heads_full_q, num_heads_full_k, head_dim, rope_dim", CONFIGS) | ||
| @pytest.mark.parametrize("batch_size, q_lens_val, context_kv_lens_val", SEQ_CONFIGS) | ||
| @pytest.mark.parametrize("update_kv", [True, False]) | ||
| @bypass_not_implemented | ||
| def test_diff_vs_torch_no_sage( | ||
| num_heads_swa_q, num_heads_swa_k, num_heads_full_q, num_heads_full_k, head_dim, rope_dim, | ||
| batch_size, q_lens_val, context_kv_lens_val, | ||
| update_kv, | ||
| ): |
…to single_kv_store
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Adds MojoFusedNormRoPESageQuantStore, fusing QK-Norm, RoPE, int8 K/V quant, and paged store into one op, with optional SAGE per-token int8 key + scale stored inline via the ixformer rms_norm_sage_qk_rotary_embedding kernel.