-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathOPT.py
83 lines (61 loc) · 3.25 KB
/
OPT.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import torch
from mip import Model, xsum, minimize, maximize, BINARY
def IPconstraint_and_solve(S, K=50, mean_row_min=1, mean_col_min=1, union=True):
S_raw = S.clone()
S = S.detach().cpu()
S_topk_row, top_id_row = S.topk(K, dim=1, largest=False) if K < S.size(1) else S.topk(S.size(1), dim=1, largest=False)
S_topk_col, top_id_col = S.topk(K, dim=0, largest=False) if K < S.size(1) else S.topk(S.size(0), dim=0, largest=False)
S_empty_col = torch.zeros(S.size(0)).unsqueeze(1)
S_empty_row = torch.zeros(S.size(1)).unsqueeze(0)
c, top_id, col_constraint = quick_extract_col_constraint_list(S, S_topk_row, top_id_row, S_empty_row, S_topk_col, top_id_col, S_empty_col, union)
IPmodel = Model()
x = [[IPmodel.add_var(var_type=BINARY) for j in range(len(c[i]))] for i in range(len(c))]
IPmodel.objective = minimize(xsum(c[i][j] * x[i][j] for i in range(len(x)) for j in range(len(x[i]))))
for i in range(len(x) - 1):
IPmodel += xsum(x[i][j] for j in range(len(x[i]))) == 10
for i in range(len(col_constraint)):
IPmodel += xsum(x[j[0]][j[1]] for j in col_constraint[i]) == 10
IPmodel.optimize()
result_g1 = [[] for i in range(len(top_id))]
for i in range(len(x)):
for j in range(len(x[i])):
if x[i][j].x > 0:
result_g1[i].append(top_id[i][j])
topic_ids = torch.tensor(result_g1[:-1], device=S.device, dtype=torch.long)
topic_att = torch.stack([S_raw[i, topic_ids[i]] for i in range(S.size(0))], dim=0)
return topic_att, topic_ids
def quick_extract_col_constraint_list(S, S_topk_row, top_id_row, S_empty_row, S_topk_col, top_id_col, S_empty_col, union=True):
top_id = [set(row.numpy().tolist()) for row in top_id_row]
for i in range(len(top_id)):
location = torch.nonzero(torch.where(top_id_col == i, torch.tensor(1), torch.tensor(0)))
if union:
top_id[i] = top_id[i].union(set(location[:, 1].numpy().tolist()))
else:
top_id[i] = top_id[i].intersection(set(location[:, 1].numpy().tolist()))
top_id[i].add(S.size(1))
for i in range(len(top_id)):
top_id[i] = list(top_id[i])
top_id[i].sort()
top_id.append(torch.arange(S.size(1)).numpy().tolist())
seleted = torch.zeros(S.size(0) + 1, S.size(1) + 1).long()
for i in range(len(top_id)):
seleted[i, top_id[i]] = 1
row_nonzero_seleted = torch.nonzero(seleted).numpy().tolist()
col_nonzero_seleted = torch.nonzero(seleted.T).numpy().tolist()
temp = torch.zeros(seleted.size(0)).long()
for line in row_nonzero_seleted:
seleted[line[0], line[1]] = temp[line[0]]
temp[line[0]] += 1
quick_col_constraint = [[] for i in range(S.size(1))]
for line in col_nonzero_seleted:
if line[0] < S.size(1):
quick_col_constraint[line[0]].append([line[1], seleted[line[1], line[0]].numpy().tolist()])
c = [[] for i in range(S.size(0))]
for i in range(S.size(0)):
for j in top_id[i]:
if j < S.size(1):
c[i].append(S[i, j].numpy().tolist())
elif j == S.size(1):
c[i].append(S_empty_col[i, 0].numpy().tolist())
c = c + S_empty_row.numpy().tolist()
return c, top_id, quick_col_constraint