Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 32 additions & 15 deletions tests/test_mlx_lm_mtp_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,21 +331,22 @@ def test_generation_batch_is_patched(self):
assert hasattr(GenerationBatch, "_omlx_mtp_patched")

def test_is_mtp_eligible_requires_mtp_forward_and_solo_batch(self):
from omlx.patches import mlx_lm_mtp
from omlx.patches.mlx_lm_mtp.batch_generator import _is_mtp_eligible

class _NonMtpModel:
pass

class _MtpModelWithoutHead:
"""Has the patched method but no actual MTP head attached
(mtp_enabled was False when this hypothetical model loaded)."""
(config did not declare an MTP head when this model loaded)."""

def mtp_forward(self, *_):
pass

class _MtpModel:
"""Has both the method and the attached head — i.e. mtp_enabled
was True at load time."""
"""Has both the method and the attached head — i.e. the model
class was patched and the head was attached at load time."""

def __init__(self):
self.mtp = object() # placeholder for an actual MTPModule
Expand All @@ -358,18 +359,34 @@ def __init__(self, model, uids):
self.model = model
self.uids = uids

# Non-MTP model never triggers the MTP path.
assert _is_mtp_eligible(_GenBatch(_NonMtpModel(), uids=[1])) is False
# Has mtp_forward but no attached head → still off (mtp_enabled was False).
assert (
_is_mtp_eligible(_GenBatch(_MtpModelWithoutHead(), uids=[1])) is False
)
# Has both method and head + batch=1 → triggers the path.
assert _is_mtp_eligible(_GenBatch(_MtpModel(), uids=[1])) is True
# MTP model with batch=2 falls back to standard step.
assert _is_mtp_eligible(_GenBatch(_MtpModel(), uids=[1, 2])) is False
# Empty batch never triggers.
assert _is_mtp_eligible(_GenBatch(_MtpModel(), uids=[])) is False
prior_active = mlx_lm_mtp.is_mtp_active()
try:
# Head attached but the per-load mtp_active flag is off
# (e.g. VLM runtime patches attach unconditionally so weight
# load matches, while inference-time MTP stays disabled).
mlx_lm_mtp.set_mtp_active(False)
assert (
_is_mtp_eligible(_GenBatch(_MtpModel(), uids=[1])) is False
)

mlx_lm_mtp.set_mtp_active(True)
# Non-MTP model never triggers the MTP path.
assert _is_mtp_eligible(_GenBatch(_NonMtpModel(), uids=[1])) is False
# Has mtp_forward but no attached head → still off.
assert (
_is_mtp_eligible(_GenBatch(_MtpModelWithoutHead(), uids=[1]))
is False
)
# Has both method and head + batch=1 + flag on → triggers the path.
assert _is_mtp_eligible(_GenBatch(_MtpModel(), uids=[1])) is True
# MTP model with batch=2 falls back to standard step.
assert (
_is_mtp_eligible(_GenBatch(_MtpModel(), uids=[1, 2])) is False
)
# Empty batch never triggers.
assert _is_mtp_eligible(_GenBatch(_MtpModel(), uids=[])) is False
finally:
mlx_lm_mtp.set_mtp_active(prior_active)


# ---------------------------------------------------------------------------
Expand Down
14 changes: 8 additions & 6 deletions tests/test_vlm_torch_free_image_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,28 +320,30 @@ def test_patch_wraps_target_processors():
fake_transformers = types.ModuleType("transformers")
fake_transformers.AutoImageProcessor = fake_aip

# Build two fake mlx-vlm processor modules.
# Build two fake mlx-vlm processor modules. Module paths and class names
# must match the (module_path, cls_name) tuples in vlm.py's
# _patch_torch_free_image_processor.
class FakeGlmOcrProcessor:
@classmethod
def from_pretrained(cls, path, **kwargs):
return "glm"

class FakeDotsOcrProcessor:
class FakeDotsVLProcessor:
@classmethod
def from_pretrained(cls, path, **kwargs):
return "dots"

glm_mod = types.ModuleType("mlx_vlm.models.glm_ocr.processing")
glm_mod.GlmOcrProcessor = FakeGlmOcrProcessor
dots_mod = types.ModuleType("mlx_vlm.models.dots_ocr.processing")
dots_mod.DotsOcrProcessor = FakeDotsOcrProcessor
dots_mod = types.ModuleType("mlx_vlm.models.dots_ocr.processing_dots_ocr")
dots_mod.DotsVLProcessor = FakeDotsVLProcessor

real_import = importlib.import_module

def fake_import(name, *args, **kwargs):
if name == "mlx_vlm.models.glm_ocr.processing":
return glm_mod
if name == "mlx_vlm.models.dots_ocr.processing":
if name == "mlx_vlm.models.dots_ocr.processing_dots_ocr":
return dots_mod
return real_import(name, *args, **kwargs)

Expand All @@ -353,5 +355,5 @@ def fake_import(name, *args, **kwargs):
FakeGlmOcrProcessor.from_pretrained, "_omlx_torch_free_patched", False
)
assert getattr(
FakeDotsOcrProcessor.from_pretrained, "_omlx_torch_free_patched", False
FakeDotsVLProcessor.from_pretrained, "_omlx_torch_free_patched", False
)