-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy patheprstmt.py
116 lines (92 loc) · 3.56 KB
/
eprstmt.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
# -*- encoding: utf-8 -*-
"""
-------------------------------------------------
File Name: tnews.py
Description :
Author : Wings DH
Time: 6/16/21 10:40 PM
-------------------------------------------------
Change Activity:
6/16/21: Create
-------------------------------------------------
"""
import sys
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
from modeling.classifier import LabelData
from modeling.mlm_encoder import MlmBertEncoder
from modeling.retriever_classifier import RetrieverClassifier
from utils.cls_train import eval_model, dump_result
from absl import app, flags
sys.path.append('../')
sys.path.append('./')
from utils.data_utils import load_data, load_test_data
from utils.seed import set_seed
set_seed()
flags.DEFINE_string('c', '0', 'index of dataset')
FLAGS = flags.FLAGS
def infer(test_data, classifier):
for d in test_data:
sentence = d.pop('sentence')
label, _ = classifier.classify(sentence)
d['label'] = label
return test_data
label_2_desc = {'Positive': '满意',
'Negative': '失望'}
def get_data_fp(use_index):
train_fp = f'dataset/eprstmt/train_{use_index}.json'
dev_fp = f'dataset/eprstmt/dev_{use_index}.json'
test_fp = 'dataset/eprstmt/test.json'
my_test_fp = []
for ind in range(5):
ind = str(ind)
if ind != use_index:
my_test_fp.append(f'dataset/eprstmt/dev_{ind}.json')
return train_fp, dev_fp, my_test_fp, test_fp
def main(_):
# 参数
# 加载数据
train_fp, dev_fp, my_test_fp, test_fp = get_data_fp(FLAGS.c)
key_label = 'label'
key_sentence = 'sentence'
train_data = load_data(train_fp, key_sentence, key_label)
dev_data = load_data(dev_fp, key_sentence, key_label)
data = [LabelData(text, label) for text, label in train_data]
# 初始化encoder
n_top = 3
model_path = '../chinese_roberta_wwm_ext_L-12_H-768_A-12'
weight_path = f'../temp_eprstmt_{FLAGS.c}.weights'
prefix = '对以下产品感到十分锟锟,'
mask_ind = [9, 10]
encoder = MlmBertEncoder(model_path, weight_path, train_data, dev_data, prefix, mask_ind, label_2_desc, 8,
merge=MlmBertEncoder.CONCAT, norm=False, lr=1e-4, policy='attention')
classifier = RetrieverClassifier(encoder, data, n_top=n_top)
rst = eval_model(classifier, [dev_fp], key_sentence, key_label)
print(f'{train_fp} + {dev_fp} -> {rst}')
# fine tune
best_acc = 0
for epoch in range(20):
print(f'Training epoch {epoch}')
encoder.train(1)
# 加载分类器
classifier = RetrieverClassifier(encoder, data, n_top=n_top)
print('Eval model')
rst = eval_model(classifier, [dev_fp], key_sentence, key_label)
print(f'{train_fp} + {dev_fp} -> {rst}')
if rst >= best_acc and epoch > 5:
encoder.save()
best_acc = rst
print(f'Save for best {best_acc}')
# 加载最终模型
encoder.load()
classifier = RetrieverClassifier(encoder, data, n_top=n_top)
# 自测试集测试
rst = eval_model(classifier, my_test_fp, key_sentence, key_label)
print(f'{train_fp} + {dev_fp} -> {rst}')
# 官方测试集
test_data = load_test_data(test_fp)
test_data = infer(test_data, classifier)
outp_fn = f'eprstmt_predict_{FLAGS.c.replace("few_all", "all")}.json'
dump_result(outp_fn, test_data)
if __name__ == "__main__":
app.run(main)