fix(load): wire MTP sanitize-preservation patch into VLM inference path#1320
Open
sangemaru wants to merge 1 commit into
Open
fix(load): wire MTP sanitize-preservation patch into VLM inference path#1320sangemaru wants to merge 1 commit into
sangemaru wants to merge 1 commit into
Conversation
MTP heads loaded onto VLM-path checkpoints (any model with vision_config, e.g. all Qwen3.6 PARO models) were silently stripped at load, leaving the MTPModule at random init -> 0% accept rate. Root cause: mlx-vlm's stock Model.sanitize() unconditionally drops every `mtp.*` key. The sanitize-preservation override (apply_mlx_vlm_mtp_patch in patches/mlx_vlm_mtp/qwen35_vlm_model.py) was only wired into the oQ quantization path, never the inference load path. maybe_apply_pre_load_patches called only apply_mlx_vlm_mtp_runtime_patch (attaches the MTPModule) but not the sanitize patch, so the runtime head was created then left unfilled. LLM-path checkpoints (no vision_config -> BatchedEngine) were unaffected because the mlx-lm sanitize patch (qwen35_model.py) already preserves MTP. Fix: call apply_mlx_vlm_mtp_patch() alongside apply_mlx_vlm_mtp_runtime_patch() in the mtp_enabled VLM branch. Verified on z-lab/Qwen3.6-27B-PARO + an injected BF16 MTP head: weights-with-mtp in the load dict went 0 -> 15, loaded params bit-match source (maxabsdiff=0.0), and accept rate went 0% -> 68.8%.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Problem
MTP heads loaded onto VLM-path checkpoints (any model with
vision_config→ loads viaVLMBatchedEngine, which includes all Qwen3.6 PARO models) are silently stripped at load, leaving theMTPModuleat random init → 0% accept rate.The model still generates correct output (the base path is fine), but speculative decoding is dead weight: the MTP head runs, produces garbage drafts from random-init weights, and every draft is rejected.
Root cause
mlx-vlm's stock
Model.sanitize()(mlx_vlm/models/qwen3_5/qwen3_5.py) unconditionally drops everymtp.*key:oMLX already has a sanitize-preservation override for exactly this —
apply_mlx_vlm_mtp_patch()inomlx/patches/mlx_vlm_mtp/qwen35_vlm_model.py, which keepsmtp.*keys and remaps them tolanguage_model.mtp.*. But it was only wired into the oQ-quantization path, never the inference load path.maybe_apply_pre_load_patches()inomlx/utils/model_loading.pycalled onlyapply_mlx_vlm_mtp_runtime_patch()(which attaches theMTPModuletoLanguageModel.__init__) — so the head module was created but then left unfilled because the stock sanitize stripped its weights beforeload_weights().LLM-path checkpoints (no
vision_config→BatchedEngine) were unaffected, because the mlx-lm sanitize patch (omlx/patches/mlx_lm_mtp/qwen35_model.py) already preserves MTP on that path. The bug only bit VLM-arch checkpoints that ship MTP heads.Fix
Call
apply_mlx_vlm_mtp_patch()alongsideapply_mlx_vlm_mtp_runtime_patch()in themtp_enabledVLM branch ofmaybe_apply_pre_load_patches(). Two-import + one-call addition; no behavior change for any path that wasn't already broken.Verification
Tested on
z-lab/Qwen3.6-27B-PARO(declarestext_config.mtp_num_hidden_layers=1) with an injected BF16 MTP head (mtp.safetensorsfromguru87/Qwen3.6-27B-MTP, lifted verbatim fromQwen/Qwen3.6-27B):Instrumented
paroquant.inference.backends.mlx.load.loadright beforemodel.load_weights(...):mtp.*keys in the load dictfc.weight std≈0.006,input_layernormall-ones)maxabsdiff=0.000000)Log line added on success:
Note (not part of this fix)
On M1 Max, MTP remains net-negative vs no-MTP for steady-state decode regardless of this fix, because the sampler runs serially per accepted token (tracked separately in #1311). This PR is strictly about correctness — making the MTP head actually load on the VLM path. Whether MTP is a win is a separate, hardware/runtime question.
Environment: oMLX 0.3.9.dev2, M1 Max 64 GB, macOS 25.4, mlx-lm 0.31.3, paroquant[mlx] 0.1.15.