Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit e133a1a

Browse files
Niki ParmarRyan Sepassi
authored andcommitted
Enable Transformer fast decoding in eager mode
PiperOrigin-RevId: 177554962
1 parent c93a188 commit e133a1a

File tree

1 file changed

+11
-21
lines changed

1 file changed

+11
-21
lines changed

tensor2tensor/models/transformer.py

Lines changed: 11 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -171,14 +171,9 @@ def _greedy_infer(self, features, decode_length):
171171
Raises:
172172
NotImplementedError: If there are multiple data shards.
173173
"""
174-
# TODO(nikip): Remove slow decoding for eager. Eager mode doesn't work
175-
# with accessing _shape which is used in fast decoding currently.
176-
if self._hparams.use_eager_mode:
177-
return self._slow_greedy_infer(features, decode_length)
178-
else:
179-
with tf.variable_scope(self.name):
180-
decoded_ids, _ = self._fast_decode(features, decode_length)
181-
return decoded_ids, None, None
174+
with tf.variable_scope(self.name):
175+
decoded_ids, _ = self._fast_decode(features, decode_length)
176+
return decoded_ids, None, None
182177

183178
def _beam_decode(self, features, decode_length, beam_size, top_beams, alpha):
184179
"""Beam search decoding.
@@ -194,16 +189,10 @@ def _beam_decode(self, features, decode_length, beam_size, top_beams, alpha):
194189
Returns:
195190
samples: an integer `Tensor`. Top samples from the beam search
196191
"""
197-
# TODO(nikip): Remove slow decoding for eager. Eager mode doesn't work
198-
# with accessing _shape which is used in fast decoding currently.
199-
if self._hparams.use_eager_mode:
200-
return self._beam_decode_slow(
201-
features, decode_length, beam_size, top_beams, alpha)
202-
else:
203-
with tf.variable_scope(self.name):
204-
decoded_ids, scores = self._fast_decode(features, decode_length,
205-
beam_size, top_beams, alpha)
206-
return {"outputs": decoded_ids, "scores": scores}
192+
with tf.variable_scope(self.name):
193+
decoded_ids, scores = self._fast_decode(features, decode_length,
194+
beam_size, top_beams, alpha)
195+
return {"outputs": decoded_ids, "scores": scores}
207196

208197
def _fast_decode(self,
209198
features,
@@ -335,9 +324,10 @@ def symbols_to_logits_fn(ids, i, cache):
335324
# Note: Tensor.set_shape() does not work here since it merges shape info.
336325
# TODO(llion); Find a more robust solution.
337326
# pylint: disable=protected-access
338-
for layer in cache:
339-
cache[layer]["k"]._shape = tf.TensorShape([None, None, key_channels])
340-
cache[layer]["v"]._shape = tf.TensorShape([None, None, value_channels])
327+
if not self._hparams.use_eager_mode:
328+
for layer in cache:
329+
cache[layer]["k"]._shape = tf.TensorShape([None, None, key_channels])
330+
cache[layer]["v"]._shape = tf.TensorShape([None, None, value_channels])
341331
# pylint: enable=protected-access
342332
cache["encoder_output"] = encoder_output
343333
cache["encoder_decoder_attention_bias"] = encoder_decoder_attention_bias

0 commit comments

Comments
 (0)