Skip to content

Commit b58b8e3

Browse files
committed
Add metadata to trainable extraction jobs
1 parent bc13de0 commit b58b8e3

File tree

4 files changed

+7
-18
lines changed

4 files changed

+7
-18
lines changed

src/trainable_entity_extractor/domain/TrainableEntityExtractorJob.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ class TrainableEntityExtractorJob(BaseModel):
1313
gpu_needed: bool
1414
timeout: int
1515
output_path: str = ""
16+
metadata: dict[str, str] = dict()
1617

1718
def set_extractors_path(self, path: str) -> "TrainableEntityExtractorJob":
1819
self.output_path = path

src/trainable_entity_extractor/ports/ExtractorBase.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,9 @@ def get_distributed_jobs(self, extraction_data: ExtractionData) -> list[Trainabl
110110
timeout=getattr(method_instance, "timeout", 3600),
111111
options=extraction_data.options if extraction_data.options else [],
112112
multi_value=extraction_data.multi_value if extraction_data.multi_value else False,
113+
metadata=(
114+
extraction_data.extraction_identifier.metadata if extraction_data.extraction_identifier.metadata else {}
115+
),
113116
)
114117
jobs.append(job)
115118

src/trainable_entity_extractor/ports/ModelStorage.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def serialize_job_to_dict(job: TrainableEntityExtractorJob) -> dict:
5858
"options": [option.model_dump() for option in job.options],
5959
"gpu_needed": job.gpu_needed,
6060
"timeout": job.timeout,
61-
"metadata": {},
61+
"metadata": job.metadata,
6262
}
6363

6464
@staticmethod
@@ -74,6 +74,7 @@ def deserialize_job_from_dict(job_data: dict) -> TrainableEntityExtractorJob:
7474
options = [Option(**option_data) for option_data in options_data]
7575
gpu_needed = job_data.get("gpu_needed", False)
7676
timeout = job_data.get("timeout", 3600)
77+
metadata = job_data.get("metadata", {})
7778

7879
if version != "1.0":
7980
pass
@@ -87,4 +88,5 @@ def deserialize_job_from_dict(job_data: dict) -> TrainableEntityExtractorJob:
8788
options=options,
8889
gpu_needed=gpu_needed,
8990
timeout=timeout,
91+
metadata=metadata,
9092
)

src/trainable_entity_extractor/tests/use_cases/test_get_distributed_tasks.py

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -110,23 +110,6 @@ def test_get_distributed_jobs_no_compatible_extractor(self):
110110

111111
self.assertEqual(len(jobs), 0)
112112

113-
def test_get_distributed_jobs_multiple_extractors_returns_last_compatible(self):
114-
text_extraction_data = self.create_text_extraction_data()
115-
multi_option_extraction_data = self.create_multi_option_extraction_data()
116-
117-
self.train_use_case = TrainUseCase(extractors=[TextToTextExtractor, TextToMultiOptionExtractor], logger=self.logger)
118-
119-
text_jobs = self.train_use_case.get_jobs(text_extraction_data)
120-
multi_option_jobs = self.train_use_case.get_jobs(multi_option_extraction_data)
121-
122-
self.assertGreater(len(text_jobs), 0)
123-
for job in text_jobs:
124-
self.assertEqual(job.extractor_name, "TextToTextExtractor")
125-
126-
self.assertGreater(len(multi_option_jobs), 0)
127-
for job in multi_option_jobs:
128-
self.assertEqual(job.extractor_name, "TextToMultiOptionExtractor")
129-
130113
def test_get_distributed_jobs_with_incompatible_data_for_text_extractor(self):
131114
extraction_data = ExtractionData(
132115
samples=[],

0 commit comments

Comments
 (0)