Skip to content

Commit bb6028c

Browse files
authored
Fix MXFP4 quantizer to support variable num_local_experts and hidden_size (#41795)
Fix MXFP4 quantizer to support variable num_local_experts
1 parent 7935b86 commit bb6028c

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

src/transformers/quantizers/quantizer_mxfp4.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -383,6 +383,10 @@ def get_state_dict_and_metadata(self, model, safe_serialization: bool = False):
383383

384384
state_dict = model.state_dict()
385385

386+
# Get num_local_experts from model config
387+
num_local_experts = getattr(model.config, "num_local_experts", 32)
388+
hidden_size = getattr(model.config, "hidden_size", 2880)
389+
386390
for name, module in model.named_modules():
387391
if (
388392
isinstance(module, Mxfp4GptOssExperts)
@@ -392,7 +396,7 @@ def get_state_dict_and_metadata(self, model, safe_serialization: bool = False):
392396
state_dict[f"{name}.gate_up_proj_blocks"] = (
393397
module.gate_up_proj.storage.layout.unswizzle_data(module.gate_up_proj.storage.data)
394398
.transpose(-1, -2)
395-
.reshape(32, -1, 90, 16)
399+
.reshape(num_local_experts, -1, 90, 16)
396400
)
397401
state_dict[f"{name}.gate_up_proj_scales"] = (
398402
module.gate_up_proj_precision_config.weight_scale.storage.layout.unswizzle_data(
@@ -402,7 +406,7 @@ def get_state_dict_and_metadata(self, model, safe_serialization: bool = False):
402406
state_dict[f"{name}.down_proj_blocks"] = (
403407
module.down_proj.storage.layout.unswizzle_data(module.down_proj.storage.data)
404408
.transpose(-1, -2)
405-
.reshape(32, 2880, 90, -1)
409+
.reshape(num_local_experts, hidden_size, 90, -1)
406410
)
407411
state_dict[f"{name}.down_proj_scales"] = (
408412
module.down_proj_precision_config.weight_scale.storage.layout.unswizzle_data(

0 commit comments

Comments
 (0)