Skip to content

Commit 7a21148

Browse files
committed
mix-placement
1 parent d252e36 commit 7a21148

File tree

9 files changed

+388
-26
lines changed

9 files changed

+388
-26
lines changed

vllm_ascend/ascend_config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ class AscendConfig:
3434

3535
def __init__(self, vllm_config):
3636
additional_config = vllm_config.additional_config if vllm_config.additional_config is not None else {}
37+
self.mix_placement = additional_config.get("mix_placement",False)
3738
torchair_graph_config = additional_config.get("torchair_graph_config",
3839
{})
3940
self.torchair_graph_config = TorchairGraphConfig(

vllm_ascend/eplb/adaptor/vllm_adaptor.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -194,20 +194,34 @@ def _export_tensor_to_file(self, expert_maps, expert_map_record_path: str):
194194
json.dump(record, f, indent=4)
195195

196196
def do_update_expert_map(self, layer_id, updated_expert_map):
197-
self.expert_map_per_layer[layer_id] = updated_expert_map.clone()
198-
self.expert_map_per_layer_cpu[layer_id] = updated_expert_map.clone()
197+
pad_len = self.expert_map_per_layer[layer_id].shape[0] - updated_expert_map.shape[0]
198+
updated_expert_map_padded = torch.nn.functional.pad(
199+
updated_expert_map,
200+
pad=(0,pad_len),
201+
mode='constant',
202+
value=-1
203+
)
204+
self.expert_map_per_layer[layer_id].copy_(updated_expert_map_padded)
205+
self.expert_map_per_layer_cpu[layer_id].copy_(updated_expert_map)
199206

200207
def do_update_expert_weight(self, layer_id, local_expert_to_replace,
201208
buffer_tensor_id):
202209
for expert_tensor, buffer_tensor in zip(
203210
self.expert_param_per_layer[layer_id][local_expert_to_replace],
204211
self.buffer_tensor_list[buffer_tensor_id]):
205-
expert_tensor = buffer_tensor.clone()
212+
expert_tensor.copy_(buffer_tensor)
206213
logger.debug(f"Expert tensor shape is :{expert_tensor.shape}")
207214

208215
def do_update_log2phy_map(self, layer_id, updated_log2phy_map):
209216
if self.log2phy_map_per_layer[layer_id] is not None:
210-
self.log2phy_map_per_layer[layer_id].copy_(updated_log2phy_map)
217+
pad_len = self.log2phy_map_per_layer[layer_id].shape[0] - updated_log2phy_map.shape[0]
218+
updated_log2phy_map_padded = torch.nn.functional.pad(
219+
updated_log2phy_map,
220+
pad=(0,pad_len),
221+
mode='constant',
222+
value=-1
223+
)
224+
self.log2phy_map_per_layer[layer_id].copy_(updated_log2phy_map_padded)
211225

212226
def global2local(self, placement: torch.Tensor,
213227
E_local: int) -> torch.Tensor:

vllm_ascend/eplb/core/eplb_device_transfer_loader.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,6 @@ def generate_expert_d2d_transfer_task(self, expert_send_info,
5050
)
5151
return
5252

53-
# If neither send nor receive task is needed for this layer on this rank, return
54-
if not (expert_send_info or expert_recv_info):
55-
return
56-
5753
self.updated_expert_map = updated_expert_map
5854

5955
self.layer_id = layer_id

vllm_ascend/ops/fused_moe/experts_selector.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ def select_experts(hidden_states: torch.Tensor,
3333
routed_scaling_factor=1.0,
3434
e_score_correction_bias: Optional[torch.Tensor] = None,
3535
indices_type: Optional[torch.dtype] = None,
36+
mix_placement: Optional[bool] = False,
37+
num_logical_experts: int = -1,
3638
global_num_experts: int = -1):
3739
"""
3840
Fused experts with select experts.
@@ -87,6 +89,19 @@ def select_experts(hidden_states: torch.Tensor,
8789
e_score_correction_bias=e_score_correction_bias,
8890
global_num_experts=global_num_experts,
8991
)
92+
if mix_placement:
93+
pad_shared_expert_ids = torch.full((topk_ids.shape[0], 1),
94+
num_logical_experts,
95+
dtype=topk_ids.dtype,
96+
device=topk_ids.device)
97+
98+
pad_shared_expert_weights = torch.full((topk_weights.shape[0], 1),
99+
0.4,
100+
dtype=topk_weights.dtype,
101+
device=topk_weights.device)
102+
topk_ids = torch.cat([topk_ids, pad_shared_expert_ids], dim=1)
103+
topk_weights = torch.cat([topk_weights, pad_shared_expert_weights],
104+
dim=1)
90105
return topk_weights, topk_ids
91106

92107

vllm_ascend/ops/fused_moe/fused_moe.py

Lines changed: 30 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -170,10 +170,10 @@ def __init__(self, *args, **kwargs):
170170
self.moe_config.dp_group = get_dp_group()
171171
self.moe_config.ep_group = get_ep_group()
172172
self.moe_config.mc2_group = get_mc2_group()
173-
ascend_config = get_ascend_config()
174-
self.dynamic_eplb = ascend_config.dynamic_eplb or ascend_config.expert_map_record_path
175-
self.expert_map_path = ascend_config.expert_map_path
176-
self.global_redundant_expert_num = ascend_config.init_redundancy_expert
173+
self.ascend_config = get_ascend_config()
174+
self.dynamic_eplb = self.ascend_config.dynamic_eplb or self.ascend_config.expert_map_record_path
175+
self.expert_map_path = self.ascend_config.expert_map_path
176+
self.global_redundant_expert_num = self.ascend_config.init_redundancy_expert
177177
self.global_num_experts = num_experts + self.global_redundant_expert_num
178178
if self.custom_routing_function is None and self.e_score_correction_bias is not None:
179179
vllm_config = get_current_vllm_config()
@@ -248,7 +248,7 @@ def __init__(self, *args, **kwargs):
248248
moe_quant_params["intermediate_size_full"] = intermediate_size
249249
self.quant_method.create_weights(layer=self, **moe_quant_params)
250250

251-
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
251+
self.enable_shared_expert_dp = self.ascend_config.enable_shared_expert_dp
252252

253253
setup_moe_comm_method(self.moe_config)
254254
self.quant_type = self._get_quant_type()
@@ -275,7 +275,7 @@ def get_map(self):
275275
return self.expert_map
276276

277277
def get_log2phy_map(self):
278-
return self.logical_to_physical_map
278+
return self.log2phy
279279

280280
def clear_moe_load(self):
281281
if self.moe_load is not None:
@@ -428,8 +428,8 @@ def __init__(
428428
self._shared_experts = shared_experts
429429
self.use_overlapped = use_overlapped
430430
self.shared_expert_stream = None
431-
ascend_config = get_ascend_config()
432-
self.multistream_overlap_shared_expert = ascend_config.multistream_overlap_shared_expert
431+
self.ascend_config = get_ascend_config()
432+
self.multistream_overlap_shared_expert = self.ascend_config.multistream_overlap_shared_expert
433433
if enable_sp():
434434
logger.info_once(
435435
"Sequence parallelism is enabled, shared experts are replicated for best performance."
@@ -457,11 +457,19 @@ def forward(
457457
hidden_states: torch.Tensor,
458458
router_logits: torch.Tensor,
459459
) -> tuple[torch.Tensor, torch.Tensor]:
460-
shared_out, fused_out = AscendFusedMoE.forward(
461-
self,
462-
hidden_states=hidden_states,
463-
router_logits=router_logits,
464-
)
460+
if self._shared_experts is None:
461+
fused_out = AscendFusedMoE.forward(
462+
self,
463+
hidden_states=hidden_states,
464+
router_logits=router_logits,
465+
)
466+
shared_out = None
467+
else:
468+
shared_out, fused_out = AscendFusedMoE.forward(
469+
self,
470+
hidden_states=hidden_states,
471+
router_logits=router_logits,
472+
)
465473
return shared_out, fused_out
466474

467475
def forward_impl(self, hidden_states: torch.Tensor,
@@ -475,7 +483,10 @@ def forward_impl(self, hidden_states: torch.Tensor,
475483
# Use a separate stream to run shared experts.
476484
# Note that currently we only support calculations in separate streams with aclgraph.
477485
# Communication operations in another stream might cause unknown errors.
478-
shared_out = self._shared_experts(hidden_states)
486+
if self._shared_experts is None:
487+
shared_out = None
488+
else:
489+
shared_out = self._shared_experts(hidden_states)
479490

480491
fused_output = AscendFusedMoE.forward_impl(
481492
self,
@@ -490,6 +501,9 @@ def forward_impl(self, hidden_states: torch.Tensor,
490501
forward_context = get_forward_context()
491502
moe_comm_type = forward_context.moe_comm_type
492503
if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2} \
493-
and not shared_expert_dp_enabled():
504+
and not shared_expert_dp_enabled() and shared_out is not None:
494505
shared_out = tensor_model_parallel_all_reduce(shared_out)
495-
return shared_out, fused_output
506+
if shared_out is None:
507+
return fused_output
508+
else:
509+
return shared_out, fused_output

vllm_ascend/ops/fused_moe/moe_mlp.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,14 +112,16 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
112112
if quantized_hidden_states is not None:
113113
dispose_tensor(quantized_hidden_states)
114114
# act_fn: swiglu
115+
group_diff = torch.diff(group_list)
116+
new_group = torch.cat([group_diff[0].unsqueeze(0), group_diff], dim=0)
115117
hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant(
116118
x=hidden_states,
117119
weight_scale=w1_scale,
118120
activation_scale=pertoken_scale,
119121
bias=None,
120122
quant_scale=None,
121123
quant_offset=None,
122-
group_index=group_list,
124+
group_index=new_group,
123125
activate_left=True,
124126
quant_mode=1,
125127
)

vllm_ascend/patch/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,3 +138,4 @@
138138
# Future Plan:
139139
# Remove this patch when adapted vllm version contains the above PR.
140140
#
141+
from vllm_ascend.patch.worker import patch_deepseekv3

0 commit comments

Comments
 (0)