-
Notifications
You must be signed in to change notification settings - Fork 26
/
Copy pathtriple_selector.py
75 lines (63 loc) · 2.7 KB
/
triple_selector.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import numpy as np
import torch
def get_all_triplets(dist_mat, pos_mask, neg_mask, is_inverted=False, margin=0.5, different_embedding=False):
if not different_embedding:
pos_mask = torch.triu(pos_mask, 1)
triplets = pos_mask.unsqueeze(2) * neg_mask.unsqueeze(1)
return torch.where(triplets)
def hardest_negative_selector(dist_mat, pos_mask, neg_mask, is_inverted, margin=0.5, different_embedding=False):
if not different_embedding:
pos_mask = torch.triu(pos_mask, 1)
a, p = torch.where(pos_mask)
if neg_mask.sum() == 0:
return a, p, None
if is_inverted:
dist_neg = dist_mat * neg_mask
n = torch.max(dist_neg, dim=1)
else:
dist_neg = dist_mat.clone()
dist_neg[~neg_mask] = dist_neg.max()+1.
_, n = torch.min(dist_neg, dim=1)
n = n[a]
return a, p, n
def random_negative_selector(dist_mat, pos_mask, neg_mask, is_inverted, margin=0.5, different_embedding=False):
if not different_embedding:
pos_mask = torch.triu(pos_mask, 1)
a, p = torch.where(pos_mask)
selected_negs = []
for i in range(a.shape[0]):
possible_negs = torch.where(neg_mask[a[i]])[0]
if len(possible_negs) == 0:
return a, p, None
dist_neg = dist_mat[a[i], possible_negs]
if is_inverted:
curr_loss = -dist_mat[a[i], p[i]] + dist_neg + margin
else:
curr_loss = dist_mat[a[i], p[i]] - dist_neg + margin
if len(possible_negs[curr_loss > 0]) > 0:
possible_negs = possible_negs[curr_loss > 0]
random_neg = np.random.choice(possible_negs.cpu().numpy())
selected_negs.append(random_neg)
n = torch.tensor(selected_negs, dtype=a.dtype, device=a.device)
return a, p, n
def semihard_negative_selector(dist_mat, pos_mask, neg_mask, is_inverted, margin=0.5, different_embedding=False):
if not different_embedding:
pos_mask = torch.triu(pos_mask, 1)
a, p = torch.where(pos_mask)
selected_negs = []
for i in range(a.shape[0]):
possible_negs = torch.where(neg_mask[a[i]])[0]
if len(possible_negs) == 0:
return a, p, None
dist_neg = dist_mat[a[i], possible_negs]
if is_inverted:
curr_loss = -dist_mat[a[i], p[i]] + dist_neg + margin
else:
curr_loss = dist_mat[a[i], p[i]] - dist_neg + margin
semihard_idxs = (curr_loss > 0) & (curr_loss < margin)
if len(possible_negs[semihard_idxs]) > 0:
possible_negs = possible_negs[semihard_idxs]
random_neg = np.random.choice(possible_negs.cpu().numpy())
selected_negs.append(random_neg)
n = torch.tensor(selected_negs, dtype=a.dtype, device=a.device)
return a, p, n