Skip to content

Commit

Permalink
introduce img understand pipeline
Browse files Browse the repository at this point in the history
Signed-off-by: Michele Dolfi <[email protected]>
  • Loading branch information
dolfim-ibm committed Sep 22, 2024
1 parent 1f4b224 commit a122a7b
Show file tree
Hide file tree
Showing 8 changed files with 561 additions and 11 deletions.
23 changes: 16 additions & 7 deletions docling/datamodel/base_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,18 +224,27 @@ class TableStructurePrediction(BaseModel):
class TextElement(BasePageElement): ...


class FigureClassificationData(BaseModel):
provenance: str
predicted_class: str
confidence: float


class FigureDescriptionData(BaseModel):
text: str
provenance: str = ""


class FigureData(BaseModel):
pass
classification: Optional[FigureClassificationData] = None
description: Optional[FigureDescriptionData] = None


class FigureElement(BasePageElement):
data: Optional[FigureData] = None
provenance: Optional[str] = None
predicted_class: Optional[str] = None
confidence: Optional[float] = None
data: FigureData = FigureData()


class FigureClassificationPrediction(BaseModel):
class FigurePrediction(BaseModel):
figure_count: int = 0
figure_map: Dict[int, FigureElement] = {}

Expand All @@ -248,7 +257,7 @@ class EquationPrediction(BaseModel):
class PagePredictions(BaseModel):
layout: Optional[LayoutPrediction] = None
tablestructure: Optional[TableStructurePrediction] = None
figures_classification: Optional[FigureClassificationPrediction] = None
figures_prediction: Optional[FigurePrediction] = None
equations_prediction: Optional[EquationPrediction] = None


Expand Down
123 changes: 123 additions & 0 deletions docling/models/img_understand_api_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
import base64
import datetime
import io
import logging
from typing import Any, Dict, Iterable, List, Literal, Optional, Tuple

import httpx
from PIL import Image
from pydantic import AnyUrl, BaseModel, ConfigDict

from docling.datamodel.base_models import Cluster, FigureDescriptionData
from docling.models.img_understand_base_model import (
ImgUnderstandBaseModel,
ImgUnderstandOptions,
)

_log = logging.getLogger(__name__)


class ImgUnderstandApiOptions(ImgUnderstandOptions):
kind: Literal["api"] = "api"

url: AnyUrl
headers: Dict[str, str]
params: Dict[str, Any]
timeout: float = 20

llm_prompt: str
provenance: str


class ChatMessage(BaseModel):
role: str
content: str


class ResponseChoice(BaseModel):
index: int
message: ChatMessage
finish_reason: str


class ResponseUsage(BaseModel):
prompt_tokens: int
completion_tokens: int
total_tokens: int


class ApiResponse(BaseModel):
model_config = ConfigDict(
protected_namespaces=(),
)

id: str
model_id: Optional[str] = None # returned by watsonx
model: Optional[str] = None # returned bu openai
choices: List[ResponseChoice]
created: int
usage: ResponseUsage


class ImgUnderstandApiModel(ImgUnderstandBaseModel):

def __init__(self, enabled: bool, options: ImgUnderstandApiOptions):
super().__init__(enabled=enabled, options=options)
self.options: ImgUnderstandApiOptions

def _annotate_image_batch(
self, batch: Iterable[Tuple[Cluster, Image.Image]]
) -> List[FigureDescriptionData]:

if not self.enabled:
return [FigureDescriptionData() for _ in batch]

results = []
for cluster, image in batch:
img_io = io.BytesIO()
image.save(img_io, "PNG")
image_base64 = base64.b64encode(img_io.getvalue()).decode("utf-8")

messages = [
{
"role": "user",
"content": [
{
"type": "text",
"text": self.options.llm_prompt,
},
{
"type": "image_url",
"image_url": {
"url": f"data:image/png;base64,{image_base64}"
},
},
],
}
]

payload = {
"messages": messages,
**self.options.params,
}

r = httpx.post(
str(self.options.url),
headers=self.options.headers,
json=payload,
timeout=self.options.timeout,
)
if not r.is_success:
_log.error(f"Error calling the API. Reponse was {r.text}")
r.raise_for_status()

api_resp = ApiResponse.model_validate_json(r.text)
generated_text = api_resp.choices[0].message.content.strip()
results.append(
FigureDescriptionData(
text=generated_text, provenance=self.options.provenance
)
)
_log.info(f"Generated description: {generated_text}")

return results
145 changes: 145 additions & 0 deletions docling/models/img_understand_base_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
import logging
import time
from typing import Iterable, List, Literal, Tuple

from PIL import Image
from pydantic import BaseModel

from docling.datamodel.base_models import (
Cluster,
FigureData,
FigureDescriptionData,
FigureElement,
FigurePrediction,
Page,
)

_log = logging.getLogger(__name__)


class ImgUnderstandOptions(BaseModel):
kind: str
batch_size: int = 8
scale: float = 2

# if the relative area of the image with respect to the whole image page
# is larger than this threshold it will be processed, otherwise not.
# TODO: implement the skip logic
min_area: float = 0.05


class ImgUnderstandBaseModel:

def __init__(self, enabled: bool, options: ImgUnderstandOptions):
self.enabled = enabled
self.options = options

def _annotate_image_batch(
self, batch: Iterable[Tuple[Cluster, Image.Image]]
) -> List[FigureDescriptionData]:
raise NotImplemented()

def _flush_merge(
self,
page: Page,
cluster_figure_batch: List[Tuple[Cluster, Image.Image]],
figures_prediction: FigurePrediction,
):
start_time = time.time()
results_batch = self._annotate_image_batch(cluster_figure_batch)
assert len(results_batch) == len(
cluster_figure_batch
), "The returned annotations is not matching the input size"
end_time = time.time()
_log.info(
f"Batch of {len(results_batch)} images processed in {end_time-start_time:.1f} seconds. Time per image is {(end_time-start_time) / len(results_batch):.3f} seconds."
)

for (cluster, _), desc_data in zip(cluster_figure_batch, results_batch):
if not cluster.id in figures_prediction.figure_map:
figures_prediction.figure_map[cluster.id] = FigureElement(
label=cluster.label,
id=cluster.id,
data=FigureData(desciption=desc_data),
cluster=cluster,
page_no=page.page_no,
)
elif figures_prediction.figure_map[cluster.id].data.description is None:
figures_prediction.figure_map[cluster.id].data.description = desc_data
else:
_log.warning(
f"Conflicting predictions. "
f"Another model ({figures_prediction.figure_map[cluster.id].data.description.provenance}) "
f"was already predicting an image description. The new prediction will be skipped."
)

def __call__(self, page_batch: Iterable[Page]) -> Iterable[Page]:

if not self.enabled:
yield from page_batch
return

for page in page_batch:

# This model could be the first one initializing figures_prediction
if page.predictions.figures_prediction is None:
page.predictions.figures_prediction = FigurePrediction()

# Select the picture clusters
in_clusters = []
for cluster in page.predictions.layout.clusters:
if cluster.label != "Picture":
continue

crop_bbox = cluster.bbox.scaled(
scale=self.options.scale
).to_top_left_origin(page_height=page.size.height * self.options.scale)
in_clusters.append(
(
cluster,
crop_bbox.as_tuple(),
)
)

if not len(in_clusters):
yield page
continue

# save classifications using proper object
if (
page.predictions.figures_prediction.figure_count > 0
and page.predictions.figures_prediction.figure_count != len(in_clusters)
):
raise RuntimeError(
"Different models predicted a different number of figures."
)
page.predictions.figures_prediction.figure_count = len(in_clusters)

cluster_figure_batch = []
page_image = page.get_image(scale=self.options.scale)
if page_image is None:
raise RuntimeError("The page image cannot be generated.")

for cluster, figure_bbox in in_clusters:
figure = page_image.crop(figure_bbox)
cluster_figure_batch.append((cluster, figure))

# if enough figures then flush
if len(cluster_figure_batch) == self.options.batch_size:
self._flush_merge(
page=page,
cluster_figure_batch=cluster_figure_batch,
figures_prediction=page.predictions.figures_prediction,
)
cluster_figure_batch = []

# final flush
if len(cluster_figure_batch) > 0:
self._flush_merge(
page=page,
cluster_figure_batch=cluster_figure_batch,
figures_prediction=page.predictions.figures_prediction,
)
cluster_figure_batch = []

yield page
Loading

0 comments on commit a122a7b

Please sign in to comment.