-
Notifications
You must be signed in to change notification settings - Fork 37
fix: Enable MyPy in pre-commit and refactor the code to fix all errors #74
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
a5e6084
71ff4bc
14e71a7
43ead93
24baf54
b45c176
fe7ddf1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,7 +3,7 @@ | |
# SPDX-License-Identifier: MIT | ||
# | ||
import logging | ||
from typing import List, Union | ||
from typing import List, Optional, Union | ||
|
||
import numpy as np | ||
import torch | ||
|
@@ -132,7 +132,7 @@ def predict( | |
self, | ||
images: List[Union[Image.Image, np.ndarray]], | ||
labels: List[str], | ||
temperature: float = 0.1, | ||
temperature: Optional[float] = 0.1, | ||
) -> List[str]: | ||
""" | ||
Predicts the textual representation of input images (code or LaTeX). | ||
|
@@ -143,7 +143,7 @@ def predict( | |
List of images to be processed, provided as PIL Image objects or numpy arrays. | ||
labels : List[str] | ||
List of labels indicating the type of each image ('code' or 'formula'). | ||
temperature : float, optional | ||
temperature : Optional[float] | ||
Sampling temperature for generation, by default set to 0.1. | ||
|
||
Returns | ||
|
@@ -159,7 +159,11 @@ def predict( | |
Excpetion | ||
In case the temperature is an invalid number. | ||
""" | ||
if (type(temperature) != float and type(temperature) != int) or temperature < 0: | ||
if ( | ||
temperature is None | ||
or not (isinstance(temperature, float) or isinstance(temperature, int)) | ||
or temperature < 0 | ||
): | ||
raise Exception("Temperature must be a number greater or equal to 0.") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. { // Predicate: |
||
|
||
do_sample = True | ||
|
@@ -181,11 +185,10 @@ def predict( | |
else: | ||
raise TypeError("Not supported input image format") | ||
images_tmp.append(image) | ||
images = images_tmp | ||
|
||
images_tensor = torch.stack([self._image_processor(img) for img in images]).to( | ||
self._device | ||
) | ||
images_tensor = torch.stack( | ||
[self._image_processor(img) for img in images_tmp] | ||
).to(self._device) | ||
|
||
prompts = [self._get_prompt(label) for label in labels] | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
{
// Standard attestation fields:
"_type": "https://in-toto.io/Statement/v0.1",
"subject": [{ ... }],
// Predicate:
"predicateType": "https://cyclonedx.org/bom/v1.4",
"predicate": {
"bomFormat": "CycloneDX",
"specVersion": "1.4",
"serialNumber": "urn:uuid:3e671687-395b-41f5-a30f-a58921a69b79",
"version": 1,
"components": [
{
"type": "library",
"name": "acme-library",
"version": "1.0.0"
}
]
...
}
}