@@ -143,52 +143,46 @@ def model_fn_body(self, features):
143143 encoder_decoder_attention_bias ,
144144 decoder_self_attention_bias , hparams )
145145
146- def _greedy_infer (self , features , decode_length , last_position_only = True ):
146+ def _greedy_infer (self , features , decode_length ):
147147 """Fast version of greedy decoding.
148148
149149 Args:
150150 features: an map of string to `Tensor`
151151 decode_length: an integer. How many additional timesteps to decode.
152- last_position_only: MUST be true for fast decoding!
153152
154153 Returns:
155154 samples: [batch_size, input_length + decode_length]
156155 logits: Not returned
157156 losses: Not returned
158157
159158 Raises:
160- ValueError: If last_position_only if False
161159 NotImplementedError: If there are multiple data shards.
162160 """
163- decoded_ids , _ = self ._fast_decode (
164- features , decode_length , last_position_only )
161+ decoded_ids , _ = self ._fast_decode (features , decode_length )
165162 return decoded_ids , None , None
166163
167164 def _beam_decode (self , features , decode_length , beam_size , top_beams ,
168- last_position_only , alpha ):
165+ alpha ):
169166 """Beam search decoding.
170167
171168 Args:
172169 features: an map of string to `Tensor`
173170 decode_length: an integer. How many additional timesteps to decode.
174171 beam_size: number of beams.
175172 top_beams: an integer. How many of the beams to return.
176- last_position_only: MUST be true for fast decoding!
177173 alpha: Float that controls the length penalty. larger the alpha, stronger
178174 the preference for slonger translations.
179175
180176 Returns:
181177 samples: an integer `Tensor`. Top samples from the beam search
182178 """
183179 decoded_ids , scores = self ._fast_decode (
184- features , decode_length , last_position_only , beam_size , top_beams ,
185- alpha )
180+ features , decode_length , beam_size , top_beams , alpha )
186181 return {"outputs" : decoded_ids , "scores" : scores }
187182
188183 def _fast_decode (self ,
189184 features ,
190185 decode_length ,
191- last_position_only = True ,
192186 beam_size = 1 ,
193187 top_beams = 1 ,
194188 alpha = 1.0 ):
@@ -200,7 +194,6 @@ def _fast_decode(self,
200194 Args:
201195 features: a map of string to model features.
202196 decode_length: an integer. How many additional timesteps to decode.
203- last_position_only: MUST be true for fast decoding!
204197 beam_size: number of beams.
205198 top_beams: an integer. How many of the beams to return.
206199 alpha: Float that controls the length penalty. larger the alpha, stronger
@@ -210,11 +203,8 @@ def _fast_decode(self,
210203 samples: an integer `Tensor`. Top samples from the beam search
211204
212205 Raises:
213- ValueError: If last_position_only if False
214206 NotImplementedError: If there are multiple data shards.
215207 """
216- if not last_position_only :
217- raise ValueError ("Fast decoding only deals with the last positions!" )
218208 if self ._num_datashards != 1 :
219209 raise NotImplementedError ("Fast decoding only supports a single shard." )
220210 dp = self ._data_parallelism
0 commit comments