Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 889fc84

Browse files
author
Ryan Sepassi
committed
TF Eager improvements for T2TModel
PiperOrigin-RevId: 177641254
1 parent 654f74e commit 889fc84

File tree

8 files changed

+661
-222
lines changed

8 files changed

+661
-222
lines changed

tensor2tensor/layers/common_hparams.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -184,9 +184,6 @@ def basic_params1():
184184
# This is the actual batch size, *not* tokens per batch (i.e. for
185185
# language models this is the number of sentences in the batch)
186186
tpu_batch_size_per_shard=24,
187-
# Things not compatible with eager mode use this flag to implement
188-
# alternative functionality. We expect this to go away soon.
189-
use_eager_mode=False,
190187
# Set by tpu_trainer to let the model know whether we are on TPU.
191188
# Switching on/off tpu should not invalidate checkpoints.
192189
use_tpu=False,

tensor2tensor/layers/common_layers.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232

3333
import tensorflow as tf
3434

35+
from tensorflow.python.eager import context as tfe_context
3536
from tensorflow.python.framework import function
3637
from tensorflow.python.framework import ops
3738

@@ -200,16 +201,15 @@ def flatten4d3d(x):
200201
return result
201202

202203

203-
def embedding(x, vocab_size, dense_size, name=None, reuse=None, multiplier=1.0,
204-
use_eager_mode=False):
204+
def embedding(x, vocab_size, dense_size, name=None, reuse=None, multiplier=1.0):
205205
"""Embed x of type int64 into dense vectors, reducing to max 4 dimensions."""
206206
with tf.variable_scope(
207207
name, default_name="embedding", values=[x], reuse=reuse):
208208
embedding_var = tf.get_variable("kernel", [vocab_size, dense_size])
209209
# On the backwards pass, we want to convert the gradient from
210210
# an indexed-slices to a regular tensor before sending it back to the
211211
# parameter server. This avoids excess computation on the parameter server.
212-
if not use_eager_mode:
212+
if not tfe_context.in_eager_mode():
213213
embedding_var = eu.convert_gradient_to_tensor(embedding_var)
214214
emb_x = tf.gather(embedding_var, x)
215215
if multiplier != 1.0:

tensor2tensor/layers/modalities.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929

3030
import tensorflow as tf
3131

32+
from tensorflow.python.eager import context
33+
3234

3335
# TODO(noam): remove this function after TPUs do gather faster.
3436
def tpu_gather(params, indices):
@@ -96,7 +98,7 @@ def _get_weights(self, hidden_dim=None):
9698
else:
9799
ret = tf.concat(shards, 0)
98100
# Convert ret to tensor.
99-
if not self._model_hparams.use_eager_mode:
101+
if not context.in_eager_mode():
100102
ret = eu.convert_gradient_to_tensor(ret)
101103
return ret
102104

@@ -205,7 +207,7 @@ class ImageModality(modality.Modality):
205207
def bottom(self, inputs):
206208
with tf.variable_scope(self.name):
207209
inputs = common_layers.standardize_images(inputs)
208-
if not self._model_hparams.use_eager_mode:
210+
if not context.in_eager_mode():
209211
tf.summary.image("inputs", inputs, max_outputs=2)
210212
return tf.to_float(inputs)
211213

@@ -216,8 +218,7 @@ def targets_bottom(self, inputs):
216218
tf.to_int32(common_layers.flatten4d3d(inputs)),
217219
self.top_dimensionality,
218220
self._body_input_depth,
219-
name="input_rgb_embedding",
220-
use_eager_mode=self._model_hparams.use_eager_mode)
221+
name="input_rgb_embedding")
221222
if self._model_hparams.multiply_embedding_mode == "sqrt_depth":
222223
ret *= self._body_input_depth**0.5
223224

tensor2tensor/layers/modalities_test.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@ def testSymbolModalityInputs(self):
4343
symbol_modality_skip_top=0,
4444
shared_embedding_and_softmax_weights=0,
4545
prepend_mode="none",
46-
use_eager_mode=False,
4746
use_tpu=False)
4847
x = -1 + np.random.random_integers(
4948
vocab_size, size=(batch_size, length, 1, 1))
@@ -74,7 +73,6 @@ def testSymbolModalityTargets(self):
7473
factored_logits=0,
7574
mode=tf.estimator.ModeKeys.TRAIN,
7675
prepend_mode="none",
77-
use_eager_mode=False,
7876
use_tpu=False)
7977
body_output = -1 + np.random.random_integers(
8078
100, size=(batch_size, length, height, hidden_size))
@@ -112,7 +110,6 @@ def testSymbolModalityTargetsFactored(self):
112110
factored_logits=1,
113111
mode=tf.estimator.ModeKeys.TRAIN,
114112
prepend_mode="none",
115-
use_eager_mode=False,
116113
use_tpu=False)
117114
body_output = -1 + np.random.random_integers(
118115
100, size=(batch_size, length, height, hidden_size))

tensor2tensor/models/cycle_gan.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,11 +66,10 @@ def cycle_gan_internal(inputs, targets, _, hparams):
6666
# Embed inputs and targets.
6767
inputs_orig, targets_orig = tf.to_int32(inputs), tf.to_int32(targets)
6868
inputs = common_layers.embedding(
69-
inputs_orig, hparams.vocab_size, hparams.hidden_size, "embed",
70-
use_eager_mode=hparams.use_eager_mode)
69+
inputs_orig, hparams.vocab_size, hparams.hidden_size, "embed")
7170
targets = common_layers.embedding(
7271
targets_orig, hparams.vocab_size, hparams.hidden_size,
73-
"embed", reuse=True, use_eager_mode=hparams.use_eager_mode)
72+
"embed", reuse=True)
7473

7574
# Split the batch into input-input and target-target parts.
7675
inputs1, _ = split_on_batch(inputs)

tensor2tensor/models/transformer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737

3838
import tensorflow as tf
3939

40+
from tensorflow.python.eager import context
4041
from tensorflow.python.util import nest
4142

4243

@@ -324,7 +325,7 @@ def symbols_to_logits_fn(ids, i, cache):
324325
# Note: Tensor.set_shape() does not work here since it merges shape info.
325326
# TODO(llion); Find a more robust solution.
326327
# pylint: disable=protected-access
327-
if not self._hparams.use_eager_mode:
328+
if not context.in_eager_mode():
328329
for layer in cache:
329330
cache[layer]["k"]._shape = tf.TensorShape([None, None, key_channels])
330331
cache[layer]["v"]._shape = tf.TensorShape([None, None, value_channels])
@@ -452,8 +453,7 @@ def transformer_prepare_encoder(inputs, target_space, hparams, features=None):
452453
common_layers.shape_list(inputs)[1])
453454
# Append target_space_id embedding to inputs.
454455
emb_target_space = common_layers.embedding(
455-
target_space, 32, ishape_static[-1], name="target_space_embedding",
456-
use_eager_mode=hparams.use_eager_mode)
456+
target_space, 32, ishape_static[-1], name="target_space_embedding")
457457
emb_target_space = tf.reshape(emb_target_space, [1, 1, -1])
458458
encoder_input += emb_target_space
459459
if hparams.pos == "timing":

0 commit comments

Comments
 (0)