-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
104 lines (72 loc) · 2.93 KB
/
Copy pathmain.py
File metadata and controls
104 lines (72 loc) · 2.93 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import numpy as np
from tqdm import tqdm
import csv
import os
import torch
from torch.utils.data import dataset, dataloader
from torchvision import transforms
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score
from sklearn.preprocessing import StandardScaler
from data import Market1501
from network import MGN
from utils.extract_feature import extract_feature
from opt import opt
import shutil
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
def main():
model = MGN()
model.load_state_dict(torch.load(opt.weight))
cuda_model = model.to("cuda")
embeddings = None
file_names = None
cuda_model.eval()
for dataset in opt.data_path:
test_transform = transforms.Compose([
transforms.Resize((384, 128), interpolation=3),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
testset = Market1501(test_transform, 'test', dataset)
data_loader = dataloader.DataLoader(testset, batch_size=16, num_workers=8, pin_memory=True)
print("generating embeddings for", dataset)
features = extract_feature(cuda_model, tqdm(data_loader)).numpy()
names = np.array(testset.imgs)
if embeddings is None:
embeddings = features
file_names = names
else:
embeddings = np.concatenate((embeddings, features))
file_names = np.concatenate((file_names, names))
if opt.embeddings_path is not None:
np.save(opt.embeddings_path, embeddings)
np.save(opt.embeddings_path + "_names", file_names)
print("performing clustering")
print("finding number of clusters")
embeddings = StandardScaler().fit_transform(embeddings)
clusterings = {}
scores = {}
def get_score(x):
x = int(x)
if x < 2:
return 0
if x not in clusterings:
clusterings[x] = KMeans(x).fit(embeddings)
scores[x] = silhouette_score(embeddings, clusterings[x].labels_)
return scores[x]
return scores[x]
candidates = [2**x for x in range(1, 7)]
candidate = candidates[np.argmax(np.array([get_score(a) for a in candidates]))]
possible_ks = [x for x in range(candidate // 2, (candidate * 2) + 1)]
k = possible_ks[np.argmax(np.array([get_score(a) for a in tqdm(possible_ks)]))]
print("found best k", k, "with silhouette score of", scores[k])
clustering = clusterings[k]
with open(opt.output_file, "w") as csvfile:
csv_writer = csv.writer(csvfile)
for path, cluster_id, embedding in zip(file_names, clustering.labels_, embeddings):
csv_writer.writerow([path, str(cluster_id)])
if not os.path.isdir(str(cluster_id)):
os.mkdir(str(cluster_id))
shutil.copyfile(path, str(cluster_id) + "/" + path.split("\\")[-1])
if __name__ == "__main__":
main()