-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathloss_weights.py
76 lines (52 loc) · 1.77 KB
/
loss_weights.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Function
class ReverseLayerF(Function):
@staticmethod
def forward(ctx, x, p):
ctx.p = p
return x.view_as(x)
@staticmethod
def backward(ctx, grad_output):
output = grad_output.neg() * ctx.p
return output, None
def weight_comp(y):
cw = np.sum(y, axis = 0)
cw = cw / y.shape[0]
return cw
def integrated_loss_weight(wt):
w = wt / np.amax(wt)
return w
def loss_weight(yt, ysb):
wt = weight_comp(yt)
w_class = integrated_loss_weight(wt)
w_sample_class = np.array(([0.] * 2 * ysb.shape[0]))
w_sample_adv = np.array(([1.] * 2 * ysb.shape[0]))
ysbi = ysb.argmax(1)
w_sample_class[0:ysb.shape[0]] = w_class[ysbi]
w_sample_adv[0:ysb.shape[0]] = w_class[ysbi]
return w_sample_class, w_sample_adv
def MSE(pred, real):
diffs = torch.add(real, -pred)
n = torch.numel(diffs.data)
mse = torch.sum(diffs.pow(2)) / n
return mse
def SIMSE(pred, real):
diffs = torch.add(real, - pred)
n = torch.numel(diffs.data)
simse = torch.sum(diffs).pow(2) / (n ** 2)
return simse
def DiffLoss(x, recon_x):
batch_size = x.size(0)
x = x.view(batch_size, -1)
recon_x = recon_x.view(batch_size, -1)
x_l2_norm = torch.norm(x, p=2, dim=1, keepdim=True).detach()
x_l2 = x.div(x_l2_norm.expand_as(x) + 1e-6)
recon_x_l2_norm = torch.norm(recon_x, p=2, dim=1, keepdim=True).detach()
recon_x_l2 = recon_x.div(recon_x_l2_norm.expand_as(recon_x) + 1e-6)
diff_loss = torch.mean((x_l2.t().mm(recon_x_l2)).pow(2))
return diff_loss
def WLoss(source, target):
loss = -torch.log(source).sum().mean() - torch.log(1 - target).sum().mean()
return loss