diff --git a/README.md b/README.md index 847ede61..470936ab 100644 --- a/README.md +++ b/README.md @@ -47,7 +47,9 @@ python examples/convert.py ``` The output of the above command will be written to `./scratch`. -### Enable or disable pipeline features +### Adjust pipeline features + +**Control pipeline options** You can control if table structure recognition or OCR should be performed by arguments passed to `DocumentConverter`: ```python @@ -60,6 +62,23 @@ doc_converter = DocumentConverter( ) ``` +**Control table extraction options** + +You can control if table structure recognition should map the recognized structure back to PDF cells (default) or use text cells from the structure prediction itself. +This can improve output quality if you find that multiple columns in extracted tables are erroneously merged into one. + + +```python + +pipeline_options = PipelineOptions(do_table_structure=True) +pipeline_options.table_structure_options.do_cell_matching = False # Uses text cells predicted from table structure model + +doc_converter = DocumentConverter( + artifacts_path=artifacts_path, + pipeline_options=pipeline_options, +) +``` + ### Impose limits on the document size You can limit the file size and number of pages which should be allowed to process per document: diff --git a/docling/datamodel/base_models.py b/docling/datamodel/base_models.py index 1e446edd..8b6796d6 100644 --- a/docling/datamodel/base_models.py +++ b/docling/datamodel/base_models.py @@ -1,3 +1,4 @@ +import copy from enum import Enum, auto from io import BytesIO from typing import Any, Dict, List, Optional, Tuple, Union @@ -47,6 +48,15 @@ def width(self): def height(self): return abs(self.t - self.b) + def scaled(self, scale: float) -> "BoundingBox": + out_bbox = copy.deepcopy(self) + out_bbox.l *= scale + out_bbox.r *= scale + out_bbox.t *= scale + out_bbox.b *= scale + + return out_bbox + def as_tuple(self): if self.coord_origin == CoordOrigin.TOPLEFT: return (self.l, self.t, self.r, self.b) @@ -241,6 +251,17 @@ class DocumentStream(BaseModel): stream: BytesIO +class TableStructureOptions(BaseModel): + do_cell_matching: bool = ( + True + # True: Matches predictions back to PDF cells. Can break table output if PDF cells + # are merged across table columns. + # False: Let table structure model define the text cells, ignore PDF cells. + ) + + class PipelineOptions(BaseModel): - do_table_structure: bool = True - do_ocr: bool = False + do_table_structure: bool = True # True: perform table structure extraction + do_ocr: bool = False # True: perform OCR, replace programmatic PDF text + + table_structure_options: TableStructureOptions = TableStructureOptions() diff --git a/docling/models/page_assemble_model.py b/docling/models/page_assemble_model.py index 4ed0832d..2b9db544 100644 --- a/docling/models/page_assemble_model.py +++ b/docling/models/page_assemble_model.py @@ -19,18 +19,6 @@ class PageAssembleModel: def __init__(self, config): self.config = config - # self.line_wrap_pattern = re.compile(r'(?<=[^\W_])- \n(?=\w)') - - # def sanitize_text_poor(self, lines): - # text = '\n'.join(lines) - # - # # treat line wraps. - # sanitized_text = self.line_wrap_pattern.sub('', text) - # - # sanitized_text = sanitized_text.replace('\n', ' ') - # - # return sanitized_text - def sanitize_text(self, lines): if len(lines) <= 1: return " ".join(lines) diff --git a/docling/models/table_structure_model.py b/docling/models/table_structure_model.py index 8ee4bdae..132b141c 100644 --- a/docling/models/table_structure_model.py +++ b/docling/models/table_structure_model.py @@ -1,7 +1,10 @@ -from typing import Iterable +import copy +import random +from typing import Iterable, List import numpy from docling_ibm_models.tableformer.data_management.tf_predictor import TFPredictor +from PIL import ImageDraw from docling.datamodel.base_models import ( BoundingBox, @@ -28,6 +31,21 @@ def __init__(self, config): self.tm_model_type = self.tm_config["model"]["type"] self.tf_predictor = TFPredictor(self.tm_config) + self.scale = 2.0 # Scale up table input images to 144 dpi + + def draw_table_and_cells(self, page: Page, tbl_list: List[TableElement]): + image = page._backend.get_page_image() + draw = ImageDraw.Draw(image) + + for table_element in tbl_list: + x0, y0, x1, y1 = table_element.cluster.bbox.as_tuple() + draw.rectangle([(x0, y0), (x1, y1)], outline="red") + + for tc in table_element.table_cells: + x0, y0, x1, y1 = tc.bbox.as_tuple() + draw.rectangle([(x0, y0), (x1, y1)], outline="blue") + + image.show() def __call__(self, page_batch: Iterable[Page]) -> Iterable[Page]: @@ -36,16 +54,17 @@ def __call__(self, page_batch: Iterable[Page]) -> Iterable[Page]: return for page in page_batch: + page.predictions.tablestructure = TableStructurePrediction() # dummy in_tables = [ ( cluster, [ - round(cluster.bbox.l), - round(cluster.bbox.t), - round(cluster.bbox.r), - round(cluster.bbox.b), + round(cluster.bbox.l) * self.scale, + round(cluster.bbox.t) * self.scale, + round(cluster.bbox.r) * self.scale, + round(cluster.bbox.b) * self.scale, ], ) for cluster in page.predictions.layout.clusters @@ -65,20 +84,29 @@ def __call__(self, page_batch: Iterable[Page]) -> Iterable[Page]: ): # Only allow non empty stings (spaces) into the cells of a table if len(c.text.strip()) > 0: - tokens.append(c.model_dump()) + new_cell = copy.deepcopy(c) + new_cell.bbox = new_cell.bbox.scaled(scale=self.scale) + + tokens.append(new_cell.model_dump()) - iocr_page = { - "image": numpy.asarray(page.image), + page_input = { "tokens": tokens, - "width": page.size.width, - "height": page.size.height, + "width": page.size.width * self.scale, + "height": page.size.height * self.scale, } + # add image to page input. + if self.scale == 1.0: + page_input["image"] = numpy.asarray(page.image) + else: # render new page image on the fly at desired scale + page_input["image"] = numpy.asarray( + page._backend.get_page_image(scale=self.scale) + ) table_clusters, table_bboxes = zip(*in_tables) if len(table_bboxes): tf_output = self.tf_predictor.multi_table_predict( - iocr_page, table_bboxes, do_matching=self.do_cell_matching + page_input, table_bboxes, do_matching=self.do_cell_matching ) for table_cluster, table_out in zip(table_clusters, tf_output): @@ -91,6 +119,7 @@ def __call__(self, page_batch: Iterable[Page]) -> Iterable[Page]: element["bbox"]["token"] = text_piece tc = TableCell.model_validate(element) + tc.bbox = tc.bbox.scaled(1 / self.scale) table_cells.append(tc) # Retrieving cols/rows, after post processing: @@ -111,4 +140,7 @@ def __call__(self, page_batch: Iterable[Page]) -> Iterable[Page]: page.predictions.tablestructure.table_map[table_cluster.id] = tbl + # For debugging purposes: + # self.draw_table_and_cells(page, page.predictions.tablestructure.table_map.values()) + yield page diff --git a/docling/pipeline/standard_model_pipeline.py b/docling/pipeline/standard_model_pipeline.py index 07c01135..33fee75e 100644 --- a/docling/pipeline/standard_model_pipeline.py +++ b/docling/pipeline/standard_model_pipeline.py @@ -34,7 +34,7 @@ def __init__(self, artifacts_path: Path, pipeline_options: PipelineOptions): "artifacts_path": artifacts_path / StandardModelPipeline._table_model_path, "enabled": pipeline_options.do_table_structure, - "do_cell_matching": False, + "do_cell_matching": pipeline_options.table_structure_options.do_cell_matching, } ), ]