diff --git a/tests/test_mlx_lm_mtp_patch.py b/tests/test_mlx_lm_mtp_patch.py index 185628935..996d1274f 100644 --- a/tests/test_mlx_lm_mtp_patch.py +++ b/tests/test_mlx_lm_mtp_patch.py @@ -331,6 +331,7 @@ def test_generation_batch_is_patched(self): assert hasattr(GenerationBatch, "_omlx_mtp_patched") def test_is_mtp_eligible_requires_mtp_forward_and_solo_batch(self): + from omlx.patches import mlx_lm_mtp from omlx.patches.mlx_lm_mtp.batch_generator import _is_mtp_eligible class _NonMtpModel: @@ -338,14 +339,14 @@ class _NonMtpModel: class _MtpModelWithoutHead: """Has the patched method but no actual MTP head attached - (mtp_enabled was False when this hypothetical model loaded).""" + (config did not declare an MTP head when this model loaded).""" def mtp_forward(self, *_): pass class _MtpModel: - """Has both the method and the attached head — i.e. mtp_enabled - was True at load time.""" + """Has both the method and the attached head — i.e. the model + class was patched and the head was attached at load time.""" def __init__(self): self.mtp = object() # placeholder for an actual MTPModule @@ -358,18 +359,34 @@ def __init__(self, model, uids): self.model = model self.uids = uids - # Non-MTP model never triggers the MTP path. - assert _is_mtp_eligible(_GenBatch(_NonMtpModel(), uids=[1])) is False - # Has mtp_forward but no attached head → still off (mtp_enabled was False). - assert ( - _is_mtp_eligible(_GenBatch(_MtpModelWithoutHead(), uids=[1])) is False - ) - # Has both method and head + batch=1 → triggers the path. - assert _is_mtp_eligible(_GenBatch(_MtpModel(), uids=[1])) is True - # MTP model with batch=2 falls back to standard step. - assert _is_mtp_eligible(_GenBatch(_MtpModel(), uids=[1, 2])) is False - # Empty batch never triggers. - assert _is_mtp_eligible(_GenBatch(_MtpModel(), uids=[])) is False + prior_active = mlx_lm_mtp.is_mtp_active() + try: + # Head attached but the per-load mtp_active flag is off + # (e.g. VLM runtime patches attach unconditionally so weight + # load matches, while inference-time MTP stays disabled). + mlx_lm_mtp.set_mtp_active(False) + assert ( + _is_mtp_eligible(_GenBatch(_MtpModel(), uids=[1])) is False + ) + + mlx_lm_mtp.set_mtp_active(True) + # Non-MTP model never triggers the MTP path. + assert _is_mtp_eligible(_GenBatch(_NonMtpModel(), uids=[1])) is False + # Has mtp_forward but no attached head → still off. + assert ( + _is_mtp_eligible(_GenBatch(_MtpModelWithoutHead(), uids=[1])) + is False + ) + # Has both method and head + batch=1 + flag on → triggers the path. + assert _is_mtp_eligible(_GenBatch(_MtpModel(), uids=[1])) is True + # MTP model with batch=2 falls back to standard step. + assert ( + _is_mtp_eligible(_GenBatch(_MtpModel(), uids=[1, 2])) is False + ) + # Empty batch never triggers. + assert _is_mtp_eligible(_GenBatch(_MtpModel(), uids=[])) is False + finally: + mlx_lm_mtp.set_mtp_active(prior_active) # --------------------------------------------------------------------------- diff --git a/tests/test_vlm_torch_free_image_processor.py b/tests/test_vlm_torch_free_image_processor.py index fe03c7b0b..d3548ed89 100644 --- a/tests/test_vlm_torch_free_image_processor.py +++ b/tests/test_vlm_torch_free_image_processor.py @@ -320,28 +320,30 @@ def test_patch_wraps_target_processors(): fake_transformers = types.ModuleType("transformers") fake_transformers.AutoImageProcessor = fake_aip - # Build two fake mlx-vlm processor modules. + # Build two fake mlx-vlm processor modules. Module paths and class names + # must match the (module_path, cls_name) tuples in vlm.py's + # _patch_torch_free_image_processor. class FakeGlmOcrProcessor: @classmethod def from_pretrained(cls, path, **kwargs): return "glm" - class FakeDotsOcrProcessor: + class FakeDotsVLProcessor: @classmethod def from_pretrained(cls, path, **kwargs): return "dots" glm_mod = types.ModuleType("mlx_vlm.models.glm_ocr.processing") glm_mod.GlmOcrProcessor = FakeGlmOcrProcessor - dots_mod = types.ModuleType("mlx_vlm.models.dots_ocr.processing") - dots_mod.DotsOcrProcessor = FakeDotsOcrProcessor + dots_mod = types.ModuleType("mlx_vlm.models.dots_ocr.processing_dots_ocr") + dots_mod.DotsVLProcessor = FakeDotsVLProcessor real_import = importlib.import_module def fake_import(name, *args, **kwargs): if name == "mlx_vlm.models.glm_ocr.processing": return glm_mod - if name == "mlx_vlm.models.dots_ocr.processing": + if name == "mlx_vlm.models.dots_ocr.processing_dots_ocr": return dots_mod return real_import(name, *args, **kwargs) @@ -353,5 +355,5 @@ def fake_import(name, *args, **kwargs): FakeGlmOcrProcessor.from_pretrained, "_omlx_torch_free_patched", False ) assert getattr( - FakeDotsOcrProcessor.from_pretrained, "_omlx_torch_free_patched", False + FakeDotsVLProcessor.from_pretrained, "_omlx_torch_free_patched", False )