Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 63 additions & 1 deletion omlx/engine/vlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,11 @@ def _strip_audio_config_if_orphaned(model_dir: Path):

def _patched(path, **kwargs):
cfg = original(path, **kwargs)

from ..utils.model_loading import expand_per_layer_quant_keys

expand_per_layer_quant_keys(cfg)

if cfg.get("audio_config") is None:
return cfg
try:
Expand Down Expand Up @@ -399,6 +404,62 @@ def _patched(path, **kwargs):
_vu.load_config = original


_NESTED_VIS_PREFIX = "language_model.model.visual."
_VISION_TOWER_PREFIX = "vision_tower."


@contextlib.contextmanager
def _remap_nested_visual_on_load(model_dir: Path):
"""Remap ``language_model.model.visual.*`` → ``vision_tower.*`` during
``load_model`` for MLX-format models where sanitize is skipped.

mlx-vlm's ``load_model`` skips ``Model.sanitize`` when the safetensors
metadata declares ``format=mlx``. oQ output is MLX-format, so the
nested-visual key fixup that sanitize normally applies never fires.
This context manager wraps ``load_model`` to intercept the weight dict
and perform the remap before ``nn.Module.load_weights`` is called.

Scoped to a single ``vlm_load(...)`` call.
"""
import mlx_vlm.utils as _vu

original_load_model = _vu.load_model

def _patched_load_model(model_path, lazy=False, **kwargs):
import mlx.nn as _nn

orig_load_weights = _nn.Module.load_weights
def _remapping_load_weights(self, weights_items, *args, **kw):
if isinstance(weights_items, str):
return orig_load_weights(self, weights_items, *args, **kw)
remapped = []
n = 0
for k, v in weights_items:
if k.startswith(_NESTED_VIS_PREFIX):
k = _VISION_TOWER_PREFIX + k[len(_NESTED_VIS_PREFIX):]
n += 1
remapped.append((k, v))
if n:
logger.info(
"remap_nested_visual_on_load: remapped %d keys "
"'language_model.model.visual.*' -> 'vision_tower.*'",
n,
)
return orig_load_weights(self, remapped, *args, **kw)

_nn.Module.load_weights = _remapping_load_weights
try:
return original_load_model(model_path, lazy, **kwargs)
finally:
_nn.Module.load_weights = orig_load_weights

_vu.load_model = _patched_load_model
try:
yield
finally:
_vu.load_model = original_load_model


# Models that only support a single image per request
SINGLE_IMAGE_ONLY_MODELS = {
"llava_next",
Expand Down Expand Up @@ -605,7 +666,8 @@ async def start(self) -> None:
def _load_vlm_sync():
_patch_video_processor_bug()
_patch_torch_free_image_processor()
with _strip_audio_config_if_orphaned(Path(self._model_name)):
with _strip_audio_config_if_orphaned(Path(self._model_name)), \
_remap_nested_visual_on_load(Path(self._model_name)):
custom_loaded = maybe_load_custom_quantization(
self._model_name,
is_vlm=True,
Expand Down
158 changes: 69 additions & 89 deletions omlx/oq.py
Original file line number Diff line number Diff line change
Expand Up @@ -2257,48 +2257,12 @@ def quantize_oq_streaming(

cb("loading", 12.0)

sanitize_fn = _build_model_sanitizer(config, text_only=text_only)
# When preserve_mtp is True, the patched sanitize functions
# (mlx_lm_mtp/qwen35_model.py and mlx_vlm_mtp/qwen35_vlm_model.py)
# keep mtp.* in the output and apply the +1 RMSNorm shift to MTP
# norms. No stash/merge wrapper needed — the patch covers both paths.
if sanitize_fn is not None:
try:
plan = _discover_sanitize_plan(sanitize_fn, all_weights)
all_weights = _DiscoveredPlan(plan, all_weights)
logger.info(
f"oQ{oq_level:g}: discovered streaming sanitize plan, "
f"{len(all_weights)} output tensors"
)
except Exception as e:
if _model_exceeds_ram:
# Silent skip used to produce broken artifacts (see #1204):
# the source layout (e.g. fused experts.gate_up_proj) never
# got remapped to inference-expected names, and load failed
# with "Received N parameters not in model". Hard-fail so the
# caller sees the cause immediately.
raise RuntimeError(
f"oQ{oq_level:g}: streaming sanitize-plan discovery "
f"failed ({e}) and the eager fallback is unsafe with "
f"model size {_model_bytes / 1e9:.1f} GB exceeding "
f"{int(_MAX_MODEL_RAM_FRACTION * 100)}% of system RAM "
f"({_system_ram / 1e9:.1f} GB). Run on a machine with "
"enough RAM, or extend _TrackedTensor to cover the "
"indexing pattern the sanitize uses."
) from e
logger.warning(
f"Streaming discovery failed ({e}), falling back to eager sanitize"
)
try:
all_weights = sanitize_fn(all_weights)
logger.info(f"oQ{oq_level:g}: eager sanitize applied, {len(all_weights)} tensors")
except Exception as e2:
logger.warning(f"Sanitize failed ({e2}), using original names")

config["_oq_non_quantizable"] = _build_non_quantizable_set(config)

cb("loading", 15.0)

# --- Sensitivity measurement (before sanitize-plan discovery) ---------
# Must run before _build_model_sanitizer + _discover_sanitize_plan,
# because the discovery pass feeds _TrackedTensor proxies through
# Model.sanitize which corrupts mutable state in the MTP sanitize
# patch (weights.pop on tracked objects). Running sensitivity first
# ensures vlm_load_model sees a pristine patch chain.
if sensitivity_model_path:
logger.info(f"oQ{oq_level:g}: measuring sensitivity via proxy model")
sensitivity_map = _measure_sensitivity_from_quantized_model(
Expand Down Expand Up @@ -2362,6 +2326,44 @@ def quantize_oq_streaming(
"calibration data, or layer discovery), and either fix it or "
"pass an explicit sensitivity_model_path."
)

cb("loading", 15.0)

# --- Sanitize-plan discovery ------------------------------------------
sanitize_fn = _build_model_sanitizer(config, text_only=text_only)
# When preserve_mtp is True, the patched sanitize functions
# (mlx_lm_mtp/qwen35_model.py and mlx_vlm_mtp/qwen35_vlm_model.py)
# keep mtp.* in the output and apply the +1 RMSNorm shift to MTP
# norms. No stash/merge wrapper needed — the patch covers both paths.
if sanitize_fn is not None:
try:
plan = _discover_sanitize_plan(sanitize_fn, all_weights)
all_weights = _DiscoveredPlan(plan, all_weights)
logger.info(
f"oQ{oq_level:g}: discovered streaming sanitize plan, "
f"{len(all_weights)} output tensors"
)
except Exception as e:
if _model_exceeds_ram:
raise RuntimeError(
f"oQ{oq_level:g}: streaming sanitize-plan discovery "
f"failed ({e}) and the eager fallback is unsafe with "
f"model size {_model_bytes / 1e9:.1f} GB exceeding "
f"{int(_MAX_MODEL_RAM_FRACTION * 100)}% of system RAM "
f"({_system_ram / 1e9:.1f} GB). Run on a machine with "
"enough RAM, or extend _TrackedTensor to cover the "
"indexing pattern the sanitize uses."
) from e
logger.warning(
f"Streaming discovery failed ({e}), falling back to eager sanitize"
)
try:
all_weights = sanitize_fn(all_weights)
logger.info(f"oQ{oq_level:g}: eager sanitize applied, {len(all_weights)} tensors")
except Exception as e2:
logger.warning(f"Sanitize failed ({e2}), using original names")

config["_oq_non_quantizable"] = _build_non_quantizable_set(config)
config["_oq_sensitivity_map"] = {
str(k): v for k, v in sensitivity_map.items()
}
Expand Down Expand Up @@ -2645,7 +2647,7 @@ def quantize_oq_streaming(
"c4": "C4 (Web Crawl)",
"code": "Code (StarCoder)",
"multilingual": "Multilingual (CulturaX)",
"code_multilingual": "Code + Multilingual",
"code_multilingual": "Code + Multilingual + Reasoning",
}


Expand Down Expand Up @@ -2712,7 +2714,7 @@ def _load_builtin_calibration(tokenizer, dataset: str, num_samples: int,

if dataset == "code_multilingual":
texts = []
for key in ("code", "en", "ko", "zh", "ja", "tool_calling"):
for key in ("code", "en", "ko", "zh", "ja", "tool_calling", "reasoning"):
texts.extend(all_data.get(key, []))
elif dataset == "code":
texts = all_data.get("code", []) + all_data.get("en", [])
Expand Down Expand Up @@ -3056,57 +3058,35 @@ def _measure_sensitivity(
num_samples=32, seq_length=256,
):
"""Measure sensitivity by loading model temporarily. Used by streaming path."""
is_vlm = "vision_config" in config
from omlx.utils.model_loading import maybe_apply_pre_load_patches

# Apply the same MTP runtime patches that production load and the
# main quantize path use. Sanitize patches are already global (from
# _build_model_sanitizer above), so loaded weights arrive with
# ``language_model.mtp.*`` keys; without the runtime patch the
# mlx-vlm LanguageModel.__init__ never attaches ``self.mtp`` and
# load_weights rejects the MTP tensors with "parameters not in model".
try:
from omlx.patches.mlx_lm_mtp import (
apply_mlx_lm_mtp_patch,
is_mtp_active,
set_mtp_active,
)
_have_lm_patch = apply_mlx_lm_mtp_patch()
except Exception:
_have_lm_patch = False
is_mtp_active = None
set_mtp_active = None
# Reuse the centralised pre-load dispatch so every current and future
# patch (MTP, DeepSeek V4, nested-visual, load_config, …) is applied
# exactly as in the production load path. model_settings is not
# passed — sensitivity runs on the *source* checkpoint which may not
# have MTP weights yet; maybe_apply_pre_load_patches without settings
# installs patches for sanitize correctness but leaves mtp_active
# False so Model.__init__ won't try to attach a missing MTP head.
maybe_apply_pre_load_patches(model_path)

if is_vlm:
try:
from omlx.patches.mlx_vlm_mtp import apply_mlx_vlm_mtp_runtime_patch
apply_mlx_vlm_mtp_runtime_patch()
except Exception as e:
logger.debug(f"mlx-vlm runtime MTP patch skipped: {e}")

prev_active = is_mtp_active() if _have_lm_patch else False
is_vlm = "vision_config" in config
try:
if _have_lm_patch:
set_mtp_active(True)
try:
if is_vlm:
from mlx_vlm.utils import load_model as vlm_load_model
if is_vlm:
from mlx_vlm.utils import load_model as vlm_load_model

model = vlm_load_model(Path(model_path), lazy=True)
from mlx_lm import load as lm_load
model = vlm_load_model(Path(model_path), lazy=True)
from mlx_lm.tokenizer_utils import load as load_tokenizer

_, tokenizer = lm_load(model_path, lazy=True)
else:
from mlx_lm import load as lm_load
tokenizer = load_tokenizer(Path(model_path))
else:
from mlx_lm import load as lm_load

model, tokenizer = lm_load(model_path, lazy=True)
except Exception as e:
logger.error(
f"Sensitivity measurement: model load failed ({e})"
)
return {}
finally:
if _have_lm_patch:
set_mtp_active(prev_active)
model, tokenizer = lm_load(model_path, lazy=True)
except Exception as e:
logger.error(
f"Sensitivity measurement: model load failed ({e})"
)
return {}

sensitivity = _measure_sensitivity_from_model(
model, tokenizer, config, oq_level,
Expand Down
3 changes: 3 additions & 0 deletions omlx/patches/mlx_vlm_mtp/qwen35_moe_vlm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,9 @@ def _unfuse_layer_experts(prefix):
# mlx-lm model hierarchy. See qwen35_vlm_model.py for why.
key = "language_model." + key

if key.startswith("language_model.model.visual."):
key = "vision_tower." + key[len("language_model.model.visual."):]

if "conv1d.weight" in key and value.shape[-1] != 1:
# mx.moveaxis goes through the streaming-discovery
# monkey-patch in omlx.oq when called with _TrackedTensor;
Expand Down
3 changes: 3 additions & 0 deletions omlx/patches/mlx_vlm_mtp/qwen35_moe_vlm_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,9 @@ def _discover_mtp_layers(prefix_root: str):
elif key.startswith("mtp."):
key = "language_model." + key

if key.startswith("language_model.model.visual."):
key = "vision_tower." + key[len("language_model.model.visual."):]

if "conv1d.weight" in key and value.shape[-1] != 1:
# Use the module-level mx.moveaxis so it goes through the
# streaming-discovery monkey-patch (in ``omlx.oq``) when
Expand Down
3 changes: 3 additions & 0 deletions omlx/patches/mlx_vlm_mtp/qwen35_vlm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,9 @@ def sanitize(self, weights):
# vs e.g. 6-bit packed weight on disk → shape error.
key = "language_model." + key

if key.startswith("language_model.model.visual."):
key = "vision_tower." + key[len("language_model.model.visual."):]

if "conv1d.weight" in key and value.shape[-1] != 1:
value = value.moveaxis(2, 1)
if should_shift_norm_weights and any(
Expand Down
19 changes: 8 additions & 11 deletions omlx/patches/qwen3_6_nested_visual.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,6 @@ class currently has) and from ``omlx.oq._build_model_sanitizer``.
_NESTED_PREFIX = "language_model.model.visual."
_TARGET_PREFIX = "vision_tower."

_class_patch_applied = False


def _rewrite_key(key: str) -> str:
if key.startswith(_NESTED_PREFIX):
Expand Down Expand Up @@ -64,20 +62,18 @@ def patched_sanitize(self, weights):
)
return out

patched_sanitize._omlx_nested_visual_wrapped = True
return patched_sanitize


def apply_qwen3_6_nested_visual_patch() -> bool:
"""Install the sanitize wrapper on mlx-vlm's Qwen3_5MoE VLM Model class.

Idempotent. Returns True on first successful application, False if the
module is unavailable or the patch was already applied.
Idempotent: skips if the current ``Model.sanitize`` already carries the
``_omlx_nested_visual_wrapped`` marker. Uses a function-attribute marker
instead of a module-level flag so that if another patch replaces
``Model.sanitize`` (e.g. MTP runtime), this wrapper can re-apply.
"""
global _class_patch_applied

if _class_patch_applied:
return False

try:
from mlx_vlm.models.qwen3_5_moe import qwen3_5_moe as qwen3_5_moe_module
except ImportError:
Expand All @@ -94,6 +90,9 @@ def apply_qwen3_6_nested_visual_patch() -> bool:
logger.debug("qwen3_6_nested_visual: Model has no sanitize attr")
return False

if getattr(original_sanitize, "_omlx_nested_visual_wrapped", False):
return False

try:
import inspect

Expand All @@ -103,12 +102,10 @@ def apply_qwen3_6_nested_visual_patch() -> bool:
"qwen3_6_nested_visual: upstream sanitize already handles "
"nested visual; skipping"
)
_class_patch_applied = True
return False
except (OSError, TypeError):
pass

model_cls.sanitize = _make_patched_sanitize(original_sanitize)
_class_patch_applied = True
logger.info("qwen3_6_nested_visual: patched mlx_vlm.qwen3_5_moe Model.sanitize")
return True
Loading