diff --git a/ActionTopic.py b/ActionTopic.py new file mode 100644 index 0000000..6c6209b --- /dev/null +++ b/ActionTopic.py @@ -0,0 +1,173 @@ +import torch +import torch.nn as nn +from tools import Tools + +class Action(nn.Module): + def __init__(self, args, p_encoder, a_decoder, hidden_size, main_encoder, + n_topic_vocab, bos_idx, max_len, glo2loc, loc2glo, vocab): + super(Action, self).__init__() + self.args = args + self.p_encoder = p_encoder + self.a_decoder = a_decoder + self.main_encoder = main_encoder + self.vocab = vocab + self.hidden_size = hidden_size + self.n_topic_vocab = n_topic_vocab + self.bos_idx = bos_idx + self.max_len = max_len + self.glo2loc = glo2loc + self.loc2glo = loc2glo + self.gen_proj = nn.Sequential(nn.Linear(self.hidden_size, self.n_topic_vocab)) + + def forward(self, m, l, context, context_len, related_topics, related_topics_len, ar_gth, ar_gth_len, tp_path, tp_path_len, user_id, user_embed, mode, encoder_embed=None, decoder_embed=None, profile_prob=None, tp_path_embed=None): + if l is not None: + l_mask = l.new_ones(l.size(0), 1, l.size(1)) + l_hidden = self.p_encoder(l, l_mask, embed=encoder_embed) + if tp_path_embed==None: + tp_path = one_hot_scatter(tp_path, self.n_topic_vocab) + tp_mask = Tools.get_mask_via_len(tp_path_len, self.args.state_num) + tp_hidden = self.p_encoder(tp_path, tp_mask, embed=encoder_embed) + else: + tp_path = one_hot_scatter(tp_path, self.n_topic_vocab) + tp_mask = Tools.get_mask_via_len(tp_path_len, self.args.state_num) + tp_hidden = self.p_encoder(tp_path_embed, tp_mask, embed=encoder_embed, embed_input=True) + context_mask = Tools.get_mask_via_len(context_len, self.args.context_max_len) + context_hidden = self.main_encoder(context, context_mask) + if related_topics is not None: + related_topics = one_hot_scatter(related_topics, self.n_topic_vocab) + related_topics_mask = Tools.get_mask_via_len(related_topics_len, self.args.relation_num) + related_topic_hidden = self.p_encoder(related_topics, related_topics_mask, embed=encoder_embed) + if m is not None: + m_mask = m.new_ones(m.size(0), 1, m.size(1)) + m_hidden = self.p_encoder(m, m_mask, embed=encoder_embed) + + if m is None: + src_hidden = torch.cat([tp_hidden, context_hidden, related_topic_hidden], 1) + src_mask = torch.cat([tp_mask, context_mask, related_topics_mask], 2) + elif related_topics is None: + src_hidden = torch.cat([m_hidden, l_hidden, tp_hidden, context_hidden], 1) + src_mask = torch.cat([m_mask, l_mask, tp_mask, context_mask], 2) + elif l is None: + src_hidden = torch.cat([m_hidden, tp_hidden, context_hidden, related_topic_hidden], 1) + src_mask = torch.cat([m_mask, tp_mask, context_mask, related_topics_mask], 2) + else: + src_hidden = torch.cat([m_hidden, l_hidden, tp_hidden, context_hidden, related_topic_hidden], 1) + src_mask = torch.cat([m_mask, l_mask, tp_mask, context_mask, related_topics_mask], 2) + + probs = None + action_mask = Tools.get_mask_via_len(ar_gth_len, self.args.action_num) + if mode == 'train': + for i in range(0, self.args.action_num, 2): + seq_gth = ar_gth[:, 0: i + 1] + ar_mask = action_mask[:, :, 0:i + 1] + dec_output = Tools._single_decode(seq_gth.detach(), src_hidden, src_mask, self.a_decoder, ar_mask) + prob = self.proj(dec_output, src_hidden, src_mask, m, l, context, tp_path, related_topics, embed=decoder_embed, profile_prob=profile_prob) + if i == 0: + probs = prob + else: + probs = torch.cat([probs, prob], 1) + return probs + else: + seq_gen = None + for i in range(0, self.args.action_num, 2): + if i == 0: + seq_gen = ar_gth[:, 0:i + 1] + else: + seq_gen = torch.cat([seq_gen, ar_gth[:, i:i + 1]], 1) + ar_mask = action_mask[:, :, 0:i + 1] + dec_output = Tools._single_decode(seq_gen.detach(), src_hidden, src_mask, self.a_decoder, ar_mask) + single_step_prob = self.proj(dec_output, src_hidden, src_mask, m, l, context, tp_path, related_topics, embed=decoder_embed, profile_prob=profile_prob) + if i == 0: + probs = single_step_prob + else: + probs = torch.cat([probs, single_step_prob], 1) + single_step_word = torch.argmax(single_step_prob, -1) + seq_gen = torch.cat([seq_gen, single_step_word], 1) + return seq_gen, probs + + def proj(self, dec_out, src_hidden, src_mask, pv_m, l, context, tp, related_topics, embed=None, profile_prob=None): + B, L_a = dec_out.size(0), dec_out.size(1) + gen_logit = self.gen_prob(dec_out, embed) + copy_logit = torch.bmm(dec_out, src_hidden.permute(0, 2, 1)) + copy_logit = copy_logit.masked_fill((src_mask == 0).expand(-1, L_a, -1), -1e9 if copy_logit.dtype==torch.float32 else -1e4) + logits = torch.cat([gen_logit, copy_logit], -1) + if self.args.scale_prj: + logits *= self.hidden_size ** -0.5 + probs = torch.softmax(logits, -1) + gen_prob = probs[:, :, :self.n_topic_vocab] + if self.args.not_copynet: + probs = gen_prob + elif pv_m is None: + copy_context_prob = probs[:, :, self.n_topic_vocab + self.args.state_num: self.n_topic_vocab + self.args.state_num + self.args.context_max_len] + transfer_context_word = torch.gather(self.glo2loc.unsqueeze(0).expand(B, -1), 1, context) + copy_context_temp = copy_context_prob.new_zeros(B, L_a, self.n_topic_vocab) + copy_context_prob = copy_context_temp.scatter_add(dim=2, + index=transfer_context_word.unsqueeze(1).expand(-1, L_a, -1), + src=copy_context_prob) + copy_tp_prob = probs[:, :, self.n_topic_vocab : self.n_topic_vocab + self.args.state_num] + copy_tp_prob = torch.bmm(copy_tp_prob, tp) + copy_relation_prob = probs[:, :, self.n_topic_vocab + self.args.state_num + self.args.context_max_len:] + copy_relation_prob = torch.bmm(copy_relation_prob, related_topics) + probs = gen_prob + copy_tp_prob + copy_context_prob + copy_relation_prob + elif related_topics is None: + copy_m_prob = probs[:, :, self.n_topic_vocab:self.n_topic_vocab + self.args.preference_num] + copy_m_prob = torch.bmm(copy_m_prob, pv_m) + copy_l_prob = probs[:, :, self.n_topic_vocab + self.args.preference_num:self.n_topic_vocab + self.args.preference_num + self.args.profile_num] + copy_l_prob = torch.bmm(copy_l_prob, l) + copy_context_prob = probs[:, :, self.n_topic_vocab + self.args.preference_num + self.args.profile_num + self.args.state_num: self.n_topic_vocab + self.args.preference_num + self.args.profile_num + self.args.state_num + self.args.context_max_len] + transfer_context_word = torch.gather(self.glo2loc.unsqueeze(0).expand(B, -1), 1, context) + copy_context_temp = copy_context_prob.new_zeros(B, L_a, self.n_topic_vocab) + copy_context_prob = copy_context_temp.scatter_add(dim=2, + index=transfer_context_word.unsqueeze(1).expand(-1, L_a, -1), + src=copy_context_prob) + copy_tp_prob = probs[:, :, self.n_topic_vocab + self.args.preference_num + self.args.profile_num: self.n_topic_vocab + self.args.preference_num + self.args.profile_num + self.args.state_num] + copy_tp_prob = torch.bmm(copy_tp_prob, tp) + probs = gen_prob + copy_l_prob + copy_tp_prob + copy_context_prob + copy_m_prob + elif l is None: + copy_m_prob = probs[:, :, self.n_topic_vocab:self.n_topic_vocab + self.args.preference_num] + copy_m_prob = torch.bmm(copy_m_prob, pv_m) + copy_context_prob = probs[:, :, self.n_topic_vocab + self.args.preference_num + self.args.state_num: self.n_topic_vocab + self.args.preference_num + self.args.state_num + self.args.context_max_len] + transfer_context_word = torch.gather(self.glo2loc.unsqueeze(0).expand(B, -1), 1, context) + copy_context_temp = copy_context_prob.new_zeros(B, L_a, self.n_topic_vocab) + copy_context_prob = copy_context_temp.scatter_add(dim=2, + index=transfer_context_word.unsqueeze(1).expand(-1, L_a, -1), + src=copy_context_prob) + copy_tp_prob = probs[:, :, self.n_topic_vocab + self.args.preference_num: self.n_topic_vocab + self.args.preference_num + self.args.state_num] + copy_tp_prob = torch.bmm(copy_tp_prob, tp) + copy_relation_prob = probs[:, :, self.n_topic_vocab + self.args.preference_num + self.args.state_num + self.args.context_max_len:] + copy_relation_prob = torch.bmm(copy_relation_prob, related_topics) + probs = gen_prob + copy_tp_prob + copy_context_prob + copy_relation_prob + copy_m_prob + else: + copy_m_prob = probs[:, :, self.n_topic_vocab:self.n_topic_vocab + self.args.preference_num] + copy_m_prob = torch.bmm(copy_m_prob, pv_m) + copy_l_prob = probs[:, :, self.n_topic_vocab + self.args.preference_num:self.n_topic_vocab + self.args.preference_num + self.args.profile_num] + copy_l_prob = torch.bmm(copy_l_prob, l) + copy_context_prob = probs[:, :, self.n_topic_vocab + self.args.preference_num + self.args.profile_num + self.args.state_num: self.n_topic_vocab + self.args.preference_num + self.args.profile_num + self.args.state_num + self.args.context_max_len] + transfer_context_word = torch.gather(self.glo2loc.unsqueeze(0).expand(B, -1), 1, context) + copy_context_temp = copy_context_prob.new_zeros(B, L_a, self.n_topic_vocab) + copy_context_prob = copy_context_temp.scatter_add(dim=2, index=transfer_context_word.unsqueeze(1).expand(-1, L_a, -1), src=copy_context_prob) + copy_tp_prob = probs[:, :, self.n_topic_vocab + self.args.preference_num + self.args.profile_num: self.n_topic_vocab + self.args.preference_num + self.args.profile_num + self.args.state_num] + copy_tp_prob = torch.bmm(copy_tp_prob, tp) + copy_relation_prob = probs[:, :, self.n_topic_vocab + self.args.preference_num + self.args.profile_num + self.args.state_num + self.args.context_max_len:] + copy_relation_prob = torch.bmm(copy_relation_prob, related_topics) + probs = gen_prob + copy_l_prob + copy_tp_prob + copy_context_prob + copy_relation_prob + copy_m_prob + if self.args.topic_copynet: + probs = probs + profile_prob.unsqueeze(1) + return probs + + def gen_prob(self, dec_output, embed=None): + + if embed is None: + prob = self.gen_proj(dec_output) + else: + assert embed.size(0) == self.gen_proj[0].out_features + prob = dec_output.matmul(embed.T) + return prob + + +def one_hot_scatter(indice, num_classes, dtype=torch.float): + indice_shape = list(indice.shape) + placeholder = torch.zeros(*(indice_shape + [num_classes]), device=indice.device, dtype=dtype) + v = 1 if dtype == torch.long else 1.0 + placeholder.scatter_(-1, indice.unsqueeze(-1), v) + return placeholder diff --git a/Action_all.py b/Action_all.py new file mode 100644 index 0000000..909155f --- /dev/null +++ b/Action_all.py @@ -0,0 +1,102 @@ +import json +import torch.nn.functional as F +import ipdb +import torch +import torch.nn as nn +from tools import Tools + +class Action(nn.Module): + def __init__(self,p_encoder,a_decoder,hidden_size,main_encoder, + m_encoder,n_topic_vocab,bos_idx,max_len,glo2loc,loc2glo,vocab): + super(Action, self).__init__() + self.p_encoder = p_encoder + self.a_decoder = a_decoder + self.m_encoder = m_encoder + self.main_encoder = main_encoder + self.vocab = vocab + self.hidden_size = hidden_size + self.n_topic_vocab = n_topic_vocab + self.bos_idx = bos_idx + self.max_len = max_len + self.glo2loc = glo2loc + self.loc2glo = loc2glo + self.gen_proj = nn.Linear(self.hidden_size,self.n_topic_vocab) + self.topic2movie = nn.Linear(2583,self.n_topic_vocab-2583) + self.mask = torch.zeros(self.args.batch_size,1,self.n_topic_vocab).cuda() + self.mask[:,:,2583:] = 1 + self.pad = torch.zeros(self.args.batch_size,1,2583).cuda() + def forward(self,m,l,context,context_len,ar_gth,ar_gth_len, + tp_path,tp_path_len,related_movies,related_movies_len,mode): + if mode == 'test': + m = one_hot_scatter(m, self.n_topic_vocab) + l = one_hot_scatter(l, self.n_topic_vocab) + m_mask = m.new_ones(m.size(0),1,m.size(1)) + m_hidden = self.p_encoder(m,m_mask) + l_mask = l.new_ones(l.size(0),1,l.size(1)) + l_hidden = self.p_encoder(l,l_mask) + tp_path = one_hot_scatter(tp_path, self.n_topic_vocab) + tp_mask = Tools.get_mask_via_len(tp_path_len, self.args.state_num) + tp_path_hidden = self.p_encoder(tp_path, tp_mask) + context_mask = Tools.get_mask_via_len(context_len, self.args.context_max_len) + context_hidden = self.main_encoder(context, context_mask) + related_movies = one_hot_scatter(related_movies,self.n_topic_vocab) + related_movies_mask = Tools.get_mask_via_len(related_movies_len,self.args.movie_num) + related_movies_hidden = self.p_encoder(related_movies,related_movies_mask) + src_hidden = torch.cat([m_hidden,l_hidden,context_hidden,tp_path_hidden,related_movies_hidden],1) + src_mask = torch.cat([m_mask,l_mask,context_mask,tp_mask,related_movies_mask],2) + action_mask = Tools.get_mask_via_len(ar_gth_len,self.args.action_num) + if mode == 'train': + seq_gth = ar_gth[:,[0]] + ar_mask = action_mask[:,:,[0]] + dec_output = Tools._single_decode(seq_gth.detach(), src_hidden, src_mask, self.a_decoder, ar_mask) + prob = self.proj(dec_out=dec_output, src_hidden=src_hidden, src_mask=src_mask, + tp=tp_path, m=m, l=l, context=context,related_movies=related_movies) + return prob + else: + seq_gen = ar_gth[:,[0]] + ar_mask = action_mask[:, :, [0]] + dec_output = Tools._single_decode(seq_gen.detach(), src_hidden, src_mask, self.a_decoder,ar_mask) + prob = self.proj(dec_out=dec_output, src_hidden=src_hidden, src_mask=src_mask, + tp=tp_path, m=m, l=l, context=context,related_movies=related_movies) + word = torch.argmax(prob, -1) + return word, prob + + def proj(self,dec_out, src_hidden,src_mask, tp, m, l, context,related_movies ): + B,L_a =dec_out.size(0), dec_out.size(1) + gen_logit = self.gen_proj(dec_out) + copy_logit = torch.bmm(dec_out, src_hidden.permute(0, 2, 1)) + copy_logit = copy_logit.masked_fill((src_mask == 0).expand(-1, L_a, -1), -1e9) + logits = torch.cat([gen_logit, copy_logit], -1) + if self.args.scale_prj: + logits *= self.hidden_size ** -0.5 + probs = torch.softmax(logits, -1) + gen_prob = probs[:, :, :self.n_topic_vocab] + copy_m = probs[:, :, self.n_topic_vocab : + self.n_topic_vocab + self.args.preference_num] + copy_m_prob = torch.bmm(copy_m, m) + copy_l = probs[:, :, self.n_topic_vocab+ self.args.preference_num:self.n_topic_vocab+ self.args.preference_num+ self.args.profile_num] + copy_l_prob = torch.bmm(copy_l, l) + copy_context_prob = probs[:, :, self.n_topic_vocab+ self.args.preference_num+ self.args.profile_num:self.n_topic_vocab+ self.args.preference_num+ self.args.profile_num+ self.args.context_max_len] + transfer_context_word = torch.gather(self.glo2loc.unsqueeze(0).expand(B, -1), 1, context) + copy_context_temp = copy_context_prob.new_zeros(B, L_a, self.n_topic_vocab) + copy_context_prob = copy_context_temp.scatter_add(dim=2, + index=transfer_context_word.unsqueeze(1).expand(-1, L_a, -1), + src=copy_context_prob) + copy_tp = probs[:,:,self.n_topic_vocab+ self.args.preference_num+ self.args.profile_num+ self.args.context_max_len: + self.n_topic_vocab+ self.args.preference_num+ self.args.profile_num+ self.args.context_max_len + self.args.state_num] + copy_tp_prob = torch.bmm(copy_tp, tp) + copy_relation = probs[:,:,self.n_topic_vocab+ self.args.preference_num+ self.args.profile_num+ self.args.context_max_len + self.args.state_num:] + copy_relation = torch.bmm(copy_relation,related_movies) + probs = gen_prob + copy_m_prob + copy_l_prob + copy_context_prob + copy_tp_prob + copy_relation + probs = probs.mul(self.mask) + norm = torch.sum(probs,-1) + norm = norm.unsqueeze(1) + probs/=norm + return probs + +def one_hot_scatter(indice, num_classes, dtype=torch.float): + indice_shape = list(indice.shape) + placeholder = torch.zeros(*(indice_shape + [num_classes]), device=indice.device, dtype=dtype) + v = 1 if dtype == torch.long else 1.0 + placeholder.scatter_(-1, indice.unsqueeze(-1), v) + return placeholder \ No newline at end of file diff --git a/Bleu.py b/Bleu.py new file mode 100644 index 0000000..89e841f --- /dev/null +++ b/Bleu.py @@ -0,0 +1,57 @@ + +import argparse +import sys +import numpy as np +from tqdm import tqdm +from nltk.translate.bleu_score import sentence_bleu +import ipdb +import re +import jieba + +def bleu_cal(sen1, tar1): + bleu1 = sentence_bleu([tar1], sen1, weights=(1, 0, 0, 0)) + bleu2 = sentence_bleu([tar1], sen1, weights=(0.5, 0.5, 0, 0)) + bleu3 = sentence_bleu([tar1], sen1, weights=(0.33, 0.33, 0.33, 0)) + bleu4 = sentence_bleu([tar1], sen1, weights=(0.25, 0.25, 0.25, 0.25)) + return bleu1, bleu2, bleu3, bleu4 + +def bleu(args, tokenized_gen, tokenized_tar): + print_num = 0 + bleu1_sum, bleu2_sum, bleu3_sum, bleu4_sum, count = 0, 0, 0, 0, 0 + for sen, tar in zip(tokenized_gen, tokenized_tar): + for j,word in enumerate(sen): + if word == args.EOS_RESPONSE: + sen = sen[:j] + break + tar = tar[1:] + for k,word in enumerate(tar): + if word == args.EOS_RESPONSE: + tar = tar[:k] + break + full_sen_gen = '' + full_sen_gth = '' + for word in sen: + full_sen_gen += word + for word in tar: + full_sen_gth +=word + sen_split_by_movie = list(full_sen_gen.split('')) + sen_1 = [] + for i, sen_split in enumerate(sen_split_by_movie): + for segment in jieba.cut(sen_split): + sen_1.append(segment) + if i != len(sen_split_by_movie) - 1: + sen_1.append('') + tar_split_by_movie = list(full_sen_gth.split('')) + tar_1 = [] + for i, tar_split in enumerate(tar_split_by_movie): + for segment in jieba.cut(tar_split): + tar_1.append(segment) + if i != len(tar_split_by_movie) - 1: + tar_1.append('') + bleu1, bleu2, bleu3, bleu4 = bleu_cal(sen_1, tar_1) + bleu1_sum += bleu1 + bleu2_sum += bleu2 + bleu3_sum += bleu3 + bleu4_sum += bleu4 + count += 1 + return bleu1_sum / count, bleu2_sum / count, bleu3_sum / count, bleu4_sum / count \ No newline at end of file diff --git a/DataLoaderRec.py b/DataLoaderRec.py new file mode 100644 index 0000000..2684786 --- /dev/null +++ b/DataLoaderRec.py @@ -0,0 +1,110 @@ +import json +from DataProcessor import clip_pad_sentence,clip_pad_context +import torch +import pandas as pd +import re + +class DataLoaderRec(): + def __init__(self,dataset,vocab): + self.dataset = dataset + self.vocab = vocab + self.batch_size = self.args.batch_size + self.history_convs = [ [] for _ in range(self.batch_size)] + self.number_workers = self.args.worker_num + self.sunset=False + self.conv_index = 0 + self.name2id = self.get_name2id() + self.topic_graph = self.get_topic_graph() + + def __iter__(self): + return self + + def __next__(self): + for i in range(len(self.history_convs)): + if len(self.history_convs[i]) == 0: + if not self.sunset: + processed_session = self.load_processed_session() + if processed_session is not None: + self.history_convs[i] = processed_session + self.history_convs = [ conv for conv in self.history_convs if len(conv)>0 ] + if len(self.history_convs) == 0: + print("stop") + raise StopIteration + batch_convs = [ conv[0] for conv in self.history_convs ] + self.history_convs = [ conv[1:] for conv in self.history_convs ] + nn_inputs = [] + for idx, batch_data in enumerate(zip(*batch_convs)): + nn_inputs.append(torch.tensor(data=batch_data, dtype=torch.long).cuda()) + return nn_inputs + + def load_processed_session(self): + if self.conv_index >= len(self.dataset): + self.sunset = True + return None + conv = self.dataset[self.conv_index] + processed_session = self.process(conv) + self.conv_index += 1 + return processed_session + + def process(self,conversation): + session_segs = [] + id = int(conversation[0]) + contexts = conversation[-2] + all_topics = conversation[-1] + utterances = conversation[1:-2] + uttr_len = len(utterances) + all_topic, all_topic_len = clip_pad_sentence(all_topics, self.args.all_topic_num, self.args.PAD_WORD) + all_topic = self.vocab.topic2index(all_topic) + for i in range(2,uttr_len,2): + response = utterances[i] + resp = response[0] + resp, resp_len = clip_pad_sentence(resp, self.args.r_max_len, self.args.PAD_WORD, sos=self.args.BOS_RESPONSE, eos=self.args.EOS_RESPONSE) + action_R = response[2] + if action_R == []: + continue + a_R, a_R_len = clip_pad_sentence(action_R, self.args.action_num, self.args.PAD_WORD) + context = contexts[:i] + context, context_len = clip_pad_context(context, self.args.context_max_len, self.args.PAD_WORD, self.args.SENTENCE_SPLITER) + state_R = utterances[i-2][1] + state_R, state_R_len = clip_pad_sentence(state_R, self.args.state_num, self.args.PAD_WORD) + Seeker = utterances[i-1] + seek = Seeker[0] + seek, seek_len = clip_pad_sentence(seek, self.args.r_max_len, self.args.PAD_WORD, sos=self.args.BOS_CONTEXT,eos=self.args.EOS_CONTEXT, save_prefix=False) + state_U = Seeker[1] + state_U, state_U_len = clip_pad_sentence(state_U, self.args.state_num, self.args.PAD_WORD) + action_U = Seeker[1][-1] + pv_action = action_U + related_topics = self.get_related_movies(pv_action,self.args.movie_num) + related_topics, related_topics_len = clip_pad_sentence(related_topics, self.args.movie_num, self.args.PAD_WORD) + context_idx = self.vocab.word2index(context) + seek_idx = self.vocab.word2index(seek) + resp_idx = self.vocab.word2index(resp) + state_R = self.vocab.topic2index(state_R) + a_R = self.vocab.topic2index(a_R) + state_U = self.vocab.topic2index(state_U) + related_topics = self.vocab.topic2index(related_topics) + session_segs.append([id,all_topic, all_topic_len,context_idx,context_len, + state_U, state_U_len,a_R, a_R_len,seek_idx,seek_len, + resp_idx,resp_len,state_R,state_R_len,related_topics,related_topics_len,1]) + session_segs[0][-1] = 0 + return session_segs + + def get_topic_graph(self): + with open('TG-ReDial/dataset/TG-ReDial/graph_rec.json') as f: + topic_graph = json.load(f) + return topic_graph + + def get_related_movies(self, action_U, relation_num): + return self.topic_graph[action_U][0:relation_num] + + def get_name2id(self): + name2id = {} + movie_id = pd.read_csv('TG-ReDial/dataset/TG-ReDial/movie_with_mentions.csv', usecols=[1, 2, 3], encoding='gbk') + movies = movie_id.values.tolist() + for movie in movies: + movie[1] = re.sub('\(\d*\)', '', movie[1]) + movie[1] = re.sub('\(上\)', '', movie[1]) + movie[1] = re.sub('\(下\)', '', movie[1]) + movie[1] = re.sub('\(美版\)', '', movie[1]) + name2id[movie[1]] = str(movie[0]) + return name2id diff --git a/DataLoaderResp.py b/DataLoaderResp.py new file mode 100644 index 0000000..378267a --- /dev/null +++ b/DataLoaderResp.py @@ -0,0 +1,145 @@ +import json +import csv +from tqdm import tqdm +from enum import Enum +from DataProcessor import clip_pad_sentence,clip_pad_context +from Vocab import Vocab +import torch +import pickle +import json as js +import numpy as np + +class DataLoaderResp(): + def __init__(self, args, dataset, vocab): + self.args = args + self.dataset = dataset + self.vocab = vocab + self.final_topic = json.load(open('./dataset/{}/final_topic.json'.format(self.args.dataset), 'r')) + self.processed_data = [] + self.processed_session() + self.user2character_metric = self.get_user2character_metric() + + def __iter__(self): + return self + + def __len__(self): + return len(self.processed_data) + + def __getitem__(self, idx): + return self.processed_data[idx] + + def processed_session(self): + for conv in tqdm(self.dataset): + if len(conv) > 5: + processed_session = self.process(conv) + self.processed_data.extend(processed_session) + + def process(self, conversation): + session_segs = [] + id = int(conversation[0]) + contexts = conversation[-3] + conv_id = conversation[-1] + utterances = conversation[1:-3] + uttr_len = len(utterances) + pv_action = [] + if self.args.dataset == 'TG-ReDial': + skip_len = 2 + elif self.args.dataset == 'PersonaChat': + skip_len = 1 + if self.args.gpt2: + for i in range(len(contexts)): + contexts[i] = [i for i in ''.join(contexts[i])] + for i in range(2, uttr_len, skip_len): + if self.args.dataset == 'PersonaChat' and (utterances[i - 1][2][-1] == '[UNK]' or utterances[i][2][-1] == '[UNK]'): + continue + response = utterances[i] + action_R = response[2] + if action_R == []: + continue + resp = response[0] + if self.args.gpt2: + resp = [i for i in ''.join(resp)] + resp, resp_len = clip_pad_sentence(resp, self.args.r_max_len, self.args.PAD_WORD, sos='[CLS]', eos='[SEP]') + resp = self.vocab.tokenizer.convert_tokens_to_ids(resp) + context = contexts[:i] + context_all, context_all_len = clip_pad_context(context, self.args.context_all_max_len, self.args.PAD_WORD, '[SEP]') + context, context_len = clip_pad_context(context, self.args.context_max_len, self.args.PAD_WORD, '[SEP]', pad_suffix=False) + else: + resp, resp_len = clip_pad_sentence(resp, self.args.r_max_len, self.args.PAD_WORD, sos=self.args.BOS_RESPONSE, eos=self.args.EOS_RESPONSE) + resp = self.vocab.word2index(resp) + context = contexts[:i] + context_all, context_all_len = clip_pad_context(context, self.args.context_all_max_len, self.args.PAD_WORD, self.args.SENTENCE_SPLITER) + context, context_len = clip_pad_context(context, self.args.context_max_len, self.args.PAD_WORD, self.args.SENTENCE_SPLITER) + final_topic_len = len(self.final_topic[str(conv_id) + '/' + str(i+1)]) + if self.args.not_topic_guide: + state_U = response[1][:-final_topic_len] + else: + state_U = response[1] + + topic2context = [] + k = 0 + for topic in state_U[:-final_topic_len]: + if topic in conversation[k+1][-2]: + topic2context.append(k) + else: + while k <= len(conversation) - 1: + if topic in conversation[k + 1][-2]: + topic2context.append(k) + break + k += 1 + for _ in range(final_topic_len): + topic2context.append(i - 1) + if max(topic2context) >= i: + state_U = response[1] + + topic2context = [] + k = 0 + for topic in state_U[:-1]: + if topic in conversation[k + 1][-2]: + topic2context.append(k) + else: + while k <= len(conversation) - 1: + if topic in conversation[k + 1][-2]: + topic2context.append(k) + break + k += 1 + topic2context.append(i - 1) + assert len(state_U) == len(topic2context) + if len(topic2context) >= self.args.state_num: + topic2context = topic2context[-self.args.state_num:] + else: + topic2context = topic2context + [0] * (self.args.state_num - len(topic2context)) + state_U, state_U_len = clip_pad_sentence(state_U, self.args.state_num, self.args.PAD_WORD) + Seeker = utterances[i - 1] + action_U = Seeker[2] + if self.args.gpt2: + context_all_idx = self.vocab.tokenizer.convert_tokens_to_ids(context_all) + context_idx = self.vocab.tokenizer.convert_tokens_to_ids(context) + else: + context_all_idx = self.vocab.word2index(context_all) + context_idx = self.vocab.word2index(context) + state_U = self.vocab.topic2index(state_U) + a_R, a_R_len = clip_pad_sentence(action_R, self.args.action_num, self.args.PAD_WORD) + a_R = self.vocab.topic2index(a_R) + session_segs.append([id, context_all_idx, context_all_len, context_idx, context_len, state_U, state_U_len, a_R, a_R_len, resp, resp_len, topic2context, 0]) + if len(session_segs) != 0: + session_segs[0][-1] = 0 + return session_segs + + def get_user2character_metric(self): + + print('create user2character metric') + max_character_num = max([len(i) for i in self.vocab.user_to_Sentidx.values()]) + user2character_metric = np.zeros((self.vocab.n_user + 1, max_character_num), dtype=int) + for user, sent_list in tqdm(self.vocab.user_to_Sentidx.items()): + user_idx = int(user) + for idx, sent_idx in enumerate(sent_list): + user2character_metric[user_idx, idx] = sent_idx + return user2character_metric + +def one_hot_scatter(indice, num_classes, dtype=torch.float): + indice_shape = list(indice.shape) + placeholder = torch.zeros(*(indice_shape + [num_classes]), device=indice.device, dtype=dtype) + v = 1 if dtype == torch.long else 1.0 + placeholder.scatter_(-1, indice.unsqueeze(-1), v) + return placeholder \ No newline at end of file diff --git a/DataLoaderTopic.py b/DataLoaderTopic.py new file mode 100644 index 0000000..8f81ee9 --- /dev/null +++ b/DataLoaderTopic.py @@ -0,0 +1,460 @@ +import copy +import json +import os +import numpy as np +from enum import Enum +from tqdm import tqdm +from DataProcessor import clip_pad_sentence, clip_pad_context +from Vocab import Vocab +import torch +from torch.utils.data import Dataset +import pickle +import requests +import json as js + +def collate_fn(batch_convs): + nn_inputs = [] + for idx, batch_data in enumerate(zip(*batch_convs)): + try: + nn_inputs.append(torch.tensor(data=batch_data, dtype=torch.long)) + except: + print('here') + return nn_inputs + +class DataLoaderTopic(): + def __init__(self, args, dataset, vocab): + self.args = args + self.dataset = dataset + self.vocab = vocab + self.final_topic = json.load(open('./dataset/{}/final_topic.json'.format(self.args.dataset), 'r')) + self.topic_graph = self.get_topic_graph(dataset) + self.topic_co_graph = self.get_topic_co_graph(dataset) + + if self.args.use_ckg != 0: + self.ckg = self.get_ckg(args.dataset) + + self.user2character_metric = self.get_user2character_metric() + + self.processed_data = [] + self.processed_session() + + + + def __iter__(self): + return self + + def __len__(self): + return len(self.processed_data) + + def __getitem__(self, idx): + return self.processed_data[idx] + + def get_ckg(self, dataset_name): + ckg_path = "./dataset/{}/ckg.pkl".format(dataset_name) + if os.path.exists(ckg_path): + dump_data = pickle.load(open(ckg_path, 'rb')) + ckg = dump_data[0] + relation2idx = dump_data[1] + other_entity2idx = dump_data[2] + else: + ckg = dict() + if dataset_name == "TG-ReDial": + language = 'zh' + elif dataset_name == "PersonaChat": + language = 'en' + other_entity2idx = copy.deepcopy(self.vocab.topic2idx) + relation2idx = {'self_loop': 0} + for topic in tqdm(self.vocab.topic2idx.keys()): + + obj = requests.get('http://222.20.75.16:8884/c/' + language + '/' + topic + '?limit=2000').json() + if topic not in ckg.keys(): + ckg[topic] = set() + if 'error' in obj.keys(): + print(topic) + continue + for edge in obj['edges']: + start = edge['start']['label'] + end = edge['end']['label'] + relation = edge['rel']['label'] + weight = edge['weight'] + + if weight >= 1 and \ + (start in self.vocab.topic2idx.keys() or start in self.vocab.word_list) and \ + (end in self.vocab.topic2idx.keys() or end in self.vocab.word_list): + if start not in ckg.keys(): + ckg[start] = set() + ckg[start].add((relation, end)) + if relation not in relation2idx.keys(): + relation2idx[relation] = len(relation2idx) + relation2idx[relation+'_inv'] = len(relation2idx) + if start not in other_entity2idx.keys(): + other_entity2idx[start] = len(other_entity2idx) + if end not in other_entity2idx.keys(): + other_entity2idx[end] = len(other_entity2idx) + + dump_data = [ckg, relation2idx, other_entity2idx] + with open(ckg_path, 'wb') as keywords_conept_file: + pickle.dump(dump_data, keywords_conept_file) + + self.relation2idx = relation2idx + self.other_entity2idx = other_entity2idx + + ckg_trans = set() + for head, edges in ckg.items(): + head_idx = other_entity2idx[head] + ckg_trans.add((head_idx, relation2idx['self_loop'], head_idx)) + for relation, tail in edges: + relation_id = relation2idx[relation] + relation_inv_id = relation2idx[relation+'_inv'] + tail_id = other_entity2idx[tail] + ckg_trans.add((head_idx, relation_id,tail_id)) + ckg_trans.add((tail_id, relation_inv_id,head_idx)) + edge_set = [[head for (head, relation, tail) in list(ckg_trans)], + [tail for (head, relation, tail) in list(ckg_trans)]] + edge_type = [relation for (head, relation, tail) in list(ckg_trans)] + + self.edge_set = edge_set + self.edge_type = edge_type + + + def statistic(self): + + + user2all_topic = {} + user2session_topic = {} + for conv in self.dataset: + user_idx = int(conv[0]) + topic_list = conv[-1] + if user_idx not in user2all_topic.keys(): + user2all_topic[user_idx] = set() + user2session_topic[user_idx] = [] + user2all_topic[user_idx].update(topic_list) + user2session_topic[user_idx].append(topic_list) + user_cross_session = {'Y':0, 'N':0} + topic_cross_session = {'Y':0, 'N':0} + for user_idx, all_topic in user2all_topic.items(): + if len(user2session_topic[user_idx]) == 1: + continue + user_cross_flag = False + for topic_idx in all_topic: + topic_appear = [topic_idx in cur_topics for cur_topics in user2session_topic[user_idx]] + if sum(topic_appear) > 1: + topic_cross_session['Y'] += 1 + user_cross_flag = True + else: + topic_cross_session['N'] += 1 + if user_cross_flag: + user_cross_session['Y'] += 1 + else: + user_cross_session['N'] += 1 + all_user = (user_cross_session['Y'] + user_cross_session['N']) + user_cross_session['Y'] = user_cross_session['Y'] / all_user + user_cross_session['N'] = user_cross_session['N'] / all_user + all_topic = (topic_cross_session['Y'] + topic_cross_session['N']) + topic_cross_session['Y'] = topic_cross_session['Y'] / all_topic + topic_cross_session['N'] = topic_cross_session['N'] / all_topic + print('user_cross_session', user_cross_session) + print('topic_cross_session', topic_cross_session) + + + topic2user = {} + for user_idx, topic_set in user2all_topic.items(): + for topic in list(topic_list): + if topic not in topic2user.keys(): + topic2user[topic] = set() + topic2user[topic].add(user_idx) + chars_cross_user = {'Y':0, 'N':0} + for topic, user_set in topic2user.items(): + if len(user_set) == 1: + continue + user_set = list(user_set) + for i in range(len(user_set)): + chars_i = set(self.vocab.user_to_Sentidx[str(user_set[i])]) + for j in range(i, len(user_set)): + chars_j = set(self.vocab.user_to_Sentidx[str(user_set[j])]) + if not chars_i.isdisjoint(chars_j): + chars_cross_user['Y'] += 1 + else: + chars_cross_user['N'] += 1 + all_user = chars_cross_user['Y'] + chars_cross_user['N'] + chars_cross_user['Y'] = chars_cross_user['Y'] / all_user + chars_cross_user['N'] = chars_cross_user['N'] / all_user + print('chars_cross_user', chars_cross_user) + + def processed_session(self): + for conv in tqdm(self.dataset): + if len(conv) > 5: + processed_session = self.process(conv) + self.processed_data.extend(processed_session) + + def process(self, conversation): + session_segs = [] + user_id = int(conversation[0]) + contexts = conversation[-3] + all_topics = conversation[-2] + conv_id = conversation[-1] + all_topic, all_topic_len = clip_pad_sentence(all_topics, self.args.all_topic_num, self.args.PAD_WORD) + all_topic = self.vocab.topic2index(all_topic) + utterances = conversation[1:-3] + uttr_len = len(utterances) + pv_action = [] + if self.args.dataset == 'TG-ReDial': + skip_len = 2 + elif self.args.dataset == 'PersonaChat': + skip_len = 1 + for i in range(2, uttr_len, skip_len): + if self.args.dataset == 'PersonaChat' and (utterances[i - 1][2][-1] == '[UNK]' or utterances[i][2][-1] == '[UNK]'): + continue + response = utterances[i] + action_R = response[2] + if action_R == []: + continue + resp = response[0] + resp, resp_len = clip_pad_sentence(resp, self.args.r_max_len, self.args.PAD_WORD, sos=self.args.BOS_RESPONSE, eos=self.args.EOS_RESPONSE) + begin_turn = max(0, i-self.args.history_turn) + context = contexts[begin_turn:i] + context_all, context_all_len = clip_pad_context(context, self.args.context_all_max_len, self.args.PAD_WORD, self.args.SENTENCE_SPLITER) + context, context_len = clip_pad_context(context, self.args.context_max_len, self.args.PAD_WORD, self.args.SENTENCE_SPLITER) + final_topic_len = len(self.final_topic[str(conv_id) + '/' + str(i+1)]) + if self.args.not_topic_guide: + state_U = response[1][:-final_topic_len] + + topic2context = [] + k = 0 + for topic in state_U[:-final_topic_len]: + if topic in conversation[k+1][-2]: + topic2context.append(k) + else: + while k <= len(conversation) - 1: + if topic in conversation[k + 1][-2]: + topic2context.append(k) + break + k += 1 + for _ in range(final_topic_len): + topic2context.append(i - 1) + if max(topic2context) >= i: + state_U = response[1] + + topic2context = [] + k = 0 + for topic in state_U: + if topic in conversation[k + 1][-2]: + topic2context.append(k) + else: + while k <= len(conversation) - 1: + if topic in conversation[k + 1][-2]: + topic2context.append(k) + break + k += 1 + assert len(state_U) == len(topic2context) + topic2context = [topic2context[i]-begin_turn for i in range(len(topic2context)) if topic2context[i] >= begin_turn] + state_U = [state_U[i] for i in range(len(topic2context)) if topic2context[i] >= begin_turn] + if len(topic2context) >= self.args.state_num: + topic2context = topic2context[-self.args.state_num:] + else: + topic2context = topic2context + [0] * (self.args.state_num - len(topic2context)) + else: + state_U = response[1] + + topic2context = [] + k = 0 + for topic in state_U[:-final_topic_len]: + if topic in conversation[k+1][-2]: + topic2context.append(k) + else: + while k <= len(conversation) - 1: + if topic in conversation[k + 1][-2]: + topic2context.append(k) + break + k += 1 + for _ in range(final_topic_len): + topic2context.append(i - 1) + if max(topic2context) >= i: + state_U = response[1] + + topic2context = [] + k = 0 + for topic in state_U[:-1]: + if topic in conversation[k + 1][-2]: + topic2context.append(k) + else: + while k <= len(conversation) - 1: + if topic in conversation[k + 1][-2]: + topic2context.append(k) + break + k += 1 + topic2context.append(i - 1) + assert len(state_U) == len(topic2context) + topic2context = [topic2context[i]-begin_turn for i in range(len(topic2context)) if topic2context[i] >= begin_turn] + state_U = [state_U[i] for i in range(len(topic2context)) if topic2context[i] >= begin_turn] + if len(topic2context) >= self.args.state_num: + topic2context = topic2context[-self.args.state_num:] + else: + topic2context = topic2context + [0] * (self.args.state_num - len(topic2context)) + state_U, state_U_len = clip_pad_sentence(state_U, self.args.state_num, self.args.PAD_WORD) + Seeker = utterances[i - 1] + action_U = Seeker[2] + if action_U != []: + pv_action = action_U + related_topics = self.get_related_topics(pv_action, self.args.relation_num, action_R) + related_topics, related_topics_len = clip_pad_sentence(related_topics, self.args.relation_num, self.args.PAD_WORD) + context_all_idx = self.vocab.word2index(context_all) + context_idx = self.vocab.word2index(context) + response_idx = self.vocab.word2index(resp) + state_U = self.vocab.topic2index(state_U) + related_topics = self.vocab.topic2index(related_topics) + a_R, a_R_len = clip_pad_sentence(action_R, self.args.action_num, self.args.PAD_WORD) + a_R = self.vocab.topic2index(a_R) + + + + session_segs.append([user_id, context_all_idx, context_all_len, context_idx, context_len, state_U, state_U_len, related_topics, related_topics_len, a_R, a_R_len, all_topic, all_topic_len, topic2context, 1, response_idx]) + if len(session_segs) != 0: + session_segs[0][-2] = 0 + return session_segs + + + def get_topic_graph(self, dataset): + + if os.path.exists('./dataset/{}/topic_graph.json'.format(self.args.dataset)): + print('load topic graph') + with open('./dataset/{}/topic_graph.json'.format(self.args.dataset), 'r') as f: + topic_graph = json.load(f) + else: + print('get topic graph') + topic_graph = {} + all_topic_list = list(self.vocab.topic2idx.keys())[0:1] + list(self.vocab.topic2idx.keys())[3:4] + list(self.vocab.topic2idx.keys())[11:] + for topic in all_topic_list: + topic_graph[topic] = set() + for data in dataset: + for idx in range(2, len(data) - 2): + if len(data[idx - 1][2]) != 0 and len(data[idx][2]) != 0: + last_topic = data[idx - 1][2][-1] + cur_topic = data[idx][2][-1] + if last_topic is None: + last_topic = '[UNK]' + if cur_topic is None: + cur_topic = '[UNK]' + topic_graph[last_topic].add(cur_topic) + for key, values in topic_graph.items(): + topic_graph[key] = list(values) + with open('./dataset/{}/topic_graph.json'.format(self.args.dataset), 'w') as f: + json.dump(topic_graph, f) + return topic_graph + + def get_topic_co_graph(self, dataset): + + if os.path.exists('./dataset/{}/topic_co_graph.json'.format(self.args.dataset)): + print('load topic graph') + with open('./dataset/{}/topic_co_graph.json'.format(self.args.dataset), 'r') as f: + topic_graph = json.load(f) + else: + print('get topic graph') + topic_graph = {} + all_topic_list = list(self.vocab.topic2idx.keys())[0:1] + list(self.vocab.topic2idx.keys())[3:4] + list(self.vocab.topic2idx.keys())[11:] + for topic in all_topic_list: + topic_graph[topic] = set() + for data in dataset: + topic_list = data[-2] + for i in range(len(topic_list)): + for j in range(len(topic_list)): + topic_graph[topic_list[i]].add(topic_list[j]) + for key, values in topic_graph.items(): + topic_graph[key] = list(values) + with open('./dataset/{}/topic_co_graph.json'.format(self.args.dataset), 'w') as f: + json.dump(topic_graph, f) + return topic_graph + + def get_related_topics(self, action_U, relation_num, action_R): + + gth = [] + for i in range(0, len(action_R), 2): + gth.append(action_R[i+1]) + related_topics = [] + a_len = len(action_U) + for k in range(0, a_len, 2): + action_type = action_U[k] + topic = action_U[k+1] + if '拒绝' in action_type: + assert a_len > 1 + continue + related_topic = self.topic_graph[topic][0:int(2*relation_num/a_len)] + related_topics.extend(related_topic) + return related_topics + + def get_cut_graph(self): + + if not os.path.exists('./dataset/{}/processed_data/'.format(self.args.dataset)): + os.mkdir('./dataset/{}/processed_data/'.format(self.args.dataset)) + cut_trans_path = './dataset/{}/processed_data/cut_trans.pkl'.format(self.args.dataset) + if os.path.exists(cut_trans_path): + print('load c-u-t graph') + cut_graph = pickle.load(open(cut_trans_path, 'rb')) + edge_set = cut_graph['edge_set'] + edge_type = cut_graph['edge_type'] + else: + print('create c-u-t graph') + + cut_trans = set() + + for conv in self.dataset: + user_idx = int(conv[0]) + self.vocab.topic_len + topic_list = conv[-2] + for topic in topic_list: + topic_idx = self.vocab.topic2idx[topic] + cut_trans.add((user_idx, self.vocab.relation2idx['user2topic'], topic_idx)) + cut_trans.add((topic_idx, self.vocab.relation2idx['user2topic_inv'], user_idx)) + + + + + + + + edge_set = [[head for (head, relation, tail) in list(cut_trans)], [tail for (head, relation, tail) in list(cut_trans)]] + edge_type = [relation for (head, relation, tail) in list(cut_trans)] + cut_graph = {'edge_set': edge_set, 'edge_type': edge_type} + pickle.dump(cut_graph, open(cut_trans_path, 'wb')) + return edge_set, edge_type + + + + def get_user2character_metric(self): + + print('create user2character metric') + max_character_num = max([len(i) for i in self.vocab.user_to_Sentidx.values()]) + user2character_metric = np.zeros((self.vocab.n_user + 1, max_character_num), dtype=int) + for user, sent_list in tqdm(self.vocab.user_to_Sentidx.items()): + user_idx = int(user) + for idx, sent_idx in enumerate(sent_list): + user2character_metric[user_idx, idx] = sent_idx + return user2character_metric + + def show_case(self): + for conversation in self.dataset: + if len(conversation) > 5: + print('\n') + user_id = int(conversation[0]) + contexts = conversation[-3] + all_topics = conversation[-2] + conv_id = conversation[-1] + utterances = conversation[1:-3] + persona_ids = list(self.user2character_metric[user_id]) + persona_sents = [self.vocab.idx_to_userSent[i] for i in persona_ids] + print('user profilie:') + for i in persona_sents: + print('\t'+i) + print() + print('contexts:') + for i in range(len(utterances)): + print('\t'+str(i%2)+' '+ ''.join(utterances[i][0]) + '\t'+' '.join(utterances[i][2])) + a = 1 + + +def one_hot_scatter(indice, num_classes, dtype=torch.float): + indice_shape = list(indice.shape) + placeholder = torch.zeros(*(indice_shape + [num_classes]), device=indice.device, dtype=dtype) + v = 1 if dtype == torch.long else 1.0 + placeholder.scatter_(-1, indice.unsqueeze(-1), v) + return placeholder + diff --git a/DataProcessor.py b/DataProcessor.py new file mode 100644 index 0000000..f8c9402 --- /dev/null +++ b/DataProcessor.py @@ -0,0 +1,702 @@ +import json +import os +import pickle +import json +import copy +from tqdm import tqdm +import re +import collections +import numpy as np +import nltk +from nltk.stem import WordNetLemmatizer +_lemmatizer = WordNetLemmatizer() + +def tokenize(example, ppln): + for fn in ppln: + example = fn(example) + return example + + +def kw_tokenize(string): + return tokenize(string, [nltk_tokenize, lower, pos_tag, to_basic_form]) + + +def simp_tokenize(string): + return tokenize(string, [nltk_tokenize, lower]) + + +def nltk_tokenize(string): + return nltk.word_tokenize(string) + + +def lower(tokens): + if not isinstance(tokens, str): + return [lower(token) for token in tokens] + return tokens.lower() + + +def pos_tag(tokens): + return nltk.pos_tag(tokens) + + +def to_basic_form(tokens): + if not isinstance(tokens, tuple): + return [to_basic_form(token) for token in tokens] + word, tag = tokens + if tag.startswith('NN'): + pos = 'n' + elif tag.startswith('VB'): + pos = 'v' + elif tag.startswith('JJ'): + pos = 'a' + else: + return word + return _lemmatizer.lemmatize(word, pos) + + +class DataSet(): + def __init__(self, args, vocab): + super(DataSet, self).__init__() + self.args = args + self.vocab = vocab + self.userSet = set() + self.topics = self.get_topics() + if os.path.exists('./dataset/{}/final_topic.json'.format(self.args.dataset)): + self.final_topic = json.load(open('./dataset/{}/final_topic.json'.format(self.args.dataset), 'r')) + else: + self.final_topic = None + + def get_dialog(self, task): + from DataLoaderTopic import DataLoaderTopic + from DataLoaderResp import DataLoaderResp + if self.args.processed: + with open('./dataset/{}/train_topic.pkl'.format(self.args.dataset), 'rb+') as train_set: + train = pickle.load(train_set) + with open('./dataset/{}/valid_topic.pkl'.format(self.args.dataset), 'rb+') as valid_set: + valid = pickle.load(valid_set) + with open('./dataset/{}/test_topic.pkl'.format(self.args.dataset), 'rb+') as test_set: + test = pickle.load(test_set) + all = [train, valid, test] + users = [] + for dataset in all: + for data in dataset: + user_id = data[0] + if user_id not in users: + user_id = int(user_id) + users.append(user_id) + user_cont = max(users)+1 + + if task == 'topic': + train_set = DataLoaderTopic(self.args, train, self.vocab) + valid_set = DataLoaderTopic(self.args, valid, self.vocab) + test_set = DataLoaderTopic(self.args, test, self.vocab) + elif task == 'gene': + train_set = DataLoaderResp(self.args, train, self.vocab) + valid_set = DataLoaderResp(self.args, valid, self.vocab) + test_set = DataLoaderResp(self.args, test, self.vocab) + return train_set, valid_set, test_set, users, user_cont + else: + if self.args.dataset == 'TG-ReDial': + train_data = pickle.load(open('dataset/TG-ReDial/train_data.pkl', 'rb+'))[:] + valid_data = pickle.load(open('dataset/TG-ReDial/valid_data.pkl', 'rb+'))[:] + test_data = pickle.load(open('dataset/TG-ReDial/test_data.pkl', 'rb+'))[:] + def _excute_data(conversations): + convs = [] + for conversation in tqdm(conversations): + conv_id, user_id, conv_id, utterances, topic_thread, movies = conversation['conv_id'], conversation['user_id'], conversation['conv_id'], conversation['messages'], conversation['goal_path'], conversation['mentionMovies'] + conv = [] + self.userSet.add(user_id) + conv.append(user_id) + contents_word = [] + states = [] + alltopic = [] + ks = 1 + + + for utterance in utterances: + processed_sentence = [] + utter_round, role, content = int(utterance['local_id']), utterance['role'], utterance['content'] + goal = topic_thread[utter_round] if utter_round != 1 else [0] + action, topics = self.get_action(goal, movies, utter_round) + if '推荐电影' in action or '反馈,结束' in action: + action, topics = [], [] + + + if utter_round != 1: + final_topic = self.get_final_topic(conv_id, utter_round) + final_states = states.copy() + for topic in final_topic: + final_states.append(topic) + word_level, word2token, leng, ks = self.tokenize_sentence(content,movies,utter_round,ks) + else: + action = [] + final_states = states + word_level, word2token, leng, ks = self.tokenize_sentence(content,movies,utter_round,ks) + contents_word.append(word_level) + processed_sentence.append(word_level) + processed_sentence.append(final_states.copy()) + processed_sentence.append(action) + processed_sentence.append([utter_round]) + conv.append(processed_sentence) + for topic in topics: + states.append(topic) + for topic in states: + if topic not in alltopic and topic is not None: + alltopic.append(topic) + conv.append(contents_word) + conv.append(alltopic) + conv.append(conv_id) + convs.append(conv) + return convs + train = _excute_data(train_data) + valid = _excute_data(valid_data) + test = _excute_data(test_data) + elif self.args.dataset == 'PersonaChat': + all_data = open('dataset/PersonaChat/ConvAI2/train_both_original_no_cands.txt', 'r').readlines() + open( + 'dataset/PersonaChat/ConvAI2/valid_both_original_no_cands.txt', 'r').readlines() + all_data = self.process_raw_data(all_data, 0) + self.get_vocab(all_data) + idf_dict = self.cal_idf(all_data) + + kg_1hop_triples = pickle.load(open('dataset/PersonaChat/dict_file_1hop.pkl', 'rb')) + kg_1hop = {} + for head, tails in kg_1hop_triples.items(): + kg_1hop[head] = [] + for relation, tail in tails: + kg_1hop[head].append(tail) + self.kg_1hop = kg_1hop + all_data = self.get_topic(all_data, idf_dict) + + if True: + self.final_topic = self.get_all_final_topic(all_data, target_len='kg') + kw_counter = collections.Counter() + for data in all_data: + kw_counter.update(data['all_topics']) + kw_freq = {} + kw_sum = sum(kw_counter.values()) + for k, v in kw_counter.most_common(): + kw_freq[k] = v / kw_sum + for data in all_data: + data['score'] = 0. + for kw in set(data['all_topics']): + data['score'] += kw_freq[kw] + data['score'] /= len(set(data['all_topics'])) + all_data.sort(key=lambda x: x['score'], reverse=True) + train_data, valid_data, test_data = [], [], [] + all_dataset_num = len(all_data) + test_end_id = 500 + valid_end_id = 500 + int((all_dataset_num - 500) * 0.05) + for idx, data in enumerate(all_data): + if idx < test_end_id: + test_data.append(data) + elif idx < valid_end_id: + valid_data.append(data) + else: + train_data.append(data) + def _excute_data(conversations, dataset): + + convs = [] + for idx, conversation in enumerate(tqdm(conversations)): + user_id, conv_id, utterances, topic_thread, movies = conversation['user_id'], conversation['conv_id'], conversation['messages'], conversation['goal_path'], conversation['mentionMovies'] + conv = [] + self.userSet.add(user_id) + conv.append(str(user_id)) + contents_word = [] + states = [] + alltopic = [] + ks = 1 + for utterance in utterances: + processed_sentence = [] + utter_round, role, content = int(utterance['local_id']), utterance['role'], utterance['content'] + goal = topic_thread[utter_round] + action, topics = self.get_action(goal, movies, utter_round) + final_topic = self.get_final_topic(conv_id, utter_round) + final_states = states.copy() + for topic in final_topic: + final_states.append(topic) + if dataset == 'TG-ReDial': + word_level, word2token, leng, ks = self.tokenize_sentence(content, movies, utter_round, ks) + elif dataset == 'PersonaChat': + word_level = self.chinese_tokenize_sentence(content) + contents_word.append(word_level) + processed_sentence.append(word_level) + processed_sentence.append(final_states.copy()) + processed_sentence.append(action) + processed_sentence.append([utter_round]) + conv.append(processed_sentence) + for topic in topics: + if topic is None: + states.append('[UNK]') + else: + states.append(topic) + for topic in states: + if topic not in alltopic: + alltopic.append(topic) + conv.append(contents_word) + conv.append(alltopic) + conv.append(conv_id) + convs.append(conv) + return convs + train = _excute_data(train_data, self.args.dataset) + valid = _excute_data(valid_data, self.args.dataset) + test = _excute_data(test_data, self.args.dataset) + with open('./dataset/{}/train_topic.pkl'.format(self.args.dataset), 'wb+') as f: + pickle.dump(train, f) + with open('./dataset/{}/valid_topic.pkl'.format(self.args.dataset), 'wb+') as f: + pickle.dump(valid, f) + with open('./dataset/{}/test_topic.pkl'.format(self.args.dataset), 'wb+') as f: + pickle.dump(test, f) + train_set = DataLoaderTopic(self.args, train, self.vocab) + valid_set = DataLoaderTopic(self.args, valid, self.vocab) + test_set = DataLoaderTopic(self.args, test, self.vocab) + return train_set, valid_set, test_set, self.userSet, len(self.userSet) + + def get_vocab(self, dataset): + counter = collections.Counter() + for data in dataset: + dialog = data['messages'] + for uttr in dialog: + counter.update(simp_tokenize(uttr['content'])) + print('total vocab count: ', len(counter.items())) + sepetial_vocab = ['[PAD]', '[s_context]', '[ / s_context]', '[s_response >]', '[ / s_response]', '[sent]', '[UNK]', '[CLS]', '[SEP]', '[MASK]'] + vocab = sepetial_vocab + [token for token, times in sorted(list(counter.items()), key=lambda x: (-x[1], x[0]))] + with open('./dataset/{}/tpvocab.txt'.format(self.args.dataset), 'w') as f: + for word in vocab: + f.write(word + '\n') + print('save vocab in vocab.txt') + + def cal_idf(self, dataset): + print('get topic cal idf') + counter = collections.Counter() + total = 0. + for data in tqdm(dataset): + dialog = data['messages'] + for uttr in dialog: + total += 1 + counter.update(set(kw_tokenize(uttr['content']))) + idf_dict = {} + for k, v in counter.items(): + idf_dict[k] = np.log10(total / (v+1.)) + return idf_dict + + def get_topic(self, dataset, idf_dict): + print('extract topic') + keyword_extractor = KeywordExtractor(self.args, idf_dict, self.kg_1hop) + for data in tqdm(dataset): + dialog = data['messages'] + data['all_topics'] = [] + his_topic = None + for uttr in dialog: + topic = keyword_extractor.idf_extract(uttr['content'], his_topic=his_topic) + his_topic = topic + data['goal_path'][uttr['local_id']] = [uttr['role'], '谈论', topic] + if topic != None: + data['all_topics'].append(topic) + return dataset + + def process_raw_data(self, raw_data: list, conv_id: int): + print('process_raw_data') + data_list = [] + processed_data = { + 'conv_id': conv_id, + 'messages': [], + 'goal_path': {}, + 'mentionMovies': {}, + 'user_id': conv_id, + 'user_profile': [] + } + role = ['Recommender', 'Seeker'] + local_id = 1 + for idx, line in enumerate(tqdm(raw_data)): + line = line.strip() + if line[:2] == '1 ' and idx != 0: + + user_profile_set = processed_data['user_profile'] + user_profile_set = frozenset([self.vocab.userSent_to_idx[sent] for sent in user_profile_set]) + user_id = self.vocab.Sentset_to_user[user_profile_set] + processed_data['user_id'] = user_id + data_list.append(processed_data) + conv_id += 1 + local_id = 1 + processed_data = { + 'conv_id': conv_id, + 'messages': [], + 'goal_path': {}, + 'mentionMovies': {}, + 'user_id': None, + 'user_profile': [] + } + if 'your persona: ' in line: + line = line[line.find('your persona: ') + len('your persona: '):] + processed_data['user_profile'].append(line) + elif "partner's persona: " in line: + line = line[line.find("partner's persona: ") + len("partner's persona: "):] + processed_data['user_profile'].append(line) + else: + line = line[line.find(" ") + 1:] + line = line.split('\t') + processed_data['messages'].append({ + 'local_id': local_id, + 'role': role[local_id % 2], + 'content': line[0] + }) + local_id += 1 + processed_data['messages'].append({ + 'local_id': local_id, + 'role': role[local_id % 2], + 'content': line[1] + }) + local_id += 1 + user_profile_set = processed_data['user_profile'] + user_profile_set = frozenset([self.vocab.userSent_to_idx[sent] for sent in user_profile_set]) + user_id = self.vocab.Sentset_to_user[user_profile_set] + processed_data['user_id'] = user_id + data_list.append(processed_data) + return data_list + + def chinese_tokenize_sentence(self, sentence: str): + return simp_tokenize(sentence) + + def tokenize_sentence(self,sentence: str, movies, turn, ks): + raw_sentence = copy.copy(sentence) + if turn in movies: + assert "《" in sentence and "》" in sentence + movie_id = movies[turn][0] + con = re.sub(r'《(.*)》', '', sentence) + split_content = con.split('') + sentence = split_content[0] + '' + split_content[1] + processed_sentence = [] + while (sentence): + flag = 0 + for topic in self.topics: + if topic in sentence: + idx = sentence.index(topic) + if idx == 0: + flag = 1 + processed_sentence.append(topic) + sentence = sentence[len(topic):] + continue + if turn in movies and movies[turn][0] in sentence: + if sentence.index(movies[turn][0]) == 0: + flag = 1 + processed_sentence.append(movies[turn][0]) + sentence = sentence[len(movies[turn][0]):] + if flag == 0: + word = sentence[0] + processed_sentence.append(word) + sentence = sentence[1:] + word2token = [] + for word in processed_sentence: + if word == '': + length = 3 + else: + length = len(word) + word2token.append([ks+j for j in range(length)]) + ks+=length + leng = [] + for word in word2token: + leng.append(len(word)) + word2token_pad = [] + for i in range(len(word2token)): + word = word2token[i] + length = leng[i] + pad_token = word + [0]*(10-length) + word2token_pad.append(pad_token) + if turn == 1: + processed_sentence = [i for i in raw_sentence] + return processed_sentence, word2token_pad, leng, ks + + def get_action(self, goals, movies, utter_round): + action = [] + topic_path = [] + goal = goals[1:] + if '反馈' in goal: + assert goal[0] == '反馈' + goal = goal[2:4] + + + if '谈论' in goal and '请求推荐' in goal: + goal = goal[:2] + if len(goal) == 2: + action_type = goal[0] + topics = goal[1] + if '推荐电影' in action_type: + if isinstance(topics, str): + action.append(action_type) + movie = movies[utter_round][0] + action.append('') + if '拒绝' not in action_type: + topic_path.append(movie) + elif isinstance(topics, list): + for topic in topics: + action.append(action_type) + action.append('') + if '拒绝' not in action_type: + topic_path.append(topic) + else: + if isinstance(topics, str): + action.append(action_type) + action.append(topics) + topic_path.append(topics) + elif isinstance(topics, list): + for topic in topics: + action.append(action_type) + action.append(topic) + topic_path.append(topic) + elif topics is None: + action.append(action_type) + action.append('[UNK]') + topic_path.append(topics) + elif len(goal) == 4: + for i in range(0, 4, 2): + action_type = goal[i] + topics = goal[i + 1] + if '推荐电影' in action_type: + continue + + + + + + + + + + + + + else: + if isinstance(topics, str): + action.append(action_type) + action.append(topics) + + + if '拒绝' not in action_type: + topic_path.append(topics) + if isinstance(topics, list): + for topic in topics: + action.append(action_type) + action.append(topic) + if '拒绝' not in action_type: + topic_path.append(topic) + return action, topic_path + + def get_state(self,action): + state = [] + delete_state = [] + + action_len = len(action) + for k in range(0, action_len, 2): + action_type = action[k] + topic = action[k+1] + if '拒绝' in action_type: + delete_state.append(topic) + else: + state.append(topic) + return state, delete_state + + def get_final_topic(self, conv_id, utter_id): + kw_list = [] + conv_id = str(conv_id) + utter_id = str(utter_id) + identity = conv_id + '/' + utter_id + if identity in self.final_topic: + kw_list = self.final_topic[identity] + return kw_list + + def get_all_final_topic(self, dataset, target_len=None): + print('get_all_final_topic') + all_trans = 0. + all_num = 0. + all_final_topic = {} + for data in tqdm(dataset): + conv_id = str(data['conv_id']) + if target_len is None: + if len([topic[-1] for idx, topic in data['goal_path'].items() if topic[-1] is not None]) != 0: + final_topic = [topic[-1] for idx, topic in data['goal_path'].items() if topic[-1] is not None][-1] + else: + final_topic = '[UNK]' + dialog = data['messages'] + for uttr in dialog[1:]: + utter_id = str(uttr['local_id']) + identity = conv_id + '/' + utter_id + if target_len is None: + all_final_topic[identity] = [final_topic] + elif target_len == 'kg': + for j in range(uttr['local_id'], len(dialog)+1): + if data['goal_path'][j][-1] is None or data['goal_path'][j-1][-1] not in self.kg_1hop.keys() or data['goal_path'][j][-1] not in self.kg_1hop[data['goal_path'][j-1][-1]]: + break + all_final_topic[identity] = [data['goal_path'][min(len(dialog), j)][-1]] + all_trans += max(1, j+1-uttr['local_id']) + all_num += 1 + else: + all_final_topic[identity] = [final_topic if uttr['local_id']+target_len > len(dialog) else data['goal_path'][uttr['local_id']+target_len][-1]] + with open('./dataset/{}/final_topic.json'.format(self.args.dataset), 'w') as f: + json.dump(all_final_topic, f) + if all_num != 0: + print('avg trans hop is ', all_trans/all_num) + return all_final_topic + + def get_topics(self): + topic_file = open(self.args.topic_file.format(self.args.dataset), encoding='utf-8') + topic_vocab = [] + for line in topic_file.readlines(): + line = line.strip('\n') + topic_vocab.append(line) + return topic_vocab + + def get_sparsity(self): + + with open('./dataset/{}/train_topic.pkl'.format(self.args.dataset), 'rb+') as train_set: + train = pickle.load(train_set) + with open('./dataset/{}/valid_topic.pkl'.format(self.args.dataset), 'rb+') as valid_set: + valid = pickle.load(valid_set) + with open('./dataset/{}/test_topic.pkl'.format(self.args.dataset), 'rb+') as test_set: + test = pickle.load(test_set) + data_all = train + valid + test + user_set = set([int(data[0]) for data in data_all]) + user2topic = np.zeros((self.vocab.n_user + 1, self.vocab.topic_len)) + for data in tqdm(data_all): + user_id = int(data[0]) + topics = self.vocab.topic2index(data[-2]) + user2topic[user_id, topics] = 1 + all_interacions_num = user2topic.sum() + Sparsity = 1 - all_interacions_num / (len(user_set) * self.vocab.topic_len) + print('Sparsity is ', Sparsity) + + def get_co_topic(self, datasets): + + co_topic_path = './dataset/{}/processed_data/co_topic.pkl'.format(self.args.dataset) + if os.path.exists(co_topic_path) and False: + print('load co-occurrence topic') + co_topic = pickle.load(open(co_topic_path, 'rb')) + co_topic_graph = co_topic['co_topic_graph'] + persona_co_topic = co_topic['persona_co_topic'] + else: + print('create co-occurrence topic') + co_topic_graph = np.zeros([self.vocab.topic_len, self.vocab.topic_len], dtype=np.int32) + persona_co_topic = np.zeros([self.vocab.n_character, self.vocab.topic_len, self.vocab.topic_len], dtype=np.int8) + for dataset in datasets: + for conv in tqdm(dataset): + user_idx = int(conv[0]) + topic_list = conv[-2] + personas = self.vocab.user_to_Sentidx[str(user_idx)] + for i in range(len(topic_list)): + for j in range(i, len(topic_list)): + idx = self.vocab.topic2idx[topic_list[i]] + jdx = self.vocab.topic2idx[topic_list[j]] + co_topic_graph[idx, jdx] += 1 + co_topic_graph[jdx, idx] += 1 + for pid in personas: + persona_co_topic[pid, idx, jdx] += 1 + persona_co_topic[pid, jdx, idx] += 1 + co_topic = {'co_topic_graph': co_topic_graph, 'persona_co_topic': persona_co_topic} + pickle.dump(co_topic, open(co_topic_path, 'wb')) + +class KeywordExtractor(): + def __init__(self, args, idf_dict=None, kg_1hop=None): + self.args = args + self.idf_dict = idf_dict + candi_keyword_path = args.topic_file + self.candiwords = [x.strip() for x in open(candi_keyword_path).readlines()] + self.kg_1hop = kg_1hop + + @staticmethod + def is_keyword_tag(tag): + return tag.startswith('VB') or tag.startswith('NN') or tag.startswith('JJ') + + @staticmethod + def cal_tag_score(tag): + if tag.startswith('VB'): + return 1. + if tag.startswith('NN'): + return 2. + if tag.startswith('JJ'): + return 0.5 + return 0. + + def idf_extract(self, string, con_kw=None, his_topic=None): + tokens = simp_tokenize(string) + seq_len = len(tokens) + tokens = pos_tag(tokens) + source = kw_tokenize(string) + candi = [] + for i, (word, tag) in enumerate(tokens): + score = self.cal_tag_score(tag) + if source[i] not in self.candiwords or score == 0.: + continue + if con_kw is not None and source[i] in con_kw: + continue + score *= source.count(source[i]) + score *= 1 / seq_len + score *= self.idf_dict[source[i]] + candi.append((source[i], score)) + + + + if len(candi) > 0: + if his_topic is not None: + kg_candi = [(i, j) for (i, j) in candi if i in self.kg_1hop[his_topic]] + if len(kg_candi) != 0: + max_idx = np.argmax([i[1] for i in kg_candi]) + topic = kg_candi[max_idx][0] + else: + max_idx = np.argmax([i[1] for i in candi]) + topic = candi[max_idx][0] + else: + max_idx = np.argmax([i[1] for i in candi]) + topic = candi[max_idx][0] + else: + topic = None + return topic + + def extract(self, string): + tokens = simp_tokenize(string) + tokens = pos_tag(tokens) + source = kw_tokenize(string) + kwpos_alters = [] + for i, (word, tag) in enumerate(tokens): + if source[i] and self.is_keyword_tag(tag): + kwpos_alters.append(i) + _, keywords = [], [] + for id in kwpos_alters: + if source[id]: + keywords.append(source[id]) + return list(set(keywords)) + +def clip_pad_sentence(sentence, max_len, pad, sos=None, eos=None, save_prefix=False, pad_suffix=True, return_length=True): + ml = max_len + if eos is not None: + ml = ml - 2 + if save_prefix: + sentence = sentence[:ml] + else: + sentence = sentence[-ml:] + if eos is not None: + sentence = [sos] + sentence + sentence = sentence + [eos] + length = None + if return_length: + length = len(sentence) + if pad_suffix: + sentence += [pad] * (max_len - len(sentence)) + else: + sentence = [pad] * (max_len - len(sentence)) + sentence + if not return_length: + return sentence + return sentence, length + +def clip_pad_context(context, max_len, pad, sent, pad_suffix=True): + sentence = [] + for turn in context: + turn = turn + [sent] + sentence = sentence + turn + real_len = len(sentence) + if real_len > max_len: + sentence = sentence[-max_len:] + else: + if pad_suffix: + sentence = sentence + [pad] * (max_len - real_len) + else: + sentence = [pad] * (max_len - real_len) + sentence + return sentence, real_len \ No newline at end of file diff --git a/OPT.py b/OPT.py new file mode 100644 index 0000000..1ba0952 --- /dev/null +++ b/OPT.py @@ -0,0 +1,83 @@ +import torch +from mip import Model, xsum, minimize, maximize, BINARY + +def IPconstraint_and_solve(S, K=50, mean_row_min=1, mean_col_min=1, union=True): + + + S_raw = S.clone() + S = S.detach().cpu() + S_topk_row, top_id_row = S.topk(K, dim=1, largest=False) if K < S.size(1) else S.topk(S.size(1), dim=1, largest=False) + S_topk_col, top_id_col = S.topk(K, dim=0, largest=False) if K < S.size(1) else S.topk(S.size(0), dim=0, largest=False) + + S_empty_col = torch.zeros(S.size(0)).unsqueeze(1) + S_empty_row = torch.zeros(S.size(1)).unsqueeze(0) + c, top_id, col_constraint = quick_extract_col_constraint_list(S, S_topk_row, top_id_row, S_empty_row, S_topk_col, top_id_col, S_empty_col, union) + + + + IPmodel = Model() + + x = [[IPmodel.add_var(var_type=BINARY) for j in range(len(c[i]))] for i in range(len(c))] + IPmodel.objective = minimize(xsum(c[i][j] * x[i][j] for i in range(len(x)) for j in range(len(x[i])))) + for i in range(len(x) - 1): + IPmodel += xsum(x[i][j] for j in range(len(x[i]))) == 10 + + + for i in range(len(col_constraint)): + IPmodel += xsum(x[j[0]][j[1]] for j in col_constraint[i]) == 10 + + + IPmodel.optimize() + + result_g1 = [[] for i in range(len(top_id))] + for i in range(len(x)): + for j in range(len(x[i])): + if x[i][j].x > 0: + result_g1[i].append(top_id[i][j]) + + topic_ids = torch.tensor(result_g1[:-1], device=S.device, dtype=torch.long) + topic_att = torch.stack([S_raw[i, topic_ids[i]] for i in range(S.size(0))], dim=0) + return topic_att, topic_ids + +def quick_extract_col_constraint_list(S, S_topk_row, top_id_row, S_empty_row, S_topk_col, top_id_col, S_empty_col, union=True): + top_id = [set(row.numpy().tolist()) for row in top_id_row] + for i in range(len(top_id)): + + location = torch.nonzero(torch.where(top_id_col == i, torch.tensor(1), torch.tensor(0))) + if union: + top_id[i] = top_id[i].union(set(location[:, 1].numpy().tolist())) + else: + top_id[i] = top_id[i].intersection(set(location[:, 1].numpy().tolist())) + top_id[i].add(S.size(1)) + + for i in range(len(top_id)): + top_id[i] = list(top_id[i]) + top_id[i].sort() + + top_id.append(torch.arange(S.size(1)).numpy().tolist()) + + seleted = torch.zeros(S.size(0) + 1, S.size(1) + 1).long() + for i in range(len(top_id)): + seleted[i, top_id[i]] = 1 + + row_nonzero_seleted = torch.nonzero(seleted).numpy().tolist() + col_nonzero_seleted = torch.nonzero(seleted.T).numpy().tolist() + temp = torch.zeros(seleted.size(0)).long() + for line in row_nonzero_seleted: + seleted[line[0], line[1]] = temp[line[0]] + temp[line[0]] += 1 + + quick_col_constraint = [[] for i in range(S.size(1))] + for line in col_nonzero_seleted: + if line[0] < S.size(1): + quick_col_constraint[line[0]].append([line[1], seleted[line[1], line[0]].numpy().tolist()]) + c = [[] for i in range(S.size(0))] + for i in range(S.size(0)): + for j in top_id[i]: + if j < S.size(1): + c[i].append(S[i, j].numpy().tolist()) + elif j == S.size(1): + c[i].append(S_empty_col[i, 0].numpy().tolist()) + + c = c + S_empty_row.numpy().tolist() + return c, top_id, quick_col_constraint \ No newline at end of file diff --git a/PreferenceTopic.py b/PreferenceTopic.py new file mode 100644 index 0000000..d92bbcf --- /dev/null +++ b/PreferenceTopic.py @@ -0,0 +1,214 @@ +from gumbel_softmax import GumbelSoftmax +from tau_scheduler import TauScheduler +import torch +import torch.nn as nn +from tools import Tools + +class PriorPreference(nn.Module): + def __init__(self, args, encoder, decoder, hidden_size, n_topic_vocab, + trg_bos_idx, max_seq_len, glo2loc, loc2glo, main_tfr_encoder, + gs: GumbelSoftmax, ts: TauScheduler): + super(PriorPreference, self).__init__() + self.args = args + self.decoder = decoder + self.p_encoder = encoder + self.main_tfr_encoder = main_tfr_encoder + self.n_topic_vocab = n_topic_vocab + self.bos_idx = trg_bos_idx + self.max_seq_len = max_seq_len + self.gs = gs + self.ts = ts + self.glo2loc = glo2loc + self.loc2glo = loc2glo + self.hidden_size = hidden_size + self.gen_proj = nn.Sequential(nn.Linear(self.hidden_size, self.n_topic_vocab)) + + def forward(self, context, context_len, tp_path, tp_path_len, encoder_embed=None, decoder_embed=None): + bs = context.size(0) + context_mask = Tools.get_mask_via_len(context_len, self.args.context_max_len) + context_hidden = self.main_tfr_encoder(context, context_mask) + + tp_path = one_hot_scatter(tp_path, self.n_topic_vocab) + tp_mask = Tools.get_mask_via_len(tp_path_len, self.args.state_num) + tp_hidden = self.p_encoder(tp_path, tp_mask, embed=encoder_embed) + + src_hiddens = torch.cat([tp_hidden, context_hidden], 1) + src_mask = torch.cat([tp_mask, context_mask], 2) + seq_gen_gumbel = Tools._generate_init(bs, self.n_topic_vocab, trg_bos_idx=self.bos_idx, training=self.training) + seq_gen_prob = None + seq_gen_prob_raw = None + for _ in range(self.args.preference_num): + dec_output = Tools._single_decode(seq_gen_gumbel.detach(), src_hiddens, src_mask, self.decoder) + single_step_prob, single_step_prob_raw = self.proj(dec_out=dec_output, context=context, src_hidden=src_hiddens, src_mask=src_mask, tp_path=tp_path, embed=decoder_embed) + single_step_gumbel_word = self.gs.forward(single_step_prob, self.ts.step_on(), normed=True) + if self.training: + if seq_gen_prob is not None: + seq_gen_prob = torch.cat([seq_gen_prob, single_step_prob], 1) + seq_gen_prob_raw = torch.cat([seq_gen_prob_raw, single_step_prob_raw], 1) + else: + seq_gen_prob = single_step_prob + seq_gen_prob_raw = single_step_prob_raw + seq_gen_gumbel = torch.cat([seq_gen_gumbel, single_step_gumbel_word], 1) + else: + if seq_gen_prob is not None: + seq_gen_prob = torch.cat([seq_gen_prob, single_step_prob], 1) + else: + seq_gen_prob = single_step_prob + single_step_word = torch.argmax(single_step_prob, -1) + seq_gen_gumbel = torch.cat([seq_gen_gumbel, single_step_word], 1) + if self.training: + return seq_gen_prob, seq_gen_gumbel[:, 1:, :] + else: + return seq_gen_prob, seq_gen_gumbel[:, 1:] + + def proj(self, dec_out, context, src_hidden, src_mask, tp_path, embed=None): + gen_logit = self.gen_prob(dec_out, embed) + L_s = dec_out.size(1) + B = context.size(0) + copy_logit = torch.bmm(dec_out, src_hidden.permute(0, 2, 1)) + copy_logit = copy_logit.masked_fill((src_mask == 0).expand(-1, L_s, -1), -1e9 if copy_logit.dtype==torch.float32 else -1e4) + logits = torch.cat([gen_logit, copy_logit], -1) + if self.args.scale_prj: + logits *= self.hidden_size ** -0.5 + probs = torch.softmax(logits, -1) + gen_prob = probs[:, :, :self.n_topic_vocab] + copy_context_prob = probs[:, :, self.n_topic_vocab + self.args.state_num:] + transfer_context_word = torch.gather(self.glo2loc.unsqueeze(0).expand(B, -1), 1, context) + copy_context_temp = copy_context_prob.new_zeros(B, L_s, self.n_topic_vocab) + copy_context_prob = copy_context_temp.scatter_add(dim=2, + index=transfer_context_word.unsqueeze(1).expand(-1, L_s, -1), + src=copy_context_prob) + copy_tp_prob = probs[:, :, self.n_topic_vocab + self.args.preference_num:self.n_topic_vocab + self.args.preference_num + self.args.state_num] + copy_tp_prob = torch.bmm(copy_tp_prob, tp_path) + probs = gen_prob + copy_tp_prob + copy_context_prob + + probs_raw = logits + gen_prob_raw = probs_raw[:, :, :self.n_topic_vocab] + copy_tp_prob_raw = probs_raw[:, :, self.n_topic_vocab + self.args.preference_num:self.n_topic_vocab + self.args.preference_num + self.args.state_num] + copy_tp_prob_raw = torch.bmm(copy_tp_prob_raw, tp_path) + copy_context_prob_raw = probs_raw[:, :, self.n_topic_vocab + self.args.state_num:] + copy_context_temp_raw = copy_context_prob.new_zeros(B, L_s, self.n_topic_vocab) + copy_context_prob_raw = copy_context_temp_raw.scatter_add(dim=2, + index=transfer_context_word.unsqueeze(1).expand(-1, + L_s, + -1), + src=copy_context_prob_raw) + probs_raw = gen_prob_raw + copy_tp_prob_raw + copy_context_prob_raw + return probs, probs_raw + + def gen_prob(self, dec_output, embed=None): + + if embed is None: + prob = self.gen_proj(dec_output) + else: + assert embed.size(0) == self.gen_proj[0].out_features + prob = dec_output.matmul(embed.T) + return prob + + +class PosteriorPreference(nn.Module): + def __init__(self, encoder, decoder, main_encoder, hidden_size, n_topic_vocab, glo2loc, loc2glo, + trg_bos_idx, max_seq_len, gs: GumbelSoftmax, ts: TauScheduler): + super(PosteriorPreference, self).__init__() + self.p_encoder = encoder + self.decoder = decoder + self.main_encoder = main_encoder + self.glo2loc = glo2loc + self.loc2glo = loc2glo + self.hidden_size = hidden_size + self.n_topic_vocab = n_topic_vocab + self.trg_bos_idx = trg_bos_idx + self.max_seq_len = max_seq_len + self.gs = gs + self.ts = ts + self.gen_proj = nn.Sequential(nn.Linear(self.hidden_size, self.n_topic_vocab)) + + def forward(self, context, context_len, ar_gth, ar_gth_len, tp_path, tp_path_len, encoder_embed=None, decoder_embed=None): + bs = context.size(0) + context_mask = Tools.get_mask_via_len(context_len, self.args.context_max_len) + context_hidden = self.main_encoder(context, context_mask) + tp_path = one_hot_scatter(tp_path, self.n_topic_vocab) + tp_mask = Tools.get_mask_via_len(tp_path_len, self.args.state_num) + tp_hidden = self.p_encoder(tp_path, tp_mask, embed=encoder_embed) + if ar_gth is not None and ar_gth_len is not None: + ar_gth_len = [int(length / 2) for length in ar_gth_len] + ar_gth_len = torch.tensor(ar_gth_len).cuda() + ar_gth = ar_gth[:, list(range(1, self.args.action_num, 2))] + ar_gth = one_hot_scatter(ar_gth, self.n_topic_vocab) + ar_mask = Tools.get_mask_via_len(ar_gth_len, int(self.args.action_num / 2)) + ar_hidden = self.p_encoder(ar_gth, ar_mask, embed=encoder_embed) + src_hiddens = torch.cat([tp_hidden, context_hidden, ar_hidden], 1) + ar_mask[ar_mask] = False + src_mask = torch.cat([tp_mask, context_mask, ar_mask], 2) + else: + src_hiddens = torch.cat([tp_hidden, context_hidden], 1) + src_mask = torch.cat([tp_mask, context_mask], 2) + + seq_gen_gumbel = Tools._generate_init(bs, self.n_topic_vocab, trg_bos_idx=self.trg_bos_idx, training=self.training) + seq_gen_prob = None + for _ in range(self.args.preference_num): + dec_output = Tools._single_decode(seq_gen_gumbel.detach(), src_hiddens, src_mask, self.decoder) + single_step_prob = self.proj(dec_out=dec_output, src_hidden=src_hiddens, src_mask=src_mask, context=context, tp=tp_path, embed=decoder_embed, no_action=ar_gth is None) + if self.training: + single_step_gumbel_word = self.gs.forward(single_step_prob, self.ts.step_on(), normed=True) + if seq_gen_prob is not None: + seq_gen_prob = torch.cat([seq_gen_prob, single_step_prob], 1) + else: + seq_gen_prob = single_step_prob + seq_gen_gumbel = torch.cat([seq_gen_gumbel, single_step_gumbel_word], 1) + else: + if seq_gen_prob is not None: + seq_gen_prob = torch.cat([seq_gen_prob, single_step_prob], 1) + else: + seq_gen_prob = single_step_prob + single_step_word = torch.argmax(single_step_prob, -1) + seq_gen_gumbel = torch.cat([seq_gen_gumbel, single_step_word], 1) + if self.training: + return seq_gen_prob, seq_gen_gumbel[:, 1:, :] + else: + return seq_gen_prob, seq_gen_gumbel[:, 1:] + + def proj(self, dec_out, src_hidden, src_mask, context, tp, embed=None, no_action=False): + B, L_s = dec_out.size(0), dec_out.size(1) + gen_logit = self.gen_prob(dec_out, embed) + if no_action: + hidden_no_At = src_hidden + mask_no_At = src_mask + else: + hidden_no_At = src_hidden[:, :-self.args.action_num // 2, :] + mask_no_At = src_mask[:, :, :-self.args.action_num // 2] + copy_logit = torch.bmm(dec_out, hidden_no_At.permute(0, 2, 1)) + copy_logit = copy_logit.masked_fill((mask_no_At == 0).expand(-1, L_s, -1), -1e9 if copy_logit.dtype==torch.float32 else -1e4) + logits = torch.cat([gen_logit, copy_logit], -1) + if self.args.scale_prj: + logits *= self.hidden_size ** -0.5 + probs = torch.softmax(logits, -1) + gen_prob = probs[:, :, :self.n_topic_vocab] + copy_tp_path_prob = probs[:, :, self.n_topic_vocab:self.n_topic_vocab + self.args.state_num] + copy_tp_path_prob = torch.bmm(copy_tp_path_prob, tp) + copy_context_prob = probs[:, :, self.n_topic_vocab + self.args.state_num:] + transfer_context_word = torch.gather(self.glo2loc.unsqueeze(0).expand(B, -1), 1, context) + copy_context_temp = copy_context_prob.new_zeros(B, L_s, self.n_topic_vocab) + copy_context_prob = copy_context_temp.scatter_add(dim=2, + index=transfer_context_word.unsqueeze(1).expand(-1, L_s, -1), + src=copy_context_prob) + probs = gen_prob + copy_tp_path_prob + copy_context_prob + return probs + + def gen_prob(self, dec_output, embed=None): + + if embed is None: + prob = self.gen_proj(dec_output) + else: + assert embed.size(0) == self.gen_proj[0].out_features + prob = dec_output.matmul(embed.T) + return prob + + +def one_hot_scatter(indice, num_classes, dtype=torch.float): + + indice_shape = list(indice.shape) + placeholder = torch.zeros(*(indice_shape + [num_classes]), device=indice.device, dtype=dtype) + v = 1 if dtype == torch.long else 1.0 + placeholder.scatter_(-1, indice.unsqueeze(-1), v) + return placeholder diff --git a/Preference_all.py b/Preference_all.py new file mode 100644 index 0000000..f869e51 --- /dev/null +++ b/Preference_all.py @@ -0,0 +1,163 @@ +from gumbel_softmax import GumbelSoftmax +from tau_scheduler import TauScheduler +import torch +import torch.nn as nn +from tools import Tools + +class PriorPreference(nn.Module): + def __init__(self,encoder, decoder,m_encoder, hidden_size, n_topic_vocab, + trg_bos_idx, max_seq_len,glo2loc,loc2glo,main_tfr_encoder, + gs: GumbelSoftmax, ts: TauScheduler): + super(PriorPreference, self).__init__() + self.decoder = decoder + self.p_encoder = encoder + self.m_encoder = m_encoder + self.main_tfr_encoder = main_tfr_encoder + self.n_topic_vocab = n_topic_vocab + self.bos_idx = trg_bos_idx + self.max_seq_len = max_seq_len + self.gs = gs + self.ts = ts + self.glo2loc = glo2loc + self.loc2glo = loc2glo + self.hidden_size = hidden_size + self.gen_proj = nn.Linear(self.hidden_size, self.n_topic_vocab) + + def forward(self,context,context_len,pv_m,pv_m_mask,tp_path,tp_path_len): + bs = pv_m.size(0) + pv_m_hidden = self.p_encoder(pv_m,pv_m_mask) + context_mask = Tools.get_mask_via_len(context_len, self.args.context_max_len) + context_hidden = self.main_tfr_encoder(context, context_mask) + tp_path = one_hot_scatter(tp_path,self.n_topic_vocab) + tp_mask = Tools.get_mask_via_len(tp_path_len,self.args.state_num) + tp_path_hidden = self.p_encoder(tp_path,tp_mask) + src_hiddens = torch.cat([context_hidden,pv_m_hidden,tp_path_hidden], 1) + src_mask = torch.cat([context_mask,pv_m_mask,tp_mask], 2) + seq_gen_gumbel = Tools._generate_init(bs,self.n_topic_vocab , trg_bos_idx=self.bos_idx,training=self.training) + seq_gen_prob = None + for _ in range(self.args.preference_num): + dec_output = Tools._single_decode(seq_gen_gumbel.detach(),src_hiddens,src_mask,self.decoder) + single_step_prob = self.proj(dec_out=dec_output,context=context,src_hidden=src_hiddens, + src_mask=src_mask,pv_m=pv_m,tp_path=tp_path) + single_step_gumbel_word = self.gs.forward(single_step_prob, self.ts.step_on(), normed=True) + if self.training: + if seq_gen_prob is not None: + seq_gen_prob = torch.cat([seq_gen_prob, single_step_prob], 1) + else: + seq_gen_prob = single_step_prob + seq_gen_gumbel = torch.cat([seq_gen_gumbel, single_step_gumbel_word], 1) + else: + single_step_word = torch.argmax(single_step_prob, -1) + seq_gen_gumbel = torch.cat([seq_gen_gumbel, single_step_word], 1) + if self.training: + return seq_gen_prob, seq_gen_gumbel[:,1:,:] + else: + return seq_gen_gumbel[:,1:] + + def proj(self,dec_out,context,src_hidden,src_mask,pv_m,tp_path): + gen_logit = self.gen_proj(dec_out) + L_s = dec_out.size(1) + B = dec_out.size(0) + copy_logit = torch.bmm(dec_out, src_hidden.permute(0, 2, 1)) + copy_logit = copy_logit.masked_fill((src_mask == 0).expand(-1, L_s, -1), -1e9) + logits = torch.cat([gen_logit, copy_logit], -1) + if self.args.scale_prj: + logits *= self.hidden_size ** -0.5 + probs = torch.softmax(logits, -1) + gen_prob = probs[:, :, :self.n_topic_vocab] + copy_context_prob = probs[:, :,self.n_topic_vocab :self.n_topic_vocab + self.args.context_max_len] + transfer_context_word = torch.gather(self.glo2loc.unsqueeze(0).expand(B, -1),1, context) + copy_context_temp = copy_context_prob.new_zeros(B, L_s, self.n_topic_vocab) + copy_context_prob = copy_context_temp.scatter_add(dim=2,index=transfer_context_word.unsqueeze(1).expand(-1, L_s, -1), + src=copy_context_prob) + copy_pv_m_prob = probs[:, :, self.n_topic_vocab + self.args.context_max_len : self.n_topic_vocab + self.args.context_max_len + self.args.preference_num ] + copy_pv_m_prob = torch.bmm(copy_pv_m_prob, pv_m) + copy_tp_prob = probs[:, :,self.n_topic_vocab + self.args.context_max_len + self.args.preference_num: ] + copy_tp_prob = torch.bmm(copy_tp_prob,tp_path) + probs = gen_prob + copy_pv_m_prob + copy_tp_prob + copy_context_prob + return probs + +class PosteriorPreference(nn.Module): + def __init__(self,encoder,decoder,m_encoder,main_encoder, hidden_size, n_topic_vocab,glo2loc,loc2glo, + trg_bos_idx, max_seq_len, gs: GumbelSoftmax, ts: TauScheduler): + super(PosteriorPreference, self).__init__() + self.p_encoder = encoder + self.decoder = decoder + self.m_encoder = m_encoder + self.main_encoder = main_encoder + self.glo2loc = glo2loc + self.loc2glo = loc2glo + self.hidden_size = hidden_size + self.n_topic_vocab = n_topic_vocab + self.trg_bos_idx = trg_bos_idx + self.max_seq_len = max_seq_len + self.gs = gs + self.ts = ts + self.gen_proj = nn.Linear(self.hidden_size, self.n_topic_vocab) + + def forward(self,context,context_len,pv_m,pv_m_mask,ar_gth,ar_gth_len, + tp_path,tp_path_len): + bs = pv_m.size(0) + ar_gth = ar_gth[:,[1]] + ar_gth = one_hot_scatter(ar_gth,self.n_topic_vocab) + ar_mask = Tools.get_mask_via_len(ar_gth_len, 1) + ar_hidden = self.p_encoder(ar_gth,ar_mask) + context_mask = Tools.get_mask_via_len(context_len, self.args.context_max_len) + context_hidden = self.main_encoder(context, context_mask) + pv_m_hidden = self.p_encoder(pv_m, pv_m_mask) + tp_path = one_hot_scatter(tp_path, self.n_topic_vocab) + tp_mask = Tools.get_mask_via_len(tp_path_len, self.args.state_num) + tp_path_hidden = self.p_encoder(tp_path, tp_mask) + src_hiddens = torch.cat([context_hidden, pv_m_hidden, tp_path_hidden, ar_hidden], 1) + src_mask = torch.cat([context_mask, pv_m_mask, tp_mask, ar_mask], 2) + seq_gen_gumbel = Tools._generate_init(bs, self.n_topic_vocab, trg_bos_idx=self.trg_bos_idx) + seq_gen_prob = None + for _ in range(self.args.preference_num): + dec_output = Tools._single_decode(seq_gen_gumbel.detach(), src_hiddens, src_mask, self.decoder) + single_step_prob = self.proj(dec_out=dec_output,src_hidden=src_hiddens,src_mask=src_mask,pv_m=pv_m,context=context,tp=tp_path) + if self.training: + single_step_gumbel_word = self.gs.forward(single_step_prob, self.ts.step_on(),normed=True) + if seq_gen_prob is not None: + seq_gen_prob = torch.cat([seq_gen_prob, single_step_prob], 1) + else: + seq_gen_prob = single_step_prob + seq_gen_gumbel = torch.cat([seq_gen_gumbel, single_step_gumbel_word], 1) + else: + single_step_word = torch.argmax(single_step_prob, -1) + seq_gen_gumbel = torch.cat([seq_gen_gumbel, single_step_word], 1) + if self.training: + return seq_gen_prob, seq_gen_gumbel[:,1:,:] + else: + return seq_gen_gumbel[:,1:] + + def proj(self, dec_out, src_hidden, src_mask, pv_m , context ,tp ): + B, L_s = dec_out.size(0), dec_out.size(1) + gen_logit = self.gen_proj(dec_out) + hidden_no_At = src_hidden[:, : ,:] + mask_no_At = src_mask[:,:, : ] + copy_logit = torch.bmm(dec_out, hidden_no_At.permute(0, 2, 1)) + copy_logit = copy_logit.masked_fill((mask_no_At == 0).expand(-1, L_s, -1), -1e9) + logits = torch.cat([gen_logit, copy_logit], -1) + if self.args.scale_prj: + logits *= self.hidden_size ** -0.5 + probs = torch.softmax(logits, -1) + gen_prob = probs[:, :, :self.n_topic_vocab] + copy_pv_m_prob = probs[:, :,self.n_topic_vocab + self.args.context_max_len :self.n_topic_vocab + self.args.context_max_len + self.args.preference_num] + copy_pv_m_prob = torch.bmm(copy_pv_m_prob, pv_m) + copy_context_prob = probs[:,:,self.n_topic_vocab :self.n_topic_vocab + self.args.context_max_len] + transfer_context_word = torch.gather(self.glo2loc.unsqueeze(0).expand(B, -1),1, context) + copy_context_temp = copy_context_prob.new_zeros(B, L_s, self.n_topic_vocab) + copy_context_prob = copy_context_temp.scatter_add(dim=2,index=transfer_context_word.unsqueeze(1).expand(-1, L_s, -1), + src=copy_context_prob) + copy_tp_path_prob = probs[:, :,self.n_topic_vocab + self.args.context_max_len + self.args.preference_num : + self.n_topic_vocab + self.args.context_max_len + self.args.preference_num + self.args.state_num] + copy_tp_path_prob = torch.bmm(copy_tp_path_prob, tp) + probs = gen_prob + copy_pv_m_prob + copy_tp_path_prob + copy_context_prob + return probs + +def one_hot_scatter(indice, num_classes, dtype=torch.float): + indice_shape = list(indice.shape) + placeholder = torch.zeros(*(indice_shape + [num_classes]), device=indice.device, dtype=dtype) + v = 1 if dtype == torch.long else 1.0 + placeholder.scatter_(-1, indice.unsqueeze(-1), v) + return placeholder \ No newline at end of file diff --git a/Profile.py b/Profile.py new file mode 100644 index 0000000..c4a666e --- /dev/null +++ b/Profile.py @@ -0,0 +1,102 @@ +from gumbel_softmax import GumbelSoftmax +from tau_scheduler import TauScheduler +import torch.nn as nn +import torch +from tools import Tools + +class PriorProfile(nn.Module): + def __init__(self,encoder, decoder, hidden_size, n_topic_vocab, + trg_bos_idx, max_seq_len, gs: GumbelSoftmax, ts: TauScheduler): + super(PriorProfile, self).__init__() + self.id_encoder = encoder + self.decoder = decoder + self.n_topic_vocab = n_topic_vocab + self.bos_idx = trg_bos_idx + self.hidden_size = hidden_size + self.gen_proj = nn.Linear(self.hidden_size, self.n_topic_vocab) + self.max_seq_len = max_seq_len + self.gs = gs + self.ts = ts + + def forward(self,id): + bs = id.size(0) + id_mask = id.new_ones(bs, 1, 1).cuda() + user_id = id.unsqueeze(-1) + id_hidden_p = self.id_encoder(user_id, id_mask) + seq_gen_gumbel = Tools._generate_init(bs, self.n_topic_vocab, trg_bos_idx=self.bos_idx,training=self.training) + seq_gen_prob = None + for _ in range(self.args.profile_num): + dec_output = Tools._single_decode(seq_gen_gumbel.detach(), id_hidden_p, id_mask, self.decoder) + single_step_prob = self.gen_proj(dec_output) + single_step_prob = torch.softmax(single_step_prob,-1) + single_step_gumbel_word = self.gs.forward(single_step_prob, self.ts.step_on(), normed=True) + if self.training: + if seq_gen_prob is not None: + seq_gen_prob = torch.cat([seq_gen_prob, single_step_prob], 1) + else: + seq_gen_prob = single_step_prob + seq_gen_gumbel = torch.cat([seq_gen_gumbel, single_step_gumbel_word], 1) + else: + single_step_word = torch.argmax(single_step_prob, -1) + seq_gen_gumbel = torch.cat([seq_gen_gumbel, single_step_word], 1) + if self.training: + return seq_gen_prob, seq_gen_gumbel[:,1:,:] + else: + return seq_gen_gumbel[:,1:] + +class PosteriorProfile(nn.Module): + def __init__(self,main_encoder,topic_encoder,id_encoder,decoder, hidden_size, n_topic_vocab,glo2loc,loc2glo, + m_encoder,trg_bos_idx, max_seq_len, gs: GumbelSoftmax, ts: TauScheduler): + super(PosteriorProfile, self).__init__() + self.main_encoder = main_encoder + self.m_encoder = m_encoder + self.topic_encoder = topic_encoder + self.id_encoder = id_encoder + self.n_topic_vocab = n_topic_vocab + self.bos_idx = trg_bos_idx + self.glo2loc = glo2loc + self.loc2glo = loc2glo + self.hidden_size = hidden_size + self.decoder = decoder + self.max_seq_len = max_seq_len + self.gs = gs + self.ts = ts + self.gen_proj = nn.Linear(self.hidden_size, self.n_topic_vocab) + + def forward(self,id,topics,topics_len): + bs = id.size(0) + topics = one_hot_scatter(topics,self.n_topic_vocab) + topic_mask = Tools.get_mask_via_len(topics_len, self.args.all_topic_num) + topic_hidden = self.topic_encoder(topics,topic_mask) + id_mask = id.new_ones(bs, 1, 1).cuda() + user_id = id.unsqueeze(-1) + id_hidden_q = self.id_encoder(user_id, id_mask) + src_hidden = torch.cat([id_hidden_q,topic_hidden],1) + src_mask = torch.cat([id_mask,topic_mask],2) + seq_gen_gumbel = Tools._generate_init(bs, self.n_topic_vocab, trg_bos_idx=self.bos_idx,training=self.training) + seq_gen_prob = None + for _ in range(self.args.profile_num): + dec_output = Tools._single_decode(seq_gen_gumbel.detach(), src_hidden, src_mask, self.decoder) + single_step_prob = self.gen_proj(dec_output) + single_step_prob = torch.softmax(single_step_prob, -1) + if self.training: + single_step_gumbel_word = self.gs.forward(single_step_prob, self.ts.step_on(),normed=True) + if seq_gen_prob is not None: + seq_gen_prob = torch.cat([seq_gen_prob, single_step_prob], 1) + else: + seq_gen_prob = single_step_prob + seq_gen_gumbel = torch.cat([seq_gen_gumbel, single_step_gumbel_word], 1) + else: + single_step_word = torch.argmax(single_step_prob, -1) + seq_gen_gumbel = torch.cat([seq_gen_gumbel, single_step_word], 1) + if self.training: + return seq_gen_prob, seq_gen_gumbel[:,1:,:] + else: + return seq_gen_gumbel[:,1:] + +def one_hot_scatter(indice, num_classes, dtype=torch.float): + indice_shape = list(indice.shape) + placeholder = torch.zeros(*(indice_shape + [num_classes]), device=indice.device, dtype=dtype) + v = 1 if dtype == torch.long else 1.0 + placeholder.scatter_(-1, indice.unsqueeze(-1), v) + return placeholder diff --git a/ProfileTopic.py b/ProfileTopic.py new file mode 100644 index 0000000..a84f1d0 --- /dev/null +++ b/ProfileTopic.py @@ -0,0 +1,398 @@ +from gumbel_softmax import GumbelSoftmax +from tau_scheduler import TauScheduler +import torch.nn as nn +import torch +import torch.nn.functional as F +from tools import Tools +from transformer.SubLayers import PositionwiseFeedForward + +class PriorProfile(nn.Module): + def __init__(self, encoder, decoder, hidden_size, n_topic_vocab, + trg_bos_idx, max_seq_len, gs: GumbelSoftmax, ts: TauScheduler): + super(PriorProfile, self).__init__() + self.id_encoder = encoder + self.decoder = decoder + self.n_topic_vocab = n_topic_vocab + self.bos_idx = trg_bos_idx + self.hidden_size = hidden_size + self.gen_proj = nn.Sequential(nn.Linear(self.hidden_size, self.n_topic_vocab)) + self.max_seq_len = max_seq_len + self.gs = gs + self.ts = ts + + def forward(self, id, encoder_embed=None, decoder_embed=None): + bs = id.size(0) + id_mask = id.new_ones(bs, 1, 1).cuda() + user_id = id.unsqueeze(-1) + id_hidden_p = self.id_encoder(user_id, id_mask, embed=encoder_embed) + seq_gen_gumbel = Tools._generate_init(bs, self.n_topic_vocab, trg_bos_idx=self.bos_idx, training=self.training) + seq_gen_prob = None + seq_gen_prob_raw = None + for _ in range(self.args.profile_num): + dec_output = Tools._single_decode(seq_gen_gumbel.detach(), id_hidden_p, id_mask, self.decoder) + single_step_prob_raw = self.gen_prob(dec_output, embed=decoder_embed) + single_step_prob = torch.softmax(single_step_prob_raw, -1) + single_step_gumbel_word = self.gs.forward(single_step_prob, self.ts.step_on(), normed=True) + if self.training: + if seq_gen_prob is not None: + seq_gen_prob = torch.cat([seq_gen_prob, single_step_prob], 1) + seq_gen_prob_raw = torch.cat([seq_gen_prob_raw, single_step_prob_raw], 1) + else: + seq_gen_prob = single_step_prob + seq_gen_prob_raw = single_step_prob_raw + seq_gen_gumbel = torch.cat([seq_gen_gumbel, single_step_gumbel_word], 1) + else: + if seq_gen_prob is not None: + seq_gen_prob = torch.cat([seq_gen_prob, single_step_prob], 1) + else: + seq_gen_prob = single_step_prob + single_step_word = torch.argmax(single_step_prob, -1) + seq_gen_gumbel = torch.cat([seq_gen_gumbel, single_step_word], 1) + + if self.training: + return seq_gen_prob, seq_gen_gumbel[:,1:,:], seq_gen_prob_raw + else: + return seq_gen_prob, seq_gen_gumbel[:,1:], None + + def gen_prob(self, dec_output, embed=None): + + if embed is None: + prob = self.gen_proj(dec_output) + else: + assert embed.size(0) == self.gen_proj[0].out_features + prob = dec_output.matmul(embed.T) + return prob + +class PosteriorProfile(nn.Module): + def __init__(self,main_encoder,topic_encoder,id_encoder,decoder, hidden_size, n_topic_vocab,glo2loc,loc2glo, + trg_bos_idx, max_seq_len, gs: GumbelSoftmax, ts: TauScheduler): + super(PosteriorProfile, self).__init__() + self.main_encoder = main_encoder + self.topic_encoder = topic_encoder + self.id_encoder = id_encoder + self.n_topic_vocab = n_topic_vocab + self.bos_idx = trg_bos_idx + self.glo2loc = glo2loc + self.loc2glo = loc2glo + self.hidden_size = hidden_size + self.decoder = decoder + self.max_seq_len = max_seq_len + self.gs = gs + self.ts = ts + self.gen_proj = nn.Sequential(nn.Linear(self.hidden_size, self.n_topic_vocab)) + + def forward(self, id, topics, topics_len, encoder_embed=None, decoder_embed=None): + + bs = id.size(0) + topics = one_hot_scatter(topics, self.n_topic_vocab) + topic_mask = Tools.get_mask_via_len(topics_len, self.args.all_topic_num) + topic_hidden = self.topic_encoder(topics, topic_mask) + id_mask = id.new_ones(bs, 1, 1).cuda() + user_id = id.unsqueeze(-1) + id_hidden_q = self.id_encoder(user_id, id_mask, embed=encoder_embed) + src_hidden = torch.cat([id_hidden_q, topic_hidden], 1) + src_mask = torch.cat([id_mask, topic_mask], 2) + seq_gen_gumbel = Tools._generate_init(bs, self.n_topic_vocab, trg_bos_idx=self.bos_idx, training=self.training) + seq_gen_prob = None + for _ in range(self.args.profile_num): + dec_output = Tools._single_decode(seq_gen_gumbel.detach(), src_hidden, src_mask, self.decoder) + single_step_prob = self.gen_prob(dec_output, embed=decoder_embed) + single_step_prob = torch.softmax(single_step_prob, -1) + if self.training: + single_step_gumbel_word = self.gs.forward(single_step_prob, self.ts.step_on(),normed=True) + if seq_gen_prob is not None: + seq_gen_prob = torch.cat([seq_gen_prob, single_step_prob], 1) + else: + seq_gen_prob = single_step_prob + seq_gen_gumbel = torch.cat([seq_gen_gumbel, single_step_gumbel_word], 1) + else: + if seq_gen_prob is not None: + seq_gen_prob = torch.cat([seq_gen_prob, single_step_prob], 1) + else: + seq_gen_prob = single_step_prob + single_step_word = torch.argmax(single_step_prob, -1) + seq_gen_gumbel = torch.cat([seq_gen_gumbel, single_step_word], 1) + if self.training: + return seq_gen_prob, seq_gen_gumbel[:,1:,:] + else: + return seq_gen_prob, seq_gen_gumbel[:,1:] + + def gen_prob(self, dec_output, embed=None): + + if embed is None: + prob = self.gen_proj(dec_output) + else: + assert embed.size(0) == self.gen_proj[0].out_features + prob = dec_output.matmul(embed.T) + return prob + +def one_hot_scatter(indice, num_classes, dtype=torch.float): + indice_shape = list(indice.shape) + placeholder = torch.zeros(*(indice_shape + [num_classes]), device=indice.device, dtype=dtype) + v = 1 if dtype == torch.long else 1.0 + placeholder.scatter_(-1, indice.unsqueeze(-1), v) + return placeholder + + +class UserContrast(nn.Module): + def __init__(self, args, p_encoder, p_g_encoder, context_encoder, decoder, hidden_size, n_topic_vocab, trg_bos_idx, max_seq_len, gs: GumbelSoftmax, ts: TauScheduler, user2character_metric, topic_co_graph, no_prob=False): + super(UserContrast, self).__init__() + self.args = args + self.no_prob = no_prob + self.p_encoder = p_encoder + self.p_g_encoder = p_g_encoder + self.context_encoder = context_encoder + self.decoder = decoder + self.n_topic_vocab = n_topic_vocab + self.bos_idx = trg_bos_idx + self.hidden_size = hidden_size + self.gen_proj = nn.Sequential(nn.Linear(self.hidden_size, self.n_topic_vocab)) + self.persona2topic = nn.Linear(self.hidden_size, self.hidden_size) + self.max_seq_len = max_seq_len + self.gs = gs + self.ts = ts + self.user2character_metric = user2character_metric + self.subfun2char_mlp = nn.Sequential( + nn.Linear(self.hidden_size * 3, self.hidden_size * 4), + nn.LeakyReLU(), + nn.Linear(self.hidden_size * 4, self.hidden_size) + ) + if self.args.s_profile_add_t and self.args.no_user_emb == 0: + self.main2char_mlp = nn.Sequential( + nn.Linear(self.hidden_size * 3, self.hidden_size * 4), + nn.LeakyReLU(), + nn.Linear(self.hidden_size * 4, self.hidden_size) + ) + else: + self.main2char_mlp = nn.Sequential( + nn.Linear(self.hidden_size * 2, self.hidden_size * 4), + nn.LeakyReLU(), + nn.Linear(self.hidden_size * 4, self.hidden_size) + ) + self.profile_select_threshold = 0.5 + self.global_ffn = PositionwiseFeedForward(self.hidden_size, self.hidden_size * 4) + + self.topic_co_graph = topic_co_graph + + if self.args.no_learn == 0: + self.liner_1 = nn.Linear(self.hidden_size, self.hidden_size) + self.liner_2 = nn.Linear(self.hidden_size, self.hidden_size) + + + def get_global_tp_hidden(self, user_id, context, context_len, tp_path, tp_path_len, topic2context, user_embed=None, character_embed=None, topic_embed=None): + bs = user_id.size(0) + user_id = user_id + character_id = self.user2character_metric[user_id] + + character_id = torch.cat([character_id, torch.zeros((bs, 1)).long().to(device=character_id.device)], dim=-1) + character_vectors = character_embed[character_id] + user_vector = user_embed[user_id] + context_mask = Tools.get_mask_via_len(context_len, self.args.context_all_max_len) + context_hidden = self.context_encoder(context, context_mask, only_sent=True) + tp_path_vec = one_hot_scatter(tp_path, self.n_topic_vocab) + tp_mask = Tools.get_mask_via_len(tp_path_len, self.args.state_num) + tp_hidden = self.p_g_encoder(tp_path_vec, tp_mask, embed=topic_embed) + tp_context_hidden = torch.stack([context_hidden[i][topic2context[i]] for i in range(bs)], dim=0) + tp_context_hidden[topic2context.eq(0)] = 0 + + + char_q = torch.cat([tp_context_hidden, tp_hidden, user_vector.unsqueeze(1).expand(-1, self.args.state_num, -1)], dim=-1) + char_q = self.subfun2char_mlp(char_q) + if self.args.profile_agg: + + char_prob_raw = char_q.matmul(character_vectors.transpose(-1, -2)) + if self.args.top_p != 0: + char_prob = char_prob_raw.clone() + sorted_logits, sorted_indices = torch.sort(char_prob, descending=True) + cumulative_probs = torch.cumsum(sorted_logits.softmax(dim=-1), dim=-1) + sorted_indices_to_remove = cumulative_probs > self.args.top_p + sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() + sorted_indices_to_remove[..., 0] = 0 + indices_to_remove = sorted_indices_to_remove.scatter(2, sorted_indices, sorted_indices_to_remove) + char_prob[indices_to_remove] = -1e9 if char_prob.dtype == torch.float32 else -1e4 + char_prob = char_prob.softmax(-1) + else: + char_mask = char_prob_raw.sigmoid() > self.profile_select_threshold + char_prob = char_prob_raw.clone() + char_prob[~char_mask] = -1e9 if char_prob.dtype == torch.float32 else -1e4 + char_prob = char_prob.softmax(dim=-1) + + if self.args.otp: + topic_att, topic_ids = self.topic_att, self.topic_ids + + + else: + + topic_att, topic_ids = character_embed.matmul(topic_embed.T).topk(self.args.global_topics, -1) + if self.args.use_co_occurrence == 1: + + bsz, seq_len, char_num = char_prob.size() + character_vect = character_embed[character_id] + topic_score = character_vect.matmul(topic_embed.T.unsqueeze(0)).unsqueeze(1).repeat(1, seq_len, 1, 1) + topic_score_mask = torch.zeros_like(topic_score, dtype=torch.bool) + for i in range(bsz): + for j in range(seq_len): + global_topic_ids = self.topic_co_graph[tp_path[i, j].tolist()] + topic_score_mask[i, j, :, global_topic_ids] = True + topic_score[~topic_score_mask] = -1e13 + topic_att, topic_ids = topic_score.topk(self.args.global_topics, -1) + topic_prob = topic_att.softmax(-1) + character_topic_embed = (topic_prob.unsqueeze(-1) * topic_embed[topic_ids]).sum(-2) + tp_global_embed = (char_prob.unsqueeze(-1) * character_topic_embed).sum(-2) + else: + topic_prob = topic_att.softmax(-1) + character_topic_embed = topic_prob.unsqueeze(1).matmul(topic_embed[topic_ids]).squeeze(1) + tp_global_embed = char_prob.unsqueeze(2).matmul(character_topic_embed[character_id].unsqueeze(1).expand(-1, self.args.state_num, -1, -1)).squeeze(2) + else: + + topic_att, topic_ids = char_q.matmul(topic_embed.T).topk(self.args.global_topics, -1) + topic_prob = topic_att.softmax(-1) + tp_global_embed = topic_prob.unsqueeze(2).matmul(topic_embed[topic_ids]).squeeze(2) + + + tp_agg_hidden = self.global_ffn(tp_global_embed + tp_hidden) + tp_agg_hidden[~tp_mask.squeeze(1)] = 0 + return tp_agg_hidden + + def forward(self, user_id, context, context_len, tp_path, tp_path_len, all_topic, topic2context, user_embed=None, character_embed=None, topic_embed=None): + bs = user_id.size(0) + user_mask = user_id.new_ones(bs, 1, 1).cuda() + user_id = user_id + character_id = self.user2character_metric[user_id] + user_vector = user_embed[user_id] + context_mask = Tools.get_mask_via_len(context_len, self.args.context_all_max_len) + context_vector = self.context_encoder(context, context_mask, only_last=True) + character_vectors = character_embed[character_id] + + if self.args.s_profile_add_t: + if self.args.global_topic: + + tp_mask = Tools.get_mask_via_len(tp_path_len, self.args.state_num) + tp_hidden_raw = self.get_global_tp_hidden(user_id=user_id, context=context, context_len=context_len, tp_path=tp_path, tp_path_len=tp_path_len, topic2context=topic2context, user_embed=user_embed, character_embed=character_embed, topic_embed=topic_embed) + tp_hidden = self.p_encoder(tp_hidden_raw, tp_mask, embed=topic_embed, only_last=True, embed_input=True) + else: + tp_path = one_hot_scatter(tp_path, self.n_topic_vocab) + tp_mask = Tools.get_mask_via_len(tp_path_len, self.args.state_num) + tp_hidden = self.p_encoder(tp_path, tp_mask, embed=topic_embed, only_last=True) + + + if self.args.no_user_emb == 0: + char_q = torch.cat([context_vector, tp_hidden, user_vector], dim=-1) + else: + char_q = torch.cat([context_vector, tp_hidden], dim=-1) + char_q = self.main2char_mlp(char_q) + else: + char_q = torch.cat([context_vector, user_vector], dim=-1) + char_q = self.main2char_mlp(char_q) + + if self.args.no_learn == 0: + char_porb_raw = character_vectors.matmul(self.persona2topic(char_q).unsqueeze(1).transpose(1, 2)).squeeze(-1) + else: + char_porb_raw = character_vectors.matmul(char_q.unsqueeze(1).transpose(1, 2)).squeeze(-1) + + if self.args.profile_select: + if self.args.random_persona == 0: + if self.args.sharping_profile: + pos_mask = char_porb_raw.sigmoid() < self.profile_select_threshold + char_prob = char_porb_raw.clone() + char_prob[pos_mask] = -1e9 if char_prob.dtype == torch.float32 else -1e4 + char_prob = char_prob.softmax(-1) + char_inv_prob = char_porb_raw.clone() + char_inv_prob[~pos_mask] = 1e9 if char_prob.dtype == torch.float32 else 1e4 + char_inv_prob = (-char_inv_prob).softmax(-1) + else: + char_prob = char_porb_raw.softmax(-1) + char_inv_prob = (-char_porb_raw).softmax(-1) + else: + select_indices = torch.randint(char_porb_raw.size(-1), (char_porb_raw.size(0), self.args.random_persona)) + char_prob = torch.zeros_like(char_porb_raw) + char_inv_prob = torch.full_like(char_porb_raw, 1 / (char_porb_raw.size(-1) - self.args.random_persona)) + for i in range(char_porb_raw.size(0)): + char_prob[i, select_indices[i]] = 1 / self.args.random_persona + char_inv_prob[i, select_indices[i]] = 0 + + dec_q_pos = char_prob.unsqueeze(1).matmul(character_vectors).squeeze(1) + dec_q_anchor = character_vectors.mean(1) + dec_q_neg = char_inv_prob.unsqueeze(1).matmul(character_vectors).squeeze(1) + + single_step_prob_raw_pos = self.gen_prob(dec_q_pos, embed=topic_embed) + single_step_prob_pos = torch.softmax(single_step_prob_raw_pos, -1).unsqueeze(1) + single_step_gumbel_word_pos = self.gs.forward(single_step_prob_pos, self.ts.step_on(), normed=True) + + if all_topic is not None: + single_step_pos = self.gen_prob(dec_q_pos, embed=topic_embed).softmax(-1) + single_step_anchor = self.gen_prob(dec_q_anchor, embed=topic_embed).softmax(-1) + single_step_neg = self.gen_prob(dec_q_neg, embed=topic_embed).softmax(-1) + all_topic_label = F.one_hot(all_topic, num_classes=self.n_topic_vocab).sum(dim=1) + all_topic_label[:, 0] = 0 + all_topic_label = all_topic_label.to(dtype=torch.bool) + contrast_loss = - (single_step_pos-single_step_anchor)[all_topic_label].sigmoid().log().sum() - (single_step_anchor-single_step_neg)[all_topic_label].sigmoid().log().sum() + else: + dec_q_pos = character_vectors.mean(1) + single_step_prob_raw_pos = self.gen_prob(dec_q_pos, embed=topic_embed) + single_step_prob_pos = torch.softmax(single_step_prob_raw_pos, -1).unsqueeze(1) + single_step_gumbel_word_pos = self.gs.forward(single_step_prob_pos, self.ts.step_on(), normed=True) + contrast_loss = None + + if self.args.topic_copynet: + + if self.args.otp: + char2topic_att, char2topic_ids = self.topic_att, self.topic_ids + elif self.args.no_learn == 0: + char2topic_att, char2topic_ids = character_embed.matmul(topic_embed.T).topk(10, -1) + else: + char2topic_att, char2topic_ids = self.liner_1(character_embed).matmul(topic_embed.T).topk(10, -1) + char2topic_att = char2topic_att.softmax(-1) + char2topic_prob = torch.zeros((character_embed.size(0), topic_embed.size(0)), device=topic_embed.device) + char2topic_prob.scatter_(dim=1, index=char2topic_ids, src=char2topic_att) + + profile_prob = torch.zeros((bs, character_embed.size(0)), device=topic_embed.device) + profile_prob.scatter_(dim=1, index=character_id, src=char_prob) + + profile_prob = profile_prob.matmul(char2topic_prob) + else: + profile_prob = None + + if self.training: + seq_gen_prob = single_step_prob_pos + seq_gen_gumbel = single_step_gumbel_word_pos + else: + seq_gen_prob = single_step_prob_pos + seq_gen_gumbel = torch.argmax(single_step_prob_pos, -1) + + if not self.args.global_topic_for_action: + tp_hidden_raw = None + + if self.args.get_personas: + pos_mask = torch.nonzero(~pos_mask).tolist() + selected_personas = [[] for i in range(bs)] + for (i, j) in pos_mask: + selected_personas[i].append(j) + else: + selected_personas = None + + if self.no_prob: + + if self.training: + return dec_q_pos, tp_hidden_raw, profile_prob, contrast_loss + else: + return dec_q_pos, tp_hidden_raw, profile_prob + else: + if self.training: + return seq_gen_prob, seq_gen_gumbel, tp_hidden_raw, profile_prob, contrast_loss, selected_personas + else: + return seq_gen_prob, seq_gen_gumbel, tp_hidden_raw, profile_prob, selected_personas + + + + def gen_prob(self, dec_output, embed=None): + + if embed is None: + prob = self.gen_proj(dec_output) + else: + if embed.ndim == 2: + assert embed.size(0) == self.gen_proj[0].out_features + prob = dec_output.matmul(embed.T) + elif embed.ndim == 3: + prob = embed.matmul(dec_output.unsqueeze(1).transpose(1, 2)).squeeze(-1) + return prob diff --git a/README.md b/README.md new file mode 100644 index 0000000..d367520 --- /dev/null +++ b/README.md @@ -0,0 +1,37 @@ +# Personalized Topic Selection Model for Topic-Grounded Dialogue (ACL 2024 findings) + +## System Requirements +```bash +conda create -n PETD python==3.10 +conda activate PETD +pip install -r requirement.txt +``` + +## Train & Eval +* For the topic prediction task : `python topic.py -dataset {dataset} ` +* For the response generation task: `python generation.py -dataset {dataset}` + + +## Cite +``` +@inproceedings{fan-etal-2024-personalized, + title = "Personalized Topic Selection Model for Topic-Grounded Dialogue", + author = "Fan, Shixuan and + Wei, Wei and + Wen, Xiaofei and + Mao, Xian-Ling and + Chen, Jixiong and + Chen, Dangyang", + editor = "Ku, Lun-Wei and + Martins, Andre and + Srikumar, Vivek", + booktitle = "Findings of the Association for Computational Linguistics: ACL 2024", + month = aug, + year = "2024", + address = "Bangkok, Thailand", + publisher = "Association for Computational Linguistics", + url = "https://aclanthology.org/2024.findings-acl.429/", + doi = "10.18653/v1/2024.findings-acl.429", + pages = "7188--7202" +} +``` diff --git a/Response.py b/Response.py new file mode 100644 index 0000000..e8bf0dd --- /dev/null +++ b/Response.py @@ -0,0 +1,281 @@ +import torch +import torch.nn as nn +from tools import Tools + + +class Response(nn.Module): + def __init__(self, args, vocab, a_encoder, p_encoder, decoder, hidden_size, n_vocab, trg_bos_idx, trg_eos_idx, + max_seq_len, main_encoder, beam_width, loc2glo, n_topic): + super(Response, self).__init__() + self.args = args + self.vocab = vocab + self.main_encoder = main_encoder + self.a_encoder = a_encoder + self.p_encoder = p_encoder + self.decoder = decoder + self.hidden_size = hidden_size + self.n_vocab = n_vocab + self.n_topic = n_topic + self.bos_idx = trg_bos_idx + self.eos_idx = trg_eos_idx + self.pad_idx = self.vocab.tokenizer.pad_token_id if self.args.gpt2 else self.vocab.word2idx['[PAD]'] + self.max_len = max_seq_len + self.beam_width = beam_width + self.gen_proj = nn.Sequential(nn.Linear(self.hidden_size, self.n_vocab)) + self.loc2glo = loc2glo + + def forward(self, ar, ar_len, context, context_len, tp_path, tp_path_len, + resp_gth=None, resp_gth_len=None, + tp_path_embed=None, profile_embed=None): + + bs = ar.size(0) + if self.args.only_context: + if self.args.gpt2: + context_hidden = self.decoder.get_input_embeddings()(context) + context_mask = context.ne(0).unsqueeze(1) + else: + context_mask = Tools.get_mask_via_len(context_len, self.args.context_max_len) + context_hidden = self.main_encoder(context, context_mask) + src_hidden = context_hidden + src_mask = context_mask + else: + if self.args.gpt2: + action_mask = Tools.get_mask_via_len(ar_len, self.args.action_num) + action_hidden = self.a_encoder(ar, action_mask) + if tp_path_embed is not None and profile_embed is not None: + tp_mask = Tools.get_mask_via_len(tp_path_len, self.args.state_num) + tp_hidden = self.p_encoder(tp_path_embed, tp_mask, embed_input=True) + + contect_vectors = self.decoder.get_input_embeddings()(context) + + src_hidden = torch.cat([tp_hidden, action_hidden, profile_embed.unsqueeze(1), contect_vectors], 1) + profile_mask = torch.ones((bs, 1, 1), device=tp_hidden.device).to(dtype=torch.bool) + src_mask = torch.cat([tp_mask, action_mask, profile_mask, context.ne(0).unsqueeze(1)], 2) + else: + tp_mask = Tools.get_mask_via_len(tp_path_len, self.args.state_num) + tp_hidden = self.p_encoder(tp_path, tp_mask) + + contect_vectors = self.decoder.get_input_embeddings()(context) + + src_hidden = torch.cat([tp_hidden, action_hidden, contect_vectors], 1) + src_mask = torch.cat([tp_mask, action_mask, context.ne(0).unsqueeze(1)], 2) + + + else: + context_mask = Tools.get_mask_via_len(context_len, self.args.context_max_len) + context_hidden = self.main_encoder(context, context_mask) + action_mask = Tools.get_mask_via_len(ar_len, self.args.action_num) + action_hidden = self.a_encoder(ar, action_mask) + if tp_path_embed is not None and profile_embed is not None: + tp_mask = Tools.get_mask_via_len(tp_path_len, self.args.state_num) + tp_hidden = self.p_encoder(tp_path_embed, tp_mask, embed_input=True) + + src_hidden = torch.cat([context_hidden, tp_hidden, action_hidden, profile_embed.unsqueeze(1)], 1) + profile_mask = torch.ones((bs, 1, 1), device=tp_hidden.device).to(dtype=torch.bool) + src_mask = torch.cat([context_mask, tp_mask, action_mask, profile_mask], 2) + else: + tp_mask = Tools.get_mask_via_len(tp_path_len, self.args.state_num) + tp_hidden = self.p_encoder(tp_path, tp_mask) + + src_hidden = torch.cat([context_hidden, tp_hidden, action_hidden], 1) + src_mask = torch.cat([context_mask, tp_mask, action_mask], 2) + + if resp_gth is not None: + if self.args.gpt2: + input_mask = torch.cat((src_mask.squeeze(1).to(dtype=torch.float), resp_gth.ne(0)), dim=-1) + input_vectors = self.decoder.get_input_embeddings()(resp_gth) + input_vectors = torch.cat((src_hidden, input_vectors), dim=1) + probs = self.decoder.forward(inputs_embeds=input_vectors, attention_mask=input_mask) + probs = probs[0][:, -resp_gth.shape[-1]:, :] + probs = probs.softmax(-1) + else: + resp_mask = Tools.get_mask_via_len(resp_gth_len, self.args.r_max_len) & Tools.get_subsequent_mask(resp_gth) + dec_out = self.decoder(resp_gth, resp_mask, src_hidden, src_mask) + if tp_path_embed is not None and profile_embed is not None: + probs = self.proj(dec_out=dec_out, src_hidden=src_hidden[:, :-1, :], src_mask=src_mask[:, :, :-1], context=context, action=ar, tp=tp_path) + else: + probs = self.proj(dec_out=dec_out, src_hidden=src_hidden, src_mask=src_mask, context=context, action=ar, tp=tp_path) + return probs + else: + seq_gen = torch.ones(bs, 1, dtype=torch.long) * self.bos_idx + seq_gen = seq_gen.cuda() + if tp_path_embed is not None and profile_embed is not None: + src_hidden_decode = src_hidden[:, :-1, :] + src_mask_decode = src_mask[:, :, :-1] + else: + src_hidden_decode = src_hidden + src_mask_decode = src_mask + if self.args.decoder_strategy == 'greedy': + decoder_function = self._greedy_search + elif self.args.decoder_strategy == 'beam_search': + decoder_function = self._beam_search + seq_gen, probs = decoder_function(seq_gen=seq_gen, src_hidden=src_hidden, src_mask=src_mask, action=ar, + context=context, tp=tp_path, + src_hidden_decode=src_hidden_decode, src_mask_decode=src_mask_decode) + return seq_gen, probs + + def proj(self, dec_out, src_hidden, src_mask, context, action, tp): + B = action.size(0) + gen_logit = self.gen_proj(dec_out) + L_r = dec_out.size(1) + copy_logit = torch.bmm(dec_out, src_hidden.permute(0, 2, 1)) + copy_logit = copy_logit.masked_fill((src_mask == 0).expand(-1, L_r, -1), -1e9) + logits = torch.cat([gen_logit, copy_logit], -1) + if self.args.scale_prj: + logits *= self.hidden_size ** -0.5 + probs = torch.softmax(logits, -1) + gen_prob = probs[:, :, :self.n_vocab] + copy_context_prob = probs[:, :, self.n_vocab:self.n_vocab + self.args.context_max_len] + context = one_hot_scatter(context, self.n_vocab) + copy_context_prob = torch.bmm(copy_context_prob, context) + copy_tp_prob = probs[:, :, self.n_vocab + self.args.context_max_len:self.n_vocab + self.args.context_max_len + self.args.state_num] + transfer_tp_word = torch.gather(self.loc2glo.unsqueeze(0).expand(B, -1), 1, tp) + copy_tp_temp = copy_tp_prob.new_zeros(B, L_r, self.n_vocab) + copy_tp_prob = copy_tp_temp.scatter_add(dim=2, index=transfer_tp_word.unsqueeze(1).expand(-1, L_r, -1), + src=copy_tp_prob) + copy_ar_prob = probs[:, :, self.n_vocab + self.args.context_max_len + self.args.state_num:] + transfer_ar_word = torch.gather(self.loc2glo.unsqueeze(0).expand(B, -1), 1, action) + copy_ar_temp = copy_ar_prob.new_zeros(B, L_r, self.n_vocab) + copy_ar_prob = copy_ar_temp.scatter_add(dim=2, index=transfer_ar_word.unsqueeze(1).expand(-1, L_r, -1), + src=copy_ar_prob) + probs = gen_prob + copy_context_prob + copy_tp_prob + copy_ar_prob + + return probs + + def _greedy_search(self, seq_gen, src_hidden, src_mask, action, context, tp, src_hidden_decode, src_mask_decode): + probs = None + for step in range(self.args.r_max_len): + if self.args.gpt2: + input_mask = torch.cat((src_mask.squeeze(1).to(dtype=torch.float), seq_gen.ne(0)), dim=-1) + input_vectors = self.decoder.get_input_embeddings()(seq_gen) + input_vectors = torch.cat((src_hidden, input_vectors), dim=1) + single_step_probs = self.decoder.forward(inputs_embeds=input_vectors, attention_mask=input_mask) + single_step_probs = single_step_probs[0][:, -1:, :] + single_step_probs = single_step_probs.softmax(-1) + else: + single_step_probs = self.single_decode(input_seq=seq_gen, src_hidden=src_hidden, src_mask=src_mask, + decoder=self.decoder, action=action, context=context, tp=tp, + src_hidden_decode=src_hidden_decode, src_mask_decode=src_mask_decode) + if probs is None: + probs = single_step_probs + else: + probs = torch.cat([probs, single_step_probs], 1) + single_step_word = torch.argmax(single_step_probs, -1) + seq_gen = torch.cat([seq_gen, single_step_word], 1) + return seq_gen[:, 1:], probs + + def _beam_search(self, seq_gen, src_hidden, src_mask, action, context, tp, src_hidden_decode, src_mask_decode, beam_width=10): + + + batch_size, seq_len, embed_size = src_hidden.size() + device = seq_gen.device + + + src_hidden = src_hidden.unsqueeze(1).repeat(1, beam_width, 1, 1).view(batch_size * beam_width, seq_len, embed_size) + src_mask = src_mask.repeat(1, beam_width, 1).view(batch_size * beam_width, seq_len) + action = action.unsqueeze(1).repeat(1, beam_width, 1).view(batch_size * beam_width, -1) + context = context.unsqueeze(1).repeat(1, beam_width, 1).view(batch_size * beam_width, -1) + tp = tp.unsqueeze(1).repeat(1, beam_width, 1).view(batch_size * beam_width, -1) + src_hidden_decode = src_hidden_decode.unsqueeze(1).repeat(1, beam_width, 1, 1).view(batch_size * beam_width, -1, embed_size) + src_mask_decode = src_mask_decode.unsqueeze(1).repeat(1, beam_width, 1, 1).view(batch_size * beam_width, 1, -1) + + + batch_beams = [[{'sequence': seq_gen[i].unsqueeze(0), 'score': 0.0} for j in range(beam_width)] for i in range(batch_size)] + completed_beams = [[] for _ in range(batch_size)] + + for step in range(self.args.r_max_len): + candidate_sequences = [] + candidate_scores = [] + + for i, beams in enumerate(batch_beams): + for j, beam in enumerate(beams): + sequence = beam['sequence'] + score = beam['score'] + candidate_sequences.append(sequence) + candidate_scores.append(score) + + if len(candidate_sequences) == 0: + break + + candidate_sequences = torch.cat(candidate_sequences, dim=0) + candidate_scores = torch.tensor(candidate_scores, dtype=torch.float, device=device) + + if self.args.gpt2: + input_mask = torch.cat((src_mask.to(dtype=torch.float), candidate_sequences.ne(0)), dim=-1) + input_vectors = self.decoder.get_input_embeddings()(candidate_sequences) + input_vectors = torch.cat((src_hidden, input_vectors), dim=1) + step_probs = self.decoder.forward(inputs_embeds=input_vectors, attention_mask=input_mask) + step_probs = step_probs[0][:, -1:, :] + step_probs = step_probs.softmax(-1) + else: + step_probs = self.single_decode(input_seq=candidate_sequences, src_hidden=src_hidden, src_mask=src_mask.unsqueeze(1), + decoder=self.decoder, action=action, context=context, tp=tp, + src_hidden_decode=src_hidden_decode, src_mask_decode=src_mask_decode) + + + scores = candidate_scores.unsqueeze(-1) + torch.log(step_probs.squeeze(1)) + if step == 0: + + scores = scores.view(batch_size, beam_width, self.decoder.config.vocab_size if self.args.gpt2 else self.n_vocab)[:, 0, :] + else: + scores = scores.reshape(batch_size, beam_width * (self.decoder.config.vocab_size if self.args.gpt2 else self.n_vocab)) + + sort_scores, sort_indices = scores.sort(dim=-1, descending=True) + + + beam_indices = sort_indices // (self.decoder.config.vocab_size if self.args.gpt2 else self.n_vocab) + word_indices = sort_indices % (self.decoder.config.vocab_size if self.args.gpt2 else self.n_vocab) + + + new_beams = [[] for _ in range(batch_size)] + for i in range(batch_size): + completed_beam_num = 0 + j = 0 + while j < (beam_width + completed_beam_num): + beam_index = beam_indices[i, j] + i * beam_width + word_index = word_indices[i, j] + + sequence = torch.cat((candidate_sequences[beam_index], word_index.unsqueeze(0)), dim=0) + score = sort_scores[i, j].item() + if sequence[-1] == self.eos_idx: + completed_beams[i].append({'sequence': sequence, 'score': score}) + completed_beam_num += 1 + else: + new_beams[i].append({'sequence': sequence.unsqueeze(0), 'score': score}) + j += 1 + + batch_beams = new_beams + + + for i in range(batch_size): + completed_beams[i] += [{'sequence': beam['sequence'], 'score': beam['score']} for beam in batch_beams[i]] + + + output_seqs = [] + output_probs = [] + for i in range(batch_size): + completed_beams[i].sort(key=lambda x: x['score'], reverse=True) + top_beam = completed_beams[i][0] + output_seqs.append(top_beam['sequence']) + output_probs.append(top_beam['score']) + + + output_seqs = torch.nn.utils.rnn.pad_sequence(output_seqs, batch_first=True) + output_probs = torch.tensor(output_probs, dtype=torch.float, device=device) + + return output_seqs[:, 1:], output_probs + + def single_decode(self, input_seq, src_hidden, src_mask, decoder, action, context, tp, src_hidden_decode, src_mask_decode): + dec_output = Tools._single_decode(input_seq.detach(), src_hidden, src_mask, decoder) + single_step_probs = self.proj(dec_out=dec_output, context=context, + src_hidden=src_hidden_decode, src_mask=src_mask_decode, + action=action, tp=tp) + return single_step_probs + + +def one_hot_scatter(indice, num_classes, dtype=torch.float): + indice_shape = list(indice.shape) + placeholder = torch.zeros(*(indice_shape + [num_classes]), device=indice.device, dtype=dtype) + v = 1 if dtype == torch.long else 1.0 + placeholder.scatter_(-1, indice.unsqueeze(-1), v) + return placeholder diff --git a/Vocab.py b/Vocab.py new file mode 100644 index 0000000..e9f3ee2 --- /dev/null +++ b/Vocab.py @@ -0,0 +1,260 @@ +import numpy as np +import os +import json +import pickle +import jieba +import requests +from tqdm import tqdm +from transformers import BertTokenizer + + +class Vocab(object): + def __init__(self, args, task='', word_vocab=False, topic_vocab=False): + super(Vocab, self).__init__() + self.args = args + self.word_list, self.word_len, self.topic_list, self.topic_len = self.get_vocab(task) + self.word2idx = dict(zip(self.word_list, range(len(self.word_list)))) + self.idx2word = {id:word for word,id in self.word2idx.items()} + self.topic2idx = dict(zip(self.topic_list, range(len(self.topic_list)))) + self.idx2topic = {id:word for word,id in self.topic2idx.items()} + self.word_vocab = word_vocab + self.topic_vocab = topic_vocab + self.get_userSent(args.dataset) + + + + + + if self.args.gpt2: + self.tokenizer = BertTokenizer(vocab_file='vocabulary/vocab_small.txt') + self.vocab_size = len(self.tokenizer) + self.pad_id = self.tokenizer.convert_tokens_to_ids('[PAD]') + self.pad_id = self.tokenizer.convert_tokens_to_ids('[PAD]') + + def get_userSent(self, dataset_name): + + if not os.path.exists('./dataset/{}/processed_data'.format(dataset_name)): + os.mkdir('./dataset/{}/processed_data'.format(dataset_name)) + userSent_to_idx_path = './dataset/{}/processed_data/userSent_to_idx.json'.format(dataset_name) + user_to_Sentidx_path = './dataset/{}/processed_data/user_to_Sentidx.json'.format(dataset_name) + if os.path.exists(userSent_to_idx_path) and os.path.exists(user_to_Sentidx_path): + userSent_to_idx = json.load(open(userSent_to_idx_path, 'r')) + user_to_Sentidx = json.load(open(user_to_Sentidx_path, 'r')) + else: + print('get user profile data') + if dataset_name == 'TG-ReDial': + user_to_topic_sents = pickle.load(open("./dataset/{}/user2TopicSent.pkl".format(dataset_name), 'rb')) + + userSent_to_idx = {'[PAD]': 0} + for idx, sent_set in user_to_topic_sents.items(): + for sent in list(sent_set): + if sent not in userSent_to_idx.keys(): + userSent_to_idx[sent] = len(userSent_to_idx) + user_to_Sentidx = {} + for idx, sent_set in user_to_topic_sents.items(): + if idx not in user_to_Sentidx.keys(): + user_to_Sentidx[int(idx)] = [] + for sent in list(sent_set): + user_to_Sentidx[int(idx)].append(userSent_to_idx[sent]) + with open(userSent_to_idx_path, 'w') as f: + json.dump(userSent_to_idx, f) + with open(user_to_Sentidx_path, 'w') as f: + json.dump(user_to_Sentidx, f) + elif dataset_name == 'PersonaChat': + train_data = open('dataset/PersonaChat/ConvAI2/train_both_original_no_cands.txt', 'r').readlines() + valid_data = open('dataset/PersonaChat/ConvAI2/valid_both_original_no_cands.txt', 'r').readlines() + all_data = train_data + valid_data + userSent_to_idx = {'[PAD]': 0} + user_idx = -1 + user_to_Sentidx = {} + for line in tqdm(all_data): + line = line.strip() + if line[:2] == '1 ': + user_idx += 1 + user_to_Sentidx[user_idx] = [] + if 'your persona: ' in line: + line = line[line.find('your persona: ') + len('your persona: '):] + if line not in userSent_to_idx.keys(): + userSent_to_idx[line] = len(userSent_to_idx) + user_to_Sentidx[user_idx].append(userSent_to_idx[line]) + elif "partner's persona: " in line: + line = line[line.find("partner's persona: ") + len("partner's persona: "):] + if line not in userSent_to_idx.keys(): + userSent_to_idx[line] = len(userSent_to_idx) + user_to_Sentidx[user_idx].append(userSent_to_idx[line]) + + character_set = list(set([frozenset(i) for i in user_to_Sentidx.values()])) + user_to_Sentidx = {} + for i in range(len(character_set)): + user_to_Sentidx[str(i)] = list(character_set[i]) + with open(userSent_to_idx_path, 'w') as f: + json.dump(userSent_to_idx, f) + with open(user_to_Sentidx_path, 'w') as f: + json.dump(user_to_Sentidx, f) + Sentset_to_user = {} + for (user, Sentlist) in user_to_Sentidx.items(): + Sentset_to_user[frozenset(Sentlist)] = user + self.Sentset_to_user = Sentset_to_user + self.userSent_to_idx = userSent_to_idx + self.idx_to_userSent = {v:k for k, v in userSent_to_idx.items()} + self.user_to_Sentidx = user_to_Sentidx + self.n_user = max([int(i) for i in user_to_Sentidx.keys()]) + self.n_character = len(userSent_to_idx) + + def get_Character2topic(self, dataset_name): + + character2topic_path = './dataset/{}/processed_data/character2topic.json'.format(dataset_name) + if os.path.exists(character2topic_path): + character2topic = json.load(open(character2topic_path, 'r')) + else: + character2topic = {} + for character, char_idx in self.userSent_to_idx.items(): + if character =='[PAD]': + character2topic[character] = '[PAD]' + continue + is_match = False + character_cuted = list(jieba.cut(character, cut_all=True)) + for topic, topic_idx in self.topic2idx.items(): + if topic in character_cuted: + is_match = True + character2topic[character] = topic + break + if is_match == False: + for topic, topic_idx in self.topic2idx.items(): + if topic in character: + is_match = True + character2topic[character] = topic + break + if is_match == False: + print('not match', character) + with open(character2topic_path, 'w') as f: + json.dump(character2topic, f) + self.character2topic = character2topic + + def get_vocab(self, task): + action_type = ['谈论', '拒绝', '请求推荐', '允许推荐', '推荐电影', '反馈', '反馈,结束'] + RESERVED_WORDS = [self.args.PAD_WORD, self.args.BOS_PRE, self.args.BOS_PRO, self.args.UNK_WORD] + topic_vocab = [] + word_vocab = [] + if task == 'rec': + with open(self.args.topic_movie_file.format(self.args.dataset), encoding='utf-8') as topic_file: + for line in topic_file: + line = line.strip('\n') + topic_vocab.append(line) + topic_vocab = RESERVED_WORDS + action_type + topic_vocab + topic_len = len(topic_vocab) + with open(self.args.vocab_movie_file.format(self.args.dataset), encoding='utf-8') as vocab_file: + for line in vocab_file.readlines(): + line = line.strip('\n') + word_vocab.append(line) + word_len = len(word_vocab) + else: + with open(self.args.topic_file.format(self.args.dataset), encoding='utf-8') as topic_file: + for line in topic_file: + line = line.strip('\n') + topic_vocab.append(line) + topic_vocab = RESERVED_WORDS + action_type + topic_vocab + topic_len = len(topic_vocab) + with open(self.args.vocab_file.format(self.args.dataset), encoding='utf-8') as vocab_file: + for line in vocab_file.readlines(): + line = line.strip('\n') + word_vocab.append(line) + word_len = len(word_vocab) + return word_vocab, word_len, topic_vocab, topic_len + + def word2index(self, word): + unk_id = self.word2idx.get('[UNK]') + if isinstance(word, str): + return self.word2idx.get(word, unk_id) + elif isinstance(word, list): + return [self.word2index(w) for w in word] + else: + raise ValueError("wrong type {}".format(type(word))) + + def index2word(self, index): + if isinstance(index, int): + if index < len(self.word_list): + return self.word_list[index] + else: + raise ValueError("{} is out of {}".format(index, len(self.word_list))) + elif isinstance(index, np.ndarray): + index = index.tolist() + return [self.index2word(i) for i in index] + elif isinstance(index, list): + return [self.index2word(i) for i in index] + else: + raise ValueError("wrong type {}".format(type(index))) + + def topic2index(self, topic): + unk_id = self.topic2idx.get('[UNK]') + if isinstance(topic, str): + return self.topic2idx.get(topic, unk_id) + elif isinstance(topic, list): + return [self.topic2index(w) for w in topic] + elif isinstance(topic,int): + return int + elif topic is None: + return self.topic2idx.get(self.args.PAD_WORD) + else: + raise ValueError("wrong type {}".format(type(topic))) + + def index2topic(self, index): + if isinstance(index, int): + if index < len(self.topic_list): + return self.topic_list[index] + elif index == len(self.topic_list): + return None + else: + raise ValueError("{} is out of {}".format(index, len(self.word_list))) + elif isinstance(index, np.ndarray): + index = index.tolist() + return [self.index2topic(i) for i in index] + elif isinstance(index, list): + return [self.index2topic(i) for i in index] + else: + raise ValueError("wrong type {}".format(type(index))) + + def item_in(self, word): + if self.word_vocab: + return self.word2index(word) + elif self.topic_vocab: + return self.topic2index(word) + else: + raise ValueError("word_vocab or topic_vocab must be true") + + def __len__(self,word=False,topic=False): + if word: + return self.word_len + elif topic: + return self.topic_len + else: + raise ValueError("word_vocab or topic_vocab must be true") + + def vocab_transfer(self): + + glo2loc = [] + for word in self.word_list: + glo2loc.append(self.topic2index(word)) + loc2glo = [] + for index, topic in enumerate(self.topic_list): + loc2glo.append(self.word2index(topic)) + + + + + return glo2loc, loc2glo + + def get_word_pad(self): + return self.word2index('[PAD]') + + def get_topic_pad(self): + return self.topic2index('[PAD]') + + def topic_num(self): + return self.topic_len + + def movie_num(self): + non_movie = self.topic2index('') + 1 + movienum = self.topic_num() - non_movie + return movienum + diff --git a/copy_scheduler.py b/copy_scheduler.py new file mode 100644 index 0000000..66dc43a --- /dev/null +++ b/copy_scheduler.py @@ -0,0 +1,41 @@ + + + +class CopyScheduler: + def __init__(self, origin_lambda, mini_lambda, n_step, s_step=0): + + self.origin_lambda = origin_lambda + self.mini_lambda = mini_lambda + self.n_step = n_step + self.s_step = s_step + self.step_interval = (self.mini_lambda - self.origin_lambda) / self.n_step + + def step_on(self, advance=True): + if advance: + self.s_step += 1 + if self.s_step >= self.n_step: + return self.mini_lambda + return self.s_step * self.step_interval + self.origin_lambda + + def dump(self): + self_dict = { + "origin_lambda": self.origin_lambda, + "mini_lambda": self.mini_lambda, + "n_step": self.n_step, + "s_step": self.s_step, + } + return self_dict + + @staticmethod + def load(self_dict): + return CopyScheduler(self_dict["origin_lambda"], + self_dict["mini_lambda"], + self_dict["n_step"], + self_dict["s_step"]) + + def self_load(self, self_dict): + self.origin_lambda = self_dict["origin_lambda"] + self.mini_lambda = self_dict["mini_lambda"] + self.n_step = self_dict["n_step"] + self.s_step = self_dict["s_step"] + self.step_interval = (self.mini_lambda - self.origin_lambda) / self.n_step diff --git a/dataset/PersonaChat/.gitkeep b/dataset/PersonaChat/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/dataset/TG-ReDial/.gitkeep b/dataset/TG-ReDial/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/distinct.py b/distinct.py new file mode 100644 index 0000000..d4b3d55 --- /dev/null +++ b/distinct.py @@ -0,0 +1,29 @@ +import jieba +def cal_calculate(args, tokenized_gen, tokenized_tar): + dis1=0 + dis_set1=set() + dis2=0 + dis_set2=set() + for sen, tar in zip(tokenized_gen, tokenized_tar): + for j,word in enumerate(sen): + if word == args.EOS_RESPONSE: + sen = sen[:j] + break + full_sen_gen = '' + for word in sen: + full_sen_gen += word + sen_split_by_movie = list(full_sen_gen.split('')) + sen_1 = [] + for i, sen_split in enumerate(sen_split_by_movie): + for segment in jieba.cut(sen_split): + sen_1.append(segment) + if i != len(sen_split_by_movie) - 1: + sen_1.append('') + prediction = sen_1 + for word in prediction: + dis_set1.add(word) + dis1 += 1 + for i in range(1, len(prediction)): + dis_set2.add(prediction[i - 1] + ' ' + prediction[i]) + dis2 += 1 + return len(dis_set1)/dis1, len(dis_set2)/dis2 \ No newline at end of file diff --git a/draw_picture.py b/draw_picture.py new file mode 100644 index 0000000..023c7a5 --- /dev/null +++ b/draw_picture.py @@ -0,0 +1,164 @@ +import matplotlib.pyplot as plt +import seaborn as sns +import numpy as np +import os +import pickle +import json +import pandas as pd +import os +import networkx as nx +import json +import scipy +from tqdm import tqdm +from matplotlib.patches import Patch + + + + + + +def softmax(x): + exp_x = np.exp(x) + return exp_x/np.sum(exp_x) + +def Hyperparameter_experiment(): + sns.set_style('darkgrid') + t10 = [0.842, 0.907, 0.921 ] + t7 = [0.826, 0.879, 0.975 ] + t4 = [0.609, 0.727, 0.773 ] + t1 = [0.312, 0.474, 0.557 ] + + plt.figure(figsize=(9, 9)) + + labels = ['Hit@1', 'Hit@3', 'Hit@5'] + y = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0] + y_label = ['0.0', '0.2', '0.4', '0.6', '0.8', '1.0'] + + + x = np.arange(len(labels)) + width = 0.2 + + plt.bar(x - 1.5 * width, t1, width, label='t = 1', color=' + plt.bar(x - 0.5 * width, t4, width, label='t = 4', color=' + plt.bar(x + 0.5 * width, t7, width, label='t = 7', color=' + plt.bar(x + 1.5 * width, t10, width, label='t = 10', color=' + + plt.xticks(x, labels=labels, fontsize=30) + plt.yticks(y, labels=y_label, fontsize=30) + plt.legend(ncol=4, loc='upper center', bbox_to_anchor=(0.5, 1.15), fontsize=23, columnspacing=0.5) + plt.savefig('pic/turns.pdf', dpi=400) + plt.show() + + n5 = [0.818, 0.863, 0.923 ] + n10 = [0.837, 0.899, 0.920] + n15 = [0.807, 0.864, 0.888 ] + n20 = [0.768, 0.827, 0.860] + + plt.figure(figsize=(9, 9)) + + labels = ['Hit@1', 'Hit@3', 'Hit@5'] + + + x = np.arange(len(labels)) + width = 0.2 + + plt.bar(x - 1.5 * width, n5, width, label='k = 5', color=' + plt.bar(x - 0.5 * width, n10, width, label='k = 10', color=' + plt.bar(x + 0.5 * width, n15, width, label='k = 15', color=' + plt.bar(x + 1.5 * width, n20, width, label='k = 20', color=' + + plt.xticks(x, labels=labels, fontsize=30) + plt.yticks(y, labels=y_label, fontsize=30) + plt.legend(ncol=4, loc='upper center', bbox_to_anchor=(0.5, 1.15), fontsize=23, columnspacing=0.5) + plt.savefig('pic/topics.pdf', dpi=400) + plt.show() + + +def show_kl(dataset): + co_topic_path = './dataset/{}/processed_data/co_topic.pkl'.format(dataset) + co_topic = pickle.load(open(co_topic_path, 'rb')) + co_topic_graph = co_topic['co_topic_graph'] + persona_co_topic = co_topic['persona_co_topic'] + + max_kl = [] + max_pl_p = [] + for i, num in enumerate(tqdm(co_topic_graph.sum(-1)[:100])): + if num == 0: + max_kl.append(-1) + max_pl_p.append(-1) + else: + persona_kl = [] + for p in range(persona_co_topic.shape[0]): + if co_topic_graph[i].sum() == 0 or persona_co_topic[p, i].sum() == 0: + persona_kl.append(0) + continue + kl = scipy.stats.entropy(co_topic_graph[i]+1E-5, persona_co_topic[p, i]+1E-5) + persona_kl.append(kl) + max_persona_kl = max(persona_kl) + max_presona = persona_kl.index(max_persona_kl) + max_kl.append(max_persona_kl) + max_pl_p.append(max_presona) + + for i, p in enumerate(max_pl_p): + if p == -1: + continue + all_global = co_topic_graph[i] + persona_global = persona_co_topic[p, i] + + + all_global_cleared, persona_global_cleared = [], [] + for all, per in zip(all_global, persona_global): + if all != 0: + all_global_cleared.append(all if all >= 0 else all+1000) + persona_global_cleared.append(per if per >= 0 else per+1000) + show_num = 20 + if len(all_global_cleared) < show_num: + continue + all_global_cleared = all_global_cleared[:show_num] + persona_global_cleared = persona_global_cleared[:show_num] + + all_global_cleared = [idx for idx, num in enumerate(all_global_cleared) for jdx in range(num)] + persona_global_cleared = [idx for idx, num in enumerate(persona_global_cleared) for jdx in range(num)] + + plt.rcParams['axes.facecolor'] = ' + ax = sns.displot(all_global_cleared, stat="density", common_norm=False, label='all_global') + ax = sns.displot(persona_global_cleared, stat="density", common_norm=False, label='persona_global') + text_size = 18 + + + + + + + + legend_elements = [Patch(facecolor=' + Patch(facecolor=' + ] + + + + + + + + + + ax.set(yticklabels=[]) + ax.tick_params(left=False) + + + + ax.set(ylabel=None) + + + + + plt.show() + print('kl:{:.4f}'.format(max_kl[i])) + a = 1 + +if __name__ == '__main__': + if not os.path.exists('pic/'): + os.mkdir('pic/') + + show_kl('TG-ReDial') \ No newline at end of file diff --git a/eval_personas.py b/eval_personas.py new file mode 100644 index 0000000..b9c38d6 --- /dev/null +++ b/eval_personas.py @@ -0,0 +1,34 @@ +import json +import os + +if __name__ == '__main__': + + + + + + + + + + + + + + with open('human_eval_persona.jsonl', 'r') as f: + dataset = json.load(f) + + tp, pred_true, gold_true = 0, 0, 0 + for data in dataset: + if len(data['gold_personas']) == 0: + continue + pred_true += len(data['selected_personas']) + gold_true += len(data['gold_personas']) + for persona in data['selected_personas']: + if persona in data['gold_personas']: + tp += 1 + recall = tp / gold_true + precision = tp / pred_true + f1 = 2 * recall * precision / (recall + precision) + print('Recall: {:.4f}; Precision: {:.4f}; F1: {:.4f}'.format(recall, precision, f1)) + diff --git a/generation.py b/generation.py new file mode 100644 index 0000000..e0ab2fc --- /dev/null +++ b/generation.py @@ -0,0 +1,170 @@ +import argparse +import random +import torch +import os +from DataProcessor import DataSet +from get_logger import get_logger +from get_logger import task_uuid +from Vocab import Vocab +from upcrgene import Upcrgene, Engine +from torch.utils.data import DataLoader +from DataLoaderTopic import collate_fn + + +main_logger = get_logger("main", './log/test.log') +main_logger.info("TASK ID {}".format(task_uuid)) + +def config(): + parser = argparse.ArgumentParser() + parser.add_argument("-test", "--test", action="store_true") + parser.add_argument('--inference', type=bool, default=False, ) + parser.add_argument("-use_cuda", "--use_cuda", type=bool, default=False) + parser.add_argument("-gpu", "--gpu", type=str, default='1') + parser.add_argument("--num_workers", type=int, default=8) + parser.add_argument("--processed", action='store_false', ) + parser.add_argument("--not_topic_guide", action='store_true', ) + parser.add_argument("-dataset", "--dataset", choices=["TG-ReDial", "PersonaChat"], default="TG-ReDial") + + + parser.add_argument("-global_topic", "--global_topic", action='store_true', ) + parser.add_argument("-global_topic_for_action", "--global_topic_for_action", action='store_true', ) + parser.add_argument("-top_p", "--top_p", type=float, default=0., ) + parser.add_argument("-profile_agg", "--profile_agg", action='store_true', ) + parser.add_argument("-otp", "--otp", action='store_true', ) + parser.add_argument("-s_profile_add_t", "--s_profile_add_t", action='store_true', ) + parser.add_argument("-topic_copynet", "--topic_copynet", action='store_true', ) + + parser.add_argument("-profile_pred", "--profile_pred", action='store_true', ) + parser.add_argument("-profile_select", "--profile_select", action='store_true', ) + parser.add_argument("-sharping_profile", "--sharping_profile", action='store_true', ) + parser.add_argument("-profile_contrast", "--profile_contrast", action='store_true', ) + parser.add_argument("-global_topics", "--global_topics", type=int, default=10, ) + + parser.add_argument("-gene_add_profile", "--gene_add_profile", action='store_true', + ) + parser.add_argument("-gpt2", "--gpt2", action='store_true', ) + parser.add_argument("-only_context", "--only_context", action='store_true', ) + parser.add_argument("-decoder_strategy", "--decoder_strategy", type=str, choices=['greedy', 'beam_search'], default='greedy', ) + + parser.add_argument("-history_turn", "--history_turn", type=int, default=100, ) + + + parser.add_argument('--n_layers', type=int, default=6) + parser.add_argument('--n_position', type=int, default=160) + parser.add_argument('--n_inner_vocab', type=int, default=5000) + parser.add_argument('--n_inner_layers', type=int, default=3) + parser.add_argument('--n_inner_position', type=int, default=15) + parser.add_argument('--d_word_vec', type=int, default=512) + parser.add_argument('--n_head', type=int, default=8) + parser.add_argument('--d_k', type=int, default=64) + parser.add_argument('--d_v', type=int, default=64) + parser.add_argument('--pad_idx', type=int, default=2) + parser.add_argument('--d_model', type=int, default=512) + parser.add_argument('--d_inner', type=int, default=2048) + parser.add_argument('--dropout', type=float, default=0.1) + parser.add_argument('--n_warmup_steps', type=int, default=2000) + parser.add_argument('--scale_emb', type=bool, default=False) + parser.add_argument('--switch_interval', type=int, default=16) + parser.add_argument('--cache_turn', type=int, default=0) + parser.add_argument('--context_all_max_len', type=int, default=1024) + parser.add_argument('--context_max_len', type=int, default=150) + parser.add_argument('--r_max_len', type=int, default=50) + parser.add_argument('--r_beam_max_len', type=int, default=30) + parser.add_argument('--conv_max_len', type=int, default=500) + parser.add_argument('--profile_num', type=int, default=1) + parser.add_argument('--state_num', type=int, default=20) + parser.add_argument('--state_num_redial', type=int, default=20) + parser.add_argument('--pretrain_state_num', type=int, default=50) + parser.add_argument('--all_topic_num', type=int, default=20) + parser.add_argument('--all_topic_num_redial', type=int, default=40) + parser.add_argument('--movie_path_len', type=int, default=3) + parser.add_argument('--tag_num', type=int, default=3) + parser.add_argument('--preference_num', type=int, default=1) + parser.add_argument('--topic_num', type=int, default=2) + parser.add_argument('--action_num', type=int, default=10) + parser.add_argument('--action_num_redial', type=int, default=1) + parser.add_argument('--relation_num', type=int, default=150) + parser.add_argument('--movie_num', type=int, default=200) + parser.add_argument('--state_token', type=int, default=40) + parser.add_argument('--scale_prj', type=bool, default=True) + parser.add_argument('--epoch', type=int, default=100) + parser.add_argument('--task', type=str, default="meddg") + parser.add_argument('--dataset_file', type=str, default="./dataset/{}.zip") + parser.add_argument('--topic_file', type=str, default="./dataset/{}/topic.txt") + parser.add_argument('--topic_movie_file', type=str, default="./dataset/{}/tpmv.txt") + parser.add_argument('--vocab_file', type=str, default="./dataset/{}/tpvocab.txt") + parser.add_argument('--vocab_movie_file', type=str, default="./dataset/{}/tpmvvocab.txt") + parser.add_argument('--no_action_super', type=str, default=None) + parser.add_argument('--max_patience', type=int, default=20) + parser.add_argument('--log_loss_interval', type=int, default=100) + parser.add_argument('--gradient_stack', type=int, default=80) + parser.add_argument('--decay_interval', type=int, default=10000) + parser.add_argument('--decay_rate', type=float, default=0.9) + parser.add_argument('--lr', type=float, default=1e-5) + parser.add_argument('--valid_eval_interval', type=int, default=10000) + parser.add_argument('--test_eval_interval', type=int, default=10000) + parser.add_argument('--force_ckpt_dump', action='store_true') + parser.add_argument('--sub_gen_lambda', type=float, default=0.01) + parser.add_argument('--s_copy_lambda', type=int, default=1) + parser.add_argument('--a_copy_lambda', type=int, default=1) + parser.add_argument('--copy_lambda_mini', type=float, default=0.1) + parser.add_argument('--copy_lambda_decay_steps', type=int, default=10000) + parser.add_argument('--copy_lambda_decay_value', type=float, default=1.0) + parser.add_argument('--init_tau', type=float, default=1.0) + parser.add_argument('--tau_mini', type=float, default=0.1) + parser.add_argument('--tau_decay_total_steps', type=int, default=5000) + parser.add_argument('--tau_decay_rate', type=float, default=0.5) + parser.add_argument('--beam_width', type=int, default=1) + parser.add_argument('--wo_l', action='store_true') + parser.add_argument('--wo_m', action='store_true') + parser.add_argument('--wo_entropy_restrain', action='store_true') + parser.add_argument('--wo_repeat_penalty', action='store_true') + parser.add_argument('--wo_rl', action='store_true') + parser.add_argument('--super_only', action='store_true') + parser.add_argument('--hungary', action='store_true') + parser.add_argument('--super_epoch', type=int, default=5) + parser.add_argument('--batch_size', type=int, default=8) + parser.add_argument('--reg_lambda', type=float, default=5e-3) + parser.add_argument('--BOS_CONTEXT', type=str, default="[s_context]") + parser.add_argument('--EOS_CONTEXT', type=str, default="[/s_context]") + parser.add_argument('--BOS_RESPONSE', type=str, default="[s_response>]") + parser.add_argument('--EOS_RESPONSE', type=str, default="[/s_response]") + parser.add_argument('--BOS_ACTION', type=str, default="[s_action]") + parser.add_argument('--EOS_ACTION', type=str, default="[/s_action]") + parser.add_argument('--PAD_WORD', type=str, default="[PAD]") + parser.add_argument('--SENTENCE_SPLITER', type=str, default="[sent]") + parser.add_argument('--TOPIC_SPLITER', type=str, default="[unused2]") + parser.add_argument('--UNK_WORD', type=str, default="[UNK]") + parser.add_argument('--BOS_PRE', type=str, default="[s_preference]") + parser.add_argument('--EOS_PRE', type=str, default="[/s_preference]") + parser.add_argument('--BOS_PRO', type=str, default="[s_profile]") + parser.add_argument('--EOS_PRO', type=str, default="[/s_profile]") + args = parser.parse_args() + return args + +def main(): + + random.seed(1234) + args = config() + main_logger.info("preparing data") + + vocab = Vocab(args) + dataset = DataSet(args=args, vocab=vocab) + train_set, valid_set, test_set, users, user_cont = dataset.get_dialog(task='gene') + train_loader = DataLoader(train_set, batch_size=args.batch_size, collate_fn=collate_fn, pin_memory=True, num_workers=args.num_workers, shuffle=True) + valid_loader = DataLoader(valid_set, batch_size=args.batch_size, collate_fn=collate_fn, pin_memory=True, num_workers=args.num_workers, shuffle=False) + test_loader = DataLoader(test_set, batch_size=args.batch_size, collate_fn=collate_fn, pin_memory=True, num_workers=args.num_workers, shuffle=False) + + + vocab = Vocab(args) + excrs = Upcrgene(args=args, vocab=vocab, user_cont=user_cont, train_set=train_set) + engine = Engine(args=args, model=excrs, vocab=vocab) + if not os.path.exists('saved_model'): + os.mkdir('saved_model') + if args.test: + engine.model.load_state_dict(torch.load('saved_model/best_generate_model_{}.pkl'.format(args.dataset)), strict=False) + engine.test(test_loader, get_ppl=False, is_show=False) + else: + engine.train(train_loader, valid_loader, test_loader) +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/get_logger.py b/get_logger.py new file mode 100644 index 0000000..00a95b1 --- /dev/null +++ b/get_logger.py @@ -0,0 +1,29 @@ + + + +import os +import uuid +import logging + +task_uuid = str(uuid.uuid4())[:8] + +def get_logger(name, filename): + logger = logging.getLogger(name) + logger.setLevel(logging.DEBUG) + formatter = logging.Formatter( + fmt="[{}] - ".format(task_uuid) + '%(asctime)s [%(filename)s %(lineno)d] %(name)s - %(levelname)s: %(message)s', + datefmt='%m/%d/%Y %H:%M:%S') + + if not os.path.exists('log/'): + os.mkdir('log') + file_handler = logging.FileHandler(filename) + file_handler.setLevel(logging.DEBUG) + file_handler.setFormatter(formatter) + logger.addHandler(file_handler) + + console_handler = logging.StreamHandler() + console_handler.setLevel(logging.INFO) + console_handler.setFormatter(formatter) + logger.addHandler(console_handler) + + return logger diff --git a/gumbel_softmax.py b/gumbel_softmax.py new file mode 100644 index 0000000..17e0601 --- /dev/null +++ b/gumbel_softmax.py @@ -0,0 +1,100 @@ + +import torch +import torch.nn as nn +import random +import argparse +import numpy as np +from tqdm import tqdm +from collections import Counter + +def config(): + parser = argparse.ArgumentParser() + parser.add_argument("--sample_num", type=int, default=5000) + parser.add_argument("--hidden_size", type=int, default=128) + parser.add_argument("--batch_size", type=int, default=32) + parser.add_argument("--train_epoch", type=int, default=10000) + cfg = parser.parse_args() + return cfg + +def one_hot(indice, num_classes): + I = torch.eye(num_classes) + T = I[indice] + T.requires_grad = False + return T + +class GumbelSoftmax(nn.Module): + def __init__(self, origin_version=True, rep_penalize=False, reoper=10): + super(GumbelSoftmax, self).__init__() + + self.origin_version = origin_version + self.eps = 1e-24 + self.step = 0 + self.rep_penalize = rep_penalize + self.reoper = reoper + + def forward(self, inp, tau, normed): + if normed: + inp = torch.log(inp + self.eps) + if not self.origin_version: + gk = -torch.log(-torch.log(torch.rand(inp.shape))) + out = torch.softmax((inp + gk) / tau, dim=-1) + else: + if self.rep_penalize: + expand_inp = inp.unsqueeze(1).expand(-1, self.reoper, -1, -1) + out = torch.nn.functional.gumbel_softmax(expand_inp, tau=tau) + max_index = out.argmax(-1) + max_index = max_index.reshape(max_index.size(0), -1) + max_index = max_index.detach().cpu().tolist() + def find_index(rand_value, prob_list): + ceil = np.cumsum(prob_list[:-1]) + index = (rand_value > ceil).astype(np.long).sum() + return int(index) + batch_selected_indexs = [] + for b in range(expand_inp.size(0)): + c = Counter() + c.update(max_index[b]) + index2prob = dict([(x, 1 / y) for x, y in c.most_common()]) + probs = [index2prob[i] for i in max_index[b]] + probs_sum = sum(probs) + normalized_probs = [x / probs_sum for x in probs] + indexs = [find_index(random.random(), normalized_probs) for _ in range(expand_inp.size(2))] + batch_selected_indexs.append(indexs) + B, _, S, T = out.shape + flat_out = out.reshape(-1, T) + indexs = torch.tensor(batch_selected_indexs).reshape(-1) + indexs = indexs + torch.arange(B).unsqueeze(1).expand(-1, self.reoper).reshape(-1) * self.reoper * S + flat_out = flat_out.index_select(0, indexs) + out = flat_out.reshape(B, S, -1) + else: + out = torch.nn.functional.gumbel_softmax(inp, tau=tau) + return out + +class Argmax(nn.Module): + def __init__(self): + super(Argmax, self).__init__() + def forward(self, inp): + return torch.argmax(inp, dim=-1) + +class GUMBEL(nn.Module): + def __init__(self, sample_num, hidden_size, is_train=False, gumbel_act=True): + super(GUMBEL, self).__init__() + self.is_train = is_train + self.gumbel_act = gumbel_act + self.embedding_layer = nn.Linear(sample_num, hidden_size) + self.pred_layer = nn.Sequential(nn.Linear(hidden_size, hidden_size), + nn.ReLU(), + nn.Linear(hidden_size, sample_num)) + self.train_act1 = nn.Softmax(dim=-1) + self.train_act2 = GumbelSoftmax() + self.test_act3 = Argmax() + + def get_act_fn(self): + act_fn = self.test_act3 if not self.is_train else (self.train_act2 if self.gumbel_act else self.train_act1) + return act_fn + + def forward(self, sample): + sample = sample.cuda() + sample_embedding = self.embedding_layer(sample) + pred = self.pred_layer(sample_embedding) + ret = self.get_act_fn()(pred) + return ret \ No newline at end of file diff --git a/option.py b/option.py new file mode 100644 index 0000000..a2d373f --- /dev/null +++ b/option.py @@ -0,0 +1,93 @@ +import torch + +class option(object): + n_layers=6 + n_position = 160 + n_inner_vocab = 5000 + n_inner_layers = 3 + n_inner_position = 15 + d_word_vec = 512 + n_head = 8 + d_k = 64 + d_v = 64 + pad_idx = 2 + d_model = 512 + d_inner = 2048 + dropout = 0.1 + n_warmup_steps = 2000 + scale_emb = False + switch_interval = 16 + cache_turn = 0 + context_max_len = 150 + r_max_len = 50 + r_beam_max_len = 30 + conv_max_len = 500 + profile_num = 1 + state_num = 10 + state_num_redial = 20 + pretrain_state_num = 50 + all_topic_num = 20 + all_topic_num_redial = 40 + movie_path_len = 3 + tag_num = 3 + preference_num = 1 + topic_num = 2 + action_num = 10 + action_num_redial = 1 + relation_num = 150 + movie_num = 200 + state_token = 40 + scale_prj = True + epoch = 100 + task = "meddg" + dataset_file = "dataset/{dataset}.zip" + topic_file = "dataset/TG-ReDial/topic.txt" + topic_movie_file = "dataset/TG-ReDial/tpmv.txt" + vocab_file = "dataset/TG-ReDial/tpvocab.txt" + vocab_movie_file = "dataset/TG-ReDial/tpmvvocab.txt" + no_action_super = None + max_patience = 20 + log_loss_interval = 100 + gradient_stack = 8 + decay_interval = 10000 + decay_rate = 0.9 + lr = 1e-5 + valid_eval_interval = 10000 + test_eval_interval = 10000 + force_ckpt_dump = True + sub_gen_lambda = 0.01 + s_copy_lambda = 1 + a_copy_lambda = 1 + copy_lambda_mini = 0.1 + copy_lambda_decay_steps = 10000 + copy_lambda_decay_value = 1.0 + init_tau = 1.0 + tau_mini = 0.1 + tau_decay_total_steps = 5000 + tau_decay_rate = 0.5 + beam_width = 1 + wo_l = False + wo_m = False + wo_entropy_restrain = False + wo_repeat_penalty = False + wo_rl = False + super_only = False + hungary = False + super_rate = 0.0 + super_epoch = 5 + batch_size = 16 + reg_lambda = 5e-3 + BOS_CONTEXT = "[s_context]" + EOS_CONTEXT = "[/s_context]" + BOS_RESPONSE = "[s_response>]" + EOS_RESPONSE = "[/s_response]" + BOS_ACTION = "[s_action]" + EOS_ACTION = "[/s_action]" + PAD_WORD = "[PAD]" + SENTENCE_SPLITER = "[sent]" + TOPIC_SPLITER = "[unused2]" + UNK_WORD = "[UNK]" + BOS_PRE = "[s_preference]" + EOS_PRE = "[/s_preference]" + BOS_PRO = "[s_profile]" + EOS_PRO = "[/s_profile]" \ No newline at end of file diff --git a/persona_eval.py b/persona_eval.py new file mode 100644 index 0000000..7c18be1 --- /dev/null +++ b/persona_eval.py @@ -0,0 +1,72 @@ +import json +import math + + +def docs(w, history_list): + c = 0 + for i, h in enumerate(history_list): + if w in h: + c += 1 + return c + + +def gen_idf_dict(history_list): + idf_dict = {} + for i, h in enumerate(history_list): + for w in h: + if w not in idf_dict: + idf = math.log(len(history_list) * 1.0 / docs(w, history_list)) + idf_dict[w] = idf + return idf_dict + + +def cal_s_for_each_history(r, h, idf_dict): + c = 0 + has_c = {} + for w in r: + if w in h and w not in has_c: + c += idf_dict[w] + has_c[w] = 1 + return c + + +def cal_p_cover(src_generate, all_personas): + s_sum = 0 + line_cnt = 0 + for result, personas in zip(src_generate, all_personas): + idf_dict = gen_idf_dict(personas) + + + s_list = [] + for i, persona in enumerate(personas): + s = cal_s_for_each_history(result, persona, idf_dict) + s_list.append(s) + s_max = max(s_list) + s_sum += s_max + line_cnt += 1 + return (s_sum + 0.0) / line_cnt + + +def cal_f1(result, personas): + p_all = [] + for i, p in enumerate(personas): + p_all += p + h_set = set(p_all) + r_set = set(result) + if len(h_set) == 0 or len(r_set) == 0: + p, r = 0, 0 + else: + p = len(h_set & r_set) / len(r_set) + r = len(h_set & r_set) / len(h_set) + if p == r == 0: + return 0 + return (2 * p * r) / (p + r) + + +def cal_p_f1(src_generate, all_personas): + s_sum = 0 + line_cnt = 0 + for result, personas in zip(src_generate, all_personas): + s_sum += cal_f1(result, personas) + line_cnt += 1 + return (s_sum + 0.0) / line_cnt diff --git a/pretrained_models/GPT2/.gitkeep b/pretrained_models/GPT2/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/recommendation.py b/recommendation.py new file mode 100644 index 0000000..8e20bbc --- /dev/null +++ b/recommendation.py @@ -0,0 +1,125 @@ +import argparse +import random +from DataProcessor import DataSet +from get_logger import get_logger +from get_logger import task_uuid +from Vocab import Vocab +from upcrrec import Upcrrec,Engine + +main_logger = get_logger("main", './log/test.log') +main_logger.info("TASK ID {}".format(task_uuid)) + +def config(): + parser = argparse.ArgumentParser() + parser.add_argument('--inference',type=bool,default=False,) + parser.add_argument("--processed", type=bool, default=True, ) + + + parser.add_argument('--n_layers', type=int, default=6) + parser.add_argument('--n_position', type=int, default=160) + parser.add_argument('--n_inner_vocab', type=int, default=5000) + parser.add_argument('--n_inner_layers', type=int, default=3) + parser.add_argument('--n_inner_position', type=int, default=15) + parser.add_argument('--d_word_vec', type=int, default=512) + parser.add_argument('--n_head', type=int, default=8) + parser.add_argument('--d_k', type=int, default=64) + parser.add_argument('--d_v', type=int, default=64) + parser.add_argument('--pad_idx', type=int, default=2) + parser.add_argument('--d_model', type=int, default=512) + parser.add_argument('--d_inner', type=int, default=2048) + parser.add_argument('--dropout', type=float, default=0.1) + parser.add_argument('--n_warmup_steps', type=int, default=2000) + parser.add_argument('--scale_emb', type=bool, default=False) + parser.add_argument('--switch_interval', type=int, default=16) + parser.add_argument('--cache_turn', type=int, default=0) + parser.add_argument('--context_all_max_len', type=int, default=1024) + parser.add_argument('--context_max_len', type=int, default=150) + parser.add_argument('--r_max_len', type=int, default=50) + parser.add_argument('--r_beam_max_len', type=int, default=30) + + parser.add_argument('--profile_num', type=int, default=1) + parser.add_argument('--state_num', type=int, default=20) + parser.add_argument('--state_num_redial', type=int, default=20) + parser.add_argument('--pretrain_state_num', type=int, default=50) + parser.add_argument('--all_topic_num', type=int, default=20) + parser.add_argument('--all_topic_num_redial', type=int, default=40) + parser.add_argument('--movie_path_len', type=int, default=3) + parser.add_argument('--tag_num', type=int, default=3) + parser.add_argument('--preference_num', type=int, default=1) + parser.add_argument('--topic_num', type=int, default=2) + parser.add_argument('--action_num', type=int, default=10) + parser.add_argument('--action_num_redial', type=int, default=1) + parser.add_argument('--relation_num', type=int, default=150) + parser.add_argument('--movie_num', type=int, default=200) + parser.add_argument('--state_token', type=int, default=40) + parser.add_argument('--scale_prj', type=bool, default=True) + parser.add_argument('--epoch', type=int, default=100) + parser.add_argument('--task', type=str, default="meddg") + parser.add_argument('--dataset_file', type=str, default="./dataset/{}.zip") + parser.add_argument('--topic_file', type=str, default="./dataset/{}/topic.txt") + parser.add_argument('--topic_movie_file', type=str, default="./dataset/{}/tpmv.txt") + parser.add_argument('--vocab_file', type=str, default="./dataset/{}/tpvocab.txt") + parser.add_argument('--vocab_movie_file', type=str, default="./dataset/{}/tpmvvocab.txt") + parser.add_argument('--no_action_super', type=str, default=None) + parser.add_argument('--max_patience', type=int, default=20) + parser.add_argument('--log_loss_interval', type=int, default=100) + parser.add_argument('--gradient_stack', type=int, default=80) + parser.add_argument('--decay_interval', type=int, default=10000) + parser.add_argument('--decay_rate', type=float, default=0.9) + parser.add_argument('--lr', type=float, default=1e-5) + parser.add_argument('--valid_eval_interval', type=int, default=10000) + parser.add_argument('--test_eval_interval', type=int, default=10000) + parser.add_argument('--force_ckpt_dump', action='store_true') + parser.add_argument('--sub_gen_lambda', type=float, default=0.01) + parser.add_argument('--s_copy_lambda', type=int, default=1) + parser.add_argument('--a_copy_lambda', type=int, default=1) + parser.add_argument('--copy_lambda_mini', type=float, default=0.1) + parser.add_argument('--copy_lambda_decay_steps', type=int, default=10000) + parser.add_argument('--copy_lambda_decay_value', type=float, default=1.0) + parser.add_argument('--init_tau', type=float, default=1.0) + parser.add_argument('--tau_mini', type=float, default=0.1) + parser.add_argument('--tau_decay_total_steps', type=int, default=5000) + parser.add_argument('--tau_decay_rate', type=float, default=0.5) + parser.add_argument('--beam_width', type=int, default=1) + parser.add_argument('--wo_l', action='store_true') + parser.add_argument('--wo_m', action='store_true') + parser.add_argument('--wo_entropy_restrain', action='store_true') + parser.add_argument('--wo_repeat_penalty', action='store_true') + parser.add_argument('--wo_rl', action='store_true') + parser.add_argument('--super_only', action='store_true') + parser.add_argument('--hungary', action='store_true') + parser.add_argument('--super_epoch', type=int, default=5) + parser.add_argument('--batch_size', type=int, default=8) + parser.add_argument('--reg_lambda', type=float, default=5e-3) + parser.add_argument('--BOS_CONTEXT', type=str, default="[s_context]") + parser.add_argument('--EOS_CONTEXT', type=str, default="[/s_context]") + parser.add_argument('--BOS_RESPONSE', type=str, default="[s_response>]") + parser.add_argument('--EOS_RESPONSE', type=str, default="[/s_response]") + parser.add_argument('--BOS_ACTION', type=str, default="[s_action]") + parser.add_argument('--EOS_ACTION', type=str, default="[/s_action]") + parser.add_argument('--PAD_WORD', type=str, default="[PAD]") + parser.add_argument('--SENTENCE_SPLITER', type=str, default="[sent]") + parser.add_argument('--TOPIC_SPLITER', type=str, default="[unused2]") + parser.add_argument('--UNK_WORD', type=str, default="[UNK]") + parser.add_argument('--BOS_PRE', type=str, default="[s_preference]") + parser.add_argument('--EOS_PRE', type=str, default="[/s_preference]") + parser.add_argument('--BOS_PRO', type=str, default="[s_profile]") + parser.add_argument('--EOS_PRO', type=str, default="[/s_profile]") + args = parser.parse_args() + return args + +def main(): + random.seed(1234) + args = config() + main_logger.info("preparing data") + vocab = Vocab(args) + dataset = DataSet(args=args, vocab=vocab) + train, valid, test, users, user_cont = dataset.get_dialog(task='rec') + vocab = Vocab(task='rec') + random.shuffle(train) + excrs = Upcrrec(vocab=vocab,user_cont=user_cont) + engine = Engine(model=excrs,vocab=vocab) + engine.train(train,test) + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..5e63f0a --- /dev/null +++ b/requirements.txt @@ -0,0 +1,361 @@ +absl-py==1.3.0 +aiohttp==3.8.3 +aiosignal==1.3.1 +alabaster==0.7.12 +antlr4-python3-runtime==4.8 +anyascii==0.3.1 +anykeystore==0.2 +apex @ file:///home/fsx/downloads/apex +asttokens==2.2.1 +astunparse==1.6.3 +async-timeout==4.0.2 +attrdict==2.0.1 +attrs==20.2.0 +autopep8==2.0.0 +Babel==2.11.0 +backcall==0.2.0 +bcrypt==4.0.1 +bert-score==0.3.13 +better-exceptions==0.3.3 +bitarray==2.8.0 +bleach==6.0.0 +blinker==1.6.2 +blis==0.7.9 +blobfile==2.0.2 +boto3==1.26.77 +botocore==1.29.77 +Bottleneck==1.3.7 +cachetools==5.2.0 +catalogue==2.0.8 +certifi==2022.12.7 +cffi==1.15.1 +charset-normalizer==2.1.1 +click==8.0.4 +cloudpickle==2.2.1 +cmake==3.25.0 +colorama==0.4.6 +coloredlogs==15.0.1 +colorful==0.5.5 +colorlog==6.7.0 +commonmark==0.9.1 +confection==0.0.4 +contourpy==1.0.6 +contractions==0.1.73 +cPython==0.0.6 +cryptacular==1.6.2 +cryptography==39.0.0 +cycler==0.11.0 +cymem==2.0.7 +Cython==0.29.35 +datasets==1.8.0 +decorator==5.1.1 +defusedxml==0.7.1 +dgl==1.0.1+cu116 +dgl-cu116==0.9.1.post1 +dglgo==0.0.2 +dill==0.3.6 +dm-tree==0.1.8 +dnspython==2.3.0 +docformatter==1.5.1 +docker-pycreds==0.4.0 +docopt==0.6.2 +docutils==0.15.2 +dotmap==1.3.30 +embeddings==0.0.8 +emoji==2.2.0 +en-core-web-sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-2.2.0/en_core_web_sm-2.2.0.tar.gz +entmax==1.1 +exceptiongroup==1.1.0 +executing==1.2.0 +fairscale==0.4.13 +fairseq==0.12.2 +faiss-gpu==1.7.2 +FastNLP==1.0.1 +filelock==3.8.2 +fitlog==0.9.15 +flake8==6.0.0 +flake8-bugbear==23.1.20 +Flask==1.1.1 +flatbuffers==22.12.6 +fonttools==4.38.0 +frozenlist==1.3.3 +fsspec==2022.2.0 +fst-pso==1.8.1 +future==0.18.3 +FuzzyTM==2.0.5 +fuzzywuzzy==0.18.0 +gast==0.4.0 +gensim==4.3.0 +gevent==22.10.2 +gitdb==4.0.10 +gitdb2==4.0.2 +GitPython==3.1.29 +google-api-core==2.10.1 +google-auth==2.15.0 +google-auth-oauthlib==1.0.0 +google-cloud-core==2.3.2 +google-cloud-storage==2.7.0 +google-crc32c==1.5.0 +google-pasta==0.2.0 +google-resumable-media==2.4.1 +googleapis-common-protos==1.56.4 +greenlet==2.0.2 +grpcio==1.51.1 +h5py==3.7.0 +hnswlib==0.7.0 +huggingface-hub==0.16.4 +humanfriendly==10.0 +hupper==1.11 +hydra-core==1.0.7 +idna==3.4 +imageio==2.28.1 +imagesize==1.4.1 +importlib-metadata==4.2.0 +-e git+https://github.com/XiangLi1999/Diffusion-LM.git@759889d58ef38e2eed41a8c34db8032e072826f4 +iniconfig==2.0.0 +install==1.3.5 +iopath==0.1.10 +ipdb==0.13.11 +ipython==8.9.0 +isort==5.10.1 +itsdangerous==1.1.0 +jaraco.classes==3.2.3 +jax==0.4.9 +jedi==0.18.2 +jeepney==0.8.0 +jieba==0.42.1 +Jinja2==2.10.3 +jmespath==1.0.1 +joblib==1.2.0 +jsonlines==3.1.0 +keras==2.12.0 +keyring==24.0.0 +kiwisolver==1.4.4 +langcodes==3.3.0 +language-evaluation @ git+https://github.com/bckim92/language-evaluation.git@8e9e03feb7f5e605f20fc0c8dba0e24532a26216 +lazy_loader==0.2 +libclang==14.0.6 +lightning-utilities==0.8.0 +lit==15.0.7 +littleutils==0.2.2 +loguru==0.7.0 +lxml==4.9.3 +Markdown==3.3.2 +markdown-it-py==0.5.8 +MarkupSafe==2.1.3 +matplotlib==3.6.2 +matplotlib-inline==0.1.6 +mccabe==0.7.0 +miniful==0.0.6 +mip==1.15.0 +ml-dtypes==0.1.0 +mock==5.0.1 +more-itertools==9.1.0 +mpi4py @ file:///home/conda/feedstock_root/build_artifacts/mpi4py_1667459929885/work +mpmath==1.2.1 +multidict==6.0.4 +multiprocess==0.70.14 +munkres==1.1.4 +murmurhash==1.0.9 +myst-parser==0.12.10 +namedlist==1.8 +networkx==2.8.8 +ninja==1.10.2.4 +nltk==3.7 +numpy==1.23.5 +numpydoc==1.5.0 +nvidia-cublas-cu11==11.10.3.66 +nvidia-cuda-cupti-cu11==11.7.101 +nvidia-cuda-nvrtc-cu11==11.7.99 +nvidia-cuda-runtime-cu11==11.7.99 +nvidia-cudnn-cu11==8.5.0.96 +nvidia-cufft-cu11==10.9.0.58 +nvidia-curand-cu11==10.2.10.91 +nvidia-cusolver-cu11==11.4.0.1 +nvidia-cusparse-cu11==11.7.4.91 +nvidia-ml-py==11.515.75 +nvidia-nccl-cu11==2.14.3 +nvidia-nvtx-cu11==11.7.91 +nvitop==1.1.2 +oauthlib==3.2.2 +ogb==1.3.5 +ogb-lite==0.0.3 +omegaconf==2.0.6 +opt-einsum==3.3.0 +outdated==0.2.2 +packaging==22.0 +pandas==1.5.2 +paramiko==2.12.0 +parlai==0.1.0 +parso==0.8.3 +PasteDeploy==3.0.1 +pathlib2==2.3.7.post1 +pathos==0.3.0 +pathtools==0.1.2 +pathy==0.10.1 +pbkdf2==1.3 +pexpect==4.8.0 +pickleshare==0.7.5 +Pillow==9.3.0 +pkginfo==1.9.6 +plaster==1.1.2 +plaster-pastedeploy==1.0.1 +pluggy==1.0.0 +portalocker==2.7.0 +pox==0.3.2 +ppft==1.7.6.6 +preshed==3.0.8 +prettytable==3.8.0 +promise==2.3 +prompt-toolkit==3.0.36 +protobuf==3.20.1 +psutil==5.9.4 +ptyprocess==0.7.0 +pure-eval==0.2.2 +py-gfm==2.0.0 +py-rouge==1.1 +pyahocorasick==2.0.0 +pyarrow==3.0.0 +pyasn1==0.4.8 +pyasn1-modules==0.2.8 +pycodestyle==2.10.0 +pycparser==2.21 +pycryptodomex==3.18.0 +pydantic==1.8.2 +pyDeprecate==0.3.1 +pyflakes==3.0.1 +pyFUME==0.2.25 +Pygments==2.13.0 +pymongo==4.3.3 +PyNaCl==1.5.0 +pyparsing==3.0.9 +pyramid==2.0.1 +pyramid-mailer==0.15.1 +PySocks==1.7.1 +pytest==7.2.1 +pytest-datadir==1.4.1 +pytest-regressions==2.4.2 +python-dateutil==2.8.2 +python3-openid==3.2.0 +pytorch-ignite==0.4.9 +pytorch-lightning==1.5.4 +pytz==2022.6 +PyWavelets==1.4.1 +PyYAML==6.0 +pyzmq==25.0.0 +rdkit-pypi==2022.9.3 +readme-renderer==40.0 +regex==2022.10.31 +repoze.sendmail==4.4.1 +requests==2.28.1 +requests-mock==1.10.0 +requests-oauthlib==1.3.1 +requests-toolbelt==1.0.0 +responses==0.18.0 +rfc3986==2.0.0 +rich==11.2.0 +rouge==1.0.0 +rsa==4.9 +ruamel.yaml==0.17.21 +ruamel.yaml.clib==0.2.7 +s3transfer==0.6.0 +sacrebleu==2.3.1 +sacremoses==0.0.53 +scikit-image==0.20.0 +scikit-learn==1.2.0 +scipy==1.9.1 +seaborn==0.12.2 +SecretStorage==3.3.3 +sentry-sdk==1.11.1 +setproctitle==1.3.2 +sh==1.14.3 +shortuuid==1.0.11 +simpful==2.9.0 +six==1.16.0 +sklearn==0.0.post1 +smart-open==6.3.0 +smmap==5.0.0 +snowballstemmer==2.2.0 +spacy==3.2.4 +spacy-legacy==3.0.12 +spacy-loggers==1.0.4 +Sphinx==2.2.2 +sphinx-autodoc-typehints==1.10.3 +sphinx-rtd-theme==1.1.1 +sphinxcontrib-applehelp==1.0.2 +sphinxcontrib-devhelp==1.0.2 +sphinxcontrib-htmlhelp==2.0.0 +sphinxcontrib-jsmath==1.0.1 +sphinxcontrib-qthelp==1.0.3 +sphinxcontrib-serializinghtml==1.1.5 +SQLAlchemy==1.4.46 +srsly==2.4.6 +stack-data==0.6.2 +stanza==1.4.2 +subword-nmt==0.3.8 +sympy==1.11.1 +tabulate==0.9.0 +tensorboard==2.12.3 +tensorboard-data-server==0.7.0 +tensorboard-plugin-wit==1.8.1 +tensorboardX==2.5.1 +tensorflow-addons==0.20.0 +tensorflow-estimator==2.12.0 +tensorflow-io-gcs-filesystem==0.28.0 +tensorflow-probability==0.20.0 +termcolor==2.1.1 +textsearch==0.0.24 +tf-geometric==0.1.5 +tf-sparse==0.0.17 +thinc==8.0.17 +threadpoolctl==3.1.0 +tifffile==2023.4.12 +tokenizers==0.10.3 +tomli==2.0.1 +torch==2.0.1 +torch-geometric==2.2.0 +torch-scatter==2.1.1+pt20cu118 +torch-sparse==0.6.17+pt20cu118 +torchaudio==2.0.2+cu118 +torchdata==0.6.1 +torchdiffeq==0.2.3 +torchdynamo==1.13.0 +torchmetrics==0.11.4 +torchsample==0.1.0 +torchtext==0.15.2 +tornado==6.2 +tqdm==4.49.0 +traitlets==5.9.0 +transaction==3.0.1 +-e git+https://github.com/XiangLi1999/Diffusion-LM.git@759889d58ef38e2eed41a8c34db8032e072826f4 +translationstring==1.4 +triton==2.0.0 +twine==3.8.0 +typeguard==2.13.3 +typer==0.4.2 +typing_extensions==4.4.0 +Unidecode==1.3.6 +untokenize==0.1.1 +urllib3==1.26.13 +velruse==1.1.1 +venusian==3.0.0 +wandb==0.13.6 +wasabi==0.10.1 +wcwidth==0.2.6 +webencodings==0.5.1 +WebOb==1.8.7 +websocket==0.2.1 +websocket-client==1.5.1 +websocket-server==0.6.4 +Werkzeug==2.3.6 +wget==3.2 +wrapt==1.14.1 +WTForms==3.0.1 +wtforms-recaptcha==0.3.2 +xxhash==3.2.0 +yarl==1.8.2 +zipp==3.11.0 +zope.deprecation==4.4.0 +zope.event==4.6 +zope.interface==5.5.2 +zope.sqlalchemy==2.0 diff --git a/rgat.py b/rgat.py new file mode 100644 index 0000000..c9e5c4e --- /dev/null +++ b/rgat.py @@ -0,0 +1,152 @@ +import dgl.function as fn +import torch.nn as nn +import torch +import math + + +class DualRGATLayer(nn.Module): + + def __init__(self, ndim, edim, num_heads=8, feat_drop=0.2): + super(DualRGATLayer, self).__init__() + self.ndim, self.edim = ndim, edim + self.num_heads = num_heads + self.node_update = RGATLayer(self.ndim, self.edim, self.num_heads, feat_drop=feat_drop) + self.edge_update = EdgeRGATLayer(self.edim, self.ndim, self.num_heads, feat_drop=feat_drop) + + def forward(self, node_embed, edge_embed, node_graph, line_graph, src_ids, dst_ids): + out_x, _ = self.node_update(node_embed, edge_embed, node_graph) + src_x = torch.index_select(node_embed, dim=0, index=src_ids) + dst_x = torch.index_select(node_embed, dim=0, index=dst_ids) + out_local_lgx, _ = self.edge_update(edge_embed, src_x, dst_x, line_graph) + return out_x, out_local_lgx + + +class RGATLayer(nn.Module): + + def __init__(self, ndim, edim, num_heads=8, feat_drop=0.2): + super(RGATLayer, self).__init__() + self.ndim, self.edim = ndim, edim + self.num_heads = num_heads + dim = max([ndim, edim]) + self.d_k = dim // self.num_heads + self.affine_q, self.affine_k, self.affine_v = nn.Linear(self.ndim, dim), nn.Linear(self.ndim, dim, bias=False), nn.Linear(self.ndim, dim, bias=False) + self.affine_o = nn.Linear(dim, self.ndim) + self.layernorm = nn.LayerNorm(self.ndim) + self.feat_dropout = nn.Dropout(p=feat_drop) + self.ffn = FFN(self.ndim) + + def forward(self, x, lgx, g): + + + q, k, v = self.affine_q(self.feat_dropout(x)), self.affine_k(self.feat_dropout(x)), self.affine_v(self.feat_dropout(x)) + e = lgx.view(-1, self.num_heads, self.d_k) if lgx.size(-1) == q.size(-1) else lgx.unsqueeze(1).expand(-1, self.num_heads, -1) + with g.local_scope(): + g.ndata['q'], g.ndata['k'] = q.view(-1, self.num_heads, self.d_k), k.view(-1, self.num_heads, self.d_k) + g.ndata['v'] = v.view(-1, self.num_heads, self.d_k) + g.edata['e'] = e + out_x = self.propagate_attention(g) + + out_x = self.layernorm(x + self.affine_o(out_x.view(-1, self.num_heads * self.d_k))) + out_x = self.ffn(out_x) + return out_x, lgx + + def propagate_attention(self, g): + + g.apply_edges(src_sum_edge_mul_dst('k', 'q', 'e', 'score')) + g.apply_edges(scaled_exp('score', math.sqrt(self.d_k))) + + g.update_all(src_sum_edge_mul_edge('v', 'e', 'score', 'v'), fn.sum('v', 'wv')) + g.update_all(fn.copy_e('score', 'score'), fn.sum('score', 'z'), div_by_z('wv', 'z', 'o')) + out_x = g.ndata['o'] + return out_x + + +class EdgeRGATLayer(nn.Module): + + def __init__(self, edim, ndim, num_heads=8, feat_drop=0.2): + super(EdgeRGATLayer, self).__init__() + self.edim, self.ndim = edim, ndim + self.num_heads = num_heads + self.d_k = self.ndim // self.num_heads + self.affine_q, self.affine_k, self.affine_v = nn.Linear(self.edim, self.ndim), nn.Linear(self.edim, self.ndim, bias=False), nn.Linear(self.edim, self.ndim, bias=False) + self.affine_o = nn.Linear(self.ndim, self.edim) + self.layernorm = nn.LayerNorm(self.edim) + self.feat_dropout = nn.Dropout(p=feat_drop) + self.ffn = FFN(self.edim) + + def forward(self, x, src_x, dst_x, g): + q, k, v = self.affine_q(self.feat_dropout(x)), self.affine_k(self.feat_dropout(x)), self.affine_v(self.feat_dropout(x)) + with g.local_scope(): + g.ndata['q'] = (q + src_x).view(-1, self.num_heads, self.d_k) + g.ndata['k'] = k.view(-1, self.num_heads, self.d_k) + g.ndata['v'] = (v + dst_x).view(-1, self.num_heads, self.d_k) + out_x = self.propagate_attention(g) + out_x = self.layernorm(x + self.affine_o(out_x.view(-1, self.num_heads * self.d_k))) + out_x = self.ffn(out_x) + return out_x, (src_x, dst_x) + + def propagate_attention(self, g): + + g.apply_edges(src_dot_dst('k', 'q', 'score')) + g.apply_edges(scaled_exp('score', math.sqrt(self.d_k))) + + g.update_all(fn.u_mul_e('v', 'score', 'v'), fn.sum('v', 'wv')) + g.update_all(fn.copy_e('score', 'score'), fn.sum('score', 'z'), div_by_z('wv', 'z', 'o')) + out_x = g.ndata['o'] + return out_x + + + +class FFN(nn.Module): + def __init__(self, input_size): + super(FFN, self).__init__() + self.input_size = input_size + self.feedforward = nn.Sequential( + nn.Linear(self.input_size, self.input_size * 4), + nn.ReLU(inplace=True), + nn.Linear(self.input_size * 4, self.input_size) + ) + self.layernorm = nn.LayerNorm(self.input_size) + + def forward(self, inputs): + return self.layernorm(inputs + self.feedforward(inputs)) + + + + +def src_dot_dst(src_field, dst_field, out_field): + def func(edges): + return {out_field: (edges.src[src_field] * edges.dst[dst_field]).sum(-1, keepdim=True)} + + return func + + +def src_sum_edge_mul_dst(src_field, dst_field, e_field, out_field): + def func(edges): + return {out_field: ((edges.src[src_field] + edges.data[e_field]) * edges.dst[dst_field]).sum(-1, keepdim=True)} + + return func + + +def scaled_exp(field, scale_constant): + def func(edges): + + return {field: torch.exp((edges.data[field] / scale_constant).clamp(-10, 10))} + + return func + + +def src_sum_edge_mul_edge(src_field, e_field1, e_field2, out_field): + def func(edges): + return {out_field: (edges.src[src_field] + edges.data[e_field1]) * edges.data[e_field2]} + + return func + + +def div_by_z(in_field, norm_field, out_field): + def func(nodes): + + nodes.data[norm_field][nodes.data[norm_field].eq(0)] = 1 + return {out_field: nodes.data[in_field] / nodes.data[norm_field]} + + return func diff --git a/saved_model/.gitkeep b/saved_model/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/tau_scheduler.py b/tau_scheduler.py new file mode 100644 index 0000000..ff89c2d --- /dev/null +++ b/tau_scheduler.py @@ -0,0 +1,39 @@ + + +class TauScheduler: + def __init__(self, origin_tau, mini_tau, n_step, s_step=0): + self.origin_tau = origin_tau + self.mini_tau = mini_tau + self.n_step = n_step + self.s_step = s_step + self.step_interval = (self.mini_tau - self.origin_tau) / self.n_step + + def step_on(self, advance=False): + if advance: + self.s_step += 1 + if self.s_step >= self.n_step: + return self.mini_tau + return self.s_step * self.step_interval + self.origin_tau + + def dump(self): + self_dict = { + "origin_tau": self.origin_tau, + "mini_tau": self.mini_tau, + "n_step": self.n_step, + "s_step": self.s_step, + } + return self_dict + + @staticmethod + def load(self_dict): + return TauScheduler(self_dict["origin_tau"], + self_dict["mini_tau"], + self_dict["n_step"], + self_dict["s_step"]) + + def self_load(self, self_dict): + self.origin_tau = self_dict["origin_tau"] + self.mini_tau = self_dict["mini_tau"] + self.n_step = self_dict["n_step"] + self.s_step = self_dict["s_step"] + self.step_interval = (self.mini_tau - self.origin_tau) / self.n_step diff --git a/tools.py b/tools.py new file mode 100644 index 0000000..e0ced07 --- /dev/null +++ b/tools.py @@ -0,0 +1,96 @@ +import torch +import functools + +class Tools(): + @staticmethod + def one_hot(tensor, n_vocab): + shape = list(tensor.shape) + shape = shape + [n_vocab] + new_tensor = tensor.new_zeros(*shape, dtype=torch.float) + new_tensor = new_tensor.scatter(dim=-1,index=tensor.unsqueeze(-1), + src=torch.ones_like(tensor, dtype=torch.float).unsqueeze(-1)) + return new_tensor + + @staticmethod + def _single_decode(input_seq, src_hiddens, src_mask, decoder, input_mask=None, ret_last_step=True): + batch_size = input_seq.size(0) + trg_seq_mask = Tools.get_subsequent_mask(input_seq) + trg_seq_mask = trg_seq_mask.expand(batch_size, -1, -1) + if input_mask is not None: + trg_seq_mask = input_mask & trg_seq_mask + dec_output = decoder(input_seq, trg_seq_mask, src_hiddens, src_mask) + if ret_last_step: + last_step_dec_output = dec_output[:, -1, :].unsqueeze(1) + return last_step_dec_output + else: + return dec_output + + @staticmethod + def get_subsequent_mask(seq): + sz_b, len_s = seq.size(0), seq.size(1) + subsequent_mask = (1 - torch.triu(torch.ones((1, len_s, len_s)), diagonal=1)).bool() + subsequent_mask = subsequent_mask.cuda() + return subsequent_mask + + @staticmethod + def _generate_init(batch_size, n_vocab, trg_bos_idx, training=True): + ret = torch.ones(batch_size, 1, dtype=torch.long) * trg_bos_idx + if training : + ret = Tools.one_hot(ret, n_vocab) + ret = ret.cuda() + return ret + + @staticmethod + def get_mask_via_len(length, max_len): + B = length.size(0) + mask = torch.ones([B, max_len]).cuda() + mask = torch.cumsum(mask, 1) + mask = mask <= length.unsqueeze(1) + mask = mask.unsqueeze(-2) + return mask + + @staticmethod + def nested_index_select(origin_data, select_index): + origin_data_shape = list(origin_data.shape) + select_index_shape = list(select_index.shape) + work_axes = len(select_index_shape) - 1 + grad_v = functools.reduce(lambda x, y: x * y, origin_data_shape[:work_axes]) + new_dim = select_index_shape[-1] + grad = torch.arange(0, grad_v, dtype=torch.long).unsqueeze(-1) + grad = grad.expand(-1, new_dim) + grad = grad.reshape(-1) + grad = grad * origin_data_shape[work_axes] + select_index = select_index.reshape(-1) + grad + reshaped_data = origin_data.reshape(grad_v * origin_data_shape[work_axes], -1) + selected_data = reshaped_data.index_select(0, select_index) + origin_data_shape[work_axes] = new_dim + selected_data = selected_data.reshape(origin_data_shape) + return selected_data + + @staticmethod + def repeat_penalty(dist, pad_idx=None): + L, V = dist.size(1), dist.size(2) + diag = torch.ones(L, dtype=torch.float) + mask = torch.ones(L, L, dtype=torch.float) - torch.diag_embed(diag) + mask = mask.unsqueeze(0).cuda() + eps = 1e-9 + dist1 = dist.unsqueeze(2).expand(-1, -1, L, -1) + dist2 = dist.unsqueeze(1).expand(-1, L, -1, -1) + pad_mask = torch.ones(1, 1, 1, V, dtype=torch.float) + if pad_idx is not None: + pad_mask[:, :, :, pad_idx] = .0 + pad_mask = pad_mask.cuda() + kl = (dist1 * torch.log(dist1 / (dist2 + eps) + eps) * pad_mask).sum(-1) + kl = (kl * mask).sum(-1).sum(-1) / (L * max((L - 1), 1)) + return - kl.mean() + + @staticmethod + def entropy_restrain(dist): + eps = 1e-9 + if len(dist.shape) == 3: + B, L, V = dist.shape + dist = dist.reshape(-1, V) + else: + B = dist.size(1) + entropy = (dist * torch.log(dist + eps)).sum() / B + return - entropy \ No newline at end of file diff --git a/topic.py b/topic.py new file mode 100644 index 0000000..107a0ce --- /dev/null +++ b/topic.py @@ -0,0 +1,217 @@ +import argparse +import os.path +import random +import torch +from DataProcessor import DataSet +from get_logger import get_logger +from get_logger import task_uuid +from Vocab import Vocab +from upcrtopic import Upcrtopic, Engine +import numpy as np +from torch.utils.data import DataLoader +from DataLoaderTopic import collate_fn + +main_logger = get_logger("main", './log/test.log') +main_logger.info("TASK ID {}".format(task_uuid)) + +def config(): + parser = argparse.ArgumentParser() + parser.add_argument("-test", "--test", action="store_true") + parser.add_argument("--test_data", choices=["test", "valid"], default="test") + parser.add_argument("-dataset", "--dataset", choices=["TG-ReDial", "PersonaChat"], default="TG-ReDial") + parser.add_argument("--num_workers", type=int, default=4) + parser.add_argument("--super_rate", type=float, default=0.) + parser.add_argument("--device", type=int, default=-1) + parser.add_argument("-seed", "--seed", type=int, default=1234) + parser.add_argument("--ckpt", type=str, default=None) + parser.add_argument('--inference', type=bool, default=False, ) + parser.add_argument("-use_cuda", "--use_cuda", type=bool, default=False) + parser.add_argument("-gpu", "--gpu", type=int, default=0) + parser.add_argument("-processed", "--processed", action='store_false', ) + parser.add_argument("-not_topic_guide", "--not_topic_guide", action='store_true', ) + parser.add_argument("--not_copynet", action='store_true', ) + + + parser.add_argument("-get_personas", "--get_personas", action='store_true', ) + + parser.add_argument("-global_topic", "--global_topic", action='store_true', ) + parser.add_argument("-global_topic_for_action", "--global_topic_for_action", action='store_true', ) + parser.add_argument("-top_p", "--top_p", type=float, default=0., ) + parser.add_argument("-profile_agg", "--profile_agg", action='store_true', ) + parser.add_argument("-otp", "--otp", action='store_true', ) + parser.add_argument("-s_profile_add_t", "--s_profile_add_t", action='store_true', ) + parser.add_argument("-topic_copynet", "--topic_copynet", action='store_true', ) + + parser.add_argument("-profile_pred", "--profile_pred", action='store_true', ) + parser.add_argument("-profile_select", "--profile_select", action='store_true', ) + parser.add_argument("-sharping_profile", "--sharping_profile", action='store_true', ) + parser.add_argument("-profile_contrast", "--profile_contrast", action='store_true', ) + + parser.add_argument("-gene_add_profile", "--gene_add_profile", action='store_true', + ) + parser.add_argument("-gpt2", "--gpt2", action='store_true', ) + + parser.add_argument("-history_turn", "--history_turn", type=int, default=100, ) + parser.add_argument("-global_topics", "--global_topics", type=int, default=10, ) + parser.add_argument("-early_stop_num", "--early_stop_num", type=int, default=5, ) + + + parser.add_argument("-use_ckg", "--use_ckg", type=int, default=0, ) + parser.add_argument("-random_persona", "--random_persona", type=int, default=0, ) + parser.add_argument("-use_co_occurrence", "--use_co_occurrence", type=int, default=0, ) + parser.add_argument("-no_user_emb", "--no_user_emb", type=int, default=0, ) + parser.add_argument("-no_learn", "--no_learn", type=int, default=0, ) + parser.add_argument("-infoNCE_num", "--infoNCE_num", type=int, default=0, ) + + + + parser.add_argument('--n_layers', type=int, default=6) + parser.add_argument('--n_position', type=int, default=160) + parser.add_argument('--n_inner_vocab', type=int, default=5000) + parser.add_argument('--n_inner_layers', type=int, default=3) + parser.add_argument('--n_inner_position', type=int, default=15) + parser.add_argument('--d_word_vec', type=int, default=512) + parser.add_argument('--n_head', type=int, default=8) + parser.add_argument('--d_k', type=int, default=64) + parser.add_argument('--d_v', type=int, default=64) + parser.add_argument('--pad_idx', type=int, default=2) + parser.add_argument('--d_model', type=int, default=512) + parser.add_argument('--d_inner', type=int, default=2048) + parser.add_argument('--dropout', type=float, default=0.1) + parser.add_argument('--n_warmup_steps', type=int, default=2000) + parser.add_argument('--scale_emb', type=bool, default=False) + parser.add_argument('--switch_interval', type=int, default=16) + parser.add_argument('--cache_turn', type=int, default=0) + parser.add_argument('--context_all_max_len', type=int, default=1024) + parser.add_argument('--context_max_len', type=int, default=150) + parser.add_argument('--r_max_len', type=int, default=50) + parser.add_argument('--r_beam_max_len', type=int, default=30) + parser.add_argument('--conv_max_len', type=int, default=500) + parser.add_argument('--profile_num', type=int, default=1) + parser.add_argument('--state_num', type=int, default=20) + parser.add_argument('--state_num_redial', type=int, default=20) + parser.add_argument('--pretrain_state_num', type=int, default=50) + parser.add_argument('--all_topic_num', type=int, default=20) + parser.add_argument('--all_topic_num_redial', type=int, default=40) + parser.add_argument('--movie_path_len', type=int, default=3) + parser.add_argument('--tag_num', type=int, default=3) + parser.add_argument('--preference_num', type=int, default=1) + parser.add_argument('--topic_num', type=int, default=2) + parser.add_argument('--action_num', type=int, default=10) + parser.add_argument('--action_num_redial', type=int, default=1) + parser.add_argument('--relation_num', type=int, default=150) + parser.add_argument('--movie_num', type=int, default=200) + parser.add_argument('--state_token', type=int, default=40) + parser.add_argument('--scale_prj', type=bool, default=True) + parser.add_argument('--epoch', type=int, default=100) + parser.add_argument('--task', type=str, default="meddg") + parser.add_argument('--dataset_file', type=str, default="./dataset/{}.zip") + parser.add_argument('--topic_file', type=str, default="./dataset/{}/topic.txt") + parser.add_argument('--topic_movie_file', type=str, default="./dataset/{}/tpmv.txt") + parser.add_argument('--vocab_file', type=str, default="./dataset/{}/tpvocab.txt") + parser.add_argument('--vocab_movie_file', type=str, default="./dataset/{}/tpmvvocab.txt") + parser.add_argument('--no_action_super', type=str, default=None) + parser.add_argument('--max_patience', type=int, default=20) + parser.add_argument('--log_loss_interval', type=int, default=100) + parser.add_argument('--gradient_stack', type=int, default=80) + parser.add_argument('--decay_interval', type=int, default=10000) + parser.add_argument('--decay_rate', type=float, default=0.9) + parser.add_argument('--lr', type=float, default=1e-5) + parser.add_argument('--valid_eval_interval', type=int, default=10000) + parser.add_argument('--test_eval_interval', type=int, default=10000) + parser.add_argument('--force_ckpt_dump', action='store_true') + parser.add_argument('--sub_gen_lambda', type=float, default=0.01) + parser.add_argument('--s_copy_lambda', type=int, default=1) + parser.add_argument('--a_copy_lambda', type=int, default=1) + parser.add_argument('--copy_lambda_mini', type=float, default=0.1) + parser.add_argument('--copy_lambda_decay_steps', type=int, default=10000) + parser.add_argument('--copy_lambda_decay_value', type=float, default=1.0) + parser.add_argument('--init_tau', type=float, default=1.0) + parser.add_argument('--tau_mini', type=float, default=0.1) + parser.add_argument('--tau_decay_total_steps', type=int, default=5000) + parser.add_argument('--tau_decay_rate', type=float, default=0.5) + parser.add_argument('--beam_width', type=int, default=1) + parser.add_argument('--wo_l', action='store_true') + parser.add_argument('--wo_m', action='store_true') + parser.add_argument('--wo_entropy_restrain', action='store_true') + parser.add_argument('--wo_repeat_penalty', action='store_true') + parser.add_argument('--wo_rl', action='store_true') + parser.add_argument('--super_only', action='store_true') + parser.add_argument('--hungary', action='store_true') + parser.add_argument('--super_epoch', type=int, default=5) + parser.add_argument('--batch_size', type=int, default=8) + parser.add_argument('--reg_lambda', type=float, default=5e-3) + parser.add_argument('--BOS_CONTEXT', type=str, default="[s_context]") + parser.add_argument('--EOS_CONTEXT', type=str, default="[/s_context]") + parser.add_argument('--BOS_RESPONSE', type=str, default="[s_response>]") + parser.add_argument('--EOS_RESPONSE', type=str, default="[/s_response]") + parser.add_argument('--BOS_ACTION', type=str, default="[s_action]") + parser.add_argument('--EOS_ACTION', type=str, default="[/s_action]") + parser.add_argument('--PAD_WORD', type=str, default="[PAD]") + parser.add_argument('--SENTENCE_SPLITER', type=str, default="[sent]") + parser.add_argument('--TOPIC_SPLITER', type=str, default="[unused2]") + parser.add_argument('--UNK_WORD', type=str, default="[UNK]") + parser.add_argument('--BOS_PRE', type=str, default="[s_preference]") + parser.add_argument('--EOS_PRE', type=str, default="[/s_preference]") + parser.add_argument('--BOS_PRO', type=str, default="[s_profile]") + parser.add_argument('--EOS_PRO', type=str, default="[/s_profile]") + args = parser.parse_args() + + + args.dataset_file = args.dataset_file.format(args.dataset) + args.topic_file = args.topic_file.format(args.dataset) + args.topic_movie_file = args.topic_movie_file.format(args.dataset) + args.vocab_file = args.vocab_file.format(args.dataset) + args.vocab_movie_file = args.vocab_movie_file.format(args.dataset) + return args + +def set_seed(seed): + np.random.seed(seed) + random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.benchmark = True + torch.backends.cudnn.deterministic = True + +def main(): + + + + + + + + args = config() + print(args) + set_seed(args.seed) + main_logger.info("preparing data") + vocab = Vocab(args) + dataset = DataSet(args=args, vocab=vocab) + + train_set, valid_set, test_set, users, user_cont = dataset.get_dialog(task='topic') + + train_loader = DataLoader(train_set, batch_size=args.batch_size, collate_fn=collate_fn, pin_memory=True, num_workers=args.num_workers, shuffle=True) + valid_loader = DataLoader(valid_set, batch_size=args.batch_size, collate_fn=collate_fn, pin_memory=True, num_workers=args.num_workers, shuffle=False) + test_loader = DataLoader(test_set, batch_size=args.batch_size, collate_fn=collate_fn, pin_memory=True, num_workers=args.num_workers, shuffle=False) + + excrs_topic = Upcrtopic(args=args, vocab=vocab, user_cont=user_cont, train_set=train_set, d_word_vec=768, d_model=768) + + all_param = 0 + for param in excrs_topic.parameters(): + all_param += np.prod(param.size()) + print('all_param: ', all_param) + + engine = Engine(args=args, model=excrs_topic, vocab=vocab) + if not os.path.exists('saved_model'): + os.mkdir('saved_model') + if args.test: + engine.model.load_state_dict(torch.load('saved_model/best_topic_model_{}.pkl'.format(args.dataset)), strict=False) + engine.test(test_loader) + else: + print('args:', args) + engine.train(train_loader, valid_loader, test_loader, args.early_stop_num) + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/topic_with_early_stop.py b/topic_with_early_stop.py new file mode 100644 index 0000000..6d2abfc --- /dev/null +++ b/topic_with_early_stop.py @@ -0,0 +1,39 @@ +import argparse +import random +from DataProcessor import DataSet +from get_logger import get_logger +from get_logger import task_uuid +from Vocab import Vocab +from upcrtopic import Upcrtopic,Engine + +main_logger = get_logger("main", './log/test.log') +main_logger.info("TASK ID {}".format(task_uuid)) + +def config(): + parser = argparse.ArgumentParser() + parser.add_argument("--test", action="store_true") + parser.add_argument("--test_data", choices=["test", "valid"], default="test") + parser.add_argument("--super_rate", type=float, default=0.) + parser.add_argument("--device", type=int, default=-1) + parser.add_argument("--ckpt", type=str, default=None) + parser.add_argument('--inference',type=bool,default=False,) + parser.add_argument("-use_cuda", "--use_cuda", type=bool, default=False) + parser.add_argument("-gpu", "--gpu", type=str, default='1') + parser.add_argument("--processed", type=bool, default=True, ) + args = parser.parse_args() + return args + +def main(): + random.seed(1234) + args = config() + main_logger.info("preparing data") + dataset = DataSet(args=args) + train, valid, test, users, user_cont = dataset.get_dialog(task='topic') + vocab = Vocab() + random.shuffle(train) + excrs_topic = Upcrtopic(vocab=vocab, user_cont=user_cont) + engine = Engine(model=excrs_topic, vocab=vocab) + engine.train(train, valid, test) + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/transformer/Constants.py b/transformer/Constants.py new file mode 100644 index 0000000..9a8b7a2 --- /dev/null +++ b/transformer/Constants.py @@ -0,0 +1,17 @@ + +BOS_WORD = '' +SOS_WORD = BOS_WORD +EOS_WORD = '' +PAD_WORD = '' +UNK_WORD = '' +R_BOS_WORD = "" +R_EOS_WORD = "" +SENTENCE_SPLITER = "" + + +STATE_BOS_WORD = "" +STATE_EOS_WORD = "" +ACTION_BOS_WORD = "" +ACTION_EOS_WORD = "" + + diff --git a/transformer/Layers.py b/transformer/Layers.py new file mode 100644 index 0000000..6a3fffe --- /dev/null +++ b/transformer/Layers.py @@ -0,0 +1,43 @@ + +import torch +import torch.nn as nn + +from transformer.SubLayers import MultiHeadAttention, PositionwiseFeedForward + +__author__ = "Yu-Hsiang Huang" + + +class EncoderLayer(nn.Module): + + + def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1): + super(EncoderLayer, self).__init__() + self.slf_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout) + self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, dropout=dropout) + + def forward(self, enc_input, slf_attn_mask=None): + enc_output, enc_slf_attn = self.slf_attn(enc_input, enc_input, enc_input, mask=slf_attn_mask) + + enc_output = self.pos_ffn(enc_output) + + return enc_output, enc_slf_attn + + +class DecoderLayer(nn.Module): + + + def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1): + super(DecoderLayer, self).__init__() + + self.slf_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout) + self.enc_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout) + self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, dropout=dropout) + + def forward(self, dec_input, enc_output, slf_attn_mask=None, dec_enc_attn_mask=None): + dec_output, dec_slf_attn = self.slf_attn(dec_input, dec_input, dec_input, mask=slf_attn_mask) + + dec_output, dec_enc_attn = self.enc_attn(dec_output, enc_output, enc_output, mask=dec_enc_attn_mask) + + dec_output = self.pos_ffn(dec_output) + + return dec_output, dec_slf_attn, dec_enc_attn diff --git a/transformer/Models.py b/transformer/Models.py new file mode 100644 index 0000000..f1a3d5c --- /dev/null +++ b/transformer/Models.py @@ -0,0 +1,247 @@ + +import torch +import torch.nn as nn +import numpy as np +from transformer.Layers import EncoderLayer, DecoderLayer +import ipdb +__author__ = "Yu-Hsiang Huang" + + +def get_pad_mask(seq, pad_idx): + return (seq != pad_idx).unsqueeze(-2) + + +def get_subsequent_mask(seq): + + sz_b, len_s = seq.size() + subsequent_mask = (1 - torch.triu( + torch.ones((1, len_s, len_s), device=seq.device), diagonal=1)).bool() + + + + + return subsequent_mask + + +class PositionalEncoding(nn.Module): + + def __init__(self, d_hid, n_position=200): + super(PositionalEncoding, self).__init__() + + self.register_buffer('pos_table', self._get_sinusoid_encoding_table(n_position, d_hid)) + + @staticmethod + def _get_sinusoid_encoding_table(n_position, d_hid): + + + + + def get_position_angle_vec(position): + return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)] + + sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)]) + sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) + sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) + + return torch.FloatTensor(sinusoid_table).unsqueeze(0) + + def forward(self, x): + + return x + self.pos_table[:, :x.size(1)].clone().detach() + + +class Encoder(nn.Module): + + + def __init__(self, n_src_vocab, d_word_vec, n_layers, n_head, d_k, d_v, + d_model, d_inner, pad_idx, dropout=0.1, n_position=200, scale_emb=False, word_emb=None, sent_id=0): + + super().__init__() + + if word_emb is not None: + self.src_word_emb = word_emb + else: + self.src_word_emb = nn.Embedding(n_src_vocab, d_word_vec, padding_idx=pad_idx) + + self.position_enc = PositionalEncoding(d_word_vec, n_position=n_position) + self.dropout = nn.Dropout(p=dropout) + self.layer_stack = nn.ModuleList([ + EncoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout) + for _ in range(n_layers)]) + self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) + self.scale_emb = scale_emb + self.d_model = d_model + self.sent_id = sent_id + + def forward(self, src_seq, src_mask, return_attns=False, embed=None, only_last=False, only_sent=False, embed_input=False): + + if embed is not None: + assert self.src_word_emb.num_embeddings == embed.shape[0] + + enc_slf_attn_list = [] + if not embed_input: + + + if len(src_seq.shape) == 2: + if embed is None: + enc_output = self.src_word_emb(src_seq) + else: + enc_output = embed[src_seq] + + elif len(src_seq.shape) == 3: + if src_seq.size(2) == self.d_model: + enc_output = src_seq + else: + batch_size = src_seq.size(0) + if embed is None: + enc_output = torch.bmm(src_seq, self.src_word_emb.weight.unsqueeze(0).expand(batch_size, -1, -1)) + else: + enc_output = torch.bmm(src_seq, embed.unsqueeze(0).expand(batch_size, -1, -1)) + else: + raise RuntimeError + else: + enc_output = src_seq + enc_output.cuda() + if self.scale_emb: + enc_output *= self.d_model ** 0.5 + + enc_output = self.dropout(self.position_enc(enc_output)) + enc_output = self.layer_norm(enc_output) + + for enc_layer in self.layer_stack: + enc_output, enc_slf_attn = enc_layer(enc_output, slf_attn_mask=src_mask) + enc_slf_attn_list += [enc_slf_attn] if return_attns else [] + + if only_sent: + sent_embeds = enc_output[src_seq.eq(self.sent_id)].split(src_seq.eq(self.sent_id).sum(dim=-1).tolist(), dim=0) + enc_output = nn.utils.rnn.pad_sequence(sent_embeds, batch_first=True) + + if only_last: + seq_len = src_mask.squeeze(1).sum(-1) + bsz = seq_len.shape[0] + enc_output = enc_output[torch.arange(bsz), seq_len - 1, :] + + if return_attns: + return enc_output, enc_slf_attn_list + + return enc_output + + +class Decoder(nn.Module): + + + def __init__( + self, n_trg_vocab, d_word_vec, n_layers, n_head, d_k, d_v, + d_model, d_inner, pad_idx, n_position=200, dropout=0.1, scale_emb=False, + word_emb=None): + + super().__init__() + + if word_emb is not None: + self.trg_word_emb = word_emb + else: + self.trg_word_emb = nn.Embedding(n_trg_vocab, d_word_vec, padding_idx=pad_idx) + + self.position_enc = PositionalEncoding(d_word_vec, n_position=n_position) + self.dropout = nn.Dropout(p=dropout) + self.layer_stack = nn.ModuleList([ + DecoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout) + for _ in range(n_layers)]) + self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) + self.scale_emb = scale_emb + self.d_model = d_model + + def forward(self, trg_seq, trg_mask, enc_output, src_mask, return_attns=False): + + dec_slf_attn_list, dec_enc_attn_list = [], [] + + + if len(trg_seq.shape) == 2: + dec_output = self.trg_word_emb(trg_seq) + elif len(trg_seq.shape) == 3: + batch_size = trg_seq.size(0) + dec_output = torch.bmm(trg_seq, self.trg_word_emb.weight.unsqueeze(0).expand(batch_size, -1, -1).to(torch.device("cuda:0"))) + else: + raise RuntimeError + + if self.scale_emb: + dec_output *= self.d_model ** 0.5 + + dec_output = self.dropout(self.position_enc(dec_output)) + dec_output = self.layer_norm(dec_output) + + for dec_layer in self.layer_stack: + dec_output, dec_slf_attn, dec_enc_attn = dec_layer(dec_output, enc_output, slf_attn_mask=trg_mask, dec_enc_attn_mask=src_mask) + dec_slf_attn_list += [dec_slf_attn] if return_attns else [] + dec_enc_attn_list += [dec_enc_attn] if return_attns else [] + + if return_attns: + return dec_output, dec_slf_attn_list, dec_enc_attn_list + + return dec_output + + +class Transformer(nn.Module): + + + def __init__( + self, n_src_vocab, n_trg_vocab, src_pad_idx, trg_pad_idx, + d_word_vec=512, d_model=512, d_inner=2048, + n_layers=6, n_head=8, d_k=64, d_v=64, dropout=0.1, n_position=200, + trg_emb_prj_weight_sharing=True, emb_src_trg_weight_sharing=True, + scale_emb_or_prj='prj'): + + super().__init__() + assert scale_emb_or_prj in ['emb', 'prj', 'none'] + scale_emb = (scale_emb_or_prj == 'emb') if trg_emb_prj_weight_sharing else False + self.scale_prj = (scale_emb_or_prj == 'prj') if trg_emb_prj_weight_sharing else False + self.d_model = d_model + + self.encoder = Encoder( + n_src_vocab=n_src_vocab, n_position=n_position, + d_word_vec=d_word_vec, d_model=d_model, d_inner=d_inner, + n_layers=n_layers, n_head=n_head, d_k=d_k, d_v=d_v, + pad_idx=src_pad_idx, dropout=dropout, scale_emb=scale_emb) + + self.decoder = Decoder( + n_trg_vocab=n_trg_vocab, n_position=n_position, + d_word_vec=d_word_vec, d_model=d_model, d_inner=d_inner, + n_layers=n_layers, n_head=n_head, d_k=d_k, d_v=d_v, + pad_idx=trg_pad_idx, dropout=dropout, scale_emb=scale_emb) + + + self.trg_word_prj = nn.Linear(d_model, n_trg_vocab, bias=False) + + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + assert d_model == d_word_vec, \ + 'To facilitate the residual connections, \ + the dimensions of all module outputs shall be the same.' + + if trg_emb_prj_weight_sharing: + + self.trg_word_prj.weight = self.decoder.trg_word_emb.weight + + if emb_src_trg_weight_sharing: + self.encoder.src_word_emb.weight = self.decoder.trg_word_emb.weight + + def forward(self, src_seq, trg_seq): + + + + + + src_mask = get_pad_mask(src_seq, self.src_pad_idx) + trg_mask = get_pad_mask(trg_seq, self.trg_pad_idx) & get_subsequent_mask(trg_seq) + + enc_output, *_ = self.encoder(src_seq, src_mask) + dec_output, *_ = self.decoder(trg_seq, trg_mask, enc_output, src_mask) + + seq_logit = self.trg_word_prj(dec_output) + + if self.scale_prj: + seq_logit *= self.d_model ** -0.5 + + return seq_logit.view(-1, seq_logit.size(2)) diff --git a/transformer/Modules.py b/transformer/Modules.py new file mode 100644 index 0000000..fd18958 --- /dev/null +++ b/transformer/Modules.py @@ -0,0 +1,25 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +__author__ = "Yu-Hsiang Huang" + + +class ScaledDotProductAttention(nn.Module): + + + def __init__(self, temperature, attn_dropout=0.1): + super().__init__() + self.temperature = temperature + self.dropout = nn.Dropout(attn_dropout) + + def forward(self, q, k, v, mask=None): + attn = torch.matmul(q / self.temperature, k.transpose(2, 3)) + + if mask is not None: + attn = attn.masked_fill(mask == 0, -1e9 if attn.dtype==torch.float32 else -1e4) + + attn = self.dropout(F.softmax(attn, dim=-1)) + output = torch.matmul(attn, v) + + return output, attn diff --git a/transformer/Optim.py b/transformer/Optim.py new file mode 100644 index 0000000..774b73e --- /dev/null +++ b/transformer/Optim.py @@ -0,0 +1,45 @@ +import ipdb + + +class ScheduledOptim(): + + + def __init__(self, optimizer, lr_mul, d_model, n_warmup_steps): + self._optimizer = optimizer + self.lr_mul = lr_mul + self.d_model = d_model + self.n_warmup_steps = n_warmup_steps + self.n_steps = 0 + self.param_groups = optimizer.param_groups + + def step(self): + "Step with the inner optimizer" + + + + + self._update_learning_rate() + self._optimizer.step() + + def update_step(self, global_step): + self.n_steps = global_step + + def zero_grad(self): + "Zero out the gradients with the inner optimizer" + self._optimizer.zero_grad() + + def _get_lr_scale(self): + d_model = self.d_model + n_steps, n_warmup_steps = self.n_steps, self.n_warmup_steps + return (d_model ** -0.5) * min(n_steps ** (-0.5), n_steps * n_warmup_steps ** (-1.5)) + + def _update_learning_rate(self): + + + self.n_steps += 1 + lr = self.lr_mul * self._get_lr_scale() + + self._optimizer.param_groups[0]['lr'] = lr + + if len(self._optimizer.param_groups)>1: + self._optimizer.param_groups[1]['lr'] = 0.6 * lr \ No newline at end of file diff --git a/transformer/SubLayers.py b/transformer/SubLayers.py new file mode 100644 index 0000000..76f04e7 --- /dev/null +++ b/transformer/SubLayers.py @@ -0,0 +1,84 @@ +import torch.nn as nn +import torch.nn.functional as F +from transformer.Modules import ScaledDotProductAttention +import torch + +__author__ = "Yu-Hsiang Huang" + + +class MultiHeadAttention(nn.Module): + + + def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1): + super().__init__() + + self.n_head = n_head + self.d_k = d_k + self.d_v = d_v + + + self.w_qs = nn.Linear(d_model, n_head * d_k, bias=False) + self.w_ks = nn.Linear(d_model, n_head * d_k, bias=False) + self.w_vs = nn.Linear(d_model, n_head * d_v, bias=False) + self.fc = nn.Linear(n_head * d_v, d_model, bias=False) + + self.attention = ScaledDotProductAttention(temperature=d_k ** 0.5) + + self.dropout = nn.Dropout(dropout) + self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) + + def forward(self, q, k, v, mask=None): + d_k, d_v, n_head = self.d_k, self.d_v, self.n_head + sz_b, len_q, len_k, len_v = q.size(0), q.size(1), k.size(1), v.size(1) + + residual = q + + + + q = self.w_qs(q).view(sz_b, len_q, n_head, d_k) + k = self.w_ks(k).view(sz_b, len_k, n_head, d_k) + v = self.w_vs(v).view(sz_b, len_v, n_head, d_v) + + + q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) + + if mask is not None: + mask = mask.unsqueeze(1) + + q, attn = self.attention(q, k, v, mask=mask) + + + + q = q.transpose(1, 2).contiguous().view(sz_b, len_q, -1) + q = self.dropout(self.fc(q)) + + + q += residual + q = self.layer_norm(q) + + return q, attn + + +class PositionwiseFeedForward(nn.Module): + + + def __init__(self, d_in, d_hid, dropout=0.1): + super().__init__() + self.w_1 = nn.Linear(d_in, d_hid) + self.w_2 = nn.Linear(d_hid, d_in) + + self.layer_norm = nn.LayerNorm(d_in, eps=1e-6) + self.dropout = nn.Dropout(dropout) + + def forward(self, x): + residual = x + + + x = self.w_2(F.relu(self.w_1(x))) + x = self.dropout(x) + + + x += residual + x = self.layer_norm(x) + + return x diff --git a/transformer/Translator.py b/transformer/Translator.py new file mode 100644 index 0000000..0843087 --- /dev/null +++ b/transformer/Translator.py @@ -0,0 +1,113 @@ + + +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformer.Models import Transformer, get_pad_mask, get_subsequent_mask + + +class Translator(nn.Module): + + + def __init__( + self, model, beam_size, max_seq_len, + src_pad_idx, trg_pad_idx, trg_bos_idx, trg_eos_idx): + + super(Translator, self).__init__() + + self.alpha = 0.7 + self.beam_size = beam_size + self.max_seq_len = max_seq_len + self.src_pad_idx = src_pad_idx + self.trg_bos_idx = trg_bos_idx + self.trg_eos_idx = trg_eos_idx + + self.model = model + self.model.eval() + + self.register_buffer('init_seq', torch.LongTensor([[trg_bos_idx]])) + self.register_buffer( + 'blank_seqs', + torch.full((beam_size, max_seq_len), trg_pad_idx, dtype=torch.long)) + self.blank_seqs[:, 0] = self.trg_bos_idx + self.register_buffer( + 'len_map', + torch.arange(1, max_seq_len + 1, dtype=torch.long).unsqueeze(0)) + + def _model_decode(self, trg_seq, enc_output, src_mask): + trg_mask = get_subsequent_mask(trg_seq) + dec_output, *_ = self.model.decoder(trg_seq, trg_mask, enc_output, src_mask) + return F.softmax(self.model.trg_word_prj(dec_output), dim=-1) + + def _get_init_state(self, src_seq, src_mask): + beam_size = self.beam_size + + enc_output, *_ = self.model.tfr_encoder(src_seq, src_mask) + dec_output = self._model_decode(self.init_seq, enc_output, src_mask) + + best_k_probs, best_k_idx = dec_output[:, -1, :].topk(beam_size) + + scores = torch.log(best_k_probs).view(beam_size) + gen_seq = self.blank_seqs.clone().detach() + gen_seq[:, 1] = best_k_idx[0] + enc_output = enc_output.repeat(beam_size, 1, 1) + return enc_output, gen_seq, scores + + def _get_the_best_score_and_idx(self, gen_seq, dec_output, scores, step): + assert len(scores.size()) == 1 + + beam_size = self.beam_size + + + best_k2_probs, best_k2_idx = dec_output[:, -1, :].topk(beam_size) + + + scores = torch.log(best_k2_probs).view(beam_size, -1) + scores.view(beam_size, 1) + + + scores, best_k_idx_in_k2 = scores.view(-1).topk(beam_size) + + + best_k_r_idxs, best_k_c_idxs = best_k_idx_in_k2 // beam_size, best_k_idx_in_k2 % beam_size + best_k_idx = best_k2_idx[best_k_r_idxs, best_k_c_idxs] + + + gen_seq[:, :step] = gen_seq[best_k_r_idxs, :step] + + gen_seq[:, step] = best_k_idx + + return gen_seq, scores + + def translate_sentence(self, src_seq): + + + + assert src_seq.size(0) == 1 + + src_pad_idx, trg_eos_idx = self.src_pad_idx, self.trg_eos_idx + max_seq_len, beam_size, alpha = self.max_seq_len, self.beam_size, self.alpha + + with torch.no_grad(): + src_mask = get_pad_mask(src_seq, src_pad_idx) + enc_output, gen_seq, scores = self._get_init_state(src_seq, src_mask) + ans_idx = 0 + + for step in range(2, max_seq_len): + dec_output = self._model_decode(gen_seq[:, :step], enc_output, src_mask) + gen_seq, scores = self._get_the_best_score_and_idx(gen_seq, dec_output, scores, step) + + + + eos_locs = gen_seq == trg_eos_idx + + seq_lens, _ = self.len_map.masked_fill(~eos_locs, max_seq_len).min(1) + + + if (eos_locs.sum(1) > 0).sum(0).item() == beam_size: + + + _, ans_idx = scores.div(seq_lens.float() ** alpha).max(0) + ans_idx = ans_idx.item() + break + + return gen_seq[ans_idx][:seq_lens[ans_idx]].tolist() diff --git a/transformer/__init__.py b/transformer/__init__.py new file mode 100644 index 0000000..6ce696c --- /dev/null +++ b/transformer/__init__.py @@ -0,0 +1,12 @@ +import transformer.Constants +import transformer.Modules +import transformer.Layers +import transformer.SubLayers +import transformer.Models +import transformer.Translator +import transformer.Optim + +__all__ = [ + transformer.Constants, transformer.Modules, transformer.Layers, + transformer.SubLayers, transformer.Models, transformer.Optim, + transformer.Translator] diff --git a/upcrgene.py b/upcrgene.py new file mode 100644 index 0000000..dd8ee63 --- /dev/null +++ b/upcrgene.py @@ -0,0 +1,400 @@ +import torch.nn as nn +import torch +from transformer.Models import Encoder +from transformer.Models import Decoder +from gumbel_softmax import GumbelSoftmax +from tau_scheduler import TauScheduler +from transformers import GPT2LMHeadModel +from ProfileTopic import UserContrast +from tools import Tools +from Response import Response +from Vocab import Vocab +import torch.nn.functional as F +from transformer.Optim import ScheduledOptim +import Bleu +import distinct +from upcrtopic import get_time +from tqdm import tqdm +import sys +import math +from persona_eval import cal_p_cover, cal_p_f1 + + +class Upcrgene(nn.Module): + def __init__(self, args, vocab:Vocab, user_cont, train_set, n_layers=6, p_layers=3, + d_word_vec=768, d_model=768, d_inner=3072, beam_width=1, + n_head=8, d_k=64, d_v=64, dropout=0.1): + super(Upcrgene, self).__init__() + self.args = args + self.device = 'cuda:0' if torch.cuda.is_available() else 'cpu' + self.vocab = vocab + self.glo2loc, self.loc2glo = vocab.vocab_transfer() + self.glo2loc = torch.tensor(self.glo2loc).cuda() + self.loc2glo = torch.tensor(self.loc2glo).cuda() + self.topic_num = vocab.topic_num() + self.user_num = vocab.n_user + 1 + self.character_num = vocab.n_character + self.word_vocab, self.word_len, self.topic_vocab, self.topic_len = vocab.get_vocab(task='gene') + self.word_pad_idx = vocab.get_word_pad() + self.topic_pad_idx = vocab.get_topic_pad() + self.r_bos_idx = vocab.word2index(self.args.BOS_RESPONSE) + self.r_eos_idx = vocab.word2index(self.args.EOS_RESPONSE) + self.beam_width = beam_width + self.word_emb = nn.Embedding(len(self.vocab.tokenizer) if self.args.gpt2 else self.word_len, d_word_vec, padding_idx=self.word_pad_idx) + self.topic_emb = nn.Embedding(self.topic_len, d_word_vec, padding_idx=self.topic_pad_idx) + self.user_emb = nn.Embedding(self.user_num, d_word_vec).to(device=self.device) + self.character_emb = nn.Embedding(self.character_num, d_word_vec).to(device=self.device) + self.gumbel_softmax = GumbelSoftmax() + self.global_step = 0 + self.main_tfr_encoder = Encoder( + n_src_vocab=self.word_len, n_position=self.args.context_all_max_len, + d_word_vec=d_word_vec, d_model=d_model, d_inner=d_inner, + n_layers=n_layers, n_head=n_head, d_k=d_k, d_v=d_v, + pad_idx=self.word_pad_idx, dropout=dropout, scale_emb=False, + word_emb=self.word_emb + ) + self.p_tfr_encoder4p = Encoder( + n_src_vocab=self.topic_len, n_position=20, + d_word_vec=d_word_vec, d_model=d_model, d_inner=d_inner, + n_layers=p_layers, n_head=n_head, d_k=d_k, d_v=d_v, + pad_idx=self.topic_pad_idx, dropout=dropout, scale_emb=False, + word_emb=self.topic_emb + ) + self.a_tfr_encoder = Encoder( + n_src_vocab=self.topic_len, n_position=self.args.action_num, + d_word_vec=d_word_vec, d_model=d_model, d_inner=d_inner, + n_layers=p_layers, n_head=n_head, d_k=d_k, d_v=d_v, + pad_idx=self.topic_pad_idx, dropout=dropout, scale_emb=False, + word_emb=self.topic_emb + ) + if self.args.gpt2: + self.main_tfr_decoder = GPT2LMHeadModel.from_pretrained('pretrained_models/GPT2').cuda() + + self.main_tfr_decoder.resize_token_embeddings(len(self.vocab.tokenizer)) + self.n_ctx = self.main_tfr_decoder.config.to_dict().get("n_ctx") + else: + self.main_tfr_decoder = Decoder( + n_trg_vocab=self.word_len, n_position=self.args.r_max_len, + d_word_vec=d_word_vec, d_model=d_model, d_inner=d_inner, + n_layers=n_layers, n_head=n_head, d_k=d_k, d_v=d_v, + pad_idx=self.word_pad_idx, dropout=dropout, scale_emb=False, + word_emb=self.word_emb + ) + self.response = Response(args=args, vocab=self.vocab, + a_encoder=self.a_tfr_encoder,decoder=self.main_tfr_decoder, + p_encoder=self.p_tfr_encoder4p, main_encoder=self.main_tfr_encoder, + hidden_size=d_model, n_vocab=self.word_len,trg_bos_idx=self.vocab.tokenizer.convert_tokens_to_ids('[CLS]') if self.args.gpt2 else self.r_bos_idx, + trg_eos_idx=self.vocab.tokenizer.convert_tokens_to_ids('[SEP]') if self.args.gpt2 else self.r_eos_idx,max_seq_len=self.args.r_max_len,beam_width=beam_width, + loc2glo=self.loc2glo,n_topic=self.topic_len).cuda() + + + self.p_tfr_encoder4p_g = Encoder( + n_src_vocab=self.topic_len, n_position=20, + d_word_vec=d_word_vec, d_model=d_model, d_inner=d_inner, + n_layers=p_layers, n_head=n_head, d_k=d_k, d_v=d_v, + pad_idx=self.topic_pad_idx, dropout=dropout, scale_emb=False, + word_emb=self.topic_emb + ).cuda() + self.p_tfr_decoder4p_l = Decoder( + n_trg_vocab=self.topic_len, n_position=30, + d_word_vec=d_word_vec, d_model=d_model, d_inner=d_inner, + n_layers=p_layers, n_head=n_head, d_k=d_k, d_v=d_v, + pad_idx=self.topic_pad_idx, dropout=dropout, scale_emb=False, + word_emb=self.topic_emb + ).cuda() + self.l_bos_idx = vocab.topic2index(args.BOS_PRO) + self.pro_tau_scheduler = TauScheduler(args.init_tau, args.tau_mini, args.tau_decay_total_steps) + self.user2character_metric = torch.LongTensor(train_set.user2character_metric).to(device=self.device) + self.p_l = UserContrast(args=args, no_prob=True, + p_encoder=self.p_tfr_encoder4p, p_g_encoder=self.p_tfr_encoder4p_g, + context_encoder=self.main_tfr_encoder, decoder=self.p_tfr_decoder4p_l, + hidden_size=d_model, n_topic_vocab=self.topic_len, + trg_bos_idx=self.l_bos_idx, max_seq_len=self.args.profile_num, + gs=self.gumbel_softmax, ts=self.pro_tau_scheduler, + user2character_metric=self.user2character_metric).cuda() + + def forward(self, user_id, context_all, context_all_len, context, context_len, tp_path, tp_path_len, ar_gth, ar_gth_len, resp, resp_len, topic2context, mode='train'): + assert mode in ['train', 'valid', 'test'] + if mode == 'train': + self.global_step += 1 + if self.args.gene_add_profile: + user_emb = self.user_emb.weight + topic_emb = self.topic_emb.weight + character_emb = self.character_emb.weight + char_vectors, tp_hidden_raw, profile_prob, contrast_loss = self.p_l.forward(user_id=user_id, context=context_all, context_len=context_all_len, tp_path=tp_path, tp_path_len=tp_path_len, all_topic=ar_gth[:, torch.arange(1, self.args.action_num, 2)], topic2context=topic2context, user_embed=user_emb, character_embed=character_emb, topic_embed=topic_emb) + + resp = self.response.forward(ar=ar_gth, ar_len=ar_gth_len, + context=context, context_len=context_len, + tp_path=tp_path, tp_path_len=tp_path_len, + resp_gth=resp, resp_gth_len=resp_len, + tp_path_embed=tp_hidden_raw, profile_embed=char_vectors) + else: + contrast_loss = None + resp = self.response.forward(ar=ar_gth, ar_len=ar_gth_len, context=context, context_len=context_len, + tp_path=tp_path, tp_path_len=tp_path_len, resp_gth=resp, resp_gth_len=resp_len) + return resp, contrast_loss + else: + if resp is not None: + probs = None + resp = self.response.forward(ar=ar_gth, ar_len=ar_gth_len, + context=context, context_len=context_len, + tp_path=tp_path, tp_path_len=tp_path_len, + resp_gth=resp, resp_gth_len=resp_len) + else: + resp, probs = self.response.forward(ar=ar_gth, ar_len=ar_gth_len, context=context, context_len=context_len, + tp_path=tp_path, tp_path_len=tp_path_len, resp_gth=resp, resp_gth_len=resp_len) + return resp, probs + + def topictensor2nl(self,tensor): + words = tensor.detach().cpu().numpy() + words = self.vocab.index2topic(words) + return words + +class Engine(): + def __init__(self,args, model:torch.nn.Module,vocab): + self.model = model + self.args = args + lr = self.args.lr + self.optimizer = torch.optim.Adam(self.model.parameters(), lr, betas=(0.9, 0.98), eps=1e-9) + self.optimizer = ScheduledOptim(self.optimizer, 0.5, self.args.d_model, self.args.n_warmup_steps) + self.vocab = vocab + self.topic_pad_idx = self.vocab.topic2index(self.args.PAD_WORD) + self.word_pad_idx = self.vocab.word2index(self.args.PAD_WORD) + self.global_step = 0 + self.loss = 0 + + def train(self, train_loader, valid_loader, test_loader, early_stop_num=10): + min_ppl, max_bleu_1, max_bleu_2, max_dist_1, max_dist_2, best_epoch = 1000, 0, 0, 0, 0, 0 + for e in range(self.args.epoch): + print("epoch : {}".format(e)) + self.optimizer.zero_grad() + for index, input in enumerate(tqdm(train_loader)): + if input[0].size(0) != self.args.batch_size: + break + input = [data.to(device='cuda:0') for data in input] + user_id, context_all_idx, context_all_len, context_idx, context_len, state_U, state_U_len, a_R, a_R_len, resp, resp_len, topic2context, final = input + resp_gen, contrast_loss = self.model.forward(user_id=user_id, + context=context_idx, context_len=context_len, + context_all=context_all_idx, context_all_len=context_all_len, + tp_path=state_U, tp_path_len=state_U_len, + ar_gth=a_R, ar_gth_len=a_R_len, + resp=resp, resp_len=resp_len, + topic2context=topic2context) + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + loss, _ = nll_loss(resp_gen, resp.detach(), self.word_pad_idx) + if contrast_loss is not None: + loss = loss + contrast_loss + self.loss += loss.item() + loss = loss / float(self.args.gradient_stack) + loss.backward(retain_graph=False) + + torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0) + if self.global_step % self.args.gradient_stack == 0: + self.optimizer.step() + self.optimizer.zero_grad() + self.global_step += 1 + print(get_time(), 'valid:', end=' ') + ppl_val, bleu_1_val, bleu_2_val, dist_1_val, dist_2_val = self.test(valid_loader) + print(get_time(), 'test:', end=' ') + ppl, bleu_1, bleu_2, dist_1, dist_2 = self.test(test_loader) + if bleu_1_val > max_bleu_1: + print('new best epoch!!!') + min_ppl, max_bleu_1, max_bleu_2, max_dist_1, max_dist_2 = ppl_val, bleu_1_val, bleu_2_val, dist_1_val, dist_2_val + best_epoch = e + torch.save(self.model.state_dict(), 'saved_model/best_generate_model_{}.pkl'.format(self.args.dataset)) + elif e - best_epoch >= early_stop_num: + print('early stop the best epoch is', str(e-early_stop_num)) + break + self.model.load_state_dict(torch.load('saved_model/best_generate_model_{}.pkl'.format(self.args.dataset)), strict=False) + ppl, bleu_1, bleu_2, dist_1, dist_2 = self.test(test_loader) + print(get_time(), 'best test:' + "ppl:{:.4f},bleu_1:{:.4f},bleu_2:{:.4f},dist_1:{:.4f},dist_2:{:.4f}".format(ppl, bleu_1, bleu_2, dist_1, dist_2)) + + def test(self, test_loader, get_ppl=False, is_show=False): + print('start eval') + self.model.eval() + res_gen = [] + res_gth = [] + all_personas = [] + step = 0 + with torch.no_grad(): + for index, input in enumerate(tqdm(test_loader)): + if input[0].size(0) != self.args.batch_size: + break + step += 1 + input = [data.to(device='cuda:0') for data in input] + user_id, context_all_idx, context_all_len, context_idx, context_len, state_U, state_U_len, a_R, a_R_len, resp, resp_len, topic2context, final = input + resp_gen, probs = self.model.forward(user_id=user_id, + context=context_idx, context_len=context_len, + context_all=context_all_idx, context_all_len=context_all_len, + tp_path=state_U, tp_path_len=state_U_len, + ar_gth=a_R, ar_gth_len=a_R_len, + resp=None, resp_len=None, + topic2context=topic2context, + mode='test') + resp_gen_word = self.wordtensor2nl(resp_gen) + resp_gth_word = self.wordtensor2nl(resp) + resp_gth_word = [i[1:] for i in resp_gth_word] + + + + + res_gen.extend(resp_gen_word) + res_gth.extend(resp_gth_word) + for i in range(self.args.batch_size): + personas = self.vocab.user_to_Sentidx[str(user_id[i].tolist())] + personas = [self.vocab.idx_to_userSent[idx] for idx in personas] + if self.args.dataset == 'PersonaChat': + personas = [i.split() for i in personas] + elif self.args.dataset == 'TG-ReDial': + personas = [[w for w in i] for i in personas] + all_personas.append(personas) + if is_show: + for i in range(self.args.batch_size): + personas = self.vocab.user_to_Sentidx[str(user_id[i].tolist())] + personas = [self.vocab.idx_to_userSent[idx] for idx in personas] + context = self.vocab.index2word(context_idx[i].tolist()) + context = ''.join([word if word!='[sent]' else '\n' for word in context if word!='[PAD]']) + gold_response = ''.join(resp_gth_word[i]) + pred_response = ''.join(resp_gen_word[i]) + topic_his = self.vocab.index2topic(state_U[i].tolist()) + topic_his = ' '.join([t for t in topic_his if t != '[PAD]']) + gold_topic = self.vocab.index2topic(a_R[i, 1].tolist()) + print('personas:') + print('\n'.join(personas)) + print('context:') + print(context, end='') + print('gold response:', gold_response) + print('pred response:', pred_response) + print('topic_his:', topic_his) + print('gold_topic', gold_topic) + print() + bleu_1, bleu_2, bleu_3, bleu_4 = Bleu.bleu(self.args, res_gen, res_gth) + dist_1, dist_2 = distinct.cal_calculate(self.args, res_gen, res_gth) + p_f1 = cal_p_f1(res_gen, all_personas) + p_cover = cal_p_cover(res_gen, all_personas) + + + ppl = 0. + if get_ppl: + step = 0 + for index, input in enumerate(test_loader): + if input[0].size(0) != self.args.batch_size: + break + step += 1 + input = [data.to(device='cuda:0') for data in input] + user_id, context_all_idx, context_all_len, context_idx, context_len, state_U, state_U_len, a_R, a_R_len, resp, resp_len, topic2context, final = input + resp_gen, probs = self.model.forward(user_id=user_id, + context=context_idx, context_len=context_len, + context_all=context_all_idx, context_all_len=context_all_len, + tp_path=state_U, tp_path_len=state_U_len, + ar_gth=a_R, ar_gth_len=a_R_len, + resp=resp, resp_len=resp_len, + topic2context=topic2context, + mode='test') + loss, _ = nll_loss(resp_gen, resp.detach(), self.word_pad_idx) + ppl += math.exp(loss.tolist()) + ppl = ppl / step + print("ppl:{:.4f},bleu_1:{:.4f},bleu_2:{:.4f},dist_1:{:.4f},dist_2:{:.4f},p_f1:{:.4f},p_cover:{:.4f}".format(ppl, bleu_1, bleu_2, dist_1, dist_2, p_f1, p_cover)) + sys.stdout.flush() + self.model.train() + return ppl, bleu_1, bleu_2, dist_1, dist_2 + + def wordtensor2nl(self, tensor): + bs = tensor.shape[0] + if self.args.gpt2: + words = [] + for i in range(bs): + words.append(self.vocab.tokenizer.convert_ids_to_tokens(tensor[i])) + return words + else: + words = tensor.detach().cpu().numpy() + words = self.vocab.index2word(words) + return words + + + + + + + + + + + + +def get_mask_via_len(length, max_len): + + B = length.size(0) + mask = torch.ones([B, max_len]).cuda() + mask = torch.cumsum(mask, 1) + mask = mask <= length.unsqueeze(1) + mask = mask.unsqueeze(-2) + return mask + +def get_default_tensor(shape, dtype, pad_idx=None): + pad_tensor = torch.zeros(shape, dtype=dtype) + pad_tensor[..., pad_idx] = 1.0 if dtype == torch.float else 1 + pad_tensor = pad_tensor.cuda() + return pad_tensor + +def sparse_prefix_pad(inp, sos_idx): + n_vocab = inp.size(2) + pad = inp.new_ones(inp.size(0), 1, dtype=torch.long) * sos_idx + sparse_pad = Tools.one_hot(pad, n_vocab).cuda() + tensor = torch.cat([sparse_pad, inp], 1) + return tensor + +def one_hot_scatter(indice, num_classes, dtype=torch.float): + indice_shape = list(indice.shape) + placeholder = torch.zeros(*(indice_shape + [num_classes]), device=indice.device, dtype=dtype) + v = 1 if dtype == torch.long else 1.0 + placeholder.scatter_(-1, indice.unsqueeze(-1), v) + return placeholder + +def nll_loss(hypothesis, target, pad_id): + + eps = 1e-9 + B, T = target.shape + hypothesis = hypothesis.reshape(-1, hypothesis.size(-1)) + target = target[:, 1:] + padding = torch.ones(target.size(0), 1, dtype=torch.long) * pad_id + padding = padding.cuda() + target = torch.cat([target, padding], 1) + target = target.reshape(-1) + nll_loss = F.nll_loss(torch.log(hypothesis + 1e-20), target, ignore_index=pad_id, reduction='none') + not_ignore_tag = (target != pad_id).float() + not_ignore_num = not_ignore_tag.reshape(B, T).sum(-1) + sum_nll_loss = nll_loss.reshape(B, T).sum(-1) + nll_loss_vector = sum_nll_loss / (not_ignore_num + eps) + nll_loss = nll_loss_vector.mean() + return nll_loss, nll_loss_vector.detach() diff --git a/upcrrec.py b/upcrrec.py new file mode 100644 index 0000000..5f44cf3 --- /dev/null +++ b/upcrrec.py @@ -0,0 +1,427 @@ +import torch.nn as nn +import torch +from tau_scheduler import TauScheduler +from copy_scheduler import CopyScheduler +from transformer.Models import Encoder +from transformer.Models import Decoder +from Profile import PriorProfile,PosteriorProfile +from Preference_all import PriorPreference,PosteriorPreference +from gumbel_softmax import GumbelSoftmax +from Action_all import Action +from tools import Tools +from Vocab import Vocab +from scipy import optimize +import torch.nn.functional as F +from transformer.Optim import ScheduledOptim +from DataLoaderRec import DataLoaderRec +import math +import sys + +class Upcrrec(nn.Module): + def __init__(self,vocab:Vocab,user_cont,n_layers=6,p_layers=3, + d_word_vec=512,d_model=512, d_inner=2048,beam_width=1, + n_head=8, d_k=64, d_v=64, dropout=0.1): + + super(Upcrrec, self).__init__() + self.vocab = vocab + self.glo2loc , self.loc2glo = vocab.vocab_transfer() + self.glo2loc = torch.tensor(self.glo2loc).cuda() + self.loc2glo = torch.tensor(self.loc2glo).cuda() + self.topic_num = vocab.topic_num() + self.word_vocab, self.word_len, self.topic_vocab, self.topic_len = vocab.get_vocab(task='rec') + self.word_pad_idx = vocab.get_word_pad() + self.topic_pad_idx = vocab.get_topic_pad() + self.m_bos_idx = vocab.topic2index(self.args.BOS_PRE) + self.l_bos_idx = vocab.topic2index(self.args.BOS_PRO) + self.a_bos_idx = vocab.topic2index(self.args.BOS_ACTION) + self.r_bos_idx = vocab.word2index(self.args.BOS_RESPONSE) + self.r_eos_idx = vocab.word2index(self.args.EOS_RESPONSE) + self.beam_width = beam_width + self.pro_tau_scheduler = TauScheduler(self.args.init_tau, self.args.tau_mini, self.args.tau_decay_total_steps) + self.pre_tau_scheduler = TauScheduler(self.args.init_tau, self.args.tau_mini, self.args.tau_decay_total_steps) + self.m_copy_scheduler = CopyScheduler(self.args.s_copy_lambda, self.args.copy_lambda_mini, self.args.copy_lambda_decay_steps) + self.l_copy_scheduler = CopyScheduler(self.args.a_copy_lambda, self.args.copy_lambda_mini, self.args.copy_lambda_decay_steps) + self.word_emb = nn.Embedding(self.word_len,d_word_vec,padding_idx=self.word_pad_idx) + self.topic_emb = nn.Embedding(self.topic_len,d_word_vec,padding_idx=self.topic_pad_idx) + self.user_emb = nn.Embedding(user_cont,d_word_vec) + self.gumbel_softmax = GumbelSoftmax() + self.global_step = 0 + self.main_tfr_encoder = Encoder( + n_src_vocab=self.word_len, n_position=self.args.conv_max_len, + d_word_vec=d_word_vec, d_model=d_model, d_inner=d_inner, + n_layers=n_layers, n_head=n_head, d_k=d_k, d_v=d_v, + pad_idx=self.word_pad_idx, dropout=dropout, scale_emb=False, + word_emb=self.word_emb + ) + + self.u_tfr_encoder4p = Encoder( + n_src_vocab=user_cont, n_position=1, + d_word_vec=d_word_vec, d_model=d_model, d_inner=d_inner, + n_layers=p_layers, n_head=n_head, d_k=d_k, d_v=d_v, + pad_idx=self.topic_pad_idx, dropout=dropout, scale_emb=False, + word_emb=self.user_emb + ) + + self.u_tfr_encoder4q = Encoder( + n_src_vocab=user_cont, n_position=1, + d_word_vec=d_word_vec, d_model=d_model, d_inner=d_inner, + n_layers=p_layers, n_head=n_head, d_k=d_k, d_v=d_v, + pad_idx=self.topic_pad_idx, dropout=dropout, scale_emb=False, + word_emb=self.user_emb + ) + + self.p_tfr_encoder4p = Encoder( + n_src_vocab=self.topic_len, n_position=200, + d_word_vec=d_word_vec, d_model=d_model, d_inner=d_inner, + n_layers=p_layers, n_head=n_head, d_k=d_k, d_v=d_v, + pad_idx=self.topic_pad_idx, dropout=dropout, scale_emb=False, + word_emb=self.topic_emb + ) + self.p_tfr_decoder4p = Decoder( + n_trg_vocab=self.topic_len, n_position=15, + d_word_vec=d_word_vec, d_model=d_model, d_inner=d_inner, + n_layers=p_layers, n_head=n_head, d_k=d_k, d_v=d_v, + pad_idx=self.topic_pad_idx, dropout=dropout, scale_emb=False, + word_emb=self.topic_emb + ) + + self.m_tfr_encoder4p = Encoder( + n_src_vocab=self.topic_len, n_position=100, + d_word_vec=d_word_vec, d_model=d_model, d_inner=d_inner, + n_layers=p_layers, n_head=n_head, d_k=d_k, d_v=d_v, + pad_idx=self.topic_pad_idx, dropout=dropout, scale_emb=False, + word_emb=self.topic_emb + ) + + self.p_tfr_encoder4q = Encoder( + n_src_vocab=self.topic_len, n_position=200, + d_word_vec=d_word_vec, d_model=d_model, d_inner=d_inner, + n_layers=p_layers, n_head=n_head, d_k=d_k, d_v=d_v, + pad_idx=self.topic_pad_idx, dropout=dropout, scale_emb=False, + word_emb=self.topic_emb + ) + self.p_tfr_decoder4q = Decoder( + n_trg_vocab=self.topic_len, n_position=15, + d_word_vec=d_word_vec, d_model=d_model, d_inner=d_inner, + n_layers=p_layers, n_head=n_head, d_k=d_k, d_v=d_v, + pad_idx=self.topic_pad_idx, dropout=dropout, scale_emb=False, + word_emb=self.topic_emb + ) + self.m_tfr_encoder4q = Encoder( + n_src_vocab=self.topic_len, n_position=100, + d_word_vec=d_word_vec, d_model=d_model, d_inner=d_inner, + n_layers=p_layers, n_head=n_head, d_k=d_k, d_v=d_v, + pad_idx=self.topic_pad_idx, dropout=dropout, scale_emb=False, + word_emb=self.topic_emb + ) + + self.a_tfr_decoder = Decoder( + n_trg_vocab=self.topic_len, n_position=self.args.action_num, + d_word_vec=d_word_vec, d_model=d_model, d_inner=d_inner, + n_layers=p_layers, n_head=n_head, d_k=d_k, d_v=d_v, + pad_idx=self.topic_pad_idx, dropout=dropout, scale_emb=False, + word_emb=self.topic_emb + ) + + if self.args.wo_l: + self.p_l = None + self.q_l = None + else: + self.p_l = PriorProfile(encoder=self.u_tfr_encoder4p,decoder=self.p_tfr_decoder4p, + hidden_size=d_model,n_topic_vocab=self.topic_len, + trg_bos_idx=self.l_bos_idx,max_seq_len=self.args.profile_num, + gs=self.gumbel_softmax,ts=self.pro_tau_scheduler).cuda() + + self.q_l = PosteriorProfile(main_encoder=self.main_tfr_encoder,id_encoder = self.u_tfr_encoder4q, + topic_encoder=self.p_tfr_encoder4q, decoder=self.p_tfr_decoder4q,m_encoder=self.m_tfr_encoder4q, + hidden_size=d_model,n_topic_vocab=self.topic_len, + trg_bos_idx=self.l_bos_idx,max_seq_len=self.args.profile_num, + gs=self.gumbel_softmax,ts=self.pro_tau_scheduler, + glo2loc=self.glo2loc,loc2glo=self.loc2glo).cuda() + + if self.args.wo_m: + self.p_mt = None + self.q_mt = None + + else: + self.p_mt = PriorPreference(encoder=self.p_tfr_encoder4p,decoder=self.p_tfr_decoder4p, + m_encoder=self.m_tfr_encoder4p,main_tfr_encoder=self.main_tfr_encoder, + hidden_size=d_model,n_topic_vocab=self.topic_len,trg_bos_idx=self.m_bos_idx, + max_seq_len=self.args.preference_num,gs=self.gumbel_softmax,glo2loc=self.glo2loc, + loc2glo=self.loc2glo,ts=self.pre_tau_scheduler).cuda() + + self.q_mt = PosteriorPreference(encoder=self.p_tfr_encoder4q,main_encoder=self.main_tfr_encoder, + m_encoder=self.m_tfr_encoder4q,decoder=self.p_tfr_decoder4q, + hidden_size=d_model,n_topic_vocab=self.topic_len,trg_bos_idx=self.m_bos_idx, + max_seq_len=self.args.preference_num,gs=self.gumbel_softmax,glo2loc=self.glo2loc, + loc2glo=self.loc2glo,ts=self.pre_tau_scheduler).cuda() + + self.action = Action(p_encoder=self.p_tfr_encoder4p,main_encoder=self.main_tfr_encoder,m_encoder=self.m_tfr_encoder4p, + a_decoder=self.a_tfr_decoder,hidden_size=d_model, + n_topic_vocab=self.topic_len,bos_idx=self.a_bos_idx,vocab=self.vocab, + max_len=self.args.action_num,glo2loc=self.glo2loc,loc2glo=self.loc2glo).cuda() + + def forward(self,user_id,all_topic, all_topic_len,context, context_len,tp_path, tp_path_len, + ar_gth, ar_gth_len,related_movies, related_movies_len,final,pv_m,mode='train'): + assert mode in ['train','valid','test'] + pv_m, pv_m_mask = self.mask_preference(pv_m, final) + if mode == 'train': + p_l, p_l_gumbel = self.p_l.forward(id=user_id) + q_l, q_l_gumbel = self.q_l.forward(id=user_id, topics=all_topic, topics_len=all_topic_len) + p_m, p_m_gumbel = self.p_mt.forward(context=context, context_len=context_len,pv_m=pv_m,pv_m_mask=pv_m_mask, + tp_path=tp_path,tp_path_len=tp_path_len) + q_m, q_m_gumbel = self.q_mt.forward(context=context, context_len=context_len,pv_m=pv_m,pv_m_mask=pv_m_mask, + ar_gth=ar_gth,ar_gth_len=ar_gth_len,tp_path=tp_path,tp_path_len=tp_path_len) + ar = self.action.forward(m=q_m_gumbel, + l=q_l_gumbel, + context=context, + context_len=context_len, + ar_gth=ar_gth,ar_gth_len=ar_gth_len, + tp_path=tp_path,tp_path_len=tp_path_len, + related_movies=related_movies,related_movies_len=related_movies_len, + mode='train') + self.global_step += 1 + return p_l, q_l, p_m, q_m, ar, q_m_gumbel + else: + p_l = self.p_l.forward(id=user_id) + p_m = self.p_mt.forward(context=context,context_len=context_len,pv_m=pv_m,pv_m_mask=pv_m_mask,tp_path=tp_path,tp_path_len=tp_path_len) + ar,ar_probs = self.action.forward(m=p_m,l=p_l,context=context,context_len=context_len,ar_gth=ar_gth,ar_gth_len=ar_gth_len,tp_path=tp_path, + tp_path_len=tp_path_len,related_movies=related_movies, related_movies_len=related_movies_len,mode='test') + return ar, ar_probs, p_m + + def mask_preference(self, pv_m, final): + b = range(self.args.batch_size) + b = [i + 1 for i in b] + b = torch.tensor(b).cuda().tolist() + final = list(final) + final = [int(i) for i in final] + c = [i * j for i, j in zip(final, b)] + c = list(c) + d = [] + for i in c: + if i != 0: + d.append(c.index(i)) + if d: + d = torch.tensor(d).cuda() + pv_m[d, :, :] = 0 + pv_m[d, :, self.topic_pad_idx] = 1.0 + pv_m_mask = pv_m.new_ones(pv_m.size(0), 1, pv_m.size(1)) + pv_m_mask[d,:,:] = 0 + return pv_m, pv_m_mask + + def topictensor2nl(self,tensor): + words = tensor.detach().cpu().numpy() + words = self.vocab.index2topic(words) + return words + + def wordtensor2nl(self,tensor): + words = tensor.detach().cpu().numpy() + words = self.vocab.index2word(words) + return words + +class Engine(): + def __init__(self,model:torch.nn.Module, + vocab): + self.model = model + lr = self.args.lr + self.optimizer = torch.optim.Adam(self.model.parameters(), lr, betas=(0.9, 0.98), eps=1e-9) + self.optimizer = ScheduledOptim(self.optimizer, 0.5, self.args.d_model, self.args.n_warmup_steps) + self.vocab = vocab + self.topic_pad_idx = self.vocab.topic2index(self.args.PAD_WORD) + self.global_step = 0 + self.action_loss = 0 + self.kl_l_loss = 0 + self.kl_m_loss =0 + + def train(self,train_set,test_set): + bst_metric = 0 + patience = 0 + gen_stop = False + for e in range(self.args.epoch): + print("epoch : {}".format(e)) + train_loader = DataLoaderRec(train_set,self.vocab) + self.pv_m = get_default_tensor([self.args.batch_size, self.args.preference_num, self.model.topic_len], torch.float,pad_idx=self.topic_pad_idx) + self.optimizer.zero_grad() + for index,input in enumerate(train_loader): + if input[0].size(0) != self.args.batch_size: + break + id, all_topic, all_topic_len,context_idx, context_len, topic_path, topic_path_len, a_R, a_R_len, \ + seek_idx, seek_len, resp_idx, resp_len, state_R, state_R_len, related_movies,related_movies_len,final = input + p_l, q_l, p_m, q_m, ar, m= self.model.forward(user_id=id,all_topic=all_topic,all_topic_len=all_topic_len,context=context_idx, context_len=context_len, + tp_path=topic_path,tp_path_len=topic_path_len,ar_gth=a_R, ar_gth_len=a_R_len, + related_movies=related_movies, related_movies_len=related_movies_len,final=final,pv_m=self.pv_m) + kl_l = kl_loss(p_l, q_l.detach()) + self.kl_l_loss += kl_l.item() + kl_m = kl_loss(p_m, q_m.detach()) + self.kl_m_loss += kl_m.item() + nll_ar = action_nll(ar, a_R.detach(), self.model.topic_pad_idx) + self.action_loss += nll_ar.item() + p_l_reg, q_l_reg = regularization_loss(p_l), regularization_loss(q_l) + p_m_reg, q_m_reg = regularization_loss(p_m), regularization_loss(q_m) + reg_loss = self.args.reg_lambda * (p_l_reg + q_l_reg + p_m_reg + q_m_reg) + loss = 0.3 * kl_m + 0.3 * kl_l + nll_ar + reg_loss + if (self.global_step % 200 == 0): + print("global_step: {}".format(self.global_step)) + print("kl_preference: {}".format(self.kl_m_loss / self.model.global_step)) + print("kl_profile: {}".format(self.kl_l_loss / self.model.global_step)) + print("nll_ar: {}".format(self.action_loss / self.model.global_step)) + sys.stdout.flush() + loss = loss / float(self.args.gradient_stack) + loss.backward(retain_graph=False) + if self.global_step % self.args.gradient_stack == 0: + self.optimizer.step() + self.optimizer.zero_grad() + self.pv_m = m.detach() + self.global_step += 1 + metric = self.test(test_set) + print("train finished ! ") + + def test(self,test_set): + self.model.eval() + print(" test ") + dataloader = DataLoaderRec(test_set,self.vocab) + metrics = { + "NDCG1": 0, + "NDCG10": 0, + "NDCG50": 0, + "MRR1": 0, + "MRR10": 0, + "MRR50": 0, + "rec_count": 0, + "count":0 + } + self.pv_m = get_default_tensor([self.args.batch_size, self.args.preference_num, self.model.topic_len], torch.float,pad_idx=self.model.topic_pad_idx) + with torch.no_grad(): + for index,data in enumerate(dataloader): + if data[0].size(0) != self.args.batch_size: + break + id, all_topic, all_topic_len, context_idx, context_len, topic_path, topic_path_len, a_R, a_R_len, \ + seek_idx, seek_len, resp_idx, resp_len, state_R, state_R_len, related_movies, related_movies_len, final = data + ar, ar_probs, m = self.model.forward(user_id=id,all_topic=all_topic,all_topic_len=all_topic_len,context=context_idx, context_len=context_len, + tp_path=topic_path,tp_path_len=topic_path_len,ar_gth=a_R, ar_gth_len=a_R_len, + related_movies=related_movies, related_movies_len=related_movies_len,final=final,pv_m=self.pv_m,mode='test') + self.pv_m = one_hot_scatter(m,self.vocab.topic_num()) + self.compute_metrics(ar_probs, a_R, a_R_len, metrics) + metrics['NDCG1'] = round(metrics['NDCG1'] / metrics['rec_count'], 4) + metrics['NDCG10'] = round(metrics['NDCG10'] / metrics['rec_count'], 4) + metrics['NDCG50'] = round(metrics['NDCG50'] / metrics['rec_count'], 4) + metrics['MRR1'] = round(metrics['MRR1'] / metrics['rec_count'], 4) + metrics['MRR10'] = round(metrics['MRR10'] / metrics['rec_count'], 4) + metrics['MRR50'] = round(metrics['MRR50'] / metrics['rec_count'], 4) + print(metrics) + self.model.train() + print('test finished!') + return metrics + + def compute_metrics(self,ar_probs, ar_gth, a_R_len, metrics): + tanlun = self.vocab.topic2index('谈论') + qingqiutuijian = self.vocab.topic2index('请求推荐') + def _topic_prediction(tar,gen,metrics): + metrics['topic_count'] += 1 + for k in [1,3,5]: + pred, pred_id = torch.topk(gen,k,-1) + pred_id = pred_id.tolist() + if tar in pred_id: + metrics["TopicId_Hits@{}".format(k)] += 1 + def _movie_recommendation(tar,gen,metrics): + _, pred_idx = torch.topk(gen, k=100, dim=0) + metrics["count"] += 1 + metrics['rec_count'] += 1 + for k in [1,10,50]: + pred, pred_id = torch.topk(gen,k,-1) + pred_id = pred_id.tolist() + if tar in pred_id: + rank = pred_id.index(tar) + metrics['NDCG{}'.format(k)] += 1.0 / math.log(rank + 2.0, 2) + metrics['MRR{}'.format(k)] += 1.0 / (rank + 1.0) + for i, gt in enumerate(ar_gth): + ar_gen = ar_probs[i,:] + gt_len = int(a_R_len[i]) + for j in range(0,gt_len,2): + action_type = gt[j] + if action_type == self.vocab.topic2index('推荐电影'): + _movie_recommendation(gt[j+1],ar_gen[int(j/2)],metrics) + else: + _topic_prediction(gt[j+1],ar_gen[int(j/2)],metrics) + if tanlun in gt and qingqiutuijian in gt: + break + +def get_mask_via_len(length, max_len): + B = length.size(0) + mask = torch.ones([B, max_len]).cuda() + mask = torch.cumsum(mask, 1) + mask = mask <= length.unsqueeze(1) + mask = mask.unsqueeze(-2) + return mask + +def get_default_tensor(shape, dtype, pad_idx=None): + pad_tensor = torch.zeros(shape, dtype=dtype) + pad_tensor[..., pad_idx] = 1.0 if dtype == torch.float else 1 + pad_tensor = pad_tensor.cuda() + return pad_tensor + +def sparse_prefix_pad(inp, sos_idx): + n_vocab = inp.size(2) + pad = inp.new_ones(inp.size(0), 1, dtype=torch.long) * sos_idx + sparse_pad = Tools.one_hot(pad, n_vocab).cuda() + tensor = torch.cat([sparse_pad, inp], 1) + return tensor + +def one_hot_scatter(indice, num_classes, dtype=torch.float): + indice_shape = list(indice.shape) + placeholder = torch.zeros(*(indice_shape + [num_classes]), device=indice.device, dtype=dtype) + v = 1 if dtype == torch.long else 1.0 + placeholder.scatter_(-1, indice.unsqueeze(-1), v) + return placeholder + +def kl_loss(prior_dist, posterior_dist): + bias = 1e-24 + if (len(prior_dist.shape) >= 3) and self.args.hungary: + B, S = prior_dist.size(0), prior_dist.size(1) + expand_prior_dist = prior_dist.unsqueeze(2).expand(-1, -1, S, -1).reshape(B, S * S, -1) + expand_posterior_dist = posterior_dist.unsqueeze(1).expand(-1, S, -1, -1).reshape(B, S * S, -1) + cost_vector = F.kl_div((expand_prior_dist + bias).log(), expand_posterior_dist, reduce=False).sum(-1) + cost_matrix = cost_vector.reshape(-1, S, S) + cost_matrix_np = cost_matrix.detach().cpu().numpy() + row_idx, col_idx = zip(*[optimize.linear_sum_assignment(cost_matrix_np[i]) for i in range(B)]) + col_idx = torch.tensor(col_idx, dtype=torch.long) + posterior_dist = Tools.nested_index_select(posterior_dist, col_idx) + flat_prior_dist = prior_dist.reshape(-1, prior_dist.size(-1)) + flat_posterior_dist = posterior_dist.reshape(-1, posterior_dist.size(-1)) + kl_div = F.kl_div((flat_prior_dist + bias).log(), flat_posterior_dist, reduce=False).sum(-1) + kl_div = kl_div.mean() + return kl_div + +def nll_loss(hypothesis, target, pad_id ): + eps = 1e-9 + B, T = target.shape + hypothesis = hypothesis.reshape(-1, hypothesis.size(-1)) + target = target[:,1:] + padding = torch.ones(target.size(0),1,dtype=torch.long) * pad_id + padding = padding.cuda() + target = torch.cat([target,padding],1) + target = target.reshape(-1) + nll_loss = F.nll_loss(torch.log(hypothesis + 1e-20), target, ignore_index=pad_id, reduce=False) + not_ignore_tag = (target != pad_id).float() + not_ignore_num = not_ignore_tag.reshape(B, T).sum(-1) + sum_nll_loss = nll_loss.reshape(B, T).sum(-1) + nll_loss_vector = sum_nll_loss / (not_ignore_num + eps) + nll_loss = nll_loss_vector.mean() + return nll_loss, nll_loss_vector.detach() + +def regularization_loss(dist): + entropy_loss, repeat_loss = torch.tensor(0.), torch.tensor(0.) + if not self.args.wo_entropy_restrain: + entropy_loss = Tools.entropy_restrain(dist) + if not self.args.wo_repeat_penalty: + repeat_loss = Tools.repeat_penalty(dist) + regularization = entropy_loss + repeat_loss + return regularization + +def action_nll(hypothesis,target,pad_idx): + eps = 1e-9 + hypothesis = hypothesis.reshape(-1,hypothesis.size(-1)) + target = target[:,[1]] + target = target.reshape(-1) + nll_loss = F.nll_loss(torch.log(hypothesis+eps),target,ignore_index=pad_idx) + return nll_loss \ No newline at end of file diff --git a/upcrtopic.py b/upcrtopic.py new file mode 100644 index 0000000..2161c0c --- /dev/null +++ b/upcrtopic.py @@ -0,0 +1,539 @@ +import json +import dgl +import torch.nn as nn +import torch +from torch.cuda.amp import autocast as autocast +from tau_scheduler import TauScheduler +from copy_scheduler import CopyScheduler +from transformer.Models import Encoder +from transformer.Models import Decoder +from ProfileTopic import PriorProfile, PosteriorProfile, UserContrast +from PreferenceTopic import PriorPreference, PosteriorPreference +from gumbel_softmax import GumbelSoftmax +from ActionTopic import Action +from tools import Tools +from Vocab import Vocab +from scipy import optimize +import numpy as np +import torch.nn.functional as F +from transformer.Optim import ScheduledOptim +from DataLoaderTopic import DataLoaderTopic +import math +import sys +from tqdm import tqdm +import datetime +from rgat import RGATLayer, DualRGATLayer +from OPT import IPconstraint_and_solve + + +def get_time(): + return datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") + ': ' + +class Upcrtopic(nn.Module): + def __init__(self, args, vocab: Vocab, user_cont, train_set, n_layers=6, p_layers=3, + d_word_vec=768, d_model=768, d_inner=3072, beam_width=1, + n_head=8, d_k=64, d_v=64, dropout=0.1): + super(Upcrtopic, self).__init__() + self.args = args + self.device = 'cuda:0' if torch.cuda.is_available() else 'cpu' + self.vocab = vocab + self.d_word_vec = d_word_vec + self.glo2loc, self.loc2glo = vocab.vocab_transfer() + self.glo2loc = torch.tensor(self.glo2loc).to(device=self.device) + self.loc2glo = torch.tensor(self.loc2glo).to(device=self.device) + self.topic_num = vocab.topic_num() + self.user_num = vocab.n_user + 1 + self.character_num = vocab.n_character + self.word_vocab, self.word_len, self.topic_vocab, self.topic_len = vocab.get_vocab(task='topic') + self.word_pad_idx = vocab.get_word_pad() + self.topic_pad_idx = vocab.get_topic_pad() + self.m_bos_idx = vocab.topic2index(args.BOS_PRE) + self.l_bos_idx = vocab.topic2index(args.BOS_PRO) + self.a_bos_idx = vocab.topic2index(args.BOS_ACTION) + self.r_bos_idx = vocab.word2index(args.BOS_RESPONSE) + self.r_eos_idx = vocab.word2index(args.EOS_RESPONSE) + self.beam_width = beam_width + self.pro_tau_scheduler = TauScheduler(args.init_tau, args.tau_mini, args.tau_decay_total_steps) + self.pre_tau_scheduler = TauScheduler(args.init_tau, args.tau_mini, args.tau_decay_total_steps) + self.m_copy_scheduler = CopyScheduler(args.s_copy_lambda, args.copy_lambda_mini, args.copy_lambda_decay_steps) + self.l_copy_scheduler = CopyScheduler(args.a_copy_lambda, args.copy_lambda_mini, args.copy_lambda_decay_steps) + self.word_emb = nn.Embedding(self.word_len, d_word_vec, padding_idx=self.word_pad_idx).to(device=self.device) + self.topic_emb = nn.Embedding(self.topic_len, d_word_vec, padding_idx=self.topic_pad_idx).to(device=self.device) + self.user_emb = nn.Embedding(self.user_num, d_word_vec).to(device=self.device) + self.character_emb = nn.Embedding(self.character_num, d_word_vec).to(device=self.device) + self.gumbel_softmax = GumbelSoftmax() + self.global_step = 0 + self.main_tfr_encoder = Encoder( + n_src_vocab=self.word_len, n_position=args.context_all_max_len, + d_word_vec=d_word_vec, d_model=d_model, d_inner=d_inner, + n_layers=n_layers, n_head=n_head, d_k=d_k, d_v=d_v, + pad_idx=self.word_pad_idx, dropout=dropout, scale_emb=False, + word_emb=self.word_emb, sent_id=self.vocab.word2idx['[sent]'] + ).cuda() + self.p_tfr_encoder4p = Encoder( + n_src_vocab=self.topic_len, n_position=20, + d_word_vec=d_word_vec, d_model=d_model, d_inner=d_inner, + n_layers=p_layers, n_head=n_head, d_k=d_k, d_v=d_v, + pad_idx=self.topic_pad_idx, dropout=dropout, scale_emb=False, + word_emb=self.topic_emb + ).cuda() + self.p_tfr_encoder4p_g = Encoder( + n_src_vocab=self.topic_len, n_position=20, + d_word_vec=d_word_vec, d_model=d_model, d_inner=d_inner, + n_layers=p_layers, n_head=n_head, d_k=d_k, d_v=d_v, + pad_idx=self.topic_pad_idx, dropout=dropout, scale_emb=False, + word_emb=self.topic_emb + ).cuda() + self.p_tfr_encoder4p_a = Encoder( + n_src_vocab=self.topic_len, n_position=200, + d_word_vec=d_word_vec, d_model=d_model, d_inner=d_inner, + n_layers=p_layers, n_head=n_head, d_k=d_k, d_v=d_v, + pad_idx=self.topic_pad_idx, dropout=dropout, scale_emb=False, + word_emb=self.topic_emb + ).cuda() + self.p_tfr_decoder4p_l = Decoder( + n_trg_vocab=self.topic_len, n_position=30, + d_word_vec=d_word_vec, d_model=d_model, d_inner=d_inner, + n_layers=p_layers, n_head=n_head, d_k=d_k, d_v=d_v, + pad_idx=self.topic_pad_idx, dropout=dropout, scale_emb=False, + word_emb=self.topic_emb + ).cuda() + self.a_tfr_decoder = Decoder( + n_trg_vocab=self.topic_len, n_position=self.args.action_num, + d_word_vec=d_word_vec, d_model=d_model, d_inner=d_inner, + n_layers=p_layers, n_head=n_head, d_k=d_k, d_v=d_v, + pad_idx=self.topic_pad_idx, dropout=dropout, scale_emb=False, + word_emb=self.topic_emb + ).cuda() + self.user2character_metric = torch.LongTensor(train_set.user2character_metric).to(device=self.device) + + self.topic_co_graph = dict() + for key, values in train_set.topic_co_graph.items(): + key_id = vocab.topic2idx[key] + value_ids = [vocab.topic2idx[v] for v in values] + self.topic_co_graph[key_id] = value_ids + + self.p_l = UserContrast(args=args, + p_encoder=self.p_tfr_encoder4p, p_g_encoder=self.p_tfr_encoder4p_g, + context_encoder=self.main_tfr_encoder, decoder=self.p_tfr_decoder4p_l, + hidden_size=d_model, n_topic_vocab=self.topic_len, + trg_bos_idx=self.l_bos_idx, max_seq_len=self.args.profile_num, + gs=self.gumbel_softmax, ts=self.pro_tau_scheduler, + user2character_metric=self.user2character_metric, + topic_co_graph=self.topic_co_graph).cuda() + + self.action = Action(args=args, p_encoder=self.p_tfr_encoder4p_a, + main_encoder=self.main_tfr_encoder, a_decoder=self.a_tfr_decoder, + hidden_size=d_model, n_topic_vocab=self.topic_len, bos_idx=self.a_bos_idx, vocab=self.vocab, + max_len=args.action_num, glo2loc=self.glo2loc, loc2glo=self.loc2glo).cuda() + + + if self.args.use_ckg == 1: + self.conv_layers = 2 + self.edge_dropout_rate = 0.2 + self.relation2idx = train_set.relation2idx + self.relation_num = len(self.relation2idx) + self.other_entity2idx = train_set.other_entity2idx + self.other_entity_emb = nn.Embedding(len(self.other_entity2idx)-self.topic_len, d_word_vec).to(device=self.device) + self.relation_emb = nn.Embedding(self.relation_num, d_word_vec).to(device=self.device) + self.cut_RGCN = nn.ModuleList([ + RGATLayer(self.d_word_vec, self.d_word_vec, num_heads=8, feat_drop=0.2).to(device=self.device) + for _ in range(self.conv_layers) + ]) + self.ckg_edge_sets = torch.tensor(train_set.edge_set, dtype=torch.long, device=self.device) + self.ckg_edge_types = torch.tensor(train_set.edge_type, dtype=torch.long, device=self.device) + self.ckg = dgl.graph((self.ckg_edge_sets[0], self.ckg_edge_sets[1]), num_nodes=self.ckg_edge_sets.max().tolist()+1 , idtype=torch.int32).to(device=self.device) + self.ckg.edata['edge_types'] = self.ckg_edge_types + self.edge_dropout = dgl.transforms.DropEdge(p=self.edge_dropout_rate) + + def forward(self, user_id, all_topic, all_topic_len, context_all, context_all_len, context, context_len, tp_path, tp_path_len, ar_gth, ar_gth_len, related_topics, related_topics_len, final, topic2context, mode='train'): + assert mode in ['train', 'valid', 'test'] + if self.args.use_ckg == 1: + all_embedding = torch.cat([self.topic_emb.weight, self.other_entity_emb.weight], dim=0) + if mode == 'train': + ckg = self.edge_dropout(self.ckg) + else: + ckg = self.ckg + relation_embed = self.relation_emb(ckg.edata['edge_types']) + for i in range(self.conv_layers): + all_embedding, relation_embed = self.cut_RGCN[i](all_embedding, relation_embed, ckg) + topic_emb = all_embedding[:self.topic_len] + else: + topic_emb = self.topic_emb.weight + + user_emb = self.user_emb.weight + character_emb = self.character_emb.weight + + if mode == 'train': + self.global_step += 1 + if self.args.profile_pred: + p_l, p_l_gumbel, tp_hidden_raw, profile_prob, contrast_loss, selected_personas = self.p_l.forward(user_id=user_id, context=context_all, context_len=context_all_len, tp_path=tp_path, tp_path_len=tp_path_len, all_topic=ar_gth[:, torch.arange(1, self.args.action_num, 2)], topic2context=topic2context, user_embed=user_emb, character_embed=character_emb, topic_embed=topic_emb) + if not self.args.profile_contrast: + contrast_loss = None + else: + p_l, p_l_gumbel, tp_hidden_raw, profile_prob, contrast_loss, selected_personas = None, None, None, None, None, None + p_m, p_m_gumbel = None, None + if self.args.global_topic_for_action: + + ar = self.action.forward(m=p_m, l=p_l, context=context, context_len=context_len, + ar_gth=ar_gth, ar_gth_len=ar_gth_len, + tp_path=tp_path, tp_path_len=tp_path_len, tp_path_embed=tp_hidden_raw, + related_topics=related_topics, related_topics_len=related_topics_len, + encoder_embed=topic_emb, decoder_embed=topic_emb, profile_prob=profile_prob, + user_id=user_id, user_embed=user_emb, + mode='train') + else: + ar = self.action.forward(m=p_m, l=p_l, context=context, context_len=context_len, + ar_gth=ar_gth, ar_gth_len=ar_gth_len, + tp_path=tp_path, tp_path_len=tp_path_len, + related_topics=related_topics, related_topics_len=related_topics_len, + encoder_embed=topic_emb, decoder_embed=topic_emb, profile_prob=profile_prob, + user_id=user_id, user_embed=user_emb, + mode='train') + return p_l, p_m, ar, contrast_loss, selected_personas + else: + p_l, p_l_gumbel, tp_hidden_raw, profile_prob, selected_personas = self.p_l.forward(user_id=user_id, context=context_all, context_len=context_all_len, tp_path=tp_path, tp_path_len=tp_path_len, all_topic=ar_gth[:, torch.arange(1, self.args.action_num, 2)], topic2context=topic2context, user_embed=user_emb, character_embed=character_emb, topic_embed=topic_emb) + p_m, p_m_gumbel = None, None + if self.args.global_topic_for_action: + + ar, ar_probs = self.action.forward(m=p_m, l=p_l, context=context, context_len=context_len, + ar_gth=ar_gth, ar_gth_len=ar_gth_len, + tp_path=tp_path, tp_path_len=tp_path_len, tp_path_embed=tp_hidden_raw, + related_topics=related_topics, related_topics_len=related_topics_len, + encoder_embed=topic_emb, decoder_embed=topic_emb, profile_prob=profile_prob, + user_id=user_id, user_embed=user_emb, + mode='test') + else: + ar, ar_probs = self.action.forward(m=p_m, l=p_l, context=context, context_len=context_len, + ar_gth=ar_gth, ar_gth_len=ar_gth_len, + tp_path=tp_path, tp_path_len=tp_path_len, + related_topics=related_topics, related_topics_len=related_topics_len, + encoder_embed=topic_emb, decoder_embed=topic_emb, profile_prob=profile_prob, + user_id=user_id, user_embed=user_emb, + mode='test') + return ar, ar_probs, p_m_gumbel, None, selected_personas + + def mask_preference(self, pv_m, final): + b = range(self.args.batch_size) + b = [i + 1 for i in b] + b = torch.tensor(b).cuda().tolist() + final = list(final) + final = [int(i) for i in final] + c = [i * j for i, j in zip(final, b)] + c = list(c) + d = [] + for i in c: + if i != 0: + d.append(c.index(i)) + if d: + d = torch.tensor(d).cuda() + pv_m[d, :, :] = 0 + pv_m[d, :, self.topic_pad_idx] = 1.0 + pv_m_mask = pv_m.new_ones(pv_m.size(0), 1, pv_m.size(1)) + pv_m_mask[d, :, :] = 0 + return pv_m, pv_m_mask + + def topictensor2nl(self, tensor): + words = tensor.detach().cpu().numpy() + words = self.vocab.index2topic(words) + return words + + def wordtensor2nl(self, tensor): + words = tensor.detach().cpu().numpy() + words = self.vocab.index2word(words) + return words + + def edge_set_dropout(self, edge_sets, edge_types=None, dropout_rate=None): + _, n_edges = edge_sets.shape + random_indices = np.random.choice(n_edges, size=int(n_edges * (1 - dropout_rate)), replace=False) + if edge_types is not None: + return edge_sets[:, random_indices], edge_types[random_indices] + else: + return edge_sets[:, random_indices] + + +class Engine(): + def __init__(self, args, model: torch.nn.Module, vocab): + self.args = args + self.model = model + lr = self.args.lr + self.optimizer = torch.optim.Adam(self.model.parameters(), lr, betas=(0.9, 0.98), eps=1e-9) + self.optimizer = ScheduledOptim(self.optimizer, 0.5, args.d_model, args.n_warmup_steps) + self.vocab = vocab + self.topic_pad_idx = self.vocab.topic2index(args.PAD_WORD) + self.global_step = 0 + + def train(self, train_loader, valid_loader, test_loader, early_stop_num=3): + + best_metrics = {"TopicId_Hits@1": 0, "TopicId_Hits@3": 0, "TopicId_Hits@5": 0, "best_epoch": 0} + for e in range(self.args.epoch): + train_metrics = {"topic_Loss": 0, "TopicId_Hits@1": 0, "TopicId_Hits@3": 0, "TopicId_Hits@5": 0, "topic_count": 0} + print("epoch : {}".format(e)) + self.optimizer.zero_grad() + for index, input in enumerate(tqdm(train_loader)): + if input[0].size(0) != self.args.batch_size: + break + input = [data.to(device='cuda:0') for data in input] + id, context_all_idx, context_all_len, context_idx, context_len, state_U, state_U_len, related_topics, related_topics_len, a_R, a_R_len, all_topic, all_topic_len, topic2context, final, response_idx = input + + + p_l, p_m, ar, contrast_loss, selected_personas = self.model.forward(user_id=id, + all_topic=all_topic, all_topic_len=all_topic_len, + context_all=context_all_idx, context_all_len=context_all_len, + context=context_idx, context_len=context_len, + tp_path=state_U, tp_path_len=state_U_len, + ar_gth=a_R, ar_gth_len=a_R_len, + related_topics=related_topics, related_topics_len=related_topics_len, + final=final, topic2context=topic2context) + self.compute_metrics(ar, a_R, a_R_len, train_metrics) + nll_ar = self.action_nll(ar, a_R.detach(), self.model.topic_pad_idx) + if contrast_loss is None: + loss = nll_ar + else: + loss = (nll_ar + contrast_loss) + train_metrics['topic_Loss'] += loss.tolist() + loss = loss / float(self.args.gradient_stack) + + loss.backward(retain_graph=False) + if self.global_step % self.args.gradient_stack == 0: + self.optimizer.step() + self.optimizer.zero_grad() + + self.global_step += 1 + + train_metrics['topic_Loss'] = round(train_metrics['topic_Loss'] / train_metrics['topic_count'], 4) + train_metrics['TopicId_Hits@1'] = round(train_metrics['TopicId_Hits@1'] / train_metrics['topic_count'], 4) + train_metrics['TopicId_Hits@3'] = round(train_metrics['TopicId_Hits@3'] / train_metrics['topic_count'], 4) + train_metrics['TopicId_Hits@5'] = round(train_metrics['TopicId_Hits@5'] / train_metrics['topic_count'], 4) + print(get_time(), 'train:', train_metrics) + print(get_time(), 'valid:', end=' ') + metric = self.test(valid_loader) + print(get_time(), 'test:', end=' ') + metric = self.test(test_loader) + if metric["TopicId_Hits@3"] > best_metrics["TopicId_Hits@3"]: + best_metrics["TopicId_Hits@1"] = metric["TopicId_Hits@1"] + best_metrics["TopicId_Hits@3"] = metric["TopicId_Hits@3"] + best_metrics["TopicId_Hits@5"] = metric["TopicId_Hits@5"] + best_metrics['best_epoch'] = e + torch.save(self.model.state_dict(), 'saved_model/best_topic_model_{}.pkl'.format(self.args.dataset)) + print('!!! new best on epoch ' + str(e)) + elif e - best_metrics['best_epoch'] >= early_stop_num: + print('early stop the best epoch is', str(e-early_stop_num)) + break + self.model.load_state_dict(torch.load('saved_model/best_topic_model_{}.pkl'.format(self.args.dataset)), strict=False) + print('test:', end=' ') + metric = self.test(test_loader) + print("train finished ! ") + print(get_time(), 'best test:' + str(best_metrics)) + + def test(self, dataloader, in_tqdm=False): + self.model.eval() + metrics = {"topic_Loss": 0, "TopicId_Hits@1": 0, "TopicId_Hits@3": 0, "TopicId_Hits@5": 0, "topic_count": 0} + with open('predict_res.jsonl', 'w') as f: + with torch.no_grad(): + if self.args.otp: + hcar2topic_att = self.model.character_emb.weight.matmul(self.model.topic_emb.weight.T) + self.model.p_l.topic_att, self.model.p_l.topic_ids = IPconstraint_and_solve(hcar2topic_att) + else: + self.model.p_l.topic_att, self.model.p_l.topic_ids = None, None + for index, input in enumerate(dataloader): + if input[0].size(0) != self.args.batch_size: + break + input = [data.to(device='cuda:0') for data in input] + user_id, context_all_idx, context_all_len, context_idx, context_len, state_U, state_U_len, related_topics, related_topics_len, a_R, a_R_len, all_topic, all_topic_len, topic2context, final, response_idx = input + ar, ar_probs, m, l, selected_personas = self.model.forward(user_id=user_id, + all_topic=all_topic, all_topic_len=all_topic_len, + context_all=context_all_idx, + context_all_len=context_all_len, + context=context_idx, context_len=context_len, + tp_path=state_U, tp_path_len=state_U_len, + ar_gth=a_R, ar_gth_len=a_R_len, + related_topics=related_topics, + related_topics_len=related_topics_len, + final=final, topic2context=topic2context, + mode='test') + self.compute_metrics(ar_probs, a_R, a_R_len, metrics) + loss = self.action_nll(ar_probs, a_R.detach(), self.model.topic_pad_idx) + metrics['topic_Loss'] += loss.tolist() + see_bad_case = False + if see_bad_case: + character_ids = self.model.p_l.user2character_metric[user_id].tolist() + for i in range(self.args.batch_size): + persona_ids = character_ids[i] + personas = [self.vocab.idx_to_userSent[idx] for idx in persona_ids] + if ar_probs[i, 0].topk(1, -1)[1].tolist()[0] != a_R[i, 1].tolist() or True: + context = self.vocab.index2word(context_idx[i].tolist()) + context = ''.join([word if word!='[sent]' else '\n' for word in context if word!='[PAD]']) + if i < self.args.batch_size - 1: + response = self.vocab.index2word(response_idx[i].tolist()) + response = ''.join([word if word!='[sent]' else '\n' for word in response if word not in ['[s_response>]', '[PAD]', '[/s_response]']]) + topic_his = self.vocab.index2topic(state_U[i].tolist()) + topic_his = ' '.join([t for t in topic_his if t!='[PAD]']) + pred_topic = self.vocab.index2topic(ar_probs[i, 0].topk(1, -1)[1].tolist())[0] + gold_topic = self.vocab.index2topic(a_R[i, 1].tolist()) + all_topics = self.vocab.index2topic(all_topic[i].tolist()) + all_topics = ' '.join([t for t in all_topics if t!='[PAD]']) + print('all personas:') + print('\n'.join(personas)) + print('selected personas:') + print('\n'.join([personas[idx] for idx in selected_personas[i]])) + print('context:') + print(context) + print('response:') + print(response) + print('topic_his:', topic_his) + print('pred_topic:', pred_topic) + print('gold_topic:', gold_topic) + print('all_topic:', all_topics) + print() + save_data = { + 'user_id': user_id.tolist()[i], + 'personas': personas, + 'selected_personas': [personas[idx] for idx in selected_personas[i]], + 'context': context.split('\n'), + 'response': response.split('\n'), + 'topic_his': topic_his.split(' '), + 'gold_topic': gold_topic, + 'pred_topic': pred_topic, + 'all_topics': all_topics.split(' '), + } + f.write(json.dumps(save_data, ensure_ascii=False)+'\n') + + metrics['topic_Loss'] = round(metrics['topic_Loss'] / metrics['topic_count'], 4) + metrics['TopicId_Hits@1'] = round(metrics['TopicId_Hits@1'] / metrics['topic_count'], 4) + metrics['TopicId_Hits@3'] = round(metrics['TopicId_Hits@3'] / metrics['topic_count'], 4) + metrics['TopicId_Hits@5'] = round(metrics['TopicId_Hits@5'] / metrics['topic_count'], 4) + if in_tqdm: + tqdm.write(str(metrics)) + else: + print(metrics) + self.model.train() + return metrics + + def compute_metrics(self, ar_probs, ar_gth, a_R_len, metrics): + tanlun = self.vocab.topic2index('谈论') + qingqiutuijian = self.vocab.topic2index('请求推荐') + + def _topic_prediction(tar, gen, metrics): + metrics['topic_count'] += 1 + for k in [1, 3, 5]: + pred, pred_id = torch.topk(gen, k, -1) + pred_id = pred_id.tolist() + if tar in pred_id: + metrics["TopicId_Hits@{}".format(k)] += 1 + + def _movie_recommendation(tar, gen, metrics): + metrics['rec_count'] += 1 + for k in [1, 10, 50]: + pred, pred_id = torch.topk(gen, k, -1) + pred_id = pred_id.tolist() + if tar in pred_id: + rank = pred_id.index(tar) + metrics['NDCG{}'.format(k)] += 1.0 / math.log(rank + 2.0, 2) + metrics['MRR{}'.format(k)] += 1.0 / (rank + 1.0) + + for i, gt in enumerate(ar_gth): + ar_gen = ar_probs[i, :] + gt_len = int(a_R_len[i]) + for j in range(0, gt_len, 2): + action_type = gt[j] + if action_type == self.vocab.topic2index('推荐电影'): + _movie_recommendation(gt[j + 1], ar_gen[int(j / 2)], metrics) + else: + _topic_prediction(gt[j + 1], ar_gen[int(j / 2)], metrics) + if tanlun in gt and qingqiutuijian in gt: + break + + def kl_loss(self, prior_dist, posterior_dist): + bias = 1e-24 + if (len(prior_dist.shape) >= 3) and self.args.hungary: + B, S = prior_dist.size(0), prior_dist.size(1) + expand_prior_dist = prior_dist.unsqueeze(2).expand(-1, -1, S, -1).reshape(B, S * S, -1) + expand_posterior_dist = posterior_dist.unsqueeze(1).expand(-1, S, -1, -1).reshape(B, S * S, -1) + cost_vector = F.kl_div((expand_prior_dist + bias).log(), expand_posterior_dist, reduce=False).sum(-1) + cost_matrix = cost_vector.reshape(-1, S, S) + cost_matrix_np = cost_matrix.detach().cpu().numpy() + row_idx, col_idx = zip(*[optimize.linear_sum_assignment(cost_matrix_np[i]) for i in range(B)]) + col_idx = torch.tensor(col_idx, dtype=torch.long) + posterior_dist = Tools.nested_index_select(posterior_dist, col_idx) + flat_prior_dist = prior_dist.reshape(-1, prior_dist.size(-1)) + flat_posterior_dist = posterior_dist.reshape(-1, posterior_dist.size(-1)) + kl_div = F.kl_div((flat_prior_dist + bias).log(), flat_posterior_dist, reduce=False).sum(-1) + kl_div = kl_div.mean() + return kl_div + + def nll_loss(self, hypothesis, target, pad_id): + eps = 1e-9 + B, T = target.shape + hypothesis = hypothesis.reshape(-1, hypothesis.size(-1)) + target = target[:, 1:] + padding = torch.ones(target.size(0), 1, dtype=torch.long) * pad_id + padding = padding.cuda() + target = torch.cat([target, padding], 1) + target = target.reshape(-1) + nll_loss = F.nll_loss(torch.log(hypothesis + 1e-20), target, ignore_index=pad_id, reduce=False) + not_ignore_tag = (target != pad_id).float() + not_ignore_num = not_ignore_tag.reshape(B, T).sum(-1) + sum_nll_loss = nll_loss.reshape(B, T).sum(-1) + nll_loss_vector = sum_nll_loss / (not_ignore_num + eps) + nll_loss = nll_loss_vector.mean() + return nll_loss, nll_loss_vector.detach() + + def regularization_loss(self, dist): + entropy_loss, repeat_loss = torch.tensor(0.), torch.tensor(0.) + if not self.args.wo_entropy_restrain: + entropy_loss = Tools.entropy_restrain(dist) + if not self.args.wo_repeat_penalty: + repeat_loss = Tools.repeat_penalty(dist) + regularization = entropy_loss + repeat_loss + return regularization + + def get_mask_via_len(self, length, max_len): + B = length.size(0) + mask = torch.ones([B, max_len]).cuda() + mask = torch.cumsum(mask, 1) + mask = mask <= length.unsqueeze(1) + mask = mask.unsqueeze(-2) + return mask + + + def get_default_tensor(self, shape, dtype, pad_idx=None): + pad_tensor = torch.zeros(shape, dtype=dtype) + pad_tensor[..., pad_idx] = 1.0 if dtype == torch.float else 1 + pad_tensor = pad_tensor.cuda() + return pad_tensor + + + def sparse_prefix_pad(self, inp, sos_idx): + n_vocab = inp.size(2) + pad = inp.new_ones(inp.size(0), 1, dtype=torch.long) * sos_idx + sparse_pad = Tools.one_hot(pad, n_vocab).cuda() + tensor = torch.cat([sparse_pad, inp], 1) + return tensor + + + def one_hot_scatter(self, indice, num_classes, dtype=torch.float): + indice_shape = list(indice.shape) + placeholder = torch.zeros(*(indice_shape + [num_classes]), device=indice.device, dtype=dtype) + v = 1 if dtype == torch.long else 1.0 + placeholder.scatter_(-1, indice.unsqueeze(-1), v) + return placeholder + + + def action_nll(self, hypothesis, target, pad_idx): + eps = 1e-9 + hypothesis = hypothesis.reshape(-1, hypothesis.size(-1)) + target = target[:, [1, 3, 5, 7, 9]] + target = target.reshape(-1) + if self.args.infoNCE_num != 0: + sample_num = self.args.infoNCE_num + sample_mask = torch.zeros_like(hypothesis, dtype=torch.bool) + select_indices = torch.randint(hypothesis.size(-1), (hypothesis.size(0), sample_num)).to(device=hypothesis.device) + for i in range(hypothesis.size(0)): + sample_mask[i, target[i]] = True + sample_mask[i, select_indices[i]] = True + hypothesis[~sample_mask] = 0 + nll_loss = F.nll_loss(torch.log(hypothesis + eps), target, ignore_index=pad_idx) + return nll_loss + + diff --git a/upcrtopic_with_early_stop.py b/upcrtopic_with_early_stop.py new file mode 100644 index 0000000..0261cfc --- /dev/null +++ b/upcrtopic_with_early_stop.py @@ -0,0 +1,406 @@ +import torch.nn as nn +import torch +from option import option as op +from tau_scheduler import TauScheduler +from copy_scheduler import CopyScheduler +from transformer.Models import Encoder +from transformer.Models import Decoder +from ProfileTopic import PriorProfile,PosteriorProfile +from PreferenceTopic import PriorPreference,PosteriorPreference +from gumbel_softmax import GumbelSoftmax +from ActionTopic import Action +from tools import Tools +from Vocab import Vocab +from scipy import optimize +import torch.nn.functional as F +from transformer.Optim import ScheduledOptim +from DataLoaderTopic import DataLoaderTopic +import math +import sys +from tqdm import tqdm + +class Upcrtopic(nn.Module): + def __init__(self,vocab:Vocab,user_cont,n_layers=6,p_layers=3, + d_word_vec=768,d_model=768, d_inner=3072,beam_width=1, + n_head=8, d_k=64, d_v=64, dropout=0.1): + super(Upcrtopic, self).__init__() + self.vocab = vocab + self.glo2loc , self.loc2glo = vocab.vocab_transfer() + self.glo2loc = torch.tensor(self.glo2loc).cuda() + self.loc2glo = torch.tensor(self.loc2glo).cuda() + self.topic_num = vocab.topic_num() + self.word_vocab, self.word_len, self.topic_vocab, self.topic_len = vocab.get_vocab(task='topic') + self.word_pad_idx = vocab.get_word_pad() + self.topic_pad_idx = vocab.get_topic_pad() + self.m_bos_idx = vocab.topic2index(op.BOS_PRE) + self.l_bos_idx = vocab.topic2index(op.BOS_PRO) + self.a_bos_idx = vocab.topic2index(op.BOS_ACTION) + self.r_bos_idx = vocab.word2index(op.BOS_RESPONSE) + self.r_eos_idx = vocab.word2index(op.EOS_RESPONSE) + self.beam_width = beam_width + self.pro_tau_scheduler = TauScheduler(op.init_tau, op.tau_mini, op.tau_decay_total_steps) + self.pre_tau_scheduler = TauScheduler(op.init_tau, op.tau_mini, op.tau_decay_total_steps) + self.m_copy_scheduler = CopyScheduler(op.s_copy_lambda, op.copy_lambda_mini, op.copy_lambda_decay_steps) + self.l_copy_scheduler = CopyScheduler(op.a_copy_lambda, op.copy_lambda_mini, op.copy_lambda_decay_steps) + self.word_emb = nn.Embedding(self.word_len,d_word_vec,padding_idx=self.word_pad_idx) + self.topic_emb = nn.Embedding(self.topic_len,d_word_vec,padding_idx=self.topic_pad_idx) + self.user_emb = nn.Embedding(user_cont,d_word_vec) + self.gumbel_softmax = GumbelSoftmax() + self.global_step = 0 + self.main_tfr_encoder = Encoder( + n_src_vocab=self.word_len, n_position=op.conv_max_len, + d_word_vec=d_word_vec, d_model=d_model, d_inner=d_inner, + n_layers=n_layers, n_head=n_head, d_k=d_k, d_v=d_v, + pad_idx=self.word_pad_idx, dropout=dropout, scale_emb=False, + word_emb=self.word_emb + ).cuda() + self.u_tfr_encoder4p = Encoder( + n_src_vocab=user_cont, n_position=1, + d_word_vec=d_word_vec, d_model=d_model, d_inner=d_inner, + n_layers=p_layers, n_head=n_head, d_k=d_k, d_v=d_v, + pad_idx=self.topic_pad_idx, dropout=dropout, scale_emb=False, + word_emb=self.user_emb + ).cuda() + self.u_tfr_encoder4q = Encoder( + n_src_vocab=user_cont, n_position=1, + d_word_vec=d_word_vec, d_model=d_model, d_inner=d_inner, + n_layers=p_layers, n_head=n_head, d_k=d_k, d_v=d_v, + pad_idx=self.topic_pad_idx, dropout=dropout, scale_emb=False, + word_emb=self.user_emb + ).cuda() + self.p_tfr_encoder4p = Encoder( + n_src_vocab=self.topic_len, n_position=200, + d_word_vec=d_word_vec, d_model=d_model, d_inner=d_inner, + n_layers=p_layers, n_head=n_head, d_k=d_k, d_v=d_v, + pad_idx=self.topic_pad_idx, dropout=dropout, scale_emb=False, + word_emb=self.topic_emb + ).cuda() + self.p_tfr_decoder4p = Decoder( + n_trg_vocab=self.topic_len, n_position=30, + d_word_vec=d_word_vec, d_model=d_model, d_inner=d_inner, + n_layers=p_layers, n_head=n_head, d_k=d_k, d_v=d_v, + pad_idx=self.topic_pad_idx, dropout=dropout, scale_emb=False, + word_emb=self.topic_emb + ).cuda() + self.p_tfr_encoder4q = Encoder( + n_src_vocab=self.topic_len, n_position=30, + d_word_vec=d_word_vec, d_model=d_model, d_inner=d_inner, + n_layers=p_layers, n_head=n_head, d_k=d_k, d_v=d_v, + pad_idx=self.topic_pad_idx, dropout=dropout, scale_emb=False, + word_emb=self.topic_emb + ).cuda() + self.p_tfr_decoder4q = Decoder( + n_trg_vocab=self.topic_len, n_position=30, + d_word_vec=d_word_vec, d_model=d_model, d_inner=d_inner, + n_layers=p_layers, n_head=n_head, d_k=d_k, d_v=d_v, + pad_idx=self.topic_pad_idx, dropout=dropout, scale_emb=False, + word_emb=self.topic_emb + ).cuda() + self.a_tfr_decoder = Decoder( + n_trg_vocab=self.topic_len, n_position=op.action_num, + d_word_vec=d_word_vec, d_model=d_model, d_inner=d_inner, + n_layers=p_layers, n_head=n_head, d_k=d_k, d_v=d_v, + pad_idx=self.topic_pad_idx, dropout=dropout, scale_emb=False, + word_emb=self.topic_emb + ).cuda() + if op.wo_l: + self.p_l = None + self.q_l = None + else: + self.p_l = PriorProfile(encoder=self.u_tfr_encoder4p,decoder=self.p_tfr_decoder4p, + hidden_size=d_model,n_topic_vocab=self.topic_len, + trg_bos_idx=self.l_bos_idx,max_seq_len=op.profile_num, + gs=self.gumbel_softmax,ts=self.pro_tau_scheduler).cuda() + + self.q_l = PosteriorProfile(main_encoder=self.main_tfr_encoder,id_encoder = self.u_tfr_encoder4q, + topic_encoder=self.p_tfr_encoder4q, decoder=self.p_tfr_decoder4q, + hidden_size=d_model,n_topic_vocab=self.topic_len, + trg_bos_idx=self.l_bos_idx,max_seq_len=op.profile_num, + gs=self.gumbel_softmax,ts=self.pro_tau_scheduler, + glo2loc=self.glo2loc,loc2glo=self.loc2glo).cuda() + if op.wo_m: + self.p_mt = None + self.q_mt = None + else: + self.p_mt = PriorPreference(encoder=self.p_tfr_encoder4p,decoder=self.p_tfr_decoder4p, + main_tfr_encoder=self.main_tfr_encoder, + hidden_size=d_model,n_topic_vocab=self.topic_len,trg_bos_idx=self.m_bos_idx, + max_seq_len=op.preference_num,gs=self.gumbel_softmax,glo2loc=self.glo2loc, + loc2glo=self.loc2glo,ts=self.pre_tau_scheduler).cuda() + self.q_mt = PosteriorPreference(encoder=self.p_tfr_encoder4q,main_encoder=self.main_tfr_encoder, + decoder=self.p_tfr_decoder4q, + hidden_size=d_model,n_topic_vocab=self.topic_len,trg_bos_idx=self.m_bos_idx, + max_seq_len=op.preference_num,gs=self.gumbel_softmax,glo2loc=self.glo2loc, + loc2glo=self.loc2glo,ts=self.pre_tau_scheduler).cuda() + self.action = Action(p_encoder=self.p_tfr_encoder4p,main_encoder=self.main_tfr_encoder, + a_decoder=self.a_tfr_decoder,hidden_size=d_model, + n_topic_vocab=self.topic_len,bos_idx=self.a_bos_idx,vocab=self.vocab, + max_len=op.action_num,glo2loc=self.glo2loc,loc2glo=self.loc2glo).cuda() + def forward(self,user_id,all_topic, all_topic_len,context, context_len,tp_path, tp_path_len,ar_gth, ar_gth_len, + related_topics, related_topics_len,final,pv_m,mode='train'): + assert mode in ['train','valid','test'] + pv_m, pv_m_mask = self.mask_preference(pv_m, final) + if mode == 'train': + self.global_step += 1 + p_l, p_l_gumbel = self.p_l.forward(id=user_id) + q_l, q_l_gumbel = self.q_l.forward(id=user_id, topics=all_topic, topics_len=all_topic_len) + p_m, p_m_gumbel = self.p_mt.forward(context=context, context_len=context_len,pv_m=pv_m,pv_m_mask=pv_m_mask, + tp_path=tp_path,tp_path_len=tp_path_len) + q_m, q_m_gumbel = self.q_mt.forward(context=context, context_len=context_len,pv_m=pv_m, + pv_m_mask=pv_m_mask,ar_gth=ar_gth,ar_gth_len=ar_gth_len, + tp_path=tp_path,tp_path_len=tp_path_len) + ar = self.action.forward(m=q_m_gumbel,l=q_l_gumbel,context=context,context_len=context_len,ar_gth=ar_gth,ar_gth_len=ar_gth_len, + tp_path=tp_path,tp_path_len=tp_path_len,related_topics=related_topics,related_topics_len=related_topics_len, + mode='train') + return p_l, q_l, p_m, q_m, ar, q_m_gumbel + else: + p_l = self.p_l.forward(id=user_id) + p_m = self.p_mt.forward(context=context,context_len=context_len,pv_m=pv_m,pv_m_mask=pv_m_mask,tp_path=tp_path,tp_path_len=tp_path_len) + ar,ar_probs = self.action.forward(m=p_m,l=p_l,context=context,context_len=context_len,ar_gth=ar_gth,ar_gth_len=ar_gth_len, + tp_path=tp_path,tp_path_len=tp_path_len,related_topics=related_topics, related_topics_len=related_topics_len, + mode='test') + return ar, ar_probs, p_m, p_l + + def mask_preference(self, pv_m, final): + b = range(op.batch_size) + b = [i + 1 for i in b] + b = torch.tensor(b).cuda().tolist() + final = list(final) + final = [int(i) for i in final] + c = [i * j for i, j in zip(final, b)] + c = list(c) + d = [] + for i in c: + if i != 0: + d.append(c.index(i)) + if d: + d = torch.tensor(d).cuda() + pv_m[d, :, :] = 0 + pv_m[d, :, self.topic_pad_idx] = 1.0 + pv_m_mask = pv_m.new_ones(pv_m.size(0), 1, pv_m.size(1)) + pv_m_mask[d,:,:] = 0 + return pv_m, pv_m_mask + + def topictensor2nl(self,tensor): + words = tensor.detach().cpu().numpy() + words = self.vocab.index2topic(words) + return words + + def wordtensor2nl(self,tensor): + words = tensor.detach().cpu().numpy() + words = self.vocab.index2word(words) + return words + +class Engine(): + def __init__(self,model:torch.nn.Module,vocab): + self.model = model + lr = op.lr + self.optimizer = torch.optim.Adam(self.model.parameters(), lr, betas=(0.9, 0.98), eps=1e-9) + self.optimizer = ScheduledOptim(self.optimizer, 0.5, op.d_model, op.n_warmup_steps) + self.vocab = vocab + self.topic_pad_idx = self.vocab.topic2index(op.PAD_WORD) + self.global_step = 0 + self.action_loss = 0 + self.kl_l_loss = 0 + self.kl_m_loss =0 + + def train(self, train_set, test_set, valid_set, early_stop_num=10): + best_metrics = { + "TopicId_Hits@1": 0, + "TopicId_Hits@3": 0, + "TopicId_Hits@5": 0, + "best_epoch": 0, + } + for e in range(op.epoch): + print("epoch : {}".format(e)) + train_loader = DataLoaderTopic(train_set, self.vocab) + pv_m = get_default_tensor([op.batch_size, op.preference_num, self.model.topic_len], torch.float,pad_idx=self.topic_pad_idx) + self.optimizer.zero_grad() + for index,input in enumerate(tqdm(train_loader)): + if input[0].size(0) != op.batch_size: + break + id, context_idx, context_len, state_U, state_U_len, related_topics, related_topics_len, \ + a_R, a_R_len, all_topic, all_topic_len,final = input + p_l, q_l, p_m, q_m, ar, m= self.model.forward(user_id=id,all_topic=all_topic,all_topic_len=all_topic_len,context=context_idx, context_len=context_len, + tp_path=state_U,tp_path_len=state_U_len,ar_gth=a_R, ar_gth_len=a_R_len,related_topics=related_topics, + related_topics_len=related_topics_len,final=final,pv_m=pv_m) + kl_l = kl_loss(p_l, q_l.detach()) + self.kl_l_loss += kl_l.item() + kl_m = kl_loss(p_m, q_m.detach()) + self.kl_m_loss += kl_m.item() + nll_ar = action_nll(ar, a_R.detach(), self.model.topic_pad_idx) + self.action_loss += nll_ar.item() + p_l_reg, q_l_reg = regularization_loss(p_l), regularization_loss(q_l) + p_m_reg, q_m_reg = regularization_loss(p_m), regularization_loss(q_m) + reg_loss = op.reg_lambda * (p_l_reg + q_l_reg + p_m_reg + q_m_reg ) + loss = 0.5 * kl_l + nll_ar + reg_loss + 0.5 * kl_m + + + + + + + loss = loss / float(op.gradient_stack) + loss.backward(retain_graph=False) + if self.global_step % op.gradient_stack == 0: + self.optimizer.step() + self.optimizer.zero_grad() + self.global_step += 1 + pv_m = m.detach() + print('valid:', end=' ') + metric = self.test(valid_set) + if metric["TopicId_Hits@3"] > best_metrics["TopicId_Hits@3"]: + best_metrics["TopicId_Hits@1"] = metric["TopicId_Hits@1"] + best_metrics["TopicId_Hits@3"] = metric["TopicId_Hits@3"] + best_metrics["TopicId_Hits@5"] = metric["TopicId_Hits@5"] + best_metrics['best_epoch'] = e + torch.save(self.model, 'best_topic_model.pkl') + elif e - best_metrics['best_epoch'] >= early_stop_num: + print('early stop at epoch', e) + break + self.model = torch.load('best_topic_model.pkl') + print('test:', end=' ') + metric = self.test(train_set) + print("train finished ! ") + print('best valid:' + str(best_metrics)) + + + + def test(self, test_set): + self.model.eval() + dataloader = DataLoaderTopic(test_set,self.vocab) + metrics = { + "topic_Loss": 0, + "TopicId_Hits@1": 0, + "TopicId_Hits@3": 0, + "TopicId_Hits@5": 0, + "topic_count": 0, + } + pv_m = get_default_tensor([op.batch_size, op.preference_num, self.model.topic_len], torch.float,pad_idx=self.model.topic_pad_idx) + with torch.no_grad(): + for index, data in enumerate(dataloader): + if data[0].size(0) != op.batch_size: + break + id, context_idx, context_len, state_U, state_U_len, related_topics, related_topics_len, \ + a_R, a_R_len, all_topic, all_topic_len, final = data + ar, ar_probs, m , l = self.model.forward(user_id=id,all_topic=all_topic,all_topic_len=all_topic_len,context=context_idx, context_len=context_len, + tp_path=state_U,tp_path_len=state_U_len,ar_gth=a_R, ar_gth_len=a_R_len,related_topics=related_topics, + related_topics_len=related_topics_len,final=final,pv_m=pv_m,mode='test') + self.compute_metrics(ar_probs, a_R, a_R_len, metrics) + metrics['TopicId_Hits@1'] = round(metrics['TopicId_Hits@1'] / metrics['topic_count'], 4) + metrics['TopicId_Hits@3'] = round(metrics['TopicId_Hits@3'] / metrics['topic_count'], 4) + metrics['TopicId_Hits@5'] = round(metrics['TopicId_Hits@5'] / metrics['topic_count'], 4) + print(metrics) + self.model.train() + return metrics + + def compute_metrics(self,ar_probs, ar_gth, a_R_len, metrics): + tanlun = self.vocab.topic2index('谈论') + qingqiutuijian = self.vocab.topic2index('请求推荐') + def _topic_prediction(tar,gen,metrics): + metrics['topic_count'] += 1 + for k in [1,3,5]: + pred, pred_id = torch.topk(gen,k,-1) + pred_id = pred_id.tolist() + if tar in pred_id: + metrics["TopicId_Hits@{}".format(k)] += 1 + def _movie_recommendation(tar,gen,metrics): + metrics['rec_count'] += 1 + for k in [1,10,50]: + pred, pred_id = torch.topk(gen,k,-1) + pred_id = pred_id.tolist() + if tar in pred_id: + rank = pred_id.index(tar) + metrics['NDCG{}'.format(k)] += 1.0 / math.log(rank + 2.0, 2) + metrics['MRR{}'.format(k)] += 1.0 / (rank + 1.0) + for i, gt in enumerate(ar_gth): + ar_gen = ar_probs[i,:] + gt_len = int(a_R_len[i]) + for j in range(0,gt_len,2): + action_type = gt[j] + if action_type == self.vocab.topic2index('推荐电影'): + _movie_recommendation(gt[j+1],ar_gen[int(j/2)],metrics) + else: + _topic_prediction(gt[j+1],ar_gen[int(j/2)],metrics) + if tanlun in gt and qingqiutuijian in gt: + break + +def get_mask_via_len(length, max_len): + B = length.size(0) + mask = torch.ones([B, max_len]).cuda() + mask = torch.cumsum(mask, 1) + mask = mask <= length.unsqueeze(1) + mask = mask.unsqueeze(-2) + return mask + +def get_default_tensor(shape, dtype, pad_idx=None): + pad_tensor = torch.zeros(shape, dtype=dtype) + pad_tensor[..., pad_idx] = 1.0 if dtype == torch.float else 1 + pad_tensor = pad_tensor.cuda() + return pad_tensor + +def sparse_prefix_pad(inp, sos_idx): + n_vocab = inp.size(2) + pad = inp.new_ones(inp.size(0), 1, dtype=torch.long) * sos_idx + sparse_pad = Tools.one_hot(pad, n_vocab).cuda() + tensor = torch.cat([sparse_pad, inp], 1) + return tensor + +def one_hot_scatter(indice, num_classes, dtype=torch.float): + indice_shape = list(indice.shape) + placeholder = torch.zeros(*(indice_shape + [num_classes]), device=indice.device, dtype=dtype) + v = 1 if dtype == torch.long else 1.0 + placeholder.scatter_(-1, indice.unsqueeze(-1), v) + return placeholder + +def kl_loss(prior_dist, posterior_dist): + bias = 1e-24 + if (len(prior_dist.shape) >= 3) and op.hungary: + B, S = prior_dist.size(0), prior_dist.size(1) + expand_prior_dist = prior_dist.unsqueeze(2).expand(-1, -1, S, -1).reshape(B, S * S, -1) + expand_posterior_dist = posterior_dist.unsqueeze(1).expand(-1, S, -1, -1).reshape(B, S * S, -1) + cost_vector = F.kl_div((expand_prior_dist + bias).log(), expand_posterior_dist, reduce=False).sum(-1) + cost_matrix = cost_vector.reshape(-1, S, S) + cost_matrix_np = cost_matrix.detach().cpu().numpy() + row_idx, col_idx = zip(*[optimize.linear_sum_assignment(cost_matrix_np[i]) for i in range(B)]) + col_idx = torch.tensor(col_idx, dtype=torch.long) + posterior_dist = Tools.nested_index_select(posterior_dist, col_idx) + flat_prior_dist = prior_dist.reshape(-1, prior_dist.size(-1)) + flat_posterior_dist = posterior_dist.reshape(-1, posterior_dist.size(-1)) + kl_div = F.kl_div((flat_prior_dist + bias).log(), flat_posterior_dist, reduce=False).sum(-1) + kl_div = kl_div.mean() + return kl_div + +def nll_loss(hypothesis, target, pad_id ): + eps = 1e-9 + B, T = target.shape + hypothesis = hypothesis.reshape(-1, hypothesis.size(-1)) + target = target[:,1:] + padding = torch.ones(target.size(0),1,dtype=torch.long) * pad_id + padding = padding.cuda() + target = torch.cat([target,padding],1) + target = target.reshape(-1) + nll_loss = F.nll_loss(torch.log(hypothesis + 1e-20), target, ignore_index=pad_id, reduce=False) + not_ignore_tag = (target != pad_id).float() + not_ignore_num = not_ignore_tag.reshape(B, T).sum(-1) + sum_nll_loss = nll_loss.reshape(B, T).sum(-1) + nll_loss_vector = sum_nll_loss / (not_ignore_num + eps) + nll_loss = nll_loss_vector.mean() + return nll_loss, nll_loss_vector.detach() + +def regularization_loss(dist): + entropy_loss, repeat_loss = torch.tensor(0.), torch.tensor(0.) + if not op.wo_entropy_restrain: + entropy_loss = Tools.entropy_restrain(dist) + if not op.wo_repeat_penalty: + repeat_loss = Tools.repeat_penalty(dist) + regularization = entropy_loss + repeat_loss + return regularization + +def action_nll(hypothesis,target,pad_idx): + eps = 1e-9 + hypothesis = hypothesis.reshape(-1,hypothesis.size(-1)) + target = target[:,[1,3,5,7,9]] + target = target.reshape(-1) + nll_loss = F.nll_loss(torch.log(hypothesis+eps),target,ignore_index=pad_idx) + return nll_loss \ No newline at end of file diff --git a/vocabulary/vocab_small.txt b/vocabulary/vocab_small.txt new file mode 100644 index 0000000..d35a4fd --- /dev/null +++ b/vocabulary/vocab_small.txt @@ -0,0 +1,13317 @@ +[PAD] +“ +” +’ +‘ +[unused5] +[unused6] +[unused7] +[unused8] +[unused9] +[unused10] +[unused11] +[unused12] +[unused13] +[unused14] +[unused15] +[unused16] +[unused17] +[unused18] +[unused19] +[unused20] +[unused21] +[unused22] +[unused23] +[unused24] +[unused25] +[unused26] +[unused27] +[unused28] +[unused29] +[unused30] +[unused31] +[unused32] +[unused33] +[unused34] +[unused35] +[unused36] +[unused37] +[unused38] +[unused39] +[unused40] +[unused41] +[unused42] +[unused43] +[unused44] +[unused45] +[unused46] +[unused47] +[unused48] +[unused49] +[unused50] +[unused51] +[unused52] +[unused53] +[unused54] +[unused55] +[unused56] +[unused57] +[unused58] +[unused59] +[unused60] +[unused61] +[unused62] +[unused63] +[unused64] +[unused65] +[unused66] +[unused67] +[unused68] +[unused69] +[unused70] +[unused71] +[unused72] +[unused73] +[unused74] +[unused75] +[unused76] +[unused77] +[unused78] +[unused79] +[unused80] +[unused81] +[unused82] +[unused83] +[unused84] +[unused85] +[unused86] +[unused87] +[unused88] +[unused89] +[unused90] +[unused91] +[unused92] +[unused93] +[unused94] +[unused95] +[unused96] +[unused97] +[unused98] +[unused99] +[UNK] +[CLS] +[SEP] +[MASK] + + +! +" + +$ +% +& +' +( +) +* ++ +, +- +. +/ +0 +1 +2 +3 +4 +5 +6 +7 +8 +9 +: +; +< += +> +? +@ +[ +\ +] +^ +_ +a +b +c +d +e +f +g +h +i +j +k +l +m +n +o +p +q +r +s +t +u +v +w +x +y +z +{ +| +} +~ +£ +¤ +¥ +§ +© +« +® +° +± +² +³ +µ +· +¹ +º +» +¼ +× +ß +æ +÷ +ø +đ +ŋ +ɔ +ə +ɡ +ʰ +ˇ +ˈ +ˊ +ˋ +ˍ +ː +˙ +˚ +ˢ +α +β +γ +δ +ε +η +θ +ι +κ +λ +μ +ν +ο +π +ρ +ς +σ +τ +υ +φ +χ +ψ +ω +а +б +в +г +д +е +ж +з +и +к +л +м +н +о +п +р +с +т +у +ф +х +ц +ч +ш +ы +ь +я +і +ا +ب +ة +ت +د +ر +س +ع +ل +م +ن +ه +و +ي +۩ +ก +ง +น +ม +ย +ร +อ +า +เ +๑ +་ +ღ +ᄀ +ᄁ +ᄂ +ᄃ +ᄅ +ᄆ +ᄇ +ᄈ +ᄉ +ᄋ +ᄌ +ᄎ +ᄏ +ᄐ +ᄑ +ᄒ +ᅡ +ᅢ +ᅣ +ᅥ +ᅦ +ᅧ +ᅨ +ᅩ +ᅪ +ᅬ +ᅭ +ᅮ +ᅯ +ᅲ +ᅳ +ᅴ +ᅵ +ᆨ +ᆫ +ᆯ +ᆷ +ᆸ +ᆺ +ᆻ +ᆼ +ᗜ +ᵃ +ᵉ +ᵍ +ᵏ +ᵐ +ᵒ +ᵘ +‖ +„ +† +• +‥ +‧ +
 +‰ +′ +″ +‹ +› +※ +‿ +⁄ +ⁱ +⁺ +ⁿ +₁ +₂ +₃ +₄ +€ +℃ +№ +™ +ⅰ +ⅱ +ⅲ +ⅳ +ⅴ +← +↑ +→ +↓ +↔ +↗ +↘ +⇒ +∀ +− +∕ +∙ +√ +∞ +∟ +∠ +∣ +∥ +∩ +∮ +∶ +∼ +∽ +≈ +≒ +≡ +≤ +≥ +≦ +≧ +≪ +≫ +⊙ +⋅ +⋈ +⋯ +⌒ +① +② +③ +④ +⑤ +⑥ +⑦ +⑧ +⑨ +⑩ +⑴ +⑵ +⑶ +⑷ +⑸ +⒈ +⒉ +⒊ +⒋ +ⓒ +ⓔ +ⓘ +─ +━ +│ +┃ +┅ +┆ +┊ +┌ +└ +├ +┣ +═ +║ +╚ +╞ +╠ +╭ +╮ +╯ +╰ +╱ +╳ +▂ +▃ +▅ +▇ +█ +▉ +▋ +▌ +▍ +▎ +■ +□ +▪ +▫ +▬ +▲ +△ +▶ +► +▼ +▽ +◆ +◇ +○ +◎ +● +◕ +◠ +◢ +◤ +☀ +★ +☆ +☕ +☞ +☺ +☼ +♀ +♂ +♠ +♡ +♣ +♥ +♦ +♪ +♫ +♬ +✈ +✔ +✕ +✖ +✦ +✨ +✪ +✰ +✿ +❀ +❤ +➜ +➤ +⦿ +、 +。 +〃 +々 +〇 +〈 +〉 +《 +》 +「 +」 +『 +』 +【 +】 +〓 +〔 +〕 +〖 +〗 +〜 +〝 +〞 +ぁ +あ +ぃ +い +う +ぇ +え +お +か +き +く +け +こ +さ +し +す +せ +そ +た +ち +っ +つ +て +と +な +に +ぬ +ね +の +は +ひ +ふ +へ +ほ +ま +み +む +め +も +ゃ +や +ゅ +ゆ +ょ +よ +ら +り +る +れ +ろ +わ +を +ん +゜ +ゝ +ァ +ア +ィ +イ +ゥ +ウ +ェ +エ +ォ +オ +カ +キ +ク +ケ +コ +サ +シ +ス +セ +ソ +タ +チ +ッ +ツ +テ +ト +ナ +ニ +ヌ +ネ +ノ +ハ +ヒ +フ +ヘ +ホ +マ +ミ +ム +メ +モ +ャ +ヤ +ュ +ユ +ョ +ヨ +ラ +リ +ル +レ +ロ +ワ +ヲ +ン +ヶ +・ +ー +ヽ +ㄅ +ㄆ +ㄇ +ㄉ +ㄋ +ㄌ +ㄍ +ㄎ +ㄏ +ㄒ +ㄚ +ㄛ +ㄞ +ㄟ +ㄢ +ㄤ +ㄥ +ㄧ +ㄨ +ㆍ +㈦ +㊣ +㎡ +㗎 +一 +丁 +七 +万 +丈 +三 +上 +下 +不 +与 +丐 +丑 +专 +且 +丕 +世 +丘 +丙 +业 +丛 +东 +丝 +丞 +丟 +両 +丢 +两 +严 +並 +丧 +丨 +个 +丫 +中 +丰 +串 +临 +丶 +丸 +丹 +为 +主 +丼 +丽 +举 +丿 +乂 +乃 +久 +么 +义 +之 +乌 +乍 +乎 +乏 +乐 +乒 +乓 +乔 +乖 +乗 +乘 +乙 +乜 +九 +乞 +也 +习 +乡 +书 +乩 +买 +乱 +乳 +乾 +亀 +亂 +了 +予 +争 +事 +二 +于 +亏 +云 +互 +五 +井 +亘 +亙 +亚 +些 +亜 +亞 +亟 +亡 +亢 +交 +亥 +亦 +产 +亨 +亩 +享 +京 +亭 +亮 +亲 +亳 +亵 +人 +亿 +什 +仁 +仃 +仄 +仅 +仆 +仇 +今 +介 +仍 +从 +仏 +仑 +仓 +仔 +仕 +他 +仗 +付 +仙 +仝 +仞 +仟 +代 +令 +以 +仨 +仪 +们 +仮 +仰 +仲 +件 +价 +任 +份 +仿 +企 +伉 +伊 +伍 +伎 +伏 +伐 +休 +伕 +众 +优 +伙 +会 +伝 +伞 +伟 +传 +伢 +伤 +伦 +伪 +伫 +伯 +估 +伴 +伶 +伸 +伺 +似 +伽 +佃 +但 +佇 +佈 +位 +低 +住 +佐 +佑 +体 +佔 +何 +佗 +佘 +余 +佚 +佛 +作 +佝 +佞 +佟 +你 +佢 +佣 +佤 +佥 +佩 +佬 +佯 +佰 +佳 +併 +佶 +佻 +佼 +使 +侃 +侄 +來 +侈 +例 +侍 +侏 +侑 +侖 +侗 +供 +依 +侠 +価 +侣 +侥 +侦 +侧 +侨 +侬 +侮 +侯 +侵 +侶 +侷 +便 +係 +促 +俄 +俊 +俎 +俏 +俐 +俑 +俗 +俘 +俚 +保 +俞 +俟 +俠 +信 +俨 +俩 +俪 +俬 +俭 +修 +俯 +俱 +俳 +俸 +俺 +俾 +倆 +倉 +個 +倌 +倍 +倏 +們 +倒 +倔 +倖 +倘 +候 +倚 +倜 +借 +倡 +値 +倦 +倩 +倪 +倫 +倬 +倭 +倶 +债 +值 +倾 +偃 +假 +偈 +偉 +偌 +偎 +偏 +偕 +做 +停 +健 +側 +偵 +偶 +偷 +偻 +偽 +偿 +傀 +傅 +傍 +傑 +傘 +備 +傚 +傢 +傣 +傥 +储 +傩 +催 +傭 +傲 +傳 +債 +傷 +傻 +傾 +僅 +働 +像 +僑 +僕 +僖 +僚 +僥 +僧 +僭 +僮 +僱 +僵 +價 +僻 +儀 +儂 +億 +儆 +儉 +儋 +儒 +儕 +儘 +償 +儡 +優 +儲 +儷 +儼 +儿 +兀 +允 +元 +兄 +充 +兆 +兇 +先 +光 +克 +兌 +免 +児 +兑 +兒 +兔 +兖 +党 +兜 +兢 +入 +內 +全 +兩 +八 +公 +六 +兮 +兰 +共 +兲 +关 +兴 +兵 +其 +具 +典 +兹 +养 +兼 +兽 +冀 +内 +円 +冇 +冈 +冉 +冊 +册 +再 +冏 +冒 +冕 +冗 +写 +军 +农 +冠 +冢 +冤 +冥 +冨 +冪 +冬 +冯 +冰 +冲 +决 +况 +冶 +冷 +冻 +冼 +冽 +冾 +净 +凄 +准 +凇 +凈 +凉 +凋 +凌 +凍 +减 +凑 +凛 +凜 +凝 +几 +凡 +凤 +処 +凪 +凭 +凯 +凰 +凱 +凳 +凶 +凸 +凹 +出 +击 +函 +凿 +刀 +刁 +刃 +分 +切 +刈 +刊 +刍 +刎 +刑 +划 +列 +刘 +则 +刚 +创 +初 +删 +判 +別 +刨 +利 +刪 +别 +刮 +到 +制 +刷 +券 +刹 +刺 +刻 +刽 +剁 +剂 +剃 +則 +剉 +削 +剋 +剌 +前 +剎 +剐 +剑 +剔 +剖 +剛 +剜 +剝 +剣 +剤 +剥 +剧 +剩 +剪 +副 +割 +創 +剷 +剽 +剿 +劃 +劇 +劈 +劉 +劊 +劍 +劏 +劑 +力 +劝 +办 +功 +加 +务 +劣 +动 +助 +努 +劫 +劭 +励 +劲 +劳 +労 +劵 +効 +劾 +势 +勁 +勃 +勇 +勉 +勋 +勐 +勒 +動 +勖 +勘 +務 +勛 +勝 +勞 +募 +勢 +勤 +勧 +勳 +勵 +勸 +勺 +勻 +勾 +勿 +匀 +包 +匆 +匈 +匍 +匐 +匕 +化 +北 +匙 +匝 +匠 +匡 +匣 +匪 +匮 +匯 +匱 +匹 +区 +医 +匾 +匿 +區 +十 +千 +卅 +升 +午 +卉 +半 +卍 +华 +协 +卑 +卒 +卓 +協 +单 +卖 +南 +単 +博 +卜 +卞 +卟 +占 +卡 +卢 +卤 +卦 +卧 +卫 +卮 +卯 +印 +危 +即 +却 +卵 +卷 +卸 +卻 +卿 +厂 +厄 +厅 +历 +厉 +压 +厌 +厕 +厘 +厚 +厝 +原 +厢 +厥 +厦 +厨 +厩 +厭 +厮 +厲 +厳 +去 +县 +叁 +参 +參 +又 +叉 +及 +友 +双 +反 +収 +发 +叔 +取 +受 +变 +叙 +叛 +叟 +叠 +叡 +叢 +口 +古 +句 +另 +叨 +叩 +只 +叫 +召 +叭 +叮 +可 +台 +叱 +史 +右 +叵 +叶 +号 +司 +叹 +叻 +叼 +叽 +吁 +吃 +各 +吆 +合 +吉 +吊 +吋 +同 +名 +后 +吏 +吐 +向 +吒 +吓 +吕 +吖 +吗 +君 +吝 +吞 +吟 +吠 +吡 +否 +吧 +吨 +吩 +含 +听 +吭 +吮 +启 +吱 +吳 +吴 +吵 +吶 +吸 +吹 +吻 +吼 +吽 +吾 +呀 +呂 +呃 +呆 +呈 +告 +呋 +呎 +呐 +呓 +呕 +呗 +员 +呛 +呜 +呢 +呤 +呦 +周 +呱 +呲 +味 +呵 +呷 +呸 +呻 +呼 +命 +咀 +咁 +咂 +咄 +咆 +咋 +和 +咎 +咏 +咐 +咒 +咔 +咕 +咖 +咗 +咘 +咙 +咚 +咛 +咣 +咤 +咦 +咧 +咨 +咩 +咪 +咫 +咬 +咭 +咯 +咱 +咲 +咳 +咸 +咻 +咽 +咿 +哀 +品 +哂 +哄 +哆 +哇 +哈 +哉 +哋 +哌 +响 +哎 +哏 +哐 +哑 +哒 +哔 +哗 +哟 +員 +哥 +哦 +哧 +哨 +哩 +哪 +哭 +哮 +哲 +哺 +哼 +哽 +唁 +唄 +唆 +唇 +唉 +唏 +唐 +唑 +唔 +唠 +唤 +唧 +唬 +售 +唯 +唰 +唱 +唳 +唷 +唸 +唾 +啃 +啄 +商 +啉 +啊 +問 +啓 +啕 +啖 +啜 +啞 +啟 +啡 +啤 +啥 +啦 +啧 +啪 +啫 +啬 +啮 +啰 +啱 +啲 +啵 +啶 +啷 +啸 +啻 +啼 +啾 +喀 +喂 +喃 +善 +喆 +喇 +喉 +喊 +喋 +喎 +喏 +喔 +喘 +喙 +喚 +喜 +喝 +喟 +喧 +喪 +喫 +喬 +單 +喰 +喱 +喲 +喳 +喵 +営 +喷 +喹 +喺 +喻 +喽 +嗅 +嗆 +嗇 +嗎 +嗑 +嗒 +嗓 +嗔 +嗖 +嗚 +嗜 +嗝 +嗟 +嗡 +嗣 +嗤 +嗦 +嗨 +嗪 +嗬 +嗯 +嗰 +嗲 +嗳 +嗶 +嗷 +嗽 +嘀 +嘅 +嘆 +嘈 +嘉 +嘌 +嘍 +嘎 +嘔 +嘖 +嘗 +嘘 +嘚 +嘛 +嘜 +嘞 +嘟 +嘢 +嘣 +嘤 +嘧 +嘩 +嘭 +嘮 +嘯 +嘰 +嘱 +嘲 +嘴 +嘶 +嘸 +嘹 +嘻 +嘿 +噁 +噌 +噎 +噓 +噔 +噗 +噙 +噜 +噠 +噢 +噤 +器 +噩 +噪 +噬 +噱 +噴 +噶 +噸 +噹 +噻 +噼 +嚀 +嚇 +嚎 +嚏 +嚐 +嚓 +嚕 +嚟 +嚣 +嚥 +嚨 +嚮 +嚴 +嚷 +嚼 +囂 +囉 +囊 +囍 +囑 +囔 +囗 +囚 +四 +囝 +回 +囟 +因 +囡 +团 +団 +囤 +囧 +囪 +囫 +园 +困 +囱 +囲 +図 +围 +囹 +固 +国 +图 +囿 +圃 +圄 +圆 +圈 +國 +圍 +圏 +園 +圓 +圖 +團 +圜 +土 +圣 +圧 +在 +圩 +圭 +地 +圳 +场 +圻 +圾 +址 +坂 +均 +坊 +坍 +坎 +坏 +坐 +坑 +块 +坚 +坛 +坝 +坞 +坟 +坠 +坡 +坤 +坦 +坨 +坪 +坯 +坳 +坵 +坷 +垂 +垃 +垄 +型 +垒 +垚 +垛 +垠 +垢 +垣 +垦 +垩 +垫 +垭 +垮 +垵 +埂 +埃 +埋 +城 +埔 +埕 +埗 +域 +埠 +埤 +埵 +執 +埸 +培 +基 +埼 +堀 +堂 +堃 +堅 +堆 +堇 +堑 +堕 +堙 +堡 +堤 +堪 +堯 +堰 +報 +場 +堵 +堺 +堿 +塊 +塌 +塑 +塔 +塗 +塘 +塚 +塞 +塢 +塩 +填 +塬 +塭 +塵 +塾 +墀 +境 +墅 +墉 +墊 +墒 +墓 +増 +墘 +墙 +墜 +增 +墟 +墨 +墩 +墮 +墳 +墻 +墾 +壁 +壅 +壆 +壇 +壊 +壑 +壓 +壕 +壘 +壞 +壟 +壢 +壤 +壩 +士 +壬 +壮 +壯 +声 +売 +壳 +壶 +壹 +壺 +壽 +处 +备 +変 +复 +夏 +夔 +夕 +外 +夙 +多 +夜 +够 +夠 +夢 +夥 +大 +天 +太 +夫 +夭 +央 +夯 +失 +头 +夷 +夸 +夹 +夺 +夾 +奂 +奄 +奇 +奈 +奉 +奋 +奎 +奏 +奐 +契 +奔 +奕 +奖 +套 +奘 +奚 +奠 +奢 +奥 +奧 +奪 +奬 +奮 +女 +奴 +奶 +奸 +她 +好 +如 +妃 +妄 +妆 +妇 +妈 +妊 +妍 +妒 +妓 +妖 +妘 +妙 +妝 +妞 +妣 +妤 +妥 +妨 +妩 +妪 +妮 +妲 +妳 +妹 +妻 +妾 +姆 +姉 +姊 +始 +姍 +姐 +姑 +姒 +姓 +委 +姗 +姚 +姜 +姝 +姣 +姥 +姦 +姨 +姪 +姫 +姬 +姹 +姻 +姿 +威 +娃 +娄 +娅 +娆 +娇 +娉 +娑 +娓 +娘 +娛 +娜 +娟 +娠 +娣 +娥 +娩 +娱 +娲 +娴 +娶 +娼 +婀 +婁 +婆 +婉 +婊 +婕 +婚 +婢 +婦 +婧 +婪 +婭 +婴 +婵 +婶 +婷 +婺 +婿 +媒 +媚 +媛 +媞 +媧 +媲 +媳 +媽 +媾 +嫁 +嫂 +嫉 +嫌 +嫑 +嫔 +嫖 +嫘 +嫚 +嫡 +嫣 +嫦 +嫩 +嫲 +嫵 +嫻 +嬅 +嬉 +嬌 +嬗 +嬛 +嬢 +嬤 +嬪 +嬰 +嬴 +嬷 +嬸 +嬿 +孀 +孃 +子 +孑 +孔 +孕 +孖 +字 +存 +孙 +孚 +孛 +孜 +孝 +孟 +孢 +季 +孤 +学 +孩 +孪 +孫 +孬 +孰 +孱 +孳 +孵 +學 +孺 +孽 +孿 +宁 +它 +宅 +宇 +守 +安 +宋 +完 +宏 +宓 +宕 +宗 +官 +宙 +定 +宛 +宜 +宝 +实 +実 +宠 +审 +客 +宣 +室 +宥 +宦 +宪 +宫 +宮 +宰 +害 +宴 +宵 +家 +宸 +容 +宽 +宾 +宿 +寂 +寄 +寅 +密 +寇 +富 +寐 +寒 +寓 +寛 +寝 +寞 +察 +寡 +寢 +寥 +實 +寧 +寨 +審 +寫 +寬 +寮 +寰 +寵 +寶 +寸 +对 +寺 +寻 +导 +対 +寿 +封 +専 +射 +将 +將 +專 +尉 +尊 +尋 +對 +導 +小 +少 +尔 +尕 +尖 +尘 +尚 +尝 +尤 +尧 +尬 +就 +尴 +尷 +尸 +尹 +尺 +尻 +尼 +尽 +尾 +尿 +局 +屁 +层 +屄 +居 +屆 +屈 +屉 +届 +屋 +屌 +屍 +屎 +屏 +屐 +屑 +展 +屜 +属 +屠 +屡 +屢 +層 +履 +屬 +屯 +山 +屹 +屿 +岀 +岁 +岂 +岌 +岐 +岑 +岔 +岖 +岗 +岘 +岙 +岚 +岛 +岡 +岩 +岫 +岬 +岭 +岱 +岳 +岷 +岸 +峇 +峋 +峒 +峙 +峡 +峤 +峥 +峦 +峨 +峪 +峭 +峯 +峰 +峴 +島 +峻 +峽 +崁 +崂 +崆 +崇 +崎 +崑 +崔 +崖 +崗 +崙 +崛 +崧 +崩 +崭 +崴 +崽 +嵇 +嵊 +嵋 +嵌 +嵐 +嵘 +嵩 +嵬 +嵯 +嶂 +嶄 +嶇 +嶋 +嶙 +嶺 +嶼 +嶽 +巅 +巍 +巒 +巔 +巖 +川 +州 +巡 +巢 +工 +左 +巧 +巨 +巩 +巫 +差 +己 +已 +巳 +巴 +巷 +巻 +巽 +巾 +巿 +币 +市 +布 +帅 +帆 +师 +希 +帐 +帑 +帕 +帖 +帘 +帚 +帛 +帜 +帝 +帥 +带 +帧 +師 +席 +帮 +帯 +帰 +帳 +帶 +帷 +常 +帼 +帽 +幀 +幂 +幄 +幅 +幌 +幔 +幕 +幟 +幡 +幢 +幣 +幫 +干 +平 +年 +并 +幸 +幹 +幺 +幻 +幼 +幽 +幾 +广 +庁 +広 +庄 +庆 +庇 +床 +序 +庐 +库 +应 +底 +庖 +店 +庙 +庚 +府 +庞 +废 +庠 +度 +座 +庫 +庭 +庵 +庶 +康 +庸 +庹 +庾 +廁 +廂 +廃 +廈 +廉 +廊 +廓 +廖 +廚 +廝 +廟 +廠 +廢 +廣 +廬 +廳 +延 +廷 +建 +廿 +开 +弁 +异 +弃 +弄 +弈 +弊 +弋 +式 +弑 +弒 +弓 +弔 +引 +弗 +弘 +弛 +弟 +张 +弥 +弦 +弧 +弩 +弭 +弯 +弱 +張 +強 +弹 +强 +弼 +弾 +彅 +彆 +彈 +彌 +彎 +归 +当 +录 +彗 +彙 +彝 +形 +彤 +彥 +彦 +彧 +彩 +彪 +彫 +彬 +彭 +彰 +影 +彷 +役 +彻 +彼 +彿 +往 +征 +径 +待 +徇 +很 +徉 +徊 +律 +後 +徐 +徑 +徒 +従 +徕 +得 +徘 +徙 +徜 +從 +徠 +御 +徨 +復 +循 +徬 +微 +徳 +徴 +徵 +德 +徹 +徼 +徽 +心 +必 +忆 +忌 +忍 +忏 +忐 +忑 +忒 +忖 +志 +忘 +忙 +応 +忠 +忡 +忤 +忧 +忪 +快 +忱 +念 +忻 +忽 +忿 +怀 +态 +怂 +怅 +怆 +怎 +怏 +怒 +怔 +怕 +怖 +怙 +怜 +思 +怠 +怡 +急 +怦 +性 +怨 +怪 +怯 +怵 +总 +怼 +恁 +恃 +恆 +恋 +恍 +恐 +恒 +恕 +恙 +恚 +恢 +恣 +恤 +恥 +恨 +恩 +恪 +恫 +恬 +恭 +息 +恰 +恳 +恵 +恶 +恸 +恺 +恻 +恼 +恿 +悄 +悅 +悉 +悌 +悍 +悔 +悖 +悚 +悟 +悠 +患 +悦 +您 +悩 +悪 +悬 +悯 +悱 +悲 +悴 +悵 +悶 +悸 +悻 +悼 +悽 +情 +惆 +惇 +惊 +惋 +惑 +惕 +惘 +惚 +惜 +惟 +惠 +惡 +惦 +惧 +惨 +惩 +惫 +惬 +惭 +惮 +惯 +惰 +惱 +想 +惴 +惶 +惹 +惺 +愁 +愆 +愈 +愉 +愍 +意 +愕 +愚 +愛 +愜 +感 +愣 +愤 +愧 +愫 +愷 +愿 +慄 +慈 +態 +慌 +慎 +慑 +慕 +慘 +慚 +慟 +慢 +慣 +慧 +慨 +慫 +慮 +慰 +慳 +慵 +慶 +慷 +慾 +憂 +憊 +憋 +憎 +憐 +憑 +憔 +憚 +憤 +憧 +憨 +憩 +憫 +憬 +憲 +憶 +憾 +懂 +懇 +懈 +應 +懊 +懋 +懑 +懒 +懦 +懲 +懵 +懶 +懷 +懸 +懺 +懼 +懾 +懿 +戀 +戈 +戊 +戌 +戍 +戎 +戏 +成 +我 +戒 +戕 +或 +战 +戚 +戛 +戟 +戡 +戦 +截 +戬 +戮 +戰 +戲 +戳 +戴 +戶 +户 +戸 +戻 +戾 +房 +所 +扁 +扇 +扈 +扉 +手 +才 +扎 +扑 +扒 +打 +扔 +払 +托 +扛 +扣 +扦 +执 +扩 +扪 +扫 +扬 +扭 +扮 +扯 +扰 +扱 +扳 +扶 +批 +扼 +找 +承 +技 +抄 +抉 +把 +抑 +抒 +抓 +投 +抖 +抗 +折 +抚 +抛 +抜 +択 +抟 +抠 +抡 +抢 +护 +报 +抨 +披 +抬 +抱 +抵 +抹 +押 +抽 +抿 +拂 +拄 +担 +拆 +拇 +拈 +拉 +拋 +拌 +拍 +拎 +拐 +拒 +拓 +拔 +拖 +拗 +拘 +拙 +拚 +招 +拜 +拟 +拡 +拢 +拣 +拥 +拦 +拧 +拨 +择 +括 +拭 +拮 +拯 +拱 +拳 +拴 +拷 +拼 +拽 +拾 +拿 +持 +挂 +指 +挈 +按 +挎 +挑 +挖 +挙 +挚 +挛 +挝 +挞 +挟 +挠 +挡 +挣 +挤 +挥 +挨 +挪 +挫 +振 +挲 +挹 +挺 +挽 +挾 +捂 +捅 +捆 +捉 +捋 +捌 +捍 +捎 +捏 +捐 +捕 +捞 +损 +捡 +换 +捣 +捧 +捨 +捩 +据 +捱 +捲 +捶 +捷 +捺 +捻 +掀 +掂 +掃 +掇 +授 +掉 +掌 +掏 +掐 +排 +掖 +掘 +掙 +掛 +掠 +採 +探 +掣 +接 +控 +推 +掩 +措 +掬 +掰 +掲 +掳 +掴 +掷 +掸 +掺 +揀 +揃 +揄 +揆 +揉 +揍 +描 +提 +插 +揖 +揚 +換 +握 +揣 +揩 +揪 +揭 +揮 +援 +揶 +揸 +揹 +揽 +搀 +搁 +搂 +搅 +損 +搏 +搐 +搓 +搔 +搖 +搗 +搜 +搞 +搡 +搪 +搬 +搭 +搵 +搶 +携 +搽 +摀 +摁 +摄 +摆 +摇 +摈 +摊 +摒 +摔 +摘 +摞 +摟 +摧 +摩 +摯 +摳 +摸 +摹 +摺 +摻 +撂 +撃 +撅 +撇 +撈 +撐 +撑 +撒 +撓 +撕 +撚 +撞 +撤 +撥 +撩 +撫 +撬 +播 +撮 +撰 +撲 +撵 +撷 +撸 +撻 +撼 +撿 +擀 +擁 +擂 +擄 +擅 +擇 +擊 +擋 +操 +擎 +擒 +擔 +擘 +據 +擞 +擠 +擡 +擢 +擦 +擬 +擰 +擱 +擲 +擴 +擷 +擺 +擼 +擾 +攀 +攏 +攒 +攔 +攘 +攙 +攜 +攝 +攞 +攢 +攣 +攤 +攥 +攪 +攫 +攬 +支 +收 +攸 +改 +攻 +放 +政 +故 +效 +敌 +敍 +敎 +敏 +救 +敕 +敖 +敗 +敘 +教 +敛 +敝 +敞 +敢 +散 +敦 +敬 +数 +敲 +整 +敵 +敷 +數 +斂 +斃 +文 +斋 +斌 +斎 +斐 +斑 +斓 +斗 +料 +斛 +斜 +斟 +斡 +斤 +斥 +斧 +斩 +斫 +斬 +断 +斯 +新 +斷 +方 +於 +施 +旁 +旃 +旅 +旋 +旌 +旎 +族 +旖 +旗 +无 +既 +日 +旦 +旧 +旨 +早 +旬 +旭 +旮 +旱 +时 +旷 +旺 +旻 +昀 +昂 +昆 +昇 +昉 +昊 +昌 +明 +昏 +易 +昔 +昕 +昙 +星 +映 +春 +昧 +昨 +昭 +是 +昱 +昴 +昵 +昶 +昼 +显 +晁 +時 +晃 +晉 +晋 +晌 +晏 +晒 +晓 +晔 +晕 +晖 +晗 +晚 +晝 +晞 +晟 +晤 +晦 +晨 +晩 +普 +景 +晰 +晴 +晶 +晷 +智 +晾 +暂 +暄 +暇 +暈 +暉 +暌 +暐 +暑 +暖 +暗 +暝 +暢 +暧 +暨 +暫 +暮 +暱 +暴 +暸 +暹 +曄 +曆 +曇 +曉 +曖 +曙 +曜 +曝 +曠 +曦 +曬 +曰 +曲 +曳 +更 +書 +曹 +曼 +曾 +替 +最 +會 +月 +有 +朋 +服 +朐 +朔 +朕 +朗 +望 +朝 +期 +朦 +朧 +木 +未 +末 +本 +札 +朮 +术 +朱 +朴 +朵 +机 +朽 +杀 +杂 +权 +杆 +杈 +杉 +李 +杏 +材 +村 +杓 +杖 +杜 +杞 +束 +杠 +条 +来 +杨 +杭 +杯 +杰 +東 +杳 +杵 +杷 +杼 +松 +板 +极 +构 +枇 +枉 +枋 +析 +枕 +林 +枚 +果 +枝 +枢 +枣 +枪 +枫 +枭 +枯 +枰 +枱 +枳 +架 +枷 +枸 +柄 +柏 +某 +柑 +柒 +染 +柔 +柘 +柚 +柜 +柞 +柠 +柢 +查 +柩 +柬 +柯 +柱 +柳 +柴 +柵 +査 +柿 +栀 +栃 +栄 +栅 +标 +栈 +栉 +栋 +栎 +栏 +树 +栓 +栖 +栗 +校 +栩 +株 +样 +核 +根 +格 +栽 +栾 +桀 +桁 +桂 +桃 +桅 +框 +案 +桉 +桌 +桎 +桐 +桑 +桓 +桔 +桜 +桠 +桡 +桢 +档 +桥 +桦 +桧 +桨 +桩 +桶 +桿 +梁 +梅 +梆 +梏 +梓 +梗 +條 +梟 +梢 +梦 +梧 +梨 +梭 +梯 +械 +梳 +梵 +梶 +检 +棂 +棄 +棉 +棋 +棍 +棒 +棕 +棗 +棘 +棚 +棟 +棠 +棣 +棧 +森 +棱 +棲 +棵 +棹 +棺 +椁 +椅 +椋 +植 +椎 +椒 +検 +椪 +椭 +椰 +椹 +椽 +椿 +楂 +楊 +楓 +楔 +楚 +楝 +楞 +楠 +楣 +楨 +楫 +業 +楮 +極 +楷 +楸 +楹 +楼 +楽 +概 +榄 +榆 +榈 +榉 +榔 +榕 +榖 +榛 +榜 +榨 +榫 +榭 +榮 +榱 +榴 +榷 +榻 +槁 +槃 +構 +槌 +槍 +槎 +槐 +槓 +様 +槛 +槟 +槤 +槭 +槲 +槳 +槻 +槽 +槿 +樁 +樂 +樊 +樑 +樓 +標 +樞 +樟 +模 +樣 +権 +横 +樫 +樯 +樱 +樵 +樸 +樹 +樺 +樽 +樾 +橄 +橇 +橋 +橐 +橘 +橙 +機 +橡 +橢 +橫 +橱 +橹 +橼 +檀 +檄 +檎 +檐 +檔 +檗 +檜 +檢 +檬 +檯 +檳 +檸 +檻 +櫃 +櫚 +櫛 +櫥 +櫸 +櫻 +欄 +權 +欒 +欖 +欠 +次 +欢 +欣 +欧 +欲 +欸 +欺 +欽 +款 +歆 +歇 +歉 +歌 +歎 +歐 +歓 +歙 +歛 +歡 +止 +正 +此 +步 +武 +歧 +歩 +歪 +歯 +歲 +歳 +歴 +歷 +歸 +歹 +死 +歼 +殁 +殃 +殆 +殇 +殉 +殊 +残 +殒 +殓 +殖 +殘 +殞 +殡 +殤 +殭 +殯 +殲 +殴 +段 +殷 +殺 +殼 +殿 +毀 +毁 +毂 +毅 +毆 +毋 +母 +毎 +每 +毒 +毓 +比 +毕 +毗 +毘 +毙 +毛 +毡 +毫 +毯 +毽 +氈 +氏 +氐 +民 +氓 +气 +氖 +気 +氙 +氛 +氟 +氡 +氢 +氣 +氤 +氦 +氧 +氨 +氪 +氫 +氮 +氯 +氰 +氲 +水 +氷 +永 +氹 +氾 +汀 +汁 +求 +汆 +汇 +汉 +汎 +汐 +汕 +汗 +汙 +汛 +汝 +汞 +江 +池 +污 +汤 +汨 +汩 +汪 +汰 +汲 +汴 +汶 +汹 +決 +汽 +汾 +沁 +沂 +沃 +沅 +沈 +沉 +沌 +沏 +沐 +沒 +沓 +沖 +沙 +沛 +沟 +没 +沢 +沣 +沥 +沦 +沧 +沪 +沫 +沭 +沮 +沱 +河 +沸 +油 +治 +沼 +沽 +沾 +沿 +況 +泄 +泉 +泊 +泌 +泓 +法 +泗 +泛 +泞 +泠 +泡 +波 +泣 +泥 +注 +泪 +泫 +泮 +泯 +泰 +泱 +泳 +泵 +泷 +泸 +泻 +泼 +泽 +泾 +洁 +洄 +洋 +洒 +洗 +洙 +洛 +洞 +津 +洩 +洪 +洮 +洱 +洲 +洵 +洶 +洸 +洹 +活 +洼 +洽 +派 +流 +浃 +浄 +浅 +浆 +浇 +浊 +测 +济 +浏 +浑 +浒 +浓 +浔 +浙 +浚 +浜 +浣 +浦 +浩 +浪 +浬 +浮 +浯 +浴 +海 +浸 +涂 +涅 +涇 +消 +涉 +涌 +涎 +涓 +涔 +涕 +涙 +涛 +涝 +涞 +涟 +涠 +涡 +涣 +涤 +润 +涧 +涨 +涩 +涪 +涮 +涯 +液 +涵 +涸 +涼 +涿 +淀 +淄 +淅 +淆 +淇 +淋 +淌 +淑 +淒 +淖 +淘 +淙 +淚 +淞 +淡 +淤 +淦 +淨 +淩 +淪 +淫 +淬 +淮 +深 +淳 +淵 +混 +淹 +淺 +添 +淼 +清 +済 +渉 +渊 +渋 +渍 +渎 +渐 +渔 +渗 +渙 +渚 +減 +渝 +渠 +渡 +渣 +渤 +渥 +渦 +温 +測 +渭 +港 +渲 +渴 +游 +渺 +渾 +湃 +湄 +湊 +湍 +湖 +湘 +湛 +湟 +湧 +湫 +湮 +湯 +湳 +湾 +湿 +満 +溃 +溅 +溉 +溏 +源 +準 +溜 +溝 +溟 +溢 +溥 +溧 +溪 +溫 +溯 +溱 +溴 +溶 +溺 +溼 +滁 +滂 +滄 +滅 +滇 +滋 +滌 +滑 +滓 +滔 +滕 +滙 +滚 +滝 +滞 +滟 +满 +滢 +滤 +滥 +滦 +滨 +滩 +滬 +滯 +滲 +滴 +滷 +滸 +滾 +滿 +漁 +漂 +漆 +漉 +漏 +漓 +演 +漕 +漠 +漢 +漣 +漩 +漪 +漫 +漬 +漯 +漱 +漲 +漳 +漸 +漾 +漿 +潆 +潇 +潋 +潍 +潑 +潔 +潘 +潛 +潜 +潞 +潟 +潢 +潤 +潦 +潧 +潭 +潮 +潰 +潴 +潸 +潺 +潼 +澀 +澄 +澆 +澈 +澍 +澎 +澗 +澜 +澡 +澤 +澧 +澱 +澳 +澹 +激 +濁 +濂 +濃 +濑 +濒 +濕 +濘 +濛 +濟 +濠 +濡 +濤 +濫 +濬 +濮 +濯 +濱 +濺 +濾 +瀅 +瀆 +瀉 +瀋 +瀏 +瀑 +瀕 +瀘 +瀚 +瀛 +瀝 +瀞 +瀟 +瀧 +瀨 +瀬 +瀰 +瀾 +灌 +灏 +灑 +灘 +灝 +灞 +灣 +火 +灬 +灭 +灯 +灰 +灵 +灶 +灸 +灼 +災 +灾 +灿 +炀 +炁 +炅 +炉 +炊 +炎 +炒 +炔 +炕 +炖 +炙 +炜 +炫 +炬 +炭 +炮 +炯 +炳 +炷 +炸 +点 +為 +炼 +炽 +烁 +烂 +烃 +烈 +烊 +烏 +烘 +烙 +烛 +烟 +烤 +烦 +烧 +烨 +烩 +烫 +烬 +热 +烯 +烷 +烹 +烽 +焉 +焊 +焕 +焖 +焗 +焘 +焙 +焚 +焜 +無 +焦 +焯 +焰 +焱 +然 +焼 +煅 +煉 +煊 +煌 +煎 +煒 +煖 +煙 +煜 +煞 +煤 +煥 +煦 +照 +煨 +煩 +煮 +煲 +煸 +煽 +熄 +熊 +熏 +熒 +熔 +熙 +熟 +熠 +熨 +熬 +熱 +熵 +熹 +熾 +燁 +燃 +燄 +燈 +燉 +燊 +燎 +燒 +燔 +燕 +燙 +燜 +營 +燥 +燦 +燧 +燭 +燮 +燴 +燻 +燼 +燿 +爆 +爍 +爐 +爛 +爪 +爬 +爭 +爰 +爱 +爲 +爵 +父 +爷 +爸 +爹 +爺 +爻 +爽 +爾 +牆 +片 +版 +牌 +牍 +牒 +牙 +牛 +牝 +牟 +牠 +牡 +牢 +牦 +牧 +物 +牯 +牲 +牴 +牵 +特 +牺 +牽 +犀 +犁 +犄 +犊 +犍 +犒 +犢 +犧 +犬 +犯 +状 +犷 +犸 +犹 +狀 +狂 +狄 +狈 +狎 +狐 +狒 +狗 +狙 +狞 +狠 +狡 +狩 +独 +狭 +狮 +狰 +狱 +狸 +狹 +狼 +狽 +猎 +猕 +猖 +猗 +猙 +猛 +猜 +猝 +猥 +猩 +猪 +猫 +猬 +献 +猴 +猶 +猷 +猾 +猿 +獄 +獅 +獎 +獐 +獒 +獗 +獠 +獣 +獨 +獭 +獰 +獲 +獵 +獷 +獸 +獺 +獻 +獼 +獾 +玄 +率 +玉 +王 +玑 +玖 +玛 +玟 +玠 +玥 +玩 +玫 +玮 +环 +现 +玲 +玳 +玷 +玺 +玻 +珀 +珂 +珅 +珈 +珉 +珊 +珍 +珏 +珐 +珑 +珙 +珞 +珠 +珣 +珥 +珩 +珪 +班 +珮 +珲 +珺 +現 +球 +琅 +理 +琇 +琉 +琊 +琍 +琏 +琐 +琛 +琢 +琥 +琦 +琨 +琪 +琬 +琮 +琰 +琲 +琳 +琴 +琵 +琶 +琺 +琼 +瑀 +瑁 +瑄 +瑋 +瑕 +瑗 +瑙 +瑚 +瑛 +瑜 +瑞 +瑟 +瑠 +瑣 +瑤 +瑩 +瑪 +瑯 +瑰 +瑶 +瑾 +璀 +璁 +璃 +璇 +璉 +璋 +璎 +璐 +璜 +璞 +璟 +璧 +璨 +環 +璽 +璿 +瓊 +瓏 +瓒 +瓜 +瓢 +瓣 +瓤 +瓦 +瓮 +瓯 +瓴 +瓶 +瓷 +甄 +甌 +甕 +甘 +甙 +甚 +甜 +生 +產 +産 +甥 +甦 +用 +甩 +甫 +甬 +甭 +甯 +田 +由 +甲 +申 +电 +男 +甸 +町 +画 +甾 +畀 +畅 +界 +畏 +畑 +畔 +留 +畜 +畝 +畢 +略 +畦 +番 +畫 +異 +畲 +畳 +畴 +當 +畸 +畹 +畿 +疆 +疇 +疊 +疏 +疑 +疔 +疖 +疗 +疙 +疚 +疝 +疟 +疡 +疣 +疤 +疥 +疫 +疮 +疯 +疱 +疲 +疳 +疵 +疸 +疹 +疼 +疽 +疾 +痂 +病 +症 +痈 +痉 +痊 +痍 +痒 +痔 +痕 +痘 +痙 +痛 +痞 +痠 +痢 +痣 +痤 +痧 +痨 +痪 +痫 +痰 +痱 +痴 +痹 +痺 +痼 +痿 +瘀 +瘁 +瘋 +瘍 +瘓 +瘘 +瘙 +瘟 +瘠 +瘡 +瘢 +瘤 +瘦 +瘧 +瘩 +瘪 +瘫 +瘴 +瘸 +瘾 +療 +癇 +癌 +癒 +癖 +癜 +癞 +癡 +癢 +癣 +癥 +癫 +癬 +癮 +癱 +癲 +癸 +発 +登 +發 +白 +百 +皂 +的 +皆 +皇 +皈 +皋 +皎 +皑 +皓 +皖 +皙 +皚 +皮 +皰 +皱 +皴 +皺 +皿 +盂 +盃 +盅 +盆 +盈 +益 +盎 +盏 +盐 +监 +盒 +盔 +盖 +盗 +盘 +盛 +盜 +盞 +盟 +盡 +監 +盤 +盥 +盧 +盪 +目 +盯 +盱 +盲 +直 +相 +盹 +盼 +盾 +省 +眈 +眉 +看 +県 +眙 +眞 +真 +眠 +眦 +眨 +眩 +眯 +眶 +眷 +眸 +眺 +眼 +眾 +着 +睁 +睇 +睏 +睐 +睑 +睛 +睜 +睞 +睡 +睢 +督 +睥 +睦 +睨 +睪 +睫 +睬 +睹 +睽 +睾 +睿 +瞄 +瞅 +瞇 +瞋 +瞌 +瞎 +瞑 +瞒 +瞓 +瞞 +瞟 +瞠 +瞥 +瞧 +瞩 +瞪 +瞬 +瞭 +瞰 +瞳 +瞻 +瞼 +瞿 +矇 +矍 +矗 +矚 +矛 +矜 +矢 +矣 +知 +矩 +矫 +短 +矮 +矯 +石 +矶 +矽 +矾 +矿 +码 +砂 +砌 +砍 +砒 +研 +砖 +砗 +砚 +砝 +砣 +砥 +砧 +砭 +砰 +砲 +破 +砷 +砸 +砺 +砼 +砾 +础 +硅 +硐 +硒 +硕 +硝 +硫 +硬 +确 +硯 +硼 +碁 +碇 +碉 +碌 +碍 +碎 +碑 +碓 +碗 +碘 +碚 +碛 +碟 +碣 +碧 +碩 +碰 +碱 +碳 +碴 +確 +碼 +碾 +磁 +磅 +磊 +磋 +磐 +磕 +磚 +磡 +磨 +磬 +磯 +磲 +磷 +磺 +礁 +礎 +礙 +礡 +礦 +礪 +礫 +礴 +示 +礼 +社 +祀 +祁 +祂 +祇 +祈 +祉 +祎 +祐 +祕 +祖 +祗 +祚 +祛 +祜 +祝 +神 +祟 +祠 +祢 +祥 +票 +祭 +祯 +祷 +祸 +祺 +祿 +禀 +禁 +禄 +禅 +禍 +禎 +福 +禛 +禦 +禧 +禪 +禮 +禱 +禹 +禺 +离 +禽 +禾 +禿 +秀 +私 +秃 +秆 +秉 +秋 +种 +科 +秒 +秘 +租 +秣 +秤 +秦 +秧 +秩 +秭 +积 +称 +秸 +移 +秽 +稀 +稅 +程 +稍 +税 +稔 +稗 +稚 +稜 +稞 +稟 +稠 +稣 +種 +稱 +稲 +稳 +稷 +稹 +稻 +稼 +稽 +稿 +穀 +穂 +穆 +穌 +積 +穎 +穗 +穢 +穩 +穫 +穴 +究 +穷 +穹 +空 +穿 +突 +窃 +窄 +窈 +窍 +窑 +窒 +窓 +窕 +窖 +窗 +窘 +窜 +窝 +窟 +窠 +窥 +窦 +窨 +窩 +窪 +窮 +窯 +窺 +窿 +竄 +竅 +竇 +竊 +立 +竖 +站 +竜 +竞 +竟 +章 +竣 +童 +竭 +端 +競 +竹 +竺 +竽 +竿 +笃 +笆 +笈 +笋 +笏 +笑 +笔 +笙 +笛 +笞 +笠 +符 +笨 +第 +笹 +笺 +笼 +筆 +等 +筊 +筋 +筍 +筏 +筐 +筑 +筒 +答 +策 +筛 +筝 +筠 +筱 +筲 +筵 +筷 +筹 +签 +简 +箇 +箋 +箍 +箏 +箐 +箔 +箕 +算 +箝 +管 +箩 +箫 +箭 +箱 +箴 +箸 +節 +篁 +範 +篆 +篇 +築 +篑 +篓 +篙 +篝 +篠 +篡 +篤 +篩 +篪 +篮 +篱 +篷 +簇 +簌 +簍 +簡 +簦 +簧 +簪 +簫 +簷 +簸 +簽 +簾 +簿 +籁 +籃 +籌 +籍 +籐 +籟 +籠 +籤 +籬 +籮 +籲 +米 +类 +籼 +籽 +粄 +粉 +粑 +粒 +粕 +粗 +粘 +粟 +粤 +粥 +粧 +粪 +粮 +粱 +粲 +粳 +粵 +粹 +粼 +粽 +精 +粿 +糅 +糊 +糍 +糕 +糖 +糗 +糙 +糜 +糞 +糟 +糠 +糧 +糬 +糯 +糰 +糸 +系 +糾 +紀 +紂 +約 +紅 +紉 +紊 +紋 +納 +紐 +紓 +純 +紗 +紘 +紙 +級 +紛 +紜 +素 +紡 +索 +紧 +紫 +紮 +累 +細 +紳 +紹 +紺 +終 +絃 +組 +絆 +経 +結 +絕 +絞 +絡 +絢 +給 +絨 +絮 +統 +絲 +絳 +絵 +絶 +絹 +綁 +綏 +綑 +經 +継 +続 +綜 +綠 +綢 +綦 +綫 +綬 +維 +綱 +網 +綴 +綵 +綸 +綺 +綻 +綽 +綾 +綿 +緊 +緋 +総 +緑 +緒 +緘 +線 +緝 +緞 +締 +緣 +編 +緩 +緬 +緯 +練 +緹 +緻 +縁 +縄 +縈 +縛 +縝 +縣 +縫 +縮 +縱 +縴 +縷 +總 +績 +繁 +繃 +繆 +繇 +繋 +織 +繕 +繚 +繞 +繡 +繩 +繪 +繫 +繭 +繳 +繹 +繼 +繽 +纂 +續 +纍 +纏 +纓 +纔 +纖 +纜 +纠 +红 +纣 +纤 +约 +级 +纨 +纪 +纫 +纬 +纭 +纯 +纰 +纱 +纲 +纳 +纵 +纶 +纷 +纸 +纹 +纺 +纽 +纾 +线 +绀 +练 +组 +绅 +细 +织 +终 +绊 +绍 +绎 +经 +绑 +绒 +结 +绔 +绕 +绘 +给 +绚 +绛 +络 +绝 +绞 +统 +绡 +绢 +绣 +绥 +绦 +继 +绩 +绪 +绫 +续 +绮 +绯 +绰 +绳 +维 +绵 +绶 +绷 +绸 +绻 +综 +绽 +绾 +绿 +缀 +缄 +缅 +缆 +缇 +缈 +缉 +缎 +缓 +缔 +缕 +编 +缘 +缙 +缚 +缜 +缝 +缠 +缢 +缤 +缥 +缨 +缩 +缪 +缭 +缮 +缰 +缱 +缴 +缸 +缺 +缽 +罂 +罄 +罌 +罐 +网 +罔 +罕 +罗 +罚 +罡 +罢 +罩 +罪 +置 +罰 +署 +罵 +罷 +罹 +羁 +羅 +羈 +羊 +羌 +美 +羔 +羚 +羞 +羟 +羡 +羣 +群 +羥 +羧 +羨 +義 +羯 +羲 +羸 +羹 +羽 +羿 +翁 +翅 +翊 +翌 +翎 +習 +翔 +翘 +翟 +翠 +翡 +翦 +翩 +翰 +翱 +翳 +翹 +翻 +翼 +耀 +老 +考 +耄 +者 +耆 +耋 +而 +耍 +耐 +耒 +耕 +耗 +耘 +耙 +耦 +耨 +耳 +耶 +耷 +耸 +耻 +耽 +耿 +聂 +聆 +聊 +聋 +职 +聒 +联 +聖 +聘 +聚 +聞 +聪 +聯 +聰 +聲 +聳 +聴 +聶 +職 +聽 +聾 +聿 +肃 +肄 +肅 +肆 +肇 +肉 +肋 +肌 +肏 +肓 +肖 +肘 +肚 +肛 +肝 +肠 +股 +肢 +肤 +肥 +肩 +肪 +肮 +肯 +肱 +育 +肴 +肺 +肽 +肾 +肿 +胀 +胁 +胃 +胄 +胆 +背 +胍 +胎 +胖 +胚 +胛 +胜 +胝 +胞 +胡 +胤 +胥 +胧 +胫 +胭 +胯 +胰 +胱 +胳 +胴 +胶 +胸 +胺 +能 +脂 +脅 +脆 +脇 +脈 +脉 +脊 +脍 +脏 +脐 +脑 +脓 +脖 +脘 +脚 +脛 +脣 +脩 +脫 +脯 +脱 +脲 +脳 +脸 +脹 +脾 +腆 +腈 +腊 +腋 +腌 +腎 +腐 +腑 +腓 +腔 +腕 +腥 +腦 +腩 +腫 +腭 +腮 +腰 +腱 +腳 +腴 +腸 +腹 +腺 +腻 +腼 +腾 +腿 +膀 +膈 +膊 +膏 +膑 +膘 +膚 +膛 +膜 +膝 +膠 +膦 +膨 +膩 +膳 +膺 +膻 +膽 +膾 +膿 +臀 +臂 +臃 +臆 +臉 +臊 +臍 +臓 +臘 +臟 +臣 +臥 +臧 +臨 +自 +臬 +臭 +至 +致 +臺 +臻 +臼 +臾 +舀 +舂 +舅 +舆 +與 +興 +舉 +舊 +舌 +舍 +舎 +舐 +舒 +舔 +舖 +舗 +舛 +舜 +舞 +舟 +航 +舫 +般 +舰 +舱 +舵 +舶 +舷 +舸 +船 +舺 +舾 +艇 +艋 +艘 +艙 +艦 +艮 +良 +艰 +艱 +色 +艳 +艷 +艹 +艺 +艾 +节 +芃 +芈 +芊 +芋 +芍 +芎 +芒 +芙 +芜 +芝 +芡 +芥 +芦 +芩 +芪 +芫 +芬 +芭 +芮 +芯 +花 +芳 +芷 +芸 +芹 +芻 +芽 +芾 +苁 +苄 +苇 +苋 +苍 +苏 +苑 +苒 +苓 +苔 +苕 +苗 +苛 +苜 +苞 +苟 +苡 +苣 +若 +苦 +苫 +苯 +英 +苷 +苹 +苻 +茁 +茂 +范 +茄 +茅 +茉 +茎 +茏 +茗 +茜 +茧 +茨 +茫 +茬 +茭 +茯 +茱 +茲 +茴 +茵 +茶 +茸 +茹 +茼 +荀 +荃 +荆 +草 +荊 +荏 +荐 +荒 +荔 +荖 +荘 +荚 +荞 +荟 +荠 +荡 +荣 +荤 +荥 +荧 +荨 +荪 +荫 +药 +荳 +荷 +荸 +荻 +荼 +荽 +莅 +莆 +莉 +莊 +莎 +莒 +莓 +莖 +莘 +莞 +莠 +莢 +莧 +莪 +莫 +莱 +莲 +莴 +获 +莹 +莺 +莽 +莿 +菀 +菁 +菅 +菇 +菈 +菊 +菌 +菏 +菓 +菖 +菘 +菜 +菟 +菠 +菡 +菩 +華 +菱 +菲 +菸 +菽 +萁 +萃 +萄 +萊 +萋 +萌 +萍 +萎 +萘 +萝 +萤 +营 +萦 +萧 +萨 +萩 +萬 +萱 +萵 +萸 +萼 +落 +葆 +葉 +著 +葚 +葛 +葡 +董 +葦 +葩 +葫 +葬 +葭 +葯 +葱 +葳 +葵 +葷 +葺 +蒂 +蒋 +蒐 +蒔 +蒙 +蒜 +蒞 +蒟 +蒡 +蒨 +蒲 +蒸 +蒹 +蒻 +蒼 +蒿 +蓁 +蓄 +蓆 +蓉 +蓋 +蓑 +蓓 +蓖 +蓝 +蓟 +蓦 +蓬 +蓮 +蓼 +蓿 +蔑 +蔓 +蔔 +蔗 +蔘 +蔚 +蔡 +蔣 +蔥 +蔫 +蔬 +蔭 +蔵 +蔷 +蔺 +蔻 +蔼 +蔽 +蕁 +蕃 +蕈 +蕉 +蕊 +蕎 +蕙 +蕤 +蕨 +蕩 +蕪 +蕭 +蕲 +蕴 +蕻 +蕾 +薄 +薅 +薇 +薈 +薊 +薏 +薑 +薔 +薙 +薛 +薦 +薨 +薩 +薪 +薬 +薯 +薰 +薹 +藉 +藍 +藏 +藐 +藓 +藕 +藜 +藝 +藤 +藥 +藩 +藹 +藻 +藿 +蘆 +蘇 +蘊 +蘋 +蘑 +蘚 +蘭 +蘸 +蘼 +蘿 +虎 +虏 +虐 +虑 +虔 +處 +虚 +虛 +虜 +虞 +號 +虢 +虧 +虫 +虬 +虱 +虹 +虻 +虽 +虾 +蚀 +蚁 +蚂 +蚊 +蚌 +蚓 +蚕 +蚜 +蚝 +蚣 +蚤 +蚩 +蚪 +蚯 +蚱 +蚵 +蛀 +蛆 +蛇 +蛊 +蛋 +蛎 +蛐 +蛔 +蛙 +蛛 +蛟 +蛤 +蛭 +蛮 +蛰 +蛳 +蛹 +蛻 +蛾 +蜀 +蜂 +蜃 +蜆 +蜇 +蜈 +蜊 +蜍 +蜒 +蜓 +蜕 +蜗 +蜘 +蜚 +蜜 +蜡 +蜢 +蜥 +蜱 +蜴 +蜷 +蜻 +蜿 +蝇 +蝈 +蝉 +蝌 +蝎 +蝕 +蝗 +蝙 +蝟 +蝠 +蝦 +蝨 +蝴 +蝶 +蝸 +蝼 +螂 +螃 +融 +螞 +螢 +螨 +螯 +螳 +螺 +蟀 +蟄 +蟆 +蟋 +蟎 +蟑 +蟒 +蟠 +蟬 +蟲 +蟹 +蟻 +蟾 +蠅 +蠍 +蠔 +蠕 +蠛 +蠟 +蠡 +蠢 +蠣 +蠱 +蠶 +蠹 +蠻 +血 +衄 +衅 +衆 +行 +衍 +術 +衔 +街 +衙 +衛 +衝 +衞 +衡 +衢 +衣 +补 +表 +衩 +衫 +衬 +衮 +衰 +衲 +衷 +衹 +衾 +衿 +袁 +袂 +袄 +袅 +袈 +袋 +袍 +袒 +袖 +袜 +袞 +袤 +袪 +被 +袭 +袱 +裁 +裂 +装 +裆 +裊 +裏 +裔 +裕 +裘 +裙 +補 +裝 +裟 +裡 +裤 +裨 +裱 +裳 +裴 +裸 +裹 +製 +裾 +褂 +複 +褐 +褒 +褓 +褔 +褚 +褥 +褪 +褫 +褲 +褶 +褻 +襁 +襄 +襟 +襠 +襪 +襬 +襯 +襲 +西 +要 +覃 +覆 +覇 +見 +規 +覓 +視 +覚 +覦 +覧 +親 +覬 +観 +覷 +覺 +覽 +觀 +见 +观 +规 +觅 +视 +览 +觉 +觊 +觎 +觐 +觑 +角 +觞 +解 +觥 +触 +觸 +言 +訂 +計 +訊 +討 +訓 +訕 +訖 +託 +記 +訛 +訝 +訟 +訣 +訥 +訪 +設 +許 +訳 +訴 +訶 +診 +註 +証 +詆 +詐 +詔 +評 +詛 +詞 +詠 +詡 +詢 +詣 +試 +詩 +詫 +詬 +詭 +詮 +詰 +話 +該 +詳 +詹 +詼 +誅 +誇 +誉 +誌 +認 +誓 +誕 +誘 +語 +誠 +誡 +誣 +誤 +誥 +誦 +誨 +說 +説 +読 +誰 +課 +誹 +誼 +調 +諄 +談 +請 +諏 +諒 +論 +諗 +諜 +諡 +諦 +諧 +諫 +諭 +諮 +諱 +諳 +諷 +諸 +諺 +諾 +謀 +謁 +謂 +謄 +謊 +謎 +謐 +謔 +謗 +謙 +講 +謝 +謠 +謨 +謬 +謹 +謾 +譁 +證 +譎 +譏 +識 +譙 +譚 +譜 +警 +譬 +譯 +議 +譲 +譴 +護 +譽 +讀 +變 +讓 +讚 +讞 +计 +订 +认 +讥 +讧 +讨 +让 +讪 +讫 +训 +议 +讯 +记 +讲 +讳 +讴 +讶 +讷 +许 +讹 +论 +讼 +讽 +设 +访 +诀 +证 +诃 +评 +诅 +识 +诈 +诉 +诊 +诋 +词 +诏 +译 +试 +诗 +诘 +诙 +诚 +诛 +话 +诞 +诟 +诠 +诡 +询 +诣 +诤 +该 +详 +诧 +诩 +诫 +诬 +语 +误 +诰 +诱 +诲 +说 +诵 +诶 +请 +诸 +诺 +读 +诽 +课 +诿 +谀 +谁 +调 +谄 +谅 +谆 +谈 +谊 +谋 +谌 +谍 +谎 +谏 +谐 +谑 +谒 +谓 +谔 +谕 +谗 +谘 +谙 +谚 +谛 +谜 +谟 +谢 +谣 +谤 +谥 +谦 +谧 +谨 +谩 +谪 +谬 +谭 +谯 +谱 +谲 +谴 +谶 +谷 +豁 +豆 +豇 +豈 +豉 +豊 +豌 +豎 +豐 +豔 +豚 +象 +豢 +豪 +豫 +豬 +豹 +豺 +貂 +貅 +貌 +貓 +貔 +貘 +貝 +貞 +負 +財 +貢 +貧 +貨 +販 +貪 +貫 +責 +貯 +貰 +貳 +貴 +貶 +買 +貸 +費 +貼 +貽 +貿 +賀 +賁 +賂 +賃 +賄 +資 +賈 +賊 +賑 +賓 +賜 +賞 +賠 +賡 +賢 +賣 +賤 +賦 +質 +賬 +賭 +賴 +賺 +購 +賽 +贅 +贈 +贊 +贍 +贏 +贓 +贖 +贛 +贝 +贞 +负 +贡 +财 +责 +贤 +败 +账 +货 +质 +贩 +贪 +贫 +贬 +购 +贮 +贯 +贰 +贱 +贲 +贴 +贵 +贷 +贸 +费 +贺 +贻 +贼 +贾 +贿 +赁 +赂 +赃 +资 +赅 +赈 +赊 +赋 +赌 +赎 +赏 +赐 +赓 +赔 +赖 +赘 +赚 +赛 +赝 +赞 +赠 +赡 +赢 +赣 +赤 +赦 +赧 +赫 +赭 +走 +赳 +赴 +赵 +赶 +起 +趁 +超 +越 +趋 +趕 +趙 +趟 +趣 +趨 +足 +趴 +趵 +趸 +趺 +趾 +跃 +跄 +跆 +跋 +跌 +跎 +跑 +跖 +跚 +跛 +距 +跟 +跡 +跤 +跨 +跩 +跪 +路 +跳 +践 +跷 +跹 +跺 +跻 +踉 +踊 +踌 +踏 +踐 +踝 +踞 +踟 +踢 +踩 +踪 +踮 +踱 +踴 +踵 +踹 +蹂 +蹄 +蹇 +蹈 +蹉 +蹊 +蹋 +蹑 +蹒 +蹙 +蹟 +蹣 +蹤 +蹦 +蹩 +蹬 +蹭 +蹲 +蹴 +蹶 +蹺 +蹼 +蹿 +躁 +躇 +躉 +躊 +躋 +躍 +躏 +躪 +身 +躬 +躯 +躲 +躺 +軀 +車 +軋 +軌 +軍 +軒 +軟 +転 +軸 +軼 +軽 +軾 +較 +載 +輒 +輓 +輔 +輕 +輛 +輝 +輟 +輩 +輪 +輯 +輸 +輻 +輾 +輿 +轄 +轅 +轆 +轉 +轍 +轎 +轟 +车 +轧 +轨 +轩 +转 +轭 +轮 +软 +轰 +轲 +轴 +轶 +轻 +轼 +载 +轿 +较 +辄 +辅 +辆 +辇 +辈 +辉 +辊 +辍 +辐 +辑 +输 +辕 +辖 +辗 +辘 +辙 +辛 +辜 +辞 +辟 +辣 +辦 +辨 +辩 +辫 +辭 +辮 +辯 +辰 +辱 +農 +边 +辺 +辻 +込 +辽 +达 +迁 +迂 +迄 +迅 +过 +迈 +迎 +运 +近 +返 +还 +这 +进 +远 +违 +连 +迟 +迢 +迤 +迥 +迦 +迩 +迪 +迫 +迭 +述 +迴 +迷 +迸 +迹 +迺 +追 +退 +送 +适 +逃 +逅 +逆 +选 +逊 +逍 +透 +逐 +递 +途 +逕 +逗 +這 +通 +逛 +逝 +逞 +速 +造 +逢 +連 +逮 +週 +進 +逵 +逶 +逸 +逻 +逼 +逾 +遁 +遂 +遅 +遇 +遊 +運 +遍 +過 +遏 +遐 +遑 +遒 +道 +達 +違 +遗 +遙 +遛 +遜 +遞 +遠 +遢 +遣 +遥 +遨 +適 +遭 +遮 +遲 +遴 +遵 +遶 +遷 +選 +遺 +遼 +遽 +避 +邀 +邁 +邂 +邃 +還 +邇 +邈 +邊 +邋 +邏 +邑 +邓 +邕 +邛 +邝 +邢 +那 +邦 +邨 +邪 +邬 +邮 +邯 +邰 +邱 +邳 +邵 +邸 +邹 +邺 +邻 +郁 +郅 +郊 +郎 +郑 +郜 +郝 +郡 +郢 +郤 +郦 +郧 +部 +郫 +郭 +郴 +郵 +郷 +郸 +都 +鄂 +鄉 +鄒 +鄔 +鄙 +鄞 +鄢 +鄧 +鄭 +鄰 +鄱 +鄲 +鄺 +酉 +酊 +酋 +酌 +配 +酐 +酒 +酗 +酚 +酝 +酢 +酣 +酥 +酩 +酪 +酬 +酮 +酯 +酰 +酱 +酵 +酶 +酷 +酸 +酿 +醃 +醇 +醉 +醋 +醍 +醐 +醒 +醚 +醛 +醜 +醞 +醣 +醪 +醫 +醬 +醮 +醯 +醴 +醺 +釀 +釁 +采 +釉 +释 +釋 +里 +重 +野 +量 +釐 +金 +釗 +釘 +釜 +針 +釣 +釦 +釧 +釵 +鈀 +鈉 +鈍 +鈎 +鈔 +鈕 +鈞 +鈣 +鈦 +鈪 +鈴 +鈺 +鈾 +鉀 +鉄 +鉅 +鉉 +鉑 +鉗 +鉚 +鉛 +鉤 +鉴 +鉻 +銀 +銃 +銅 +銑 +銓 +銖 +銘 +銜 +銬 +銭 +銮 +銳 +銷 +銹 +鋁 +鋅 +鋒 +鋤 +鋪 +鋰 +鋸 +鋼 +錄 +錐 +錘 +錚 +錠 +錢 +錦 +錨 +錫 +錮 +錯 +録 +錳 +錶 +鍊 +鍋 +鍍 +鍛 +鍥 +鍰 +鍵 +鍺 +鍾 +鎂 +鎊 +鎌 +鎏 +鎔 +鎖 +鎗 +鎚 +鎧 +鎬 +鎮 +鎳 +鏈 +鏖 +鏗 +鏘 +鏞 +鏟 +鏡 +鏢 +鏤 +鏽 +鐘 +鐮 +鐲 +鐳 +鐵 +鐸 +鐺 +鑄 +鑊 +鑑 +鑒 +鑣 +鑫 +鑰 +鑲 +鑼 +鑽 +鑾 +鑿 +针 +钉 +钊 +钎 +钏 +钒 +钓 +钗 +钙 +钛 +钜 +钝 +钞 +钟 +钠 +钡 +钢 +钣 +钤 +钥 +钦 +钧 +钨 +钩 +钮 +钯 +钰 +钱 +钳 +钴 +钵 +钺 +钻 +钼 +钾 +钿 +铀 +铁 +铂 +铃 +铄 +铅 +铆 +铉 +铎 +铐 +铛 +铜 +铝 +铠 +铡 +铢 +铣 +铤 +铨 +铩 +铬 +铭 +铮 +铰 +铲 +铵 +银 +铸 +铺 +链 +铿 +销 +锁 +锂 +锄 +锅 +锆 +锈 +锉 +锋 +锌 +锏 +锐 +锑 +错 +锚 +锟 +锡 +锢 +锣 +锤 +锥 +锦 +锭 +键 +锯 +锰 +锲 +锵 +锹 +锺 +锻 +镀 +镁 +镂 +镇 +镉 +镌 +镍 +镐 +镑 +镕 +镖 +镗 +镛 +镜 +镣 +镭 +镯 +镰 +镳 +镶 +長 +长 +門 +閃 +閉 +開 +閎 +閏 +閑 +閒 +間 +閔 +閘 +閡 +関 +閣 +閥 +閨 +閩 +閱 +閲 +閹 +閻 +閾 +闆 +闇 +闊 +闌 +闍 +闔 +闕 +闖 +闘 +關 +闡 +闢 +门 +闪 +闫 +闭 +问 +闯 +闰 +闲 +间 +闵 +闷 +闸 +闹 +闺 +闻 +闽 +闾 +阀 +阁 +阂 +阅 +阆 +阇 +阈 +阉 +阎 +阐 +阑 +阔 +阕 +阖 +阙 +阚 +阜 +队 +阡 +阪 +阮 +阱 +防 +阳 +阴 +阵 +阶 +阻 +阿 +陀 +陂 +附 +际 +陆 +陇 +陈 +陋 +陌 +降 +限 +陕 +陛 +陝 +陞 +陟 +陡 +院 +陣 +除 +陨 +险 +陪 +陰 +陲 +陳 +陵 +陶 +陷 +陸 +険 +陽 +隅 +隆 +隈 +隊 +隋 +隍 +階 +随 +隐 +隔 +隕 +隘 +隙 +際 +障 +隠 +隣 +隧 +隨 +險 +隱 +隴 +隶 +隸 +隻 +隼 +隽 +难 +雀 +雁 +雄 +雅 +集 +雇 +雉 +雋 +雌 +雍 +雎 +雏 +雑 +雒 +雕 +雖 +雙 +雛 +雜 +雞 +離 +難 +雨 +雪 +雯 +雰 +雲 +雳 +零 +雷 +雹 +電 +雾 +需 +霁 +霄 +霆 +震 +霈 +霉 +霊 +霍 +霎 +霏 +霑 +霓 +霖 +霜 +霞 +霧 +霭 +霰 +露 +霸 +霹 +霽 +霾 +靂 +靄 +靈 +青 +靓 +靖 +静 +靚 +靛 +靜 +非 +靠 +靡 +面 +靥 +靦 +革 +靳 +靴 +靶 +靼 +鞅 +鞋 +鞍 +鞏 +鞑 +鞘 +鞠 +鞣 +鞦 +鞭 +韆 +韋 +韌 +韓 +韜 +韦 +韧 +韩 +韬 +韭 +音 +韵 +韶 +韻 +響 +頁 +頂 +頃 +項 +順 +須 +頌 +預 +頑 +頒 +頓 +頗 +領 +頜 +頡 +頤 +頫 +頭 +頰 +頷 +頸 +頹 +頻 +頼 +顆 +題 +額 +顎 +顏 +顔 +願 +顛 +類 +顧 +顫 +顯 +顱 +顴 +页 +顶 +顷 +项 +顺 +须 +顼 +顽 +顾 +顿 +颁 +颂 +预 +颅 +领 +颇 +颈 +颉 +颊 +颌 +颍 +颐 +频 +颓 +颔 +颖 +颗 +题 +颚 +颛 +颜 +额 +颞 +颠 +颡 +颢 +颤 +颦 +颧 +風 +颯 +颱 +颳 +颶 +颼 +飄 +飆 +风 +飒 +飓 +飕 +飘 +飙 +飚 +飛 +飞 +食 +飢 +飨 +飩 +飪 +飯 +飲 +飼 +飽 +飾 +餃 +餅 +餉 +養 +餌 +餐 +餒 +餓 +餘 +餚 +餛 +餞 +餡 +館 +餮 +餵 +餾 +饅 +饈 +饋 +饌 +饍 +饑 +饒 +饕 +饗 +饞 +饥 +饨 +饪 +饬 +饭 +饮 +饯 +饰 +饱 +饲 +饴 +饵 +饶 +饷 +饺 +饼 +饽 +饿 +馀 +馁 +馄 +馅 +馆 +馈 +馋 +馍 +馏 +馒 +馔 +首 +馗 +香 +馥 +馨 +馬 +馭 +馮 +馳 +馴 +駁 +駄 +駅 +駆 +駐 +駒 +駕 +駛 +駝 +駭 +駱 +駿 +騁 +騎 +騏 +験 +騙 +騨 +騰 +騷 +驀 +驅 +驊 +驍 +驒 +驕 +驗 +驚 +驛 +驟 +驢 +驥 +马 +驭 +驮 +驯 +驰 +驱 +驳 +驴 +驶 +驷 +驸 +驹 +驻 +驼 +驾 +驿 +骁 +骂 +骄 +骅 +骆 +骇 +骈 +骊 +骋 +验 +骏 +骐 +骑 +骗 +骚 +骛 +骜 +骞 +骠 +骡 +骤 +骥 +骧 +骨 +骯 +骰 +骶 +骷 +骸 +骼 +髂 +髅 +髋 +髏 +髒 +髓 +體 +髖 +高 +髦 +髪 +髮 +髯 +髻 +鬃 +鬆 +鬍 +鬓 +鬚 +鬟 +鬢 +鬣 +鬥 +鬧 +鬱 +鬼 +魁 +魂 +魄 +魅 +魇 +魍 +魏 +魔 +魘 +魚 +魯 +魷 +鮑 +鮨 +鮪 +鮭 +鮮 +鯉 +鯊 +鯖 +鯛 +鯨 +鯰 +鯽 +鰍 +鰓 +鰭 +鰲 +鰻 +鰾 +鱈 +鱉 +鱔 +鱗 +鱷 +鱸 +鱼 +鱿 +鲁 +鲈 +鲍 +鲑 +鲛 +鲜 +鲟 +鲢 +鲤 +鲨 +鲫 +鲱 +鲲 +鲶 +鲷 +鲸 +鳃 +鳄 +鳅 +鳌 +鳍 +鳕 +鳖 +鳗 +鳝 +鳞 +鳥 +鳩 +鳳 +鳴 +鳶 +鴉 +鴕 +鴛 +鴦 +鴨 +鴻 +鴿 +鵑 +鵜 +鵝 +鵡 +鵬 +鵰 +鵲 +鶘 +鶩 +鶯 +鶴 +鷗 +鷲 +鷹 +鷺 +鸚 +鸞 +鸟 +鸠 +鸡 +鸢 +鸣 +鸥 +鸦 +鸨 +鸪 +鸭 +鸯 +鸳 +鸵 +鸽 +鸾 +鸿 +鹂 +鹃 +鹄 +鹅 +鹈 +鹉 +鹊 +鹌 +鹏 +鹑 +鹕 +鹘 +鹜 +鹞 +鹤 +鹦 +鹧 +鹫 +鹭 +鹰 +鹳 +鹵 +鹹 +鹼 +鹽 +鹿 +麂 +麋 +麒 +麓 +麗 +麝 +麟 +麥 +麦 +麩 +麴 +麵 +麸 +麺 +麻 +麼 +麽 +麾 +黃 +黄 +黍 +黎 +黏 +黑 +黒 +黔 +默 +黛 +黜 +黝 +點 +黠 +黨 +黯 +黴 +鼋 +鼎 +鼐 +鼓 +鼠 +鼬 +鼹 +鼻 +鼾 +齁 +齊 +齋 +齐 +齒 +齡 +齢 +齣 +齦 +齿 +龄 +龅 +龈 +龊 +龋 +龌 +龍 +龐 +龔 +龕 +龙 +龚 +龛 +龜 +龟 +︰ +︱ +︶ +︿ +﹁ +﹂ +﹍ +﹏ +﹐ +﹑ +﹒ +﹔ +﹕ +﹖ +﹗ +﹙ +﹚ +﹝ +﹞ +﹡ +﹣ +! +" +# +$ +% +& +' +( +) +* ++ +, +- +. +/ +0 +1 +2 +3 +4 +5 +6 +7 +8 +9 +: +; +< += +> +? +@ +[ +\ +] +^ +_ +` +a +b +c +d +e +f +g +h +i +j +k +l +m +n +o +p +q +r +s +t +u +v +w +x +y +z +{ +| +} +~ +。 +「 +」 +、 +・ +ッ +ー +イ +ク +シ +ス +ト +ノ +フ +ラ +ル +ン +゙ +゚ + ̄ +¥ +👍 +🔥 +😂 +😎 +... +yam +10 +2017 +12 +11 +2016 +20 +30 +15 +06 +lofter + +2015 +by +16 +14 +18 +13 +24 +17 +2014 +21 + +22 +19 +25 +23 +com +100 +00 +05 +2013 + +03 +09 +08 +28 + +50 +01 +04 + +27 +02 +2012 + +26 + +07 + + + + + + +29 +2011 +40 + +2010 + + + +2009 + +app +www +the + +31 + + + + + +2008 +60 +http +200 +qq + +80 + +google +pixnet +90 +cookies +tripadvisor +500 + + +35 + +facebook +2007 +2000 +70 + +of + + +45 +300 +iphone +32 +1000 +2006 +48 +ip +36 +in +38 +3d + + +55 +ctrip + + +33 + +to +34 +400 +id +2005 +it +37 +windows +llc +top +99 +42 +39 +000 +led +at + +41 +51 +52 +46 +49 +43 +53 +44 + +android +58 +and +59 +2004 +56 +vr + +5000 +2003 +47 +blogthis +twitter +54 + +150 +ok +2018 +57 +75 +cn +no +ios + + + +800 +on +te +3000 +65 +2001 +360 +95 +ig +lv +120 + + + + +pc +てす +── +600 + +85 +2002 +88 + +html +ncc +wifi +email +64 +blog +is + + +mail +online + +dvd + +studio + + + + +line +vip +72 + +98 + + +for + + + + +usb +net +cp +1999 +asia +4g + +diy +new +3c + +ta +66 +language +vs +apple +tw +86 +web + +ipad +62 +you + +101 +68 + +ps +de +bt +pony +atm + +1998 +67 + +ceo + +go + +av +pro +cafe +96 +pinterest +97 +63 +pixstyleme3c + +more +said + +1997 +mp3 +700 + +nba +jun + +92 +tv +1995 +pm +61 +76 +nbsp +250 + +linux + +cd +110 +hd + +78 + +77 +6000 +am + + +94 + + +69 +180 +gdp +my +105 +81 +abc +89 +flash +79 +one +93 +1990 +1996 + +gps + + +web885 +106 +2020 +91 + +4000 +1500 +xd +boss +isbn +1994 +org + +me +love + +0fork +73 + +3g + + +71 +82 + +hotel +130 +1970 +pk +83 +87 +140 +ie + + + +74 + +seo +cpu + +p2p +84 +may + +sun +tue +internet +cc +posted +youtube + + + +ii + + +abs +nt +pdf +yahoo +ago +1980 + +news +mac +104 + + + +java +1992 +spa + + +hk +all +plus +la +1993 + + + +west + +160 +air + + +から + +1989 +logo +htc +php +https +fi +momo + +sat + + +ebd +suv +wi +day +apk + + +mv +galaxy +wiki +or +brake + +1200 +する +this +1991 +mon + +❤2017 +po + +javascript +life +home +june + +system +900 + + +pp +1988 +world +fb +4k +br + +ic +ai +leonardo +safari + +live +free +xx +wed +win7 +kiehl + +lg +o2o + +us +235 +1949 +mm +しい +vfm +kanye + + + +jr + +123 +rss + + + + +thu +fri +350 + + +103 +comments +name + + + +max +1987 +8000 +uber + + +wordpress +office +1986 +1985 + +107 +bd +win10 + + +gmail +bb +dior + + + + +up +cad + +dr +して +read + +をお + + +url +1984 +pvc +paypal +show +policy + + + +with + + +txt +102 + +dna +from +post +mini +ar +taiwan +john + +privacy +agoda + + +word + + + + + +1982 + +265 +cookie +netscape +108 + + + +house +share +note +ibm +code +hello +nike +sim +survey + +1979 +1950 +wikia + + +5g +cbc + + +1983 + + +campaign +store +2500 +os + + + +170 +api + +365 +excel + + + + +~~ + +university +163 +には +518 + + + + +pierre +ipo +0020 +897 + +hotels + +のお +125 +years +6606 + + +high + +time + +bug + + + + +xp +talk2yam +yamservice +10000 +coco + +sony + +1978 +microsoft +david +people + +1960 +instagram +intel +その + +iso +1981 + +115 + + +xxx +man +co +ltxsw + +baby +220 + + +1945 +7000 +tag +450 + +msn + +oppo + + +control + +st +chrome + + +be + +lol + +した + +240 +lady + + + +4600 + + + +4s +corporation +168 + +herme + +cp +978 + + +ui + +ppt +admin +three +します +bbc +re +128 + +ca + + +hp + +tpp + + +×× +root + + + + +adobe +park +114 +et +oled +city + + + +china + +20000 +view + +global + +your +hong + +out + +ng +ebay + +menu +ubuntu + +rom + +open +ktv +do +server + +if +english + + + +1600 + +step1 +kong +club +135 +july +inc +1976 +mr +hi + +touch + + +michael +lcd + + +phone +james +step2 +1300 +ios9 + +dc + + +samsung +111 +280 +pokemon +css + + +いいえ + +s8 +atom +play +bmw + +sa +etf +ctrl +♥yoyo♥ + +2025 + + +adidas +amazon +1958 + + +visa + + +1800 +connectivity + +firefox +109 +118 +hr +so +style +mark +pop +ol +skip +1975 +as + + + +190 +mba + + +le + +1900 +cafe2017 +lte +super +113 +129 + +amd +like + +are + +we + +paul +data +international + +longchamp +ssd +good + + +reply + +↓↓↓ +apr +star + +source +136 +js +112 +get +force +photo + +126 + + +link +bbs +1972 +goods + +python +119 + +game + + +blue + +520 + +page +itunes + +1955 +260 +1968 +gt +gif +618 + + +group +くたさい +about +bar +ganji + +music +lee +not +1977 +1971 +1973 + +an +faq +comment + +days + +116 + +1974 +1969 +v1 +player +1956 +xbox +sql +fm +f1 +139 + +210 + + + +melody +1957 + +550 +17life +199 +1966 +xml +market + + +999 + +what +gl + + +tips + +book + +mysql +can +1959 +230 + +wonderland +watch +10℃ + +9000 +mar +mobile +1946 +1962 +article + +part +▲top +party +って +1967 +1964 +1948 + + + +この +dj + + +010 +main +225 +1965 + +art +320 +ad +134 +020 + +117 +pm2 +japan +228 + +ts +1963 + +der +sm + +2019 + +ct + + + +1937 +homemesh +search + + + + +macbook + + +service + +type +った +750 + + + + + +best + +goris +lock + +cf +3m +big + +ftp +carol + +10 +1961 +happy +sd + +122 +anti +pe +cnn +iii +1920 +138 + +1940 +esp +jan +tags + + +august +vol + +154 + + + + +design +ac + +press +jordan +ppp +that +key +check + + + +1080p + +power + +1952 + +vivi + +he +133 +121 +jpg + +201 +175 +3500 +1947 +nb + + +しています +1954 +usd + +master + +001 +model + +al + +1953 + +ram +goo +ても + +127 +1930 +red + +rpg +item + + +270 + +project + +hot +td +blogabstract + + +650 + +gr2 + + +black +electronic +nfc +year +asus +また +html5 +cindy + +m3 +132 +esc + +booking + +fed +tvb + + +mit +165 + +chan +192 +distribution +next +になる +peter +bios +steam +cm +1941 +にも +pk10 + + + +dec +nasa + +icecat +00z +b1 +will + +li +se + + + +oct + +jp + + +cio + +smart +h5 + + +curve +vpn + + +utc + +12345678910 + +rmvb +chanel +a4 +miss + + +media +who + +she +girl +5s +124 +vera + +class +vivo +king + + +national +ab +1951 +5cm +888 +145 +ipod +ap +1100 +5mm +211 +ms +2756 + +mp4 +msci + + +131 +mg +index +380 + + + + + +158 +apec + +photoshop +opec +¥799 +ては + + + +2g +○○ + +¥2899 + + + +1938 + +kitty +content + +step3 + +win8 +155 +vc +1400 +iphone7 +robert + +tcl +137 +beauty + +en +dollars + + +step +pay +yy +a1 + + + + +1939 +188 +download +1944 +sep +exe +ph +います +school +gb +center +pr +street + +uv + + +winrar + + + +1942 +1936 +480 +gpu + +ettoday +fu +tom + + + +149 + +b2b +144 + + +rose +arm +mb + + + +nvidia +step4 +mvp +00㎡ +york +156 + +how +cpi +591 +2765 +gov +kg +joe + +mandy +pa + +copyright +fashion +1935 +don + +ecu + + +erp +wap +have + +talk + + + +ch + +video +1943 +cs +san +iot +look + + + +october + +trump + + +box +141 +first + +april + + +185 +angel +protected +aa +151 +162 +x1 +m2 + + + +size +143 +min +ofo +fun +gomaji +ex +hdmi +food +dns +march +chris +kevin + + + + +ag +ems +6s +720p + + +off + +asp +team +fandom +ed +299 +▌♥ + +info +されています + +sina +4066 +161 + + +330 +399 +315 +dll +rights +ltd +idc +jul +3kg +1927 +142 +ma +surface + + +~~~ +304 +mall +eps +146 +green + +map +space +donald +v2 +sodu + +1931 +148 +1700 +まて +310 +reserved +htm + + +2d +178 +mod + + +152 +ti + +doc +1933 +icp +055 +wang + +shopping +aug + + +now +wam +b2 +からお + +236 +1928 + +266 +f2 + +153 +mix + + +bwl + + +core + +tea +5℃ +hktvmall +nhk + +list + +301 +feb +4m +inn +ての +nov +159 +12345 +daniel + +pass + + +coffee +202 +ssl +airbnb + +fbi +woshipm +skype +ea +cg +sp + + +yes +edge +alt +007 + +fpga + + +iso9001 +さい + + + +image +lin +icon +american + +1932 +set +says + + +blogger + +なと +256 +147 + + + + + +nokia +claire + + +november +lohas + + + + + + + + + +db +january +win + +166 +road +ptt + + +198 + + +anna +pchome +はい +udn +ef +420 + + +2030 + +g20 +white +かかります +1929 +308 +garden +eleven +di + +chen +309b +777 +172 +young +cosplay +ちてない +4500 +bat + + + +kindle +npc +steve +etc + + +call +xperia +ces +travel +sk +s7 + +1934 + +みいたたけます +183 +edu +file +cho +qr + + +186 + + +eric +1914 +rends + + +mastercard + +kb + +290 + +vista + + +jack +2400 + +169 +pos +1912 + + +taipei +しく +205 +beta + +232 + +express +255 +body + +aphojoy +user +december +meiki + +tweet +richard + + +iphone6 + +ちてすか +views + +321 +pd + +times + +level + +10g +point +5l + +208 +koreanmall + +george +q2 +206 +wma +tcp + +スタッフ +full +mlb + + +tm +run +179 +911 +smith +business + +1919 +color + +222 +171 + +moon +4399 + +update +pcb +shop +499 +157 +little +なし +end + +van +dsp +easy +660 + + +history + +oh + + + +oem +let +was + + +review + +182 + +203 +uc +title + +united +233 +2021 + +doi +trivago +overdope +sbs + + +grand +special +573032185 +imf +216 +wx17house + + +audi + +london +william + + +science +beach +cfa +amp +ps4 +880 + + + +crm +ferragamo +bell +make + +195 +under +zh +photos +2300 + + +via +176 +da + +company +i7 + +thomas +370 +ufo +i5 + +plc +ben +back +research +8g +173 +mike + + +september +189 + +vps +february +167 +pantos +wp +lisa +1921 +★★ +jquery +night +long +offer + + +1911 + +ray +fks +wto +せます +over +164 +340 + + +1924 + + +blogtitle +loftpermalink + +187 +martin +test +ling +km + +15000 +fda +v3 + + +wedding +かある +outlet +family + +をこ + +story + +salvatore + +204 +swift +215 +room +している +oracle + +1925 +sam +b2c +week +pi +rock + + + + + + +cctv +after +chinese + +powered +x2 + +1918 + + +canon +only +181 + + +say + +184 + +221 + + + +sky +made +top100 +just +1926 +pmi +802 +234 +gap + +177 +les +174 +▲topoct +ball +vogue +vi +ing +ofweek +cos + + +▲topmay + + +として +last + + + + +real +eva + +a3 +nas + + + + +▲topapr +his +212 +cat +nata +vive +health +⋯⋯ +drive +sir +▲topmar +du +cup + + + + +alex +msg +tour +しました +3ce + +193 +ebooks +r8 +block +318 + +2200 +nice +pvp +207 +months +1905 +rewards + +1917 +0800 + + + +micro +850 +gg +blogfp +op +1922 +daily +m1 +264 +true + +ml + + + +anthony +196 +253 + +state +218 + + + + + +より +gear + + +ge +see +1923 + + +ss +heart + + +down + +el +png +2100 +610 +rakuten +whatsapp +bay +dream +add + +680 +311 +pad +gucci +mpv + + +island +▲topjun + +223 +jason +214 +chicago + +しの + +io + + +sogo +be2 + +990 +cloud +vcd + +2~3 + + + + + +but + +docker + +rfid +ul + +hit +ford + +580 + +11 +a2 +sdk +reading +edited + +cmos + +238 +siri +light + + +bloomberg + +pizza + +jimmy + +college +node +journal +ba +18k + +245 + +20 +magic + +191 +jump +288 +tt + +asr + +3200 +step5 +network + +mc +いします +1234 +pixstyleme +273 + +2800 +money +★★★★★ +1280 +12 +430 +bl +みの +act + +tokyo + + +emba + +saas +tcs + + +summer + +ko + +390 +premium + +netflix + +uk +mt + +right +frank +two +209 +える + + +021 + + + +hold +nexus +dd + +てお + + +tila +zero +820 +ce + +resort + +charles +old +p10 +5d +report + + + +bus +vans +lt + +pv + +links +rebecca + + +azure + +きな +limited +bit +4gb + +1910 +moto + +213 +1913 +var +eos +なとの +226 +blogspot +された +699 +e3 +dos +dm +fc + + + +boy + + +960 +er + +219 + + + +194 + +station + + +835 +files +zara +hdr +top10 +nature +950 +magazine +s6 +marriott + +avira +case + +tab + +tony + +oculus +im + +jean +saint +cry +307 +rosie + + +ice + +のある + + +pet +2600 + +plurk +▲topdec + +00kg +▲topnov +720 + +tim + + + + +log +ips +great +ikea +malaysia +unix + +3600 + + +12000 +akb48 + + +404 + + +oa +xuehai + + + +275 +さん + + +980 +ho + +text + +560 +bob +227 + + +8891 +scp +avi + +2022 +mi +wu +museum +qvod +apache +lake +jcb +▲topaug +★★★ +ni + +hill +302 +ne +weibo +490 +ruby + + + +4d +▲topjul +iv + +github +306 +mate +312 + + + +andrew +のハイト + +t1 +rf +ed2k + + +way +final +りの +ns +5a +705 +197 + +sweet +bytes + +▲topjan +231 + + + +100g +topapp +229 +helpapp +rs +low +14k +g4g +care +630 +ldquo +あり + +leave +rm +edition + + + +▲topsep + + +gold +224 +explorer + +toyota +category +select +visual + +restaurant + +posts +s1 + +もっと +angelababy +123456 +217 +sports +s3 +mbc +1915 +してくたさい +shell +x86 +candy + +kbs +face +xl +470 + +4a +swissinfo +v8 +▲topfeb +dram + + +3a + +sport +q1 +ios10 +public +int +card + +ep +au +rt + +1080 +bill + +kim +30 +460 +wan + + +x3 +298 +0t +scott + +239 +e5 + +h7n9 +worldcat +brown + + + + + +249 +410 + +paris + +polo +925 + +599 + +capital + +bank +cv +1g + + + +adc + +2m + +digital +hotmail +268 + +870 +bbq +quot + +before +wali + +mcu +2k +2b +という +costco +316 +north +333 +switch + + +philips + +management +panasonic + + + + +alice + + +css3 + +vision +alpha + + + +lz +にお + +mode +gre +1916 +pci + +237 +1~2 + + +について + + +work +war +coach +ah +mary + +huang + +a8 +pt +follow + +1895 + +a5 +ghost + + + +south + +girls + +action +villa +git +r11 +table +games + +error + + +here + + +qa + + +gmp + +vmalife + +yu +wedding + +demo +dragon +530 +soho +social +bye + +river +orz +acer +325 + + + +261 +del + +440 +ups + + +305 +value +macd +yougou + +661 + +ll + + +continue +script + + +paper +263 +319 +shift + + + +258 +x5 +fox +243 + +car +aaa + +loading + + +kuso +799 +si +sns +イカせるテンマ +ヒンクテンマ3 +rmb +vdc +forest +central +prime +help +ultra + + +241 +square +688 + +のないフロクに + + + + +c1 +start +510 + + +cdn + +cba +stephen +m8 +100km + +opera + + +vsa +com™ + + +251 +なのて +count +t2 + + +2700 +hop + +vsc +tree + + +816 +285 + + +alphago +v4 +1909 +simon + +fluke62max +zip +スホンサー + +louis +cr +bas + +bc + +hadoop + + +1906 +0755 +hola + +place +centre +5v +d3 + +252 + + +281 +540 +0l +exchange +262 +series + + +eb + + +q3 + + +take + +259 +1888 +client +east +cache +event +vincent + +きを + +sui +855 +adchoice + + + +246 + +ga +apps +sea + +248 +cisco + + +kymco + +dha + + +minkoff +royal +p1 +への +annie +269 +collection +kpi +playstation +257 +になります +866 +bh + +queen +505 +radio +1904 +andy +armani + +manager +iherb + + +spring +raid +johnson +1908 + +volvo +hall + +v6 +our +taylor + +bi +242 + +kate +bo +water +technology + +サイトは +277 + + +hpv +303 +gtx +hip +rdquo +jayz +stone + + +namespace + +620 + + +des + + + + +enter + + +d2 + + + +との +a9 +jj +ky + +access +movie + +リストに +tower + + +ます + +ua +tel +prefix + +1907 + +1901 +ott +~10 + + +baidu + +member + +bigbang +nownews + + + + +247 +eba + + +ける +v5 +spark + +there + +god + + +hiv + +burberry +day2 + +◆◆ +jeff +related +film +edit +joseph +283 + +cx +32gb +order +g9 +30000 + + +s5 + +かあります +thread +xr +buy +sh +005 +land +spotify +mx + +276 + +×email +sf +why + +244 +7headlines +nego +sunny +dom +exo +401 +666 +positioning +fit +rgb + +278 +kiss +alexa +adam +lp +みリストを + +mp + + +amy + +np +002 +institute +271 + + +2345 +590 + +sidebar +15 +imax +site + + + + +season +323 + + + +gogoro +a7 +pu +lily +fire +twd600 + +いて + +30ml + + +information + +close +friday + +yi +nick +てすか + + +6500 + +cbd +economy +254 +かお +267 +tinker +double +375 +8gb +voice + +oops +channel +today +985 + +raw +xyz + +jim +edm + +7500 +supreme +814 +ds + + +dropbox + + +books +272 +100ml + + + + + +sex +309 + +t3 + + + +1903 +810 +feel +5500 + + +により +s2 +mo + +men +ka +amoled +div + + +port +howard + +ken +dnf + +adsense + +ide + +buff +thunder + + +has + +auto +pin + +tee +てした +295 +number + + +object +psp +cool +udnbkk +16gb + +miui + +most +r2 + + +1880 +±0 + +428 +s4 +law +version + +n1 +sgs +docomo + + +henry +fc2 + + + + +286 +0mm +linkedin + + +wii + +ucbug + +sputniknews +legalminer + + +2gb + +q10 +oo +b6 +come + +cheese +ming +maker + +nikon + +ppi +kelly + +jchere +てきます +ted +md +003 +fgo +tech + +dan +soc + + +hair +earth +640 +521 +img + + + + +acca + + +suite + +outlook + + +398 + +279 +101vip +358 + +282 +64gb +3800 +345 +airport + +284 + +jones + +lab + + +co2 +town +piece + +no1 +vmware +24h + +focus +reader + + +tb +false + +1898 +know +lan +838 + +f4 + +motel +stop + +na +flickr +netcomponents + + +pose +williams +local + + + + +いお +274 +5m +gsm +con + +1902 +friends + +cell +317 + +780 +cream + +012 + +facebooktwitterpinterestgoogle +sso +324 +shtml +song +swiss + + +lumia +xdd +string +tiffany +522 +marc +られた +insee +russell +sc +dell + +ok +camera +289 + + + +classic +287 + +stay +g1 +mtv +512 + + + +qe +sata +ryan +d1 +50ml +cms + +su +292 +3300 +editor +296 + +security +sunday +association + + + +acg + +sofascore +とは +mkv + +jonathan +gary +build +labels + +tesla +moba +qi +gohappy +general +ajax +1024 + +サイト +society + + +wps +fedora + +mozilla +328 + + +usa +urn + + +grace + + + +1250 + +elle +570 + + +price + +uhz + +eq + +states +push +session +balance +wow +506 + + +when + + +34e +wong +library +prada + + +running + +313 +ck +date +q4 + + + +mk + + +388 +die +secret +rq +dota +buffet +は1ヶ +e6 + +pan +368 +ha + + +2a + +alan +day3 +eye +f3 + +france +keep +adi +rna +tvbs + +solo +nova + + + +support + + + +base +copy +iis +fps + +hero +hgih +profile +fish +mu +ssh +entertainment +chang + +click +cake + +pre + +kic +pixel + + +product +6a + +dear + +es +yumi +audio + + +echo +bin +where + +329 + +find +sap +isis + +nand + + + +band +a6 +525 +never + +festival +50cm + +555 +guide +314 +zenfone + +335 +gd +forum +jessica +strong +alexander + +software +allen + +program +360° +else +lohasthree + +することかてきます +please + +rc + + +bim +50000 + +eclipse +355 +brian +3ds + +061 +361 + + + + +485 +engine + + +plaza + +cia +ngo +westbrook +shi +tbs +50mm + +sci +291 +reuters + +contextlink + +af + +bridge +very + +1890 +cambridge + +15g + + +790 +frm + +award +butler + +meta + +america +ps3 +puma +pmid + +lc +670 +kitchen + +オーフン5 +きなしソフトサーヒス +そして +day1 +future +★★★★ + + + +pm1 + +fans + +1001 +christian +bot +kids +trackback + +c3 +display + +n2 +1896 +idea +さんも + +airmail + + +pwm +けます +028 + +369 +852 +awards +schemas +354 +asics +wikipedia +font + + +c2 +293 + + + +っている +contact +pepper +スキル +339 + +294 + + +730 + +みてす +q5 + +rain + +wei +swatch + +わせ +331 +popular + + +p2 +501 +trc +1899 + + +justin +honda +ping +messenger + +v9 +543 + +unity +appqq +はすへて +025 +leo + + + +uniqlo + +502 +her +jane +memory +moneydj + +human +12306 +していると + +coc +miacare + +tmt + +vim +kk + +fan +target +use +too +338 +435 +2050 +867 +737 +fast + +services + +omega +energy + +pinkoi +1a + + +jackson + + +374 +366 +そんな +p9 +rd + +1111 + + +zone + +385 +690 +dl +isofix +cpa +m4 +322 +kimi +めて +davis + +lulu + +050 +weeks +qs + +920 + +ae + +~5 +eia +405 + +korea +jpeg +boost + +small + +1860 +eur +297 +425 +valley + +simple + +rn +k2 + +されます +non +patrick +しているから + +feed +5757 +30g +process +well +qqmei + +they +aws +lu +pink + + +または +board + +wine + +unicode + +r1 +359 + +いを + + +cool1 +される + + +isp + +standard +45㎡2 +402 + +matt + +326 + +googlemsn +pixnetfacebookyahoo + +x7 +886 + +メーカー +sao + + + +9678 +403 +xddd +shirt +6l + + +3mm +givenchy +ya +bang + +monday +crystal +ロクイン + +336 +head +890 +ubuntuforumwikilinuxpastechat + + + +cnc +7866 +ipv6 +null +1897 + +yang +imsean +tiger + + +352 + +dji +327 +ji +maria + + +foundation +3100 + + +1m +601 +active + + +3p +sr +349 +emma + +living +415 +353 +1889 +341 +709 +457 +sas +x6 + +pptv +x4 + +han +sophie + +337 +fifa + +other +sale +inwedding + +てきちゃいます + + +bad +nana +nbc +してみてくたさいね +なとはお + + + +note7 +single + +せからこ +してくたさい♪この +しにはとんとんワークケートを +するとあなたにもっとマッチした +ならワークケートへ +もみつかっちゃうかも +ワークケートの + +window + + +union +age +382 +14 + + +コメント +domain +neo + + +5k +f5 +steven + +powerpoint +tft +self +g2 +ft + +zol + +mwc +381 +343 +もう +nbapop +408 +てある +eds +ace + +previous +author +tomtom +il + +hu +financial +☆☆☆ +っています +bp +5t +chi +1gb + +fairmont +cross +008 +gay +h2 +function + +356 +also +1b +625 + + +1894 +3~5 + +i3 +334 +avenue + +による + + +message +navigation +50g +fintech +h6 + +8cm + + + +credit + +xxxx +form + + +huawei +plan +json +sbl + +machine +921 +392 +wish + + +windows7 +edward + +development +washington + +lo +818 + + + +planet + + +ieee +gpa + +camp +ann +gm + + +connect + + + +wall +chicken +soul +2mm + +fa + + +009 + +hitachi +gui +harry + +e1 +disney + + +wind +386 +frigidaire + +liu +hsu +332 +basic +von +ev +いた +てきる +スホンサーサイト +learning + +expedia +archives +change + +santa +cut +ins +6gb +turbo +brand +cf1 +508 +004 +return +747 + +h1 + + +128gb + +3t +application +しており +emc +rx + +384 +quick +412 +15058 +wilson +wing +chapter + +beyond + + + +zoom +e2 +trip +sb + +rcep +342 +aspx +ci +080 +gc +gnu +める + +advanced +dance +dv + + +367 +8591 +am09 +shadow +battle +346 + + + +emily + + +host +ff +techorz +sars + + + +nc +4200 +798 + +cma + + + + + +455 + +amana + +426 + +ir +00㎡1 + + + +710 +ˋ▽ˊ + +dcs +iq + +l1 + +maggie + + +588 + +830 + +1tb +articles +create + + +database +fantasy + + +dlc +dean + +hard +path +gaming +victoria +maps +cb + + +overchicstoretvhome +systems + +416 +p3 +sarah +760 + +407 +486 +x9 +install +second +626 + + + + +860 + +ec + +768 +metro +chocolate + +~4 + + +skin + +395 +mountain + +inparadise +6m +7x24 +ib +4800 + +eeworld +creative +g5 +g3 +357 +parker +ecfa +village +からの +18000 +sylvia +サーヒス +hbl + + + + + + +ie6 +383 + +389 +ver + + +sound +bbe + + + +ads +022 +gundam +351 +thinkpad +006 +scrum +match + +mems + + + + +glass +lamigo +span + +job + +jay +wade +kde +498 + +ocean +tvg + + + + +junior +think + +cover + + +↓↓ + +msi +413 +458 +406 + +711 +801 +soft +z2 + +456 +1840 +mobil +mind + +427 +nginx + +めた + +6221 + + + +371 + +91tv +comhd +crv3000 + +1868 +397 +deep +lost +field +gallery + +rate +spf +redis +traction +930 +icloud +011 +なら +fe +jose +372 + +into +sohu +fx +899 +379 +kicstart2 + +すく + + +ra +24 + + +500g + +pacific +xa +natural +carlo + + +1850 + +cto +gigi +516 + +pen + +ob +matlab + + +13913459 + +mango + +sense +c5 +oxford + +walker +jennifer + +course + +701 + + +lucky +075 + +ivy +なお + +sotheby +side + +joy + + + + +364 +r9 + + +511 +country +wear + + + +393 +seven +study +411 +348 +lonzo +8k + +evolution + + +gs +kd + +arduino +344 +b12 + +arpg + +cook + +dark +five + + +とても +sign +362 + +something +20mm + +387 + +fresh +tf +1870 +422 +cam + + + + +education +394 + +dyson +stage + +want + +epson +pack +あります + +テリヘル + +wd + + +left + +golden +mhz +discovery + + +loft + + + +speed +~1 +1mdb +sorry +welcome + +wave +gaga + +teddy + +トラックハック +せよ +611 + +378 +rp + +rar + + +840 +holiday + +373 +074 + + + +gartner +gi +6p + +kit +488 +b3 +eco + +20g +sean + +autocad +nu + +f16 +write +029 +m5 + +images +atp + +fsm +504 +1350 +ve +52kb + + + +414 +unit +lim +ru +1v + +published +angela +16g +analytics +ak + + +gmt + +again + + +ios11 +445 +かこさいます +waze +いてす + +9985 + + +framework + +iptv +delete +52sykb +cl +wwdc +027 +30cm + + +1389 + +brandt + + +tc +vetements +anne +monte +modern +official + + + + +もちろん +50 +etnews + + +421 +863 + +444 + + +l2 + +mount +ccd +たと +archive +morning +tan +ddos +e7 + +day4 + +gis +453 +its +495 +factory +bruce +pg + +ってくたさい +guest +cdma + +536 +n3 +しかし +3~4 +mega +eyes +ro +13 +women +dac +church + +singapore + +6991 +starbucks + + + +zen + +tina +20℃ +1893 + +503 +465 +request + +qt + +1886 +347 +363 +q7 + +diary + +409 + +468 +cst + +canada +agent +va + + + + +sg + + + +g6 + +bing + +charlie +16 +8mm +nb40 + +thai + +ln284ct + + +bonnie + + +originals + + +418 +∟∣ + +children +ntd +yesstyle + +hmv + +d5 +2cm +arts +sms + + + +topios9 +539 +lifestyle +virtual + +xz + +muji +024 +unt + + +faq1 +1884 +396 + +fly +64㎡ +はしめまして +441 +curry + +のこ +release + + + +073 +ありな +500ml + +5c + +ios7 + +787 +dog +lenovo + +roger +013 +cbs +vornado +100m +417 + + + +1867 +9595 +2900 + +oil + +some +break +common + + +g7 +twice +419 +ella +nano +belle +にこ + + + +jb + +benz + + +451 +save + + +kai +りは + + +rainer + +448 + +adsl + +guestname + + + +tokichoi + +county + + +rmk +391 +address +vm +えて +openload + + + +amg +urban + +jobs +emi + +beautiful + +album + + +jerry +works +hostel +miller + + + +376 +boot +828 + + + +1885 + + + + + + +433 + +432 +francis +xi +c919 +b5 +evernote + +vga + +coupe + + + +019 +6g +れる +multi + + +em +hey + + + +inside +than +740 +leonnhurt + +ict +れた +bird +notes +200mm +くの + + +result +442 +iu +ee +438 +smap +gopro + +yin +pure +998 +32g +けた +5kg + + +mama + +bean +marketing + +2l +bella +sync +xuite + +515 +discuz + + + + +cj + +gmat +apt + +jing + +c4 +rich + +niusnews + +bag +770 + + +18 +culture +015 + +377 +1020 +area + +616 +details +gp +universal +silver +dit +はお +private +ddd +u11 +kanshu + +fung + +dx + +tai +475 +023 + + +3s + +429 + +25000 +ly +rick + +usb3 +banner + + +metal +dt +vdf +1871 +karl +qualcomm +bear +1010 +oldid +ian +jo + +population + +1882 +mmorpg + + +603 + +ww +friend + +exhibition + + +fpx +structure + + +kl + + + +california +3400 +orange +yoga +4l +canmake +honey + + +595 +nikkie + +dhl +publishing + + +20cm +513 + + +e88 +970 + +fishbase \ No newline at end of file