-
Notifications
You must be signed in to change notification settings - Fork 81
Description
The function 'linear' in cnn.py is defined as:
def linear(input_,
output_size,
weights_initializer=initializers.xavier_initializer(),
biases_initializer=tf.zeros_initializer,
activation_fn=None,
trainable=True,
name='linear')
Its sixth parameter is a trainable boolean flag, but in the following two lines, it is wrongly assigned as 'data_format':
self.l4, self.var['l4_w'], self.var['l4_b'] =
linear(self.l3, 512, weights_initializer, biases_initializer,
hidden_activation_fn, data_format, name='l4_conv')
self.l3, self.var['l3_w'], self.var['l3_b'] =
linear(self.l2, 256, weights_initializer, biases_initializer,
hidden_activation_fn, data_format, name='l3_conv')
This will cause error messages "'NCHW' has type str, but expected one of: int, long, bool" when saving model files.
It would be nice to change as follows:
self.l4, self.var['l4_w'], self.var['l4_b'] =
linear(self.l3, 512, weights_initializer, biases_initializer,
hidden_activation_fn, trainable, name='l4_conv')
self.l3, self.var['l3_w'], self.var['l3_b'] =
linear(self.l2, 256, weights_initializer, biases_initializer,
hidden_activation_fn, trainable, name='l3_conv')