From a6d11e2081e78c33bfec76c9e35f914906012bc9 Mon Sep 17 00:00:00 2001 From: Che Ruan Date: Wed, 26 Nov 2025 19:14:10 +0800 Subject: [PATCH 1/3] mix-placement Signed-off-by: Che Ruan --- vllm_ascend/ascend_config.py | 1 + vllm_ascend/eplb/adaptor/vllm_adaptor.py | 22 +- .../eplb/core/eplb_device_transfer_loader.py | 4 - vllm_ascend/ops/fused_moe/experts_selector.py | 15 + vllm_ascend/ops/fused_moe/fused_moe.py | 46 ++- vllm_ascend/ops/fused_moe/moe_mlp.py | 4 +- vllm_ascend/patch/__init__.py | 1 + vllm_ascend/patch/worker/patch_deepseekv3.py | 319 ++++++++++++++++++ vllm_ascend/quantization/w8a8_dynamic.py | 2 +- 9 files changed, 388 insertions(+), 26 deletions(-) create mode 100644 vllm_ascend/patch/worker/patch_deepseekv3.py diff --git a/vllm_ascend/ascend_config.py b/vllm_ascend/ascend_config.py index 16d16a4d7c8..72f9ca91965 100644 --- a/vllm_ascend/ascend_config.py +++ b/vllm_ascend/ascend_config.py @@ -34,6 +34,7 @@ class AscendConfig: def __init__(self, vllm_config): additional_config = vllm_config.additional_config if vllm_config.additional_config is not None else {} + self.mix_placement = additional_config.get("mix_placement",False) torchair_graph_config = additional_config.get("torchair_graph_config", {}) self.torchair_graph_config = TorchairGraphConfig( diff --git a/vllm_ascend/eplb/adaptor/vllm_adaptor.py b/vllm_ascend/eplb/adaptor/vllm_adaptor.py index 726763013f4..1fb17c42fc8 100644 --- a/vllm_ascend/eplb/adaptor/vllm_adaptor.py +++ b/vllm_ascend/eplb/adaptor/vllm_adaptor.py @@ -194,20 +194,34 @@ def _export_tensor_to_file(self, expert_maps, expert_map_record_path: str): json.dump(record, f, indent=4) def do_update_expert_map(self, layer_id, updated_expert_map): - self.expert_map_per_layer[layer_id] = updated_expert_map.clone() - self.expert_map_per_layer_cpu[layer_id] = updated_expert_map.clone() + pad_len = self.expert_map_per_layer[layer_id].shape[0] - updated_expert_map.shape[0] + updated_expert_map_padded = torch.nn.functional.pad( + updated_expert_map, + pad=(0,pad_len), + mode='constant', + value=-1 + ) + self.expert_map_per_layer[layer_id].copy_(updated_expert_map_padded) + self.expert_map_per_layer_cpu[layer_id].copy_(updated_expert_map) def do_update_expert_weight(self, layer_id, local_expert_to_replace, buffer_tensor_id): for expert_tensor, buffer_tensor in zip( self.expert_param_per_layer[layer_id][local_expert_to_replace], self.buffer_tensor_list[buffer_tensor_id]): - expert_tensor = buffer_tensor.clone() + expert_tensor.copy_(buffer_tensor) logger.debug(f"Expert tensor shape is :{expert_tensor.shape}") def do_update_log2phy_map(self, layer_id, updated_log2phy_map): if self.log2phy_map_per_layer[layer_id] is not None: - self.log2phy_map_per_layer[layer_id].copy_(updated_log2phy_map) + pad_len = self.log2phy_map_per_layer[layer_id].shape[0] - updated_log2phy_map.shape[0] + updated_log2phy_map_padded = torch.nn.functional.pad( + updated_log2phy_map, + pad=(0,pad_len), + mode='constant', + value=-1 + ) + self.log2phy_map_per_layer[layer_id].copy_(updated_log2phy_map_padded) def global2local(self, placement: torch.Tensor, E_local: int) -> torch.Tensor: diff --git a/vllm_ascend/eplb/core/eplb_device_transfer_loader.py b/vllm_ascend/eplb/core/eplb_device_transfer_loader.py index 5c676cddb8f..ce1c3d73325 100644 --- a/vllm_ascend/eplb/core/eplb_device_transfer_loader.py +++ b/vllm_ascend/eplb/core/eplb_device_transfer_loader.py @@ -50,10 +50,6 @@ def generate_expert_d2d_transfer_task(self, expert_send_info, ) return - # If neither send nor receive task is needed for this layer on this rank, return - if not (expert_send_info or expert_recv_info): - return - self.updated_expert_map = updated_expert_map self.layer_id = layer_id diff --git a/vllm_ascend/ops/fused_moe/experts_selector.py b/vllm_ascend/ops/fused_moe/experts_selector.py index eb3fc848c8e..85f9a128ae1 100644 --- a/vllm_ascend/ops/fused_moe/experts_selector.py +++ b/vllm_ascend/ops/fused_moe/experts_selector.py @@ -33,6 +33,8 @@ def select_experts(hidden_states: torch.Tensor, routed_scaling_factor=1.0, e_score_correction_bias: Optional[torch.Tensor] = None, indices_type: Optional[torch.dtype] = None, + mix_placement: Optional[bool] = False, + num_logical_experts: int = -1, global_num_experts: int = -1): """ Fused experts with select experts. @@ -87,6 +89,19 @@ def select_experts(hidden_states: torch.Tensor, e_score_correction_bias=e_score_correction_bias, global_num_experts=global_num_experts, ) + if mix_placement: + pad_shared_expert_ids = torch.full((topk_ids.shape[0], 1), + num_logical_experts, + dtype=topk_ids.dtype, + device=topk_ids.device) + + pad_shared_expert_weights = torch.full((topk_weights.shape[0], 1), + 0.4, + dtype=topk_weights.dtype, + device=topk_weights.device) + topk_ids = torch.cat([topk_ids, pad_shared_expert_ids], dim=1) + topk_weights = torch.cat([topk_weights, pad_shared_expert_weights], + dim=1) return topk_weights, topk_ids diff --git a/vllm_ascend/ops/fused_moe/fused_moe.py b/vllm_ascend/ops/fused_moe/fused_moe.py index b9667abbccb..5155165af0f 100644 --- a/vllm_ascend/ops/fused_moe/fused_moe.py +++ b/vllm_ascend/ops/fused_moe/fused_moe.py @@ -170,10 +170,10 @@ def __init__(self, *args, **kwargs): self.moe_config.dp_group = get_dp_group() self.moe_config.ep_group = get_ep_group() self.moe_config.mc2_group = get_mc2_group() - ascend_config = get_ascend_config() - self.dynamic_eplb = ascend_config.dynamic_eplb or ascend_config.expert_map_record_path - self.expert_map_path = ascend_config.expert_map_path - self.global_redundant_expert_num = ascend_config.init_redundancy_expert + self.ascend_config = get_ascend_config() + self.dynamic_eplb = self.ascend_config.dynamic_eplb or self.ascend_config.expert_map_record_path + self.expert_map_path = self.ascend_config.expert_map_path + self.global_redundant_expert_num = self.ascend_config.init_redundancy_expert self.global_num_experts = num_experts + self.global_redundant_expert_num if self.custom_routing_function is None and self.e_score_correction_bias is not None: vllm_config = get_current_vllm_config() @@ -248,7 +248,7 @@ def __init__(self, *args, **kwargs): moe_quant_params["intermediate_size_full"] = intermediate_size self.quant_method.create_weights(layer=self, **moe_quant_params) - self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp + self.enable_shared_expert_dp = self.ascend_config.enable_shared_expert_dp setup_moe_comm_method(self.moe_config) self.quant_type = self._get_quant_type() @@ -275,7 +275,7 @@ def get_map(self): return self.expert_map def get_log2phy_map(self): - return self.logical_to_physical_map + return self.log2phy def clear_moe_load(self): if self.moe_load is not None: @@ -428,8 +428,8 @@ def __init__( self._shared_experts = shared_experts self.use_overlapped = use_overlapped self.shared_expert_stream = None - ascend_config = get_ascend_config() - self.multistream_overlap_shared_expert = ascend_config.multistream_overlap_shared_expert + self.ascend_config = get_ascend_config() + self.multistream_overlap_shared_expert = self.ascend_config.multistream_overlap_shared_expert if enable_sp(): logger.info_once( "Sequence parallelism is enabled, shared experts are replicated for best performance." @@ -457,11 +457,19 @@ def forward( hidden_states: torch.Tensor, router_logits: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: - shared_out, fused_out = AscendFusedMoE.forward( - self, - hidden_states=hidden_states, - router_logits=router_logits, - ) + if self._shared_experts is None: + fused_out = AscendFusedMoE.forward( + self, + hidden_states=hidden_states, + router_logits=router_logits, + ) + shared_out = None + else: + shared_out, fused_out = AscendFusedMoE.forward( + self, + hidden_states=hidden_states, + router_logits=router_logits, + ) return shared_out, fused_out def forward_impl(self, hidden_states: torch.Tensor, @@ -475,7 +483,10 @@ def forward_impl(self, hidden_states: torch.Tensor, # Use a separate stream to run shared experts. # Note that currently we only support calculations in separate streams with aclgraph. # Communication operations in another stream might cause unknown errors. - shared_out = self._shared_experts(hidden_states) + if self._shared_experts is None: + shared_out = None + else: + shared_out = self._shared_experts(hidden_states) fused_output = AscendFusedMoE.forward_impl( self, @@ -490,6 +501,9 @@ def forward_impl(self, hidden_states: torch.Tensor, forward_context = get_forward_context() moe_comm_type = forward_context.moe_comm_type if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2} \ - and not shared_expert_dp_enabled(): + and not shared_expert_dp_enabled() and shared_out is not None: shared_out = tensor_model_parallel_all_reduce(shared_out) - return shared_out, fused_output + if shared_out is None: + return fused_output + else: + return shared_out, fused_output diff --git a/vllm_ascend/ops/fused_moe/moe_mlp.py b/vllm_ascend/ops/fused_moe/moe_mlp.py index 07ba732f199..35920c7b569 100644 --- a/vllm_ascend/ops/fused_moe/moe_mlp.py +++ b/vllm_ascend/ops/fused_moe/moe_mlp.py @@ -112,6 +112,8 @@ def quant_apply_mlp(hidden_states: torch.Tensor, if quantized_hidden_states is not None: dispose_tensor(quantized_hidden_states) # act_fn: swiglu + group_diff = torch.diff(group_list) + new_group = torch.cat([group_diff[0].unsqueeze(0), group_diff], dim=0) hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant( x=hidden_states, weight_scale=w1_scale, @@ -119,7 +121,7 @@ def quant_apply_mlp(hidden_states: torch.Tensor, bias=None, quant_scale=None, quant_offset=None, - group_index=group_list, + group_index=new_group, activate_left=True, quant_mode=1, ) diff --git a/vllm_ascend/patch/__init__.py b/vllm_ascend/patch/__init__.py index 1b346de6744..83fcf6aa8c4 100644 --- a/vllm_ascend/patch/__init__.py +++ b/vllm_ascend/patch/__init__.py @@ -138,3 +138,4 @@ # Future Plan: # Remove this patch when adapted vllm version contains the above PR. # +from vllm_ascend.patch.worker import patch_deepseekv3 \ No newline at end of file diff --git a/vllm_ascend/patch/worker/patch_deepseekv3.py b/vllm_ascend/patch/worker/patch_deepseekv3.py new file mode 100644 index 00000000000..a32643be178 --- /dev/null +++ b/vllm_ascend/patch/worker/patch_deepseekv3.py @@ -0,0 +1,319 @@ +import typing +from typing import Iterable + +import torch +import vllm +from torch import nn +from transformers import DeepseekV2Config, DeepseekV3Config +import torch.distributed as dist +from vllm.distributed import ( + get_ep_group, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_gather, +) + +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, maybe_remap_kv_scale_name) +from collections.abc import Callable, Iterable +# from vllm.model_executor.layers.fused_moe.shared_fused_moe import SharedFusedMoE +from vllm.model_executor.layers.shared_fused_moe import SharedFusedMoE +from vllm.model_executor.models.deepseek_v2 import DeepseekV2ForCausalLM, get_spec_layer_idx_from_weight_name, DeepseekV2MLP, DeepseekV2MoE +from vllm.model_executor.models.utils import is_pp_missing_parameter +from vllm.config import ParallelConfig +from vllm.config import get_current_vllm_config +from vllm_ascend.ascend_config import get_ascend_config +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.linear import ReplicatedLinear +from vllm.model_executor.models.utils import sequence_parallel_chunk +from vllm_ascend.ops.common_fused_moe import (AscendFusedMoE, + AscendSharedFusedMoE) + +class AscendDeepseekV2MoE(DeepseekV2MoE,nn.Module): + def __init__( + self, + config: DeepseekV2Config | DeepseekV3Config, + parallel_config: ParallelConfig, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + ): + nn.Module.__init__(self) + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + + self.routed_scaling_factor = config.routed_scaling_factor + + self.ep_group = get_ep_group().device_group + self.ep_rank = self.ep_group.rank() + self.ep_size = self.ep_group.size() + self.n_routed_experts: int = config.n_routed_experts + self.n_shared_experts: int = config.n_shared_experts + + self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe + + if config.hidden_act != "silu": + raise ValueError( + f"Unsupported activation: {config.hidden_act}. " + "Only silu is supported for now." + ) + + self.gate = ReplicatedLinear( + config.hidden_size, + config.n_routed_experts, + bias=False, + quant_config=None, + prefix=f"{prefix}.gate", + ) + if config.topk_method == "noaux_tc": + self.gate.e_score_correction_bias = nn.Parameter( + torch.empty(config.n_routed_experts, dtype=torch.float32) + ) + else: + self.gate.e_score_correction_bias = None + + # Load balancing settings. + eplb_config = parallel_config.eplb_config + self.enable_eplb = parallel_config.enable_eplb + + self.n_redundant_experts = eplb_config.num_redundant_experts + self.n_logical_experts = self.n_routed_experts + self.n_physical_experts = self.n_logical_experts + self.n_redundant_experts + self.n_local_physical_experts = self.n_physical_experts // self.ep_size + + self.physical_expert_start = self.ep_rank * self.n_local_physical_experts + self.physical_expert_end = ( + self.physical_expert_start + self.n_local_physical_experts + ) + + ascend_config = get_ascend_config() + mix_placement = getattr(ascend_config,"mix_placement",False) + if ( + config.n_shared_experts is None + or mix_placement + ): + self.shared_experts = None + else: + intermediate_size = config.moe_intermediate_size * config.n_shared_experts + + self.shared_experts = DeepseekV2MLP( + hidden_size=config.hidden_size, + intermediate_size=intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + is_sequence_parallel=self.is_sequence_parallel, + reduce_results=False, + prefix=f"{prefix}.shared_experts", + ) + + self.experts = SharedFusedMoE( + shared_experts=self.shared_experts, + num_experts=config.n_routed_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + reduce_results=False, + renormalize=config.norm_topk_prob, + quant_config=quant_config, + use_grouped_topk=True, + num_expert_group=config.n_group, + topk_group=config.topk_group, + prefix=f"{prefix}.experts", + scoring_func=config.scoring_func, + # we do scaling outside, set factor to 1.0 to avoid double mul + # aiter applies routed_scaling_factor internally + routed_scaling_factor=1.0 + if not mix_placement + else self.routed_scaling_factor, + e_score_correction_bias=self.gate.e_score_correction_bias, + enable_eplb=self.enable_eplb, + num_redundant_experts=self.n_redundant_experts, + is_sequence_parallel=self.is_sequence_parallel, + ) + + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + num_tokens, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + # Chunk the hidden states so they aren't replicated across TP ranks. + # This avoids duplicate computation in self.experts. + # TODO: We can replace the all_reduce at the end of attn with a + # reduce_scatter instead of chunking here. + if self.is_sequence_parallel: + hidden_states = sequence_parallel_chunk(hidden_states) + + router_logits, _ = self.gate(hidden_states) + fused_moe_out = self.experts( + hidden_states=hidden_states, router_logits=router_logits + ) + ascend_config = get_ascend_config() + shared_output, final_hidden_states = fused_moe_out + if self.shared_experts is None: + assert shared_output is None + + # Fix FP16 overflow + # See DeepseekV2DecoderLayer for more details. + if hidden_states.dtype != torch.float16: + final_hidden_states *= self.routed_scaling_factor + elif self.shared_experts is not None: + assert shared_output is not None + shared_output *= 1.0 / self.routed_scaling_factor + + if self.shared_experts is not None: + assert shared_output is not None + final_hidden_states += shared_output + + if self.is_sequence_parallel: + final_hidden_states = tensor_model_parallel_all_gather( + final_hidden_states, 0 + ) + final_hidden_states = final_hidden_states[:num_tokens] + elif self.tp_size > 1: + final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( + final_hidden_states + ) + return final_hidden_states.view(num_tokens, hidden_dim) + +class CustomDeepseekV2ForCausalLM(DeepseekV2ForCausalLM): + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + self.vllm_config=get_current_vllm_config() + ascend_config = get_ascend_config() + stacked_params_mapping = [ + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ("fused_qkv_a_proj", "q_a_proj", 0), + ("fused_qkv_a_proj", "kv_a_proj_with_mqa", 1), + ] + + mix_placement = getattr(ascend_config, "mix_placement", False) + + + expert_params_mapping = SharedFusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.n_routed_experts + + ( + self.config.n_shared_experts + if mix_placement + else 0 + ), + num_redundant_experts=self.num_redundant_experts, + ) + + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + + spec_layer = get_spec_layer_idx_from_weight_name(self.config, name) + if spec_layer is not None: + continue + + is_fuse_shared_experts_layer = ( + mix_placement + and ("mlp.shared_experts" in name) + ) + + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + if ("mlp.experts." in name) and name not in params_dict: + continue + if is_fuse_shared_experts_layer: + continue + name_mapped = name.replace(weight_name, param_name) + + if (param_name == "fused_qkv_a_proj") and name_mapped not in params_dict: + continue + else: + name = name_mapped + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + if name not in params_dict.keys(): + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + is_expert_weight = False + num_chunks = 1 + if is_fuse_shared_experts_layer: + num_chunks = getattr(self.config, "n_shared_experts", 1) or 1 + split_dim = 1 if "down_proj.weight" in name else 0 + total = loaded_weight.shape[split_dim] + assert total % num_chunks == 0, ( + f"Shared expert weight dim {total} not divisible by num_chunks {num_chunks}" + ) + chunk_size = total // num_chunks + + for j in range(num_chunks): + chunk_name = name + weight_to_load = loaded_weight + + if is_fuse_shared_experts_layer: + if split_dim == 0: + weight_to_load = loaded_weight[j * chunk_size : (j + 1) * chunk_size, :] + else: + weight_to_load = loaded_weight[:, j * chunk_size : (j + 1) * chunk_size] + chunk_name = name.replace( + "mlp.shared_experts", + f"mlp.experts.{self.config.n_routed_experts + j}", + ) + + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in chunk_name: + continue + + is_expert_weight = True + name_mapped = chunk_name.replace(weight_name, param_name) + + if is_pp_missing_parameter(name_mapped, self): + continue + if name_mapped not in params_dict.keys(): + continue + param = params_dict[name_mapped] + weight_loader = typing.cast(Callable[..., bool], param.weight_loader) + success = weight_loader( + param, + weight_to_load, + name_mapped, + shard_id=shard_id, + expert_id=expert_id, + return_success=True, + ) + if success: + if not is_fuse_shared_experts_layer: + name = name_mapped + else: + loaded_params.add(name_mapped) + break + else: + if is_expert_weight: + continue + if name.endswith(".bias") and name not in params_dict: + continue + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + if is_pp_missing_parameter(name, self): + continue + if name not in params_dict.keys(): + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + if not is_fuse_shared_experts_layer: + loaded_params.add(name) + return loaded_params + + +vllm.model_executor.models.deepseek_v2.DeepseekV2MoE = AscendDeepseekV2MoE +DeepseekV2ForCausalLM.load_weights = CustomDeepseekV2ForCausalLM.load_weights \ No newline at end of file diff --git a/vllm_ascend/quantization/w8a8_dynamic.py b/vllm_ascend/quantization/w8a8_dynamic.py index 6b7d6b0875c..589b7519dee 100644 --- a/vllm_ascend/quantization/w8a8_dynamic.py +++ b/vllm_ascend/quantization/w8a8_dynamic.py @@ -238,7 +238,7 @@ def apply( hidden_states=x, pertoken_scale=pertoken_scale, w1=layer.w13_weight, - w1_scale=layer.w13_weight_scale_fp32, + w1_scale=layer.w13_weight_scale.to(torch.float32), w2=layer.w2_weight, w2_scale=layer.w2_weight_scale, topk_weights=topk_weights, From da75120aeb9408401026be9e1baebc7444363d4f Mon Sep 17 00:00:00 2001 From: Mercykid-bash Date: Mon, 1 Dec 2025 16:22:15 +0800 Subject: [PATCH 2/3] Update fused_moe.py --- vllm_ascend/ops/fused_moe/fused_moe.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm_ascend/ops/fused_moe/fused_moe.py b/vllm_ascend/ops/fused_moe/fused_moe.py index 5155165af0f..e6413be5af7 100644 --- a/vllm_ascend/ops/fused_moe/fused_moe.py +++ b/vllm_ascend/ops/fused_moe/fused_moe.py @@ -190,8 +190,8 @@ def __init__(self, *args, **kwargs): self.expert_load_balancer = ExpertLoadBalancer( self.expert_map_path, num_experts) self.expert_load_balancer.check_expert_map_tensor() - self.global_redundant_expert_num = ( - self.expert_load_balancer.get_global_redundant_expert_num()) + # self.global_redundant_expert_num = ( + # self.expert_load_balancer.get_global_redundant_expert_num()) self.global_num_experts = num_experts + self.global_redundant_expert_num try: self.local_num_experts, self.expert_map = ( From 31921401de02d2b678f5e1241e4cbbd6edb2961b Mon Sep 17 00:00:00 2001 From: Mercykid-bash Date: Mon, 1 Dec 2025 16:53:06 +0800 Subject: [PATCH 3/3] Update patch_deepseekv3.py --- vllm_ascend/patch/worker/patch_deepseekv3.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm_ascend/patch/worker/patch_deepseekv3.py b/vllm_ascend/patch/worker/patch_deepseekv3.py index a32643be178..59cfa76a633 100644 --- a/vllm_ascend/patch/worker/patch_deepseekv3.py +++ b/vllm_ascend/patch/worker/patch_deepseekv3.py @@ -175,7 +175,6 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class CustomDeepseekV2ForCausalLM(DeepseekV2ForCausalLM): def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - self.vllm_config=get_current_vllm_config() ascend_config = get_ascend_config() stacked_params_mapping = [ ("gate_up_proj", "gate_proj", 0), @@ -316,4 +315,4 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: vllm.model_executor.models.deepseek_v2.DeepseekV2MoE = AscendDeepseekV2MoE -DeepseekV2ForCausalLM.load_weights = CustomDeepseekV2ForCausalLM.load_weights \ No newline at end of file +DeepseekV2ForCausalLM.load_weights = CustomDeepseekV2ForCausalLM.load_weights