Skip to content

add deepep and test, modify test_moe_quant and test_attention_quant#349

Open
qiushi13 wants to merge 13 commits into
masterfrom
heqiushi/moe_quant
Open

add deepep and test, modify test_moe_quant and test_attention_quant#349
qiushi13 wants to merge 13 commits into
masterfrom
heqiushi/moe_quant

Conversation

@qiushi13

@qiushi13 qiushi13 commented Jun 5, 2026

Copy link
Copy Markdown
Collaborator

No description provided.

@github-actions

github-actions Bot commented Jun 5, 2026

Copy link
Copy Markdown

Claude Code Review

Verdict: Request changes -- Combine path has a correctness bug in the per-rank slice of the global scatter and a likely-incorrect quant rounding rule.

Summary

Adds a torch reference implementation of DeepEP-style MoE dispatch/combine plus accuracy tests, and extends the quant MoE test to a distributed xops-vs-torch comparison. The torch backend is intended as a reference for the xops kernel.

Must fix

  • [BLOCKER] Wrong slice of global scatter in combine -- mojo_opset/core/operators/deepep.py:269-271 -- local_global_pos = global_scatter[offset : offset + q_len*top_k] assumes the global tensor was laid out as [rank0_tokens, rank1_tokens, ...], but the dispatch path constructs global_top_k via all_gather_into_tensor of top_k_indices (shape [q_len, top_k]), which gives that exact layout only if every rank has equal q_len and all_gather_into_tensor concatenates along dim 0 -- but global_scatter indexes into global_expand which is reordered by per-rank r_per_rank slicing of the padded gather, not by token order. The two index spaces do not match, so combine will pull the wrong rows. Reconstruct the mapping consistently (e.g. compute pack_index/inverse the same way dispatch does) rather than slicing by rank * q_len * top_k.
  • [BLOCKER] Quant rounding does not implement round-half-to-even / standard rounding -- mojo_opset/core/operators/deepep.py:120-122 -- floor(|x| + 0.5) * sign(x) rounds half away from zero and, more importantly, mismatches what production int8 quant kernels typically do (round-to-nearest-even or torch.round). Since this is the reference the xops kernel is checked against, confirm the kernel's exact rounding rule and match it, or the dispatch comparison will have systematic 1-ULP errors that the atol=0 in the test will flag.

Suggestions

Suggestions (5)
  • [MAJOR] expand_scale placeholder is uninitialized memory -- mojo_opset/core/operators/deepep.py:124-126 -- torch.empty returns garbage; the test masks this with inf tolerance, but any caller reading it (or a future test without the mask) will see nondeterministic values. Use torch.zeros or document and return None.
  • [MAJOR] .item() syncs in a hot path -- mojo_opset/core/operators/deepep.py:108-109,247-251 -- Multiple int(cum[...].item()) calls force device-host syncs per call. Acceptable for a torch reference, but worth a comment so nobody copies this pattern into a real backend.
  • [MAJOR] Variable-size all_gather pads to max_r then trims -- mojo_opset/core/operators/deepep.py:240-260 -- This works but is O(world_size * max_r * hidden) memory; for the 384-expert/3072-hidden case in tests this is non-trivial. Consider all_gather_object of sizes + a single all_to_all_single instead, or at least note the cost.
  • [MAJOR] Distributed tests will hard-fail without NPU+hccl -- mojo_opset/tests/accuracy/operators/test_deepep.py:46-48 -- Skip predicates rely on torch.npu existing; on a non-NPU CI the torch.npu.device_count() call will raise before the skip. Guard with hasattr(torch, "npu").
  • [MINOR] output/output_scale/output_size args accepted but unused -- mojo_opset/core/operators/deepep.py:69-72 -- Either honor them (copy into provided buffers) or drop from the signature; silently ignoring user-provided output buffers is surprising.

Nits

Nits (3)
  • [NIT] Non-English comment in test -- mojo_opset/tests/accuracy/operators/test_moe_quant.py:670 -- "非整除8时精度不对..." should be English to match the rest of the codebase.
  • [NIT] Unrelated test-config addition -- mojo_opset/tests/accuracy/operators/test_attention_quant.py:528 -- Adding ("ABAB", 4, 2047) is unrelated to DeepEP; consider splitting.
  • [NIT] Layering -- mojo_opset/tests/accuracy/operators/test_deepep.py:50 -- from mojo_opset_ext.backends.xpu_ops... inside core tests couples accuracy tests to a specific backend module path; at minimum guard the import.

Notes

  • [CHECK] mojo_opset/core/operators/deepep.py:88-90 -- dist.all_gather_into_tensor requires the output tensor to be contiguous and exactly world_size * input.numel(); confirm top_k_indices dtype after the .to(torch.int32) cast is consistent across ranks (it is, but worth a sanity check for int64 callers).
  • [CHECK] mojo_opset/core/operators/moe.py:126 -- Adding top_k=self.top_k to the experts kwargs may break existing torch backend MojoQuantExperts signatures that don't accept it; verify all registered backends accept top_k.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces DeepEP-style cross-rank MoE dispatch and combine operators (MojoDeepEPDispatch and MojoDeepEPCombine), along with corresponding accuracy and distributed tests. It also updates the MoE operator to pass top_k to MojoQuantExperts and expands the test coverage for quantized MoE and attention operators. Feedback highlights a potential division-by-zero issue in the dispatch operator when calculating expand_scale and suggests avoiding monkey-patching _dispatch_up_proj_inv_smooth_scale in the tests.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment on lines +157 to +158
expand_scale = smoothed.abs().amax(-1, keepdim=True) / 127.0
x = smoothed / expand_scale

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There's a potential division-by-zero issue here. If a row in smoothed is all zeros, expand_scale will be zero, leading to 0.0 / 0.0 which results in NaN. This can propagate through the model and cause correctness issues.

You should use a safe division pattern to handle this case. Using torch.nan_to_num is a concise way to ensure that any NaN resulting from 0/0 is converted to 0, which is the correct behavior here.

Suggested change
expand_scale = smoothed.abs().amax(-1, keepdim=True) / 127.0
x = smoothed / expand_scale
expand_scale = smoothed.abs().amax(-1, keepdim=True) / 127.0
x = torch.nan_to_num(smoothed / expand_scale)

).to(device)
op.load_state_dict(state_dict)
if is_xops_backend:
op._dispatch_up_proj_inv_smooth_scale = op.experts.up_proj_quantize.inv_smooth_scale

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

You are monkey-patching the _dispatch_up_proj_inv_smooth_scale attribute onto the op instance. This attribute is not declared in the MojoQuantMoE class __init__ method, which can make the code harder to understand and maintain. It suggests a hidden dependency for the xops backend.

While this might be a necessary workaround for testing, consider a cleaner approach for passing this data to the operator. For example, you could pass it as a keyword argument to the __init__ method if the backend implementation supports it, or add a dedicated setter method on the operator. This would make the data flow more explicit.

@github-actions

github-actions Bot commented Jun 8, 2026

Copy link
Copy Markdown

Claude Code Review

Verdict: Request changes -- Torch reference for DeepEP combine has a global-scatter indexing bug that will break correctness for group_size > 1.

Summary

Adds MojoDeepEPDispatch / MojoDeepEPCombine operators with a torch reference implementation backed by torch.distributed, plus distributed accuracy tests for them and an extended quantized MoE EP test path. The torch reference is intended as a correctness oracle for the xops backend.

Must fix

  • [BLOCKER] Combine global scatter offset is wrong -- mojo_opset/core/operators/deepep.py:259-265 -- offset = self.rank * q_len * top_k indexes global_scatter (which is ordered by global token-major: rank 0 tokens, then rank 1, ...) but global_scatter is the inverse of an argsort over global_flat (ordered by expert id), so positions in global_scatter correspond to the (token, k) input order which is exactly token-major across ranks -- this is correct only because all_gather_into_tensor concatenates rank-major. However global_scatter[i] gives the position in global_expand (sorted-by-expert, then concatenated rank-major after slicing per-rank ranges); global_expand is built by per-rank slicing of expert-sorted outputs and concatenating, which is not the same ordering as a single global expert-stable sort. The gathered reduction will pick wrong rows. Reconstruct global_expand to match global_sort_perm order (sort by expert id globally), e.g. by computing per-(rank, local-expert) segment offsets and reordering, before indexing with global_scatter.
  • [BLOCKER] Dispatch int8 quant rounding has sign bug at zero -- mojo_opset/core/operators/deepep.py:160-163 -- floor(|x| + 0.5) * sign(x) returns 0 for negative values in (-0.5, 0) but sign(x) = -1 so it works; however torch.sign(0) == 0, so any exactly-zero smoothed row will yield NaN scale (amax=0 -> div-by-zero) and zero output. Guard expand_scale against zero (clamp_min(eps)) and use round-half-away-from-zero ((x + 0.5*sign).trunc() or torch.round) to match the kernel's documented "±1" tolerance.

Suggestions

Suggestions (5)
  • [MAJOR] Silent dtype downcast in dispatch returns int64 vs int32 -- mojo_opset/core/operators/deepep.py:175-176 -- expert_token_cnt_cumsum is int64 here but the test notes the xops wrapper returns int32; document this contract and pick one to avoid downstream surprises.
  • [MAJOR] expand_scale is uninitialized when no smooth_scale -- mojo_opset/core/operators/deepep.py:166-168 -- torch.empty returns garbage; use torch.zeros or ones so consumers that read it unconditionally do not get NaNs.
  • [MAJOR] .item() calls force device sync on hot path -- mojo_opset/core/operators/deepep.py:148-149,236-241 -- multiple int(cum[...].item()) calls per forward will serialize against the collective; acceptable for a torch reference but worth a comment that this is reference-only.
  • [MAJOR] Layering: core operator imports mojo_opset_ext.backends.xpu_ops -- mojo_opset/tests/accuracy/operators/test_deepep.py:46, mojo_opset/tests/accuracy/operators/test_moe_quant.py:294 -- tests in core reach into a backend-specific module; either move these tests under the backend's test tree or guard the import behind a backend check helper.
  • [MINOR] Unused parameters in dispatch forward -- mojo_opset/core/operators/deepep.py:84-86 -- output, output_scale, output_size are accepted but never used; either wire them through or drop them to avoid misleading callers.

Nits

Nits (3)
  • [NIT] Non-ASCII em-dashes in docstrings -- mojo_opset/core/operators/deepep.py:39-58 -- mixed unicode punctuation in user-facing docs; harmless but inconsistent with the rest of the repo.
  • [NIT] Chinese comment in test file -- mojo_opset/tests/accuracy/operators/test_moe_quant.py:298 -- translate to English for consistency.
  • [NIT] result_queue.empty() race -- mojo_opset/tests/accuracy/operators/test_deepep.py:74-76, test_moe_quant.py:265-267 -- prefer draining by expected count (world_size) rather than empty() polling, which is unreliable.

Notes

  • [CHECK] The dispatch torch path uses dist.all_gather_into_tensor on top_k_indices cast to int32 -- confirm hccl supports int32 all-gather on the target platform; otherwise cast to int64 for the collective and back afterwards.
  • [CHECK] In combine world_size>1 branch, expert_token_count is treated as global (per-rank input contains global counts) -- the dispatch returns local expert_token_count (line 95), so feeding dispatch's output directly into combine on multi-rank will mis-reconstruct sorted_local_experts. Verify the expected contract end-to-end.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants