2323import pprint
2424import tempfile
2525
26+
2627import tensorflow as tf
2728import tensorflow_transform as tft
2829from apache_beam .io import textio
4950REVIEW_WEIGHT = 'review_weight'
5051LABEL_COLUMN = 'label'
5152
53+ RAW_DATA_METADATA = dataset_metadata .DatasetMetadata (dataset_schema .Schema ({
54+ REVIEW_COLUMN : dataset_schema .ColumnSchema (
55+ tf .string , [], dataset_schema .FixedColumnRepresentation ()),
56+ LABEL_COLUMN : dataset_schema .ColumnSchema (
57+ tf .int64 , [], dataset_schema .FixedColumnRepresentation ()),
58+ }))
59+
5260DELIMITERS = '.,!?() '
5361
5462
@@ -99,13 +107,13 @@ def ReadAndShuffleData(pcoll, filepatterns):
99107 lambda p : {REVIEW_COLUMN : p [0 ], LABEL_COLUMN : p [1 ]})
100108
101109
102- def transform_data (train_neg_filepattern , train_pos_filepattern ,
103- test_neg_filepattern , test_pos_filepattern ,
104- transformed_train_filebase , transformed_test_filebase ,
105- transformed_metadata_dir ):
106- """Transform the data and write out as a TFRecord of Example protos.
110+ def read_and_shuffle_data (
111+ train_neg_filepattern , train_pos_filepattern , test_neg_filepattern ,
112+ test_pos_filepattern , shuffled_train_filebase , shuffled_test_filebase ):
113+ """Read and shuffle the data and write out as a TFRecord of Example protos.
107114
108- Read in the data from the positive and negative examples on disk, and
115+ Read in the data from the positive and negative examples on disk, shuffle it
116+ and write it out in TFRecord format.
109117 transform it using a preprocessing pipeline that removes punctuation,
110118 tokenizes and maps tokens to int64 values indices.
111119
@@ -114,6 +122,42 @@ def transform_data(train_neg_filepattern, train_pos_filepattern,
114122 train_pos_filepattern: Filepattern for training data positive examples
115123 test_neg_filepattern: Filepattern for test data negative examples
116124 test_pos_filepattern: Filepattern for test data positive examples
125+ shuffled_train_filebase: Base filename for shuffled training data shards
126+ shuffled_test_filebase: Base filename for shuffled test data shards
127+ """
128+ with beam .Pipeline () as pipeline :
129+ # pylint: disable=no-value-for-parameter
130+ _ = (
131+ pipeline
132+ | 'ReadAndShuffleTrain' >> ReadAndShuffleData (
133+ (train_neg_filepattern , train_pos_filepattern ))
134+ | 'WriteTrainData' >> tfrecordio .WriteToTFRecord (
135+ shuffled_train_filebase ,
136+ coder = example_proto_coder .ExampleProtoCoder (
137+ RAW_DATA_METADATA .schema )))
138+ _ = (
139+ pipeline
140+ | 'ReadAndShuffleTest' >> ReadAndShuffleData (
141+ (test_neg_filepattern , test_pos_filepattern ))
142+ | 'WriteTestData' >> tfrecordio .WriteToTFRecord (
143+ shuffled_test_filebase ,
144+ coder = example_proto_coder .ExampleProtoCoder (
145+ RAW_DATA_METADATA .schema )))
146+ # pylint: enable=no-value-for-parameter
147+
148+
149+ def transform_data (shuffled_train_filepattern , shuffled_test_filepattern ,
150+ transformed_train_filebase , transformed_test_filebase ,
151+ transformed_metadata_dir ):
152+ """Transform the data and write out as a TFRecord of Example protos.
153+
154+ Read in the data from the positive and negative examples on disk, and
155+ transform it using a preprocessing pipeline that removes punctuation,
156+ tokenizes and maps tokens to int64 values indices.
157+
158+ Args:
159+ shuffled_train_filepattern: Base filename for shuffled training data shards
160+ shuffled_test_filepattern: Base filename for shuffled test data shards
117161 transformed_train_filebase: Base filename for transformed training data
118162 shards
119163 transformed_test_filebase: Base filename for transformed test data shards
@@ -123,19 +167,19 @@ def transform_data(train_neg_filepattern, train_pos_filepattern,
123167
124168 with beam .Pipeline () as pipeline :
125169 with beam_impl .Context (temp_dir = tempfile .mkdtemp ()):
126- # pylint: disable=no-value-for-parameter
127- train_data = pipeline | 'ReadTrain' >> ReadAndShuffleData (
128- ( train_neg_filepattern , train_pos_filepattern ))
129- # pylint: disable=no-value-for-parameter
130- test_data = pipeline | 'ReadTest' >> ReadAndShuffleData (
131- ( test_neg_filepattern , test_pos_filepattern ))
132-
133- metadata = dataset_metadata . DatasetMetadata ( dataset_schema . Schema ({
134- REVIEW_COLUMN : dataset_schema . ColumnSchema (
135- tf . string , [], dataset_schema . FixedColumnRepresentation ()),
136- LABEL_COLUMN : dataset_schema . ColumnSchema (
137- tf . int64 , [], dataset_schema . FixedColumnRepresentation ()),
138- } ))
170+ train_data = (
171+ pipeline |
172+ 'ReadTrain' >> tfrecordio . ReadFromTFRecord (
173+ shuffled_train_filepattern ,
174+ coder = example_proto_coder . ExampleProtoCoder (
175+ RAW_DATA_METADATA . schema ) ))
176+
177+ test_data = (
178+ pipeline |
179+ 'ReadTest' >> tfrecordio . ReadFromTFRecord (
180+ shuffled_test_filepattern ,
181+ coder = example_proto_coder . ExampleProtoCoder (
182+ RAW_DATA_METADATA . schema ) ))
139183
140184 def preprocessing_fn (inputs ):
141185 """Preprocess input columns into transformed columns."""
@@ -153,12 +197,12 @@ def preprocessing_fn(inputs):
153197 }
154198
155199 (transformed_train_data , transformed_metadata ), transform_fn = (
156- (train_data , metadata )
200+ (train_data , RAW_DATA_METADATA )
157201 | 'AnalyzeAndTransform' >> beam_impl .AnalyzeAndTransformDataset (
158202 preprocessing_fn ))
159203
160204 transformed_test_data , _ = (
161- ((test_data , metadata ), transform_fn )
205+ ((test_data , RAW_DATA_METADATA ), transform_fn )
162206 | 'Transform' >> beam_impl .TransformDataset ())
163207
164208 _ = (
@@ -183,7 +227,9 @@ def preprocessing_fn(inputs):
183227
184228
185229def train_and_evaluate (transformed_train_filepattern ,
186- transformed_test_filepattern , transformed_metadata_dir ):
230+ transformed_test_filepattern , transformed_metadata_dir ,
231+ num_train_instances = NUM_TRAIN_INSTANCES ,
232+ num_test_instances = NUM_TEST_INSTANCES ):
187233 """Train the model on training data and evaluate on evaluation data.
188234
189235 Args:
@@ -192,6 +238,8 @@ def train_and_evaluate(transformed_train_filepattern,
192238 transformed_test_filepattern: Base filename for transformed evaluation data
193239 shards
194240 transformed_metadata_dir: Directory containing transformed data metadata
241+ num_train_instances: Number of instances in train set
242+ num_test_instances: Number of instances in test set
195243
196244 Returns:
197245 The results from the estimator's 'evaluate' method
@@ -219,7 +267,7 @@ def train_and_evaluate(transformed_train_filepattern,
219267 # Estimate the model using the default optimizer.
220268 estimator .fit (
221269 input_fn = train_input_fn ,
222- max_steps = TRAIN_NUM_EPOCHS * NUM_TRAIN_INSTANCES / TRAIN_BATCH_SIZE )
270+ max_steps = TRAIN_NUM_EPOCHS * num_train_instances / TRAIN_BATCH_SIZE )
223271
224272 # Evaluate model on eval dataset.
225273 eval_input_fn = input_fn_maker .build_training_input_fn (
@@ -228,7 +276,7 @@ def train_and_evaluate(transformed_train_filepattern,
228276 training_batch_size = 1 ,
229277 label_keys = [LABEL_COLUMN ])
230278
231- return estimator .evaluate (input_fn = eval_input_fn , steps = NUM_TEST_INSTANCES )
279+ return estimator .evaluate (input_fn = eval_input_fn , steps = num_test_instances )
232280
233281
234282def main ():
@@ -248,14 +296,19 @@ def main():
248296 train_pos_filepattern = os .path .join (args .input_data_dir , 'train/pos/*' )
249297 test_neg_filepattern = os .path .join (args .input_data_dir , 'test/neg/*' )
250298 test_pos_filepattern = os .path .join (args .input_data_dir , 'test/pos/*' )
299+ shuffled_train_filebase = os .path .join (transformed_data_dir , 'train_shuffled' )
300+ shuffled_test_filebase = os .path .join (transformed_data_dir , 'test_shuffled' )
251301 transformed_train_filebase = os .path .join (transformed_data_dir ,
252302 'train_transformed' )
253303 transformed_test_filebase = os .path .join (transformed_data_dir ,
254304 'test_transformed' )
255305 transformed_metadata_dir = os .path .join (transformed_data_dir , 'metadata' )
256306
257- transform_data (train_neg_filepattern , train_pos_filepattern ,
258- test_neg_filepattern , test_pos_filepattern ,
307+ read_and_shuffle_data (train_neg_filepattern , train_pos_filepattern ,
308+ test_neg_filepattern , test_pos_filepattern ,
309+ shuffled_train_filebase , shuffled_test_filebase )
310+
311+ transform_data (shuffled_train_filebase + '*' , shuffled_test_filebase + '*' ,
259312 transformed_train_filebase , transformed_test_filebase ,
260313 transformed_metadata_dir )
261314
0 commit comments