diff --git a/api/common/env.py b/api/common/env.py new file mode 100644 index 0000000000..523273302b --- /dev/null +++ b/api/common/env.py @@ -0,0 +1,24 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import os + + +def tf_use_xla(): + return os.environ.get("TF_USE_XLA", False) + + +def benchmark_use_feed_fetch(): + return os.environ.get("BENCHMARK_USE_FEED_FETCH", False) diff --git a/api/common/main.py b/api/common/main.py index 3fab250920..24036a3cf8 100644 --- a/api/common/main.py +++ b/api/common/main.py @@ -22,6 +22,7 @@ import collections import numpy as np +from common import env from common import utils from common import system from common import api_param @@ -228,7 +229,8 @@ def test_main_without_json(pd_obj=None, _adaptive_repeat(config, args) config.backward = args.backward - use_feed_fetch = True if args.task == "accuracy" else False + use_feed_fetch = True if args.task == "accuracy" else env.benchmark_use_feed_fetch( + ) feeder_adapter = None if _is_tensorflow_enabled(args, config): diff --git a/api/common/tensorflow_api_benchmark.py b/api/common/tensorflow_api_benchmark.py index 9e3a65f01b..c734a69779 100644 --- a/api/common/tensorflow_api_benchmark.py +++ b/api/common/tensorflow_api_benchmark.py @@ -22,6 +22,7 @@ import numpy as np from common import special_op_list +from . import env from . import utils from . import api_param from . import feeder @@ -384,6 +385,8 @@ def _init_session(self, use_gpu): config = tf.compat.v1.ConfigProto() if use_gpu: config.gpu_options.allow_growth = self.allow_growth + if env.tf_use_xla(): + config.graph_options.optimizer_options.global_jit_level = tf.compat.v1.OptimizerOptions.ON_1 else: # In default, TF use full cpu cores, but Paddle use one cpu core. # To make the same experiment, set TF use one cpu core as well. diff --git a/api/tests/batch_norm.py b/api/tests/batch_norm.py index 123688f8cb..3aaf63f1fd 100644 --- a/api/tests/batch_norm.py +++ b/api/tests/batch_norm.py @@ -22,9 +22,21 @@ def __init__(self): def init_from_json(self, filename, config_id=0, unknown_dim=16): super(BatchNormConfig, self).init_from_json(filename, config_id, unknown_dim) - # TFBatchNorm does not have data_layout param, it only support NHWC format. - if self.data_layout == "NCHW": - self.run_tf = False + # tf's batch_norm does not have data_format param, it only support NHWC format. + if self.data_layout == "NCHW" and len(self.input_shape) == 4: + self.feed_spec = [ + { + "range": [-1, 1], + "permute": [0, 2, 3, 1] + }, # input + { + "range": [-1, 1], + }, # scale + { + "range": [-1, 1], + } # bias + ] + if len(self.input_shape) == 4: if self.data_layout == "NCHW": self.num_channels = self.input_shape[1] @@ -35,6 +47,17 @@ def init_from_json(self, filename, config_id=0, unknown_dim=16): def to_tensorflow(self): tf_config = super(BatchNormConfig, self).to_tensorflow() + if self.data_layout == "NCHW" and len(self.input_shape) == 4: + print( + "Warning:\n" + " 1. tf's batch_norm does not have data_format param, it only support NHWC format. The benchmark test is actually running with NHWC format.\n" + ) + tf_config.data_layout = "NHWC" + tf_config.input_shape = [ + self.input_shape[0], self.input_shape[2], self.input_shape[3], + self.input_shape[1] + ] + if len(tf_config.input_shape) == 4: tf_config.axes = [0, 1, 2] else: @@ -72,20 +95,20 @@ class TFBatchNorm(TensorflowAPIBenchmarkBase): def build_graph(self, config): input = self.variable( name='input', shape=config.input_shape, dtype=config.input_dtype) + bias = self.variable( + name='bias', shape=[config.num_channels], dtype=config.input_dtype) scale = self.variable( name='scale', shape=[config.num_channels], dtype=config.input_dtype) - bias = self.variable( - name='bias', shape=[config.num_channels], dtype=config.input_dtype) mean, var = tf.nn.moments( x=input, axes=config.axes, shift=None, keepdims=False) result = tf.nn.batch_normalization( x=input, mean=mean, variance=var, - offset=bias, - scale=scale, + offset=bias, # beta + scale=scale, # gamma variance_epsilon=config.epsilon) self.feed_list = [input, scale, bias] diff --git a/api/tests/configs/batch_norm.json b/api/tests/configs/batch_norm.json index 47313e9a10..ad981830dd 100644 --- a/api/tests/configs/batch_norm.json +++ b/api/tests/configs/batch_norm.json @@ -38,7 +38,8 @@ "type": "bool", "value": "False" } - } + }, + "atol": 1E-05 }, { "op": "batch_norm", "param_info": {