Skip to content

Commit ffaccc9

Browse files
committed
Fix Gliner slowness
1 parent 75a9f73 commit ffaccc9

File tree

4 files changed

+49
-15
lines changed

4 files changed

+49
-15
lines changed

src/trainable_entity_extractor/adapters/extractors/GlinerDateExtractor.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
import json
2-
32
from dateparser.search import search_dates
4-
from gliner import GLiNER
53

64

75
class GlinerDateExtractor:
6+
7+
def __init__(self, model):
8+
self.model = model
9+
810
@staticmethod
911
def find_unique_entity_dicts(entities: list[dict]) -> list[dict]:
1012
dicts_without_score = [{k: v for k, v in d.items() if k != "score"} for d in entities]
@@ -25,8 +27,6 @@ def remove_overlapping_entities(entities):
2527
return result
2628

2729
def extract_dates(self, text: str):
28-
gliner_model = GLiNER.from_pretrained("urchade/gliner_multi-v2.1")
29-
3030
words = text.split()
3131

3232
entities = []
@@ -37,7 +37,7 @@ def extract_dates(self, text: str):
3737
for i in range(0, len(words), slide_size):
3838
window_words = words[i : i + window_size]
3939
window_text = " ".join(window_words)
40-
window_entities = gliner_model.predict_entities(window_text, ["date"])
40+
window_entities = self.model.predict_entities(window_text, ["date"])
4141

4242
for entity in window_entities:
4343
entity["start"] += last_slide_end_index

src/trainable_entity_extractor/adapters/extractors/pdf_to_text_extractor/methods/GlinerFirstDateMethod.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,54 @@
11
import re
22

3+
from gliner import GLiNER
4+
5+
from trainable_entity_extractor.adapters.extractors.ToTextExtractorMethod import ToTextExtractorMethod
6+
from trainable_entity_extractor.domain.ExtractionData import ExtractionData
37
from trainable_entity_extractor.domain.PdfDataSegment import PdfDataSegment
4-
from trainable_entity_extractor.adapters.extractors.pdf_to_text_extractor.methods.FirstDateMethod import FirstDateMethod
58
from trainable_entity_extractor.adapters.extractors.text_to_text_extractor.methods.GlinerDateParserMethod import (
69
GlinerDateParserMethod,
710
)
11+
from trainable_entity_extractor.domain.PredictionSamplesData import PredictionSamplesData
12+
13+
14+
class GlinerFirstDateMethod(ToTextExtractorMethod):
15+
def train(self, extraction_data: ExtractionData):
16+
languages = [x.labeled_data.language_iso for x in extraction_data.samples]
17+
self.save_json("languages.json", list(set(languages)))
18+
19+
def predict(self, prediction_samples_data: PredictionSamplesData) -> list[str]:
20+
gliner_model = GLiNER.from_pretrained("urchade/gliner_multi-v2.1")
21+
predictions_samples = prediction_samples_data.prediction_samples
22+
predictions = [""] * len(predictions_samples)
23+
languages = self.load_json("languages.json")
24+
for index, prediction_sample in enumerate(predictions_samples):
25+
segments = prediction_sample.pdf_data.pdf_data_segments
26+
27+
if predictions[index] or not prediction_sample.pdf_data or not segments:
28+
continue
29+
30+
predictions[index] = self.get_date_from_segments(gliner_model, segments, languages)
31+
32+
return predictions
833

34+
@staticmethod
35+
def loop_segments(segments: list[PdfDataSegment]):
36+
for segment in segments:
37+
yield segment
938

10-
class GlinerFirstDateMethod(FirstDateMethod):
1139
@staticmethod
1240
def contains_year(text: str):
1341
year_pattern = re.compile(r"([0-9]{2})")
1442
return bool(year_pattern.search(text.replace(" ", "")))
1543

16-
def get_date_from_segments(self, segments: list[PdfDataSegment], languages):
44+
def get_date_from_segments(self, model, segments: list[PdfDataSegment], languages):
1745
merge_segments: list[list[PdfDataSegment]] = self.merge_segments_for_dates(segments)
1846
for segments in merge_segments:
1947
segment_merged = PdfDataSegment.from_list_to_merge(segments)
2048
if not self.contains_year(segment_merged.text_content):
2149
continue
2250

23-
date = GlinerDateParserMethod.get_date([segment_merged.text_content])
51+
date = GlinerDateParserMethod.get_date(model, [segment_merged.text_content])
2452
if date:
2553
for segment in segments:
2654
segment.ml_label = 1

src/trainable_entity_extractor/adapters/extractors/pdf_to_text_extractor/methods/GlinerLastDateMethod.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,12 @@ def loop_segments(segments):
1313
for segment in reversed(segments):
1414
yield segment
1515

16-
def get_date_from_segments(self, segments, languages):
16+
def get_date_from_segments(self, model, segments, languages):
1717
for segment in self.loop_segments(segments):
1818
if not self.contains_year(segment.text_content):
1919
continue
2020

21-
date = GlinerDateParserMethod.get_date([segment.text_content])
21+
date = GlinerDateParserMethod.get_date(model, [segment.text_content])
2222
if date:
2323
segment.ml_label = 1
2424
return date.strftime("%Y-%m-%d")

src/trainable_entity_extractor/adapters/extractors/text_to_text_extractor/methods/GlinerDateParserMethod.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from gliner import GLiNER
2+
13
from trainable_entity_extractor.domain.ExtractionData import ExtractionData
24
from trainable_entity_extractor.domain.PredictionSamplesData import PredictionSamplesData
35
from trainable_entity_extractor.adapters.extractors.ToTextExtractorMethod import ToTextExtractorMethod
@@ -15,12 +17,12 @@ def get_alphanumeric_text_with_spaces(text):
1517
return "".join([letter for letter in text if letter.isalnum() or letter.isspace()])
1618

1719
@staticmethod
18-
def get_date(tags_texts: list[str]):
20+
def get_date(model, tags_texts: list[str]):
1921
if not tags_texts:
2022
return ""
2123
text = GlinerDateParserMethod.get_alphanumeric_text_with_spaces(" ".join(tags_texts))
2224
try:
23-
gliner_date_extractor = GlinerDateExtractor()
25+
gliner_date_extractor = GlinerDateExtractor(model)
2426
dates = gliner_date_extractor.extract_dates(text)
2527
return dates[0]
2628
except:
@@ -29,7 +31,9 @@ def get_date(tags_texts: list[str]):
2931
return None
3032

3133
def train(self, extraction_data: ExtractionData):
32-
gliner_date_extractor = GlinerDateExtractor()
34+
gliner_model = GLiNER.from_pretrained("urchade/gliner_multi-v2.1")
35+
36+
gliner_date_extractor = GlinerDateExtractor(gliner_model)
3337

3438
for sample in extraction_data.samples[:15]:
3539
if not sample.labeled_data.label_text.strip():
@@ -42,11 +46,13 @@ def train(self, extraction_data: ExtractionData):
4246
self.save_json(self.IS_VALID_EXECUTION_FILE_NAME, "true")
4347

4448
def predict(self, prediction_samples_data: PredictionSamplesData) -> list[str]:
49+
gliner_model = GLiNER.from_pretrained("urchade/gliner_multi-v2.1")
50+
4551
if self.load_json(self.IS_VALID_EXECUTION_FILE_NAME) == "false":
4652
return [""] * len(prediction_samples_data.prediction_samples)
4753

4854
predictions_dates = [
49-
self.get_date(prediction_sample.get_input_text_by_lines())
55+
self.get_date(gliner_model, prediction_sample.get_input_text_by_lines())
5056
for prediction_sample in prediction_samples_data.prediction_samples
5157
]
5258
predictions = [date.strftime("%Y-%m-%d") if date else "" for date in predictions_dates]

0 commit comments

Comments
 (0)