Skip to content

Commit e657341

Browse files
committed
Add if it failed to Performance object
1 parent ee19c8c commit e657341

File tree

4 files changed

+10
-35
lines changed

4 files changed

+10
-35
lines changed

src/trainable_entity_extractor/adapters/extractors/pdf_to_multi_option_extractor/PdfToMultiOptionExtractor.py

Lines changed: 0 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -166,37 +166,6 @@ def get_predictions(
166166

167167
return method.get_samples_for_context(prediction_samples_data), prediction
168168

169-
def get_best_method(
170-
self, multi_option_data: ExtractionData, training_set: ExtractionData, test_set: ExtractionData
171-
) -> Optional[PdfMultiOptionMethod]:
172-
best_method_instance = self.METHODS[0]
173-
best_performance = 0
174-
performance_summary = PerformanceSummary.from_extraction_data(
175-
extractor_name=self.get_name(),
176-
training_samples_count=len(training_set.samples),
177-
testing_samples_count=len(test_set.samples),
178-
extraction_data=multi_option_data,
179-
)
180-
for method in self.METHODS:
181-
if self.extraction_identifier.is_training_canceled():
182-
self.logger.log(self.extraction_identifier, "Training canceled")
183-
return None
184-
185-
performance = self.get_method_performance(method, training_set, test_set)
186-
performance_summary.add_performance(method.get_name(), performance)
187-
if performance == 100:
188-
self.logger.log(self.extraction_identifier, performance_summary.to_log())
189-
self.extraction_identifier.save_content("performance_log.txt", performance_summary.to_log())
190-
return method
191-
192-
if round(performance, 2) > best_performance:
193-
best_performance = round(performance, 2)
194-
best_method_instance = method
195-
196-
self.logger.log(self.extraction_identifier, performance_summary.to_log())
197-
self.extraction_identifier.save_content("performance_log.txt", performance_summary.to_log())
198-
return best_method_instance
199-
200169
def get_method_performance(
201170
self, method: PdfMultiOptionMethod, train_set: ExtractionData, test_set: ExtractionData
202171
) -> float:

src/trainable_entity_extractor/domain/PerformanceLog.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ class PerformanceLog(BaseModel):
55
method_name: str
66
performance: float
77
execution_seconds: int = 0
8+
failed: bool = False
89

910
@staticmethod
1011
def get_execution_time_string(execution_seconds: int):

src/trainable_entity_extractor/domain/PerformanceSummary.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,13 @@ class PerformanceSummary(BaseModel):
1919
previous_timestamp: int = Field(default_factory=lambda: int(time()))
2020
empty_pdf_count: int = 0
2121

22-
def add_performance(self, method_name: str, performance: float):
22+
def add_performance(self, method_name: str, performance: float, failed: bool = False):
2323
current_time = int(time())
2424
performance = PerformanceLog(
25-
method_name=method_name, performance=performance, execution_seconds=int(current_time - self.previous_timestamp)
25+
method_name=method_name,
26+
performance=performance,
27+
execution_seconds=int(current_time - self.previous_timestamp),
28+
failed=failed,
2629
)
2730
self.previous_timestamp = current_time
2831
self.performances.append(performance)
@@ -36,7 +39,9 @@ def add_performance_from_sub_job(self, sub_job):
3639
else:
3740
performance_score = 0.0
3841

39-
self.add_performance(sub_job.extractor_job.method_name, performance_score)
42+
failed = sub_job.result is None or (hasattr(sub_job.result, "failed") and sub_job.result.failed)
43+
44+
self.add_performance(sub_job.extractor_job.method_name, performance_score, failed)
4045

4146
def to_log(self) -> str:
4247
total_time = sum(performance.execution_seconds for performance in self.performances)

src/trainable_entity_extractor/use_cases/OrchestratorUseCase.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ def _log_performance_summary(self, distributed_job: DistributedJob) -> None:
154154
performance_summary = PerformanceSummary.from_distributed_job(distributed_job)
155155

156156
for sub_job in distributed_job.sub_jobs:
157-
if sub_job.status == JobStatus.SUCCESS and sub_job.result:
157+
if sub_job.status in [JobStatus.SUCCESS, JobStatus.FAILURE] and sub_job.result:
158158
performance_summary.add_performance_from_sub_job(sub_job)
159159

160160
summary_log = performance_summary.to_log()

0 commit comments

Comments
 (0)