Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Enable MyPy in pre-commit and refactor the code to fix all errors #74

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ repos:
entry: poetry lock --check
pass_filenames: false
language: system
# - id: system
# name: MyPy
# entry: poetry run mypy docling_ibm_models
# pass_filenames: false
# language: system
# files: '\.py$'
- id: system
name: MyPy
entry: poetry run mypy docling_ibm_models
pass_filenames: false
language: system
files: '\.py$'
Empty file added docling_ibm_models/__init__.py
Empty file.
Empty file.
19 changes: 11 additions & 8 deletions docling_ibm_models/code_formula_model/code_formula_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# SPDX-License-Identifier: MIT
#
import logging
from typing import List, Union
from typing import List, Optional, Union

import numpy as np
import torch
Expand Down Expand Up @@ -132,7 +132,7 @@ def predict(
self,
images: List[Union[Image.Image, np.ndarray]],
labels: List[str],
temperature: float = 0.1,
temperature: Optional[float] = 0.1,
) -> List[str]:
"""
Predicts the textual representation of input images (code or LaTeX).
Expand All @@ -143,7 +143,7 @@ def predict(
List of images to be processed, provided as PIL Image objects or numpy arrays.
labels : List[str]
List of labels indicating the type of each image ('code' or 'formula').
temperature : float, optional
temperature : Optional[float]
Sampling temperature for generation, by default set to 0.1.

Returns
Expand All @@ -159,7 +159,11 @@ def predict(
Excpetion
In case the temperature is an invalid number.
"""
if (type(temperature) != float and type(temperature) != int) or temperature < 0:
if (
temperature is None
or not (isinstance(temperature, float) or isinstance(temperature, int))
or temperature < 0
):
raise Exception("Temperature must be a number greater or equal to 0.")

do_sample = True
Expand All @@ -181,11 +185,10 @@ def predict(
else:
raise TypeError("Not supported input image format")
images_tmp.append(image)
images = images_tmp

images_tensor = torch.stack([self._image_processor(img) for img in images]).to(
self._device
)
images_tensor = torch.stack(
[self._image_processor(img) for img in images_tmp]
).to(self._device)

prompts = [self._get_prompt(label) for label in labels]

Expand Down
Empty file.
13 changes: 7 additions & 6 deletions docling_ibm_models/code_formula_model/models/sam_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,14 +67,14 @@ def embed_tokens(self, x):

def forward(
self,
input_ids: torch.LongTensor = None,
input_ids: torch.LongTensor,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
images: torch.FloatTensor = None,
images: Optional[torch.FloatTensor] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:

Expand All @@ -86,6 +86,7 @@ def forward(

if input_ids.shape[1] != 1 or self.training:
with torch.set_grad_enabled(self.training):
assert vision_tower is not None
image_features = vision_tower(images)
image_features = image_features.flatten(2).permute(0, 2, 1)
image_features = self.mm_projector(image_features)
Expand All @@ -94,9 +95,9 @@ def forward(
for cur_input_ids, cur_input_embeds, cur_image_features in zip(
input_ids, inputs_embeds, image_features
):
image_start_token_position = torch.where(
cur_input_ids == im_start_token
)[0].item()
image_start_token_position = int(
torch.where(cur_input_ids == im_start_token)[0].item()
) # cast to int for mypy

cur_image_features = cur_image_features.to(
device=cur_input_embeds.device
Expand All @@ -115,7 +116,7 @@ def forward(

new_input_embeds.append(cur_input_embeds)

inputs_embeds = torch.stack(new_input_embeds, dim=0)
inputs_embeds = torch.stack(new_input_embeds, dim=0) # type: ignore

return super(SamOPTModel, self).forward(
input_ids=None,
Expand Down
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -147,24 +147,23 @@ def predict(

The predictions for each image are sorted in descending order of confidence.
"""
processed_images = []
rgb_images = []
for image in images:
if isinstance(image, Image.Image):
processed_images.append(image.convert("RGB"))
rgb_images.append(image.convert("RGB"))
elif isinstance(image, np.ndarray):
processed_images.append(Image.fromarray(image).convert("RGB"))
rgb_images.append(Image.fromarray(image).convert("RGB"))
else:
raise TypeError(
"Supported input formats are PIL.Image.Image or numpy.ndarray."
)
images = processed_images

# (batch_size, 3, 224, 224)
images = [self._image_processor(image) for image in images]
images = torch.stack(images).to(self._device)
processed_images = [self._image_processor(image) for image in rgb_images]
torch_images = torch.stack(processed_images).to(self._device)

with torch.no_grad():
logits = self._model(images).logits # (batch_size, num_classes)
logits = self._model(torch_images).logits # (batch_size, num_classes)
probs_batch = logits.softmax(dim=1) # (batch_size, num_classes)
probs_batch = probs_batch.cpu().numpy().tolist()

Expand Down
Empty file.
Empty file added docling_ibm_models/py.typed
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def forward(self, x):


class TMTransformerDecoder(nn.TransformerDecoder):
def forward(
def forward( # type: ignore
self,
tgt: Tensor,
memory: Optional[Tensor] = None,
Expand Down Expand Up @@ -69,11 +69,11 @@ def forward(
else:
out_cache = torch.stack(tag_cache, dim=0)

return output, out_cache
return output, out_cache # type: ignore


class TMTransformerDecoderLayer(nn.TransformerDecoderLayer):
def forward(
def forward( # type: ignore
self,
tgt: Tensor,
memory: Optional[Tensor] = None,
Expand Down
2 changes: 1 addition & 1 deletion docling_ibm_models/tableformer/otsl.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
LOG_LEVEL = logging.INFO
# LOG_LEVEL = logging.DEBUG
logger = s.get_custom_logger("consolidate", LOG_LEVEL)
png_files = {} # Evaluation files
# png_files = {} # Evaluation files
total_pics = 0


Expand Down
5 changes: 3 additions & 2 deletions docling_ibm_models/tableformer/utils/mem_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import os
import platform
import re
from typing import Dict, Union


class MemMonitor:
Expand Down Expand Up @@ -112,7 +113,7 @@ def __init__(self, enable=True):
regex_str = r"({}:)(\s+)(\d*)(.*)".format(mem_field)
self._status_regex[mem_field] = re.compile(regex_str)

def get_memory_full(self) -> dict:
def get_memory_full(self) -> Union[Dict, int]:
r"""
- Parse /proc/<pid>status to get all memory info.
- The method returns a dict with the fields self._status_fields
Expand Down Expand Up @@ -140,7 +141,7 @@ def get_memory_full(self) -> dict:

return memory

def get_memory(self) -> dict:
def get_memory(self) -> Union[Dict, int]:
r"""
- Parse /proc/<pid>statm to get the most important memory fields
- This is a fast implementation.
Expand Down
Loading
Loading