@@ -170,10 +170,10 @@ def __init__(self, *args, **kwargs):
170170 self .moe_config .dp_group = get_dp_group ()
171171 self .moe_config .ep_group = get_ep_group ()
172172 self .moe_config .mc2_group = get_mc2_group ()
173- ascend_config = get_ascend_config ()
174- self .dynamic_eplb = ascend_config .dynamic_eplb or ascend_config .expert_map_record_path
175- self .expert_map_path = ascend_config .expert_map_path
176- self .global_redundant_expert_num = ascend_config .init_redundancy_expert
173+ self . ascend_config = get_ascend_config ()
174+ self .dynamic_eplb = self . ascend_config .dynamic_eplb or self . ascend_config .expert_map_record_path
175+ self .expert_map_path = self . ascend_config .expert_map_path
176+ self .global_redundant_expert_num = self . ascend_config .init_redundancy_expert
177177 self .global_num_experts = num_experts + self .global_redundant_expert_num
178178 if self .custom_routing_function is None and self .e_score_correction_bias is not None :
179179 vllm_config = get_current_vllm_config ()
@@ -248,7 +248,7 @@ def __init__(self, *args, **kwargs):
248248 moe_quant_params ["intermediate_size_full" ] = intermediate_size
249249 self .quant_method .create_weights (layer = self , ** moe_quant_params )
250250
251- self .enable_shared_expert_dp = ascend_config .enable_shared_expert_dp
251+ self .enable_shared_expert_dp = self . ascend_config .enable_shared_expert_dp
252252
253253 setup_moe_comm_method (self .moe_config )
254254 self .quant_type = self ._get_quant_type ()
@@ -275,7 +275,7 @@ def get_map(self):
275275 return self .expert_map
276276
277277 def get_log2phy_map (self ):
278- return self .logical_to_physical_map
278+ return self .log2phy
279279
280280 def clear_moe_load (self ):
281281 if self .moe_load is not None :
@@ -428,8 +428,8 @@ def __init__(
428428 self ._shared_experts = shared_experts
429429 self .use_overlapped = use_overlapped
430430 self .shared_expert_stream = None
431- ascend_config = get_ascend_config ()
432- self .multistream_overlap_shared_expert = ascend_config .multistream_overlap_shared_expert
431+ self . ascend_config = get_ascend_config ()
432+ self .multistream_overlap_shared_expert = self . ascend_config .multistream_overlap_shared_expert
433433 if enable_sp ():
434434 logger .info_once (
435435 "Sequence parallelism is enabled, shared experts are replicated for best performance."
@@ -457,11 +457,19 @@ def forward(
457457 hidden_states : torch .Tensor ,
458458 router_logits : torch .Tensor ,
459459 ) -> tuple [torch .Tensor , torch .Tensor ]:
460- shared_out , fused_out = AscendFusedMoE .forward (
461- self ,
462- hidden_states = hidden_states ,
463- router_logits = router_logits ,
464- )
460+ if self ._shared_experts is None :
461+ fused_out = AscendFusedMoE .forward (
462+ self ,
463+ hidden_states = hidden_states ,
464+ router_logits = router_logits ,
465+ )
466+ shared_out = None
467+ else :
468+ shared_out , fused_out = AscendFusedMoE .forward (
469+ self ,
470+ hidden_states = hidden_states ,
471+ router_logits = router_logits ,
472+ )
465473 return shared_out , fused_out
466474
467475 def forward_impl (self , hidden_states : torch .Tensor ,
@@ -475,7 +483,10 @@ def forward_impl(self, hidden_states: torch.Tensor,
475483 # Use a separate stream to run shared experts.
476484 # Note that currently we only support calculations in separate streams with aclgraph.
477485 # Communication operations in another stream might cause unknown errors.
478- shared_out = self ._shared_experts (hidden_states )
486+ if self ._shared_experts is None :
487+ shared_out = None
488+ else :
489+ shared_out = self ._shared_experts (hidden_states )
479490
480491 fused_output = AscendFusedMoE .forward_impl (
481492 self ,
@@ -490,6 +501,9 @@ def forward_impl(self, hidden_states: torch.Tensor,
490501 forward_context = get_forward_context ()
491502 moe_comm_type = forward_context .moe_comm_type
492503 if moe_comm_type in {MoECommType .ALLTOALL , MoECommType .MC2 } \
493- and not shared_expert_dp_enabled ():
504+ and not shared_expert_dp_enabled () and shared_out is not None :
494505 shared_out = tensor_model_parallel_all_reduce (shared_out )
495- return shared_out , fused_output
506+ if shared_out is None :
507+ return fused_output
508+ else :
509+ return shared_out , fused_output
0 commit comments