Skip to content

Commit 83c1bd6

Browse files
author
Sara Elkerdawy
committed
FTWT demo
0 parents  commit 83c1bd6

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

49 files changed

+3306
-0
lines changed

README.md

+19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# Fire Together Wire Together
2+
Sample training code for CIFAR fo dynamic pruning with self-supervised mask
3+
4+
## Environment
5+
```
6+
virtualenv .envpy36 -p python3.6 #Initialize environment
7+
source .envpy36/bin/activate
8+
pip install -r req.txt # Install dependencies
9+
```
10+
11+
## Train baseline
12+
```
13+
sh job_baseline.sh #You can change model at line 5
14+
```
15+
16+
## Train dynamic
17+
```
18+
sh job_dynamic.sh #You can change model at line 5 and threshold at line 40
19+
```

cifar.py

+607
Large diffs are not rendered by default.

job_baseline.sh

+54
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
#!/bin/bash
2+
3+
datasetdir='./data'
4+
dataset=cifar10
5+
model=mobilenetv1 #vgg16_bn #resnet56 #mobilenetv2
6+
mlr=1e-1
7+
task=train
8+
wd=5e-4
9+
10+
11+
scratch(){
12+
initpath='None'
13+
init='scratch'
14+
}
15+
16+
train(){
17+
task=train
18+
epochs=200
19+
lr=0.1
20+
extra='--tb'
21+
schedule='81 122 151'
22+
lr_scheduler_b='step'
23+
#lr_scheduler_b='cosine' #for MbnetV2
24+
}
25+
26+
evaluate(){
27+
task=evaluate
28+
}
29+
30+
scratch
31+
train
32+
#evaluate
33+
34+
bs=128
35+
extra='--baseline'
36+
echo $initpath
37+
chkpnt='pretrained/'$dataset'/'$model'/'
38+
39+
if [ $task != evaluate ]
40+
then
41+
python cifar.py -a $model --dataset $dataset -p $datasetdir\
42+
--gpu-id 0,1,2,3 \
43+
--checkpoint $chkpnt --init $initpath \
44+
--epochs $epochs --lr $lr --mlr $mlr --wd $wd\
45+
--train-batch $bs --test-batch $bs\
46+
--schedule $schedule --lr_scheduler_b $lr_scheduler_b \
47+
$extra
48+
else
49+
modelbest=$chkpnt'/model_best.pth.tar'
50+
python cifar.py -a $model --dataset $dataset -p $datasetdir --checkpoint $chkpnt\
51+
--evaluate --test-batch 100\
52+
--init $initpath --resume $modelbest --tb \
53+
$extra
54+
fi

job_dynamic.sh

+63
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
#!/bin/bash
2+
3+
datasetdir='./data'
4+
dataset=cifar10
5+
model=mobilenetv1 #vgg16_bn #resnet56
6+
mlr=1e-1
7+
task=train
8+
wd=5e-4
9+
10+
pretrained(){
11+
initpath='pretrained/'$dataset'/'$model'/model_best.pth.tar'
12+
init='pretrained'
13+
}
14+
trainwithpred(){
15+
task=train
16+
epochs=200
17+
lr=1e-2
18+
extra='--tb'
19+
schedule='81 122 151'
20+
lr_scheduler_b='step'
21+
#lr_scheduler_b='cosine' #for MbnetV2
22+
}
23+
finetunewithpred(){
24+
task=finetune
25+
epochs=50
26+
lr=1e-3
27+
extra='--tb'
28+
schedule='29 39 49'
29+
lr_scheduler_b='step'
30+
}
31+
evaluate(){
32+
task=evaluate
33+
}
34+
35+
pretrained
36+
trainwithpred
37+
38+
bs=128
39+
mode='decoupled' #Or joint
40+
gttype='mass' #Or uniform
41+
mthresh=1.0 #Keep top {mthresh}% of heatmap mass in case of gttype=mass, or top {mthresh}% filters (uniform pruning) in case of gttype=uniform
42+
43+
echo $initpath
44+
chkpnt='dynamic-ftwt/'$dataset'/'$task'_'$model'_lr'$lr'_mthresh'$mthresh'_'$mode'_'$gttype'_'$lr_scheduler_b
45+
46+
if [ $task != evaluate ] #Train or finetune
47+
then
48+
python cifar.py -a $model --dataset $dataset -p $datasetdir\
49+
--gpu-id 0,1,2,3 \
50+
--checkpoint $chkpnt --init $initpath \
51+
--epochs $epochs --lr $lr --mlr $mlr --wd $wd\
52+
--train-batch $bs --test-batch $bs \
53+
--schedule $schedule --lr_scheduler_b $lr_scheduler_b \
54+
--mthresh $mthresh --mode $mode --gt-type $gttype\
55+
$extra
56+
else
57+
modelbest=$chkpnt'/model_best.pth.tar'
58+
python cifar.py -a $model --dataset $dataset -p $datasetdir --checkpoint $chkpnt\
59+
--evaluate --test-batch 100\
60+
--init $initpath --resume $modelbest --tb \
61+
$extra
62+
fi
63+

models/__init__.py

Whitespace-only changes.
124 Bytes
Binary file not shown.

models/cifar/__init__.py

+67
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
from __future__ import absolute_import
2+
3+
"""The models subpackage contains definitions for the following model for CIFAR10/CIFAR100
4+
architectures:
5+
6+
- `AlexNet`_
7+
- `VGG`_
8+
- `ResNet`_
9+
- `SqueezeNet`_
10+
- `DenseNet`_
11+
12+
You can construct a model with random weights by calling its constructor:
13+
14+
.. code:: python
15+
16+
import torchvision.models as models
17+
resnet18 = models.resnet18()
18+
alexnet = models.alexnet()
19+
squeezenet = models.squeezenet1_0()
20+
densenet = models.densenet_161()
21+
22+
We provide pre-trained models for the ResNet variants and AlexNet, using the
23+
PyTorch :mod:`torch.utils.model_zoo`. These can constructed by passing
24+
``pretrained=True``:
25+
26+
.. code:: python
27+
28+
import torchvision.models as models
29+
resnet18 = models.resnet18(pretrained=True)
30+
alexnet = models.alexnet(pretrained=True)
31+
32+
ImageNet 1-crop error rates (224x224)
33+
34+
======================== ============= =============
35+
Network Top-1 error Top-5 error
36+
======================== ============= =============
37+
ResNet-18 30.24 10.92
38+
ResNet-34 26.70 8.58
39+
ResNet-50 23.85 7.13
40+
ResNet-101 22.63 6.44
41+
ResNet-152 21.69 5.94
42+
Inception v3 22.55 6.44
43+
AlexNet 43.45 20.91
44+
VGG-11 30.98 11.37
45+
VGG-13 30.07 10.75
46+
VGG-16 28.41 9.62
47+
VGG-19 27.62 9.12
48+
SqueezeNet 1.0 41.90 19.58
49+
SqueezeNet 1.1 41.81 19.38
50+
Densenet-121 25.35 7.83
51+
Densenet-169 24.00 7.00
52+
Densenet-201 22.80 6.43
53+
Densenet-161 22.35 6.20
54+
======================== ============= =============
55+
56+
57+
.. _AlexNet: https://arxiv.org/abs/1404.5997
58+
.. _VGG: https://arxiv.org/abs/1409.1556
59+
.. _ResNet: https://arxiv.org/abs/1512.03385
60+
.. _SqueezeNet: https://arxiv.org/abs/1602.07360
61+
.. _DenseNet: https://arxiv.org/abs/1608.06993
62+
"""
63+
64+
from .vgg import *
65+
from .resnet import *
66+
from .mobilenet import *
67+
from .mobilenetv2 import *
267 Bytes
Binary file not shown.
1.49 KB
Binary file not shown.
4.67 KB
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
6.76 KB
Binary file not shown.
4.57 KB
Binary file not shown.
5.26 KB
Binary file not shown.
6.48 KB
Binary file not shown.

models/cifar/mobilenet.py

+174
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
import torch.nn as nn
2+
import torch.nn.functional as F
3+
import pdb
4+
from utils.layers import ConvWithMask
5+
6+
__all__ = ['mobilenetv1', 'mobilenetv1_75', 'mobilenetv1_50']
7+
8+
pretrained_TF = False
9+
relu_fn = None
10+
import torch
11+
12+
class TFSamePad(nn.Module):
13+
def __init__(self, kernel_size, stride):
14+
super(TFSamePad, self).__init__()
15+
self.stride = stride
16+
if kernel_size != 3:
17+
raise NotImplementedError('only support kernel_size == 3')
18+
19+
def forward(self, x):
20+
if self.stride == 2:
21+
return F.pad(x, (0, 1, 0, 1))
22+
elif self.stride == 1:
23+
return F.pad(x, (1, 1, 1, 1))
24+
else:
25+
raise NotImplementedError('only support stride == 1 or 2')
26+
27+
def relu(relu6):
28+
if relu6:
29+
return nn.ReLU6(inplace=True)
30+
else:
31+
return nn.ReLU(inplace=True)
32+
33+
class MobileNet(nn.Module):
34+
def __init__(self, num_classes=1000, dropout=False, from_TF=False, depth_multiplier=1.0):
35+
super(MobileNet, self).__init__()
36+
37+
self.nmasked_layers = 0
38+
self.baseline = True #For profile in thop
39+
self.d = depth_multiplier
40+
41+
global pretrained_TF, relu_fn, cfg
42+
pretrained_TF = from_TF
43+
relu_fn = relu(from_TF)
44+
45+
if num_classes == 1000:
46+
self.cfg = cfg['imagenet']
47+
self.pool_k = 7
48+
else:
49+
self.cfg = cfg['cifar']
50+
self.pool_k = 2
51+
52+
self.model = self._make_layers(self.cfg, self.d)
53+
self.pool = nn.AvgPool2d(self.pool_k)
54+
self.dropout = nn.Dropout(0.2) if dropout else nn.Identity()
55+
last_layer = int(self.d * 1024)
56+
self.fc = nn.Linear(last_layer, num_classes)
57+
58+
def _make_layers(self, cfg, d):
59+
60+
conv_bn = self.conv_bn
61+
conv_dw = self.conv_dw
62+
63+
layers = []
64+
in_planes = 3
65+
for i, x in enumerate(self.cfg):
66+
out_planes = x if isinstance(x, int) else x[0]
67+
stride = 1 if isinstance(x, int) else x[1]
68+
if i == 0: #First layer is normal conv
69+
layers.append(conv_bn(in_planes, out_planes, stride, d))
70+
else:
71+
layers.append(conv_dw(in_planes, out_planes, stride, d))
72+
73+
in_planes = out_planes
74+
75+
return nn.Sequential(*layers)
76+
77+
@staticmethod
78+
def conv_bn(inp, oup, stride, d):
79+
80+
oup = int(d * oup)
81+
layers=[]
82+
pad = 1
83+
84+
# PyTorch BN defaults
85+
eps=1e-5
86+
momentum=0.1
87+
if pretrained_TF:
88+
layers += [TFSamePad(3, stride)]
89+
pad = 0
90+
# TF BN defaults
91+
eps = 1e-3
92+
momentum = 1e-3
93+
94+
layers += [
95+
nn.Conv2d(inp, oup, 3, stride, pad, bias=False),
96+
nn.BatchNorm2d(oup, eps=eps, momentum=momentum),
97+
relu_fn]
98+
return nn.Sequential(*layers)
99+
100+
@staticmethod
101+
def conv_dw(inp, oup, stride, d):
102+
inp = int(d * inp)
103+
oup = int(d * oup)
104+
layers=[]
105+
pad = 1
106+
107+
# PyTorch BN defaults
108+
eps=1e-5
109+
momentum=0.1
110+
if pretrained_TF:
111+
layers += [TFSamePad(3, stride)]
112+
pad = 0
113+
# TF BN defaults
114+
eps = 1e-3
115+
momentum = 1e-3
116+
117+
layers += [
118+
nn.Conv2d(inp, inp, 3, stride, pad, groups=inp, bias=False),
119+
nn.BatchNorm2d(inp, eps=eps, momentum=momentum),
120+
relu_fn,
121+
122+
nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
123+
nn.BatchNorm2d(oup, eps=eps, momentum=momentum),
124+
relu_fn]
125+
126+
return nn.Sequential(*layers)
127+
128+
def forward_baseline(self, x):
129+
x = self.model(x)
130+
x = self.dropout(self.pool(x))
131+
x = x.view(x.shape[0], -1)
132+
x = self.fc(x)
133+
return x
134+
135+
def forward_mask(self, x):
136+
all_mask_logits, all_gt_mask = [], []
137+
FLOPs = 0
138+
bs = x.shape[0]
139+
prev = torch.ones(bs) * x.shape[1]
140+
141+
for i, m in enumerate(self.model.children()):
142+
if isinstance(m, ConvWithMask):
143+
x, all_mask_logits, prev, all_gt_mask, cur_flops = m(x, all_mask_logits, all_gt_mask, prev)
144+
FLOPs += cur_flops
145+
else:
146+
x = m(x)
147+
148+
x = self.dropout(self.pool(x))
149+
x = x.view(x.shape[0], -1)
150+
x = self.fc(x)
151+
FLOPs = FLOPs.to(x.device)
152+
153+
return x, all_mask_logits, all_gt_mask, FLOPs
154+
155+
def forward(self, x):
156+
if self.baseline:
157+
return self.forward_baseline(x)
158+
else:
159+
return self.forward_mask(x)
160+
161+
cfg = {
162+
'cifar': [(32,1), 64, (128,2), 128, (256,2), 256, (512,2), 512, 512, 512, 512, 512, (1024,2), 1024],
163+
'imagenet': [(32,2), 64, (128,2), 128, (256,2), 256, (512,2), 512, 512, 512, 512, 512, (1024,2), 1024],
164+
}
165+
166+
167+
def mobilenetv1(num_classes, dropout=False, from_TF=False):
168+
return MobileNet(num_classes, dropout, from_TF, depth_multiplier=1.)
169+
170+
def mobilenetv1_75(num_classes, dropout=False, from_TF=False):
171+
return MobileNet(num_classes, dropout, from_TF, depth_multiplier=0.75)
172+
173+
def mobilenetv1_50(num_classes, dropout=False, from_TF=False):
174+
return MobileNet(num_classes, dropout, from_TF, depth_multiplier=0.50)

0 commit comments

Comments
 (0)