-
Notifications
You must be signed in to change notification settings - Fork 617
mix-placement #4470
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
mix-placement #4470
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -33,6 +33,8 @@ def select_experts(hidden_states: torch.Tensor, | |
| routed_scaling_factor=1.0, | ||
| e_score_correction_bias: Optional[torch.Tensor] = None, | ||
| indices_type: Optional[torch.dtype] = None, | ||
| mix_placement: Optional[bool] = False, | ||
| num_logical_experts: int = -1, | ||
| global_num_experts: int = -1): | ||
| """ | ||
| Fused experts with select experts. | ||
|
|
@@ -87,6 +89,19 @@ def select_experts(hidden_states: torch.Tensor, | |
| e_score_correction_bias=e_score_correction_bias, | ||
| global_num_experts=global_num_experts, | ||
| ) | ||
| if mix_placement: | ||
| pad_shared_expert_ids = torch.full((topk_ids.shape[0], 1), | ||
| num_logical_experts, | ||
| dtype=topk_ids.dtype, | ||
| device=topk_ids.device) | ||
|
|
||
| pad_shared_expert_weights = torch.full((topk_weights.shape[0], 1), | ||
| 0.4, | ||
| dtype=topk_weights.dtype, | ||
| device=topk_weights.device) | ||
|
Comment on lines
+98
to
+101
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The value |
||
| topk_ids = torch.cat([topk_ids, pad_shared_expert_ids], dim=1) | ||
| topk_weights = torch.cat([topk_weights, pad_shared_expert_weights], | ||
| dim=1) | ||
| return topk_weights, topk_ids | ||
|
|
||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||
|---|---|---|---|---|---|---|---|---|
|
|
@@ -112,14 +112,16 @@ def quant_apply_mlp(hidden_states: torch.Tensor, | |||||||
| if quantized_hidden_states is not None: | ||||||||
| dispose_tensor(quantized_hidden_states) | ||||||||
| # act_fn: swiglu | ||||||||
| group_diff = torch.diff(group_list) | ||||||||
| new_group = torch.cat([group_diff[0].unsqueeze(0), group_diff], dim=0) | ||||||||
|
Comment on lines
+115
to
+116
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The calculation of
The correct way to get group sizes from a cumulative sum tensor is to use
Suggested change
|
||||||||
| hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant( | ||||||||
| x=hidden_states, | ||||||||
| weight_scale=w1_scale, | ||||||||
| activation_scale=pertoken_scale, | ||||||||
| bias=None, | ||||||||
| quant_scale=None, | ||||||||
| quant_offset=None, | ||||||||
| group_index=group_list, | ||||||||
| group_index=new_group, | ||||||||
| activate_left=True, | ||||||||
| quant_mode=1, | ||||||||
| ) | ||||||||
|
|
||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The in-place copy
self.expert_map_per_layer_cpu[layer_id].copy_(updated_expert_map)will raise aRuntimeErrorifupdated_expert_maphas a different shape thanself.expert_map_per_layer_cpu[layer_id]. The logic for paddingupdated_expert_mapfor the device tensorself.expert_map_per_layersuggests that shape mismatches are expected. The CPU-side map should be handled in a way that accommodates shape changes to avoid crashes. Reassigning the tensor, as was done previously, is a safer approach.