From a122a7be4cae5a17b39803316b1f17824add4ece Mon Sep 17 00:00:00 2001 From: Michele Dolfi Date: Sun, 22 Sep 2024 20:24:38 +0200 Subject: [PATCH] introduce img understand pipeline Signed-off-by: Michele Dolfi --- docling/datamodel/base_models.py | 23 +++- docling/models/img_understand_api_model.py | 123 +++++++++++++++++ docling/models/img_understand_base_model.py | 145 ++++++++++++++++++++ docling/models/img_understand_vllm_model.py | 87 ++++++++++++ docling/models/page_assemble_model.py | 7 +- docling/pipeline/img_understand_pipeline.py | 53 +++++++ examples/img_understand_pipeline.py | 132 ++++++++++++++++++ pyproject.toml | 2 + 8 files changed, 561 insertions(+), 11 deletions(-) create mode 100644 docling/models/img_understand_api_model.py create mode 100644 docling/models/img_understand_base_model.py create mode 100644 docling/models/img_understand_vllm_model.py create mode 100644 docling/pipeline/img_understand_pipeline.py create mode 100644 examples/img_understand_pipeline.py diff --git a/docling/datamodel/base_models.py b/docling/datamodel/base_models.py index e9c51d69..1e8f7981 100644 --- a/docling/datamodel/base_models.py +++ b/docling/datamodel/base_models.py @@ -224,18 +224,27 @@ class TableStructurePrediction(BaseModel): class TextElement(BasePageElement): ... +class FigureClassificationData(BaseModel): + provenance: str + predicted_class: str + confidence: float + + +class FigureDescriptionData(BaseModel): + text: str + provenance: str = "" + + class FigureData(BaseModel): - pass + classification: Optional[FigureClassificationData] = None + description: Optional[FigureDescriptionData] = None class FigureElement(BasePageElement): - data: Optional[FigureData] = None - provenance: Optional[str] = None - predicted_class: Optional[str] = None - confidence: Optional[float] = None + data: FigureData = FigureData() -class FigureClassificationPrediction(BaseModel): +class FigurePrediction(BaseModel): figure_count: int = 0 figure_map: Dict[int, FigureElement] = {} @@ -248,7 +257,7 @@ class EquationPrediction(BaseModel): class PagePredictions(BaseModel): layout: Optional[LayoutPrediction] = None tablestructure: Optional[TableStructurePrediction] = None - figures_classification: Optional[FigureClassificationPrediction] = None + figures_prediction: Optional[FigurePrediction] = None equations_prediction: Optional[EquationPrediction] = None diff --git a/docling/models/img_understand_api_model.py b/docling/models/img_understand_api_model.py new file mode 100644 index 00000000..bc62b206 --- /dev/null +++ b/docling/models/img_understand_api_model.py @@ -0,0 +1,123 @@ +import base64 +import datetime +import io +import logging +from typing import Any, Dict, Iterable, List, Literal, Optional, Tuple + +import httpx +from PIL import Image +from pydantic import AnyUrl, BaseModel, ConfigDict + +from docling.datamodel.base_models import Cluster, FigureDescriptionData +from docling.models.img_understand_base_model import ( + ImgUnderstandBaseModel, + ImgUnderstandOptions, +) + +_log = logging.getLogger(__name__) + + +class ImgUnderstandApiOptions(ImgUnderstandOptions): + kind: Literal["api"] = "api" + + url: AnyUrl + headers: Dict[str, str] + params: Dict[str, Any] + timeout: float = 20 + + llm_prompt: str + provenance: str + + +class ChatMessage(BaseModel): + role: str + content: str + + +class ResponseChoice(BaseModel): + index: int + message: ChatMessage + finish_reason: str + + +class ResponseUsage(BaseModel): + prompt_tokens: int + completion_tokens: int + total_tokens: int + + +class ApiResponse(BaseModel): + model_config = ConfigDict( + protected_namespaces=(), + ) + + id: str + model_id: Optional[str] = None # returned by watsonx + model: Optional[str] = None # returned bu openai + choices: List[ResponseChoice] + created: int + usage: ResponseUsage + + +class ImgUnderstandApiModel(ImgUnderstandBaseModel): + + def __init__(self, enabled: bool, options: ImgUnderstandApiOptions): + super().__init__(enabled=enabled, options=options) + self.options: ImgUnderstandApiOptions + + def _annotate_image_batch( + self, batch: Iterable[Tuple[Cluster, Image.Image]] + ) -> List[FigureDescriptionData]: + + if not self.enabled: + return [FigureDescriptionData() for _ in batch] + + results = [] + for cluster, image in batch: + img_io = io.BytesIO() + image.save(img_io, "PNG") + image_base64 = base64.b64encode(img_io.getvalue()).decode("utf-8") + + messages = [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": self.options.llm_prompt, + }, + { + "type": "image_url", + "image_url": { + "url": f"data:image/png;base64,{image_base64}" + }, + }, + ], + } + ] + + payload = { + "messages": messages, + **self.options.params, + } + + r = httpx.post( + str(self.options.url), + headers=self.options.headers, + json=payload, + timeout=self.options.timeout, + ) + if not r.is_success: + _log.error(f"Error calling the API. Reponse was {r.text}") + r.raise_for_status() + + api_resp = ApiResponse.model_validate_json(r.text) + generated_text = api_resp.choices[0].message.content.strip() + results.append( + FigureDescriptionData( + text=generated_text, provenance=self.options.provenance + ) + ) + _log.info(f"Generated description: {generated_text}") + + return results diff --git a/docling/models/img_understand_base_model.py b/docling/models/img_understand_base_model.py new file mode 100644 index 00000000..7abb6e55 --- /dev/null +++ b/docling/models/img_understand_base_model.py @@ -0,0 +1,145 @@ +import logging +import time +from typing import Iterable, List, Literal, Tuple + +from PIL import Image +from pydantic import BaseModel + +from docling.datamodel.base_models import ( + Cluster, + FigureData, + FigureDescriptionData, + FigureElement, + FigurePrediction, + Page, +) + +_log = logging.getLogger(__name__) + + +class ImgUnderstandOptions(BaseModel): + kind: str + batch_size: int = 8 + scale: float = 2 + + # if the relative area of the image with respect to the whole image page + # is larger than this threshold it will be processed, otherwise not. + # TODO: implement the skip logic + min_area: float = 0.05 + + +class ImgUnderstandBaseModel: + + def __init__(self, enabled: bool, options: ImgUnderstandOptions): + self.enabled = enabled + self.options = options + + def _annotate_image_batch( + self, batch: Iterable[Tuple[Cluster, Image.Image]] + ) -> List[FigureDescriptionData]: + raise NotImplemented() + + def _flush_merge( + self, + page: Page, + cluster_figure_batch: List[Tuple[Cluster, Image.Image]], + figures_prediction: FigurePrediction, + ): + start_time = time.time() + results_batch = self._annotate_image_batch(cluster_figure_batch) + assert len(results_batch) == len( + cluster_figure_batch + ), "The returned annotations is not matching the input size" + end_time = time.time() + _log.info( + f"Batch of {len(results_batch)} images processed in {end_time-start_time:.1f} seconds. Time per image is {(end_time-start_time) / len(results_batch):.3f} seconds." + ) + + for (cluster, _), desc_data in zip(cluster_figure_batch, results_batch): + if not cluster.id in figures_prediction.figure_map: + figures_prediction.figure_map[cluster.id] = FigureElement( + label=cluster.label, + id=cluster.id, + data=FigureData(desciption=desc_data), + cluster=cluster, + page_no=page.page_no, + ) + elif figures_prediction.figure_map[cluster.id].data.description is None: + figures_prediction.figure_map[cluster.id].data.description = desc_data + else: + _log.warning( + f"Conflicting predictions. " + f"Another model ({figures_prediction.figure_map[cluster.id].data.description.provenance}) " + f"was already predicting an image description. The new prediction will be skipped." + ) + + def __call__(self, page_batch: Iterable[Page]) -> Iterable[Page]: + + if not self.enabled: + yield from page_batch + return + + for page in page_batch: + + # This model could be the first one initializing figures_prediction + if page.predictions.figures_prediction is None: + page.predictions.figures_prediction = FigurePrediction() + + # Select the picture clusters + in_clusters = [] + for cluster in page.predictions.layout.clusters: + if cluster.label != "Picture": + continue + + crop_bbox = cluster.bbox.scaled( + scale=self.options.scale + ).to_top_left_origin(page_height=page.size.height * self.options.scale) + in_clusters.append( + ( + cluster, + crop_bbox.as_tuple(), + ) + ) + + if not len(in_clusters): + yield page + continue + + # save classifications using proper object + if ( + page.predictions.figures_prediction.figure_count > 0 + and page.predictions.figures_prediction.figure_count != len(in_clusters) + ): + raise RuntimeError( + "Different models predicted a different number of figures." + ) + page.predictions.figures_prediction.figure_count = len(in_clusters) + + cluster_figure_batch = [] + page_image = page.get_image(scale=self.options.scale) + if page_image is None: + raise RuntimeError("The page image cannot be generated.") + + for cluster, figure_bbox in in_clusters: + figure = page_image.crop(figure_bbox) + cluster_figure_batch.append((cluster, figure)) + + # if enough figures then flush + if len(cluster_figure_batch) == self.options.batch_size: + self._flush_merge( + page=page, + cluster_figure_batch=cluster_figure_batch, + figures_prediction=page.predictions.figures_prediction, + ) + cluster_figure_batch = [] + + # final flush + if len(cluster_figure_batch) > 0: + self._flush_merge( + page=page, + cluster_figure_batch=cluster_figure_batch, + figures_prediction=page.predictions.figures_prediction, + ) + cluster_figure_batch = [] + + yield page diff --git a/docling/models/img_understand_vllm_model.py b/docling/models/img_understand_vllm_model.py new file mode 100644 index 00000000..3f9103de --- /dev/null +++ b/docling/models/img_understand_vllm_model.py @@ -0,0 +1,87 @@ +import json +import logging +from typing import Any, Dict, Iterable, List, Literal, Tuple + +from PIL import Image + +from docling.datamodel.base_models import Cluster, FigureDescriptionData +from docling.models.img_understand_base_model import ( + ImgUnderstandBaseModel, + ImgUnderstandOptions, +) +from docling.utils.utils import create_hash + +_log = logging.getLogger(__name__) + + +class ImgUnderstandVllmOptions(ImgUnderstandOptions): + kind: Literal["vllm"] = "vllm" + + # For more example parameters see https://docs.vllm.ai/en/latest/getting_started/examples/offline_inference_vision_language.html + + # Parameters for LLaVA-1.6/LLaVA-NeXT + llm_name: str = "llava-hf/llava-v1.6-mistral-7b-hf" + llm_prompt: str = "[INST] \nDescribe the image in details. [/INST]" + llm_extra: Dict[str, Any] = dict(max_model_len=8192) + + # Parameters for Phi-3-Vision + # llm_name: str = "microsoft/Phi-3-vision-128k-instruct" + # llm_prompt: str = "<|user|>\n<|image_1|>\nDescribe the image in details.<|end|>\n<|assistant|>\n" + # llm_extra: Dict[str, Any] = dict(max_num_seqs=5, trust_remote_code=True) + + sampling_params: Dict[str, Any] = dict(max_tokens=64, seed=42) + + +class ImgUnderstandVllmModel(ImgUnderstandBaseModel): + + def __init__(self, enabled: bool, options: ImgUnderstandVllmOptions): + super().__init__(enabled=enabled, options=options) + self.options: ImgUnderstandVllmOptions + + if self.enabled: + try: + from vllm import LLM, SamplingParams + except ImportError: + raise ImportError( + "VLLM is not installed. Please install Docling with the required extras `pip install docling[vllm]`." + ) + + self.sampling_params = SamplingParams(**self.options.sampling_params) + self.llm = LLM(model=self.options.llm_name, **self.options.llm_extra) + + # Generate a stable hash from the extra parameters + params_hash = create_hash( + json.dumps(self.options.llm_extra, sort_keys=True) + + json.dumps(self.options.sampling_params, sort_keys=True) + ) + self.provenance = f"{self.options.llm_name}-{params_hash[:8]}" + + def _annotate_image_batch( + self, batch: Iterable[Tuple[Cluster, Image.Image]] + ) -> List[FigureDescriptionData]: + + if not self.enabled: + return [FigureDescriptionData() for _ in batch] + + from vllm import RequestOutput + + inputs = [ + { + "prompt": self.options.llm_prompt, + "multi_modal_data": {"image": im}, + } + for _, im in batch + ] + outputs: List[RequestOutput] = self.llm.generate( + inputs, sampling_params=self.sampling_params + ) + + results = [] + for o in outputs: + generated_text = o.outputs[0].text + results.append( + FigureDescriptionData(text=generated_text, provenance=self.provenance) + ) + _log.info(f"Generated description: {generated_text}") + + return results diff --git a/docling/models/page_assemble_model.py b/docling/models/page_assemble_model.py index 2b9db544..27b1afa0 100644 --- a/docling/models/page_assemble_model.py +++ b/docling/models/page_assemble_model.py @@ -98,18 +98,17 @@ def __call__(self, page_batch: Iterable[Page]) -> Iterable[Page]: body.append(tbl) elif cluster.label == LayoutModel.FIGURE_LABEL: fig = None - if page.predictions.figures_classification: - fig = page.predictions.figures_classification.figure_map.get( + if page.predictions.figures_prediction: + fig = page.predictions.figures_prediction.figure_map.get( cluster.id, None ) if ( not fig - ): # fallback: add figure without classification, if it isn't present + ): # fallback: add figure with default data, if it isn't present fig = FigureElement( label=cluster.label, id=cluster.id, text="", - data=None, cluster=cluster, page_no=page.page_no, ) diff --git a/docling/pipeline/img_understand_pipeline.py b/docling/pipeline/img_understand_pipeline.py new file mode 100644 index 00000000..8db581bd --- /dev/null +++ b/docling/pipeline/img_understand_pipeline.py @@ -0,0 +1,53 @@ +from pathlib import Path +from typing import Union + +from pydantic import BaseModel, Field + +from docling.datamodel.base_models import PipelineOptions +from docling.models.img_understand_api_model import ( + ImgUnderstandApiModel, + ImgUnderstandApiOptions, +) +from docling.models.img_understand_vllm_model import ( + ImgUnderstandVllmModel, + ImgUnderstandVllmOptions, +) +from docling.pipeline.standard_model_pipeline import StandardModelPipeline + + +class ImgUnderstandPipelineOptions(PipelineOptions): + do_img_understand: bool = True + img_understand_options: Union[ImgUnderstandApiOptions, ImgUnderstandVllmOptions] = ( + Field(ImgUnderstandVllmOptions(), discriminator="kind") + ) + + +class ImgUnderstandPipeline(StandardModelPipeline): + + def __init__( + self, artifacts_path: Path, pipeline_options: ImgUnderstandPipelineOptions + ): + super().__init__(artifacts_path, pipeline_options) + + if isinstance( + pipeline_options.img_understand_options, ImgUnderstandVllmOptions + ): + self.model_pipe.append( + ImgUnderstandVllmModel( + enabled=pipeline_options.do_img_understand, + options=pipeline_options.img_understand_options, + ) + ) + elif isinstance( + pipeline_options.img_understand_options, ImgUnderstandApiOptions + ): + self.model_pipe.append( + ImgUnderstandApiModel( + enabled=pipeline_options.do_img_understand, + options=pipeline_options.img_understand_options, + ) + ) + else: + raise RuntimeError( + f"The specified imgage understanding kind is not supported: {pipeline_options.img_understand_options.kind}." + ) diff --git a/examples/img_understand_pipeline.py b/examples/img_understand_pipeline.py new file mode 100644 index 00000000..29cfc96a --- /dev/null +++ b/examples/img_understand_pipeline.py @@ -0,0 +1,132 @@ +import logging +import os +import time +from pathlib import Path +from typing import Iterable + +import httpx +from dotenv import load_dotenv + +from docling.datamodel.base_models import ConversionStatus +from docling.datamodel.document import ConversionResult, DocumentConversionInput +from docling.document_converter import DocumentConverter +from docling.pipeline.img_understand_pipeline import ( + ImgUnderstandApiOptions, + ImgUnderstandPipeline, + ImgUnderstandPipelineOptions, + ImgUnderstandVllmOptions, +) + +_log = logging.getLogger(__name__) + + +def export_documents( + conv_results: Iterable[ConversionResult], + output_dir: Path, +): + output_dir.mkdir(parents=True, exist_ok=True) + + success_count = 0 + failure_count = 0 + + for conv_res in conv_results: + if conv_res.status == ConversionStatus.SUCCESS: + success_count += 1 + doc_filename = conv_res.input.file.stem + + # # Export Deep Search document JSON format: + # with (output_dir / f"{doc_filename}.json").open("w") as fp: + # fp.write(json.dumps(conv_res.render_as_dict())) + + # # Export Text format: + # with (output_dir / f"{doc_filename}.txt").open("w") as fp: + # fp.write(conv_res.render_as_text()) + + # # Export Markdown format: + # with (output_dir / f"{doc_filename}.md").open("w") as fp: + # fp.write(conv_res.render_as_markdown()) + + # # Export Document Tags format: + # with (output_dir / f"{doc_filename}.doctags").open("w") as fp: + # fp.write(conv_res.render_as_doctags()) + + else: + _log.info(f"Document {conv_res.input.file} failed to convert.") + failure_count += 1 + + _log.info( + f"Processed {success_count + failure_count} docs, of which {failure_count} failed" + ) + + return success_count, failure_count + + +def _get_iam_access_token(api_key: str) -> str: + res = httpx.post( + url="https://iam.cloud.ibm.com/identity/token", + headers={ + "Content-Type": "application/x-www-form-urlencoded", + }, + data=f"grant_type=urn:ibm:params:oauth:grant-type:apikey&apikey={api_key}", + ) + res.raise_for_status() + api_out = res.json() + print(f"{api_out=}") + return api_out["access_token"] + + +def main(): + logging.basicConfig(level=logging.INFO) + + input_doc_paths = [ + Path("./tests/data/2206.01062.pdf"), + ] + + load_dotenv() + api_key = os.environ.get("WX_API_KEY") + project_id = os.environ.get("WX_PROJECT_ID") + + doc_converter = DocumentConverter( + pipeline_cls=ImgUnderstandPipeline, + # TODO: make DocumentConverter provide the correct default value + # for pipeline_options, given the pipeline_cls + pipeline_options=ImgUnderstandPipelineOptions( + img_understand_options=ImgUnderstandApiOptions( + url="https://us-south.ml.cloud.ibm.com/ml/v1/text/chat?version=2023-05-29", + headers={ + "Authorization": "Bearer " + _get_iam_access_token(api_key=api_key), + }, + params=dict( + model_id="meta-llama/llama3-llava-next-8b-hf", + project_id=project_id, + max_tokens=512, + seed=42, + ), + llm_prompt="Describe this figure in three sentences.", + provenance="llama3-llava-next-8b-hf", + ) + ), + ) + + # Define input files + input = DocumentConversionInput.from_paths(input_doc_paths) + + start_time = time.time() + + conv_results = doc_converter.convert(input) + success_count, failure_count = export_documents( + conv_results, output_dir=Path("./scratch") + ) + + end_time = time.time() - start_time + + _log.info(f"All documents were converted in {end_time:.2f} seconds.") + + if failure_count > 0: + raise RuntimeError( + f"The example failed converting {failure_count} on {len(input_doc_paths)}." + ) + + +if __name__ == "__main__": + main() diff --git a/pyproject.toml b/pyproject.toml index cd20fb64..35537d26 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,6 +41,7 @@ pyarrow = "^16.1.0" ######### # extras: ######### +# vllm = { version = "^0.5.0", optional = true, markers = "sys_platform != 'darwin' or platform_machine != 'x86_64'" } python-dotenv = { version = "^1.0.1", optional = true } llama-index-embeddings-huggingface = { version = "^0.3.1", optional = true } llama-index-llms-huggingface-api = { version = "^0.2.0", optional = true } @@ -84,6 +85,7 @@ nbqa = "^1.9.0" datasets = "^2.21.0" [tool.poetry.extras] +# vllm = ["vllm"] examples = [ "python-dotenv", # LlamaIndex examples: