|
24 | 24 | from __future__ import division |
25 | 25 | from __future__ import print_function |
26 | 26 |
|
27 | | -import copy |
28 | | - |
29 | 27 | # Dependency imports |
30 | 28 |
|
31 | 29 | from six.moves import xrange # pylint: disable=redefined-builtin |
|
43 | 41 | class AttentionLmMoe(t2t_model.T2TModel): |
44 | 42 | """Attention net. See file docstring.""" |
45 | 43 |
|
46 | | - def model_fn_body_sharded(self, sharded_features, train): |
| 44 | + def model_fn_body_sharded(self, sharded_features): |
47 | 45 | # Remove dropout if not training |
48 | | - hparams = copy.copy(self._hparams) |
49 | | - if not train: |
50 | | - hparams.attention_dropout = 0. |
51 | | - hparams.relu_dropout = 0. |
52 | | - hparams.residual_dropout = 0. |
| 46 | + hparams = self._hparams |
53 | 47 | dp = self._data_parallelism |
54 | 48 | targets = sharded_features["targets"] |
55 | 49 | targets = dp(tf.squeeze, targets, 2) |
@@ -81,7 +75,9 @@ def residual_fn(x, y): |
81 | 75 | with tf.variable_scope("ffn"): |
82 | 76 | if str(layer) in hparams.moe_layers.split(","): |
83 | 77 | y, loss = common_layers.moe_layer( |
84 | | - dp, self._ps_devices, x, train, hparams.hidden_size, |
| 78 | + dp, self._ps_devices, x, |
| 79 | + hparams.mode == tf.contrib.learn.ModeKeys.TRAIN, |
| 80 | + hparams.hidden_size, |
85 | 81 | hparams.moe_hidden_size, hparams.moe_n1, hparams.moe_n2, |
86 | 82 | hparams.moe_loss_coef) |
87 | 83 | extra_loss += loss |
@@ -162,10 +158,12 @@ def attention_lm_moe_base(): |
162 | 158 | hparams.add_hparam("num_heads", 8) |
163 | 159 | hparams.add_hparam("attention_key_channels", 0) |
164 | 160 | hparams.add_hparam("attention_value_channels", 0) |
| 161 | + # All hyperparameters ending in "dropout" are automatically set to 0.0 |
| 162 | + # when not in training mode. |
165 | 163 | hparams.add_hparam("attention_dropout", 0.0) |
166 | 164 | hparams.add_hparam("relu_dropout", 0.0) |
167 | | - hparams.add_hparam("pos", "timing") # timing, none |
168 | 165 | hparams.add_hparam("residual_dropout", 0.1) |
| 166 | + hparams.add_hparam("pos", "timing") # timing, none |
169 | 167 | return hparams |
170 | 168 |
|
171 | 169 |
|
|
0 commit comments