From 4251a381ef3ea9b0dba562c8602426ca6698aa4d Mon Sep 17 00:00:00 2001 From: Huffon Date: Fri, 14 May 2021 01:51:34 +0900 Subject: [PATCH] #11 Implement BERTScore-based module --- README.md | 21 ++++++++++++++ factsumm/__init__.py | 47 +++++++++++++++++++++++++++++++- factsumm/utils/level_sentence.py | 17 +++++++++++- factsumm/utils/utils.py | 1 + 4 files changed, 84 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index dd12b56..09ef427 100644 --- a/README.md +++ b/README.md @@ -113,6 +113,11 @@ Triple Score: 0.5 Avg. ROUGE-1: 0.4415584415584415 Avg. ROUGE-2: 0.3287671232876712 Avg. ROUGE-L: 0.4415584415584415 + +BERTScore Score +Precision: 0.9151781797409058 +Recall: 0.9141832590103149 +F1: 0.9150083661079407 ```
@@ -241,6 +246,22 @@ Simple but effective word-level overlap ROUGE score
+### BERTScore Module + +```python +>>> from factsumm import FactSumm +>>> factsumm = FactSumm() +>>> factsumm.calculate_bert_score(article, summary) +BERTScore Score +Precision: 0.9151781797409058 +Recall: 0.9141832590103149 +F1: 0.9150083661079407 +``` + +[BERTScore](https://github.com/Tiiiger/bert_score) can be used to calculate the similarity between each source sentence and the summary sentence + +
+ ### Citation If you apply this library to any project, please cite: diff --git a/factsumm/__init__.py b/factsumm/__init__.py index 6ba8dc3..4d046ce 100644 --- a/factsumm/__init__.py +++ b/factsumm/__init__.py @@ -8,7 +8,7 @@ from sumeval.metrics.rouge import RougeCalculator from factsumm.utils.level_entity import load_ie, load_ner, load_rel -from factsumm.utils.level_sentence import load_qa, load_qg +from factsumm.utils.level_sentence import load_bert_score, load_qa, load_qg from factsumm.utils.utils import Config, qags_score os.environ["TOKENIZERS_PARALLELISM"] = "false" @@ -26,6 +26,7 @@ def __init__( rel_model: str = None, qg_model: str = None, qa_model: str = None, + bert_score_model: str = None, ): self.config = Config() self.segmenter = pysbd.Segmenter(language="en", clean=False) @@ -36,6 +37,7 @@ def __init__( self.rel = rel_model if rel_model is not None else self.config.REL_MODEL self.qg = qg_model if qg_model is not None else self.config.QG_MODEL self.qa = qa_model if qa_model is not None else self.config.QA_MODEL + self.bert_score = bert_score_model if bert_score_model is not None else self.config.BERT_SCORE_MODEL self.ie = None def build_perm( @@ -321,12 +323,46 @@ def extract_triples(self, source: str, summary: str, verbose: bool = False): return triple_score + def calculate_bert_score(self, source: str, summary: str): + """ + Calculate BERTScore + + See also https://arxiv.org/abs/2005.03754 + + Args: + source (str): original source + summary (str): generated summary + + """ + add_dummy = False + + if isinstance(self.bert_score, str): + self.bert_score = load_bert_score(self.bert_score) + + source_lines = self._segment(source) + summary_lines = [summary, "dummy"] + + scores = self.bert_score(summary_lines, source_lines) + filtered_scores = list() + + for score in scores: + score = score.tolist() + score.pop(-1) + filtered_scores.append(sum(score) / len(score)) + + print( + f"BERTScore Score\nPrecision: {filtered_scores[0]}\nRecall: {filtered_scores[1]}\nF1: {filtered_scores[2]}" + ) + + return filtered_scores + def __call__(self, source: str, summary: str, verbose: bool = False): source_ents, summary_ents, fact_score = self.extract_facts( source, summary, verbose, ) + qags_score = self.extract_qas( source, summary, @@ -334,12 +370,21 @@ def __call__(self, source: str, summary: str, verbose: bool = False): summary_ents, verbose, ) + triple_score = self.extract_triples(source, summary, verbose) + rouge_1, rouge_2, rouge_l = self.calculate_rouge(source, summary) + bert_scores = self.calculate_bert_score(source, summary) + return { "fact_score": fact_score, "qa_score": qags_score, "triple_score": triple_score, "rouge": (rouge_1, rouge_2, rouge_l), + "bert_score": { + "precision": bert_scores[0], + "recall": bert_scores[1], + "f1": bert_scores[2], + }, } diff --git a/factsumm/utils/level_sentence.py b/factsumm/utils/level_sentence.py index ffd676f..ded05c5 100644 --- a/factsumm/utils/level_sentence.py +++ b/factsumm/utils/level_sentence.py @@ -1,5 +1,6 @@ from typing import List +from bert_score import BERTScorer from rich import print from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, pipeline @@ -111,4 +112,18 @@ def answer_question(context: str, qa_pairs: List): return answer_question -# TODO: NLI, FactCC +def load_bert_score(model: str): + """ + Load BERTScore model from HuggingFace hub + + Args: + model (str): model name to be loaded + + Returns: + function: BERTScore score function + + """ + print("Loading BERTScore Pipeline...") + + scorer = BERTScorer(model_type=model, lang="en", rescale_with_baseline=True) + return scorer.score diff --git a/factsumm/utils/utils.py b/factsumm/utils/utils.py index 3e59f65..66afb39 100644 --- a/factsumm/utils/utils.py +++ b/factsumm/utils/utils.py @@ -15,6 +15,7 @@ class Config: QG_MODEL: str = "mrm8488/t5-base-finetuned-question-generation-ap" QA_MODEL: str = "deepset/roberta-base-squad2" SUMM_MODEL: str = "sshleifer/distilbart-cnn-12-6" + BERT_SCORE_MODEL: str = "microsoft/deberta-base-mnli" def grouped_entities(entities: List[Dict]):