Skip to content

Commit

Permalink
Ensure all models work only on valid pages (#158)
Browse files Browse the repository at this point in the history
Signed-off-by: Christoph Auer <[email protected]>
  • Loading branch information
cau-git authored Oct 18, 2024
1 parent 034a411 commit a00c937
Show file tree
Hide file tree
Showing 10 changed files with 427 additions and 390 deletions.
1 change: 1 addition & 0 deletions docling/models/ds_glm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@ def make_spans(cell):
page_dimensions = [
PageDimensions(page=p.page_no + 1, height=p.size.height, width=p.size.width)
for p in conv_res.pages
if p.size is not None
]

ds_doc: DsDocument = DsDocument(
Expand Down
82 changes: 42 additions & 40 deletions docling/models/easyocr_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,48 +41,50 @@ def __call__(self, page_batch: Iterable[Page]) -> Iterable[Page]:

for page in page_batch:
assert page._backend is not None

ocr_rects = self.get_ocr_rects(page)

all_ocr_cells = []
for ocr_rect in ocr_rects:
# Skip zero area boxes
if ocr_rect.area() == 0:
continue
high_res_image = page._backend.get_page_image(
scale=self.scale, cropbox=ocr_rect
)
im = numpy.array(high_res_image)
result = self.reader.readtext(im)

del high_res_image
del im

cells = [
OcrCell(
id=ix,
text=line[1],
confidence=line[2],
bbox=BoundingBox.from_tuple(
coord=(
(line[0][0][0] / self.scale) + ocr_rect.l,
(line[0][0][1] / self.scale) + ocr_rect.t,
(line[0][2][0] / self.scale) + ocr_rect.l,
(line[0][2][1] / self.scale) + ocr_rect.t,
),
origin=CoordOrigin.TOPLEFT,
),
if not page._backend.is_valid():
yield page
else:
ocr_rects = self.get_ocr_rects(page)

all_ocr_cells = []
for ocr_rect in ocr_rects:
# Skip zero area boxes
if ocr_rect.area() == 0:
continue
high_res_image = page._backend.get_page_image(
scale=self.scale, cropbox=ocr_rect
)
for ix, line in enumerate(result)
]
all_ocr_cells.extend(cells)
im = numpy.array(high_res_image)
result = self.reader.readtext(im)

del high_res_image
del im

cells = [
OcrCell(
id=ix,
text=line[1],
confidence=line[2],
bbox=BoundingBox.from_tuple(
coord=(
(line[0][0][0] / self.scale) + ocr_rect.l,
(line[0][0][1] / self.scale) + ocr_rect.t,
(line[0][2][0] / self.scale) + ocr_rect.l,
(line[0][2][1] / self.scale) + ocr_rect.t,
),
origin=CoordOrigin.TOPLEFT,
),
)
for ix, line in enumerate(result)
]
all_ocr_cells.extend(cells)

## Remove OCR cells which overlap with programmatic cells.
filtered_ocr_cells = self.filter_ocr_cells(all_ocr_cells, page.cells)
## Remove OCR cells which overlap with programmatic cells.
filtered_ocr_cells = self.filter_ocr_cells(all_ocr_cells, page.cells)

page.cells.extend(filtered_ocr_cells)
page.cells.extend(filtered_ocr_cells)

# DEBUG code:
# self.draw_ocr_rects_and_cells(page, ocr_rects)
# DEBUG code:
# self.draw_ocr_rects_and_cells(page, ocr_rects)

yield page
yield page
122 changes: 63 additions & 59 deletions docling/models/layout_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,68 +273,72 @@ def postprocess(self, clusters_in: List[Cluster], cells: List[Cell], page_height

def __call__(self, page_batch: Iterable[Page]) -> Iterable[Page]:
for page in page_batch:
assert page.size is not None

clusters = []
for ix, pred_item in enumerate(
self.layout_predictor.predict(page.get_image(scale=1.0))
):
label = DocItemLabel(
pred_item["label"].lower().replace(" ", "_").replace("-", "_")
) # Temporary, until docling-ibm-model uses docling-core types
cluster = Cluster(
id=ix,
label=label,
confidence=pred_item["confidence"],
bbox=BoundingBox.model_validate(pred_item),
cells=[],
)
clusters.append(cluster)

# Map cells to clusters
# TODO: Remove, postprocess should take care of it anyway.
for cell in page.cells:
for cluster in clusters:
if not cell.bbox.area() > 0:
overlap_frac = 0.0
else:
overlap_frac = (
cell.bbox.intersection_area_with(cluster.bbox)
/ cell.bbox.area()
)

if overlap_frac > 0.5:
cluster.cells.append(cell)

# Pre-sort clusters
# clusters = self.sort_clusters_by_cell_order(clusters)

# DEBUG code:
def draw_clusters_and_cells():
image = copy.deepcopy(page.image)
draw = ImageDraw.Draw(image)
for c in clusters:
x0, y0, x1, y1 = c.bbox.as_tuple()
draw.rectangle([(x0, y0), (x1, y1)], outline="green")

cell_color = (
random.randint(30, 140),
random.randint(30, 140),
random.randint(30, 140),
assert page._backend is not None
if not page._backend.is_valid():
yield page
else:
assert page.size is not None

clusters = []
for ix, pred_item in enumerate(
self.layout_predictor.predict(page.get_image(scale=1.0))
):
label = DocItemLabel(
pred_item["label"].lower().replace(" ", "_").replace("-", "_")
) # Temporary, until docling-ibm-model uses docling-core types
cluster = Cluster(
id=ix,
label=label,
confidence=pred_item["confidence"],
bbox=BoundingBox.model_validate(pred_item),
cells=[],
)
for tc in c.cells: # [:1]:
x0, y0, x1, y1 = tc.bbox.as_tuple()
draw.rectangle([(x0, y0), (x1, y1)], outline=cell_color)
image.show()
clusters.append(cluster)

# Map cells to clusters
# TODO: Remove, postprocess should take care of it anyway.
for cell in page.cells:
for cluster in clusters:
if not cell.bbox.area() > 0:
overlap_frac = 0.0
else:
overlap_frac = (
cell.bbox.intersection_area_with(cluster.bbox)
/ cell.bbox.area()
)

if overlap_frac > 0.5:
cluster.cells.append(cell)

# Pre-sort clusters
# clusters = self.sort_clusters_by_cell_order(clusters)

# DEBUG code:
def draw_clusters_and_cells():
image = copy.deepcopy(page.image)
draw = ImageDraw.Draw(image)
for c in clusters:
x0, y0, x1, y1 = c.bbox.as_tuple()
draw.rectangle([(x0, y0), (x1, y1)], outline="green")

cell_color = (
random.randint(30, 140),
random.randint(30, 140),
random.randint(30, 140),
)
for tc in c.cells: # [:1]:
x0, y0, x1, y1 = tc.bbox.as_tuple()
draw.rectangle([(x0, y0), (x1, y1)], outline=cell_color)
image.show()

# draw_clusters_and_cells()
# draw_clusters_and_cells()

clusters, page.cells = self.postprocess(
clusters, page.cells, page.size.height
)
clusters, page.cells = self.postprocess(
clusters, page.cells, page.size.height
)

# draw_clusters_and_cells()
# draw_clusters_and_cells()

page.predictions.layout = LayoutPrediction(clusters=clusters)
page.predictions.layout = LayoutPrediction(clusters=clusters)

yield page
yield page
Loading

0 comments on commit a00c937

Please sign in to comment.