Skip to content
This repository was archived by the owner on Jul 31, 2023. It is now read-only.

Commit 7779bee

Browse files
authored
Merge pull request #25 from google/write-if-non-empty
Generate TFRecords only if data exists in a split.
2 parents f5d5c2d + a912554 commit 7779bee

File tree

7 files changed

+218
-54
lines changed

7 files changed

+218
-54
lines changed

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ init:
44
pip install -r requirements.txt
55

66
test:
7-
nosetests --with-coverage --nocapture -v --cover-package=tfrecorder
7+
nosetests --with-coverage -v --cover-package=tfrecorder
88

99
pylint:
1010
pylint tfrecorder

tfrecorder/beam_image.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ def process(
109109
logging.warning('Could not load image: %s', image_uri)
110110
logging.error('Exception was: %s', str(e))
111111
self.image_bad_counter.inc()
112+
d['split'] = 'DISCARD'
112113

113114
element.update(d)
114115
yield element

tfrecorder/beam_pipeline.py

Lines changed: 68 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -19,20 +19,22 @@
1919
This file implements the full Beam pipeline for TFRecorder.
2020
"""
2121

22-
from typing import Any, Dict, Generator, Union
22+
from typing import Any, Callable, Dict, Generator, List, Optional, Union
2323

2424
import functools
2525
import logging
2626
import os
2727

2828
import apache_beam as beam
29+
from apache_beam import pvalue
2930
import pandas as pd
3031
import tensorflow_transform as tft
3132
from tensorflow_transform import beam as tft_beam
3233

3334
from tfrecorder import beam_image
3435
from tfrecorder import common
3536
from tfrecorder import constants
37+
from tfrecorder import types
3638

3739

3840
def _get_job_name(job_label: str = None) -> str:
@@ -138,7 +140,7 @@ def _get_write_to_tfrecord(output_dir: str,
138140
num_shards=num_shards,
139141
)
140142

141-
def _preprocessing_fn(inputs, integer_label: bool = False):
143+
def _preprocessing_fn(inputs: Dict[str, Any], integer_label: bool = False):
142144
"""TensorFlow Transform preprocessing function."""
143145

144146
outputs = inputs.copy()
@@ -166,7 +168,7 @@ def __init__(self):
166168
# pylint: disable=arguments-differ
167169
def process(
168170
self,
169-
element: Dict[str, Any]
171+
element: List[str],
170172
) -> Generator[Dict[str, Any], None, None]:
171173
"""Loads image and creates image features.
172174
@@ -178,6 +180,43 @@ def process(
178180
yield element
179181

180182

183+
def get_split_counts(df: pd.DataFrame):
184+
"""Returns number of rows for each data split type given dataframe."""
185+
assert constants.SPLIT_KEY in df.columns
186+
return df[constants.SPLIT_KEY].value_counts().to_dict()
187+
188+
189+
def _transform_and_write_tfr(
190+
dataset: pvalue.PCollection,
191+
tfr_writer: Callable = None,
192+
preprocessing_fn: Optional[Callable] = None,
193+
transform_fn: Optional[types.TransformFn] = None,
194+
label: str = 'data'):
195+
"""Applies TF Transform to dataset and outputs it as TFRecords."""
196+
197+
dataset_metadata = (dataset, constants.RAW_METADATA)
198+
199+
if transform_fn:
200+
transformed_dataset, transformed_metadata = (
201+
(dataset_metadata, transform_fn)
202+
| f'Transform{label}' >> tft_beam.TransformDataset())
203+
else:
204+
if not preprocessing_fn:
205+
preprocessing_fn = lambda x: x
206+
(transformed_dataset, transformed_metadata), transform_fn = (
207+
dataset_metadata
208+
| f'AnalyzeAndTransform{label}' >>
209+
tft_beam.AnalyzeAndTransformDataset(preprocessing_fn))
210+
211+
transformed_data_coder = tft.coders.ExampleProtoCoder(
212+
transformed_metadata.schema)
213+
_ = (
214+
transformed_dataset
215+
| f'Encode{label}' >> beam.Map(transformed_data_coder.encode)
216+
| f'Write{label}' >> tfr_writer(prefix=label.lower()))
217+
218+
return transform_fn
219+
181220

182221
# pylint: disable=too-many-arguments
183222
# pylint: disable=too-many-locals
@@ -246,71 +285,49 @@ def build_pipeline(
246285
| 'ReadImage' >> beam.ParDo(extract_images_fn)
247286
)
248287

249-
# Split dataset into train and validation.
288+
# Note: This will not always reflect actual number of samples per dataset
289+
# written as TFRecords. The succeeding `Partition` operation may mark
290+
# additional samples from other splits as discarded. If a split has all
291+
# its samples discarded, the pipeline will still generate a TFRecord
292+
# file for that split, albeit empty.
293+
split_counts = get_split_counts(df)
294+
295+
# Require training set to be available in the input data. The transform_fn
296+
# and transformed_metadata will be generated from the training set and
297+
# applied to the other datasets, if any
298+
assert 'TRAIN' in split_counts
299+
250300
train_data, val_data, test_data, discard_data = (
251301
image_csv_data | 'SplitDataset' >> beam.Partition(
252302
_partition_fn, len(constants.SPLIT_VALUES))
253303
)
254304

255-
train_dataset = (train_data, constants.RAW_METADATA)
256-
val_dataset = (val_data, constants.RAW_METADATA)
257-
test_dataset = (test_data, constants.RAW_METADATA)
258-
259-
# TensorFlow Transform applied to all datasets.
260305
preprocessing_fn = functools.partial(
261306
_preprocessing_fn,
262307
integer_label=integer_label)
263-
transformed_train_dataset, transform_fn = (
264-
train_dataset
265-
| 'AnalyzeAndTransformTrain' >> tft_beam.AnalyzeAndTransformDataset(
266-
preprocessing_fn))
267-
268-
transformed_train_data, transformed_metadata = transformed_train_dataset
269-
transformed_data_coder = tft.coders.ExampleProtoCoder(
270-
transformed_metadata.schema)
271-
272-
transformed_val_data, _ = (
273-
(val_dataset, transform_fn)
274-
| 'TransformVal' >> tft_beam.TransformDataset()
275-
)
276308

277-
transformed_test_data, _ = (
278-
(test_dataset, transform_fn)
279-
| 'TransformTest' >> tft_beam.TransformDataset()
280-
)
309+
tfr_writer = functools.partial(
310+
_get_write_to_tfrecord, output_dir=job_dir, compress=compression,
311+
num_shards=num_shards)
312+
transform_fn = _transform_and_write_tfr(
313+
train_data, tfr_writer, preprocessing_fn=preprocessing_fn,
314+
label='Train')
281315

282-
# Sinks for TFRecords and metadata.
283-
tfr_writer = functools.partial(_get_write_to_tfrecord,
284-
output_dir=job_dir,
285-
compress=compression,
286-
num_shards=num_shards)
316+
if 'VALIDATION' in split_counts:
317+
_transform_and_write_tfr(
318+
val_data, tfr_writer, transform_fn=transform_fn, label='Validation')
287319

288-
_ = (
289-
transformed_train_data
290-
| 'EncodeTrainData' >> beam.Map(transformed_data_coder.encode)
291-
| 'WriteTrainData' >> tfr_writer(prefix='train'))
292-
293-
_ = (
294-
transformed_val_data
295-
| 'EncodeValData' >> beam.Map(transformed_data_coder.encode)
296-
| 'WriteValData' >> tfr_writer(prefix='val'))
297-
298-
_ = (
299-
transformed_test_data
300-
| 'EncodeTestData' >> beam.Map(transformed_data_coder.encode)
301-
| 'WriteTestData' >> tfr_writer(prefix='test'))
320+
if 'TEST' in split_counts:
321+
_transform_and_write_tfr(
322+
test_data, tfr_writer, transform_fn=transform_fn, label='Test')
302323

303324
_ = (
304325
discard_data
305-
| 'DiscardDataWriter' >> beam.io.WriteToText(
326+
| 'WriteDiscardedData' >> beam.io.WriteToText(
306327
os.path.join(job_dir, 'discarded-data')))
307328

308-
# Output transform function and metadata
329+
# Note: `transform_fn` already contains the transformed metadata
309330
_ = (transform_fn | 'WriteTransformFn' >> tft_beam.WriteTransformFn(
310331
job_dir))
311332

312-
# Output metadata schema
313-
_ = (transformed_metadata | 'WriteMetadata' >> tft_beam.WriteMetadata(
314-
job_dir, pipeline=p))
315-
316333
return p

tfrecorder/beam_pipeline_test.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,21 @@
1616

1717
"""Tests for beam_pipeline."""
1818

19+
import functools
20+
import glob
21+
import os
22+
import tempfile
1923
import unittest
2024
from unittest import mock
2125

2226
import apache_beam as beam
2327
import tensorflow as tf
28+
import tensorflow_transform as tft
29+
from tensorflow_transform import beam as tft_beam
2430

2531
from tfrecorder import beam_pipeline
32+
from tfrecorder import constants
33+
from tfrecorder import test_utils
2634

2735

2836
# pylint: disable=protected-access
@@ -78,5 +86,94 @@ def test_partition_fn(self):
7886
'{} should be index {} but was index {}'.format(part, i, index))
7987

8088

89+
class GetSplitCountsTest(unittest.TestCase):
90+
"""Tests `get_split_counts` function."""
91+
92+
def setUp(self):
93+
self.df = test_utils.get_test_df()
94+
95+
def test_all_splits(self):
96+
"""Tests case where train, validation and test data exists"""
97+
expected = {'TRAIN': 2, 'VALIDATION': 2, 'TEST': 2}
98+
actual = beam_pipeline.get_split_counts(self.df)
99+
self.assertEqual(actual, expected)
100+
101+
def test_one_split(self):
102+
"""Tests case where only one split (train) exists."""
103+
df = self.df[self.df.split == 'TRAIN']
104+
expected = {'TRAIN': 2}
105+
actual = beam_pipeline.get_split_counts(df)
106+
self.assertEqual(actual, expected)
107+
108+
def test_error_no_split_key(self):
109+
"""Tests case no split key/column exists."""
110+
df = self.df.drop(constants.SPLIT_KEY, axis=1)
111+
with self.assertRaises(AssertionError):
112+
beam_pipeline.get_split_counts(df)
113+
114+
115+
class TransformAndWriteTfrTest(unittest.TestCase):
116+
"""Tests `_transform_and_write_tfr` function."""
117+
118+
def setUp(self):
119+
self.pipeline = test_utils.get_test_pipeline()
120+
self.raw_df = test_utils.get_raw_feature_df()
121+
self.temp_dir_obj = tempfile.TemporaryDirectory(dir='/tmp', prefix='test-')
122+
self.test_dir = self.temp_dir_obj.name
123+
self.tfr_writer = functools.partial(
124+
beam_pipeline._get_write_to_tfrecord, output_dir=self.test_dir,
125+
compress='gzip', num_shards=2)
126+
self.converter = tft.coders.CsvCoder(
127+
constants.RAW_FEATURE_SPEC.keys(), constants.RAW_METADATA.schema)
128+
self.transform_fn_path = ('./tfrecorder/test_data/sample_tfrecords')
129+
130+
def tearDown(self):
131+
self.temp_dir_obj.cleanup()
132+
133+
def _get_dataset(self, pipeline, df):
134+
"""Returns dataset `PCollection`."""
135+
return (pipeline
136+
| beam.Create(df.values.tolist())
137+
| beam.ParDo(beam_pipeline.ToCSVRows())
138+
| beam.Map(self.converter.decode))
139+
140+
def test_train(self):
141+
"""Tests case where training data is passed."""
142+
143+
with self.pipeline as p:
144+
with tft_beam.Context(temp_dir=os.path.join(self.test_dir, 'tmp')):
145+
df = self.raw_df[self.raw_df.split == 'TRAIN']
146+
dataset = self._get_dataset(p, df)
147+
transform_fn = (
148+
beam_pipeline._transform_and_write_tfr(
149+
dataset, self.tfr_writer, label='Train'))
150+
_ = transform_fn | tft_beam.WriteTransformFn(self.test_dir)
151+
152+
self.assertTrue(
153+
os.path.isdir(os.path.join(self.test_dir, 'transform_fn')))
154+
self.assertTrue(
155+
os.path.isdir(os.path.join(self.test_dir, 'transformed_metadata')))
156+
self.assertTrue(glob.glob(os.path.join(self.test_dir, 'train*.gz')))
157+
self.assertFalse(glob.glob(os.path.join(self.test_dir, 'validation*.gz')))
158+
self.assertFalse(glob.glob(os.path.join(self.test_dir, 'test*.gz')))
159+
160+
def test_non_training(self):
161+
"""Tests case where dataset contains non-training (e.g. test) data."""
162+
163+
with self.pipeline as p:
164+
with tft_beam.Context(temp_dir=os.path.join(self.test_dir, 'tmp')):
165+
166+
df = self.raw_df[self.raw_df.split == 'TEST']
167+
dataset = self._get_dataset(p, df)
168+
transform_fn = p | tft_beam.ReadTransformFn(self.transform_fn_path)
169+
beam_pipeline._transform_and_write_tfr(
170+
dataset, self.tfr_writer, transform_fn=transform_fn,
171+
label='Test')
172+
173+
self.assertFalse(glob.glob(os.path.join(self.test_dir, 'train*.gz')))
174+
self.assertFalse(glob.glob(os.path.join(self.test_dir, 'validation*.gz')))
175+
self.assertTrue(glob.glob(os.path.join(self.test_dir, 'test*.gz')))
176+
177+
81178
if __name__ == '__main__':
82179
unittest.main()

tfrecorder/client.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,9 @@ def _validate_runner(
7373

7474
if (runner == 'DataflowRunner') & (not tfrecorder_wheel):
7575
raise AttributeError(
76-
'DataflowRunner requires a tfrecorder whl file for remote execution.')
76+
'DataflowRunner requires a tfrecorder whl file for remote execution.')
77+
78+
7779
# def read_image_directory(dirpath) -> pd.DataFrame:
7880
# """Reads image data from a directory into a Pandas DataFrame."""
7981
#

tfrecorder/test_utils.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,13 @@
2626
from apache_beam.testing import test_pipeline
2727
import pandas as pd
2828

29+
from tfrecorder import constants
30+
2931

3032
TEST_DIR = 'tfrecorder/test_data'
3133

3234

33-
def get_test_df():
35+
def get_test_df() -> pd.DataFrame:
3436
"""Gets a test dataframe that works with the data in test_data/."""
3537
return pd.read_csv(os.path.join(TEST_DIR, 'data.csv'))
3638

@@ -41,6 +43,24 @@ def get_test_data() -> Dict[str, List[Any]]:
4143
return get_test_df().to_dict(orient='list')
4244

4345

46+
def get_raw_feature_df() -> pd.DataFrame:
47+
"""Returns test dataframe having raw feature spec schema."""
48+
49+
df = get_test_df()
50+
df.drop(constants.IMAGE_URI_KEY, axis=1, inplace=True)
51+
df['image_name'] = 'image_name'
52+
df['image'] = 'image'
53+
# Note: TF Transform parser expects string values in input. They will
54+
# be parsed based on the raw feature spec that is passed together with the
55+
# data
56+
df['image_height'] = '48'
57+
df['image_width'] = '48'
58+
df['image_channels'] = '3'
59+
df = df[constants.RAW_FEATURE_SPEC.keys()]
60+
61+
return df
62+
63+
4464
def get_test_pipeline():
4565
"""Gets a test pipeline."""
4666
return test_pipeline.TestPipeline(runner='DirectRunner')

tfrecorder/types.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# Lint as: python3
2+
3+
# Copyright 2020 Google LLC.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
"""Custom types."""
18+
19+
from typing import Tuple
20+
21+
from apache_beam.pvalue import PCollection
22+
from tensorflow_transform import beam as tft_beam
23+
24+
25+
BeamDatasetMetadata = tft_beam.tft_beam_io.beam_metadata_io.BeamDatasetMetadata
26+
TransformedMetadata = BeamDatasetMetadata
27+
TransformFn = Tuple[PCollection, TransformedMetadata]

0 commit comments

Comments
 (0)