Skip to content

Commit f45e503

Browse files
committed
Retrain always
1 parent bf118a2 commit f45e503

File tree

1 file changed

+15
-21
lines changed

1 file changed

+15
-21
lines changed

src/trainable_entity_extractor/use_cases/OrchestratorUseCase.py

Lines changed: 15 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,12 @@ def process_job(self, distributed_job: DistributedJob) -> JobProcessingResult:
3030
error_message=f"Job cancelled for extraction {distributed_job.extraction_identifier}",
3131
)
3232

33-
if distributed_job.type == JobType.TRAIN:
33+
if distributed_job.type == JobType.PERFORMANCE:
34+
return self._process_performance_job(distributed_job)
35+
elif distributed_job.type == JobType.TRAIN:
3436
return self._process_training_job(distributed_job)
3537
elif distributed_job.type == JobType.PREDICT:
3638
return self._process_prediction_job(distributed_job)
37-
elif distributed_job.type == JobType.PERFORMANCE:
38-
return self._process_performance_job(distributed_job)
3939
else:
4040
self.distributed_jobs.remove(distributed_job)
4141
return JobProcessingResult(
@@ -180,24 +180,6 @@ def _handle_performance_results(self, distributed_job: DistributedJob) -> JobPro
180180
finished=True, success=False, error_message="No valid performance results to select the best model"
181181
)
182182

183-
if not best_job.extractor_job.should_be_retrained_with_more_data:
184-
return self._finalize_best_model(distributed_job, best_job)
185-
else:
186-
return self._schedule_retraining(distributed_job, best_job)
187-
188-
def _finalize_best_model(self, distributed_job: DistributedJob, best_job: DistributedSubJob) -> JobProcessingResult:
189-
if self.job_executor.upload_model(distributed_job.extraction_identifier, best_job.extractor_job):
190-
performance_score = self._extract_performance_score(best_job)
191-
return JobProcessingResult(
192-
finished=True,
193-
success=True,
194-
error_message=f"Best model selected: {best_job.extractor_job.method_name} with performance {performance_score}",
195-
gpu_needed=getattr(best_job.extractor_job, "requires_gpu", False),
196-
)
197-
else:
198-
return JobProcessingResult(finished=True, success=False, error_message="Best model selected but upload failed")
199-
200-
def _schedule_retraining(self, distributed_job: DistributedJob, best_job: DistributedSubJob) -> JobProcessingResult:
201183
training_job = DistributedJob(
202184
extraction_identifier=distributed_job.extraction_identifier,
203185
type=JobType.TRAIN,
@@ -212,6 +194,18 @@ def _schedule_retraining(self, distributed_job: DistributedJob, best_job: Distri
212194
gpu_needed=getattr(best_job.extractor_job, "requires_gpu", False),
213195
)
214196

197+
def _finalize_best_model(self, distributed_job: DistributedJob, best_job: DistributedSubJob) -> JobProcessingResult:
198+
if self.job_executor.upload_model(distributed_job.extraction_identifier, best_job.extractor_job):
199+
performance_score = self._extract_performance_score(best_job)
200+
return JobProcessingResult(
201+
finished=True,
202+
success=True,
203+
error_message=f"Best model selected: {best_job.extractor_job.method_name} with performance {performance_score}",
204+
gpu_needed=getattr(best_job.extractor_job, "requires_gpu", False),
205+
)
206+
else:
207+
return JobProcessingResult(finished=True, success=False, error_message="Best model selected but upload failed")
208+
215209
@staticmethod
216210
def _extract_performance_score(best_job: DistributedSubJob) -> str:
217211
if best_job.result and hasattr(best_job.result, "performance_score"):

0 commit comments

Comments
 (0)