diff --git a/automates/model_analysis/linking.py b/automates/model_analysis/linking.py index 248cd2184..3acb1c6e5 100644 --- a/automates/model_analysis/linking.py +++ b/automates/model_analysis/linking.py @@ -1,130 +1,636 @@ -import networkx as nx - - -def make_link_tables(GrFN): - # Add links to the link-graph G from the grounding data. Add all variables - # found to a list of vars. (Use the vars list to walk the graph in a DFS - # style to recover all rows for the table) - vars = list() - G = nx.Graph() - for link_dict in GrFN["grounding"]: - id1 = get_id(link_dict["element_1"]) - id2 = get_id(link_dict["element_2"]) - if id1[0] == "": - vars.append(id1) - if id2[0] == "": - vars.append(id2) - G.add_edge("\n".join(id1), "\n".join(id2), label=round(link_dict["score"], 3)) - - table_data = dict() - for var in vars: - var_name = "\n".join(var) - var_table = list() - for n1 in G.neighbors(var_name): - if n1.startswith(""): - for n2 in G.neighbors(n1): - if n2.startswith(""): - for n3 in G.neighbors(n2): - if n3.startswith(""): - var_table.append({ - "link_score": min([ - G[var_name][n1]["label"], - G[n1][n2]["label"], - G[n2][n3]["label"] - ]), - "comm": n1.split("\n")[1], - "vc_score": G[var_name][n1]["label"], - "txt": n2.split("\n")[1], - "ct_score": G[n1][n2]["label"], - "eqn": n3.split("\n")[1], - "te_score": G[n2][n3]["label"] - }) - var_table.sort( - key=lambda r: (r["link_score"], - r["vc_score"], - r["ct_score"], - r["te_score"]), - reverse=True - ) - table_data[var] = var_table - return table_data - - -def get_argument_lists(GrFN): - # Make a dict of the argument lists for each container in the GrFN indexed - # by the container basename - return { - c["name"].split("::")[-1]: c["arguments"] for c in GrFN["containers"] - } - - -def get_call_conventions(GrFN): - # Make a dict of all of the calls in every container in the GrFN. Index them - # by basename of the callee. Include the caller basename and input list in - # the value field - return { - stmt["function"]["name"].split("::")[-1]: { - "caller": container["name"].split("::")[-1], - "inputs": stmt["input"] - } - for container in GrFN["containers"] - for stmt in container["body"] - if stmt["function"]["type"] == "function_name" - } - - -def output_link_graph(G, filename="linking-graph.pdf"): - A = nx.nx_agraph.to_AGraph(G) - A.draw(filename, prog="circo") - - -def print_table_data(table_data): - for (_, scope, name, idx), table in table_data.items(): - print("::".join([scope, name, idx])) - print("L-SCORE\tComment\tV-C\tText-span\tC-T\tEquation\tT-E") - for row in table: - print(f"{row['link_score']}\t{row['comm']}\t{row['vc_score']}\t{row['txt']}\t{row['ct_score']}\t{row['eqn']}\t{row['te_score']}") - print("\n\n") +from abc import ABC, abstractmethod +from dataclasses import dataclass +from functools import singledispatch +from typing import Tuple +import re + +from networkx import DiGraph +from automates.model_assembly.networks import GroundedFunctionNetwork + + +@dataclass(repr=False, frozen=True) +class LinkNode(ABC): + uid: str + content: str + + def __repr__(self): + return self.__str__() + + def __str__(self): + return self.content + + @staticmethod + def from_dict(data: dict, element_type: str, grounding_information: dict): + if element_type == "source": + return CodeVarNode(data["uid"], data["content"], data["source"]) + elif element_type == "gl_src_var": + src_vars = list() + for src_var_uid in data["identifier_objects"]: + src_var_data = [ + src_var + for src_var in grounding_information["src"] + if src_var_uid == src_var["uid"] + ][0] + src_vars.append( + CodeVarNode( + src_var_data["uid"], + src_var_data["content"], + src_var_data["source"], + ) + ) + return GCodeVarNode(data["uid"], data["content"], tuple(src_vars)) + elif element_type == "gl_comm": + comm_vars = list() + for comm_var_uid in data["identifier_objects"]: + comm_var_data = [ + comm_var + for comm_var in grounding_information["comment"] + if comm_var_uid == comm_var["uid"] + ][0] + content = comm_var_data["content"] + if "arguments" in comm_var_data: + description_arguments = [ + arg + for arg in comm_var_data["arguments"] + if "name" in arg and arg["name"] == "description" + ] + if len(description_arguments) > 0: + content = description_arguments[0]["text"] -def merge_similar_vars(vars): - unique_vars = dict() - for (_, scope, name, idx) in vars: - if (scope, name) in unique_vars: - unique_vars[(scope, name)].append(int(idx)) + comm_vars.append( + CommSpanNode( + comm_var_data["uid"], + content, + comm_var_data["source"], + ) + ) + return GCommSpanNode(data["uid"], data["content"], tuple(comm_vars)) + elif element_type == "gl_eq_var": + equation_nodes = list() + for equation_node_uid in data["identifier_objects"]: + equation_node_matches = [ + eq + for eq in grounding_information["equation"] + if eq["uid"] == equation_node_uid + ] + if len(equation_node_matches) == 0: + continue + equation_node = equation_node_matches[0] + equation = None + equation_index = -1 + if "equation_uid" in equation_node: + for eq in grounding_information["full_text_equation"]: + equation_index += 1 + if eq["uid"] == equation_node["equation_uid"]: + equation = FullTextEquationNode(eq["uid"], eq["content"]) + break + equation_nodes.append( + EqnVarNode( + equation_node["uid"], + equation_node["content"], + equation, + equation_index, + ) + ) + + return GEqnVarNode(data["uid"], data["content"], tuple(equation_nodes)) + + elif element_type == "gvar": + text_vars = list() + for text_var_uid in data["identifier_objects"]: + text_var_data = [ + text_var + for text_var in grounding_information["text_var"] + if text_var_uid == text_var["uid"] + ][0] + + content = text_var_data["content"] + if "arguments" in text_var_data: + description_arguments = [ + arg + for arg in text_var_data["arguments"] + if "name" in arg and arg["name"] == "description" + ] + if len(description_arguments) > 0: + content = description_arguments[0]["text"] + + text_vars.append( + TextVarNode( + text_var_data["source"], + content, + TextExtraction( + text_var_data["spans"]["page"], + text_var_data["spans"]["block"], + tuple( + Span(s["char_begin"], s["char_end"]) + for s in text_var_data["spans"]["spans"] + ), + ), + ) + ) + + return GVarNode(data["uid"], data["content"], tuple(text_vars)) + elif ( + element_type == "parameter_setting_via_idfr" + or element_type == "int_param_setting_via_idfr" + or element_type == "parameter_setting_via_cncpt" + or element_type == "int_param_setting_via_cncpt" + ): + return ParameterSettingNode( + data["uid"], + data["content"], + data["original_sentence"], + data["source"], + TextExtraction( + data["spans"]["page"], + data["spans"]["block"], + tuple( + Span(s["char_begin"], s["char_end"]) + for s in data["spans"]["spans"] + ), + ), + ) + elif element_type == "unit_via_idfr" or element_type == "unit_via_cncpt": + return UnitNode( + data["uid"], + data["content"], + TextExtraction( + data["spans"]["page"], + data["spans"]["block"], + tuple( + Span(s["char_begin"], s["char_end"]) + for s in data["spans"]["spans"] + ), + ), + ) else: - unique_vars[(scope, name)] = [int(idx)] - return unique_vars - - -def format_long_text(text): - new_text = list() - while len(text) > 8: - new_text.extend(text[:4]) - new_text.append("\n") - text = text[4:] - new_text.extend(text) - return new_text - - -def get_id(el_data): - el_type = el_data["type"] - if el_type == "identifier": - var_name = el_data["content"].split("::")[-3:] - return tuple([""] + var_name) - elif el_type == "comment_span": - tokens = el_data["content"].split() - name = tokens[0] - desc = " ".join(format_long_text(tokens[1:])) - return ("", f"{name}: {desc}") - return ("", name, desc) - elif el_type == "text_span": - desc = " ".join(format_long_text(el_data["content"].split())) - desc = el_data["content"] - return ("", desc) - elif el_type == "equation_span": - desc = " ".join(format_long_text(el_data["content"].split())) - desc = el_data["content"] - return ("", desc) - else: - raise ValueError(f"Unrecognized link type: {el_type}") + raise ValueError(f"Unrecognized link element type: {element_type}") + + @abstractmethod + def get_table_rows(self, link_graph: DiGraph) -> list: + return NotImplemented + + +@dataclass(repr=False, frozen=True) +class Span: + char_begin: int + char_end: int + + +@dataclass(frozen=True) +class TextExtraction: + page: int + block: int + spans: Tuple[Span] + + +@dataclass(repr=False, frozen=True) +class ParameterSettingNode(LinkNode): + + original_sentence: str + source: str + text_extraction: TextExtraction + + def get_table_rows(self, link_graph: DiGraph) -> list: + return None + + +@dataclass(repr=False, frozen=True) +class UnitNode(LinkNode): + text_extraction: TextExtraction + + def get_table_rows(self, link_graph: DiGraph) -> list: + return None + + +@dataclass(repr=False, frozen=True) +class CodeVarNode(LinkNode): + source: str + + def __repr__(self): + return self.__str__() + + def __str__(self): + # TODO the content no longer holds the var identifier. Is that correct? + # (namespace, scope, basename, index) = self.content.split("::") + # return "\n".join( + # [ + # f"NAMESPACE: {namespace}", + # f"SCOPE: {scope}", + # f"NAME: {basename}", + # f"INDEX: {index}", + # ] + # ) + return self.content + + def get_varname(self) -> str: + (_, _, _, basename, _) = self.content.split("::") + return basename + + def get_table_rows(self, L: DiGraph) -> list: + gcode_var_span_nodes = [ + n for n in L.predecessors(self) if isinstance(n, GCodeVarNode) + ] + + rows = list() + for gcode_var_node in gcode_var_span_nodes: + w_vc = L.edges[gcode_var_node, self]["weight"] + for r in gcode_var_node.get_table_rows(L): + scores = [ + val + for key, val in r.items() + if key.endswith("_score") and val is not None + ] + w_row = min(w_vc, *scores) + r.update({"vc_score": w_vc, "link_score": w_row}) + rows.append(r) + + return rows + + +@dataclass(repr=False, frozen=True) +class GCodeVarNode(LinkNode): + source: list + + def __repr__(self): + return self.__str__() + + def __str__(self): + return self.content + + def get_varname(self) -> str: + (_, _, _, basename, _) = self.content.split("::") + return basename + + def get_table_rows(self, L: DiGraph) -> list: + comm_span_nodes = [ + n for n in L.predecessors(self) if isinstance(n, GCommSpanNode) + ] + + rows = list() + for comm_node in comm_span_nodes: + w_vc = L.edges[comm_node, self]["weight"] + for r in comm_node.get_table_rows(L): + scores = [ + val + for key, val in r.items() + if key.endswith("_score") and val is not None + ] + w_row = min(w_vc, *scores) + r.update({"vc_score": w_vc, "link_score": w_row}) + rows.append(r) + + return rows + + +@dataclass(repr=False, frozen=True) +class TextVarNode(LinkNode): + text_extraction: TextExtraction + + def get_docname(self) -> str: + path_pieces = self.source.split("/") + doc_data = path_pieces[-1] + (docname, _) = doc_data.split(".pdf_") + return docname + + def get_table_rows(self, L: DiGraph) -> list: + # NOTE: nothing to do for now + return [] + + +@dataclass(repr=False, frozen=True) +class GVarNode(LinkNode): + text_vars: tuple + + def get_text_vars(self): + return self.text_vars + + def get_table_rows(self, L: DiGraph) -> list: + text_vars = [t_var for t_var in self.text_vars] + txt = [n.content for n in text_vars] + + eqn_span_nodes = [n for n in L.predecessors(self) if isinstance(n, GEqnVarNode)] + + rows = list() + for eqn_span in eqn_span_nodes: + te_ct = L.edges[eqn_span, self]["weight"] + for r in eqn_span.get_table_rows(L): + r.update({"txt": txt, "te_score": te_ct}) + rows.append(r) + else: + rows.append({"txt": txt, "te_score": None}) + + return rows + + +@dataclass(repr=False, frozen=True) +class CommSpanNode(LinkNode): + source: str + + def __repr__(self): + return self.__str__() + + def __str__(self): + tokens = self.content.strip().split() + if len(tokens) <= 4: + return " ".join(tokens) + + new_content = "" + while len(tokens) > 4: + new_content += "\n" + " ".join(tokens[:4]) + tokens = tokens[4:] + new_content += "\n" + " ".join(tokens) + return new_content + + def get_comment_location(self): + (filename, sub_name, place) = self.source.split("; ") + filename = filename[: filename.rfind(".f")] + return f"{filename}::{sub_name}${place}" + + def get_table_rows(self, L: DiGraph) -> list: + gvar_nodes = [n for n in L.predecessors(self) if isinstance(n, GCommSpanNode)] + + rows = list() + for gvar_node in gvar_nodes: + w_ct = L.edges[gvar_node, self]["weight"] + for r in gvar_node.get_table_rows(L): + r.update({"comm": str(self).replace("\n", " "), "ct_score": w_ct}) + rows.append(r) + + return rows + + +@dataclass(repr=False, frozen=True) +class GCommSpanNode(LinkNode): + source: list + + def __repr__(self): + return self.__str__() + + def __str__(self): + tokens = self.content.strip().split() + if len(tokens) <= 4: + return " ".join(tokens) + + new_content = "" + while len(tokens) > 4: + new_content += "\n" + " ".join(tokens[:4]) + tokens = tokens[4:] + new_content += "\n" + " ".join(tokens) + return new_content + + def get_comment_location(self): + (filename, sub_name, place) = self.source.split("; ") + filename = filename[: filename.rfind(".f")] + return f"{filename}::{sub_name}${place}" + + def get_table_rows(self, L: DiGraph) -> list: + gvar_nodes = [n for n in L.predecessors(self) if isinstance(n, GVarNode)] + + comment_spans = [s.content.replace("\n", " ") for s in self.source] + + rows = list() + for gvar_node in gvar_nodes: + w_ct = L.edges[gvar_node, self]["weight"] + for r in gvar_node.get_table_rows(L): + r.update({"comm": comment_spans, "ct_score": w_ct}) + rows.append(r) + + return rows + + +@dataclass(repr=False, frozen=True) +class FullTextEquationNode(LinkNode): + def get_table_rows(self, L: DiGraph) -> list: + # TODO + return None + + +@dataclass(repr=False, frozen=True) +class EqnVarNode(LinkNode): + full_text_equations: FullTextEquationNode + equation_number: int + + def get_table_rows(self, L: DiGraph) -> list: + return [{"eqn": str(self)}] + + +@dataclass(repr=False, frozen=True) +class GEqnVarNode(LinkNode): + equation_nodes: list + + def get_table_rows(self, L: DiGraph) -> list: + return [{"eqn": str(self)}] + + +def build_link_graph(grounding_information: dict) -> DiGraph: + G = DiGraph() + + def report_bad_link(n1, n2): + raise ValueError(f"Inappropriate link type: ({type(n1)}, {type(n2)})") + + @singledispatch + def add_link_node(node): + raise ValueError(f"Inappropriate node type: {type(node)}") + + @add_link_node.register + def _(node: CodeVarNode): + G.add_node(node, color="darkviolet") + + @add_link_node.register + def _(node: GCodeVarNode): + G.add_node(node, color="violet") + + @add_link_node.register + def _(node: CommSpanNode): + G.add_node(node, color="lightskyblue") + + @add_link_node.register + def _(node: GCommSpanNode): + G.add_node(node, color="blue") + + @add_link_node.register + def _(node: EqnVarNode): + G.add_node(node, color="orange") + + @add_link_node.register + def _(node: GEqnVarNode): + G.add_node(node, color="yellow") + + @add_link_node.register + def _(node: ParameterSettingNode): + G.add_node(node, color="green") + + @add_link_node.register + def _(node: UnitNode): + G.add_node(node, color="crimson") + + @add_link_node.register + def _(node: GVarNode): + if node not in G: + for text_var in node.text_vars: + G.add_node(text_var) + G.add_edge(node, text_var) + G.add_node(node, color="deeppink") + + @singledispatch + def add_link(n1, n2, score): + raise ValueError(f"Inappropriate node type: {type(n1)}") + + @add_link.register + def _(n1: GCodeVarNode, n2, score): + add_link_node(n1) + add_link_node(n2) + + if isinstance(n2, GCommSpanNode): + G.add_edge(n2, n1, weight=score) + else: + report_bad_link(n1, n2) + + @add_link.register + def _(n1: CommSpanNode, n2, score): + add_link_node(n1) + add_link_node(n2) + + if isinstance(n2, GCommSpanNode): + G.add_edge(n1, n2, weight=score) + else: + report_bad_link(n1, n2) + + @add_link.register + def _(n1: GCommSpanNode, n2, score): + add_link_node(n1) + add_link_node(n2) + + if isinstance(n2, CommSpanNode): + G.add_edge(n1, n2, weight=score) + elif isinstance(n2, GVarNode): + G.add_edge(n2, n1, weight=score) + else: + report_bad_link(n1, n2) + + @add_link.register + def _(n1: ParameterSettingNode, n2, score): + add_link_node(n1) + add_link_node(n2) + + if isinstance(n2, GVarNode): + G.add_edge(n2, n1, weight=score) + else: + report_bad_link(n1, n2) + + @add_link.register + def _(n1: UnitNode, n2, score): + add_link_node(n1) + add_link_node(n2) + + if isinstance(n2, GVarNode): + G.add_edge(n2, n1, weight=score) + else: + report_bad_link(n1, n2) + + @add_link.register + def _(n1: GEqnVarNode, n2, score): + add_link_node(n1) + add_link_node(n2) + + if isinstance(n2, GVarNode): + G.add_edge(n1, n2, weight=score) + else: + report_bad_link(n1, n2) + + @add_link.register + def _(n1: GVarNode, n2, score): + add_link_node(n1) + add_link_node(n2) + + if isinstance(n2, (ParameterSettingNode, UnitNode)): + G.add_edge(n1, n2, weight=score) + else: + report_bad_link(n1, n2) + + def build_link_node(element, type): + # Element ids are structured like: ::. We want just the uid. + uid = element.split("::")[0] + # Go to its item type and gets its data information + node_data = [ + item for item in grounding_information[type] if item["uid"] == uid + ][0] + + return LinkNode.from_dict(node_data, type, grounding_information) + + def update_type(found_type): + # In link objects, the to/from node type is specified as + # "param_setting_via_idfr" but in the top level grounding information + # it is specified as "parameter_setting_via_idfr" + if found_type == "param_setting_via_idfr": + return "parameter_setting_via_idfr" + elif found_type == "interval_param_setting_via_idfr": + return "int_param_setting_via_idfr" + elif found_type == "param_setting_via_cpcpt": + return "parameter_setting_via_cncpt" + elif found_type == "interval_param_setting_via_cpcpt": + return "int_param_setting_via_cncpt" + elif found_type == "source": + return "gl_src_var" + elif found_type == "equation": + return "gl_eq_var" + elif found_type == "comment": + return "gl_comm" + elif found_type == "unit_via_cpcpt": + return "unit_via_cncpt" + return found_type + + link_hypotheses = grounding_information["links"] + for link_dict in link_hypotheses: + + link_type = link_dict["link_type"] + (node_1_type, node_2_type) = map(update_type, tuple(link_type.split("_to_"))) + + node1 = build_link_node(link_dict["element_1"], node_1_type) + node2 = build_link_node(link_dict["element_2"], node_2_type) + + link_score = round(link_dict["score"], 3) + add_link(node1, node2, link_score) + + return G + + +def extract_link_tables(L: DiGraph) -> dict: + var_nodes = [n for n in L.nodes if isinstance(n, GCodeVarNode)] + + tables = dict() + for var_node in var_nodes: + var_name = str(var_node) + if var_name not in tables: + table_rows = var_node.get_table_rows(L) + table_rows.sort( + key=lambda r: [ + r[f"{key}_score"] + for key in ["vc", "ct", "te"] + if r[f"{key}_score"] is not None + ], + reverse=True, + ) + tables[var_name] = table_rows + + return tables + + +def print_table_data(table_data: dict) -> None: + for var_name, table in table_data.items(): + print(var_name) + print("L-SCORE\tComment\tV-C\tText-span\tC-T\tEquation\tT-E") + for row in table: + row_data = [ + str(row["link_score"]), + row["comm"], + str(row["vc_score"]), + row["txt"], + str(row["ct_score"]), + row["eqn"], + str(row["te_score"]), + ] + print("\t".join(row_data)) + print("\n\n") diff --git a/automates/model_assembly/interfaces.py b/automates/model_assembly/interfaces.py index 2ef56a855..d9b2383fe 100644 --- a/automates/model_assembly/interfaces.py +++ b/automates/model_assembly/interfaces.py @@ -1,11 +1,37 @@ import os import json from typing import List, Dict, NoReturn +from abc import ABC, abstractmethod import requests -class TextReadingInterface: +class TextReadingInterface(ABC): + # TODO I dislike how these methods take file paths to read from and then + # pass that information to TR app. However, sometimes the TR app requires + # the path to files in its payload (this is really bad and needs to be + # changed). Eventually should move away from this model. + + @abstractmethod + def extract_mentions(self, doc_path: str, out_path: str) -> dict: + pass + + @abstractmethod + def get_link_hypotheses( + self, + mentions_path: str, + eqns_path: str, + grfn_path: str, + comments_path: str, + ) -> dict: + pass + + @abstractmethod + def ground_to_SVO(self, mentions_path: str) -> dict: + pass + + +class TextReadingAppInterface(TextReadingInterface): def __init__(self, addr): self.webservice = addr @@ -40,17 +66,37 @@ def extract_mentions(self, doc_path: str, out_path: str) -> dict: print(f"HTTP {res} for /json_doc_to_mentions on {doc_path}") else: - raise ValueError(f"Unknown input document extension in file {doc_path} (pdf or json expected)") + raise ValueError( + f"Unknown input document extension in file {doc_path} (pdf or json expected)" + ) return json.load(open(out_path, "r")) + def get_grfn_link_hypothesis( + self, mentions_path: str, eqns_path: str, grfn_path: str, comments_path: str + ) -> dict: + if not grfn_path.endswith(".json"): + raise ValueError("/align expects GrFN to be a JSON file") + + grfn_data = json.load(open(grfn_path, "r")) + unique_var_names = list( + { + "::".join(var_def["identifier"].split("::")[:-1]) + "::0" + for var_def in grfn_data["variables"] + } + ) + variable_names = [{"name": var_name} for var_name in unique_var_names] + + return self.get_link_hypotheses( + mentions_path, eqns_path, comments_path, variable_names + ) + def get_link_hypotheses( self, mentions_path: str, eqns_path: str, - grfn_path: str, comments_path: str, - wikidata_path: str + variable_names: list, ) -> dict: if not os.path.isfile(mentions_path): raise RuntimeError(f"Mentions not found: {mentions_path}") @@ -58,9 +104,6 @@ def get_link_hypotheses( if not os.path.isfile(eqns_path): raise RuntimeError(f"Equations not found: {eqns_path}") - if not os.path.isfile(grfn_path): - raise RuntimeError(f"GrFN not found: {grfn_path}") - if not os.path.isfile(comments_path): raise RuntimeError(f"Comments not found: {comments_path}") @@ -70,22 +113,9 @@ def get_link_hypotheses( if not eqns_path.endswith(".txt"): raise ValueError("/align expects equations to be a text file") - if not grfn_path.endswith(".json"): - raise ValueError("/align expects GrFN to be a JSON file") - if not comments_path.endswith(".json"): raise ValueError("/align expects comments to be a JSON file") - grfn_data = json.load(open(grfn_path, "r")) - - unique_var_names = list( - { - "::".join(var_def["identifier"].split("::")[:-1]) + "::0" - for var_def in grfn_data["variables"] - } - ) - variables = [{"name": var_name} for var_name in unique_var_names] - equations = list() with open(eqns_path, "r") as infile: for eqn_line in infile.readlines(): @@ -96,12 +126,11 @@ def get_link_hypotheses( "documents": mentions_path, "equations": equations, "source_code": { - "variables": variables, + "variables": variable_names, "comments": json.load(open(comments_path, "r")), }, - "toggles": {"groundToSVO": False, "groundToWiki": False, "saveWikiGroundings": False, "appendToGrFN": False}, + "toggles": {"groundToSVO": False, "appendToGrFN": False}, "arguments": {"maxSVOgroundingsPerVar": 5}, - "wikidata": wikidata_path } payload_path = f"{os.getcwd()}/align_payload.json" json.dump(payload, open(payload_path, "w")) @@ -111,18 +140,17 @@ def get_link_hypotheses( headers={"Content-type": "application/json"}, json={"pathToJson": payload_path}, ) - print(f"HTTP {res} for /align on:\n\t{mentions_path}\n\t{grfn_path}\n") + print(f"HTTP {res} for /align on:\n\t{mentions_path}\n\t{variable_names}\n") json_dict = res.json() return json_dict def ground_to_SVO(self, mentions_path: str) -> dict: + if not os.path.isfile(mentions_path): raise RuntimeError(f"Mentions file not found: {mentions_path}") if not mentions_path.endswith(".json"): - raise ValueError( - "/groundMentionsToSVO expects mentions to be a JSON file" - ) + raise ValueError("/groundMentionsToSVO expects mentions to be a JSON file") res = requests.post( f"{self.webservice}/groundMentionsToSVO", @@ -135,6 +163,60 @@ def ground_to_SVO(self, mentions_path: str) -> dict: return json_dict +class LocalTextReadingInterface(TextReadingInterface): + name: str + + def __init__(self, name): + # key into the index of the files we will be dealing with + self.name = name + # build an index of local tr mentions and alignment files. These are + # all stored in the scripts/model_assembly/example-tr-data folder. + self.index = { + "SIR-simple": { + "mentions": "tests/data/example-tr-data/sir-simple-documents/SIR-simple--mentions.json", + "alignment": "tests/data/example-tr-data/sir-simple-documents/SIR-simple--alignment.json", + }, + "CHIME_SIR": { + # TODO mentions file doesnt actually have data in it, just mocked + "mentions": "tests/data/example-tr-data/chime-sir/CHIME-SIR--mentions.json", + "alignment": "tests/data/example-tr-data/chime-sir/CHIME-SIR--alignment.json", + }, + "CHIME_SVIIvR": { + # TODO mentions file doesnt actually have data in it, just mocked + "mentions": "tests/data/example-tr-data/chime-sviivr/CHIME-SVIIvR--mentions.json", + "alignment": "tests/data/example-tr-data/chime-sviivr/CHIME_SVIIvR--GrFN3--alignment.json", + }, + } + + def extract_mentions(self, doc_path: str, out_path: str) -> dict: + if self.name in self.index: + # NOTE: This is a silly bit of code, but the user expects the + # given out path file to hold the mentions data after running. + # So fill that in then return the object, + mentions_obj = json.load(open(self.index[self.name]["mentions"], "r")) + json.dump(mentions_obj, open(out_path, "w")) + return mentions_obj + else: + # TODO + raise Exception(f"Error: Unable to find local TR data for {self.name}") + + def get_link_hypotheses( + self, + mentions_path: str, + eqns_path: str, + grfn_path: str, + comments_path: str, + ) -> dict: + if self.name in self.index: + return json.load(open(self.index[self.name]["alignment"], "r"))["grounding"] + else: + # TODO + raise Exception(f"Error: Unable to find local TR data for {self.name}") + + def ground_to_SVO(self, mentions_path: str) -> dict: + pass + + class EquationReadingInterface: # TODO: define this for interface to EqDec and Cosmos equation-detection pass diff --git a/automates/model_assembly/metadata.py b/automates/model_assembly/metadata.py index d44fe9f8c..153397334 100644 --- a/automates/model_assembly/metadata.py +++ b/automates/model_assembly/metadata.py @@ -2,13 +2,15 @@ from abc import ABC, abstractclassmethod, abstractmethod from copy import deepcopy from enum import Enum, auto, unique -from dataclasses import dataclass +from dataclasses import dataclass, asdict from datetime import datetime from typing import List, Union, Type, Dict from time import time from ..utils.misc import uuid +from automates.program_analysis.CAST2GrFN.model.cast.boolean import Boolean + CategoricalTypes = Union[bool, str, int] NumericalTypes = Union[int, float] @@ -19,7 +21,7 @@ class MissingEnumError(Exception): class AutoMATESBaseEnum(Enum): def __str__(self): - return str(self.name).lower() + return self.name.lower() @abstractclassmethod def from_str(cls, child_cls: Type, data: str): @@ -36,10 +38,13 @@ class MetadataType(AutoMATESBaseEnum): NONE = auto() GRFN_CREATION = auto() EQUATION_EXTRACTION = auto() - TEXT_EXTRACTION = auto() + TEXT_DEFINITION = auto() CODE_SPAN_REFERENCE = auto() CODE_COLLECTION_REFERENCE = auto() DOMAIN = auto() + PARAMETER_SETTING = auto() + EQUATION_PARAMETER = auto() + TEXT_UNIT = auto() FROM_SOURCE = auto() @classmethod @@ -48,14 +53,22 @@ def from_str(cls, data: str): @classmethod def get_metadata_class(cls, mtype: MetadataType) -> TypedMetadata: - if mtype == cls.GRFN_CREATION: + if mtype == MetadataType.GRFN_CREATION: return GrFNCreation - elif mtype == cls.CODE_SPAN_REFERENCE: + elif mtype == MetadataType.CODE_SPAN_REFERENCE: return CodeSpanReference - elif mtype == cls.CODE_COLLECTION_REFERENCE: + elif mtype == MetadataType.CODE_COLLECTION_REFERENCE: return CodeCollectionReference - elif mtype == cls.DOMAIN: + elif mtype == MetadataType.DOMAIN: return Domain + elif mtype == MetadataType.TEXT_DEFINITION: + return VariableTextDefinition + elif mtype == MetadataType.PARAMETER_SETTING: + return VariableTextParameter + elif mtype == MetadataType.EQUATION_PARAMETER: + return VariableEquationParameter + elif mtype == MetadataType.TEXT_UNIT: + return VariableTextUnit elif mtype == cls.FROM_SOURCE: return VariableFromSource else: @@ -141,35 +154,73 @@ class LambdaType(AutoMATESBaseEnum): PACK = auto() OPERATOR = auto() - def __str__(self): - return str(self.name) + def shortname(self): + return self.__str__()[0] + + @classmethod + def get_lambda_type(cls, type_str: str, num_inputs: int): + expected_val = cls.from_str(type_str) + if num_inputs == 0 and expected_val == cls.ASSIGN: + return cls.LITERAL + else: + return expected_val + + @classmethod + def from_str(cls, data: str): + return super().from_str(cls, data) + + +@unique +class FunctionType(AutoMATESBaseEnum): + ASSIGN = auto() + LITERAL = auto() + CONDITION = auto() + DECISION = auto() + INTERFACE = auto() + EXTRACT = auto() + PACK = auto() + OPERATOR = auto() + CONDITIONAL = auto() + CONTAINER = auto() + ITERABLE = auto() def shortname(self): return self.__str__()[0] @classmethod def get_lambda_type(cls, type_str: str, num_inputs: int): - if type_str == "assign": - if num_inputs == 0: - return cls.LITERAL - return cls.ASSIGN - elif type_str == "condition": - return cls.CONDITION - elif type_str == "decision": - return cls.DECISION - elif type_str == "interface": - return cls.INTERFACE - elif type_str == "pack": - return cls.PACK - elif type_str == "extract": - return cls.EXTRACT + expected_val = cls.from_str(type_str) + if num_inputs == 0 and expected_val == cls.ASSIGN: + return cls.LITERAL + else: + return expected_val + + @classmethod + def from_con(cls, con_cls_name: str) -> FunctionType: + if con_cls_name == "FuncContainerDef": + return cls.CONTAINER + elif con_cls_name == "CondContainerDef": + return cls.CONDITIONAL + elif con_cls_name == "LoopContainerDef": + return cls.ITERABLE else: - raise ValueError(f"Unrecognized lambda type name: {type_str}") + raise ValueError(f"Unexpected Container type: {con_cls_name}") @classmethod def from_str(cls, data: str): return super().from_str(cls, data) + @classmethod + def is_expression_type(cls, type_str: str) -> Boolean: + return cls.from_str(type_str) in [ + cls.ASSIGN, + cls.DECISION, + cls.CONDITION, + cls.INTERFACE, + cls.EXTRACT, + cls.PACK, + ] + @unique class DataType(AutoMATESBaseEnum): @@ -248,7 +299,7 @@ def get_dt_timestamp() -> datetime: return datetime.fromtimestamp(time()) @classmethod - def from_data(cls, data: dict) -> ProvenanceData: + def from_dict(cls, data: dict) -> ProvenanceData: return cls(MetadataMethod.from_str(data["method"]), data["timestamp"]) def to_dict(self): @@ -261,19 +312,26 @@ class TypedMetadata(BaseMetadata): provenance: ProvenanceData @abstractclassmethod - def from_data(cls, data): + def from_dict(cls, data): data = deepcopy(data) mtype = MetadataType.from_str(data["type"]) - provenance = ProvenanceData.from_data(data["provenance"]) + provenance = ProvenanceData.from_dict(data["provenance"]) ChildMetadataClass = MetadataType.get_metadata_class(mtype) data.update({"type": mtype, "provenance": provenance}) - return ChildMetadataClass.from_data(data) + return ChildMetadataClass.from_dict(data) def to_dict(self): - return { - "type": str(self.type), - "provenance": self.provenance.to_dict(), - } + def as_dict_enum_factory(data): + def convert_value(obj): + if isinstance(obj, Enum): + return obj.name + elif isinstance(obj, datetime): + return str(obj) + return obj + + return dict((k, convert_value(v)) for k, v in data) + + return asdict(self, dict_factory=as_dict_enum_factory) @dataclass @@ -296,7 +354,7 @@ def get_ref_with_default(ref: str) -> Union[int, None]: ) @classmethod - def from_data(cls, data: dict) -> CodeSpan: + def from_dict(cls, data: dict) -> CodeSpan: return cls(**data) def to_dict(self): @@ -322,7 +380,7 @@ def from_str(cls, filepath: str) -> CodeFileReference: return cls(str(uuid.uuid4()), filename, dirpath) @classmethod - def from_data(cls, data: dict) -> CodeFileReference: + def from_dict(cls, data: dict) -> CodeFileReference: return cls(**data) def to_dict(self): @@ -337,7 +395,7 @@ class DomainInterval(BaseMetadata): u_inclusive: bool @classmethod - def from_data(cls, data: dict) -> DomainInterval: + def from_dict(cls, data: dict) -> DomainInterval: return cls(**data) def to_dict(self): @@ -355,7 +413,7 @@ class DomainSet(BaseMetadata): predicate: str @classmethod - def from_data(cls, data: dict) -> DomainSet: + def from_dict(cls, data: dict) -> DomainSet: return cls(SuperSet.from_str(data["superset"]), data["predicate"]) def to_dict(self): @@ -372,7 +430,7 @@ class CodeSpanReference(TypedMetadata): code_span: CodeSpan @classmethod - def from_air_data(cls, data: dict) -> CodeSpanReference: + def from_air_json(cls, data: dict) -> CodeSpanReference: return cls( MetadataType.CODE_SPAN_REFERENCE, ProvenanceData( @@ -385,13 +443,13 @@ def from_air_data(cls, data: dict) -> CodeSpanReference: ) @classmethod - def from_data(cls, data: dict) -> CodeSpanReference: + def from_dict(cls, data: dict) -> CodeSpanReference: return cls( data["type"], data["provenance"], CodeSpanType.from_str(data["code_type"]), data["code_file_reference_uid"], - CodeSpan.from_data(data["code_span"]), + CodeSpan.from_dict(data["code_span"]), ) def to_dict(self): @@ -416,7 +474,6 @@ class VariableCreationReason(AutoMATESBaseEnum): COMPLEX_RETURN_EXPR = auto() CONDITION_RESULT = auto() LOOP_EXIT_VAR = auto() - LITERAL_FUNCTION_ARG = auto() def __str__(self): return str(self.name) @@ -445,11 +502,11 @@ def from_air_data(cls, data: dict) -> VariableFromSource: ) @classmethod - def from_data(cls, data: dict) -> VariableFromSource: + def from_dict(cls, data: dict) -> VariableFromSource: return cls( data["type"], data["provenance"], - data["from_source"] or data["from_source"] == "True", + bool(data["from_source"]), VariableCreationReason.from_str(data["creation_reason"]), ) @@ -481,7 +538,7 @@ def from_name(cls, filepath: str) -> GrFNCreation: ) @classmethod - def from_data(cls, data: dict) -> GrFNCreation: + def from_dict(cls, data: dict) -> GrFNCreation: return cls(**data) def to_dict(self): @@ -490,6 +547,131 @@ def to_dict(self): return data +@dataclass +class EquationExtraction(BaseMetadata): + source_type: str + document_reference_uid: str + equation_number: int + + @classmethod + def from_dict(cls, data: dict) -> EquationExtraction: + return cls("equation_document_source", "", data["equation_number"]) + + def to_dict(self) -> str: + return NotImplemented + + +@dataclass +class TextSpan(BaseMetadata): + char_begin: int + char_end: int + + @classmethod + def from_dict(cls, data: dict) -> TextSpan: + return cls(data["char_begin"], data["char_end"]) + + def to_dict(self) -> str: + return NotImplemented + + +@dataclass +class TextSpanRef(BaseMetadata): + page: int + block: int + span: TextSpan + + @classmethod + def from_dict(cls, data: dict) -> TextSpanRef: + return cls(None, None, TextSpan.from_dict(data["span"])) + + def to_dict(self) -> str: + return NotImplemented + + +@dataclass +class TextExtraction(BaseMetadata): + source_type: str + document_reference_uid: str + text_spans: List[TextSpanRef] + + @classmethod + def from_dict(cls, data: dict) -> TextExtraction: + return cls( + "text_document_source", + "", + [TextSpanRef.from_dict(span) for span in data["text_spans"]], + ) + + def to_dict(self) -> str: + return NotImplemented + + +@dataclass +class VariableEquationParameter(TypedMetadata): + equation_extraction: EquationExtraction + variable_identifier: str + value: str + + @classmethod + def from_dict(cls, data: dict) -> VariableEquationParameter: + return cls( + data["type"], + data["provenance"], + EquationExtraction.from_dict(data["equation_extraction"]), + data["variable_identifier"], + data["value"], + ) + + +@dataclass +class VariableTextDefinition(TypedMetadata): + text_extraction: TextExtraction + variable_identifier: str + variable_definition: str + + @classmethod + def from_dict(cls, data: dict) -> VariableTextDefinition: + return cls( + data["type"], + data["provenance"], + TextExtraction.from_dict(data["text_extraction"]), + data["variable_identifier"], + data["variable_definition"], + ) + + +@dataclass +class VariableTextParameter(TypedMetadata): + text_extraction: TextExtraction + variable_identifier: str + value: str + + @classmethod + def from_dict(cls, data: dict) -> VariableTextParameter: + return cls( + data["type"], + data["provenance"], + TextExtraction.from_dict(data["text_extraction"]), + data["variable_identifier"], + data["value"], + ) + + +@dataclass +class VariableTextUnit(TypedMetadata): + text_extraction: TextExtraction + unit: str + + @classmethod + def from_dict(cls, data: dict) -> VariableTextUnit: + return cls( + data["type"], + data["provenance"], + TextExtraction.from_dict(data["unit_extraction"]), + data["unit"], + ) + + @dataclass class CodeCollectionReference(TypedMetadata): global_reference_uid: str @@ -508,12 +690,12 @@ def from_sources(cls, sources: List[str]) -> CodeCollectionReference: ) @classmethod - def from_data(cls, data: dict) -> CodeCollectionReference: + def from_dict(cls, data: dict) -> CodeCollectionReference: return cls( data["type"], data["provenance"], data["global_reference_uid"], - [CodeFileReference.from_data(d) for d in data["files"]], + [CodeFileReference.from_dict(d) for d in data["files"]], ) def to_dict(self): @@ -544,12 +726,12 @@ class Domain(TypedMetadata): elements: List[DomainElement] @classmethod - def from_data(cls, data: dict) -> Domain: + def from_dict(cls, data: dict) -> Domain: mtype = MeasurementType.from_str(data["measurement_scale"]) if MeasurementType.isa_categorical(mtype): - els = [DomainSet.from_data(dom_el) for dom_el in data["elements"]] + els = [DomainSet.from_dict(dom_el) for dom_el in data["elements"]] elif MeasurementType.isa_numerical(mtype): - els = [DomainInterval.from_data(dom_el) for dom_el in data["elements"]] + els = [DomainInterval.from_dict(dom_el) for dom_el in data["elements"]] else: els = [] return cls( diff --git a/automates/model_assembly/text_reading_linker.py b/automates/model_assembly/text_reading_linker.py new file mode 100644 index 000000000..51c8df199 --- /dev/null +++ b/automates/model_assembly/text_reading_linker.py @@ -0,0 +1,293 @@ +from automates.model_assembly.metadata import ( + MetadataMethod, + ProvenanceData, + TypedMetadata, + VariableTextDefinition, +) +import os + +from automates.model_assembly.networks import GroundedFunctionNetwork +from automates.model_assembly.interfaces import TextReadingInterface +from automates.model_assembly.linking import ( + GCodeVarNode, + GCommSpanNode, + GEqnVarNode, + GVarNode, + ParameterSettingNode, + UnitNode, + build_link_graph, +) +from automates.model_assembly.metadata import ( + TypedMetadata, + ProvenanceData, + MetadataMethod, +) + + +class TextReadingLinker: + text_reading_interface: TextReadingInterface + + def __init__(self, text_reading_interface) -> None: + self.text_reading_interface = text_reading_interface + + def gather_tr_sources(grfn: GroundedFunctionNetwork): + """ + Given a GrFN, gather the following required sources for TR: + 1. Comment text document + 2. Source document text + 3. Equations document + + TODO At this point these files will be passed into the GrFN translation + process. Eventually we need to attempt to automatically generate this + data for all GrFNs passed in. + + Args: + grfn (GroundedFunctionNetwork): GrFN to generate TR data for + """ + pass + + def groundings_to_metadata(self, groundings): + vars_to_metadata = {} + for var, grounding in groundings.items(): + vars_to_metadata[var] = list() + + provenance = { + "method": "TEXT_READING_PIPELINE", + "timestamp": ProvenanceData.get_dt_timestamp(), + } + + for text_definition in grounding["text_definition"]: + vars_to_metadata[var].append( + TypedMetadata.from_dict( + { + "type": "TEXT_DEFINITION", + "provenance": provenance, + "text_extraction": text_definition["text_extraction"], + "variable_identifier": grounding["gvar"], + "variable_definition": text_definition["variable_def"], + } + ) + ) + + for param_setting in grounding["parameter_setting"]: + vars_to_metadata[var].append( + TypedMetadata.from_dict( + { + "type": "PARAMETER_SETTING", + "provenance": provenance, + "text_extraction": param_setting["text_extraction"], + "variable_identifier": grounding["gvar"], + "value": param_setting["value"], + } + ) + ) + + for equation_parameter in grounding["equation_parameter"]: + vars_to_metadata[var].append( + TypedMetadata.from_dict( + { + "type": "EQUATION_PARAMETER", + "provenance": provenance, + "equation_extraction": equation_parameter[ + "equation_extraction" + ], + "variable_identifier": grounding["gvar"], + "value": equation_parameter["value"], + } + ) + ) + + for text_unit in grounding["text_unit"]: + vars_to_metadata[var].append( + TypedMetadata.from_dict( + { + "type": "TEXT_UNIT", + "provenance": provenance, + "unit_extraction": text_unit["unit_extraction"], + "variable_identifier": grounding["gvar"], + "unit": text_unit["unit"], + } + ) + ) + return vars_to_metadata + + def build_text_extraction(self, text_extraction): + return { + "text_spans": [ + { + "span": { + "char_begin": span.char_begin, + "char_end": span.char_end, + } + } + for span in text_extraction.spans + ] + } + + def build_text_definition(self, gvar: GVarNode): + return [ + { + "variable_def": text_var.content, + "text_extraction": self.build_text_extraction(text_var.text_extraction), + } + for text_var in gvar.text_vars + ] + + def build_text_unit(self, gvar: GVarNode, L): + + text_unit_settings = [ + text_unit + for text_unit in L.successors(gvar) + if isinstance(text_unit, UnitNode) + ] + + if len(text_unit_settings) == 0: + return [] + + selected_text_unit = max( + text_unit_settings, key=lambda unit: L.edges[gvar, unit]["weight"] + ) + + return [ + { + "unit": selected_text_unit.content, + "unit_extraction": self.build_text_extraction( + selected_text_unit.text_extraction + ), + } + ] + + def build_parameter_setting(self, gvar: GVarNode, L): + parameter_settings = [ + param + for param in L.successors(gvar) + if isinstance(param, ParameterSettingNode) + ] + + if len(parameter_settings) == 0: + return [] + + selected_parameter_setting = max( + parameter_settings, key=lambda unit: L.edges[gvar, unit]["weight"] + ) + + return [ + { + "value": selected_parameter_setting.content, + "text_extraction": self.build_text_extraction( + selected_parameter_setting.text_extraction + ), + } + ] + + def build_equation_groundings(self, gvar: GVarNode, L): + equation_vars = [ + param for param in L.predecessors(gvar) if isinstance(param, GEqnVarNode) + ] + + if len(equation_vars) == 0: + return [] + + selected_g_eq_var = max( + equation_vars, key=lambda eq: L.edges[eq, gvar]["weight"] + ) + + return [ + { + "value": selected_eqn_var.content, + "equation_extraction": { + "equation_number": selected_eqn_var.equation_number + }, + } + for selected_eqn_var in selected_g_eq_var.equation_nodes + ] + + def get_links_from_graph(self, L): + grfn_var_to_groundings = {} + for gcode_var in [n for n in L.nodes if isinstance(n, GCodeVarNode)]: + gcode_var_name = gcode_var.content + + gcomm_nodes_to_gvars = { + comm: [ + gvar for gvar in L.predecessors(comm) if isinstance(gvar, GVarNode) + ] + for comm in L.predecessors(gcode_var) + if isinstance(comm, GCommSpanNode) + } + + for gcomm, gvars in gcomm_nodes_to_gvars.items(): + for gvar in gvars: + score = min( + L.edges[gcomm, gcode_var]["weight"], + L.edges[gvar, gcomm]["weight"], + ) + if ( + gcode_var_name not in grfn_var_to_groundings + or grfn_var_to_groundings[gcode_var_name]["score"] < score + ): + grfn_var_to_groundings[gcode_var_name] = { + "score": score, + "gvar": gvar.content, + "equation_parameter": self.build_equation_groundings( + gvar, L + ), + "parameter_setting": self.build_parameter_setting(gvar, L), + "text_definition": self.build_text_definition(gvar), + "text_unit": self.build_text_unit(gvar, L), + } + + return grfn_var_to_groundings + + def perform_tr_grfn_linking(self, grfn: GroundedFunctionNetwork, tr_sources: dict): + """ + Enriches the given grfn with text reading metadata given the text reading + source files (comments, source document, and equations text files). + + Args: + grfn (GroundedFunctionNetwork): [description] + tr_sources (dict, optional): [description]. Defaults to None. + """ + + # Make sure all required sources are given + for document_type in ["doc_file", "comm_file", "eqn_file"]: + if document_type not in tr_sources: + print( + f"Error: required TR source {document_type} not passed " + + "into TR-GrFN linking." + ) + return grfn + + # Generate temporary output file names for TR mentions + cur_dir = os.getcwd() + mentions_path = f"{cur_dir}/mentions.json" + + # Generate variables list for linking + variable_ids = [v.identifier for k, v in grfn.variables.items()] + + # Build the hypothesis data by first getting mentions then generating + # the hypothesis + self.text_reading_interface.extract_mentions( + tr_sources["doc_file"], mentions_path + ) + hypothesis_data = self.text_reading_interface.get_link_hypotheses( + mentions_path, + tr_sources["eqn_file"], + tr_sources["comm_file"], + variable_ids, + ) + + # Cleanup temp files + if os.path.isfile(mentions_path): + os.remove(mentions_path) + + L = build_link_graph(hypothesis_data) + grfn_var_to_groundings = self.get_links_from_graph(L) + vars_to_metadata = self.groundings_to_metadata(grfn_var_to_groundings) + + for var_id, var in grfn.variables.items(): + var_name = var_id.name + if var_name in vars_to_metadata: + for metadata in vars_to_metadata[var_name]: + var.add_metadata(metadata) + + return grfn diff --git a/scripts/model_assembly/alignment_experiment.py b/scripts/model_assembly/alignment_experiment.py index eed6f3509..9e063e8c2 100644 --- a/scripts/model_assembly/alignment_experiment.py +++ b/scripts/model_assembly/alignment_experiment.py @@ -2,7 +2,7 @@ import json import os -from automates.model_assembly.interfaces import TextReadingInterface +from automates.model_assembly.interfaces import TextReadingAppInterface def main(args): @@ -12,17 +12,18 @@ def main(args): MENTIONS_PATH = f"{CUR_DIR}/{MODEL_NAME}--mentions.json" ALIGNMENT_PATH = f"{CUR_DIR}/{MODEL_NAME}--alignment.json" - caller = TextReadingInterface(f"http://{args.address}:{args.port}") + caller = TextReadingAppInterface(f"http://{args.address}:{args.port}") if not os.path.isfile(MENTIONS_PATH): - caller.extract_mentions(args.doc_file, MENTIONS_PATH) + caller.get_grfn_link_hypothesis(args.doc_file, MENTIONS_PATH) else: - print(f"Mentions have been previously extracted and are stored in {MENTIONS_PATH}") - + print( + f"Mentions have been previously extracted and are stored in {MENTIONS_PATH}" + ) hypothesis_data = caller.get_link_hypotheses( - MENTIONS_PATH, args.eqn_file, args.grfn_file, args.comm_file, args.wikidata_file, + MENTIONS_PATH, args.eqn_file, args.grfn_file, args.comm_file ) - json.dump({"grounding": hypothesis_data}, open(ALIGNMENT_PATH, "w", encoding='utf8'), ensure_ascii=False) + json.dump({"grounding": hypothesis_data}, open(ALIGNMENT_PATH, "w")) if __name__ == "__main__": @@ -31,7 +32,6 @@ def main(args): parser.add_argument("comm_file", help="filepath to a comments JSON file") parser.add_argument("doc_file", help="filepath to a source paper file (COSMOS or Science Parse)") parser.add_argument("eqn_file", help="filepath to an equations txt file") - parser.add_argument("--wikidata_file", help="filepath to a wikidata grounding json file", type=str, default="None") parser.add_argument( "-a", "--address", diff --git a/scripts/model_assembly/grfn_links_to_csv.py b/scripts/model_assembly/grfn_links_to_csv.py index 1876c1d3f..0e02b9098 100644 --- a/scripts/model_assembly/grfn_links_to_csv.py +++ b/scripts/model_assembly/grfn_links_to_csv.py @@ -30,22 +30,17 @@ def main(): outpath = filepath.replace("-alignment.json", "-link-tables.csv") with open(outpath, "w", newline="") as csvfile: link_writer = csv.writer(csvfile, dialect="excel") - for var_name, var_data in tables.items(): - (_, sub_name_str, basename_str, idx_str) = var_name.split("\n") - (_, sub_name) = sub_name_str.split(": ") - (_, basename) = basename_str.split(": ") - (_, idx) = idx_str.split(": ") - short_varname = "::".join([sub_name, basename, idx]) + short_varname = var_name link_writer.writerow([short_varname]) link_writer.writerow( [ "Link Score", - "Var-Comm Score", - "Comm-Text Score", - "Text-Eqn Score", + "GlobalCodeVar-GlobalComm Score", + "GlobalComm-GlobalVar Score", + "GlobalVar-GlobalEqn Score", "Comment Span", - "Text Mention", + "Text Mention(s) List", "Equation Symbol", ] ) @@ -56,10 +51,10 @@ def main(): link_data["link_score"], link_data["vc_score"], link_data["ct_score"], - link_data["te_score"], + link_data["te_score"] if "te_score" in link_data else None, link_data["comm"], link_data["txt"], - link_data["eqn"], + link_data["eqn"] if "eqn" in link_data else None, ] ) link_writer.writerow([]) diff --git a/scripts/model_assembly/grfn_tr_alignment_merge.py b/scripts/model_assembly/grfn_tr_alignment_merge.py new file mode 100644 index 000000000..a0e018322 --- /dev/null +++ b/scripts/model_assembly/grfn_tr_alignment_merge.py @@ -0,0 +1,45 @@ +import argparse +import json +import os + +from automates.model_assembly.air import AutoMATES_IR +from automates.model_assembly.networks import GroundedFunctionNetwork +from automates.model_assembly.text_reading_linker import TextReadingLinker +from automates.model_assembly.interfaces import ( + LocalTextReadingInterface, +) + + +def main(args): + + air_filepath = args.air_file + air_json_data = json.load(open(air_filepath, "r")) + AIR = AutoMATES_IR.from_air_json(air_json_data) + GrFN = GroundedFunctionNetwork.from_AIR(AIR) + + name = air_filepath.split("/")[-1].rsplit("--AIR", 1)[0] + tr_interface = LocalTextReadingInterface(name) + tr_linker = TextReadingLinker(tr_interface) + + GrFN = tr_linker.perform_tr_grfn_linking( + GrFN, + { + "comm_file": args.comm_file, + "eqn_file": args.eqn_file, + "doc_file": args.doc_file, + }, + ) + + grfn_file = air_filepath.replace("--AIR.json", "_with_metadata--GrFN3.json") + GrFN.to_json_file(grfn_file) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("air_file", help="filepath to a AIR JSON file") + parser.add_argument("comm_file", help="filepath to a comments JSON file") + parser.add_argument("doc_file", help="filepath to a source text pdf file") + parser.add_argument("eqn_file", help="filepath to an equations txt file") + + args = parser.parse_args() + main(args) diff --git a/scripts/model_assembly/variable_name_alignment.py b/scripts/model_assembly/variable_name_alignment.py new file mode 100644 index 000000000..8a2b2de2f --- /dev/null +++ b/scripts/model_assembly/variable_name_alignment.py @@ -0,0 +1,59 @@ +""" +Code is taken from "alignment_experiment.py". The difference in this file +is that we will read in a list of var names for the alignment that is not +grounded in an actual GrFN. +""" + +import argparse +import json +import os + +from automates.model_assembly.interfaces import TextReadingAppInterface + + +def main(args): + CUR_DIR = os.getcwd() + MODEL_NAME = os.path.basename(args.names).replace("--vars.json", "") + + MENTIONS_PATH = f"{CUR_DIR}/{MODEL_NAME}--mentions.json" + ALIGNMENT_PATH = f"{CUR_DIR}/{MODEL_NAME}--alignment.json" + + caller = TextReadingAppInterface(f"http://{args.address}:{args.port}") + if not os.path.isfile(MENTIONS_PATH): + caller.extract_mentions(args.doc_file, MENTIONS_PATH) + else: + print( + f"Mentions have been previously extracted and are stored in {MENTIONS_PATH}" + ) + + variable_names_json = json.load(open(args.names, "r")) + variable_names = [{"name": name} for name in variable_names_json["variables"]] + + hypothesis_data = caller.get_link_hypotheses( + MENTIONS_PATH, args.eqn_file, args.comm_file, variable_names + ) + json.dump({"grounding": hypothesis_data}, open(ALIGNMENT_PATH, "w")) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("names", help="filepath to a json file with variable names") + parser.add_argument("comm_file", help="filepath to a comments JSON file") + parser.add_argument("doc_file", help="filepath to a source text pdf file") + parser.add_argument("eqn_file", help="filepath to an equations txt file") + parser.add_argument( + "-a", + "--address", + type=str, + default="localhost", + help="Address to reach the TextReading webapp", + ) + parser.add_argument( + "-p", + "--port", + type=int, + default=9000, + help="Port to reach the TextReading webapp", + ) + args = parser.parse_args() + main(args)