diff --git a/CorridorKeyModule/backend.py b/CorridorKeyModule/backend.py index e3251318..fa803ab8 100644 --- a/CorridorKeyModule/backend.py +++ b/CorridorKeyModule/backend.py @@ -372,12 +372,16 @@ def create_engine( backend: str | None = None, device: str | None = None, img_size: int = DEFAULT_IMG_SIZE, + model_precision: torch.dtype = torch.float16, tile_size: int | None = DEFAULT_MLX_TILE_SIZE, overlap: int = DEFAULT_MLX_TILE_OVERLAP, ): """Factory: returns an engine with process_frame() matching the Torch contract. Args: + model_precision: Torch only. Model weight dtype. Defaults to + torch.float16 for lower VRAM. Override with torch.float32 if + numerical parity with eager CPU inference is required. tile_size: MLX only — tile size for tiled inference (default 512). Set to None to disable tiling and use full-frame inference. overlap: MLX only — overlap pixels between tiles (default 64). @@ -398,5 +402,5 @@ def create_engine( logger.info("Torch engine loaded: %s (device=%s)", ckpt.name, device) return CorridorKeyEngine( - checkpoint_path=str(ckpt), device=device or "cpu", img_size=img_size, model_precision=torch.float16 + checkpoint_path=str(ckpt), device=device or "cpu", img_size=img_size, model_precision=model_precision ) diff --git a/backend/service.py b/backend/service.py index b7b1eb9f..81a37eb6 100644 --- a/backend/service.py +++ b/backend/service.py @@ -284,25 +284,25 @@ def _ensure_model(self, needed: _ActiveModel) -> None: self._active_model = needed def _get_engine(self): - """Lazy-load the CorridorKey inference engine.""" + """Lazy-load the CorridorKey inference engine via the backend factory.""" self._ensure_model(_ActiveModel.INFERENCE) if self._engine is not None: return self._engine try: - from CorridorKeyModule.backend import TORCH_EXT, _discover_checkpoint - from CorridorKeyModule.inference_engine import CorridorKeyEngine + import torch + + from CorridorKeyModule.backend import create_engine except ImportError as exc: raise RuntimeError("CorridorKeyModule is not installed. Run: uv sync") from exc - ckpt_path = _discover_checkpoint(TORCH_EXT) - logger.info(f"Loading checkpoint: {os.path.basename(ckpt_path)}") t0 = time.monotonic() - self._engine = CorridorKeyEngine( - checkpoint_path=ckpt_path, + self._engine = create_engine( + backend="torch", device=self._device, img_size=2048, + model_precision=torch.float32, ) logger.info(f"Engine loaded in {time.monotonic() - t0:.1f}s") return self._engine diff --git a/tests/test_backend.py b/tests/test_backend.py index 64e6b055..01ebc864 100644 --- a/tests/test_backend.py +++ b/tests/test_backend.py @@ -1,12 +1,14 @@ """Unit tests for CorridorKeyModule.backend — no GPU/MLX required.""" import errno +import inspect import logging import os from unittest import mock import numpy as np import pytest +import torch from CorridorKeyModule.backend import ( BACKEND_ENV_VAR, @@ -344,3 +346,65 @@ def test_value_ranges(self, mlx_raw_output): # (same behavior as Torch engine — linear_to_srgb doesn't clamp) for key in ("comp", "processed"): assert result[key].min() >= 0.0, f"{key} has negative values" + + +# --- create_engine model_precision parameter --- + + +class TestCreateEngineModelPrecision: + """create_engine exposes and forwards a model_precision parameter for Torch.""" + + def test_signature_has_model_precision(self): + """create_engine's signature includes a model_precision parameter.""" + from CorridorKeyModule.backend import create_engine + + params = inspect.signature(create_engine).parameters + assert "model_precision" in params + + def test_default_model_precision_is_float16(self): + """Default model_precision is torch.float16 to keep CLI and service aligned.""" + from CorridorKeyModule.backend import create_engine + + params = inspect.signature(create_engine).parameters + assert params["model_precision"].default is torch.float16 + + def test_model_precision_forwarded_to_torch_engine(self, tmp_path): + """create_engine forwards the requested model_precision to CorridorKeyEngine.""" + from CorridorKeyModule.backend import create_engine + + ckpt = tmp_path / "model.safetensors" + ckpt.touch() + with mock.patch("CorridorKeyModule.backend.CHECKPOINT_DIR", str(tmp_path)): + with mock.patch("CorridorKeyModule.inference_engine.CorridorKeyEngine") as mock_engine_cls: + create_engine(backend="torch", device="cpu", model_precision=torch.float32) + mock_engine_cls.assert_called_once() + kwargs = mock_engine_cls.call_args.kwargs + assert kwargs["model_precision"] is torch.float32 + + +# --- Service routes engine construction through the factory --- + + +class TestServiceEngineRouting: + """CorridorKeyService._get_engine delegates engine construction to create_engine.""" + + def test_get_engine_calls_create_engine_with_torch_and_fp32(self): + """_get_engine invokes the factory with explicit backend="torch" and FP32 precision. + + FP32 is passed explicitly to preserve the service's existing quality-first + behavior (FP32 weights plus mixed_precision autocast for speed on safe ops). + The factory's default FP16 matches the CLI's long-standing VRAM tradeoff. + """ + from backend.service import CorridorKeyService + + service = CorridorKeyService() + with mock.patch("CorridorKeyModule.backend.create_engine") as mock_factory: + mock_factory.return_value = object() + service._get_engine() + + mock_factory.assert_called_once() + kwargs = mock_factory.call_args.kwargs + assert kwargs["backend"] == "torch" + assert kwargs["device"] == "cpu" + assert kwargs["img_size"] == 2048 + assert kwargs["model_precision"] is torch.float32