@@ -28,6 +28,7 @@ from __future__ import absolute_import
2828from __future__ import division
2929from __future__ import print_function
3030
31+ import os
3132import random
3233import tempfile
3334
@@ -79,24 +80,30 @@ _SUPPORTED_PROBLEM_GENERATORS = {
7980 lambda : algorithmic_math .algebra_inverse (26 , 0 , 2 , 100000 ),
8081 lambda : algorithmic_math .algebra_inverse (26 , 3 , 3 , 10000 )),
8182 "ice_parsing_tokens" : (
82- lambda : wmt .tabbed_parsing_token_generator (FLAGS . tmp_dir ,
83- True , "ice" , 2 ** 13 , 2 ** 8 ),
84- lambda : wmt .tabbed_parsing_token_generator (FLAGS . tmp_dir ,
85- False , "ice" , 2 ** 13 , 2 ** 8 )),
83+ lambda : wmt .tabbed_parsing_token_generator (
84+ FLAGS . data_dir , FLAGS . tmp_dir , True , "ice" , 2 ** 13 , 2 ** 8 ),
85+ lambda : wmt .tabbed_parsing_token_generator (
86+ FLAGS . data_dir , FLAGS . tmp_dir , False , "ice" , 2 ** 13 , 2 ** 8 )),
8687 "ice_parsing_characters" : (
87- lambda : wmt .tabbed_parsing_character_generator (FLAGS .tmp_dir , True ),
88- lambda : wmt .tabbed_parsing_character_generator (FLAGS .tmp_dir , False )),
88+ lambda : wmt .tabbed_parsing_character_generator (
89+ FLAGS .data_dir , FLAGS .tmp_dir , True ),
90+ lambda : wmt .tabbed_parsing_character_generator (
91+ FLAGS .data_dir , FLAGS .tmp_dir , False )),
8992 "wmt_parsing_tokens_8k" : (
90- lambda : wmt .parsing_token_generator (FLAGS .tmp_dir , True , 2 ** 13 ),
91- lambda : wmt .parsing_token_generator (FLAGS .tmp_dir , False , 2 ** 13 )),
93+ lambda : wmt .parsing_token_generator (
94+ FLAGS .data_dir , FLAGS .tmp_dir , True , 2 ** 13 ),
95+ lambda : wmt .parsing_token_generator (
96+ FLAGS .data_dir , FLAGS .tmp_dir , False , 2 ** 13 )),
9297 "wsj_parsing_tokens_16k" : (
93- lambda : wsj_parsing .parsing_token_generator (FLAGS . tmp_dir , True ,
94- 2 ** 14 , 2 ** 9 ),
95- lambda : wsj_parsing .parsing_token_generator (FLAGS . tmp_dir , False ,
96- 2 ** 14 , 2 ** 9 )),
98+ lambda : wsj_parsing .parsing_token_generator (
99+ FLAGS . data_dir , FLAGS . tmp_dir , True , 2 ** 14 , 2 ** 9 ),
100+ lambda : wsj_parsing .parsing_token_generator (
101+ FLAGS . data_dir , FLAGS . tmp_dir , False , 2 ** 14 , 2 ** 9 )),
97102 "wmt_ende_bpe32k" : (
98- lambda : wmt .ende_bpe_token_generator (FLAGS .tmp_dir , True ),
99- lambda : wmt .ende_bpe_token_generator (FLAGS .tmp_dir , False )),
103+ lambda : wmt .ende_bpe_token_generator (
104+ FLAGS .data_dir , FLAGS .tmp_dir , True ),
105+ lambda : wmt .ende_bpe_token_generator (
106+ FLAGS .data_dir , FLAGS .tmp_dir , False )),
100107 "lm1b_32k" : (
101108 lambda : lm1b .generator (FLAGS .tmp_dir , True ),
102109 lambda : lm1b .generator (FLAGS .tmp_dir , False )
@@ -118,98 +125,50 @@ _SUPPORTED_PROBLEM_GENERATORS = {
118125 lambda : image .cifar10_generator (FLAGS .tmp_dir , True , 50000 ),
119126 lambda : image .cifar10_generator (FLAGS .tmp_dir , False , 10000 )),
120127 "image_mscoco_characters_test" : (
121- lambda : image .mscoco_generator (FLAGS .tmp_dir , True , 80000 ),
122- lambda : image .mscoco_generator (FLAGS .tmp_dir , False , 40000 )),
128+ lambda : image .mscoco_generator (
129+ FLAGS .data_dir , FLAGS .tmp_dir , True , 80000 ),
130+ lambda : image .mscoco_generator (
131+ FLAGS .data_dir , FLAGS .tmp_dir , False , 40000 )),
132+ "image_celeba_tune" : (
133+ lambda : image .celeba_generator (FLAGS .tmp_dir , 162770 ),
134+ lambda : image .celeba_generator (FLAGS .tmp_dir , 19867 , 162770 )),
123135 "image_mscoco_tokens_8k_test" : (
124136 lambda : image .mscoco_generator (
125- FLAGS .tmp_dir ,
126- True ,
127- 80000 ,
128- vocab_filename = "tokens.vocab.%d" % 2 ** 13 ,
129- vocab_size = 2 ** 13 ),
137+ FLAGS .data_dir , FLAGS .tmp_dir , True , 80000 ,
138+ vocab_filename = "vocab.endefr.%d" % 2 ** 13 , vocab_size = 2 ** 13 ),
130139 lambda : image .mscoco_generator (
131- FLAGS .tmp_dir ,
132- False ,
133- 40000 ,
134- vocab_filename = "tokens.vocab.%d" % 2 ** 13 ,
135- vocab_size = 2 ** 13 )),
140+ FLAGS .data_dir , FLAGS .tmp_dir , False , 40000 ,
141+ vocab_filename = "vocab.endefr.%d" % 2 ** 13 , vocab_size = 2 ** 13 )),
136142 "image_mscoco_tokens_32k_test" : (
137143 lambda : image .mscoco_generator (
138- FLAGS .tmp_dir ,
139- True ,
140- 80000 ,
141- vocab_filename = "tokens.vocab.%d" % 2 ** 15 ,
142- vocab_size = 2 ** 15 ),
144+ FLAGS .data_dir , FLAGS .tmp_dir , True , 80000 ,
145+ vocab_filename = "vocab.endefr.%d" % 2 ** 15 , vocab_size = 2 ** 15 ),
143146 lambda : image .mscoco_generator (
144- FLAGS .tmp_dir ,
145- False ,
146- 40000 ,
147- vocab_filename = "tokens.vocab.%d" % 2 ** 15 ,
148- vocab_size = 2 ** 15 )),
147+ FLAGS .data_dir , FLAGS .tmp_dir , False , 40000 ,
148+ vocab_filename = "vocab.endefr.%d" % 2 ** 15 , vocab_size = 2 ** 15 )),
149149 "snli_32k" : (
150150 lambda : snli .snli_token_generator (FLAGS .tmp_dir , True , 2 ** 15 ),
151151 lambda : snli .snli_token_generator (FLAGS .tmp_dir , False , 2 ** 15 ),
152152 ),
153- "audio_timit_characters_tune" : (
154- lambda : audio .timit_generator (FLAGS .tmp_dir , True , 1374 ),
155- lambda : audio .timit_generator (FLAGS .tmp_dir , True , 344 , 1374 )),
156153 "audio_timit_characters_test" : (
157- lambda : audio .timit_generator (FLAGS .tmp_dir , True , 1718 ),
158- lambda : audio .timit_generator (FLAGS .tmp_dir , False , 626 )),
159- "audio_timit_tokens_8k_tune" : (
160154 lambda : audio .timit_generator (
161- FLAGS .tmp_dir ,
162- True ,
163- 1374 ,
164- vocab_filename = "tokens.vocab.%d" % 2 ** 13 ,
165- vocab_size = 2 ** 13 ),
155+ FLAGS .data_dir , FLAGS .tmp_dir , True , 1718 ),
166156 lambda : audio .timit_generator (
167- FLAGS .tmp_dir ,
168- True ,
169- 344 ,
170- 1374 ,
171- vocab_filename = "tokens.vocab.%d" % 2 ** 13 ,
172- vocab_size = 2 ** 13 )),
157+ FLAGS .data_dir , FLAGS .tmp_dir , False , 626 )),
173158 "audio_timit_tokens_8k_test" : (
174159 lambda : audio .timit_generator (
175- FLAGS .tmp_dir ,
176- True ,
177- 1718 ,
178- vocab_filename = "tokens.vocab.%d" % 2 ** 13 ,
179- vocab_size = 2 ** 13 ),
180- lambda : audio .timit_generator (
181- FLAGS .tmp_dir ,
182- False ,
183- 626 ,
184- vocab_filename = "tokens.vocab.%d" % 2 ** 13 ,
185- vocab_size = 2 ** 13 )),
186- "audio_timit_tokens_32k_tune" : (
187- lambda : audio .timit_generator (
188- FLAGS .tmp_dir ,
189- True ,
190- 1374 ,
191- vocab_filename = "tokens.vocab.%d" % 2 ** 15 ,
192- vocab_size = 2 ** 15 ),
160+ FLAGS .data_dir , FLAGS .tmp_dir , True , 1718 ,
161+ vocab_filename = "vocab.endefr.%d" % 2 ** 13 , vocab_size = 2 ** 13 ),
193162 lambda : audio .timit_generator (
194- FLAGS .tmp_dir ,
195- True ,
196- 344 ,
197- 1374 ,
198- vocab_filename = "tokens.vocab.%d" % 2 ** 15 ,
199- vocab_size = 2 ** 15 )),
163+ FLAGS .data_dir , FLAGS .tmp_dir , False , 626 ,
164+ vocab_filename = "vocab.endefr.%d" % 2 ** 13 , vocab_size = 2 ** 13 )),
200165 "audio_timit_tokens_32k_test" : (
201166 lambda : audio .timit_generator (
202- FLAGS .tmp_dir ,
203- True ,
204- 1718 ,
205- vocab_filename = "tokens.vocab.%d" % 2 ** 15 ,
206- vocab_size = 2 ** 15 ),
167+ FLAGS .data_dir , FLAGS .tmp_dir , True , 1718 ,
168+ vocab_filename = "vocab.endefr.%d" % 2 ** 15 , vocab_size = 2 ** 15 ),
207169 lambda : audio .timit_generator (
208- FLAGS .tmp_dir ,
209- False ,
210- 626 ,
211- vocab_filename = "tokens.vocab.%d" % 2 ** 15 ,
212- vocab_size = 2 ** 15 )),
170+ FLAGS .data_dir , FLAGS .tmp_dir , False , 626 ,
171+ vocab_filename = "vocab.endefr.%d" % 2 ** 15 , vocab_size = 2 ** 15 )),
213172 "lmptb_10k" : (
214173 lambda : ptb .train_generator (
215174 FLAGS .tmp_dir ,
@@ -317,7 +276,9 @@ def generate_data_for_problem(problem):
317276
318277def generate_data_for_registered_problem (problem_name ):
319278 problem = registry .problem (problem_name )
320- problem .generate_data (FLAGS .data_dir , FLAGS .tmp_dir , FLAGS .num_shards )
279+ problem .generate_data (os .path .expanduser (FLAGS .data_dir ),
280+ os .path .expanduser (FLAGS .tmp_dir ),
281+ FLAGS .num_shards )
321282
322283
323284if __name__ == "__main__" :
0 commit comments