Skip to content

Commit

Permalink
Rename snapshot -> context
Browse files Browse the repository at this point in the history
  • Loading branch information
donghyeonk committed May 12, 2018
1 parent d529ba9 commit a1dce46
Show file tree
Hide file tree
Showing 5 changed files with 234 additions and 245 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@ $ wget -P data https://s3-us-west-1.amazonaws.com/ml-study0/nets/sample_data.csv

# Download word, character dictionaries
```
$ wget -P data https://s3-us-west-1.amazonaws.com/ml-study0/nets/preprocess_20180429_dict.pkl
$ wget -P data https://s3-us-west-1.amazonaws.com/ml-study0/nets/preprocess_180504_dict.pkl
```

# Download pretrained NETS model
```
$ wget -P data https://s3-us-west-1.amazonaws.com/ml-study0/nets/nets_gradclip_180501_5_1.pth
$ wget -P data https://s3-us-west-1.amazonaws.com/ml-study0/nets/nets_180512_0.pth
```

# Run NETS w/ sample data
Expand Down
78 changes: 39 additions & 39 deletions dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def initial_settings(self):
self.max_rs_dist = 2 # reg-st week distance
self.class_div = 2 # 168 output
self.slot_size = 336
self.max_snapshot = float("inf") # 35
self.max_context = float("inf") # 35
self.min_word_cnt = 0
self.max_title_len = 50
self.max_word_len = 50
Expand Down Expand Up @@ -100,7 +100,7 @@ def initialize_dictionary(self):
self.initial_word_dict = {}
self.invalid_weeks = []
self.user_event_cnt = {}

def update_dictionary(self, key, mode=None):
# update dictionary given a key
if mode == 'c':
Expand Down Expand Up @@ -182,7 +182,8 @@ def check_maxlen(text, w_key):
for single_what in prev_what_list:
what_split = nltk.word_tokenize(single_what)
if self.config.word2vec_type == 6:
what_split = [w.lower() for w in what_split]
what_split = [word.lower() for word
in what_split]
for word in what_split:
if word not in self.initial_word_dict:
self.initial_word_dict[word] = (
Expand Down Expand Up @@ -222,8 +223,8 @@ def get_pretrained_word(self, path):

widx2vec = []
unk_cnt = 0
widx2vec.append([0.0] * self.config.word_embed_dim) # UNK
widx2vec.append([0.0] * self.config.word_embed_dim) # PAD
widx2vec.append([0.] * self.config.word_embed_dim) # PAD
widx2vec.append([1.] * self.config.word_embed_dim) # UNK

for word, (word_idx, word_cnt) in self.initial_word_dict.items():
if word != 'UNK' and word != 'PAD':
Expand All @@ -236,14 +237,14 @@ def get_pretrained_word(self, path):

self.widx2vec = widx2vec

print('pretrained vectors', np.asarray(widx2vec).shape, 'unk', unk_cnt)
print('pretrained vectors', np.asarray(widx2vec).shape, '#unk', unk_cnt)
print('dictionary change', len(self.initial_word_dict),
'to', len(self.word2idx), len(self.idx2word), end='\n\n')

def process_data(self, path, update_dict=False):
print('### processing %s' % path)
total_data = []
max_wordlen = max_sentlen = max_dur = max_snapshot = 0
max_wordlen = max_sentlen = max_dur = max_context = 0
min_dur = float("inf")
max_slot_idx = (self.slot_size // self.class_div) - 1

Expand All @@ -265,7 +266,7 @@ def process_data(self, path, update_dict=False):
"""
prev_user = ''
prev_st_yw = ('', '')
saved_snapshot = []
saved_context = []
calendar_data = csv.reader(f, quotechar='"')

for k, features in enumerate(calendar_data):
Expand Down Expand Up @@ -310,7 +311,7 @@ def process_data(self, path, update_dict=False):
# process title feature
what_split = nltk.word_tokenize(what)
if self.config.word2vec_type == 6:
what_split = [w.lower() for w in what_split]
what_split = [word.lower() for word in what_split]
for word in what_split:
max_wordlen = \
len(word) if len(word) > max_wordlen else max_wordlen
Expand Down Expand Up @@ -357,26 +358,26 @@ def process_data(self, path, update_dict=False):
input_slot = st_slot // self.class_div
target_slot = st_slot // self.class_div

# process snapshot
# process context
if reg_seq == 0: # start of a new week
assert curr_user != prev_user or curr_st_yw != prev_st_yw
prev_user = curr_user
prev_st_yw = curr_st_yw
# prev_grid = []
input_snapshot = []
saved_snapshot = [[input_title, fine_duration, input_slot]]
input_context = []
saved_context = [[input_title, fine_duration, input_slot]]
else: # same as the prev week
assert curr_user == prev_user and curr_st_yw == prev_st_yw
# input_snapshot = copy.deepcopy(saved_snapshot)
prev_grid = [svs[2] for svs in saved_snapshot]
# input_context = copy.deepcopy(saved_context)
prev_grid = [svs[2] for svs in saved_context]
if input_slot in prev_grid:
continue
input_snapshot = saved_snapshot[:]
saved_snapshot.append(
input_context = saved_context[:]
saved_context.append(
[input_title, fine_duration, input_slot])

# transform snapshot features into slot grid
# snapshot slots w/ durations
# transform context features into slot grid
# context slots w/ durations

target_n_slot = \
int(math.ceil(input_duration / (30 * self.class_div)))
Expand All @@ -387,7 +388,7 @@ def process_data(self, path, update_dict=False):
targets_w_duration.append(target_slot + shift)

input_grid = list()
for ips in input_snapshot:
for ips in input_context:
n_slots = int(math.ceil(ips[1] / (30 * self.class_div)))
for slot_idx in range(n_slots):
slot = ips[2] + slot_idx
Expand All @@ -402,16 +403,16 @@ def process_data(self, path, update_dict=False):

assert target_slot not in input_grid

# filter by register distance & max_snapshot & recurrent
# filter by register distance & max_context & recurrent
if (reg_st_week_dist <= self.max_rs_dist
and len(input_snapshot) <= self.max_snapshot
and len(input_context) <= self.max_context
and 'False' == is_recurrent):
max_snapshot = max_snapshot \
if max_snapshot > len(input_snapshot) \
else len(input_snapshot)
max_context = max_context \
if max_context > len(input_context) \
else len(input_context)
total_data.append(
[input_user, input_title, input_duration,
input_snapshot, input_grid, target_slot])
input_context, input_grid, target_slot])

if user_id not in self.user_event_cnt:
self.user_event_cnt[user_id] = 1
Expand All @@ -431,7 +432,7 @@ def process_data(self, path, update_dict=False):
print('data size', len(total_data))
print('max duration', max_dur)
print('min duration', min_dur)
print('max snapshot', max_snapshot)
print('max context', max_context)
print('max wordlen', max_wordlen)
print('max sentlen', max_sentlen, end='\n\n')

Expand All @@ -452,7 +453,7 @@ def get_dataloader(self, batch_size=None, shuffle=True):
sampler=train_sampler,
num_workers=self.config.data_workers,
collate_fn=self.batchify,
pin_memory=True,
pin_memory=True
)
else:
train_loader = None
Expand All @@ -468,7 +469,7 @@ def get_dataloader(self, batch_size=None, shuffle=True):
sampler=valid_sampler,
num_workers=self.config.data_workers,
collate_fn=self.batchify,
pin_memory=True,
pin_memory=True
)
else:
valid_loader = None
Expand All @@ -483,7 +484,7 @@ def get_dataloader(self, batch_size=None, shuffle=True):
sampler=test_sampler,
num_workers=self.config.data_workers,
collate_fn=self.batchify,
pin_memory=True,
pin_memory=True
)

return train_loader, valid_loader, test_loader
Expand Down Expand Up @@ -549,23 +550,22 @@ def __getitem__(self, index):
tw = title[1]
tl = title[2]

# Snapshot (title, duration, slot)
snapshot = example[3]
# context (title, duration, slot)
context = example[3]
stc = []
stw = []
stl = []
sdur = []
sslot = []
for _, event in enumerate(snapshot):
for _, event in enumerate(context):
stc.append(event[0][0])
stw.append(event[0][1])
stl.append(event[0][2])
sdur.append(event[1])
sslot.append(event[2])

# Grid
grid = torch.zeros(
self.config.sm_day_num * self.config.sm_slot_num)
grid = torch.zeros(self.config.sm_day_num * self.config.sm_slot_num)
if len(example[4]) > 0:
grid[example[4]] = 1

Expand All @@ -576,12 +576,12 @@ def __getitem__(self, index):
return user, dur, tc, tw, tl, stc, stw, stl, sdur, sslot, grid, target

def lengths(self):
def maxlen_from_snapshot(snapshots):
if len(snapshots) > 0:
return max([s[0][2] for s in snapshots])
def maxlen_from_context(contexts):
if len(contexts) > 0:
return max([s[0][2] for s in contexts])
else:
return 0
return [(example[1][2], maxlen_from_snapshot(example[3]))
return [(example[1][2], maxlen_from_context(example[3]))
for example in self.examples]


Expand Down Expand Up @@ -635,7 +635,7 @@ def __init__(self):
self.sm_day_num = 7
self.sm_slot_num = 24
self.preprocess_save_path = './data/preprocess_tmp.pkl'
self.preprocess_load_path = './data/preprocess_20180429.pkl'
self.preprocess_load_path = './data/preprocess_.pkl'


if __name__ == '__main__':
Expand Down
Loading

0 comments on commit a1dce46

Please sign in to comment.