Skip to content

fix(load): wire MTP sanitize-preservation patch into VLM inference path#1320

Open
sangemaru wants to merge 1 commit into
jundot:mainfrom
sangemaru:fix/vlm-mtp-sanitize-inference-load
Open

fix(load): wire MTP sanitize-preservation patch into VLM inference path#1320
sangemaru wants to merge 1 commit into
jundot:mainfrom
sangemaru:fix/vlm-mtp-sanitize-inference-load

Conversation

@sangemaru
Copy link
Copy Markdown

Problem

MTP heads loaded onto VLM-path checkpoints (any model with vision_config → loads via VLMBatchedEngine, which includes all Qwen3.6 PARO models) are silently stripped at load, leaving the MTPModule at random init → 0% accept rate.

The model still generates correct output (the base path is fine), but speculative decoding is dead weight: the MTP head runs, produces garbage drafts from random-init weights, and every draft is rejected.

Root cause

mlx-vlm's stock Model.sanitize() (mlx_vlm/models/qwen3_5/qwen3_5.py) unconditionally drops every mtp.* key:

# ignore mtp weights
weights = {key: value for key, value in weights.items() if "mtp." not in key}

oMLX already has a sanitize-preservation override for exactly this — apply_mlx_vlm_mtp_patch() in omlx/patches/mlx_vlm_mtp/qwen35_vlm_model.py, which keeps mtp.* keys and remaps them to language_model.mtp.*. But it was only wired into the oQ-quantization path, never the inference load path.

maybe_apply_pre_load_patches() in omlx/utils/model_loading.py called only apply_mlx_vlm_mtp_runtime_patch() (which attaches the MTPModule to LanguageModel.__init__) — so the head module was created but then left unfilled because the stock sanitize stripped its weights before load_weights().

LLM-path checkpoints (no vision_configBatchedEngine) were unaffected, because the mlx-lm sanitize patch (omlx/patches/mlx_lm_mtp/qwen35_model.py) already preserves MTP on that path. The bug only bit VLM-arch checkpoints that ship MTP heads.

Fix

Call apply_mlx_vlm_mtp_patch() alongside apply_mlx_vlm_mtp_runtime_patch() in the mtp_enabled VLM branch of maybe_apply_pre_load_patches(). Two-import + one-call addition; no behavior change for any path that wasn't already broken.

Verification

Tested on z-lab/Qwen3.6-27B-PARO (declares text_config.mtp_num_hidden_layers=1) with an injected BF16 MTP head (mtp.safetensors from guru87/Qwen3.6-27B-MTP, lifted verbatim from Qwen/Qwen3.6-27B):

Instrumented paroquant.inference.backends.mlx.load.load right before model.load_weights(...):

Metric Before fix After fix
mtp.* keys in the load dict 0 15
loaded MTP params vs source weights random init (fc.weight std≈0.006, input_layernorm all-ones) bit-exact (maxabsdiff=0.000000)
MTP accept rate (smoke) 0% 68.8%
32K warm accept rate 61.5%
64K warm accept rate 75.2%

Log line added on success:

mlx-vlm MTP sanitize patch applied for /path/to/Qwen3.6-27B-PARO

Note (not part of this fix)

On M1 Max, MTP remains net-negative vs no-MTP for steady-state decode regardless of this fix, because the sampler runs serially per accepted token (tracked separately in #1311). This PR is strictly about correctness — making the MTP head actually load on the VLM path. Whether MTP is a win is a separate, hardware/runtime question.

Environment: oMLX 0.3.9.dev2, M1 Max 64 GB, macOS 25.4, mlx-lm 0.31.3, paroquant[mlx] 0.1.15.

MTP heads loaded onto VLM-path checkpoints (any model with vision_config,
e.g. all Qwen3.6 PARO models) were silently stripped at load, leaving the
MTPModule at random init -> 0% accept rate.

Root cause: mlx-vlm's stock Model.sanitize() unconditionally drops every
`mtp.*` key. The sanitize-preservation override (apply_mlx_vlm_mtp_patch
in patches/mlx_vlm_mtp/qwen35_vlm_model.py) was only wired into the oQ
quantization path, never the inference load path. maybe_apply_pre_load_patches
called only apply_mlx_vlm_mtp_runtime_patch (attaches the MTPModule) but not
the sanitize patch, so the runtime head was created then left unfilled.

LLM-path checkpoints (no vision_config -> BatchedEngine) were unaffected
because the mlx-lm sanitize patch (qwen35_model.py) already preserves MTP.

Fix: call apply_mlx_vlm_mtp_patch() alongside apply_mlx_vlm_mtp_runtime_patch()
in the mtp_enabled VLM branch.

Verified on z-lab/Qwen3.6-27B-PARO + an injected BF16 MTP head: weights-with-mtp
in the load dict went 0 -> 15, loaded params bit-match source (maxabsdiff=0.0),
and accept rate went 0% -> 68.8%.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant