forked from eifuentes/swae-pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrainer.py
165 lines (145 loc) · 6.17 KB
/
trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
import numpy as np
import torch
import torch.nn.functional as F
from .distributions import rand_cirlce2d
def rand_projections(embedding_dim, num_samples=50):
"""This fn generates `L` random samples from the latent space's unit sphere.
Args:
embedding_dim (int): embedding dimension size
num_samples (int): number of random projection samples
Return:
torch.Tensor
"""
theta = [w / np.sqrt((w**2).sum())
for w in np.random.normal(size=(num_samples, embedding_dim))]
theta = np.asarray(theta)
return torch.from_numpy(theta).type(torch.FloatTensor)
def _sliced_wasserstein_distance(encoded_samples,
distribution_samples,
num_projections=50,
p=2,
device='cpu'):
"""Sliced Wasserstein Distance between encoded samples and drawn
distribution samples.
Args:
encoded_samples (toch.Tensor): embedded training tensor samples
distribution_samples (torch.Tensor): distribution training tensor
samples
num_projections (int): number of projectsion to approximate sliced
wasserstein distance
p (int): power of distance metric
device: 'cuda' or 'cpu' (default 'cpu')
Return:
torch.Tensor
"""
# derive latent space dimension size from random samples drawn from a
# distribution in it
embedding_dim = distribution_samples.size(1)
# generate random projections in latent space
projections = rand_projections(embedding_dim, num_projections).to(device)
# calculate projection of the encoded samples
encoded_projections = encoded_samples.matmul(projections.transpose(0, 1))
# calculate projection of the random distribution samples
distribution_projections = (
distribution_samples.matmul(projections.transpose(0, 1)))
# calculate the sliced wasserstein distance by
# sorting the samples per projection and
# calculating the difference between the
# encoded samples and drawn samples per projection
wasserstein_distance = (
torch.sort(encoded_projections.transpose(0, 1), dim=1)[0] -
torch.sort(distribution_projections.transpose(0, 1), dim=1)[0])
# distance between them (L2 by default for Wasserstein-2)
wasserstein_distance_p = torch.pow(wasserstein_distance, p)
# approximate wasserstein_distance for each projection
return wasserstein_distance_p.mean()
def sliced_wasserstein_distance(encoded_samples,
distribution_fn=rand_cirlce2d,
num_projections=50,
p=2,
device='cpu'):
"""Sliced Wasserstein Distance between encoded samples and drawn
distribution samples.
Args:
encoded_samples (toch.Tensor): embedded training tensor samples
distribution_fn (callable): callable to draw random samples
num_projections (int): number of projectsion to approximate sliced
Wasserstein distance
p (int): power of distance metric
device: 'cuda' or 'cpu' (default 'cpu')
Return:
torch.Tensor
"""
# derive batch size from encoded samples
batch_size = encoded_samples.size(0)
# draw samples from latent space prior distribution
z = distribution_fn(batch_size).to(device)
# approximate wasserstein_distance between encoded and prior distributions
# for average over each projection
swd = _sliced_wasserstein_distance(
encoded_samples,
z,
num_projections,
p,
device)
return swd
class SWAEBatchTrainer:
"""Sliced Wasserstein Autoencoder Batch Trainer.
Args:
autoencoder (torch.nn.Module): module which implements autoencoder
framework
optimizer (torch.optim.Optimizer): torch optimizer
distribution_fn (callable): callable to draw random samples
num_projections (int): number of projectsion to approximate sliced
Wasserstein distance
p (int): power of distance metric
weight_swd (float): weight of divergence metric compared to
reconstruction in loss
device (torch.Device): torch device
"""
def __init__(self, autoencoder, optimizer, distribution_fn,
num_projections=50, p=2, weight_swd=10.0, device=None):
self.model_ = autoencoder
self.optimizer = optimizer
self._distribution_fn = distribution_fn
self.embedding_dim_ = self.model_.encoder.embedding_dim_
self.num_projections_ = num_projections
self.p_ = p
self.weight_swd = weight_swd
self._device = device if device else torch.device('cpu')
def __call__(self, x):
return self.eval_on_batch(x)
def train_on_batch(self, x):
# reset gradients
self.optimizer.zero_grad()
# autoencoder forward pass and loss
evals = self.eval_on_batch(x)
# backpropagate loss
evals['loss'].backward()
# update encoder and decoder parameters
self.optimizer.step()
return evals
def test_on_batch(self, x):
# reset gradients
self.optimizer.zero_grad()
# autoencoder forward pass and loss
evals = self.eval_on_batch(x)
return evals
def eval_on_batch(self, x):
x = x.to(self._device)
recon_x, z = self.model_(x)
bce = F.binary_cross_entropy(recon_x, x)
l1 = F.l1_loss(recon_x, x)
w2 = float(self.weight_swd) * sliced_wasserstein_distance(
z,
self._distribution_fn,
self.num_projections_,
self.p_,
self._device)
loss = bce + l1 + w2
return {'loss': loss,
'bce': bce,
'l1': l1,
'w2': w2,
'encode': z,
'decode': recon_x}