-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathload_dataset.py
75 lines (69 loc) · 3.7 KB
/
load_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import os
from torchvision import datasets, transforms
import torch
def load_dataset(dataset='MNIST', batch_size=100, dataset_path='../../data', is_cuda=False):
kwargs = {'num_workers': 0, 'pin_memory': True} if is_cuda else {}
if dataset == 'MNIST':
num_classes = 10
dataset_train = datasets.MNIST(os.path.join(dataset_path, 'MNIST'), train=True, download=False,
transform=transforms.ToTensor())
train_loader = torch.utils.data.DataLoader(
dataset_train,
batch_size=batch_size, shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST(os.path.join(dataset_path, 'MNIST'), train=False, transform=transforms.ToTensor()),
batch_size=batch_size, shuffle=False, **kwargs)
elif dataset == 'FashionMNIST':
num_classes = 10
dataset_train = datasets.FashionMNIST(os.path.join(dataset_path, 'FashionMNIST'), train=True, download=False,
transform=transforms.ToTensor())
train_loader = torch.utils.data.DataLoader(
dataset_train,
batch_size=batch_size, shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
datasets.FashionMNIST(os.path.join(dataset_path, 'FashionMNIST'), train=False,
transform=transforms.ToTensor()),
batch_size=batch_size, shuffle=False, **kwargs)
elif dataset == 'CIFAR10':
num_classes = 10
train_transform = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.424, 0.415, 0.384), (0.283, 0.278, 0.284))
])
dataset_train = datasets.CIFAR10(os.path.join(dataset_path, 'CIFAR10'), train=True, download=False,
transform=train_transform)
train_loader = torch.utils.data.DataLoader(
dataset_train,
batch_size=batch_size, shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
datasets.CIFAR10(os.path.join(dataset_path, 'CIFAR10'), train=False,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.424, 0.415, 0.384), (0.283, 0.278, 0.284))
])),
batch_size=batch_size, shuffle=False, **kwargs)
elif dataset == 'SVHN':
num_classes = 10
train_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.431, 0.430, 0.446), (0.197, 0.198, 0.199))
])
dataset_train = torch.utils.data.ConcatDataset((
datasets.SVHN(os.path.join(dataset_path, 'SVHN'), split='train', download=False, transform=train_transform),
# datasets.SVHN('../data/SVHN', split='extra', download=True, transform=train_transform))
))
train_loader = torch.utils.data.DataLoader(
dataset_train,
batch_size=batch_size, shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
datasets.SVHN(os.path.join(dataset_path, 'SVHN'), split='test', download=False,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.431, 0.430, 0.446), (0.197, 0.198, 0.199))
])),
batch_size=batch_size, shuffle=False, **kwargs)
else:
raise Exception('No valid dataset is specified.')
return train_loader, test_loader, num_classes