Skip to content

Commit

Permalink
[fix] replace aggregation logic with huggingface official logic
Browse files Browse the repository at this point in the history
  • Loading branch information
karter-liner committed Jan 1, 2024
1 parent 9ac7210 commit 84f4cdf
Show file tree
Hide file tree
Showing 5 changed files with 134 additions and 196 deletions.
150 changes: 72 additions & 78 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,68 +42,64 @@ pip install .
>>> article = "Lionel Andrés Messi (born 24 June 1987) is an Argentine professional footballer who plays as a forward and captains both Spanish club Barcelona and the Argentina national team. Often considered as the best player in the world and widely regarded as one of the greatest players of all time, Messi has won a record six Ballon d'Or awards, a record six European Golden Shoes, and in 2020 was named to the Ballon d'Or Dream Team."
>>> summary = "Lionel Andrés Messi (born 24 Aug 1997) is an Spanish professional footballer who plays as a forward and captains both Spanish club Barcelona and the Spanish national team."
>>> factsumm(article, summary, verbose=True)
SOURCE Entities
1: [('Lionel Andrés Messi', 'PERSON'), ('24 June 1987', 'DATE'), ('Argentine', 'NORP'), ('Spanish', 'NORP'), ('Barcelona',
'GPE'), ('Argentina', 'GPE')]
2: [('one', 'CARDINAL'), ('Messi', 'PERSON'), ('six', 'CARDINAL'), ('European Golden Shoes', 'WORK_OF_ART'), ('2020', 'DATE'),
("the Ballon d'Or Dream Team", 'ORG')]
<Source Entities>
Line No.1: [[('Lionel Andrés Messi', 'PERSON'), ('24 June 1987', 'DATE'), ('Argentine', 'NORP'), ('Spanish', 'NORP'), ('Barcelona', 'ORG'), ('Argentina', 'GPE')]]
Line No.2: [[('one', 'CARDINAL'), ('Messi', 'PERSON'), ('six', 'CARDINAL'), ("Ballon d'Or", 'WORK_OF_ART'), ('European Golden Shoes', 'WORK_OF_ART'), ('2020', 'DATE'), ("Ballon d'Or Dream Team", 'WORK_OF_ART')]]

SUMMARY Entities
1: [('Lionel Andrés Messi', 'PERSON'), ('24 Aug 1997', 'DATE'), ('Spanish', 'NORP'), ('Barcelona', 'ORG')]
<Summary Entities>
Line No.1: [[('Lionel Andrés Messi', 'PERSON'), ('24 Aug 1997', 'DATE'), ('Spanish', 'NORP'), ('Barcelona', 'ORG')]]

SOURCE Facts
<Source Facts>
('Lionel Andrés Messi', 'per:origin', 'Argentine')
('Spanish', 'per:date_of_birth', '24 June 1987')
('Spanish', 'org:top_members/employees', 'Lionel Andrés Messi')
('Spanish', 'org:members', 'Barcelona')
('Lionel Andrés Messi', 'per:employee_of', 'Barcelona')
('Barcelona', 'org:country_of_headquarters', 'Spanish')
('Lionel Andrés Messi', 'per:date_of_birth', '24 June 1987')
('Lionel Andrés Messi', 'per:countries_of_residence', 'Argentina')
('Spanish', 'org:top_members/employees', 'Lionel Andrés Messi')
('Barcelona', 'org:top_members/employees', 'Lionel Andrés Messi')

SUMMARY Facts
('Lionel Andrés Messi', 'per:origin', 'Spanish')
<Summary Facts>
('Lionel Andrés Messi', 'per:employee_of', 'Barcelona')
('Lionel Andrés Messi', 'per:date_of_birth', '24 Aug 1997')
('Spanish', 'per:date_of_birth', '24 Aug 1997')
('Barcelona', 'org:country_of_headquarters', 'Spanish')
('Spanish', 'org:top_members/employees', 'Lionel Andrés Messi')
('Spanish', 'org:members', 'Barcelona')
('Lionel Andrés Messi', 'per:employee_of', 'Barcelona')
('Barcelona', 'org:top_members/employees', 'Lionel Andrés Messi')
('Lionel Andrés Messi', 'per:origin', 'Spanish')
('Lionel Andrés Messi', 'per:countries_of_residence', 'Spanish')

COMMON Facts
('Spanish', 'org:top_members/employees', 'Lionel Andrés Messi')
('Spanish', 'org:members', 'Barcelona')
('Lionel Andrés Messi', 'per:employee_of', 'Barcelona')
<Common Facts>
('Barcelona', 'org:top_members/employees', 'Lionel Andrés Messi')
('Lionel Andrés Messi', 'per:employee_of', 'Barcelona')
('Barcelona', 'org:country_of_headquarters', 'Spanish')
('Spanish', 'org:top_members/employees', 'Lionel Andrés Messi')

DIFF Facts
('Lionel Andrés Messi', 'per:origin', 'Spanish')
<Diff Facts>
('Lionel Andrés Messi', 'per:date_of_birth', '24 Aug 1997')
('Spanish', 'per:date_of_birth', '24 Aug 1997')
('Lionel Andrés Messi', 'per:origin', 'Spanish')
('Lionel Andrés Messi', 'per:countries_of_residence', 'Spanish')

Fact Score: 0.5714285714285714

Answers based on SOURCE (Questions are generated from Summary)
[Q] Who is the captain of the Spanish national team? [Pred] <unanswerable>
[Q] When was Lionel Andrés Messi born? [Pred] 24 June 1987
[Q] Lionel Andrés Messi is a professional footballer of what nationality? [Pred] Argentine
[Q] Lionel Messi is a captain of which Spanish club? [Pred] Barcelona

Answers based on SUMMARY (Questions are generated from Summary)
[Q] Who is the captain of the Spanish national team? [Pred] Lionel Andrés Messi
[Q] When was Lionel Andrés Messi born? [Pred] 24 Aug 1997
[Q] Lionel Andrés Messi is a professional footballer of what nationality? [Pred] Spanish
[Q] Lionel Messi is a captain of which Spanish club? [Pred] Barcelona
Answers based on Source (Questions are generated from Summary)
[Q] Who is the captain of the Spanish national team? [Pred] <unanswerable>
[Q] When was Lionel Andrés Messi born? [Pred] 24 June 1987
[Q] Lionel Andrés Messi is a professional footballer of what nationality? [Pred] Argentine
[Q] Lionel Messi is a captain of which Spanish club? [Pred] Barcelona

Answers based on Summary (Questions are generated from Summary)
[Q] Who is the captain of the Spanish national team? [Pred] Lionel Andrés Messi
[Q] When was Lionel Andrés Messi born? [Pred] 24 Aug 1997
[Q] Lionel Andrés Messi is a professional footballer of what nationality? [Pred] Spanish
[Q] Lionel Messi is a captain of which Spanish club? [Pred] Barcelona

QAGS Score: 0.3333333333333333

Avg. ROUGE-1: 0.4415584415584415
Avg. ROUGE-2: 0.3287671232876712
Avg. ROUGE-L: 0.4415584415584415

BERTScore Score
Precision: 0.9151781797409058
Recall: 0.9141832590103149
F1: 0.9150083661079407
<BERTScore Score>
Precision: 0.9760397672653198
Recall: 0.9778039455413818
F1: 0.9769210815429688
```

You can use the GPU with the `device`. If you want to use GPU, pass `cuda` (default is `cpu`)
Expand All @@ -126,43 +122,41 @@ From [here](https://arxiv.org/pdf/2104.14839.pdf), you can find various way to s
>>> from factsumm import FactSumm
>>> factsumm = FactSumm()
>>> factsumm.extract_facts(article, summary, verbose=True)
SOURCE Entities
1: [('Lionel Andrés Messi', 'PERSON'), ('24 June 1987', 'DATE'), ('Argentine', 'NORP'), ('Spanish', 'NORP'), ('Barcelona',
'GPE'), ('Argentina', 'GPE')]
2: [('one', 'CARDINAL'), ('Messi', 'PERSON'), ('six', 'CARDINAL'), ('European Golden Shoes', 'WORK_OF_ART'), ('2020', 'DATE'),
("the Ballon d'Or Dream Team", 'ORG')]
<Source Entities>
Line No.1: [[('Lionel Andrés Messi', 'PERSON'), ('24 June 1987', 'DATE'), ('Argentine', 'NORP'), ('Spanish', 'NORP'), ('Barcelona', 'ORG'), ('Argentina', 'GPE')]]
Line No.2: [[('one', 'CARDINAL'), ('Messi', 'PERSON'), ('six', 'CARDINAL'), ("Ballon d'Or", 'WORK_OF_ART'), ('European Golden Shoes', 'WORK_OF_ART'), ('2020', 'DATE'), ("Ballon d'Or Dream Team", 'WORK_OF_ART')]]

SUMMARY Entities
1: [('Lionel Andrés Messi', 'PERSON'), ('24 Aug 1997', 'DATE'), ('Spanish', 'NORP'), ('Barcelona', 'ORG')]
<Summary Entities>
Line No.1: [[('Lionel Andrés Messi', 'PERSON'), ('24 Aug 1997', 'DATE'), ('Spanish', 'NORP'), ('Barcelona', 'ORG')]]

SOURCE Facts
<Source Facts>
('Lionel Andrés Messi', 'per:origin', 'Argentine')
('Spanish', 'per:date_of_birth', '24 June 1987')
('Spanish', 'org:top_members/employees', 'Lionel Andrés Messi')
('Spanish', 'org:members', 'Barcelona')
('Lionel Andrés Messi', 'per:employee_of', 'Barcelona')
('Barcelona', 'org:country_of_headquarters', 'Spanish')
('Lionel Andrés Messi', 'per:date_of_birth', '24 June 1987')
('Lionel Andrés Messi', 'per:countries_of_residence', 'Argentina')
('Spanish', 'org:top_members/employees', 'Lionel Andrés Messi')
('Barcelona', 'org:top_members/employees', 'Lionel Andrés Messi')

SUMMARY Facts
('Lionel Andrés Messi', 'per:origin', 'Spanish')
<Summary Facts>
('Lionel Andrés Messi', 'per:employee_of', 'Barcelona')
('Lionel Andrés Messi', 'per:date_of_birth', '24 Aug 1997')
('Spanish', 'per:date_of_birth', '24 Aug 1997')
('Barcelona', 'org:country_of_headquarters', 'Spanish')
('Spanish', 'org:top_members/employees', 'Lionel Andrés Messi')
('Spanish', 'org:members', 'Barcelona')
('Lionel Andrés Messi', 'per:employee_of', 'Barcelona')
('Barcelona', 'org:top_members/employees', 'Lionel Andrés Messi')
('Lionel Andrés Messi', 'per:origin', 'Spanish')
('Lionel Andrés Messi', 'per:countries_of_residence', 'Spanish')

COMMON Facts
('Spanish', 'org:top_members/employees', 'Lionel Andrés Messi')
('Spanish', 'org:members', 'Barcelona')
('Lionel Andrés Messi', 'per:employee_of', 'Barcelona')
<Common Facts>
('Barcelona', 'org:top_members/employees', 'Lionel Andrés Messi')
('Lionel Andrés Messi', 'per:employee_of', 'Barcelona')
('Barcelona', 'org:country_of_headquarters', 'Spanish')
('Spanish', 'org:top_members/employees', 'Lionel Andrés Messi')

DIFF Facts
('Lionel Andrés Messi', 'per:origin', 'Spanish')
<Diff Facts>
('Lionel Andrés Messi', 'per:date_of_birth', '24 Aug 1997')
('Spanish', 'per:date_of_birth', '24 Aug 1997')
('Lionel Andrés Messi', 'per:origin', 'Spanish')
('Lionel Andrés Messi', 'per:countries_of_residence', 'Spanish')

Fact Score: 0.5714285714285714
```
Expand All @@ -181,17 +175,17 @@ If you ask questions about the summary and the source document, you will get a s
>>> from factsumm import FactSumm
>>> factsumm = FactSumm()
>>> factsumm.extract_qas(article, summary, verbose=True)
Answers based on SOURCE (Questions are generated from Summary)
[Q] Who is the captain of the Spanish national team? [Pred] <unanswerable>
[Q] When was Lionel Andrés Messi born? [Pred] 24 June 1987
[Q] Lionel Andrés Messi is a professional footballer of what nationality? [Pred] Argentine
[Q] Lionel Messi is a captain of which Spanish club? [Pred] Barcelona

Answers based on SUMMARY (Questions are generated from Summary)
[Q] Who is the captain of the Spanish national team? [Pred] Lionel Andrés Messi
[Q] When was Lionel Andrés Messi born? [Pred] 24 Aug 1997
[Q] Lionel Andrés Messi is a professional footballer of what nationality? [Pred] Spanish
[Q] Lionel Messi is a captain of which Spanish club? [Pred] Barcelona
Answers based on Source (Questions are generated from Summary)
[Q] Who is the captain of the Spanish national team? [Pred] <unanswerable>
[Q] When was Lionel Andrés Messi born? [Pred] 24 June 1987
[Q] Lionel Andrés Messi is a professional footballer of what nationality? [Pred] Argentine
[Q] Lionel Messi is a captain of which Spanish club? [Pred] Barcelona

Answers based on Summary (Questions are generated from Summary)
[Q] Who is the captain of the Spanish national team? [Pred] Lionel Andrés Messi
[Q] When was Lionel Andrés Messi born? [Pred] 24 Aug 1997
[Q] Lionel Andrés Messi is a professional footballer of what nationality? [Pred] Spanish
[Q] Lionel Messi is a captain of which Spanish club? [Pred] Barcelona

QAGS Score: 0.3333333333333333
```
Expand Down Expand Up @@ -219,10 +213,10 @@ Simple but effective word-level overlap ROUGE score
>>> from factsumm import FactSumm
>>> factsumm = FactSumm()
>>> factsumm.calculate_bert_score(article, summary)
BERTScore Score
Precision: 0.9151781797409058
Recall: 0.9141832590103149
F1: 0.9150083661079407
<BERTScore Score>
Precision: 0.9760397672653198
Recall: 0.9778039455413818
F1: 0.9769210815429688
```

[BERTScore](https://github.com/Tiiiger/bert_score) can be used to calculate the similarity between each source sentence and the summary sentence
Expand Down
74 changes: 50 additions & 24 deletions factsumm/factsumm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from factsumm.utils.module_entity import load_ner, load_rel
from factsumm.utils.module_question import load_qa, load_qg
from factsumm.utils.module_sentence import load_bert_score
from factsumm.utils.utils import Config, qags_score
from factsumm.utils.utils import Config, score_qags

os.environ["TOKENIZERS_PARALLELISM"] = "false"

Expand Down Expand Up @@ -99,7 +99,7 @@ def get_facts(self, lines: List[str], entities: List[List[Dict]]) -> Set:
triples = []

for perm, entity in zip(perms, entities):
entity_key = {ent["word"].replace("▁", ""): ent["entity"] for ent in entity}
entity_key = {ent["word"]: ent["entity_group"] for ent in entity}
facts = self.rel(perm)
filtered_facts = []

Expand All @@ -109,6 +109,9 @@ def get_facts(self, lines: List[str], entities: List[List[Dict]]) -> Set:
head = head.strip()
tail = tail.strip()

if head == tail:
continue

head_entity_type = entity_key.get(head, None)
tail_entity_type = entity_key.get(tail, None)

Expand Down Expand Up @@ -143,7 +146,14 @@ def _segment_sentence(self, text: str) -> List[str]:
def _print_entities(self, mode: str, total_entities: List[List[Dict]]):
logging.info("<%s Entities>", mode.capitalize())
for i, line_entities in enumerate(total_entities):
logging.info("Line No.%s: [%s]", i+1, [(entity["word"].replace("▁", ""), entity["entity"]) for entity in line_entities])
printable_elements = []
dedup = {}
for entity in line_entities:
if entity["word"] not in dedup:
printable_elements.append((entity["word"], entity["entity_group"]))
dedup[entity["word"]] = True

logging.info("Line No.%s: [%s]", i+1, printable_elements)
logging.info("")

def calculate_rouge(
Expand Down Expand Up @@ -319,7 +329,7 @@ def extract_qas(
self._print_qas("source", source_answers)
self._print_qas("summary", summary_answers)

qa_score = qags_score(source_answers, summary_answers)
qa_score = score_qags(source_answers, summary_answers)
logging.info("QAGS Score: %s\n", qa_score)

return qa_score
Expand All @@ -329,7 +339,7 @@ def calculate_bert_score(
source: str,
summary: str,
device: str = "cpu",
) -> Dict[str, float]:
) -> Tuple[float, float, float]:
"""
Calculate BERTScore
Expand All @@ -341,7 +351,7 @@ def calculate_bert_score(
device (str): device info
Returns:
Dict: (Precision, Recall, F1) BERTScore dictionary
Tuple[float]: (Precision, Recall, F1) BERTScore tuple
"""
if self.bert_score is None:
Expand All @@ -350,26 +360,25 @@ def calculate_bert_score(
source_lines = self._segment_sentence(source)
summary_lines = self._segment_sentence(summary)

scores = {
"precision": 0.0,
"recall": 0.0,
"f1": 0.0,
}
total_precision = 0.0
total_recall = 0.0
total_f1 = 0.0

for summary_line in summary_lines:
precision, recall, f1 = self.bert_score([summary_line], [source_lines])
scores["precision"] += precision.item()
scores["recall"] += recall.item()
scores["f1"] += f1.item()

total_precision += precision.item()
total_recall += recall.item()
total_f1 += f1.item()

if len(summary_lines) > 1:
scores["precision"] /= len(summary_lines)
scores["recall"] /= len(summary_lines)
scores["f1"] /= len(summary_lines)
total_precision /= len(summary_lines)
total_recall /= len(summary_lines)
total_f1 /= len(summary_lines)

logging.info("<BERTScore Score>\nPrecision: %s\nRecall: %s\nF1: %s", scores["precision"], scores["recall"], scores["f1"])
logging.info("<BERTScore Score>\nPrecision: %s\nRecall: %s\nF1: %s", total_precision, total_recall, total_f1)

return scores
return total_precision, total_recall, total_f1

def __call__(
self,
Expand All @@ -378,12 +387,14 @@ def __call__(
verbose: bool = False,
device: str = "cpu",
) -> Dict:
if isinstance(sources, str) and isinstance(summaries, str):
if isinstance(sources, str):
sources = [sources]

if isinstance(summaries, str):
summaries = [summaries]

if len(sources) != len(summaries):
raise ValueError("`sources` and `summaries` must have the same number of elements!")
raise ValueError("`sources` and `summaries` should have the same number of elements!")

num_pairs = len(sources)

Expand Down Expand Up @@ -424,11 +435,26 @@ def __call__(
rouge_scores["rouge-2"] += rouge_2
rouge_scores["rouge-l"] += rouge_l

bert_scores = self.calculate_bert_score(source, summary, device)
precision, recall, f1 = self.calculate_bert_score(source, summary, device)
bert_scores["precision"] += precision
bert_scores["recall"] += recall
bert_scores["f1"] += f1

if num_pairs > 1:
fact_scores /= num_pairs
qags_scores /= num_pairs

rouge_scores["rouge-1"] /= num_pairs
rouge_scores["rouge-2"] /= num_pairs
rouge_scores["rouge-l"] /= num_pairs

bert_scores["precision"] /= num_pairs
bert_scores["recall"] /= num_pairs
bert_scores["f1"] /= num_pairs

return {
"fact_score": fact_scores / num_pairs,
"qa_score": qags_scores / num_pairs,
"fact_score": fact_scores,
"qa_score": qags_scores,
"rouge": rouge_scores,
"bert_score": bert_scores,
}
Loading

0 comments on commit 84f4cdf

Please sign in to comment.