@@ -383,6 +383,10 @@ def get_state_dict_and_metadata(self, model, safe_serialization: bool = False):
383383
384384 state_dict = model .state_dict ()
385385
386+ # Get num_local_experts from model config
387+ num_local_experts = getattr (model .config , "num_local_experts" , 32 )
388+ hidden_size = getattr (model .config , "hidden_size" , 2880 )
389+
386390 for name , module in model .named_modules ():
387391 if (
388392 isinstance (module , Mxfp4GptOssExperts )
@@ -392,7 +396,7 @@ def get_state_dict_and_metadata(self, model, safe_serialization: bool = False):
392396 state_dict [f"{ name } .gate_up_proj_blocks" ] = (
393397 module .gate_up_proj .storage .layout .unswizzle_data (module .gate_up_proj .storage .data )
394398 .transpose (- 1 , - 2 )
395- .reshape (32 , - 1 , 90 , 16 )
399+ .reshape (num_local_experts , - 1 , 90 , 16 )
396400 )
397401 state_dict [f"{ name } .gate_up_proj_scales" ] = (
398402 module .gate_up_proj_precision_config .weight_scale .storage .layout .unswizzle_data (
@@ -402,7 +406,7 @@ def get_state_dict_and_metadata(self, model, safe_serialization: bool = False):
402406 state_dict [f"{ name } .down_proj_blocks" ] = (
403407 module .down_proj .storage .layout .unswizzle_data (module .down_proj .storage .data )
404408 .transpose (- 1 , - 2 )
405- .reshape (32 , 2880 , 90 , - 1 )
409+ .reshape (num_local_experts , hidden_size , 90 , - 1 )
406410 )
407411 state_dict [f"{ name } .down_proj_scales" ] = (
408412 module .down_proj_precision_config .weight_scale .storage .layout .unswizzle_data (
0 commit comments