Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion tf_unet/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,10 @@ def conv2d(x, W, b, keep_prob_):
with tf.name_scope("conv2d"):
conv_2d = tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='VALID')
conv_2d_b = tf.nn.bias_add(conv_2d, b)
return tf.nn.dropout(conv_2d_b, keep_prob_)
if keep_prob_ < 1.0:
return tf.nn.dropout(conv_2d_b, keep_prob_)
else:
return conv_2d_b

def deconv2d(x, W,stride):
with tf.name_scope("deconv2d"):
Expand Down
29 changes: 14 additions & 15 deletions tf_unet/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def create_conv_net(x, keep_prob, channels, n_class, layers=3, features_root=16,
with tf.name_scope("output_map"):
weight = weight_variable([1, 1, features_root, n_class], stddev)
bias = bias_variable([n_class], name="bias")
conv = conv2d(in_node, weight, bias, tf.constant(1.0))
conv = conv2d(in_node, weight, bias, 1.0)
output_map = tf.nn.relu(conv)
up_h_convs["out"] = output_map

Expand Down Expand Up @@ -185,15 +185,15 @@ class Unet(object):
:param cost_kwargs: (optional) kwargs passed to the cost function. See Unet._get_cost for more options
"""

def __init__(self, channels=3, n_class=2, cost="cross_entropy", cost_kwargs={}, **kwargs):
def __init__(self, channels=3, n_class=2, keep_prob=1.0, cost="cross_entropy", cost_kwargs={}, **kwargs):
tf.reset_default_graph()

self.n_class = n_class
self.summaries = kwargs.get("summaries", True)

self.x = tf.placeholder("float", shape=[None, None, None, channels], name="x")
self.y = tf.placeholder("float", shape=[None, None, None, n_class], name="y")
self.keep_prob = tf.placeholder(tf.float32, name="dropout_probability") # dropout (keep probability)
self.keep_prob = keep_prob # dropout (keep probability)

logits, self.variables, self.offset = create_conv_net(self.x, self.keep_prob, channels, n_class, **kwargs)

Expand Down Expand Up @@ -274,7 +274,7 @@ def predict(self, model_path, x_test):
self.restore(sess, model_path)

y_dummy = np.empty((x_test.shape[0], x_test.shape[1], x_test.shape[2], self.n_class))
prediction = sess.run(self.predicter, feed_dict={self.x: x_test, self.y: y_dummy, self.keep_prob: 1.})
prediction = sess.run(self.predicter, feed_dict={self.x: x_test, self.y: y_dummy})

return prediction

Expand Down Expand Up @@ -387,8 +387,8 @@ def _initialize(self, training_iters, output_path, restore, prediction_path):

return init

def train(self, data_provider, output_path, training_iters=10, epochs=100, dropout=0.75, display_step=1,
restore=False, write_graph=False, prediction_path='prediction'):
def train(self, data_provider, output_path, training_iters=10, epochs=100, display_step=1,
restore=False, write_graph=False, prediction_path='prediction', inter=1, intra=56):
"""
Lauches the training process

Expand All @@ -401,14 +401,17 @@ def train(self, data_provider, output_path, training_iters=10, epochs=100, dropo
:param restore: Flag if previous model should be restored
:param write_graph: Flag if the computation graph should be written as protobuf file to the output path
:param prediction_path: path where to save predictions on each epoch
:param inter: number of inter threads for train
:param intra: number of intra threads for train
"""
save_path = os.path.join(output_path, "model.ckpt")
if epochs == 0:
return save_path

init = self._initialize(training_iters, output_path, restore, prediction_path)

with tf.Session() as sess:
config = tf.ConfigProto(intra_op_parallelism_threads=intra,inter_op_parallelism_threads=inter)
with tf.Session(config=config) as sess:
if write_graph:
tf.train.write_graph(sess.graph_def, output_path, "graph.pb", False)

Expand All @@ -435,8 +438,7 @@ def train(self, data_provider, output_path, training_iters=10, epochs=100, dropo
_, loss, lr, gradients = sess.run(
(self.optimizer, self.net.cost, self.learning_rate_node, self.net.gradients_node),
feed_dict={self.net.x: batch_x,
self.net.y: util.crop_to_shape(batch_y, pred_shape),
self.net.keep_prob: dropout})
self.net.y: util.crop_to_shape(batch_y, pred_shape)})

if self.net.summaries and self.norm_grads:
avg_gradients = _update_avg_gradients(avg_gradients, gradients, step)
Expand All @@ -459,13 +461,11 @@ def train(self, data_provider, output_path, training_iters=10, epochs=100, dropo

def store_prediction(self, sess, batch_x, batch_y, name):
prediction = sess.run(self.net.predicter, feed_dict={self.net.x: batch_x,
self.net.y: batch_y,
self.net.keep_prob: 1.})
self.net.y: batch_y})
pred_shape = prediction.shape

loss = sess.run(self.net.cost, feed_dict={self.net.x: batch_x,
self.net.y: util.crop_to_shape(batch_y, pred_shape),
self.net.keep_prob: 1.})
self.net.y: util.crop_to_shape(batch_y, pred_shape)})

logging.info("Verification error= {:.1f}%, loss= {:.4f}".format(error_rate(prediction,
util.crop_to_shape(batch_y,
Expand All @@ -488,8 +488,7 @@ def output_minibatch_stats(self, sess, summary_writer, step, batch_x, batch_y):
self.net.accuracy,
self.net.predicter],
feed_dict={self.net.x: batch_x,
self.net.y: batch_y,
self.net.keep_prob: 1.})
self.net.y: batch_y})
summary_writer.add_summary(summary_str, step)
summary_writer.flush()
logging.info(
Expand Down