Skip to content

Tongbowen/dpskv4 ascend#322

Open
bowentong-HW wants to merge 46 commits into
dev/dpskv4_ascendfrom
tongbowen/dpskv4_ascend
Open

Tongbowen/dpskv4 ascend#322
bowentong-HW wants to merge 46 commits into
dev/dpskv4_ascendfrom
tongbowen/dpskv4_ascend

Conversation

@bowentong-HW

Copy link
Copy Markdown
Collaborator

add deepseek v4 infer:

  1. Support EP8 inference;
  2. Support multi-batch inference;
  3. Support npugraph-ex;
  4. Integrate cann-receips operators;
  5. only support w8a8 int8

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for DeepSeek-V4 inference, adding several custom NPU operators, distributed execution capabilities, and graph mode compilation via torch.compile. The feedback highlights several areas for improvement, including the removal of hardcoded device indices and debug print statements, addressing performance bottlenecks caused by frequent host-device synchronization, and ensuring exception safety when modifying global PyTorch settings. Additionally, the reviewer pointed out redundant tensor initializations and operations that could be simplified for better efficiency.

layout_kv: str = 'PA_ND',
has_ori_kv: bool = True,
has_cmp_kv: bool = False,
device: str = 'npu:0',

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The device index is hardcoded to 'npu:0'. This will cause issues in multi-NPU environments where the current process might be assigned to a different device. Use f"npu:{torch.npu.current_device()}" or a device parameter passed from the model.

Suggested change
device: str = 'npu:0',
device: str = f"npu:{torch.npu.current_device()}",

Comment thread examples/llm_inference.py Outdated
input_ids = torch.full((len(encoded), max_len), pad_token_id, dtype=torch.long)
attention_mask = torch.zeros((len(encoded), max_len), dtype=torch.bool)
for idx, ids in enumerate(encoded):
flat = ids.squeeze(0).cpu()

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Moving tensors to CPU within a loop during batch padding can significantly degrade performance due to frequent synchronization between the host and the NPU. It is more efficient to perform these operations on the device or pre-process the data before the inference loop.

Comment thread examples/llm_inference.py Outdated
Comment on lines +487 to +490
torch.set_default_dtype(torch.bfloat16)
with no_init_weights():
model = model_class(hf_config, num_layers=args.num_layers, ep_size=ep_size, ep_rank=ep_rank)
torch.set_default_dtype(origin_dtype)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Changing the global default dtype using torch.set_default_dtype is not exception-safe here. If an error occurs during model construction, the global state will remain altered, potentially affecting subsequent operations. Consider using a try...finally block or a context manager to ensure the original dtype is restored.

Suggested change
torch.set_default_dtype(torch.bfloat16)
with no_init_weights():
model = model_class(hf_config, num_layers=args.num_layers, ep_size=ep_size, ep_rank=ep_rank)
torch.set_default_dtype(origin_dtype)
origin_dtype = torch.get_default_dtype()
try:
torch.set_default_dtype(torch.bfloat16)
with no_init_weights():
model = model_class(hf_config, num_layers=args.num_layers, ep_size=ep_size, ep_rank=ep_rank)
finally:
torch.set_default_dtype(origin_dtype)

Comment on lines +46 to +50
query = query.clone().contiguous()
key = key.clone().contiguous()
weights = weights.clone().contiguous() if weights is not None else None
query_dequant_scale = query_dequant_scale.clone().contiguous()
key_dequant_scale = key_dequant_scale.clone().contiguous()

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Calling .clone().contiguous() is redundant because contiguous() already returns a copy if the tensor is not already contiguous. If the tensor is already contiguous, clone() creates an unnecessary deep copy. Using just .contiguous() is more efficient.

Suggested change
query = query.clone().contiguous()
key = key.clone().contiguous()
weights = weights.clone().contiguous() if weights is not None else None
query_dequant_scale = query_dequant_scale.clone().contiguous()
key_dequant_scale = key_dequant_scale.clone().contiguous()
query = query.contiguous()
key = key.contiguous()
weights = weights.contiguous() if weights is not None else None
query_dequant_scale = query_dequant_scale.contiguous()
key_dequant_scale = key_dequant_scale.contiguous()

Comment thread mojo_opset/core/operators/hc_post.py Outdated
Returns:
Output tensor with the same shape as residual.
"""
print('qqqq')

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

A debug print statement print('qqqq') was left in the production code. Please remove it to keep the logs clean.

Comment thread mojo_opset/core/operators/indexer.py Outdated
Comment on lines +451 to +455
index_score = torch.zeros(
(batch_size, q_seq_len, k_seq_len),
dtype=torch.float32,
device=query.device,
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The index_score tensor is initialized twice with the same parameters. The first initialization at line 445 is redundant.

Comment on lines +113 to +114
return y_out, expert_idx_out, norm_out_fp32
return y_out, expert_idx_out, norm_out_fp32

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The if not out_flag check is redundant because both branches return the exact same values. The logic can be simplified to a single return statement.

Suggested change
return y_out, expert_idx_out, norm_out_fp32
return y_out, expert_idx_out, norm_out_fp32
return y_out, expert_idx_out, norm_out_fp32

if not (hasattr(torch, 'npu') and torch.npu.is_available()):
pytest.skip("NPU not available!")

torch.npu.set_device(13)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Hardcoding the NPU device index to 13 will cause the test to fail on systems with fewer than 14 NPUs. Use torch.npu.set_device(0) or detect an available device dynamically.

Suggested change
torch.npu.set_device(13)
torch.npu.set_device(0)

@bowentong-HW bowentong-HW force-pushed the tongbowen/dpskv4_ascend branch from 101146e to 772adf9 Compare May 30, 2026 07:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants