Skip to content

Commit e12e3ae

Browse files
committed
Fix output path for predictions use case
1 parent 34480ba commit e12e3ae

File tree

4 files changed

+3
-23
lines changed

4 files changed

+3
-23
lines changed

src/trainable_entity_extractor/domain/TrainableEntityExtractorJob.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,4 @@ class TrainableEntityExtractorJob(BaseModel):
1212
options: list[Option] = []
1313
gpu_needed: bool
1414
timeout: int
15+
output_path: str = ""

src/trainable_entity_extractor/ports/JobExecutor.py

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -58,27 +58,6 @@ def recreate_model_folder(extraction_identifier: ExtractionIdentifier) -> None:
5858
shutil.rmtree(extraction_identifier.get_path(), ignore_errors=True)
5959
extraction_identifier.get_path().mkdir(parents=True, exist_ok=True)
6060

61-
def check_and_wait_for_model(self, extraction_identifier: ExtractionIdentifier) -> bool:
62-
try:
63-
completion_signal_exists = self.model_storage.check_model_completion_signal(extraction_identifier)
64-
65-
if not completion_signal_exists:
66-
self.logger.log(extraction_identifier, "Model completion signal not found, model may still be uploading")
67-
return False
68-
69-
self.logger.log(extraction_identifier, "Model completion signal found, model is ready")
70-
model_downloaded = self.model_storage.download_model(extraction_identifier)
71-
if model_downloaded:
72-
self.logger.log(
73-
extraction_identifier, "Model download failed, checking completion signal", LogSeverity.warning
74-
)
75-
return True
76-
77-
return False
78-
except Exception as e:
79-
self.logger.log(extraction_identifier, f"Error checking model availability: {e}", LogSeverity.error)
80-
return False
81-
8261
def is_extractor_cancelled(self, extractor_identifier: ExtractionIdentifier) -> bool:
8362
try:
8463
return self.data_retriever.is_extractor_cancelled(extractor_identifier)

src/trainable_entity_extractor/tests/use_cases/test_extractor_text_to_text.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,6 @@ def test_predictions_from_source_text_in_labeled_data(self):
8787
# Create prediction samples
8888
texts = ["test 0"]
8989
predictions_samples = [PredictionSample.from_text(text, str(i)) for i, text in enumerate(texts)]
90-
predictions_samples[0].segment_selector_texts = []
9190

9291
# Save prediction data
9392
self.data_retriever.save_prediction_data(extraction_identifier, predictions_samples)

src/trainable_entity_extractor/use_cases/PredictUseCase.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,9 @@ def __init__(self, extractors: list[type[ExtractorBase]], logger: Logger):
1414
self.logger = logger
1515

1616
def predict(self, extractor_job: TrainableEntityExtractorJob, samples: list[PredictionSample]) -> list[Suggestion]:
17+
output_path = extractor_job.output_path if extractor_job.output_path else DATA_PATH
1718
extraction_identifier = ExtractionIdentifier(
18-
run_name=extractor_job.run_name, output_path=DATA_PATH, extraction_name=extractor_job.extraction_name
19+
run_name=extractor_job.run_name, output_path=output_path, extraction_name=extractor_job.extraction_name
1920
)
2021

2122
extractor_name = extractor_job.extractor_name

0 commit comments

Comments
 (0)