-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathtriplet_loss.py
121 lines (95 loc) · 4.19 KB
/
triplet_loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import torch
import torch.nn.functional as F
def euclidean_dist(x, y):
m, n = x.size(0), y.size(0)
xx = torch.pow(x, 2).sum(1, keepdim=True).expand(m, n)
yy = torch.pow(y, 2).sum(1, keepdim=True).expand(n, m).t()
dist = xx + yy - 2 * torch.matmul(x, y.t())
dist = dist.clamp(min=1e-12).sqrt() # for numerical stability
return dist
def cosine_dist(x, y):
x = F.normalize(x, dim=1)
y = F.normalize(y, dim=1)
dist = 2 - 2 * torch.mm(x, y.t())
return dist
def softmax_weights(dist, mask):
max_v = torch.max(dist * mask, dim=1, keepdim=True)[0]
diff = dist - max_v
Z = torch.sum(torch.exp(diff) * mask, dim=1, keepdim=True) + 1e-6 # avoid division by zero
W = torch.exp(diff) * mask / Z
return W
def hard_example_mining(dist_mat, is_pos, is_neg):
"""For each anchor, find the hardest positive and negative sample.
Args:
dist_mat: pair wise distance between samples, shape [N, M]
is_pos: positive index with shape [N, M]
is_neg: negative index with shape [N, M]
Returns:
dist_ap: pytorch Variable, distance(anchor, positive); shape [N]
dist_an: pytorch Variable, distance(anchor, negative); shape [N]
p_inds: pytorch LongTensor, with shape [N];
indices of selected hard positive samples; 0 <= p_inds[i] <= N - 1
n_inds: pytorch LongTensor, with shape [N];
indices of selected hard negative samples; 0 <= n_inds[i] <= N - 1
NOTE: Only consider the case in which all labels have same num of samples,
thus we can cope with all anchors in parallel.
"""
assert len(dist_mat.size()) == 2
# `dist_ap` means distance(anchor, positive)
# both `dist_ap` and `relative_p_inds` with shape [N]
dist_ap, _ = torch.max(dist_mat * is_pos, dim=1)
# `dist_an` means distance(anchor, negative)
# both `dist_an` and `relative_n_inds` with shape [N]
dist_an, _ = torch.min(dist_mat * is_neg + is_pos * 1e9, dim=1)
return dist_ap, dist_an
def weighted_example_mining(dist_mat, is_pos, is_neg):
"""For each anchor, find the weighted positive and negative sample.
Args:
dist_mat: pytorch Variable, pair wise distance between samples, shape [N, N]
is_pos:
is_neg:
Returns:
dist_ap: pytorch Variable, distance(anchor, positive); shape [N]
dist_an: pytorch Variable, distance(anchor, negative); shape [N]
"""
assert len(dist_mat.size()) == 2
is_pos = is_pos
is_neg = is_neg
dist_ap = dist_mat * is_pos
dist_an = dist_mat * is_neg
weights_ap = softmax_weights(dist_ap, is_pos)
weights_an = softmax_weights(-dist_an, is_neg)
dist_ap = torch.sum(dist_ap * weights_ap, dim=1)
dist_an = torch.sum(dist_an * weights_an, dim=1)
return dist_ap, dist_an
def triplet_loss(embedding, targets, margin=0, norm_feat=False, hard_mining=False):
r"""Modified from Tong Xiao's open-reid (https://github.com/Cysu/open-reid).
Related Triplet Loss theory can be found in paper 'In Defense of the Triplet
Loss for Person Re-Identification'."""
if norm_feat:
dist_mat = cosine_dist(embedding, embedding)
else:
dist_mat = euclidean_dist(embedding, embedding)
# For distributed training, gather all features from different process.
# if comm.get_world_size() > 1:
# all_embedding = torch.cat(GatherLayer.apply(embedding), dim=0)
# all_targets = concat_all_gather(targets)
# else:
# all_embedding = embedding
# all_targets = targets
N = dist_mat.size(0)
is_pos = targets.view(N, 1).expand(N, N).eq(targets.view(N, 1).expand(N, N).t()).float()
is_neg = targets.view(N, 1).expand(N, N).ne(targets.view(N, 1).expand(N, N).t()).float()
if hard_mining:
dist_ap, dist_an = hard_example_mining(dist_mat, is_pos, is_neg)
else:
dist_ap, dist_an = weighted_example_mining(dist_mat, is_pos, is_neg)
y = dist_an.new().resize_as_(dist_an).fill_(1)
if margin > 0:
loss = F.margin_ranking_loss(dist_an, dist_ap, y, margin=margin)
else:
loss = F.soft_margin_loss(dist_an - dist_ap, y)
# fmt: off
if loss == float('Inf'): loss = F.margin_ranking_loss(dist_an, dist_ap, y, margin=0.3)
# fmt: on
return loss