|
| 1 | +from __future__ import absolute_import |
| 2 | +from __future__ import division |
| 3 | +from __future__ import print_function |
| 4 | + |
| 5 | +from six.moves import xrange # pylint: disable=redefined-builtin |
| 6 | + |
| 7 | +from tensor2tensor.models import common_hparams |
| 8 | +from tensor2tensor.models import common_layers |
| 9 | +from tensor2tensor.utils import registry |
| 10 | +from tensor2tensor.utils import t2t_model |
| 11 | + |
| 12 | +import tensorflow as tf |
| 13 | + |
| 14 | + |
| 15 | +def shake_shake_block_branch(x, conv_filters, stride): |
| 16 | + x = tf.nn.relu(x) |
| 17 | + x = tf.layers.conv2d( |
| 18 | + x, conv_filters, (3, 3), strides=(stride, stride), padding='SAME') |
| 19 | + x = tf.layers.batch_normalization(x) |
| 20 | + x = tf.nn.relu(x) |
| 21 | + x = tf.layers.conv2d(x, conv_filters, (3, 3), strides=(1, 1), padding='SAME') |
| 22 | + x = tf.layers.batch_normalization(x) |
| 23 | + return x |
| 24 | + |
| 25 | + |
| 26 | +def downsampling_residual_branch(x, conv_filters): |
| 27 | + x = tf.nn.relu(x) |
| 28 | + |
| 29 | + x1 = tf.layers.average_pooling2d(x, pool_size=(1, 1), strides=(2, 2)) |
| 30 | + x1 = tf.layers.conv2d(x1, conv_filters / 2, (1, 1), padding='SAME') |
| 31 | + |
| 32 | + x2 = tf.pad(x[:, 1:, 1:], [[0, 0], [0, 1], [0, 1], [0, 0]]) |
| 33 | + x2 = tf.layers.average_pooling2d(x2, pool_size=(1, 1), strides=(2, 2)) |
| 34 | + x2 = tf.layers.conv2d(x2, conv_filters / 2, (1, 1), padding='SAME') |
| 35 | + |
| 36 | + return tf.concat([x1, x2], axis=3) |
| 37 | + |
| 38 | + |
| 39 | +def shake_shake_block(x, conv_filters, stride, hparams): |
| 40 | + with tf.variable_scope('branch_1'): |
| 41 | + branch1 = shake_shake_block_branch(x, conv_filters, stride) |
| 42 | + with tf.variable_scope('branch_2'): |
| 43 | + branch2 = shake_shake_block_branch(x, conv_filters, stride) |
| 44 | + if x.shape[-1] == conv_filters: |
| 45 | + skip = tf.identity(x) |
| 46 | + else: |
| 47 | + skip = downsampling_residual_branch(x, conv_filters) |
| 48 | + |
| 49 | + # TODO(rshin): Use different alpha for each image in batch. |
| 50 | + if hparams.mode == tf.contrib.learn.ModeKeys.TRAIN: |
| 51 | + if hparams.shakeshake_type == 'batch': |
| 52 | + shaken = common_layers.shakeshake2(branch1, branch2) |
| 53 | + elif hparams.shakeshake_type == 'image': |
| 54 | + shaken = common_layers.shakeshake2_indiv(branch1, branch2) |
| 55 | + elif hparams.shakeshake_type == 'equal': |
| 56 | + shaken = common_layers.shakeshake2_py(branch1, branch2, equal=True) |
| 57 | + else: |
| 58 | + raise ValueError('Invalid shakeshake_type: {!r}'.format(shaken)) |
| 59 | + else: |
| 60 | + shaken = common_layers.shakeshake2_py(branch1, branch2, equal=True) |
| 61 | + shaken.set_shape(branch1.get_shape()) |
| 62 | + |
| 63 | + return skip + shaken |
| 64 | + |
| 65 | + |
| 66 | +def shake_shake_stage(x, num_blocks, conv_filters, initial_stride, hparams): |
| 67 | + with tf.variable_scope('block_0'): |
| 68 | + x = shake_shake_block(x, conv_filters, initial_stride, hparams) |
| 69 | + for i in xrange(1, num_blocks): |
| 70 | + with tf.variable_scope('block_{}'.format(i)): |
| 71 | + x = shake_shake_block(x, conv_filters, 1, hparams) |
| 72 | + return x |
| 73 | + |
| 74 | + |
| 75 | +@registry.register_model |
| 76 | +class ShakeShake(t2t_model.T2TModel): |
| 77 | + '''Implements the Shake-Shake architecture. |
| 78 | +
|
| 79 | + From <https://arxiv.org/pdf/1705.07485.pdf> |
| 80 | + This is intended to match the CIFAR-10 version, and correspond to |
| 81 | + "Shake-Shake-Batch" in Table 1. |
| 82 | + ''' |
| 83 | + |
| 84 | + def model_fn_body(self, features): |
| 85 | + hparams = self._hparams |
| 86 | + print(hparams.learning_rate) |
| 87 | + |
| 88 | + inputs = features["inputs"] |
| 89 | + assert (hparams.num_hidden_layers - 2) % 6 == 0 |
| 90 | + blocks_per_stage = (hparams.num_hidden_layers - 2) // 6 |
| 91 | + |
| 92 | + # For canonical Shake-Shake, the entry flow is a 3x3 convolution with 16 |
| 93 | + # filters then a batch norm. Instead we will rely on the one in |
| 94 | + # SmallImageModality, which seems to instead use a layer norm. |
| 95 | + x = inputs |
| 96 | + mode = hparams.mode |
| 97 | + with tf.variable_scope('shake_shake_stage_1'): |
| 98 | + x = shake_shake_stage(x, blocks_per_stage, hparams.base_filters, 1, |
| 99 | + hparams) |
| 100 | + with tf.variable_scope('shake_shake_stage_2'): |
| 101 | + x = shake_shake_stage(x, blocks_per_stage, hparams.base_filters * 2, 2, |
| 102 | + hparams) |
| 103 | + with tf.variable_scope('shake_shake_stage_3'): |
| 104 | + x = shake_shake_stage(x, blocks_per_stage, hparams.base_filters * 4, 2, |
| 105 | + hparams) |
| 106 | + |
| 107 | + # For canonical Shake-Shake, we should perform 8x8 average pooling and then |
| 108 | + # have a fully-connected layer (which produces the logits for each class). |
| 109 | + # Instead, we rely on the Xception exit flow in ClassLabelModality. |
| 110 | + # |
| 111 | + # Also, this model_fn does not return an extra_loss. However, TensorBoard |
| 112 | + # reports an exponential moving average for extra_loss, where the initial |
| 113 | + # value for the moving average may be a large number, so extra_loss will |
| 114 | + # look large at the beginning of training. |
| 115 | + return x |
| 116 | + |
| 117 | + |
| 118 | +@registry.register_hparams |
| 119 | +def shakeshake_cifar10(): |
| 120 | + hparams = common_hparams.basic_params1() |
| 121 | + # This leads to effective batch size 128 when number of GPUs is 1 |
| 122 | + hparams.batch_size = 4096 * 8 |
| 123 | + hparams.hidden_size = 16 |
| 124 | + hparams.dropout = 0 |
| 125 | + hparams.label_smoothing = 0.0 |
| 126 | + hparams.clip_grad_norm = 2.0 |
| 127 | + hparams.num_hidden_layers = 26 |
| 128 | + hparams.kernel_height = -1 # Unused |
| 129 | + hparams.kernel_width = -1 # Unused |
| 130 | + hparams.learning_rate_decay_scheme = "cosine" |
| 131 | + # Model should be run for 700000 steps with batch size 128 (~1800 epochs) |
| 132 | + hparams.learning_rate_cosine_cycle_steps = 700000 |
| 133 | + hparams.learning_rate = 0.2 |
| 134 | + hparams.learning_rate_warmup_steps = 3000 |
| 135 | + hparams.initializer = "uniform_unit_scaling" |
| 136 | + hparams.initializer_gain = 1.0 |
| 137 | + # TODO(rshin): Adjust so that effective value becomes ~1e-4 |
| 138 | + hparams.weight_decay = 3.0 |
| 139 | + hparams.optimizer = "Momentum" |
| 140 | + hparams.optimizer_momentum_momentum = 0.9 |
| 141 | + hparams.add_hparam('base_filters', 16) |
| 142 | + hparams.add_hparam('shakeshake_type', 'batch') |
| 143 | + return hparams |
0 commit comments