Skip to content

Commit

Permalink
add icnet
Browse files Browse the repository at this point in the history
  • Loading branch information
hellochick committed Jan 26, 2018
1 parent 58f836e commit c5931f4
Show file tree
Hide file tree
Showing 4 changed files with 262 additions and 6 deletions.
11 changes: 7 additions & 4 deletions inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,13 @@
import numpy as np
from scipy import misc

from model import FCN8s, PSPNet50, ENet
from model import FCN8s, PSPNet50, ENet, ICNet

save_dir = './output/'
model_path = {'pspnet': './model/pspnet50.npy',
'fcn': './model/fcn.npy',
'enet': './model/cityscapes/enet.ckpt'}
'enet': './model/cityscapes/enet.ckpt',
'icnet': './model/cityscapes/icnet.npy'}

def get_arguments():
parser = argparse.ArgumentParser(description="Reproduced PSPNet")
Expand All @@ -25,7 +26,7 @@ def get_arguments():
help="Path to save output.")
parser.add_argument("--model", type=str, default='',
help="pspnet or fcn",
choices=['pspnet', 'fcn', 'enet'],
choices=['pspnet', 'fcn', 'enet', 'icnet'],
required=True)

return parser.parse_args()
Expand All @@ -39,7 +40,9 @@ def main():
model = FCN8s()
elif args.model == 'enet':
model = ENet()

elif args.model == 'icnet':
model = ICNet()

model.read_input(args.img_path)

# Init tf Session
Expand Down
246 changes: 245 additions & 1 deletion model.py
Original file line number Diff line number Diff line change
Expand Up @@ -893,4 +893,248 @@ def get_variable_weight_decay(self, name, shape, initializer, loss_category,
weight_decay = self.wd*tf.nn.l2_loss(variable)
tf.add_to_collection(loss_category, weight_decay)

return variable
return variabl

class ICNet(Network):
def __init__(self, is_training=False, num_classes=19, input_size=[1024, 2048]):
self.input_size = input_size

self.x = tf.placeholder(dtype=tf.float32, shape=[None, None, 3])
self.img_tf, self.shape = preprocess(self.x, self.input_size, 'icnet')

super().__init__({'data': self.img_tf}, num_classes, is_training)

def setup(self, is_training, num_classes):
(self.feed('data')
.interp(factor=0.5, name='data_sub2')
.conv(3, 3, 32, 2, 2, biased=True, padding='SAME', relu=True, name='conv1_1_3x3_s2')
.conv(3, 3, 32, 1, 1, biased=True, padding='SAME', relu=True, name='conv1_2_3x3')
.conv(3, 3, 64, 1, 1, biased=True, padding='SAME', relu=True, name='conv1_3_3x3')
.max_pool(3, 3, 2, 2, name='pool1_3x3_s2')
.conv(1, 1, 128, 1, 1, biased=True, relu=False, name='conv2_1_1x1_proj'))

(self.feed('pool1_3x3_s2')
.conv(1, 1, 32, 1, 1, biased=True, relu=True, name='conv2_1_1x1_reduce')
.zero_padding(paddings=1, name='padding1')
.conv(3, 3, 32, 1, 1, biased=True, relu=True, name='conv2_1_3x3')
.conv(1, 1, 128, 1, 1, biased=True, relu=False, name='conv2_1_1x1_increase'))

(self.feed('conv2_1_1x1_proj',
'conv2_1_1x1_increase')
.add(name='conv2_1')
.relu(name='conv2_1/relu')
.conv(1, 1, 32, 1, 1, biased=True, relu=True, name='conv2_2_1x1_reduce')
.zero_padding(paddings=1, name='padding2')
.conv(3, 3, 32, 1, 1, biased=True, relu=True, name='conv2_2_3x3')
.conv(1, 1, 128, 1, 1, biased=True, relu=False, name='conv2_2_1x1_increase'))

(self.feed('conv2_1/relu',
'conv2_2_1x1_increase')
.add(name='conv2_2')
.relu(name='conv2_2/relu')
.conv(1, 1, 32, 1, 1, biased=True, relu=True, name='conv2_3_1x1_reduce')
.zero_padding(paddings=1, name='padding3')
.conv(3, 3, 32, 1, 1, biased=True, relu=True, name='conv2_3_3x3')
.conv(1, 1, 128, 1, 1, biased=True, relu=False, name='conv2_3_1x1_increase'))

(self.feed('conv2_2/relu',
'conv2_3_1x1_increase')
.add(name='conv2_3')
.relu(name='conv2_3/relu')
.conv(1, 1, 256, 2, 2, biased=True, relu=False, name='conv3_1_1x1_proj'))

(self.feed('conv2_3/relu')
.conv(1, 1, 64, 2, 2, biased=True, relu=True, name='conv3_1_1x1_reduce')
.zero_padding(paddings=1, name='padding4')
.conv(3, 3, 64, 1, 1, biased=True, relu=True, name='conv3_1_3x3')
.conv(1, 1, 256, 1, 1, biased=True, relu=False, name='conv3_1_1x1_increase'))

(self.feed('conv3_1_1x1_proj',
'conv3_1_1x1_increase')
.add(name='conv3_1')
.relu(name='conv3_1/relu')
.interp(factor=0.5, name='conv3_1_sub4')
.conv(1, 1, 64, 1, 1, biased=True, relu=True, name='conv3_2_1x1_reduce')
.zero_padding(paddings=1, name='padding5')
.conv(3, 3, 64, 1, 1, biased=True, relu=True, name='conv3_2_3x3')
.conv(1, 1, 256, 1, 1, biased=True, relu=False, name='conv3_2_1x1_increase'))

(self.feed('conv3_1_sub4',
'conv3_2_1x1_increase')
.add(name='conv3_2')
.relu(name='conv3_2/relu')
.conv(1, 1, 64, 1, 1, biased=True, relu=True, name='conv3_3_1x1_reduce')
.zero_padding(paddings=1, name='padding6')
.conv(3, 3, 64, 1, 1, biased=True, relu=True, name='conv3_3_3x3')
.conv(1, 1, 256, 1, 1, biased=True, relu=False, name='conv3_3_1x1_increase'))

(self.feed('conv3_2/relu',
'conv3_3_1x1_increase')
.add(name='conv3_3')
.relu(name='conv3_3/relu')
.conv(1, 1, 64, 1, 1, biased=True, relu=True, name='conv3_4_1x1_reduce')
.zero_padding(paddings=1, name='padding7')
.conv(3, 3, 64, 1, 1, biased=True, relu=True, name='conv3_4_3x3')
.conv(1, 1, 256, 1, 1, biased=True, relu=False, name='conv3_4_1x1_increase'))

(self.feed('conv3_3/relu',
'conv3_4_1x1_increase')
.add(name='conv3_4')
.relu(name='conv3_4/relu')
.conv(1, 1, 512, 1, 1, biased=True, relu=False, name='conv4_1_1x1_proj'))

(self.feed('conv3_4/relu')
.conv(1, 1, 128, 1, 1, biased=True, relu=True, name='conv4_1_1x1_reduce')
.zero_padding(paddings=2, name='padding8')
.atrous_conv(3, 3, 128, 2, biased=True, relu=True, name='conv4_1_3x3')
.conv(1, 1, 512, 1, 1, biased=True, relu=False, name='conv4_1_1x1_increase'))

(self.feed('conv4_1_1x1_proj',
'conv4_1_1x1_increase')
.add(name='conv4_1')
.relu(name='conv4_1/relu')
.conv(1, 1, 128, 1, 1, biased=True, relu=True, name='conv4_2_1x1_reduce')
.zero_padding(paddings=2, name='padding9')
.atrous_conv(3, 3, 128, 2, biased=True, relu=True, name='conv4_2_3x3')
.conv(1, 1, 512, 1, 1, biased=True, relu=False, name='conv4_2_1x1_increase'))

(self.feed('conv4_1/relu',
'conv4_2_1x1_increase')
.add(name='conv4_2')
.relu(name='conv4_2/relu')
.conv(1, 1, 128, 1, 1, biased=True, relu=True, name='conv4_3_1x1_reduce')
.zero_padding(paddings=2, name='padding10')
.atrous_conv(3, 3, 128, 2, biased=True, relu=True, name='conv4_3_3x3')
.conv(1, 1, 512, 1, 1, biased=True, relu=False, name='conv4_3_1x1_increase'))

(self.feed('conv4_2/relu',
'conv4_3_1x1_increase')
.add(name='conv4_3')
.relu(name='conv4_3/relu')
.conv(1, 1, 128, 1, 1, biased=True, relu=True, name='conv4_4_1x1_reduce')
.zero_padding(paddings=2, name='padding11')
.atrous_conv(3, 3, 128, 2, biased=True, relu=True, name='conv4_4_3x3')
.conv(1, 1, 512, 1, 1, biased=True, relu=False, name='conv4_4_1x1_increase'))

(self.feed('conv4_3/relu',
'conv4_4_1x1_increase')
.add(name='conv4_4')
.relu(name='conv4_4/relu')
.conv(1, 1, 128, 1, 1, biased=True, relu=True, name='conv4_5_1x1_reduce')
.zero_padding(paddings=2, name='padding12')
.atrous_conv(3, 3, 128, 2, biased=True, relu=True, name='conv4_5_3x3')
.conv(1, 1, 512, 1, 1, biased=True, relu=False, name='conv4_5_1x1_increase'))

(self.feed('conv4_4/relu',
'conv4_5_1x1_increase')
.add(name='conv4_5')
.relu(name='conv4_5/relu')
.conv(1, 1, 128, 1, 1, biased=True, relu=True, name='conv4_6_1x1_reduce')
.zero_padding(paddings=2, name='padding13')
.atrous_conv(3, 3, 128, 2, biased=True, relu=True, name='conv4_6_3x3')
.conv(1, 1, 512, 1, 1, biased=True, relu=False, name='conv4_6_1x1_increase'))

(self.feed('conv4_5/relu',
'conv4_6_1x1_increase')
.add(name='conv4_6')
.relu(name='conv4_6/relu')
.conv(1, 1, 1024, 1, 1, biased=True, relu=False, name='conv5_1_1x1_proj'))

(self.feed('conv4_6/relu')
.conv(1, 1, 256, 1, 1, biased=True, relu=True, name='conv5_1_1x1_reduce')
.zero_padding(paddings=4, name='padding14')
.atrous_conv(3, 3, 256, 4, biased=True, relu=True, name='conv5_1_3x3')
.conv(1, 1, 1024, 1, 1, biased=True, relu=False, name='conv5_1_1x1_increase'))

(self.feed('conv5_1_1x1_proj',
'conv5_1_1x1_increase')
.add(name='conv5_1')
.relu(name='conv5_1/relu')
.conv(1, 1, 256, 1, 1, biased=True, relu=True, name='conv5_2_1x1_reduce')
.zero_padding(paddings=4, name='padding15')
.atrous_conv(3, 3, 256, 4, biased=True, relu=True, name='conv5_2_3x3')
.conv(1, 1, 1024, 1, 1, biased=True, relu=False, name='conv5_2_1x1_increase'))

(self.feed('conv5_1/relu',
'conv5_2_1x1_increase')
.add(name='conv5_2')
.relu(name='conv5_2/relu')
.conv(1, 1, 256, 1, 1, biased=True, relu=True, name='conv5_3_1x1_reduce')
.zero_padding(paddings=4, name='padding16')
.atrous_conv(3, 3, 256, 4, biased=True, relu=True, name='conv5_3_3x3')
.conv(1, 1, 1024, 1, 1, biased=True, relu=False, name='conv5_3_1x1_increase'))

(self.feed('conv5_2/relu',
'conv5_3_1x1_increase')
.add(name='conv5_3')
.relu(name='conv5_3/relu'))

shape = self.layers['conv5_3/relu'].get_shape().as_list()[1:3]
h, w = shape

(self.feed('conv5_3/relu')
.avg_pool(h, w, h, w, name='conv5_3_pool1')
.resize_bilinear(shape, name='conv5_3_pool1_interp'))

(self.feed('conv5_3/relu')
.avg_pool(h/2, w/2, h/2, w/2, name='conv5_3_pool2')
.resize_bilinear(shape, name='conv5_3_pool2_interp'))

(self.feed('conv5_3/relu')
.avg_pool(h/3, w/3, h/3, w/3, name='conv5_3_pool3')
.resize_bilinear(shape, name='conv5_3_pool3_interp'))

(self.feed('conv5_3/relu')
.avg_pool(h/4, w/4, h/4, w/4, name='conv5_3_pool6')
.resize_bilinear(shape, name='conv5_3_pool6_interp'))

(self.feed('conv5_3/relu',
'conv5_3_pool6_interp',
'conv5_3_pool3_interp',
'conv5_3_pool2_interp',
'conv5_3_pool1_interp')
.add(name='conv5_3_sum')
.conv(1, 1, 256, 1, 1, biased=True, relu=True, name='conv5_4_k1')
.interp(factor=2.0, name='conv5_4_interp')
.zero_padding(paddings=2, name='padding17')
.atrous_conv(3, 3, 128, 2, biased=True, relu=False, name='conv_sub4'))

(self.feed('conv3_1/relu')
.conv(1, 1, 128, 1, 1, biased=True, relu=False, name='conv3_1_sub2_proj'))

(self.feed('conv_sub4',
'conv3_1_sub2_proj')
.add(name='sub24_sum')
.relu(name='sub24_sum/relu')
.interp(factor=2.0, name='sub24_sum_interp')
.zero_padding(paddings=2, name='padding18')
.atrous_conv(3, 3, 128, 2, biased=True, relu=False, name='conv_sub2'))

(self.feed('data')
.conv(3, 3, 32, 2, 2, biased=True, padding='SAME', relu=True, name='conv1_sub1')
.conv(3, 3, 32, 2, 2, biased=True, padding='SAME', relu=True, name='conv2_sub1')
.conv(3, 3, 64, 2, 2, biased=True, padding='SAME', relu=True, name='conv3_sub1')
.conv(1, 1, 128, 1, 1, biased=True, relu=False, name='conv3_sub1_proj'))

(self.feed('conv_sub2',
'conv3_sub1_proj')
.add(name='sub12_sum')
.relu(name='sub12_sum/relu')
.interp(factor=2.0, name='sub12_sum_interp')
.conv(1, 1, num_classes, 1, 1, biased=True, relu=False, name='conv6_cls'))

raw_output = self.layers['conv6_cls']
raw_output_up = tf.image.resize_bilinear(raw_output, size=self.shape, align_corners=True)
raw_output_up = tf.argmax(raw_output_up, dimension=3)
self.pred = decode_labels(raw_output_up, self.shape, num_classes)

def read_input(self, img_path):
self.img, self.img_name = load_img(img_path)

def forward(self, sess):
return sess.run(self.pred, feed_dict={self.x: self.img})

"""
def forward(self, img_array, sess):
return sess.run(self.pred, feed_dict={self.x: self.img_array})
"""
Binary file added output/icnet_outdoor_1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
11 changes: 10 additions & 1 deletion tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@ def decode_labels(mask, img_shape, num_classes):
elif num_classes == 20: # cityscapes includin background
color_table = label_colours + [[255, 255, 255]]
color_table = [tuple(color_table[i]) for i in range(len(color_table))]

elif num_classes == 19:
color_table = label_colours

color_mat = tf.constant(color_table, dtype=tf.float32)
onehot_output = tf.one_hot(mask, depth=num_classes)
onehot_output = tf.reshape(onehot_output, (-1, num_classes))
Expand Down Expand Up @@ -93,4 +95,11 @@ def preprocess(img, input_size, model):

return output, h, w, shape

elif model == 'icnet':
img = tf.expand_dims(img, dim=0)
output = tf.image.resize_bilinear(img, input_size)

return output, input_size



0 comments on commit c5931f4

Please sign in to comment.