forked from cfchen-duke/ProtoPNet
-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathprune.py
109 lines (90 loc) · 5.54 KB
/
prune.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import json
import os
import shutil
from collections import Counter
import numpy as np
import torch
from helpers import makedir
import find_nearest
def prune_prototypes(dataset,
prototype_network_parallel,
k,
prune_threshold,
preprocess_input_function,
original_model_dir,
epoch_number,
#model_name=None,
log=print,
copy_prototype_imgs=True):
### run global analysis
nearest_train_patch_class_ids = \
find_nearest.find_k_nearest_patches_to_prototypes(dataset=dataset,
prototype_network_parallel=prototype_network_parallel,
k=k,
preprocess_input_function=preprocess_input_function,
full_save=True,
root_dir_for_saving_images=os.path.join(original_model_dir,
'img'),
log=log)
### find prototypes to prune
original_num_prototypes = prototype_network_parallel.module.num_prototypes
prototypes_to_prune = []
for j in range(prototype_network_parallel.module.num_prototypes):
class_j = torch.argmax(prototype_network_parallel.module.prototype_class_identity[j]).item()
nearest_train_patch_class_counts_j = Counter(nearest_train_patch_class_ids[j])
# if no such element is in Counter, it will return 0
if nearest_train_patch_class_counts_j[class_j] < prune_threshold:
prototypes_to_prune.append(j)
log('k = {}, prune_threshold = {}'.format(k, prune_threshold))
log('{} prototypes will be pruned'.format(len(prototypes_to_prune)))
### bookkeeping of prototypes to be pruned
class_of_prototypes_to_prune = \
torch.argmax(
prototype_network_parallel.module.prototype_class_identity[prototypes_to_prune],
dim=1).cpu().detach().numpy().reshape(-1, 1)
prototypes_to_prune_np = np.array(prototypes_to_prune).reshape(-1, 1)
prune_info = np.hstack((prototypes_to_prune_np, class_of_prototypes_to_prune))
makedir(os.path.join(original_model_dir, 'pruned_prototypes_epoch{}_k{}_pt{}'.format(epoch_number,
k,
prune_threshold)))
np.save(os.path.join(original_model_dir, 'pruned_prototypes_epoch{}_k{}_pt{}'.format(epoch_number,
k,
prune_threshold), 'prune_info.npy'),
prune_info)
### prune prototypes
prototype_network_parallel.module.prune_prototypes(prototypes_to_prune)
#torch.save(obj=ppnet.module,
# f=os.path.join(original_model_dir, 'pruned_prototypes_epoch{}_k{}_pt{}'.format(epoch_number,
# k,
# prune_threshold),
# model_name + '-pruned.pth'))
prototypes_to_keep = list(sorted(set(range(original_num_prototypes)) - set(prototypes_to_prune)))
log(f'Prototypes to keep: {prototypes_to_keep}')
with open(os.path.join(original_model_dir, 'prototypes_to_keep.json'), 'w') as fp:
json.dump(prototypes_to_keep, fp)
if copy_prototype_imgs:
original_img_dir = os.path.join(original_model_dir, 'img', 'epoch-%d' % epoch_number)
dst_img_dir = os.path.join(original_model_dir,
'pruned_prototypes_epoch{}_k{}_pt{}'.format(epoch_number,
k,
prune_threshold),
'img', 'epoch-%d' % epoch_number)
makedir(dst_img_dir)
for idx in range(len(prototypes_to_keep)):
shutil.copyfile(src=os.path.join(original_img_dir, 'prototype-img%d.png' % prototypes_to_keep[idx]),
dst=os.path.join(dst_img_dir, 'prototype-img%d.png' % idx))
shutil.copyfile(src=os.path.join(original_img_dir, 'prototype-img-original%d.png' % prototypes_to_keep[idx]),
dst=os.path.join(dst_img_dir, 'prototype-img-original%d.png' % idx))
shutil.copyfile(src=os.path.join(original_img_dir, 'prototype-img-original_with_self_act%d.png' % prototypes_to_keep[idx]),
dst=os.path.join(dst_img_dir, 'prototype-img-original_with_self_act%d.png' % idx))
shutil.copyfile(src=os.path.join(original_img_dir, 'prototype-self-act%d.npy' % prototypes_to_keep[idx]),
dst=os.path.join(dst_img_dir, 'prototype-self-act%d.npy' % idx))
bb = np.load(os.path.join(original_img_dir, 'bb%d.npy' % epoch_number))
bb = bb[prototypes_to_keep]
np.save(os.path.join(dst_img_dir, 'bb%d.npy' % epoch_number),
bb)
bb_rf = np.load(os.path.join(original_img_dir, 'bb-receptive_field%d.npy' % epoch_number))
bb_rf = bb_rf[prototypes_to_keep]
np.save(os.path.join(dst_img_dir, 'bb-receptive_field%d.npy' % epoch_number),
bb_rf)
return prune_info