@@ -171,14 +171,9 @@ def _greedy_infer(self, features, decode_length):
171171 Raises:
172172 NotImplementedError: If there are multiple data shards.
173173 """
174- # TODO(nikip): Remove slow decoding for eager. Eager mode doesn't work
175- # with accessing _shape which is used in fast decoding currently.
176- if self ._hparams .use_eager_mode :
177- return self ._slow_greedy_infer (features , decode_length )
178- else :
179- with tf .variable_scope (self .name ):
180- decoded_ids , _ = self ._fast_decode (features , decode_length )
181- return decoded_ids , None , None
174+ with tf .variable_scope (self .name ):
175+ decoded_ids , _ = self ._fast_decode (features , decode_length )
176+ return decoded_ids , None , None
182177
183178 def _beam_decode (self , features , decode_length , beam_size , top_beams , alpha ):
184179 """Beam search decoding.
@@ -194,16 +189,10 @@ def _beam_decode(self, features, decode_length, beam_size, top_beams, alpha):
194189 Returns:
195190 samples: an integer `Tensor`. Top samples from the beam search
196191 """
197- # TODO(nikip): Remove slow decoding for eager. Eager mode doesn't work
198- # with accessing _shape which is used in fast decoding currently.
199- if self ._hparams .use_eager_mode :
200- return self ._beam_decode_slow (
201- features , decode_length , beam_size , top_beams , alpha )
202- else :
203- with tf .variable_scope (self .name ):
204- decoded_ids , scores = self ._fast_decode (features , decode_length ,
205- beam_size , top_beams , alpha )
206- return {"outputs" : decoded_ids , "scores" : scores }
192+ with tf .variable_scope (self .name ):
193+ decoded_ids , scores = self ._fast_decode (features , decode_length ,
194+ beam_size , top_beams , alpha )
195+ return {"outputs" : decoded_ids , "scores" : scores }
207196
208197 def _fast_decode (self ,
209198 features ,
@@ -335,9 +324,10 @@ def symbols_to_logits_fn(ids, i, cache):
335324 # Note: Tensor.set_shape() does not work here since it merges shape info.
336325 # TODO(llion); Find a more robust solution.
337326 # pylint: disable=protected-access
338- for layer in cache :
339- cache [layer ]["k" ]._shape = tf .TensorShape ([None , None , key_channels ])
340- cache [layer ]["v" ]._shape = tf .TensorShape ([None , None , value_channels ])
327+ if not self ._hparams .use_eager_mode :
328+ for layer in cache :
329+ cache [layer ]["k" ]._shape = tf .TensorShape ([None , None , key_channels ])
330+ cache [layer ]["v" ]._shape = tf .TensorShape ([None , None , value_channels ])
341331 # pylint: enable=protected-access
342332 cache ["encoder_output" ] = encoder_output
343333 cache ["encoder_decoder_attention_bias" ] = encoder_decoder_attention_bias
0 commit comments