Skip to content

Commit

Permalink
feat!: Resolve nested clusters for DoclingDocument (#92)
Browse files Browse the repository at this point in the history
Signed-off-by: Christoph Auer <[email protected]>
Signed-off-by: Peter Staar <[email protected]>
Co-authored-by: Peter Staar <[email protected]>
  • Loading branch information
cau-git and PeterStaar-IBM authored Dec 4, 2024
1 parent 49d2cd8 commit 61e5499
Show file tree
Hide file tree
Showing 10 changed files with 82 additions and 266 deletions.
261 changes: 1 addition & 260 deletions deepsearch_glm/utils/doc_utils.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,7 @@
import re
from pathlib import Path
from typing import List

import pandas as pd
from docling_core.types.doc import (
BoundingBox,
CoordOrigin,
DocItemLabel,
DoclingDocument,
DocumentOrigin,
GroupLabel,
ProvenanceItem,
Size,
TableCell,
TableData,
)
from docling_core.types.doc import DocItemLabel


def resolve_item(paths, obj):
Expand Down Expand Up @@ -51,252 +38,6 @@ def resolve_item(paths, obj):
return None


def _flatten_table_grid(grid: List[List[dict]]) -> List[dict]:
unique_objects = []
seen_spans = set()

for sublist in grid:
for obj in sublist:
# Convert the spans list to a tuple of tuples for hashing
spans_tuple = tuple(tuple(span) for span in obj["spans"])
if spans_tuple not in seen_spans:
seen_spans.add(spans_tuple)
unique_objects.append(obj)

return unique_objects


def to_docling_document(doc_glm, update_name_label=False) -> DoclingDocument:
origin = DocumentOrigin(
mimetype="application/pdf",
filename=doc_glm["file-info"]["filename"],
binary_hash=doc_glm["file-info"]["document-hash"],
)
doc_name = Path(origin.filename).stem

doc: DoclingDocument = DoclingDocument(name=doc_name, origin=origin)

if "properties" in doc_glm:
props = pd.DataFrame(
doc_glm["properties"]["data"], columns=doc_glm["properties"]["headers"]
)
else:
props = pd.DataFrame()

current_list = None

for ix, pelem in enumerate(doc_glm["page-elements"]):
ptype = pelem["type"]
span_i = pelem["span"][0]
span_j = pelem["span"][1]

if "iref" not in pelem:
# print(json.dumps(pelem, indent=2))
continue

iref = pelem["iref"]

if re.match("#/figures/(\\d+)/captions/(.+)", iref):
# print(f"skip {iref}")
continue

if re.match("#/tables/(\\d+)/captions/(.+)", iref):
# print(f"skip {iref}")
continue

path = iref.split("/")
obj = resolve_item(path, doc_glm)

if obj is None:
current_list = None
print(f"warning: undefined {path}")
continue

if ptype == "figure":
current_list = None
text = ""
caption_refs = []
for caption in obj["captions"]:
text += caption["text"]

for nprov in caption["prov"]:
npaths = nprov["$ref"].split("/")
nelem = resolve_item(npaths, doc_glm)

if nelem is None:
# print(f"warning: undefined caption {npaths}")
continue

span_i = nelem["span"][0]
span_j = nelem["span"][1]

cap_text = caption["text"][span_i:span_j]

# doc_glm["page-elements"].remove(nelem)

prov = ProvenanceItem(
page_no=nelem["page"],
charspan=tuple(nelem["span"]),
bbox=BoundingBox.from_tuple(
nelem["bbox"], origin=CoordOrigin.BOTTOMLEFT
),
)

caption_obj = doc.add_text(
label=DocItemLabel.CAPTION, text=cap_text, prov=prov
)
caption_refs.append(caption_obj.get_ref())

prov = ProvenanceItem(
page_no=pelem["page"],
charspan=(0, len(text)),
bbox=BoundingBox.from_tuple(
pelem["bbox"], origin=CoordOrigin.BOTTOMLEFT
),
)

pic = doc.add_picture(prov=prov)
pic.captions.extend(caption_refs)

elif ptype == "table":
current_list = None
text = ""
caption_refs = []
for caption in obj["captions"]:
text += caption["text"]

for nprov in caption["prov"]:
npaths = nprov["$ref"].split("/")
nelem = resolve_item(npaths, doc_glm)

if nelem is None:
# print(f"warning: undefined caption {npaths}")
continue

span_i = nelem["span"][0]
span_j = nelem["span"][1]

cap_text = caption["text"][span_i:span_j]

# doc_glm["page-elements"].remove(nelem)

prov = ProvenanceItem(
page_no=nelem["page"],
charspan=tuple(nelem["span"]),
bbox=BoundingBox.from_tuple(
nelem["bbox"], origin=CoordOrigin.BOTTOMLEFT
),
)

caption_obj = doc.add_text(
label=DocItemLabel.CAPTION, text=cap_text, prov=prov
)
caption_refs.append(caption_obj.get_ref())

table_cells_glm = _flatten_table_grid(obj["data"])

table_cells = []
for tbl_cell_glm in table_cells_glm:
if tbl_cell_glm["bbox"] is not None:
bbox = BoundingBox.from_tuple(
tbl_cell_glm["bbox"], origin=CoordOrigin.BOTTOMLEFT
)
else:
bbox = None

is_col_header = False
is_row_header = False
is_row_section = False

if tbl_cell_glm["type"] == "col_header":
is_col_header = True
elif tbl_cell_glm["type"] == "row_header":
is_row_header = True
elif tbl_cell_glm["type"] == "row_section":
is_row_section = True

table_cells.append(
TableCell(
row_span=tbl_cell_glm["row-span"][1]
- tbl_cell_glm["row-span"][0],
col_span=tbl_cell_glm["col-span"][1]
- tbl_cell_glm["col-span"][0],
start_row_offset_idx=tbl_cell_glm["row-span"][0],
end_row_offset_idx=tbl_cell_glm["row-span"][1],
start_col_offset_idx=tbl_cell_glm["col-span"][0],
end_col_offset_idx=tbl_cell_glm["col-span"][1],
text=tbl_cell_glm["text"],
bbox=bbox,
column_header=is_col_header,
row_header=is_row_header,
row_section=is_row_section,
)
)

tbl_data = TableData(
num_rows=obj.get("#-rows", 0),
num_cols=obj.get("#-cols", 0),
table_cells=table_cells,
)

prov = ProvenanceItem(
page_no=pelem["page"],
charspan=(0, 0),
bbox=BoundingBox.from_tuple(
pelem["bbox"], origin=CoordOrigin.BOTTOMLEFT
),
)

tbl = doc.add_table(data=tbl_data, prov=prov)
tbl.captions.extend(caption_refs)

elif "text" in obj:
text = obj["text"][span_i:span_j]

type_label = pelem["type"]
name_label = pelem["name"]
if update_name_label and len(props) > 0 and type_label == "paragraph":
prop = props[
(props["type"] == "semantic") & (props["subj_path"] == iref)
]
if len(prop) == 1 and prop.iloc[0]["confidence"] > 0.85:
name_label = prop.iloc[0]["label"]

prov = ProvenanceItem(
page_no=pelem["page"],
charspan=(0, len(text)),
bbox=BoundingBox.from_tuple(
pelem["bbox"], origin=CoordOrigin.BOTTOMLEFT
),
)
label = DocItemLabel(name_label)

if label == DocItemLabel.LIST_ITEM:
if current_list is None:
current_list = doc.add_group(label=GroupLabel.LIST, name="list")

# TODO: Infer if this is a numbered or a bullet list item
doc.add_list_item(
text=text, enumerated=False, prov=prov, parent=current_list
)
elif label == DocItemLabel.SECTION_HEADER:
current_list = None

doc.add_heading(text=text, prov=prov)
else:
current_list = None

doc.add_text(label=DocItemLabel(name_label), text=text, prov=prov)

for page_dim in doc_glm["page-dimensions"]:
page_no = int(page_dim["page"])
size = Size(width=page_dim["width"], height=page_dim["height"])

doc.add_page(page_no=page_no, size=size)

return doc


def to_legacy_document_format(doc_glm, doc_leg={}, update_name_label=False):
"""Convert Document object (with `body`) to its legacy format (with `main-text`)"""

Expand Down
22 changes: 19 additions & 3 deletions src/andromeda/tooling/structs/subjects/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ namespace andromeda

const static inline std::string prov_lbl = "prov";

const static inline std::string payload_lbl = "payload"; // arbitrary data that needs to be carried through

const static inline std::string subj_hash_lbl = "subj_hash";
const static inline std::string text_hash_lbl = "text_hash"; // for text

Expand Down Expand Up @@ -141,6 +143,8 @@ namespace andromeda
std::vector<base_property> properties;
std::vector<base_instance> instances;
std::vector<base_relation> relations;

nlohmann::json payload;

//std::vector<base_entity> entities;
};
Expand All @@ -159,7 +163,9 @@ namespace andromeda

properties({}),
instances({}),
relations({})
relations({}),

payload(nlohmann::json::value_t::null)
{}

base_subject::base_subject(subject_name name):
Expand All @@ -176,7 +182,9 @@ namespace andromeda

properties({}),
instances({}),
relations({})
relations({}),

payload(nlohmann::json::value_t::null)
{}

base_subject::base_subject(uint64_t dhash,
Expand All @@ -195,7 +203,9 @@ namespace andromeda

properties({}),
instances({}),
relations({})
relations({}),

payload(nlohmann::json::value_t::null)
{
auto parts = utils::split(dloc, "#");
if(parts.size()==2)
Expand Down Expand Up @@ -321,6 +331,10 @@ namespace andromeda
{
nlohmann::json result = nlohmann::json::object({});

{
result[payload_lbl] = payload;
}

{
result[subj_hash_lbl] = hash;
result[dloc_lbl] = dloc;
Expand Down Expand Up @@ -379,6 +393,8 @@ namespace andromeda

bool base_subject::_from_json(const nlohmann::json& item)
{
payload = item.value(payload_lbl, payload);

hash = item.value(subj_hash_lbl, hash);

dloc = item.value(dloc_lbl, dloc);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -323,14 +323,14 @@ namespace andromeda

std::string base = parts.at(1);
std::size_t index = std::stoi(parts.at(2));

auto& item = orig.at(base).at(index);

if(is_text.count(prov->get_type()))
{
std::stringstream ss;
ss << doc_name << "#/" << doc_type::texts_lbl << "/" << texts.size();

std::string dloc = ss.str();

auto subj = std::make_shared<subject<TEXT> >(doc.get_hash(), dloc, prov);
Expand Down
9 changes: 9 additions & 0 deletions src/andromeda/tooling/structs/subjects/figure.h
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,15 @@ namespace andromeda
{
base_subject::valid = true;

if(data.count(payload_lbl))
{
payload = data.value(payload_lbl, payload);
}
else
{
payload = nlohmann::json::value_t::null;
}

return base_subject::valid;
}

Expand Down
9 changes: 9 additions & 0 deletions src/andromeda/tooling/structs/subjects/table.h
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,15 @@ namespace andromeda

data.clear();

if(item.count(payload_lbl))
{
payload = item.value(payload_lbl, payload);
}
else
{
payload = nlohmann::json::value_t::null;
}

{
conf = item.value(base_subject::confidence_lbl, conf);
created_by = item.value(base_subject::created_by_lbl, created_by);
Expand Down
Loading

0 comments on commit 61e5499

Please sign in to comment.