From e69da79be28e62877f4c5186041ab520507b67d5 Mon Sep 17 00:00:00 2001 From: heqiushi Date: Fri, 5 Jun 2026 16:28:51 +0800 Subject: [PATCH 01/10] add deepep and test, modify test_moe_quant and test_attention_quant --- mojo_opset/core/__init__.py | 4 + mojo_opset/core/operators/deepep.py | 307 ++++++++++++++++++ mojo_opset/core/operators/moe.py | 1 + .../operators/test_attention_quant.py | 1 + .../tests/accuracy/operators/test_deepep.py | 280 ++++++++++++++++ .../accuracy/operators/test_moe_quant.py | 234 ++++++++++--- 6 files changed, 781 insertions(+), 46 deletions(-) create mode 100644 mojo_opset/core/operators/deepep.py create mode 100644 mojo_opset/tests/accuracy/operators/test_deepep.py diff --git a/mojo_opset/core/__init__.py b/mojo_opset/core/__init__.py index decd01dfa..d2ed8aa16 100644 --- a/mojo_opset/core/__init__.py +++ b/mojo_opset/core/__init__.py @@ -57,6 +57,8 @@ from .operators.quantize import MojoStaticQuant """ moe """ +from .operators.deepep import MojoDeepEPCombine +from .operators.deepep import MojoDeepEPDispatch from .operators.moe import MojoExperts from .operators.moe import MojoMoE from .operators.moe import MojoMoECombine @@ -155,6 +157,8 @@ "MojoMoECombine", "MojoQuantExperts", "MojoQuantMoE", + "MojoDeepEPDispatch", + "MojoDeepEPCombine", "MojoLayerNorm", "MojoRMSNorm", diff --git a/mojo_opset/core/operators/deepep.py b/mojo_opset/core/operators/deepep.py new file mode 100644 index 000000000..c97889dc6 --- /dev/null +++ b/mojo_opset/core/operators/deepep.py @@ -0,0 +1,307 @@ +"""DeepEP-style cross-rank MoE dispatch / combine operators. + +These operators model the all-to-all token-routing protocol used by DeepEP: +the dispatch step routes (token, expert) pairs to their target rank, sorts +by local expert, and (optionally) fuses smooth + per-token int8 quantization; +the combine step reverses the routing and reduces by top-k gates. + +The torch backend runs purely local at ``group_size == 1``; for +``group_size > 1`` it uses ``torch.distributed`` collectives to reconstruct +the global state and slice for the local-expert range, and therefore +requires an initialized process group. The xops backend provides a real +symmetric-memory implementation for production use. +""" + +from typing import Optional, Tuple + +import torch +import torch.distributed as dist + +from mojo_opset.core.operator import MojoOperator + + +def _local_dispatch_indices( + top_k_indices: torch.Tensor, + num_experts: int, + top_k: int, +): + """Compute local routing indices from per-rank top_k_indices. + + Returns: + sort_perm: int64 [BS*top_k] argsort of flat top_k_indices, stable. + scatter_index: int32 [BS, top_k] inverse permutation. + expert_token_count: int32 [num_experts] local per-expert token count. + """ + flat = top_k_indices.reshape(-1).to(torch.int64) + sort_perm = flat.argsort(stable=True) + scatter_index = sort_perm.argsort(stable=True).reshape(-1, top_k).to(torch.int32) + expert_token_count = torch.bincount(flat, minlength=num_experts).to(torch.int32) + return sort_perm, scatter_index, expert_token_count + + +class MojoDeepEPDispatch(MojoOperator): + """All-to-all MoE token dispatch with optional fused per-token int8 quantization. + + Init params (kernel binds these on init — must match the paired combine): + - num_experts (int): Total experts across all ranks. Must be divisible by group_size. + - top_k (int): Top-k experts per token. + - group_size (int): Number of ranks (a.k.a. ep_size). Defaults to 1. + - rank (int): Local rank id in [0, group_size). Defaults to 0. + - buffer_size (int): Symmetric-memory scratch in bytes. + + Forward returns a 6-tuple ``(expand_hidden_states, expert_token_cnt_per_rank, + expert_token_cnt_cumsum, expand_scale, scatter_index, expert_token_count)``: + - expand_hidden_states: [R, hidden] — int8 if fused per-token quant ran (smooth_scale provided), + otherwise input dtype. R = total (token, expert) pairs landing on this rank, + sorted by local expert id. + - expert_token_cnt_per_rank: [local_experts] int32 — non-cumsum count per local expert. Sums to R. + - expert_token_cnt_cumsum: [local_experts] int64 — end-offsets (cumsum without leading 0). + - expand_scale: [R, 1] float32 — per-token quant scale (one entry per row of expand_hidden_states). + - scatter_index: [q_len, top_k] int32 — local routing index, consumed by combine. + - expert_token_count: [num_experts] int32 — local per-expert token count, consumed by combine. + """ + + def __init__( + self, + num_experts: int, + top_k: int, + group_size: int = 1, + rank: int = 0, + buffer_size: int = 256 * 1024 * 1024, + **kwargs, + ): + super().__init__(**kwargs) + if num_experts % group_size != 0: + raise ValueError( + f"MojoDeepEPDispatch: num_experts must be divisible by group_size, " + f"got num_experts={num_experts}, group_size={group_size}." + ) + self.num_experts = num_experts + self.top_k = top_k + self.group_size = group_size + self.rank = rank + self.buffer_size = buffer_size + self.local_experts = num_experts // group_size + self.start_expert_id = rank * self.local_experts + + def forward( + self, + hidden_states: torch.Tensor, + top_k_gates: torch.Tensor, + top_k_indices: torch.Tensor, + smooth_scale: Optional[torch.Tensor] = None, + output: Optional[torch.Tensor] = None, + output_scale: Optional[torch.Tensor] = None, + output_size: Optional[int] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + top_k = top_k_indices.size(-1) + if top_k != self.top_k: + raise ValueError(f"top_k_indices last dim must be {self.top_k}, got {top_k}.") + + q_len = hidden_states.size(0) + device = hidden_states.device + top_k_indices = top_k_indices.reshape(-1, top_k).to(torch.int32) + + # Local routing — returned as-is for combine to consume. + _, scatter_index, expert_token_count = _local_dispatch_indices( + top_k_indices, self.num_experts, top_k + ) + + # Dispatch is defined on the global token tensor; for group_size>1 we + # all_gather to reconstruct it, then slice for this rank's experts. + if self.group_size == 1: + global_hidden = hidden_states + global_top_k = top_k_indices + else: + global_hidden = torch.empty( + self.group_size * q_len, + hidden_states.size(1), + dtype=hidden_states.dtype, + device=device, + ) + dist.all_gather_into_tensor(global_hidden, hidden_states.contiguous()) + global_top_k = torch.empty( + self.group_size * q_len, + top_k, + dtype=top_k_indices.dtype, + device=device, + ) + dist.all_gather_into_tensor(global_top_k, top_k_indices.contiguous()) + + global_q = global_hidden.size(0) + global_flat = global_top_k.reshape(-1).to(torch.int64) + global_sort_perm = global_flat.argsort(stable=True) + global_token_idx = ( + torch.arange(global_q, device=device, dtype=torch.int64) + .unsqueeze(1) + .expand(global_q, top_k) + .reshape(-1) + ) + pack_index = global_token_idx[global_sort_perm] + expand_global = global_hidden[pack_index] + sorted_experts_global = global_flat[global_sort_perm] + + global_expert_count = torch.bincount(global_flat, minlength=self.num_experts).to(torch.int32) + cum = global_expert_count.to(torch.int64).cumsum(0) + start_e = self.start_expert_id + end_e = start_e + self.local_experts + start_t = 0 if start_e == 0 else int(cum[start_e - 1].item()) + end_t = int(cum[end_e - 1].item()) + + expand_local = expand_global[start_t:end_t] + sorted_experts_local = sorted_experts_global[start_t:end_t] + + if smooth_scale is not None: + # Fused smooth + per-token int8 quant. + smoothed = expand_local.float() * smooth_scale[sorted_experts_local].float() + expand_scale = smoothed.abs().amax(-1, keepdim=True) / 127.0 + x = smoothed / expand_scale + expand_out = torch.clamp( + torch.floor(x.abs() + 0.5) * x.sign(), -128, 127 + ).to(torch.int8) + else: + expand_out = expand_local + expand_scale = torch.empty( + (expand_local.size(0), 1), dtype=torch.float32, device=device + ) + + expert_token_cnt_per_rank = global_expert_count[start_e:end_e] + expert_token_cnt_cumsum = expert_token_cnt_per_rank.to(torch.int64).cumsum(0) + + return ( + expand_out, + expert_token_cnt_per_rank, + expert_token_cnt_cumsum, + expand_scale, + scatter_index, + expert_token_count, + ) + + +class MojoDeepEPCombine(MojoOperator): + """All-to-all MoE expert-output combine — reverse of MojoDeepEPDispatch. + + Gathers per-local-expert outputs from all ranks, weights by top-k gates, and + scatters back to the original [q_len, hidden] layout. + + Init params (must match the paired MojoDeepEPDispatch): + - num_experts, top_k, group_size, rank, buffer_size. + + Forward args: + - expert_outputs: [R, hidden] — local experts' outputs (sorted by local expert id). + - top_k_gates: [q_len, top_k] — gating weights for the top-k reduction. + - scatter_index: [q_len, top_k] int32 — from the paired dispatch. + - expert_token_count: [num_experts] int32 — global per-expert count from the paired dispatch. + - q_len (int): Original token count — sizes the output buffer. + - output: optional pre-allocated [q_len, hidden] tensor. + """ + + def __init__( + self, + num_experts: int, + top_k: int, + group_size: int = 1, + rank: int = 0, + buffer_size: int = 256 * 1024 * 1024, + **kwargs, + ): + super().__init__(**kwargs) + if num_experts % group_size != 0: + raise ValueError( + f"MojoDeepEPCombine: num_experts must be divisible by group_size, " + f"got num_experts={num_experts}, group_size={group_size}." + ) + self.num_experts = num_experts + self.top_k = top_k + self.group_size = group_size + self.rank = rank + self.buffer_size = buffer_size + self.local_experts = num_experts // group_size + self.start_expert_id = rank * self.local_experts + + def forward( + self, + expert_outputs: torch.Tensor, + top_k_gates: torch.Tensor, + scatter_index: torch.Tensor, + expert_token_count: torch.Tensor, + q_len: int, + output: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + top_k = top_k_gates.size(-1) + if top_k != self.top_k: + raise ValueError(f"top_k_gates last dim must be {self.top_k}, got {top_k}.") + + top_k_gates = top_k_gates.reshape(-1, top_k) + scatter_index = scatter_index.reshape(-1, top_k).to(torch.int64) + device = expert_outputs.device + + if self.group_size == 1: + # All experts are local — scatter_index directly indexes expert_outputs. + gathered = expert_outputs[scatter_index.reshape(-1)].reshape(q_len, top_k, -1) + else: + # Recover local top_k_indices from (scatter_index, expert_token_count): + # sorted_local_experts[p] is the expert id at local sort position p, + # so top_k_indices_local[t, k] = sorted_local_experts[scatter_index[t, k]]. + sorted_local_experts = torch.repeat_interleave( + torch.arange(self.num_experts, device=device, dtype=torch.int64), + expert_token_count.to(torch.int64), + ) + top_k_indices_local = sorted_local_experts[scatter_index.reshape(-1)].reshape( + q_len, top_k + ).to(torch.int32) + + # all_gather top_k_indices across ranks to rebuild the global routing. + global_top_k = torch.empty( + self.group_size * q_len, + top_k, + dtype=top_k_indices_local.dtype, + device=device, + ) + dist.all_gather_into_tensor(global_top_k, top_k_indices_local.contiguous()) + + # Variable-size all_gather of expert_outputs: pad to max R, gather, trim. + global_expert_count = expert_token_count.clone() + dist.all_reduce(global_expert_count, op=dist.ReduceOp.SUM) + cum_global = global_expert_count.to(torch.int64).cumsum(0) + r_per_rank = [] + for r in range(self.group_size): + s = r * self.local_experts + e = s + self.local_experts + start = 0 if s == 0 else int(cum_global[s - 1].item()) + end = int(cum_global[e - 1].item()) + r_per_rank.append(end - start) + max_r = max(r_per_rank) if r_per_rank else 0 + + padded = torch.zeros( + max(max_r, 1), + expert_outputs.size(1), + dtype=expert_outputs.dtype, + device=device, + ) + if expert_outputs.size(0) > 0: + padded[: expert_outputs.size(0)] = expert_outputs + gathered_padded = [torch.zeros_like(padded) for _ in range(self.group_size)] + dist.all_gather(gathered_padded, padded) + global_expand = torch.cat( + [gathered_padded[r][: r_per_rank[r]] for r in range(self.group_size)], + dim=0, + ) + + # Global scatter (inverse of the global stable sort by expert id). + global_flat = global_top_k.reshape(-1).to(torch.int64) + global_sort_perm = global_flat.argsort(stable=True) + global_scatter = global_sort_perm.argsort(stable=True) + + # Slice the global scatter for this rank's (token, k) pairs. + offset = self.rank * q_len * top_k + local_global_pos = global_scatter[offset : offset + q_len * top_k] + gathered = global_expand[local_global_pos].reshape(q_len, top_k, -1) + + combined = (gathered.float() * top_k_gates.float().unsqueeze(-1)).sum(dim=1) + result = combined.to(expert_outputs.dtype) + + if output is not None: + output.copy_(result) + return output + return result diff --git a/mojo_opset/core/operators/moe.py b/mojo_opset/core/operators/moe.py index cd2753ea1..3de245d1c 100644 --- a/mojo_opset/core/operators/moe.py +++ b/mojo_opset/core/operators/moe.py @@ -123,6 +123,7 @@ def __init__( up_weight_dtype=up_weight_dtype, down_quant_group_size=down_quant_group_size, down_weight_dtype=down_weight_dtype, + top_k=self.top_k, **kwargs, ) self.combine = MojoMoECombine._registry.get(self._backend)(multiply_by_gates=True, **kwargs) diff --git a/mojo_opset/tests/accuracy/operators/test_attention_quant.py b/mojo_opset/tests/accuracy/operators/test_attention_quant.py index d800ef1a9..b63679388 100644 --- a/mojo_opset/tests/accuracy/operators/test_attention_quant.py +++ b/mojo_opset/tests/accuracy/operators/test_attention_quant.py @@ -525,6 +525,7 @@ def test_paged_decode_gqa_with_kv_dequant( [ ("ABAB", 4, 255), ("AABB", 4, 1023), + ("ABAB", 4, 2047), ], ) @pytest.mark.parametrize("query_dtype, context_dtype, compute_dtype", diff --git a/mojo_opset/tests/accuracy/operators/test_deepep.py b/mojo_opset/tests/accuracy/operators/test_deepep.py new file mode 100644 index 000000000..d0323c817 --- /dev/null +++ b/mojo_opset/tests/accuracy/operators/test_deepep.py @@ -0,0 +1,280 @@ +"""Accuracy tests for MojoDeepEPDispatch / MojoDeepEPCombine.""" + +import os +import socket +import traceback + +import pytest +import torch +import torch.distributed as dist +import torch.multiprocessing as mp + +from mojo_opset import MojoDeepEPCombine +from mojo_opset import MojoDeepEPDispatch +from mojo_opset.tests.utils import auto_switch_platform +from mojo_opset.tests.utils import bypass_not_implemented +from mojo_opset.utils.platform import get_torch_device + + +# --------------------------------------------------------------------------- +# Shared helpers. +# --------------------------------------------------------------------------- + + +def _make_global_inputs(world_size, num_tokens_sp, hidden, num_experts, top_k, dtype, device): + global_tokens = num_tokens_sp * world_size + if dtype == torch.int8: + hidden_states = torch.randint(-128, 127, (global_tokens, hidden), dtype=torch.int8, device=device) + else: + hidden_states = torch.randn(global_tokens, hidden, dtype=dtype, device=device) + gating = torch.rand(global_tokens, num_experts, dtype=torch.float32, device=device) + top_k_logits, top_k_indices = torch.topk(gating, top_k) + top_k_gates = torch.nn.functional.softmax(top_k_logits, dim=-1) + return hidden_states, top_k_gates, top_k_indices.to(torch.int32) + + +def _find_free_port() -> int: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + sock.bind(("", 0)) + return sock.getsockname()[1] + + +def _xops_skip_if_unsupported(num_experts, world_size): + if world_size < 1: + pytest.skip("MOJO_XOPS_TEST_WORLD_SIZE must be >= 1") + if num_experts % world_size != 0: + pytest.skip(f"num_experts={num_experts} must be divisible by world_size={world_size}") + if torch.npu.device_count() < world_size: + pytest.skip(f"Need {world_size} NPU devices, got {torch.npu.device_count()}") + local_experts = num_experts // world_size + from mojo_opset_ext.backends.xpu_ops.operators.moe import is_deep_ep_local_experts_supported + + if world_size > 1 and not is_deep_ep_local_experts_supported(local_experts): + pytest.skip( + f"DeepEPMoe kernels require local_experts==1 or local_experts%8==0, got {local_experts}" + ) + + +def _run_distributed(case_args, world_size, worker): + ctx = mp.get_context("forkserver") + port = _find_free_port() + result_queue = ctx.Queue() + processes = [] + for rank in range(world_size): + process = ctx.Process( + target=worker, + args=(rank, world_size, port, result_queue, case_args), + ) + process.start() + processes.append(process) + + for process in processes: + process.join() + + results = [] + while not result_queue.empty(): + results.append(result_queue.get()) + errors = [(rank, error) for rank, error in results if error is not None] + if errors: + message = "\n".join(f"[Rank {rank}]\n{error}" for rank, error in errors) + pytest.fail(f"Distributed DeepEP test failed:\n{message}") + + for rank, process in enumerate(processes): + if process.exitcode != 0: + pytest.fail(f"[Rank {rank}] exited with code {process.exitcode}") + + +# (num_experts, top_k, hidden, num_tokens_sp) — kept moderate to keep CI cost bounded. +deep_ep_cases = [ + (8, 2, 256, 16), + (16, 4, 512, 32), + (64, 8, 1024, 64), + (384, 8, 3072, 64), +] + + +# --------------------------------------------------------------------------- +# Dispatch — single unified test driving forward_diff_with for ops vs torch. +# --------------------------------------------------------------------------- + + +def _dispatch_compare(rank, world_size, port, queue, case_args): + """Ops-vs-torch dispatch comparison. Sets up hccl + symmetric memory when + world_size>1; ``queue`` is the multiprocess result queue (``None`` for an + in-process single-rank call).""" + shmem_manager = None + try: + if world_size > 1: + import torch_npu + from mojo_opset.runtime import MojoSymmetricMemoryManager + + torch_npu.npu.set_device(rank) + dist.init_process_group( + backend="hccl", + rank=rank, + world_size=world_size, + init_method=f"tcp://127.0.0.1:{port}", + ) + shmem_manager = MojoSymmetricMemoryManager.get_or_create( + backend="xops", shmem_heap_size_mb=2048 + ) + shmem_manager.get_backend_manager() + + num_tokens_sp, hidden, top_k, num_experts, dtype, use_smooth_scale = case_args + device = get_torch_device() + torch.manual_seed(0) + global_hidden, global_gates, global_indices = _make_global_inputs( + world_size, num_tokens_sp, hidden, num_experts, top_k, dtype, device + ) + smooth_scale = ( + torch.rand(num_experts, hidden, dtype=torch.float32, device=device) + 0.5 + if use_smooth_scale + else None + ) + s = rank * num_tokens_sp + e = s + num_tokens_sp + local_hidden = global_hidden[s:e].contiguous() + local_gates = global_gates[s:e].contiguous() + local_indices = global_indices[s:e].contiguous() + + op = MojoDeepEPDispatch( + num_experts=num_experts, top_k=top_k, group_size=world_size, rank=rank, + ).to(device) + op_ref = MojoDeepEPDispatch._registry.get("torch")( + num_experts=num_experts, top_k=top_k, group_size=world_size, rank=rank, + ).to(device) + + # Return tuple = (expand_hidden_states, expert_token_cnt_per_rank, + # expert_token_cnt_cumsum, expand_scale, scatter_index, expert_token_count). + # Index 3 (expand_scale) is a meaningless placeholder when smooth_scale is None; + # widen its tolerance so the rest of the tuple still gates the comparison. + scale_tol = 1e-5 if use_smooth_scale else float("inf") + op.forward_diff_with( + op_ref, + local_hidden, + local_gates, + local_indices, + smooth_scale=smooth_scale, + atol=(0, 0, 0, scale_tol, 0, 0), + rtol=(0, 0, 0, scale_tol, 0, 0), + ) + + if queue is not None: + queue.put((rank, None)) + except Exception: + if queue is not None: + queue.put((rank, traceback.format_exc())) + else: + raise + finally: + if shmem_manager is not None: + shmem_manager.close() + if dist.is_initialized(): + dist.destroy_process_group() + + +@pytest.mark.parametrize("num_experts, top_k, hidden, num_tokens_sp", deep_ep_cases) +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.int8]) +@pytest.mark.parametrize("use_smooth_scale", [False, True]) +@pytest.mark.parametrize("world_size", [2, 4, 8]) +@auto_switch_platform() +@bypass_not_implemented +def test_deep_ep_dispatch(world_size, num_experts, top_k, hidden, num_tokens_sp, dtype, use_smooth_scale): + """Compare active backend's dispatch with torch backend via forward_diff_with.""" + if dtype == torch.int8 and use_smooth_scale: + pytest.skip("int8 input + per_token quant is not supported by the kernel.") + if os.environ.get("MOJO_BACKEND", "").strip().lower() != "xops": + pytest.skip("ops-vs-torch comparison requires MOJO_BACKEND=xops") + + _xops_skip_if_unsupported(num_experts, world_size) + case_args = (num_tokens_sp, hidden, top_k, num_experts, dtype, use_smooth_scale) + _run_distributed(case_args, world_size, _dispatch_compare) + + +# --------------------------------------------------------------------------- +# Combine — single unified test driving forward_diff_with for ops vs torch. +# --------------------------------------------------------------------------- + + +def _combine_compare(rank, world_size, port, queue, case_args): + """Ops-vs-torch combine comparison. Sets up hccl + symmetric memory when + world_size>1; ``queue`` is the multiprocess result queue (``None`` for an + in-process single-rank call).""" + shmem_manager = None + try: + if world_size > 1: + import torch_npu + from mojo_opset.runtime import MojoSymmetricMemoryManager + + torch_npu.npu.set_device(rank) + dist.init_process_group( + backend="hccl", + rank=rank, + world_size=world_size, + init_method=f"tcp://127.0.0.1:{port}", + ) + shmem_manager = MojoSymmetricMemoryManager.get_or_create( + backend="xops", shmem_heap_size_mb=2048 + ) + shmem_manager.get_backend_manager() + + num_tokens_sp, hidden, top_k, num_experts, dtype = case_args + device = get_torch_device() + torch.manual_seed(0) + global_hidden, global_gates, global_indices = _make_global_inputs( + world_size, num_tokens_sp, hidden, num_experts, top_k, dtype, device + ) + s = rank * num_tokens_sp + e = s + num_tokens_sp + local_hidden = global_hidden[s:e].contiguous() + local_gates = global_gates[s:e].contiguous() + local_indices = global_indices[s:e].contiguous() + + # Build deterministic combine inputs by running torch dispatch. + dispatch_op = MojoDeepEPDispatch._registry.get("torch")( + num_experts=num_experts, top_k=top_k, group_size=world_size, rank=rank, + ).to(device) + expand, _, _, _, scatter_index, expert_token_count = dispatch_op( + local_hidden, local_gates, local_indices + ) + + op = MojoDeepEPCombine( + num_experts=num_experts, top_k=top_k, group_size=world_size, rank=rank, + ).to(device) + op_ref = MojoDeepEPCombine._registry.get("torch")( + num_experts=num_experts, top_k=top_k, group_size=world_size, rank=rank, + ).to(device) + + op.forward_diff_with( + op_ref, + expand, local_gates, scatter_index, expert_token_count, num_tokens_sp, + atol=2**-6, rtol=2**-6, mixed_tol=True, + ) + + if queue is not None: + queue.put((rank, None)) + except Exception: + if queue is not None: + queue.put((rank, traceback.format_exc())) + else: + raise + finally: + if shmem_manager is not None: + shmem_manager.close() + if dist.is_initialized(): + dist.destroy_process_group() + + +@pytest.mark.parametrize("num_experts, top_k, hidden, num_tokens_sp", deep_ep_cases) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("world_size", [2, 4, 8]) +@auto_switch_platform() +@bypass_not_implemented +def test_deep_ep_combine(world_size, num_experts, top_k, hidden, num_tokens_sp, dtype): + """Compare active backend's combine with torch backend via forward_diff_with.""" + if os.environ.get("MOJO_BACKEND", "").strip().lower() != "xops": + pytest.skip("ops-vs-torch comparison requires MOJO_BACKEND=xops") + + _xops_skip_if_unsupported(num_experts, world_size) + case_args = (num_tokens_sp, hidden, top_k, num_experts, dtype) + _run_distributed(case_args, world_size, _combine_compare) diff --git a/mojo_opset/tests/accuracy/operators/test_moe_quant.py b/mojo_opset/tests/accuracy/operators/test_moe_quant.py index 63e34e0ec..3b43ce002 100644 --- a/mojo_opset/tests/accuracy/operators/test_moe_quant.py +++ b/mojo_opset/tests/accuracy/operators/test_moe_quant.py @@ -1,8 +1,13 @@ from typing import Union +import math import os +import socket +import traceback import pytest import torch +import torch.distributed as dist +import torch.multiprocessing as mp from mojo_opset import MojoQuantExperts from mojo_opset import MojoQuantMoE @@ -135,16 +140,171 @@ def _make_quant_weights( down_quant_group_size: int, down_weight_dtype: Union[torch.dtype, str], ): - up_weight_fp = torch.randn(num_experts, intermediate_size * 2, hidden_size, dtype=torch.float32) * 0.01 - down_weight_fp = torch.randn(num_experts, hidden_size, intermediate_size, dtype=torch.float32) * 0.01 + up_weight_fp = torch.randn(num_experts, intermediate_size * 2, hidden_size, dtype=torch.float32) * ( + 1.0 / math.sqrt(hidden_size) + ) + down_weight_fp = torch.randn(num_experts, hidden_size, intermediate_size, dtype=torch.float32) * ( + 1.0 / math.sqrt(intermediate_size) + ) up_weight, up_weight_scale = _quantize_weight_per_group(up_weight_fp, up_quant_group_size, up_weight_dtype) down_weight, down_weight_scale = _quantize_weight_per_group(down_weight_fp, down_quant_group_size, down_weight_dtype) return up_weight, up_weight_scale.bfloat16(), down_weight, down_weight_scale.bfloat16() +def _find_free_port(): + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + sock.bind(("", 0)) + return sock.getsockname()[1] + + +def _run_quant_moe_backend_case( + num_experts, + top_k, + hidden_size, + intermediate_size, + num_tokens, + up_weight_dtype, + up_quant_group_size, + down_weight_dtype, + down_quant_group_size, + dtype, + *, + rank=0, + world_size=1, +): + device = get_torch_device() + is_xops_backend = os.environ.get("MOJO_BACKEND", "").strip().lower() == "xops" + if is_xops_backend: + torch.manual_seed(0) + # Real EP scenario: every rank generates the SAME global input via identical + # rng sequence, then takes its own slice. The torch reference op consumes the + # full global input; each rank compares its xops output against the matching + # slice of the torch ref output. + global_tokens = num_tokens * world_size + global_hidden_states = torch.randn(global_tokens, hidden_size, dtype=dtype, device=device) + local_hidden_states = global_hidden_states[ + rank * num_tokens : (rank + 1) * num_tokens + ].contiguous() + gate_weight = torch.randn(hidden_size, num_experts, dtype=torch.float32, device=device) * 0.2 + fc1_input_smooth_scale = torch.rand(num_experts, hidden_size, dtype=torch.float32, device=device) + 0.5 + fc2_input_smooth_scale = torch.rand(num_experts, intermediate_size, dtype=torch.float32, device=device) + 0.5 + up_weight, up_weight_scale, down_weight, down_weight_scale = _make_quant_weights( + num_experts, + hidden_size, + intermediate_size, + up_quant_group_size, + up_weight_dtype, + down_quant_group_size, + down_weight_dtype, + ) + + state_dict = { + "gating.gate_weight": gate_weight, + "experts.up_proj_weight": up_weight.to(device), + "experts.down_proj_weight": down_weight.to(device), + "experts.up_proj_weight_scale": up_weight_scale.to(device), + "experts.down_proj_weight_scale": down_weight_scale.to(device), + "experts.up_proj_quantize.inv_smooth_scale": (1.0 / fc1_input_smooth_scale).to(device), + "experts.down_proj_quantize.inv_smooth_scale": (1.0 / fc2_input_smooth_scale).to(device), + } + ref_state_dict = {k: v.clone() for k, v in state_dict.items()} + + op = MojoQuantMoE( + num_experts=num_experts, + top_k=top_k, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + quant_dtype=torch.int8, + up_quant_group_size=up_quant_group_size, + up_weight_dtype=up_weight_dtype, + down_quant_group_size=down_quant_group_size, + down_weight_dtype=down_weight_dtype, + **({"group_size": world_size, "rank": rank} if is_xops_backend else {}), + ).to(device) + op_ref = MojoQuantMoE._registry.get("torch")( + num_experts=num_experts, + top_k=top_k, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + quant_dtype=torch.int8, + up_quant_group_size=up_quant_group_size, + up_weight_dtype=up_weight_dtype, + down_quant_group_size=down_quant_group_size, + down_weight_dtype=down_weight_dtype, + ).to(device) + op.load_state_dict(state_dict) + if is_xops_backend: + op._dispatch_up_proj_inv_smooth_scale = op.experts.up_proj_quantize.inv_smooth_scale + op.experts.prepare_runtime_weight_layout() + op_ref.load_state_dict(ref_state_dict) + + from mojo_opset.utils.acc import check_tol_diff + + actual = op(local_hidden_states.clone()) + expected_full = op_ref(global_hidden_states.clone()) + expected_local = expected_full[rank * num_tokens : (rank + 1) * num_tokens] + if is_xops_backend: + check_tol_diff(actual, expected_local, atol=6e-2, rtol=2**-6, ptol=0.97, mixed_tol=False) + else: + check_tol_diff(actual, expected_local, atol=1e-2, rtol=1e-2, ptol=1.0, mixed_tol=True) + + +def _quant_moe_backend_worker(rank, world_size, port, result_queue, case_args): + shmem_manager = None + try: + import torch_npu + from mojo_opset.runtime import MojoSymmetricMemoryManager + + torch_npu.npu.set_device(rank) + init_method = f"tcp://127.0.0.1:{port}" + dist.init_process_group(backend="hccl", rank=rank, world_size=world_size, init_method=init_method) + shmem_manager = MojoSymmetricMemoryManager.get_or_create(backend="xops", shmem_heap_size_mb=2048) + shmem_manager.get_backend_manager() + _run_quant_moe_backend_case(*case_args, rank=rank, world_size=world_size) + result_queue.put((rank, None)) + except Exception: + result_queue.put((rank, traceback.format_exc())) + finally: + if shmem_manager is not None: + shmem_manager.close() + if dist.is_initialized(): + dist.destroy_process_group() + + +def _run_quant_moe_backend_distributed(case_args, world_size): + ctx = mp.get_context("forkserver") + port = _find_free_port() + result_queue = ctx.Queue() + processes = [] + for rank in range(world_size): + process = ctx.Process( + target=_quant_moe_backend_worker, + args=(rank, world_size, port, result_queue, case_args), + ) + process.start() + processes.append(process) + + for process in processes: + process.join() + + results = [] + while not result_queue.empty(): + results.append(result_queue.get()) + errors = [(rank, error) for rank, error in results if error is not None] + if errors: + message = "\n".join(f"[Rank {rank}]\n{error}" for rank, error in errors) + pytest.fail(f"Distributed quant MoE backend test failed:\n{message}") + + for rank, process in enumerate(processes): + if process.exitcode != 0: + pytest.fail(f"[Rank {rank}] exited with code {process.exitcode}") + + quant_moe_backend_cases = [ (16, 2, 512, 1280, 33), (24, 4, 512, 1280, 97), + (64, 8, 512, 1280, 64), + (128, 8, 512, 1280, 128), ] @@ -484,54 +644,36 @@ def test_quant_moe_backend( down_quant_group_size, dtype, ): - device = get_torch_device() - - hidden_states = torch.randn(num_tokens, hidden_size, dtype=dtype, device=device) - gate_weight = torch.randn(hidden_size, num_experts, dtype=torch.float32, device=device) * 0.2 - fc1_input_smooth_scale = torch.rand(num_experts, hidden_size, dtype=torch.float32, device=device) + 0.5 - fc2_input_smooth_scale = torch.rand(num_experts, intermediate_size, dtype=torch.float32, device=device) + 0.5 - up_weight, up_weight_scale, down_weight, down_weight_scale = _make_quant_weights( + case_args = ( num_experts, + top_k, hidden_size, intermediate_size, - up_quant_group_size, + num_tokens, up_weight_dtype, - down_quant_group_size, + up_quant_group_size, down_weight_dtype, + down_quant_group_size, + dtype, ) - state_dict = { - "gating.gate_weight": gate_weight, - "experts.up_proj_weight": up_weight.to(device), - "experts.down_proj_weight": down_weight.to(device), - "experts.up_proj_weight_scale": up_weight_scale.to(device), - "experts.down_proj_weight_scale": down_weight_scale.to(device), - "experts.up_proj_quantize.inv_smooth_scale": (1.0 / fc1_input_smooth_scale).to(device), - "experts.down_proj_quantize.inv_smooth_scale": (1.0 / fc2_input_smooth_scale).to(device), - } - - op = MojoQuantMoE( - num_experts=num_experts, - top_k=top_k, - hidden_size=hidden_size, - intermediate_size=intermediate_size, - quant_dtype=torch.int8, - up_quant_group_size=up_quant_group_size, - up_weight_dtype=up_weight_dtype, - down_quant_group_size=down_quant_group_size, - down_weight_dtype=down_weight_dtype, - ).to(device) - op_ref = MojoQuantMoE._registry.get("torch")( - num_experts=num_experts, - top_k=top_k, - hidden_size=hidden_size, - intermediate_size=intermediate_size, - quant_dtype=torch.int8, - up_quant_group_size=up_quant_group_size, - up_weight_dtype=up_weight_dtype, - down_quant_group_size=down_quant_group_size, - down_weight_dtype=down_weight_dtype, - ).to(device) - op.load_state_dict(state_dict) - op_ref.load_state_dict({k: v.clone() for k, v in state_dict.items()}) - op.forward_diff_with(op_ref, hidden_states, mixed_tol=True) + if os.environ.get("MOJO_BACKEND", "").strip().lower() == "xops": + world_size = int(os.environ.get("MOJO_XOPS_TEST_WORLD_SIZE", "8")) + if dtype != torch.bfloat16: + pytest.skip("XopsQuantExperts W4A8 path currently returns bfloat16 output only.") + if world_size < 1: + pytest.skip("MOJO_XOPS_TEST_WORLD_SIZE must be >= 1") + if num_experts % world_size != 0: + pytest.skip(f"num_experts={num_experts} must be divisible by {world_size=}") + if torch.npu.device_count() < world_size: + pytest.skip(f"Need {world_size} NPU devices, got {torch.npu.device_count()}") + local_experts = num_experts // world_size + from mojo_opset_ext.backends.xpu_ops.operators.moe import is_deep_ep_local_experts_supported + # 非整除8时精度不对,在这里skip。分布式用NotImplementedError没法正确skip + if world_size > 1 and not is_deep_ep_local_experts_supported(local_experts): + pytest.skip( + f"DeepEPMoe kernels require local_experts==1 or local_experts%8==0, got {local_experts}" + ) + _run_quant_moe_backend_distributed(case_args, world_size) + else: + _run_quant_moe_backend_case(*case_args) From 9a528472b09a297ea54f78e89d96582d16083d25 Mon Sep 17 00:00:00 2001 From: heqiushi Date: Fri, 5 Jun 2026 18:05:26 +0800 Subject: [PATCH 02/10] update test_deepep.py --- .../tests/accuracy/operators/test_deepep.py | 55 +++++++++++++------ 1 file changed, 38 insertions(+), 17 deletions(-) diff --git a/mojo_opset/tests/accuracy/operators/test_deepep.py b/mojo_opset/tests/accuracy/operators/test_deepep.py index d0323c817..1963d8fc7 100644 --- a/mojo_opset/tests/accuracy/operators/test_deepep.py +++ b/mojo_opset/tests/accuracy/operators/test_deepep.py @@ -13,6 +13,7 @@ from mojo_opset import MojoDeepEPDispatch from mojo_opset.tests.utils import auto_switch_platform from mojo_opset.tests.utils import bypass_not_implemented +from mojo_opset.utils.acc import check_tol_diff from mojo_opset.utils.platform import get_torch_device @@ -139,25 +140,32 @@ def _dispatch_compare(rank, world_size, port, queue, case_args): op = MojoDeepEPDispatch( num_experts=num_experts, top_k=top_k, group_size=world_size, rank=rank, + quant_mode="per_token" if use_smooth_scale else "none", ).to(device) op_ref = MojoDeepEPDispatch._registry.get("torch")( num_experts=num_experts, top_k=top_k, group_size=world_size, rank=rank, ).to(device) - # Return tuple = (expand_hidden_states, expert_token_cnt_per_rank, - # expert_token_cnt_cumsum, expand_scale, scatter_index, expert_token_count). - # Index 3 (expand_scale) is a meaningless placeholder when smooth_scale is None; - # widen its tolerance so the rest of the tuple still gates the comparison. - scale_tol = 1e-5 if use_smooth_scale else float("inf") - op.forward_diff_with( - op_ref, - local_hidden, - local_gates, - local_indices, - smooth_scale=smooth_scale, - atol=(0, 0, 0, scale_tol, 0, 0), - rtol=(0, 0, 0, scale_tol, 0, 0), - ) + # xops dispatch returns an upper-bound buffer (q_len*group_size*top_k rows); + # only the first R rows are valid where R = sum(expert_token_cnt_per_rank). + # The torch reference returns exactly R rows. Trim before comparing index 0 / 3. + out = op.forward(local_hidden, local_gates, local_indices, smooth_scale=smooth_scale) + ref = op_ref.forward(local_hidden, local_gates, local_indices, smooth_scale=smooth_scale) + + r_actual = int(out[1].sum().item()) + r_ref = ref[0].size(0) + assert r_actual == r_ref, f"R mismatch: xops={r_actual}, ref={r_ref}" + + # idx 0: int8 quant rounding can differ by ±1 between hw kernel and torch ref. + hidden_atol = 1 if use_smooth_scale else 0 + torch.testing.assert_close(out[0][:r_actual], ref[0], atol=hidden_atol, rtol=0) + torch.testing.assert_close(out[1], ref[1], atol=0, rtol=0) + # xops cumsum dtype is int32 (wrapper init), torch ref is int64 — values must agree. + torch.testing.assert_close(out[2].to(torch.int64), ref[2], atol=0, rtol=0) + if use_smooth_scale: + torch.testing.assert_close(out[3][:r_actual], ref[3], atol=1e-5, rtol=1e-5) + torch.testing.assert_close(out[4], ref[4], atol=0, rtol=0) + torch.testing.assert_close(out[5], ref[5], atol=0, rtol=0) if queue is not None: queue.put((rank, None)) @@ -245,11 +253,24 @@ def _combine_compare(rank, world_size, port, queue, case_args): num_experts=num_experts, top_k=top_k, group_size=world_size, rank=rank, ).to(device) - op.forward_diff_with( - op_ref, + # xops combine asserts expand.size(0) >= q_len * top_k (deep_ep.cpp:468). + # Torch dispatch returns R-sized expand; pad to the required upper bound for + # xops, but keep the R-sized version for torch combine which expects it. + upper = num_tokens_sp * world_size * top_k + if expand.size(0) < upper: + padded = torch.zeros(upper, expand.size(1), dtype=expand.dtype, device=device) + padded[: expand.size(0)] = expand + expand_for_xops = padded + else: + expand_for_xops = expand + + xops_out = op.forward( + expand_for_xops, local_gates, scatter_index, expert_token_count, num_tokens_sp, + ) + ref_out = op_ref.forward( expand, local_gates, scatter_index, expert_token_count, num_tokens_sp, - atol=2**-6, rtol=2**-6, mixed_tol=True, ) + check_tol_diff(xops_out, ref_out, mixed_tol=True) if queue is not None: queue.put((rank, None)) From 8121473c502bd2a4af3abe2bcc5e259b29f80271 Mon Sep 17 00:00:00 2001 From: heqiushi Date: Thu, 11 Jun 2026 14:38:47 +0800 Subject: [PATCH 03/10] =?UTF-8?q?fix:=20=E5=85=A5=E5=8F=82=E5=A2=9E?= =?UTF-8?q?=E5=8A=A0pg,=20deepep.py=E4=BB=8Ecore=E7=A7=BB=E5=88=B0experime?= =?UTF-8?q?ntal?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- mojo_opset/core/__init__.py | 4 ---- mojo_opset/core/operators/moe.py | 3 +++ mojo_opset/experimental/__init__.py | 4 ++++ .../operators/deepep.py | 19 +++++++++++++------ .../tests/accuracy/operators/test_deepep.py | 4 ++-- 5 files changed, 22 insertions(+), 12 deletions(-) rename mojo_opset/{core => experimental}/operators/deepep.py (94%) diff --git a/mojo_opset/core/__init__.py b/mojo_opset/core/__init__.py index d2ed8aa16..decd01dfa 100644 --- a/mojo_opset/core/__init__.py +++ b/mojo_opset/core/__init__.py @@ -57,8 +57,6 @@ from .operators.quantize import MojoStaticQuant """ moe """ -from .operators.deepep import MojoDeepEPCombine -from .operators.deepep import MojoDeepEPDispatch from .operators.moe import MojoExperts from .operators.moe import MojoMoE from .operators.moe import MojoMoECombine @@ -157,8 +155,6 @@ "MojoMoECombine", "MojoQuantExperts", "MojoQuantMoE", - "MojoDeepEPDispatch", - "MojoDeepEPCombine", "MojoLayerNorm", "MojoRMSNorm", diff --git a/mojo_opset/core/operators/moe.py b/mojo_opset/core/operators/moe.py index 3de245d1c..9281d3d92 100644 --- a/mojo_opset/core/operators/moe.py +++ b/mojo_opset/core/operators/moe.py @@ -1,6 +1,7 @@ from typing import Optional, Tuple, Union import torch +import torch.distributed as dist import torch.nn as nn import torch.nn.functional as F @@ -80,6 +81,7 @@ def __init__( up_weight_dtype: Union[torch.dtype, str] = torch.int8, down_quant_group_size: int = -1, down_weight_dtype: Union[torch.dtype, str] = torch.int8, + process_group: Optional[dist.ProcessGroup] = None, **kwargs, ): super().__init__() @@ -105,6 +107,7 @@ def __init__( self.up_weight_dtype = up_weight_dtype self.down_quant_group_size = down_quant_group_size self.down_weight_dtype = down_weight_dtype + self.process_group = process_group self.gating = MojoMoEGating._registry.get(self._backend)( hidden_size=self.hidden_size, diff --git a/mojo_opset/experimental/__init__.py b/mojo_opset/experimental/__init__.py index 2434ed2f7..3af3f88c4 100755 --- a/mojo_opset/experimental/__init__.py +++ b/mojo_opset/experimental/__init__.py @@ -20,6 +20,8 @@ from .operators.attention import MojoPrefillNSA from .operators.attention_gate import MojoFusedAttnOutputGate from .operators.attention import MojoPagedPrefillSageGQA +from .operators.deepep import MojoDeepEPCombine +from .operators.deepep import MojoDeepEPDispatch from .operators.gemm import MojoQuantBatchGemmReduceSum from .operators.indexer import MojoIndexer from .operators.indexer import MojoLightningIndexer @@ -61,4 +63,6 @@ "MojoGridRoPE", "MojoStoreLowrank", "MojoIndexer", + "MojoDeepEPDispatch", + "MojoDeepEPCombine", ] diff --git a/mojo_opset/core/operators/deepep.py b/mojo_opset/experimental/operators/deepep.py similarity index 94% rename from mojo_opset/core/operators/deepep.py rename to mojo_opset/experimental/operators/deepep.py index c97889dc6..7fae7b523 100644 --- a/mojo_opset/core/operators/deepep.py +++ b/mojo_opset/experimental/operators/deepep.py @@ -48,6 +48,9 @@ class MojoDeepEPDispatch(MojoOperator): - group_size (int): Number of ranks (a.k.a. ep_size). Defaults to 1. - rank (int): Local rank id in [0, group_size). Defaults to 0. - buffer_size (int): Symmetric-memory scratch in bytes. + - process_group (dist.ProcessGroup, optional): EP communication group used by all + collectives. When None, collectives run on the default group (dist.group.WORLD). + Caller is responsible for passing group_size/rank consistent with the group. Forward returns a 6-tuple ``(expand_hidden_states, expert_token_cnt_per_rank, expert_token_cnt_cumsum, expand_scale, scatter_index, expert_token_count)``: @@ -68,6 +71,7 @@ def __init__( group_size: int = 1, rank: int = 0, buffer_size: int = 256 * 1024 * 1024, + process_group: Optional[dist.ProcessGroup] = None, **kwargs, ): super().__init__(**kwargs) @@ -76,6 +80,7 @@ def __init__( f"MojoDeepEPDispatch: num_experts must be divisible by group_size, " f"got num_experts={num_experts}, group_size={group_size}." ) + self.process_group = process_group self.num_experts = num_experts self.top_k = top_k self.group_size = group_size @@ -119,14 +124,14 @@ def forward( dtype=hidden_states.dtype, device=device, ) - dist.all_gather_into_tensor(global_hidden, hidden_states.contiguous()) + dist.all_gather_into_tensor(global_hidden, hidden_states.contiguous(), group=self.process_group) global_top_k = torch.empty( self.group_size * q_len, top_k, dtype=top_k_indices.dtype, device=device, ) - dist.all_gather_into_tensor(global_top_k, top_k_indices.contiguous()) + dist.all_gather_into_tensor(global_top_k, top_k_indices.contiguous(), group=self.process_group) global_q = global_hidden.size(0) global_flat = global_top_k.reshape(-1).to(torch.int64) @@ -185,7 +190,7 @@ class MojoDeepEPCombine(MojoOperator): scatters back to the original [q_len, hidden] layout. Init params (must match the paired MojoDeepEPDispatch): - - num_experts, top_k, group_size, rank, buffer_size. + - num_experts, top_k, group_size, rank, buffer_size, process_group. Forward args: - expert_outputs: [R, hidden] — local experts' outputs (sorted by local expert id). @@ -203,6 +208,7 @@ def __init__( group_size: int = 1, rank: int = 0, buffer_size: int = 256 * 1024 * 1024, + process_group: Optional[dist.ProcessGroup] = None, **kwargs, ): super().__init__(**kwargs) @@ -211,6 +217,7 @@ def __init__( f"MojoDeepEPCombine: num_experts must be divisible by group_size, " f"got num_experts={num_experts}, group_size={group_size}." ) + self.process_group = process_group self.num_experts = num_experts self.top_k = top_k self.group_size = group_size @@ -258,11 +265,11 @@ def forward( dtype=top_k_indices_local.dtype, device=device, ) - dist.all_gather_into_tensor(global_top_k, top_k_indices_local.contiguous()) + dist.all_gather_into_tensor(global_top_k, top_k_indices_local.contiguous(), group=self.process_group) # Variable-size all_gather of expert_outputs: pad to max R, gather, trim. global_expert_count = expert_token_count.clone() - dist.all_reduce(global_expert_count, op=dist.ReduceOp.SUM) + dist.all_reduce(global_expert_count, op=dist.ReduceOp.SUM, group=self.process_group) cum_global = global_expert_count.to(torch.int64).cumsum(0) r_per_rank = [] for r in range(self.group_size): @@ -282,7 +289,7 @@ def forward( if expert_outputs.size(0) > 0: padded[: expert_outputs.size(0)] = expert_outputs gathered_padded = [torch.zeros_like(padded) for _ in range(self.group_size)] - dist.all_gather(gathered_padded, padded) + dist.all_gather(gathered_padded, padded, group=self.process_group) global_expand = torch.cat( [gathered_padded[r][: r_per_rank[r]] for r in range(self.group_size)], dim=0, diff --git a/mojo_opset/tests/accuracy/operators/test_deepep.py b/mojo_opset/tests/accuracy/operators/test_deepep.py index 1963d8fc7..ff25d3179 100644 --- a/mojo_opset/tests/accuracy/operators/test_deepep.py +++ b/mojo_opset/tests/accuracy/operators/test_deepep.py @@ -9,8 +9,8 @@ import torch.distributed as dist import torch.multiprocessing as mp -from mojo_opset import MojoDeepEPCombine -from mojo_opset import MojoDeepEPDispatch +from mojo_opset.experimental import MojoDeepEPCombine +from mojo_opset.experimental import MojoDeepEPDispatch from mojo_opset.tests.utils import auto_switch_platform from mojo_opset.tests.utils import bypass_not_implemented from mojo_opset.utils.acc import check_tol_diff From 659292c73af2308f650d353e387afbf61ed6a676 Mon Sep 17 00:00:00 2001 From: "wangnan.light" Date: Thu, 11 Jun 2026 15:35:44 +0800 Subject: [PATCH 04/10] feat: add MojoRotaryEmbedding --- .../operators/position_embedding.py | 146 ++++++++++++++++++ 1 file changed, 146 insertions(+) diff --git a/mojo_opset/experimental/operators/position_embedding.py b/mojo_opset/experimental/operators/position_embedding.py index a5f7dea44..d9f55cb09 100644 --- a/mojo_opset/experimental/operators/position_embedding.py +++ b/mojo_opset/experimental/operators/position_embedding.py @@ -1,5 +1,6 @@ import math from typing import List +from typing import Optional import torch @@ -118,7 +119,152 @@ def forward( return y.type_as(x) +class MojoRotaryEmbedding(MojoOperator): + """ + Apply RoPE to packed QKV tensors inplace. + """ + + def __init__(self, query_head_num: int, kv_head_num: int, rope_dim: int, interleaved: bool = False, **kwargs): + """ + Args: + query_head_num (int): Number of query heads packed at the front of + ``qkv_input``. + kv_head_num (int): Number of key heads and value heads. ``qkv_input`` + is expected to contain ``query_head_num + 2 * kv_head_num`` heads + in ``[Q, K, V]`` order. + rope_dim (int): Number of tail head channels to rotate. + interleaved (bool): Whether to use interleaved RoPE rotation. When + ``True``, implementations may consume ``interleave_offset`` in + ``forward``. + **kwargs: Tensor factory kwargs passed to ``MojoOperator``. + """ + super().__init__(**kwargs) + if query_head_num <= 0: + raise ValueError(f"query_head_num must be positive, got {query_head_num}.") + if kv_head_num <= 0: + raise ValueError(f"kv_head_num must be positive, got {kv_head_num}.") + if rope_dim <= 0: + raise ValueError(f"rope_dim must be positive, got {rope_dim}.") + self.query_head_num = int(query_head_num) + self.kv_head_num = int(kv_head_num) + self.rope_dim = int(rope_dim) + self.interleaved = bool(interleaved) + + @staticmethod + def _rotate(x: torch.Tensor, interleaved: bool) -> torch.Tensor: + if not interleaved: + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + + y = torch.empty_like(x) + x1 = x[..., ::2] + x2 = x[..., 1::2] + y[..., ::2] = -x2 + y[..., 1::2] = x1 + return y + + def _apply_rope_by_seq( + self, + tensor: torch.Tensor, + sin_embeds: torch.Tensor, + cos_embeds: torch.Tensor, + kv_len: torch.Tensor, + cu_seq_lens: torch.Tensor, + rope_dim: int, + ) -> None: + head_dim = tensor.size(-1) + rope_offset = head_dim - rope_dim + batch_size = kv_len.numel() + + for batch_idx in range(batch_size): + start = int(cu_seq_lens[batch_idx].item()) + end = int(cu_seq_lens[batch_idx + 1].item()) + seq_len = end - start + position = int(kv_len[batch_idx].item()) + x_rope = tensor[start:end, :, rope_offset : rope_offset + rope_dim].float() + sin = sin_embeds[position : position + seq_len].float().unsqueeze(1) + cos = cos_embeds[position : position + seq_len].float().unsqueeze(1) + tensor[start:end, :, rope_offset : rope_offset + rope_dim] = ( + self._rotate(x_rope, self.interleaved) * sin + x_rope * cos + ).to(tensor.dtype) + + def forward( + self, + qkv_input: torch.Tensor, + sin_embeds: torch.Tensor, + cos_embeds: torch.Tensor, + kv_len: torch.Tensor, + cu_seq_lens: Optional[torch.Tensor] = None, + interleave_offset: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Apply rotary position embedding to packed QKV. + + Args: + qkv_input (torch.Tensor): Packed QKV tensor with shape + ``[total_seq_len, (query_head_num + 2 * kv_head_num) * head_dim]``. + Q heads come first, followed by K heads and then V heads. Q/K are + rotated in the tail ``rope_dim`` channels of each head; V is + preserved. + sin_embeds (torch.Tensor): Sine table with shape + ``[max_position_embeddings, rope_dim]``. Dtype should match + ``qkv_input``. + cos_embeds (torch.Tensor): Cosine table with the same shape and dtype + contract as ``sin_embeds``. + kv_len (torch.Tensor): Per-sequence starting position / existing KV + length, shape ``[batch_size]``. + cu_seq_lens (Optional[torch.Tensor]): Cumulative query lengths for + packed varlen input, shape ``[batch_size + 1]``. The torch + fallback currently requires it. + interleave_offset (Optional[torch.Tensor]): Offset table used by the + interleaved rotation path by implementations that need explicit + offsets. It is not used by the torch fallback reference + implementation. + + Returns: + torch.Tensor: Packed QKV tensor with the same shape as ``qkv_input``. + Implementations may return ``qkv_input`` updated in-place or a newly + constructed tensor. + """ + if qkv_input.dim() != 2: + raise NotImplementedError("Torch MojoRotaryEmbedding currently supports packed 2D qkv_input only.") + if cu_seq_lens is None: + raise NotImplementedError("Torch MojoRotaryEmbedding currently requires cu_seq_lens.") + if sin_embeds.shape != cos_embeds.shape: + raise ValueError( + f"sin_embeds and cos_embeds must have same shape, got {tuple(sin_embeds.shape)} and {tuple(cos_embeds.shape)}." + ) + if sin_embeds.dim() != 2: + raise ValueError(f"sin_embeds must be 2D [max_position_embeddings, rope_dim], got {tuple(sin_embeds.shape)}.") + if kv_len.dim() != 1: + raise ValueError(f"kv_len must be 1D [batch_size], got {tuple(kv_len.shape)}.") + if cu_seq_lens.dim() != 1 or cu_seq_lens.numel() != kv_len.numel() + 1: + raise ValueError( + f"cu_seq_lens must be 1D with size batch_size + 1, got shape {tuple(cu_seq_lens.shape)} for kv_len {tuple(kv_len.shape)}." + ) + + if sin_embeds.size(-1) != self.rope_dim: + raise ValueError(f"sin_embeds last dim must match rope_dim={self.rope_dim}, got {sin_embeds.size(-1)}.") + total_head_num = self.query_head_num + 2 * self.kv_head_num + if qkv_input.size(-1) % total_head_num != 0: + raise ValueError( + f"qkv hidden dim {qkv_input.size(-1)} must be divisible by total head count {total_head_num}." + ) + head_dim = qkv_input.size(-1) // total_head_num + if head_dim < self.rope_dim: + raise ValueError(f"head_dim must be >= rope_dim, got head_dim={head_dim}, rope_dim={self.rope_dim}.") + + qkv = qkv_input.view(qkv_input.size(0), total_head_num, head_dim) + q = qkv[:, : self.query_head_num, :] + k = qkv[:, self.query_head_num : self.query_head_num + self.kv_head_num, :] + + self._apply_rope_by_seq(q, sin_embeds, cos_embeds, kv_len, cu_seq_lens, self.rope_dim) + self._apply_rope_by_seq(k, sin_embeds, cos_embeds, kv_len, cu_seq_lens, self.rope_dim) + return qkv_input + + __all__ = [ "MojoRelativeEmbedding", "MojoGridRoPE", + "MojoRotaryEmbedding", ] From de0222b2af497fbb5d1dfd42451c8006d3bce038 Mon Sep 17 00:00:00 2001 From: HJzhang-sjtu Date: Thu, 11 Jun 2026 16:39:49 +0800 Subject: [PATCH 05/10] feat: add MojoGatherRopeStore --- mojo_opset/experimental/__init__.py | 4 + mojo_opset/experimental/operators/__init__.py | 4 + mojo_opset/experimental/operators/kv_cache.py | 126 ++++++++++++++++++ 3 files changed, 134 insertions(+) diff --git a/mojo_opset/experimental/__init__.py b/mojo_opset/experimental/__init__.py index 2434ed2f7..0a5b48143 100755 --- a/mojo_opset/experimental/__init__.py +++ b/mojo_opset/experimental/__init__.py @@ -23,6 +23,7 @@ from .operators.gemm import MojoQuantBatchGemmReduceSum from .operators.indexer import MojoIndexer from .operators.indexer import MojoLightningIndexer +from .operators.kv_cache import MojoGatherRopeStore from .operators.kv_cache import MojoStorePagedMLAKVCache from .operators.moe import MojoFusedSwiGLUMoEScaleDynamicQuantize from .operators.moe import MojoMoEInitRoutingDynamicQuant @@ -30,6 +31,7 @@ from .operators.normalization import MojoGroupLayerNorm from .operators.position_embedding import MojoGridRoPE from .operators.position_embedding import MojoRelativeEmbedding +from .operators.position_embedding import MojoRotaryEmbedding from .operators.store_lowrank import MojoStoreLowrank __all__ = [ @@ -52,6 +54,7 @@ "MojoPagedDecodeSWAWithKVDequant", "MojoFusedAttnOutputGate", "MojoPagedPrefillSageGQA", + "MojoGatherRopeStore", "MojoStorePagedMLAKVCache", "MojoMoEInitRoutingDynamicQuant", "MojoFusedSwiGLUMoEScaleDynamicQuantize", @@ -59,6 +62,7 @@ "MojoChannelRMSNorm", "MojoRelativeEmbedding", "MojoGridRoPE", + "MojoRotaryEmbedding", "MojoStoreLowrank", "MojoIndexer", ] diff --git a/mojo_opset/experimental/operators/__init__.py b/mojo_opset/experimental/operators/__init__.py index d0c747c3d..b2118a840 100644 --- a/mojo_opset/experimental/operators/__init__.py +++ b/mojo_opset/experimental/operators/__init__.py @@ -15,6 +15,7 @@ from .gemm import MojoQuantBatchGemmReduceSum from .indexer import MojoIndexer from .indexer import MojoLightningIndexer +from .kv_cache import MojoGatherRopeStore from .kv_cache import MojoStorePagedMLAKVCache from .moe import MojoFusedSwiGLUMoEScaleDynamicQuantize from .moe import MojoMoEInitRoutingDynamicQuant @@ -22,6 +23,7 @@ from .normalization import MojoGroupLayerNorm from .position_embedding import MojoGridRoPE from .position_embedding import MojoRelativeEmbedding +from .position_embedding import MojoRotaryEmbedding from .store_lowrank import MojoStoreLowrank __all__ = [ @@ -41,6 +43,7 @@ "MojoPagedDecodeGQAWithKVDequant", "MojoPagedPrefillSWAWithKVDequant", "MojoPagedDecodeSWAWithKVDequant", + "MojoGatherRopeStore", "MojoStorePagedMLAKVCache", "MojoMoEInitRoutingDynamicQuant", "MojoFusedSwiGLUMoEScaleDynamicQuantize", @@ -48,6 +51,7 @@ "MojoChannelRMSNorm", "MojoRelativeEmbedding", "MojoGridRoPE", + "MojoRotaryEmbedding", "MojoStoreLowrank", "MojoQuantBatchGemmReduceSum", ] diff --git a/mojo_opset/experimental/operators/kv_cache.py b/mojo_opset/experimental/operators/kv_cache.py index 1066c3dc1..5240563c0 100644 --- a/mojo_opset/experimental/operators/kv_cache.py +++ b/mojo_opset/experimental/operators/kv_cache.py @@ -1,3 +1,4 @@ +from typing import Optional from typing import Tuple import torch @@ -103,6 +104,131 @@ def forward( return compressed_kv_cache, k_pe_cache +class MojoGatherRopeStore(MojoOperator): + """Gather paged key cache blocks, apply RoPE, and optionally store the result.""" + + def __init__(self): + super().__init__() + + @staticmethod + def _rotate(x: torch.Tensor) -> torch.Tensor: + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + + @staticmethod + def _reshape_scale(scale: torch.Tensor, kv_head_num: int, head_dim: int) -> torch.Tensor: + if scale.dim() == 1: + if kv_head_num != 1 or scale.size(0) != head_dim: + raise ValueError( + f"1D scale requires kv_head_num=1 and shape [{head_dim}], got kv_head_num={kv_head_num}, shape={tuple(scale.shape)}." + ) + return scale.reshape(1, 1, 1, head_dim) + if scale.dim() == 2: + if scale.shape != (kv_head_num, head_dim): + raise ValueError(f"2D scale must have shape [{kv_head_num}, {head_dim}], got {tuple(scale.shape)}.") + return scale.reshape(1, kv_head_num, 1, head_dim) + raise ValueError(f"scale must be 1D or 2D, got shape {tuple(scale.shape)}.") + + def forward( + self, + key: torch.Tensor, + rope_cache: Optional[torch.Tensor], + dequant_scale: Optional[torch.Tensor], + kv_idx: torch.Tensor, + sin: torch.Tensor, + cos: torch.Tensor, + quant_scale: Optional[torch.Tensor], + ) -> torch.Tensor: + """ + Args: + key: Paged key cache with shape ``[num_blocks, kv_head_num, page_size, head_dim]``. + Dtype must be ``torch.int8`` or ``torch.bfloat16``. + rope_cache: Optional cache updated in-place with the RoPE result. When + present, shape and dtype must match ``key``. + dequant_scale: Scale used to dequantize int8 ``key``. Shape is + ``[kv_head_num, head_dim]`` or ``[head_dim]`` when ``kv_head_num == 1``. + kv_idx: Page index table with shape ``[batch_size, max_block_nums]``. + Invalid entries are ``-1``. + sin: RoPE sine table with shape ``[max_block_nums * page_size, rope_head_dim]``. + cos: RoPE cosine table with the same shape as ``sin``. + quant_scale: Scale used to quantize int8 ``rope_cache``. + + Returns: + BF16 tensor with the same shape as ``key``. Entries corresponding to + invalid ``kv_idx`` pages are unspecified. + """ + if key.dim() != 4: + raise ValueError(f"key must be 4D [num_blocks, kv_head_num, page_size, head_dim], got {tuple(key.shape)}.") + if key.dtype not in (torch.int8, torch.bfloat16): + raise TypeError(f"key must be torch.int8 or torch.bfloat16, got {key.dtype}.") + if rope_cache is not None: + if rope_cache.shape != key.shape: + raise ValueError(f"rope_cache shape must match key, got {tuple(rope_cache.shape)} and {tuple(key.shape)}.") + if rope_cache.dtype != key.dtype: + raise TypeError(f"rope_cache dtype must match key, got {rope_cache.dtype} and {key.dtype}.") + if kv_idx.dim() != 2: + raise ValueError(f"kv_idx must be 2D [batch_size, max_block_nums], got {tuple(kv_idx.shape)}.") + if kv_idx.dtype != torch.int64: + raise TypeError(f"kv_idx must be torch.int64, got {kv_idx.dtype}.") + if sin.shape != cos.shape or sin.dim() != 2: + raise ValueError(f"sin and cos must be matching 2D tensors, got {tuple(sin.shape)} and {tuple(cos.shape)}.") + if sin.dtype != torch.bfloat16 or cos.dtype != torch.bfloat16: + raise TypeError(f"sin and cos must be torch.bfloat16, got {sin.dtype} and {cos.dtype}.") + + _, kv_head_num, page_size, head_dim = key.shape + max_block_nums = kv_idx.size(1) + rope_head_dim = sin.size(1) + if rope_head_dim >= head_dim: + raise ValueError(f"rope_head_dim must be smaller than head_dim, got {rope_head_dim} and {head_dim}.") + if rope_head_dim % 2 != 0: + raise ValueError(f"rope_head_dim must be even, got {rope_head_dim}.") + + nope_head_dim = head_dim - rope_head_dim + output = torch.empty(key.shape, dtype=torch.bfloat16, device=key.device) + valid_idx = kv_idx.reshape(-1) + valid_idx = valid_idx[valid_idx != -1] + + if key.dtype == torch.int8: + if dequant_scale is None: + raise ValueError("dequant_scale is required when key dtype is torch.int8.") + key_work = key[valid_idx].float() * self._reshape_scale(dequant_scale, kv_head_num, head_dim) + else: + key_work = key[valid_idx].to(torch.bfloat16) + + global_block_id = 0 + for batch_id in range(kv_idx.size(0)): + for block_id in range(max_block_nums): + page_id = int(kv_idx[batch_id, block_id].item()) + if page_id < 0: + continue + + cur_block = key_work[global_block_id] + rope_block = cur_block[:, :, nope_head_dim:] + block_sin = sin[block_id * page_size : (block_id + 1) * page_size].float().reshape(1, page_size, -1) + block_cos = cos[block_id * page_size : (block_id + 1) * page_size].float().reshape(1, page_size, -1) + + out_block = torch.empty(kv_head_num, page_size, head_dim, dtype=torch.bfloat16, device=key.device) + out_block[:, :, :nope_head_dim] = cur_block[:, :, :nope_head_dim].to(torch.bfloat16) + out_block[:, :, nope_head_dim:] = ( + self._rotate(rope_block.float()) * block_sin + rope_block.float() * block_cos + ).to(torch.bfloat16) + output[page_id] = out_block + + if rope_cache is not None: + if rope_cache.dtype == torch.int8: + if quant_scale is None: + raise ValueError("quant_scale is required when rope_cache dtype is torch.int8.") + quant = out_block.float() * self._reshape_scale(quant_scale, kv_head_num, head_dim).squeeze(0) + rope_cache[page_id] = torch.clamp(quant, -128, 127).to(torch.int8) + else: + rope_cache[page_id] = out_block + + global_block_id += 1 + + return output + + __all__ = [ + "MojoGatherRopeStore", "MojoStorePagedMLAKVCache", ] From bc5ed60e7fe065f9cb7e4c712cf1decf565308a4 Mon Sep 17 00:00:00 2001 From: HJzhang-sjtu Date: Thu, 11 Jun 2026 18:21:11 +0800 Subject: [PATCH 06/10] feat: add MojoPagedAttentionStoreKvCache --- mojo_opset/experimental/__init__.py | 2 + mojo_opset/experimental/operators/__init__.py | 2 + mojo_opset/experimental/operators/kv_cache.py | 116 ++++++++++++++++++ 3 files changed, 120 insertions(+) diff --git a/mojo_opset/experimental/__init__.py b/mojo_opset/experimental/__init__.py index 0a5b48143..a00e1c1f7 100755 --- a/mojo_opset/experimental/__init__.py +++ b/mojo_opset/experimental/__init__.py @@ -24,6 +24,7 @@ from .operators.indexer import MojoIndexer from .operators.indexer import MojoLightningIndexer from .operators.kv_cache import MojoGatherRopeStore +from .operators.kv_cache import MojoPagedAttentionStoreKvCache from .operators.kv_cache import MojoStorePagedMLAKVCache from .operators.moe import MojoFusedSwiGLUMoEScaleDynamicQuantize from .operators.moe import MojoMoEInitRoutingDynamicQuant @@ -55,6 +56,7 @@ "MojoFusedAttnOutputGate", "MojoPagedPrefillSageGQA", "MojoGatherRopeStore", + "MojoPagedAttentionStoreKvCache", "MojoStorePagedMLAKVCache", "MojoMoEInitRoutingDynamicQuant", "MojoFusedSwiGLUMoEScaleDynamicQuantize", diff --git a/mojo_opset/experimental/operators/__init__.py b/mojo_opset/experimental/operators/__init__.py index b2118a840..4333b1ac5 100644 --- a/mojo_opset/experimental/operators/__init__.py +++ b/mojo_opset/experimental/operators/__init__.py @@ -16,6 +16,7 @@ from .indexer import MojoIndexer from .indexer import MojoLightningIndexer from .kv_cache import MojoGatherRopeStore +from .kv_cache import MojoPagedAttentionStoreKvCache from .kv_cache import MojoStorePagedMLAKVCache from .moe import MojoFusedSwiGLUMoEScaleDynamicQuantize from .moe import MojoMoEInitRoutingDynamicQuant @@ -44,6 +45,7 @@ "MojoPagedPrefillSWAWithKVDequant", "MojoPagedDecodeSWAWithKVDequant", "MojoGatherRopeStore", + "MojoPagedAttentionStoreKvCache", "MojoStorePagedMLAKVCache", "MojoMoEInitRoutingDynamicQuant", "MojoFusedSwiGLUMoEScaleDynamicQuantize", diff --git a/mojo_opset/experimental/operators/kv_cache.py b/mojo_opset/experimental/operators/kv_cache.py index 5240563c0..0cdfb7b55 100644 --- a/mojo_opset/experimental/operators/kv_cache.py +++ b/mojo_opset/experimental/operators/kv_cache.py @@ -228,7 +228,123 @@ def forward( return output +class MojoPagedAttentionStoreKvCache(MojoOperator): + """Store packed paged-attention K/V states into paged KV cache.""" + + def __init__(self): + super().__init__() + + def forward( + self, + qkv: torch.Tensor, + key_cache: torch.Tensor, + value_cache: Optional[torch.Tensor], + block_table: torch.Tensor, + seq_len: torch.Tensor, + kv_len: torch.Tensor, + k_scale: Optional[torch.Tensor], + v_scale: Optional[torch.Tensor], + query_head_num: int, + kv_head_num: int, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + qkv: Packed Q/K/V tensor with shape + ``[sum(seq_len), (query_head_num + kv_head_num + maybe value heads) * head_dim]``. + key_cache: Paged key cache, shape ``[num_blocks, kv_head_num, block_size, head_dim]``. + value_cache: Optional paged value cache with the same shape as ``key_cache``. + block_table: Logical-to-physical block table, shape ``[batch_size, max_blocks]``. + Negative block ids stop storing for that sequence. + seq_len: Current sequence lengths, shape ``[batch_size]``. + kv_len: Existing KV lengths before storing, shape ``[batch_size]``. + k_scale: Optional int8 key quantization scale, shape ``[kv_head_num, head_dim]``. + v_scale: Optional int8 value quantization scale, shape ``[kv_head_num, head_dim]``. + query_head_num: Number of query heads packed before key heads. + kv_head_num: Number of key/value heads. + + Returns: + Tuple of updated ``(key_cache, value_cache)``. If ``value_cache`` is + ``None``, the second return value is ``key_cache`` for API compatibility. + """ + if qkv.dim() != 2: + raise ValueError(f"qkv must be 2D, got {tuple(qkv.shape)}.") + if key_cache.dim() != 4: + raise ValueError(f"key_cache must be 4D, got {tuple(key_cache.shape)}.") + if value_cache is not None and value_cache.shape != key_cache.shape: + raise ValueError( + f"value_cache shape must match key_cache, got {tuple(value_cache.shape)} and {tuple(key_cache.shape)}." + ) + if block_table.dim() != 2: + raise ValueError(f"block_table must be 2D, got {tuple(block_table.shape)}.") + if seq_len.dim() != 1 or kv_len.dim() != 1 or seq_len.numel() != kv_len.numel(): + raise ValueError( + f"seq_len and kv_len must be 1D tensors with same length, got {tuple(seq_len.shape)} and {tuple(kv_len.shape)}." + ) + + _, _, block_size, head_dim = key_cache.shape + has_value = value_cache is not None + query_head_num = int(query_head_num) + kv_head_num = int(kv_head_num) + process_seq_len = 0 + + for batch_id in range(seq_len.numel()): + now_seq_len = int(seq_len[batch_id].item()) + now_kv_len = int(kv_len[batch_id].item()) + now_block_table = block_table[batch_id] + + key_start = query_head_num * head_dim + key_end = (query_head_num + kv_head_num) * head_dim + now_key = qkv[process_seq_len : process_seq_len + now_seq_len, key_start:key_end] + now_key = now_key.reshape(-1, kv_head_num, head_dim).transpose(1, 0) + + if has_value: + value_start = key_end + now_value = qkv[process_seq_len : process_seq_len + now_seq_len, value_start:] + now_value = now_value.reshape(-1, kv_head_num, head_dim).transpose(1, 0) + + if key_cache.dtype == torch.int8: + if k_scale is None: + raise ValueError("k_scale is required when key_cache dtype is torch.int8.") + now_key = torch.clamp(torch.round(now_key.to(torch.float32) * k_scale.unsqueeze(1)), -128, 127).to( + torch.int8 + ) + if has_value: + if v_scale is None: + raise ValueError("v_scale is required when value_cache dtype is torch.int8.") + now_value = torch.clamp( + torch.round(now_value.to(torch.float32) * v_scale.unsqueeze(1)), -128, 127 + ).to(torch.int8) + + start_block_table_idx = now_kv_len // block_size + block_offset = now_kv_len % block_size + remain_seq_len = now_seq_len + kv_offset = 0 + + for block_id in now_block_table[start_block_table_idx:]: + block_id = int(block_id.item()) + if block_id < 0: + break + store_kv_len = min(block_size - block_offset, remain_seq_len) + key_cache[block_id, :, block_offset : block_offset + store_kv_len, :] = now_key[ + :, kv_offset : kv_offset + store_kv_len, : + ] + if has_value: + value_cache[block_id, :, block_offset : block_offset + store_kv_len, :] = now_value[ + :, kv_offset : kv_offset + store_kv_len, : + ] + block_offset = 0 + kv_offset += store_kv_len + remain_seq_len -= store_kv_len + if remain_seq_len <= 0: + break + + process_seq_len += now_seq_len + + return key_cache, value_cache if has_value else key_cache + + __all__ = [ "MojoGatherRopeStore", + "MojoPagedAttentionStoreKvCache", "MojoStorePagedMLAKVCache", ] From d5b789ef102fd83bdca9fef8ac4a9f1a06918f5b Mon Sep 17 00:00:00 2001 From: HJzhang-sjtu Date: Thu, 11 Jun 2026 19:50:01 +0800 Subject: [PATCH 07/10] feat: add MojoPagedCacheDequant --- mojo_opset/experimental/__init__.py | 2 + mojo_opset/experimental/operators/__init__.py | 2 + mojo_opset/experimental/operators/kv_cache.py | 64 +++++++++++++++++++ 3 files changed, 68 insertions(+) diff --git a/mojo_opset/experimental/__init__.py b/mojo_opset/experimental/__init__.py index a00e1c1f7..8e83e05e2 100755 --- a/mojo_opset/experimental/__init__.py +++ b/mojo_opset/experimental/__init__.py @@ -25,6 +25,7 @@ from .operators.indexer import MojoLightningIndexer from .operators.kv_cache import MojoGatherRopeStore from .operators.kv_cache import MojoPagedAttentionStoreKvCache +from .operators.kv_cache import MojoPagedCacheDequant from .operators.kv_cache import MojoStorePagedMLAKVCache from .operators.moe import MojoFusedSwiGLUMoEScaleDynamicQuantize from .operators.moe import MojoMoEInitRoutingDynamicQuant @@ -57,6 +58,7 @@ "MojoPagedPrefillSageGQA", "MojoGatherRopeStore", "MojoPagedAttentionStoreKvCache", + "MojoPagedCacheDequant", "MojoStorePagedMLAKVCache", "MojoMoEInitRoutingDynamicQuant", "MojoFusedSwiGLUMoEScaleDynamicQuantize", diff --git a/mojo_opset/experimental/operators/__init__.py b/mojo_opset/experimental/operators/__init__.py index 4333b1ac5..c2724c220 100644 --- a/mojo_opset/experimental/operators/__init__.py +++ b/mojo_opset/experimental/operators/__init__.py @@ -17,6 +17,7 @@ from .indexer import MojoLightningIndexer from .kv_cache import MojoGatherRopeStore from .kv_cache import MojoPagedAttentionStoreKvCache +from .kv_cache import MojoPagedCacheDequant from .kv_cache import MojoStorePagedMLAKVCache from .moe import MojoFusedSwiGLUMoEScaleDynamicQuantize from .moe import MojoMoEInitRoutingDynamicQuant @@ -46,6 +47,7 @@ "MojoPagedDecodeSWAWithKVDequant", "MojoGatherRopeStore", "MojoPagedAttentionStoreKvCache", + "MojoPagedCacheDequant", "MojoStorePagedMLAKVCache", "MojoMoEInitRoutingDynamicQuant", "MojoFusedSwiGLUMoEScaleDynamicQuantize", diff --git a/mojo_opset/experimental/operators/kv_cache.py b/mojo_opset/experimental/operators/kv_cache.py index 0cdfb7b55..cfd3232b8 100644 --- a/mojo_opset/experimental/operators/kv_cache.py +++ b/mojo_opset/experimental/operators/kv_cache.py @@ -343,8 +343,72 @@ def forward( return key_cache, value_cache if has_value else key_cache +class MojoPagedCacheDequant(MojoOperator): + """Dequantize an int8 paged KV cache.""" + + def __init__(self): + super().__init__() + + @staticmethod + def _reshape_scale(scale: torch.Tensor, head_num: int, head_dim: int) -> torch.Tensor: + if scale.dim() == 1: + if head_num != 1 or scale.size(0) != head_dim: + raise ValueError( + f"1D dequant_scale requires head_num=1 and shape [{head_dim}], got head_num={head_num}, shape={tuple(scale.shape)}." + ) + return scale.reshape(1, 1, 1, head_dim) + if scale.dim() == 2: + if scale.shape != (head_num, head_dim): + raise ValueError(f"2D dequant_scale must have shape [{head_num}, {head_dim}], got {tuple(scale.shape)}.") + return scale.reshape(1, head_num, 1, head_dim) + raise ValueError(f"dequant_scale must be 1D or 2D, got shape {tuple(scale.shape)}.") + + def forward( + self, + quantized_cache: torch.Tensor, + dequant_scale: torch.Tensor, + block_table: torch.Tensor, + ) -> torch.Tensor: + """ + Args: + quantized_cache: Int8 paged cache with shape + ``[num_blocks, head_num, block_size, head_dim]``. + dequant_scale: Dequantization scale with shape ``[head_num, head_dim]`` + or ``[head_dim]`` when ``head_num == 1``. + block_table: Logical-to-physical block table with shape + ``[batch_size, max_blocks]``. Entries with ``-1`` are ignored by + paged cache implementations. + + Returns: + BF16 dequantized cache with the same shape as ``quantized_cache``. + """ + if quantized_cache.dim() != 4: + raise ValueError( + f"quantized_cache must be 4D [num_blocks, head_num, block_size, head_dim], got {tuple(quantized_cache.shape)}." + ) + if quantized_cache.dtype != torch.int8: + raise TypeError(f"quantized_cache must be torch.int8, got {quantized_cache.dtype}.") + if block_table.dim() != 2: + raise ValueError(f"block_table must be 2D, got {tuple(block_table.shape)}.") + if block_table.dtype != torch.int64: + raise TypeError(f"block_table must be torch.int64, got {block_table.dtype}.") + + num_blocks, head_num, _, head_dim = quantized_cache.shape + scale = self._reshape_scale(dequant_scale, head_num, head_dim) + output = torch.empty(quantized_cache.shape, dtype=torch.bfloat16, device=quantized_cache.device) + valid_block_ids = block_table[block_table != -1] + if valid_block_ids.numel() == 0: + return output + if int(valid_block_ids.min().item()) < 0 or int(valid_block_ids.max().item()) >= num_blocks: + raise ValueError(f"block_table contains block ids outside [-1, {num_blocks}).") + + output[valid_block_ids] = (quantized_cache[valid_block_ids].float() * scale).to(torch.bfloat16) + return output + + __all__ = [ "MojoGatherRopeStore", + "MojoPagedCacheDequant", "MojoPagedAttentionStoreKvCache", "MojoStorePagedMLAKVCache", ] From 7abb604004519a6ec8bb76a8a7720357640f720a Mon Sep 17 00:00:00 2001 From: HJzhang-sjtu Date: Fri, 12 Jun 2026 11:58:53 +0800 Subject: [PATCH 08/10] feat: add MojoFusedAttnGateConcat --- mojo_opset/experimental/__init__.py | 2 + mojo_opset/experimental/operators/__init__.py | 2 + .../experimental/operators/attention_gate.py | 61 +++++++++++++++++++ 3 files changed, 65 insertions(+) diff --git a/mojo_opset/experimental/__init__.py b/mojo_opset/experimental/__init__.py index 8e83e05e2..374515c33 100755 --- a/mojo_opset/experimental/__init__.py +++ b/mojo_opset/experimental/__init__.py @@ -18,6 +18,7 @@ from .operators.attention import MojoPagedPrefillSWAWithKVDequant from .operators.attention import MojoPrefillMLA from .operators.attention import MojoPrefillNSA +from .operators.attention_gate import MojoFusedAttnGateConcat from .operators.attention_gate import MojoFusedAttnOutputGate from .operators.attention import MojoPagedPrefillSageGQA from .operators.gemm import MojoQuantBatchGemmReduceSum @@ -54,6 +55,7 @@ "MojoPagedDecodeGQAWithKVDequant", "MojoPagedPrefillSWAWithKVDequant", "MojoPagedDecodeSWAWithKVDequant", + "MojoFusedAttnGateConcat", "MojoFusedAttnOutputGate", "MojoPagedPrefillSageGQA", "MojoGatherRopeStore", diff --git a/mojo_opset/experimental/operators/__init__.py b/mojo_opset/experimental/operators/__init__.py index c2724c220..90c4de27c 100644 --- a/mojo_opset/experimental/operators/__init__.py +++ b/mojo_opset/experimental/operators/__init__.py @@ -11,6 +11,7 @@ from .attention import MojoPagedPrefillSWAWithKVDequant from .attention import MojoPrefillMLA from .attention import MojoPrefillNSA +from .attention_gate import MojoFusedAttnGateConcat from .attention_gate import MojoFusedAttnOutputGate from .gemm import MojoQuantBatchGemmReduceSum from .indexer import MojoIndexer @@ -30,6 +31,7 @@ __all__ = [ "MojoRotateActivation", + "MojoFusedAttnGateConcat", "MojoFusedAttnOutputGate", "MojoIndexer", "MojoLightningIndexer", diff --git a/mojo_opset/experimental/operators/attention_gate.py b/mojo_opset/experimental/operators/attention_gate.py index 7c9627205..ec8204422 100644 --- a/mojo_opset/experimental/operators/attention_gate.py +++ b/mojo_opset/experimental/operators/attention_gate.py @@ -115,3 +115,64 @@ def extra_repr(self) -> str: f"head_dim={self.head_dim}, " f"bias={self.full_gate_bias is not None}" ) + + +class MojoFusedAttnGateConcat(MojoOperator): + """Apply full/SWA attention gates and concatenate the gated outputs.""" + + def __init__(self): + super().__init__() + + def forward( + self, + full_attn_out: torch.Tensor, + full_attn_gate_score: torch.Tensor, + swa_attn_out: torch.Tensor, + swa_attn_gate_score: torch.Tensor, + ) -> torch.Tensor: + """ + Args: + full_attn_out: Full attention output with shape + ``[total_seq, full_head_num, head_dim]``. + full_attn_gate_score: Gate score for full attention with shape + ``[total_seq, full_head_num]``. + swa_attn_out: SWA attention output with shape + ``[total_seq, swa_head_num, head_dim]``. + swa_attn_gate_score: Gate score for SWA attention with shape + ``[total_seq, swa_head_num]``. + + Returns: + Gated concatenated attention output with shape + ``[total_seq, full_head_num + swa_head_num, head_dim]`` and the + same dtype as ``full_attn_out``. + """ + if swa_attn_gate_score is None: + raise ValueError("swa_attn_gate_score is required.") + if full_attn_out.dim() != 3: + raise ValueError(f"full_attn_out must be 3D, got {tuple(full_attn_out.shape)}.") + if swa_attn_out.dim() != 3: + raise ValueError(f"swa_attn_out must be 3D, got {tuple(swa_attn_out.shape)}.") + if full_attn_gate_score.dim() != 2: + raise ValueError(f"full_attn_gate_score must be 2D, got {tuple(full_attn_gate_score.shape)}.") + if swa_attn_gate_score.dim() != 2: + raise ValueError(f"swa_attn_gate_score must be 2D, got {tuple(swa_attn_gate_score.shape)}.") + + total_seq, full_head_num, head_dim = full_attn_out.shape + swa_total_seq, swa_head_num, swa_head_dim = swa_attn_out.shape + if swa_total_seq != total_seq or swa_head_dim != head_dim: + raise ValueError( + "full_attn_out and swa_attn_out must have matching total_seq and head_dim, " + f"got {tuple(full_attn_out.shape)} and {tuple(swa_attn_out.shape)}." + ) + if full_attn_gate_score.shape != (total_seq, full_head_num): + raise ValueError( + f"full_attn_gate_score must have shape [{total_seq}, {full_head_num}], got {tuple(full_attn_gate_score.shape)}." + ) + if swa_attn_gate_score.shape != (total_seq, swa_head_num): + raise ValueError( + f"swa_attn_gate_score must have shape [{total_seq}, {swa_head_num}], got {tuple(swa_attn_gate_score.shape)}." + ) + + full_out = full_attn_out.float() * torch.sigmoid(full_attn_gate_score.float()).unsqueeze(-1) + swa_out = swa_attn_out.float() * torch.sigmoid(swa_attn_gate_score.float()).unsqueeze(-1) + return torch.cat((full_out, swa_out), dim=1).to(full_attn_out.dtype) From af1b5fe4fd8ea757bc8ed807c27f0f6756254d78 Mon Sep 17 00:00:00 2001 From: HJzhang-sjtu Date: Fri, 12 Jun 2026 13:52:28 +0800 Subject: [PATCH 09/10] fix: add assert and fix c8 quant in golden --- mojo_opset/experimental/operators/attention_gate.py | 2 ++ mojo_opset/experimental/operators/kv_cache.py | 2 +- mojo_opset/experimental/operators/position_embedding.py | 2 ++ 3 files changed, 5 insertions(+), 1 deletion(-) diff --git a/mojo_opset/experimental/operators/attention_gate.py b/mojo_opset/experimental/operators/attention_gate.py index ec8204422..7a907bdbd 100644 --- a/mojo_opset/experimental/operators/attention_gate.py +++ b/mojo_opset/experimental/operators/attention_gate.py @@ -146,6 +146,8 @@ def forward( ``[total_seq, full_head_num + swa_head_num, head_dim]`` and the same dtype as ``full_attn_out``. """ + if full_attn_gate_score is None: + raise ValueError("full_attn_gate_score is required.") if swa_attn_gate_score is None: raise ValueError("swa_attn_gate_score is required.") if full_attn_out.dim() != 3: diff --git a/mojo_opset/experimental/operators/kv_cache.py b/mojo_opset/experimental/operators/kv_cache.py index cfd3232b8..c8e6595a6 100644 --- a/mojo_opset/experimental/operators/kv_cache.py +++ b/mojo_opset/experimental/operators/kv_cache.py @@ -218,7 +218,7 @@ def forward( if rope_cache.dtype == torch.int8: if quant_scale is None: raise ValueError("quant_scale is required when rope_cache dtype is torch.int8.") - quant = out_block.float() * self._reshape_scale(quant_scale, kv_head_num, head_dim).squeeze(0) + quant = torch.round(out_block.float() * self._reshape_scale(quant_scale, kv_head_num, head_dim).squeeze(0)) rope_cache[page_id] = torch.clamp(quant, -128, 127).to(torch.int8) else: rope_cache[page_id] = out_block diff --git a/mojo_opset/experimental/operators/position_embedding.py b/mojo_opset/experimental/operators/position_embedding.py index d9f55cb09..fec94c71f 100644 --- a/mojo_opset/experimental/operators/position_embedding.py +++ b/mojo_opset/experimental/operators/position_embedding.py @@ -226,6 +226,8 @@ def forward( Implementations may return ``qkv_input`` updated in-place or a newly constructed tensor. """ + if not qkv_input.is_contiguous(): + raise NotImplementedError("Torch MojoRotaryEmbedding currently supports qkv_input is contiguous only.") if qkv_input.dim() != 2: raise NotImplementedError("Torch MojoRotaryEmbedding currently supports packed 2D qkv_input only.") if cu_seq_lens is None: From 53e83cfbd55e5c456d89127f438ddfcad90492d4 Mon Sep 17 00:00:00 2001 From: "gaoyujia.01" Date: Fri, 12 Jun 2026 18:16:03 +0800 Subject: [PATCH 10/10] feat: add fused ag scale quant and qk rmsnorm --- mojo_opset/experimental/__init__.py | 4 + mojo_opset/experimental/operators/__init__.py | 4 + .../operators/compute_with_comm.py | 113 +++++++++++++ .../experimental/operators/normalization.py | 77 +++++++++ .../operators/test_compute_with_comm_quant.py | 148 ++++++++++++++++++ .../accuracy/operators/test_normalization.py | 58 +++++++ 6 files changed, 404 insertions(+) create mode 100644 mojo_opset/experimental/operators/compute_with_comm.py diff --git a/mojo_opset/experimental/__init__.py b/mojo_opset/experimental/__init__.py index 2434ed2f7..9db0dad3d 100755 --- a/mojo_opset/experimental/__init__.py +++ b/mojo_opset/experimental/__init__.py @@ -20,6 +20,7 @@ from .operators.attention import MojoPrefillNSA from .operators.attention_gate import MojoFusedAttnOutputGate from .operators.attention import MojoPagedPrefillSageGQA +from .operators.compute_with_comm import MojoFusedAGScaleQuant from .operators.gemm import MojoQuantBatchGemmReduceSum from .operators.indexer import MojoIndexer from .operators.indexer import MojoLightningIndexer @@ -28,6 +29,7 @@ from .operators.moe import MojoMoEInitRoutingDynamicQuant from .operators.normalization import MojoChannelRMSNorm from .operators.normalization import MojoGroupLayerNorm +from .operators.normalization import MojoQKInplaceRMSNorm from .operators.position_embedding import MojoGridRoPE from .operators.position_embedding import MojoRelativeEmbedding from .operators.store_lowrank import MojoStoreLowrank @@ -51,12 +53,14 @@ "MojoPagedPrefillSWAWithKVDequant", "MojoPagedDecodeSWAWithKVDequant", "MojoFusedAttnOutputGate", + "MojoFusedAGScaleQuant", "MojoPagedPrefillSageGQA", "MojoStorePagedMLAKVCache", "MojoMoEInitRoutingDynamicQuant", "MojoFusedSwiGLUMoEScaleDynamicQuantize", "MojoGroupLayerNorm", "MojoChannelRMSNorm", + "MojoQKInplaceRMSNorm", "MojoRelativeEmbedding", "MojoGridRoPE", "MojoStoreLowrank", diff --git a/mojo_opset/experimental/operators/__init__.py b/mojo_opset/experimental/operators/__init__.py index d0c747c3d..3d49ad9ce 100644 --- a/mojo_opset/experimental/operators/__init__.py +++ b/mojo_opset/experimental/operators/__init__.py @@ -12,6 +12,7 @@ from .attention import MojoPrefillMLA from .attention import MojoPrefillNSA from .attention_gate import MojoFusedAttnOutputGate +from .compute_with_comm import MojoFusedAGScaleQuant from .gemm import MojoQuantBatchGemmReduceSum from .indexer import MojoIndexer from .indexer import MojoLightningIndexer @@ -20,6 +21,7 @@ from .moe import MojoMoEInitRoutingDynamicQuant from .normalization import MojoChannelRMSNorm from .normalization import MojoGroupLayerNorm +from .normalization import MojoQKInplaceRMSNorm from .position_embedding import MojoGridRoPE from .position_embedding import MojoRelativeEmbedding from .store_lowrank import MojoStoreLowrank @@ -27,6 +29,7 @@ __all__ = [ "MojoRotateActivation", "MojoFusedAttnOutputGate", + "MojoFusedAGScaleQuant", "MojoIndexer", "MojoLightningIndexer", "MojoPrefillMLA", @@ -46,6 +49,7 @@ "MojoFusedSwiGLUMoEScaleDynamicQuantize", "MojoGroupLayerNorm", "MojoChannelRMSNorm", + "MojoQKInplaceRMSNorm", "MojoRelativeEmbedding", "MojoGridRoPE", "MojoStoreLowrank", diff --git a/mojo_opset/experimental/operators/compute_with_comm.py b/mojo_opset/experimental/operators/compute_with_comm.py new file mode 100644 index 000000000..6341c8c30 --- /dev/null +++ b/mojo_opset/experimental/operators/compute_with_comm.py @@ -0,0 +1,113 @@ +from typing import Optional + +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch.distributed.distributed_c10d import _get_default_group + +from mojo_opset.core.operator import MojoOperator + + +def _is_dist_initialized() -> bool: + return dist.is_available() and dist.is_initialized() + + +class MojoFusedAGScaleQuant(MojoOperator): + def __init__( + self, + *, + team_size: int = 1, + quant_mode: str = "per_token", + norm_mode: str = "none", + eps: float = 1e-5, + max_tokens: Optional[int] = None, + process_group: Optional[dist.ProcessGroup] = None, + comm_context=None, + **kwargs, + ): + """ + Fused AllGather-scale exchange + optional RMSNorm + per-token int8 quantization. + + Args: + team_size (int): Communication team size. + quant_mode (str): Quantization mode. Only ``"per_token"`` is supported. + norm_mode (str): Normalization mode. Supports ``"none"`` and ``"rmsnorm"``. + eps (float): Epsilon for RMSNorm. + max_tokens (Optional[int]): Maximum token count expected by backend + implementations that initialize communication buffers in ``__init__``. + process_group (Optional[ProcessGroup]): Distributed group for the torch reference. + ``None`` means the default group. + comm_context: Optional runtime/context object for backend implementations. + """ + super().__init__(**kwargs) + if quant_mode not in ["per_token"]: + raise NotImplementedError(f"quant_mode {quant_mode} not supported") + if norm_mode not in ["none", "rmsnorm"]: + raise NotImplementedError(f"norm_mode {norm_mode} not supported") + if team_size < 1: + raise ValueError(f"team_size must be positive, but got {team_size}") + if max_tokens is not None and max_tokens < 1: + raise ValueError(f"max_tokens must be positive, but got {max_tokens}") + + self.team_size = team_size + self.quant_mode = quant_mode + self.norm_mode = norm_mode + self.eps = eps + self.max_tokens = max_tokens + self.process_group = process_group + self.comm_context = comm_context + + def _team_max_scale(self, scale: torch.Tensor) -> torch.Tensor: + if self.team_size == 1 or not _is_dist_initialized(): + return scale + + process_group = self.process_group or _get_default_group() + world_size = dist.get_world_size(group=process_group) + if world_size == 1: + return scale + if world_size != self.team_size: + raise ValueError(f"process group world size must match team_size={self.team_size}, but got {world_size}") + + gathered = [torch.empty_like(scale) for _ in range(world_size)] + dist.all_gather(gathered, scale.contiguous(), group=process_group) + return torch.stack(gathered, dim=0).amax(dim=0) + + def forward( + self, + input: torch.Tensor, + quant_scale: torch.Tensor, + norm_weight: Optional[torch.Tensor] = None, + ): + if input.dim() not in [3, 4]: + raise ValueError(f"input must be 3-D or 4-D, but got dim={input.dim()}") + + head_num = input.shape[-2] + head_dim = input.shape[-1] + hidden_size = head_num * head_dim + if quant_scale.numel() != hidden_size: + raise ValueError(f"quant_scale numel must be {hidden_size}, but got {quant_scale.numel()}") + if self.norm_mode == "rmsnorm" and norm_weight is not None and norm_weight.numel() != head_dim: + raise ValueError(f"norm_weight numel must be {head_dim}, but got {norm_weight.numel()}") + + input_fp = input.float() + if self.norm_mode == "rmsnorm": + weight = norm_weight.float() if norm_weight is not None else None + input_fp = F.rms_norm(input_fp, (head_dim,), weight=weight, eps=self.eps) + + rows = input_fp.numel() // hidden_size + if self.max_tokens is not None and rows > self.max_tokens: + raise ValueError(f"input token count {rows} exceeds max_tokens={self.max_tokens}") + scaled = input_fp.reshape(rows, hidden_size) * quant_scale.float().reshape(1, hidden_size) + scale = scaled.abs().amax(dim=-1).clamp(min=1e-12) / 127 + scale = self._team_max_scale(scale) + quantized = torch.clamp(torch.round(scaled / scale.unsqueeze(-1)), -128, 127).to(torch.int8) + + return quantized, scale + + def extra_repr(self) -> str: + return ( + f"{self.team_size=}, {self.quant_mode=}, {self.norm_mode=}, {self.eps=}, {self.max_tokens=}" + ).replace("self.", "") + + +__all__ = ["MojoFusedAGScaleQuant"] diff --git a/mojo_opset/experimental/operators/normalization.py b/mojo_opset/experimental/operators/normalization.py index 2f44e8edf..f2076ff0b 100644 --- a/mojo_opset/experimental/operators/normalization.py +++ b/mojo_opset/experimental/operators/normalization.py @@ -92,7 +92,84 @@ def extra_repr(self) -> str: ) +class MojoQKInplaceRMSNorm(MojoOperator): + def __init__( + self, + q_heads: int, + kv_heads: int, + qk_head_dim: int = 128, + v_head_dim: int = None, + eps: float = 1e-5, + **kwargs, + ): + """ + Initialize in-place RMSNorm parameters for packed QKV tensors. + + Args: + q_heads (int): Number of query heads in the packed QKV tensor. + kv_heads (int): Number of key/value heads in the packed QKV tensor. + qk_head_dim (int, default=128): Per-head dimension for query and key. + v_head_dim (int, optional): Per-head dimension for value. Defaults to ``qk_head_dim``. + eps (float, default=1e-5): Epsilon added for numerical stability. + **kwargs: The keyword arguments of torch.empty, such as device and dtype. + """ + super().__init__(**kwargs) + self.q_heads = q_heads + self.kv_heads = kv_heads + self.qk_head_dim = qk_head_dim + self.v_head_dim = qk_head_dim if v_head_dim is None else v_head_dim + self.hidden_size = (q_heads + kv_heads) * qk_head_dim + kv_heads * self.v_head_dim + self.q_rms_weight = torch.nn.Parameter(torch.empty(qk_head_dim, **self.tensor_factory_kwargs)) + self.k_rms_weight = torch.nn.Parameter(torch.empty(qk_head_dim, **self.tensor_factory_kwargs)) + self.variance_epsilon = eps + + def forward(self, qkv: torch.Tensor) -> torch.Tensor: + """ + Apply RMSNorm to Q and K segments of a packed QKV tensor and write the results in-place. + + Args: + qkv (torch.Tensor): 2-D packed QKV tensor shaped + ``(tokens, (q_heads + kv_heads) * qk_head_dim + kv_heads * v_head_dim)``. + + Returns: + torch.Tensor: The same QKV tensor after in-place Q/K normalization. The V segment is unchanged. + """ + if qkv.dim() != 2: + raise ValueError(f"qkv must be a 2-D packed tensor, but got dim={qkv.dim()}") + if qkv.shape[-1] != self.hidden_size: + raise ValueError(f"qkv last dimension must be {self.hidden_size}, but got {qkv.shape[-1]}") + + q_end = self.q_heads * self.qk_head_dim + k_end = q_end + self.kv_heads * self.qk_head_dim + + query = qkv[:, :q_end].reshape(-1, self.q_heads, self.qk_head_dim) + key = qkv[:, q_end:k_end].reshape(-1, self.kv_heads, self.qk_head_dim) + query_norm = F.rms_norm( + query, + (self.qk_head_dim,), + weight=self.q_rms_weight, + eps=self.variance_epsilon, + ) + key_norm = F.rms_norm( + key, + (self.qk_head_dim,), + weight=self.k_rms_weight, + eps=self.variance_epsilon, + ) + + qkv[:, :q_end].copy_(query_norm.reshape(qkv.shape[0], q_end)) + qkv[:, q_end:k_end].copy_(key_norm.reshape(qkv.shape[0], k_end - q_end)) + return qkv + + def extra_repr(self) -> str: + return ( + f"{self.q_heads=}, {self.kv_heads=}, {self.qk_head_dim=}, " + f"{self.v_head_dim=}, {self.variance_epsilon=}" + ).replace("self.", "") + + __all__ = [ "MojoGroupLayerNorm", "MojoChannelRMSNorm", + "MojoQKInplaceRMSNorm", ] diff --git a/mojo_opset/tests/accuracy/operators/test_compute_with_comm_quant.py b/mojo_opset/tests/accuracy/operators/test_compute_with_comm_quant.py index ffa3f41b2..567d4905d 100644 --- a/mojo_opset/tests/accuracy/operators/test_compute_with_comm_quant.py +++ b/mojo_opset/tests/accuracy/operators/test_compute_with_comm_quant.py @@ -1,5 +1,8 @@ import os import socket +import subprocess +import sys +import tempfile import pytest import torch @@ -8,10 +11,16 @@ from mojo_opset import MojoAll2AllQuantGemm from mojo_opset import MojoQuantGemmAll2All +from mojo_opset.experimental import MojoFusedAGScaleQuant from mojo_opset.tests.utils import bypass_not_implemented +from mojo_opset.utils.platform import get_dist_backend +from mojo_opset.utils.platform import get_platform torch.manual_seed(42) +_PLATFORM = get_platform() +COMM_BACKEND = get_dist_backend() +DEVICE = _PLATFORM if _PLATFORM in ("npu", "mlu") else "cpu" def _free_port(): @@ -128,3 +137,142 @@ def test_all2all_quant_gemm_gloo(): nprocs=world_size, join=True, ) + + +def _is_dist_env() -> bool: + return "RANK" in os.environ and "WORLD_SIZE" in os.environ + + +def _device_count() -> int: + if DEVICE == "cpu": + return 1 + device_module = getattr(torch, DEVICE, None) + if device_module is None or not hasattr(device_module, "device_count"): + return 0 + return int(device_module.device_count()) + + +def _to_dev(t: torch.Tensor) -> torch.Tensor: + return t.to(DEVICE) if DEVICE != "cpu" else t + + +def _synchronize_device(): + if DEVICE == "cpu": + return + device_module = getattr(torch, DEVICE, None) + if device_module is not None and hasattr(device_module, "synchronize"): + device_module.synchronize() + + +def _init_dist(): + rank = int(os.environ.get("LOCAL_RANK", os.environ.get("RANK", "0"))) + if DEVICE == "npu": + import torch_npu # noqa: F401 + + if DEVICE != "cpu": + device_module = getattr(torch, DEVICE) + if hasattr(device_module, "set_device"): + device_module.set_device(rank) + + if not dist.is_initialized(): + dist.init_process_group(COMM_BACKEND) + return dist.get_rank(), dist.get_world_size() + + +def _run_torchrun_test(test_fn_name: str, *fn_args, nproc: int = 2, timeout: int = 600): + with tempfile.TemporaryDirectory() as tmp_dir: + script_path = os.path.join(tmp_dir, "run_dist_test.py") + args_repr = ", ".join(repr(arg) for arg in fn_args) + with open(script_path, "w", encoding="utf-8") as f: + f.write( + "from mojo_opset.tests.accuracy.operators.test_compute_with_comm_quant " + f"import {test_fn_name}\n" + f"{test_fn_name}({args_repr})\n" + ) + + cmd = [ + sys.executable, + "-m", + "torch.distributed.run", + f"--nproc_per_node={nproc}", + "--master_addr=127.0.0.1", + f"--master_port={_free_port()}", + script_path, + ] + subprocess.run(cmd, check=True, timeout=timeout) + + +def _fused_ag_scale_quant_team_inputs(world_size: int, token_num: int, head_num: int, head_dim: int, dtype): + generator = torch.Generator(device="cpu") + generator.manual_seed(2028) + return [ + torch.randn(token_num, head_num, head_dim, dtype=torch.float32, generator=generator).to(dtype) + for _ in range(world_size) + ] + + +def _dist_fused_ag_scale_quant(norm_mode: str): + rank, world_size = _init_dist() + try: + eps = 1e-5 + token_num = 5 + head_num = 2 + head_dim = 128 + hidden_size = head_num * head_dim + dtype = torch.bfloat16 + + inputs = _fused_ag_scale_quant_team_inputs(world_size, token_num, head_num, head_dim, dtype) + quant_scale = torch.ones(hidden_size, dtype=torch.float32, device="cpu") + norm_weight = ( + torch.ones(head_dim, dtype=torch.float32, device="cpu") + if norm_mode == "rmsnorm" + else None + ) + + input = _to_dev(inputs[rank].contiguous()) + quant_scale = _to_dev(quant_scale) + norm_weight = _to_dev(norm_weight) if norm_weight is not None else None + + op = MojoFusedAGScaleQuant( + team_size=world_size, + norm_mode=norm_mode, + eps=eps, + max_tokens=token_num, + ) + op_ref = MojoFusedAGScaleQuant._registry.get("torch")( + team_size=world_size, + norm_mode=norm_mode, + eps=eps, + max_tokens=token_num, + ) + + scale_atol = 1e-6 if norm_mode == "none" else 5e-4 + scale_rtol = 1e-6 if norm_mode == "none" else 5e-3 + op.forward_diff_with( + op_ref, + input, + quant_scale, + norm_weight, + atol=(1, scale_atol), + rtol=(0, scale_rtol), + ) + finally: + _synchronize_device() + _destroy_pg() + + +@pytest.mark.parametrize("team_size", [1, 2, 4, 8]) +@pytest.mark.parametrize("norm_mode", ["none", "rmsnorm"]) +@bypass_not_implemented +def test_fused_ag_scale_quant(team_size, norm_mode): + if _device_count() < team_size: + raise NotImplementedError(f"{team_size} {DEVICE} devices are required for this test") + + if _is_dist_env(): + _dist_fused_ag_scale_quant(norm_mode) + else: + _run_torchrun_test( + "_dist_fused_ag_scale_quant", + norm_mode, + nproc=team_size, + ) diff --git a/mojo_opset/tests/accuracy/operators/test_normalization.py b/mojo_opset/tests/accuracy/operators/test_normalization.py index d44fd1929..cccf8b05f 100644 --- a/mojo_opset/tests/accuracy/operators/test_normalization.py +++ b/mojo_opset/tests/accuracy/operators/test_normalization.py @@ -16,6 +16,7 @@ from mojo_opset import MojoRMSNormQuant from mojo_opset.experimental import MojoChannelRMSNorm from mojo_opset.experimental import MojoGroupLayerNorm +from mojo_opset.experimental import MojoQKInplaceRMSNorm torch.manual_seed(43) @@ -334,6 +335,63 @@ def test_channel_rmsnorm(x, norm_size, channel_first, images): norm.forward_diff_with(norm_ref, x, atol=atol, rtol=rtol) +@pytest.mark.parametrize("tokens", [4, 17]) +@pytest.mark.parametrize("q_heads,kv_heads", [(2, 1), (4, 2)]) +@pytest.mark.parametrize("v_head_dim", [64, 128, 256]) +@pytest.mark.parametrize("dtype", dtypes) +@bypass_not_implemented +def test_qk_inplace_rmsnorm(tokens, q_heads, kv_heads, v_head_dim, dtype): + qk_head_dim = 128 + hidden_size = (q_heads + kv_heads) * qk_head_dim + kv_heads * v_head_dim + qkv = torch.randn(size=(tokens, hidden_size), dtype=dtype) + qkv_ref = qkv.clone() + qkv_before = qkv.clone() + + op = MojoQKInplaceRMSNorm( + q_heads=q_heads, + kv_heads=kv_heads, + qk_head_dim=qk_head_dim, + v_head_dim=v_head_dim, + device=qkv.device, + dtype=qkv.dtype, + ) + op_ref = ( + MojoQKInplaceRMSNorm._registry.get("torch")( + q_heads=q_heads, + kv_heads=kv_heads, + qk_head_dim=qk_head_dim, + v_head_dim=v_head_dim, + ) + .to(qkv.device) + .to(qkv.dtype) + ) + + with torch.no_grad(): + q_weight = torch.randn(size=(qk_head_dim,), dtype=torch.float32, device=qkv.device) + k_weight = torch.randn(size=(qk_head_dim,), dtype=torch.float32, device=qkv.device) + op.q_rms_weight.copy_(q_weight) + op.k_rms_weight.copy_(k_weight) + op_ref.q_rms_weight.copy_(q_weight) + op_ref.k_rms_weight.copy_(k_weight) + + qkv_ptr = qkv.data_ptr() + out = op(qkv) + out_ref = op_ref(qkv_ref) + + assert out.data_ptr() == qkv_ptr + assert qkv.data_ptr() == qkv_ptr + + q_end = q_heads * qk_head_dim + k_end = q_end + kv_heads * qk_head_dim + torch.testing.assert_close(out[:, k_end:], qkv_before[:, k_end:], atol=0, rtol=0) + + if dtype == torch.float16: + atol, rtol = 2e-2, 1e-2 + else: + atol, rtol = 5e-2, 1e-2 + torch.testing.assert_close(out.float(), out_ref.float(), atol=atol, rtol=rtol) + + # =========================================================================== # NormQuant tests # ===========================================================================