From 0c5413cbe2ffdd2a9d4deb39d13aba972fd33e1a Mon Sep 17 00:00:00 2001 From: Jinyang Li Date: Wed, 13 May 2026 14:00:52 -0400 Subject: [PATCH] fix(load): VLM model loading fixes for oQ-quantized checkpoints - Expand per-layer quant keys for VLM model-tree paths so quantization config matches the MLX model parameter hierarchy - Centralise pre-load patches in oQ _measure_sensitivity - Remap nested visual keys (language_model.model.visual.* -> vision_tower.*) for MLX-format VLM models where mlx-vlm skips Model.sanitize - Fix nested-visual patch idempotency: use function-attribute marker instead of module-level flag - Add inline nested-visual post-fixup in MTP sanitize functions Co-Authored-By: Claude Opus 4.6 (1M context) --- omlx/engine/vlm.py | 64 ++++++- omlx/oq.py | 158 ++++++++---------- .../mlx_vlm_mtp/qwen35_moe_vlm_model.py | 3 + .../mlx_vlm_mtp/qwen35_moe_vlm_runtime.py | 3 + omlx/patches/mlx_vlm_mtp/qwen35_vlm_model.py | 3 + omlx/patches/qwen3_6_nested_visual.py | 19 +-- omlx/utils/model_loading.py | 59 +++++++ tests/test_oq.py | 51 ++++++ 8 files changed, 259 insertions(+), 101 deletions(-) diff --git a/omlx/engine/vlm.py b/omlx/engine/vlm.py index a4ccb1910..d5c22f76d 100644 --- a/omlx/engine/vlm.py +++ b/omlx/engine/vlm.py @@ -365,6 +365,11 @@ def _strip_audio_config_if_orphaned(model_dir: Path): def _patched(path, **kwargs): cfg = original(path, **kwargs) + + from ..utils.model_loading import expand_per_layer_quant_keys + + expand_per_layer_quant_keys(cfg) + if cfg.get("audio_config") is None: return cfg try: @@ -399,6 +404,62 @@ def _patched(path, **kwargs): _vu.load_config = original +_NESTED_VIS_PREFIX = "language_model.model.visual." +_VISION_TOWER_PREFIX = "vision_tower." + + +@contextlib.contextmanager +def _remap_nested_visual_on_load(model_dir: Path): + """Remap ``language_model.model.visual.*`` → ``vision_tower.*`` during + ``load_model`` for MLX-format models where sanitize is skipped. + + mlx-vlm's ``load_model`` skips ``Model.sanitize`` when the safetensors + metadata declares ``format=mlx``. oQ output is MLX-format, so the + nested-visual key fixup that sanitize normally applies never fires. + This context manager wraps ``load_model`` to intercept the weight dict + and perform the remap before ``nn.Module.load_weights`` is called. + + Scoped to a single ``vlm_load(...)`` call. + """ + import mlx_vlm.utils as _vu + + original_load_model = _vu.load_model + + def _patched_load_model(model_path, lazy=False, **kwargs): + import mlx.nn as _nn + + orig_load_weights = _nn.Module.load_weights + def _remapping_load_weights(self, weights_items, *args, **kw): + if isinstance(weights_items, str): + return orig_load_weights(self, weights_items, *args, **kw) + remapped = [] + n = 0 + for k, v in weights_items: + if k.startswith(_NESTED_VIS_PREFIX): + k = _VISION_TOWER_PREFIX + k[len(_NESTED_VIS_PREFIX):] + n += 1 + remapped.append((k, v)) + if n: + logger.info( + "remap_nested_visual_on_load: remapped %d keys " + "'language_model.model.visual.*' -> 'vision_tower.*'", + n, + ) + return orig_load_weights(self, remapped, *args, **kw) + + _nn.Module.load_weights = _remapping_load_weights + try: + return original_load_model(model_path, lazy, **kwargs) + finally: + _nn.Module.load_weights = orig_load_weights + + _vu.load_model = _patched_load_model + try: + yield + finally: + _vu.load_model = original_load_model + + # Models that only support a single image per request SINGLE_IMAGE_ONLY_MODELS = { "llava_next", @@ -605,7 +666,8 @@ async def start(self) -> None: def _load_vlm_sync(): _patch_video_processor_bug() _patch_torch_free_image_processor() - with _strip_audio_config_if_orphaned(Path(self._model_name)): + with _strip_audio_config_if_orphaned(Path(self._model_name)), \ + _remap_nested_visual_on_load(Path(self._model_name)): custom_loaded = maybe_load_custom_quantization( self._model_name, is_vlm=True, diff --git a/omlx/oq.py b/omlx/oq.py index 22df8fc10..6c273ff82 100644 --- a/omlx/oq.py +++ b/omlx/oq.py @@ -2257,48 +2257,12 @@ def quantize_oq_streaming( cb("loading", 12.0) - sanitize_fn = _build_model_sanitizer(config, text_only=text_only) - # When preserve_mtp is True, the patched sanitize functions - # (mlx_lm_mtp/qwen35_model.py and mlx_vlm_mtp/qwen35_vlm_model.py) - # keep mtp.* in the output and apply the +1 RMSNorm shift to MTP - # norms. No stash/merge wrapper needed — the patch covers both paths. - if sanitize_fn is not None: - try: - plan = _discover_sanitize_plan(sanitize_fn, all_weights) - all_weights = _DiscoveredPlan(plan, all_weights) - logger.info( - f"oQ{oq_level:g}: discovered streaming sanitize plan, " - f"{len(all_weights)} output tensors" - ) - except Exception as e: - if _model_exceeds_ram: - # Silent skip used to produce broken artifacts (see #1204): - # the source layout (e.g. fused experts.gate_up_proj) never - # got remapped to inference-expected names, and load failed - # with "Received N parameters not in model". Hard-fail so the - # caller sees the cause immediately. - raise RuntimeError( - f"oQ{oq_level:g}: streaming sanitize-plan discovery " - f"failed ({e}) and the eager fallback is unsafe with " - f"model size {_model_bytes / 1e9:.1f} GB exceeding " - f"{int(_MAX_MODEL_RAM_FRACTION * 100)}% of system RAM " - f"({_system_ram / 1e9:.1f} GB). Run on a machine with " - "enough RAM, or extend _TrackedTensor to cover the " - "indexing pattern the sanitize uses." - ) from e - logger.warning( - f"Streaming discovery failed ({e}), falling back to eager sanitize" - ) - try: - all_weights = sanitize_fn(all_weights) - logger.info(f"oQ{oq_level:g}: eager sanitize applied, {len(all_weights)} tensors") - except Exception as e2: - logger.warning(f"Sanitize failed ({e2}), using original names") - - config["_oq_non_quantizable"] = _build_non_quantizable_set(config) - - cb("loading", 15.0) - + # --- Sensitivity measurement (before sanitize-plan discovery) --------- + # Must run before _build_model_sanitizer + _discover_sanitize_plan, + # because the discovery pass feeds _TrackedTensor proxies through + # Model.sanitize which corrupts mutable state in the MTP sanitize + # patch (weights.pop on tracked objects). Running sensitivity first + # ensures vlm_load_model sees a pristine patch chain. if sensitivity_model_path: logger.info(f"oQ{oq_level:g}: measuring sensitivity via proxy model") sensitivity_map = _measure_sensitivity_from_quantized_model( @@ -2362,6 +2326,44 @@ def quantize_oq_streaming( "calibration data, or layer discovery), and either fix it or " "pass an explicit sensitivity_model_path." ) + + cb("loading", 15.0) + + # --- Sanitize-plan discovery ------------------------------------------ + sanitize_fn = _build_model_sanitizer(config, text_only=text_only) + # When preserve_mtp is True, the patched sanitize functions + # (mlx_lm_mtp/qwen35_model.py and mlx_vlm_mtp/qwen35_vlm_model.py) + # keep mtp.* in the output and apply the +1 RMSNorm shift to MTP + # norms. No stash/merge wrapper needed — the patch covers both paths. + if sanitize_fn is not None: + try: + plan = _discover_sanitize_plan(sanitize_fn, all_weights) + all_weights = _DiscoveredPlan(plan, all_weights) + logger.info( + f"oQ{oq_level:g}: discovered streaming sanitize plan, " + f"{len(all_weights)} output tensors" + ) + except Exception as e: + if _model_exceeds_ram: + raise RuntimeError( + f"oQ{oq_level:g}: streaming sanitize-plan discovery " + f"failed ({e}) and the eager fallback is unsafe with " + f"model size {_model_bytes / 1e9:.1f} GB exceeding " + f"{int(_MAX_MODEL_RAM_FRACTION * 100)}% of system RAM " + f"({_system_ram / 1e9:.1f} GB). Run on a machine with " + "enough RAM, or extend _TrackedTensor to cover the " + "indexing pattern the sanitize uses." + ) from e + logger.warning( + f"Streaming discovery failed ({e}), falling back to eager sanitize" + ) + try: + all_weights = sanitize_fn(all_weights) + logger.info(f"oQ{oq_level:g}: eager sanitize applied, {len(all_weights)} tensors") + except Exception as e2: + logger.warning(f"Sanitize failed ({e2}), using original names") + + config["_oq_non_quantizable"] = _build_non_quantizable_set(config) config["_oq_sensitivity_map"] = { str(k): v for k, v in sensitivity_map.items() } @@ -2645,7 +2647,7 @@ def quantize_oq_streaming( "c4": "C4 (Web Crawl)", "code": "Code (StarCoder)", "multilingual": "Multilingual (CulturaX)", - "code_multilingual": "Code + Multilingual", + "code_multilingual": "Code + Multilingual + Reasoning", } @@ -2712,7 +2714,7 @@ def _load_builtin_calibration(tokenizer, dataset: str, num_samples: int, if dataset == "code_multilingual": texts = [] - for key in ("code", "en", "ko", "zh", "ja", "tool_calling"): + for key in ("code", "en", "ko", "zh", "ja", "tool_calling", "reasoning"): texts.extend(all_data.get(key, [])) elif dataset == "code": texts = all_data.get("code", []) + all_data.get("en", []) @@ -3056,57 +3058,35 @@ def _measure_sensitivity( num_samples=32, seq_length=256, ): """Measure sensitivity by loading model temporarily. Used by streaming path.""" - is_vlm = "vision_config" in config + from omlx.utils.model_loading import maybe_apply_pre_load_patches - # Apply the same MTP runtime patches that production load and the - # main quantize path use. Sanitize patches are already global (from - # _build_model_sanitizer above), so loaded weights arrive with - # ``language_model.mtp.*`` keys; without the runtime patch the - # mlx-vlm LanguageModel.__init__ never attaches ``self.mtp`` and - # load_weights rejects the MTP tensors with "parameters not in model". - try: - from omlx.patches.mlx_lm_mtp import ( - apply_mlx_lm_mtp_patch, - is_mtp_active, - set_mtp_active, - ) - _have_lm_patch = apply_mlx_lm_mtp_patch() - except Exception: - _have_lm_patch = False - is_mtp_active = None - set_mtp_active = None + # Reuse the centralised pre-load dispatch so every current and future + # patch (MTP, DeepSeek V4, nested-visual, load_config, …) is applied + # exactly as in the production load path. model_settings is not + # passed — sensitivity runs on the *source* checkpoint which may not + # have MTP weights yet; maybe_apply_pre_load_patches without settings + # installs patches for sanitize correctness but leaves mtp_active + # False so Model.__init__ won't try to attach a missing MTP head. + maybe_apply_pre_load_patches(model_path) - if is_vlm: - try: - from omlx.patches.mlx_vlm_mtp import apply_mlx_vlm_mtp_runtime_patch - apply_mlx_vlm_mtp_runtime_patch() - except Exception as e: - logger.debug(f"mlx-vlm runtime MTP patch skipped: {e}") - - prev_active = is_mtp_active() if _have_lm_patch else False + is_vlm = "vision_config" in config try: - if _have_lm_patch: - set_mtp_active(True) - try: - if is_vlm: - from mlx_vlm.utils import load_model as vlm_load_model + if is_vlm: + from mlx_vlm.utils import load_model as vlm_load_model - model = vlm_load_model(Path(model_path), lazy=True) - from mlx_lm import load as lm_load + model = vlm_load_model(Path(model_path), lazy=True) + from mlx_lm.tokenizer_utils import load as load_tokenizer - _, tokenizer = lm_load(model_path, lazy=True) - else: - from mlx_lm import load as lm_load + tokenizer = load_tokenizer(Path(model_path)) + else: + from mlx_lm import load as lm_load - model, tokenizer = lm_load(model_path, lazy=True) - except Exception as e: - logger.error( - f"Sensitivity measurement: model load failed ({e})" - ) - return {} - finally: - if _have_lm_patch: - set_mtp_active(prev_active) + model, tokenizer = lm_load(model_path, lazy=True) + except Exception as e: + logger.error( + f"Sensitivity measurement: model load failed ({e})" + ) + return {} sensitivity = _measure_sensitivity_from_model( model, tokenizer, config, oq_level, diff --git a/omlx/patches/mlx_vlm_mtp/qwen35_moe_vlm_model.py b/omlx/patches/mlx_vlm_mtp/qwen35_moe_vlm_model.py index fd51138a0..23cc9d917 100644 --- a/omlx/patches/mlx_vlm_mtp/qwen35_moe_vlm_model.py +++ b/omlx/patches/mlx_vlm_mtp/qwen35_moe_vlm_model.py @@ -127,6 +127,9 @@ def _unfuse_layer_experts(prefix): # mlx-lm model hierarchy. See qwen35_vlm_model.py for why. key = "language_model." + key + if key.startswith("language_model.model.visual."): + key = "vision_tower." + key[len("language_model.model.visual."):] + if "conv1d.weight" in key and value.shape[-1] != 1: # mx.moveaxis goes through the streaming-discovery # monkey-patch in omlx.oq when called with _TrackedTensor; diff --git a/omlx/patches/mlx_vlm_mtp/qwen35_moe_vlm_runtime.py b/omlx/patches/mlx_vlm_mtp/qwen35_moe_vlm_runtime.py index d3dd9b576..42ea7e9f6 100644 --- a/omlx/patches/mlx_vlm_mtp/qwen35_moe_vlm_runtime.py +++ b/omlx/patches/mlx_vlm_mtp/qwen35_moe_vlm_runtime.py @@ -423,6 +423,9 @@ def _discover_mtp_layers(prefix_root: str): elif key.startswith("mtp."): key = "language_model." + key + if key.startswith("language_model.model.visual."): + key = "vision_tower." + key[len("language_model.model.visual."):] + if "conv1d.weight" in key and value.shape[-1] != 1: # Use the module-level mx.moveaxis so it goes through the # streaming-discovery monkey-patch (in ``omlx.oq``) when diff --git a/omlx/patches/mlx_vlm_mtp/qwen35_vlm_model.py b/omlx/patches/mlx_vlm_mtp/qwen35_vlm_model.py index 0a1050b6e..c28849c54 100644 --- a/omlx/patches/mlx_vlm_mtp/qwen35_vlm_model.py +++ b/omlx/patches/mlx_vlm_mtp/qwen35_vlm_model.py @@ -90,6 +90,9 @@ def sanitize(self, weights): # vs e.g. 6-bit packed weight on disk → shape error. key = "language_model." + key + if key.startswith("language_model.model.visual."): + key = "vision_tower." + key[len("language_model.model.visual."):] + if "conv1d.weight" in key and value.shape[-1] != 1: value = value.moveaxis(2, 1) if should_shift_norm_weights and any( diff --git a/omlx/patches/qwen3_6_nested_visual.py b/omlx/patches/qwen3_6_nested_visual.py index c4d27045d..df548e61e 100644 --- a/omlx/patches/qwen3_6_nested_visual.py +++ b/omlx/patches/qwen3_6_nested_visual.py @@ -35,8 +35,6 @@ class currently has) and from ``omlx.oq._build_model_sanitizer``. _NESTED_PREFIX = "language_model.model.visual." _TARGET_PREFIX = "vision_tower." -_class_patch_applied = False - def _rewrite_key(key: str) -> str: if key.startswith(_NESTED_PREFIX): @@ -64,20 +62,18 @@ def patched_sanitize(self, weights): ) return out + patched_sanitize._omlx_nested_visual_wrapped = True return patched_sanitize def apply_qwen3_6_nested_visual_patch() -> bool: """Install the sanitize wrapper on mlx-vlm's Qwen3_5MoE VLM Model class. - Idempotent. Returns True on first successful application, False if the - module is unavailable or the patch was already applied. + Idempotent: skips if the current ``Model.sanitize`` already carries the + ``_omlx_nested_visual_wrapped`` marker. Uses a function-attribute marker + instead of a module-level flag so that if another patch replaces + ``Model.sanitize`` (e.g. MTP runtime), this wrapper can re-apply. """ - global _class_patch_applied - - if _class_patch_applied: - return False - try: from mlx_vlm.models.qwen3_5_moe import qwen3_5_moe as qwen3_5_moe_module except ImportError: @@ -94,6 +90,9 @@ def apply_qwen3_6_nested_visual_patch() -> bool: logger.debug("qwen3_6_nested_visual: Model has no sanitize attr") return False + if getattr(original_sanitize, "_omlx_nested_visual_wrapped", False): + return False + try: import inspect @@ -103,12 +102,10 @@ def apply_qwen3_6_nested_visual_patch() -> bool: "qwen3_6_nested_visual: upstream sanitize already handles " "nested visual; skipping" ) - _class_patch_applied = True return False except (OSError, TypeError): pass model_cls.sanitize = _make_patched_sanitize(original_sanitize) - _class_patch_applied = True logger.info("qwen3_6_nested_visual: patched mlx_vlm.qwen3_5_moe Model.sanitize") return True diff --git a/omlx/utils/model_loading.py b/omlx/utils/model_loading.py index 6e5a92660..a7d3f277a 100644 --- a/omlx/utils/model_loading.py +++ b/omlx/utils/model_loading.py @@ -10,6 +10,63 @@ logger = logging.getLogger(__name__) +_VLM_TEXT_PREFIX = "language_model." + +_MLX_LM_LOAD_CONFIG_PATCHED = False + + +def expand_per_layer_quant_keys(cfg: dict) -> dict: + """Add ``language_model.``-prefixed variants of per-layer quantization keys. + + oQ writes per-layer overrides keyed by safetensors tensor base name + (e.g. ``"lm_head"``), but ``nn.quantize``'s class_predicate receives + model-tree paths (``"language_model.lm_head"``). Without the prefixed + variant the lookup misses and the global bits are used, causing a + shape mismatch at ``load_weights``. + + Mutates *cfg* in place and returns it for convenience. + """ + for config_key in ("quantization", "quantization_config"): + quant = cfg.get(config_key) + if not isinstance(quant, dict): + continue + extras: dict[str, dict] = {} + for key, val in quant.items(): + if not isinstance(val, dict): + continue + prefixed = _VLM_TEXT_PREFIX + key + if not key.startswith(_VLM_TEXT_PREFIX) and prefixed not in quant: + extras[prefixed] = val + elif key.startswith(_VLM_TEXT_PREFIX): + short = key[len(_VLM_TEXT_PREFIX):] + if short not in quant: + extras[short] = val + if extras: + quant.update(extras) + return cfg + + +def _patch_mlx_lm_load_config() -> None: + """Wrap ``mlx_lm.utils.load_config`` to expand per-layer quant keys.""" + global _MLX_LM_LOAD_CONFIG_PATCHED + if _MLX_LM_LOAD_CONFIG_PATCHED: + return + + try: + import mlx_lm.utils as _lu + except ImportError: + return + + _original = _lu.load_config + + def _patched(model_path, *args, **kwargs): + cfg = _original(model_path, *args, **kwargs) + expand_per_layer_quant_keys(cfg) + return cfg + + _lu.load_config = _patched + _MLX_LM_LOAD_CONFIG_PATCHED = True + def maybe_apply_pre_load_patches( model_name: str, @@ -36,6 +93,8 @@ def maybe_apply_pre_load_patches( set_mtp_active(False) + _patch_mlx_lm_load_config() + config_path = Path(model_name) / "config.json" if not config_path.exists(): return diff --git a/tests/test_oq.py b/tests/test_oq.py index cd04384d4..48c1d74fa 100644 --- a/tests/test_oq.py +++ b/tests/test_oq.py @@ -2234,3 +2234,54 @@ def test_import_fails_graceful_degradation(self, tmp_path, monkeypatch): assert isinstance(result, Path) mock_convert.assert_called_once() + + +# ============================================================================= +# Test built-in calibration data +# ============================================================================= + + +class TestBuiltinCalibration: + + def _load_json(self): + import json + p = Path(__file__).parent.parent / "omlx" / "oq_calibration_data.json" + with open(p, encoding="utf-8") as f: + return json.load(f) + + def test_json_has_all_required_categories(self): + data = self._load_json() + required = { + "code", "en", "ko", "zh", "ja", "tool_calling", "reasoning", + "mixed", "chat", "bartowski", + } + assert required.issubset(set(data.keys())), ( + f"Missing categories: {required - set(data.keys())}" + ) + + def test_json_minimum_samples_per_category(self): + data = self._load_json() + for key, texts in data.items(): + assert len(texts) >= 40, ( + f"Category {key!r} has only {len(texts)} samples" + ) + + def test_json_texts_are_nonempty_strings(self): + data = self._load_json() + for key, texts in data.items(): + for t in texts[:5]: + assert isinstance(t, str) and len(t) > 0, ( + f"Category {key!r} has invalid entry" + ) + + def test_total_sample_count(self): + data = self._load_json() + total = sum(len(v) for v in data.values()) + assert total >= 3000, f"Only {total} total samples" + + def test_code_multilingual_key_list_matches_json(self): + """The code_multilingual loop keys should all exist in the JSON.""" + data = self._load_json() + code_multi_keys = ("code", "en", "ko", "zh", "ja", "tool_calling", "reasoning") + for key in code_multi_keys: + assert key in data, f"code_multilingual references {key!r} but JSON is missing it"