Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions mtplx/artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,68 @@ def is_mtp_key(key: str) -> bool:
return any(text.startswith(prefix) for prefix in MTP_KEY_PREFIXES)


def _num_mtp_layers(config: dict[str, Any]) -> int:
tcfg = text_config(config)
return int(
tcfg.get("mtp_num_hidden_layers")
or tcfg.get("num_nextn_predict_layers")
or config.get("num_nextn_predict_layers")
or 0
)


def _moe_mtp_expected_key_set(config: dict[str, Any]) -> tuple[set[str], int, str]:
"""Expected MTP keys for a Qwen3.5-MoE (bf16) MTP head.

The MoE MTP layer mirrors a Qwen3.5-MoE decoder layer: full self-attention
plus a SparseMoeBlock (router ``gate`` + ``num_experts`` per-expert MLPs and
a ``shared_expert`` / ``shared_expert_gate``). Experts are stored one tensor
per expert in the sidecar, so the expected count scales with ``num_experts``
instead of the dense head's fixed 15.
"""
tcfg = text_config(config)
num_experts = int(tcfg.get("num_experts") or 0)
n_layers = max(_num_mtp_layers(config), 1)
keys: set[str] = {
"mtp.fc.weight",
"mtp.norm.weight",
"mtp.pre_fc_norm_embedding.weight",
"mtp.pre_fc_norm_hidden.weight",
}
for li in range(n_layers):
base = f"mtp.layers.{li}"
keys.update(
{
f"{base}.input_layernorm.weight",
f"{base}.post_attention_layernorm.weight",
f"{base}.self_attn.q_proj.weight",
f"{base}.self_attn.k_proj.weight",
f"{base}.self_attn.v_proj.weight",
f"{base}.self_attn.o_proj.weight",
f"{base}.self_attn.q_norm.weight",
f"{base}.self_attn.k_norm.weight",
f"{base}.mlp.gate.weight",
f"{base}.mlp.shared_expert.gate_proj.weight",
f"{base}.mlp.shared_expert.up_proj.weight",
f"{base}.mlp.shared_expert.down_proj.weight",
f"{base}.mlp.shared_expert_gate.weight",
}
)
for e in range(num_experts):
for proj in ("gate_proj", "up_proj", "down_proj"):
keys.add(f"{base}.mlp.experts.{e}.{proj}.weight")
return keys, len(keys), "bf16-moe"


def _mtp_expected_key_set(
config: dict[str, Any],
*,
keys: tuple[str, ...] = (),
) -> tuple[set[str], int, str]:
mtp_quant = config.get("mtplx_mtp_quantization", {})
prequantized = isinstance(mtp_quant, dict) and bool(mtp_quant.get("prequantized"))
if not prequantized and int(text_config(config).get("num_experts") or 0) > 0:
return _moe_mtp_expected_key_set(config)
quant_policy = str(mtp_quant.get("policy") or "") if isinstance(mtp_quant, dict) else ""
normalized = {normalize_mtp_key(key) for key in keys}
if prequantized and quant_policy == "all":
Expand Down
43 changes: 41 additions & 2 deletions mtplx/mtp_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,45 @@ def predicate(path: str, module: Any):
nn.quantize(mtp, class_predicate=predicate)


def _stack_mtp_moe_experts(
weights: dict[str, Any],
config: dict[str, Any],
) -> dict[str, Any]:
"""Stack per-expert MTP MLP weights into SwitchGLU (``switch_mlp``) layout.

Qwen3.5-MoE ships an MoE MTP head whose experts are stored one tensor per
expert (``layers.{i}.mlp.experts.{e}.{proj}.{leaf}``), but mlx-lm's
``SparseMoeBlock`` expects a single stacked tensor per projection
(``layers.{i}.mlp.switch_mlp.{proj}.{leaf}`` of shape ``[num_experts, ...]``).
This mirrors the stacking mlx-lm performs for the main decoder layers.

No-op for dense MTP heads (``num_experts <= 0`` or no per-expert keys), so
the existing Qwen3-Next dense path is unaffected.
"""
tcfg = text_config(config)
num_experts = int(tcfg.get("num_experts") or 0)
if num_experts <= 0:
return weights

import mlx.core as mx

n_layers = max(_num_mtp_layers(config), 1)
for li in range(n_layers):
prefix = f"layers.{li}.mlp"
if f"{prefix}.experts.0.gate_proj.weight" not in weights:
continue
for proj in ("gate_proj", "up_proj", "down_proj"):
for leaf in ("weight", "scales", "biases"):
if f"{prefix}.experts.0.{proj}.{leaf}" not in weights:
continue
stacked = [
weights.pop(f"{prefix}.experts.{e}.{proj}.{leaf}")
for e in range(num_experts)
]
weights[f"{prefix}.switch_mlp.{proj}.{leaf}"] = mx.stack(stacked, axis=0)
return weights


def _finalize_mtp_weights(
raw_mtp: dict[str, Any],
config: dict[str, Any],
Expand All @@ -137,7 +176,7 @@ def _finalize_mtp_weights(
if value.ndim == 1 and any(key.endswith(suffix) for suffix in _RMSNORM_SUFFIXES):
if float(value.mean().item()) < 0.5:
weights[key] = value + 1.0
return weights
return _stack_mtp_moe_experts(weights, config)

weights: dict[str, Any] = {}
processed: set[str] = set()
Expand Down Expand Up @@ -165,7 +204,7 @@ def _finalize_mtp_weights(
if value.ndim == 1 and any(key.endswith(suffix) for suffix in _RMSNORM_SUFFIXES):
if float(value.mean().item()) < 0.5:
weights[key] = value + 1.0
return weights
return _stack_mtp_moe_experts(weights, config)


def _strip_mtp_namespace(key: str) -> str:
Expand Down
60 changes: 60 additions & 0 deletions tests/test_artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,66 @@ def test_all_prequantized_mtp_sidecar_accepts_quantized_fc_tensors(tmp_path):
assert result.compatibility["recommended_profile"] == "performance-cold"


def test_qwen3_5_moe_mtp_sidecar_passes_moe_tensor_gate(tmp_path):
num_experts = 4
(tmp_path / "config.json").write_text(
json.dumps(
{
"architectures": ["Qwen3_5MoeForConditionalGeneration"],
"model_type": "qwen3_5_moe",
"text_config": {
"model_type": "qwen3_5_moe_text",
"mtp_num_hidden_layers": 1,
"num_experts": num_experts,
"hidden_size": 256,
"num_hidden_layers": 8,
"vocab_size": 1000,
},
"mlx_lm_extra_tensors": {"mtp_file": "mtp.safetensors"},
}
)
)
base = "mtp.layers.0"
keys = {
"mtp.fc.weight",
"mtp.norm.weight",
"mtp.pre_fc_norm_embedding.weight",
"mtp.pre_fc_norm_hidden.weight",
f"{base}.input_layernorm.weight",
f"{base}.post_attention_layernorm.weight",
f"{base}.self_attn.q_proj.weight",
f"{base}.self_attn.k_proj.weight",
f"{base}.self_attn.v_proj.weight",
f"{base}.self_attn.o_proj.weight",
f"{base}.self_attn.q_norm.weight",
f"{base}.self_attn.k_norm.weight",
f"{base}.mlp.gate.weight",
f"{base}.mlp.shared_expert.gate_proj.weight",
f"{base}.mlp.shared_expert.up_proj.weight",
f"{base}.mlp.shared_expert.down_proj.weight",
f"{base}.mlp.shared_expert_gate.weight",
}
for e in range(num_experts):
for proj in ("gate_proj", "up_proj", "down_proj"):
keys.add(f"{base}.mlp.experts.{e}.{proj}.weight")
save_file(
{key: np.ones((1,), dtype=np.float32) for key in keys},
tmp_path / "mtp.safetensors",
)

result = inspect_model(tmp_path)
assert result.mtp is not None
assert result.mtp.sidecar_format == "bf16-moe"
# 4 global + 13 per-layer structural + 3 projections per expert.
assert result.mtp.tensor_count == 4 + 13 + 3 * num_experts
assert result.mtp.missing_expected_keys == ()
assert result.mtp.extra_keys == ()
assert result.mtp.passes_tensor_gate is True
assert result.passes_primary_gate is True
assert result.compatibility["can_run"] is True
assert result.compatibility["arch_id"] == "qwen3-next-mtp"


def test_qwen_mtp_without_runtime_contract_is_family_runnable(monkeypatch, tmp_path):
from mtplx import artifacts
from mtplx.artifacts import MTPInspection
Expand Down
35 changes: 34 additions & 1 deletion tests/test_mtp_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import pytest

from mtplx.mtp_patch import MTPContract
from mtplx.mtp_patch import MTPContract, _stack_mtp_moe_experts


def test_mtp_contract_reads_config_quant_defaults() -> None:
Expand Down Expand Up @@ -45,3 +45,36 @@ def test_mtp_contract_cli_bits_override_config_bits() -> None:
def test_mtp_contract_rejects_unknown_quant_policy() -> None:
with pytest.raises(ValueError, match="mtp_quant_policy"):
MTPContract(mtp_quant_policy="mystery").validate()


def test_stack_mtp_moe_experts_stacks_per_expert_into_switch_mlp() -> None:
mx = pytest.importorskip("mlx.core")
num_experts, out_dim, in_dim = 4, 6, 8
config = {"text_config": {"num_experts": num_experts, "mtp_num_hidden_layers": 1}}
weights = {"layers.0.input_layernorm.weight": mx.ones((in_dim,))}
for e in range(num_experts):
for proj, out in (("gate_proj", out_dim), ("up_proj", out_dim), ("down_proj", in_dim)):
cols = in_dim if proj != "down_proj" else out_dim
weights[f"layers.0.mlp.experts.{e}.{proj}.weight"] = mx.full((out, cols), float(e))

stacked = _stack_mtp_moe_experts(weights, config)

# Per-expert keys are consumed; stacked switch_mlp keys appear with a leading expert axis.
assert not any(".experts." in k for k in stacked)
assert tuple(stacked["layers.0.mlp.switch_mlp.gate_proj.weight"].shape) == (num_experts, out_dim, in_dim)
assert tuple(stacked["layers.0.mlp.switch_mlp.down_proj.weight"].shape) == (num_experts, in_dim, out_dim)
# Stacking preserves per-expert ordering (expert e was filled with value e).
gate = stacked["layers.0.mlp.switch_mlp.gate_proj.weight"]
for e in range(num_experts):
assert float(gate[e, 0, 0].item()) == float(e)
# Non-expert tensors pass through untouched.
assert "layers.0.input_layernorm.weight" in stacked


def test_stack_mtp_moe_experts_is_noop_for_dense_head() -> None:
mx = pytest.importorskip("mlx.core")
config = {"text_config": {"num_experts": 0, "mtp_num_hidden_layers": 1}}
weights = {"layers.0.mlp.gate_proj.weight": mx.ones((4, 4))}
out = _stack_mtp_moe_experts(weights, config)
assert "layers.0.mlp.switch_mlp.gate_proj.weight" not in out
assert "layers.0.mlp.gate_proj.weight" in out