Skip to content

Commit

Permalink
reproducibility settings
Browse files Browse the repository at this point in the history
  • Loading branch information
donghyeonk committed Jul 1, 2023
1 parent 4e28f74 commit fadb8ad
Showing 1 changed file with 40 additions and 15 deletions.
55 changes: 40 additions & 15 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,21 @@
import numpy as np
import os
import pickle
import pprint
import random
import torch


def get_dataset(cfg, trained_dict_path):
print('Creating the test dataset pickle..', )
nets_dictionary = pickle.load(open(trained_dict_path, 'rb'))
print('Creating the test dataset pickle ...', )
with open(trained_dict_path, 'rb') as f:
nets_dictionary = pickle.load(f)
test_set = dataset.NETSDataset(cfg, nets_dictionary)
if len(test_set.test_data) == 0:
print('no events')
return None
pickle.dump(test_set, open(cfg.preprocess_save_path, 'wb'))
with open(cfg.preprocess_save_path, 'wb') as f:
pickle.dump(test_set, f)
return test_set


Expand All @@ -33,23 +36,29 @@ def get_model(widx2vec, model_path, dvc, idx2dur, arg):
else None).to(dvc)
model.config.checkpoint_dir = model_dir + '/'
model.load_checkpoint(filename=model_filename[:-4]) # .pth
# import pprint
# pprint.PrettyPrinter().pprint(_model.config.__dict__)
pprint.PrettyPrinter().pprint(model.config.__dict__)
return model, ckpt_config


def measure_performance(test_set, model, conf, dvc, batch_size=1):
def measure_performance(test_set, model, conf, dvc, batch_size=1, write_log=False, debug=False):
model = model.eval()

performance_dict = dict()
performance_dict['recall1'] = 0.
performance_dict['recall5'] = 0.
performance_dict['mrr'] = 0.
performance_dict['ieuc'] = 0.

# debugging
sample_idx = -1
log_f = None
if debug:
sample_idx = 0
log_f = open(f'res_bs{batch_size}.log', 'w') if write_log else None

_, _, test_loader = test_set.get_dataloader(batch_size=batch_size)
with torch.inference_mode():
for d_idx, ex in enumerate(test_loader):
for batchidx, ex in enumerate(test_loader):
labels = ex[-1].to(dvc)
outputs = model(*ex[:-1])
metrics = get_metrics(outputs, labels, model.n_day_slots,
Expand All @@ -62,8 +71,20 @@ def measure_performance(test_set, model, conf, dvc, batch_size=1):
performance_dict['mrr'] += metrics[2]
performance_dict['ieuc'] += metrics[3]

if d_idx % 1000 == 0 and d_idx > 0:
print(d_idx)
if batchidx % 1000 == 0 and batchidx > 0:
print(batchidx)

if debug:
inout = ex + (outputs, [float("nan")] * labels.size()[0])
for io in zip(*inout):
# if 40 == sample_idx:
# print(sample_idx, io[:-1])
# print(sample_idx, io[:-1])
log_f.write(f'{sample_idx} {io[:-2]}\n')
sample_idx += 1

if write_log and log_f is not None:
log_f.close()

n_samples = len(test_loader.dataset)
recall1 = performance_dict['recall1'] / n_samples
Expand All @@ -79,6 +100,7 @@ def measure_performance(test_set, model, conf, dvc, batch_size=1):


def set_seed_all(seed):
os.environ['PYTHONHASHSEED'] = str(seed)
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
Expand All @@ -104,21 +126,24 @@ def set_seed_all(seed):
if use_cuda else 'CPU')
print(torch.__version__)

set_seed_all(args.seed)
torch.use_deterministic_algorithms(True)
set_seed_all(args.seed) # reproducibility
torch.use_deterministic_algorithms(True) # reproducibility
torch.backends.cudnn.deterministic = True # reproducibility
torch.backends.cudnn.benchmark = False # reproducibility

config = dataset.Config()
config.test_path = args.input_path
config.preprocess_save_path = args.serialized_data_path
config.preprocess_load_path = args.serialized_data_path

print('Loading test dataset..')
print('Loading test dataset ...')
test_dataset = get_dataset(config, args.trained_dict_path)
assert test_dataset is not None
test_dataset.config.data_workers = 0 # reproducibility

print('Loading NESA model..')
print('Loading NESA model ...')
nesa_model, nesa_conf = get_model(test_dataset.widx2vec, args.model_path,
device, test_dataset.idx2dur, args)

print('\nMeasuring NESA performance on test data..')
measure_performance(test_dataset, nesa_model, nesa_conf, device)
print('\nMeasuring NESA performance on test data ...')
measure_performance(test_dataset, nesa_model, nesa_conf, device, batch_size=1)

0 comments on commit fadb8ad

Please sign in to comment.