@@ -77,6 +77,37 @@ def _get_token_encoder(vocab_dir, vocab_name, filename):
7777 return text_encoder .TokenTextEncoder (vocab_path )
7878
7979
80+ def _maybe_download_corpus (tmp_dir , vocab_type ):
81+ """Download and unpack the corpus.
82+
83+ Args:
84+ tmp_dir: directory containing dataset.
85+ """
86+ filename = os .path .basename (PTB_URL )
87+ compressed_filepath = generator_utils .maybe_download (
88+ tmp_dir , filename , PTB_URL )
89+ ptb_files = []
90+ ptb_char_files = []
91+
92+ with tarfile .open (compressed_filepath , "r:gz" ) as tgz :
93+ files = []
94+ # Selecting only relevant files.
95+ for m in tgz .getmembers ():
96+ if "ptb" in m .name and ".txt" in m .name :
97+ if "char" in m .name :
98+ ptb_char_files += [m .name ]
99+ else :
100+ ptb_files += [m .name ]
101+ files += [m ]
102+
103+ tgz .extractall (tmp_dir , members = files )
104+
105+ if vocab_type == text_problems .VocabType .CHARACTER :
106+ return ptb_char_files
107+ else :
108+ return ptb_files
109+
110+
80111@registry .register_problem
81112class LanguagemodelPtb10k (text_problems .Text2SelfProblem ):
82113 """PTB, 10k vocab."""
@@ -91,6 +122,10 @@ def dataset_splits(self):
91122 "shards" : 1 ,
92123 }]
93124
125+ @property
126+ def is_generate_per_split (self ):
127+ return True
128+
94129 @property
95130 def vocab_filename (self ):
96131 return "vocab.lmptb.10000"
@@ -100,28 +135,7 @@ def vocab_type(self):
100135 return text_problems .VocabType .TOKEN
101136
102137 def generate_samples (self , data_dir , tmp_dir , dataset_split ):
103- filename = os .path .basename (PTB_URL )
104- compressed_filepath = generator_utils .maybe_download (
105- tmp_dir , filename , PTB_URL )
106- ptb_files = []
107- ptb_char_files = []
108- with tarfile .open (compressed_filepath , "r:gz" ) as tgz :
109- files = []
110- # Selecting only relevant files.
111- for m in tgz .getmembers ():
112- if "ptb" in m .name and ".txt" in m .name :
113- if "char" in m .name :
114- ptb_char_files += [m .name ]
115- else :
116- ptb_files += [m .name ]
117- files += [m ]
118-
119- tgz .extractall (tmp_dir , members = files )
120-
121- if self .vocab_type == text_problems .VocabType .CHARACTER :
122- files = ptb_char_files
123- else :
124- files = ptb_files
138+ files = _maybe_download_corpus (tmp_dir , self .vocab_type )
125139
126140 train_file , valid_file = None , None
127141 for filename in files :
@@ -138,10 +152,13 @@ def generate_samples(self, data_dir, tmp_dir, dataset_split):
138152 train = dataset_split == problem .DatasetSplit .TRAIN
139153 filepath = train_file if train else valid_file
140154
141- with tf .gfile .GFile (filepath , "r" ) as f :
142- for line in f :
143- line = " " .join (line .replace ("\n " , " %s " % EOS ).split ())
144- yield {"targets" : line }
155+ def _generate_samples ():
156+ with tf .gfile .GFile (filepath , "r" ) as f :
157+ for line in f :
158+ line = " " .join (line .replace ("\n " , " %s " % EOS ).split ())
159+ yield {"targets" : line }
160+
161+ return _generate_samples ()
145162
146163
147164@registry .register_problem
0 commit comments