Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit aae9966

Browse files
authored
Merge pull request #75 from rshin/master
Add image augmentation for CIFAR-10
2 parents f95b7c9 + 24571fb commit aae9966

File tree

2 files changed

+17
-0
lines changed

2 files changed

+17
-0
lines changed

tensor2tensor/models/common_layers.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,17 @@ def image_augmentation(images, do_colors=False):
132132
return images
133133

134134

135+
def cifar_image_augmentation(images):
136+
"""Image augmentation suitable for CIFAR-10/100.
137+
138+
As described in https://arxiv.org/pdf/1608.06993v3.pdf (page 5)."""
139+
images = tf.image.resize_image_with_crop_or_pad(
140+
images, 40, 40)
141+
images = tf.random_crop(images, [32, 32, 3])
142+
images = tf.image.random_flip_left_right(images)
143+
return images
144+
145+
135146
def flatten4d3d(x):
136147
"""Flatten a 4d-tensor into a 3d-tensor by joining width and height."""
137148
xshape = tf.shape(x)

tensor2tensor/utils/data_reader.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,12 @@ def preprocess(img):
203203
lambda img=inputs: resize(img))
204204
else:
205205
examples["inputs"] = tf.to_int64(resize(inputs))
206+
207+
elif ("image_cifar10" in data_file_pattern
208+
and mode == tf.contrib.learn.ModeKeys.TRAIN):
209+
examples["inputs"] = common_layers.cifar_image_augmentation(
210+
examples["inputs"])
211+
206212
elif "audio" in data_file_pattern:
207213
# Reshape audio to proper shape
208214
sample_count = tf.to_int32(examples.pop("audio/sample_count"))

0 commit comments

Comments
 (0)