Skip to content

Commit ee19c8c

Browse files
committed
Add method to Performance object
1 parent b58b8e3 commit ee19c8c

File tree

5 files changed

+34
-7
lines changed

5 files changed

+34
-7
lines changed

src/trainable_entity_extractor/adapters/LocalJobExecutor.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
from datetime import timedelta, datetime
2+
from pathlib import Path
3+
import shutil
14
from typing import Tuple, List
25

36
from trainable_entity_extractor.adapters.extractors.pdf_to_multi_option_extractor.PdfToMultiOptionExtractor import (
@@ -29,10 +32,26 @@ class LocalJobExecutor(JobExecutor):
2932
TextToTextExtractor,
3033
]
3134

35+
@staticmethod
36+
def ensure_fresh_model_folder(extraction_identifier: ExtractionIdentifier, max_age_hours: int = 1) -> None:
37+
path = Path(extraction_identifier.get_path())
38+
39+
if path.exists():
40+
folder_modified_time = datetime.fromtimestamp(path.stat().st_mtime)
41+
current_time = datetime.now()
42+
age = current_time - folder_modified_time
43+
44+
if age > timedelta(hours=max_age_hours):
45+
shutil.rmtree(path)
46+
path.mkdir(parents=True, exist_ok=True)
47+
else:
48+
path.mkdir(parents=True, exist_ok=True)
49+
3250
def start_performance_evaluation(
3351
self, extraction_identifier: ExtractionIdentifier, distributed_sub_job: DistributedSubJob
3452
):
3553
try:
54+
self.ensure_fresh_model_folder(extraction_identifier)
3655
extraction_data = self.data_retriever.get_extraction_data(extraction_identifier)
3756
if not extraction_data:
3857
distributed_sub_job.status = JobStatus.FAILURE
@@ -67,9 +86,6 @@ def start_performance_evaluation(
6786
distributed_sub_job.status = JobStatus.FAILURE
6887
return None
6988

70-
distributed_sub_job.status = JobStatus.FAILURE
71-
return None
72-
7389
def upload_model(self, extraction_identifier: ExtractionIdentifier, extractor_job: TrainableEntityExtractorJob) -> bool:
7490
try:
7591
extraction_identifier.clean_extractor_folder(extractor_job.method_name)

src/trainable_entity_extractor/domain/Performance.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33

44
class Performance(BaseModel):
5+
method_name: str = "Unknown Method"
56
performance: float = 0.0
67
execution_seconds: int = 0
78
is_perfect: bool = False

src/trainable_entity_extractor/domain/PerformanceSummary.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,17 +91,19 @@ def from_distributed_job(distributed_job: DistributedJob) -> "PerformanceSummary
9191
testing_samples_count = 0
9292
training_samples_count = 0
9393
options_count = 0
94+
extractor_name = "Unknown Extractor"
9495

9596
for sub_job in distributed_job.sub_jobs:
9697
if not sub_job.result:
9798
continue
9899
testing_samples_count = sub_job.result.testing_samples_count
99100
training_samples_count = sub_job.result.training_samples_count
100101
options_count = len(sub_job.extractor_job.options) if sub_job.extractor_job.options else 0
102+
extractor_name = sub_job.extractor_job.extractor_name
101103

102104
return PerformanceSummary(
103105
extraction_identifier=distributed_job.extraction_identifier,
104-
extractor_name="Performance Evaluation",
106+
extractor_name=extractor_name,
105107
samples_count=0,
106108
options_count=options_count,
107109
languages=[],

src/trainable_entity_extractor/ports/ExtractorBase.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -125,14 +125,14 @@ def get_performance(self, extractor_job: TrainableEntityExtractorJob, extraction
125125
method_instance = self.get_method_instance_by_name(method_name)
126126
if not method_instance:
127127
self.logger.log(extraction_data.extraction_identifier, f"Method {method_name} not found")
128-
return Performance(failed=True)
128+
return Performance(method_name=method_name, failed=True)
129129

130130
if hasattr(method_instance, "can_be_used"):
131131
if not method_instance.can_be_used(extraction_data):
132132
self.logger.log(
133133
extraction_data.extraction_identifier, f"Method {method_name} cannot be used with current data"
134134
)
135-
return Performance(failed=True)
135+
return Performance(method_name=method_name, failed=True)
136136

137137
self.logger.log(extraction_data.extraction_identifier, f"\nChecking {method_name}")
138138

@@ -145,6 +145,7 @@ def get_performance(self, extractor_job: TrainableEntityExtractorJob, extraction
145145
is_perfect = performance_score >= 99.99
146146

147147
return Performance(
148+
method_name=method_name,
148149
performance=performance_score,
149150
execution_seconds=execution_time,
150151
is_perfect=is_perfect,
@@ -157,7 +158,7 @@ def get_performance(self, extractor_job: TrainableEntityExtractorJob, extraction
157158
self.logger.log(extraction_data.extraction_identifier, "ERROR", LogSeverity.info, e)
158159
execution_time = int(time.time() - start_time)
159160

160-
return Performance(execution_seconds=execution_time)
161+
return Performance(method_name=method_name, execution_seconds=execution_time)
161162

162163
def train_one_method(
163164
self, extractor_job: TrainableEntityExtractorJob, extraction_data: ExtractionData

src/trainable_entity_extractor/use_cases/TrainUseCase.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from pathlib import Path
2+
13
from trainable_entity_extractor.domain.ExtractionData import ExtractionData
24
from trainable_entity_extractor.domain.Performance import Performance
35
from trainable_entity_extractor.domain.TrainableEntityExtractorJob import TrainableEntityExtractorJob
@@ -13,6 +15,11 @@ def __init__(self, extractors: list[type[ExtractorBase]], logger: Logger):
1315
def train_one_method(
1416
self, extractor_job: TrainableEntityExtractorJob, extraction_data: ExtractionData
1517
) -> tuple[bool, str]:
18+
19+
method_path = Path(extraction_data.extraction_identifier.get_path()) / extractor_job.method_name
20+
if method_path.exists() and any(method_path.iterdir()):
21+
return True, ""
22+
1623
extractor_name = extractor_job.extractor_name
1724
for extractor in self.extractors:
1825
extractor_instance = extractor(extraction_data.extraction_identifier, self.logger)

0 commit comments

Comments
 (0)