-
Notifications
You must be signed in to change notification settings - Fork 1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Signed-off-by: Michele Dolfi <[email protected]>
- Loading branch information
1 parent
1f4b224
commit a122a7b
Showing
8 changed files
with
561 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
import base64 | ||
import datetime | ||
import io | ||
import logging | ||
from typing import Any, Dict, Iterable, List, Literal, Optional, Tuple | ||
|
||
import httpx | ||
from PIL import Image | ||
from pydantic import AnyUrl, BaseModel, ConfigDict | ||
|
||
from docling.datamodel.base_models import Cluster, FigureDescriptionData | ||
from docling.models.img_understand_base_model import ( | ||
ImgUnderstandBaseModel, | ||
ImgUnderstandOptions, | ||
) | ||
|
||
_log = logging.getLogger(__name__) | ||
|
||
|
||
class ImgUnderstandApiOptions(ImgUnderstandOptions): | ||
kind: Literal["api"] = "api" | ||
|
||
url: AnyUrl | ||
headers: Dict[str, str] | ||
params: Dict[str, Any] | ||
timeout: float = 20 | ||
|
||
llm_prompt: str | ||
provenance: str | ||
|
||
|
||
class ChatMessage(BaseModel): | ||
role: str | ||
content: str | ||
|
||
|
||
class ResponseChoice(BaseModel): | ||
index: int | ||
message: ChatMessage | ||
finish_reason: str | ||
|
||
|
||
class ResponseUsage(BaseModel): | ||
prompt_tokens: int | ||
completion_tokens: int | ||
total_tokens: int | ||
|
||
|
||
class ApiResponse(BaseModel): | ||
model_config = ConfigDict( | ||
protected_namespaces=(), | ||
) | ||
|
||
id: str | ||
model_id: Optional[str] = None # returned by watsonx | ||
model: Optional[str] = None # returned bu openai | ||
choices: List[ResponseChoice] | ||
created: int | ||
usage: ResponseUsage | ||
|
||
|
||
class ImgUnderstandApiModel(ImgUnderstandBaseModel): | ||
|
||
def __init__(self, enabled: bool, options: ImgUnderstandApiOptions): | ||
super().__init__(enabled=enabled, options=options) | ||
self.options: ImgUnderstandApiOptions | ||
|
||
def _annotate_image_batch( | ||
self, batch: Iterable[Tuple[Cluster, Image.Image]] | ||
) -> List[FigureDescriptionData]: | ||
|
||
if not self.enabled: | ||
return [FigureDescriptionData() for _ in batch] | ||
|
||
results = [] | ||
for cluster, image in batch: | ||
img_io = io.BytesIO() | ||
image.save(img_io, "PNG") | ||
image_base64 = base64.b64encode(img_io.getvalue()).decode("utf-8") | ||
|
||
messages = [ | ||
{ | ||
"role": "user", | ||
"content": [ | ||
{ | ||
"type": "text", | ||
"text": self.options.llm_prompt, | ||
}, | ||
{ | ||
"type": "image_url", | ||
"image_url": { | ||
"url": f"data:image/png;base64,{image_base64}" | ||
}, | ||
}, | ||
], | ||
} | ||
] | ||
|
||
payload = { | ||
"messages": messages, | ||
**self.options.params, | ||
} | ||
|
||
r = httpx.post( | ||
str(self.options.url), | ||
headers=self.options.headers, | ||
json=payload, | ||
timeout=self.options.timeout, | ||
) | ||
if not r.is_success: | ||
_log.error(f"Error calling the API. Reponse was {r.text}") | ||
r.raise_for_status() | ||
|
||
api_resp = ApiResponse.model_validate_json(r.text) | ||
generated_text = api_resp.choices[0].message.content.strip() | ||
results.append( | ||
FigureDescriptionData( | ||
text=generated_text, provenance=self.options.provenance | ||
) | ||
) | ||
_log.info(f"Generated description: {generated_text}") | ||
|
||
return results |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
import logging | ||
import time | ||
from typing import Iterable, List, Literal, Tuple | ||
|
||
from PIL import Image | ||
from pydantic import BaseModel | ||
|
||
from docling.datamodel.base_models import ( | ||
Cluster, | ||
FigureData, | ||
FigureDescriptionData, | ||
FigureElement, | ||
FigurePrediction, | ||
Page, | ||
) | ||
|
||
_log = logging.getLogger(__name__) | ||
|
||
|
||
class ImgUnderstandOptions(BaseModel): | ||
kind: str | ||
batch_size: int = 8 | ||
scale: float = 2 | ||
|
||
# if the relative area of the image with respect to the whole image page | ||
# is larger than this threshold it will be processed, otherwise not. | ||
# TODO: implement the skip logic | ||
min_area: float = 0.05 | ||
|
||
|
||
class ImgUnderstandBaseModel: | ||
|
||
def __init__(self, enabled: bool, options: ImgUnderstandOptions): | ||
self.enabled = enabled | ||
self.options = options | ||
|
||
def _annotate_image_batch( | ||
self, batch: Iterable[Tuple[Cluster, Image.Image]] | ||
) -> List[FigureDescriptionData]: | ||
raise NotImplemented() | ||
|
||
def _flush_merge( | ||
self, | ||
page: Page, | ||
cluster_figure_batch: List[Tuple[Cluster, Image.Image]], | ||
figures_prediction: FigurePrediction, | ||
): | ||
start_time = time.time() | ||
results_batch = self._annotate_image_batch(cluster_figure_batch) | ||
assert len(results_batch) == len( | ||
cluster_figure_batch | ||
), "The returned annotations is not matching the input size" | ||
end_time = time.time() | ||
_log.info( | ||
f"Batch of {len(results_batch)} images processed in {end_time-start_time:.1f} seconds. Time per image is {(end_time-start_time) / len(results_batch):.3f} seconds." | ||
) | ||
|
||
for (cluster, _), desc_data in zip(cluster_figure_batch, results_batch): | ||
if not cluster.id in figures_prediction.figure_map: | ||
figures_prediction.figure_map[cluster.id] = FigureElement( | ||
label=cluster.label, | ||
id=cluster.id, | ||
data=FigureData(desciption=desc_data), | ||
cluster=cluster, | ||
page_no=page.page_no, | ||
) | ||
elif figures_prediction.figure_map[cluster.id].data.description is None: | ||
figures_prediction.figure_map[cluster.id].data.description = desc_data | ||
else: | ||
_log.warning( | ||
f"Conflicting predictions. " | ||
f"Another model ({figures_prediction.figure_map[cluster.id].data.description.provenance}) " | ||
f"was already predicting an image description. The new prediction will be skipped." | ||
) | ||
|
||
def __call__(self, page_batch: Iterable[Page]) -> Iterable[Page]: | ||
|
||
if not self.enabled: | ||
yield from page_batch | ||
return | ||
|
||
for page in page_batch: | ||
|
||
# This model could be the first one initializing figures_prediction | ||
if page.predictions.figures_prediction is None: | ||
page.predictions.figures_prediction = FigurePrediction() | ||
|
||
# Select the picture clusters | ||
in_clusters = [] | ||
for cluster in page.predictions.layout.clusters: | ||
if cluster.label != "Picture": | ||
continue | ||
|
||
crop_bbox = cluster.bbox.scaled( | ||
scale=self.options.scale | ||
).to_top_left_origin(page_height=page.size.height * self.options.scale) | ||
in_clusters.append( | ||
( | ||
cluster, | ||
crop_bbox.as_tuple(), | ||
) | ||
) | ||
|
||
if not len(in_clusters): | ||
yield page | ||
continue | ||
|
||
# save classifications using proper object | ||
if ( | ||
page.predictions.figures_prediction.figure_count > 0 | ||
and page.predictions.figures_prediction.figure_count != len(in_clusters) | ||
): | ||
raise RuntimeError( | ||
"Different models predicted a different number of figures." | ||
) | ||
page.predictions.figures_prediction.figure_count = len(in_clusters) | ||
|
||
cluster_figure_batch = [] | ||
page_image = page.get_image(scale=self.options.scale) | ||
if page_image is None: | ||
raise RuntimeError("The page image cannot be generated.") | ||
|
||
for cluster, figure_bbox in in_clusters: | ||
figure = page_image.crop(figure_bbox) | ||
cluster_figure_batch.append((cluster, figure)) | ||
|
||
# if enough figures then flush | ||
if len(cluster_figure_batch) == self.options.batch_size: | ||
self._flush_merge( | ||
page=page, | ||
cluster_figure_batch=cluster_figure_batch, | ||
figures_prediction=page.predictions.figures_prediction, | ||
) | ||
cluster_figure_batch = [] | ||
|
||
# final flush | ||
if len(cluster_figure_batch) > 0: | ||
self._flush_merge( | ||
page=page, | ||
cluster_figure_batch=cluster_figure_batch, | ||
figures_prediction=page.predictions.figures_prediction, | ||
) | ||
cluster_figure_batch = [] | ||
|
||
yield page |
Oops, something went wrong.