File tree Expand file tree Collapse file tree 4 files changed +7
-18
lines changed Expand file tree Collapse file tree 4 files changed +7
-18
lines changed Original file line number Diff line number Diff line change @@ -13,6 +13,7 @@ class TrainableEntityExtractorJob(BaseModel):
1313 gpu_needed : bool
1414 timeout : int
1515 output_path : str = ""
16+ metadata : dict [str , str ] = dict ()
1617
1718 def set_extractors_path (self , path : str ) -> "TrainableEntityExtractorJob" :
1819 self .output_path = path
Original file line number Diff line number Diff line change @@ -110,6 +110,9 @@ def get_distributed_jobs(self, extraction_data: ExtractionData) -> list[Trainabl
110110 timeout = getattr (method_instance , "timeout" , 3600 ),
111111 options = extraction_data .options if extraction_data .options else [],
112112 multi_value = extraction_data .multi_value if extraction_data .multi_value else False ,
113+ metadata = (
114+ extraction_data .extraction_identifier .metadata if extraction_data .extraction_identifier .metadata else {}
115+ ),
113116 )
114117 jobs .append (job )
115118
Original file line number Diff line number Diff line change @@ -58,7 +58,7 @@ def serialize_job_to_dict(job: TrainableEntityExtractorJob) -> dict:
5858 "options" : [option .model_dump () for option in job .options ],
5959 "gpu_needed" : job .gpu_needed ,
6060 "timeout" : job .timeout ,
61- "metadata" : {} ,
61+ "metadata" : job . metadata ,
6262 }
6363
6464 @staticmethod
@@ -74,6 +74,7 @@ def deserialize_job_from_dict(job_data: dict) -> TrainableEntityExtractorJob:
7474 options = [Option (** option_data ) for option_data in options_data ]
7575 gpu_needed = job_data .get ("gpu_needed" , False )
7676 timeout = job_data .get ("timeout" , 3600 )
77+ metadata = job_data .get ("metadata" , {})
7778
7879 if version != "1.0" :
7980 pass
@@ -87,4 +88,5 @@ def deserialize_job_from_dict(job_data: dict) -> TrainableEntityExtractorJob:
8788 options = options ,
8889 gpu_needed = gpu_needed ,
8990 timeout = timeout ,
91+ metadata = metadata ,
9092 )
Original file line number Diff line number Diff line change @@ -110,23 +110,6 @@ def test_get_distributed_jobs_no_compatible_extractor(self):
110110
111111 self .assertEqual (len (jobs ), 0 )
112112
113- def test_get_distributed_jobs_multiple_extractors_returns_last_compatible (self ):
114- text_extraction_data = self .create_text_extraction_data ()
115- multi_option_extraction_data = self .create_multi_option_extraction_data ()
116-
117- self .train_use_case = TrainUseCase (extractors = [TextToTextExtractor , TextToMultiOptionExtractor ], logger = self .logger )
118-
119- text_jobs = self .train_use_case .get_jobs (text_extraction_data )
120- multi_option_jobs = self .train_use_case .get_jobs (multi_option_extraction_data )
121-
122- self .assertGreater (len (text_jobs ), 0 )
123- for job in text_jobs :
124- self .assertEqual (job .extractor_name , "TextToTextExtractor" )
125-
126- self .assertGreater (len (multi_option_jobs ), 0 )
127- for job in multi_option_jobs :
128- self .assertEqual (job .extractor_name , "TextToMultiOptionExtractor" )
129-
130113 def test_get_distributed_jobs_with_incompatible_data_for_text_extractor (self ):
131114 extraction_data = ExtractionData (
132115 samples = [],
You can’t perform that action at this time.
0 commit comments