diff --git a/omlx/utils/model_loading.py b/omlx/utils/model_loading.py index a7d3f277a..cfd6ff7da 100644 --- a/omlx/utils/model_loading.py +++ b/omlx/utils/model_loading.py @@ -158,11 +158,23 @@ def maybe_apply_pre_load_patches( if mtp_enabled: try: from ..patches.mlx_vlm_mtp import ( + apply_mlx_vlm_mtp_patch, apply_mlx_vlm_mtp_runtime_patch, ) except Exception: pass else: + # Sanitize-preservation patch MUST run too: the stock + # mlx-vlm Model.sanitize strips every ``mtp.*`` key, so + # without this the MTPModule loads at random init (0% + # accept). Previously only wired into the oQ path; needed + # on the inference load path as well for VLM checkpoints + # that ship MTP heads (e.g. PARO + injected guru87 head). + if apply_mlx_vlm_mtp_patch(): + logger.info( + "mlx-vlm MTP sanitize patch applied for %s", + model_name, + ) if apply_mlx_vlm_mtp_runtime_patch(): logger.info( "mlx-vlm runtime MTP patch applied for %s",