Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 4617c01

Browse files
authored
Merge pull request #153 from rshin/shakeshake
Add Shake-Shake model
2 parents 28e0e4e + 2adf3ae commit 4617c01

File tree

5 files changed

+169
-2
lines changed

5 files changed

+169
-2
lines changed

tensor2tensor/models/common_hparams.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ def basic_params1():
6161
weight_noise=0.0,
6262
learning_rate_decay_scheme="none",
6363
learning_rate_warmup_steps=100,
64+
learning_rate_cosine_cycle_steps=250000,
6465
learning_rate=0.1,
6566
sampling_method="argmax", # "argmax" or "random"
6667
problem_choice="adaptive", # "uniform", "adaptive", "distributed"

tensor2tensor/models/common_layers.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,15 @@ def inverse_exp_decay(max_step, min_value=0.01):
5858
return inv_base**tf.maximum(float(max_step) - step, 0.0)
5959

6060

61-
def shakeshake2_py(x, y, equal=False):
61+
def shakeshake2_py(x, y, equal=False, individual=False):
6262
"""The shake-shake sum of 2 tensors, python version."""
63-
alpha = 0.5 if equal else tf.random_uniform([])
63+
if equal:
64+
alpha = 0.5
65+
if individual:
66+
alpha = tf.random_uniform(tf.get_shape(x)[:1])
67+
else:
68+
alpha = tf.random_uniform([])
69+
6470
return alpha * x + (1.0 - alpha) * y
6571

6672

@@ -72,6 +78,14 @@ def shakeshake2_grad(x1, x2, dy):
7278
return dx
7379

7480

81+
@function.Defun()
82+
def shakeshake2_indiv_grad(x1, x2, dy):
83+
"""Overriding gradient for shake-shake of 2 tensors."""
84+
y = shakeshake2_py(x1, x2, individual=True)
85+
dx = tf.gradients(ys=[y], xs=[x1, x2], grad_ys=[dy])
86+
return dx
87+
88+
7589
@function.Defun()
7690
def shakeshake2_equal_grad(x1, x2, dy):
7791
"""Overriding gradient for shake-shake of 2 tensors."""
@@ -86,6 +100,11 @@ def shakeshake2(x1, x2):
86100
return shakeshake2_py(x1, x2)
87101

88102

103+
@function.Defun(grad_func=shakeshake2_indiv_grad)
104+
def shakeshake2_indiv(x1, x2):
105+
return shakeshake2_py(x1, x2, individual=True)
106+
107+
89108
@function.Defun(grad_func=shakeshake2_equal_grad)
90109
def shakeshake2_eqgrad(x1, x2):
91110
"""The shake-shake function with a different alpha for forward/backward."""

tensor2tensor/models/models.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from tensor2tensor.models import modalities
3131
from tensor2tensor.models import multimodel
3232
from tensor2tensor.models import neural_gpu
33+
from tensor2tensor.models import shake_shake
3334
from tensor2tensor.models import slicenet
3435
from tensor2tensor.models import transformer
3536
from tensor2tensor.models import transformer_alternative
Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
from __future__ import absolute_import
2+
from __future__ import division
3+
from __future__ import print_function
4+
5+
from six.moves import xrange # pylint: disable=redefined-builtin
6+
7+
from tensor2tensor.models import common_hparams
8+
from tensor2tensor.models import common_layers
9+
from tensor2tensor.utils import registry
10+
from tensor2tensor.utils import t2t_model
11+
12+
import tensorflow as tf
13+
14+
15+
def shake_shake_block_branch(x, conv_filters, stride):
16+
x = tf.nn.relu(x)
17+
x = tf.layers.conv2d(
18+
x, conv_filters, (3, 3), strides=(stride, stride), padding='SAME')
19+
x = tf.layers.batch_normalization(x)
20+
x = tf.nn.relu(x)
21+
x = tf.layers.conv2d(x, conv_filters, (3, 3), strides=(1, 1), padding='SAME')
22+
x = tf.layers.batch_normalization(x)
23+
return x
24+
25+
26+
def downsampling_residual_branch(x, conv_filters):
27+
x = tf.nn.relu(x)
28+
29+
x1 = tf.layers.average_pooling2d(x, pool_size=(1, 1), strides=(2, 2))
30+
x1 = tf.layers.conv2d(x1, conv_filters / 2, (1, 1), padding='SAME')
31+
32+
x2 = tf.pad(x[:, 1:, 1:], [[0, 0], [0, 1], [0, 1], [0, 0]])
33+
x2 = tf.layers.average_pooling2d(x2, pool_size=(1, 1), strides=(2, 2))
34+
x2 = tf.layers.conv2d(x2, conv_filters / 2, (1, 1), padding='SAME')
35+
36+
return tf.concat([x1, x2], axis=3)
37+
38+
39+
def shake_shake_block(x, conv_filters, stride, hparams):
40+
with tf.variable_scope('branch_1'):
41+
branch1 = shake_shake_block_branch(x, conv_filters, stride)
42+
with tf.variable_scope('branch_2'):
43+
branch2 = shake_shake_block_branch(x, conv_filters, stride)
44+
if x.shape[-1] == conv_filters:
45+
skip = tf.identity(x)
46+
else:
47+
skip = downsampling_residual_branch(x, conv_filters)
48+
49+
# TODO(rshin): Use different alpha for each image in batch.
50+
if hparams.mode == tf.contrib.learn.ModeKeys.TRAIN:
51+
if hparams.shakeshake_type == 'batch':
52+
shaken = common_layers.shakeshake2(branch1, branch2)
53+
elif hparams.shakeshake_type == 'image':
54+
shaken = common_layers.shakeshake2_indiv(branch1, branch2)
55+
elif hparams.shakeshake_type == 'equal':
56+
shaken = common_layers.shakeshake2_py(branch1, branch2, equal=True)
57+
else:
58+
raise ValueError('Invalid shakeshake_type: {!r}'.format(shaken))
59+
else:
60+
shaken = common_layers.shakeshake2_py(branch1, branch2, equal=True)
61+
shaken.set_shape(branch1.get_shape())
62+
63+
return skip + shaken
64+
65+
66+
def shake_shake_stage(x, num_blocks, conv_filters, initial_stride, hparams):
67+
with tf.variable_scope('block_0'):
68+
x = shake_shake_block(x, conv_filters, initial_stride, hparams)
69+
for i in xrange(1, num_blocks):
70+
with tf.variable_scope('block_{}'.format(i)):
71+
x = shake_shake_block(x, conv_filters, 1, hparams)
72+
return x
73+
74+
75+
@registry.register_model
76+
class ShakeShake(t2t_model.T2TModel):
77+
'''Implements the Shake-Shake architecture.
78+
79+
From <https://arxiv.org/pdf/1705.07485.pdf>
80+
This is intended to match the CIFAR-10 version, and correspond to
81+
"Shake-Shake-Batch" in Table 1.
82+
'''
83+
84+
def model_fn_body(self, features):
85+
hparams = self._hparams
86+
print(hparams.learning_rate)
87+
88+
inputs = features["inputs"]
89+
assert (hparams.num_hidden_layers - 2) % 6 == 0
90+
blocks_per_stage = (hparams.num_hidden_layers - 2) // 6
91+
92+
# For canonical Shake-Shake, the entry flow is a 3x3 convolution with 16
93+
# filters then a batch norm. Instead we will rely on the one in
94+
# SmallImageModality, which seems to instead use a layer norm.
95+
x = inputs
96+
mode = hparams.mode
97+
with tf.variable_scope('shake_shake_stage_1'):
98+
x = shake_shake_stage(x, blocks_per_stage, hparams.base_filters, 1,
99+
hparams)
100+
with tf.variable_scope('shake_shake_stage_2'):
101+
x = shake_shake_stage(x, blocks_per_stage, hparams.base_filters * 2, 2,
102+
hparams)
103+
with tf.variable_scope('shake_shake_stage_3'):
104+
x = shake_shake_stage(x, blocks_per_stage, hparams.base_filters * 4, 2,
105+
hparams)
106+
107+
# For canonical Shake-Shake, we should perform 8x8 average pooling and then
108+
# have a fully-connected layer (which produces the logits for each class).
109+
# Instead, we rely on the Xception exit flow in ClassLabelModality.
110+
#
111+
# Also, this model_fn does not return an extra_loss. However, TensorBoard
112+
# reports an exponential moving average for extra_loss, where the initial
113+
# value for the moving average may be a large number, so extra_loss will
114+
# look large at the beginning of training.
115+
return x
116+
117+
118+
@registry.register_hparams
119+
def shakeshake_cifar10():
120+
hparams = common_hparams.basic_params1()
121+
# This leads to effective batch size 128 when number of GPUs is 1
122+
hparams.batch_size = 4096 * 8
123+
hparams.hidden_size = 16
124+
hparams.dropout = 0
125+
hparams.label_smoothing = 0.0
126+
hparams.clip_grad_norm = 2.0
127+
hparams.num_hidden_layers = 26
128+
hparams.kernel_height = -1 # Unused
129+
hparams.kernel_width = -1 # Unused
130+
hparams.learning_rate_decay_scheme = "cosine"
131+
# Model should be run for 700000 steps with batch size 128 (~1800 epochs)
132+
hparams.learning_rate_cosine_cycle_steps = 700000
133+
hparams.learning_rate = 0.2
134+
hparams.learning_rate_warmup_steps = 3000
135+
hparams.initializer = "uniform_unit_scaling"
136+
hparams.initializer_gain = 1.0
137+
# TODO(rshin): Adjust so that effective value becomes ~1e-4
138+
hparams.weight_decay = 3.0
139+
hparams.optimizer = "Momentum"
140+
hparams.optimizer_momentum_momentum = 0.9
141+
hparams.add_hparam('base_filters', 16)
142+
hparams.add_hparam('shakeshake_type', 'batch')
143+
return hparams

tensor2tensor/utils/trainer_utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,9 @@ def learning_rate_decay():
321321
(step + 1) * warmup_steps**-1.5, (step + 1)**-0.5)
322322
elif hparams.learning_rate_decay_scheme == "exp100k":
323323
return 0.94**(step // 100000)
324+
elif hparams.learning_rate_decay_scheme == "cosine":
325+
cycle_steps = hparams.learning_rate_cosine_cycle_steps
326+
return 0.5 * (1 + tf.cos(np.pi * (step % cycle_steps) / cycle_steps))
324327

325328
inv_base = tf.exp(tf.log(0.01) / warmup_steps)
326329
inv_decay = inv_base**(warmup_steps - step)

0 commit comments

Comments
 (0)