diff --git a/mtplx/artifacts.py b/mtplx/artifacts.py index 509872b..ed8d609 100644 --- a/mtplx/artifacts.py +++ b/mtplx/artifacts.py @@ -48,6 +48,59 @@ def is_mtp_key(key: str) -> bool: return any(text.startswith(prefix) for prefix in MTP_KEY_PREFIXES) +def _num_mtp_layers(config: dict[str, Any]) -> int: + tcfg = text_config(config) + return int( + tcfg.get("mtp_num_hidden_layers") + or tcfg.get("num_nextn_predict_layers") + or config.get("num_nextn_predict_layers") + or 0 + ) + + +def _moe_mtp_expected_key_set(config: dict[str, Any]) -> tuple[set[str], int, str]: + """Expected MTP keys for a Qwen3.5-MoE (bf16) MTP head. + + The MoE MTP layer mirrors a Qwen3.5-MoE decoder layer: full self-attention + plus a SparseMoeBlock (router ``gate`` + ``num_experts`` per-expert MLPs and + a ``shared_expert`` / ``shared_expert_gate``). Experts are stored one tensor + per expert in the sidecar, so the expected count scales with ``num_experts`` + instead of the dense head's fixed 15. + """ + tcfg = text_config(config) + num_experts = int(tcfg.get("num_experts") or 0) + n_layers = max(_num_mtp_layers(config), 1) + keys: set[str] = { + "mtp.fc.weight", + "mtp.norm.weight", + "mtp.pre_fc_norm_embedding.weight", + "mtp.pre_fc_norm_hidden.weight", + } + for li in range(n_layers): + base = f"mtp.layers.{li}" + keys.update( + { + f"{base}.input_layernorm.weight", + f"{base}.post_attention_layernorm.weight", + f"{base}.self_attn.q_proj.weight", + f"{base}.self_attn.k_proj.weight", + f"{base}.self_attn.v_proj.weight", + f"{base}.self_attn.o_proj.weight", + f"{base}.self_attn.q_norm.weight", + f"{base}.self_attn.k_norm.weight", + f"{base}.mlp.gate.weight", + f"{base}.mlp.shared_expert.gate_proj.weight", + f"{base}.mlp.shared_expert.up_proj.weight", + f"{base}.mlp.shared_expert.down_proj.weight", + f"{base}.mlp.shared_expert_gate.weight", + } + ) + for e in range(num_experts): + for proj in ("gate_proj", "up_proj", "down_proj"): + keys.add(f"{base}.mlp.experts.{e}.{proj}.weight") + return keys, len(keys), "bf16-moe" + + def _mtp_expected_key_set( config: dict[str, Any], *, @@ -55,6 +108,8 @@ def _mtp_expected_key_set( ) -> tuple[set[str], int, str]: mtp_quant = config.get("mtplx_mtp_quantization", {}) prequantized = isinstance(mtp_quant, dict) and bool(mtp_quant.get("prequantized")) + if not prequantized and int(text_config(config).get("num_experts") or 0) > 0: + return _moe_mtp_expected_key_set(config) quant_policy = str(mtp_quant.get("policy") or "") if isinstance(mtp_quant, dict) else "" normalized = {normalize_mtp_key(key) for key in keys} if prequantized and quant_policy == "all": diff --git a/mtplx/mtp_patch.py b/mtplx/mtp_patch.py index f5b82ba..c95d2aa 100644 --- a/mtplx/mtp_patch.py +++ b/mtplx/mtp_patch.py @@ -113,6 +113,45 @@ def predicate(path: str, module: Any): nn.quantize(mtp, class_predicate=predicate) +def _stack_mtp_moe_experts( + weights: dict[str, Any], + config: dict[str, Any], +) -> dict[str, Any]: + """Stack per-expert MTP MLP weights into SwitchGLU (``switch_mlp``) layout. + + Qwen3.5-MoE ships an MoE MTP head whose experts are stored one tensor per + expert (``layers.{i}.mlp.experts.{e}.{proj}.{leaf}``), but mlx-lm's + ``SparseMoeBlock`` expects a single stacked tensor per projection + (``layers.{i}.mlp.switch_mlp.{proj}.{leaf}`` of shape ``[num_experts, ...]``). + This mirrors the stacking mlx-lm performs for the main decoder layers. + + No-op for dense MTP heads (``num_experts <= 0`` or no per-expert keys), so + the existing Qwen3-Next dense path is unaffected. + """ + tcfg = text_config(config) + num_experts = int(tcfg.get("num_experts") or 0) + if num_experts <= 0: + return weights + + import mlx.core as mx + + n_layers = max(_num_mtp_layers(config), 1) + for li in range(n_layers): + prefix = f"layers.{li}.mlp" + if f"{prefix}.experts.0.gate_proj.weight" not in weights: + continue + for proj in ("gate_proj", "up_proj", "down_proj"): + for leaf in ("weight", "scales", "biases"): + if f"{prefix}.experts.0.{proj}.{leaf}" not in weights: + continue + stacked = [ + weights.pop(f"{prefix}.experts.{e}.{proj}.{leaf}") + for e in range(num_experts) + ] + weights[f"{prefix}.switch_mlp.{proj}.{leaf}"] = mx.stack(stacked, axis=0) + return weights + + def _finalize_mtp_weights( raw_mtp: dict[str, Any], config: dict[str, Any], @@ -137,7 +176,7 @@ def _finalize_mtp_weights( if value.ndim == 1 and any(key.endswith(suffix) for suffix in _RMSNORM_SUFFIXES): if float(value.mean().item()) < 0.5: weights[key] = value + 1.0 - return weights + return _stack_mtp_moe_experts(weights, config) weights: dict[str, Any] = {} processed: set[str] = set() @@ -165,7 +204,7 @@ def _finalize_mtp_weights( if value.ndim == 1 and any(key.endswith(suffix) for suffix in _RMSNORM_SUFFIXES): if float(value.mean().item()) < 0.5: weights[key] = value + 1.0 - return weights + return _stack_mtp_moe_experts(weights, config) def _strip_mtp_namespace(key: str) -> str: diff --git a/tests/test_artifacts.py b/tests/test_artifacts.py index b63b612..d54b8bc 100644 --- a/tests/test_artifacts.py +++ b/tests/test_artifacts.py @@ -246,6 +246,66 @@ def test_all_prequantized_mtp_sidecar_accepts_quantized_fc_tensors(tmp_path): assert result.compatibility["recommended_profile"] == "performance-cold" +def test_qwen3_5_moe_mtp_sidecar_passes_moe_tensor_gate(tmp_path): + num_experts = 4 + (tmp_path / "config.json").write_text( + json.dumps( + { + "architectures": ["Qwen3_5MoeForConditionalGeneration"], + "model_type": "qwen3_5_moe", + "text_config": { + "model_type": "qwen3_5_moe_text", + "mtp_num_hidden_layers": 1, + "num_experts": num_experts, + "hidden_size": 256, + "num_hidden_layers": 8, + "vocab_size": 1000, + }, + "mlx_lm_extra_tensors": {"mtp_file": "mtp.safetensors"}, + } + ) + ) + base = "mtp.layers.0" + keys = { + "mtp.fc.weight", + "mtp.norm.weight", + "mtp.pre_fc_norm_embedding.weight", + "mtp.pre_fc_norm_hidden.weight", + f"{base}.input_layernorm.weight", + f"{base}.post_attention_layernorm.weight", + f"{base}.self_attn.q_proj.weight", + f"{base}.self_attn.k_proj.weight", + f"{base}.self_attn.v_proj.weight", + f"{base}.self_attn.o_proj.weight", + f"{base}.self_attn.q_norm.weight", + f"{base}.self_attn.k_norm.weight", + f"{base}.mlp.gate.weight", + f"{base}.mlp.shared_expert.gate_proj.weight", + f"{base}.mlp.shared_expert.up_proj.weight", + f"{base}.mlp.shared_expert.down_proj.weight", + f"{base}.mlp.shared_expert_gate.weight", + } + for e in range(num_experts): + for proj in ("gate_proj", "up_proj", "down_proj"): + keys.add(f"{base}.mlp.experts.{e}.{proj}.weight") + save_file( + {key: np.ones((1,), dtype=np.float32) for key in keys}, + tmp_path / "mtp.safetensors", + ) + + result = inspect_model(tmp_path) + assert result.mtp is not None + assert result.mtp.sidecar_format == "bf16-moe" + # 4 global + 13 per-layer structural + 3 projections per expert. + assert result.mtp.tensor_count == 4 + 13 + 3 * num_experts + assert result.mtp.missing_expected_keys == () + assert result.mtp.extra_keys == () + assert result.mtp.passes_tensor_gate is True + assert result.passes_primary_gate is True + assert result.compatibility["can_run"] is True + assert result.compatibility["arch_id"] == "qwen3-next-mtp" + + def test_qwen_mtp_without_runtime_contract_is_family_runnable(monkeypatch, tmp_path): from mtplx import artifacts from mtplx.artifacts import MTPInspection diff --git a/tests/test_mtp_patch.py b/tests/test_mtp_patch.py index 3bc2779..e3b3e21 100644 --- a/tests/test_mtp_patch.py +++ b/tests/test_mtp_patch.py @@ -2,7 +2,7 @@ import pytest -from mtplx.mtp_patch import MTPContract +from mtplx.mtp_patch import MTPContract, _stack_mtp_moe_experts def test_mtp_contract_reads_config_quant_defaults() -> None: @@ -45,3 +45,36 @@ def test_mtp_contract_cli_bits_override_config_bits() -> None: def test_mtp_contract_rejects_unknown_quant_policy() -> None: with pytest.raises(ValueError, match="mtp_quant_policy"): MTPContract(mtp_quant_policy="mystery").validate() + + +def test_stack_mtp_moe_experts_stacks_per_expert_into_switch_mlp() -> None: + mx = pytest.importorskip("mlx.core") + num_experts, out_dim, in_dim = 4, 6, 8 + config = {"text_config": {"num_experts": num_experts, "mtp_num_hidden_layers": 1}} + weights = {"layers.0.input_layernorm.weight": mx.ones((in_dim,))} + for e in range(num_experts): + for proj, out in (("gate_proj", out_dim), ("up_proj", out_dim), ("down_proj", in_dim)): + cols = in_dim if proj != "down_proj" else out_dim + weights[f"layers.0.mlp.experts.{e}.{proj}.weight"] = mx.full((out, cols), float(e)) + + stacked = _stack_mtp_moe_experts(weights, config) + + # Per-expert keys are consumed; stacked switch_mlp keys appear with a leading expert axis. + assert not any(".experts." in k for k in stacked) + assert tuple(stacked["layers.0.mlp.switch_mlp.gate_proj.weight"].shape) == (num_experts, out_dim, in_dim) + assert tuple(stacked["layers.0.mlp.switch_mlp.down_proj.weight"].shape) == (num_experts, in_dim, out_dim) + # Stacking preserves per-expert ordering (expert e was filled with value e). + gate = stacked["layers.0.mlp.switch_mlp.gate_proj.weight"] + for e in range(num_experts): + assert float(gate[e, 0, 0].item()) == float(e) + # Non-expert tensors pass through untouched. + assert "layers.0.input_layernorm.weight" in stacked + + +def test_stack_mtp_moe_experts_is_noop_for_dense_head() -> None: + mx = pytest.importorskip("mlx.core") + config = {"text_config": {"num_experts": 0, "mtp_num_hidden_layers": 1}} + weights = {"layers.0.mlp.gate_proj.weight": mx.ones((4, 4))} + out = _stack_mtp_moe_experts(weights, config) + assert "layers.0.mlp.switch_mlp.gate_proj.weight" not in out + assert "layers.0.mlp.gate_proj.weight" in out