-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrecommendation.py
125 lines (118 loc) · 6.79 KB
/
recommendation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import argparse
import random
from DataProcessor import DataSet
from get_logger import get_logger
from get_logger import task_uuid
from Vocab import Vocab
from upcrrec import Upcrrec,Engine
main_logger = get_logger("main", './log/test.log')
main_logger.info("TASK ID {}".format(task_uuid))
def config():
parser = argparse.ArgumentParser()
parser.add_argument('--inference',type=bool,default=False,)
parser.add_argument("--processed", type=bool, default=True, )
parser.add_argument('--n_layers', type=int, default=6)
parser.add_argument('--n_position', type=int, default=160)
parser.add_argument('--n_inner_vocab', type=int, default=5000)
parser.add_argument('--n_inner_layers', type=int, default=3)
parser.add_argument('--n_inner_position', type=int, default=15)
parser.add_argument('--d_word_vec', type=int, default=512)
parser.add_argument('--n_head', type=int, default=8)
parser.add_argument('--d_k', type=int, default=64)
parser.add_argument('--d_v', type=int, default=64)
parser.add_argument('--pad_idx', type=int, default=2)
parser.add_argument('--d_model', type=int, default=512)
parser.add_argument('--d_inner', type=int, default=2048)
parser.add_argument('--dropout', type=float, default=0.1)
parser.add_argument('--n_warmup_steps', type=int, default=2000)
parser.add_argument('--scale_emb', type=bool, default=False)
parser.add_argument('--switch_interval', type=int, default=16)
parser.add_argument('--cache_turn', type=int, default=0)
parser.add_argument('--context_all_max_len', type=int, default=1024)
parser.add_argument('--context_max_len', type=int, default=150)
parser.add_argument('--r_max_len', type=int, default=50)
parser.add_argument('--r_beam_max_len', type=int, default=30)
parser.add_argument('--profile_num', type=int, default=1)
parser.add_argument('--state_num', type=int, default=20)
parser.add_argument('--state_num_redial', type=int, default=20)
parser.add_argument('--pretrain_state_num', type=int, default=50)
parser.add_argument('--all_topic_num', type=int, default=20)
parser.add_argument('--all_topic_num_redial', type=int, default=40)
parser.add_argument('--movie_path_len', type=int, default=3)
parser.add_argument('--tag_num', type=int, default=3)
parser.add_argument('--preference_num', type=int, default=1)
parser.add_argument('--topic_num', type=int, default=2)
parser.add_argument('--action_num', type=int, default=10)
parser.add_argument('--action_num_redial', type=int, default=1)
parser.add_argument('--relation_num', type=int, default=150)
parser.add_argument('--movie_num', type=int, default=200)
parser.add_argument('--state_token', type=int, default=40)
parser.add_argument('--scale_prj', type=bool, default=True)
parser.add_argument('--epoch', type=int, default=100)
parser.add_argument('--task', type=str, default="meddg")
parser.add_argument('--dataset_file', type=str, default="./dataset/{}.zip")
parser.add_argument('--topic_file', type=str, default="./dataset/{}/topic.txt")
parser.add_argument('--topic_movie_file', type=str, default="./dataset/{}/tpmv.txt")
parser.add_argument('--vocab_file', type=str, default="./dataset/{}/tpvocab.txt")
parser.add_argument('--vocab_movie_file', type=str, default="./dataset/{}/tpmvvocab.txt")
parser.add_argument('--no_action_super', type=str, default=None)
parser.add_argument('--max_patience', type=int, default=20)
parser.add_argument('--log_loss_interval', type=int, default=100)
parser.add_argument('--gradient_stack', type=int, default=80)
parser.add_argument('--decay_interval', type=int, default=10000)
parser.add_argument('--decay_rate', type=float, default=0.9)
parser.add_argument('--lr', type=float, default=1e-5)
parser.add_argument('--valid_eval_interval', type=int, default=10000)
parser.add_argument('--test_eval_interval', type=int, default=10000)
parser.add_argument('--force_ckpt_dump', action='store_true')
parser.add_argument('--sub_gen_lambda', type=float, default=0.01)
parser.add_argument('--s_copy_lambda', type=int, default=1)
parser.add_argument('--a_copy_lambda', type=int, default=1)
parser.add_argument('--copy_lambda_mini', type=float, default=0.1)
parser.add_argument('--copy_lambda_decay_steps', type=int, default=10000)
parser.add_argument('--copy_lambda_decay_value', type=float, default=1.0)
parser.add_argument('--init_tau', type=float, default=1.0)
parser.add_argument('--tau_mini', type=float, default=0.1)
parser.add_argument('--tau_decay_total_steps', type=int, default=5000)
parser.add_argument('--tau_decay_rate', type=float, default=0.5)
parser.add_argument('--beam_width', type=int, default=1)
parser.add_argument('--wo_l', action='store_true')
parser.add_argument('--wo_m', action='store_true')
parser.add_argument('--wo_entropy_restrain', action='store_true')
parser.add_argument('--wo_repeat_penalty', action='store_true')
parser.add_argument('--wo_rl', action='store_true')
parser.add_argument('--super_only', action='store_true')
parser.add_argument('--hungary', action='store_true')
parser.add_argument('--super_epoch', type=int, default=5)
parser.add_argument('--batch_size', type=int, default=8)
parser.add_argument('--reg_lambda', type=float, default=5e-3)
parser.add_argument('--BOS_CONTEXT', type=str, default="[s_context]")
parser.add_argument('--EOS_CONTEXT', type=str, default="[/s_context]")
parser.add_argument('--BOS_RESPONSE', type=str, default="[s_response>]")
parser.add_argument('--EOS_RESPONSE', type=str, default="[/s_response]")
parser.add_argument('--BOS_ACTION', type=str, default="[s_action]")
parser.add_argument('--EOS_ACTION', type=str, default="[/s_action]")
parser.add_argument('--PAD_WORD', type=str, default="[PAD]")
parser.add_argument('--SENTENCE_SPLITER', type=str, default="[sent]")
parser.add_argument('--TOPIC_SPLITER', type=str, default="[unused2]")
parser.add_argument('--UNK_WORD', type=str, default="[UNK]")
parser.add_argument('--BOS_PRE', type=str, default="[s_preference]")
parser.add_argument('--EOS_PRE', type=str, default="[/s_preference]")
parser.add_argument('--BOS_PRO', type=str, default="[s_profile]")
parser.add_argument('--EOS_PRO', type=str, default="[/s_profile]")
args = parser.parse_args()
return args
def main():
random.seed(1234)
args = config()
main_logger.info("preparing data")
vocab = Vocab(args)
dataset = DataSet(args=args, vocab=vocab)
train, valid, test, users, user_cont = dataset.get_dialog(task='rec')
vocab = Vocab(task='rec')
random.shuffle(train)
excrs = Upcrrec(vocab=vocab,user_cont=user_cont)
engine = Engine(model=excrs,vocab=vocab)
engine.train(train,test)
if __name__ == '__main__':
main()