diff --git a/dataflow_agent/toolkits/image2drawio/__init__.py b/dataflow_agent/toolkits/image2drawio/__init__.py index b70cabc3..1ac55e45 100644 --- a/dataflow_agent/toolkits/image2drawio/__init__.py +++ b/dataflow_agent/toolkits/image2drawio/__init__.py @@ -9,6 +9,8 @@ save_masked_rgba, bbox_iou_px, ) +from .metric_evaluator import evaluate as metric_evaluate +from .refinement_processor import refine as refinement_refine __all__ = [ "classify_shape", @@ -18,4 +20,6 @@ "sample_fill_stroke", "save_masked_rgba", "bbox_iou_px", + "metric_evaluate", + "refinement_refine", ] diff --git a/dataflow_agent/toolkits/image2drawio/metric_evaluator.py b/dataflow_agent/toolkits/image2drawio/metric_evaluator.py new file mode 100644 index 00000000..e93f8263 --- /dev/null +++ b/dataflow_agent/toolkits/image2drawio/metric_evaluator.py @@ -0,0 +1,692 @@ +""" +metric_evaluator.py — Image2DrawIO quality evaluation module. + +Computes a content-coverage score and detects uncovered "bad regions" +that need fallback rescue. Works with Paper2Any's dict-based element +format (kind/bbox_px/image_path …). + +Core idea: + score = covered_content_pixels / total_content_pixels × 100 + +Three-channel bad-region detection: + Fine — small icons / sub-figures (0.05 %–20 % of image) + Coarse — panels / large images (0.2 %–30 %) + Complex— high-variance areas without base64 (heatmaps, photos …) + +Usage: + from dataflow_agent.toolkits.image2drawio.metric_evaluator import evaluate + + result = evaluate( + image_path="input.png", + elements=[...], # from _build_elements_from_sam3 + text_blocks=[...], # from _text_node + output_dir="outputs/xx", # optional, saves debug images + ) + print(result["score"], result["bad_regions"]) +""" + +from __future__ import annotations + +import json +import os +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple + +import cv2 +import numpy as np +from dataflow_agent.logger import get_logger + +log = get_logger(__name__) + +# ======================== helpers ======================== + +def _bbox_area(bbox: List[int]) -> int: + return max(0, bbox[2] - bbox[0]) * max(0, bbox[3] - bbox[1]) + + +def _bbox_iou(a: List[int], b: List[int]) -> float: + xa = max(a[0], b[0]) + ya = max(a[1], b[1]) + xb = min(a[2], b[2]) + yb = min(a[3], b[3]) + inter = max(0, xb - xa) * max(0, yb - ya) + area_a = _bbox_area(a) + area_b = _bbox_area(b) + union = area_a + area_b - inter + return inter / union if union > 0 else 0.0 + + +# ======================== configuration ======================== + +DEFAULT_CONFIG: Dict[str, Any] = { + # content mask + "content_threshold": 240, # 更严格:灰度<240即为内容(原245太宽松) + "use_edge_detection": True, + "edge_low": 25, # 更敏感的边缘检测(原30) + "edge_high": 80, # 降低高阈值以捕获更多边缘(原100) + "denoise_kernel": 2, + "min_content_area": 20, # 保留更小的内容区域(原30) + + # fine channel — 检测小图标/人脸/小子图 + "fine_min_ratio": 0.0003, # 更敏感:0.03%(原0.05%) + "fine_max_ratio": 0.25, # 扩大到25%(原20%) + "fine_min_fill": 0.12, # 更宽松的填充率(原0.15) + "fine_max_aspect": 10.0, # 允许更大宽高比(原8.0) + + # coarse channel — 检测版块/大图 + "coarse_min_ratio": 0.001, # 更敏感:0.1%(原0.2%) + "coarse_max_ratio": 0.35, # 扩大到35%(原30%) + "coarse_min_fill": 0.15, # 更宽松的填充率(原0.20) + "coarse_max_aspect": 10.0, # 允许更大宽高比(原8.0) + "coarse_kernel": 7, # 稍大的核以合并相邻内容(原5) + + # NMS / dedup + "nms_iou": 0.25, # 更严格的NMS(原0.3) + "existing_iou": 0.30, # 更积极过滤:与已有元素 IoU>30% 即跳过(原0.45) + "max_covered_ratio": 0.65, # 降低已覆盖容忍度(原0.7) + "min_missing_ratio": 0.03, # 更敏感的漏检内容检测(原0.05) + + # merge + "merge_distance_ratio": 0.08, # 更保守的合并距离(原0.10) + "small_region_threshold": 0.025, # 更小的小区域阈值(原0.03) + + # text protection + "text_pad_px": 15, # 文字 bbox 外扩像素,弥补 OCR 框偏小(原8太小) + "text_overlap_skip": 0.35, # 候选区域被文字覆盖 ≥ 35% 则跳过(原0.5太高,文字区域容易漏过) +} + +# ======================== public API ======================== + +def evaluate( + image_path: str, + elements: List[Dict[str, Any]], + text_blocks: Optional[List[Dict[str, Any]]] = None, + output_dir: Optional[str] = None, + config: Optional[Dict[str, Any]] = None, +) -> Dict[str, Any]: + """ + Evaluate how well existing elements cover the image content. + + Returns dict with keys: + score – 0..100 (100 = perfect coverage) + bad_regions – list of {bbox, area, area_ratio, channel, …} + needs_refinement – bool + metrics – detailed numbers for debugging + """ + cfg = {**DEFAULT_CONFIG, **(config or {})} + + cv2_image = cv2.imread(image_path) + if cv2_image is None: + log.error(f"[MetricEvaluator] Cannot read image: {image_path}") + return {"score": 0, "bad_regions": [], "needs_refinement": False, "metrics": {}} + + h, w = cv2_image.shape[:2] + img_area = h * w + + # 1. content mask (foreground pixels) + content_mask = _create_content_mask(cv2_image, cfg) + total_content = int(np.count_nonzero(content_mask)) + + # 2. covered mask (elements that produced actual output) + covered_mask, existing_bboxes = _create_covered_mask( + elements, text_blocks or [], h, w, cfg + ) + + # 3. build text-only mask (with padding) for text-overlap filtering + text_bboxes = _collect_text_bboxes(text_blocks or [], h, w, cfg) + + # 4. pixel coverage + covered_content = int(np.count_nonzero(cv2.bitwise_and(content_mask, covered_mask))) + pixel_coverage = (covered_content / total_content * 100) if total_content > 0 else 100.0 + + # 5. uncovered content + uncovered = cv2.bitwise_and(content_mask, cv2.bitwise_not(covered_mask)) + + # 6. detect bad regions (three channels) + bad_regions = _detect_bad_regions( + cv2_image, content_mask, covered_mask, uncovered, + existing_bboxes, text_bboxes, elements, img_area, cfg, + ) + + # 7. score = 100 − bad-region area ratio (de-duplicated) + bad_mask = np.zeros((h, w), dtype=np.uint8) + for r in bad_regions: + x1, y1, x2, y2 = r["bbox"] + bad_mask[max(0, y1):min(h, y2), max(0, x1):min(w, x2)] = 255 + bad_ratio = float(np.count_nonzero(bad_mask) / img_area * 100) if img_area > 0 else 0.0 + score = max(0.0, 100.0 - bad_ratio) + + needs_refinement = len(bad_regions) > 0 + + metrics = { + "score": round(score, 2), + "pixel_coverage": round(pixel_coverage, 2), + "total_content_px": total_content, + "covered_content_px": covered_content, + "image_area": img_area, + "element_count": len(elements), + "text_block_count": len(text_blocks or []), + "bad_region_count": len(bad_regions), + "bad_region_ratio": round(bad_ratio, 2), + } + + log.info( + f"[MetricEvaluator] score={score:.1f}, " + f"bad_regions={len(bad_regions)}, bad_ratio={bad_ratio:.1f}%" + ) + + # 8. optional visualisation / JSON dump + if output_dir: + _save_debug(cv2_image, covered_mask, uncovered, bad_regions, + metrics, needs_refinement, score, output_dir) + + return { + "score": round(score, 2), + "bad_regions": bad_regions, + "needs_refinement": needs_refinement, + "metrics": metrics, + } + + +# ======================== content mask ======================== + +def _create_content_mask(img: np.ndarray, cfg: dict) -> np.ndarray: + gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) + h, w = gray.shape + + # threshold + mask_gray = ((gray < cfg["content_threshold"]).astype(np.uint8)) * 255 + + # edge detection (optional) + if cfg.get("use_edge_detection", True): + edges = cv2.Canny(gray, cfg["edge_low"], cfg["edge_high"]) + edges = cv2.dilate(edges, np.ones((5, 5), np.uint8), iterations=2) + mask_gray = cv2.bitwise_or(mask_gray, edges) + + # denoise + ks = cfg.get("denoise_kernel", 2) + if ks > 0: + mask_gray = cv2.morphologyEx(mask_gray, cv2.MORPH_OPEN, np.ones((ks, ks), np.uint8)) + + # remove tiny connected components + min_cc = cfg.get("min_content_area", 20) + if min_cc > 0: + n_labels, labels, stats, _ = cv2.connectedComponentsWithStats(mask_gray, connectivity=8) + clean = np.zeros_like(mask_gray) + for i in range(1, n_labels): + if stats[i, cv2.CC_STAT_AREA] >= min_cc: + clean[labels == i] = 255 + mask_gray = clean + + return mask_gray + + +# ======================== covered mask ======================== + +def _create_covered_mask( + elements: List[Dict], + text_blocks: List[Dict], + h: int, w: int, + cfg: dict, +) -> Tuple[np.ndarray, List[List[int]]]: + """ + Build a mask of regions that already have real output. + + Rules: + - shape with fill/stroke → covered (矢量已还原) + - image with existing image_path → covered + - text block with geometry → covered (带 padding 扩展) + """ + mask = np.zeros((h, w), dtype=np.uint8) + bboxes: List[List[int]] = [] + + for el in elements: + bbox = el.get("bbox_px") + if not bbox or len(bbox) != 4: + continue + kind = el.get("kind", "") + + # image kind: must have a valid file to count + if kind == "image": + ip = el.get("image_path", "") + if not ip or not os.path.exists(ip): + continue + + x1, y1, x2, y2 = [int(v) for v in bbox] + x1, y1 = max(0, x1), max(0, y1) + x2, y2 = min(w, x2), min(h, y2) + if x2 > x1 and y2 > y1: + mask[y1:y2, x1:x2] = 255 + bboxes.append([x1, y1, x2, y2]) + + # 文字 bbox 带 padding 写入 covered mask + pad = cfg.get("text_pad_px", 8) + for blk in text_blocks: + geo = blk.get("geometry", {}) + x = int(float(geo.get("x", 0))) + y = int(float(geo.get("y", 0))) + bw = int(float(geo.get("width", 0))) + bh = int(float(geo.get("height", 0))) + x1, y1 = max(0, x - pad), max(0, y - pad) + x2, y2 = min(w, x + bw + pad), min(h, y + bh + pad) + if x2 > x1 and y2 > y1: + mask[y1:y2, x1:x2] = 255 + bboxes.append([x1, y1, x2, y2]) + + return mask, bboxes + + +def _collect_text_bboxes( + text_blocks: List[Dict], h: int, w: int, cfg: dict, +) -> List[List[int]]: + """Collect padded text bboxes for text-overlap filtering.""" + pad = cfg.get("text_pad_px", 8) + bboxes: List[List[int]] = [] + for blk in text_blocks: + geo = blk.get("geometry", {}) + x = int(float(geo.get("x", 0))) + y = int(float(geo.get("y", 0))) + bw = int(float(geo.get("width", 0))) + bh = int(float(geo.get("height", 0))) + x1, y1 = max(0, x - pad), max(0, y - pad) + x2, y2 = min(w, x + bw + pad), min(h, y + bh + pad) + if x2 > x1 and y2 > y1: + bboxes.append([x1, y1, x2, y2]) + return bboxes + + +# ======================== bad-region detection ======================== + +def _detect_bad_regions( + cv2_image: np.ndarray, + content_mask: np.ndarray, + covered_mask: np.ndarray, + uncovered: np.ndarray, + existing_bboxes: List[List[int]], + text_bboxes: List[List[int]], + elements: List[Dict], + img_area: int, + cfg: dict, +) -> List[Dict[str, Any]]: + h, w = cv2_image.shape[:2] + candidates: List[Tuple[List[int], str]] = [] + + # fine channel + for box in _channel_cc(uncovered, img_area, + cfg["fine_min_ratio"], cfg["fine_max_ratio"], + cfg["fine_min_fill"], cfg["fine_max_aspect"]): + candidates.append((box, "fine")) + + # coarse channel + k = cfg["coarse_kernel"] + closed = cv2.morphologyEx(uncovered, cv2.MORPH_CLOSE, + cv2.getStructuringElement(cv2.MORPH_RECT, (k, k))) + for box in _channel_cc(closed, img_area, + cfg["coarse_min_ratio"], cfg["coarse_max_ratio"], + cfg["coarse_min_fill"], cfg["coarse_max_aspect"]): + candidates.append((box, "coarse")) + + # complex channel (high-variance regions without element coverage) + for box in _detect_complex(cv2_image, elements, covered_mask, img_area): + candidates.append((box, "complex")) + + # banner channel — 横幅/标题栏: 宽度 ≥ 图片宽度 40%, 宽高比 > 10 + # 普通通道会因 max_aspect 过滤掉这类区域, 但横跨画面的标题栏应该保留 + for box in _channel_cc(uncovered, img_area, + cfg.get("banner_min_ratio", 0.003), + cfg.get("banner_max_ratio", 0.15), + cfg.get("banner_min_fill", 0.08), + max_aspect=100.0): # 放宽宽高比 + bw = box[2] - box[0] + bh = box[3] - box[1] + # 必须宽度 ≥ 图片宽度的 40%(排除普通窄条噪音) + if bw >= w * 0.4 and bh >= 10: + candidates.append((box, "banner")) + + log.info( + f"[MetricEvaluator] candidates: " + f"fine={sum(1 for _, c in candidates if c == 'fine')}, " + f"coarse={sum(1 for _, c in candidates if c == 'coarse')}, " + f"complex={sum(1 for _, c in candidates if c == 'complex')}, " + f"banner={sum(1 for _, c in candidates if c == 'banner')}" + ) + + # small-box-first NMS + candidates = _nms_small_first(candidates, cfg["nms_iou"]) + + # filter vs existing elements + text overlap + coverage check + regions = _filter_candidates( + candidates, covered_mask, existing_bboxes, text_bboxes, + uncovered, img_area, cfg, + ) + + # merge nearby small regions + merge_dist = min(h, w) * cfg.get("merge_distance_ratio", 0.10) + regions = _merge_nearby(regions, merge_dist, img_area, + cfg.get("small_region_threshold", 0.03)) + + regions.sort(key=lambda r: r["area"], reverse=True) + return regions + + +def _channel_cc( + mask: np.ndarray, img_area: int, + min_ratio: float, max_ratio: float, + min_fill: float, max_aspect: float, +) -> List[List[int]]: + """Connected-component channel: returns list of bboxes.""" + min_a = img_area * min_ratio + max_a = img_area * max_ratio + n, _, stats, _ = cv2.connectedComponentsWithStats(mask, connectivity=8) + boxes: List[List[int]] = [] + for i in range(1, n): + x, y, rw, rh, cc_area = stats[i] + if rw <= 0 or rh <= 0: + continue + ba = rw * rh + if ba < min_a or ba > max_a: + continue + if max(rw, rh) / max(1, min(rw, rh)) > max_aspect: + continue + if cc_area / ba < min_fill: + continue + boxes.append([int(x), int(y), int(x + rw), int(y + rh)]) + return boxes + + +def _detect_complex( + cv2_image: np.ndarray, + elements: List[Dict], + covered_mask: np.ndarray, + img_area: int, +) -> List[List[int]]: + """Detect high-complexity regions not covered by any element. + + NOTE: text regions have already been painted into covered_mask + (with padding), so they are excluded from `uncovered_hi` via the + bitwise_and(hi, ~covered_mask) step. + """ + h, w = cv2_image.shape[:2] + gray = cv2.cvtColor(cv2_image, cv2.COLOR_BGR2GRAY).astype(np.float32) + + ks = max(21, min(h, w) // 50) + if ks % 2 == 0: + ks += 1 + + local_mean = cv2.blur(gray, (ks, ks)) + local_var = cv2.blur(gray ** 2, (ks, ks)) - local_mean ** 2 + local_var = np.maximum(local_var, 0) + + edges = cv2.Canny(gray.astype(np.uint8), 30, 100) + edge_density = cv2.blur(edges.astype(np.float32), (ks, ks)) + + var_norm = local_var / (local_var.max() + 1e-6) + edge_norm = edge_density / (edge_density.max() + 1e-6) + complexity = var_norm * 0.6 + edge_norm * 0.4 + + thresh = np.percentile(complexity, 75) + hi = (complexity > thresh).astype(np.uint8) * 255 + hi = cv2.morphologyEx(hi, cv2.MORPH_CLOSE, np.ones((15, 15), np.uint8)) + + uncovered_hi = cv2.bitwise_and(hi, cv2.bitwise_not(covered_mask)) + uncovered_hi = cv2.morphologyEx(uncovered_hi, cv2.MORPH_OPEN, np.ones((7, 7), np.uint8)) + uncovered_hi = cv2.morphologyEx(uncovered_hi, cv2.MORPH_CLOSE, np.ones((51, 51), np.uint8)) + + min_a = img_area * 0.002 + max_a = img_area * 0.30 + n, _, stats, _ = cv2.connectedComponentsWithStats(uncovered_hi, connectivity=8) + boxes: List[List[int]] = [] + for i in range(1, n): + x, y, rw, rh, _ = stats[i] + ba = rw * rh + if ba < min_a or ba > max_a: + continue + if max(rw, rh) / max(1, min(rw, rh)) > 8: + continue + boxes.append([int(x), int(y), int(x + rw), int(y + rh)]) + return boxes + + +# ======================== NMS / filtering ======================== + +def _nms_small_first( + candidates: List[Tuple[List[int], str]], + iou_thresh: float, +) -> List[Tuple[List[int], str]]: + """Keep smaller boxes, suppress larger overlapping ones.""" + if not candidates: + return [] + items = [(b, c, _bbox_area(b)) for b, c in candidates] + items.sort(key=lambda x: x[2]) # ascending area + keep: List[Tuple[List[int], str]] = [] + suppressed = [False] * len(items) + for i, (bi, ci, _) in enumerate(items): + if suppressed[i]: + continue + keep.append((bi, ci)) + for j in range(i + 1, len(items)): + if not suppressed[j] and _bbox_iou(bi, items[j][0]) > iou_thresh: + suppressed[j] = True + return keep + + +def _text_overlap_ratio(candidate: List[int], text_bboxes: List[List[int]]) -> float: + """计算 candidate 区域被文字 bbox 覆盖的面积占比. + + 将所有与 candidate 相交的文字区域的交集面积累加(使用 mask 去重), + 然后除以 candidate 面积。 + """ + c_area = _bbox_area(candidate) + if c_area <= 0 or not text_bboxes: + return 0.0 + + cx1, cy1, cx2, cy2 = candidate + cw, ch = cx2 - cx1, cy2 - cy1 + + # 用小 mask 精确计算覆盖面积(避免多个文字 bbox 重叠导致重复计算) + tmask = np.zeros((ch, cw), dtype=np.uint8) + for tb in text_bboxes: + # 相交区域(相对于 candidate 的局部坐标) + lx1 = max(0, tb[0] - cx1) + ly1 = max(0, tb[1] - cy1) + lx2 = min(cw, tb[2] - cx1) + ly2 = min(ch, tb[3] - cy1) + if lx2 > lx1 and ly2 > ly1: + tmask[ly1:ly2, lx1:lx2] = 255 + + covered = int(np.count_nonzero(tmask)) + return covered / c_area + + +def _filter_candidates( + candidates: List[Tuple[List[int], str]], + covered_mask: np.ndarray, + existing_bboxes: List[List[int]], + text_bboxes: List[List[int]], + uncovered: np.ndarray, + img_area: int, + cfg: dict, +) -> List[Dict[str, Any]]: + iou_thresh = cfg["existing_iou"] + max_covered = cfg["max_covered_ratio"] + min_missing = cfg["min_missing_ratio"] + text_skip = cfg.get("text_overlap_skip", 0.5) + + regions: List[Dict[str, Any]] = [] + for box, channel in candidates: + x1, y1, x2, y2 = int(box[0]), int(box[1]), int(box[2]), int(box[3]) + area = int(_bbox_area(box)) + is_complex = channel == "complex" + + # skip if high IoU with an existing element + eff_iou = 0.8 if is_complex else iou_thresh + if any(_bbox_iou(box, eb) > eff_iou for eb in existing_bboxes): + continue + + # skip if candidate is mostly covered by union of existing element bboxes + # (handles cases where no single element has high IoU but collectively they cover it) + if not is_complex and existing_bboxes: + cw, ch = x2 - x1, y2 - y1 + if cw > 0 and ch > 0: + union_mask = np.zeros((ch, cw), dtype=np.uint8) + for eb in existing_bboxes: + lx1 = max(0, eb[0] - x1) + ly1 = max(0, eb[1] - y1) + lx2 = min(cw, eb[2] - x1) + ly2 = min(ch, eb[3] - y1) + if lx2 > lx1 and ly2 > ly1: + union_mask[ly1:ly2, lx1:lx2] = 255 + union_covered = float(np.count_nonzero(union_mask)) / (cw * ch) + if union_covered > 0.50: + continue + + # ★ skip if candidate is predominantly text + # 如果候选区域被文字 bbox 覆盖的面积 ≥ text_overlap_skip,则认为是文字区域,跳过 + if text_bboxes and _text_overlap_ratio([x1, y1, x2, y2], text_bboxes) >= text_skip: + continue + + # skip if mostly already covered (except complex) + if not is_complex: + roi = covered_mask[max(0, y1):min(covered_mask.shape[0], y2), + max(0, x1):min(covered_mask.shape[1], x2)] + if roi.size > 0 and float(np.mean(roi > 0)) > max_covered: + continue + + # missing content pixels + roi_unc = uncovered[max(0, y1):min(uncovered.shape[0], y2), + max(0, x1):min(uncovered.shape[1], x2)] + missing_px = int(np.count_nonzero(roi_unc)) if not is_complex else area + + if not is_complex and area > 0 and missing_px < area * min_missing: + continue + + regions.append({ + "bbox": [x1, y1, x2, y2], + "area": area, + "area_ratio": round(area / img_area, 4) if img_area > 0 else 0, + "missing_pixels": missing_px, + "channel": channel, + "reason": "complex_image" if is_complex else "uncovered_content", + "description": ( + f"({x1},{y1})-({x2},{y2}) " + f"{'complex image' if is_complex else 'uncovered'} [{channel}]" + ), + }) + return regions + + +# ======================== merge ======================== + +def _merge_nearby( + regions: List[Dict], + merge_dist: float, + img_area: int, + small_thresh: float, +) -> List[Dict]: + if len(regions) <= 1: + return regions + + large = [r for r in regions if r["area_ratio"] >= small_thresh] + small = [r for r in regions if r["area_ratio"] < small_thresh] + if len(small) <= 1: + return regions + + def _dist(a, b): + dx = max(0, max(a[0], b[0]) - min(a[2], b[2])) + dy = max(0, max(a[1], b[1]) - min(a[3], b[3])) + return max(dx, dy) + + n = len(small) + parent = list(range(n)) + + def _find(x): + while parent[x] != x: + parent[x] = parent[parent[x]] + x = parent[x] + return x + + for i in range(n): + for j in range(i + 1, n): + if _dist(small[i]["bbox"], small[j]["bbox"]) < merge_dist: + pi, pj = _find(i), _find(j) + if pi != pj: + parent[pi] = pj + + groups: Dict[int, List[int]] = {} + for i in range(n): + groups.setdefault(_find(i), []).append(i) + + merged: List[Dict] = [] + for indices in groups.values(): + if len(indices) == 1: + merged.append(small[indices[0]]) + else: + bxs = [small[i]["bbox"] for i in indices] + mb = [int(min(b[0] for b in bxs)), int(min(b[1] for b in bxs)), + int(max(b[2] for b in bxs)), int(max(b[3] for b in bxs))] + ma = int(_bbox_area(mb)) + merged.append({ + "bbox": mb, + "area": ma, + "area_ratio": round(ma / img_area, 4) if img_area > 0 else 0, + "missing_pixels": sum(small[i]["missing_pixels"] for i in indices), + "channel": "merged", + "reason": "merged_regions", + "description": f"merged {len(indices)} small regions", + }) + return large + merged + + +# ======================== debug output ======================== + +def _save_debug( + cv2_image, covered_mask, uncovered, bad_regions, + metrics, needs_refinement, score, output_dir, +): + os.makedirs(output_dir, exist_ok=True) + + # visualisation + vis = cv2_image.copy() + overlay = cv2_image.copy() + h, w = vis.shape[:2] + for i, r in enumerate(bad_regions): + x1, y1, x2, y2 = r["bbox"] + x1, y1 = max(0, x1), max(0, y1) + x2, y2 = min(w, x2), min(h, y2) + cv2.rectangle(overlay, (x1, y1), (x2, y2), (0, 0, 255), -1) + cv2.rectangle(vis, (x1, y1), (x2, y2), (0, 0, 255), 3) + label = f"#{i+1} {r['channel']} ({r['area_ratio']*100:.1f}%)" + (tw, th), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.7, 2) + cv2.rectangle(vis, (x1, y1 - th - 8), (x1 + tw + 6, y1), (0, 0, 255), -1) + cv2.putText(vis, label, (x1 + 3, y1 - 4), + cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2) + result_img = cv2.addWeighted(overlay, 0.25, vis, 0.75, 0) + cv2.imwrite(str(Path(output_dir) / "metric_eval.png"), result_img) + + # JSON + def _native(o): + if isinstance(o, (np.integer,)): + return int(o) + if isinstance(o, (np.floating,)): + return float(o) + if isinstance(o, np.ndarray): + return o.tolist() + if isinstance(o, list): + return [_native(x) for x in o] + if isinstance(o, dict): + return {k: _native(v) for k, v in o.items()} + return o + + with open(str(Path(output_dir) / "metric_eval.json"), "w", encoding="utf-8") as f: + json.dump( + { + "score": round(float(score), 2), + "needs_refinement": bool(needs_refinement), + "metrics": {k: _native(v) for k, v in metrics.items()}, + "bad_regions": [{k: _native(v) for k, v in r.items()} for r in bad_regions], + }, + f, + ensure_ascii=False, + indent=2, + ) + log.info(f"[MetricEvaluator] debug saved to {output_dir}") diff --git a/dataflow_agent/toolkits/image2drawio/refinement_processor.py b/dataflow_agent/toolkits/image2drawio/refinement_processor.py new file mode 100644 index 00000000..b8663725 --- /dev/null +++ b/dataflow_agent/toolkits/image2drawio/refinement_processor.py @@ -0,0 +1,333 @@ +""" +refinement_processor.py — Fallback rescue for uncovered regions. + +Takes bad regions from metric_evaluator, crops them from the original +image, saves as PNG, and returns new element dicts compatible with +the existing Paper2Any _render_xml_node format. + +Strategy (conservative): + - Crop the region from the original image + - Save as PNG file + - Return as kind="image" element with image_path + - Skip regions that are >95% white or too small + +Usage: + from dataflow_agent.toolkits.image2drawio.refinement_processor import refine + + new_elements = refine( + image_path="input.png", + bad_regions=[...], # from metric_evaluator + existing_elements=[...], + output_dir="outputs/xx", + ) + # new_elements are dicts with kind="image", bbox_px, image_path, etc. +""" + +from __future__ import annotations + +import os +from pathlib import Path +from typing import Any, Dict, List, Optional + +import cv2 +import numpy as np + +from dataflow_agent.logger import get_logger + +log = get_logger(__name__) + +# ======================== configuration ======================== + +DEFAULT_CONFIG: Dict[str, Any] = { + "min_region_area": 100, # skip regions smaller than this (px) + "min_region_ratio": 0.0005, # skip regions smaller than 0.05% of image + "expand_margin": 5, # expand crop by N pixels each side + "skip_mostly_white": True, # skip regions that are almost all white + "white_threshold": 0.95, # ratio of white pixels to skip + "white_pixel_value": 245, # grayscale > this = "white" + "skip_mostly_text": True, # skip regions that are mostly thin text strokes + "text_stroke_threshold": 0.55, # ratio: if >55% of dark pixels are thin strokes → text(原0.80太严格) +} + + +# ======================== public API ======================== + +def refine( + image_path: str, + bad_regions: List[Dict[str, Any]], + existing_elements: List[Dict[str, Any]], + output_dir: str, + config: Optional[Dict[str, Any]] = None, +) -> List[Dict[str, Any]]: + """ + Process bad regions and return new image elements. + + Args: + image_path: path to the original image + bad_regions: list of dicts from metric_evaluator (each has "bbox") + existing_elements: current element list (for ID numbering) + output_dir: directory to save cropped PNGs + + Returns: + List of new element dicts (kind="image") ready for _render_xml_node. + These should be appended to existing_elements by the caller. + """ + cfg = {**DEFAULT_CONFIG, **(config or {})} + + if not bad_regions: + log.info("[Refinement] No bad regions to process") + return [] + + cv2_image = cv2.imread(image_path) + if cv2_image is None: + log.error(f"[Refinement] Cannot read image: {image_path}") + return [] + + h, w = cv2_image.shape[:2] + img_area = h * w + + crop_dir = Path(output_dir) / "refinement_crops" + crop_dir.mkdir(parents=True, exist_ok=True) + + min_area = cfg["min_region_area"] + min_ratio = cfg["min_region_ratio"] + margin = cfg["expand_margin"] + + # Collect existing element bboxes for overlap checking + existing_bboxes: List[List[int]] = [] + for el in existing_elements: + bbox = el.get("bbox_px") + if bbox and len(bbox) == 4: + existing_bboxes.append([int(v) for v in bbox]) + + new_elements: List[Dict[str, Any]] = [] + skipped = 0 + + # Generate IDs that don't collide with existing elements + max_existing_id = 0 + for el in existing_elements: + eid = el.get("id", "") + if isinstance(eid, str): + # extract numeric part from "s42", "i13", etc. + digits = "".join(c for c in eid if c.isdigit()) + if digits: + max_existing_id = max(max_existing_id, int(digits)) + elif isinstance(eid, int): + max_existing_id = max(max_existing_id, eid) + + next_id = max_existing_id + 1 + + for i, region in enumerate(bad_regions): + bbox = region.get("bbox") + if not bbox or len(bbox) != 4: + skipped += 1 + continue + + x1, y1, x2, y2 = int(bbox[0]), int(bbox[1]), int(bbox[2]), int(bbox[3]) + area = (x2 - x1) * (y2 - y1) + + # size filter + if area < min_area or (img_area > 0 and area < img_area * min_ratio): + log.debug(f"[Refinement] Region {i} too small ({area}px), skip") + skipped += 1 + continue + + # white filter + if cfg.get("skip_mostly_white", True) and _is_mostly_white( + cv2_image, [x1, y1, x2, y2], + cfg["white_pixel_value"], cfg["white_threshold"] + ): + log.debug(f"[Refinement] Region {i} mostly white, skip") + skipped += 1 + continue + + # overlap filter: skip if region is significantly covered by existing elements + if existing_bboxes: + rw, rh = x2 - x1, y2 - y1 + if rw > 0 and rh > 0: + overlap_mask = np.zeros((rh, rw), dtype=np.uint8) + for eb in existing_bboxes: + lx1 = max(0, eb[0] - x1) + ly1 = max(0, eb[1] - y1) + lx2 = min(rw, eb[2] - x1) + ly2 = min(rh, eb[3] - y1) + if lx2 > lx1 and ly2 > ly1: + overlap_mask[ly1:ly2, lx1:lx2] = 255 + overlap_ratio = float(np.count_nonzero(overlap_mask)) / (rw * rh) + if overlap_ratio > 0.40: + log.debug(f"[Refinement] Region {i} overlaps {overlap_ratio:.0%} with existing elements, skip") + skipped += 1 + continue + + # text-stroke filter: skip if region is mostly thin text strokes + # Exception: banner channel regions are OCR-missed titles that MUST be kept + is_banner = region.get("channel") == "banner" + if not is_banner and cfg.get("skip_mostly_text", True) and _is_mostly_text( + cv2_image, [x1, y1, x2, y2], + cfg.get("text_stroke_threshold", 0.70) + ): + log.debug(f"[Refinement] Region {i} mostly text strokes, skip") + skipped += 1 + continue + + # expand margin + cx1 = max(0, x1 - margin) + cy1 = max(0, y1 - margin) + cx2 = min(w, x2 + margin) + cy2 = min(h, y2 + margin) + + # crop and save + crop = cv2_image[cy1:cy2, cx1:cx2] + if crop.size == 0: + skipped += 1 + continue + + crop_path = str(crop_dir / f"refine_{i}.png") + cv2.imwrite(crop_path, crop) + + # build element dict compatible with _render_xml_node + new_elements.append({ + "id": f"r{next_id}", + "kind": "image", + "bbox_px": [cx1, cy1, cx2, cy2], + "image_path": crop_path, + "area": (cx2 - cx1) * (cy2 - cy1), + "group": "refinement", + "prompt": "fallback_crop", + "_source": "refinement", + "_channel": region.get("channel", "unknown"), + }) + next_id += 1 + + log.info( + f"[Refinement] Done: {len(new_elements)} new elements, " + f"{skipped} skipped" + ) + + # save visualisation + if new_elements: + _save_visualization(cv2_image, new_elements, existing_elements, output_dir) + + return new_elements + + +# ======================== helpers ======================== + +def _is_mostly_white( + cv2_image: np.ndarray, + bbox: List[int], + white_value: int = 245, + threshold: float = 0.95, +) -> bool: + """Check if a region is mostly white/empty.""" + x1, y1, x2, y2 = bbox + h, w = cv2_image.shape[:2] + x1, y1 = max(0, x1), max(0, y1) + x2, y2 = min(w, x2), min(h, y2) + if x2 <= x1 or y2 <= y1: + return True + + roi = cv2_image[y1:y2, x1:x2] + gray = cv2.cvtColor(roi, cv2.COLOR_BGR2GRAY) + white_count = int(np.count_nonzero(gray > white_value)) + total = gray.size + return (white_count / total) > threshold if total > 0 else True + + +def _is_mostly_text( + cv2_image: np.ndarray, + bbox: List[int], + threshold: float = 0.70, +) -> bool: + """Check if a region is mostly text strokes (dark-on-light OR light-on-dark). + + Detects both: + - Dark text on light background (standard documents) + - Light/white text on dark background (dark-theme panels, banners) + + Heuristic: text = strokes of foreground color on uniform background. + Uses adaptive kernel sizing based on region height to handle both + small body text (thin strokes) and large bold titles (thick strokes). + After morphological opening, text strokes disappear but filled shapes remain. + """ + x1, y1, x2, y2 = bbox + h, w = cv2_image.shape[:2] + x1, y1 = max(0, x1), max(0, y1) + x2, y2 = min(w, x2), min(h, y2) + if x2 <= x1 or y2 <= y1: + return False + + roi = cv2_image[y1:y2, x1:x2] + gray = cv2.cvtColor(roi, cv2.COLOR_BGR2GRAY) + total = gray.size + if total < 100: + return False + + roi_h = y2 - y1 + + # Determine polarity: is it dark-on-light or light-on-dark? + light_ratio = float(np.count_nonzero(gray > 200)) / total + dark_ratio = float(np.count_nonzero(gray < 60)) / total + + if light_ratio >= 0.55: + # Case 1: light background, dark text strokes + fg_mask = (gray < 180).astype(np.uint8) * 255 + elif dark_ratio >= 0.55: + # Case 2: dark background, light text strokes + fg_mask = (gray > 80).astype(np.uint8) * 255 + else: + # Mixed / mid-tone → probably not a text region + return False + + fg_count = int(np.count_nonzero(fg_mask)) + if fg_count < 10: + return False # almost no foreground + + # Adaptive kernel: larger regions may have thicker text (bold titles, headers) + # Use ~15% of region height as kernel size, clamped to [3, 11] + ks = max(3, min(11, int(roi_h * 0.15))) + if ks % 2 == 0: + ks += 1 # must be odd + + kernel = np.ones((ks, ks), np.uint8) + opened = cv2.morphologyEx(fg_mask, cv2.MORPH_OPEN, kernel) + thick_count = int(np.count_nonzero(opened)) + + # thin_ratio = fraction of foreground pixels that are thin (removed by opening) + thin_ratio = 1.0 - (thick_count / fg_count) if fg_count > 0 else 0.0 + + return thin_ratio >= threshold + + +def _save_visualization( + cv2_image: np.ndarray, + new_elements: List[Dict], + existing_elements: List[Dict], + output_dir: str, +): + """Save a debug image showing original + new elements.""" + vis = cv2_image.copy() + h, w = vis.shape[:2] + + # existing elements in blue + for el in existing_elements: + bbox = el.get("bbox_px") + if not bbox or len(bbox) != 4: + continue + x1, y1, x2, y2 = [int(v) for v in bbox] + cv2.rectangle(vis, (x1, y1), (x2, y2), (200, 100, 0), 1) + + # new (refinement) elements in red + for i, el in enumerate(new_elements): + bbox = el.get("bbox_px") + if not bbox or len(bbox) != 4: + continue + x1, y1, x2, y2 = [int(v) for v in bbox] + cv2.rectangle(vis, (x1, y1), (x2, y2), (0, 0, 255), 2) + label = f"NEW-{i}" + cv2.putText(vis, label, (x1, y1 - 5), + cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 1) + + out_path = str(Path(output_dir) / "refinement_result.png") + cv2.imwrite(out_path, vis) + log.info(f"[Refinement] Visualisation saved: {out_path}") diff --git a/dataflow_agent/toolkits/image2drawio/utils.py b/dataflow_agent/toolkits/image2drawio/utils.py index 22b88445..16df6881 100644 --- a/dataflow_agent/toolkits/image2drawio/utils.py +++ b/dataflow_agent/toolkits/image2drawio/utils.py @@ -107,19 +107,6 @@ def sample_fill_stroke(image_bgr: np.ndarray, mask: np.ndarray) -> Tuple[str, st if stroke_pixels.size == 0: stroke_pixels = image_bgr[mask] - # Select darkest quartile by luminance - if stroke_pixels.size > 0: - rgb = stroke_pixels[:, ::-1].astype(np.float32) - lum = 0.2126 * rgb[:, 0] + 0.7152 * rgb[:, 1] + 0.0722 * rgb[:, 2] - if lum.size > 10: - thresh = np.percentile(lum, 25) - sel = stroke_pixels[lum <= thresh] - else: - sel = stroke_pixels - stroke = tuple(np.mean(sel, axis=0).tolist()) - else: - stroke = (0, 0, 0) - # Fill: erode mask to remove border erode_k = max(1, int(min(h, w) * 0.004)) erode_k = min(erode_k, 7) @@ -134,14 +121,42 @@ def sample_fill_stroke(image_bgr: np.ndarray, mask: np.ndarray) -> Tuple[str, st else: fill = (255, 255, 255) + # Stroke: detect real border vs anti-aliased edge + if stroke_pixels.size > 0: + stroke_median = tuple(np.median(stroke_pixels, axis=0).tolist()) + # Compute luminance for stroke candidate and fill + def _lum_bgr(bgr): + return 0.0722 * bgr[0] + 0.7152 * bgr[1] + 0.2126 * bgr[2] + stroke_lum = _lum_bgr(stroke_median) + fill_lum = _lum_bgr(fill) + + # Check if edge pixels contain a distinctly dark border + rgb = stroke_pixels[:, ::-1].astype(np.float32) + lum = 0.2126 * rgb[:, 0] + 0.7152 * rgb[:, 1] + 0.0722 * rgb[:, 2] + dark_ratio = float(np.count_nonzero(lum < 50)) / max(1, len(lum)) + + if dark_ratio > 0.3: + # A significant portion of edge pixels are truly dark → real border + thresh = np.percentile(lum, 25) + sel = stroke_pixels[lum <= thresh] + stroke = tuple(np.mean(sel, axis=0).tolist()) + elif stroke_lum < 30 and fill_lum > 80: + # Stroke looks black but fill is colored → no real border, + # use slightly darkened fill + stroke = tuple(min(255, max(0, c * 0.7)) for c in fill) + else: + stroke = stroke_median + else: + stroke = (0, 0, 0) + return _to_hex(fill), _to_hex(stroke) def extract_text_color(image_bgr: np.ndarray, bbox_px: List[int]) -> str: x1, y1, x2, y2 = bbox_px - x1 = max(0, min(image_bgr.shape[1] - 1, int(x1))) + x1 = max(0, min(image_bgr.shape[1], int(x1))) x2 = max(0, min(image_bgr.shape[1], int(x2))) - y1 = max(0, min(image_bgr.shape[0] - 1, int(y1))) + y1 = max(0, min(image_bgr.shape[0], int(y1))) y2 = max(0, min(image_bgr.shape[0], int(y2))) if x2 <= x1 or y2 <= y1: return "#000000" diff --git a/dataflow_agent/workflow/wf_paper2drawio_sam3.py b/dataflow_agent/workflow/wf_paper2drawio_sam3.py index 29674b82..d2d4efba 100644 --- a/dataflow_agent/workflow/wf_paper2drawio_sam3.py +++ b/dataflow_agent/workflow/wf_paper2drawio_sam3.py @@ -55,6 +55,8 @@ save_masked_rgba, bbox_iou_px, ) +from dataflow_agent.toolkits.image2drawio.metric_evaluator import evaluate as metric_evaluate +from dataflow_agent.toolkits.image2drawio.refinement_processor import refine as refinement_refine from dataflow_agent.utils_common import robust_parse_json from dataflow_agent.workflow.sam3_segment_hint import ( dedupe_prompts, @@ -65,6 +67,7 @@ log = get_logger(__name__) # ==================== SAM3 PROMPTS (ported from Edit-Banana/prompts) ==================== +# 基本图形:覆盖主流流程图/架构图的所有几何元素 SHAPE_PROMPT = [ "rectangle", "rounded rectangle", @@ -73,6 +76,9 @@ "circle", "triangle", "hexagon", + "parallelogram", + "cylinder", + "cloud", ] ARROW_PROMPT = [ @@ -81,6 +87,7 @@ "connector", ] +# 图片类:覆盖各类非矢量化内容 IMAGE_PROMPT = [ "icon", "symbol", @@ -105,7 +112,7 @@ "blob", ] -# 泛化补召回提示词:避免与具体业务词绑定(如 planner/critic/robot) +# 泛化补召回提示词:低阈值兜底,避免与具体业务词绑定 IMAGE_PROMPT_RECALL = [ "illustration", "object", @@ -128,6 +135,8 @@ "container", "filled region", "background", + "section panel", + "title bar", ] SAM3_GROUPS = { @@ -139,22 +148,22 @@ # Thresholds aligned with Edit-Banana config defaults SAM3_GROUP_CONFIG = { - "shape": {"score_threshold": 0.5, "min_area": 200, "priority": 3}, - "arrow": {"score_threshold": 0.45, "min_area": 50, "priority": 4}, - "image": {"score_threshold": 0.5, "min_area": 100, "priority": 2}, - "background": {"score_threshold": 0.25, "min_area": 500, "priority": 1}, + "shape": {"score_threshold": 0.45, "min_area": 150, "priority": 3}, + "arrow": {"score_threshold": 0.40, "min_area": 30, "priority": 4}, + "image": {"score_threshold": 0.45, "min_area": 80, "priority": 2}, + "background": {"score_threshold": 0.20, "min_area": 400, "priority": 1}, } # 第2轮 image 召回配置(低阈值 + 动态最小面积) -SAM3_IMAGE_RECALL_SCORE_THRESHOLD = 0.38 -SAM3_IMAGE_RECALL_MIN_AREA_BASE = 40 -SAM3_IMAGE_RECALL_MIN_AREA_RATIO = 0.00003 -SAM3_IMAGE_RECALL_TRIGGER_MAX_IMAGES = 2 +SAM3_IMAGE_RECALL_SCORE_THRESHOLD = 0.35 +SAM3_IMAGE_RECALL_MIN_AREA_BASE = 30 +SAM3_IMAGE_RECALL_MIN_AREA_RATIO = 0.00002 +SAM3_IMAGE_RECALL_TRIGGER_MAX_IMAGES = 4 # Dedup params aligned with Edit-Banana defaults -SAM3_DEDUP_IOU = 0.7 -SAM3_ARROW_DEDUP_IOU = 0.85 -SAM3_SHAPE_IMAGE_IOU = 0.6 +SAM3_DEDUP_IOU = 0.65 +SAM3_ARROW_DEDUP_IOU = 0.80 +SAM3_SHAPE_IMAGE_IOU = 0.55 MAX_DRAWIO_ELEMENTS = 800 MIN_IMAGE_AREA_RATIO = 0.00001 @@ -950,6 +959,12 @@ def _shape_style( base = "shape=triangle;" elif st in {"hexagon"}: base = "shape=hexagon;perimeter=hexagonPerimeter2;fixedSize=1;" + elif st in {"parallelogram"}: + base = "shape=parallelogram;perimeter=parallelogramPerimeter;fixedSize=1;" + elif st in {"cylinder"}: + base = "shape=cylinder3;boundedLbl=1;backgroundOutline=1;size=15;" + elif st in {"cloud"}: + base = "ellipse;shape=cloud;" elif st in {"container", "rounded rectangle", "rounded_rect", "rounded rectangle"}: base = "rounded=1;" else: @@ -1123,7 +1138,7 @@ def _shape_type_from_prompt(prompt: str) -> str: p = normalize_prompt(prompt) if p in {"rounded rectangle", "rounded_rectangle"}: return "rounded rectangle" - if p in {"rectangle", "square", "panel", "background", "filled region", "title bar", "section_panel"}: + if p in {"rectangle", "square", "panel", "background", "filled region", "title bar", "section_panel", "section panel"}: return "rectangle" if p in {"container"}: return "rounded rectangle" @@ -1135,9 +1150,182 @@ def _shape_type_from_prompt(prompt: str) -> str: return "triangle" if p in {"hexagon"}: return "hexagon" + if p in {"parallelogram"}: + return "parallelogram" + if p in {"cylinder"}: + return "cylinder" + if p in {"cloud"}: + return "cloud" return p or "rectangle" +# ==================== CV BACKGROUND PANEL DETECTION ==================== +# 当 SAM3 没有检测到任何 background 组时,使用 CV 方法补充检测大面积 +# 色块面板(典型的海报/PPT 中的深色或浅色背景面板)。 + +# 面板检测的最小/最大面积比例 +_BG_MIN_AREA_RATIO = 0.02 # ≥ 2% 画面面积 +_BG_MAX_AREA_RATIO = 0.85 # ≤ 85% +_BG_MAX_ASPECT = 12.0 # 最大宽高比 +_BG_MAX_PANELS = 12 # 最多检测 12 个面板 +_BG_IOU_DEDUP = 0.3 # 面板间 IoU 去重阈值 +_BG_EXISTING_IOU = 0.6 # 与已有元素 IoU 去重阈值 +_BG_MIN_CONTAINED = 2 # 面板内部至少包含 N 个已有元素才算"容器" +_BG_SMALL_PANEL_RATIO = 0.08 # 面积 < 8% 的面板必须满足容器条件 + + +def _detect_background_panels_cv( + image_bgr: np.ndarray, + existing_elements: List[Dict[str, Any]], +) -> List[Dict[str, Any]]: + """ + 用 CV 方法检测大面积矩形面板作为背景元素。 + + 策略: Canny 边缘 → 轮廓 → 筛选大矩形 → NMS 去重 → 颜色采样 + 适用于海报/PPT 中有明确矩形分区的图片。 + """ + h, w = image_bgr.shape[:2] + img_area = h * w + panels: List[Dict[str, Any]] = [] + + # 收集已有元素的 bbox + existing_bboxes = [] + for el in existing_elements: + bbox = el.get("bbox_px") + if bbox and len(bbox) == 4: + existing_bboxes.append(bbox) + + gray = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2GRAY) + + # 边缘检测 + 膨胀连接 + edges = cv2.Canny(gray, 20, 60) + edges = cv2.dilate(edges, np.ones((3, 3), np.uint8), iterations=2) + + contours, _ = cv2.findContours(edges, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) + + # 找大矩形轮廓 + candidates: List[Tuple[List[int], float]] = [] # (bbox, rect_fill) + for cnt in contours: + area = cv2.contourArea(cnt) + if area < img_area * _BG_MIN_AREA_RATIO: + continue + peri = cv2.arcLength(cnt, True) + approx = cv2.approxPolyDP(cnt, 0.02 * peri, True) + if len(approx) < 4 or len(approx) > 8: + continue + x, y, rw, rh = cv2.boundingRect(cnt) + ba = rw * rh + rect_fill = area / ba if ba > 0 else 0 + if rect_fill < 0.5: # 至少 50% 填充 → 近似矩形 + continue + ratio = ba / img_area + if ratio > _BG_MAX_AREA_RATIO: + continue + aspect = max(rw, rh) / max(1, min(rw, rh)) + if aspect > _BG_MAX_ASPECT: + continue + candidates.append(([int(x), int(y), int(x + rw), int(y + rh)], rect_fill)) + + if not candidates: + return [] + + # NMS: 按面积从大到小,去掉高 IoU 重叠的 (内外轮廓去重) + candidates.sort(key=lambda c: _bbox_area(c[0]), reverse=True) + kept: List[List[int]] = [] + for bbox, _ in candidates: + skip = False + for kb in kept: + if bbox_iou_px(bbox, kb) > _BG_IOU_DEDUP: + skip = True + break + if not skip: + kept.append(bbox) + if len(kept) >= _BG_MAX_PANELS: + break + + # 过滤掉与已有元素高度重叠的 + final: List[List[int]] = [] + for bbox in kept: + skip = False + for eb in existing_bboxes: + if bbox_iou_px(bbox, eb) > _BG_EXISTING_IOU: + skip = True + break + if not skip: + final.append(bbox) + + # 构建 shape 元素 — 从边框附近采样颜色 + # 对于较小的面板 (< 8%), 要求内部包含至少 N 个已有元素才算"容器" + # 大面板 (≥ 8%) 通常就是布局面板,可以直接保留 + for idx, bbox in enumerate(final): + x1, y1, x2, y2 = bbox + panel_ratio = _bbox_area(bbox) / img_area + + # 小面板容器验证: 计算内部包含多少个已有元素 + if panel_ratio < _BG_SMALL_PANEL_RATIO: + contained = 0 + for eb in existing_bboxes: + # 元素中心在面板内部 → 算被包含 + ecx = (eb[0] + eb[2]) / 2 + ecy = (eb[1] + eb[3]) / 2 + if x1 <= ecx <= x2 and y1 <= ecy <= y2: + contained += 1 + if contained < _BG_MIN_CONTAINED: + log.debug( + f"[paper2drawio_sam3] CV panel [{x1},{y1},{x2},{y2}] " + f"ratio={panel_ratio*100:.1f}% skipped: only {contained} " + f"contained elements (need ≥{_BG_MIN_CONTAINED})" + ) + continue + x1, y1, x2, y2 = bbox + roi = image_bgr[y1:y2, x1:x2] + if roi.size == 0: + continue + rh_roi, rw_roi = roi.shape[:2] + border_w = max(3, min(rw_roi, rh_roi) // 20) + + # 边框区域像素 + border_pixels = np.concatenate([ + roi[:border_w, :].reshape(-1, 3), # top + roi[-border_w:, :].reshape(-1, 3), # bottom + roi[:, :border_w].reshape(-1, 3), # left + roi[:, -border_w:].reshape(-1, 3), # right + ], axis=0) + + fill_bgr = np.median(border_pixels, axis=0).astype(int) + fill_hex = "#{:02x}{:02x}{:02x}".format( + int(fill_bgr[2]), int(fill_bgr[1]), int(fill_bgr[0]) + ) + # 边框色: 稍微深一点 + darker = np.clip(fill_bgr * 0.7, 0, 255).astype(int) + stroke_hex = "#{:02x}{:02x}{:02x}".format( + int(darker[2]), int(darker[1]), int(darker[0]) + ) + + panels.append({ + "id": f"bg{idx}", + "kind": "shape", + "shape_type": "rectangle", + "bbox_px": bbox, + "fill": fill_hex, + "stroke": stroke_hex, + "text": "", + "text_color": None, + "font_size": None, + "area": _bbox_area(bbox), + "group": "background", + "prompt": "cv_panel", + }) + + if panels: + log.info( + f"[paper2drawio_sam3] CV background panels: " + f"{len(panels)} detected, areas={[round(p['area']/img_area*100,1) for p in panels]}%" + ) + + return panels + + def _sam3_predict_groups( client: Any, image_path: str, @@ -1172,6 +1360,14 @@ def _sam3_predict_groups( image_path=image_path, runs=base_runs, ) + + # Diagnostic: log per-group counts before dedup + _pre_dedup: Dict[str, int] = {} + for item in all_results: + g = str(item.get("group", "unknown")) + _pre_dedup[g] = _pre_dedup.get(g, 0) + 1 + log.info(f"[paper2drawio_sam3] SAM3 raw results (before dedup): {json.dumps(_pre_dedup)}") + all_results = dedup_sam3_results_across_groups( all_results, group_config=SAM3_GROUP_CONFIG, @@ -1434,6 +1630,15 @@ def _refine_low_coverage_image_mask(mask: np.ndarray, bbox: List[int]) -> Tuple[ shapes.sort(key=lambda s: s.get("area", 0), reverse=True) images.sort(key=lambda s: s.get("area", 0), reverse=True) + + # ---- CV fallback: detect background panels not found by SAM3 ---- + bg_count = sum(1 for s in shapes if s.get("group") == "background") + if bg_count == 0: + cv_bg = _detect_background_panels_cv(image_bgr, shapes + images) + if cv_bg: + log.info(f"[paper2drawio_sam3] CV background detection: added {len(cv_bg)} panels") + shapes = cv_bg + shapes # backgrounds go first (rendered at back) + total = len(shapes) + len(images) if total > MAX_DRAWIO_ELEMENTS: keep = max(0, MAX_DRAWIO_ELEMENTS - len(shapes)) @@ -1499,15 +1704,16 @@ async def _text_node(state: Paper2DrawioState) -> Paper2DrawioState: temp_state.request.chat_api_key = api_key try: - vlm_timeout = int(os.getenv("VLM_OCR_TIMEOUT", "120")) + vlm_timeout = int(os.getenv("VLM_OCR_TIMEOUT", "180")) except ValueError: - vlm_timeout = 120 + vlm_timeout = 180 + agent = create_vlm_agent( name="ImageTextBBoxAgent", model_name="qwen-vl-ocr-2025-11-20", chat_api_url=chat_api_url, - max_tokens=4096, vlm_mode="ocr", + max_tokens=8192, additional_params={"input_image": img_path, "timeout": vlm_timeout}, ) new_state = await agent.execute(temp_state) @@ -1640,6 +1846,65 @@ async def _build_elements_node(state: Paper2DrawioState) -> Paper2DrawioState: state.temp_data["fallback_hide_text_blocks"] = fallback_hide_text_blocks return state + async def _evaluate_node(state: Paper2DrawioState) -> Paper2DrawioState: + """Evaluate coverage quality and detect uncovered bad regions.""" + img_path = state.temp_data.get("input_image_path") + if not img_path or not os.path.exists(img_path): + state.temp_data["bad_regions"] = [] + return state + + elements = state.temp_data.get("drawio_elements", []) or [] + text_blocks = state.temp_data.get("text_blocks", []) or [] + base_dir = str(Path(_ensure_result_path(state))) + + eval_result = metric_evaluate( + image_path=img_path, + elements=elements, + text_blocks=text_blocks, + output_dir=base_dir, + ) + + state.temp_data["bad_regions"] = eval_result.get("bad_regions", []) + state.temp_data["eval_score"] = eval_result.get("score", 100) + state.temp_data["needs_refinement"] = eval_result.get("needs_refinement", False) + + log.info( + f"[paper2drawio_sam3] Evaluation: score={eval_result.get('score', 0):.1f}, " + f"bad_regions={len(eval_result.get('bad_regions', []))}, " + f"needs_refinement={eval_result.get('needs_refinement', False)}" + ) + return state + + async def _refine_node(state: Paper2DrawioState) -> Paper2DrawioState: + """Fallback rescue: crop uncovered bad regions as image elements.""" + if not state.temp_data.get("needs_refinement", False): + return state + + img_path = state.temp_data.get("input_image_path") + if not img_path or not os.path.exists(img_path): + return state + + bad_regions = state.temp_data.get("bad_regions", []) + if not bad_regions: + return state + + elements = state.temp_data.get("drawio_elements", []) or [] + base_dir = str(Path(_ensure_result_path(state))) + + new_elements = refinement_refine( + image_path=img_path, + bad_regions=bad_regions, + existing_elements=elements, + output_dir=base_dir, + ) + + if new_elements: + elements.extend(new_elements) + state.temp_data["drawio_elements"] = elements + log.info(f"[paper2drawio_sam3] Refinement: added {len(new_elements)} fallback elements") + + return state + async def _render_xml_node(state: Paper2DrawioState) -> Paper2DrawioState: img_path = state.temp_data.get("input_image_path") if not img_path or not os.path.exists(img_path): @@ -1714,6 +1979,8 @@ async def _render_xml_node(state: Paper2DrawioState) -> Paper2DrawioState: "segment_hint": _segment_hint_node, "sam3": _sam3_node, "build_elements": _build_elements_node, + "evaluate": _evaluate_node, + "refine": _refine_node, "render_xml": _render_xml_node, "_end_": lambda s: s, } @@ -1723,7 +1990,9 @@ async def _render_xml_node(state: Paper2DrawioState) -> Paper2DrawioState: ("text_ocr", "segment_hint"), ("segment_hint", "sam3"), ("sam3", "build_elements"), - ("build_elements", "render_xml"), + ("build_elements", "evaluate"), + ("evaluate", "refine"), + ("refine", "render_xml"), ("render_xml", "_end_"), ]