-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathtnews.py
167 lines (139 loc) · 5.07 KB
/
tnews.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
# -*- encoding: utf-8 -*-
"""
-------------------------------------------------
File Name: tnews.py
Description :
Author : Wings DH
Time: 6/16/21 10:40 PM
-------------------------------------------------
Change Activity:
6/16/21: Create
-------------------------------------------------
"""
import sys
import os
from absl import app, flags
from utils.seed import set_seed
set_seed()
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
from modeling.classifier import LabelData
from modeling.mlm_encoder import MlmBertEncoder
from modeling.retriever_classifier import RetrieverClassifier
from utils.cls_train import eval_model, dump_result
from utils.data_utils import load_data, load_test_data
flags.DEFINE_string('c', '0', 'index of tnews dataset')
FLAGS = flags.FLAGS
def infer(test_data, classifier):
for d in test_data:
sentence = d.pop('sentence')
label, _ = classifier.classify(sentence)
d['label'] = label_2_num[label]
return test_data
label_2_desc = {'news_tech': '科技',
'news_entertainment': '娱乐',
'news_car': '汽车',
'news_travel': '旅游',
'news_finance': '财经',
'news_edu': '教育',
'news_world': '国际',
'news_house': '房产',
'news_game': '电竞',
'news_military': '军事',
'news_story': '故事',
'news_culture': '文化',
'news_sports': '体育',
'news_agriculture': '农业',
'news_stock': '股票'}
label_2_num = {
'news_car': '107',
'news_world': '113',
'news_stock': '114',
'news_story': '100',
'news_culture': '101',
'news_edu': '108',
'news_house': '106',
'news_game': '116',
'news_travel': '112',
'news_sports': '103',
'news_military': '110',
'news_finance': '104',
'news_tech': '109',
'news_entertainment': '102',
'news_agriculture': '115',
}
code_2_label = {0: 'news_tech',
1: 'news_entertainment',
2: 'news_car',
3: 'news_travel',
4: 'news_finance',
5: 'news_edu',
6: 'news_world',
7: 'news_house',
8: 'news_game',
9: 'news_military',
10: 'news_story',
11: 'news_culture',
12: 'news_sports',
13: 'news_agriculture',
14: 'news_stock'}
def get_data_fp(use_index):
train_fp = f'dataset/tnews/train_{use_index}.json'
dev_fp = f'dataset/tnews/dev_{use_index}.json'
test_fp = 'dataset/tnews/test.json'
my_test_fp = []
for ind in range(5):
if str(ind) != use_index:
my_test_fp.append(f'dataset/tnews/dev_{ind}.json')
return train_fp, dev_fp, my_test_fp, test_fp
def main(_):
# 参数
# 加载数据
train_fp, dev_fp, my_test_fp, test_fp = get_data_fp(FLAGS.c)
key_label = 'label_desc'
key_sentence = 'sentence'
train_data = load_data(train_fp, key_sentence, key_label)
data = [LabelData(text, label) for text, label in train_data]
dev_data = load_data(dev_fp, key_sentence, key_label)
dev_sentences = [LabelData(text, label) for text, label in dev_data]
# 初始化encoder
model_path = '../chinese_roberta_wwm_ext_L-12_H-768_A-12'
weight_path = '../temp_tnews.weights'
prefix = '以下是一则关于啊啊的新闻。'
mask_ind = [7, 8]
n_top = 3
encoder = MlmBertEncoder(model_path, weight_path, train_data, dev_data, prefix, mask_ind, label_2_desc, 8,
merge=MlmBertEncoder.CONCAT, norm=False, lr=1e-4, policy='attention')
classifier = RetrieverClassifier(encoder, data, n_top=n_top)
print('Eval model')
rst = eval_model(classifier, [dev_fp], key_sentence, key_label)
print(f'{train_fp} + {dev_fp} -> {rst}')
# fine tune
best_acc = 0
for epoch in range(10):
print(f'Training epoch {epoch}')
encoder.train(1)
# 加载分类器
classifier = RetrieverClassifier(encoder, data, n_top=n_top)
print('Eval model')
rst = eval_model(classifier, [dev_fp], key_sentence, key_label)
if rst > best_acc:
encoder.save()
best_acc = rst
print(f'Save for best {best_acc}')
print(f'{train_fp} + {dev_fp} -> {rst}')
# 加载最终模型
encoder.load()
classifier = RetrieverClassifier(encoder, data, n_top=n_top)
# train_py_env = ClassFewShotEnv(classifier.retriever, dev_sentences, code_2_label)
# train_py_env.reset()
# train_process(train_py_env, train_py_env)
# 自测试集测试
rst = eval_model(classifier, my_test_fp, key_sentence, key_label)
print(f'{train_fp} + {dev_fp} -> {rst}')
# 官方测试集
test_data = load_test_data(test_fp)
test_data = infer(test_data, classifier)
outp_fn = f'tnewsf_predict_{FLAGS.c.replace("few_all", "all")}.json'
dump_result(outp_fn, test_data)
if __name__ == "__main__":
app.run(main)