diff --git a/src/data_model/SegmentBox.py b/src/data_model/SegmentBox.py index 5929768..a629828 100644 --- a/src/data_model/SegmentBox.py +++ b/src/data_model/SegmentBox.py @@ -2,6 +2,7 @@ from pdf_features.PdfPage import PdfPage from pdf_token_type_labels.TokenType import TokenType from pydantic import BaseModel +from typing import List, Optional, Dict, Any class SegmentBox(BaseModel): @@ -14,9 +15,10 @@ class SegmentBox(BaseModel): page_height: int text: str = "" type: TokenType = TokenType.TEXT + sub_element_positions: Optional[List[Dict[str, Any]]] = None def to_dict(self): - return { + result = { "left": self.left, "top": self.top, "width": self.width, @@ -27,9 +29,12 @@ def to_dict(self): "text": self.text, "type": self.type.value, } + if self.sub_elements_positions is not None: + result["sub_elements_positions"] = self.sub_elements_positions + return result @staticmethod - def from_pdf_segment(pdf_segment: PdfSegment, pdf_pages: list[PdfPage]): + def from_pdf_segment(pdf_segment: PdfSegment, pdf_pages: list[PdfPage], sub_elements: Optional[List[Dict[str, Any]]] = None): return SegmentBox( left=pdf_segment.bounding_box.left, top=pdf_segment.bounding_box.top, @@ -40,6 +45,7 @@ def from_pdf_segment(pdf_segment: PdfSegment, pdf_pages: list[PdfPage]): page_height=pdf_pages[pdf_segment.page_number - 1].page_height, text=pdf_segment.text_content, type=pdf_segment.segment_type, + sub_elements_positions=sub_elements, ) diff --git a/src/fast_trainer/PdfSegment.py b/src/fast_trainer/PdfSegment.py index 0427a7b..a17f575 100644 --- a/src/fast_trainer/PdfSegment.py +++ b/src/fast_trainer/PdfSegment.py @@ -2,6 +2,7 @@ from pdf_features.PdfToken import PdfToken from pdf_features.Rectangle import Rectangle from pdf_token_type_labels.TokenType import TokenType +from typing import List, Dict, Any, Optional class PdfSegment: @@ -13,6 +14,7 @@ def __init__( self.text_content = text_content self.segment_type = segment_type self.pdf_name = pdf_name + self.sub_element_positions: Optional[List[Dict[str, Any]]] = None @staticmethod def from_pdf_tokens(pdf_tokens: list[PdfToken], pdf_name: str = ""): diff --git a/src/pdf_layout_analysis/run_pdf_layout_analysis.py b/src/pdf_layout_analysis/run_pdf_layout_analysis.py index 0201509..6fe4d27 100644 --- a/src/pdf_layout_analysis/run_pdf_layout_analysis.py +++ b/src/pdf_layout_analysis/run_pdf_layout_analysis.py @@ -68,7 +68,7 @@ def analyze_pdf(file: AnyStr, xml_file_name: str, extraction_format: str = "", k pdf_path.unlink(missing_ok=True) return [ - SegmentBox.from_pdf_segment(pdf_segment, pdf_images_list[0].pdf_features.pages).to_dict() + SegmentBox.from_pdf_segment(pdf_segment, pdf_images_list[0].pdf_features.pages, pdf_segment.sub_element_positions).to_dict() for pdf_segment in predicted_segments ] diff --git a/src/pdf_layout_analysis/run_pdf_layout_analysis_fast.py b/src/pdf_layout_analysis/run_pdf_layout_analysis_fast.py index e06b36f..9db2773 100644 --- a/src/pdf_layout_analysis/run_pdf_layout_analysis_fast.py +++ b/src/pdf_layout_analysis/run_pdf_layout_analysis_fast.py @@ -37,4 +37,4 @@ def analyze_pdf_fast( pdf_images.remove_images() if not keep_pdf: pdf_path.unlink(missing_ok=True) - return [SegmentBox.from_pdf_segment(pdf_segment, pdf_images.pdf_features.pages).to_dict() for pdf_segment in segments] + return [SegmentBox.from_pdf_segment(pdf_segment, pdf_images.pdf_features.pages, pdf_segment.sub_element_positions).to_dict() for pdf_segment in segments] diff --git a/src/toc/TitleFeatures.py b/src/toc/TitleFeatures.py index 9769183..dfce408 100755 --- a/src/toc/TitleFeatures.py +++ b/src/toc/TitleFeatures.py @@ -157,7 +157,7 @@ def to_toc_item(self, indentation): return TOCItem( indentation=indentation, label=self.text_content, - selection_rectangle=SegmentBox.from_pdf_segment(self.pdf_segment, self.pdf_features.pages), + selection_rectangle=SegmentBox.from_pdf_segment(self.pdf_segment, self.pdf_features.pages, self.pdf_segment.sub_element_positions), ) def append(self, other_title_features: "TitleFeatures"): diff --git a/src/vgt/get_most_probable_pdf_segments.py b/src/vgt/get_most_probable_pdf_segments.py index f802b4f..aaed442 100644 --- a/src/vgt/get_most_probable_pdf_segments.py +++ b/src/vgt/get_most_probable_pdf_segments.py @@ -93,7 +93,33 @@ def merge_colliding_predictions(predictions: list[Prediction]): def get_pdf_segments_for_page(page, pdf_name, page_pdf_name, vgt_predictions_dict): most_probable_pdf_segments_for_page: list[PdfSegment] = [] most_probable_tokens_by_predictions: dict[Prediction, list[PdfToken]] = {} - vgt_predictions_dict[page_pdf_name] = merge_colliding_predictions(vgt_predictions_dict[page_pdf_name]) + + # Store original predictions before merging + original_predictions = vgt_predictions_dict[page_pdf_name].copy() + + # Create a mapping from merged predictions to their original constituent predictions + merged_to_original_mapping = {} + + # Merge predictions and track which original predictions went into each merged one + merged_predictions = merge_colliding_predictions(vgt_predictions_dict[page_pdf_name]) + + # Build mapping by checking which original predictions intersect with merged ones + for merged_pred in merged_predictions: + original_elements = [] + for orig_pred in original_predictions: + if merged_pred.bounding_box.get_intersection_percentage(orig_pred.bounding_box) > 0: + original_elements.append({ + "left": orig_pred.bounding_box.left, + "top": orig_pred.bounding_box.top, + "width": orig_pred.bounding_box.width, + "height": orig_pred.bounding_box.height, + "type": DOCLAYNET_TYPE_BY_ID[orig_pred.category_id], + "score": orig_pred.score + }) + if len(original_elements) > 1: # Only store if there were actually multiple elements merged + merged_to_original_mapping[merged_pred] = original_elements + + vgt_predictions_dict[page_pdf_name] = merged_predictions for token in page.tokens: find_best_prediction_for_token(page_pdf_name, token, vgt_predictions_dict, most_probable_tokens_by_predictions) @@ -102,6 +128,11 @@ def get_pdf_segments_for_page(page, pdf_name, page_pdf_name, vgt_predictions_dic new_segment = PdfSegment.from_pdf_tokens(tokens, pdf_name) new_segment.bounding_box = prediction.bounding_box new_segment.segment_type = TokenType.from_text(DOCLAYNET_TYPE_BY_ID[prediction.category_id]) + + # Store original elements if this prediction was merged + if prediction in merged_to_original_mapping: + new_segment.sub_element_positions = merged_to_original_mapping[prediction] + most_probable_pdf_segments_for_page.append(new_segment) no_token_predictions = [ @@ -114,6 +145,11 @@ def get_pdf_segments_for_page(page, pdf_name, page_pdf_name, vgt_predictions_dic segment_type = TokenType.from_text(DOCLAYNET_TYPE_BY_ID[prediction.category_id]) page_number = page.page_number new_segment = PdfSegment(page_number, prediction.bounding_box, "", segment_type, pdf_name) + + # Store original elements if this prediction was merged + if prediction in merged_to_original_mapping: + new_segment.sub_elements_positions = merged_to_original_mapping[prediction] + most_probable_pdf_segments_for_page.append(new_segment) return most_probable_pdf_segments_for_page