-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpersona_eval.py
72 lines (58 loc) · 1.67 KB
/
persona_eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import json
import math
def docs(w, history_list):
c = 0
for i, h in enumerate(history_list):
if w in h:
c += 1
return c
def gen_idf_dict(history_list):
idf_dict = {}
for i, h in enumerate(history_list):
for w in h:
if w not in idf_dict:
idf = math.log(len(history_list) * 1.0 / docs(w, history_list))
idf_dict[w] = idf
return idf_dict
def cal_s_for_each_history(r, h, idf_dict):
c = 0
has_c = {}
for w in r:
if w in h and w not in has_c:
c += idf_dict[w]
has_c[w] = 1
return c
def cal_p_cover(src_generate, all_personas):
s_sum = 0
line_cnt = 0
for result, personas in zip(src_generate, all_personas):
idf_dict = gen_idf_dict(personas)
s_list = []
for i, persona in enumerate(personas):
s = cal_s_for_each_history(result, persona, idf_dict)
s_list.append(s)
s_max = max(s_list)
s_sum += s_max
line_cnt += 1
return (s_sum + 0.0) / line_cnt
def cal_f1(result, personas):
p_all = []
for i, p in enumerate(personas):
p_all += p
h_set = set(p_all)
r_set = set(result)
if len(h_set) == 0 or len(r_set) == 0:
p, r = 0, 0
else:
p = len(h_set & r_set) / len(r_set)
r = len(h_set & r_set) / len(h_set)
if p == r == 0:
return 0
return (2 * p * r) / (p + r)
def cal_p_f1(src_generate, all_personas):
s_sum = 0
line_cnt = 0
for result, personas in zip(src_generate, all_personas):
s_sum += cal_f1(result, personas)
line_cnt += 1
return (s_sum + 0.0) / line_cnt