diff --git a/README.md b/README.md index 4507a0a..e8e02ae 100644 --- a/README.md +++ b/README.md @@ -42,68 +42,64 @@ pip install . >>> article = "Lionel Andrés Messi (born 24 June 1987) is an Argentine professional footballer who plays as a forward and captains both Spanish club Barcelona and the Argentina national team. Often considered as the best player in the world and widely regarded as one of the greatest players of all time, Messi has won a record six Ballon d'Or awards, a record six European Golden Shoes, and in 2020 was named to the Ballon d'Or Dream Team." >>> summary = "Lionel Andrés Messi (born 24 Aug 1997) is an Spanish professional footballer who plays as a forward and captains both Spanish club Barcelona and the Spanish national team." >>> factsumm(article, summary, verbose=True) -SOURCE Entities -1: [('Lionel Andrés Messi', 'PERSON'), ('24 June 1987', 'DATE'), ('Argentine', 'NORP'), ('Spanish', 'NORP'), ('Barcelona', -'GPE'), ('Argentina', 'GPE')] -2: [('one', 'CARDINAL'), ('Messi', 'PERSON'), ('six', 'CARDINAL'), ('European Golden Shoes', 'WORK_OF_ART'), ('2020', 'DATE'), -("the Ballon d'Or Dream Team", 'ORG')] + +Line No.1: [[('Lionel Andrés Messi', 'PERSON'), ('24 June 1987', 'DATE'), ('Argentine', 'NORP'), ('Spanish', 'NORP'), ('Barcelona', 'ORG'), ('Argentina', 'GPE')]] +Line No.2: [[('one', 'CARDINAL'), ('Messi', 'PERSON'), ('six', 'CARDINAL'), ("Ballon d'Or", 'WORK_OF_ART'), ('European Golden Shoes', 'WORK_OF_ART'), ('2020', 'DATE'), ("Ballon d'Or Dream Team", 'WORK_OF_ART')]] -SUMMARY Entities -1: [('Lionel Andrés Messi', 'PERSON'), ('24 Aug 1997', 'DATE'), ('Spanish', 'NORP'), ('Barcelona', 'ORG')] + +Line No.1: [[('Lionel Andrés Messi', 'PERSON'), ('24 Aug 1997', 'DATE'), ('Spanish', 'NORP'), ('Barcelona', 'ORG')]] -SOURCE Facts + ('Lionel Andrés Messi', 'per:origin', 'Argentine') -('Spanish', 'per:date_of_birth', '24 June 1987') -('Spanish', 'org:top_members/employees', 'Lionel Andrés Messi') -('Spanish', 'org:members', 'Barcelona') ('Lionel Andrés Messi', 'per:employee_of', 'Barcelona') +('Barcelona', 'org:country_of_headquarters', 'Spanish') ('Lionel Andrés Messi', 'per:date_of_birth', '24 June 1987') +('Lionel Andrés Messi', 'per:countries_of_residence', 'Argentina') +('Spanish', 'org:top_members/employees', 'Lionel Andrés Messi') ('Barcelona', 'org:top_members/employees', 'Lionel Andrés Messi') -SUMMARY Facts -('Lionel Andrés Messi', 'per:origin', 'Spanish') + +('Lionel Andrés Messi', 'per:employee_of', 'Barcelona') ('Lionel Andrés Messi', 'per:date_of_birth', '24 Aug 1997') -('Spanish', 'per:date_of_birth', '24 Aug 1997') +('Barcelona', 'org:country_of_headquarters', 'Spanish') ('Spanish', 'org:top_members/employees', 'Lionel Andrés Messi') -('Spanish', 'org:members', 'Barcelona') -('Lionel Andrés Messi', 'per:employee_of', 'Barcelona') ('Barcelona', 'org:top_members/employees', 'Lionel Andrés Messi') +('Lionel Andrés Messi', 'per:origin', 'Spanish') +('Lionel Andrés Messi', 'per:countries_of_residence', 'Spanish') -COMMON Facts -('Spanish', 'org:top_members/employees', 'Lionel Andrés Messi') -('Spanish', 'org:members', 'Barcelona') -('Lionel Andrés Messi', 'per:employee_of', 'Barcelona') + ('Barcelona', 'org:top_members/employees', 'Lionel Andrés Messi') +('Lionel Andrés Messi', 'per:employee_of', 'Barcelona') +('Barcelona', 'org:country_of_headquarters', 'Spanish') +('Spanish', 'org:top_members/employees', 'Lionel Andrés Messi') -DIFF Facts -('Lionel Andrés Messi', 'per:origin', 'Spanish') + ('Lionel Andrés Messi', 'per:date_of_birth', '24 Aug 1997') -('Spanish', 'per:date_of_birth', '24 Aug 1997') +('Lionel Andrés Messi', 'per:origin', 'Spanish') +('Lionel Andrés Messi', 'per:countries_of_residence', 'Spanish') Fact Score: 0.5714285714285714 - -Answers based on SOURCE (Questions are generated from Summary) -[Q] Who is the captain of the Spanish national team? [Pred] -[Q] When was Lionel Andrés Messi born? [Pred] 24 June 1987 -[Q] Lionel Andrés Messi is a professional footballer of what nationality? [Pred] Argentine -[Q] Lionel Messi is a captain of which Spanish club? [Pred] Barcelona - -Answers based on SUMMARY (Questions are generated from Summary) -[Q] Who is the captain of the Spanish national team? [Pred] Lionel Andrés Messi -[Q] When was Lionel Andrés Messi born? [Pred] 24 Aug 1997 -[Q] Lionel Andrés Messi is a professional footballer of what nationality? [Pred] Spanish -[Q] Lionel Messi is a captain of which Spanish club? [Pred] Barcelona +Answers based on Source (Questions are generated from Summary) +[Q] Who is the captain of the Spanish national team? [Pred] +[Q] When was Lionel Andrés Messi born? [Pred] 24 June 1987 +[Q] Lionel Andrés Messi is a professional footballer of what nationality? [Pred] Argentine +[Q] Lionel Messi is a captain of which Spanish club? [Pred] Barcelona + +Answers based on Summary (Questions are generated from Summary) +[Q] Who is the captain of the Spanish national team? [Pred] Lionel Andrés Messi +[Q] When was Lionel Andrés Messi born? [Pred] 24 Aug 1997 +[Q] Lionel Andrés Messi is a professional footballer of what nationality? [Pred] Spanish +[Q] Lionel Messi is a captain of which Spanish club? [Pred] Barcelona QAGS Score: 0.3333333333333333 Avg. ROUGE-1: 0.4415584415584415 Avg. ROUGE-2: 0.3287671232876712 Avg. ROUGE-L: 0.4415584415584415 - -BERTScore Score -Precision: 0.9151781797409058 -Recall: 0.9141832590103149 -F1: 0.9150083661079407 + +Precision: 0.9760397672653198 +Recall: 0.9778039455413818 +F1: 0.9769210815429688 ``` You can use the GPU with the `device`. If you want to use GPU, pass `cuda` (default is `cpu`) @@ -126,43 +122,41 @@ From [here](https://arxiv.org/pdf/2104.14839.pdf), you can find various way to s >>> from factsumm import FactSumm >>> factsumm = FactSumm() >>> factsumm.extract_facts(article, summary, verbose=True) -SOURCE Entities -1: [('Lionel Andrés Messi', 'PERSON'), ('24 June 1987', 'DATE'), ('Argentine', 'NORP'), ('Spanish', 'NORP'), ('Barcelona', -'GPE'), ('Argentina', 'GPE')] -2: [('one', 'CARDINAL'), ('Messi', 'PERSON'), ('six', 'CARDINAL'), ('European Golden Shoes', 'WORK_OF_ART'), ('2020', 'DATE'), -("the Ballon d'Or Dream Team", 'ORG')] + +Line No.1: [[('Lionel Andrés Messi', 'PERSON'), ('24 June 1987', 'DATE'), ('Argentine', 'NORP'), ('Spanish', 'NORP'), ('Barcelona', 'ORG'), ('Argentina', 'GPE')]] +Line No.2: [[('one', 'CARDINAL'), ('Messi', 'PERSON'), ('six', 'CARDINAL'), ("Ballon d'Or", 'WORK_OF_ART'), ('European Golden Shoes', 'WORK_OF_ART'), ('2020', 'DATE'), ("Ballon d'Or Dream Team", 'WORK_OF_ART')]] -SUMMARY Entities -1: [('Lionel Andrés Messi', 'PERSON'), ('24 Aug 1997', 'DATE'), ('Spanish', 'NORP'), ('Barcelona', 'ORG')] + +Line No.1: [[('Lionel Andrés Messi', 'PERSON'), ('24 Aug 1997', 'DATE'), ('Spanish', 'NORP'), ('Barcelona', 'ORG')]] -SOURCE Facts + ('Lionel Andrés Messi', 'per:origin', 'Argentine') -('Spanish', 'per:date_of_birth', '24 June 1987') -('Spanish', 'org:top_members/employees', 'Lionel Andrés Messi') -('Spanish', 'org:members', 'Barcelona') ('Lionel Andrés Messi', 'per:employee_of', 'Barcelona') +('Barcelona', 'org:country_of_headquarters', 'Spanish') ('Lionel Andrés Messi', 'per:date_of_birth', '24 June 1987') +('Lionel Andrés Messi', 'per:countries_of_residence', 'Argentina') +('Spanish', 'org:top_members/employees', 'Lionel Andrés Messi') ('Barcelona', 'org:top_members/employees', 'Lionel Andrés Messi') -SUMMARY Facts -('Lionel Andrés Messi', 'per:origin', 'Spanish') + +('Lionel Andrés Messi', 'per:employee_of', 'Barcelona') ('Lionel Andrés Messi', 'per:date_of_birth', '24 Aug 1997') -('Spanish', 'per:date_of_birth', '24 Aug 1997') +('Barcelona', 'org:country_of_headquarters', 'Spanish') ('Spanish', 'org:top_members/employees', 'Lionel Andrés Messi') -('Spanish', 'org:members', 'Barcelona') -('Lionel Andrés Messi', 'per:employee_of', 'Barcelona') ('Barcelona', 'org:top_members/employees', 'Lionel Andrés Messi') +('Lionel Andrés Messi', 'per:origin', 'Spanish') +('Lionel Andrés Messi', 'per:countries_of_residence', 'Spanish') -COMMON Facts -('Spanish', 'org:top_members/employees', 'Lionel Andrés Messi') -('Spanish', 'org:members', 'Barcelona') -('Lionel Andrés Messi', 'per:employee_of', 'Barcelona') + ('Barcelona', 'org:top_members/employees', 'Lionel Andrés Messi') +('Lionel Andrés Messi', 'per:employee_of', 'Barcelona') +('Barcelona', 'org:country_of_headquarters', 'Spanish') +('Spanish', 'org:top_members/employees', 'Lionel Andrés Messi') -DIFF Facts -('Lionel Andrés Messi', 'per:origin', 'Spanish') + ('Lionel Andrés Messi', 'per:date_of_birth', '24 Aug 1997') -('Spanish', 'per:date_of_birth', '24 Aug 1997') +('Lionel Andrés Messi', 'per:origin', 'Spanish') +('Lionel Andrés Messi', 'per:countries_of_residence', 'Spanish') Fact Score: 0.5714285714285714 ``` @@ -181,17 +175,17 @@ If you ask questions about the summary and the source document, you will get a s >>> from factsumm import FactSumm >>> factsumm = FactSumm() >>> factsumm.extract_qas(article, summary, verbose=True) -Answers based on SOURCE (Questions are generated from Summary) -[Q] Who is the captain of the Spanish national team? [Pred] -[Q] When was Lionel Andrés Messi born? [Pred] 24 June 1987 -[Q] Lionel Andrés Messi is a professional footballer of what nationality? [Pred] Argentine -[Q] Lionel Messi is a captain of which Spanish club? [Pred] Barcelona - -Answers based on SUMMARY (Questions are generated from Summary) -[Q] Who is the captain of the Spanish national team? [Pred] Lionel Andrés Messi -[Q] When was Lionel Andrés Messi born? [Pred] 24 Aug 1997 -[Q] Lionel Andrés Messi is a professional footballer of what nationality? [Pred] Spanish -[Q] Lionel Messi is a captain of which Spanish club? [Pred] Barcelona +Answers based on Source (Questions are generated from Summary) +[Q] Who is the captain of the Spanish national team? [Pred] +[Q] When was Lionel Andrés Messi born? [Pred] 24 June 1987 +[Q] Lionel Andrés Messi is a professional footballer of what nationality? [Pred] Argentine +[Q] Lionel Messi is a captain of which Spanish club? [Pred] Barcelona + +Answers based on Summary (Questions are generated from Summary) +[Q] Who is the captain of the Spanish national team? [Pred] Lionel Andrés Messi +[Q] When was Lionel Andrés Messi born? [Pred] 24 Aug 1997 +[Q] Lionel Andrés Messi is a professional footballer of what nationality? [Pred] Spanish +[Q] Lionel Messi is a captain of which Spanish club? [Pred] Barcelona QAGS Score: 0.3333333333333333 ``` @@ -219,10 +213,10 @@ Simple but effective word-level overlap ROUGE score >>> from factsumm import FactSumm >>> factsumm = FactSumm() >>> factsumm.calculate_bert_score(article, summary) -BERTScore Score -Precision: 0.9151781797409058 -Recall: 0.9141832590103149 -F1: 0.9150083661079407 + +Precision: 0.9760397672653198 +Recall: 0.9778039455413818 +F1: 0.9769210815429688 ``` [BERTScore](https://github.com/Tiiiger/bert_score) can be used to calculate the similarity between each source sentence and the summary sentence diff --git a/factsumm/factsumm.py b/factsumm/factsumm.py index 54d452f..b5bbc59 100644 --- a/factsumm/factsumm.py +++ b/factsumm/factsumm.py @@ -9,7 +9,7 @@ from factsumm.utils.module_entity import load_ner, load_rel from factsumm.utils.module_question import load_qa, load_qg from factsumm.utils.module_sentence import load_bert_score -from factsumm.utils.utils import Config, qags_score +from factsumm.utils.utils import Config, score_qags os.environ["TOKENIZERS_PARALLELISM"] = "false" @@ -99,7 +99,7 @@ def get_facts(self, lines: List[str], entities: List[List[Dict]]) -> Set: triples = [] for perm, entity in zip(perms, entities): - entity_key = {ent["word"].replace("▁", ""): ent["entity"] for ent in entity} + entity_key = {ent["word"]: ent["entity_group"] for ent in entity} facts = self.rel(perm) filtered_facts = [] @@ -109,6 +109,9 @@ def get_facts(self, lines: List[str], entities: List[List[Dict]]) -> Set: head = head.strip() tail = tail.strip() + if head == tail: + continue + head_entity_type = entity_key.get(head, None) tail_entity_type = entity_key.get(tail, None) @@ -143,7 +146,14 @@ def _segment_sentence(self, text: str) -> List[str]: def _print_entities(self, mode: str, total_entities: List[List[Dict]]): logging.info("<%s Entities>", mode.capitalize()) for i, line_entities in enumerate(total_entities): - logging.info("Line No.%s: [%s]", i+1, [(entity["word"].replace("▁", ""), entity["entity"]) for entity in line_entities]) + printable_elements = [] + dedup = {} + for entity in line_entities: + if entity["word"] not in dedup: + printable_elements.append((entity["word"], entity["entity_group"])) + dedup[entity["word"]] = True + + logging.info("Line No.%s: [%s]", i+1, printable_elements) logging.info("") def calculate_rouge( @@ -319,7 +329,7 @@ def extract_qas( self._print_qas("source", source_answers) self._print_qas("summary", summary_answers) - qa_score = qags_score(source_answers, summary_answers) + qa_score = score_qags(source_answers, summary_answers) logging.info("QAGS Score: %s\n", qa_score) return qa_score @@ -329,7 +339,7 @@ def calculate_bert_score( source: str, summary: str, device: str = "cpu", - ) -> Dict[str, float]: + ) -> Tuple[float, float, float]: """ Calculate BERTScore @@ -341,7 +351,7 @@ def calculate_bert_score( device (str): device info Returns: - Dict: (Precision, Recall, F1) BERTScore dictionary + Tuple[float]: (Precision, Recall, F1) BERTScore tuple """ if self.bert_score is None: @@ -350,26 +360,25 @@ def calculate_bert_score( source_lines = self._segment_sentence(source) summary_lines = self._segment_sentence(summary) - scores = { - "precision": 0.0, - "recall": 0.0, - "f1": 0.0, - } + total_precision = 0.0 + total_recall = 0.0 + total_f1 = 0.0 for summary_line in summary_lines: precision, recall, f1 = self.bert_score([summary_line], [source_lines]) - scores["precision"] += precision.item() - scores["recall"] += recall.item() - scores["f1"] += f1.item() + + total_precision += precision.item() + total_recall += recall.item() + total_f1 += f1.item() if len(summary_lines) > 1: - scores["precision"] /= len(summary_lines) - scores["recall"] /= len(summary_lines) - scores["f1"] /= len(summary_lines) + total_precision /= len(summary_lines) + total_recall /= len(summary_lines) + total_f1 /= len(summary_lines) - logging.info("\nPrecision: %s\nRecall: %s\nF1: %s", scores["precision"], scores["recall"], scores["f1"]) + logging.info("\nPrecision: %s\nRecall: %s\nF1: %s", total_precision, total_recall, total_f1) - return scores + return total_precision, total_recall, total_f1 def __call__( self, @@ -378,12 +387,14 @@ def __call__( verbose: bool = False, device: str = "cpu", ) -> Dict: - if isinstance(sources, str) and isinstance(summaries, str): + if isinstance(sources, str): sources = [sources] + + if isinstance(summaries, str): summaries = [summaries] if len(sources) != len(summaries): - raise ValueError("`sources` and `summaries` must have the same number of elements!") + raise ValueError("`sources` and `summaries` should have the same number of elements!") num_pairs = len(sources) @@ -424,11 +435,26 @@ def __call__( rouge_scores["rouge-2"] += rouge_2 rouge_scores["rouge-l"] += rouge_l - bert_scores = self.calculate_bert_score(source, summary, device) + precision, recall, f1 = self.calculate_bert_score(source, summary, device) + bert_scores["precision"] += precision + bert_scores["recall"] += recall + bert_scores["f1"] += f1 + + if num_pairs > 1: + fact_scores /= num_pairs + qags_scores /= num_pairs + + rouge_scores["rouge-1"] /= num_pairs + rouge_scores["rouge-2"] /= num_pairs + rouge_scores["rouge-l"] /= num_pairs + + bert_scores["precision"] /= num_pairs + bert_scores["recall"] /= num_pairs + bert_scores["f1"] /= num_pairs return { - "fact_score": fact_scores / num_pairs, - "qa_score": qags_scores / num_pairs, + "fact_score": fact_scores, + "qa_score": qags_scores, "rouge": rouge_scores, "bert_score": bert_scores, } diff --git a/factsumm/utils/module_entity.py b/factsumm/utils/module_entity.py index d7f1b8e..3f2d144 100644 --- a/factsumm/utils/module_entity.py +++ b/factsumm/utils/module_entity.py @@ -4,8 +4,6 @@ from requests import HTTPError from transformers import LukeForEntityPairClassification, LukeTokenizer, pipeline -from factsumm.utils.utils import grouped_entities - def load_ner(model: str, device: str) -> object: """ @@ -29,24 +27,22 @@ def load_ner(model: str, device: str) -> object: ignore_labels=[], framework="pt", device=-1 if device == "cpu" else 0, + aggregation_strategy="simple", ) except (HTTPError, OSError): logging.warning("Input model is not supported by HuggingFace Hub") raise - def extract_entities_hf(sentences: List[str]): - result = [] + def extract_entities(sentences: List[str]): total_entities = ner(sentences) - if isinstance(total_entities[0], dict): - total_entities = [total_entities] - + result = [] for line_entities in total_entities: - result.append(grouped_entities(line_entities)) + result.append([entity for entity in line_entities if entity["entity_group"] != "O"]) return result - return extract_entities_hf + return extract_entities def load_rel(model: str, device: str): diff --git a/factsumm/utils/module_question.py b/factsumm/utils/module_question.py index dd25dd0..458c93d 100644 --- a/factsumm/utils/module_question.py +++ b/factsumm/utils/module_question.py @@ -40,8 +40,11 @@ def generate_question(sentences: List[str], total_entities: List): qa_pairs = [] for sentence, line_entities in zip(sentences, total_entities): + dedup = {} for entity in line_entities: entity = entity["word"] + if entity in dedup: + continue template = f"answer: {entity} context: {sentence} " @@ -65,6 +68,8 @@ def generate_question(sentences: List[str], total_entities: List): "answer": entity, }) + dedup[entity] = True + return qa_pairs return generate_question diff --git a/factsumm/utils/utils.py b/factsumm/utils/utils.py index 31a59ac..473a2d1 100644 --- a/factsumm/utils/utils.py +++ b/factsumm/utils/utils.py @@ -1,7 +1,7 @@ import re import string from collections import Counter -from typing import Dict, List +from typing import List from transformers import pipeline @@ -14,89 +14,6 @@ class Config: SUMM_MODEL: str = "sshleifer/distilbart-cnn-12-6" -def grouped_entities(entities: List[Dict]) -> List: - """ - Group entities to concatenate BIO - - Args: - entities (List[Dict]): list of inference entities - - Returns: - List[Tuple]: list of grouped BIO scheme entities - - """ - - def _remove_prefix(entity: str) -> str: - if "-" in entity: - entity = entity[2:] - return entity - - def _append(lst: List, word: str, entity_type: str, start: int, end: int): - if prev_word != "": - lst.append((word, entity_type, start, end)) - - result = [] - - prev_word = entities[0]["word"] - prev_entity = entities[0]["entity"] - prev_entity_type = _remove_prefix(prev_entity) - prev_start = entities[0]["start"] - prev_end = entities[0]["end"] - - for pair in entities[1:]: - word = pair["word"] - entity = pair["entity"] - entity_type = _remove_prefix(entity) - start = pair["start"] - end = pair["end"] - - if "##" in word: - prev_word += word - prev_end = end - continue - - if entity == prev_entity: - if entity == "O": - _append(result, prev_word, prev_entity_type, prev_start, prev_end) - result.append((word, entity_type)) - prev_word = "" - prev_start = start - prev_end = end - if "I-" in entity: - prev_word += f" {word}" - prev_end = end - elif (entity != prev_entity) and ("I-" in entity) and (entity_type != "O"): - prev_word += f" {word}" - prev_end = end - else: - _append(result, prev_word, prev_entity_type, prev_start, prev_end) - prev_word = word - prev_entity_type = entity_type - prev_start = start - prev_end = end - - prev_entity = entity - - _append(result, prev_word, prev_entity_type, prev_start, prev_end) - - cache = {} - dedup = [] - - for pair in result: - if pair[1] == "O": - continue - - if pair[0] not in cache: - dedup.append({ - "word": pair[0].replace("##", ""), - "entity": pair[1], - "start": pair[2], - "end": pair[3] - }) - cache[pair[0]] = None - return dedup - - def load_summarizer(model: str) -> object: """ Load Summarization model from HuggingFace hub @@ -162,7 +79,7 @@ def _white_space_fix(text: str): return f1 -def qags_score(source_answers: List, summary_answers: List) -> float: +def score_qags(source_answers: List, summary_answers: List) -> float: """ Caculate QAGS Score