Skip to content

feat: add DeepSeek-V4 inference on Ascend with TP/EP/DP/CP parallel support#375

Open
Song-begin wants to merge 47 commits into
masterfrom
tongbowen/dpskv4_ascend
Open

feat: add DeepSeek-V4 inference on Ascend with TP/EP/DP/CP parallel support#375
Song-begin wants to merge 47 commits into
masterfrom
tongbowen/dpskv4_ascend

Conversation

@Song-begin

Copy link
Copy Markdown
Collaborator

Features

  • support DeepSeek-V4 model inference on Ascend
  • support DeepSeek-V4 distributed inference with TP / EP / DP / CP
  • support DeepSeek-V4 multi-batch inference
  • support MTP for DeepSeek-V4
  • support graph mode / npugraph-ex execution for DeepSeek-V4
  • support single-node and multi-node inference deployment
  • support AscendC / torch_npu operators required by DeepSeek-V4

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces AscendC backend support, custom operators (such as MojoCompressor, MojoHcPost, MojoHcPre, MojoScatterNdUpdateAsc, and MojoRMSNormDynamicQuant), and multi-node/multi-card LLM inference scripts optimized for DeepSeek-V4. Key feedback includes fixing a logic bug in llm_inference.py that builds the model twice when --transformers is enabled, vectorizing a loop in MojoScatterNdUpdateAsc to avoid performance-degrading host-device synchronizations, removing the evaluation-at-import anti-pattern of torch.npu.current_device() in a function signature, and replacing hardcoded NPU device indices (13) in the sparse attention tests to prevent failures on standard hardware.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment on lines +120 to +123
device: str = f"npu:{torch.npu.current_device()}",

):
return torch.ops.custom.npu_sparse_attn_sharedkv_metadata(

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Using torch.npu.current_device() as a default argument in the function signature is a Python anti-pattern. Default arguments are evaluated once at module import time, not at function execution time. This can cause errors if the NPU is not yet initialized when the module is imported, or it will incorrectly lock the default device to npu:0 (or the import-time active device) for all ranks in a multi-NPU environment.

Suggested change
device: str = f"npu:{torch.npu.current_device()}",
):
return torch.ops.custom.npu_sparse_attn_sharedkv_metadata(
device: str = None,
):
if device is None:
device = f"npu:{torch.npu.current_device()}"
return torch.ops.custom.npu_sparse_attn_sharedkv_metadata(

Comment thread examples/llm_inference.py
Comment on lines +1870 to +1871
else:
model = build_model_from_hf(

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

When --transformers is enabled, the model is built on lines 1717-1723. However, because AutoModelForCausalLM does not have a load_weights attribute, the execution falls into the else block on line 1870, causing the model to be built a second time. This redundant model creation can lead to high memory usage or Out-Of-Memory (OOM) errors.

Suggested change
else:
model = build_model_from_hf(
elif not args.transformers:
model = build_model_from_hf(

Comment on lines +43 to +47
for i in range(idx.numel()):
j = int(idx[i].item())
if j >= 0:
var[j, :] = update[i, :]
return var

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Using a Python loop with .item() on NPU tensors causes host-device synchronization at every iteration. For large inputs (e.g., u = 8192 in tests), this will result in thousands of syncs, severely degrading performance. This can be fully vectorized using PyTorch indexing, which runs entirely on the device and is orders of magnitude faster.

Suggested change
for i in range(idx.numel()):
j = int(idx[i].item())
if j >= 0:
var[j, :] = update[i, :]
return var
idx = indices.to(dtype=torch.int64).view(-1)
valid_mask = idx >= 0
if valid_mask.any():
var[idx[valid_mask]] = update[valid_mask]

def test_mojo_sparse_attn_sharedkv_metadata():
if not (hasattr(torch, 'npu') and torch.npu.is_available()):
pytest.skip("NPU not available!")

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Hardcoding NPU device index 13 will cause the test to fail on standard systems with fewer than 14 NPUs (e.g., standard 8-NPU or 1-NPU nodes). It is better to use device 0 or dynamically query the current device.

Suggested change
torch.npu.set_device(0)

layout_kv='PA_ND',
has_ori_kv=True,
has_cmp_kv=False,
device='npu:13'

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Hardcoding NPU device index 13 will cause the test to fail on standard systems with fewer than 14 NPUs. It is better to use device 0 or dynamically query the current device.

Suggested change
device='npu:13'
device='npu:0'

if not (hasattr(torch, 'npu') and torch.npu.is_available()):
pytest.skip("NPU not available!")

torch.npu.set_device(13)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Hardcoding NPU device index 13 will cause the test to fail on standard systems with fewer than 14 NPUs. It is better to use device 0 or dynamically query the current device.

Suggested change
torch.npu.set_device(13)
torch.npu.set_device(0)

layout_kv='PA_ND',
has_ori_kv=True,
has_cmp_kv=False,
device='npu:13'

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Hardcoding NPU device index 13 will cause the test to fail on standard systems with fewer than 14 NPUs. It is better to use device 0 or dynamically query the current device.

Suggested change
device='npu:13'
device='npu:0'

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants