-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgeneration.py
170 lines (155 loc) · 9.63 KB
/
generation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
import argparse
import random
import torch
import os
from DataProcessor import DataSet
from get_logger import get_logger
from get_logger import task_uuid
from Vocab import Vocab
from upcrgene import Upcrgene, Engine
from torch.utils.data import DataLoader
from DataLoaderTopic import collate_fn
main_logger = get_logger("main", './log/test.log')
main_logger.info("TASK ID {}".format(task_uuid))
def config():
parser = argparse.ArgumentParser()
parser.add_argument("-test", "--test", action="store_true")
parser.add_argument('--inference', type=bool, default=False, )
parser.add_argument("-use_cuda", "--use_cuda", type=bool, default=False)
parser.add_argument("-gpu", "--gpu", type=str, default='1')
parser.add_argument("--num_workers", type=int, default=8)
parser.add_argument("--processed", action='store_false', )
parser.add_argument("--not_topic_guide", action='store_true', )
parser.add_argument("-dataset", "--dataset", choices=["TG-ReDial", "PersonaChat"], default="TG-ReDial")
parser.add_argument("-global_topic", "--global_topic", action='store_true', )
parser.add_argument("-global_topic_for_action", "--global_topic_for_action", action='store_true', )
parser.add_argument("-top_p", "--top_p", type=float, default=0., )
parser.add_argument("-profile_agg", "--profile_agg", action='store_true', )
parser.add_argument("-otp", "--otp", action='store_true', )
parser.add_argument("-s_profile_add_t", "--s_profile_add_t", action='store_true', )
parser.add_argument("-topic_copynet", "--topic_copynet", action='store_true', )
parser.add_argument("-profile_pred", "--profile_pred", action='store_true', )
parser.add_argument("-profile_select", "--profile_select", action='store_true', )
parser.add_argument("-sharping_profile", "--sharping_profile", action='store_true', )
parser.add_argument("-profile_contrast", "--profile_contrast", action='store_true', )
parser.add_argument("-global_topics", "--global_topics", type=int, default=10, )
parser.add_argument("-gene_add_profile", "--gene_add_profile", action='store_true',
)
parser.add_argument("-gpt2", "--gpt2", action='store_true', )
parser.add_argument("-only_context", "--only_context", action='store_true', )
parser.add_argument("-decoder_strategy", "--decoder_strategy", type=str, choices=['greedy', 'beam_search'], default='greedy', )
parser.add_argument("-history_turn", "--history_turn", type=int, default=100, )
parser.add_argument('--n_layers', type=int, default=6)
parser.add_argument('--n_position', type=int, default=160)
parser.add_argument('--n_inner_vocab', type=int, default=5000)
parser.add_argument('--n_inner_layers', type=int, default=3)
parser.add_argument('--n_inner_position', type=int, default=15)
parser.add_argument('--d_word_vec', type=int, default=512)
parser.add_argument('--n_head', type=int, default=8)
parser.add_argument('--d_k', type=int, default=64)
parser.add_argument('--d_v', type=int, default=64)
parser.add_argument('--pad_idx', type=int, default=2)
parser.add_argument('--d_model', type=int, default=512)
parser.add_argument('--d_inner', type=int, default=2048)
parser.add_argument('--dropout', type=float, default=0.1)
parser.add_argument('--n_warmup_steps', type=int, default=2000)
parser.add_argument('--scale_emb', type=bool, default=False)
parser.add_argument('--switch_interval', type=int, default=16)
parser.add_argument('--cache_turn', type=int, default=0)
parser.add_argument('--context_all_max_len', type=int, default=1024)
parser.add_argument('--context_max_len', type=int, default=150)
parser.add_argument('--r_max_len', type=int, default=50)
parser.add_argument('--r_beam_max_len', type=int, default=30)
parser.add_argument('--conv_max_len', type=int, default=500)
parser.add_argument('--profile_num', type=int, default=1)
parser.add_argument('--state_num', type=int, default=20)
parser.add_argument('--state_num_redial', type=int, default=20)
parser.add_argument('--pretrain_state_num', type=int, default=50)
parser.add_argument('--all_topic_num', type=int, default=20)
parser.add_argument('--all_topic_num_redial', type=int, default=40)
parser.add_argument('--movie_path_len', type=int, default=3)
parser.add_argument('--tag_num', type=int, default=3)
parser.add_argument('--preference_num', type=int, default=1)
parser.add_argument('--topic_num', type=int, default=2)
parser.add_argument('--action_num', type=int, default=10)
parser.add_argument('--action_num_redial', type=int, default=1)
parser.add_argument('--relation_num', type=int, default=150)
parser.add_argument('--movie_num', type=int, default=200)
parser.add_argument('--state_token', type=int, default=40)
parser.add_argument('--scale_prj', type=bool, default=True)
parser.add_argument('--epoch', type=int, default=100)
parser.add_argument('--task', type=str, default="meddg")
parser.add_argument('--dataset_file', type=str, default="./dataset/{}.zip")
parser.add_argument('--topic_file', type=str, default="./dataset/{}/topic.txt")
parser.add_argument('--topic_movie_file', type=str, default="./dataset/{}/tpmv.txt")
parser.add_argument('--vocab_file', type=str, default="./dataset/{}/tpvocab.txt")
parser.add_argument('--vocab_movie_file', type=str, default="./dataset/{}/tpmvvocab.txt")
parser.add_argument('--no_action_super', type=str, default=None)
parser.add_argument('--max_patience', type=int, default=20)
parser.add_argument('--log_loss_interval', type=int, default=100)
parser.add_argument('--gradient_stack', type=int, default=80)
parser.add_argument('--decay_interval', type=int, default=10000)
parser.add_argument('--decay_rate', type=float, default=0.9)
parser.add_argument('--lr', type=float, default=1e-5)
parser.add_argument('--valid_eval_interval', type=int, default=10000)
parser.add_argument('--test_eval_interval', type=int, default=10000)
parser.add_argument('--force_ckpt_dump', action='store_true')
parser.add_argument('--sub_gen_lambda', type=float, default=0.01)
parser.add_argument('--s_copy_lambda', type=int, default=1)
parser.add_argument('--a_copy_lambda', type=int, default=1)
parser.add_argument('--copy_lambda_mini', type=float, default=0.1)
parser.add_argument('--copy_lambda_decay_steps', type=int, default=10000)
parser.add_argument('--copy_lambda_decay_value', type=float, default=1.0)
parser.add_argument('--init_tau', type=float, default=1.0)
parser.add_argument('--tau_mini', type=float, default=0.1)
parser.add_argument('--tau_decay_total_steps', type=int, default=5000)
parser.add_argument('--tau_decay_rate', type=float, default=0.5)
parser.add_argument('--beam_width', type=int, default=1)
parser.add_argument('--wo_l', action='store_true')
parser.add_argument('--wo_m', action='store_true')
parser.add_argument('--wo_entropy_restrain', action='store_true')
parser.add_argument('--wo_repeat_penalty', action='store_true')
parser.add_argument('--wo_rl', action='store_true')
parser.add_argument('--super_only', action='store_true')
parser.add_argument('--hungary', action='store_true')
parser.add_argument('--super_epoch', type=int, default=5)
parser.add_argument('--batch_size', type=int, default=8)
parser.add_argument('--reg_lambda', type=float, default=5e-3)
parser.add_argument('--BOS_CONTEXT', type=str, default="[s_context]")
parser.add_argument('--EOS_CONTEXT', type=str, default="[/s_context]")
parser.add_argument('--BOS_RESPONSE', type=str, default="[s_response>]")
parser.add_argument('--EOS_RESPONSE', type=str, default="[/s_response]")
parser.add_argument('--BOS_ACTION', type=str, default="[s_action]")
parser.add_argument('--EOS_ACTION', type=str, default="[/s_action]")
parser.add_argument('--PAD_WORD', type=str, default="[PAD]")
parser.add_argument('--SENTENCE_SPLITER', type=str, default="[sent]")
parser.add_argument('--TOPIC_SPLITER', type=str, default="[unused2]")
parser.add_argument('--UNK_WORD', type=str, default="[UNK]")
parser.add_argument('--BOS_PRE', type=str, default="[s_preference]")
parser.add_argument('--EOS_PRE', type=str, default="[/s_preference]")
parser.add_argument('--BOS_PRO', type=str, default="[s_profile]")
parser.add_argument('--EOS_PRO', type=str, default="[/s_profile]")
args = parser.parse_args()
return args
def main():
random.seed(1234)
args = config()
main_logger.info("preparing data")
vocab = Vocab(args)
dataset = DataSet(args=args, vocab=vocab)
train_set, valid_set, test_set, users, user_cont = dataset.get_dialog(task='gene')
train_loader = DataLoader(train_set, batch_size=args.batch_size, collate_fn=collate_fn, pin_memory=True, num_workers=args.num_workers, shuffle=True)
valid_loader = DataLoader(valid_set, batch_size=args.batch_size, collate_fn=collate_fn, pin_memory=True, num_workers=args.num_workers, shuffle=False)
test_loader = DataLoader(test_set, batch_size=args.batch_size, collate_fn=collate_fn, pin_memory=True, num_workers=args.num_workers, shuffle=False)
vocab = Vocab(args)
excrs = Upcrgene(args=args, vocab=vocab, user_cont=user_cont, train_set=train_set)
engine = Engine(args=args, model=excrs, vocab=vocab)
if not os.path.exists('saved_model'):
os.mkdir('saved_model')
if args.test:
engine.model.load_state_dict(torch.load('saved_model/best_generate_model_{}.pkl'.format(args.dataset)), strict=False)
engine.test(test_loader, get_ppl=False, is_show=False)
else:
engine.train(train_loader, valid_loader, test_loader)
if __name__ == '__main__':
main()