@@ -30,12 +30,12 @@ def process_job(self, distributed_job: DistributedJob) -> JobProcessingResult:
3030 error_message = f"Job cancelled for extraction { distributed_job .extraction_identifier } " ,
3131 )
3232
33- if distributed_job .type == JobType .TRAIN :
33+ if distributed_job .type == JobType .PERFORMANCE :
34+ return self ._process_performance_job (distributed_job )
35+ elif distributed_job .type == JobType .TRAIN :
3436 return self ._process_training_job (distributed_job )
3537 elif distributed_job .type == JobType .PREDICT :
3638 return self ._process_prediction_job (distributed_job )
37- elif distributed_job .type == JobType .PERFORMANCE :
38- return self ._process_performance_job (distributed_job )
3939 else :
4040 self .distributed_jobs .remove (distributed_job )
4141 return JobProcessingResult (
@@ -180,24 +180,6 @@ def _handle_performance_results(self, distributed_job: DistributedJob) -> JobPro
180180 finished = True , success = False , error_message = "No valid performance results to select the best model"
181181 )
182182
183- if not best_job .extractor_job .should_be_retrained_with_more_data :
184- return self ._finalize_best_model (distributed_job , best_job )
185- else :
186- return self ._schedule_retraining (distributed_job , best_job )
187-
188- def _finalize_best_model (self , distributed_job : DistributedJob , best_job : DistributedSubJob ) -> JobProcessingResult :
189- if self .job_executor .upload_model (distributed_job .extraction_identifier , best_job .extractor_job ):
190- performance_score = self ._extract_performance_score (best_job )
191- return JobProcessingResult (
192- finished = True ,
193- success = True ,
194- error_message = f"Best model selected: { best_job .extractor_job .method_name } with performance { performance_score } " ,
195- gpu_needed = getattr (best_job .extractor_job , "requires_gpu" , False ),
196- )
197- else :
198- return JobProcessingResult (finished = True , success = False , error_message = "Best model selected but upload failed" )
199-
200- def _schedule_retraining (self , distributed_job : DistributedJob , best_job : DistributedSubJob ) -> JobProcessingResult :
201183 training_job = DistributedJob (
202184 extraction_identifier = distributed_job .extraction_identifier ,
203185 type = JobType .TRAIN ,
@@ -212,6 +194,18 @@ def _schedule_retraining(self, distributed_job: DistributedJob, best_job: Distri
212194 gpu_needed = getattr (best_job .extractor_job , "requires_gpu" , False ),
213195 )
214196
197+ def _finalize_best_model (self , distributed_job : DistributedJob , best_job : DistributedSubJob ) -> JobProcessingResult :
198+ if self .job_executor .upload_model (distributed_job .extraction_identifier , best_job .extractor_job ):
199+ performance_score = self ._extract_performance_score (best_job )
200+ return JobProcessingResult (
201+ finished = True ,
202+ success = True ,
203+ error_message = f"Best model selected: { best_job .extractor_job .method_name } with performance { performance_score } " ,
204+ gpu_needed = getattr (best_job .extractor_job , "requires_gpu" , False ),
205+ )
206+ else :
207+ return JobProcessingResult (finished = True , success = False , error_message = "Best model selected but upload failed" )
208+
215209 @staticmethod
216210 def _extract_performance_score (best_job : DistributedSubJob ) -> str :
217211 if best_job .result and hasattr (best_job .result , "performance_score" ):
0 commit comments