-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathcriterion.py
116 lines (97 loc) · 3.72 KB
/
criterion.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import torch
import torch.nn as nn
import torch.nn.functional as F
from vgg import Vgg16
class LossFunc(nn.Module):
"""
Loss function for landmark prediction
"""
def __init__(self, loss_type='perceptual'):
super(LossFunc, self).__init__()
self.loss_type = loss_type
self.ema = EMA()
self.vggnet = Vgg16() if loss_type == 'perceptual' else None
self._init_ema()
def forward(self, future_im_pred, future_im, mask=None):
loss = self._loss(future_im_pred, future_im, mask=mask)
return loss
def _loss(self, future_im_pred, future_im, mask=None):
"loss function"
vgg_losses = []
w_reconstruct = 1. / 255.
if self.loss_type == 'perceptual':
w_reconstruct = 1.
reconstruction_loss, vgg_losses = self._colorization_reconstruction_loss(
future_im, future_im_pred, mask=mask)
elif self.loss_type == 'l2':
if mask is not None:
l = F.mse_loss(future_im_pred, future_im, reduction='none')
reconstruction_loss = torch.mean(self._loss_mask(l, mask))
else:
reconstruction_loss = F.mse_loss(future_im_pred, future_im)
else:
raise ValueError('Incorrect loss-type')
loss = w_reconstruct * reconstruction_loss
return loss, vgg_losses
def _loss_mask(self, imap, mask):
mask = F.interpolate(mask, imap.shape[-2:])
return imap * mask
def _colorization_reconstruction_loss(
self, gt_image, pred_image, mask=None):
"perceptual loss"
names = list(self.ema.avgs)
#get features map from vgg
feats_gt = self.vggnet(gt_image)
feats_pred = self.vggnet(pred_image)
feat_gt, feat_pred = [gt_image], [pred_image]
for k in names[1:]: #no need input
feat_gt.append(getattr(feats_gt, k))
feat_pred.append(getattr(feats_pred, k))
losses = []
for k, v in enumerate(names):
l = F.mse_loss(feat_pred[k], feat_gt[k], reduction='none')
if mask is not None:
l = self._loss_mask(l, mask)
#update EMA
# wl = self.exp_moving_avg(
# torch.mean(l).item(), name=v, init_val=self.ema[v])
l /= self.ema[v]
l = torch.mean(l)
losses.append(l)
vgg_losses = [x.item() for x in losses] #for display
loss = torch.stack(losses).sum()
return loss, vgg_losses
# def exp_moving_avg(self, x, name='x', init_val=0.):
# "exponential moving average"
# with torch.no_grad():
# if not self.training:
# return init_val
# x_new = self.ema.update(name, x, init_val)
# return x_new
def _init_ema(self, ws=[50., 40., 6., 3., 3., 1.],
names=['input', 'conv1_2', 'conv2_2', 'conv3_2', 'conv4_2', 'conv5_2']):
"init weight for perceptual loss/EMA"
for k, v in range(names):
self.ema.update(v, ws[k], 0.)
class EMA(object):
"""Exponential running average
"""
def __init__(self, decay=0.99):
self.rho = decay
self.avgs = {}
def register(self, name, val):
"add val to shadow by key=name"
self.avgs.update({name: val})
def get(self, name):
"get value with key=name"
return self.avgs[name]
def update(self, name, x, init_val=0.):
"update new value for variable x"
if name not in self.avgs.keys():
self.register(name, init_val)
return init_val
x_avg = self.get(name)
w_update = 1. - self.rho
x_new = x_avg + w_update * (x - x_avg)
self.register(name, x_new)
return x_new