11#!/usr/bin/env python
2- # Copyright 2017 Google Inc .
2+ # Copyright 2017 The Tensor2Tensor Authors .
33#
44# Licensed under the Apache License, Version 2.0 (the "License");
55# you may not use this file except in compliance with the License.
@@ -24,6 +24,9 @@ takes 2 arguments - input_directory and mode (one of "train" or "dev") - and
2424yields for each training example a dictionary mapping string feature names to
2525lists of {string, int, float}. The generator will be run once for each mode.
2626"""
27+ from __future__ import absolute_import
28+ from __future__ import division
29+ from __future__ import print_function
2730
2831import random
2932import tempfile
@@ -34,6 +37,7 @@ import numpy as np
3437
3538from tensor2tensor .data_generators import algorithmic
3639from tensor2tensor .data_generators import algorithmic_math
40+ from tensor2tensor .data_generators import all_problems # pylint: disable=unused-import
3741from tensor2tensor .data_generators import audio
3842from tensor2tensor .data_generators import generator_utils
3943from tensor2tensor .data_generators import image
@@ -43,6 +47,7 @@ from tensor2tensor.data_generators import snli
4347from tensor2tensor .data_generators import wiki
4448from tensor2tensor .data_generators import wmt
4549from tensor2tensor .data_generators import wsj_parsing
50+ from tensor2tensor .utils import registry
4651
4752import tensorflow as tf
4853
@@ -62,12 +67,6 @@ flags.DEFINE_integer("random_seed", 429459, "Random seed to use.")
6267# Mapping from problems that we can generate data for to their generators.
6368# pylint: disable=g-long-lambda
6469_SUPPORTED_PROBLEM_GENERATORS = {
65- "algorithmic_identity_binary40" : (
66- lambda : algorithmic .identity_generator (2 , 40 , 100000 ),
67- lambda : algorithmic .identity_generator (2 , 400 , 10000 )),
68- "algorithmic_identity_decimal40" : (
69- lambda : algorithmic .identity_generator (10 , 40 , 100000 ),
70- lambda : algorithmic .identity_generator (10 , 400 , 10000 )),
7170 "algorithmic_shift_decimal40" : (
7271 lambda : algorithmic .shift_generator (20 , 10 , 40 , 100000 ),
7372 lambda : algorithmic .shift_generator (20 , 10 , 80 , 10000 )),
@@ -104,9 +103,9 @@ _SUPPORTED_PROBLEM_GENERATORS = {
104103 lambda : algorithmic_math .algebra_inverse (26 , 3 , 3 , 10000 )),
105104 "ice_parsing_tokens" : (
106105 lambda : wmt .tabbed_parsing_token_generator (FLAGS .tmp_dir ,
107- True , "ice" , 2 ** 13 , 2 ** 8 ),
106+ True , "ice" , 2 ** 13 , 2 ** 8 ),
108107 lambda : wmt .tabbed_parsing_token_generator (FLAGS .tmp_dir ,
109- False , "ice" , 2 ** 13 , 2 ** 8 )),
108+ False , "ice" , 2 ** 13 , 2 ** 8 )),
110109 "ice_parsing_characters" : (
111110 lambda : wmt .tabbed_parsing_character_generator (FLAGS .tmp_dir , True ),
112111 lambda : wmt .tabbed_parsing_character_generator (FLAGS .tmp_dir , False )),
@@ -118,11 +117,6 @@ _SUPPORTED_PROBLEM_GENERATORS = {
118117 2 ** 14 , 2 ** 9 ),
119118 lambda : wsj_parsing .parsing_token_generator (FLAGS .tmp_dir , False ,
120119 2 ** 14 , 2 ** 9 )),
121- "wsj_parsing_tokens_32k" : (
122- lambda : wsj_parsing .parsing_token_generator (FLAGS .tmp_dir , True ,
123- 2 ** 15 , 2 ** 9 ),
124- lambda : wsj_parsing .parsing_token_generator (FLAGS .tmp_dir , False ,
125- 2 ** 15 , 2 ** 9 )),
126120 "wmt_enfr_characters" : (
127121 lambda : wmt .enfr_character_generator (FLAGS .tmp_dir , True ),
128122 lambda : wmt .enfr_character_generator (FLAGS .tmp_dir , False )),
@@ -140,14 +134,6 @@ _SUPPORTED_PROBLEM_GENERATORS = {
140134 "wmt_ende_bpe32k" : (
141135 lambda : wmt .ende_bpe_token_generator (FLAGS .tmp_dir , True ),
142136 lambda : wmt .ende_bpe_token_generator (FLAGS .tmp_dir , False )),
143- "wmt_ende_tokens_8k" : (
144- lambda : wmt .ende_wordpiece_token_generator (FLAGS .tmp_dir , True , 2 ** 13 ),
145- lambda : wmt .ende_wordpiece_token_generator (FLAGS .tmp_dir , False , 2 ** 13 )
146- ),
147- "wmt_ende_tokens_32k" : (
148- lambda : wmt .ende_wordpiece_token_generator (FLAGS .tmp_dir , True , 2 ** 15 ),
149- lambda : wmt .ende_wordpiece_token_generator (FLAGS .tmp_dir , False , 2 ** 15 )
150- ),
151137 "wmt_zhen_tokens_32k" : (
152138 lambda : wmt .zhen_wordpiece_token_generator (FLAGS .tmp_dir , True ,
153139 2 ** 15 , 2 ** 15 ),
@@ -174,26 +160,9 @@ _SUPPORTED_PROBLEM_GENERATORS = {
174160 "image_cifar10_test" : (
175161 lambda : image .cifar10_generator (FLAGS .tmp_dir , True , 50000 ),
176162 lambda : image .cifar10_generator (FLAGS .tmp_dir , False , 10000 )),
177- "image_mscoco_characters_tune" : (
178- lambda : image .mscoco_generator (FLAGS .tmp_dir , True , 70000 ),
179- lambda : image .mscoco_generator (FLAGS .tmp_dir , True , 10000 , 70000 )),
180163 "image_mscoco_characters_test" : (
181164 lambda : image .mscoco_generator (FLAGS .tmp_dir , True , 80000 ),
182165 lambda : image .mscoco_generator (FLAGS .tmp_dir , False , 40000 )),
183- "image_mscoco_tokens_8k_tune" : (
184- lambda : image .mscoco_generator (
185- FLAGS .tmp_dir ,
186- True ,
187- 70000 ,
188- vocab_filename = "tokens.vocab.%d" % 2 ** 13 ,
189- vocab_size = 2 ** 13 ),
190- lambda : image .mscoco_generator (
191- FLAGS .tmp_dir ,
192- True ,
193- 10000 ,
194- 70000 ,
195- vocab_filename = "tokens.vocab.%d" % 2 ** 13 ,
196- vocab_size = 2 ** 13 )),
197166 "image_mscoco_tokens_8k_test" : (
198167 lambda : image .mscoco_generator (
199168 FLAGS .tmp_dir ,
@@ -207,20 +176,6 @@ _SUPPORTED_PROBLEM_GENERATORS = {
207176 40000 ,
208177 vocab_filename = "tokens.vocab.%d" % 2 ** 13 ,
209178 vocab_size = 2 ** 13 )),
210- "image_mscoco_tokens_32k_tune" : (
211- lambda : image .mscoco_generator (
212- FLAGS .tmp_dir ,
213- True ,
214- 70000 ,
215- vocab_filename = "tokens.vocab.%d" % 2 ** 15 ,
216- vocab_size = 2 ** 15 ),
217- lambda : image .mscoco_generator (
218- FLAGS .tmp_dir ,
219- True ,
220- 10000 ,
221- 70000 ,
222- vocab_filename = "tokens.vocab.%d" % 2 ** 15 ,
223- vocab_size = 2 ** 15 )),
224179 "image_mscoco_tokens_32k_test" : (
225180 lambda : image .mscoco_generator (
226181 FLAGS .tmp_dir ,
@@ -308,8 +263,6 @@ _SUPPORTED_PROBLEM_GENERATORS = {
308263
309264# pylint: enable=g-long-lambda
310265
311- UNSHUFFLED_SUFFIX = "-unshuffled"
312-
313266
314267def set_random_seed ():
315268 """Set the random seed from flag everywhere."""
@@ -322,13 +275,15 @@ def main(_):
322275 tf .logging .set_verbosity (tf .logging .INFO )
323276
324277 # Calculate the list of problems to generate.
325- problems = list (sorted (_SUPPORTED_PROBLEM_GENERATORS ))
278+ problems = sorted (
279+ list (_SUPPORTED_PROBLEM_GENERATORS ) + registry .list_problems ())
326280 if FLAGS .problem and FLAGS .problem [- 1 ] == "*" :
327281 problems = [p for p in problems if p .startswith (FLAGS .problem [:- 1 ])]
328282 elif FLAGS .problem :
329283 problems = [p for p in problems if p == FLAGS .problem ]
330284 else :
331285 problems = []
286+
332287 # Remove TIMIT if paths are not given.
333288 if not FLAGS .timit_paths :
334289 problems = [p for p in problems if "timit" not in p ]
@@ -340,7 +295,8 @@ def main(_):
340295 problems = [p for p in problems if "ende_bpe" not in p ]
341296
342297 if not problems :
343- problems_str = "\n * " .join (sorted (_SUPPORTED_PROBLEM_GENERATORS ))
298+ problems_str = "\n * " .join (
299+ sorted (list (_SUPPORTED_PROBLEM_GENERATORS ) + registry .list_problems ()))
344300 error_msg = ("You must specify one of the supported problems to "
345301 "generate data for:\n * " + problems_str + "\n " )
346302 error_msg += ("TIMIT, ende_bpe and parsing need data_sets specified with "
@@ -357,40 +313,50 @@ def main(_):
357313 for problem in problems :
358314 set_random_seed ()
359315
360- training_gen , dev_gen = _SUPPORTED_PROBLEM_GENERATORS [problem ]
361-
362- if isinstance (dev_gen , int ):
363- # The dev set and test sets are generated as extra shards using the
364- # training generator. The integer specifies the number of training
365- # shards. FLAGS.num_shards is ignored.
366- num_training_shards = dev_gen
367- tf .logging .info ("Generating data for %s." , problem )
368- all_output_files = generator_utils .combined_data_filenames (
369- problem + UNSHUFFLED_SUFFIX , FLAGS .data_dir , num_training_shards )
370- generator_utils .generate_files (
371- training_gen (), all_output_files , FLAGS .max_cases )
316+ if problem in _SUPPORTED_PROBLEM_GENERATORS :
317+ generate_data_for_problem (problem )
372318 else :
373- # usual case - train data and dev data are generated using separate
374- # generators.
375- tf .logging .info ("Generating training data for %s." , problem )
376- train_output_files = generator_utils .train_data_filenames (
377- problem + UNSHUFFLED_SUFFIX , FLAGS .data_dir , FLAGS .num_shards )
378- generator_utils .generate_files (
379- training_gen (), train_output_files , FLAGS .max_cases )
380- tf .logging .info ("Generating development data for %s." , problem )
381- dev_shards = 10 if "coco" in problem else 1
382- dev_output_files = generator_utils .dev_data_filenames (
383- problem + UNSHUFFLED_SUFFIX , FLAGS .data_dir , dev_shards )
384- generator_utils .generate_files (dev_gen (), dev_output_files )
385- all_output_files = train_output_files + dev_output_files
319+ generate_data_for_registered_problem (problem )
320+
321+
322+ def generate_data_for_problem (problem ):
323+ """Generate data for a problem in _SUPPORTED_PROBLEM_GENERATORS."""
324+ training_gen , dev_gen = _SUPPORTED_PROBLEM_GENERATORS [problem ]
325+
326+ if isinstance (dev_gen , int ):
327+ # The dev set and test sets are generated as extra shards using the
328+ # training generator. The integer specifies the number of training
329+ # shards. FLAGS.num_shards is ignored.
330+ num_training_shards = dev_gen
331+ tf .logging .info ("Generating data for %s." , problem )
332+ all_output_files = generator_utils .combined_data_filenames (
333+ problem + generator_utils .UNSHUFFLED_SUFFIX , FLAGS .data_dir ,
334+ num_training_shards )
335+ generator_utils .generate_files (training_gen (), all_output_files ,
336+ FLAGS .max_cases )
337+ else :
338+ # usual case - train data and dev data are generated using separate
339+ # generators.
340+ tf .logging .info ("Generating training data for %s." , problem )
341+ train_output_files = generator_utils .train_data_filenames (
342+ problem + generator_utils .UNSHUFFLED_SUFFIX , FLAGS .data_dir ,
343+ FLAGS .num_shards )
344+ generator_utils .generate_files (training_gen (), train_output_files ,
345+ FLAGS .max_cases )
346+ tf .logging .info ("Generating development data for %s." , problem )
347+ dev_shards = 10 if "coco" in problem else 1
348+ dev_output_files = generator_utils .dev_data_filenames (
349+ problem + generator_utils .UNSHUFFLED_SUFFIX , FLAGS .data_dir , dev_shards )
350+ generator_utils .generate_files (dev_gen (), dev_output_files )
351+ all_output_files = train_output_files + dev_output_files
352+
353+ tf .logging .info ("Shuffling data..." )
354+ generator_utils .shuffle_dataset (all_output_files )
355+
386356
387- tf .logging .info ("Shuffling data..." )
388- for fname in all_output_files :
389- records = generator_utils .read_records (fname )
390- random .shuffle (records )
391- out_fname = fname .replace (UNSHUFFLED_SUFFIX , "" )
392- generator_utils .write_records (records , out_fname )
393- tf .gfile .Remove (fname )
357+ def generate_data_for_registered_problem (problem_name ):
358+ problem = registry .problem (problem_name )
359+ problem .generate_data (FLAGS .data_dir , FLAGS .tmp_dir )
394360
395361
396362if __name__ == "__main__" :
0 commit comments