From e9768ae6a5910e19ed95917c1888651289a70720 Mon Sep 17 00:00:00 2001 From: Yusik Kim <107410898+kmyusk@users.noreply.github.com> Date: Fri, 24 Jan 2025 17:35:29 +0100 Subject: [PATCH] chore: expose draw_clusters function (#803) feat: expose draw_clusters function add type annotations to function signature Signed-off-by: Yusik Kim --- docling/models/layout_model.py | 94 ++++------------------------------ docling/utils/visualization.py | 80 +++++++++++++++++++++++++++++ 2 files changed, 89 insertions(+), 85 deletions(-) create mode 100644 docling/utils/visualization.py diff --git a/docling/models/layout_model.py b/docling/models/layout_model.py index 9fa0ecb4..69193c94 100644 --- a/docling/models/layout_model.py +++ b/docling/models/layout_model.py @@ -1,28 +1,21 @@ import copy import logging -import random -import time from pathlib import Path -from typing import Iterable, List +from typing import Iterable -from docling_core.types.doc import CoordOrigin, DocItemLabel +from docling_core.types.doc import DocItemLabel from docling_ibm_models.layoutmodel.layout_predictor import LayoutPredictor -from PIL import Image, ImageDraw, ImageFont - -from docling.datamodel.base_models import ( - BoundingBox, - Cell, - Cluster, - LayoutPrediction, - Page, -) +from PIL import Image + +from docling.datamodel.base_models import BoundingBox, Cluster, LayoutPrediction, Page from docling.datamodel.document import ConversionResult -from docling.datamodel.pipeline_options import AcceleratorDevice, AcceleratorOptions +from docling.datamodel.pipeline_options import AcceleratorOptions from docling.datamodel.settings import settings from docling.models.base_model import BasePageModel from docling.utils.accelerator_utils import decide_device from docling.utils.layout_postprocessor import LayoutPostprocessor from docling.utils.profiling import TimeRecorder +from docling.utils.visualization import draw_clusters _log = logging.getLogger(__name__) @@ -82,78 +75,9 @@ def draw_clusters_and_cells_side_by_side( left_image = copy.deepcopy(page.image) right_image = copy.deepcopy(page.image) - # Function to draw clusters on an image - def draw_clusters(image, clusters): - draw = ImageDraw.Draw(image, "RGBA") - # Create a smaller font for the labels - try: - font = ImageFont.truetype("arial.ttf", 12) - except OSError: - # Fallback to default font if arial is not available - font = ImageFont.load_default() - for c_tl in clusters: - all_clusters = [c_tl, *c_tl.children] - for c in all_clusters: - # Draw cells first (underneath) - cell_color = (0, 0, 0, 40) # Transparent black for cells - for tc in c.cells: - cx0, cy0, cx1, cy1 = tc.bbox.as_tuple() - cx0 *= scale_x - cx1 *= scale_x - cy0 *= scale_x - cy1 *= scale_y - - draw.rectangle( - [(cx0, cy0), (cx1, cy1)], - outline=None, - fill=cell_color, - ) - # Draw cluster rectangle - x0, y0, x1, y1 = c.bbox.as_tuple() - x0 *= scale_x - x1 *= scale_x - y0 *= scale_x - y1 *= scale_y - - cluster_fill_color = (*list(DocItemLabel.get_color(c.label)), 70) - cluster_outline_color = ( - *list(DocItemLabel.get_color(c.label)), - 255, - ) - draw.rectangle( - [(x0, y0), (x1, y1)], - outline=cluster_outline_color, - fill=cluster_fill_color, - ) - # Add label name and confidence - label_text = f"{c.label.name} ({c.confidence:.2f})" - # Create semi-transparent background for text - text_bbox = draw.textbbox((x0, y0), label_text, font=font) - text_bg_padding = 2 - draw.rectangle( - [ - ( - text_bbox[0] - text_bg_padding, - text_bbox[1] - text_bg_padding, - ), - ( - text_bbox[2] + text_bg_padding, - text_bbox[3] + text_bg_padding, - ), - ], - fill=(255, 255, 255, 180), # Semi-transparent white - ) - # Draw text - draw.text( - (x0, y0), - label_text, - fill=(0, 0, 0, 255), # Solid black - font=font, - ) - # Draw clusters on both images - draw_clusters(left_image, left_clusters) - draw_clusters(right_image, right_clusters) + draw_clusters(left_image, left_clusters, scale_x, scale_y) + draw_clusters(right_image, right_clusters, scale_x, scale_y) # Combine the images side by side combined_width = left_image.width * 2 combined_height = left_image.height diff --git a/docling/utils/visualization.py b/docling/utils/visualization.py new file mode 100644 index 00000000..465b7749 --- /dev/null +++ b/docling/utils/visualization.py @@ -0,0 +1,80 @@ +from docling_core.types.doc import DocItemLabel +from PIL import Image, ImageDraw, ImageFont +from PIL.ImageFont import FreeTypeFont + +from docling.datamodel.base_models import Cluster + + +def draw_clusters( + image: Image.Image, clusters: list[Cluster], scale_x: float, scale_y: float +) -> None: + """ + Draw clusters on an image + """ + draw = ImageDraw.Draw(image, "RGBA") + # Create a smaller font for the labels + font: ImageFont.ImageFont | FreeTypeFont + try: + font = ImageFont.truetype("arial.ttf", 12) + except OSError: + # Fallback to default font if arial is not available + font = ImageFont.load_default() + for c_tl in clusters: + all_clusters = [c_tl, *c_tl.children] + for c in all_clusters: + # Draw cells first (underneath) + cell_color = (0, 0, 0, 40) # Transparent black for cells + for tc in c.cells: + cx0, cy0, cx1, cy1 = tc.bbox.as_tuple() + cx0 *= scale_x + cx1 *= scale_x + cy0 *= scale_x + cy1 *= scale_y + + draw.rectangle( + [(cx0, cy0), (cx1, cy1)], + outline=None, + fill=cell_color, + ) + # Draw cluster rectangle + x0, y0, x1, y1 = c.bbox.as_tuple() + x0 *= scale_x + x1 *= scale_x + y0 *= scale_x + y1 *= scale_y + + cluster_fill_color = (*list(DocItemLabel.get_color(c.label)), 70) + cluster_outline_color = ( + *list(DocItemLabel.get_color(c.label)), + 255, + ) + draw.rectangle( + [(x0, y0), (x1, y1)], + outline=cluster_outline_color, + fill=cluster_fill_color, + ) + # Add label name and confidence + label_text = f"{c.label.name} ({c.confidence:.2f})" + # Create semi-transparent background for text + text_bbox = draw.textbbox((x0, y0), label_text, font=font) + text_bg_padding = 2 + draw.rectangle( + [ + ( + text_bbox[0] - text_bg_padding, + text_bbox[1] - text_bg_padding, + ), + ( + text_bbox[2] + text_bg_padding, + text_bbox[3] + text_bg_padding, + ), + ], + fill=(255, 255, 255, 180), # Semi-transparent white + ) + # Draw text + draw.text( + (x0, y0), + label_text, + fill=(0, 0, 0, 255), # Solid black + font=font, + )