Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit aa2c0b7

Browse files
author
Ryan Sepassi
committed
T2T depends on TF 1.4+, daisy_chain_getter bug fix, some Eager-mode improvements/fixes
PiperOrigin-RevId: 177538074
1 parent 24c1fd7 commit aa2c0b7

File tree

10 files changed

+998
-70
lines changed

10 files changed

+998
-70
lines changed

docs/example_life.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ hooks in the `Problem` class and the model's `HParams` object (typically
7575
registered in the model's file and specified by the `--hparams_set` flag).
7676

7777
The entire input pipeline is implemented with the new `tf.data.Dataset` API
78-
(previously `tf.contrib.data.Dataset`).
78+
(previously `tf.data.Dataset`).
7979

8080
The key function in the codebase for the input pipeline is
8181
[`data_reader.input_pipeline`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/utils/data_reader.py).

setup.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,7 @@
1515
package_data={
1616
'tensor2tensor.data_generators': ['test_data/*'],
1717
'tensor2tensor.visualization': [
18-
'attention.js',
19-
'TransformerVisualization.ipynb'
18+
'attention.js', 'TransformerVisualization.ipynb'
2019
],
2120
},
2221
scripts=[
@@ -34,8 +33,8 @@
3433
'six',
3534
],
3635
extras_require={
37-
'tensorflow': ['tensorflow>=1.3.0'],
38-
'tensorflow_gpu': ['tensorflow-gpu>=1.3.0'],
36+
'tensorflow': ['tensorflow>=1.4.0'],
37+
'tensorflow_gpu': ['tensorflow-gpu>=1.4.0'],
3938
'tests': ['pytest', 'h5py', 'mock'],
4039
},
4140
classifiers=[
@@ -45,4 +44,5 @@
4544
'License :: OSI Approved :: Apache Software License',
4645
'Topic :: Scientific/Engineering :: Artificial Intelligence',
4746
],
48-
keywords='tensorflow machine learning',)
47+
keywords='tensorflow machine learning',
48+
)

tensor2tensor/data_generators/generator_utils.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,13 @@ def shard_filepath(fname, num_shards):
125125
]
126126

127127

128+
def outputs_exist(filenames):
129+
for out_fname in filenames:
130+
out_fname = out_fname.replace(UNSHUFFLED_SUFFIX, "")
131+
if tf.gfile.Exists(out_fname):
132+
return out_fname
133+
134+
128135
def generate_files(generator, output_filenames, max_cases=None):
129136
"""Generate cases from a generator and save as TFRecord files.
130137
@@ -137,6 +144,9 @@ def generate_files(generator, output_filenames, max_cases=None):
137144
max_cases: maximum number of cases to get from the generator;
138145
if None (default), we use the generator until StopIteration is raised.
139146
"""
147+
if outputs_exist(output_filenames):
148+
tf.logging.info("Skipping generator because outputs files exist")
149+
return
140150
num_shards = len(output_filenames)
141151
writers = [tf.python_io.TFRecordWriter(fname) for fname in output_filenames]
142152
counter, shard = 0, 0
@@ -440,6 +450,9 @@ def generate_dataset_and_shuffle(train_gen,
440450

441451

442452
def shuffle_dataset(filenames):
453+
if outputs_exist(filenames):
454+
tf.logging.info("Skipping shuffle because output files exist")
455+
return
443456
tf.logging.info("Shuffling data...")
444457
for fname in filenames:
445458
records = read_records(fname)

tensor2tensor/data_generators/image.py

Lines changed: 26 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@
4242

4343
import tensorflow as tf
4444

45+
from tensorflow.python.eager import context
46+
4547

4648
def resize_by_area(img, size):
4749
"""image resize function used by quite a few image problems."""
@@ -463,6 +465,21 @@ def hparams(self, defaults, unused_model_hparams):
463465
p.target_space_id = 1
464466

465467

468+
def _encoded_images(images):
469+
if context.in_eager_mode():
470+
for image in images:
471+
yield tf.image.encode_png(image).numpy()
472+
else:
473+
(width, height, channels) = images[0].shape
474+
with tf.Graph().as_default():
475+
image_t = tf.placeholder(dtype=tf.uint8, shape=(width, height, channels))
476+
encoded_image_t = tf.image.encode_png(image_t)
477+
with tf.Session() as sess:
478+
for image in images:
479+
enc_string = sess.run(encoded_image_t, feed_dict={image_t: image})
480+
yield enc_string
481+
482+
466483
def image_generator(images, labels):
467484
"""Generator for images that takes image and labels lists and creates pngs.
468485
@@ -484,20 +501,15 @@ def image_generator(images, labels):
484501
"""
485502
if not images:
486503
raise ValueError("Must provide some images for the generator.")
487-
(width, height, channels) = images[0].shape
488-
with tf.Graph().as_default():
489-
image_t = tf.placeholder(dtype=tf.uint8, shape=(width, height, channels))
490-
encoded_image_t = tf.image.encode_png(image_t)
491-
with tf.Session() as sess:
492-
for (image, label) in zip(images, labels):
493-
enc_string = sess.run(encoded_image_t, feed_dict={image_t: image})
494-
yield {
495-
"image/encoded": [enc_string],
496-
"image/format": ["png"],
497-
"image/class/label": [int(label)],
498-
"image/height": [height],
499-
"image/width": [width]
500-
}
504+
width, height, _ = images[0].shape
505+
for (enc_image, label) in zip(_encoded_images(images), labels):
506+
yield {
507+
"image/encoded": [enc_image],
508+
"image/format": ["png"],
509+
"image/class/label": [int(label)],
510+
"image/height": [height],
511+
"image/width": [width]
512+
}
501513

502514

503515
# URLs and filenames for MNIST data.

tensor2tensor/data_generators/problem.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -382,7 +382,7 @@ def dataset(self,
382382
data_filepattern)
383383
if shuffle_files or shuffle_files is None and is_training:
384384
random.shuffle(data_files)
385-
dataset = tf.contrib.data.TFRecordDataset(data_files)
385+
dataset = tf.data.TFRecordDataset(data_files)
386386

387387
def decode_record(record):
388388
"""Serialized Example to dict of <feature name, Tensor>."""
@@ -399,13 +399,12 @@ def _preprocess(example):
399399
self.maybe_copy_features(example)
400400
return example
401401

402-
dataset = dataset.map(decode_record, num_threads=num_threads)
402+
dataset = dataset.map(decode_record, num_parallel_calls=num_threads)
403403

404404
if preprocess:
405-
dataset = dataset.map(
406-
_preprocess,
407-
num_threads=num_threads,
408-
output_buffer_size=output_buffer_size)
405+
dataset = dataset.map(_preprocess, num_parallel_calls=num_threads)
406+
if output_buffer_size:
407+
dataset = dataset.prefetch(output_buffer_size)
409408

410409
return dataset
411410

@@ -517,7 +516,7 @@ def define_shapes(example):
517516
dataset = self.dataset(
518517
mode=mode, data_dir=data_dir, num_threads=num_threads, hparams=hparams)
519518
dataset = dataset.map(
520-
data_reader.cast_int64_to_int32, num_threads=num_threads)
519+
data_reader.cast_int64_to_int32, num_parallel_calls=num_threads)
521520
if is_training:
522521
dataset = dataset.repeat(None)
523522

tensor2tensor/layers/rev_block.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -399,7 +399,7 @@ def grad_fn(inputs, variables, outputs, output_grads):
399399
@common_layers.fn_with_custom_grad(grad_fn)
400400
def fn_with_recompute(*args):
401401
cached_vs.append(tf.get_variable_scope())
402-
# TODO(rsepassi): Rm conditional in TF 1.4
402+
# TODO(rsepassi): Rm conditional in TF 1.5
403403
if hasattr(tf.contrib.framework, "current_arg_scope"):
404404
cached_arg_scope.append(tf.contrib.framework.current_arg_scope())
405405
else:

0 commit comments

Comments
 (0)