Skip to content

Commit 49c091b

Browse files
committed
Remove cuda necessity for SetFit
1 parent b6d3348 commit 49c091b

File tree

4 files changed

+18
-31
lines changed

4 files changed

+18
-31
lines changed

src/trainable_entity_extractor/adapters/extractors/pdf_to_multi_option_extractor/MultiLabelMethod.py

Lines changed: 7 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -55,28 +55,19 @@ def load_json(self, file_name: str):
5555
def remove_model(self):
5656
shutil.rmtree(join(self.get_path()), ignore_errors=True)
5757

58-
def get_texts_labels(self, multi_option_data: ExtractionData) -> (list[str], list[list[int]]):
58+
def get_texts_labels(self, extraction_data: ExtractionData) -> (list[str], list[list[int]]):
5959
texts = list()
60-
for sample in multi_option_data.samples:
60+
for sample in extraction_data.samples:
6161
texts.append(" ".join([x.text_content.strip() for x in sample.pdf_data.pdf_data_segments]))
6262

63-
labels = self.get_one_hot_encoding(multi_option_data)
63+
labels = self.get_one_hot_encoding(extraction_data)
6464
return texts, labels
6565

66-
def predictions_to_options_list(self, predictions) -> list[list[Value]]:
67-
return [self.one_prediction_to_option_list(prediction) for prediction in predictions]
68-
69-
def one_prediction_to_option_list(self, prediction) -> list[Value]:
70-
if not self.multi_value:
71-
best_score_index = argmax(prediction)
72-
return [self.options[best_score_index]] if prediction[best_score_index] > 0.5 else []
73-
74-
return [Value.from_option(self.options[i]) for i, value in enumerate(prediction) if value > 0.5]
75-
76-
def get_one_hot_encoding(self, multi_option_data: ExtractionData):
77-
options_ids = [option.id for option in self.options]
66+
@staticmethod
67+
def get_one_hot_encoding(extraction_data: ExtractionData):
68+
options_ids = [option.id for option in extraction_data.options]
7869
one_hot_encoding = list()
79-
for sample in multi_option_data.samples:
70+
for sample in extraction_data.samples:
8071
one_hot_encoding.append([0] * len(options_ids))
8172
for option in sample.labeled_data.values:
8273
if option.id not in options_ids:

src/trainable_entity_extractor/adapters/extractors/pdf_to_multi_option_extractor/multi_labels_methods/SetFitEnglishMethod.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def train(self, extraction_data: ExtractionData):
7575

7676
model = SetFitModel.from_pretrained(
7777
self.model_name,
78-
labels=[x.label for x in self.options],
78+
labels=[x.label for x in extraction_data.options],
7979
multi_target_strategy="one-vs-rest",
8080
trust_remote_code=True,
8181
)
@@ -118,6 +118,8 @@ def predict(self, prediction_samples_data: PredictionSamplesData) -> list[list[V
118118
if prediction_samples_data.multi_value:
119119
predictions_proba = model.predict_proba(texts)
120120
threshold = 0.5
121+
if hasattr(predictions_proba, "cpu"):
122+
predictions_proba = predictions_proba.cpu().numpy()
121123
predictions = (predictions_proba > threshold).astype(int)
122124

123125
predictions_values = list()
@@ -137,9 +139,6 @@ def predict(self, prediction_samples_data: PredictionSamplesData) -> list[list[V
137139
return predictions_values
138140

139141
def can_be_used(self, extraction_data: ExtractionData) -> bool:
140-
if not torch.cuda.is_available():
141-
return False
142-
143142
if not extraction_data.multi_value:
144143
return False
145144

src/trainable_entity_extractor/adapters/extractors/pdf_to_multi_option_extractor/multi_labels_methods/SingleLabelSetFitEnglishMethod.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,6 @@ def gpu_needed(self) -> bool:
2828
return True
2929

3030
def can_be_used(self, extraction_data: ExtractionData) -> bool:
31-
if not torch.cuda.is_available():
32-
return False
33-
3431
if extraction_data.multi_value:
3532
return False
3633

src/trainable_entity_extractor/tests/unit_tests/extractors/pdf_to_multi_option_extractor/multi_labels_methods/test_setfit_multilingual.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
from trainable_entity_extractor.domain.LabeledData import LabeledData
1212
from trainable_entity_extractor.domain.Option import Option
1313
from trainable_entity_extractor.domain.PdfData import PdfData
14+
from trainable_entity_extractor.domain.PredictionSample import PredictionSample
15+
from trainable_entity_extractor.domain.PredictionSamplesData import PredictionSamplesData
1416
from trainable_entity_extractor.domain.TrainingSample import TrainingSample
1517
from trainable_entity_extractor.adapters.extractors.pdf_to_multi_option_extractor.multi_labels_methods.SetFitMultilingualMethod import (
1618
SetFitMultilingualMethod,
@@ -25,7 +27,7 @@ class TestSetFitMultilingualMethod(TestCase):
2527
def tearDown(self):
2628
shutil.rmtree(join(DATA_PATH, self.TENANT), ignore_errors=True)
2729

28-
@unittest.SkipTest
30+
@unittest.skip("Skipping GPU test in CI/CD")
2931
def test_train_and_predict(self):
3032
if not torch.cuda.is_available():
3133
return
@@ -55,21 +57,19 @@ def test_train_and_predict(self):
5557
extraction_data = ExtractionData(
5658
multi_value=True, options=options, samples=samples, extraction_identifier=extraction_identifier
5759
)
58-
setfit_multilingual_method = SetFitMultilingualMethod(extraction_identifier, options, True)
60+
setfit_multilingual_method = SetFitMultilingualMethod(extraction_identifier)
5961

6062
try:
6163
setfit_multilingual_method.train(extraction_data)
6264
except Exception as e:
6365
self.fail(f"train() raised {type(e).__name__}")
6466

65-
prediction_sample_1 = TrainingSample(pdf_data=pdf_data_1)
66-
prediction_sample_2 = TrainingSample(pdf_data=pdf_data_2)
67-
prediction_sample_4 = TrainingSample(pdf_data=pdf_data_4)
67+
prediction_sample_1 = PredictionSample(pdf_data=pdf_data_1)
68+
prediction_sample_2 = PredictionSample(pdf_data=pdf_data_2)
69+
prediction_sample_4 = PredictionSample(pdf_data=pdf_data_4)
6870
prediction_samples = [prediction_sample_1, prediction_sample_2, prediction_sample_4]
6971

70-
prediction_data = ExtractionData(
71-
multi_value=True, options=options, samples=prediction_samples, extraction_identifier=extraction_identifier
72-
)
72+
prediction_data = PredictionSamplesData(multi_value=True, options=options, prediction_samples=prediction_samples)
7373
predictions = setfit_multilingual_method.predict(prediction_data)
7474

7575
self.assertEqual(3, len(predictions))

0 commit comments

Comments
 (0)