Skip to content

Incorrect assignment of parameter trainable for function linear in cnn.py. #3

@chch1970

Description

@chch1970

The function 'linear' in cnn.py is defined as:
def linear(input_,
output_size,
weights_initializer=initializers.xavier_initializer(),
biases_initializer=tf.zeros_initializer,
activation_fn=None,
trainable=True,
name='linear')

Its sixth parameter is a trainable boolean flag, but in the following two lines, it is wrongly assigned as 'data_format':
self.l4, self.var['l4_w'], self.var['l4_b'] =
linear(self.l3, 512, weights_initializer, biases_initializer,
hidden_activation_fn, data_format, name='l4_conv')
self.l3, self.var['l3_w'], self.var['l3_b'] =
linear(self.l2, 256, weights_initializer, biases_initializer,
hidden_activation_fn, data_format, name='l3_conv')
This will cause error messages "'NCHW' has type str, but expected one of: int, long, bool" when saving model files.

It would be nice to change as follows:
self.l4, self.var['l4_w'], self.var['l4_b'] =
linear(self.l3, 512, weights_initializer, biases_initializer,
hidden_activation_fn, trainable, name='l4_conv')
self.l3, self.var['l3_w'], self.var['l3_b'] =
linear(self.l2, 256, weights_initializer, biases_initializer,
hidden_activation_fn, trainable, name='l3_conv')

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions