12
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
13
# See the License for the specific language governing permissions and
14
14
# limitations under the License.
15
-
16
15
r"""Decode from trained T2T models.
17
16
18
17
This binary performs inference using the Estimator API.
@@ -82,9 +81,13 @@ def create_decode_hparams():
82
81
83
82
def decode (estimator , hparams , decode_hp ):
84
83
if FLAGS .decode_interactive :
84
+ if estimator .config .use_tpu :
85
+ raise ValueError ("TPU can only decode from dataset." )
85
86
decoding .decode_interactively (estimator , hparams , decode_hp ,
86
87
checkpoint_path = FLAGS .checkpoint_path )
87
88
elif FLAGS .decode_from_file :
89
+ if estimator .config .use_tpu :
90
+ raise ValueError ("TPU can only decode from dataset." )
88
91
decoding .decode_from_file (estimator , FLAGS .decode_from_file , hparams ,
89
92
decode_hp , FLAGS .decode_to_file ,
90
93
checkpoint_path = FLAGS .checkpoint_path )
@@ -160,7 +163,6 @@ def main(_):
160
163
tf .logging .set_verbosity (tf .logging .INFO )
161
164
trainer_lib .set_random_seed (FLAGS .random_seed )
162
165
usr_dir .import_usr_dir (FLAGS .t2t_usr_dir )
163
- FLAGS .use_tpu = False # decoding not supported on TPU
164
166
165
167
if FLAGS .score_file :
166
168
filename = os .path .expanduser (FLAGS .score_file )
@@ -183,7 +185,7 @@ def main(_):
183
185
hp ,
184
186
t2t_trainer .create_run_config (hp ),
185
187
decode_hparams = decode_hp ,
186
- use_tpu = False )
188
+ use_tpu = FLAGS . use_tpu )
187
189
188
190
decode (estimator , hp , decode_hp )
189
191
0 commit comments