diff --git a/mlx_lm/models/nemotron_h.py b/mlx_lm/models/nemotron_h.py index 353de36c9..3d4636253 100644 --- a/mlx_lm/models/nemotron_h.py +++ b/mlx_lm/models/nemotron_h.py @@ -357,6 +357,11 @@ def __init__(self, config: ModelArgs): self.weight = mx.zeros((self.n_routed_experts, config.hidden_size)) self.e_score_correction_bias = mx.zeros((self.n_routed_experts,)) + def to_quantized(self, group_size: int = 64, bits: int = 4, mode: str = "affine"): + # Router gate weights are tiny and path-dependent; keep them unquantized + # so per-path quantization configs can reference the gate without failing. + return self + def __call__(self, x): return group_expert_select( x @ self.weight.T, diff --git a/tests/test_models.py b/tests/test_models.py index 6e1fcd96e..ac971cc3b 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -842,6 +842,57 @@ def test_gemma4_moe_router_quantizes_to_8bit(self): ) self.assertEqual(config["quantization"]["bits"], 4) + def test_nemotron_h_moe_gate_quantization(self): + from mlx_lm.models import nemotron_h + + args = nemotron_h.ModelArgs.from_dict( + { + "model_type": "nemotron_h", + "vocab_size": 128, + "hidden_size": 64, + "intermediate_size": 128, + "num_hidden_layers": 1, + "max_position_embeddings": 128, + "num_attention_heads": 4, + "num_key_value_heads": 2, + "attention_bias": False, + "mamba_num_heads": 4, + "mamba_head_dim": 16, + "mamba_proj_bias": False, + "ssm_state_size": 16, + "conv_kernel": 3, + "n_groups": 1, + "mlp_bias": False, + "layer_norm_epsilon": 1e-5, + "use_bias": False, + "use_conv_bias": True, + "hybrid_override_pattern": ["E"], + "moe_intermediate_size": 64, + "n_routed_experts": 4, + "n_group": 1, + "topk_group": 1, + "num_experts_per_tok": 2, + "norm_topk_prob": True, + "routed_scaling_factor": 1.0, + } + ) + model = nemotron_h.Model(args) + gate_path = "backbone.layers.0.mixer.gate" + + def class_predicate(path, module): + # Mirror utils.py: a per-path quantization config entry yields a + # truthy result even for modules that cannot themselves be quantized. + if path == gate_path: + return {"group_size": 64, "bits": 4} + return hasattr(module, "to_quantized") + + # Must not raise: MoEGate exposes a no-op to_quantized(). + nn.quantize(model, group_size=64, bits=4, class_predicate=class_predicate) + + gate = model.backbone.layers[0].mixer.gate + self.assertIsInstance(gate, nemotron_h.MoEGate) + self.assertIs(gate.to_quantized(), gate) + def test_qwen2_moe(self): from mlx_lm.models import qwen2_moe