Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 80 additions & 0 deletions api/common/benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import abc, six
import importlib


@six.add_metaclass(abc.ABCMeta)
class BenchmarkBase(object):
def __init__(self, framework, testing_mode):
self.name = self.__class__.__name__
self.feed_list = None
self.fetch_list = None
self._backward = False
self._framework = framework
self._testing_mode = testing_mode

@property
def backward(self):
return self._backward

def compute_flop_and_byte(self, config):
""" flop is used as a metric for op's performance and it is optional.
"""
return None, None

@abc.abstractmethod
def build_graph(self, config=None):
pass

@abc.abstractmethod
def variable(self, name, shape, dtype, value=None, stop_gradient=False):
pass

@abc.abstractmethod
def layers(self, api_name, module_name=None, **kwargs):
pass

@abc.abstractmethod
def append_gradients(self, targets, inputs):
pass

def get_running_stats(self, use_gpu, config, runtimes, walltimes=None):
try:
module_name = "torch" if self._framework == "pytorch" else self._framework
module = importlib.import_module(module_name)
version = module.__version__
except Exception:
version = "none"
print("Failed to call %s.__version__" % (self._framework))

stats = {
"framework": self._framework,
"version": version,
"name": self.name,
"device": "GPU" if use_gpu else "CPU",
"backward": self._backward,
"total": runtimes
}

if walltimes is not None:
stats["wall_time"] = walltimes

flop, byte = self.compute_flop_and_byte(config)
if flop is not None:
stats["flop"] = flop
if byte is not None:
stats["byte"] = byte
return stats
98 changes: 92 additions & 6 deletions api/common/feeder.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,99 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

import collections
import numpy as np

from . import paddle_api_benchmark as paddle_api
from . import tensorflow_api_benchmark as tensorflow_api

def _convert_paddle_dtype(dtype, to_string=True):
import paddle

def _trans(to_string, dtype_str, np_dtype):
dtype = dtype_str if to_string else np.dtype(np_dtype)
return dtype

if not isinstance(dtype, paddle.fluid.core.VarDesc.VarType):
raise TypeError("dtype is not of type fluid.core.VarDesc.VarType")
if dtype == paddle.fluid.core.VarDesc.VarType.FP32:
return _trans(to_string, "float32", np.float32)
elif dtype == paddle.fluid.core.VarDesc.VarType.FP64:
return _trans(to_string, "float64", np.float64)
elif dtype == paddle.fluid.core.VarDesc.VarType.FP16:
return _trans(to_string, "float16", np.float16)
elif dtype == paddle.fluid.core.VarDesc.VarType.INT32:
return _trans(to_string, "int32", np.int32)
elif dtype == paddle.fluid.core.VarDesc.VarType.INT16:
return _trans(to_string, "int16", np.int16)
elif dtype == paddle.fluid.core.VarDesc.VarType.INT64:
return _trans(to_string, "int64", np.int64)
elif dtype == paddle.fluid.core.VarDesc.VarType.BOOL:
return _trans(to_string, "bool", np.bool)
elif dtype == paddle.fluid.core.VarDesc.VarType.INT16:
return _trans(to_string, "uint16", np.uint16)
elif dtype == paddle.fluid.core.VarDesc.VarType.UINT8:
return _trans(to_string, "uint8", np.uint8)
elif dtype == paddle.fluid.core.VarDesc.VarType.INT8:
return _trans(to_string, "int8", np.int8)
else:
raise ValueError("Unsupported dtype %s" % dtype)


def _convert_tensorflow_dtype(dtype, to_string=True):
import tensorflow as tf

def _trans(to_string, dtype_str, np_dtype):
dtype = dtype_str if to_string else np.dtype(np_dtype)
return dtype

if dtype == tf.float16:
# tf.float16: 16-bit half-precision floating-point.
return _trans(to_string, "float16", np.float16)
elif dtype == tf.float32:
# tf.float32: 32-bit single-precision floating-point.
return _trans(to_string, "float32", np.float32)
elif dtype == tf.float64:
# tf.float64: 64-bit double-precision floating-point.
return _trans(to_string, "float64", np.float64)
elif dtype == tf.int8:
# tf.int8: 8-bit signed integer.
return _trans(to_string, "int8", np.int8)
elif dtype == tf.uint8:
# tf.uint8: 8-bit unsigned integer.
return _trans(to_string, "uint8", np.uint8)
elif dtype == tf.uint16:
# tf.uint16: 16-bit unsigned integer.
return _trans(to_string, "uint16", np.uint16)
elif dtype == tf.uint32:
# tf.uint32: 32-bit unsigned integer.
return _trans(to_string, "uint32", np.uint32)
elif dtype == tf.uint64:
# tf.uint64: 64-bit unsigned integer.
return _trans(to_string, "uint64", np.uint64)
elif dtype == tf.int16:
# tf.int16: 16-bit signed integer.
return _trans(to_string, "int16", np.int16)
elif dtype == tf.int32:
# tf.int32: 32-bit signed integer.
return _trans(to_string, "int32", np.int32)
elif dtype == tf.int64:
# tf.int64: 64-bit signed integer.
return _trans(to_string, "int64", np.int64)
elif dtype == tf.bool:
# tf.bool: Boolean.
return _trans(to_string, "bool", np.bool)
else:
# tf.bfloat16: 16-bit truncated floating-point.
# tf.complex64: 64-bit single-precision complex.
# tf.complex128: 128-bit double-precision complex.
# tf.string: String.
# tf.qint8: Quantized 8-bit signed integer.
# tf.quint8: Quantized 8-bit unsigned integer.
# tf.qint16: Quantized 16-bit signed integer.
# tf.quint16: Quantized 16-bit unsigned integer.
# tf.qint32: Quantized 32-bit signed integer.
# tf.resource: Handle to a mutable resource.
# tf.variant: Values of arbitrary types.
raise ValueError("Unsupported dtype %s" % dtype)


def copy_feed_spec(feed_spec):
Expand Down Expand Up @@ -132,7 +218,7 @@ def to_paddle(self, feed_vars=None):

# Check shape and dtype
var_shape = var.shape
var_dtype = paddle_api.convert_dtype(
var_dtype = _convert_paddle_dtype(
var.dtype, to_string=True)
value = check_shape_and_dtype(var_shape, var_dtype, value)

Expand Down Expand Up @@ -173,7 +259,7 @@ def _to_other(self, target_framework, feed_vars=None):
var = feed_list[i]
var_shape = var.shape
if target_framework == "tensorflow":
var_dtype = tensorflow_api.convert_dtype(
var_dtype = _convert_tensorflow_dtype(
var.dtype, to_string=True)
value = check_shape_and_dtype(var_shape, var_dtype, value)

Expand Down
21 changes: 9 additions & 12 deletions api/common/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,7 @@

from common import utils
from common import system
from common import api_param
from common import special_op_list
from common import pytorch_api_benchmark
from common import paddle_dynamic_api_benchmark


def _check_gpu_device(use_gpu):
Expand Down Expand Up @@ -170,13 +167,13 @@ def _test_with_json_impl(filename, config_id, unknown_dim,
test_main_without_json(pd_obj, tf_obj, pd_dy_obj, torch_obj, config)


def _is_paddle_enabled(args, config):
def is_paddle_enabled(args, config):
if args.task == "accuracy" or args.framework in ["paddle", "both"]:
return True
return False


def _is_tensorflow_enabled(args, config):
def is_tensorflow_enabled(args, config):
if config.run_tf and args.testing_mode == "static":
if args.task == "accuracy" or args.framework in [
"tensorflow", "tf", "both"
Expand All @@ -185,7 +182,7 @@ def _is_tensorflow_enabled(args, config):
return False


def _is_torch_enabled(args, config):
def is_torch_enabled(args, config):
if config.run_torch and args.testing_mode == "dynamic":
if args.task == "accuracy" or args.framework in [
"torch", "pytorch", "both"
Expand Down Expand Up @@ -231,7 +228,7 @@ def test_main_without_json(pd_obj=None,
use_feed_fetch = True if args.task == "accuracy" else False

feeder_adapter = None
if _is_tensorflow_enabled(args, config):
if is_tensorflow_enabled(args, config):
assert tf_obj is not None, "TensorFlow object is None."
tf_config = config.to_tensorflow()
print(tf_config)
Expand All @@ -246,7 +243,7 @@ def test_main_without_json(pd_obj=None,
log_level=args.log_level,
config_params=config.to_string())

if _is_paddle_enabled(args, config) and args.testing_mode == "static":
if is_paddle_enabled(args, config) and args.testing_mode == "static":
assert pd_obj is not None, "Paddle object is None."
print(config)
pd_outputs, pd_stats = pd_obj.run(config, args, use_feed_fetch,
Expand All @@ -262,7 +259,7 @@ def test_main_without_json(pd_obj=None,
if pd_outputs == False:
sys.exit(1)

if _is_torch_enabled(args, config):
if is_torch_enabled(args, config):
assert torch_obj is not None, "PyTorch object is None."
import torch
try:
Expand All @@ -286,11 +283,11 @@ def test_main_without_json(pd_obj=None,
log_level=args.log_level,
config_params=config.to_string())

if _is_paddle_enabled(args, config) and args.testing_mode == "dynamic":
if is_paddle_enabled(args, config) and args.testing_mode == "dynamic":
assert pd_dy_obj is not None, "Paddle dynamic object is None."
print(config)
pd_dy_outputs, pd_dy_stats = pd_dy_obj.run(config, args,
feeder_adapter)
pd_dy_outputs, pd_dy_stats = pd_dy_obj.run(
config, args, feeder_adapter=feeder_adapter)

if args.task == "speed":
pd_dy_stats["gpu_time"] = args.gpu_time
Expand Down
Loading