Skip to content

perf: Speed up method LayoutPostprocessor._process_special_clusters by 653% #1952

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from

Conversation

mohammedahmed18
Copy link

πŸ“„ 653% (6.53x) speedup for LayoutPostprocessor._process_special_clusters in docling/utils/layout_postprocessor.py

⏱️ Runtime : 236 milliseconds β†’ 31.3 milliseconds (best of 43 runs)

πŸ“ Explanation and details

Here are targeted optimizations based on the profiling output and the code.

Major bottlenecks & optimization strategies

1. _process_special_clusters:

  • Main bottleneck:
    • The nested loop: for each special cluster, loop through all regular clusters and compute .bbox.intersection_over_self(special.bbox).
    • This is O(N*M) for N special and M regular clusters and is by far the slowest part.
  • Optimization:
    • Pre-index regular clusters by bounding box for fast containment:
      • Build a simple R-tree-like spatial grid (using bins, or just a fast bbox filtering pass) to filter out regular clusters that are definitely non-overlapping before running the expensive geometric calculation.
    • If spatial index unavailable: Pre-filter regulars to those whose bbox intersects the special’s bbox (quick min/max bbox checks), greatly reducing pairwise calculations.

2. _handle_cross_type_overlaps:

  • Similar bottleneck: Again, checking every regular cluster for every wrapper.
    • We can apply the same bbox quick-check.

3. Miscellaneous.

  • _deduplicate_cells/_sort_cells optimizations: Minor, but batch sort/unique patterns can help.
  • Avoid recomputation: Avoid recomputing thresholds/constants in hot loops.

Below is the optimized code addressing the biggest O(N*M) loop, using fast bbox intersection check for quick rejection before expensive calculation.
We achieve this purely with local logic in the function (no external indices needed), and respect your constraint not to introduce module-level classes.
Comments in the code indicate all changes.

Summary of changes:

  • For both _process_special_clusters and _handle_cross_type_overlaps, we avoid unnecessary .intersection_over_self calculations by pre-filtering clusters based on simple bbox intersection conditions (l < rx and r > lx and t < by and b > ty).
  • This turns expensive O(N*M) geometric checks into a two-stage filter, which is extremely fast for typical bbox distributions.
  • All hot-spot loops now use local variables rather than repeated attribute lookups.
  • No changes are made to APIs, outputs, or major logic branches; only faster candidate filtering is introduced.

This should reduce total runtime of _process_special_clusters and _handle_cross_type_overlaps by an order of magnitude on large documents.

βœ… Correctness verification report:

Test Status
βš™οΈ Existing Unit Tests πŸ”˜ None Found
πŸŒ€ Generated Regression Tests βœ… 61 Passed
βͺ Replay Tests πŸ”˜ None Found
πŸ”Ž Concolic Coverage Tests πŸ”˜ None Found
πŸ“Š Tests Coverage 100.0%
πŸŒ€ Generated Regression Tests and Runtime
from enum import Enum, auto
from typing import List, Optional

# imports
import pytest
from docling.utils.layout_postprocessor import LayoutPostprocessor

# --- Minimal stubs for required classes and types ---

class DocItemLabel(Enum):
    CAPTION = auto()
    FOOTNOTE = auto()
    FORMULA = auto()
    LIST_ITEM = auto()
    PAGE_FOOTER = auto()
    PAGE_HEADER = auto()
    PICTURE = auto()
    SECTION_HEADER = auto()
    TABLE = auto()
    TEXT = auto()
    TITLE = auto()
    CODE = auto()
    CHECKBOX_SELECTED = auto()
    CHECKBOX_UNSELECTED = auto()
    FORM = auto()
    KEY_VALUE_REGION = auto()
    DOCUMENT_INDEX = auto()

class BoundingBox:
    def __init__(self, l, t, r, b):
        self.l = l
        self.t = t
        self.r = r
        self.b = b

    def area(self):
        width = max(0, self.r - self.l)
        height = max(0, self.b - self.t)
        return width * height

    def as_tuple(self):
        return (self.l, self.t, self.r, self.b)

    def intersection(self, other: 'BoundingBox'):
        l = max(self.l, other.l)
        t = max(self.t, other.t)
        r = min(self.r, other.r)
        b = min(self.b, other.b)
        if l < r and t < b:
            return BoundingBox(l, t, r, b)
        return None

    def intersection_area(self, other: 'BoundingBox'):
        inter = self.intersection(other)
        return inter.area() if inter else 0

    def intersection_over_union(self, other: 'BoundingBox'):
        inter = self.intersection_area(other)
        union = self.area() + other.area() - inter
        return inter / union if union > 0 else 0

    def intersection_over_self(self, other: 'BoundingBox'):
        inter = self.intersection_area(other)
        return inter / self.area() if self.area() > 0 else 0

    def __eq__(self, other):
        return (self.l, self.t, self.r, self.b) == (other.l, other.t, other.r, other.b)

    def __repr__(self):
        return f"BoundingBox({self.l}, {self.t}, {self.r}, {self.b})"

class TextCell:
    def __init__(self, index, text=""):
        self.index = index
        self.text = text

    def __eq__(self, other):
        return self.index == other.index and self.text == other.text

    def __repr__(self):
        return f"TextCell({self.index}, '{self.text}')"

class Cluster:
    def __init__(self, id, label, bbox, confidence, cells=None):
        self.id = id
        self.label = label
        self.bbox = bbox
        self.confidence = confidence
        self.cells = cells if cells is not None else []
        self.children = []

    def __eq__(self, other):
        return (
            self.id == other.id
            and self.label == other.label
            and self.bbox == other.bbox
            and abs(self.confidence - other.confidence) < 1e-6
            and self.cells == other.cells
            and self.children == other.children
        )

    def __repr__(self):
        return (
            f"Cluster(id={self.id}, label={self.label}, bbox={self.bbox}, "
            f"confidence={self.confidence:.2f}, cells={self.cells}, children={self.children})"
        )

class Page:
    def __init__(self, cells: List[TextCell], size):
        self.cells = cells
        self.size = size

class PageSize:
    def __init__(self, width, height):
        self.width = width
        self.height = height

class LayoutOptions:
    pass

# --- Unit Tests ---

# Helper to create clusters and page
def make_page_and_clusters(
    special_clusters: List[Cluster],
    regular_clusters: Optional[List[Cluster]] = None,
    page_size=(1000, 1000),
    cells: Optional[List[TextCell]] = None,
):
    if cells is None:
        cells = []
        for c in special_clusters + (regular_clusters or []):
            cells.extend(c.cells)
        # Remove duplicates
        seen = set()
        dedup_cells = []
        for cell in cells:
            if cell.index not in seen:
                seen.add(cell.index)
                dedup_cells.append(cell)
        cells = dedup_cells
    page = Page(cells, PageSize(*page_size))
    all_clusters = (regular_clusters or []) + special_clusters
    return page, all_clusters

# --- Basic Test Cases ---

def test_empty_special_clusters():
    # No clusters at all
    page, clusters = make_page_and_clusters([])
    proc = LayoutPostprocessor(page, clusters, LayoutOptions())
    codeflash_output = proc._process_special_clusters(); result = codeflash_output # 2.20ΞΌs -> 2.21ΞΌs (0.543% slower)

def test_special_clusters_below_confidence():
    # All special clusters below threshold
    c1 = Cluster(1, DocItemLabel.PICTURE, BoundingBox(10, 10, 20, 20), 0.2, [])
    c2 = Cluster(2, DocItemLabel.FORM, BoundingBox(30, 30, 40, 40), 0.3, [])
    page, clusters = make_page_and_clusters([c1, c2])
    proc = LayoutPostprocessor(page, clusters, LayoutOptions())
    codeflash_output = proc._process_special_clusters(); result = codeflash_output # 2.16ΞΌs -> 2.92ΞΌs (25.8% slower)

def test_special_clusters_above_confidence():
    # One special cluster above threshold
    c1 = Cluster(1, DocItemLabel.PICTURE, BoundingBox(10, 10, 20, 20), 0.6, [])
    page, clusters = make_page_and_clusters([c1])
    proc = LayoutPostprocessor(page, clusters, LayoutOptions())
    codeflash_output = proc._process_special_clusters(); result = codeflash_output # 2.02ΞΌs -> 2.33ΞΌs (13.2% slower)



def test_full_page_picture_is_removed():
    # Picture covers >90% of page area
    page_width, page_height = 1000, 1000
    big_pic = Cluster(1, DocItemLabel.PICTURE, BoundingBox(0, 0, 1000, 950), 0.99, [])
    page, clusters = make_page_and_clusters([big_pic], page_size=(page_width, page_height))
    proc = LayoutPostprocessor(page, clusters, LayoutOptions())
    codeflash_output = proc._process_special_clusters(); result = codeflash_output # 2.84ΞΌs -> 4.01ΞΌs (29.2% slower)

def test_picture_just_below_full_page_is_kept():
    # Picture covers exactly 90% of page area (should be kept)
    page_width, page_height = 1000, 1000
    area = page_width * page_height
    pic_height = int((0.9 * area) / page_width)
    pic = Cluster(1, DocItemLabel.PICTURE, BoundingBox(0, 0, 1000, pic_height), 0.99, [])
    page, clusters = make_page_and_clusters([pic], page_size=(page_width, page_height))
    proc = LayoutPostprocessor(page, clusters, LayoutOptions())
    codeflash_output = proc._process_special_clusters(); result = codeflash_output # 2.06ΞΌs -> 2.52ΞΌs (18.1% slower)

def test_cross_type_overlap_removes_key_value():
    # KEY_VALUE_REGION overlaps TABLE by >0.9 and confidence diff < 0.1
    reg = Cluster(2, DocItemLabel.TABLE, BoundingBox(10, 10, 50, 50), 0.8, [])
    kv = Cluster(1, DocItemLabel.KEY_VALUE_REGION, BoundingBox(10, 10, 50, 50), 0.85, [])
    page, clusters = make_page_and_clusters([kv], [reg])
    proc = LayoutPostprocessor(page, clusters, LayoutOptions())
    codeflash_output = proc._process_special_clusters(); result = codeflash_output # 2.09ΞΌs -> 2.48ΞΌs (15.8% slower)

def test_cross_type_overlap_does_not_remove_if_conf_high():
    # KEY_VALUE_REGION overlaps TABLE but confidence diff >= 0.1
    reg = Cluster(2, DocItemLabel.TABLE, BoundingBox(10, 10, 50, 50), 0.7, [])
    kv = Cluster(1, DocItemLabel.KEY_VALUE_REGION, BoundingBox(10, 10, 50, 50), 0.85, [])
    page, clusters = make_page_and_clusters([kv], [reg])
    proc = LayoutPostprocessor(page, clusters, LayoutOptions())
    codeflash_output = proc._process_special_clusters(); result = codeflash_output # 1.99ΞΌs -> 2.51ΞΌs (20.4% slower)




def test_many_special_clusters_scaling():
    # 500 special clusters, all above threshold, all non-overlapping
    specials = [
        Cluster(i, DocItemLabel.PICTURE, BoundingBox(i*2, 0, i*2+1, 1), 0.99, [])
        for i in range(1, 501)
    ]
    page, clusters = make_page_and_clusters(specials, page_size=(2000, 2000))
    proc = LayoutPostprocessor(page, clusters, LayoutOptions())
    codeflash_output = proc._process_special_clusters(); result = codeflash_output # 4.35ΞΌs -> 19.2ΞΌs (77.4% slower)
    ids = set(c.id for c in result)

def test_many_regulars_and_specials_assignment():
    # 50 wrappers, each with 10 regulars inside, all should be assigned as children
    wrappers = []
    regulars = []
    for i in range(50):
        l, t = i*10, i*10
        wrapper = Cluster(i+1000, DocItemLabel.FORM, BoundingBox(l, t, l+10, t+10), 0.95, [])
        wrappers.append(wrapper)
        for j in range(10):
            idx = i*10 + j
            cell = TextCell(idx, f"cell{idx}")
            reg = Cluster(i*100 + j, DocItemLabel.TEXT, BoundingBox(l+1, t+1, l+2, t+2), 0.9, [cell])
            regulars.append(reg)
    page, clusters = make_page_and_clusters(wrappers, regulars, page_size=(1000, 1000))
    proc = LayoutPostprocessor(page, clusters, LayoutOptions())
    codeflash_output = proc._process_special_clusters(); result = codeflash_output
    # Each wrapper should have 10 children
    for wrapper in result:
        indices = [cell.index for cell in wrapper.cells]

def test_large_number_of_clusters_efficiency():
    # 900 clusters, 300 special, 600 regular, all above threshold, no overlaps
    specials = [
        Cluster(i, DocItemLabel.PICTURE, BoundingBox(i*3, 0, i*3+1, 1), 0.99, [])
        for i in range(1, 301)
    ]
    regulars = [
        Cluster(i+1000, DocItemLabel.TEXT, BoundingBox(i*2, 2, i*2+1, 3), 0.99, [TextCell(i+1000)])
        for i in range(1, 601)
    ]
    page, clusters = make_page_and_clusters(specials, regulars, page_size=(5000, 5000))
    proc = LayoutPostprocessor(page, clusters, LayoutOptions())
    codeflash_output = proc._process_special_clusters(); result = codeflash_output # 4.47ΞΌs -> 27.4ΞΌs (83.7% slower)
    # All special clusters should be present
    ids = set(c.id for c in result)

def test_special_and_regular_clusters_with_some_overlap():
    # 10 wrappers, each with 5 regulars, and some regulars not contained
    wrappers = []
    regulars = []
    for i in range(10):
        l, t = i*10, i*10
        wrapper = Cluster(i+100, DocItemLabel.FORM, BoundingBox(l, t, l+10, t+10), 0.95, [])
        wrappers.append(wrapper)
        for j in range(5):
            idx = i*10 + j
            cell = TextCell(idx, f"cell{idx}")
            reg = Cluster(i*100 + j, DocItemLabel.TEXT, BoundingBox(l+1, t+1, l+2, t+2), 0.9, [cell])
            regulars.append(reg)
    # Add 10 regulars not contained in any wrapper
    for i in range(10):
        cell = TextCell(1000+i, f"cell{1000+i}")
        reg = Cluster(2000+i, DocItemLabel.TEXT, BoundingBox(500+i, 500+i, 501+i, 501+i), 0.9, [cell])
        regulars.append(reg)
    page, clusters = make_page_and_clusters(wrappers, regulars, page_size=(1000, 1000))
    proc = LayoutPostprocessor(page, clusters, LayoutOptions())
    codeflash_output = proc._process_special_clusters(); result = codeflash_output
    # Each wrapper should have 5 children, and the 10 uncontained regulars should not be children
    for wrapper in result:
        pass
    # The 10 uncontained regulars should not be assigned as children anywhere
    all_child_ids = set()
    for wrapper in result:
        all_child_ids.update(c.id for c in wrapper.children)
    for reg in regulars[-10:]:
        pass
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.

from typing import List, Optional

# imports
import pytest
from docling.utils.layout_postprocessor import LayoutPostprocessor

# --- Minimal stubs for required classes and enums ---

class DocItemLabel:
    # Simulate enum values
    CAPTION = "caption"
    FOOTNOTE = "footnote"
    FORMULA = "formula"
    LIST_ITEM = "list_item"
    PAGE_FOOTER = "page_footer"
    PAGE_HEADER = "page_header"
    PICTURE = "picture"
    SECTION_HEADER = "section_header"
    TABLE = "table"
    TEXT = "text"
    TITLE = "title"
    CODE = "code"
    CHECKBOX_SELECTED = "checkbox_selected"
    CHECKBOX_UNSELECTED = "checkbox_unselected"
    FORM = "form"
    KEY_VALUE_REGION = "key_value_region"
    DOCUMENT_INDEX = "document_index"

class BoundingBox:
    def __init__(self, l, t, r, b):
        self.l = l
        self.t = t
        self.r = r
        self.b = b

    def area(self):
        width = max(0, self.r - self.l)
        height = max(0, self.b - self.t)
        return width * height

    def as_tuple(self):
        return (self.l, self.t, self.r, self.b)

    def intersection(self, other: "BoundingBox") -> Optional["BoundingBox"]:
        l = max(self.l, other.l)
        t = max(self.t, other.t)
        r = min(self.r, other.r)
        b = min(self.b, other.b)
        if l < r and t < b:
            return BoundingBox(l, t, r, b)
        return None

    def intersection_area(self, other: "BoundingBox") -> float:
        inter = self.intersection(other)
        return inter.area() if inter else 0.0

    def intersection_over_union(self, other: "BoundingBox") -> float:
        inter = self.intersection_area(other)
        union = self.area() + other.area() - inter
        return inter / union if union > 0 else 0.0

    def intersection_over_self(self, other: "BoundingBox") -> float:
        inter = self.intersection_area(other)
        return inter / self.area() if self.area() > 0 else 0.0

    def __eq__(self, other):
        return (self.l, self.t, self.r, self.b) == (other.l, other.t, other.r, other.b)

    def __repr__(self):
        return f"BoundingBox({self.l},{self.t},{self.r},{self.b})"

class TextCell:
    def __init__(self, index):
        self.index = index

    def __eq__(self, other):
        return self.index == other.index

    def __repr__(self):
        return f"TextCell({self.index})"

class Cluster:
    def __init__(self, id, label, bbox, confidence, cells=None):
        self.id = id
        self.label = label
        self.bbox = bbox
        self.confidence = confidence
        self.cells = cells if cells is not None else []
        self.children = []

    def __eq__(self, other):
        return (
            self.id == other.id and
            self.label == other.label and
            self.bbox == other.bbox and
            abs(self.confidence - other.confidence) < 1e-6 and
            self.cells == other.cells and
            self.children == other.children
        )

    def __repr__(self):
        return (f"Cluster(id={self.id}, label={self.label}, bbox={self.bbox}, "
                f"confidence={self.confidence}, cells={self.cells}, children={self.children})")

class Page:
    def __init__(self, cells: List[TextCell], size):
        self.cells = cells
        self.size = size

class LayoutOptions:
    pass

class PageSize:
    def __init__(self, width, height):
        self.width = width
        self.height = height

# --- Unit Tests ---

# Helper to create a page of given size with N cells
def make_page(width, height, n_cells):
    return Page([TextCell(i) for i in range(n_cells)], PageSize(width, height))

def make_cluster(id, label, l, t, r, b, confidence, cell_indices):
    return Cluster(
        id=id,
        label=label,
        bbox=BoundingBox(l, t, r, b),
        confidence=confidence,
        cells=[TextCell(i) for i in cell_indices]
    )

# -------------------- BASIC TEST CASES --------------------

def test_returns_empty_if_no_special_clusters():
    # No clusters at all
    page = make_page(100, 100, 0)
    clusters = []
    lp = LayoutPostprocessor(page, clusters, LayoutOptions())
    codeflash_output = lp._process_special_clusters() # 2.11ΞΌs -> 2.11ΞΌs (0.190% faster)

def test_filters_by_confidence():
    # One special cluster below threshold, one above
    page = make_page(100, 100, 1)
    c1 = make_cluster(1, DocItemLabel.PICTURE, 0, 0, 10, 10, 0.4, [0])  # below threshold
    c2 = make_cluster(2, DocItemLabel.PICTURE, 10, 10, 20, 20, 0.6, [0]) # above threshold
    lp = LayoutPostprocessor(page, [c1, c2], LayoutOptions())
    codeflash_output = lp._process_special_clusters(); result = codeflash_output # 58.9ΞΌs -> 58.4ΞΌs (0.914% faster)

def test_removes_full_page_picture():
    # Picture covers >90% of page area
    page = make_page(100, 100, 1)
    c1 = make_cluster(1, DocItemLabel.PICTURE, 0, 0, 100, 100, 0.8, [0])  # full page
    c2 = make_cluster(2, DocItemLabel.PICTURE, 0, 0, 80, 80, 0.8, [0])    # not full page
    lp = LayoutPostprocessor(page, [c1, c2], LayoutOptions())
    codeflash_output = lp._process_special_clusters(); result = codeflash_output # 30.2ΞΌs -> 30.5ΞΌs (1.18% slower)

def test_assigns_children_and_merges_cells_for_form():
    # FORM contains two regular clusters, should assign as children and merge cells
    page = make_page(100, 100, 4)
    reg1 = make_cluster(10, DocItemLabel.TEXT, 10, 10, 20, 20, 0.8, [0, 1])
    reg2 = make_cluster(11, DocItemLabel.TEXT, 20, 20, 30, 30, 0.8, [2, 3])
    form = make_cluster(1, DocItemLabel.FORM, 5, 5, 35, 35, 0.8, [])
    lp = LayoutPostprocessor(page, [reg1, reg2, form], LayoutOptions())
    codeflash_output = lp._process_special_clusters(); result = codeflash_output # 68.2ΞΌs -> 65.1ΞΌs (4.73% faster)
    form_out = result[0]

def test_assigns_children_and_merges_cells_for_key_value_region():
    # KEY_VALUE_REGION contains one regular cluster
    page = make_page(100, 100, 2)
    reg = make_cluster(10, DocItemLabel.TEXT, 10, 10, 20, 20, 0.8, [0, 1])
    kvr = make_cluster(1, DocItemLabel.KEY_VALUE_REGION, 5, 5, 25, 25, 0.8, [])
    lp = LayoutPostprocessor(page, [reg, kvr], LayoutOptions())
    codeflash_output = lp._process_special_clusters(); result = codeflash_output # 43.7ΞΌs -> 41.8ΞΌs (4.65% faster)
    kvr_out = result[0]

def test_children_not_assigned_to_picture_or_table():
    # Children only assigned to FORM and KEY_VALUE_REGION
    page = make_page(100, 100, 2)
    reg = make_cluster(10, DocItemLabel.TEXT, 10, 10, 20, 20, 0.8, [0, 1])
    pic = make_cluster(1, DocItemLabel.PICTURE, 5, 5, 25, 25, 0.8, [])
    tbl = make_cluster(2, DocItemLabel.TABLE, 5, 5, 25, 25, 0.8, [])
    lp = LayoutPostprocessor(page, [reg, pic, tbl], LayoutOptions())
    codeflash_output = lp._process_special_clusters(); result = codeflash_output # 52.9ΞΌs -> 53.4ΞΌs (0.894% slower)
    for c in result:
        if c.label in [DocItemLabel.PICTURE, DocItemLabel.TABLE]:
            pass

# -------------------- EDGE TEST CASES --------------------

def test_zero_area_bbox_is_ignored_for_containment():
    # If a regular cluster has zero area, it should not be assigned as child
    page = make_page(100, 100, 1)
    reg = make_cluster(10, DocItemLabel.TEXT, 10, 10, 10, 10, 0.8, [0])  # zero area
    form = make_cluster(1, DocItemLabel.FORM, 5, 5, 15, 15, 0.8, [])
    lp = LayoutPostprocessor(page, [reg, form], LayoutOptions())
    codeflash_output = lp._process_special_clusters(); result = codeflash_output # 27.1ΞΌs -> 27.8ΞΌs (2.57% slower)

def test_cluster_with_no_cells():
    # Special cluster with no cells, but children with cells
    page = make_page(100, 100, 2)
    reg = make_cluster(10, DocItemLabel.TEXT, 10, 10, 20, 20, 0.8, [0, 1])
    form = make_cluster(1, DocItemLabel.FORM, 5, 5, 25, 25, 0.8, [])
    lp = LayoutPostprocessor(page, [reg, form], LayoutOptions())
    codeflash_output = lp._process_special_clusters(); result = codeflash_output # 42.6ΞΌs -> 42.7ΞΌs (0.222% slower)
    form_out = result[0]

def test_duplicate_cells_are_deduplicated():
    # Children with overlapping cells, deduplication should occur
    page = make_page(100, 100, 3)
    reg1 = make_cluster(10, DocItemLabel.TEXT, 10, 10, 20, 20, 0.8, [0, 1])
    reg2 = make_cluster(11, DocItemLabel.TEXT, 20, 20, 30, 30, 0.8, [1, 2])
    form = make_cluster(1, DocItemLabel.FORM, 5, 5, 35, 35, 0.8, [])
    lp = LayoutPostprocessor(page, [reg1, reg2, form], LayoutOptions())
    codeflash_output = lp._process_special_clusters(); result = codeflash_output # 44.4ΞΌs -> 43.8ΞΌs (1.37% faster)
    form_out = result[0]

def test_handle_cross_type_overlaps_removes_wrapper():
    # KEY_VALUE_REGION overlaps TABLE almost exactly, and confidence difference < 0.1, so wrapper removed
    page = make_page(100, 100, 1)
    table = make_cluster(2, DocItemLabel.TABLE, 10, 10, 30, 30, 0.7, [0])
    kvr = make_cluster(1, DocItemLabel.KEY_VALUE_REGION, 10, 10, 30, 30, 0.75, [])
    lp = LayoutPostprocessor(page, [table, kvr], LayoutOptions())
    codeflash_output = lp._process_special_clusters(); result = codeflash_output # 68.2ΞΌs -> 68.0ΞΌs (0.326% faster)

def test_handle_cross_type_overlaps_does_not_remove_if_confidence_high():
    # KEY_VALUE_REGION overlaps TABLE, but confidence difference >= 0.1, so wrapper kept
    page = make_page(100, 100, 1)
    table = make_cluster(2, DocItemLabel.TABLE, 10, 10, 30, 30, 0.6, [0])
    kvr = make_cluster(1, DocItemLabel.KEY_VALUE_REGION, 10, 10, 30, 30, 0.8, [])
    lp = LayoutPostprocessor(page, [table, kvr], LayoutOptions())
    codeflash_output = lp._process_special_clusters(); result = codeflash_output # 58.7ΞΌs -> 59.1ΞΌs (0.707% slower)

def test_removes_full_page_picture_edge_case_just_under_threshold():
    # Picture covers exactly 90% of page area, should NOT be removed
    page = make_page(100, 100, 1)
    area = 100 * 100
    pic_area = int(area * 0.9)
    side = int(pic_area ** 0.5)
    c1 = make_cluster(1, DocItemLabel.PICTURE, 0, 0, side, side, 0.8, [0])
    lp = LayoutPostprocessor(page, [c1], LayoutOptions())
    codeflash_output = lp._process_special_clusters(); result = codeflash_output # 26.5ΞΌs -> 26.3ΞΌs (0.536% faster)

def test_special_cluster_with_no_regular_clusters():
    # Special cluster present, but no regular clusters to assign as children
    page = make_page(100, 100, 0)
    form = make_cluster(1, DocItemLabel.FORM, 5, 5, 25, 25, 0.8, [])
    lp = LayoutPostprocessor(page, [form], LayoutOptions())
    codeflash_output = lp._process_special_clusters(); result = codeflash_output # 25.7ΞΌs -> 24.6ΞΌs (4.57% faster)

# -------------------- LARGE SCALE TEST CASES --------------------

def test_many_special_and_regular_clusters():
    # 100 special clusters, 100 regular clusters, each special contains one regular
    page = make_page(200, 200, 200)
    regulars = [
        make_cluster(i, DocItemLabel.TEXT, i, i, i+2, i+2, 0.8, [i])
        for i in range(100)
    ]
    specials = [
        make_cluster(1000+i, DocItemLabel.FORM, i, i, i+2, i+2, 0.8, [])
        for i in range(100)
    ]
    clusters = regulars + specials
    lp = LayoutPostprocessor(page, clusters, LayoutOptions())
    codeflash_output = lp._process_special_clusters(); result = codeflash_output # 12.4ms -> 4.83ms (155% faster)
    for i, form in enumerate(result):
        pass

def test_large_number_of_cells_and_deduplication():
    # Special cluster with 500 children, each with overlapping cells, deduplication should keep all unique
    n = 500
    page = make_page(1000, 1000, n)
    regulars = [
        make_cluster(i, DocItemLabel.TEXT, i, i, i+1, i+1, 0.8, [i, (i+1)%n])
        for i in range(n)
    ]
    form = make_cluster(9999, DocItemLabel.FORM, 0, 0, n+1, n+1, 0.9, [])
    clusters = regulars + [form]
    lp = LayoutPostprocessor(page, clusters, LayoutOptions())
    codeflash_output = lp._process_special_clusters(); result = codeflash_output # 979ΞΌs -> 1.03ms (4.67% slower)
    form_out = result[0]

def test_performance_with_max_clusters():
    # 500 special clusters, 500 regular clusters, no overlap
    n = 500
    page = make_page(2000, 2000, n)
    regulars = [
        make_cluster(i, DocItemLabel.TEXT, i*2, i*2, i*2+1, i*2+1, 0.8, [i])
        for i in range(n)
    ]
    specials = [
        make_cluster(1000+i, DocItemLabel.FORM, n*2+i*2, n*2+i*2, n*2+i*2+1, n*2+i*2+1, 0.8, [])
        for i in range(n)
    ]
    clusters = regulars + specials
    lp = LayoutPostprocessor(page, clusters, LayoutOptions())
    codeflash_output = lp._process_special_clusters(); result = codeflash_output # 222ms -> 24.8ms (795% faster)
    for form in result:
        pass

def test_large_page_area_and_full_page_picture_removal():
    # Large page, picture covers >90%, should be removed
    page = make_page(1000, 1000, 1)
    c1 = make_cluster(1, DocItemLabel.PICTURE, 0, 0, 1000, 1000, 0.8, [0])
    c2 = make_cluster(2, DocItemLabel.PICTURE, 0, 0, 900, 900, 0.8, [0])
    lp = LayoutPostprocessor(page, [c1, c2], LayoutOptions())
    codeflash_output = lp._process_special_clusters(); result = codeflash_output # 47.3ΞΌs -> 45.3ΞΌs (4.31% faster)
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.

To edit these changes git checkout codeflash/optimize-LayoutPostprocessor._process_special_clusters-mcu3u6n5 and push.

Codeflash

codeflash-ai bot and others added 3 commits July 8, 2025 05:43
… 653%

Here are targeted optimizations based on the profiling output and the code.

### Major bottlenecks & optimization strategies

#### 1. `_process_special_clusters`:  
- **Main bottleneck:**  
  - The nested loop: for each special cluster, loop through all regular clusters and compute `.bbox.intersection_over_self(special.bbox)`.
  - This is `O(N*M)` for N special and M regular clusters and is by far the slowest part.
- **Optimization:**  
  - **Pre-index regular clusters by bounding box for fast containment:**  
    - Build a simple R-tree-like spatial grid (using bins, or just a fast bbox filtering pass) to filter out regular clusters that are definitely non-overlapping before running the expensive geometric calculation.  
  - **If spatial index unavailable:** Pre-filter regulars to those whose bbox intersects the special’s bbox (quick min/max bbox checks), greatly reducing pairwise calculations.

#### 2. `_handle_cross_type_overlaps`:  
- **Similar bottleneck:** Again, checking every regular cluster for every wrapper.  
  - We can apply the same bbox quick-check.

#### 3. Miscellaneous.
- **`_deduplicate_cells`/`_sort_cells` optimizations:** Minor, but batch sort/unique patterns can help.
- **Avoid recomputation:** Avoid recomputing thresholds/constants in hot loops.

Below is the optimized code addressing the biggest O(N*M) loop, using fast bbox intersection check for quick rejection before expensive calculation.
We achieve this purely with local logic in the function (no external indices needed), and respect your constraint not to introduce module-level classes.
Comments in the code indicate all changes.



**Summary of changes:**
- For both `_process_special_clusters` and `_handle_cross_type_overlaps`, we avoid unnecessary `.intersection_over_self` calculations by pre-filtering clusters based on simple bbox intersection conditions (`l < rx and r > lx and t < by and b > ty`).
- This turns expensive O(N*M) geometric checks into a two-stage filter, which is extremely fast for typical bbox distributions.
- All hot-spot loops now use local variables rather than repeated attribute lookups.
- No changes are made to APIs, outputs, or major logic branches; only faster candidate filtering is introduced.

This should reduce total runtime of `_process_special_clusters` and `_handle_cross_type_overlaps` by an order of magnitude on large documents.
Copy link
Contributor

github-actions bot commented Jul 16, 2025

βœ… DCO Check Passed

Thanks @mohammedahmed18, all your commits are properly signed off. πŸŽ‰

Copy link

mergify bot commented Jul 16, 2025

Merge Protections

Your pull request matches the following merge protections and will not be merged until they are valid.

🟒 Enforce conventional commit

Wonderful, this rule succeeded.

Make sure that we follow https://www.conventionalcommits.org/en/v1.0.0/

  • title ~= ^(fix|feat|docs|style|refactor|perf|test|build|ci|chore|revert)(?:\(.+\))?(!)?:

…, mohammed <[email protected]>, hereby add my Signed-off-by to this commit: d982474\n\nSigned-off-by: mohammed <[email protected]>n
…bot]@users.noreply.github.com>

I, codeflash-ai[bot] <148906541+codeflash-ai[bot]@users.noreply.github.com>, hereby add my Signed-off-by to this commit: 3b8deae
I, mohammed <[email protected]>, hereby add my Signed-off-by to this commit: bd8b1c4
I, mohammed <[email protected]>, hereby add my Signed-off-by to this commit: 7b84668
I, mohammed <[email protected]>, hereby add my Signed-off-by to this commit: ad90f33

Signed-off-by: mohammed <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant