diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index c918677..d89cca6 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -38,6 +38,6 @@ jobs: reg = BenchmarkRegistry() reg.discover() ids = reg.list_ids() - assert len(ids) == 39, f'Expected 39 benchmarks, got {len(ids)}' + assert len(ids) == 40, f'Expected 40 benchmarks, got {len(ids)}' " python scripts/run_benchmarks.py --list diff --git a/README.md b/README.md index 4171c4a..ba66269 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # GDB: GraphicDesignBench -**GDB** evaluates vision-language models on professional graphic design tasks — layout reasoning, typography, SVG editing, template matching, animation. 39 benchmarks across 7 domains, built on the [Lica dataset](https://github.com/lica-world/lica-dataset) (1,148 real design layouts). +**GDB** evaluates vision-language models on professional graphic design tasks — layout reasoning, typography, SVG editing, template matching, animation. 40 benchmarks across 7 domains, built on the [Lica dataset](https://github.com/lica-world/lica-dataset) (1,148 real design layouts). **Paper:** [arXiv:2604.04192](https://arxiv.org/abs/2604.04192)  |  **Dataset:** [HuggingFace](https://huggingface.co/datasets/lica-world/GDB)  |  **Blog:** [lica.world](https://lica.world/blog/gdb-real-world-benchmark-for-graphic-design) @@ -16,7 +16,7 @@ Each task is either **understanding** or **generation**: | svg | 8 | 8 | SVG reasoning and editing (perceptual and semantic Q/A, bug fixing, optimization, style editing) and generation (text-to-SVG, image-to-SVG, combined input) | | template | 5 | 5 | Template matching, retrieval, clustering, and generation (style completion, color transfer) | | temporal | 8 | 6 | Keyframe ordering; motion type classification; video/component duration and start-time estimation; generation (animation parameters, motion trajectory, short-form video) | -| typography | 12 | 8 | Font family, color, size/weight/alignment/letter spacing/line height, style ranges, curvature, rotation, and generation (styled text element, styled text rendering to layout) | +| typography | 13 | 9 | Font family, color, size/weight/alignment/letter spacing/line height, style ranges, curvature, rotation, and generation (styled text element, styled text rendering to layout, text removal/background inpainting as `image-6`) | ## Setup @@ -92,7 +92,7 @@ python scripts/run_benchmarks.py --benchmarks svg-1 \ --provider hf --device auto \ --dataset-root data/gdb-dataset -# Diffusion / image generation (defaults to FLUX.2 klein 4B) +# Diffusion / image generation (defaults to FLUX.2 klein 9B) python scripts/run_benchmarks.py --benchmarks layout-1 \ --provider diffusion \ --dataset-root data/gdb-dataset @@ -109,7 +109,7 @@ python -m pip install --no-deps --ignore-requires-python \ python scripts/run_benchmarks.py --benchmarks layout-1 layout-3 layout-8 typography-7 typography-8 \ --provider custom \ --custom-entry gdb.models.local_models:Flux2Model \ - --custom-init-kwargs '{"model_name":"flux.2-klein-4b"}' \ + --custom-init-kwargs '{"model_name":"flux.2-klein-9b"}' \ --custom-modality image_generation \ --dataset-root data/gdb-dataset ``` @@ -132,7 +132,7 @@ helm-summarize --suite gdb-eval helm-server --suite gdb-eval ``` -All 39 benchmarks are available. See [integrations/helm/](integrations/helm/) for details. +All 40 benchmarks are available. See [integrations/helm/](integrations/helm/) for details. ### API keys @@ -186,12 +186,13 @@ GDB/ ├── src/gdb/ │ ├── tasks/ # @benchmark classes — one file per domain │ │ ├── category.py # category-1, category-2 +│ │ ├── image.py # compatibility shim (re-exports image-6) │ │ ├── layout.py # layout-1 … layout-8 │ │ ├── lottie.py # lottie-1, lottie-2 │ │ ├── svg.py # svg-1 … svg-8 │ │ ├── template.py # template-1 … template-5 │ │ ├── temporal.py # temporal-1 … temporal-6 -│ │ └── typography.py # typography-1 … typography-8 +│ │ └── typography.py # typography-1 … typography-8 + image-6 implementation │ ├── models/ # Provider wrappers (OpenAI, Anthropic, Gemini, HF, vLLM) │ ├── metrics/ # Reusable metric functions (IoU, FID, SSIM, LPIPS, edit distance) │ ├── evaluation/ diff --git a/integrations/helm/README.md b/integrations/helm/README.md index 6b66995..a53e6ac 100644 --- a/integrations/helm/README.md +++ b/integrations/helm/README.md @@ -1,6 +1,6 @@ # lica-gdb-helm -HELM integration for [GDB (GraphicDesignBench)](https://github.com/lica-world/GDB) — run all 39 GDB benchmarks through Stanford CRFM's [HELM](https://github.com/stanford-crfm/helm) framework. +HELM integration for [GDB (GraphicDesignBench)](https://github.com/lica-world/GDB) — run all 40 GDB benchmarks through Stanford CRFM's [HELM](https://github.com/stanford-crfm/helm) framework. ## Install @@ -35,7 +35,7 @@ helm-server --suite gdb-eval ## Available benchmarks -All 39 GDB benchmarks are available. Pass any benchmark ID: +All 40 GDB benchmarks are available. Pass any benchmark ID: | Domain | Benchmark IDs | |--------|--------------| @@ -44,7 +44,7 @@ All 39 GDB benchmarks are available. Pass any benchmark ID: | SVG | `svg-1` through `svg-8` | | Template | `template-1` through `template-5` | | Temporal | `temporal-1` through `temporal-6` | -| Typography | `typography-1` through `typography-8` | +| Typography | `typography-1` through `typography-8`, `image-6` | | Lottie | `lottie-1`, `lottie-2` | ## Options diff --git a/integrations/helm/src/gdb_helm/_benchmark_info.py b/integrations/helm/src/gdb_helm/_benchmark_info.py index 6f9faa0..3f66c48 100644 --- a/integrations/helm/src/gdb_helm/_benchmark_info.py +++ b/integrations/helm/src/gdb_helm/_benchmark_info.py @@ -73,6 +73,7 @@ class BenchmarkInfo: "typography-6": BenchmarkInfo(method="generation_multimodal", max_tokens=256, has_images=True), # -- typography: generation -- + "image-6": BenchmarkInfo(method="generation", max_tokens=0, has_images=True, image_gen=True), "typography-7": BenchmarkInfo(method="generation", max_tokens=0, has_images=True, image_gen=True), "typography-8": BenchmarkInfo(method="generation", max_tokens=0, image_gen=True), diff --git a/integrations/helm/src/gdb_helm/scenarios.py b/integrations/helm/src/gdb_helm/scenarios.py index 27ba4da..74bb191 100644 --- a/integrations/helm/src/gdb_helm/scenarios.py +++ b/integrations/helm/src/gdb_helm/scenarios.py @@ -1,6 +1,6 @@ """HELM Scenario that wraps any GDB benchmark. -One parameterized class handles all 39 benchmarks by delegating data loading +One parameterized class handles all 40 benchmarks by delegating data loading and prompt construction to the ``gdb`` package. """ diff --git a/scripts/README.md b/scripts/README.md index b91279d..8b68532 100644 --- a/scripts/README.md +++ b/scripts/README.md @@ -46,7 +46,7 @@ python scripts/run_benchmarks.py --benchmarks svg-6 \ --provider vllm --model-id Qwen/Qwen3-VL-4B-Instruct --top-k 20 --top-p 0.8 \ --dataset-root data/gdb-dataset -# Diffusion / image generation (defaults to FLUX.2 klein 4B) +# Diffusion / image generation (defaults to FLUX.2 klein 9B) python scripts/run_benchmarks.py --benchmarks layout-1 \ --provider diffusion \ --dataset-root data/gdb-dataset @@ -69,7 +69,7 @@ python -m pip install --no-deps --ignore-requires-python \ python scripts/run_benchmarks.py --benchmarks layout-1 layout-3 layout-8 typography-7 typography-8 \ --provider custom \ --custom-entry gdb.models.local_models:Flux2Model \ - --custom-init-kwargs '{"model_name":"flux.2-klein-4b"}' \ + --custom-init-kwargs '{"model_name":"flux.2-klein-9b"}' \ --custom-modality image_generation \ --dataset-root data/gdb-dataset @@ -100,7 +100,7 @@ from Hugging Face and can use either environment tokens (`HF_TOKEN`, `HF_HUB_TOKEN`) or an existing cached login/token file. The default local text/VLM model ID is now `Qwen/Qwen3-VL-4B-Instruct` for both -`hf` and `vllm`, and the default `diffusion` model ID is `flux.2-klein-4b`. +`hf` and `vllm`, and the default `diffusion` model ID is `flux.2-klein-9b`. ### Batch submit/collect (~50% cheaper) diff --git a/scripts/run_benchmarks.py b/scripts/run_benchmarks.py index 5d29a0c..4069a4a 100644 --- a/scripts/run_benchmarks.py +++ b/scripts/run_benchmarks.py @@ -84,7 +84,7 @@ "anthropic": "claude-sonnet-4-20250514", "hf": "Qwen/Qwen3-VL-4B-Instruct", "vllm": "Qwen/Qwen3-VL-4B-Instruct", - "diffusion": "flux.2-klein-4b", + "diffusion": "flux.2-klein-9b", "custom": "custom-entrypoint", } diff --git a/scripts/upload_to_hf.py b/scripts/upload_to_hf.py index 3c590e2..536823c 100644 --- a/scripts/upload_to_hf.py +++ b/scripts/upload_to_hf.py @@ -442,7 +442,7 @@ def generate_dataset_card(config_names: Optional[List[str]] = None) -> str: # GDB: GraphicDesignBench -39 benchmarks for evaluating vision-language models on graphic design tasks — layout, typography, SVG, template matching, animation. Built on 1,148 real design layouts from the [Lica dataset](https://lica.world). +40 benchmarks for evaluating vision-language models on graphic design tasks — layout, typography, SVG, template matching, animation. Built on 1,148 real design layouts from the [Lica dataset](https://lica.world). **Paper:** [arXiv:2604.04192](https://arxiv.org/abs/2604.04192)  |  **Code:** [github.com/lica-world/GDB](https://github.com/lica-world/GDB)  |  **Blog:** [lica.world](https://lica.world/blog/gdb-real-world-benchmark-for-graphic-design) diff --git a/src/gdb/metrics/__init__.py b/src/gdb/metrics/__init__.py index fb9d950..97c546c 100644 --- a/src/gdb/metrics/__init__.py +++ b/src/gdb/metrics/__init__.py @@ -1,6 +1,6 @@ """Shared metric implementations for GDB benchmarks.""" -from .core import edit_distance, fid, iou, lpips_score, ssim +from .core import edit_distance, fid, iou, lpips_score, psnr, ssim from .text import normalize_font_name __all__ = [ @@ -8,6 +8,7 @@ "fid", "iou", "lpips_score", + "psnr", "normalize_font_name", "ssim", ] diff --git a/src/gdb/metrics/core.py b/src/gdb/metrics/core.py index 6fb59be..9b167bf 100644 --- a/src/gdb/metrics/core.py +++ b/src/gdb/metrics/core.py @@ -76,6 +76,27 @@ def edit_distance(source: str, target: str) -> float: # --------------------------------------------------------------------------- +def psnr(pred: Any, gt: Any) -> float: + """Peak signal-to-noise ratio. + + If the optional third-party ``evaluation.image`` module is available, + delegates to it; otherwise uses ``scikit-image``. + """ + try: + from evaluation.image import psnr as _psnr + + return _psnr(pred, gt) + except ImportError: + pass + + try: + from skimage.metrics import peak_signal_noise_ratio + except ImportError: + raise _missing_extra("scikit-image", "metrics") + + return float(peak_signal_noise_ratio(gt, pred)) + + def ssim(pred: Any, gt: Any) -> float: """Structural similarity index. diff --git a/src/gdb/metrics/remove_metric.py b/src/gdb/metrics/remove_metric.py new file mode 100644 index 0000000..105461a --- /dev/null +++ b/src/gdb/metrics/remove_metric.py @@ -0,0 +1,206 @@ +"""ReMOVE metric helper for object/text erasure quality. + +This adapts the public ReMOVE reference implementation into a reusable class +for gdb. +""" + +from __future__ import annotations + +import logging +import urllib.request +from dataclasses import dataclass +from pathlib import Path +from typing import Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +from PIL import Image +from torch.nn.functional import cosine_similarity + +logger = logging.getLogger(__name__) + +DEFAULT_SAM_CHECKPOINT_URL = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth" +DEFAULT_SAM_CACHE_DIR = Path.home() / ".cache" / "gdb" / "checkpoints" +DEFAULT_SAM_CHECKPOINT_PATH = DEFAULT_SAM_CACHE_DIR / "sam_vit_h_4b8939.pth" + + +def ensure_sam_checkpoint(path: Optional[Union[str, Path]] = None) -> Path: + """Resolve SAM checkpoint path and download it when missing.""" + + if path is None: + target = DEFAULT_SAM_CHECKPOINT_PATH + else: + target = Path(path).expanduser() + if target.exists() and target.is_dir(): + target = target / DEFAULT_SAM_CHECKPOINT_PATH.name + elif target.suffix == "": + target = target / DEFAULT_SAM_CHECKPOINT_PATH.name + + if target.exists(): + return target.resolve() + + target.parent.mkdir(parents=True, exist_ok=True) + logger.info("Downloading SAM checkpoint for ReMOVE: %s", target) + urllib.request.urlretrieve(DEFAULT_SAM_CHECKPOINT_URL, str(target)) + return target.resolve() + + +def _find_smallest_bounding_square(binary_image: np.ndarray) -> Optional[Tuple[int, int, int]]: + """Return (x, y, size) for the smallest square that encloses the mask.""" + + if binary_image.ndim == 3: + binary_image = np.asarray(Image.fromarray(binary_image).convert("L")) + + white_pixels = np.argwhere(binary_image == 255) + if white_pixels.size == 0: + return None + + min_row = int(np.min(white_pixels[:, 0])) + max_row = int(np.max(white_pixels[:, 0])) + min_col = int(np.min(white_pixels[:, 1])) + max_col = int(np.max(white_pixels[:, 1])) + + width = max_col - min_col + 1 + height = max_row - min_row + 1 + size = max(width, height) + + margin = 16 + h, w = binary_image.shape[:2] + pad = margin if ( + min_col - margin >= 0 + and min_row - margin >= 0 + and max_row + margin < h + and max_col + margin < w + ) else max(min(min_col, min_row, h - 1 - max_row, w - 1 - max_col), 0) + + x = max(min_col - pad, 0) + y = max(min_row - pad, 0) + size = size + 2 * pad + + # Clamp so the final crop remains in-bounds. + size = min(size, h - y, w - x) + return x, y, size + + +@dataclass +class RemoveMetricEvaluator: + """Compute ReMOVE score for an inpainted image and removal mask.""" + + sam_checkpoint: Union[str, Path] + model_type: str = "vit_h" + device: Optional[str] = None + crop: bool = True + + def __post_init__(self) -> None: + self._setup_predictor() + + def _setup_predictor(self) -> None: + try: + from segment_anything import sam_model_registry + from segment_anything.predictor import SamPredictor + except ImportError as exc: + raise ImportError( + "segment-anything is required for ReMOVE metric. " + "Install with: pip install git+https://github.com/facebookresearch/segment-anything.git" + ) from exc + + checkpoint = Path(self.sam_checkpoint).expanduser() + if not checkpoint.exists(): + raise FileNotFoundError(f"SAM checkpoint not found: {checkpoint}") + + resolved_device = str(self.device or ("cuda" if torch.cuda.is_available() else "cpu")) + if resolved_device.startswith("cuda") and not torch.cuda.is_available(): + raise RuntimeError("CUDA requested for ReMOVE but CUDA is not available.") + + sam = sam_model_registry[self.model_type](checkpoint=str(checkpoint)) + sam = sam.to(resolved_device) + sam.eval() + + self.predictor = SamPredictor(sam) + self.device = resolved_device + + def _get_mask_embeddings(self, image_np: np.ndarray, masks: list[np.ndarray]) -> list[torch.Tensor]: + if hasattr(self.predictor, "get_aggregate_features"): + embeddings = self.predictor.get_aggregate_features(image_np, masks) + return [self._ensure_tensor(e) for e in embeddings] + return self._aggregate_features(image_np, masks) + + @staticmethod + def _ensure_tensor(value: torch.Tensor | np.ndarray) -> torch.Tensor: + if torch.is_tensor(value): + return value + return torch.as_tensor(value) + + def _aggregate_features(self, image_np: np.ndarray, masks: list[np.ndarray]) -> list[torch.Tensor]: + with torch.no_grad(): + self.predictor.set_image(image_np) + features = self.predictor.get_image_embedding() + + embeddings: list[torch.Tensor] = [] + for mask in masks: + mask_tensor = torch.as_tensor(mask) + if mask_tensor.ndim == 2: + mask_tensor = mask_tensor.unsqueeze(0).unsqueeze(0) + elif mask_tensor.ndim == 3: + mask_tensor = mask_tensor.unsqueeze(0) + + mask_tensor = mask_tensor.to(features.device) + if mask_tensor.shape[-2:] != features.shape[-2:]: + mask_tensor = F.interpolate( + mask_tensor.float(), + size=features.shape[-2:], + mode="nearest", + ) + mask_bool = mask_tensor > 0 + expanded_mask = mask_bool.expand_as(features) + + if not expanded_mask.any(): + embedding = torch.zeros((1, features.shape[1]), device=features.device) + else: + masked_features = features[expanded_mask] + embedding = masked_features.view(1, features.shape[1], -1).mean(dim=2) + + embeddings.append(embedding.detach()) + + return embeddings + + @torch.no_grad() + def score(self, image: Image.Image, mask: Image.Image) -> Optional[float]: + """Return ReMOVE score in [-1, 1], or None for empty masks.""" + + image_np = np.asarray(image.convert("RGB")) + mask_np = np.asarray(mask.convert("L")) + + if mask_np.max() == 0: + return None + + crop_info = _find_smallest_bounding_square(mask_np) if self.crop else None + if crop_info is not None: + x, y, size = crop_info + x2 = min(x + size, image_np.shape[1]) + y2 = min(y + size, image_np.shape[0]) + image_np = image_np[y:y2, x:x2] + mask_np = mask_np[y:y2, x:x2] + + # ReMOVE reference implementation computes foreground/background embeddings + # with a 64x64 binary mask. + mask_fg = ( + np.asarray(Image.fromarray(mask_np).resize((64, 64), Image.NEAREST)) + .reshape(1, 1, 64, 64) + // 255 + ).astype(np.uint8) + mask_bg = 1 - mask_fg + + embeddings = self._get_mask_embeddings(image_np, [mask_fg, mask_bg]) + if len(embeddings) != 2: + raise RuntimeError("Unexpected embedding count returned by SAM predictor for ReMOVE.") + + fg = embeddings[0] + bg = embeddings[1] + if fg.device != bg.device: + bg = bg.to(fg.device) + + score = cosine_similarity(fg, bg).item() + return float(np.clip(score, -1.0, 1.0)) + diff --git a/src/gdb/models/api_models.py b/src/gdb/models/api_models.py index c5ac3ae..fa2f674 100644 --- a/src/gdb/models/api_models.py +++ b/src/gdb/models/api_models.py @@ -82,6 +82,9 @@ class OpenAIImageModel(BaseModel): """OpenAI image model wrapper for generation/editing workflows.""" modality = Modality.IMAGE_GENERATION + supports_image_output = True + supports_image_input = True + supports_mask_editing = True def __init__( self, @@ -300,6 +303,9 @@ def __init__( self.model_id = self._resolve_model_id(model_id) self.name = self.model_id self.modality = self._infer_modality(self.model_id) + self.supports_image_output = self.modality == Modality.IMAGE_GENERATION + self.supports_image_input = True + self.supports_mask_editing = self.modality == Modality.IMAGE_GENERATION self.temperature = temperature self.max_tokens = max_tokens self.adaptive_image_config = adaptive_image_config diff --git a/src/gdb/models/local_models.py b/src/gdb/models/local_models.py index 74d1253..9e1cc70 100644 --- a/src/gdb/models/local_models.py +++ b/src/gdb/models/local_models.py @@ -20,6 +20,7 @@ logger = logging.getLogger(__name__) + def _is_fatal_load_error(exc: BaseException) -> bool: """True for errors that should abort the model-loader fallback loop. @@ -538,10 +539,12 @@ def _load(self) -> None: return try: from vllm_omni import Omni # type: ignore[reportMissingImports] - except ImportError: + except ImportError as exc: raise ImportError( - 'vllm-omni is required for diffusion models. Install with: pip install -e ".[vllm-omni]"' - ) + "vllm-omni is required for diffusion models. " + 'Install with: pip install -e ".[vllm-omni]". ' + f"Import error: {exc}" + ) from exc self._omni = Omni(model=self.model_id) @@ -789,7 +792,7 @@ class Flux2Model(BaseModel): python scripts/run_benchmarks.py --benchmarks layout-8 \ --provider custom \ --custom-entry gdb.models.local_models:Flux2Model \ - --custom-init-kwargs '{"model_name":"flux.2-klein-4b"}' \ + --custom-init-kwargs '{"model_name":"flux.2-klein-9b"}' \ --custom-modality image_generation """ @@ -807,13 +810,12 @@ class Flux2Model(BaseModel): def __init__( self, - model_name: str = "flux.2-klein-4b", + model_name: str = "flux.2-klein-9b", device: str = "cuda", num_steps: Optional[int] = None, guidance: Optional[float] = None, seed: Optional[int] = None, debug_mode: bool = False, - preserve_unmasked_regions: bool = True, default_width: int = 1024, default_height: int = 1024, **kwargs: Any, @@ -830,7 +832,6 @@ def __init__( self.guidance = float(guidance) if guidance is not None else None self.seed = int(seed) if seed is not None else None self.debug_mode = bool(debug_mode) - self.preserve_unmasked_regions = bool(preserve_unmasked_regions) self.default_width = max(64, int(default_width)) self.default_height = max(64, int(default_height)) self._bundle: Optional[Dict[str, Any]] = None @@ -1159,33 +1160,6 @@ def _resolve_guidance(self, model_info: Dict[str, Any]) -> float: ) return self.guidance - def _compose_masked_output(self, inp: ModelInput, input_images: List[Any], generated: Any) -> Any: - if not self.preserve_unmasked_regions or not input_images: - return generated - - if (inp.metadata or {}).get("skip_mask_composition", False): - return generated - - mask = self._coerce_pil_image((inp.metadata or {}).get("mask")) - if mask is None: - return generated - - try: - from PIL import Image # type: ignore[reportMissingImports] - except ImportError: - return generated - - base = input_images[0].convert("RGB") - output = generated.convert("RGB") - if output.size != base.size: - output = output.resize(base.size, Image.Resampling.LANCZOS) - mask_l = mask.convert("L") - if mask_l.size != base.size: - mask_l = mask_l.resize(base.size, Image.Resampling.NEAREST) - if mask_l.getextrema() == (255, 255): - return generated - return Image.composite(output, base, mask_l) - def predict(self, inp: ModelInput) -> ModelOutput: bundle = self._ensure_loaded() torch = bundle["torch"] @@ -1261,8 +1235,6 @@ def predict(self, inp: ModelInput) -> ModelOutput: generated = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy()) - generated = self._compose_masked_output(inp, input_images, generated) - return ModelOutput( images=[generated], raw={ diff --git a/src/gdb/tasks/image.py b/src/gdb/tasks/image.py new file mode 100644 index 0000000..41fa5d5 --- /dev/null +++ b/src/gdb/tasks/image.py @@ -0,0 +1,8 @@ +"""Backward-compatible shim for image-domain task imports. + +Text removal (`image-6`) is implemented in `gdb.tasks.typography`. +""" + +from gdb.tasks.typography import TextRemoval + +__all__ = ["TextRemoval"] diff --git a/src/gdb/tasks/typography.py b/src/gdb/tasks/typography.py index 6e3450c..ea75e87 100644 --- a/src/gdb/tasks/typography.py +++ b/src/gdb/tasks/typography.py @@ -1,4 +1,4 @@ -"""Typography benchmarks: typography-1 through typography-8. +"""Typography benchmarks: typography-1 through typography-8 (+ image-6). Data contract for these tasks: ``samples.csv`` in the ``--data`` directory with columns ``sample_id``, ``prompt``, ``image_path``, ``expected_output``. @@ -28,6 +28,7 @@ from gdb.base import BaseBenchmark, BenchmarkMeta, TaskType, benchmark from gdb.metrics.core import edit_distance, lpips_score +from gdb.metrics.core import fid as fid_metric from gdb.metrics.core import ssim as ssim_metric from gdb.metrics.text import normalize_font_name from gdb.utils.data_helpers import build_vision_input, load_csv_samples @@ -748,19 +749,46 @@ def _parse_json_cell(value: Any) -> Optional[Any]: def build_model_input(self, sample: Dict[str, Any], *, modality: Any = None) -> Any: from gdb.models.base import ModelInput + metadata: Dict[str, Any] = { + "benchmark_id": self.meta.id, + "task": "g10_styled_text_element_generation", + "text": str(sample.get("text") or ""), + "style_spec": sample.get("style_spec") or {}, + "prompt": str(sample.get("prompt") or ""), + } + width, height = self._read_image_size(sample.get("ground_truth_image")) + if width > 0 and height > 0: + metadata["target_width"] = width + metadata["target_height"] = height + return ModelInput( text=str(sample.get("prompt") or ""), images=[], - metadata={ - "benchmark_id": self.meta.id, - "task": "g10_styled_text_element_generation", - "mask": str(sample.get("mask") or ""), - "text": str(sample.get("text") or ""), - "style_spec": sample.get("style_spec") or {}, - "prompt": str(sample.get("prompt") or ""), - }, + metadata=metadata, ) + @staticmethod + def _read_image_size(image_like: Any) -> Tuple[int, int]: + try: + from PIL import Image + except ImportError: + return (0, 0) + + try: + if isinstance(image_like, (str, Path)): + p = Path(image_like) + if p.exists(): + with Image.open(p) as img: + return int(img.size[0]), int(img.size[1]) + elif isinstance(image_like, (bytes, bytearray)): + with Image.open(io.BytesIO(image_like)) as img: + return int(img.size[0]), int(img.size[1]) + elif isinstance(image_like, Image.Image): + return int(image_like.size[0]), int(image_like.size[1]) + except Exception: + return (0, 0) + return (0, 0) + def parse_model_output(self, output: Any) -> Any: if output is None: return None @@ -2539,3 +2567,1009 @@ def _compose_inpaint_prompt(cls, *, text: str, style_spec: Dict[str, Any]) -> st ) return "\n".join(lines) + +# =========================================================================== +# image-6 (Text Removal) — implemented in typography domain module +# =========================================================================== + + +@benchmark +class TextRemoval(BaseBenchmark): + """image-6 -- Remove text and inpaint background cleanly (G16-style). + + Data format: + - JSON manifest (array or ``{"samples": [...]}``) + - CSV manifest (header row + one sample per row) + + Each sample needs ``input_image`` and ``ground_truth_image``; ``mask`` is + strongly recommended, but can be inferred from common path conventions + (e.g. ``text_removal/mask/.png`` or ``/input/`` -> ``/mask/``). + Optional fields: ``forbidden_texts`` and ``prompt``. Image paths are + resolved relative to the manifest file's directory. + + Accepts either a direct path to a manifest file or a directory + containing ``text_removal_manifest.json`` or ``text_removal_manifest.csv``. + """ + + pipeline_implemented = True + + meta = BenchmarkMeta( + id="image-6", + name="Text Removal & Background Inpainting", + task_type=TaskType.GENERATION, + domain="typography", + data_subpath="image/image-6-text-removal", + description="Remove text and inpaint the underlying background cleanly", + input_spec="Layout image with text components masked (+ optional text mask)", + output_spec="Clean image with text removed and background reconstructed", + metrics=[ + "psnr", + "ssim", + "lpips", + "dino_score", + "clip_score", + "fid", + "fid_coverage", + "ocr_text_absence", + "ocr_coverage", + "bbox_text_absence", + "bbox_coverage", + "remove", + "remove_coverage", + ], + ) + + DEFAULT_PROMPT = "Remove all text and reconstruct the background naturally." + PROMPT_PREFIX = ( + "You are an expert design retoucher specialized in text removal and " + "background inpainting." + ) + PROMPT_SIGNATURE = ( + "Task: remove all visible text while preserving non-text visual content." + ) + BBOX_TEXT_ABSENCE_ENABLED_ENV = "GDB_IMAGE6_USE_BBOX_DETECTOR" + BBOX_TEXT_ABSENCE_MAX_PHRASES_ENV = "GDB_IMAGE6_BBOX_MAX_PHRASES" + BBOX_TEXT_ABSENCE_MAX_CHARS_ENV = "GDB_IMAGE6_BBOX_MAX_CHARS" + _remove_evaluator_bundle: Any = None + _fid_inception_bundle: Any = None + + @classmethod + def _looks_like_composed_prompt(cls, text: str) -> bool: + raw = str(text or "") + return ( + cls.PROMPT_SIGNATURE in raw + and "Hard constraints (must satisfy all):" in raw + ) + + @classmethod + def compose_model_prompt( + cls, + *, + user_prompt: str = "", + forbidden_texts: Optional[List[str]] = None, + ) -> str: + objective = str(user_prompt or "").strip() or cls.DEFAULT_PROMPT + forbidden = [str(t).strip() for t in (forbidden_texts or []) if str(t).strip()] + + lines = [ + cls.PROMPT_PREFIX, + cls.PROMPT_SIGNATURE, + "", + f"Objective: {objective}", + "", + "Input semantics:", + "- Image #1 is the original layout image.", + "- A binary text mask is provided by the task runtime.", + "- White mask pixels indicate where text is likely present.", + "- Full-image regeneration is allowed; strict unmasked preservation is not required.", + ] + + if forbidden: + lines.extend( + [ + "", + "Texts that must be absent in the final output:", + ] + ) + for text in forbidden[:40]: + lines.append(f'- "{text}"') + + lines.extend( + [ + "", + "Hard constraints (must satisfy all):", + "- Remove all visible text traces, prioritizing masked regions.", + "- Keep overall layout semantics, style, and composition coherent.", + "- Reconstruct the background naturally with coherent texture/lighting.", + "- Keep canvas size/aspect ratio consistent with the input image.", + "- Output one final generated image only (no explanation text).", + ] + ) + return "\n".join(lines) + + @classmethod + def _resolve_model_prompt( + cls, + *, + user_prompt: str, + forbidden_texts: Optional[List[str]] = None, + ) -> str: + raw = str(user_prompt or "").strip() + if raw and cls._looks_like_composed_prompt(raw): + return raw + return cls.compose_model_prompt( + user_prompt=raw or cls.DEFAULT_PROMPT, + forbidden_texts=forbidden_texts, + ) + + def _resolve(self, base_dir: Path, value: str) -> str: + """Resolve a path relative to the manifest's directory.""" + p = Path(value) + if p.is_absolute(): + return str(p) + return str((base_dir / p).resolve()) + + def load_data( + self, + data_dir: Union[str, Path], + *, + n: Optional[int] = None, + dataset_root: Union[str, Path], + ) -> List[Dict[str, Any]]: + path = Path(data_dir).resolve() + if path.is_dir(): + json_manifest = path / "text_removal_manifest.json" + csv_manifest = path / "text_removal_manifest.csv" + if json_manifest.exists(): + path = json_manifest + elif csv_manifest.exists(): + path = csv_manifest + else: + raise FileNotFoundError( + "Text removal manifest not found under directory: " + f"{path}. Expected text_removal_manifest.json or text_removal_manifest.csv" + ) + if not path.exists(): + raise FileNotFoundError(f"Text removal manifest not found: {path}") + + rows = self._load_manifest_rows(path) + + base_dir = path.parent + samples: List[Dict[str, Any]] = [] + used_sample_ids: set[str] = set() + for i, row in enumerate(rows): + if not isinstance(row, dict): + logger.warning("Invalid sample at index %d (expected object/dict), skipping", i) + continue + + input_image = self._first_nonempty_value( + row, + ("input_image", "masked_image", "image", "image_path"), + ) + mask = self._first_nonempty_value( + row, + ("mask", "text_mask", "mask_path"), + ) + gt_image = self._first_nonempty_value( + row, + ("ground_truth_image", "target_image", "ground_truth", "expected_output"), + ) + + raw_sample_id = ( + self._first_nonempty_value( + row, + ("sample_id", "id", "layout_id", "source_layout_id", "sampleId"), + ) + ) + if not raw_sample_id and isinstance(input_image, str): + raw_sample_id = Path(input_image).stem + sample_id = str(raw_sample_id).strip() if raw_sample_id else f"text_removal_{i:03d}" + if not sample_id: + sample_id = f"text_removal_{i:03d}" + + if not mask: + inferred_mask = self._infer_mask_path( + base_dir=base_dir, + sample_id=sample_id, + input_image=input_image, + ) + if inferred_mask: + mask = inferred_mask + + if not input_image or not mask or not gt_image: + logger.warning("Incomplete sample at index %d, skipping", i) + continue + + forbidden_texts = self._parse_forbidden_texts( + self._first_nonempty_value( + row, + ("forbidden_texts", "forbidden_text", "texts"), + ) + ) + + if sample_id in used_sample_ids: + suffix = 2 + candidate = f"{sample_id}__{suffix}" + while candidate in used_sample_ids: + suffix += 1 + candidate = f"{sample_id}__{suffix}" + sample_id = candidate + used_sample_ids.add(sample_id) + + raw_prompt = str(self._first_nonempty_value(row, ("prompt",)) or "") + prompt = self._resolve_model_prompt( + user_prompt=self._decode_prompt_field(raw_prompt), + forbidden_texts=forbidden_texts, + ) + + samples.append({ + "sample_id": sample_id, + "ground_truth": { + "image": self._resolve(base_dir, gt_image), + "mask": self._resolve(base_dir, mask), + "forbidden_texts": [str(t) for t in forbidden_texts], + "prompt": prompt, + }, + "input_image": self._resolve(base_dir, input_image), + "mask": self._resolve(base_dir, mask), + "forbidden_texts": [str(t) for t in forbidden_texts], + "prompt": prompt, + }) + + if n is not None: + samples = samples[:n] + return samples + + @staticmethod + def _load_manifest_rows(path: Path) -> List[Dict[str, Any]]: + suffix = path.suffix.lower() + if suffix == ".csv": + with open(path, "r", encoding="utf-8-sig", newline="") as f: + reader = csv.DictReader(f) + if not reader.fieldnames: + raise ValueError(f"CSV manifest has no header row: {path}") + rows: List[Dict[str, Any]] = [] + for row in reader: + if not isinstance(row, dict): + continue + cleaned: Dict[str, Any] = {} + for key, value in row.items(): + header = str(key or "").strip() + if not header: + continue + cleaned[header] = value + rows.append(cleaned) + return rows + + with open(path, "r", encoding="utf-8") as f: + payload = json.load(f) + rows = payload.get("samples") if isinstance(payload, dict) else payload + if not isinstance(rows, list): + raise ValueError(f"Manifest must be a list or dict with 'samples': {path}") + return rows + + @staticmethod + def _first_nonempty_value(row: Dict[str, Any], keys: Tuple[str, ...]) -> Any: + for key in keys: + if key not in row: + continue + value = row.get(key) + if value is None: + continue + if isinstance(value, str): + text = value.strip() + if text: + return text + continue + if isinstance(value, (list, dict)): + if value: + return value + continue + return value + return None + + @staticmethod + def _parse_forbidden_texts(raw: Any) -> List[str]: + if raw is None: + return [] + if isinstance(raw, list): + return [str(t).strip() for t in raw if str(t).strip()] + + text = str(raw).strip() + if not text: + return [] + + if text.startswith("["): + try: + decoded = json.loads(text) + if isinstance(decoded, list): + return [str(t).strip() for t in decoded if str(t).strip()] + except Exception: + pass + + for sep in ("|||", "|", ";", "\n", ","): + if sep in text: + values = [part.strip() for part in text.split(sep) if part.strip()] + if values: + return values + return [text] + + @classmethod + def _infer_mask_path( + cls, + *, + base_dir: Path, + sample_id: str, + input_image: Any, + ) -> str: + sid = str(sample_id or "").strip() + if sid: + for rel in ( + f"text_removal/mask/{sid}.png", + f"mask/{sid}.png", + f"masks/{sid}.png", + ): + resolved = cls._resolve_existing_path(base_dir, rel) + if resolved: + return resolved + + raw_input = str(input_image or "").strip() + if not raw_input: + return "" + + for src, dst in ( + ("/input/", "/mask/"), + ("\\input\\", "\\mask\\"), + ("/masked_layout/", "/mask/"), + ("\\masked_layout\\", "\\mask\\"), + ): + if src in raw_input: + candidate = raw_input.replace(src, dst) + resolved = cls._resolve_existing_path(base_dir, candidate) + if resolved: + return resolved + return "" + + @staticmethod + def _resolve_existing_path(base_dir: Path, raw: str) -> str: + text = str(raw or "").strip() + if not text: + return "" + as_path = Path(text) + if as_path.is_file(): + return str(as_path.resolve()) + rel_path = (base_dir / text).resolve() + if rel_path.is_file(): + return str(rel_path) + return "" + + @staticmethod + def _decode_prompt_field(raw: Any) -> str: + """Decode one-line CSV escaped newlines into runtime prompt newlines.""" + text = str(raw or "") + if not text.strip(): + return "" + text = text.replace("\\r\\n", "\n").replace("\\n", "\n") + return text.strip() + + def build_model_input(self, sample: Dict[str, Any], *, modality: Any = None) -> Any: + from gdb.models.base import ModelInput + + prompt = self._resolve_model_prompt( + user_prompt=str(sample.get("prompt") or ""), + forbidden_texts=self._parse_forbidden_texts(sample.get("forbidden_texts")), + ) + metadata: Dict[str, Any] = { + "mask": sample["mask"], + "task": "text_removal", + "benchmark_id": self.meta.id, + "sample_id": str(sample.get("sample_id") or ""), + } + width, height = self._read_image_size(sample.get("input_image")) + if width > 0 and height > 0: + metadata["target_width"] = width + metadata["target_height"] = height + + return ModelInput( + text=prompt, + images=[sample["input_image"]], + metadata=metadata, + ) + + @staticmethod + def _read_image_size(image_like: Any) -> Tuple[int, int]: + try: + from PIL import Image + except ImportError: + return (0, 0) + + try: + if isinstance(image_like, (str, Path)): + p = Path(image_like) + if p.exists(): + with Image.open(p) as img: + return int(img.size[0]), int(img.size[1]) + elif isinstance(image_like, (bytes, bytearray)): + with Image.open(io.BytesIO(image_like)) as img: + return int(img.size[0]), int(img.size[1]) + elif isinstance(image_like, Image.Image): + return int(image_like.size[0]), int(image_like.size[1]) + except Exception: + return (0, 0) + return (0, 0) + + def parse_model_output(self, output: Any) -> Any: + """Return the first generated image, path-like payload, or None.""" + if output is None: + return None + images = getattr(output, "images", None) + if isinstance(images, list) and images: + return images[0] + if isinstance(output, dict): + for key in ("image", "image_path", "prediction", "output_image"): + if key in output: + return output[key] + if isinstance(output, (str, Path, bytes, bytearray)): + return output + return None + + def evaluate(self, predictions: List[Any], ground_truth: List[Any]) -> Dict[str, float]: + """Evaluate reconstruction quality + OCR-confirmed text absence.""" + from gdb.metrics.core import psnr as metric_psnr + from gdb.metrics.core import ssim as metric_ssim + from gdb.tasks.layout import LayerAwareObjectInsertion + + psnr_scores: List[float] = [] + ssim_scores: List[float] = [] + lpips_scores: List[float] = [] + dino_scores: List[float] = [] + clip_scores: List[float] = [] + fid_real_features: List[np.ndarray] = [] + fid_gen_features: List[np.ndarray] = [] + text_absence_scores: List[float] = [] + bbox_absence_scores: List[float] = [] + remove_scores: List[float] = [] + + for pred_raw, gt_raw in zip(predictions, ground_truth): + gt_bundle = self._normalise_gt_bundle(gt_raw) + pred_image_like = self._extract_image_like(pred_raw) + gt_image_like = self._extract_image_like(gt_bundle["image"]) + + pred_img = self._to_rgb_array(pred_image_like) + gt_img = self._to_rgb_array(gt_image_like) + if pred_img is None or gt_img is None: + continue + + pred_img_native = pred_img.copy() + pred_img = self._resize_to_match(pred_img, gt_img.shape[:2]) + + try: + psnr_scores.append(float(metric_psnr(pred_img, gt_img))) + except Exception: + psnr_scores.append(self._fallback_psnr(pred_img, gt_img)) + + try: + ssim_scores.append(float(metric_ssim(pred_img, gt_img))) + except Exception: + ssim_scores.append(self._fallback_ssim(pred_img, gt_img)) + + lpips = LayerAwareObjectInsertion._lpips_distance(pred_img, gt_img) + if isinstance(lpips, float) and math.isfinite(lpips): + lpips_scores.append(lpips) + + dino = LayerAwareObjectInsertion._dino_similarity(pred_img, gt_img) + if isinstance(dino, float) and math.isfinite(dino): + dino_scores.append(dino) + + # Image-6 clip_score is defined as image-image similarity + # between generated output and ground-truth target. + clip = LayerAwareObjectInsertion._clip_image_similarity(pred_img, gt_img) + if isinstance(clip, float) and math.isfinite(clip): + clip_scores.append(clip) + + real_feat = self._inception_feature(gt_img) + gen_feat = self._inception_feature(pred_img) + if real_feat is not None and gen_feat is not None: + fid_real_features.append(real_feat) + fid_gen_features.append(gen_feat) + + absence = self._ocr_text_absence_score( + pred_img, + gt_bundle["forbidden_texts"], + gt_bundle["mask"], + ) + if absence is not None: + text_absence_scores.append(absence) + + bbox_absence = self._bbox_text_absence_score( + prediction_image=pred_img, + forbidden_texts=gt_bundle["forbidden_texts"], + mask_like=gt_bundle["mask"], + sample_id=str(gt_bundle.get("sample_id", "")), + ) + if bbox_absence is not None: + bbox_absence_scores.append(bbox_absence) + + remove_score = self._remove_score(pred_img_native, gt_bundle["mask"]) + if isinstance(remove_score, float) and math.isfinite(remove_score): + remove_scores.append(remove_score) + + n = len(psnr_scores) or 1 + fid_score = float("nan") + if len(fid_real_features) >= 2 and len(fid_gen_features) >= 2: + try: + fid_score = float(fid_metric(np.stack(fid_real_features), np.stack(fid_gen_features))) + # Numerical noise from sqrtm can produce tiny negative values. + if math.isfinite(fid_score): + fid_score = max(0.0, fid_score) + except Exception: + fid_score = float("nan") + return { + "psnr": sum(psnr_scores) / n, + "ssim": sum(ssim_scores) / n, + "lpips": (sum(lpips_scores) / len(lpips_scores) if lpips_scores else float("nan")), + "lpips_coverage": len(lpips_scores) / n, + "dino_score": (sum(dino_scores) / len(dino_scores) if dino_scores else float("nan")), + "dino_coverage": len(dino_scores) / n, + "clip_score": (sum(clip_scores) / len(clip_scores) if clip_scores else float("nan")), + "clipscore": (sum(clip_scores) / len(clip_scores) if clip_scores else float("nan")), + "clip_coverage": len(clip_scores) / n, + "fid": fid_score, + "fid_coverage": len(fid_real_features) / n, + "ocr_text_absence": ( + sum(text_absence_scores) / len(text_absence_scores) + if text_absence_scores + else float("nan") + ), + "ocr_coverage": len(text_absence_scores) / n, + "bbox_text_absence": ( + sum(bbox_absence_scores) / len(bbox_absence_scores) + if bbox_absence_scores + else float("nan") + ), + "bbox_coverage": len(bbox_absence_scores) / n, + "remove": (sum(remove_scores) / len(remove_scores) if remove_scores else float("nan")), + "remove_coverage": len(remove_scores) / n, + } + + @staticmethod + def _normalise_gt_bundle(raw: Any) -> Dict[str, Any]: + if isinstance(raw, dict): + image = raw.get("image", raw.get("ground_truth_image", raw)) + forbidden = raw.get("forbidden_texts") or raw.get("texts") or [] + if isinstance(forbidden, str): + forbidden = [forbidden] + mask = raw.get("mask") or raw.get("text_mask") + prompt = str(raw.get("prompt", "")) + sample_id = str(raw.get("sample_id", "")).strip() + return { + "image": image, + "forbidden_texts": forbidden, + "mask": mask, + "prompt": prompt, + "sample_id": sample_id, + } + + return {"image": raw, "forbidden_texts": [], "mask": None, "prompt": "", "sample_id": ""} + + @staticmethod + def _extract_image_like(value: Any) -> Any: + if isinstance(value, dict): + for key in ("image", "output_image", "predicted_image", "path"): + if key in value: + return value[key] + + # Support ModelOutput-like objects without importing model modules. + images = getattr(value, "images", None) + if images: + return images[0] + + return value + + @staticmethod + def _to_rgb_array(image_like: Any) -> Optional[np.ndarray]: + if isinstance(image_like, np.ndarray): + arr = image_like + else: + try: + from PIL import Image + except ImportError: + return None + + pil: Optional[Image.Image] = None + if isinstance(image_like, Image.Image): + pil = image_like + elif isinstance(image_like, (str, Path)): + path_text = str(image_like).strip() + if not path_text: + return None + path_obj = Path(path_text) + if not path_obj.exists() or path_obj.is_dir(): + return None + pil = Image.open(path_obj) + elif isinstance(image_like, (bytes, bytearray)): + pil = Image.open(io.BytesIO(image_like)) + + if pil is None: + return None + arr = np.asarray(pil.convert("RGB")) + + if arr.ndim == 2: + arr = np.stack([arr, arr, arr], axis=-1) + elif arr.ndim == 3 and arr.shape[2] == 4: + arr = arr[:, :, :3] + elif arr.ndim != 3: + return None + + if arr.dtype != np.uint8: + arr = np.clip(arr, 0, 255).astype(np.uint8) + + return arr + + @staticmethod + def _resize_to_match(image: np.ndarray, target_hw: tuple[int, int]) -> np.ndarray: + if image.shape[:2] == target_hw: + return image + + try: + from PIL import Image + + resized = Image.fromarray(image).resize( + (target_hw[1], target_hw[0]), + Image.BILINEAR, + ) + return np.asarray(resized) + except ImportError: + return np.resize(image, (target_hw[0], target_hw[1], image.shape[2])) + + @staticmethod + def _fallback_psnr(pred: np.ndarray, gt: np.ndarray) -> float: + pred_f = pred.astype(np.float32) + gt_f = gt.astype(np.float32) + mse = float(np.mean((pred_f - gt_f) ** 2)) + if mse == 0.0: + return float("inf") + return float(20.0 * math.log10(255.0) - 10.0 * math.log10(mse)) + + @staticmethod + def _fallback_ssim(pred: np.ndarray, gt: np.ndarray) -> float: + # Lightweight fallback when skimage is unavailable. + pred_f = pred.astype(np.float32) + gt_f = gt.astype(np.float32) + mse = float(np.mean((pred_f - gt_f) ** 2)) + return float(max(0.0, 1.0 - (mse / (255.0 ** 2)))) + + @staticmethod + def _normalise_text(raw: str) -> str: + compact = re.sub(r"[^a-z0-9]+", " ", str(raw).lower()) + return re.sub(r"\s+", " ", compact).strip() + + @classmethod + def _mask_to_region(cls, image: np.ndarray, mask_like: Any) -> np.ndarray: + if mask_like is None: + return image + + mask = cls._to_gray_mask(mask_like, image.shape[:2]) + if mask is None: + return image + + ys, xs = np.where(mask > 127) + if ys.size == 0 or xs.size == 0: + return image + + y1, y2 = int(ys.min()), int(ys.max()) + 1 + x1, x2 = int(xs.min()), int(xs.max()) + 1 + return image[y1:y2, x1:x2] + + @staticmethod + def _to_gray_mask(mask_like: Any, target_hw: tuple[int, int]) -> Optional[np.ndarray]: + if isinstance(mask_like, np.ndarray): + mask = mask_like + else: + try: + from PIL import Image + except ImportError: + return None + + pil: Optional[Image.Image] = None + if isinstance(mask_like, Image.Image): + pil = mask_like + elif isinstance(mask_like, (str, Path)): + if not Path(mask_like).exists(): + return None + pil = Image.open(mask_like) + elif isinstance(mask_like, (bytes, bytearray)): + pil = Image.open(io.BytesIO(mask_like)) + + if pil is None: + return None + mask = np.asarray(pil.convert("L")) + + if mask.ndim == 3: + mask = mask[:, :, 0] + if mask.shape[:2] != target_hw: + try: + from PIL import Image + + mask = np.asarray( + Image.fromarray(mask.astype(np.uint8)).resize( + (target_hw[1], target_hw[0]), + Image.NEAREST, + ) + ) + except ImportError: + mask = np.resize(mask, target_hw) + return mask.astype(np.uint8) + + @classmethod + def _ocr_text_absence_score( + cls, + prediction_image: np.ndarray, + forbidden_texts: List[str], + mask_like: Any, + ) -> Optional[float]: + if not forbidden_texts: + return 1.0 + + ocr_text = cls._run_ocr(cls._mask_to_region(prediction_image, mask_like)) + if ocr_text is None: + return None + + normalised_ocr = cls._normalise_text(ocr_text) + if not normalised_ocr: + return 1.0 + + for phrase in forbidden_texts: + needle = cls._normalise_text(phrase) + if not needle: + continue + + # Exact phrase check first. + if needle in normalised_ocr: + return 0.0 + + # Token-level fallback for OCR spacing/punctuation variation. + tokens = [tok for tok in needle.split(" ") if len(tok) >= 3] + if tokens and all(tok in normalised_ocr for tok in tokens): + return 0.0 + + return 1.0 + + @staticmethod + def _env_flag_enabled(name: str, default: bool = False) -> bool: + raw = str(os.environ.get(name, "")).strip().lower() + if not raw: + return default + return raw in {"1", "true", "yes", "on"} + + @staticmethod + def _env_int(name: str, default: int) -> int: + raw = str(os.environ.get(name, "")).strip() + if not raw: + return default + try: + return int(raw) + except Exception: + return default + + @staticmethod + def _box_area(box: Tuple[int, int, int, int]) -> int: + x1, y1, x2, y2 = box + return max(0, x2 - x1) * max(0, y2 - y1) + + @classmethod + def _box_iou(cls, a: Tuple[int, int, int, int], b: Tuple[int, int, int, int]) -> float: + ax1, ay1, ax2, ay2 = a + bx1, by1, bx2, by2 = b + ix1 = max(ax1, bx1) + iy1 = max(ay1, by1) + ix2 = min(ax2, bx2) + iy2 = min(ay2, by2) + inter = cls._box_area((ix1, iy1, ix2, iy2)) + union = cls._box_area(a) + cls._box_area(b) - inter + if union <= 0: + return 0.0 + return float(inter / union) + + @classmethod + def _prepare_bbox_forbidden_texts(cls, forbidden_texts: List[str]) -> List[str]: + max_phrases = max(1, cls._env_int(cls.BBOX_TEXT_ABSENCE_MAX_PHRASES_ENV, 10)) + max_chars = max(80, cls._env_int(cls.BBOX_TEXT_ABSENCE_MAX_CHARS_ENV, 1200)) + out: List[str] = [] + seen: set[str] = set() + char_budget = 0 + for raw in forbidden_texts: + text = re.sub(r"\s+", " ", str(raw or "")).strip() + if not text: + continue + key = cls._normalise_text(text) + if not key or key in seen: + continue + if char_budget + len(text) > max_chars and out: + break + seen.add(key) + out.append(text) + char_budget += len(text) + if len(out) >= max_phrases: + break + return out + + @classmethod + def _bbox_query_text(cls, forbidden_texts: List[str]) -> str: + lines = [ + "Detect any remaining visible text that matches one of the following forbidden texts.", + "Return a bbox for one matching occurrence if found.", + "", + "Forbidden texts:", + ] + for idx, phrase in enumerate(forbidden_texts, start=1): + lines.append(f"{idx}. {phrase}") + return "\n".join(lines) + + @classmethod + def _bbox_text_absence_score( + cls, + *, + prediction_image: np.ndarray, + forbidden_texts: List[str], + mask_like: Any, + sample_id: str = "", + ) -> Optional[float]: + if not cls._env_flag_enabled(cls.BBOX_TEXT_ABSENCE_ENABLED_ENV, default=False): + return None + prepared = cls._prepare_bbox_forbidden_texts(forbidden_texts) + if not prepared: + return 1.0 + + # Reuse typography-7 bbox detector stack (gpt-5.4 by default). + if StyledTextGeneration._get_bbox_detector_model() is None: + return None + + mask = cls._to_gray_mask(mask_like, prediction_image.shape[:2]) + mask_bbox = StyledTextGeneration._mask_bbox(mask) + query_text = cls._bbox_query_text(prepared) + bbox = StyledTextGeneration._detect_text_bbox_llm( + image=prediction_image, + expected_text=query_text, + mask_bbox=mask_bbox, + sample_id=f"image-6|{sample_id}", + ) + if bbox is None: + return 1.0 + if mask_bbox is None: + return 0.0 + # Ignore detections that do not overlap editable text area. + return 0.0 if cls._box_iou(bbox, mask_bbox) > 0.0 else 1.0 + + @staticmethod + def _run_ocr(image: np.ndarray) -> Optional[str]: + try: + import pytesseract + from PIL import Image + except ImportError: + return None + + try: + return str(pytesseract.image_to_string(Image.fromarray(image), config="--psm 6")) + except Exception: + return None + + @classmethod + def _inception_feature(cls, image: np.ndarray) -> Optional[np.ndarray]: + """Extract commonly used Inception-v3 pool3(2048) feature for FID.""" + if cls._fid_inception_bundle is None: + try: + import torch + + device = "cuda" if torch.cuda.is_available() else "cpu" + # Prefer pytorch-fid's Inception (common FID implementation). + try: + from pytorch_fid.inception import InceptionV3 + + block = InceptionV3.BLOCK_INDEX_BY_DIM[2048] + model = InceptionV3([block]).to(device).eval() + cls._fid_inception_bundle = ("pytorch_fid", model, torch, device) + except Exception: + # Fallback: torchvision Inception-v3 pool3 feature. + from torchvision.models import Inception_V3_Weights, inception_v3 + + model = inception_v3( + weights=Inception_V3_Weights.IMAGENET1K_V1, + aux_logits=False, + ) + model.fc = torch.nn.Identity() + model = model.to(device).eval() + cls._fid_inception_bundle = ("torchvision", model, torch, device) + except Exception as exc: + logger.info("Inception FID feature extractor unavailable: %s", exc) + cls._fid_inception_bundle = False + + if not cls._fid_inception_bundle: + return None + + mode, model, torch, device = cls._fid_inception_bundle + try: + img = image + if img.dtype != np.uint8: + img = np.clip(img, 0, 255).astype(np.uint8) + img = np.array(img, copy=True) + + x = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0).float() / 255.0 + x = torch.nn.functional.interpolate( + x, + size=(299, 299), + mode="bilinear", + align_corners=False, + ).to(device) + + with torch.no_grad(): + if mode == "pytorch_fid": + feats = model(x)[0] + feats = feats.squeeze(-1).squeeze(-1) + else: + mean = torch.tensor([0.485, 0.456, 0.406], device=device).view(1, 3, 1, 1) + std = torch.tensor([0.229, 0.224, 0.225], device=device).view(1, 3, 1, 1) + x = (x - mean) / std + feats = model(x) + + vec = feats.detach().cpu().numpy().reshape(-1).astype(np.float64) + if vec.size != 2048: + return None + if not np.all(np.isfinite(vec)): + return None + return vec + except Exception: + return None + + @classmethod + def _remove_score(cls, prediction_image: np.ndarray, mask_like: Any) -> float: + if mask_like is None: + return float("nan") + + if cls._remove_evaluator_bundle is None: + try: + from gdb.metrics.remove_metric import ( + DEFAULT_SAM_CHECKPOINT_PATH, + RemoveMetricEvaluator, + ensure_sam_checkpoint, + ) + + checkpoint_override = os.environ.get("GDB_REMOVE_SAM_CHECKPOINT") + checkpoint = ensure_sam_checkpoint(checkpoint_override or DEFAULT_SAM_CHECKPOINT_PATH) + disable_crop = os.environ.get("GDB_REMOVE_DISABLE_CROP", "") + crop = str(disable_crop).strip().lower() not in {"1", "true", "yes", "on"} + + cls._remove_evaluator_bundle = RemoveMetricEvaluator( + sam_checkpoint=str(checkpoint), + model_type=os.environ.get("GDB_REMOVE_MODEL_TYPE", "vit_h"), + device=os.environ.get("GDB_REMOVE_DEVICE") or None, + crop=crop, + ) + except Exception as exc: + logger.info("ReMOVE metric unavailable, returning NaN: %s", exc) + cls._remove_evaluator_bundle = False + + if not cls._remove_evaluator_bundle: + return float("nan") + + try: + from PIL import Image + + pred_u8 = prediction_image + if pred_u8.dtype != np.uint8: + pred_u8 = np.clip(pred_u8, 0, 255).astype(np.uint8) + + mask = cls._to_gray_mask(mask_like, pred_u8.shape[:2]) + if mask is None: + return float("nan") + + pred_pil = Image.fromarray(pred_u8, mode="RGB") + mask_pil = Image.fromarray(mask.astype(np.uint8), mode="L") + score = cls._remove_evaluator_bundle.score(pred_pil, mask_pil) + if score is None: + return float("nan") + return float(score) + except Exception as exc: + logger.debug("ReMOVE metric failed for a sample: %s", exc) + return float("nan") +