forked from eifuentes/swae-pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmnist.py
116 lines (109 loc) · 5.68 KB
/
mnist.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import argparse
import matplotlib.pyplot as plt
import torch
import torch.optim as optim
import torchvision.utils as vutils
from torchvision import datasets, transforms
from swae.distributions import rand_cirlce2d, rand_ring2d, rand_uniform2d
from swae.models.mnist import MNISTAutoencoder
from swae.trainer import SWAEBatchTrainer
def main():
# train args
parser = argparse.ArgumentParser(description='Sliced Wasserstein Autoencoder PyTorch MNIST Example')
parser.add_argument('--batch-size', type=int, default=500, metavar='N',
help='input batch size for training (default: 500)')
parser.add_argument('--epochs', type=int, default=30, metavar='N',
help='number of epochs to train (default: 30)')
parser.add_argument('--lr', type=float, default=0.001, metavar='LR',
help='learning rate (default: 0.001)')
parser.add_argument('--alpha', type=float, default=0.9, metavar='A',
help='RMSprop alpha/rho (default: 0.9)')
parser.add_argument('--distribution', type=str, default='circle', metavar='DIST',
help='Latent Distribution (default: circle)')
parser.add_argument('--no-cuda', action='store_true', default=False,
help='disables CUDA training')
parser.add_argument('--num_workers', type=int, default=8, metavar='N',
help='number of dataloader workers if device is CPU (default: 8)')
parser.add_argument('--seed', type=int, default=7, metavar='S',
help='random seed (default: 7)')
parser.add_argument('--log-interval', type=int, default=10, metavar='N',
help='number of batches to log training status (default: 10)')
args = parser.parse_args()
# set random seed
torch.manual_seed(args.seed)
# determine device and device dep. args
use_cuda = not args.no_cuda and torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
dataloader_kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {'num_workers': args.num_workers, 'pin_memory': False}
# log args
print('batch size {}\nepochs {}\nRMSprop lr {} alpha {}\ndistribution {}\nusing device {}\nseed set to {}'.format(
args.batch_size, args.epochs, args.lr, args.alpha, args.distribution, device.type, args.seed
))
# build train and test set data loaders
train_loader = torch.utils.data.DataLoader(
datasets.MNIST('../data', train=True, download=True,
transform=transforms.Compose([transforms.ToTensor()])),
batch_size=args.batch_size, shuffle=True, **dataloader_kwargs)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST('../data', train=False, download=True,
transform=transforms.Compose([transforms.ToTensor()])),
batch_size=64, shuffle=False, **dataloader_kwargs)
# create encoder and decoder
model = MNISTAutoencoder().to(device)
print(model)
# create optimizer
# matching default Keras args for RMSprop
optimizer = optim.RMSprop(model.parameters(), lr=args.lr, alpha=args.alpha)
# determine latent distribution
if args.distribution == 'circle':
distribution_fn = rand_cirlce2d
elif args.distribution == 'ring':
distribution_fn = rand_ring2d
else:
distribution_fn = rand_uniform2d
# create batch sliced_wasserstein autoencoder trainer
trainer = SWAEBatchTrainer(model, optimizer, distribution_fn, device=device)
# put networks in training mode
model.train()
# train networks for n epochs
print('training...')
for epoch in range(args.epochs):
if epoch > 10:
trainer.weight_swd *= 1.1
# train autoencoder on train dataset
for batch_idx, (x, y) in enumerate(train_loader, start=0):
batch = trainer.train_on_batch(x)
if (batch_idx + 1) % args.log_interval == 0:
print('Train Epoch: {} ({:.2f}%) [{}/{}]\tLoss: {:.6f}'.format(
epoch + 1, float(epoch + 1) / (args.epochs) * 100.,
(batch_idx + 1), len(train_loader),
batch['loss'].item()))
# evaluate autoencoder on test dataset
test_encode, test_targets, test_loss = list(), list(), 0.0
with torch.no_grad():
for test_batch_idx, (x_test, y_test) in enumerate(test_loader, start=0):
test_evals = trainer.test_on_batch(x_test)
test_encode.append(test_evals['encode'].detach())
test_loss += test_evals['loss'].item()
test_targets.append(y_test)
test_encode, test_targets = torch.cat(test_encode).cpu().numpy(), torch.cat(test_targets).cpu().numpy()
test_loss /= len(test_loader)
print('Test Epoch: {} ({:.2f}%)\tLoss: {:.6f}'.format(
epoch + 1, float(epoch + 1) / (args.epochs) * 100.,
test_loss))
# save encoded samples plot
plt.figure(figsize=(10, 10))
plt.scatter(test_encode[:, 0], -test_encode[:, 1], c=(10 * test_targets), cmap=plt.cm.Spectral)
plt.xlim([-1.5, 1.5])
plt.ylim([-1.5, 1.5])
# plt.title('Test Latent Space\nLoss: {:.5f}'.format(test_loss))
plt.savefig('../data/test_latent_epoch_{}.png'.format(epoch + 1))
plt.close()
# save sample input and reconstruction
vutils.save_image(x,
'../data/test_samples_epoch_{}.png'.format(epoch + 1))
vutils.save_image(batch['decode'].detach(),
'../data/test_reconstructions_epoch_{}.png'.format(epoch + 1),
normalize=True)
if __name__ == '__main__':
main()