Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions mlx_lm/models/nemotron_h.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,11 @@ def __init__(self, config: ModelArgs):
self.weight = mx.zeros((self.n_routed_experts, config.hidden_size))
self.e_score_correction_bias = mx.zeros((self.n_routed_experts,))

def to_quantized(self, group_size: int = 64, bits: int = 4, mode: str = "affine"):
# Router gate weights are tiny and path-dependent; keep them unquantized
# so per-path quantization configs can reference the gate without failing.
return self

def __call__(self, x):
return group_expert_select(
x @ self.weight.T,
Expand Down
51 changes: 51 additions & 0 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -842,6 +842,57 @@ def test_gemma4_moe_router_quantizes_to_8bit(self):
)
self.assertEqual(config["quantization"]["bits"], 4)

def test_nemotron_h_moe_gate_quantization(self):
from mlx_lm.models import nemotron_h

args = nemotron_h.ModelArgs.from_dict(
{
"model_type": "nemotron_h",
"vocab_size": 128,
"hidden_size": 64,
"intermediate_size": 128,
"num_hidden_layers": 1,
"max_position_embeddings": 128,
"num_attention_heads": 4,
"num_key_value_heads": 2,
"attention_bias": False,
"mamba_num_heads": 4,
"mamba_head_dim": 16,
"mamba_proj_bias": False,
"ssm_state_size": 16,
"conv_kernel": 3,
"n_groups": 1,
"mlp_bias": False,
"layer_norm_epsilon": 1e-5,
"use_bias": False,
"use_conv_bias": True,
"hybrid_override_pattern": ["E"],
"moe_intermediate_size": 64,
"n_routed_experts": 4,
"n_group": 1,
"topk_group": 1,
"num_experts_per_tok": 2,
"norm_topk_prob": True,
"routed_scaling_factor": 1.0,
}
)
model = nemotron_h.Model(args)
gate_path = "backbone.layers.0.mixer.gate"

def class_predicate(path, module):
# Mirror utils.py: a per-path quantization config entry yields a
# truthy result even for modules that cannot themselves be quantized.
if path == gate_path:
return {"group_size": 64, "bits": 4}
return hasattr(module, "to_quantized")

# Must not raise: MoEGate exposes a no-op to_quantized().
nn.quantize(model, group_size=64, bits=4, class_predicate=class_predicate)

gate = model.backbone.layers[0].mixer.gate
self.assertIsInstance(gate, nemotron_h.MoEGate)
self.assertIs(gate.to_quantized(), gate)

def test_qwen2_moe(self):
from mlx_lm.models import qwen2_moe

Expand Down