From 8aa476ccd3cb7bc513a40df67fc6aa99aab0e780 Mon Sep 17 00:00:00 2001 From: Michele Dolfi <97102151+dolfim-ibm@users.noreply.github.com> Date: Thu, 12 Sep 2024 15:56:29 +0200 Subject: [PATCH] test: improve typing definitions (part 1) (#72) Signed-off-by: Michele Dolfi --- docling/backend/abstract_backend.py | 9 ++++--- docling/backend/docling_parse_backend.py | 9 +++---- docling/backend/pypdfium2_backend.py | 8 +++--- docling/datamodel/base_models.py | 6 ++--- docling/pipeline/base_model_pipeline.py | 4 +-- docling/utils/export.py | 29 ++++++++++++++------- poetry.lock | 33 +++++++++++++++++++++++- pyproject.toml | 9 +++++++ tests/verify_utils.py | 13 +++++++++- 9 files changed, 91 insertions(+), 29 deletions(-) diff --git a/docling/backend/abstract_backend.py b/docling/backend/abstract_backend.py index 7bb53fce..66df2869 100644 --- a/docling/backend/abstract_backend.py +++ b/docling/backend/abstract_backend.py @@ -1,10 +1,13 @@ from abc import ABC, abstractmethod from io import BytesIO from pathlib import Path -from typing import Any, Iterable, Optional, Union +from typing import TYPE_CHECKING, Any, Iterable, Optional, Union from PIL import Image +if TYPE_CHECKING: + from docling.datamodel.base_models import BoundingBox, Cell, PageSize + class PdfPageBackend(ABC): @@ -17,12 +20,12 @@ def get_text_cells(self) -> Iterable["Cell"]: pass @abstractmethod - def get_bitmap_rects(self, scale: int = 1) -> Iterable["BoundingBox"]: + def get_bitmap_rects(self, float: int = 1) -> Iterable["BoundingBox"]: pass @abstractmethod def get_page_image( - self, scale: int = 1, cropbox: Optional["BoundingBox"] = None + self, scale: float = 1, cropbox: Optional["BoundingBox"] = None ) -> Image.Image: pass diff --git a/docling/backend/docling_parse_backend.py b/docling/backend/docling_parse_backend.py index aeaf4739..d7a116d4 100644 --- a/docling/backend/docling_parse_backend.py +++ b/docling/backend/docling_parse_backend.py @@ -2,7 +2,7 @@ import random from io import BytesIO from pathlib import Path -from typing import Iterable, Optional, Union +from typing import Iterable, List, Optional, Union import pypdfium2 as pdfium from docling_parse.docling_parse import pdf_parser @@ -22,7 +22,6 @@ def __init__( self._ppage = page_obj parsed_page = parser.parse_pdf_from_key_on_page(document_hash, page_no) - self._dpage = None self.valid = "pages" in parsed_page if self.valid: self._dpage = parsed_page["pages"][0] @@ -68,7 +67,7 @@ def get_text_in_rect(self, bbox: BoundingBox) -> str: return text_piece def get_text_cells(self) -> Iterable[Cell]: - cells = [] + cells: List[Cell] = [] cell_counter = 0 if not self.valid: @@ -130,7 +129,7 @@ def draw_clusters_and_cells(): return cells - def get_bitmap_rects(self, scale: int = 1) -> Iterable[BoundingBox]: + def get_bitmap_rects(self, scale: float = 1) -> Iterable[BoundingBox]: AREA_THRESHOLD = 32 * 32 for i in range(len(self._dpage["images"])): @@ -145,7 +144,7 @@ def get_bitmap_rects(self, scale: int = 1) -> Iterable[BoundingBox]: yield cropbox def get_page_image( - self, scale: int = 1, cropbox: Optional[BoundingBox] = None + self, scale: float = 1, cropbox: Optional[BoundingBox] = None ) -> Image.Image: page_size = self.get_size() diff --git a/docling/backend/pypdfium2_backend.py b/docling/backend/pypdfium2_backend.py index b7ec824a..81ab8488 100644 --- a/docling/backend/pypdfium2_backend.py +++ b/docling/backend/pypdfium2_backend.py @@ -7,7 +7,7 @@ import pypdfium2 as pdfium import pypdfium2.raw as pdfium_c from PIL import Image, ImageDraw -from pypdfium2 import PdfPage +from pypdfium2 import PdfPage, PdfTextPage from pypdfium2._helpers.misc import PdfiumError from docling.backend.abstract_backend import PdfDocumentBackend, PdfPageBackend @@ -29,12 +29,12 @@ def __init__( exc_info=True, ) self.valid = False - self.text_page = None + self.text_page: Optional[PdfTextPage] = None def is_valid(self) -> bool: return self.valid - def get_bitmap_rects(self, scale: int = 1) -> Iterable[BoundingBox]: + def get_bitmap_rects(self, scale: float = 1) -> Iterable[BoundingBox]: AREA_THRESHOLD = 32 * 32 for obj in self._ppage.get_objects(filter=[pdfium_c.FPDF_PAGEOBJ_IMAGE]): pos = obj.get_pos() @@ -189,7 +189,7 @@ def draw_clusters_and_cells(): return cells def get_page_image( - self, scale: int = 1, cropbox: Optional[BoundingBox] = None + self, scale: float = 1, cropbox: Optional[BoundingBox] = None ) -> Image.Image: page_size = self.get_size() diff --git a/docling/datamodel/base_models.py b/docling/datamodel/base_models.py index 71238b8d..e9c51d69 100644 --- a/docling/datamodel/base_models.py +++ b/docling/datamodel/base_models.py @@ -87,7 +87,7 @@ def as_tuple(self): return (self.l, self.b, self.r, self.t) @classmethod - def from_tuple(cls, coord: Tuple[float], origin: CoordOrigin): + def from_tuple(cls, coord: Tuple[float, ...], origin: CoordOrigin): if origin == CoordOrigin.TOPLEFT: l, t, r, b = coord[0], coord[1], coord[2], coord[3] if r < l: @@ -246,7 +246,7 @@ class EquationPrediction(BaseModel): class PagePredictions(BaseModel): - layout: LayoutPrediction = None + layout: Optional[LayoutPrediction] = None tablestructure: Optional[TableStructurePrediction] = None figures_classification: Optional[FigureClassificationPrediction] = None equations_prediction: Optional[EquationPrediction] = None @@ -267,7 +267,7 @@ class Page(BaseModel): page_no: int page_hash: Optional[str] = None size: Optional[PageSize] = None - cells: List[Cell] = None + cells: List[Cell] = [] predictions: PagePredictions = PagePredictions() assembled: Optional[AssembledUnit] = None diff --git a/docling/pipeline/base_model_pipeline.py b/docling/pipeline/base_model_pipeline.py index 680a1140..4fdde951 100644 --- a/docling/pipeline/base_model_pipeline.py +++ b/docling/pipeline/base_model_pipeline.py @@ -1,12 +1,12 @@ from pathlib import Path -from typing import Iterable +from typing import Callable, Iterable, List from docling.datamodel.base_models import Page, PipelineOptions class BaseModelPipeline: def __init__(self, artifacts_path: Path, pipeline_options: PipelineOptions): - self.model_pipe = [] + self.model_pipe: List[Callable] = [] self.artifacts_path = artifacts_path self.pipeline_options = pipeline_options diff --git a/docling/utils/export.py b/docling/utils/export.py index f438ed1d..115f7646 100644 --- a/docling/utils/export.py +++ b/docling/utils/export.py @@ -1,10 +1,10 @@ import logging -from typing import Any, Dict, Iterable, List, Tuple +from typing import Any, Dict, Iterable, List, Tuple, Union -from docling_core.types.doc.base import BaseCell, Ref, Table, TableCell +from docling_core.types.doc.base import BaseCell, BaseText, Ref, Table, TableCell from docling.datamodel.base_models import BoundingBox, CoordOrigin, OcrCell -from docling.datamodel.document import ConvertedDocument, Page +from docling.datamodel.document import ConversionResult, Page _log = logging.getLogger(__name__) @@ -15,7 +15,10 @@ def _export_table_to_html(table: Table): # to the docling-core package. def _get_tablecell_span(cell: TableCell, ix): - span = set([s[ix] for s in cell.spans]) + if cell.spans is None: + span = set() + else: + span = set([s[ix] for s in cell.spans]) if len(span) == 0: return 1, None, None return len(span), min(span), max(span) @@ -24,6 +27,8 @@ def _get_tablecell_span(cell: TableCell, ix): nrows = table.num_rows ncols = table.num_cols + if table.data is None: + return "" for i in range(nrows): body += "" for j in range(ncols): @@ -66,7 +71,7 @@ def _get_tablecell_span(cell: TableCell, ix): def generate_multimodal_pages( - doc_result: ConvertedDocument, + doc_result: ConversionResult, ) -> Iterable[Tuple[str, str, List[Dict[str, Any]], List[Dict[str, Any]], Page]]: label_to_doclaynet = { @@ -94,7 +99,7 @@ def generate_multimodal_pages( page_no = 0 start_ix = 0 end_ix = 0 - doc_items = [] + doc_items: List[Tuple[int, Union[BaseCell, BaseText]]] = [] doc = doc_result.output @@ -105,11 +110,11 @@ def _process_page_segments(doc_items: list[Tuple[int, BaseCell]], page: Page): item_type = item.obj_type label = label_to_doclaynet.get(item_type, None) - if label is None: + if label is None or item.prov is None or page.size is None: continue bbox = BoundingBox.from_tuple( - item.prov[0].bbox, origin=CoordOrigin.BOTTOMLEFT + tuple(item.prov[0].bbox), origin=CoordOrigin.BOTTOMLEFT ) new_bbox = bbox.to_top_left_origin(page_height=page.size.height).normalized( page_size=page.size @@ -137,13 +142,15 @@ def _process_page_segments(doc_items: list[Tuple[int, BaseCell]], page: Page): return segments def _process_page_cells(page: Page): - cells = [] + cells: List[dict] = [] + if page.size is None: + return cells for cell in page.cells: new_bbox = cell.bbox.to_top_left_origin( page_height=page.size.height ).normalized(page_size=page.size) is_ocr = isinstance(cell, OcrCell) - ocr_confidence = cell.confidence if is_ocr else 1.0 + ocr_confidence = cell.confidence if isinstance(cell, OcrCell) else 1.0 cells.append( { "text": cell.text, @@ -170,6 +177,8 @@ def _process_page(): return content_text, content_md, content_dt, page_cells, page_segments, page + if doc.main_text is None: + return for ix, orig_item in enumerate(doc.main_text): item = doc._resolve_ref(orig_item) if isinstance(orig_item, Ref) else orig_item diff --git a/poetry.lock b/poetry.lock index e711babc..9492e560 100644 --- a/poetry.lock +++ b/poetry.lock @@ -3771,6 +3771,21 @@ sql-other = ["SQLAlchemy (>=2.0.0)", "adbc-driver-postgresql (>=0.8.0)", "adbc-d test = ["hypothesis (>=6.46.1)", "pytest (>=7.3.2)", "pytest-xdist (>=2.2.0)"] xml = ["lxml (>=4.9.2)"] +[[package]] +name = "pandas-stubs" +version = "2.2.2.240909" +description = "Type annotations for pandas" +optional = false +python-versions = ">=3.10" +files = [ + {file = "pandas_stubs-2.2.2.240909-py3-none-any.whl", hash = "sha256:e230f5fa4065f9417804f4d65cd98f86c002efcc07933e8abcd48c3fad9c30a2"}, + {file = "pandas_stubs-2.2.2.240909.tar.gz", hash = "sha256:3c0951a2c3e45e3475aed9d80b7147ae82f176b9e42e9fb321cfdebf3d411b3d"}, +] + +[package.dependencies] +numpy = ">=1.23.5" +types-pytz = ">=2022.1.1" + [[package]] name = "parso" version = "0.8.4" @@ -6584,6 +6599,11 @@ files = [ {file = "triton-3.0.0-1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:34e509deb77f1c067d8640725ef00c5cbfcb2052a1a3cb6a6d343841f92624eb"}, {file = "triton-3.0.0-1-cp38-cp38-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:bcbf3b1c48af6a28011a5c40a5b3b9b5330530c3827716b5fbf6d7adcc1e53e9"}, {file = "triton-3.0.0-1-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:6e5727202f7078c56f91ff13ad0c1abab14a0e7f2c87e91b12b6f64f3e8ae609"}, + {file = "triton-3.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:39b052da883351fdf6be3d93cedae6db3b8e3988d3b09ed221bccecfa9612230"}, + {file = "triton-3.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cd34f19a8582af96e6291d4afce25dac08cb2a5d218c599163761e8e0827208e"}, + {file = "triton-3.0.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0d5e10de8c011adeb7c878c6ce0dd6073b14367749e34467f1cff2bde1b78253"}, + {file = "triton-3.0.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e8903767951bf86ec960b4fe4e21bc970055afc65e9d57e916d79ae3c93665e3"}, + {file = "triton-3.0.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:41004fb1ae9a53fcb3e970745feb87f0e3c94c6ce1ba86e95fa3b8537894bef7"}, ] [package.dependencies] @@ -6617,6 +6637,17 @@ rfc3986 = ">=1.4.0" tqdm = ">=4.14" urllib3 = ">=1.26.0" +[[package]] +name = "types-pytz" +version = "2024.1.0.20240417" +description = "Typing stubs for pytz" +optional = false +python-versions = ">=3.8" +files = [ + {file = "types-pytz-2024.1.0.20240417.tar.gz", hash = "sha256:6810c8a1f68f21fdf0f4f374a432487c77645a0ac0b31de4bf4690cf21ad3981"}, + {file = "types_pytz-2024.1.0.20240417-py3-none-any.whl", hash = "sha256:8335d443310e2db7b74e007414e74c4f53b67452c0cb0d228ca359ccfba59659"}, +] + [[package]] name = "types-requests" version = "2.32.0.20240907" @@ -7169,4 +7200,4 @@ examples = ["langchain-huggingface", "langchain-milvus", "langchain-text-splitte [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "5ce8fc1e245442e355b967430e211b1378fed2e9fd20d2ddbea47f0e9f1dfcd5" +content-hash = "b881ea7a3504555707e0778c7c25631cbb353b78da04bd724852c7d34f39d46d" diff --git a/pyproject.toml b/pyproject.toml index 0542e679..9842831b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -74,6 +74,7 @@ pytest-xdist = "^3.3.1" types-requests = "^2.31.0.2" flake8-pyproject = "^1.2.3" pylint = "^2.17.5" +pandas-stubs = "^2.2.2.240909" ipykernel = "^6.29.5" ipywidgets = "^8.1.5" nbqa = "^1.9.0" @@ -114,6 +115,14 @@ pretty = true no_implicit_optional = true python_version = "3.10" +[[tool.mypy.overrides]] +module = [ + "docling_parse.*", + "pypdfium2.*", + "networkx.*", +] +ignore_missing_imports = true + [tool.flake8] max-line-length = 88 extend-ignore = ["E203", "E501"] diff --git a/tests/verify_utils.py b/tests/verify_utils.py index 448b7b61..e66cc79c 100644 --- a/tests/verify_utils.py +++ b/tests/verify_utils.py @@ -45,6 +45,8 @@ def verify_cells(doc_pred_pages: List[Page], doc_true_pages: List[Page]): def verify_maintext(doc_pred: DsDocument, doc_true: DsDocument): + assert doc_true.main_text is not None, "doc_true cannot be None" + assert doc_pred.main_text is not None, "doc_true cannot be None" assert len(doc_true.main_text) == len( doc_pred.main_text @@ -68,6 +70,13 @@ def verify_maintext(doc_pred: DsDocument, doc_true: DsDocument): def verify_tables(doc_pred: DsDocument, doc_true: DsDocument): + if doc_true.tables is None: + # No tables to check + assert doc_pred.tables is None, "not expecting any table on this document" + return True + + assert doc_pred.tables is not None, "no tables predicted, but expected in doc_true" + assert len(doc_true.tables) == len( doc_pred.tables ), "document has different count of tables than expected." @@ -82,6 +91,8 @@ def verify_tables(doc_pred: DsDocument, doc_true: DsDocument): true_item.num_cols == pred_item.num_cols ), "table does not have the same #-cols" + assert true_item.data is not None, "documents are expected to have table data" + assert pred_item.data is not None, "documents are expected to have table data" for i, row in enumerate(true_item.data): for j, col in enumerate(true_item.data[i]): @@ -135,7 +146,7 @@ def verify_conversion_result( doc_true_pages = PageList.validate_json(fr.read()) with open(json_path, "r") as fr: - doc_true = DsDocument.model_validate_json(fr.read()) + doc_true: DsDocument = DsDocument.model_validate_json(fr.read()) with open(md_path, "r") as fr: doc_true_md = fr.read()