Skip to content

Commit 3e59f05

Browse files
aselletensorflower-gardener
authored andcommitted
Basic version of TensorFlow 1.0 Upgrade Script.
This script currently is minimally tested. It is a work in progress currently. Change: 144125570
1 parent e273097 commit 3e59f05

File tree

6 files changed

+975
-0
lines changed

6 files changed

+975
-0
lines changed

tensorflow/BUILD

+1
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,7 @@ filegroup(
207207
"//tensorflow/tensorboard/lib/python:all_files",
208208
"//tensorflow/tensorboard/scripts:all_files",
209209
"//tensorflow/tools/common:all_files",
210+
"//tensorflow/tools/compatibility:all_files",
210211
"//tensorflow/tools/dist_test/server:all_files",
211212
"//tensorflow/tools/docker:all_files",
212213
"//tensorflow/tools/docker/notebooks:all_files",

tensorflow/tools/compatibility/BUILD

+83
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
licenses(["notice"]) # Apache 2.0
2+
3+
package(default_visibility = ["//tensorflow:internal"])
4+
5+
load(
6+
"//tensorflow:tensorflow.bzl",
7+
"tf_copts", # @unused
8+
"tf_cc_test", # @unused
9+
)
10+
11+
py_binary(
12+
name = "tf_upgrade",
13+
srcs = ["tf_upgrade.py"],
14+
srcs_version = "PY2AND3",
15+
)
16+
17+
py_test(
18+
name = "tf_upgrade_test",
19+
srcs = ["tf_upgrade_test.py"],
20+
srcs_version = "PY2AND3",
21+
deps = [
22+
"tf_upgrade",
23+
"//tensorflow:tensorflow_py",
24+
],
25+
)
26+
27+
# Keep for reference, this test will succeed in 0.11 but fail in 1.0
28+
# py_test(
29+
# name = "test_file_v0_11",
30+
# size = "small",
31+
# srcs = ["testdata/test_file_v0_11.py"],
32+
# srcs_version = "PY2AND3",
33+
# deps = [
34+
# "//tensorflow:tensorflow_py",
35+
# ],
36+
# )
37+
38+
genrule(
39+
name = "generate_upgraded_file",
40+
testonly = 1,
41+
srcs = ["testdata/test_file_v0_11.py"],
42+
outs = [
43+
"test_file_v1_0.py",
44+
"report.txt",
45+
],
46+
cmd = ("$(location tf_upgrade)" +
47+
" --infile $(location testdata/test_file_v0_11.py)" +
48+
" --outfile $(location test_file_v1_0.py)" +
49+
" --reportfile $(location report.txt)"),
50+
tools = ["tf_upgrade"],
51+
)
52+
53+
py_test(
54+
name = "test_file_v1_0",
55+
size = "small",
56+
srcs = ["test_file_v1_0.py"],
57+
srcs_version = "PY2AND3",
58+
deps = [
59+
"//tensorflow:tensorflow_py",
60+
],
61+
)
62+
63+
exports_files(
64+
[
65+
"tf_upgrade.py",
66+
"testdata/test_file_v0_11.py",
67+
],
68+
)
69+
70+
# -----------------------------------------------------------------------------
71+
# Google-internal targets. These must be at the end for syncrepo.
72+
73+
filegroup(
74+
name = "all_files",
75+
srcs = glob(
76+
["**/*"],
77+
exclude = [
78+
"**/METADATA",
79+
"**/OWNERS",
80+
],
81+
),
82+
visibility = ["//tensorflow:__subpackages__"],
83+
)
+48
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# TensorFlow Python API Upgrade Utility
2+
3+
This tool allows you to upgrade your existing TensorFlow Python scripts.
4+
This script can be run on a single Python file:
5+
6+
```
7+
tf_upgrade.py --infile foo.py --outfile foo-upgraded.py
8+
```
9+
10+
It will print a list of errors it finds that it can't fix. You can also run
11+
it on a directory tree:
12+
13+
```
14+
tf_upgrade.py --intree coolcode -outtree coolcode-upgraded
15+
```
16+
17+
In either case, it will also dump out a report e.g. which will detail changes
18+
e.g.:
19+
20+
```
21+
third_party/tensorflow/tools/compatibility/test_file_v0.11.py Line 125
22+
23+
Renamed keyword argument from `dim` to `axis`
24+
Renamed keyword argument from `squeeze_dims` to `axis`
25+
26+
Old: [[1, 2, 3]], dim=1), squeeze_dims=[1]).eval(),
27+
~~~~ ~~~~~~~~~~~~~
28+
New: [[1, 2, 3]], axis=1), axis=[1]).eval(),
29+
~~~~~ ~~~~~
30+
```
31+
32+
## Caveats
33+
34+
- Don't update parts of your code manually before running this script. In
35+
particular, functions that have had reordered arguments like `tf.concat`,
36+
`tf.split` will cause the script to incorrectly add keyword arguments that
37+
mismap arguments.
38+
39+
- This script is not able to upgrade all functions. One notable example is
40+
`tf.reverse()` which has been changed to take a list of indices rather than
41+
a tensor of bools. If the script detects this, it will report this to stdout
42+
(and in the report), and you can fix it manually. For example if you have
43+
`tf.reverse(a, [False, True, True])` you will need to manually change it to
44+
`tf.reverse(a, [1, 2])`.
45+
46+
47+
48+
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,208 @@
1+
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Tests for tf upgrader."""
16+
17+
from __future__ import absolute_import
18+
from __future__ import division
19+
from __future__ import print_function
20+
import shutil
21+
import tempfile
22+
import numpy as np
23+
import tensorflow as tf
24+
from tensorflow.python.framework import test_util
25+
from tensorflow.python.platform import test as test_lib
26+
27+
28+
class TestUpgrade(test_util.TensorFlowTestCase):
29+
"""Test various APIs that have been changed in 1.0.
30+
31+
This test will not run in current TensorFlow, but did run in 0.11.
32+
This file is intended to be converted by a genrule() that uses the converter
33+
so that a 1.0 compatible version of this file is generated. That is run as
34+
a unit test if the converter is successful.
35+
"""
36+
37+
def testArgRenames(self):
38+
with self.test_session():
39+
40+
a = [[1., 2., 3.], [4., 5., 6.]]
41+
b = [[True, False, False], [False, True, True]]
42+
dim0 = [1]
43+
dim1 = [1]
44+
45+
self.assertAllEqual(
46+
tf.reduce_any(
47+
b, reduction_indices=dim0).eval(), [True, True])
48+
self.assertAllEqual(
49+
tf.reduce_all(
50+
b, reduction_indices=[0]).eval(), [False, False, False])
51+
self.assertAllEqual(
52+
tf.reduce_all(
53+
b, reduction_indices=dim1).eval(), [False, False])
54+
self.assertAllEqual(
55+
tf.reduce_sum(
56+
a, reduction_indices=[1]).eval(), [6., 15.])
57+
self.assertAllEqual(
58+
tf.reduce_sum(
59+
a, reduction_indices=[0, 1]).eval(), 21.0)
60+
self.assertAllEqual(tf.reduce_sum(a, [0, 1]).eval(), 21.0)
61+
self.assertAllEqual(
62+
tf.reduce_prod(
63+
a, reduction_indices=[1]).eval(), [6., 120.])
64+
self.assertAllEqual(
65+
tf.reduce_prod(
66+
a, reduction_indices=[0, 1]).eval(), 720.0)
67+
self.assertAllEqual(tf.reduce_prod(a, [0, 1]).eval(), 720.0)
68+
self.assertAllEqual(
69+
tf.reduce_mean(
70+
a, reduction_indices=[1]).eval(), [2., 5.])
71+
self.assertAllEqual(
72+
tf.reduce_mean(
73+
a, reduction_indices=[0, 1]).eval(), 3.5)
74+
self.assertAllEqual(tf.reduce_mean(a, [0, 1]).eval(), 3.5)
75+
self.assertAllEqual(
76+
tf.reduce_min(
77+
a, reduction_indices=[1]).eval(), [1., 4.])
78+
self.assertAllEqual(
79+
tf.reduce_min(
80+
a, reduction_indices=[0, 1]).eval(), 1.0)
81+
self.assertAllEqual(tf.reduce_min(a, [0, 1]).eval(), 1.0)
82+
self.assertAllEqual(
83+
tf.reduce_max(
84+
a, reduction_indices=[1]).eval(), [3., 6.])
85+
self.assertAllEqual(
86+
tf.reduce_max(
87+
a, reduction_indices=[0, 1]).eval(), 6.0)
88+
self.assertAllEqual(tf.reduce_max(a, [0, 1]).eval(), 6.0)
89+
self.assertAllClose(tf.reduce_logsumexp(a, reduction_indices=[1]).eval(),
90+
[3.40760589, 6.40760612])
91+
self.assertAllClose(
92+
tf.reduce_logsumexp(a, reduction_indices=[0, 1]).eval(),
93+
6.45619344711)
94+
self.assertAllClose(
95+
tf.reduce_logsumexp(a, [0, 1]).eval(), 6.45619344711)
96+
self.assertAllEqual(
97+
tf.expand_dims([[1, 2], [3, 4]], dim=1).eval(),
98+
[[[1, 2]], [[3, 4]]])
99+
100+
def testArgMinMax(self):
101+
with self.test_session():
102+
self.assertAllEqual(
103+
tf.argmin([[1, 2, 3], [4, 1, 0]], dimension=1).eval(),
104+
[0, 2])
105+
self.assertAllEqual(
106+
tf.argmin([[1, 2, 3], [4, 1, 0]], dimension=0).eval(),
107+
[0, 1, 1])
108+
self.assertAllEqual(
109+
tf.argmax([[1, 2, 3], [4, 1, 0]], dimension=1).eval(),
110+
[2, 0])
111+
self.assertAllEqual(
112+
tf.argmax([[1, 2, 3], [4, 1, 0]], dimension=0).eval(),
113+
[1, 0, 0])
114+
115+
def testExpandAndSqueeze(self):
116+
with self.test_session():
117+
118+
# TODO(aselle): sparse_split, sparse_reduce_sum,
119+
# sparse_reduce_sum_sparse, reduce_join
120+
a = [[1, 2, 3]]
121+
self.assertAllEqual(tf.expand_dims(tf.squeeze(a, [0]), 0).eval(),
122+
a)
123+
self.assertAllEqual(tf.squeeze(tf.expand_dims(a, 1), [1]).eval(),
124+
a)
125+
self.assertAllEqual(
126+
tf.expand_dims(
127+
tf.squeeze(
128+
[[1, 2, 3]], squeeze_dims=[0]), dim=0).eval(),
129+
a)
130+
self.assertAllEqual(
131+
tf.squeeze(
132+
tf.expand_dims(
133+
[[1, 2, 3]], dim=1), squeeze_dims=[1]).eval(),
134+
a)
135+
136+
self.assertAllEqual(
137+
tf.squeeze(
138+
tf.expand_dims(
139+
[[1, 2, 3]], dim=1), squeeze_dims=[1]).eval(),
140+
a)
141+
142+
def testArithmeticRenames(self):
143+
with self.test_session() as s:
144+
stuff = tf.split(1, 2, [[1, 2, 3, 4], [4, 5, 6, 7]])
145+
vals = s.run(stuff)
146+
self.assertAllEqual(vals,
147+
[[[1, 2], [4, 5]], [[3, 4], [6, 7]]])
148+
self.assertAllEqual(
149+
tf.neg(tf.mul(tf.add(1, 2), tf.sub(5, 3))).eval(),
150+
-6)
151+
self.assertAllEqual(
152+
s.run(tf.listdiff([1, 2, 3], [3, 3, 4]))[0], [1, 2])
153+
self.assertAllEqual(
154+
tf.list_diff([1, 2, 3], [3, 3, 4])[0].eval(), [1, 2])
155+
a = [[1., 2., 3.], [4., 5., 6.]]
156+
foo = np.where(np.less(a, 2), np.negative(a), a)
157+
self.assertAllEqual(
158+
tf.select(tf.less(a, 2), tf.neg(a), a).eval(),
159+
foo)
160+
self.assertAllEqual(
161+
tf.complex_abs(tf.constant(3 + 4.j)).eval(),
162+
5)
163+
# # TODO(aselle): (tf.batch_*)
164+
# ]
165+
166+
def testVariables(self):
167+
with self.test_session() as s:
168+
169+
# make some variables
170+
_ = [tf.Variable([1, 2, 3], dtype=tf.float32),
171+
tf.Variable([1, 2, 3], dtype=tf.int32)]
172+
s.run(tf.initialize_all_variables())
173+
_ = [v.name for v in tf.all_variables()]
174+
_ = [v.name for v in tf.local_variables()]
175+
176+
def testSummaries(self):
177+
with self.test_session() as s:
178+
var = tf.Variable([1, 2, 3], dtype=tf.float32)
179+
s.run(tf.initialize_all_variables())
180+
x, y = np.meshgrid(np.linspace(-10, 10, 256), np.linspace(-10, 10, 256))
181+
image = np.sin(x**2 + y**2) / np.sqrt(x**2 + y**2) * .5 + .5
182+
image = image[None, :, :, None]
183+
184+
# make a dummy sound
185+
freq = 440 # A = 440Hz
186+
sampling_frequency = 11000
187+
audio = np.sin(2 * np.pi * np.linspace(0, 1, sampling_frequency) * freq)
188+
audio = audio[None, :, None]
189+
test_dir = tempfile.mkdtemp()
190+
# test summaries
191+
writer = tf.train.SummaryWriter(test_dir)
192+
summaries = [
193+
tf.scalar_summary("scalar_var", var[0]),
194+
tf.scalar_summary("scalar_reduce_var", tf.reduce_sum(var)),
195+
tf.histogram_summary("var_histogram", var),
196+
tf.image_summary("sin_image", image),
197+
tf.audio_summary("sin_wave", audio, sampling_frequency),
198+
]
199+
run_summaries = s.run(summaries)
200+
writer.add_summary(s.run(tf.merge_summary(inputs=run_summaries)))
201+
# This is redundant, but we want to be able to rewrite the command
202+
writer.add_summary(s.run(tf.merge_all_summaries()))
203+
writer.close()
204+
shutil.rmtree(test_dir)
205+
206+
207+
if __name__ == "__main__":
208+
test_lib.main()

0 commit comments

Comments
 (0)