Skip to content

Commit b6e13f8

Browse files
committed
Added utility to compute BERT f1 score
1 parent a2ff835 commit b6e13f8

File tree

1 file changed

+32
-0
lines changed

1 file changed

+32
-0
lines changed

get_bert_f1.py

+32
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import json
2+
3+
from data_util.evaluate import f1_score
4+
5+
PREDICTIONS_FILE_NAME = './bert_test/predictions.json'
6+
DEV_DATA_FILE_NAME = './data/dev-v1.1.json'
7+
8+
with open(DEV_DATA_FILE_NAME) as json_file:
9+
sets = json.load(json_file)['data']
10+
11+
id_answer_text_dict = {}
12+
print('Loading dev data')
13+
for topic_set in sets:
14+
for paragraph in topic_set['paragraphs']:
15+
for qas in paragraph['qas']:
16+
id_answer_text_dict[qas['id']] = qas['answers'][0]['text']
17+
18+
print('Loading predictions')
19+
with open(PREDICTIONS_FILE_NAME) as json_file:
20+
predictions = json.load(json_file)
21+
22+
count = 0
23+
f1 = 0
24+
for qas_id, prediction in predictions.items():
25+
ground_truth = id_answer_text_dict[qas_id]
26+
f1 += f1_score(prediction, ground_truth)
27+
count += 1
28+
29+
f1 = f1 / (count + 1)
30+
31+
# Print the f1 score to console
32+
print('Computed f1: ', f1)

0 commit comments

Comments
 (0)