diff --git a/demo/demo_code_formula_predictor.py b/demo/demo_code_formula_predictor.py new file mode 100644 index 0000000..b6b6579 --- /dev/null +++ b/demo/demo_code_formula_predictor.py @@ -0,0 +1,111 @@ +# +# Copyright IBM Corp. 2024 - 2024 +# SPDX-License-Identifier: MIT +# +import argparse +import logging +import os +import sys +import time +from pathlib import Path + +from huggingface_hub import snapshot_download +from PIL import Image + +from docling_ibm_models.code_formula_model.code_formula_predictor import CodeFormulaPredictor + + +def demo( + logger: logging.Logger, + artifact_path: str, + device: str, + num_threads: int, + image_dir: str, + viz_dir: str, +): + r""" + Apply LayoutPredictor on the input image directory + + If you want to load from PDF: + pdf_image = pyvips.Image.new_from_file("test_data/ADS.2007.page_123.pdf", page=0) + """ + # Create the layout predictor + code_formula_predictor = CodeFormulaPredictor(artifact_path, device=device, num_threads=num_threads) + + image_dir = Path(image_dir) + images = [] + image_names = os.listdir(image_dir) + image_names.sort() + for image_name in image_names: + image = Image.open(image_dir / image_name) + images.append(image) + + t0 = time.perf_counter() + outputs = code_formula_predictor.predict(images, ['code', 'formula'], temperature=0) + total_ms = 1000 * (time.perf_counter() - t0) + avg_ms = (total_ms / len(image_names)) if len(image_names) > 0 else 0 + logger.info( + "For {} images(ms): [total|avg] = [{:.1f}|{:.1f}]".format( + len(image_names), total_ms, avg_ms + ) + ) + + for i, output in enumerate(outputs): + logger.info(f"\nOutput {i}:\n{output}\n\n") + + +def main(args): + num_threads = int(args.num_threads) if args.num_threads is not None else None + device = args.device.lower() + image_dir = args.image_dir + viz_dir = args.viz_dir + + # Initialize logger + logging.basicConfig(level=logging.DEBUG) + logger = logging.getLogger("CodeFormulaPredictor") + logger.setLevel(logging.DEBUG) + if not logger.hasHandlers(): + handler = logging.StreamHandler(sys.stdout) + formatter = logging.Formatter( + "%(asctime)s %(name)-12s %(levelname)-8s %(message)s" + ) + handler.setFormatter(formatter) + logger.addHandler(handler) + + # Ensure the viz dir + Path(viz_dir).mkdir(parents=True, exist_ok=True) + + # Download models from HF + download_path = snapshot_download(repo_id="ds4sd/CodeFormula", revision="v1.0.0") + + # Test the Code+Equation model + demo(logger, download_path, device, num_threads, image_dir, viz_dir) + + +if __name__ == "__main__": + r""" + python -m demo.demo_code_formula_predictor -i + """ + parser = argparse.ArgumentParser(description="Test the CodeFormulaPredictor") + parser.add_argument( + "-d", "--device", required=False, default="cpu", help="One of [cpu, cuda, mps]" + ) + parser.add_argument( + "-n", "--num_threads", required=False, default=4, help="Number of threads" + ) + parser.add_argument( + "-i", + "--image_dir", + required=True, + help="PNG images input directory", + ) + parser.add_argument( + "-v", + "--viz_dir", + required=False, + default="viz/", + help="Directory to save prediction visualizations", + ) + + args = parser.parse_args() + main(args) diff --git a/docling_ibm_models/code_formula_model/code_formula_predictor.py b/docling_ibm_models/code_formula_model/code_formula_predictor.py new file mode 100644 index 0000000..52849a7 --- /dev/null +++ b/docling_ibm_models/code_formula_model/code_formula_predictor.py @@ -0,0 +1,223 @@ +# +# Copyright IBM Corp. 2024 - 2024 +# SPDX-License-Identifier: MIT +# +import logging +from typing import List, Union + +import numpy as np +import torch +from PIL import Image +from transformers import AutoTokenizer + +from docling_ibm_models.code_formula_model.models.sam_opt import SamOPTForCausalLM +from docling_ibm_models.code_formula_model.models.sam_opt_image_processor import ( + SamOptImageProcessor, +) + +_log = logging.getLogger(__name__) + + +class CodeFormulaPredictor: + """ + Code and Formula Predictor using a multi-modal vision-language model. + + This class enables the prediction of code or LaTeX representations + from input images of code snippets or mathematical formulas. + + Attributes + ---------- + _device : str + The device on which the model is loaded (e.g., 'cpu' or 'cuda'). + _num_threads : int + Number of threads used for inference when running on CPU. + _tokenizer : transformers.PreTrainedTokenizer + Tokenizer for processing textual inputs to the model. + _model : transformers.PreTrainedModel + Pretrained multi-modal vision-language model. + _image_processor : transformers.ImageProcessor + Processor for normalizing and preparing input images. + _temperature : float + Sampling temperature for generation; controls randomness in predictions. + """ + + def __init__( + self, + artifacts_path: str, + device: str = "cpu", + num_threads: int = 4, + ): + """ + Initializes the CodeFormulaPredictor with the specified model artifacts. + + Parameters + ---------- + artifacts_path : str + Path to the directory containing the pretrained model files. + device : str, optional + Device to run the inference on ('cpu' or 'cuda'), by default "cpu". + num_threads : int, optional + Number of threads for CPU inference, by default 4. + """ + self._device = device + self._num_threads = num_threads + if device == "cpu": + torch.set_num_threads(self._num_threads) + + self._tokenizer = AutoTokenizer.from_pretrained( + artifacts_path, use_fast=True, padding_side="left" + ) + self._model = SamOPTForCausalLM.from_pretrained(artifacts_path).to(self._device) + self._model.eval() + + self._image_processor = SamOptImageProcessor.from_pretrained(artifacts_path) + + _log.debug("CodeFormulaModel settings: {}".format(self.info())) + + def info(self) -> dict: + """ + Retrieves configuration details of the CodeFormulaPredictor instance. + + Returns + ------- + dict + A dictionary containing configuration details such as the device and + the number of threads used. + """ + info = { + "device": self._device, + "num_threads": self._num_threads, + } + return info + + def _get_prompt(self, label: str) -> str: + """ + Constructs the prompt for the model based on the input label. + + Parameters + ---------- + label : str + The type of input, either 'code' or 'formula'. + + Returns + ------- + str + The constructed prompt including necessary tokens and query. + + Raises + ------ + NotImplementedError + If the label is not 'code' or 'formula'. + """ + if label == "code": + query = "" + elif label == "formula": + query = "" + else: + raise NotImplementedError("Label must be either code or formula") + + prompt = ( + "A chat between a curious user and an artificial intelligence" + " assistant. The assistant gives helpful, detailed, and polite answers to" + " the user's questions. USER:" + ) + prompt += ( + "" + "" * 256 + "" + "\n" + " ASSISTANT:" + "\n" + query + ) + + return prompt + + @torch.inference_mode() + def predict( + self, + images: List[Union[Image.Image, np.ndarray]], + labels: List[str], + temperature: float = 0.1, + ) -> List[str]: + """ + Predicts the textual representation of input images (code or LaTeX). + + Parameters + ---------- + images : List[Union[Image.Image, np.ndarray]] + List of images to be processed, provided as PIL Image objects or numpy arrays. + labels : List[str] + List of labels indicating the type of each image ('code' or 'formula'). + temperature : float, optional + Sampling temperature for generation, by default set to 0.1. + + Returns + ------- + List[str] + List of predicted textual outputs for each input image in the given input + order. + + Raises + ------ + TypeError + If any of the input images is not of a supported type (PIL Image or numpy array). + Excpetion + In case the temperature is an invalid number. + """ + if (type(temperature) != float and type(temperature) != int) or temperature < 0: + raise Exception("Temperature must be a number greater or equal to 0.") + + do_sample = True + if temperature == 0: + do_sample = False + temperature = None + + if len(labels) != len(images): + raise Exception( + "The number of images must be the same as the number of labels." + ) + + images_tmp = [] + for image in images: + if isinstance(image, Image.Image): + image = image.convert("RGB") + elif isinstance(image, np.ndarray): + image = Image.fromarray(image).convert("RGB") + else: + raise TypeError("Not supported input image format") + images_tmp.append(image) + images = images_tmp + + images_tensor = torch.stack([self._image_processor(img) for img in images]).to( + self._device + ) + + prompts = [self._get_prompt(label) for label in labels] + + tokenized = self._tokenizer(prompts, padding=True, return_tensors="pt") + tokenized = {k: v.to(self._device) for k, v in tokenized.items()} + + prompt_ids = tokenized["input_ids"] + attention_mask = tokenized["attention_mask"] + + if self._device == "cpu": + output_ids_list = self._model.generate( + input_ids=prompt_ids, + attention_mask=attention_mask, + images=images_tensor, + do_sample=do_sample, + temperature=temperature, + max_new_tokens=4096 - prompt_ids.shape[1], + use_cache=True, + ) + else: + with torch.autocast(device_type=self._device, dtype=torch.bfloat16): + output_ids_list = self._model.generate( + prompt_ids, + images=images_tensor, + do_sample=do_sample, + temperature=temperature, + max_new_tokens=4096 - prompt_ids.shape[1], + use_cache=True, + ) + + outputs = self._tokenizer.batch_decode( + output_ids_list[:, prompt_ids.shape[1] :], skip_special_tokens=True + ) + + return outputs diff --git a/docling_ibm_models/code_formula_model/models/sam.py b/docling_ibm_models/code_formula_model/models/sam.py new file mode 100644 index 0000000..98179cf --- /dev/null +++ b/docling_ibm_models/code_formula_model/models/sam.py @@ -0,0 +1,514 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This file was originally developed by Meta Platforms, Inc. as part of +# the Segment Anything project (https://github.com/facebookresearch/segment-anything). +# It has been adapted by contributors from the Vary-toy project +# (https://github.com/Ucas-HaoranWei/Vary-toy). +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from functools import partial +from typing import Optional, Tuple, Type + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class MLPBlock(nn.Module): + def __init__( + self, + embedding_dim: int, + mlp_dim: int, + act: Type[nn.Module] = nn.GELU, + ) -> None: + super().__init__() + self.lin1 = nn.Linear(embedding_dim, mlp_dim) + self.lin2 = nn.Linear(mlp_dim, embedding_dim) + self.act = act() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.lin2(self.act(self.lin1(x))) + + +# From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa +# Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa +class LayerNorm2d(nn.Module): + def __init__(self, num_channels: int, eps: float = 1e-6) -> None: + super().__init__() + self.weight = nn.Parameter(torch.ones(num_channels)) + self.bias = nn.Parameter(torch.zeros(num_channels)) + self.eps = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + x = self.weight[:, None, None] * x + self.bias[:, None, None] + return x + + +# This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa +class ImageEncoderViT(nn.Module): + def __init__( + self, + img_size: int = 1024, + patch_size: int = 16, + in_chans: int = 3, + embed_dim: int = 768, + depth: int = 12, + num_heads: int = 12, + mlp_ratio: float = 4.0, + out_chans: int = 256, + qkv_bias: bool = True, + norm_layer: Type[nn.Module] = nn.LayerNorm, + act_layer: Type[nn.Module] = nn.GELU, + use_abs_pos: bool = True, + use_rel_pos: bool = False, + rel_pos_zero_init: bool = True, + window_size: int = 0, + global_attn_indexes: Tuple[int, ...] = (), + ) -> None: + """ + Args: + img_size (int): Input image size. + patch_size (int): Patch size. + in_chans (int): Number of input image channels. + embed_dim (int): Patch embedding dimension. + depth (int): Depth of ViT. + num_heads (int): Number of attention heads in each ViT block. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + norm_layer (nn.Module): Normalization layer. + act_layer (nn.Module): Activation layer. + use_abs_pos (bool): If True, use absolute positional embeddings. + use_rel_pos (bool): If True, add relative positional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + window_size (int): Window size for window attention blocks. + global_attn_indexes (list): Indexes for blocks using global attention. + """ + super().__init__() + self.img_size = img_size + + self.patch_embed = PatchEmbed( + kernel_size=(patch_size, patch_size), + stride=(patch_size, patch_size), + in_chans=in_chans, + embed_dim=embed_dim, + ) + + self.pos_embed: Optional[nn.Parameter] = None + if use_abs_pos: + self.pos_embed = nn.Parameter( + torch.zeros( + 1, img_size // patch_size, img_size // patch_size, embed_dim + ) + ) + + self.blocks = nn.ModuleList() + for i in range(depth): + block = Block( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + norm_layer=norm_layer, + act_layer=act_layer, + use_rel_pos=use_rel_pos, + rel_pos_zero_init=rel_pos_zero_init, + window_size=window_size if i not in global_attn_indexes else 0, + input_size=(img_size // patch_size, img_size // patch_size), + ) + self.blocks.append(block) + + self.neck = nn.Sequential( + nn.Conv2d( + embed_dim, + out_chans, + kernel_size=1, + bias=False, + ), + LayerNorm2d(out_chans), + nn.Conv2d( + out_chans, + out_chans, + kernel_size=3, + padding=1, + bias=False, + ), + LayerNorm2d(out_chans), + ) + + self.net_2 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1, bias=False) + self.net_3 = nn.Conv2d( + 512, 1024, kernel_size=3, stride=2, padding=1, bias=False + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.patch_embed(x) + if self.pos_embed is not None: + x = x + self.pos_embed + + for blk in self.blocks: + x = blk(x) + + x = self.neck(x.permute(0, 3, 1, 2)) + x = self.net_2(x) + x = self.net_3(x) + + return x + + +class Block(nn.Module): + """Transformer blocks with support of window attention and residual propagation blocks""" + + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + norm_layer: Type[nn.Module] = nn.LayerNorm, + act_layer: Type[nn.Module] = nn.GELU, + use_rel_pos: bool = False, + rel_pos_zero_init: bool = True, + window_size: int = 0, + input_size: Optional[Tuple[int, int]] = None, + ) -> None: + """ + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads in each ViT block. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + norm_layer (nn.Module): Normalization layer. + act_layer (nn.Module): Activation layer. + use_rel_pos (bool): If True, add relative positional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + window_size (int): Window size for window attention blocks. If it equals 0, then + use global attention. + input_size (tuple(int, int) or None): Input resolution for calculating the relative + positional parameter size. + """ + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + use_rel_pos=use_rel_pos, + rel_pos_zero_init=rel_pos_zero_init, + input_size=input_size if window_size == 0 else (window_size, window_size), + ) + + self.norm2 = norm_layer(dim) + self.mlp = MLPBlock( + embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer + ) + + self.window_size = window_size + + def forward(self, x: torch.Tensor) -> torch.Tensor: + shortcut = x + x = self.norm1(x) + if self.window_size > 0: + H, W = x.shape[1], x.shape[2] + x, pad_hw = window_partition(x, self.window_size) + + x = self.attn(x) + if self.window_size > 0: + x = window_unpartition(x, self.window_size, pad_hw, (H, W)) + + x = shortcut + x + x = x + self.mlp(self.norm2(x)) + + return x + + +class Attention(nn.Module): + """Multi-head Attention block with relative position embeddings.""" + + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = True, + use_rel_pos: bool = False, + rel_pos_zero_init: bool = True, + input_size: Optional[Tuple[int, int]] = None, + ) -> None: + """ + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + rel_pos (bool): If True, add relative positional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + input_size (tuple(int, int) or None): Input resolution for calculating the relative + positional parameter size. + """ + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.proj = nn.Linear(dim, dim) + + self.use_rel_pos = use_rel_pos + if self.use_rel_pos: + assert ( + input_size is not None + ), "Input size must be provided if using relative positional encoding." + # initialize relative positional embeddings + self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim)) + self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, H, W, _ = x.shape + # qkv with shape (3, B, nHead, H * W, C) + qkv = ( + self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + ) + # q, k, v with shape (B * nHead, H * W, C) + q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0) + + attn = (q * self.scale) @ k.transpose(-2, -1) + + if self.use_rel_pos: + attn = add_decomposed_rel_pos( + attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W) + ) + + attn = attn.softmax(dim=-1) + x = ( + (attn @ v) + .view(B, self.num_heads, H, W, -1) + .permute(0, 2, 3, 1, 4) + .reshape(B, H, W, -1) + ) + x = self.proj(x) + + return x + + +def window_partition( + x: torch.Tensor, window_size: int +) -> Tuple[torch.Tensor, Tuple[int, int]]: + """ + Partition into non-overlapping windows with padding if needed. + Args: + x (tensor): input tokens with [B, H, W, C]. + window_size (int): window size. + + Returns: + windows: windows after partition with [B * num_windows, window_size, window_size, C]. + (Hp, Wp): padded height and width before partition + """ + B, H, W, C = x.shape + + pad_h = (window_size - H % window_size) % window_size + pad_w = (window_size - W % window_size) % window_size + if pad_h > 0 or pad_w > 0: + x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) + Hp, Wp = H + pad_h, W + pad_w + + x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) + windows = ( + x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + ) + return windows, (Hp, Wp) + + +def window_unpartition( + windows: torch.Tensor, + window_size: int, + pad_hw: Tuple[int, int], + hw: Tuple[int, int], +) -> torch.Tensor: + """ + Window unpartition into original sequences and removing padding. + Args: + windows (tensor): input tokens with [B * num_windows, window_size, window_size, C]. + window_size (int): window size. + pad_hw (Tuple): padded height and width (Hp, Wp). + hw (Tuple): original height and width (H, W) before padding. + + Returns: + x: unpartitioned sequences with [B, H, W, C]. + """ + Hp, Wp = pad_hw + H, W = hw + B = windows.shape[0] // (Hp * Wp // window_size // window_size) + x = windows.view( + B, Hp // window_size, Wp // window_size, window_size, window_size, -1 + ) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) + + if Hp > H or Wp > W: + x = x[:, :H, :W, :].contiguous() + return x + + +def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor: + """ + Get relative positional embeddings according to the relative positions of + query and key sizes. + Args: + q_size (int): size of query q. + k_size (int): size of key k. + rel_pos (Tensor): relative position embeddings (L, C). + + Returns: + Extracted positional embeddings according to relative positions. + """ + max_rel_dist = int(2 * max(q_size, k_size) - 1) + # Interpolate rel pos if needed. + if rel_pos.shape[0] != max_rel_dist: + # Interpolate rel pos. + rel_pos_resized = F.interpolate( + rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), + size=max_rel_dist, + mode="linear", + ) + rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) + else: + rel_pos_resized = rel_pos + + # Scale the coords with short length if shapes for q and k are different. + q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) + k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0) + relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) + + return rel_pos_resized[relative_coords.long()] + + +def add_decomposed_rel_pos( + attn: torch.Tensor, + q: torch.Tensor, + rel_pos_h: torch.Tensor, + rel_pos_w: torch.Tensor, + q_size: Tuple[int, int], + k_size: Tuple[int, int], +) -> torch.Tensor: + """ + Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. + https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950 + Args: + attn (Tensor): attention map. + q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C). + rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis. + rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis. + q_size (Tuple): spatial sequence size of query q with (q_h, q_w). + k_size (Tuple): spatial sequence size of key k with (k_h, k_w). + + Returns: + attn (Tensor): attention map with added relative positional embeddings. + """ + q_h, q_w = q_size + k_h, k_w = k_size + Rh = get_rel_pos(q_h, k_h, rel_pos_h) + Rw = get_rel_pos(q_w, k_w, rel_pos_w) + + B, _, dim = q.shape + r_q = q.reshape(B, q_h, q_w, dim) + rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) + rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw) + + attn = ( + attn.view(B, q_h, q_w, k_h, k_w) + + rel_h[:, :, :, :, None] + + rel_w[:, :, :, None, :] + ).view(B, q_h * q_w, k_h * k_w) + + return attn + + +class PatchEmbed(nn.Module): + """ + Image to Patch Embedding. + """ + + def __init__( + self, + kernel_size: Tuple[int, int] = (16, 16), + stride: Tuple[int, int] = (16, 16), + padding: Tuple[int, int] = (0, 0), + in_chans: int = 3, + embed_dim: int = 768, + ) -> None: + """ + Args: + kernel_size (Tuple): kernel size of the projection layer. + stride (Tuple): stride of the projection layer. + padding (Tuple): padding size of the projection layer. + in_chans (int): Number of input image channels. + embed_dim (int): Patch embedding dimension. + """ + super().__init__() + + self.proj = nn.Conv2d( + in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.proj(x) + # B C H W -> B H W C + x = x.permute(0, 2, 3, 1) + return x + + +def build_sam_vit_b(checkpoint=None, image_size=1024): + return _build_sam( + encoder_embed_dim=768, + encoder_depth=12, + encoder_num_heads=12, + encoder_global_attn_indexes=[2, 5, 8, 11], + checkpoint=checkpoint, + image_size=image_size, + ) + + +def _build_sam( + encoder_embed_dim, + encoder_depth, + encoder_num_heads, + encoder_global_attn_indexes, + checkpoint=None, + image_size=1024, +): + prompt_embed_dim = 256 + vit_patch_size = 16 + image_encoder = ImageEncoderViT( + depth=encoder_depth, + embed_dim=encoder_embed_dim, + img_size=image_size, + mlp_ratio=4, + norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), + num_heads=encoder_num_heads, + patch_size=vit_patch_size, + qkv_bias=True, + use_rel_pos=True, + global_attn_indexes=encoder_global_attn_indexes, + window_size=14, + out_chans=prompt_embed_dim, + ) + + if checkpoint is not None: + # with open(checkpoint, "rb") as f: + state_dict = torch.load(checkpoint) + + image_encoder.load_state_dict(state_dict, strict=True) + return image_encoder diff --git a/docling_ibm_models/code_formula_model/models/sam_opt.py b/docling_ibm_models/code_formula_model/models/sam_opt.py new file mode 100644 index 0000000..c682f45 --- /dev/null +++ b/docling_ibm_models/code_formula_model/models/sam_opt.py @@ -0,0 +1,237 @@ +# Copyright 2023 Haotian Liu +# +# This file is part of the Vary project, originally located at: +# https://github.com/Ucas-HaoranWei/Vary-toy/blob/main/Vary-master/vary/model/vary_opt.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn +from transformers import ( + AutoConfig, + AutoModelForCausalLM, + OPTConfig, + OPTForCausalLM, + OPTModel, +) +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, +) + +from docling_ibm_models.code_formula_model.models.sam import build_sam_vit_b + + +class SamOptConfig(OPTConfig): + model_type = "sam_opt" + + def __init__( + self, + sam_image_size=1024, + sam_mm_projector_in=1024, + sam_mm_projector_out=768, + **kwargs, + ): + super().__init__(**kwargs) + self.sam_image_size = sam_image_size + self.sam_mm_projector_in = sam_mm_projector_in + self.sam_mm_projector_out = sam_mm_projector_out + + +class SamOPTModel(OPTModel): + config_class = SamOptConfig + + def __init__(self, config: OPTConfig): + super(SamOPTModel, self).__init__(config) + self.vision_tower = build_sam_vit_b(image_size=config.sam_image_size) + + self.mm_projector = nn.Linear( + config.sam_mm_projector_in, config.sam_mm_projector_out + ) + + def embed_tokens(self, x): + return self.get_input_embeddings()(x) + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + images: torch.FloatTensor = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + vision_tower = getattr(self, "vision_tower", None) + im_start_token = getattr(self.config, "im_start_token", -1) + + if input_ids.shape[1] != 1 or self.training: + with torch.set_grad_enabled(self.training): + image_features = vision_tower(images) + image_features = image_features.flatten(2).permute(0, 2, 1) + image_features = self.mm_projector(image_features) + + new_input_embeds = [] + for cur_input_ids, cur_input_embeds, cur_image_features in zip( + input_ids, inputs_embeds, image_features + ): + image_start_token_position = torch.where( + cur_input_ids == im_start_token + )[0].item() + + cur_image_features = cur_image_features.to( + device=cur_input_embeds.device + ) + num_patches = cur_image_features.shape[0] + cur_input_embeds = torch.cat( + ( + cur_input_embeds[: image_start_token_position + 1], + cur_image_features, + cur_input_embeds[ + image_start_token_position + num_patches + 1 : + ], + ), + dim=0, + ) + + new_input_embeds.append(cur_input_embeds) + + inputs_embeds = torch.stack(new_input_embeds, dim=0) + + return super(SamOPTModel, self).forward( + input_ids=None, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + +class SamOPTForCausalLM(OPTForCausalLM): + config_class = SamOptConfig + + def __init__(self, config): + super(OPTForCausalLM, self).__init__(config) + self.model = SamOPTModel(config) + + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + self.post_init() + + def get_model(self): + return self.model + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + images: Optional[torch.FloatTensor] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + + outputs = self.model( + input_ids=input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + images=images, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states).contiguous() + + return CausalLMOutputWithPast( + loss=None, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs + ): + token_type_ids = kwargs.get("token_type_ids", None) + if past_key_values: + input_ids = input_ids[:, -1].unsqueeze(-1) + if token_type_ids is not None: + token_type_ids = token_type_ids[:, -1].unsqueeze(-1) + + attention_mask = kwargs.get("attention_mask", None) + position_ids = kwargs.get("position_ids", None) + + if attention_mask is not None and position_ids is None: + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -1].unsqueeze(-1) + else: + position_ids = None + + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "position_ids": position_ids, + "attention_mask": attention_mask, + "token_type_ids": token_type_ids, + "images": kwargs.get("images", None), + } + ) + return model_inputs + + +AutoConfig.register("sam_opt", SamOptConfig) +AutoModelForCausalLM.register(SamOptConfig, SamOPTForCausalLM) diff --git a/docling_ibm_models/code_formula_model/models/sam_opt_image_processor.py b/docling_ibm_models/code_formula_model/models/sam_opt_image_processor.py new file mode 100644 index 0000000..75459aa --- /dev/null +++ b/docling_ibm_models/code_formula_model/models/sam_opt_image_processor.py @@ -0,0 +1,31 @@ +# +# Copyright IBM Corp. 2024 - 2024 +# SPDX-License-Identifier: MIT +# +from PIL import Image +from torchvision.transforms import functional as F +from transformers import AutoImageProcessor +from transformers.image_processing_utils import ImageProcessingMixin + + +class SamOptImageProcessor(ImageProcessingMixin): + + def __init__(self, size=(1024, 1024), mean=None, std=None, **kwargs): + super().__init__(**kwargs) + self.size = size + self.mean = mean + self.std = std + + def __call__(self, image): + if not isinstance(image, Image.Image): + raise ValueError("Input must be a PIL Image") + + image = F.resize(image, self.size) + image = F.to_tensor(image) + + image = F.normalize(image, mean=self.mean, std=self.std) + + return image + + +AutoImageProcessor.register(SamOptImageProcessor, SamOptImageProcessor) diff --git a/tests/test_code_formula_predictor.py b/tests/test_code_formula_predictor.py new file mode 100644 index 0000000..52614f4 --- /dev/null +++ b/tests/test_code_formula_predictor.py @@ -0,0 +1,131 @@ +# +# Copyright IBM Corp. 2024 - 2024 +# SPDX-License-Identifier: MIT +# +import os +import numpy as np +import pytest +from PIL import Image + +from docling_ibm_models.code_formula_model.code_formula_predictor import CodeFormulaPredictor + +from huggingface_hub import snapshot_download + +@pytest.fixture(scope="module") +def init() -> dict: + r""" + Initialize the testing environment + """ + init = { + "num_threads": 1, + "test_imgs": [ + { + "label": "code", + "image_path": "tests/test_data/code_formula/images/code.png", + "gt_path": "tests/test_data/code_formula/gt/code.txt", + }, + { + "label": "formula", + "image_path": "tests/test_data/code_formula/images/formula.png", + "gt_path": "tests/test_data/code_formula/gt/formula.txt", + }, + ], + "info": { + "device": "auto", + "temperature": 0, + }, + } + + # Download models from HF + artifact_path = snapshot_download(repo_id="ds4sd/CodeFormula", revision="v1.0.0") + + init["artifact_path"] = artifact_path + + return init + + +def test_code_formula_predictor(init: dict): + r""" + Unit test for the CodeFormulaPredictor + """ + device = "cpu" + num_threads = 2 + + # Initialize LayoutPredictor + code_formula_predictor = CodeFormulaPredictor( + init["artifact_path"], device=device, num_threads=num_threads + ) + + # Check info + info = code_formula_predictor.info() + assert info["device"] == device, "Wronly set device" + assert info["num_threads"] == num_threads, "Wronly set number of threads" + + # Unsupported input image + is_exception = False + try: + for _ in code_formula_predictor.predict(["wrong"], ['label']): + pass + except TypeError: + is_exception = True + assert is_exception + + # wrong type for temperature + is_exception = False + try: + dummy_image = Image.new(mode="RGB", size=(100, 100), color=(255, 255, 255)) + for _ in code_formula_predictor.predict([dummy_image], ['label'], "0.1"): + pass + except Exception: + is_exception = True + assert is_exception + + # wrong value for temperature + is_exception = False + try: + dummy_image = Image.new(mode="RGB", size=(100, 100), color=(255, 255, 255)) + for _ in code_formula_predictor.predict([dummy_image], ['label'], -0.1): + pass + except Exception: + is_exception = True + assert is_exception + + # mistmatched number of images and labels + is_exception = False + try: + dummy_image = Image.new(mode="RGB", size=(100, 100), color=(255, 255, 255)) + for _ in code_formula_predictor.predict([dummy_image], ['label', 'label']): + pass + except Exception: + is_exception = True + assert is_exception + + # Predict on test images, not batched + temperature = init['info']['temperature'] + for d in init["test_imgs"]: + label = d['label'] + img_path = d['image_path'] + gt_path = d['gt_path'] + + with Image.open(img_path) as img, open(gt_path, 'r') as gt_fp: + gt = gt_fp.read() + + output = code_formula_predictor.predict([img], [label], temperature) + output = output[0] + + assert output == gt + + # Load images as numpy arrays + np_arr = np.asarray(img) + output = code_formula_predictor.predict([np_arr], [label], temperature) + output = output[0] + + assert output == gt + + # Predict on test images, batched + labels = [d['label'] for d in init["test_imgs"]] + images = [Image.open(d['image_path']) for d in init["test_imgs"]] + gts = [open(d['gt_path'], 'r').read() for d in init["test_imgs"]] + + outputs = code_formula_predictor.predict(images, labels, temperature) + assert outputs == gts diff --git a/tests/test_data/code_formula/gt/code.txt b/tests/test_data/code_formula/gt/code.txt new file mode 100644 index 0000000..9de7750 --- /dev/null +++ b/tests/test_data/code_formula/gt/code.txt @@ -0,0 +1,27 @@ +<_C++_> #include +using namespace std; + +int main(){ + + int n; + + while(cin>>n, n){ + int cnt=0; + + n=1000-n; + cnt+=n/500; + n%=500; + cnt+=n/100; + n%=100; + cnt+=n/50; + n%=50; + cnt+=n/10; + n%=10; + cnt+=n/5; + n%=5; + + cout<