Skip to content

Commit

Permalink
feat!: New API for models initialization with accelerators parameters…
Browse files Browse the repository at this point in the history
…. Use HF implementation for LayoutPredictor. Migrate models to safetensors format. (#50)

Signed-off-by: Nikos Livathinos <[email protected]>
Co-authored-by: Christoph Auer <[email protected]>
  • Loading branch information
nikos-livathinos and cau-git authored Dec 11, 2024
1 parent 33e0216 commit 04295b2
Show file tree
Hide file tree
Showing 9 changed files with 1,426 additions and 276 deletions.
24 changes: 6 additions & 18 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,21 +19,9 @@ repos:
entry: poetry lock --check
pass_filenames: false
language: system

# Ready to be enabled soon
# - repo: local
# hooks:
# - id: system
# name: flake8
# entry: poetry run flake8 docling_ibm_models
# pass_filenames: false
# language: system
# files: '\.py$'
# - repo: local
# hooks:
# - id: system
# name: MyPy
# entry: poetry run mypy docling_ibm_models
# pass_filenames: false
# language: system
# files: '\.py$'
# - id: system
# name: MyPy
# entry: poetry run mypy docling_ibm_models
# pass_filenames: false
# language: system
# files: '\.py$'
107 changes: 67 additions & 40 deletions demo/demo_layout_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,52 @@
from pathlib import Path

import numpy as np
from PIL import Image, ImageDraw
import torch
from huggingface_hub import snapshot_download
from PIL import Image, ImageDraw, ImageFont

from docling_ibm_models.layoutmodel.layout_predictor import LayoutPredictor


def save_predictions(prefix: str, viz_dir: str, img_fn: str, img, predictions: dict):
img_path = Path(img_fn)

image = img.copy()
draw = ImageDraw.Draw(image)

predictions_filename = f"{prefix}_{img_path.stem}.txt"
predictions_fn = os.path.join(viz_dir, predictions_filename)
with open(predictions_fn, "w") as fd:
for pred in predictions:
bbox = [
round(pred["l"], 2),
round(pred["t"], 2),
round(pred["r"], 2),
round(pred["b"], 2),
]
label = pred["label"]
confidence = round(pred["confidence"], 3)

# Save the predictions in txt file
pred_txt = f"{prefix} {img_fn}: {label} - {bbox} - {confidence}\n"
fd.write(pred_txt)

# Draw the bbox and label
draw.rectangle(bbox, outline="orange")
txt = f"{label}: {confidence}"
draw.text(
(bbox[0], bbox[1]), text=txt, font=ImageFont.load_default(), fill="blue"
)

draw_filename = f"{prefix}_{img_path.name}"
draw_fn = os.path.join(viz_dir, draw_filename)
image.save(draw_fn)


def demo(
logger: logging.Logger,
artifact_path: str,
device: str,
num_threads: int,
img_dir: str,
viz_dir: str,
Expand All @@ -30,58 +67,43 @@ def demo(
pdf_image = pyvips.Image.new_from_file("test_data/ADS.2007.page_123.pdf", page=0)
"""
# Create the layout predictor
lpredictor = LayoutPredictor(artifact_path, num_threads=num_threads)
logger.info("LayoutPredictor settings: {}".format(lpredictor.info()))
lpredictor = LayoutPredictor(artifact_path, device=device, num_threads=num_threads)

# Predict all test png images
t0 = time.perf_counter()
img_counter = 0
for img_fn in Path(img_dir).rglob("*.png"):
img_counter += 1
logger.info("Predicting '%s'...", img_fn)
start_t = time.time()

with Image.open(img_fn) as image:
# Predict layout
img_t0 = time.perf_counter()
preds = list(lpredictor.predict(image))
dt_ms = 1000 * (time.time() - start_t)
logger.debug("Time elapsed for prediction(ms): %s", dt_ms)

# Draw predictions
out_img = image.copy()
draw = ImageDraw.Draw(out_img)

for i, pred in enumerate(preds):
score = pred["confidence"]
label = pred["label"]
box = [
round(pred["l"]),
round(pred["t"]),
round(pred["r"]),
round(pred["b"]),
]

# Draw bbox and label
draw.rectangle(
box,
outline="red",
)
draw.text(
(box[0], box[1]),
text=str(label),
fill="blue",
)
logger.info("%s: [label|score|bbox] = ['%s' | %s | %s]", i, label, score, box)

save_fn = os.path.join(viz_dir, os.path.basename(img_fn))
out_img.save(save_fn)
logger.info("Saving prediction visualization in: '%s'", save_fn)
img_ms = 1000 * (time.perf_counter() - img_t0)
logger.debug("Prediction(ms): {:.2f}".format(img_ms))

# Save predictions
logger.info("Saving prediction visualization in: '%s'", viz_dir)
save_predictions("ST", viz_dir, img_fn, image, preds)
total_ms = 1000 * (time.perf_counter() - t0)
avg_ms = (total_ms / img_counter) if img_counter > 0 else 0
logger.info(
"For {} images(ms): [total|avg] = [{:.1f}|{:.1f}]".format(
img_counter, total_ms, avg_ms
)
)


def main(args):
r""" """
num_threads = int(args.num_threads) if args.num_threads is not None else None
device = args.device.lower()
img_dir = args.img_dir
viz_dir = args.viz_dir

# Initialize logger
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger("LayoutPredictor")
logger.setLevel(logging.DEBUG)
if not logger.hasHandlers():
Expand All @@ -96,11 +118,13 @@ def main(args):
Path(viz_dir).mkdir(parents=True, exist_ok=True)

# Download models from HF
download_path = snapshot_download(repo_id="ds4sd/docling-models", revision="v2.0.1")
artifact_path = os.path.join(download_path, "model_artifacts/layout/beehive_v0.0.5_pt")
download_path = snapshot_download(
repo_id="ds4sd/docling-models", revision="v2.1.0"
)
artifact_path = os.path.join(download_path, "model_artifacts/layout")

# Test the LayoutPredictor
demo(logger, artifact_path, num_threads, img_dir, viz_dir)
demo(logger, artifact_path, device, num_threads, img_dir, viz_dir)


if __name__ == "__main__":
Expand All @@ -109,7 +133,10 @@ def main(args):
"""
parser = argparse.ArgumentParser(description="Test the LayoutPredictor")
parser.add_argument(
"-n", "--num_threads", required=False, default=None, help="Number of threads"
"-d", "--device", required=False, default="cpu", help="One of [cpu, cuda, mps]"
)
parser.add_argument(
"-n", "--num_threads", required=False, default=4, help="Number of threads"
)
parser.add_argument(
"-i",
Expand Down
140 changes: 74 additions & 66 deletions docling_ibm_models/layoutmodel/layout_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Copyright IBM Corp. 2024 - 2024
# SPDX-License-Identifier: MIT
#
import logging
import os
from collections.abc import Iterable
from typing import Union
Expand All @@ -10,38 +11,30 @@
import torch
import torchvision.transforms as T
from PIL import Image
from transformers import RTDetrForObjectDetection, RTDetrImageProcessor

MODEL_CHECKPOINT_FN = "model.pt"
DEFAULT_NUM_THREADS = 4
_log = logging.getLogger(__name__)


class LayoutPredictor:
r"""
Document layout prediction using torch
"""
Document layout prediction using safe tensors
"""

def __init__(
self, artifact_path: str, num_threads: int = None, use_cpu_only: bool = False
self,
artifact_path: str,
device: str = "cpu",
num_threads: int = 4,
):
r"""
"""
Provide the artifact path that contains the LayoutModel file
The number of threads is decided, in the following order, by:
1. The init method parameter `num_threads`, if it is set.
2. The envvar "OMP_NUM_THREADS", if it is set.
3. The default value DEFAULT_NUM_THREADS.
The execution provided is decided, in the following order:
1. If the init method parameter `cpu_only` is True or the envvar "USE_CPU_ONLY" is set,
it uses the "CPUExecutionProvider".
3. Otherwise if the "CUDAExecutionProvider" is present, use:
["CUDAExecutionProvider", "CPUExecutionProvider"]:
Parameters
----------
artifact_path: Path for the model torch file.
num_threads: (Optional) Number of threads to run the inference.
use_cpu_only: (Optional) If True, it forces CPU as the execution provider.
device: (Optional) device to run the inference.
num_threads: (Optional) Number of threads to run the inference if device = 'cpu'
Raises
------
Expand Down Expand Up @@ -70,40 +63,51 @@ def __init__(
}

# Blacklisted classes
self._black_classes = set(["Form", "Key-Value Region"])
self._black_classes = set() # ["Form", "Key-Value Region"])

# Set basic params
self._threshold = 0.6 # Score threshold
self._threshold = 0.3 # Score threshold
self._image_size = 640
self._size = np.asarray([[self._image_size, self._image_size]], dtype=np.int64)
self._use_cpu_only = use_cpu_only or ("USE_CPU_ONLY" in os.environ)

# Model file
self._torch_fn = os.path.join(artifact_path, MODEL_CHECKPOINT_FN)
if not os.path.isfile(self._torch_fn):
raise FileNotFoundError("Missing torch file: {}".format(self._torch_fn))

# Get env vars
if num_threads is None:
num_threads = int(os.environ.get("OMP_NUM_THREADS", DEFAULT_NUM_THREADS))
# Set number of threads for CPU
self._device = torch.device(device)
self._num_threads = num_threads
if device == "cpu":
torch.set_num_threads(self._num_threads)

# Model file and configurations
self._st_fn = os.path.join(artifact_path, "model.safetensors")
if not os.path.isfile(self._st_fn):
raise FileNotFoundError("Missing safe tensors file: {}".format(self._st_fn))

self.model = torch.jit.load(self._torch_fn)
# Load model and move to device
processor_config = os.path.join(artifact_path, "preprocessor_config.json")
model_config = os.path.join(artifact_path, "config.json")
self._image_processor = RTDetrImageProcessor.from_json_file(processor_config)
self._model = RTDetrForObjectDetection.from_pretrained(
artifact_path, config=model_config
).to(self._device)
self._model.eval()

_log.debug("LayoutPredictor settings: {}".format(self.info()))

def info(self) -> dict:
r"""
"""
Get information about the configuration of LayoutPredictor
"""
info = {
"torch_file": self._torch_fn,
"use_cpu_only": self._use_cpu_only,
"safe_tensors_file": self._st_fn,
"device": self._device.type,
"num_threads": self._num_threads,
"image_size": self._image_size,
"threshold": self._threshold,
}
return info

@torch.inference_mode()
def predict(self, orig_img: Union[Image.Image, np.ndarray]) -> Iterable[dict]:
r"""
"""
Predict bounding boxes for a given image.
The origin (0, 0) is the top-left corner and the predicted bbox coords are provided as:
[left, top, right, bottom]
Expand All @@ -128,40 +132,44 @@ def predict(self, orig_img: Union[Image.Image, np.ndarray]) -> Iterable[dict]:
else:
raise TypeError("Not supported input image format")

resize = {"height": self._image_size, "width": self._image_size}
inputs = self._image_processor(
images=page_img,
return_tensors="pt",
size=resize,
).to(self._device)
outputs = self._model(**inputs)
results = self._image_processor.post_process_object_detection(
outputs,
target_sizes=torch.tensor([page_img.size[::-1]]),
threshold=self._threshold,
)

w, h = page_img.size
orig_size = torch.tensor([w, h])[None]

transforms = T.Compose(
[
T.Resize((640, 640)),
T.ToTensor(),
]
)
img = transforms(page_img)[None]
# Predict
with torch.no_grad():
labels, boxes, scores = self.model(img, orig_size)
result = results[0]
for score, label_id, box in zip(
result["scores"], result["labels"], result["boxes"]
):
score = float(score.item())

label_id = int(label_id.item()) + 1 # Advance the label_id
label_str = self._classes_map[label_id]

# Yield output
for label_idx, box, score in zip(labels[0], boxes[0], scores[0]):
# Filter out blacklisted classes
label_idx = int(label_idx.item())
score = float(score.item())
label = self._classes_map[label_idx + 1]
if label in self._black_classes:
if label_str in self._black_classes:
continue

# Check against threshold
if score > self._threshold:
l = min(w, max(0, box[0]))
t = min(h, max(0, box[1]))
r = min(w, max(0, box[2]))
b = min(h, max(0, box[3]))
yield {
"l": l,
"t": t,
"r": r,
"b": b,
"label": label,
"confidence": score,
}
bbox_float = [float(b.item()) for b in box]
l = min(w, max(0, bbox_float[0]))
t = min(h, max(0, bbox_float[1]))
r = min(w, max(0, bbox_float[2]))
b = min(h, max(0, bbox_float[3]))
yield {
"l": l,
"t": t,
"r": r,
"b": b,
"label": label_str,
"confidence": score,
}
Loading

0 comments on commit 04295b2

Please sign in to comment.