Skip to content

Support load fused moe weights #3672

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 26, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 30 additions & 1 deletion lmdeploy/pytorch/models/qwen3_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,6 +489,10 @@ def prepare_inputs_for_generation(
def _load_weight_experts(self, name: str, loaded_weight: torch.Tensor, params_dict: Dict[str, nn.Parameter],
expert_params_mapping: List):
"""Load weight experts."""
# load fused weights
if any([k in name for k in ['fused_w1w3', 'fused_w2']]):
return self._load_weight_fused_experts(name, loaded_weight, params_dict)

for (param_name, weight_name, expert_id, shard_id) in expert_params_mapping:
if weight_name not in name:
continue
Expand All @@ -500,6 +504,31 @@ def _load_weight_experts(self, name: str, loaded_weight: torch.Tensor, params_di
param = params_dict[name]
load_weight(param, loaded_weight)

def _load_weight_fused_experts(self, name: str, loaded_weight: torch.Tensor, params_dict: Dict[str, nn.Parameter]):
"""Load weight of fused expert weights."""
num_experts = self.config.num_experts
fused_gateup_name = 'fused_w1w3'
fused_down_name = 'fused_w2'
if fused_gateup_name in name:
chunk_size = loaded_weight.shape[0] // num_experts

for expert_id in range(num_experts):
param_name = name.replace(f'experts.{fused_gateup_name}', 'experts.gate_up')
param = params_dict[param_name]
w1 = loaded_weight.narrow(dim=0, start=chunk_size * expert_id, length=chunk_size // 2)
w3 = loaded_weight.narrow(dim=0, start=chunk_size * expert_id + chunk_size // 2, length=chunk_size // 2)
load_weight(param, w1, expert_id=expert_id, shard_id='gate')
load_weight(param, w3, expert_id=expert_id, shard_id='up')

elif fused_down_name in name:
chunk_size = loaded_weight.shape[0] // num_experts

for expert_id in range(num_experts):
param_name = name.replace(f'experts.{fused_down_name}', 'experts.down')
param = params_dict[param_name]
w2 = loaded_weight.narrow(dim=0, start=chunk_size * expert_id, length=chunk_size)
load_weight(param, w2, expert_id=expert_id, shard_id='down')

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
"""Load weights."""
# modify from vllm
Expand Down Expand Up @@ -529,7 +558,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
continue
if self.config.tie_word_embeddings and 'lm_head.weight' in name:
continue

name = name.replace('.block_sparse_moe.', '.mlp.')
if '.experts' in name:
self._load_weight_experts(name, loaded_weight, params_dict, expert_params_mapping=expert_params_mapping)
else:
Expand Down