feat: add DeepSeek-V4 inference on Ascend with TP/EP/DP/CP parallel support#375
feat: add DeepSeek-V4 inference on Ascend with TP/EP/DP/CP parallel support#375Song-begin wants to merge 47 commits into
Conversation
…opset into tongbowen/dpskv4_ascend_multibatch_v1
# Conflicts: # examples/llm_inference.py # mojo_opset/modeling/deepseekv4/mojo_deepseek_v4.py
There was a problem hiding this comment.
Code Review
This pull request introduces AscendC backend support, custom operators (such as MojoCompressor, MojoHcPost, MojoHcPre, MojoScatterNdUpdateAsc, and MojoRMSNormDynamicQuant), and multi-node/multi-card LLM inference scripts optimized for DeepSeek-V4. Key feedback includes fixing a logic bug in llm_inference.py that builds the model twice when --transformers is enabled, vectorizing a loop in MojoScatterNdUpdateAsc to avoid performance-degrading host-device synchronizations, removing the evaluation-at-import anti-pattern of torch.npu.current_device() in a function signature, and replacing hardcoded NPU device indices (13) in the sparse attention tests to prevent failures on standard hardware.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| device: str = f"npu:{torch.npu.current_device()}", | ||
|
|
||
| ): | ||
| return torch.ops.custom.npu_sparse_attn_sharedkv_metadata( |
There was a problem hiding this comment.
Using torch.npu.current_device() as a default argument in the function signature is a Python anti-pattern. Default arguments are evaluated once at module import time, not at function execution time. This can cause errors if the NPU is not yet initialized when the module is imported, or it will incorrectly lock the default device to npu:0 (or the import-time active device) for all ranks in a multi-NPU environment.
| device: str = f"npu:{torch.npu.current_device()}", | |
| ): | |
| return torch.ops.custom.npu_sparse_attn_sharedkv_metadata( | |
| device: str = None, | |
| ): | |
| if device is None: | |
| device = f"npu:{torch.npu.current_device()}" | |
| return torch.ops.custom.npu_sparse_attn_sharedkv_metadata( |
| else: | ||
| model = build_model_from_hf( |
There was a problem hiding this comment.
When --transformers is enabled, the model is built on lines 1717-1723. However, because AutoModelForCausalLM does not have a load_weights attribute, the execution falls into the else block on line 1870, causing the model to be built a second time. This redundant model creation can lead to high memory usage or Out-Of-Memory (OOM) errors.
| else: | |
| model = build_model_from_hf( | |
| elif not args.transformers: | |
| model = build_model_from_hf( |
| for i in range(idx.numel()): | ||
| j = int(idx[i].item()) | ||
| if j >= 0: | ||
| var[j, :] = update[i, :] | ||
| return var |
There was a problem hiding this comment.
Using a Python loop with .item() on NPU tensors causes host-device synchronization at every iteration. For large inputs (e.g., u = 8192 in tests), this will result in thousands of syncs, severely degrading performance. This can be fully vectorized using PyTorch indexing, which runs entirely on the device and is orders of magnitude faster.
| for i in range(idx.numel()): | |
| j = int(idx[i].item()) | |
| if j >= 0: | |
| var[j, :] = update[i, :] | |
| return var | |
| idx = indices.to(dtype=torch.int64).view(-1) | |
| valid_mask = idx >= 0 | |
| if valid_mask.any(): | |
| var[idx[valid_mask]] = update[valid_mask] |
| def test_mojo_sparse_attn_sharedkv_metadata(): | ||
| if not (hasattr(torch, 'npu') and torch.npu.is_available()): | ||
| pytest.skip("NPU not available!") | ||
|
|
| layout_kv='PA_ND', | ||
| has_ori_kv=True, | ||
| has_cmp_kv=False, | ||
| device='npu:13' |
| if not (hasattr(torch, 'npu') and torch.npu.is_available()): | ||
| pytest.skip("NPU not available!") | ||
|
|
||
| torch.npu.set_device(13) |
| layout_kv='PA_ND', | ||
| has_ori_kv=True, | ||
| has_cmp_kv=False, | ||
| device='npu:13' |
Features