diff --git a/README.md b/README.md index 72a399cf..44e32dbb 100644 --- a/README.md +++ b/README.md @@ -294,13 +294,14 @@ Dingo provides **70+ evaluation metrics** across multiple dimensions, combining | **RAG Evaluation** | Faithfulness, Context Precision, Answer Relevancy | RAG system assessment | | **Hallucination Detection** | HHEM-2.1-Open, Factuality Check | Production AI reliability | | **Classification** | Topic categorization, Content labeling | Data organization | -| **Multimodal** | Image-text relevance, VLM quality | Vision-language data | +| **Multimodal** | Image-text relevance, VLM quality, OCR visual evaluation | Vision-language data | | **Security** | PII detection, Perspective API toxicity | Privacy and safety | 📊 **[View Complete Metrics Documentation →](docs/metrics.md)** 📖 **[RAG Evaluation Guide →](docs/rag_evaluation_metrics.md)** | **[中文版](docs/rag_evaluation_metrics_zh.md)** 🔍 **[Hallucination Detection Guide →](docs/hallucination_detection_guide.md)** | **[中文版](docs/hallucination_guide.md)** -✅ **[Factuality Assessment Guide →](docs/factuality_assessment_guide.md)** | **[中文版](docs/factcheck_guide.md)** +✅ **[Factuality Assessment Guide →](docs/factuality_assessment_guide.md)** | **[中文版](docs/factcheck_guide.md)** +👁️ **[VLM Render Judge Guide →](docs/en/vlm_render_judge_guide.md)** | **[中文版](docs/vlm_render_judge_guide.md)** Most metrics are backed by academic research to ensure scientific rigor. diff --git a/dingo/model/llm/agent/tools/__init__.py b/dingo/model/llm/agent/tools/__init__.py index dcdbe098..42606ac2 100644 --- a/dingo/model/llm/agent/tools/__init__.py +++ b/dingo/model/llm/agent/tools/__init__.py @@ -9,6 +9,7 @@ from dingo.model.llm.agent.tools.tool_registry import ToolRegistry, tool_register # Convenience function for getting tools +# Note: Tools are lazily loaded. Import from specific module before using ToolRegistry.get() get_tool = ToolRegistry.get __all__ = [ diff --git a/dingo/model/llm/agent/tools/mineru_ocr_tool.py b/dingo/model/llm/agent/tools/mineru_ocr_tool.py new file mode 100644 index 00000000..d6f708cc --- /dev/null +++ b/dingo/model/llm/agent/tools/mineru_ocr_tool.py @@ -0,0 +1,239 @@ +""" +MinerU OCR Tool for Agent-Based Evaluation + +This tool calls MinerU API (https://mineru.net/apiManage/docs) for initial OCR recognition. +Used as the first step in the iterative judge-refine workflow. +""" + +import base64 +import os +import time +from typing import Any, Dict, Optional + +import requests +from pydantic import BaseModel, Field + +from dingo.model.llm.agent.tools.base_tool import BaseTool +from dingo.model.llm.agent.tools.tool_registry import tool_register +from dingo.utils import log + + +class MinerUOCRToolConfig(BaseModel): + """Configuration for MinerU OCR Tool""" + api_key: Optional[str] = Field( + default=None, + description="MinerU API key (from https://mineru.net/apiManage/docs)" + ) + api_url: str = Field( + default="https://mineru.net/api/v4/extract/task", + description="MinerU API endpoint URL" + ) + timeout: int = Field( + default=120, + ge=30, + le=600, + description="Timeout for API request in seconds" + ) + poll_interval: int = Field( + default=3, + ge=1, + le=30, + description="Interval between status polling in seconds" + ) + + +@tool_register +class MinerUOCRTool(BaseTool): + """ + MinerU OCR Tool - Call MinerU API for document parsing. + + MinerU (https://mineru.net) provides high-quality document parsing with support for: + - Text extraction + - Formula recognition (LaTeX) + - Table extraction + - Layout detection + + API Documentation: https://mineru.net/apiManage/docs + + Configuration: + api_key: Your MinerU API key + api_url: API endpoint URL + timeout: Request timeout in seconds + poll_interval: Status polling interval + + Returns: + Dict with: + - success: bool + - content: Extracted text/markdown content + - content_type: Type of content extracted + - error: Error message if failed + """ + + name = "mineru_ocr_tool" + description = "Call MinerU API for OCR/document parsing" + config: MinerUOCRToolConfig = MinerUOCRToolConfig() + + @classmethod + def execute( + cls, + image_path: Optional[str] = None, + image_base64: Optional[str] = None, + content_type: str = "text", + **kwargs + ) -> Dict[str, Any]: + """ + Execute MinerU OCR recognition. + + Args: + image_path: Path to image file + image_base64: Base64 encoded image (alternative to image_path) + content_type: Type of content to extract - "text", "formula", "table" + + Returns: + Dict with OCR results or error + """ + if not cls.config.api_key: + return { + 'success': False, + 'error': 'MinerU API key not configured. Get your key from https://mineru.net/apiManage/docs' + } + + # Get image data + if image_path and os.path.exists(image_path): + with open(image_path, 'rb') as f: + image_data = base64.b64encode(f.read()).decode('utf-8') + elif image_base64: + image_data = image_base64 + else: + return { + 'success': False, + 'error': 'No image provided. Provide either image_path or image_base64.' + } + + try: + # Submit extraction task + result = cls._submit_and_wait(image_data, content_type) + return result + + except requests.Timeout: + return { + 'success': False, + 'error': f'MinerU API request timed out after {cls.config.timeout}s' + } + except Exception as e: + log.error(f"MinerU OCR failed: {e}") + return { + 'success': False, + 'error': str(e) + } + + @classmethod + def _submit_and_wait(cls, image_data: str, content_type: str) -> Dict[str, Any]: + """ + Submit task to MinerU API and wait for result. + + MinerU API is async - we submit a task and poll for completion. + """ + headers = { + "Authorization": f"Bearer {cls.config.api_key}", + "Content-Type": "application/json" + } + + # Submit task + submit_payload = { + "file": f"data:image/png;base64,{image_data}", + "is_ocr": True, + "enable_formula": content_type in ["formula", "equation"], + "enable_table": content_type == "table", + } + + log.info(f"Submitting MinerU OCR task for {content_type}...") + + response = requests.post( + cls.config.api_url, + headers=headers, + json=submit_payload, + timeout=30 + ) + response.raise_for_status() + + result = response.json() + + if result.get("code") != 0: + return { + 'success': False, + 'error': f"MinerU API error: {result.get('msg', 'Unknown error')}" + } + + task_id = result.get("data", {}).get("task_id") + if not task_id: + return { + 'success': False, + 'error': "No task_id returned from MinerU API" + } + + # Poll for result + log.info(f"MinerU task submitted, task_id: {task_id}") + return cls._poll_result(task_id, headers) + + @classmethod + def _poll_result(cls, task_id: str, headers: Dict) -> Dict[str, Any]: + """Poll MinerU API for task result.""" + status_url = f"https://mineru.net/api/v4/extract/task/{task_id}" + + start_time = time.time() + + while time.time() - start_time < cls.config.timeout: + response = requests.get(status_url, headers=headers, timeout=30) + response.raise_for_status() + + result = response.json() + status = result.get("data", {}).get("status") + + if status == "success": + # Extract content from result + content = cls._extract_content(result.get("data", {})) + return { + 'success': True, + 'content': content, + 'task_id': task_id, + 'raw_result': result + } + + elif status == "failed": + return { + 'success': False, + 'error': f"MinerU task failed: {result.get('data', {}).get('msg', 'Unknown error')}" + } + + # Still processing, wait and retry + log.debug(f"MinerU task {task_id} status: {status}, waiting...") + time.sleep(cls.config.poll_interval) + + return { + 'success': False, + 'error': f"MinerU task timed out after {cls.config.timeout}s" + } + + @classmethod + def _extract_content(cls, data: Dict) -> str: + """Extract text content from MinerU result.""" + # Try different result formats + if "markdown" in data: + return data["markdown"] + if "text" in data: + return data["text"] + if "content" in data: + return data["content"] + if "pages" in data: + # Multi-page result + pages = data["pages"] + contents = [] + for page in pages: + if isinstance(page, dict): + contents.append(page.get("markdown", page.get("text", ""))) + elif isinstance(page, str): + contents.append(page) + return "\n\n".join(contents) + + return str(data) diff --git a/dingo/model/llm/agent/tools/render_tool.py b/dingo/model/llm/agent/tools/render_tool.py new file mode 100644 index 00000000..a67fb945 --- /dev/null +++ b/dingo/model/llm/agent/tools/render_tool.py @@ -0,0 +1,476 @@ +""" +Render Tool for OCR Self-Verification + +This tool renders text/equation/table content as images for VLM comparison. +Used by agent-based OCR quality evaluation to implement the "render-judge-refine" loop. +""" + +import base64 +import io +import os +import re +import shutil +import subprocess +import tempfile +from typing import Any, Dict, Optional + +from pydantic import BaseModel, Field + +from dingo.model.llm.agent.tools.base_tool import BaseTool +from dingo.model.llm.agent.tools.tool_registry import tool_register +from dingo.utils import log + +try: + import numpy as np + from PIL import Image, ImageDraw, ImageFont + HAS_PIL = True +except ImportError: + HAS_PIL = False + + +class RenderToolConfig(BaseModel): + """Configuration for RenderTool""" + font_path: Optional[str] = Field( + default=None, + description="Path to font file for text rendering (e.g., simsun.ttc)" + ) + cjk_font: Optional[str] = Field( + default=None, + description="CJK font name for LaTeX rendering (e.g., 'SimSun' on Windows, 'PingFang SC' on macOS, 'Noto Sans CJK SC' on Linux)" + ) + density: int = Field( + default=150, + ge=72, + le=300, + description="Rendering density (DPI) for LaTeX" + ) + timeout: int = Field( + default=60, + ge=10, + le=300, + description="Timeout for LaTeX rendering in seconds" + ) + pad: int = Field( + default=20, + ge=0, + le=100, + description="Padding around rendered content" + ) + + +@tool_register +class RenderTool(BaseTool): + """ + Render text/equation/table content as images. + + This tool converts OCR output (text, LaTeX equations, HTML tables) into + rendered images that can be compared with original document images by VLM. + + Features: + - Text rendering with CJK support + - LaTeX equation rendering via xelatex + - HTML table rendering + + Configuration: + font_path: Path to font file (default: system font) + density: Rendering DPI (default: 150) + timeout: Rendering timeout in seconds (default: 60) + pad: Padding around content (default: 20) + + Returns: + Dict with: + - success: bool + - image_base64: Base64 encoded PNG image + - image_path: Optional path to saved image file + - error: Error message if failed + """ + + name = "render_tool" + description = "Render text, equations, or tables as images for VLM comparison" + config: RenderToolConfig = RenderToolConfig() + + # LaTeX template for rendering (enhanced with MinerU_Metis symbol support) + LATEX_TEMPLATE = r""" +\documentclass[12pt]{article} +\usepackage{geometry} +\usepackage[CJKmath]{xeCJK} +[CJKFONT] +\geometry{paperwidth=[PAPERWIDTH], paperheight=5000cm, margin=1cm} +\pagestyle{empty} +\usepackage{amsmath} +\usepackage{amssymb} +\usepackage{wasysym} +\usepackage{unicode-math} +\usepackage{upgreek} +\usepackage{xcolor} +\usepackage{textcomp} +\usepackage{fontspec} +\setmathfont{Latin Modern Math} +\setcounter{MaxMatrixCols}{1000} +\xeCJKDeclareCharClass{CJK}{"0080->"FFFF} +\xeCJKDeclareCharClass{CJK}{"10000->"1FFFF} +\xeCJKDeclareCharClass{CJK}{"20000->"2FFFF} +\xeCJKDeclareCharClass{CJK}{"30000->"3FFFF} +\setlength{\parindent}{0pt} +\setlength{\parskip}{0pt} +\begin{document} +\raggedright +\makeatletter +\makeatother +[CONTENT] +\end{document} +""" + + @classmethod + def execute( + cls, + content: str, + content_type: str = "text", + output_path: Optional[str] = None, + **kwargs + ) -> Dict[str, Any]: + """ + Execute rendering and return image. + + Args: + content: The content to render (text, LaTeX, or HTML) + content_type: Type of content - "text", "equation", or "table" + output_path: Optional path to save rendered image + + Returns: + Dict with: + - success: bool + - image_base64: Base64 encoded image + - image_path: Path to saved image (if output_path provided) + - content_type: Type of rendered content + - error: Error message if failed + """ + if not HAS_PIL: + return { + 'success': False, + 'error': 'PIL (Pillow) is required for rendering. Install with: pip install Pillow' + } + + if not content or not content.strip(): + return { + 'success': False, + 'error': 'Content is empty or None' + } + + try: + # Route to appropriate renderer + if content_type == "equation": + image = cls._render_latex(content) + elif content_type == "table": + image = cls._render_table(content) + else: # default to text + image = cls._render_text(content) + + if image is None: + return { + 'success': False, + 'error': f'Failed to render {content_type} content' + } + + # Convert to base64 + buffer = io.BytesIO() + image.save(buffer, format='PNG') + image_base64 = base64.b64encode(buffer.getvalue()).decode('utf-8') + + result = { + 'success': True, + 'image_base64': image_base64, + 'content_type': content_type, + 'image_size': image.size + } + + # Optionally save to file + if output_path: + os.makedirs(os.path.dirname(output_path) or '.', exist_ok=True) + image.save(output_path) + result['image_path'] = output_path + log.info(f"Rendered image saved to: {output_path}") + + return result + + except Exception as e: + log.error(f"Rendering failed: {e}") + return { + 'success': False, + 'error': str(e), + 'content_type': content_type + } + + @classmethod + def _render_text(cls, content: str) -> Optional[Image.Image]: + """ + Render plain text as image. + + Args: + content: Text content to render + + Returns: + PIL Image or None if failed + """ + try: + # Try to load font + font = None + font_size = 24 + + if cls.config.font_path and os.path.exists(cls.config.font_path): + try: + font = ImageFont.truetype(cls.config.font_path, font_size) + except Exception: + pass + + if font is None: + try: + # Try common system fonts (prioritize Western fonts for better symbol support) + font_candidates = [ + '/System/Library/Fonts/Helvetica.ttc', # macOS Helvetica + 'Arial', # Windows/Linux Arial + 'Arial Unicode MS', # Unicode support + 'DejaVuSans', # Linux fallback + 'SimSun' # Chinese font (last resort) + ] + for font_name in font_candidates: + try: + font = ImageFont.truetype(font_name, font_size) + break + except Exception: + continue + except Exception: + font = ImageFont.load_default() + + # Calculate text size + dummy_img = Image.new('RGB', (1, 1), 'white') + draw = ImageDraw.Draw(dummy_img) + + # Handle multiline text + lines = content.split('\n') + max_width = 0 + total_height = 0 + line_heights = [] + + for line in lines: + bbox = draw.textbbox((0, 0), line or ' ', font=font) + width = bbox[2] - bbox[0] + height = bbox[3] - bbox[1] + max_width = max(max_width, width) + line_heights.append(height) + total_height += height + 5 # 5px line spacing + + # Create image with padding + pad = cls.config.pad + img_width = max_width + 2 * pad + img_height = total_height + 2 * pad + + image = Image.new('RGB', (img_width, img_height), 'white') + draw = ImageDraw.Draw(image) + + # Draw text + y_offset = pad + for i, line in enumerate(lines): + draw.text((pad, y_offset), line, font=font, fill='black') + y_offset += line_heights[i] + 5 + + return image + + except Exception as e: + log.error(f"Text rendering failed: {e}") + return None + + @classmethod + def _render_latex(cls, content: str) -> Optional[Image.Image]: + """ + Render LaTeX equation as image using xelatex. + + Args: + content: LaTeX content to render + + Returns: + PIL Image or None if failed + """ + temp_dir = tempfile.mkdtemp(prefix="latex_render_") + + try: + # Prepare content - wrap in math mode if needed + processed_content = cls._preprocess_latex(content) + + # Determine paper width based on content length + char_count = len(content) + if char_count < 100: + paper_width = "20cm" + elif char_count < 300: + paper_width = "40cm" + else: + paper_width = "60cm" + + # Determine CJK font to use + cjk_font_line = "" + if cls.config.cjk_font: + cjk_font_line = f"\\setCJKmainfont{{{cls.config.cjk_font}}}" + else: + # Try to detect system and use appropriate default + import platform + system = platform.system() + if system == "Windows": + cjk_font_line = "\\setCJKmainfont{SimSun}" + elif system == "Darwin": # macOS + cjk_font_line = "\\setCJKmainfont{PingFang SC}" + elif system == "Linux": + cjk_font_line = "\\setCJKmainfont{Noto Sans CJK SC}" + else: + # Fallback: try SimSun, may fail on non-Windows + cjk_font_line = "\\setCJKmainfont{SimSun}" + + # Generate LaTeX file + latex = cls.LATEX_TEMPLATE.replace("[PAPERWIDTH]", paper_width) + latex = latex.replace("[CJKFONT]", cjk_font_line) + latex = latex.replace("[CONTENT]", processed_content) + + tex_file = os.path.join(temp_dir, "formula.tex") + pdf_file = os.path.join(temp_dir, "formula.pdf") + png_file = os.path.join(temp_dir, "formula.png") + + with open(tex_file, "w", encoding="utf-8") as f: + f.write(latex) + + # Compile with xelatex (use list args to prevent shell injection) + xelatex_cmd = [ + "xelatex", + "-interaction=nonstopmode", + f"-output-directory={temp_dir}", + tex_file + ] + result = subprocess.run( + xelatex_cmd, + capture_output=True, + timeout=cls.config.timeout + ) + + if not os.path.exists(pdf_file): + log.error(f"LaTeX compilation failed: {result.stderr.decode()}") + return None + + # Convert PDF to PNG using ImageMagick (use list args to prevent shell injection) + convert_cmd = [ + "magick", + "-density", str(cls.config.density), + pdf_file, + "-background", "white", + "-alpha", "remove", + "-quality", "100", + png_file + ] + subprocess.run(convert_cmd, timeout=30) + + if not os.path.exists(png_file): + log.error("PDF to PNG conversion failed") + return None + + # Load and crop image + image = cls._crop_image(png_file) + return image + + except subprocess.TimeoutExpired: + log.error("LaTeX rendering timed out") + return None + except Exception as e: + log.error(f"LaTeX rendering failed: {e}") + return None + finally: + # Cleanup + if os.path.exists(temp_dir): + try: + shutil.rmtree(temp_dir) + except Exception: + pass + + @classmethod + def _preprocess_latex(cls, content: str) -> str: + """ + Preprocess LaTeX content for rendering. + + Args: + content: Raw LaTeX content + + Returns: + Preprocessed LaTeX content + """ + # Check if content already has math delimiters + math_patterns = [ + r'\$\$.*?\$\$', + r'\$.*?\$', + r'\\\(.*?\\\)', + r'\\\[.*?\\\]', + r'\\begin\{equation', + r'\\begin\{align', + ] + + has_math = any(re.search(p, content, re.DOTALL) for p in math_patterns) + + if not has_math: + # Wrap in display math mode + content = f"$${content}$$" + + # Escape special characters in text mode + # (simplified - full implementation would be more complex) + return content + + @classmethod + def _render_table(cls, content: str) -> Optional[Image.Image]: + """ + Render HTML table as image. + + Args: + content: HTML table content + + Returns: + PIL Image or None if failed + """ + # For now, fall back to text rendering + # A full implementation would use a headless browser or + # specialized HTML-to-image converter + log.warning("Table rendering falling back to text mode") + return cls._render_text(content) + + @classmethod + def _crop_image(cls, image_path: str) -> Image.Image: + """ + Crop image to content bounds with padding. + + Args: + image_path: Path to image file + + Returns: + Cropped PIL Image + """ + img = Image.open(image_path).convert("L") + img_data = np.asarray(img, dtype=np.uint8) + + # Find non-white pixels + nnz_inds = np.where(img_data < 250) + + if len(nnz_inds[0]) == 0: + # All white - return small image + return Image.new('RGB', (100, 50), 'white') + + y_min = np.min(nnz_inds[0]) + y_max = np.max(nnz_inds[0]) + x_min = np.min(nnz_inds[1]) + x_max = np.max(nnz_inds[1]) + + # Add padding + pad = cls.config.pad + h, w = img_data.shape + x_min = max(0, x_min - pad) + y_min = max(0, y_min - pad) + x_max = min(w, x_max + pad) + y_max = min(h, y_max + pad) + + # Crop and convert to RGB + img = Image.open(image_path).convert("RGB") + cropped = img.crop((x_min, y_min, x_max, y_max)) + + return cropped diff --git a/dingo/model/llm/vlm_ocr_understanding.py b/dingo/model/llm/vlm_ocr_understanding.py index 5a8a01f3..8c638e2d 100644 --- a/dingo/model/llm/vlm_ocr_understanding.py +++ b/dingo/model/llm/vlm_ocr_understanding.py @@ -21,7 +21,7 @@ class VLMOCRUnderstanding(BaseOpenAI): _metric_info = { "category": "Multimodality Assessment Metrics", "quality_dimension": "VLM_OCR_UNDERSTANDING", - "metric_name": "PromptVLMOCRUnderstanding", + "metric_name": "VLMOCRUnderstanding", "description": "评估多模态模型对图片中文字内容的识别和理解能力,使用DeepSeek-OCR作为Ground Truth", "paper_title": "DeepSeek-OCR: Contexts Optical Compression", "paper_url": "https://github.com/deepseek-ai/DeepSeek-OCR", diff --git a/dingo/model/llm/vlm_render_judge.py b/dingo/model/llm/vlm_render_judge.py new file mode 100644 index 00000000..9f3e2a72 --- /dev/null +++ b/dingo/model/llm/vlm_render_judge.py @@ -0,0 +1,359 @@ +""" +VLM Render Judge - Visual OCR Quality Evaluation + +This metric implements the "Render → Judge" pattern from MinerU_Metis: +1. Render OCR content as image (using LaTeX/HTML rendering) +2. Use VLM to compare original image vs rendered image +3. Output quality assessment + +This is a standalone metric that can be used independently or as part of +an iterative refinement workflow. + +Prompt source: MinerU_Metis configs/prompts/text/judge-render.j2 +""" + +import base64 +import os +import re +from typing import Any, Dict, List, Optional + +from dingo.io import Data +from dingo.io.output.eval_detail import EvalDetail, QualityLabel +from dingo.model import Model +from dingo.model.llm.base_openai import BaseOpenAI +from dingo.utils import log + + +@Model.llm_register("VLMRenderJudge") +class VLMRenderJudge(BaseOpenAI): + """ + VLM-based OCR quality evaluation through visual comparison. + + Workflow: + 1. Receive original image + OCR content + 2. Render OCR content as image + 3. VLM compares original vs rendered + 4. Output quality assessment + + ┌─────────┐ ┌──────────┐ ┌─────────────────────┐ + │ 原始图像 │───▶│ │ │ │ + └─────────┘ │ VLM │───▶│ EvalDetail │ + ┌─────────┐ │ Judge │ │ - is_correct │ + │ OCR内容 │───▶│ │ │ - reason │ + └────┬────┘ └──────────┘ └─────────────────────┘ + │ ▲ + ▼ │ + ┌─────────┐ │ + │ 渲染 │─────────┘ + │ (LaTeX) │ + └─────────┘ + + Input Data Fields: + - image: Original document image (path or base64) + - content: OCR result text to evaluate + - content_type: "text" | "equation" | "table" (optional, default: "text") + + Output: + - score: 1.0 (correct) or 0.0 (incorrect) + - label: QUALITY_GOOD or QUALITY_BAD_OCR.* + - reason: VLM's detailed judgment reason + - extra: {"is_correct": bool, "rendered_available": bool} + + Configuration Example: + { + "name": "VLMRenderJudge", + "config": { + "key": "your-api-key", + "api_url": "https://api.openai.com/v1", + "model": "gpt-4o", + "parameters": { + "content_type": "equation", + "render_config": { + "density": 150, + "pad": 20 + } + } + } + } + """ + + _metric_info = { + "category": "Multimodality Assessment Metrics", + "metric_name": "VLMRenderJudge", + "description": "VLM-based OCR quality evaluation through visual render-compare", + } + + # Judge prompt from MinerU_Metis (configs/prompts/text/judge-render.j2) + JUDGE_PROMPT = """You are a Text Consistency Verification Expert. Your only task is to compare text content (including characters, symbols, and text) between two images: +1. First Image: Ground Truth (original image with accurate text content) +2. Second Image: Model-rendered OCR result to be evaluated + +Judgment Rules: +1. Core Consistency: Return TRUE only if the text, characters, and symbols in the second image are fully consistent with those in the first image. Return FALSE if there are actual missing, incorrect, or extra text, characters, or symbols (excluding mere rendering differences). Additionally, any content that should be present but is not displayed (i.e., undisplayed portions) shall also be deemed inconsistent (return FALSE). +2. In addition to punctuation marks, spaces, and line breaks, all symbols (including superscripts/subscripts, e.g., ⁸, ₃, ²α) must maintain both semantic consistency and visual shape consistency. Differences in text font styles (e.g., bold/italic, serif/sans-serif, font color) do NOT affect consistency, provided that the core character identity is preserved. Note: This font style exemption applies solely to standard text characters (not symbols) — superscripts, subscripts, and all non-text symbols are excluded from this exemption and must strictly match the ground truth in visual shape. +3. Space Quantity Rule: Differences in the number of spaces (e.g., 1 space vs. 2 spaces, no space vs. multiple spaces) between the two images do NOT count as inconsistency (IGNORE). +4. Symbol Punctuation Rule: Differences between Chinese and English symbols/punctuation (e.g., "," vs ",", "." vs "。", "" vs "", ":" vs ":") do NOT count as inconsistency (IGNORE). +5. Line Break Rule: Differences in line breaks/line wrapping between the two images do NOT count as inconsistency (IGNORE). +6. Special Character Rule: For ellipsis (.../.......), underscores (//) (or equivalent horizontal line representations), or similar repeated symbols: +- Mandatory Requirement: The symbol type (e.g., ellipsis vs. underscore) MUST exist in the corresponding position of the second image as the first (complete absence of the symbol = ERROR; line break differences do NOT affect position judgment). +- Acceptable Difference: The length/number of repeated symbols (e.g., 3 dots vs. 6 dots, 1 underscore "_" vs. 5 underscores "_____") does NOT need to match (ACCEPT any length). +7. Truncated Character Rule: If either image contains partial/truncated characters (e.g., half-cut characters at image edges), the OCR result SHOULD NOT recognize these partial characters as valid text. These truncated characters must be IGNORED during consistency comparison – the OCR result is considered incorrect if it attempts to recognize/truncate partial characters. + +Output Requirements (MUST COMPLY) +1. First output concise reason (max 300 words) explaining your judgment (key differences/findings) +2. Then output ONLY XML (no extra text/formatting) with exactly this structure: +Concise reason (max 300 words) explaining your judgment. +The final judge result: true / false + +Example Reasoning & Output +Case 1 (Allowed Rendering Difference): +GT has "答案" , OCR has "答案" → allowed, consistent. +true + +Case 2 (Forbidden Difference): +GT has "ABC" , OCR has "EFG" → inconsistent. +false + +Case 3 (Forbidden Difference: Symbol Shape) +GT has "ⓐ 128r⁸", OCR has "(a) 128r⁸" → ⓐ changed to (a) (Rule 3A violation), inconsistent. +false""" + + @classmethod + def eval(cls, input_data: Data) -> EvalDetail: + """ + Evaluate OCR quality through render-compare. + + Args: + input_data: Data with 'image' and 'content' fields + + Returns: + EvalDetail with quality assessment + """ + try: + cls.create_client() + + # Get inputs + image = cls._get_image(input_data) + content = cls._get_content(input_data) + content_type = cls._get_content_type(input_data) + + if not image: + return cls._error_result("No image provided for comparison") + + if not content: + return cls._error_result("No OCR content provided for evaluation") + + log.info(f"{cls.__name__}: Evaluating {content_type} content") + + # Step 1: Render OCR content + rendered_base64 = cls._render_content(content, content_type) + + if not rendered_base64: + log.warning(f"{cls.__name__}: Render failed, using text-only comparison") + return cls._text_only_comparison(image, content) + + # Step 2: VLM Judge + judge_result = cls._judge(image, rendered_base64) + + # Step 3: Build result + return cls._build_result(judge_result, content) + + except Exception as e: + log.error(f"{cls.__name__} failed: {e}") + return cls._error_result(f"Evaluation failed: {str(e)}") + + @classmethod + def _render_content(cls, content: str, content_type: str) -> Optional[str]: + """ + Render OCR content to image. + + Uses RenderTool if available, otherwise returns None. + """ + try: + from dingo.model.llm.agent.tools.render_tool import RenderTool + + # Get render config + params = cls.dynamic_config.parameters or {} + render_config = params.get('render_config', {}) + + if render_config: + RenderTool.update_config(render_config) + + result = RenderTool.execute(content=content, content_type=content_type) + + if result.get('success'): + return result.get('image_base64') + else: + log.warning(f"Render failed: {result.get('error')}") + return None + + except ImportError: + log.warning("RenderTool not available") + return None + except Exception as e: + log.warning(f"Render error: {e}") + return None + + @classmethod + def _judge(cls, original_image: str, rendered_base64: str) -> Dict[str, Any]: + """ + Use VLM to compare original vs rendered image. + + Returns: + Dict with 'is_correct' and 'reason' + """ + try: + # Build multimodal message + messages = cls._build_judge_message(original_image, rendered_base64) + response = cls.send_messages(messages) + + # Parse XML response + return cls._parse_response(response) + + except Exception as e: + log.error(f"Judge failed: {e}") + return { + 'is_correct': False, + 'reason': f"Judge failed: {str(e)}" + } + + @classmethod + def _build_judge_message(cls, original_image: str, rendered_base64: str) -> List[Dict]: + """Build multimodal message with two images.""" + # Load original image + if os.path.exists(original_image): + with open(original_image, 'rb') as f: + original_base64 = base64.b64encode(f.read()).decode('utf-8') + else: + original_base64 = original_image + + return [ + { + "role": "user", + "content": [ + {"type": "text", "text": cls.JUDGE_PROMPT}, + { + "type": "image_url", + "image_url": {"url": f"data:image/png;base64,{original_base64}"} + }, + { + "type": "image_url", + "image_url": {"url": f"data:image/png;base64,{rendered_base64}"} + } + ] + } + ] + + @classmethod + def _parse_response(cls, response: str) -> Dict[str, Any]: + """ + Parse VLM's XML response. + + Expected format: + ... + true / false + """ + try: + reason_match = re.search(r'(.*?)', response, re.DOTALL) + answer_match = re.search(r'(.*?)', response, re.DOTALL) + + reason = reason_match.group(1).strip() if reason_match else response[:500] + answer_text = answer_match.group(1).strip().lower() if answer_match else "" + + is_correct = "true" in answer_text and "false" not in answer_text + + return { + 'is_correct': is_correct, + 'reason': reason + } + + except Exception as e: + log.error(f"Parse failed: {e}") + return { + 'is_correct': False, + 'reason': response[:500] + } + + @classmethod + def _build_result(cls, judge_result: Dict, content: str) -> EvalDetail: + """Build EvalDetail from judge result.""" + result = EvalDetail(metric=cls.__name__) + + is_correct = judge_result.get('is_correct', False) + reason = judge_result.get('reason', '') + + result.score = 1.0 if is_correct else 0.0 + result.status = not is_correct # True if there's an issue + + if is_correct: + result.label = [QualityLabel.QUALITY_GOOD] + result.reason = [ + "✅ OCR content verified correct", + "", + "Judge reason:", + reason + ] + else: + result.label = ["QUALITY_BAD_OCR.VISUAL_MISMATCH"] + result.reason = [ + "❌ OCR content has errors", + "", + "Judge reason:", + reason, + "", + "OCR content evaluated:", + content[:300] + "..." if len(content) > 300 else content + ] + + return result + + @classmethod + def _text_only_comparison(cls, image: str, content: str) -> EvalDetail: + """Fallback when render is not available.""" + result = EvalDetail(metric=cls.__name__) + result.score = 0.5 + result.status = True + result.label = ["QUALITY_UNKNOWN.RENDER_FAILED"] # 通过 label 标识渲染失败 + result.reason = [ + "⚠️ Could not render OCR content for visual comparison", + "Render tool may not be available or content format unsupported", + "", + "OCR content:", + content[:300] + "..." if len(content) > 300 else content + ] + return result + + @classmethod + def _error_result(cls, message: str) -> EvalDetail: + """Create error result.""" + result = EvalDetail(metric=cls.__name__) + result.status = True + result.label = [f"{QualityLabel.QUALITY_BAD_PREFIX}EVAL_ERROR"] + result.reason = [f"❌ {message}"] + return result + + @classmethod + def _get_image(cls, input_data: Data) -> Optional[str]: + """Extract image from input data.""" + if hasattr(input_data, 'image'): + img = input_data.image + if isinstance(img, list) and img: + return img[0] + return img + return None + + @classmethod + def _get_content(cls, input_data: Data) -> Optional[str]: + """Extract OCR content from input data.""" + if hasattr(input_data, 'content'): + return input_data.content + return None + + @classmethod + def _get_content_type(cls, input_data: Data) -> str: + """Get content type.""" + if hasattr(input_data, 'content_type') and input_data.content_type: + return input_data.content_type + + params = cls.dynamic_config.parameters or {} + return params.get('content_type', 'text') diff --git a/docs/en/vlm_render_judge_guide.md b/docs/en/vlm_render_judge_guide.md new file mode 100644 index 00000000..3071691f --- /dev/null +++ b/docs/en/vlm_render_judge_guide.md @@ -0,0 +1,267 @@ +# VLMRenderJudge - Visual OCR Quality Evaluation Guide + +This guide introduces **VLMRenderJudge**, a visual comparison-based OCR quality evaluation metric in Dingo. It implements the **"Render → Judge"** pattern by rendering OCR results as images and comparing them with original images using VLM. + +## 🎯 Overview + +VLMRenderJudge is an innovative OCR quality assessment method that evaluates accuracy through visual comparison rather than text comparison. It's particularly suitable for: + +- **Mathematical Formula Recognition**: Accurately evaluate symbols, subscripts/superscripts, fractions, and other details +- **Table Structure Recognition**: Verify table borders, cell alignment, merging, and structural information +- **Document Layout Assessment**: Detect whether paragraphs, titles, lists, and other layout elements are correctly recognized +- **Multilingual OCR Evaluation**: Unified evaluation of OCR quality across different languages +- **Iterative OCR Optimization**: Serves as the Judge component in "Judge → Refine" iterative workflows + +### Comparison with Traditional Methods + +| Method | Advantages | Disadvantages | +|--------|-----------|---------------| +| **Text Similarity** (CER/WER) | Fast, quantifiable | Cannot assess format, layout, mathematical symbols | +| **Edit Distance** (Levenshtein) | Simple, intuitive | Symbol order sensitive, no visual perception | +| **VLMRenderJudge** | Visually accurate, supports complex formats, close to human judgment | Requires VLM API, depends on rendering tools | + +--- + +## 🔧 Core Principles + +### Render → Judge Workflow + +``` +┌──────────────┐ +│ Original Image│ ────────┐ +└──────────────┘ │ + ▼ + ┌──────────┐ ┌─────────────┐ + │ VLM │────▶│ EvalDetail │ + │ Judge │ │ score: 0/1 │ + └──────────┘ └─────────────┘ +┌──────────────┐ ▲ +│ OCR Result │ │ +│ (text) │─────────┘ +│ ↓ Render │ (Compare two images) +│ ┌──────────┐ │ +│ │ Rendered │ │ +│ └──────────┘ │ +└──────────────┘ +``` + +**Core Steps**: +1. **Receive Input**: Original image + OCR recognized text +2. **Render OCR Result**: Render as image based on content type (text/equation/table) +3. **VLM Visual Comparison**: Submit both images to VLM to judge consistency +4. **Output Result**: + - `score = 1.0`: OCR completely correct (QUALITY_GOOD) + - `score = 0.0`: OCR has errors (QUALITY_BAD_OCR.VISUAL_MISMATCH) + - `score = 0.5`: Render failed, cannot judge (QUALITY_UNKNOWN.RENDER_FAILED) + +--- + +## 📋 Requirements + +### Environment Dependencies + +```bash +# Basic dependencies +pip install dingo pillow + +# LaTeX rendering (for equation type, optional) +# macOS +brew install mactex-no-gui imagemagick + +# Ubuntu/Debian +sudo apt-get install texlive-xetex imagemagick +``` + +### Data Format + +```python +from dingo.io.input import Data + +data = Data( + data_id="test_1", + image="path/to/original_image.png", # Original document image (required) + content="The quick brown fox...", # OCR recognized text (required) + content_type="text" # Content type (optional, default "text") +) +``` + +--- + +## 🚀 Quick Start + +### Example 1: Standalone OCR Quality Evaluation + +```python +from dingo.config.input_args import InputArgs +from dingo.exec import Executor + +# Configure evaluator +args = InputArgs( + input_path="test_data.jsonl", + dataset={ + "source": "local", + "format": "jsonl" + }, + evaluator=[{ + "fields": { + "image": "image", + "content": "content", + "content_type": "content_type" + }, + "evals": [{ + "name": "VLMRenderJudge", + "config": { + "model": "gpt-4o", + "key": "your-api-key", + "api_url": "https://api.openai.com/v1", + "parameters": { + "max_tokens": 4000, + "temperature": 0, + "render_config": { + "density": 150, # LaTeX render DPI + "pad": 20, # Image padding + "cjk_font": None # CJK font for LaTeX (auto-detect by default) + } + } + } + }] + }] +) + +# Execute evaluation +executor = Executor.exec_map["local"](args) +summary = executor.execute() + +print(f"Evaluation complete: {summary.score:.2f}%") +print(f"Correct: {summary.num_good}/{summary.total}") +``` + +### Example 2: Test Data Format + +**test_data.jsonl** example: + +```jsonl +{"image": "images/doc1.png", "content": "The quick brown fox jumps over the lazy dog.", "content_type": "text"} +{"image": "images/formula1.png", "content": "E = mc^2", "content_type": "equation"} +{"image": "images/table1.png", "content": "
AB
", "content_type": "table"} +``` + +--- + +## 🔥 Complete Configuration + +### Python Code + +```python +from dingo.model.llm.vlm_render_judge import VLMRenderJudge +from dingo.io.input import Data + +# 1. Configure VLM model +VLMRenderJudge.set_config({ + "model": "gpt-4o", + "key": "your-api-key", + "api_url": "https://api.openai.com/v1", + "parameters": { + "max_tokens": 4000, + "temperature": 0, + "render_config": { + "density": 150, # LaTeX render DPI (72-300) + "pad": 20, # Image padding + "timeout": 60, # Render timeout (seconds) + "font_path": None, # Custom font for text rendering (optional) + "cjk_font": None # CJK font for LaTeX (auto-detect: SimSun/PingFang SC/Noto Sans CJK SC) + } + } +}) + +# 2. Prepare test data +data = Data( + data_id="test_1", + image="test/images/formula.png", + content="\\int_{0}^{\\infty} e^{-x^2} dx = \\frac{\\sqrt{\\pi}}{2}", + content_type="equation" +) + +# 3. Execute evaluation +result = VLMRenderJudge.eval(data) + +# 4. View results +print(f"Score: {result.score}") +print(f"Label: {result.label}") +print(f"Reason:\n{chr(10).join(result.reason)}") +``` + +--- + +## 🔄 Integration with AgentIterativeOCR + +VLMRenderJudge can serve as the **Judge component** in iterative OCR optimization: + +```python +from dingo.model.llm.agent.agent_iterative_ocr import AgentIterativeOCR +from dingo.config.input_args import InputArgs + +# Configure iterative OCR evaluation +args = InputArgs( + input_path="test_data.jsonl", + dataset={"source": "local", "format": "jsonl"}, + evaluator=[{ + "fields": { + "image": "image", + "content": "initial_ocr_result" + }, + "evals": [{ + "name": "AgentIterativeOCR", + "config": { + "model": "gpt-4o", + "key": "your-api-key", + "api_url": "https://api.openai.com/v1", + "parameters": { + "max_iterations": 3, + "content_type": "equation" + } + } + }] + }] +) +``` + +--- + +## 💡 Best Practices + +1. **Choose appropriate content_type** +2. **Use reasonable batch_size** for batch evaluation +3. **Adjust temperature** for different strictness levels +4. **Save rendered images** for debugging +5. **Use streaming** for large datasets + +--- + +## 📊 Metrics Interpretation + +| Score | Meaning | Label | Action | +|-------|---------|-------|--------| +| 1.0 | Completely correct | QUALITY_GOOD | No action needed | +| 0.0 | Has errors | QUALITY_BAD_OCR.VISUAL_MISMATCH | Fix or re-OCR | +| 0.5 | Render failed | QUALITY_UNKNOWN.RENDER_FAILED | Check render environment | + +--- + +## 🔗 Related Resources + +- **Example Script**: `examples/ocr/vlm_render_judge.py` +- **Test Data**: `test/data/img_OCR_iterative/` +- **API Docs**: `dingo.model.llm.vlm_render_judge.VLMRenderJudge` +- **Related Tools**: + - `RenderTool`: OCR content rendering + - `AgentIterativeOCR`: Iterative OCR optimization +- **Reference**: [MinerU_Metis](https://github.com/opendatalab/MinerU) - Original Render-Judge implementation + +--- + +## 📝 Changelog + +- **v1.0** (2026-01): Initial release with text/equation/table support +- Aligned with MinerU_Metis Render-Judge pattern +- Supports standalone and Agent integration diff --git a/docs/metrics.md b/docs/metrics.md index 4dd7bdd0..337d64de 100644 --- a/docs/metrics.md +++ b/docs/metrics.md @@ -51,6 +51,7 @@ This document provides comprehensive information about all quality metrics used |------|--------|-------------|--------------|-------------------|----------| | `LLMClassifyQR` | LLMClassifyQR | Identifies images as CAPTCHA, QR code, or normal images | Internal Implementation | N/A | N/A | | `VLMOCRUnderstanding` | VLMOCRUnderstanding | 评估多模态模型对图片中文字内容的识别和理解能力,使用DeepSeek-OCR作为Ground Truth | [DeepSeek-OCR: Contexts Optical Compression](https://github.com/deepseek-ai/DeepSeek-OCR) | [📊 See Results](通过对比VLM输出与OCR ground truth,识别文字遗漏、错误、幻觉等问题) | N/A | +| `VLMRenderJudge` | VLMRenderJudge | VLM-based OCR quality evaluation through visual render-compare | Internal Implementation | N/A | N/A | ### Rule-Based TEXT Quality Metrics diff --git a/docs/vlm_render_judge_guide.md b/docs/vlm_render_judge_guide.md new file mode 100644 index 00000000..2b05efc6 --- /dev/null +++ b/docs/vlm_render_judge_guide.md @@ -0,0 +1,528 @@ +# VLMRenderJudge - 基于视觉渲染的 OCR 质量评估指南 + +本指南介绍如何在 Dingo 中使用 **VLMRenderJudge**,一个基于视觉比较的 OCR 质量评估指标。该指标实现了 **"Render → Judge"** 模式,通过渲染 OCR 结果为图像并与原图进行 VLM 比较,从而准确评估 OCR 质量。 + +## 🎯 功能概述 + +VLMRenderJudge 是一种创新的 OCR 质量评估方法,通过视觉比较而非文本比较来判断 OCR 结果的准确性。该方法特别适用于: + +- **数学公式识别评估**: 准确评估复杂公式中的符号、上下标、分数等细节 +- **表格结构识别评估**: 验证表格边框、单元格对齐、合并等结构信息 +- **文档布局评估**: 检测段落、标题、列表等布局元素是否正确识别 +- **多语言 OCR 评估**: 统一评估不同语言的 OCR 质量 +- **迭代式 OCR 优化**: 作为 Judge 环节支持 "Judge → Refine" 迭代优化流程 + +### 与传统方法的对比 + +| 方法 | 优势 | 劣势 | +|------|------|------| +| **文本相似度** (CER/WER) | 快速、可量化 | 无法评估格式、布局、数学符号 | +| **编辑距离** (Levenshtein) | 简单直观 | 对符号顺序敏感,无视觉感知 | +| **VLMRenderJudge** | 视觉准确、支持复杂格式、接近人类判断 | 需要 VLM API、依赖渲染工具 | + +--- + +## 🔧 核心原理 + +### Render → Judge 流程 + +``` +┌─────────────┐ +│ 原始图像 │ ────────┐ +└─────────────┘ │ + ▼ + ┌──────────┐ ┌─────────────┐ + │ VLM │────▶│ EvalDetail │ + │ Judge │ │ score: 0/1 │ + └──────────┘ └─────────────┘ +┌─────────────┐ ▲ +│ OCR 结果 │ │ +│ (文本) │─────────┘ +│ ↓ 渲染 │ (两张图片比较) +│ ┌─────────┐ │ +│ │渲染图像 │ │ +│ └─────────┘ │ +└─────────────┘ +``` + +**核心步骤**: +1. **接收输入**: 原始图像 + OCR 识别文本 +2. **渲染 OCR 结果**: 根据内容类型(text/equation/table)渲染为图像 +3. **VLM 视觉比较**: 将原图和渲染图提交给 VLM,判断是否一致 +4. **输出结果**: + - `score = 1.0`: OCR 完全正确 (QUALITY_GOOD) + - `score = 0.0`: OCR 有错误 (QUALITY_BAD_OCR.VISUAL_MISMATCH) + - `score = 0.5`: 渲染失败,无法判断 (QUALITY_UNKNOWN.RENDER_FAILED) + +### 评判标准 + +VLM 使用严格的一致性规则(来自 MinerU_Metis 项目): + +- ✅ **忽略差异**: 字体样式、空格数量、换行位置、中英文标点互换 +- ❌ **标记为错误**: 字符缺失/增加/替换、符号错误、上下标错误、数字错误 + +--- + +## 📋 使用要求 + +### 环境依赖 + +```bash +# 基础依赖 +pip install dingo pillow + +# LaTeX 渲染(用于公式评估,可选) +# macOS +brew install mactex-no-gui imagemagick + +# Ubuntu/Debian +sudo apt-get install texlive-xetex imagemagick +``` + +### 数据格式要求 + +```python +from dingo.io.input import Data + +data = Data( + data_id="test_1", + image="path/to/original_image.png", # 原始文档图像(必需) + content="The quick brown fox...", # OCR 识别文本(必需) + content_type="text" # 内容类型(可选,默认 "text") +) +``` + +#### 支持的 content_type + +| 类型 | 说明 | 渲染方式 | +|------|------|---------| +| `text` | 纯文本、段落、标题 | PIL ImageDraw | +| `equation` | LaTeX 数学公式 | xelatex | +| `table` | HTML 表格 | HTML to Image | + +--- + +## 🚀 快速开始 + +### 示例 1: 独立使用评估 OCR 质量 + +```python +from dingo.config.input_args import InputArgs +from dingo.exec import Executor + +# 配置评估器 +args = InputArgs( + input_path="test_data.jsonl", + dataset={ + "source": "local", + "format": "jsonl" + }, + evaluator=[{ + "fields": { + "image": "image", # 原始图片字段 + "content": "content", # OCR 文本字段 + "content_type": "content_type" + }, + "evals": [{ + "name": "VLMRenderJudge", + "config": { + "model": "gpt-4o", + "key": "your-api-key", + "api_url": "https://api.openai.com/v1", + "parameters": { + "max_tokens": 4000, + "temperature": 0, + "render_config": { + "density": 150, # LaTeX 渲染 DPI + "pad": 20 # 图像边距 + } + } + } + }] + }] +) + +# 执行评估 +executor = Executor.exec_map["local"](args) +summary = executor.execute() + +print(f"评估完成: {summary.score:.2f}%") +print(f"正确数量: {summary.num_good}/{summary.total}") +``` + +### 示例 2: 测试数据格式 + +**test_data.jsonl** 示例: + +```jsonl +{"image": "images/doc1.png", "content": "The quick brown fox jumps over the lazy dog.", "content_type": "text"} +{"image": "images/formula1.png", "content": "E = mc^2", "content_type": "equation"} +{"image": "images/table1.png", "content": "
AB
", "content_type": "table"} +``` + +--- + +## 🔥 完整配置示例 + +### Python 代码方式 + +```python +from dingo.model.llm.vlm_render_judge import VLMRenderJudge +from dingo.io.input import Data + +# 1. 配置 VLM 模型 +VLMRenderJudge.set_config({ + "model": "gpt-4o", + "key": "your-api-key", + "api_url": "https://api.openai.com/v1", + "parameters": { + "max_tokens": 4000, + "temperature": 0, + "render_config": { + "density": 150, # LaTeX 渲染 DPI (72-300) + "pad": 20, # 图像边距 + "timeout": 60, # 渲染超时时间(秒) + "font_path": None, # 文本渲染自定义字体路径(可选) + "cjk_font": None # LaTeX CJK 字体名称(可选,如 'SimSun'/'PingFang SC'/'Noto Sans CJK SC') + } + } +}) + +# 2. 准备测试数据 +data = Data( + data_id="test_1", + image="test/images/formula.png", + content="\\int_{0}^{\\infty} e^{-x^2} dx = \\frac{\\sqrt{\\pi}}{2}", + content_type="equation" +) + +# 3. 执行评估 +result = VLMRenderJudge.eval(data) + +# 4. 查看结果 +print(f"评分: {result.score}") +print(f"标签: {result.label}") +print(f"原因:\n{chr(10).join(result.reason)}") +``` + +### 输出示例 + +**正确的 OCR 结果**: +``` +评分: 1.0 +标签: ['QUALITY_GOOD'] +原因: +✅ OCR content verified correct + +Judge reason: +Both images show the same mathematical equation with consistent symbols, +subscripts, and superscripts. The content is fully consistent. +``` + +**错误的 OCR 结果**: +``` +评分: 0.0 +标签: ['QUALITY_BAD_OCR.VISUAL_MISMATCH'] +原因: +❌ OCR content has errors + +Judge reason: +GT has "e^{-x^2}" while OCR has "e^-x2". The OCR result is missing the +superscript braces, causing incorrect rendering. This is an actual symbol +difference. + +OCR content evaluated: +\int_{0}^{\infty} e^-x2 dx = \frac{\sqrt{\pi}}{2} +``` + +--- + +## 🔄 与 AgentIterativeOCR 配合使用 + +VLMRenderJudge 可作为 **Judge 环节**,与 OCR Refiner 组成迭代优化流程: + +```python +from dingo.model.llm.agent.agent_iterative_ocr import AgentIterativeOCR +from dingo.config.input_args import InputArgs + +# 配置迭代式 OCR 评估 +args = InputArgs( + input_path="test_data.jsonl", + dataset={"source": "local", "format": "jsonl"}, + evaluator=[{ + "fields": { + "image": "image", + "content": "initial_ocr_result" # 初始 OCR 结果 + }, + "evals": [{ + "name": "AgentIterativeOCR", + "config": { + # VLM Judge 配置 + "model": "gpt-4o", + "key": "your-api-key", + "api_url": "https://api.openai.com/v1", + "parameters": { + "max_iterations": 3, # 最大迭代次数 + "content_type": "equation" + } + } + }] + }] +) + +# 执行迭代评估 +executor = Executor.exec_map["local"](args) +summary = executor.execute() +``` + +**迭代流程**: +1. **Judge**: VLMRenderJudge 判断当前 OCR 是否正确 +2. **Refine**: 如果不正确,调用 VLM 分析错误并生成改进版本 +3. **Repeat**: 重复步骤 1-2,直到正确或达到最大迭代次数 + +--- + +## ⚙️ 渲染配置详解 + +### render_config 参数 + +```python +"render_config": { + "density": 150, # LaTeX 渲染 DPI (默认: 150) + # - 72: 低质量,速度快 + # - 150: 平衡质量与速度(推荐) + # - 300: 高质量,速度慢 + + "pad": 20, # 图像边距,单位像素 (默认: 20) + + "timeout": 60, # 渲染超时时间(秒)(默认: 60) + + "font_path": None, # 文本渲染自定义字体路径(可选) + # 例如: "/usr/share/fonts/SimSun.ttc" + + "cjk_font": None # LaTeX CJK 字体名称(可选) + # - Windows: "SimSun", "Microsoft YaHei" + # - macOS: "PingFang SC", "Heiti SC" + # - Linux: "Noto Sans CJK SC", "WenQuanYi Micro Hei" +} +``` + +### 字体选择逻辑 + +#### 文本渲染字体(font_path) + +用于普通文本渲染,按以下顺序尝试: + +1. `render_config.font_path`(如果指定) +2. `/System/Library/Fonts/Helvetica.ttc` (macOS) +3. `Arial` (Windows/Linux) +4. `Arial Unicode MS` (Unicode 支持) +5. `DejaVuSans` (Linux) +6. `SimSun` (中文字体) +7. 系统默认字体(最后备选) + +#### LaTeX CJK 字体(cjk_font) + +用于 LaTeX 公式中的中文字符渲染,**自动跨平台适配**: + +- 如果指定 `cjk_font`:使用指定字体 +- 如果未指定:**自动检测操作系统**并使用默认字体 + - Windows: `SimSun` (宋体) + - macOS: `PingFang SC` (苹方) + - Linux: `Noto Sans CJK SC` + +**跨平台配置示例**: + +```python +# 方式 1: 明确指定字体(推荐,确保一致性) +"render_config": { + "cjk_font": "SimSun" # 确保所有平台都安装了此字体 +} + +# 方式 2: 自动检测(默认,方便但可能导致不同平台渲染结果不同) +"render_config": { + "cjk_font": None # 自动根据操作系统选择 +} +``` + +**建议**: +- **英文文档**:使用默认配置 +- **中文文档**: + - 团队协作:明确指定 `cjk_font`,确保所有成员安装相同字体 + - 个人使用:使用自动检测(`cjk_font=None`) +- **混合文档**:指定支持中英文的字体(如 `Arial Unicode MS` + `cjk_font="PingFang SC"`) + +--- + +## 🎯 应用场景 + +### 场景 1: OCR 模型对比评估 + +```python +# 评估多个 OCR 模型的输出质量 +models = ["paddleocr", "tesseract", "mineru", "surya"] + +for model in models: + args = InputArgs( + input_path=f"ocr_results_{model}.jsonl", + evaluator=[{ + "fields": {"image": "image", "content": "ocr_text"}, + "evals": [{"name": "VLMRenderJudge", "config": llm_config}] + }] + ) + + summary = Executor.exec_map["local"](args).execute() + print(f"{model}: {summary.score:.2f}%") +``` + +### 场景 2: 数据集质量验证 + +```python +# 验证 OCR 训练数据集的标注质量 +args = InputArgs( + input_path="training_data_with_labels.jsonl", + evaluator=[{ + "fields": { + "image": "image", + "content": "ground_truth_label" # 人工标注的 GT + }, + "evals": [{"name": "VLMRenderJudge", "config": llm_config}] + }] +) + +summary = Executor.exec_map["local"](args).execute() + +# 标记质量有问题的样本 +bad_samples = [ + item for item in summary.details + if item.eval_status and item.eval_details[0].score == 0.0 +] + +print(f"发现 {len(bad_samples)} 个质量问题样本") +``` + +### 场景 3: 实时 OCR 质量监控 + +```python +from dingo.model.llm.vlm_render_judge import VLMRenderJudge +from dingo.io.input import Data + +def ocr_with_quality_check(image_path, ocr_function): + """带质量检查的 OCR 函数""" + # 1. 执行 OCR + ocr_text = ocr_function(image_path) + + # 2. 质量评估 + data = Data(image=image_path, content=ocr_text) + result = VLMRenderJudge.eval(data) + + # 3. 根据质量分数决定是否重试 + if result.score < 0.5: # 质量不佳 + print(f"⚠️ OCR quality low: {result.score}") + # 可以触发重试、人工审核等 + + return { + "text": ocr_text, + "quality_score": result.score, + "is_reliable": result.score >= 0.8 + } +``` + +--- + +## 💡 最佳实践 + +### 1. 选择合适的 content_type + +```python +# ✅ 正确 +{"content": "x^2 + y^2 = r^2", "content_type": "equation"} + +# ❌ 错误 +{"content": "x^2 + y^2 = r^2", "content_type": "text"} # 无法正确渲染上标 +``` + +### 2. 批量评估时使用合理的 batch_size + +```python +args = InputArgs( + executor={ + "batch_size": 10, # 平衡速度与内存 + "num_workers": 2 # 并发数量 + } +) +``` + +### 3. 针对不同场景调整 temperature + +```python +# 严格评估(推荐) +"parameters": {"temperature": 0} + +# 宽松评估(容忍小差异) +"parameters": {"temperature": 0.3} +``` + +### 4. 保存渲染图像用于调试 + +```python +from dingo.model.llm.agent.tools import RenderTool + +# 渲染并保存图像 +result = RenderTool.execute( + content="test content", + content_type="text", + output_path="debug_render.png" # 保存到文件 +) +``` + +### 5. 处理大规模数据集 + +```python +# 使用流式处理,避免内存溢出 +args = InputArgs( + input_path="large_dataset.jsonl", + executor={ + "batch_size": 5, # 小批量 + "checkpoint_interval": 100 # 定期保存检查点 + } +) +``` + +--- + +## 📊 评估指标解读 + +### score 含义 + +| Score | 含义 | 标签 | 建议 | +|-------|------|------|------| +| 1.0 | 完全正确 | QUALITY_GOOD | 无需处理 | +| 0.0 | 有错误 | QUALITY_BAD_OCR.VISUAL_MISMATCH | 需要修正或重新 OCR | +| 0.5 | 渲染失败 | QUALITY_UNKNOWN.RENDER_FAILED | 检查渲染环境 | + +### reason 字段说明 + +```python +result.reason = [ + "✅ OCR content verified correct", # 或 "❌ OCR content has errors" + "", + "Judge reason:", + "VLM 的详细判断理由...", + "", + "OCR content evaluated:", + "被评估的 OCR 文本内容(前300字)" +] +``` + +--- + +## 🔗 相关资源 + +- **示例脚本**: `examples/ocr/vlm_render_judge.py` +- **测试数据**: `test/data/img_OCR_iterative/` +- **API 文档**: `dingo.model.llm.vlm_render_judge.VLMRenderJudge` +- **相关工具**: + - `RenderTool`: OCR 内容渲染工具 + - `AgentIterativeOCR`: 迭代式 OCR 优化 +- **参考项目**: [MinerU_Metis](https://github.com/opendatalab/MinerU) - Render-Judge 模式的原始实现 diff --git a/test/data/img_OCR_iterative/formula/input.png b/test/data/img_OCR_iterative/formula/input.png new file mode 100644 index 00000000..c9da7637 Binary files /dev/null and b/test/data/img_OCR_iterative/formula/input.png differ diff --git a/test/data/img_OCR_iterative/formula/render.png b/test/data/img_OCR_iterative/formula/render.png new file mode 100644 index 00000000..598fd290 Binary files /dev/null and b/test/data/img_OCR_iterative/formula/render.png differ diff --git a/test/data/img_OCR_iterative/simple_text/english.png b/test/data/img_OCR_iterative/simple_text/english.png new file mode 100644 index 00000000..d5475e76 Binary files /dev/null and b/test/data/img_OCR_iterative/simple_text/english.png differ diff --git a/test/data/img_OCR_iterative/simple_text/mixed.png b/test/data/img_OCR_iterative/simple_text/mixed.png new file mode 100644 index 00000000..ffde8634 Binary files /dev/null and b/test/data/img_OCR_iterative/simple_text/mixed.png differ diff --git a/test/data/img_OCR_iterative/simple_text/numbers.png b/test/data/img_OCR_iterative/simple_text/numbers.png new file mode 100644 index 00000000..b869b328 Binary files /dev/null and b/test/data/img_OCR_iterative/simple_text/numbers.png differ diff --git a/test/data/img_OCR_iterative/test_agent_iterative_ocr.jsonl b/test/data/img_OCR_iterative/test_agent_iterative_ocr.jsonl new file mode 100644 index 00000000..e69c173d --- /dev/null +++ b/test/data/img_OCR_iterative/test_agent_iterative_ocr.jsonl @@ -0,0 +1,3 @@ +{"image": "test/data/img_OCR_iterative/text/input.png", "content": "Ths s a smple txt for OCR tstng.", "content_type": "text", "expected_iterations": 2, "description": "多处拼写错误,需要迭代修正"} +{"image": "test/data/img_OCR_iterative/formula/input.png", "content": "E = m c 2", "content_type": "equation", "expected_iterations": 1, "description": "公式格式需要修正"} +{"image": "test/data/img_OCR_iterative/text/input.png", "content": "This is a sample text for OCR testing.", "content_type": "text", "expected_iterations": 0, "description": "初始结果正确,无需迭代"} diff --git a/test/data/img_OCR_iterative/test_vlm_render_judge.jsonl b/test/data/img_OCR_iterative/test_vlm_render_judge.jsonl new file mode 100644 index 00000000..5b69dd6f --- /dev/null +++ b/test/data/img_OCR_iterative/test_vlm_render_judge.jsonl @@ -0,0 +1,4 @@ +{"image": "test/data/img_OCR_iterative/text/input.png", "content": "Theorem 4.6. For i ∈ ℤ≥1, the linear operator δ̃ᵢ commute with Ш*", "content_type": "text", "expected_result": "correct", "description": "正确识别数学定理文本"} +{"image": "test/data/img_OCR_iterative/text/input.png", "content": "Theorem 4.6. For i ∈ Z≥1, the linear operator δi commute with Ш*", "content_type": "text", "expected_result": "incorrect", "description": "缺少上标符号(ℤ→Z)和波浪号(δ̃ᵢ→δi)"} +{"image": "test/data/img_OCR_iterative/formula/input.png", "content": "= 2 ∑_{1≤u≤√x} [x/u] - [√x]²\n= 2 ∑_{1≤u≤√x} (x/u - {x/u}) - (√x - {√x})²\n= 2x ∑_{1≤u≤√x} 1/u - 2 ∑_{1≤u≤√x} {x/u} - x + O(√x)\n= 2x (log√x + γ + O(1/√x)) - x + O(√x)\n= x log x + (2γ - 1)x + O(√x).", "content_type": "equation", "expected_result": "correct", "description": "正确识别复杂数学推导公式"} +{"image": "test/data/img_OCR_iterative/formula/input.png", "content": "= 2 ∑ [x/u] - [√x]²\n= 2 ∑ (x/u - {x/u}) - (√x - {√x})²\n= 2x ∑ 1/u - 2 ∑ {x/u} - x + O(√x)\n= 2x (log√x + γ + O(1/√x)) - x + O(√x)\n= x log x + (2γ - 1)x + O(√x).", "content_type": "equation", "expected_result": "incorrect", "description": "缺少求和范围(1≤u≤√x)"} diff --git a/test/data/img_OCR_iterative/test_vlm_render_judge_aligned.jsonl b/test/data/img_OCR_iterative/test_vlm_render_judge_aligned.jsonl new file mode 100644 index 00000000..994bf400 --- /dev/null +++ b/test/data/img_OCR_iterative/test_vlm_render_judge_aligned.jsonl @@ -0,0 +1,6 @@ +{"image": "test/data/img_OCR_iterative/simple_text/english.png", "content": "The quick brown fox jumps over the lazy dog.", "content_type": "text", "expected_result": "correct", "description": "完全正确的OCR结果"} +{"image": "test/data/img_OCR_iterative/simple_text/english.png", "content": "The quick brown fox jumps over the lzy dog.", "content_type": "text", "expected_result": "incorrect", "description": "拼写错误:lzy->lazy"} +{"image": "test/data/img_OCR_iterative/simple_text/numbers.png", "content": "Price: $123.45 (Discount: 20%)", "content_type": "text", "expected_result": "correct", "description": "完全正确的数字和符号识别"} +{"image": "test/data/img_OCR_iterative/simple_text/numbers.png", "content": "Price: $12345 (Discount: 20%)", "content_type": "text", "expected_result": "incorrect", "description": "数字错误:12345->123.45"} +{"image": "test/data/img_OCR_iterative/simple_text/mixed.png", "content": "Meeting at 3:00 PM in Room #405", "content_type": "text", "expected_result": "correct", "description": "完全正确的混合文本识别"} +{"image": "test/data/img_OCR_iterative/simple_text/mixed.png", "content": "Meeting at 3:00 PM in Room #45", "content_type": "text", "expected_result": "incorrect", "description": "房间号错误:#45->#405"} diff --git a/test/data/img_OCR_iterative/test_vlm_render_judge_simple.jsonl b/test/data/img_OCR_iterative/test_vlm_render_judge_simple.jsonl new file mode 100644 index 00000000..92306d42 --- /dev/null +++ b/test/data/img_OCR_iterative/test_vlm_render_judge_simple.jsonl @@ -0,0 +1,4 @@ +{"image": "test/data/img_OCR_iterative/text/input.png", "content": "Theorem 4.6. For i ∈ ℤ≥₁, the linear operator δ̃ᵢ commute with ш*", "content_type": "text", "expected_result": "correct", "description": "使用Unicode下标符号₁"} +{"image": "test/data/img_OCR_iterative/text/input.png", "content": "Theorem 4.6. For i ∈ Z≥1, the linear operator δi commute with ш*", "content_type": "text", "expected_result": "incorrect", "description": "缺少特殊符号"} +{"image": "test/data/img_OCR_iterative/formula/input.png", "content": "= 2 ∑_{1≤u≤√x} ⌊x/u⌋ - ⌊√x⌋²\n= 2 ∑_{1≤u≤√x} (x/u - {x/u}) - (√x - {√x})²\n= 2x ∑_{1≤u≤√x} 1/u - 2 ∑_{1≤u≤√x} {x/u} - x + O(√x)\n= 2x (log√x + γ + O(1/√x)) - x + O(√x)\n= x log x + (2γ - 1)x + O(√x).", "content_type": "equation", "expected_result": "correct", "description": "使用Floor符号⌊⌋"} +{"image": "test/data/img_OCR_iterative/formula/input.png", "content": "= 2 ∑ [x/u] - [√x]²\n= 2 ∑ (x/u - {x/u}) - (√x - {√x})²\n= 2x ∑ 1/u - 2 ∑ {x/u} - x + O(√x)\n= 2x (log√x + γ + O(1/√x)) - x + O(√x)\n= x log x + (2γ - 1)x + O(√x).", "content_type": "equation", "expected_result": "incorrect", "description": "缺少求和范围且用错误符号"} diff --git a/test/data/img_OCR_iterative/text/input.png b/test/data/img_OCR_iterative/text/input.png new file mode 100644 index 00000000..a6b9b077 Binary files /dev/null and b/test/data/img_OCR_iterative/text/input.png differ diff --git a/test/data/img_OCR_iterative/text/render.png b/test/data/img_OCR_iterative/text/render.png new file mode 100644 index 00000000..4a553d2a Binary files /dev/null and b/test/data/img_OCR_iterative/text/render.png differ diff --git a/test/scripts/model/llm/agent/tools/test_mineru_ocr_tool.py b/test/scripts/model/llm/agent/tools/test_mineru_ocr_tool.py new file mode 100644 index 00000000..830412a3 --- /dev/null +++ b/test/scripts/model/llm/agent/tools/test_mineru_ocr_tool.py @@ -0,0 +1,242 @@ +""" +MinerUOCRTool 单元测试 + +测试 MinerU API OCR 工具的核心功能 + +运行方式: +pytest test/scripts/model/llm/agent/tools/test_mineru_ocr_tool.py -v +""" + +import base64 +import tempfile +from unittest.mock import MagicMock, patch + +import pytest + +from dingo.model.llm.agent.tools.mineru_ocr_tool import MinerUOCRTool, MinerUOCRToolConfig + + +class TestMinerUOCRToolConfig: + """测试 MinerUOCRToolConfig 配置类""" + + def test_default_values(self): + """测试默认配置值""" + config = MinerUOCRToolConfig() + assert config.api_key is None + assert config.api_url == "https://mineru.net/api/v4/extract/task" + assert config.timeout == 120 + assert config.poll_interval == 3 + + def test_custom_values(self): + """测试自定义配置值""" + config = MinerUOCRToolConfig( + api_key="test_key_123", + timeout=60, + poll_interval=5 + ) + assert config.api_key == "test_key_123" + assert config.timeout == 60 + assert config.poll_interval == 5 + + +class TestMinerUOCRTool: + """测试 MinerUOCRTool 核心功能""" + + def setup_method(self): + """每个测试前的设置""" + MinerUOCRTool.config = MinerUOCRToolConfig( + api_key="test_api_key", + timeout=120, + poll_interval=3 + ) + + def test_tool_attributes(self): + """测试工具的基本属性""" + assert MinerUOCRTool.name == "mineru_ocr_tool" + assert "MinerU API" in MinerUOCRTool.description + assert isinstance(MinerUOCRTool.config, MinerUOCRToolConfig) + + def test_execute_missing_api_key(self): + """测试缺少 API key 的情况""" + MinerUOCRTool.config.api_key = None + + result = MinerUOCRTool.execute(image_path="test.png") + + assert result['success'] is False + assert 'API key not configured' in result['error'] + + def test_execute_missing_image(self): + """测试缺少图像输入的情况""" + result = MinerUOCRTool.execute() + + assert result['success'] is False + assert 'No image provided' in result['error'] + + def test_execute_with_image_path(self): + """测试使用图像路径的情况""" + # 创建临时图像文件 + with tempfile.NamedTemporaryFile(mode='wb', suffix='.png', delete=False) as f: + f.write(b'fake_image_data') + temp_path = f.name + + try: + with patch.object(MinerUOCRTool, '_submit_and_wait') as mock_submit: + mock_submit.return_value = { + 'success': True, + 'content': 'Extracted text', + 'task_id': 'test_task_123' + } + + result = MinerUOCRTool.execute(image_path=temp_path) + + assert result['success'] is True + assert result['content'] == 'Extracted text' + assert result['task_id'] == 'test_task_123' + + # 验证 _submit_and_wait 被调用 + assert mock_submit.called + call_args = mock_submit.call_args[0] + assert isinstance(call_args[0], str) # Base64 string + finally: + import os + os.unlink(temp_path) + + def test_execute_with_image_base64(self): + """测试使用 Base64 图像数据的情况""" + fake_base64 = base64.b64encode(b'fake_image_data').decode('utf-8') + + with patch.object(MinerUOCRTool, '_submit_and_wait') as mock_submit: + mock_submit.return_value = { + 'success': True, + 'content': 'OCR result from base64', + 'task_id': 'test_task_456' + } + + result = MinerUOCRTool.execute(image_base64=fake_base64) + + assert result['success'] is True + assert result['content'] == 'OCR result from base64' + assert mock_submit.called + + @patch('dingo.model.llm.agent.tools.mineru_ocr_tool.requests.post') + def test_submit_and_wait_success(self, mock_post): + """测试成功的任务提交""" + # Mock submit response + mock_submit_response = MagicMock() + mock_submit_response.json.return_value = { + "code": 0, + "msg": "success", + "data": {"task_id": "test_task_789"} + } + mock_post.return_value = mock_submit_response + + with patch.object(MinerUOCRTool, '_poll_result') as mock_poll: + mock_poll.return_value = { + 'success': True, + 'content': 'Final OCR result', + 'task_id': 'test_task_789' + } + + result = MinerUOCRTool._submit_and_wait("fake_base64", "text") + + assert result['success'] is True + assert result['content'] == 'Final OCR result' + + # 验证 API 调用 + assert mock_post.called + call_kwargs = mock_post.call_args[1] + assert 'headers' in call_kwargs + assert 'json' in call_kwargs + assert 'Bearer test_api_key' in call_kwargs['headers']['Authorization'] + + @patch('dingo.model.llm.agent.tools.mineru_ocr_tool.requests.post') + def test_submit_and_wait_api_error(self, mock_post): + """测试 API 返回错误的情况""" + mock_response = MagicMock() + mock_response.json.return_value = { + "code": 400, + "msg": "Invalid request", + "data": None + } + mock_post.return_value = mock_response + + result = MinerUOCRTool._submit_and_wait("fake_base64", "text") + + assert result['success'] is False + assert 'MinerU API error' in result['error'] + + @patch('dingo.model.llm.agent.tools.mineru_ocr_tool.requests.get') + def test_poll_result_immediate_success(self, mock_get): + """测试立即成功的任务轮询""" + mock_response = MagicMock() + mock_response.json.return_value = { + "code": 0, + "data": { + "status": "success", + "markdown": "# OCR Result\n\nExtracted content here." + } + } + mock_get.return_value = mock_response + + headers = {"Authorization": "Bearer test_key"} + result = MinerUOCRTool._poll_result("task_123", headers) + + assert result['success'] is True + assert result['content'] == "# OCR Result\n\nExtracted content here." + assert result['task_id'] == "task_123" + + def test_poll_result_task_failed(self): + """测试任务失败的情况""" + mock_response = MagicMock() + mock_response.json.return_value = { + "code": 0, + "data": { + "status": "failed", + "msg": "OCR processing failed" + } + } + + with patch('dingo.model.llm.agent.tools.mineru_ocr_tool.requests.get', return_value=mock_response): + headers = {"Authorization": "Bearer test_key"} + result = MinerUOCRTool._poll_result("task_789", headers) + + assert result['success'] is False + assert 'task failed' in result['error'].lower() + + def test_extract_content_markdown(self): + """测试从 markdown 字段提取内容""" + data = {"markdown": "# Title\n\nContent here."} + content = MinerUOCRTool._extract_content(data) + assert content == "# Title\n\nContent here." + + def test_extract_content_text(self): + """测试从 text 字段提取内容""" + data = {"text": "Plain text content"} + content = MinerUOCRTool._extract_content(data) + assert content == "Plain text content" + + def test_extract_content_pages(self): + """测试从多页结果提取内容""" + data = { + "pages": [ + {"markdown": "Page 1"}, + {"markdown": "Page 2"} + ] + } + content = MinerUOCRTool._extract_content(data) + assert "Page 1" in content + assert "Page 2" in content + + def test_execute_exception_handling(self): + """测试异常处理""" + with patch.object(MinerUOCRTool, '_submit_and_wait') as mock_submit: + mock_submit.side_effect = Exception("Unexpected error") + + result = MinerUOCRTool.execute(image_base64="fake_base64") + + assert result['success'] is False + assert 'Unexpected error' in result['error'] + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/test/scripts/model/llm/agent/tools/test_render_tool.py b/test/scripts/model/llm/agent/tools/test_render_tool.py new file mode 100644 index 00000000..8b421001 --- /dev/null +++ b/test/scripts/model/llm/agent/tools/test_render_tool.py @@ -0,0 +1,375 @@ +""" +RenderTool 单元测试 + +测试 OCR 内容渲染工具的核心功能: +1. 文本渲染 +2. LaTeX 公式渲染 +3. 配置管理 +4. 错误处理 + +运行方式: +pytest test/scripts/model/llm/agent/tools/test_render_tool.py -v +""" + +from unittest.mock import MagicMock, patch + +import pytest + +from dingo.model.llm.agent.tools.render_tool import RenderTool, RenderToolConfig + + +class TestRenderToolConfig: + """测试 RenderTool 配置""" + + def test_default_values(self): + """测试默认配置值""" + config = RenderToolConfig() + assert config.font_path is None + assert config.density == 150 + assert config.timeout == 60 + assert config.pad == 20 + + def test_custom_values(self): + """测试自定义配置值""" + config = RenderToolConfig( + font_path="/path/to/font.ttc", + density=300, + timeout=120, + pad=30 + ) + assert config.font_path == "/path/to/font.ttc" + assert config.density == 300 + assert config.timeout == 120 + assert config.pad == 30 + + def test_density_validation(self): + """测试 density 必须在 72-300 之间""" + # 有效值 + RenderToolConfig(density=72) + RenderToolConfig(density=300) + RenderToolConfig(density=150) + + # 无效值 + with pytest.raises(ValueError): + RenderToolConfig(density=71) + + with pytest.raises(ValueError): + RenderToolConfig(density=301) + + def test_pad_validation(self): + """测试 pad 必须 >= 0""" + # 有效值 + RenderToolConfig(pad=0) + RenderToolConfig(pad=20) + + # 无效值 + with pytest.raises(ValueError): + RenderToolConfig(pad=-1) + + +class TestRenderTool: + """测试 RenderTool 核心功能""" + + def setup_method(self): + """每个测试前的设置""" + RenderTool.config = RenderToolConfig() + + def test_tool_attributes(self): + """测试工具的基本属性""" + assert RenderTool.name == "render_tool" + assert "Render text, equations, or tables" in RenderTool.description + assert isinstance(RenderTool.config, RenderToolConfig) + + def test_empty_content(self): + """测试空内容返回错误""" + result = RenderTool.execute(content="", content_type="text") + assert result['success'] is False + assert 'empty' in result['error'].lower() + + def test_whitespace_content(self): + """测试纯空格内容返回错误""" + result = RenderTool.execute(content=" ", content_type="text") + assert result['success'] is False + assert 'empty' in result['error'].lower() + + def test_invalid_content_type(self): + """测试无效的 content_type - 会默认为 text""" + result = RenderTool.execute(content="test", content_type="invalid_type") + # RenderTool 不验证 content_type,无效类型会默认为 text 渲染 + assert result['success'] is True + assert result['content_type'] == "invalid_type" + + @patch('dingo.model.llm.agent.tools.render_tool.HAS_PIL', True) + @patch('dingo.model.llm.agent.tools.render_tool.Image') + @patch('dingo.model.llm.agent.tools.render_tool.ImageDraw') + @patch('dingo.model.llm.agent.tools.render_tool.ImageFont') + def test_render_text_success(self, mock_font, mock_draw, mock_image): + """测试文本渲染成功""" + # Mock PIL objects + mock_img = MagicMock() + mock_img.size = (200, 100) + mock_image.new.return_value = mock_img + + mock_draw_obj = MagicMock() + mock_draw_obj.textbbox.return_value = (0, 0, 100, 20) + mock_draw.Draw.return_value = mock_draw_obj + + mock_font.truetype.return_value = MagicMock() + + # Mock image to base64 conversion + with patch('io.BytesIO') as mock_bytesio: + mock_buffer = MagicMock() + mock_bytesio.return_value = mock_buffer + mock_buffer.getvalue.return_value = b'fake_image_data' + + result = RenderTool.execute( + content="Hello World", + content_type="text" + ) + + assert result['success'] is True + assert 'image_base64' in result + assert result['content_type'] == 'text' + + @patch('dingo.model.llm.agent.tools.render_tool.HAS_PIL', False) + def test_render_text_no_pil(self): + """测试 PIL 未安装的情况""" + result = RenderTool.execute(content="test", content_type="text") + assert result['success'] is False + assert 'PIL' in result['error'] + + @patch('dingo.model.llm.agent.tools.render_tool.subprocess.run') + @patch('dingo.model.llm.agent.tools.render_tool.os.path.exists') + @patch('dingo.model.llm.agent.tools.render_tool.Image') + @patch('tempfile.mkdtemp') + def test_render_latex_success(self, mock_mkdtemp, mock_image, mock_exists, mock_subprocess): + """测试 LaTeX 渲染成功""" + # Mock temporary directory + mock_mkdtemp.return_value = '/tmp/test_dir' + + # Mock file existence checks + mock_exists.side_effect = lambda path: True + + # Mock subprocess (xelatex and magick) + mock_process = MagicMock() + mock_process.returncode = 0 + mock_subprocess.return_value = mock_process + + # Mock image loading + mock_img = MagicMock() + mock_image.open.return_value = mock_img + + # Mock image to base64 + with patch('io.BytesIO') as mock_bytesio: + mock_buffer = MagicMock() + mock_bytesio.return_value = mock_buffer + mock_buffer.getvalue.return_value = b'fake_image_data' + + with patch('shutil.rmtree'): + result = RenderTool.execute( + content="E = mc^2", + content_type="equation" + ) + + # LaTeX 渲染需要环境支持,可能失败 + # 这里主要测试代码路径是否正常 + assert 'success' in result + + @patch('dingo.model.llm.agent.tools.render_tool.subprocess.run') + @patch('tempfile.mkdtemp') + def test_render_latex_xelatex_not_found(self, mock_mkdtemp, mock_subprocess): + """测试 xelatex 未安装的情况""" + mock_mkdtemp.return_value = '/tmp/test_dir' + + # Mock xelatex not found + mock_subprocess.side_effect = FileNotFoundError("xelatex not found") + + with patch('shutil.rmtree'): + result = RenderTool.execute( + content="E = mc^2", + content_type="equation" + ) + + assert result['success'] is False + # 错误消息是 "failed to render equation content" + assert 'failed' in result['error'].lower() + assert 'equation' in result['error'].lower() + + def test_render_with_output_path(self): + """测试保存到指定文件路径""" + with patch('dingo.model.llm.agent.tools.render_tool.RenderTool._render_text') as mock_render: + # Mock successful render + mock_img = MagicMock() + mock_render.return_value = mock_img + + output_path = "/tmp/test_output.png" + result = RenderTool.execute( + content="test", + content_type="text", + output_path=output_path + ) + + if result['success']: + # save 会被调用两次:一次保存到 BytesIO,一次保存到文件 + assert mock_img.save.call_count == 2 + assert result['image_path'] == output_path + + def test_update_config(self): + """测试更新配置""" + original_density = RenderTool.config.density + original_pad = RenderTool.config.pad + + # 更新配置 + RenderTool.update_config({ + 'density': 200, + 'pad': 30 + }) + + assert RenderTool.config.density == 200 + assert RenderTool.config.pad == 30 + + # 恢复配置 + RenderTool.config.density = original_density + RenderTool.config.pad = original_pad + + def test_multiline_text_handling(self): + """测试多行文本处理""" + multiline_content = "Line 1\nLine 2\nLine 3" + + with patch('dingo.model.llm.agent.tools.render_tool.RenderTool._render_text') as mock_render: + mock_img = MagicMock() + mock_render.return_value = mock_img + + with patch('io.BytesIO') as mock_bytesio: + mock_buffer = MagicMock() + mock_bytesio.return_value = mock_buffer + mock_buffer.getvalue.return_value = b'fake_data' + + result = RenderTool.execute( + content=multiline_content, + content_type="text" + ) + + # 应该调用 _render_text 并传入多行内容 + if result['success']: + mock_render.assert_called_once() + call_args = mock_render.call_args[0] + assert '\n' in call_args[0] + + def test_font_fallback(self): + """测试字体加载失败时的 fallback""" + # 使用不存在的字体路径,测试是否能 fallback 到系统字体或默认字体 + RenderTool.config.font_path = "/nonexistent/path/to/font.ttf" + + result = RenderTool.execute( + content="test fallback", + content_type="text" + ) + + # 即使指定的字体不存在,也应该能够渲染成功(使用 fallback) + assert result['success'] is True + assert 'image_base64' in result + + def test_render_special_characters(self): + """测试特殊字符渲染""" + special_content = "Price: $123.45 (Discount: 20%)" + + with patch('dingo.model.llm.agent.tools.render_tool.RenderTool._render_text') as mock_render: + mock_img = MagicMock() + mock_render.return_value = mock_img + + with patch('io.BytesIO') as mock_bytesio: + mock_buffer = MagicMock() + mock_bytesio.return_value = mock_buffer + mock_buffer.getvalue.return_value = b'fake_data' + + result = RenderTool.execute( + content=special_content, + content_type="text" + ) + + if result['success']: + # 应该能处理特殊字符 + call_args = mock_render.call_args[0] + assert '$' in call_args[0] + assert '%' in call_args[0] + assert ':' in call_args[0] + + def test_unicode_arrow_symbols(self): + """测试 Unicode 箭头符号 (wasysym 包支持)""" + result = RenderTool.execute( + content="价格: $123.45 ◄ 原价: $200.00", + content_type="text" + ) + + # 如果环境支持,应该能成功渲染 + assert 'success' in result + if result['success']: + assert 'image_base64' in result + assert result['content_type'] == 'text' + + def test_greek_letters_equation(self): + """测试正体希腊字母 (upgreek 包支持)""" + result = RenderTool.execute( + content="α + β = γ", + content_type="equation" + ) + + # LaTeX 渲染需要环境支持 + assert 'success' in result + # 只验证返回结构,不强制要求成功(依赖环境) + + def test_copyright_symbol(self): + """测试版权符号 (textcomp 包支持)""" + result = RenderTool.execute( + content="版权所有 © 2026", + content_type="text" + ) + + assert 'success' in result + if result['success']: + assert 'image_base64' in result + + def test_large_matrix_support(self): + """测试大矩阵支持 (MaxMatrixCols=1000)""" + # 创建一个 50 列的矩阵(超过默认的 10 列限制) + matrix_content = "\\begin{pmatrix}" + " & ".join([str(i) for i in range(50)]) + "\\end{pmatrix}" + + result = RenderTool.execute( + content=matrix_content, + content_type="equation" + ) + + # 验证能处理大矩阵而不报错 + assert 'success' in result + # LaTeX 编译依赖环境,只验证结构 + + def test_extended_cjk_characters(self): + """测试扩展 CJK 字符范围支持""" + # 测试包含罕见汉字和符号的内容 + result = RenderTool.execute( + content="𠮷野家:讃岐うどん", # 包含扩展 B 区汉字 + content_type="text" + ) + + assert 'success' in result + if result['success']: + assert 'image_base64' in result + + def test_mixed_unicode_content(self): + """测试混合 Unicode 内容 (综合测试)""" + mixed_content = "价格 $99.99 ◄ 折扣 20% • 版权 © 2026 ★ α=0.5" + + result = RenderTool.execute( + content=mixed_content, + content_type="text" + ) + + assert 'success' in result + assert 'content_type' in result + # 验证内容类型正确传递 + assert result['content_type'] == 'text' + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/test/scripts/model/llm/test_vlm_render_judge.py b/test/scripts/model/llm/test_vlm_render_judge.py new file mode 100644 index 00000000..ae4a1a08 --- /dev/null +++ b/test/scripts/model/llm/test_vlm_render_judge.py @@ -0,0 +1,340 @@ +""" +VLMRenderJudge 单元测试 + +测试基于视觉渲染的 OCR 质量评估功能: +1. 成功的视觉比较(OCR 正确) +2. 失败的视觉比较(OCR 错误) +3. 渲染失败的处理 +4. 响应解析 + +运行方式: +pytest test/scripts/model/llm/test_vlm_render_judge.py -v +""" + +from unittest.mock import patch + +import pytest + +from dingo.config.input_args import EvaluatorLLMArgs +from dingo.io import Data +from dingo.model.llm.vlm_render_judge import VLMRenderJudge + + +class TestVLMRenderJudge: + """测试 VLMRenderJudge 核心功能""" + + def setup_method(self): + """每个测试前的设置""" + VLMRenderJudge.dynamic_config = EvaluatorLLMArgs( + model="gpt-4o", + key="test-api-key", + api_url="https://api.openai.com/v1" + ) + + def test_parse_response_correct(self): + """测试解析 VLM 判断为正确的响应""" + response = """ + + Both images show the same text content with consistent formatting. + The symbols, characters, and layout are fully consistent. + + true + """ + + result = VLMRenderJudge._parse_response(response) + + assert result['is_correct'] is True + assert 'consistent' in result['reason'].lower() + + def test_parse_response_incorrect(self): + """测试解析 VLM 判断为错误的响应""" + response = """ + + GT has "lazy" while OCR has "lzy". The character "a" is missing. + This is an actual character omission. + + false + """ + + result = VLMRenderJudge._parse_response(response) + + assert result['is_correct'] is False + assert 'missing' in result['reason'].lower() + + def test_parse_response_with_true_false_both(self): + """测试包含 true 和 false 的响应(应判断为 false)""" + response = """ + The text says "true to life" but the OCR result is false. + false + """ + + result = VLMRenderJudge._parse_response(response) + + # "false" 在 answer 中,应该判为 false + assert result['is_correct'] is False + + def test_parse_response_no_xml_tags(self): + """测试没有 XML 标签的响应""" + response = "The OCR result looks correct to me." + + result = VLMRenderJudge._parse_response(response) + + # 没有 answer 标签,应判为 false + assert result['is_correct'] is False + assert len(result['reason']) > 0 + + def test_build_result_correct(self): + """测试构建正确的评估结果""" + judge_result = { + 'is_correct': True, + 'reason': 'Both images are consistent' + } + + result = VLMRenderJudge._build_result(judge_result, "test content") + + assert result.score == 1.0 + assert result.metric == "VLMRenderJudge" + assert "QUALITY_GOOD" in result.label + assert any("✅" in reason for reason in result.reason) + + def test_build_result_incorrect(self): + """测试构建错误的评估结果""" + judge_result = { + 'is_correct': False, + 'reason': 'OCR has missing characters' + } + + result = VLMRenderJudge._build_result(judge_result, "test content") + + assert result.score == 0.0 + assert result.metric == "VLMRenderJudge" + assert "QUALITY_BAD_OCR.VISUAL_MISMATCH" in result.label + assert any("❌" in reason for reason in result.reason) + assert any("test content" in reason for reason in result.reason) + + def test_text_only_comparison_fallback(self): + """测试渲染失败时的 fallback""" + result = VLMRenderJudge._text_only_comparison( + "test_image.png", + "test content" + ) + + assert result.score == 0.5 + assert result.status is True + assert "QUALITY_UNKNOWN.RENDER_FAILED" in result.label + assert any("Could not render" in reason for reason in result.reason) + assert result.metric == "VLMRenderJudge" + + def test_get_image_from_data(self): + """测试从 Data 对象提取图片""" + # 单个图片路径 + data = Data(image="test/image.png") + image = VLMRenderJudge._get_image(data) + assert image == "test/image.png" + + # 图片列表(取第一个) + data = Data(image=["test/image1.png", "test/image2.png"]) + image = VLMRenderJudge._get_image(data) + assert image == "test/image1.png" + + # 没有图片 + data = Data(content="text only") + image = VLMRenderJudge._get_image(data) + assert image is None + + def test_get_content_from_data(self): + """测试从 Data 对象提取内容""" + data = Data(content="The quick brown fox") + content = VLMRenderJudge._get_content(data) + assert content == "The quick brown fox" + + # 没有内容 + data = Data(image="test.png") + content = VLMRenderJudge._get_content(data) + assert content is None + + def test_get_content_type(self): + """测试获取内容类型""" + # 从 Data 对象 + data = Data(content="test", content_type="equation") + content_type = VLMRenderJudge._get_content_type(data) + assert content_type == "equation" + + # 从配置参数 + VLMRenderJudge.dynamic_config = EvaluatorLLMArgs( + model="gpt-4o", + key="test-key", + parameters={"content_type": "table"} + ) + data = Data(content="test") + content_type = VLMRenderJudge._get_content_type(data) + assert content_type == "table" + + # 默认值 + VLMRenderJudge.dynamic_config.parameters = None + data = Data(content="test") + content_type = VLMRenderJudge._get_content_type(data) + assert content_type == "text" + + @patch('dingo.model.llm.vlm_render_judge.VLMRenderJudge._judge') + @patch('dingo.model.llm.agent.tools.render_tool.RenderTool.execute') + @patch('dingo.model.llm.vlm_render_judge.VLMRenderJudge.create_client') + def test_eval_success_correct_ocr(self, mock_create_client, mock_render, mock_judge): + """测试完整评估流程 - OCR 正确""" + # Mock client creation (do nothing) + mock_create_client.return_value = None + + # Mock 渲染成功 + mock_render.return_value = { + 'success': True, + 'image_base64': 'base64_encoded_image_data' + } + + # Mock VLM 判断为正确 + mock_judge.return_value = { + 'is_correct': True, + 'reason': 'Both images are consistent' + } + + data = Data( + image="test/image.png", + content="The quick brown fox jumps over the lazy dog.", + content_type="text" + ) + + result = VLMRenderJudge.eval(data) + + # 验证结果 + assert result.score == 1.0 + assert result.metric == "VLMRenderJudge" + assert "QUALITY_GOOD" in result.label + assert mock_render.called + assert mock_judge.called + + @patch('dingo.model.llm.vlm_render_judge.VLMRenderJudge._judge') + @patch('dingo.model.llm.agent.tools.render_tool.RenderTool.execute') + @patch('dingo.model.llm.vlm_render_judge.VLMRenderJudge.create_client') + def test_eval_success_incorrect_ocr(self, mock_create_client, mock_render, mock_judge): + """测试完整评估流程 - OCR 错误""" + # Mock client creation (do nothing) + mock_create_client.return_value = None + + # Mock 渲染成功 + mock_render.return_value = { + 'success': True, + 'image_base64': 'base64_encoded_image_data' + } + + # Mock VLM 判断为错误 + mock_judge.return_value = { + 'is_correct': False, + 'reason': 'GT has "lazy" but OCR has "lzy", missing character "a"' + } + + data = Data( + image="test/image.png", + content="The quick brown fox jumps over the lzy dog.", + content_type="text" + ) + + result = VLMRenderJudge.eval(data) + + # 验证结果 + assert result.score == 0.0 + assert result.metric == "VLMRenderJudge" + assert "QUALITY_BAD_OCR.VISUAL_MISMATCH" in result.label + assert any("missing" in reason.lower() for reason in result.reason) + + @patch('dingo.model.llm.agent.tools.render_tool.RenderTool.execute') + def test_eval_render_failed(self, mock_render): + """测试渲染失败的情况""" + # Mock 渲染失败 + mock_render.return_value = { + 'success': False, + 'error': 'LaTeX compilation failed' + } + + data = Data( + image="test/image.png", + content="E = mc^2", + content_type="equation" + ) + + result = VLMRenderJudge.eval(data) + + # 验证 fallback 结果 + assert result.score == 0.5 + assert result.status is True + assert "QUALITY_UNKNOWN.RENDER_FAILED" in result.label + assert mock_render.called + + @patch('dingo.model.llm.vlm_render_judge.VLMRenderJudge.create_client') + def test_eval_missing_image(self, mock_create_client): + """测试缺少图片的情况""" + mock_create_client.return_value = None + + data = Data(content="test content") + + result = VLMRenderJudge.eval(data) + + assert result.status is True + assert "QUALITY_BAD" in result.label[0] + assert any("image" in reason.lower() for reason in result.reason) + + @patch('dingo.model.llm.vlm_render_judge.VLMRenderJudge.create_client') + def test_eval_missing_content(self, mock_create_client): + """测试缺少内容的情况""" + mock_create_client.return_value = None + + data = Data(image="test/image.png") + + result = VLMRenderJudge.eval(data) + + assert result.status is True + assert "QUALITY_BAD" in result.label[0] + assert any("content" in reason.lower() for reason in result.reason) + + @patch('dingo.model.llm.vlm_render_judge.VLMRenderJudge._judge') + @patch('dingo.model.llm.agent.tools.render_tool.RenderTool.execute') + @patch('dingo.model.llm.vlm_render_judge.VLMRenderJudge.create_client') + def test_eval_with_render_config(self, mock_create_client, mock_render, mock_judge): + """测试使用自定义渲染配置""" + mock_create_client.return_value = None + + # 设置渲染配置 + VLMRenderJudge.dynamic_config = EvaluatorLLMArgs( + model="gpt-4o", + key="test-key", + api_url="https://api.openai.com/v1", + parameters={ + "render_config": { + "density": 300, + "pad": 30 + } + } + ) + + mock_render.return_value = { + 'success': True, + 'image_base64': 'base64_data' + } + mock_judge.return_value = { + 'is_correct': True, + 'reason': 'OK' + } + + data = Data( + image="test/image.png", + content="test", + content_type="equation" + ) + + result = VLMRenderJudge.eval(data) + + # 验证 RenderTool.update_config 被调用 + # (通过检查结果正常即可) + assert result.score == 1.0 + + +if __name__ == "__main__": + pytest.main([__file__, "-v"])