Skip to content

Commit 3e81c54

Browse files
committed
Add data summary
1 parent 24e44e9 commit 3e81c54

File tree

16 files changed

+367
-53
lines changed

16 files changed

+367
-53
lines changed

src/trainable_entity_extractor/adapters/extractors/pdf_to_multi_option_extractor/PdfMultiOptionMethod.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def get_performance(self, train_set: ExtractionData, test_set: ExtractionData) -
5757

5858
self.train(train_set)
5959
prediction_samples_data = PredictionSamplesData(
60-
prediction_samples=[PredictionSample.from_text(x.pdf_data.get_text()) for x in test_set.samples],
60+
prediction_samples=[PredictionSample.from_pdf_data(x.pdf_data) for x in test_set.samples],
6161
options=self.options,
6262
multi_value=self.multi_value,
6363
)

src/trainable_entity_extractor/adapters/extractors/segment_selector/methods/lightgbm_frequent_words/LightgbmFrequentWords.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def __init__(self):
3030

3131
def create_model(self, training_pdfs_segments: list[PdfData], model_path):
3232
start = time()
33-
self.set_segments(pdfs_segments=training_pdfs_segments)
33+
self.set_segments(pdfs_data=training_pdfs_segments)
3434

3535
config_logger.info(f"Set segments {int(time() - start)} seconds")
3636

@@ -81,13 +81,13 @@ def get_training_data(self):
8181

8282
return X, y
8383

84-
def set_segments(self, pdfs_segments: list[PdfData]):
84+
def set_segments(self, pdfs_data: list[PdfData]):
8585
self.segments = list()
86-
for pdf_features in pdfs_segments:
87-
self.segments.extend(SegmentLightgbmFrequentWords.from_pdf_features(pdf_features))
86+
for pdf_data in pdfs_data:
87+
self.segments.extend(SegmentLightgbmFrequentWords.from_pdf_data(pdf_data))
8888

89-
def predict(self, model, testing_pdfs_segments: list[PdfData], model_path):
90-
self.set_segments(testing_pdfs_segments)
89+
def predict(self, model, pdfs_data: list[PdfData], model_path):
90+
self.set_segments(pdfs_data)
9191
self.set_most_frequent_words_to_segments(model_path)
9292
x, y = self.get_training_data()
9393
x = x[:, : model.num_feature()]

src/trainable_entity_extractor/adapters/extractors/segment_selector/methods/lightgbm_frequent_words/SegmentLightgbmFrequentWords.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -235,11 +235,11 @@ def token_after_last_token(self, token: PdfToken):
235235
return False
236236

237237
@staticmethod
238-
def from_pdf_features(pdf_features: PdfData) -> list["SegmentLightgbmFrequentWords"]:
239-
modes = Modes(pdf_features)
238+
def from_pdf_data(pdf_data: PdfData) -> list["SegmentLightgbmFrequentWords"]:
239+
modes = Modes(pdf_data)
240240
segments: list["SegmentLightgbmFrequentWords"] = list()
241-
for index, pdf_segment in enumerate(pdf_features.pdf_data_segments):
242-
segment_landmarks = SegmentLightgbmFrequentWords(index, pdf_segment, pdf_features, modes)
241+
for index, pdf_segment in enumerate(pdf_data.pdf_data_segments):
242+
segment_landmarks = SegmentLightgbmFrequentWords(index, pdf_segment, pdf_data, modes)
243243
segments.append(segment_landmarks)
244244

245245
sorted_pdf_segments = sorted(segments, key=lambda x: (x.page_index, x.top))
Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
from collections import Counter
2+
from typing import Optional
3+
4+
from pydantic import BaseModel, Field
5+
6+
from trainable_entity_extractor.domain.ExtractionData import ExtractionData
7+
8+
9+
class OptionDistribution(BaseModel):
10+
option_id: str
11+
option_label: str
12+
count: int
13+
percentage: float
14+
15+
16+
class LanguageDistribution(BaseModel):
17+
language_iso: str
18+
count: int
19+
percentage: float
20+
21+
22+
class TextLengthStats(BaseModel):
23+
min_length: int
24+
max_length: int
25+
avg_length: float
26+
median_length: float
27+
28+
29+
class ExtractionDataSummary(BaseModel):
30+
total_samples: int
31+
total_options: int
32+
has_pdf_data: bool
33+
empty_pdfs_count: int = 0
34+
languages: list[LanguageDistribution] = Field(default_factory=list)
35+
option_distribution: list[OptionDistribution] = Field(default_factory=list)
36+
label_text_stats: Optional[TextLengthStats] = None
37+
source_text_stats: Optional[TextLengthStats] = None
38+
samples_with_values: int = 0
39+
40+
@staticmethod
41+
def from_extraction_data(extraction_data: ExtractionData) -> "ExtractionDataSummary":
42+
total_samples = len(extraction_data.samples)
43+
total_options = len(extraction_data.options) if extraction_data.options else 0
44+
45+
has_pdf_data = any(sample.pdf_data and sample.pdf_data.get_text() for sample in extraction_data.samples)
46+
empty_pdfs_count = 0
47+
48+
if has_pdf_data:
49+
for sample in extraction_data.samples:
50+
if sample.pdf_data:
51+
if not sample.pdf_data.get_text():
52+
empty_pdfs_count += 1
53+
54+
language_counter = Counter()
55+
for sample in extraction_data.samples:
56+
if sample.labeled_data and sample.labeled_data.language_iso:
57+
language_counter[sample.labeled_data.language_iso] += 1
58+
59+
languages = [
60+
LanguageDistribution(language_iso=lang, count=count, percentage=round(count / total_samples * 100, 2))
61+
for lang, count in language_counter.most_common()
62+
]
63+
64+
option_counter = Counter()
65+
for sample in extraction_data.samples:
66+
if sample.labeled_data and sample.labeled_data.values:
67+
for value in sample.labeled_data.values:
68+
option_counter[value.id] += 1
69+
70+
option_distribution = []
71+
if extraction_data.options:
72+
for option in extraction_data.options:
73+
count = option_counter.get(option.id, 0)
74+
option_distribution.append(
75+
OptionDistribution(
76+
option_id=option.id,
77+
option_label=option.label,
78+
count=count,
79+
percentage=round(count / total_samples * 100, 2) if total_samples > 0 else 0,
80+
)
81+
)
82+
option_distribution = sorted(option_distribution, key=lambda x: x.count, reverse=True)[:30]
83+
84+
label_text_lengths = []
85+
source_text_lengths = []
86+
samples_with_values = 0
87+
88+
for sample in extraction_data.samples:
89+
if sample.labeled_data:
90+
if sample.labeled_data.label_text:
91+
label_text_lengths.append(len(sample.labeled_data.label_text))
92+
if sample.labeled_data.source_text:
93+
source_text_lengths.append(len(sample.labeled_data.source_text))
94+
if sample.labeled_data.values:
95+
samples_with_values += 1
96+
97+
label_text_stats = None
98+
if label_text_lengths:
99+
sorted_lengths = sorted(label_text_lengths)
100+
label_text_stats = TextLengthStats(
101+
min_length=min(label_text_lengths),
102+
max_length=max(label_text_lengths),
103+
avg_length=round(sum(label_text_lengths) / len(label_text_lengths), 2),
104+
median_length=sorted_lengths[len(sorted_lengths) // 2],
105+
)
106+
107+
source_text_stats = None
108+
if source_text_lengths:
109+
sorted_lengths = sorted(source_text_lengths)
110+
source_text_stats = TextLengthStats(
111+
min_length=min(source_text_lengths),
112+
max_length=max(source_text_lengths),
113+
avg_length=round(sum(source_text_lengths) / len(source_text_lengths), 2),
114+
median_length=sorted_lengths[len(sorted_lengths) // 2],
115+
)
116+
117+
return ExtractionDataSummary(
118+
total_samples=total_samples,
119+
total_options=total_options,
120+
has_pdf_data=has_pdf_data,
121+
empty_pdfs_count=empty_pdfs_count,
122+
languages=languages,
123+
option_distribution=option_distribution,
124+
label_text_stats=label_text_stats,
125+
source_text_stats=source_text_stats,
126+
samples_with_values=samples_with_values,
127+
)
128+
129+
def to_report_string(self) -> str:
130+
lines = [
131+
"Data Summary",
132+
"=" * 80,
133+
f"Total Samples: {self.total_samples}",
134+
]
135+
136+
if self.total_options:
137+
lines.append(f"Total Options: {self.total_options}")
138+
139+
if self.total_options and self.option_distribution:
140+
lines.append("\nOption Distribution:")
141+
for dist in self.option_distribution:
142+
lines.append(f" - {dist.option_label} (id: {dist.option_id}): {dist.count} samples ({dist.percentage}%)")
143+
144+
if self.samples_with_values > 0:
145+
percentage = round(self.samples_with_values / self.total_samples * 100, 2)
146+
lines.append(f"\nSamples with Option Values: {self.samples_with_values} ({percentage}%)")
147+
148+
if self.languages:
149+
lines.append("\nLanguage Distribution:")
150+
for lang_dist in self.languages:
151+
lines.append(f" - {lang_dist.language_iso}: {lang_dist.count} samples ({lang_dist.percentage}%)")
152+
153+
if self.has_pdf_data:
154+
lines.append(f"\nPDF Data: Present")
155+
if self.empty_pdfs_count > 0:
156+
lines.append(f"Empty PDFs: {self.empty_pdfs_count}")
157+
158+
if self.label_text_stats:
159+
lines.append("\nLabel Text Length:")
160+
lines.append(f" - Min: {self.label_text_stats.min_length}")
161+
lines.append(f" - Max: {self.label_text_stats.max_length}")
162+
lines.append(f" - Average: {self.label_text_stats.avg_length}")
163+
lines.append(f" - Median: {self.label_text_stats.median_length}")
164+
165+
if self.source_text_stats:
166+
lines.append("\nSource Text Length:")
167+
lines.append(f" - Min: {self.source_text_stats.min_length}")
168+
lines.append(f" - Max: {self.source_text_stats.max_length}")
169+
lines.append(f" - Average: {self.source_text_stats.avg_length}")
170+
lines.append(f" - Median: {self.source_text_stats.median_length}")
171+
172+
lines.append("=" * 80)
173+
174+
return "\n".join(lines)

src/trainable_entity_extractor/domain/PerformanceSummary.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ class PerformanceSummary(BaseModel):
1717
performances: list[PerformanceLog] = []
1818
extraction_identifier: ExtractionIdentifier | None = None
1919
previous_timestamp: int = Field(default_factory=lambda: int(time()))
20-
empty_pdf_count: int = 0
2120

2221
def add_performance(self, method_name: str, performance: float, failed: bool = False):
2322
current_time = int(time())
@@ -53,7 +52,6 @@ def to_log(self) -> str:
5352
text += f"Samples: {self.samples_count}\n"
5453
text += f"Train/test: {self.training_samples_count}/{self.testing_samples_count}\n"
5554
text += f"{len(self.languages)} language(s): {', '.join(self.languages) if self.languages else 'None'}\n"
56-
text += f"Empty PDFs: {self.empty_pdf_count}\n" if self.empty_pdf_count else ""
5755
text += f"Options count: {self.options_count}\n" if self.options_count > 0 else ""
5856
text += "Methods by performance:\n"
5957
for performance in sorted(self.performances, key=lambda x: x.performance, reverse=True):
@@ -94,5 +92,4 @@ def from_distributed_job(distributed_job: DistributedJob) -> "PerformanceSummary
9492
languages=languages,
9593
training_samples_count=training_samples_count,
9694
testing_samples_count=testing_samples_count,
97-
empty_pdf_count=0,
9895
)

src/trainable_entity_extractor/domain/PredictionSample.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,10 @@ def get_input_text_by_lines(self) -> list[str]:
3232
return [""]
3333

3434
@staticmethod
35-
def from_pdf_data(pdf_data: PdfData):
36-
return PredictionSample(pdf_data=pdf_data)
35+
def from_pdf_data(pdf_data: PdfData, entity_name: str = ""):
36+
prediction_sample = PredictionSample.from_text(pdf_data.get_text(), entity_name)
37+
prediction_sample.pdf_data = pdf_data
38+
return prediction_sample
3739

3840
@staticmethod
3941
def from_text(text: str, entity_name: str = ""):

src/trainable_entity_extractor/domain/Suggestion.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,8 +108,18 @@ def from_prediction_text(extraction_identifier: ExtractionIdentifier, entity_nam
108108
def from_prediction_multi_option(extraction_identifier: ExtractionIdentifier, entity_name: str, values: list[Value]):
109109
suggestion = Suggestion.get_empty(extraction_identifier, entity_name)
110110
suggestion.values = values
111-
if values:
112-
suggestion.segment_text = values[0].segment_text
111+
for value in values:
112+
if value.segment_text:
113+
suggestion._raw_context = [values[0].segment_text]
114+
suggestion.segment_text = FormatSegmentText([values[0].segment_text], value.label).get_text()
115+
break
116+
117+
for value in values:
118+
if value.segment_text:
119+
value.segment_text = FormatSegmentText([value.segment_text], value.label).get_text()
120+
else:
121+
value.segment_text = FormatSegmentText(suggestion._raw_context, value.label).get_text()
122+
113123
return suggestion
114124

115125
def set_segment_text_from_sample(self, prediction_sample: PredictionSample):

src/trainable_entity_extractor/domain/Value.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def __eq__(self, other):
1717
if not isinstance(other, Value):
1818
return False
1919

20-
if other.segment_text and self.segment_text != other.segment_text:
20+
if self.segment_text and other.segment_text and self.segment_text != other.segment_text:
2121
return False
2222

2323
return self.id == other.id and self.label == other.label

src/trainable_entity_extractor/tests/unit_tests/domain/test_PerformanceSummary.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,14 +125,12 @@ def test_direct_instantiation_with_empty_pdf_count(self):
125125
languages=["en", "es"],
126126
training_samples_count=25,
127127
testing_samples_count=15,
128-
empty_pdf_count=2,
129128
)
130129

131130
assert result.extractor_name == "PDF Extractor"
132131
assert result.samples_count == 4
133132
assert result.options_count == 0
134133
assert set(result.languages) == {"en", "es"}
135-
assert result.empty_pdf_count == 2
136134

137135
def test_to_log_basic_summary_no_methods(self):
138136
"""Test to_log with basic summary but no performance methods"""

src/trainable_entity_extractor/tests/unit_tests/extractors/pdf_to_multi_option_extractor/multi_labels_methods/test_single_label_setfit_english.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
from trainable_entity_extractor.domain.LabeledData import LabeledData
1212
from trainable_entity_extractor.domain.Option import Option
1313
from trainable_entity_extractor.domain.PdfData import PdfData
14+
from trainable_entity_extractor.domain.PredictionSample import PredictionSample
15+
from trainable_entity_extractor.domain.PredictionSamplesData import PredictionSamplesData
1416
from trainable_entity_extractor.domain.TrainingSample import TrainingSample
1517
from trainable_entity_extractor.adapters.extractors.pdf_to_multi_option_extractor.multi_labels_methods.SingleLabelSetFitEnglishMethod import (
1618
SingleLabelSetFitEnglishMethod,
@@ -25,7 +27,7 @@ class TestSetFitSingleLabelEnglishMethod(TestCase):
2527
def tearDown(self):
2628
shutil.rmtree(join(DATA_PATH, self.TENANT), ignore_errors=True)
2729

28-
@unittest.SkipTest
30+
@unittest.skip("Skipping GPU test in CI/CD")
2931
def test_train_and_predict(self):
3032
if not torch.cuda.is_available():
3133
return
@@ -45,21 +47,19 @@ def test_train_and_predict(self):
4547
extraction_data = ExtractionData(
4648
multi_value=False, options=options, samples=samples, extraction_identifier=extraction_identifier
4749
)
48-
setfit_english_method = SingleLabelSetFitEnglishMethod(extraction_identifier, options, False)
50+
setfit_english_method = SingleLabelSetFitEnglishMethod(extraction_identifier)
4951

5052
try:
5153
setfit_english_method.train(extraction_data)
5254
except Exception as e:
5355
self.fail(f"train() raised {type(e).__name__}")
5456

55-
prediction_sample_1 = TrainingSample(pdf_data=pdf_data_1)
56-
prediction_sample_2 = TrainingSample(pdf_data=pdf_data_2)
57-
prediction_sample_3 = TrainingSample(pdf_data=pdf_data_3)
57+
prediction_sample_1 = PredictionSample(pdf_data=pdf_data_1)
58+
prediction_sample_2 = PredictionSample(pdf_data=pdf_data_2)
59+
prediction_sample_3 = PredictionSample(pdf_data=pdf_data_3)
5860
prediction_samples = [prediction_sample_1, prediction_sample_2, prediction_sample_3]
5961

60-
prediction_data = ExtractionData(
61-
multi_value=False, options=options, samples=prediction_samples, extraction_identifier=extraction_identifier
62-
)
62+
prediction_data = PredictionSamplesData(multi_value=False, options=options, prediction_samples=prediction_samples)
6363
predictions = setfit_english_method.predict(prediction_data)
6464

6565
self.assertEqual(3, len(predictions))

0 commit comments

Comments
 (0)