Skip to content

Commit

Permalink
add flag --train-beta-gamma
Browse files Browse the repository at this point in the history
  • Loading branch information
hellochick committed Nov 24, 2017
1 parent 2e9436e commit 69aff22
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ def get_arguments():
help="Regularisation parameter for L2-loss.")
parser.add_argument("--update-mean-var", action="store_true",
help="whether to get update_op from tf.Graphic_Keys")
parser.add_argument("--train-beta-gamma", action="store_true",
help="whether to train beta & gamma in bn layer")
return parser.parse_args()

def save(saver, sess, logdir, step):
Expand Down Expand Up @@ -125,7 +127,7 @@ def main():
# According from the prototxt in Caffe implement, learning rate must multiply by 10.0 in pyramid module
fc_list = ['conv5_3_pool1_conv', 'conv5_3_pool2_conv', 'conv5_3_pool3_conv', 'conv5_3_pool6_conv', 'conv6', 'conv5_4']
restore_var = [v for v in tf.global_variables()]
all_trainable = [v for v in tf.trainable_variables() if 'gamma' not in v.name and 'beta' not in v.name]
all_trainable = [v for v in tf.trainable_variables() if ('beta' not in v.name and 'gamma' not in v.name) or args.train_beta_gamma]
fc_trainable = [v for v in all_trainable if v.name.split('/')[0] in fc_list]
conv_trainable = [v for v in all_trainable if v.name.split('/')[0] not in fc_list] # lr * 1.0
fc_w_trainable = [v for v in fc_trainable if 'weights' in v.name] # lr * 10.0
Expand Down

0 comments on commit 69aff22

Please sign in to comment.