@@ -36,7 +36,7 @@ def downsampling_residual_branch(x, conv_filters):
3636 return tf .concat ([x1 , x2 ], axis = 3 )
3737
3838
39- def shake_shake_block (x , conv_filters , stride , mode ):
39+ def shake_shake_block (x , conv_filters , stride , hparams ):
4040 with tf .variable_scope ('branch_1' ):
4141 branch1 = shake_shake_block_branch (x , conv_filters , stride )
4242 with tf .variable_scope ('branch_2' ):
@@ -47,21 +47,28 @@ def shake_shake_block(x, conv_filters, stride, mode):
4747 skip = downsampling_residual_branch (x , conv_filters )
4848
4949 # TODO(rshin): Use different alpha for each image in batch.
50- if mode == tf .contrib .learn .ModeKeys .TRAIN :
51- shaken = common_layers .shakeshake2 (branch1 , branch2 )
50+ if hparams .mode == tf .contrib .learn .ModeKeys .TRAIN :
51+ if hparams .shakeshake_type == 'batch' :
52+ shaken = common_layers .shakeshake2 (branch1 , branch2 )
53+ elif hparams .shakeshake_type == 'image' :
54+ shaken = common_layers .shakeshake2_indiv (branch1 , branch2 )
55+ elif hparams .shakeshake_type == 'equal' :
56+ shaken = common_layers .shakeshake2_py (branch1 , branch2 , equal = True )
57+ else :
58+ raise ValueError ('Invalid shakeshake_type: {!r}' .format (shaken ))
5259 else :
53- shaken = common_layers .shakeshake2_eqforward (branch1 , branch2 )
60+ shaken = common_layers .shakeshake2_py (branch1 , branch2 , equal = True )
5461 shaken .set_shape (branch1 .get_shape ())
5562
5663 return skip + shaken
5764
5865
59- def shake_shake_stage (x , num_blocks , conv_filters , initial_stride , mode ):
66+ def shake_shake_stage (x , num_blocks , conv_filters , initial_stride , hparams ):
6067 with tf .variable_scope ('block_0' ):
61- x = shake_shake_block (x , conv_filters , initial_stride , mode )
68+ x = shake_shake_block (x , conv_filters , initial_stride , hparams )
6269 for i in xrange (1 , num_blocks ):
6370 with tf .variable_scope ('block_{}' .format (i )):
64- x = shake_shake_block (x , conv_filters , 1 , mode )
71+ x = shake_shake_block (x , conv_filters , 1 , hparams )
6572 return x
6673
6774
@@ -76,6 +83,7 @@ class ShakeShake(t2t_model.T2TModel):
7683
7784 def model_fn_body (self , features ):
7885 hparams = self ._hparams
86+ print (hparams .learning_rate )
7987
8088 inputs = features ["inputs" ]
8189 assert (hparams .num_hidden_layers - 2 ) % 6 == 0
@@ -87,13 +95,14 @@ def model_fn_body(self, features):
8795 x = inputs
8896 mode = hparams .mode
8997 with tf .variable_scope ('shake_shake_stage_1' ):
90- x = shake_shake_stage (x , blocks_per_stage , hparams .base_filters , 1 , mode )
98+ x = shake_shake_stage (x , blocks_per_stage , hparams .base_filters , 1 ,
99+ hparams )
91100 with tf .variable_scope ('shake_shake_stage_2' ):
92101 x = shake_shake_stage (x , blocks_per_stage , hparams .base_filters * 2 , 2 ,
93- mode )
102+ hparams )
94103 with tf .variable_scope ('shake_shake_stage_3' ):
95104 x = shake_shake_stage (x , blocks_per_stage , hparams .base_filters * 4 , 2 ,
96- mode )
105+ hparams )
97106
98107 # For canonical Shake-Shake, we should perform 8x8 average pooling and then
99108 # have a fully-connected layer (which produces the logits for each class).
@@ -130,4 +139,5 @@ def shakeshake_cifar10():
130139 hparams .optimizer = "Momentum"
131140 hparams .optimizer_momentum_momentum = 0.9
132141 hparams .add_hparam ('base_filters' , 16 )
142+ hparams .add_hparam ('shakeshake_type' , 'batch' )
133143 return hparams
0 commit comments