Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion CorridorKeyModule/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,12 +372,16 @@ def create_engine(
backend: str | None = None,
device: str | None = None,
img_size: int = DEFAULT_IMG_SIZE,
model_precision: torch.dtype = torch.float16,
tile_size: int | None = DEFAULT_MLX_TILE_SIZE,
overlap: int = DEFAULT_MLX_TILE_OVERLAP,
):
"""Factory: returns an engine with process_frame() matching the Torch contract.

Args:
model_precision: Torch only. Model weight dtype. Defaults to
torch.float16 for lower VRAM. Override with torch.float32 if
numerical parity with eager CPU inference is required.
tile_size: MLX only — tile size for tiled inference (default 512).
Set to None to disable tiling and use full-frame inference.
overlap: MLX only — overlap pixels between tiles (default 64).
Expand All @@ -398,5 +402,5 @@ def create_engine(

logger.info("Torch engine loaded: %s (device=%s)", ckpt.name, device)
return CorridorKeyEngine(
checkpoint_path=str(ckpt), device=device or "cpu", img_size=img_size, model_precision=torch.float16
checkpoint_path=str(ckpt), device=device or "cpu", img_size=img_size, model_precision=model_precision
)
14 changes: 7 additions & 7 deletions backend/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,25 +284,25 @@ def _ensure_model(self, needed: _ActiveModel) -> None:
self._active_model = needed

def _get_engine(self):
"""Lazy-load the CorridorKey inference engine."""
"""Lazy-load the CorridorKey inference engine via the backend factory."""
self._ensure_model(_ActiveModel.INFERENCE)

if self._engine is not None:
return self._engine

try:
from CorridorKeyModule.backend import TORCH_EXT, _discover_checkpoint
from CorridorKeyModule.inference_engine import CorridorKeyEngine
import torch

from CorridorKeyModule.backend import create_engine
except ImportError as exc:
raise RuntimeError("CorridorKeyModule is not installed. Run: uv sync") from exc

ckpt_path = _discover_checkpoint(TORCH_EXT)
logger.info(f"Loading checkpoint: {os.path.basename(ckpt_path)}")
t0 = time.monotonic()
self._engine = CorridorKeyEngine(
checkpoint_path=ckpt_path,
self._engine = create_engine(
backend="torch",
device=self._device,
img_size=2048,
model_precision=torch.float32,
)
logger.info(f"Engine loaded in {time.monotonic() - t0:.1f}s")
return self._engine
Expand Down
64 changes: 64 additions & 0 deletions tests/test_backend.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
"""Unit tests for CorridorKeyModule.backend — no GPU/MLX required."""

import errno
import inspect
import logging
import os
from unittest import mock

import numpy as np
import pytest
import torch

from CorridorKeyModule.backend import (
BACKEND_ENV_VAR,
Expand Down Expand Up @@ -344,3 +346,65 @@ def test_value_ranges(self, mlx_raw_output):
# (same behavior as Torch engine — linear_to_srgb doesn't clamp)
for key in ("comp", "processed"):
assert result[key].min() >= 0.0, f"{key} has negative values"


# --- create_engine model_precision parameter ---


class TestCreateEngineModelPrecision:
"""create_engine exposes and forwards a model_precision parameter for Torch."""

def test_signature_has_model_precision(self):
"""create_engine's signature includes a model_precision parameter."""
from CorridorKeyModule.backend import create_engine

params = inspect.signature(create_engine).parameters
assert "model_precision" in params

def test_default_model_precision_is_float16(self):
"""Default model_precision is torch.float16 to keep CLI and service aligned."""
from CorridorKeyModule.backend import create_engine

params = inspect.signature(create_engine).parameters
assert params["model_precision"].default is torch.float16

def test_model_precision_forwarded_to_torch_engine(self, tmp_path):
"""create_engine forwards the requested model_precision to CorridorKeyEngine."""
from CorridorKeyModule.backend import create_engine

ckpt = tmp_path / "model.safetensors"
ckpt.touch()
with mock.patch("CorridorKeyModule.backend.CHECKPOINT_DIR", str(tmp_path)):
with mock.patch("CorridorKeyModule.inference_engine.CorridorKeyEngine") as mock_engine_cls:
create_engine(backend="torch", device="cpu", model_precision=torch.float32)
mock_engine_cls.assert_called_once()
kwargs = mock_engine_cls.call_args.kwargs
assert kwargs["model_precision"] is torch.float32


# --- Service routes engine construction through the factory ---


class TestServiceEngineRouting:
"""CorridorKeyService._get_engine delegates engine construction to create_engine."""

def test_get_engine_calls_create_engine_with_torch_and_fp32(self):
"""_get_engine invokes the factory with explicit backend="torch" and FP32 precision.

FP32 is passed explicitly to preserve the service's existing quality-first
behavior (FP32 weights plus mixed_precision autocast for speed on safe ops).
The factory's default FP16 matches the CLI's long-standing VRAM tradeoff.
"""
from backend.service import CorridorKeyService

service = CorridorKeyService()
with mock.patch("CorridorKeyModule.backend.create_engine") as mock_factory:
mock_factory.return_value = object()
service._get_engine()

mock_factory.assert_called_once()
kwargs = mock_factory.call_args.kwargs
assert kwargs["backend"] == "torch"
assert kwargs["device"] == "cpu"
assert kwargs["img_size"] == 2048
assert kwargs["model_precision"] is torch.float32