-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathedl_loss.py
223 lines (185 loc) · 8.4 KB
/
edl_loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
import torch
import torch.nn.functional as F
import numpy as np
from base import BaseWeightedLoss
def relu_evidence(y):
return F.relu(y)
def exp_evidence(y):
return torch.exp(torch.clamp(y, -10, 10))
def softplus_evidence(y):
return F.softplus(y)
class EvidenceLoss(BaseWeightedLoss):
"""Evidential MSE Loss."""
def __init__(self, num_classes,
evidence='relu',
loss_type='mse',
with_kldiv=True,
with_avuloss=False,
disentangle=False,
annealing_method='step',
annealing_start=0.01,
annealing_step=10):
super().__init__()
self.num_classes = num_classes
self.evidence = evidence
self.loss_type = loss_type
self.with_kldiv = with_kldiv
self.with_avuloss = with_avuloss
self.disentangle = disentangle
self.annealing_method = annealing_method
self.annealing_start = annealing_start
self.annealing_step = annealing_step
self.eps = 1e-10
def kl_divergence(self, alpha):
beta = torch.ones([1, self.num_classes], dtype=torch.float32).to(alpha.device)
S_alpha = torch.sum(alpha, dim=1, keepdim=True)
S_beta = torch.sum(beta, dim=1, keepdim=True)
lnB = torch.lgamma(S_alpha) - \
torch.sum(torch.lgamma(alpha), dim=1, keepdim=True)
lnB_uni = torch.sum(torch.lgamma(beta), dim=1,
keepdim=True) - torch.lgamma(S_beta)
dg0 = torch.digamma(S_alpha)
dg1 = torch.digamma(alpha)
kl = torch.sum((alpha - beta) * (dg1 - dg0), dim=1,
keepdim=True) + lnB + lnB_uni
return kl
def loglikelihood_loss(self, y, alpha):
S = torch.sum(alpha, dim=1, keepdim=True)
loglikelihood_err = torch.sum(
(y - (alpha / S)) ** 2, dim=1, keepdim=True)
loglikelihood_var = torch.sum(
alpha * (S - alpha) / (S * S * (S + 1)), dim=1, keepdim=True)
return loglikelihood_err, loglikelihood_var
def mse_loss(self, y, alpha, annealing_coef):
"""Used only for loss_type == 'mse'
y: the one-hot labels (batchsize, num_classes)
alpha: the predictions (batchsize, num_classes)
epoch_num: the current training epoch
"""
losses = {}
loglikelihood_err, loglikelihood_var = self.loglikelihood_loss(y, alpha)
losses.update({'loss_cls': loglikelihood_err, 'loss_var': loglikelihood_var})
losses.update({'lambda': annealing_coef})
if self.with_kldiv:
kl_alpha = (alpha - 1) * (1 - y) + 1
kl_div = annealing_coef * \
self.kl_divergence(kl_alpha)
losses.update({'loss_kl': kl_div})
if self.with_avuloss:
S = torch.sum(alpha, dim=1, keepdim=True) # Dirichlet strength
pred_score = alpha / S
uncertainty = self.num_classes / S
# avu_loss = annealing_coef *
return losses
def ce_loss(self, target, y, alpha, annealing_coef):
"""Used only for loss_type == 'ce'
target: the scalar labels (batchsize,)
alpha: the predictions (batchsize, num_classes), alpha >= 1
epoch_num: the current training epoch
"""
losses = {}
# (1) the classification loss term
S = torch.sum(alpha, dim=1, keepdim=True)
pred_score = alpha / S
loss_cls = F.nll_loss(torch.log(pred_score), target, reduction='none')
losses.update({'loss_cls': loss_cls})
# (2) the likelihood variance term
loglikelihood_var = torch.sum(
alpha * (S - alpha) / (S * S * (S + 1)), dim=1, keepdim=True)
losses.update({'loss_var': loglikelihood_var})
# (3) the KL divergence term
kl_alpha = (alpha - 1) * (1 - y) + 1
kl_div = annealing_coef * \
self.kl_divergence(kl_alpha)
losses.update({'loss_kl': kl_div, 'lambda': annealing_coef})
return losses
def edl_loss(self, func, y, alpha, annealing_coef, target):
"""Used for both loss_type == 'log' and loss_type == 'digamma'
func: function handler (torch.log, or torch.digamma)
y: the one-hot labels (batch_size, num_classes)
alpha: the predictions (batch_size, num_classes)
epoch_num: the current training epoch
"""
losses = {}
S = torch.sum(alpha, dim=1, keepdim=True)
uncertainty = self.num_classes / S
label_num = torch.sum(y, dim=1, keepdim=True)
temp = 1 / alpha * y
g = (1 - uncertainty.detach()) * label_num * torch.div(temp, torch.sum(temp, dim=1, keepdim=True))
A = torch.sum(g * (func(S) - func(alpha)), dim=1, keepdim=True)
losses.update({'loss_cls': A})
losses.update({'lambda': annealing_coef})
if self.with_kldiv:
kl_alpha = (alpha - 1) * (1 - y) + 1
kl_div = annealing_coef * \
self.kl_divergence(kl_alpha)
losses.update({'loss_kl': kl_div})
if self.with_avuloss:
pred = alpha / S
uncertainty = self.num_classes / S
inacc_measure = torch.abs(pred - target).sum(dim=1) / 2.0
acc_uncertain = - (torch.ones_like(inacc_measure) - inacc_measure) * torch.log(1 - uncertainty + self.eps)
inacc_certain = - inacc_measure * torch.log(uncertainty + self.eps)
batch_size, _ = y.shape
inacc_measure_bool = inacc_measure.clone()
inacc_measure_bool[inacc_measure_bool > 0.7] = 1
inacc_measure_bool[inacc_measure_bool <= 0.7] = 0
acc_match = 1 - torch.sum(inacc_measure_bool) / batch_size
avu_loss = annealing_coef * acc_match * acc_uncertain + (1 - annealing_coef) * (
1 - acc_match) * inacc_certain
losses.update({'loss_avu': avu_loss})
return losses
def compute_annealing_coef(self, **kwargs):
assert 'epoch' in kwargs, "epoch number is missing!"
assert 'total_epoch' in kwargs, "total epoch number is missing!"
epoch_num, total_epoch = kwargs['epoch'], kwargs['total_epoch']
# annealing coefficient
if self.annealing_method == 'step':
annealing_coef = torch.min(torch.tensor(
1.0, dtype=torch.float32), torch.tensor(epoch_num / self.annealing_step, dtype=torch.float32))
elif self.annealing_method == 'exp':
annealing_start = torch.tensor(self.annealing_start, dtype=torch.float32)
annealing_coef = annealing_start * torch.exp(-torch.log(annealing_start) / total_epoch * epoch_num)
else:
raise NotImplementedError
return annealing_coef
def _forward(self, output, target, **kwargs):
"""Forward function.
Args:
output (torch.Tensor): The class score (before softmax).
target (torch.Tensor): The ground truth label.
epoch_num: The number of epochs during training.
Returns:
torch.Tensor: The returned EvidenceLoss loss.
"""
# get evidence
if self.evidence == 'relu':
evidence = relu_evidence(output)
elif self.evidence == 'exp':
evidence = exp_evidence(output)
elif self.evidence == 'softplus':
evidence = softplus_evidence(output)
else:
raise NotImplementedError
alpha = evidence + 1
# Our target is a vector, as result, no need for one-hot embedding
y = target
# # one-hot embedding for the target
# y = torch.eye(self.num_classes).to(output.device)
# y = y[target]
# compute annealing coefficient
annealing_coef = self.compute_annealing_coef(**kwargs)
# compute the EDL loss
if self.loss_type == 'mse':
results = self.mse_loss(y, alpha, annealing_coef)
elif self.loss_type == 'log':
results = self.edl_loss(torch.log, y, alpha, annealing_coef, target)
elif self.loss_type == 'digamma':
results = self.edl_loss(torch.digamma, y, alpha, annealing_coef, target)
elif self.loss_type == 'cross_entropy':
results = self.ce_loss(target, y, alpha, annealing_coef)
else:
raise NotImplementedError
uncertainty = self.num_classes / torch.sum(alpha, dim=1, keepdim=True)
results.update({'uncertainty': uncertainty})
return results