Skip to content

Commit

Permalink
add enet
Browse files Browse the repository at this point in the history
  • Loading branch information
hellochick committed Jan 26, 2018
1 parent 5782813 commit ba6c4b2
Show file tree
Hide file tree
Showing 10 changed files with 614 additions and 6 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
*.pyc
*.npy
__pycache__
*.ckpt
9 changes: 6 additions & 3 deletions inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@
import numpy as np
from scipy import misc

from model import FCN8s, PSPNet50
from model import FCN8s, PSPNet50, ENet

save_dir = './output/'
model_path = {'pspnet': './model/pspnet50.npy',
'fcn': './model/fcn.npy'}
'fcn': './model/fcn.npy',
'enet': './model/cityscapes/enet.ckpt'}

def get_arguments():
parser = argparse.ArgumentParser(description="Reproduced PSPNet")
Expand All @@ -24,7 +25,7 @@ def get_arguments():
help="Path to save output.")
parser.add_argument("--model", type=str, default='',
help="pspnet or fcn",
choices=['pspnet', 'fcn'],
choices=['pspnet', 'fcn', 'enet'],
required=True)

return parser.parse_args()
Expand All @@ -36,6 +37,8 @@ def main():
model = PSPNet50()
elif args.model == 'fcn':
model = FCN8s()
elif args.model == 'enet':
model = ENet()

model.read_input(args.img_path)

Expand Down
528 changes: 528 additions & 0 deletions model.py

Large diffs are not rendered by default.

Binary file added model/cityscapes/enet.ckpt.data-00000-of-00001
Binary file not shown.
Binary file added model/cityscapes/enet.ckpt.index
Binary file not shown.
Binary file added model/cityscapes/enet.ckpt.meta
Binary file not shown.
59 changes: 59 additions & 0 deletions network.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,3 +325,62 @@ def interp(self, input, factor, name):
resize_shape = [(int)(ori_h * factor), (int)(ori_w * factor)]

return tf.image.resize_bilinear(input, size=resize_shape, align_corners=True, name=name)

def PReLU(x, scope):
# PReLU(x) = x if x > 0, alpha*x otherwise

alpha = tf.get_variable(scope + "/alpha", shape=[1],
initializer=tf.constant_initializer(0), dtype=tf.float32)

output = tf.nn.relu(x) + alpha*(x - abs(x))*0.5

return output

# function for 2D spatial dropout:
def spatial_dropout(x, drop_prob):
# x is a tensor of shape [batch_size, height, width, channels]

keep_prob = 1.0 - drop_prob
input_shape = x.get_shape().as_list()

batch_size = input_shape[0]
channels = input_shape[3]

# drop each channel with probability drop_prob:
noise_shape = tf.constant(value=[batch_size, 1, 1, channels])
x_drop = tf.nn.dropout(x, keep_prob, noise_shape=noise_shape)

output = x_drop

return output

# function for unpooling max_pool:
def max_unpool(inputs, pooling_indices, output_shape=None, k_size=[1, 2, 2, 1]):
# NOTE! this function is based on the implementation by kwotsin in
# https://github.com/kwotsin/TensorFlow-ENet

# inputs has shape [batch_size, height, width, channels]

# pooling_indices: pooling indices of the previously max_pooled layer

# output_shape: what shape the returned tensor should have

pooling_indices = tf.cast(pooling_indices, tf.int32)
input_shape = tf.shape(inputs, out_type=tf.int32)

one_like_pooling_indices = tf.ones_like(pooling_indices, dtype=tf.int32)
batch_shape = tf.concat([[input_shape[0]], [1], [1], [1]], 0)
batch_range = tf.reshape(tf.range(input_shape[0], dtype=tf.int32), shape=batch_shape)
b = one_like_pooling_indices*batch_range
y = pooling_indices//(output_shape[2]*output_shape[3])
x = (pooling_indices//output_shape[3]) % output_shape[2]
feature_range = tf.range(output_shape[3], dtype=tf.int32)
f = one_like_pooling_indices*feature_range

inputs_size = tf.size(inputs)
indices = tf.transpose(tf.reshape(tf.stack([b, y, x, f]), [4, inputs_size]))
values = tf.reshape(inputs, [inputs_size])

ret = tf.scatter_nd(indices, values, output_shape)

return ret
Binary file added output/enet_test_1024x2048.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added output/enet_test_720x720.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
23 changes: 20 additions & 3 deletions tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,20 @@

IMG_MEAN = np.array((103.939, 116.779, 123.68), dtype=np.float32)
matfn = './utils/color150.mat'
label_colours = [[128, 64, 128], [244, 35, 231], [69, 69, 69]
# 0 = road, 1 = sidewalk, 2 = building
,[102, 102, 156], [190, 153, 153], [153, 153, 153]
# 3 = wall, 4 = fence, 5 = pole
,[250, 170, 29], [219, 219, 0], [106, 142, 35]
# 6 = traffic light, 7 = traffic sign, 8 = vegetation
,[152, 250, 152], [69, 129, 180], [219, 19, 60]
# 9 = terrain, 10 = sky, 11 = person
,[255, 0, 0], [0, 0, 142], [0, 0, 69]
# 12 = rider, 13 = car, 14 = truck
,[0, 60, 100], [0, 79, 100], [0, 0, 230]
# 15 = bus, 16 = train, 17 = motocycle
,[119, 10, 32]]
# 18 = bicycle

def read_labelcolours(matfn, append_background=False):
mat = sio.loadmat(matfn)
Expand All @@ -20,11 +34,14 @@ def read_labelcolours(matfn, append_background=False):
return color_list

def decode_labels(mask, img_shape, num_classes):
if num_classes == 151:
if num_classes == 151: # ade20k including background
color_table = read_labelcolours(matfn, append_background=True)
else:
elif num_classes == 150: # ade20k excluding background
color_table = read_labelcolours(matfn)

elif num_classes == 20: # cityscapes includin background
color_table = label_colours + [[255, 255, 255]]
color_table = [tuple(color_table[i]) for i in range(len(color_table))]

color_mat = tf.constant(color_table, dtype=tf.float32)
onehot_output = tf.one_hot(mask, depth=num_classes)
onehot_output = tf.reshape(onehot_output, (-1, num_classes))
Expand Down

0 comments on commit ba6c4b2

Please sign in to comment.