From 4bb49aef7e572b3979c08e6f526189e4ed25fcc0 Mon Sep 17 00:00:00 2001 From: Vlad Tudorie Date: Wed, 20 May 2026 08:23:05 +0300 Subject: [PATCH] fix(load): wire MTP sanitize-preservation patch into VLM inference path MTP heads loaded onto VLM-path checkpoints (any model with vision_config, e.g. all Qwen3.6 PARO models) were silently stripped at load, leaving the MTPModule at random init -> 0% accept rate. Root cause: mlx-vlm's stock Model.sanitize() unconditionally drops every `mtp.*` key. The sanitize-preservation override (apply_mlx_vlm_mtp_patch in patches/mlx_vlm_mtp/qwen35_vlm_model.py) was only wired into the oQ quantization path, never the inference load path. maybe_apply_pre_load_patches called only apply_mlx_vlm_mtp_runtime_patch (attaches the MTPModule) but not the sanitize patch, so the runtime head was created then left unfilled. LLM-path checkpoints (no vision_config -> BatchedEngine) were unaffected because the mlx-lm sanitize patch (qwen35_model.py) already preserves MTP. Fix: call apply_mlx_vlm_mtp_patch() alongside apply_mlx_vlm_mtp_runtime_patch() in the mtp_enabled VLM branch. Verified on z-lab/Qwen3.6-27B-PARO + an injected BF16 MTP head: weights-with-mtp in the load dict went 0 -> 15, loaded params bit-match source (maxabsdiff=0.0), and accept rate went 0% -> 68.8%. --- omlx/utils/model_loading.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/omlx/utils/model_loading.py b/omlx/utils/model_loading.py index a7d3f277a..cfd6ff7da 100644 --- a/omlx/utils/model_loading.py +++ b/omlx/utils/model_loading.py @@ -158,11 +158,23 @@ def maybe_apply_pre_load_patches( if mtp_enabled: try: from ..patches.mlx_vlm_mtp import ( + apply_mlx_vlm_mtp_patch, apply_mlx_vlm_mtp_runtime_patch, ) except Exception: pass else: + # Sanitize-preservation patch MUST run too: the stock + # mlx-vlm Model.sanitize strips every ``mtp.*`` key, so + # without this the MTPModule loads at random init (0% + # accept). Previously only wired into the oQ path; needed + # on the inference load path as well for VLM checkpoints + # that ship MTP heads (e.g. PARO + injected guru87 head). + if apply_mlx_vlm_mtp_patch(): + logger.info( + "mlx-vlm MTP sanitize patch applied for %s", + model_name, + ) if apply_mlx_vlm_mtp_runtime_patch(): logger.info( "mlx-vlm runtime MTP patch applied for %s",