Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions mlx_lm/models/qwen3_5.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,11 +305,9 @@ def make_cache(self):
return [ArraysCache(size=2) if l.is_linear else KVCache() for l in self.layers]

def sanitize(self, weights):
has_mtp_weights = any("mtp." in k for k in weights)
has_unsanitized_conv1d = any(
"conv1d.weight" in k and v.shape[-1] != 1 for k, v in weights.items()
)
should_shift_norm_weights = has_mtp_weights or has_unsanitized_conv1d
weights = {k: v for k, v in weights.items() if "mtp." not in k}

if self.args.tie_word_embeddings:
Expand All @@ -325,7 +323,7 @@ def sanitize(self, weights):
for k, v in weights.items():
if "conv1d.weight" in k and v.shape[-1] != 1:
weights[k] = v.moveaxis(2, 1)
if should_shift_norm_weights and any(k.endswith(sfx) for sfx in norm_keys):
if has_unsanitized_conv1d and any(k.endswith(sfx) for sfx in norm_keys):
if v.ndim == 1:
weights[k] = v + 1.0
return weights
Expand Down
20 changes: 12 additions & 8 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,12 +615,11 @@ def test_qwen3_5_family_convert_then_load_norm_not_shift_twice(self):
"max_position_embeddings": 64,
}
hf_norm_key = "model.language_model.layers.0.input_layernorm.weight"
hf_conv_key = "model.language_model.layers.0.linear_attn.conv1d.weight"
mlx_norm_key = "language_model.model.layers.0.input_layernorm.weight"
mlx_mtp_key = "language_model.mtp.fc.weight"

for model_type, hf_mtp_key in (
("qwen3_5", "mtp.fc.weights"),
("qwen3_5_moe", "mtp.fc.weight"),
):
for model_type in ("qwen3_5", "qwen3_5_moe"):
module = importlib.import_module(f"mlx_lm.models.{model_type}")
args = module.ModelArgs.from_dict(
{
Expand All @@ -636,18 +635,23 @@ def test_qwen3_5_family_convert_then_load_norm_not_shift_twice(self):
converted = model.sanitize(
{
hf_norm_key: base,
hf_mtp_key: mx.zeros((1,), dtype=mx.float32),
hf_conv_key: mx.zeros((4, 1, 3), dtype=mx.float32),
}
)
self.assertIn(mlx_norm_key, converted)
self.assertTrue(mx.array_equal(converted[mlx_norm_key], base + 1.0))
self.assertFalse(any("mtp." in k for k in converted))

# Simulate load sanitize on already-converted keys.
loaded = model.sanitize(converted)
# Simulate load sanitize on already-converted keys with MTP weights.
loaded = model.sanitize(
{
mlx_norm_key: converted[mlx_norm_key],
mlx_mtp_key: mx.zeros((1,), dtype=mx.float32),
}
)
self.assertTrue(
mx.array_equal(loaded[mlx_norm_key], converted[mlx_norm_key])
)
self.assertFalse(any("mtp." in k for k in loaded))

def test_gemma4_convert_then_load_keeps_language_model_prefix(self):
from mlx_lm.models import gemma4
Expand Down