|
| 1 | +import numpy as np |
| 2 | +import torch |
| 3 | + |
| 4 | +from torchvision import transforms |
| 5 | +from prefetch_generator import BackgroundGenerator |
| 6 | +import util |
| 7 | + |
| 8 | +import torchvision.datasets as torch_data |
| 9 | +from .time_series import uea as uea_data |
| 10 | +from .tabular import maf as maf_data |
| 11 | + |
| 12 | + |
| 13 | +def _gen_mini_dataset(dataset, dataset_ratio): |
| 14 | + n_dataset = dataset.shape[0] |
| 15 | + n_mini_dataset = int(dataset_ratio*n_dataset) |
| 16 | + s = torch.from_numpy(np.random.choice( |
| 17 | + np.arange(n_dataset, dtype=np.int64), n_mini_dataset, replace=False) |
| 18 | + ) |
| 19 | + return dataset[s] |
| 20 | + |
| 21 | + |
| 22 | +class DataLoaderX(torch.utils.data.DataLoader): |
| 23 | + def __iter__(self): |
| 24 | + return BackgroundGenerator(super().__iter__()) |
| 25 | + |
| 26 | + |
| 27 | +class TabularLoader: |
| 28 | + def __init__(self,opt, data, batch_size=None, shuffle=True): |
| 29 | + |
| 30 | + self.data_size = data.shape[0] |
| 31 | + self.opt = opt |
| 32 | + self.device = opt.device |
| 33 | + |
| 34 | + self.data = data.to(opt.device) |
| 35 | + self.batch_size = opt.batch_size if batch_size is None else batch_size |
| 36 | + self.shuffle = shuffle |
| 37 | + |
| 38 | + self.input_dim = data.shape[-1] |
| 39 | + self.output_dim= [data.shape[-1]] |
| 40 | + |
| 41 | + loc = torch.zeros(data.shape[-1]).to(opt.device) |
| 42 | + covariance_matrix = torch.eye(data.shape[-1]).to(opt.device) # TODO(Guan) scale down the cov ? |
| 43 | + self.p_z0 = torch.distributions.MultivariateNormal(loc=loc, covariance_matrix=covariance_matrix) |
| 44 | + self._reset_idxs() |
| 45 | + self.data_size = len(self.idxs_by_batch_size) |
| 46 | + |
| 47 | + def _reset_idxs(self): |
| 48 | + idxs = torch.randperm(self.data.shape[0]) if self.shuffle else torch.arange(self.data.shape[0]) |
| 49 | + self.idxs_by_batch_size = idxs.split(self.batch_size) |
| 50 | + self.batch_idx = 0 |
| 51 | + |
| 52 | + def __len__(self): |
| 53 | + return self.data_size |
| 54 | + |
| 55 | + def __iter__(self): |
| 56 | + return self |
| 57 | + |
| 58 | + def __next__(self): |
| 59 | + if self.batch_idx >= len(self.idxs_by_batch_size): |
| 60 | + self._reset_idxs() |
| 61 | + raise StopIteration |
| 62 | + |
| 63 | + s = self.idxs_by_batch_size[self.batch_idx] |
| 64 | + self.batch_idx += 1 |
| 65 | + x = self.data[s] |
| 66 | + logp_diff_t1 = torch.zeros(x.shape[0], 1, device=x.device) |
| 67 | + return (x, logp_diff_t1), self.p_z0 |
| 68 | + |
| 69 | + |
| 70 | +def get_uea_loader(opt): |
| 71 | + |
| 72 | + print(util.magenta("loading uea data...")) |
| 73 | + |
| 74 | + dataset_name = { |
| 75 | + 'CharT' :'CharacterTrajectories', |
| 76 | + 'ArtWR' : 'ArticularyWordRecognition', |
| 77 | + 'SpoAD' : 'SpokenArabicDigits', |
| 78 | + }.get(opt.problem) |
| 79 | + |
| 80 | + missing_rate = 0.0 |
| 81 | + device = opt.device |
| 82 | + intensity_data = True |
| 83 | + |
| 84 | + (times, train_dataloader, val_dataloader, |
| 85 | + test_dataloader, num_classes, input_channels) = uea_data.get_data(dataset_name, missing_rate, device, |
| 86 | + intensity=intensity_data, |
| 87 | + batch_size=opt.batch_size) |
| 88 | + |
| 89 | + # we'll return dataloader and store the rest in opt |
| 90 | + opt.times = times |
| 91 | + opt.output_dim = num_classes |
| 92 | + opt.input_dim = input_channels |
| 93 | + return train_dataloader, test_dataloader |
| 94 | + |
| 95 | + |
| 96 | +def get_tabular_loader(opt, test_batch_size=1000): |
| 97 | + assert opt.problem in ['gas', 'miniboone'] |
| 98 | + print(util.magenta("loading tabular data...")) |
| 99 | + |
| 100 | + data = maf_data.get_data(opt.problem) |
| 101 | + data.trn.x = torch.from_numpy(data.trn.x) |
| 102 | + data.val.x = torch.from_numpy(data.val.x) |
| 103 | + data.tst.x = torch.from_numpy(data.tst.x) |
| 104 | + |
| 105 | + if opt.dataset_ratio < 1.0: |
| 106 | + data.trn.x = _gen_mini_dataset(data.trn.x, opt.dataset_ratio) |
| 107 | + data.val.x = _gen_mini_dataset(data.val.x, opt.dataset_ratio) |
| 108 | + data.tst.x = _gen_mini_dataset(data.tst.x, opt.dataset_ratio) |
| 109 | + |
| 110 | + train_loader = TabularLoader(opt, data.trn.x, shuffle=True) |
| 111 | + val_loader = TabularLoader(opt, data.val.x, batch_size=test_batch_size, shuffle=False) |
| 112 | + test_loader = TabularLoader(opt, data.tst.x, batch_size=test_batch_size, shuffle=False) |
| 113 | + |
| 114 | + opt.input_dim = train_loader.input_dim |
| 115 | + opt.output_dim = train_loader.output_dim |
| 116 | + |
| 117 | + return train_loader, test_loader |
| 118 | + |
| 119 | + |
| 120 | +def get_img_loader(opt, test_batch_size=1000): |
| 121 | + print(util.magenta("loading image data...")) |
| 122 | + |
| 123 | + dataset_builder, root, input_dim, output_dim = { |
| 124 | + 'mnist': [torch_data.MNIST, 'data/img/mnist', [1,28,28], 10], |
| 125 | + 'SVHN': [torch_data.SVHN, 'data/img/svhn', [3,32,32], 10], |
| 126 | + 'cifar10': [torch_data.CIFAR10,'data/img/cifar10',[3,32,32], 10], |
| 127 | + }.get(opt.problem) |
| 128 | + opt.input_dim = input_dim |
| 129 | + opt.output_dim = output_dim |
| 130 | + |
| 131 | + transform = transforms.Compose([ |
| 132 | + transforms.ToTensor(), |
| 133 | + transforms.Normalize((0.1307,), (0.3081,)), |
| 134 | + ]) |
| 135 | + feed_dict = dict(download=True, root=root, transform=transform) |
| 136 | + train_dataset = dataset_builder(**feed_dict) if opt.problem=='SVHN' else dataset_builder(train=True, **feed_dict) |
| 137 | + test_dataset = dataset_builder(**feed_dict) if opt.problem=='SVHN' else dataset_builder(train=False, **feed_dict) |
| 138 | + |
| 139 | + feed_dict = dict(num_workers=2, drop_last=True) |
| 140 | + train_loader = DataLoaderX(train_dataset, batch_size=opt.batch_size, shuffle=True, **feed_dict) |
| 141 | + test_loader = DataLoaderX(test_dataset, batch_size=test_batch_size, shuffle=False, **feed_dict) |
| 142 | + |
| 143 | + return train_loader, test_loader |
0 commit comments