Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 2adf3ae

Browse files
committed
Add shakeshake_type hparam: batch, image, equal
1 parent 8287094 commit 2adf3ae

File tree

2 files changed

+39
-15
lines changed

2 files changed

+39
-15
lines changed

tensor2tensor/models/common_layers.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,13 @@ def inverse_exp_decay(max_step, min_value=0.01):
6060

6161
def shakeshake2_py(x, y, equal=False, individual=False):
6262
"""The shake-shake sum of 2 tensors, python version."""
63-
alpha = 0.5 if equal else tf.random_uniform([])
63+
if equal:
64+
alpha = 0.5
65+
if individual:
66+
alpha = tf.random_uniform(tf.get_shape(x)[:1])
67+
else:
68+
alpha = tf.random_uniform([])
69+
6470
return alpha * x + (1.0 - alpha) * y
6571

6672

@@ -72,6 +78,14 @@ def shakeshake2_grad(x1, x2, dy):
7278
return dx
7379

7480

81+
@function.Defun()
82+
def shakeshake2_indiv_grad(x1, x2, dy):
83+
"""Overriding gradient for shake-shake of 2 tensors."""
84+
y = shakeshake2_py(x1, x2, individual=True)
85+
dx = tf.gradients(ys=[y], xs=[x1, x2], grad_ys=[dy])
86+
return dx
87+
88+
7589
@function.Defun()
7690
def shakeshake2_equal_grad(x1, x2, dy):
7791
"""Overriding gradient for shake-shake of 2 tensors."""
@@ -85,10 +99,10 @@ def shakeshake2(x1, x2):
8599
"""The shake-shake function with a different alpha for forward/backward."""
86100
return shakeshake2_py(x1, x2)
87101

88-
@function.Defun(grad_func=shakeshake2_grad)
89-
def shakeshake2_eqforward(x1, x2):
90-
"""The shake-shake function with a different alpha for forward/backward."""
91-
return shakeshake2_py(x1, x2, equal=True)
102+
103+
@function.Defun(grad_func=shakeshake2_indiv_grad)
104+
def shakeshake2_indiv(x1, x2):
105+
return shakeshake2_py(x1, x2, individual=True)
92106

93107

94108
@function.Defun(grad_func=shakeshake2_equal_grad)

tensor2tensor/models/shake_shake.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def downsampling_residual_branch(x, conv_filters):
3636
return tf.concat([x1, x2], axis=3)
3737

3838

39-
def shake_shake_block(x, conv_filters, stride, mode):
39+
def shake_shake_block(x, conv_filters, stride, hparams):
4040
with tf.variable_scope('branch_1'):
4141
branch1 = shake_shake_block_branch(x, conv_filters, stride)
4242
with tf.variable_scope('branch_2'):
@@ -47,21 +47,28 @@ def shake_shake_block(x, conv_filters, stride, mode):
4747
skip = downsampling_residual_branch(x, conv_filters)
4848

4949
# TODO(rshin): Use different alpha for each image in batch.
50-
if mode == tf.contrib.learn.ModeKeys.TRAIN:
51-
shaken = common_layers.shakeshake2(branch1, branch2)
50+
if hparams.mode == tf.contrib.learn.ModeKeys.TRAIN:
51+
if hparams.shakeshake_type == 'batch':
52+
shaken = common_layers.shakeshake2(branch1, branch2)
53+
elif hparams.shakeshake_type == 'image':
54+
shaken = common_layers.shakeshake2_indiv(branch1, branch2)
55+
elif hparams.shakeshake_type == 'equal':
56+
shaken = common_layers.shakeshake2_py(branch1, branch2, equal=True)
57+
else:
58+
raise ValueError('Invalid shakeshake_type: {!r}'.format(shaken))
5259
else:
53-
shaken = common_layers.shakeshake2_eqforward(branch1, branch2)
60+
shaken = common_layers.shakeshake2_py(branch1, branch2, equal=True)
5461
shaken.set_shape(branch1.get_shape())
5562

5663
return skip + shaken
5764

5865

59-
def shake_shake_stage(x, num_blocks, conv_filters, initial_stride, mode):
66+
def shake_shake_stage(x, num_blocks, conv_filters, initial_stride, hparams):
6067
with tf.variable_scope('block_0'):
61-
x = shake_shake_block(x, conv_filters, initial_stride, mode)
68+
x = shake_shake_block(x, conv_filters, initial_stride, hparams)
6269
for i in xrange(1, num_blocks):
6370
with tf.variable_scope('block_{}'.format(i)):
64-
x = shake_shake_block(x, conv_filters, 1, mode)
71+
x = shake_shake_block(x, conv_filters, 1, hparams)
6572
return x
6673

6774

@@ -76,6 +83,7 @@ class ShakeShake(t2t_model.T2TModel):
7683

7784
def model_fn_body(self, features):
7885
hparams = self._hparams
86+
print(hparams.learning_rate)
7987

8088
inputs = features["inputs"]
8189
assert (hparams.num_hidden_layers - 2) % 6 == 0
@@ -87,13 +95,14 @@ def model_fn_body(self, features):
8795
x = inputs
8896
mode = hparams.mode
8997
with tf.variable_scope('shake_shake_stage_1'):
90-
x = shake_shake_stage(x, blocks_per_stage, hparams.base_filters, 1, mode)
98+
x = shake_shake_stage(x, blocks_per_stage, hparams.base_filters, 1,
99+
hparams)
91100
with tf.variable_scope('shake_shake_stage_2'):
92101
x = shake_shake_stage(x, blocks_per_stage, hparams.base_filters * 2, 2,
93-
mode)
102+
hparams)
94103
with tf.variable_scope('shake_shake_stage_3'):
95104
x = shake_shake_stage(x, blocks_per_stage, hparams.base_filters * 4, 2,
96-
mode)
105+
hparams)
97106

98107
# For canonical Shake-Shake, we should perform 8x8 average pooling and then
99108
# have a fully-connected layer (which produces the logits for each class).
@@ -130,4 +139,5 @@ def shakeshake_cifar10():
130139
hparams.optimizer = "Momentum"
131140
hparams.optimizer_momentum_momentum = 0.9
132141
hparams.add_hparam('base_filters', 16)
142+
hparams.add_hparam('shakeshake_type', 'batch')
133143
return hparams

0 commit comments

Comments
 (0)