Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
figure classifier
Browse files Browse the repository at this point in the history
Signed-off-by: Matteo Omenetti <[email protected]>
Matteo-Omenetti committed Jan 24, 2025
1 parent 3213b24 commit 8ecb810
Showing 3 changed files with 199 additions and 0 deletions.
1 change: 1 addition & 0 deletions docling/datamodel/pipeline_options.py
Original file line number Diff line number Diff line change
@@ -221,6 +221,7 @@ class PdfPipelineOptions(PipelineOptions):
do_ocr: bool = True # True: perform OCR, replace programmatic PDF text
do_code_enrichment: bool = False # True: perform code OCR
do_formula_enrichment: bool = False # True: perform formula OCR, return Latex code
do_picture_classification: bool = False # True: classify pictures in documents

table_structure_options: TableStructureOptions = TableStructureOptions()
ocr_options: Union[
187 changes: 187 additions & 0 deletions docling/models/document_picture_classifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
from pathlib import Path
from typing import Iterable, List, Literal, Optional, Tuple, Union

from docling_core.types.doc import (
DoclingDocument,
NodeItem,
PictureClassificationClass,
PictureClassificationData,
PictureItem,
)
from PIL import Image
from pydantic import BaseModel

from docling.datamodel.pipeline_options import AcceleratorOptions
from docling.models.base_model import BaseEnrichmentModel
from docling.utils.accelerator_utils import decide_device


class DocumentPictureClassifierOptions(BaseModel):
"""
Options for configuring the DocumentPictureClassifier.
Attributes
----------
kind : Literal["document_picture_classifier"]
Identifier for the type of classifier.
"""

kind: Literal["document_picture_classifier"] = "document_picture_classifier"


class DocumentPictureClassifier(BaseEnrichmentModel):
"""
A model for classifying pictures in documents.
This class enriches document pictures with predicted classifications
based on a predefined set of classes.
Attributes
----------
enabled : bool
Whether the classifier is enabled for use.
options : DocumentPictureClassifierOptions
Configuration options for the classifier.
document_picture_classifier : DocumentPictureClassifierPredictor
The underlying prediction model, loaded if the classifier is enabled.
Methods
-------
__init__(enabled, artifacts_path, options, accelerator_options)
Initializes the classifier with specified configurations.
is_processable(doc, element)
Checks if the given element can be processed by the classifier.
__call__(doc, element_batch)
Processes a batch of elements and adds classification annotations.
"""

images_scale = 2

def __init__(
self,
enabled: bool,
artifacts_path: Optional[Union[Path, str]],
options: DocumentPictureClassifierOptions,
accelerator_options: AcceleratorOptions,
):
"""
Initializes the DocumentPictureClassifier.
Parameters
----------
enabled : bool
Indicates whether the classifier is enabled.
artifacts_path : Optional[Union[Path, str]],
Path to the directory containing model artifacts.
options : DocumentPictureClassifierOptions
Configuration options for the classifier.
accelerator_options : AcceleratorOptions
Options for configuring the device and parallelism.
"""
self.enabled = enabled
self.options = options

if self.enabled:
device = decide_device(accelerator_options.device)
from docling_ibm_models.document_figure_classifier_model.document_figure_classifier_predictor import (
DocumentFigureClassifierPredictor,
)

if artifacts_path is None:
artifacts_path = self.download_models_hf()
else:
artifacts_path = Path(artifacts_path)

self.document_picture_classifier = DocumentFigureClassifierPredictor(
artifacts_path=artifacts_path,
device=device,
num_threads=accelerator_options.num_threads,
)

@staticmethod
def download_models_hf(
local_dir: Optional[Path] = None, force: bool = False
) -> Path:
from huggingface_hub import snapshot_download
from huggingface_hub.utils import disable_progress_bars

disable_progress_bars()
download_path = snapshot_download(
repo_id="ds4sd/DocumentFigureClassifier",
force_download=force,
local_dir=local_dir,
revision="v1.0.0",
)

return Path(download_path)

def is_processable(self, doc: DoclingDocument, element: NodeItem) -> bool:
"""
Determines if the given element can be processed by the classifier.
Parameters
----------
doc : DoclingDocument
The document containing the element.
element : NodeItem
The element to be checked.
Returns
-------
bool
True if the element is a PictureItem and processing is enabled; False otherwise.
"""
return self.enabled and isinstance(element, PictureItem)

def __call__(
self,
doc: DoclingDocument,
element_batch: Iterable[NodeItem],
) -> Iterable[NodeItem]:
"""
Processes a batch of elements and enriches them with classification predictions.
Parameters
----------
doc : DoclingDocument
The document containing the elements to be processed.
element_batch : Iterable[NodeItem]
A batch of pictures to classify.
Returns
-------
Iterable[NodeItem]
An iterable of NodeItem objects after processing. The field
'data.classification' is added containing the classification for each picture.
"""
if not self.enabled:
for element in element_batch:
yield element
return

images: List[Image.Image] = []
elements: List[PictureItem] = []
for el in element_batch:
assert isinstance(el, PictureItem)
elements.append(el)
img = el.get_image(doc)
assert img is not None
images.append(img)

outputs = self.document_picture_classifier.predict(images)

for element, output in zip(elements, outputs):
element.annotations.append(
PictureClassificationData(
provenance="DocumentPictureClassifier",
predicted_classes=[
PictureClassificationClass(
class_name=pred[0],
confidence=pred[1],
)
for pred in output
],
)
)

yield element
11 changes: 11 additions & 0 deletions docling/pipeline/standard_pdf_pipeline.py
Original file line number Diff line number Diff line change
@@ -19,6 +19,10 @@
)
from docling.models.base_ocr_model import BaseOcrModel
from docling.models.code_formula_model import CodeFormulaModel, CodeFormulaModelOptions
from docling.models.document_picture_classifier import (
DocumentPictureClassifier,
DocumentPictureClassifierOptions,
)
from docling.models.ds_glm_model import GlmModel, GlmOptions
from docling.models.easyocr_model import EasyOcrModel
from docling.models.layout_model import LayoutModel
@@ -104,6 +108,13 @@ def __init__(self, pipeline_options: PdfPipelineOptions):
),
accelerator_options=pipeline_options.accelerator_options,
),
# Document Picture Classifier
DocumentPictureClassifier(
enabled=pipeline_options.do_picture_classification,
artifacts_path=pipeline_options.artifacts_path,
options=DocumentPictureClassifierOptions(),
accelerator_options=pipeline_options.accelerator_options,
),
]

if (

0 comments on commit 8ecb810

Please sign in to comment.