Skip to content

Commit 5dc0fd3

Browse files
committed
Fix single SetFit
1 parent 49c091b commit 5dc0fd3

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

src/trainable_entity_extractor/adapters/extractors/pdf_to_multi_option_extractor/multi_labels_methods/SingleLabelSetFitEnglishMethod.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def get_dataset_from_data(self, extraction_data: ExtractionData):
6969
for sample in extraction_data.samples:
7070
labels.append("no_label")
7171
if sample.labeled_data.values:
72-
options = [option for option in self.options if option.id == sample.labeled_data.values[0].id]
72+
options = [option for option in extraction_data.options if option.id == sample.labeled_data.values[0].id]
7373
if options:
7474
labels[-1] = options[0].label
7575

@@ -90,7 +90,9 @@ def train(self, extraction_data: ExtractionData):
9090
train_dataset = self.get_dataset_from_data(extraction_data)
9191
batch_size = get_batch_size(len(extraction_data.samples))
9292

93-
model = SetFitModel.from_pretrained(self.model_name, labels=[x.label for x in self.options], trust_remote_code=True)
93+
model = SetFitModel.from_pretrained(
94+
self.model_name, labels=[x.label for x in extraction_data.options], trust_remote_code=True
95+
)
9496

9597
args = TrainingArguments(
9698
output_dir=self.get_model_path(),

0 commit comments

Comments
 (0)