Skip to content

Commit cdbfc7b

Browse files
committed
Add save extractor job to model storage
1 parent 56251c6 commit cdbfc7b

File tree

4 files changed

+30
-87
lines changed

4 files changed

+30
-87
lines changed

src/trainable_entity_extractor/adapters/LocalModelStorage.py

Lines changed: 2 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -2,64 +2,24 @@
22
import json
33
from typing import Optional
44

5-
from trainable_entity_extractor.config import EXTRACTOR_JOB_PATH, CACHE_PATH
5+
from trainable_entity_extractor.config import EXTRACTOR_JOB_PATH
66
from trainable_entity_extractor.domain.ExtractionIdentifier import ExtractionIdentifier
7-
from trainable_entity_extractor.domain.Option import Option
87
from trainable_entity_extractor.domain.TrainableEntityExtractorJob import TrainableEntityExtractorJob
98
from trainable_entity_extractor.ports.ModelStorage import ModelStorage
109

1110

1211
class LocalModelStorage(ModelStorage):
13-
def __init__(self):
14-
self.completion_signals = {}
1512

1613
def upload_model(self, extraction_identifier: ExtractionIdentifier, extractor_job: TrainableEntityExtractorJob) -> bool:
17-
try:
18-
model_path = extraction_identifier.get_path()
19-
if not os.path.exists(model_path):
20-
os.makedirs(model_path, exist_ok=True)
21-
22-
extractor_job_dir = os.path.join(model_path, EXTRACTOR_JOB_PATH.parent)
23-
if not os.path.exists(extractor_job_dir):
24-
os.makedirs(extractor_job_dir, exist_ok=True)
25-
26-
job_file_path = os.path.join(model_path, EXTRACTOR_JOB_PATH)
27-
job_data = self.serialize_job_to_dict(extractor_job)
28-
29-
with open(job_file_path, "w", encoding="utf-8") as f:
30-
json.dump(job_data, f, indent=2, ensure_ascii=False)
31-
32-
return True
33-
except Exception as e:
34-
print(f"Error saving job: {e}")
35-
return False
14+
return self.save_extractor_job(extraction_identifier, extractor_job)
3615

3716
def download_model(self, extraction_identifier: ExtractionIdentifier) -> bool:
38-
"""Download/load model locally"""
3917
try:
4018
model_path = extraction_identifier.get_path()
4119
return os.path.exists(model_path)
4220
except Exception:
4321
return False
4422

45-
def check_model_completion_signal(self, extraction_identifier: ExtractionIdentifier) -> bool:
46-
key = f"{extraction_identifier.run_name}_{extraction_identifier.extraction_name}"
47-
return self.completion_signals.get(key, False)
48-
49-
def create_model_completion_signal(self, extraction_identifier: ExtractionIdentifier) -> bool:
50-
try:
51-
key = f"{extraction_identifier.run_name}_{extraction_identifier.extraction_name}"
52-
self.completion_signals[key] = True
53-
54-
# Also create a physical completion signal file
55-
completion_file = os.path.join(extraction_identifier.get_path(), "training_complete.signal")
56-
with open(completion_file, "w") as f:
57-
f.write("Training completed successfully")
58-
59-
return True
60-
except Exception:
61-
return False
62-
6323
def get_extractor_job(self, extraction_identifier: ExtractionIdentifier) -> Optional[TrainableEntityExtractorJob]:
6424
try:
6525
model_path = extraction_identifier.get_path()

src/trainable_entity_extractor/ports/JobExecutor.py

Lines changed: 1 addition & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -62,27 +62,7 @@ def upload_model(self, extraction_identifier: ExtractionIdentifier, extractor_jo
6262
try:
6363
extraction_identifier.clean_extractor_folder(extractor_job.method_name)
6464
shutil.rmtree(CACHE_PATH / extraction_identifier.run_name, ignore_errors=True)
65-
upload_success = self.model_storage.upload_model(extraction_identifier, extractor_job)
66-
if upload_success:
67-
signal_success = self.model_storage.create_model_completion_signal(extraction_identifier)
68-
if signal_success:
69-
self.logger.log(
70-
extraction_identifier, f"Model and completion signal uploaded for method {extractor_job.method_name}"
71-
)
72-
return True
73-
else:
74-
self.logger.log(
75-
extraction_identifier,
76-
f"Model uploaded but completion signal creation failed for method {extractor_job.method_name}",
77-
LogSeverity.error,
78-
)
79-
return False
80-
else:
81-
self.logger.log(
82-
extraction_identifier, f"Model upload failed for method {extractor_job.method_name}", LogSeverity.error
83-
)
84-
return False
85-
65+
return self.model_storage.upload_model(extraction_identifier, extractor_job)
8666
except Exception as e:
8767
self.logger.log(extraction_identifier, f"Model upload failed with exception: {e}", LogSeverity.error, e)
8868
return False

src/trainable_entity_extractor/ports/ModelStorage.py

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
1+
import json
2+
import os
13
from abc import ABC, abstractmethod
24
from typing import Optional
5+
6+
from trainable_entity_extractor.config import EXTRACTOR_JOB_PATH
37
from trainable_entity_extractor.domain.ExtractionIdentifier import ExtractionIdentifier
48
from trainable_entity_extractor.domain.Option import Option
59
from trainable_entity_extractor.domain.TrainableEntityExtractorJob import TrainableEntityExtractorJob
@@ -16,16 +20,31 @@ def download_model(self, extraction_identifier: ExtractionIdentifier) -> bool:
1620
pass
1721

1822
@abstractmethod
19-
def check_model_completion_signal(self, extraction_identifier: ExtractionIdentifier) -> bool:
23+
def get_extractor_job(self, extraction_identifier: ExtractionIdentifier) -> Optional[TrainableEntityExtractorJob]:
2024
pass
2125

22-
@abstractmethod
23-
def create_model_completion_signal(self, extraction_identifier: ExtractionIdentifier) -> bool:
24-
pass
26+
def save_extractor_job(
27+
self, extraction_identifier: ExtractionIdentifier, extractor_job: TrainableEntityExtractorJob
28+
) -> bool:
29+
try:
30+
model_path = extraction_identifier.get_path()
31+
if not os.path.exists(model_path):
32+
os.makedirs(model_path, exist_ok=True)
2533

26-
@abstractmethod
27-
def get_extractor_job(self, extraction_identifier: ExtractionIdentifier) -> Optional[TrainableEntityExtractorJob]:
28-
pass
34+
extractor_job_dir = os.path.join(model_path, EXTRACTOR_JOB_PATH.parent)
35+
if not os.path.exists(extractor_job_dir):
36+
os.makedirs(extractor_job_dir, exist_ok=True)
37+
38+
job_file_path = os.path.join(model_path, EXTRACTOR_JOB_PATH)
39+
job_data = self.serialize_job_to_dict(extractor_job)
40+
41+
with open(job_file_path, "w", encoding="utf-8") as f:
42+
json.dump(job_data, f, indent=2, ensure_ascii=False)
43+
44+
return True
45+
except Exception as e:
46+
print(f"Error saving job: {e}")
47+
return False
2948

3049
@staticmethod
3150
def serialize_job_to_dict(job: TrainableEntityExtractorJob) -> dict:
@@ -56,7 +75,6 @@ def deserialize_job_from_dict(job_data: dict) -> TrainableEntityExtractorJob:
5675
gpu_needed = job_data.get("gpu_needed", False)
5776
timeout = job_data.get("timeout", 3600)
5877

59-
additional_fields = {}
6078
if version != "1.0":
6179
pass
6280

src/trainable_entity_extractor/use_cases/OrchestratorUseCase.py

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -113,9 +113,6 @@ def _process_prediction_job(self, distributed_job: DistributedJob) -> JobProcess
113113
)
114114

115115
def _process_performance_job(self, distributed_job: DistributedJob) -> JobProcessingResult:
116-
if len(distributed_job.sub_jobs) == [x for x in distributed_job.sub_jobs if x.status == JobStatus.WAITING]:
117-
self.job_executor.recreate_model_folder(distributed_job.extraction_identifier)
118-
119116
self._start_pending_performance_evaluations(distributed_job)
120117

121118
if self._has_perfect_score_job(distributed_job):
@@ -129,8 +126,8 @@ def _process_performance_job(self, distributed_job: DistributedJob) -> JobProces
129126
gpu_needed=any(getattr(job.extractor_job, "requires_gpu", False) for job in distributed_job.sub_jobs),
130127
)
131128

132-
self._log_performance_summary(distributed_job)
133129
self._remove_job_from_queue(distributed_job)
130+
self._log_performance_summary(distributed_job)
134131

135132
return self._handle_performance_results(distributed_job)
136133

@@ -194,18 +191,6 @@ def _handle_performance_results(self, distributed_job: DistributedJob) -> JobPro
194191
gpu_needed=getattr(best_job.extractor_job, "requires_gpu", False),
195192
)
196193

197-
def _finalize_best_model(self, distributed_job: DistributedJob, best_job: DistributedSubJob) -> JobProcessingResult:
198-
if self.job_executor.upload_model(distributed_job.extraction_identifier, best_job.extractor_job):
199-
performance_score = self._extract_performance_score(best_job)
200-
return JobProcessingResult(
201-
finished=True,
202-
success=True,
203-
error_message=f"Best model selected: {best_job.extractor_job.method_name} with performance {performance_score}",
204-
gpu_needed=getattr(best_job.extractor_job, "requires_gpu", False),
205-
)
206-
else:
207-
return JobProcessingResult(finished=True, success=False, error_message="Best model selected but upload failed")
208-
209194
@staticmethod
210195
def _extract_performance_score(best_job: DistributedSubJob) -> str:
211196
if best_job.result and hasattr(best_job.result, "performance_score"):

0 commit comments

Comments
 (0)