From fadb8ad07c52f655bec87c1c9a1e3847e3e87806 Mon Sep 17 00:00:00 2001 From: Donghyeon Kim <12129692+donghyeonk@users.noreply.github.com> Date: Sat, 1 Jul 2023 13:08:03 +0900 Subject: [PATCH] reproducibility settings --- test.py | 55 ++++++++++++++++++++++++++++++++++++++++--------------- 1 file changed, 40 insertions(+), 15 deletions(-) diff --git a/test.py b/test.py index 8a787a5..ce1424d 100644 --- a/test.py +++ b/test.py @@ -4,18 +4,21 @@ import numpy as np import os import pickle +import pprint import random import torch def get_dataset(cfg, trained_dict_path): - print('Creating the test dataset pickle..', ) - nets_dictionary = pickle.load(open(trained_dict_path, 'rb')) + print('Creating the test dataset pickle ...', ) + with open(trained_dict_path, 'rb') as f: + nets_dictionary = pickle.load(f) test_set = dataset.NETSDataset(cfg, nets_dictionary) if len(test_set.test_data) == 0: print('no events') return None - pickle.dump(test_set, open(cfg.preprocess_save_path, 'wb')) + with open(cfg.preprocess_save_path, 'wb') as f: + pickle.dump(test_set, f) return test_set @@ -33,12 +36,11 @@ def get_model(widx2vec, model_path, dvc, idx2dur, arg): else None).to(dvc) model.config.checkpoint_dir = model_dir + '/' model.load_checkpoint(filename=model_filename[:-4]) # .pth - # import pprint - # pprint.PrettyPrinter().pprint(_model.config.__dict__) + pprint.PrettyPrinter().pprint(model.config.__dict__) return model, ckpt_config -def measure_performance(test_set, model, conf, dvc, batch_size=1): +def measure_performance(test_set, model, conf, dvc, batch_size=1, write_log=False, debug=False): model = model.eval() performance_dict = dict() @@ -46,10 +48,17 @@ def measure_performance(test_set, model, conf, dvc, batch_size=1): performance_dict['recall5'] = 0. performance_dict['mrr'] = 0. performance_dict['ieuc'] = 0. + + # debugging + sample_idx = -1 + log_f = None + if debug: + sample_idx = 0 + log_f = open(f'res_bs{batch_size}.log', 'w') if write_log else None _, _, test_loader = test_set.get_dataloader(batch_size=batch_size) with torch.inference_mode(): - for d_idx, ex in enumerate(test_loader): + for batchidx, ex in enumerate(test_loader): labels = ex[-1].to(dvc) outputs = model(*ex[:-1]) metrics = get_metrics(outputs, labels, model.n_day_slots, @@ -62,8 +71,20 @@ def measure_performance(test_set, model, conf, dvc, batch_size=1): performance_dict['mrr'] += metrics[2] performance_dict['ieuc'] += metrics[3] - if d_idx % 1000 == 0 and d_idx > 0: - print(d_idx) + if batchidx % 1000 == 0 and batchidx > 0: + print(batchidx) + + if debug: + inout = ex + (outputs, [float("nan")] * labels.size()[0]) + for io in zip(*inout): + # if 40 == sample_idx: + # print(sample_idx, io[:-1]) + # print(sample_idx, io[:-1]) + log_f.write(f'{sample_idx} {io[:-2]}\n') + sample_idx += 1 + + if write_log and log_f is not None: + log_f.close() n_samples = len(test_loader.dataset) recall1 = performance_dict['recall1'] / n_samples @@ -79,6 +100,7 @@ def measure_performance(test_set, model, conf, dvc, batch_size=1): def set_seed_all(seed): + os.environ['PYTHONHASHSEED'] = str(seed) torch.manual_seed(seed) np.random.seed(seed) random.seed(seed) @@ -104,21 +126,24 @@ def set_seed_all(seed): if use_cuda else 'CPU') print(torch.__version__) - set_seed_all(args.seed) - torch.use_deterministic_algorithms(True) + set_seed_all(args.seed) # reproducibility + torch.use_deterministic_algorithms(True) # reproducibility + torch.backends.cudnn.deterministic = True # reproducibility + torch.backends.cudnn.benchmark = False # reproducibility config = dataset.Config() config.test_path = args.input_path config.preprocess_save_path = args.serialized_data_path config.preprocess_load_path = args.serialized_data_path - print('Loading test dataset..') + print('Loading test dataset ...') test_dataset = get_dataset(config, args.trained_dict_path) assert test_dataset is not None + test_dataset.config.data_workers = 0 # reproducibility - print('Loading NESA model..') + print('Loading NESA model ...') nesa_model, nesa_conf = get_model(test_dataset.widx2vec, args.model_path, device, test_dataset.idx2dur, args) - print('\nMeasuring NESA performance on test data..') - measure_performance(test_dataset, nesa_model, nesa_conf, device) + print('\nMeasuring NESA performance on test data ...') + measure_performance(test_dataset, nesa_model, nesa_conf, device, batch_size=1)