Skip to content

Commit

Permalink
add inference phase for ade20k dataset ( using pspnet50 )
Browse files Browse the repository at this point in the history
  • Loading branch information
hellochick committed Jan 19, 2018
1 parent 69aff22 commit 624f571
Show file tree
Hide file tree
Showing 17 changed files with 390 additions and 116 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
./ade20k_model
./cityscapes_model
model.ckpt-*
checkpoint
11 changes: 4 additions & 7 deletions evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@
import tensorflow as tf
import numpy as np

from model import PSPNet
from tools import decode_labels
from model import PSPNet101
from image_reader import ImageReader

IMG_MEAN = np.array((103.939, 116.779, 123.68), dtype=np.float32)
Expand Down Expand Up @@ -80,13 +79,12 @@ def main():
image_batch, label_batch = tf.expand_dims(image, dim=0), tf.expand_dims(label, dim=0) # Add one batch dimension.

# Create network.
net = PSPNet({'data': image_batch}, is_training=False, num_classes=num_classes)
net = PSPNet101({'data': image_batch}, is_training=False, num_classes=num_classes)

with tf.variable_scope('', reuse=True):
flipped_img = tf.image.flip_left_right(image)
flipped_img = tf.expand_dims(flipped_img, dim=0)
net2 = PSPNet({'data': flipped_img}, is_training=False, num_classes=num_classes)

net2 = PSPNet101({'data': flipped_img}, is_training=False, num_classes=num_classes)

# Which variables to load.
restore_var = tf.global_variables()
Expand Down Expand Up @@ -142,9 +140,8 @@ def main():

if step % 10 == 0:
print('Finish {0}/{1}'.format(step, num_steps))
print('step {0} mIoU: {1}'.format(step, sess.run(mIoU)))

print('step {0} mIoU: {1}'.format(step, sess.run(mIoU)))
print('mIoU: {1}'.format(step, sess.run(mIoU)))

coord.request_stop()
coord.join(threads)
Expand Down
81 changes: 31 additions & 50 deletions inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,31 +4,36 @@
import os
import sys
import time
from PIL import Image
import tensorflow as tf
import numpy as np
from scipy import misc

from model import PSPNet
from tools import decode_labels
from model import PSPNet101, PSPNet50
from tools import *

IMG_MEAN = np.array((103.939, 116.779, 123.68), dtype=np.float32)
input_size = [1024, 2048]
num_classes = 19
ADE20k_param = {'crop_size': [473, 473],
'num_classes': 150,
'model': PSPNet50}
cityscapes_param = {'crop_size': [720, 720],
'num_classes': 19,
'model': PSPNet101}

SAVE_DIR = './output/'
SNAPSHOT_DIR = './model/'
crop_size = [720, 720]

def get_arguments():
parser = argparse.ArgumentParser(description="Reproduced PSPNet")
parser.add_argument("--img-path", type=str, default='',
help="Path to the RGB image file.")
parser.add_argument("--model", type=str, default=SNAPSHOT_DIR,
parser.add_argument("--checkpoints", type=str, default=SNAPSHOT_DIR,
help="Path to restore weights.")
parser.add_argument("--save-dir", type=str, default=SAVE_DIR,
help="Path to save output.")
parser.add_argument("--flipped-eval", action="store_true",
help="whether to evaluate with flipped img.")
parser.add_argument("--dataset", type=str, default='',
choices=['ade20k', 'cityscapes'],
required=True)

return parser.parse_args()

Expand All @@ -45,44 +50,23 @@ def load(saver, sess, ckpt_path):
saver.restore(sess, ckpt_path)
print("Restored model parameters from {}".format(ckpt_path))

def load_img(img_path):
if os.path.isfile(img_path):
print('successful load img: {0}'.format(img_path))
else:
print('not found file: {0}'.format(img_path))
sys.exit(0)

filename = img_path.split('/')[-1]
ext = filename.split('.')[-1]

if ext.lower() == 'png':
img = tf.image.decode_png(tf.read_file(img_path), channels=3)
elif ext.lower() == 'jpg':
img = tf.image.decode_jpeg(tf.read_file(img_path), channels=3)
else:
print('cannot process {0} file.'.format(file_type))

return img, filename

def preprocess(img, h, w):
# Convert RGB to BGR
img_r, img_g, img_b = tf.split(axis=2, num_or_size_splits=3, value=img)
img = tf.cast(tf.concat(axis=2, values=[img_b, img_g, img_r]), dtype=tf.float32)
# Extract mean.
img -= IMG_MEAN

pad_img = tf.image.pad_to_bounding_box(img, 0, 0, h, w)
pad_img = tf.expand_dims(pad_img, dim=0)

return pad_img

def main():
args = get_arguments()

# load parameters
if args.dataset == 'ade20k':
param = ADE20k_param
elif args.dataset == 'cityscapes':
param = cityscapes_param

crop_size = param['crop_size']
num_classes = param['num_classes']
PSPNet = param['model']

# preprocess images
img, filename = load_img(args.img_path)
img_shape = tf.shape(img)
h, w = (tf.maximum(crop_size[0], img_shape[0]), tf.maximum(crop_size[1], img_shape[1]))

img = preprocess(img, h, w)

# Create network.
Expand All @@ -92,8 +76,9 @@ def main():
flipped_img = tf.expand_dims(flipped_img, dim=0)
net2 = PSPNet({'data': flipped_img}, is_training=False, num_classes=num_classes)


raw_output = net.layers['conv6']

# Do flipped eval or not
if args.flipped_eval:
flipped_output = tf.image.flip_left_right(tf.squeeze(net2.layers['conv6']))
flipped_output = tf.expand_dims(flipped_output, dim=0)
Expand All @@ -103,7 +88,7 @@ def main():
raw_output_up = tf.image.resize_bilinear(raw_output, size=[h, w], align_corners=True)
raw_output_up = tf.image.crop_to_bounding_box(raw_output_up, 0, 0, img_shape[0], img_shape[1])
raw_output_up = tf.argmax(raw_output_up, dimension=3)
pred = tf.expand_dims(raw_output_up, dim=3)
pred = decode_labels(raw_output_up, img_shape, num_classes)

# Init tf Session
config = tf.ConfigProto()
Expand All @@ -113,11 +98,9 @@ def main():

sess.run(init)

saver = tf.train.Saver(var_list=tf.global_variables(), max_to_keep=10)

restore_var = tf.global_variables()

ckpt = tf.train.get_checkpoint_state(args.model)
ckpt = tf.train.get_checkpoint_state(args.checkpoints)
if ckpt and ckpt.model_checkpoint_path:
loader = tf.train.Saver(var_list=restore_var)
load_step = int(os.path.basename(ckpt.model_checkpoint_path).split('-')[1])
Expand All @@ -127,11 +110,9 @@ def main():

preds = sess.run(pred)

msk = decode_labels(preds, num_classes=num_classes)
im = Image.fromarray(msk[0])
if not os.path.exists(args.save_dir):
os.makedirs(args.save_dir)
im.save(args.save_dir + filename)
misc.imsave(args.save_dir + filename, preds[0])

if __name__ == '__main__':
main()
main()
Binary file removed input/._test_256x512.png
Binary file not shown.
Binary file removed input/._test_720x720.png
Binary file not shown.
Binary file added input/indoor_1.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added input/indoor_2.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading

0 comments on commit 624f571

Please sign in to comment.