Tongbowen/dpskv4 ascend#322
Conversation
…opset into tongbowen/dpskv4_ascend_multibatch_v1
There was a problem hiding this comment.
Code Review
This pull request introduces support for DeepSeek-V4 inference, adding several custom NPU operators, distributed execution capabilities, and graph mode compilation via torch.compile. The feedback highlights several areas for improvement, including the removal of hardcoded device indices and debug print statements, addressing performance bottlenecks caused by frequent host-device synchronization, and ensuring exception safety when modifying global PyTorch settings. Additionally, the reviewer pointed out redundant tensor initializations and operations that could be simplified for better efficiency.
| layout_kv: str = 'PA_ND', | ||
| has_ori_kv: bool = True, | ||
| has_cmp_kv: bool = False, | ||
| device: str = 'npu:0', |
There was a problem hiding this comment.
The device index is hardcoded to 'npu:0'. This will cause issues in multi-NPU environments where the current process might be assigned to a different device. Use f"npu:{torch.npu.current_device()}" or a device parameter passed from the model.
| device: str = 'npu:0', | |
| device: str = f"npu:{torch.npu.current_device()}", |
| input_ids = torch.full((len(encoded), max_len), pad_token_id, dtype=torch.long) | ||
| attention_mask = torch.zeros((len(encoded), max_len), dtype=torch.bool) | ||
| for idx, ids in enumerate(encoded): | ||
| flat = ids.squeeze(0).cpu() |
There was a problem hiding this comment.
| torch.set_default_dtype(torch.bfloat16) | ||
| with no_init_weights(): | ||
| model = model_class(hf_config, num_layers=args.num_layers, ep_size=ep_size, ep_rank=ep_rank) | ||
| torch.set_default_dtype(origin_dtype) |
There was a problem hiding this comment.
Changing the global default dtype using torch.set_default_dtype is not exception-safe here. If an error occurs during model construction, the global state will remain altered, potentially affecting subsequent operations. Consider using a try...finally block or a context manager to ensure the original dtype is restored.
| torch.set_default_dtype(torch.bfloat16) | |
| with no_init_weights(): | |
| model = model_class(hf_config, num_layers=args.num_layers, ep_size=ep_size, ep_rank=ep_rank) | |
| torch.set_default_dtype(origin_dtype) | |
| origin_dtype = torch.get_default_dtype() | |
| try: | |
| torch.set_default_dtype(torch.bfloat16) | |
| with no_init_weights(): | |
| model = model_class(hf_config, num_layers=args.num_layers, ep_size=ep_size, ep_rank=ep_rank) | |
| finally: | |
| torch.set_default_dtype(origin_dtype) |
| query = query.clone().contiguous() | ||
| key = key.clone().contiguous() | ||
| weights = weights.clone().contiguous() if weights is not None else None | ||
| query_dequant_scale = query_dequant_scale.clone().contiguous() | ||
| key_dequant_scale = key_dequant_scale.clone().contiguous() |
There was a problem hiding this comment.
Calling .clone().contiguous() is redundant because contiguous() already returns a copy if the tensor is not already contiguous. If the tensor is already contiguous, clone() creates an unnecessary deep copy. Using just .contiguous() is more efficient.
| query = query.clone().contiguous() | |
| key = key.clone().contiguous() | |
| weights = weights.clone().contiguous() if weights is not None else None | |
| query_dequant_scale = query_dequant_scale.clone().contiguous() | |
| key_dequant_scale = key_dequant_scale.clone().contiguous() | |
| query = query.contiguous() | |
| key = key.contiguous() | |
| weights = weights.contiguous() if weights is not None else None | |
| query_dequant_scale = query_dequant_scale.contiguous() | |
| key_dequant_scale = key_dequant_scale.contiguous() |
| Returns: | ||
| Output tensor with the same shape as residual. | ||
| """ | ||
| print('qqqq') |
| index_score = torch.zeros( | ||
| (batch_size, q_seq_len, k_seq_len), | ||
| dtype=torch.float32, | ||
| device=query.device, | ||
| ) |
| return y_out, expert_idx_out, norm_out_fp32 | ||
| return y_out, expert_idx_out, norm_out_fp32 |
There was a problem hiding this comment.
| if not (hasattr(torch, 'npu') and torch.npu.is_available()): | ||
| pytest.skip("NPU not available!") | ||
|
|
||
| torch.npu.set_device(13) |
101146e to
772adf9
Compare
# Conflicts: # examples/llm_inference.py # mojo_opset/modeling/deepseekv4/mojo_deepseek_v4.py
add deepseek v4 infer: