From a1dce465e5ba213918a64776dfe58da56e0bf191 Mon Sep 17 00:00:00 2001 From: donghyeon_kim Date: Sat, 12 May 2018 22:02:47 +0900 Subject: [PATCH] Rename snapshot -> context --- README.md | 4 +- dataset.py | 78 +++++------ model.py | 373 +++++++++++++++++++++++++++-------------------------- test.py | 14 +- utils.py | 10 -- 5 files changed, 234 insertions(+), 245 deletions(-) diff --git a/README.md b/README.md index 7dfbfe3..fb73fac 100644 --- a/README.md +++ b/README.md @@ -15,12 +15,12 @@ $ wget -P data https://s3-us-west-1.amazonaws.com/ml-study0/nets/sample_data.csv # Download word, character dictionaries ``` -$ wget -P data https://s3-us-west-1.amazonaws.com/ml-study0/nets/preprocess_20180429_dict.pkl +$ wget -P data https://s3-us-west-1.amazonaws.com/ml-study0/nets/preprocess_180504_dict.pkl ``` # Download pretrained NETS model ``` -$ wget -P data https://s3-us-west-1.amazonaws.com/ml-study0/nets/nets_gradclip_180501_5_1.pth +$ wget -P data https://s3-us-west-1.amazonaws.com/ml-study0/nets/nets_180512_0.pth ``` # Run NETS w/ sample data diff --git a/dataset.py b/dataset.py index edf0f6a..dff7dca 100644 --- a/dataset.py +++ b/dataset.py @@ -72,7 +72,7 @@ def initial_settings(self): self.max_rs_dist = 2 # reg-st week distance self.class_div = 2 # 168 output self.slot_size = 336 - self.max_snapshot = float("inf") # 35 + self.max_context = float("inf") # 35 self.min_word_cnt = 0 self.max_title_len = 50 self.max_word_len = 50 @@ -100,7 +100,7 @@ def initialize_dictionary(self): self.initial_word_dict = {} self.invalid_weeks = [] self.user_event_cnt = {} - + def update_dictionary(self, key, mode=None): # update dictionary given a key if mode == 'c': @@ -182,7 +182,8 @@ def check_maxlen(text, w_key): for single_what in prev_what_list: what_split = nltk.word_tokenize(single_what) if self.config.word2vec_type == 6: - what_split = [w.lower() for w in what_split] + what_split = [word.lower() for word + in what_split] for word in what_split: if word not in self.initial_word_dict: self.initial_word_dict[word] = ( @@ -222,8 +223,8 @@ def get_pretrained_word(self, path): widx2vec = [] unk_cnt = 0 - widx2vec.append([0.0] * self.config.word_embed_dim) # UNK - widx2vec.append([0.0] * self.config.word_embed_dim) # PAD + widx2vec.append([0.] * self.config.word_embed_dim) # PAD + widx2vec.append([1.] * self.config.word_embed_dim) # UNK for word, (word_idx, word_cnt) in self.initial_word_dict.items(): if word != 'UNK' and word != 'PAD': @@ -236,14 +237,14 @@ def get_pretrained_word(self, path): self.widx2vec = widx2vec - print('pretrained vectors', np.asarray(widx2vec).shape, 'unk', unk_cnt) + print('pretrained vectors', np.asarray(widx2vec).shape, '#unk', unk_cnt) print('dictionary change', len(self.initial_word_dict), 'to', len(self.word2idx), len(self.idx2word), end='\n\n') def process_data(self, path, update_dict=False): print('### processing %s' % path) total_data = [] - max_wordlen = max_sentlen = max_dur = max_snapshot = 0 + max_wordlen = max_sentlen = max_dur = max_context = 0 min_dur = float("inf") max_slot_idx = (self.slot_size // self.class_div) - 1 @@ -265,7 +266,7 @@ def process_data(self, path, update_dict=False): """ prev_user = '' prev_st_yw = ('', '') - saved_snapshot = [] + saved_context = [] calendar_data = csv.reader(f, quotechar='"') for k, features in enumerate(calendar_data): @@ -310,7 +311,7 @@ def process_data(self, path, update_dict=False): # process title feature what_split = nltk.word_tokenize(what) if self.config.word2vec_type == 6: - what_split = [w.lower() for w in what_split] + what_split = [word.lower() for word in what_split] for word in what_split: max_wordlen = \ len(word) if len(word) > max_wordlen else max_wordlen @@ -357,26 +358,26 @@ def process_data(self, path, update_dict=False): input_slot = st_slot // self.class_div target_slot = st_slot // self.class_div - # process snapshot + # process context if reg_seq == 0: # start of a new week assert curr_user != prev_user or curr_st_yw != prev_st_yw prev_user = curr_user prev_st_yw = curr_st_yw # prev_grid = [] - input_snapshot = [] - saved_snapshot = [[input_title, fine_duration, input_slot]] + input_context = [] + saved_context = [[input_title, fine_duration, input_slot]] else: # same as the prev week assert curr_user == prev_user and curr_st_yw == prev_st_yw - # input_snapshot = copy.deepcopy(saved_snapshot) - prev_grid = [svs[2] for svs in saved_snapshot] + # input_context = copy.deepcopy(saved_context) + prev_grid = [svs[2] for svs in saved_context] if input_slot in prev_grid: continue - input_snapshot = saved_snapshot[:] - saved_snapshot.append( + input_context = saved_context[:] + saved_context.append( [input_title, fine_duration, input_slot]) - # transform snapshot features into slot grid - # snapshot slots w/ durations + # transform context features into slot grid + # context slots w/ durations target_n_slot = \ int(math.ceil(input_duration / (30 * self.class_div))) @@ -387,7 +388,7 @@ def process_data(self, path, update_dict=False): targets_w_duration.append(target_slot + shift) input_grid = list() - for ips in input_snapshot: + for ips in input_context: n_slots = int(math.ceil(ips[1] / (30 * self.class_div))) for slot_idx in range(n_slots): slot = ips[2] + slot_idx @@ -402,16 +403,16 @@ def process_data(self, path, update_dict=False): assert target_slot not in input_grid - # filter by register distance & max_snapshot & recurrent + # filter by register distance & max_context & recurrent if (reg_st_week_dist <= self.max_rs_dist - and len(input_snapshot) <= self.max_snapshot + and len(input_context) <= self.max_context and 'False' == is_recurrent): - max_snapshot = max_snapshot \ - if max_snapshot > len(input_snapshot) \ - else len(input_snapshot) + max_context = max_context \ + if max_context > len(input_context) \ + else len(input_context) total_data.append( [input_user, input_title, input_duration, - input_snapshot, input_grid, target_slot]) + input_context, input_grid, target_slot]) if user_id not in self.user_event_cnt: self.user_event_cnt[user_id] = 1 @@ -431,7 +432,7 @@ def process_data(self, path, update_dict=False): print('data size', len(total_data)) print('max duration', max_dur) print('min duration', min_dur) - print('max snapshot', max_snapshot) + print('max context', max_context) print('max wordlen', max_wordlen) print('max sentlen', max_sentlen, end='\n\n') @@ -452,7 +453,7 @@ def get_dataloader(self, batch_size=None, shuffle=True): sampler=train_sampler, num_workers=self.config.data_workers, collate_fn=self.batchify, - pin_memory=True, + pin_memory=True ) else: train_loader = None @@ -468,7 +469,7 @@ def get_dataloader(self, batch_size=None, shuffle=True): sampler=valid_sampler, num_workers=self.config.data_workers, collate_fn=self.batchify, - pin_memory=True, + pin_memory=True ) else: valid_loader = None @@ -483,7 +484,7 @@ def get_dataloader(self, batch_size=None, shuffle=True): sampler=test_sampler, num_workers=self.config.data_workers, collate_fn=self.batchify, - pin_memory=True, + pin_memory=True ) return train_loader, valid_loader, test_loader @@ -549,14 +550,14 @@ def __getitem__(self, index): tw = title[1] tl = title[2] - # Snapshot (title, duration, slot) - snapshot = example[3] + # context (title, duration, slot) + context = example[3] stc = [] stw = [] stl = [] sdur = [] sslot = [] - for _, event in enumerate(snapshot): + for _, event in enumerate(context): stc.append(event[0][0]) stw.append(event[0][1]) stl.append(event[0][2]) @@ -564,8 +565,7 @@ def __getitem__(self, index): sslot.append(event[2]) # Grid - grid = torch.zeros( - self.config.sm_day_num * self.config.sm_slot_num) + grid = torch.zeros(self.config.sm_day_num * self.config.sm_slot_num) if len(example[4]) > 0: grid[example[4]] = 1 @@ -576,12 +576,12 @@ def __getitem__(self, index): return user, dur, tc, tw, tl, stc, stw, stl, sdur, sslot, grid, target def lengths(self): - def maxlen_from_snapshot(snapshots): - if len(snapshots) > 0: - return max([s[0][2] for s in snapshots]) + def maxlen_from_context(contexts): + if len(contexts) > 0: + return max([s[0][2] for s in contexts]) else: return 0 - return [(example[1][2], maxlen_from_snapshot(example[3])) + return [(example[1][2], maxlen_from_context(example[3])) for example in self.examples] @@ -635,7 +635,7 @@ def __init__(self): self.sm_day_num = 7 self.sm_slot_num = 24 self.preprocess_save_path = './data/preprocess_tmp.pkl' - self.preprocess_load_path = './data/preprocess_20180429.pkl' + self.preprocess_load_path = './data/preprocess_.pkl' if __name__ == '__main__': diff --git a/model.py b/model.py index 14eacc0..3d20170 100644 --- a/model.py +++ b/model.py @@ -29,12 +29,12 @@ def __init__(self, config, widx2vec, class_weight=None, idx=None): config.word_embed_dim, padding_idx=0) - if not config.no_intention or not config.no_snapshot: + if not config.no_intention or not config.no_context: self.user_embed = nn.Embedding(config.user_size, config.user_embed_dim) if not config.no_intention: self.dur_embed = nn.Embedding(config.dur_size, config.dur_embed_dim) - if not config.no_snapshot: + if not config.no_context: self.slot_embed = nn.Embedding(self.n_classes, config.slot_embed_dim) self.emtpy_long = torch.LongTensor([]).to(self.device) @@ -44,12 +44,13 @@ def __init__(self, config, widx2vec, class_weight=None, idx=None): self.t_rnn_idim = config.word_embed_dim + sum(config.tc_conv_fn) self.st_rnn_idim = config.word_embed_dim + sum(config.tc_conv_fn) self.sm_conv1_idim = config.user_embed_dim + config.slot_embed_dim - if not config.no_snapshot and not config.no_snapshot_title: + if not config.no_context and not config.no_context_title: self.sm_conv1_idim += config.st_rnn_hdim * self.num_directions self.empty_st_rnn_output = \ - torch.zeros(1, self.config.st_rnn_hdim * self.num_directions)\ + torch.zeros(1, self.config.st_rnn_hdim * self.num_directions) \ .to(self.device) - self.sm_conv2_idim = sum(config.sm_conv_fn[:len(config.sm_conv_fn)//2]) + self.sm_conv2_idim = sum( + config.sm_conv_fn[:len(config.sm_conv_fn) // 2]) self.it_idim = config.user_embed_dim + config.dur_embed_dim if not config.no_title: @@ -60,12 +61,12 @@ def __init__(self, config, widx2vec, class_weight=None, idx=None): else: if not config.no_title: self.mt_idim += config.t_rnn_hdim * self.num_directions - if not config.no_snapshot: - self.snapshot_odim = sum( + if not config.no_context: + self.context_odim = sum( config.sm_conv_fn[len(config.sm_conv_fn) // 2:]) self.mt_idim += config.sm_day_num * config.sm_slot_num - self.mt_idim += self.snapshot_odim + self.mt_idim += self.context_odim # convolution layers self.tc_conv = nn.ModuleList( @@ -77,22 +78,22 @@ def __init__(self, config, widx2vec, class_weight=None, idx=None): for num_tc_conv_f in config.tc_conv_fn]) self.tc_conv_min_dim = len(config.tc_conv_fn) + 1 - if not config.no_snapshot: + if not config.no_context: self.sm_conv1 = nn.ModuleList([nn.Conv2d( - self.sm_conv1_idim, config.sm_conv_fn[i], - (config.sm_conv_fh[i], config.sm_conv_fw[i]), - stride=1, padding=config.sm_conv_pd[i]) - for i in range(0, len(config.sm_conv_fn)//2)]) + self.sm_conv1_idim, config.sm_conv_fn[i], + (config.sm_conv_fh[i], config.sm_conv_fw[i]), + stride=1, padding=config.sm_conv_pd[i]) + for i in range(0, len(config.sm_conv_fn) // 2)]) self.sm_mp1 = nn.MaxPool2d(2) self.sm_conv1_bn = nn.BatchNorm2d(self.sm_conv2_idim) self.sm_conv2 = nn.ModuleList([nn.Conv2d( - self.sm_conv2_idim, - config.sm_conv_fn[i + len(config.sm_conv_fn)//2], - (config.sm_conv_fh[i], config.sm_conv_fw[i]), - stride=1, padding=config.sm_conv_pd[i]) - for i in range(len(config.sm_conv_fn)//2)]) + self.sm_conv2_idim, + config.sm_conv_fn[i + len(config.sm_conv_fn) // 2], + (config.sm_conv_fh[i], config.sm_conv_fw[i]), + stride=1, padding=config.sm_conv_pd[i]) + for i in range(len(config.sm_conv_fn) // 2)]) self.sm_mp2 = nn.MaxPool2d(2) - self.sm_conv2_bn = nn.BatchNorm2d(self.snapshot_odim) + self.sm_conv2_bn = nn.BatchNorm2d(self.context_odim) # rnn layers self.batch_first = False @@ -106,7 +107,7 @@ def __init__(self, config, widx2vec, class_weight=None, idx=None): batch_first=self.batch_first, bidirectional=self.bidirectional) - if not config.no_snapshot and not config.no_snapshot_title: + if not config.no_context and not config.no_context_title: self.st_rnn = nn.LSTM(self.st_rnn_idim, config.st_rnn_hdim, config.st_rnn_ln, dropout=config.st_rnn_dr, @@ -119,7 +120,7 @@ def __init__(self, config, widx2vec, class_weight=None, idx=None): dropout=config.t_rnn_dr, batch_first=self.batch_first, bidirectional=self.bidirectional) - if not config.no_snapshot and not config.no_snapshot_title: + if not config.no_context and not config.no_context_title: self.st_rnn = nn.GRU(self.st_rnn_idim, config.st_rnn_hdim, config.st_rnn_ln, dropout=config.st_rnn_dr, @@ -144,7 +145,8 @@ def __init__(self, config, widx2vec, class_weight=None, idx=None): params = self.model_params(debug=False) self.optimizer = optim.Adam(params, lr=config.lr, - weight_decay=config.wd) + weight_decay=config.wd, + amsgrad=True) self.scheduler = \ optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, factor=0.5, @@ -153,10 +155,10 @@ def __init__(self, config, widx2vec, class_weight=None, idx=None): # https://discuss.pytorch.org/t/loss-weighting-imbalanced-data/11698 self.criterion = nn.CrossEntropyLoss(weight=class_weight) - if config.summary and idx is not None: - summary_path = 'runs/' + config.model_name + '_' + str(idx) - self.train_writer = SummaryWriter(log_dir=summary_path + '/train') - self.valid_writer = SummaryWriter(log_dir=summary_path + '/valid') + if config.summary: + summary_path = 'runs/' + config.model_name + \ + ('_%d' % idx if idx is not None else '') + self.summary_writer = SummaryWriter(log_dir=summary_path) def init_word_embed(self, widx2vec): self.word_embed.weight.data.copy_(torch.from_numpy(np.array(widx2vec))) @@ -168,9 +170,10 @@ def init_conv_list(conv_list): # https://discuss.pytorch.org/t/weight-initilzation/157/9 nn.init.xavier_uniform_(conv.weight.data) nn.init.uniform_(conv.bias.data) + init_conv_list(self.tc_conv) - if not self.config.no_snapshot: + if not self.config.no_context: init_conv_list(self.sm_conv1) init_conv_list(self.sm_conv2) @@ -294,7 +297,7 @@ def get_rnn_out(self, batch_size, batch_max_seqlen, tl, # https://github.com/pytorch/pytorch/issues/3587#issuecomment-354284160 if rnn.bidirectional: bw_idxes = \ - torch.arange(0, batch_size, dtype=torch.long)\ + torch.arange(0, batch_size, dtype=torch.long) \ .to(self.device) * batch_max_seqlen selected_bw = rnn_out[bw_idxes] @@ -310,7 +313,7 @@ def get_rnn_out(self, batch_size, batch_max_seqlen, tl, @Profile(__name__) def title_layer(self, tc, tw, tl, mode='t'): - # it's snapshot size if mode='st' + # it's context size if mode='st' tl = torch.LongTensor(tl).to(self.device) batch_size = tl.size(0) batch_max_seqlen = tl.max() @@ -330,7 +333,7 @@ def title_layer(self, tc, tw, tl, mode='t'): # (B, L (batch_max_seqlen), max_wordlen) tc_tensor = torch.zeros((batch_size, batch_max_seqlen, - batch_max_wordlen), dtype=torch.long)\ + batch_max_wordlen), dtype=torch.long) \ .to(self.device) for b_idx, (seq, seqlen) in enumerate(zip(tc, tl)): for w_idx in range(seqlen): @@ -341,7 +344,7 @@ def title_layer(self, tc, tw, tl, mode='t'): # assure that dataset.word2idx[self.PAD] is 0 # (B, L (batch_max_seqlen)) tw_tensor = torch.zeros((batch_size, - batch_max_seqlen), dtype=torch.long)\ + batch_max_seqlen), dtype=torch.long) \ .to(self.device) for idx, (seq, seqlen) in enumerate(zip(tw, tl)): tw_tensor[idx, :seqlen] = \ @@ -416,10 +419,10 @@ def title_layer(self, tc, tw, tl, mode='t'): self.config.t_rnn_hdim, self.config.t_rnn_out_dr, self.config.t_rnn_ln) - # for snapshot title + # for context title elif mode == 'st': - assert not self.config.no_snapshot \ - and not self.config.no_snapshot_title + assert not self.config.no_context \ + and not self.config.no_context_title return self.get_rnn_out(batch_size, batch_max_seqlen, tl, packed_lstm_input, idx_unsort, self.st_rnn, @@ -441,14 +444,14 @@ def intention_layer(self, user, dur, title): return torch.mul(gate, nonl) + torch.mul(1 - gate, concat) @Profile(__name__) - def snapshot_title_layer(self, stc, stw, stl): + def context_title_layer(self, stc, stw, stl): stacked_tc = [] stacked_tw = [] stacked_tl = [] split_idx = [0] split_titles = [] - # Stack snapshot features + # Stack context features for tc, tw, tl in zip(stc, stw, stl): stacked_tc += tc stacked_tw += tw @@ -458,33 +461,33 @@ def snapshot_title_layer(self, stc, stw, stl): # Run title layer once if len(stacked_tc) > 0: - snapshot_titles = self.title_layer( + context_titles = self.title_layer( stacked_tc, stacked_tw, stacked_tl, mode='st') else: - snapshot_titles = self.empty_st_rnn_output + context_titles = self.empty_st_rnn_output # Gather by split idx for s, e in zip(split_idx[:-1], split_idx[1:]): if s == e: split_titles.append(self.empty_st_rnn_output) else: - split_titles.append(snapshot_titles[s:e]) + split_titles.append(context_titles[s:e]) return split_titles @Profile(__name__) - def snapshot_layer(self, user_embed, stitle, sdur, sslot): + def context_layer(self, user_embed, stitle, sdur, sslot): # # test - # return torch.zeros(user_embed.size(0), self.snapshot_odim)\ + # return torch.zeros(user_embed.size(0), self.context_odim)\ # .to(self.device) - snapshot_rep_list = list() - if not self.config.no_snapshot_title: + context_rep_list = list() + if not self.config.no_context_title: for usr_emb, title, dur, slot \ in zip(user_embed, stitle, sdur, sslot): # if 0 == len(dur): - # snapshot_rep_list.append( - # torch.zeros(1, self.snapshot_odim).to(self.device)) + # context_rep_list.append( + # torch.zeros(1, self.context_odim).to(self.device)) # else: if 0 == len(dur): @@ -498,14 +501,14 @@ def snapshot_layer(self, user_embed, stitle, sdur, sslot): slot = torch.LongTensor(slot).to(self.device) usr_emb = torch.unsqueeze(usr_emb, 0) - snapshot_rep, _ = \ - self.snapshot_layer_core(usr_emb, title, dur, slot) - snapshot_rep_list.append(snapshot_rep) + context_rep, _ = \ + self.context_layer_core(usr_emb, title, dur, slot) + context_rep_list.append(context_rep) else: for usr_emb, dur, slot in zip(user_embed, sdur, sslot): # if 0 == len(dur): - # snapshot_rep_list.append( - # torch.zeros(1, self.snapshot_odim).to(self.device)) + # context_rep_list.append( + # torch.zeros(1, self.context_odim).to(self.device)) # else: if 0 == len(dur): @@ -519,26 +522,25 @@ def snapshot_layer(self, user_embed, stitle, sdur, sslot): slot = torch.LongTensor(slot).to(self.device) usr_emb = torch.unsqueeze(usr_emb, 0) - snapshot_rep, _ = \ - self.snapshot_layer_core(usr_emb, None, dur, slot) - snapshot_rep_list.append(snapshot_rep) - return torch.cat(snapshot_rep_list, dim=0) + context_rep, _ = \ + self.context_layer_core(usr_emb, None, dur, slot) + context_rep_list.append(context_rep) + return torch.cat(context_rep_list, dim=0) @Profile(__name__) - def snapshot_layer_core(self, user_embed, title, dur, slot): + def context_layer_core(self, user_embed, title, dur, slot): new_slot = None - snapshot_contents = None + context_contents = None - # ready for snapshot (contents) + # ready for context (contents) total_slots = self.config.sm_day_num * self.config.sm_slot_num saved_slot = list() has_preregistered_events = dur.size(0) > 0 if has_preregistered_events: - if not self.config.no_snapshot_title: + if not self.config.no_context_title: assert title is not None - title = title.data new_title = list() dur = dur / 30 - 1 new_slot = list() @@ -547,7 +549,7 @@ def snapshot_layer_core(self, user_embed, title, dur, slot): 't %d, d %d, s %d' % ( title.size(0), dur.size(0), slot.size(0)) - for i, (d, s) in enumerate(zip(dur.data, slot.data)): + for i, (d, s) in enumerate(zip(dur, slot)): new_slot.append(s) new_title.append(title[i]) for k in range(d): @@ -558,7 +560,7 @@ def snapshot_layer_core(self, user_embed, title, dur, slot): saved_slot = new_slot[:] new_slot = torch.LongTensor(new_slot).to(self.device) new_title = \ - torch.cat(new_title, 0).\ + torch.cat(new_title, 0). \ view(-1, self.config.st_rnn_hdim * self.num_directions) slot_embed = F.dropout(self.slot_embed(new_slot), p=self.config.slot_dr, @@ -567,9 +569,8 @@ def snapshot_layer_core(self, user_embed, title, dur, slot): # slot_embed = torch.zeros(slot_embed.size()).to(self.device) user_src_embed = user_embed.expand(slot_embed.size(0), user_embed.size(1)) - snapshot_contents = \ - torch.cat((new_title, user_src_embed.data, slot_embed.data), - 1) + context_contents = \ + torch.cat((new_title, user_src_embed, slot_embed), 1) else: dur = dur / 30 - 1 new_slot = list() @@ -577,7 +578,7 @@ def snapshot_layer_core(self, user_embed, title, dur, slot): assert dur.size(0) == slot.size(0), \ 'd %d, s %d' % (dur.size(0), slot.size(0)) - for i, (d, s) in enumerate(zip(dur.data, slot.data)): + for i, (d, s) in enumerate(zip(dur, slot)): new_slot.append(s) for k in range(d): if s + k + 1 < total_slots: @@ -592,83 +593,81 @@ def snapshot_layer_core(self, user_embed, title, dur, slot): # slot_embed = torch.zeros(slot_embed.size()).to(self.device) user_src_embed = user_embed.expand(slot_embed.size(0), user_embed.size(1)) - snapshot_contents = \ - torch.cat((user_src_embed.data, slot_embed.data), 1) + context_contents = torch.cat((user_src_embed, slot_embed), 1) saved_slot = torch.LongTensor(saved_slot).to(self.device) # ready for slot, user embed (base) - slot_all = torch.arange(0, total_slots, dtype=torch.long)\ + slot_all = torch.arange(0, total_slots, dtype=torch.long) \ .to(self.device) slot_all_embed = self.slot_embed(slot_all) user_all_embed = user_embed[0].expand(slot_all_embed.size(0), user_embed.size(1)) - if not self.config.no_snapshot_title: + if not self.config.no_context_title: zero_concat = \ torch.zeros( total_slots, - self.config.st_rnn_hdim * self.num_directions)\ + self.config.st_rnn_hdim * self.num_directions) \ .to(self.device) - snapshot_base = torch.cat((zero_concat, user_all_embed.data, - slot_all_embed.data), 1) + context_base = torch.cat((zero_concat, user_all_embed, + slot_all_embed), 1) else: - snapshot_base = torch.cat((user_all_embed.data, - slot_all_embed.data), 1) + context_base = torch.cat((user_all_embed, slot_all_embed), 1) - # ready for snapshot map (empty) - snapshot_map = torch.zeros(total_slots, self.sm_conv1_idim)\ + # ready for context map (empty) + context_map = torch.zeros(total_slots, self.sm_conv1_idim) \ .to(self.device) index = None if has_preregistered_events: - index = new_slot.data.unsqueeze(1) - index = index.expand_as(snapshot_contents) - slot_all = slot_all.data.unsqueeze(1) - slot_all = slot_all.expand_as(snapshot_base) + index = new_slot.unsqueeze(1) + index = index.expand_as(context_contents) + slot_all = slot_all.unsqueeze(1) + slot_all = slot_all.expand_as(context_base) # scatter base and then the contents - snapshot_map.data.scatter_(0, slot_all, snapshot_base) + context_map.scatter_(0, slot_all, context_base) if has_preregistered_events: - snapshot_map.data.scatter_(0, index, snapshot_contents) + context_map.scatter_(0, index, context_contents) # (sm_day_num, sm_slot_num, # user_embed_dim + slot_embed_dim + st_rnn_hdim * num_directions) - snapshot_map = snapshot_map.view(self.config.sm_day_num, - self.config.sm_slot_num, - self.sm_conv1_idim) + context_map = context_map.view(self.config.sm_day_num, + self.config.sm_slot_num, + self.sm_conv1_idim) # (user_embed_dim + slot_embed_dim + st_rnn_hdim * num_directions, # sm_day_num, # sm_slot_num) - snapshot_map = torch.transpose( - torch.transpose(snapshot_map, 0, 2), 1, 2) + context_map = torch.transpose( + torch.transpose(context_map, 0, 2), 1, 2) # multiple filter conv conv_list = [self.sm_conv1, self.sm_conv2] - snapshot_mf = torch.unsqueeze(snapshot_map, 0).to(self.device) + context_mf = torch.unsqueeze(context_map, 0).to(self.device) for layer_idx, sm_conv in enumerate(conv_list): conv_result = list() for filter_idx, conv in enumerate(sm_conv): - conv_out = conv(snapshot_mf) + conv_out = conv(context_mf) conv_result.append(conv_out) - snapshot_mf = torch.cat(conv_result, 1) + context_mf = torch.cat(conv_result, 1) if layer_idx == 0: - snapshot_mf = F.rrelu(self.sm_conv1_bn(snapshot_mf)) + context_mf = F.rrelu(self.sm_conv1_bn(context_mf)) else: # layer_idx == 1 - snapshot_mf = torch.max(self.sm_conv2_bn(snapshot_mf) - .view(1, snapshot_mf.size(1), -1), 2)[0] + context_mf = torch.max(self.sm_conv2_bn(context_mf) + .view(1, context_mf.size(1), -1), 2)[0] - return snapshot_mf, saved_slot + return context_mf, saved_slot @Profile(__name__) - def matching_layer(self, title, intention, snapshot_mf, grid): + def matching_layer(self, title, intention, context_mf, grid): # Highway network for mf concat_seq = list() - if not self.config.no_snapshot: + if not self.config.no_context: concat_seq.append(grid.to(self.device)) - concat_seq.insert(0, snapshot_mf) + concat_seq.insert(0, context_mf) if not self.config.no_intention: concat_seq.insert(0, intention) else: @@ -712,7 +711,7 @@ def forward(self, user, dur, tc, tw, tl, stc, stw, stl, sdur, sslot, gr): title_rep = self.title_layer(tc, tw, tl) user_embed = None - if not self.config.no_intention or not self.config.no_snapshot: + if not self.config.no_intention or not self.config.no_context: user_embed = self.user_embed(user.to(self.device)) # user_embed = torch.zeros(user_embed.size()).to(self.device) @@ -735,19 +734,18 @@ def forward(self, user, dur, tc, tw, tl, stc, stw, stl, sdur, sslot, gr): intention_rep = \ self.intention_layer(user_embed, dur_embed, title_rep) - if not self.config.no_snapshot: + if not self.config.no_context: stitle_rep = None - if not self.config.no_snapshot_title: - # (B, (VARIABLE snapshot length, st_rnn_hdim * num_directions)) - stitle_rep = self.snapshot_title_layer(stc, stw, stl) + if not self.config.no_context_title: + # (B, (VARIABLE context length, st_rnn_hdim * num_directions)) + stitle_rep = self.context_title_layer(stc, stw, stl) # (B, sum(config.sm_conv_fn[len(config.sm_conv_fn)//2:])) - snapshot_mf = \ - self.snapshot_layer(user_embed, stitle_rep, sdur, sslot) + context_mf = self.context_layer(user_embed, stitle_rep, sdur, sslot) # (B, config.sm_day_num * config.sm_slot_num) output = \ - self.matching_layer(title_rep, intention_rep, snapshot_mf, gr) + self.matching_layer(title_rep, intention_rep, context_mf, gr) else: output = self.matching_layer(title_rep, intention_rep, None, None) @@ -759,7 +757,7 @@ def get_regloss(self, weight_decay=None): if weight_decay is None: weight_decay = self.config.wd reg_loss = 0 - params = [self.output_fc1, self.output_fc2, self.it_nonl, self.it_gate] + params = [self.output_fc1, self.it_nonl, self.it_gate] for param in params: reg_loss += torch.norm(param.weight, 2) return reg_loss * weight_decay @@ -774,27 +772,33 @@ def decay_lr(self, lr_decay=None): print('\tlearning rate decay to %.3f' % self.config.lr) @Profile(__name__) - def get_metrics(self, outputs, targets, ex_targets=None): - outputs_max_idxes = torch.squeeze(torch.topk(outputs, 1)[1], dim=1).data - outputs_topall_idxes = torch.topk(outputs, self.n_classes)[1].data - targets = targets.data.cpu() - outputs = outputs.data.cpu() - - topk = 10 - ex1 = 0. - ex5 = 0. - ex10 = 0. - for o, et, t in zip(outputs, ex_targets, targets): - assert et[t] == 0, et[t] - output = o[:] - (et[:] * 1e16) - output_topk = torch.topk(output, topk)[1] - - if t == output_topk[0]: - ex1 += 1. - if t in output_topk[:5]: - ex5 += 1. - if t in output_topk: - ex10 += 1. + def get_metrics(self, outputs, targets, ex_targets=None, topk=5): + + def get_recalls(): + def get_r1_r5(_o, _t): + out_topk = torch.topk(_o, topk)[1] + if _t == out_topk[0]: + return 1., 1. + else: + if _t in out_topk: + return 0., 1. + return 0., 0. + + ex1 = 0. + ex5 = 0. + if ex_targets is None: + for o, t in zip(outputs, targets): + r1, r5 = get_r1_r5(o, t) + ex1 += r1 + ex5 += r5 + else: + for o, et, t in zip(outputs, ex_targets, targets): + assert et[t] == 0 + output = o[:] - (et[:] * 1e16) + r1, r5 = get_r1_r5(output, t) + ex1 += r1 + ex5 += r5 + return ex1, ex5 def ndcg_at_k(r, k): def get_dcg(_r, _k): @@ -814,46 +818,50 @@ def inverse_euclidean_distance(target, pred): ** 2) ** 0.5 return 1. / (euc + 1.) - mrr = 0. - ndcg_at_5 = 0. - ndcg_at_10 = 0. - for target_slot_idx, ota in zip(targets, outputs_topall_idxes): + def get_mrr_ndcg(calc_ndcg=False): + mrr_sum = 0. + ndcg_at_5_sum = 0. + outputs_topall_idxes = torch.topk(outputs, self.n_classes)[1] + # relevance vector for nDCG - relevance_vector = [0.] * self.n_classes + relevance_vector = [0.] * self.n_classes if calc_ndcg else None - target_rank_idx = -1 - for rank_idx, slot_idx in enumerate(ota): - if target_slot_idx.item() == slot_idx.item(): - target_rank_idx = rank_idx + for target_slot_idx, ota in zip(targets, outputs_topall_idxes): + target_rank_idx = -1 + for rank_idx, slot_idx in enumerate(ota): + if target_slot_idx.item() == slot_idx.item(): + target_rank_idx = rank_idx + if not calc_ndcg: + break - # assign ieuc - relevance_vector[rank_idx] = \ - inverse_euclidean_distance(target_slot_idx.item(), - slot_idx.item()) + if calc_ndcg: + # assign ieuc + relevance_vector[rank_idx] = \ + inverse_euclidean_distance(target_slot_idx.item(), + slot_idx.item()) - assert target_rank_idx > -1 + assert target_rank_idx > -1 - # MRR - mrr += 1. / (target_rank_idx + 1.) + # MRR + mrr_sum += 1. / (target_rank_idx + 1) - # nDCG@5 and nDCG@10 - ndcg_at_5 += ndcg_at_k(relevance_vector, 5) - ndcg_at_10 += ndcg_at_k(relevance_vector, 10) + if calc_ndcg: + # nDCG@5 + ndcg_at_5_sum += ndcg_at_k(relevance_vector, 5) + return mrr_sum, ndcg_at_5_sum - ieuc = 0. - for t, m in zip(targets, outputs_max_idxes): - ieuc += inverse_euclidean_distance(t.item(), m.item()) + def get_ieuc(): + ieuc_sum = 0. + outputs_max_idxes = torch.max(outputs, 1)[1] + for t, m in zip(targets, outputs_max_idxes): + ieuc_sum += inverse_euclidean_distance(t.item(), m.item()) + return ieuc_sum - len_outputs = len(outputs) - ex1 /= len_outputs - ex5 /= len_outputs - ex10 /= len_outputs - mrr /= len_outputs - ieuc /= len_outputs - ndcg_at_5 /= len_outputs - ndcg_at_10 /= len_outputs + recall1, recall5 = get_recalls() + mrr, ndcg_at_5 = get_mrr_ndcg() + ieuc = get_ieuc() - return ex1, ex5, ex10, mrr, ieuc, ndcg_at_5, ndcg_at_10 + return recall1, recall5, mrr, ieuc, ndcg_at_5 def save_checkpoint(self, state, filename=None): if filename is None: @@ -862,47 +870,40 @@ def save_checkpoint(self, state, filename=None): else: filename = os.path.join(self.config.checkpoint_dir, filename + '.pth') - print('\t=> save checkpoint %s' % filename) + print('\t-> save checkpoint %s' % filename) if not os.path.exists(self.config.checkpoint_dir): os.mkdir(self.config.checkpoint_dir) torch.save(state, filename) - def load_checkpoint(self, filename=None, map_location=None): + def load_checkpoint(self, filename=None): if filename is None: - filename = self.config.checkpoint_dir + self.config.model_name + filename = os.path.join(self.config.checkpoint_dir, + self.config.model_name + '.pth') else: - filename = self.config.checkpoint_dir + filename - print('\t=> load checkpoint %s' % filename) - checkpoint = torch.load(filename, map_location=map_location) + filename = os.path.join(self.config.checkpoint_dir, + filename + '.pth') + print('\t-> load checkpoint %s' % filename) + checkpoint = torch.load(filename, + map_location=None if 'cuda' == self.device.type + else 'cpu') self.load_state_dict(checkpoint['state_dict']) self.optimizer.load_state_dict(checkpoint['optimizer']) # self.config = checkpoint['config'] @Profile(__name__) - def write_summary(self, mode, metrics, offset): - if mode == 'tr': - writer = self.train_writer - elif mode == 'va': - writer = self.valid_writer - elif mode == 'te': + def write_summary(self, mode, loss, metrics, offset, add_histogram=False): + if mode != 'tr': return - else: - raise ValueError('Invalid mode %s' % mode) - writer.add_scalar('loss', metrics[0], offset) - # writer.add_scalar('recall@1', metrics[1], offset) - # writer.add_scalar('recall@5', metrics[2], offset) - # writer.add_scalar('recall@10', metrics[3], offset) - writer.add_scalar('mrr', metrics[4], offset) - # writer.add_scalar('ieuc', metrics[5], offset) - # writer.add_scalar('ndcg@5', metrics[6], offset) - # writer.add_scalar('ndcg@10', metrics[7], offset) + self.summary_writer.add_scalar('loss', loss, offset) + self.summary_writer.add_scalar('mrr', metrics[2], offset) - for name, param in self.named_parameters(): - if not param.requires_grad: - continue - writer.add_histogram(name, param.clone().cpu().data.numpy(), offset) + if add_histogram: + for name, param in self.named_parameters(): + if not param.requires_grad: + continue + self.summary_writer.add_histogram( + name, param.clone().cpu().data.numpy(), offset) def close_summary_writer(self): - self.train_writer.close() - self.valid_writer.close() + self.summary_writer.close() diff --git a/test.py b/test.py index f01307d..5cbbb15 100644 --- a/test.py +++ b/test.py @@ -29,8 +29,7 @@ def get_model(widx2vec, model_path, dvc, arg): model = NETS(ckpt_config, widx2vec).to(dvc) model.config.checkpoint_dir = model_dir + '/' - model.load_checkpoint(filename=model_filename, - map_location=None if 'cuda' == dvc.type else 'cpu') + model.load_checkpoint(filename=model_filename[:-4]) # .pth # import pprint # pprint.PrettyPrinter().pprint(_model.config.__dict__) return model @@ -52,12 +51,12 @@ def measure_performance(test_set, model, dvc, batch_size=1): for d_idx, ex in enumerate(test_loader): labels = ex[-1].to(dvc) outputs, reps = model(*ex[:-1]) - metrics = model.get_metrics(outputs, labels, ex[-2]) + metrics = model.get_metrics(outputs, labels, ex[-2].to(dvc)) performance_dict['recall1'] += metrics[0] performance_dict['recall5'] += metrics[1] - performance_dict['mrr'] += metrics[3] - performance_dict['ieuc'] += metrics[4] + performance_dict['mrr'] += metrics[2] + performance_dict['ieuc'] += metrics[3] performance_dict['count'] += outputs.data.size()[0] performance_dict['steps'] += 1. @@ -93,16 +92,15 @@ def set_seed_all(seed): arg_parser.add_argument("--serialized_data_path", type=str, default='./data/preprocess_test.pkl') arg_parser.add_argument("--model_path", type=str, - default='./data/nets_gradclip_180501_5_1.pth') + default='./data/nets_180512_0.pth') arg_parser.add_argument("--trained_dict_path", type=str, - default='./data/preprocess_20180429_dict.pkl') + default='./data/preprocess_180504_dict.pkl') arg_parser.add_argument("--seed", type=int, default=3) arg_parser.add_argument('--yes_cuda', type=int, default=1) args = arg_parser.parse_args() use_cuda = args.yes_cuda > 0 and torch.cuda.is_available() device = torch.device("cuda" if use_cuda else "cpu") - print('CUDA device_count {0}'.format(torch.cuda.device_count()) if use_cuda else 'CPU') diff --git a/utils.py b/utils.py index bc70f7e..309ba5c 100644 --- a/utils.py +++ b/utils.py @@ -27,16 +27,6 @@ def with_profiling(*args, **kwargs): return with_profiling -def print_prof_data(): - for fname, data in sorted(PROF_DATA.items()): - max_time = max(data[1]) - avg_time = sum(data[1]) / len(data[1]) - total_time = sum(data[1]) - print("\n{} => called {} times.".format(fname, data[0])) - print("Time total: {:.3f}, max: {:.3f}, avg: {:.3f}".format( - total_time, max_time, avg_time)) - - def clear_prof_data(): global PROF_DATA PROF_DATA = {}