From be24f0258a520a48555c9baec9d2f737ba1c2ca0 Mon Sep 17 00:00:00 2001 From: Jakub Kuderski Date: Wed, 21 Jun 2023 14:44:18 -0400 Subject: [PATCH] Use Black to format Python files (#14161) Switch from yapf to Black to better align with the LLVM and broader Python community. I decided not to go with Pyink as it seems much less popular and differs in formatting style beyond indentation. - Reformat all python files outside of `third_party` with black. - Update the lint workflow to use black. This only considers files modified by the PR. - Delete old dotfiles. The command used to reformat all files at once: ```shell fd -e py --exclude third_party | xargs black ``` To learn more about Back, see: https://black.readthedocs.io/en/stable/ and https://github.com/psf/black. In the next PR, once the commit SHA of this PR is finalized, I plan to add this commit to `.git-blame-ignore-revs` to keep the blame history clean. Issue: https://github.com/openxla/iree/issues/14135 --- .github/workflows/lint.yml | 16 +- .pylintrc | 1 - .style.yapf | 4 - build_tools/bazel_to_cmake/bazel_to_cmake.py | 531 ++--- .../bazel_to_cmake_converter.py | 1766 ++++++++-------- .../bazel_to_cmake/bazel_to_cmake_targets.py | 452 ++-- build_tools/benchmarks/benchmark_helper.py | 294 +-- .../collect_compilation_statistics.py | 384 ++-- .../collect_compilation_statistics_test.py | 296 +-- .../benchmarks/common/android_device_utils.py | 79 +- .../benchmarks/common/benchmark_config.py | 147 +- .../common/benchmark_config_test.py | 194 +- .../benchmarks/common/benchmark_definition.py | 1062 +++++----- .../benchmarks/common/benchmark_driver.py | 512 ++--- .../common/benchmark_driver_test.py | 529 ++--- .../common/benchmark_presentation.py | 964 +++++---- .../benchmarks/common/benchmark_suite.py | 310 +-- .../benchmarks/common/benchmark_suite_test.py | 426 ++-- .../benchmarks/common/benchmark_thresholds.py | 163 +- .../benchmarks/common/common_arguments.py | 309 +-- .../common/common_arguments_test.py | 115 +- .../benchmarks/common/linux_device_utils.py | 75 +- .../common/linux_device_utils_test.py | 35 +- .../comparisons/common/benchmark_command.py | 236 +-- .../common/benchmark_command_factory.py | 23 +- .../comparisons/common/benchmark_runner.py | 81 +- .../benchmarks/comparisons/common/utils.py | 10 +- .../comparisons/mobilebert_fp32_commands.py | 404 ++-- .../comparisons/mobilebert_int8_commands.py | 370 ++-- .../benchmarks/comparisons/run_benchmarks.py | 378 ++-- .../benchmarks/comparisons/simple_commands.py | 462 +++-- .../benchmarks/diff_local_benchmarks.py | 171 +- .../benchmarks/export_benchmark_config.py | 368 ++-- .../export_benchmark_config_test.py | 469 +++-- .../benchmarks/generate_benchmark_comment.py | 551 ++--- .../benchmarks/post_benchmark_comment.py | 398 ++-- .../benchmarks/post_benchmark_comment_test.py | 308 ++- .../reporting/parse_shark_benchmarks.py | 605 +++--- .../reporting/parse_tflite_benchmarks.py | 748 +++---- .../benchmarks/run_benchmarks_on_android.py | 678 +++--- .../benchmarks/run_benchmarks_on_linux.py | 325 +-- .../upload_benchmarks_to_dashboard.py | 690 ++++--- build_tools/docker/get_image_name.py | 31 +- build_tools/docker/manage_images.py | 351 ++-- build_tools/docker/utils.py | 53 +- build_tools/github_actions/build_dist.py | 215 +- build_tools/github_actions/cmake_ci.py | 283 +-- build_tools/github_actions/configure_ci.py | 441 ++-- .../github_actions/configure_ci_test.py | 168 +- .../config/health_server/health_server.py | 96 +- .../runner/gcp/update_instance_groups.py | 688 ++++--- .../runner/gcp/update_runner_version.py | 117 +- .../runner/instance_deleter/main.py | 427 ++-- .../runner/instance_deleter/main_test.py | 970 ++++----- .../iree/adreno_benchmarks.py | 159 +- .../iree/armv8_a_benchmarks.py | 194 +- .../iree/benchmark_collections.py | 115 +- .../benchmark_suites/iree/cuda_benchmarks.py | 170 +- .../benchmark_suites/iree/mali_benchmarks.py | 251 +-- .../iree/module_execution_configs.py | 99 +- .../benchmark_suites/iree/riscv_benchmarks.py | 124 +- .../python/benchmark_suites/iree/utils.py | 33 +- .../benchmark_suites/iree/vmvx_benchmarks.py | 76 +- .../iree/vulkan_nvidia_benchmarks.py | 173 +- .../iree/x86_64_benchmarks.py | 217 +- build_tools/python/cmake_builder/rules.py | 319 +-- .../python/cmake_builder/rules_test.py | 213 +- .../python/e2e_model_tests/cmake_generator.py | 96 +- .../e2e_model_tests/run_module_utils.py | 21 +- .../e2e_model_tests/run_module_utils_test.py | 25 +- .../e2e_model_tests/test_definitions.py | 98 +- .../cmake_generator/iree_rule_generator.py | 353 ++-- .../iree_rule_generator_test.py | 364 ++-- .../cmake_generator/model_rule_generator.py | 62 +- .../model_rule_generator_test.py | 70 +- .../e2e_test_artifacts/iree_artifacts.py | 93 +- .../e2e_test_artifacts/iree_artifacts_test.py | 260 +-- .../e2e_test_artifacts/model_artifacts.py | 38 +- .../model_artifacts_test.py | 78 +- .../definitions/common_definitions.py | 310 +-- .../definitions/iree_definitions.py | 743 +++---- .../definitions/iree_definitions_test.py | 309 +-- .../e2e_test_framework/definitions/utils.py | 59 +- .../definitions/utils_test.py | 76 +- .../device_specs/device_collections.py | 71 +- .../device_specs/device_collections_test.py | 134 +- .../device_specs/gcp_specs.py | 6 +- .../device_specs/moto_edge_x30_specs.py | 3 +- .../device_specs/pixel_4_specs.py | 6 +- .../device_specs/pixel_6_pro_specs.py | 9 +- .../e2e_test_framework/models/jax_models.py | 44 +- .../e2e_test_framework/models/matmul.py | 60 +- .../e2e_test_framework/models/model_groups.py | 211 +- .../e2e_test_framework/models/tf_models.py | 77 +- .../models/tflite_models.py | 72 +- .../e2e_test_framework/models/torch_models.py | 74 +- .../python/e2e_test_framework/models/utils.py | 60 +- .../e2e_test_framework/models/utils_test.py | 102 +- .../e2e_test_framework/serialization.py | 449 ++-- .../e2e_test_framework/serialization_test.py | 235 ++- .../python/e2e_test_framework/unique_ids.py | 138 +- .../e2e_test_framework/unique_ids_test.py | 58 +- .../python/reporting/benchmark_comment.py | 19 +- .../python/reporting/common/html_utils.py | 234 ++- build_tools/scripts/add_license_header.py | 300 +-- build_tools/scripts/check_path_lengths.py | 113 +- build_tools/scripts/download_file.py | 143 +- .../scripts/generate_compilation_flagfile.py | 29 +- build_tools/scripts/generate_flagfile.py | 101 +- build_tools/scripts/generate_release_index.py | 99 +- build_tools/scripts/get_e2e_artifacts.py | 257 +-- .../scripts/git/check_submodule_init.py | 72 +- build_tools/scripts/integrate/bump_llvm.py | 177 +- build_tools/scripts/integrate/iree_modules.py | 69 +- build_tools/scripts/integrate/iree_utils.py | 280 +-- build_tools/scripts/integrate/patch_module.py | 123 +- build_tools/scripts/ir_to_markdown.py | 133 +- build_tools/scripts/local_web_server.py | 80 +- build_tools/scripts/update_tflite_models.py | 51 +- build_tools/scripts/utils.py | 68 +- build_tools/testing/gen_test_matrix.py | 488 ++--- .../testing/generate_cmake_e2e_model_tests.py | 34 +- ...generate_cmake_e2e_test_artifacts_suite.py | 85 +- .../python/iree/compiler/tools/binaries.py | 449 ++-- .../python/iree/compiler/tools/core.py | 504 ++--- .../python/iree/compiler/tools/debugging.py | 320 +-- .../compiler/tools/scripts/ireec/__main__.py | 10 +- .../bindings/python/iree/compiler/tools/tf.py | 259 +-- .../python/iree/compiler/tools/tflite.py | 187 +- .../python/test/ir/registration_test.py | 8 +- .../python/test/tools/compiler_core_test.py | 468 +++-- .../python/test/tools/compiler_tf_test.py | 108 +- .../python/test/tools/compiler_tflite_test.py | 144 +- .../test/tools/testdata/generate_tflite.py | 15 +- .../test/tools/testdata/generate_xla.py | 4 +- compiler/lit.cfg.py | 22 +- compiler/setup.py | 517 ++--- .../src/iree/compiler/API/generate_exports.py | 189 +- configure_bazel.py | 91 +- docs/api_docs/python/conf.py | 28 +- .../dispatch_profiler/batch_matmul.py | 411 ++-- experimental/dispatch_profiler/compile.py | 85 +- experimental/dispatch_profiler/dispatch.py | 91 +- experimental/dispatch_profiler/generator.py | 21 +- experimental/dispatch_profiler/launchers.py | 431 ++-- experimental/dispatch_profiler/library.py | 287 +-- experimental/dispatch_profiler/manifest.py | 475 ++--- experimental/dispatch_profiler/matmul.py | 1120 +++++----- experimental/dispatch_profiler/options.py | 447 ++-- .../dispatch_profiler/performance_report.py | 275 +-- experimental/dispatch_profiler/profiler.py | 117 +- .../dispatch_profiler/split_k_matmul.py | 126 +- experimental/web/testing/parse_test_list.py | 196 +- integrations/tensorflow/lit.cfg.py | 8 +- .../iree_tf/iree/tf/support/module_utils.py | 1817 +++++++++-------- .../iree/tf/support/module_utils_test.py | 152 +- .../iree_tf/iree/tf/support/tf_test_utils.py | 1132 +++++----- .../iree/tf/support/tf_test_utils_test.py | 139 +- .../iree_tf/iree/tf/support/tf_utils.py | 469 +++-- .../iree_tf/iree/tf/support/tf_utils_test.py | 146 +- .../iree_tf/iree/tf/support/trace_utils.py | 756 +++---- .../iree/tf/support/trace_utils_test.py | 244 +-- .../iree_tf/iree/tools/tf/__init__.py | 10 +- .../tf/scripts/iree_import_tf/__main__.py | 180 +- .../python_projects/iree_tf/setup.py | 30 +- .../iree_tflite/iree/tools/tflite/__init__.py | 10 +- .../scripts/iree_import_tflite/__main__.py | 97 +- .../python_projects/iree_tflite/setup.py | 28 +- .../update_tflite_model_documentation.py | 83 +- integrations/tensorflow/test/lit.cfg.py | 42 +- .../tensorflow/test/python/generate_runner.py | 107 +- .../python/iree_tf_tests/batch_norm_test.py | 73 +- .../iree_tf_tests/batch_to_space_nd_test.py | 43 +- .../python/iree_tf_tests/broadcast_to_test.py | 48 +- .../python/iree_tf_tests/broadcasting_test.py | 73 +- .../test/python/iree_tf_tests/concat_test.py | 134 +- .../python/iree_tf_tests/control_flow_test.py | 66 +- .../test/python/iree_tf_tests/conv_test.py | 240 ++- .../iree_tf_tests/conv_transpose_test.py | 144 +- .../python/iree_tf_tests/depth_conv_test.py | 194 +- .../iree_tf_tests/dynamic_mlp_relu_test.py | 94 +- .../python/iree_tf_tests/dynamic_mlp_test.py | 94 +- .../iree_tf_tests/einsum_dynamic_test.py | 248 ++- .../iree_tf_tests/einsum_static_test.py | 430 ++-- .../iree_tf_tests/einsum_vector_test.py | 182 +- .../test/python/iree_tf_tests/fft_test.py | 172 +- .../test/python/iree_tf_tests/fill_test.py | 48 +- .../test/python/iree_tf_tests/gather_test.py | 201 +- .../python/iree_tf_tests/image_resize_test.py | 62 +- .../python/iree_tf_tests/linspace_test.py | 54 +- .../python/iree_tf_tests/mandelbrot_test.py | 161 +- .../iree_tf_tests/matrix_ops_dynamic_test.py | 132 +- .../iree_tf_tests/matrix_ops_static_test.py | 145 +- .../iree_tf_tests/mobile_bert_squad_test.py | 102 +- .../test/python/iree_tf_tests/pytree_test.py | 42 +- .../iree_tf_tests/quantization_dyn_test.py | 42 +- .../python/iree_tf_tests/quantization_test.py | 40 +- .../test/python/iree_tf_tests/range_test.py | 55 +- .../python/iree_tf_tests/resource_ops_test.py | 64 +- .../python/iree_tf_tests/ring_buffer_test.py | 356 ++-- .../iree_tf_tests/scatter_update_test.py | 129 +- .../iree_tf_tests/simple_arithmetic_test.py | 77 +- .../iree_tf_tests/simple_stateful_test.py | 49 +- .../iree_tf_tests/sliding_window_test.py | 115 +- .../iree_tf_tests/space_to_batch_nd_test.py | 43 +- .../python/iree_tfl_tests/cartoon_gan_test.py | 13 +- .../iree_tfl_tests/east_text_detector_test.py | 44 +- .../test/python/iree_tfl_tests/gpt2_test.py | 55 +- .../iree_tfl_tests/imagenet_test_data.py | 14 +- .../python/iree_tfl_tests/mnasnet_test.py | 13 +- .../mobilebert_tf2_quant_test.py | 84 +- .../iree_tfl_tests/mobilenet_v1_test.py | 25 +- .../mobilenet_v3-large_uint8_test.py | 53 +- .../iree_tfl_tests/mobilenet_v3_test.py | 25 +- .../iree_tfl_tests/person_detect_test.py | 90 +- .../python/iree_tfl_tests/posenet_i8_test.py | 68 +- .../iree_tfl_tests/resnet_50_int8_test.py | 32 +- .../python/iree_tfl_tests/squad_test_data.py | 1203 ++++++++++- .../test/python/iree_tfl_tests/test_util.py | 276 +-- .../_iree_linalg_transform_ops_ext.py | 117 +- .../_iree_structured_transform_ops_ext.py | 92 +- .../iree-dialects/test/lit.cfg.py | 59 +- .../iree-dialects/test/python/smoketest.py | 6 +- .../python/iree/runtime/array_interop.py | 375 ++-- .../bindings/python/iree/runtime/benchmark.py | 159 +- .../bindings/python/iree/runtime/function.py | 546 ++--- .../scripts/iree_benchmark_module/__main__.py | 11 +- .../scripts/iree_benchmark_trace/__main__.py | 11 +- .../scripts/iree_run_module/__main__.py | 10 +- .../scripts/iree_run_trace/__main__.py | 10 +- .../scripts/iree_tracy_capture/__main__.py | 11 +- .../python/iree/runtime/system_api.py | 471 ++--- .../python/iree/runtime/system_setup.py | 138 +- .../bindings/python/iree/runtime/tracing.py | 261 +-- .../python/tests/array_interop_test.py | 300 +-- .../bindings/python/tests/devices_cli_test.py | 91 +- runtime/bindings/python/tests/flags_test.py | 15 +- .../bindings/python/tests/function_test.py | 1217 ++++++----- runtime/bindings/python/tests/hal_test.py | 265 +-- runtime/bindings/python/tests/package_test.py | 29 +- .../bindings/python/tests/py_module_test.py | 560 +++-- .../bindings/python/tests/system_api_test.py | 242 ++- .../python/tests/system_setup_test.py | 101 +- runtime/bindings/python/tests/vm_test.py | 377 ++-- .../bindings/python/tests/vm_types_test.py | 169 +- runtime/lit.cfg.py | 20 +- runtime/setup.py | 470 +++-- .../testdata/npy/generate_npy_files.py | 62 +- samples/colab/test_notebooks.py | 51 +- .../simple_io_sample/test/run_mock.py | 20 +- samples/lit.cfg.py | 20 +- .../py_custom_module/decode_secret_message.py | 202 +- samples/vision_inference/convert_image.py | 18 +- tests/e2e/matmul/generate_e2e_matmul_tests.py | 970 ++++----- tests/e2e/models/mnist_train_test/datasets.py | 101 +- .../mnist_train_test/generate_test_data.py | 308 +-- .../mnist_train_test/mnist_train_test.py | 160 +- tests/lit.cfg.py | 22 +- tools/lit.cfg.py | 28 +- tools/test/echo_npy.py | 12 +- 260 files changed, 30427 insertions(+), 26423 deletions(-) delete mode 100644 .pylintrc delete mode 100644 .style.yapf diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 9d5ca9ccaf87..91f58d549adc 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -49,7 +49,7 @@ jobs: git diff --exit-code exit "${EXIT_CODE?}" - yapf: + black: runs-on: ubuntu-20.04 steps: - name: Checking out repository @@ -59,18 +59,18 @@ jobs: - name: Fetching Base Branch # We have to explicitly fetch the base branch as well run: git fetch --no-tags --prune --depth=1 origin "${GITHUB_BASE_REF?}:${GITHUB_BASE_REF?}" - - name: Install yapf + - name: Install black run: | - python3 -m pip install yapf==0.30.0 - - name: Run format_diff.py with yapf + python3 -m pip install black==23.3 + - name: Check if modified files are formatted run: | - git diff -U0 "${GITHUB_BASE_REF?}" | python3 third_party/format_diff/format_diff.py yapf -i - git diff --exit-code + git diff "${GITHUB_BASE_REF?}" --name-only -- '*.py' ':!third_party' \ + | xargs black --check --diff --verbose - name: Instructions for fixing the above linting errors if: failure() run: | - printf "You can fix the lint errors above by running\n" - printf " git diff -U0 "${GITHUB_BASE_REF?}" | python3 third_party/format_diff/format_diff.py yapf -i\n" + printf "You can fix formatting by running 'black' on the modified python files:\n" + printf " git diff ${GITHUB_BASE_REF?} --name-only -- '*.py' ':!third_party' | xargs black\n" pytype: runs-on: ubuntu-20.04 diff --git a/.pylintrc b/.pylintrc deleted file mode 100644 index f7ad5d7958ee..000000000000 --- a/.pylintrc +++ /dev/null @@ -1 +0,0 @@ -indent-string=' ' diff --git a/.style.yapf b/.style.yapf deleted file mode 100644 index 9ef1dc15ba62..000000000000 --- a/.style.yapf +++ /dev/null @@ -1,4 +0,0 @@ -[style] - based_on_style = google - column_limit = 80 - indent_width = 2 diff --git a/build_tools/bazel_to_cmake/bazel_to_cmake.py b/build_tools/bazel_to_cmake/bazel_to_cmake.py index 44e2b8ad2c3e..efb95292ed8c 100755 --- a/build_tools/bazel_to_cmake/bazel_to_cmake.py +++ b/build_tools/bazel_to_cmake/bazel_to_cmake.py @@ -59,8 +59,8 @@ repo_cfg = None EDIT_BLOCKING_PATTERN = re.compile( - r"bazel[\s_]*to[\s_]*cmake[\s_]*:?[\s_]*do[\s_]*not[\s_]*edit", - flags=re.IGNORECASE) + r"bazel[\s_]*to[\s_]*cmake[\s_]*:?[\s_]*do[\s_]*not[\s_]*edit", flags=re.IGNORECASE +) PRESERVE_ABOVE_TAG = "### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_ABOVE_THIS_LINE ###" PRESERVE_BELOW_TAG = "### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###" @@ -69,274 +69,299 @@ class Status(Enum): - UPDATED = 1 - NOOP = 2 - FAILED = 3 - SKIPPED = 4 - NO_BUILD_FILE = 5 + UPDATED = 1 + NOOP = 2 + FAILED = 3 + SKIPPED = 4 + NO_BUILD_FILE = 5 def parse_arguments(): - parser = argparse.ArgumentParser( - description="Bazel to CMake conversion helper.") - parser.add_argument("--preview", - help="Prints results instead of writing files", - action="store_true", - default=False) - parser.add_argument( - "--allow_partial_conversion", - help="Generates partial files, ignoring errors during conversion.", - action="store_true", - default=False) - parser.add_argument( - "--verbosity", - "-v", - type=int, - default=0, - help="Specify verbosity level where higher verbosity emits more logging." - " 0 (default): Only output errors and summary statistics." - " 1: Also output the name of each directory as it's being processed and" - " whether the directory is skipped." - " 2: Also output when conversion was successful.") - - # Specify only one of these (defaults to --root_dir=
). - group = parser.add_mutually_exclusive_group() - group.add_argument("--dir", - help="Converts the BUILD file in the given directory", - default=None) - default_root_dirs = (repo_cfg.DEFAULT_ROOT_DIRS if hasattr( - repo_cfg, "DEFAULT_ROOT_DIRS") else []) - group.add_argument("--root_dir", - nargs="+", - help="Converts all BUILD files under a root directory", - default=default_root_dirs) - - args = parser.parse_args() - - # --dir takes precedence over --root_dir. - # They are mutually exclusive, but the default value is still set. - if args.dir: - args.root_dir = None - - return args + parser = argparse.ArgumentParser(description="Bazel to CMake conversion helper.") + parser.add_argument( + "--preview", + help="Prints results instead of writing files", + action="store_true", + default=False, + ) + parser.add_argument( + "--allow_partial_conversion", + help="Generates partial files, ignoring errors during conversion.", + action="store_true", + default=False, + ) + parser.add_argument( + "--verbosity", + "-v", + type=int, + default=0, + help="Specify verbosity level where higher verbosity emits more logging." + " 0 (default): Only output errors and summary statistics." + " 1: Also output the name of each directory as it's being processed and" + " whether the directory is skipped." + " 2: Also output when conversion was successful.", + ) + + # Specify only one of these (defaults to --root_dir=
). + group = parser.add_mutually_exclusive_group() + group.add_argument( + "--dir", help="Converts the BUILD file in the given directory", default=None + ) + default_root_dirs = ( + repo_cfg.DEFAULT_ROOT_DIRS if hasattr(repo_cfg, "DEFAULT_ROOT_DIRS") else [] + ) + group.add_argument( + "--root_dir", + nargs="+", + help="Converts all BUILD files under a root directory", + default=default_root_dirs, + ) + + args = parser.parse_args() + + # --dir takes precedence over --root_dir. + # They are mutually exclusive, but the default value is still set. + if args.dir: + args.root_dir = None + + return args def setup_environment(): - """Sets up some environment globals.""" - global repo_root - global repo_cfg - - # Scan up the directory tree for a repo config file. - check_dir = os.getcwd() - while not os.path.exists(os.path.join(check_dir, REPO_CFG_FILE)): - new_check_dir = os.path.dirname(check_dir) - if not new_check_dir or new_check_dir == check_dir: - print(f"ERROR: Could not find {REPO_CFG_FILE} in a parent directory " - f"of {os.getcwd()}") - sys.exit(1) - check_dir = new_check_dir - repo_root = check_dir - log(f"Using repo root {repo_root}") - - # Dynamically load the config file as a module. - orig_dont_write_bytecode = sys.dont_write_bytecode - sys.dont_write_bytecode = True # Don't generate __pycache__ dir - repo_cfg_path = os.path.join(repo_root, REPO_CFG_FILE) - spec = importlib.util.spec_from_file_location(REPO_CFG_MODULE_NAME, - repo_cfg_path) - if spec and spec.loader: - repo_cfg = importlib.util.module_from_spec(spec) - sys.modules[REPO_CFG_MODULE_NAME] = repo_cfg - spec.loader.exec_module(repo_cfg) - sys.dont_write_bytecode = orig_dont_write_bytecode - else: - print(f"INTERNAL ERROR: Could not evaluate {repo_cfg_path} as module") - sys.exit(1) + """Sets up some environment globals.""" + global repo_root + global repo_cfg + + # Scan up the directory tree for a repo config file. + check_dir = os.getcwd() + while not os.path.exists(os.path.join(check_dir, REPO_CFG_FILE)): + new_check_dir = os.path.dirname(check_dir) + if not new_check_dir or new_check_dir == check_dir: + print( + f"ERROR: Could not find {REPO_CFG_FILE} in a parent directory " + f"of {os.getcwd()}" + ) + sys.exit(1) + check_dir = new_check_dir + repo_root = check_dir + log(f"Using repo root {repo_root}") + + # Dynamically load the config file as a module. + orig_dont_write_bytecode = sys.dont_write_bytecode + sys.dont_write_bytecode = True # Don't generate __pycache__ dir + repo_cfg_path = os.path.join(repo_root, REPO_CFG_FILE) + spec = importlib.util.spec_from_file_location(REPO_CFG_MODULE_NAME, repo_cfg_path) + if spec and spec.loader: + repo_cfg = importlib.util.module_from_spec(spec) + sys.modules[REPO_CFG_MODULE_NAME] = repo_cfg + spec.loader.exec_module(repo_cfg) + sys.dont_write_bytecode = orig_dont_write_bytecode + else: + print(f"INTERNAL ERROR: Could not evaluate {repo_cfg_path} as module") + sys.exit(1) def repo_relpath(path): - return os.path.relpath(path, repo_root).replace("\\", "/") + return os.path.relpath(path, repo_root).replace("\\", "/") def log(string, *args, indent=0, **kwargs): - print(textwrap.indent(string, prefix=(indent * " ")), - *args, - **kwargs, - file=sys.stderr) - - -def convert_directories(directories, write_files, allow_partial_conversion, - verbosity): - failure_dirs = [] - skip_count = 0 - success_count = 0 - noop_count = 0 - for directory in directories: - status = convert_directory( - directory, - write_files=write_files, - allow_partial_conversion=allow_partial_conversion, - verbosity=verbosity) - if status == Status.FAILED: - failure_dirs.append(repo_relpath(directory)) - elif status == Status.SKIPPED: - skip_count += 1 - elif status == Status.UPDATED: - success_count += 1 - elif status == Status.NOOP: - noop_count += 1 - - log(f"{success_count} CMakeLists.txt files were updated, {skip_count} were" - f" skipped, and {noop_count} required no change.") - if failure_dirs: - log(f"ERROR: Encountered unexpected errors converting {len(failure_dirs)}" - " directories:") - log("\n".join(failure_dirs), indent=2) - sys.exit(1) - - -def convert_directory(directory_path, write_files, allow_partial_conversion, - verbosity): - if not os.path.isdir(directory_path): - raise FileNotFoundError(f"Cannot find directory '{directory_path}'") - - rel_dir_path = repo_relpath(directory_path) - if verbosity >= 1: - log(f"Processing {rel_dir_path}") - - # Scan for a BUILD file. - build_file_found = False - build_file_basenames = ["BUILD", "BUILD.bazel"] - for build_file_basename in build_file_basenames: - build_file_path = os.path.join(directory_path, build_file_basename) - - rel_build_file_path = repo_relpath(build_file_path) - if os.path.isfile(build_file_path): - build_file_found = True - break - cmakelists_file_path = os.path.join(directory_path, "CMakeLists.txt") - rel_cmakelists_file_path = repo_relpath(cmakelists_file_path) - - if not build_file_found: - return Status.NO_BUILD_FILE - - autogeneration_tag = f"Autogenerated by {repo_relpath(os.path.abspath(__file__))}" - - header = "\n".join(["#" * 80] + [ - l.ljust(79) + "#" for l in [ - f"# {autogeneration_tag} from", - f"# {rel_build_file_path}", - "#", - "# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary", - "# CMake-only content.", - "#", - f"# To disable autogeneration for this file entirely, delete this header.", - ] - ] + ["#" * 80]) - - old_lines = [] - possible_preserved_header_lines = [] - preserved_footer_lines = ["\n" + PRESERVE_BELOW_TAG + "\n"] - - # Read CMakeLists.txt and check if it has the auto-generated header. - found_preserve_below_tag = False - found_preserve_above_tag = False - if os.path.isfile(cmakelists_file_path): - found_autogeneration_tag = False - with open(cmakelists_file_path) as f: - old_lines = f.readlines() - - for line in old_lines: - if not found_preserve_above_tag: - possible_preserved_header_lines.append(line) - if not found_autogeneration_tag and autogeneration_tag in line: - found_autogeneration_tag = True - if not found_preserve_below_tag and PRESERVE_BELOW_TAG in line: - found_preserve_below_tag = True - elif not found_preserve_above_tag and PRESERVE_ABOVE_TAG in line: - found_preserve_above_tag = True - elif found_preserve_below_tag: - preserved_footer_lines.append(line) - if not found_autogeneration_tag: - if verbosity >= 1: - log(f"Skipped. Did not find autogeneration line.", indent=2) - return Status.SKIPPED - preserved_header = ("".join(possible_preserved_header_lines) - if found_preserve_above_tag else "") - preserved_footer = "".join(preserved_footer_lines) - - # Read the Bazel BUILD file and interpret it. - with open(build_file_path, "rt") as build_file: - build_file_contents = build_file.read() - if "bazel-to-cmake: skip" in build_file_contents: - return Status.SKIPPED - build_file_code = compile(build_file_contents, build_file_path, "exec") - try: - converted_build_file = bazel_to_cmake_converter.convert_build_file( - build_file_code, - repo_cfg=repo_cfg, - allow_partial_conversion=allow_partial_conversion) - except (NameError, NotImplementedError) as e: - log( - f"ERROR generating {rel_dir_path}.\n" - f"Missing a rule handler in bazel_to_cmake_converter.py?\n" - f"Reason: `{type(e).__name__}: {e}`", - indent=2) - return Status.FAILED - except KeyError as e: - log( - f"ERROR generating {rel_dir_path}.\n" - f"Missing a conversion in bazel_to_cmake_targets.py?\n" - f"Reason: `{type(e).__name__}: {e}`", - indent=2) - return Status.FAILED - converted_content = (preserved_header + header + converted_build_file + - preserved_footer) - if write_files: - with open(cmakelists_file_path, "wt") as cmakelists_file: - cmakelists_file.write(converted_content) - else: - print(converted_content, end="") - - if converted_content == "".join(old_lines): - if verbosity >= 2: - log(f"{rel_cmakelists_file_path} required no update", indent=2) - return Status.NOOP + print( + textwrap.indent(string, prefix=(indent * " ")), *args, **kwargs, file=sys.stderr + ) + + +def convert_directories(directories, write_files, allow_partial_conversion, verbosity): + failure_dirs = [] + skip_count = 0 + success_count = 0 + noop_count = 0 + for directory in directories: + status = convert_directory( + directory, + write_files=write_files, + allow_partial_conversion=allow_partial_conversion, + verbosity=verbosity, + ) + if status == Status.FAILED: + failure_dirs.append(repo_relpath(directory)) + elif status == Status.SKIPPED: + skip_count += 1 + elif status == Status.UPDATED: + success_count += 1 + elif status == Status.NOOP: + noop_count += 1 - if verbosity >= 2: log( - f"Successfly generated {rel_cmakelists_file_path}" - f" from {rel_build_file_path}", - indent=2) - return Status.UPDATED + f"{success_count} CMakeLists.txt files were updated, {skip_count} were" + f" skipped, and {noop_count} required no change." + ) + if failure_dirs: + log( + f"ERROR: Encountered unexpected errors converting {len(failure_dirs)}" + " directories:" + ) + log("\n".join(failure_dirs), indent=2) + sys.exit(1) + + +def convert_directory(directory_path, write_files, allow_partial_conversion, verbosity): + if not os.path.isdir(directory_path): + raise FileNotFoundError(f"Cannot find directory '{directory_path}'") + + rel_dir_path = repo_relpath(directory_path) + if verbosity >= 1: + log(f"Processing {rel_dir_path}") + + # Scan for a BUILD file. + build_file_found = False + build_file_basenames = ["BUILD", "BUILD.bazel"] + for build_file_basename in build_file_basenames: + build_file_path = os.path.join(directory_path, build_file_basename) + + rel_build_file_path = repo_relpath(build_file_path) + if os.path.isfile(build_file_path): + build_file_found = True + break + cmakelists_file_path = os.path.join(directory_path, "CMakeLists.txt") + rel_cmakelists_file_path = repo_relpath(cmakelists_file_path) + + if not build_file_found: + return Status.NO_BUILD_FILE + + autogeneration_tag = f"Autogenerated by {repo_relpath(os.path.abspath(__file__))}" + + header = "\n".join( + ["#" * 80] + + [ + l.ljust(79) + "#" + for l in [ + f"# {autogeneration_tag} from", + f"# {rel_build_file_path}", + "#", + "# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary", + "# CMake-only content.", + "#", + f"# To disable autogeneration for this file entirely, delete this header.", + ] + ] + + ["#" * 80] + ) + + old_lines = [] + possible_preserved_header_lines = [] + preserved_footer_lines = ["\n" + PRESERVE_BELOW_TAG + "\n"] + + # Read CMakeLists.txt and check if it has the auto-generated header. + found_preserve_below_tag = False + found_preserve_above_tag = False + if os.path.isfile(cmakelists_file_path): + found_autogeneration_tag = False + with open(cmakelists_file_path) as f: + old_lines = f.readlines() + + for line in old_lines: + if not found_preserve_above_tag: + possible_preserved_header_lines.append(line) + if not found_autogeneration_tag and autogeneration_tag in line: + found_autogeneration_tag = True + if not found_preserve_below_tag and PRESERVE_BELOW_TAG in line: + found_preserve_below_tag = True + elif not found_preserve_above_tag and PRESERVE_ABOVE_TAG in line: + found_preserve_above_tag = True + elif found_preserve_below_tag: + preserved_footer_lines.append(line) + if not found_autogeneration_tag: + if verbosity >= 1: + log(f"Skipped. Did not find autogeneration line.", indent=2) + return Status.SKIPPED + preserved_header = ( + "".join(possible_preserved_header_lines) if found_preserve_above_tag else "" + ) + preserved_footer = "".join(preserved_footer_lines) + + # Read the Bazel BUILD file and interpret it. + with open(build_file_path, "rt") as build_file: + build_file_contents = build_file.read() + if "bazel-to-cmake: skip" in build_file_contents: + return Status.SKIPPED + build_file_code = compile(build_file_contents, build_file_path, "exec") + try: + converted_build_file = bazel_to_cmake_converter.convert_build_file( + build_file_code, + repo_cfg=repo_cfg, + allow_partial_conversion=allow_partial_conversion, + ) + except (NameError, NotImplementedError) as e: + log( + f"ERROR generating {rel_dir_path}.\n" + f"Missing a rule handler in bazel_to_cmake_converter.py?\n" + f"Reason: `{type(e).__name__}: {e}`", + indent=2, + ) + return Status.FAILED + except KeyError as e: + log( + f"ERROR generating {rel_dir_path}.\n" + f"Missing a conversion in bazel_to_cmake_targets.py?\n" + f"Reason: `{type(e).__name__}: {e}`", + indent=2, + ) + return Status.FAILED + converted_content = ( + preserved_header + header + converted_build_file + preserved_footer + ) + if write_files: + with open(cmakelists_file_path, "wt") as cmakelists_file: + cmakelists_file.write(converted_content) + else: + print(converted_content, end="") + + if converted_content == "".join(old_lines): + if verbosity >= 2: + log(f"{rel_cmakelists_file_path} required no update", indent=2) + return Status.NOOP + + if verbosity >= 2: + log( + f"Successfly generated {rel_cmakelists_file_path}" + f" from {rel_build_file_path}", + indent=2, + ) + return Status.UPDATED def main(args): - """Runs Bazel to CMake conversion.""" - global repo_root - - write_files = not args.preview - - if args.root_dir: - for root_dir in args.root_dir: - root_directory_path = os.path.join(repo_root, root_dir) - log(f"Converting directory tree rooted at: {root_directory_path}") - convert_directories( - (root for root, _, _ in os.walk(root_directory_path)), - write_files=write_files, - allow_partial_conversion=args.allow_partial_conversion, - verbosity=args.verbosity) - elif args.dir: - convert_directories([os.path.join(repo_root, args.dir)], - write_files=write_files, - allow_partial_conversion=args.allow_partial_conversion, - verbosity=args.verbosity) - else: - log(f"ERROR: None of --root-dir, --dir arguments or DEFAULT_ROOT_DIRS in " - f".bazel_to_cmake.cfg.py: No conversion will be done") - sys.exit(1) + """Runs Bazel to CMake conversion.""" + global repo_root + + write_files = not args.preview + + if args.root_dir: + for root_dir in args.root_dir: + root_directory_path = os.path.join(repo_root, root_dir) + log(f"Converting directory tree rooted at: {root_directory_path}") + convert_directories( + (root for root, _, _ in os.walk(root_directory_path)), + write_files=write_files, + allow_partial_conversion=args.allow_partial_conversion, + verbosity=args.verbosity, + ) + elif args.dir: + convert_directories( + [os.path.join(repo_root, args.dir)], + write_files=write_files, + allow_partial_conversion=args.allow_partial_conversion, + verbosity=args.verbosity, + ) + else: + log( + f"ERROR: None of --root-dir, --dir arguments or DEFAULT_ROOT_DIRS in " + f".bazel_to_cmake.cfg.py: No conversion will be done" + ) + sys.exit(1) if __name__ == "__main__": - setup_environment() - main(parse_arguments()) + setup_environment() + main(parse_arguments()) diff --git a/build_tools/bazel_to_cmake/bazel_to_cmake_converter.py b/build_tools/bazel_to_cmake/bazel_to_cmake_converter.py index 023dc4d2abeb..f6d6cb543ff4 100644 --- a/build_tools/bazel_to_cmake/bazel_to_cmake_converter.py +++ b/build_tools/bazel_to_cmake/bazel_to_cmake_converter.py @@ -20,875 +20,921 @@ class BuildFileFunctions(object): - """Object passed to `exec` that has handlers for BUILD file functions.""" - - def __init__(self, *, converter: "Converter", - targets: bazel_to_cmake_targets.TargetConverter): - self._converter = converter - self._targets = targets - self._custom_initialize() - - def _custom_initialize(self): - pass - - # ------------------------------------------------------------------------- # - # Conversion utilities, written to reduce boilerplate and allow for reuse # - # between similar rule conversions (e.g. cc_library and cc_binary). # - # ------------------------------------------------------------------------- # - - def _expand_cmake_var(self, var): - return "${" + var + "}" - - def _convert_string_arg_block(self, name, value, quote=True): - # NAME - # "value" - if value is None: - return "" - if quote: - return f' {name}\n "{value}"\n' - else: - return f" {name}\n {value}\n" - - # Match Bazel's timeout values - # https://docs.bazel.build/versions/main/test-encyclopedia.html - _timeout_map = { - "short": 60, - "moderate": 300, - "long": 900, - "eternal": 3600, - } - - def _should_skip_target(self, tags=None, **kwargs): - if tags and "skip-bazel_to_cmake" in tags: - return True - return False - - def _convert_timeout_arg_block(self, name, value): - if value is None: - return "" - value = self._timeout_map[value] - return f" {name}\n {value}\n" - - def _convert_string_list_block(self, name, values, quote=True, sort=False): - # Note this deliberately distinguishes between an empty list (argument - # explicitly specified) and None (argument left as default). - if values is None: - return "" - - if sort: - values = sorted(values) - - if quote: - values_list = "\n".join([f' "{v}"' for v in values]) - else: - values_list = "\n".join([f" {v}" for v in values]) - - return f" {name}\n{values_list}\n" - - def _convert_option_block(self, option, option_value): - if option_value: - # Note: this is a truthiness check as well as an existence check, e.g. - # Bazel `testonly = False` will be handled correctly by this condition. - return f" {option}\n" - else: - return "" - - def _convert_target_block(self, name, target): - if target is None: - return "" - - # Convert the target name from its Bazel name to the corresponding CMake name. - # The specific conversion pattern depends on the target location. In general, - # Bazel targets are fully qualified and use slashes as delimiters, while - # targets in CMake are rooted on subtrees and use _ (with :: aliases). - cmake_aliases = self._targets.convert_target(target) - if len(cmake_aliases) != 1: - raise ValueError( - f"Expected a CMake alias from {target}. Got {cmake_aliases}") - target = cmake_aliases[0] - # Replace aliased :: target names with their explicit _ names. - target = target.replace("::", "_") - return self._convert_string_arg_block(name, target, quote=False) - - def _convert_srcs_block(self, srcs): - if not srcs: - return "" - # Bazel allows srcs to reference targets in the current package (leading - # ':') or in other packages (leading '//'). We map that to paths by: - # - dropping any leading ':' as in: - # ':generated.c' -> 'generated.c' - # - dropping any leading '//', and internal ':' by '/', as in: - # '//path/to/package:generated.c' -> 'path/to/package/generated.c' - srcs = [s.lstrip('//').lstrip(':').replace(':', '/') for s in srcs] - - return self._convert_string_list_block("SRCS", srcs, sort=True) - - def _convert_td_file_block(self, td_file): - if td_file.startswith("//iree"): - # TODO: This should be generalized for out of tree. - # Bazel `//iree/dir/td_file.td` - # -> CMake `${IREE_ROOT_DIR}/iree/dir/td_file.td - # Bazel `//iree/dir/IR:td_file.td` - # -> CMake `${IREE_ROOT_DIR}/iree/dir/IR/td_file.td - td_file = td_file.replace("//iree", "${IREE_ROOT_DIR}/iree") - td_file = td_file.replace(":", "/") - return self._convert_string_arg_block("TD_FILE", td_file) - - def _convert_tbl_outs_block(self, tbl_outs): - outs_list = "\n".join( - [f" {' '.join(flags)} {value}" for flags, value in tbl_outs]) - return f" OUTS\n{outs_list}\n" - - def _convert_tblgen_block(self, tblgen): - if tblgen.endswith("iree-tblgen"): - return " TBLGEN\n IREE\n" - else: - return "" - - def _convert_target(self, target): - """Returns a list of targets that correspond to the specified Bazel target. - Note that this must be a list because some targets have a one to many mapping. - """ - return self._targets.convert_target(target) - - def _convert_single_target(self, target): - replacement_targets = self._convert_target(target) - if len(replacement_targets) != 1: - raise RuntimeError(f"Expected single target replacement for {target}," - f" but got multiple: {replacement_targets}") - return replacement_targets[0] - - def _convert_single_target_block(self, name, target): - mapped_target = self._convert_single_target(target) - return self._convert_string_arg_block(name, mapped_target, quote=False) - - def _convert_target_list_block(self, list_name, targets): - if targets is None: - return "" - - # DEPS - # package1::target1 - # package1::target2 - # package2::target - targets = [self._convert_target(t) for t in targets] - # Flatten lists - targets = list(itertools.chain.from_iterable(targets)) - # Remove duplicates - targets = set(targets) - # Remove Falsey (None and empty string) values - targets = filter(None, targets) - - return self._convert_string_list_block(list_name, - targets, - sort=True, - quote=False) - - def _convert_includes_block(self, includes): - if not includes: - return "" - dirs = [] - for include in includes: - dirs.append("$" % - (include,)) - dirs.append("$" % - (include,)) - return self._convert_string_list_block("INCLUDES", - dirs, - sort=False, - quote=True) - - def _convert_unimplemented_function(self, function, details=""): - message = f"Unimplemented {function}: {details}" - if not self._converter.first_error: - self._converter.first_error = NotImplementedError(message) - # Avoid submitting the raw results from non-strict runs. These are still - # useful but are generally not safe to submit as-is. An upstream check - # prevents changes with this phrase from being submitted. - # Written as separate literals to avoid the check triggering here. - submit_blocker = "DO" + " NOT" + " SUBMIT." - self._converter.body += f"# {submit_blocker} {message}\n" - - # ------------------------------------------------------------------------- # - # Function handlers that convert BUILD definitions to CMake definitions. # - # # - # Names and signatures must match 1:1 with those expected in BUILD files # - # except that default values for optional arguments should generally be # - # `None` so we don't set them unnecessarily in the CMakeLists.txt files. # - # Each function that may be found in a BUILD file must be listed here. # - # ------------------------------------------------------------------------- # - - # Functions with no mapping to CMake. Just ignore these. - def alias(self, *args, **kwargs): - pass - - def bool_flag(self, *args, **kwargs): - pass - - def load(self, *args, **kwargs): - pass - - def package(self, **kwargs): - pass - - def iree_build_test(self, **kwargs): - pass - - def test_suite(self, **kwargs): - pass - - def config_setting(self, **kwargs): - pass - - def exports_files(self, *args, **kwargs): - pass - - def iree_td_library(self, *args, **kwargs): - pass - - # Technically we could do something with a CMake equivalent but we have no use - # case. - def py_binary(self, *args, **kwargs): - pass - - def filegroup(self, name, **kwargs): - # Not implemented, but allowed for Bazel-only uses, such as declaring internal - # headers and other kinds of files that Bazel enforces but CMake doesn't care - # about. If we ever need to implement this, this might be a no-op, or may - # want to evaluate the srcs attribute and pass them along to any targets - # that depend on the filegroup. - # Cross-package dependencies and complicated globs could be hard to handle. - pass - - def sh_binary(self, name, **kwargs): - if self._should_skip_target(**kwargs): - return - self._convert_unimplemented_function("sh_binary", name) - - def enforce_glob(self, files, **kwargs): - return files - - def glob(self, include, exclude=None, exclude_directories=1): - if exclude_directories != 1: - self._convert_unimplemented_function("glob", "with exclude_directories") - if exclude is None: - exclude = [] - - glob_vars = [] - for pattern in include: - if "**" in pattern: - # bazel's glob has some specific restrictions about crossing package - # boundaries. We have no uses of recursive globs. Rather than try to - # emulate them or silently give different behavior, just error out. - # See https://docs.bazel.build/versions/master/be/functions.html#glob - raise NotImplementedError("Recursive globs not supported") - # Bazel `*.mlir` glob -> CMake Variable `_GLOB_X_MLIR` - var = "_GLOB_" + pattern.replace("*", "X").replace(".", "_").upper() - glob_vars.append(var) - self._converter.body += ( - f"file(GLOB {var} LIST_DIRECTORIES false" - f" RELATIVE {self._expand_cmake_var('CMAKE_CURRENT_SOURCE_DIR')}" - f" CONFIGURE_DEPENDS {pattern})\n") - for pattern in exclude: - if "**" in pattern: - raise NotImplementedError("Recursive globs not supported") - exclude_var = ("_GLOB_" + - pattern.replace("*", "X").replace(".", "_").upper()) - self._converter.body += ( - f"file(GLOB {exclude_var} LIST_DIRECTORIES false" - f" RELATIVE {self._expand_cmake_var('CMAKE_CURRENT_SOURCE_DIR')}" - f" CONFIGURE_DEPENDS {pattern})\n") - for glob_var in glob_vars: + """Object passed to `exec` that has handlers for BUILD file functions.""" + + def __init__( + self, *, converter: "Converter", targets: bazel_to_cmake_targets.TargetConverter + ): + self._converter = converter + self._targets = targets + self._custom_initialize() + + def _custom_initialize(self): + pass + + # ------------------------------------------------------------------------- # + # Conversion utilities, written to reduce boilerplate and allow for reuse # + # between similar rule conversions (e.g. cc_library and cc_binary). # + # ------------------------------------------------------------------------- # + + def _expand_cmake_var(self, var): + return "${" + var + "}" + + def _convert_string_arg_block(self, name, value, quote=True): + # NAME + # "value" + if value is None: + return "" + if quote: + return f' {name}\n "{value}"\n' + else: + return f" {name}\n {value}\n" + + # Match Bazel's timeout values + # https://docs.bazel.build/versions/main/test-encyclopedia.html + _timeout_map = { + "short": 60, + "moderate": 300, + "long": 900, + "eternal": 3600, + } + + def _should_skip_target(self, tags=None, **kwargs): + if tags and "skip-bazel_to_cmake" in tags: + return True + return False + + def _convert_timeout_arg_block(self, name, value): + if value is None: + return "" + value = self._timeout_map[value] + return f" {name}\n {value}\n" + + def _convert_string_list_block(self, name, values, quote=True, sort=False): + # Note this deliberately distinguishes between an empty list (argument + # explicitly specified) and None (argument left as default). + if values is None: + return "" + + if sort: + values = sorted(values) + + if quote: + values_list = "\n".join([f' "{v}"' for v in values]) + else: + values_list = "\n".join([f" {v}" for v in values]) + + return f" {name}\n{values_list}\n" + + def _convert_option_block(self, option, option_value): + if option_value: + # Note: this is a truthiness check as well as an existence check, e.g. + # Bazel `testonly = False` will be handled correctly by this condition. + return f" {option}\n" + else: + return "" + + def _convert_target_block(self, name, target): + if target is None: + return "" + + # Convert the target name from its Bazel name to the corresponding CMake name. + # The specific conversion pattern depends on the target location. In general, + # Bazel targets are fully qualified and use slashes as delimiters, while + # targets in CMake are rooted on subtrees and use _ (with :: aliases). + cmake_aliases = self._targets.convert_target(target) + if len(cmake_aliases) != 1: + raise ValueError( + f"Expected a CMake alias from {target}. Got {cmake_aliases}" + ) + target = cmake_aliases[0] + # Replace aliased :: target names with their explicit _ names. + target = target.replace("::", "_") + return self._convert_string_arg_block(name, target, quote=False) + + def _convert_srcs_block(self, srcs): + if not srcs: + return "" + # Bazel allows srcs to reference targets in the current package (leading + # ':') or in other packages (leading '//'). We map that to paths by: + # - dropping any leading ':' as in: + # ':generated.c' -> 'generated.c' + # - dropping any leading '//', and internal ':' by '/', as in: + # '//path/to/package:generated.c' -> 'path/to/package/generated.c' + srcs = [s.lstrip("//").lstrip(":").replace(":", "/") for s in srcs] + + return self._convert_string_list_block("SRCS", srcs, sort=True) + + def _convert_td_file_block(self, td_file): + if td_file.startswith("//iree"): + # TODO: This should be generalized for out of tree. + # Bazel `//iree/dir/td_file.td` + # -> CMake `${IREE_ROOT_DIR}/iree/dir/td_file.td + # Bazel `//iree/dir/IR:td_file.td` + # -> CMake `${IREE_ROOT_DIR}/iree/dir/IR/td_file.td + td_file = td_file.replace("//iree", "${IREE_ROOT_DIR}/iree") + td_file = td_file.replace(":", "/") + return self._convert_string_arg_block("TD_FILE", td_file) + + def _convert_tbl_outs_block(self, tbl_outs): + outs_list = "\n".join( + [f" {' '.join(flags)} {value}" for flags, value in tbl_outs] + ) + return f" OUTS\n{outs_list}\n" + + def _convert_tblgen_block(self, tblgen): + if tblgen.endswith("iree-tblgen"): + return " TBLGEN\n IREE\n" + else: + return "" + + def _convert_target(self, target): + """Returns a list of targets that correspond to the specified Bazel target. + Note that this must be a list because some targets have a one to many mapping. + """ + return self._targets.convert_target(target) + + def _convert_single_target(self, target): + replacement_targets = self._convert_target(target) + if len(replacement_targets) != 1: + raise RuntimeError( + f"Expected single target replacement for {target}," + f" but got multiple: {replacement_targets}" + ) + return replacement_targets[0] + + def _convert_single_target_block(self, name, target): + mapped_target = self._convert_single_target(target) + return self._convert_string_arg_block(name, mapped_target, quote=False) + + def _convert_target_list_block(self, list_name, targets): + if targets is None: + return "" + + # DEPS + # package1::target1 + # package1::target2 + # package2::target + targets = [self._convert_target(t) for t in targets] + # Flatten lists + targets = list(itertools.chain.from_iterable(targets)) + # Remove duplicates + targets = set(targets) + # Remove Falsey (None and empty string) values + targets = filter(None, targets) + + return self._convert_string_list_block( + list_name, targets, sort=True, quote=False + ) + + def _convert_includes_block(self, includes): + if not includes: + return "" + dirs = [] + for include in includes: + dirs.append( + "$" % (include,) + ) + dirs.append( + "$" % (include,) + ) + return self._convert_string_list_block("INCLUDES", dirs, sort=False, quote=True) + + def _convert_unimplemented_function(self, function, details=""): + message = f"Unimplemented {function}: {details}" + if not self._converter.first_error: + self._converter.first_error = NotImplementedError(message) + # Avoid submitting the raw results from non-strict runs. These are still + # useful but are generally not safe to submit as-is. An upstream check + # prevents changes with this phrase from being submitted. + # Written as separate literals to avoid the check triggering here. + submit_blocker = "DO" + " NOT" + " SUBMIT." + self._converter.body += f"# {submit_blocker} {message}\n" + + # ------------------------------------------------------------------------- # + # Function handlers that convert BUILD definitions to CMake definitions. # + # # + # Names and signatures must match 1:1 with those expected in BUILD files # + # except that default values for optional arguments should generally be # + # `None` so we don't set them unnecessarily in the CMakeLists.txt files. # + # Each function that may be found in a BUILD file must be listed here. # + # ------------------------------------------------------------------------- # + + # Functions with no mapping to CMake. Just ignore these. + def alias(self, *args, **kwargs): + pass + + def bool_flag(self, *args, **kwargs): + pass + + def load(self, *args, **kwargs): + pass + + def package(self, **kwargs): + pass + + def iree_build_test(self, **kwargs): + pass + + def test_suite(self, **kwargs): + pass + + def config_setting(self, **kwargs): + pass + + def exports_files(self, *args, **kwargs): + pass + + def iree_td_library(self, *args, **kwargs): + pass + + # Technically we could do something with a CMake equivalent but we have no use + # case. + def py_binary(self, *args, **kwargs): + pass + + def filegroup(self, name, **kwargs): + # Not implemented, but allowed for Bazel-only uses, such as declaring internal + # headers and other kinds of files that Bazel enforces but CMake doesn't care + # about. If we ever need to implement this, this might be a no-op, or may + # want to evaluate the srcs attribute and pass them along to any targets + # that depend on the filegroup. + # Cross-package dependencies and complicated globs could be hard to handle. + pass + + def sh_binary(self, name, **kwargs): + if self._should_skip_target(**kwargs): + return + self._convert_unimplemented_function("sh_binary", name) + + def enforce_glob(self, files, **kwargs): + return files + + def glob(self, include, exclude=None, exclude_directories=1): + if exclude_directories != 1: + self._convert_unimplemented_function("glob", "with exclude_directories") + if exclude is None: + exclude = [] + + glob_vars = [] + for pattern in include: + if "**" in pattern: + # bazel's glob has some specific restrictions about crossing package + # boundaries. We have no uses of recursive globs. Rather than try to + # emulate them or silently give different behavior, just error out. + # See https://docs.bazel.build/versions/master/be/functions.html#glob + raise NotImplementedError("Recursive globs not supported") + # Bazel `*.mlir` glob -> CMake Variable `_GLOB_X_MLIR` + var = "_GLOB_" + pattern.replace("*", "X").replace(".", "_").upper() + glob_vars.append(var) + self._converter.body += ( + f"file(GLOB {var} LIST_DIRECTORIES false" + f" RELATIVE {self._expand_cmake_var('CMAKE_CURRENT_SOURCE_DIR')}" + f" CONFIGURE_DEPENDS {pattern})\n" + ) + for pattern in exclude: + if "**" in pattern: + raise NotImplementedError("Recursive globs not supported") + exclude_var = "_GLOB_" + pattern.replace("*", "X").replace(".", "_").upper() + self._converter.body += ( + f"file(GLOB {exclude_var} LIST_DIRECTORIES false" + f" RELATIVE {self._expand_cmake_var('CMAKE_CURRENT_SOURCE_DIR')}" + f" CONFIGURE_DEPENDS {pattern})\n" + ) + for glob_var in glob_vars: + self._converter.body += f"list(REMOVE_ITEM {glob_var} {self._expand_cmake_var(exclude_var)})\n" + return [self._expand_cmake_var(var) for var in glob_vars] + + # TODO(gcmn) implement these types of functions in a less hard-coded way + def platform_trampoline_deps(self, basename, path="base"): + return [f"//{path}/internal:{basename}_internal"] + + def select(self, d): + self._convert_unimplemented_function("select", str(d)) + return d["//conditions:default"] + + def defaulting_select(self, selector): + """Defined in build_defs.oss.bzl as a scoped alternative to select.""" + default_value = selector.get("//conditions:default") + if default_value is None: + raise ValueError("bazel_to_cmake can only convert selects with a default") + return default_value + + def cc_library( + self, + name, + hdrs=None, + textual_hdrs=None, + srcs=None, + copts=None, + defines=None, + data=None, + deps=None, + testonly=None, + linkopts=None, + includes=None, + **kwargs, + ): + if self._should_skip_target(**kwargs): + return + if linkopts: + self._convert_unimplemented_function("linkopts") + name_block = self._convert_string_arg_block("NAME", name, quote=False) + hdrs_block = self._convert_string_list_block("HDRS", hdrs, sort=True) + textual_hdrs_block = self._convert_string_list_block( + "TEXTUAL_HDRS", textual_hdrs, sort=True + ) + srcs_block = self._convert_srcs_block(srcs) + copts_block = self._convert_string_list_block("COPTS", copts, sort=False) + defines_block = self._convert_string_list_block("DEFINES", defines) + data_block = self._convert_target_list_block("DATA", data) + deps_block = self._convert_target_list_block("DEPS", deps) + testonly_block = self._convert_option_block("TESTONLY", testonly) + includes_block = self._convert_includes_block(includes) + + self._converter.body += ( + f"iree_cc_library(\n" + f"{name_block}" + f"{copts_block}" + f"{hdrs_block}" + f"{textual_hdrs_block}" + f"{srcs_block}" + f"{data_block}" + f"{deps_block}" + f"{defines_block}" + f"{testonly_block}" + f"{includes_block}" + f" PUBLIC\n)\n\n" + ) + + def iree_compiler_register_plugin(self, plugin_id, target): + plugin_id_block = self._convert_string_arg_block( + "PLUGIN_ID", plugin_id, quote=False + ) + target_block = self._convert_single_target_block("TARGET", target) + self._converter.body += ( + f"iree_compiler_register_plugin(\n" + f"{plugin_id_block}" + f"{target_block}" + f")\n\n" + ) + + def cc_test( + self, + name, + hdrs=None, + srcs=None, + copts=None, + defines=None, + data=None, + deps=None, + timeout=None, + args=None, + tags=None, + includes=None, + **kwargs, + ): + if self._should_skip_target(tags=tags, **kwargs): + return + name_block = self._convert_string_arg_block("NAME", name, quote=False) + hdrs_block = self._convert_string_list_block("HDRS", hdrs, sort=True) + srcs_block = self._convert_srcs_block(srcs) + copts_block = self._convert_string_list_block("COPTS", copts, sort=False) + defines_block = self._convert_string_list_block("DEFINES", defines) + data_block = self._convert_target_list_block("DATA", data) + deps_block = self._convert_target_list_block("DEPS", deps) + args_block = self._convert_string_list_block("ARGS", args) + labels_block = self._convert_string_list_block("LABELS", tags) + timeout_block = self._convert_timeout_arg_block("TIMEOUT", timeout) + includes_block = self._convert_includes_block(includes) + + self._converter.body += ( + f"iree_cc_test(\n" + f"{name_block}" + f"{hdrs_block}" + f"{srcs_block}" + f"{copts_block}" + f"{defines_block}" + f"{data_block}" + f"{deps_block}" + f"{args_block}" + f"{labels_block}" + f"{timeout_block}" + f"{includes_block}" + f")\n\n" + ) + + def cc_binary( + self, + name, + srcs=None, + data=None, + deps=None, + copts=None, + defines=None, + linkopts=None, + testonly=None, + includes=None, + **kwargs, + ): + if self._should_skip_target(**kwargs): + return + if linkopts: + self._convert_unimplemented_function("linkopts") + name_block = self._convert_string_arg_block("NAME", name, quote=False) + copts_block = self._convert_string_list_block("COPTS", copts, sort=False) + defines_block = self._convert_string_list_block("DEFINES", defines) + srcs_block = self._convert_srcs_block(srcs) + data_block = self._convert_target_list_block("DATA", data) + deps_block = self._convert_target_list_block("DEPS", deps) + testonly_block = self._convert_option_block("TESTONLY", testonly) + includes_block = self._convert_includes_block(includes) + + self._converter.body += ( + f"iree_cc_binary(\n" + f"{name_block}" + f"{srcs_block}" + f"{copts_block}" + f"{defines_block}" + f"{data_block}" + f"{deps_block}" + f"{testonly_block}" + f"{includes_block}" + f")\n\n" + ) + + def c_embed_data( + self, + name, + srcs, + c_file_output, + h_file_output, + testonly=None, + strip_prefix=None, + flatten=None, + identifier=None, + deps=None, + **kwargs, + ): + if self._should_skip_target(**kwargs): + return + name_block = self._convert_string_arg_block("NAME", name, quote=False) + srcs_block = self._convert_srcs_block(srcs) + c_file_output_block = self._convert_string_arg_block( + "C_FILE_OUTPUT", c_file_output + ) + h_file_output_block = self._convert_string_arg_block( + "H_FILE_OUTPUT", h_file_output + ) + testonly_block = self._convert_option_block("TESTONLY", testonly) + identifier_block = self._convert_string_arg_block("IDENTIFIER", identifier) + flatten_block = self._convert_option_block("FLATTEN", flatten) + deps_block = self._convert_target_list_block("DEPS", deps) + + self._converter.body += ( + f"iree_c_embed_data(\n" + f"{name_block}" + f"{srcs_block}" + f"{deps_block}" + f"{c_file_output_block}" + f"{h_file_output_block}" + f"{identifier_block}" + f"{testonly_block}" + f"{flatten_block}" + f" PUBLIC\n)\n\n" + ) + + def iree_bitcode_library(self, name, arch, srcs, internal_hdrs=None, copts=None): + name_block = self._convert_string_arg_block("NAME", name, quote=False) + arch_block = self._convert_string_arg_block("ARCH", arch, quote=False) + srcs_block = self._convert_srcs_block(srcs) + copts_block = self._convert_string_list_block("COPTS", copts, sort=False) + + self._converter.body += ( + f"iree_bitcode_library(\n" + f"{name_block}" + f"{arch_block}" + f"{srcs_block}" + f"{copts_block}" + f")\n\n" + ) + + def iree_link_bitcode(self, name, bitcode_files): + name_block = self._convert_string_arg_block("NAME", name, quote=False) + bitcode_files_block = self._convert_srcs_block( + [f.replace(":", "/") for f in bitcode_files] + ) + + self._converter.body += ( + f"iree_link_bitcode(\n" f"{name_block}" f"{bitcode_files_block}" f"\n)\n\n" + ) + + def iree_bytecode_module( + self, + name, + src, + module_name=None, + flags=None, + compile_tool=None, + c_identifier=None, + static_lib_path=None, + deps=None, + testonly=None, + ): + name_block = self._convert_string_arg_block("NAME", name, quote=False) + src_block = self._convert_string_arg_block("SRC", src) + module_name_block = self._convert_string_arg_block( + "MODULE_FILE_NAME", module_name + ) + c_identifier_block = self._convert_string_arg_block( + "C_IDENTIFIER", c_identifier + ) + static_lib_block = self._convert_string_arg_block( + "STATIC_LIB_PATH", static_lib_path + ) + compile_tool_block = self._convert_target_block("COMPILE_TOOL", compile_tool) + flags_block = self._convert_string_list_block("FLAGS", flags) + deps_block = self._convert_target_list_block("DEPS", deps) + testonly_block = self._convert_option_block("TESTONLY", testonly) + + self._converter.body += ( + f"iree_bytecode_module(\n" + f"{name_block}" + f"{src_block}" + f"{module_name_block}" + f"{c_identifier_block}" + f"{compile_tool_block}" + f"{static_lib_block}" + f"{flags_block}" + f"{deps_block}" + f"{testonly_block}" + f" PUBLIC\n)\n\n" + ) + + def iree_flatbuffer_c_library(self, name, srcs, flatcc_args=None): + name_block = self._convert_string_arg_block("NAME", name, quote=False) + srcs_block = self._convert_srcs_block(srcs) + flatcc_args_block = self._convert_string_list_block("FLATCC_ARGS", flatcc_args) + + self._converter.body += ( + f"flatbuffer_c_library(\n" + f"{name_block}" + f"{srcs_block}" + f"{flatcc_args_block}" + f" PUBLIC\n)\n\n" + ) + + def gentbl_cc_library( + self, + name, + tblgen, + td_file, + tbl_outs, + td_srcs=None, + deps=None, + includes=None, + strip_include_prefix=None, + test=None, + ): + name_block = self._convert_string_arg_block("NAME", name, quote=False) + tblgen_block = self._convert_tblgen_block(tblgen) + td_file_block = self._convert_td_file_block(td_file) + outs_block = self._convert_tbl_outs_block(tbl_outs) + + self._converter.body += ( + f"iree_tablegen_library(\n" + f"{name_block}" + f"{td_file_block}" + f"{outs_block}" + f"{tblgen_block}" + f")\n\n" + ) + + def iree_gentbl_cc_library(self, **kwargs): + if self._should_skip_target(**kwargs): + return + # The bazel version of this rule adds some include directories and defs + # that are implicitly handled by the cmake version. + self.gentbl_cc_library(**kwargs) + + def iree_tablegen_doc( + self, + name, + tblgen, + td_file, + tbl_outs, + td_srcs=None, + includes=None, + deps=None, + test=None, + ): + name_block = self._convert_string_arg_block("NAME", name, quote=False) + tblgen_block = self._convert_tblgen_block(tblgen) + td_file_block = self._convert_td_file_block(td_file) + outs_block = self._convert_tbl_outs_block(tbl_outs) + self._converter.body += ( - f"list(REMOVE_ITEM {glob_var} {self._expand_cmake_var(exclude_var)})\n" - ) - return [self._expand_cmake_var(var) for var in glob_vars] - - # TODO(gcmn) implement these types of functions in a less hard-coded way - def platform_trampoline_deps(self, basename, path="base"): - return [f"//{path}/internal:{basename}_internal"] - - def select(self, d): - self._convert_unimplemented_function("select", str(d)) - return d["//conditions:default"] - - def defaulting_select(self, selector): - """Defined in build_defs.oss.bzl as a scoped alternative to select.""" - default_value = selector.get("//conditions:default") - if default_value is None: - raise ValueError("bazel_to_cmake can only convert selects with a default") - return default_value - - def cc_library(self, - name, - hdrs=None, - textual_hdrs=None, - srcs=None, - copts=None, - defines=None, - data=None, - deps=None, - testonly=None, - linkopts=None, - includes=None, - **kwargs): - if self._should_skip_target(**kwargs): - return - if linkopts: - self._convert_unimplemented_function("linkopts") - name_block = self._convert_string_arg_block("NAME", name, quote=False) - hdrs_block = self._convert_string_list_block("HDRS", hdrs, sort=True) - textual_hdrs_block = self._convert_string_list_block("TEXTUAL_HDRS", - textual_hdrs, - sort=True) - srcs_block = self._convert_srcs_block(srcs) - copts_block = self._convert_string_list_block("COPTS", copts, sort=False) - defines_block = self._convert_string_list_block("DEFINES", defines) - data_block = self._convert_target_list_block("DATA", data) - deps_block = self._convert_target_list_block("DEPS", deps) - testonly_block = self._convert_option_block("TESTONLY", testonly) - includes_block = self._convert_includes_block(includes) - - self._converter.body += (f"iree_cc_library(\n" - f"{name_block}" - f"{copts_block}" - f"{hdrs_block}" - f"{textual_hdrs_block}" - f"{srcs_block}" - f"{data_block}" - f"{deps_block}" - f"{defines_block}" - f"{testonly_block}" - f"{includes_block}" - f" PUBLIC\n)\n\n") - - def iree_compiler_register_plugin(self, plugin_id, target): - plugin_id_block = self._convert_string_arg_block("PLUGIN_ID", - plugin_id, - quote=False) - target_block = self._convert_single_target_block("TARGET", target) - self._converter.body += (f"iree_compiler_register_plugin(\n" - f"{plugin_id_block}" - f"{target_block}" - f")\n\n") - - def cc_test(self, - name, - hdrs=None, - srcs=None, - copts=None, - defines=None, - data=None, - deps=None, - timeout=None, - args=None, - tags=None, - includes=None, - **kwargs): - if self._should_skip_target(tags=tags, **kwargs): - return - name_block = self._convert_string_arg_block("NAME", name, quote=False) - hdrs_block = self._convert_string_list_block("HDRS", hdrs, sort=True) - srcs_block = self._convert_srcs_block(srcs) - copts_block = self._convert_string_list_block("COPTS", copts, sort=False) - defines_block = self._convert_string_list_block("DEFINES", defines) - data_block = self._convert_target_list_block("DATA", data) - deps_block = self._convert_target_list_block("DEPS", deps) - args_block = self._convert_string_list_block("ARGS", args) - labels_block = self._convert_string_list_block("LABELS", tags) - timeout_block = self._convert_timeout_arg_block("TIMEOUT", timeout) - includes_block = self._convert_includes_block(includes) - - self._converter.body += (f"iree_cc_test(\n" - f"{name_block}" - f"{hdrs_block}" - f"{srcs_block}" - f"{copts_block}" - f"{defines_block}" - f"{data_block}" - f"{deps_block}" - f"{args_block}" - f"{labels_block}" - f"{timeout_block}" - f"{includes_block}" - f")\n\n") - - def cc_binary(self, - name, - srcs=None, - data=None, - deps=None, - copts=None, - defines=None, - linkopts=None, - testonly=None, - includes=None, - **kwargs): - if self._should_skip_target(**kwargs): - return - if linkopts: - self._convert_unimplemented_function("linkopts") - name_block = self._convert_string_arg_block("NAME", name, quote=False) - copts_block = self._convert_string_list_block("COPTS", copts, sort=False) - defines_block = self._convert_string_list_block("DEFINES", defines) - srcs_block = self._convert_srcs_block(srcs) - data_block = self._convert_target_list_block("DATA", data) - deps_block = self._convert_target_list_block("DEPS", deps) - testonly_block = self._convert_option_block("TESTONLY", testonly) - includes_block = self._convert_includes_block(includes) - - self._converter.body += (f"iree_cc_binary(\n" - f"{name_block}" - f"{srcs_block}" - f"{copts_block}" - f"{defines_block}" - f"{data_block}" - f"{deps_block}" - f"{testonly_block}" - f"{includes_block}" - f")\n\n") - - def c_embed_data(self, - name, - srcs, - c_file_output, - h_file_output, - testonly=None, - strip_prefix=None, - flatten=None, - identifier=None, - deps=None, - **kwargs): - if self._should_skip_target(**kwargs): - return - name_block = self._convert_string_arg_block("NAME", name, quote=False) - srcs_block = self._convert_srcs_block(srcs) - c_file_output_block = self._convert_string_arg_block( - "C_FILE_OUTPUT", c_file_output) - h_file_output_block = self._convert_string_arg_block( - "H_FILE_OUTPUT", h_file_output) - testonly_block = self._convert_option_block("TESTONLY", testonly) - identifier_block = self._convert_string_arg_block("IDENTIFIER", identifier) - flatten_block = self._convert_option_block("FLATTEN", flatten) - deps_block = self._convert_target_list_block("DEPS", deps) - - self._converter.body += (f"iree_c_embed_data(\n" - f"{name_block}" - f"{srcs_block}" - f"{deps_block}" - f"{c_file_output_block}" - f"{h_file_output_block}" - f"{identifier_block}" - f"{testonly_block}" - f"{flatten_block}" - f" PUBLIC\n)\n\n") - - def iree_bitcode_library(self, - name, - arch, - srcs, - internal_hdrs=None, - copts=None): - name_block = self._convert_string_arg_block("NAME", name, quote=False) - arch_block = self._convert_string_arg_block("ARCH", arch, quote=False) - srcs_block = self._convert_srcs_block(srcs) - copts_block = self._convert_string_list_block("COPTS", copts, sort=False) - - self._converter.body += (f"iree_bitcode_library(\n" - f"{name_block}" - f"{arch_block}" - f"{srcs_block}" - f"{copts_block}" - f")\n\n") - - def iree_link_bitcode(self, name, bitcode_files): - name_block = self._convert_string_arg_block("NAME", name, quote=False) - bitcode_files_block = self._convert_srcs_block( - [f.replace(":", "/") for f in bitcode_files]) - - self._converter.body += (f"iree_link_bitcode(\n" - f"{name_block}" - f"{bitcode_files_block}" - f"\n)\n\n") - - def iree_bytecode_module(self, - name, - src, - module_name=None, - flags=None, - compile_tool=None, - c_identifier=None, - static_lib_path=None, - deps=None, - testonly=None): - name_block = self._convert_string_arg_block("NAME", name, quote=False) - src_block = self._convert_string_arg_block("SRC", src) - module_name_block = self._convert_string_arg_block("MODULE_FILE_NAME", - module_name) - c_identifier_block = self._convert_string_arg_block("C_IDENTIFIER", - c_identifier) - static_lib_block = self._convert_string_arg_block("STATIC_LIB_PATH", - static_lib_path) - compile_tool_block = self._convert_target_block("COMPILE_TOOL", - compile_tool) - flags_block = self._convert_string_list_block("FLAGS", flags) - deps_block = self._convert_target_list_block("DEPS", deps) - testonly_block = self._convert_option_block("TESTONLY", testonly) - - self._converter.body += (f"iree_bytecode_module(\n" - f"{name_block}" - f"{src_block}" - f"{module_name_block}" - f"{c_identifier_block}" - f"{compile_tool_block}" - f"{static_lib_block}" - f"{flags_block}" - f"{deps_block}" - f"{testonly_block}" - f" PUBLIC\n)\n\n") - - def iree_flatbuffer_c_library(self, name, srcs, flatcc_args=None): - name_block = self._convert_string_arg_block("NAME", name, quote=False) - srcs_block = self._convert_srcs_block(srcs) - flatcc_args_block = self._convert_string_list_block("FLATCC_ARGS", - flatcc_args) - - self._converter.body += (f"flatbuffer_c_library(\n" - f"{name_block}" - f"{srcs_block}" - f"{flatcc_args_block}" - f" PUBLIC\n)\n\n") - - def gentbl_cc_library(self, - name, - tblgen, - td_file, - tbl_outs, - td_srcs=None, - deps=None, - includes=None, - strip_include_prefix=None, - test=None): - name_block = self._convert_string_arg_block("NAME", name, quote=False) - tblgen_block = self._convert_tblgen_block(tblgen) - td_file_block = self._convert_td_file_block(td_file) - outs_block = self._convert_tbl_outs_block(tbl_outs) - - self._converter.body += (f"iree_tablegen_library(\n" - f"{name_block}" - f"{td_file_block}" - f"{outs_block}" - f"{tblgen_block}" - f")\n\n") - - def iree_gentbl_cc_library(self, **kwargs): - if self._should_skip_target(**kwargs): - return - # The bazel version of this rule adds some include directories and defs - # that are implicitly handled by the cmake version. - self.gentbl_cc_library(**kwargs) - - def iree_tablegen_doc(self, - name, - tblgen, - td_file, - tbl_outs, - td_srcs=None, - includes=None, - deps=None, - test=None): - name_block = self._convert_string_arg_block("NAME", name, quote=False) - tblgen_block = self._convert_tblgen_block(tblgen) - td_file_block = self._convert_td_file_block(td_file) - outs_block = self._convert_tbl_outs_block(tbl_outs) - - self._converter.body += (f"iree_tablegen_doc(\n" - f"{name_block}" - f"{td_file_block}" - f"{outs_block}" - f"{tblgen_block}" - f")\n\n") - - def iree_lit_test_suite(self, - name, - srcs, - tools=None, - data=None, - tags=None, - timeout=None, - **kwargs): - if self._should_skip_target(tags=tags, **kwargs): - return - name_block = self._convert_string_arg_block("NAME", name, quote=False) - srcs_block = self._convert_srcs_block(srcs) - tools_block = self._convert_target_list_block("TOOLS", tools) - data_block = self._convert_target_list_block("DATA", data) - labels_block = self._convert_string_list_block("LABELS", tags) - timeout_block = self._convert_timeout_arg_block("TIMEOUT", timeout) - - self._converter.body += (f"iree_lit_test_suite(\n" - f"{name_block}" - f"{srcs_block}" - f"{tools_block}" - f"{data_block}" - f"{labels_block}" - f"{timeout_block}" - f")\n\n") - - def iree_check_single_backend_test_suite(self, - name, - srcs, - target_backend, - driver=None, - compiler_flags=None, - target_backends_and_drivers=None, - runner_args=None, - tags=None, - target_cpu_features=None, - timeout=None, - **kwargs): - if self._should_skip_target(tags=tags, **kwargs): - return - name_block = self._convert_string_arg_block("NAME", name, quote=False) - srcs_block = self._convert_srcs_block(srcs) - target_backend_block = self._convert_string_arg_block( - "TARGET_BACKEND", target_backend) - driver_block = self._convert_string_arg_block("DRIVER", driver) - compiler_flags_block = self._convert_string_list_block( - "COMPILER_FLAGS", compiler_flags) - runner_args_block = self._convert_string_list_block("RUNNER_ARGS", - runner_args) - labels_block = self._convert_string_list_block("LABELS", tags) - target_cpu_features_block = self._convert_string_arg_block( - "TARGET_CPU_FEATURES", target_cpu_features) - timeout_block = self._convert_timeout_arg_block("TIMEOUT", timeout) - - self._converter.body += (f"iree_check_single_backend_test_suite(\n" - f"{name_block}" - f"{srcs_block}" - f"{target_backend_block}" - f"{driver_block}" - f"{compiler_flags_block}" - f"{runner_args_block}" - f"{labels_block}" - f"{target_cpu_features_block}" - f"{timeout_block}" - f")\n\n") - - def iree_check_test_suite(self, - name, - srcs, - target_backends_and_drivers=None, - compiler_flags=None, - runner_args=None, - tags=None, - target_cpu_features_variants=None, - timeout=None, - **kwargs): - if self._should_skip_target(tags=tags, **kwargs): - return - target_backends = None - drivers = None - if target_backends_and_drivers is not None: - target_backends = [it[0] for it in target_backends_and_drivers] - drivers = [it[1] for it in target_backends_and_drivers] - - name_block = self._convert_string_arg_block("NAME", name, quote=False) - srcs_block = self._convert_srcs_block(srcs) - target_backends_block = self._convert_string_list_block( - "TARGET_BACKENDS", target_backends) - drivers_block = self._convert_string_list_block("DRIVERS", drivers) - compiler_flags_block = self._convert_string_list_block( - "COMPILER_FLAGS", compiler_flags) - runner_args_block = self._convert_string_list_block("RUNNER_ARGS", - runner_args) - labels_block = self._convert_string_list_block("LABELS", tags) - target_cpu_features_variants_block = self._convert_string_list_block( - "TARGET_CPU_FEATURES_VARIANTS", target_cpu_features_variants) - timeout_block = self._convert_timeout_arg_block("TIMEOUT", timeout) - - self._converter.body += (f"iree_check_test_suite(\n" - f"{name_block}" - f"{srcs_block}" - f"{target_backends_block}" - f"{drivers_block}" - f"{compiler_flags_block}" - f"{runner_args_block}" - f"{labels_block}" - f"{target_cpu_features_variants_block}" - f"{timeout_block}" - f")\n\n") - - def iree_generated_trace_runner_test(self, - name, - generator, - generator_args=None, - trace_runner=None, - target_backends_and_drivers=None, - compiler_flags=None, - runner_args=None, - tags=None, - target_cpu_features_variants=None, - **kwargs): - if self._should_skip_target(tags=tags, **kwargs): - return - target_backends = None - drivers = None - if target_backends_and_drivers is not None: - target_backends = [it[0] for it in target_backends_and_drivers] - drivers = [it[1] for it in target_backends_and_drivers] - - name_block = self._convert_string_arg_block("NAME", name, quote=False) - # For now we assume that the generator target is a py_binary with a single - # source .py file named like it. - generator_py = f"{generator.split(':')[-1]}.py" - generator_block = self._convert_string_arg_block("GENERATOR", - generator_py, - quote=True) - generator_args_block = self._convert_string_list_block( - "GENERATOR_ARGS", generator_args) - trace_runner_block = self._convert_target_block("TRACE_RUNNER", - trace_runner) - target_backends_block = self._convert_string_list_block( - "TARGET_BACKENDS", target_backends) - drivers_block = self._convert_string_list_block("DRIVERS", drivers) - compiler_flags_block = self._convert_string_list_block( - "COMPILER_FLAGS", compiler_flags) - runner_args_block = self._convert_string_list_block("RUNNER_ARGS", - runner_args) - labels_block = self._convert_string_list_block("LABELS", tags) - target_cpu_features_variants_block = self._convert_string_list_block( - "TARGET_CPU_FEATURES_VARIANTS", target_cpu_features_variants) - - self._converter.body += (f"iree_generated_trace_runner_test(\n" - f"{name_block}" - f"{generator_block}" - f"{generator_args_block}" - f"{trace_runner_block}" - f"{target_backends_block}" - f"{drivers_block}" - f"{compiler_flags_block}" - f"{runner_args_block}" - f"{labels_block}" - f"{target_cpu_features_variants_block}" - f")\n\n") - - def native_test(self, - name, - src, - args=None, - data=None, - tags=None, - timeout=None): - if self._should_skip_target(tags=tags): - return - if data is not None: - self._convert_unimplemented_function("native_test", name + " has data") - - name_block = self._convert_string_arg_block("NAME", name) - test_binary_block = self._convert_single_target_block("SRC", src) - args_block = self._convert_string_list_block("ARGS", args) - labels_block = self._convert_string_list_block("LABELS", tags) - timeout_block = self._convert_timeout_arg_block("TIMEOUT", timeout) - - self._converter.body += (f"iree_native_test(\n" - f"{name_block}" - f"{args_block}" - f"{test_binary_block}" - f"{labels_block}" - f")\n\n") - - def cc_binary_benchmark( - self, - name, - srcs=None, - data=None, - deps=None, - copts=None, - defines=None, - linkopts=None, - tags=None, - testonly=True, - # unused - size="small", - timeout=None): - if self._should_skip_target(tags=tags): - return - name_block = self._convert_string_arg_block("NAME", name, quote=False) - srcs_block = self._convert_srcs_block(srcs) - data_block = self._convert_target_list_block("DATA", data) - deps_block = self._convert_target_list_block("DEPS", deps) - copts_block = self._convert_string_list_block("COPTS", copts, sort=False) - defines_block = self._convert_string_list_block("DEFINES", defines) - defines_block = self._convert_string_list_block("LINKOPTS", linkopts) - testonly_block = self._convert_option_block("TESTONLY", testonly) - labels_block = self._convert_string_list_block("LABELS", tags) - - self._converter.body += (f"iree_cc_binary_benchmark(\n" - f"{name_block}" - f"{srcs_block}" - f"{data_block}" - f"{deps_block}" - f"{copts_block}" - f"{defines_block}" - f"{defines_block}" - f"{testonly_block}" - f"{labels_block}" - f")\n\n") - - def iree_cmake_extra_content(self, content, inline=False): - if inline: - self._converter.body += (f"\n{content}\n") - else: - self._converter.header += (f"\n{content}\n") + f"iree_tablegen_doc(\n" + f"{name_block}" + f"{td_file_block}" + f"{outs_block}" + f"{tblgen_block}" + f")\n\n" + ) + + def iree_lit_test_suite( + self, name, srcs, tools=None, data=None, tags=None, timeout=None, **kwargs + ): + if self._should_skip_target(tags=tags, **kwargs): + return + name_block = self._convert_string_arg_block("NAME", name, quote=False) + srcs_block = self._convert_srcs_block(srcs) + tools_block = self._convert_target_list_block("TOOLS", tools) + data_block = self._convert_target_list_block("DATA", data) + labels_block = self._convert_string_list_block("LABELS", tags) + timeout_block = self._convert_timeout_arg_block("TIMEOUT", timeout) + + self._converter.body += ( + f"iree_lit_test_suite(\n" + f"{name_block}" + f"{srcs_block}" + f"{tools_block}" + f"{data_block}" + f"{labels_block}" + f"{timeout_block}" + f")\n\n" + ) + + def iree_check_single_backend_test_suite( + self, + name, + srcs, + target_backend, + driver=None, + compiler_flags=None, + target_backends_and_drivers=None, + runner_args=None, + tags=None, + target_cpu_features=None, + timeout=None, + **kwargs, + ): + if self._should_skip_target(tags=tags, **kwargs): + return + name_block = self._convert_string_arg_block("NAME", name, quote=False) + srcs_block = self._convert_srcs_block(srcs) + target_backend_block = self._convert_string_arg_block( + "TARGET_BACKEND", target_backend + ) + driver_block = self._convert_string_arg_block("DRIVER", driver) + compiler_flags_block = self._convert_string_list_block( + "COMPILER_FLAGS", compiler_flags + ) + runner_args_block = self._convert_string_list_block("RUNNER_ARGS", runner_args) + labels_block = self._convert_string_list_block("LABELS", tags) + target_cpu_features_block = self._convert_string_arg_block( + "TARGET_CPU_FEATURES", target_cpu_features + ) + timeout_block = self._convert_timeout_arg_block("TIMEOUT", timeout) + + self._converter.body += ( + f"iree_check_single_backend_test_suite(\n" + f"{name_block}" + f"{srcs_block}" + f"{target_backend_block}" + f"{driver_block}" + f"{compiler_flags_block}" + f"{runner_args_block}" + f"{labels_block}" + f"{target_cpu_features_block}" + f"{timeout_block}" + f")\n\n" + ) + + def iree_check_test_suite( + self, + name, + srcs, + target_backends_and_drivers=None, + compiler_flags=None, + runner_args=None, + tags=None, + target_cpu_features_variants=None, + timeout=None, + **kwargs, + ): + if self._should_skip_target(tags=tags, **kwargs): + return + target_backends = None + drivers = None + if target_backends_and_drivers is not None: + target_backends = [it[0] for it in target_backends_and_drivers] + drivers = [it[1] for it in target_backends_and_drivers] + + name_block = self._convert_string_arg_block("NAME", name, quote=False) + srcs_block = self._convert_srcs_block(srcs) + target_backends_block = self._convert_string_list_block( + "TARGET_BACKENDS", target_backends + ) + drivers_block = self._convert_string_list_block("DRIVERS", drivers) + compiler_flags_block = self._convert_string_list_block( + "COMPILER_FLAGS", compiler_flags + ) + runner_args_block = self._convert_string_list_block("RUNNER_ARGS", runner_args) + labels_block = self._convert_string_list_block("LABELS", tags) + target_cpu_features_variants_block = self._convert_string_list_block( + "TARGET_CPU_FEATURES_VARIANTS", target_cpu_features_variants + ) + timeout_block = self._convert_timeout_arg_block("TIMEOUT", timeout) + + self._converter.body += ( + f"iree_check_test_suite(\n" + f"{name_block}" + f"{srcs_block}" + f"{target_backends_block}" + f"{drivers_block}" + f"{compiler_flags_block}" + f"{runner_args_block}" + f"{labels_block}" + f"{target_cpu_features_variants_block}" + f"{timeout_block}" + f")\n\n" + ) + + def iree_generated_trace_runner_test( + self, + name, + generator, + generator_args=None, + trace_runner=None, + target_backends_and_drivers=None, + compiler_flags=None, + runner_args=None, + tags=None, + target_cpu_features_variants=None, + **kwargs, + ): + if self._should_skip_target(tags=tags, **kwargs): + return + target_backends = None + drivers = None + if target_backends_and_drivers is not None: + target_backends = [it[0] for it in target_backends_and_drivers] + drivers = [it[1] for it in target_backends_and_drivers] + + name_block = self._convert_string_arg_block("NAME", name, quote=False) + # For now we assume that the generator target is a py_binary with a single + # source .py file named like it. + generator_py = f"{generator.split(':')[-1]}.py" + generator_block = self._convert_string_arg_block( + "GENERATOR", generator_py, quote=True + ) + generator_args_block = self._convert_string_list_block( + "GENERATOR_ARGS", generator_args + ) + trace_runner_block = self._convert_target_block("TRACE_RUNNER", trace_runner) + target_backends_block = self._convert_string_list_block( + "TARGET_BACKENDS", target_backends + ) + drivers_block = self._convert_string_list_block("DRIVERS", drivers) + compiler_flags_block = self._convert_string_list_block( + "COMPILER_FLAGS", compiler_flags + ) + runner_args_block = self._convert_string_list_block("RUNNER_ARGS", runner_args) + labels_block = self._convert_string_list_block("LABELS", tags) + target_cpu_features_variants_block = self._convert_string_list_block( + "TARGET_CPU_FEATURES_VARIANTS", target_cpu_features_variants + ) + + self._converter.body += ( + f"iree_generated_trace_runner_test(\n" + f"{name_block}" + f"{generator_block}" + f"{generator_args_block}" + f"{trace_runner_block}" + f"{target_backends_block}" + f"{drivers_block}" + f"{compiler_flags_block}" + f"{runner_args_block}" + f"{labels_block}" + f"{target_cpu_features_variants_block}" + f")\n\n" + ) + + def native_test(self, name, src, args=None, data=None, tags=None, timeout=None): + if self._should_skip_target(tags=tags): + return + if data is not None: + self._convert_unimplemented_function("native_test", name + " has data") + + name_block = self._convert_string_arg_block("NAME", name) + test_binary_block = self._convert_single_target_block("SRC", src) + args_block = self._convert_string_list_block("ARGS", args) + labels_block = self._convert_string_list_block("LABELS", tags) + timeout_block = self._convert_timeout_arg_block("TIMEOUT", timeout) + + self._converter.body += ( + f"iree_native_test(\n" + f"{name_block}" + f"{args_block}" + f"{test_binary_block}" + f"{labels_block}" + f")\n\n" + ) + + def cc_binary_benchmark( + self, + name, + srcs=None, + data=None, + deps=None, + copts=None, + defines=None, + linkopts=None, + tags=None, + testonly=True, + # unused + size="small", + timeout=None, + ): + if self._should_skip_target(tags=tags): + return + name_block = self._convert_string_arg_block("NAME", name, quote=False) + srcs_block = self._convert_srcs_block(srcs) + data_block = self._convert_target_list_block("DATA", data) + deps_block = self._convert_target_list_block("DEPS", deps) + copts_block = self._convert_string_list_block("COPTS", copts, sort=False) + defines_block = self._convert_string_list_block("DEFINES", defines) + defines_block = self._convert_string_list_block("LINKOPTS", linkopts) + testonly_block = self._convert_option_block("TESTONLY", testonly) + labels_block = self._convert_string_list_block("LABELS", tags) + + self._converter.body += ( + f"iree_cc_binary_benchmark(\n" + f"{name_block}" + f"{srcs_block}" + f"{data_block}" + f"{deps_block}" + f"{copts_block}" + f"{defines_block}" + f"{defines_block}" + f"{testonly_block}" + f"{labels_block}" + f")\n\n" + ) + + def iree_cmake_extra_content(self, content, inline=False): + if inline: + self._converter.body += f"\n{content}\n" + else: + self._converter.header += f"\n{content}\n" class Converter(object): - """Conversion state tracking and full file template substitution.""" + """Conversion state tracking and full file template substitution.""" - def __init__(self): - # Header appears after the license block but before `iree_add_all_subdirs`. - self.header = "" - # Body appears after `iree_add_all_subdirs`. - self.body = "" + def __init__(self): + # Header appears after the license block but before `iree_add_all_subdirs`. + self.header = "" + # Body appears after `iree_add_all_subdirs`. + self.body = "" - self.first_error = None + self.first_error = None - def convert(self): - converted_content = (f"{self.header}\n\n" - f"iree_add_all_subdirs()\n\n" - f"{self.body}") + def convert(self): + converted_content = ( + f"{self.header}\n\n" f"iree_add_all_subdirs()\n\n" f"{self.body}" + ) - # Cleanup newline characters. This is more convenient than ensuring all - # conversions are careful with where they insert newlines. - converted_content = converted_content.replace("\n\n\n", "\n") - converted_content = converted_content.rstrip() + "\n" + # Cleanup newline characters. This is more convenient than ensuring all + # conversions are careful with where they insert newlines. + converted_content = converted_content.replace("\n\n\n", "\n") + converted_content = converted_content.rstrip() + "\n" - return converted_content + return converted_content def GetDict(obj): - ret = {} - for k in dir(obj): - if not k.startswith("_"): - ret[k] = getattr(obj, k) - return ret - - -def convert_build_file(build_file_code, - repo_cfg, - allow_partial_conversion=False): - converter = Converter() - # Allow overrides of TargetConverter and BuildFileFunctions from repo cfg. - repo_map = getattr(repo_cfg, "REPO_MAP", {}) - target_converter = getattr( - repo_cfg, "CustomTargetConverter", - bazel_to_cmake_targets.TargetConverter)(repo_map=repo_map) - build_file_functions = getattr(repo_cfg, "CustomBuildFileFunctions", - BuildFileFunctions)(converter=converter, - targets=target_converter) - - exec(build_file_code, GetDict(build_file_functions)) - converted_text = converter.convert() - if not allow_partial_conversion and converter.first_error: - raise converter.first_error # pylint: disable=raising-bad-type - return converted_text + ret = {} + for k in dir(obj): + if not k.startswith("_"): + ret[k] = getattr(obj, k) + return ret + + +def convert_build_file(build_file_code, repo_cfg, allow_partial_conversion=False): + converter = Converter() + # Allow overrides of TargetConverter and BuildFileFunctions from repo cfg. + repo_map = getattr(repo_cfg, "REPO_MAP", {}) + target_converter = getattr( + repo_cfg, "CustomTargetConverter", bazel_to_cmake_targets.TargetConverter + )(repo_map=repo_map) + build_file_functions = getattr( + repo_cfg, "CustomBuildFileFunctions", BuildFileFunctions + )(converter=converter, targets=target_converter) + + exec(build_file_code, GetDict(build_file_functions)) + converted_text = converter.convert() + if not allow_partial_conversion and converter.first_error: + raise converter.first_error # pylint: disable=raising-bad-type + return converted_text diff --git a/build_tools/bazel_to_cmake/bazel_to_cmake_targets.py b/build_tools/bazel_to_cmake/bazel_to_cmake_targets.py index ca6c38560b8e..14d56bc89a86 100644 --- a/build_tools/bazel_to_cmake/bazel_to_cmake_targets.py +++ b/build_tools/bazel_to_cmake/bazel_to_cmake_targets.py @@ -9,230 +9,228 @@ class TargetConverter: - - def __init__(self, repo_map: Dict[str, str]): - self._explicit_target_mapping = {} - self._repo_map = repo_map - - iree_core_repo = self._repo_alias("@iree_core") - self._update_target_mappings({ - # Internal utilities to emulate various binary/library options. - f"{iree_core_repo}//build_tools:default_linkopts": [], - f"{iree_core_repo}//build_tools:dl": ["${CMAKE_DL_LIBS}"], - f"{iree_core_repo}//compiler/src/iree/compiler/API:CAPI": [ - "IREECompilerCAPILib" - ], - - # IREE llvm-external-projects - f"{iree_core_repo}//llvm-external-projects/iree-dialects:CAPI": [ - "IREEDialectsCAPI" - ], - - # Disable all hard-coded codegen targets (they are expanded dynamically - # in CMake). - "@llvm-project//llvm:AArch64AsmParser": ["IREELLVMCPUTargetDeps"], - "@llvm-project//llvm:AArch64CodeGen": ["IREELLVMCPUTargetDeps"], - "@llvm-project//llvm:ARMAsmParser": ["IREELLVMCPUTargetDeps"], - "@llvm-project//llvm:ARMCodeGen": ["IREELLVMCPUTargetDeps"], - "@llvm-project//llvm:RISCVAsmParser": ["IREELLVMCPUTargetDeps"], - "@llvm-project//llvm:RISCVCodeGen": ["IREELLVMCPUTargetDeps"], - "@llvm-project//llvm:WebAssemblyAsmParser": ["IREELLVMCPUTargetDeps"], - "@llvm-project//llvm:WebAssemblyCodeGen": ["IREELLVMCPUTargetDeps"], - "@llvm-project//llvm:X86AsmParser": ["IREELLVMCPUTargetDeps"], - "@llvm-project//llvm:X86CodeGen": ["IREELLVMCPUTargetDeps"], - - # Clang - "@llvm-project//clang": ["${IREE_CLANG_TARGET}"], - - # LLD - "@llvm-project//lld": ["${IREE_LLD_TARGET}"], - "@llvm-project//lld:COFF": ["lldCOFF"], - "@llvm-project//lld:Common": ["lldCommon"], - "@llvm-project//lld:ELF": ["lldELF"], - "@llvm-project//lld:MachO": ["lldMachO"], - "@llvm-project//lld:Wasm": ["lldWasm"], - - # LLVM - "@llvm-project//llvm:config": [], - "@llvm-project//llvm:IPO": ["LLVMipo"], - "@llvm-project//llvm:FileCheck": ["FileCheck"], - "@llvm-project//llvm:not": ["not"], - "@llvm-project//llvm:llvm-link": ["${IREE_LLVM_LINK_TARGET}"], - "@llvm-project//llvm:NVPTXUtilsAndDesc": ["LLVMNVPTXDesc",], - - # MLIR - "@llvm-project//mlir:AllPassesAndDialects": ["MLIRAllDialects"], - "@llvm-project//mlir:CommonFolders": [""], - "@llvm-project//mlir:DialectUtils": [""], - "@llvm-project//mlir:GPUDialect": ["MLIRGPUDialect"], - "@llvm-project//mlir:GPUTransforms": ["MLIRGPUTransforms"], - "@llvm-project//mlir:LinalgStructuredOpsIncGen": [ - "MLIRLinalgStructuredOpsIncGenLib" - ], - "@llvm-project//mlir:ShapeTransforms": ["MLIRShapeOpsTransforms"], - "@llvm-project//mlir:ToLLVMIRTranslation": ["MLIRTargetLLVMIRExport"], - "@llvm-project//mlir:mlir-translate": ["mlir-translate"], - "@llvm-project//mlir:MlirLspServerLib": ["MLIRLspServerLib"], - "@llvm-project//mlir:MlirTableGenMain": ["MLIRTableGen"], - "@llvm-project//mlir:MlirOptLib": ["MLIROptLib"], - "@llvm-project//mlir:VectorOps": ["MLIRVector"], - - # StableHLO. - "@stablehlo//:chlo_ops": ["ChloOps",], - "@stablehlo//:stablehlo_ops": ["StablehloOps",], - "@stablehlo//:broadcast_utils": ["StablehloBroadcastUtils",], - - # NCCL - "@nccl//:headers": ["nccl::headers",], - - # Torch-MLIR. - "@torch-mlir-dialects//:TorchMLIRTMTensorDialect": [ - "TorchMLIRTMTensorDialect" - ], - - # Tracy. - "@tracy_client//:runtime": ["tracy_client::runtime"], - - # Vulkan - "@vulkan_headers": ["Vulkan::Headers"], - # Misc single targets - "@com_google_benchmark//:benchmark": ["benchmark"], - "@com_github_dvidelabs_flatcc//:flatcc": ["flatcc"], - "@com_github_dvidelabs_flatcc//:parsing": ["flatcc::parsing"], - "@com_github_dvidelabs_flatcc//:runtime": ["flatcc::runtime"], - "@com_github_yaml_libyaml//:yaml": ["yaml"], - "@com_google_googletest//:gtest": ["gmock", "gtest"], - "@spirv_cross//:spirv_cross_lib": ["spirv-cross-msl"], - "@cpuinfo": ["${IREE_CPUINFO_TARGET}"], - "@vulkan_memory_allocator//:impl_header_only": [ - "vulkan_memory_allocator" - ], - "@webgpu_headers": [], - }) - - self._initialize() - - def _initialize(self): - pass - - def _repo_alias(self, repo_name: str) -> str: - """Returns the prefix of a repo (i.e. '@iree_core') given the repo map.""" - return self._repo_map.get(repo_name, repo_name) - - def _update_target_mappings(self, mappings: Dict[str, List[str]]): - self._explicit_target_mapping.update(mappings) - - def _convert_mlir_target(self, target): - # Default to a pattern substitution approach. - # Take "MLIR" and append the name part of the full target identifier, e.g. - # "@llvm-project//mlir:IR" -> "MLIRIR" - # "@llvm-project//mlir:Pass" -> "MLIRPass" - # MLIR does not have header-only targets apart from the libraries. Here - # we redirect any request for a CAPI{Name}Headers to a target within IREE - # that sets this up. - label = target.rsplit(":")[-1] - if label.startswith("CAPI") and label.endswith("Headers"): - return [f"IREELLVMIncludeSetup"] - else: - return [f"MLIR{label}"] - - def _convert_llvm_target(self, target): - # Default to a pattern substitution approach. - # Prepend "LLVM" to the Bazel target name. - # "@llvm-project//llvm:AsmParser" -> "LLVMAsmParser" - # "@llvm-project//llvm:Core" -> "LLVMCore" - return ["LLVM" + target.rsplit(":")[-1]] - - def _convert_iree_cuda_target(self, target): - # Convert like: - # @iree_cuda//:libdevice_embedded -> iree_cuda::libdevice_embedded - label = target.rsplit(":")[-1] - return [f"iree_cuda::{label}"] - - def _convert_iree_dialects_target(self, target): - # Just take the target name as-is. - return [target.rsplit(":")[-1]] - - def _convert_to_cmake_path(self, bazel_path_fragment: str) -> str: - cmake_path = bazel_path_fragment - # Bazel `//iree/base` -> CMake `iree::base` - # Bazel `//iree/base:foo` -> CMake `iree::base::foo` - if cmake_path.startswith("//"): - cmake_path = cmake_path[len("//"):] - cmake_path = cmake_path.replace(":", "::") # iree/base::foo or ::foo - cmake_path = cmake_path.replace("/", "::") # iree::base - return cmake_path - - def convert_target(self, target): - """Converts a Bazel target to a list of CMake targets. - - IREE targets are expected to follow a standard form between Bazel and CMake - that facilitates conversion. External targets *may* have their own patterns, - or they may be purely special cases. - - Multiple target in Bazel may map to a single target in CMake and a Bazel - target may map to multiple CMake targets. - - Returns: - A list of converted targets if it was successfully converted. - - Raises: - KeyError: No conversion was found for the target. - """ - iree_core_repo = self._repo_alias("@iree_core") - if target in self._explicit_target_mapping: - return self._explicit_target_mapping[target] - if target.startswith("@llvm-project//llvm"): - return self._convert_llvm_target(target) - if target.startswith("@llvm-project//mlir"): - return self._convert_mlir_target(target) - if target.startswith("@iree_cuda//"): - return self._convert_iree_cuda_target(target) - if target.startswith(f"{iree_core_repo}//"): - return self._convert_iree_core_target(target) - if target.startswith("@"): - raise KeyError(f"No conversion found for target '{target}'") - - # Pass through package-relative targets - # :target_name - # file_name.txt - if target.startswith(":") or (":" not in target and - not target.startswith("/")): - return [self._convert_to_cmake_path(target)] - - return self._convert_unmatched_target(target) - - def _convert_iree_core_target(self, target): - iree_core_repo = self._repo_alias("@iree_core") - if target.startswith( - f"{iree_core_repo}//llvm-external-projects/iree-dialects"): - return self._convert_iree_dialects_target(target) - - # IREE root paths map to package names based on explicit rules. - # * src/iree/ directories (compiler/src/iree/ and runtime/src/iree/) - # creating their own root paths by trimming down to just "iree" - # * tools/ uses an empty root, for binary targets names like "iree-compile" - # * other top level directories add back an 'iree' prefix - # If changing these, make the corresponding change in iree_macros.cmake - # (iree_package_ns function). - - # Map //compiler/src/iree/(.*) -> iree::\1 (i.e. iree::compiler::\1) - m = re.match(f"^{iree_core_repo}//compiler/src/iree/(.+)", target) - if m: - return ["iree::" + self._convert_to_cmake_path(m.group(1))] - - # Map //runtime/src/iree/(.*) -> iree::\1 - m = re.match(f"^{iree_core_repo}//runtime/src/iree/(.+)", target) - if m: - return ["iree::" + self._convert_to_cmake_path(m.group(1))] - - # Map //tools/(.*) -> \1 - m = re.match(f"^{iree_core_repo}//tools[/|:](.+)", target) - if m: - return [self._convert_to_cmake_path(m.group(1))] - - return self._convert_unmatched_target(target) - - def _convert_unmatched_target(self, target: str) -> str: - """Converts unmatched targets in a repo specific way.""" - raise ValueError(f"No target matching for {target}") + def __init__(self, repo_map: Dict[str, str]): + self._explicit_target_mapping = {} + self._repo_map = repo_map + + iree_core_repo = self._repo_alias("@iree_core") + self._update_target_mappings( + { + # Internal utilities to emulate various binary/library options. + f"{iree_core_repo}//build_tools:default_linkopts": [], + f"{iree_core_repo}//build_tools:dl": ["${CMAKE_DL_LIBS}"], + f"{iree_core_repo}//compiler/src/iree/compiler/API:CAPI": [ + "IREECompilerCAPILib" + ], + # IREE llvm-external-projects + f"{iree_core_repo}//llvm-external-projects/iree-dialects:CAPI": [ + "IREEDialectsCAPI" + ], + # Disable all hard-coded codegen targets (they are expanded dynamically + # in CMake). + "@llvm-project//llvm:AArch64AsmParser": ["IREELLVMCPUTargetDeps"], + "@llvm-project//llvm:AArch64CodeGen": ["IREELLVMCPUTargetDeps"], + "@llvm-project//llvm:ARMAsmParser": ["IREELLVMCPUTargetDeps"], + "@llvm-project//llvm:ARMCodeGen": ["IREELLVMCPUTargetDeps"], + "@llvm-project//llvm:RISCVAsmParser": ["IREELLVMCPUTargetDeps"], + "@llvm-project//llvm:RISCVCodeGen": ["IREELLVMCPUTargetDeps"], + "@llvm-project//llvm:WebAssemblyAsmParser": ["IREELLVMCPUTargetDeps"], + "@llvm-project//llvm:WebAssemblyCodeGen": ["IREELLVMCPUTargetDeps"], + "@llvm-project//llvm:X86AsmParser": ["IREELLVMCPUTargetDeps"], + "@llvm-project//llvm:X86CodeGen": ["IREELLVMCPUTargetDeps"], + # Clang + "@llvm-project//clang": ["${IREE_CLANG_TARGET}"], + # LLD + "@llvm-project//lld": ["${IREE_LLD_TARGET}"], + "@llvm-project//lld:COFF": ["lldCOFF"], + "@llvm-project//lld:Common": ["lldCommon"], + "@llvm-project//lld:ELF": ["lldELF"], + "@llvm-project//lld:MachO": ["lldMachO"], + "@llvm-project//lld:Wasm": ["lldWasm"], + # LLVM + "@llvm-project//llvm:config": [], + "@llvm-project//llvm:IPO": ["LLVMipo"], + "@llvm-project//llvm:FileCheck": ["FileCheck"], + "@llvm-project//llvm:not": ["not"], + "@llvm-project//llvm:llvm-link": ["${IREE_LLVM_LINK_TARGET}"], + "@llvm-project//llvm:NVPTXUtilsAndDesc": [ + "LLVMNVPTXDesc", + ], + # MLIR + "@llvm-project//mlir:AllPassesAndDialects": ["MLIRAllDialects"], + "@llvm-project//mlir:CommonFolders": [""], + "@llvm-project//mlir:DialectUtils": [""], + "@llvm-project//mlir:GPUDialect": ["MLIRGPUDialect"], + "@llvm-project//mlir:GPUTransforms": ["MLIRGPUTransforms"], + "@llvm-project//mlir:LinalgStructuredOpsIncGen": [ + "MLIRLinalgStructuredOpsIncGenLib" + ], + "@llvm-project//mlir:ShapeTransforms": ["MLIRShapeOpsTransforms"], + "@llvm-project//mlir:ToLLVMIRTranslation": ["MLIRTargetLLVMIRExport"], + "@llvm-project//mlir:mlir-translate": ["mlir-translate"], + "@llvm-project//mlir:MlirLspServerLib": ["MLIRLspServerLib"], + "@llvm-project//mlir:MlirTableGenMain": ["MLIRTableGen"], + "@llvm-project//mlir:MlirOptLib": ["MLIROptLib"], + "@llvm-project//mlir:VectorOps": ["MLIRVector"], + # StableHLO. + "@stablehlo//:chlo_ops": [ + "ChloOps", + ], + "@stablehlo//:stablehlo_ops": [ + "StablehloOps", + ], + "@stablehlo//:broadcast_utils": [ + "StablehloBroadcastUtils", + ], + # NCCL + "@nccl//:headers": [ + "nccl::headers", + ], + # Torch-MLIR. + "@torch-mlir-dialects//:TorchMLIRTMTensorDialect": [ + "TorchMLIRTMTensorDialect" + ], + # Tracy. + "@tracy_client//:runtime": ["tracy_client::runtime"], + # Vulkan + "@vulkan_headers": ["Vulkan::Headers"], + # Misc single targets + "@com_google_benchmark//:benchmark": ["benchmark"], + "@com_github_dvidelabs_flatcc//:flatcc": ["flatcc"], + "@com_github_dvidelabs_flatcc//:parsing": ["flatcc::parsing"], + "@com_github_dvidelabs_flatcc//:runtime": ["flatcc::runtime"], + "@com_github_yaml_libyaml//:yaml": ["yaml"], + "@com_google_googletest//:gtest": ["gmock", "gtest"], + "@spirv_cross//:spirv_cross_lib": ["spirv-cross-msl"], + "@cpuinfo": ["${IREE_CPUINFO_TARGET}"], + "@vulkan_memory_allocator//:impl_header_only": [ + "vulkan_memory_allocator" + ], + "@webgpu_headers": [], + } + ) + + self._initialize() + + def _initialize(self): + pass + + def _repo_alias(self, repo_name: str) -> str: + """Returns the prefix of a repo (i.e. '@iree_core') given the repo map.""" + return self._repo_map.get(repo_name, repo_name) + + def _update_target_mappings(self, mappings: Dict[str, List[str]]): + self._explicit_target_mapping.update(mappings) + + def _convert_mlir_target(self, target): + # Default to a pattern substitution approach. + # Take "MLIR" and append the name part of the full target identifier, e.g. + # "@llvm-project//mlir:IR" -> "MLIRIR" + # "@llvm-project//mlir:Pass" -> "MLIRPass" + # MLIR does not have header-only targets apart from the libraries. Here + # we redirect any request for a CAPI{Name}Headers to a target within IREE + # that sets this up. + label = target.rsplit(":")[-1] + if label.startswith("CAPI") and label.endswith("Headers"): + return [f"IREELLVMIncludeSetup"] + else: + return [f"MLIR{label}"] + + def _convert_llvm_target(self, target): + # Default to a pattern substitution approach. + # Prepend "LLVM" to the Bazel target name. + # "@llvm-project//llvm:AsmParser" -> "LLVMAsmParser" + # "@llvm-project//llvm:Core" -> "LLVMCore" + return ["LLVM" + target.rsplit(":")[-1]] + + def _convert_iree_cuda_target(self, target): + # Convert like: + # @iree_cuda//:libdevice_embedded -> iree_cuda::libdevice_embedded + label = target.rsplit(":")[-1] + return [f"iree_cuda::{label}"] + + def _convert_iree_dialects_target(self, target): + # Just take the target name as-is. + return [target.rsplit(":")[-1]] + + def _convert_to_cmake_path(self, bazel_path_fragment: str) -> str: + cmake_path = bazel_path_fragment + # Bazel `//iree/base` -> CMake `iree::base` + # Bazel `//iree/base:foo` -> CMake `iree::base::foo` + if cmake_path.startswith("//"): + cmake_path = cmake_path[len("//") :] + cmake_path = cmake_path.replace(":", "::") # iree/base::foo or ::foo + cmake_path = cmake_path.replace("/", "::") # iree::base + return cmake_path + + def convert_target(self, target): + """Converts a Bazel target to a list of CMake targets. + + IREE targets are expected to follow a standard form between Bazel and CMake + that facilitates conversion. External targets *may* have their own patterns, + or they may be purely special cases. + + Multiple target in Bazel may map to a single target in CMake and a Bazel + target may map to multiple CMake targets. + + Returns: + A list of converted targets if it was successfully converted. + + Raises: + KeyError: No conversion was found for the target. + """ + iree_core_repo = self._repo_alias("@iree_core") + if target in self._explicit_target_mapping: + return self._explicit_target_mapping[target] + if target.startswith("@llvm-project//llvm"): + return self._convert_llvm_target(target) + if target.startswith("@llvm-project//mlir"): + return self._convert_mlir_target(target) + if target.startswith("@iree_cuda//"): + return self._convert_iree_cuda_target(target) + if target.startswith(f"{iree_core_repo}//"): + return self._convert_iree_core_target(target) + if target.startswith("@"): + raise KeyError(f"No conversion found for target '{target}'") + + # Pass through package-relative targets + # :target_name + # file_name.txt + if target.startswith(":") or (":" not in target and not target.startswith("/")): + return [self._convert_to_cmake_path(target)] + + return self._convert_unmatched_target(target) + + def _convert_iree_core_target(self, target): + iree_core_repo = self._repo_alias("@iree_core") + if target.startswith(f"{iree_core_repo}//llvm-external-projects/iree-dialects"): + return self._convert_iree_dialects_target(target) + + # IREE root paths map to package names based on explicit rules. + # * src/iree/ directories (compiler/src/iree/ and runtime/src/iree/) + # creating their own root paths by trimming down to just "iree" + # * tools/ uses an empty root, for binary targets names like "iree-compile" + # * other top level directories add back an 'iree' prefix + # If changing these, make the corresponding change in iree_macros.cmake + # (iree_package_ns function). + + # Map //compiler/src/iree/(.*) -> iree::\1 (i.e. iree::compiler::\1) + m = re.match(f"^{iree_core_repo}//compiler/src/iree/(.+)", target) + if m: + return ["iree::" + self._convert_to_cmake_path(m.group(1))] + + # Map //runtime/src/iree/(.*) -> iree::\1 + m = re.match(f"^{iree_core_repo}//runtime/src/iree/(.+)", target) + if m: + return ["iree::" + self._convert_to_cmake_path(m.group(1))] + + # Map //tools/(.*) -> \1 + m = re.match(f"^{iree_core_repo}//tools[/|:](.+)", target) + if m: + return [self._convert_to_cmake_path(m.group(1))] + + return self._convert_unmatched_target(target) + + def _convert_unmatched_target(self, target: str) -> str: + """Converts unmatched targets in a repo specific way.""" + raise ValueError(f"No target matching for {target}") diff --git a/build_tools/benchmarks/benchmark_helper.py b/build_tools/benchmarks/benchmark_helper.py index 847c94ebb7a4..5a5d377cde73 100755 --- a/build_tools/benchmarks/benchmark_helper.py +++ b/build_tools/benchmarks/benchmark_helper.py @@ -27,159 +27,179 @@ def _convert_to_cmd_string(cmds: Sequence[str]) -> str: - if os.name == "nt": - # list2cmdline is an undocumented method for Windows command lines. Python - # doesn't provide an official method for quoting Windows command lines and - # the correct implementation is slightly non-trivial. Use the undocumented - # method for now and can be rewritten with our own implementation later. - # See https://learn.microsoft.com/en-us/archive/blogs/twistylittlepassagesallalike/everyone-quotes-command-line-arguments-the-wrong-way - return subprocess.list2cmdline(cmds) + if os.name == "nt": + # list2cmdline is an undocumented method for Windows command lines. Python + # doesn't provide an official method for quoting Windows command lines and + # the correct implementation is slightly non-trivial. Use the undocumented + # method for now and can be rewritten with our own implementation later. + # See https://learn.microsoft.com/en-us/archive/blogs/twistylittlepassagesallalike/everyone-quotes-command-line-arguments-the-wrong-way + return subprocess.list2cmdline(cmds) - return " ".join(shlex.quote(cmd) for cmd in cmds) + return " ".join(shlex.quote(cmd) for cmd in cmds) def _dump_cmds_of_generation_config( gen_config: iree_definitions.ModuleGenerationConfig, - root_path: pathlib.PurePath = pathlib.PurePath()): - - imported_model = gen_config.imported_model - imported_model_path = iree_artifacts.get_imported_model_path( - imported_model=imported_model, root_path=root_path) - module_dir_path = iree_artifacts.get_module_dir_path( - module_generation_config=gen_config, root_path=root_path) - module_path = module_dir_path / iree_artifacts.MODULE_FILENAME - compile_cmds = [ - IREE_COMPILER_NAME, - str(imported_model_path), "-o", - str(module_path) - ] - compile_cmds += gen_config.materialize_compile_flags( - module_dir_path=module_dir_path) - compile_cmd_str = _convert_to_cmd_string(compile_cmds) - - if imported_model.import_config.tool == iree_definitions.ImportTool.NONE: - import_cmd_str = "# (Source model is already in MLIR)" - else: - source_model_path = model_artifacts.get_model_path( - model=imported_model.model, root_path=root_path) - import_cmds = [ - imported_model.import_config.tool.value, - str(source_model_path), "-o", - str(imported_model_path) + root_path: pathlib.PurePath = pathlib.PurePath(), +): + imported_model = gen_config.imported_model + imported_model_path = iree_artifacts.get_imported_model_path( + imported_model=imported_model, root_path=root_path + ) + module_dir_path = iree_artifacts.get_module_dir_path( + module_generation_config=gen_config, root_path=root_path + ) + module_path = module_dir_path / iree_artifacts.MODULE_FILENAME + compile_cmds = [ + IREE_COMPILER_NAME, + str(imported_model_path), + "-o", + str(module_path), ] - import_cmds += imported_model.import_config.materialize_import_flags( - model=imported_model.model) - import_cmd_str = _convert_to_cmd_string(import_cmds) - - # Insert a blank line after each command to help read with line wrap. - return [ - "Compile Module:", compile_cmd_str, "", "Import Model:", import_cmd_str, - "" - ] + compile_cmds += gen_config.materialize_compile_flags( + module_dir_path=module_dir_path + ) + compile_cmd_str = _convert_to_cmd_string(compile_cmds) + + if imported_model.import_config.tool == iree_definitions.ImportTool.NONE: + import_cmd_str = "# (Source model is already in MLIR)" + else: + source_model_path = model_artifacts.get_model_path( + model=imported_model.model, root_path=root_path + ) + import_cmds = [ + imported_model.import_config.tool.value, + str(source_model_path), + "-o", + str(imported_model_path), + ] + import_cmds += imported_model.import_config.materialize_import_flags( + model=imported_model.model + ) + import_cmd_str = _convert_to_cmd_string(import_cmds) + + # Insert a blank line after each command to help read with line wrap. + return ["Compile Module:", compile_cmd_str, "", "Import Model:", import_cmd_str, ""] def _dump_cmds_from_run_config( run_config: iree_definitions.E2EModelRunConfig, - root_path: pathlib.PurePath = pathlib.PurePath()): - - gen_config = run_config.module_generation_config - module_path = iree_artifacts.get_module_dir_path( - module_generation_config=gen_config, - root_path=root_path) / iree_artifacts.MODULE_FILENAME - - run_cmds = [run_config.tool.value, f"--module={module_path}"] - run_cmds += run_config.materialize_run_flags() - # Insert a blank line after the command to help read with line wrap. - lines = ["Run Module:", _convert_to_cmd_string(run_cmds), ""] - lines += _dump_cmds_of_generation_config(gen_config=gen_config, - root_path=root_path) - return lines - - -def _dump_cmds_handler(e2e_test_artifacts_dir: pathlib.Path, - execution_benchmark_config: Optional[pathlib.Path], - compilation_benchmark_config: Optional[pathlib.Path], - benchmark_id: Optional[str], **_unused_args): - lines = [] - - if execution_benchmark_config is not None: - benchmark_groups = json.loads(execution_benchmark_config.read_text()) - for target_device, benchmark_group in benchmark_groups.items(): - run_configs = serialization.unpack_and_deserialize( - data=benchmark_group["run_configs"], - root_type=List[iree_definitions.E2EModelRunConfig]) - for run_config in run_configs: - if benchmark_id is not None and benchmark_id != run_config.composite_id: - continue - - lines.append("################") - lines.append("") - lines.append(f"Execution Benchmark ID: {run_config.composite_id}") - lines.append(f"Name: {run_config}") - lines.append(f"Target Device: {target_device}") - lines.append("") - lines += _dump_cmds_from_run_config(run_config=run_config, - root_path=e2e_test_artifacts_dir) - - if compilation_benchmark_config is not None: - benchmark_config = json.loads(compilation_benchmark_config.read_text()) - gen_configs = serialization.unpack_and_deserialize( - data=benchmark_config["generation_configs"], - root_type=List[iree_definitions.ModuleGenerationConfig]) - for gen_config in gen_configs: - if benchmark_id is not None and benchmark_id != gen_config.composite_id: - continue - - lines.append("################") - lines.append("") - lines.append(f"Compilation Benchmark ID: {gen_config.composite_id}") - lines.append(f"Name: {gen_config}") - lines.append("") - lines += _dump_cmds_of_generation_config(gen_config=gen_config, - root_path=e2e_test_artifacts_dir) - - print(*lines, sep="\n") + root_path: pathlib.PurePath = pathlib.PurePath(), +): + gen_config = run_config.module_generation_config + module_path = ( + iree_artifacts.get_module_dir_path( + module_generation_config=gen_config, root_path=root_path + ) + / iree_artifacts.MODULE_FILENAME + ) + + run_cmds = [run_config.tool.value, f"--module={module_path}"] + run_cmds += run_config.materialize_run_flags() + # Insert a blank line after the command to help read with line wrap. + lines = ["Run Module:", _convert_to_cmd_string(run_cmds), ""] + lines += _dump_cmds_of_generation_config(gen_config=gen_config, root_path=root_path) + return lines + + +def _dump_cmds_handler( + e2e_test_artifacts_dir: pathlib.Path, + execution_benchmark_config: Optional[pathlib.Path], + compilation_benchmark_config: Optional[pathlib.Path], + benchmark_id: Optional[str], + **_unused_args, +): + lines = [] + + if execution_benchmark_config is not None: + benchmark_groups = json.loads(execution_benchmark_config.read_text()) + for target_device, benchmark_group in benchmark_groups.items(): + run_configs = serialization.unpack_and_deserialize( + data=benchmark_group["run_configs"], + root_type=List[iree_definitions.E2EModelRunConfig], + ) + for run_config in run_configs: + if benchmark_id is not None and benchmark_id != run_config.composite_id: + continue + + lines.append("################") + lines.append("") + lines.append(f"Execution Benchmark ID: {run_config.composite_id}") + lines.append(f"Name: {run_config}") + lines.append(f"Target Device: {target_device}") + lines.append("") + lines += _dump_cmds_from_run_config( + run_config=run_config, root_path=e2e_test_artifacts_dir + ) + + if compilation_benchmark_config is not None: + benchmark_config = json.loads(compilation_benchmark_config.read_text()) + gen_configs = serialization.unpack_and_deserialize( + data=benchmark_config["generation_configs"], + root_type=List[iree_definitions.ModuleGenerationConfig], + ) + for gen_config in gen_configs: + if benchmark_id is not None and benchmark_id != gen_config.composite_id: + continue + + lines.append("################") + lines.append("") + lines.append(f"Compilation Benchmark ID: {gen_config.composite_id}") + lines.append(f"Name: {gen_config}") + lines.append("") + lines += _dump_cmds_of_generation_config( + gen_config=gen_config, root_path=e2e_test_artifacts_dir + ) + + print(*lines, sep="\n") def _parse_arguments() -> argparse.Namespace: - parser = argparse.ArgumentParser( - description= - "Miscellaneous tool to help work with benchmark suite and benchmark CI.") - - subparser = parser.add_subparsers(required=True, title="operation") - dump_cmds_parser = subparser.add_parser( - "dump-cmds", - help="Dump the commands to compile and run benchmarks manually.") - dump_cmds_parser.add_argument( - "--e2e_test_artifacts_dir", - type=pathlib.PurePath, - default=pathlib.Path(), - help="E2E test artifacts root path used in the outputs of artifact paths") - dump_cmds_parser.add_argument("--benchmark_id", - type=str, - help="Only dump the benchmark with this id") - dump_cmds_parser.add_argument( - "--execution_benchmark_config", - type=pathlib.Path, - help="Config file exported from export_benchmark_config.py execution") - dump_cmds_parser.add_argument( - "--compilation_benchmark_config", - type=pathlib.Path, - help="Config file exported from export_benchmark_config.py compilation") - dump_cmds_parser.set_defaults(handler=_dump_cmds_handler) - - args = parser.parse_args() - if (args.execution_benchmark_config is None and - args.compilation_benchmark_config is None): - parser.error("At least one of --execution_benchmark_config or " - "--compilation_benchmark_config must be set.") - - return args + parser = argparse.ArgumentParser( + description="Miscellaneous tool to help work with benchmark suite and benchmark CI." + ) + + subparser = parser.add_subparsers(required=True, title="operation") + dump_cmds_parser = subparser.add_parser( + "dump-cmds", help="Dump the commands to compile and run benchmarks manually." + ) + dump_cmds_parser.add_argument( + "--e2e_test_artifacts_dir", + type=pathlib.PurePath, + default=pathlib.Path(), + help="E2E test artifacts root path used in the outputs of artifact paths", + ) + dump_cmds_parser.add_argument( + "--benchmark_id", type=str, help="Only dump the benchmark with this id" + ) + dump_cmds_parser.add_argument( + "--execution_benchmark_config", + type=pathlib.Path, + help="Config file exported from export_benchmark_config.py execution", + ) + dump_cmds_parser.add_argument( + "--compilation_benchmark_config", + type=pathlib.Path, + help="Config file exported from export_benchmark_config.py compilation", + ) + dump_cmds_parser.set_defaults(handler=_dump_cmds_handler) + + args = parser.parse_args() + if ( + args.execution_benchmark_config is None + and args.compilation_benchmark_config is None + ): + parser.error( + "At least one of --execution_benchmark_config or " + "--compilation_benchmark_config must be set." + ) + + return args def main(args: argparse.Namespace): - args.handler(**vars(args)) + args.handler(**vars(args)) if __name__ == "__main__": - main(_parse_arguments()) + main(_parse_arguments()) diff --git a/build_tools/benchmarks/collect_compilation_statistics.py b/build_tools/benchmarks/collect_compilation_statistics.py index 19c80e9471d7..057368da5f65 100755 --- a/build_tools/benchmarks/collect_compilation_statistics.py +++ b/build_tools/benchmarks/collect_compilation_statistics.py @@ -26,7 +26,13 @@ from typing import BinaryIO, Dict, List, Optional, TextIO from common import benchmark_definition -from common.benchmark_definition import CompilationInfo, CompilationResults, CompilationStatistics, ModuleComponentSizes, get_git_commit_hash +from common.benchmark_definition import ( + CompilationInfo, + CompilationResults, + CompilationStatistics, + ModuleComponentSizes, + get_git_commit_hash, +) from common import benchmark_config from e2e_test_artifacts import iree_artifacts from e2e_test_framework import serialization @@ -52,208 +58,226 @@ @dataclass(frozen=True) class ModuleInfo(object): - module_path: pathlib.Path - stream_stats_path: pathlib.Path + module_path: pathlib.Path + stream_stats_path: pathlib.Path def match_module_cmake_target(module_path: pathlib.PurePath) -> Optional[str]: - if module_path.match(f"{E2E_TEST_ARTIFACTS_REL_PATH}/iree_*/" - f"{iree_artifacts.MODULE_FILENAME}"): - # /iree_/ - path_parts = module_path.parts[-3:] - # Join to get the CMake target name. This is *not* a filesystem path, so we - # don't want \ separators on Windows that we would get with os.path.join(). - return '/'.join(path_parts) + if module_path.match( + f"{E2E_TEST_ARTIFACTS_REL_PATH}/iree_*/" f"{iree_artifacts.MODULE_FILENAME}" + ): + # /iree_/ + path_parts = module_path.parts[-3:] + # Join to get the CMake target name. This is *not* a filesystem path, so we + # don't want \ separators on Windows that we would get with os.path.join(). + return "/".join(path_parts) - return None + return None def parse_compilation_time_from_ninja_log(log: TextIO) -> Dict[str, int]: - """Retrieve the compilation time (ms) from the Ninja build log. - - Returns: - Map of target name and compilation time in ms. - """ - - target_build_time_map = {} - header = log.readline() - if NINJA_LOG_HEADER not in header: - raise NotImplementedError(f"Unsupported ninja log version: {header}") - - for line in log: - start_time, end_time, _, target, _ = line.strip().split("\t") - cmake_target = match_module_cmake_target(pathlib.PurePath(target)) - if cmake_target is None: - continue - - start_time = int(start_time) - end_time = int(end_time) - target_build_time_map[cmake_target] = end_time - start_time - - return target_build_time_map - - -def get_module_component_info(module: BinaryIO, - module_file_bytes: int) -> ModuleComponentSizes: - with zipfile.ZipFile(module) as module_zipfile: - size_map = dict( - (info.filename, info.file_size) for info in module_zipfile.infolist()) - - identified_names = set() - if VM_COMPONENT_NAME in size_map: - vm_component_bytes = size_map[VM_COMPONENT_NAME] - identified_names.add(VM_COMPONENT_NAME) - else: - vm_component_bytes = 0 - - if CONST_COMPONENT_NAME in size_map: - const_component_bytes = size_map[CONST_COMPONENT_NAME] - identified_names.add(CONST_COMPONENT_NAME) - else: - const_component_bytes = 0 - - total_dispatch_component_bytes = 0 - for filename, size in size_map.items(): - for pattern in DISPATCH_COMPONENT_PATTERNS: - if re.match(pattern, filename): - total_dispatch_component_bytes += size - identified_names.add(filename) - break - - if identified_names != set(size_map.keys()): - raise RuntimeError( - f"Unrecognized components in the module: {size_map.keys()}.") - - return ModuleComponentSizes( - file_bytes=module_file_bytes, - vm_component_bytes=vm_component_bytes, - const_component_bytes=const_component_bytes, - total_dispatch_component_bytes=total_dispatch_component_bytes) + """Retrieve the compilation time (ms) from the Ninja build log. + + Returns: + Map of target name and compilation time in ms. + """ + + target_build_time_map = {} + header = log.readline() + if NINJA_LOG_HEADER not in header: + raise NotImplementedError(f"Unsupported ninja log version: {header}") + + for line in log: + start_time, end_time, _, target, _ = line.strip().split("\t") + cmake_target = match_module_cmake_target(pathlib.PurePath(target)) + if cmake_target is None: + continue + + start_time = int(start_time) + end_time = int(end_time) + target_build_time_map[cmake_target] = end_time - start_time + + return target_build_time_map + + +def get_module_component_info( + module: BinaryIO, module_file_bytes: int +) -> ModuleComponentSizes: + with zipfile.ZipFile(module) as module_zipfile: + size_map = dict( + (info.filename, info.file_size) for info in module_zipfile.infolist() + ) + + identified_names = set() + if VM_COMPONENT_NAME in size_map: + vm_component_bytes = size_map[VM_COMPONENT_NAME] + identified_names.add(VM_COMPONENT_NAME) + else: + vm_component_bytes = 0 + + if CONST_COMPONENT_NAME in size_map: + const_component_bytes = size_map[CONST_COMPONENT_NAME] + identified_names.add(CONST_COMPONENT_NAME) + else: + const_component_bytes = 0 + + total_dispatch_component_bytes = 0 + for filename, size in size_map.items(): + for pattern in DISPATCH_COMPONENT_PATTERNS: + if re.match(pattern, filename): + total_dispatch_component_bytes += size + identified_names.add(filename) + break + + if identified_names != set(size_map.keys()): + raise RuntimeError(f"Unrecognized components in the module: {size_map.keys()}.") + + return ModuleComponentSizes( + file_bytes=module_file_bytes, + vm_component_bytes=vm_component_bytes, + const_component_bytes=const_component_bytes, + total_dispatch_component_bytes=total_dispatch_component_bytes, + ) def get_module_map_from_compilation_benchmark_config( - compilation_benchmark_config_data: TextIO, - e2e_test_artifacts_dir: pathlib.PurePath + compilation_benchmark_config_data: TextIO, e2e_test_artifacts_dir: pathlib.PurePath ) -> Dict[CompilationInfo, ModuleInfo]: - benchmark_config = json.load(compilation_benchmark_config_data) - gen_configs = serialization.unpack_and_deserialize( - data=benchmark_config["generation_configs"], - root_type=List[iree_definitions.ModuleGenerationConfig]) - module_map = {} - for gen_config in gen_configs: - model = gen_config.imported_model.model - compile_config = gen_config.compile_config - target_archs = [] - for compile_target in compile_config.compile_targets: - arch = compile_target.target_architecture - target_archs.append( - (f"{arch.type.value}-{arch.architecture}-{arch.microarchitecture}-" - f"{compile_target.target_abi.value}")) - compilation_info = CompilationInfo( - name=gen_config.name, - model_name=model.name, - model_tags=tuple(model.tags), - model_source=model.source_type.value, - target_arch=f"[{','.join(target_archs)}]", - compile_tags=tuple(compile_config.tags), - gen_config_id=gen_config.composite_id) - module_dir_path = pathlib.Path( - iree_artifacts.get_module_dir_path(module_generation_config=gen_config, - root_path=e2e_test_artifacts_dir)) - module_path = module_dir_path / iree_artifacts.MODULE_FILENAME - stream_stats_path = (module_dir_path / - iree_artifacts.SCHEDULING_STATS_FILENAME) - module_map[compilation_info] = ModuleInfo( - module_path=module_path, stream_stats_path=stream_stats_path) - - return module_map + benchmark_config = json.load(compilation_benchmark_config_data) + gen_configs = serialization.unpack_and_deserialize( + data=benchmark_config["generation_configs"], + root_type=List[iree_definitions.ModuleGenerationConfig], + ) + module_map = {} + for gen_config in gen_configs: + model = gen_config.imported_model.model + compile_config = gen_config.compile_config + target_archs = [] + for compile_target in compile_config.compile_targets: + arch = compile_target.target_architecture + target_archs.append( + ( + f"{arch.type.value}-{arch.architecture}-{arch.microarchitecture}-" + f"{compile_target.target_abi.value}" + ) + ) + compilation_info = CompilationInfo( + name=gen_config.name, + model_name=model.name, + model_tags=tuple(model.tags), + model_source=model.source_type.value, + target_arch=f"[{','.join(target_archs)}]", + compile_tags=tuple(compile_config.tags), + gen_config_id=gen_config.composite_id, + ) + module_dir_path = pathlib.Path( + iree_artifacts.get_module_dir_path( + module_generation_config=gen_config, root_path=e2e_test_artifacts_dir + ) + ) + module_path = module_dir_path / iree_artifacts.MODULE_FILENAME + stream_stats_path = module_dir_path / iree_artifacts.SCHEDULING_STATS_FILENAME + module_map[compilation_info] = ModuleInfo( + module_path=module_path, stream_stats_path=stream_stats_path + ) + + return module_map def _check_dir_path(path_str: str) -> pathlib.Path: - path = pathlib.Path(path_str) - if not path.is_dir(): - raise argparse.ArgumentTypeError(f"{path} is not a directory.") - return path + path = pathlib.Path(path_str) + if not path.is_dir(): + raise argparse.ArgumentTypeError(f"{path} is not a directory.") + return path def _check_file_path(path_str: str) -> pathlib.Path: - path = pathlib.Path(path_str) - if not path.is_file(): - raise argparse.ArgumentTypeError(f"{path} is not a file.") - return path + path = pathlib.Path(path_str) + if not path.is_file(): + raise argparse.ArgumentTypeError(f"{path} is not a file.") + return path def _parse_arguments(): - """Returns an argument parser with common options.""" - - parser = argparse.ArgumentParser( - description="Collect compilation statistics from benchmark suites.") - parser.add_argument( - "--compilation_benchmark_config", - type=_check_file_path, - required=True, - help="Exported compilation benchmark config of e2e test artifacts.") - parser.add_argument("--build_log", - type=_check_file_path, - required=True, - help="Path to the ninja build log.") - parser.add_argument("--e2e_test_artifacts_dir", - type=_check_dir_path, - required=True, - help="Path to the e2e test artifacts directory.") - parser.add_argument("--output", - type=pathlib.Path, - help="Path to output JSON file.") - - return parser.parse_args() + """Returns an argument parser with common options.""" + + parser = argparse.ArgumentParser( + description="Collect compilation statistics from benchmark suites." + ) + parser.add_argument( + "--compilation_benchmark_config", + type=_check_file_path, + required=True, + help="Exported compilation benchmark config of e2e test artifacts.", + ) + parser.add_argument( + "--build_log", + type=_check_file_path, + required=True, + help="Path to the ninja build log.", + ) + parser.add_argument( + "--e2e_test_artifacts_dir", + type=_check_dir_path, + required=True, + help="Path to the e2e test artifacts directory.", + ) + parser.add_argument("--output", type=pathlib.Path, help="Path to output JSON file.") + + return parser.parse_args() def main(args: argparse.Namespace): - config_data = args.compilation_benchmark_config.open("r") - module_map = get_module_map_from_compilation_benchmark_config( - compilation_benchmark_config_data=config_data, - e2e_test_artifacts_dir=args.e2e_test_artifacts_dir) - build_log_path = args.build_log - - with build_log_path.open("r") as log_file: - target_build_time_map = parse_compilation_time_from_ninja_log(log_file) - - compilation_statistics_list = [] - for compilation_info, module_info in module_map.items(): - module_path = module_info.module_path - with module_path.open("rb") as module_file: - module_component_sizes = get_module_component_info( - module_file, - module_path.stat().st_size) - - cmake_target = match_module_cmake_target(module_path) - if cmake_target is None: - raise RuntimeError( - f"Module path isn't a module cmake target: {module_path}") - compilation_time_ms = target_build_time_map[cmake_target] - - stream_stats_json = json.loads(module_info.stream_stats_path.read_text()) - exec_stats_json = stream_stats_json["stream-aggregate"]["execution"] - ir_stats = benchmark_definition.IRStatistics( - stream_dispatch_count=exec_stats_json["dispatch-count"]) - - compilation_statistics = CompilationStatistics( - compilation_info=compilation_info, - module_component_sizes=module_component_sizes, - compilation_time_ms=compilation_time_ms, - ir_stats=ir_stats) - compilation_statistics_list.append(compilation_statistics) - - commit = get_git_commit_hash("HEAD") - compilation_results = CompilationResults( - commit=commit, compilation_statistics=compilation_statistics_list) - - json_output = json.dumps(asdict(compilation_results), indent=2) - if args.output is None: - print(json_output) - else: - args.output.write_text(json_output) + config_data = args.compilation_benchmark_config.open("r") + module_map = get_module_map_from_compilation_benchmark_config( + compilation_benchmark_config_data=config_data, + e2e_test_artifacts_dir=args.e2e_test_artifacts_dir, + ) + build_log_path = args.build_log + + with build_log_path.open("r") as log_file: + target_build_time_map = parse_compilation_time_from_ninja_log(log_file) + + compilation_statistics_list = [] + for compilation_info, module_info in module_map.items(): + module_path = module_info.module_path + with module_path.open("rb") as module_file: + module_component_sizes = get_module_component_info( + module_file, module_path.stat().st_size + ) + + cmake_target = match_module_cmake_target(module_path) + if cmake_target is None: + raise RuntimeError( + f"Module path isn't a module cmake target: {module_path}" + ) + compilation_time_ms = target_build_time_map[cmake_target] + + stream_stats_json = json.loads(module_info.stream_stats_path.read_text()) + exec_stats_json = stream_stats_json["stream-aggregate"]["execution"] + ir_stats = benchmark_definition.IRStatistics( + stream_dispatch_count=exec_stats_json["dispatch-count"] + ) + + compilation_statistics = CompilationStatistics( + compilation_info=compilation_info, + module_component_sizes=module_component_sizes, + compilation_time_ms=compilation_time_ms, + ir_stats=ir_stats, + ) + compilation_statistics_list.append(compilation_statistics) + + commit = get_git_commit_hash("HEAD") + compilation_results = CompilationResults( + commit=commit, compilation_statistics=compilation_statistics_list + ) + + json_output = json.dumps(asdict(compilation_results), indent=2) + if args.output is None: + print(json_output) + else: + args.output.write_text(json_output) if __name__ == "__main__": - main(_parse_arguments()) + main(_parse_arguments()) diff --git a/build_tools/benchmarks/collect_compilation_statistics_test.py b/build_tools/benchmarks/collect_compilation_statistics_test.py index 63287577ed7c..d2aff95cc3da 100644 --- a/build_tools/benchmarks/collect_compilation_statistics_test.py +++ b/build_tools/benchmarks/collect_compilation_statistics_test.py @@ -12,7 +12,12 @@ import zipfile from common.benchmark_definition import ModuleComponentSizes -from collect_compilation_statistics import CONST_COMPONENT_NAME, VM_COMPONENT_NAME, get_module_component_info, parse_compilation_time_from_ninja_log +from collect_compilation_statistics import ( + CONST_COMPONENT_NAME, + VM_COMPONENT_NAME, + get_module_component_info, + parse_compilation_time_from_ninja_log, +) from e2e_test_artifacts import iree_artifacts from e2e_test_framework import serialization from e2e_test_framework.definitions import common_definitions, iree_definitions @@ -21,140 +26,161 @@ class CollectCompilationStatistics(unittest.TestCase): - - def test_match_module_cmake_target_with_e2e_test_artifacts(self): - target = collect_compilation_statistics.match_module_cmake_target( - pathlib.PurePath("e2e_test_artifacts/iree_abcd/module.vmfb")) - - self.assertEqual(target, "e2e_test_artifacts/iree_abcd/module.vmfb") - - def test_match_module_cmake_target_not_match(self): - target = collect_compilation_statistics.match_module_cmake_target( - pathlib.PurePath("other/target.vmfb")) - - self.assertIsNone(target) - - def test_parse_compilation_time_from_ninja_log(self): - target1 = "e2e_test_artifacts/iree_deeplabv3/module.vmfb" - target2 = "e2e_test_artifacts/iree_mobilessd/module.vmfb" - ninja_log = StringIO("# ninja log v5\n" - f"0\t100\taaa\tbuild/{target1}\taaa\n" - f"130\t200\tbbb\tbuild/{target2}\tbbb\n") - - target_map = parse_compilation_time_from_ninja_log(ninja_log) - - self.assertEqual(target_map, {target1: 100, target2: 70}) - - def test_get_module_component_info(self): - module_file = BytesIO() - with zipfile.ZipFile(module_file, "w") as zip: - zip.writestr(VM_COMPONENT_NAME, b"abcd") - zip.writestr(CONST_COMPONENT_NAME, b"123") - zip.writestr("main_dispatch_0_vulkan_spirv_fb.fb", b"bindata0") - zip.writestr("main_dispatch_1_vulkan_spirv_fb.fb", b"bindata1") - zip.writestr("predict_dispatch_2_cuda_nvptx_fb.fb", b"bindata2") - zip.writestr("dispatch_3_embedded_elf_x86_64.so", b"bindata3") - module_file_data = module_file.getvalue() - - component_sizes = get_module_component_info(BytesIO(module_file_data), - len(module_file_data)) - - self.assertEqual( - component_sizes, - ModuleComponentSizes(file_bytes=len(module_file_data), - vm_component_bytes=4, - const_component_bytes=3, - total_dispatch_component_bytes=32)) - - def test_get_module_component_info_unknown_components(self): - module_file = BytesIO() - with zipfile.ZipFile(module_file, "w") as zip: - zip.writestr(VM_COMPONENT_NAME, b"abcd") - zip.writestr(CONST_COMPONENT_NAME, b"123") - zip.writestr("main_dispatch_0_unknown.fb", b"bindata") - module_file_data = module_file.getvalue() - - self.assertRaises( - RuntimeError, lambda: get_module_component_info( - BytesIO(module_file_data), len(module_file_data))) - - def test_get_module_map_from_compilation_benchmark_config(self): - model_a = common_definitions.Model( - id="1234", - name="tflite_m", - tags=[], - source_type=common_definitions.ModelSourceType.EXPORTED_TFLITE, - source_url="https://example.com/xyz.tflite", - entry_function="main", - input_types=["1xf32"]) - imported_model_a = iree_definitions.ImportedModel.from_model(model_a) - compile_config_a = iree_definitions.CompileConfig.build( - id="config_a", - tags=["defaults"], - compile_targets=[ - iree_definitions.CompileTarget( - target_architecture=common_definitions.DeviceArchitecture. - X86_64_CASCADELAKE, - target_backend=iree_definitions.TargetBackend.LLVM_CPU, - target_abi=iree_definitions.TargetABI.LINUX_GNU) - ]) - compile_config_b = iree_definitions.CompileConfig.build( - id="config_b", - tags=["defaults"], - compile_targets=[ - iree_definitions.CompileTarget( - target_architecture=common_definitions.DeviceArchitecture. - RV64_GENERIC, - target_backend=iree_definitions.TargetBackend.LLVM_CPU, - target_abi=iree_definitions.TargetABI.LINUX_GNU) - ]) - gen_config_a = iree_definitions.ModuleGenerationConfig.build( - imported_model=imported_model_a, compile_config=compile_config_a) - gen_config_b = iree_definitions.ModuleGenerationConfig.build( - imported_model=imported_model_a, compile_config=compile_config_b) - benchmark_config = dict(generation_configs=serialization.serialize_and_pack( - [gen_config_a, gen_config_b]), - module_dir_paths=["a", "b"]) - root_dir = pathlib.PurePath("artifacts_dir") - - module_map = collect_compilation_statistics.get_module_map_from_compilation_benchmark_config( - compilation_benchmark_config_data=StringIO( - json.dumps(benchmark_config)), - e2e_test_artifacts_dir=root_dir) - - compile_info_a = common.benchmark_definition.CompilationInfo( - name=gen_config_a.name, - model_name=model_a.name, - model_tags=tuple(model_a.tags), - model_source=model_a.source_type.value, - target_arch=f"[cpu-x86_64-cascadelake-linux-gnu]", - compile_tags=tuple(gen_config_a.compile_config.tags), - gen_config_id=gen_config_a.composite_id) - module_dir_a = pathlib.Path( - iree_artifacts.get_module_dir_path(gen_config_a, root_dir)) - module_info_a = collect_compilation_statistics.ModuleInfo( - module_path=module_dir_a / iree_artifacts.MODULE_FILENAME, - stream_stats_path=module_dir_a / - iree_artifacts.SCHEDULING_STATS_FILENAME) - compile_info_b = common.benchmark_definition.CompilationInfo( - name=gen_config_b.name, - model_name=model_a.name, - model_tags=tuple(model_a.tags), - model_source=model_a.source_type.value, - target_arch=f"[cpu-riscv_64-generic-linux-gnu]", - compile_tags=tuple(gen_config_a.compile_config.tags), - gen_config_id=gen_config_b.composite_id) - module_dir_b = pathlib.Path( - iree_artifacts.get_module_dir_path(gen_config_b, root_dir)) - module_info_b = collect_compilation_statistics.ModuleInfo( - module_path=module_dir_b / iree_artifacts.MODULE_FILENAME, - stream_stats_path=module_dir_b / - iree_artifacts.SCHEDULING_STATS_FILENAME) - self.assertEqual(module_map, { - compile_info_a: module_info_a, - compile_info_b: module_info_b - }) + def test_match_module_cmake_target_with_e2e_test_artifacts(self): + target = collect_compilation_statistics.match_module_cmake_target( + pathlib.PurePath("e2e_test_artifacts/iree_abcd/module.vmfb") + ) + + self.assertEqual(target, "e2e_test_artifacts/iree_abcd/module.vmfb") + + def test_match_module_cmake_target_not_match(self): + target = collect_compilation_statistics.match_module_cmake_target( + pathlib.PurePath("other/target.vmfb") + ) + + self.assertIsNone(target) + + def test_parse_compilation_time_from_ninja_log(self): + target1 = "e2e_test_artifacts/iree_deeplabv3/module.vmfb" + target2 = "e2e_test_artifacts/iree_mobilessd/module.vmfb" + ninja_log = StringIO( + "# ninja log v5\n" + f"0\t100\taaa\tbuild/{target1}\taaa\n" + f"130\t200\tbbb\tbuild/{target2}\tbbb\n" + ) + + target_map = parse_compilation_time_from_ninja_log(ninja_log) + + self.assertEqual(target_map, {target1: 100, target2: 70}) + + def test_get_module_component_info(self): + module_file = BytesIO() + with zipfile.ZipFile(module_file, "w") as zip: + zip.writestr(VM_COMPONENT_NAME, b"abcd") + zip.writestr(CONST_COMPONENT_NAME, b"123") + zip.writestr("main_dispatch_0_vulkan_spirv_fb.fb", b"bindata0") + zip.writestr("main_dispatch_1_vulkan_spirv_fb.fb", b"bindata1") + zip.writestr("predict_dispatch_2_cuda_nvptx_fb.fb", b"bindata2") + zip.writestr("dispatch_3_embedded_elf_x86_64.so", b"bindata3") + module_file_data = module_file.getvalue() + + component_sizes = get_module_component_info( + BytesIO(module_file_data), len(module_file_data) + ) + + self.assertEqual( + component_sizes, + ModuleComponentSizes( + file_bytes=len(module_file_data), + vm_component_bytes=4, + const_component_bytes=3, + total_dispatch_component_bytes=32, + ), + ) + + def test_get_module_component_info_unknown_components(self): + module_file = BytesIO() + with zipfile.ZipFile(module_file, "w") as zip: + zip.writestr(VM_COMPONENT_NAME, b"abcd") + zip.writestr(CONST_COMPONENT_NAME, b"123") + zip.writestr("main_dispatch_0_unknown.fb", b"bindata") + module_file_data = module_file.getvalue() + + self.assertRaises( + RuntimeError, + lambda: get_module_component_info( + BytesIO(module_file_data), len(module_file_data) + ), + ) + + def test_get_module_map_from_compilation_benchmark_config(self): + model_a = common_definitions.Model( + id="1234", + name="tflite_m", + tags=[], + source_type=common_definitions.ModelSourceType.EXPORTED_TFLITE, + source_url="https://example.com/xyz.tflite", + entry_function="main", + input_types=["1xf32"], + ) + imported_model_a = iree_definitions.ImportedModel.from_model(model_a) + compile_config_a = iree_definitions.CompileConfig.build( + id="config_a", + tags=["defaults"], + compile_targets=[ + iree_definitions.CompileTarget( + target_architecture=common_definitions.DeviceArchitecture.X86_64_CASCADELAKE, + target_backend=iree_definitions.TargetBackend.LLVM_CPU, + target_abi=iree_definitions.TargetABI.LINUX_GNU, + ) + ], + ) + compile_config_b = iree_definitions.CompileConfig.build( + id="config_b", + tags=["defaults"], + compile_targets=[ + iree_definitions.CompileTarget( + target_architecture=common_definitions.DeviceArchitecture.RV64_GENERIC, + target_backend=iree_definitions.TargetBackend.LLVM_CPU, + target_abi=iree_definitions.TargetABI.LINUX_GNU, + ) + ], + ) + gen_config_a = iree_definitions.ModuleGenerationConfig.build( + imported_model=imported_model_a, compile_config=compile_config_a + ) + gen_config_b = iree_definitions.ModuleGenerationConfig.build( + imported_model=imported_model_a, compile_config=compile_config_b + ) + benchmark_config = dict( + generation_configs=serialization.serialize_and_pack( + [gen_config_a, gen_config_b] + ), + module_dir_paths=["a", "b"], + ) + root_dir = pathlib.PurePath("artifacts_dir") + + module_map = collect_compilation_statistics.get_module_map_from_compilation_benchmark_config( + compilation_benchmark_config_data=StringIO(json.dumps(benchmark_config)), + e2e_test_artifacts_dir=root_dir, + ) + + compile_info_a = common.benchmark_definition.CompilationInfo( + name=gen_config_a.name, + model_name=model_a.name, + model_tags=tuple(model_a.tags), + model_source=model_a.source_type.value, + target_arch=f"[cpu-x86_64-cascadelake-linux-gnu]", + compile_tags=tuple(gen_config_a.compile_config.tags), + gen_config_id=gen_config_a.composite_id, + ) + module_dir_a = pathlib.Path( + iree_artifacts.get_module_dir_path(gen_config_a, root_dir) + ) + module_info_a = collect_compilation_statistics.ModuleInfo( + module_path=module_dir_a / iree_artifacts.MODULE_FILENAME, + stream_stats_path=module_dir_a / iree_artifacts.SCHEDULING_STATS_FILENAME, + ) + compile_info_b = common.benchmark_definition.CompilationInfo( + name=gen_config_b.name, + model_name=model_a.name, + model_tags=tuple(model_a.tags), + model_source=model_a.source_type.value, + target_arch=f"[cpu-riscv_64-generic-linux-gnu]", + compile_tags=tuple(gen_config_a.compile_config.tags), + gen_config_id=gen_config_b.composite_id, + ) + module_dir_b = pathlib.Path( + iree_artifacts.get_module_dir_path(gen_config_b, root_dir) + ) + module_info_b = collect_compilation_statistics.ModuleInfo( + module_path=module_dir_b / iree_artifacts.MODULE_FILENAME, + stream_stats_path=module_dir_b / iree_artifacts.SCHEDULING_STATS_FILENAME, + ) + self.assertEqual( + module_map, {compile_info_a: module_info_a, compile_info_b: module_info_b} + ) if __name__ == "__main__": - unittest.main() + unittest.main() diff --git a/build_tools/benchmarks/common/android_device_utils.py b/build_tools/benchmarks/common/android_device_utils.py index a0f86047e558..5149da8d1d33 100644 --- a/build_tools/benchmarks/common/android_device_utils.py +++ b/build_tools/benchmarks/common/android_device_utils.py @@ -10,59 +10,64 @@ import re from typing import Sequence -from .benchmark_definition import (execute_cmd_and_get_stdout, DeviceInfo, - PlatformType) +from .benchmark_definition import execute_cmd_and_get_stdout, DeviceInfo, PlatformType def get_android_device_model(verbose: bool = False) -> str: - """Returns the Android device model.""" - model = execute_cmd_and_get_stdout( - ["adb", "shell", "getprop", "ro.product.model"], verbose=verbose) - model = re.sub(r"\W+", "-", model) - return model + """Returns the Android device model.""" + model = execute_cmd_and_get_stdout( + ["adb", "shell", "getprop", "ro.product.model"], verbose=verbose + ) + model = re.sub(r"\W+", "-", model) + return model def get_android_cpu_abi(verbose: bool = False) -> str: - """Returns the CPU ABI for the Android device.""" - return execute_cmd_and_get_stdout( - ["adb", "shell", "getprop", "ro.product.cpu.abi"], verbose=verbose) + """Returns the CPU ABI for the Android device.""" + return execute_cmd_and_get_stdout( + ["adb", "shell", "getprop", "ro.product.cpu.abi"], verbose=verbose + ) def get_android_cpu_features(verbose: bool = False) -> Sequence[str]: - """Returns the CPU features for the Android device.""" - cpuinfo = execute_cmd_and_get_stdout(["adb", "shell", "cat", "/proc/cpuinfo"], - verbose=verbose) - features = [] - for line in cpuinfo.splitlines(): - if line.startswith("Features"): - _, features = line.split(":") - return features.strip().split() - return features + """Returns the CPU features for the Android device.""" + cpuinfo = execute_cmd_and_get_stdout( + ["adb", "shell", "cat", "/proc/cpuinfo"], verbose=verbose + ) + features = [] + for line in cpuinfo.splitlines(): + if line.startswith("Features"): + _, features = line.split(":") + return features.strip().split() + return features def get_android_gpu_name(verbose: bool = False) -> str: - """Returns the GPU name for the Android device.""" - vkjson = execute_cmd_and_get_stdout(["adb", "shell", "cmd", "gpu", "vkjson"], - verbose=verbose) - vkjson = json.loads(vkjson) - name = vkjson["devices"][0]["properties"]["deviceName"] + """Returns the GPU name for the Android device.""" + vkjson = execute_cmd_and_get_stdout( + ["adb", "shell", "cmd", "gpu", "vkjson"], verbose=verbose + ) + vkjson = json.loads(vkjson) + name = vkjson["devices"][0]["properties"]["deviceName"] - # Perform some canonicalization: + # Perform some canonicalization: - # - Adreno GPUs have raw names like "Adreno (TM) 650". - name = name.replace("(TM)", "") + # - Adreno GPUs have raw names like "Adreno (TM) 650". + name = name.replace("(TM)", "") - # Replace all consecutive non-word characters with a single hyphen. - name = re.sub(r"\W+", "-", name) + # Replace all consecutive non-word characters with a single hyphen. + name = re.sub(r"\W+", "-", name) - return name + return name def get_android_device_info(verbose: bool = False) -> DeviceInfo: - """Returns device info for the Android device.""" - return DeviceInfo(platform_type=PlatformType.ANDROID, - model=get_android_device_model(verbose), - cpu_abi=get_android_cpu_abi(verbose), - cpu_uarch=None, - cpu_features=get_android_cpu_features(verbose), - gpu_name=get_android_gpu_name(verbose)) + """Returns device info for the Android device.""" + return DeviceInfo( + platform_type=PlatformType.ANDROID, + model=get_android_device_model(verbose), + cpu_abi=get_android_cpu_abi(verbose), + cpu_uarch=None, + cpu_features=get_android_cpu_features(verbose), + gpu_name=get_android_gpu_name(verbose), + ) diff --git a/build_tools/benchmarks/common/benchmark_config.py b/build_tools/benchmarks/common/benchmark_config.py index 2d08d4e1e1f7..f9fb2ddb28f7 100644 --- a/build_tools/benchmarks/common/benchmark_config.py +++ b/build_tools/benchmarks/common/benchmark_config.py @@ -16,24 +16,24 @@ @dataclass class TraceCaptureConfig: - """Represents the settings for capturing traces during benchamrking. + """Represents the settings for capturing traces during benchamrking. traced_benchmark_tool_dir: the path to the tracing-enabled benchmark tool directory. trace_capture_tool: the path to the tool for collecting captured traces. capture_tarball: the path of capture tar archive. capture_tmp_dir: the temporary directory to store captured traces. - """ + """ - traced_benchmark_tool_dir: pathlib.Path - trace_capture_tool: pathlib.Path - capture_tarball: pathlib.Path - capture_tmp_dir: pathlib.Path + traced_benchmark_tool_dir: pathlib.Path + trace_capture_tool: pathlib.Path + capture_tarball: pathlib.Path + capture_tmp_dir: pathlib.Path @dataclass class BenchmarkConfig: - """Represents the settings to run benchmarks. + """Represents the settings to run benchmarks. root_benchmark_dir: the root directory containing the built benchmark suites. @@ -56,71 +56,72 @@ class BenchmarkConfig: times. continue_from_previous: skip the benchmarks if their results are found in the benchmark_results_dir. - """ - - root_benchmark_dir: pathlib.Path - benchmark_results_dir: pathlib.Path - git_commit_hash: str - - normal_benchmark_tool_dir: Optional[pathlib.Path] = None - trace_capture_config: Optional[TraceCaptureConfig] = None - - driver_filter: Optional[str] = None - model_name_filter: Optional[str] = None - mode_filter: Optional[str] = None - use_compatible_filter: bool = False - - keep_going: bool = False - benchmark_min_time: float = 0 - continue_from_previous: bool = False - - @staticmethod - def build_from_args(args: Namespace, git_commit_hash: str): - """Build config from command arguments. - - Args: - args: the command arguments. - git_commit_hash: the git commit hash of IREE. """ - def real_path_or_none( - path: Optional[pathlib.Path]) -> Optional[pathlib.Path]: - return path.resolve() if path else None - - if not args.normal_benchmark_tool_dir and not args.traced_benchmark_tool_dir: - raise ValueError( - "At least one of --normal_benchmark_tool_dir or --traced_benchmark_tool_dir should be specified." - ) - if not ((args.traced_benchmark_tool_dir is None) == - (args.trace_capture_tool is None) == - (args.capture_tarball is None)): - raise ValueError( - "The following 3 flags should be simultaneously all specified or all unspecified: --traced_benchmark_tool_dir, --trace_capture_tool, --capture_tarball" - ) - - per_commit_tmp_dir: pathlib.Path = (args.tmp_dir / - git_commit_hash).resolve() - - if args.traced_benchmark_tool_dir is None: - trace_capture_config = None - else: - trace_capture_config = TraceCaptureConfig( - traced_benchmark_tool_dir=args.traced_benchmark_tool_dir.resolve(), - trace_capture_tool=args.trace_capture_tool.resolve(), - capture_tarball=args.capture_tarball.resolve(), - capture_tmp_dir=per_commit_tmp_dir / CAPTURES_REL_PATH) - - return BenchmarkConfig(root_benchmark_dir=args.e2e_test_artifacts_dir, - benchmark_results_dir=per_commit_tmp_dir / - BENCHMARK_RESULTS_REL_PATH, - git_commit_hash=git_commit_hash, - normal_benchmark_tool_dir=real_path_or_none( - args.normal_benchmark_tool_dir), - trace_capture_config=trace_capture_config, - driver_filter=args.driver_filter_regex, - model_name_filter=args.model_name_regex, - mode_filter=args.mode_regex, - use_compatible_filter=args.compatible_only, - keep_going=args.keep_going, - benchmark_min_time=args.benchmark_min_time, - continue_from_previous=args.continue_from_previous) + root_benchmark_dir: pathlib.Path + benchmark_results_dir: pathlib.Path + git_commit_hash: str + + normal_benchmark_tool_dir: Optional[pathlib.Path] = None + trace_capture_config: Optional[TraceCaptureConfig] = None + + driver_filter: Optional[str] = None + model_name_filter: Optional[str] = None + mode_filter: Optional[str] = None + use_compatible_filter: bool = False + + keep_going: bool = False + benchmark_min_time: float = 0 + continue_from_previous: bool = False + + @staticmethod + def build_from_args(args: Namespace, git_commit_hash: str): + """Build config from command arguments. + + Args: + args: the command arguments. + git_commit_hash: the git commit hash of IREE. + """ + + def real_path_or_none(path: Optional[pathlib.Path]) -> Optional[pathlib.Path]: + return path.resolve() if path else None + + if not args.normal_benchmark_tool_dir and not args.traced_benchmark_tool_dir: + raise ValueError( + "At least one of --normal_benchmark_tool_dir or --traced_benchmark_tool_dir should be specified." + ) + if not ( + (args.traced_benchmark_tool_dir is None) + == (args.trace_capture_tool is None) + == (args.capture_tarball is None) + ): + raise ValueError( + "The following 3 flags should be simultaneously all specified or all unspecified: --traced_benchmark_tool_dir, --trace_capture_tool, --capture_tarball" + ) + + per_commit_tmp_dir: pathlib.Path = (args.tmp_dir / git_commit_hash).resolve() + + if args.traced_benchmark_tool_dir is None: + trace_capture_config = None + else: + trace_capture_config = TraceCaptureConfig( + traced_benchmark_tool_dir=args.traced_benchmark_tool_dir.resolve(), + trace_capture_tool=args.trace_capture_tool.resolve(), + capture_tarball=args.capture_tarball.resolve(), + capture_tmp_dir=per_commit_tmp_dir / CAPTURES_REL_PATH, + ) + + return BenchmarkConfig( + root_benchmark_dir=args.e2e_test_artifacts_dir, + benchmark_results_dir=per_commit_tmp_dir / BENCHMARK_RESULTS_REL_PATH, + git_commit_hash=git_commit_hash, + normal_benchmark_tool_dir=real_path_or_none(args.normal_benchmark_tool_dir), + trace_capture_config=trace_capture_config, + driver_filter=args.driver_filter_regex, + model_name_filter=args.model_name_regex, + mode_filter=args.mode_regex, + use_compatible_filter=args.compatible_only, + keep_going=args.keep_going, + benchmark_min_time=args.benchmark_min_time, + continue_from_previous=args.continue_from_previous, + ) diff --git a/build_tools/benchmarks/common/benchmark_config_test.py b/build_tools/benchmarks/common/benchmark_config_test.py index 32d238752df0..2a446ab729ce 100644 --- a/build_tools/benchmarks/common/benchmark_config_test.py +++ b/build_tools/benchmarks/common/benchmark_config_test.py @@ -15,97 +15,109 @@ class BenchmarkConfigTest(unittest.TestCase): - - def setUp(self): - self._tmp_dir_manager = tempfile.TemporaryDirectory() - self.tmp_dir = pathlib.Path(self._tmp_dir_manager.name).resolve() - self._build_dir_manager = tempfile.TemporaryDirectory() - self.build_dir = pathlib.Path(self._build_dir_manager.name).resolve() - self.e2e_test_artifacts_dir = self.build_dir / "e2e_test_artifacts" - self.e2e_test_artifacts_dir.mkdir() - self.normal_tool_dir = self.build_dir / "normal_tool" - self.normal_tool_dir.mkdir() - self.traced_tool_dir = self.build_dir / "traced_tool" - self.traced_tool_dir.mkdir() - self.trace_capture_tool = self.build_dir / "tracy_capture" - # Create capture tool with executable file mode. - self.trace_capture_tool.touch(mode=0o755) - self.execution_config = self.build_dir / "execution_config.json" - self.execution_config.touch() - - def tearDown(self): - self._build_dir_manager.cleanup() - self._tmp_dir_manager.cleanup() - - def test_build_from_args(self): - args = common_arguments.Parser().parse_args([ - f"--tmp_dir={self.tmp_dir}", - f"--normal_benchmark_tool_dir={self.normal_tool_dir}", - f"--traced_benchmark_tool_dir={self.traced_tool_dir}", - f"--trace_capture_tool={self.trace_capture_tool}", - f"--capture_tarball=capture.tar", - f"--driver_filter_regex=a", - f"--model_name_regex=b", - f"--mode_regex=c", - f"--keep_going", - f"--benchmark_min_time=10", - f"--compatible_only", - f"--e2e_test_artifacts_dir={self.e2e_test_artifacts_dir}", - f"--execution_benchmark_config={self.execution_config}", - "--target_device=test", - ]) - - config = benchmark_config.BenchmarkConfig.build_from_args( - args=args, git_commit_hash="abcd") - - per_commit_tmp_dir = self.tmp_dir / "abcd" - expected_trace_capture_config = benchmark_config.TraceCaptureConfig( - traced_benchmark_tool_dir=self.traced_tool_dir, - trace_capture_tool=pathlib.Path(self.trace_capture_tool).resolve(), - capture_tarball=pathlib.Path("capture.tar").resolve(), - capture_tmp_dir=per_commit_tmp_dir / "captures") - expected_config = benchmark_config.BenchmarkConfig( - root_benchmark_dir=self.e2e_test_artifacts_dir, - benchmark_results_dir=per_commit_tmp_dir / "benchmark-results", - git_commit_hash="abcd", - normal_benchmark_tool_dir=self.normal_tool_dir, - trace_capture_config=expected_trace_capture_config, - driver_filter="a", - model_name_filter="b", - mode_filter="c", - keep_going=True, - benchmark_min_time=10, - use_compatible_filter=True) - self.assertEqual(config, expected_config) - - def test_build_from_args_benchmark_only(self): - args = common_arguments.Parser().parse_args([ - f"--tmp_dir={self.tmp_dir}", - f"--normal_benchmark_tool_dir={self.normal_tool_dir}", - f"--e2e_test_artifacts_dir={self.e2e_test_artifacts_dir}", - f"--execution_benchmark_config={self.execution_config}", - "--target_device=test", - ]) - - config = benchmark_config.BenchmarkConfig.build_from_args( - args=args, git_commit_hash="abcd") - - self.assertIsNone(config.trace_capture_config) - - def test_build_from_args_invalid_capture_args(self): - args = common_arguments.Parser().parse_args([ - f"--tmp_dir={self.tmp_dir}", - f"--normal_benchmark_tool_dir={self.normal_tool_dir}", - f"--traced_benchmark_tool_dir={self.traced_tool_dir}", - f"--e2e_test_artifacts_dir={self.e2e_test_artifacts_dir}", - f"--execution_benchmark_config={self.execution_config}", - "--target_device=test", - ]) - - self.assertRaises( - ValueError, lambda: benchmark_config.BenchmarkConfig.build_from_args( - args=args, git_commit_hash="abcd")) + def setUp(self): + self._tmp_dir_manager = tempfile.TemporaryDirectory() + self.tmp_dir = pathlib.Path(self._tmp_dir_manager.name).resolve() + self._build_dir_manager = tempfile.TemporaryDirectory() + self.build_dir = pathlib.Path(self._build_dir_manager.name).resolve() + self.e2e_test_artifacts_dir = self.build_dir / "e2e_test_artifacts" + self.e2e_test_artifacts_dir.mkdir() + self.normal_tool_dir = self.build_dir / "normal_tool" + self.normal_tool_dir.mkdir() + self.traced_tool_dir = self.build_dir / "traced_tool" + self.traced_tool_dir.mkdir() + self.trace_capture_tool = self.build_dir / "tracy_capture" + # Create capture tool with executable file mode. + self.trace_capture_tool.touch(mode=0o755) + self.execution_config = self.build_dir / "execution_config.json" + self.execution_config.touch() + + def tearDown(self): + self._build_dir_manager.cleanup() + self._tmp_dir_manager.cleanup() + + def test_build_from_args(self): + args = common_arguments.Parser().parse_args( + [ + f"--tmp_dir={self.tmp_dir}", + f"--normal_benchmark_tool_dir={self.normal_tool_dir}", + f"--traced_benchmark_tool_dir={self.traced_tool_dir}", + f"--trace_capture_tool={self.trace_capture_tool}", + f"--capture_tarball=capture.tar", + f"--driver_filter_regex=a", + f"--model_name_regex=b", + f"--mode_regex=c", + f"--keep_going", + f"--benchmark_min_time=10", + f"--compatible_only", + f"--e2e_test_artifacts_dir={self.e2e_test_artifacts_dir}", + f"--execution_benchmark_config={self.execution_config}", + "--target_device=test", + ] + ) + + config = benchmark_config.BenchmarkConfig.build_from_args( + args=args, git_commit_hash="abcd" + ) + + per_commit_tmp_dir = self.tmp_dir / "abcd" + expected_trace_capture_config = benchmark_config.TraceCaptureConfig( + traced_benchmark_tool_dir=self.traced_tool_dir, + trace_capture_tool=pathlib.Path(self.trace_capture_tool).resolve(), + capture_tarball=pathlib.Path("capture.tar").resolve(), + capture_tmp_dir=per_commit_tmp_dir / "captures", + ) + expected_config = benchmark_config.BenchmarkConfig( + root_benchmark_dir=self.e2e_test_artifacts_dir, + benchmark_results_dir=per_commit_tmp_dir / "benchmark-results", + git_commit_hash="abcd", + normal_benchmark_tool_dir=self.normal_tool_dir, + trace_capture_config=expected_trace_capture_config, + driver_filter="a", + model_name_filter="b", + mode_filter="c", + keep_going=True, + benchmark_min_time=10, + use_compatible_filter=True, + ) + self.assertEqual(config, expected_config) + + def test_build_from_args_benchmark_only(self): + args = common_arguments.Parser().parse_args( + [ + f"--tmp_dir={self.tmp_dir}", + f"--normal_benchmark_tool_dir={self.normal_tool_dir}", + f"--e2e_test_artifacts_dir={self.e2e_test_artifacts_dir}", + f"--execution_benchmark_config={self.execution_config}", + "--target_device=test", + ] + ) + + config = benchmark_config.BenchmarkConfig.build_from_args( + args=args, git_commit_hash="abcd" + ) + + self.assertIsNone(config.trace_capture_config) + + def test_build_from_args_invalid_capture_args(self): + args = common_arguments.Parser().parse_args( + [ + f"--tmp_dir={self.tmp_dir}", + f"--normal_benchmark_tool_dir={self.normal_tool_dir}", + f"--traced_benchmark_tool_dir={self.traced_tool_dir}", + f"--e2e_test_artifacts_dir={self.e2e_test_artifacts_dir}", + f"--execution_benchmark_config={self.execution_config}", + "--target_device=test", + ] + ) + + self.assertRaises( + ValueError, + lambda: benchmark_config.BenchmarkConfig.build_from_args( + args=args, git_commit_hash="abcd" + ), + ) if __name__ == "__main__": - unittest.main() + unittest.main() diff --git a/build_tools/benchmarks/common/benchmark_definition.py b/build_tools/benchmarks/common/benchmark_definition.py index 1bd29ea2a161..bdfd5f7e77d4 100644 --- a/build_tools/benchmarks/common/benchmark_definition.py +++ b/build_tools/benchmarks/common/benchmark_definition.py @@ -23,68 +23,53 @@ # A map from CPU ABI to IREE's benchmark target architecture. CPU_ABI_TO_TARGET_ARCH_MAP = { - "arm64-v8a": - common_definitions.DeviceArchitecture.ARMV8_2_A_GENERIC, - "x86_64-cascadelake": - common_definitions.DeviceArchitecture.X86_64_CASCADELAKE, + "arm64-v8a": common_definitions.DeviceArchitecture.ARMV8_2_A_GENERIC, + "x86_64-cascadelake": common_definitions.DeviceArchitecture.X86_64_CASCADELAKE, } # A map from GPU name to IREE's benchmark target architecture. GPU_NAME_TO_TARGET_ARCH_MAP = { - "adreno-640": - common_definitions.DeviceArchitecture.QUALCOMM_ADRENO, - "adreno-650": - common_definitions.DeviceArchitecture.QUALCOMM_ADRENO, - "adreno-660": - common_definitions.DeviceArchitecture.QUALCOMM_ADRENO, - "adreno-730": - common_definitions.DeviceArchitecture.QUALCOMM_ADRENO, - "mali-g77": - common_definitions.DeviceArchitecture.ARM_VALHALL, - "mali-g78": - common_definitions.DeviceArchitecture.ARM_VALHALL, - "tesla-v100-sxm2-16gb": - common_definitions.DeviceArchitecture.NVIDIA_PASCAL, - "nvidia-a100-sxm4-40gb": - common_definitions.DeviceArchitecture.NVIDIA_AMPERE, - "nvidia-geforce-rtx-3090": - common_definitions.DeviceArchitecture.NVIDIA_AMPERE, + "adreno-640": common_definitions.DeviceArchitecture.QUALCOMM_ADRENO, + "adreno-650": common_definitions.DeviceArchitecture.QUALCOMM_ADRENO, + "adreno-660": common_definitions.DeviceArchitecture.QUALCOMM_ADRENO, + "adreno-730": common_definitions.DeviceArchitecture.QUALCOMM_ADRENO, + "mali-g77": common_definitions.DeviceArchitecture.ARM_VALHALL, + "mali-g78": common_definitions.DeviceArchitecture.ARM_VALHALL, + "tesla-v100-sxm2-16gb": common_definitions.DeviceArchitecture.NVIDIA_PASCAL, + "nvidia-a100-sxm4-40gb": common_definitions.DeviceArchitecture.NVIDIA_AMPERE, + "nvidia-geforce-rtx-3090": common_definitions.DeviceArchitecture.NVIDIA_AMPERE, } @dataclasses.dataclass class DriverInfo: - """An object describing a IREE HAL driver. + """An object describing a IREE HAL driver. - It includes the following characteristics: - - pretty_name: the pretty name, e.g., 'IREE-LLVM-CPU' - - device_type: the targeted device type, e.g., 'CPU' - - driver_name: runtime driver flag, e.g., 'local-task' - - loader_name: executable loader name, if used - """ + It includes the following characteristics: + - pretty_name: the pretty name, e.g., 'IREE-LLVM-CPU' + - device_type: the targeted device type, e.g., 'CPU' + - driver_name: runtime driver flag, e.g., 'local-task' + - loader_name: executable loader name, if used + """ - pretty_name: str - device_type: str - driver_name: str - loader_name: str + pretty_name: str + device_type: str + driver_name: str + loader_name: str # A map for IREE driver names. This allows us to normalize driver names like # mapping to more friendly ones and detach to keep driver names used in # benchmark presentation stable. IREE_DRIVERS_INFOS = { - "iree-llvm-cpu": - DriverInfo("IREE-LLVM-CPU", "CPU", "local-task", "embedded-elf"), - "iree-llvm-cpu-sync": - DriverInfo("IREE-LLVM-CPU-Sync", "CPU", "local-sync", "embedded-elf"), - "iree-vmvx": - DriverInfo("IREE-VMVX", "CPU", "local-task", "vmvx-module"), - "iree-vmvx-sync": - DriverInfo("IREE-VMVX-Sync", "CPU", "local-sync", "vmvx-module"), - "iree-vulkan": - DriverInfo("IREE-Vulkan", "GPU", "vulkan", ""), - "iree-cuda": - DriverInfo("IREE-CUDA", "GPU", "cuda", ""), + "iree-llvm-cpu": DriverInfo("IREE-LLVM-CPU", "CPU", "local-task", "embedded-elf"), + "iree-llvm-cpu-sync": DriverInfo( + "IREE-LLVM-CPU-Sync", "CPU", "local-sync", "embedded-elf" + ), + "iree-vmvx": DriverInfo("IREE-VMVX", "CPU", "local-task", "vmvx-module"), + "iree-vmvx-sync": DriverInfo("IREE-VMVX-Sync", "CPU", "local-sync", "vmvx-module"), + "iree-vulkan": DriverInfo("IREE-Vulkan", "GPU", "vulkan", ""), + "iree-cuda": DriverInfo("IREE-CUDA", "GPU", "cuda", ""), } IREE_PRETTY_NAME_TO_DRIVER_NAME = { @@ -92,296 +77,318 @@ class DriverInfo: } -def execute_cmd(args: Sequence[Any], - verbose: bool = False, - **kwargs) -> subprocess.CompletedProcess: - """Executes a command and returns the completed process. - - A thin wrapper around subprocess.run that sets some useful defaults and - optionally prints out the command being run. - - Raises: - CalledProcessError if the command fails. - """ - if verbose: - print(f"cmd: {args}") - try: - return subprocess.run(args, check=True, text=True, **kwargs) - except subprocess.CalledProcessError as exc: - print((f"\n\nThe following command failed:\n\n{args}" - f"\n\nReturn code: {exc.returncode}\n\n")) - if exc.stdout: - print(f"Stdout:\n\n{exc.stdout}\n\n") - if exc.stderr: - print(f"Stderr:\n\n{exc.stderr}\n\n") - raise exc - - -def execute_cmd_and_get_output(args: Sequence[Any], - verbose: bool = False, - **kwargs) -> Tuple[str, str]: - """Executes a command and returns its stdout and stderr - - Same as execute_cmd except captures stdout and stderr. - """ - exc = execute_cmd(args, - verbose=verbose, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - **kwargs) - return exc.stdout.strip(), exc.stderr.strip() - - -def execute_cmd_and_get_stdout(args: Sequence[Any], - verbose: bool = False, - **kwargs) -> str: - """Executes a command and returns its stdout. - - Same as execute_cmd except captures stdout (and not stderr). - """ - stdout, _ = execute_cmd_and_get_output(args, verbose=verbose, **kwargs) - return stdout +def execute_cmd( + args: Sequence[Any], verbose: bool = False, **kwargs +) -> subprocess.CompletedProcess: + """Executes a command and returns the completed process. + + A thin wrapper around subprocess.run that sets some useful defaults and + optionally prints out the command being run. + + Raises: + CalledProcessError if the command fails. + """ + if verbose: + print(f"cmd: {args}") + try: + return subprocess.run(args, check=True, text=True, **kwargs) + except subprocess.CalledProcessError as exc: + print( + ( + f"\n\nThe following command failed:\n\n{args}" + f"\n\nReturn code: {exc.returncode}\n\n" + ) + ) + if exc.stdout: + print(f"Stdout:\n\n{exc.stdout}\n\n") + if exc.stderr: + print(f"Stderr:\n\n{exc.stderr}\n\n") + raise exc + + +def execute_cmd_and_get_output( + args: Sequence[Any], verbose: bool = False, **kwargs +) -> Tuple[str, str]: + """Executes a command and returns its stdout and stderr + + Same as execute_cmd except captures stdout and stderr. + """ + exc = execute_cmd( + args, verbose=verbose, stdout=subprocess.PIPE, stderr=subprocess.PIPE, **kwargs + ) + return exc.stdout.strip(), exc.stderr.strip() + + +def execute_cmd_and_get_stdout( + args: Sequence[Any], verbose: bool = False, **kwargs +) -> str: + """Executes a command and returns its stdout. + + Same as execute_cmd except captures stdout (and not stderr). + """ + stdout, _ = execute_cmd_and_get_output(args, verbose=verbose, **kwargs) + return stdout def get_git_commit_hash(commit: str) -> str: - return execute_cmd_and_get_stdout(['git', 'rev-parse', commit], - cwd=pathlib.Path(__file__).resolve().parent) + return execute_cmd_and_get_stdout( + ["git", "rev-parse", commit], cwd=pathlib.Path(__file__).resolve().parent + ) def get_iree_benchmark_module_arguments( results_filename: str, driver_info: DriverInfo, - benchmark_min_time: Optional[float] = None): - """Returns the common arguments to run iree-benchmark-module.""" - - if driver_info.loader_name == "vmvx-module": - # VMVX is very unoptimized for now and can take a long time to run. - # Decrease the repetition for it until it's reasonably fast. - repetitions = 3 - else: - repetitions = 10 - - cmd = [ - "--time_unit=ns", - "--benchmark_format=json", - "--benchmark_out_format=json", - f"--benchmark_out={results_filename}", - "--print_statistics=true", - ] - if benchmark_min_time: - cmd.extend([ - f"--benchmark_min_time={benchmark_min_time}", - ]) - else: - cmd.extend([ - f"--benchmark_repetitions={repetitions}", - ]) - - return cmd - - -def wait_for_iree_benchmark_module_start(process: subprocess.Popen, - verbose: bool = False) -> None: - """Wait for the start of iree-benchmark module; otherwise will see connection - failure when opening the catpure tool.""" - - while True: - line = process.stdout.readline() # pytype: disable=attribute-error - if line == "" and process.poll() is not None: # Process completed - raise ValueError("Cannot find benchmark result line in the log!") - if verbose: - print(line.strip()) - # Result available - if re.match(r"^BM_.+/real_time", line) is not None: - break + benchmark_min_time: Optional[float] = None, +): + """Returns the common arguments to run iree-benchmark-module.""" + + if driver_info.loader_name == "vmvx-module": + # VMVX is very unoptimized for now and can take a long time to run. + # Decrease the repetition for it until it's reasonably fast. + repetitions = 3 + else: + repetitions = 10 + + cmd = [ + "--time_unit=ns", + "--benchmark_format=json", + "--benchmark_out_format=json", + f"--benchmark_out={results_filename}", + "--print_statistics=true", + ] + if benchmark_min_time: + cmd.extend( + [ + f"--benchmark_min_time={benchmark_min_time}", + ] + ) + else: + cmd.extend( + [ + f"--benchmark_repetitions={repetitions}", + ] + ) + + return cmd + + +def wait_for_iree_benchmark_module_start( + process: subprocess.Popen, verbose: bool = False +) -> None: + """Wait for the start of iree-benchmark module; otherwise will see connection + failure when opening the catpure tool.""" + + while True: + line = process.stdout.readline() # pytype: disable=attribute-error + if line == "" and process.poll() is not None: # Process completed + raise ValueError("Cannot find benchmark result line in the log!") + if verbose: + print(line.strip()) + # Result available + if re.match(r"^BM_.+/real_time", line) is not None: + break class PlatformType(Enum): - ANDROID = "Android" - LINUX = "Linux" + ANDROID = "Android" + LINUX = "Linux" @dataclasses.dataclass(frozen=True) class DeviceInfo: - """An object describing a device. - - It includes the following characteristics: - - platform_type: the OS platform, e.g., 'Android' - - model: the product model, e.g., 'Pixel-4' - - cpu_abi: the CPU ABI, e.g., 'arm64-v8a', 'x86_64' - - cpu_uarch: the CPU microarchitecture, e.g., 'CascadeLake' - - cpu_features: the detailed CPU features, e.g., ['fphp', 'sve'] - - gpu_name: the GPU name, e.g., 'Mali-G77' - """ - - platform_type: PlatformType - model: str - cpu_abi: str - cpu_uarch: Optional[str] - cpu_features: Sequence[str] - gpu_name: str - - def __str__(self): - features = ", ".join(self.cpu_features) - params = [ - f"model='{self.model}'", - f"cpu_abi='{self.cpu_abi}'", - f"cpu_uarch='{self.cpu_uarch}'", - f"gpu_name='{self.gpu_name}'", - f"cpu_features=[{features}]", - ] - params = ", ".join(params) - return f"{self.platform_type.value} device <{params}>" - - def get_cpu_arch(self) -> Optional[common_definitions.DeviceArchitecture]: - name = self.cpu_abi.lower() - if self.cpu_uarch: - name += f"-{self.cpu_uarch.lower()}" - - return CPU_ABI_TO_TARGET_ARCH_MAP.get(name) - - def get_gpu_arch(self) -> Optional[common_definitions.DeviceArchitecture]: - name = self.gpu_name.lower() - return GPU_NAME_TO_TARGET_ARCH_MAP.get(name) - - def get_detailed_cpu_arch_name(self) -> str: - """Returns the detailed architecture name.""" - - if self.cpu_abi == "arm64-v8a": - return self.__get_arm_cpu_arch_revision() - if self.cpu_abi == "x86_64": - return self.__get_x86_detailed_cpu_arch_name() - raise ValueError("Unrecognized CPU ABI; need to update the list") - - def to_json_object(self) -> Dict[str, Any]: - return { - "platform_type": self.platform_type.value, - "model": self.model, - "cpu_abi": self.cpu_abi, - "cpu_uarch": self.cpu_uarch if self.cpu_uarch else "", - "cpu_features": self.cpu_features, - "gpu_name": self.gpu_name, - } - - @staticmethod - def from_json_object(json_object: Dict[str, Any]): - cpu_uarch = json_object.get("cpu_uarch") - return DeviceInfo(PlatformType(json_object["platform_type"]), - json_object["model"], json_object["cpu_abi"], - None if cpu_uarch == "" else cpu_uarch, - json_object["cpu_features"], json_object["gpu_name"]) - - def __get_x86_detailed_cpu_arch_name(self) -> str: - """Returns the x86 architecture with microarchitecture name.""" - - if not self.cpu_uarch: - return self.cpu_abi - - return f"{self.cpu_abi}-{self.cpu_uarch}" - - def __get_arm_cpu_arch_revision(self) -> str: - """Returns the ARM architecture revision.""" - - # CPU features for ARMv8 revisions. - # From https://en.wikichip.org/wiki/arm/armv8#ARMv8_Extensions_and_Processor_Features - rev1_features = ["atomics", "asimdrdm"] - rev2_features = [ - "fphp", "dcpop", "sha3", "sm3", "sm4", "asimddp", "sha512", "sve" - ] + """An object describing a device. + + It includes the following characteristics: + - platform_type: the OS platform, e.g., 'Android' + - model: the product model, e.g., 'Pixel-4' + - cpu_abi: the CPU ABI, e.g., 'arm64-v8a', 'x86_64' + - cpu_uarch: the CPU microarchitecture, e.g., 'CascadeLake' + - cpu_features: the detailed CPU features, e.g., ['fphp', 'sve'] + - gpu_name: the GPU name, e.g., 'Mali-G77' + """ - rev = "ARMv8-A" - if any([f in self.cpu_features for f in rev1_features]): - rev = "ARMv8.1-A" - if any([f in self.cpu_features for f in rev2_features]): - rev = "ARMv8.2-A" - return rev + platform_type: PlatformType + model: str + cpu_abi: str + cpu_uarch: Optional[str] + cpu_features: Sequence[str] + gpu_name: str + + def __str__(self): + features = ", ".join(self.cpu_features) + params = [ + f"model='{self.model}'", + f"cpu_abi='{self.cpu_abi}'", + f"cpu_uarch='{self.cpu_uarch}'", + f"gpu_name='{self.gpu_name}'", + f"cpu_features=[{features}]", + ] + params = ", ".join(params) + return f"{self.platform_type.value} device <{params}>" + + def get_cpu_arch(self) -> Optional[common_definitions.DeviceArchitecture]: + name = self.cpu_abi.lower() + if self.cpu_uarch: + name += f"-{self.cpu_uarch.lower()}" + + return CPU_ABI_TO_TARGET_ARCH_MAP.get(name) + + def get_gpu_arch(self) -> Optional[common_definitions.DeviceArchitecture]: + name = self.gpu_name.lower() + return GPU_NAME_TO_TARGET_ARCH_MAP.get(name) + + def get_detailed_cpu_arch_name(self) -> str: + """Returns the detailed architecture name.""" + + if self.cpu_abi == "arm64-v8a": + return self.__get_arm_cpu_arch_revision() + if self.cpu_abi == "x86_64": + return self.__get_x86_detailed_cpu_arch_name() + raise ValueError("Unrecognized CPU ABI; need to update the list") + + def to_json_object(self) -> Dict[str, Any]: + return { + "platform_type": self.platform_type.value, + "model": self.model, + "cpu_abi": self.cpu_abi, + "cpu_uarch": self.cpu_uarch if self.cpu_uarch else "", + "cpu_features": self.cpu_features, + "gpu_name": self.gpu_name, + } + + @staticmethod + def from_json_object(json_object: Dict[str, Any]): + cpu_uarch = json_object.get("cpu_uarch") + return DeviceInfo( + PlatformType(json_object["platform_type"]), + json_object["model"], + json_object["cpu_abi"], + None if cpu_uarch == "" else cpu_uarch, + json_object["cpu_features"], + json_object["gpu_name"], + ) + + def __get_x86_detailed_cpu_arch_name(self) -> str: + """Returns the x86 architecture with microarchitecture name.""" + + if not self.cpu_uarch: + return self.cpu_abi + + return f"{self.cpu_abi}-{self.cpu_uarch}" + + def __get_arm_cpu_arch_revision(self) -> str: + """Returns the ARM architecture revision.""" + + # CPU features for ARMv8 revisions. + # From https://en.wikichip.org/wiki/arm/armv8#ARMv8_Extensions_and_Processor_Features + rev1_features = ["atomics", "asimdrdm"] + rev2_features = [ + "fphp", + "dcpop", + "sha3", + "sm3", + "sm4", + "asimddp", + "sha512", + "sve", + ] + + rev = "ARMv8-A" + if any([f in self.cpu_features for f in rev1_features]): + rev = "ARMv8.1-A" + if any([f in self.cpu_features for f in rev2_features]): + rev = "ARMv8.2-A" + return rev @dataclasses.dataclass(frozen=True) class BenchmarkInfo: - """An object describing the current benchmark. - - It includes the following benchmark characteristics: - - name: the benchmark name - - model_name: the model name, e.g., 'MobileNetV2' - - model_tags: a list of tags used to describe additional model information, - e.g., ['imagenet'] - - model_source: the source of the model, e.g., 'TensorFlow' - - bench_mode: a list of tags for benchmark mode, - e.g., ['1-thread', 'big-core', 'full-inference'] - - device_info: an DriverInfo object describing the IREE runtime dirver. - - device_info: an DeviceInfo object describing the device where benchmarks run - - compile_tags: an optional list of tags to describe the compile configs, - e.g., ['fuse-padding'] - - runner: which runner is used for benchmarking, e.g., 'iree_vulkan', 'tflite' - - run_config_id: ID of the corresponding iree_definitions.E2EModelRunConfig. - """ - - name: str - model_name: str - model_tags: Sequence[str] - model_source: str - bench_mode: Sequence[str] - driver_info: DriverInfo - device_info: DeviceInfo - compile_tags: Optional[Sequence[str]] = None - run_config_id: Optional[str] = None - - def __str__(self): - return self.name - - def to_json_object(self) -> Dict[str, Any]: - return { - "name": self.name, - "model_name": self.model_name, - "model_tags": self.model_tags, - "model_source": self.model_source, - "bench_mode": self.bench_mode, - "compile_tags": self.compile_tags, - # Get the "iree-*" driver name from the DriverInfo. - "runner": IREE_PRETTY_NAME_TO_DRIVER_NAME[self.driver_info.pretty_name], - "device_info": self.device_info.to_json_object(), - "run_config_id": self.run_config_id - } - - @staticmethod - def from_json_object(json_object: Dict[str, Any]): - driver_info = IREE_DRIVERS_INFOS.get(json_object["runner"]) - if not driver_info: - raise ValueError(f"Unrecognized runner: {json_object['runner']}") - - return BenchmarkInfo(name=json_object["name"], - model_name=json_object["model_name"], - model_tags=json_object["model_tags"], - model_source=json_object["model_source"], - bench_mode=json_object["bench_mode"], - compile_tags=json_object.get("compile_tags"), - driver_info=driver_info, - device_info=DeviceInfo.from_json_object( - json_object["device_info"]), - run_config_id=json_object.get("run_config_id")) + """An object describing the current benchmark. + + It includes the following benchmark characteristics: + - name: the benchmark name + - model_name: the model name, e.g., 'MobileNetV2' + - model_tags: a list of tags used to describe additional model information, + e.g., ['imagenet'] + - model_source: the source of the model, e.g., 'TensorFlow' + - bench_mode: a list of tags for benchmark mode, + e.g., ['1-thread', 'big-core', 'full-inference'] + - device_info: an DriverInfo object describing the IREE runtime dirver. + - device_info: an DeviceInfo object describing the device where benchmarks run + - compile_tags: an optional list of tags to describe the compile configs, + e.g., ['fuse-padding'] + - runner: which runner is used for benchmarking, e.g., 'iree_vulkan', 'tflite' + - run_config_id: ID of the corresponding iree_definitions.E2EModelRunConfig. + """ + + name: str + model_name: str + model_tags: Sequence[str] + model_source: str + bench_mode: Sequence[str] + driver_info: DriverInfo + device_info: DeviceInfo + compile_tags: Optional[Sequence[str]] = None + run_config_id: Optional[str] = None + + def __str__(self): + return self.name + + def to_json_object(self) -> Dict[str, Any]: + return { + "name": self.name, + "model_name": self.model_name, + "model_tags": self.model_tags, + "model_source": self.model_source, + "bench_mode": self.bench_mode, + "compile_tags": self.compile_tags, + # Get the "iree-*" driver name from the DriverInfo. + "runner": IREE_PRETTY_NAME_TO_DRIVER_NAME[self.driver_info.pretty_name], + "device_info": self.device_info.to_json_object(), + "run_config_id": self.run_config_id, + } + + @staticmethod + def from_json_object(json_object: Dict[str, Any]): + driver_info = IREE_DRIVERS_INFOS.get(json_object["runner"]) + if not driver_info: + raise ValueError(f"Unrecognized runner: {json_object['runner']}") + + return BenchmarkInfo( + name=json_object["name"], + model_name=json_object["model_name"], + model_tags=json_object["model_tags"], + model_source=json_object["model_source"], + bench_mode=json_object["bench_mode"], + compile_tags=json_object.get("compile_tags"), + driver_info=driver_info, + device_info=DeviceInfo.from_json_object(json_object["device_info"]), + run_config_id=json_object.get("run_config_id"), + ) @dataclasses.dataclass(frozen=True) class BenchmarkLatency: - """Stores latency statistics for a benchmark run.""" - mean: int - median: int - stddev: int - unit: str + """Stores latency statistics for a benchmark run.""" + + mean: int + median: int + stddev: int + unit: str - def to_json_object(self) -> Dict[str, Any]: - return dataclasses.asdict(self) + def to_json_object(self) -> Dict[str, Any]: + return dataclasses.asdict(self) - @staticmethod - def from_json_object(json_object: Dict[str, Any]): - return BenchmarkLatency(**json_object) + @staticmethod + def from_json_object(json_object: Dict[str, Any]): + return BenchmarkLatency(**json_object) def _get_google_benchmark_latencies( - benchmark_json: Dict[str, - Any]) -> Tuple[BenchmarkLatency, BenchmarkLatency]: - """Returns the Google Benchmark aggregate latencies. + benchmark_json: Dict[str, Any] +) -> Tuple[BenchmarkLatency, BenchmarkLatency]: + """Returns the Google Benchmark aggregate latencies. Args: benchmark_json: The JSON string or object returned by Google Benchmark. @@ -389,267 +396,276 @@ def _get_google_benchmark_latencies( Returns: Real time and CPU time BenchmarkLatency. """ - real_time_object: Dict[str, Any] = dict(unit="ns") - cpu_time_object: Dict[str, Any] = dict(unit="ns") - metrics = ["mean", "median", "stddev"] - for case in benchmark_json["benchmarks"]: - if any(case["name"].endswith(f"real_time_{m}") for m in metrics): - if case["time_unit"] != "ns": - raise ValueError(f"Expected ns as time unit") - metric = case["name"].split("_")[-1] - real_time_object[metric] = int(round(case["real_time"])) - cpu_time_object[metric] = int(round(case["cpu_time"])) - - # from_json_object implicitly validates that all metrics were found. - real_time = BenchmarkLatency.from_json_object(real_time_object) - cpu_time = BenchmarkLatency.from_json_object(cpu_time_object) - return real_time, cpu_time + real_time_object: Dict[str, Any] = dict(unit="ns") + cpu_time_object: Dict[str, Any] = dict(unit="ns") + metrics = ["mean", "median", "stddev"] + for case in benchmark_json["benchmarks"]: + if any(case["name"].endswith(f"real_time_{m}") for m in metrics): + if case["time_unit"] != "ns": + raise ValueError(f"Expected ns as time unit") + metric = case["name"].split("_")[-1] + real_time_object[metric] = int(round(case["real_time"])) + cpu_time_object[metric] = int(round(case["cpu_time"])) + + # from_json_object implicitly validates that all metrics were found. + real_time = BenchmarkLatency.from_json_object(real_time_object) + cpu_time = BenchmarkLatency.from_json_object(cpu_time_object) + return real_time, cpu_time @dataclasses.dataclass(frozen=True) class BenchmarkMemory: - """Stores memory statistics for a benchmark run.""" - peak: int - allocated: int - freed: int - live: int - unit: str - - def to_json_object(self) -> Dict[str, int]: - return dataclasses.asdict(self) - - @staticmethod - def from_json_object(json_object: Dict[str, Any]): - return BenchmarkMemory(**json_object) - - -def _get_iree_memory_statistics(benchmark_stderr: str, - device: str) -> BenchmarkMemory: - """Extracts IREE's memory statistics for a given device.""" - # The memory statistics for each device are listed on their own line. - pattern = (rf"{device}:" - r"\s*(?P\d+)B peak /" - r"\s*(?P\d+)B allocated /" - r"\s*(?P\d+)B freed /" - r"\s*(?P\d+)B live") - match = re.search(pattern, benchmark_stderr) - if match is None: - raise ValueError( - f"Unable to find memory statistics in '{benchmark_stderr}'") - return BenchmarkMemory( - peak=int(match["peak"]), - allocated=int(match["allocated"]), - freed=int(match["freed"]), - live=int(match["live"]), - unit="bytes", - ) + """Stores memory statistics for a benchmark run.""" + + peak: int + allocated: int + freed: int + live: int + unit: str + + def to_json_object(self) -> Dict[str, int]: + return dataclasses.asdict(self) + + @staticmethod + def from_json_object(json_object: Dict[str, Any]): + return BenchmarkMemory(**json_object) + + +def _get_iree_memory_statistics(benchmark_stderr: str, device: str) -> BenchmarkMemory: + """Extracts IREE's memory statistics for a given device.""" + # The memory statistics for each device are listed on their own line. + pattern = ( + rf"{device}:" + r"\s*(?P\d+)B peak /" + r"\s*(?P\d+)B allocated /" + r"\s*(?P\d+)B freed /" + r"\s*(?P\d+)B live" + ) + match = re.search(pattern, benchmark_stderr) + if match is None: + raise ValueError(f"Unable to find memory statistics in '{benchmark_stderr}'") + return BenchmarkMemory( + peak=int(match["peak"]), + allocated=int(match["allocated"]), + freed=int(match["freed"]), + live=int(match["live"]), + unit="bytes", + ) @dataclasses.dataclass(frozen=True) class BenchmarkMetrics(object): - """An object describing the results from a single benchmark. - - - real_time: the real time latency statistics returned by the benchmarking - framework. - - cpu_time: the cpu time latency statistics returned by the benchmarking - framework. - - host_memory: the host memory statistics returned by the benchmarking - framework. - - device_memory: the device memory statistics returned by the benchmarking - framework. - - raw_data: additional JSON-compatible raw results returned by the - benchmarking framework. - """ - real_time: BenchmarkLatency - cpu_time: BenchmarkLatency - host_memory: BenchmarkMemory - device_memory: BenchmarkMemory - raw_data: Dict[str, Any] - - def to_json_object(self) -> Dict[str, Any]: - return { - "real_time": self.real_time.to_json_object(), - "cpu_time": self.cpu_time.to_json_object(), - "host_memory": self.host_memory.to_json_object(), - "device_memory": self.device_memory.to_json_object(), - "raw_data": self.raw_data, - } - - @staticmethod - def from_json_object(json_object: Dict[str, Any]): + """An object describing the results from a single benchmark. + + - real_time: the real time latency statistics returned by the benchmarking + framework. + - cpu_time: the cpu time latency statistics returned by the benchmarking + framework. + - host_memory: the host memory statistics returned by the benchmarking + framework. + - device_memory: the device memory statistics returned by the benchmarking + framework. + - raw_data: additional JSON-compatible raw results returned by the + benchmarking framework. + """ + + real_time: BenchmarkLatency + cpu_time: BenchmarkLatency + host_memory: BenchmarkMemory + device_memory: BenchmarkMemory + raw_data: Dict[str, Any] + + def to_json_object(self) -> Dict[str, Any]: + return { + "real_time": self.real_time.to_json_object(), + "cpu_time": self.cpu_time.to_json_object(), + "host_memory": self.host_memory.to_json_object(), + "device_memory": self.device_memory.to_json_object(), + "raw_data": self.raw_data, + } + + @staticmethod + def from_json_object(json_object: Dict[str, Any]): + return BenchmarkMetrics( + real_time=BenchmarkLatency.from_json_object(json_object["real_time"]), + cpu_time=BenchmarkLatency.from_json_object(json_object["cpu_time"]), + host_memory=BenchmarkMemory.from_json_object(json_object["host_memory"]), + device_memory=BenchmarkMemory.from_json_object( + json_object["device_memory"] + ), + raw_data=json_object["raw_data"], + ) + + +def parse_iree_benchmark_metrics( + benchmark_stdout: str, benchmark_stderr: str +) -> BenchmarkMetrics: + """Extract benchmark metrics from the output of iree-benchmark-module. + + Args: + benchmark_stdout: The stdout of iree-benchmark-module with + --benchmark_format=json. + benchmark_stdout: The stderr of iree-benchmark-module with + --print_statistics=true. + + Returns: + A populated BenchmarkMetrics dataclass. + """ + benchmark_json = json.loads(benchmark_stdout) + real_time, cpu_time = _get_google_benchmark_latencies(benchmark_json) return BenchmarkMetrics( - real_time=BenchmarkLatency.from_json_object(json_object["real_time"]), - cpu_time=BenchmarkLatency.from_json_object(json_object["cpu_time"]), - host_memory=BenchmarkMemory.from_json_object( - json_object["host_memory"]), - device_memory=BenchmarkMemory.from_json_object( - json_object["device_memory"]), - raw_data=json_object["raw_data"], + real_time=real_time, + cpu_time=cpu_time, + host_memory=_get_iree_memory_statistics(benchmark_stderr, "HOST_LOCAL"), + device_memory=_get_iree_memory_statistics(benchmark_stderr, "DEVICE_LOCAL"), + raw_data=benchmark_json, ) -def parse_iree_benchmark_metrics(benchmark_stdout: str, - benchmark_stderr: str) -> BenchmarkMetrics: - """Extract benchmark metrics from the output of iree-benchmark-module. +@dataclasses.dataclass(frozen=True) +class BenchmarkRun(object): + """An object describing a single run of the benchmark binary. - Args: - benchmark_stdout: The stdout of iree-benchmark-module with - --benchmark_format=json. - benchmark_stdout: The stderr of iree-benchmark-module with - --print_statistics=true. + - info: a BenchmarkInfo object describing the benchmark setup. + - metrics: a BenchmarkMetrics object containing the results of the benchmark. + """ - Returns: - A populated BenchmarkMetrics dataclass. - """ - benchmark_json = json.loads(benchmark_stdout) - real_time, cpu_time = _get_google_benchmark_latencies(benchmark_json) - return BenchmarkMetrics( - real_time=real_time, - cpu_time=cpu_time, - host_memory=_get_iree_memory_statistics(benchmark_stderr, "HOST_LOCAL"), - device_memory=_get_iree_memory_statistics(benchmark_stderr, - "DEVICE_LOCAL"), - raw_data=benchmark_json, - ) + info: BenchmarkInfo + metrics: BenchmarkMetrics + def to_json_object(self) -> Dict[str, Any]: + return { + "info": self.info.to_json_object(), + "metrics": self.metrics.to_json_object(), + } -@dataclasses.dataclass(frozen=True) -class BenchmarkRun(object): - """An object describing a single run of the benchmark binary. - - - info: a BenchmarkInfo object describing the benchmark setup. - - metrics: a BenchmarkMetrics object containing the results of the benchmark. - """ - info: BenchmarkInfo - metrics: BenchmarkMetrics - - def to_json_object(self) -> Dict[str, Any]: - return { - "info": self.info.to_json_object(), - "metrics": self.metrics.to_json_object(), - } - - @staticmethod - def from_json_object(json_object: Dict[str, Any]): - return BenchmarkRun( - BenchmarkInfo.from_json_object(json_object["info"]), - BenchmarkMetrics.from_json_object(json_object["metrics"]), - ) + @staticmethod + def from_json_object(json_object: Dict[str, Any]): + return BenchmarkRun( + BenchmarkInfo.from_json_object(json_object["info"]), + BenchmarkMetrics.from_json_object(json_object["metrics"]), + ) class BenchmarkResults(object): - """An object describing a set of benchmarks for one particular commit. + """An object describing a set of benchmarks for one particular commit. It contains the following fields: - commit: the commit SHA for this set of benchmarks. - benchmarks: a list of BenchmarkRun objects """ - def __init__(self): - self.commit: str = "" - self.benchmarks: List[BenchmarkRun] = [] - - def set_commit(self, commit: str): - self.commit = commit - - def merge(self, other): - if self.commit != other.commit: - raise ValueError("Inconsistent pull request commit") - self.benchmarks.extend(other.benchmarks) - - def to_json_str(self) -> str: - json_object = {"commit": self.commit, "benchmarks": []} - json_object["benchmarks"] = [b.to_json_object() for b in self.benchmarks] - return json.dumps(json_object, indent=2) - - @staticmethod - def from_json_str(json_str: str): - json_object = json.loads(json_str) - results = BenchmarkResults() - results.set_commit(json_object["commit"]) - results.benchmarks = [ - BenchmarkRun.from_json_object(b) for b in json_object["benchmarks"] - ] - return results + def __init__(self): + self.commit: str = "" + self.benchmarks: List[BenchmarkRun] = [] + + def set_commit(self, commit: str): + self.commit = commit + + def merge(self, other): + if self.commit != other.commit: + raise ValueError("Inconsistent pull request commit") + self.benchmarks.extend(other.benchmarks) + + def to_json_str(self) -> str: + json_object = {"commit": self.commit, "benchmarks": []} + json_object["benchmarks"] = [b.to_json_object() for b in self.benchmarks] + return json.dumps(json_object, indent=2) + + @staticmethod + def from_json_str(json_str: str): + json_object = json.loads(json_str) + results = BenchmarkResults() + results.set_commit(json_object["commit"]) + results.benchmarks = [ + BenchmarkRun.from_json_object(b) for b in json_object["benchmarks"] + ] + return results @dataclasses.dataclass(frozen=True) class CompilationInfo(object): - name: str - model_name: str - model_tags: Tuple[str] - model_source: str - target_arch: str - compile_tags: Tuple[str] - gen_config_id: Optional[str] = None - - def __str__(self): - return self.name - - @staticmethod - def from_json_object(json_object: Dict[str, Any]): - return CompilationInfo(name=json_object["name"], - model_name=json_object["model_name"], - model_tags=tuple(json_object["model_tags"]), - model_source=json_object["model_source"], - target_arch=json_object["target_arch"], - compile_tags=tuple(json_object["compile_tags"]), - gen_config_id=json_object.get("gen_config_id")) + name: str + model_name: str + model_tags: Tuple[str] + model_source: str + target_arch: str + compile_tags: Tuple[str] + gen_config_id: Optional[str] = None + + def __str__(self): + return self.name + + @staticmethod + def from_json_object(json_object: Dict[str, Any]): + return CompilationInfo( + name=json_object["name"], + model_name=json_object["model_name"], + model_tags=tuple(json_object["model_tags"]), + model_source=json_object["model_source"], + target_arch=json_object["target_arch"], + compile_tags=tuple(json_object["compile_tags"]), + gen_config_id=json_object.get("gen_config_id"), + ) @dataclasses.dataclass(frozen=True) class ModuleComponentSizes(object): - file_bytes: int - vm_component_bytes: int - const_component_bytes: int - total_dispatch_component_bytes: int + file_bytes: int + vm_component_bytes: int + const_component_bytes: int + total_dispatch_component_bytes: int - @staticmethod - def from_json_object(json_object: Dict[str, Any]): - return ModuleComponentSizes(**json_object) + @staticmethod + def from_json_object(json_object: Dict[str, Any]): + return ModuleComponentSizes(**json_object) @dataclasses.dataclass(frozen=True) class IRStatistics(object): - # Number of cmd.dispatch ops in IR. - stream_dispatch_count: int + # Number of cmd.dispatch ops in IR. + stream_dispatch_count: int - @staticmethod - def from_json_object(json_object: Dict[str, Any]): - return IRStatistics(**json_object) + @staticmethod + def from_json_object(json_object: Dict[str, Any]): + return IRStatistics(**json_object) @dataclasses.dataclass(frozen=True) class CompilationStatistics(object): - compilation_info: CompilationInfo - # Module file and component sizes. - module_component_sizes: ModuleComponentSizes - # Module compilation time in ms. - compilation_time_ms: int - # IR-level statistics - ir_stats: IRStatistics - - @staticmethod - def from_json_object(json_object: Dict[str, Any]): - return CompilationStatistics( - compilation_info=CompilationInfo.from_json_object( - json_object["compilation_info"]), - module_component_sizes=ModuleComponentSizes.from_json_object( - json_object["module_component_sizes"]), - compilation_time_ms=json_object["compilation_time_ms"], - ir_stats=IRStatistics.from_json_object(json_object["ir_stats"])) + compilation_info: CompilationInfo + # Module file and component sizes. + module_component_sizes: ModuleComponentSizes + # Module compilation time in ms. + compilation_time_ms: int + # IR-level statistics + ir_stats: IRStatistics + + @staticmethod + def from_json_object(json_object: Dict[str, Any]): + return CompilationStatistics( + compilation_info=CompilationInfo.from_json_object( + json_object["compilation_info"] + ), + module_component_sizes=ModuleComponentSizes.from_json_object( + json_object["module_component_sizes"] + ), + compilation_time_ms=json_object["compilation_time_ms"], + ir_stats=IRStatistics.from_json_object(json_object["ir_stats"]), + ) @dataclasses.dataclass(frozen=True) class CompilationResults(object): - commit: str - compilation_statistics: Sequence[CompilationStatistics] - - @staticmethod - def from_json_object(json_object: Dict[str, Any]): - return CompilationResults( - commit=json_object["commit"], - compilation_statistics=[ - CompilationStatistics.from_json_object(obj) - for obj in json_object["compilation_statistics"] - ]) + commit: str + compilation_statistics: Sequence[CompilationStatistics] + + @staticmethod + def from_json_object(json_object: Dict[str, Any]): + return CompilationResults( + commit=json_object["commit"], + compilation_statistics=[ + CompilationStatistics.from_json_object(obj) + for obj in json_object["compilation_statistics"] + ], + ) diff --git a/build_tools/benchmarks/common/benchmark_driver.py b/build_tools/benchmarks/common/benchmark_driver.py index efa168ca2dba..49b7fd679025 100644 --- a/build_tools/benchmarks/common/benchmark_driver.py +++ b/build_tools/benchmarks/common/benchmark_driver.py @@ -10,245 +10,279 @@ from typing import List, Optional, Sequence, Set, Tuple from common.benchmark_suite import BenchmarkCase, BenchmarkSuite from common.benchmark_config import BenchmarkConfig -from common.benchmark_definition import (BenchmarkInfo, BenchmarkResults, - BenchmarkMetrics, BenchmarkRun, - DeviceInfo) +from common.benchmark_definition import ( + BenchmarkInfo, + BenchmarkResults, + BenchmarkMetrics, + BenchmarkRun, + DeviceInfo, +) class BenchmarkDriver(object): - """Abstract driver runs the whole benchmark flow.""" - - def __init__(self, - device_info: DeviceInfo, - benchmark_config: BenchmarkConfig, - benchmark_suite: BenchmarkSuite, - benchmark_grace_time: float = 0.0, - verbose: bool = False): - self.device_info = device_info - self.config = benchmark_config - self.benchmark_suite = benchmark_suite - self.benchmark_grace_time = benchmark_grace_time - self.verbose = verbose - self.finished_benchmarks: List[Tuple[BenchmarkInfo, pathlib.Path]] = [] - self.finished_captures: List[pathlib.Path] = [] - self.benchmark_errors = [] - self._seen_benchmark_names: Set[str] = set() - - def run_benchmark_case(self, benchmark_case: BenchmarkCase, - benchmark_results_filename: Optional[pathlib.Path], - capture_filename: Optional[pathlib.Path]) -> None: - """Runs the benchmark case and serializes the results. - - Args: - benchmark_case: the benchmark_case. - benchmark_results_filename: the path to store the serialized - BenchmarkMetrics. Benchmarking is required if set. - capture_filename: the path to store captured trace. Trace capturing is - required if set. - - Raises: - Exception during benchmarking. - """ - raise NotImplementedError("Should be overwritten by a subclass.") - - def run(self) -> None: - """Execute the benchmark flow. - - It performs the following steps: - 1. Enumerate and filter benchmark cases. - 2. Call 'run_benchmark_case' for each benchmark case. - 3. Collect the benchmark results and captures. - """ - - self.config.benchmark_results_dir.mkdir(parents=True, exist_ok=True) - if self.config.trace_capture_config is not None: - self.config.trace_capture_config.capture_tmp_dir.mkdir(parents=True, - exist_ok=True) - - cpu_target_arch = self.device_info.get_cpu_arch() - gpu_target_arch = self.device_info.get_gpu_arch() - detected_architectures = [ - arch for arch in [cpu_target_arch, gpu_target_arch] if arch is not None - ] - if self.config.use_compatible_filter: - if cpu_target_arch is None: - print("INFO: Detected unsupported CPU architecture in" - f' "{self.device_info}", CPU benchmarking is disabled.') - if gpu_target_arch is None: - print("INFO: Detected unsupported GPU architecture in" - f' "{self.device_info}", GPU benchmarking is disabled.') - compatible_arch_filter = detected_architectures - else: - # No compatible filter on the target architectures. - compatible_arch_filter = None - - drivers, loaders = self.__get_available_drivers_and_loaders() - - benchmark_cases = self.benchmark_suite.filter_benchmarks( - available_drivers=drivers, - available_loaders=loaders, - target_architectures=compatible_arch_filter, - driver_filter=self.config.driver_filter, - mode_filter=self.config.mode_filter, - model_name_filter=self.config.model_name_filter) - - for benchmark_case in benchmark_cases: - benchmark_info = self.__get_benchmark_info_from_case( - benchmark_case=benchmark_case) - benchmark_name = str(benchmark_info) - - if benchmark_case.target_arch not in detected_architectures: - print(f"WARNING: Benchmark '{benchmark_name}' may be incompatible" - f" with the detected architectures '{detected_architectures}'" - f" on the device. Pass --compatible-only to skip incompatible" - f" benchmarks.") - - # Sanity check for the uniqueness of benchmark names. - if benchmark_name in self._seen_benchmark_names: - raise ValueError( - f"Found duplicate benchmark {benchmark_name} in the suites.") - self._seen_benchmark_names.add(benchmark_name) - - results_path, capture_path = self.__get_output_paths(benchmark_name) - # If we continue from the previous results, check and skip if the result - # files exist. - if self.config.continue_from_previous: - if results_path is not None and results_path.exists(): - self.finished_benchmarks.append((benchmark_info, results_path)) - results_path = None - - if capture_path is not None and capture_path.exists(): - self.finished_captures.append(capture_path) - capture_path = None - - # Skip if no need to benchmark and capture. - if results_path is None and capture_path is None: - continue - - print(f"--> Benchmark started: {benchmark_name} <--") - - try: - self.run_benchmark_case(benchmark_case, results_path, capture_path) - except Exception as e: - # Delete unfinished results if they exist. - if results_path is not None: - results_path.unlink(missing_ok=True) - if capture_path is not None: - capture_path.unlink(missing_ok=True) - - if not self.config.keep_going: - raise e - - print(f"Processing of benchmark failed with: {e}") - self.benchmark_errors.append(e) - continue - finally: - # Some grace time. - time.sleep(self.benchmark_grace_time) - - print("Benchmark completed") - - if results_path: - self.finished_benchmarks.append((benchmark_info, results_path)) - if capture_path: - self.finished_captures.append(capture_path) - - def get_benchmark_results(self) -> BenchmarkResults: - """Returns the finished benchmark results.""" - - results = BenchmarkResults() - results.set_commit(self.config.git_commit_hash) - - finished_benchmarks = sorted(self.finished_benchmarks, - key=lambda pair: str(pair[0])) - for info, path in finished_benchmarks: - benchmark_metrics_json_object = json.loads(path.read_text()) - benchmark_run = BenchmarkRun(info=info, - metrics=BenchmarkMetrics.from_json_object( - benchmark_metrics_json_object)) - results.benchmarks.append(benchmark_run) - - return results - - def get_benchmark_result_filenames(self) -> Sequence[pathlib.Path]: - """Returns the json file paths of finished benchmarks.""" - return [path for info, path in self.finished_benchmarks] - - def get_capture_filenames(self) -> Sequence[pathlib.Path]: - """Returns the tracy file paths of finished captures.""" - return self.finished_captures - - def get_benchmark_errors(self): - """Returns the exceptions captured during benchmarking.""" - return self.benchmark_errors - - def __get_output_paths(self, benchmark_name: str): - """Get output paths for the results and capture. The path of results/capture - is None if the benchmark/capture doesn't need to be run. - """ - - benchmark_results_filename = None - if self.config.normal_benchmark_tool_dir: - benchmark_results_filename = self.config.benchmark_results_dir / f"{benchmark_name}.json" - - capture_filename = None - if self.config.trace_capture_config: - capture_filename = self.config.trace_capture_config.capture_tmp_dir / f"{benchmark_name}.tracy" - - return (benchmark_results_filename, capture_filename) - - def __get_benchmark_info_from_case( - self, benchmark_case: BenchmarkCase) -> BenchmarkInfo: - run_config = benchmark_case.run_config - run_tags = run_config.module_execution_config.tags - gen_config = run_config.module_generation_config - model_source = str(gen_config.imported_model.model.source_type) - compile_tags = gen_config.compile_config.tags - return BenchmarkInfo(name=run_config.name, - model_name=benchmark_case.model_name, - model_tags=benchmark_case.model_tags, - model_source=model_source, - bench_mode=run_tags, - compile_tags=compile_tags, - driver_info=benchmark_case.driver_info, - device_info=self.device_info, - run_config_id=run_config.composite_id) - - def __get_available_drivers_and_loaders( - self) -> Tuple[Sequence[str], Sequence[str]]: - any_tool_dir = (self.config.normal_benchmark_tool_dir - if self.config.normal_benchmark_tool_dir else - self.config.trace_capture_config.traced_benchmark_tool_dir) - config_txt_file_path = any_tool_dir / "build_config.txt" - config_txt_file_lines = config_txt_file_path.read_text().splitlines() - - available_drivers = [] - available_loaders = [] - for line in config_txt_file_lines: - name, value = line.strip().split("=") - if value != "ON": - continue - if name == "IREE_HAL_DRIVER_CUDA": - available_drivers.append("cuda") - elif name == "IREE_HAL_DRIVER_LOCAL_SYNC": - available_drivers.append("local-sync") - elif name == "IREE_HAL_DRIVER_LOCAL_TASK": - available_drivers.append("local-task") - elif name == "IREE_HAL_DRIVER_VULKAN": - available_drivers.append("vulkan") - elif name == "IREE_HAL_EXECUTABLE_LOADER_EMBEDDED_ELF": - available_loaders.append("embedded-elf") - elif name == "IREE_HAL_EXECUTABLE_LOADER_SYSTEM_LIBRARY": - available_loaders.append("system-library") - elif name == "IREE_HAL_EXECUTABLE_LOADER_VMVX_MODULE": - available_loaders.append("vmvx-module") - else: - continue - - if self.verbose: - available_drivers_str = ', '.join(available_drivers) - print(f"Available drivers: {available_drivers_str}") - available_loaders_str = ', '.join(available_loaders) - print(f"Available loaders: {available_loaders_str}") - - return available_drivers, available_loaders + """Abstract driver runs the whole benchmark flow.""" + + def __init__( + self, + device_info: DeviceInfo, + benchmark_config: BenchmarkConfig, + benchmark_suite: BenchmarkSuite, + benchmark_grace_time: float = 0.0, + verbose: bool = False, + ): + self.device_info = device_info + self.config = benchmark_config + self.benchmark_suite = benchmark_suite + self.benchmark_grace_time = benchmark_grace_time + self.verbose = verbose + self.finished_benchmarks: List[Tuple[BenchmarkInfo, pathlib.Path]] = [] + self.finished_captures: List[pathlib.Path] = [] + self.benchmark_errors = [] + self._seen_benchmark_names: Set[str] = set() + + def run_benchmark_case( + self, + benchmark_case: BenchmarkCase, + benchmark_results_filename: Optional[pathlib.Path], + capture_filename: Optional[pathlib.Path], + ) -> None: + """Runs the benchmark case and serializes the results. + + Args: + benchmark_case: the benchmark_case. + benchmark_results_filename: the path to store the serialized + BenchmarkMetrics. Benchmarking is required if set. + capture_filename: the path to store captured trace. Trace capturing is + required if set. + + Raises: + Exception during benchmarking. + """ + raise NotImplementedError("Should be overwritten by a subclass.") + + def run(self) -> None: + """Execute the benchmark flow. + + It performs the following steps: + 1. Enumerate and filter benchmark cases. + 2. Call 'run_benchmark_case' for each benchmark case. + 3. Collect the benchmark results and captures. + """ + + self.config.benchmark_results_dir.mkdir(parents=True, exist_ok=True) + if self.config.trace_capture_config is not None: + self.config.trace_capture_config.capture_tmp_dir.mkdir( + parents=True, exist_ok=True + ) + + cpu_target_arch = self.device_info.get_cpu_arch() + gpu_target_arch = self.device_info.get_gpu_arch() + detected_architectures = [ + arch for arch in [cpu_target_arch, gpu_target_arch] if arch is not None + ] + if self.config.use_compatible_filter: + if cpu_target_arch is None: + print( + "INFO: Detected unsupported CPU architecture in" + f' "{self.device_info}", CPU benchmarking is disabled.' + ) + if gpu_target_arch is None: + print( + "INFO: Detected unsupported GPU architecture in" + f' "{self.device_info}", GPU benchmarking is disabled.' + ) + compatible_arch_filter = detected_architectures + else: + # No compatible filter on the target architectures. + compatible_arch_filter = None + + drivers, loaders = self.__get_available_drivers_and_loaders() + + benchmark_cases = self.benchmark_suite.filter_benchmarks( + available_drivers=drivers, + available_loaders=loaders, + target_architectures=compatible_arch_filter, + driver_filter=self.config.driver_filter, + mode_filter=self.config.mode_filter, + model_name_filter=self.config.model_name_filter, + ) + + for benchmark_case in benchmark_cases: + benchmark_info = self.__get_benchmark_info_from_case( + benchmark_case=benchmark_case + ) + benchmark_name = str(benchmark_info) + + if benchmark_case.target_arch not in detected_architectures: + print( + f"WARNING: Benchmark '{benchmark_name}' may be incompatible" + f" with the detected architectures '{detected_architectures}'" + f" on the device. Pass --compatible-only to skip incompatible" + f" benchmarks." + ) + + # Sanity check for the uniqueness of benchmark names. + if benchmark_name in self._seen_benchmark_names: + raise ValueError( + f"Found duplicate benchmark {benchmark_name} in the suites." + ) + self._seen_benchmark_names.add(benchmark_name) + + results_path, capture_path = self.__get_output_paths(benchmark_name) + # If we continue from the previous results, check and skip if the result + # files exist. + if self.config.continue_from_previous: + if results_path is not None and results_path.exists(): + self.finished_benchmarks.append((benchmark_info, results_path)) + results_path = None + + if capture_path is not None and capture_path.exists(): + self.finished_captures.append(capture_path) + capture_path = None + + # Skip if no need to benchmark and capture. + if results_path is None and capture_path is None: + continue + + print(f"--> Benchmark started: {benchmark_name} <--") + + try: + self.run_benchmark_case(benchmark_case, results_path, capture_path) + except Exception as e: + # Delete unfinished results if they exist. + if results_path is not None: + results_path.unlink(missing_ok=True) + if capture_path is not None: + capture_path.unlink(missing_ok=True) + + if not self.config.keep_going: + raise e + + print(f"Processing of benchmark failed with: {e}") + self.benchmark_errors.append(e) + continue + finally: + # Some grace time. + time.sleep(self.benchmark_grace_time) + + print("Benchmark completed") + + if results_path: + self.finished_benchmarks.append((benchmark_info, results_path)) + if capture_path: + self.finished_captures.append(capture_path) + + def get_benchmark_results(self) -> BenchmarkResults: + """Returns the finished benchmark results.""" + + results = BenchmarkResults() + results.set_commit(self.config.git_commit_hash) + + finished_benchmarks = sorted( + self.finished_benchmarks, key=lambda pair: str(pair[0]) + ) + for info, path in finished_benchmarks: + benchmark_metrics_json_object = json.loads(path.read_text()) + benchmark_run = BenchmarkRun( + info=info, + metrics=BenchmarkMetrics.from_json_object( + benchmark_metrics_json_object + ), + ) + results.benchmarks.append(benchmark_run) + + return results + + def get_benchmark_result_filenames(self) -> Sequence[pathlib.Path]: + """Returns the json file paths of finished benchmarks.""" + return [path for info, path in self.finished_benchmarks] + + def get_capture_filenames(self) -> Sequence[pathlib.Path]: + """Returns the tracy file paths of finished captures.""" + return self.finished_captures + + def get_benchmark_errors(self): + """Returns the exceptions captured during benchmarking.""" + return self.benchmark_errors + + def __get_output_paths(self, benchmark_name: str): + """Get output paths for the results and capture. The path of results/capture + is None if the benchmark/capture doesn't need to be run. + """ + + benchmark_results_filename = None + if self.config.normal_benchmark_tool_dir: + benchmark_results_filename = ( + self.config.benchmark_results_dir / f"{benchmark_name}.json" + ) + + capture_filename = None + if self.config.trace_capture_config: + capture_filename = ( + self.config.trace_capture_config.capture_tmp_dir + / f"{benchmark_name}.tracy" + ) + + return (benchmark_results_filename, capture_filename) + + def __get_benchmark_info_from_case( + self, benchmark_case: BenchmarkCase + ) -> BenchmarkInfo: + run_config = benchmark_case.run_config + run_tags = run_config.module_execution_config.tags + gen_config = run_config.module_generation_config + model_source = str(gen_config.imported_model.model.source_type) + compile_tags = gen_config.compile_config.tags + return BenchmarkInfo( + name=run_config.name, + model_name=benchmark_case.model_name, + model_tags=benchmark_case.model_tags, + model_source=model_source, + bench_mode=run_tags, + compile_tags=compile_tags, + driver_info=benchmark_case.driver_info, + device_info=self.device_info, + run_config_id=run_config.composite_id, + ) + + def __get_available_drivers_and_loaders( + self, + ) -> Tuple[Sequence[str], Sequence[str]]: + any_tool_dir = ( + self.config.normal_benchmark_tool_dir + if self.config.normal_benchmark_tool_dir + else self.config.trace_capture_config.traced_benchmark_tool_dir + ) + config_txt_file_path = any_tool_dir / "build_config.txt" + config_txt_file_lines = config_txt_file_path.read_text().splitlines() + + available_drivers = [] + available_loaders = [] + for line in config_txt_file_lines: + name, value = line.strip().split("=") + if value != "ON": + continue + if name == "IREE_HAL_DRIVER_CUDA": + available_drivers.append("cuda") + elif name == "IREE_HAL_DRIVER_LOCAL_SYNC": + available_drivers.append("local-sync") + elif name == "IREE_HAL_DRIVER_LOCAL_TASK": + available_drivers.append("local-task") + elif name == "IREE_HAL_DRIVER_VULKAN": + available_drivers.append("vulkan") + elif name == "IREE_HAL_EXECUTABLE_LOADER_EMBEDDED_ELF": + available_loaders.append("embedded-elf") + elif name == "IREE_HAL_EXECUTABLE_LOADER_SYSTEM_LIBRARY": + available_loaders.append("system-library") + elif name == "IREE_HAL_EXECUTABLE_LOADER_VMVX_MODULE": + available_loaders.append("vmvx-module") + else: + continue + + if self.verbose: + available_drivers_str = ", ".join(available_drivers) + print(f"Available drivers: {available_drivers_str}") + available_loaders_str = ", ".join(available_loaders) + print(f"Available loaders: {available_loaders_str}") + + return available_drivers, available_loaders diff --git a/build_tools/benchmarks/common/benchmark_driver_test.py b/build_tools/benchmarks/common/benchmark_driver_test.py index 106cb557831c..f2dd0761b581 100644 --- a/build_tools/benchmarks/common/benchmark_driver_test.py +++ b/build_tools/benchmarks/common/benchmark_driver_test.py @@ -14,254 +14,301 @@ from common import benchmark_config from common.benchmark_suite import BenchmarkCase, BenchmarkSuite from common.benchmark_driver import BenchmarkDriver -from common.benchmark_definition import (IREE_DRIVERS_INFOS, DeviceInfo, - PlatformType, BenchmarkLatency, - BenchmarkMemory, BenchmarkMetrics) +from common.benchmark_definition import ( + IREE_DRIVERS_INFOS, + DeviceInfo, + PlatformType, + BenchmarkLatency, + BenchmarkMemory, + BenchmarkMetrics, +) from e2e_test_framework.definitions import common_definitions, iree_definitions class FakeBenchmarkDriver(BenchmarkDriver): - - def __init__(self, - *args, - raise_exception_on_case: Optional[BenchmarkCase] = None, - **kwargs): - super().__init__(*args, **kwargs) - self.raise_exception_on_case = raise_exception_on_case - self.run_benchmark_cases = [] - - def run_benchmark_case(self, benchmark_case: BenchmarkCase, - benchmark_results_filename: Optional[pathlib.Path], - capture_filename: Optional[pathlib.Path]) -> None: - if self.raise_exception_on_case == benchmark_case: - raise Exception("fake exception") - - self.run_benchmark_cases.append(benchmark_case) - - if benchmark_results_filename: - fake_benchmark_metrics = BenchmarkMetrics( - real_time=BenchmarkLatency(0, 0, 0, "ns"), - cpu_time=BenchmarkLatency(0, 0, 0, "ns"), - host_memory=BenchmarkMemory(0, 0, 0, 0, "bytes"), - device_memory=BenchmarkMemory(0, 0, 0, 0, "bytes"), - raw_data={}, - ) - benchmark_results_filename.write_text( - json.dumps(fake_benchmark_metrics.to_json_object())) - if capture_filename: - capture_filename.write_text("{}") + def __init__( + self, *args, raise_exception_on_case: Optional[BenchmarkCase] = None, **kwargs + ): + super().__init__(*args, **kwargs) + self.raise_exception_on_case = raise_exception_on_case + self.run_benchmark_cases = [] + + def run_benchmark_case( + self, + benchmark_case: BenchmarkCase, + benchmark_results_filename: Optional[pathlib.Path], + capture_filename: Optional[pathlib.Path], + ) -> None: + if self.raise_exception_on_case == benchmark_case: + raise Exception("fake exception") + + self.run_benchmark_cases.append(benchmark_case) + + if benchmark_results_filename: + fake_benchmark_metrics = BenchmarkMetrics( + real_time=BenchmarkLatency(0, 0, 0, "ns"), + cpu_time=BenchmarkLatency(0, 0, 0, "ns"), + host_memory=BenchmarkMemory(0, 0, 0, 0, "bytes"), + device_memory=BenchmarkMemory(0, 0, 0, 0, "bytes"), + raw_data={}, + ) + benchmark_results_filename.write_text( + json.dumps(fake_benchmark_metrics.to_json_object()) + ) + if capture_filename: + capture_filename.write_text("{}") class BenchmarkDriverTest(unittest.TestCase): - - def setUp(self): - self._tmp_dir_obj = tempfile.TemporaryDirectory() - self._root_dir_obj = tempfile.TemporaryDirectory() - - self.tmp_dir = pathlib.Path(self._tmp_dir_obj.name) - (self.tmp_dir / "build_config.txt").write_text( - "IREE_HAL_DRIVER_LOCAL_SYNC=ON\n" - "IREE_HAL_DRIVER_LOCAL_TASK=ON\n" - "IREE_HAL_EXECUTABLE_LOADER_EMBEDDED_ELF=ON\n") - - self.benchmark_results_dir = (self.tmp_dir / - benchmark_config.BENCHMARK_RESULTS_REL_PATH) - self.captures_dir = (self.tmp_dir / benchmark_config.CAPTURES_REL_PATH) - self.benchmark_results_dir.mkdir() - self.captures_dir.mkdir() - - self.config = benchmark_config.BenchmarkConfig( - root_benchmark_dir=pathlib.Path(self._root_dir_obj.name), - benchmark_results_dir=self.benchmark_results_dir, - git_commit_hash="abcd", - normal_benchmark_tool_dir=self.tmp_dir, - trace_capture_config=benchmark_config.TraceCaptureConfig( - traced_benchmark_tool_dir=self.tmp_dir, - trace_capture_tool=self.tmp_dir / "capture_tool", - capture_tarball=self.tmp_dir / "captures.tar", - capture_tmp_dir=self.captures_dir), - use_compatible_filter=True) - - self.device_info = DeviceInfo(platform_type=PlatformType.LINUX, - model="Unknown", - cpu_abi="x86_64", - cpu_uarch="CascadeLake", - cpu_features=[], - gpu_name="unknown") - - model_tflite = common_definitions.Model( - id="tflite", - name="model_tflite", - tags=[], - source_type=common_definitions.ModelSourceType.EXPORTED_TFLITE, - source_url="", - entry_function="predict", - input_types=["1xf32"]) - device_spec = common_definitions.DeviceSpec.build( - id="dev", - device_name="test_dev", - architecture=common_definitions.DeviceArchitecture.X86_64_CASCADELAKE, - host_environment=common_definitions.HostEnvironment.LINUX_X86_64, - device_parameters=[], - tags=[]) - compile_target = iree_definitions.CompileTarget( - target_backend=iree_definitions.TargetBackend.LLVM_CPU, - target_architecture=( - common_definitions.DeviceArchitecture.X86_64_CASCADELAKE), - target_abi=iree_definitions.TargetABI.LINUX_GNU) - gen_config = iree_definitions.ModuleGenerationConfig.build( - imported_model=iree_definitions.ImportedModel.from_model(model_tflite), - compile_config=iree_definitions.CompileConfig.build( - id="comp_a", tags=[], compile_targets=[compile_target])) - exec_config_a = iree_definitions.ModuleExecutionConfig.build( - id="exec_a", - tags=["sync"], - loader=iree_definitions.RuntimeLoader.EMBEDDED_ELF, - driver=iree_definitions.RuntimeDriver.LOCAL_SYNC) - run_config_a = iree_definitions.E2EModelRunConfig.build( - module_generation_config=gen_config, - module_execution_config=exec_config_a, - target_device_spec=device_spec, - input_data=common_definitions.ZEROS_MODEL_INPUT_DATA, - tool=iree_definitions.E2EModelRunTool.IREE_BENCHMARK_MODULE) - exec_config_b = iree_definitions.ModuleExecutionConfig.build( - id="exec_b", - tags=["task"], - loader=iree_definitions.RuntimeLoader.EMBEDDED_ELF, - driver=iree_definitions.RuntimeDriver.LOCAL_TASK) - run_config_b = iree_definitions.E2EModelRunConfig.build( - module_generation_config=gen_config, - module_execution_config=exec_config_b, - target_device_spec=device_spec, - input_data=common_definitions.ZEROS_MODEL_INPUT_DATA, - tool=iree_definitions.E2EModelRunTool.IREE_BENCHMARK_MODULE) - self.case1 = BenchmarkCase( - model_name="model_tflite", - model_tags=[], - bench_mode=["sync"], - target_arch=common_definitions.DeviceArchitecture.X86_64_CASCADELAKE, - driver_info=IREE_DRIVERS_INFOS["iree-llvm-cpu-sync"], - benchmark_case_dir=pathlib.Path("case1"), - benchmark_tool_name="tool", - run_config=run_config_a) - self.case2 = BenchmarkCase( - model_name="model_tflite", - model_tags=[], - bench_mode=["task"], - target_arch=common_definitions.DeviceArchitecture.X86_64_CASCADELAKE, - driver_info=IREE_DRIVERS_INFOS["iree-llvm-cpu"], - benchmark_case_dir=pathlib.Path("case2"), - benchmark_tool_name="tool", - run_config=run_config_b) - - compile_target_rv64 = iree_definitions.CompileTarget( - target_backend=iree_definitions.TargetBackend.LLVM_CPU, - target_architecture=common_definitions.DeviceArchitecture.RV64_GENERIC, - target_abi=iree_definitions.TargetABI.LINUX_GNU) - gen_config_rv64 = iree_definitions.ModuleGenerationConfig.build( - imported_model=iree_definitions.ImportedModel.from_model(model_tflite), - compile_config=iree_definitions.CompileConfig.build( - id="comp_rv64", tags=[], compile_targets=[compile_target_rv64])) - device_spec_rv64 = common_definitions.DeviceSpec.build( - id="rv64_dev", - device_name="rv64_dev", - architecture=common_definitions.DeviceArchitecture.RV64_GENERIC, - host_environment=common_definitions.HostEnvironment.LINUX_X86_64, - device_parameters=[], - tags=[]) - run_config_incompatible = iree_definitions.E2EModelRunConfig.build( - module_generation_config=gen_config_rv64, - module_execution_config=exec_config_b, - target_device_spec=device_spec_rv64, - input_data=common_definitions.ZEROS_MODEL_INPUT_DATA, - tool=iree_definitions.E2EModelRunTool.IREE_BENCHMARK_MODULE) - self.incompatible_case = BenchmarkCase( - model_name="model_tflite", - model_tags=[], - bench_mode=["task"], - target_arch=common_definitions.DeviceArchitecture.RV64_GENERIC, - driver_info=IREE_DRIVERS_INFOS["iree-llvm-cpu"], - benchmark_case_dir=pathlib.Path("incompatible_case"), - benchmark_tool_name="tool", - run_config=run_config_incompatible) - self.benchmark_suite = BenchmarkSuite([ - self.case1, - self.case2, - self.incompatible_case, - ]) - - def tearDown(self) -> None: - self._tmp_dir_obj.cleanup() - self._root_dir_obj.cleanup() - - def test_run(self): - driver = FakeBenchmarkDriver(self.device_info, self.config, - self.benchmark_suite) - - driver.run() - - self.assertEqual(driver.get_benchmark_results().commit, "abcd") - self.assertEqual(len(driver.get_benchmark_results().benchmarks), 2) - self.assertEqual( - driver.get_benchmark_results().benchmarks[0].metrics.raw_data, {}) - self.assertEqual(driver.get_benchmark_result_filenames(), [ - self.benchmark_results_dir / f"{self.case1.run_config}.json", - self.benchmark_results_dir / f"{self.case2.run_config}.json" - ]) - self.assertEqual(driver.get_capture_filenames(), [ - self.captures_dir / f"{self.case1.run_config}.tracy", - self.captures_dir / f"{self.case2.run_config}.tracy" - ]) - self.assertEqual(driver.get_benchmark_errors(), []) - - def test_run_disable_compatible_filter(self): - self.config.use_compatible_filter = False - driver = FakeBenchmarkDriver(self.device_info, self.config, - self.benchmark_suite) - - driver.run() - - self.assertEqual(len(driver.get_benchmark_results().benchmarks), 3) - - def test_run_with_no_capture(self): - self.config.trace_capture_config = None - driver = FakeBenchmarkDriver(self.device_info, self.config, - self.benchmark_suite) - - driver.run() - - self.assertEqual(len(driver.get_benchmark_result_filenames()), 2) - self.assertEqual(driver.get_capture_filenames(), []) - - def test_run_with_exception_and_keep_going(self): - self.config.keep_going = True - driver = FakeBenchmarkDriver(self.device_info, - self.config, - self.benchmark_suite, - raise_exception_on_case=self.case1) - - driver.run() - - self.assertEqual(len(driver.get_benchmark_errors()), 1) - self.assertEqual(len(driver.get_benchmark_result_filenames()), 1) - - def test_run_with_previous_benchmarks_and_captures(self): - benchmark_filename = (self.benchmark_results_dir / - f"{self.case1.run_config}.json") - benchmark_filename.touch() - capture_filename = self.captures_dir / f"{self.case1.run_config}.tracy" - capture_filename.touch() - config = dataclasses.replace(self.config, continue_from_previous=True) - driver = FakeBenchmarkDriver(device_info=self.device_info, - benchmark_config=config, - benchmark_suite=self.benchmark_suite) - - driver.run() - - self.assertEqual(len(driver.run_benchmark_cases), 1) - self.assertEqual(len(driver.get_benchmark_result_filenames()), 2) - self.assertEqual(len(driver.get_capture_filenames()), 2) + def setUp(self): + self._tmp_dir_obj = tempfile.TemporaryDirectory() + self._root_dir_obj = tempfile.TemporaryDirectory() + + self.tmp_dir = pathlib.Path(self._tmp_dir_obj.name) + (self.tmp_dir / "build_config.txt").write_text( + "IREE_HAL_DRIVER_LOCAL_SYNC=ON\n" + "IREE_HAL_DRIVER_LOCAL_TASK=ON\n" + "IREE_HAL_EXECUTABLE_LOADER_EMBEDDED_ELF=ON\n" + ) + + self.benchmark_results_dir = ( + self.tmp_dir / benchmark_config.BENCHMARK_RESULTS_REL_PATH + ) + self.captures_dir = self.tmp_dir / benchmark_config.CAPTURES_REL_PATH + self.benchmark_results_dir.mkdir() + self.captures_dir.mkdir() + + self.config = benchmark_config.BenchmarkConfig( + root_benchmark_dir=pathlib.Path(self._root_dir_obj.name), + benchmark_results_dir=self.benchmark_results_dir, + git_commit_hash="abcd", + normal_benchmark_tool_dir=self.tmp_dir, + trace_capture_config=benchmark_config.TraceCaptureConfig( + traced_benchmark_tool_dir=self.tmp_dir, + trace_capture_tool=self.tmp_dir / "capture_tool", + capture_tarball=self.tmp_dir / "captures.tar", + capture_tmp_dir=self.captures_dir, + ), + use_compatible_filter=True, + ) + + self.device_info = DeviceInfo( + platform_type=PlatformType.LINUX, + model="Unknown", + cpu_abi="x86_64", + cpu_uarch="CascadeLake", + cpu_features=[], + gpu_name="unknown", + ) + + model_tflite = common_definitions.Model( + id="tflite", + name="model_tflite", + tags=[], + source_type=common_definitions.ModelSourceType.EXPORTED_TFLITE, + source_url="", + entry_function="predict", + input_types=["1xf32"], + ) + device_spec = common_definitions.DeviceSpec.build( + id="dev", + device_name="test_dev", + architecture=common_definitions.DeviceArchitecture.X86_64_CASCADELAKE, + host_environment=common_definitions.HostEnvironment.LINUX_X86_64, + device_parameters=[], + tags=[], + ) + compile_target = iree_definitions.CompileTarget( + target_backend=iree_definitions.TargetBackend.LLVM_CPU, + target_architecture=( + common_definitions.DeviceArchitecture.X86_64_CASCADELAKE + ), + target_abi=iree_definitions.TargetABI.LINUX_GNU, + ) + gen_config = iree_definitions.ModuleGenerationConfig.build( + imported_model=iree_definitions.ImportedModel.from_model(model_tflite), + compile_config=iree_definitions.CompileConfig.build( + id="comp_a", tags=[], compile_targets=[compile_target] + ), + ) + exec_config_a = iree_definitions.ModuleExecutionConfig.build( + id="exec_a", + tags=["sync"], + loader=iree_definitions.RuntimeLoader.EMBEDDED_ELF, + driver=iree_definitions.RuntimeDriver.LOCAL_SYNC, + ) + run_config_a = iree_definitions.E2EModelRunConfig.build( + module_generation_config=gen_config, + module_execution_config=exec_config_a, + target_device_spec=device_spec, + input_data=common_definitions.ZEROS_MODEL_INPUT_DATA, + tool=iree_definitions.E2EModelRunTool.IREE_BENCHMARK_MODULE, + ) + exec_config_b = iree_definitions.ModuleExecutionConfig.build( + id="exec_b", + tags=["task"], + loader=iree_definitions.RuntimeLoader.EMBEDDED_ELF, + driver=iree_definitions.RuntimeDriver.LOCAL_TASK, + ) + run_config_b = iree_definitions.E2EModelRunConfig.build( + module_generation_config=gen_config, + module_execution_config=exec_config_b, + target_device_spec=device_spec, + input_data=common_definitions.ZEROS_MODEL_INPUT_DATA, + tool=iree_definitions.E2EModelRunTool.IREE_BENCHMARK_MODULE, + ) + self.case1 = BenchmarkCase( + model_name="model_tflite", + model_tags=[], + bench_mode=["sync"], + target_arch=common_definitions.DeviceArchitecture.X86_64_CASCADELAKE, + driver_info=IREE_DRIVERS_INFOS["iree-llvm-cpu-sync"], + benchmark_case_dir=pathlib.Path("case1"), + benchmark_tool_name="tool", + run_config=run_config_a, + ) + self.case2 = BenchmarkCase( + model_name="model_tflite", + model_tags=[], + bench_mode=["task"], + target_arch=common_definitions.DeviceArchitecture.X86_64_CASCADELAKE, + driver_info=IREE_DRIVERS_INFOS["iree-llvm-cpu"], + benchmark_case_dir=pathlib.Path("case2"), + benchmark_tool_name="tool", + run_config=run_config_b, + ) + + compile_target_rv64 = iree_definitions.CompileTarget( + target_backend=iree_definitions.TargetBackend.LLVM_CPU, + target_architecture=common_definitions.DeviceArchitecture.RV64_GENERIC, + target_abi=iree_definitions.TargetABI.LINUX_GNU, + ) + gen_config_rv64 = iree_definitions.ModuleGenerationConfig.build( + imported_model=iree_definitions.ImportedModel.from_model(model_tflite), + compile_config=iree_definitions.CompileConfig.build( + id="comp_rv64", tags=[], compile_targets=[compile_target_rv64] + ), + ) + device_spec_rv64 = common_definitions.DeviceSpec.build( + id="rv64_dev", + device_name="rv64_dev", + architecture=common_definitions.DeviceArchitecture.RV64_GENERIC, + host_environment=common_definitions.HostEnvironment.LINUX_X86_64, + device_parameters=[], + tags=[], + ) + run_config_incompatible = iree_definitions.E2EModelRunConfig.build( + module_generation_config=gen_config_rv64, + module_execution_config=exec_config_b, + target_device_spec=device_spec_rv64, + input_data=common_definitions.ZEROS_MODEL_INPUT_DATA, + tool=iree_definitions.E2EModelRunTool.IREE_BENCHMARK_MODULE, + ) + self.incompatible_case = BenchmarkCase( + model_name="model_tflite", + model_tags=[], + bench_mode=["task"], + target_arch=common_definitions.DeviceArchitecture.RV64_GENERIC, + driver_info=IREE_DRIVERS_INFOS["iree-llvm-cpu"], + benchmark_case_dir=pathlib.Path("incompatible_case"), + benchmark_tool_name="tool", + run_config=run_config_incompatible, + ) + self.benchmark_suite = BenchmarkSuite( + [ + self.case1, + self.case2, + self.incompatible_case, + ] + ) + + def tearDown(self) -> None: + self._tmp_dir_obj.cleanup() + self._root_dir_obj.cleanup() + + def test_run(self): + driver = FakeBenchmarkDriver( + self.device_info, self.config, self.benchmark_suite + ) + + driver.run() + + self.assertEqual(driver.get_benchmark_results().commit, "abcd") + self.assertEqual(len(driver.get_benchmark_results().benchmarks), 2) + self.assertEqual( + driver.get_benchmark_results().benchmarks[0].metrics.raw_data, {} + ) + self.assertEqual( + driver.get_benchmark_result_filenames(), + [ + self.benchmark_results_dir / f"{self.case1.run_config}.json", + self.benchmark_results_dir / f"{self.case2.run_config}.json", + ], + ) + self.assertEqual( + driver.get_capture_filenames(), + [ + self.captures_dir / f"{self.case1.run_config}.tracy", + self.captures_dir / f"{self.case2.run_config}.tracy", + ], + ) + self.assertEqual(driver.get_benchmark_errors(), []) + + def test_run_disable_compatible_filter(self): + self.config.use_compatible_filter = False + driver = FakeBenchmarkDriver( + self.device_info, self.config, self.benchmark_suite + ) + + driver.run() + + self.assertEqual(len(driver.get_benchmark_results().benchmarks), 3) + + def test_run_with_no_capture(self): + self.config.trace_capture_config = None + driver = FakeBenchmarkDriver( + self.device_info, self.config, self.benchmark_suite + ) + + driver.run() + + self.assertEqual(len(driver.get_benchmark_result_filenames()), 2) + self.assertEqual(driver.get_capture_filenames(), []) + + def test_run_with_exception_and_keep_going(self): + self.config.keep_going = True + driver = FakeBenchmarkDriver( + self.device_info, + self.config, + self.benchmark_suite, + raise_exception_on_case=self.case1, + ) + + driver.run() + + self.assertEqual(len(driver.get_benchmark_errors()), 1) + self.assertEqual(len(driver.get_benchmark_result_filenames()), 1) + + def test_run_with_previous_benchmarks_and_captures(self): + benchmark_filename = ( + self.benchmark_results_dir / f"{self.case1.run_config}.json" + ) + benchmark_filename.touch() + capture_filename = self.captures_dir / f"{self.case1.run_config}.tracy" + capture_filename.touch() + config = dataclasses.replace(self.config, continue_from_previous=True) + driver = FakeBenchmarkDriver( + device_info=self.device_info, + benchmark_config=config, + benchmark_suite=self.benchmark_suite, + ) + + driver.run() + + self.assertEqual(len(driver.run_benchmark_cases), 1) + self.assertEqual(len(driver.get_benchmark_result_filenames()), 2) + self.assertEqual(len(driver.get_capture_filenames()), 2) if __name__ == "__main__": - unittest.main() + unittest.main() diff --git a/build_tools/benchmarks/common/benchmark_presentation.py b/build_tools/benchmarks/common/benchmark_presentation.py index 9cab2a79f61d..6eafb2bcec18 100644 --- a/build_tools/benchmarks/common/benchmark_presentation.py +++ b/build_tools/benchmarks/common/benchmark_presentation.py @@ -6,8 +6,18 @@ from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import (Any, Callable, Dict, Generic, List, Optional, Sequence, - Tuple, TypeVar, Union) +from typing import ( + Any, + Callable, + Dict, + Generic, + List, + Optional, + Sequence, + Tuple, + TypeVar, + Union, +) import pathlib import dataclasses import json @@ -16,11 +26,14 @@ import math from common import benchmark_definition, benchmark_thresholds -from common.benchmark_thresholds import (BENCHMARK_THRESHOLDS, - COMPILATION_TIME_THRESHOLDS, - TOTAL_ARTIFACT_SIZE_THRESHOLDS, - TOTAL_DISPATCH_SIZE_THRESHOLDS, - BenchmarkThreshold, ThresholdUnit) +from common.benchmark_thresholds import ( + BENCHMARK_THRESHOLDS, + COMPILATION_TIME_THRESHOLDS, + TOTAL_ARTIFACT_SIZE_THRESHOLDS, + TOTAL_DISPATCH_SIZE_THRESHOLDS, + BenchmarkThreshold, + ThresholdUnit, +) GetMetricFunc = Callable[[Any], Tuple[int, Optional[int]]] @@ -36,7 +49,9 @@ COMPILATION_TIME_METRIC_ID = "e54cd682-c079-4c42-b4ad-d92c4bedea13" COMPILATION_TIME_SERIES_SUFFIX = "compilation:module:compilation-time" TOTAL_DISPATCH_SIZE_METRIC_ID = "9e15f7e6-383c-47ec-bd38-ecba55a5f10a" -TOTAL_DISPATCH_SIZE_SERIES_SUFFIX = "compilation:module:component-size:total-dispatch-size" +TOTAL_DISPATCH_SIZE_SERIES_SUFFIX = ( + "compilation:module:component-size:total-dispatch-size" +) TOTAL_ARTIFACT_SIZE_METRIC_ID = "2c8a9198-c01c-45b9-a7da-69c82cf749f7" TOTAL_ARTIFACT_SIZE_SERIES_SUFFIX = "compilation:module:total-artifact-size" STREAM_IR_DISPATCH_COUNT_METRIC_ID = "7b72cd9e-43ed-4078-b6d3-20b810f9e4ad" @@ -45,290 +60,312 @@ @dataclass class AggregateBenchmarkLatency: - """An object for describing aggregate latency numbers for a benchmark.""" - name: str - benchmark_info: benchmark_definition.BenchmarkInfo - mean_time: int - median_time: int - stddev_time: int - # The average latency time for the base commit to compare against. - base_mean_time: Optional[int] = None + """An object for describing aggregate latency numbers for a benchmark.""" - def __str__(self) -> str: - return self.name + name: str + benchmark_info: benchmark_definition.BenchmarkInfo + mean_time: int + median_time: int + stddev_time: int + # The average latency time for the base commit to compare against. + base_mean_time: Optional[int] = None + + def __str__(self) -> str: + return self.name @dataclass(frozen=True) class CompilationMetrics: - """An object for describing the summary of statistics and the reference.""" - name: str - compilation_info: benchmark_definition.CompilationInfo - compilation_time_ms: int - total_dispatch_component_bytes: int - total_artifact_bytes: int - stream_ir_dispatch_count: int - base_compilation_time_ms: Optional[int] = None - base_total_artifact_bytes: Optional[int] = None - base_total_dispatch_component_bytes: Optional[int] = None - base_stream_ir_dispatch_count: Optional[int] = None - - def __str__(self) -> str: - return self.name + """An object for describing the summary of statistics and the reference.""" + + name: str + compilation_info: benchmark_definition.CompilationInfo + compilation_time_ms: int + total_dispatch_component_bytes: int + total_artifact_bytes: int + stream_ir_dispatch_count: int + base_compilation_time_ms: Optional[int] = None + base_total_artifact_bytes: Optional[int] = None + base_total_dispatch_component_bytes: Optional[int] = None + base_stream_ir_dispatch_count: Optional[int] = None + + def __str__(self) -> str: + return self.name T = TypeVar("T") class MetricsToTableMapper(ABC, Generic[T]): - """Abstract class to help map benchmark metrics to table. + """Abstract class to help map benchmark metrics to table. It contains a set of methods to help table generator get the required information for a metric. For example, extract the current and base metric value, the metric thresholds, the table header of the metrics, ... - """ + """ - @abstractmethod - def update_base_value(self, obj: T, base_value: Any) -> T: - """Sets the base value and returns the updated metric object.""" - raise NotImplementedError() + @abstractmethod + def update_base_value(self, obj: T, base_value: Any) -> T: + """Sets the base value and returns the updated metric object.""" + raise NotImplementedError() - @abstractmethod - def get_current_and_base_value(self, obj: T) -> Tuple[int, Optional[int]]: - """Returns the current and base (can be None) value.""" - raise NotImplementedError() + @abstractmethod + def get_current_and_base_value(self, obj: T) -> Tuple[int, Optional[int]]: + """Returns the current and base (can be None) value.""" + raise NotImplementedError() - def get_series_id(self, benchmark_id: str) -> str: - """Returns the dashboard series id.""" - return f"{benchmark_id}-{self.get_metric_id()}" + def get_series_id(self, benchmark_id: str) -> str: + """Returns the dashboard series id.""" + return f"{benchmark_id}-{self.get_metric_id()}" - @abstractmethod - def get_metric_id(self) -> str: - """Returns the dashboard series id.""" - raise NotImplementedError() + @abstractmethod + def get_metric_id(self) -> str: + """Returns the dashboard series id.""" + raise NotImplementedError() - @abstractmethod - def get_series_name(self, name: str) -> str: - """Returns the dashboard series name.""" - raise NotImplementedError() + @abstractmethod + def get_series_name(self, name: str) -> str: + """Returns the dashboard series name.""" + raise NotImplementedError() - @abstractmethod - def get_unit(self) -> str: - """Returns the unit of the metric value.""" - raise NotImplementedError() + @abstractmethod + def get_unit(self) -> str: + """Returns the unit of the metric value.""" + raise NotImplementedError() - @abstractmethod - def get_table_header(self) -> str: - """Returns the header of the table.""" - raise NotImplementedError() + @abstractmethod + def get_table_header(self) -> str: + """Returns the header of the table.""" + raise NotImplementedError() - @staticmethod - @abstractmethod - def get_metric_thresholds() -> Sequence[BenchmarkThreshold]: - raise NotImplementedError() + @staticmethod + @abstractmethod + def get_metric_thresholds() -> Sequence[BenchmarkThreshold]: + raise NotImplementedError() - @staticmethod - @abstractmethod - def get_table_title() -> str: - raise NotImplementedError() + @staticmethod + @abstractmethod + def get_table_title() -> str: + raise NotImplementedError() class CompilationTimeToTable(MetricsToTableMapper[CompilationMetrics]): - """Helper to map CompilationMetrics to compilation time column.""" + """Helper to map CompilationMetrics to compilation time column.""" - def update_base_value(self, compile_metrics: CompilationMetrics, - base_value: Any) -> CompilationMetrics: - return dataclasses.replace(compile_metrics, - base_compilation_time_ms=base_value) + def update_base_value( + self, compile_metrics: CompilationMetrics, base_value: Any + ) -> CompilationMetrics: + return dataclasses.replace(compile_metrics, base_compilation_time_ms=base_value) - def get_current_and_base_value( - self, compile_metrics: CompilationMetrics) -> Tuple[int, Optional[int]]: - return (compile_metrics.compilation_time_ms, - compile_metrics.base_compilation_time_ms) + def get_current_and_base_value( + self, compile_metrics: CompilationMetrics + ) -> Tuple[int, Optional[int]]: + return ( + compile_metrics.compilation_time_ms, + compile_metrics.base_compilation_time_ms, + ) - def get_metric_id(self) -> str: - return COMPILATION_TIME_METRIC_ID + def get_metric_id(self) -> str: + return COMPILATION_TIME_METRIC_ID - def get_series_name(self, name: str) -> str: - return f"{name} [{COMPILATION_TIME_SERIES_SUFFIX}]" + def get_series_name(self, name: str) -> str: + return f"{name} [{COMPILATION_TIME_SERIES_SUFFIX}]" - def get_unit(self) -> str: - return "ms" + def get_unit(self) -> str: + return "ms" - def get_table_header(self) -> str: - return f"Compilation Time ({self.get_unit()})" + def get_table_header(self) -> str: + return f"Compilation Time ({self.get_unit()})" - @staticmethod - def get_metric_thresholds() -> Sequence[BenchmarkThreshold]: - return COMPILATION_TIME_THRESHOLDS + @staticmethod + def get_metric_thresholds() -> Sequence[BenchmarkThreshold]: + return COMPILATION_TIME_THRESHOLDS - @staticmethod - def get_table_title() -> str: - return "Compilation Times" + @staticmethod + def get_table_title() -> str: + return "Compilation Times" class TotalDispatchSizeToTable(MetricsToTableMapper[CompilationMetrics]): - """Helper to map CompilationMetrics to total dispatch size column.""" + """Helper to map CompilationMetrics to total dispatch size column.""" - def update_base_value(self, compile_metrics: CompilationMetrics, - base_value: Any) -> CompilationMetrics: - return dataclasses.replace(compile_metrics, - base_total_dispatch_component_bytes=base_value) + def update_base_value( + self, compile_metrics: CompilationMetrics, base_value: Any + ) -> CompilationMetrics: + return dataclasses.replace( + compile_metrics, base_total_dispatch_component_bytes=base_value + ) - def get_current_and_base_value( - self, compile_metrics: CompilationMetrics) -> Tuple[int, Optional[int]]: - return (compile_metrics.total_dispatch_component_bytes, - compile_metrics.base_total_dispatch_component_bytes) + def get_current_and_base_value( + self, compile_metrics: CompilationMetrics + ) -> Tuple[int, Optional[int]]: + return ( + compile_metrics.total_dispatch_component_bytes, + compile_metrics.base_total_dispatch_component_bytes, + ) - def get_metric_id(self) -> str: - return TOTAL_DISPATCH_SIZE_METRIC_ID + def get_metric_id(self) -> str: + return TOTAL_DISPATCH_SIZE_METRIC_ID - def get_series_name(self, name: str) -> str: - return f"{name} [{TOTAL_DISPATCH_SIZE_SERIES_SUFFIX}]" + def get_series_name(self, name: str) -> str: + return f"{name} [{TOTAL_DISPATCH_SIZE_SERIES_SUFFIX}]" - def get_unit(self) -> str: - return "bytes" + def get_unit(self) -> str: + return "bytes" - def get_table_header(self) -> str: - return f"Total Dispatch Size ({self.get_unit()})" + def get_table_header(self) -> str: + return f"Total Dispatch Size ({self.get_unit()})" - @staticmethod - def get_metric_thresholds() -> Sequence[BenchmarkThreshold]: - return TOTAL_DISPATCH_SIZE_THRESHOLDS + @staticmethod + def get_metric_thresholds() -> Sequence[BenchmarkThreshold]: + return TOTAL_DISPATCH_SIZE_THRESHOLDS - @staticmethod - def get_table_title() -> str: - return "Total Dispatch Sizes" + @staticmethod + def get_table_title() -> str: + return "Total Dispatch Sizes" class TotalArtifactSizeToTable(MetricsToTableMapper[CompilationMetrics]): - """Helper to map CompilationMetrics to total artifact size column.""" + """Helper to map CompilationMetrics to total artifact size column.""" - def update_base_value(self, compile_metrics: CompilationMetrics, - base_value: Any) -> CompilationMetrics: - return dataclasses.replace(compile_metrics, - base_total_artifact_bytes=base_value) + def update_base_value( + self, compile_metrics: CompilationMetrics, base_value: Any + ) -> CompilationMetrics: + return dataclasses.replace( + compile_metrics, base_total_artifact_bytes=base_value + ) - def get_current_and_base_value( - self, compile_metrics: CompilationMetrics) -> Tuple[int, Optional[int]]: - return (compile_metrics.total_artifact_bytes, - compile_metrics.base_total_artifact_bytes) + def get_current_and_base_value( + self, compile_metrics: CompilationMetrics + ) -> Tuple[int, Optional[int]]: + return ( + compile_metrics.total_artifact_bytes, + compile_metrics.base_total_artifact_bytes, + ) - def get_metric_id(self) -> str: - return TOTAL_ARTIFACT_SIZE_METRIC_ID + def get_metric_id(self) -> str: + return TOTAL_ARTIFACT_SIZE_METRIC_ID - def get_series_name(self, name: str) -> str: - return f"{name} [{TOTAL_ARTIFACT_SIZE_SERIES_SUFFIX}]" + def get_series_name(self, name: str) -> str: + return f"{name} [{TOTAL_ARTIFACT_SIZE_SERIES_SUFFIX}]" - def get_unit(self) -> str: - return "bytes" + def get_unit(self) -> str: + return "bytes" - def get_table_header(self) -> str: - return f"Total Artifact Size ({self.get_unit()})" + def get_table_header(self) -> str: + return f"Total Artifact Size ({self.get_unit()})" - @staticmethod - def get_metric_thresholds() -> Sequence[BenchmarkThreshold]: - return TOTAL_ARTIFACT_SIZE_THRESHOLDS + @staticmethod + def get_metric_thresholds() -> Sequence[BenchmarkThreshold]: + return TOTAL_ARTIFACT_SIZE_THRESHOLDS - @staticmethod - def get_table_title() -> str: - return "Total Artifact Sizes" + @staticmethod + def get_table_title() -> str: + return "Total Artifact Sizes" class StreamIRDispatchCountToTable(MetricsToTableMapper[CompilationMetrics]): - """Helper to map CompilationMetrics to Stream IR Dispatch Count column.""" + """Helper to map CompilationMetrics to Stream IR Dispatch Count column.""" - def update_base_value(self, compile_metrics: CompilationMetrics, - base_value: Any) -> CompilationMetrics: - return dataclasses.replace(compile_metrics, - base_stream_ir_dispatch_count=base_value) + def update_base_value( + self, compile_metrics: CompilationMetrics, base_value: Any + ) -> CompilationMetrics: + return dataclasses.replace( + compile_metrics, base_stream_ir_dispatch_count=base_value + ) - def get_current_and_base_value( - self, compile_metrics: CompilationMetrics) -> Tuple[int, Optional[int]]: - return (compile_metrics.stream_ir_dispatch_count, - compile_metrics.base_stream_ir_dispatch_count) + def get_current_and_base_value( + self, compile_metrics: CompilationMetrics + ) -> Tuple[int, Optional[int]]: + return ( + compile_metrics.stream_ir_dispatch_count, + compile_metrics.base_stream_ir_dispatch_count, + ) - def get_metric_id(self) -> str: - return STREAM_IR_DISPATCH_COUNT_METRIC_ID + def get_metric_id(self) -> str: + return STREAM_IR_DISPATCH_COUNT_METRIC_ID - def get_series_name(self, name: str) -> str: - return f"{name} [{STREAM_IR_DISPATCH_COUNT_SERIES_SUFFIX}]" + def get_series_name(self, name: str) -> str: + return f"{name} [{STREAM_IR_DISPATCH_COUNT_SERIES_SUFFIX}]" - def get_unit(self) -> str: - return "number" + def get_unit(self) -> str: + return "number" - def get_table_header(self) -> str: - return f"Stream IR Dispatch Count (# of cmd.dispatch ops)" + def get_table_header(self) -> str: + return f"Stream IR Dispatch Count (# of cmd.dispatch ops)" - @staticmethod - def get_metric_thresholds() -> Sequence[BenchmarkThreshold]: - return benchmark_thresholds.STREAM_IR_DISPATCH_COUNT_THRESHOLDS + @staticmethod + def get_metric_thresholds() -> Sequence[BenchmarkThreshold]: + return benchmark_thresholds.STREAM_IR_DISPATCH_COUNT_THRESHOLDS - @staticmethod - def get_table_title() -> str: - return "Stream IR Dispatch Count (# of cmd.dispatch ops)" + @staticmethod + def get_table_title() -> str: + return "Stream IR Dispatch Count (# of cmd.dispatch ops)" -COMPILATION_METRICS_TO_TABLE_MAPPERS: List[ - MetricsToTableMapper[CompilationMetrics]] = [ - CompilationTimeToTable(), - TotalDispatchSizeToTable(), - TotalArtifactSizeToTable(), - StreamIRDispatchCountToTable(), - ] +COMPILATION_METRICS_TO_TABLE_MAPPERS: List[MetricsToTableMapper[CompilationMetrics]] = [ + CompilationTimeToTable(), + TotalDispatchSizeToTable(), + TotalArtifactSizeToTable(), + StreamIRDispatchCountToTable(), +] def aggregate_all_benchmarks( - benchmark_files: Sequence[pathlib.Path], - expected_pr_commit: Optional[str] = None + benchmark_files: Sequence[pathlib.Path], expected_pr_commit: Optional[str] = None ) -> Dict[str, AggregateBenchmarkLatency]: - """Aggregates all benchmarks in the given files. + """Aggregates all benchmarks in the given files. - Args: - - benchmark_files: A list of JSON files, each can be decoded as a - BenchmarkResults. - - expected_pr_commit: An optional Git commit SHA to match against. + Args: + - benchmark_files: A list of JSON files, each can be decoded as a + BenchmarkResults. + - expected_pr_commit: An optional Git commit SHA to match against. - Returns: - - A dict of benchmark names to AggregateBenchmarkLatency numbers. - """ + Returns: + - A dict of benchmark names to AggregateBenchmarkLatency numbers. + """ - aggregate_results = {} - benchmark_names = set() - for benchmark_file in benchmark_files: - file_results = benchmark_definition.BenchmarkResults.from_json_str( - benchmark_file.read_text()) + aggregate_results = {} + benchmark_names = set() + for benchmark_file in benchmark_files: + file_results = benchmark_definition.BenchmarkResults.from_json_str( + benchmark_file.read_text() + ) - if ((expected_pr_commit is not None) and - (file_results.commit != expected_pr_commit)): - raise ValueError("Inconsistent pull request commit") + if (expected_pr_commit is not None) and ( + file_results.commit != expected_pr_commit + ): + raise ValueError("Inconsistent pull request commit") - for benchmark_index in range(len(file_results.benchmarks)): - benchmark_run = file_results.benchmarks[benchmark_index] + for benchmark_index in range(len(file_results.benchmarks)): + benchmark_run = file_results.benchmarks[benchmark_index] - series_name = str(benchmark_run.info) - # Make sure each benchmark has a unique name. - if series_name in benchmark_names: - raise ValueError(f"Duplicated benchmark name: {series_name}") - benchmark_names.add(series_name) + series_name = str(benchmark_run.info) + # Make sure each benchmark has a unique name. + if series_name in benchmark_names: + raise ValueError(f"Duplicated benchmark name: {series_name}") + benchmark_names.add(series_name) - series_id = benchmark_run.info.run_config_id - if series_id in aggregate_results: - raise ValueError(f"Duplicated benchmark id: {series_id}") + series_id = benchmark_run.info.run_config_id + if series_id in aggregate_results: + raise ValueError(f"Duplicated benchmark id: {series_id}") - aggregate_results[series_id] = AggregateBenchmarkLatency( - name=series_name, - benchmark_info=benchmark_run.info, - mean_time=benchmark_run.metrics.real_time.mean, - median_time=benchmark_run.metrics.real_time.median, - stddev_time=benchmark_run.metrics.real_time.stddev) + aggregate_results[series_id] = AggregateBenchmarkLatency( + name=series_name, + benchmark_info=benchmark_run.info, + mean_time=benchmark_run.metrics.real_time.mean, + median_time=benchmark_run.metrics.real_time.median, + stddev_time=benchmark_run.metrics.real_time.stddev, + ) - return aggregate_results + return aggregate_results def collect_all_compilation_metrics( compile_stats_files: Sequence[pathlib.Path], - expected_pr_commit: Optional[str] = None) -> Dict[str, CompilationMetrics]: - """Collects all compilation statistics in the given files. + expected_pr_commit: Optional[str] = None, +) -> Dict[str, CompilationMetrics]: + """Collects all compilation statistics in the given files. Args: compile_stats_files: A list of JSON files, each can be decoded as a @@ -337,80 +374,81 @@ def collect_all_compilation_metrics( Returns: A dict of benchmark names to CompilationMetrics. - """ - compile_metrics = {} - target_names = set() - for compile_stats_file in compile_stats_files: - with compile_stats_file.open("r") as f: - file_results = benchmark_definition.CompilationResults.from_json_object( - json.load(f)) - - if ((expected_pr_commit is not None) and - (file_results.commit != expected_pr_commit)): - raise ValueError("Inconsistent pull request commit") - - for compile_stats in file_results.compilation_statistics: - component_sizes = compile_stats.module_component_sizes - stream_dispatch_count = compile_stats.ir_stats.stream_dispatch_count - - target_name = str(compile_stats.compilation_info) - if target_name in target_names: - raise ValueError(f"Duplicated target name: {target_name}") - target_names.add(target_name) - - target_id = compile_stats.compilation_info.gen_config_id - if target_id in compile_metrics: - raise ValueError(f"Duplicated target id: {target_id}") - - compile_metrics[target_id] = CompilationMetrics( - name=target_name, - compilation_info=compile_stats.compilation_info, - compilation_time_ms=compile_stats.compilation_time_ms, - total_artifact_bytes=component_sizes.file_bytes, - total_dispatch_component_bytes=component_sizes. - total_dispatch_component_bytes, - stream_ir_dispatch_count=stream_dispatch_count) - - return compile_metrics + """ + compile_metrics = {} + target_names = set() + for compile_stats_file in compile_stats_files: + with compile_stats_file.open("r") as f: + file_results = benchmark_definition.CompilationResults.from_json_object( + json.load(f) + ) + + if (expected_pr_commit is not None) and ( + file_results.commit != expected_pr_commit + ): + raise ValueError("Inconsistent pull request commit") + + for compile_stats in file_results.compilation_statistics: + component_sizes = compile_stats.module_component_sizes + stream_dispatch_count = compile_stats.ir_stats.stream_dispatch_count + + target_name = str(compile_stats.compilation_info) + if target_name in target_names: + raise ValueError(f"Duplicated target name: {target_name}") + target_names.add(target_name) + + target_id = compile_stats.compilation_info.gen_config_id + if target_id in compile_metrics: + raise ValueError(f"Duplicated target id: {target_id}") + + compile_metrics[target_id] = CompilationMetrics( + name=target_name, + compilation_info=compile_stats.compilation_info, + compilation_time_ms=compile_stats.compilation_time_ms, + total_artifact_bytes=component_sizes.file_bytes, + total_dispatch_component_bytes=component_sizes.total_dispatch_component_bytes, + stream_ir_dispatch_count=stream_dispatch_count, + ) + + return compile_metrics def _make_series_link(name: str, series_id: str) -> str: - """Add link to the given benchmark name. + """Add link to the given benchmark name. Args: name: the text to show on the link. series_id: the dashboard series id. - """ - url = PERFBOARD_SERIES_PREFIX + urllib.parse.quote(series_id, safe="()[]@,") - return md.link(name, url) + """ + url = PERFBOARD_SERIES_PREFIX + urllib.parse.quote(series_id, safe="()[]@,") + return md.link(name, url) -def _add_header_and_get_markdown_table(headers: Sequence[str], - rows: Sequence[Tuple], - size_cut: Optional[int] = None) -> str: - """Generates a markdown table with headers. +def _add_header_and_get_markdown_table( + headers: Sequence[str], rows: Sequence[Tuple], size_cut: Optional[int] = None +) -> str: + """Generates a markdown table with headers. - Args: - headers: list of table headers. - rows: list of rows. Each row is a tuple with the same length as headers. - size_cut: If not None, only show the top N results for each table. - """ + Args: + headers: list of table headers. + rows: list of rows. Each row is a tuple with the same length as headers. + size_cut: If not None, only show the top N results for each table. + """ - total_size = len(rows) - if size_cut is not None: - rows = rows[0:size_cut] + total_size = len(rows) + if size_cut is not None: + rows = rows[0:size_cut] - columns = [[header] for header in headers] - for row in rows: - for column, item in zip(columns, row): - column.append(item) + columns = [[header] for header in headers] + for row in rows: + for column, item in zip(columns, row): + column.append(item) - table_str = md.table(columns) - if size_cut is not None and size_cut < total_size: - table_str += "\n\n" - table_str += md.italics( - f"[Top {size_cut} out of {total_size} results showed]") - return table_str + table_str = md.table(columns) + if size_cut is not None and size_cut < total_size: + table_str += "\n\n" + table_str += md.italics(f"[Top {size_cut} out of {total_size} results showed]") + return table_str T = TypeVar("T") @@ -422,7 +460,7 @@ def _categorize_on_single_metric( thresholds: Sequence[BenchmarkThreshold], metric_unit: str, ) -> Tuple[Dict[str, T], Dict[str, T], Dict[str, T], Dict[str, T]]: - """Categorize the metrics object into regressed, improved, similar, and the + """Categorize the metrics object into regressed, improved, similar, and the raw group (the group with no base to compare to). Args: @@ -431,98 +469,106 @@ def _categorize_on_single_metric( thresholds: list of threshold settings to match for categorizing. Returns: A tuple of (regressed, improved, similar, raw) groups. - """ - - regressed_map = {} - improved_map = {} - similar_map = {} - raw_map = {} - for series_id, metrics_obj in metrics_map.items(): - current, base = metric_func(metrics_obj) - if base is None: - raw_map[series_id] = metrics_obj - continue - - series_name = str(metrics_obj) - similar_threshold = None - for threshold in thresholds: - if threshold.regex.match(series_name): - similar_threshold = threshold - break - if similar_threshold is None: - raise ValueError(f"No matched threshold setting for: {series_name}") - - if similar_threshold.unit == ThresholdUnit.PERCENTAGE: - ratio = abs(current - base) / base * 100 - elif similar_threshold.unit.value == metric_unit: - ratio = abs(current - base) - else: - raise ValueError( - f"Mismatch between metric unit '{metric_unit}' and threshold unit '{similar_threshold.unit.value}'" - ) - - if ratio <= similar_threshold.threshold: - similar_map[series_id] = metrics_obj - elif current > base: - regressed_map[series_id] = metrics_obj - else: - improved_map[series_id] = metrics_obj - - return (regressed_map, improved_map, similar_map, raw_map) + """ + + regressed_map = {} + improved_map = {} + similar_map = {} + raw_map = {} + for series_id, metrics_obj in metrics_map.items(): + current, base = metric_func(metrics_obj) + if base is None: + raw_map[series_id] = metrics_obj + continue + + series_name = str(metrics_obj) + similar_threshold = None + for threshold in thresholds: + if threshold.regex.match(series_name): + similar_threshold = threshold + break + if similar_threshold is None: + raise ValueError(f"No matched threshold setting for: {series_name}") + + if similar_threshold.unit == ThresholdUnit.PERCENTAGE: + ratio = abs(current - base) / base * 100 + elif similar_threshold.unit.value == metric_unit: + ratio = abs(current - base) + else: + raise ValueError( + f"Mismatch between metric unit '{metric_unit}' and threshold unit '{similar_threshold.unit.value}'" + ) + + if ratio <= similar_threshold.threshold: + similar_map[series_id] = metrics_obj + elif current > base: + regressed_map[series_id] = metrics_obj + else: + improved_map[series_id] = metrics_obj + + return (regressed_map, improved_map, similar_map, raw_map) def _get_fixed_point_str(value: Union[int, float], digits=3) -> str: - if isinstance(value, int) or value.is_integer(): - return str(math.floor(value)) - return f"{{:.{digits}f}}".format(value) + if isinstance(value, int) or value.is_integer(): + return str(math.floor(value)) + return f"{{:.{digits}f}}".format(value) def _get_compare_text(current: float, base: Optional[int]) -> str: - """Generates the text of comparison between current and base value. Returns + """Generates the text of comparison between current and base value. Returns the current value if the base value is None. - """ - # If base is None, don't need to do compare. - if base is None: - return f"{_get_fixed_point_str(current)}" + """ + # If base is None, don't need to do compare. + if base is None: + return f"{_get_fixed_point_str(current)}" - ratio = abs(current - base) / base - direction = "↑" if current > base else ("↓" if current < base else "") - return f"{_get_fixed_point_str(current)} (vs. {_get_fixed_point_str(base)}, {ratio:.2%}{direction})" + ratio = abs(current - base) / base + direction = "↑" if current > base else ("↓" if current < base else "") + return f"{_get_fixed_point_str(current)} (vs. {_get_fixed_point_str(base)}, {ratio:.2%}{direction})" -def _sort_benchmarks_and_get_table(benchmarks: Dict[str, - AggregateBenchmarkLatency], - size_cut: Optional[int] = None) -> str: - """Sorts all benchmarks according to the improvement/regression ratio and +def _sort_benchmarks_and_get_table( + benchmarks: Dict[str, AggregateBenchmarkLatency], size_cut: Optional[int] = None +) -> str: + """Sorts all benchmarks according to the improvement/regression ratio and returns a markdown table for it. Args: benchmarks_map: map of (series_id, benchmark object). size_cut: If not None, only show the top N results for each table. - """ - sorted_rows = [] - for series_id, benchmark in benchmarks.items(): - current = benchmark.mean_time / 1e6 - base = benchmark.base_mean_time / 1e6 - ratio = abs(current - base) / base - str_mean = _get_compare_text(current, base) - clickable_name = _make_series_link(benchmark.name, series_id) - sorted_rows.append( - (ratio, (clickable_name, str_mean, - f"{_get_fixed_point_str(benchmark.median_time / 1e6)}", - f"{_get_fixed_point_str(benchmark.stddev_time / 1e6)}"))) - sorted_rows.sort(key=lambda row: row[0], reverse=True) - - return _add_header_and_get_markdown_table( - headers=BENCHMARK_RESULTS_HEADERS, - rows=[row[1] for row in sorted_rows], - size_cut=size_cut) - - -def categorize_benchmarks_into_tables(benchmarks: Dict[ - str, AggregateBenchmarkLatency], - size_cut: Optional[int] = None) -> str: - """Splits benchmarks into regressed/improved/similar/raw categories and + """ + sorted_rows = [] + for series_id, benchmark in benchmarks.items(): + current = benchmark.mean_time / 1e6 + base = benchmark.base_mean_time / 1e6 + ratio = abs(current - base) / base + str_mean = _get_compare_text(current, base) + clickable_name = _make_series_link(benchmark.name, series_id) + sorted_rows.append( + ( + ratio, + ( + clickable_name, + str_mean, + f"{_get_fixed_point_str(benchmark.median_time / 1e6)}", + f"{_get_fixed_point_str(benchmark.stddev_time / 1e6)}", + ), + ) + ) + sorted_rows.sort(key=lambda row: row[0], reverse=True) + + return _add_header_and_get_markdown_table( + headers=BENCHMARK_RESULTS_HEADERS, + rows=[row[1] for row in sorted_rows], + size_cut=size_cut, + ) + + +def categorize_benchmarks_into_tables( + benchmarks: Dict[str, AggregateBenchmarkLatency], size_cut: Optional[int] = None +) -> str: + """Splits benchmarks into regressed/improved/similar/raw categories and returns their markdown tables. If size_cut is None, the table includes regressed/improved/similar/raw @@ -531,41 +577,51 @@ def categorize_benchmarks_into_tables(benchmarks: Dict[ Args: benchmarks: A dictionary of benchmark names to its aggregate info. size_cut: If not None, only show the top N results for each table. - """ - regressed, improved, similar, raw = _categorize_on_single_metric( - benchmarks, lambda results: (results.mean_time, results.base_mean_time), - BENCHMARK_THRESHOLDS, "ns") - - tables = [] - if regressed: - tables.append(md.header("Regressed Latencies 🚩", 3)) - tables.append(_sort_benchmarks_and_get_table(regressed, size_cut)) - if improved: - tables.append(md.header("Improved Latencies 🎉", 3)) - tables.append(_sort_benchmarks_and_get_table(improved, size_cut)) - # If we want to abbreviate, similar results won't be interesting. - if similar and size_cut is None: - tables.append(md.header("Similar Latencies", 3)) - tables.append(_sort_benchmarks_and_get_table(similar, size_cut)) - if raw: - tables.append(md.header("Raw Latencies", 3)) - raw_list = [(_make_series_link(name=v.name, series_id=k), - f"{_get_fixed_point_str(v.mean_time / 1e6)}", - f"{_get_fixed_point_str(v.median_time / 1e6)}", - f"{_get_fixed_point_str(v.stddev_time / 1e6)}") - for k, v in raw.items()] - tables.append( - _add_header_and_get_markdown_table(BENCHMARK_RESULTS_HEADERS, - raw_list, - size_cut=size_cut)) - return "\n\n".join(tables) - - -def _sort_metrics_objects_and_get_table(metrics_objs: Dict[str, T], - mapper: MetricsToTableMapper[T], - headers: Sequence[str], - size_cut: Optional[int] = None) -> str: - """Sorts all metrics objects according to the improvement/regression ratio and + """ + regressed, improved, similar, raw = _categorize_on_single_metric( + benchmarks, + lambda results: (results.mean_time, results.base_mean_time), + BENCHMARK_THRESHOLDS, + "ns", + ) + + tables = [] + if regressed: + tables.append(md.header("Regressed Latencies 🚩", 3)) + tables.append(_sort_benchmarks_and_get_table(regressed, size_cut)) + if improved: + tables.append(md.header("Improved Latencies 🎉", 3)) + tables.append(_sort_benchmarks_and_get_table(improved, size_cut)) + # If we want to abbreviate, similar results won't be interesting. + if similar and size_cut is None: + tables.append(md.header("Similar Latencies", 3)) + tables.append(_sort_benchmarks_and_get_table(similar, size_cut)) + if raw: + tables.append(md.header("Raw Latencies", 3)) + raw_list = [ + ( + _make_series_link(name=v.name, series_id=k), + f"{_get_fixed_point_str(v.mean_time / 1e6)}", + f"{_get_fixed_point_str(v.median_time / 1e6)}", + f"{_get_fixed_point_str(v.stddev_time / 1e6)}", + ) + for k, v in raw.items() + ] + tables.append( + _add_header_and_get_markdown_table( + BENCHMARK_RESULTS_HEADERS, raw_list, size_cut=size_cut + ) + ) + return "\n\n".join(tables) + + +def _sort_metrics_objects_and_get_table( + metrics_objs: Dict[str, T], + mapper: MetricsToTableMapper[T], + headers: Sequence[str], + size_cut: Optional[int] = None, +) -> str: + """Sorts all metrics objects according to the improvement/regression ratio and returns a markdown table for it. Args: @@ -574,27 +630,35 @@ def _sort_metrics_objects_and_get_table(metrics_objs: Dict[str, T], mapper: MetricsToTableMapper for metrics_objs. headers: list of table headers. size_cut: If not None, only show the top N results for each table. - """ - sorted_rows = [] - for target_id, metrics_obj in metrics_objs.items(): - current, base = mapper.get_current_and_base_value(metrics_obj) - if base is None: - raise AssertionError("Base can't be None for sorting.") - ratio = abs(current - base) / base - sorted_rows.append((ratio, ( - _make_series_link(str(metrics_obj), mapper.get_series_id(target_id)), - _get_compare_text(current, base), - ))) - sorted_rows.sort(key=lambda row: row[0], reverse=True) - - return _add_header_and_get_markdown_table( - headers=headers, rows=[row[1] for row in sorted_rows], size_cut=size_cut) + """ + sorted_rows = [] + for target_id, metrics_obj in metrics_objs.items(): + current, base = mapper.get_current_and_base_value(metrics_obj) + if base is None: + raise AssertionError("Base can't be None for sorting.") + ratio = abs(current - base) / base + sorted_rows.append( + ( + ratio, + ( + _make_series_link( + str(metrics_obj), mapper.get_series_id(target_id) + ), + _get_compare_text(current, base), + ), + ) + ) + sorted_rows.sort(key=lambda row: row[0], reverse=True) + + return _add_header_and_get_markdown_table( + headers=headers, rows=[row[1] for row in sorted_rows], size_cut=size_cut + ) def categorize_compilation_metrics_into_tables( - compile_metrics_map: Dict[str, CompilationMetrics], - size_cut: Optional[int] = None) -> str: - """Splits compilation metrics into regressed/improved/all categories + compile_metrics_map: Dict[str, CompilationMetrics], size_cut: Optional[int] = None +) -> str: + """Splits compilation metrics into regressed/improved/all categories and returns their markdown tables. If size_cut is None, the table includes regressed/improved/all categories; @@ -604,51 +668,61 @@ def categorize_compilation_metrics_into_tables( compile_metrics_map: A dictionary of benchmark names to its compilation metrics. size_cut: If not None, only show the top N results for each table. - """ - - tables = [] - for mapper in COMPILATION_METRICS_TO_TABLE_MAPPERS: - regressed, improved, _, _ = _categorize_on_single_metric( - compile_metrics_map, mapper.get_current_and_base_value, - mapper.get_metric_thresholds(), mapper.get_unit()) - - table_title = mapper.get_table_title() - table_header = mapper.get_table_header() - if regressed: - tables.append(md.header(f"Regressed {table_title} 🚩", 3)) - tables.append( - _sort_metrics_objects_and_get_table( - metrics_objs=regressed, - mapper=mapper, - headers=["Benchmark Name", table_header], - size_cut=size_cut)) - if improved: - tables.append(md.header(f"Improved {table_title} 🎉", 3)) - tables.append( - _sort_metrics_objects_and_get_table( - metrics_objs=improved, - mapper=mapper, - headers=["Benchmark Name", table_header], - size_cut=size_cut)) - - # If we want to abbreviate, similar results won't be interesting. - if size_cut is None and compile_metrics_map: - tables.append(md.header("All Compilation Metrics", 3)) - headers = ["Benchmark Name"] + [ - mapper.get_table_header() - for mapper in COMPILATION_METRICS_TO_TABLE_MAPPERS - ] - rows = [] - for target_id, metrics in compile_metrics_map.items(): - row = [metrics.name] - for mapper in COMPILATION_METRICS_TO_TABLE_MAPPERS: - current, base = mapper.get_current_and_base_value(metrics) - row.append( - _make_series_link(_get_compare_text(current, base), - mapper.get_series_id(target_id))) - rows.append(tuple(row)) - - tables.append( - _add_header_and_get_markdown_table(headers, rows, size_cut=size_cut)) - - return "\n\n".join(tables) + """ + + tables = [] + for mapper in COMPILATION_METRICS_TO_TABLE_MAPPERS: + regressed, improved, _, _ = _categorize_on_single_metric( + compile_metrics_map, + mapper.get_current_and_base_value, + mapper.get_metric_thresholds(), + mapper.get_unit(), + ) + + table_title = mapper.get_table_title() + table_header = mapper.get_table_header() + if regressed: + tables.append(md.header(f"Regressed {table_title} 🚩", 3)) + tables.append( + _sort_metrics_objects_and_get_table( + metrics_objs=regressed, + mapper=mapper, + headers=["Benchmark Name", table_header], + size_cut=size_cut, + ) + ) + if improved: + tables.append(md.header(f"Improved {table_title} 🎉", 3)) + tables.append( + _sort_metrics_objects_and_get_table( + metrics_objs=improved, + mapper=mapper, + headers=["Benchmark Name", table_header], + size_cut=size_cut, + ) + ) + + # If we want to abbreviate, similar results won't be interesting. + if size_cut is None and compile_metrics_map: + tables.append(md.header("All Compilation Metrics", 3)) + headers = ["Benchmark Name"] + [ + mapper.get_table_header() for mapper in COMPILATION_METRICS_TO_TABLE_MAPPERS + ] + rows = [] + for target_id, metrics in compile_metrics_map.items(): + row = [metrics.name] + for mapper in COMPILATION_METRICS_TO_TABLE_MAPPERS: + current, base = mapper.get_current_and_base_value(metrics) + row.append( + _make_series_link( + _get_compare_text(current, base), + mapper.get_series_id(target_id), + ) + ) + rows.append(tuple(row)) + + tables.append( + _add_header_and_get_markdown_table(headers, rows, size_cut=size_cut) + ) + + return "\n\n".join(tables) diff --git a/build_tools/benchmarks/common/benchmark_suite.py b/build_tools/benchmarks/common/benchmark_suite.py index d673909827ad..da094d7fa606 100644 --- a/build_tools/benchmarks/common/benchmark_suite.py +++ b/build_tools/benchmarks/common/benchmark_suite.py @@ -25,7 +25,7 @@ @dataclass class BenchmarkCase: - """Represents a benchmark case. + """Represents a benchmark case. model_name: the source model, e.g., 'MobileSSD'. model_tags: the source model tags, e.g., ['f32']. @@ -35,148 +35,182 @@ class BenchmarkCase: benchmark_tool_name: the benchmark tool, e.g., 'iree-benchmark-module'. benchmark_case_dir: the path to benchmark case directory. run_config: the run config from e2e test framework. - """ + """ - model_name: str - model_tags: Sequence[str] - bench_mode: Sequence[str] - target_arch: common_definitions.DeviceArchitecture - driver_info: DriverInfo - benchmark_tool_name: str - benchmark_case_dir: pathlib.Path - run_config: iree_definitions.E2EModelRunConfig + model_name: str + model_tags: Sequence[str] + bench_mode: Sequence[str] + target_arch: common_definitions.DeviceArchitecture + driver_info: DriverInfo + benchmark_tool_name: str + benchmark_case_dir: pathlib.Path + run_config: iree_definitions.E2EModelRunConfig # A map from execution config to driver info. This is temporary during migration # before we can drop the DriverInfo. -EXECUTION_CONFIG_TO_DRIVER_INFO_KEY_MAP: Dict[Tuple[ - iree_definitions.RuntimeDriver, iree_definitions.RuntimeLoader], str] = { - (iree_definitions.RuntimeDriver.LOCAL_TASK, iree_definitions.RuntimeLoader.EMBEDDED_ELF): - "iree-llvm-cpu", - (iree_definitions.RuntimeDriver.LOCAL_SYNC, iree_definitions.RuntimeLoader.EMBEDDED_ELF): - "iree-llvm-cpu-sync", - (iree_definitions.RuntimeDriver.LOCAL_TASK, iree_definitions.RuntimeLoader.VMVX_MODULE): - "iree-vmvx", - (iree_definitions.RuntimeDriver.LOCAL_SYNC, iree_definitions.RuntimeLoader.VMVX_MODULE): - "iree-vmvx-sync", - (iree_definitions.RuntimeDriver.VULKAN, iree_definitions.RuntimeLoader.NONE): - "iree-vulkan", - (iree_definitions.RuntimeDriver.CUDA, iree_definitions.RuntimeLoader.NONE): - "iree-cuda", - } +EXECUTION_CONFIG_TO_DRIVER_INFO_KEY_MAP: Dict[ + Tuple[iree_definitions.RuntimeDriver, iree_definitions.RuntimeLoader], str +] = { + ( + iree_definitions.RuntimeDriver.LOCAL_TASK, + iree_definitions.RuntimeLoader.EMBEDDED_ELF, + ): "iree-llvm-cpu", + ( + iree_definitions.RuntimeDriver.LOCAL_SYNC, + iree_definitions.RuntimeLoader.EMBEDDED_ELF, + ): "iree-llvm-cpu-sync", + ( + iree_definitions.RuntimeDriver.LOCAL_TASK, + iree_definitions.RuntimeLoader.VMVX_MODULE, + ): "iree-vmvx", + ( + iree_definitions.RuntimeDriver.LOCAL_SYNC, + iree_definitions.RuntimeLoader.VMVX_MODULE, + ): "iree-vmvx-sync", + ( + iree_definitions.RuntimeDriver.VULKAN, + iree_definitions.RuntimeLoader.NONE, + ): "iree-vulkan", + ( + iree_definitions.RuntimeDriver.CUDA, + iree_definitions.RuntimeLoader.NONE, + ): "iree-cuda", +} class BenchmarkSuite(object): - """Represents the benchmarks in benchmark suite directory.""" - - def __init__(self, benchmark_cases: Sequence[BenchmarkCase]): - """Construct a benchmark suite. - - Args: - benchmark_cases: list of benchmark cases. - """ - self.benchmark_cases = list(benchmark_cases) - - def filter_benchmarks( - self, - available_drivers: Optional[Sequence[str]] = None, - available_loaders: Optional[Sequence[str]] = None, - target_architectures: Optional[Sequence[ - common_definitions.DeviceArchitecture]] = None, - driver_filter: Optional[str] = None, - mode_filter: Optional[str] = None, - model_name_filter: Optional[str] = None) -> Sequence[BenchmarkCase]: - """Filters benchmarks. - Args: - available_drivers: list of drivers supported by the tools. None means to - match any driver. - available_loaders: list of executable loaders supported by the tools. - None means to match any loader. - target_architectures: list of target architectures to be included. None - means no filter. - driver_filter: driver filter regex. - mode_filter: benchmark mode regex. - model_name_filter: model name regex. - Returns: - A list of matched benchmark cases. - """ - - chosen_cases = [] - for benchmark_case in self.benchmark_cases: - driver_info = benchmark_case.driver_info - - driver_name = driver_info.driver_name - matched_available_driver = (available_drivers is None or - driver_name in available_drivers) - matched_driver_filter = driver_filter is None or re.match( - driver_filter, driver_name) is not None - matched_driver = matched_available_driver and matched_driver_filter - - matched_loader = not driver_info.loader_name or available_loaders is None or ( - driver_info.loader_name in available_loaders) - - if target_architectures is None: - matched_arch = True - else: - matched_arch = benchmark_case.target_arch in target_architectures - - bench_mode = ','.join(benchmark_case.bench_mode) - matched_mode = (mode_filter is None or - re.match(mode_filter, bench_mode) is not None) - - model_name_with_tags = benchmark_case.model_name - if len(benchmark_case.model_tags) > 0: - model_name_with_tags += f"-{','.join(benchmark_case.model_tags)}" - matched_model_name = (model_name_filter is None or re.match( - model_name_filter, model_name_with_tags) is not None) - - if (matched_driver and matched_loader and matched_arch and - matched_model_name and matched_mode): - chosen_cases.append(benchmark_case) - - return chosen_cases - - @staticmethod - def load_from_run_configs( - run_configs: Sequence[iree_definitions.E2EModelRunConfig], - root_benchmark_dir: pathlib.Path): - """Loads the benchmarks from the run configs. - - Args: - run_configs: list of benchmark run configs. - Returns: - A benchmark suite. - """ - - benchmark_cases = [] - for run_config in run_configs: - module_gen_config = run_config.module_generation_config - module_exec_config = run_config.module_execution_config - target_device_spec = run_config.target_device_spec - - driver_info_key = EXECUTION_CONFIG_TO_DRIVER_INFO_KEY_MAP.get( - (module_exec_config.driver, module_exec_config.loader)) - if driver_info_key is None: - raise ValueError( - f"Can't map execution config to driver info: {module_exec_config}.") - driver_info = IREE_DRIVERS_INFOS[driver_info_key] - - target_arch = target_device_spec.architecture - model = module_gen_config.imported_model.model - - module_dir_path = iree_artifacts.get_module_dir_path( - module_generation_config=module_gen_config, - root_path=root_benchmark_dir) - module_dir_path = pathlib.Path(module_dir_path) - - benchmark_case = BenchmarkCase(model_name=model.name, - model_tags=model.tags, - bench_mode=module_exec_config.tags, - target_arch=target_arch, - driver_info=driver_info, - benchmark_tool_name=run_config.tool.value, - benchmark_case_dir=module_dir_path, - run_config=run_config) - benchmark_cases.append(benchmark_case) - - return BenchmarkSuite(benchmark_cases=benchmark_cases) + """Represents the benchmarks in benchmark suite directory.""" + + def __init__(self, benchmark_cases: Sequence[BenchmarkCase]): + """Construct a benchmark suite. + + Args: + benchmark_cases: list of benchmark cases. + """ + self.benchmark_cases = list(benchmark_cases) + + def filter_benchmarks( + self, + available_drivers: Optional[Sequence[str]] = None, + available_loaders: Optional[Sequence[str]] = None, + target_architectures: Optional[ + Sequence[common_definitions.DeviceArchitecture] + ] = None, + driver_filter: Optional[str] = None, + mode_filter: Optional[str] = None, + model_name_filter: Optional[str] = None, + ) -> Sequence[BenchmarkCase]: + """Filters benchmarks. + Args: + available_drivers: list of drivers supported by the tools. None means to + match any driver. + available_loaders: list of executable loaders supported by the tools. + None means to match any loader. + target_architectures: list of target architectures to be included. None + means no filter. + driver_filter: driver filter regex. + mode_filter: benchmark mode regex. + model_name_filter: model name regex. + Returns: + A list of matched benchmark cases. + """ + + chosen_cases = [] + for benchmark_case in self.benchmark_cases: + driver_info = benchmark_case.driver_info + + driver_name = driver_info.driver_name + matched_available_driver = ( + available_drivers is None or driver_name in available_drivers + ) + matched_driver_filter = ( + driver_filter is None + or re.match(driver_filter, driver_name) is not None + ) + matched_driver = matched_available_driver and matched_driver_filter + + matched_loader = ( + not driver_info.loader_name + or available_loaders is None + or (driver_info.loader_name in available_loaders) + ) + + if target_architectures is None: + matched_arch = True + else: + matched_arch = benchmark_case.target_arch in target_architectures + + bench_mode = ",".join(benchmark_case.bench_mode) + matched_mode = ( + mode_filter is None or re.match(mode_filter, bench_mode) is not None + ) + + model_name_with_tags = benchmark_case.model_name + if len(benchmark_case.model_tags) > 0: + model_name_with_tags += f"-{','.join(benchmark_case.model_tags)}" + matched_model_name = ( + model_name_filter is None + or re.match(model_name_filter, model_name_with_tags) is not None + ) + + if ( + matched_driver + and matched_loader + and matched_arch + and matched_model_name + and matched_mode + ): + chosen_cases.append(benchmark_case) + + return chosen_cases + + @staticmethod + def load_from_run_configs( + run_configs: Sequence[iree_definitions.E2EModelRunConfig], + root_benchmark_dir: pathlib.Path, + ): + """Loads the benchmarks from the run configs. + + Args: + run_configs: list of benchmark run configs. + Returns: + A benchmark suite. + """ + + benchmark_cases = [] + for run_config in run_configs: + module_gen_config = run_config.module_generation_config + module_exec_config = run_config.module_execution_config + target_device_spec = run_config.target_device_spec + + driver_info_key = EXECUTION_CONFIG_TO_DRIVER_INFO_KEY_MAP.get( + (module_exec_config.driver, module_exec_config.loader) + ) + if driver_info_key is None: + raise ValueError( + f"Can't map execution config to driver info: {module_exec_config}." + ) + driver_info = IREE_DRIVERS_INFOS[driver_info_key] + + target_arch = target_device_spec.architecture + model = module_gen_config.imported_model.model + + module_dir_path = iree_artifacts.get_module_dir_path( + module_generation_config=module_gen_config, root_path=root_benchmark_dir + ) + module_dir_path = pathlib.Path(module_dir_path) + + benchmark_case = BenchmarkCase( + model_name=model.name, + model_tags=model.tags, + bench_mode=module_exec_config.tags, + target_arch=target_arch, + driver_info=driver_info, + benchmark_tool_name=run_config.tool.value, + benchmark_case_dir=module_dir_path, + run_config=run_config, + ) + benchmark_cases.append(benchmark_case) + + return BenchmarkSuite(benchmark_cases=benchmark_cases) diff --git a/build_tools/benchmarks/common/benchmark_suite_test.py b/build_tools/benchmarks/common/benchmark_suite_test.py index 7a8d69f9fd7e..208c35ec7e37 100644 --- a/build_tools/benchmarks/common/benchmark_suite_test.py +++ b/build_tools/benchmarks/common/benchmark_suite_test.py @@ -14,208 +14,244 @@ class BenchmarkSuiteTest(unittest.TestCase): + def test_filter_benchmarks(self): + model = common_definitions.Model( + id="model", + name="model", + tags=[], + source_type=common_definitions.ModelSourceType.EXPORTED_STABLEHLO_MLIR, + source_url="", + entry_function="predict", + input_types=["1xf32"], + ) + exec_config = iree_definitions.ModuleExecutionConfig.build( + id="exec", + tags=[], + loader=iree_definitions.RuntimeLoader.EMBEDDED_ELF, + driver=iree_definitions.RuntimeDriver.LOCAL_SYNC, + ) + device_spec = common_definitions.DeviceSpec.build( + id="dev", + device_name="dev", + architecture=common_definitions.DeviceArchitecture.RV64_GENERIC, + host_environment=common_definitions.HostEnvironment.LINUX_X86_64, + device_parameters=[], + tags=[], + ) + compile_target = iree_definitions.CompileTarget( + target_backend=iree_definitions.TargetBackend.LLVM_CPU, + target_architecture=common_definitions.DeviceArchitecture.RV64_GENERIC, + target_abi=iree_definitions.TargetABI.LINUX_GNU, + ) + dummy_run_config = iree_definitions.E2EModelRunConfig.build( + module_generation_config=iree_definitions.ModuleGenerationConfig.build( + imported_model=iree_definitions.ImportedModel.from_model(model), + compile_config=iree_definitions.CompileConfig.build( + id="1", tags=[], compile_targets=[compile_target] + ), + ), + module_execution_config=exec_config, + target_device_spec=device_spec, + input_data=common_definitions.ZEROS_MODEL_INPUT_DATA, + tool=iree_definitions.E2EModelRunTool.IREE_BENCHMARK_MODULE, + ) - def test_filter_benchmarks(self): - model = common_definitions.Model( - id="model", - name="model", - tags=[], - source_type=common_definitions.ModelSourceType.EXPORTED_STABLEHLO_MLIR, - source_url="", - entry_function="predict", - input_types=["1xf32"]) - exec_config = iree_definitions.ModuleExecutionConfig.build( - id="exec", - tags=[], - loader=iree_definitions.RuntimeLoader.EMBEDDED_ELF, - driver=iree_definitions.RuntimeDriver.LOCAL_SYNC) - device_spec = common_definitions.DeviceSpec.build( - id="dev", - device_name="dev", - architecture=common_definitions.DeviceArchitecture.RV64_GENERIC, - host_environment=common_definitions.HostEnvironment.LINUX_X86_64, - device_parameters=[], - tags=[]) - compile_target = iree_definitions.CompileTarget( - target_backend=iree_definitions.TargetBackend.LLVM_CPU, - target_architecture=common_definitions.DeviceArchitecture.RV64_GENERIC, - target_abi=iree_definitions.TargetABI.LINUX_GNU) - dummy_run_config = iree_definitions.E2EModelRunConfig.build( - module_generation_config=iree_definitions.ModuleGenerationConfig.build( - imported_model=iree_definitions.ImportedModel.from_model(model), - compile_config=iree_definitions.CompileConfig.build( - id="1", tags=[], compile_targets=[compile_target])), - module_execution_config=exec_config, - target_device_spec=device_spec, - input_data=common_definitions.ZEROS_MODEL_INPUT_DATA, - tool=iree_definitions.E2EModelRunTool.IREE_BENCHMARK_MODULE) + case1 = BenchmarkCase( + model_name="deepnet", + model_tags=[], + bench_mode=["1-thread", "full-inference"], + target_arch=common_definitions.DeviceArchitecture.ARMV8_2_A_GENERIC, + driver_info=IREE_DRIVERS_INFOS["iree-llvm-cpu"], + benchmark_case_dir=pathlib.Path("case1"), + benchmark_tool_name="tool", + run_config=dummy_run_config, + ) + case2 = BenchmarkCase( + model_name="deepnetv2", + model_tags=["f32"], + bench_mode=["full-inference"], + target_arch=common_definitions.DeviceArchitecture.ARM_VALHALL, + driver_info=IREE_DRIVERS_INFOS["iree-vulkan"], + benchmark_case_dir=pathlib.Path("case2"), + benchmark_tool_name="tool", + run_config=dummy_run_config, + ) + case3 = BenchmarkCase( + model_name="deepnetv3", + model_tags=["f32"], + bench_mode=["full-inference"], + target_arch=common_definitions.DeviceArchitecture.X86_64_CASCADELAKE, + driver_info=IREE_DRIVERS_INFOS["iree-llvm-cpu-sync"], + benchmark_case_dir=pathlib.Path("case3"), + benchmark_tool_name="tool", + run_config=dummy_run_config, + ) + suite = BenchmarkSuite([case1, case2, case3]) - case1 = BenchmarkCase( - model_name="deepnet", - model_tags=[], - bench_mode=["1-thread", "full-inference"], - target_arch=common_definitions.DeviceArchitecture.ARMV8_2_A_GENERIC, - driver_info=IREE_DRIVERS_INFOS["iree-llvm-cpu"], - benchmark_case_dir=pathlib.Path("case1"), - benchmark_tool_name="tool", - run_config=dummy_run_config) - case2 = BenchmarkCase( - model_name="deepnetv2", - model_tags=["f32"], - bench_mode=["full-inference"], - target_arch=common_definitions.DeviceArchitecture.ARM_VALHALL, - driver_info=IREE_DRIVERS_INFOS["iree-vulkan"], - benchmark_case_dir=pathlib.Path("case2"), - benchmark_tool_name="tool", - run_config=dummy_run_config) - case3 = BenchmarkCase( - model_name="deepnetv3", - model_tags=["f32"], - bench_mode=["full-inference"], - target_arch=common_definitions.DeviceArchitecture.X86_64_CASCADELAKE, - driver_info=IREE_DRIVERS_INFOS["iree-llvm-cpu-sync"], - benchmark_case_dir=pathlib.Path("case3"), - benchmark_tool_name="tool", - run_config=dummy_run_config) - suite = BenchmarkSuite([case1, case2, case3]) - - cpu_and_gpu_benchmarks = suite.filter_benchmarks( - available_drivers=["local-task", "vulkan"], - available_loaders=["embedded-elf"], - target_architectures=[ - common_definitions.DeviceArchitecture.ARMV8_2_A_GENERIC, - common_definitions.DeviceArchitecture.ARM_VALHALL, - ], - driver_filter=None, - mode_filter=".*full-inference.*", - model_name_filter="deepnet.*") - gpu_benchmarks = suite.filter_benchmarks( - available_drivers=["local-task", "vulkan"], - available_loaders=["embedded-elf"], - target_architectures=[ - common_definitions.DeviceArchitecture.ARM_VALHALL, - ], - driver_filter="vulkan", - mode_filter=".*full-inference.*", - model_name_filter="deepnet.*") - all_benchmarks = suite.filter_benchmarks(available_drivers=None, - target_architectures=None, - driver_filter=None, - mode_filter=None, - model_name_filter=None) + cpu_and_gpu_benchmarks = suite.filter_benchmarks( + available_drivers=["local-task", "vulkan"], + available_loaders=["embedded-elf"], + target_architectures=[ + common_definitions.DeviceArchitecture.ARMV8_2_A_GENERIC, + common_definitions.DeviceArchitecture.ARM_VALHALL, + ], + driver_filter=None, + mode_filter=".*full-inference.*", + model_name_filter="deepnet.*", + ) + gpu_benchmarks = suite.filter_benchmarks( + available_drivers=["local-task", "vulkan"], + available_loaders=["embedded-elf"], + target_architectures=[ + common_definitions.DeviceArchitecture.ARM_VALHALL, + ], + driver_filter="vulkan", + mode_filter=".*full-inference.*", + model_name_filter="deepnet.*", + ) + all_benchmarks = suite.filter_benchmarks( + available_drivers=None, + target_architectures=None, + driver_filter=None, + mode_filter=None, + model_name_filter=None, + ) - self.assertEqual(cpu_and_gpu_benchmarks, [case1, case2]) - self.assertEqual(gpu_benchmarks, [case2]) - self.assertEqual(all_benchmarks, [case1, case2, case3]) + self.assertEqual(cpu_and_gpu_benchmarks, [case1, case2]) + self.assertEqual(gpu_benchmarks, [case2]) + self.assertEqual(all_benchmarks, [case1, case2, case3]) - def test_load_from_run_configs(self): - model_tflite = common_definitions.Model( - id="tflite", - name="model_tflite", - tags=[], - source_type=common_definitions.ModelSourceType.EXPORTED_TFLITE, - source_url="", - entry_function="predict", - input_types=["1xf32"]) - model_tf = common_definitions.Model( - id="tf", - name="model_tf", - tags=["fp32"], - source_type=common_definitions.ModelSourceType.EXPORTED_STABLEHLO_MLIR, - source_url="", - entry_function="predict", - input_types=["1xf32"]) - exec_config_a = iree_definitions.ModuleExecutionConfig.build( - id="exec_a", - tags=["defaults"], - loader=iree_definitions.RuntimeLoader.EMBEDDED_ELF, - driver=iree_definitions.RuntimeDriver.LOCAL_SYNC) - exec_config_b = iree_definitions.ModuleExecutionConfig.build( - id="exec_b", - tags=["experimental"], - loader=iree_definitions.RuntimeLoader.EMBEDDED_ELF, - driver=iree_definitions.RuntimeDriver.LOCAL_TASK) - device_spec_a = common_definitions.DeviceSpec.build( - id="dev_a", - device_name="a", - architecture=common_definitions.DeviceArchitecture.RV32_GENERIC, - host_environment=common_definitions.HostEnvironment.LINUX_X86_64, - device_parameters=[], - tags=[]) - device_spec_b = common_definitions.DeviceSpec.build( - id="dev_b", - device_name="b", - architecture=common_definitions.DeviceArchitecture.RV64_GENERIC, - host_environment=common_definitions.HostEnvironment.LINUX_X86_64, - device_parameters=[], - tags=[]) - compile_target = iree_definitions.CompileTarget( - target_backend=iree_definitions.TargetBackend.LLVM_CPU, - target_architecture=common_definitions.DeviceArchitecture.RV64_GENERIC, - target_abi=iree_definitions.TargetABI.LINUX_GNU) - run_config_a = iree_definitions.E2EModelRunConfig.build( - module_generation_config=iree_definitions.ModuleGenerationConfig.build( - imported_model=iree_definitions.ImportedModel.from_model( - model_tflite), - compile_config=iree_definitions.CompileConfig.build( - id="1", tags=[], compile_targets=[compile_target])), - module_execution_config=exec_config_a, - target_device_spec=device_spec_a, - input_data=common_definitions.ZEROS_MODEL_INPUT_DATA, - tool=iree_definitions.E2EModelRunTool.IREE_BENCHMARK_MODULE) - run_config_b = iree_definitions.E2EModelRunConfig.build( - module_generation_config=iree_definitions.ModuleGenerationConfig.build( - imported_model=iree_definitions.ImportedModel.from_model( - model_tflite), - compile_config=iree_definitions.CompileConfig.build( - id="2", tags=[], compile_targets=[compile_target])), - module_execution_config=exec_config_b, - target_device_spec=device_spec_b, - input_data=common_definitions.ZEROS_MODEL_INPUT_DATA, - tool=iree_definitions.E2EModelRunTool.IREE_BENCHMARK_MODULE) - run_config_c = iree_definitions.E2EModelRunConfig.build( - module_generation_config=iree_definitions.ModuleGenerationConfig.build( - imported_model=iree_definitions.ImportedModel.from_model(model_tf), - compile_config=iree_definitions.CompileConfig.build( - id="3", tags=[], compile_targets=[compile_target])), - module_execution_config=exec_config_a, - target_device_spec=device_spec_a, - input_data=common_definitions.ZEROS_MODEL_INPUT_DATA, - tool=iree_definitions.E2EModelRunTool.IREE_BENCHMARK_MODULE) - run_configs = [run_config_a, run_config_b, run_config_c] - root_dir = pathlib.Path("root") + def test_load_from_run_configs(self): + model_tflite = common_definitions.Model( + id="tflite", + name="model_tflite", + tags=[], + source_type=common_definitions.ModelSourceType.EXPORTED_TFLITE, + source_url="", + entry_function="predict", + input_types=["1xf32"], + ) + model_tf = common_definitions.Model( + id="tf", + name="model_tf", + tags=["fp32"], + source_type=common_definitions.ModelSourceType.EXPORTED_STABLEHLO_MLIR, + source_url="", + entry_function="predict", + input_types=["1xf32"], + ) + exec_config_a = iree_definitions.ModuleExecutionConfig.build( + id="exec_a", + tags=["defaults"], + loader=iree_definitions.RuntimeLoader.EMBEDDED_ELF, + driver=iree_definitions.RuntimeDriver.LOCAL_SYNC, + ) + exec_config_b = iree_definitions.ModuleExecutionConfig.build( + id="exec_b", + tags=["experimental"], + loader=iree_definitions.RuntimeLoader.EMBEDDED_ELF, + driver=iree_definitions.RuntimeDriver.LOCAL_TASK, + ) + device_spec_a = common_definitions.DeviceSpec.build( + id="dev_a", + device_name="a", + architecture=common_definitions.DeviceArchitecture.RV32_GENERIC, + host_environment=common_definitions.HostEnvironment.LINUX_X86_64, + device_parameters=[], + tags=[], + ) + device_spec_b = common_definitions.DeviceSpec.build( + id="dev_b", + device_name="b", + architecture=common_definitions.DeviceArchitecture.RV64_GENERIC, + host_environment=common_definitions.HostEnvironment.LINUX_X86_64, + device_parameters=[], + tags=[], + ) + compile_target = iree_definitions.CompileTarget( + target_backend=iree_definitions.TargetBackend.LLVM_CPU, + target_architecture=common_definitions.DeviceArchitecture.RV64_GENERIC, + target_abi=iree_definitions.TargetABI.LINUX_GNU, + ) + run_config_a = iree_definitions.E2EModelRunConfig.build( + module_generation_config=iree_definitions.ModuleGenerationConfig.build( + imported_model=iree_definitions.ImportedModel.from_model(model_tflite), + compile_config=iree_definitions.CompileConfig.build( + id="1", tags=[], compile_targets=[compile_target] + ), + ), + module_execution_config=exec_config_a, + target_device_spec=device_spec_a, + input_data=common_definitions.ZEROS_MODEL_INPUT_DATA, + tool=iree_definitions.E2EModelRunTool.IREE_BENCHMARK_MODULE, + ) + run_config_b = iree_definitions.E2EModelRunConfig.build( + module_generation_config=iree_definitions.ModuleGenerationConfig.build( + imported_model=iree_definitions.ImportedModel.from_model(model_tflite), + compile_config=iree_definitions.CompileConfig.build( + id="2", tags=[], compile_targets=[compile_target] + ), + ), + module_execution_config=exec_config_b, + target_device_spec=device_spec_b, + input_data=common_definitions.ZEROS_MODEL_INPUT_DATA, + tool=iree_definitions.E2EModelRunTool.IREE_BENCHMARK_MODULE, + ) + run_config_c = iree_definitions.E2EModelRunConfig.build( + module_generation_config=iree_definitions.ModuleGenerationConfig.build( + imported_model=iree_definitions.ImportedModel.from_model(model_tf), + compile_config=iree_definitions.CompileConfig.build( + id="3", tags=[], compile_targets=[compile_target] + ), + ), + module_execution_config=exec_config_a, + target_device_spec=device_spec_a, + input_data=common_definitions.ZEROS_MODEL_INPUT_DATA, + tool=iree_definitions.E2EModelRunTool.IREE_BENCHMARK_MODULE, + ) + run_configs = [run_config_a, run_config_b, run_config_c] + root_dir = pathlib.Path("root") - suite = BenchmarkSuite.load_from_run_configs(run_configs=run_configs, - root_benchmark_dir=root_dir) + suite = BenchmarkSuite.load_from_run_configs( + run_configs=run_configs, root_benchmark_dir=root_dir + ) - loaded_run_configs = [case.run_config for case in suite.filter_benchmarks()] - self.assertEqual(loaded_run_configs, [ - run_config_a, - run_config_b, - run_config_c, - ]) - run_config_c_case_dir = pathlib.Path( - iree_artifacts.get_module_dir_path( - run_config_c.module_generation_config, root_dir)) - self.assertEqual( - suite.filter_benchmarks( - target_architectures=[ - common_definitions.DeviceArchitecture.RV32_GENERIC + loaded_run_configs = [case.run_config for case in suite.filter_benchmarks()] + self.assertEqual( + loaded_run_configs, + [ + run_config_a, + run_config_b, + run_config_c, + ], + ) + run_config_c_case_dir = pathlib.Path( + iree_artifacts.get_module_dir_path( + run_config_c.module_generation_config, root_dir + ) + ) + self.assertEqual( + suite.filter_benchmarks( + target_architectures=[ + common_definitions.DeviceArchitecture.RV32_GENERIC + ], + model_name_filter="model_tf.*fp32", + mode_filter="defaults", + ), + [ + BenchmarkCase( + model_name=model_tf.name, + model_tags=model_tf.tags, + bench_mode=exec_config_a.tags, + target_arch=common_definitions.DeviceArchitecture.RV32_GENERIC, + driver_info=IREE_DRIVERS_INFOS["iree-llvm-cpu-sync"], + benchmark_tool_name="iree-benchmark-module", + benchmark_case_dir=run_config_c_case_dir, + run_config=run_config_c, + ) ], - model_name_filter="model_tf.*fp32", - mode_filter="defaults", - ), [ - BenchmarkCase( - model_name=model_tf.name, - model_tags=model_tf.tags, - bench_mode=exec_config_a.tags, - target_arch=common_definitions.DeviceArchitecture.RV32_GENERIC, - driver_info=IREE_DRIVERS_INFOS["iree-llvm-cpu-sync"], - benchmark_tool_name="iree-benchmark-module", - benchmark_case_dir=run_config_c_case_dir, - run_config=run_config_c) - ]) + ) if __name__ == "__main__": - unittest.main() + unittest.main() diff --git a/build_tools/benchmarks/common/benchmark_thresholds.py b/build_tools/benchmarks/common/benchmark_thresholds.py index f10afbd14d83..5d12cb934ea8 100644 --- a/build_tools/benchmarks/common/benchmark_thresholds.py +++ b/build_tools/benchmarks/common/benchmark_thresholds.py @@ -12,28 +12,29 @@ class ThresholdUnit(Enum): - PERCENTAGE = "%" # Percentage - VALUE_NS = "ns" # Absolute value in nanoseconds + PERCENTAGE = "%" # Percentage + VALUE_NS = "ns" # Absolute value in nanoseconds @dataclass class BenchmarkThreshold: - """Similarity threshold for benchmarks matching a regular expression.""" - # A regular expression to match against the benchmark identifier. - regex: re.Pattern - # A threshold for computing the benchmark value average. Benchmark sample - # values from consecutive runs and within the given range will be considered - # as similar (with some noise). They will be used to compute the moving - # average. The number will be interpreted according to the given unit. - # What value to set depends on the noise range of the particular benchmark. - threshold: int - unit: ThresholdUnit + """Similarity threshold for benchmarks matching a regular expression.""" - def get_threshold_str(self): - """Returns a string representation of the threshold.""" - if self.unit == ThresholdUnit.PERCENTAGE: - return f"{self.threshold}%" - return self.threshold + # A regular expression to match against the benchmark identifier. + regex: re.Pattern + # A threshold for computing the benchmark value average. Benchmark sample + # values from consecutive runs and within the given range will be considered + # as similar (with some noise). They will be used to compute the moving + # average. The number will be interpreted according to the given unit. + # What value to set depends on the noise range of the particular benchmark. + threshold: int + unit: ThresholdUnit + + def get_threshold_str(self): + """Returns a string representation of the threshold.""" + if self.unit == ThresholdUnit.PERCENTAGE: + return f"{self.threshold}%" + return self.threshold # A list of benchmarks and their similarity thresholds. @@ -41,63 +42,95 @@ def get_threshold_str(self): # match is used. BENCHMARK_THRESHOLDS = [ # Fluctuating benchmarks on ARM64 CPUs. - BenchmarkThreshold(re.compile(r"^DeepLabV3.*big-core.*LLVM-CPU.* @ Pixel"), - 20, ThresholdUnit.PERCENTAGE), - BenchmarkThreshold( - re.compile(r"^MobileBertSquad.*big-core.*LLVM-CPU-Sync @ Pixel-4"), 20, - ThresholdUnit.PERCENTAGE), - BenchmarkThreshold(re.compile(r"^MobileNetV2.*LLVM-CPU.* @ Pixel"), 15, - ThresholdUnit.PERCENTAGE), - BenchmarkThreshold(re.compile(r"^MobileNetV3Small.*LLVM-CPU.* @ Pixel"), 25, - ThresholdUnit.PERCENTAGE), - BenchmarkThreshold( - re.compile(r"^MobileSSD.*little-core.*LLVM-CPU.* @ Pixel-6"), 20, - ThresholdUnit.PERCENTAGE), - BenchmarkThreshold(re.compile(r"^PoseNet.*big-core.*LLVM-CPU.* @ Pixel"), - 15, ThresholdUnit.PERCENTAGE), - + BenchmarkThreshold( + re.compile(r"^DeepLabV3.*big-core.*LLVM-CPU.* @ Pixel"), + 20, + ThresholdUnit.PERCENTAGE, + ), + BenchmarkThreshold( + re.compile(r"^MobileBertSquad.*big-core.*LLVM-CPU-Sync @ Pixel-4"), + 20, + ThresholdUnit.PERCENTAGE, + ), + BenchmarkThreshold( + re.compile(r"^MobileNetV2.*LLVM-CPU.* @ Pixel"), 15, ThresholdUnit.PERCENTAGE + ), + BenchmarkThreshold( + re.compile(r"^MobileNetV3Small.*LLVM-CPU.* @ Pixel"), + 25, + ThresholdUnit.PERCENTAGE, + ), + BenchmarkThreshold( + re.compile(r"^MobileSSD.*little-core.*LLVM-CPU.* @ Pixel-6"), + 20, + ThresholdUnit.PERCENTAGE, + ), + BenchmarkThreshold( + re.compile(r"^PoseNet.*big-core.*LLVM-CPU.* @ Pixel"), + 15, + ThresholdUnit.PERCENTAGE, + ), # Benchmarks that complete <= 10ms on X86_64 CPUs; using percentage is not # suitable anymore. - BenchmarkThreshold(re.compile(r"^DeepLabV3_fp32.*x86_64"), 1 * 10**6, - ThresholdUnit.VALUE_NS), - BenchmarkThreshold(re.compile(r"^EfficientNet_int8.*x86_64"), 1 * 10**6, - ThresholdUnit.VALUE_NS), - BenchmarkThreshold(re.compile(r"^MobileNetV1_fp32.*x86_64"), 1 * 10**6, - ThresholdUnit.VALUE_NS), - BenchmarkThreshold(re.compile(r"^MobileNetV2_fp32.*x86_64"), 2 * 10**6, - ThresholdUnit.VALUE_NS), - BenchmarkThreshold(re.compile(r"^MobileNetV3Small_fp32.*x86_64"), 1 * 10**6, - ThresholdUnit.VALUE_NS), - BenchmarkThreshold(re.compile(r"^PersonDetect_int8.*x86_64"), 5 * 10**5, - ThresholdUnit.VALUE_NS), - BenchmarkThreshold(re.compile(r"^PoseNet_fp32.*x86_64"), 1 * 10**6, - ThresholdUnit.VALUE_NS), - + BenchmarkThreshold( + re.compile(r"^DeepLabV3_fp32.*x86_64"), 1 * 10**6, ThresholdUnit.VALUE_NS + ), + BenchmarkThreshold( + re.compile(r"^EfficientNet_int8.*x86_64"), 1 * 10**6, ThresholdUnit.VALUE_NS + ), + BenchmarkThreshold( + re.compile(r"^MobileNetV1_fp32.*x86_64"), 1 * 10**6, ThresholdUnit.VALUE_NS + ), + BenchmarkThreshold( + re.compile(r"^MobileNetV2_fp32.*x86_64"), 2 * 10**6, ThresholdUnit.VALUE_NS + ), + BenchmarkThreshold( + re.compile(r"^MobileNetV3Small_fp32.*x86_64"), + 1 * 10**6, + ThresholdUnit.VALUE_NS, + ), + BenchmarkThreshold( + re.compile(r"^PersonDetect_int8.*x86_64"), 5 * 10**5, ThresholdUnit.VALUE_NS + ), + BenchmarkThreshold( + re.compile(r"^PoseNet_fp32.*x86_64"), 1 * 10**6, ThresholdUnit.VALUE_NS + ), # Fluctuating benchmarks on mobile GPUs. BenchmarkThreshold( - re.compile(r"^MobileBertSquad.*int8.*full-inference.*GPU-Mali"), 10, - ThresholdUnit.PERCENTAGE), + re.compile(r"^MobileBertSquad.*int8.*full-inference.*GPU-Mali"), + 10, + ThresholdUnit.PERCENTAGE, + ), BenchmarkThreshold( - re.compile(r"^MobileBertSquad.*fp16.*full-inference.*GPU-Mali"), 10, - ThresholdUnit.PERCENTAGE), + re.compile(r"^MobileBertSquad.*fp16.*full-inference.*GPU-Mali"), + 10, + ThresholdUnit.PERCENTAGE, + ), BenchmarkThreshold( - re.compile(r"^MobileNetV3Small.*full-inference.*GPU-Mali"), 2 * 10**6, - ThresholdUnit.VALUE_NS), - + re.compile(r"^MobileNetV3Small.*full-inference.*GPU-Mali"), + 2 * 10**6, + ThresholdUnit.VALUE_NS, + ), # Benchmarks that complete <= 10ms on GPUs; using percentage is not # suitable anymore. - BenchmarkThreshold(re.compile(r"^DeepLabV3.*GPU-Mali"), 1 * 10**6, - ThresholdUnit.VALUE_NS), - BenchmarkThreshold(re.compile(r"^PersonDetect.*int8.*GPU-Mali"), 2 * 10**5, - ThresholdUnit.VALUE_NS), - BenchmarkThreshold(re.compile(r"^EfficientNet.*int8.*GPU-Mali"), 15 * 10**5, - ThresholdUnit.VALUE_NS), - BenchmarkThreshold(re.compile(r"^MobileNet.*GPU"), 1 * 10**6, - ThresholdUnit.VALUE_NS), - + BenchmarkThreshold( + re.compile(r"^DeepLabV3.*GPU-Mali"), 1 * 10**6, ThresholdUnit.VALUE_NS + ), + BenchmarkThreshold( + re.compile(r"^PersonDetect.*int8.*GPU-Mali"), + 2 * 10**5, + ThresholdUnit.VALUE_NS, + ), + BenchmarkThreshold( + re.compile(r"^EfficientNet.*int8.*GPU-Mali"), + 15 * 10**5, + ThresholdUnit.VALUE_NS, + ), + BenchmarkThreshold( + re.compile(r"^MobileNet.*GPU"), 1 * 10**6, ThresholdUnit.VALUE_NS + ), # Default threshold for all ARM64/X86_64 benchmarks: 10%. - BenchmarkThreshold(re.compile(r".*CPU-ARM.*"), 10, - ThresholdUnit.PERCENTAGE), + BenchmarkThreshold(re.compile(r".*CPU-ARM.*"), 10, ThresholdUnit.PERCENTAGE), BenchmarkThreshold(re.compile(r".*x86_64.*"), 10, ThresholdUnit.PERCENTAGE), # Default threshold for all benchmarks: 5%. BenchmarkThreshold(re.compile(r".*"), 5, ThresholdUnit.PERCENTAGE), diff --git a/build_tools/benchmarks/common/common_arguments.py b/build_tools/benchmarks/common/common_arguments.py index 258265f6036b..fb38be90128b 100644 --- a/build_tools/benchmarks/common/common_arguments.py +++ b/build_tools/benchmarks/common/common_arguments.py @@ -13,160 +13,189 @@ def _check_dir_path(path): - path = pathlib.Path(path) - if path.is_dir(): - return path - else: - raise argparse.ArgumentTypeError(path) + path = pathlib.Path(path) + if path.is_dir(): + return path + else: + raise argparse.ArgumentTypeError(path) def _check_file_path(path): - path = pathlib.Path(path) - if path.is_file(): - return path - else: - raise argparse.ArgumentTypeError(f"'{path}' is not found") + path = pathlib.Path(path) + if path.is_file(): + return path + else: + raise argparse.ArgumentTypeError(f"'{path}' is not found") def _check_exe_path(path): - path = pathlib.Path(path) - if os.access(path, os.X_OK): - return path - else: - raise argparse.ArgumentTypeError(f"'{path}' is not an executable") + path = pathlib.Path(path) + if os.access(path, os.X_OK): + return path + else: + raise argparse.ArgumentTypeError(f"'{path}' is not an executable") class Parser(argparse.ArgumentParser): - """Argument parser that includes common arguments and does validation.""" - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - self.add_argument("--e2e_test_artifacts_dir", - metavar="", - type=_check_dir_path, - required=True, - help="Path to the IREE e2e test artifacts directory.") - - self.add_argument( - "--normal_benchmark_tool_dir", - "--normal-benchmark-tool-dir", - type=_check_dir_path, - default=None, - help="Path to the normal (non-tracing) iree tool directory") - self.add_argument("--traced_benchmark_tool_dir", - "--traced-benchmark-tool-dir", - type=_check_dir_path, - default=None, - help="Path to the tracing-enabled iree tool directory") - self.add_argument("--trace_capture_tool", - "--trace-capture-tool", - type=_check_exe_path, - default=None, - help="Path to the tool for collecting captured traces") - self.add_argument( - "--driver-filter-regex", - "--driver_filter_regex", - type=str, - default=None, - help="Only run benchmarks matching the given driver regex") - self.add_argument( - "--model-name-regex", - "--model_name_regex", - type=str, - default=None, - help="Only run benchmarks matching the given model name regex") - self.add_argument( - "--mode-regex", - "--mode_regex", - type=str, - default=None, - help="Only run benchmarks matching the given benchmarking mode regex") - self.add_argument("--output", - "-o", - default=None, - type=pathlib.Path, - help="Path to the output file") - self.add_argument("--capture_tarball", - "--capture-tarball", - default=None, - type=pathlib.Path, - help="Path to the tarball for captures") - self.add_argument("--no-clean", - action="store_true", - help="Do not clean up the temporary directory used for " - "benchmarking on the Android device") - self.add_argument("--verbose", - action="store_true", - help="Print internal information during execution") - self.add_argument( - "--pin-cpu-freq", - "--pin_cpu_freq", - action="store_true", - help="Pin CPU frequency for all cores to the maximum. Requires root") - self.add_argument("--pin-gpu-freq", - "--pin_gpu_freq", - action="store_true", - help="Pin GPU frequency to the maximum. Requires root") - self.add_argument( - "--keep_going", - "--keep-going", - action="store_true", - help="Continue running after a failed benchmark. The overall exit status" - " will still indicate failure and all errors will be reported at the end." - ) - self.add_argument( - "--tmp_dir", - "--tmp-dir", - "--tmpdir", - default=pathlib.Path("/tmp/iree-benchmarks"), - type=_check_dir_path, - help="Base directory in which to store temporary files. A subdirectory" - " with a name matching the git commit hash will be created.") - self.add_argument( - "--continue_from_previous", - "--continue-from-previous", - action="store_true", - help="Previous benchmark and capture results will be used and not " - "rerun if they are found in the benchmark results directory.") - self.add_argument( - "--benchmark_min_time", - "--benchmark-min-time", - default=0, - type=float, - help="If specified, this will be passed as --benchmark_min_time to the" - "iree-benchmark-module (minimum number of seconds to repeat running " - "for). In that case, no --benchmark_repetitions flag will be passed." - " If not specified, a --benchmark_repetitions will be passed " - "instead.") - self.add_argument( - "--compatible_only", - "--compatible-only", - action="store_true", - help="Only run compatible benchmarks based on the detected device " - "information") - self.add_argument("--execution_benchmark_config", - type=_check_file_path, - required=True, - help="JSON config for the execution benchmarks") - self.add_argument("--target_device_name", - type=str, - required=True, - help="Target device in benchmark config to run") + """Argument parser that includes common arguments and does validation.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.add_argument( + "--e2e_test_artifacts_dir", + metavar="", + type=_check_dir_path, + required=True, + help="Path to the IREE e2e test artifacts directory.", + ) + + self.add_argument( + "--normal_benchmark_tool_dir", + "--normal-benchmark-tool-dir", + type=_check_dir_path, + default=None, + help="Path to the normal (non-tracing) iree tool directory", + ) + self.add_argument( + "--traced_benchmark_tool_dir", + "--traced-benchmark-tool-dir", + type=_check_dir_path, + default=None, + help="Path to the tracing-enabled iree tool directory", + ) + self.add_argument( + "--trace_capture_tool", + "--trace-capture-tool", + type=_check_exe_path, + default=None, + help="Path to the tool for collecting captured traces", + ) + self.add_argument( + "--driver-filter-regex", + "--driver_filter_regex", + type=str, + default=None, + help="Only run benchmarks matching the given driver regex", + ) + self.add_argument( + "--model-name-regex", + "--model_name_regex", + type=str, + default=None, + help="Only run benchmarks matching the given model name regex", + ) + self.add_argument( + "--mode-regex", + "--mode_regex", + type=str, + default=None, + help="Only run benchmarks matching the given benchmarking mode regex", + ) + self.add_argument( + "--output", + "-o", + default=None, + type=pathlib.Path, + help="Path to the output file", + ) + self.add_argument( + "--capture_tarball", + "--capture-tarball", + default=None, + type=pathlib.Path, + help="Path to the tarball for captures", + ) + self.add_argument( + "--no-clean", + action="store_true", + help="Do not clean up the temporary directory used for " + "benchmarking on the Android device", + ) + self.add_argument( + "--verbose", + action="store_true", + help="Print internal information during execution", + ) + self.add_argument( + "--pin-cpu-freq", + "--pin_cpu_freq", + action="store_true", + help="Pin CPU frequency for all cores to the maximum. Requires root", + ) + self.add_argument( + "--pin-gpu-freq", + "--pin_gpu_freq", + action="store_true", + help="Pin GPU frequency to the maximum. Requires root", + ) + self.add_argument( + "--keep_going", + "--keep-going", + action="store_true", + help="Continue running after a failed benchmark. The overall exit status" + " will still indicate failure and all errors will be reported at the end.", + ) + self.add_argument( + "--tmp_dir", + "--tmp-dir", + "--tmpdir", + default=pathlib.Path("/tmp/iree-benchmarks"), + type=_check_dir_path, + help="Base directory in which to store temporary files. A subdirectory" + " with a name matching the git commit hash will be created.", + ) + self.add_argument( + "--continue_from_previous", + "--continue-from-previous", + action="store_true", + help="Previous benchmark and capture results will be used and not " + "rerun if they are found in the benchmark results directory.", + ) + self.add_argument( + "--benchmark_min_time", + "--benchmark-min-time", + default=0, + type=float, + help="If specified, this will be passed as --benchmark_min_time to the" + "iree-benchmark-module (minimum number of seconds to repeat running " + "for). In that case, no --benchmark_repetitions flag will be passed." + " If not specified, a --benchmark_repetitions will be passed " + "instead.", + ) + self.add_argument( + "--compatible_only", + "--compatible-only", + action="store_true", + help="Only run compatible benchmarks based on the detected device " + "information", + ) + self.add_argument( + "--execution_benchmark_config", + type=_check_file_path, + required=True, + help="JSON config for the execution benchmarks", + ) + self.add_argument( + "--target_device_name", + type=str, + required=True, + help="Target device in benchmark config to run", + ) def expand_and_check_file_paths(paths: Sequence[str]) -> List[pathlib.Path]: - """Expands the wildcards in the paths and check if they are files. + """Expands the wildcards in the paths and check if they are files. Returns: List of expanded paths. - """ + """ - expanded_paths = [] - for path in paths: - expanded_paths += [pathlib.Path(path) for path in glob.glob(path)] + expanded_paths = [] + for path in paths: + expanded_paths += [pathlib.Path(path) for path in glob.glob(path)] - for path in expanded_paths: - if not path.is_file(): - raise ValueError(f"{path} is not a file.") + for path in expanded_paths: + if not path.is_file(): + raise ValueError(f"{path} is not a file.") - return expanded_paths + return expanded_paths diff --git a/build_tools/benchmarks/common/common_arguments_test.py b/build_tools/benchmarks/common/common_arguments_test.py index 1469261ecc2d..714a8b905cd8 100644 --- a/build_tools/benchmarks/common/common_arguments_test.py +++ b/build_tools/benchmarks/common/common_arguments_test.py @@ -13,65 +13,72 @@ class CommonArgumentsTest(unittest.TestCase): + def setUp(self): + self._build_dir_manager = tempfile.TemporaryDirectory() + self.build_dir = pathlib.Path(self._build_dir_manager.name).resolve() + self.e2e_test_artifacts_dir = self.build_dir / "e2e_test_artifacts" + self.e2e_test_artifacts_dir.mkdir() + self.normal_tool_dir = self.build_dir / "normal_tool" + self.normal_tool_dir.mkdir() + self.traced_tool_dir = self.build_dir / "traced_tool" + self.traced_tool_dir.mkdir() + self.trace_capture_tool = self.build_dir / "tracy_capture" + # Create capture tool with executable file mode. + self.trace_capture_tool.touch(mode=0o755) + self.execution_config = self.build_dir / "execution_config.json" + self.execution_config.touch() - def setUp(self): - self._build_dir_manager = tempfile.TemporaryDirectory() - self.build_dir = pathlib.Path(self._build_dir_manager.name).resolve() - self.e2e_test_artifacts_dir = self.build_dir / "e2e_test_artifacts" - self.e2e_test_artifacts_dir.mkdir() - self.normal_tool_dir = self.build_dir / "normal_tool" - self.normal_tool_dir.mkdir() - self.traced_tool_dir = self.build_dir / "traced_tool" - self.traced_tool_dir.mkdir() - self.trace_capture_tool = self.build_dir / "tracy_capture" - # Create capture tool with executable file mode. - self.trace_capture_tool.touch(mode=0o755) - self.execution_config = self.build_dir / "execution_config.json" - self.execution_config.touch() + def tearDown(self): + self._build_dir_manager.cleanup() - def tearDown(self): - self._build_dir_manager.cleanup() + def test_parser(self): + common.common_arguments.Parser().parse_args( + [ + f"--normal_benchmark_tool_dir={self.normal_tool_dir}", + f"--traced_benchmark_tool_dir={self.traced_tool_dir}", + f"--trace_capture_tool={self.trace_capture_tool}", + f"--e2e_test_artifacts_dir={self.e2e_test_artifacts_dir}", + f"--execution_benchmark_config={self.execution_config}", + "--target_device=test", + ] + ) - def test_parser(self): - common.common_arguments.Parser().parse_args([ - f"--normal_benchmark_tool_dir={self.normal_tool_dir}", - f"--traced_benchmark_tool_dir={self.traced_tool_dir}", - f"--trace_capture_tool={self.trace_capture_tool}", - f"--e2e_test_artifacts_dir={self.e2e_test_artifacts_dir}", - f"--execution_benchmark_config={self.execution_config}", - "--target_device=test", - ]) + def test_parser_check_normal_benchmark_tool(self): + arg_parser = common.common_arguments.Parser() + with self.assertRaises(SystemExit): + arg_parser.parse_args( + [ + "--normal_benchmark_tool_dir=nonexistent", + f"--e2e_test_artifacts_dir={self.e2e_test_artifacts_dir}", + f"--execution_benchmark_config={self.execution_config}", + "--target_device=test", + ] + ) - def test_parser_check_normal_benchmark_tool(self): - arg_parser = common.common_arguments.Parser() - with self.assertRaises(SystemExit): - arg_parser.parse_args([ - "--normal_benchmark_tool_dir=nonexistent", - f"--e2e_test_artifacts_dir={self.e2e_test_artifacts_dir}", - f"--execution_benchmark_config={self.execution_config}", - "--target_device=test", - ]) + def test_parser_check_traced_benchmark_tool(self): + arg_parser = common.common_arguments.Parser() + with self.assertRaises(SystemExit): + arg_parser.parse_args( + [ + "--traced_benchmark_tool_dir=nonexistent", + f"--e2e_test_artifacts_dir={self.e2e_test_artifacts_dir}", + f"--execution_benchmark_config={self.execution_config}", + "--target_device=test", + ] + ) - def test_parser_check_traced_benchmark_tool(self): - arg_parser = common.common_arguments.Parser() - with self.assertRaises(SystemExit): - arg_parser.parse_args([ - "--traced_benchmark_tool_dir=nonexistent", - f"--e2e_test_artifacts_dir={self.e2e_test_artifacts_dir}", - f"--execution_benchmark_config={self.execution_config}", - "--target_device=test", - ]) - - def test_parser_check_trace_capture_tool(self): - arg_parser = common.common_arguments.Parser() - with self.assertRaises(SystemExit): - arg_parser.parse_args([ - "--trace_capture_tool=nonexistent", - f"--e2e_test_artifacts_dir={self.e2e_test_artifacts_dir}", - f"--execution_benchmark_config={self.execution_config}", - "--target_device=test", - ]) + def test_parser_check_trace_capture_tool(self): + arg_parser = common.common_arguments.Parser() + with self.assertRaises(SystemExit): + arg_parser.parse_args( + [ + "--trace_capture_tool=nonexistent", + f"--e2e_test_artifacts_dir={self.e2e_test_artifacts_dir}", + f"--execution_benchmark_config={self.execution_config}", + "--target_device=test", + ] + ) if __name__ == "__main__": - unittest.main() + unittest.main() diff --git a/build_tools/benchmarks/common/linux_device_utils.py b/build_tools/benchmarks/common/linux_device_utils.py index 9285782a9920..e72b576ef973 100644 --- a/build_tools/benchmarks/common/linux_device_utils.py +++ b/build_tools/benchmarks/common/linux_device_utils.py @@ -9,58 +9,65 @@ import re from typing import Optional, Sequence -from .benchmark_definition import (execute_cmd_and_get_stdout, DeviceInfo, - PlatformType) +from .benchmark_definition import execute_cmd_and_get_stdout, DeviceInfo, PlatformType def _get_lscpu_field(lscpu_output: str, field_name: str) -> str: - (value,) = re.findall(f"^{field_name}:\s*(.+)", lscpu_output, re.MULTILINE) - return value + (value,) = re.findall(f"^{field_name}:\s*(.+)", lscpu_output, re.MULTILINE) + return value def get_linux_cpu_arch(lscpu_output: str) -> str: - """Returns CPU Architecture, e.g., 'x86_64'.""" - return _get_lscpu_field(lscpu_output, "Architecture") + """Returns CPU Architecture, e.g., 'x86_64'.""" + return _get_lscpu_field(lscpu_output, "Architecture") def get_linux_cpu_features(lscpu_output: str) -> Sequence[str]: - """Returns CPU feature lists, e.g., ['mmx', 'fxsr', 'sse', 'sse2'].""" - return _get_lscpu_field(lscpu_output, "Flags").split(" ") + """Returns CPU feature lists, e.g., ['mmx', 'fxsr', 'sse', 'sse2'].""" + return _get_lscpu_field(lscpu_output, "Flags").split(" ") def canonicalize_gpu_name(gpu_name: str) -> str: - # Replace all consecutive non-word characters with a single hyphen. - return re.sub(r"\W+", "-", gpu_name) + # Replace all consecutive non-word characters with a single hyphen. + return re.sub(r"\W+", "-", gpu_name) -def get_linux_device_info(device_model: str = "Unknown", - cpu_uarch: Optional[str] = None, - gpu_id: str = "0", - verbose: bool = False) -> DeviceInfo: - """Returns device info for the Linux device. +def get_linux_device_info( + device_model: str = "Unknown", + cpu_uarch: Optional[str] = None, + gpu_id: str = "0", + verbose: bool = False, +) -> DeviceInfo: + """Returns device info for the Linux device. Args: - device_model: the device model name, e.g., 'ThinkStation P520' - cpu_uarch: the CPU microarchitecture, e.g., 'CascadeLake' - gpu_id: the target GPU ID, e.g., '0' or 'GPU-' - """ - lscpu_output = execute_cmd_and_get_stdout(["lscpu"], verbose) + """ + lscpu_output = execute_cmd_and_get_stdout(["lscpu"], verbose) - try: - gpu_name = execute_cmd_and_get_stdout([ - "nvidia-smi", "--query-gpu=name", "--format=csv,noheader", - f"--id={gpu_id}" - ], verbose) - except FileNotFoundError: - # Set GPU name to Unknown if the tool "nvidia-smi" doesn't exist. - gpu_name = "Unknown" + try: + gpu_name = execute_cmd_and_get_stdout( + [ + "nvidia-smi", + "--query-gpu=name", + "--format=csv,noheader", + f"--id={gpu_id}", + ], + verbose, + ) + except FileNotFoundError: + # Set GPU name to Unknown if the tool "nvidia-smi" doesn't exist. + gpu_name = "Unknown" - return DeviceInfo( - PlatformType.LINUX, - # Includes CPU model as it is the key factor of the device performance. - model=device_model, - # Currently we only have x86, so CPU ABI = CPU arch. - cpu_abi=get_linux_cpu_arch(lscpu_output), - cpu_uarch=cpu_uarch, - cpu_features=get_linux_cpu_features(lscpu_output), - gpu_name=canonicalize_gpu_name(gpu_name)) + return DeviceInfo( + PlatformType.LINUX, + # Includes CPU model as it is the key factor of the device performance. + model=device_model, + # Currently we only have x86, so CPU ABI = CPU arch. + cpu_abi=get_linux_cpu_arch(lscpu_output), + cpu_uarch=cpu_uarch, + cpu_features=get_linux_cpu_features(lscpu_output), + gpu_name=canonicalize_gpu_name(gpu_name), + ) diff --git a/build_tools/benchmarks/common/linux_device_utils_test.py b/build_tools/benchmarks/common/linux_device_utils_test.py index 80ec58d6696b..60e76c2e0651 100644 --- a/build_tools/benchmarks/common/linux_device_utils_test.py +++ b/build_tools/benchmarks/common/linux_device_utils_test.py @@ -10,26 +10,33 @@ from unittest import mock from common.benchmark_definition import DeviceInfo, PlatformType -from common.linux_device_utils import canonicalize_gpu_name, get_linux_cpu_arch, get_linux_cpu_features +from common.linux_device_utils import ( + canonicalize_gpu_name, + get_linux_cpu_arch, + get_linux_cpu_features, +) -LSCPU_OUTPUT = ("Architecture: x86_64\n" - "Vendor ID: AuthenticAMD\n" - "Flags: fpu vme de pse tsc\n") +LSCPU_OUTPUT = ( + "Architecture: x86_64\n" + "Vendor ID: AuthenticAMD\n" + "Flags: fpu vme de pse tsc\n" +) class LinuxDeviceUtilsTest(unittest.TestCase): + def test_get_linux_cpu_arch(self): + self.assertEqual(get_linux_cpu_arch(LSCPU_OUTPUT), "x86_64") - def test_get_linux_cpu_arch(self): - self.assertEqual(get_linux_cpu_arch(LSCPU_OUTPUT), "x86_64") + def test_get_linux_cpu_features(self): + self.assertEqual( + get_linux_cpu_features(LSCPU_OUTPUT), ["fpu", "vme", "de", "pse", "tsc"] + ) - def test_get_linux_cpu_features(self): - self.assertEqual(get_linux_cpu_features(LSCPU_OUTPUT), - ["fpu", "vme", "de", "pse", "tsc"]) - - def test_canonicalize_gpu_name(self): - self.assertEqual(canonicalize_gpu_name("Tesla V100-SXM2-16GB"), - "Tesla-V100-SXM2-16GB") + def test_canonicalize_gpu_name(self): + self.assertEqual( + canonicalize_gpu_name("Tesla V100-SXM2-16GB"), "Tesla-V100-SXM2-16GB" + ) if __name__ == "__main__": - unittest.main() + unittest.main() diff --git a/build_tools/benchmarks/comparisons/common/benchmark_command.py b/build_tools/benchmarks/comparisons/common/benchmark_command.py index bf3b78f774ae..a49cc69e76bd 100644 --- a/build_tools/benchmarks/comparisons/common/benchmark_command.py +++ b/build_tools/benchmarks/comparisons/common/benchmark_command.py @@ -11,125 +11,129 @@ class BenchmarkCommand(abc.ABC): - """Abstracts a benchmark command.""" - - def __init__(self, - benchmark_binary: str, - model_name: str, - num_threads: int, - num_runs: int, - driver: Optional[str] = None, - taskset: Optional[str] = None): - self.benchmark_binary = benchmark_binary - self.model_name = model_name - self.taskset = taskset - self.num_threads = num_threads - self.num_runs = num_runs - self.driver = driver - self.args = [] - - @property - @abc.abstractmethod - def runtime(self): - pass - - @abc.abstractmethod - def parse_latency_from_output(self, output: str) -> float: - pass - - def generate_benchmark_command(self) -> list[str]: - """Returns a list of strings that correspond to the command to be run.""" - command = [] - if self.taskset: - command.append("taskset") - command.append(str(self.taskset)) - command.append(self.benchmark_binary) - command.extend(self.args) - return command + """Abstracts a benchmark command.""" + + def __init__( + self, + benchmark_binary: str, + model_name: str, + num_threads: int, + num_runs: int, + driver: Optional[str] = None, + taskset: Optional[str] = None, + ): + self.benchmark_binary = benchmark_binary + self.model_name = model_name + self.taskset = taskset + self.num_threads = num_threads + self.num_runs = num_runs + self.driver = driver + self.args = [] + + @property + @abc.abstractmethod + def runtime(self): + pass + + @abc.abstractmethod + def parse_latency_from_output(self, output: str) -> float: + pass + + def generate_benchmark_command(self) -> list[str]: + """Returns a list of strings that correspond to the command to be run.""" + command = [] + if self.taskset: + command.append("taskset") + command.append(str(self.taskset)) + command.append(self.benchmark_binary) + command.extend(self.args) + return command class TFLiteBenchmarkCommand(BenchmarkCommand): - """Represents a TFLite benchmark command.""" - - def __init__(self, - benchmark_binary: str, - model_name: str, - model_path: str, - num_threads: int, - num_runs: int, - taskset: Optional[str] = None): - super().__init__(benchmark_binary, - model_name, - num_threads, - num_runs, - taskset=taskset) - self.args.append("--graph=" + model_path) - self._latency_large_regex = re.compile( - r".*?Inference \(avg\): (\d+.?\d*e\+?\d*).*") - self._latency_regex = re.compile(r".*?Inference \(avg\): (\d+).*") - - @property - def runtime(self): - return "tflite" - - def parse_latency_from_output(self, output: str) -> float: - # First match whether a large number has been recorded e.g. 1.18859e+06. - matches = self._latency_large_regex.search(output) - if not matches: - # Otherwise, regular number e.g. 71495.6. - matches = self._latency_regex.search(output) - - latency_ms = 0 - if matches: - latency_ms = float(matches.group(1)) / 1000 - else: - print("Warning! Could not parse latency. Defaulting to 0ms.") - return latency_ms - - def generate_benchmark_command(self) -> list[str]: - command = super().generate_benchmark_command() - if self.driver == "gpu": - command.append("--use_gpu=true") - command.append("--num_threads=" + str(self.num_threads)) - command.append("--num_runs=" + str(self.num_runs)) - return command + """Represents a TFLite benchmark command.""" + + def __init__( + self, + benchmark_binary: str, + model_name: str, + model_path: str, + num_threads: int, + num_runs: int, + taskset: Optional[str] = None, + ): + super().__init__( + benchmark_binary, model_name, num_threads, num_runs, taskset=taskset + ) + self.args.append("--graph=" + model_path) + self._latency_large_regex = re.compile( + r".*?Inference \(avg\): (\d+.?\d*e\+?\d*).*" + ) + self._latency_regex = re.compile(r".*?Inference \(avg\): (\d+).*") + + @property + def runtime(self): + return "tflite" + + def parse_latency_from_output(self, output: str) -> float: + # First match whether a large number has been recorded e.g. 1.18859e+06. + matches = self._latency_large_regex.search(output) + if not matches: + # Otherwise, regular number e.g. 71495.6. + matches = self._latency_regex.search(output) + + latency_ms = 0 + if matches: + latency_ms = float(matches.group(1)) / 1000 + else: + print("Warning! Could not parse latency. Defaulting to 0ms.") + return latency_ms + + def generate_benchmark_command(self) -> list[str]: + command = super().generate_benchmark_command() + if self.driver == "gpu": + command.append("--use_gpu=true") + command.append("--num_threads=" + str(self.num_threads)) + command.append("--num_runs=" + str(self.num_runs)) + return command class IreeBenchmarkCommand(BenchmarkCommand): - """Represents an IREE benchmark command.""" - - def __init__(self, - benchmark_binary: str, - model_name: str, - model_path: str, - num_threads: int, - num_runs: int, - taskset: Optional[str] = None): - super().__init__(benchmark_binary, - model_name, - num_threads, - num_runs, - taskset=taskset) - self.args.append("--module=" + model_path) - self._latency_regex = re.compile( - r".*?BM_main/process_time/real_time_mean\s+(.*?) ms.*") - - @property - def runtime(self): - return "iree" - - def parse_latency_from_output(self, output: str) -> float: - matches = self._latency_regex.search(output) - latency_ms = 0 - if matches: - latency_ms = float(matches.group(1)) - else: - print("Warning! Could not parse latency. Defaulting to 0ms.") - return latency_ms - - def generate_benchmark_command(self) -> list[str]: - command = super().generate_benchmark_command() - command.append("--device=" + self.driver) - command.append("--task_topology_max_group_count=" + str(self.num_threads)) - command.append("--benchmark_repetitions=" + str(self.num_runs)) - return command + """Represents an IREE benchmark command.""" + + def __init__( + self, + benchmark_binary: str, + model_name: str, + model_path: str, + num_threads: int, + num_runs: int, + taskset: Optional[str] = None, + ): + super().__init__( + benchmark_binary, model_name, num_threads, num_runs, taskset=taskset + ) + self.args.append("--module=" + model_path) + self._latency_regex = re.compile( + r".*?BM_main/process_time/real_time_mean\s+(.*?) ms.*" + ) + + @property + def runtime(self): + return "iree" + + def parse_latency_from_output(self, output: str) -> float: + matches = self._latency_regex.search(output) + latency_ms = 0 + if matches: + latency_ms = float(matches.group(1)) + else: + print("Warning! Could not parse latency. Defaulting to 0ms.") + return latency_ms + + def generate_benchmark_command(self) -> list[str]: + command = super().generate_benchmark_command() + command.append("--device=" + self.driver) + command.append("--task_topology_max_group_count=" + str(self.num_threads)) + command.append("--benchmark_repetitions=" + str(self.num_runs)) + return command diff --git a/build_tools/benchmarks/comparisons/common/benchmark_command_factory.py b/build_tools/benchmarks/comparisons/common/benchmark_command_factory.py index f417d47288d2..d76ab5db4962 100644 --- a/build_tools/benchmarks/comparisons/common/benchmark_command_factory.py +++ b/build_tools/benchmarks/comparisons/common/benchmark_command_factory.py @@ -10,15 +10,16 @@ class BenchmarkCommandFactory(abc.ABC): - """ An abstract factory that generates commands depending on config. - Args: - device: Currently 'desktop' or 'mobile' are supported. - driver: Currently 'cpu' or 'gpu' are supported. - Returns: - An array containing `BenchmarkCommand` objects. - """ + """An abstract factory that generates commands depending on config. + Args: + device: Currently 'desktop' or 'mobile' are supported. + driver: Currently 'cpu' or 'gpu' are supported. + Returns: + An array containing `BenchmarkCommand` objects. + """ - @abc.abstractmethod - def generate_benchmark_commands(self, device: str, - driver: str) -> list[BenchmarkCommand]: - pass + @abc.abstractmethod + def generate_benchmark_commands( + self, device: str, driver: str + ) -> list[BenchmarkCommand]: + pass diff --git a/build_tools/benchmarks/comparisons/common/benchmark_runner.py b/build_tools/benchmarks/comparisons/common/benchmark_runner.py index a41f326ba120..3bc5db30da90 100644 --- a/build_tools/benchmarks/comparisons/common/benchmark_runner.py +++ b/build_tools/benchmarks/comparisons/common/benchmark_runner.py @@ -17,48 +17,51 @@ def run_command(benchmark_command: BenchmarkCommand) -> list[float]: - """Runs `benchmark_command` and polls for memory consumption statistics. - Args: - benchmark_command: A `BenchmarkCommand` object containing information on how to run the benchmark and parse the output. - Returns: - An array containing values for [`latency`, `vmhwm`, `vmrss`, `rssfile`] - """ - command = benchmark_command.generate_benchmark_command() - print("\n\nRunning command:\n" + " ".join(command)) - benchmark_process = subprocess.Popen(command, - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT) + """Runs `benchmark_command` and polls for memory consumption statistics. + Args: + benchmark_command: A `BenchmarkCommand` object containing information on how to run the benchmark and parse the output. + Returns: + An array containing values for [`latency`, `vmhwm`, `vmrss`, `rssfile`] + """ + command = benchmark_command.generate_benchmark_command() + print("\n\nRunning command:\n" + " ".join(command)) + benchmark_process = subprocess.Popen( + command, stdout=subprocess.PIPE, stderr=subprocess.STDOUT + ) - # Keep a record of the highest VmHWM corresponding VmRSS and RssFile values. - vmhwm = 0 - vmrss = 0 - rssfile = 0 - while benchmark_process.poll() is None: - pid_status = subprocess.run( - ["cat", "/proc/" + str(benchmark_process.pid) + "/status"], - capture_output=True) - output = pid_status.stdout.decode() - vmhwm_matches = _VMHWM_REGEX.search(output) - vmrss_matches = _VMRSS_REGEX.search(output) - rssfile_matches = _RSSFILE_REGEX.search(output) + # Keep a record of the highest VmHWM corresponding VmRSS and RssFile values. + vmhwm = 0 + vmrss = 0 + rssfile = 0 + while benchmark_process.poll() is None: + pid_status = subprocess.run( + ["cat", "/proc/" + str(benchmark_process.pid) + "/status"], + capture_output=True, + ) + output = pid_status.stdout.decode() + vmhwm_matches = _VMHWM_REGEX.search(output) + vmrss_matches = _VMRSS_REGEX.search(output) + rssfile_matches = _RSSFILE_REGEX.search(output) - if vmhwm_matches and vmrss_matches and rssfile_matches: - curr_vmhwm = float(vmhwm_matches.group(1)) - if curr_vmhwm > vmhwm: - vmhwm = curr_vmhwm - vmrss = float(vmrss_matches.group(1)) - rssfile = float(rssfile_matches.group(1)) + if vmhwm_matches and vmrss_matches and rssfile_matches: + curr_vmhwm = float(vmhwm_matches.group(1)) + if curr_vmhwm > vmhwm: + vmhwm = curr_vmhwm + vmrss = float(vmrss_matches.group(1)) + rssfile = float(rssfile_matches.group(1)) - time.sleep(0.5) + time.sleep(0.5) - stdout_data, _ = benchmark_process.communicate() + stdout_data, _ = benchmark_process.communicate() - if benchmark_process.returncode != 0: - print(f"Warning! Benchmark command failed with return code:" - f" {benchmark_process.returncode}") - return [0, 0, 0, 0] - else: - print(stdout_data.decode()) + if benchmark_process.returncode != 0: + print( + f"Warning! Benchmark command failed with return code:" + f" {benchmark_process.returncode}" + ) + return [0, 0, 0, 0] + else: + print(stdout_data.decode()) - latency_ms = benchmark_command.parse_latency_from_output(stdout_data.decode()) - return [latency_ms, vmhwm, vmrss, rssfile] + latency_ms = benchmark_command.parse_latency_from_output(stdout_data.decode()) + return [latency_ms, vmhwm, vmrss, rssfile] diff --git a/build_tools/benchmarks/comparisons/common/utils.py b/build_tools/benchmarks/comparisons/common/utils.py index 84b1f6fde9d6..43f92fe00a6e 100644 --- a/build_tools/benchmarks/comparisons/common/utils.py +++ b/build_tools/benchmarks/comparisons/common/utils.py @@ -6,8 +6,8 @@ def write_benchmark_result(result: list[str], save_path: str): - """Writes an array to file as a comma-separated line.""" - results_array = [str(i) for i in result] - print("Writing " + str(results_array)) - with open(save_path, "a") as f: - f.write(",".join(results_array) + "\n") + """Writes an array to file as a comma-separated line.""" + results_array = [str(i) for i in result] + print("Writing " + str(results_array)) + with open(save_path, "a") as f: + f.write(",".join(results_array) + "\n") diff --git a/build_tools/benchmarks/comparisons/mobilebert_fp32_commands.py b/build_tools/benchmarks/comparisons/mobilebert_fp32_commands.py index f8682a2cc379..b605e1abcad7 100644 --- a/build_tools/benchmarks/comparisons/mobilebert_fp32_commands.py +++ b/build_tools/benchmarks/comparisons/mobilebert_fp32_commands.py @@ -15,190 +15,232 @@ class TfliteMobilebertFP32(TFLiteBenchmarkCommand): - """ Specializes the benchmark command to use TFLite. """ - - def __init__(self, - benchmark_binary: str, - model_name: str, - model_path: str, - test_data_dir: str, - driver: str = "cpu", - num_threads: int = _DEFAULT_NUM_THREADS, - num_runs: int = _DEFAULT_NUM_BENCHMARK_RUNS, - taskset: Optional[str] = None): - super().__init__(benchmark_binary, - model_name, - model_path, - num_threads, - num_runs, - taskset=taskset) - self.driver = driver - self.args.append("--input_layer=input_ids,input_mask,segment_ids") - self.args.append("--input_layer_value_files=input_ids:" + test_data_dir + - "/input_word_id.bin,input_mask:" + test_data_dir + - "/input_mask.bin,segment_ids:" + test_data_dir + - "/input_type_id.bin") - self.args.append("--input_layer_shape=1,384:1,384:1,384") + """Specializes the benchmark command to use TFLite.""" + + def __init__( + self, + benchmark_binary: str, + model_name: str, + model_path: str, + test_data_dir: str, + driver: str = "cpu", + num_threads: int = _DEFAULT_NUM_THREADS, + num_runs: int = _DEFAULT_NUM_BENCHMARK_RUNS, + taskset: Optional[str] = None, + ): + super().__init__( + benchmark_binary, + model_name, + model_path, + num_threads, + num_runs, + taskset=taskset, + ) + self.driver = driver + self.args.append("--input_layer=input_ids,input_mask,segment_ids") + self.args.append( + "--input_layer_value_files=input_ids:" + + test_data_dir + + "/input_word_id.bin,input_mask:" + + test_data_dir + + "/input_mask.bin,segment_ids:" + + test_data_dir + + "/input_type_id.bin" + ) + self.args.append("--input_layer_shape=1,384:1,384:1,384") class IreeMobilebertFP32(IreeBenchmarkCommand): - """ Specializes the benchmark command to use IREE. """ - - def __init__(self, - benchmark_binary: str, - model_name: str, - model_path: str, - driver: str = "local-task", - num_threads: int = _DEFAULT_NUM_THREADS, - num_runs: int = _DEFAULT_NUM_BENCHMARK_RUNS, - taskset: Optional[str] = None): - super().__init__(benchmark_binary, - model_name, - model_path, - num_threads, - num_runs, - taskset=taskset) - self.driver = driver - self.args.append("--function=main") - self.args.append( - '--input="1x384xi32=101 2129 2116 19576 2015 2106 3854 4679 2486 1029 102 1996 14169 2165 2019 2220 2599 1999 3565 4605 2753 1998 2196 11145 1012 8446 2001 3132 2011 7573 1005 1055 3639 1010 2029 14159 2032 2698 2335 1998 3140 2032 2046 2093 20991 2015 1010 2164 1037 19576 2029 2027 6757 2005 1037 7921 1012 7573 15674 3854 4679 2001 2315 3565 4605 12041 1010 3405 2274 3948 10455 1010 1016 13714 14918 1010 1998 2048 3140 19576 2015 1012 102 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0"' - ) - self.args.append( - '--input="1x384xi32=0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0"' - ) - self.args.append( - '--input="1x384xi32=1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0"' - ) + """Specializes the benchmark command to use IREE.""" + + def __init__( + self, + benchmark_binary: str, + model_name: str, + model_path: str, + driver: str = "local-task", + num_threads: int = _DEFAULT_NUM_THREADS, + num_runs: int = _DEFAULT_NUM_BENCHMARK_RUNS, + taskset: Optional[str] = None, + ): + super().__init__( + benchmark_binary, + model_name, + model_path, + num_threads, + num_runs, + taskset=taskset, + ) + self.driver = driver + self.args.append("--function=main") + self.args.append( + '--input="1x384xi32=101 2129 2116 19576 2015 2106 3854 4679 2486 1029 102 1996 14169 2165 2019 2220 2599 1999 3565 4605 2753 1998 2196 11145 1012 8446 2001 3132 2011 7573 1005 1055 3639 1010 2029 14159 2032 2698 2335 1998 3140 2032 2046 2093 20991 2015 1010 2164 1037 19576 2029 2027 6757 2005 1037 7921 1012 7573 15674 3854 4679 2001 2315 3565 4605 12041 1010 3405 2274 3948 10455 1010 1016 13714 14918 1010 1998 2048 3140 19576 2015 1012 102 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0"' + ) + self.args.append( + '--input="1x384xi32=0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0"' + ) + self.args.append( + '--input="1x384xi32=1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0"' + ) class MobilebertFP32CommandFactory(BenchmarkCommandFactory): - """ Generates `BenchmarkCommand` objects specific to running MobileBert.""" - - def __init__(self, base_dir: str, model_name: str): - self._model_name = model_name - self._base_dir = base_dir - self._iree_benchmark_binary_path = os.path.join(base_dir, - "iree-benchmark-module") - self._tflite_benchmark_binary_path = os.path.join(base_dir, - "benchmark_model") - self._tflite_model_path = os.path.join(self._base_dir, "models", "tflite", - self._model_name + ".tflite") - self._tflite_test_data_dir = os.path.join(self._base_dir, "test_data", - "squad") - - def generate_benchmark_commands(self, device: str, - driver: str) -> list[BenchmarkCommand]: - if device == "desktop" and driver == "cpu": - return self._generate_cpu(device) - elif device == "desktop" and driver == "gpu": - return self._generate_gpu("cuda") - elif device == "mobile" and driver == "cpu": - return self._generate_cpu(device) - elif device == "mobile" and driver == "gpu": - return self._generate_gpu("vulkan") - else: - print("Warning! Not a valid configuration.") - return [] - - def _generate_cpu(self, device: str): - # Generate TFLite benchmarks. - tflite_mobilebert = TfliteMobilebertFP32(self._tflite_benchmark_binary_path, - self._model_name, - self._tflite_model_path, - self._tflite_test_data_dir, - driver="cpu") - - tflite_mobilebert_noxnn = TfliteMobilebertFP32( - self._tflite_benchmark_binary_path, - self._model_name + "_noxnn", - self._tflite_model_path, - self._tflite_test_data_dir, - driver="cpu") - tflite_mobilebert_noxnn.args.append("--use_xnnpack=false") - - # Generate IREE benchmarks. - driver = "local-task" - backend = "llvm-cpu" - iree_model_path = os.path.join(self._base_dir, "models", "iree", backend, - self._model_name + ".vmfb") - iree_mobilebert = IreeMobilebertFP32(self._iree_benchmark_binary_path, - self._model_name, - iree_model_path, - driver=driver) - commands = [tflite_mobilebert, tflite_mobilebert_noxnn, iree_mobilebert] - - # Test mmt4d only on mobile. - if device == "mobile": - model_mmt4d_name = self._model_name + "_mmt4d" - iree_mmt4d_model_path = os.path.join(self._base_dir, "models", "iree", - backend, model_mmt4d_name + ".vmfb") - iree_mmt4d_mobilebert = IreeMobilebertFP32( - self._iree_benchmark_binary_path, - model_mmt4d_name, - iree_mmt4d_model_path, - driver=driver) - commands.append(iree_mmt4d_mobilebert) - - model_im2col_mmt4d_name = self._model_name + "_im2col_mmt4d" - iree_im2col_mmt4d_model_path = os.path.join( - self._base_dir, "models", "iree", backend, - model_im2col_mmt4d_name + ".vmfb") - iree_im2col_mmt4d_mobilebert = IreeMobilebertFP32( - self._iree_benchmark_binary_path, - model_im2col_mmt4d_name, - iree_im2col_mmt4d_model_path, - driver=driver) - commands.append(iree_im2col_mmt4d_mobilebert) - - return commands - - def _generate_gpu(self, driver: str): - tflite_mobilebert = TfliteMobilebertFP32(self._tflite_benchmark_binary_path, - self._model_name, - self._tflite_model_path, - self._tflite_test_data_dir, - driver="gpu") - tflite_mobilebert.args.append("--gpu_precision_loss_allowed=false") - - tflite_mobilebert_noxnn = TfliteMobilebertFP32( - self._tflite_benchmark_binary_path, - self._model_name + "_noxnn", - self._tflite_model_path, - self._tflite_test_data_dir, - driver="gpu") - tflite_mobilebert_noxnn.args.append("--gpu_precision_loss_allowed=false") - tflite_mobilebert_noxnn.args.append("--use_xnnpack=false") - - tflite_mobilebert_fp16 = TfliteMobilebertFP32( - self._tflite_benchmark_binary_path, - self._model_name + "_fp16", - self._tflite_model_path, - self._tflite_test_data_dir, - driver="gpu") - tflite_mobilebert_fp16.args.append("--gpu_precision_loss_allowed=true") - - iree_model_path = os.path.join(self._base_dir, "models", "iree", driver, - self._model_name + ".vmfb") - iree_mobilebert = IreeMobilebertFP32(self._iree_benchmark_binary_path, - self._model_name, - iree_model_path, - driver=driver) - iree_fp16_model_path = os.path.join(self._base_dir, "models", "iree", - driver, self._model_name + "_fp16.vmfb") - iree_mobilebert_fp16 = IreeMobilebertFP32(self._iree_benchmark_binary_path, - self._model_name + "_fp16", - iree_fp16_model_path, - driver=driver) - iree_padfuse_model_path = os.path.join(self._base_dir, "models", "iree", - driver, - self._model_name + "_padfuse.vmfb") - iree_mobilebert_padfuse = IreeMobilebertFP32( - self._iree_benchmark_binary_path, - self._model_name + "_padfuse", - iree_padfuse_model_path, - driver=driver) - - return [ - tflite_mobilebert, tflite_mobilebert_noxnn, tflite_mobilebert_fp16, - iree_mobilebert, iree_mobilebert_fp16, iree_mobilebert_padfuse - ] + """Generates `BenchmarkCommand` objects specific to running MobileBert.""" + + def __init__(self, base_dir: str, model_name: str): + self._model_name = model_name + self._base_dir = base_dir + self._iree_benchmark_binary_path = os.path.join( + base_dir, "iree-benchmark-module" + ) + self._tflite_benchmark_binary_path = os.path.join(base_dir, "benchmark_model") + self._tflite_model_path = os.path.join( + self._base_dir, "models", "tflite", self._model_name + ".tflite" + ) + self._tflite_test_data_dir = os.path.join(self._base_dir, "test_data", "squad") + + def generate_benchmark_commands( + self, device: str, driver: str + ) -> list[BenchmarkCommand]: + if device == "desktop" and driver == "cpu": + return self._generate_cpu(device) + elif device == "desktop" and driver == "gpu": + return self._generate_gpu("cuda") + elif device == "mobile" and driver == "cpu": + return self._generate_cpu(device) + elif device == "mobile" and driver == "gpu": + return self._generate_gpu("vulkan") + else: + print("Warning! Not a valid configuration.") + return [] + + def _generate_cpu(self, device: str): + # Generate TFLite benchmarks. + tflite_mobilebert = TfliteMobilebertFP32( + self._tflite_benchmark_binary_path, + self._model_name, + self._tflite_model_path, + self._tflite_test_data_dir, + driver="cpu", + ) + + tflite_mobilebert_noxnn = TfliteMobilebertFP32( + self._tflite_benchmark_binary_path, + self._model_name + "_noxnn", + self._tflite_model_path, + self._tflite_test_data_dir, + driver="cpu", + ) + tflite_mobilebert_noxnn.args.append("--use_xnnpack=false") + + # Generate IREE benchmarks. + driver = "local-task" + backend = "llvm-cpu" + iree_model_path = os.path.join( + self._base_dir, "models", "iree", backend, self._model_name + ".vmfb" + ) + iree_mobilebert = IreeMobilebertFP32( + self._iree_benchmark_binary_path, + self._model_name, + iree_model_path, + driver=driver, + ) + commands = [tflite_mobilebert, tflite_mobilebert_noxnn, iree_mobilebert] + + # Test mmt4d only on mobile. + if device == "mobile": + model_mmt4d_name = self._model_name + "_mmt4d" + iree_mmt4d_model_path = os.path.join( + self._base_dir, "models", "iree", backend, model_mmt4d_name + ".vmfb" + ) + iree_mmt4d_mobilebert = IreeMobilebertFP32( + self._iree_benchmark_binary_path, + model_mmt4d_name, + iree_mmt4d_model_path, + driver=driver, + ) + commands.append(iree_mmt4d_mobilebert) + + model_im2col_mmt4d_name = self._model_name + "_im2col_mmt4d" + iree_im2col_mmt4d_model_path = os.path.join( + self._base_dir, + "models", + "iree", + backend, + model_im2col_mmt4d_name + ".vmfb", + ) + iree_im2col_mmt4d_mobilebert = IreeMobilebertFP32( + self._iree_benchmark_binary_path, + model_im2col_mmt4d_name, + iree_im2col_mmt4d_model_path, + driver=driver, + ) + commands.append(iree_im2col_mmt4d_mobilebert) + + return commands + + def _generate_gpu(self, driver: str): + tflite_mobilebert = TfliteMobilebertFP32( + self._tflite_benchmark_binary_path, + self._model_name, + self._tflite_model_path, + self._tflite_test_data_dir, + driver="gpu", + ) + tflite_mobilebert.args.append("--gpu_precision_loss_allowed=false") + + tflite_mobilebert_noxnn = TfliteMobilebertFP32( + self._tflite_benchmark_binary_path, + self._model_name + "_noxnn", + self._tflite_model_path, + self._tflite_test_data_dir, + driver="gpu", + ) + tflite_mobilebert_noxnn.args.append("--gpu_precision_loss_allowed=false") + tflite_mobilebert_noxnn.args.append("--use_xnnpack=false") + + tflite_mobilebert_fp16 = TfliteMobilebertFP32( + self._tflite_benchmark_binary_path, + self._model_name + "_fp16", + self._tflite_model_path, + self._tflite_test_data_dir, + driver="gpu", + ) + tflite_mobilebert_fp16.args.append("--gpu_precision_loss_allowed=true") + + iree_model_path = os.path.join( + self._base_dir, "models", "iree", driver, self._model_name + ".vmfb" + ) + iree_mobilebert = IreeMobilebertFP32( + self._iree_benchmark_binary_path, + self._model_name, + iree_model_path, + driver=driver, + ) + iree_fp16_model_path = os.path.join( + self._base_dir, "models", "iree", driver, self._model_name + "_fp16.vmfb" + ) + iree_mobilebert_fp16 = IreeMobilebertFP32( + self._iree_benchmark_binary_path, + self._model_name + "_fp16", + iree_fp16_model_path, + driver=driver, + ) + iree_padfuse_model_path = os.path.join( + self._base_dir, "models", "iree", driver, self._model_name + "_padfuse.vmfb" + ) + iree_mobilebert_padfuse = IreeMobilebertFP32( + self._iree_benchmark_binary_path, + self._model_name + "_padfuse", + iree_padfuse_model_path, + driver=driver, + ) + + return [ + tflite_mobilebert, + tflite_mobilebert_noxnn, + tflite_mobilebert_fp16, + iree_mobilebert, + iree_mobilebert_fp16, + iree_mobilebert_padfuse, + ] diff --git a/build_tools/benchmarks/comparisons/mobilebert_int8_commands.py b/build_tools/benchmarks/comparisons/mobilebert_int8_commands.py index eb78e41fa712..309bffb563a6 100644 --- a/build_tools/benchmarks/comparisons/mobilebert_int8_commands.py +++ b/build_tools/benchmarks/comparisons/mobilebert_int8_commands.py @@ -15,176 +15,212 @@ class TfliteMobilebertInt8(TFLiteBenchmarkCommand): - """ Specializes the benchmark command to use TFLite. """ - - def __init__(self, - benchmark_binary: str, - model_name: str, - model_path: str, - test_data_dir: str, - driver: str = "cpu", - num_threads: int = _DEFAULT_NUM_THREADS, - num_runs: int = _DEFAULT_NUM_BENCHMARK_RUNS, - taskset: Optional[str] = None): - super().__init__(benchmark_binary, - model_name, - model_path, - num_threads, - num_runs, - taskset=taskset) - self.driver = driver - self.args.append("--input_layer=input_ids,segment_ids,input_mask") - self.args.append("--input_layer_value_files=input_ids:" + test_data_dir + - "/input_word_id.bin,segment_ids:" + test_data_dir + - "/input_type_id.bin,input_mask:" + test_data_dir + - "/input_mask.bin") - self.args.append("--input_layer_shape=1,384:1,384:1,384") + """Specializes the benchmark command to use TFLite.""" + + def __init__( + self, + benchmark_binary: str, + model_name: str, + model_path: str, + test_data_dir: str, + driver: str = "cpu", + num_threads: int = _DEFAULT_NUM_THREADS, + num_runs: int = _DEFAULT_NUM_BENCHMARK_RUNS, + taskset: Optional[str] = None, + ): + super().__init__( + benchmark_binary, + model_name, + model_path, + num_threads, + num_runs, + taskset=taskset, + ) + self.driver = driver + self.args.append("--input_layer=input_ids,segment_ids,input_mask") + self.args.append( + "--input_layer_value_files=input_ids:" + + test_data_dir + + "/input_word_id.bin,segment_ids:" + + test_data_dir + + "/input_type_id.bin,input_mask:" + + test_data_dir + + "/input_mask.bin" + ) + self.args.append("--input_layer_shape=1,384:1,384:1,384") class IreeMobilebertInt8(IreeBenchmarkCommand): - """ Specializes the benchmark command to use IREE. """ - - def __init__(self, - benchmark_binary: str, - model_name: str, - model_path: str, - driver: str = "local-task", - num_threads: int = _DEFAULT_NUM_THREADS, - num_runs: int = _DEFAULT_NUM_BENCHMARK_RUNS, - taskset: Optional[str] = None): - super().__init__(benchmark_binary, - model_name, - model_path, - num_threads, - num_runs, - taskset=taskset) - self.driver = driver - self.args.append("--function=main") - self.args.append( - '--input="1x384xi32=101 2129 2116 19576 2015 2106 3854 4679 2486 1029 102 1996 14169 2165 2019 2220 2599 1999 3565 4605 2753 1998 2196 11145 1012 8446 2001 3132 2011 7573 1005 1055 3639 1010 2029 14159 2032 2698 2335 1998 3140 2032 2046 2093 20991 2015 1010 2164 1037 19576 2029 2027 6757 2005 1037 7921 1012 7573 15674 3854 4679 2001 2315 3565 4605 12041 1010 3405 2274 3948 10455 1010 1016 13714 14918 1010 1998 2048 3140 19576 2015 1012 102 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0"' - ) - self.args.append( - '--input="1x384xi32=0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0"' - ) - self.args.append( - '--input="1x384xi32=1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0"' - ) + """Specializes the benchmark command to use IREE.""" + + def __init__( + self, + benchmark_binary: str, + model_name: str, + model_path: str, + driver: str = "local-task", + num_threads: int = _DEFAULT_NUM_THREADS, + num_runs: int = _DEFAULT_NUM_BENCHMARK_RUNS, + taskset: Optional[str] = None, + ): + super().__init__( + benchmark_binary, + model_name, + model_path, + num_threads, + num_runs, + taskset=taskset, + ) + self.driver = driver + self.args.append("--function=main") + self.args.append( + '--input="1x384xi32=101 2129 2116 19576 2015 2106 3854 4679 2486 1029 102 1996 14169 2165 2019 2220 2599 1999 3565 4605 2753 1998 2196 11145 1012 8446 2001 3132 2011 7573 1005 1055 3639 1010 2029 14159 2032 2698 2335 1998 3140 2032 2046 2093 20991 2015 1010 2164 1037 19576 2029 2027 6757 2005 1037 7921 1012 7573 15674 3854 4679 2001 2315 3565 4605 12041 1010 3405 2274 3948 10455 1010 1016 13714 14918 1010 1998 2048 3140 19576 2015 1012 102 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0"' + ) + self.args.append( + '--input="1x384xi32=0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0"' + ) + self.args.append( + '--input="1x384xi32=1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0"' + ) class MobilebertInt8CommandFactory(BenchmarkCommandFactory): - """ Generates `BenchmarkCommand` objects specific to running MobileBert.""" - - def __init__(self, base_dir: str): - self._model_name = "mobilebert-baseline-tf2-quant" - self._base_dir = base_dir - self._iree_benchmark_binary_path = os.path.join(base_dir, - "iree-benchmark-module") - self._tflite_benchmark_binary_path = os.path.join(base_dir, - "benchmark_model") - self._tflite_model_path = os.path.join(self._base_dir, "models", "tflite", - self._model_name + ".tflite") - self._tflite_test_data_dir = os.path.join(self._base_dir, "test_data", - "squad") - - def generate_benchmark_commands(self, device: str, - driver: str) -> list[BenchmarkCommand]: - if device == "desktop" and driver == "cpu": - return self._generate_cpu(device) - elif device == "desktop" and driver == "gpu": - return self._generate_gpu("cuda") - elif device == "mobile" and driver == "cpu": - return self._generate_cpu(device) - elif device == "mobile" and driver == "gpu": - return self._generate_gpu("vulkan") - else: - print("Warning! Not a valid configuration.") - return [] - - def _generate_cpu(self, device: str): - # Generate TFLite benchmarks. - tflite_mobilebert = TfliteMobilebertInt8(self._tflite_benchmark_binary_path, - self._model_name, - self._tflite_model_path, - self._tflite_test_data_dir, - driver="cpu") - - tflite_mobilebert_noxnn = TfliteMobilebertInt8( - self._tflite_benchmark_binary_path, - self._model_name + "_noxnn", - self._tflite_model_path, - self._tflite_test_data_dir, - driver="cpu") - tflite_mobilebert_noxnn.args.append("--use_xnnpack=false") - - # Generate IREE benchmarks. - driver = "local-task" - backend = "llvm-cpu" - iree_model_path = os.path.join(self._base_dir, "models", "iree", backend, - self._model_name + ".vmfb") - iree_mobilebert = IreeMobilebertInt8(self._iree_benchmark_binary_path, - self._model_name, - iree_model_path, - driver=driver) - commands = [tflite_mobilebert, tflite_mobilebert_noxnn, iree_mobilebert] - - # Test mmt4d only on mobile. - if device == "mobile": - model_mmt4d_name = self._model_name + "_mmt4d" - iree_mmt4d_model_path = os.path.join(self._base_dir, "models", "iree", - backend, model_mmt4d_name + ".vmfb") - iree_mmt4d_mobilebert = IreeMobilebertInt8( - self._iree_benchmark_binary_path, - model_mmt4d_name, - iree_mmt4d_model_path, - driver=driver) - commands.append(iree_mmt4d_mobilebert) - - model_im2col_mmt4d_name = self._model_name + "_im2col_mmt4d" - iree_im2col_mmt4d_model_path = os.path.join( - self._base_dir, "models", "iree", backend, - model_im2col_mmt4d_name + ".vmfb") - iree_im2col_mmt4d_mobilebert = IreeMobilebertInt8( - self._iree_benchmark_binary_path, - model_im2col_mmt4d_name, - iree_im2col_mmt4d_model_path, - driver=driver) - commands.append(iree_im2col_mmt4d_mobilebert) - - return commands - - def _generate_gpu(self, driver: str): - tflite_mobilebert = TfliteMobilebertInt8(self._tflite_benchmark_binary_path, - self._model_name, - self._tflite_model_path, - self._tflite_test_data_dir, - driver="gpu") - tflite_mobilebert.args.append("--gpu_precision_loss_allowed=false") - - tflite_mobilebert_noxnn = TfliteMobilebertInt8( - self._tflite_benchmark_binary_path, - self._model_name + "_noxnn", - self._tflite_model_path, - self._tflite_test_data_dir, - driver="gpu") - tflite_mobilebert_noxnn.args.append("--gpu_precision_loss_allowed=false") - tflite_mobilebert_noxnn.args.append("--use_xnnpack=false") - - iree_model_path = os.path.join(self._base_dir, "models", "iree", driver, - self._model_name + ".vmfb") - iree_mobilebert = IreeMobilebertInt8(self._iree_benchmark_binary_path, - self._model_name, - iree_model_path, - driver=driver) - - iree_padfuse_model_path = os.path.join(self._base_dir, "models", "iree", - driver, - self._model_name + "_padfuse.vmfb") - iree_padfuse_mobilebert = IreeMobilebertInt8( - self._iree_benchmark_binary_path, - self._model_name + "_padfuse", - iree_padfuse_model_path, - driver=driver) - return [ - tflite_mobilebert, tflite_mobilebert_noxnn, iree_mobilebert, - iree_padfuse_mobilebert - ] + """Generates `BenchmarkCommand` objects specific to running MobileBert.""" + + def __init__(self, base_dir: str): + self._model_name = "mobilebert-baseline-tf2-quant" + self._base_dir = base_dir + self._iree_benchmark_binary_path = os.path.join( + base_dir, "iree-benchmark-module" + ) + self._tflite_benchmark_binary_path = os.path.join(base_dir, "benchmark_model") + self._tflite_model_path = os.path.join( + self._base_dir, "models", "tflite", self._model_name + ".tflite" + ) + self._tflite_test_data_dir = os.path.join(self._base_dir, "test_data", "squad") + + def generate_benchmark_commands( + self, device: str, driver: str + ) -> list[BenchmarkCommand]: + if device == "desktop" and driver == "cpu": + return self._generate_cpu(device) + elif device == "desktop" and driver == "gpu": + return self._generate_gpu("cuda") + elif device == "mobile" and driver == "cpu": + return self._generate_cpu(device) + elif device == "mobile" and driver == "gpu": + return self._generate_gpu("vulkan") + else: + print("Warning! Not a valid configuration.") + return [] + + def _generate_cpu(self, device: str): + # Generate TFLite benchmarks. + tflite_mobilebert = TfliteMobilebertInt8( + self._tflite_benchmark_binary_path, + self._model_name, + self._tflite_model_path, + self._tflite_test_data_dir, + driver="cpu", + ) + + tflite_mobilebert_noxnn = TfliteMobilebertInt8( + self._tflite_benchmark_binary_path, + self._model_name + "_noxnn", + self._tflite_model_path, + self._tflite_test_data_dir, + driver="cpu", + ) + tflite_mobilebert_noxnn.args.append("--use_xnnpack=false") + + # Generate IREE benchmarks. + driver = "local-task" + backend = "llvm-cpu" + iree_model_path = os.path.join( + self._base_dir, "models", "iree", backend, self._model_name + ".vmfb" + ) + iree_mobilebert = IreeMobilebertInt8( + self._iree_benchmark_binary_path, + self._model_name, + iree_model_path, + driver=driver, + ) + commands = [tflite_mobilebert, tflite_mobilebert_noxnn, iree_mobilebert] + + # Test mmt4d only on mobile. + if device == "mobile": + model_mmt4d_name = self._model_name + "_mmt4d" + iree_mmt4d_model_path = os.path.join( + self._base_dir, "models", "iree", backend, model_mmt4d_name + ".vmfb" + ) + iree_mmt4d_mobilebert = IreeMobilebertInt8( + self._iree_benchmark_binary_path, + model_mmt4d_name, + iree_mmt4d_model_path, + driver=driver, + ) + commands.append(iree_mmt4d_mobilebert) + + model_im2col_mmt4d_name = self._model_name + "_im2col_mmt4d" + iree_im2col_mmt4d_model_path = os.path.join( + self._base_dir, + "models", + "iree", + backend, + model_im2col_mmt4d_name + ".vmfb", + ) + iree_im2col_mmt4d_mobilebert = IreeMobilebertInt8( + self._iree_benchmark_binary_path, + model_im2col_mmt4d_name, + iree_im2col_mmt4d_model_path, + driver=driver, + ) + commands.append(iree_im2col_mmt4d_mobilebert) + + return commands + + def _generate_gpu(self, driver: str): + tflite_mobilebert = TfliteMobilebertInt8( + self._tflite_benchmark_binary_path, + self._model_name, + self._tflite_model_path, + self._tflite_test_data_dir, + driver="gpu", + ) + tflite_mobilebert.args.append("--gpu_precision_loss_allowed=false") + + tflite_mobilebert_noxnn = TfliteMobilebertInt8( + self._tflite_benchmark_binary_path, + self._model_name + "_noxnn", + self._tflite_model_path, + self._tflite_test_data_dir, + driver="gpu", + ) + tflite_mobilebert_noxnn.args.append("--gpu_precision_loss_allowed=false") + tflite_mobilebert_noxnn.args.append("--use_xnnpack=false") + + iree_model_path = os.path.join( + self._base_dir, "models", "iree", driver, self._model_name + ".vmfb" + ) + iree_mobilebert = IreeMobilebertInt8( + self._iree_benchmark_binary_path, + self._model_name, + iree_model_path, + driver=driver, + ) + + iree_padfuse_model_path = os.path.join( + self._base_dir, "models", "iree", driver, self._model_name + "_padfuse.vmfb" + ) + iree_padfuse_mobilebert = IreeMobilebertInt8( + self._iree_benchmark_binary_path, + self._model_name + "_padfuse", + iree_padfuse_model_path, + driver=driver, + ) + return [ + tflite_mobilebert, + tflite_mobilebert_noxnn, + iree_mobilebert, + iree_padfuse_mobilebert, + ] diff --git a/build_tools/benchmarks/comparisons/run_benchmarks.py b/build_tools/benchmarks/comparisons/run_benchmarks.py index 00b0fb4050c7..24d1bb5de39d 100644 --- a/build_tools/benchmarks/comparisons/run_benchmarks.py +++ b/build_tools/benchmarks/comparisons/run_benchmarks.py @@ -31,178 +31,222 @@ from simple_commands import * -def benchmark_desktop_cpu(device_name: str, - command_factories: list[BenchmarkCommandFactory], - results_path: str): - benchmarks = [] - for factory in command_factories: - benchmarks.extend(factory.generate_benchmark_commands("desktop", "cpu")) - - for num_threads in [1, 2, 4, 8]: +def benchmark_desktop_cpu( + device_name: str, + command_factories: list[BenchmarkCommandFactory], + results_path: str, +): + benchmarks = [] + for factory in command_factories: + benchmarks.extend(factory.generate_benchmark_commands("desktop", "cpu")) + + for num_threads in [1, 2, 4, 8]: + for benchmark in benchmarks: + results_array = [ + device_name, + benchmark.model_name, + benchmark.runtime, + benchmark.driver, + num_threads, + ] + benchmark.num_threads = num_threads + results_array.extend(run_command(benchmark)) + write_benchmark_result(results_array, results_path) + + +def benchmark_desktop_gpu( + device_name: str, + command_factories: list[BenchmarkCommandFactory], + results_path: str, +): + benchmarks = [] + for factory in command_factories: + benchmarks.extend(factory.generate_benchmark_commands("desktop", "gpu")) for benchmark in benchmarks: - results_array = [ - device_name, benchmark.model_name, benchmark.runtime, - benchmark.driver, num_threads - ] - benchmark.num_threads = num_threads - results_array.extend(run_command(benchmark)) - write_benchmark_result(results_array, results_path) - - -def benchmark_desktop_gpu(device_name: str, - command_factories: list[BenchmarkCommandFactory], - results_path: str): - benchmarks = [] - for factory in command_factories: - benchmarks.extend(factory.generate_benchmark_commands("desktop", "gpu")) - for benchmark in benchmarks: - results_array = [ - device_name, benchmark.model_name, benchmark.runtime, benchmark.driver, - benchmark.num_threads - ] - results_array.extend(run_command(benchmark)) - write_benchmark_result(results_array, results_path) - - -def benchmark_mobile_cpu(device_name: str, - command_factories: list[BenchmarkCommandFactory], - results_path: str): - benchmarks = [] - for factory in command_factories: - benchmarks.extend(factory.generate_benchmark_commands("mobile", "cpu")) - - for _, tuple in enumerate([("80", 1), ("C0", 2), ("F0", 4), ("0F", 4), - ("FF", 8)]): - taskset = tuple[0] - num_threads = tuple[1] + results_array = [ + device_name, + benchmark.model_name, + benchmark.runtime, + benchmark.driver, + benchmark.num_threads, + ] + results_array.extend(run_command(benchmark)) + write_benchmark_result(results_array, results_path) + + +def benchmark_mobile_cpu( + device_name: str, + command_factories: list[BenchmarkCommandFactory], + results_path: str, +): + benchmarks = [] + for factory in command_factories: + benchmarks.extend(factory.generate_benchmark_commands("mobile", "cpu")) + + for _, tuple in enumerate([("80", 1), ("C0", 2), ("F0", 4), ("0F", 4), ("FF", 8)]): + taskset = tuple[0] + num_threads = tuple[1] + for benchmark in benchmarks: + results_array = [ + device_name, + benchmark.model_name, + benchmark.runtime, + benchmark.driver, + taskset, + num_threads, + ] + benchmark.taskset = taskset + benchmark.num_threads = num_threads + results_array.extend(run_command(benchmark)) + write_benchmark_result(results_array, results_path) + + +def benchmark_mobile_gpu( + device_name: str, + command_factories: list[BenchmarkCommandFactory], + results_path: str, +): + benchmarks = [] + for factory in command_factories: + benchmarks.extend(factory.generate_benchmark_commands("mobile", "gpu")) + + taskset = "80" + num_threads = 1 for benchmark in benchmarks: - results_array = [ - device_name, benchmark.model_name, benchmark.runtime, - benchmark.driver, taskset, num_threads - ] - benchmark.taskset = taskset - benchmark.num_threads = num_threads - results_array.extend(run_command(benchmark)) - write_benchmark_result(results_array, results_path) - - -def benchmark_mobile_gpu(device_name: str, - command_factories: list[BenchmarkCommandFactory], - results_path: str): - benchmarks = [] - for factory in command_factories: - benchmarks.extend(factory.generate_benchmark_commands("mobile", "gpu")) - - taskset = "80" - num_threads = 1 - for benchmark in benchmarks: - results_array = [ - device_name, benchmark.model_name, benchmark.runtime, benchmark.driver, - taskset, num_threads - ] - benchmark.taskset = taskset - benchmark.num_threads = num_threads - results_array.extend(run_command(benchmark)) - write_benchmark_result(results_array, results_path) + results_array = [ + device_name, + benchmark.model_name, + benchmark.runtime, + benchmark.driver, + taskset, + num_threads, + ] + benchmark.taskset = taskset + benchmark.num_threads = num_threads + results_array.extend(run_command(benchmark)) + write_benchmark_result(results_array, results_path) def main(args): - # Create factories for all models to be benchmarked. - command_factory = [] - command_factory.append( - MobilebertFP32CommandFactory(args.base_dir, "mobilebert_float_384_gpu")) - command_factory.append(MobilebertInt8CommandFactory(args.base_dir)) - command_factory.append( - MobilebertFP32CommandFactory(args.base_dir, "albert_lite_base_squadv1_1")) - command_factory.append( - SimpleCommandFactory(args.base_dir, "mobilenet_v2_1.0_224", - "1x224x224x3xf32")) - command_factory.append( - SimpleCommandFactory(args.base_dir, "mobilenet_v2_224_1.0_uint8", - "1x224x224x3xui8")) - command_factory.append( - SimpleCommandFactory(args.base_dir, "deeplabv3", "1x257x257x3xf32")) - command_factory.append( - SimpleCommandFactory(args.base_dir, "person_detect", "1x96x96x1xi8")) - command_factory.append( - SimpleCommandFactory(args.base_dir, "ssd_mobilenet_v2_static_1.0_int8", - "1x320x320x3xi8")) - command_factory.append( - SimpleCommandFactory(args.base_dir, "resnet_v2_101_1_default_1", - "1x299x299x3xf32")) - command_factory.append( - SimpleCommandFactory(args.base_dir, "ssd_mobilenet_v2_fpnlite_uint8", - "1x320x320x3xui8")) - command_factory.append( - SimpleCommandFactory(args.base_dir, "ssd_mobilenet_v2_fpnlite_fp32", - "1x320x320x3xf32")) - command_factory.append( - SimpleCommandFactory(args.base_dir, "efficientnet_lite0_int8_2", - "1x224x224x3xui8")) - command_factory.append( - SimpleCommandFactory(args.base_dir, "efficientnet_lite0_fp32_2", - "1x224x224x3xf32")) - command_factory.append( - SimpleCommandFactory(args.base_dir, "inception_v4_299_uint8", - "1x299x299x3xui8")) - command_factory.append( - SimpleCommandFactory(args.base_dir, "inception_v4_299_fp32", - "1x299x299x3xf32")) - - if args.mode == "desktop": - results_path = os.path.join(args.output_dir, "results.csv") - with open(results_path, "w") as f: - f.write( - "device,model,runtime,driver/delegate,threads,latency (ms),vmhwm (KB),vmrss (KB),rssfile (KB)\n" - ) - - if not args.disable_cpu: - benchmark_desktop_cpu(args.device_name, command_factory, results_path) - if not args.disable_gpu: - benchmark_desktop_gpu(args.device_name, command_factory, results_path) - else: - assert (args.mode == "mobile") - results_path = os.path.join(args.output_dir, "results.csv") - with open(results_path, "w") as f: - f.write( - "device,model,runtime,driver/delegate,taskset,threads,latency (ms),vmhwm (KB),vmrss (KB),rssfile (KB)\n" - ) - if not args.disable_cpu: - benchmark_mobile_cpu(args.device_name, command_factory, results_path) - if not args.disable_gpu: - benchmark_mobile_gpu(args.device_name, command_factory, results_path) + # Create factories for all models to be benchmarked. + command_factory = [] + command_factory.append( + MobilebertFP32CommandFactory(args.base_dir, "mobilebert_float_384_gpu") + ) + command_factory.append(MobilebertInt8CommandFactory(args.base_dir)) + command_factory.append( + MobilebertFP32CommandFactory(args.base_dir, "albert_lite_base_squadv1_1") + ) + command_factory.append( + SimpleCommandFactory(args.base_dir, "mobilenet_v2_1.0_224", "1x224x224x3xf32") + ) + command_factory.append( + SimpleCommandFactory( + args.base_dir, "mobilenet_v2_224_1.0_uint8", "1x224x224x3xui8" + ) + ) + command_factory.append( + SimpleCommandFactory(args.base_dir, "deeplabv3", "1x257x257x3xf32") + ) + command_factory.append( + SimpleCommandFactory(args.base_dir, "person_detect", "1x96x96x1xi8") + ) + command_factory.append( + SimpleCommandFactory( + args.base_dir, "ssd_mobilenet_v2_static_1.0_int8", "1x320x320x3xi8" + ) + ) + command_factory.append( + SimpleCommandFactory( + args.base_dir, "resnet_v2_101_1_default_1", "1x299x299x3xf32" + ) + ) + command_factory.append( + SimpleCommandFactory( + args.base_dir, "ssd_mobilenet_v2_fpnlite_uint8", "1x320x320x3xui8" + ) + ) + command_factory.append( + SimpleCommandFactory( + args.base_dir, "ssd_mobilenet_v2_fpnlite_fp32", "1x320x320x3xf32" + ) + ) + command_factory.append( + SimpleCommandFactory( + args.base_dir, "efficientnet_lite0_int8_2", "1x224x224x3xui8" + ) + ) + command_factory.append( + SimpleCommandFactory( + args.base_dir, "efficientnet_lite0_fp32_2", "1x224x224x3xf32" + ) + ) + command_factory.append( + SimpleCommandFactory(args.base_dir, "inception_v4_299_uint8", "1x299x299x3xui8") + ) + command_factory.append( + SimpleCommandFactory(args.base_dir, "inception_v4_299_fp32", "1x299x299x3xf32") + ) + + if args.mode == "desktop": + results_path = os.path.join(args.output_dir, "results.csv") + with open(results_path, "w") as f: + f.write( + "device,model,runtime,driver/delegate,threads,latency (ms),vmhwm (KB),vmrss (KB),rssfile (KB)\n" + ) + + if not args.disable_cpu: + benchmark_desktop_cpu(args.device_name, command_factory, results_path) + if not args.disable_gpu: + benchmark_desktop_gpu(args.device_name, command_factory, results_path) + else: + assert args.mode == "mobile" + results_path = os.path.join(args.output_dir, "results.csv") + with open(results_path, "w") as f: + f.write( + "device,model,runtime,driver/delegate,taskset,threads,latency (ms),vmhwm (KB),vmrss (KB),rssfile (KB)\n" + ) + if not args.disable_cpu: + benchmark_mobile_cpu(args.device_name, command_factory, results_path) + if not args.disable_gpu: + benchmark_mobile_gpu(args.device_name, command_factory, results_path) def parse_args(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--device_name", - type=str, - default=None, - help="The name of the device the benchmark is running on e.g. Pixel 6") - parser.add_argument( - "--base_dir", - type=str, - default=None, - help="The directory where all benchmarking artifacts are located.") - parser.add_argument("--output_dir", - type=str, - default=None, - help="The directory to save output artifacts into.") - parser.add_argument( - "--mode", - type=str, - choices=("desktop", "mobile"), - default="desktop", - help="The benchmarking mode to use. If mode is `mobile`, uses tasksets.") - parser.add_argument("--disable_cpu", - action="store_true", - help="Disables running benchmarks on CPU.") - parser.add_argument("--disable_gpu", - action="store_true", - help="Disables running benchmarks on GPU.") - return parser.parse_args() - - -if __name__ == '__main__': - main(parse_args()) + parser = argparse.ArgumentParser() + parser.add_argument( + "--device_name", + type=str, + default=None, + help="The name of the device the benchmark is running on e.g. Pixel 6", + ) + parser.add_argument( + "--base_dir", + type=str, + default=None, + help="The directory where all benchmarking artifacts are located.", + ) + parser.add_argument( + "--output_dir", + type=str, + default=None, + help="The directory to save output artifacts into.", + ) + parser.add_argument( + "--mode", + type=str, + choices=("desktop", "mobile"), + default="desktop", + help="The benchmarking mode to use. If mode is `mobile`, uses tasksets.", + ) + parser.add_argument( + "--disable_cpu", action="store_true", help="Disables running benchmarks on CPU." + ) + parser.add_argument( + "--disable_gpu", action="store_true", help="Disables running benchmarks on GPU." + ) + return parser.parse_args() + + +if __name__ == "__main__": + main(parse_args()) diff --git a/build_tools/benchmarks/comparisons/simple_commands.py b/build_tools/benchmarks/comparisons/simple_commands.py index 2e49c3076cba..ef001bfeabec 100644 --- a/build_tools/benchmarks/comparisons/simple_commands.py +++ b/build_tools/benchmarks/comparisons/simple_commands.py @@ -15,217 +15,263 @@ class TfliteWrapper(TFLiteBenchmarkCommand): - """Specializes the benchmark command to use TFLite.""" - - def __init__(self, - benchmark_binary: str, - model_name: str, - model_path: str, - input_layer: Optional[str] = None, - input_shape: Optional[str] = None, - driver: str = "cpu", - num_threads: int = _DEFAULT_NUM_THREADS, - num_runs: int = _DEFAULT_NUM_BENCHMARK_RUNS, - taskset: Optional[str] = None): - super().__init__(benchmark_binary, - model_name, - model_path, - num_threads, - num_runs, - taskset=taskset) - self.driver = driver - if input_layer and input_shape: - self.args.append("--input_layer=%s" % input_layer) - self.args.append("--input_layer_shape=%s" % input_shape) + """Specializes the benchmark command to use TFLite.""" + + def __init__( + self, + benchmark_binary: str, + model_name: str, + model_path: str, + input_layer: Optional[str] = None, + input_shape: Optional[str] = None, + driver: str = "cpu", + num_threads: int = _DEFAULT_NUM_THREADS, + num_runs: int = _DEFAULT_NUM_BENCHMARK_RUNS, + taskset: Optional[str] = None, + ): + super().__init__( + benchmark_binary, + model_name, + model_path, + num_threads, + num_runs, + taskset=taskset, + ) + self.driver = driver + if input_layer and input_shape: + self.args.append("--input_layer=%s" % input_layer) + self.args.append("--input_layer_shape=%s" % input_shape) class IreeWrapper(IreeBenchmarkCommand): - """Specializes the benchmark command to use IREE.""" - - def __init__(self, - benchmark_binary: str, - model_name: str, - model_path: str, - function_input: str, - driver: str = "local-task", - num_threads: int = _DEFAULT_NUM_THREADS, - num_runs: int = _DEFAULT_NUM_BENCHMARK_RUNS, - taskset: Optional[str] = None): - super().__init__(benchmark_binary, - model_name, - model_path, - num_threads, - num_runs, - taskset=taskset) - self.driver = driver - self.args.append("--function=main") - self.args.append('--input="%s"' % function_input) + """Specializes the benchmark command to use IREE.""" + + def __init__( + self, + benchmark_binary: str, + model_name: str, + model_path: str, + function_input: str, + driver: str = "local-task", + num_threads: int = _DEFAULT_NUM_THREADS, + num_runs: int = _DEFAULT_NUM_BENCHMARK_RUNS, + taskset: Optional[str] = None, + ): + super().__init__( + benchmark_binary, + model_name, + model_path, + num_threads, + num_runs, + taskset=taskset, + ) + self.driver = driver + self.args.append("--function=main") + self.args.append('--input="%s"' % function_input) class SimpleCommandFactory(BenchmarkCommandFactory): - """ - Generates `BenchmarkCommand` objects specific to running series of simple models. - - A model is considered simple if its inputs can be generically generated based - on expected signature only without affecting behavior. - """ - - def __init__(self, - base_dir: str, - model_name: str, - function_input: str, - input_name: Optional[str] = None, - input_layer: Optional[str] = None): - self._model_name = model_name - self._function_input = function_input - self._input_name = input_name - self._input_layer = input_layer - self._base_dir = base_dir - self._iree_benchmark_binary_path = os.path.join(base_dir, - "iree-benchmark-module") - self._tflite_benchmark_binary_path = os.path.join(base_dir, - "benchmark_model") - # Required to be set, but no test data used yet. - self._tflite_test_data_dir = os.path.join(self._base_dir, "test_data") - - def generate_benchmark_commands(self, device: str, - driver: str) -> list[BenchmarkCommand]: - if device == "desktop" and driver == "cpu": - return self._generate_cpu(device) - elif device == "desktop" and driver == "gpu": - return self._generate_gpu("cuda") - elif device == "mobile" and driver == "cpu": - return self._generate_cpu(device) - elif device == "mobile" and driver == "gpu": - return self._generate_gpu("vulkan") - else: - print("Warning! Not a valid configuration.") - return [] - - def _generate_cpu(self, device: str): - commands = [] - # Generate TFLite benchmarks. - tflite_model_path = os.path.join(self._base_dir, "models", "tflite", - self._model_name + ".tflite") - tflite = TfliteWrapper(self._tflite_benchmark_binary_path, - self._model_name, - tflite_model_path, - self._input_name, - driver="cpu") - commands.append(tflite) - - tflite_noxnn = TfliteWrapper(self._tflite_benchmark_binary_path, - self._model_name + "_noxnn", - tflite_model_path, - self._input_name, - driver="cpu") - tflite_noxnn.args.append("--use_xnnpack=false") - commands.append(tflite_noxnn) - - # Generate IREE benchmarks. - driver = "local-task" - backend = "llvm-cpu" - - iree_model_path = os.path.join(self._base_dir, "models", "iree", backend, - self._model_name + ".vmfb") - iree = IreeWrapper(self._iree_benchmark_binary_path, - self._model_name, - iree_model_path, - self._function_input, - driver=driver) - commands.append(iree) - - model_padfuse_name = self._model_name + "_padfuse" - iree_padfuse_model_path = os.path.join(self._base_dir, "models", "iree", - backend, - model_padfuse_name + ".vmfb") - iree_padfuse = IreeWrapper(self._iree_benchmark_binary_path, - model_padfuse_name, - iree_padfuse_model_path, - self._function_input, - driver=driver) - commands.append(iree_padfuse) - - # Test mmt4d only on mobile. - if device == "mobile": - model_mmt4d_name = self._model_name + "_mmt4d" - iree_mmt4d_model_path = os.path.join(self._base_dir, "models", "iree", - backend, model_mmt4d_name + ".vmfb") - iree_mmt4d = IreeWrapper(self._iree_benchmark_binary_path, - model_mmt4d_name, - iree_mmt4d_model_path, - self._function_input, - driver=driver) - commands.append(iree_mmt4d) - - model_im2col_mmt4d_name = self._model_name + "_im2col_mmt4d" - iree_im2col_mmt4d_model_path = os.path.join( - self._base_dir, "models", "iree", backend, - model_im2col_mmt4d_name + ".vmfb") - iree_im2col_mmt4d = IreeWrapper(self._iree_benchmark_binary_path, - model_im2col_mmt4d_name, - iree_im2col_mmt4d_model_path, - self._function_input, - driver=driver) - commands.append(iree_im2col_mmt4d) - - return commands - - def _generate_gpu(self, driver: str): - commands = [] - tflite_model_path = os.path.join(self._base_dir, "models", "tflite", - self._model_name + ".tflite") - tflite = TfliteWrapper(self._tflite_benchmark_binary_path, - self._model_name, - tflite_model_path, - self._input_name, - self._input_layer, - driver="gpu") - tflite.args.append("--gpu_precision_loss_allowed=false") - commands.append(tflite) - - tflite_noxnn = TfliteWrapper(self._tflite_benchmark_binary_path, - self._model_name + "_noxnn", - tflite_model_path, - self._input_name, - self._input_layer, - driver="gpu") - tflite.args.append("--use_xnnpack=false") - commands.append(tflite_noxnn) - - tflite_fp16 = TfliteWrapper(self._tflite_benchmark_binary_path, - self._model_name + "_fp16", - tflite_model_path, - self._input_name, - self._input_layer, - driver="gpu") - tflite.args.append("--gpu_precision_loss_allowed=true") - commands.append(tflite_fp16) - - iree_model_path = os.path.join(self._base_dir, "models", "iree", driver, - self._model_name + ".vmfb") - iree = IreeWrapper(self._iree_benchmark_binary_path, - self._model_name, - iree_model_path, - self._function_input, - driver=driver) - commands.append(iree) - - iree_model_path = os.path.join(self._base_dir, "models", "iree", driver, - self._model_name + "_fp16.vmfb") - iree = IreeWrapper(self._iree_benchmark_binary_path, - self._model_name + "_fp16", - iree_model_path, - self._function_input, - driver=driver) - commands.append(iree) - - iree_model_path = os.path.join(self._base_dir, "models", "iree", driver, - self._model_name + "_padfuse.vmfb") - iree = IreeWrapper(self._iree_benchmark_binary_path, - self._model_name + "_padfuse", - iree_model_path, - self._function_input, - driver=driver) - commands.append(iree) - return commands + """ + Generates `BenchmarkCommand` objects specific to running series of simple models. + + A model is considered simple if its inputs can be generically generated based + on expected signature only without affecting behavior. + """ + + def __init__( + self, + base_dir: str, + model_name: str, + function_input: str, + input_name: Optional[str] = None, + input_layer: Optional[str] = None, + ): + self._model_name = model_name + self._function_input = function_input + self._input_name = input_name + self._input_layer = input_layer + self._base_dir = base_dir + self._iree_benchmark_binary_path = os.path.join( + base_dir, "iree-benchmark-module" + ) + self._tflite_benchmark_binary_path = os.path.join(base_dir, "benchmark_model") + # Required to be set, but no test data used yet. + self._tflite_test_data_dir = os.path.join(self._base_dir, "test_data") + + def generate_benchmark_commands( + self, device: str, driver: str + ) -> list[BenchmarkCommand]: + if device == "desktop" and driver == "cpu": + return self._generate_cpu(device) + elif device == "desktop" and driver == "gpu": + return self._generate_gpu("cuda") + elif device == "mobile" and driver == "cpu": + return self._generate_cpu(device) + elif device == "mobile" and driver == "gpu": + return self._generate_gpu("vulkan") + else: + print("Warning! Not a valid configuration.") + return [] + + def _generate_cpu(self, device: str): + commands = [] + # Generate TFLite benchmarks. + tflite_model_path = os.path.join( + self._base_dir, "models", "tflite", self._model_name + ".tflite" + ) + tflite = TfliteWrapper( + self._tflite_benchmark_binary_path, + self._model_name, + tflite_model_path, + self._input_name, + driver="cpu", + ) + commands.append(tflite) + + tflite_noxnn = TfliteWrapper( + self._tflite_benchmark_binary_path, + self._model_name + "_noxnn", + tflite_model_path, + self._input_name, + driver="cpu", + ) + tflite_noxnn.args.append("--use_xnnpack=false") + commands.append(tflite_noxnn) + + # Generate IREE benchmarks. + driver = "local-task" + backend = "llvm-cpu" + + iree_model_path = os.path.join( + self._base_dir, "models", "iree", backend, self._model_name + ".vmfb" + ) + iree = IreeWrapper( + self._iree_benchmark_binary_path, + self._model_name, + iree_model_path, + self._function_input, + driver=driver, + ) + commands.append(iree) + + model_padfuse_name = self._model_name + "_padfuse" + iree_padfuse_model_path = os.path.join( + self._base_dir, "models", "iree", backend, model_padfuse_name + ".vmfb" + ) + iree_padfuse = IreeWrapper( + self._iree_benchmark_binary_path, + model_padfuse_name, + iree_padfuse_model_path, + self._function_input, + driver=driver, + ) + commands.append(iree_padfuse) + + # Test mmt4d only on mobile. + if device == "mobile": + model_mmt4d_name = self._model_name + "_mmt4d" + iree_mmt4d_model_path = os.path.join( + self._base_dir, "models", "iree", backend, model_mmt4d_name + ".vmfb" + ) + iree_mmt4d = IreeWrapper( + self._iree_benchmark_binary_path, + model_mmt4d_name, + iree_mmt4d_model_path, + self._function_input, + driver=driver, + ) + commands.append(iree_mmt4d) + + model_im2col_mmt4d_name = self._model_name + "_im2col_mmt4d" + iree_im2col_mmt4d_model_path = os.path.join( + self._base_dir, + "models", + "iree", + backend, + model_im2col_mmt4d_name + ".vmfb", + ) + iree_im2col_mmt4d = IreeWrapper( + self._iree_benchmark_binary_path, + model_im2col_mmt4d_name, + iree_im2col_mmt4d_model_path, + self._function_input, + driver=driver, + ) + commands.append(iree_im2col_mmt4d) + + return commands + + def _generate_gpu(self, driver: str): + commands = [] + tflite_model_path = os.path.join( + self._base_dir, "models", "tflite", self._model_name + ".tflite" + ) + tflite = TfliteWrapper( + self._tflite_benchmark_binary_path, + self._model_name, + tflite_model_path, + self._input_name, + self._input_layer, + driver="gpu", + ) + tflite.args.append("--gpu_precision_loss_allowed=false") + commands.append(tflite) + + tflite_noxnn = TfliteWrapper( + self._tflite_benchmark_binary_path, + self._model_name + "_noxnn", + tflite_model_path, + self._input_name, + self._input_layer, + driver="gpu", + ) + tflite.args.append("--use_xnnpack=false") + commands.append(tflite_noxnn) + + tflite_fp16 = TfliteWrapper( + self._tflite_benchmark_binary_path, + self._model_name + "_fp16", + tflite_model_path, + self._input_name, + self._input_layer, + driver="gpu", + ) + tflite.args.append("--gpu_precision_loss_allowed=true") + commands.append(tflite_fp16) + + iree_model_path = os.path.join( + self._base_dir, "models", "iree", driver, self._model_name + ".vmfb" + ) + iree = IreeWrapper( + self._iree_benchmark_binary_path, + self._model_name, + iree_model_path, + self._function_input, + driver=driver, + ) + commands.append(iree) + + iree_model_path = os.path.join( + self._base_dir, "models", "iree", driver, self._model_name + "_fp16.vmfb" + ) + iree = IreeWrapper( + self._iree_benchmark_binary_path, + self._model_name + "_fp16", + iree_model_path, + self._function_input, + driver=driver, + ) + commands.append(iree) + + iree_model_path = os.path.join( + self._base_dir, "models", "iree", driver, self._model_name + "_padfuse.vmfb" + ) + iree = IreeWrapper( + self._iree_benchmark_binary_path, + self._model_name + "_padfuse", + iree_model_path, + self._function_input, + driver=driver, + ) + commands.append(iree) + return commands diff --git a/build_tools/benchmarks/diff_local_benchmarks.py b/build_tools/benchmarks/diff_local_benchmarks.py index 6b458d769658..eeb43339a0b9 100755 --- a/build_tools/benchmarks/diff_local_benchmarks.py +++ b/build_tools/benchmarks/diff_local_benchmarks.py @@ -29,89 +29,100 @@ def get_benchmark_result_markdown( target_benchmark_file: Optional[pathlib.Path], base_compile_stats_file: Optional[pathlib.Path], target_compile_stats_file: Optional[pathlib.Path], - verbose: bool = False) -> str: - """Gets the full markdown summary of all benchmarks in files.""" - base_benchmarks = {} - target_benchmarks = {} - base_compilation_metrics = {} - target_compilation_metrics = {} - if base_benchmark_file and target_benchmark_file: - base_benchmarks = aggregate_all_benchmarks([base_benchmark_file]) - target_benchmarks = aggregate_all_benchmarks([target_benchmark_file]) - if base_compile_stats_file and target_compile_stats_file: - base_compilation_metrics = collect_all_compilation_metrics( - [base_compile_stats_file]) - target_compilation_metrics = collect_all_compilation_metrics( - [target_compile_stats_file]) - - # Update the target benchmarks with their corresponding base numbers. - for bench in base_benchmarks: - if bench in target_benchmarks: - target_benchmarks[bench].base_mean_time = base_benchmarks[bench].mean_time - - for target_name, base_metrics in base_compilation_metrics.items(): - updated_metrics = base_metrics - for mapper in COMPILATION_METRICS_TO_TABLE_MAPPERS: - metric_key = mapper.get_series_name(target_name) - base_value, _ = mapper.get_current_and_base_value(base_metrics) - updated_metrics = mapper.update_base_value(updated_metrics, base_value) - target_compilation_metrics[target_name] = updated_metrics - - # Compose the full benchmark tables. - full_table = [md.header("Full Benchmark Summary", 2)] - full_table.append(categorize_benchmarks_into_tables(target_benchmarks)) - - # Compose the full compilation metrics tables. - full_table.append( - categorize_compilation_metrics_into_tables(target_compilation_metrics)) - - return "\n\n".join(full_table) + verbose: bool = False, +) -> str: + """Gets the full markdown summary of all benchmarks in files.""" + base_benchmarks = {} + target_benchmarks = {} + base_compilation_metrics = {} + target_compilation_metrics = {} + if base_benchmark_file and target_benchmark_file: + base_benchmarks = aggregate_all_benchmarks([base_benchmark_file]) + target_benchmarks = aggregate_all_benchmarks([target_benchmark_file]) + if base_compile_stats_file and target_compile_stats_file: + base_compilation_metrics = collect_all_compilation_metrics( + [base_compile_stats_file] + ) + target_compilation_metrics = collect_all_compilation_metrics( + [target_compile_stats_file] + ) + + # Update the target benchmarks with their corresponding base numbers. + for bench in base_benchmarks: + if bench in target_benchmarks: + target_benchmarks[bench].base_mean_time = base_benchmarks[bench].mean_time + + for target_name, base_metrics in base_compilation_metrics.items(): + updated_metrics = base_metrics + for mapper in COMPILATION_METRICS_TO_TABLE_MAPPERS: + metric_key = mapper.get_series_name(target_name) + base_value, _ = mapper.get_current_and_base_value(base_metrics) + updated_metrics = mapper.update_base_value(updated_metrics, base_value) + target_compilation_metrics[target_name] = updated_metrics + + # Compose the full benchmark tables. + full_table = [md.header("Full Benchmark Summary", 2)] + full_table.append(categorize_benchmarks_into_tables(target_benchmarks)) + + # Compose the full compilation metrics tables. + full_table.append( + categorize_compilation_metrics_into_tables(target_compilation_metrics) + ) + + return "\n\n".join(full_table) def parse_arguments(): - """Parses command-line options.""" - - def check_file_path(path): - path = pathlib.Path(path) - if path.is_file(): - return path - else: - raise ValueError(path) - - parser = argparse.ArgumentParser() - parser.add_argument("--base", - type=check_file_path, - help="Base benchmark results") - parser.add_argument("--target", - type=check_file_path, - help="Target benchmark results") - parser.add_argument("--base-compile-stats", - type=check_file_path, - help="Base compilation statistics") - parser.add_argument("--target-compile-stats", - type=check_file_path, - help="Target compilation statistics") - parser.add_argument("--verbose", - action="store_true", - help="Print internal information during execution") - args = parser.parse_args() - - return args + """Parses command-line options.""" + + def check_file_path(path): + path = pathlib.Path(path) + if path.is_file(): + return path + else: + raise ValueError(path) + + parser = argparse.ArgumentParser() + parser.add_argument("--base", type=check_file_path, help="Base benchmark results") + parser.add_argument( + "--target", type=check_file_path, help="Target benchmark results" + ) + parser.add_argument( + "--base-compile-stats", type=check_file_path, help="Base compilation statistics" + ) + parser.add_argument( + "--target-compile-stats", + type=check_file_path, + help="Target compilation statistics", + ) + parser.add_argument( + "--verbose", + action="store_true", + help="Print internal information during execution", + ) + args = parser.parse_args() + + return args if __name__ == "__main__": - args = parse_arguments() - if args.base or args.target: - if not args.base or not args.target: - raise ValueError("--base and --target must be used together.") - if args.base_compile_stats or args.target_compile_stats: - if not args.base_compile_stats or not args.target_compile_stats: - raise ValueError("--base-compile-stats and --target-compile-stats must " - "be used together.") - - print( - get_benchmark_result_markdown(args.base, - args.target, - args.base_compile_stats, - args.target_compile_stats, - verbose=args.verbose)) + args = parse_arguments() + if args.base or args.target: + if not args.base or not args.target: + raise ValueError("--base and --target must be used together.") + if args.base_compile_stats or args.target_compile_stats: + if not args.base_compile_stats or not args.target_compile_stats: + raise ValueError( + "--base-compile-stats and --target-compile-stats must " + "be used together." + ) + + print( + get_benchmark_result_markdown( + args.base, + args.target, + args.base_compile_stats, + args.target_compile_stats, + verbose=args.verbose, + ) + ) diff --git a/build_tools/benchmarks/export_benchmark_config.py b/build_tools/benchmarks/export_benchmark_config.py index 2fbf27002120..1d0853b1dd45 100755 --- a/build_tools/benchmarks/export_benchmark_config.py +++ b/build_tools/benchmarks/export_benchmark_config.py @@ -47,172 +47,181 @@ PresetMatcher = Callable[[Any], bool] EXECUTION_BENCHMARK_PRESET_MATCHERS: Dict[str, PresetMatcher] = { - "x86_64": - lambda config: (benchmark_tags.X86_64 in config.tags and benchmark_tags. - LARGE not in config.tags), - "x86_64-large": - lambda config: (benchmark_tags.X86_64 in config.tags and benchmark_tags. - LARGE in config.tags), - "cuda": - lambda config: (benchmark_tags.CUDA in config.tags and benchmark_tags. - LARGE not in config.tags), - "cuda-large": - lambda config: (benchmark_tags.CUDA in config.tags and benchmark_tags. - LARGE in config.tags), - "vulkan-nvidia": - lambda config: benchmark_tags.VULKAN_NVIDIA in config.tags, - "android-cpu": - lambda config: - (config.target_device_spec.architecture.type == common_definitions. - ArchitectureType.CPU and config.target_device_spec.host_environment. - platform == "android"), - "android-gpu": - lambda config: - (config.target_device_spec.architecture.type == common_definitions. - ArchitectureType.GPU and config.target_device_spec.host_environment. - platform == "android"), + "x86_64": lambda config: ( + benchmark_tags.X86_64 in config.tags and benchmark_tags.LARGE not in config.tags + ), + "x86_64-large": lambda config: ( + benchmark_tags.X86_64 in config.tags and benchmark_tags.LARGE in config.tags + ), + "cuda": lambda config: ( + benchmark_tags.CUDA in config.tags and benchmark_tags.LARGE not in config.tags + ), + "cuda-large": lambda config: ( + benchmark_tags.CUDA in config.tags and benchmark_tags.LARGE in config.tags + ), + "vulkan-nvidia": lambda config: benchmark_tags.VULKAN_NVIDIA in config.tags, + "android-cpu": lambda config: ( + config.target_device_spec.architecture.type + == common_definitions.ArchitectureType.CPU + and config.target_device_spec.host_environment.platform == "android" + ), + "android-gpu": lambda config: ( + config.target_device_spec.architecture.type + == common_definitions.ArchitectureType.GPU + and config.target_device_spec.host_environment.platform == "android" + ), } COMPILATION_BENCHMARK_PRESET_MATCHERS: Dict[str, PresetMatcher] = { - "comp-stats": - lambda gen_config: benchmark_tags.LARGE not in gen_config.tags, - "comp-stats-large": - lambda gen_config: benchmark_tags.LARGE in gen_config.tags, + "comp-stats": lambda gen_config: benchmark_tags.LARGE not in gen_config.tags, + "comp-stats-large": lambda gen_config: benchmark_tags.LARGE in gen_config.tags, } def filter_and_group_run_configs( run_configs: List[iree_definitions.E2EModelRunConfig], target_device_names: Optional[Set[str]] = None, - preset_matchers: Optional[Sequence[PresetMatcher]] = None + preset_matchers: Optional[Sequence[PresetMatcher]] = None, ) -> Dict[str, List[iree_definitions.E2EModelRunConfig]]: - """Filters run configs and groups by target device name. - - Args: - run_configs: source e2e model run configs. - target_device_names: list of target device names, includes all if not set. - preset_matchers: list of preset matcher, matches all if not set. - - Returns: - A map of e2e model run configs keyed by target device name. - """ - grouped_run_config_map = collections.defaultdict(list) - - for run_config in run_configs: - device_name = run_config.target_device_spec.device_name - if (target_device_names is not None and - device_name not in target_device_names): - continue - if (preset_matchers is not None and - not any(matcher(run_config) for matcher in preset_matchers)): - continue - grouped_run_config_map[device_name].append(run_config) - - return grouped_run_config_map + """Filters run configs and groups by target device name. + + Args: + run_configs: source e2e model run configs. + target_device_names: list of target device names, includes all if not set. + preset_matchers: list of preset matcher, matches all if not set. + + Returns: + A map of e2e model run configs keyed by target device name. + """ + grouped_run_config_map = collections.defaultdict(list) + + for run_config in run_configs: + device_name = run_config.target_device_spec.device_name + if target_device_names is not None and device_name not in target_device_names: + continue + if preset_matchers is not None and not any( + matcher(run_config) for matcher in preset_matchers + ): + continue + grouped_run_config_map[device_name].append(run_config) + + return grouped_run_config_map def _get_distinct_module_dir_paths( - module_generation_configs: Iterable[ - iree_definitions.ModuleGenerationConfig], - root_path: pathlib.PurePath = pathlib.PurePath() + module_generation_configs: Iterable[iree_definitions.ModuleGenerationConfig], + root_path: pathlib.PurePath = pathlib.PurePath(), ) -> List[str]: - module_dir_paths = (str( - iree_artifacts.get_module_dir_path(config, root_path=root_path)) - for config in module_generation_configs) - return sorted(set(module_dir_paths)) + module_dir_paths = ( + str(iree_artifacts.get_module_dir_path(config, root_path=root_path)) + for config in module_generation_configs + ) + return sorted(set(module_dir_paths)) def _export_execution_handler( benchmark_presets: Optional[Sequence[PresetMatcher]] = None, target_device_names: Optional[Sequence[str]] = None, - **_unused_args): - _, all_run_configs = benchmark_collections.generate_benchmarks() - target_device_name_set = (None if target_device_names is None else - set(target_device_names)) - grouped_run_config_map = filter_and_group_run_configs( - all_run_configs, - target_device_names=target_device_name_set, - preset_matchers=benchmark_presets) - - output_map = {} - for device_name, run_configs in grouped_run_config_map.items(): - host_environments = set(run_config.target_device_spec.host_environment - for run_config in run_configs) - if len(host_environments) > 1: - raise ValueError( - "Device specs of the same device should have the same host environment." - ) - host_environment = host_environments.pop() - - distinct_module_dir_paths = _get_distinct_module_dir_paths( - config.module_generation_config for config in run_configs) - - output_map[device_name] = { - "host_environment": dataclasses.asdict(host_environment), - "module_dir_paths": distinct_module_dir_paths, - "run_configs": serialization.serialize_and_pack(run_configs), - } - - return output_map + **_unused_args, +): + _, all_run_configs = benchmark_collections.generate_benchmarks() + target_device_name_set = ( + None if target_device_names is None else set(target_device_names) + ) + grouped_run_config_map = filter_and_group_run_configs( + all_run_configs, + target_device_names=target_device_name_set, + preset_matchers=benchmark_presets, + ) + + output_map = {} + for device_name, run_configs in grouped_run_config_map.items(): + host_environments = set( + run_config.target_device_spec.host_environment for run_config in run_configs + ) + if len(host_environments) > 1: + raise ValueError( + "Device specs of the same device should have the same host environment." + ) + host_environment = host_environments.pop() + + distinct_module_dir_paths = _get_distinct_module_dir_paths( + config.module_generation_config for config in run_configs + ) + + output_map[device_name] = { + "host_environment": dataclasses.asdict(host_environment), + "module_dir_paths": distinct_module_dir_paths, + "run_configs": serialization.serialize_and_pack(run_configs), + } + + return output_map def _export_compilation_handler( - benchmark_presets: Optional[Sequence[PresetMatcher]] = None, - **_unused_args): - all_gen_configs, _ = benchmark_collections.generate_benchmarks() - compile_stats_gen_configs = [ - config for config in all_gen_configs - if benchmark_tags.COMPILE_STATS in config.compile_config.tags - ] - - if benchmark_presets is not None: - match_predicate = lambda gen_config: any( - matcher(gen_config) for matcher in benchmark_presets) + benchmark_presets: Optional[Sequence[PresetMatcher]] = None, **_unused_args +): + all_gen_configs, _ = benchmark_collections.generate_benchmarks() compile_stats_gen_configs = [ - gen_config for gen_config in compile_stats_gen_configs - if match_predicate(gen_config) + config + for config in all_gen_configs + if benchmark_tags.COMPILE_STATS in config.compile_config.tags ] - distinct_module_dir_paths = _get_distinct_module_dir_paths( - compile_stats_gen_configs) + if benchmark_presets is not None: + match_predicate = lambda gen_config: any( + matcher(gen_config) for matcher in benchmark_presets + ) + compile_stats_gen_configs = [ + gen_config + for gen_config in compile_stats_gen_configs + if match_predicate(gen_config) + ] - return { - "module_dir_paths": - distinct_module_dir_paths, - "generation_configs": - serialization.serialize_and_pack(compile_stats_gen_configs) - } + distinct_module_dir_paths = _get_distinct_module_dir_paths( + compile_stats_gen_configs + ) + + return { + "module_dir_paths": distinct_module_dir_paths, + "generation_configs": serialization.serialize_and_pack( + compile_stats_gen_configs + ), + } def _parse_and_strip_list_argument(arg: str) -> List[str]: - return [part.strip() for part in arg.split(",") if part != ""] + return [part.strip() for part in arg.split(",") if part != ""] def _parse_benchmark_presets( - arg: str, matcher_map: Dict[str, PresetMatcher]) -> List[PresetMatcher]: - matchers = [] - for preset in _parse_and_strip_list_argument(arg): - matcher = matcher_map.get(preset) - if matcher is None: - raise argparse.ArgumentTypeError( - f"Unrecognized benchmark preset: '{preset}'.") - matchers.append(matcher) - return matchers + arg: str, matcher_map: Dict[str, PresetMatcher] +) -> List[PresetMatcher]: + matchers = [] + for preset in _parse_and_strip_list_argument(arg): + matcher = matcher_map.get(preset) + if matcher is None: + raise argparse.ArgumentTypeError( + f"Unrecognized benchmark preset: '{preset}'." + ) + matchers.append(matcher) + return matchers def _parse_arguments(): - """Parses command-line options.""" - - # Makes global options come *after* command. - # See https://stackoverflow.com/q/23296695 - subparser_base = argparse.ArgumentParser(add_help=False) - subparser_base.add_argument("--output", - type=pathlib.Path, - help="Path to write the JSON output.") - - parser = argparse.ArgumentParser( - formatter_class=argparse.RawDescriptionHelpFormatter, - description=textwrap.dedent(""" + """Parses command-line options.""" + + # Makes global options come *after* command. + # See https://stackoverflow.com/q/23296695 + subparser_base = argparse.ArgumentParser(add_help=False) + subparser_base.add_argument( + "--output", type=pathlib.Path, help="Path to write the JSON output." + ) + + parser = argparse.ArgumentParser( + formatter_class=argparse.RawDescriptionHelpFormatter, + description=textwrap.dedent( + """ Export type: "execution" outputs: [ : { @@ -231,53 +240,70 @@ def _parse_arguments(): } of generation configs defined for compilation statistics, to be used in build_tools/benchmarks/collect_compilation_statistics.py - """)) - - subparser = parser.add_subparsers(required=True, title="export type") - execution_parser = subparser.add_parser( - "execution", - parents=[subparser_base], - help="Export execution config to run benchmarks.") - execution_parser.set_defaults(handler=_export_execution_handler) - execution_parser.add_argument( - "--target_device_names", - type=_parse_and_strip_list_argument, - help=("Target device names, separated by comma, not specified means " - "including all devices.")) - execution_parser.add_argument( - "--benchmark_presets", - type=lambda arg: _parse_benchmark_presets( - arg, EXECUTION_BENCHMARK_PRESET_MATCHERS), - help=("Presets that select a bundle of benchmarks, separated by comma, " + """ + ), + ) + + subparser = parser.add_subparsers(required=True, title="export type") + execution_parser = subparser.add_parser( + "execution", + parents=[subparser_base], + help="Export execution config to run benchmarks.", + ) + execution_parser.set_defaults(handler=_export_execution_handler) + execution_parser.add_argument( + "--target_device_names", + type=_parse_and_strip_list_argument, + help=( + "Target device names, separated by comma, not specified means " + "including all devices." + ), + ) + execution_parser.add_argument( + "--benchmark_presets", + type=lambda arg: _parse_benchmark_presets( + arg, EXECUTION_BENCHMARK_PRESET_MATCHERS + ), + help=( + "Presets that select a bundle of benchmarks, separated by comma, " "multiple presets will be union. Available options: " - f"{','.join(EXECUTION_BENCHMARK_PRESET_MATCHERS.keys())}")) - - compilation_parser = subparser.add_parser( - "compilation", - parents=[subparser_base], - help=("Export serialized list of module generation configs defined for " - "compilation statistics.")) - compilation_parser.set_defaults(handler=_export_compilation_handler) - compilation_parser.add_argument( - "--benchmark_presets", - type=lambda arg: _parse_benchmark_presets( - arg, COMPILATION_BENCHMARK_PRESET_MATCHERS), - help=("Presets `comp-stats*` that select a bundle of compilation" + f"{','.join(EXECUTION_BENCHMARK_PRESET_MATCHERS.keys())}" + ), + ) + + compilation_parser = subparser.add_parser( + "compilation", + parents=[subparser_base], + help=( + "Export serialized list of module generation configs defined for " + "compilation statistics." + ), + ) + compilation_parser.set_defaults(handler=_export_compilation_handler) + compilation_parser.add_argument( + "--benchmark_presets", + type=lambda arg: _parse_benchmark_presets( + arg, COMPILATION_BENCHMARK_PRESET_MATCHERS + ), + help=( + "Presets `comp-stats*` that select a bundle of compilation" " benchmarks, separated by comma, multiple presets will be union." " Available options: " - f"{','.join(COMPILATION_BENCHMARK_PRESET_MATCHERS.keys())}")) + f"{','.join(COMPILATION_BENCHMARK_PRESET_MATCHERS.keys())}" + ), + ) - return parser.parse_args() + return parser.parse_args() def main(args: argparse.Namespace): - output_obj = args.handler(**vars(args)) - json_data = json.dumps(output_obj, indent=2) - if args.output is None: - print(json_data) - else: - args.output.write_text(json_data) + output_obj = args.handler(**vars(args)) + json_data = json.dumps(output_obj, indent=2) + if args.output is None: + print(json_data) + else: + args.output.write_text(json_data) if __name__ == "__main__": - main(_parse_arguments()) + main(_parse_arguments()) diff --git a/build_tools/benchmarks/export_benchmark_config_test.py b/build_tools/benchmarks/export_benchmark_config_test.py index 2ab342dc08db..b186301a806b 100644 --- a/build_tools/benchmarks/export_benchmark_config_test.py +++ b/build_tools/benchmarks/export_benchmark_config_test.py @@ -17,7 +17,8 @@ source_type=common_definitions.ModelSourceType.EXPORTED_TFLITE, source_url="", entry_function="predict", - input_types=["1xf32"]) + input_types=["1xf32"], +) COMMON_GEN_CONFIG = iree_definitions.ModuleGenerationConfig.build( imported_model=iree_definitions.ImportedModel.from_model(COMMON_MODEL), compile_config=iree_definitions.CompileConfig.build( @@ -26,221 +27,273 @@ compile_targets=[ iree_definitions.CompileTarget( target_backend=iree_definitions.TargetBackend.LLVM_CPU, - target_architecture=common_definitions.DeviceArchitecture. - RV64_GENERIC, - target_abi=iree_definitions.TargetABI.LINUX_GNU) - ])) + target_architecture=common_definitions.DeviceArchitecture.RV64_GENERIC, + target_abi=iree_definitions.TargetABI.LINUX_GNU, + ) + ], + ), +) COMMON_EXEC_CONFIG = iree_definitions.ModuleExecutionConfig.build( id="exec", tags=[], loader=iree_definitions.RuntimeLoader.EMBEDDED_ELF, - driver=iree_definitions.RuntimeDriver.LOCAL_SYNC) + driver=iree_definitions.RuntimeDriver.LOCAL_SYNC, +) class ExportBenchmarkConfigTest(unittest.TestCase): + def test_filter_and_group_run_configs_set_all_filters(self): + device_spec_a = common_definitions.DeviceSpec.build( + id="dev_a_cpu", + device_name="dev_a_cpu", + architecture=common_definitions.DeviceArchitecture.RV64_GENERIC, + host_environment=common_definitions.HostEnvironment.ANDROID_ARMV8_2_A, + tags=[], + ) + device_spec_b = common_definitions.DeviceSpec.build( + id="dev_a_gpu", + device_name="dev_a_gpu", + architecture=common_definitions.DeviceArchitecture.ARM_VALHALL, + host_environment=common_definitions.HostEnvironment.ANDROID_ARMV8_2_A, + tags=[], + ) + device_spec_c = common_definitions.DeviceSpec.build( + id="dev_c", + device_name="dev_c", + architecture=common_definitions.DeviceArchitecture.CUDA_SM80, + host_environment=common_definitions.HostEnvironment.LINUX_X86_64, + tags=[], + ) + matched_run_config_a = iree_definitions.E2EModelRunConfig.build( + module_generation_config=COMMON_GEN_CONFIG, + module_execution_config=COMMON_EXEC_CONFIG, + target_device_spec=device_spec_a, + input_data=common_definitions.ZEROS_MODEL_INPUT_DATA, + tool=iree_definitions.E2EModelRunTool.IREE_BENCHMARK_MODULE, + ) + unmatched_run_config_b = iree_definitions.E2EModelRunConfig.build( + module_generation_config=COMMON_GEN_CONFIG, + module_execution_config=COMMON_EXEC_CONFIG, + target_device_spec=device_spec_b, + input_data=common_definitions.ZEROS_MODEL_INPUT_DATA, + tool=iree_definitions.E2EModelRunTool.IREE_BENCHMARK_MODULE, + ) + matched_run_config_c = iree_definitions.E2EModelRunConfig.build( + module_generation_config=COMMON_GEN_CONFIG, + module_execution_config=COMMON_EXEC_CONFIG, + target_device_spec=device_spec_c, + input_data=common_definitions.ZEROS_MODEL_INPUT_DATA, + tool=iree_definitions.E2EModelRunTool.IREE_BENCHMARK_MODULE, + ) + matchers = [ + ( + lambda config: config.target_device_spec.architecture.architecture + == "cuda" + ), + ( + lambda config: config.target_device_spec.host_environment.platform + == "android" + ), + ] - def test_filter_and_group_run_configs_set_all_filters(self): - device_spec_a = common_definitions.DeviceSpec.build( - id="dev_a_cpu", - device_name="dev_a_cpu", - architecture=common_definitions.DeviceArchitecture.RV64_GENERIC, - host_environment=common_definitions.HostEnvironment.ANDROID_ARMV8_2_A, - tags=[]) - device_spec_b = common_definitions.DeviceSpec.build( - id="dev_a_gpu", - device_name="dev_a_gpu", - architecture=common_definitions.DeviceArchitecture.ARM_VALHALL, - host_environment=common_definitions.HostEnvironment.ANDROID_ARMV8_2_A, - tags=[]) - device_spec_c = common_definitions.DeviceSpec.build( - id="dev_c", - device_name="dev_c", - architecture=common_definitions.DeviceArchitecture.CUDA_SM80, - host_environment=common_definitions.HostEnvironment.LINUX_X86_64, - tags=[]) - matched_run_config_a = iree_definitions.E2EModelRunConfig.build( - module_generation_config=COMMON_GEN_CONFIG, - module_execution_config=COMMON_EXEC_CONFIG, - target_device_spec=device_spec_a, - input_data=common_definitions.ZEROS_MODEL_INPUT_DATA, - tool=iree_definitions.E2EModelRunTool.IREE_BENCHMARK_MODULE) - unmatched_run_config_b = iree_definitions.E2EModelRunConfig.build( - module_generation_config=COMMON_GEN_CONFIG, - module_execution_config=COMMON_EXEC_CONFIG, - target_device_spec=device_spec_b, - input_data=common_definitions.ZEROS_MODEL_INPUT_DATA, - tool=iree_definitions.E2EModelRunTool.IREE_BENCHMARK_MODULE) - matched_run_config_c = iree_definitions.E2EModelRunConfig.build( - module_generation_config=COMMON_GEN_CONFIG, - module_execution_config=COMMON_EXEC_CONFIG, - target_device_spec=device_spec_c, - input_data=common_definitions.ZEROS_MODEL_INPUT_DATA, - tool=iree_definitions.E2EModelRunTool.IREE_BENCHMARK_MODULE) - matchers = [(lambda config: config.target_device_spec.architecture. - architecture == "cuda"), - (lambda config: config.target_device_spec.host_environment. - platform == "android")] - - run_config_map = export_benchmark_config.filter_and_group_run_configs( - run_configs=[ - matched_run_config_a, unmatched_run_config_b, matched_run_config_c - ], - target_device_names={"dev_a_cpu", "dev_c"}, - preset_matchers=matchers) - - self.assertEqual(run_config_map, { - "dev_a_cpu": [matched_run_config_a], - "dev_c": [matched_run_config_c], - }) - - def test_filter_and_group_run_configs_include_all(self): - device_spec_a = common_definitions.DeviceSpec.build( - id="dev_a_cpu", - device_name="dev_a_cpu", - architecture=common_definitions.DeviceArchitecture.RV64_GENERIC, - host_environment=common_definitions.HostEnvironment.ANDROID_ARMV8_2_A, - tags=[]) - device_spec_b = common_definitions.DeviceSpec.build( - id="dev_a_gpu", - device_name="dev_a_gpu", - architecture=common_definitions.DeviceArchitecture.ARM_VALHALL, - host_environment=common_definitions.HostEnvironment.ANDROID_ARMV8_2_A, - tags=[]) - device_spec_c = common_definitions.DeviceSpec.build( - id="dev_a_second_gpu", - device_name="dev_a_gpu", - architecture=common_definitions.DeviceArchitecture.QUALCOMM_ADRENO, - host_environment=common_definitions.HostEnvironment.ANDROID_ARMV8_2_A, - tags=[]) - run_config_a = iree_definitions.E2EModelRunConfig.build( - module_generation_config=COMMON_GEN_CONFIG, - module_execution_config=COMMON_EXEC_CONFIG, - target_device_spec=device_spec_a, - input_data=common_definitions.ZEROS_MODEL_INPUT_DATA, - tool=iree_definitions.E2EModelRunTool.IREE_BENCHMARK_MODULE) - run_config_b = iree_definitions.E2EModelRunConfig.build( - module_generation_config=COMMON_GEN_CONFIG, - module_execution_config=COMMON_EXEC_CONFIG, - target_device_spec=device_spec_b, - input_data=common_definitions.ZEROS_MODEL_INPUT_DATA, - tool=iree_definitions.E2EModelRunTool.IREE_BENCHMARK_MODULE) - run_config_c = iree_definitions.E2EModelRunConfig.build( - module_generation_config=COMMON_GEN_CONFIG, - module_execution_config=COMMON_EXEC_CONFIG, - target_device_spec=device_spec_c, - input_data=common_definitions.ZEROS_MODEL_INPUT_DATA, - tool=iree_definitions.E2EModelRunTool.IREE_BENCHMARK_MODULE) - - run_config_map = export_benchmark_config.filter_and_group_run_configs( - run_configs=[run_config_a, run_config_b, run_config_c]) - - self.maxDiff = 100000 - - self.assertEqual(run_config_map, { - "dev_a_cpu": [run_config_a], - "dev_a_gpu": [run_config_b, run_config_c], - }) - - def test_filter_and_group_run_configs_set_target_device_names(self): - device_spec_a = common_definitions.DeviceSpec.build( - id="dev_a", - device_name="dev_a", - architecture=common_definitions.DeviceArchitecture.RV64_GENERIC, - host_environment=common_definitions.HostEnvironment.ANDROID_ARMV8_2_A, - tags=[]) - device_spec_b = common_definitions.DeviceSpec.build( - id="dev_b", - device_name="dev_b", - architecture=common_definitions.DeviceArchitecture.ARM_VALHALL, - host_environment=common_definitions.HostEnvironment.ANDROID_ARMV8_2_A, - tags=[]) - run_config_a = iree_definitions.E2EModelRunConfig.build( - module_generation_config=COMMON_GEN_CONFIG, - module_execution_config=COMMON_EXEC_CONFIG, - target_device_spec=device_spec_a, - input_data=common_definitions.ZEROS_MODEL_INPUT_DATA, - tool=iree_definitions.E2EModelRunTool.IREE_BENCHMARK_MODULE) - run_config_b = iree_definitions.E2EModelRunConfig.build( - module_generation_config=COMMON_GEN_CONFIG, - module_execution_config=COMMON_EXEC_CONFIG, - target_device_spec=device_spec_b, - input_data=common_definitions.ZEROS_MODEL_INPUT_DATA, - tool=iree_definitions.E2EModelRunTool.IREE_BENCHMARK_MODULE) - - run_config_map = export_benchmark_config.filter_and_group_run_configs( - run_configs=[run_config_a, run_config_b], - target_device_names={"dev_a", "dev_b"}) - - self.assertEqual(run_config_map, { - "dev_a": [run_config_a], - "dev_b": [run_config_b], - }) - - def test_filter_and_group_run_configs_set_preset_matchers(self): - small_model = common_definitions.Model( - id="small_model", - name="small_model", - tags=[], - source_type=common_definitions.ModelSourceType.EXPORTED_TFLITE, - source_url="", - entry_function="predict", - input_types=["1xf32"]) - big_model = common_definitions.Model( - id="big_model", - name="big_model", - tags=[], - source_type=common_definitions.ModelSourceType.EXPORTED_TFLITE, - source_url="", - entry_function="predict", - input_types=["1xf32"]) - compile_target = iree_definitions.CompileTarget( - target_backend=iree_definitions.TargetBackend.LLVM_CPU, - target_architecture=common_definitions.DeviceArchitecture.RV64_GENERIC, - target_abi=iree_definitions.TargetABI.LINUX_GNU) - compile_config = iree_definitions.CompileConfig.build( - id="1", tags=[], compile_targets=[compile_target]) - small_gen_config = iree_definitions.ModuleGenerationConfig.build( - imported_model=iree_definitions.ImportedModel.from_model(small_model), - compile_config=compile_config) - big_gen_config = iree_definitions.ModuleGenerationConfig.build( - imported_model=iree_definitions.ImportedModel.from_model(big_model), - compile_config=compile_config) - device_spec_a = common_definitions.DeviceSpec.build( - id="dev_a", - device_name="dev_a", - architecture=common_definitions.DeviceArchitecture.RV64_GENERIC, - host_environment=common_definitions.HostEnvironment.ANDROID_ARMV8_2_A, - tags=[]) - device_spec_b = common_definitions.DeviceSpec.build( - id="dev_b", - device_name="dev_b", - architecture=common_definitions.DeviceArchitecture.ARM_VALHALL, - host_environment=common_definitions.HostEnvironment.ANDROID_ARMV8_2_A, - tags=[]) - run_config_a = iree_definitions.E2EModelRunConfig.build( - module_generation_config=small_gen_config, - module_execution_config=COMMON_EXEC_CONFIG, - target_device_spec=device_spec_a, - input_data=common_definitions.ZEROS_MODEL_INPUT_DATA, - tool=iree_definitions.E2EModelRunTool.IREE_BENCHMARK_MODULE) - run_config_b = iree_definitions.E2EModelRunConfig.build( - module_generation_config=big_gen_config, - module_execution_config=COMMON_EXEC_CONFIG, - target_device_spec=device_spec_b, - input_data=common_definitions.ZEROS_MODEL_INPUT_DATA, - tool=iree_definitions.E2EModelRunTool.IREE_BENCHMARK_MODULE) - - run_config_map = export_benchmark_config.filter_and_group_run_configs( - run_configs=[run_config_a, run_config_b], - preset_matchers=[ - lambda config: config.module_generation_config.imported_model.model. - id == "small_model" - ]) - - self.assertEqual(run_config_map, { - "dev_a": [run_config_a], - }) + run_config_map = export_benchmark_config.filter_and_group_run_configs( + run_configs=[ + matched_run_config_a, + unmatched_run_config_b, + matched_run_config_c, + ], + target_device_names={"dev_a_cpu", "dev_c"}, + preset_matchers=matchers, + ) + + self.assertEqual( + run_config_map, + { + "dev_a_cpu": [matched_run_config_a], + "dev_c": [matched_run_config_c], + }, + ) + + def test_filter_and_group_run_configs_include_all(self): + device_spec_a = common_definitions.DeviceSpec.build( + id="dev_a_cpu", + device_name="dev_a_cpu", + architecture=common_definitions.DeviceArchitecture.RV64_GENERIC, + host_environment=common_definitions.HostEnvironment.ANDROID_ARMV8_2_A, + tags=[], + ) + device_spec_b = common_definitions.DeviceSpec.build( + id="dev_a_gpu", + device_name="dev_a_gpu", + architecture=common_definitions.DeviceArchitecture.ARM_VALHALL, + host_environment=common_definitions.HostEnvironment.ANDROID_ARMV8_2_A, + tags=[], + ) + device_spec_c = common_definitions.DeviceSpec.build( + id="dev_a_second_gpu", + device_name="dev_a_gpu", + architecture=common_definitions.DeviceArchitecture.QUALCOMM_ADRENO, + host_environment=common_definitions.HostEnvironment.ANDROID_ARMV8_2_A, + tags=[], + ) + run_config_a = iree_definitions.E2EModelRunConfig.build( + module_generation_config=COMMON_GEN_CONFIG, + module_execution_config=COMMON_EXEC_CONFIG, + target_device_spec=device_spec_a, + input_data=common_definitions.ZEROS_MODEL_INPUT_DATA, + tool=iree_definitions.E2EModelRunTool.IREE_BENCHMARK_MODULE, + ) + run_config_b = iree_definitions.E2EModelRunConfig.build( + module_generation_config=COMMON_GEN_CONFIG, + module_execution_config=COMMON_EXEC_CONFIG, + target_device_spec=device_spec_b, + input_data=common_definitions.ZEROS_MODEL_INPUT_DATA, + tool=iree_definitions.E2EModelRunTool.IREE_BENCHMARK_MODULE, + ) + run_config_c = iree_definitions.E2EModelRunConfig.build( + module_generation_config=COMMON_GEN_CONFIG, + module_execution_config=COMMON_EXEC_CONFIG, + target_device_spec=device_spec_c, + input_data=common_definitions.ZEROS_MODEL_INPUT_DATA, + tool=iree_definitions.E2EModelRunTool.IREE_BENCHMARK_MODULE, + ) + + run_config_map = export_benchmark_config.filter_and_group_run_configs( + run_configs=[run_config_a, run_config_b, run_config_c] + ) + + self.maxDiff = 100000 + + self.assertEqual( + run_config_map, + { + "dev_a_cpu": [run_config_a], + "dev_a_gpu": [run_config_b, run_config_c], + }, + ) + + def test_filter_and_group_run_configs_set_target_device_names(self): + device_spec_a = common_definitions.DeviceSpec.build( + id="dev_a", + device_name="dev_a", + architecture=common_definitions.DeviceArchitecture.RV64_GENERIC, + host_environment=common_definitions.HostEnvironment.ANDROID_ARMV8_2_A, + tags=[], + ) + device_spec_b = common_definitions.DeviceSpec.build( + id="dev_b", + device_name="dev_b", + architecture=common_definitions.DeviceArchitecture.ARM_VALHALL, + host_environment=common_definitions.HostEnvironment.ANDROID_ARMV8_2_A, + tags=[], + ) + run_config_a = iree_definitions.E2EModelRunConfig.build( + module_generation_config=COMMON_GEN_CONFIG, + module_execution_config=COMMON_EXEC_CONFIG, + target_device_spec=device_spec_a, + input_data=common_definitions.ZEROS_MODEL_INPUT_DATA, + tool=iree_definitions.E2EModelRunTool.IREE_BENCHMARK_MODULE, + ) + run_config_b = iree_definitions.E2EModelRunConfig.build( + module_generation_config=COMMON_GEN_CONFIG, + module_execution_config=COMMON_EXEC_CONFIG, + target_device_spec=device_spec_b, + input_data=common_definitions.ZEROS_MODEL_INPUT_DATA, + tool=iree_definitions.E2EModelRunTool.IREE_BENCHMARK_MODULE, + ) + + run_config_map = export_benchmark_config.filter_and_group_run_configs( + run_configs=[run_config_a, run_config_b], + target_device_names={"dev_a", "dev_b"}, + ) + + self.assertEqual( + run_config_map, + { + "dev_a": [run_config_a], + "dev_b": [run_config_b], + }, + ) + + def test_filter_and_group_run_configs_set_preset_matchers(self): + small_model = common_definitions.Model( + id="small_model", + name="small_model", + tags=[], + source_type=common_definitions.ModelSourceType.EXPORTED_TFLITE, + source_url="", + entry_function="predict", + input_types=["1xf32"], + ) + big_model = common_definitions.Model( + id="big_model", + name="big_model", + tags=[], + source_type=common_definitions.ModelSourceType.EXPORTED_TFLITE, + source_url="", + entry_function="predict", + input_types=["1xf32"], + ) + compile_target = iree_definitions.CompileTarget( + target_backend=iree_definitions.TargetBackend.LLVM_CPU, + target_architecture=common_definitions.DeviceArchitecture.RV64_GENERIC, + target_abi=iree_definitions.TargetABI.LINUX_GNU, + ) + compile_config = iree_definitions.CompileConfig.build( + id="1", tags=[], compile_targets=[compile_target] + ) + small_gen_config = iree_definitions.ModuleGenerationConfig.build( + imported_model=iree_definitions.ImportedModel.from_model(small_model), + compile_config=compile_config, + ) + big_gen_config = iree_definitions.ModuleGenerationConfig.build( + imported_model=iree_definitions.ImportedModel.from_model(big_model), + compile_config=compile_config, + ) + device_spec_a = common_definitions.DeviceSpec.build( + id="dev_a", + device_name="dev_a", + architecture=common_definitions.DeviceArchitecture.RV64_GENERIC, + host_environment=common_definitions.HostEnvironment.ANDROID_ARMV8_2_A, + tags=[], + ) + device_spec_b = common_definitions.DeviceSpec.build( + id="dev_b", + device_name="dev_b", + architecture=common_definitions.DeviceArchitecture.ARM_VALHALL, + host_environment=common_definitions.HostEnvironment.ANDROID_ARMV8_2_A, + tags=[], + ) + run_config_a = iree_definitions.E2EModelRunConfig.build( + module_generation_config=small_gen_config, + module_execution_config=COMMON_EXEC_CONFIG, + target_device_spec=device_spec_a, + input_data=common_definitions.ZEROS_MODEL_INPUT_DATA, + tool=iree_definitions.E2EModelRunTool.IREE_BENCHMARK_MODULE, + ) + run_config_b = iree_definitions.E2EModelRunConfig.build( + module_generation_config=big_gen_config, + module_execution_config=COMMON_EXEC_CONFIG, + target_device_spec=device_spec_b, + input_data=common_definitions.ZEROS_MODEL_INPUT_DATA, + tool=iree_definitions.E2EModelRunTool.IREE_BENCHMARK_MODULE, + ) + + run_config_map = export_benchmark_config.filter_and_group_run_configs( + run_configs=[run_config_a, run_config_b], + preset_matchers=[ + lambda config: config.module_generation_config.imported_model.model.id + == "small_model" + ], + ) + + self.assertEqual( + run_config_map, + { + "dev_a": [run_config_a], + }, + ) if __name__ == "__main__": - unittest.main() + unittest.main() diff --git a/build_tools/benchmarks/generate_benchmark_comment.py b/build_tools/benchmarks/generate_benchmark_comment.py index 4e2ab34810a0..92f326313a77 100755 --- a/build_tools/benchmarks/generate_benchmark_comment.py +++ b/build_tools/benchmarks/generate_benchmark_comment.py @@ -28,7 +28,7 @@ GITHUB_IREE_REPO_PREFIX = "https://github.com/openxla/iree" IREE_DASHBOARD_URL = "https://perf.iree.dev/apis/v2" -IREE_PROJECT_ID = 'IREE' +IREE_PROJECT_ID = "IREE" # The maximal numbers of trials when querying base commit benchmark results. MAX_BASE_COMMIT_QUERY_COUNT = 10 # The max number of rows to show per table. @@ -38,303 +38,340 @@ @dataclasses.dataclass(frozen=True) class CommentDef(object): - title: str - type_id: str + title: str + type_id: str # Map from comment type to comment definition. COMMENT_DEF_MAP = { - "android-benchmark-summary": - CommentDef(title="Abbreviated Android Benchmark Summary", - type_id="bf8cdf94-a992-466d-b11c-778cbd805a22"), - "linux-benchmark-summary": - CommentDef(title="Abbreviated Linux Benchmark Summary", - type_id="37549014-3c67-4e74-8d88-8e929231abe3"), - "benchmark-summary": - CommentDef(title="Abbreviated Benchmark Summary", - type_id="5b42cbfe-26a0-4164-a51c-07f06762e2dc") + "android-benchmark-summary": CommentDef( + title="Abbreviated Android Benchmark Summary", + type_id="bf8cdf94-a992-466d-b11c-778cbd805a22", + ), + "linux-benchmark-summary": CommentDef( + title="Abbreviated Linux Benchmark Summary", + type_id="37549014-3c67-4e74-8d88-8e929231abe3", + ), + "benchmark-summary": CommentDef( + title="Abbreviated Benchmark Summary", + type_id="5b42cbfe-26a0-4164-a51c-07f06762e2dc", + ), } def get_git_total_commit_count(commit: str, verbose: bool = False) -> int: - """Gets the total commit count in history ending with the given commit.""" - # TODO(#11703): Should use --first-parent here. See issue for the required - # work. - count = benchmark_definition.execute_cmd_and_get_stdout( - ['git', 'rev-list', '--count', commit], - cwd=THIS_DIRECTORY, - verbose=verbose) - return int(count) + """Gets the total commit count in history ending with the given commit.""" + # TODO(#11703): Should use --first-parent here. See issue for the required + # work. + count = benchmark_definition.execute_cmd_and_get_stdout( + ["git", "rev-list", "--count", commit], cwd=THIS_DIRECTORY, verbose=verbose + ) + return int(count) -def get_from_dashboard(url: str, - payload: Dict[str, Any], - verbose: bool = False) -> Dict[str, Dict[str, Any]]: - headers = {'Content-type': 'application/json'} - data = json.dumps(payload) +def get_from_dashboard( + url: str, payload: Dict[str, Any], verbose: bool = False +) -> Dict[str, Dict[str, Any]]: + headers = {"Content-type": "application/json"} + data = json.dumps(payload) - if verbose: - print(f'API request payload: {data}') + if verbose: + print(f"API request payload: {data}") - response = requests.get(url, data=data, headers=headers) - code = response.status_code - if code != 200: - raise requests.RequestException( - f'Failed to get from dashboard server with status code {code}') + response = requests.get(url, data=data, headers=headers) + code = response.status_code + if code != 200: + raise requests.RequestException( + f"Failed to get from dashboard server with status code {code}" + ) - data = response.json() - if verbose: - print(f'Queried base benchmark data: {data}') - return data + data = response.json() + if verbose: + print(f"Queried base benchmark data: {data}") + return data BenchmarkQueryResults = Dict[str, Dict[str, Any]] def query_base_benchmark_results( - commit: str, verbose: bool = False) -> BenchmarkQueryResults: - """Queries the benchmark results for the given commit.""" - build_id = get_git_total_commit_count(commit, verbose) - payload = {'projectId': IREE_PROJECT_ID, 'buildId': build_id} - return get_from_dashboard(f'{IREE_DASHBOARD_URL}/getBuild', - payload, - verbose=verbose) + commit: str, verbose: bool = False +) -> BenchmarkQueryResults: + """Queries the benchmark results for the given commit.""" + build_id = get_git_total_commit_count(commit, verbose) + payload = {"projectId": IREE_PROJECT_ID, "buildId": build_id} + return get_from_dashboard( + f"{IREE_DASHBOARD_URL}/getBuild", payload, verbose=verbose + ) @dataclasses.dataclass(frozen=True) class ComparableBenchmarkResults(object): - commit_sha: str - benchmark_results: BenchmarkQueryResults + commit_sha: str + benchmark_results: BenchmarkQueryResults def _find_comparable_benchmark_results( - start_commit: str, - required_benchmark_keys: Set[str], - verbose: bool = False) -> Optional[ComparableBenchmarkResults]: - cmds = [ - "git", "rev-list", "--first-parent", - f"--max-count={MAX_BASE_COMMIT_QUERY_COUNT}", start_commit - ] - output = benchmark_definition.execute_cmd_and_get_stdout(cmds, - cwd=THIS_DIRECTORY, - verbose=verbose) - previous_commits = output.splitlines() - # Try to query some base benchmark to diff against, from the top of the - # tree. Bail out if the maximal trial number is exceeded. - for base_commit in previous_commits: - base_benchmarks = query_base_benchmark_results(commit=base_commit, - verbose=verbose) - base_benchmark_keys = set(base_benchmarks.keys()) - if required_benchmark_keys <= base_benchmark_keys: - return ComparableBenchmarkResults(commit_sha=base_commit, - benchmark_results=base_benchmarks) - - return None + start_commit: str, required_benchmark_keys: Set[str], verbose: bool = False +) -> Optional[ComparableBenchmarkResults]: + cmds = [ + "git", + "rev-list", + "--first-parent", + f"--max-count={MAX_BASE_COMMIT_QUERY_COUNT}", + start_commit, + ] + output = benchmark_definition.execute_cmd_and_get_stdout( + cmds, cwd=THIS_DIRECTORY, verbose=verbose + ) + previous_commits = output.splitlines() + # Try to query some base benchmark to diff against, from the top of the + # tree. Bail out if the maximal trial number is exceeded. + for base_commit in previous_commits: + base_benchmarks = query_base_benchmark_results( + commit=base_commit, verbose=verbose + ) + base_benchmark_keys = set(base_benchmarks.keys()) + if required_benchmark_keys <= base_benchmark_keys: + return ComparableBenchmarkResults( + commit_sha=base_commit, benchmark_results=base_benchmarks + ) + + return None def _get_git_commit_hash(ref: str, verbose: bool = False) -> str: - """Gets the commit hash for the given commit.""" - return benchmark_definition.execute_cmd_and_get_stdout( - ['git', 'rev-parse', ref], cwd=THIS_DIRECTORY, verbose=verbose) + """Gets the commit hash for the given commit.""" + return benchmark_definition.execute_cmd_and_get_stdout( + ["git", "rev-parse", ref], cwd=THIS_DIRECTORY, verbose=verbose + ) -def _get_git_merge_base_commit(pr_commit: str, - target_branch: str, - verbose: bool = False) -> str: - return benchmark_definition.execute_cmd_and_get_stdout( - args=["git", "merge-base", target_branch, pr_commit], - cwd=THIS_DIRECTORY, - verbose=verbose) +def _get_git_merge_base_commit( + pr_commit: str, target_branch: str, verbose: bool = False +) -> str: + return benchmark_definition.execute_cmd_and_get_stdout( + args=["git", "merge-base", target_branch, pr_commit], + cwd=THIS_DIRECTORY, + verbose=verbose, + ) def _get_benchmark_result_markdown( - execution_benchmarks: Dict[ - str, benchmark_presentation.AggregateBenchmarkLatency], + execution_benchmarks: Dict[str, benchmark_presentation.AggregateBenchmarkLatency], compilation_metrics: Dict[str, benchmark_presentation.CompilationMetrics], - pr_url: str, build_url: str, comment_def: CommentDef, - commit_info_md: str) -> Tuple[str, str]: - """Gets the full/abbreviated markdown summary of all benchmarks in files.""" - - pr_info = md.link("Pull request", pr_url) - build_info = md.link("Build", build_url) - - # Compose the full benchmark tables. - full_table = [md.header("Full Benchmark Summary", 2)] - full_table.append(md.unordered_list([commit_info_md, pr_info, build_info])) - - # Compose the abbreviated benchmark tables. - abbr_table = [md.header(comment_def.title, 2)] - abbr_table.append(commit_info_md) - - if len(execution_benchmarks) > 0: - full_table.append( - benchmark_presentation.categorize_benchmarks_into_tables( - execution_benchmarks)) - - abbr_benchmarks_tables = benchmark_presentation.categorize_benchmarks_into_tables( - execution_benchmarks, TABLE_SIZE_CUT) - if len(abbr_benchmarks_tables) == 0: - abbr_table.append("No improved or regressed benchmarks 🏖️") - else: - abbr_table.append(abbr_benchmarks_tables) - - # Compose the full compilation metrics tables. - if len(compilation_metrics) > 0: - full_table.append( - benchmark_presentation.categorize_compilation_metrics_into_tables( - compilation_metrics)) - - abbr_compilation_metrics_tables = benchmark_presentation.categorize_compilation_metrics_into_tables( - compilation_metrics, TABLE_SIZE_CUT) - if len(abbr_compilation_metrics_tables) == 0: - abbr_table.append("No improved or regressed compilation metrics 🏖️") - else: - abbr_table.append(abbr_compilation_metrics_tables) - - abbr_table.append("For more information:") - # We don't know until a Gist is really created. Use a placeholder for now and - # replace later. - full_result_info = md.link("Full benchmark result tables", - benchmark_comment.GIST_LINK_PLACEHORDER) - abbr_table.append(md.unordered_list([full_result_info, build_info])) - - # Append the unique comment type id to help identify and update the existing - # comment. - abbr_table.append(f"") - - return "\n\n".join(full_table), "\n\n".join(abbr_table) + pr_url: str, + build_url: str, + comment_def: CommentDef, + commit_info_md: str, +) -> Tuple[str, str]: + """Gets the full/abbreviated markdown summary of all benchmarks in files.""" + + pr_info = md.link("Pull request", pr_url) + build_info = md.link("Build", build_url) + + # Compose the full benchmark tables. + full_table = [md.header("Full Benchmark Summary", 2)] + full_table.append(md.unordered_list([commit_info_md, pr_info, build_info])) + + # Compose the abbreviated benchmark tables. + abbr_table = [md.header(comment_def.title, 2)] + abbr_table.append(commit_info_md) + + if len(execution_benchmarks) > 0: + full_table.append( + benchmark_presentation.categorize_benchmarks_into_tables( + execution_benchmarks + ) + ) + + abbr_benchmarks_tables = ( + benchmark_presentation.categorize_benchmarks_into_tables( + execution_benchmarks, TABLE_SIZE_CUT + ) + ) + if len(abbr_benchmarks_tables) == 0: + abbr_table.append("No improved or regressed benchmarks 🏖️") + else: + abbr_table.append(abbr_benchmarks_tables) + + # Compose the full compilation metrics tables. + if len(compilation_metrics) > 0: + full_table.append( + benchmark_presentation.categorize_compilation_metrics_into_tables( + compilation_metrics + ) + ) + + abbr_compilation_metrics_tables = ( + benchmark_presentation.categorize_compilation_metrics_into_tables( + compilation_metrics, TABLE_SIZE_CUT + ) + ) + if len(abbr_compilation_metrics_tables) == 0: + abbr_table.append("No improved or regressed compilation metrics 🏖️") + else: + abbr_table.append(abbr_compilation_metrics_tables) + + abbr_table.append("For more information:") + # We don't know until a Gist is really created. Use a placeholder for now and + # replace later. + full_result_info = md.link( + "Full benchmark result tables", benchmark_comment.GIST_LINK_PLACEHORDER + ) + abbr_table.append(md.unordered_list([full_result_info, build_info])) + + # Append the unique comment type id to help identify and update the existing + # comment. + abbr_table.append(f"") + + return "\n\n".join(full_table), "\n\n".join(abbr_table) def parse_arguments(): - """Parses command-line options.""" - - parser = argparse.ArgumentParser() - parser.add_argument( - "--benchmark_files", - metavar="", - default=[], - action="append", - help=("Paths to the JSON files containing benchmark results, " - "accepts wildcards")) - parser.add_argument( - "--compile_stats_files", - metavar="", - default=[], - action="append", - help=("Paths to the JSON files containing compilation statistics, " - "accepts wildcards")) - parser.add_argument("--pr_number", required=True, type=int, help="PR number") - parser.add_argument("--pr_committish", - type=str, - default="HEAD", - help="PR commit hash or ref") - parser.add_argument("--pr_base_branch", - type=str, - default=None, - help="Base branch to merge the PR.") - parser.add_argument("--comment_type", - required=True, - choices=COMMENT_DEF_MAP.keys(), - help="Type of summary comment") - parser.add_argument("--build_url", - required=True, - type=str, - help="CI build page url to show in the report") - parser.add_argument("--output", type=pathlib.Path, default=None) - parser.add_argument("--verbose", - action="store_true", - help="Print internal information during execution") - - return parser.parse_args() + """Parses command-line options.""" + + parser = argparse.ArgumentParser() + parser.add_argument( + "--benchmark_files", + metavar="", + default=[], + action="append", + help=( + "Paths to the JSON files containing benchmark results, " "accepts wildcards" + ), + ) + parser.add_argument( + "--compile_stats_files", + metavar="", + default=[], + action="append", + help=( + "Paths to the JSON files containing compilation statistics, " + "accepts wildcards" + ), + ) + parser.add_argument("--pr_number", required=True, type=int, help="PR number") + parser.add_argument( + "--pr_committish", type=str, default="HEAD", help="PR commit hash or ref" + ) + parser.add_argument( + "--pr_base_branch", type=str, default=None, help="Base branch to merge the PR." + ) + parser.add_argument( + "--comment_type", + required=True, + choices=COMMENT_DEF_MAP.keys(), + help="Type of summary comment", + ) + parser.add_argument( + "--build_url", + required=True, + type=str, + help="CI build page url to show in the report", + ) + parser.add_argument("--output", type=pathlib.Path, default=None) + parser.add_argument( + "--verbose", + action="store_true", + help="Print internal information during execution", + ) + + return parser.parse_args() def main(args): - benchmark_files = common_arguments.expand_and_check_file_paths( - args.benchmark_files) - compile_stats_files = common_arguments.expand_and_check_file_paths( - args.compile_stats_files) - - pr_commit = _get_git_commit_hash(ref=args.pr_committish, verbose=args.verbose) - execution_benchmarks = benchmark_presentation.aggregate_all_benchmarks( - benchmark_files=benchmark_files, expected_pr_commit=pr_commit) - compilation_metrics = benchmark_presentation.collect_all_compilation_metrics( - compile_stats_files=compile_stats_files, expected_pr_commit=pr_commit) - - if args.pr_base_branch is None: - pr_base_commit = None - else: - pr_base_commit = _get_git_merge_base_commit( - pr_commit=pr_commit, - target_branch=args.pr_base_branch, - verbose=args.verbose) - - if pr_base_commit is None: - comparable_results = None - else: - required_benchmark_keys = set(execution_benchmarks.keys()) - for target_id in compilation_metrics: - for mapper in benchmark_presentation.COMPILATION_METRICS_TO_TABLE_MAPPERS: - required_benchmark_keys.add(mapper.get_series_id(target_id)) - - comparable_results = _find_comparable_benchmark_results( - start_commit=pr_base_commit, - required_benchmark_keys=required_benchmark_keys, - verbose=args.verbose) - - if comparable_results is None: - comparable_commit = None - else: - comparable_commit = comparable_results.commit_sha - # Update the execution benchmarks with base numbers. - for bench in execution_benchmarks: - base_benchmark = comparable_results.benchmark_results[bench] - if base_benchmark["sampleUnit"] != "ns": - raise ValueError("Only support nanoseconds for latency sample.") - execution_benchmarks[bench].base_mean_time = base_benchmark["sample"] - - # Update the compilation metrics with base numbers. - for target_id, metrics in compilation_metrics.items(): - updated_metrics = metrics - for mapper in benchmark_presentation.COMPILATION_METRICS_TO_TABLE_MAPPERS: - base_benchmark = comparable_results.benchmark_results[ - mapper.get_series_id(target_id)] - if base_benchmark["sampleUnit"] != mapper.get_unit(): - raise ValueError("Unit of the queried sample is mismatched.") - updated_metrics = mapper.update_base_value(updated_metrics, - base_benchmark["sample"]) - compilation_metrics[target_id] = updated_metrics - - pr_commit_link = md.link(pr_commit, - f"{GITHUB_IREE_REPO_PREFIX}/commit/{pr_commit}") - commit_info_md = f"@ commit {pr_commit_link}" - if comparable_commit is not None: - baseline_commit_link = md.link( - comparable_commit, - f"{GITHUB_IREE_REPO_PREFIX}/commit/{comparable_commit}") - commit_info_md += f" (vs. base {baseline_commit_link})" - elif pr_base_commit is not None: - commit_info_md += " (no previous benchmark results to compare)" - - comment_def = COMMENT_DEF_MAP[args.comment_type] - full_md, abbr_md = _get_benchmark_result_markdown( - execution_benchmarks=execution_benchmarks, - compilation_metrics=compilation_metrics, - pr_url=f"{GITHUB_IREE_REPO_PREFIX}/pull/{args.pr_number}", - build_url=args.build_url, - comment_def=comment_def, - commit_info_md=commit_info_md) - - comment_data = benchmark_comment.CommentData( - type_id=comment_def.type_id, - abbr_md=abbr_md, - full_md=full_md, - unverified_pr_number=args.pr_number) - comment_json_data = json.dumps(dataclasses.asdict(comment_data), indent=2) - if args.output is None: - print(comment_json_data) - else: - args.output.write_text(comment_json_data) + benchmark_files = common_arguments.expand_and_check_file_paths(args.benchmark_files) + compile_stats_files = common_arguments.expand_and_check_file_paths( + args.compile_stats_files + ) + + pr_commit = _get_git_commit_hash(ref=args.pr_committish, verbose=args.verbose) + execution_benchmarks = benchmark_presentation.aggregate_all_benchmarks( + benchmark_files=benchmark_files, expected_pr_commit=pr_commit + ) + compilation_metrics = benchmark_presentation.collect_all_compilation_metrics( + compile_stats_files=compile_stats_files, expected_pr_commit=pr_commit + ) + + if args.pr_base_branch is None: + pr_base_commit = None + else: + pr_base_commit = _get_git_merge_base_commit( + pr_commit=pr_commit, target_branch=args.pr_base_branch, verbose=args.verbose + ) + + if pr_base_commit is None: + comparable_results = None + else: + required_benchmark_keys = set(execution_benchmarks.keys()) + for target_id in compilation_metrics: + for mapper in benchmark_presentation.COMPILATION_METRICS_TO_TABLE_MAPPERS: + required_benchmark_keys.add(mapper.get_series_id(target_id)) + + comparable_results = _find_comparable_benchmark_results( + start_commit=pr_base_commit, + required_benchmark_keys=required_benchmark_keys, + verbose=args.verbose, + ) + + if comparable_results is None: + comparable_commit = None + else: + comparable_commit = comparable_results.commit_sha + # Update the execution benchmarks with base numbers. + for bench in execution_benchmarks: + base_benchmark = comparable_results.benchmark_results[bench] + if base_benchmark["sampleUnit"] != "ns": + raise ValueError("Only support nanoseconds for latency sample.") + execution_benchmarks[bench].base_mean_time = base_benchmark["sample"] + + # Update the compilation metrics with base numbers. + for target_id, metrics in compilation_metrics.items(): + updated_metrics = metrics + for mapper in benchmark_presentation.COMPILATION_METRICS_TO_TABLE_MAPPERS: + base_benchmark = comparable_results.benchmark_results[ + mapper.get_series_id(target_id) + ] + if base_benchmark["sampleUnit"] != mapper.get_unit(): + raise ValueError("Unit of the queried sample is mismatched.") + updated_metrics = mapper.update_base_value( + updated_metrics, base_benchmark["sample"] + ) + compilation_metrics[target_id] = updated_metrics + + pr_commit_link = md.link(pr_commit, f"{GITHUB_IREE_REPO_PREFIX}/commit/{pr_commit}") + commit_info_md = f"@ commit {pr_commit_link}" + if comparable_commit is not None: + baseline_commit_link = md.link( + comparable_commit, f"{GITHUB_IREE_REPO_PREFIX}/commit/{comparable_commit}" + ) + commit_info_md += f" (vs. base {baseline_commit_link})" + elif pr_base_commit is not None: + commit_info_md += " (no previous benchmark results to compare)" + + comment_def = COMMENT_DEF_MAP[args.comment_type] + full_md, abbr_md = _get_benchmark_result_markdown( + execution_benchmarks=execution_benchmarks, + compilation_metrics=compilation_metrics, + pr_url=f"{GITHUB_IREE_REPO_PREFIX}/pull/{args.pr_number}", + build_url=args.build_url, + comment_def=comment_def, + commit_info_md=commit_info_md, + ) + + comment_data = benchmark_comment.CommentData( + type_id=comment_def.type_id, + abbr_md=abbr_md, + full_md=full_md, + unverified_pr_number=args.pr_number, + ) + comment_json_data = json.dumps(dataclasses.asdict(comment_data), indent=2) + if args.output is None: + print(comment_json_data) + else: + args.output.write_text(comment_json_data) if __name__ == "__main__": - main(parse_arguments()) + main(parse_arguments()) diff --git a/build_tools/benchmarks/post_benchmark_comment.py b/build_tools/benchmarks/post_benchmark_comment.py index 9322d8571b33..8cd98da1955c 100755 --- a/build_tools/benchmarks/post_benchmark_comment.py +++ b/build_tools/benchmarks/post_benchmark_comment.py @@ -39,221 +39,227 @@ class APIRequester(object): - """REST API client that injects proper GitHub authentication headers.""" - - def __init__(self, github_token: str): - self._api_headers = { - "Accept": "application/vnd.github+json", - "Authorization": f"token {github_token}", - "X-GitHub-Api-Version": GITHUB_API_VERSION, - } - self._session = requests.session() - - def get(self, endpoint: str, payload: Any = {}) -> requests.Response: - return self._session.get(endpoint, - data=json.dumps(payload), - headers=self._api_headers) - - def post(self, endpoint: str, payload: Any = {}) -> requests.Response: - return self._session.post(endpoint, - data=json.dumps(payload), - headers=self._api_headers) - - def patch(self, endpoint: str, payload: Any = {}) -> requests.Response: - return self._session.patch(endpoint, - data=json.dumps(payload), - headers=self._api_headers) - + """REST API client that injects proper GitHub authentication headers.""" + + def __init__(self, github_token: str): + self._api_headers = { + "Accept": "application/vnd.github+json", + "Authorization": f"token {github_token}", + "X-GitHub-Api-Version": GITHUB_API_VERSION, + } + self._session = requests.session() + + def get(self, endpoint: str, payload: Any = {}) -> requests.Response: + return self._session.get( + endpoint, data=json.dumps(payload), headers=self._api_headers + ) -class GithubClient(object): - """Helper to call Github REST APIs.""" - - def __init__(self, requester: APIRequester): - self._requester = requester - - def post_to_gist(self, - filename: str, - content: str, - verbose: bool = False) -> str: - """Posts the given content to a new GitHub Gist and returns the URL to it.""" - - response = self._requester.post(endpoint=GITHUB_GIST_API, - payload={ - "public": True, - "files": { - filename: { - "content": content - } - } - }) - if response.status_code != http.client.CREATED: - raise RuntimeError( - f"Failed to create on gist; error code: {response.status_code} - {response.text}" - ) - - response = response.json() - if verbose: - print(f"Gist posting response: {response}") - - if response["truncated"]: - raise RuntimeError(f"Content is too large and was truncated") - - return response["html_url"] - - def get_previous_comment_on_pr(self, - pr_number: int, - comment_bot_user: str, - comment_type_id: str, - query_comment_per_page: int = 100, - max_pages_to_search: int = 10, - verbose: bool = False) -> Optional[int]: - """Gets the previous comment's id from GitHub.""" - - for page in range(1, max_pages_to_search + 1): - response = self._requester.get( - endpoint=f"{GITHUB_IREE_API_PREFIX}/issues/{pr_number}/comments", - payload={ - "per_page": query_comment_per_page, - "page": page, - "sort": "updated", - "direction": "desc" - }) - if response.status_code != http.client.OK: - raise RuntimeError( - f"Failed to get PR comments from GitHub; error code: {response.status_code} - {response.text}" + def post(self, endpoint: str, payload: Any = {}) -> requests.Response: + return self._session.post( + endpoint, data=json.dumps(payload), headers=self._api_headers ) - comments = response.json() - if verbose: - print(f"Previous comment query response on page {page}: {comments}") + def patch(self, endpoint: str, payload: Any = {}) -> requests.Response: + return self._session.patch( + endpoint, data=json.dumps(payload), headers=self._api_headers + ) - # Find the most recently updated comment that matches. - for comment in comments: - if (comment["user"]["login"] == comment_bot_user and - comment_type_id in comment["body"]): - return comment["id"] - if len(comments) < query_comment_per_page: - break +class GithubClient(object): + """Helper to call Github REST APIs.""" - return None + def __init__(self, requester: APIRequester): + self._requester = requester - def update_comment_on_pr(self, comment_id: int, content: str): - """Updates the content of the given comment id.""" + def post_to_gist(self, filename: str, content: str, verbose: bool = False) -> str: + """Posts the given content to a new GitHub Gist and returns the URL to it.""" - response = self._requester.patch( - endpoint=f"{GITHUB_IREE_API_PREFIX}/issues/comments/{comment_id}", - payload={"body": content}) - if response.status_code != http.client.OK: - raise RuntimeError( - f"Failed to comment on GitHub; error code: {response.status_code} - {response.text}" - ) + response = self._requester.post( + endpoint=GITHUB_GIST_API, + payload={"public": True, "files": {filename: {"content": content}}}, + ) + if response.status_code != http.client.CREATED: + raise RuntimeError( + f"Failed to create on gist; error code: {response.status_code} - {response.text}" + ) + + response = response.json() + if verbose: + print(f"Gist posting response: {response}") + + if response["truncated"]: + raise RuntimeError(f"Content is too large and was truncated") + + return response["html_url"] + + def get_previous_comment_on_pr( + self, + pr_number: int, + comment_bot_user: str, + comment_type_id: str, + query_comment_per_page: int = 100, + max_pages_to_search: int = 10, + verbose: bool = False, + ) -> Optional[int]: + """Gets the previous comment's id from GitHub.""" + + for page in range(1, max_pages_to_search + 1): + response = self._requester.get( + endpoint=f"{GITHUB_IREE_API_PREFIX}/issues/{pr_number}/comments", + payload={ + "per_page": query_comment_per_page, + "page": page, + "sort": "updated", + "direction": "desc", + }, + ) + if response.status_code != http.client.OK: + raise RuntimeError( + f"Failed to get PR comments from GitHub; error code: {response.status_code} - {response.text}" + ) + + comments = response.json() + if verbose: + print(f"Previous comment query response on page {page}: {comments}") + + # Find the most recently updated comment that matches. + for comment in comments: + if ( + comment["user"]["login"] == comment_bot_user + and comment_type_id in comment["body"] + ): + return comment["id"] + + if len(comments) < query_comment_per_page: + break + + return None + + def update_comment_on_pr(self, comment_id: int, content: str): + """Updates the content of the given comment id.""" + + response = self._requester.patch( + endpoint=f"{GITHUB_IREE_API_PREFIX}/issues/comments/{comment_id}", + payload={"body": content}, + ) + if response.status_code != http.client.OK: + raise RuntimeError( + f"Failed to comment on GitHub; error code: {response.status_code} - {response.text}" + ) - def create_comment_on_pr(self, pr_number: int, content: str): - """Posts the given content as comments to the current pull request.""" + def create_comment_on_pr(self, pr_number: int, content: str): + """Posts the given content as comments to the current pull request.""" - response = self._requester.post( - endpoint=f"{GITHUB_IREE_API_PREFIX}/issues/{pr_number}/comments", - payload={"body": content}) - if response.status_code != http.client.CREATED: - raise RuntimeError( - f"Failed to comment on GitHub; error code: {response.status_code} - {response.text}" - ) + response = self._requester.post( + endpoint=f"{GITHUB_IREE_API_PREFIX}/issues/{pr_number}/comments", + payload={"body": content}, + ) + if response.status_code != http.client.CREATED: + raise RuntimeError( + f"Failed to comment on GitHub; error code: {response.status_code} - {response.text}" + ) - def get_pull_request_head_commit(self, pr_number: int) -> str: - """Get pull request head commit SHA.""" + def get_pull_request_head_commit(self, pr_number: int) -> str: + """Get pull request head commit SHA.""" - response = self._requester.get( - endpoint=f"{GITHUB_IREE_API_PREFIX}/pulls/{pr_number}") - if response.status_code != http.client.OK: - raise RuntimeError( - f"Failed to fetch the pull request: {pr_number}; " - f"error code: {response.status_code} - {response.text}") + response = self._requester.get( + endpoint=f"{GITHUB_IREE_API_PREFIX}/pulls/{pr_number}" + ) + if response.status_code != http.client.OK: + raise RuntimeError( + f"Failed to fetch the pull request: {pr_number}; " + f"error code: {response.status_code} - {response.text}" + ) - return response.json()["head"]["sha"] + return response.json()["head"]["sha"] def _parse_arguments(): - parser = argparse.ArgumentParser() - parser.add_argument("comment_json", type=pathlib.Path) - parser.add_argument("--verbose", action="store_true") - verification_parser = parser.add_mutually_exclusive_group(required=True) - verification_parser.add_argument("--github_event_json", type=pathlib.Path) - # Temporary option for buildkite pipeline. - verification_parser.add_argument("--no_verify_pr", action="store_true") - return parser.parse_args() + parser = argparse.ArgumentParser() + parser.add_argument("comment_json", type=pathlib.Path) + parser.add_argument("--verbose", action="store_true") + verification_parser = parser.add_mutually_exclusive_group(required=True) + verification_parser.add_argument("--github_event_json", type=pathlib.Path) + # Temporary option for buildkite pipeline. + verification_parser.add_argument("--no_verify_pr", action="store_true") + return parser.parse_args() def main(args: argparse.Namespace): - github_token = os.environ.get("GITHUB_TOKEN") - if github_token is None: - raise ValueError("GITHUB_TOKEN must be set.") - - comment_bot_user = os.environ.get("COMMENT_BOT_USER") - if comment_bot_user is None: - raise ValueError("COMMENT_BOT_USER must be set.") - - gist_bot_token = os.environ.get("GIST_BOT_TOKEN") - if gist_bot_token is None: - raise ValueError("GIST_BOT_TOKEN must be set.") - - comment_data = benchmark_comment.CommentData( - **json.loads(args.comment_json.read_text())) - # Sanitize the pr number to make sure it is an integer. - pr_number = int(comment_data.unverified_pr_number) - - pr_client = GithubClient(requester=APIRequester(github_token=github_token)) - if args.github_event_json is None: - github_event = None - else: - github_event = json.loads(args.github_event_json.read_text()) - workflow_run_sha = github_event["workflow_run"]["head_sha"] - pr_head_sha = pr_client.get_pull_request_head_commit(pr_number=pr_number) - # We can't get the trusted PR number of a workflow run from GitHub API. So we - # take the untrusted PR number from presubmit workflow and verify if the PR's - # current head SHA matches the commit SHA in the workflow run. It assumes - # that to generate the malicious comment data, attacker must modify the code - # and has a new commit SHA. So if the PR head commit matches the workflow - # run with attacker's commit, either the PR is created by the attacker or - # other's PR has the malicious commit. In both cases posting malicious - # comment is acceptable. - # - # Note that the collision of a target SHA1 is possible but GitHub has some - # protections (https://github.blog/2017-03-20-sha-1-collision-detection-on-github-com/). - # The assumption also only holds if files in GCS can't be overwritten (so the - # comment data can't be modified without changing the code). - # The check will also fail if the PR author pushes the new commit after the - # workflow is triggered. But pushing the new commit means to cancel the - # current CI run including the benchmarking. So it will unlikely fail for - # that reason. - if workflow_run_sha != pr_head_sha: - raise ValueError( - f"Workflow run SHA: {workflow_run_sha} does not match " - f"the head SHA: {pr_head_sha} of the pull request: {pr_number}.") - - gist_client = GithubClient(requester=APIRequester( - github_token=gist_bot_token)) - gist_url = gist_client.post_to_gist( - filename=f'iree-full-benchmark-results-{pr_number}.md', - content=comment_data.full_md, - verbose=args.verbose) - - previous_comment_id = pr_client.get_previous_comment_on_pr( - pr_number=pr_number, - comment_bot_user=comment_bot_user, - comment_type_id=comment_data.type_id, - verbose=args.verbose) - - abbr_md = comment_data.abbr_md.replace( - benchmark_comment.GIST_LINK_PLACEHORDER, gist_url) - if github_event is not None: - abbr_md += f'\n\n[Source Workflow Run]({github_event["workflow_run"]["html_url"]})' - if previous_comment_id is not None: - pr_client.update_comment_on_pr(comment_id=previous_comment_id, - content=abbr_md) - else: - pr_client.create_comment_on_pr(pr_number=pr_number, content=abbr_md) + github_token = os.environ.get("GITHUB_TOKEN") + if github_token is None: + raise ValueError("GITHUB_TOKEN must be set.") + + comment_bot_user = os.environ.get("COMMENT_BOT_USER") + if comment_bot_user is None: + raise ValueError("COMMENT_BOT_USER must be set.") + + gist_bot_token = os.environ.get("GIST_BOT_TOKEN") + if gist_bot_token is None: + raise ValueError("GIST_BOT_TOKEN must be set.") + + comment_data = benchmark_comment.CommentData( + **json.loads(args.comment_json.read_text()) + ) + # Sanitize the pr number to make sure it is an integer. + pr_number = int(comment_data.unverified_pr_number) + + pr_client = GithubClient(requester=APIRequester(github_token=github_token)) + if args.github_event_json is None: + github_event = None + else: + github_event = json.loads(args.github_event_json.read_text()) + workflow_run_sha = github_event["workflow_run"]["head_sha"] + pr_head_sha = pr_client.get_pull_request_head_commit(pr_number=pr_number) + # We can't get the trusted PR number of a workflow run from GitHub API. So we + # take the untrusted PR number from presubmit workflow and verify if the PR's + # current head SHA matches the commit SHA in the workflow run. It assumes + # that to generate the malicious comment data, attacker must modify the code + # and has a new commit SHA. So if the PR head commit matches the workflow + # run with attacker's commit, either the PR is created by the attacker or + # other's PR has the malicious commit. In both cases posting malicious + # comment is acceptable. + # + # Note that the collision of a target SHA1 is possible but GitHub has some + # protections (https://github.blog/2017-03-20-sha-1-collision-detection-on-github-com/). + # The assumption also only holds if files in GCS can't be overwritten (so the + # comment data can't be modified without changing the code). + # The check will also fail if the PR author pushes the new commit after the + # workflow is triggered. But pushing the new commit means to cancel the + # current CI run including the benchmarking. So it will unlikely fail for + # that reason. + if workflow_run_sha != pr_head_sha: + raise ValueError( + f"Workflow run SHA: {workflow_run_sha} does not match " + f"the head SHA: {pr_head_sha} of the pull request: {pr_number}." + ) + + gist_client = GithubClient(requester=APIRequester(github_token=gist_bot_token)) + gist_url = gist_client.post_to_gist( + filename=f"iree-full-benchmark-results-{pr_number}.md", + content=comment_data.full_md, + verbose=args.verbose, + ) + + previous_comment_id = pr_client.get_previous_comment_on_pr( + pr_number=pr_number, + comment_bot_user=comment_bot_user, + comment_type_id=comment_data.type_id, + verbose=args.verbose, + ) + + abbr_md = comment_data.abbr_md.replace( + benchmark_comment.GIST_LINK_PLACEHORDER, gist_url + ) + if github_event is not None: + abbr_md += ( + f'\n\n[Source Workflow Run]({github_event["workflow_run"]["html_url"]})' + ) + if previous_comment_id is not None: + pr_client.update_comment_on_pr(comment_id=previous_comment_id, content=abbr_md) + else: + pr_client.create_comment_on_pr(pr_number=pr_number, content=abbr_md) if __name__ == "__main__": - main(_parse_arguments()) + main(_parse_arguments()) diff --git a/build_tools/benchmarks/post_benchmark_comment_test.py b/build_tools/benchmarks/post_benchmark_comment_test.py index 6564b0db7d13..d54fb8de6def 100644 --- a/build_tools/benchmarks/post_benchmark_comment_test.py +++ b/build_tools/benchmarks/post_benchmark_comment_test.py @@ -15,174 +15,146 @@ class GithubClientTest(unittest.TestCase): - - def setUp(self): - self._mock_response = mock.create_autospec(requests.Response) - self._mock_requester = mock.create_autospec( - post_benchmark_comment.APIRequester) - self._mock_requester.get.return_value = self._mock_response - self._mock_requester.post.return_value = self._mock_response - self._mock_requester.patch.return_value = self._mock_response - - def test_post_to_gist(self): - gist_url = "https://example.com/123455/1234.md" - self._mock_response.status_code = http.client.CREATED - self._mock_response.json.return_value = { - "html_url": gist_url, - "truncated": False - } - client = post_benchmark_comment.GithubClient(self._mock_requester) - - url = client.post_to_gist(filename="1234.md", content="xyz") - - self.assertEqual(url, gist_url) - self._mock_requester.post.assert_called_once_with( - endpoint=post_benchmark_comment.GITHUB_GIST_API, - payload={ - "public": True, - "files": { - "1234.md": { - "content": "xyz" - } - } - }) - - def test_post_to_gist_truncated(self): - gist_url = "example.com/123455/1234.md" - self._mock_response.status_code = http.client.CREATED - self._mock_response.json.return_value = { - "html_url": gist_url, - "truncated": True - } - client = post_benchmark_comment.GithubClient(self._mock_requester) - - with self.assertRaises(RuntimeError) as _: - client.post_to_gist(filename="1234.md", content="xyz") - - def test_get_previous_comment_on_pr(self): - first_mock_response = mock.create_autospec(requests.Response) - first_mock_response.status_code = http.client.OK - first_mock_response.json.return_value = [{ - "id": 1, - "user": { - "login": "bot" - }, - "body": "comment id: abcd" - }, { - "id": 2, - "user": { - "login": "user" - }, - "body": "comment id: 1234" - }] - second_mock_response = mock.create_autospec(requests.Response) - second_mock_response.status_code = http.client.OK - second_mock_response.json.return_value = [{ - "id": 3, - "user": { - "login": "bot" - }, - "body": "comment id: 1234" - }] - mock_requester = mock.create_autospec(post_benchmark_comment.APIRequester) - mock_requester.get.side_effect = [first_mock_response, second_mock_response] - client = post_benchmark_comment.GithubClient(mock_requester) - - comment_id = client.get_previous_comment_on_pr(pr_number=23, - comment_bot_user="bot", - comment_type_id="1234", - query_comment_per_page=2, - max_pages_to_search=10) - - self.assertEqual(comment_id, 3) - self.assertEqual(mock_requester.get.call_count, 2) - endpoint_url = f"{post_benchmark_comment.GITHUB_IREE_API_PREFIX}/issues/23/comments" - mock_requester.get.assert_any_call(endpoint=endpoint_url, - payload={ - "per_page": 2, - "page": 1, - "sort": "updated", - "direction": "desc" - }) - mock_requester.get.assert_any_call(endpoint=endpoint_url, - payload={ - "per_page": 2, - "page": 2, - "sort": "updated", - "direction": "desc" - }) - - def test_get_previous_comment_on_pr_not_found(self): - mock_response = mock.create_autospec(requests.Response) - mock_response.status_code = http.client.OK - mock_response.json.return_value = [{ - "id": 1, - "user": { - "login": "bot" - }, - "body": "comment id: 5678" - }] - mock_requester = mock.create_autospec(post_benchmark_comment.APIRequester) - mock_requester.get.side_effect = [mock_response] * 10 - client = post_benchmark_comment.GithubClient(mock_requester) - - comment_id = client.get_previous_comment_on_pr(pr_number=23, - comment_bot_user="bot", - comment_type_id="1234", - query_comment_per_page=1, - max_pages_to_search=10) - - self.assertIsNone(comment_id) - self.assertEqual(mock_requester.get.call_count, 10) - endpoint_url = f"{post_benchmark_comment.GITHUB_IREE_API_PREFIX}/issues/23/comments" - mock_requester.get.assert_any_call(endpoint=endpoint_url, - payload={ - "per_page": 1, - "page": 1, - "sort": "updated", - "direction": "desc" - }) - mock_requester.get.assert_any_call(endpoint=endpoint_url, - payload={ - "per_page": 1, - "page": 10, - "sort": "updated", - "direction": "desc" - }) - - def test_update_comment_on_pr(self): - self._mock_response.status_code = http.client.OK - client = post_benchmark_comment.GithubClient(self._mock_requester) - - client.update_comment_on_pr(comment_id=123, content="xyz") - - self._mock_requester.patch.assert_called_once_with( - endpoint= - f"{post_benchmark_comment.GITHUB_IREE_API_PREFIX}/issues/comments/123", - payload={"body": "xyz"}) - - def test_create_comment_on_pr(self): - self._mock_response.status_code = http.client.CREATED - client = post_benchmark_comment.GithubClient(self._mock_requester) - - client.create_comment_on_pr(pr_number=1234, content="xyz") - - self._mock_requester.post.assert_called_once_with( - endpoint= - f"{post_benchmark_comment.GITHUB_IREE_API_PREFIX}/issues/1234/comments", - payload={"body": "xyz"}) - - def test_get_pull_request_head_commit(self): - self._mock_response.status_code = http.client.OK - self._mock_response.json.return_value = {"head": {"sha": "sha123"}} - client = post_benchmark_comment.GithubClient(self._mock_requester) - - commit_sha = client.get_pull_request_head_commit(pr_number=123) - - self.assertEqual(commit_sha, "sha123") - self._mock_requester.get.assert_called_once_with( - endpoint=f"{post_benchmark_comment.GITHUB_IREE_API_PREFIX}/pulls/123") + def setUp(self): + self._mock_response = mock.create_autospec(requests.Response) + self._mock_requester = mock.create_autospec(post_benchmark_comment.APIRequester) + self._mock_requester.get.return_value = self._mock_response + self._mock_requester.post.return_value = self._mock_response + self._mock_requester.patch.return_value = self._mock_response + + def test_post_to_gist(self): + gist_url = "https://example.com/123455/1234.md" + self._mock_response.status_code = http.client.CREATED + self._mock_response.json.return_value = { + "html_url": gist_url, + "truncated": False, + } + client = post_benchmark_comment.GithubClient(self._mock_requester) + + url = client.post_to_gist(filename="1234.md", content="xyz") + + self.assertEqual(url, gist_url) + self._mock_requester.post.assert_called_once_with( + endpoint=post_benchmark_comment.GITHUB_GIST_API, + payload={"public": True, "files": {"1234.md": {"content": "xyz"}}}, + ) + + def test_post_to_gist_truncated(self): + gist_url = "example.com/123455/1234.md" + self._mock_response.status_code = http.client.CREATED + self._mock_response.json.return_value = { + "html_url": gist_url, + "truncated": True, + } + client = post_benchmark_comment.GithubClient(self._mock_requester) + + with self.assertRaises(RuntimeError) as _: + client.post_to_gist(filename="1234.md", content="xyz") + + def test_get_previous_comment_on_pr(self): + first_mock_response = mock.create_autospec(requests.Response) + first_mock_response.status_code = http.client.OK + first_mock_response.json.return_value = [ + {"id": 1, "user": {"login": "bot"}, "body": "comment id: abcd"}, + {"id": 2, "user": {"login": "user"}, "body": "comment id: 1234"}, + ] + second_mock_response = mock.create_autospec(requests.Response) + second_mock_response.status_code = http.client.OK + second_mock_response.json.return_value = [ + {"id": 3, "user": {"login": "bot"}, "body": "comment id: 1234"} + ] + mock_requester = mock.create_autospec(post_benchmark_comment.APIRequester) + mock_requester.get.side_effect = [first_mock_response, second_mock_response] + client = post_benchmark_comment.GithubClient(mock_requester) + + comment_id = client.get_previous_comment_on_pr( + pr_number=23, + comment_bot_user="bot", + comment_type_id="1234", + query_comment_per_page=2, + max_pages_to_search=10, + ) + + self.assertEqual(comment_id, 3) + self.assertEqual(mock_requester.get.call_count, 2) + endpoint_url = ( + f"{post_benchmark_comment.GITHUB_IREE_API_PREFIX}/issues/23/comments" + ) + mock_requester.get.assert_any_call( + endpoint=endpoint_url, + payload={"per_page": 2, "page": 1, "sort": "updated", "direction": "desc"}, + ) + mock_requester.get.assert_any_call( + endpoint=endpoint_url, + payload={"per_page": 2, "page": 2, "sort": "updated", "direction": "desc"}, + ) + + def test_get_previous_comment_on_pr_not_found(self): + mock_response = mock.create_autospec(requests.Response) + mock_response.status_code = http.client.OK + mock_response.json.return_value = [ + {"id": 1, "user": {"login": "bot"}, "body": "comment id: 5678"} + ] + mock_requester = mock.create_autospec(post_benchmark_comment.APIRequester) + mock_requester.get.side_effect = [mock_response] * 10 + client = post_benchmark_comment.GithubClient(mock_requester) + + comment_id = client.get_previous_comment_on_pr( + pr_number=23, + comment_bot_user="bot", + comment_type_id="1234", + query_comment_per_page=1, + max_pages_to_search=10, + ) + + self.assertIsNone(comment_id) + self.assertEqual(mock_requester.get.call_count, 10) + endpoint_url = ( + f"{post_benchmark_comment.GITHUB_IREE_API_PREFIX}/issues/23/comments" + ) + mock_requester.get.assert_any_call( + endpoint=endpoint_url, + payload={"per_page": 1, "page": 1, "sort": "updated", "direction": "desc"}, + ) + mock_requester.get.assert_any_call( + endpoint=endpoint_url, + payload={"per_page": 1, "page": 10, "sort": "updated", "direction": "desc"}, + ) + + def test_update_comment_on_pr(self): + self._mock_response.status_code = http.client.OK + client = post_benchmark_comment.GithubClient(self._mock_requester) + + client.update_comment_on_pr(comment_id=123, content="xyz") + + self._mock_requester.patch.assert_called_once_with( + endpoint=f"{post_benchmark_comment.GITHUB_IREE_API_PREFIX}/issues/comments/123", + payload={"body": "xyz"}, + ) + + def test_create_comment_on_pr(self): + self._mock_response.status_code = http.client.CREATED + client = post_benchmark_comment.GithubClient(self._mock_requester) + + client.create_comment_on_pr(pr_number=1234, content="xyz") + + self._mock_requester.post.assert_called_once_with( + endpoint=f"{post_benchmark_comment.GITHUB_IREE_API_PREFIX}/issues/1234/comments", + payload={"body": "xyz"}, + ) + + def test_get_pull_request_head_commit(self): + self._mock_response.status_code = http.client.OK + self._mock_response.json.return_value = {"head": {"sha": "sha123"}} + client = post_benchmark_comment.GithubClient(self._mock_requester) + + commit_sha = client.get_pull_request_head_commit(pr_number=123) + + self.assertEqual(commit_sha, "sha123") + self._mock_requester.get.assert_called_once_with( + endpoint=f"{post_benchmark_comment.GITHUB_IREE_API_PREFIX}/pulls/123" + ) if __name__ == "__main__": - unittest.main() + unittest.main() diff --git a/build_tools/benchmarks/reporting/parse_shark_benchmarks.py b/build_tools/benchmarks/reporting/parse_shark_benchmarks.py index 2c85891ef557..97fd5cdffaae 100755 --- a/build_tools/benchmarks/reporting/parse_shark_benchmarks.py +++ b/build_tools/benchmarks/reporting/parse_shark_benchmarks.py @@ -51,287 +51,342 @@ def _generate_table(df_iree, df_shark, df_baseline, title): - """Generates a table comparing latencies between IREE, SHARK and a baseline.""" - summary = pd.DataFrame(columns=[ - _MODEL, _BASELINE, _DATA_TYPE, _DIALECT, _DEVICE, _BASELINE_LATENCY, - _IREE_LATENCY, _SHARK_LATENCY, _IREE_VS_BASELINE, _SHARK_VS_BASELINE, - _IREE_VS_SHARK, _BASELINE_MEMORY, _IREE_MEMORY, _SHARK_MEMORY - ]) - - models = df_iree.model.unique() - for model in models: - iree_results_per_model = df_iree.loc[df_iree.model == model] - dialects = iree_results_per_model.dialect.unique() - for dialect in dialects: - iree_results_per_dialect = iree_results_per_model.loc[ - iree_results_per_model.dialect == dialect] - data_types = iree_results_per_dialect.data_type.unique() - for data_type in data_types: - iree_results_per_datatype = iree_results_per_dialect.loc[ - iree_results_per_dialect.data_type == data_type] - device_types = iree_results_per_datatype.device.unique() - for device in device_types: - iree_results = iree_results_per_datatype.loc[ - iree_results_per_datatype.device == device] - if len(iree_results) != 3: - print(f"Warning! Expected number of results to be 3. Got" - f" {len(iree_results)}") - print(iree_results) - continue - - baseline_results = df_baseline.loc[(df_baseline.model == model) & - (df_baseline.dialect == dialect) & - (df_baseline.data_type - == data_type) & - (df_baseline.device == device)] - - if baseline_results.empty: - # We use snapshots of latencies for baseline. If it is a new - # benchmark that is not included in the snapshot yet, emit a - # warning. - print( - f"Warning: No baseline results found for {model}, {dialect}," - f" {data_type}, {device}. Using IREE version as baseline. Please" - f" update baseline csv.") - engine = iree_results.engine.iloc[0] - baseline_df = iree_results.loc[iree_results.engine == engine] - baseline_latency = baseline_df.iloc[0]["ms/iter"] - baseline_device_mb = baseline_df.iloc[0]["device_memory_mb"] - else: - engine = baseline_results.engine.iloc[0] - baseline_df = baseline_results.loc[baseline_results.engine == - engine] - baseline_latency = baseline_df.iloc[0]["ms/iter"] - baseline_device_mb = baseline_df.iloc[0]["device_memory_mb"] - - iree_df = iree_results.loc[iree_results.engine == "shark_iree_c"] - iree_latency = iree_df.iloc[0]["ms/iter"] - iree_device_mb = iree_df.iloc[0]["device_memory_mb"] - iree_vs_baseline = html_utils.format_latency_comparison( - iree_latency, baseline_latency) - - if df_shark is not None: - shark_results = df_shark.loc[(df_shark.model == model) & - (df_shark.dialect == dialect) & - (df_shark.data_type == data_type) & - (df_shark.device == device)] - if shark_results.empty: - print( - f"Warning: No SHARK results for {model}, {dialect}, {data_type}, {device}." - ) - continue - - shark_df = shark_results.loc[shark_results.engine == "shark_iree_c"] - shark_latency = shark_df.iloc[0]["ms/iter"] - shark_device_mb = shark_df.iloc[0]["device_memory_mb"] - shark_vs_baseline = html_utils.format_latency_comparison( - shark_latency, baseline_latency) - iree_vs_shark = html_utils.format_latency_comparison( - iree_latency, shark_latency) - else: - # If there are no SHARK benchmarks available, use default values. - # These columns will be hidden later. - shark_latency = 0 - shark_vs_baseline = "" - iree_vs_shark = "" - - summary.loc[len(summary)] = [ - model, - engine, - data_type, - dialect, - device, - f"{baseline_latency:.1f}", - f"{iree_latency:.1f}", - f"{shark_latency:.1f}", - iree_vs_baseline, - shark_vs_baseline, - iree_vs_shark, - f"{baseline_device_mb:.3f}", - f"{iree_device_mb:.3f}", - f"{shark_device_mb:.3f}", - ] - - summary = summary.round(2) - - st = summary.style.set_table_styles(html_utils.get_table_css()) - st = st.hide(axis="index") - if df_shark is None: - st = st.hide_columns( - subset=[_SHARK_LATENCY, _SHARK_VS_BASELINE, _IREE_VS_SHARK]) - st = st.set_caption(title) - st = st.applymap(html_utils.style_performance, subset=_PERF_COLUMNS) - st = st.set_properties(subset=[_MODEL], - **{ - "width": "300px", - "text-align": "left", - }) - st = st.set_properties(subset=[_BASELINE], - **{ - "width": "140", - "text-align": "center", - }) - st = st.set_properties(subset=[_DIALECT, _DATA_TYPE, _DEVICE], - **{ - "width": "100", - "text-align": "center", - }) - st = st.set_properties(subset=_LATENCY_COLUMNS, - **{ - "width": "100", - "text-align": "right", - }) - st = st.set_properties(subset=_PERF_COLUMNS, - **{ - "width": "150px", - "text-align": "right", - "color": "#ffffff" - }) - st = st.set_properties(subset=_MEMORY_COLUMNS, - **{ - "width": "100", - "text-align": "right", - }) - - return st.to_html() + "
" - - -def generate_table(iree_csv, - baseline_csv, - shark_csv=None, - shape_type="static", - device="cpu", - title="Benchmarks"): - """Generates a table comparing latencies between IREE, SHARK and a baseline. - - Args: - iree_csv: Path to the csv file containing IREE latencies. - baseline_csv: Path to the csv file containing baseline latencies. - shark_csv: Path to the csv file containing SHARK-Runtime latencies. This is optional. - shape_type: Currently either `static` or `dynamic`. - device: Device used to run the benchmarks. - title: The title of the generated table. - - Returns: - An HTML string containing the summarized report. - """ - shark_df = None - if shark_csv is not None: - shark_df = pd.read_csv(shark_csv) - shark_df = shark_df.loc[(shark_df.shape_type == shape_type) & - (shark_df.device == device)] - - iree_df = pd.read_csv(iree_csv) - iree_df = iree_df.loc[(iree_df.shape_type == shape_type) & - (iree_df.device == device)] - - baseline_df = pd.read_csv(baseline_csv) - baseline_df = baseline_df.loc[(baseline_df.shape_type == shape_type) & - (baseline_df.device == device)] - - return _generate_table(iree_df, shark_df, baseline_df, title) + """Generates a table comparing latencies between IREE, SHARK and a baseline.""" + summary = pd.DataFrame( + columns=[ + _MODEL, + _BASELINE, + _DATA_TYPE, + _DIALECT, + _DEVICE, + _BASELINE_LATENCY, + _IREE_LATENCY, + _SHARK_LATENCY, + _IREE_VS_BASELINE, + _SHARK_VS_BASELINE, + _IREE_VS_SHARK, + _BASELINE_MEMORY, + _IREE_MEMORY, + _SHARK_MEMORY, + ] + ) + + models = df_iree.model.unique() + for model in models: + iree_results_per_model = df_iree.loc[df_iree.model == model] + dialects = iree_results_per_model.dialect.unique() + for dialect in dialects: + iree_results_per_dialect = iree_results_per_model.loc[ + iree_results_per_model.dialect == dialect + ] + data_types = iree_results_per_dialect.data_type.unique() + for data_type in data_types: + iree_results_per_datatype = iree_results_per_dialect.loc[ + iree_results_per_dialect.data_type == data_type + ] + device_types = iree_results_per_datatype.device.unique() + for device in device_types: + iree_results = iree_results_per_datatype.loc[ + iree_results_per_datatype.device == device + ] + if len(iree_results) != 3: + print( + f"Warning! Expected number of results to be 3. Got" + f" {len(iree_results)}" + ) + print(iree_results) + continue + + baseline_results = df_baseline.loc[ + (df_baseline.model == model) + & (df_baseline.dialect == dialect) + & (df_baseline.data_type == data_type) + & (df_baseline.device == device) + ] + + if baseline_results.empty: + # We use snapshots of latencies for baseline. If it is a new + # benchmark that is not included in the snapshot yet, emit a + # warning. + print( + f"Warning: No baseline results found for {model}, {dialect}," + f" {data_type}, {device}. Using IREE version as baseline. Please" + f" update baseline csv." + ) + engine = iree_results.engine.iloc[0] + baseline_df = iree_results.loc[iree_results.engine == engine] + baseline_latency = baseline_df.iloc[0]["ms/iter"] + baseline_device_mb = baseline_df.iloc[0]["device_memory_mb"] + else: + engine = baseline_results.engine.iloc[0] + baseline_df = baseline_results.loc[ + baseline_results.engine == engine + ] + baseline_latency = baseline_df.iloc[0]["ms/iter"] + baseline_device_mb = baseline_df.iloc[0]["device_memory_mb"] + + iree_df = iree_results.loc[iree_results.engine == "shark_iree_c"] + iree_latency = iree_df.iloc[0]["ms/iter"] + iree_device_mb = iree_df.iloc[0]["device_memory_mb"] + iree_vs_baseline = html_utils.format_latency_comparison( + iree_latency, baseline_latency + ) + + if df_shark is not None: + shark_results = df_shark.loc[ + (df_shark.model == model) + & (df_shark.dialect == dialect) + & (df_shark.data_type == data_type) + & (df_shark.device == device) + ] + if shark_results.empty: + print( + f"Warning: No SHARK results for {model}, {dialect}, {data_type}, {device}." + ) + continue + + shark_df = shark_results.loc[ + shark_results.engine == "shark_iree_c" + ] + shark_latency = shark_df.iloc[0]["ms/iter"] + shark_device_mb = shark_df.iloc[0]["device_memory_mb"] + shark_vs_baseline = html_utils.format_latency_comparison( + shark_latency, baseline_latency + ) + iree_vs_shark = html_utils.format_latency_comparison( + iree_latency, shark_latency + ) + else: + # If there are no SHARK benchmarks available, use default values. + # These columns will be hidden later. + shark_latency = 0 + shark_vs_baseline = "" + iree_vs_shark = "" + + summary.loc[len(summary)] = [ + model, + engine, + data_type, + dialect, + device, + f"{baseline_latency:.1f}", + f"{iree_latency:.1f}", + f"{shark_latency:.1f}", + iree_vs_baseline, + shark_vs_baseline, + iree_vs_shark, + f"{baseline_device_mb:.3f}", + f"{iree_device_mb:.3f}", + f"{shark_device_mb:.3f}", + ] + + summary = summary.round(2) + + st = summary.style.set_table_styles(html_utils.get_table_css()) + st = st.hide(axis="index") + if df_shark is None: + st = st.hide_columns( + subset=[_SHARK_LATENCY, _SHARK_VS_BASELINE, _IREE_VS_SHARK] + ) + st = st.set_caption(title) + st = st.applymap(html_utils.style_performance, subset=_PERF_COLUMNS) + st = st.set_properties( + subset=[_MODEL], + **{ + "width": "300px", + "text-align": "left", + }, + ) + st = st.set_properties( + subset=[_BASELINE], + **{ + "width": "140", + "text-align": "center", + }, + ) + st = st.set_properties( + subset=[_DIALECT, _DATA_TYPE, _DEVICE], + **{ + "width": "100", + "text-align": "center", + }, + ) + st = st.set_properties( + subset=_LATENCY_COLUMNS, + **{ + "width": "100", + "text-align": "right", + }, + ) + st = st.set_properties( + subset=_PERF_COLUMNS, + **{"width": "150px", "text-align": "right", "color": "#ffffff"}, + ) + st = st.set_properties( + subset=_MEMORY_COLUMNS, + **{ + "width": "100", + "text-align": "right", + }, + ) + + return st.to_html() + "
" + + +def generate_table( + iree_csv, + baseline_csv, + shark_csv=None, + shape_type="static", + device="cpu", + title="Benchmarks", +): + """Generates a table comparing latencies between IREE, SHARK and a baseline. + + Args: + iree_csv: Path to the csv file containing IREE latencies. + baseline_csv: Path to the csv file containing baseline latencies. + shark_csv: Path to the csv file containing SHARK-Runtime latencies. This is optional. + shape_type: Currently either `static` or `dynamic`. + device: Device used to run the benchmarks. + title: The title of the generated table. + + Returns: + An HTML string containing the summarized report. + """ + shark_df = None + if shark_csv is not None: + shark_df = pd.read_csv(shark_csv) + shark_df = shark_df.loc[ + (shark_df.shape_type == shape_type) & (shark_df.device == device) + ] + + iree_df = pd.read_csv(iree_csv) + iree_df = iree_df.loc[ + (iree_df.shape_type == shape_type) & (iree_df.device == device) + ] + + baseline_df = pd.read_csv(baseline_csv) + baseline_df = baseline_df.loc[ + (baseline_df.shape_type == shape_type) & (baseline_df.device == device) + ] + + return _generate_table(iree_df, shark_df, baseline_df, title) def main(args): - """Summarizes benchmark results generated by the SHARK Tank.""" - version_html = f"last updated: {date.today().isoformat()}

" - version_html += "Version Info
" - with open(args.version_info) as f: - version_info = dict(l.strip().split("=", 1) for l in f) - for key, value in version_info.items(): - version_html += f"{key}: {value}
" - version_html += "
" - - html = html_utils.generate_header_and_legend(version_html) - - # Generate Server CPU Static. - if args.cpu_iree_csv is not None: - html += generate_table(args.cpu_iree_csv, - args.cpu_baseline_csv, - shark_csv=args.cpu_shark_csv, - shape_type="static", - device="cpu", - title="Server Intel Ice Lake CPU (Static Shapes)") - - # Generate Server GPU Static. - if args.gpu_iree_csv is not None: - html += generate_table(args.gpu_iree_csv, - args.gpu_baseline_csv, - shark_csv=args.gpu_shark_csv, - shape_type="static", - device="cuda", - title="Server NVIDIA Tesla A100 GPU (Static Shapes)") - - # Generate Server CPU Dynamic. - if args.cpu_iree_csv is not None: - html += generate_table(args.cpu_iree_csv, - args.cpu_baseline_csv, - shark_csv=args.cpu_shark_csv, - shape_type="dynamic", - device="cpu", - title="Server Intel Ice Lake CPU (Dynamic Shapes)") - - # Generate Server GPU Dynamic. - if args.gpu_iree_csv is not None: - html += generate_table( - args.gpu_iree_csv, - args.gpu_baseline_csv, - shark_csv=args.gpu_shark_csv, - shape_type="dynamic", - device="cuda", - title="Server NVIDIA Tesla A100 GPU (Dynamic Shapes)") - - args.output_path.write_text(html) + """Summarizes benchmark results generated by the SHARK Tank.""" + version_html = f"last updated: {date.today().isoformat()}

" + version_html += "Version Info
" + with open(args.version_info) as f: + version_info = dict(l.strip().split("=", 1) for l in f) + for key, value in version_info.items(): + version_html += f"{key}: {value}
" + version_html += "
" + + html = html_utils.generate_header_and_legend(version_html) + + # Generate Server CPU Static. + if args.cpu_iree_csv is not None: + html += generate_table( + args.cpu_iree_csv, + args.cpu_baseline_csv, + shark_csv=args.cpu_shark_csv, + shape_type="static", + device="cpu", + title="Server Intel Ice Lake CPU (Static Shapes)", + ) + + # Generate Server GPU Static. + if args.gpu_iree_csv is not None: + html += generate_table( + args.gpu_iree_csv, + args.gpu_baseline_csv, + shark_csv=args.gpu_shark_csv, + shape_type="static", + device="cuda", + title="Server NVIDIA Tesla A100 GPU (Static Shapes)", + ) + + # Generate Server CPU Dynamic. + if args.cpu_iree_csv is not None: + html += generate_table( + args.cpu_iree_csv, + args.cpu_baseline_csv, + shark_csv=args.cpu_shark_csv, + shape_type="dynamic", + device="cpu", + title="Server Intel Ice Lake CPU (Dynamic Shapes)", + ) + + # Generate Server GPU Dynamic. + if args.gpu_iree_csv is not None: + html += generate_table( + args.gpu_iree_csv, + args.gpu_baseline_csv, + shark_csv=args.gpu_shark_csv, + shape_type="dynamic", + device="cuda", + title="Server NVIDIA Tesla A100 GPU (Dynamic Shapes)", + ) + + args.output_path.write_text(html) def parse_args(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--cpu_shark_csv", - type=str, - default=None, - help="The path to the csv file with CPU benchmarking results from the " - "SHARK runtime.") - parser.add_argument( - "--cpu_iree_csv", - type=str, - default=None, - help="The path to the csv file with CPU benchmarking results from IREE.") - parser.add_argument( - "--cpu_baseline_csv", - type=str, - default="data/icelake_baseline_2022-09-19.csv", - help="The path to the csv file containing baseline CPU results.") - parser.add_argument( - "--gpu_shark_csv", - type=str, - default=None, - help="The path to the csv file with GPU benchmarking results from the " - "SHARK runtime.") - parser.add_argument( - "--gpu_iree_csv", - type=str, - default=None, - help="The path to the csv file with CPU benchmarking results from IREE.") - parser.add_argument( - "--gpu_baseline_csv", - type=str, - default="data/a100_baseline_2022-09-19.csv", - help="The path to the csv file containing baseline GPU results.") - parser.add_argument( - "--version_info", - type=pathlib.Path, - default=None, - help= - "The path to a text file containing version information of the frameworks tested." - ) - parser.add_argument( - "--output_path", - type=pathlib.Path, - default="/tmp/summary.html", - help="The path to the output html file that summarizes results.") - return parser.parse_args() + parser = argparse.ArgumentParser() + parser.add_argument( + "--cpu_shark_csv", + type=str, + default=None, + help="The path to the csv file with CPU benchmarking results from the " + "SHARK runtime.", + ) + parser.add_argument( + "--cpu_iree_csv", + type=str, + default=None, + help="The path to the csv file with CPU benchmarking results from IREE.", + ) + parser.add_argument( + "--cpu_baseline_csv", + type=str, + default="data/icelake_baseline_2022-09-19.csv", + help="The path to the csv file containing baseline CPU results.", + ) + parser.add_argument( + "--gpu_shark_csv", + type=str, + default=None, + help="The path to the csv file with GPU benchmarking results from the " + "SHARK runtime.", + ) + parser.add_argument( + "--gpu_iree_csv", + type=str, + default=None, + help="The path to the csv file with CPU benchmarking results from IREE.", + ) + parser.add_argument( + "--gpu_baseline_csv", + type=str, + default="data/a100_baseline_2022-09-19.csv", + help="The path to the csv file containing baseline GPU results.", + ) + parser.add_argument( + "--version_info", + type=pathlib.Path, + default=None, + help="The path to a text file containing version information of the frameworks tested.", + ) + parser.add_argument( + "--output_path", + type=pathlib.Path, + default="/tmp/summary.html", + help="The path to the output html file that summarizes results.", + ) + return parser.parse_args() if __name__ == "__main__": - main(parse_args()) + main(parse_args()) diff --git a/build_tools/benchmarks/reporting/parse_tflite_benchmarks.py b/build_tools/benchmarks/reporting/parse_tflite_benchmarks.py index d08af9797fbf..b9c67b0272c2 100755 --- a/build_tools/benchmarks/reporting/parse_tflite_benchmarks.py +++ b/build_tools/benchmarks/reporting/parse_tflite_benchmarks.py @@ -55,7 +55,7 @@ "ssd_mobilenet_v2_static_1.0_int8": "int8", "ssd_mobilenet_v2_fpnlite_fp32": "fp32", "ssd_mobilenet_v2_fpnlite_fp32_fp16": "fp16", - "ssd_mobilenet_v2_fpnlite_uint8": 'uint8', + "ssd_mobilenet_v2_fpnlite_uint8": "uint8", } # Column headers. @@ -84,398 +84,450 @@ def get_tflite_model_list(df): - """Retrieves the list of TFLite models, filtering out duplicates. + """Retrieves the list of TFLite models, filtering out duplicates. - The .csv file includes multiple entries of the same model but under a - different configuration (e.g. XNNPack enabled, XNNPack disabled). - """ - df = df.loc[df.runtime == "tflite"] - # Remove rows where the model name ends with `noxnn` since this is a duplicate. - df = df[~df.model.str.endswith("noxnn")] - return df.model.unique() + The .csv file includes multiple entries of the same model but under a + different configuration (e.g. XNNPack enabled, XNNPack disabled). + """ + df = df.loc[df.runtime == "tflite"] + # Remove rows where the model name ends with `noxnn` since this is a duplicate. + df = df[~df.model.str.endswith("noxnn")] + return df.model.unique() def get_fastest_result(model, df): - """Retrieves the lowest latency result from multiple configurations. + """Retrieves the lowest latency result from multiple configurations. - Benchmarks are run under different configurations (e.g. number of threads, - Big core, LITTLE core, etc). This method retrieves the fastest configuration - whilst ensuring apples to apples comparisons (e.g. FP16 results are not - considered when the model is FP32). + Benchmarks are run under different configurations (e.g. number of threads, + Big core, LITTLE core, etc). This method retrieves the fastest configuration + whilst ensuring apples to apples comparisons (e.g. FP16 results are not + considered when the model is FP32). - Args: - model: The model name. - df: The dataframe to filter through. + Args: + model: The model name. + df: The dataframe to filter through. - Returns: - A dataframe containing the lowest latency. - """ - df = df[df.model.str.startswith(model)] - if not model.endswith("fp16"): - df = df[~df[_MODEL].str.endswith("fp16")] - df = df[df[_LATENCY] != 0] - df = df[df[_LATENCY] == df[_LATENCY].min()] - return df.head(1) + Returns: + A dataframe containing the lowest latency. + """ + df = df[df.model.str.startswith(model)] + if not model.endswith("fp16"): + df = df[~df[_MODEL].str.endswith("fp16")] + df = df[df[_LATENCY] != 0] + df = df[df[_LATENCY] == df[_LATENCY].min()] + return df.head(1) def get_tflite_config(model, df): - """Generates a configuration string from TFLite config variables.""" - config = [] - if _TASKSET in df.columns: - taskset = df.taskset.iloc[0] - config.append(f"taskset {taskset}") - threads = df.threads.iloc[0] - config.append(f"{threads} threads" if threads > 1 else f"{threads} thread") - config.append("no xnnpack" if model.endswith("noxnn") else "xnnpack") - return ", ".join(config) + """Generates a configuration string from TFLite config variables.""" + config = [] + if _TASKSET in df.columns: + taskset = df.taskset.iloc[0] + config.append(f"taskset {taskset}") + threads = df.threads.iloc[0] + config.append(f"{threads} threads" if threads > 1 else f"{threads} thread") + config.append("no xnnpack" if model.endswith("noxnn") else "xnnpack") + return ", ".join(config) def generate_tflite_summary(dataframe): - """Generates a dataframe containing the fastest TFLite result for each model.""" - summary = pd.DataFrame(columns=[_MODEL, _LATENCY, _MEMORY, _CONFIG]) - tflite_df = dataframe[dataframe.runtime == "tflite"] - model_list = get_tflite_model_list(dataframe) - for model in model_list: - df = get_fastest_result(model, tflite_df) - if df.empty: - print(f"Warning: TFLite results invalid for {model}.") - continue - latency = df[_LATENCY].iloc[0] - full_model_name = df.model.iloc[0] - memory = df[_MEMORY].iloc[0] - config = get_tflite_config(full_model_name, df) - summary.loc[len(summary)] = [model, latency, memory, config] - return summary + """Generates a dataframe containing the fastest TFLite result for each model.""" + summary = pd.DataFrame(columns=[_MODEL, _LATENCY, _MEMORY, _CONFIG]) + tflite_df = dataframe[dataframe.runtime == "tflite"] + model_list = get_tflite_model_list(dataframe) + for model in model_list: + df = get_fastest_result(model, tflite_df) + if df.empty: + print(f"Warning: TFLite results invalid for {model}.") + continue + latency = df[_LATENCY].iloc[0] + full_model_name = df.model.iloc[0] + memory = df[_MEMORY].iloc[0] + config = get_tflite_config(full_model_name, df) + summary.loc[len(summary)] = [model, latency, memory, config] + return summary def get_iree_model_list(df): - """Retrieves the list of IREE models, filtering out duplicates. + """Retrieves the list of IREE models, filtering out duplicates. - The .csv file includes multiple entries of the same model but under a - different configuration (e.g. mmt4d). - """ - df = df.loc[df.runtime == "iree"] - df = df[~df.model.str.endswith("mmt4d")] - df = df[~df.model.str.endswith("padfuse")] - return df.model.unique() + The .csv file includes multiple entries of the same model but under a + different configuration (e.g. mmt4d). + """ + df = df.loc[df.runtime == "iree"] + df = df[~df.model.str.endswith("mmt4d")] + df = df[~df.model.str.endswith("padfuse")] + return df.model.unique() def get_iree_config(model, df): - """Generates a configuration string from IREE config variables. - - The configuration is embedded in the model name. - """ - config = [] - if _TASKSET in df.columns: - taskset = df.taskset.iloc[0] - config.append(f"taskset {taskset}") - threads = df.threads.iloc[0] - config.append(f"{threads} threads" if threads > 1 else f"{threads} thread") - if model.endswith("im2col_mmt4d"): - config.append("im2col") - config.append("mmt4d") - elif model.endswith("mmt4d"): - config.append("mmt4d") - elif model.endswith("padfuse"): - config.append("fused pad") - return ", ".join(config) + """Generates a configuration string from IREE config variables. + + The configuration is embedded in the model name. + """ + config = [] + if _TASKSET in df.columns: + taskset = df.taskset.iloc[0] + config.append(f"taskset {taskset}") + threads = df.threads.iloc[0] + config.append(f"{threads} threads" if threads > 1 else f"{threads} thread") + if model.endswith("im2col_mmt4d"): + config.append("im2col") + config.append("mmt4d") + elif model.endswith("mmt4d"): + config.append("mmt4d") + elif model.endswith("padfuse"): + config.append("fused pad") + return ", ".join(config) def generate_iree_summary(dataframe): - """Generates a dataframe containing the fastest IREE result for each model.""" - summary = pd.DataFrame(columns=[_MODEL, _LATENCY, _MEMORY, _CONFIG]) - iree_df = dataframe[dataframe.runtime == "iree"] - model_list = get_iree_model_list(dataframe) - for model in model_list: - df = get_fastest_result(model, iree_df) - if df.empty: - print(f"Warning: IREE results invalid for {model}.") - continue - latency = df[_LATENCY].iloc[0] - full_model_name = df.model.iloc[0] - memory = df[_MEMORY].iloc[0] - config = get_iree_config(full_model_name, df) - summary.loc[len(summary)] = [model, latency, memory, config] - return summary + """Generates a dataframe containing the fastest IREE result for each model.""" + summary = pd.DataFrame(columns=[_MODEL, _LATENCY, _MEMORY, _CONFIG]) + iree_df = dataframe[dataframe.runtime == "iree"] + model_list = get_iree_model_list(dataframe) + for model in model_list: + df = get_fastest_result(model, iree_df) + if df.empty: + print(f"Warning: IREE results invalid for {model}.") + continue + latency = df[_LATENCY].iloc[0] + full_model_name = df.model.iloc[0] + memory = df[_MEMORY].iloc[0] + config = get_iree_config(full_model_name, df) + summary.loc[len(summary)] = [model, latency, memory, config] + return summary def get_common_html_style(df, title): - """Returns HTML style attributes common to both server and mobile.""" - st = df.style.set_table_styles(html_utils.get_table_css()) - st = st.hide(axis="index") - st = st.set_caption(title) - st = st.set_properties(subset=[_MODEL], - **{ - "width": "300px", - "text-align": "left", - }) - st = st.set_properties(subset=[_DATA_TYPE], - **{ - "width": "100", - "text-align": "center", - }) - st = st.set_properties(subset=_NUMBER_COLUMNS, - **{ - "width": "100", - "text-align": "right", - }) - st = st.set_properties(subset=_PERF_COLUMNS, - **{ - "width": "150px", - "text-align": "right", - "color": "#ffffff" - }) - st = st.applymap(html_utils.style_latency, subset=[_IREE_VS_TFLITE_LATENCY]) - st = st.applymap(html_utils.style_memory, subset=[_IREE_VS_TFLITE_MEMORY]) - return st + """Returns HTML style attributes common to both server and mobile.""" + st = df.style.set_table_styles(html_utils.get_table_css()) + st = st.hide(axis="index") + st = st.set_caption(title) + st = st.set_properties( + subset=[_MODEL], + **{ + "width": "300px", + "text-align": "left", + }, + ) + st = st.set_properties( + subset=[_DATA_TYPE], + **{ + "width": "100", + "text-align": "center", + }, + ) + st = st.set_properties( + subset=_NUMBER_COLUMNS, + **{ + "width": "100", + "text-align": "right", + }, + ) + st = st.set_properties( + subset=_PERF_COLUMNS, + **{"width": "150px", "text-align": "right", "color": "#ffffff"}, + ) + st = st.applymap(html_utils.style_latency, subset=[_IREE_VS_TFLITE_LATENCY]) + st = st.applymap(html_utils.style_memory, subset=[_IREE_VS_TFLITE_MEMORY]) + return st def generate_summary(dataframe, title): - """Generates a table comparing latencies and memory usage between IREE and TFLite. - - For each model, retrieves the lowest latency configuration from both IREE and TFLite. - - Args: - dataframe: The raw data to summarize. - title: The title of the table. - - Returns: - An HTML string containing the summarized report. - """ - summary = pd.DataFrame(columns=[ - _MODEL, _DATA_TYPE, _TFLITE_CONFIG, _IREE_CONFIG, _TFLITE_LATENCY, - _IREE_LATENCY, _IREE_VS_TFLITE_LATENCY, _TFLITE_MEMORY, _IREE_MEMORY, - _IREE_VS_TFLITE_MEMORY - ]) - - tflite_df = generate_tflite_summary(dataframe) - iree_df = generate_iree_summary(dataframe) - model_list = tflite_df[_MODEL].unique() - - for model in model_list: - tflite_results = tflite_df[tflite_df.model == model] - iree_results = iree_df[iree_df.model == model] - - if tflite_results.empty: - print(f"Warning: No TFLite results found for model {model}") - continue - if iree_results.empty: - print(f"Warning: No IREE results found for model {model}") - continue - - iree_latency = iree_results[_LATENCY].iloc[0] - tflite_latency = tflite_results[_LATENCY].iloc[0] - latency_comparison = html_utils.format_latency_comparison( - iree_latency, tflite_latency) - - iree_memory = iree_results[_MEMORY].iloc[0] - tflite_memory = tflite_results[_MEMORY].iloc[0] - memory_comparison = html_utils.format_memory_comparison( - iree_memory, tflite_memory) - - iree_config = iree_results.config.iloc[0] - tflite_config = tflite_results.config.iloc[0] - summary.loc[len(summary)] = [ - model, - _MODEL_TO_DATA_TYPE[model], - tflite_config, - iree_config, - f"{tflite_latency:.1f}", - f"{iree_latency:.1f}", - latency_comparison, - f"{tflite_memory:,.0f}", - f"{iree_memory:,.0f}", - memory_comparison, - ] - - summary = summary.round(2) - st = get_common_html_style(summary, title) - st = st.set_properties(subset=_CONFIG_COLUMNS, - **{ - "width": "300px", - "text-align": "left", - }) - return st.to_html().replace("\\n", "
") + "
" - + """Generates a table comparing latencies and memory usage between IREE and TFLite. + + For each model, retrieves the lowest latency configuration from both IREE and TFLite. + + Args: + dataframe: The raw data to summarize. + title: The title of the table. + + Returns: + An HTML string containing the summarized report. + """ + summary = pd.DataFrame( + columns=[ + _MODEL, + _DATA_TYPE, + _TFLITE_CONFIG, + _IREE_CONFIG, + _TFLITE_LATENCY, + _IREE_LATENCY, + _IREE_VS_TFLITE_LATENCY, + _TFLITE_MEMORY, + _IREE_MEMORY, + _IREE_VS_TFLITE_MEMORY, + ] + ) + + tflite_df = generate_tflite_summary(dataframe) + iree_df = generate_iree_summary(dataframe) + model_list = tflite_df[_MODEL].unique() + + for model in model_list: + tflite_results = tflite_df[tflite_df.model == model] + iree_results = iree_df[iree_df.model == model] + + if tflite_results.empty: + print(f"Warning: No TFLite results found for model {model}") + continue + if iree_results.empty: + print(f"Warning: No IREE results found for model {model}") + continue -def generate_detail(dataframe, title, platform): - """Generates a table comparing latencies and memory usage between IREE and TFLite. - - The table generated is more detailed than `generate_summary`. It lists latencies - of all IREE configurations, using the fastest TFLite configuration as baseline. - - Args: - dataframe: The raw data to summarize. - title: The title of the table. - platform: Either `server` or `mobile`. - - Returns: - An HTML string containing the detailed report. - """ - summary = pd.DataFrame(columns=[ - _MODEL, _DATA_TYPE, _TFLITE_CONFIG, _IREE_CONFIG, _TASKSET, _THREADS, - _TFLITE_LATENCY, _IREE_LATENCY, _IREE_VS_TFLITE_LATENCY, _TFLITE_MEMORY, - _IREE_MEMORY, _IREE_VS_TFLITE_MEMORY - ]) - - model_list = get_tflite_model_list(dataframe) - for model in model_list: - df = dataframe[dataframe.model.str.startswith(model)] - # If result does not use FP16, remove FP16 results from dataframe to - # maintain apples-to-apples comparisons. - if not model.endswith("fp16"): - df = df[~df.model.str.endswith("fp16")] + iree_latency = iree_results[_LATENCY].iloc[0] + tflite_latency = tflite_results[_LATENCY].iloc[0] + latency_comparison = html_utils.format_latency_comparison( + iree_latency, tflite_latency + ) + + iree_memory = iree_results[_MEMORY].iloc[0] + tflite_memory = tflite_results[_MEMORY].iloc[0] + memory_comparison = html_utils.format_memory_comparison( + iree_memory, tflite_memory + ) + + iree_config = iree_results.config.iloc[0] + tflite_config = tflite_results.config.iloc[0] + summary.loc[len(summary)] = [ + model, + _MODEL_TO_DATA_TYPE[model], + tflite_config, + iree_config, + f"{tflite_latency:.1f}", + f"{iree_latency:.1f}", + latency_comparison, + f"{tflite_memory:,.0f}", + f"{iree_memory:,.0f}", + memory_comparison, + ] + + summary = summary.round(2) + st = get_common_html_style(summary, title) + st = st.set_properties( + subset=_CONFIG_COLUMNS, + **{ + "width": "300px", + "text-align": "left", + }, + ) + return st.to_html().replace("\\n", "
") + "
" - if _TASKSET in df.columns: - tasksets = df.taskset.unique() - else: - tasksets = ["none"] - - for taskset in tasksets: - per_taskset_df = df if taskset == "none" else df[df.taskset == taskset] - threads = per_taskset_df.threads.unique() - for thread in threads: - per_thread_df = per_taskset_df[per_taskset_df.threads == thread] - tflite_df = get_fastest_result( - model, per_thread_df[per_thread_df.runtime == "tflite"]) - if tflite_df.empty: - continue - - tflite_latency = tflite_df[_LATENCY].iloc[0] - tflite_memory = tflite_df[_MEMORY].iloc[0] - if tflite_latency == 0 or tflite_memory == 0: - continue - - full_model_name = tflite_df.model.iloc[0] - # For TFLite config, we only want to know if XNNPack was used. The other - # configuration settings are covered in other columns. - tflite_config = "no xnnpack" if full_model_name.endswith( - "noxnn") else "xnnpack" - - iree_df = per_thread_df[per_thread_df.runtime == "iree"] - for _, row in iree_df.iterrows(): - iree_config = row[_DRIVER] - model_name = row[_MODEL] - if model_name.endswith("im2col_mmt4d"): - iree_config += ", im2col, mmt4d" - elif model_name.endswith("mmt4d"): - iree_config += ", mmt4d" - elif model_name.endswith("padfuse"): - iree_config += ", fused pad" - - iree_latency = row[_LATENCY] - latency_comparison = html_utils.format_latency_comparison( - iree_latency, tflite_latency) - iree_memory = row[_MEMORY] - memory_comparison = html_utils.format_memory_comparison( - iree_memory, tflite_memory) - - if iree_latency == 0 or iree_memory == 0: - continue - summary.loc[len(summary)] = [ - model, _MODEL_TO_DATA_TYPE[model], tflite_config, iree_config, - taskset, thread, f"{tflite_latency:.1f}", f"{iree_latency:.1f}", - latency_comparison, f"{tflite_memory:,.0f}", - f"{iree_memory:,.0f}", memory_comparison - ] - - summary = summary.round(2) - st = get_common_html_style(summary, title) - st = st.set_properties(subset=[_TASKSET, _THREADS], - **{ - "width": "100", - "text-align": "center", - }) - st = st.set_properties(subset=[_TFLITE_CONFIG], - **{ - "width": "150px", - "text-align": "left", - }) - st = st.set_properties(subset=[_IREE_CONFIG], - **{ - "width": "300px", - "text-align": "left", - }) - if platform != "mobile": - st.hide_columns(subset=[_TASKSET]) - - return st.to_html().replace("\\n", "
") + "
" +def generate_detail(dataframe, title, platform): + """Generates a table comparing latencies and memory usage between IREE and TFLite. + + The table generated is more detailed than `generate_summary`. It lists latencies + of all IREE configurations, using the fastest TFLite configuration as baseline. + + Args: + dataframe: The raw data to summarize. + title: The title of the table. + platform: Either `server` or `mobile`. + + Returns: + An HTML string containing the detailed report. + """ + summary = pd.DataFrame( + columns=[ + _MODEL, + _DATA_TYPE, + _TFLITE_CONFIG, + _IREE_CONFIG, + _TASKSET, + _THREADS, + _TFLITE_LATENCY, + _IREE_LATENCY, + _IREE_VS_TFLITE_LATENCY, + _TFLITE_MEMORY, + _IREE_MEMORY, + _IREE_VS_TFLITE_MEMORY, + ] + ) + + model_list = get_tflite_model_list(dataframe) + for model in model_list: + df = dataframe[dataframe.model.str.startswith(model)] + # If result does not use FP16, remove FP16 results from dataframe to + # maintain apples-to-apples comparisons. + if not model.endswith("fp16"): + df = df[~df.model.str.endswith("fp16")] + + if _TASKSET in df.columns: + tasksets = df.taskset.unique() + else: + tasksets = ["none"] + + for taskset in tasksets: + per_taskset_df = df if taskset == "none" else df[df.taskset == taskset] + threads = per_taskset_df.threads.unique() + for thread in threads: + per_thread_df = per_taskset_df[per_taskset_df.threads == thread] + tflite_df = get_fastest_result( + model, per_thread_df[per_thread_df.runtime == "tflite"] + ) + if tflite_df.empty: + continue + + tflite_latency = tflite_df[_LATENCY].iloc[0] + tflite_memory = tflite_df[_MEMORY].iloc[0] + if tflite_latency == 0 or tflite_memory == 0: + continue + + full_model_name = tflite_df.model.iloc[0] + # For TFLite config, we only want to know if XNNPack was used. The other + # configuration settings are covered in other columns. + tflite_config = ( + "no xnnpack" if full_model_name.endswith("noxnn") else "xnnpack" + ) + + iree_df = per_thread_df[per_thread_df.runtime == "iree"] + for _, row in iree_df.iterrows(): + iree_config = row[_DRIVER] + model_name = row[_MODEL] + if model_name.endswith("im2col_mmt4d"): + iree_config += ", im2col, mmt4d" + elif model_name.endswith("mmt4d"): + iree_config += ", mmt4d" + elif model_name.endswith("padfuse"): + iree_config += ", fused pad" + + iree_latency = row[_LATENCY] + latency_comparison = html_utils.format_latency_comparison( + iree_latency, tflite_latency + ) + iree_memory = row[_MEMORY] + memory_comparison = html_utils.format_memory_comparison( + iree_memory, tflite_memory + ) + + if iree_latency == 0 or iree_memory == 0: + continue + + summary.loc[len(summary)] = [ + model, + _MODEL_TO_DATA_TYPE[model], + tflite_config, + iree_config, + taskset, + thread, + f"{tflite_latency:.1f}", + f"{iree_latency:.1f}", + latency_comparison, + f"{tflite_memory:,.0f}", + f"{iree_memory:,.0f}", + memory_comparison, + ] + + summary = summary.round(2) + st = get_common_html_style(summary, title) + st = st.set_properties( + subset=[_TASKSET, _THREADS], + **{ + "width": "100", + "text-align": "center", + }, + ) + st = st.set_properties( + subset=[_TFLITE_CONFIG], + **{ + "width": "150px", + "text-align": "left", + }, + ) + st = st.set_properties( + subset=[_IREE_CONFIG], + **{ + "width": "300px", + "text-align": "left", + }, + ) + if platform != "mobile": + st.hide_columns(subset=[_TASKSET]) + + return st.to_html().replace("\\n", "
") + "
" def main(args): - """Summarizes IREE vs TFLite benchmark results.""" - if args.platform == _PLATFORM_SERVER: - cpu_drivers = ["cpu", "local-task"] - gpu_drivers = ["gpu", "cuda"] - else: - cpu_drivers = ["cpu", "local-task"] - gpu_drivers = ["gpu", "vulkan", "adreno"] + """Summarizes IREE vs TFLite benchmark results.""" + if args.platform == _PLATFORM_SERVER: + cpu_drivers = ["cpu", "local-task"] + gpu_drivers = ["gpu", "cuda"] + else: + cpu_drivers = ["cpu", "local-task"] + gpu_drivers = ["gpu", "vulkan", "adreno"] - version_html = (f"IREE version: {args.iree_version}
" - f"TFlite version: {args.tflite_version}
" - f"last updated: {date.today().isoformat()}

") - html = html_utils.generate_header_and_legend(version_html) + version_html = ( + f"IREE version: {args.iree_version}
" + f"TFlite version: {args.tflite_version}
" + f"last updated: {date.today().isoformat()}

" + ) + html = html_utils.generate_header_and_legend(version_html) - df = pd.read_csv(args.input_csv) + df = pd.read_csv(args.input_csv) - # Generate CPU Summary. - results = df[df[_DRIVER].isin(cpu_drivers)] - html += generate_summary(results, args.platform.capitalize() + " CPU Summary") + # Generate CPU Summary. + results = df[df[_DRIVER].isin(cpu_drivers)] + html += generate_summary(results, args.platform.capitalize() + " CPU Summary") - # Generate GPU Summary. - results = df[df[_DRIVER].isin(gpu_drivers)] - html += generate_summary(results, args.platform.capitalize() + " GPU Summary") + # Generate GPU Summary. + results = df[df[_DRIVER].isin(gpu_drivers)] + html += generate_summary(results, args.platform.capitalize() + " GPU Summary") - # Generate CPU Detailed View. - results = df[df[_DRIVER].isin(cpu_drivers)] - html += generate_detail(results, - args.platform.capitalize() + " CPU Detailed", - args.platform) + # Generate CPU Detailed View. + results = df[df[_DRIVER].isin(cpu_drivers)] + html += generate_detail( + results, args.platform.capitalize() + " CPU Detailed", args.platform + ) - # Generate GPU Detailed View. - results = df[df[_DRIVER].isin(gpu_drivers)] - html += generate_detail(results, - args.platform.capitalize() + " GPU Detailed", - args.platform) + # Generate GPU Detailed View. + results = df[df[_DRIVER].isin(gpu_drivers)] + html += generate_detail( + results, args.platform.capitalize() + " GPU Detailed", args.platform + ) - args.output_path.write_text(html) + args.output_path.write_text(html) def parse_args(): - parser = argparse.ArgumentParser() - parser.add_argument("--iree_version", - type=str, - default=None, - required=True, - help="The IREE version.") - parser.add_argument("--tflite_version", - type=str, - default=None, - required=True, - help="The TFLite version.") - parser.add_argument( - "--platform", - action="store", - type=str.lower, - help= - "The platform the models were benchmarked on. Either server or mobile.", - required=True, - choices=[_PLATFORM_SERVER, _PLATFORM_MOBILE]) - parser.add_argument( - "--input_csv", - type=str, - default=None, - help= - "The path to the csv file containing benchmark results for both IREE and TFLite." - ) - parser.add_argument( - "--output_path", - type=pathlib.Path, - default="/tmp/summary.html", - help="The path to the output html file that summarizes results.") - return parser.parse_args() + parser = argparse.ArgumentParser() + parser.add_argument( + "--iree_version", + type=str, + default=None, + required=True, + help="The IREE version.", + ) + parser.add_argument( + "--tflite_version", + type=str, + default=None, + required=True, + help="The TFLite version.", + ) + parser.add_argument( + "--platform", + action="store", + type=str.lower, + help="The platform the models were benchmarked on. Either server or mobile.", + required=True, + choices=[_PLATFORM_SERVER, _PLATFORM_MOBILE], + ) + parser.add_argument( + "--input_csv", + type=str, + default=None, + help="The path to the csv file containing benchmark results for both IREE and TFLite.", + ) + parser.add_argument( + "--output_path", + type=pathlib.Path, + default="/tmp/summary.html", + help="The path to the output html file that summarizes results.", + ) + return parser.parse_args() if __name__ == "__main__": - main(parse_args()) + main(parse_args()) diff --git a/build_tools/benchmarks/run_benchmarks_on_android.py b/build_tools/benchmarks/run_benchmarks_on_android.py index 5ffccd07f29f..995b80e2e336 100755 --- a/build_tools/benchmarks/run_benchmarks_on_android.py +++ b/build_tools/benchmarks/run_benchmarks_on_android.py @@ -45,15 +45,21 @@ from common.benchmark_config import BenchmarkConfig from common.benchmark_driver import BenchmarkDriver from common.benchmark_definition import ( - DriverInfo, execute_cmd, execute_cmd_and_get_stdout, - execute_cmd_and_get_output, get_git_commit_hash, - get_iree_benchmark_module_arguments, wait_for_iree_benchmark_module_start, - parse_iree_benchmark_metrics) -from common.benchmark_suite import (MODEL_FLAGFILE_NAME, BenchmarkCase, - BenchmarkSuite) -from common.android_device_utils import (get_android_device_model, - get_android_device_info, - get_android_gpu_name) + DriverInfo, + execute_cmd, + execute_cmd_and_get_stdout, + execute_cmd_and_get_output, + get_git_commit_hash, + get_iree_benchmark_module_arguments, + wait_for_iree_benchmark_module_start, + parse_iree_benchmark_metrics, +) +from common.benchmark_suite import MODEL_FLAGFILE_NAME, BenchmarkCase, BenchmarkSuite +from common.android_device_utils import ( + get_android_device_model, + get_android_device_info, + get_android_gpu_name, +) import common.common_arguments from e2e_test_artifacts import iree_artifacts from e2e_test_framework import serialization @@ -70,349 +76,395 @@ def adb_push_to_tmp_dir( content: pathlib.Path, relative_dir: pathlib.PurePosixPath = pathlib.PurePosixPath(), - verbose: bool = False) -> pathlib.PurePosixPath: - """Pushes content onto the Android device. - - Args: - content: the full path to the source file. - relative_dir: the directory to push to; relative to ANDROID_TMPDIR. - - Returns: - The full path to the content on the Android device. - """ - filename = content.name - android_path = ANDROID_TMPDIR / relative_dir / filename - # When the output is a TTY, keep the default progress info output. - # In other cases, redirect progress info to null to avoid bloating log files. - stdout_redirect = None if sys.stdout.isatty() else subprocess.DEVNULL - execute_cmd(["adb", "push", content.resolve(), android_path], - verbose=verbose, - stdout=stdout_redirect) - return android_path + verbose: bool = False, +) -> pathlib.PurePosixPath: + """Pushes content onto the Android device. + + Args: + content: the full path to the source file. + relative_dir: the directory to push to; relative to ANDROID_TMPDIR. + + Returns: + The full path to the content on the Android device. + """ + filename = content.name + android_path = ANDROID_TMPDIR / relative_dir / filename + # When the output is a TTY, keep the default progress info output. + # In other cases, redirect progress info to null to avoid bloating log files. + stdout_redirect = None if sys.stdout.isatty() else subprocess.DEVNULL + execute_cmd( + ["adb", "push", content.resolve(), android_path], + verbose=verbose, + stdout=stdout_redirect, + ) + return android_path def adb_execute_and_get_output( cmd_args: Sequence[str], relative_dir: pathlib.PurePosixPath = pathlib.PurePosixPath(), - verbose: bool = False) -> Tuple[str, str]: - """Executes command with adb shell. + verbose: bool = False, +) -> Tuple[str, str]: + """Executes command with adb shell. - Switches to `relative_dir` relative to the android tmp directory before - executing. Waits for completion and returns the command stdout. + Switches to `relative_dir` relative to the android tmp directory before + executing. Waits for completion and returns the command stdout. - Args: - cmd_args: a list containing the command to execute and its parameters - relative_dir: the directory to execute the command in; relative to - ANDROID_TMPDIR. + Args: + cmd_args: a list containing the command to execute and its parameters + relative_dir: the directory to execute the command in; relative to + ANDROID_TMPDIR. - Returns: - Strings for stdout and stderr. - """ - cmd = ["adb", "shell", "cd", ANDROID_TMPDIR / relative_dir, "&&"] - cmd.extend(cmd_args) - return execute_cmd_and_get_output(cmd, verbose=verbose) + Returns: + Strings for stdout and stderr. + """ + cmd = ["adb", "shell", "cd", ANDROID_TMPDIR / relative_dir, "&&"] + cmd.extend(cmd_args) + return execute_cmd_and_get_output(cmd, verbose=verbose) -def adb_execute(cmd_args: Sequence[str], - relative_dir: pathlib.PurePosixPath = pathlib.PurePosixPath(), - verbose: bool = False) -> subprocess.CompletedProcess: - """Executes command with adb shell. +def adb_execute( + cmd_args: Sequence[str], + relative_dir: pathlib.PurePosixPath = pathlib.PurePosixPath(), + verbose: bool = False, +) -> subprocess.CompletedProcess: + """Executes command with adb shell. - Switches to `relative_dir` relative to the android tmp directory before - executing. Waits for completion. Output is streamed to the terminal. + Switches to `relative_dir` relative to the android tmp directory before + executing. Waits for completion. Output is streamed to the terminal. - Args: - cmd_args: a list containing the command to execute and its parameters - relative_dir: the directory to execute the command in; relative to - ANDROID_TMPDIR. + Args: + cmd_args: a list containing the command to execute and its parameters + relative_dir: the directory to execute the command in; relative to + ANDROID_TMPDIR. - Returns: - The completed process. - """ - cmd = ["adb", "shell", "cd", ANDROID_TMPDIR / relative_dir, "&&"] - cmd.extend(cmd_args) - return execute_cmd(cmd, verbose=verbose) + Returns: + The completed process. + """ + cmd = ["adb", "shell", "cd", ANDROID_TMPDIR / relative_dir, "&&"] + cmd.extend(cmd_args) + return execute_cmd(cmd, verbose=verbose) def is_magisk_su(): - """Returns true if the Android device has a Magisk SU binary.""" - stdout, _ = adb_execute_and_get_output(["su", "--help"]) - return "MagiskSU" in stdout + """Returns true if the Android device has a Magisk SU binary.""" + stdout, _ = adb_execute_and_get_output(["su", "--help"]) + return "MagiskSU" in stdout def adb_execute_as_root(cmd_args: Sequence[Any]) -> subprocess.CompletedProcess: - """Executes the given command as root.""" - cmd = ["su", "-c" if is_magisk_su() else "root"] - cmd.extend(cmd_args) - return adb_execute(cmd) + """Executes the given command as root.""" + cmd = ["su", "-c" if is_magisk_su() else "root"] + cmd.extend(cmd_args) + return adb_execute(cmd) -def adb_start_cmd(cmd_args: Sequence[str], - relative_dir: pathlib.PurePosixPath = pathlib.PurePosixPath(), - verbose: bool = False) -> subprocess.Popen: - """Executes command with adb shell in a directory and returns the handle - without waiting for completion. +def adb_start_cmd( + cmd_args: Sequence[str], + relative_dir: pathlib.PurePosixPath = pathlib.PurePosixPath(), + verbose: bool = False, +) -> subprocess.Popen: + """Executes command with adb shell in a directory and returns the handle + without waiting for completion. - Args: - cmd_args: a list containing the command to execute and its parameters - relative_dir: the directory to execute the command in; relative to - ANDROID_TMPDIR. + Args: + cmd_args: a list containing the command to execute and its parameters + relative_dir: the directory to execute the command in; relative to + ANDROID_TMPDIR. - Returns: - A Popen object for the started command. - """ - cmd = ["adb", "shell", "cd", ANDROID_TMPDIR / relative_dir, "&&"] - cmd.extend(cmd_args) + Returns: + A Popen object for the started command. + """ + cmd = ["adb", "shell", "cd", ANDROID_TMPDIR / relative_dir, "&&"] + cmd.extend(cmd_args) - if verbose: - print(f"cmd: {cmd}") - return subprocess.Popen(cmd, stdout=subprocess.PIPE, text=True) + if verbose: + print(f"cmd: {cmd}") + return subprocess.Popen(cmd, stdout=subprocess.PIPE, text=True) def get_vmfb_full_path_for_benchmark_case( - benchmark_case_dir: pathlib.Path) -> pathlib.Path: - flagfile = benchmark_case_dir / MODEL_FLAGFILE_NAME - for line in flagfile.read_text().splitlines(): - flag_name, flag_value = line.strip().split("=") - if flag_name == "--module": - # Realpath canonicalization matters. The caller may rely on that to track - # which files it already pushed. - return (benchmark_case_dir / flag_value).resolve() - raise ValueError(f"{flagfile} does not contain a --module flag") + benchmark_case_dir: pathlib.Path, +) -> pathlib.Path: + flagfile = benchmark_case_dir / MODEL_FLAGFILE_NAME + for line in flagfile.read_text().splitlines(): + flag_name, flag_value = line.strip().split("=") + if flag_name == "--module": + # Realpath canonicalization matters. The caller may rely on that to track + # which files it already pushed. + return (benchmark_case_dir / flag_value).resolve() + raise ValueError(f"{flagfile} does not contain a --module flag") class AndroidBenchmarkDriver(BenchmarkDriver): - """Android benchmark driver.""" - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.already_pushed_files = {} - - def run_benchmark_case(self, benchmark_case: BenchmarkCase, - benchmark_results_filename: Optional[pathlib.Path], - capture_filename: Optional[pathlib.Path]) -> None: - benchmark_case_dir = benchmark_case.benchmark_case_dir - android_case_dir = pathlib.PurePosixPath( - benchmark_case_dir.relative_to(self.config.root_benchmark_dir)) - - run_config = benchmark_case.run_config - self.__check_and_push_file( - benchmark_case_dir / iree_artifacts.MODULE_FILENAME, android_case_dir) - taskset = self.__deduce_taskset_from_run_config(run_config) - run_args = run_config.materialize_run_flags() - run_args.append(f"--module={iree_artifacts.MODULE_FILENAME}") - - if benchmark_results_filename is not None: - self.__run_benchmark(android_case_dir=android_case_dir, - tool_name=benchmark_case.benchmark_tool_name, - driver_info=benchmark_case.driver_info, - run_args=run_args, - results_filename=benchmark_results_filename, - taskset=taskset) - - if capture_filename is not None: - self.__run_capture(android_case_dir=android_case_dir, - tool_name=benchmark_case.benchmark_tool_name, - run_args=run_args, - capture_filename=capture_filename, - taskset=taskset) - - def __run_benchmark(self, android_case_dir: pathlib.PurePosixPath, - tool_name: str, driver_info: DriverInfo, - run_args: Sequence[str], results_filename: pathlib.Path, - taskset: str): - if self.config.normal_benchmark_tool_dir is None: - raise ValueError("normal_benchmark_tool_dir can't be None.") - - host_tool_path = self.config.normal_benchmark_tool_dir / tool_name - android_tool = self.__check_and_push_file(host_tool_path, - NORMAL_TOOL_REL_DIR) - cmd = ["taskset", taskset, android_tool] - cmd += run_args - if tool_name == "iree-benchmark-module": - cmd += get_iree_benchmark_module_arguments( - results_filename=f"'{results_filename.name}'", - driver_info=driver_info, - benchmark_min_time=self.config.benchmark_min_time) - - benchmark_stdout, benchmark_stderr = adb_execute_and_get_output( - cmd, android_case_dir, verbose=self.verbose) - benchmark_metrics = parse_iree_benchmark_metrics(benchmark_stdout, - benchmark_stderr) - if self.verbose: - print(benchmark_metrics) - results_filename.write_text(json.dumps(benchmark_metrics.to_json_object())) - - def __run_capture(self, android_case_dir: pathlib.PurePosixPath, - tool_name: str, capture_filename: pathlib.Path, - run_args: Sequence[str], taskset: str): - capture_config = self.config.trace_capture_config - if capture_config is None: - raise ValueError("capture_config can't be None.") - - host_tool_path = capture_config.traced_benchmark_tool_dir / tool_name - android_tool = self.__check_and_push_file(host_tool_path, - TRACED_TOOL_REL_DIR) - run_cmd = [ - "TRACY_NO_EXIT=1", f"IREE_PRESERVE_DYLIB_TEMP_FILES={ANDROID_TMPDIR}", - "taskset", taskset, android_tool - ] - run_cmd += run_args - - # Just launch the traced benchmark tool with TRACY_NO_EXIT=1 without - # waiting for the adb command to complete as that won't happen. - process = adb_start_cmd(run_cmd, android_case_dir, verbose=self.verbose) - - wait_for_iree_benchmark_module_start(process, self.verbose) - - # Now it's okay to collect the trace via the capture tool. This will - # send the signal to let the previously waiting benchmark tool to - # complete. - capture_cmd = [ - capture_config.trace_capture_tool, "-f", "-o", capture_filename - ] - # If verbose, just let the subprocess print its output. The subprocess - # may need to detect if the output is a TTY to decide whether to log - # verbose progress info and use ANSI colors, so it's better to use - # stdout redirection than to capture the output in a string. - stdout_redirect = None if self.verbose else subprocess.DEVNULL - execute_cmd(capture_cmd, verbose=self.verbose, stdout=stdout_redirect) - - # TODO(#13187): These logics are inherited from the legacy benchmark suites, - # which only work for a few specific phones. We should define the topology - # in their device specs. - def __deduce_taskset_from_run_config( - self, run_config: iree_definitions.E2EModelRunConfig) -> str: - """Deduces the CPU mask according to device and execution config.""" - - device_spec = run_config.target_device_spec - # For GPU benchmarks, use the most performant core. - if device_spec.architecture.type == common_definitions.ArchitectureType.GPU: - return "80" - - device_params = device_spec.device_parameters - single_thread = "1-thread" in run_config.module_execution_config.tags - if device_parameters.ARM_BIG_CORES in device_params: - return "80" if single_thread else "f0" - elif device_parameters.ARM_LITTLE_CORES in device_params: - return "08" if single_thread else "0f" - - raise ValueError(f"Unsupported config to deduce taskset: '{run_config}'.") - - def __check_and_push_file(self, host_path: pathlib.Path, - relative_dir: pathlib.PurePosixPath): - """Checks if the file has been pushed and pushes it if not.""" - android_path = self.already_pushed_files.get(host_path) - if android_path is not None: - return android_path - - android_path = adb_push_to_tmp_dir(host_path, - relative_dir=relative_dir, - verbose=self.verbose) - self.already_pushed_files[host_path] = android_path - return android_path + """Android benchmark driver.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.already_pushed_files = {} + + def run_benchmark_case( + self, + benchmark_case: BenchmarkCase, + benchmark_results_filename: Optional[pathlib.Path], + capture_filename: Optional[pathlib.Path], + ) -> None: + benchmark_case_dir = benchmark_case.benchmark_case_dir + android_case_dir = pathlib.PurePosixPath( + benchmark_case_dir.relative_to(self.config.root_benchmark_dir) + ) + + run_config = benchmark_case.run_config + self.__check_and_push_file( + benchmark_case_dir / iree_artifacts.MODULE_FILENAME, android_case_dir + ) + taskset = self.__deduce_taskset_from_run_config(run_config) + run_args = run_config.materialize_run_flags() + run_args.append(f"--module={iree_artifacts.MODULE_FILENAME}") + + if benchmark_results_filename is not None: + self.__run_benchmark( + android_case_dir=android_case_dir, + tool_name=benchmark_case.benchmark_tool_name, + driver_info=benchmark_case.driver_info, + run_args=run_args, + results_filename=benchmark_results_filename, + taskset=taskset, + ) + + if capture_filename is not None: + self.__run_capture( + android_case_dir=android_case_dir, + tool_name=benchmark_case.benchmark_tool_name, + run_args=run_args, + capture_filename=capture_filename, + taskset=taskset, + ) + + def __run_benchmark( + self, + android_case_dir: pathlib.PurePosixPath, + tool_name: str, + driver_info: DriverInfo, + run_args: Sequence[str], + results_filename: pathlib.Path, + taskset: str, + ): + if self.config.normal_benchmark_tool_dir is None: + raise ValueError("normal_benchmark_tool_dir can't be None.") + + host_tool_path = self.config.normal_benchmark_tool_dir / tool_name + android_tool = self.__check_and_push_file(host_tool_path, NORMAL_TOOL_REL_DIR) + cmd = ["taskset", taskset, android_tool] + cmd += run_args + if tool_name == "iree-benchmark-module": + cmd += get_iree_benchmark_module_arguments( + results_filename=f"'{results_filename.name}'", + driver_info=driver_info, + benchmark_min_time=self.config.benchmark_min_time, + ) + + benchmark_stdout, benchmark_stderr = adb_execute_and_get_output( + cmd, android_case_dir, verbose=self.verbose + ) + benchmark_metrics = parse_iree_benchmark_metrics( + benchmark_stdout, benchmark_stderr + ) + if self.verbose: + print(benchmark_metrics) + results_filename.write_text(json.dumps(benchmark_metrics.to_json_object())) + + def __run_capture( + self, + android_case_dir: pathlib.PurePosixPath, + tool_name: str, + capture_filename: pathlib.Path, + run_args: Sequence[str], + taskset: str, + ): + capture_config = self.config.trace_capture_config + if capture_config is None: + raise ValueError("capture_config can't be None.") + + host_tool_path = capture_config.traced_benchmark_tool_dir / tool_name + android_tool = self.__check_and_push_file(host_tool_path, TRACED_TOOL_REL_DIR) + run_cmd = [ + "TRACY_NO_EXIT=1", + f"IREE_PRESERVE_DYLIB_TEMP_FILES={ANDROID_TMPDIR}", + "taskset", + taskset, + android_tool, + ] + run_cmd += run_args + + # Just launch the traced benchmark tool with TRACY_NO_EXIT=1 without + # waiting for the adb command to complete as that won't happen. + process = adb_start_cmd(run_cmd, android_case_dir, verbose=self.verbose) + + wait_for_iree_benchmark_module_start(process, self.verbose) + + # Now it's okay to collect the trace via the capture tool. This will + # send the signal to let the previously waiting benchmark tool to + # complete. + capture_cmd = [capture_config.trace_capture_tool, "-f", "-o", capture_filename] + # If verbose, just let the subprocess print its output. The subprocess + # may need to detect if the output is a TTY to decide whether to log + # verbose progress info and use ANSI colors, so it's better to use + # stdout redirection than to capture the output in a string. + stdout_redirect = None if self.verbose else subprocess.DEVNULL + execute_cmd(capture_cmd, verbose=self.verbose, stdout=stdout_redirect) + + # TODO(#13187): These logics are inherited from the legacy benchmark suites, + # which only work for a few specific phones. We should define the topology + # in their device specs. + def __deduce_taskset_from_run_config( + self, run_config: iree_definitions.E2EModelRunConfig + ) -> str: + """Deduces the CPU mask according to device and execution config.""" + + device_spec = run_config.target_device_spec + # For GPU benchmarks, use the most performant core. + if device_spec.architecture.type == common_definitions.ArchitectureType.GPU: + return "80" + + device_params = device_spec.device_parameters + single_thread = "1-thread" in run_config.module_execution_config.tags + if device_parameters.ARM_BIG_CORES in device_params: + return "80" if single_thread else "f0" + elif device_parameters.ARM_LITTLE_CORES in device_params: + return "08" if single_thread else "0f" + + raise ValueError(f"Unsupported config to deduce taskset: '{run_config}'.") + + def __check_and_push_file( + self, host_path: pathlib.Path, relative_dir: pathlib.PurePosixPath + ): + """Checks if the file has been pushed and pushes it if not.""" + android_path = self.already_pushed_files.get(host_path) + if android_path is not None: + return android_path + + android_path = adb_push_to_tmp_dir( + host_path, relative_dir=relative_dir, verbose=self.verbose + ) + self.already_pushed_files[host_path] = android_path + return android_path def set_cpu_frequency_scaling_governor(governor: str): - git_root = execute_cmd_and_get_stdout(["git", "rev-parse", "--show-toplevel"]) - cpu_script = (pathlib.Path(git_root) / "build_tools" / "benchmarks" / - "set_android_scaling_governor.sh") - android_path = adb_push_to_tmp_dir(cpu_script) - adb_execute_as_root([android_path, governor]) + git_root = execute_cmd_and_get_stdout(["git", "rev-parse", "--show-toplevel"]) + cpu_script = ( + pathlib.Path(git_root) + / "build_tools" + / "benchmarks" + / "set_android_scaling_governor.sh" + ) + android_path = adb_push_to_tmp_dir(cpu_script) + adb_execute_as_root([android_path, governor]) def set_gpu_frequency_scaling_policy(policy: str): - git_root = execute_cmd_and_get_stdout(["git", "rev-parse", "--show-toplevel"]) - device_model = get_android_device_model() - gpu_name = get_android_gpu_name() - benchmarks_tool_dir = pathlib.Path(git_root) / "build_tools" / "benchmarks" - if device_model == "Pixel-6" or device_model == "Pixel-6-Pro": - gpu_script = benchmarks_tool_dir / "set_pixel6_gpu_scaling_policy.sh" - elif gpu_name.lower().startswith("adreno"): - gpu_script = benchmarks_tool_dir / "set_adreno_gpu_scaling_policy.sh" - else: - raise RuntimeError( - f"Unsupported device '{device_model}' for setting GPU scaling policy") - android_path = adb_push_to_tmp_dir(gpu_script) - adb_execute_as_root([android_path, policy]) + git_root = execute_cmd_and_get_stdout(["git", "rev-parse", "--show-toplevel"]) + device_model = get_android_device_model() + gpu_name = get_android_gpu_name() + benchmarks_tool_dir = pathlib.Path(git_root) / "build_tools" / "benchmarks" + if device_model == "Pixel-6" or device_model == "Pixel-6-Pro": + gpu_script = benchmarks_tool_dir / "set_pixel6_gpu_scaling_policy.sh" + elif gpu_name.lower().startswith("adreno"): + gpu_script = benchmarks_tool_dir / "set_adreno_gpu_scaling_policy.sh" + else: + raise RuntimeError( + f"Unsupported device '{device_model}' for setting GPU scaling policy" + ) + android_path = adb_push_to_tmp_dir(gpu_script) + adb_execute_as_root([android_path, policy]) def main(args): - device_info = get_android_device_info(args.verbose) - if args.verbose: - print(device_info) - - commit = get_git_commit_hash("HEAD") - benchmark_config = BenchmarkConfig.build_from_args(args, commit) - benchmark_groups = json.loads(args.execution_benchmark_config.read_text()) - benchmark_group = benchmark_groups.get(args.target_device_name) - if benchmark_group is None: - raise ValueError("Target device not found in the benchmark config.") - run_configs = serialization.unpack_and_deserialize( - data=benchmark_group["run_configs"], - root_type=List[iree_definitions.E2EModelRunConfig]) - benchmark_suite = BenchmarkSuite.load_from_run_configs( - run_configs=run_configs, - root_benchmark_dir=benchmark_config.root_benchmark_dir) - - benchmark_driver = AndroidBenchmarkDriver(device_info=device_info, - benchmark_config=benchmark_config, - benchmark_suite=benchmark_suite, - benchmark_grace_time=1.0, - verbose=args.verbose) - - if args.pin_cpu_freq: - set_cpu_frequency_scaling_governor("performance") - atexit.register(set_cpu_frequency_scaling_governor, "schedutil") - if args.pin_gpu_freq: - set_gpu_frequency_scaling_policy("performance") - atexit.register(set_gpu_frequency_scaling_policy, "default") - - # Clear the benchmark directory on the Android device first just in case - # there are leftovers from manual or failed runs. - execute_cmd_and_get_stdout(["adb", "shell", "rm", "-rf", ANDROID_TMPDIR], - verbose=args.verbose) - - if not args.no_clean: - # Clear the benchmark directory on the Android device. - atexit.register(execute_cmd_and_get_stdout, - ["adb", "shell", "rm", "-rf", ANDROID_TMPDIR], - verbose=args.verbose) - # Also clear temporary directory on the host device. - atexit.register(shutil.rmtree, args.tmp_dir) - - # Tracy client and server communicate over port 8086 by default. If we want - # to capture traces along the way, forward port via adb. - trace_capture_config = benchmark_config.trace_capture_config - if trace_capture_config: - execute_cmd_and_get_stdout(["adb", "forward", "tcp:8086", "tcp:8086"], - verbose=args.verbose) - atexit.register(execute_cmd_and_get_stdout, - ["adb", "forward", "--remove", "tcp:8086"], - verbose=args.verbose) - - benchmark_driver.run() - - benchmark_results = benchmark_driver.get_benchmark_results() - if args.output is not None: - with open(args.output, "w") as f: - f.write(benchmark_results.to_json_str()) - - if args.verbose: - print(benchmark_results.commit) - print(benchmark_results.benchmarks) - - if trace_capture_config: - # Put all captures in a tarball and remove the original files. - with tarfile.open(trace_capture_config.capture_tarball, "w:gz") as tar: - for capture_filename in benchmark_driver.get_capture_filenames(): - tar.add(capture_filename) - - benchmark_errors = benchmark_driver.get_benchmark_errors() - if benchmark_errors: - print("Benchmarking completed with errors", file=sys.stderr) - raise RuntimeError(benchmark_errors) + device_info = get_android_device_info(args.verbose) + if args.verbose: + print(device_info) + + commit = get_git_commit_hash("HEAD") + benchmark_config = BenchmarkConfig.build_from_args(args, commit) + benchmark_groups = json.loads(args.execution_benchmark_config.read_text()) + benchmark_group = benchmark_groups.get(args.target_device_name) + if benchmark_group is None: + raise ValueError("Target device not found in the benchmark config.") + run_configs = serialization.unpack_and_deserialize( + data=benchmark_group["run_configs"], + root_type=List[iree_definitions.E2EModelRunConfig], + ) + benchmark_suite = BenchmarkSuite.load_from_run_configs( + run_configs=run_configs, root_benchmark_dir=benchmark_config.root_benchmark_dir + ) + + benchmark_driver = AndroidBenchmarkDriver( + device_info=device_info, + benchmark_config=benchmark_config, + benchmark_suite=benchmark_suite, + benchmark_grace_time=1.0, + verbose=args.verbose, + ) + + if args.pin_cpu_freq: + set_cpu_frequency_scaling_governor("performance") + atexit.register(set_cpu_frequency_scaling_governor, "schedutil") + if args.pin_gpu_freq: + set_gpu_frequency_scaling_policy("performance") + atexit.register(set_gpu_frequency_scaling_policy, "default") + + # Clear the benchmark directory on the Android device first just in case + # there are leftovers from manual or failed runs. + execute_cmd_and_get_stdout( + ["adb", "shell", "rm", "-rf", ANDROID_TMPDIR], verbose=args.verbose + ) + + if not args.no_clean: + # Clear the benchmark directory on the Android device. + atexit.register( + execute_cmd_and_get_stdout, + ["adb", "shell", "rm", "-rf", ANDROID_TMPDIR], + verbose=args.verbose, + ) + # Also clear temporary directory on the host device. + atexit.register(shutil.rmtree, args.tmp_dir) + + # Tracy client and server communicate over port 8086 by default. If we want + # to capture traces along the way, forward port via adb. + trace_capture_config = benchmark_config.trace_capture_config + if trace_capture_config: + execute_cmd_and_get_stdout( + ["adb", "forward", "tcp:8086", "tcp:8086"], verbose=args.verbose + ) + atexit.register( + execute_cmd_and_get_stdout, + ["adb", "forward", "--remove", "tcp:8086"], + verbose=args.verbose, + ) + + benchmark_driver.run() + + benchmark_results = benchmark_driver.get_benchmark_results() + if args.output is not None: + with open(args.output, "w") as f: + f.write(benchmark_results.to_json_str()) + + if args.verbose: + print(benchmark_results.commit) + print(benchmark_results.benchmarks) + + if trace_capture_config: + # Put all captures in a tarball and remove the original files. + with tarfile.open(trace_capture_config.capture_tarball, "w:gz") as tar: + for capture_filename in benchmark_driver.get_capture_filenames(): + tar.add(capture_filename) + + benchmark_errors = benchmark_driver.get_benchmark_errors() + if benchmark_errors: + print("Benchmarking completed with errors", file=sys.stderr) + raise RuntimeError(benchmark_errors) if __name__ == "__main__": - main(common.common_arguments.Parser().parse_args()) + main(common.common_arguments.Parser().parse_args()) diff --git a/build_tools/benchmarks/run_benchmarks_on_linux.py b/build_tools/benchmarks/run_benchmarks_on_linux.py index 9554ccc737b2..6820e0f9ec83 100755 --- a/build_tools/benchmarks/run_benchmarks_on_linux.py +++ b/build_tools/benchmarks/run_benchmarks_on_linux.py @@ -23,12 +23,14 @@ from common.benchmark_driver import BenchmarkDriver from common.benchmark_suite import BenchmarkCase, BenchmarkSuite from common.benchmark_config import BenchmarkConfig -from common.benchmark_definition import (execute_cmd, - execute_cmd_and_get_output, - get_git_commit_hash, - get_iree_benchmark_module_arguments, - wait_for_iree_benchmark_module_start, - parse_iree_benchmark_metrics) +from common.benchmark_definition import ( + execute_cmd, + execute_cmd_and_get_output, + get_git_commit_hash, + get_iree_benchmark_module_arguments, + wait_for_iree_benchmark_module_start, + parse_iree_benchmark_metrics, +) from common.linux_device_utils import get_linux_device_info from e2e_test_framework.definitions import iree_definitions from e2e_test_framework import serialization @@ -38,161 +40,174 @@ class LinuxBenchmarkDriver(BenchmarkDriver): - """Linux benchmark driver.""" - - def __init__(self, gpu_id: str, *args, **kwargs): - self.gpu_id = gpu_id - super().__init__(*args, **kwargs) - - def run_benchmark_case(self, benchmark_case: BenchmarkCase, - benchmark_results_filename: Optional[pathlib.Path], - capture_filename: Optional[pathlib.Path]) -> None: - - if benchmark_results_filename: - self.__run_benchmark(benchmark_case=benchmark_case, - results_filename=benchmark_results_filename) - - if capture_filename: - self.__run_capture(benchmark_case=benchmark_case, - capture_filename=capture_filename) - - def __build_tool_cmds(self, benchmark_case: BenchmarkCase, - tool_path: pathlib.Path) -> List[Any]: - run_config = benchmark_case.run_config - cmds: List[Any] = run_module_utils.build_linux_wrapper_cmds_for_device_spec( - run_config.target_device_spec) - cmds.append(tool_path) - - module_dir_path = benchmark_case.benchmark_case_dir - cmds += [f"--module={module_dir_path / iree_artifacts.MODULE_FILENAME}"] - cmds += run_config.materialize_run_flags(gpu_id=self.gpu_id) - - return cmds - - def __run_benchmark(self, benchmark_case: BenchmarkCase, - results_filename: pathlib.Path): - if self.config.normal_benchmark_tool_dir is None: - raise ValueError("normal_benchmark_tool_dir can't be None.") - - tool_name = benchmark_case.benchmark_tool_name - tool_path = self.config.normal_benchmark_tool_dir / tool_name - cmd = self.__build_tool_cmds(benchmark_case=benchmark_case, - tool_path=tool_path) - - if tool_name == "iree-benchmark-module": - cmd.extend( - get_iree_benchmark_module_arguments( - results_filename=str(results_filename), - driver_info=benchmark_case.driver_info, - benchmark_min_time=self.config.benchmark_min_time)) - - benchmark_stdout, benchmark_stderr = execute_cmd_and_get_output( - cmd, verbose=self.verbose) - benchmark_metrics = parse_iree_benchmark_metrics(benchmark_stdout, - benchmark_stderr) - if self.verbose: - print(benchmark_metrics) - results_filename.write_text(json.dumps(benchmark_metrics.to_json_object())) - - def __run_capture(self, benchmark_case: BenchmarkCase, - capture_filename: pathlib.Path): - capture_config = self.config.trace_capture_config - if capture_config is None: - raise ValueError("capture_config can't be None.") - - tool_path = (capture_config.traced_benchmark_tool_dir / - benchmark_case.benchmark_tool_name) - cmd = self.__build_tool_cmds(benchmark_case=benchmark_case, - tool_path=tool_path) - - process = subprocess.Popen(cmd, - env={"TRACY_NO_EXIT": "1"}, - stdout=subprocess.PIPE, - text=True) - - wait_for_iree_benchmark_module_start(process, self.verbose) - - capture_cmd = [ - capture_config.trace_capture_tool, "-f", "-o", capture_filename - ] - stdout_redirect = None if self.verbose else subprocess.DEVNULL - execute_cmd(capture_cmd, verbose=self.verbose, stdout=stdout_redirect) + """Linux benchmark driver.""" + + def __init__(self, gpu_id: str, *args, **kwargs): + self.gpu_id = gpu_id + super().__init__(*args, **kwargs) + + def run_benchmark_case( + self, + benchmark_case: BenchmarkCase, + benchmark_results_filename: Optional[pathlib.Path], + capture_filename: Optional[pathlib.Path], + ) -> None: + if benchmark_results_filename: + self.__run_benchmark( + benchmark_case=benchmark_case, + results_filename=benchmark_results_filename, + ) + + if capture_filename: + self.__run_capture( + benchmark_case=benchmark_case, capture_filename=capture_filename + ) + + def __build_tool_cmds( + self, benchmark_case: BenchmarkCase, tool_path: pathlib.Path + ) -> List[Any]: + run_config = benchmark_case.run_config + cmds: List[Any] = run_module_utils.build_linux_wrapper_cmds_for_device_spec( + run_config.target_device_spec + ) + cmds.append(tool_path) + + module_dir_path = benchmark_case.benchmark_case_dir + cmds += [f"--module={module_dir_path / iree_artifacts.MODULE_FILENAME}"] + cmds += run_config.materialize_run_flags(gpu_id=self.gpu_id) + + return cmds + + def __run_benchmark( + self, benchmark_case: BenchmarkCase, results_filename: pathlib.Path + ): + if self.config.normal_benchmark_tool_dir is None: + raise ValueError("normal_benchmark_tool_dir can't be None.") + + tool_name = benchmark_case.benchmark_tool_name + tool_path = self.config.normal_benchmark_tool_dir / tool_name + cmd = self.__build_tool_cmds(benchmark_case=benchmark_case, tool_path=tool_path) + + if tool_name == "iree-benchmark-module": + cmd.extend( + get_iree_benchmark_module_arguments( + results_filename=str(results_filename), + driver_info=benchmark_case.driver_info, + benchmark_min_time=self.config.benchmark_min_time, + ) + ) + + benchmark_stdout, benchmark_stderr = execute_cmd_and_get_output( + cmd, verbose=self.verbose + ) + benchmark_metrics = parse_iree_benchmark_metrics( + benchmark_stdout, benchmark_stderr + ) + if self.verbose: + print(benchmark_metrics) + results_filename.write_text(json.dumps(benchmark_metrics.to_json_object())) + + def __run_capture( + self, benchmark_case: BenchmarkCase, capture_filename: pathlib.Path + ): + capture_config = self.config.trace_capture_config + if capture_config is None: + raise ValueError("capture_config can't be None.") + + tool_path = ( + capture_config.traced_benchmark_tool_dir + / benchmark_case.benchmark_tool_name + ) + cmd = self.__build_tool_cmds(benchmark_case=benchmark_case, tool_path=tool_path) + + process = subprocess.Popen( + cmd, env={"TRACY_NO_EXIT": "1"}, stdout=subprocess.PIPE, text=True + ) + + wait_for_iree_benchmark_module_start(process, self.verbose) + + capture_cmd = [capture_config.trace_capture_tool, "-f", "-o", capture_filename] + stdout_redirect = None if self.verbose else subprocess.DEVNULL + execute_cmd(capture_cmd, verbose=self.verbose, stdout=stdout_redirect) def main(args): - device_info = get_linux_device_info(args.device_model, args.cpu_uarch, - args.gpu_id, args.verbose) - if args.verbose: - print(device_info) - - commit = get_git_commit_hash("HEAD") - benchmark_config = BenchmarkConfig.build_from_args(args, commit) - - benchmark_groups = json.loads(args.execution_benchmark_config.read_text()) - benchmark_group = benchmark_groups.get(args.target_device_name) - if benchmark_group is None: - raise ValueError("Target device not found in the benchmark config.") - run_configs = serialization.unpack_and_deserialize( - data=benchmark_group["run_configs"], - root_type=typing.List[iree_definitions.E2EModelRunConfig]) - benchmark_suite = BenchmarkSuite.load_from_run_configs( - run_configs=run_configs, - root_benchmark_dir=benchmark_config.root_benchmark_dir) - - benchmark_driver = LinuxBenchmarkDriver(gpu_id=args.gpu_id, - device_info=device_info, - benchmark_config=benchmark_config, - benchmark_suite=benchmark_suite, - benchmark_grace_time=1.0, - verbose=args.verbose) - - if args.pin_cpu_freq: - raise NotImplementedError("CPU freq pinning is not supported yet.") - if args.pin_gpu_freq: - raise NotImplementedError("GPU freq pinning is not supported yet.") - if not args.no_clean: - atexit.register(shutil.rmtree, args.tmp_dir) - - benchmark_driver.run() - - benchmark_results = benchmark_driver.get_benchmark_results() - if args.output is not None: - with args.output.open("w") as f: - f.write(benchmark_results.to_json_str()) - - if args.verbose: - print(benchmark_results.commit) - print(benchmark_results.benchmarks) - - trace_capture_config = benchmark_config.trace_capture_config - if trace_capture_config: - # Put all captures in a tarball and remove the original files. - with tarfile.open(trace_capture_config.capture_tarball, "w:gz") as tar: - for capture_filename in benchmark_driver.get_capture_filenames(): - tar.add(capture_filename) - - benchmark_errors = benchmark_driver.get_benchmark_errors() - if benchmark_errors: - print("Benchmarking completed with errors", file=sys.stderr) - raise RuntimeError(benchmark_errors) + device_info = get_linux_device_info( + args.device_model, args.cpu_uarch, args.gpu_id, args.verbose + ) + if args.verbose: + print(device_info) + + commit = get_git_commit_hash("HEAD") + benchmark_config = BenchmarkConfig.build_from_args(args, commit) + + benchmark_groups = json.loads(args.execution_benchmark_config.read_text()) + benchmark_group = benchmark_groups.get(args.target_device_name) + if benchmark_group is None: + raise ValueError("Target device not found in the benchmark config.") + run_configs = serialization.unpack_and_deserialize( + data=benchmark_group["run_configs"], + root_type=typing.List[iree_definitions.E2EModelRunConfig], + ) + benchmark_suite = BenchmarkSuite.load_from_run_configs( + run_configs=run_configs, root_benchmark_dir=benchmark_config.root_benchmark_dir + ) + + benchmark_driver = LinuxBenchmarkDriver( + gpu_id=args.gpu_id, + device_info=device_info, + benchmark_config=benchmark_config, + benchmark_suite=benchmark_suite, + benchmark_grace_time=1.0, + verbose=args.verbose, + ) + + if args.pin_cpu_freq: + raise NotImplementedError("CPU freq pinning is not supported yet.") + if args.pin_gpu_freq: + raise NotImplementedError("GPU freq pinning is not supported yet.") + if not args.no_clean: + atexit.register(shutil.rmtree, args.tmp_dir) + + benchmark_driver.run() + + benchmark_results = benchmark_driver.get_benchmark_results() + if args.output is not None: + with args.output.open("w") as f: + f.write(benchmark_results.to_json_str()) + + if args.verbose: + print(benchmark_results.commit) + print(benchmark_results.benchmarks) + + trace_capture_config = benchmark_config.trace_capture_config + if trace_capture_config: + # Put all captures in a tarball and remove the original files. + with tarfile.open(trace_capture_config.capture_tarball, "w:gz") as tar: + for capture_filename in benchmark_driver.get_capture_filenames(): + tar.add(capture_filename) + + benchmark_errors = benchmark_driver.get_benchmark_errors() + if benchmark_errors: + print("Benchmarking completed with errors", file=sys.stderr) + raise RuntimeError(benchmark_errors) def parse_argument(): - arg_parser = common.common_arguments.Parser() - arg_parser.add_argument("--device_model", - default="Unknown", - help="Device model") - arg_parser.add_argument("--cpu_uarch", - default=None, - help="CPU microarchitecture, e.g., CascadeLake") - arg_parser.add_argument( - "--gpu_id", - type=str, - default="0", - help="GPU ID to run the benchmark, e.g., '0' or 'GPU-'") - - return arg_parser.parse_args() + arg_parser = common.common_arguments.Parser() + arg_parser.add_argument("--device_model", default="Unknown", help="Device model") + arg_parser.add_argument( + "--cpu_uarch", default=None, help="CPU microarchitecture, e.g., CascadeLake" + ) + arg_parser.add_argument( + "--gpu_id", + type=str, + default="0", + help="GPU ID to run the benchmark, e.g., '0' or 'GPU-'", + ) + + return arg_parser.parse_args() if __name__ == "__main__": - main(parse_argument()) + main(parse_argument()) diff --git a/build_tools/benchmarks/upload_benchmarks_to_dashboard.py b/build_tools/benchmarks/upload_benchmarks_to_dashboard.py index 89aaf2e14ca0..137b7c65e8af 100755 --- a/build_tools/benchmarks/upload_benchmarks_to_dashboard.py +++ b/build_tools/benchmarks/upload_benchmarks_to_dashboard.py @@ -32,8 +32,8 @@ from common import benchmark_definition, benchmark_presentation, benchmark_thresholds IREE_DASHBOARD_URL = "https://perf.iree.dev" -IREE_GITHUB_COMMIT_URL_PREFIX = 'https://github.com/openxla/iree/commit' -IREE_PROJECT_ID = 'IREE' +IREE_GITHUB_COMMIT_URL_PREFIX = "https://github.com/openxla/iree/commit" +IREE_PROJECT_ID = "IREE" THIS_DIRECTORY = pathlib.Path(__file__).resolve().parent COMMON_DESCRIPTION = """ @@ -50,353 +50,395 @@ # For models listed here we can provide a nicer description for them on # webpage. IREE_TF_MODEL_SOURCE_URL = { - 'MobileBertSquad': - 'https://github.com/google-research/google-research/tree/master/mobilebert', - 'MobileNetV2': - 'https://www.tensorflow.org/api_docs/python/tf/keras/applications/MobileNetV2', - 'MobileNetV3Small': - 'https://www.tensorflow.org/api_docs/python/tf/keras/applications/MobileNetV3Small', + "MobileBertSquad": "https://github.com/google-research/google-research/tree/master/mobilebert", + "MobileNetV2": "https://www.tensorflow.org/api_docs/python/tf/keras/applications/MobileNetV2", + "MobileNetV3Small": "https://www.tensorflow.org/api_docs/python/tf/keras/applications/MobileNetV3Small", } IREE_TFLITE_MODEL_SOURCE_URL = { - 'DeepLabV3': - 'https://tfhub.dev/tensorflow/lite-model/deeplabv3/1/default/1', - 'MobileSSD': - 'https://www.tensorflow.org/lite/performance/gpu#demo_app_tutorials', - 'PoseNet': - 'https://tfhub.dev/tensorflow/lite-model/posenet/mobilenet/float/075/1/default/1', + "DeepLabV3": "https://tfhub.dev/tensorflow/lite-model/deeplabv3/1/default/1", + "MobileSSD": "https://www.tensorflow.org/lite/performance/gpu#demo_app_tutorials", + "PoseNet": "https://tfhub.dev/tensorflow/lite-model/posenet/mobilenet/float/075/1/default/1", } def get_model_description(model_name: str, model_source: str) -> Optional[str]: - """Gets the model description for the given benchmark.""" - url = None - if model_source == "TensorFlow": - url = IREE_TF_MODEL_SOURCE_URL.get(model_name) - elif model_source == "TFLite": - url = IREE_TFLITE_MODEL_SOURCE_URL.get(model_name) - if url is not None: - description = f'{model_name} from {url}.' - return description - return None + """Gets the model description for the given benchmark.""" + url = None + if model_source == "TensorFlow": + url = IREE_TF_MODEL_SOURCE_URL.get(model_name) + elif model_source == "TFLite": + url = IREE_TFLITE_MODEL_SOURCE_URL.get(model_name) + if url is not None: + description = f'{model_name} from {url}.' + return description + return None def get_git_commit_hash(commit: str, verbose: bool = False) -> str: - """Gets the commit hash for the given commit.""" - return benchmark_definition.execute_cmd_and_get_stdout( - ['git', 'rev-parse', commit], cwd=THIS_DIRECTORY, verbose=verbose) + """Gets the commit hash for the given commit.""" + return benchmark_definition.execute_cmd_and_get_stdout( + ["git", "rev-parse", commit], cwd=THIS_DIRECTORY, verbose=verbose + ) def get_git_total_commit_count(commit: str, verbose: bool = False) -> int: - """Gets the total commit count in history ending with the given commit.""" - count = benchmark_definition.execute_cmd_and_get_stdout( - ['git', 'rev-list', '--count', commit], - cwd=THIS_DIRECTORY, - verbose=verbose) - return int(count) + """Gets the total commit count in history ending with the given commit.""" + count = benchmark_definition.execute_cmd_and_get_stdout( + ["git", "rev-list", "--count", commit], cwd=THIS_DIRECTORY, verbose=verbose + ) + return int(count) def get_git_commit_info(commit: str, verbose: bool = False) -> Dict[str, str]: - """Gets commit information dictionary for the given commit.""" - cmd = [ - 'git', 'show', '--format=%H:::%h:::%an:::%ae:::%s', '--no-patch', commit - ] - info = benchmark_definition.execute_cmd_and_get_stdout(cmd, - cwd=THIS_DIRECTORY, - verbose=verbose) - segments = info.split(':::') - return { - 'hash': segments[0], - 'abbrevHash': segments[1], - 'authorName': segments[2], - 'authorEmail': segments[3], - 'subject': segments[4], - } - - -def compose_series_payload(project_id: str, - series_id: str, - series_unit: str, - series_name: Optional[str] = None, - series_description: Optional[str] = None, - average_range: Union[int, str] = '5%', - average_min_count: int = 3, - better_criterion: str = 'smaller', - override: bool = False) -> Dict[str, Any]: - """Composes the payload dictionary for a series.""" - payload = { - 'projectId': project_id, - 'serieId': series_id, - 'serieUnit': series_unit, - 'serieName': series_name, - 'analyse': { - 'benchmark': { - 'range': average_range, - 'required': average_min_count, - 'trend': better_criterion, - } - }, - 'override': override, - } - if series_description is not None: - payload['description'] = series_description - return payload - - -def compose_build_payload(project_id: str, - project_github_commit_url: str, - build_id: int, - commit: str, - override: bool = False) -> Dict[str, Any]: - """Composes the payload dictionary for a build.""" - commit_info = get_git_commit_info(commit) - commit_info['url'] = f'{project_github_commit_url}/{commit_info["hash"]}' - return { - 'projectId': project_id, - 'build': { - 'buildId': build_id, - 'infos': commit_info, - }, - 'override': override, - } - - -def compose_sample_payload(project_id: str, - series_id: str, - build_id: int, - sample_unit: str, - sample_value: int, - override: bool = False) -> Dict[str, Any]: - """Composes the payload dictionary for a sample.""" - return { - 'projectId': project_id, - 'serieId': series_id, - 'sampleUnit': sample_unit, - 'sample': { - 'buildId': build_id, - 'value': sample_value - }, - 'override': override - } + """Gets commit information dictionary for the given commit.""" + cmd = ["git", "show", "--format=%H:::%h:::%an:::%ae:::%s", "--no-patch", commit] + info = benchmark_definition.execute_cmd_and_get_stdout( + cmd, cwd=THIS_DIRECTORY, verbose=verbose + ) + segments = info.split(":::") + return { + "hash": segments[0], + "abbrevHash": segments[1], + "authorName": segments[2], + "authorEmail": segments[3], + "subject": segments[4], + } + + +def compose_series_payload( + project_id: str, + series_id: str, + series_unit: str, + series_name: Optional[str] = None, + series_description: Optional[str] = None, + average_range: Union[int, str] = "5%", + average_min_count: int = 3, + better_criterion: str = "smaller", + override: bool = False, +) -> Dict[str, Any]: + """Composes the payload dictionary for a series.""" + payload = { + "projectId": project_id, + "serieId": series_id, + "serieUnit": series_unit, + "serieName": series_name, + "analyse": { + "benchmark": { + "range": average_range, + "required": average_min_count, + "trend": better_criterion, + } + }, + "override": override, + } + if series_description is not None: + payload["description"] = series_description + return payload + + +def compose_build_payload( + project_id: str, + project_github_commit_url: str, + build_id: int, + commit: str, + override: bool = False, +) -> Dict[str, Any]: + """Composes the payload dictionary for a build.""" + commit_info = get_git_commit_info(commit) + commit_info["url"] = f'{project_github_commit_url}/{commit_info["hash"]}' + return { + "projectId": project_id, + "build": { + "buildId": build_id, + "infos": commit_info, + }, + "override": override, + } + + +def compose_sample_payload( + project_id: str, + series_id: str, + build_id: int, + sample_unit: str, + sample_value: int, + override: bool = False, +) -> Dict[str, Any]: + """Composes the payload dictionary for a sample.""" + return { + "projectId": project_id, + "serieId": series_id, + "sampleUnit": sample_unit, + "sample": {"buildId": build_id, "value": sample_value}, + "override": override, + } def get_required_env_var(var: str) -> str: - """Gets the value for a required environment variable.""" - value = os.getenv(var) - if value is None: - raise RuntimeError(f'Missing environment variable "{var}"') - return value - - -def post_to_dashboard(url: str, - payload: Dict[str, Any], - dry_run: bool = False, - verbose: bool = False): - data = json.dumps(payload) - - if dry_run or verbose: - print(f'API request payload: {data}') - - if dry_run: - return - - api_token = get_required_env_var('IREE_DASHBOARD_API_TOKEN') - headers = { - 'Content-type': 'application/json', - 'Authorization': f'Bearer {api_token}', - } - - response = requests.post(url, data=data, headers=headers) - code = response.status_code - if code != 200: - raise requests.RequestException( - f'Failed to post to dashboard server with {code} - {response.text}') - - -def add_new_iree_series(series_id: str, - series_unit: str, - series_name: str, - series_description: Optional[str] = None, - average_range: Optional[Union[str, int]] = None, - override: bool = False, - dry_run: bool = False, - verbose: bool = False): - """Posts a new series to the dashboard.""" - if average_range is None: - raise ValueError(f"no matched threshold setting for benchmark: {series_id}") - - payload = compose_series_payload(IREE_PROJECT_ID, - series_id, - series_unit, - series_name, - series_description, - average_range=average_range, - override=override) - post_to_dashboard(f'{IREE_DASHBOARD_URL}/apis/v2/addSerie', - payload, - dry_run=dry_run, - verbose=verbose) - - -def add_new_iree_build(build_id: int, - commit: str, - override: bool = False, - dry_run: bool = False, - verbose: bool = False): - """Posts a new build to the dashboard.""" - payload = compose_build_payload(IREE_PROJECT_ID, - IREE_GITHUB_COMMIT_URL_PREFIX, build_id, - commit, override) - post_to_dashboard(f'{IREE_DASHBOARD_URL}/apis/addBuild', - payload, - dry_run=dry_run, - verbose=verbose) - - -def add_new_sample(series_id: str, - build_id: int, - sample_unit: str, - sample_value: int, - override: bool = False, - dry_run: bool = False, - verbose: bool = False): - """Posts a new sample to the dashboard.""" - payload = compose_sample_payload(IREE_PROJECT_ID, series_id, build_id, - sample_unit, sample_value, override) - post_to_dashboard(f'{IREE_DASHBOARD_URL}/apis/v2/addSample', - payload, - dry_run=dry_run, - verbose=verbose) + """Gets the value for a required environment variable.""" + value = os.getenv(var) + if value is None: + raise RuntimeError(f'Missing environment variable "{var}"') + return value + + +def post_to_dashboard( + url: str, payload: Dict[str, Any], dry_run: bool = False, verbose: bool = False +): + data = json.dumps(payload) + + if dry_run or verbose: + print(f"API request payload: {data}") + + if dry_run: + return + + api_token = get_required_env_var("IREE_DASHBOARD_API_TOKEN") + headers = { + "Content-type": "application/json", + "Authorization": f"Bearer {api_token}", + } + + response = requests.post(url, data=data, headers=headers) + code = response.status_code + if code != 200: + raise requests.RequestException( + f"Failed to post to dashboard server with {code} - {response.text}" + ) + + +def add_new_iree_series( + series_id: str, + series_unit: str, + series_name: str, + series_description: Optional[str] = None, + average_range: Optional[Union[str, int]] = None, + override: bool = False, + dry_run: bool = False, + verbose: bool = False, +): + """Posts a new series to the dashboard.""" + if average_range is None: + raise ValueError(f"no matched threshold setting for benchmark: {series_id}") + + payload = compose_series_payload( + IREE_PROJECT_ID, + series_id, + series_unit, + series_name, + series_description, + average_range=average_range, + override=override, + ) + post_to_dashboard( + f"{IREE_DASHBOARD_URL}/apis/v2/addSerie", + payload, + dry_run=dry_run, + verbose=verbose, + ) + + +def add_new_iree_build( + build_id: int, + commit: str, + override: bool = False, + dry_run: bool = False, + verbose: bool = False, +): + """Posts a new build to the dashboard.""" + payload = compose_build_payload( + IREE_PROJECT_ID, IREE_GITHUB_COMMIT_URL_PREFIX, build_id, commit, override + ) + post_to_dashboard( + f"{IREE_DASHBOARD_URL}/apis/addBuild", payload, dry_run=dry_run, verbose=verbose + ) + + +def add_new_sample( + series_id: str, + build_id: int, + sample_unit: str, + sample_value: int, + override: bool = False, + dry_run: bool = False, + verbose: bool = False, +): + """Posts a new sample to the dashboard.""" + payload = compose_sample_payload( + IREE_PROJECT_ID, series_id, build_id, sample_unit, sample_value, override + ) + post_to_dashboard( + f"{IREE_DASHBOARD_URL}/apis/v2/addSample", + payload, + dry_run=dry_run, + verbose=verbose, + ) def parse_arguments(): - """Parses command-line options.""" - - parser = argparse.ArgumentParser() - parser.add_argument( - '--benchmark_files', - metavar='', - default=[], - action="append", - help=("Paths to the JSON files containing benchmark results, " - "accepts wildcards")) - parser.add_argument( - "--compile_stats_files", - metavar="", - default=[], - action="append", - help=("Paths to the JSON files containing compilation statistics, " - "accepts wildcards")) - parser.add_argument("--dry-run", - action="store_true", - help="Print the comment instead of posting to dashboard") - parser.add_argument('--verbose', - action='store_true', - help='Print internal information during execution') - args = parser.parse_args() - - return args + """Parses command-line options.""" + + parser = argparse.ArgumentParser() + parser.add_argument( + "--benchmark_files", + metavar="", + default=[], + action="append", + help=( + "Paths to the JSON files containing benchmark results, " "accepts wildcards" + ), + ) + parser.add_argument( + "--compile_stats_files", + metavar="", + default=[], + action="append", + help=( + "Paths to the JSON files containing compilation statistics, " + "accepts wildcards" + ), + ) + parser.add_argument( + "--dry-run", + action="store_true", + help="Print the comment instead of posting to dashboard", + ) + parser.add_argument( + "--verbose", + action="store_true", + help="Print internal information during execution", + ) + args = parser.parse_args() + + return args def main(args): - benchmark_files = expand_and_check_file_paths(args.benchmark_files) - compile_stats_files = expand_and_check_file_paths(args.compile_stats_files) - - if len(benchmark_files) > 0: - committish = benchmark_definition.BenchmarkResults.from_json_str( - benchmark_files[0].read_text()).commit - elif len(compile_stats_files) > 0: - committish = benchmark_definition.CompilationResults.from_json_object( - json.loads(compile_stats_files[0].read_text())).commit - else: - raise ValueError("No benchmark/compilation results.") - - # Register a new build for the current commit. - commit_hash = get_git_commit_hash(committish, verbose=args.verbose) - commit_count = get_git_total_commit_count(commit_hash, verbose=args.verbose) - - aggregate_results = benchmark_presentation.aggregate_all_benchmarks( - benchmark_files=benchmark_files, expected_pr_commit=commit_hash) - - all_compilation_metrics = benchmark_presentation.collect_all_compilation_metrics( - compile_stats_files=compile_stats_files, expected_pr_commit=commit_hash) - - # Allow override to support uploading data for the same build in - # different batches. - add_new_iree_build(commit_count, - commit_hash, - override=True, - dry_run=args.dry_run, - verbose=args.verbose) - - # Upload benchmark results to the dashboard. - for series_id, benchmark_latency in aggregate_results.items(): - series_name = benchmark_latency.name - benchmark_info = benchmark_latency.benchmark_info - description = get_model_description(benchmark_info.model_name, - benchmark_info.model_source) - if description is None: - description = "" - description += COMMON_DESCRIPTION - - threshold = next( - (threshold for threshold in benchmark_thresholds.BENCHMARK_THRESHOLDS - if threshold.regex.match(series_name)), None) - average_range = (threshold.get_threshold_str() - if threshold is not None else None) - - # Override by default to allow updates to the series. - add_new_iree_series(series_id=series_id, - series_unit="ns", - series_name=benchmark_latency.name, - series_description=description, - average_range=average_range, - override=True, - dry_run=args.dry_run, - verbose=args.verbose) - add_new_sample(series_id=series_id, - build_id=commit_count, - sample_unit="ns", - sample_value=benchmark_latency.mean_time, - dry_run=args.dry_run, - verbose=args.verbose) - - for target_id, compile_metrics in all_compilation_metrics.items(): - description = get_model_description( - compile_metrics.compilation_info.model_name, - compile_metrics.compilation_info.model_source) - if description is None: - description = "" - description += COMMON_DESCRIPTION - - for mapper in benchmark_presentation.COMPILATION_METRICS_TO_TABLE_MAPPERS: - sample_value, _ = mapper.get_current_and_base_value(compile_metrics) - series_unit = mapper.get_unit() - series_id = mapper.get_series_id(target_id) - series_name = mapper.get_series_name(compile_metrics.name) - - threshold = next( - (threshold for threshold in mapper.get_metric_thresholds() - if threshold.regex.match(series_name)), None) - average_range = (threshold.get_threshold_str() - if threshold is not None else None) - - # Override by default to allow updates to the series. - add_new_iree_series(series_id=series_id, - series_unit=series_unit, - series_name=series_name, - series_description=description, - average_range=average_range, - override=True, - dry_run=args.dry_run, - verbose=args.verbose) - add_new_sample(series_id=series_id, - build_id=commit_count, - sample_unit=series_unit, - sample_value=sample_value, - dry_run=args.dry_run, - verbose=args.verbose) + benchmark_files = expand_and_check_file_paths(args.benchmark_files) + compile_stats_files = expand_and_check_file_paths(args.compile_stats_files) + + if len(benchmark_files) > 0: + committish = benchmark_definition.BenchmarkResults.from_json_str( + benchmark_files[0].read_text() + ).commit + elif len(compile_stats_files) > 0: + committish = benchmark_definition.CompilationResults.from_json_object( + json.loads(compile_stats_files[0].read_text()) + ).commit + else: + raise ValueError("No benchmark/compilation results.") + + # Register a new build for the current commit. + commit_hash = get_git_commit_hash(committish, verbose=args.verbose) + commit_count = get_git_total_commit_count(commit_hash, verbose=args.verbose) + + aggregate_results = benchmark_presentation.aggregate_all_benchmarks( + benchmark_files=benchmark_files, expected_pr_commit=commit_hash + ) + + all_compilation_metrics = benchmark_presentation.collect_all_compilation_metrics( + compile_stats_files=compile_stats_files, expected_pr_commit=commit_hash + ) + + # Allow override to support uploading data for the same build in + # different batches. + add_new_iree_build( + commit_count, + commit_hash, + override=True, + dry_run=args.dry_run, + verbose=args.verbose, + ) + + # Upload benchmark results to the dashboard. + for series_id, benchmark_latency in aggregate_results.items(): + series_name = benchmark_latency.name + benchmark_info = benchmark_latency.benchmark_info + description = get_model_description( + benchmark_info.model_name, benchmark_info.model_source + ) + if description is None: + description = "" + description += COMMON_DESCRIPTION + + threshold = next( + ( + threshold + for threshold in benchmark_thresholds.BENCHMARK_THRESHOLDS + if threshold.regex.match(series_name) + ), + None, + ) + average_range = threshold.get_threshold_str() if threshold is not None else None + + # Override by default to allow updates to the series. + add_new_iree_series( + series_id=series_id, + series_unit="ns", + series_name=benchmark_latency.name, + series_description=description, + average_range=average_range, + override=True, + dry_run=args.dry_run, + verbose=args.verbose, + ) + add_new_sample( + series_id=series_id, + build_id=commit_count, + sample_unit="ns", + sample_value=benchmark_latency.mean_time, + dry_run=args.dry_run, + verbose=args.verbose, + ) + + for target_id, compile_metrics in all_compilation_metrics.items(): + description = get_model_description( + compile_metrics.compilation_info.model_name, + compile_metrics.compilation_info.model_source, + ) + if description is None: + description = "" + description += COMMON_DESCRIPTION + + for mapper in benchmark_presentation.COMPILATION_METRICS_TO_TABLE_MAPPERS: + sample_value, _ = mapper.get_current_and_base_value(compile_metrics) + series_unit = mapper.get_unit() + series_id = mapper.get_series_id(target_id) + series_name = mapper.get_series_name(compile_metrics.name) + + threshold = next( + ( + threshold + for threshold in mapper.get_metric_thresholds() + if threshold.regex.match(series_name) + ), + None, + ) + average_range = ( + threshold.get_threshold_str() if threshold is not None else None + ) + + # Override by default to allow updates to the series. + add_new_iree_series( + series_id=series_id, + series_unit=series_unit, + series_name=series_name, + series_description=description, + average_range=average_range, + override=True, + dry_run=args.dry_run, + verbose=args.verbose, + ) + add_new_sample( + series_id=series_id, + build_id=commit_count, + sample_unit=series_unit, + sample_value=sample_value, + dry_run=args.dry_run, + verbose=args.verbose, + ) if __name__ == "__main__": - main(parse_arguments()) + main(parse_arguments()) diff --git a/build_tools/docker/get_image_name.py b/build_tools/docker/get_image_name.py index ae8e18031a12..993e00c4c062 100755 --- a/build_tools/docker/get_image_name.py +++ b/build_tools/docker/get_image_name.py @@ -22,22 +22,23 @@ def find_image_by_name(img_name): - this_dir = Path(__file__).resolve().parent + this_dir = Path(__file__).resolve().parent - with open(this_dir / "prod_digests.txt", "rt") as f: - for line in f.readlines(): - line = line.strip() - if line.startswith(f"gcr.io/iree-oss/{img_name}@"): - return line - else: - raise ValueError( - f"ERROR: Image name {img_name} not found in prod_digests.txt") + with open(this_dir / "prod_digests.txt", "rt") as f: + for line in f.readlines(): + line = line.strip() + if line.startswith(f"gcr.io/iree-oss/{img_name}@"): + return line + else: + raise ValueError( + f"ERROR: Image name {img_name} not found in prod_digests.txt" + ) if __name__ == "__main__": - if len(sys.argv) != 2: - print("ERROR: Expected image short name", file=sys.stderr) - sys.exit(1) - short_name = sys.argv[1] - image_name = find_image_by_name(short_name) - print(image_name) + if len(sys.argv) != 2: + print("ERROR: Expected image short name", file=sys.stderr) + sys.exit(1) + short_name = sys.argv[1] + image_name = find_image_by_name(short_name) + print(image_name) diff --git a/build_tools/docker/manage_images.py b/build_tools/docker/manage_images.py index f10b4331e12e..3bde67a99b58 100755 --- a/build_tools/docker/manage_images.py +++ b/build_tools/docker/manage_images.py @@ -67,193 +67,212 @@ IMAGES_TO_DEPENDENT_IMAGES = {k: [] for k in IMAGES_TO_DEPENDENCIES} for image, dependencies in IMAGES_TO_DEPENDENCIES.items(): - for dependency in dependencies: - IMAGES_TO_DEPENDENT_IMAGES[dependency].append(image) + for dependency in dependencies: + IMAGES_TO_DEPENDENT_IMAGES[dependency].append(image) IMAGES_HELP = [f"`{name}`" for name in IMAGES_TO_DEPENDENCIES] IMAGES_HELP = f"{', '.join(IMAGES_HELP)} or `all`" def parse_arguments(): - """Parses command-line options.""" - parser = argparse.ArgumentParser( - description="Build IREE's Docker images and optionally push them to GCR.") - parser.add_argument("--images", - "--image", - type=str, - required=True, - action="append", - help=f"Name of the image to build: {IMAGES_HELP}.") - parser.add_argument( - "--dry_run", - "--dry-run", - "-n", - action="store_true", - help="Print output without building or pushing any images.") - parser.add_argument( - "--only_references", - "--only-references", - action="store_true", - help= - "Just update references to images using the digests in prod_digests.txt") - - args = parser.parse_args() - for image in args.images: - if image == "all": - # Sort for a determinstic order - args.images = sorted(IMAGES_TO_DEPENDENCIES.keys()) - elif image not in IMAGES_TO_DEPENDENCIES: - raise parser.error("Expected --image to be one of:\n" - f" {IMAGES_HELP}\n" - f"but got `{image}`.") - return args - - -def _dag_dfs(input_nodes: Sequence[str], - node_to_child_nodes: Dict[str, Sequence[str]]) -> List[str]: - # Python doesn't have a builtin OrderedSet, but we don't have many images, so - # we just use a list. - ordered_nodes = [] - - def add_children(parent_node: str): - if parent_node not in ordered_nodes: - for child_node in node_to_child_nodes[parent_node]: - add_children(child_node) - ordered_nodes.append(parent_node) - - for node in input_nodes: - add_children(node) - return ordered_nodes + """Parses command-line options.""" + parser = argparse.ArgumentParser( + description="Build IREE's Docker images and optionally push them to GCR." + ) + parser.add_argument( + "--images", + "--image", + type=str, + required=True, + action="append", + help=f"Name of the image to build: {IMAGES_HELP}.", + ) + parser.add_argument( + "--dry_run", + "--dry-run", + "-n", + action="store_true", + help="Print output without building or pushing any images.", + ) + parser.add_argument( + "--only_references", + "--only-references", + action="store_true", + help="Just update references to images using the digests in prod_digests.txt", + ) + + args = parser.parse_args() + for image in args.images: + if image == "all": + # Sort for a determinstic order + args.images = sorted(IMAGES_TO_DEPENDENCIES.keys()) + elif image not in IMAGES_TO_DEPENDENCIES: + raise parser.error( + "Expected --image to be one of:\n" + f" {IMAGES_HELP}\n" + f"but got `{image}`." + ) + return args + + +def _dag_dfs( + input_nodes: Sequence[str], node_to_child_nodes: Dict[str, Sequence[str]] +) -> List[str]: + # Python doesn't have a builtin OrderedSet, but we don't have many images, so + # we just use a list. + ordered_nodes = [] + + def add_children(parent_node: str): + if parent_node not in ordered_nodes: + for child_node in node_to_child_nodes[parent_node]: + add_children(child_node) + ordered_nodes.append(parent_node) + + for node in input_nodes: + add_children(node) + return ordered_nodes def get_ordered_images_to_process(images: Sequence[str]) -> List[str]: - dependents = _dag_dfs(images, IMAGES_TO_DEPENDENT_IMAGES) - dependents.reverse() - return dependents + dependents = _dag_dfs(images, IMAGES_TO_DEPENDENT_IMAGES) + dependents.reverse() + return dependents def get_dependencies(images: Sequence[str]) -> List[str]: - return _dag_dfs(images, IMAGES_TO_DEPENDENCIES) + return _dag_dfs(images, IMAGES_TO_DEPENDENCIES) def get_repo_digest(tagged_image_url: str, dry_run: bool = False) -> str: - inspect_command = [ - "docker", - "image", - "inspect", - tagged_image_url, - "-f", - "{{index .RepoDigests 0}}", - ] - try: - completed_process = utils.run_command( - inspect_command, - dry_run=False, # Run even if --dry_run is True. - capture_output=True, - timeout=10) - except subprocess.CalledProcessError as error: - if dry_run: - return "" - else: - raise RuntimeError( - f"Computing the repository digest for {tagged_image_url} failed. Has " - "it been pushed to GCR?") from error - _, repo_digest = completed_process.stdout.strip().split("@") - return repo_digest + inspect_command = [ + "docker", + "image", + "inspect", + tagged_image_url, + "-f", + "{{index .RepoDigests 0}}", + ] + try: + completed_process = utils.run_command( + inspect_command, + dry_run=False, # Run even if --dry_run is True. + capture_output=True, + timeout=10, + ) + except subprocess.CalledProcessError as error: + if dry_run: + return "" + else: + raise RuntimeError( + f"Computing the repository digest for {tagged_image_url} failed. Has " + "it been pushed to GCR?" + ) from error + _, repo_digest = completed_process.stdout.strip().split("@") + return repo_digest def update_references(image_url: str, digest: str, dry_run: bool = False): - """Updates all references to "image_url" with a sha256 digest.""" - print(f"Updating references to {image_url}") - - grep_command = ["git", "grep", "-l", f"{image_url}@sha256"] - try: - completed_process = utils.run_command(grep_command, - capture_output=True, - timeout=5) - except subprocess.CalledProcessError as error: - if error.returncode == 1: - print(f"Found no references to {image_url}") - return - raise error - - # Update references in all grepped files. - files = completed_process.stdout.split() - print(f"Updating references in {len(files)} files: {files}") - if not dry_run: - for line in fileinput.input(files=files, inplace=True): - print(re.sub(f"{image_url}@{DIGEST_REGEX}", f"{image_url}@{digest}", - line), - end="") + """Updates all references to "image_url" with a sha256 digest.""" + print(f"Updating references to {image_url}") + + grep_command = ["git", "grep", "-l", f"{image_url}@sha256"] + try: + completed_process = utils.run_command( + grep_command, capture_output=True, timeout=5 + ) + except subprocess.CalledProcessError as error: + if error.returncode == 1: + print(f"Found no references to {image_url}") + return + raise error + + # Update references in all grepped files. + files = completed_process.stdout.split() + print(f"Updating references in {len(files)} files: {files}") + if not dry_run: + for line in fileinput.input(files=files, inplace=True): + print( + re.sub(f"{image_url}@{DIGEST_REGEX}", f"{image_url}@{digest}", line), + end="", + ) def parse_prod_digests() -> Dict[str, str]: - image_urls_to_prod_digests = {} - with open(utils.PROD_DIGESTS_PATH, "r") as f: - for line in f: - image_url, digest = line.strip().split("@") - image_urls_to_prod_digests[image_url] = digest - return image_urls_to_prod_digests + image_urls_to_prod_digests = {} + with open(utils.PROD_DIGESTS_PATH, "r") as f: + for line in f: + image_url, digest = line.strip().split("@") + image_urls_to_prod_digests[image_url] = digest + return image_urls_to_prod_digests if __name__ == "__main__": - args = parse_arguments() - image_urls_to_prod_digests = parse_prod_digests() - images_to_process = get_ordered_images_to_process(args.images) - print(f"Also processing dependent images. Will process: {images_to_process}") - - if not args.only_references: - # Ensure the user has the correct authorization to push to GCR. - utils.check_gcloud_auth(dry_run=args.dry_run) - - dependencies = get_dependencies(images_to_process) - print(f"Pulling image dependencies: {dependencies}") - for dependency in dependencies: - dependency_url = posixpath.join(IREE_GCR_URL, dependency) - # If `dependency` is a new image then it may not have a prod digest yet. - if dependency_url in image_urls_to_prod_digests: - digest = image_urls_to_prod_digests[dependency_url] - dependency_with_digest = f"{dependency_url}@{digest}" - utils.run_command(["docker", "pull", dependency_with_digest], - dry_run=args.dry_run) - - for image in images_to_process: - print("\n" * 5 + f"Processing image {image}") - image_url = posixpath.join(IREE_GCR_URL, image) - tagged_image_url = f"{image_url}" - image_path = os.path.join(DOCKER_DIR, "dockerfiles", f"{image}.Dockerfile") - - if args.only_references: - digest = image_urls_to_prod_digests[image_url] - else: - # We deliberately give the whole repository as context so we can reuse - # scripts and such. It would be nice if Docker gave us a way to make this - # more explicit, like symlinking files in the context, but they refuse to - # with the justification that it makes builds non-hermetic, a hilarious - # concern for something that allows and encourages arbitrary network - # access in builds. - # We're assuming this is being run from the root of the repository. - # FIXME: make this more robust to where it is run from. - utils.run_command([ - "docker", "build", "--file", image_path, "--tag", tagged_image_url, - "." - ], - dry_run=args.dry_run) - - utils.run_command(["docker", "push", tagged_image_url], - dry_run=args.dry_run) - - digest = get_repo_digest(tagged_image_url, args.dry_run) - - # Check that the image is in "prod_digests.txt" and append it to the list - # in the file if it isn't. - if image_url not in image_urls_to_prod_digests: - image_with_digest = f"{image_url}@{digest}" - print( - f"Adding new image {image_with_digest} to {utils.PROD_DIGESTS_PATH}" - ) - if not args.dry_run: - with open(utils.PROD_DIGESTS_PATH, "a") as f: - f.write(f"{image_with_digest}\n") - - update_references(image_url, digest, dry_run=args.dry_run) + args = parse_arguments() + image_urls_to_prod_digests = parse_prod_digests() + images_to_process = get_ordered_images_to_process(args.images) + print(f"Also processing dependent images. Will process: {images_to_process}") + + if not args.only_references: + # Ensure the user has the correct authorization to push to GCR. + utils.check_gcloud_auth(dry_run=args.dry_run) + + dependencies = get_dependencies(images_to_process) + print(f"Pulling image dependencies: {dependencies}") + for dependency in dependencies: + dependency_url = posixpath.join(IREE_GCR_URL, dependency) + # If `dependency` is a new image then it may not have a prod digest yet. + if dependency_url in image_urls_to_prod_digests: + digest = image_urls_to_prod_digests[dependency_url] + dependency_with_digest = f"{dependency_url}@{digest}" + utils.run_command( + ["docker", "pull", dependency_with_digest], dry_run=args.dry_run + ) + + for image in images_to_process: + print("\n" * 5 + f"Processing image {image}") + image_url = posixpath.join(IREE_GCR_URL, image) + tagged_image_url = f"{image_url}" + image_path = os.path.join(DOCKER_DIR, "dockerfiles", f"{image}.Dockerfile") + + if args.only_references: + digest = image_urls_to_prod_digests[image_url] + else: + # We deliberately give the whole repository as context so we can reuse + # scripts and such. It would be nice if Docker gave us a way to make this + # more explicit, like symlinking files in the context, but they refuse to + # with the justification that it makes builds non-hermetic, a hilarious + # concern for something that allows and encourages arbitrary network + # access in builds. + # We're assuming this is being run from the root of the repository. + # FIXME: make this more robust to where it is run from. + utils.run_command( + [ + "docker", + "build", + "--file", + image_path, + "--tag", + tagged_image_url, + ".", + ], + dry_run=args.dry_run, + ) + + utils.run_command( + ["docker", "push", tagged_image_url], dry_run=args.dry_run + ) + + digest = get_repo_digest(tagged_image_url, args.dry_run) + + # Check that the image is in "prod_digests.txt" and append it to the list + # in the file if it isn't. + if image_url not in image_urls_to_prod_digests: + image_with_digest = f"{image_url}@{digest}" + print( + f"Adding new image {image_with_digest} to {utils.PROD_DIGESTS_PATH}" + ) + if not args.dry_run: + with open(utils.PROD_DIGESTS_PATH, "a") as f: + f.write(f"{image_with_digest}\n") + + update_references(image_url, digest, dry_run=args.dry_run) diff --git a/build_tools/docker/utils.py b/build_tools/docker/utils.py index cec694f99915..cde92bcb1421 100644 --- a/build_tools/docker/utils.py +++ b/build_tools/docker/utils.py @@ -13,32 +13,33 @@ PROD_DIGESTS_PATH = "build_tools/docker/prod_digests.txt".replace("/", os.sep) -def run_command(command: Sequence[str], - dry_run: bool = False, - check: bool = True, - capture_output: bool = False, - text: bool = True, - **run_kwargs) -> subprocess.CompletedProcess: - """Thin wrapper around subprocess.run""" - print(f"Running: `{' '.join(command)}`") - if dry_run: - # Dummy CompletedProess with successful returncode. - return subprocess.CompletedProcess(command, returncode=0) - - completed_process = subprocess.run(command, - text=text, - check=check, - capture_output=capture_output, - **run_kwargs) - return completed_process +def run_command( + command: Sequence[str], + dry_run: bool = False, + check: bool = True, + capture_output: bool = False, + text: bool = True, + **run_kwargs, +) -> subprocess.CompletedProcess: + """Thin wrapper around subprocess.run""" + print(f"Running: `{' '.join(command)}`") + if dry_run: + # Dummy CompletedProess with successful returncode. + return subprocess.CompletedProcess(command, returncode=0) + + completed_process = subprocess.run( + command, text=text, check=check, capture_output=capture_output, **run_kwargs + ) + return completed_process def check_gcloud_auth(dry_run: bool = False): - # Ensure the user has the correct authorization if they try to push to GCR. - try: - run_command(['which', 'gcloud']) - except subprocess.CalledProcessError as error: - raise RuntimeError( - 'gcloud not found. See https://cloud.google.com/sdk/install for ' - 'installation.') from error - run_command(["gcloud", "auth", "configure-docker"], dry_run) + # Ensure the user has the correct authorization if they try to push to GCR. + try: + run_command(["which", "gcloud"]) + except subprocess.CalledProcessError as error: + raise RuntimeError( + "gcloud not found. See https://cloud.google.com/sdk/install for " + "installation." + ) from error + run_command(["gcloud", "auth", "configure-docker"], dry_run) diff --git a/build_tools/github_actions/build_dist.py b/build_tools/github_actions/build_dist.py index b067eba03738..19dbcc6d9932 100644 --- a/build_tools/github_actions/build_dist.py +++ b/build_tools/github_actions/build_dist.py @@ -62,142 +62,157 @@ TF_INTEGRATIONS_DIR = os.path.join(IREESRC_DIR, "integrations/tensorflow") BINDIST_DIR = os.environ.get("BINDIST_DIR") if BINDIST_DIR is None: - BINDIST_DIR = os.path.join(WORK_DIR, "bindist") + BINDIST_DIR = os.path.join(WORK_DIR, "bindist") THIS_DIR = os.path.realpath(os.path.dirname(__file__)) CMAKE_CI_SCRIPT = os.path.join(THIS_DIR, "cmake_ci.py") -BUILD_REQUIREMENTS_TXT = os.path.join(IREESRC_DIR, "runtime", "bindings", - "python", "iree", "runtime", - "build_requirements.txt") +BUILD_REQUIREMENTS_TXT = os.path.join( + IREESRC_DIR, + "runtime", + "bindings", + "python", + "iree", + "runtime", + "build_requirements.txt", +) CI_REQUIREMENTS_TXT = os.path.join(THIS_DIR, "ci_requirements.txt") CONFIGURE_BAZEL_PY = os.path.join(IREESRC_DIR, "configure_bazel.py") -INSTALL_TARGET = ("install" - if platform.system() == "Windows" else "install/strip") +INSTALL_TARGET = "install" if platform.system() == "Windows" else "install/strip" # Load version info. def load_version_info(): - with open(os.path.join(IREESRC_DIR, "version_info.json"), "rt") as f: - return json.load(f) + with open(os.path.join(IREESRC_DIR, "version_info.json"), "rt") as f: + return json.load(f) try: - version_info = load_version_info() + version_info = load_version_info() except FileNotFoundError: - print("version_info.json not found. Using defaults") - version_info = { - "package-version": "0.1dev1", - "package-suffix": "-dev", - } + print("version_info.json not found. Using defaults") + version_info = { + "package-version": "0.1dev1", + "package-suffix": "-dev", + } def remove_cmake_cache(): - cache_file = os.path.join(BUILD_DIR, "CMakeCache.txt") - if os.path.exists(cache_file): - print(f"Removing {cache_file}") - os.remove(cache_file) - else: - print(f"Not removing cache file (does not exist): {cache_file}") + cache_file = os.path.join(BUILD_DIR, "CMakeCache.txt") + if os.path.exists(cache_file): + print(f"Removing {cache_file}") + os.remove(cache_file) + else: + print(f"Not removing cache file (does not exist): {cache_file}") def install_python_requirements(): - print("Installing python requirements...") - subprocess.check_call( - [sys.executable, "-m", "pip", "install", "-r", BUILD_REQUIREMENTS_TXT]) - subprocess.check_call( - [sys.executable, "-m", "pip", "install", "-r", CI_REQUIREMENTS_TXT]) + print("Installing python requirements...") + subprocess.check_call( + [sys.executable, "-m", "pip", "install", "-r", BUILD_REQUIREMENTS_TXT] + ) + subprocess.check_call( + [sys.executable, "-m", "pip", "install", "-r", CI_REQUIREMENTS_TXT] + ) def configure_bazel(): - print("Generating configured.bazelrc...") - subprocess.check_call([sys.executable, CONFIGURE_BAZEL_PY]) + print("Generating configured.bazelrc...") + subprocess.check_call([sys.executable, CONFIGURE_BAZEL_PY]) def build_main_dist(): - """Builds the main distribution binaries. - - Additional packages that are installable as part of a full build and do not - benefit from a more restricted build can be added here. - """ - install_python_requirements() - - # Clean up install and build trees. - shutil.rmtree(INSTALL_DIR, ignore_errors=True) - remove_cmake_cache() - - # CMake configure. - print("*** Configuring ***") - subprocess.run( - [ - sys.executable, - CMAKE_CI_SCRIPT, - f"-B{BUILD_DIR}", - "--log-level=VERBOSE", - f"-DCMAKE_INSTALL_PREFIX={INSTALL_DIR}", - # On some distributions, this will install to lib64. We would like - # consistency in built packages, so hard-code it. - "-DCMAKE_INSTALL_LIBDIR=lib", - f"-DCMAKE_BUILD_TYPE=Release", - f"-DIREE_BUILD_COMPILER=ON", - f"-DIREE_BUILD_PYTHON_BINDINGS=OFF", - f"-DIREE_BUILD_SAMPLES=OFF", - ], - check=True) - - print("*** Building ***") - subprocess.run([ - sys.executable, - CMAKE_CI_SCRIPT, - "--build", - BUILD_DIR, - "--target", - INSTALL_TARGET, - ], - check=True) - - print("*** Packaging ***") - dist_entries = [ - "bin", - "lib", - ] - dist_archive = os.path.join( - BINDIST_DIR, f"iree-dist{version_info['package-suffix']}" - f"-{version_info['package-version']}" - f"-{sysconfig.get_platform()}.tar.xz") - print(f"Creating archive {dist_archive}") - os.makedirs(os.path.dirname(dist_archive), exist_ok=True) - with tarfile.open(dist_archive, mode="w:xz") as tf: - for entry in dist_entries: - print(f"Adding entry: {entry}") - tf.add(os.path.join(INSTALL_DIR, entry), arcname=entry, recursive=True) + """Builds the main distribution binaries. + Additional packages that are installable as part of a full build and do not + benefit from a more restricted build can be added here. + """ + install_python_requirements() -def build_py_tf_compiler_tools_pkg(): - """Builds the iree-install/python_packages/iree_tools_tf package.""" - install_python_requirements() - configure_bazel() - - # Clean up install and build trees. - shutil.rmtree(INSTALL_DIR, ignore_errors=True) - remove_cmake_cache() + # Clean up install and build trees. + shutil.rmtree(INSTALL_DIR, ignore_errors=True) + remove_cmake_cache() - os.makedirs(BINDIST_DIR, exist_ok=True) + # CMake configure. + print("*** Configuring ***") + subprocess.run( + [ + sys.executable, + CMAKE_CI_SCRIPT, + f"-B{BUILD_DIR}", + "--log-level=VERBOSE", + f"-DCMAKE_INSTALL_PREFIX={INSTALL_DIR}", + # On some distributions, this will install to lib64. We would like + # consistency in built packages, so hard-code it. + "-DCMAKE_INSTALL_LIBDIR=lib", + f"-DCMAKE_BUILD_TYPE=Release", + f"-DIREE_BUILD_COMPILER=ON", + f"-DIREE_BUILD_PYTHON_BINDINGS=OFF", + f"-DIREE_BUILD_SAMPLES=OFF", + ], + check=True, + ) - for project in ["iree_tflite", "iree_tf"]: - print(f"*** Building wheel for {project} ***") + print("*** Building ***") subprocess.run( [ - sys.executable, "-m", "pip", "wheel", - os.path.join(TF_INTEGRATIONS_DIR, "python_projects", project) + sys.executable, + CMAKE_CI_SCRIPT, + "--build", + BUILD_DIR, + "--target", + INSTALL_TARGET, ], - cwd=BINDIST_DIR, check=True, ) + print("*** Packaging ***") + dist_entries = [ + "bin", + "lib", + ] + dist_archive = os.path.join( + BINDIST_DIR, + f"iree-dist{version_info['package-suffix']}" + f"-{version_info['package-version']}" + f"-{sysconfig.get_platform()}.tar.xz", + ) + print(f"Creating archive {dist_archive}") + os.makedirs(os.path.dirname(dist_archive), exist_ok=True) + with tarfile.open(dist_archive, mode="w:xz") as tf: + for entry in dist_entries: + print(f"Adding entry: {entry}") + tf.add(os.path.join(INSTALL_DIR, entry), arcname=entry, recursive=True) + + +def build_py_tf_compiler_tools_pkg(): + """Builds the iree-install/python_packages/iree_tools_tf package.""" + install_python_requirements() + configure_bazel() + + # Clean up install and build trees. + shutil.rmtree(INSTALL_DIR, ignore_errors=True) + remove_cmake_cache() + + os.makedirs(BINDIST_DIR, exist_ok=True) + + for project in ["iree_tflite", "iree_tf"]: + print(f"*** Building wheel for {project} ***") + subprocess.run( + [ + sys.executable, + "-m", + "pip", + "wheel", + os.path.join(TF_INTEGRATIONS_DIR, "python_projects", project), + ], + cwd=BINDIST_DIR, + check=True, + ) + command = sys.argv[1] if command == "main-dist": - build_main_dist() + build_main_dist() elif command == "py-tf-compiler-tools-pkg": - build_py_tf_compiler_tools_pkg() + build_py_tf_compiler_tools_pkg() else: - print(f"Unrecognized command: {command}") + print(f"Unrecognized command: {command}") diff --git a/build_tools/github_actions/cmake_ci.py b/build_tools/github_actions/cmake_ci.py index 0e328bd6092a..b01a7ef020e5 100644 --- a/build_tools/github_actions/cmake_ci.py +++ b/build_tools/github_actions/cmake_ci.py @@ -10,14 +10,15 @@ # This future is needed to print Python2 EOL message from __future__ import print_function import sys + if sys.version_info < (3,): - print("Python 2 has reached end-of-life and is no longer supported.") - sys.exit(-1) -if sys.platform == 'win32' and sys.maxsize.bit_length() == 31: - print( - "32-bit Windows Python runtime is not supported. Please switch to 64-bit Python." - ) - sys.exit(-1) + print("Python 2 has reached end-of-life and is no longer supported.") + sys.exit(-1) +if sys.platform == "win32" and sys.maxsize.bit_length() == 31: + print( + "32-bit Windows Python runtime is not supported. Please switch to 64-bit Python." + ) + sys.exit(-1) import importlib import json @@ -27,175 +28,177 @@ import sysconfig import tempfile -is_windows = platform.system() == 'Windows' +is_windows = platform.system() == "Windows" def display_help(): - print('Syntax: python build_tools/cmake/cmake_ci.py [--install|--build] ...') - print('If neither --install or --build are the first argument, then it is ') - print('assumed to be a generate invocation') + print("Syntax: python build_tools/cmake/cmake_ci.py [--install|--build] ...") + print("If neither --install or --build are the first argument, then it is ") + print("assumed to be a generate invocation") -mode = 'generate' +mode = "generate" if len(sys.argv) < 2: - display_help() - sys.exit(1) -if sys.argv[1] == '--install': - mode = 'install' -elif sys.argv[1] == '--build': - mode = 'build' + display_help() + sys.exit(1) +if sys.argv[1] == "--install": + mode = "install" +elif sys.argv[1] == "--build": + mode = "build" def report(*args): - print('--', *args) + print("--", *args) def get_setting(varname, default_value): - value = os.environ.get(varname) - if value is None: - return default_value - return value + value = os.environ.get(varname) + if value is None: + return default_value + return value def get_bool_setting(varname, default_value): - value = get_setting(varname, default_value) - if value is True or value is False: - return value - return value == '' or value == 'ON' or value == '1' + value = get_setting(varname, default_value) + if value is True or value is False: + return value + return value == "" or value == "ON" or value == "1" def which(thefile): - path = os.environ.get("PATH", os.defpath).split(os.pathsep) - for d in path: - fname = os.path.join(d, thefile) - fnames = [fname] - if sys.platform == 'win32': - exts = os.environ.get('PATHEXT', '').split(os.pathsep) - fnames += [fname + ext for ext in exts] - for name in fnames: - if os.access(name, os.F_OK | os.X_OK) and not os.path.isdir(name): - return name - return None + path = os.environ.get("PATH", os.defpath).split(os.pathsep) + for d in path: + fname = os.path.join(d, thefile) + fnames = [fname] + if sys.platform == "win32": + exts = os.environ.get("PATHEXT", "").split(os.pathsep) + fnames += [fname + ext for ext in exts] + for name in fnames: + if os.access(name, os.F_OK | os.X_OK) and not os.path.isdir(name): + return name + return None def use_tool_path(toolname, varname=None): - if not varname: - varname = toolname.upper() - value = get_setting(f'USE_{varname}', 'ON') - if value.upper() == 'OFF': - return None - if value.upper() == 'ON' or value == '': - return which(toolname) - if os.access(value, os.F_OK | os.X_OK) and not os.path.isdir(value): - return value + if not varname: + varname = toolname.upper() + value = get_setting(f"USE_{varname}", "ON") + if value.upper() == "OFF": + return None + if value.upper() == "ON" or value == "": + return which(toolname) + if os.access(value, os.F_OK | os.X_OK) and not os.path.isdir(value): + return value ### Detect cmake. -use_cmake = use_tool_path('cmake') or 'cmake' +use_cmake = use_tool_path("cmake") or "cmake" cmake_command_prefix = [use_cmake] cmake_environ = os.environ def cmake_commandline(args): - return cmake_command_prefix + args + return cmake_command_prefix + args if is_windows: - # Bazel needs msys bash and TensorFlow will melt down and cry if it finds - # system bash. Because, of course it will. - # Note that we don't set this as a CMake option because it may have spaces - # in the path, use backslashes or various other things that get corrupted - # in the five or six layers of shoddy string transformations between here - # and where it gets used. - bash_exe = which('bash') - report('Found Windows bash:', bash_exe) - report('NOTE: If the above is system32 bash and you are using bazel to build ' - 'TensorFlow, you are going to have a bad time. Suggest being explicit ' - 'adding the correct directory to your path. I\'m really sorry. ' - 'I didn\'t make this mess... just the messenger') - report(f'Full path = {os.environ.get("PATH")}') + # Bazel needs msys bash and TensorFlow will melt down and cry if it finds + # system bash. Because, of course it will. + # Note that we don't set this as a CMake option because it may have spaces + # in the path, use backslashes or various other things that get corrupted + # in the five or six layers of shoddy string transformations between here + # and where it gets used. + bash_exe = which("bash") + report("Found Windows bash:", bash_exe) + report( + "NOTE: If the above is system32 bash and you are using bazel to build " + "TensorFlow, you are going to have a bad time. Suggest being explicit " + "adding the correct directory to your path. I'm really sorry. " + "I didn't make this mess... just the messenger" + ) + report(f'Full path = {os.environ.get("PATH")}') def invoke_generate(): - ############################################################################## - # Figure out where we are and where we are going. - ############################################################################## - repo_root = os.path.abspath( - get_setting('REPO_DIR', os.path.join(os.path.dirname(__file__), '..', - '..'))) - report(f'Using REPO_DIR = {repo_root}') - - ############################################################################## - # Load version_info.json - ############################################################################## - - def load_version_info(): - with open(os.path.join(repo_root, 'version_info.json'), 'rt') as f: - return json.load(f) - - try: - version_info = load_version_info() - except FileNotFoundError: - report('version_info.json found') - version_info = {} - - ############################################################################## - # CMake configure. - ############################################################################## - - cmake_args = [ - f'-S{repo_root}', - f'-DPython3_EXECUTABLE:FILEPATH={sys.executable}', - # The old python package settings should not be needed, but since there - # can be configuration races between packages that use both mechanisms, - # be explicit. - f'-DPYTHON_EXECUTABLE:FILEPATH={sys.executable}', - f'-DPython3_INCLUDE_DIR:PATH={sysconfig.get_path("include")}', - f'-DPYTHON_INCLUDE_DIR:PATH={sysconfig.get_path("include")}', - f'-DIREE_RELEASE_PACKAGE_SUFFIX:STRING={version_info.get("package-suffix") or ""}', - f'-DIREE_RELEASE_VERSION:STRING={version_info.get("package-version") or "0.0.1a1"}', - f'-DIREE_RELEASE_REVISION:STRING={version_info.get("iree-revision") or "HEAD"}', - ] - - ### Detect generator. - if use_tool_path('ninja'): - report('Using ninja') - cmake_args.append('-GNinja') - elif is_windows: - cmake_args.extend(['-G', 'NMake Makefiles']) - - # Detect other build tools. - use_ccache = use_tool_path('ccache') - if not is_windows and use_ccache: - report(f'Using ccache {use_ccache}') - cmake_args.append(f'-DCMAKE_CXX_COMPILER_LAUNCHER={use_ccache}') - - # Clang - use_clang = use_tool_path('clang') - if not is_windows and use_clang: - report(f'Using clang {use_clang}') - cmake_args.append(f'-DCMAKE_C_COMPILER={use_clang}') - use_clangcpp = use_tool_path('clang++', 'CLANGCPP') - if not is_windows and use_clangcpp: - report(f'Using clang++ {use_clangcpp}') - cmake_args.append(f'-DCMAKE_CXX_COMPILER={use_clangcpp}') - - # LLD - use_lld = use_tool_path('lld') - if not is_windows and use_lld: - report(f'Using linker {use_lld}') - cmake_args.append('-DIREE_ENABLE_LLD=ON') - - cmake_args.extend(sys.argv[1:]) - report(f'Running cmake (generate): {" ".join(cmake_args)}') - subprocess.check_call(cmake_commandline(cmake_args), env=cmake_environ) + ############################################################################## + # Figure out where we are and where we are going. + ############################################################################## + repo_root = os.path.abspath( + get_setting("REPO_DIR", os.path.join(os.path.dirname(__file__), "..", "..")) + ) + report(f"Using REPO_DIR = {repo_root}") + + ############################################################################## + # Load version_info.json + ############################################################################## + + def load_version_info(): + with open(os.path.join(repo_root, "version_info.json"), "rt") as f: + return json.load(f) + + try: + version_info = load_version_info() + except FileNotFoundError: + report("version_info.json found") + version_info = {} + + ############################################################################## + # CMake configure. + ############################################################################## + + cmake_args = [ + f"-S{repo_root}", + f"-DPython3_EXECUTABLE:FILEPATH={sys.executable}", + # The old python package settings should not be needed, but since there + # can be configuration races between packages that use both mechanisms, + # be explicit. + f"-DPYTHON_EXECUTABLE:FILEPATH={sys.executable}", + f'-DPython3_INCLUDE_DIR:PATH={sysconfig.get_path("include")}', + f'-DPYTHON_INCLUDE_DIR:PATH={sysconfig.get_path("include")}', + f'-DIREE_RELEASE_PACKAGE_SUFFIX:STRING={version_info.get("package-suffix") or ""}', + f'-DIREE_RELEASE_VERSION:STRING={version_info.get("package-version") or "0.0.1a1"}', + f'-DIREE_RELEASE_REVISION:STRING={version_info.get("iree-revision") or "HEAD"}', + ] + + ### Detect generator. + if use_tool_path("ninja"): + report("Using ninja") + cmake_args.append("-GNinja") + elif is_windows: + cmake_args.extend(["-G", "NMake Makefiles"]) + + # Detect other build tools. + use_ccache = use_tool_path("ccache") + if not is_windows and use_ccache: + report(f"Using ccache {use_ccache}") + cmake_args.append(f"-DCMAKE_CXX_COMPILER_LAUNCHER={use_ccache}") + + # Clang + use_clang = use_tool_path("clang") + if not is_windows and use_clang: + report(f"Using clang {use_clang}") + cmake_args.append(f"-DCMAKE_C_COMPILER={use_clang}") + use_clangcpp = use_tool_path("clang++", "CLANGCPP") + if not is_windows and use_clangcpp: + report(f"Using clang++ {use_clangcpp}") + cmake_args.append(f"-DCMAKE_CXX_COMPILER={use_clangcpp}") + + # LLD + use_lld = use_tool_path("lld") + if not is_windows and use_lld: + report(f"Using linker {use_lld}") + cmake_args.append("-DIREE_ENABLE_LLD=ON") + + cmake_args.extend(sys.argv[1:]) + report(f'Running cmake (generate): {" ".join(cmake_args)}') + subprocess.check_call(cmake_commandline(cmake_args), env=cmake_environ) # Select which mode. -if mode == 'generate': - invoke_generate() +if mode == "generate": + invoke_generate() else: - # Just pass-through. - cmake_args = cmake_commandline(sys.argv[1:]) - report('Invoke CMake:', ' '.join(cmake_args)) - subprocess.check_call(cmake_args, env=cmake_environ) + # Just pass-through. + cmake_args = cmake_commandline(sys.argv[1:]) + report("Invoke CMake:", " ".join(cmake_args)) + subprocess.check_call(cmake_args, env=cmake_environ) diff --git a/build_tools/github_actions/configure_ci.py b/build_tools/github_actions/configure_ci.py index 8c8011057a12..0e009c205e0c 100755 --- a/build_tools/github_actions/configure_ci.py +++ b/build_tools/github_actions/configure_ci.py @@ -78,14 +78,17 @@ RUNNER_ENV_OPTIONS = [RUNNER_ENV_DEFAULT, "testing"] DEFAULT_BENCHMARK_PRESET_GROUP = [ - "cuda", "x86_64", "android-cpu", "android-gpu", "vulkan-nvidia", - "comp-stats" + "cuda", + "x86_64", + "android-cpu", + "android-gpu", + "vulkan-nvidia", + "comp-stats", ] DEFAULT_BENCHMARK_PRESET = "default" LARGE_BENCHMARK_PRESET_GROUP = ["cuda-large", "x86_64-large"] # All available benchmark preset options including experimental presets. -BENCHMARK_PRESET_OPTIONS = (DEFAULT_BENCHMARK_PRESET_GROUP + - LARGE_BENCHMARK_PRESET_GROUP) +BENCHMARK_PRESET_OPTIONS = DEFAULT_BENCHMARK_PRESET_GROUP + LARGE_BENCHMARK_PRESET_GROUP BENCHMARK_LABEL_PREFIX = "benchmarks" PR_DESCRIPTION_TEMPLATE = "{title}" "\n\n" "{body}" @@ -95,67 +98,76 @@ # intended to be merged and should exclude test/draft PRs as well as # PRs that include temporary patches to the submodule during review. # See also: https://github.com/openxla/iree/issues/12268 -LLVM_INTEGRATE_TITLE_PATTERN = re.compile("^integrate.+llvm-project", - re.IGNORECASE) +LLVM_INTEGRATE_TITLE_PATTERN = re.compile("^integrate.+llvm-project", re.IGNORECASE) LLVM_INTEGRATE_BRANCH_PATTERN = re.compile("bump-llvm|llvm-bump", re.IGNORECASE) LLVM_INTEGRATE_LABEL = "llvm-integrate" def skip_path(path: str) -> bool: - return any(fnmatch.fnmatch(path, pattern) for pattern in SKIP_PATH_PATTERNS) + return any(fnmatch.fnmatch(path, pattern) for pattern in SKIP_PATH_PATTERNS) def set_output(d: Mapping[str, str]): - print(f"Setting outputs: {d}") - step_output_file = os.environ["GITHUB_OUTPUT"] - with open(step_output_file, "a") as f: - f.writelines(f"{k}={v}" + "\n" for k, v in d.items()) + print(f"Setting outputs: {d}") + step_output_file = os.environ["GITHUB_OUTPUT"] + with open(step_output_file, "a") as f: + f.writelines(f"{k}={v}" + "\n" for k, v in d.items()) def write_job_summary(summary: str): - """Write markdown messages on Github workflow UI. - See https://docs.github.com/en/actions/using-workflows/workflow-commands-for-github-actions#adding-a-job-summary - """ - step_summary_file = os.environ["GITHUB_STEP_SUMMARY"] - with open(step_summary_file, "a") as f: - # Use double newlines to split sections in markdown. - f.write(summary + "\n\n") - - -def check_description_and_show_diff(original_description: str, - original_labels: Sequence[str], - current_description: str, - current_labels: Sequence[str]): - original_labels = sorted(original_labels) - current_labels = sorted(current_labels) - if (original_description == current_description and - original_labels == current_labels): - return - - description_diffs = difflib.unified_diff( - original_description.splitlines(keepends=True), - current_description.splitlines(keepends=True)) - description_diffs = "".join(description_diffs) - - if description_diffs != "": - description_diffs = textwrap.dedent("""\ + """Write markdown messages on Github workflow UI. + See https://docs.github.com/en/actions/using-workflows/workflow-commands-for-github-actions#adding-a-job-summary + """ + step_summary_file = os.environ["GITHUB_STEP_SUMMARY"] + with open(step_summary_file, "a") as f: + # Use double newlines to split sections in markdown. + f.write(summary + "\n\n") + + +def check_description_and_show_diff( + original_description: str, + original_labels: Sequence[str], + current_description: str, + current_labels: Sequence[str], +): + original_labels = sorted(original_labels) + current_labels = sorted(current_labels) + if ( + original_description == current_description + and original_labels == current_labels + ): + return + + description_diffs = difflib.unified_diff( + original_description.splitlines(keepends=True), + current_description.splitlines(keepends=True), + ) + description_diffs = "".join(description_diffs) + + if description_diffs != "": + description_diffs = textwrap.dedent( + """\ ```diff {} ``` - """).format(description_diffs) - - if original_labels == current_labels: - label_diffs = "" - else: - label_diffs = textwrap.dedent("""\ + """ + ).format(description_diffs) + + if original_labels == current_labels: + label_diffs = "" + else: + label_diffs = textwrap.dedent( + """\ ``` Original labels: {original_labels} Current labels: {current_labels} ``` - """).format(original_labels=original_labels, current_labels=current_labels) + """ + ).format(original_labels=original_labels, current_labels=current_labels) - write_job_summary( - textwrap.dedent("""\ + write_job_summary( + textwrap.dedent( + """\ :pushpin: Using the PR description and labels different from the original PR event that started this workflow.
@@ -164,184 +176,203 @@ def check_description_and_show_diff(original_description: str, {description_diffs} {label_diffs} -
""").format(description_diffs=description_diffs, - label_diffs=label_diffs)) + """ + ).format(description_diffs=description_diffs, label_diffs=label_diffs) + ) def get_trailers_and_labels(is_pr: bool) -> Tuple[Mapping[str, str], List[str]]: - if not is_pr: - return ({}, []) - - title = os.environ["PR_TITLE"] - body = os.environ.get("PR_BODY", "") - labels = json.loads(os.environ.get("PR_LABELS", "[]")) - original_title = os.environ.get("ORIGINAL_PR_TITLE") - original_body = os.environ.get("ORIGINAL_PR_BODY", "") - original_labels = json.loads(os.environ.get("ORIGINAL_PR_LABELS", "[]")) - - description = PR_DESCRIPTION_TEMPLATE.format(title=title, body=body) - - # PR information can be fetched from API for the latest updates. If - # ORIGINAL_PR_TITLE is set, compare the current and original description and - # show a notice if they are different. This is mostly to inform users that the - # workflow might not parse the PR description they expect. - if original_title is not None: - original_description = PR_DESCRIPTION_TEMPLATE.format(title=original_title, - body=original_body) - print("Original PR description and labels:", - original_description, - original_labels, - sep="\n") - check_description_and_show_diff(original_description=original_description, - original_labels=original_labels, - current_description=description, - current_labels=labels) - - print("Parsing PR description and labels:", description, labels, sep="\n") - - trailer_lines = subprocess.run( - ["git", "interpret-trailers", "--parse", "--no-divider"], - input=description, - stdout=subprocess.PIPE, - check=True, - text=True, - timeout=60).stdout.splitlines() - trailer_map = { - k.lower().strip(): v.strip() - for k, v in (line.split(":", maxsplit=1) for line in trailer_lines) - } - return (trailer_map, labels) + if not is_pr: + return ({}, []) + + title = os.environ["PR_TITLE"] + body = os.environ.get("PR_BODY", "") + labels = json.loads(os.environ.get("PR_LABELS", "[]")) + original_title = os.environ.get("ORIGINAL_PR_TITLE") + original_body = os.environ.get("ORIGINAL_PR_BODY", "") + original_labels = json.loads(os.environ.get("ORIGINAL_PR_LABELS", "[]")) + + description = PR_DESCRIPTION_TEMPLATE.format(title=title, body=body) + + # PR information can be fetched from API for the latest updates. If + # ORIGINAL_PR_TITLE is set, compare the current and original description and + # show a notice if they are different. This is mostly to inform users that the + # workflow might not parse the PR description they expect. + if original_title is not None: + original_description = PR_DESCRIPTION_TEMPLATE.format( + title=original_title, body=original_body + ) + print( + "Original PR description and labels:", + original_description, + original_labels, + sep="\n", + ) + check_description_and_show_diff( + original_description=original_description, + original_labels=original_labels, + current_description=description, + current_labels=labels, + ) + + print("Parsing PR description and labels:", description, labels, sep="\n") + + trailer_lines = subprocess.run( + ["git", "interpret-trailers", "--parse", "--no-divider"], + input=description, + stdout=subprocess.PIPE, + check=True, + text=True, + timeout=60, + ).stdout.splitlines() + trailer_map = { + k.lower().strip(): v.strip() + for k, v in (line.split(":", maxsplit=1) for line in trailer_lines) + } + return (trailer_map, labels) def get_modified_paths(base_ref: str) -> Iterable[str]: - return subprocess.run(["git", "diff", "--name-only", base_ref], - stdout=subprocess.PIPE, - check=True, - text=True, - timeout=60).stdout.splitlines() + return subprocess.run( + ["git", "diff", "--name-only", base_ref], + stdout=subprocess.PIPE, + check=True, + text=True, + timeout=60, + ).stdout.splitlines() def modifies_included_path(base_ref: str) -> bool: - return any(not skip_path(p) for p in get_modified_paths(base_ref)) + return any(not skip_path(p) for p in get_modified_paths(base_ref)) def should_run_ci(is_pr: bool, trailers: Mapping[str, str]) -> bool: - if not is_pr: - print("Running CI independent of diff because run was not triggered by a" - " pull request event.") - return True - - if SKIP_CI_KEY in trailers: - print(f"Not running CI because PR description has '{SKIP_CI_KEY}' trailer.") - return False - - base_ref = os.environ["BASE_REF"] - try: - modifies = modifies_included_path(base_ref) - except TimeoutError as e: - print("Computing modified files timed out. Running the CI") + if not is_pr: + print( + "Running CI independent of diff because run was not triggered by a" + " pull request event." + ) + return True + + if SKIP_CI_KEY in trailers: + print(f"Not running CI because PR description has '{SKIP_CI_KEY}' trailer.") + return False + + base_ref = os.environ["BASE_REF"] + try: + modifies = modifies_included_path(base_ref) + except TimeoutError as e: + print("Computing modified files timed out. Running the CI") + return True + + if not modifies: + print("Skipping CI because all modified files are marked as excluded.") + return False + + print("CI should run") return True - if not modifies: - print("Skipping CI because all modified files are marked as excluded.") - return False - - print("CI should run") - return True - def get_runner_env(trailers: Mapping[str, str]) -> str: - runner_env = trailers.get(RUNNER_ENV_KEY) - if runner_env is None: - print(f"Using '{RUNNER_ENV_DEFAULT}' runners because '{RUNNER_ENV_KEY}'" - f" not found in {trailers}") - runner_env = RUNNER_ENV_DEFAULT - else: - print( - f"Using runner environment '{runner_env}' from PR description trailers") - return runner_env - - -def get_benchmark_presets(trailers: Mapping[str, str], labels: Sequence[str], - is_pr: bool, is_llvm_integrate_pr: bool) -> str: - """Parses and validates the benchmark presets from trailers. - - Args: - trailers: trailers from PR description. - labels: list of PR labels. - is_pr: is pull request event. - is_llvm_integrate_pr: is LLVM integration PR. - - Returns: - A comma separated preset string, which later will be parsed by - build_tools/benchmarks/export_benchmark_config.py. - """ - - skip_llvm_integrate_benchmark = SKIP_LLVM_INTEGRATE_BENCHMARK_KEY in trailers - if skip_llvm_integrate_benchmark: - print("Skipping default benchmarking on LLVM integration because PR " - f"description has '{SKIP_LLVM_INTEGRATE_BENCHMARK_KEY}' trailer.") - - if not is_pr: - preset_options = {DEFAULT_BENCHMARK_PRESET} - print(f"Using benchmark presets '{preset_options}' for non-PR run") - elif is_llvm_integrate_pr and not skip_llvm_integrate_benchmark: - # Run all benchmark presets for LLVM integration PRs. - preset_options = {DEFAULT_BENCHMARK_PRESET} - print(f"Using benchmark preset '{preset_options}' for LLVM integration PR") - else: - preset_options = set( - label.split(":", maxsplit=1)[1] - for label in labels - if label.startswith(BENCHMARK_LABEL_PREFIX + ":")) - trailer = trailers.get(BENCHMARK_EXTRA_KEY) - if trailer is not None: - preset_options = preset_options.union( - option.strip() for option in trailer.split(",")) - print(f"Using benchmark preset '{preset_options}' from trailers and labels") - - if DEFAULT_BENCHMARK_PRESET in preset_options: - preset_options.remove(DEFAULT_BENCHMARK_PRESET) - preset_options.update(DEFAULT_BENCHMARK_PRESET_GROUP) - - if preset_options.intersection(DEFAULT_BENCHMARK_PRESET_GROUP): - # The is a sugar to run the compilation benchmarks when any default - # benchmark preset is present. - preset_options.add("comp-stats") - - preset_options = sorted(preset_options) - for preset_option in preset_options: - if preset_option not in BENCHMARK_PRESET_OPTIONS: - raise ValueError(f"Unknown benchmark preset option: '{preset_option}'.\n" - f"Available options: '{BENCHMARK_PRESET_OPTIONS}'.") - - return ",".join(preset_options) + runner_env = trailers.get(RUNNER_ENV_KEY) + if runner_env is None: + print( + f"Using '{RUNNER_ENV_DEFAULT}' runners because '{RUNNER_ENV_KEY}'" + f" not found in {trailers}" + ) + runner_env = RUNNER_ENV_DEFAULT + else: + print(f"Using runner environment '{runner_env}' from PR description trailers") + return runner_env + + +def get_benchmark_presets( + trailers: Mapping[str, str], + labels: Sequence[str], + is_pr: bool, + is_llvm_integrate_pr: bool, +) -> str: + """Parses and validates the benchmark presets from trailers. + + Args: + trailers: trailers from PR description. + labels: list of PR labels. + is_pr: is pull request event. + is_llvm_integrate_pr: is LLVM integration PR. + + Returns: + A comma separated preset string, which later will be parsed by + build_tools/benchmarks/export_benchmark_config.py. + """ + + skip_llvm_integrate_benchmark = SKIP_LLVM_INTEGRATE_BENCHMARK_KEY in trailers + if skip_llvm_integrate_benchmark: + print( + "Skipping default benchmarking on LLVM integration because PR " + f"description has '{SKIP_LLVM_INTEGRATE_BENCHMARK_KEY}' trailer." + ) + + if not is_pr: + preset_options = {DEFAULT_BENCHMARK_PRESET} + print(f"Using benchmark presets '{preset_options}' for non-PR run") + elif is_llvm_integrate_pr and not skip_llvm_integrate_benchmark: + # Run all benchmark presets for LLVM integration PRs. + preset_options = {DEFAULT_BENCHMARK_PRESET} + print(f"Using benchmark preset '{preset_options}' for LLVM integration PR") + else: + preset_options = set( + label.split(":", maxsplit=1)[1] + for label in labels + if label.startswith(BENCHMARK_LABEL_PREFIX + ":") + ) + trailer = trailers.get(BENCHMARK_EXTRA_KEY) + if trailer is not None: + preset_options = preset_options.union( + option.strip() for option in trailer.split(",") + ) + print(f"Using benchmark preset '{preset_options}' from trailers and labels") + + if DEFAULT_BENCHMARK_PRESET in preset_options: + preset_options.remove(DEFAULT_BENCHMARK_PRESET) + preset_options.update(DEFAULT_BENCHMARK_PRESET_GROUP) + + if preset_options.intersection(DEFAULT_BENCHMARK_PRESET_GROUP): + # The is a sugar to run the compilation benchmarks when any default + # benchmark preset is present. + preset_options.add("comp-stats") + + preset_options = sorted(preset_options) + for preset_option in preset_options: + if preset_option not in BENCHMARK_PRESET_OPTIONS: + raise ValueError( + f"Unknown benchmark preset option: '{preset_option}'.\n" + f"Available options: '{BENCHMARK_PRESET_OPTIONS}'." + ) + + return ",".join(preset_options) def main(): - is_pr = os.environ["GITHUB_EVENT_NAME"] == "pull_request" - trailers, labels = get_trailers_and_labels(is_pr) - is_llvm_integrate_pr = bool( - LLVM_INTEGRATE_TITLE_PATTERN.search(os.environ.get("PR_TITLE", "")) or - LLVM_INTEGRATE_BRANCH_PATTERN.search(os.environ.get("PR_BRANCH", "")) or - LLVM_INTEGRATE_LABEL in labels) - output = { - "should-run": - json.dumps(should_run_ci(is_pr, trailers)), - "is-pr": - json.dumps(is_pr), - "runner-env": - get_runner_env(trailers), - "runner-group": - "presubmit" if is_pr else "postsubmit", - "write-caches": - "0" if is_pr else "1", - "benchmark-presets": - get_benchmark_presets(trailers, labels, is_pr, is_llvm_integrate_pr), - } - - set_output(output) + is_pr = os.environ["GITHUB_EVENT_NAME"] == "pull_request" + trailers, labels = get_trailers_and_labels(is_pr) + is_llvm_integrate_pr = bool( + LLVM_INTEGRATE_TITLE_PATTERN.search(os.environ.get("PR_TITLE", "")) + or LLVM_INTEGRATE_BRANCH_PATTERN.search(os.environ.get("PR_BRANCH", "")) + or LLVM_INTEGRATE_LABEL in labels + ) + output = { + "should-run": json.dumps(should_run_ci(is_pr, trailers)), + "is-pr": json.dumps(is_pr), + "runner-env": get_runner_env(trailers), + "runner-group": "presubmit" if is_pr else "postsubmit", + "write-caches": "0" if is_pr else "1", + "benchmark-presets": get_benchmark_presets( + trailers, labels, is_pr, is_llvm_integrate_pr + ), + } + + set_output(output) if __name__ == "__main__": - main() + main() diff --git a/build_tools/github_actions/configure_ci_test.py b/build_tools/github_actions/configure_ci_test.py index 04b6d4fc00ce..0640ec23a25d 100644 --- a/build_tools/github_actions/configure_ci_test.py +++ b/build_tools/github_actions/configure_ci_test.py @@ -11,93 +11,99 @@ import configure_ci SORTED_DEFAULT_BENCHMARK_PRESETS_STR = ",".join( - sorted(configure_ci.DEFAULT_BENCHMARK_PRESET_GROUP)) + sorted(configure_ci.DEFAULT_BENCHMARK_PRESET_GROUP) +) class GetBenchmarkPresetsTest(unittest.TestCase): + def test_get_benchmark_presets_no_preset(self): + presets_str = configure_ci.get_benchmark_presets( + trailers={}, + labels=["unrelated-labels"], + is_pr=True, + is_llvm_integrate_pr=False, + ) + + self.assertEqual(presets_str, "") + + def test_get_benchmark_presets_from_pr_labels(self): + presets_str = configure_ci.get_benchmark_presets( + trailers={}, + labels=["benchmarks:x86_64", "benchmarks:cuda"], + is_pr=True, + is_llvm_integrate_pr=False, + ) + + self.assertEqual(presets_str, "comp-stats,cuda,x86_64") + + def test_get_benchmark_presets_from_trailers_and_labels(self): + presets_str = configure_ci.get_benchmark_presets( + trailers={"benchmark-extra": "android-cpu,cuda-large,x86_64-large"}, + labels=["benchmarks:vulkan-nvidia"], + is_pr=True, + is_llvm_integrate_pr=False, + ) + + self.assertEqual( + presets_str, "android-cpu,comp-stats,cuda-large,vulkan-nvidia,x86_64-large" + ) - def test_get_benchmark_presets_no_preset(self): - presets_str = configure_ci.get_benchmark_presets( - trailers={}, - labels=["unrelated-labels"], - is_pr=True, - is_llvm_integrate_pr=False) - - self.assertEqual(presets_str, "") - - def test_get_benchmark_presets_from_pr_labels(self): - presets_str = configure_ci.get_benchmark_presets( - trailers={}, - labels=["benchmarks:x86_64", "benchmarks:cuda"], - is_pr=True, - is_llvm_integrate_pr=False) - - self.assertEqual(presets_str, "comp-stats,cuda,x86_64") - - def test_get_benchmark_presets_from_trailers_and_labels(self): - presets_str = configure_ci.get_benchmark_presets( - trailers={"benchmark-extra": "android-cpu,cuda-large,x86_64-large"}, - labels=["benchmarks:vulkan-nvidia"], - is_pr=True, - is_llvm_integrate_pr=False) - - self.assertEqual( - presets_str, - "android-cpu,comp-stats,cuda-large,vulkan-nvidia,x86_64-large") - - def test_get_benchmark_presets_from_default_group(self): - presets_str = configure_ci.get_benchmark_presets( - trailers={"benchmark-extra": "default"}, - labels=[], - is_pr=True, - is_llvm_integrate_pr=False) - - self.assertEqual(presets_str, SORTED_DEFAULT_BENCHMARK_PRESETS_STR) - # Sanity check to ensure no `*-large` preset in the default group. - self.assertNotIn("-large", presets_str) - - def test_get_benchmark_presets_for_non_pr(self): - presets_str = configure_ci.get_benchmark_presets(trailers={}, - labels=[], - is_pr=False, - is_llvm_integrate_pr=False) - - self.assertEqual(presets_str, SORTED_DEFAULT_BENCHMARK_PRESETS_STR) - - def test_get_benchmark_presets_for_llvm_integrate_pr(self): - presets_str = configure_ci.get_benchmark_presets(trailers={}, - labels=[], - is_pr=True, - is_llvm_integrate_pr=True) - - self.assertEqual(presets_str, SORTED_DEFAULT_BENCHMARK_PRESETS_STR) - - # Sample PR description: - # ``` - # PR Title - # - # PR body... - # - # skip-llvm-integrate-benchmark: some good reasons - # ``` - # Result: No benchmark is automatically enabled on the LLVM integrate PR. - def test_get_benchmark_presets_skip_llvm_integrate_benchmark(self): - presets_str = configure_ci.get_benchmark_presets( - trailers={"skip-llvm-integrate-benchmark": "some good reasons"}, - labels=[], - is_pr=True, - is_llvm_integrate_pr=True) - - self.assertEqual(presets_str, "") - - def test_get_benchmark_presets_unknown_preset(self): - self.assertRaises( - ValueError, lambda: configure_ci.get_benchmark_presets( - trailers={"benchmark-extra": "unknown"}, + def test_get_benchmark_presets_from_default_group(self): + presets_str = configure_ci.get_benchmark_presets( + trailers={"benchmark-extra": "default"}, + labels=[], + is_pr=True, + is_llvm_integrate_pr=False, + ) + + self.assertEqual(presets_str, SORTED_DEFAULT_BENCHMARK_PRESETS_STR) + # Sanity check to ensure no `*-large` preset in the default group. + self.assertNotIn("-large", presets_str) + + def test_get_benchmark_presets_for_non_pr(self): + presets_str = configure_ci.get_benchmark_presets( + trailers={}, labels=[], is_pr=False, is_llvm_integrate_pr=False + ) + + self.assertEqual(presets_str, SORTED_DEFAULT_BENCHMARK_PRESETS_STR) + + def test_get_benchmark_presets_for_llvm_integrate_pr(self): + presets_str = configure_ci.get_benchmark_presets( + trailers={}, labels=[], is_pr=True, is_llvm_integrate_pr=True + ) + + self.assertEqual(presets_str, SORTED_DEFAULT_BENCHMARK_PRESETS_STR) + + # Sample PR description: + # ``` + # PR Title + # + # PR body... + # + # skip-llvm-integrate-benchmark: some good reasons + # ``` + # Result: No benchmark is automatically enabled on the LLVM integrate PR. + def test_get_benchmark_presets_skip_llvm_integrate_benchmark(self): + presets_str = configure_ci.get_benchmark_presets( + trailers={"skip-llvm-integrate-benchmark": "some good reasons"}, labels=[], is_pr=True, - is_llvm_integrate_pr=False)) + is_llvm_integrate_pr=True, + ) + + self.assertEqual(presets_str, "") + + def test_get_benchmark_presets_unknown_preset(self): + self.assertRaises( + ValueError, + lambda: configure_ci.get_benchmark_presets( + trailers={"benchmark-extra": "unknown"}, + labels=[], + is_pr=True, + is_llvm_integrate_pr=False, + ), + ) if __name__ == "__main__": - unittest.main() + unittest.main() diff --git a/build_tools/github_actions/runner/config/health_server/health_server.py b/build_tools/github_actions/runner/config/health_server/health_server.py index d62df745be39..626bb20b4880 100755 --- a/build_tools/github_actions/runner/config/health_server/health_server.py +++ b/build_tools/github_actions/runner/config/health_server/health_server.py @@ -31,62 +31,62 @@ class HealthCheckHandler(http.server.BaseHTTPRequestHandler): - - def send_success(self, - *, - msg: Optional[str] = None, - body: Optional[str] = None): - self.send_response(OK) - self.send_header("Content-type", "text/html") - self.end_headers() - if body is not None: - self.wfile.write(bytes(body, encoding="utf-8")) - - def do_GET(self): - try: - subprocess.run(CHECK_SERVICE_CMD, - check=True, - text=True, - stdout=subprocess.PIPE, - timeout=CHECK_SERVICE_TIMEOUT) - except subprocess.TimeoutExpired as e: - msg = f"'{' '.join(e.cmd)}' timed out: {e.stdout}" - return self.send_error(INTERNAL_SERVER_ERROR, msg) - except subprocess.CalledProcessError as e: - return self.send_error( - NOT_FOUND, f"Runner service not found: '{' '.join(e.cmd)}' returned" - f" '{e.stdout.strip()}' (exit code {e.returncode})") - - # The runner writes a log file for each job it runs. In our case it only - # runs one, so we glob for anything matching that pattern. Yes that is an - # absolutely ludicrous way to get the runner's status. GitHub should really - # implement a proper health check so we don't have to hack around like this. - if glob.glob(RUNNER_WORK_LOG_PATTERN): - return self.send_success(body="active") - - return self.send_success(body="idle") + def send_success(self, *, msg: Optional[str] = None, body: Optional[str] = None): + self.send_response(OK) + self.send_header("Content-type", "text/html") + self.end_headers() + if body is not None: + self.wfile.write(bytes(body, encoding="utf-8")) + + def do_GET(self): + try: + subprocess.run( + CHECK_SERVICE_CMD, + check=True, + text=True, + stdout=subprocess.PIPE, + timeout=CHECK_SERVICE_TIMEOUT, + ) + except subprocess.TimeoutExpired as e: + msg = f"'{' '.join(e.cmd)}' timed out: {e.stdout}" + return self.send_error(INTERNAL_SERVER_ERROR, msg) + except subprocess.CalledProcessError as e: + return self.send_error( + NOT_FOUND, + f"Runner service not found: '{' '.join(e.cmd)}' returned" + f" '{e.stdout.strip()}' (exit code {e.returncode})", + ) + + # The runner writes a log file for each job it runs. In our case it only + # runs one, so we glob for anything matching that pattern. Yes that is an + # absolutely ludicrous way to get the runner's status. GitHub should really + # implement a proper health check so we don't have to hack around like this. + if glob.glob(RUNNER_WORK_LOG_PATTERN): + return self.send_success(body="active") + + return self.send_success(body="idle") def main(args: argparse.Namespace): - webServer = http.server.HTTPServer(("", args.port), HealthCheckHandler) - print(f"Server started on port {args.port}. Ctrl+C to stop.") + webServer = http.server.HTTPServer(("", args.port), HealthCheckHandler) + print(f"Server started on port {args.port}. Ctrl+C to stop.") - try: - webServer.serve_forever() - except KeyboardInterrupt: - # Don't print an exception on interrupt. Add a newline to handle printing of - # "^C" - print() + try: + webServer.serve_forever() + except KeyboardInterrupt: + # Don't print an exception on interrupt. Add a newline to handle printing of + # "^C" + print() - webServer.server_close() - print("Server stopped.") + webServer.server_close() + print("Server stopped.") def parse_args(): - parser = argparse.ArgumentParser() - parser.add_argument("--port", type=int, default=8080) - return parser.parse_args() + parser = argparse.ArgumentParser() + parser.add_argument("--port", type=int, default=8080) + return parser.parse_args() if __name__ == "__main__": - main(parse_args()) + main(parse_args()) diff --git a/build_tools/github_actions/runner/gcp/update_instance_groups.py b/build_tools/github_actions/runner/gcp/update_instance_groups.py index 88f7effe3d92..6835b5f89227 100755 --- a/build_tools/github_actions/runner/gcp/update_instance_groups.py +++ b/build_tools/github_actions/runner/gcp/update_instance_groups.py @@ -25,345 +25,415 @@ def resource_basename(resource): - return os.path.basename(urllib.parse.urlparse(resource).path) + return os.path.basename(urllib.parse.urlparse(resource).path) def error(msg): - print("ERROR: ", msg, file=sys.stderr) - sys.exit(1) + print("ERROR: ", msg, file=sys.stderr) + sys.exit(1) def confirm(msg): - user_input = "" - while user_input.lower() not in ["yes", "no", "y", "n"]: - user_input = input(f"{msg} [y/n] ") - if user_input.lower() in ["n", "no"]: - print("Aborting") - sys.exit(1) + user_input = "" + while user_input.lower() not in ["yes", "no", "y", "n"]: + user_input = input(f"{msg} [y/n] ") + if user_input.lower() in ["n", "no"]: + print("Aborting") + sys.exit(1) def check_scary_action(action, skip_confirmation): - if skip_confirmation: - print(f"WARNING: Performing {action}.\n" - f"Proceeding because '--skip-confirmation' is set.") - else: - confirm(f"You are about to perform {action}.\n" - f" Are you sure you want to proceed?") + if skip_confirmation: + print( + f"WARNING: Performing {action}.\n" + f"Proceeding because '--skip-confirmation' is set." + ) + else: + confirm( + f"You are about to perform {action}.\n" + f" Are you sure you want to proceed?" + ) def summarize_versions(versions): - return {v.name: resource_basename(v.instance_template) for v in versions} + return {v.name: resource_basename(v.instance_template) for v in versions} -class MigFetcher(): +class MigFetcher: + def __init__(self, *, migs_client, regions_client, project): + self._migs_client = migs_client + self._regions_client = regions_client + self._project = project - def __init__(self, *, migs_client, regions_client, project): - self._migs_client = migs_client - self._regions_client = regions_client - self._project = project + def get_migs(self, *, region, type, group, prefix, modifier=None): + print("Finding matching MIGs") + migs = [] - def get_migs(self, *, region, type, group, prefix, modifier=None): - print("Finding matching MIGs") - migs = [] + request = compute.ListRegionsRequest(project=self._project) + if region != "all": + request.filter = f"name eq {region}" + regions = [r.name for r in self._regions_client.list(request)] - request = compute.ListRegionsRequest(project=self._project) - if region != "all": - request.filter = f"name eq {region}" - regions = [r.name for r in self._regions_client.list(request)] + if type == "all": + type = r"\w+" - if type == "all": - type = r"\w+" + if group == "all": + group = r"\w+" - if group == "all": - group = r"\w+" - - for region in regions: - filter_parts = [p for p in [prefix, modifier, group, type, region] if p] - filter = f"name eq '{'-'.join(filter_parts)}'" - list_mig_request = compute.ListRegionInstanceGroupManagersRequest( - project=self._project, - region=region, - filter=filter, - ) - region_migs = self._migs_client.list(list_mig_request) - migs.extend([mig for mig in region_migs]) - return migs + for region in regions: + filter_parts = [p for p in [prefix, modifier, group, type, region] if p] + filter = f"name eq '{'-'.join(filter_parts)}'" + list_mig_request = compute.ListRegionInstanceGroupManagersRequest( + project=self._project, + region=region, + filter=filter, + ) + region_migs = self._migs_client.list(list_mig_request) + migs.extend([mig for mig in region_migs]) + return migs def main(args): - templates_client = compute.InstanceTemplatesClient() - migs_client = compute.RegionInstanceGroupManagersClient() - updater = MigFetcher( - migs_client=migs_client, - regions_client=compute.RegionsClient(), - project=args.project, - ) - - # Prod instances just have the bare name - modifier = None if args.env == PROD_ENV_NAME else args.env - migs = updater.get_migs(region=args.region, - type=args.type, - group=args.group, - prefix=args.name_prefix, - modifier=modifier) - if len(migs) == 0: - error("arguments matched no instance groups") - sys.exit(1) - - print(f"Found:\n ", "\n ".join([m.name for m in migs]), sep="") - if args.skip_confirmation: - print("Proceeding with update as --skip-confirmation is set") - else: - confirm("Proceed with updating these MIGs?") - - if args.mode == "proactive" and args.action != "refresh": - mig_desc = f"'{migs[0].name}'" if len(migs) == 1 else f"{len(migs)} groups" - scary_action = ( - f"an update on {mig_desc} that will shut down instances even if" - f" they're in the middle of running a job") - check_scary_action(scary_action, args.skip_confirmation) - - for mig in migs: - region = resource_basename(mig.region) - if args.command in [DIRECT_UPDATE_COMMAND_NAME, CANARY_COMMAND_NAME]: - if "testing" in args.version and args.env != TESTING_ENV_NAME: - scary_action = (f"using testing template version '{args.version}' in" - f" environment '{args.env}'") - check_scary_action(scary_action, args.skip_confirmation) - - strip = f"-{region}" - if not mig.name.endswith(strip): - raise ValueError(f"MIG name does not end with '{strip}' as expected") - template_name = f"{mig.name[:-len(strip)]}-{args.version}" - - # TODO(gcmn): Make template naming consistent (ran into length limits) - template_name = template_name.replace(f"-{args.env}-", "-") - template_url = templates_client.get( - project=args.project, instance_template=template_name).self_link - - current_templates = {v.name: v.instance_template for v in mig.versions} - - if not current_templates: - error(f"Found no template versions for '{mig.name}'." - f" This shouldn't be possible.") - - # TODO(gcmn): These should probably be factored into functions - if args.command == CANARY_COMMAND_NAME: - if len(current_templates) > 1: - error(f"Instance group '{mig.name}' has multiple versions, but canary" - f" requires it start with exactly one. Current versions:" - f" {summarize_versions(mig.versions)}") - - base_template = current_templates.get(args.base_version_name) - if not base_template: - error(f"Instance group '{mig.name}' does not have a current version" - f" named '{args.base_version_name}', which is required for an" - f" automatic canary. Current versions:" - f" {summarize_versions(mig.versions)}") - - if base_template == template_url: - error(f"Instance group '{mig.name}' already has the requested canary" - f" version '{template_name}' as its base version. Current" - " versions:" - f" {summarize_versions(mig.versions)}") - new_versions = [ - compute.InstanceGroupManagerVersion(name=args.base_version_name, - instance_template=base_template), - compute.InstanceGroupManagerVersion(name=args.canary_version_name, - instance_template=template_url, - target_size=CANARY_SIZE) - ] - elif args.command == DIRECT_UPDATE_COMMAND_NAME: - scary_action = (f"an update of all instances in '{mig.name}' directly" - f" without doing a canary") - check_scary_action(scary_action, args.skip_confirmation) - - new_versions = [ - compute.InstanceGroupManagerVersion(name=args.base_version_name, - instance_template=template_url) - ] - elif args.command == PROMOTE_CANARY_COMMAND_NAME: - new_base_template = current_templates.get(args.canary_version_name) - if new_base_template is None: - error(f"Instance group '{mig.name}' does not have a current version" - f" named '{args.canary_version_name}', which is required for an" - f" automatic canary promotion. Current versions:" - f" {summarize_versions(mig.versions)}") - new_versions = [ - compute.InstanceGroupManagerVersion( - name=args.base_version_name, instance_template=new_base_template) - ] - elif args.command == ROLLBACK_CANARY_COMMAND_NAME: - base_template = current_templates.get(args.base_version_name) - if base_template is None: - error(f"Instance group '{mig.name}' does not have a current version" - f" named '{args.base_version_name}', which is required for an" - f" automatic canary rollback. Current versions:" - f" {summarize_versions(mig.versions)}") - new_versions = [ - compute.InstanceGroupManagerVersion(name=args.base_version_name, - instance_template=base_template) - ] - else: - error(f"Unrecognized command '{args.command}'") - - update_policy = compute.InstanceGroupManagerUpdatePolicy( - type_=args.mode, - minimal_action=args.action, - most_disruptive_allowed_action=args.action) - - print(f"Updating {mig.name} to new versions:" - f" {summarize_versions(new_versions)}") - - request = compute.PatchRegionInstanceGroupManagerRequest( + templates_client = compute.InstanceTemplatesClient() + migs_client = compute.RegionInstanceGroupManagersClient() + updater = MigFetcher( + migs_client=migs_client, + regions_client=compute.RegionsClient(), project=args.project, - region=region, - instance_group_manager=mig.name, - instance_group_manager_resource=compute.InstanceGroupManager( - versions=new_versions, update_policy=update_policy)) - - if not args.dry_run: - migs_client.patch(request) + ) + + # Prod instances just have the bare name + modifier = None if args.env == PROD_ENV_NAME else args.env + migs = updater.get_migs( + region=args.region, + type=args.type, + group=args.group, + prefix=args.name_prefix, + modifier=modifier, + ) + if len(migs) == 0: + error("arguments matched no instance groups") + sys.exit(1) + + print(f"Found:\n ", "\n ".join([m.name for m in migs]), sep="") + if args.skip_confirmation: + print("Proceeding with update as --skip-confirmation is set") else: - print(f"Dry run, so not sending this patch request:\n```\n{request}```") - print(f"Successfully updated {mig.name}") + confirm("Proceed with updating these MIGs?") + + if args.mode == "proactive" and args.action != "refresh": + mig_desc = f"'{migs[0].name}'" if len(migs) == 1 else f"{len(migs)} groups" + scary_action = ( + f"an update on {mig_desc} that will shut down instances even if" + f" they're in the middle of running a job" + ) + check_scary_action(scary_action, args.skip_confirmation) + + for mig in migs: + region = resource_basename(mig.region) + if args.command in [DIRECT_UPDATE_COMMAND_NAME, CANARY_COMMAND_NAME]: + if "testing" in args.version and args.env != TESTING_ENV_NAME: + scary_action = ( + f"using testing template version '{args.version}' in" + f" environment '{args.env}'" + ) + check_scary_action(scary_action, args.skip_confirmation) + + strip = f"-{region}" + if not mig.name.endswith(strip): + raise ValueError(f"MIG name does not end with '{strip}' as expected") + template_name = f"{mig.name[:-len(strip)]}-{args.version}" + + # TODO(gcmn): Make template naming consistent (ran into length limits) + template_name = template_name.replace(f"-{args.env}-", "-") + template_url = templates_client.get( + project=args.project, instance_template=template_name + ).self_link + + current_templates = {v.name: v.instance_template for v in mig.versions} + + if not current_templates: + error( + f"Found no template versions for '{mig.name}'." + f" This shouldn't be possible." + ) + + # TODO(gcmn): These should probably be factored into functions + if args.command == CANARY_COMMAND_NAME: + if len(current_templates) > 1: + error( + f"Instance group '{mig.name}' has multiple versions, but canary" + f" requires it start with exactly one. Current versions:" + f" {summarize_versions(mig.versions)}" + ) + + base_template = current_templates.get(args.base_version_name) + if not base_template: + error( + f"Instance group '{mig.name}' does not have a current version" + f" named '{args.base_version_name}', which is required for an" + f" automatic canary. Current versions:" + f" {summarize_versions(mig.versions)}" + ) + + if base_template == template_url: + error( + f"Instance group '{mig.name}' already has the requested canary" + f" version '{template_name}' as its base version. Current" + " versions:" + f" {summarize_versions(mig.versions)}" + ) + new_versions = [ + compute.InstanceGroupManagerVersion( + name=args.base_version_name, instance_template=base_template + ), + compute.InstanceGroupManagerVersion( + name=args.canary_version_name, + instance_template=template_url, + target_size=CANARY_SIZE, + ), + ] + elif args.command == DIRECT_UPDATE_COMMAND_NAME: + scary_action = ( + f"an update of all instances in '{mig.name}' directly" + f" without doing a canary" + ) + check_scary_action(scary_action, args.skip_confirmation) + + new_versions = [ + compute.InstanceGroupManagerVersion( + name=args.base_version_name, instance_template=template_url + ) + ] + elif args.command == PROMOTE_CANARY_COMMAND_NAME: + new_base_template = current_templates.get(args.canary_version_name) + if new_base_template is None: + error( + f"Instance group '{mig.name}' does not have a current version" + f" named '{args.canary_version_name}', which is required for an" + f" automatic canary promotion. Current versions:" + f" {summarize_versions(mig.versions)}" + ) + new_versions = [ + compute.InstanceGroupManagerVersion( + name=args.base_version_name, instance_template=new_base_template + ) + ] + elif args.command == ROLLBACK_CANARY_COMMAND_NAME: + base_template = current_templates.get(args.base_version_name) + if base_template is None: + error( + f"Instance group '{mig.name}' does not have a current version" + f" named '{args.base_version_name}', which is required for an" + f" automatic canary rollback. Current versions:" + f" {summarize_versions(mig.versions)}" + ) + new_versions = [ + compute.InstanceGroupManagerVersion( + name=args.base_version_name, instance_template=base_template + ) + ] + else: + error(f"Unrecognized command '{args.command}'") + + update_policy = compute.InstanceGroupManagerUpdatePolicy( + type_=args.mode, + minimal_action=args.action, + most_disruptive_allowed_action=args.action, + ) + + print( + f"Updating {mig.name} to new versions:" + f" {summarize_versions(new_versions)}" + ) + + request = compute.PatchRegionInstanceGroupManagerRequest( + project=args.project, + region=region, + instance_group_manager=mig.name, + instance_group_manager_resource=compute.InstanceGroupManager( + versions=new_versions, update_policy=update_policy + ), + ) + + if not args.dry_run: + migs_client.patch(request) + else: + print(f"Dry run, so not sending this patch request:\n```\n{request}```") + print(f"Successfully updated {mig.name}") def parse_args(): - parser = argparse.ArgumentParser(description=( - "Updates one or more GCP Managed Instance Groups (MIGs) to new" - " instance template versions. Wraps the GCP API with shortcuts for the" - " patterns we have in our MIGs. See the README and" - " https://cloud.google.com/compute/docs/instance-groups/updating-migs for" - " more details.")) - - # Makes global options come *after* command. - # See https://stackoverflow.com/q/23296695 - subparser_base = argparse.ArgumentParser(add_help=False) - subparser_base.add_argument("--project", - default="iree-oss", - help="The cloud project for the MIGs.") - subparser_base.add_argument( - "--region", - "--regions", - required=True, - help=("The cloud region (e.g. 'us-west1') of the MIG to update, an RE2" + parser = argparse.ArgumentParser( + description=( + "Updates one or more GCP Managed Instance Groups (MIGs) to new" + " instance template versions. Wraps the GCP API with shortcuts for the" + " patterns we have in our MIGs. See the README and" + " https://cloud.google.com/compute/docs/instance-groups/updating-migs for" + " more details." + ) + ) + + # Makes global options come *after* command. + # See https://stackoverflow.com/q/23296695 + subparser_base = argparse.ArgumentParser(add_help=False) + subparser_base.add_argument( + "--project", default="iree-oss", help="The cloud project for the MIGs." + ) + subparser_base.add_argument( + "--region", + "--regions", + required=True, + help=( + "The cloud region (e.g. 'us-west1') of the MIG to update, an RE2" " regex for matching region names (e.g. 'us-.*'), or 'all' to" - " search for MIGs in all regions.")) - subparser_base.add_argument( - "--group", - "--groups", - required=True, - help=("The runner group of the MIGs to update, an RE2 regex for matching" + " search for MIGs in all regions." + ), + ) + subparser_base.add_argument( + "--group", + "--groups", + required=True, + help=( + "The runner group of the MIGs to update, an RE2 regex for matching" " the group (e.g. 'cpu|gpu'), or 'all' to search for MIGs for all" - " groups."), - ) - subparser_base.add_argument( - "--type", - "--types", - required=True, - help=("The runner type of the MIGs to update, an RE2 regex for matching" + " groups." + ), + ) + subparser_base.add_argument( + "--type", + "--types", + required=True, + help=( + "The runner type of the MIGs to update, an RE2 regex for matching" " the type (e.g. 'presubmit|postsubmit'), or 'all' to search for" - " MIGs for all types."), - ) - subparser_base.add_argument( - "--mode", - default="opportunistic", - choices=["opportunistic", "proactive"], - help=( - "The mode in which to update instances. See README and" - " https://cloud.google.com/compute/docs/instance-groups/updating-migs." - )) - subparser_base.add_argument( - "--action", - choices=["refresh", "restart", "replace"], - help=( - "What action to take when updating an instance. See README and" - " https://cloud.google.com/compute/docs/instance-groups/updating-migs." - )) - subparser_base.add_argument("--env", - "--environment", - default=TESTING_ENV_NAME, - help="The environment for the MIGs.", - choices=[PROD_ENV_NAME, TESTING_ENV_NAME]) - subparser_base.add_argument( - "--dry-run", - action="store_true", - default=False, - help="Print all output but don't actually send the update request.") - - # Defaulting to true for testing environment avoids people getting in the - # habit of routinely passing --force. - skip_confirmation = subparser_base.add_mutually_exclusive_group() - skip_confirmation.add_argument( - "--skip-confirmation", - "--force", - action="store_true", - default=None, - help=("Skip all confirmation prompts. Be careful." - " Defaults to True for testing environment")) - skip_confirmation.add_argument("--noskip-confirmation", - "--noforce", - action="store_false", - default=None, - dest="skip_confirmation") - - # These shouldn't be set very often, but it's just as easy to make them flags - # as it is to make them global constants. - subparser_base.add_argument("--name-prefix", - default="gh-runner", - help="The first part of MIG and template names.") - subparser_base.add_argument( - "--base-version-name", - default="base", - help="The name given to the MIG instance version that isn't in canary.") - subparser_base.add_argument( - "--canary-version-name", - default="canary", - help="The name given to the MIG instance version that is being canaried.") - - subparsers = parser.add_subparsers(required=True, dest="command") - - canary_sp = subparsers.add_parser(CANARY_COMMAND_NAME, - parents=[subparser_base], - help="Canary a new template version.") - rollback_sp = subparsers.add_parser( - ROLLBACK_CANARY_COMMAND_NAME, - parents=[subparser_base], - help=("Rollback a previous canary, restoring all instances to the base" - " version.")) - promote_sp = subparsers.add_parser( - PROMOTE_CANARY_COMMAND_NAME, - parents=[subparser_base], - help="Promote the current canary version to be the base version.") - direct_sp = subparsers.add_parser( - DIRECT_UPDATE_COMMAND_NAME, - parents=[subparser_base], - help=("Update all instances in the MIG to a new version. Generally should" - " not be used for prod.")) - - for sp in [canary_sp, direct_sp]: - sp.add_argument( - "--version", - help=("The new instance template version. Usually git hash +" - " 3-character uid, e.g. 56e40f6505-9lp")) - - # TODO: Add this argument with a custom parser - # canary_sp.add_argument("--canary-size", type=int, default=1) - - args = parser.parse_args() - - if args.skip_confirmation is None: - args.skip_confirmation = args.env == TESTING_ENV_NAME - - if args.action is None: - if args.mode == "proactive": - args.action = "refresh" - else: - args.action = "replace" - - return args + " MIGs for all types." + ), + ) + subparser_base.add_argument( + "--mode", + default="opportunistic", + choices=["opportunistic", "proactive"], + help=( + "The mode in which to update instances. See README and" + " https://cloud.google.com/compute/docs/instance-groups/updating-migs." + ), + ) + subparser_base.add_argument( + "--action", + choices=["refresh", "restart", "replace"], + help=( + "What action to take when updating an instance. See README and" + " https://cloud.google.com/compute/docs/instance-groups/updating-migs." + ), + ) + subparser_base.add_argument( + "--env", + "--environment", + default=TESTING_ENV_NAME, + help="The environment for the MIGs.", + choices=[PROD_ENV_NAME, TESTING_ENV_NAME], + ) + subparser_base.add_argument( + "--dry-run", + action="store_true", + default=False, + help="Print all output but don't actually send the update request.", + ) + + # Defaulting to true for testing environment avoids people getting in the + # habit of routinely passing --force. + skip_confirmation = subparser_base.add_mutually_exclusive_group() + skip_confirmation.add_argument( + "--skip-confirmation", + "--force", + action="store_true", + default=None, + help=( + "Skip all confirmation prompts. Be careful." + " Defaults to True for testing environment" + ), + ) + skip_confirmation.add_argument( + "--noskip-confirmation", + "--noforce", + action="store_false", + default=None, + dest="skip_confirmation", + ) + + # These shouldn't be set very often, but it's just as easy to make them flags + # as it is to make them global constants. + subparser_base.add_argument( + "--name-prefix", + default="gh-runner", + help="The first part of MIG and template names.", + ) + subparser_base.add_argument( + "--base-version-name", + default="base", + help="The name given to the MIG instance version that isn't in canary.", + ) + subparser_base.add_argument( + "--canary-version-name", + default="canary", + help="The name given to the MIG instance version that is being canaried.", + ) + + subparsers = parser.add_subparsers(required=True, dest="command") + + canary_sp = subparsers.add_parser( + CANARY_COMMAND_NAME, + parents=[subparser_base], + help="Canary a new template version.", + ) + rollback_sp = subparsers.add_parser( + ROLLBACK_CANARY_COMMAND_NAME, + parents=[subparser_base], + help=( + "Rollback a previous canary, restoring all instances to the base" + " version." + ), + ) + promote_sp = subparsers.add_parser( + PROMOTE_CANARY_COMMAND_NAME, + parents=[subparser_base], + help="Promote the current canary version to be the base version.", + ) + direct_sp = subparsers.add_parser( + DIRECT_UPDATE_COMMAND_NAME, + parents=[subparser_base], + help=( + "Update all instances in the MIG to a new version. Generally should" + " not be used for prod." + ), + ) + + for sp in [canary_sp, direct_sp]: + sp.add_argument( + "--version", + help=( + "The new instance template version. Usually git hash +" + " 3-character uid, e.g. 56e40f6505-9lp" + ), + ) + + # TODO: Add this argument with a custom parser + # canary_sp.add_argument("--canary-size", type=int, default=1) + + args = parser.parse_args() + + if args.skip_confirmation is None: + args.skip_confirmation = args.env == TESTING_ENV_NAME + + if args.action is None: + if args.mode == "proactive": + args.action = "refresh" + else: + args.action = "replace" + + return args if __name__ == "__main__": - main(parse_args()) + main(parse_args()) diff --git a/build_tools/github_actions/runner/gcp/update_runner_version.py b/build_tools/github_actions/runner/gcp/update_runner_version.py index f2e2ad431f44..4124288d09b0 100755 --- a/build_tools/github_actions/runner/gcp/update_runner_version.py +++ b/build_tools/github_actions/runner/gcp/update_runner_version.py @@ -27,12 +27,15 @@ # This is using the old printf-style string formatting because we're creating # lines that have Bash substitutions using braces VERSION_LINE_FORMAT_STRING = 'GITHUB_RUNNER_VERSION="${GITHUB_RUNNER_VERSION:-%s}"' -DIGEST_LINE_FORMAT_STRING = 'GITHUB_RUNNER_ARCHIVE_DIGEST="${GITHUB_RUNNER_ARCHIVE_DIGEST:-%s}"' +DIGEST_LINE_FORMAT_STRING = ( + 'GITHUB_RUNNER_ARCHIVE_DIGEST="${GITHUB_RUNNER_ARCHIVE_DIGEST:-%s}"' +) -DIGEST_SEARCH_PATTERN = r"^.*\bBEGIN.SHA linux-x64\b.*\b([a-fA-F0-9]{64})\b.*END.SHA linux-x64\b.*$" +DIGEST_SEARCH_PATTERN = ( + r"^.*\bBEGIN.SHA linux-x64\b.*\b([a-fA-F0-9]{64})\b.*END.SHA linux-x64\b.*$" +) -RUNNER_ARCHIVE_TEMPLATE = string.Template( - "actions-runner-linux-x64-${version}.tar.gz") +RUNNER_ARCHIVE_TEMPLATE = string.Template("actions-runner-linux-x64-${version}.tar.gz") ASSET_URL_TEMPLATE = string.Template( "https://github.com/actions/runner/releases/download/v${version}/${archive}" ) @@ -43,57 +46,61 @@ def error(*msg): - print(*msg, file=sys.stderr) - sys.exit(1) + print(*msg, file=sys.stderr) + sys.exit(1) if __name__ == "__main__": - release = json.loads( - subprocess.run(["gh", "api", "/repos/actions/runner/releases?per_page=1"], - check=True, - text=True, - stdout=subprocess.PIPE).stdout.strip())[0] - - if not release["tag_name"].startswith("v"): - error( - f"ERROR: Release tag name '{release.tag_name}' does not start with 'v' as expected" - ) - - version = release["tag_name"][1:] - digest = None - - sha_pattern = re.compile(DIGEST_SEARCH_PATTERN, flags=re.MULTILINE) - matches = sha_pattern.findall(release["body"]) - - if not matches: - error( - f"ERROR: No lines match digest search regex: '{DIGEST_SEARCH_PATTERN}'") - - if len(matches) > 1: - error(f"ERROR: Multiple lines match digest search regex:", matches) - - digest = matches[0] - - archive = RUNNER_ARCHIVE_TEMPLATE.substitute(version=version) - asset_url = ASSET_URL_TEMPLATE.substitute(version=version, archive=archive) - - # With Python 3.11 we could use hashlib.file_digest - hash = hashlib.sha256() - with urllib.request.urlopen(asset_url) as f: - hash.update(f.read()) - - actual_digest = hash.hexdigest() - - if digest != actual_digest: - error(f"Digest extracted from release notes ('{digest}') does not match" - f" digest obtained from fetching '{asset_url}' ('{actual_digest}')") - - for line in fileinput.input(files=[TARGET_SCRIPT], inplace=True): - if line.startswith("GITHUB_RUNNER_VERSION"): - print(VERSION_LINE_FORMAT_STRING % (version,)) - elif line.startswith("GITHUB_RUNNER_ARCHIVE_DIGEST"): - print(DIGEST_LINE_FORMAT_STRING % (digest,)) - else: - print(line, end="") - - print(f"Successfully updated {TARGET_SCRIPT}") + release = json.loads( + subprocess.run( + ["gh", "api", "/repos/actions/runner/releases?per_page=1"], + check=True, + text=True, + stdout=subprocess.PIPE, + ).stdout.strip() + )[0] + + if not release["tag_name"].startswith("v"): + error( + f"ERROR: Release tag name '{release.tag_name}' does not start with 'v' as expected" + ) + + version = release["tag_name"][1:] + digest = None + + sha_pattern = re.compile(DIGEST_SEARCH_PATTERN, flags=re.MULTILINE) + matches = sha_pattern.findall(release["body"]) + + if not matches: + error(f"ERROR: No lines match digest search regex: '{DIGEST_SEARCH_PATTERN}'") + + if len(matches) > 1: + error(f"ERROR: Multiple lines match digest search regex:", matches) + + digest = matches[0] + + archive = RUNNER_ARCHIVE_TEMPLATE.substitute(version=version) + asset_url = ASSET_URL_TEMPLATE.substitute(version=version, archive=archive) + + # With Python 3.11 we could use hashlib.file_digest + hash = hashlib.sha256() + with urllib.request.urlopen(asset_url) as f: + hash.update(f.read()) + + actual_digest = hash.hexdigest() + + if digest != actual_digest: + error( + f"Digest extracted from release notes ('{digest}') does not match" + f" digest obtained from fetching '{asset_url}' ('{actual_digest}')" + ) + + for line in fileinput.input(files=[TARGET_SCRIPT], inplace=True): + if line.startswith("GITHUB_RUNNER_VERSION"): + print(VERSION_LINE_FORMAT_STRING % (version,)) + elif line.startswith("GITHUB_RUNNER_ARCHIVE_DIGEST"): + print(DIGEST_LINE_FORMAT_STRING % (digest,)) + else: + print(line, end="") + + print(f"Successfully updated {TARGET_SCRIPT}") diff --git a/build_tools/github_actions/runner/instance_deleter/main.py b/build_tools/github_actions/runner/instance_deleter/main.py index 930787931331..5c05a0c1d764 100644 --- a/build_tools/github_actions/runner/instance_deleter/main.py +++ b/build_tools/github_actions/runner/instance_deleter/main.py @@ -64,8 +64,14 @@ import random import re import time -from http.client import (BAD_REQUEST, FORBIDDEN, GATEWAY_TIMEOUT, - INTERNAL_SERVER_ERROR, NOT_FOUND, UNAUTHORIZED) +from http.client import ( + BAD_REQUEST, + FORBIDDEN, + GATEWAY_TIMEOUT, + INTERNAL_SERVER_ERROR, + NOT_FOUND, + UNAUTHORIZED, +) import flask import functions_framework @@ -92,119 +98,133 @@ def _verify_token(token: str) -> dict: - """Verify token signature and return the token payload""" - request = transport.requests.Request(session) - payload = id_token.verify_oauth2_token(token, request=request) - return payload + """Verify token signature and return the token payload""" + request = transport.requests.Request(session) + payload = id_token.verify_oauth2_token(token, request=request) + return payload def _get_region(zone: str) -> str: - """Extract region name from zone name""" - # Drop the trailing zone identifier to get the region. Yeah it kinda does seem - # like there should be a better way to do this... - region, _ = zone.rsplit("-", maxsplit=1) - return region + """Extract region name from zone name""" + # Drop the trailing zone identifier to get the region. Yeah it kinda does seem + # like there should be a better way to do this... + region, _ = zone.rsplit("-", maxsplit=1) + return region def _get_name_from_resource(resource: str) -> str: - """Extract just the final name component from a fully scoped resource name.""" - _, name = resource.rsplit("/", maxsplit=1) - return name + """Extract just the final name component from a fully scoped resource name.""" + _, name = resource.rsplit("/", maxsplit=1) + return name def _get_from_items(items: compute.Items, key: str): - # Why would the GCP Python API return something as silly as a dictionary? - return next((item.value for item in items if item.key == key), None) + # Why would the GCP Python API return something as silly as a dictionary? + return next((item.value for item in items if item.key == key), None) -def delete_instance_from_mig(mig_name: str, project: str, region: str, - instance: compute.Instance): - try: - operation = migs_client.delete_instances( - instance_group_manager=mig_name, - project=project, - region=region, - # For some reason we can't just use a list of instance names and need to - # build this RhymingRythmicJavaClasses proto. Also, unlike all the other - # parameters, the instance has to be a fully-specified URL for the - # instance, not just its name. - region_instance_group_managers_delete_instances_request_resource=( - compute.RegionInstanceGroupManagersDeleteInstancesRequest( - instances=[instance.self_link]))) - except (google.api_core.exceptions.Forbidden, - google.api_core.exceptions.Unauthorized, - google.api_core.exceptions.NotFound) as e: - print(e) - return flask.abort( - e.code, f"Error requesting that {mig_name} delete {instance.name}.") - except Exception as e: - # We'll call any other error here a server error. - print(e) - return flask.abort( - INTERNAL_SERVER_ERROR, - f"Error requesting that {mig_name} delete {instance.name}.") - - try: - # This is actually an extended operation that you have to poll to get its - # status, but we just check the status once because it appears that errors - # always show up here and all we just want to return success in marking for - # deletion. We don't need to wait for the deletion to actually take place. - operation.result() - except google.api_core.exceptions.ClientError as e: - print(e) - # Unpack the actual usable error message - msg = ( - f"Error requesting that {mig_name} delete {instance.name}:" - "\n" + "\n".join( - [f"{err.code}: {err.message}" for err in e.response.error.errors])) - print(msg) - # We're not actually totally sure whether this is a client or server error - # for the overall request, but let's call it a client error (the only client - # here is our VM instances, so I think we can be a bit loose). - return flask.abort(BAD_REQUEST, msg) - - success_msg = f"{instance.name} has been marked for deletion by {mig_name}." - print(success_msg) - return success_msg +def delete_instance_from_mig( + mig_name: str, project: str, region: str, instance: compute.Instance +): + try: + operation = migs_client.delete_instances( + instance_group_manager=mig_name, + project=project, + region=region, + # For some reason we can't just use a list of instance names and need to + # build this RhymingRythmicJavaClasses proto. Also, unlike all the other + # parameters, the instance has to be a fully-specified URL for the + # instance, not just its name. + region_instance_group_managers_delete_instances_request_resource=( + compute.RegionInstanceGroupManagersDeleteInstancesRequest( + instances=[instance.self_link] + ) + ), + ) + except ( + google.api_core.exceptions.Forbidden, + google.api_core.exceptions.Unauthorized, + google.api_core.exceptions.NotFound, + ) as e: + print(e) + return flask.abort( + e.code, f"Error requesting that {mig_name} delete {instance.name}." + ) + except Exception as e: + # We'll call any other error here a server error. + print(e) + return flask.abort( + INTERNAL_SERVER_ERROR, + f"Error requesting that {mig_name} delete {instance.name}.", + ) + + try: + # This is actually an extended operation that you have to poll to get its + # status, but we just check the status once because it appears that errors + # always show up here and all we just want to return success in marking for + # deletion. We don't need to wait for the deletion to actually take place. + operation.result() + except google.api_core.exceptions.ClientError as e: + print(e) + # Unpack the actual usable error message + msg = ( + f"Error requesting that {mig_name} delete {instance.name}:" + "\n" + + "\n".join( + [f"{err.code}: {err.message}" for err in e.response.error.errors] + ) + ) + print(msg) + # We're not actually totally sure whether this is a client or server error + # for the overall request, but let's call it a client error (the only client + # here is our VM instances, so I think we can be a bit loose). + return flask.abort(BAD_REQUEST, msg) + + success_msg = f"{instance.name} has been marked for deletion by {mig_name}." + print(success_msg) + return success_msg def should_scale_down(mig_name: str, project: str, region: str): - start = time.time() - print(f"Polling {mig_name} for stability") - while time.time() - start < STABILIZE_TIMEOUT_SECONDS: - try: - mig = migs_client.get(project=project, - region=region, - instance_group_manager=mig_name) - except google.api_core.exceptions.NotFound as e: - print(e) - return flask.abort( - e.code, - f"Cannot find {mig_name} in region={region}, project={project}") - if mig.status.is_stable: - break - # We sleep for a random amount of time here to avoid synchronizing callers - # waiting for the MIG to be stable. - sleep_secs = random.randint(1, 15) - print(f"{mig_name} is not stable. Retrying in {sleep_secs} seconds") - time.sleep(sleep_secs) - else: - return flask.abort(GATEWAY_TIMEOUT, - "Timed out waiting for the MIG to become stable") - autoscaler = autoscalers_client.get(project=project, - region=region, - autoscaler=_get_name_from_resource( - mig.status.autoscaler)) - response = "true" if autoscaler.recommended_size < mig.target_size else "false" - print( - f"Autoscaler recommends size {autoscaler.recommended_size} and" - f" {mig_name} is targetting size {mig.target_size}. Sending: {response}") - return response + start = time.time() + print(f"Polling {mig_name} for stability") + while time.time() - start < STABILIZE_TIMEOUT_SECONDS: + try: + mig = migs_client.get( + project=project, region=region, instance_group_manager=mig_name + ) + except google.api_core.exceptions.NotFound as e: + print(e) + return flask.abort( + e.code, f"Cannot find {mig_name} in region={region}, project={project}" + ) + if mig.status.is_stable: + break + # We sleep for a random amount of time here to avoid synchronizing callers + # waiting for the MIG to be stable. + sleep_secs = random.randint(1, 15) + print(f"{mig_name} is not stable. Retrying in {sleep_secs} seconds") + time.sleep(sleep_secs) + else: + return flask.abort( + GATEWAY_TIMEOUT, "Timed out waiting for the MIG to become stable" + ) + autoscaler = autoscalers_client.get( + project=project, + region=region, + autoscaler=_get_name_from_resource(mig.status.autoscaler), + ) + response = "true" if autoscaler.recommended_size < mig.target_size else "false" + print( + f"Autoscaler recommends size {autoscaler.recommended_size} and" + f" {mig_name} is targetting size {mig.target_size}. Sending: {response}" + ) + return response @functions_framework.http def delete_self(request: flask.Request): - """HTTP Cloud Function to delete the instance group making the request. + """HTTP Cloud Function to delete the instance group making the request. Args: request: The request object. https://flask.palletsprojects.com/en/1.1.x/api/#incoming-request-data @@ -216,104 +236,117 @@ def delete_self(request: flask.Request): For more information on how Flask integrates with Cloud Functions, see the `Writing HTTP functions` page. https://cloud.google.com/functions/docs/writing/http#http_frameworks - """ - if request.method not in ALLOWED_HTTP_METHODS: - return flask.abort( - BAD_REQUEST, f"Invalid method {request.method}." - f" Allowed methods: {ALLOWED_HTTP_METHODS}") - - # No path is needed, since the token and method contain all the information we - # need. Maybe that design was a mistake, but since the resource being operated - # on is always the instance making the call, it seemed handy. - if request.path != "/": - return flask.abort( - BAD_REQUEST, - f"Invalid request path {request.path}. Only root path is valid).") - - auth_header = request.headers.get("Authorization") - if auth_header is None: - return flask.abort(UNAUTHORIZED, "Authorization header is missing") - if not auth_header.startswith(AUTH_HEADER_PREFIX): - return flask.abort( - UNAUTHORIZED, - f"Authorization header does not start with expected string" - f" {AUTH_HEADER_PREFIX}.") - - token = auth_header[len(AUTH_HEADER_PREFIX):] - - try: - # We don't verify audience here because Cloud IAM will have already done so - # and jwt's matching of audiences is exact, which means trailing slashes or - # http vs https matters and that's pretty brittle. - token_payload = _verify_token(token) - except (ValueError, google.auth.exceptions.GoogleAuthError) as e: - print(e) - return flask.abort(UNAUTHORIZED, "Decoding bearer token failed.") - - print(f"Token payload: {token_payload}") - - try: - compute_info = token_payload["google"]["compute_engine"] - except KeyError: - return flask.abort( - UNAUTHORIZED, - "Bearer token payload does not have expected field google.compute") - - project = compute_info["project_id"] - zone = compute_info["zone"] - region = _get_region(zone) - instance_name = compute_info["instance_name"] - - if request.method == "DELETE": - print(f"Received request to delete {instance_name}") - else: + """ + if request.method not in ALLOWED_HTTP_METHODS: + return flask.abort( + BAD_REQUEST, + f"Invalid method {request.method}." + f" Allowed methods: {ALLOWED_HTTP_METHODS}", + ) + + # No path is needed, since the token and method contain all the information we + # need. Maybe that design was a mistake, but since the resource being operated + # on is always the instance making the call, it seemed handy. + if request.path != "/": + return flask.abort( + BAD_REQUEST, + f"Invalid request path {request.path}. Only root path is valid).", + ) + + auth_header = request.headers.get("Authorization") + if auth_header is None: + return flask.abort(UNAUTHORIZED, "Authorization header is missing") + if not auth_header.startswith(AUTH_HEADER_PREFIX): + return flask.abort( + UNAUTHORIZED, + f"Authorization header does not start with expected string" + f" {AUTH_HEADER_PREFIX}.", + ) + + token = auth_header[len(AUTH_HEADER_PREFIX) :] + + try: + # We don't verify audience here because Cloud IAM will have already done so + # and jwt's matching of audiences is exact, which means trailing slashes or + # http vs https matters and that's pretty brittle. + token_payload = _verify_token(token) + except (ValueError, google.auth.exceptions.GoogleAuthError) as e: + print(e) + return flask.abort(UNAUTHORIZED, "Decoding bearer token failed.") + + print(f"Token payload: {token_payload}") + + try: + compute_info = token_payload["google"]["compute_engine"] + except KeyError: + return flask.abort( + UNAUTHORIZED, + "Bearer token payload does not have expected field google.compute", + ) + + project = compute_info["project_id"] + zone = compute_info["zone"] + region = _get_region(zone) + instance_name = compute_info["instance_name"] + + if request.method == "DELETE": + print(f"Received request to delete {instance_name}") + else: + assert request.method == "GET" + print(f"Received inquiry whether to delete {instance_name}") + try: + instance = instances_client.get( + instance=instance_name, project=project, zone=zone + ) + except ( + google.api_core.exceptions.NotFound, + google.api_core.exceptions.Forbidden, + ) as e: + print(e) + return flask.abort( + e.code, f"Cannot view {instance_name} in zone={zone}, project={project}" + ) + + instance_id = int(compute_info["instance_id"]) + # Verify it's *actually* the same instance. Names get reused, but IDs don't. + # For some reason you can't reference anything by ID in the API. + if instance.id != instance_id: + return flask.abort( + BAD_REQUEST, + f"Existing instance of the same name {instance.name} has a different" + f" ID {instance.id} than token specifies {instance_id}.", + ) + + mig_name = _get_from_items(instance.metadata.items, MIG_METADATA_KEY) + + if mig_name is None: + return flask.abort( + BAD_REQUEST, + ( + f"Instance is not part of a managed instance group." + f" Did not find {MIG_METADATA_KEY} in metadata." + ), + ) + mig_name = _get_name_from_resource(mig_name) + + # General good practice would be to compile the regex once, but the only way + # to do that is to make it a global, which makes this difficult to test and + # compiling this regex should not be expensive. + allowed_mig_pattern = os.environ.get(ALLOWED_MIG_PATTERN_ENV_VARIABLE) + if allowed_mig_pattern is None: + flask.abort( + INTERNAL_SERVER_ERROR, + f"Missing required environment variable" + f" {ALLOWED_MIG_PATTERN_ENV_VARIABLE}", + ) + + if not re.fullmatch(allowed_mig_pattern, mig_name): + return flask.abort(FORBIDDEN, f"No access to MIG {mig_name}") + + if request.method == "DELETE": + return delete_instance_from_mig( + mig_name=mig_name, project=project, region=region, instance=instance + ) + assert request.method == "GET" - print(f"Received inquiry whether to delete {instance_name}") - try: - instance = instances_client.get(instance=instance_name, - project=project, - zone=zone) - except (google.api_core.exceptions.NotFound, - google.api_core.exceptions.Forbidden) as e: - print(e) - return flask.abort( - e.code, - f"Cannot view {instance_name} in zone={zone}, project={project}") - - instance_id = int(compute_info["instance_id"]) - # Verify it's *actually* the same instance. Names get reused, but IDs don't. - # For some reason you can't reference anything by ID in the API. - if instance.id != instance_id: - return flask.abort( - BAD_REQUEST, - f"Existing instance of the same name {instance.name} has a different" - f" ID {instance.id} than token specifies {instance_id}.") - - mig_name = _get_from_items(instance.metadata.items, MIG_METADATA_KEY) - - if mig_name is None: - return flask.abort(BAD_REQUEST, - (f"Instance is not part of a managed instance group." - f" Did not find {MIG_METADATA_KEY} in metadata.")) - mig_name = _get_name_from_resource(mig_name) - - # General good practice would be to compile the regex once, but the only way - # to do that is to make it a global, which makes this difficult to test and - # compiling this regex should not be expensive. - allowed_mig_pattern = os.environ.get(ALLOWED_MIG_PATTERN_ENV_VARIABLE) - if allowed_mig_pattern is None: - flask.abort( - INTERNAL_SERVER_ERROR, f"Missing required environment variable" - f" {ALLOWED_MIG_PATTERN_ENV_VARIABLE}") - - if not re.fullmatch(allowed_mig_pattern, mig_name): - return flask.abort(FORBIDDEN, f"No access to MIG {mig_name}") - - if request.method == "DELETE": - return delete_instance_from_mig(mig_name=mig_name, - project=project, - region=region, - instance=instance) - - assert request.method == "GET" - return should_scale_down(mig_name=mig_name, project=project, region=region) + return should_scale_down(mig_name=mig_name, project=project, region=region) diff --git a/build_tools/github_actions/runner/instance_deleter/main_test.py b/build_tools/github_actions/runner/instance_deleter/main_test.py index 43397d6b076a..535a59bb579f 100644 --- a/build_tools/github_actions/runner/instance_deleter/main_test.py +++ b/build_tools/github_actions/runner/instance_deleter/main_test.py @@ -34,506 +34,564 @@ def get_message(ctx): - return ctx.exception.get_response().get_data(as_text=True) + return ctx.exception.get_response().get_data(as_text=True) # A fake for oauth2 token verification that pretends the encoding scheme is just # JSON. def fake_verify_oauth2_token(token, request): - del request - return json.loads(token) + del request + return json.loads(token) def make_token(payload: dict): - return json.dumps(payload) + return json.dumps(payload) -@mock.patch("google.oauth2.id_token.verify_oauth2_token", - fake_verify_oauth2_token) +@mock.patch("google.oauth2.id_token.verify_oauth2_token", fake_verify_oauth2_token) class InstanceDeleterTest(unittest.TestCase): - - def setUp(self): - self.addCleanup(mock.patch.stopall) - instances_client_patcher = mock.patch("main.instances_client", - autospec=True) - self.instances_client = instances_client_patcher.start() - migs_client_patcher = mock.patch("main.migs_client", autospec=True) - self.migs_client = migs_client_patcher.start() - os_environ_patcher = mock.patch.dict( - "os.environ", {main.ALLOWED_MIG_PATTERN_ENV_VARIABLE: ".*"}) - self.environ = os_environ_patcher.start() - autoscalers_client_patcher = mock.patch("main.autoscalers_client", - autospec=True) - self.autoscalers_client = autoscalers_client_patcher.start() - time_patcher = mock.patch("time.time", autospec=True) - self.time = time_patcher.start() - self.time.return_value = 0 - # Just noop sleep - mock.patch("time.sleep", autospec=True).start() - - def test_delete_happy_path(self): - req = Request({}, populate_request=False, shallow=True) - req.method = "DELETE" - - token = make_token({ - "google": { - "compute_engine": { - "project_id": PROJECT, - "zone": f"{REGION}-a", - "instance_name": INSTANCE_NAME, - "instance_id": str(ID1), + def setUp(self): + self.addCleanup(mock.patch.stopall) + instances_client_patcher = mock.patch("main.instances_client", autospec=True) + self.instances_client = instances_client_patcher.start() + migs_client_patcher = mock.patch("main.migs_client", autospec=True) + self.migs_client = migs_client_patcher.start() + os_environ_patcher = mock.patch.dict( + "os.environ", {main.ALLOWED_MIG_PATTERN_ENV_VARIABLE: ".*"} + ) + self.environ = os_environ_patcher.start() + autoscalers_client_patcher = mock.patch( + "main.autoscalers_client", autospec=True + ) + self.autoscalers_client = autoscalers_client_patcher.start() + time_patcher = mock.patch("time.time", autospec=True) + self.time = time_patcher.start() + self.time.return_value = 0 + # Just noop sleep + mock.patch("time.sleep", autospec=True).start() + + def test_delete_happy_path(self): + req = Request({}, populate_request=False, shallow=True) + req.method = "DELETE" + + token = make_token( + { + "google": { + "compute_engine": { + "project_id": PROJECT, + "zone": f"{REGION}-a", + "instance_name": INSTANCE_NAME, + "instance_id": str(ID1), + } + } } - } - }) - - req.headers = {"Authorization": f"Bearer {token}"} - - self_link = f"{INSTANCE_LINK_PREFIX}{INSTANCE_NAME}" - instance = compute.Instance( - id=ID1, - name=INSTANCE_NAME, - zone=ZONE, - self_link=self_link, - metadata=compute.Metadata(items=[ - compute.Items(key=main.MIG_METADATA_KEY, - value=f"{MIG_PATH_PREFIX}{MIG_NAME}") - ])) - self.instances_client.get.return_value = instance - - response = main.delete_self(req) - - self.assertIn(MIG_NAME, response) - self.assertIn(INSTANCE_NAME, response) - - self.migs_client.delete_instances.assert_called_once_with( - instance_group_manager=MIG_NAME, - project=PROJECT, - region=REGION, - region_instance_group_managers_delete_instances_request_resource=compute - .RegionInstanceGroupManagersDeleteInstancesRequest( - instances=[instance.self_link])) - - def test_get_happy_path(self): - req = Request({}, populate_request=False, shallow=True) - req.method = "GET" - - token = make_token({ - "google": { - "compute_engine": { - "project_id": PROJECT, - "zone": f"{REGION}-a", - "instance_name": INSTANCE_NAME, - "instance_id": str(ID1), + ) + + req.headers = {"Authorization": f"Bearer {token}"} + + self_link = f"{INSTANCE_LINK_PREFIX}{INSTANCE_NAME}" + instance = compute.Instance( + id=ID1, + name=INSTANCE_NAME, + zone=ZONE, + self_link=self_link, + metadata=compute.Metadata( + items=[ + compute.Items( + key=main.MIG_METADATA_KEY, value=f"{MIG_PATH_PREFIX}{MIG_NAME}" + ) + ] + ), + ) + self.instances_client.get.return_value = instance + + response = main.delete_self(req) + + self.assertIn(MIG_NAME, response) + self.assertIn(INSTANCE_NAME, response) + + self.migs_client.delete_instances.assert_called_once_with( + instance_group_manager=MIG_NAME, + project=PROJECT, + region=REGION, + region_instance_group_managers_delete_instances_request_resource=compute.RegionInstanceGroupManagersDeleteInstancesRequest( + instances=[instance.self_link] + ), + ) + + def test_get_happy_path(self): + req = Request({}, populate_request=False, shallow=True) + req.method = "GET" + + token = make_token( + { + "google": { + "compute_engine": { + "project_id": PROJECT, + "zone": f"{REGION}-a", + "instance_name": INSTANCE_NAME, + "instance_id": str(ID1), + } + } } - } - }) - - req.headers = {"Authorization": f"Bearer {token}"} - - self_link = f"{INSTANCE_LINK_PREFIX}{INSTANCE_NAME}" - instance = compute.Instance( - id=ID1, - name=INSTANCE_NAME, - zone=ZONE, - self_link=self_link, - metadata=compute.Metadata(items=[ - compute.Items(key=main.MIG_METADATA_KEY, - value=f"{MIG_PATH_PREFIX}{MIG_NAME}") - ])) - self.instances_client.get.return_value = instance - - mig = compute.InstanceGroupManager( - target_size=5, - status={ - "is_stable": True, - "autoscaler": "autoscaler_link/autoscaler_name" - }) - self.migs_client.get.return_value = mig - - autoscaler = compute.Autoscaler(recommended_size=3) - self.autoscalers_client.get.return_value = autoscaler - - response = main.delete_self(req) - - self.assertEqual(response, "true") - - def test_get_timeout(self): - req = Request({}, populate_request=False, shallow=True) - req.method = "GET" - - token = make_token({ - "google": { - "compute_engine": { - "project_id": PROJECT, - "zone": f"{REGION}-a", - "instance_name": INSTANCE_NAME, - "instance_id": str(ID1), + ) + + req.headers = {"Authorization": f"Bearer {token}"} + + self_link = f"{INSTANCE_LINK_PREFIX}{INSTANCE_NAME}" + instance = compute.Instance( + id=ID1, + name=INSTANCE_NAME, + zone=ZONE, + self_link=self_link, + metadata=compute.Metadata( + items=[ + compute.Items( + key=main.MIG_METADATA_KEY, value=f"{MIG_PATH_PREFIX}{MIG_NAME}" + ) + ] + ), + ) + self.instances_client.get.return_value = instance + + mig = compute.InstanceGroupManager( + target_size=5, + status={"is_stable": True, "autoscaler": "autoscaler_link/autoscaler_name"}, + ) + self.migs_client.get.return_value = mig + + autoscaler = compute.Autoscaler(recommended_size=3) + self.autoscalers_client.get.return_value = autoscaler + + response = main.delete_self(req) + + self.assertEqual(response, "true") + + def test_get_timeout(self): + req = Request({}, populate_request=False, shallow=True) + req.method = "GET" + + token = make_token( + { + "google": { + "compute_engine": { + "project_id": PROJECT, + "zone": f"{REGION}-a", + "instance_name": INSTANCE_NAME, + "instance_id": str(ID1), + } + } } - } - }) - - req.headers = {"Authorization": f"Bearer {token}"} - - self_link = f"{INSTANCE_LINK_PREFIX}{INSTANCE_NAME}" - instance = compute.Instance( - id=ID1, - name=INSTANCE_NAME, - zone=ZONE, - self_link=self_link, - metadata=compute.Metadata(items=[ - compute.Items(key=main.MIG_METADATA_KEY, - value=f"{MIG_PATH_PREFIX}{MIG_NAME}") - ])) - self.instances_client.get.return_value = instance - - mig = compute.InstanceGroupManager( - target_size=5, - status={ - "is_stable": False, - "autoscaler": "autoscaler_link/autoscaler_name" - }) - self.migs_client.get.return_value = mig - self.time.side_effect = [0, main.STABILIZE_TIMEOUT_SECONDS + 1] - - with self.assertRaises(werkzeug.exceptions.GatewayTimeout): - main.delete_self(req) - - def test_narrow_allowed_migs(self): - req = Request({}, populate_request=False, shallow=True) - req.method = "DELETE" - - token = make_token({ - "google": { - "compute_engine": { - "project_id": PROJECT, - "zone": f"{REGION}-a", - "instance_name": INSTANCE_NAME, - "instance_id": str(ID1), + ) + + req.headers = {"Authorization": f"Bearer {token}"} + + self_link = f"{INSTANCE_LINK_PREFIX}{INSTANCE_NAME}" + instance = compute.Instance( + id=ID1, + name=INSTANCE_NAME, + zone=ZONE, + self_link=self_link, + metadata=compute.Metadata( + items=[ + compute.Items( + key=main.MIG_METADATA_KEY, value=f"{MIG_PATH_PREFIX}{MIG_NAME}" + ) + ] + ), + ) + self.instances_client.get.return_value = instance + + mig = compute.InstanceGroupManager( + target_size=5, + status={ + "is_stable": False, + "autoscaler": "autoscaler_link/autoscaler_name", + }, + ) + self.migs_client.get.return_value = mig + self.time.side_effect = [0, main.STABILIZE_TIMEOUT_SECONDS + 1] + + with self.assertRaises(werkzeug.exceptions.GatewayTimeout): + main.delete_self(req) + + def test_narrow_allowed_migs(self): + req = Request({}, populate_request=False, shallow=True) + req.method = "DELETE" + + token = make_token( + { + "google": { + "compute_engine": { + "project_id": PROJECT, + "zone": f"{REGION}-a", + "instance_name": INSTANCE_NAME, + "instance_id": str(ID1), + } + } } - } - }) - - req.headers = {"Authorization": f"Bearer {token}"} - - mig_name = "github-runner-foo-bar" - self.environ[main.ALLOWED_MIG_PATTERN_ENV_VARIABLE] = "github-runner-.*" - self_link = f"{INSTANCE_LINK_PREFIX}{INSTANCE_NAME}" - instance = compute.Instance( - id=ID1, - name=INSTANCE_NAME, - zone=ZONE, - self_link=self_link, - metadata=compute.Metadata(items=[ - compute.Items(key=main.MIG_METADATA_KEY, - value=f"{MIG_PATH_PREFIX}{mig_name}") - ])) - self.instances_client.get.return_value = instance - - ext_operation = mock.MagicMock( - google.api_core.extended_operation.ExtendedOperation) - ext_operation.result.return_value = None - - response = main.delete_self(req) - - self.assertIn(mig_name, response) - self.assertIn(INSTANCE_NAME, response) - - self.migs_client.delete_instances.assert_called_once_with( - instance_group_manager=mig_name, - project=PROJECT, - region=REGION, - region_instance_group_managers_delete_instances_request_resource=compute - .RegionInstanceGroupManagersDeleteInstancesRequest( - instances=[instance.self_link])) - - def test_bad_method(self): - req = Request({}, populate_request=False, shallow=True) - req.method = "POST" - - with self.assertRaises(werkzeug.exceptions.BadRequest) as ctx: - main.delete_self(req) - - self.assertIn("Invalid method", get_message(ctx)) - - def test_bad_path(self): - req = Request({}, populate_request=False, shallow=True) - req.method = "DELETE" - req.path = "/foo/bar" - - with self.assertRaises(werkzeug.exceptions.BadRequest) as ctx: - main.delete_self(req) - - self.assertIn("Invalid request path", get_message(ctx)) - - def test_missing_header(self): - req = Request({}, populate_request=False, shallow=True) - req.method = "DELETE" - - with self.assertRaises(werkzeug.exceptions.Unauthorized) as ctx: - main.delete_self(req) - - self.assertIn("Authorization header is missing", get_message(ctx)) - - def test_malformed_header(self): - req = Request({}, populate_request=False, shallow=True) - req.method = "DELETE" - req.headers = {"Authorization": "UnknownScheme token"} + ) - with self.assertRaises(werkzeug.exceptions.Unauthorized) as ctx: - main.delete_self(req) + req.headers = {"Authorization": f"Bearer {token}"} - self.assertIn("Authorization header does not start", get_message(ctx)) + mig_name = "github-runner-foo-bar" + self.environ[main.ALLOWED_MIG_PATTERN_ENV_VARIABLE] = "github-runner-.*" + self_link = f"{INSTANCE_LINK_PREFIX}{INSTANCE_NAME}" + instance = compute.Instance( + id=ID1, + name=INSTANCE_NAME, + zone=ZONE, + self_link=self_link, + metadata=compute.Metadata( + items=[ + compute.Items( + key=main.MIG_METADATA_KEY, value=f"{MIG_PATH_PREFIX}{mig_name}" + ) + ] + ), + ) + self.instances_client.get.return_value = instance + + ext_operation = mock.MagicMock( + google.api_core.extended_operation.ExtendedOperation + ) + ext_operation.result.return_value = None + + response = main.delete_self(req) + + self.assertIn(mig_name, response) + self.assertIn(INSTANCE_NAME, response) + + self.migs_client.delete_instances.assert_called_once_with( + instance_group_manager=mig_name, + project=PROJECT, + region=REGION, + region_instance_group_managers_delete_instances_request_resource=compute.RegionInstanceGroupManagersDeleteInstancesRequest( + instances=[instance.self_link] + ), + ) + + def test_bad_method(self): + req = Request({}, populate_request=False, shallow=True) + req.method = "POST" + + with self.assertRaises(werkzeug.exceptions.BadRequest) as ctx: + main.delete_self(req) + + self.assertIn("Invalid method", get_message(ctx)) + + def test_bad_path(self): + req = Request({}, populate_request=False, shallow=True) + req.method = "DELETE" + req.path = "/foo/bar" + + with self.assertRaises(werkzeug.exceptions.BadRequest) as ctx: + main.delete_self(req) + + self.assertIn("Invalid request path", get_message(ctx)) + + def test_missing_header(self): + req = Request({}, populate_request=False, shallow=True) + req.method = "DELETE" + + with self.assertRaises(werkzeug.exceptions.Unauthorized) as ctx: + main.delete_self(req) + + self.assertIn("Authorization header is missing", get_message(ctx)) + + def test_malformed_header(self): + req = Request({}, populate_request=False, shallow=True) + req.method = "DELETE" + req.headers = {"Authorization": "UnknownScheme token"} + + with self.assertRaises(werkzeug.exceptions.Unauthorized) as ctx: + main.delete_self(req) + + self.assertIn("Authorization header does not start", get_message(ctx)) + + def test_invalid_token(self): + req = Request({}, populate_request=False, shallow=True) + req.method = "DELETE" + req.headers = {"Authorization": f"Bearer {INVALID_TOKEN}"} + + with self.assertRaises(werkzeug.exceptions.Unauthorized) as ctx: + main.delete_self(req) + + self.assertIn("token", get_message(ctx)) + + def test_bad_token_payload(self): + req = Request({}, populate_request=False, shallow=True) + req.method = "DELETE" + + token = make_token({"aud": "localhost"}) + + req.headers = {"Authorization": f"Bearer {token}"} + + with self.assertRaises(werkzeug.exceptions.Unauthorized) as ctx: + main.delete_self(req) + + self.assertIn("token", get_message(ctx)) - def test_invalid_token(self): - req = Request({}, populate_request=False, shallow=True) - req.method = "DELETE" - req.headers = {"Authorization": f"Bearer {INVALID_TOKEN}"} - - with self.assertRaises(werkzeug.exceptions.Unauthorized) as ctx: - main.delete_self(req) - - self.assertIn("token", get_message(ctx)) - - def test_bad_token_payload(self): - req = Request({}, populate_request=False, shallow=True) - req.method = "DELETE" - - token = make_token({"aud": "localhost"}) - - req.headers = {"Authorization": f"Bearer {token}"} - - with self.assertRaises(werkzeug.exceptions.Unauthorized) as ctx: - main.delete_self(req) - - self.assertIn("token", get_message(ctx)) - - def test_nonexistent_instance(self): - req = Request({}, populate_request=False, shallow=True) - req.method = "DELETE" - - token = make_token({ - "google": { - "compute_engine": { - "project_id": PROJECT, - "zone": ZONE, - "instance_name": INSTANCE_NAME, - "instance_id": str(ID1), + def test_nonexistent_instance(self): + req = Request({}, populate_request=False, shallow=True) + req.method = "DELETE" + + token = make_token( + { + "google": { + "compute_engine": { + "project_id": PROJECT, + "zone": ZONE, + "instance_name": INSTANCE_NAME, + "instance_id": str(ID1), + } + } } - } - }) + ) - req.headers = {"Authorization": f"Bearer {token}"} + req.headers = {"Authorization": f"Bearer {token}"} - self.instances_client.get.side_effect = google.api_core.exceptions.NotFound( - "Instance not found") + self.instances_client.get.side_effect = google.api_core.exceptions.NotFound( + "Instance not found" + ) - with self.assertRaises(werkzeug.exceptions.NotFound) as ctx: - main.delete_self(req) + with self.assertRaises(werkzeug.exceptions.NotFound) as ctx: + main.delete_self(req) - self.assertIn(INSTANCE_NAME, get_message(ctx)) + self.assertIn(INSTANCE_NAME, get_message(ctx)) - def test_id_mismatch(self): - req = Request({}, populate_request=False, shallow=True) - req.method = "DELETE" + def test_id_mismatch(self): + req = Request({}, populate_request=False, shallow=True) + req.method = "DELETE" - token = make_token({ - "google": { - "compute_engine": { - "project_id": PROJECT, - "zone": ZONE, - "instance_name": INSTANCE_NAME, - "instance_id": str(ID1), + token = make_token( + { + "google": { + "compute_engine": { + "project_id": PROJECT, + "zone": ZONE, + "instance_name": INSTANCE_NAME, + "instance_id": str(ID1), + } + } } - } - }) + ) - req.headers = {"Authorization": f"Bearer {token}"} + req.headers = {"Authorization": f"Bearer {token}"} - instance = compute.Instance(id=ID2, name=INSTANCE_NAME) + instance = compute.Instance(id=ID2, name=INSTANCE_NAME) - self.instances_client.get.return_value = instance + self.instances_client.get.return_value = instance - with self.assertRaises(werkzeug.exceptions.BadRequest) as ctx: - main.delete_self(req) + with self.assertRaises(werkzeug.exceptions.BadRequest) as ctx: + main.delete_self(req) - msg = get_message(ctx) - self.assertIn(str(ID1), msg) - self.assertIn(str(ID2), msg) + msg = get_message(ctx) + self.assertIn(str(ID1), msg) + self.assertIn(str(ID2), msg) - def test_missing_mig_metadata(self): - req = Request({}, populate_request=False, shallow=True) - req.method = "DELETE" + def test_missing_mig_metadata(self): + req = Request({}, populate_request=False, shallow=True) + req.method = "DELETE" - token = make_token({ - "google": { - "compute_engine": { - "project_id": PROJECT, - "zone": ZONE, - "instance_name": INSTANCE_NAME, - "instance_id": str(ID1), + token = make_token( + { + "google": { + "compute_engine": { + "project_id": PROJECT, + "zone": ZONE, + "instance_name": INSTANCE_NAME, + "instance_id": str(ID1), + } + } } - } - }) - - req.headers = {"Authorization": f"Bearer {token}"} - - instance = compute.Instance(id=ID1, - name=INSTANCE_NAME, - zone=ZONE, - self_link=f"http://foo/bar/{INSTANCE_NAME}") - - self.instances_client.get.return_value = instance - - with self.assertRaises(werkzeug.exceptions.BadRequest) as ctx: - main.delete_self(req) - - self.assertIn(main.MIG_METADATA_KEY, get_message(ctx)) - - def test_mig_pattern_unset(self): - req = Request({}, populate_request=False, shallow=True) - req.method = "DELETE" - - token = make_token({ - "google": { - "compute_engine": { - "project_id": PROJECT, - "zone": f"{REGION}-a", - "instance_name": INSTANCE_NAME, - "instance_id": str(ID1), + ) + + req.headers = {"Authorization": f"Bearer {token}"} + + instance = compute.Instance( + id=ID1, + name=INSTANCE_NAME, + zone=ZONE, + self_link=f"http://foo/bar/{INSTANCE_NAME}", + ) + + self.instances_client.get.return_value = instance + + with self.assertRaises(werkzeug.exceptions.BadRequest) as ctx: + main.delete_self(req) + + self.assertIn(main.MIG_METADATA_KEY, get_message(ctx)) + + def test_mig_pattern_unset(self): + req = Request({}, populate_request=False, shallow=True) + req.method = "DELETE" + + token = make_token( + { + "google": { + "compute_engine": { + "project_id": PROJECT, + "zone": f"{REGION}-a", + "instance_name": INSTANCE_NAME, + "instance_id": str(ID1), + } + } } - } - }) - - req.headers = {"Authorization": f"Bearer {token}"} - - self_link = f"{INSTANCE_LINK_PREFIX}{INSTANCE_NAME}" - instance = compute.Instance( - id=ID1, - name=INSTANCE_NAME, - zone=ZONE, - self_link=self_link, - metadata=compute.Metadata(items=[ - compute.Items(key=main.MIG_METADATA_KEY, - value=f"{MIG_PATH_PREFIX}{MIG_NAME}") - ])) - self.instances_client.get.return_value = instance - - del self.environ[main.ALLOWED_MIG_PATTERN_ENV_VARIABLE] - - with self.assertRaises(werkzeug.exceptions.InternalServerError) as ctx: - main.delete_self(req) - - self.assertIn(main.ALLOWED_MIG_PATTERN_ENV_VARIABLE, get_message(ctx)) - - def test_no_migs_allowed(self): - req = Request({}, populate_request=False, shallow=True) - req.method = "DELETE" - - token = make_token({ - "google": { - "compute_engine": { - "project_id": PROJECT, - "zone": f"{REGION}-a", - "instance_name": INSTANCE_NAME, - "instance_id": str(ID1), + ) + + req.headers = {"Authorization": f"Bearer {token}"} + + self_link = f"{INSTANCE_LINK_PREFIX}{INSTANCE_NAME}" + instance = compute.Instance( + id=ID1, + name=INSTANCE_NAME, + zone=ZONE, + self_link=self_link, + metadata=compute.Metadata( + items=[ + compute.Items( + key=main.MIG_METADATA_KEY, value=f"{MIG_PATH_PREFIX}{MIG_NAME}" + ) + ] + ), + ) + self.instances_client.get.return_value = instance + + del self.environ[main.ALLOWED_MIG_PATTERN_ENV_VARIABLE] + + with self.assertRaises(werkzeug.exceptions.InternalServerError) as ctx: + main.delete_self(req) + + self.assertIn(main.ALLOWED_MIG_PATTERN_ENV_VARIABLE, get_message(ctx)) + + def test_no_migs_allowed(self): + req = Request({}, populate_request=False, shallow=True) + req.method = "DELETE" + + token = make_token( + { + "google": { + "compute_engine": { + "project_id": PROJECT, + "zone": f"{REGION}-a", + "instance_name": INSTANCE_NAME, + "instance_id": str(ID1), + } + } } - } - }) - - req.headers = {"Authorization": f"Bearer {token}"} - - instance = compute.Instance( - id=ID1, - name=INSTANCE_NAME, - zone=ZONE, - self_link=f"{INSTANCE_LINK_PREFIX}{INSTANCE_NAME}", - metadata=compute.Metadata(items=[ - compute.Items(key=main.MIG_METADATA_KEY, - value=f"{MIG_PATH_PREFIX}{MIG_NAME}") - ])) - self.instances_client.get.return_value = instance - - self.environ[main.ALLOWED_MIG_PATTERN_ENV_VARIABLE] = "" - - with self.assertRaises(werkzeug.exceptions.Forbidden) as ctx: - main.delete_self(req) - - self.assertIn(MIG_NAME, get_message((ctx))) - - def test_mig_not_allowed(self): - req = Request({}, populate_request=False, shallow=True) - req.method = "DELETE" - - token = make_token({ - "google": { - "compute_engine": { - "project_id": PROJECT, - "zone": f"{REGION}-a", - "instance_name": INSTANCE_NAME, - "instance_id": str(ID1), + ) + + req.headers = {"Authorization": f"Bearer {token}"} + + instance = compute.Instance( + id=ID1, + name=INSTANCE_NAME, + zone=ZONE, + self_link=f"{INSTANCE_LINK_PREFIX}{INSTANCE_NAME}", + metadata=compute.Metadata( + items=[ + compute.Items( + key=main.MIG_METADATA_KEY, value=f"{MIG_PATH_PREFIX}{MIG_NAME}" + ) + ] + ), + ) + self.instances_client.get.return_value = instance + + self.environ[main.ALLOWED_MIG_PATTERN_ENV_VARIABLE] = "" + + with self.assertRaises(werkzeug.exceptions.Forbidden) as ctx: + main.delete_self(req) + + self.assertIn(MIG_NAME, get_message((ctx))) + + def test_mig_not_allowed(self): + req = Request({}, populate_request=False, shallow=True) + req.method = "DELETE" + + token = make_token( + { + "google": { + "compute_engine": { + "project_id": PROJECT, + "zone": f"{REGION}-a", + "instance_name": INSTANCE_NAME, + "instance_id": str(ID1), + } + } } - } - }) - - req.headers = {"Authorization": f"Bearer {token}"} - - mig_name = "not-github-runner" - self.environ[main.ALLOWED_MIG_PATTERN_ENV_VARIABLE] = "github-runner-.*" - instance = compute.Instance( - id=ID1, - name=INSTANCE_NAME, - zone=ZONE, - self_link=f"{INSTANCE_LINK_PREFIX}{INSTANCE_NAME}", - metadata=compute.Metadata(items=[ - compute.Items(key=main.MIG_METADATA_KEY, - value=f"{MIG_PATH_PREFIX}{mig_name}") - ])) - self.instances_client.get.return_value = instance - - with self.assertRaises(werkzeug.exceptions.Forbidden) as ctx: - main.delete_self(req) - - self.assertIn(mig_name, get_message((ctx))) - - def test_bad_deletion_request_server(self): - req = Request({}, populate_request=False, shallow=True) - req.method = "DELETE" - - token = make_token({ - "google": { - "compute_engine": { - "project_id": PROJECT, - "zone": ZONE, - "instance_name": INSTANCE_NAME, - "instance_id": str(ID1), + ) + + req.headers = {"Authorization": f"Bearer {token}"} + + mig_name = "not-github-runner" + self.environ[main.ALLOWED_MIG_PATTERN_ENV_VARIABLE] = "github-runner-.*" + instance = compute.Instance( + id=ID1, + name=INSTANCE_NAME, + zone=ZONE, + self_link=f"{INSTANCE_LINK_PREFIX}{INSTANCE_NAME}", + metadata=compute.Metadata( + items=[ + compute.Items( + key=main.MIG_METADATA_KEY, value=f"{MIG_PATH_PREFIX}{mig_name}" + ) + ] + ), + ) + self.instances_client.get.return_value = instance + + with self.assertRaises(werkzeug.exceptions.Forbidden) as ctx: + main.delete_self(req) + + self.assertIn(mig_name, get_message((ctx))) + + def test_bad_deletion_request_server(self): + req = Request({}, populate_request=False, shallow=True) + req.method = "DELETE" + + token = make_token( + { + "google": { + "compute_engine": { + "project_id": PROJECT, + "zone": ZONE, + "instance_name": INSTANCE_NAME, + "instance_id": str(ID1), + } + } } - } - }) + ) - req.headers = {"Authorization": f"Bearer {token}"} + req.headers = {"Authorization": f"Bearer {token}"} - instance = compute.Instance( - id=ID1, - name=INSTANCE_NAME, - zone=ZONE, - self_link=f"{INSTANCE_LINK_PREFIX}{INSTANCE_NAME}", - metadata=compute.Metadata(items=[ - compute.Items(key=main.MIG_METADATA_KEY, - value=f"{MIG_PATH_PREFIX}{MIG_NAME}") - ])) + instance = compute.Instance( + id=ID1, + name=INSTANCE_NAME, + zone=ZONE, + self_link=f"{INSTANCE_LINK_PREFIX}{INSTANCE_NAME}", + metadata=compute.Metadata( + items=[ + compute.Items( + key=main.MIG_METADATA_KEY, value=f"{MIG_PATH_PREFIX}{MIG_NAME}" + ) + ] + ), + ) - self.instances_client.get.return_value = instance - self.migs_client.delete_instances.side_effect = ValueError("Bad request") + self.instances_client.get.return_value = instance + self.migs_client.delete_instances.side_effect = ValueError("Bad request") - with self.assertRaises(werkzeug.exceptions.InternalServerError) as ctx: - main.delete_self(req) + with self.assertRaises(werkzeug.exceptions.InternalServerError) as ctx: + main.delete_self(req) - self.assertIn(MIG_NAME, get_message(ctx)) + self.assertIn(MIG_NAME, get_message(ctx)) - # Testing of server errors is unimplemented. ExtendedOperation is not - # documented well enough for me to produce a reasonable fake and a bad fake is - # worse than nothing. + # Testing of server errors is unimplemented. ExtendedOperation is not + # documented well enough for me to produce a reasonable fake and a bad fake is + # worse than nothing. if __name__ == "__main__": - unittest.main() + unittest.main() diff --git a/build_tools/python/benchmark_suites/iree/adreno_benchmarks.py b/build_tools/python/benchmark_suites/iree/adreno_benchmarks.py index 9e4f6ae0d581..07a75880d427 100644 --- a/build_tools/python/benchmark_suites/iree/adreno_benchmarks.py +++ b/build_tools/python/benchmark_suites/iree/adreno_benchmarks.py @@ -15,83 +15,100 @@ class Android_Adreno_Benchmarks(object): - """Benchmarks on Android devices with Adreno GPU.""" + """Benchmarks on Android devices with Adreno GPU.""" - ADRENO_GPU_COMPILE_TARGET = iree_definitions.CompileTarget( - target_backend=iree_definitions.TargetBackend.VULKAN_SPIRV, - target_architecture=common_definitions.DeviceArchitecture.QUALCOMM_ADRENO, - target_abi=iree_definitions.TargetABI.VULKAN_ANDROID31) - DEFAULT_COMPILE_CONFIG = iree_definitions.CompileConfig.build( - id=unique_ids.IREE_COMPILE_CONFIG_ANDROID_QUALCOMM_ADRENO_DEFAULTS, - tags=["default-flags"], - compile_targets=[ADRENO_GPU_COMPILE_TARGET]) - FUSE_PADDING_COMPILE_CONFIG = iree_definitions.CompileConfig.build( - id=unique_ids.IREE_COMPILE_CONFIG_ANDROID_QUALCOMM_ADRENO_FUSE_PADDING, - tags=["experimental-flags", "fuse-padding"], - compile_targets=[ADRENO_GPU_COMPILE_TARGET], - extra_flags=["--iree-flow-enable-fuse-padding-into-linalg-consumer-ops"]) - FUSE_PADDING_REPEATED_KERNEL_COMPILE_CONFIG = iree_definitions.CompileConfig.build( - id=unique_ids. - IREE_COMPILE_CONFIG_ANDROID_QUALCOMM_ADRENO_FUSE_PADDING_REPEATED_KERNEL, - tags=["experimental-flags", "fuse-padding", "repeated-kernel"], - compile_targets=[ADRENO_GPU_COMPILE_TARGET], - extra_flags=FUSE_PADDING_COMPILE_CONFIG.extra_flags + - ["--iree-hal-benchmark-dispatch-repeat-count=16"]) + ADRENO_GPU_COMPILE_TARGET = iree_definitions.CompileTarget( + target_backend=iree_definitions.TargetBackend.VULKAN_SPIRV, + target_architecture=common_definitions.DeviceArchitecture.QUALCOMM_ADRENO, + target_abi=iree_definitions.TargetABI.VULKAN_ANDROID31, + ) + DEFAULT_COMPILE_CONFIG = iree_definitions.CompileConfig.build( + id=unique_ids.IREE_COMPILE_CONFIG_ANDROID_QUALCOMM_ADRENO_DEFAULTS, + tags=["default-flags"], + compile_targets=[ADRENO_GPU_COMPILE_TARGET], + ) + FUSE_PADDING_COMPILE_CONFIG = iree_definitions.CompileConfig.build( + id=unique_ids.IREE_COMPILE_CONFIG_ANDROID_QUALCOMM_ADRENO_FUSE_PADDING, + tags=["experimental-flags", "fuse-padding"], + compile_targets=[ADRENO_GPU_COMPILE_TARGET], + extra_flags=["--iree-flow-enable-fuse-padding-into-linalg-consumer-ops"], + ) + FUSE_PADDING_REPEATED_KERNEL_COMPILE_CONFIG = iree_definitions.CompileConfig.build( + id=unique_ids.IREE_COMPILE_CONFIG_ANDROID_QUALCOMM_ADRENO_FUSE_PADDING_REPEATED_KERNEL, + tags=["experimental-flags", "fuse-padding", "repeated-kernel"], + compile_targets=[ADRENO_GPU_COMPILE_TARGET], + extra_flags=FUSE_PADDING_COMPILE_CONFIG.extra_flags + + ["--iree-hal-benchmark-dispatch-repeat-count=16"], + ) - def generate( - self - ) -> Tuple[List[iree_definitions.ModuleGenerationConfig], - List[iree_definitions.E2EModelRunConfig]]: - default_models = [ - tflite_models.DEEPLABV3_FP32, - tflite_models.MOBILESSD_FP32, - tflite_models.POSENET_FP32, - tflite_models.MOBILEBERT_FP32, - tflite_models.MOBILENET_V2, - tflite_models.MOBILENET_V3SMALL, - ] - default_gen_configs = [ - iree_definitions.ModuleGenerationConfig.build( - compile_config=self.DEFAULT_COMPILE_CONFIG, - imported_model=iree_definitions.ImportedModel.from_model(model)) - for model in default_models - ] - fuse_padding_gen_configs = [ - iree_definitions.ModuleGenerationConfig.build( - compile_config=self.FUSE_PADDING_COMPILE_CONFIG, - imported_model=iree_definitions.ImportedModel.from_model(model)) - for model in default_models - ] - fuse_padding_repeated_kernel_gen_configs = [ - iree_definitions.ModuleGenerationConfig.build( - compile_config=self.FUSE_PADDING_REPEATED_KERNEL_COMPILE_CONFIG, - imported_model=iree_definitions.ImportedModel.from_model(model)) - for model in [ + def generate( + self, + ) -> Tuple[ + List[iree_definitions.ModuleGenerationConfig], + List[iree_definitions.E2EModelRunConfig], + ]: + default_models = [ + tflite_models.DEEPLABV3_FP32, tflite_models.MOBILESSD_FP32, tflite_models.POSENET_FP32, + tflite_models.MOBILEBERT_FP32, tflite_models.MOBILENET_V2, tflite_models.MOBILENET_V3SMALL, ] - ] + default_gen_configs = [ + iree_definitions.ModuleGenerationConfig.build( + compile_config=self.DEFAULT_COMPILE_CONFIG, + imported_model=iree_definitions.ImportedModel.from_model(model), + ) + for model in default_models + ] + fuse_padding_gen_configs = [ + iree_definitions.ModuleGenerationConfig.build( + compile_config=self.FUSE_PADDING_COMPILE_CONFIG, + imported_model=iree_definitions.ImportedModel.from_model(model), + ) + for model in default_models + ] + fuse_padding_repeated_kernel_gen_configs = [ + iree_definitions.ModuleGenerationConfig.build( + compile_config=self.FUSE_PADDING_REPEATED_KERNEL_COMPILE_CONFIG, + imported_model=iree_definitions.ImportedModel.from_model(model), + ) + for model in [ + tflite_models.MOBILESSD_FP32, + tflite_models.POSENET_FP32, + tflite_models.MOBILENET_V2, + tflite_models.MOBILENET_V3SMALL, + ] + ] - adreno_devices = device_collections.DEFAULT_DEVICE_COLLECTION.query_device_specs( - architecture=common_definitions.DeviceArchitecture.QUALCOMM_ADRENO, - host_environment=common_definitions.HostEnvironment.ANDROID_ARMV8_2_A) - run_configs = benchmark_suites.iree.utils.generate_e2e_model_run_configs( - module_generation_configs=default_gen_configs, - module_execution_configs=[module_execution_configs.VULKAN_CONFIG], - device_specs=adreno_devices) - run_configs += benchmark_suites.iree.utils.generate_e2e_model_run_configs( - module_generation_configs=fuse_padding_gen_configs, - module_execution_configs=[module_execution_configs.VULKAN_CONFIG], - device_specs=adreno_devices) - run_configs += benchmark_suites.iree.utils.generate_e2e_model_run_configs( - module_generation_configs=fuse_padding_repeated_kernel_gen_configs, - module_execution_configs=[ - module_execution_configs.VULKAN_BATCH_SIZE_16_CONFIG - ], - device_specs=adreno_devices) + adreno_devices = ( + device_collections.DEFAULT_DEVICE_COLLECTION.query_device_specs( + architecture=common_definitions.DeviceArchitecture.QUALCOMM_ADRENO, + host_environment=common_definitions.HostEnvironment.ANDROID_ARMV8_2_A, + ) + ) + run_configs = benchmark_suites.iree.utils.generate_e2e_model_run_configs( + module_generation_configs=default_gen_configs, + module_execution_configs=[module_execution_configs.VULKAN_CONFIG], + device_specs=adreno_devices, + ) + run_configs += benchmark_suites.iree.utils.generate_e2e_model_run_configs( + module_generation_configs=fuse_padding_gen_configs, + module_execution_configs=[module_execution_configs.VULKAN_CONFIG], + device_specs=adreno_devices, + ) + run_configs += benchmark_suites.iree.utils.generate_e2e_model_run_configs( + module_generation_configs=fuse_padding_repeated_kernel_gen_configs, + module_execution_configs=[ + module_execution_configs.VULKAN_BATCH_SIZE_16_CONFIG + ], + device_specs=adreno_devices, + ) - gen_configs = (default_gen_configs + fuse_padding_gen_configs + - fuse_padding_repeated_kernel_gen_configs) - return (gen_configs, run_configs) + gen_configs = ( + default_gen_configs + + fuse_padding_gen_configs + + fuse_padding_repeated_kernel_gen_configs + ) + return (gen_configs, run_configs) diff --git a/build_tools/python/benchmark_suites/iree/armv8_a_benchmarks.py b/build_tools/python/benchmark_suites/iree/armv8_a_benchmarks.py index 4d6279e0eb4c..48d49d81c63f 100644 --- a/build_tools/python/benchmark_suites/iree/armv8_a_benchmarks.py +++ b/build_tools/python/benchmark_suites/iree/armv8_a_benchmarks.py @@ -15,99 +15,115 @@ class Android_ARMv8_A_Benchmarks(object): - """Benchmarks on ARMv8-A Android devices.""" - NONQUANT_MODELS = [ - tflite_models.DEEPLABV3_FP32, - tflite_models.MOBILESSD_FP32, - tflite_models.POSENET_FP32, - tflite_models.MOBILEBERT_FP32, - tflite_models.MOBILENET_V2, - tflite_models.MOBILENET_V3SMALL, - ] - QUANT_MODELS = [tflite_models.MOBILEBERT_INT8] + """Benchmarks on ARMv8-A Android devices.""" - ARMV8_A_CPU_TARGET = iree_definitions.CompileTarget( - target_architecture=common_definitions.DeviceArchitecture. - ARMV8_2_A_GENERIC, - target_backend=iree_definitions.TargetBackend.LLVM_CPU, - target_abi=iree_definitions.TargetABI.LINUX_ANDROID29) + NONQUANT_MODELS = [ + tflite_models.DEEPLABV3_FP32, + tflite_models.MOBILESSD_FP32, + tflite_models.POSENET_FP32, + tflite_models.MOBILEBERT_FP32, + tflite_models.MOBILENET_V2, + tflite_models.MOBILENET_V3SMALL, + ] + QUANT_MODELS = [tflite_models.MOBILEBERT_INT8] - DEFAULT_COMPILE_CONFIG = iree_definitions.CompileConfig.build( - id=unique_ids.IREE_COMPILE_CONFIG_ANDROID_ARMV8_2_A_GENERIC_DEFAULTS, - tags=["default-flags"], - compile_targets=[ARMV8_A_CPU_TARGET]) - MMT4D_COMPILE_CONFIG = iree_definitions.CompileConfig.build( - id=unique_ids.IREE_COMPILE_CONFIG_ANDROID_ARMV8_2_A_GENERIC_MMT4D, - tags=["experimental-flags", "mmt4d"], - compile_targets=[ARMV8_A_CPU_TARGET], - extra_flags=[ - "--iree-flow-enable-data-tiling", - "--iree-flow-enable-fuse-padding-into-linalg-consumer-ops", - "--iree-llvmcpu-enable-pad-consumer-fusion" - ]) - MMT4D_AND_DOTPROD_COMPILE_CONFIG = iree_definitions.CompileConfig.build( - id=unique_ids.IREE_COMPILE_CONFIG_ANDROID_ARMV8_2_A_GENERIC_MMT4D_DOTPROD, - tags=["experimental-flags", "mmt4d", "dotprod"], - compile_targets=[ARMV8_A_CPU_TARGET], - extra_flags=[ - "--iree-flow-enable-data-tiling", - "--iree-llvmcpu-target-cpu-features=+dotprod", - "--iree-flow-enable-fuse-padding-into-linalg-consumer-ops", - "--iree-llvmcpu-enable-pad-consumer-fusion" - ]) + ARMV8_A_CPU_TARGET = iree_definitions.CompileTarget( + target_architecture=common_definitions.DeviceArchitecture.ARMV8_2_A_GENERIC, + target_backend=iree_definitions.TargetBackend.LLVM_CPU, + target_abi=iree_definitions.TargetABI.LINUX_ANDROID29, + ) - def generate( - self - ) -> Tuple[List[iree_definitions.ModuleGenerationConfig], - List[iree_definitions.E2EModelRunConfig]]: - """Generates IREE compile and run configs.""" + DEFAULT_COMPILE_CONFIG = iree_definitions.CompileConfig.build( + id=unique_ids.IREE_COMPILE_CONFIG_ANDROID_ARMV8_2_A_GENERIC_DEFAULTS, + tags=["default-flags"], + compile_targets=[ARMV8_A_CPU_TARGET], + ) + MMT4D_COMPILE_CONFIG = iree_definitions.CompileConfig.build( + id=unique_ids.IREE_COMPILE_CONFIG_ANDROID_ARMV8_2_A_GENERIC_MMT4D, + tags=["experimental-flags", "mmt4d"], + compile_targets=[ARMV8_A_CPU_TARGET], + extra_flags=[ + "--iree-flow-enable-data-tiling", + "--iree-flow-enable-fuse-padding-into-linalg-consumer-ops", + "--iree-llvmcpu-enable-pad-consumer-fusion", + ], + ) + MMT4D_AND_DOTPROD_COMPILE_CONFIG = iree_definitions.CompileConfig.build( + id=unique_ids.IREE_COMPILE_CONFIG_ANDROID_ARMV8_2_A_GENERIC_MMT4D_DOTPROD, + tags=["experimental-flags", "mmt4d", "dotprod"], + compile_targets=[ARMV8_A_CPU_TARGET], + extra_flags=[ + "--iree-flow-enable-data-tiling", + "--iree-llvmcpu-target-cpu-features=+dotprod", + "--iree-flow-enable-fuse-padding-into-linalg-consumer-ops", + "--iree-llvmcpu-enable-pad-consumer-fusion", + ], + ) - local_sync_execution_configs = [ - module_execution_configs.ELF_LOCAL_SYNC_CONFIG - ] - local_task_execution_configs = [ - module_execution_configs.get_elf_system_scheduling_local_task_config( - thread_num) for thread_num in [1, 4] - ] + def generate( + self, + ) -> Tuple[ + List[iree_definitions.ModuleGenerationConfig], + List[iree_definitions.E2EModelRunConfig], + ]: + """Generates IREE compile and run configs.""" - default_gen_confings = [ - iree_definitions.ModuleGenerationConfig.build( - compile_config=self.DEFAULT_COMPILE_CONFIG, - imported_model=iree_definitions.ImportedModel.from_model(model)) - for model in self.NONQUANT_MODELS + self.QUANT_MODELS - ] - experimental_gen_confings = [ - iree_definitions.ModuleGenerationConfig.build( - compile_config=self.MMT4D_COMPILE_CONFIG, - imported_model=iree_definitions.ImportedModel.from_model(model)) - for model in self.NONQUANT_MODELS - ] + [ - iree_definitions.ModuleGenerationConfig.build( - compile_config=self.MMT4D_AND_DOTPROD_COMPILE_CONFIG, - imported_model=iree_definitions.ImportedModel.from_model(model)) - for model in self.QUANT_MODELS - ] + local_sync_execution_configs = [module_execution_configs.ELF_LOCAL_SYNC_CONFIG] + local_task_execution_configs = [ + module_execution_configs.get_elf_system_scheduling_local_task_config( + thread_num + ) + for thread_num in [1, 4] + ] + + default_gen_confings = [ + iree_definitions.ModuleGenerationConfig.build( + compile_config=self.DEFAULT_COMPILE_CONFIG, + imported_model=iree_definitions.ImportedModel.from_model(model), + ) + for model in self.NONQUANT_MODELS + self.QUANT_MODELS + ] + experimental_gen_confings = [ + iree_definitions.ModuleGenerationConfig.build( + compile_config=self.MMT4D_COMPILE_CONFIG, + imported_model=iree_definitions.ImportedModel.from_model(model), + ) + for model in self.NONQUANT_MODELS + ] + [ + iree_definitions.ModuleGenerationConfig.build( + compile_config=self.MMT4D_AND_DOTPROD_COMPILE_CONFIG, + imported_model=iree_definitions.ImportedModel.from_model(model), + ) + for model in self.QUANT_MODELS + ] - all_devices = device_collections.DEFAULT_DEVICE_COLLECTION.query_device_specs( - architecture=common_definitions.DeviceArchitecture.ARMV8_2_A_GENERIC, - host_environment=common_definitions.HostEnvironment.ANDROID_ARMV8_2_A) - big_cores_devices = device_collections.DEFAULT_DEVICE_COLLECTION.query_device_specs( - architecture=common_definitions.DeviceArchitecture.ARMV8_2_A_GENERIC, - host_environment=common_definitions.HostEnvironment.ANDROID_ARMV8_2_A, - device_parameters={"big-cores"}) - run_configs = benchmark_suites.iree.utils.generate_e2e_model_run_configs( - module_generation_configs=default_gen_confings, - module_execution_configs=local_sync_execution_configs + - local_task_execution_configs, - device_specs=all_devices) - run_configs += benchmark_suites.iree.utils.generate_e2e_model_run_configs( - module_generation_configs=experimental_gen_confings, - module_execution_configs=local_sync_execution_configs, - device_specs=all_devices) - run_configs += benchmark_suites.iree.utils.generate_e2e_model_run_configs( - module_generation_configs=experimental_gen_confings, - module_execution_configs=local_task_execution_configs, - device_specs=big_cores_devices) + all_devices = device_collections.DEFAULT_DEVICE_COLLECTION.query_device_specs( + architecture=common_definitions.DeviceArchitecture.ARMV8_2_A_GENERIC, + host_environment=common_definitions.HostEnvironment.ANDROID_ARMV8_2_A, + ) + big_cores_devices = ( + device_collections.DEFAULT_DEVICE_COLLECTION.query_device_specs( + architecture=common_definitions.DeviceArchitecture.ARMV8_2_A_GENERIC, + host_environment=common_definitions.HostEnvironment.ANDROID_ARMV8_2_A, + device_parameters={"big-cores"}, + ) + ) + run_configs = benchmark_suites.iree.utils.generate_e2e_model_run_configs( + module_generation_configs=default_gen_confings, + module_execution_configs=local_sync_execution_configs + + local_task_execution_configs, + device_specs=all_devices, + ) + run_configs += benchmark_suites.iree.utils.generate_e2e_model_run_configs( + module_generation_configs=experimental_gen_confings, + module_execution_configs=local_sync_execution_configs, + device_specs=all_devices, + ) + run_configs += benchmark_suites.iree.utils.generate_e2e_model_run_configs( + module_generation_configs=experimental_gen_confings, + module_execution_configs=local_task_execution_configs, + device_specs=big_cores_devices, + ) - gen_confings = (default_gen_confings + experimental_gen_confings) - return (gen_confings, run_configs) + gen_confings = default_gen_confings + experimental_gen_confings + return (gen_confings, run_configs) diff --git a/build_tools/python/benchmark_suites/iree/benchmark_collections.py b/build_tools/python/benchmark_suites/iree/benchmark_collections.py index a2f3a14db20e..d2e1cc2471fd 100644 --- a/build_tools/python/benchmark_suites/iree/benchmark_collections.py +++ b/build_tools/python/benchmark_suites/iree/benchmark_collections.py @@ -9,60 +9,73 @@ from e2e_test_artifacts import iree_artifacts from e2e_test_framework.definitions import iree_definitions -from benchmark_suites.iree import (benchmark_tags, riscv_benchmarks, - x86_64_benchmarks, adreno_benchmarks, - armv8_a_benchmarks, cuda_benchmarks, - mali_benchmarks, vulkan_nvidia_benchmarks, - vmvx_benchmarks) +from benchmark_suites.iree import ( + benchmark_tags, + riscv_benchmarks, + x86_64_benchmarks, + adreno_benchmarks, + armv8_a_benchmarks, + cuda_benchmarks, + mali_benchmarks, + vulkan_nvidia_benchmarks, + vmvx_benchmarks, +) COMPILE_STATS_ID_SUFFIX = "-compile-stats" -def generate_benchmarks( -) -> Tuple[List[iree_definitions.ModuleGenerationConfig], - List[iree_definitions.E2EModelRunConfig]]: - benchmarks = [ - x86_64_benchmarks.Linux_x86_64_Benchmarks(), - cuda_benchmarks.Linux_CUDA_Benchmarks(), - riscv_benchmarks.Linux_RV64_Benchmarks(), - riscv_benchmarks.Linux_RV32_Benchmarks(), - armv8_a_benchmarks.Android_ARMv8_A_Benchmarks(), - adreno_benchmarks.Android_Adreno_Benchmarks(), - mali_benchmarks.Android_Mali_Benchmarks(), - vulkan_nvidia_benchmarks.Linux_Vulkan_NVIDIA_Benchmarks(), - vmvx_benchmarks.Android_VMVX_Benchmarks() - ] - all_gen_configs: List[iree_definitions.ModuleGenerationConfig] = [] - all_run_configs: List[iree_definitions.E2EModelRunConfig] = [] - for benchmark in benchmarks: - module_generation_configs, run_configs = benchmark.generate() - all_gen_configs += module_generation_configs - all_run_configs += run_configs +def generate_benchmarks() -> ( + Tuple[ + List[iree_definitions.ModuleGenerationConfig], + List[iree_definitions.E2EModelRunConfig], + ] +): + benchmarks = [ + x86_64_benchmarks.Linux_x86_64_Benchmarks(), + cuda_benchmarks.Linux_CUDA_Benchmarks(), + riscv_benchmarks.Linux_RV64_Benchmarks(), + riscv_benchmarks.Linux_RV32_Benchmarks(), + armv8_a_benchmarks.Android_ARMv8_A_Benchmarks(), + adreno_benchmarks.Android_Adreno_Benchmarks(), + mali_benchmarks.Android_Mali_Benchmarks(), + vulkan_nvidia_benchmarks.Linux_Vulkan_NVIDIA_Benchmarks(), + vmvx_benchmarks.Android_VMVX_Benchmarks(), + ] + all_gen_configs: List[iree_definitions.ModuleGenerationConfig] = [] + all_run_configs: List[iree_definitions.E2EModelRunConfig] = [] + for benchmark in benchmarks: + module_generation_configs, run_configs = benchmark.generate() + all_gen_configs += module_generation_configs + all_run_configs += run_configs - compile_stats_gen_configs = [] - # For now we simply track compilation statistics of all modules. - for gen_config in all_gen_configs: - compile_config = gen_config.compile_config - # Use POSIX path, see the comment of iree_definitions.MODULE_DIR_VARIABLE. - scheduling_stats_path = f"{iree_definitions.MODULE_DIR_VARIABLE}/{iree_artifacts.SCHEDULING_STATS_FILENAME}" - compile_stats_config = iree_definitions.CompileConfig.build( - id=compile_config.id + COMPILE_STATS_ID_SUFFIX, - tags=compile_config.tags + [benchmark_tags.COMPILE_STATS], - compile_targets=compile_config.compile_targets, - extra_flags=compile_config.extra_flags + [ - # Enable zip polyglot to provide component sizes. - "--iree-vm-emit-polyglot-zip=true", - # Disable debug symbols to provide correct component sizes. - "--iree-llvmcpu-debug-symbols=false", - # Dump scheduling statistics - "--iree-scheduling-dump-statistics-format=json", - f"--iree-scheduling-dump-statistics-file={scheduling_stats_path}" - ]) - compile_stats_gen_configs.append( - iree_definitions.ModuleGenerationConfig.build( - imported_model=gen_config.imported_model, - compile_config=compile_stats_config, - tags=gen_config.tags)) - all_gen_configs += compile_stats_gen_configs + compile_stats_gen_configs = [] + # For now we simply track compilation statistics of all modules. + for gen_config in all_gen_configs: + compile_config = gen_config.compile_config + # Use POSIX path, see the comment of iree_definitions.MODULE_DIR_VARIABLE. + scheduling_stats_path = f"{iree_definitions.MODULE_DIR_VARIABLE}/{iree_artifacts.SCHEDULING_STATS_FILENAME}" + compile_stats_config = iree_definitions.CompileConfig.build( + id=compile_config.id + COMPILE_STATS_ID_SUFFIX, + tags=compile_config.tags + [benchmark_tags.COMPILE_STATS], + compile_targets=compile_config.compile_targets, + extra_flags=compile_config.extra_flags + + [ + # Enable zip polyglot to provide component sizes. + "--iree-vm-emit-polyglot-zip=true", + # Disable debug symbols to provide correct component sizes. + "--iree-llvmcpu-debug-symbols=false", + # Dump scheduling statistics + "--iree-scheduling-dump-statistics-format=json", + f"--iree-scheduling-dump-statistics-file={scheduling_stats_path}", + ], + ) + compile_stats_gen_configs.append( + iree_definitions.ModuleGenerationConfig.build( + imported_model=gen_config.imported_model, + compile_config=compile_stats_config, + tags=gen_config.tags, + ) + ) + all_gen_configs += compile_stats_gen_configs - return (all_gen_configs, all_run_configs) + return (all_gen_configs, all_run_configs) diff --git a/build_tools/python/benchmark_suites/iree/cuda_benchmarks.py b/build_tools/python/benchmark_suites/iree/cuda_benchmarks.py index 1174235d1edf..fd62c1a89a6c 100644 --- a/build_tools/python/benchmark_suites/iree/cuda_benchmarks.py +++ b/build_tools/python/benchmark_suites/iree/cuda_benchmarks.py @@ -15,81 +15,103 @@ class Linux_CUDA_Benchmarks(object): - """Benchmarks on CUDA Linux devices.""" + """Benchmarks on CUDA Linux devices.""" - SM_80_GPU_TARGET = iree_definitions.CompileTarget( - target_architecture=common_definitions.DeviceArchitecture.CUDA_SM80, - target_backend=iree_definitions.TargetBackend.CUDA, - target_abi=iree_definitions.TargetABI.LINUX_GNU) - SM_80_COMPILE_CONFIG = iree_definitions.CompileConfig.build( - id=unique_ids.IREE_COMPILE_CONFIG_LINUX_CUDA_SM80_DEFAULTS, - tags=["default-flags"], - compile_targets=[SM_80_GPU_TARGET]) - SM_80_UBENCH_MATMUL_COMPILE_CONFIG = iree_definitions.CompileConfig.build( - id=unique_ids.IREE_COMPILE_CONFIG_LINUX_CUDA_SM80_MATMUL_UBENCH, - tags=["ukernel", "matmul"], - compile_targets=[SM_80_GPU_TARGET], - extra_flags=["--iree-hal-benchmark-dispatch-repeat-count=100"]) - SM_80_UBENCH_MATMUL_SPLITK_COMPILE_CONFIG = iree_definitions.CompileConfig.build( - id=unique_ids.IREE_COMPILE_CONFIG_LINUX_CUDA_SM80_MATMUL_SPLITK_UBENCH, - tags=["ukernel", "matmul", "splitk"], - compile_targets=[SM_80_GPU_TARGET], - extra_flags=[ - "--iree-hal-benchmark-dispatch-repeat-count=100", - "--iree-flow-split-matmul-reduction=4", - "--iree-codegen-llvmgpu-use-wmma" - ]) + SM_80_GPU_TARGET = iree_definitions.CompileTarget( + target_architecture=common_definitions.DeviceArchitecture.CUDA_SM80, + target_backend=iree_definitions.TargetBackend.CUDA, + target_abi=iree_definitions.TargetABI.LINUX_GNU, + ) + SM_80_COMPILE_CONFIG = iree_definitions.CompileConfig.build( + id=unique_ids.IREE_COMPILE_CONFIG_LINUX_CUDA_SM80_DEFAULTS, + tags=["default-flags"], + compile_targets=[SM_80_GPU_TARGET], + ) + SM_80_UBENCH_MATMUL_COMPILE_CONFIG = iree_definitions.CompileConfig.build( + id=unique_ids.IREE_COMPILE_CONFIG_LINUX_CUDA_SM80_MATMUL_UBENCH, + tags=["ukernel", "matmul"], + compile_targets=[SM_80_GPU_TARGET], + extra_flags=["--iree-hal-benchmark-dispatch-repeat-count=100"], + ) + SM_80_UBENCH_MATMUL_SPLITK_COMPILE_CONFIG = iree_definitions.CompileConfig.build( + id=unique_ids.IREE_COMPILE_CONFIG_LINUX_CUDA_SM80_MATMUL_SPLITK_UBENCH, + tags=["ukernel", "matmul", "splitk"], + compile_targets=[SM_80_GPU_TARGET], + extra_flags=[ + "--iree-hal-benchmark-dispatch-repeat-count=100", + "--iree-flow-split-matmul-reduction=4", + "--iree-codegen-llvmgpu-use-wmma", + ], + ) - def _generate_configs( - self, - models: Sequence[common_definitions.Model], - compile_config: iree_definitions.CompileConfig, - execution_config: iree_definitions. - ModuleExecutionConfig = module_execution_configs.CUDA_CONFIG, - tags: Sequence[str] = (), - ) -> Tuple[List[iree_definitions.ModuleGenerationConfig], - List[iree_definitions.E2EModelRunConfig]]: - gen_configs = [ - iree_definitions.ModuleGenerationConfig.build( - compile_config=compile_config, - imported_model=iree_definitions.ImportedModel.from_model(model), - tags=tags) for model in models - ] - sm80_devices = device_collections.DEFAULT_DEVICE_COLLECTION.query_device_specs( - architecture=common_definitions.DeviceArchitecture.NVIDIA_AMPERE, - host_environment=common_definitions.HostEnvironment.LINUX_X86_64) - run_module_configs = benchmark_suites.iree.utils.generate_e2e_model_run_configs( - module_generation_configs=gen_configs, - module_execution_configs=[execution_config], - device_specs=sm80_devices, - tags=tags) + def _generate_configs( + self, + models: Sequence[common_definitions.Model], + compile_config: iree_definitions.CompileConfig, + execution_config: iree_definitions.ModuleExecutionConfig = module_execution_configs.CUDA_CONFIG, + tags: Sequence[str] = (), + ) -> Tuple[ + List[iree_definitions.ModuleGenerationConfig], + List[iree_definitions.E2EModelRunConfig], + ]: + gen_configs = [ + iree_definitions.ModuleGenerationConfig.build( + compile_config=compile_config, + imported_model=iree_definitions.ImportedModel.from_model(model), + tags=tags, + ) + for model in models + ] + sm80_devices = device_collections.DEFAULT_DEVICE_COLLECTION.query_device_specs( + architecture=common_definitions.DeviceArchitecture.NVIDIA_AMPERE, + host_environment=common_definitions.HostEnvironment.LINUX_X86_64, + ) + run_module_configs = benchmark_suites.iree.utils.generate_e2e_model_run_configs( + module_generation_configs=gen_configs, + module_execution_configs=[execution_config], + device_specs=sm80_devices, + tags=tags, + ) - return (gen_configs, run_module_configs) + return (gen_configs, run_module_configs) - def generate( - self - ) -> Tuple[List[iree_definitions.ModuleGenerationConfig], - List[iree_definitions.E2EModelRunConfig]]: - """Generates IREE compile and run configs.""" - # The CUDA tag is required to put them into the CUDA benchmark preset. - gen_configs, run_configs = self._generate_configs( - model_groups.CUDA_MODELS, - self.SM_80_COMPILE_CONFIG, - tags=[benchmark_tags.CUDA]) - ubench_gen_configs, ubench_run_configs = self._generate_configs( - model_groups.MICRO_MATMUL, - self.SM_80_UBENCH_MATMUL_COMPILE_CONFIG, - execution_config=module_execution_configs.CUDA_BATCH_SIZE_100_CONFIG, - tags=[benchmark_tags.CUDA]) - ubench_splitk_gen_configs, ubench_splitk_run_configs = self._generate_configs( - model_groups.MICRO_MATMUL_SPLITK, - self.SM_80_UBENCH_MATMUL_SPLITK_COMPILE_CONFIG, - execution_config=module_execution_configs.CUDA_BATCH_SIZE_100_CONFIG, - tags=[benchmark_tags.CUDA]) - large_gen_configs, large_module_configs = self._generate_configs( - model_groups.CUDA_MODELS_LONG, - self.SM_80_COMPILE_CONFIG, - tags=[benchmark_tags.CUDA, benchmark_tags.LARGE]) - return (gen_configs + ubench_gen_configs + ubench_splitk_gen_configs + - large_gen_configs, run_configs + ubench_run_configs + - ubench_splitk_run_configs + large_module_configs) + def generate( + self, + ) -> Tuple[ + List[iree_definitions.ModuleGenerationConfig], + List[iree_definitions.E2EModelRunConfig], + ]: + """Generates IREE compile and run configs.""" + # The CUDA tag is required to put them into the CUDA benchmark preset. + gen_configs, run_configs = self._generate_configs( + model_groups.CUDA_MODELS, + self.SM_80_COMPILE_CONFIG, + tags=[benchmark_tags.CUDA], + ) + ubench_gen_configs, ubench_run_configs = self._generate_configs( + model_groups.MICRO_MATMUL, + self.SM_80_UBENCH_MATMUL_COMPILE_CONFIG, + execution_config=module_execution_configs.CUDA_BATCH_SIZE_100_CONFIG, + tags=[benchmark_tags.CUDA], + ) + ubench_splitk_gen_configs, ubench_splitk_run_configs = self._generate_configs( + model_groups.MICRO_MATMUL_SPLITK, + self.SM_80_UBENCH_MATMUL_SPLITK_COMPILE_CONFIG, + execution_config=module_execution_configs.CUDA_BATCH_SIZE_100_CONFIG, + tags=[benchmark_tags.CUDA], + ) + large_gen_configs, large_module_configs = self._generate_configs( + model_groups.CUDA_MODELS_LONG, + self.SM_80_COMPILE_CONFIG, + tags=[benchmark_tags.CUDA, benchmark_tags.LARGE], + ) + return ( + gen_configs + + ubench_gen_configs + + ubench_splitk_gen_configs + + large_gen_configs, + run_configs + + ubench_run_configs + + ubench_splitk_run_configs + + large_module_configs, + ) diff --git a/build_tools/python/benchmark_suites/iree/mali_benchmarks.py b/build_tools/python/benchmark_suites/iree/mali_benchmarks.py index 2cd371bf1a93..dd5e044983c8 100644 --- a/build_tools/python/benchmark_suites/iree/mali_benchmarks.py +++ b/build_tools/python/benchmark_suites/iree/mali_benchmarks.py @@ -15,123 +15,146 @@ class Android_Mali_Benchmarks(object): - """Benchmarks on Android devices with Mali GPU.""" + """Benchmarks on Android devices with Mali GPU.""" - ARM_VALHALL_GPU_TARGET = iree_definitions.CompileTarget( - target_backend=iree_definitions.TargetBackend.VULKAN_SPIRV, - target_architecture=common_definitions.DeviceArchitecture.ARM_VALHALL, - target_abi=iree_definitions.TargetABI.VULKAN_ANDROID31) - DEFAULT_COMPILE_CONFIG = iree_definitions.CompileConfig.build( - id=unique_ids.IREE_COMPILE_CONFIG_ANDROID_ARM_VALHALL_DEFAULTS, - tags=["default-flags"], - compile_targets=[ARM_VALHALL_GPU_TARGET]) - EXPERIMENTAL_COMPILE_CONFIG = iree_definitions.CompileConfig.build( - id=unique_ids.IREE_COMPILE_CONFIG_ANDROID_ARM_VALHALL_EXPERIMENTAL, - tags=["experimental-flags", "fuse-padding", "max-concurrency"], - compile_targets=[ARM_VALHALL_GPU_TARGET], - extra_flags=[ - "--iree-flow-enable-fuse-padding-into-linalg-consumer-ops", - "--iree-stream-partitioning-favor=max-concurrency" - ]) - # Kernel execution - # Note that for kernel-execution benchmarks batch_size/repeat-count need to be - # low enough that the whole dispatch completes within an OS-specific timeout. - # Otherwise you'll get error like: - # ``` - # INTERNAL; VK_ERROR_DEVICE_LOST; vkQueueSubmit; while invoking native function - # hal.fence.await; while calling import; - # ``` - EXPERIMENTAL_REPEATED_KERNEL_COMPILE_CONFIG = iree_definitions.CompileConfig.build( - id=unique_ids. - IREE_COMPILE_CONFIG_ANDROID_ARM_VALHALL_EXPERIMENTAL_REPEATED_KERNEL, - tags=[ - "experimental-flags", "fuse-padding", "max-concurrency", - "repeated-kernel" - ], - compile_targets=[ARM_VALHALL_GPU_TARGET], - extra_flags=EXPERIMENTAL_COMPILE_CONFIG.extra_flags + - ["--iree-hal-benchmark-dispatch-repeat-count=32"]) - EXPERIMENTAL_REPEATED_KERNEL_RUN_FLAGS = ["--batch_size=32"] + ARM_VALHALL_GPU_TARGET = iree_definitions.CompileTarget( + target_backend=iree_definitions.TargetBackend.VULKAN_SPIRV, + target_architecture=common_definitions.DeviceArchitecture.ARM_VALHALL, + target_abi=iree_definitions.TargetABI.VULKAN_ANDROID31, + ) + DEFAULT_COMPILE_CONFIG = iree_definitions.CompileConfig.build( + id=unique_ids.IREE_COMPILE_CONFIG_ANDROID_ARM_VALHALL_DEFAULTS, + tags=["default-flags"], + compile_targets=[ARM_VALHALL_GPU_TARGET], + ) + EXPERIMENTAL_COMPILE_CONFIG = iree_definitions.CompileConfig.build( + id=unique_ids.IREE_COMPILE_CONFIG_ANDROID_ARM_VALHALL_EXPERIMENTAL, + tags=["experimental-flags", "fuse-padding", "max-concurrency"], + compile_targets=[ARM_VALHALL_GPU_TARGET], + extra_flags=[ + "--iree-flow-enable-fuse-padding-into-linalg-consumer-ops", + "--iree-stream-partitioning-favor=max-concurrency", + ], + ) + # Kernel execution + # Note that for kernel-execution benchmarks batch_size/repeat-count need to be + # low enough that the whole dispatch completes within an OS-specific timeout. + # Otherwise you'll get error like: + # ``` + # INTERNAL; VK_ERROR_DEVICE_LOST; vkQueueSubmit; while invoking native function + # hal.fence.await; while calling import; + # ``` + EXPERIMENTAL_REPEATED_KERNEL_COMPILE_CONFIG = iree_definitions.CompileConfig.build( + id=unique_ids.IREE_COMPILE_CONFIG_ANDROID_ARM_VALHALL_EXPERIMENTAL_REPEATED_KERNEL, + tags=[ + "experimental-flags", + "fuse-padding", + "max-concurrency", + "repeated-kernel", + ], + compile_targets=[ARM_VALHALL_GPU_TARGET], + extra_flags=EXPERIMENTAL_COMPILE_CONFIG.extra_flags + + ["--iree-hal-benchmark-dispatch-repeat-count=32"], + ) + EXPERIMENTAL_REPEATED_KERNEL_RUN_FLAGS = ["--batch_size=32"] - FP32_MODELS = [ - tflite_models.DEEPLABV3_FP32, - tflite_models.MOBILESSD_FP32, - tflite_models.POSENET_FP32, - tflite_models.MOBILEBERT_FP32, - tflite_models.MOBILENET_V2, - tflite_models.MOBILENET_V3SMALL, - ] - FP16_MODELS = [tflite_models.MOBILEBERT_FP16] - QUANT_MODELS = [ - tflite_models.MOBILEBERT_INT8, - tflite_models.EFFICIENTNET_INT8, - tflite_models.PERSON_DETECT_INT8, - ] + FP32_MODELS = [ + tflite_models.DEEPLABV3_FP32, + tflite_models.MOBILESSD_FP32, + tflite_models.POSENET_FP32, + tflite_models.MOBILEBERT_FP32, + tflite_models.MOBILENET_V2, + tflite_models.MOBILENET_V3SMALL, + ] + FP16_MODELS = [tflite_models.MOBILEBERT_FP16] + QUANT_MODELS = [ + tflite_models.MOBILEBERT_INT8, + tflite_models.EFFICIENTNET_INT8, + tflite_models.PERSON_DETECT_INT8, + ] - def generate( - self - ) -> Tuple[List[iree_definitions.ModuleGenerationConfig], - List[iree_definitions.E2EModelRunConfig]]: - default_gen_configs = self._get_module_generation_configs( - compile_config=self.DEFAULT_COMPILE_CONFIG, - fp32_models=self.FP32_MODELS, - fp16_models=self.FP16_MODELS, - quant_models=self.QUANT_MODELS) - experimental_gen_configs = self._get_module_generation_configs( - compile_config=self.EXPERIMENTAL_COMPILE_CONFIG, - fp32_models=self.FP32_MODELS, - fp16_models=self.FP16_MODELS, - quant_models=self.QUANT_MODELS) - experimental_repeated_kernel_gen_configs = self._get_module_generation_configs( - compile_config=self.EXPERIMENTAL_REPEATED_KERNEL_COMPILE_CONFIG, - fp32_models=self.FP32_MODELS, - fp16_models=self.FP16_MODELS, - quant_models=self.QUANT_MODELS) + def generate( + self, + ) -> Tuple[ + List[iree_definitions.ModuleGenerationConfig], + List[iree_definitions.E2EModelRunConfig], + ]: + default_gen_configs = self._get_module_generation_configs( + compile_config=self.DEFAULT_COMPILE_CONFIG, + fp32_models=self.FP32_MODELS, + fp16_models=self.FP16_MODELS, + quant_models=self.QUANT_MODELS, + ) + experimental_gen_configs = self._get_module_generation_configs( + compile_config=self.EXPERIMENTAL_COMPILE_CONFIG, + fp32_models=self.FP32_MODELS, + fp16_models=self.FP16_MODELS, + quant_models=self.QUANT_MODELS, + ) + experimental_repeated_kernel_gen_configs = self._get_module_generation_configs( + compile_config=self.EXPERIMENTAL_REPEATED_KERNEL_COMPILE_CONFIG, + fp32_models=self.FP32_MODELS, + fp16_models=self.FP16_MODELS, + quant_models=self.QUANT_MODELS, + ) - mali_devices = device_collections.DEFAULT_DEVICE_COLLECTION.query_device_specs( - architecture=common_definitions.DeviceArchitecture.ARM_VALHALL, - host_environment=common_definitions.HostEnvironment.ANDROID_ARMV8_2_A) - run_configs = benchmark_suites.iree.utils.generate_e2e_model_run_configs( - module_generation_configs=default_gen_configs + - experimental_gen_configs, - module_execution_configs=[module_execution_configs.VULKAN_CONFIG], - device_specs=mali_devices) - run_configs += benchmark_suites.iree.utils.generate_e2e_model_run_configs( - module_generation_configs=experimental_repeated_kernel_gen_configs, - module_execution_configs=[ - module_execution_configs.VULKAN_BATCH_SIZE_32_CONFIG - ], - device_specs=mali_devices) + mali_devices = device_collections.DEFAULT_DEVICE_COLLECTION.query_device_specs( + architecture=common_definitions.DeviceArchitecture.ARM_VALHALL, + host_environment=common_definitions.HostEnvironment.ANDROID_ARMV8_2_A, + ) + run_configs = benchmark_suites.iree.utils.generate_e2e_model_run_configs( + module_generation_configs=default_gen_configs + experimental_gen_configs, + module_execution_configs=[module_execution_configs.VULKAN_CONFIG], + device_specs=mali_devices, + ) + run_configs += benchmark_suites.iree.utils.generate_e2e_model_run_configs( + module_generation_configs=experimental_repeated_kernel_gen_configs, + module_execution_configs=[ + module_execution_configs.VULKAN_BATCH_SIZE_32_CONFIG + ], + device_specs=mali_devices, + ) - gen_configs = (default_gen_configs + experimental_gen_configs + - experimental_repeated_kernel_gen_configs) - return (gen_configs, run_configs) + gen_configs = ( + default_gen_configs + + experimental_gen_configs + + experimental_repeated_kernel_gen_configs + ) + return (gen_configs, run_configs) - def _get_module_generation_configs( - self, compile_config: iree_definitions.CompileConfig, - fp32_models: Sequence[common_definitions.Model], - fp16_models: Sequence[common_definitions.Model], - quant_models: Sequence[common_definitions.Model] - ) -> List[iree_definitions.ModuleGenerationConfig]: - demote_compile_config = iree_definitions.CompileConfig.build( - id=compile_config.id + "-demote-f32-to-16", - tags=compile_config.tags + ["demote-f32-to-f16"], - compile_targets=compile_config.compile_targets, - extra_flags=compile_config.extra_flags + - ["--iree-flow-demote-f32-to-f16"]) - return [ - iree_definitions.ModuleGenerationConfig.build( - compile_config=compile_config, - imported_model=iree_definitions.ImportedModel.from_model(model)) - for model in fp32_models - ] + [ - iree_definitions.ModuleGenerationConfig.build( - compile_config=demote_compile_config, - imported_model=iree_definitions.ImportedModel.from_model(model)) - for model in fp16_models - ] + [ - iree_definitions.ModuleGenerationConfig.build( - compile_config=compile_config, - imported_model=iree_definitions.ImportedModel.from_model(model)) - for model in quant_models - ] + def _get_module_generation_configs( + self, + compile_config: iree_definitions.CompileConfig, + fp32_models: Sequence[common_definitions.Model], + fp16_models: Sequence[common_definitions.Model], + quant_models: Sequence[common_definitions.Model], + ) -> List[iree_definitions.ModuleGenerationConfig]: + demote_compile_config = iree_definitions.CompileConfig.build( + id=compile_config.id + "-demote-f32-to-16", + tags=compile_config.tags + ["demote-f32-to-f16"], + compile_targets=compile_config.compile_targets, + extra_flags=compile_config.extra_flags + ["--iree-flow-demote-f32-to-f16"], + ) + return ( + [ + iree_definitions.ModuleGenerationConfig.build( + compile_config=compile_config, + imported_model=iree_definitions.ImportedModel.from_model(model), + ) + for model in fp32_models + ] + + [ + iree_definitions.ModuleGenerationConfig.build( + compile_config=demote_compile_config, + imported_model=iree_definitions.ImportedModel.from_model(model), + ) + for model in fp16_models + ] + + [ + iree_definitions.ModuleGenerationConfig.build( + compile_config=compile_config, + imported_model=iree_definitions.ImportedModel.from_model(model), + ) + for model in quant_models + ] + ) diff --git a/build_tools/python/benchmark_suites/iree/module_execution_configs.py b/build_tools/python/benchmark_suites/iree/module_execution_configs.py index a8ffcf7654bf..3a0dc038dd29 100644 --- a/build_tools/python/benchmark_suites/iree/module_execution_configs.py +++ b/build_tools/python/benchmark_suites/iree/module_execution_configs.py @@ -16,92 +16,107 @@ def _with_caching_allocator( tags: List[str], loader: iree_definitions.RuntimeLoader, driver: iree_definitions.RuntimeDriver, - extra_flags: Optional[Sequence[str]] = None + extra_flags: Optional[Sequence[str]] = None, ) -> iree_definitions.ModuleExecutionConfig: - extra_flags = [] if extra_flags is None else list(extra_flags) - return iree_definitions.ModuleExecutionConfig.build( - id=id, - tags=tags, - loader=loader, - driver=driver, - extra_flags=["--device_allocator=caching"] + extra_flags) + extra_flags = [] if extra_flags is None else list(extra_flags) + return iree_definitions.ModuleExecutionConfig.build( + id=id, + tags=tags, + loader=loader, + driver=driver, + extra_flags=["--device_allocator=caching"] + extra_flags, + ) ELF_LOCAL_SYNC_CONFIG = _with_caching_allocator( id=unique_ids.IREE_MODULE_EXECUTION_CONFIG_LOCAL_SYNC, tags=["full-inference", "default-flags"], loader=iree_definitions.RuntimeLoader.EMBEDDED_ELF, - driver=iree_definitions.RuntimeDriver.LOCAL_SYNC) + driver=iree_definitions.RuntimeDriver.LOCAL_SYNC, +) CUDA_CONFIG = _with_caching_allocator( id=unique_ids.IREE_MODULE_EXECUTION_CONFIG_CUDA, tags=["full-inference", "default-flags"], loader=iree_definitions.RuntimeLoader.NONE, - driver=iree_definitions.RuntimeDriver.CUDA) + driver=iree_definitions.RuntimeDriver.CUDA, +) CUDA_BATCH_SIZE_100_CONFIG = _with_caching_allocator( id=unique_ids.IREE_MODULE_EXECUTION_CONFIG_CUDA, tags=["full-inference", "default-flags"], loader=iree_definitions.RuntimeLoader.NONE, driver=iree_definitions.RuntimeDriver.CUDA, - extra_flags=["--batch_size=100"]) + extra_flags=["--batch_size=100"], +) VULKAN_CONFIG = _with_caching_allocator( id=unique_ids.IREE_MODULE_EXECUTION_CONFIG_VULKAN, tags=["full-inference", "default-flags"], loader=iree_definitions.RuntimeLoader.NONE, - driver=iree_definitions.RuntimeDriver.VULKAN) + driver=iree_definitions.RuntimeDriver.VULKAN, +) VULKAN_BATCH_SIZE_16_CONFIG = _with_caching_allocator( id=unique_ids.IREE_MODULE_EXECUTION_CONFIG_VULKAN_BATCH_SIZE_16, tags=["full-inference", "experimental-flags"], loader=iree_definitions.RuntimeLoader.NONE, driver=iree_definitions.RuntimeDriver.VULKAN, - extra_flags=["--batch_size=16"]) + extra_flags=["--batch_size=16"], +) VULKAN_BATCH_SIZE_32_CONFIG = _with_caching_allocator( id=unique_ids.IREE_MODULE_EXECUTION_CONFIG_VULKAN_BATCH_SIZE_32, tags=["full-inference", "experimental-flags"], loader=iree_definitions.RuntimeLoader.NONE, driver=iree_definitions.RuntimeDriver.VULKAN, - extra_flags=["--batch_size=32"]) + extra_flags=["--batch_size=32"], +) def get_elf_local_task_config(thread_num: int): - config_id = f"{unique_ids.IREE_MODULE_EXECUTION_CONFIG_LOCAL_TASK_BASE}-{thread_num}" - return _with_caching_allocator( - id=config_id, - tags=[f"{thread_num}-thread", "full-inference", "default-flags"], - loader=iree_definitions.RuntimeLoader.EMBEDDED_ELF, - driver=iree_definitions.RuntimeDriver.LOCAL_TASK, - extra_flags=[f"--task_topology_max_group_count={thread_num}"]) + config_id = ( + f"{unique_ids.IREE_MODULE_EXECUTION_CONFIG_LOCAL_TASK_BASE}-{thread_num}" + ) + return _with_caching_allocator( + id=config_id, + tags=[f"{thread_num}-thread", "full-inference", "default-flags"], + loader=iree_definitions.RuntimeLoader.EMBEDDED_ELF, + driver=iree_definitions.RuntimeDriver.LOCAL_TASK, + extra_flags=[f"--task_topology_max_group_count={thread_num}"], + ) def get_vmvx_local_task_config(thread_num: int): - config_id = f"{unique_ids.IREE_MODULE_EXECUTION_CONFIG_VMVX_LOCAL_TASK_BASE}-{thread_num}" - return _with_caching_allocator( - id=config_id, - tags=[f"{thread_num}-thread", "full-inference", "default-flags"], - loader=iree_definitions.RuntimeLoader.VMVX_MODULE, - driver=iree_definitions.RuntimeDriver.LOCAL_TASK, - extra_flags=[f"--task_topology_max_group_count={thread_num}"]) + config_id = ( + f"{unique_ids.IREE_MODULE_EXECUTION_CONFIG_VMVX_LOCAL_TASK_BASE}-{thread_num}" + ) + return _with_caching_allocator( + id=config_id, + tags=[f"{thread_num}-thread", "full-inference", "default-flags"], + loader=iree_definitions.RuntimeLoader.VMVX_MODULE, + driver=iree_definitions.RuntimeDriver.LOCAL_TASK, + extra_flags=[f"--task_topology_max_group_count={thread_num}"], + ) def get_elf_system_scheduling_local_task_config(thread_num: int): - config_id = f"{unique_ids.IREE_MODULE_EXECUTION_CONFIG_SYS_SCHED_LOCAL_TASK_BASE}-{thread_num}" - return _with_caching_allocator( - id=config_id, - tags=[f"{thread_num}-thread", "full-inference", "system-scheduling"], - loader=iree_definitions.RuntimeLoader.EMBEDDED_ELF, - driver=iree_definitions.RuntimeDriver.LOCAL_TASK, - extra_flags=[f"--task_topology_group_count={thread_num}"]) + config_id = f"{unique_ids.IREE_MODULE_EXECUTION_CONFIG_SYS_SCHED_LOCAL_TASK_BASE}-{thread_num}" + return _with_caching_allocator( + id=config_id, + tags=[f"{thread_num}-thread", "full-inference", "system-scheduling"], + loader=iree_definitions.RuntimeLoader.EMBEDDED_ELF, + driver=iree_definitions.RuntimeDriver.LOCAL_TASK, + extra_flags=[f"--task_topology_group_count={thread_num}"], + ) def get_vmvx_system_scheduling_local_task_config(thread_num: int): - config_id = f"{unique_ids.IREE_MODULE_EXECUTION_CONFIG_VMVX_SYS_SCHED_LOCAL_TASK_BASE}-{thread_num}" - return _with_caching_allocator( - id=config_id, - tags=[f"{thread_num}-thread", "full-inference", "system-scheduling"], - loader=iree_definitions.RuntimeLoader.VMVX_MODULE, - driver=iree_definitions.RuntimeDriver.LOCAL_TASK, - extra_flags=[f"--task_topology_group_count={thread_num}"]) + config_id = f"{unique_ids.IREE_MODULE_EXECUTION_CONFIG_VMVX_SYS_SCHED_LOCAL_TASK_BASE}-{thread_num}" + return _with_caching_allocator( + id=config_id, + tags=[f"{thread_num}-thread", "full-inference", "system-scheduling"], + loader=iree_definitions.RuntimeLoader.VMVX_MODULE, + driver=iree_definitions.RuntimeDriver.LOCAL_TASK, + extra_flags=[f"--task_topology_group_count={thread_num}"], + ) diff --git a/build_tools/python/benchmark_suites/iree/riscv_benchmarks.py b/build_tools/python/benchmark_suites/iree/riscv_benchmarks.py index 477fa46a17db..c34339ee6762 100644 --- a/build_tools/python/benchmark_suites/iree/riscv_benchmarks.py +++ b/build_tools/python/benchmark_suites/iree/riscv_benchmarks.py @@ -12,67 +12,77 @@ class Linux_RV64_Benchmarks(object): - """Benchmarks RV64 on Linux devices.""" + """Benchmarks RV64 on Linux devices.""" - RV64_CPU_TARGET = iree_definitions.CompileTarget( - target_backend=iree_definitions.TargetBackend.LLVM_CPU, - target_architecture=common_definitions.DeviceArchitecture.RV64_GENERIC, - target_abi=iree_definitions.TargetABI.LINUX_GNU) - DEFAULT_COMPILE_CONFIG = iree_definitions.CompileConfig.build( - id=unique_ids.IREE_COMPILE_CONFIG_LINUX_RV64_GENERIC_DEFAULTS, - tags=["default-flags"], - compile_targets=[RV64_CPU_TARGET]) - MODELS = [ - tflite_models.DEEPLABV3_FP32, - tflite_models.MOBILEBERT_FP32, - tflite_models.MOBILENET_V1, - tflite_models.MOBILEBERT_INT8, - tflite_models.PERSON_DETECT_INT8, - tflite_models.EFFICIENTNET_INT8, - tflite_models.MOBILENET_V2_INT8, - ] - - def generate( - self - ) -> Tuple[List[iree_definitions.ModuleGenerationConfig], - List[iree_definitions.E2EModelRunConfig]]: - """Generates IREE compile and run configs.""" - gen_configs = [ - iree_definitions.ModuleGenerationConfig.build( - compile_config=self.DEFAULT_COMPILE_CONFIG, - imported_model=iree_definitions.ImportedModel.from_model(model)) - for model in self.MODELS + RV64_CPU_TARGET = iree_definitions.CompileTarget( + target_backend=iree_definitions.TargetBackend.LLVM_CPU, + target_architecture=common_definitions.DeviceArchitecture.RV64_GENERIC, + target_abi=iree_definitions.TargetABI.LINUX_GNU, + ) + DEFAULT_COMPILE_CONFIG = iree_definitions.CompileConfig.build( + id=unique_ids.IREE_COMPILE_CONFIG_LINUX_RV64_GENERIC_DEFAULTS, + tags=["default-flags"], + compile_targets=[RV64_CPU_TARGET], + ) + MODELS = [ + tflite_models.DEEPLABV3_FP32, + tflite_models.MOBILEBERT_FP32, + tflite_models.MOBILENET_V1, + tflite_models.MOBILEBERT_INT8, + tflite_models.PERSON_DETECT_INT8, + tflite_models.EFFICIENTNET_INT8, + tflite_models.MOBILENET_V2_INT8, ] - return (gen_configs, []) + def generate( + self, + ) -> Tuple[ + List[iree_definitions.ModuleGenerationConfig], + List[iree_definitions.E2EModelRunConfig], + ]: + """Generates IREE compile and run configs.""" + gen_configs = [ + iree_definitions.ModuleGenerationConfig.build( + compile_config=self.DEFAULT_COMPILE_CONFIG, + imported_model=iree_definitions.ImportedModel.from_model(model), + ) + for model in self.MODELS + ] + return (gen_configs, []) -class Linux_RV32_Benchmarks(object): - """Benchmarks RV32 on Linux devices.""" - RV32_CPU_TARGET = iree_definitions.CompileTarget( - target_backend=iree_definitions.TargetBackend.LLVM_CPU, - target_architecture=common_definitions.DeviceArchitecture.RV32_GENERIC, - target_abi=iree_definitions.TargetABI.LINUX_GNU) - DEFAULT_COMPILE_CONFIG = iree_definitions.CompileConfig.build( - id=unique_ids.IREE_COMPILE_CONFIG_LINUX_RV32_GENERIC_DEFAULTS, - tags=["default-flags"], - compile_targets=[RV32_CPU_TARGET]) - MODELS = [ - tflite_models.EFFICIENTNET_INT8, - tflite_models.MOBILEBERT_INT8, - tflite_models.PERSON_DETECT_INT8, - tflite_models.MOBILENET_V2_INT8, - ] +class Linux_RV32_Benchmarks(object): + """Benchmarks RV32 on Linux devices.""" - def generate( - self - ) -> Tuple[List[iree_definitions.ModuleGenerationConfig], - List[iree_definitions.E2EModelRunConfig]]: - """Generates IREE compile and run configs.""" - gen_configs = [ - iree_definitions.ModuleGenerationConfig.build( - compile_config=self.DEFAULT_COMPILE_CONFIG, - imported_model=iree_definitions.ImportedModel.from_model(model)) - for model in self.MODELS + RV32_CPU_TARGET = iree_definitions.CompileTarget( + target_backend=iree_definitions.TargetBackend.LLVM_CPU, + target_architecture=common_definitions.DeviceArchitecture.RV32_GENERIC, + target_abi=iree_definitions.TargetABI.LINUX_GNU, + ) + DEFAULT_COMPILE_CONFIG = iree_definitions.CompileConfig.build( + id=unique_ids.IREE_COMPILE_CONFIG_LINUX_RV32_GENERIC_DEFAULTS, + tags=["default-flags"], + compile_targets=[RV32_CPU_TARGET], + ) + MODELS = [ + tflite_models.EFFICIENTNET_INT8, + tflite_models.MOBILEBERT_INT8, + tflite_models.PERSON_DETECT_INT8, + tflite_models.MOBILENET_V2_INT8, ] - return (gen_configs, []) + + def generate( + self, + ) -> Tuple[ + List[iree_definitions.ModuleGenerationConfig], + List[iree_definitions.E2EModelRunConfig], + ]: + """Generates IREE compile and run configs.""" + gen_configs = [ + iree_definitions.ModuleGenerationConfig.build( + compile_config=self.DEFAULT_COMPILE_CONFIG, + imported_model=iree_definitions.ImportedModel.from_model(model), + ) + for model in self.MODELS + ] + return (gen_configs, []) diff --git a/build_tools/python/benchmark_suites/iree/utils.py b/build_tools/python/benchmark_suites/iree/utils.py index 0b614d1e4fca..c800cf0e1d46 100644 --- a/build_tools/python/benchmark_suites/iree/utils.py +++ b/build_tools/python/benchmark_suites/iree/utils.py @@ -11,24 +11,23 @@ def generate_e2e_model_run_configs( - module_generation_configs: Sequence[ - iree_definitions.ModuleGenerationConfig], + module_generation_configs: Sequence[iree_definitions.ModuleGenerationConfig], module_execution_configs: Sequence[iree_definitions.ModuleExecutionConfig], device_specs: Sequence[common_definitions.DeviceSpec], tags: Optional[Sequence[str]] = None, - tool: iree_definitions.E2EModelRunTool = iree_definitions.E2EModelRunTool. - IREE_BENCHMARK_MODULE + tool: iree_definitions.E2EModelRunTool = iree_definitions.E2EModelRunTool.IREE_BENCHMARK_MODULE, ) -> List[iree_definitions.E2EModelRunConfig]: - """Generates the run specs from the product of compile specs and run configs. - """ - return [ - iree_definitions.E2EModelRunConfig.build( - module_generation_config=module_generation_config, - module_execution_config=module_execution_config, - target_device_spec=device_spec, - input_data=common_definitions.ZEROS_MODEL_INPUT_DATA, - tool=tool, - tags=tags) for module_generation_config, - module_execution_config, device_spec in itertools.product( - module_generation_configs, module_execution_configs, device_specs) - ] + """Generates the run specs from the product of compile specs and run configs.""" + return [ + iree_definitions.E2EModelRunConfig.build( + module_generation_config=module_generation_config, + module_execution_config=module_execution_config, + target_device_spec=device_spec, + input_data=common_definitions.ZEROS_MODEL_INPUT_DATA, + tool=tool, + tags=tags, + ) + for module_generation_config, module_execution_config, device_spec in itertools.product( + module_generation_configs, module_execution_configs, device_specs + ) + ] diff --git a/build_tools/python/benchmark_suites/iree/vmvx_benchmarks.py b/build_tools/python/benchmark_suites/iree/vmvx_benchmarks.py index 0bfd7396067e..6461a5d9824f 100644 --- a/build_tools/python/benchmark_suites/iree/vmvx_benchmarks.py +++ b/build_tools/python/benchmark_suites/iree/vmvx_benchmarks.py @@ -15,40 +15,50 @@ class Android_VMVX_Benchmarks(object): - """Benchmarks VMVX on Android devices.""" + """Benchmarks VMVX on Android devices.""" - VMVX_CPU_TARGET = iree_definitions.CompileTarget( - target_backend=iree_definitions.TargetBackend.VMVX, - target_architecture=common_definitions.DeviceArchitecture.VMVX_GENERIC, - target_abi=iree_definitions.TargetABI.VMVX) - EXPERIMENTAL_COMPILE_CONFIG = iree_definitions.CompileConfig.build( - id=unique_ids.IREE_COMPILE_CONFIG_VMVX_GENERIC_EXPERIMENTAL, - tags=["experimental-flags"], - compile_targets=[VMVX_CPU_TARGET]) + VMVX_CPU_TARGET = iree_definitions.CompileTarget( + target_backend=iree_definitions.TargetBackend.VMVX, + target_architecture=common_definitions.DeviceArchitecture.VMVX_GENERIC, + target_abi=iree_definitions.TargetABI.VMVX, + ) + EXPERIMENTAL_COMPILE_CONFIG = iree_definitions.CompileConfig.build( + id=unique_ids.IREE_COMPILE_CONFIG_VMVX_GENERIC_EXPERIMENTAL, + tags=["experimental-flags"], + compile_targets=[VMVX_CPU_TARGET], + ) - def generate( - self - ) -> Tuple[List[iree_definitions.ModuleGenerationConfig], - List[iree_definitions.E2EModelRunConfig]]: - """Generates IREE compile and run configs.""" + def generate( + self, + ) -> Tuple[ + List[iree_definitions.ModuleGenerationConfig], + List[iree_definitions.E2EModelRunConfig], + ]: + """Generates IREE compile and run configs.""" - gen_configs = [ - iree_definitions.ModuleGenerationConfig.build( - compile_config=self.EXPERIMENTAL_COMPILE_CONFIG, - imported_model=iree_definitions.ImportedModel.from_model(model)) for - model in [tflite_models.MOBILENET_V2, tflite_models.MOBILENET_V3SMALL] - ] - default_execution_configs = [ - benchmark_suites.iree.module_execution_configs. - get_vmvx_system_scheduling_local_task_config(thread_num=4) - ] - big_cores_devices = device_collections.DEFAULT_DEVICE_COLLECTION.query_device_specs( - architecture=common_definitions.DeviceArchitecture.ARMV8_2_A_GENERIC, - host_environment=common_definitions.HostEnvironment.ANDROID_ARMV8_2_A, - device_parameters={"big-cores"}) - run_configs = benchmark_suites.iree.utils.generate_e2e_model_run_configs( - module_generation_configs=gen_configs, - module_execution_configs=default_execution_configs, - device_specs=big_cores_devices) + gen_configs = [ + iree_definitions.ModuleGenerationConfig.build( + compile_config=self.EXPERIMENTAL_COMPILE_CONFIG, + imported_model=iree_definitions.ImportedModel.from_model(model), + ) + for model in [tflite_models.MOBILENET_V2, tflite_models.MOBILENET_V3SMALL] + ] + default_execution_configs = [ + benchmark_suites.iree.module_execution_configs.get_vmvx_system_scheduling_local_task_config( + thread_num=4 + ) + ] + big_cores_devices = ( + device_collections.DEFAULT_DEVICE_COLLECTION.query_device_specs( + architecture=common_definitions.DeviceArchitecture.ARMV8_2_A_GENERIC, + host_environment=common_definitions.HostEnvironment.ANDROID_ARMV8_2_A, + device_parameters={"big-cores"}, + ) + ) + run_configs = benchmark_suites.iree.utils.generate_e2e_model_run_configs( + module_generation_configs=gen_configs, + module_execution_configs=default_execution_configs, + device_specs=big_cores_devices, + ) - return (gen_configs, run_configs) + return (gen_configs, run_configs) diff --git a/build_tools/python/benchmark_suites/iree/vulkan_nvidia_benchmarks.py b/build_tools/python/benchmark_suites/iree/vulkan_nvidia_benchmarks.py index 1bf07942066d..80fa8be9974f 100644 --- a/build_tools/python/benchmark_suites/iree/vulkan_nvidia_benchmarks.py +++ b/build_tools/python/benchmark_suites/iree/vulkan_nvidia_benchmarks.py @@ -15,88 +15,107 @@ def _get_compile_flag(): - preprocess_passes = [ - "iree-flow-detach-elementwise-from-named-ops", - "iree-preprocessing-convert-conv2d-to-img2col", - "iree-flow-convert-1x1-filter-conv2d-to-matmul", - "iree-preprocessing-pad-linalg-ops{pad-size=32}", - ] - preprocess_flag_template = \ - "--iree-preprocessing-pass-pipeline=builtin.module(func.func({}))" - return [ - "--iree-stream-resource-index-bits=64", "--iree-vm-target-index-bits=64", - preprocess_flag_template.format(",".join(preprocess_passes)) - ] + preprocess_passes = [ + "iree-flow-detach-elementwise-from-named-ops", + "iree-preprocessing-convert-conv2d-to-img2col", + "iree-flow-convert-1x1-filter-conv2d-to-matmul", + "iree-preprocessing-pad-linalg-ops{pad-size=32}", + ] + preprocess_flag_template = ( + "--iree-preprocessing-pass-pipeline=builtin.module(func.func({}))" + ) + return [ + "--iree-stream-resource-index-bits=64", + "--iree-vm-target-index-bits=64", + preprocess_flag_template.format(",".join(preprocess_passes)), + ] class Linux_Vulkan_NVIDIA_Benchmarks(object): - """Benchmarks on Linux Vulkan NVIDIA devices.""" + """Benchmarks on Linux Vulkan NVIDIA devices.""" - AMPERE_TARGET = iree_definitions.CompileTarget( - target_architecture=common_definitions.DeviceArchitecture.NVIDIA_AMPERE, - target_backend=iree_definitions.TargetBackend.VULKAN_SPIRV, - target_abi=iree_definitions.TargetABI.VULKAN_LINUX) - PASCAL_TARGET = iree_definitions.CompileTarget( - target_architecture=common_definitions.DeviceArchitecture.NVIDIA_PASCAL, - target_backend=iree_definitions.TargetBackend.VULKAN_SPIRV, - target_abi=iree_definitions.TargetABI.VULKAN_LINUX) + AMPERE_TARGET = iree_definitions.CompileTarget( + target_architecture=common_definitions.DeviceArchitecture.NVIDIA_AMPERE, + target_backend=iree_definitions.TargetBackend.VULKAN_SPIRV, + target_abi=iree_definitions.TargetABI.VULKAN_LINUX, + ) + PASCAL_TARGET = iree_definitions.CompileTarget( + target_architecture=common_definitions.DeviceArchitecture.NVIDIA_PASCAL, + target_backend=iree_definitions.TargetBackend.VULKAN_SPIRV, + target_abi=iree_definitions.TargetABI.VULKAN_LINUX, + ) - SIMT_COMPILE_CONFIG = iree_definitions.CompileConfig.build( - id=unique_ids.IREE_COMPILE_CONFIG_LINUX_VULKAN_SD_SIMT, - tags=["experimental-flags", "simt"], - compile_targets=[PASCAL_TARGET], - extra_flags=_get_compile_flag()) - TENSORCORE_COMPILE_CONFIG = iree_definitions.CompileConfig.build( - id=unique_ids.IREE_COMPILE_CONFIG_LINUX_VULKAN_SD_TENSORCORE, - tags=["experimental-flags", "tensorcore"], - compile_targets=[AMPERE_TARGET], - extra_flags=_get_compile_flag()) + SIMT_COMPILE_CONFIG = iree_definitions.CompileConfig.build( + id=unique_ids.IREE_COMPILE_CONFIG_LINUX_VULKAN_SD_SIMT, + tags=["experimental-flags", "simt"], + compile_targets=[PASCAL_TARGET], + extra_flags=_get_compile_flag(), + ) + TENSORCORE_COMPILE_CONFIG = iree_definitions.CompileConfig.build( + id=unique_ids.IREE_COMPILE_CONFIG_LINUX_VULKAN_SD_TENSORCORE, + tags=["experimental-flags", "tensorcore"], + compile_targets=[AMPERE_TARGET], + extra_flags=_get_compile_flag(), + ) - def _generate_configs( - self, - models: Sequence[common_definitions.Model], - compile_config: iree_definitions.CompileConfig, - execution_config: iree_definitions. - ModuleExecutionConfig = module_execution_configs.VULKAN_CONFIG, - tags: Sequence[str] = [], - ) -> Tuple[List[iree_definitions.ModuleGenerationConfig], - List[iree_definitions.E2EModelRunConfig]]: - gen_configs = [ - iree_definitions.ModuleGenerationConfig.build( - compile_config=compile_config, - imported_model=iree_definitions.ImportedModel.from_model(model), - tags=tags) for model in models - ] - # We use the same NVIDIA Ampere GPU for benchmarking code generated for - # both Pascal and Ampere architectures. What we care is not exactly these - # two architectures per se; they represent SIMT and tensorcore CodeGen - # paths that we would want both to work. Ampere is able to run both SIMT - # and tensorcore cases. - ampere_devices = device_collections.DEFAULT_DEVICE_COLLECTION.query_device_specs( - architecture=common_definitions.DeviceArchitecture.NVIDIA_AMPERE, - host_environment=common_definitions.HostEnvironment.LINUX_X86_64) - run_module_configs = benchmark_suites.iree.utils.generate_e2e_model_run_configs( - module_generation_configs=gen_configs, - module_execution_configs=[execution_config], - device_specs=ampere_devices, - tags=tags) + def _generate_configs( + self, + models: Sequence[common_definitions.Model], + compile_config: iree_definitions.CompileConfig, + execution_config: iree_definitions.ModuleExecutionConfig = module_execution_configs.VULKAN_CONFIG, + tags: Sequence[str] = [], + ) -> Tuple[ + List[iree_definitions.ModuleGenerationConfig], + List[iree_definitions.E2EModelRunConfig], + ]: + gen_configs = [ + iree_definitions.ModuleGenerationConfig.build( + compile_config=compile_config, + imported_model=iree_definitions.ImportedModel.from_model(model), + tags=tags, + ) + for model in models + ] + # We use the same NVIDIA Ampere GPU for benchmarking code generated for + # both Pascal and Ampere architectures. What we care is not exactly these + # two architectures per se; they represent SIMT and tensorcore CodeGen + # paths that we would want both to work. Ampere is able to run both SIMT + # and tensorcore cases. + ampere_devices = ( + device_collections.DEFAULT_DEVICE_COLLECTION.query_device_specs( + architecture=common_definitions.DeviceArchitecture.NVIDIA_AMPERE, + host_environment=common_definitions.HostEnvironment.LINUX_X86_64, + ) + ) + run_module_configs = benchmark_suites.iree.utils.generate_e2e_model_run_configs( + module_generation_configs=gen_configs, + module_execution_configs=[execution_config], + device_specs=ampere_devices, + tags=tags, + ) - return (gen_configs, run_module_configs) + return (gen_configs, run_module_configs) - def generate( - self - ) -> Tuple[List[iree_definitions.ModuleGenerationConfig], - List[iree_definitions.E2EModelRunConfig]]: - """Generates IREE compile and run configs.""" - # The `vulkan-nvidia`` tag is required to put them into the Vulkan NVIDIA - # benchmark preset. - tensorcore_gen_configs, tensorcore_run_configs = self._generate_configs( - model_groups.VULKAN_MODELS, - self.TENSORCORE_COMPILE_CONFIG, - tags=[benchmark_tags.VULKAN_NVIDIA]) - simt_gen_configs, simt_run_configs = self._generate_configs( - model_groups.VULKAN_MODELS, - self.SIMT_COMPILE_CONFIG, - tags=[benchmark_tags.VULKAN_NVIDIA]) - return (tensorcore_gen_configs + simt_gen_configs, - tensorcore_run_configs + simt_run_configs) + def generate( + self, + ) -> Tuple[ + List[iree_definitions.ModuleGenerationConfig], + List[iree_definitions.E2EModelRunConfig], + ]: + """Generates IREE compile and run configs.""" + # The `vulkan-nvidia`` tag is required to put them into the Vulkan NVIDIA + # benchmark preset. + tensorcore_gen_configs, tensorcore_run_configs = self._generate_configs( + model_groups.VULKAN_MODELS, + self.TENSORCORE_COMPILE_CONFIG, + tags=[benchmark_tags.VULKAN_NVIDIA], + ) + simt_gen_configs, simt_run_configs = self._generate_configs( + model_groups.VULKAN_MODELS, + self.SIMT_COMPILE_CONFIG, + tags=[benchmark_tags.VULKAN_NVIDIA], + ) + return ( + tensorcore_gen_configs + simt_gen_configs, + tensorcore_run_configs + simt_run_configs, + ) diff --git a/build_tools/python/benchmark_suites/iree/x86_64_benchmarks.py b/build_tools/python/benchmark_suites/iree/x86_64_benchmarks.py index 8bc7ebf5f43a..96be90e5de7f 100644 --- a/build_tools/python/benchmark_suites/iree/x86_64_benchmarks.py +++ b/build_tools/python/benchmark_suites/iree/x86_64_benchmarks.py @@ -16,101 +16,122 @@ class Linux_x86_64_Benchmarks(object): - """Benchmarks on x86_64 linux devices.""" - - CASCADELAKE_CPU_TARGET = iree_definitions.CompileTarget( - target_architecture=common_definitions.DeviceArchitecture. - X86_64_CASCADELAKE, - target_backend=iree_definitions.TargetBackend.LLVM_CPU, - target_abi=iree_definitions.TargetABI.LINUX_GNU) - - CASCADELAKE_COMPILE_CONFIG = iree_definitions.CompileConfig.build( - id=unique_ids.IREE_COMPILE_CONFIG_LINUX_CASCADELAKE, - tags=["default-flags"], - compile_targets=[CASCADELAKE_CPU_TARGET]) - CASCADELAKE_FUSE_PADDING_COMPILE_CONFIG = iree_definitions.CompileConfig.build( - id=unique_ids.IREE_COMPILE_CONFIG_LINUX_CASCADELAKE_FUSE_PADDING, - tags=["experimental-flags", "fuse-padding"], - compile_targets=[CASCADELAKE_CPU_TARGET], - extra_flags=[ - "--iree-flow-enable-fuse-padding-into-linalg-consumer-ops", - "--iree-llvmcpu-enable-pad-consumer-fusion" - ]) - - def _generate( - self, - benchmark_configs: List[common_definitions.CpuBenchmarkConfig], - compile_config: iree_definitions.CompileConfig, - device_specs: List[common_definitions.DeviceSpec], - tags: Sequence[str] = [], - ) -> Tuple[List[iree_definitions.ModuleGenerationConfig], - List[iree_definitions.E2EModelRunConfig]]: - gen_configs_all = [] - run_configs_all = [] - - # We avoid the full combinatorial explosion of testing all models with all - # thread counts and instead test each model with a number of threads - # appropriate for its size and configurations we're interested in. - for config in benchmark_configs: - gen_config = iree_definitions.ModuleGenerationConfig.build( - compile_config=compile_config, - imported_model=iree_definitions.ImportedModel.from_model( - config.model), - tags=tags) - - execution_configs = [] - for thread in config.threads: - if thread == 0: - execution_configs.append( - module_execution_configs.ELF_LOCAL_SYNC_CONFIG) - else: - execution_configs.append( - module_execution_configs.get_elf_local_task_config(thread)) - - run_configs = benchmark_suites.iree.utils.generate_e2e_model_run_configs( - module_generation_configs=[gen_config], - module_execution_configs=execution_configs, - device_specs=device_specs, - tags=tags) - - gen_configs_all.append(gen_config) - run_configs_all.extend(run_configs) - - return (gen_configs_all, run_configs_all) - - def generate( - self - ) -> Tuple[List[iree_definitions.ModuleGenerationConfig], - List[iree_definitions.E2EModelRunConfig]]: - """Generates IREE compile and run configs.""" - - cascadelake_devices = device_collections.DEFAULT_DEVICE_COLLECTION.query_device_specs( - architecture=common_definitions.DeviceArchitecture.X86_64_CASCADELAKE, - host_environment=common_definitions.HostEnvironment.LINUX_X86_64) - - # The X86_64 tag is required to put them into the X86_64 benchmark preset. - default_gen_configs, default_run_configs = self._generate( - model_groups.X86_64_BENCHMARK_CONFIG, - self.CASCADELAKE_COMPILE_CONFIG, - cascadelake_devices, - tags=[benchmark_tags.X86_64]) - experimental_gen_configs, experimental_run_configs = self._generate( - model_groups.X86_64_BENCHMARK_CONFIG_EXPERIMENTAL, - self.CASCADELAKE_FUSE_PADDING_COMPILE_CONFIG, - cascadelake_devices, - tags=[benchmark_tags.X86_64]) - - large_gen_configs, large_run_configs = self._generate( - model_groups.X86_64_BENCHMARK_CONFIG_LONG, - self.CASCADELAKE_COMPILE_CONFIG, - cascadelake_devices, - tags=[benchmark_tags.X86_64, benchmark_tags.LARGE]) - - return (default_gen_configs + experimental_gen_configs + large_gen_configs, - default_run_configs + experimental_run_configs + large_run_configs) - - -def generate() -> Tuple[List[iree_definitions.ModuleGenerationConfig], - List[iree_definitions.E2EModelRunConfig]]: - """Generates all compile and run configs for IREE benchmarks.""" - return Linux_x86_64_Benchmarks().generate() + """Benchmarks on x86_64 linux devices.""" + + CASCADELAKE_CPU_TARGET = iree_definitions.CompileTarget( + target_architecture=common_definitions.DeviceArchitecture.X86_64_CASCADELAKE, + target_backend=iree_definitions.TargetBackend.LLVM_CPU, + target_abi=iree_definitions.TargetABI.LINUX_GNU, + ) + + CASCADELAKE_COMPILE_CONFIG = iree_definitions.CompileConfig.build( + id=unique_ids.IREE_COMPILE_CONFIG_LINUX_CASCADELAKE, + tags=["default-flags"], + compile_targets=[CASCADELAKE_CPU_TARGET], + ) + CASCADELAKE_FUSE_PADDING_COMPILE_CONFIG = iree_definitions.CompileConfig.build( + id=unique_ids.IREE_COMPILE_CONFIG_LINUX_CASCADELAKE_FUSE_PADDING, + tags=["experimental-flags", "fuse-padding"], + compile_targets=[CASCADELAKE_CPU_TARGET], + extra_flags=[ + "--iree-flow-enable-fuse-padding-into-linalg-consumer-ops", + "--iree-llvmcpu-enable-pad-consumer-fusion", + ], + ) + + def _generate( + self, + benchmark_configs: List[common_definitions.CpuBenchmarkConfig], + compile_config: iree_definitions.CompileConfig, + device_specs: List[common_definitions.DeviceSpec], + tags: Sequence[str] = [], + ) -> Tuple[ + List[iree_definitions.ModuleGenerationConfig], + List[iree_definitions.E2EModelRunConfig], + ]: + gen_configs_all = [] + run_configs_all = [] + + # We avoid the full combinatorial explosion of testing all models with all + # thread counts and instead test each model with a number of threads + # appropriate for its size and configurations we're interested in. + for config in benchmark_configs: + gen_config = iree_definitions.ModuleGenerationConfig.build( + compile_config=compile_config, + imported_model=iree_definitions.ImportedModel.from_model(config.model), + tags=tags, + ) + + execution_configs = [] + for thread in config.threads: + if thread == 0: + execution_configs.append( + module_execution_configs.ELF_LOCAL_SYNC_CONFIG + ) + else: + execution_configs.append( + module_execution_configs.get_elf_local_task_config(thread) + ) + + run_configs = benchmark_suites.iree.utils.generate_e2e_model_run_configs( + module_generation_configs=[gen_config], + module_execution_configs=execution_configs, + device_specs=device_specs, + tags=tags, + ) + + gen_configs_all.append(gen_config) + run_configs_all.extend(run_configs) + + return (gen_configs_all, run_configs_all) + + def generate( + self, + ) -> Tuple[ + List[iree_definitions.ModuleGenerationConfig], + List[iree_definitions.E2EModelRunConfig], + ]: + """Generates IREE compile and run configs.""" + + cascadelake_devices = ( + device_collections.DEFAULT_DEVICE_COLLECTION.query_device_specs( + architecture=common_definitions.DeviceArchitecture.X86_64_CASCADELAKE, + host_environment=common_definitions.HostEnvironment.LINUX_X86_64, + ) + ) + + # The X86_64 tag is required to put them into the X86_64 benchmark preset. + default_gen_configs, default_run_configs = self._generate( + model_groups.X86_64_BENCHMARK_CONFIG, + self.CASCADELAKE_COMPILE_CONFIG, + cascadelake_devices, + tags=[benchmark_tags.X86_64], + ) + experimental_gen_configs, experimental_run_configs = self._generate( + model_groups.X86_64_BENCHMARK_CONFIG_EXPERIMENTAL, + self.CASCADELAKE_FUSE_PADDING_COMPILE_CONFIG, + cascadelake_devices, + tags=[benchmark_tags.X86_64], + ) + + large_gen_configs, large_run_configs = self._generate( + model_groups.X86_64_BENCHMARK_CONFIG_LONG, + self.CASCADELAKE_COMPILE_CONFIG, + cascadelake_devices, + tags=[benchmark_tags.X86_64, benchmark_tags.LARGE], + ) + + return ( + default_gen_configs + experimental_gen_configs + large_gen_configs, + default_run_configs + experimental_run_configs + large_run_configs, + ) + + +def generate() -> ( + Tuple[ + List[iree_definitions.ModuleGenerationConfig], + List[iree_definitions.E2EModelRunConfig], + ] +): + """Generates all compile and run configs for IREE benchmarks.""" + return Linux_x86_64_Benchmarks().generate() diff --git a/build_tools/python/cmake_builder/rules.py b/build_tools/python/cmake_builder/rules.py index ece750de996d..33d152414e9b 100644 --- a/build_tools/python/cmake_builder/rules.py +++ b/build_tools/python/cmake_builder/rules.py @@ -36,134 +36,157 @@ def _get_string_list(values: Sequence[str], quote: bool = True) -> List[str]: - if quote: - return [f'"{value}"' for value in values] - return list(values) + if quote: + return [f'"{value}"' for value in values] + return list(values) def _get_block_body(body: List[str]) -> List[str]: - return [INDENT_SPACES + line for line in body] + return [INDENT_SPACES + line for line in body] -def _get_string_arg_block(keyword: str, - value: Optional[str], - quote: bool = True) -> List[str]: - if value is None: - return [] - if quote: - value = f'"{value}"' - return [f"{keyword} {value}"] +def _get_string_arg_block( + keyword: str, value: Optional[str], quote: bool = True +) -> List[str]: + if value is None: + return [] + if quote: + value = f'"{value}"' + return [f"{keyword} {value}"] -def _get_string_list_arg_block(keyword: str, - values: Sequence[str], - quote: bool = True) -> List[str]: - if len(values) == 0: - return [] - body = _get_string_list(values, quote) - return [keyword] + _get_block_body(body) +def _get_string_list_arg_block( + keyword: str, values: Sequence[str], quote: bool = True +) -> List[str]: + if len(values) == 0: + return [] + body = _get_string_list(values, quote) + return [keyword] + _get_block_body(body) def _get_option_arg_block(keyword: str, value: Optional[bool]) -> List[str]: - if value is True: - return [keyword] - return [] + if value is True: + return [keyword] + return [] -def _build_call_rule(rule_name: str, - parameter_blocks: Sequence[List[str]]) -> List[str]: - output = [f"{rule_name}("] - for block in parameter_blocks: - if len(block) == 0: - continue - output.extend(_get_block_body(block)) - output.append(")") - return output +def _build_call_rule( + rule_name: str, parameter_blocks: Sequence[List[str]] +) -> List[str]: + output = [f"{rule_name}("] + for block in parameter_blocks: + if len(block) == 0: + continue + output.extend(_get_block_body(block)) + output.append(")") + return output def _convert_block_to_string(block: List[str]) -> str: - # Hack to append the terminating newline and only copies the list instead of - # the whole string. - return "\n".join(block + [""]) - - -def build_iree_bytecode_module(target_name: str, - src: str, - module_name: str, - flags: List[str] = [], - compile_tool_target: Optional[str] = None, - c_identifier: Optional[str] = None, - static_lib_path: Optional[str] = None, - deps: List[str] = [], - friendly_name: Optional[str] = None, - testonly: bool = False, - public: bool = True) -> str: - name_block = _get_string_arg_block("NAME", target_name) - src_block = _get_string_arg_block("SRC", src) - module_name_block = _get_string_arg_block("MODULE_FILE_NAME", module_name) - c_identifier_block = _get_string_arg_block("C_IDENTIFIER", c_identifier) - static_lib_block = _get_string_arg_block("STATIC_LIB_PATH", static_lib_path) - compile_tool_target_block = _get_string_arg_block("COMPILE_TOOL", - compile_tool_target) - flags_block = _get_string_list_arg_block("FLAGS", flags) - deps_block = _get_string_list_arg_block("DEPS", deps) - friendly_name_block = _get_string_arg_block("FRIENDLY_NAME", friendly_name) - testonly_block = _get_option_arg_block("TESTONLY", testonly) - public_block = _get_option_arg_block("PUBLIC", public) - return _convert_block_to_string( - _build_call_rule(rule_name="iree_bytecode_module", - parameter_blocks=[ - name_block, src_block, module_name_block, - c_identifier_block, compile_tool_target_block, - static_lib_block, flags_block, friendly_name_block, - deps_block, testonly_block, public_block - ])) - - -def build_iree_fetch_artifact(target_name: str, source_url: str, output: str, - unpack: bool) -> str: - name_block = _get_string_arg_block("NAME", target_name) - source_url_block = _get_string_arg_block("SOURCE_URL", source_url) - output_block = _get_string_arg_block("OUTPUT", output) - unpack_block = _get_option_arg_block("UNPACK", unpack) - return _convert_block_to_string( - _build_call_rule(rule_name="iree_fetch_artifact", - parameter_blocks=[ - name_block, source_url_block, output_block, - unpack_block - ])) - - -def build_iree_import_tf_model(target_path: str, source: str, - import_flags: List[str], - output_mlir_file: str) -> str: - target_name_block = _get_string_arg_block("TARGET_NAME", target_path) - source_block = _get_string_arg_block("SOURCE", source) - import_flags_block = _get_string_list_arg_block("IMPORT_FLAGS", import_flags) - output_mlir_file_block = _get_string_arg_block("OUTPUT_MLIR_FILE", - output_mlir_file) - return _convert_block_to_string( - _build_call_rule(rule_name="iree_import_tf_model", - parameter_blocks=[ - target_name_block, source_block, import_flags_block, - output_mlir_file_block - ])) - - -def build_iree_import_tflite_model(target_path: str, source: str, - import_flags: List[str], - output_mlir_file: str) -> str: - target_name_block = _get_string_arg_block("TARGET_NAME", target_path) - source_block = _get_string_arg_block("SOURCE", source) - import_flags_block = _get_string_list_arg_block("IMPORT_FLAGS", import_flags) - output_mlir_file_block = _get_string_arg_block("OUTPUT_MLIR_FILE", - output_mlir_file) - return _convert_block_to_string( - _build_call_rule(rule_name="iree_import_tflite_model", - parameter_blocks=[ - target_name_block, source_block, import_flags_block, - output_mlir_file_block - ])) + # Hack to append the terminating newline and only copies the list instead of + # the whole string. + return "\n".join(block + [""]) + + +def build_iree_bytecode_module( + target_name: str, + src: str, + module_name: str, + flags: List[str] = [], + compile_tool_target: Optional[str] = None, + c_identifier: Optional[str] = None, + static_lib_path: Optional[str] = None, + deps: List[str] = [], + friendly_name: Optional[str] = None, + testonly: bool = False, + public: bool = True, +) -> str: + name_block = _get_string_arg_block("NAME", target_name) + src_block = _get_string_arg_block("SRC", src) + module_name_block = _get_string_arg_block("MODULE_FILE_NAME", module_name) + c_identifier_block = _get_string_arg_block("C_IDENTIFIER", c_identifier) + static_lib_block = _get_string_arg_block("STATIC_LIB_PATH", static_lib_path) + compile_tool_target_block = _get_string_arg_block( + "COMPILE_TOOL", compile_tool_target + ) + flags_block = _get_string_list_arg_block("FLAGS", flags) + deps_block = _get_string_list_arg_block("DEPS", deps) + friendly_name_block = _get_string_arg_block("FRIENDLY_NAME", friendly_name) + testonly_block = _get_option_arg_block("TESTONLY", testonly) + public_block = _get_option_arg_block("PUBLIC", public) + return _convert_block_to_string( + _build_call_rule( + rule_name="iree_bytecode_module", + parameter_blocks=[ + name_block, + src_block, + module_name_block, + c_identifier_block, + compile_tool_target_block, + static_lib_block, + flags_block, + friendly_name_block, + deps_block, + testonly_block, + public_block, + ], + ) + ) + + +def build_iree_fetch_artifact( + target_name: str, source_url: str, output: str, unpack: bool +) -> str: + name_block = _get_string_arg_block("NAME", target_name) + source_url_block = _get_string_arg_block("SOURCE_URL", source_url) + output_block = _get_string_arg_block("OUTPUT", output) + unpack_block = _get_option_arg_block("UNPACK", unpack) + return _convert_block_to_string( + _build_call_rule( + rule_name="iree_fetch_artifact", + parameter_blocks=[name_block, source_url_block, output_block, unpack_block], + ) + ) + + +def build_iree_import_tf_model( + target_path: str, source: str, import_flags: List[str], output_mlir_file: str +) -> str: + target_name_block = _get_string_arg_block("TARGET_NAME", target_path) + source_block = _get_string_arg_block("SOURCE", source) + import_flags_block = _get_string_list_arg_block("IMPORT_FLAGS", import_flags) + output_mlir_file_block = _get_string_arg_block("OUTPUT_MLIR_FILE", output_mlir_file) + return _convert_block_to_string( + _build_call_rule( + rule_name="iree_import_tf_model", + parameter_blocks=[ + target_name_block, + source_block, + import_flags_block, + output_mlir_file_block, + ], + ) + ) + + +def build_iree_import_tflite_model( + target_path: str, source: str, import_flags: List[str], output_mlir_file: str +) -> str: + target_name_block = _get_string_arg_block("TARGET_NAME", target_path) + source_block = _get_string_arg_block("SOURCE", source) + import_flags_block = _get_string_list_arg_block("IMPORT_FLAGS", import_flags) + output_mlir_file_block = _get_string_arg_block("OUTPUT_MLIR_FILE", output_mlir_file) + return _convert_block_to_string( + _build_call_rule( + rule_name="iree_import_tflite_model", + parameter_blocks=[ + target_name_block, + source_block, + import_flags_block, + output_mlir_file_block, + ], + ) + ) def build_iree_benchmark_suite_module_test( @@ -174,38 +197,50 @@ def build_iree_benchmark_suite_module_test( runner_args: Sequence[str], timeout_secs: Optional[int] = None, labels: Sequence[str] = [], - xfail_platforms: Sequence[str] = []) -> str: - name_block = _get_string_arg_block("NAME", target_name) - driver_block = _get_string_arg_block("DRIVER", driver) - expected_output_block = _get_string_arg_block("EXPECTED_OUTPUT", - expected_output) - modules_block = _get_string_list_arg_block( - "MODULES", - [f"{platform}={path}" for platform, path in platform_module_map.items()]) - timeout_block = _get_string_arg_block( - "TIMEOUT", - str(timeout_secs) if timeout_secs is not None else None) - runner_args_block = _get_string_list_arg_block("RUNNER_ARGS", runner_args) - labels_block = _get_string_list_arg_block("LABELS", labels) - xfail_platforms_block = _get_string_list_arg_block("XFAIL_PLATFORMS", - xfail_platforms) - return _convert_block_to_string( - _build_call_rule(rule_name="iree_benchmark_suite_module_test", - parameter_blocks=[ - name_block, driver_block, expected_output_block, - timeout_block, modules_block, runner_args_block, - labels_block, xfail_platforms_block - ])) + xfail_platforms: Sequence[str] = [], +) -> str: + name_block = _get_string_arg_block("NAME", target_name) + driver_block = _get_string_arg_block("DRIVER", driver) + expected_output_block = _get_string_arg_block("EXPECTED_OUTPUT", expected_output) + modules_block = _get_string_list_arg_block( + "MODULES", + [f"{platform}={path}" for platform, path in platform_module_map.items()], + ) + timeout_block = _get_string_arg_block( + "TIMEOUT", str(timeout_secs) if timeout_secs is not None else None + ) + runner_args_block = _get_string_list_arg_block("RUNNER_ARGS", runner_args) + labels_block = _get_string_list_arg_block("LABELS", labels) + xfail_platforms_block = _get_string_list_arg_block( + "XFAIL_PLATFORMS", xfail_platforms + ) + return _convert_block_to_string( + _build_call_rule( + rule_name="iree_benchmark_suite_module_test", + parameter_blocks=[ + name_block, + driver_block, + expected_output_block, + timeout_block, + modules_block, + runner_args_block, + labels_block, + xfail_platforms_block, + ], + ) + ) def build_add_dependencies(target: str, deps: List[str]) -> str: - if len(deps) == 0: - raise ValueError("Target dependencies can't be empty.") - deps_list = _get_string_list(deps, quote=False) - return _convert_block_to_string([f"add_dependencies({target}"] + - _get_block_body(deps_list) + [")"]) + if len(deps) == 0: + raise ValueError("Target dependencies can't be empty.") + deps_list = _get_string_list(deps, quote=False) + return _convert_block_to_string( + [f"add_dependencies({target}"] + _get_block_body(deps_list) + [")"] + ) def build_set(variable_name: str, value: str) -> str: - return _convert_block_to_string([f"set({variable_name}"] + - _get_block_body([value]) + [")"]) + return _convert_block_to_string( + [f"set({variable_name}"] + _get_block_body([value]) + [")"] + ) diff --git a/build_tools/python/cmake_builder/rules_test.py b/build_tools/python/cmake_builder/rules_test.py index be3f25f8e104..ad44e7e77be6 100644 --- a/build_tools/python/cmake_builder/rules_test.py +++ b/build_tools/python/cmake_builder/rules_test.py @@ -11,23 +11,24 @@ class RulesTest(unittest.TestCase): + def test_build_iree_bytecode_module(self): + rule = cmake_builder.rules.build_iree_bytecode_module( + target_name="abcd", + src="abcd.mlir", + module_name="abcd.vmfb", + flags=["--backend=cpu", "--opt=3"], + compile_tool_target="iree_iree-compile2", + c_identifier="abcd.c", + static_lib_path="libx.a", + deps=["iree_libx", "iree_liby"], + testonly=True, + public=False, + ) - def test_build_iree_bytecode_module(self): - rule = cmake_builder.rules.build_iree_bytecode_module( - target_name="abcd", - src="abcd.mlir", - module_name="abcd.vmfb", - flags=["--backend=cpu", "--opt=3"], - compile_tool_target="iree_iree-compile2", - c_identifier="abcd.c", - static_lib_path="libx.a", - deps=["iree_libx", "iree_liby"], - testonly=True, - public=False) - - self.assertEqual( - rule, - textwrap.dedent("""\ + self.assertEqual( + rule, + textwrap.dedent( + """\ iree_bytecode_module( NAME "abcd" SRC "abcd.mlir" @@ -43,18 +44,22 @@ def test_build_iree_bytecode_module(self): "iree_liby" TESTONLY ) - """)) + """ + ), + ) - def test_build_iree_bytecode_module_with_defaults(self): - rule = cmake_builder.rules.build_iree_bytecode_module( - target_name="abcd", - src="abcd.mlir", - module_name="abcd.vmfb", - flags=["--backend=cpu", "--opt=3"]) + def test_build_iree_bytecode_module_with_defaults(self): + rule = cmake_builder.rules.build_iree_bytecode_module( + target_name="abcd", + src="abcd.mlir", + module_name="abcd.vmfb", + flags=["--backend=cpu", "--opt=3"], + ) - self.assertEqual( - rule, - textwrap.dedent("""\ + self.assertEqual( + rule, + textwrap.dedent( + """\ iree_bytecode_module( NAME "abcd" SRC "abcd.mlir" @@ -64,39 +69,47 @@ def test_build_iree_bytecode_module_with_defaults(self): "--opt=3" PUBLIC ) - """)) + """ + ), + ) - def test_build_iree_fetch_artifact(self): - rule = cmake_builder.rules.build_iree_fetch_artifact( - target_name="abcd", - source_url="https://example.com/abcd.tflite", - output="./abcd.tflite", - unpack=True) + def test_build_iree_fetch_artifact(self): + rule = cmake_builder.rules.build_iree_fetch_artifact( + target_name="abcd", + source_url="https://example.com/abcd.tflite", + output="./abcd.tflite", + unpack=True, + ) - self.assertEqual( - rule, - textwrap.dedent("""\ + self.assertEqual( + rule, + textwrap.dedent( + """\ iree_fetch_artifact( NAME "abcd" SOURCE_URL "https://example.com/abcd.tflite" OUTPUT "./abcd.tflite" UNPACK ) - """)) + """ + ), + ) - def test_build_iree_import_tf_model(self): - rule = cmake_builder.rules.build_iree_import_tf_model( - target_path="pkg_abcd", - source="abcd/model", - import_flags=[ - "--tf-savedmodel-exported-names=main", - "--tf-import-type=savedmodel_v1" - ], - output_mlir_file="abcd.mlir") + def test_build_iree_import_tf_model(self): + rule = cmake_builder.rules.build_iree_import_tf_model( + target_path="pkg_abcd", + source="abcd/model", + import_flags=[ + "--tf-savedmodel-exported-names=main", + "--tf-import-type=savedmodel_v1", + ], + output_mlir_file="abcd.mlir", + ) - self.assertEqual( - rule, - textwrap.dedent("""\ + self.assertEqual( + rule, + textwrap.dedent( + """\ iree_import_tf_model( TARGET_NAME "pkg_abcd" SOURCE "abcd/model" @@ -105,18 +118,22 @@ def test_build_iree_import_tf_model(self): "--tf-import-type=savedmodel_v1" OUTPUT_MLIR_FILE "abcd.mlir" ) - """)) + """ + ), + ) - def test_build_iree_import_tflite_model(self): - rule = cmake_builder.rules.build_iree_import_tflite_model( - target_path="pkg_abcd", - source="abcd.tflite", - import_flags=["--fake-flag=abcd"], - output_mlir_file="abcd.mlir") + def test_build_iree_import_tflite_model(self): + rule = cmake_builder.rules.build_iree_import_tflite_model( + target_path="pkg_abcd", + source="abcd.tflite", + import_flags=["--fake-flag=abcd"], + output_mlir_file="abcd.mlir", + ) - self.assertEqual( - rule, - textwrap.dedent("""\ + self.assertEqual( + rule, + textwrap.dedent( + """\ iree_import_tflite_model( TARGET_NAME "pkg_abcd" SOURCE "abcd.tflite" @@ -124,25 +141,26 @@ def test_build_iree_import_tflite_model(self): "--fake-flag=abcd" OUTPUT_MLIR_FILE "abcd.mlir" ) - """)) - - def test_build_iree_benchmark_suite_module_test(self): - rule = cmake_builder.rules.build_iree_benchmark_suite_module_test( - target_name="model_test", - driver="LOCAL_TASK", - expected_output="xyz", - platform_module_map={ - "x86_64": "a.vmfb", - "arm": "b.vmfb" - }, - runner_args=["--x=0", "--y=1"], - timeout_secs=10, - labels=["defaults", "e2e"], - xfail_platforms=["arm_64-Android", "riscv_32-Linux"]) - - self.assertEqual( - rule, - textwrap.dedent("""\ + """ + ), + ) + + def test_build_iree_benchmark_suite_module_test(self): + rule = cmake_builder.rules.build_iree_benchmark_suite_module_test( + target_name="model_test", + driver="LOCAL_TASK", + expected_output="xyz", + platform_module_map={"x86_64": "a.vmfb", "arm": "b.vmfb"}, + runner_args=["--x=0", "--y=1"], + timeout_secs=10, + labels=["defaults", "e2e"], + xfail_platforms=["arm_64-Android", "riscv_32-Linux"], + ) + + self.assertEqual( + rule, + textwrap.dedent( + """\ iree_benchmark_suite_module_test( NAME "model_test" DRIVER "LOCAL_TASK" @@ -161,32 +179,41 @@ def test_build_iree_benchmark_suite_module_test(self): "arm_64-Android" "riscv_32-Linux" ) - """)) + """ + ), + ) - def test_build_add_dependencies(self): - rule = cmake_builder.rules.build_add_dependencies( - target="iree_mlir_suites", deps=["pkg_abcd", "pkg_efgh"]) + def test_build_add_dependencies(self): + rule = cmake_builder.rules.build_add_dependencies( + target="iree_mlir_suites", deps=["pkg_abcd", "pkg_efgh"] + ) - self.assertEqual( - rule, - textwrap.dedent("""\ + self.assertEqual( + rule, + textwrap.dedent( + """\ add_dependencies(iree_mlir_suites pkg_abcd pkg_efgh ) - """)) + """ + ), + ) - def test_build_set(self): - rule = cmake_builder.rules.build_set(variable_name="_ABC", value="123") + def test_build_set(self): + rule = cmake_builder.rules.build_set(variable_name="_ABC", value="123") - self.assertEqual( - rule, - textwrap.dedent("""\ + self.assertEqual( + rule, + textwrap.dedent( + """\ set(_ABC 123 ) - """)) + """ + ), + ) if __name__ == "__main__": - unittest.main() + unittest.main() diff --git a/build_tools/python/e2e_model_tests/cmake_generator.py b/build_tools/python/e2e_model_tests/cmake_generator.py index e890f442554d..5b0bb1a7eae5 100644 --- a/build_tools/python/e2e_model_tests/cmake_generator.py +++ b/build_tools/python/e2e_model_tests/cmake_generator.py @@ -14,52 +14,62 @@ def generate_rules( - module_generation_configs: List[iree_definitions.ModuleGenerationConfig] + module_generation_configs: List[iree_definitions.ModuleGenerationConfig], ) -> List[str]: - """Generates CMake rules for e2e model tests.""" + """Generates CMake rules for e2e model tests.""" - # ModelTestConfig uses (imported_model, compile_config (mapped from platform)) - # to define the required module. Collect module paths indexed by the pair. - all_module_path_map = {} - for gen_config in module_generation_configs: - module_path = iree_artifacts.get_module_dir_path( - gen_config) / iree_artifacts.MODULE_FILENAME - all_module_path_map[(gen_config.imported_model.composite_id, - gen_config.compile_config.id)] = module_path + # ModelTestConfig uses (imported_model, compile_config (mapped from platform)) + # to define the required module. Collect module paths indexed by the pair. + all_module_path_map = {} + for gen_config in module_generation_configs: + module_path = ( + iree_artifacts.get_module_dir_path(gen_config) + / iree_artifacts.MODULE_FILENAME + ) + all_module_path_map[ + (gen_config.imported_model.composite_id, gen_config.compile_config.id) + ] = module_path - cmake_rules = [] - for test_config in test_definitions.TEST_CONFIGS: - imported_model = test_config.imported_model - platform_module_map = {} - for platform in test_definitions.CMakePlatform: - if platform in test_config.unsupported_platforms: - continue + cmake_rules = [] + for test_config in test_definitions.TEST_CONFIGS: + imported_model = test_config.imported_model + platform_module_map = {} + for platform in test_definitions.CMakePlatform: + if platform in test_config.unsupported_platforms: + continue - compile_config = test_definitions.PLATFORM_COMPILE_CONFIG_MAP[platform] - module_path = all_module_path_map.get( - (imported_model.composite_id, compile_config.id)) - if module_path is None: - raise ValueError( - f"Module for {test_config.name} on {platform} not found.") - platform_module_map[platform.value] = module_path + compile_config = test_definitions.PLATFORM_COMPILE_CONFIG_MAP[platform] + module_path = all_module_path_map.get( + (imported_model.composite_id, compile_config.id) + ) + if module_path is None: + raise ValueError( + f"Module for {test_config.name} on {platform} not found." + ) + platform_module_map[platform.value] = module_path - # TODO(#11136): Currently the DRIVER is a separate field in the CMake rule ( - # and has effect on test labels). Rules should be generated in another way - # to avoid that. Generates the flags without the driver for now. - runner_args = iree_definitions.generate_run_flags( - imported_model=imported_model, - input_data=test_config.input_data, - module_execution_config=test_config.execution_config, - with_driver=False) + test_config.extra_test_flags - cmake_rule = cmake_builder.rules.build_iree_benchmark_suite_module_test( - target_name=test_config.name, - driver=test_config.execution_config.driver.value, - expected_output=test_config.expected_output, - platform_module_map=platform_module_map, - runner_args=runner_args, - xfail_platforms=[ - platform.value for platform in test_config.xfail_platforms - ]) - cmake_rules.append(cmake_rule) + # TODO(#11136): Currently the DRIVER is a separate field in the CMake rule ( + # and has effect on test labels). Rules should be generated in another way + # to avoid that. Generates the flags without the driver for now. + runner_args = ( + iree_definitions.generate_run_flags( + imported_model=imported_model, + input_data=test_config.input_data, + module_execution_config=test_config.execution_config, + with_driver=False, + ) + + test_config.extra_test_flags + ) + cmake_rule = cmake_builder.rules.build_iree_benchmark_suite_module_test( + target_name=test_config.name, + driver=test_config.execution_config.driver.value, + expected_output=test_config.expected_output, + platform_module_map=platform_module_map, + runner_args=runner_args, + xfail_platforms=[ + platform.value for platform in test_config.xfail_platforms + ], + ) + cmake_rules.append(cmake_rule) - return cmake_rules + return cmake_rules diff --git a/build_tools/python/e2e_model_tests/run_module_utils.py b/build_tools/python/e2e_model_tests/run_module_utils.py index 3842a11ddca7..eac9ad39c31c 100644 --- a/build_tools/python/e2e_model_tests/run_module_utils.py +++ b/build_tools/python/e2e_model_tests/run_module_utils.py @@ -12,16 +12,17 @@ def build_linux_wrapper_cmds_for_device_spec( - device_spec: common_definitions.DeviceSpec) -> List[str]: - """Builds the commands with tools to create the execution environment.""" + device_spec: common_definitions.DeviceSpec, +) -> List[str]: + """Builds the commands with tools to create the execution environment.""" - affinity_mask = None - for param in device_spec.device_parameters: - if param != device_parameters.ALL_CORES: - raise ValueError(f"Unsupported device parameter: {param}.") + affinity_mask = None + for param in device_spec.device_parameters: + if param != device_parameters.ALL_CORES: + raise ValueError(f"Unsupported device parameter: {param}.") - cmds = [] - if affinity_mask is not None: - cmds += ["taskset", affinity_mask] + cmds = [] + if affinity_mask is not None: + cmds += ["taskset", affinity_mask] - return cmds + return cmds diff --git a/build_tools/python/e2e_model_tests/run_module_utils_test.py b/build_tools/python/e2e_model_tests/run_module_utils_test.py index 2bbebf4a3979..f86ddacdab1b 100644 --- a/build_tools/python/e2e_model_tests/run_module_utils_test.py +++ b/build_tools/python/e2e_model_tests/run_module_utils_test.py @@ -12,21 +12,20 @@ class RunModuleUtilsTest(unittest.TestCase): + def test_build_linux_wrapper_cmds_for_device_spec(self): + device_spec = common_definitions.DeviceSpec.build( + id="abc", + device_name="test-device", + architecture=common_definitions.DeviceArchitecture.VMVX_GENERIC, + host_environment=common_definitions.HostEnvironment.LINUX_X86_64, + device_parameters=[device_parameters.ALL_CORES], + tags=[], + ) - def test_build_linux_wrapper_cmds_for_device_spec(self): - device_spec = common_definitions.DeviceSpec.build( - id="abc", - device_name="test-device", - architecture=common_definitions.DeviceArchitecture.VMVX_GENERIC, - host_environment=common_definitions.HostEnvironment.LINUX_X86_64, - device_parameters=[device_parameters.ALL_CORES], - tags=[]) + flags = run_module_utils.build_linux_wrapper_cmds_for_device_spec(device_spec) - flags = run_module_utils.build_linux_wrapper_cmds_for_device_spec( - device_spec) - - self.assertEqual(flags, []) + self.assertEqual(flags, []) if __name__ == "__main__": - unittest.main() + unittest.main() diff --git a/build_tools/python/e2e_model_tests/test_definitions.py b/build_tools/python/e2e_model_tests/test_definitions.py index 599de4f54ef4..def7dd3a8b63 100644 --- a/build_tools/python/e2e_model_tests/test_definitions.py +++ b/build_tools/python/e2e_model_tests/test_definitions.py @@ -12,50 +12,53 @@ from e2e_test_framework.definitions import common_definitions, iree_definitions from e2e_test_framework.models import tflite_models -from benchmark_suites.iree import (riscv_benchmarks, x86_64_benchmarks, - armv8_a_benchmarks, module_execution_configs) +from benchmark_suites.iree import ( + riscv_benchmarks, + x86_64_benchmarks, + armv8_a_benchmarks, + module_execution_configs, +) class CMakePlatform(enum.Enum): - """Enum of CMake system platform string.""" - ANDROID_ARMV8_A = "arm_64-Android" - LINUX_RISCV32 = "riscv_32-Linux" - LINUX_RISCV64 = "riscv_64-Linux" - LINUX_X86_64 = "x86_64-Linux" + """Enum of CMake system platform string.""" + + ANDROID_ARMV8_A = "arm_64-Android" + LINUX_RISCV32 = "riscv_32-Linux" + LINUX_RISCV64 = "riscv_64-Linux" + LINUX_X86_64 = "x86_64-Linux" # Compile config used for each CMake system platform. PLATFORM_COMPILE_CONFIG_MAP = { - CMakePlatform.ANDROID_ARMV8_A: - armv8_a_benchmarks.Android_ARMv8_A_Benchmarks.DEFAULT_COMPILE_CONFIG, - CMakePlatform.LINUX_RISCV32: - riscv_benchmarks.Linux_RV32_Benchmarks.DEFAULT_COMPILE_CONFIG, - CMakePlatform.LINUX_RISCV64: - riscv_benchmarks.Linux_RV64_Benchmarks.DEFAULT_COMPILE_CONFIG, - CMakePlatform.LINUX_X86_64: - x86_64_benchmarks.Linux_x86_64_Benchmarks.CASCADELAKE_COMPILE_CONFIG + CMakePlatform.ANDROID_ARMV8_A: armv8_a_benchmarks.Android_ARMv8_A_Benchmarks.DEFAULT_COMPILE_CONFIG, + CMakePlatform.LINUX_RISCV32: riscv_benchmarks.Linux_RV32_Benchmarks.DEFAULT_COMPILE_CONFIG, + CMakePlatform.LINUX_RISCV64: riscv_benchmarks.Linux_RV64_Benchmarks.DEFAULT_COMPILE_CONFIG, + CMakePlatform.LINUX_X86_64: x86_64_benchmarks.Linux_x86_64_Benchmarks.CASCADELAKE_COMPILE_CONFIG, } @dataclass(frozen=True) class ModelTestConfig(object): - """Defines an e2e model test to run by iree-run-module.""" - # Test name shown in the test rule. - name: str - imported_model: iree_definitions.ImportedModel - execution_config: iree_definitions.ModuleExecutionConfig + """Defines an e2e model test to run by iree-run-module.""" + + # Test name shown in the test rule. + name: str + imported_model: iree_definitions.ImportedModel + execution_config: iree_definitions.ModuleExecutionConfig - # Either a string literal or a file path. - expected_output: str - input_data: common_definitions.ModelInputData = common_definitions.ZEROS_MODEL_INPUT_DATA + # Either a string literal or a file path. + expected_output: str + input_data: common_definitions.ModelInputData = ( + common_definitions.ZEROS_MODEL_INPUT_DATA + ) - # Platforms to ignore this test. - unsupported_platforms: List[CMakePlatform] = dataclasses.field( - default_factory=list) - # Platforms to expect this test failed. - xfail_platforms: List[CMakePlatform] = dataclasses.field(default_factory=list) - # Extra flags for `iree-run-module`. - extra_test_flags: List[str] = dataclasses.field(default_factory=list) + # Platforms to ignore this test. + unsupported_platforms: List[CMakePlatform] = dataclasses.field(default_factory=list) + # Platforms to expect this test failed. + xfail_platforms: List[CMakePlatform] = dataclasses.field(default_factory=list) + # Extra flags for `iree-run-module`. + extra_test_flags: List[str] = dataclasses.field(default_factory=list) TEST_CONFIGS = [ @@ -63,40 +66,51 @@ class ModelTestConfig(object): ModelTestConfig( name="mobilenet_v1_fp32_correctness_test", imported_model=iree_definitions.ImportedModel.from_model( - tflite_models.MOBILENET_V1), + tflite_models.MOBILENET_V1 + ), execution_config=module_execution_configs.ELF_LOCAL_SYNC_CONFIG, expected_output="mobilenet_v1_fp32_expected_output.txt", unsupported_platforms=[ - CMakePlatform.LINUX_RISCV32, CMakePlatform.ANDROID_ARMV8_A - ]), + CMakePlatform.LINUX_RISCV32, + CMakePlatform.ANDROID_ARMV8_A, + ], + ), # efficientnet_int8_correctness_test ModelTestConfig( name="efficientnet_int8_correctness_test", imported_model=iree_definitions.ImportedModel.from_model( - tflite_models.EFFICIENTNET_INT8), + tflite_models.EFFICIENTNET_INT8 + ), execution_config=module_execution_configs.ELF_LOCAL_SYNC_CONFIG, expected_output="efficientnet_int8_expected_output.txt", unsupported_platforms=[ - CMakePlatform.ANDROID_ARMV8_A, CMakePlatform.LINUX_RISCV32, - CMakePlatform.LINUX_RISCV64 - ]), + CMakePlatform.ANDROID_ARMV8_A, + CMakePlatform.LINUX_RISCV32, + CMakePlatform.LINUX_RISCV64, + ], + ), # deeplab_v3_fp32_correctness_test ModelTestConfig( name="deeplab_v3_fp32_correctness_test", imported_model=iree_definitions.ImportedModel.from_model( - tflite_models.DEEPLABV3_FP32), + tflite_models.DEEPLABV3_FP32 + ), execution_config=module_execution_configs.ELF_LOCAL_SYNC_CONFIG, expected_output="deeplab_v3_fp32_input_0_expected_output.npy", extra_test_flags=["--expected_f32_threshold=0.001"], unsupported_platforms=[ - CMakePlatform.LINUX_RISCV32, CMakePlatform.LINUX_RISCV64 - ]), + CMakePlatform.LINUX_RISCV32, + CMakePlatform.LINUX_RISCV64, + ], + ), # person_detect_int8_correctness_test ModelTestConfig( name="person_detect_int8_correctness_test", imported_model=iree_definitions.ImportedModel.from_model( - tflite_models.PERSON_DETECT_INT8), + tflite_models.PERSON_DETECT_INT8 + ), execution_config=module_execution_configs.ELF_LOCAL_SYNC_CONFIG, expected_output="1x2xi8=[72 -72]", - unsupported_platforms=[CMakePlatform.ANDROID_ARMV8_A]) + unsupported_platforms=[CMakePlatform.ANDROID_ARMV8_A], + ), ] diff --git a/build_tools/python/e2e_test_artifacts/cmake_generator/iree_rule_generator.py b/build_tools/python/e2e_test_artifacts/cmake_generator/iree_rule_generator.py index 1610a24a6c90..3e060f180e7d 100644 --- a/build_tools/python/e2e_test_artifacts/cmake_generator/iree_rule_generator.py +++ b/build_tools/python/e2e_test_artifacts/cmake_generator/iree_rule_generator.py @@ -32,181 +32,204 @@ @dataclass(frozen=True) class IreeModelImportRule(object): - target_name: str - output_file_path: pathlib.PurePath - cmake_rules: List[str] + target_name: str + output_file_path: pathlib.PurePath + cmake_rules: List[str] @dataclass(frozen=True) class IreeModuleCompileRule(object): - target_name: str - output_module_path: pathlib.PurePath - cmake_rules: List[str] + target_name: str + output_module_path: pathlib.PurePath + cmake_rules: List[str] class IreeRuleBuilder(object): - """Builder to generate IREE CMake rules.""" - - _package_name: str - - def __init__(self, package_name: str): - self._package_name = package_name - - def build_model_import_rule( - self, source_model_rule: model_rule_generator.ModelRule, - imported_model: iree_definitions.ImportedModel, - output_file_path: pathlib.PurePath) -> IreeModelImportRule: - - model = imported_model.model - import_config = imported_model.import_config - if import_config.tool == iree_definitions.ImportTool.NONE: - if source_model_rule.file_path != output_file_path: - raise ValueError( - f"Separate path for MLIR model isn't supported yet: " - f"('{source_model_rule.file_path }' != '{output_file_path}')") - return IreeModelImportRule(target_name=source_model_rule.target_name, - output_file_path=output_file_path, - cmake_rules=[]) - - # Import target name: iree-imported-model- - target_name = f"iree-imported-model-{imported_model.composite_id}" - - import_flags = import_config.materialize_import_flags(model) - if import_config.tool == iree_definitions.ImportTool.TFLITE_IMPORTER: - cmake_rules = [ - cmake_builder.rules.build_iree_import_tflite_model( - target_path=self.build_target_path(target_name), - source=str(source_model_rule.file_path), - import_flags=import_flags, - output_mlir_file=str(output_file_path)) - ] - elif import_config.tool == iree_definitions.ImportTool.TF_IMPORTER: - cmake_rules = [ - cmake_builder.rules.build_iree_import_tf_model( - target_path=self.build_target_path(target_name), - source=str(source_model_rule.file_path), - import_flags=import_flags, - output_mlir_file=str(output_file_path)) - ] - else: - raise ValueError( - f"Unsupported import tool '{import_config.tool}' of the model '{model.id}'." - ) - - return IreeModelImportRule(target_name=target_name, - output_file_path=output_file_path, - cmake_rules=cmake_rules) - - def build_module_compile_rule( - self, model_import_rule: IreeModelImportRule, - module_generation_config: iree_definitions.ModuleGenerationConfig, - output_file_path: pathlib.PurePath) -> IreeModuleCompileRule: - - compile_flags = module_generation_config.materialize_compile_flags( - module_dir_path=output_file_path.parent) - - # Module target name: iree-module- - target_name = f"iree-module-{module_generation_config.composite_id}" - - cmake_rules = [ - cmake_builder.rules.build_iree_bytecode_module( + """Builder to generate IREE CMake rules.""" + + _package_name: str + + def __init__(self, package_name: str): + self._package_name = package_name + + def build_model_import_rule( + self, + source_model_rule: model_rule_generator.ModelRule, + imported_model: iree_definitions.ImportedModel, + output_file_path: pathlib.PurePath, + ) -> IreeModelImportRule: + model = imported_model.model + import_config = imported_model.import_config + if import_config.tool == iree_definitions.ImportTool.NONE: + if source_model_rule.file_path != output_file_path: + raise ValueError( + f"Separate path for MLIR model isn't supported yet: " + f"('{source_model_rule.file_path }' != '{output_file_path}')" + ) + return IreeModelImportRule( + target_name=source_model_rule.target_name, + output_file_path=output_file_path, + cmake_rules=[], + ) + + # Import target name: iree-imported-model- + target_name = f"iree-imported-model-{imported_model.composite_id}" + + import_flags = import_config.materialize_import_flags(model) + if import_config.tool == iree_definitions.ImportTool.TFLITE_IMPORTER: + cmake_rules = [ + cmake_builder.rules.build_iree_import_tflite_model( + target_path=self.build_target_path(target_name), + source=str(source_model_rule.file_path), + import_flags=import_flags, + output_mlir_file=str(output_file_path), + ) + ] + elif import_config.tool == iree_definitions.ImportTool.TF_IMPORTER: + cmake_rules = [ + cmake_builder.rules.build_iree_import_tf_model( + target_path=self.build_target_path(target_name), + source=str(source_model_rule.file_path), + import_flags=import_flags, + output_mlir_file=str(output_file_path), + ) + ] + else: + raise ValueError( + f"Unsupported import tool '{import_config.tool}' of the model '{model.id}'." + ) + + return IreeModelImportRule( target_name=target_name, - src=str(model_import_rule.output_file_path), - module_name=str(output_file_path), - flags=compile_flags, - friendly_name=str(module_generation_config)) - ] + output_file_path=output_file_path, + cmake_rules=cmake_rules, + ) + + def build_module_compile_rule( + self, + model_import_rule: IreeModelImportRule, + module_generation_config: iree_definitions.ModuleGenerationConfig, + output_file_path: pathlib.PurePath, + ) -> IreeModuleCompileRule: + compile_flags = module_generation_config.materialize_compile_flags( + module_dir_path=output_file_path.parent + ) + + # Module target name: iree-module- + target_name = f"iree-module-{module_generation_config.composite_id}" + + cmake_rules = [ + cmake_builder.rules.build_iree_bytecode_module( + target_name=target_name, + src=str(model_import_rule.output_file_path), + module_name=str(output_file_path), + flags=compile_flags, + friendly_name=str(module_generation_config), + ) + ] + + # TODO(#10155): Dump the compile flags from iree_bytecode_module into a flagfile. + + return IreeModuleCompileRule( + target_name=target_name, + output_module_path=output_file_path, + cmake_rules=cmake_rules, + ) - # TODO(#10155): Dump the compile flags from iree_bytecode_module into a flagfile. + def build_target_path(self, target_name: str): + """Returns the full target path by combining the package name and the target + name. + """ + return f"{self._package_name}_{target_name}" - return IreeModuleCompileRule(target_name=target_name, - output_module_path=output_file_path, - cmake_rules=cmake_rules) - def build_target_path(self, target_name: str): - """Returns the full target path by combining the package name and the target - name. +def generate_rules( + package_name: str, + root_path: pathlib.PurePath, + module_generation_configs: Sequence[iree_definitions.ModuleGenerationConfig], + model_rule_map: Dict[str, model_rule_generator.ModelRule], +) -> List[str]: + """Generates all rules to build IREE artifacts. + + Args: + package_name: CMake package name for rules. + root_path: path of the root artifact directory. + module_generation_configs: list of IREE module generation configs. + model_rule_map: map of generated model rules keyed by model id, it must + cover all model referenced in module_generation_configs. + Returns: + List of cmake rules. """ - return f"{self._package_name}_{target_name}" - -def generate_rules( - package_name: str, root_path: pathlib.PurePath, - module_generation_configs: Sequence[ - iree_definitions.ModuleGenerationConfig], - model_rule_map: Dict[str, model_rule_generator.ModelRule]) -> List[str]: - """Generates all rules to build IREE artifacts. - - Args: - package_name: CMake package name for rules. - root_path: path of the root artifact directory. - module_generation_configs: list of IREE module generation configs. - model_rule_map: map of generated model rules keyed by model id, it must - cover all model referenced in module_generation_configs. - Returns: - List of cmake rules. - """ - - rule_builder = IreeRuleBuilder(package_name=package_name) - - all_imported_models = dict( - (config.imported_model.composite_id, config.imported_model) - for config in module_generation_configs) - - cmake_rules = [] - model_import_rule_map = {} - for imported_model_id, imported_model in all_imported_models.items(): - model_rule = model_rule_map.get(imported_model.model.id) - if model_rule is None: - raise ValueError(f"Model rule not found for {imported_model.model.id}.") - - imported_model_path = iree_artifacts.get_imported_model_path( - imported_model=imported_model, root_path=root_path) - model_import_rule = rule_builder.build_model_import_rule( - source_model_rule=model_rule, - imported_model=imported_model, - output_file_path=imported_model_path) - model_import_rule_map[imported_model_id] = model_import_rule - cmake_rules.extend(model_import_rule.cmake_rules) - - cmake_target_names = collections.defaultdict(set) - for gen_config in module_generation_configs: - model_import_rule = model_import_rule_map[ - gen_config.imported_model.composite_id] - module_dir_path = iree_artifacts.get_module_dir_path( - module_generation_config=gen_config, root_path=root_path) - module_compile_rule = rule_builder.build_module_compile_rule( - model_import_rule=model_import_rule, - module_generation_config=gen_config, - output_file_path=module_dir_path / iree_artifacts.MODULE_FILENAME) - - is_compile_stats = (benchmark_tags.COMPILE_STATS - in gen_config.compile_config.tags) - if benchmark_tags.LARGE in gen_config.tags: - import_target = LARGE_BENCHMARK_IMPORT_MODELS_CMAKE_TARGET - if is_compile_stats: - suite_target = LARGE_E2E_COMPILE_STATS_SUITES_CMAKE_TARGET - else: - suite_target = LARGE_BENCHMARK_SUITES_CMAKE_TARGET - else: - import_target = BENCHMARK_IMPORT_MODELS_CMAKE_TARGET - if is_compile_stats: - suite_target = E2E_COMPILE_STATS_SUITES - else: - suite_target = BENCHMARK_SUITES_CMAKE_TARGET - - cmake_target_names[import_target].add(model_import_rule.target_name) - cmake_target_names[suite_target].add(module_compile_rule.target_name) - cmake_rules.extend(module_compile_rule.cmake_rules) - - for cmake_target, module_target_names in cmake_target_names.items(): - module_target_names = sorted(module_target_names) - cmake_rules.append( - cmake_builder.rules.build_add_dependencies( - target=cmake_target, - deps=[ - rule_builder.build_target_path(target_name) - for target_name in module_target_names - ])) - - return cmake_rules + rule_builder = IreeRuleBuilder(package_name=package_name) + + all_imported_models = dict( + (config.imported_model.composite_id, config.imported_model) + for config in module_generation_configs + ) + + cmake_rules = [] + model_import_rule_map = {} + for imported_model_id, imported_model in all_imported_models.items(): + model_rule = model_rule_map.get(imported_model.model.id) + if model_rule is None: + raise ValueError(f"Model rule not found for {imported_model.model.id}.") + + imported_model_path = iree_artifacts.get_imported_model_path( + imported_model=imported_model, root_path=root_path + ) + model_import_rule = rule_builder.build_model_import_rule( + source_model_rule=model_rule, + imported_model=imported_model, + output_file_path=imported_model_path, + ) + model_import_rule_map[imported_model_id] = model_import_rule + cmake_rules.extend(model_import_rule.cmake_rules) + + cmake_target_names = collections.defaultdict(set) + for gen_config in module_generation_configs: + model_import_rule = model_import_rule_map[ + gen_config.imported_model.composite_id + ] + module_dir_path = iree_artifacts.get_module_dir_path( + module_generation_config=gen_config, root_path=root_path + ) + module_compile_rule = rule_builder.build_module_compile_rule( + model_import_rule=model_import_rule, + module_generation_config=gen_config, + output_file_path=module_dir_path / iree_artifacts.MODULE_FILENAME, + ) + + is_compile_stats = ( + benchmark_tags.COMPILE_STATS in gen_config.compile_config.tags + ) + if benchmark_tags.LARGE in gen_config.tags: + import_target = LARGE_BENCHMARK_IMPORT_MODELS_CMAKE_TARGET + if is_compile_stats: + suite_target = LARGE_E2E_COMPILE_STATS_SUITES_CMAKE_TARGET + else: + suite_target = LARGE_BENCHMARK_SUITES_CMAKE_TARGET + else: + import_target = BENCHMARK_IMPORT_MODELS_CMAKE_TARGET + if is_compile_stats: + suite_target = E2E_COMPILE_STATS_SUITES + else: + suite_target = BENCHMARK_SUITES_CMAKE_TARGET + + cmake_target_names[import_target].add(model_import_rule.target_name) + cmake_target_names[suite_target].add(module_compile_rule.target_name) + cmake_rules.extend(module_compile_rule.cmake_rules) + + for cmake_target, module_target_names in cmake_target_names.items(): + module_target_names = sorted(module_target_names) + cmake_rules.append( + cmake_builder.rules.build_add_dependencies( + target=cmake_target, + deps=[ + rule_builder.build_target_path(target_name) + for target_name in module_target_names + ], + ) + ) + + return cmake_rules diff --git a/build_tools/python/e2e_test_artifacts/cmake_generator/iree_rule_generator_test.py b/build_tools/python/e2e_test_artifacts/cmake_generator/iree_rule_generator_test.py index ea9d23a43cf8..44fb1e88c596 100644 --- a/build_tools/python/e2e_test_artifacts/cmake_generator/iree_rule_generator_test.py +++ b/build_tools/python/e2e_test_artifacts/cmake_generator/iree_rule_generator_test.py @@ -12,184 +12,206 @@ class IreeRuleBuilderTest(unittest.TestCase): - - def setUp(self): - self._builder = iree_rule_generator.IreeRuleBuilder( - package_name="${package}") - - def test_build_model_import_rule_tflite(self): - tflite_model = common_definitions.Model( - id="1234", - name="tflite_m", - tags=[], - source_type=common_definitions.ModelSourceType.EXPORTED_TFLITE, - source_url="https://example.com/xyz.tflite", - entry_function="main", - input_types=["1xf32"]) - tflite_imported_model = iree_definitions.ImportedModel.from_model( - tflite_model) - model_rule = model_rule_generator.ModelRule( - target_name="model-1234", - file_path=pathlib.PurePath("root/models/x.tflite"), - cmake_rules=["abc"]) - output_file_path = pathlib.PurePath("root", "iree", tflite_model.id, - f"{tflite_model.name}.mlir") - - rule = self._builder.build_model_import_rule( - source_model_rule=model_rule, - imported_model=tflite_imported_model, - output_file_path=output_file_path) - - self.assertEqual( - rule.target_name, - f"iree-imported-model-{tflite_imported_model.composite_id}") - self.assertEqual(rule.output_file_path, output_file_path) - - def test_build_model_import_rule_linalg(self): - linalg_model = common_definitions.Model( - id="9012", - name="linalg_m", - tags=[], - source_type=common_definitions.ModelSourceType.EXPORTED_LINALG_MLIR, - source_url="https://example.com/xyz.mlir", - entry_function="main", - input_types=["3xf32"]) - linalg_imported_model = iree_definitions.ImportedModel.from_model( - linalg_model) - model_rule = model_rule_generator.ModelRule( - target_name="model-5678", - file_path=pathlib.PurePath("root/models/y.mlir"), - cmake_rules=["abc"]) - - rule = self._builder.build_model_import_rule( - source_model_rule=model_rule, - imported_model=linalg_imported_model, - output_file_path=pathlib.PurePath(model_rule.file_path)) - - self.assertEqual(rule.target_name, model_rule.target_name) - self.assertEqual(pathlib.PurePath(rule.output_file_path), - pathlib.PurePath(model_rule.file_path)) - - def test_build_module_compile_rule(self): - model = common_definitions.Model( - id="1234", - name="tflite_m", - tags=[], - source_type=common_definitions.ModelSourceType.EXPORTED_TFLITE, - source_url="https://example.com/xyz.tflite", - entry_function="main", - input_types=["1xf32"]) - imported_model = iree_definitions.ImportedModel.from_model(model) - compile_config = iree_definitions.CompileConfig.build( - id="config_a", - tags=["defaults"], - compile_targets=[ - iree_definitions.CompileTarget( - target_architecture=common_definitions.DeviceArchitecture. - X86_64_CASCADELAKE, - target_backend=iree_definitions.TargetBackend.LLVM_CPU, - target_abi=iree_definitions.TargetABI.LINUX_GNU) - ]) - gen_config = iree_definitions.ModuleGenerationConfig.build( - imported_model=imported_model, compile_config=compile_config) - model_import_rule = iree_rule_generator.IreeModelImportRule( - target_name=f"iree-import-model-abcd", - output_file_path=pathlib.PurePath("root/iree/abcd/1234.mlir"), - cmake_rules=["abc"]) - output_file_path = pathlib.PurePath("root/iree/test_output") - - rule = self._builder.build_module_compile_rule( - model_import_rule=model_import_rule, - module_generation_config=gen_config, - output_file_path=output_file_path) - - self.assertEqual(rule.target_name, f"iree-module-{gen_config.composite_id}") - self.assertEqual(rule.output_module_path, output_file_path) - - def test_build_target_path(self): - builder = iree_rule_generator.IreeRuleBuilder(package_name="xyz") - - path = builder.build_target_path("target-abc") - - self.assertEqual(path, f"xyz_target-abc") + def setUp(self): + self._builder = iree_rule_generator.IreeRuleBuilder(package_name="${package}") + + def test_build_model_import_rule_tflite(self): + tflite_model = common_definitions.Model( + id="1234", + name="tflite_m", + tags=[], + source_type=common_definitions.ModelSourceType.EXPORTED_TFLITE, + source_url="https://example.com/xyz.tflite", + entry_function="main", + input_types=["1xf32"], + ) + tflite_imported_model = iree_definitions.ImportedModel.from_model(tflite_model) + model_rule = model_rule_generator.ModelRule( + target_name="model-1234", + file_path=pathlib.PurePath("root/models/x.tflite"), + cmake_rules=["abc"], + ) + output_file_path = pathlib.PurePath( + "root", "iree", tflite_model.id, f"{tflite_model.name}.mlir" + ) + + rule = self._builder.build_model_import_rule( + source_model_rule=model_rule, + imported_model=tflite_imported_model, + output_file_path=output_file_path, + ) + + self.assertEqual( + rule.target_name, + f"iree-imported-model-{tflite_imported_model.composite_id}", + ) + self.assertEqual(rule.output_file_path, output_file_path) + + def test_build_model_import_rule_linalg(self): + linalg_model = common_definitions.Model( + id="9012", + name="linalg_m", + tags=[], + source_type=common_definitions.ModelSourceType.EXPORTED_LINALG_MLIR, + source_url="https://example.com/xyz.mlir", + entry_function="main", + input_types=["3xf32"], + ) + linalg_imported_model = iree_definitions.ImportedModel.from_model(linalg_model) + model_rule = model_rule_generator.ModelRule( + target_name="model-5678", + file_path=pathlib.PurePath("root/models/y.mlir"), + cmake_rules=["abc"], + ) + + rule = self._builder.build_model_import_rule( + source_model_rule=model_rule, + imported_model=linalg_imported_model, + output_file_path=pathlib.PurePath(model_rule.file_path), + ) + + self.assertEqual(rule.target_name, model_rule.target_name) + self.assertEqual( + pathlib.PurePath(rule.output_file_path), + pathlib.PurePath(model_rule.file_path), + ) + + def test_build_module_compile_rule(self): + model = common_definitions.Model( + id="1234", + name="tflite_m", + tags=[], + source_type=common_definitions.ModelSourceType.EXPORTED_TFLITE, + source_url="https://example.com/xyz.tflite", + entry_function="main", + input_types=["1xf32"], + ) + imported_model = iree_definitions.ImportedModel.from_model(model) + compile_config = iree_definitions.CompileConfig.build( + id="config_a", + tags=["defaults"], + compile_targets=[ + iree_definitions.CompileTarget( + target_architecture=common_definitions.DeviceArchitecture.X86_64_CASCADELAKE, + target_backend=iree_definitions.TargetBackend.LLVM_CPU, + target_abi=iree_definitions.TargetABI.LINUX_GNU, + ) + ], + ) + gen_config = iree_definitions.ModuleGenerationConfig.build( + imported_model=imported_model, compile_config=compile_config + ) + model_import_rule = iree_rule_generator.IreeModelImportRule( + target_name=f"iree-import-model-abcd", + output_file_path=pathlib.PurePath("root/iree/abcd/1234.mlir"), + cmake_rules=["abc"], + ) + output_file_path = pathlib.PurePath("root/iree/test_output") + + rule = self._builder.build_module_compile_rule( + model_import_rule=model_import_rule, + module_generation_config=gen_config, + output_file_path=output_file_path, + ) + + self.assertEqual(rule.target_name, f"iree-module-{gen_config.composite_id}") + self.assertEqual(rule.output_module_path, output_file_path) + + def test_build_target_path(self): + builder = iree_rule_generator.IreeRuleBuilder(package_name="xyz") + + path = builder.build_target_path("target-abc") + + self.assertEqual(path, f"xyz_target-abc") class IreeGeneratorTest(unittest.TestCase): - - def test_generate_rules(self): - model_a = common_definitions.Model( - id="1234", - name="tflite_m", - tags=[], - source_type=common_definitions.ModelSourceType.EXPORTED_TFLITE, - source_url="https://example.com/xyz.tflite", - entry_function="main", - input_types=["1xf32"]) - model_b = common_definitions.Model( - id="5678", - name="stablehlo_m", - tags=[], - source_type=common_definitions.ModelSourceType.EXPORTED_STABLEHLO_MLIR, - source_url="https://example.com/xyz_stablehlo.mlir", - entry_function="predict", - input_types=["2xf32"]) - imported_model_a = iree_definitions.ImportedModel.from_model(model_a) - imported_model_b = iree_definitions.ImportedModel.from_model(model_b) - compile_config_a = iree_definitions.CompileConfig.build( - id="config_a", - tags=["defaults"], - compile_targets=[ - iree_definitions.CompileTarget( - target_architecture=common_definitions.DeviceArchitecture. - X86_64_CASCADELAKE, - target_backend=iree_definitions.TargetBackend.LLVM_CPU, - target_abi=iree_definitions.TargetABI.LINUX_GNU) - ]) - compile_config_b = iree_definitions.CompileConfig.build( - id="config_b", - tags=["defaults"], - compile_targets=[ - iree_definitions.CompileTarget( - target_architecture=common_definitions.DeviceArchitecture. - RV64_GENERIC, - target_backend=iree_definitions.TargetBackend.LLVM_CPU, - target_abi=iree_definitions.TargetABI.LINUX_GNU) - ]) - gen_config_a = iree_definitions.ModuleGenerationConfig.build( - imported_model=imported_model_a, compile_config=compile_config_a) - gen_config_b = iree_definitions.ModuleGenerationConfig.build( - imported_model=imported_model_b, compile_config=compile_config_a) - gen_config_c = iree_definitions.ModuleGenerationConfig.build( - imported_model=imported_model_b, compile_config=compile_config_b) - model_rule_map = { - model_a.id: - model_rule_generator.ModelRule( + def test_generate_rules(self): + model_a = common_definitions.Model( + id="1234", + name="tflite_m", + tags=[], + source_type=common_definitions.ModelSourceType.EXPORTED_TFLITE, + source_url="https://example.com/xyz.tflite", + entry_function="main", + input_types=["1xf32"], + ) + model_b = common_definitions.Model( + id="5678", + name="stablehlo_m", + tags=[], + source_type=common_definitions.ModelSourceType.EXPORTED_STABLEHLO_MLIR, + source_url="https://example.com/xyz_stablehlo.mlir", + entry_function="predict", + input_types=["2xf32"], + ) + imported_model_a = iree_definitions.ImportedModel.from_model(model_a) + imported_model_b = iree_definitions.ImportedModel.from_model(model_b) + compile_config_a = iree_definitions.CompileConfig.build( + id="config_a", + tags=["defaults"], + compile_targets=[ + iree_definitions.CompileTarget( + target_architecture=common_definitions.DeviceArchitecture.X86_64_CASCADELAKE, + target_backend=iree_definitions.TargetBackend.LLVM_CPU, + target_abi=iree_definitions.TargetABI.LINUX_GNU, + ) + ], + ) + compile_config_b = iree_definitions.CompileConfig.build( + id="config_b", + tags=["defaults"], + compile_targets=[ + iree_definitions.CompileTarget( + target_architecture=common_definitions.DeviceArchitecture.RV64_GENERIC, + target_backend=iree_definitions.TargetBackend.LLVM_CPU, + target_abi=iree_definitions.TargetABI.LINUX_GNU, + ) + ], + ) + gen_config_a = iree_definitions.ModuleGenerationConfig.build( + imported_model=imported_model_a, compile_config=compile_config_a + ) + gen_config_b = iree_definitions.ModuleGenerationConfig.build( + imported_model=imported_model_b, compile_config=compile_config_a + ) + gen_config_c = iree_definitions.ModuleGenerationConfig.build( + imported_model=imported_model_b, compile_config=compile_config_b + ) + model_rule_map = { + model_a.id: model_rule_generator.ModelRule( target_name=f"model-x", file_path=pathlib.PurePath("x.tflite"), - cmake_rules=["abc"]), - model_b.id: - model_rule_generator.ModelRule( + cmake_rules=["abc"], + ), + model_b.id: model_rule_generator.ModelRule( target_name=f"model-y", file_path=pathlib.PurePath("root/model_5678_stablehlo_m.mlir"), - cmake_rules=["efg"]), - } - - cmake_rules = iree_rule_generator.generate_rules( - package_name="${package}", - root_path=pathlib.PurePath("root"), - module_generation_configs=[gen_config_a, gen_config_b, gen_config_c], - model_rule_map=model_rule_map) - - concated_cmake_rules = "\n".join(cmake_rules) - self.assertRegex(concated_cmake_rules, - f"iree-imported-model-{imported_model_a.composite_id}") - self.assertRegex(concated_cmake_rules, - f"iree-module-{gen_config_a.composite_id}") - self.assertRegex(concated_cmake_rules, - f"iree-module-{gen_config_b.composite_id}") - self.assertRegex(concated_cmake_rules, - f"iree-module-{gen_config_c.composite_id}") + cmake_rules=["efg"], + ), + } + + cmake_rules = iree_rule_generator.generate_rules( + package_name="${package}", + root_path=pathlib.PurePath("root"), + module_generation_configs=[gen_config_a, gen_config_b, gen_config_c], + model_rule_map=model_rule_map, + ) + + concated_cmake_rules = "\n".join(cmake_rules) + self.assertRegex( + concated_cmake_rules, f"iree-imported-model-{imported_model_a.composite_id}" + ) + self.assertRegex( + concated_cmake_rules, f"iree-module-{gen_config_a.composite_id}" + ) + self.assertRegex( + concated_cmake_rules, f"iree-module-{gen_config_b.composite_id}" + ) + self.assertRegex( + concated_cmake_rules, f"iree-module-{gen_config_c.composite_id}" + ) if __name__ == "__main__": - unittest.main() + unittest.main() diff --git a/build_tools/python/e2e_test_artifacts/cmake_generator/model_rule_generator.py b/build_tools/python/e2e_test_artifacts/cmake_generator/model_rule_generator.py index 08e429784ddd..aedd18c4670d 100644 --- a/build_tools/python/e2e_test_artifacts/cmake_generator/model_rule_generator.py +++ b/build_tools/python/e2e_test_artifacts/cmake_generator/model_rule_generator.py @@ -17,37 +17,37 @@ @dataclass class ModelRule(object): - target_name: str - file_path: pathlib.PurePath - cmake_rules: List[str] + target_name: str + file_path: pathlib.PurePath + cmake_rules: List[str] def generate_model_rule_map( - root_path: pathlib.PurePath, - models: Iterable[common_definitions.Model]) -> Dict[str, ModelRule]: - """Returns the model rules keyed by model id in an ordered map.""" - - model_rules = {} - for model in models: - # Model target: -model- - target_name = f"model-{model.id}" - model_path = model_artifacts.get_model_path(model=model, - root_path=root_path) - - model_url = urllib.parse.urlparse(model.source_url) - if model_url.scheme == "https": - cmake_rules = [ - cmake_builder.rules.build_iree_fetch_artifact( - target_name=target_name, - source_url=model.source_url, - output=str(model_path), - unpack=True) - ] - else: - raise ValueError("Unsupported model url: {model.source_url}.") - - model_rules[model.id] = ModelRule(target_name=target_name, - file_path=model_path, - cmake_rules=cmake_rules) - - return model_rules + root_path: pathlib.PurePath, models: Iterable[common_definitions.Model] +) -> Dict[str, ModelRule]: + """Returns the model rules keyed by model id in an ordered map.""" + + model_rules = {} + for model in models: + # Model target: -model- + target_name = f"model-{model.id}" + model_path = model_artifacts.get_model_path(model=model, root_path=root_path) + + model_url = urllib.parse.urlparse(model.source_url) + if model_url.scheme == "https": + cmake_rules = [ + cmake_builder.rules.build_iree_fetch_artifact( + target_name=target_name, + source_url=model.source_url, + output=str(model_path), + unpack=True, + ) + ] + else: + raise ValueError("Unsupported model url: {model.source_url}.") + + model_rules[model.id] = ModelRule( + target_name=target_name, file_path=model_path, cmake_rules=cmake_rules + ) + + return model_rules diff --git a/build_tools/python/e2e_test_artifacts/cmake_generator/model_rule_generator_test.py b/build_tools/python/e2e_test_artifacts/cmake_generator/model_rule_generator_test.py index 71958efc1a28..1d43a723357f 100644 --- a/build_tools/python/e2e_test_artifacts/cmake_generator/model_rule_generator_test.py +++ b/build_tools/python/e2e_test_artifacts/cmake_generator/model_rule_generator_test.py @@ -13,39 +13,43 @@ class CommonGeneratorsTest(unittest.TestCase): - - def test_generate_model_rule_map(self): - model_a = common_definitions.Model( - id="1234", - name="tflite_m", - tags=[], - source_type=common_definitions.ModelSourceType.EXPORTED_TFLITE, - source_url="https://example.com/xyz.tflite", - entry_function="main", - input_types=["1xf32"]) - model_b = common_definitions.Model( - id="5678", - name="tf_m", - tags=[], - source_type=common_definitions.ModelSourceType.EXPORTED_STABLEHLO_MLIR, - source_url="https://example.com/xyz_mlir", - entry_function="predict", - input_types=["2xf32"]) - root_path = pathlib.PurePath("model_root") - - rule_map = model_rule_generator.generate_model_rule_map( - root_path=root_path, models=[model_a, model_b]) - - self.assertEqual(list(rule_map.keys()), [model_a.id, model_b.id]) - self.assertEqual(rule_map[model_a.id].target_name, f"model-{model_a.id}") - self.assertEqual( - rule_map[model_a.id].file_path, - model_artifacts.get_model_path(model=model_a, root_path=root_path)) - self.assertEqual(rule_map[model_b.id].target_name, f"model-{model_b.id}") - self.assertEqual( - rule_map[model_b.id].file_path, - model_artifacts.get_model_path(model=model_b, root_path=root_path)) + def test_generate_model_rule_map(self): + model_a = common_definitions.Model( + id="1234", + name="tflite_m", + tags=[], + source_type=common_definitions.ModelSourceType.EXPORTED_TFLITE, + source_url="https://example.com/xyz.tflite", + entry_function="main", + input_types=["1xf32"], + ) + model_b = common_definitions.Model( + id="5678", + name="tf_m", + tags=[], + source_type=common_definitions.ModelSourceType.EXPORTED_STABLEHLO_MLIR, + source_url="https://example.com/xyz_mlir", + entry_function="predict", + input_types=["2xf32"], + ) + root_path = pathlib.PurePath("model_root") + + rule_map = model_rule_generator.generate_model_rule_map( + root_path=root_path, models=[model_a, model_b] + ) + + self.assertEqual(list(rule_map.keys()), [model_a.id, model_b.id]) + self.assertEqual(rule_map[model_a.id].target_name, f"model-{model_a.id}") + self.assertEqual( + rule_map[model_a.id].file_path, + model_artifacts.get_model_path(model=model_a, root_path=root_path), + ) + self.assertEqual(rule_map[model_b.id].target_name, f"model-{model_b.id}") + self.assertEqual( + rule_map[model_b.id].file_path, + model_artifacts.get_model_path(model=model_b, root_path=root_path), + ) if __name__ == "__main__": - unittest.main() + unittest.main() diff --git a/build_tools/python/e2e_test_artifacts/iree_artifacts.py b/build_tools/python/e2e_test_artifacts/iree_artifacts.py index 939c5cd25c0a..2d3cdc5ef01c 100644 --- a/build_tools/python/e2e_test_artifacts/iree_artifacts.py +++ b/build_tools/python/e2e_test_artifacts/iree_artifacts.py @@ -17,64 +17,65 @@ def _get_model_prefix(imported_model: iree_definitions.ImportedModel) -> str: - """Returns the model prefix for IREE artifacts. The common prefix helps group - artifacts from the same model together for easier navigation. - """ - model = imported_model.model - # IREE model prefix: _ - return f"{IREE_ARTIFACT_PREFIX}_{model.name}" + """Returns the model prefix for IREE artifacts. The common prefix helps group + artifacts from the same model together for easier navigation. + """ + model = imported_model.model + # IREE model prefix: _ + return f"{IREE_ARTIFACT_PREFIX}_{model.name}" def get_imported_model_path( imported_model: iree_definitions.ImportedModel, - root_path: pathlib.PurePath = pathlib.PurePath() + root_path: pathlib.PurePath = pathlib.PurePath(), ) -> pathlib.PurePath: - """Returns the path of an IREE imported MLIR model. If the source model is - in MLIR format, returns the path of source model. - - Args: - imported_model: IREE model importing config. - root_path: path of the root artifact directory, on which the returned path - will base. - Returns: - Path of the imported model file. - """ - model = imported_model.model - if model.source_type in [ - common_definitions.ModelSourceType.EXPORTED_LINALG_MLIR, - common_definitions.ModelSourceType.EXPORTED_STABLEHLO_MLIR, - ]: - # Uses the MLIR model directly. - return model_artifacts.get_model_path(model=model, root_path=root_path) + """Returns the path of an IREE imported MLIR model. If the source model is + in MLIR format, returns the path of source model. - model_prefix = _get_model_prefix(imported_model) - # Imported model path: /_.mlir - return (root_path / f"{model_prefix}_{imported_model.composite_id}.mlir") + Args: + imported_model: IREE model importing config. + root_path: path of the root artifact directory, on which the returned path + will base. + Returns: + Path of the imported model file. + """ + model = imported_model.model + if model.source_type in [ + common_definitions.ModelSourceType.EXPORTED_LINALG_MLIR, + common_definitions.ModelSourceType.EXPORTED_STABLEHLO_MLIR, + ]: + # Uses the MLIR model directly. + return model_artifacts.get_model_path(model=model, root_path=root_path) + + model_prefix = _get_model_prefix(imported_model) + # Imported model path: /_.mlir + return root_path / f"{model_prefix}_{imported_model.composite_id}.mlir" def get_module_dir_path( module_generation_config: iree_definitions.ModuleGenerationConfig, - root_path: pathlib.PurePath = pathlib.PurePath() + root_path: pathlib.PurePath = pathlib.PurePath(), ) -> pathlib.PurePath: - """Returns the path of an IREE module directory, which contains the compiled - module and related flag files. - - Args: - module_generation_config: IREE module generation config. - root_path: path of the root artifact directory, on which the returned path - will base. - Returns: - Path of the module directory. - """ - model_prefix = _get_model_prefix(module_generation_config.imported_model) - # Module dir path: /_module_ - return (root_path / - f"{model_prefix}_module_{module_generation_config.composite_id}") + """Returns the path of an IREE module directory, which contains the compiled + module and related flag files. + + Args: + module_generation_config: IREE module generation config. + root_path: path of the root artifact directory, on which the returned path + will base. + Returns: + Path of the module directory. + """ + model_prefix = _get_model_prefix(module_generation_config.imported_model) + # Module dir path: /_module_ + return root_path / f"{model_prefix}_module_{module_generation_config.composite_id}" def get_dependent_model_map( - module_generation_configs: Iterable[iree_definitions.ModuleGenerationConfig] + module_generation_configs: Iterable[iree_definitions.ModuleGenerationConfig], ) -> Dict[str, common_definitions.Model]: - """Returns an ordered map of the dependent models keyed by model id.""" - return dict((config.imported_model.model.id, config.imported_model.model) - for config in module_generation_configs) + """Returns an ordered map of the dependent models keyed by model id.""" + return dict( + (config.imported_model.model.id, config.imported_model.model) + for config in module_generation_configs + ) diff --git a/build_tools/python/e2e_test_artifacts/iree_artifacts_test.py b/build_tools/python/e2e_test_artifacts/iree_artifacts_test.py index 0ac6ec5c72d2..cc254acd15bd 100644 --- a/build_tools/python/e2e_test_artifacts/iree_artifacts_test.py +++ b/build_tools/python/e2e_test_artifacts/iree_artifacts_test.py @@ -12,126 +12,146 @@ class IreeArtifactsTest(unittest.TestCase): - - def test_get_imported_model_path(self): - model = common_definitions.Model( - id="1234", - name="tflite_m", - tags=[], - source_type=common_definitions.ModelSourceType.EXPORTED_TFLITE, - source_url="https://example.com/xyz.tflite", - entry_function="main", - input_types=["1xf32"]) - imported_model = iree_definitions.ImportedModel.from_model(model) - root_path = pathlib.PurePath("root") - - path = iree_artifacts.get_imported_model_path(imported_model=imported_model, - root_path=root_path) - - self.assertEqual( - path, root_path / f"{iree_artifacts.IREE_ARTIFACT_PREFIX}_{model.name}_" - f"{imported_model.composite_id}.mlir") - - def test_get_imported_model_path_with_mlir_model(self): - model = common_definitions.Model( - id="9012", - name="linalg_m", - tags=[], - source_type=common_definitions.ModelSourceType.EXPORTED_LINALG_MLIR, - source_url="https://example.com/xyz.mlir", - entry_function="main", - input_types=["3xf32"]) - imported_model = iree_definitions.ImportedModel.from_model(model) - root_path = pathlib.PurePath("root") - - path = iree_artifacts.get_imported_model_path(imported_model=imported_model, - root_path=root_path) - - self.assertEqual( - path, model_artifacts.get_model_path(model=model, root_path=root_path)) - - def test_get_module_dir_path(self): - model = common_definitions.Model( - id="1234", - name="tflite_m", - tags=[], - source_type=common_definitions.ModelSourceType.EXPORTED_TFLITE, - source_url="https://example.com/xyz.tflite", - entry_function="main", - input_types=["1xf32"]) - imported_model = iree_definitions.ImportedModel.from_model(model) - compile_config = iree_definitions.CompileConfig.build( - id="config_a", - tags=["defaults"], - compile_targets=[ - iree_definitions.CompileTarget( - target_architecture=common_definitions.DeviceArchitecture. - X86_64_CASCADELAKE, - target_backend=iree_definitions.TargetBackend.LLVM_CPU, - target_abi=iree_definitions.TargetABI.LINUX_GNU) - ]) - gen_config = iree_definitions.ModuleGenerationConfig.build( - imported_model=imported_model, compile_config=compile_config) - root_path = pathlib.PurePath("root") - - path = iree_artifacts.get_module_dir_path( - module_generation_config=gen_config, root_path=root_path) - - self.assertEqual( - path, root_path / f"{iree_artifacts.IREE_ARTIFACT_PREFIX}_{model.name}_" - f"module_{gen_config.composite_id}") - - def test_get_dependent_model_map(self): - model_a = common_definitions.Model( - id="1234", - name="tflite_m", - tags=[], - source_type=common_definitions.ModelSourceType.EXPORTED_TFLITE, - source_url="https://example.com/xyz.tflite", - entry_function="main", - input_types=["1xf32"]) - model_b = common_definitions.Model( - id="9012", - name="linalg_m", - tags=[], - source_type=common_definitions.ModelSourceType.EXPORTED_LINALG_MLIR, - source_url="https://example.com/xyz.mlir", - entry_function="main", - input_types=["3xf32"]) - imported_model_a = iree_definitions.ImportedModel.from_model(model_a) - imported_model_b = iree_definitions.ImportedModel.from_model(model_b) - compile_config_a = iree_definitions.CompileConfig.build( - id="config_a", - tags=["defaults"], - compile_targets=[ - iree_definitions.CompileTarget( - target_architecture=common_definitions.DeviceArchitecture. - X86_64_CASCADELAKE, - target_backend=iree_definitions.TargetBackend.LLVM_CPU, - target_abi=iree_definitions.TargetABI.LINUX_GNU) - ]) - compile_config_b = iree_definitions.CompileConfig.build( - id="config_b", - tags=["defaults"], - compile_targets=[ - iree_definitions.CompileTarget( - target_architecture=common_definitions.DeviceArchitecture. - RV64_GENERIC, - target_backend=iree_definitions.TargetBackend.LLVM_CPU, - target_abi=iree_definitions.TargetABI.LINUX_GNU) - ]) - gen_config_a = iree_definitions.ModuleGenerationConfig.build( - imported_model=imported_model_a, compile_config=compile_config_a) - gen_config_b = iree_definitions.ModuleGenerationConfig.build( - imported_model=imported_model_b, compile_config=compile_config_a) - gen_config_c = iree_definitions.ModuleGenerationConfig.build( - imported_model=imported_model_b, compile_config=compile_config_b) - - models = iree_artifacts.get_dependent_model_map( - module_generation_configs=[gen_config_a, gen_config_b, gen_config_c]) - - self.assertEqual(models, {model_a.id: model_a, model_b.id: model_b}) + def test_get_imported_model_path(self): + model = common_definitions.Model( + id="1234", + name="tflite_m", + tags=[], + source_type=common_definitions.ModelSourceType.EXPORTED_TFLITE, + source_url="https://example.com/xyz.tflite", + entry_function="main", + input_types=["1xf32"], + ) + imported_model = iree_definitions.ImportedModel.from_model(model) + root_path = pathlib.PurePath("root") + + path = iree_artifacts.get_imported_model_path( + imported_model=imported_model, root_path=root_path + ) + + self.assertEqual( + path, + root_path / f"{iree_artifacts.IREE_ARTIFACT_PREFIX}_{model.name}_" + f"{imported_model.composite_id}.mlir", + ) + + def test_get_imported_model_path_with_mlir_model(self): + model = common_definitions.Model( + id="9012", + name="linalg_m", + tags=[], + source_type=common_definitions.ModelSourceType.EXPORTED_LINALG_MLIR, + source_url="https://example.com/xyz.mlir", + entry_function="main", + input_types=["3xf32"], + ) + imported_model = iree_definitions.ImportedModel.from_model(model) + root_path = pathlib.PurePath("root") + + path = iree_artifacts.get_imported_model_path( + imported_model=imported_model, root_path=root_path + ) + + self.assertEqual( + path, model_artifacts.get_model_path(model=model, root_path=root_path) + ) + + def test_get_module_dir_path(self): + model = common_definitions.Model( + id="1234", + name="tflite_m", + tags=[], + source_type=common_definitions.ModelSourceType.EXPORTED_TFLITE, + source_url="https://example.com/xyz.tflite", + entry_function="main", + input_types=["1xf32"], + ) + imported_model = iree_definitions.ImportedModel.from_model(model) + compile_config = iree_definitions.CompileConfig.build( + id="config_a", + tags=["defaults"], + compile_targets=[ + iree_definitions.CompileTarget( + target_architecture=common_definitions.DeviceArchitecture.X86_64_CASCADELAKE, + target_backend=iree_definitions.TargetBackend.LLVM_CPU, + target_abi=iree_definitions.TargetABI.LINUX_GNU, + ) + ], + ) + gen_config = iree_definitions.ModuleGenerationConfig.build( + imported_model=imported_model, compile_config=compile_config + ) + root_path = pathlib.PurePath("root") + + path = iree_artifacts.get_module_dir_path( + module_generation_config=gen_config, root_path=root_path + ) + + self.assertEqual( + path, + root_path / f"{iree_artifacts.IREE_ARTIFACT_PREFIX}_{model.name}_" + f"module_{gen_config.composite_id}", + ) + + def test_get_dependent_model_map(self): + model_a = common_definitions.Model( + id="1234", + name="tflite_m", + tags=[], + source_type=common_definitions.ModelSourceType.EXPORTED_TFLITE, + source_url="https://example.com/xyz.tflite", + entry_function="main", + input_types=["1xf32"], + ) + model_b = common_definitions.Model( + id="9012", + name="linalg_m", + tags=[], + source_type=common_definitions.ModelSourceType.EXPORTED_LINALG_MLIR, + source_url="https://example.com/xyz.mlir", + entry_function="main", + input_types=["3xf32"], + ) + imported_model_a = iree_definitions.ImportedModel.from_model(model_a) + imported_model_b = iree_definitions.ImportedModel.from_model(model_b) + compile_config_a = iree_definitions.CompileConfig.build( + id="config_a", + tags=["defaults"], + compile_targets=[ + iree_definitions.CompileTarget( + target_architecture=common_definitions.DeviceArchitecture.X86_64_CASCADELAKE, + target_backend=iree_definitions.TargetBackend.LLVM_CPU, + target_abi=iree_definitions.TargetABI.LINUX_GNU, + ) + ], + ) + compile_config_b = iree_definitions.CompileConfig.build( + id="config_b", + tags=["defaults"], + compile_targets=[ + iree_definitions.CompileTarget( + target_architecture=common_definitions.DeviceArchitecture.RV64_GENERIC, + target_backend=iree_definitions.TargetBackend.LLVM_CPU, + target_abi=iree_definitions.TargetABI.LINUX_GNU, + ) + ], + ) + gen_config_a = iree_definitions.ModuleGenerationConfig.build( + imported_model=imported_model_a, compile_config=compile_config_a + ) + gen_config_b = iree_definitions.ModuleGenerationConfig.build( + imported_model=imported_model_b, compile_config=compile_config_a + ) + gen_config_c = iree_definitions.ModuleGenerationConfig.build( + imported_model=imported_model_b, compile_config=compile_config_b + ) + + models = iree_artifacts.get_dependent_model_map( + module_generation_configs=[gen_config_a, gen_config_b, gen_config_c] + ) + + self.assertEqual(models, {model_a.id: model_a, model_b.id: model_b}) if __name__ == "__main__": - unittest.main() + unittest.main() diff --git a/build_tools/python/e2e_test_artifacts/model_artifacts.py b/build_tools/python/e2e_test_artifacts/model_artifacts.py index 686646154fa9..cbbaa4eecf59 100644 --- a/build_tools/python/e2e_test_artifacts/model_artifacts.py +++ b/build_tools/python/e2e_test_artifacts/model_artifacts.py @@ -16,25 +16,23 @@ def get_model_path( - model: common_definitions.Model, - root_path: pathlib.PurePath = pathlib.PurePath() + model: common_definitions.Model, root_path: pathlib.PurePath = pathlib.PurePath() ) -> pathlib.PurePath: - """Returns the path of an model artifact file or directory. - - Args: - model: source model. - root_path: path of the root artifact directory, on which the returned path - will base. - Returns: - Path of the model artifact. - """ - model_url = urllib.parse.urlparse(model.source_url) - # Drop the archive extensions. - file_exts = pathlib.PurePath(model_url.path).suffixes - while len(file_exts) > 0 and file_exts[-1] in ARCHIVE_FILE_EXTENSIONS: - file_exts.pop() - model_ext = "".join(file_exts) + """Returns the path of an model artifact file or directory. - # Model path: /__ - return (root_path / - f"{MODEL_ARTIFACT_PREFIX}_{model.id}_{model.name}{model_ext}") + Args: + model: source model. + root_path: path of the root artifact directory, on which the returned path + will base. + Returns: + Path of the model artifact. + """ + model_url = urllib.parse.urlparse(model.source_url) + # Drop the archive extensions. + file_exts = pathlib.PurePath(model_url.path).suffixes + while len(file_exts) > 0 and file_exts[-1] in ARCHIVE_FILE_EXTENSIONS: + file_exts.pop() + model_ext = "".join(file_exts) + + # Model path: /__ + return root_path / f"{MODEL_ARTIFACT_PREFIX}_{model.id}_{model.name}{model_ext}" diff --git a/build_tools/python/e2e_test_artifacts/model_artifacts_test.py b/build_tools/python/e2e_test_artifacts/model_artifacts_test.py index 796ed336d2da..3103233762c3 100644 --- a/build_tools/python/e2e_test_artifacts/model_artifacts_test.py +++ b/build_tools/python/e2e_test_artifacts/model_artifacts_test.py @@ -12,44 +12,46 @@ class ModelArtifactsTest(unittest.TestCase): - - def test_get_model_path_with_tflite_model(self): - tflite_model = common_definitions.Model( - id="1234", - name="tflite_m", - tags=[], - source_type=common_definitions.ModelSourceType.EXPORTED_TFLITE, - source_url="https://example.com/xyz.tflite", - entry_function="main", - input_types=["1xf32"]) - root_path = pathlib.PurePath("root") - - path = model_artifacts.get_model_path(model=tflite_model, - root_path=root_path) - - self.assertEqual( - path, root_path / - f"{model_artifacts.MODEL_ARTIFACT_PREFIX}_{tflite_model.id}_{tflite_model.name}.tflite" - ) - - def test_get_model_path_with_tf_model(self): - tf_model = common_definitions.Model( - id="5678", - name="tf_m", - tags=[], - source_type=common_definitions.ModelSourceType.EXPORTED_STABLEHLO_MLIR, - source_url="https://example.com/xyz_mlir", - entry_function="predict", - input_types=["2xf32"]) - root_path = pathlib.PurePath("root") - - path = model_artifacts.get_model_path(model=tf_model, root_path=root_path) - - self.assertEqual( - path, root_path / - f"{model_artifacts.MODEL_ARTIFACT_PREFIX}_{tf_model.id}_{tf_model.name}" - ) + def test_get_model_path_with_tflite_model(self): + tflite_model = common_definitions.Model( + id="1234", + name="tflite_m", + tags=[], + source_type=common_definitions.ModelSourceType.EXPORTED_TFLITE, + source_url="https://example.com/xyz.tflite", + entry_function="main", + input_types=["1xf32"], + ) + root_path = pathlib.PurePath("root") + + path = model_artifacts.get_model_path(model=tflite_model, root_path=root_path) + + self.assertEqual( + path, + root_path + / f"{model_artifacts.MODEL_ARTIFACT_PREFIX}_{tflite_model.id}_{tflite_model.name}.tflite", + ) + + def test_get_model_path_with_tf_model(self): + tf_model = common_definitions.Model( + id="5678", + name="tf_m", + tags=[], + source_type=common_definitions.ModelSourceType.EXPORTED_STABLEHLO_MLIR, + source_url="https://example.com/xyz_mlir", + entry_function="predict", + input_types=["2xf32"], + ) + root_path = pathlib.PurePath("root") + + path = model_artifacts.get_model_path(model=tf_model, root_path=root_path) + + self.assertEqual( + path, + root_path + / f"{model_artifacts.MODEL_ARTIFACT_PREFIX}_{tf_model.id}_{tf_model.name}", + ) if __name__ == "__main__": - unittest.main() + unittest.main() diff --git a/build_tools/python/e2e_test_framework/definitions/common_definitions.py b/build_tools/python/e2e_test_framework/definitions/common_definitions.py index 87d10094cb74..f74aef26d701 100644 --- a/build_tools/python/e2e_test_framework/definitions/common_definitions.py +++ b/build_tools/python/e2e_test_framework/definitions/common_definitions.py @@ -13,203 +13,217 @@ class ArchitectureType(Enum): - """Type of architecture.""" - CPU = "cpu" - GPU = "gpu" + """Type of architecture.""" + + CPU = "cpu" + GPU = "gpu" @dataclass(frozen=True) class _ArchitectureInfo(object): - """Architecture information.""" - type: ArchitectureType - architecture: str - microarchitecture: str = "" - vendor: str = "" + """Architecture information.""" + + type: ArchitectureType + architecture: str + microarchitecture: str = "" + vendor: str = "" class DeviceArchitecture(_ArchitectureInfo, Enum): - """Predefined architecture/microarchitecture.""" + """Predefined architecture/microarchitecture.""" - # VMVX virtual machine - VMVX_GENERIC = (ArchitectureType.CPU, "vmvx", "generic") + # VMVX virtual machine + VMVX_GENERIC = (ArchitectureType.CPU, "vmvx", "generic") - # x86_64 CPUs - X86_64_CASCADELAKE = (ArchitectureType.CPU, "x86_64", "cascadelake") + # x86_64 CPUs + X86_64_CASCADELAKE = (ArchitectureType.CPU, "x86_64", "cascadelake") - # ARM CPUs - ARMV8_2_A_GENERIC = (ArchitectureType.CPU, "armv8.2-a", "generic") - ARMV9_A_GENERIC = (ArchitectureType.CPU, "armv9-a", "generic") + # ARM CPUs + ARMV8_2_A_GENERIC = (ArchitectureType.CPU, "armv8.2-a", "generic") + ARMV9_A_GENERIC = (ArchitectureType.CPU, "armv9-a", "generic") - # RISC-V CPUs - RV64_GENERIC = (ArchitectureType.CPU, "riscv_64", "generic") - RV32_GENERIC = (ArchitectureType.CPU, "riscv_32", "generic") + # RISC-V CPUs + RV64_GENERIC = (ArchitectureType.CPU, "riscv_64", "generic") + RV32_GENERIC = (ArchitectureType.CPU, "riscv_32", "generic") - # Vulkan GPUs - QUALCOMM_ADRENO = (ArchitectureType.GPU, "adreno", "", "qualcomm") - ARM_VALHALL = (ArchitectureType.GPU, "valhall", "", "arm") - NVIDIA_AMPERE = (ArchitectureType.GPU, "ampere", "", "nvidia") - NVIDIA_PASCAL = (ArchitectureType.GPU, "pascal", "", "nvidia") + # Vulkan GPUs + QUALCOMM_ADRENO = (ArchitectureType.GPU, "adreno", "", "qualcomm") + ARM_VALHALL = (ArchitectureType.GPU, "valhall", "", "arm") + NVIDIA_AMPERE = (ArchitectureType.GPU, "ampere", "", "nvidia") + NVIDIA_PASCAL = (ArchitectureType.GPU, "pascal", "", "nvidia") - # CUDA GPUs - CUDA_SM70 = (ArchitectureType.GPU, "cuda", "sm_70") - CUDA_SM80 = (ArchitectureType.GPU, "cuda", "sm_80") + # CUDA GPUs + CUDA_SM70 = (ArchitectureType.GPU, "cuda", "sm_70") + CUDA_SM80 = (ArchitectureType.GPU, "cuda", "sm_80") - # Starting from 3.11, enum members are defined before the subclasses (don't - # follow MRO, see https://docs.python.org/3/whatsnew/3.11.html#enum). - # Therefore __str__ is defined here instead of in _ArchitectureInfo to - # override the default one. - def __str__(self): - parts = [self.vendor, self.architecture, self.microarchitecture] - return "-".join(part for part in parts if part != "") + # Starting from 3.11, enum members are defined before the subclasses (don't + # follow MRO, see https://docs.python.org/3/whatsnew/3.11.html#enum). + # Therefore __str__ is defined here instead of in _ArchitectureInfo to + # override the default one. + def __str__(self): + parts = [self.vendor, self.architecture, self.microarchitecture] + return "-".join(part for part in parts if part != "") @dataclass(frozen=True) class _HostEnvironmentInfo(object): - """Environment information of a host. - - The definitions and terms here matches the macros in - `runtime/src/iree/base/target_platform.h`. - - Note that this is the environment where the runtime "runs". For example: - ``` - { - "platform": "linux", - "architecture": "x86_64" - } - ``` - means the runtime will run on a Linux x86_64 host. The runtime might dispatch - the workloads on GPU or it can be a VM to run workloads compiled in another - ISA, but those are irrelevant to the information here. - """ - platform: str - architecture: str + """Environment information of a host. + + The definitions and terms here matches the macros in + `runtime/src/iree/base/target_platform.h`. + + Note that this is the environment where the runtime "runs". For example: + ``` + { + "platform": "linux", + "architecture": "x86_64" + } + ``` + means the runtime will run on a Linux x86_64 host. The runtime might dispatch + the workloads on GPU or it can be a VM to run workloads compiled in another + ISA, but those are irrelevant to the information here. + """ + + platform: str + architecture: str class HostEnvironment(_HostEnvironmentInfo, Enum): - """Predefined host environment.""" + """Predefined host environment.""" - LINUX_X86_64 = ("linux", "x86_64") - ANDROID_ARMV8_2_A = ("android", "armv8.2-a") + LINUX_X86_64 = ("linux", "x86_64") + ANDROID_ARMV8_2_A = ("android", "armv8.2-a") class ModelSourceType(Enum): - """Type of model source.""" - # Exported Linalg MLIR file. - EXPORTED_LINALG_MLIR = "exported_linalg_mlir" - # Exported Stable HLO file. - EXPORTED_STABLEHLO_MLIR = "exported_stablehlo_mlir" - # Exported TFLite model file. - EXPORTED_TFLITE = "exported_tflite" + """Type of model source.""" + + # Exported Linalg MLIR file. + EXPORTED_LINALG_MLIR = "exported_linalg_mlir" + # Exported Stable HLO file. + EXPORTED_STABLEHLO_MLIR = "exported_stablehlo_mlir" + # Exported TFLite model file. + EXPORTED_TFLITE = "exported_tflite" class InputDataFormat(Enum): - """Model input data format.""" - ZEROS = "zeros" - NUMPY_NPY = "numpy_npy" + """Model input data format.""" + + ZEROS = "zeros" + NUMPY_NPY = "numpy_npy" @serialization.serializable(type_key="device_specs") @dataclass(frozen=True) class DeviceSpec(object): - """Benchmark device specification.""" - id: str - - # Unique name of the device spec. - name: str - - # Device name. E.g., Pixel-6. - device_name: str - - # Tags to describe the device spec. - tags: List[str] - - # Host environment where the IREE runtime is running. For CPU device type, - # this is usually the same as the device that workloads are dispatched to. - # With a separate device, such as a GPU, however, the runtime and dispatched - # workloads will run on different platforms. - host_environment: HostEnvironment - - # Architecture of the target device. - architecture: DeviceArchitecture - - # Device-specific parameters. E.g., 2-big-cores, 4-little-cores. - # This is for modeling the spec of a heterogeneous processor. Depending on - # which cores you run, the device has a different spec. Benchmark machines use - # these parameters to set up the devices. E.g. set CPU mask. - device_parameters: List[str] = dataclasses.field(default_factory=list) - - def __str__(self): - return self.name - - @classmethod - def build(cls, - id: str, - device_name: str, - tags: Sequence[str], - host_environment: HostEnvironment, - architecture: DeviceArchitecture, - device_parameters: Optional[Sequence[str]] = None): - tag_part = ",".join(tags) - # Format: [,...] - name = f"{device_name}[{tag_part}]" - device_parameters = device_parameters or [] - return cls(id=id, - name=name, - tags=list(tags), - device_name=device_name, - host_environment=host_environment, - architecture=architecture, - device_parameters=list(device_parameters)) + """Benchmark device specification.""" + + id: str + + # Unique name of the device spec. + name: str + + # Device name. E.g., Pixel-6. + device_name: str + + # Tags to describe the device spec. + tags: List[str] + + # Host environment where the IREE runtime is running. For CPU device type, + # this is usually the same as the device that workloads are dispatched to. + # With a separate device, such as a GPU, however, the runtime and dispatched + # workloads will run on different platforms. + host_environment: HostEnvironment + + # Architecture of the target device. + architecture: DeviceArchitecture + + # Device-specific parameters. E.g., 2-big-cores, 4-little-cores. + # This is for modeling the spec of a heterogeneous processor. Depending on + # which cores you run, the device has a different spec. Benchmark machines use + # these parameters to set up the devices. E.g. set CPU mask. + device_parameters: List[str] = dataclasses.field(default_factory=list) + + def __str__(self): + return self.name + + @classmethod + def build( + cls, + id: str, + device_name: str, + tags: Sequence[str], + host_environment: HostEnvironment, + architecture: DeviceArchitecture, + device_parameters: Optional[Sequence[str]] = None, + ): + tag_part = ",".join(tags) + # Format: [,...] + name = f"{device_name}[{tag_part}]" + device_parameters = device_parameters or [] + return cls( + id=id, + name=name, + tags=list(tags), + device_name=device_name, + host_environment=host_environment, + architecture=architecture, + device_parameters=list(device_parameters), + ) @serialization.serializable(type_key="models") @dataclass(frozen=True) class Model(object): - """Model to be benchmarked.""" - id: str - # Friendly unique name. - name: str - # Tags that describe the model characteristics. - tags: List[str] - source_type: ModelSourceType - source_url: str - entry_function: str - # Input types. E.g., ["100x100xf32", "200x200x5xf32"]. - input_types: List[str] - - def __str__(self): - return self.name + """Model to be benchmarked.""" + + id: str + # Friendly unique name. + name: str + # Tags that describe the model characteristics. + tags: List[str] + source_type: ModelSourceType + source_url: str + entry_function: str + # Input types. E.g., ["100x100xf32", "200x200x5xf32"]. + input_types: List[str] + + def __str__(self): + return self.name @serialization.serializable(type_key="model_input_data") @dataclass(frozen=True) class ModelInputData(object): - """Input data to benchmark the model.""" - id: str - # Associated model. - model_id: str - # Friendly name. - name: str - # Tags that describe the data characteristics. - tags: List[str] - data_format: InputDataFormat - source_url: str + """Input data to benchmark the model.""" + + id: str + # Associated model. + model_id: str + # Friendly name. + name: str + # Tags that describe the data characteristics. + tags: List[str] + data_format: InputDataFormat + source_url: str - def __str__(self): - return self.name + def __str__(self): + return self.name # All-zeros dummy input data. Runners will generate the zeros input with proper # shapes. -ZEROS_MODEL_INPUT_DATA = ModelInputData(id=unique_ids.MODEL_INPUT_DATA_ZEROS, - model_id="", - name="zeros", - tags=[], - data_format=InputDataFormat.ZEROS, - source_url="") +ZEROS_MODEL_INPUT_DATA = ModelInputData( + id=unique_ids.MODEL_INPUT_DATA_ZEROS, + model_id="", + name="zeros", + tags=[], + data_format=InputDataFormat.ZEROS, + source_url="", +) @dataclass(frozen=True) class CpuBenchmarkConfig(object): - model: Model - threads: List[int] + model: Model + threads: List[int] diff --git a/build_tools/python/e2e_test_framework/definitions/iree_definitions.py b/build_tools/python/e2e_test_framework/definitions/iree_definitions.py index a10ffa4029e3..52301647eb3e 100644 --- a/build_tools/python/e2e_test_framework/definitions/iree_definitions.py +++ b/build_tools/python/e2e_test_framework/definitions/iree_definitions.py @@ -16,211 +16,229 @@ class TargetBackend(Enum): - """IREE target backend.""" - LLVM_CPU = "llvm-cpu" - CUDA = "cuda" - ROCM = "rocm" - VMVX = "vmvx" - METAL_SPIRV = "metal-spirv" - VULKAN_SPIRV = "vulkan-spirv" + """IREE target backend.""" + + LLVM_CPU = "llvm-cpu" + CUDA = "cuda" + ROCM = "rocm" + VMVX = "vmvx" + METAL_SPIRV = "metal-spirv" + VULKAN_SPIRV = "vulkan-spirv" class TargetABI(Enum): - VMVX = "vmvx" - LINUX_GNU = "linux-gnu" - LINUX_ANDROID29 = "linux-android29" - # IREE defined OS name for vulkan target. See: - # compiler/src/iree/compiler/Dialect/Vulkan/IR/VulkanBase.td - VULKAN_ANDROID30 = "android30" - VULKAN_ANDROID31 = "android31" - VULKAN_LINUX = "linux" + VMVX = "vmvx" + LINUX_GNU = "linux-gnu" + LINUX_ANDROID29 = "linux-android29" + # IREE defined OS name for vulkan target. See: + # compiler/src/iree/compiler/Dialect/Vulkan/IR/VulkanBase.td + VULKAN_ANDROID30 = "android30" + VULKAN_ANDROID31 = "android31" + VULKAN_LINUX = "linux" class RuntimeLoader(Enum): - """IREE runtime loader.""" - # For target that doesn't support loader configuration. - NONE = "none" - EMBEDDED_ELF = "embedded-elf" - VMVX_MODULE = "vmvx-module" - SYSTEM_LIBRARY = "system-library" + """IREE runtime loader.""" + + # For target that doesn't support loader configuration. + NONE = "none" + EMBEDDED_ELF = "embedded-elf" + VMVX_MODULE = "vmvx-module" + SYSTEM_LIBRARY = "system-library" class RuntimeDriver(Enum): - """IREE runtime driver.""" - LOCAL_SYNC = "local-sync" - LOCAL_TASK = "local-task" - CUDA = "cuda" - VULKAN = "vulkan" + """IREE runtime driver.""" + + LOCAL_SYNC = "local-sync" + LOCAL_TASK = "local-task" + CUDA = "cuda" + VULKAN = "vulkan" @serialization.serializable @dataclass(frozen=True) class CompileTarget(object): - """Describes a target device to build for.""" - target_backend: TargetBackend - target_architecture: common_definitions.DeviceArchitecture - target_abi: TargetABI + """Describes a target device to build for.""" - def __str__(self): - return (f"{self.target_architecture}-" + target_backend: TargetBackend + target_architecture: common_definitions.DeviceArchitecture + target_abi: TargetABI + + def __str__(self): + return ( + f"{self.target_architecture}-" f"{self.target_abi.name}-" - f"{self.target_backend.name}").lower() + f"{self.target_backend.name}" + ).lower() @serialization.serializable(type_key="iree_compile_configs") @dataclass(frozen=True) class CompileConfig(object): - """Describes the options to build a module.""" - id: str - name: str - tags: List[str] - compile_targets: List[CompileTarget] - extra_flags: List[str] = dataclasses.field(default_factory=list) - - def __str__(self): - return self.name - - @classmethod - def build(cls, - id: str, - tags: Sequence[str], - compile_targets: Sequence[CompileTarget], - extra_flags: Optional[Sequence[str]] = None): - target_part = ",".join(str(target) for target in compile_targets) - tag_part = ",".join(tags) - # Format: [,...][,...] - name = f"[{target_part}][{tag_part}]" - extra_flags = extra_flags or [] - return cls(id=id, - name=name, - tags=list(tags), - compile_targets=list(compile_targets), - extra_flags=list(extra_flags)) + """Describes the options to build a module.""" + + id: str + name: str + tags: List[str] + compile_targets: List[CompileTarget] + extra_flags: List[str] = dataclasses.field(default_factory=list) + + def __str__(self): + return self.name + + @classmethod + def build( + cls, + id: str, + tags: Sequence[str], + compile_targets: Sequence[CompileTarget], + extra_flags: Optional[Sequence[str]] = None, + ): + target_part = ",".join(str(target) for target in compile_targets) + tag_part = ",".join(tags) + # Format: [,...][,...] + name = f"[{target_part}][{tag_part}]" + extra_flags = extra_flags or [] + return cls( + id=id, + name=name, + tags=list(tags), + compile_targets=list(compile_targets), + extra_flags=list(extra_flags), + ) @serialization.serializable(type_key="iree_module_execution_configs") @dataclass(frozen=True) class ModuleExecutionConfig(object): - """Describes the options to run a module.""" - id: str - name: str - tags: List[str] - loader: RuntimeLoader - driver: RuntimeDriver - extra_flags: List[str] = dataclasses.field(default_factory=list) - - def __str__(self): - return self.name - - @classmethod - def build(cls, - id: str, - tags: Sequence[str], - loader: RuntimeLoader, - driver: RuntimeDriver, - extra_flags: Optional[Sequence[str]] = None): - runtime_part = f"{driver.name}({loader.name})".lower() - tag_part = ",".join(tags) - # Format: ()[,...] - name = f"{runtime_part}[{tag_part}]" - extra_flags = extra_flags or [] - return cls(id=id, - name=name, - tags=list(tags), - loader=loader, - driver=driver, - extra_flags=list(extra_flags)) + """Describes the options to run a module.""" + + id: str + name: str + tags: List[str] + loader: RuntimeLoader + driver: RuntimeDriver + extra_flags: List[str] = dataclasses.field(default_factory=list) + + def __str__(self): + return self.name + + @classmethod + def build( + cls, + id: str, + tags: Sequence[str], + loader: RuntimeLoader, + driver: RuntimeDriver, + extra_flags: Optional[Sequence[str]] = None, + ): + runtime_part = f"{driver.name}({loader.name})".lower() + tag_part = ",".join(tags) + # Format: ()[,...] + name = f"{runtime_part}[{tag_part}]" + extra_flags = extra_flags or [] + return cls( + id=id, + name=name, + tags=list(tags), + loader=loader, + driver=driver, + extra_flags=list(extra_flags), + ) class ImportTool(Enum): - """Iree model import tool.""" - NONE = "none" - TF_IMPORTER = "iree-import-tf" - TFLITE_IMPORTER = "iree-import-tflite" + """Iree model import tool.""" + + NONE = "none" + TF_IMPORTER = "iree-import-tf" + TFLITE_IMPORTER = "iree-import-tflite" # Value should be the name of an IREE supported input type (--iree-input-type). class MLIRDialectType(Enum): - """Imported MLIR dialect type.""" - NONE = "none" - TOSA = "tosa" - STABLEHLO = "stablehlo" + """Imported MLIR dialect type.""" + + NONE = "none" + TOSA = "tosa" + STABLEHLO = "stablehlo" @serialization.serializable(type_key="iree_import_configs") @dataclass(frozen=True) class ImportConfig(object): - """Config to import the model.""" - id: str - name: str - tool: ImportTool - dialect_type: MLIRDialectType - import_flags: List[str] = dataclasses.field(default_factory=list) + """Config to import the model.""" + + id: str + name: str + tool: ImportTool + dialect_type: MLIRDialectType + import_flags: List[str] = dataclasses.field(default_factory=list) - def __str__(self): - return self.name + def __str__(self): + return self.name - def materialize_import_flags(self, - model: common_definitions.Model) -> List[str]: - """Materialize flags with dependent values.""" - return utils.substitute_flag_vars(flags=self.import_flags, - ENTRY_FUNCTION=model.entry_function) + def materialize_import_flags(self, model: common_definitions.Model) -> List[str]: + """Materialize flags with dependent values.""" + return utils.substitute_flag_vars( + flags=self.import_flags, ENTRY_FUNCTION=model.entry_function + ) DEFAULT_TFLITE_IMPORT_CONFIG = ImportConfig( id=unique_ids.IREE_MODEL_IMPORT_TFLITE_DEFAULT, name="tflite", tool=ImportTool.TFLITE_IMPORTER, - dialect_type=MLIRDialectType.TOSA) + dialect_type=MLIRDialectType.TOSA, +) DEFAULT_LINALG_MLIR_IMPORT_CONFIG = ImportConfig( id=unique_ids.IREE_MODEL_IMPORT_LINALG_MLIR_DEFAULT, name="linalg", tool=ImportTool.NONE, - dialect_type=MLIRDialectType.NONE) + dialect_type=MLIRDialectType.NONE, +) DEFAULT_STABLEHLO_MLIR_IMPORT_CONFIG = ImportConfig( id=unique_ids.IREE_MODEL_IMPORT_STABLEHLO_MLIR_DEFAULT, name="stablehlo", tool=ImportTool.NONE, - dialect_type=MLIRDialectType.STABLEHLO) + dialect_type=MLIRDialectType.STABLEHLO, +) MODEL_SOURCE_TO_DEFAULT_IMPORT_CONFIG_MAP = { - common_definitions.ModelSourceType.EXPORTED_LINALG_MLIR: - DEFAULT_LINALG_MLIR_IMPORT_CONFIG, - common_definitions.ModelSourceType.EXPORTED_TFLITE: - DEFAULT_TFLITE_IMPORT_CONFIG, - common_definitions.ModelSourceType.EXPORTED_STABLEHLO_MLIR: - DEFAULT_STABLEHLO_MLIR_IMPORT_CONFIG, + common_definitions.ModelSourceType.EXPORTED_LINALG_MLIR: DEFAULT_LINALG_MLIR_IMPORT_CONFIG, + common_definitions.ModelSourceType.EXPORTED_TFLITE: DEFAULT_TFLITE_IMPORT_CONFIG, + common_definitions.ModelSourceType.EXPORTED_STABLEHLO_MLIR: DEFAULT_STABLEHLO_MLIR_IMPORT_CONFIG, } -@serialization.serializable(type_key="iree_imported_models", - id_field="composite_id") +@serialization.serializable(type_key="iree_imported_models", id_field="composite_id") @dataclass(frozen=True) class ImportedModel(object): - """Describes an imported MLIR model.""" - composite_id: str - name: str - model: common_definitions.Model - import_config: ImportConfig - - def __str__(self): - return self.name - - @classmethod - def from_model(cls, model: common_definitions.Model): - config = MODEL_SOURCE_TO_DEFAULT_IMPORT_CONFIG_MAP.get(model.source_type) - if config is None: - raise ValueError(f"Unsupported model source type: {model.source_type}.") - - composite_id = unique_ids.hash_composite_id([model.id, config.id]) - # Format: () - name = f"{model}({config})" - return cls(composite_id=composite_id, - name=name, - model=model, - import_config=config) + """Describes an imported MLIR model.""" + + composite_id: str + name: str + model: common_definitions.Model + import_config: ImportConfig + + def __str__(self): + return self.name + + @classmethod + def from_model(cls, model: common_definitions.Model): + config = MODEL_SOURCE_TO_DEFAULT_IMPORT_CONFIG_MAP.get(model.source_type) + if config is None: + raise ValueError(f"Unsupported model source type: {model.source_type}.") + + composite_id = unique_ids.hash_composite_id([model.id, config.id]) + # Format: () + name = f"{model}({config})" + return cls( + composite_id=composite_id, name=name, model=model, import_config=config + ) # Variable in flags to be replaced with module dir path. The whole path should @@ -228,221 +246,252 @@ def from_model(cls, model: common_definitions.Model): MODULE_DIR_VARIABLE = r"${MODULE_DIR}" -@serialization.serializable(type_key="iree_module_generation_configs", - id_field="composite_id") +@serialization.serializable( + type_key="iree_module_generation_configs", id_field="composite_id" +) @dataclass(frozen=True) class ModuleGenerationConfig(object): - """Describes a compile target to generate the module.""" - composite_id: str - name: str - tags: List[str] - imported_model: ImportedModel - compile_config: CompileConfig - # Full list of flags to compile with, derived from sub-components, with - # unmaterialized placeholders. Allows the compile flags to be persisted and - # decouple from the generation code. Also serves as useful information in the - # serialized JSON. - compile_flags: List[str] - - def __str__(self): - return self.name - - def materialize_compile_flags(self, module_dir_path: pathlib.PurePath): - """Materialize flags with dependent values.""" - - def _replace_module_dir_placeholder(value: str) -> str: - """Replaces ${MODULE_DIR} in a POSIX path and returns the - platform-dependent path string. - """ - parts = pathlib.PurePosixPath(value).parts - if MODULE_DIR_VARIABLE not in parts: - return value - if parts[0] != MODULE_DIR_VARIABLE: - raise ValueError( - f"'{MODULE_DIR_VARIABLE}' needs to be the head of flag value" - f" if present, but got '{value}'.") - # Properly construct the platform-dependent path. - return str(module_dir_path.joinpath(*parts[1:])) - - return utils.transform_flags(flags=self.compile_flags, - map_funcs=[_replace_module_dir_placeholder]) - - @classmethod - def build(cls, - imported_model: ImportedModel, - compile_config: CompileConfig, - tags: Sequence[str] = ()): - composite_id = unique_ids.hash_composite_id( - [imported_model.composite_id, compile_config.id]) - # Format: - name = f"{imported_model} {compile_config}" - compile_flags = _generate_compile_flags( - compile_config, imported_model.import_config.dialect_type) - return cls(composite_id=composite_id, - name=name, - tags=list(tags), - imported_model=imported_model, - compile_config=compile_config, - compile_flags=compile_flags) + """Describes a compile target to generate the module.""" + + composite_id: str + name: str + tags: List[str] + imported_model: ImportedModel + compile_config: CompileConfig + # Full list of flags to compile with, derived from sub-components, with + # unmaterialized placeholders. Allows the compile flags to be persisted and + # decouple from the generation code. Also serves as useful information in the + # serialized JSON. + compile_flags: List[str] + + def __str__(self): + return self.name + + def materialize_compile_flags(self, module_dir_path: pathlib.PurePath): + """Materialize flags with dependent values.""" + + def _replace_module_dir_placeholder(value: str) -> str: + """Replaces ${MODULE_DIR} in a POSIX path and returns the + platform-dependent path string. + """ + parts = pathlib.PurePosixPath(value).parts + if MODULE_DIR_VARIABLE not in parts: + return value + if parts[0] != MODULE_DIR_VARIABLE: + raise ValueError( + f"'{MODULE_DIR_VARIABLE}' needs to be the head of flag value" + f" if present, but got '{value}'." + ) + # Properly construct the platform-dependent path. + return str(module_dir_path.joinpath(*parts[1:])) + + return utils.transform_flags( + flags=self.compile_flags, map_funcs=[_replace_module_dir_placeholder] + ) + + @classmethod + def build( + cls, + imported_model: ImportedModel, + compile_config: CompileConfig, + tags: Sequence[str] = (), + ): + composite_id = unique_ids.hash_composite_id( + [imported_model.composite_id, compile_config.id] + ) + # Format: + name = f"{imported_model} {compile_config}" + compile_flags = _generate_compile_flags( + compile_config, imported_model.import_config.dialect_type + ) + return cls( + composite_id=composite_id, + name=name, + tags=list(tags), + imported_model=imported_model, + compile_config=compile_config, + compile_flags=compile_flags, + ) class E2EModelRunTool(Enum): - """Tool to run a module.""" - IREE_BENCHMARK_MODULE = "iree-benchmark-module" + """Tool to run a module.""" + + IREE_BENCHMARK_MODULE = "iree-benchmark-module" -@serialization.serializable(type_key="iree_e2e_model_run_configs", - id_field="composite_id") +@serialization.serializable( + type_key="iree_e2e_model_run_configs", id_field="composite_id" +) @dataclass(frozen=True) class E2EModelRunConfig(object): - """Describes an e2e run.""" - composite_id: str - name: str - tags: List[str] - module_generation_config: ModuleGenerationConfig - module_execution_config: ModuleExecutionConfig - target_device_spec: common_definitions.DeviceSpec - input_data: common_definitions.ModelInputData - # Full list of flags to run with, derived from sub-components, with - # unmaterialized placeholders. Allows the run flags to be persisted and - # decouple from the generation code. Also serves as useful information in the - # serialized JSON. - run_flags: List[str] - tool: E2EModelRunTool - - def __str__(self): - return self.name - - def materialize_run_flags(self, gpu_id: str = "0"): - """Materialize flags with dependent values.""" - return utils.substitute_flag_vars(flags=self.run_flags, GPU_ID=gpu_id) - - @classmethod - def build(cls, - module_generation_config: ModuleGenerationConfig, - module_execution_config: ModuleExecutionConfig, - target_device_spec: common_definitions.DeviceSpec, - input_data: common_definitions.ModelInputData, - tool: E2EModelRunTool, - tags: Optional[Sequence[str]] = None): - composite_id = unique_ids.hash_composite_id([ - module_generation_config.composite_id, module_execution_config.id, - target_device_spec.id, input_data.id - ]) - # Format: with @ - name = f"{module_generation_config} {module_execution_config} with {input_data} @ {target_device_spec}" - run_flags = generate_run_flags( - imported_model=module_generation_config.imported_model, - input_data=input_data, - module_execution_config=module_execution_config, - gpu_id=r"${GPU_ID}") - tags_list = [] if tags is None else list(tags) - return cls(composite_id=composite_id, - name=name, - tags=tags_list, - module_generation_config=module_generation_config, - module_execution_config=module_execution_config, - target_device_spec=target_device_spec, - input_data=input_data, - run_flags=run_flags, - tool=tool) - - -def generate_run_flags(imported_model: ImportedModel, - input_data: common_definitions.ModelInputData, - module_execution_config: ModuleExecutionConfig, - gpu_id: str = "0", - with_driver: bool = True) -> List[str]: - """Returns the IREE run module flags of the input model and execution config. - Args: - model: source model. - input_data: model input data. - module_execution_config: execution config. - gpu_id: target gpu id, if runs on GPUs. - with_driver: populate the driver flags if true. False can be used for - generating flags for some CMake rules with a separate DRIVER arg. - Returns: - List of flags. - """ - - model = imported_model.model - run_flags = [f"--function={model.entry_function}"] - if input_data != common_definitions.ZEROS_MODEL_INPUT_DATA: - raise ValueError("Currently only support all-zeros data.") - run_flags += [f"--input={input_type}=0" for input_type in model.input_types] - - exec_config = module_execution_config - run_flags += exec_config.extra_flags.copy() - if with_driver: - driver = exec_config.driver - if driver == RuntimeDriver.CUDA: - run_flags.append(f"--device=cuda://{gpu_id}") - else: - run_flags.append(f"--device={driver.value}") - - return run_flags - - -def _generate_compile_flags(compile_config: CompileConfig, - dialect_type: MLIRDialectType) -> List[str]: - if len(compile_config.compile_targets) != 1: - raise ValueError(f"Only one compile target is supported. Got:" - f" {compile_config.compile_targets}") + """Describes an e2e run.""" + + composite_id: str + name: str + tags: List[str] + module_generation_config: ModuleGenerationConfig + module_execution_config: ModuleExecutionConfig + target_device_spec: common_definitions.DeviceSpec + input_data: common_definitions.ModelInputData + # Full list of flags to run with, derived from sub-components, with + # unmaterialized placeholders. Allows the run flags to be persisted and + # decouple from the generation code. Also serves as useful information in the + # serialized JSON. + run_flags: List[str] + tool: E2EModelRunTool + + def __str__(self): + return self.name + + def materialize_run_flags(self, gpu_id: str = "0"): + """Materialize flags with dependent values.""" + return utils.substitute_flag_vars(flags=self.run_flags, GPU_ID=gpu_id) + + @classmethod + def build( + cls, + module_generation_config: ModuleGenerationConfig, + module_execution_config: ModuleExecutionConfig, + target_device_spec: common_definitions.DeviceSpec, + input_data: common_definitions.ModelInputData, + tool: E2EModelRunTool, + tags: Optional[Sequence[str]] = None, + ): + composite_id = unique_ids.hash_composite_id( + [ + module_generation_config.composite_id, + module_execution_config.id, + target_device_spec.id, + input_data.id, + ] + ) + # Format: with @ + name = f"{module_generation_config} {module_execution_config} with {input_data} @ {target_device_spec}" + run_flags = generate_run_flags( + imported_model=module_generation_config.imported_model, + input_data=input_data, + module_execution_config=module_execution_config, + gpu_id=r"${GPU_ID}", + ) + tags_list = [] if tags is None else list(tags) + return cls( + composite_id=composite_id, + name=name, + tags=tags_list, + module_generation_config=module_generation_config, + module_execution_config=module_execution_config, + target_device_spec=target_device_spec, + input_data=input_data, + run_flags=run_flags, + tool=tool, + ) + + +def generate_run_flags( + imported_model: ImportedModel, + input_data: common_definitions.ModelInputData, + module_execution_config: ModuleExecutionConfig, + gpu_id: str = "0", + with_driver: bool = True, +) -> List[str]: + """Returns the IREE run module flags of the input model and execution config. + Args: + model: source model. + input_data: model input data. + module_execution_config: execution config. + gpu_id: target gpu id, if runs on GPUs. + with_driver: populate the driver flags if true. False can be used for + generating flags for some CMake rules with a separate DRIVER arg. + Returns: + List of flags. + """ + + model = imported_model.model + run_flags = [f"--function={model.entry_function}"] + if input_data != common_definitions.ZEROS_MODEL_INPUT_DATA: + raise ValueError("Currently only support all-zeros data.") + run_flags += [f"--input={input_type}=0" for input_type in model.input_types] + + exec_config = module_execution_config + run_flags += exec_config.extra_flags.copy() + if with_driver: + driver = exec_config.driver + if driver == RuntimeDriver.CUDA: + run_flags.append(f"--device=cuda://{gpu_id}") + else: + run_flags.append(f"--device={driver.value}") + + return run_flags + + +def _generate_compile_flags( + compile_config: CompileConfig, dialect_type: MLIRDialectType +) -> List[str]: + if len(compile_config.compile_targets) != 1: + raise ValueError( + f"Only one compile target is supported. Got:" + f" {compile_config.compile_targets}" + ) - compile_target = compile_config.compile_targets[0] - flags = [ - f"--iree-hal-target-backends={compile_target.target_backend.value}", - f"--iree-input-type={dialect_type.value}" - ] - flags += _generate_compile_target_flags(compile_target) - flags += compile_config.extra_flags - return flags + compile_target = compile_config.compile_targets[0] + flags = [ + f"--iree-hal-target-backends={compile_target.target_backend.value}", + f"--iree-input-type={dialect_type.value}", + ] + flags += _generate_compile_target_flags(compile_target) + flags += compile_config.extra_flags + return flags def _generate_compile_target_flags(target: CompileTarget) -> List[str]: - arch_info = target.target_architecture - if target.target_backend == TargetBackend.VULKAN_SPIRV: - gpu_arch = arch_info.microarchitecture if len( - arch_info.microarchitecture) != 0 else arch_info.architecture - return [ - f"--iree-vulkan-target-triple={gpu_arch}-unknown-{target.target_abi.value}", - ] - - if arch_info.architecture == "x86_64": - flags = [ - f"--iree-llvmcpu-target-triple=x86_64-unknown-{target.target_abi.value}", - f"--iree-llvmcpu-target-cpu={arch_info.microarchitecture.lower()}" - ] - elif arch_info.architecture == "riscv_64": - flags = [ - f"--iree-llvmcpu-target-triple=riscv64-pc-{target.target_abi.value}", - "--iree-llvmcpu-target-cpu=generic-rv64", - "--iree-llvmcpu-target-abi=lp64d", - "--iree-llvmcpu-target-cpu-features=+m,+a,+f,+d,+zvl512b,+v", - "--riscv-v-fixed-length-vector-lmul-max=8" - ] - elif arch_info.architecture == "riscv_32": - # TODO(llvm-project/60463): Replace 'zve32f' with 'zve32x'. - flags = [ - f"--iree-llvmcpu-target-triple=riscv32-pc-{target.target_abi.value}", - "--iree-llvmcpu-target-cpu=generic-rv32", - "--iree-llvmcpu-target-abi=ilp32", - "--iree-llvmcpu-target-cpu-features=+m,+a,+f,+zvl512b,+zve32f", - "--riscv-v-fixed-length-vector-lmul-max=8" - ] - elif arch_info.architecture == "armv8.2-a": - flags = [ - f"--iree-llvmcpu-target-triple=aarch64-none-{target.target_abi.value}", - ] - elif arch_info.architecture == "cuda": - if target.target_abi != TargetABI.LINUX_GNU: - raise ValueError( - f"Unsupported target ABI for CUDA backend: `{target.target_abi}`") - flags = [ - f"--iree-hal-cuda-llvm-target-arch={arch_info.microarchitecture}", - ] - elif arch_info.architecture == "vmvx": - flags = [] - else: - raise ValueError(f"Unsupported architecture: '{arch_info.architecture}'") - return flags + arch_info = target.target_architecture + if target.target_backend == TargetBackend.VULKAN_SPIRV: + gpu_arch = ( + arch_info.microarchitecture + if len(arch_info.microarchitecture) != 0 + else arch_info.architecture + ) + return [ + f"--iree-vulkan-target-triple={gpu_arch}-unknown-{target.target_abi.value}", + ] + + if arch_info.architecture == "x86_64": + flags = [ + f"--iree-llvmcpu-target-triple=x86_64-unknown-{target.target_abi.value}", + f"--iree-llvmcpu-target-cpu={arch_info.microarchitecture.lower()}", + ] + elif arch_info.architecture == "riscv_64": + flags = [ + f"--iree-llvmcpu-target-triple=riscv64-pc-{target.target_abi.value}", + "--iree-llvmcpu-target-cpu=generic-rv64", + "--iree-llvmcpu-target-abi=lp64d", + "--iree-llvmcpu-target-cpu-features=+m,+a,+f,+d,+zvl512b,+v", + "--riscv-v-fixed-length-vector-lmul-max=8", + ] + elif arch_info.architecture == "riscv_32": + # TODO(llvm-project/60463): Replace 'zve32f' with 'zve32x'. + flags = [ + f"--iree-llvmcpu-target-triple=riscv32-pc-{target.target_abi.value}", + "--iree-llvmcpu-target-cpu=generic-rv32", + "--iree-llvmcpu-target-abi=ilp32", + "--iree-llvmcpu-target-cpu-features=+m,+a,+f,+zvl512b,+zve32f", + "--riscv-v-fixed-length-vector-lmul-max=8", + ] + elif arch_info.architecture == "armv8.2-a": + flags = [ + f"--iree-llvmcpu-target-triple=aarch64-none-{target.target_abi.value}", + ] + elif arch_info.architecture == "cuda": + if target.target_abi != TargetABI.LINUX_GNU: + raise ValueError( + f"Unsupported target ABI for CUDA backend: `{target.target_abi}`" + ) + flags = [ + f"--iree-hal-cuda-llvm-target-arch={arch_info.microarchitecture}", + ] + elif arch_info.architecture == "vmvx": + flags = [] + else: + raise ValueError(f"Unsupported architecture: '{arch_info.architecture}'") + return flags diff --git a/build_tools/python/e2e_test_framework/definitions/iree_definitions_test.py b/build_tools/python/e2e_test_framework/definitions/iree_definitions_test.py index 02d19876b1a0..585c7e57a4f7 100644 --- a/build_tools/python/e2e_test_framework/definitions/iree_definitions_test.py +++ b/build_tools/python/e2e_test_framework/definitions/iree_definitions_test.py @@ -11,147 +11,180 @@ class IreeDefinitionsTest(unittest.TestCase): - - def test_generate_run_flags(self): - imported_model = iree_definitions.ImportedModel.from_model( - common_definitions.Model( - id="1234", - name="tflite_m", - tags=[], - source_type=common_definitions.ModelSourceType.EXPORTED_TFLITE, - source_url="https://example.com/xyz.tflite", - entry_function="main", - input_types=["1xf32", "2x2xf32"])) - execution_config = iree_definitions.ModuleExecutionConfig.build( - id="123", - tags=["test"], - loader=iree_definitions.RuntimeLoader.EMBEDDED_ELF, - driver=iree_definitions.RuntimeDriver.LOCAL_TASK, - extra_flags=["--task=10"]) - - flags = iree_definitions.generate_run_flags( - imported_model=imported_model, - input_data=common_definitions.ZEROS_MODEL_INPUT_DATA, - module_execution_config=execution_config) - - self.assertEqual(flags, [ - "--function=main", "--input=1xf32=0", "--input=2x2xf32=0", "--task=10", - "--device=local-task" - ]) - - def test_generate_run_flags_with_cuda(self): - imported_model = iree_definitions.ImportedModel.from_model( - common_definitions.Model( - id="1234", - name="tflite_m", - tags=[], - source_type=common_definitions.ModelSourceType.EXPORTED_TFLITE, - source_url="https://example.com/xyz.tflite", - entry_function="main", - input_types=["1xf32"])) - execution_config = iree_definitions.ModuleExecutionConfig.build( - id="123", - tags=["test"], - loader=iree_definitions.RuntimeLoader.NONE, - driver=iree_definitions.RuntimeDriver.CUDA, - extra_flags=[]) - - flags = iree_definitions.generate_run_flags( - imported_model=imported_model, - input_data=common_definitions.ZEROS_MODEL_INPUT_DATA, - module_execution_config=execution_config, - gpu_id="3") - - self.assertEqual( - flags, ["--function=main", "--input=1xf32=0", "--device=cuda://3"]) - - def test_generate_run_flags_without_driver(self): - imported_model = iree_definitions.ImportedModel.from_model( - common_definitions.Model( - id="1234", - name="tflite_m", - tags=[], - source_type=common_definitions.ModelSourceType.EXPORTED_TFLITE, - source_url="https://example.com/xyz.tflite", - entry_function="main", - input_types=["1xf32"])) - execution_config = iree_definitions.ModuleExecutionConfig.build( - id="123", - tags=["test"], - loader=iree_definitions.RuntimeLoader.EMBEDDED_ELF, - driver=iree_definitions.RuntimeDriver.LOCAL_TASK, - extra_flags=["--task=10"]) - - flags = iree_definitions.generate_run_flags( - imported_model=imported_model, - input_data=common_definitions.ZEROS_MODEL_INPUT_DATA, - module_execution_config=execution_config, - with_driver=False) - - self.assertEqual(flags, ["--function=main", "--input=1xf32=0", "--task=10"]) + def test_generate_run_flags(self): + imported_model = iree_definitions.ImportedModel.from_model( + common_definitions.Model( + id="1234", + name="tflite_m", + tags=[], + source_type=common_definitions.ModelSourceType.EXPORTED_TFLITE, + source_url="https://example.com/xyz.tflite", + entry_function="main", + input_types=["1xf32", "2x2xf32"], + ) + ) + execution_config = iree_definitions.ModuleExecutionConfig.build( + id="123", + tags=["test"], + loader=iree_definitions.RuntimeLoader.EMBEDDED_ELF, + driver=iree_definitions.RuntimeDriver.LOCAL_TASK, + extra_flags=["--task=10"], + ) + + flags = iree_definitions.generate_run_flags( + imported_model=imported_model, + input_data=common_definitions.ZEROS_MODEL_INPUT_DATA, + module_execution_config=execution_config, + ) + + self.assertEqual( + flags, + [ + "--function=main", + "--input=1xf32=0", + "--input=2x2xf32=0", + "--task=10", + "--device=local-task", + ], + ) + + def test_generate_run_flags_with_cuda(self): + imported_model = iree_definitions.ImportedModel.from_model( + common_definitions.Model( + id="1234", + name="tflite_m", + tags=[], + source_type=common_definitions.ModelSourceType.EXPORTED_TFLITE, + source_url="https://example.com/xyz.tflite", + entry_function="main", + input_types=["1xf32"], + ) + ) + execution_config = iree_definitions.ModuleExecutionConfig.build( + id="123", + tags=["test"], + loader=iree_definitions.RuntimeLoader.NONE, + driver=iree_definitions.RuntimeDriver.CUDA, + extra_flags=[], + ) + + flags = iree_definitions.generate_run_flags( + imported_model=imported_model, + input_data=common_definitions.ZEROS_MODEL_INPUT_DATA, + module_execution_config=execution_config, + gpu_id="3", + ) + + self.assertEqual( + flags, ["--function=main", "--input=1xf32=0", "--device=cuda://3"] + ) + + def test_generate_run_flags_without_driver(self): + imported_model = iree_definitions.ImportedModel.from_model( + common_definitions.Model( + id="1234", + name="tflite_m", + tags=[], + source_type=common_definitions.ModelSourceType.EXPORTED_TFLITE, + source_url="https://example.com/xyz.tflite", + entry_function="main", + input_types=["1xf32"], + ) + ) + execution_config = iree_definitions.ModuleExecutionConfig.build( + id="123", + tags=["test"], + loader=iree_definitions.RuntimeLoader.EMBEDDED_ELF, + driver=iree_definitions.RuntimeDriver.LOCAL_TASK, + extra_flags=["--task=10"], + ) + + flags = iree_definitions.generate_run_flags( + imported_model=imported_model, + input_data=common_definitions.ZEROS_MODEL_INPUT_DATA, + module_execution_config=execution_config, + with_driver=False, + ) + + self.assertEqual(flags, ["--function=main", "--input=1xf32=0", "--task=10"]) class ModuleGenerationConfigTest(unittest.TestCase): - - def test_materialize_compile_flags(self): - imported_model = iree_definitions.ImportedModel.from_model( - common_definitions.Model( - id="1234", - name="tflite_m", - tags=[], - source_type=common_definitions.ModelSourceType.EXPORTED_TFLITE, - source_url="https://example.com/xyz.tflite", - entry_function="main", - input_types=["1xf32"])) - compile_target = iree_definitions.CompileTarget( - target_backend=iree_definitions.TargetBackend.LLVM_CPU, - target_architecture=common_definitions.DeviceArchitecture.RV64_GENERIC, - target_abi=iree_definitions.TargetABI.LINUX_GNU) - compile_config = iree_definitions.CompileConfig( - id="compile_config_a", - name="compile_config_a", - tags=["test"], - compile_targets=[compile_target], - extra_flags=[r"--test=${MODULE_DIR}/test.json"]) - gen_config = iree_definitions.ModuleGenerationConfig.build( - imported_model=imported_model, compile_config=compile_config) - - flags = gen_config.materialize_compile_flags( - module_dir_path=pathlib.Path("abc")) - - expected_path = pathlib.Path("abc", "test.json") - self.assertIn(f"--test={expected_path}", flags) - - def test_materialize_compile_flags_invalid_module_dir_position(self): - imported_model = iree_definitions.ImportedModel.from_model( - common_definitions.Model( - id="1234", - name="tflite_m", - tags=[], - source_type=common_definitions.ModelSourceType.EXPORTED_TFLITE, - source_url="https://example.com/xyz.tflite", - entry_function="main", - input_types=["1xf32"])) - compile_target = iree_definitions.CompileTarget( - target_backend=iree_definitions.TargetBackend.LLVM_CPU, - target_architecture=common_definitions.DeviceArchitecture.RV64_GENERIC, - target_abi=iree_definitions.TargetABI.LINUX_GNU) - compile_config = iree_definitions.CompileConfig( - id="compile_config_a", - name="compile_config_a", - tags=["test"], - compile_targets=[compile_target], - extra_flags=[r"--test=prefix/${MODULE_DIR}/test.json"]) - gen_config = iree_definitions.ModuleGenerationConfig.build( - imported_model=imported_model, compile_config=compile_config) - expected_error = ( - r"^'\${MODULE_DIR}' needs to be the head of flag value if present," - r" but got 'prefix/\${MODULE_DIR}/test.json'.$") - - self.assertRaisesRegex( - ValueError, expected_error, lambda: gen_config. - materialize_compile_flags(module_dir_path=pathlib.Path("abc"))) + def test_materialize_compile_flags(self): + imported_model = iree_definitions.ImportedModel.from_model( + common_definitions.Model( + id="1234", + name="tflite_m", + tags=[], + source_type=common_definitions.ModelSourceType.EXPORTED_TFLITE, + source_url="https://example.com/xyz.tflite", + entry_function="main", + input_types=["1xf32"], + ) + ) + compile_target = iree_definitions.CompileTarget( + target_backend=iree_definitions.TargetBackend.LLVM_CPU, + target_architecture=common_definitions.DeviceArchitecture.RV64_GENERIC, + target_abi=iree_definitions.TargetABI.LINUX_GNU, + ) + compile_config = iree_definitions.CompileConfig( + id="compile_config_a", + name="compile_config_a", + tags=["test"], + compile_targets=[compile_target], + extra_flags=[r"--test=${MODULE_DIR}/test.json"], + ) + gen_config = iree_definitions.ModuleGenerationConfig.build( + imported_model=imported_model, compile_config=compile_config + ) + + flags = gen_config.materialize_compile_flags( + module_dir_path=pathlib.Path("abc") + ) + + expected_path = pathlib.Path("abc", "test.json") + self.assertIn(f"--test={expected_path}", flags) + + def test_materialize_compile_flags_invalid_module_dir_position(self): + imported_model = iree_definitions.ImportedModel.from_model( + common_definitions.Model( + id="1234", + name="tflite_m", + tags=[], + source_type=common_definitions.ModelSourceType.EXPORTED_TFLITE, + source_url="https://example.com/xyz.tflite", + entry_function="main", + input_types=["1xf32"], + ) + ) + compile_target = iree_definitions.CompileTarget( + target_backend=iree_definitions.TargetBackend.LLVM_CPU, + target_architecture=common_definitions.DeviceArchitecture.RV64_GENERIC, + target_abi=iree_definitions.TargetABI.LINUX_GNU, + ) + compile_config = iree_definitions.CompileConfig( + id="compile_config_a", + name="compile_config_a", + tags=["test"], + compile_targets=[compile_target], + extra_flags=[r"--test=prefix/${MODULE_DIR}/test.json"], + ) + gen_config = iree_definitions.ModuleGenerationConfig.build( + imported_model=imported_model, compile_config=compile_config + ) + expected_error = ( + r"^'\${MODULE_DIR}' needs to be the head of flag value if present," + r" but got 'prefix/\${MODULE_DIR}/test.json'.$" + ) + + self.assertRaisesRegex( + ValueError, + expected_error, + lambda: gen_config.materialize_compile_flags( + module_dir_path=pathlib.Path("abc") + ), + ) if __name__ == "__main__": - unittest.main() + unittest.main() diff --git a/build_tools/python/e2e_test_framework/definitions/utils.py b/build_tools/python/e2e_test_framework/definitions/utils.py index a37e56130d7f..5079cc6679b1 100644 --- a/build_tools/python/e2e_test_framework/definitions/utils.py +++ b/build_tools/python/e2e_test_framework/definitions/utils.py @@ -11,41 +11,42 @@ MAX_SUBSTITUTION_ITERATIONS = 10 -def transform_flags(flags: Sequence[str], - map_funcs: Sequence[Callable[[str], str]]) -> List[str]: - """Call map functions to transform flag values, e.g., replace placeholders - that were unknown when the flag was constructed. +def transform_flags( + flags: Sequence[str], map_funcs: Sequence[Callable[[str], str]] +) -> List[str]: + """Call map functions to transform flag values, e.g., replace placeholders + that were unknown when the flag was constructed. - It parses and extracts the flag values from both keyword and positional flags, - transforms them, and returns the updated flags with transformed values. + It parses and extracts the flag values from both keyword and positional flags, + transforms them, and returns the updated flags with transformed values. - Each flag value is transformed only once by each map function in order. - - Args: - flags: list of flags. - map_funcs: list of map functions to map flag value. - Returns: - List of transformed flags. - """ + Each flag value is transformed only once by each map function in order. - transformed_flags = [] - for flag in flags: - keyword, separator, value = ("", "", flag) - if flag.startswith("-"): - keyword, separator, value = flag.partition("=") + Args: + flags: list of flags. + map_funcs: list of map functions to map flag value. + Returns: + List of transformed flags. + """ - if value: - for map_func in map_funcs: - value = map_func(value) + transformed_flags = [] + for flag in flags: + keyword, separator, value = ("", "", flag) + if flag.startswith("-"): + keyword, separator, value = flag.partition("=") - transformed_flags.append(f"{keyword}{separator}{value}") + if value: + for map_func in map_funcs: + value = map_func(value) - return transformed_flags + transformed_flags.append(f"{keyword}{separator}{value}") + + return transformed_flags def substitute_flag_vars(flags: Sequence[str], **mapping: Any) -> List[str]: - """Sugar of transform_flags to substitute variables in string.Template format. - """ - return transform_flags( - flags=flags, - map_funcs=[lambda value: string.Template(value).substitute(mapping)]) + """Sugar of transform_flags to substitute variables in string.Template format.""" + return transform_flags( + flags=flags, + map_funcs=[lambda value: string.Template(value).substitute(mapping)], + ) diff --git a/build_tools/python/e2e_test_framework/definitions/utils_test.py b/build_tools/python/e2e_test_framework/definitions/utils_test.py index 700687e9408f..7beae7e8cce1 100644 --- a/build_tools/python/e2e_test_framework/definitions/utils_test.py +++ b/build_tools/python/e2e_test_framework/definitions/utils_test.py @@ -10,44 +10,44 @@ class UtilsTest(unittest.TestCase): - - def test_transform_flags(self): - flags = utils.transform_flags( - flags=[ - r"${HOLDER_A} ${HOLDER_B}", r"--key=${HOLDER_A}", "--no-value-key", - r"--filter=x=${HOLDER_A}" - ], - map_funcs=[ - lambda value: value.replace(r"${HOLDER_A}", "val_a"), - lambda value: value.replace(r"${HOLDER_B}", "val_b") - ]) - - self.assertEquals( - flags, - ["val_a val_b", "--key=val_a", "--no-value-key", "--filter=x=val_a"]) - - def test_substitute_flag_vars(self): - raw_flags = [ - r"${HOLDER_A}", - r"--key=${HOLDER_B}", - ] - - flags = utils.substitute_flag_vars(flags=raw_flags, - HOLDER_A=1, - HOLDER_B="b") - - self.assertEquals(flags, ["1", "--key=b"]) - - def test_substitute_flag_vars_missing_variable(self): - raw_flags = [ - r"${HOLDER_A}", - r"--key=${HOLDER_B}", - ] - - self.assertRaises( - KeyError, - lambda: utils.substitute_flag_vars(flags=raw_flags, HOLDER_A=1)) + def test_transform_flags(self): + flags = utils.transform_flags( + flags=[ + r"${HOLDER_A} ${HOLDER_B}", + r"--key=${HOLDER_A}", + "--no-value-key", + r"--filter=x=${HOLDER_A}", + ], + map_funcs=[ + lambda value: value.replace(r"${HOLDER_A}", "val_a"), + lambda value: value.replace(r"${HOLDER_B}", "val_b"), + ], + ) + + self.assertEquals( + flags, ["val_a val_b", "--key=val_a", "--no-value-key", "--filter=x=val_a"] + ) + + def test_substitute_flag_vars(self): + raw_flags = [ + r"${HOLDER_A}", + r"--key=${HOLDER_B}", + ] + + flags = utils.substitute_flag_vars(flags=raw_flags, HOLDER_A=1, HOLDER_B="b") + + self.assertEquals(flags, ["1", "--key=b"]) + + def test_substitute_flag_vars_missing_variable(self): + raw_flags = [ + r"${HOLDER_A}", + r"--key=${HOLDER_B}", + ] + + self.assertRaises( + KeyError, lambda: utils.substitute_flag_vars(flags=raw_flags, HOLDER_A=1) + ) if __name__ == "__main__": - unittest.main() + unittest.main() diff --git a/build_tools/python/e2e_test_framework/device_specs/device_collections.py b/build_tools/python/e2e_test_framework/device_specs/device_collections.py index 9ed4629d01fb..250884be0891 100644 --- a/build_tools/python/e2e_test_framework/device_specs/device_collections.py +++ b/build_tools/python/e2e_test_framework/device_specs/device_collections.py @@ -7,42 +7,47 @@ from typing import List, Sequence, Set from e2e_test_framework.definitions import common_definitions -from e2e_test_framework.device_specs import gcp_specs, pixel_4_specs, pixel_6_pro_specs, moto_edge_x30_specs +from e2e_test_framework.device_specs import ( + gcp_specs, + pixel_4_specs, + pixel_6_pro_specs, + moto_edge_x30_specs, +) class DeviceCollection(object): - """Class to collect and query device specs.""" - - def __init__(self, device_specs: Sequence[common_definitions.DeviceSpec]): - self.device_specs = device_specs - - def query_device_specs( - self, - architecture: common_definitions.DeviceArchitecture, - host_environment: common_definitions.HostEnvironment, - device_parameters: Set[str] = set() - ) -> List[common_definitions.DeviceSpec]: - """Query the device specs. - - Args: - architecture: device architecture to match. - platform: device platform to match. - device_parameters: parameters that devices need to have. - Returns: - List of matched device specs. - """ - - matched_device_specs = [] - for device_spec in self.device_specs: - if device_spec.architecture != architecture: - continue - if device_spec.host_environment != host_environment: - continue - if not device_parameters.issubset(device_spec.device_parameters): - continue - matched_device_specs.append(device_spec) - - return matched_device_specs + """Class to collect and query device specs.""" + + def __init__(self, device_specs: Sequence[common_definitions.DeviceSpec]): + self.device_specs = device_specs + + def query_device_specs( + self, + architecture: common_definitions.DeviceArchitecture, + host_environment: common_definitions.HostEnvironment, + device_parameters: Set[str] = set(), + ) -> List[common_definitions.DeviceSpec]: + """Query the device specs. + + Args: + architecture: device architecture to match. + platform: device platform to match. + device_parameters: parameters that devices need to have. + Returns: + List of matched device specs. + """ + + matched_device_specs = [] + for device_spec in self.device_specs: + if device_spec.architecture != architecture: + continue + if device_spec.host_environment != host_environment: + continue + if not device_parameters.issubset(device_spec.device_parameters): + continue + matched_device_specs.append(device_spec) + + return matched_device_specs ALL_DEVICE_SPECS = [ diff --git a/build_tools/python/e2e_test_framework/device_specs/device_collections_test.py b/build_tools/python/e2e_test_framework/device_specs/device_collections_test.py index f441a3252a04..f7154be6dc63 100644 --- a/build_tools/python/e2e_test_framework/device_specs/device_collections_test.py +++ b/build_tools/python/e2e_test_framework/device_specs/device_collections_test.py @@ -10,68 +10,82 @@ class DeviceCollectionTest(unittest.TestCase): + def test_query_device_specs(self): + linux_x86_device_spec = common_definitions.DeviceSpec.build( + id="linux_x86", + device_name="a", + architecture=common_definitions.DeviceArchitecture.X86_64_CASCADELAKE, + host_environment=common_definitions.HostEnvironment.LINUX_X86_64, + tags=[], + ) + android_x86_device_spec = common_definitions.DeviceSpec.build( + id="android_x86", + device_name="b", + architecture=common_definitions.DeviceArchitecture.X86_64_CASCADELAKE, + host_environment=common_definitions.HostEnvironment.ANDROID_ARMV8_2_A, + tags=[], + ) + little_cores_device_spec = common_definitions.DeviceSpec.build( + id="android_little", + device_name="c", + architecture=common_definitions.DeviceArchitecture.ARMV9_A_GENERIC, + host_environment=common_definitions.HostEnvironment.ANDROID_ARMV8_2_A, + device_parameters=["little-cores"], + tags=[], + ) + big_cores_device_spec = common_definitions.DeviceSpec.build( + id="android_big", + device_name="d", + architecture=common_definitions.DeviceArchitecture.ARMV9_A_GENERIC, + host_environment=common_definitions.HostEnvironment.ANDROID_ARMV8_2_A, + device_parameters=["big-cores"], + tags=[], + ) + devices = device_collections.DeviceCollection( + device_specs=[ + linux_x86_device_spec, + android_x86_device_spec, + little_cores_device_spec, + big_cores_device_spec, + ] + ) - def test_query_device_specs(self): - linux_x86_device_spec = common_definitions.DeviceSpec.build( - id="linux_x86", - device_name="a", - architecture=common_definitions.DeviceArchitecture.X86_64_CASCADELAKE, - host_environment=common_definitions.HostEnvironment.LINUX_X86_64, - tags=[]) - android_x86_device_spec = common_definitions.DeviceSpec.build( - id="android_x86", - device_name="b", - architecture=common_definitions.DeviceArchitecture.X86_64_CASCADELAKE, - host_environment=common_definitions.HostEnvironment.ANDROID_ARMV8_2_A, - tags=[]) - little_cores_device_spec = common_definitions.DeviceSpec.build( - id="android_little", - device_name="c", - architecture=common_definitions.DeviceArchitecture.ARMV9_A_GENERIC, - host_environment=common_definitions.HostEnvironment.ANDROID_ARMV8_2_A, - device_parameters=["little-cores"], - tags=[]) - big_cores_device_spec = common_definitions.DeviceSpec.build( - id="android_big", - device_name="d", - architecture=common_definitions.DeviceArchitecture.ARMV9_A_GENERIC, - host_environment=common_definitions.HostEnvironment.ANDROID_ARMV8_2_A, - device_parameters=["big-cores"], - tags=[]) - devices = device_collections.DeviceCollection(device_specs=[ - linux_x86_device_spec, android_x86_device_spec, - little_cores_device_spec, big_cores_device_spec - ]) + linux_x86_devices = devices.query_device_specs( + architecture=common_definitions.DeviceArchitecture.X86_64_CASCADELAKE, + host_environment=common_definitions.HostEnvironment.LINUX_X86_64, + ) + android_x86_devices = devices.query_device_specs( + architecture=common_definitions.DeviceArchitecture.X86_64_CASCADELAKE, + host_environment=common_definitions.HostEnvironment.ANDROID_ARMV8_2_A, + ) + little_cores_devices = devices.query_device_specs( + architecture=common_definitions.DeviceArchitecture.ARMV9_A_GENERIC, + host_environment=common_definitions.HostEnvironment.ANDROID_ARMV8_2_A, + device_parameters={"little-cores"}, + ) + big_cores_devices = devices.query_device_specs( + architecture=common_definitions.DeviceArchitecture.ARMV9_A_GENERIC, + host_environment=common_definitions.HostEnvironment.ANDROID_ARMV8_2_A, + device_parameters={"big-cores"}, + ) + all_arm_devices = devices.query_device_specs( + architecture=common_definitions.DeviceArchitecture.ARMV9_A_GENERIC, + host_environment=common_definitions.HostEnvironment.ANDROID_ARMV8_2_A, + ) + no_matched_device = devices.query_device_specs( + architecture=common_definitions.DeviceArchitecture.ARMV9_A_GENERIC, + host_environment=common_definitions.HostEnvironment.LINUX_X86_64, + ) - linux_x86_devices = devices.query_device_specs( - architecture=common_definitions.DeviceArchitecture.X86_64_CASCADELAKE, - host_environment=common_definitions.HostEnvironment.LINUX_X86_64) - android_x86_devices = devices.query_device_specs( - architecture=common_definitions.DeviceArchitecture.X86_64_CASCADELAKE, - host_environment=common_definitions.HostEnvironment.ANDROID_ARMV8_2_A) - little_cores_devices = devices.query_device_specs( - architecture=common_definitions.DeviceArchitecture.ARMV9_A_GENERIC, - host_environment=common_definitions.HostEnvironment.ANDROID_ARMV8_2_A, - device_parameters={"little-cores"}) - big_cores_devices = devices.query_device_specs( - architecture=common_definitions.DeviceArchitecture.ARMV9_A_GENERIC, - host_environment=common_definitions.HostEnvironment.ANDROID_ARMV8_2_A, - device_parameters={"big-cores"}) - all_arm_devices = devices.query_device_specs( - architecture=common_definitions.DeviceArchitecture.ARMV9_A_GENERIC, - host_environment=common_definitions.HostEnvironment.ANDROID_ARMV8_2_A) - no_matched_device = devices.query_device_specs( - architecture=common_definitions.DeviceArchitecture.ARMV9_A_GENERIC, - host_environment=common_definitions.HostEnvironment.LINUX_X86_64) - - self.assertEqual(linux_x86_devices, [linux_x86_device_spec]) - self.assertEqual(android_x86_devices, [android_x86_device_spec]) - self.assertEqual(little_cores_devices, [little_cores_device_spec]) - self.assertEqual(big_cores_devices, [big_cores_device_spec]) - self.assertEqual(all_arm_devices, - [little_cores_device_spec, big_cores_device_spec]) - self.assertEqual(no_matched_device, []) + self.assertEqual(linux_x86_devices, [linux_x86_device_spec]) + self.assertEqual(android_x86_devices, [android_x86_device_spec]) + self.assertEqual(little_cores_devices, [little_cores_device_spec]) + self.assertEqual(big_cores_devices, [big_cores_device_spec]) + self.assertEqual( + all_arm_devices, [little_cores_device_spec, big_cores_device_spec] + ) + self.assertEqual(no_matched_device, []) if __name__ == "__main__": - unittest.main() + unittest.main() diff --git a/build_tools/python/e2e_test_framework/device_specs/gcp_specs.py b/build_tools/python/e2e_test_framework/device_specs/gcp_specs.py index 82b48ec6aff4..9ca92f8e899f 100644 --- a/build_tools/python/e2e_test_framework/device_specs/gcp_specs.py +++ b/build_tools/python/e2e_test_framework/device_specs/gcp_specs.py @@ -15,11 +15,13 @@ host_environment=common_definitions.HostEnvironment.LINUX_X86_64, architecture=common_definitions.DeviceArchitecture.X86_64_CASCADELAKE, device_parameters=[device_parameters.ALL_CORES], - tags=["cpu"]) + tags=["cpu"], +) GCP_A2_HIGHGPU_1G = common_definitions.DeviceSpec.build( id=unique_ids.DEVICE_SPEC_GCP_A2_HIGHGPU_1G, device_name="a2-highgpu-1g", host_environment=common_definitions.HostEnvironment.LINUX_X86_64, architecture=common_definitions.DeviceArchitecture.NVIDIA_AMPERE, - tags=["gpu"]) + tags=["gpu"], +) diff --git a/build_tools/python/e2e_test_framework/device_specs/moto_edge_x30_specs.py b/build_tools/python/e2e_test_framework/device_specs/moto_edge_x30_specs.py index 68337f9dbb00..4a1f2d1ccb7f 100644 --- a/build_tools/python/e2e_test_framework/device_specs/moto_edge_x30_specs.py +++ b/build_tools/python/e2e_test_framework/device_specs/moto_edge_x30_specs.py @@ -15,4 +15,5 @@ device_name=DEVICE_NAME, architecture=common_definitions.DeviceArchitecture.QUALCOMM_ADRENO, host_environment=common_definitions.HostEnvironment.ANDROID_ARMV8_2_A, - tags=["gpu"]) + tags=["gpu"], +) diff --git a/build_tools/python/e2e_test_framework/device_specs/pixel_4_specs.py b/build_tools/python/e2e_test_framework/device_specs/pixel_4_specs.py index c9e6dc6cafab..9cac9d23f15b 100644 --- a/build_tools/python/e2e_test_framework/device_specs/pixel_4_specs.py +++ b/build_tools/python/e2e_test_framework/device_specs/pixel_4_specs.py @@ -17,11 +17,13 @@ architecture=common_definitions.DeviceArchitecture.ARMV8_2_A_GENERIC, host_environment=common_definitions.HostEnvironment.ANDROID_ARMV8_2_A, device_parameters=[device_parameters.ARM_BIG_CORES], - tags=["big-core"]) + tags=["big-core"], +) LITTLE_CORES = common_definitions.DeviceSpec.build( id=unique_ids.DEVICE_SPEC_MOBILE_PIXEL_4 + "-little-core", device_name=DEVICE_NAME, architecture=common_definitions.DeviceArchitecture.ARMV8_2_A_GENERIC, host_environment=common_definitions.HostEnvironment.ANDROID_ARMV8_2_A, device_parameters=[device_parameters.ARM_LITTLE_CORES], - tags=["little-core"]) + tags=["little-core"], +) diff --git a/build_tools/python/e2e_test_framework/device_specs/pixel_6_pro_specs.py b/build_tools/python/e2e_test_framework/device_specs/pixel_6_pro_specs.py index bd1f17608c3b..26cba5147146 100644 --- a/build_tools/python/e2e_test_framework/device_specs/pixel_6_pro_specs.py +++ b/build_tools/python/e2e_test_framework/device_specs/pixel_6_pro_specs.py @@ -17,17 +17,20 @@ architecture=common_definitions.DeviceArchitecture.ARMV8_2_A_GENERIC, host_environment=common_definitions.HostEnvironment.ANDROID_ARMV8_2_A, device_parameters=[device_parameters.ARM_BIG_CORES], - tags=["big-core"]) + tags=["big-core"], +) LITTLE_CORES = common_definitions.DeviceSpec.build( id=unique_ids.DEVICE_SPEC_MOBILE_PIXEL_6_PRO + "-little-core", device_name=DEVICE_NAME, architecture=common_definitions.DeviceArchitecture.ARMV8_2_A_GENERIC, host_environment=common_definitions.HostEnvironment.ANDROID_ARMV8_2_A, device_parameters=[device_parameters.ARM_LITTLE_CORES], - tags=["little-core"]) + tags=["little-core"], +) GPU = common_definitions.DeviceSpec.build( id=unique_ids.DEVICE_SPEC_MOBILE_PIXEL_6_PRO + "-gpu", device_name=DEVICE_NAME, architecture=common_definitions.DeviceArchitecture.ARM_VALHALL, host_environment=common_definitions.HostEnvironment.ANDROID_ARMV8_2_A, - tags=["gpu"]) + tags=["gpu"], +) diff --git a/build_tools/python/e2e_test_framework/models/jax_models.py b/build_tools/python/e2e_test_framework/models/jax_models.py index 84b5982d81a1..af861f693748 100644 --- a/build_tools/python/e2e_test_framework/models/jax_models.py +++ b/build_tools/python/e2e_test_framework/models/jax_models.py @@ -16,59 +16,71 @@ ID_FORMAT = string.Template("${model_id}-batch${batch_size}") NAME_FORMAT = string.Template("${name}_BATCH${batch_size}") SOURCE_URL_FORMAT = string.Template( - GCS_ARTIFACT_ROOT_DIR + - "/${directory}/batch_${batch_size}/stablehlo.mlirbc") + GCS_ARTIFACT_ROOT_DIR + "/${directory}/batch_${batch_size}/stablehlo.mlirbc" +) # Derived from https://huggingface.co/docs/transformers/model_doc/resnet#transformers.FlaxResNetModel. RESNET50_TAGS = ["fp32", "cnn", "resnet"] RESNET50_FP32_JAX_3X224X224XF32_BATCHES = model_utils.generate_batch_models( id_template=model_utils.partial_template_substitute( - ID_FORMAT, model_id=unique_ids.MODEL_RESNET50_FP32_JAX_3X224X224XF32), + ID_FORMAT, model_id=unique_ids.MODEL_RESNET50_FP32_JAX_3X224X224XF32 + ), name_template=model_utils.partial_template_substitute( - NAME_FORMAT, name="RESNET50_FP32_JAX_3X224X224XF32"), + NAME_FORMAT, name="RESNET50_FP32_JAX_3X224X224XF32" + ), tags=RESNET50_TAGS, source_type=common_definitions.ModelSourceType.EXPORTED_STABLEHLO_MLIR, source_url_template=model_utils.partial_template_substitute( - SOURCE_URL_FORMAT, directory="RESNET50"), + SOURCE_URL_FORMAT, directory="RESNET50" + ), entry_function="main", input_type_templates=[string.Template("${batch_size}x3x224x224xf32")], - batch_sizes=[1, 8, 64, 128, 256, 2048]) + batch_sizes=[1, 8, 64, 128, 256, 2048], +) # Derived from https://huggingface.co/docs/transformers/model_doc/bert#transformers.FlaxBertModel. BERT_LARGE_TAGS = ["fp32", "seqlen384", "jax", "bert-variant"] BERT_LARGE_FP32_JAX_384XI32_BATCHES = model_utils.generate_batch_models( id_template=model_utils.partial_template_substitute( - ID_FORMAT, model_id=unique_ids.MODEL_BERT_LARGE_FP32_JAX_384XI32), + ID_FORMAT, model_id=unique_ids.MODEL_BERT_LARGE_FP32_JAX_384XI32 + ), name_template=model_utils.partial_template_substitute( - NAME_FORMAT, name="BERT_LARGE_JAX_384XI32"), + NAME_FORMAT, name="BERT_LARGE_JAX_384XI32" + ), tags=BERT_LARGE_TAGS, source_type=common_definitions.ModelSourceType.EXPORTED_STABLEHLO_MLIR, source_url_template=model_utils.partial_template_substitute( - SOURCE_URL_FORMAT, directory="BERT_LARGE"), + SOURCE_URL_FORMAT, directory="BERT_LARGE" + ), entry_function="main", input_type_templates=[ string.Template("${batch_size}x384xi32"), - string.Template("${batch_size}x384xi32") + string.Template("${batch_size}x384xi32"), ], - batch_sizes=[1, 16, 24, 32, 48, 64, 512, 1024, 1280]) + batch_sizes=[1, 16, 24, 32, 48, 64, 512, 1024, 1280], +) # Derived from https://huggingface.co/docs/transformers/model_doc/t5#transformers.FlaxT5Model T5_TAGS = ["fp32", "transformer-encoder", "transformer-decoder", "t5"] T5_LARGE_FP32_JAX_512XI32_BATCHES = model_utils.generate_batch_models( id_template=model_utils.partial_template_substitute( - ID_FORMAT, model_id=unique_ids.MODEL_T5_LARGE_FP32_JAX_512XI32), + ID_FORMAT, model_id=unique_ids.MODEL_T5_LARGE_FP32_JAX_512XI32 + ), name_template=model_utils.partial_template_substitute( - NAME_FORMAT, name="T5_LARGE_FP32_JAX_512XI32"), + NAME_FORMAT, name="T5_LARGE_FP32_JAX_512XI32" + ), tags=T5_TAGS, source_type=common_definitions.ModelSourceType.EXPORTED_STABLEHLO_MLIR, source_url_template=model_utils.partial_template_substitute( - SOURCE_URL_FORMAT, directory="T5_LARGE"), + SOURCE_URL_FORMAT, directory="T5_LARGE" + ), entry_function="main", input_type_templates=[ string.Template("${batch_size}x512xi32"), - string.Template("${batch_size}x512xi32") + string.Template("${batch_size}x512xi32"), ], - batch_sizes=[1, 16, 24, 32, 48, 64, 512]) + batch_sizes=[1, 16, 24, 32, 48, 64, 512], +) diff --git a/build_tools/python/e2e_test_framework/models/matmul.py b/build_tools/python/e2e_test_framework/models/matmul.py index 16a635fe5905..d751528fbf3e 100644 --- a/build_tools/python/e2e_test_framework/models/matmul.py +++ b/build_tools/python/e2e_test_framework/models/matmul.py @@ -13,97 +13,97 @@ name="matmul_3456x1024x2048_f16t_tile_config_default", tags=["fp16", "ubench", "matmul"], source_type=common_definitions.ModelSourceType.EXPORTED_LINALG_MLIR, - source_url= - "https://storage.googleapis.com/iree-model-artifacts/microbenchmarks/matmul/20230410_1681181224/matmul_3456x1024x2048_f16t_f16t_f16t_tile_config_default.mlirbc", + source_url="https://storage.googleapis.com/iree-model-artifacts/microbenchmarks/matmul/20230410_1681181224/matmul_3456x1024x2048_f16t_f16t_f16t_tile_config_default.mlirbc", entry_function="matmul_3456x1024x2048_f16t_f16t_f16t_tile_config_default", - input_types=["3456x2048xf16", "2048x1024xf16"]) + input_types=["3456x2048xf16", "2048x1024xf16"], +) MATMUL_3456X1024X2048_FP32_MLIR = common_definitions.Model( id=unique_ids.MICRO_MATMUL_3456X1024X2048_FP32_MLIR, name="matmul_3456x1024x2048_f32t_tile_config_default", tags=["fp32", "ubench", "matmul"], source_type=common_definitions.ModelSourceType.EXPORTED_LINALG_MLIR, - source_url= - "https://storage.googleapis.com/iree-model-artifacts/microbenchmarks/matmul/20230410_1681181224/matmul_3456x1024x2048_f32t_f32t_f32t_tile_config_default.mlirbc", + source_url="https://storage.googleapis.com/iree-model-artifacts/microbenchmarks/matmul/20230410_1681181224/matmul_3456x1024x2048_f32t_f32t_f32t_tile_config_default.mlirbc", entry_function="matmul_3456x1024x2048_f32t_f32t_f32t_tile_config_default", - input_types=["3456x2048xf32", "2048x1024xf32"]) + input_types=["3456x2048xf32", "2048x1024xf32"], +) MATMUL_2560X2560X2560_FP16_MLIR = common_definitions.Model( id=unique_ids.MICRO_MATMUL_2560X2560X2560_FP16_MLIR, name="matmul_2560x2560x2560_f16t_tile_config_default", tags=["fp16", "ubench", "matmul"], source_type=common_definitions.ModelSourceType.EXPORTED_LINALG_MLIR, - source_url= - "https://storage.googleapis.com/iree-model-artifacts/microbenchmarks/matmul/20230410_1681181224/matmul_2560x2560x2560_f16t_f16t_f16t_tile_config_default.mlirbc", + source_url="https://storage.googleapis.com/iree-model-artifacts/microbenchmarks/matmul/20230410_1681181224/matmul_2560x2560x2560_f16t_f16t_f16t_tile_config_default.mlirbc", entry_function="matmul_2560x2560x2560_f16t_f16t_f16t_tile_config_default", - input_types=["2560x2560xf16", "2560x2560xf16"]) + input_types=["2560x2560xf16", "2560x2560xf16"], +) MATMUL_2560X2560X2560_FP32_MLIR = common_definitions.Model( id=unique_ids.MICRO_MATMUL_2560X2560X2560_FP32_MLIR, name="matmul_2560x2560x2560_f32t_tile_config_default", tags=["fp32", "ubench", "matmul"], source_type=common_definitions.ModelSourceType.EXPORTED_LINALG_MLIR, - source_url= - "https://storage.googleapis.com/iree-model-artifacts/microbenchmarks/matmul/20230410_1681181224/matmul_2560x2560x2560_f32t_f32t_f32t_tile_config_default.mlirbc", + source_url="https://storage.googleapis.com/iree-model-artifacts/microbenchmarks/matmul/20230410_1681181224/matmul_2560x2560x2560_f32t_f32t_f32t_tile_config_default.mlirbc", entry_function="matmul_2560x2560x2560_f32t_f32t_f32t_tile_config_default", - input_types=["2560x2560xf32", "2560x2560xf32"]) + input_types=["2560x2560xf32", "2560x2560xf32"], +) MATMUL_128X256X8192_FP16_MLIR = common_definitions.Model( id=unique_ids.MICRO_MATMUL_128X256X8192_FP16_MLIR, name="matmul_128x256x8192_f16t_tile_config_default", tags=["fp16", "ubench", "matmul", "splitk"], source_type=common_definitions.ModelSourceType.EXPORTED_LINALG_MLIR, - source_url= - "https://storage.googleapis.com/iree-model-artifacts/microbenchmarks/matmul/20230410_1681181224/matmul_128x256x8192_f16t_f16t_f16t_tile_config_default.mlirbc", + source_url="https://storage.googleapis.com/iree-model-artifacts/microbenchmarks/matmul/20230410_1681181224/matmul_128x256x8192_f16t_f16t_f16t_tile_config_default.mlirbc", entry_function="matmul_128x256x8192_f16t_f16t_f16t_tile_config_default", - input_types=["128x8192xf16", "8192x256xf16"]) + input_types=["128x8192xf16", "8192x256xf16"], +) MATMUL_128X256X8192_FP32_MLIR = common_definitions.Model( id=unique_ids.MICRO_MATMUL_128X256X8192_FP32_MLIR, name="matmul_128x256x8192_f32t_tile_config_default", tags=["fp32", "ubench", "matmul", "splitk"], source_type=common_definitions.ModelSourceType.EXPORTED_LINALG_MLIR, - source_url= - "https://storage.googleapis.com/iree-model-artifacts/microbenchmarks/matmul/20230410_1681181224/matmul_128x256x8192_f32t_f32t_f32t_tile_config_default.mlirbc", + source_url="https://storage.googleapis.com/iree-model-artifacts/microbenchmarks/matmul/20230410_1681181224/matmul_128x256x8192_f32t_f32t_f32t_tile_config_default.mlirbc", entry_function="matmul_128x256x8192_f32t_f32t_f32t_tile_config_default", - input_types=["128x8192xf32", "8192x256xf32"]) + input_types=["128x8192xf32", "8192x256xf32"], +) MATMUL_2564x2564x2564_FP32_MLIR = common_definitions.Model( id=unique_ids.MICRO_MATMUL_2564x2564x2564_FP32_MLIR, name="matmul_2564x2564x2564_f32t_f32t_f32t_tile_config_default", tags=["fp32", "ubench", "matmul"], source_type=common_definitions.ModelSourceType.EXPORTED_LINALG_MLIR, - source_url= - "https://storage.googleapis.com/iree-model-artifacts/microbenchmarks/matmul/20230525_1685058259/matmul_2564x2564x2564_f32t_f32t_f32t_tile_config_default.mlirbc", + source_url="https://storage.googleapis.com/iree-model-artifacts/microbenchmarks/matmul/20230525_1685058259/matmul_2564x2564x2564_f32t_f32t_f32t_tile_config_default.mlirbc", entry_function="matmul_2564x2564x2564_f32t_f32t_f32t_tile_config_default", - input_types=["2564x2564xf32", "2564x2564xf32", "2564x2564xf32"]) + input_types=["2564x2564xf32", "2564x2564xf32", "2564x2564xf32"], +) MATMUL_2562x2564x2562_FP32_MLIR = common_definitions.Model( id=unique_ids.MICRO_MATMUL_2562x2564x2562_FP32_MLIR, name="matmul_2562x2564x2562_f32t_f32t_f32t_tile_config_default", tags=["fp32", "ubench", "matmul"], source_type=common_definitions.ModelSourceType.EXPORTED_LINALG_MLIR, - source_url= - "https://storage.googleapis.com/iree-model-artifacts/microbenchmarks/matmul/20230525_1685058259/matmul_2562x2564x2562_f32t_f32t_f32t_tile_config_default.mlirbc", + source_url="https://storage.googleapis.com/iree-model-artifacts/microbenchmarks/matmul/20230525_1685058259/matmul_2562x2564x2562_f32t_f32t_f32t_tile_config_default.mlirbc", entry_function="matmul_2562x2564x2562_f32t_f32t_f32t_tile_config_default", - input_types=["2562x2562xf32", "2562x2564xf32", "2562x2564xf32"]) + input_types=["2562x2562xf32", "2562x2564xf32", "2562x2564xf32"], +) MATMUL_2562x2561x2561_FP32_MLIR = common_definitions.Model( id=unique_ids.MICRO_MATMUL_2562x2561x2561_FP32_MLIR, name="matmul_2562x2561x2561_f32t_f32t_f32t_tile_config_default", tags=["fp32", "ubench", "matmul"], source_type=common_definitions.ModelSourceType.EXPORTED_LINALG_MLIR, - source_url= - "https://storage.googleapis.com/iree-model-artifacts/microbenchmarks/matmul/20230525_1685058259/matmul_2562x2561x2561_f32t_f32t_f32t_tile_config_default.mlirbc", + source_url="https://storage.googleapis.com/iree-model-artifacts/microbenchmarks/matmul/20230525_1685058259/matmul_2562x2561x2561_f32t_f32t_f32t_tile_config_default.mlirbc", entry_function="matmul_2562x2561x2561_f32t_f32t_f32t_tile_config_default", - input_types=["2562x2561xf32", "2561x2561xf32", "2562x2561xf32"]) + input_types=["2562x2561xf32", "2561x2561xf32", "2562x2561xf32"], +) MATMUL_123x2561x2561_FP32_MLIR = common_definitions.Model( id=unique_ids.MICRO_MATMUL_123x2561x2561_FP32_MLIR, name="matmul_123x2561x2561_f32t_f32t_f32t_tile_config_default", tags=["fp32", "ubench", "matmul"], source_type=common_definitions.ModelSourceType.EXPORTED_LINALG_MLIR, - source_url= - "https://storage.googleapis.com/iree-model-artifacts/microbenchmarks/matmul/20230612_1686563210/matmul_123x2561x2561_f32t_f32t_f32t_tile_config_default.mlirbc", + source_url="https://storage.googleapis.com/iree-model-artifacts/microbenchmarks/matmul/20230612_1686563210/matmul_123x2561x2561_f32t_f32t_f32t_tile_config_default.mlirbc", entry_function="matmul_123x2561x2561_f32t_f32t_f32t_tile_config_default", - input_types=["123x2561xf32", "2561x2561xf32", "123x2561xf32"]) + input_types=["123x2561xf32", "2561x2561xf32", "123x2561xf32"], +) diff --git a/build_tools/python/e2e_test_framework/models/model_groups.py b/build_tools/python/e2e_test_framework/models/model_groups.py index d3a0b2a82bb4..197bea637f97 100644 --- a/build_tools/python/e2e_test_framework/models/model_groups.py +++ b/build_tools/python/e2e_test_framework/models/model_groups.py @@ -6,7 +6,13 @@ """Defines the groups of models.""" from e2e_test_framework.definitions import common_definitions -from e2e_test_framework.models import matmul, tflite_models, torch_models, tf_models, jax_models +from e2e_test_framework.models import ( + matmul, + tflite_models, + torch_models, + tf_models, + jax_models, +) # x86 models, single batch. @@ -15,106 +21,145 @@ X86_64_BENCHMARK_CONFIG = [ # Tiny models. common_definitions.CpuBenchmarkConfig( - model=tflite_models.PERSON_DETECT_INT8, threads=[0, 1]), - common_definitions.CpuBenchmarkConfig(model=tflite_models.MOBILENET_V3SMALL, - threads=[0, 1]), + model=tflite_models.PERSON_DETECT_INT8, threads=[0, 1] + ), + common_definitions.CpuBenchmarkConfig( + model=tflite_models.MOBILENET_V3SMALL, threads=[0, 1] + ), # Small models. - common_definitions.CpuBenchmarkConfig(model=tflite_models.DEEPLABV3_FP32, - threads=[1, 8]), - common_definitions.CpuBenchmarkConfig(model=tflite_models.EFFICIENTNET_INT8, - threads=[1, 8]), - common_definitions.CpuBenchmarkConfig(model=tflite_models.MOBILENET_V1, - threads=[1, 8]), - common_definitions.CpuBenchmarkConfig(model=tflite_models.MOBILENET_V2, - threads=[1, 8]), - common_definitions.CpuBenchmarkConfig(model=tflite_models.MOBILENET_V2_INT8, - threads=[1, 8]), - common_definitions.CpuBenchmarkConfig(model=tflite_models.MOBILESSD_FP32, - threads=[1, 8]), - common_definitions.CpuBenchmarkConfig(model=tflite_models.POSENET_FP32, - threads=[1, 8]), + common_definitions.CpuBenchmarkConfig( + model=tflite_models.DEEPLABV3_FP32, threads=[1, 8] + ), + common_definitions.CpuBenchmarkConfig( + model=tflite_models.EFFICIENTNET_INT8, threads=[1, 8] + ), + common_definitions.CpuBenchmarkConfig( + model=tflite_models.MOBILENET_V1, threads=[1, 8] + ), + common_definitions.CpuBenchmarkConfig( + model=tflite_models.MOBILENET_V2, threads=[1, 8] + ), + common_definitions.CpuBenchmarkConfig( + model=tflite_models.MOBILENET_V2_INT8, threads=[1, 8] + ), + common_definitions.CpuBenchmarkConfig( + model=tflite_models.MOBILESSD_FP32, threads=[1, 8] + ), + common_definitions.CpuBenchmarkConfig( + model=tflite_models.POSENET_FP32, threads=[1, 8] + ), # Medium models. # TODO: Add 13 threads once we move to new hardware. - common_definitions.CpuBenchmarkConfig(model=tflite_models.MOBILEBERT_FP16, - threads=[1, 8]), - common_definitions.CpuBenchmarkConfig(model=tflite_models.MOBILEBERT_FP32, - threads=[1, 8]), - common_definitions.CpuBenchmarkConfig(model=tflite_models.MOBILEBERT_INT8, - threads=[1, 8]), common_definitions.CpuBenchmarkConfig( - model=tf_models.EFFICIENTNET_V2_S_FP32, threads=[1, 8]), + model=tflite_models.MOBILEBERT_FP16, threads=[1, 8] + ), + common_definitions.CpuBenchmarkConfig( + model=tflite_models.MOBILEBERT_FP32, threads=[1, 8] + ), + common_definitions.CpuBenchmarkConfig( + model=tflite_models.MOBILEBERT_INT8, threads=[1, 8] + ), + common_definitions.CpuBenchmarkConfig( + model=tf_models.EFFICIENTNET_V2_S_FP32, threads=[1, 8] + ), common_definitions.CpuBenchmarkConfig( - model=tf_models.MINILM_L12_H384_UNCASED_INT32_SEQLEN128, threads=[1, - 8]), + model=tf_models.MINILM_L12_H384_UNCASED_INT32_SEQLEN128, threads=[1, 8] + ), common_definitions.CpuBenchmarkConfig( - model=torch_models.EFFICIENTNET_V2_S_FP32_TORCH, threads=[1, 8]), + model=torch_models.EFFICIENTNET_V2_S_FP32_TORCH, threads=[1, 8] + ), # Large models. # TODO: These models should be running at 8, 13, 28 threads but we use 8 for now until new hardware becomes available. common_definitions.CpuBenchmarkConfig( - model=tf_models.BERT_FOR_MASKED_LM_FP32_SEQLEN512, threads=[8]), + model=tf_models.BERT_FOR_MASKED_LM_FP32_SEQLEN512, threads=[8] + ), common_definitions.CpuBenchmarkConfig( - model=tf_models.BERT_LARGE_TF_FP32_SEQLEN384, threads=[8]), + model=tf_models.BERT_LARGE_TF_FP32_SEQLEN384, threads=[8] + ), common_definitions.CpuBenchmarkConfig( - model=torch_models.EFFICIENTNET_B7_FP32_TORCH, threads=[8]), + model=torch_models.EFFICIENTNET_B7_FP32_TORCH, threads=[8] + ), ] # A subset of `x86_64_MODELS_AND_THREADS`. X86_64_BENCHMARK_CONFIG_EXPERIMENTAL = [ # Tiny models. common_definitions.CpuBenchmarkConfig( - model=tflite_models.PERSON_DETECT_INT8, threads=[1]), - common_definitions.CpuBenchmarkConfig(model=tflite_models.MOBILENET_V3SMALL, - threads=[1]), + model=tflite_models.PERSON_DETECT_INT8, threads=[1] + ), + common_definitions.CpuBenchmarkConfig( + model=tflite_models.MOBILENET_V3SMALL, threads=[1] + ), # Small models. - common_definitions.CpuBenchmarkConfig(model=tflite_models.DEEPLABV3_FP32, - threads=[8]), - common_definitions.CpuBenchmarkConfig(model=tflite_models.EFFICIENTNET_INT8, - threads=[8]), - common_definitions.CpuBenchmarkConfig(model=tflite_models.MOBILENET_V2, - threads=[8]), - common_definitions.CpuBenchmarkConfig(model=tflite_models.MOBILENET_V2_INT8, - threads=[8]), - common_definitions.CpuBenchmarkConfig(model=tflite_models.MOBILESSD_FP32, - threads=[8]), - common_definitions.CpuBenchmarkConfig(model=tflite_models.POSENET_FP32, - threads=[8]), + common_definitions.CpuBenchmarkConfig( + model=tflite_models.DEEPLABV3_FP32, threads=[8] + ), + common_definitions.CpuBenchmarkConfig( + model=tflite_models.EFFICIENTNET_INT8, threads=[8] + ), + common_definitions.CpuBenchmarkConfig( + model=tflite_models.MOBILENET_V2, threads=[8] + ), + common_definitions.CpuBenchmarkConfig( + model=tflite_models.MOBILENET_V2_INT8, threads=[8] + ), + common_definitions.CpuBenchmarkConfig( + model=tflite_models.MOBILESSD_FP32, threads=[8] + ), + common_definitions.CpuBenchmarkConfig( + model=tflite_models.POSENET_FP32, threads=[8] + ), # Medium models. - common_definitions.CpuBenchmarkConfig(model=tflite_models.MOBILEBERT_FP32, - threads=[8]), - common_definitions.CpuBenchmarkConfig(model=tflite_models.MOBILEBERT_INT8, - threads=[8]), common_definitions.CpuBenchmarkConfig( - model=tf_models.EFFICIENTNET_V2_S_FP32, threads=[8]), + model=tflite_models.MOBILEBERT_FP32, threads=[8] + ), + common_definitions.CpuBenchmarkConfig( + model=tflite_models.MOBILEBERT_INT8, threads=[8] + ), common_definitions.CpuBenchmarkConfig( - model=tf_models.MINILM_L12_H384_UNCASED_INT32_SEQLEN128, threads=[8]), + model=tf_models.EFFICIENTNET_V2_S_FP32, threads=[8] + ), + common_definitions.CpuBenchmarkConfig( + model=tf_models.MINILM_L12_H384_UNCASED_INT32_SEQLEN128, threads=[8] + ), # Disabled due to https://github.com/openxla/iree/issues/12772. # common_definitions.CpuBenchmarkConfig(model=torch_models.EFFICIENTNET_V2_S_FP32_TORCH, threads=[8]), # Large models. common_definitions.CpuBenchmarkConfig( - model=tf_models.BERT_LARGE_TF_FP32_SEQLEN384, threads=[8]), + model=tf_models.BERT_LARGE_TF_FP32_SEQLEN384, threads=[8] + ), # Disabled due to https://github.com/openxla/iree/issues/12772. # common_definitions.CpuBenchmarkConfig(model=torch_models.EFFICIENTNET_B7_FP32_TORCH, threads=[8]), ] X86_64_BENCHMARK_CONFIG_LONG = [ common_definitions.CpuBenchmarkConfig( - model=tf_models.BERT_LARGE_384_FP32_TF_BATCHES[1], threads=[8]), + model=tf_models.BERT_LARGE_384_FP32_TF_BATCHES[1], threads=[8] + ), common_definitions.CpuBenchmarkConfig( - model=tf_models.BERT_LARGE_384_FP32_TF_BATCHES[32], threads=[8]), + model=tf_models.BERT_LARGE_384_FP32_TF_BATCHES[32], threads=[8] + ), common_definitions.CpuBenchmarkConfig( - model=tf_models.BERT_LARGE_384_FP32_TF_BATCHES[64], threads=[8]), + model=tf_models.BERT_LARGE_384_FP32_TF_BATCHES[64], threads=[8] + ), common_definitions.CpuBenchmarkConfig( - model=tf_models.RESNET50_3X224X224_FP32_TF_BATCHES[1], threads=[8]), + model=tf_models.RESNET50_3X224X224_FP32_TF_BATCHES[1], threads=[8] + ), common_definitions.CpuBenchmarkConfig( - model=tf_models.RESNET50_3X224X224_FP32_TF_BATCHES[64], threads=[8]), + model=tf_models.RESNET50_3X224X224_FP32_TF_BATCHES[64], threads=[8] + ), common_definitions.CpuBenchmarkConfig( - model=tf_models.RESNET50_3X224X224_FP32_TF_BATCHES[128], threads=[8]), + model=tf_models.RESNET50_3X224X224_FP32_TF_BATCHES[128], threads=[8] + ), common_definitions.CpuBenchmarkConfig( - model=tf_models.T5_LARGE_512_FP32_TF_BATCHES[1], threads=[8]), + model=tf_models.T5_LARGE_512_FP32_TF_BATCHES[1], threads=[8] + ), common_definitions.CpuBenchmarkConfig( - model=tf_models.T5_LARGE_512_FP32_TF_BATCHES[16], threads=[8]), + model=tf_models.T5_LARGE_512_FP32_TF_BATCHES[16], threads=[8] + ), common_definitions.CpuBenchmarkConfig( - model=tf_models.T5_LARGE_512_FP32_TF_BATCHES[32], threads=[8]), + model=tf_models.T5_LARGE_512_FP32_TF_BATCHES[32], threads=[8] + ), ] # Microkernels. @@ -137,28 +182,28 @@ # Batched Torch models. -BERT_LARGE_TORCH_BATCHES = list( - torch_models.BERT_LARGE_384_FP32_TORCH_BATCHES.values()) +BERT_LARGE_TORCH_BATCHES = list(torch_models.BERT_LARGE_384_FP32_TORCH_BATCHES.values()) BERT_LARGE_FP16_TORCH_BATCHES = [ - model for batch_size, model in - torch_models.BERT_LARGE_384_FP16_TORCH_BATCHES.items() + model + for batch_size, model in torch_models.BERT_LARGE_384_FP16_TORCH_BATCHES.items() # Batchsize 1 is included seperately in CUDA_MODELS if batch_size != 1 ] RESNET50_TORCH_BATCHES = list( - torch_models.RESNET50_3X224X224_FP32_TORCH_BATCHES.values()) + torch_models.RESNET50_3X224X224_FP32_TORCH_BATCHES.values() +) RESNET50_FP16_TORCH_BATCHES = list( - torch_models.RESNET50_3X224X224_FP16_TORCH_BATCHES.values()) + torch_models.RESNET50_3X224X224_FP16_TORCH_BATCHES.values() +) # Batched Tensorflow models. BERT_LARGE_TF_BATCHES = list(tf_models.BERT_LARGE_384_FP32_TF_BATCHES.values()) -RESNET50_TF_BATCHES = list( - tf_models.RESNET50_3X224X224_FP32_TF_BATCHES.values()) +RESNET50_TF_BATCHES = list(tf_models.RESNET50_3X224X224_FP32_TF_BATCHES.values()) T5_LARGE_TF_BATCHES = [ model @@ -169,14 +214,11 @@ # Batched JAX models. -RESNET50_JAX_BATCHES = list( - jax_models.RESNET50_FP32_JAX_3X224X224XF32_BATCHES.values()) +RESNET50_JAX_BATCHES = list(jax_models.RESNET50_FP32_JAX_3X224X224XF32_BATCHES.values()) -BERT_LARGE_JAX_BATCHES = list( - jax_models.BERT_LARGE_FP32_JAX_384XI32_BATCHES.values()) +BERT_LARGE_JAX_BATCHES = list(jax_models.BERT_LARGE_FP32_JAX_384XI32_BATCHES.values()) -T5_LARGE_JAX_BATCHES = list( - jax_models.T5_LARGE_FP32_JAX_512XI32_BATCHES.values()) +T5_LARGE_JAX_BATCHES = list(jax_models.T5_LARGE_FP32_JAX_512XI32_BATCHES.values()) # GPU model groups. @@ -192,11 +234,18 @@ torch_models.EFFICIENTNET_V2_S_FP16_TORCH, ] -CUDA_MODELS_LONG = (RESNET50_TF_BATCHES + BERT_LARGE_TF_BATCHES + - T5_LARGE_TF_BATCHES + BERT_LARGE_TORCH_BATCHES + - RESNET50_TORCH_BATCHES + RESNET50_FP16_TORCH_BATCHES + - BERT_LARGE_FP16_TORCH_BATCHES + BERT_LARGE_JAX_BATCHES + - RESNET50_JAX_BATCHES + T5_LARGE_JAX_BATCHES) +CUDA_MODELS_LONG = ( + RESNET50_TF_BATCHES + + BERT_LARGE_TF_BATCHES + + T5_LARGE_TF_BATCHES + + BERT_LARGE_TORCH_BATCHES + + RESNET50_TORCH_BATCHES + + RESNET50_FP16_TORCH_BATCHES + + BERT_LARGE_FP16_TORCH_BATCHES + + BERT_LARGE_JAX_BATCHES + + RESNET50_JAX_BATCHES + + T5_LARGE_JAX_BATCHES +) VULKAN_MODELS = [ torch_models.MODEL_CLIP_TEXT_SEQLEN64_FP32_TORCH, diff --git a/build_tools/python/e2e_test_framework/models/tf_models.py b/build_tools/python/e2e_test_framework/models/tf_models.py index 58ac9cc29498..3cf7f2c8ddb3 100644 --- a/build_tools/python/e2e_test_framework/models/tf_models.py +++ b/build_tools/python/e2e_test_framework/models/tf_models.py @@ -11,7 +11,9 @@ from e2e_test_framework.definitions import common_definitions import e2e_test_framework.models.utils as model_utils -TF_MODELS_MANUAL_ROOT_DIR = "https://storage.googleapis.com/iree-model-artifacts/tensorflow/manual" +TF_MODELS_MANUAL_ROOT_DIR = ( + "https://storage.googleapis.com/iree-model-artifacts/tensorflow/manual" +) MINILM_L12_H384_UNCASED_INT32_SEQLEN128 = common_definitions.Model( id=unique_ids.MODEL_MINILM_L12_H384_UNCASED_INT32_SEQLEN128, @@ -19,10 +21,10 @@ tags=["int32", "seqlen128"], source_type=common_definitions.ModelSourceType.EXPORTED_STABLEHLO_MLIR, # Converted from https://huggingface.co/microsoft/MiniLM-L12-H384-uncased/commit/44acabbec0ef496f6dbc93adadea57f376b7c0ec - source_url= - f"{TF_MODELS_MANUAL_ROOT_DIR}/MiniLML12H384Uncased_2023-05-07.timestamp_1683504734.mlirbc", + source_url=f"{TF_MODELS_MANUAL_ROOT_DIR}/MiniLML12H384Uncased_2023-05-07.timestamp_1683504734.mlirbc", entry_function="predict", - input_types=["1x128xi32", "1x128xi32", "1x128xi32"]) + input_types=["1x128xi32", "1x128xi32", "1x128xi32"], +) BERT_FOR_MASKED_LM_FP32_SEQLEN512 = common_definitions.Model( id=unique_ids.MODEL_BERT_FOR_MASKED_LM_FP32_SEQLEN512_TF, @@ -30,10 +32,10 @@ tags=["fp32", "seqlen512", "tensorflow"], source_type=common_definitions.ModelSourceType.EXPORTED_STABLEHLO_MLIR, # Converted from https://huggingface.co/transformers/v3.0.2/model_doc/bert.html#tfbertformaskedlm - source_url= - f"{TF_MODELS_MANUAL_ROOT_DIR}/BertForMaskedLMTF_2023-05-07.timestamp_1683504734.mlirbc", + source_url=f"{TF_MODELS_MANUAL_ROOT_DIR}/BertForMaskedLMTF_2023-05-07.timestamp_1683504734.mlirbc", entry_function="forward", - input_types=["1x512xi32", "1x512xi32"]) + input_types=["1x512xi32", "1x512xi32"], +) EFFICIENTNET_V2_S_FP32 = common_definitions.Model( id=unique_ids.MODEL_EFFICIENTNET_V2_S_FP32_TF, @@ -41,10 +43,10 @@ tags=["fp32", "cnn", "tensorflow"], source_type=common_definitions.ModelSourceType.EXPORTED_STABLEHLO_MLIR, # Converted from https://github.com/keras-team/keras/blob/v2.10.0/keras/applications/efficientnet_v2.py - source_url= - f"{TF_MODELS_MANUAL_ROOT_DIR}/EfficientNetV2STF_2023-05-07.timestamp_1683504734.mlirbc", + source_url=f"{TF_MODELS_MANUAL_ROOT_DIR}/EfficientNetV2STF_2023-05-07.timestamp_1683504734.mlirbc", entry_function="forward", - input_types=["1x384x384x3xf32"]) + input_types=["1x384x384x3xf32"], +) # This is the model used in the MLPerf Inference Suite. BERT_LARGE_TF_FP32_SEQLEN384 = common_definitions.Model( @@ -54,62 +56,75 @@ source_type=common_definitions.ModelSourceType.EXPORTED_STABLEHLO_MLIR, # Derived from https://github.com/mlcommons/inference/tree/master/language/bert # Instructions on how to regenerate the model: https://gist.github.com/mariecwhite/e61ccebd979d98d097946ac7725bcc29 - source_url= - f"{TF_MODELS_MANUAL_ROOT_DIR}/BertLargeTF_2023-05-07.timestamp_1683504734.mlirbc", + source_url=f"{TF_MODELS_MANUAL_ROOT_DIR}/BertLargeTF_2023-05-07.timestamp_1683504734.mlirbc", entry_function="serving_default", - input_types=["1x384xi32", "1x384xi32", "1x384xi32"]) + input_types=["1x384xi32", "1x384xi32", "1x384xi32"], +) TF_MODELS_ROOT_DIR = "https://storage.googleapis.com/iree-model-artifacts/tensorflow/tf_models_2.12.0_1683544084" ID_FORMAT = string.Template("${model_id}-batch-${batch_size}") NAME_FORMAT = string.Template("${name}Batch${batch_size}") SOURCE_URL_FORMAT = string.Template( - TF_MODELS_ROOT_DIR + "/${directory}/batch_${batch_size}/hlo.mlirbc") + TF_MODELS_ROOT_DIR + "/${directory}/batch_${batch_size}/hlo.mlirbc" +) # Derived from https://huggingface.co/docs/transformers/model_doc/bert#transformers.TFBertModel. BERT_LARGE_384_FP32_TF_BATCHES = model_utils.generate_batch_models( id_template=model_utils.partial_template_substitute( - ID_FORMAT, model_id=unique_ids.MODEL_BERT_LARGE_384_FP32_TF), - name_template=model_utils.partial_template_substitute(NAME_FORMAT, - name="BertLargeTF"), + ID_FORMAT, model_id=unique_ids.MODEL_BERT_LARGE_384_FP32_TF + ), + name_template=model_utils.partial_template_substitute( + NAME_FORMAT, name="BertLargeTF" + ), tags=["fp32", "seqlen384", "tensorflow", "bert-variant"], source_type=common_definitions.ModelSourceType.EXPORTED_STABLEHLO_MLIR, source_url_template=model_utils.partial_template_substitute( - SOURCE_URL_FORMAT, directory="BERT_LARGE"), + SOURCE_URL_FORMAT, directory="BERT_LARGE" + ), entry_function="forward", input_type_templates=[ string.Template("${batch_size}x384xi32"), - string.Template("${batch_size}x384xi32") + string.Template("${batch_size}x384xi32"), ], - batch_sizes=[1, 16, 24, 32, 48, 64, 512, 1024, 1280]) + batch_sizes=[1, 16, 24, 32, 48, 64, 512, 1024, 1280], +) # Converted from https://www.tensorflow.org/api_docs/python/tf/keras/applications/resnet50/ResNet50 RESNET50_3X224X224_FP32_TF_BATCHES = model_utils.generate_batch_models( id_template=model_utils.partial_template_substitute( - ID_FORMAT, model_id=unique_ids.MODEL_RESNET50_3X224X224_FP32_TF), - name_template=model_utils.partial_template_substitute(NAME_FORMAT, - name="Resnet50TF"), + ID_FORMAT, model_id=unique_ids.MODEL_RESNET50_3X224X224_FP32_TF + ), + name_template=model_utils.partial_template_substitute( + NAME_FORMAT, name="Resnet50TF" + ), tags=["fp32", "cnn"], source_type=common_definitions.ModelSourceType.EXPORTED_STABLEHLO_MLIR, source_url_template=model_utils.partial_template_substitute( - SOURCE_URL_FORMAT, directory="RESNET50"), + SOURCE_URL_FORMAT, directory="RESNET50" + ), entry_function="forward", input_type_templates=[string.Template("${batch_size}x224x224x3xf32")], - batch_sizes=[1, 8, 64, 128, 256, 2048]) + batch_sizes=[1, 8, 64, 128, 256, 2048], +) # Derived from https://huggingface.co/transformers/v3.0.2/model_doc/t5.html#tft5model. T5_LARGE_512_FP32_TF_BATCHES = model_utils.generate_batch_models( id_template=model_utils.partial_template_substitute( - ID_FORMAT, model_id=unique_ids.MODEL_T5_LARGE_512_FP32_TF), - name_template=model_utils.partial_template_substitute(NAME_FORMAT, - name="T5LargeTF"), + ID_FORMAT, model_id=unique_ids.MODEL_T5_LARGE_512_FP32_TF + ), + name_template=model_utils.partial_template_substitute( + NAME_FORMAT, name="T5LargeTF" + ), tags=["fp32", "seqlen512", "tensorflow"], source_type=common_definitions.ModelSourceType.EXPORTED_STABLEHLO_MLIR, source_url_template=model_utils.partial_template_substitute( - SOURCE_URL_FORMAT, directory="T5_LARGE"), + SOURCE_URL_FORMAT, directory="T5_LARGE" + ), entry_function="forward", input_type_templates=[ string.Template("${batch_size}x512xi32"), - string.Template("${batch_size}x512xi32") + string.Template("${batch_size}x512xi32"), ], - batch_sizes=[1, 16, 24, 32, 48, 64, 512]) + batch_sizes=[1, 16, 24, 32, 48, 64, 512], +) diff --git a/build_tools/python/e2e_test_framework/models/tflite_models.py b/build_tools/python/e2e_test_framework/models/tflite_models.py index 4909c3b29b12..6fdcd79d3e43 100644 --- a/build_tools/python/e2e_test_framework/models/tflite_models.py +++ b/build_tools/python/e2e_test_framework/models/tflite_models.py @@ -14,10 +14,10 @@ tags=["fp32"], source_type=common_definitions.ModelSourceType.EXPORTED_TFLITE, # Mirror of https://tfhub.dev/tensorflow/lite-model/deeplabv3/1/default/1 - source_url= - "https://storage.googleapis.com/iree-model-artifacts/deeplabv3.tflite", + source_url="https://storage.googleapis.com/iree-model-artifacts/deeplabv3.tflite", entry_function="main", - input_types=["1x257x257x3xf32"]) + input_types=["1x257x257x3xf32"], +) MOBILESSD_FP32 = common_definitions.Model( id=unique_ids.MODEL_MOBILESSD_FP32, @@ -25,10 +25,10 @@ tags=["fp32"], source_type=common_definitions.ModelSourceType.EXPORTED_TFLITE, # Mirror of https://storage.googleapis.com/download.tensorflow.org/models/tflite/gpu/mobile_ssd_v2_float_coco.tflite - source_url= - "https://storage.googleapis.com/iree-model-artifacts/mobile_ssd_v2_float_coco.tflite", + source_url="https://storage.googleapis.com/iree-model-artifacts/mobile_ssd_v2_float_coco.tflite", entry_function="main", - input_types=["1x320x320x3xf32"]) + input_types=["1x320x320x3xf32"], +) POSENET_FP32 = common_definitions.Model( id=unique_ids.MODEL_POSENET_FP32, @@ -36,10 +36,10 @@ tags=["fp32"], source_type=common_definitions.ModelSourceType.EXPORTED_TFLITE, # Mirror of https://tfhub.dev/tensorflow/lite-model/posenet/mobilenet/float/075/1/default/1 - source_url= - "https://storage.googleapis.com/iree-model-artifacts/posenet.tflite", + source_url="https://storage.googleapis.com/iree-model-artifacts/posenet.tflite", entry_function="main", - input_types=["1x353x257x3xf32"]) + input_types=["1x353x257x3xf32"], +) MOBILEBERT_FP32 = common_definitions.Model( id=unique_ids.MODEL_MOBILEBERT_FP32, @@ -47,10 +47,10 @@ tags=["fp32"], source_type=common_definitions.ModelSourceType.EXPORTED_TFLITE, # Mirror of https://tfhub.dev/iree/lite-model/mobilebert/fp32/1 - source_url= - "https://storage.googleapis.com/iree-model-artifacts/mobilebert-baseline-tf2-float.tflite", + source_url="https://storage.googleapis.com/iree-model-artifacts/mobilebert-baseline-tf2-float.tflite", entry_function="main", - input_types=["1x384xi32", "1x384xi32", "1x384xi32"]) + input_types=["1x384xi32", "1x384xi32", "1x384xi32"], +) MOBILEBERT_INT8 = common_definitions.Model( id=unique_ids.MODEL_MOBILEBERT_INT8, @@ -58,10 +58,10 @@ tags=["int8"], source_type=common_definitions.ModelSourceType.EXPORTED_TFLITE, # Mirror of https://tfhub.dev/iree/lite-model/mobilebert/int8/1 - source_url= - "https://storage.googleapis.com/iree-model-artifacts/mobilebert-baseline-tf2-quant.tflite", + source_url="https://storage.googleapis.com/iree-model-artifacts/mobilebert-baseline-tf2-quant.tflite", entry_function="main", - input_types=["1x384xi32", "1x384xi32", "1x384xi32"]) + input_types=["1x384xi32", "1x384xi32", "1x384xi32"], +) MOBILEBERT_FP16 = common_definitions.Model( id=unique_ids.MODEL_MOBILEBERT_FP16, @@ -69,10 +69,10 @@ tags=["fp16"], source_type=common_definitions.ModelSourceType.EXPORTED_TFLITE, # Mirror of https://tfhub.dev/tensorflow/lite-model/mobilebert/1/default/1 - source_url= - "https://storage.googleapis.com/iree-model-artifacts/mobilebertsquad.tflite", + source_url="https://storage.googleapis.com/iree-model-artifacts/mobilebertsquad.tflite", entry_function="main", - input_types=["1x384xi32", "1x384xi32", "1x384xi32"]) + input_types=["1x384xi32", "1x384xi32", "1x384xi32"], +) MOBILENET_V1 = common_definitions.Model( id=unique_ids.MODEL_MOBILENET_V1, @@ -80,10 +80,10 @@ tags=["fp32", "imagenet"], source_type=common_definitions.ModelSourceType.EXPORTED_TFLITE, # Mirror of https://tfhub.dev/iree/lite-model/mobilenet_v1_100_224/fp32/1 - source_url= - "https://storage.googleapis.com/iree-model-artifacts/mobilenet_v1_224_1.0_float.tflite", + source_url="https://storage.googleapis.com/iree-model-artifacts/mobilenet_v1_224_1.0_float.tflite", entry_function="main", - input_types=["1x224x224x3xf32"]) + input_types=["1x224x224x3xf32"], +) MOBILENET_V2 = common_definitions.Model( id=unique_ids.MODEL_MOBILENET_V2, @@ -91,10 +91,10 @@ tags=["fp32", "imagenet"], source_type=common_definitions.ModelSourceType.EXPORTED_TFLITE, # Mirror of https://github.com/tensorflow/tflite-support/blob/master/tensorflow_lite_support/metadata/python/tests/testdata/image_classifier/mobilenet_v2_1.0_224.tflite - source_url= - "https://storage.googleapis.com/iree-model-artifacts/mobilenet_v2_1.0_224.tflite", + source_url="https://storage.googleapis.com/iree-model-artifacts/mobilenet_v2_1.0_224.tflite", entry_function="main", - input_types=["1x224x224x3xf32"]) + input_types=["1x224x224x3xf32"], +) MOBILENET_V3SMALL = common_definitions.Model( id=unique_ids.MODEL_MOBILENET_V3SMALL, @@ -103,10 +103,10 @@ source_type=common_definitions.ModelSourceType.EXPORTED_TFLITE, # https://tfhub.dev/google/imagenet/mobilenet_v3_small_100_224/classification/5 # Manually exported to tflite with static batch dimension - source_url= - "https://storage.googleapis.com/iree-model-artifacts/MobileNetV3SmallStaticBatch.tflite", + source_url="https://storage.googleapis.com/iree-model-artifacts/MobileNetV3SmallStaticBatch.tflite", entry_function="main", - input_types=["1x224x224x3xf32"]) + input_types=["1x224x224x3xf32"], +) PERSON_DETECT_INT8 = common_definitions.Model( id=unique_ids.MODEL_PERSON_DETECT_INT8, @@ -114,10 +114,10 @@ tags=["int8"], source_type=common_definitions.ModelSourceType.EXPORTED_TFLITE, # Mirror of https://github.com/tensorflow/tflite-micro/raw/aeac6f39e5c7475cea20c54e86d41e3a38312546/tensorflow/lite/micro/models/person_detect.tflite - source_url= - "https://storage.googleapis.com/iree-model-artifacts/person_detect.tflite", + source_url="https://storage.googleapis.com/iree-model-artifacts/person_detect.tflite", entry_function="main", - input_types=["1x96x96x1xi8"]) + input_types=["1x96x96x1xi8"], +) EFFICIENTNET_INT8 = common_definitions.Model( id=unique_ids.MODEL_EFFICIENTNET_INT8, @@ -125,10 +125,10 @@ tags=["int8"], source_type=common_definitions.ModelSourceType.EXPORTED_TFLITE, # Mirror of https://tfhub.dev/tensorflow/lite-model/efficientnet/lite0/int8/2 - source_url= - "https://storage.googleapis.com/iree-model-artifacts/efficientnet_lite0_int8_2.tflite", + source_url="https://storage.googleapis.com/iree-model-artifacts/efficientnet_lite0_int8_2.tflite", entry_function="main", - input_types=["1x224x224x3xui8"]) + input_types=["1x224x224x3xui8"], +) MOBILENET_V2_INT8 = common_definitions.Model( name="MobileNetV2_int8", @@ -136,7 +136,7 @@ tags=["int8", "imagenet"], source_type=common_definitions.ModelSourceType.EXPORTED_TFLITE, # Mirror of https://tfhub.dev/tensorflow/lite-model/mobilenet_v2_1.0_224_quantized/1/default/1 - source_url= - "https://storage.googleapis.com/iree-model-artifacts/mobilenet_v2_1.0_224_quantized.tflite", + source_url="https://storage.googleapis.com/iree-model-artifacts/mobilenet_v2_1.0_224_quantized.tflite", entry_function="main", - input_types=["1x224x224x3xui8"]) + input_types=["1x224x224x3xui8"], +) diff --git a/build_tools/python/e2e_test_framework/models/torch_models.py b/build_tools/python/e2e_test_framework/models/torch_models.py index ce174386c9db..aae2f6d58550 100644 --- a/build_tools/python/e2e_test_framework/models/torch_models.py +++ b/build_tools/python/e2e_test_framework/models/torch_models.py @@ -28,10 +28,10 @@ name="ClipTextSeqLen64PT", tags=["fp32", "seqlen64"], source_type=common_definitions.ModelSourceType.EXPORTED_LINALG_MLIR, - source_url= - "https://storage.googleapis.com/iree-model-artifacts/pytorch/torch_models_20230307.103_1678163233/SD_CLIP_TEXT_MODEL_SEQLEN64/linalg.mlir", + source_url="https://storage.googleapis.com/iree-model-artifacts/pytorch/torch_models_20230307.103_1678163233/SD_CLIP_TEXT_MODEL_SEQLEN64/linalg.mlir", entry_function="forward", - input_types=["1x77xi64", "1x77xi64"]) + input_types=["1x77xi64", "1x77xi64"], +) # `Unet2d` consists of `ResNet` encoder and decoder blocks with cross-attention layers. # @@ -51,10 +51,10 @@ name="Unet2dPT", tags=["fp32"], source_type=common_definitions.ModelSourceType.EXPORTED_LINALG_MLIR, - source_url= - "https://storage.googleapis.com/iree-model-artifacts/pytorch/torch_models_20230307.103_1678163233/SD_UNET_MODEL/linalg.mlir", + source_url="https://storage.googleapis.com/iree-model-artifacts/pytorch/torch_models_20230307.103_1678163233/SD_UNET_MODEL/linalg.mlir", entry_function="forward", - input_types=["1x4x64x64xf32", "1x77x768xf32"]) + input_types=["1x4x64x64xf32", "1x77x768xf32"], +) # Converted from https://pytorch.org/vision/stable/models/generated/torchvision.models.efficientnet_v2_s.html#torchvision.models.efficientnet_v2_s EFFICIENTNET_V2_S_FP32_TORCH = common_definitions.Model( @@ -62,10 +62,10 @@ name="EfficientNetV2SPT", tags=["fp32", "cnn", "depthwise-conv"], source_type=common_definitions.ModelSourceType.EXPORTED_LINALG_MLIR, - source_url= - "https://storage.googleapis.com/iree-model-artifacts/pytorch/torch_models_20230321.784_1679461251/EFFICIENTNET_V2_S/batch_1/linalg.mlir", + source_url="https://storage.googleapis.com/iree-model-artifacts/pytorch/torch_models_20230321.784_1679461251/EFFICIENTNET_V2_S/batch_1/linalg.mlir", entry_function="forward", - input_types=["1x3x384x384xf32"]) + input_types=["1x3x384x384xf32"], +) # FP16 EFFICIENTNET_V2_S_FP16_TORCH = common_definitions.Model( @@ -73,10 +73,10 @@ name="EfficientNetV2Sfp16PT", tags=["fp16", "cnn", "depthwise-conv"], source_type=common_definitions.ModelSourceType.EXPORTED_LINALG_MLIR, - source_url= - "https://storage.googleapis.com/iree-model-artifacts/pytorch/torch_models_20230522.846_1684831160/EFFICIENTNET_V2_S_FP16/batch_1/linalg.mlir", + source_url="https://storage.googleapis.com/iree-model-artifacts/pytorch/torch_models_20230522.846_1684831160/EFFICIENTNET_V2_S_FP16/batch_1/linalg.mlir", entry_function="forward", - input_types=["1x3x384x384xf16"]) + input_types=["1x3x384x384xf16"], +) # Converted from https://pytorch.org/vision/stable/models/generated/torchvision.models.efficientnet_b7.html#torchvision.models.efficientnet_b7 EFFICIENTNET_B7_FP32_TORCH = common_definitions.Model( @@ -84,10 +84,10 @@ name="EfficientNetB7PT", tags=["fp32", "cnn", "depthwise-conv"], source_type=common_definitions.ModelSourceType.EXPORTED_LINALG_MLIR, - source_url= - "https://storage.googleapis.com/iree-model-artifacts/pytorch/torch_models_20230321.784_1679461251/EFFICIENTNET_B7/batch_1/linalg.mlir", + source_url="https://storage.googleapis.com/iree-model-artifacts/pytorch/torch_models_20230321.784_1679461251/EFFICIENTNET_B7/batch_1/linalg.mlir", entry_function="forward", - input_types=["1x3x600x600xf32"]) + input_types=["1x3x600x600xf32"], +) ID_FORMAT = string.Template("${model_id}-batch-${batch_size}") NAME_FORMAT = string.Template("${name}Batch${batch_size}") @@ -107,57 +107,69 @@ # Converted from https://huggingface.co/docs/transformers/v4.27.2/en/model_doc/bert#transformers.BertModel BERT_LARGE_384_FP32_TORCH_BATCHES = model_utils.generate_batch_models( id_template=model_utils.partial_template_substitute( - ID_FORMAT, model_id=unique_ids.MODEL_BERT_LARGE_384_FP32_TORCH), - name_template=model_utils.partial_template_substitute(NAME_FORMAT, - name="BertLargePT"), + ID_FORMAT, model_id=unique_ids.MODEL_BERT_LARGE_384_FP32_TORCH + ), + name_template=model_utils.partial_template_substitute( + NAME_FORMAT, name="BertLargePT" + ), tags=["fp32", "transformer", "seqlen384"], source_type=common_definitions.ModelSourceType.EXPORTED_LINALG_MLIR, source_url_template=BERT_LARGE_FP32_URL, entry_function="forward", input_type_templates=[ string.Template("${batch_size}x384xi64"), - string.Template("${batch_size}x384xi64") + string.Template("${batch_size}x384xi64"), ], - batch_sizes=[1, 16, 24, 32, 48, 64, 512, 1024, 1280]) + batch_sizes=[1, 16, 24, 32, 48, 64, 512, 1024, 1280], +) # FP16 Versions BERT_LARGE_384_FP16_TORCH_BATCHES = model_utils.generate_batch_models( id_template=model_utils.partial_template_substitute( - ID_FORMAT, model_id=unique_ids.MODEL_BERT_LARGE_384_FP16_TORCH), + ID_FORMAT, model_id=unique_ids.MODEL_BERT_LARGE_384_FP16_TORCH + ), name_template=model_utils.partial_template_substitute( - NAME_FORMAT, name="BertLargefp16PT"), + NAME_FORMAT, name="BertLargefp16PT" + ), tags=["fp16", "transformer", "seqlen384"], source_type=common_definitions.ModelSourceType.EXPORTED_LINALG_MLIR, source_url_template=BERT_LARGE_FP16_URL, entry_function="forward", input_type_templates=[ string.Template("${batch_size}x384xi64"), - string.Template("${batch_size}x384xi64") + string.Template("${batch_size}x384xi64"), ], - batch_sizes=[1, 16, 24, 32, 48, 64, 512, 1024, 1280]) + batch_sizes=[1, 16, 24, 32, 48, 64, 512, 1024, 1280], +) # Converted from https://pytorch.org/vision/main/models/generated/torchvision.models.resnet50.html RESNET50_3X224X224_FP32_TORCH_BATCHES = model_utils.generate_batch_models( id_template=model_utils.partial_template_substitute( - ID_FORMAT, model_id=unique_ids.MODEL_RESNET50_3X224X224_FP32_TORCH), - name_template=model_utils.partial_template_substitute(NAME_FORMAT, - name="Resnet50PT"), + ID_FORMAT, model_id=unique_ids.MODEL_RESNET50_3X224X224_FP32_TORCH + ), + name_template=model_utils.partial_template_substitute( + NAME_FORMAT, name="Resnet50PT" + ), tags=["fp32", "cnn"], source_type=common_definitions.ModelSourceType.EXPORTED_LINALG_MLIR, source_url_template=RESNET50_FP32_URL, entry_function="forward", input_type_templates=[string.Template("${batch_size}x3x224x224xf32")], - batch_sizes=[1, 8, 64, 128, 256, 2048]) + batch_sizes=[1, 8, 64, 128, 256, 2048], +) # FP16 Versions RESNET50_3X224X224_FP16_TORCH_BATCHES = model_utils.generate_batch_models( id_template=model_utils.partial_template_substitute( - ID_FORMAT, model_id=unique_ids.MODEL_RESNET50_3X224X224_FP16_TORCH), + ID_FORMAT, model_id=unique_ids.MODEL_RESNET50_3X224X224_FP16_TORCH + ), name_template=model_utils.partial_template_substitute( - NAME_FORMAT, name="Resnet50fp16PT"), + NAME_FORMAT, name="Resnet50fp16PT" + ), tags=["fp32", "cnn"], source_type=common_definitions.ModelSourceType.EXPORTED_LINALG_MLIR, source_url_template=RESNET50_FP16_URL, entry_function="forward", input_type_templates=[string.Template("${batch_size}x3x224x224xf16")], - batch_sizes=[1, 8, 64, 128, 256, 2048]) + batch_sizes=[1, 8, 64, 128, 256, 2048], +) diff --git a/build_tools/python/e2e_test_framework/models/utils.py b/build_tools/python/e2e_test_framework/models/utils.py index 654b2100fae4..53f96b38d890 100644 --- a/build_tools/python/e2e_test_framework/models/utils.py +++ b/build_tools/python/e2e_test_framework/models/utils.py @@ -11,10 +11,11 @@ from e2e_test_framework.definitions import common_definitions -def partial_template_substitute(template: string.Template, - **substitutions) -> string.Template: - """Partially substitutes keywords in the template and returns a template.""" - return string.Template(template.safe_substitute(**substitutions)) +def partial_template_substitute( + template: string.Template, **substitutions +) -> string.Template: + """Partially substitutes keywords in the template and returns a template.""" + return string.Template(template.safe_substitute(**substitutions)) def generate_batch_models( @@ -27,28 +28,29 @@ def generate_batch_models( input_type_templates: Sequence[string.Template], batch_sizes: Sequence[int], ) -> Dict[int, common_definitions.Model]: - """Generate model definitions for different batch sizes by substituting - ${batch_size}` in the template strings. - - Only `*_template` parameters will be treated as templates and substituted. A - `batch-` tag will be appended to the tags in each returned model. - - Returns: - Map of batch size to model. - """ - model_map = {} - for batch_size in batch_sizes: - substituted_input_types = [ - input_type.substitute(batch_size=batch_size) - for input_type in input_type_templates - ] - model_map[batch_size] = common_definitions.Model( - id=id_template.substitute(batch_size=batch_size), - name=name_template.substitute(batch_size=batch_size), - tags=list(tags) + [f"batch-{batch_size}"], - source_type=source_type, - source_url=source_url_template.substitute(batch_size=batch_size), - entry_function=entry_function, - input_types=substituted_input_types) - - return model_map + """Generate model definitions for different batch sizes by substituting + ${batch_size}` in the template strings. + + Only `*_template` parameters will be treated as templates and substituted. A + `batch-` tag will be appended to the tags in each returned model. + + Returns: + Map of batch size to model. + """ + model_map = {} + for batch_size in batch_sizes: + substituted_input_types = [ + input_type.substitute(batch_size=batch_size) + for input_type in input_type_templates + ] + model_map[batch_size] = common_definitions.Model( + id=id_template.substitute(batch_size=batch_size), + name=name_template.substitute(batch_size=batch_size), + tags=list(tags) + [f"batch-{batch_size}"], + source_type=source_type, + source_url=source_url_template.substitute(batch_size=batch_size), + entry_function=entry_function, + input_types=substituted_input_types, + ) + + return model_map diff --git a/build_tools/python/e2e_test_framework/models/utils_test.py b/build_tools/python/e2e_test_framework/models/utils_test.py index bf4011d14c48..6d6034d4a80a 100644 --- a/build_tools/python/e2e_test_framework/models/utils_test.py +++ b/build_tools/python/e2e_test_framework/models/utils_test.py @@ -12,71 +12,75 @@ class UtilsTest(unittest.TestCase): + def test_partial_template_substitute(self): + template = string.Template("${name}-${batch_size}") - def test_partial_template_substitute(self): - template = string.Template("${name}-${batch_size}") + result = model_utils.partial_template_substitute(template, name="xyz") - result = model_utils.partial_template_substitute(template, name="xyz") + self.assertEqual(result.substitute(batch_size=10), "xyz-10") - self.assertEqual(result.substitute(batch_size=10), "xyz-10") - - def test_generate_batch_models(self): - models = model_utils.generate_batch_models( - id_template=string.Template("1234-${batch_size}"), - name_template=string.Template("model-batch-${batch_size}"), - tags=["abc"], - source_url_template=string.Template( - "https://example.com/x/${batch_size}.mlir"), - source_type=common_definitions.ModelSourceType.EXPORTED_STABLEHLO_MLIR, - entry_function="forward", - input_type_templates=[ - string.Template("${batch_size}x128"), - string.Template("${batch_size}x256") - ], - batch_sizes=[1, 4]) + def test_generate_batch_models(self): + models = model_utils.generate_batch_models( + id_template=string.Template("1234-${batch_size}"), + name_template=string.Template("model-batch-${batch_size}"), + tags=["abc"], + source_url_template=string.Template( + "https://example.com/x/${batch_size}.mlir" + ), + source_type=common_definitions.ModelSourceType.EXPORTED_STABLEHLO_MLIR, + entry_function="forward", + input_type_templates=[ + string.Template("${batch_size}x128"), + string.Template("${batch_size}x256"), + ], + batch_sizes=[1, 4], + ) - self.assertEqual( - models, { - 1: - common_definitions.Model( + self.assertEqual( + models, + { + 1: common_definitions.Model( id="1234-1", name="model-batch-1", tags=["abc", "batch-1"], source_url="https://example.com/x/1.mlir", - source_type=common_definitions.ModelSourceType. - EXPORTED_STABLEHLO_MLIR, + source_type=common_definitions.ModelSourceType.EXPORTED_STABLEHLO_MLIR, entry_function="forward", - input_types=["1x128", "1x256"]), - 4: - common_definitions.Model( + input_types=["1x128", "1x256"], + ), + 4: common_definitions.Model( id="1234-4", name="model-batch-4", tags=["abc", "batch-4"], source_url="https://example.com/x/4.mlir", - source_type=common_definitions.ModelSourceType. - EXPORTED_STABLEHLO_MLIR, + source_type=common_definitions.ModelSourceType.EXPORTED_STABLEHLO_MLIR, entry_function="forward", - input_types=["4x128", "4x256"]) - }) + input_types=["4x128", "4x256"], + ), + }, + ) - def test_generate_batch_models_missing_substitution(self): - id_template_with_unknown = string.Template("1234-${unknown}-${batch_size}") + def test_generate_batch_models_missing_substitution(self): + id_template_with_unknown = string.Template("1234-${unknown}-${batch_size}") - self.assertRaises( - KeyError, lambda: model_utils.generate_batch_models( - id_template=id_template_with_unknown, - name_template=string.Template("model-batch-${batch_size}"), - tags=["abc"], - source_url_template=string.Template( - "https://example.com/x/${batch_size}.mlir"), - source_type=common_definitions.ModelSourceType. - EXPORTED_STABLEHLO_MLIR, - entry_function="forward", - input_type_templates=[ - string.Template("${batch_size}x128"), - ], - batch_sizes=[1, 4])) + self.assertRaises( + KeyError, + lambda: model_utils.generate_batch_models( + id_template=id_template_with_unknown, + name_template=string.Template("model-batch-${batch_size}"), + tags=["abc"], + source_url_template=string.Template( + "https://example.com/x/${batch_size}.mlir" + ), + source_type=common_definitions.ModelSourceType.EXPORTED_STABLEHLO_MLIR, + entry_function="forward", + input_type_templates=[ + string.Template("${batch_size}x128"), + ], + batch_sizes=[1, 4], + ), + ) if __name__ == "__main__": - unittest.main() + unittest.main() diff --git a/build_tools/python/e2e_test_framework/serialization.py b/build_tools/python/e2e_test_framework/serialization.py index bc430c72e9eb..5cd3d358f583 100644 --- a/build_tools/python/e2e_test_framework/serialization.py +++ b/build_tools/python/e2e_test_framework/serialization.py @@ -18,227 +18,234 @@ SUPPORTED_PRIMITIVE_TYPES = {str, int, float, bool, NONE_TYPE} -def serialize_and_pack(obj, - root_obj_field_name="root_obj", - keyed_obj_map_field_name="keyed_obj_map"): - """Converts and packs the object into a serializable object. - - Args: - obj: object to be serialized. - root_obj_field_name: field name of the top-level object in the return dict. - keyed_obj_map_field_name: field name of the keyed object map in the return - dict. - Returns - A serializable dict. - """ - - if root_obj_field_name == keyed_obj_map_field_name: - raise ValueError( - f"root_obj and keyed_obj_map can't have the same field name.") - - keyed_obj_map = {} - root_obj = _serialize(obj=obj, keyed_obj_map=keyed_obj_map) - return { - root_obj_field_name: root_obj, - keyed_obj_map_field_name: keyed_obj_map - } - - -T = TypeVar('T') - - -def unpack_and_deserialize(data, - root_type: Type[T], - root_obj_field_name="root_obj", - keyed_obj_map_field_name="keyed_obj_map") -> T: - """Unpacks and deserializes the data back to the typed object. - - Args: - data: serialized data dict. - root_type: top-level object type of the data. - root_obj_field_name: field name of the top-level object in the dict. - keyed_obj_map_field_name: field name of the keyed object map in the dict. - Returns: - A deserialized object. - """ - obj = _deserialize(data=data[root_obj_field_name], - obj_type=root_type, - keyed_obj_map=data[keyed_obj_map_field_name]) - return typing.cast(root_type, obj) +def serialize_and_pack( + obj, root_obj_field_name="root_obj", keyed_obj_map_field_name="keyed_obj_map" +): + """Converts and packs the object into a serializable object. + + Args: + obj: object to be serialized. + root_obj_field_name: field name of the top-level object in the return dict. + keyed_obj_map_field_name: field name of the keyed object map in the return + dict. + Returns + A serializable dict. + """ + + if root_obj_field_name == keyed_obj_map_field_name: + raise ValueError(f"root_obj and keyed_obj_map can't have the same field name.") + + keyed_obj_map = {} + root_obj = _serialize(obj=obj, keyed_obj_map=keyed_obj_map) + return {root_obj_field_name: root_obj, keyed_obj_map_field_name: keyed_obj_map} + + +T = TypeVar("T") + + +def unpack_and_deserialize( + data, + root_type: Type[T], + root_obj_field_name="root_obj", + keyed_obj_map_field_name="keyed_obj_map", +) -> T: + """Unpacks and deserializes the data back to the typed object. + + Args: + data: serialized data dict. + root_type: top-level object type of the data. + root_obj_field_name: field name of the top-level object in the dict. + keyed_obj_map_field_name: field name of the keyed object map in the dict. + Returns: + A deserialized object. + """ + obj = _deserialize( + data=data[root_obj_field_name], + obj_type=root_type, + keyed_obj_map=data[keyed_obj_map_field_name], + ) + return typing.cast(root_type, obj) def _serialize(obj, keyed_obj_map: Dict[str, Any]): - """Converts the object into a serializable object. - - Args: - obj: object to be serialized. - keyed_obj_map: mutable container to store the keyed serializable object. - Returns - A serializable object. - """ - - serialize_func = getattr(obj, SERIALIZE_FUNC_NAME, None) - if serialize_func is not None: - return serialize_func(keyed_obj_map) - - elif isinstance(obj, list): - return [_serialize(value, keyed_obj_map) for value in obj] - - elif isinstance(obj, Enum): - return obj.name - - elif isinstance(obj, dict): - result_dict = {} - for key, value in obj.items(): - if type(key) not in SUPPORTED_DICT_KEY_TYPES: - raise ValueError(f"Unsupported key {key} in the dict {obj}.") - result_dict[key] = _serialize(value, keyed_obj_map) - return result_dict - - elif type(obj) in SUPPORTED_PRIMITIVE_TYPES: - return obj - - raise ValueError(f"Unsupported object: {obj}.") - - -def _deserialize(data, - obj_type: Type, - keyed_obj_map: Dict[str, Any], - obj_cache: Dict[str, Any] = {}): - """Deserializes the data back to the typed object. - - Args: - data: serialized data. - obj_type: type of the data. - keyed_obj_map: container of the keyed serializable object. - Returns: - A deserialized object. - """ - - deserialize_func = getattr(obj_type, DESERIALIZE_FUNC_NAME, None) - if deserialize_func is not None: - return deserialize_func(data, keyed_obj_map, obj_cache) - - elif typing.get_origin(obj_type) == list: - subtype, = typing.get_args(obj_type) - return [ - _deserialize(item, subtype, keyed_obj_map, obj_cache) for item in data - ] - - elif typing.get_origin(obj_type) == dict: - _, value_type = typing.get_args(obj_type) - return dict((key, _deserialize(value, value_type, keyed_obj_map, obj_cache)) - for key, value in data.items()) - - elif typing.get_origin(obj_type) == Union: - subtypes = typing.get_args(obj_type) - if len(subtypes) != 2 or NONE_TYPE not in subtypes: - raise ValueError(f"Unsupported union type: {obj_type}.") - subtype = subtypes[0] if subtypes[1] == NONE_TYPE else subtypes[1] - return _deserialize(data, subtype, keyed_obj_map, obj_cache) - - elif issubclass(obj_type, Enum): - for member in obj_type: - if data == member.name: - return member - raise ValueError(f"Member {data} not found in the enum {obj_type}.") - - return data - - -def serializable(cls=None, - type_key: Optional[str] = None, - id_field: str = "id"): - """Decorator to make a dataclass serializable. - - Args: - type_key: string defines the object type and indeicates that the class is a - keyed object, which is unique per id and will only have one copy in the - serialization per id. - id_field: field name of the id field of a keyed object. - - Example: - @serializable - @dataclass - class A(object): - ... - - @serialzable(type_key="obj_b") - @dataclass - class B(object): - id: str - """ - - if type_key is not None and ":" in type_key: - raise ValueError("':' is the reserved character in type_key.") - - def wrap(cls): - if not dataclasses.is_dataclass(cls): - raise ValueError(f"{cls} is not a dataclass.") - - fields = dataclasses.fields(cls) - if type_key is not None and all(field.name != id_field for field in fields): - raise ValueError(f'Id field "{id_field}" not found in the class {cls}.') - - def serialize(self, keyed_obj_map: Dict[str, Any]): - if type_key is None: - return _fields_to_dict(self, fields, keyed_obj_map) - - obj_id = getattr(self, id_field) - obj_key = f"{type_key}:{obj_id}" - if obj_key in keyed_obj_map: - # If the value in the map is None, it means we have visited this object - # before but not yet finished serializing it. This will only happen if - # there is a circular reference. - if keyed_obj_map[obj_key] is None: - raise ValueError(f"Circular reference is not supported: {obj_key}.") - return obj_id - - # Populate the keyed_obj_map with None first to detect circular reference. - keyed_obj_map[obj_key] = None - obj_dict = _fields_to_dict(self, fields, keyed_obj_map) - keyed_obj_map[obj_key] = obj_dict - return obj_id - - def deserialize(data, keyed_obj_map: Dict[str, Any], obj_cache: Dict[str, - Any]): - if type_key is None: - field_value_map = _dict_to_fields(data, fields, keyed_obj_map, - obj_cache) - return cls(**field_value_map) - - obj_id = data - obj_key = f"{type_key}:{obj_id}" - if obj_key in obj_cache: - return obj_cache[obj_key] - - field_value_map = _dict_to_fields(keyed_obj_map[obj_key], fields, - keyed_obj_map, obj_cache) - derialized_obj = cls(**field_value_map) - obj_cache[obj_key] = derialized_obj - return derialized_obj - - setattr(cls, SERIALIZE_FUNC_NAME, serialize) - setattr(cls, DESERIALIZE_FUNC_NAME, deserialize) - return cls - - # Trick to allow the decoration with `@serializable(...)`. In that case, - # `serializable` is called without cls and should return a decorator. - if cls is None: - return wrap - return wrap(cls) - - -def _fields_to_dict(obj, fields: Sequence[dataclasses.Field], - keyed_obj_map: Dict[str, Any]) -> Dict[str, Any]: - return dict((field.name, _serialize(getattr(obj, field.name), keyed_obj_map)) - for field in fields) - - -def _dict_to_fields(obj_dict, fields: Sequence[dataclasses.Field], - keyed_obj_map: Dict[str, Any], - obj_cache: Dict[str, Any]) -> Dict[str, Any]: - return dict( - (field.name, - _deserialize(obj_dict[field.name], field.type, keyed_obj_map, obj_cache)) - for field in fields) + """Converts the object into a serializable object. + + Args: + obj: object to be serialized. + keyed_obj_map: mutable container to store the keyed serializable object. + Returns + A serializable object. + """ + + serialize_func = getattr(obj, SERIALIZE_FUNC_NAME, None) + if serialize_func is not None: + return serialize_func(keyed_obj_map) + + elif isinstance(obj, list): + return [_serialize(value, keyed_obj_map) for value in obj] + + elif isinstance(obj, Enum): + return obj.name + + elif isinstance(obj, dict): + result_dict = {} + for key, value in obj.items(): + if type(key) not in SUPPORTED_DICT_KEY_TYPES: + raise ValueError(f"Unsupported key {key} in the dict {obj}.") + result_dict[key] = _serialize(value, keyed_obj_map) + return result_dict + + elif type(obj) in SUPPORTED_PRIMITIVE_TYPES: + return obj + + raise ValueError(f"Unsupported object: {obj}.") + + +def _deserialize( + data, obj_type: Type, keyed_obj_map: Dict[str, Any], obj_cache: Dict[str, Any] = {} +): + """Deserializes the data back to the typed object. + + Args: + data: serialized data. + obj_type: type of the data. + keyed_obj_map: container of the keyed serializable object. + Returns: + A deserialized object. + """ + + deserialize_func = getattr(obj_type, DESERIALIZE_FUNC_NAME, None) + if deserialize_func is not None: + return deserialize_func(data, keyed_obj_map, obj_cache) + + elif typing.get_origin(obj_type) == list: + (subtype,) = typing.get_args(obj_type) + return [_deserialize(item, subtype, keyed_obj_map, obj_cache) for item in data] + + elif typing.get_origin(obj_type) == dict: + _, value_type = typing.get_args(obj_type) + return dict( + (key, _deserialize(value, value_type, keyed_obj_map, obj_cache)) + for key, value in data.items() + ) + + elif typing.get_origin(obj_type) == Union: + subtypes = typing.get_args(obj_type) + if len(subtypes) != 2 or NONE_TYPE not in subtypes: + raise ValueError(f"Unsupported union type: {obj_type}.") + subtype = subtypes[0] if subtypes[1] == NONE_TYPE else subtypes[1] + return _deserialize(data, subtype, keyed_obj_map, obj_cache) + + elif issubclass(obj_type, Enum): + for member in obj_type: + if data == member.name: + return member + raise ValueError(f"Member {data} not found in the enum {obj_type}.") + + return data + + +def serializable(cls=None, type_key: Optional[str] = None, id_field: str = "id"): + """Decorator to make a dataclass serializable. + + Args: + type_key: string defines the object type and indeicates that the class is a + keyed object, which is unique per id and will only have one copy in the + serialization per id. + id_field: field name of the id field of a keyed object. + + Example: + @serializable + @dataclass + class A(object): + ... + + @serialzable(type_key="obj_b") + @dataclass + class B(object): + id: str + """ + + if type_key is not None and ":" in type_key: + raise ValueError("':' is the reserved character in type_key.") + + def wrap(cls): + if not dataclasses.is_dataclass(cls): + raise ValueError(f"{cls} is not a dataclass.") + + fields = dataclasses.fields(cls) + if type_key is not None and all(field.name != id_field for field in fields): + raise ValueError(f'Id field "{id_field}" not found in the class {cls}.') + + def serialize(self, keyed_obj_map: Dict[str, Any]): + if type_key is None: + return _fields_to_dict(self, fields, keyed_obj_map) + + obj_id = getattr(self, id_field) + obj_key = f"{type_key}:{obj_id}" + if obj_key in keyed_obj_map: + # If the value in the map is None, it means we have visited this object + # before but not yet finished serializing it. This will only happen if + # there is a circular reference. + if keyed_obj_map[obj_key] is None: + raise ValueError(f"Circular reference is not supported: {obj_key}.") + return obj_id + + # Populate the keyed_obj_map with None first to detect circular reference. + keyed_obj_map[obj_key] = None + obj_dict = _fields_to_dict(self, fields, keyed_obj_map) + keyed_obj_map[obj_key] = obj_dict + return obj_id + + def deserialize(data, keyed_obj_map: Dict[str, Any], obj_cache: Dict[str, Any]): + if type_key is None: + field_value_map = _dict_to_fields( + data, fields, keyed_obj_map, obj_cache + ) + return cls(**field_value_map) + + obj_id = data + obj_key = f"{type_key}:{obj_id}" + if obj_key in obj_cache: + return obj_cache[obj_key] + + field_value_map = _dict_to_fields( + keyed_obj_map[obj_key], fields, keyed_obj_map, obj_cache + ) + derialized_obj = cls(**field_value_map) + obj_cache[obj_key] = derialized_obj + return derialized_obj + + setattr(cls, SERIALIZE_FUNC_NAME, serialize) + setattr(cls, DESERIALIZE_FUNC_NAME, deserialize) + return cls + + # Trick to allow the decoration with `@serializable(...)`. In that case, + # `serializable` is called without cls and should return a decorator. + if cls is None: + return wrap + return wrap(cls) + + +def _fields_to_dict( + obj, fields: Sequence[dataclasses.Field], keyed_obj_map: Dict[str, Any] +) -> Dict[str, Any]: + return dict( + (field.name, _serialize(getattr(obj, field.name), keyed_obj_map)) + for field in fields + ) + + +def _dict_to_fields( + obj_dict, + fields: Sequence[dataclasses.Field], + keyed_obj_map: Dict[str, Any], + obj_cache: Dict[str, Any], +) -> Dict[str, Any]: + return dict( + ( + field.name, + _deserialize(obj_dict[field.name], field.type, keyed_obj_map, obj_cache), + ) + for field in fields + ) diff --git a/build_tools/python/e2e_test_framework/serialization_test.py b/build_tools/python/e2e_test_framework/serialization_test.py index 2d1077586585..e4fcca41fb00 100644 --- a/build_tools/python/e2e_test_framework/serialization_test.py +++ b/build_tools/python/e2e_test_framework/serialization_test.py @@ -16,141 +16,162 @@ class EnumX(enum.Enum): - OPTION_A = "a" - OPTION_B = "b" - OPTION_C = "c" + OPTION_A = "a" + OPTION_B = "b" + OPTION_C = "c" @serialization.serializable @dataclass class TestC(object): - float_val: float + float_val: float @serialization.serializable(type_key="test_b", id_field="key") @dataclass class TestB(object): - key: str - int_val: int + key: str + int_val: int @serialization.serializable @dataclass class TestA(object): - b_list: List[TestB] - c_obj: TestC - str_val: Optional[str] - enum_val: EnumX + b_list: List[TestB] + c_obj: TestC + str_val: Optional[str] + enum_val: EnumX @serialization.serializable @dataclass class TestUnsupported(object): - path: pathlib.PurePath + path: pathlib.PurePath @serialization.serializable(type_key="test_circular") @dataclass class TestCircularReference(object): - id: str - child: Optional["TestCircularReference"] + id: str + child: Optional["TestCircularReference"] class SerializationTest(unittest.TestCase): - - def test_serialize_and_pack(self): - b_obj_a = TestB(key="id_a", int_val=10) - b_obj_b = TestB(key="id_b", int_val=20) - test_objs = [ - TestA(b_list=[b_obj_a, b_obj_b], - c_obj=TestC(float_val=0.1), - str_val="test1", - enum_val=EnumX.OPTION_B), - TestA(b_list=[b_obj_a], - c_obj=TestC(float_val=0.2), - str_val=None, - enum_val=EnumX.OPTION_C) - ] - - results = serialization.serialize_and_pack( - test_objs, - root_obj_field_name="main_obj", - keyed_obj_map_field_name="obj_map") - - self.maxDiff = None - self.assertEqual( - results, { - "main_obj": [ - dict(b_list=["id_a", "id_b"], - c_obj=dict(float_val=0.1), - str_val="test1", - enum_val="OPTION_B"), - dict(b_list=["id_a"], - c_obj=dict(float_val=0.2), - str_val=None, - enum_val="OPTION_C") - ], - "obj_map": { - "test_b:id_a": dict(key="id_a", int_val=10), - "test_b:id_b": dict(key="id_b", int_val=20) - } - }) - - def test_serialize_and_pack_with_unsupported_type(self): - self.assertRaises( - ValueError, lambda: serialization.serialize_and_pack( - TestUnsupported(path=pathlib.PurePath("abc")))) - - def test_serialize_and_pack_with_unsupported_dict_key(self): - self.assertRaises( - ValueError, lambda: serialization.serialize_and_pack({(0, 0): "test"})) - - def test_serialize_and_pack_with_circular_reference(self): - obj_a = TestCircularReference(id="0", child=None) - obj_b = TestCircularReference(id="1", child=obj_a) - obj_a.child = obj_b - - self.assertRaises(ValueError, - lambda: serialization.serialize_and_pack(obj_a)) - - def test_roundtrip(self): - b_obj_a = TestB(key="id_a", int_val=10) - b_obj_b = TestB(key="id_b", int_val=20) - test_objs = [ - TestA(b_list=[b_obj_a, b_obj_b], - c_obj=TestC(float_val=0.1), - str_val="test1", - enum_val=EnumX.OPTION_B), - TestA(b_list=[b_obj_a], - c_obj=TestC(float_val=0.2), - str_val=None, - enum_val=EnumX.OPTION_C), - TestA(b_list=[b_obj_b], - c_obj=TestC(float_val=0.3), - str_val="test3", - enum_val=EnumX.OPTION_A), - ] - - results = serialization.unpack_and_deserialize( - serialization.serialize_and_pack(test_objs), typing.List[TestA]) - - self.assertEqual(results, test_objs) - - def test_roundtrip_with_json(self): - b_obj_a = TestB(key="id_a", int_val=10) - b_obj_b = TestB(key="id_b", int_val=20) - - objs = { - "x": b_obj_a, - "y": b_obj_b, - } - - json_str = json.dumps(serialization.serialize_and_pack(objs)) - results = serialization.unpack_and_deserialize(json.loads(json_str), - typing.Dict[str, TestB]) - - self.assertEqual(results, objs) + def test_serialize_and_pack(self): + b_obj_a = TestB(key="id_a", int_val=10) + b_obj_b = TestB(key="id_b", int_val=20) + test_objs = [ + TestA( + b_list=[b_obj_a, b_obj_b], + c_obj=TestC(float_val=0.1), + str_val="test1", + enum_val=EnumX.OPTION_B, + ), + TestA( + b_list=[b_obj_a], + c_obj=TestC(float_val=0.2), + str_val=None, + enum_val=EnumX.OPTION_C, + ), + ] + + results = serialization.serialize_and_pack( + test_objs, + root_obj_field_name="main_obj", + keyed_obj_map_field_name="obj_map", + ) + + self.maxDiff = None + self.assertEqual( + results, + { + "main_obj": [ + dict( + b_list=["id_a", "id_b"], + c_obj=dict(float_val=0.1), + str_val="test1", + enum_val="OPTION_B", + ), + dict( + b_list=["id_a"], + c_obj=dict(float_val=0.2), + str_val=None, + enum_val="OPTION_C", + ), + ], + "obj_map": { + "test_b:id_a": dict(key="id_a", int_val=10), + "test_b:id_b": dict(key="id_b", int_val=20), + }, + }, + ) + + def test_serialize_and_pack_with_unsupported_type(self): + self.assertRaises( + ValueError, + lambda: serialization.serialize_and_pack( + TestUnsupported(path=pathlib.PurePath("abc")) + ), + ) + + def test_serialize_and_pack_with_unsupported_dict_key(self): + self.assertRaises( + ValueError, lambda: serialization.serialize_and_pack({(0, 0): "test"}) + ) + + def test_serialize_and_pack_with_circular_reference(self): + obj_a = TestCircularReference(id="0", child=None) + obj_b = TestCircularReference(id="1", child=obj_a) + obj_a.child = obj_b + + self.assertRaises(ValueError, lambda: serialization.serialize_and_pack(obj_a)) + + def test_roundtrip(self): + b_obj_a = TestB(key="id_a", int_val=10) + b_obj_b = TestB(key="id_b", int_val=20) + test_objs = [ + TestA( + b_list=[b_obj_a, b_obj_b], + c_obj=TestC(float_val=0.1), + str_val="test1", + enum_val=EnumX.OPTION_B, + ), + TestA( + b_list=[b_obj_a], + c_obj=TestC(float_val=0.2), + str_val=None, + enum_val=EnumX.OPTION_C, + ), + TestA( + b_list=[b_obj_b], + c_obj=TestC(float_val=0.3), + str_val="test3", + enum_val=EnumX.OPTION_A, + ), + ] + + results = serialization.unpack_and_deserialize( + serialization.serialize_and_pack(test_objs), typing.List[TestA] + ) + + self.assertEqual(results, test_objs) + + def test_roundtrip_with_json(self): + b_obj_a = TestB(key="id_a", int_val=10) + b_obj_b = TestB(key="id_b", int_val=20) + + objs = { + "x": b_obj_a, + "y": b_obj_b, + } + + json_str = json.dumps(serialization.serialize_and_pack(objs)) + results = serialization.unpack_and_deserialize( + json.loads(json_str), typing.Dict[str, TestB] + ) + + self.assertEqual(results, objs) if __name__ == "__main__": - unittest.main() + unittest.main() diff --git a/build_tools/python/e2e_test_framework/unique_ids.py b/build_tools/python/e2e_test_framework/unique_ids.py index 7589f4fb8893..ef3778730401 100644 --- a/build_tools/python/e2e_test_framework/unique_ids.py +++ b/build_tools/python/e2e_test_framework/unique_ids.py @@ -24,40 +24,38 @@ def hash_composite_id(keys: Sequence[str]) -> str: - """Computes the composite hash id from string keys. - - String keys are the component ids that compose this composite object. We hash - the composite id since the id isn't designed to be inspected and insufficient - to reconstruct the original composite object. - - Note that the output is sensitive to the order of the keys, and any key == - TRANSPARENT_ID will be skipped. When adding a new key to the keys, the new key - should be always appended to the end. In this way, the composite id can be - unchanged for the existing composite object if they use TRANSPARENT_ID on the - new keyed field. - - The composite id is computed in the following steps: - 1. Index each key with its position in the list from 0. - 2. Remove any key == TRANSPARENT_ID - 3. Get the SHA256 hex digest of "0-key_0:1-key_1:..." - - Step 1 is needed to avoid the ambiguity between: - ["key_abc", TRANSPARENT_ID] and [TRANSPARENT_ID, "key_abc"] - since after removing TRANSPARENT_ID, they both become ["key_abc"] without the - position index. - - Args: - keys: list of string keys. - - Returns: - Unique composite id. - """ - trimmed_indexed_key = [ - f"{index}-{key}" for index, key in enumerate(keys) - if key != TRANSPARENT_ID - ] - return hashlib.sha256( - ":".join(trimmed_indexed_key).encode("utf-8")).hexdigest() + """Computes the composite hash id from string keys. + + String keys are the component ids that compose this composite object. We hash + the composite id since the id isn't designed to be inspected and insufficient + to reconstruct the original composite object. + + Note that the output is sensitive to the order of the keys, and any key == + TRANSPARENT_ID will be skipped. When adding a new key to the keys, the new key + should be always appended to the end. In this way, the composite id can be + unchanged for the existing composite object if they use TRANSPARENT_ID on the + new keyed field. + + The composite id is computed in the following steps: + 1. Index each key with its position in the list from 0. + 2. Remove any key == TRANSPARENT_ID + 3. Get the SHA256 hex digest of "0-key_0:1-key_1:..." + + Step 1 is needed to avoid the ambiguity between: + ["key_abc", TRANSPARENT_ID] and [TRANSPARENT_ID, "key_abc"] + since after removing TRANSPARENT_ID, they both become ["key_abc"] without the + position index. + + Args: + keys: list of string keys. + + Returns: + Unique composite id. + """ + trimmed_indexed_key = [ + f"{index}-{key}" for index, key in enumerate(keys) if key != TRANSPARENT_ID + ] + return hashlib.sha256(":".join(trimmed_indexed_key).encode("utf-8")).hexdigest() # To generate an id, run `uuid.uuid4()`. @@ -141,33 +139,69 @@ def hash_composite_id(keys: Sequence[str]) -> str: # IREE benchmarks IREE_COMPILE_CONFIG_VMVX_GENERIC_EXPERIMENTAL = "75336abd-8108-462c-9ce3-15443e3f32f4" IREE_COMPILE_CONFIG_LINUX_CASCADELAKE = "e7e18b0f-c72d-4f1c-89b1-5afee70df6e9" -IREE_COMPILE_CONFIG_LINUX_CASCADELAKE_FUSE_PADDING = "6d0d5716-5525-44ad-b71d-8075ee1583a6" +IREE_COMPILE_CONFIG_LINUX_CASCADELAKE_FUSE_PADDING = ( + "6d0d5716-5525-44ad-b71d-8075ee1583a6" +) IREE_COMPILE_CONFIG_LINUX_RV64_GENERIC_DEFAULTS = "cdf579a9-5446-403b-a991-802a6c702e65" IREE_COMPILE_CONFIG_LINUX_RV32_GENERIC_DEFAULTS = "6d9ce240-ec14-4d8f-a8e4-1b20aa17b4e4" IREE_COMPILE_CONFIG_LINUX_CUDA_SM80_DEFAULTS = "09cb5300-7f73-45cf-9f68-e114c77ca030" -IREE_COMPILE_CONFIG_LINUX_CUDA_SM80_MATMUL_UBENCH = "3f66ba98-5716-4d30-9a87-50bc78e5f714" -IREE_COMPILE_CONFIG_LINUX_CUDA_SM80_MATMUL_SPLITK_UBENCH = "54cf2ec3-d073-4281-9561-b6c1280bd0eb" +IREE_COMPILE_CONFIG_LINUX_CUDA_SM80_MATMUL_UBENCH = ( + "3f66ba98-5716-4d30-9a87-50bc78e5f714" +) +IREE_COMPILE_CONFIG_LINUX_CUDA_SM80_MATMUL_SPLITK_UBENCH = ( + "54cf2ec3-d073-4281-9561-b6c1280bd0eb" +) IREE_COMPILE_CONFIG_LINUX_VULKAN_SD_SIMT = "da0ea6e6-719b-43ee-bfec-72eb3b1173bf" IREE_COMPILE_CONFIG_LINUX_VULKAN_SD_TENSORCORE = "97790694-4f0f-4d83-bc52-d74e019c1df9" -IREE_COMPILE_CONFIG_ANDROID_ARM_VALHALL_DEFAULTS = "8da35f2b-a042-4b7d-9dcf-5ebbc1728765" -IREE_COMPILE_CONFIG_ANDROID_ARM_VALHALL_EXPERIMENTAL = "32a56c8d-cc6c-41b8-8620-1f8eda0b8223" -IREE_COMPILE_CONFIG_ANDROID_ARM_VALHALL_EXPERIMENTAL_REPEATED_KERNEL = "6b601a8d-4824-42e0-bcc6-500c0c3fa346" -IREE_COMPILE_CONFIG_ANDROID_ARMV8_2_A_GENERIC_DEFAULTS = "1f2adf49-282e-4aff-9d4f-e63b1621f1e8" -IREE_COMPILE_CONFIG_ANDROID_ARMV8_2_A_GENERIC_MMT4D = "d463322c-24e6-4685-85ca-d541b41a405f" -IREE_COMPILE_CONFIG_ANDROID_ARMV8_2_A_GENERIC_MMT4D_DOTPROD = "f672a6b9-99fc-47ce-8b1b-8e5f44a541a1" -IREE_COMPILE_CONFIG_ANDROID_QUALCOMM_ADRENO_DEFAULTS = "c7eea358-d8d2-4199-9d75-bb741c399b1b" -IREE_COMPILE_CONFIG_ANDROID_QUALCOMM_ADRENO_FUSE_PADDING = "d3038b95-c889-456a-bff6-5cbabd10f1ad" -IREE_COMPILE_CONFIG_ANDROID_QUALCOMM_ADRENO_FUSE_PADDING_REPEATED_KERNEL = "70b823ca-2807-4531-8c00-e02af7d70466" +IREE_COMPILE_CONFIG_ANDROID_ARM_VALHALL_DEFAULTS = ( + "8da35f2b-a042-4b7d-9dcf-5ebbc1728765" +) +IREE_COMPILE_CONFIG_ANDROID_ARM_VALHALL_EXPERIMENTAL = ( + "32a56c8d-cc6c-41b8-8620-1f8eda0b8223" +) +IREE_COMPILE_CONFIG_ANDROID_ARM_VALHALL_EXPERIMENTAL_REPEATED_KERNEL = ( + "6b601a8d-4824-42e0-bcc6-500c0c3fa346" +) +IREE_COMPILE_CONFIG_ANDROID_ARMV8_2_A_GENERIC_DEFAULTS = ( + "1f2adf49-282e-4aff-9d4f-e63b1621f1e8" +) +IREE_COMPILE_CONFIG_ANDROID_ARMV8_2_A_GENERIC_MMT4D = ( + "d463322c-24e6-4685-85ca-d541b41a405f" +) +IREE_COMPILE_CONFIG_ANDROID_ARMV8_2_A_GENERIC_MMT4D_DOTPROD = ( + "f672a6b9-99fc-47ce-8b1b-8e5f44a541a1" +) +IREE_COMPILE_CONFIG_ANDROID_QUALCOMM_ADRENO_DEFAULTS = ( + "c7eea358-d8d2-4199-9d75-bb741c399b1b" +) +IREE_COMPILE_CONFIG_ANDROID_QUALCOMM_ADRENO_FUSE_PADDING = ( + "d3038b95-c889-456a-bff6-5cbabd10f1ad" +) +IREE_COMPILE_CONFIG_ANDROID_QUALCOMM_ADRENO_FUSE_PADDING_REPEATED_KERNEL = ( + "70b823ca-2807-4531-8c00-e02af7d70466" +) IREE_MODULE_EXECUTION_CONFIG_LOCAL_SYNC = "13fc65a9-e5dc-4cbb-9c09-25b0b08f4c03" IREE_MODULE_EXECUTION_CONFIG_LOCAL_TASK_BASE = "c7c4a15e-b20c-4898-bb4a-864f34ff34b2" -IREE_MODULE_EXECUTION_CONFIG_SYS_SCHED_LOCAL_TASK_BASE = "0dfb6b03-bd15-45a9-b82a-345c03f1fea6" +IREE_MODULE_EXECUTION_CONFIG_SYS_SCHED_LOCAL_TASK_BASE = ( + "0dfb6b03-bd15-45a9-b82a-345c03f1fea6" +) IREE_MODULE_EXECUTION_CONFIG_CUDA = "f7c0ec98-f028-436a-b05a-7d35cf18ce2d" -IREE_MODULE_EXECUTION_CONFIG_CUDA_BATCH_SIZE_100 = "ce15c338-b1d1-4ee3-b876-22d3cc5a831d" +IREE_MODULE_EXECUTION_CONFIG_CUDA_BATCH_SIZE_100 = ( + "ce15c338-b1d1-4ee3-b876-22d3cc5a831d" +) IREE_MODULE_EXECUTION_CONFIG_VULKAN = "34ae13f0-d6d9-43f7-befb-15d024e88e89" -IREE_MODULE_EXECUTION_CONFIG_VULKAN_BATCH_SIZE_16 = "b10737a8-5da4-4052-9b7a-5b07f21e02d0" -IREE_MODULE_EXECUTION_CONFIG_VULKAN_BATCH_SIZE_32 = "c59f6ed8-ef78-4ddd-93ea-f173c5e4d6b8" -IREE_MODULE_EXECUTION_CONFIG_VMVX_LOCAL_TASK_BASE = "953183e2-1e84-4a51-a43c-9b869bdc2218" -IREE_MODULE_EXECUTION_CONFIG_VMVX_SYS_SCHED_LOCAL_TASK_BASE = "a1a9795e-2fc5-4d95-abc0-b0fb41b07557" +IREE_MODULE_EXECUTION_CONFIG_VULKAN_BATCH_SIZE_16 = ( + "b10737a8-5da4-4052-9b7a-5b07f21e02d0" +) +IREE_MODULE_EXECUTION_CONFIG_VULKAN_BATCH_SIZE_32 = ( + "c59f6ed8-ef78-4ddd-93ea-f173c5e4d6b8" +) +IREE_MODULE_EXECUTION_CONFIG_VMVX_LOCAL_TASK_BASE = ( + "953183e2-1e84-4a51-a43c-9b869bdc2218" +) +IREE_MODULE_EXECUTION_CONFIG_VMVX_SYS_SCHED_LOCAL_TASK_BASE = ( + "a1a9795e-2fc5-4d95-abc0-b0fb41b07557" +) IREE_MODEL_IMPORT_STABLEHLO_MLIR_DEFAULT = "8b2df698-f3ba-4207-8696-6c909776eac4" IREE_MODEL_IMPORT_TFLITE_DEFAULT = "16280d67-7ce0-4807-ab4b-0cb3c771d206" IREE_MODEL_IMPORT_LINALG_MLIR_DEFAULT = "8afc4561-e84d-4a91-af55-2b1917465fcc" diff --git a/build_tools/python/e2e_test_framework/unique_ids_test.py b/build_tools/python/e2e_test_framework/unique_ids_test.py index 2fdf8fc01cff..85c9f569eab3 100644 --- a/build_tools/python/e2e_test_framework/unique_ids_test.py +++ b/build_tools/python/e2e_test_framework/unique_ids_test.py @@ -11,42 +11,42 @@ class UniqueIdsTest(unittest.TestCase): + def test_hash_composite_id(self): + output = unique_ids.hash_composite_id(["abc", "123"]) - def test_hash_composite_id(self): - output = unique_ids.hash_composite_id(["abc", "123"]) + self.assertEquals( + output, hashlib.sha256(f"0-abc:1-123".encode("utf-8")).hexdigest() + ) - self.assertEquals( - output, - hashlib.sha256(f"0-abc:1-123".encode("utf-8")).hexdigest()) + def test_hash_composite_id_diff_keys(self): + ids = [ + unique_ids.hash_composite_id([]), + unique_ids.hash_composite_id(["abc", "123"]), + unique_ids.hash_composite_id(["123", "abc"]), + unique_ids.hash_composite_id(["123", unique_ids.TRANSPARENT_ID]), + unique_ids.hash_composite_id(["123", "abc", "xyz"]), + unique_ids.hash_composite_id(["123", unique_ids.TRANSPARENT_ID, "xyz"]), + ] - def test_hash_composite_id_diff_keys(self): - ids = [ - unique_ids.hash_composite_id([]), - unique_ids.hash_composite_id(["abc", "123"]), - unique_ids.hash_composite_id(["123", "abc"]), - unique_ids.hash_composite_id(["123", unique_ids.TRANSPARENT_ID]), - unique_ids.hash_composite_id(["123", "abc", "xyz"]), - unique_ids.hash_composite_id(["123", unique_ids.TRANSPARENT_ID, "xyz"]) - ] + # Check if they are all distinct. + self.assertCountEqual(set(ids), ids) - # Check if they are all distinct. - self.assertCountEqual(set(ids), ids) + def test_hash_composite_id_unchanged_with_transparent_id(self): + existing_id = unique_ids.hash_composite_id(["abc"]) + new_id_a = unique_ids.hash_composite_id(["abc", unique_ids.TRANSPARENT_ID]) + new_id_b = unique_ids.hash_composite_id( + ["abc", unique_ids.TRANSPARENT_ID, unique_ids.TRANSPARENT_ID] + ) - def test_hash_composite_id_unchanged_with_transparent_id(self): - existing_id = unique_ids.hash_composite_id(["abc"]) - new_id_a = unique_ids.hash_composite_id(["abc", unique_ids.TRANSPARENT_ID]) - new_id_b = unique_ids.hash_composite_id( - ["abc", unique_ids.TRANSPARENT_ID, unique_ids.TRANSPARENT_ID]) + self.assertEquals(existing_id, new_id_a) + self.assertEquals(existing_id, new_id_b) - self.assertEquals(existing_id, new_id_a) - self.assertEquals(existing_id, new_id_b) + def test_hash_composite_id_with_transparent_ids_in_diff_pos(self): + id_a = unique_ids.hash_composite_id([unique_ids.TRANSPARENT_ID, "abc"]) + id_b = unique_ids.hash_composite_id(["abc", unique_ids.TRANSPARENT_ID]) - def test_hash_composite_id_with_transparent_ids_in_diff_pos(self): - id_a = unique_ids.hash_composite_id([unique_ids.TRANSPARENT_ID, "abc"]) - id_b = unique_ids.hash_composite_id(["abc", unique_ids.TRANSPARENT_ID]) - - self.assertNotEquals(id_a, id_b) + self.assertNotEquals(id_a, id_b) if __name__ == "__main__": - unittest.main() + unittest.main() diff --git a/build_tools/python/reporting/benchmark_comment.py b/build_tools/python/reporting/benchmark_comment.py index 7f67e22681e8..ce6135d3a149 100644 --- a/build_tools/python/reporting/benchmark_comment.py +++ b/build_tools/python/reporting/benchmark_comment.py @@ -12,12 +12,13 @@ @dataclass(frozen=True) class CommentData(object): - """Benchmark comment data.""" - # Unique id to identify the same kind of comment. - type_id: str - # Abbreviated markdown to post as a comment. - abbr_md: str - # Abbreviated markdown to post on gist. - full_md: str - # Unverified PR number. - unverified_pr_number: int + """Benchmark comment data.""" + + # Unique id to identify the same kind of comment. + type_id: str + # Abbreviated markdown to post as a comment. + abbr_md: str + # Abbreviated markdown to post on gist. + full_md: str + # Unverified PR number. + unverified_pr_number: int diff --git a/build_tools/python/reporting/common/html_utils.py b/build_tools/python/reporting/common/html_utils.py index 96b970beba29..871953bddd41 100644 --- a/build_tools/python/reporting/common/html_utils.py +++ b/build_tools/python/reporting/common/html_utils.py @@ -16,138 +16,158 @@ def get_table_css(): - styles = [ - dict(selector="tr:hover", props=[("background", "#f4f4f4")]), - dict(selector="tbody tr", props=[("background-color", "#ffffff")]), - dict(selector="tbody td", props=[("border", "1px solid #dddfe1")]), - dict(selector="th", - props=[("background-color", "#54585d"), ("color", "#ffffff"), - ("font-weight", "bold"), ("border", "1px solid #54585d"), - ("padding", "10px")]), - dict(selector="td", props=[("padding", "10px")]), - dict(selector="", - props=[("border-collapse", "collapse"), - ("font-family", "Tahoma, Geneva, sans-serif")]), - dict(selector="caption", - props=[("text-align", "center"), ("padding", "10px"), - ("font-weight", "bold"), ("font-size", "1.2em"), - ("color", "#636363")]), - ] - return styles + styles = [ + dict(selector="tr:hover", props=[("background", "#f4f4f4")]), + dict(selector="tbody tr", props=[("background-color", "#ffffff")]), + dict(selector="tbody td", props=[("border", "1px solid #dddfe1")]), + dict( + selector="th", + props=[ + ("background-color", "#54585d"), + ("color", "#ffffff"), + ("font-weight", "bold"), + ("border", "1px solid #54585d"), + ("padding", "10px"), + ], + ), + dict(selector="td", props=[("padding", "10px")]), + dict( + selector="", + props=[ + ("border-collapse", "collapse"), + ("font-family", "Tahoma, Geneva, sans-serif"), + ], + ), + dict( + selector="caption", + props=[ + ("text-align", "center"), + ("padding", "10px"), + ("font-weight", "bold"), + ("font-size", "1.2em"), + ("color", "#636363"), + ], + ), + ] + return styles def style_legend(v): - if _LEGEND_0 in v: - props = "background-color: #0277BD;" - elif _LEGEND_1 in v: - props = "background-color: #2E7D32;" - elif _LEGEND_2 in v: - props = "background-color: #66BB6A;" - elif _LEGEND_3 in v: - props = "background-color: #FBC02D;" - elif _LEGEND_4 in v: - props = "background-color: #E57373;" - elif _LEGEND_5 in v: - props = "background-color: #C62828;" - else: - props = "background-color: #880E4F" - return props + if _LEGEND_0 in v: + props = "background-color: #0277BD;" + elif _LEGEND_1 in v: + props = "background-color: #2E7D32;" + elif _LEGEND_2 in v: + props = "background-color: #66BB6A;" + elif _LEGEND_3 in v: + props = "background-color: #FBC02D;" + elif _LEGEND_4 in v: + props = "background-color: #E57373;" + elif _LEGEND_5 in v: + props = "background-color: #C62828;" + else: + props = "background-color: #880E4F" + return props def generate_header_and_legend(version_html): - html = "" - html = html + version_html - - legend = pd.DataFrame(columns=[""]) - legend.loc[len(legend)] = [_LEGEND_0] - legend.loc[len(legend)] = [_LEGEND_1] - legend.loc[len(legend)] = [_LEGEND_2] - legend.loc[len(legend)] = [_LEGEND_3] - legend.loc[len(legend)] = [_LEGEND_4] - legend.loc[len(legend)] = [_LEGEND_5] - legend.loc[len(legend)] = [_LEGEND_6] - - styled_legend = legend.style.set_table_styles(get_table_css()) - styled_legend.set_caption("Legend") - styled_legend = styled_legend.set_properties(**{"color": "#ffffff"}) - styled_legend = styled_legend.set_properties(**{"width": "200px"}) - styled_legend = styled_legend.applymap(style_legend) - styled_legend = styled_legend.hide(axis="index") - styled_legend = styled_legend.hide(axis="columns") - html = html + styled_legend.to_html() + "
" - return html + html = "" + html = html + version_html + + legend = pd.DataFrame(columns=[""]) + legend.loc[len(legend)] = [_LEGEND_0] + legend.loc[len(legend)] = [_LEGEND_1] + legend.loc[len(legend)] = [_LEGEND_2] + legend.loc[len(legend)] = [_LEGEND_3] + legend.loc[len(legend)] = [_LEGEND_4] + legend.loc[len(legend)] = [_LEGEND_5] + legend.loc[len(legend)] = [_LEGEND_6] + + styled_legend = legend.style.set_table_styles(get_table_css()) + styled_legend.set_caption("Legend") + styled_legend = styled_legend.set_properties(**{"color": "#ffffff"}) + styled_legend = styled_legend.set_properties(**{"width": "200px"}) + styled_legend = styled_legend.applymap(style_legend) + styled_legend = styled_legend.hide(axis="index") + styled_legend = styled_legend.hide(axis="columns") + html = html + styled_legend.to_html() + "
" + return html def style_speedup(v): - if v > 10.0: - props = "background-color: #0277BD;" - elif v > 2.0: - props = "background-color: #2E7D32;" - elif v >= 1.0: - props = "background-color: #66BB6A;" - else: - props = "background-color: #FBC02D;" - return props + if v > 10.0: + props = "background-color: #0277BD;" + elif v > 2.0: + props = "background-color: #2E7D32;" + elif v >= 1.0: + props = "background-color: #66BB6A;" + else: + props = "background-color: #FBC02D;" + return props def style_slowdown(v): - if v >= 10.0: - props = "background-color: #880E4F" - elif v >= 2.0: - props = "background-color: #C62828;" - elif v > 1.15: - props = "background-color: #E57373;" - else: - props = "background-color: #FBC02D;" - return props + if v >= 10.0: + props = "background-color: #880E4F" + elif v >= 2.0: + props = "background-color: #C62828;" + elif v > 1.15: + props = "background-color: #E57373;" + else: + props = "background-color: #FBC02D;" + return props def style_performance(v): - if "faster" in v: - return style_speedup(float(v.split("x")[0])) - else: - return style_slowdown(float(v.split("x")[0])) + if "faster" in v: + return style_speedup(float(v.split("x")[0])) + else: + return style_slowdown(float(v.split("x")[0])) def style_latency(v): - if v == "nan": - return "color: #636363" - if "faster" in v: - return style_speedup(float(v.split("x")[0])) - else: - return style_slowdown(float(v.split("x")[0])) + if v == "nan": + return "color: #636363" + if "faster" in v: + return style_speedup(float(v.split("x")[0])) + else: + return style_slowdown(float(v.split("x")[0])) def style_memory(v): - if v == "nan": - return "color: #636363" - if "smaller" in v: - return style_speedup(float(v.split("x")[0])) - else: - return style_slowdown(float(v.split("x")[0])) + if v == "nan": + return "color: #636363" + if "smaller" in v: + return style_speedup(float(v.split("x")[0])) + else: + return style_slowdown(float(v.split("x")[0])) def format_latency_comparison(iree_latency, baseline_latency): - if iree_latency == 0 or baseline_latency == 0: - return "nan" + if iree_latency == 0 or baseline_latency == 0: + return "nan" - speedup = baseline_latency / iree_latency - slowdown = iree_latency / baseline_latency - faster_label = "{:.2f}x faster" - slower_label = "{:.2f}x slower" - latency = faster_label.format( - speedup) if speedup >= 1.0 else slower_label.format(slowdown) - return latency + speedup = baseline_latency / iree_latency + slowdown = iree_latency / baseline_latency + faster_label = "{:.2f}x faster" + slower_label = "{:.2f}x slower" + latency = ( + faster_label.format(speedup) + if speedup >= 1.0 + else slower_label.format(slowdown) + ) + return latency def format_memory_comparison(iree_memory, baseline_memory): - if iree_memory == 0 or baseline_memory == 0: - return "nan" - - smaller = baseline_memory / iree_memory - larger = iree_memory / baseline_memory - smaller_label = "{:.2f}x smaller" - larger_label = "{:0.2f}x larger" - memory = smaller_label.format( - smaller) if smaller >= 1.0 else larger_label.format(larger) - return memory + if iree_memory == 0 or baseline_memory == 0: + return "nan" + + smaller = baseline_memory / iree_memory + larger = iree_memory / baseline_memory + smaller_label = "{:.2f}x smaller" + larger_label = "{:0.2f}x larger" + memory = ( + smaller_label.format(smaller) if smaller >= 1.0 else larger_label.format(larger) + ) + return memory diff --git a/build_tools/scripts/add_license_header.py b/build_tools/scripts/add_license_header.py index eebc17e1cba4..305e3c598011 100755 --- a/build_tools/scripts/add_license_header.py +++ b/build_tools/scripts/add_license_header.py @@ -27,163 +27,191 @@ """ -class CommentSyntax(object): - def __init__(self, start_comment, middle_comment=None, end_comment=""): - self.start_comment = start_comment - self.middle_comment = middle_comment if middle_comment else start_comment - self.end_comment = end_comment +class CommentSyntax(object): + def __init__(self, start_comment, middle_comment=None, end_comment=""): + self.start_comment = start_comment + self.middle_comment = middle_comment if middle_comment else start_comment + self.end_comment = end_comment def comment_arg_parser(v): - """Can be used to parse a comment syntax triple.""" - if v is None: - return None - if not isinstance(v, str): - raise argparse.ArgumentTypeError("String expected") - return CommentSyntax(*v.split(",")) + """Can be used to parse a comment syntax triple.""" + if v is None: + return None + if not isinstance(v, str): + raise argparse.ArgumentTypeError("String expected") + return CommentSyntax(*v.split(",")) def create_multikey(d): - # pylint: disable=g-complex-comprehension - return {k: v for keys, v in d.items() for k in keys} + # pylint: disable=g-complex-comprehension + return {k: v for keys, v in d.items() for k in keys} -filename_to_comment = create_multikey({ - ("BUILD", "CMakeLists.txt"): CommentSyntax("#"), -}) +filename_to_comment = create_multikey( + { + ("BUILD", "CMakeLists.txt"): CommentSyntax("#"), + } +) -ext_to_comment = create_multikey({ - (".bzl", ".cfg", ".cmake", ".overlay", ".py", ".sh", ".yml"): - CommentSyntax("#"), - (".cc", ".cpp", ".comp", ".fbs", ".h", ".hpp", ".inc", ".td"): - CommentSyntax("//"), - (".def",): - CommentSyntax(";;"), -}) +ext_to_comment = create_multikey( + { + (".bzl", ".cfg", ".cmake", ".overlay", ".py", ".sh", ".yml"): CommentSyntax( + "#" + ), + (".cc", ".cpp", ".comp", ".fbs", ".h", ".hpp", ".inc", ".td"): CommentSyntax( + "//" + ), + (".def",): CommentSyntax(";;"), + } +) def get_comment_syntax(args): - """Deterime the comment syntax to use.""" - if args.comment: - return args.comment - basename = os.path.basename(args.filename) - from_filename = filename_to_comment.get(basename) - if from_filename: - return from_filename - _, ext = os.path.splitext(args.filename) - return ext_to_comment.get(ext, args.default_comment) + """Deterime the comment syntax to use.""" + if args.comment: + return args.comment + basename = os.path.basename(args.filename) + from_filename = filename_to_comment.get(basename) + if from_filename: + return from_filename + _, ext = os.path.splitext(args.filename) + return ext_to_comment.get(ext, args.default_comment) def parse_arguments(): - """Parses command line arguments.""" - current_year = datetime.date.today().year - parser = argparse.ArgumentParser() - input_group = parser.add_mutually_exclusive_group() - input_group.add_argument("infile", - nargs="?", - type=argparse.FileType("r", encoding="UTF-8"), - help="Input file to format. Default: stdin", - default=sys.stdin) - parser.add_argument( - "--filename", - "--assume-filename", - type=str, - default=None, - help=( - "Filename to use for determining comment syntax. Default: actual name" - "of input file.")) - parser.add_argument( - "--year", - "-y", - help="Year to add copyright. Default: the current year ({})".format( - current_year), - default=current_year) - parser.add_argument("--holder", - help="Copyright holder. Default: The IREE Authors", - default="The IREE Authors") - parser.add_argument( - "--quiet", - help=("Don't raise a runtime error on encountering an unhandled filetype." - "Useful for running across many files at once. Default: False"), - action="store_true", - default=False) - output_group = parser.add_mutually_exclusive_group() - output_group.add_argument("-o", - "--outfile", - "--output", - help="File to send output. Default: stdout", - type=argparse.FileType("w", encoding="UTF-8"), - default=sys.stdout) - output_group.add_argument("--in_place", - "-i", - action="store_true", - help="Run formatting in place. Default: False", - default=False) - comment_group = parser.add_mutually_exclusive_group() - comment_group.add_argument("--comment", - "-c", - type=comment_arg_parser, - help="Override comment syntax.", - default=None) - comment_group.add_argument( - "--default_comment", - type=comment_arg_parser, - help="Fallback comment syntax if filename is unknown. Default: None", - default=None) - args = parser.parse_args() - - if args.in_place and args.infile == sys.stdin: - raise parser.error("Cannot format stdin in place") - - if not args.filename and args.infile != sys.stdin: - args.filename = args.infile.name - - return args + """Parses command line arguments.""" + current_year = datetime.date.today().year + parser = argparse.ArgumentParser() + input_group = parser.add_mutually_exclusive_group() + input_group.add_argument( + "infile", + nargs="?", + type=argparse.FileType("r", encoding="UTF-8"), + help="Input file to format. Default: stdin", + default=sys.stdin, + ) + parser.add_argument( + "--filename", + "--assume-filename", + type=str, + default=None, + help=( + "Filename to use for determining comment syntax. Default: actual name" + "of input file." + ), + ) + parser.add_argument( + "--year", + "-y", + help="Year to add copyright. Default: the current year ({})".format( + current_year + ), + default=current_year, + ) + parser.add_argument( + "--holder", + help="Copyright holder. Default: The IREE Authors", + default="The IREE Authors", + ) + parser.add_argument( + "--quiet", + help=( + "Don't raise a runtime error on encountering an unhandled filetype." + "Useful for running across many files at once. Default: False" + ), + action="store_true", + default=False, + ) + output_group = parser.add_mutually_exclusive_group() + output_group.add_argument( + "-o", + "--outfile", + "--output", + help="File to send output. Default: stdout", + type=argparse.FileType("w", encoding="UTF-8"), + default=sys.stdout, + ) + output_group.add_argument( + "--in_place", + "-i", + action="store_true", + help="Run formatting in place. Default: False", + default=False, + ) + comment_group = parser.add_mutually_exclusive_group() + comment_group.add_argument( + "--comment", + "-c", + type=comment_arg_parser, + help="Override comment syntax.", + default=None, + ) + comment_group.add_argument( + "--default_comment", + type=comment_arg_parser, + help="Fallback comment syntax if filename is unknown. Default: None", + default=None, + ) + args = parser.parse_args() + + if args.in_place and args.infile == sys.stdin: + raise parser.error("Cannot format stdin in place") + + if not args.filename and args.infile != sys.stdin: + args.filename = args.infile.name + + return args def main(args): - first_line = args.infile.readline() - already_has_license = False - shebang = "" - content_lines = [] - if first_line.startswith("#!"): - shebang = first_line - else: - content_lines = [first_line] - content_lines.extend(args.infile.readlines()) - for line in content_lines: - if COPYRIGHT_PATTERN.search(line): - already_has_license = True - break - if already_has_license: - header = shebang - else: - comment_syntax = get_comment_syntax(args) - if not comment_syntax: - if args.quiet: + first_line = args.infile.readline() + already_has_license = False + shebang = "" + content_lines = [] + if first_line.startswith("#!"): + shebang = first_line + else: + content_lines = [first_line] + content_lines.extend(args.infile.readlines()) + for line in content_lines: + if COPYRIGHT_PATTERN.search(line): + already_has_license = True + break + if already_has_license: header = shebang - else: - raise ValueError("Could not determine comment syntax for " + - args.filename) else: - header = LICENSE_HEADER_FORMATTER.format( - # Add a blank line between shebang and license. - shebang=(shebang + "\n" if shebang else ""), - start_comment=comment_syntax.start_comment, - middle_comment=comment_syntax.middle_comment, - # Add a blank line before the end comment. - end_comment=("\n" + comment_syntax.end_comment - if comment_syntax.end_comment else ""), - year=args.year, - holder=args.holder) - - # Have to open for write after we're done reading. - if args.in_place: - args.outfile = open(args.filename, "w", encoding="UTF-8") - args.outfile.write(header) - args.outfile.writelines(content_lines) + comment_syntax = get_comment_syntax(args) + if not comment_syntax: + if args.quiet: + header = shebang + else: + raise ValueError( + "Could not determine comment syntax for " + args.filename + ) + else: + header = LICENSE_HEADER_FORMATTER.format( + # Add a blank line between shebang and license. + shebang=(shebang + "\n" if shebang else ""), + start_comment=comment_syntax.start_comment, + middle_comment=comment_syntax.middle_comment, + # Add a blank line before the end comment. + end_comment=( + "\n" + comment_syntax.end_comment + if comment_syntax.end_comment + else "" + ), + year=args.year, + holder=args.holder, + ) + + # Have to open for write after we're done reading. + if args.in_place: + args.outfile = open(args.filename, "w", encoding="UTF-8") + args.outfile.write(header) + args.outfile.writelines(content_lines) if __name__ == "__main__": - main(parse_arguments()) + main(parse_arguments()) diff --git a/build_tools/scripts/check_path_lengths.py b/build_tools/scripts/check_path_lengths.py index 645ba7d4a286..42d95c2317a7 100755 --- a/build_tools/scripts/check_path_lengths.py +++ b/build_tools/scripts/check_path_lengths.py @@ -30,70 +30,71 @@ def parse_arguments(): - parser = argparse.ArgumentParser(description="Path length checker") - # The default limit was selected based on repository state when this script - # was added. If the max path length decreases, consider lowering this too. - parser.add_argument("--limit", - help="Path length limit (inclusive)", - type=int, - default=75) - parser.add_argument( - "--include_tests", - help= - "Includes /test directories. False by default as these don't usually generate problematic files during the build", - action="store_true", - default=False) - parser.add_argument("--verbose", - help="Outputs detailed information about path lengths", - action="store_true", - default=False) - args = parser.parse_args() - return args + parser = argparse.ArgumentParser(description="Path length checker") + # The default limit was selected based on repository state when this script + # was added. If the max path length decreases, consider lowering this too. + parser.add_argument( + "--limit", help="Path length limit (inclusive)", type=int, default=75 + ) + parser.add_argument( + "--include_tests", + help="Includes /test directories. False by default as these don't usually generate problematic files during the build", + action="store_true", + default=False, + ) + parser.add_argument( + "--verbose", + help="Outputs detailed information about path lengths", + action="store_true", + default=False, + ) + args = parser.parse_args() + return args def main(args): - repo_root = pathlib.Path(__file__).parent.parent.parent + repo_root = pathlib.Path(__file__).parent.parent.parent - # Just look at the compiler directory for now, since it has historically had - # by far the longest paths. - walk_root = os.path.join(repo_root, "compiler") + # Just look at the compiler directory for now, since it has historically had + # by far the longest paths. + walk_root = os.path.join(repo_root, "compiler") - longest_path_length = -1 - long_paths = [] - short_paths = [] - for dirpath, dirnames, _ in os.walk(walk_root): - # Don't descend into test directories, since they typically don't generate - # object files or binaries that could trip up the build system. - if not args.include_tests and "test" in dirnames: - dirnames.remove("test") + longest_path_length = -1 + long_paths = [] + short_paths = [] + for dirpath, dirnames, _ in os.walk(walk_root): + # Don't descend into test directories, since they typically don't generate + # object files or binaries that could trip up the build system. + if not args.include_tests and "test" in dirnames: + dirnames.remove("test") - path = pathlib.Path(dirpath).relative_to(repo_root).as_posix() - if len(path) > args.limit: - long_paths.append(path) - else: - short_paths.append(path) - longest_path_length = max(longest_path_length, len(path)) - long_paths.sort(key=len) - short_paths.sort(key=len) + path = pathlib.Path(dirpath).relative_to(repo_root).as_posix() + if len(path) > args.limit: + long_paths.append(path) + else: + short_paths.append(path) + longest_path_length = max(longest_path_length, len(path)) + long_paths.sort(key=len) + short_paths.sort(key=len) - if args.verbose and short_paths: - print(f"These paths are shorter than the limit of {args.limit} characters:") - for path in short_paths: - print("{:3d}, {}".format(len(path), path)) + if args.verbose and short_paths: + print(f"These paths are shorter than the limit of {args.limit} characters:") + for path in short_paths: + print("{:3d}, {}".format(len(path), path)) - if long_paths: - print(f"These paths are longer than the limit of {args.limit} characters:") - for path in long_paths: - print("{:3d}, {}".format(len(path), path)) - print( - f"Error: {len(long_paths)} source paths are longer than {args.limit} characters." - ) - print(" Long paths can be problematic when building on Windows.") - print(" Please look at the output above and trim the paths.") - sys.exit(1) - else: - print(f"All path lengths are under the limit of {args.limit} characters.") + if long_paths: + print(f"These paths are longer than the limit of {args.limit} characters:") + for path in long_paths: + print("{:3d}, {}".format(len(path), path)) + print( + f"Error: {len(long_paths)} source paths are longer than {args.limit} characters." + ) + print(" Long paths can be problematic when building on Windows.") + print(" Please look at the output above and trim the paths.") + sys.exit(1) + else: + print(f"All path lengths are under the limit of {args.limit} characters.") if __name__ == "__main__": - main(parse_arguments()) + main(parse_arguments()) diff --git a/build_tools/scripts/download_file.py b/build_tools/scripts/download_file.py index da1a1d31af1e..ffa422077528 100755 --- a/build_tools/scripts/download_file.py +++ b/build_tools/scripts/download_file.py @@ -25,84 +25,91 @@ def parse_arguments(): - """Parses command line arguments.""" - parser = argparse.ArgumentParser( - description="Downloads a file from the web " - "and decompresses it if necessary. NEVER Use this tool to download from " - "untrusted sources, it doesn't unpack the file safely.") - parser.add_argument("source_url", - type=str, - metavar="", - help="Source URL to download") - parser.add_argument("-o", - "--output", - type=str, - required=True, - metavar="", - help="Output file path") - parser.add_argument("--unpack", - action='store_true', - default=False, - help="Unpack the downloaded file if it's an archive") - parser.add_argument("--max-tries", - metavar="", - type=int, - default=DEFAULT_MAX_TRIES, - help="Number of tries before giving up") - return parser.parse_args() + """Parses command line arguments.""" + parser = argparse.ArgumentParser( + description="Downloads a file from the web " + "and decompresses it if necessary. NEVER Use this tool to download from " + "untrusted sources, it doesn't unpack the file safely." + ) + parser.add_argument( + "source_url", type=str, metavar="", help="Source URL to download" + ) + parser.add_argument( + "-o", + "--output", + type=str, + required=True, + metavar="", + help="Output file path", + ) + parser.add_argument( + "--unpack", + action="store_true", + default=False, + help="Unpack the downloaded file if it's an archive", + ) + parser.add_argument( + "--max-tries", + metavar="", + type=int, + default=DEFAULT_MAX_TRIES, + help="Number of tries before giving up", + ) + return parser.parse_args() def download_and_extract(source_url: str, output: str, unpack: bool): - # Open the URL and get the file-like streaming object. - with urllib.request.urlopen(source_url) as response: - if response.status != 200: - raise RuntimeError( - f"Failed to download file with status {response.status} {response.msg}" - ) + # Open the URL and get the file-like streaming object. + with urllib.request.urlopen(source_url) as response: + if response.status != 200: + raise RuntimeError( + f"Failed to download file with status {response.status} {response.msg}" + ) - if unpack: - if source_url.endswith(".tar.gz"): - # Open tar.gz in the streaming mode. - with tarfile.open(fileobj=response, mode="r|*") as tar_file: - if os.path.exists(output): - shutil.rmtree(output) - os.makedirs(output) - tar_file.extractall(output) - return - elif source_url.endswith(".gz"): - # Open gzip from a file-like object, which will be in the streaming mode. - with gzip.open(filename=response, mode="rb") as input_file: - with open(output, "wb") as output_file: - shutil.copyfileobj(input_file, output_file) - return + if unpack: + if source_url.endswith(".tar.gz"): + # Open tar.gz in the streaming mode. + with tarfile.open(fileobj=response, mode="r|*") as tar_file: + if os.path.exists(output): + shutil.rmtree(output) + os.makedirs(output) + tar_file.extractall(output) + return + elif source_url.endswith(".gz"): + # Open gzip from a file-like object, which will be in the streaming mode. + with gzip.open(filename=response, mode="rb") as input_file: + with open(output, "wb") as output_file: + shutil.copyfileobj(input_file, output_file) + return - # Fallback to download the file only. - with open(output, "wb") as output_file: - # Streaming copy. - shutil.copyfileobj(response, output_file) + # Fallback to download the file only. + with open(output, "wb") as output_file: + # Streaming copy. + shutil.copyfileobj(response, output_file) def main(args): - output_dir = os.path.dirname(args.output) + output_dir = os.path.dirname(args.output) - if not os.path.isdir(output_dir): - os.makedirs(output_dir) + if not os.path.isdir(output_dir): + os.makedirs(output_dir) - remaining_tries = args.max_tries - while remaining_tries > 0: - try: - download_and_extract(args.source_url, args.output, args.unpack) - break - except (ConnectionResetError, ConnectionRefusedError, - urllib.error.URLError): - remaining_tries -= 1 - if remaining_tries == 0: - raise - else: - logging.warning(f"Connection error, remaining {remaining_tries} tries", - exc_info=True) - time.sleep(RETRY_COOLDOWN_TIME) + remaining_tries = args.max_tries + while remaining_tries > 0: + try: + download_and_extract(args.source_url, args.output, args.unpack) + break + except (ConnectionResetError, ConnectionRefusedError, urllib.error.URLError): + remaining_tries -= 1 + if remaining_tries == 0: + raise + else: + logging.warning( + f"Connection error, remaining {remaining_tries} tries", + exc_info=True, + ) + time.sleep(RETRY_COOLDOWN_TIME) if __name__ == "__main__": - main(parse_arguments()) + main(parse_arguments()) diff --git a/build_tools/scripts/generate_compilation_flagfile.py b/build_tools/scripts/generate_compilation_flagfile.py index cf0cb13656f1..adda56ed2beb 100755 --- a/build_tools/scripts/generate_compilation_flagfile.py +++ b/build_tools/scripts/generate_compilation_flagfile.py @@ -16,23 +16,24 @@ def parse_arguments(): - """Parses command line arguments.""" - parser = argparse.ArgumentParser() - parser.add_argument("--output", - type=str, - required=True, - help="output file to write to") - parser.add_argument("compilation_flags", - metavar="", - nargs="*", - help="list of compilation flags") - return parser.parse_args() + """Parses command line arguments.""" + parser = argparse.ArgumentParser() + parser.add_argument( + "--output", type=str, required=True, help="output file to write to" + ) + parser.add_argument( + "compilation_flags", + metavar="", + nargs="*", + help="list of compilation flags", + ) + return parser.parse_args() def main(args): - with open(args.output, "w") as f: - f.write("\n".join(args.compilation_flags) + "\n") + with open(args.output, "w") as f: + f.write("\n".join(args.compilation_flags) + "\n") if __name__ == "__main__": - main(parse_arguments()) + main(parse_arguments()) diff --git a/build_tools/scripts/generate_flagfile.py b/build_tools/scripts/generate_flagfile.py index f0330e042a75..fb1effd12686 100755 --- a/build_tools/scripts/generate_flagfile.py +++ b/build_tools/scripts/generate_flagfile.py @@ -12,54 +12,67 @@ def parse_arguments(): - """Parses command line arguments.""" - parser = argparse.ArgumentParser() - parser.add_argument("--module", - type=str, - required=True, - metavar="", - help="The name of the module file") - parser.add_argument("--device", - type=str, - required=True, - metavar="", - help="The name of the HAL device") - parser.add_argument("--function", - type=str, - required=True, - metavar="", - help="The name of the entry function") - parser.add_argument("--inputs", - type=str, - required=True, - metavar="", - help="A list of comma-separated function inputs") - parser.add_argument("--additional_args", - type=str, - required=True, - metavar="", - help="Additional command-line arguments") - parser.add_argument("-o", - "--output", - type=str, - required=True, - metavar="", - help="Output file to write to") - return parser.parse_args() + """Parses command line arguments.""" + parser = argparse.ArgumentParser() + parser.add_argument( + "--module", + type=str, + required=True, + metavar="", + help="The name of the module file", + ) + parser.add_argument( + "--device", + type=str, + required=True, + metavar="", + help="The name of the HAL device", + ) + parser.add_argument( + "--function", + type=str, + required=True, + metavar="", + help="The name of the entry function", + ) + parser.add_argument( + "--inputs", + type=str, + required=True, + metavar="", + help="A list of comma-separated function inputs", + ) + parser.add_argument( + "--additional_args", + type=str, + required=True, + metavar="", + help="Additional command-line arguments", + ) + parser.add_argument( + "-o", + "--output", + type=str, + required=True, + metavar="", + help="Output file to write to", + ) + return parser.parse_args() def main(args): - lines = [ - f"--device={args.device}", f"--module={args.module}", - f"--function={args.function}" - ] - lines.extend([("--input=" + e) for e in args.inputs.split(",")]) - lines.extend(args.additional_args.split(";")) - content = "\n".join(lines) + "\n" + lines = [ + f"--device={args.device}", + f"--module={args.module}", + f"--function={args.function}", + ] + lines.extend([("--input=" + e) for e in args.inputs.split(",")]) + lines.extend(args.additional_args.split(";")) + content = "\n".join(lines) + "\n" - with open(args.output, "w") as f: - f.writelines(content) + with open(args.output, "w") as f: + f.writelines(content) if __name__ == "__main__": - main(parse_arguments()) + main(parse_arguments()) diff --git a/build_tools/scripts/generate_release_index.py b/build_tools/scripts/generate_release_index.py index 0e7ea941c88b..70a4eeb57fde 100755 --- a/build_tools/scripts/generate_release_index.py +++ b/build_tools/scripts/generate_release_index.py @@ -19,63 +19,74 @@ def parse_arguments(): - parser = argparse.ArgumentParser() - parser.add_argument("--repo", - "--repository", - default="openxla/iree", - help="The GitHub repository to fetch releases from.") - parser.add_argument( - "--output", - default="-", - help="The file to write the HTML to or '-' for stdout (the default)") - return parser.parse_args() + parser = argparse.ArgumentParser() + parser.add_argument( + "--repo", + "--repository", + default="openxla/iree", + help="The GitHub repository to fetch releases from.", + ) + parser.add_argument( + "--output", + default="-", + help="The file to write the HTML to or '-' for stdout (the default)", + ) + return parser.parse_args() class ReleaseFetcher: + def __init__(self, repo, per_page=100): + self._session = requests.Session() + self._repo = repo + self._per_page = per_page - def __init__(self, repo, per_page=100): - self._session = requests.Session() - self._repo = repo - self._per_page = per_page + def get_all(self): + url = f"https://api.github.com/repos/{self._repo}/releases" + page = 1 - def get_all(self): - url = f"https://api.github.com/repos/{self._repo}/releases" - page = 1 - - while True: - response = self._session.get(url, - params={ - "page": page, - "per_page": self._per_page, - }) - for release in response.json(): - yield release - if "next" not in response.links: - break - page += 1 + while True: + response = self._session.get( + url, + params={ + "page": page, + "per_page": self._per_page, + }, + ) + for release in response.json(): + yield release + if "next" not in response.links: + break + page += 1 def main(args): - fetcher = ReleaseFetcher(repo=args.repo) - with (sys.stdout if args.output == "-" else open(args.output, "w")) as f: - f.write( - textwrap.dedent("""\ + fetcher = ReleaseFetcher(repo=args.repo) + with sys.stdout if args.output == "-" else open(args.output, "w") as f: + f.write( + textwrap.dedent( + """\ - """)) - for release in fetcher.get_all(): - if release["draft"]: - continue - for asset in release["assets"]: - url = html.escape(asset['browser_download_url']) - name = html.escape(asset['name']) - f.write(f" {name}
\n") - f.write(textwrap.dedent("""\ + """ + ) + ) + for release in fetcher.get_all(): + if release["draft"]: + continue + for asset in release["assets"]: + url = html.escape(asset["browser_download_url"]) + name = html.escape(asset["name"]) + f.write(f" {name}
\n") + f.write( + textwrap.dedent( + """\ - """)) + """ + ) + ) if __name__ == "__main__": - main(parse_arguments()) + main(parse_arguments()) diff --git a/build_tools/scripts/get_e2e_artifacts.py b/build_tools/scripts/get_e2e_artifacts.py index 634ee310ba2c..88754389a7df 100755 --- a/build_tools/scripts/get_e2e_artifacts.py +++ b/build_tools/scripts/get_e2e_artifacts.py @@ -29,153 +29,156 @@ from absl import flags SUITE_NAME_TO_TARGET = { - 'e2e_tests': - '//integrations/tensorflow/e2e:e2e_tests', - 'mobile_bert_squad_tests': - '//integrations/tensorflow/e2e:mobile_bert_squad_tests', - 'layers_tests': - '//integrations/tensorflow/e2e/keras/layers:layers_tests', - 'layers_dynamic_batch_tests': - '//integrations/tensorflow/e2e/keras/layers:layers_dynamic_batch_tests', - 'layers_training_tests': - '//integrations/tensorflow/e2e/keras/layers:layers_training_tests', - 'keyword_spotting_tests': - '//integrations/tensorflow/e2e/keras:keyword_spotting_tests', - 'keyword_spotting_internal_streaming_tests': - '//integrations/tensorflow/e2e/keras:keyword_spotting_internal_streaming_tests', - 'imagenet_non_hermetic_tests': - '//integrations/tensorflow/e2e/keras/applications:imagenet_non_hermetic_tests', - 'slim_vision_tests': - '//integrations/tensorflow/e2e/slim_vision_models:slim_vision_tests', + "e2e_tests": "//integrations/tensorflow/e2e:e2e_tests", + "mobile_bert_squad_tests": "//integrations/tensorflow/e2e:mobile_bert_squad_tests", + "layers_tests": "//integrations/tensorflow/e2e/keras/layers:layers_tests", + "layers_dynamic_batch_tests": "//integrations/tensorflow/e2e/keras/layers:layers_dynamic_batch_tests", + "layers_training_tests": "//integrations/tensorflow/e2e/keras/layers:layers_training_tests", + "keyword_spotting_tests": "//integrations/tensorflow/e2e/keras:keyword_spotting_tests", + "keyword_spotting_internal_streaming_tests": "//integrations/tensorflow/e2e/keras:keyword_spotting_internal_streaming_tests", + "imagenet_non_hermetic_tests": "//integrations/tensorflow/e2e/keras/applications:imagenet_non_hermetic_tests", + "slim_vision_tests": "//integrations/tensorflow/e2e/slim_vision_models:slim_vision_tests", } -SUITES_HELP = [f'`{name}`' for name in SUITE_NAME_TO_TARGET] +SUITES_HELP = [f"`{name}`" for name in SUITE_NAME_TO_TARGET] SUITES_HELP = f'{", ".join(SUITES_HELP[:-1])} and {SUITES_HELP[-1]}' FLAGS = flags.FLAGS flags.DEFINE_bool( - 'dry_run', False, - 'Run without extracting files. Useful for quickly checking for artifact ' - 'collisions.') + "dry_run", + False, + "Run without extracting files. Useful for quickly checking for artifact " + "collisions.", +) flags.DEFINE_string( - 'artifacts_dir', os.path.join(tempfile.gettempdir(), 'iree', 'modules'), - 'Directory to transfer the benchmarking artifacts to. Defaults to ' - '/tmp/iree/modules/') -flags.DEFINE_bool('run_test_suites', True, 'Run any specified test suites.') -flags.DEFINE_list('test_suites', list(SUITE_NAME_TO_TARGET.keys()), - f'Any combination of {SUITES_HELP}.') + "artifacts_dir", + os.path.join(tempfile.gettempdir(), "iree", "modules"), + "Directory to transfer the benchmarking artifacts to. Defaults to " + "/tmp/iree/modules/", +) +flags.DEFINE_bool("run_test_suites", True, "Run any specified test suites.") +flags.DEFINE_list( + "test_suites", + list(SUITE_NAME_TO_TARGET.keys()), + f"Any combination of {SUITES_HELP}.", +) -EXPECTED_COLLISIONS = [ - '/tf_ref/', 'tf_input.mlir', 'iree_input.mlir', '/saved_model/' -] +EXPECTED_COLLISIONS = ["/tf_ref/", "tf_input.mlir", "iree_input.mlir", "/saved_model/"] def _target_to_testlogs_path(target: str) -> str: - """Convert target into the path where Bazel stores the artifacts we want.""" - return os.path.join('bazel-testlogs', - target.replace('//', '').replace(':', os.sep)) + """Convert target into the path where Bazel stores the artifacts we want.""" + return os.path.join("bazel-testlogs", target.replace("//", "").replace(":", os.sep)) def _target_to_test_name(target: str, test_suite_path: str) -> str: - """Get test_name from `suite_name_test_name__tf__backend_name`.""" - return target.split('__')[0].replace(f'{test_suite_path}_', '') + """Get test_name from `suite_name_test_name__tf__backend_name`.""" + return target.split("__")[0].replace(f"{test_suite_path}_", "") def get_test_paths_and_names(test_suite_path: str): - """Get the paths Bazel stores test outputs in and the matching test names.""" - targets = utils.get_test_targets(test_suite_path) - test_paths = [_target_to_testlogs_path(target) for target in targets] - test_names = [ - _target_to_test_name(target, test_suite_path) for target in targets - ] - return test_paths, test_names - - -def check_collision(filename: str, test_name: str, written_paths: Set[str], - paths_to_tests: Dict[str, str]): - """Check that we aren't overwriting files unless we expect to.""" - # Note: We can't use a check that the files have identical contents because - # tf_input.mlir can have random numbers appended to its function names. - # See https://github.com/openxla/iree/issues/3375 - - expected_collision = any([name in filename for name in EXPECTED_COLLISIONS]) - if filename in written_paths and not expected_collision: - raise ValueError(f'Collision found on {filename} between {test_name}.py ' - f'and {paths_to_tests[filename]}.py') - else: - written_paths.add(filename) - paths_to_tests[filename] = test_name - - -def update_path(archive_path: str): - """Update the --module flag with the new location of the compiled.vmfb""" - backend_path = archive_path.split('traces')[0] # 'ModuleName/backend_name'. - compiled_path = os.path.join(FLAGS.artifacts_dir, backend_path, - 'compiled.vmfb') - flagfile_path = os.path.join(FLAGS.artifacts_dir, archive_path) - for line in fileinput.input(files=[flagfile_path], inplace=True): - if line.strip().startswith('--module'): - print(f'--module={compiled_path}\n', end='') + """Get the paths Bazel stores test outputs in and the matching test names.""" + targets = utils.get_test_targets(test_suite_path) + test_paths = [_target_to_testlogs_path(target) for target in targets] + test_names = [_target_to_test_name(target, test_suite_path) for target in targets] + return test_paths, test_names + + +def check_collision( + filename: str, + test_name: str, + written_paths: Set[str], + paths_to_tests: Dict[str, str], +): + """Check that we aren't overwriting files unless we expect to.""" + # Note: We can't use a check that the files have identical contents because + # tf_input.mlir can have random numbers appended to its function names. + # See https://github.com/openxla/iree/issues/3375 + + expected_collision = any([name in filename for name in EXPECTED_COLLISIONS]) + if filename in written_paths and not expected_collision: + raise ValueError( + f"Collision found on {filename} between {test_name}.py " + f"and {paths_to_tests[filename]}.py" + ) else: - print(line, end='') - + written_paths.add(filename) + paths_to_tests[filename] = test_name -def extract_artifacts(test_path: str, test_name: str, written_paths: Set[str], - paths_to_tests: Dict[str, str]): - """Unzips all of the benchmarking artifacts for a given test and backend.""" - outputs = os.path.join(test_path, 'test.outputs', 'outputs.zip') - if FLAGS.dry_run and not os.path.exists(outputs): - # The artifacts may or may not be present on disk during a dry run. If they - # are then we want to collision check them, but if they aren't that's fine. - return - archive = zipfile.ZipFile(outputs) - # Filter out directory names. - filenames = [name for name in archive.namelist() if name[-1] != os.sep] - - for filename in filenames: - # Check for collisions. - check_collision(filename, test_name, written_paths, paths_to_tests) - - # Extract and update flagfile path. - if not FLAGS.dry_run: - archive.extract(filename, FLAGS.artifacts_dir) - if filename.endswith('flagfile'): - update_path(filename) +def update_path(archive_path: str): + """Update the --module flag with the new location of the compiled.vmfb""" + backend_path = archive_path.split("traces")[0] # 'ModuleName/backend_name'. + compiled_path = os.path.join(FLAGS.artifacts_dir, backend_path, "compiled.vmfb") + flagfile_path = os.path.join(FLAGS.artifacts_dir, archive_path) + for line in fileinput.input(files=[flagfile_path], inplace=True): + if line.strip().startswith("--module"): + print(f"--module={compiled_path}\n", end="") + else: + print(line, end="") + + +def extract_artifacts( + test_path: str, + test_name: str, + written_paths: Set[str], + paths_to_tests: Dict[str, str], +): + """Unzips all of the benchmarking artifacts for a given test and backend.""" + outputs = os.path.join(test_path, "test.outputs", "outputs.zip") + if FLAGS.dry_run and not os.path.exists(outputs): + # The artifacts may or may not be present on disk during a dry run. If they + # are then we want to collision check them, but if they aren't that's fine. + return + + archive = zipfile.ZipFile(outputs) + # Filter out directory names. + filenames = [name for name in archive.namelist() if name[-1] != os.sep] + + for filename in filenames: + # Check for collisions. + check_collision(filename, test_name, written_paths, paths_to_tests) + + # Extract and update flagfile path. + if not FLAGS.dry_run: + archive.extract(filename, FLAGS.artifacts_dir) + if filename.endswith("flagfile"): + update_path(filename) def main(argv): - del argv # Unused. - - print( - "The bazel integrations build and tests are deprecated. This script " - "may be reworked in the future. For the time being refer to " - "https://github.com/openxla/iree/blob/main/docs/developers/developing_iree/e2e_benchmarking.md " - "for information on how to run TensorFlow benchmarks.") - exit(1) - - # Convert test suite shorthands to full test suite targets. - test_suites = [SUITE_NAME_TO_TARGET[suite] for suite in FLAGS.test_suites] - - if FLAGS.run_test_suites: - # Use bazel test to execute all of the test suites in parallel. - command = ['bazel', 'test', *test_suites, '--color=yes'] - print(f'Running: `{" ".join(command)}`') - if not FLAGS.dry_run: - subprocess.run(command, check=True) - print() - - written_paths = set() - paths_to_tests = dict() - - for test_suite in test_suites: - # Extract all of the artifacts for this test suite. - test_paths, test_names = get_test_paths_and_names(test_suite) - for i, (test_path, test_name) in enumerate(zip(test_paths, test_names)): - print(f'\rTransfering {test_suite} {i + 1}/{len(test_paths)}', end='') - extract_artifacts(test_path, test_name, written_paths, paths_to_tests) - print('\n') - - -if __name__ == '__main__': - app.run(main) + del argv # Unused. + + print( + "The bazel integrations build and tests are deprecated. This script " + "may be reworked in the future. For the time being refer to " + "https://github.com/openxla/iree/blob/main/docs/developers/developing_iree/e2e_benchmarking.md " + "for information on how to run TensorFlow benchmarks." + ) + exit(1) + + # Convert test suite shorthands to full test suite targets. + test_suites = [SUITE_NAME_TO_TARGET[suite] for suite in FLAGS.test_suites] + + if FLAGS.run_test_suites: + # Use bazel test to execute all of the test suites in parallel. + command = ["bazel", "test", *test_suites, "--color=yes"] + print(f'Running: `{" ".join(command)}`') + if not FLAGS.dry_run: + subprocess.run(command, check=True) + print() + + written_paths = set() + paths_to_tests = dict() + + for test_suite in test_suites: + # Extract all of the artifacts for this test suite. + test_paths, test_names = get_test_paths_and_names(test_suite) + for i, (test_path, test_name) in enumerate(zip(test_paths, test_names)): + print(f"\rTransfering {test_suite} {i + 1}/{len(test_paths)}", end="") + extract_artifacts(test_path, test_name, written_paths, paths_to_tests) + print("\n") + + +if __name__ == "__main__": + app.run(main) diff --git a/build_tools/scripts/git/check_submodule_init.py b/build_tools/scripts/git/check_submodule_init.py index 611c32f982cb..b878ef356f18 100644 --- a/build_tools/scripts/git/check_submodule_init.py +++ b/build_tools/scripts/git/check_submodule_init.py @@ -12,37 +12,47 @@ def run(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--runtime_only", - help=("Only check the initialization of the submodules for the" - "runtime-dependent submodules. Default: False"), - action="store_true", - default=False) - args = parser.parse_args() - # No-op if we're not in a git repository. - try: - subprocess.check_call(['git', 'rev-parse', '--is-inside-work-tree'], - stdout=subprocess.DEVNULL, - stderr=subprocess.DEVNULL) - except: - return - - output = os.popen("git submodule status") - submodules = output.readlines() - - runtime_submodules = pathlib.Path(__file__).with_name( - "runtime_submodules.txt").read_text().split("\n") - - for submodule in submodules: - prefix = submodule.strip()[0] - name = submodule.split()[1] - if prefix == "-" and (not args.runtime_only or name in runtime_submodules): - print( - "The git submodule '%s' is not initialized. Please run `git submodule update --init`" - % (name)) - sys.exit(1) + parser = argparse.ArgumentParser() + parser.add_argument( + "--runtime_only", + help=( + "Only check the initialization of the submodules for the" + "runtime-dependent submodules. Default: False" + ), + action="store_true", + default=False, + ) + args = parser.parse_args() + # No-op if we're not in a git repository. + try: + subprocess.check_call( + ["git", "rev-parse", "--is-inside-work-tree"], + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + ) + except: + return + + output = os.popen("git submodule status") + submodules = output.readlines() + + runtime_submodules = ( + pathlib.Path(__file__) + .with_name("runtime_submodules.txt") + .read_text() + .split("\n") + ) + + for submodule in submodules: + prefix = submodule.strip()[0] + name = submodule.split()[1] + if prefix == "-" and (not args.runtime_only or name in runtime_submodules): + print( + "The git submodule '%s' is not initialized. Please run `git submodule update --init`" + % (name) + ) + sys.exit(1) if __name__ == "__main__": - run() + run() diff --git a/build_tools/scripts/integrate/bump_llvm.py b/build_tools/scripts/integrate/bump_llvm.py index 7dd770b2cda4..459da101a7bc 100755 --- a/build_tools/scripts/integrate/bump_llvm.py +++ b/build_tools/scripts/integrate/bump_llvm.py @@ -44,92 +44,103 @@ def main(args): - if not args.disable_setup_remote: - iree_utils.git_setup_remote(args.upstream_remote, args.upstream_repository) - - iree_utils.git_check_porcelain() - print(f"Fetching remote repository: {args.upstream_remote}") - iree_utils.git_fetch(repository=args.upstream_remote) - - # If re-using a branch, make sure we are not on that branch. - if args.reuse_branch: - iree_utils.git_checkout("main") - - # Create branch. - branch_name = args.branch_name - if not branch_name: - branch_name = f"bump-llvm-{date.today().strftime('%Y%m%d')}" - print(f"Creating branch {branch_name} (override with --branch-name=)") - iree_utils.git_create_branch(branch_name, - checkout=True, - ref=f"{args.upstream_remote}/main", - force=args.reuse_branch) - - # Reset the llvm-project submodule to track upstream. - # This will discard any cherrypicks that may have been committed locally, - # but the assumption is that if doing a main llvm version bump, the - # cherrypicks will be incorporated at the new commit. If not, well, ymmv - # and you will find out. - iree_utils.git_submodule_set_origin( - "third_party/llvm-project", - url="https://github.com/iree-org/iree-llvm-fork.git", - branch="--default") - - # Remove the branch pin file, reverting us to pure upstream. - branch_pin_file = os.path.join( - iree_utils.get_repo_root(), - iree_modules.MODULE_INFOS["llvm-project"].branch_pin_file) - if os.path.exists(branch_pin_file): - os.remove(branch_pin_file) - - # Update the LLVM submodule. - llvm_commit = args.llvm_commit - print(f"Updating LLVM submodule to {llvm_commit}") - llvm_root = iree_utils.get_submodule_root("llvm-project") - iree_utils.git_fetch(repository="origin", - ref="refs/heads/main", - repo_dir=llvm_root) - if llvm_commit == "HEAD": - llvm_commit = "origin/main" - iree_utils.git_reset(llvm_commit, repo_dir=llvm_root) - llvm_commit, llvm_summary = iree_utils.git_current_commit(repo_dir=llvm_root) - print(f"LLVM submodule reset to:\n {llvm_summary}\n") - - # Create a commit. - print("Create commit...") - iree_utils.git_create_commit( - message=(f"Integrate llvm-project at {llvm_commit}\n\n" - f"* Reset third_party/llvm-project: {llvm_summary}"), - add_all=True) - - # Push. - print("Pushing...") - iree_utils.git_push_branch(args.upstream_remote, branch_name) + if not args.disable_setup_remote: + iree_utils.git_setup_remote(args.upstream_remote, args.upstream_repository) + + iree_utils.git_check_porcelain() + print(f"Fetching remote repository: {args.upstream_remote}") + iree_utils.git_fetch(repository=args.upstream_remote) + + # If re-using a branch, make sure we are not on that branch. + if args.reuse_branch: + iree_utils.git_checkout("main") + + # Create branch. + branch_name = args.branch_name + if not branch_name: + branch_name = f"bump-llvm-{date.today().strftime('%Y%m%d')}" + print(f"Creating branch {branch_name} (override with --branch-name=)") + iree_utils.git_create_branch( + branch_name, + checkout=True, + ref=f"{args.upstream_remote}/main", + force=args.reuse_branch, + ) + + # Reset the llvm-project submodule to track upstream. + # This will discard any cherrypicks that may have been committed locally, + # but the assumption is that if doing a main llvm version bump, the + # cherrypicks will be incorporated at the new commit. If not, well, ymmv + # and you will find out. + iree_utils.git_submodule_set_origin( + "third_party/llvm-project", + url="https://github.com/iree-org/iree-llvm-fork.git", + branch="--default", + ) + + # Remove the branch pin file, reverting us to pure upstream. + branch_pin_file = os.path.join( + iree_utils.get_repo_root(), + iree_modules.MODULE_INFOS["llvm-project"].branch_pin_file, + ) + if os.path.exists(branch_pin_file): + os.remove(branch_pin_file) + + # Update the LLVM submodule. + llvm_commit = args.llvm_commit + print(f"Updating LLVM submodule to {llvm_commit}") + llvm_root = iree_utils.get_submodule_root("llvm-project") + iree_utils.git_fetch(repository="origin", ref="refs/heads/main", repo_dir=llvm_root) + if llvm_commit == "HEAD": + llvm_commit = "origin/main" + iree_utils.git_reset(llvm_commit, repo_dir=llvm_root) + llvm_commit, llvm_summary = iree_utils.git_current_commit(repo_dir=llvm_root) + print(f"LLVM submodule reset to:\n {llvm_summary}\n") + + # Create a commit. + print("Create commit...") + iree_utils.git_create_commit( + message=( + f"Integrate llvm-project at {llvm_commit}\n\n" + f"* Reset third_party/llvm-project: {llvm_summary}" + ), + add_all=True, + ) + + # Push. + print("Pushing...") + iree_utils.git_push_branch(args.upstream_remote, branch_name) def parse_arguments(argv): - parser = argparse.ArgumentParser(description="IREE LLVM-bump-inator") - parser.add_argument("--upstream-remote", - help="Upstream remote", - default="UPSTREAM_AUTOMATION") - parser.add_argument("--upstream-repository", - help="Upstream repository URL", - default="git@github.com:openxla/iree.git") - parser.add_argument("--disable-setup-remote", - help="Disable remote setup", - action="store_true", - default=False) - parser.add_argument("--llvm-commit", help="LLVM commit sha", default="HEAD") - parser.add_argument("--branch-name", - help="Integrate branch to create", - default=None) - parser.add_argument("--reuse-branch", - help="Allow re-use of an existing branch", - action="store_true", - default=False) - args = parser.parse_args(argv) - return args + parser = argparse.ArgumentParser(description="IREE LLVM-bump-inator") + parser.add_argument( + "--upstream-remote", help="Upstream remote", default="UPSTREAM_AUTOMATION" + ) + parser.add_argument( + "--upstream-repository", + help="Upstream repository URL", + default="git@github.com:openxla/iree.git", + ) + parser.add_argument( + "--disable-setup-remote", + help="Disable remote setup", + action="store_true", + default=False, + ) + parser.add_argument("--llvm-commit", help="LLVM commit sha", default="HEAD") + parser.add_argument( + "--branch-name", help="Integrate branch to create", default=None + ) + parser.add_argument( + "--reuse-branch", + help="Allow re-use of an existing branch", + action="store_true", + default=False, + ) + args = parser.parse_args(argv) + return args if __name__ == "__main__": - main(parse_arguments(sys.argv[1:])) + main(parse_arguments(sys.argv[1:])) diff --git a/build_tools/scripts/integrate/iree_modules.py b/build_tools/scripts/integrate/iree_modules.py index fec8ff38411b..5587333ac852 100644 --- a/build_tools/scripts/integrate/iree_modules.py +++ b/build_tools/scripts/integrate/iree_modules.py @@ -6,40 +6,43 @@ class ModuleInfo: - - def __init__(self, *, name: str, path: str, branch_pin_file: str, - default_repository_url: str, fork_repository_push: str, - fork_repository_pull: str, branch_prefix: str): - self.name = name - self.path = path - self.branch_pin_file = branch_pin_file - self.default_repository_url = default_repository_url - self.fork_repository_push = fork_repository_push - self.fork_repository_pull = fork_repository_pull - self.branch_prefix = branch_prefix + def __init__( + self, + *, + name: str, + path: str, + branch_pin_file: str, + default_repository_url: str, + fork_repository_push: str, + fork_repository_pull: str, + branch_prefix: str + ): + self.name = name + self.path = path + self.branch_pin_file = branch_pin_file + self.default_repository_url = default_repository_url + self.fork_repository_push = fork_repository_push + self.fork_repository_pull = fork_repository_pull + self.branch_prefix = branch_prefix MODULE_INFOS = { - "llvm-project": - ModuleInfo( - name="llvm-project", - path="third_party/llvm-project", - branch_pin_file="third_party/llvm-project.branch-pin", - default_repository_url= - "https://github.com/iree-org/iree-llvm-fork.git", - fork_repository_push="git@github.com:iree-org/iree-llvm-fork.git", - fork_repository_pull= - "https://github.com/iree-org/iree-llvm-fork.git", - branch_prefix="patched-llvm-project-", - ), - "stablehlo": - ModuleInfo( - name="stablehlo", - path="third_party/stablehlo", - branch_pin_file="third_party/stablehlo.branch-pin", - default_repository_url="https://github.com/iree-org/stablehlo.git", - fork_repository_push="git@github.com:iree-org/stablehlo.git", - fork_repository_pull="https://github.com/iree-org/stablehlo.git", - branch_prefix="patched-stablehlo-", - ) + "llvm-project": ModuleInfo( + name="llvm-project", + path="third_party/llvm-project", + branch_pin_file="third_party/llvm-project.branch-pin", + default_repository_url="https://github.com/iree-org/iree-llvm-fork.git", + fork_repository_push="git@github.com:iree-org/iree-llvm-fork.git", + fork_repository_pull="https://github.com/iree-org/iree-llvm-fork.git", + branch_prefix="patched-llvm-project-", + ), + "stablehlo": ModuleInfo( + name="stablehlo", + path="third_party/stablehlo", + branch_pin_file="third_party/stablehlo.branch-pin", + default_repository_url="https://github.com/iree-org/stablehlo.git", + fork_repository_push="git@github.com:iree-org/stablehlo.git", + fork_repository_pull="https://github.com/iree-org/stablehlo.git", + branch_prefix="patched-stablehlo-", + ), } diff --git a/build_tools/scripts/integrate/iree_utils.py b/build_tools/scripts/integrate/iree_utils.py index 3a81ba8e6e45..21e54541a669 100644 --- a/build_tools/scripts/integrate/iree_utils.py +++ b/build_tools/scripts/integrate/iree_utils.py @@ -15,186 +15,200 @@ def get_repo_root() -> str: - global _repo_root - if _repo_root is None: - _repo_root = os.getcwd() - _validate_repo_root() - return _repo_root + global _repo_root + if _repo_root is None: + _repo_root = os.getcwd() + _validate_repo_root() + return _repo_root def get_submodule_root(submodule) -> str: - path = os.path.join(get_repo_root(), "third_party", submodule) - if not os.path.isdir(path): - raise SystemExit(f"Could not find submodule: {path}") - return path + path = os.path.join(get_repo_root(), "third_party", submodule) + if not os.path.isdir(path): + raise SystemExit(f"Could not find submodule: {path}") + return path def _validate_repo_root(): - # Look for something we know is there. - known_dir = os.path.join(_repo_root, "compiler") - if not os.path.isdir(known_dir): - raise SystemExit(f"ERROR: Must run from the iree repository root. " - f"Actually in: {_repo_root}") + # Look for something we know is there. + known_dir = os.path.join(_repo_root, "compiler") + if not os.path.isdir(known_dir): + raise SystemExit( + f"ERROR: Must run from the iree repository root. " + f"Actually in: {_repo_root}" + ) def git_setup_remote(remote_alias, url, *, repo_dir=None): - needs_create = False - try: - existing_url = git_exec(["remote", "get-url", remote_alias], - capture_output=True, - repo_dir=repo_dir, - quiet=True) - existing_url = existing_url.strip() - if existing_url == url: - return - except subprocess.CalledProcessError: - # Does not exist. - needs_create = True - - if needs_create: - git_exec(["remote", "add", "--no-tags", remote_alias, url], - repo_dir=repo_dir) - else: - git_exec(["remote", "set-url", remote_alias, url], repo_dir=repo_dir) + needs_create = False + try: + existing_url = git_exec( + ["remote", "get-url", remote_alias], + capture_output=True, + repo_dir=repo_dir, + quiet=True, + ) + existing_url = existing_url.strip() + if existing_url == url: + return + except subprocess.CalledProcessError: + # Does not exist. + needs_create = True + + if needs_create: + git_exec(["remote", "add", "--no-tags", remote_alias, url], repo_dir=repo_dir) + else: + git_exec(["remote", "set-url", remote_alias, url], repo_dir=repo_dir) def git_is_porcelain(*, repo_dir=None): - output = git_exec(["status", "--porcelain", "--untracked-files=no"], - capture_output=True, - quiet=True, - repo_dir=repo_dir).strip() - return not bool(output) + output = git_exec( + ["status", "--porcelain", "--untracked-files=no"], + capture_output=True, + quiet=True, + repo_dir=repo_dir, + ).strip() + return not bool(output) def git_check_porcelain(*, repo_dir=None): - output = git_exec(["status", "--porcelain", "--untracked-files=no"], - capture_output=True, - quiet=True, - repo_dir=repo_dir).strip() - if output: - actual_repo_dir = get_repo_root() if repo_dir is None else repo_dir - raise SystemExit(f"ERROR: git directory {actual_repo_dir} is not clean. " - f"Please stash changes:\n{output}") + output = git_exec( + ["status", "--porcelain", "--untracked-files=no"], + capture_output=True, + quiet=True, + repo_dir=repo_dir, + ).strip() + if output: + actual_repo_dir = get_repo_root() if repo_dir is None else repo_dir + raise SystemExit( + f"ERROR: git directory {actual_repo_dir} is not clean. " + f"Please stash changes:\n{output}" + ) def git_fetch(*, repository=None, ref=None, repo_dir=None): - args = ["fetch"] - if repository: - args.append(repository) - if ref is not None: - args.append(ref) - git_exec(args, repo_dir=repo_dir) + args = ["fetch"] + if repository: + args.append(repository) + if ref is not None: + args.append(ref) + git_exec(args, repo_dir=repo_dir) def git_checkout(ref, *, repo_dir=None): - git_exec(["checkout", ref], repo_dir=repo_dir) + git_exec(["checkout", ref], repo_dir=repo_dir) -def git_create_branch(branch_name, - *, - checkout=True, - ref=None, - force=False, - repo_dir=None): - branch_args = ["branch"] - if force: - branch_args.append("-f") - branch_args.append(branch_name) - if ref is not None: - branch_args.append(ref) - git_exec(branch_args, repo_dir=repo_dir) +def git_create_branch( + branch_name, *, checkout=True, ref=None, force=False, repo_dir=None +): + branch_args = ["branch"] + if force: + branch_args.append("-f") + branch_args.append(branch_name) + if ref is not None: + branch_args.append(ref) + git_exec(branch_args, repo_dir=repo_dir) - if checkout: - git_exec(["checkout", branch_name], repo_dir=repo_dir) + if checkout: + git_exec(["checkout", branch_name], repo_dir=repo_dir) def git_push_branch(repository, branch_name, *, force=False, repo_dir=None): - push_args = ["push", "--set-upstream"] - if force: - push_args.append("-f") - push_args.append(repository) - push_args.append(f"{branch_name}:{branch_name}") - git_exec(push_args, repo_dir=repo_dir) + push_args = ["push", "--set-upstream"] + if force: + push_args.append("-f") + push_args.append(repository) + push_args.append(f"{branch_name}:{branch_name}") + git_exec(push_args, repo_dir=repo_dir) def git_branch_exists(branch_name, *, repo_dir=None): - output = git_exec(["branch", "-l", branch_name], - repo_dir=repo_dir, - quiet=True, - capture_output=True).strip() - return bool(output) + output = git_exec( + ["branch", "-l", branch_name], + repo_dir=repo_dir, + quiet=True, + capture_output=True, + ).strip() + return bool(output) def git_submodule_set_origin(path, *, url=None, branch=None, repo_dir=None): - if url is not None: - git_exec(["submodule", "set-url", "--", path, url], repo_dir=repo_dir) - - if branch is not None: - try: - if branch == "--default": - git_exec(["submodule", "set-branch", "--default", "--", path], - repo_dir=repo_dir) - else: - git_exec(["submodule", "set-branch", "--branch", branch, "--", path], - repo_dir=repo_dir) - except subprocess.CalledProcessError: - # The set-branch command returns 0 on change and !0 on no change. - # This is a bit unfortunate. - ... + if url is not None: + git_exec(["submodule", "set-url", "--", path, url], repo_dir=repo_dir) + + if branch is not None: + try: + if branch == "--default": + git_exec( + ["submodule", "set-branch", "--default", "--", path], + repo_dir=repo_dir, + ) + else: + git_exec( + ["submodule", "set-branch", "--branch", branch, "--", path], + repo_dir=repo_dir, + ) + except subprocess.CalledProcessError: + # The set-branch command returns 0 on change and !0 on no change. + # This is a bit unfortunate. + ... def git_reset(ref, *, hard=True, repo_dir=None): - args = ["reset"] - if hard: - args.append("--hard") - args.append(ref) - git_exec(args, repo_dir=repo_dir) + args = ["reset"] + if hard: + args.append("--hard") + args.append(ref) + git_exec(args, repo_dir=repo_dir) def git_current_commit(*, repo_dir=None) -> Tuple[str, str]: - output = git_exec(["log", "-n", "1", "--pretty=format:%H (%ci): %s"], - capture_output=True, - repo_dir=repo_dir, - quiet=True) - output = output.strip() - parts = output.split(" ") - # Return commit, full_summary - return parts[0], output + output = git_exec( + ["log", "-n", "1", "--pretty=format:%H (%ci): %s"], + capture_output=True, + repo_dir=repo_dir, + quiet=True, + ) + output = output.strip() + parts = output.split(" ") + # Return commit, full_summary + return parts[0], output def git_create_commit(*, message, add_all=False, repo_dir=None): - if add_all: - git_exec(["add", "-A"], repo_dir=repo_dir) - git_exec(["commit", "-m", message]) + if add_all: + git_exec(["add", "-A"], repo_dir=repo_dir) + git_exec(["commit", "-m", message]) def git_ls_remote_branches(repository_url, *, filter=None, repo_dir=None): - args = ["ls-remote", "-h", repository_url] - if filter: - args.extend(filter) - output = git_exec(args, quiet=True, capture_output=True) - lines = output.strip().splitlines(keepends=False) + args = ["ls-remote", "-h", repository_url] + if filter: + args.extend(filter) + output = git_exec(args, quiet=True, capture_output=True) + lines = output.strip().splitlines(keepends=False) - # Format is refs/heads/branch_name - def extract_branch(line): - parts = re.split("\\s+", line) - ref = parts[1] - prefix = "refs/heads/" - if ref.startswith(prefix): - ref = ref[len(prefix):] - return ref + # Format is refs/heads/branch_name + def extract_branch(line): + parts = re.split("\\s+", line) + ref = parts[1] + prefix = "refs/heads/" + if ref.startswith(prefix): + ref = ref[len(prefix) :] + return ref - return [extract_branch(l) for l in lines] + return [extract_branch(l) for l in lines] def git_exec(args, *, repo_dir=None, quiet=False, capture_output=False): - full_args = ["git"] + args - full_args_quoted = [shlex.quote(a) for a in full_args] - if not repo_dir: - repo_dir = get_repo_root() - if not quiet: - print(f" ++ EXEC: (cd {repo_dir} && {' '.join(full_args_quoted)})") - if capture_output: - return subprocess.check_output(full_args, cwd=repo_dir).decode("utf-8") - else: - subprocess.check_call(full_args, cwd=repo_dir) + full_args = ["git"] + args + full_args_quoted = [shlex.quote(a) for a in full_args] + if not repo_dir: + repo_dir = get_repo_root() + if not quiet: + print(f" ++ EXEC: (cd {repo_dir} && {' '.join(full_args_quoted)})") + if capture_output: + return subprocess.check_output(full_args, cwd=repo_dir).decode("utf-8") + else: + subprocess.check_call(full_args, cwd=repo_dir) diff --git a/build_tools/scripts/integrate/patch_module.py b/build_tools/scripts/integrate/patch_module.py index 2184bfe9afce..fbe2230544a3 100755 --- a/build_tools/scripts/integrate/patch_module.py +++ b/build_tools/scripts/integrate/patch_module.py @@ -32,78 +32,77 @@ def main(args): - module_info = iree_modules.MODULE_INFOS.get(args.module) - if not module_info: - raise SystemExit(f"ERROR: Bad value for --module. Must be one of: " - f"{', '.join(iree_modules.MODULE_INFOS.keys())}") + module_info = iree_modules.MODULE_INFOS.get(args.module) + if not module_info: + raise SystemExit( + f"ERROR: Bad value for --module. Must be one of: " + f"{', '.join(iree_modules.MODULE_INFOS.keys())}" + ) - if args.command == "patch": - main_patch(args, module_info) - else: - raise SystemExit( - f"ERROR: Unrecognized --command. Must be one of: patch, unpatch") + if args.command == "patch": + main_patch(args, module_info) + else: + raise SystemExit( + f"ERROR: Unrecognized --command. Must be one of: patch, unpatch" + ) def main_patch(args, module_info: iree_modules.ModuleInfo): - module_root = os.path.join(iree_utils.get_repo_root(), module_info.path) - setup_module_remotes(module_root, module_info) - - branch_name = find_unused_branch_name(module_info) - print(f"Allocated branch: {branch_name}") - current_commit, summary = iree_utils.git_current_commit(repo_dir=module_root) - print(f"Module is currently at: {summary}") - print( - f"*** Pushing branch {branch_name} to {module_info.fork_repository_push} ***" - ) - print(f"(Please ignore any messages below about creating a PR)\n") - iree_utils.git_exec([ - "push", PATCH_REMOTE_ALIAS, f"{current_commit}:refs/heads/{branch_name}" - ], - repo_dir=module_root) - print(f"*** Branch {branch_name} pushed ***") - - print(f"******* Congratulations *******") - print( - f"You have pushed your commits to {branch_name} on {module_info.fork_repository_push}." - ) - print( - f"Your main repository should now show that the submodule has been edited." - ) - print(f"Make a commit, referencing the above branch cherry-picks and ") - print(f"land the resulting PR.") - print(f"You can push more commits to this module's patch branch via:") - print( - f" (cd {module_info.path} && git push {PATCH_REMOTE_ALIAS} HEAD:{branch_name})" - ) - - -def setup_module_remotes(module_root: str, - module_info: iree_modules.ModuleInfo): - iree_utils.git_setup_remote(PATCH_REMOTE_ALIAS, - url=module_info.fork_repository_push, - repo_dir=module_root) + module_root = os.path.join(iree_utils.get_repo_root(), module_info.path) + setup_module_remotes(module_root, module_info) + + branch_name = find_unused_branch_name(module_info) + print(f"Allocated branch: {branch_name}") + current_commit, summary = iree_utils.git_current_commit(repo_dir=module_root) + print(f"Module is currently at: {summary}") + print(f"*** Pushing branch {branch_name} to {module_info.fork_repository_push} ***") + print(f"(Please ignore any messages below about creating a PR)\n") + iree_utils.git_exec( + ["push", PATCH_REMOTE_ALIAS, f"{current_commit}:refs/heads/{branch_name}"], + repo_dir=module_root, + ) + print(f"*** Branch {branch_name} pushed ***") + + print(f"******* Congratulations *******") + print( + f"You have pushed your commits to {branch_name} on {module_info.fork_repository_push}." + ) + print(f"Your main repository should now show that the submodule has been edited.") + print(f"Make a commit, referencing the above branch cherry-picks and ") + print(f"land the resulting PR.") + print(f"You can push more commits to this module's patch branch via:") + print( + f" (cd {module_info.path} && git push {PATCH_REMOTE_ALIAS} HEAD:{branch_name})" + ) + + +def setup_module_remotes(module_root: str, module_info: iree_modules.ModuleInfo): + iree_utils.git_setup_remote( + PATCH_REMOTE_ALIAS, url=module_info.fork_repository_push, repo_dir=module_root + ) def find_unused_branch_name(module_info: iree_modules.ModuleInfo): - branch_base = f"{module_info.branch_prefix}{date.today().strftime('%Y%m%d')}" - branch_name = branch_base - existing_branches = iree_utils.git_ls_remote_branches( - module_info.fork_repository_pull, - filter=[f"refs/heads/{module_info.branch_prefix}*"]) - i = 1 - while branch_name in existing_branches: - branch_name = f"{branch_base}.{i}" - i += 1 - return branch_name + branch_base = f"{module_info.branch_prefix}{date.today().strftime('%Y%m%d')}" + branch_name = branch_base + existing_branches = iree_utils.git_ls_remote_branches( + module_info.fork_repository_pull, + filter=[f"refs/heads/{module_info.branch_prefix}*"], + ) + i = 1 + while branch_name in existing_branches: + branch_name = f"{branch_base}.{i}" + i += 1 + return branch_name def parse_arguments(argv): - parser = argparse.ArgumentParser(description="IREE Submodule Patcher") - parser.add_argument("--module", help="Submodule to operate on", default=None) - parser.add_argument("--command", help="Command to execute", default="patch") - args = parser.parse_args(argv) - return args + parser = argparse.ArgumentParser(description="IREE Submodule Patcher") + parser.add_argument("--module", help="Submodule to operate on", default=None) + parser.add_argument("--command", help="Command to execute", default="patch") + args = parser.parse_args(argv) + return args if __name__ == "__main__": - main(parse_arguments(sys.argv[1:])) + main(parse_arguments(sys.argv[1:])) diff --git a/build_tools/scripts/ir_to_markdown.py b/build_tools/scripts/ir_to_markdown.py index 2642f42e3741..476dff3c4db2 100644 --- a/build_tools/scripts/ir_to_markdown.py +++ b/build_tools/scripts/ir_to_markdown.py @@ -34,71 +34,74 @@ def parse_arguments(): - """Parses command line arguments.""" - - parser = argparse.ArgumentParser() - parser.add_argument( - 'input_file_path', - type=str, - nargs='?', - metavar="", - help='Input IR dump (.mlir from -mlir-print-ir-after-all)') - parser.add_argument('-o,', - '--output', - type=str, - required=True, - metavar="", - help='Output file path (e.g. translation_ir.md)') - # TODO(scotttodd): flags for original IR path and compilation command line - # .md could then show original IR + flags -> output - # TODO(scotttodd): flag for markdown flavor (mkdocs, github, etc.) - # TODO(scotttodd): flag for diff view (correlate IR before and IR after)? - - return parser.parse_args() + """Parses command line arguments.""" + + parser = argparse.ArgumentParser() + parser.add_argument( + "input_file_path", + type=str, + nargs="?", + metavar="", + help="Input IR dump (.mlir from -mlir-print-ir-after-all)", + ) + parser.add_argument( + "-o,", + "--output", + type=str, + required=True, + metavar="", + help="Output file path (e.g. translation_ir.md)", + ) + # TODO(scotttodd): flags for original IR path and compilation command line + # .md could then show original IR + flags -> output + # TODO(scotttodd): flag for markdown flavor (mkdocs, github, etc.) + # TODO(scotttodd): flag for diff view (correlate IR before and IR after)? + + return parser.parse_args() def main(args): - input_file_path = args.input_file_path - output_file_path = args.output - print("Converting input file '%s'" % (input_file_path)) - print(" into output file '%s'" % (output_file_path)) - - with open(input_file_path, "r") as input_file: - with open(output_file_path, "w") as output_file: - - # Iterate line by line through the input file, collecting text into - # blocks and writing them into the output file with markdown formatting - # as we go. - # - # Note: we could parse through and find/replace within the file using - # regex (or sed), but iterating this way is easier to understand and - # uses a predictable amount of memory. - - current_block_lines = [] - dump_after_regex = re.compile(MLIR_START_SEQUENCE + "\s(.*)\s" + - MLIR_END_SEQUENCE) - - def finish_block(): - nonlocal current_block_lines - if len(current_block_lines) != 0: - current_block_lines.append("```\n\n") - output_file.writelines(current_block_lines) - current_block_lines = [] - - for input_line in input_file: - if input_line == "\n": - continue - - if input_line.startswith(MLIR_START_SEQUENCE): - finish_block() - header_text = dump_after_regex.match(input_line).group(1) - current_block_lines.append("### " + header_text + "\n\n") - current_block_lines.append("```mlir\n") - else: - current_block_lines.append(input_line) - - finish_block() - - -if __name__ == '__main__': - main(parse_arguments()) + input_file_path = args.input_file_path + output_file_path = args.output + print("Converting input file '%s'" % (input_file_path)) + print(" into output file '%s'" % (output_file_path)) + + with open(input_file_path, "r") as input_file: + with open(output_file_path, "w") as output_file: + # Iterate line by line through the input file, collecting text into + # blocks and writing them into the output file with markdown formatting + # as we go. + # + # Note: we could parse through and find/replace within the file using + # regex (or sed), but iterating this way is easier to understand and + # uses a predictable amount of memory. + + current_block_lines = [] + dump_after_regex = re.compile( + MLIR_START_SEQUENCE + "\s(.*)\s" + MLIR_END_SEQUENCE + ) + + def finish_block(): + nonlocal current_block_lines + if len(current_block_lines) != 0: + current_block_lines.append("```\n\n") + output_file.writelines(current_block_lines) + current_block_lines = [] + + for input_line in input_file: + if input_line == "\n": + continue + + if input_line.startswith(MLIR_START_SEQUENCE): + finish_block() + header_text = dump_after_regex.match(input_line).group(1) + current_block_lines.append("### " + header_text + "\n\n") + current_block_lines.append("```mlir\n") + else: + current_block_lines.append(input_line) + + finish_block() + + +if __name__ == "__main__": + main(parse_arguments()) diff --git a/build_tools/scripts/local_web_server.py b/build_tools/scripts/local_web_server.py index 835a7600d021..a0732735ed71 100644 --- a/build_tools/scripts/local_web_server.py +++ b/build_tools/scripts/local_web_server.py @@ -20,47 +20,53 @@ class CORSHTTPRequestHandler(server.SimpleHTTPRequestHandler): + def __init__(self, *args, **kwargs): + # Include MIME types for files we expect to be serving. + # https://developer.mozilla.org/en-US/docs/Web/HTTP/Basics_of_HTTP/MIME_types/Common_types + self.extensions_map.update( + { + ".js": "application/javascript", + ".wasm": "application/wasm", + } + ) + super().__init__(*args, **kwargs) - def __init__(self, *args, **kwargs): - # Include MIME types for files we expect to be serving. - # https://developer.mozilla.org/en-US/docs/Web/HTTP/Basics_of_HTTP/MIME_types/Common_types - self.extensions_map.update({ - ".js": "application/javascript", - ".wasm": "application/wasm", - }) - super().__init__(*args, **kwargs) + # Inspiration for this hack: https://stackoverflow.com/a/13354482 + def end_headers(self): + self.send_cors_headers() - # Inspiration for this hack: https://stackoverflow.com/a/13354482 - def end_headers(self): - self.send_cors_headers() + server.SimpleHTTPRequestHandler.end_headers(self) - server.SimpleHTTPRequestHandler.end_headers(self) + def send_cors_headers(self): + # Emscripten uses SharedArrayBuffer for its multithreading, which requires + # Cross Origin Opener Policy and Cross Origin Embedder Policy headers: + # * https://emscripten.org/docs/porting/pthreads.html + # * https://developer.chrome.com/blog/enabling-shared-array-buffer/ + self.send_header("Cross-Origin-Embedder-Policy", "require-corp") + self.send_header("Cross-Origin-Opener-Policy", "same-origin") - def send_cors_headers(self): - # Emscripten uses SharedArrayBuffer for its multithreading, which requires - # Cross Origin Opener Policy and Cross Origin Embedder Policy headers: - # * https://emscripten.org/docs/porting/pthreads.html - # * https://developer.chrome.com/blog/enabling-shared-array-buffer/ - self.send_header("Cross-Origin-Embedder-Policy", "require-corp") - self.send_header("Cross-Origin-Opener-Policy", "same-origin") +if __name__ == "__main__": + import argparse -if __name__ == '__main__': - import argparse - parser = argparse.ArgumentParser() - parser.add_argument('--directory', - '-d', - default=os.getcwd(), - help='Specify alternative directory ' - '[default:current directory]') - parser.add_argument('port', - action='store', - default=8000, - type=int, - nargs='?', - help='Specify alternate port [default: 8000]') - args = parser.parse_args() + parser = argparse.ArgumentParser() + parser.add_argument( + "--directory", + "-d", + default=os.getcwd(), + help="Specify alternative directory " "[default:current directory]", + ) + parser.add_argument( + "port", + action="store", + default=8000, + type=int, + nargs="?", + help="Specify alternate port [default: 8000]", + ) + args = parser.parse_args() - server.test(HandlerClass=partial(CORSHTTPRequestHandler, - directory=args.directory), - port=args.port) + server.test( + HandlerClass=partial(CORSHTTPRequestHandler, directory=args.directory), + port=args.port, + ) diff --git a/build_tools/scripts/update_tflite_models.py b/build_tools/scripts/update_tflite_models.py index e2ea88727bc4..2af25fb0c409 100644 --- a/build_tools/scripts/update_tflite_models.py +++ b/build_tools/scripts/update_tflite_models.py @@ -26,45 +26,44 @@ import urllib FLAGS = flags.FLAGS -flags.DEFINE_string('file', '', 'file to update') +flags.DEFINE_string("file", "", "file to update") -file_dict = dict({ - "mobilenet_v1.tflite": - "https://tfhub.dev/tensorflow/lite-model/mobilenet_v1_1.0_160/1/default/1?lite-format=tflite", - "posenet_i8.tflite": - "https://tfhub.dev/google/lite-model/movenet/singlepose/lightning/tflite/int8/4?lite-format=tflite", - "posenet_i8_input.jpg": - "https://github.com/tensorflow/examples/raw/master/lite/examples/pose_estimation/raspberry_pi/test_data/image3.jpeg" -}) +file_dict = dict( + { + "mobilenet_v1.tflite": "https://tfhub.dev/tensorflow/lite-model/mobilenet_v1_1.0_160/1/default/1?lite-format=tflite", + "posenet_i8.tflite": "https://tfhub.dev/google/lite-model/movenet/singlepose/lightning/tflite/int8/4?lite-format=tflite", + "posenet_i8_input.jpg": "https://github.com/tensorflow/examples/raw/master/lite/examples/pose_estimation/raspberry_pi/test_data/image3.jpeg", + } +) BUCKET_NAME = "iree-model-artifacts" FOLDER_NAME = "tflite-integration-tests" def upload_model(source, destination, tmpfile): - """Uploads a file to the bucket.""" - urllib.request.urlretrieve(source, tmpfile) + """Uploads a file to the bucket.""" + urllib.request.urlretrieve(source, tmpfile) - storage_client = storage.Client() - bucket = storage_client.get_bucket(BUCKET_NAME) - blob = bucket.blob("/".join([FOLDER_NAME, destination])) - blob.upload_from_filename(tmpfile) + storage_client = storage.Client() + bucket = storage_client.get_bucket(BUCKET_NAME) + blob = bucket.blob("/".join([FOLDER_NAME, destination])) + blob.upload_from_filename(tmpfile) def main(argv): - tf = tempfile.NamedTemporaryFile() + tf = tempfile.NamedTemporaryFile() - items = file_dict.items() + items = file_dict.items() - if FLAGS.file in file_dict: - items = [(FLAGS.file, file_dict[FLAGS.file])] - elif FLAGS.file != "all": - print('Unknown file to upload: ', "\"" + FLAGS.file + "\"") - exit() + if FLAGS.file in file_dict: + items = [(FLAGS.file, file_dict[FLAGS.file])] + elif FLAGS.file != "all": + print("Unknown file to upload: ", '"' + FLAGS.file + '"') + exit() - for dst, src in items: - upload_model(src, dst, tf.name) + for dst, src in items: + upload_model(src, dst, tf.name) -if __name__ == '__main__': - app.run(main) +if __name__ == "__main__": + app.run(main) diff --git a/build_tools/scripts/utils.py b/build_tools/scripts/utils.py index 7713f5d8a9a6..36b31ae3335a 100644 --- a/build_tools/scripts/utils.py +++ b/build_tools/scripts/utils.py @@ -14,38 +14,46 @@ def create_markdown_table(rows: Sequence[Sequence[str]]): - """Converts a 2D array to a Markdown table.""" - return '\n'.join([' | '.join(row) for row in rows]) + """Converts a 2D array to a Markdown table.""" + return "\n".join([" | ".join(row) for row in rows]) -def check_and_get_output_lines(command: Sequence[str], - dry_run: bool = False, - log_stderr: bool = True): - print(f'Running: `{" ".join(command)}`') - if dry_run: - return None, None - return subprocess.run(command, stdout=subprocess.PIPE, text=true, - check=True).stdout.splitlines() +def check_and_get_output_lines( + command: Sequence[str], dry_run: bool = False, log_stderr: bool = True +): + print(f'Running: `{" ".join(command)}`') + if dry_run: + return None, None + return subprocess.run( + command, stdout=subprocess.PIPE, text=true, check=True + ).stdout.splitlines() def get_test_targets(test_suite_path: str): - """Returns a list of test targets for the given test suite.""" - # Check if the suite exists (which may not be true for failing suites). - # We use two queries here because the return code for a failed query is - # unfortunately the same as the return code for a bazel configuration error. - target_dir = test_suite_path.split(':')[0] - query = [ - 'bazel', 'query', '--ui_event_filters=-DEBUG', - '--noshow_loading_progress', '--noshow_progress', f'{target_dir}/...' - ] - targets = check_and_get_output_lines(query) - if test_suite_path not in targets: - return [] - - query = [ - 'bazel', 'query', '--ui_event_filters=-DEBUG', - '--noshow_loading_progress', '--noshow_progress', - f'tests({test_suite_path})' - ] - tests = check_and_get_output_lines(query) - return tests + """Returns a list of test targets for the given test suite.""" + # Check if the suite exists (which may not be true for failing suites). + # We use two queries here because the return code for a failed query is + # unfortunately the same as the return code for a bazel configuration error. + target_dir = test_suite_path.split(":")[0] + query = [ + "bazel", + "query", + "--ui_event_filters=-DEBUG", + "--noshow_loading_progress", + "--noshow_progress", + f"{target_dir}/...", + ] + targets = check_and_get_output_lines(query) + if test_suite_path not in targets: + return [] + + query = [ + "bazel", + "query", + "--ui_event_filters=-DEBUG", + "--noshow_loading_progress", + "--noshow_progress", + f"tests({test_suite_path})", + ] + tests = check_and_get_output_lines(query) + return tests diff --git a/build_tools/testing/gen_test_matrix.py b/build_tools/testing/gen_test_matrix.py index 6b9001e01313..f91548b75aba 100755 --- a/build_tools/testing/gen_test_matrix.py +++ b/build_tools/testing/gen_test_matrix.py @@ -86,11 +86,11 @@ import shutil try: - import yaml + import yaml except ModuleNotFoundError as e: - raise RuntimeError( - f"PyYAML is not installed. Typically: 'python -m pip install PyYAML" - ) from e + raise RuntimeError( + f"PyYAML is not installed. Typically: 'python -m pip install PyYAML" + ) from e ################################################################################ # Base classes and types @@ -98,74 +98,75 @@ class Environment: - """Runtime environment for processing a directory.""" - - def __init__(self, args, root_dir: str, output_dir: str): - self.args = args - self.root_dir = root_dir - self.output_dir = output_dir - # Set of directories containing purely generated files. - self.gen_dirs = set() # type: Set[str] - # Set of (gen_dir, file_name) for all files in a given directory that have - # been generated. - self.gen_files = set() # type: Set[Tuple[str, str]] - - def remember_gen_file(self, gen_file_path: str): - gen_dir = os.path.dirname(gen_file_path) - gen_file = os.path.basename(gen_file_path) - self.gen_dirs.add(gen_dir) - self.gen_files.add((gen_dir, gen_file)) - - def prune_gen_files(self): - found_gen_files = set() - for gen_dir in self.gen_dirs: - dir_listing = os.listdir(gen_dir) - for fname in dir_listing: - found_gen_files.add((gen_dir, fname)) - obsolete_gen_files = found_gen_files - self.gen_files - if obsolete_gen_files: - for gen_dir, fname in obsolete_gen_files: - obsolete_path = os.path.join(gen_dir, fname) - log(f"Removing obsolete file {obsolete_path}") - if os.path.isdir(obsolete_path): - shutil.rmtree(obsolete_path) - else: - os.remove(obsolete_path) + """Runtime environment for processing a directory.""" + + def __init__(self, args, root_dir: str, output_dir: str): + self.args = args + self.root_dir = root_dir + self.output_dir = output_dir + # Set of directories containing purely generated files. + self.gen_dirs = set() # type: Set[str] + # Set of (gen_dir, file_name) for all files in a given directory that have + # been generated. + self.gen_files = set() # type: Set[Tuple[str, str]] + + def remember_gen_file(self, gen_file_path: str): + gen_dir = os.path.dirname(gen_file_path) + gen_file = os.path.basename(gen_file_path) + self.gen_dirs.add(gen_dir) + self.gen_files.add((gen_dir, gen_file)) + + def prune_gen_files(self): + found_gen_files = set() + for gen_dir in self.gen_dirs: + dir_listing = os.listdir(gen_dir) + for fname in dir_listing: + found_gen_files.add((gen_dir, fname)) + obsolete_gen_files = found_gen_files - self.gen_files + if obsolete_gen_files: + for gen_dir, fname in obsolete_gen_files: + obsolete_path = os.path.join(gen_dir, fname) + log(f"Removing obsolete file {obsolete_path}") + if os.path.isdir(obsolete_path): + shutil.rmtree(obsolete_path) + else: + os.remove(obsolete_path) class Runner: - """Base class for a runner.""" - RUNNER_IDENT = None - - def __init__(self, env: Environment, test_id: str): - self.env = env - self.test_id = test_id - self.gen_dir = os.path.join(self.env.output_dir, "generated") - self.xfail = False - - @property - def runner_ident(self) -> str: - assert self.RUNNER_IDENT, "Must define RUNNER_IDENT" - return self.RUNNER_IDENT - - def create_gen_file(self, file_name: str, mode: str = "wt"): - os.makedirs(self.gen_dir, exist_ok=True) - full_path = os.path.join(self.gen_dir, file_name) - handle = open(full_path, mode) - self.env.remember_gen_file(full_path) - return handle - - def link_file(self, from_path: str, to_path: str): - if from_path == to_path: - return - from_path = os.path.realpath(from_path) - os.makedirs(os.path.dirname(to_path), exist_ok=True) - if os.path.exists(to_path): - os.remove(to_path) - os.symlink(from_path, to_path) - - def generate(self): - raise NotImplementedError(f"Generate not implemented for {self.__class__}") + """Base class for a runner.""" + + RUNNER_IDENT = None + + def __init__(self, env: Environment, test_id: str): + self.env = env + self.test_id = test_id + self.gen_dir = os.path.join(self.env.output_dir, "generated") + self.xfail = False + + @property + def runner_ident(self) -> str: + assert self.RUNNER_IDENT, "Must define RUNNER_IDENT" + return self.RUNNER_IDENT + + def create_gen_file(self, file_name: str, mode: str = "wt"): + os.makedirs(self.gen_dir, exist_ok=True) + full_path = os.path.join(self.gen_dir, file_name) + handle = open(full_path, mode) + self.env.remember_gen_file(full_path) + return handle + + def link_file(self, from_path: str, to_path: str): + if from_path == to_path: + return + from_path = os.path.realpath(from_path) + os.makedirs(os.path.dirname(to_path), exist_ok=True) + if os.path.exists(to_path): + os.remove(to_path) + os.symlink(from_path, to_path) + + def generate(self): + raise NotImplementedError(f"Generate not implemented for {self.__class__}") ################################################################################ @@ -174,105 +175,103 @@ def generate(self): def parse_arguments(): - parser = argparse.ArgumentParser(description="Test matrix generator") - parser.add_argument("--dir", - required=True, - type=str, - help="Directory to process") - parser.add_argument("--output_dir", - required=True, - type=str, - help="Output directory") - args = parser.parse_args() - return args + parser = argparse.ArgumentParser(description="Test matrix generator") + parser.add_argument("--dir", required=True, type=str, help="Directory to process") + parser.add_argument( + "--output_dir", required=True, type=str, help="Output directory" + ) + args = parser.parse_args() + return args def main(args): - env = Environment(args, args.dir, args.output_dir) - process_directory(env) + env = Environment(args, args.dir, args.output_dir) + process_directory(env) def process_directory(env: Environment): - dir = os.path.realpath(env.root_dir) - try: - config_sections = read_directory_config(dir) - except Exception as e: - raise RuntimeError(f"Could not read configuration from {dir}") from e - for section in config_sections: - require_mapping(section) - for config_key, config_value in section.items(): - if config_key == "lists": - # Ignore: a place to stash anchors and references. - pass - elif config_key == "test_groups": - require_list(config_value) - for test_group in config_value: - require_mapping(test_group) - process_test_group(env, test_group) - else: - raise ValueError(f"Unexpected top-level section {config_key}") - - env.prune_gen_files() + dir = os.path.realpath(env.root_dir) + try: + config_sections = read_directory_config(dir) + except Exception as e: + raise RuntimeError(f"Could not read configuration from {dir}") from e + for section in config_sections: + require_mapping(section) + for config_key, config_value in section.items(): + if config_key == "lists": + # Ignore: a place to stash anchors and references. + pass + elif config_key == "test_groups": + require_list(config_value) + for test_group in config_value: + require_mapping(test_group) + process_test_group(env, test_group) + else: + raise ValueError(f"Unexpected top-level section {config_key}") + + env.prune_gen_files() def process_test_group(env: Environment, test_group): - group_id = get_mapping_key(test_group, "id", require_str) - matrix = generate_matrix( - get_mapping_key(test_group, "matrix", require_mapping)) - matrix_id_map = {group_id.format(**m): m for m in matrix} - for runner_map in get_mapping_key(test_group, "runner", require_list): - for matrix_id, matrix_map in matrix_id_map.items(): - runner = create_runner(env, matrix_id, runner_map, matrix_map) - runner.xfail = (evaluate_xfail(test_group, matrix_map) and - not evaluate_xpass(test_group, matrix_map)) - runner.generate() + group_id = get_mapping_key(test_group, "id", require_str) + matrix = generate_matrix(get_mapping_key(test_group, "matrix", require_mapping)) + matrix_id_map = {group_id.format(**m): m for m in matrix} + for runner_map in get_mapping_key(test_group, "runner", require_list): + for matrix_id, matrix_map in matrix_id_map.items(): + runner = create_runner(env, matrix_id, runner_map, matrix_map) + runner.xfail = evaluate_xfail( + test_group, matrix_map + ) and not evaluate_xpass(test_group, matrix_map) + runner.generate() def evaluate_xfail(test_group, matrix_map) -> bool: - try: - xfail_list = flatten_lists(require_list(test_group["xfail"])) - except KeyError: + try: + xfail_list = flatten_lists(require_list(test_group["xfail"])) + except KeyError: + return False + for xfail_group in xfail_list: + if evaluate_matrix_map_predicate(matrix_map, xfail_group): + return True return False - for xfail_group in xfail_list: - if evaluate_matrix_map_predicate(matrix_map, xfail_group): - return True - return False def evaluate_xpass(test_group, matrix_map) -> bool: - try: - xpass_list = flatten_lists(require_list(test_group["xpass"])) - except KeyError: + try: + xpass_list = flatten_lists(require_list(test_group["xpass"])) + except KeyError: + return False + for xpass_group in xpass_list: + if evaluate_matrix_map_predicate(matrix_map, xpass_group): + return True return False - for xpass_group in xpass_list: - if evaluate_matrix_map_predicate(matrix_map, xpass_group): - return True - return False def evaluate_matrix_map_predicate(matrix_map, predicate_group) -> bool: - # Each key is something like 'matrix.' which are and'ed - # together. Each value is either a literal or a list that is - # or'd together. - for pred_key, pred_value in predicate_group.items(): - match_value = None - if pred_key.startswith("matrix."): - try: - match_value = matrix_map[pred_key[len("matrix."):]] - except KeyError: - raise ValueError( - f"Could not match matrix predicate to matrix value: {pred_key}") - else: - raise ValueError( - f"Expected a matrix predicate (i.e. matrix.) but got {pred_key}") - # Match list (OR) or literal (==) - if isinstance(pred_value, list): - if match_value not in flatten_lists(pred_value): - return False - else: - if pred_value != match_value: - return False - return True + # Each key is something like 'matrix.' which are and'ed + # together. Each value is either a literal or a list that is + # or'd together. + for pred_key, pred_value in predicate_group.items(): + match_value = None + if pred_key.startswith("matrix."): + try: + match_value = matrix_map[pred_key[len("matrix.") :]] + except KeyError: + raise ValueError( + f"Could not match matrix predicate to matrix value: {pred_key}" + ) + else: + raise ValueError( + f"Expected a matrix predicate (i.e. matrix.) but got {pred_key}" + ) + # Match list (OR) or literal (==) + if isinstance(pred_value, list): + if match_value not in flatten_lists(pred_value): + return False + else: + if pred_value != match_value: + return False + return True ################################################################################ @@ -281,84 +280,84 @@ def evaluate_matrix_map_predicate(matrix_map, predicate_group) -> bool: def generate_matrix(matrix_map): - # List of (key, [value, value, ...]) - matrix_entries = [(k, flatten_lists(v)) for k, v in matrix_map.items()] - # Permute. - permuted = [] + # List of (key, [value, value, ...]) + matrix_entries = [(k, flatten_lists(v)) for k, v in matrix_map.items()] + # Permute. + permuted = [] - def accumulate(prior: dict, i: int): - if i == len(matrix_entries): - permuted.append(prior) - return - next_key, next_values = matrix_entries[i] - for next_value in next_values: - current = dict(prior) - current[next_key] = next_value - accumulate(current, i + 1) + def accumulate(prior: dict, i: int): + if i == len(matrix_entries): + permuted.append(prior) + return + next_key, next_values = matrix_entries[i] + for next_value in next_values: + current = dict(prior) + current[next_key] = next_value + accumulate(current, i + 1) - accumulate({}, 0) - return permuted + accumulate({}, 0) + return permuted def read_directory_config(dir: str) -> list: - sections = [] - matrix_path = os.path.join(dir, "test_matrix.yaml") - with open(matrix_path, "r") as stream: - for section in yaml.safe_load_all(stream): - sections.append(section) - return sections + sections = [] + matrix_path = os.path.join(dir, "test_matrix.yaml") + with open(matrix_path, "r") as stream: + for section in yaml.safe_load_all(stream): + sections.append(section) + return sections INDENT = 0 def log(msg: str): - print(" " * INDENT + msg) + print(" " * INDENT + msg) @contextmanager def indent(): - global INDENT - INDENT += 1 - yield - INDENT -= 1 + global INDENT + INDENT += 1 + yield + INDENT -= 1 def flatten_lists(l): - results = list() - for item in l: - if isinstance(item, list): - results.extend(flatten_lists(item)) - else: - results.append(item) - return results + results = list() + for item in l: + if isinstance(item, list): + results.extend(flatten_lists(item)) + else: + results.append(item) + return results def require_mapping(v): - if isinstance(v, dict): - return v - raise ValueError(f"Expected a YAML mapping for {v}") + if isinstance(v, dict): + return v + raise ValueError(f"Expected a YAML mapping for {v}") def require_list(v): - if isinstance(v, list): - return v - raise ValueError(f"Expected YAML list for {v}") + if isinstance(v, list): + return v + raise ValueError(f"Expected YAML list for {v}") def require_str(v): - if isinstance(v, str): - return v - raise ValueError(f"Expected str for {v}") + if isinstance(v, str): + return v + raise ValueError(f"Expected str for {v}") def get_mapping_key(mapping, key: str, checker=None): - if key not in mapping: - raise ValueError(f"Expected key '{key}' in {mapping}") - value = mapping[key] - if checker: - checker(value) - return value + if key not in mapping: + raise ValueError(f"Expected key '{key}' in {mapping}") + value = mapping[key] + if checker: + checker(value) + return value ################################################################################ @@ -417,42 +416,45 @@ def get_mapping_key(mapping, key: str, checker=None): class TfHostRunner(Runner): - """Runner for tf e2e host tests.""" - RUNNER_IDENT = "tfhost" - - def __init__(self, env: Environment, test_id: str, runner_map: dict, - matrix_map: dict): - super().__init__(env=env, test_id=test_id) - self.main_file = get_mapping_key(runner_map, "main", require_str) - raw_arg_list = get_mapping_key(runner_map, "args", require_list) - self.args = [ - require_str(raw_arg).format(**matrix_map) for raw_arg in raw_arg_list - ] - - def generate(self): - # Generate the runner script. - file_name = ( - f"{'XFAIL_' if self.xfail else ''}{self.test_id}_{self.runner_ident}.py" - ) - with self.create_gen_file(file_name) as f: - parts = [ - "import os", - "import sys", - "REQUIRE_IMPORTS = ['iree.tf.support.tf_utils', 'iree.tf.support.tf_test_utils']", - f"ARGS = {repr(self.args)}", - f"MAIN = os.path.join(os.path.dirname(__file__), '..', {repr(self.main_file)})", - f"XFAIL = {self.xfail}", - PYRUNNER_STUB, - ] - f.write("\n".join(parts)) - - # Copy/link the main file. - main_file_src_path = os.path.join(self.env.root_dir, self.main_file) - main_file_dst_path = os.path.join(self.env.output_dir, self.main_file) - if not os.path.exists(main_file_src_path): - raise RuntimeError( - f"Referenced main file '{main_file_src_path}' does not exist") - self.link_file(main_file_src_path, main_file_dst_path) + """Runner for tf e2e host tests.""" + + RUNNER_IDENT = "tfhost" + + def __init__( + self, env: Environment, test_id: str, runner_map: dict, matrix_map: dict + ): + super().__init__(env=env, test_id=test_id) + self.main_file = get_mapping_key(runner_map, "main", require_str) + raw_arg_list = get_mapping_key(runner_map, "args", require_list) + self.args = [ + require_str(raw_arg).format(**matrix_map) for raw_arg in raw_arg_list + ] + + def generate(self): + # Generate the runner script. + file_name = ( + f"{'XFAIL_' if self.xfail else ''}{self.test_id}_{self.runner_ident}.py" + ) + with self.create_gen_file(file_name) as f: + parts = [ + "import os", + "import sys", + "REQUIRE_IMPORTS = ['iree.tf.support.tf_utils', 'iree.tf.support.tf_test_utils']", + f"ARGS = {repr(self.args)}", + f"MAIN = os.path.join(os.path.dirname(__file__), '..', {repr(self.main_file)})", + f"XFAIL = {self.xfail}", + PYRUNNER_STUB, + ] + f.write("\n".join(parts)) + + # Copy/link the main file. + main_file_src_path = os.path.join(self.env.root_dir, self.main_file) + main_file_dst_path = os.path.join(self.env.output_dir, self.main_file) + if not os.path.exists(main_file_src_path): + raise RuntimeError( + f"Referenced main file '{main_file_src_path}' does not exist" + ) + self.link_file(main_file_src_path, main_file_dst_path) RUNNER_CLASSES = { @@ -460,18 +462,16 @@ def generate(self): } -def create_runner(env: Environment, test_id: str, runner_map: dict, - matrix_map: dict): - runner_type = get_mapping_key(runner_map, "type", require_str) - try: - runner_class = RUNNER_CLASSES[runner_type] - except KeyError: - raise ValueError(f"Unknown runner type '{runner_type}'") - return runner_class(env=env, - test_id=test_id, - runner_map=runner_map, - matrix_map=matrix_map) +def create_runner(env: Environment, test_id: str, runner_map: dict, matrix_map: dict): + runner_type = get_mapping_key(runner_map, "type", require_str) + try: + runner_class = RUNNER_CLASSES[runner_type] + except KeyError: + raise ValueError(f"Unknown runner type '{runner_type}'") + return runner_class( + env=env, test_id=test_id, runner_map=runner_map, matrix_map=matrix_map + ) if __name__ == "__main__": - main(parse_arguments()) + main(parse_arguments()) diff --git a/build_tools/testing/generate_cmake_e2e_model_tests.py b/build_tools/testing/generate_cmake_e2e_model_tests.py index 44f3a1a80702..1cdafe71d6fc 100755 --- a/build_tools/testing/generate_cmake_e2e_model_tests.py +++ b/build_tools/testing/generate_cmake_e2e_model_tests.py @@ -19,30 +19,32 @@ TEMPLATE_DIR = pathlib.Path(__file__).parent GENERATED_E2E_MODEL_TESTS_CMAKE_TEMPLATE = string.Template( - (TEMPLATE_DIR / "generated_e2e_model_tests_template.cmake").read_text()) + (TEMPLATE_DIR / "generated_e2e_model_tests_template.cmake").read_text() +) def parse_arguments(): - """Parses command-line options.""" + """Parses command-line options.""" - parser = argparse.ArgumentParser() - parser.add_argument("--output", - required=True, - help="Path to write the generated cmake file.") + parser = argparse.ArgumentParser() + parser.add_argument( + "--output", required=True, help="Path to write the generated cmake file." + ) - return parser.parse_args() + return parser.parse_args() def main(args: argparse.Namespace): - (gen_configs, - _) = benchmark_suites.iree.benchmark_collections.generate_benchmarks() - cmake_rules = e2e_model_tests.cmake_generator.generate_rules( - module_generation_configs=gen_configs) - output = GENERATED_E2E_MODEL_TESTS_CMAKE_TEMPLATE.substitute( - __TEST_RULES="\n".join(cmake_rules)) - with open(args.output, "w") as output_file: - output_file.write(output) + (gen_configs, _) = benchmark_suites.iree.benchmark_collections.generate_benchmarks() + cmake_rules = e2e_model_tests.cmake_generator.generate_rules( + module_generation_configs=gen_configs + ) + output = GENERATED_E2E_MODEL_TESTS_CMAKE_TEMPLATE.substitute( + __TEST_RULES="\n".join(cmake_rules) + ) + with open(args.output, "w") as output_file: + output_file.write(output) if __name__ == "__main__": - main(parse_arguments()) + main(parse_arguments()) diff --git a/build_tools/testing/generate_cmake_e2e_test_artifacts_suite.py b/build_tools/testing/generate_cmake_e2e_test_artifacts_suite.py index 7b906488b980..85ab921d2cdb 100755 --- a/build_tools/testing/generate_cmake_e2e_test_artifacts_suite.py +++ b/build_tools/testing/generate_cmake_e2e_test_artifacts_suite.py @@ -23,49 +23,64 @@ PACKAGE_NAME_CMAKE_VARIABLE = "PACKAGE_NAME" ROOT_ARTIFACTS_DIR_CMAKE_VARIABLE = "ROOT_ARTIFACTS_DIR" -GENERATED_E2E_TEST_FETCH_MODELS_CMAKE_FILENAMAE = "generated_e2e_test_fetch_models.cmake" -GENERATED_E2E_TEST_IREE_ARTIFACTS_CMAKE_FILENAME = "generated_e2e_test_iree_artifacts.cmake" +GENERATED_E2E_TEST_FETCH_MODELS_CMAKE_FILENAMAE = ( + "generated_e2e_test_fetch_models.cmake" +) +GENERATED_E2E_TEST_IREE_ARTIFACTS_CMAKE_FILENAME = ( + "generated_e2e_test_iree_artifacts.cmake" +) def parse_arguments(): - """Parses command-line options.""" + """Parses command-line options.""" - parser = argparse.ArgumentParser() - parser.add_argument("--output_dir", - required=True, - help="Dir path to write the generated cmake files.") + parser = argparse.ArgumentParser() + parser.add_argument( + "--output_dir", + required=True, + help="Dir path to write the generated cmake files.", + ) - return parser.parse_args() + return parser.parse_args() def main(args: argparse.Namespace): - # Currently benchmark is the only source of module generation configs. - (iree_module_generation_configs, - _) = benchmark_suites.iree.benchmark_collections.generate_benchmarks() - - dependent_model_map = iree_artifacts.get_dependent_model_map( - iree_module_generation_configs) - - root_path = pathlib.PurePath("${%s}" % ROOT_ARTIFACTS_DIR_CMAKE_VARIABLE) - model_rule_map = model_rule_generator.generate_model_rule_map( - root_path=root_path, models=dependent_model_map.values()) - - output_dir = pathlib.Path(args.output_dir) - fetch_models_cmake_file = output_dir / GENERATED_E2E_TEST_FETCH_MODELS_CMAKE_FILENAMAE - model_cmake_rules = itertools.chain.from_iterable( - rule.cmake_rules for rule in model_rule_map.values()) - fetch_models_cmake_file.write_text("\n".join(model_cmake_rules)) - - package_name = "${%s}" % PACKAGE_NAME_CMAKE_VARIABLE - iree_cmake_rules = iree_rule_generator.generate_rules( - package_name=package_name, - root_path=root_path, - module_generation_configs=iree_module_generation_configs, - model_rule_map=model_rule_map) - - (output_dir / GENERATED_E2E_TEST_IREE_ARTIFACTS_CMAKE_FILENAME).write_text( - "\n".join(iree_cmake_rules)) + # Currently benchmark is the only source of module generation configs. + ( + iree_module_generation_configs, + _, + ) = benchmark_suites.iree.benchmark_collections.generate_benchmarks() + + dependent_model_map = iree_artifacts.get_dependent_model_map( + iree_module_generation_configs + ) + + root_path = pathlib.PurePath("${%s}" % ROOT_ARTIFACTS_DIR_CMAKE_VARIABLE) + model_rule_map = model_rule_generator.generate_model_rule_map( + root_path=root_path, models=dependent_model_map.values() + ) + + output_dir = pathlib.Path(args.output_dir) + fetch_models_cmake_file = ( + output_dir / GENERATED_E2E_TEST_FETCH_MODELS_CMAKE_FILENAMAE + ) + model_cmake_rules = itertools.chain.from_iterable( + rule.cmake_rules for rule in model_rule_map.values() + ) + fetch_models_cmake_file.write_text("\n".join(model_cmake_rules)) + + package_name = "${%s}" % PACKAGE_NAME_CMAKE_VARIABLE + iree_cmake_rules = iree_rule_generator.generate_rules( + package_name=package_name, + root_path=root_path, + module_generation_configs=iree_module_generation_configs, + model_rule_map=model_rule_map, + ) + + (output_dir / GENERATED_E2E_TEST_IREE_ARTIFACTS_CMAKE_FILENAME).write_text( + "\n".join(iree_cmake_rules) + ) if __name__ == "__main__": - main(parse_arguments()) + main(parse_arguments()) diff --git a/compiler/bindings/python/iree/compiler/tools/binaries.py b/compiler/bindings/python/iree/compiler/tools/binaries.py index 920852aa3b13..e2d51c733181 100644 --- a/compiler/bindings/python/iree/compiler/tools/binaries.py +++ b/compiler/bindings/python/iree/compiler/tools/binaries.py @@ -63,253 +63,256 @@ class CompilerToolError(Exception): - """Compiler exception that preserves the command line and error output.""" + """Compiler exception that preserves the command line and error output.""" - def __init__(self, process: subprocess.CompletedProcess): - try: - errs = process.stderr.decode("utf-8") - except: - errs = str(process.stderr) # Decode error or other: best we can do. + def __init__(self, process: subprocess.CompletedProcess): + try: + errs = process.stderr.decode("utf-8") + except: + errs = str(process.stderr) # Decode error or other: best we can do. - tool_name = os.path.basename(process.args[0]) - super().__init__(f"Error invoking IREE compiler tool {tool_name}\n" - f"Diagnostics:\n{errs}\n\n" - f"Invoked with:\n {tool_name} {' '.join(process.args)}\n\n" - f"Need more information? Set IREE_SAVE_TEMPS=/some/dir " - f"in your environment to save all artifacts and " - f"reproducers.\n") + tool_name = os.path.basename(process.args[0]) + super().__init__( + f"Error invoking IREE compiler tool {tool_name}\n" + f"Diagnostics:\n{errs}\n\n" + f"Invoked with:\n {tool_name} {' '.join(process.args)}\n\n" + f"Need more information? Set IREE_SAVE_TEMPS=/some/dir " + f"in your environment to save all artifacts and " + f"reproducers.\n" + ) def get_tool_path() -> List[str]: - """Returns list of paths to search for tools.""" - list_str = os.environ.get(_TOOL_PATH_ENVVAR) - if not list_str: - return [] - return list_str.split(os.pathsep) + """Returns list of paths to search for tools.""" + list_str = os.environ.get(_TOOL_PATH_ENVVAR) + if not list_str: + return [] + return list_str.split(os.pathsep) def find_tool(exe_name: str) -> str: - """Finds a tool by its (extension-less) executable name. - - Args: - exe_name: The name of the executable (extension-less). - Returns: - An absolute path to the tool. - Raises: - ValueError: If the tool is not known or not found. - """ - is_builtin = exe_name in _BUILTIN_TOOLS - if not is_builtin and exe_name not in _TOOL_MODULE_MAP: - raise ValueError(f"IREE compiler tool '{exe_name}' is not a known tool") - - # First search an explicit tool path (from environment). - tool_path = get_tool_path() - for path_entry in tool_path: - if not path_entry: - continue - candidate_exe = os.path.join(path_entry, exe_name) - if _is_executable(candidate_exe): - return candidate_exe - - if is_builtin: - # Get builtin tool. - candidate_exe = _get_builtin_tool(exe_name) - if _is_executable(candidate_exe): - return candidate_exe - - # Fall-through and attempt to find it via a tools module. - # Attempt to load the tool module. - tool_module_name = _TOOL_MODULE_MAP[exe_name] - tool_module_package = _TOOL_MODULE_PACKAGES[tool_module_name] - try: - tool_module = importlib.import_module(tool_module_name) - except ModuleNotFoundError: - raise ValueError( - f"IREE compiler tool '{exe_name}' is not installed (it should have been " - f"found in the python module '{tool_module_name}', typically installed " - f"via the package {tool_module_package}).\n\n" - f"Either install the package or set the {_TOOL_PATH_ENVVAR} environment " - f"variable to contain the path of the tool executable " - f"(current {_TOOL_PATH_ENVVAR} = {repr(tool_path)}).") from None - - # Ask the module for its tool. - candidate_exe = tool_module.get_tool(exe_name) - - if (not _is_executable(candidate_exe)): - raise ValueError( - f"IREE compiler tool '{exe_name}' was located in module " - f"'{tool_module_name}' but the file was not found or not executable: " - f"{candidate_exe}") - return candidate_exe + """Finds a tool by its (extension-less) executable name. + + Args: + exe_name: The name of the executable (extension-less). + Returns: + An absolute path to the tool. + Raises: + ValueError: If the tool is not known or not found. + """ + is_builtin = exe_name in _BUILTIN_TOOLS + if not is_builtin and exe_name not in _TOOL_MODULE_MAP: + raise ValueError(f"IREE compiler tool '{exe_name}' is not a known tool") + + # First search an explicit tool path (from environment). + tool_path = get_tool_path() + for path_entry in tool_path: + if not path_entry: + continue + candidate_exe = os.path.join(path_entry, exe_name) + if _is_executable(candidate_exe): + return candidate_exe + + if is_builtin: + # Get builtin tool. + candidate_exe = _get_builtin_tool(exe_name) + if _is_executable(candidate_exe): + return candidate_exe + + # Fall-through and attempt to find it via a tools module. + # Attempt to load the tool module. + tool_module_name = _TOOL_MODULE_MAP[exe_name] + tool_module_package = _TOOL_MODULE_PACKAGES[tool_module_name] + try: + tool_module = importlib.import_module(tool_module_name) + except ModuleNotFoundError: + raise ValueError( + f"IREE compiler tool '{exe_name}' is not installed (it should have been " + f"found in the python module '{tool_module_name}', typically installed " + f"via the package {tool_module_package}).\n\n" + f"Either install the package or set the {_TOOL_PATH_ENVVAR} environment " + f"variable to contain the path of the tool executable " + f"(current {_TOOL_PATH_ENVVAR} = {repr(tool_path)})." + ) from None + + # Ask the module for its tool. + candidate_exe = tool_module.get_tool(exe_name) + + if not _is_executable(candidate_exe): + raise ValueError( + f"IREE compiler tool '{exe_name}' was located in module " + f"'{tool_module_name}' but the file was not found or not executable: " + f"{candidate_exe}" + ) + return candidate_exe def _get_builtin_tool(exe_name: str) -> Optional[str]: - if platform.system() == "Windows": - exe_name = exe_name + ".exe" - this_path = os.path.dirname(__file__) - tool_path = os.path.join(this_path, "..", "_mlir_libs", exe_name) - return tool_path + if platform.system() == "Windows": + exe_name = exe_name + ".exe" + this_path = os.path.dirname(__file__) + tool_path = os.path.join(this_path, "..", "_mlir_libs", exe_name) + return tool_path def _is_executable(candidate_exe: str) -> bool: - if not candidate_exe: - return False - if not os.path.isfile(candidate_exe): - return False - if not os.access(candidate_exe, os.X_OK): - return False - return True - - -def invoke_immediate(command_line: List[str], - *, - input_file: Optional[bytes] = None, - immediate_input=None): - """Invokes an immediate command. - - This is separate from invoke_pipeline as it is simpler and supports more - complex input redirection, using recommended facilities for sub-processes - (less magic). - - Note that this differs from the usual way of using subprocess.run or - subprocess.Popen().communicate() because we need to pump all of the error - streams individually and only pump pipes not connected to a different stage. - Uses threads to pump everything that is required. - """ - if logger.isEnabledFor(logging.INFO): - logging.info("Invoke IREE Tool: %s", _quote_command_line(command_line)) - run_args = {} - input_file_handle = None - stderr_handle = sys.stderr - try: - # Redirect input. - if input_file is not None: - input_file_handle = open(input_file, "rb") - run_args["stdin"] = input_file_handle - elif immediate_input is not None: - run_args["input"] = immediate_input - - process = subprocess.run(command_line, capture_output=True, **run_args) - if process.returncode != 0: - raise CompilerToolError(process) - # Emit stderr contents. - _write_binary_stderr(stderr_handle, process.stderr) - return process.stdout - finally: - if input_file_handle: - input_file_handle.close() - - -def invoke_pipeline(command_lines: List[List[str]], immediate_input=None): - """Invoke a pipeline of commands. - - The first stage of the pipeline will have its stdin set to DEVNULL and each - subsequent stdin will derive from the prior stdout. The final stdout will - be accumulated and returned. All stderr contents are accumulated and printed - to stderr on completion or the first failing stage of the pipeline will have - an exception raised with its stderr output. - """ - logging.info( - "Invoke IREE Pipeline:\n %s", - "\n ".join([_quote_command_line(line) for line in command_lines])) - - stages = [] - pipeline_input = (subprocess.DEVNULL - if immediate_input is None else subprocess.PIPE) - prev_out = pipeline_input - stderr_handle = sys.stderr - - # Create all stages. - for i in range(len(command_lines)): - command_line = command_lines[i] - popen_args = { - "stdin": prev_out, - "stdout": subprocess.PIPE, - "stderr": subprocess.PIPE, - } - process = subprocess.Popen(command_line, **popen_args) - prev_out = process.stdout - capture_output = (i == (len(command_lines) - 1)) - stages.append(_PipelineStage(process, capture_output)) - - # Start stages. - for stage in stages: - stage.start() - - # Pump input. - pipe_success = True - if immediate_input is not None: + if not candidate_exe: + return False + if not os.path.isfile(candidate_exe): + return False + if not os.access(candidate_exe, os.X_OK): + return False + return True + + +def invoke_immediate( + command_line: List[str], *, input_file: Optional[bytes] = None, immediate_input=None +): + """Invokes an immediate command. + + This is separate from invoke_pipeline as it is simpler and supports more + complex input redirection, using recommended facilities for sub-processes + (less magic). + + Note that this differs from the usual way of using subprocess.run or + subprocess.Popen().communicate() because we need to pump all of the error + streams individually and only pump pipes not connected to a different stage. + Uses threads to pump everything that is required. + """ + if logger.isEnabledFor(logging.INFO): + logging.info("Invoke IREE Tool: %s", _quote_command_line(command_line)) + run_args = {} + input_file_handle = None + stderr_handle = sys.stderr try: - pipe_success = False - stages[0].process.stdin.write(immediate_input) - pipe_success = True + # Redirect input. + if input_file is not None: + input_file_handle = open(input_file, "rb") + run_args["stdin"] = input_file_handle + elif immediate_input is not None: + run_args["input"] = immediate_input + + process = subprocess.run(command_line, capture_output=True, **run_args) + if process.returncode != 0: + raise CompilerToolError(process) + # Emit stderr contents. + _write_binary_stderr(stderr_handle, process.stderr) + return process.stdout finally: - stages[0].process.stdin.close() - - # Join. - for stage in stages: - stage.join() + if input_file_handle: + input_file_handle.close() - # Check for errors. - for stage in stages: - assert stage.completed - if stage.completed.returncode != 0: - raise CompilerToolError(stage.completed) - # Broken pipe. - if not pipe_success: - raise CompilerToolError(stages[0].completed) - - # Print any stderr output. - for stage in stages: - _write_binary_stderr(stderr_handle, stage.errs) - return stages[-1].outs +def invoke_pipeline(command_lines: List[List[str]], immediate_input=None): + """Invoke a pipeline of commands. + + The first stage of the pipeline will have its stdin set to DEVNULL and each + subsequent stdin will derive from the prior stdout. The final stdout will + be accumulated and returned. All stderr contents are accumulated and printed + to stderr on completion or the first failing stage of the pipeline will have + an exception raised with its stderr output. + """ + logging.info( + "Invoke IREE Pipeline:\n %s", + "\n ".join([_quote_command_line(line) for line in command_lines]), + ) + + stages = [] + pipeline_input = subprocess.DEVNULL if immediate_input is None else subprocess.PIPE + prev_out = pipeline_input + stderr_handle = sys.stderr + + # Create all stages. + for i in range(len(command_lines)): + command_line = command_lines[i] + popen_args = { + "stdin": prev_out, + "stdout": subprocess.PIPE, + "stderr": subprocess.PIPE, + } + process = subprocess.Popen(command_line, **popen_args) + prev_out = process.stdout + capture_output = i == (len(command_lines) - 1) + stages.append(_PipelineStage(process, capture_output)) + + # Start stages. + for stage in stages: + stage.start() + + # Pump input. + pipe_success = True + if immediate_input is not None: + try: + pipe_success = False + stages[0].process.stdin.write(immediate_input) + pipe_success = True + finally: + stages[0].process.stdin.close() + + # Join. + for stage in stages: + stage.join() + + # Check for errors. + for stage in stages: + assert stage.completed + if stage.completed.returncode != 0: + raise CompilerToolError(stage.completed) + + # Broken pipe. + if not pipe_success: + raise CompilerToolError(stages[0].completed) + + # Print any stderr output. + for stage in stages: + _write_binary_stderr(stderr_handle, stage.errs) + return stages[-1].outs class _PipelineStage(threading.Thread): - """Wraps a process and pumps its handles, waiting for completion.""" - - def __init__(self, process, capture_output): - super().__init__() - self.process = process - self.capture_output = capture_output - self.completed: Optional[subprocess.CompletedProcess] = None - self.outs = None - self.errs = None - - def pump_stderr(self): - self.errs = self.process.stderr.read() - - def pump_stdout(self): - self.outs = self.process.stdout.read() - - def run(self): - stderr_thread = threading.Thread(target=self.pump_stderr) - stderr_thread.start() - if self.capture_output: - stdout_thread = threading.Thread(target=self.pump_stdout) - stdout_thread.start() - self.process.wait() - stderr_thread.join() - if self.capture_output: - stdout_thread.join() - self.completed = subprocess.CompletedProcess(self.process.args, - self.process.returncode, - self.outs, self.errs) - self.process.stderr.close() - self.process.stdout.close() + """Wraps a process and pumps its handles, waiting for completion.""" + + def __init__(self, process, capture_output): + super().__init__() + self.process = process + self.capture_output = capture_output + self.completed: Optional[subprocess.CompletedProcess] = None + self.outs = None + self.errs = None + + def pump_stderr(self): + self.errs = self.process.stderr.read() + + def pump_stdout(self): + self.outs = self.process.stdout.read() + + def run(self): + stderr_thread = threading.Thread(target=self.pump_stderr) + stderr_thread.start() + if self.capture_output: + stdout_thread = threading.Thread(target=self.pump_stdout) + stdout_thread.start() + self.process.wait() + stderr_thread.join() + if self.capture_output: + stdout_thread.join() + self.completed = subprocess.CompletedProcess( + self.process.args, self.process.returncode, self.outs, self.errs + ) + self.process.stderr.close() + self.process.stdout.close() def _write_binary_stderr(out_handle, contents): - # Fast-paths buffered text-io (which stderr is by default) while allowing - # full decode for non buffered and binary io. - if hasattr(out_handle, "buffer"): - out_handle.buffer.write(contents) - elif isinstance(out_handle, io.TextIOBase): - out_handle.write(contents.decode("utf-8")) - else: - out_handle.write(contents) + # Fast-paths buffered text-io (which stderr is by default) while allowing + # full decode for non buffered and binary io. + if hasattr(out_handle, "buffer"): + out_handle.buffer.write(contents) + elif isinstance(out_handle, io.TextIOBase): + out_handle.write(contents.decode("utf-8")) + else: + out_handle.write(contents) def _quote_command_line(command_line: List[str]) -> str: - return " ".join([shlex.quote(token) for token in command_line]) + return " ".join([shlex.quote(token) for token in command_line]) diff --git a/compiler/bindings/python/iree/compiler/tools/core.py b/compiler/bindings/python/iree/compiler/tools/core.py index 21fd10793cf1..46ab6fbb30f5 100644 --- a/compiler/bindings/python/iree/compiler/tools/core.py +++ b/compiler/bindings/python/iree/compiler/tools/core.py @@ -35,273 +35,285 @@ class InputType(Enum): - """Enumeration of allowable input types to the compiler. - - An instance of this enum or the string form can be passed to - `CompilerOptions.input_type`. - """ - NONE = "none" - AUTO = "auto" - STABLEHLO = "stablehlo" - STABLEHLO_XLA = "stablehlo_xla" - TOSA = "tosa" - TM_TENSOR = "tm_tensor" - - @staticmethod - def parse(spec: Union[str, InputType]) -> InputType: - """Parses or returns an InputType. + """Enumeration of allowable input types to the compiler. - Args: - spec: An InputType instance or the case-insensitive name of one of the - enum values. - Returns: - An InputType instance. + An instance of this enum or the string form can be passed to + `CompilerOptions.input_type`. """ - if isinstance(spec, InputType): - return spec - spec = spec.upper().replace("-", "_") - if spec not in InputType.__members__: - raise ValueError(f"For input_type= argument, expected one of: " - f"{', '.join(InputType.__members__.keys())}") - return InputType[spec] - def __str__(self): - return self.value + NONE = "none" + AUTO = "auto" + STABLEHLO = "stablehlo" + STABLEHLO_XLA = "stablehlo_xla" + TOSA = "tosa" + TM_TENSOR = "tm_tensor" + + @staticmethod + def parse(spec: Union[str, InputType]) -> InputType: + """Parses or returns an InputType. + + Args: + spec: An InputType instance or the case-insensitive name of one of the + enum values. + Returns: + An InputType instance. + """ + if isinstance(spec, InputType): + return spec + spec = spec.upper().replace("-", "_") + if spec not in InputType.__members__: + raise ValueError( + f"For input_type= argument, expected one of: " + f"{', '.join(InputType.__members__.keys())}" + ) + return InputType[spec] + + def __str__(self): + return self.value class OutputFormat(Enum): - """The output format of the compiler.""" - FLATBUFFER_BINARY = "flatbuffer-binary" - FLATBUFFER_TEXT = "flatbuffer-text" - MLIR_TEXT = "mlir-text" - - @staticmethod - def parse(spec: Union[str, OutputFormat]) -> OutputFormat: - """Parses or returns an OutputFormat. - - Args: - spec: An OutputFormat instance or the case-insensitive name of one of - the enum values. - Returns: - An OutputFormat instance. - """ - if isinstance(spec, OutputFormat): - return spec - spec = spec.upper().replace("-", "_") - if spec not in OutputFormat.__members__: - raise ValueError(f"For output_format= argument, expected one of: " - f"{', '.join(OutputFormat.__members__.keys())}") - return OutputFormat[spec] - - def __str__(self): - return self.value + """The output format of the compiler.""" + + FLATBUFFER_BINARY = "flatbuffer-binary" + FLATBUFFER_TEXT = "flatbuffer-text" + MLIR_TEXT = "mlir-text" + + @staticmethod + def parse(spec: Union[str, OutputFormat]) -> OutputFormat: + """Parses or returns an OutputFormat. + + Args: + spec: An OutputFormat instance or the case-insensitive name of one of + the enum values. + Returns: + An OutputFormat instance. + """ + if isinstance(spec, OutputFormat): + return spec + spec = spec.upper().replace("-", "_") + if spec not in OutputFormat.__members__: + raise ValueError( + f"For output_format= argument, expected one of: " + f"{', '.join(OutputFormat.__members__.keys())}" + ) + return OutputFormat[spec] + + def __str__(self): + return self.value @dataclass class CompilerOptions: - """Options to the compiler backend. - - Arguments: - output_file: Optionally save the compiled binary to a file instead of - returning it. - target_backends: List of str names of target backends to compile into - the binary. The resulting binary will run on targets that match one - or more of the compiled backends. - input_type: The type of input legalization to perform prior to full - compilation. Values can either be an `InputType` enum value or a - case-insensitive name. Defaults to `InputType.AUTO`. - output_format: Override the output format. See the `OutputFormat` enum. - Values can either be an enum value or a case-insensitive name of - the option. Typically used for debugging Defaults to - `OutputFormat.FLATBUFFER_BINARY`. - extra_args: Optional list of additional arguments to pass to the compiler. - Example: ["--mlir-print-ir-after-all", "--some-other-arg"]. Individual - arguments must be separate items in the list. - optimize: Whether to apply some default high level optimizations (default - True). - output_mlir_debuginfo: Include debuginfo (including paths) in any saved or - returned MLIR. - output_generic_mlir: Use the generic (and more portable) MLIR formatting for - any saved or returned MLIR instead of the per-dialect custom assembly. - extended_diagnostics: Outputs extended information on diagnostics, - potentially outputting very verbosely (defaults to False). - strip_debug_ops: Whether to strip high level operations used to aid - debugging. - strip_source_map: Whether to strip source map information (used to generate - better errors). - crash_reproducer_path: File name to output an MLIR crash dump to if there - is a compiler failure. - enable_tflite_bindings: Support the IREE TFLite runtime bindings API shim. - enable_benchmark: Whether to generate instrumented binaries suitable - for benchmarking. - """ - - output_file: Optional[str] = None - target_backends: Sequence[str] = () - input_type: Union[InputType, str] = InputType.AUTO - output_format: Union[OutputFormat, str] = OutputFormat.FLATBUFFER_BINARY - extra_args: Sequence[str] = () - optimize: bool = True - output_mlir_debuginfo: bool = True - output_generic_mlir: bool = False - extended_diagnostics: bool = False - strip_debug_ops: bool = False - strip_source_map: bool = False - crash_reproducer_path: Optional[str] = None - enable_tflite_bindings: bool = False - enable_benchmark: bool = False - - def __post_init__(self): - self.input_type = InputType.parse(self.input_type) - self.output_format = OutputFormat.parse(self.output_format) - - -def build_compile_command_line(input_file: str, tfs: TempFileSaver, - options: CompilerOptions) -> List[str]: - """Builds a command line for invoking the compiler. - - Args: - input_file: The input file name. - tfs: TempFileSaver. - options: Compiler options. - Returns: - List of strings of command line. - """ - iree_compile = find_tool("iree-compile") - if not options.target_backends: - raise ValueError("Expected a non-empty list for 'target_backends'") - - cl = [ - iree_compile, - input_file, - f"--iree-input-type={options.input_type!s}", - f"--iree-vm-bytecode-module-output-format={options.output_format!s}", - ] - for target_backend in options.target_backends: - cl.append(f"--iree-hal-target-backends={target_backend}") - - # Output file. - if options.output_file: - cl.append(f"-o={options.output_file}") - - # Tool paths. - if "llvm-cpu" in options.target_backends: - lld_path = find_tool("iree-lld") - cl.append(f"--iree-llvmcpu-embedded-linker-path={lld_path}") - - # MLIR flags. - if options.output_mlir_debuginfo: - cl.append("--mlir-print-debuginfo") - if options.output_generic_mlir: - cl.append("--mlir-print-op-generic") - if options.extended_diagnostics: - # Note that different tools have different defaults, so be explicit. - cl.append("--mlir-print-op-on-diagnostic=true") - else: - cl.append("--mlir-print-op-on-diagnostic=false") - - # Other options to set if specified. - if options.strip_debug_ops: - cl.append("--iree-vm-bytecode-module-strip-debug-ops") - if options.strip_source_map: - cl.append("--iree-vm-bytecode-module-strip-source-map") - crash_reproducer_path = tfs.alloc_optional( - "core-reproducer.mlir", export_as=options.crash_reproducer_path) - if crash_reproducer_path: - cl.append(f"--mlir-pass-pipeline-crash-reproducer={crash_reproducer_path}") - if options.enable_tflite_bindings: - cl.append("--iree-tflite-bindings-support") - if options.enable_benchmark: - cl.append("--iree-flow-export-benchmark-funcs") - - cl.extend(options.extra_args) - return cl + """Options to the compiler backend. + + Arguments: + output_file: Optionally save the compiled binary to a file instead of + returning it. + target_backends: List of str names of target backends to compile into + the binary. The resulting binary will run on targets that match one + or more of the compiled backends. + input_type: The type of input legalization to perform prior to full + compilation. Values can either be an `InputType` enum value or a + case-insensitive name. Defaults to `InputType.AUTO`. + output_format: Override the output format. See the `OutputFormat` enum. + Values can either be an enum value or a case-insensitive name of + the option. Typically used for debugging Defaults to + `OutputFormat.FLATBUFFER_BINARY`. + extra_args: Optional list of additional arguments to pass to the compiler. + Example: ["--mlir-print-ir-after-all", "--some-other-arg"]. Individual + arguments must be separate items in the list. + optimize: Whether to apply some default high level optimizations (default + True). + output_mlir_debuginfo: Include debuginfo (including paths) in any saved or + returned MLIR. + output_generic_mlir: Use the generic (and more portable) MLIR formatting for + any saved or returned MLIR instead of the per-dialect custom assembly. + extended_diagnostics: Outputs extended information on diagnostics, + potentially outputting very verbosely (defaults to False). + strip_debug_ops: Whether to strip high level operations used to aid + debugging. + strip_source_map: Whether to strip source map information (used to generate + better errors). + crash_reproducer_path: File name to output an MLIR crash dump to if there + is a compiler failure. + enable_tflite_bindings: Support the IREE TFLite runtime bindings API shim. + enable_benchmark: Whether to generate instrumented binaries suitable + for benchmarking. + """ + output_file: Optional[str] = None + target_backends: Sequence[str] = () + input_type: Union[InputType, str] = InputType.AUTO + output_format: Union[OutputFormat, str] = OutputFormat.FLATBUFFER_BINARY + extra_args: Sequence[str] = () + optimize: bool = True + output_mlir_debuginfo: bool = True + output_generic_mlir: bool = False + extended_diagnostics: bool = False + strip_debug_ops: bool = False + strip_source_map: bool = False + crash_reproducer_path: Optional[str] = None + enable_tflite_bindings: bool = False + enable_benchmark: bool = False + + def __post_init__(self): + self.input_type = InputType.parse(self.input_type) + self.output_format = OutputFormat.parse(self.output_format) + + +def build_compile_command_line( + input_file: str, tfs: TempFileSaver, options: CompilerOptions +) -> List[str]: + """Builds a command line for invoking the compiler. -def compile_file(input_file: str, **kwargs): - """Invokes the IREE compiler on an input file. - - Args: - input_file: File containing MLIR assembly to compile. - **kwargs: Keyword arguments corresponding to CompilerOptions. - Returns: - Either a byte buffer of the compiled content or None if output_file - was specified in the options. - """ - with TempFileSaver.implicit() as tfs: - options = CompilerOptions(**kwargs) - retained_output_file = tfs.alloc_optional("core-output.bin", - export_as=options.output_file) + Args: + input_file: The input file name. + tfs: TempFileSaver. + options: Compiler options. + Returns: + List of strings of command line. + """ + iree_compile = find_tool("iree-compile") + if not options.target_backends: + raise ValueError("Expected a non-empty list for 'target_backends'") + + cl = [ + iree_compile, + input_file, + f"--iree-input-type={options.input_type!s}", + f"--iree-vm-bytecode-module-output-format={options.output_format!s}", + ] + for target_backend in options.target_backends: + cl.append(f"--iree-hal-target-backends={target_backend}") + + # Output file. if options.output_file: - options.output_file = retained_output_file - cl = build_compile_command_line(input_file, tfs, options) + cl.append(f"-o={options.output_file}") + + # Tool paths. + if "llvm-cpu" in options.target_backends: + lld_path = find_tool("iree-lld") + cl.append(f"--iree-llvmcpu-embedded-linker-path={lld_path}") + + # MLIR flags. + if options.output_mlir_debuginfo: + cl.append("--mlir-print-debuginfo") + if options.output_generic_mlir: + cl.append("--mlir-print-op-generic") + if options.extended_diagnostics: + # Note that different tools have different defaults, so be explicit. + cl.append("--mlir-print-op-on-diagnostic=true") + else: + cl.append("--mlir-print-op-on-diagnostic=false") + + # Other options to set if specified. + if options.strip_debug_ops: + cl.append("--iree-vm-bytecode-module-strip-debug-ops") + if options.strip_source_map: + cl.append("--iree-vm-bytecode-module-strip-source-map") + crash_reproducer_path = tfs.alloc_optional( + "core-reproducer.mlir", export_as=options.crash_reproducer_path + ) + if crash_reproducer_path: + cl.append(f"--mlir-pass-pipeline-crash-reproducer={crash_reproducer_path}") + if options.enable_tflite_bindings: + cl.append("--iree-tflite-bindings-support") + if options.enable_benchmark: + cl.append("--iree-flow-export-benchmark-funcs") + + cl.extend(options.extra_args) + return cl - # Save a temp file with the command line. - retained_cl = tfs.alloc_optional("core-command-line.txt") - if retained_cl: - with open(retained_cl, "wt") as f: - f.write(" ".join(cl)) - result = invoke_immediate(cl) - if options.output_file: - return None - # Output as string needs to write to the retained output file itself. - if retained_output_file: - with open(retained_output_file, "wb") as f: - f.write(result) - return result +def compile_file(input_file: str, **kwargs): + """Invokes the IREE compiler on an input file. + + Args: + input_file: File containing MLIR assembly to compile. + **kwargs: Keyword arguments corresponding to CompilerOptions. + Returns: + Either a byte buffer of the compiled content or None if output_file + was specified in the options. + """ + with TempFileSaver.implicit() as tfs: + options = CompilerOptions(**kwargs) + retained_output_file = tfs.alloc_optional( + "core-output.bin", export_as=options.output_file + ) + if options.output_file: + options.output_file = retained_output_file + cl = build_compile_command_line(input_file, tfs, options) + + # Save a temp file with the command line. + retained_cl = tfs.alloc_optional("core-command-line.txt") + if retained_cl: + with open(retained_cl, "wt") as f: + f.write(" ".join(cl)) + + result = invoke_immediate(cl) + if options.output_file: + return None + # Output as string needs to write to the retained output file itself. + if retained_output_file: + with open(retained_output_file, "wb") as f: + f.write(result) + return result def compile_str(input_str: Union[str, bytes], **kwargs): - """Invokes the IREE compiler with an input string. - - Args: - input_str: MLIR assembly to parse/compile (str or bytes). - **kwargs: Keyword arguments corresponding to CompilerOptions. - Returns: - Either a byte buffer of the compiled content or None if output_file - was specified in the options. - """ - with TempFileSaver.implicit() as tfs: - retained_input_file = tfs.alloc_optional("core-input.mlir") - if retained_input_file: - with open(retained_input_file, - "wt" if isinstance(input_str, str) else "wb") as f: - f.write(input_str) - options = CompilerOptions(**kwargs) - retained_output_file = tfs.alloc_optional("core-output.bin", - export_as=options.output_file) - if options.output_file: - options.output_file = retained_output_file - cl = build_compile_command_line("-", tfs, options) - input_bytes = input_str.encode("utf-8") if isinstance(input_str, - str) else input_str - - # Save a temp file with the command line. - retained_cl = tfs.alloc_optional("core-command-line.txt") - if retained_cl: - with open(retained_cl, "wt") as f: - f.write(" ".join(cl)) - - result = invoke_immediate(cl, immediate_input=input_bytes) - if options.output_file: - return None + """Invokes the IREE compiler with an input string. - # Output as string needs to write to the retained output file itself. - if retained_output_file: - with open(retained_output_file, "wb") as f: - f.write(result) - return result + Args: + input_str: MLIR assembly to parse/compile (str or bytes). + **kwargs: Keyword arguments corresponding to CompilerOptions. + Returns: + Either a byte buffer of the compiled content or None if output_file + was specified in the options. + """ + with TempFileSaver.implicit() as tfs: + retained_input_file = tfs.alloc_optional("core-input.mlir") + if retained_input_file: + with open( + retained_input_file, "wt" if isinstance(input_str, str) else "wb" + ) as f: + f.write(input_str) + options = CompilerOptions(**kwargs) + retained_output_file = tfs.alloc_optional( + "core-output.bin", export_as=options.output_file + ) + if options.output_file: + options.output_file = retained_output_file + cl = build_compile_command_line("-", tfs, options) + input_bytes = ( + input_str.encode("utf-8") if isinstance(input_str, str) else input_str + ) + + # Save a temp file with the command line. + retained_cl = tfs.alloc_optional("core-command-line.txt") + if retained_cl: + with open(retained_cl, "wt") as f: + f.write(" ".join(cl)) + + result = invoke_immediate(cl, immediate_input=input_bytes) + if options.output_file: + return None + + # Output as string needs to write to the retained output file itself. + if retained_output_file: + with open(retained_output_file, "wb") as f: + f.write(result) + return result def query_available_targets(): - """Returns a collection of target names that are registered.""" - iree_compile = find_tool("iree-compile") - cl = [iree_compile, "--iree-hal-list-target-backends"] - result = invoke_immediate(cl).decode("utf-8") + """Returns a collection of target names that are registered.""" + iree_compile = find_tool("iree-compile") + cl = [iree_compile, "--iree-hal-list-target-backends"] + result = invoke_immediate(cl).decode("utf-8") - target_backends = result.split("\n")[1:] - target_backends = [target.strip() for target in target_backends] - target_backends = [target for target in target_backends if target] + target_backends = result.split("\n")[1:] + target_backends = [target.strip() for target in target_backends] + target_backends = [target for target in target_backends if target] - return target_backends + return target_backends diff --git a/compiler/bindings/python/iree/compiler/tools/debugging.py b/compiler/bindings/python/iree/compiler/tools/debugging.py index 903d8d7d7b7d..c73cda9c99df 100644 --- a/compiler/bindings/python/iree/compiler/tools/debugging.py +++ b/compiler/bindings/python/iree/compiler/tools/debugging.py @@ -3,7 +3,7 @@ # Licensed under the Apache License v2.0 with LLVM Exceptions. # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -r''' +r""" A number of optional arguments to the compiler can be useful for debugging: * `extended_diagnostics=True` - Outputs verbose attached operations to \ @@ -42,7 +42,7 @@ For the context manager based API, refer to the `iree.compiler.tools.debugging.TempFileSaver` class. -''' +""" from typing import Optional @@ -57,164 +57,172 @@ def _get_temp_file_saver_stack(): - try: - return _thread_locals.temp_file_saver_stack - except AttributeError: - stack = [] - _thread_locals.temp_file_saver_stack = stack - return stack + try: + return _thread_locals.temp_file_saver_stack + except AttributeError: + stack = [] + _thread_locals.temp_file_saver_stack = stack + return stack def _interpolate_path_pattern(path_pattern: str, *, invocation_id: str): - # We do not use str.format() because we do not know the providence of - # path_pattern. Instead, handle a fixed set of replacements. - path_pattern = path_pattern.replace("{id}", str(invocation_id)) - path_pattern = path_pattern.replace("{pid}", str(os.getpid())) - path_pattern = path_pattern.replace("{main}", os.path.basename(sys.argv[0])) - return path_pattern + # We do not use str.format() because we do not know the providence of + # path_pattern. Instead, handle a fixed set of replacements. + path_pattern = path_pattern.replace("{id}", str(invocation_id)) + path_pattern = path_pattern.replace("{pid}", str(os.getpid())) + path_pattern = path_pattern.replace("{main}", os.path.basename(sys.argv[0])) + return path_pattern class TempFileSaver: - """Manages the saving of temp files resulting from tool invocations. - - The TempFileSaver is a thread-local context bound object. An attempt to - create a new one will return the most recent instance created and entered - as a context manager. This allows up-stack callers to establish the - policy for saving temporaries and deep implementations will inherit it. - - Proper usage from users wishing to establish a saver context: - - .. code-block:: python - - with TempFileSaver(): - # Do things with temp files. - - Proper usage for implementors wishing to use an established saver context - or set up a new one: - - .. code-block:: python - - with TempFileSaver.implicit() as tfs: - # Do things with temp files. - - The outer-most creator can customize it with explicit arguments to __init__ - but these will be ignored if an instance is already thread bound. - """ - TEMP_PATH_ENV_KEY = "IREE_SAVE_TEMPS" - - @staticmethod - def implicit(): - stack = _get_temp_file_saver_stack() - if stack: - return stack[-1] - return TempFileSaver() - - def __init__(self, - temp_path_pattern: Optional[str] = None, - *, - invocation_id: Optional[str] = None): - self.retained = False - self._refcount = 0 - if temp_path_pattern is None: - temp_path_pattern = os.environ.get(TempFileSaver.TEMP_PATH_ENV_KEY) - if temp_path_pattern is None: - return - - global _invocation_id - if invocation_id is not None: - self.invocation_id = invocation_id - else: - self.invocation_id = _invocation_id - _invocation_id += 1 - - self.retained_path = _interpolate_path_pattern( - temp_path_pattern, invocation_id=self.invocation_id) - self.retained = True - self._retained_file_names = set() - self._copy_on_finalize = list() # Of (source_path, target_path) - - def __enter__(self): - _get_temp_file_saver_stack().append(self) - self._refcount += 1 - return self - - def __exit__(self, exc_type, exc_value, traceback): - del _get_temp_file_saver_stack()[-1] - self._refcount -= 1 - if self._refcount == 0: - self._finalize() - - @staticmethod - def current(): - try: - return _get_temp_file_saver_stack()[-1] - except KeyError: - raise RuntimeError("No current TempFileSaver") - - def alloc_optional(self, - file_name: str, - *, - export_as: Optional[str] = None) -> Optional[str]: - """Allocates an optional temporary file. - - - When in non-retained mode, the return value is 'export_as', meaning that the - file is just some user specified output file. - - When in retained mode, the output file will be an index-mangled variant - of 'file_name' under the temp_path. In addition, a mapping will be added - so that upon finalization, the file is also exported to 'export_as' if - specified. - - Returns None if neither a user-specified 'export_as' is specified nor in - retained mode. - - The distinction between retained temporaries and exports is to help in - cases for when the caller has requested that an artifact be written to - a specific place (i.e. an output file) but for debuggability, we also - want to save it as a temporary. In this case, we save it to the temporary - location and then conclude by moving artifacts to their final location - once the saver goes out of scope. + """Manages the saving of temp files resulting from tool invocations. + + The TempFileSaver is a thread-local context bound object. An attempt to + create a new one will return the most recent instance created and entered + as a context manager. This allows up-stack callers to establish the + policy for saving temporaries and deep implementations will inherit it. + + Proper usage from users wishing to establish a saver context: + + .. code-block:: python + + with TempFileSaver(): + # Do things with temp files. + + Proper usage for implementors wishing to use an established saver context + or set up a new one: + + .. code-block:: python + + with TempFileSaver.implicit() as tfs: + # Do things with temp files. + + The outer-most creator can customize it with explicit arguments to __init__ + but these will be ignored if an instance is already thread bound. """ - if not self.retained: - return export_as - alloced_path = self._alloc_retained_path(file_name) - if export_as: - self._copy_on_finalize.append((alloced_path, export_as)) - return alloced_path - - def _alloc_retained_path(self, file_name: str) -> str: - assert self.retained - index = 0 - original_file_name = file_name - while True: - if file_name not in self._retained_file_names: - # First use of this name. - self._retained_file_names.add(file_name) - os.makedirs(self.retained_path, exist_ok=True) - return os.path.join(self.retained_path, file_name) - index += 1 - stem, ext = os.path.splitext(original_file_name) - file_name = f"{stem}_{index}{ext}" - - def _finalize(self): - if not self.retained: - return - # See which files were materialized. - was_materialized = [] - for file_name in self._retained_file_names: - file_path = os.path.join(self.retained_path, file_name) - if os.path.exists(file_path): - was_materialized.append((file_name, file_path)) - if was_materialized: - logging.info( - "**** IREE Compiler retained temporary files (%s)***:\n%s", - self.invocation_id, "\n".join([ - f" * {file_name} : {file_path}" - for file_name, file_path in was_materialized - ])) - for source_path, target_path in self._copy_on_finalize: - if os.path.exists(source_path): - logging.info("Copy retained file to output: %s -> %s", source_path, - target_path) - shutil.copyfile(source_path, target_path) + + TEMP_PATH_ENV_KEY = "IREE_SAVE_TEMPS" + + @staticmethod + def implicit(): + stack = _get_temp_file_saver_stack() + if stack: + return stack[-1] + return TempFileSaver() + + def __init__( + self, + temp_path_pattern: Optional[str] = None, + *, + invocation_id: Optional[str] = None, + ): + self.retained = False + self._refcount = 0 + if temp_path_pattern is None: + temp_path_pattern = os.environ.get(TempFileSaver.TEMP_PATH_ENV_KEY) + if temp_path_pattern is None: + return + + global _invocation_id + if invocation_id is not None: + self.invocation_id = invocation_id + else: + self.invocation_id = _invocation_id + _invocation_id += 1 + + self.retained_path = _interpolate_path_pattern( + temp_path_pattern, invocation_id=self.invocation_id + ) + self.retained = True + self._retained_file_names = set() + self._copy_on_finalize = list() # Of (source_path, target_path) + + def __enter__(self): + _get_temp_file_saver_stack().append(self) + self._refcount += 1 + return self + + def __exit__(self, exc_type, exc_value, traceback): + del _get_temp_file_saver_stack()[-1] + self._refcount -= 1 + if self._refcount == 0: + self._finalize() + + @staticmethod + def current(): + try: + return _get_temp_file_saver_stack()[-1] + except KeyError: + raise RuntimeError("No current TempFileSaver") + + def alloc_optional( + self, file_name: str, *, export_as: Optional[str] = None + ) -> Optional[str]: + """Allocates an optional temporary file. + + + When in non-retained mode, the return value is 'export_as', meaning that the + file is just some user specified output file. + + When in retained mode, the output file will be an index-mangled variant + of 'file_name' under the temp_path. In addition, a mapping will be added + so that upon finalization, the file is also exported to 'export_as' if + specified. + + Returns None if neither a user-specified 'export_as' is specified nor in + retained mode. + + The distinction between retained temporaries and exports is to help in + cases for when the caller has requested that an artifact be written to + a specific place (i.e. an output file) but for debuggability, we also + want to save it as a temporary. In this case, we save it to the temporary + location and then conclude by moving artifacts to their final location + once the saver goes out of scope. + """ + if not self.retained: + return export_as + alloced_path = self._alloc_retained_path(file_name) + if export_as: + self._copy_on_finalize.append((alloced_path, export_as)) + return alloced_path + + def _alloc_retained_path(self, file_name: str) -> str: + assert self.retained + index = 0 + original_file_name = file_name + while True: + if file_name not in self._retained_file_names: + # First use of this name. + self._retained_file_names.add(file_name) + os.makedirs(self.retained_path, exist_ok=True) + return os.path.join(self.retained_path, file_name) + index += 1 + stem, ext = os.path.splitext(original_file_name) + file_name = f"{stem}_{index}{ext}" + + def _finalize(self): + if not self.retained: + return + # See which files were materialized. + was_materialized = [] + for file_name in self._retained_file_names: + file_path = os.path.join(self.retained_path, file_name) + if os.path.exists(file_path): + was_materialized.append((file_name, file_path)) + if was_materialized: + logging.info( + "**** IREE Compiler retained temporary files (%s)***:\n%s", + self.invocation_id, + "\n".join( + [ + f" * {file_name} : {file_path}" + for file_name, file_path in was_materialized + ] + ), + ) + for source_path, target_path in self._copy_on_finalize: + if os.path.exists(source_path): + logging.info( + "Copy retained file to output: %s -> %s", source_path, target_path + ) + shutil.copyfile(source_path, target_path) diff --git a/compiler/bindings/python/iree/compiler/tools/scripts/ireec/__main__.py b/compiler/bindings/python/iree/compiler/tools/scripts/ireec/__main__.py index bffb7868b431..0be252c621f3 100644 --- a/compiler/bindings/python/iree/compiler/tools/scripts/ireec/__main__.py +++ b/compiler/bindings/python/iree/compiler/tools/scripts/ireec/__main__.py @@ -11,11 +11,11 @@ def main(args=None): - if args is None: - args = sys.argv[1:] - exe = binaries.find_tool("iree-compile") - return subprocess.call(args=[exe] + args) + if args is None: + args = sys.argv[1:] + exe = binaries.find_tool("iree-compile") + return subprocess.call(args=[exe] + args) if __name__ == "__main__": - sys.exit(main()) + sys.exit(main()) diff --git a/compiler/bindings/python/iree/compiler/tools/tf.py b/compiler/bindings/python/iree/compiler/tools/tf.py index 00c157349b39..729cb1518f28 100644 --- a/compiler/bindings/python/iree/compiler/tools/tf.py +++ b/compiler/bindings/python/iree/compiler/tools/tf.py @@ -34,144 +34,151 @@ def is_available(): - """Determine if TensorFlow and the compiler are available.""" - try: - import tensorflow as tf - except ModuleNotFoundError: - logging.warn("Unable to import tensorflow") - return False - try: - import iree.tools.tf.scripts.iree_import_tf.__main__ - except ModuleNotFoundError: - logging.warning("Unable to find iree-import-tf") - return False - return True + """Determine if TensorFlow and the compiler are available.""" + try: + import tensorflow as tf + except ModuleNotFoundError: + logging.warn("Unable to import tensorflow") + return False + try: + import iree.tools.tf.scripts.iree_import_tf.__main__ + except ModuleNotFoundError: + logging.warning("Unable to find iree-import-tf") + return False + return True class ImportType(Enum): - """Import type of the model.""" - OBJECT_GRAPH = "savedmodel_v2" - V2 = "savedmodel_v2" - SIGNATURE_DEF = "savedmodel_v1" - V1 = "savedmodel_v1" + """Import type of the model.""" + + OBJECT_GRAPH = "savedmodel_v2" + V2 = "savedmodel_v2" + SIGNATURE_DEF = "savedmodel_v1" + V1 = "savedmodel_v1" + + @staticmethod + def parse(spec: Union[str, ImportType]) -> ImportType: + """Parses or returns an ImportType. + + Args: + spec: An ImportType instance or the case-insensitive name of one of + the enum values. + Returns: + An ImportType instance. + """ + if isinstance(spec, ImportType): + return spec + spec = spec.upper() + if spec not in ImportType.__members__: + raise ValueError( + f"For import_type= argument, expected one of: " + f"{', '.join(ImportType.__members__.keys())}" + ) + return ImportType[spec] - @staticmethod - def parse(spec: Union[str, ImportType]) -> ImportType: - """Parses or returns an ImportType. + +@dataclass +class ImportOptions(CompilerOptions): + """Import options layer on top of the backend compiler options. Args: - spec: An ImportType instance or the case-insensitive name of one of - the enum values. - Returns: - An ImportType instance. + exported_names: Optional sequence representing the exported names to + keep (object graph/v2 models only). + import_only: Only import the module. If True, the result will be textual + MLIR that can be further fed to the IREE compiler. If False (default), + the result will be the fully compiled IREE binary. In both cases, + bytes-like output is returned. Note that if the output_file= is + specified and import_only=True, then the MLIR form will be written to + the output file. + import_type: Type of import to perform. See ImportType enum. + saved_model_tags: Set of tags to export (signature def/v1 saved models + only). + save_temp_iree_input: Optionally save the IR that is the result of the + import (ready to be passed to IREE). """ - if isinstance(spec, ImportType): - return spec - spec = spec.upper() - if spec not in ImportType.__members__: - raise ValueError(f"For import_type= argument, expected one of: " - f"{', '.join(ImportType.__members__.keys())}") - return ImportType[spec] + exported_names: Sequence[str] = () + import_only: bool = False + import_type: ImportType = ImportType.OBJECT_GRAPH + input_type: Union[InputType, str] = InputType.STABLEHLO_XLA + saved_model_tags: Set[str] = field(default_factory=set) + save_temp_iree_input: Optional[str] = None -@dataclass -class ImportOptions(CompilerOptions): - """Import options layer on top of the backend compiler options. - - Args: - exported_names: Optional sequence representing the exported names to - keep (object graph/v2 models only). - import_only: Only import the module. If True, the result will be textual - MLIR that can be further fed to the IREE compiler. If False (default), - the result will be the fully compiled IREE binary. In both cases, - bytes-like output is returned. Note that if the output_file= is - specified and import_only=True, then the MLIR form will be written to - the output file. - import_type: Type of import to perform. See ImportType enum. - saved_model_tags: Set of tags to export (signature def/v1 saved models - only). - save_temp_iree_input: Optionally save the IR that is the result of the - import (ready to be passed to IREE). - """ - - exported_names: Sequence[str] = () - import_only: bool = False - import_type: ImportType = ImportType.OBJECT_GRAPH - input_type: Union[InputType, str] = InputType.STABLEHLO_XLA - saved_model_tags: Set[str] = field(default_factory=set) - save_temp_iree_input: Optional[str] = None - - def __post_init__(self): - self.import_type = ImportType.parse(self.import_type) + def __post_init__(self): + self.import_type = ImportType.parse(self.import_type) def compile_saved_model(saved_model_dir: str, **kwargs): - """Compiles an on-disk saved model to an IREE binary. - - Args: - saved_model_dir: Path to directory where the model was saved. - **kwargs: Keyword args corresponding to ImportOptions or CompilerOptions. - Returns: - A bytes-like object with the compiled output or None if output_file= - was specified. - """ - from iree.tools.tf.scripts.iree_import_tf import __main__ - - with TempFileSaver.implicit() as tfs, tempfile.TemporaryDirectory() as tmpdir: - options = ImportOptions(**kwargs) - - if options.import_only and options.output_file: - # Importing to a file and stopping, write to that file directly. - tf_iree_input = options.output_file - elif options.save_temp_iree_input: - # Saving the file, use tfs. - tf_iree_input = tfs.alloc_optional("tf-iree-input.mlir", - export_as=options.save_temp_iree_input) - else: - # Not saving the file, so generate a loose temp file without tfs. - tf_iree_input = os.path.join(tmpdir, 'tf-iree-input.mlir') - - __main__.import_saved_model(output_path=tf_iree_input, - saved_model_dir=saved_model_dir, - exported_names=",".join(options.exported_names), - import_type=options.import_type.value, - tags=",".join(options.saved_model_tags)) - - if options.import_only: - if options.output_file: - return None - with open(tf_iree_input, "r") as f: - return f.read() - - # Run IREE compilation pipeline - compile_cl = build_compile_command_line(tf_iree_input, tfs, options) - result = invoke_pipeline([compile_cl]) - if options.output_file: - return None - return result + """Compiles an on-disk saved model to an IREE binary. + + Args: + saved_model_dir: Path to directory where the model was saved. + **kwargs: Keyword args corresponding to ImportOptions or CompilerOptions. + Returns: + A bytes-like object with the compiled output or None if output_file= + was specified. + """ + from iree.tools.tf.scripts.iree_import_tf import __main__ + + with TempFileSaver.implicit() as tfs, tempfile.TemporaryDirectory() as tmpdir: + options = ImportOptions(**kwargs) + + if options.import_only and options.output_file: + # Importing to a file and stopping, write to that file directly. + tf_iree_input = options.output_file + elif options.save_temp_iree_input: + # Saving the file, use tfs. + tf_iree_input = tfs.alloc_optional( + "tf-iree-input.mlir", export_as=options.save_temp_iree_input + ) + else: + # Not saving the file, so generate a loose temp file without tfs. + tf_iree_input = os.path.join(tmpdir, "tf-iree-input.mlir") + + __main__.import_saved_model( + output_path=tf_iree_input, + saved_model_dir=saved_model_dir, + exported_names=",".join(options.exported_names), + import_type=options.import_type.value, + tags=",".join(options.saved_model_tags), + ) + + if options.import_only: + if options.output_file: + return None + with open(tf_iree_input, "r") as f: + return f.read() + + # Run IREE compilation pipeline + compile_cl = build_compile_command_line(tf_iree_input, tfs, options) + result = invoke_pipeline([compile_cl]) + if options.output_file: + return None + return result def compile_module(module, saved_model_dir: Optional[str] = None, **kwargs): - """Compiles a tf.Module to an IREE binary (by saving to disk). - - Args: - module: The tf.Module instance to convert to MLIR - saved_model_dir: Optional path to save the tf.Module to. The module will not - be persisted on disk outside of this call if this is not provided. - **kwargs: Keyword args corresponding to ImportOptions or CompilerOptions. - Returns: - Same as compile_saved_model(). - """ - with TempFileSaver.implicit() as tfs: - - def do_it(saved_model_dir): - import tensorflow as tf - options = tf.saved_model.SaveOptions(save_debug_info=True) - tf.saved_model.save(module, saved_model_dir, options=options) - return compile_saved_model(saved_model_dir, **kwargs) - - if saved_model_dir: - return do_it(saved_model_dir) - else: - with tempfile.TemporaryDirectory(suffix=".sm") as td: - return do_it(td) + """Compiles a tf.Module to an IREE binary (by saving to disk). + + Args: + module: The tf.Module instance to convert to MLIR + saved_model_dir: Optional path to save the tf.Module to. The module will not + be persisted on disk outside of this call if this is not provided. + **kwargs: Keyword args corresponding to ImportOptions or CompilerOptions. + Returns: + Same as compile_saved_model(). + """ + with TempFileSaver.implicit() as tfs: + + def do_it(saved_model_dir): + import tensorflow as tf + + options = tf.saved_model.SaveOptions(save_debug_info=True) + tf.saved_model.save(module, saved_model_dir, options=options) + return compile_saved_model(saved_model_dir, **kwargs) + + if saved_model_dir: + return do_it(saved_model_dir) + else: + with tempfile.TemporaryDirectory(suffix=".sm") as td: + return do_it(td) diff --git a/compiler/bindings/python/iree/compiler/tools/tflite.py b/compiler/bindings/python/iree/compiler/tools/tflite.py index 72333fd7fc88..84443adff3b8 100644 --- a/compiler/bindings/python/iree/compiler/tools/tflite.py +++ b/compiler/bindings/python/iree/compiler/tools/tflite.py @@ -29,104 +29,109 @@ def is_available(): - """Determine if the TFLite frontend is available.""" - try: - import iree.tools.tflite.scripts.iree_import_tflite.__main__ - except ModuleNotFoundError: - logging.warning("Unable to find IREE tool iree-import-tflite") - return False - return True + """Determine if the TFLite frontend is available.""" + try: + import iree.tools.tflite.scripts.iree_import_tflite.__main__ + except ModuleNotFoundError: + logging.warning("Unable to find IREE tool iree-import-tflite") + return False + return True @dataclass class ImportOptions(CompilerOptions): - """Import options layer on top of the backend compiler options. - - Args: - input_arrays: Sequence of input array node names (if different from - default). - output_arrays: Sequence of output array node names (if different from - default). - import_only: Only import the module. If True, the result will be textual - MLIR that can be further fed to the IREE compiler. If False (default), - the result will be the fully compiled IREE binary. In both cases, - bytes-like output is returned. Note that if the output_file= is - specified and import_only=True, then the MLIR form will be written to - the output file. - import_extra_args: Extra arguments to pass to the iree-import-tf tool. - save_temp_tfl_input: Optionally save the IR that results from importing - the flatbuffer (prior to any further transformations). - save_temp_iree_input: Optionally save the IR that is the result of the - import (ready to be passed to IREE). - """ - - input_arrays: Sequence[str] = () - output_arrays: Sequence[str] = () - import_only: bool = False - import_extra_args: Sequence[str] = () - save_temp_tfl_input: Optional[str] = None - save_temp_iree_input: Optional[str] = None - input_type: Optional[str] = "tosa" + """Import options layer on top of the backend compiler options. + + Args: + input_arrays: Sequence of input array node names (if different from + default). + output_arrays: Sequence of output array node names (if different from + default). + import_only: Only import the module. If True, the result will be textual + MLIR that can be further fed to the IREE compiler. If False (default), + the result will be the fully compiled IREE binary. In both cases, + bytes-like output is returned. Note that if the output_file= is + specified and import_only=True, then the MLIR form will be written to + the output file. + import_extra_args: Extra arguments to pass to the iree-import-tf tool. + save_temp_tfl_input: Optionally save the IR that results from importing + the flatbuffer (prior to any further transformations). + save_temp_iree_input: Optionally save the IR that is the result of the + import (ready to be passed to IREE). + """ + + input_arrays: Sequence[str] = () + output_arrays: Sequence[str] = () + import_only: bool = False + import_extra_args: Sequence[str] = () + save_temp_tfl_input: Optional[str] = None + save_temp_iree_input: Optional[str] = None + input_type: Optional[str] = "tosa" def compile_file(fb_path: str, **kwargs): - """Compiles a TFLite FlatBuffer file to an IREE binary. - - Args: - fb_path: Path to the FlatBuffer. - **kwargs: Keyword args corresponding to ImportOptions or CompilerOptions. - Returns: - A bytes-like object with the compiled output or None if output_file= - was specified. - """ - from iree.tools.tflite.scripts.iree_import_tflite import __main__ - with TempFileSaver.implicit() as tfs: - options = ImportOptions(**kwargs) - - with TempFileSaver.implicit() as tfs, tempfile.TemporaryDirectory() as tmpdir: - if options.import_only and options.output_file: - # Importing to a file and stopping, write to that file directly. - tfl_iree_input = options.output_file - elif options.save_temp_iree_input: - # Saving the file, use tfs. - tfl_iree_input = tfs.alloc_optional( - "tfl-iree-input.mlir", export_as=options.save_temp_iree_input) - else: - # Not saving the file, so generate a loose temp file without tfs. - tfl_iree_input = os.path.join(tmpdir, 'tfl-iree-input.mlir') - - __main__.tflite_to_tosa(flatbuffer=fb_path, - bytecode=tfl_iree_input, - ordered_input_arrays=options.input_arrays, - ordered_output_arrays=options.output_arrays) - - if options.import_only: - if options.output_file: - return None - with open(tfl_iree_input, "r") as f: - return f.read() - - # Run IREE compilation pipeline - compile_cl = build_compile_command_line(tfl_iree_input, tfs, options) - result = invoke_pipeline([compile_cl]) - if options.output_file: - return None - return result + """Compiles a TFLite FlatBuffer file to an IREE binary. + + Args: + fb_path: Path to the FlatBuffer. + **kwargs: Keyword args corresponding to ImportOptions or CompilerOptions. + Returns: + A bytes-like object with the compiled output or None if output_file= + was specified. + """ + from iree.tools.tflite.scripts.iree_import_tflite import __main__ + + with TempFileSaver.implicit() as tfs: + options = ImportOptions(**kwargs) + + with TempFileSaver.implicit() as tfs, tempfile.TemporaryDirectory() as tmpdir: + if options.import_only and options.output_file: + # Importing to a file and stopping, write to that file directly. + tfl_iree_input = options.output_file + elif options.save_temp_iree_input: + # Saving the file, use tfs. + tfl_iree_input = tfs.alloc_optional( + "tfl-iree-input.mlir", export_as=options.save_temp_iree_input + ) + else: + # Not saving the file, so generate a loose temp file without tfs. + tfl_iree_input = os.path.join(tmpdir, "tfl-iree-input.mlir") + + __main__.tflite_to_tosa( + flatbuffer=fb_path, + bytecode=tfl_iree_input, + ordered_input_arrays=options.input_arrays, + ordered_output_arrays=options.output_arrays, + ) + + if options.import_only: + if options.output_file: + return None + with open(tfl_iree_input, "r") as f: + return f.read() + + # Run IREE compilation pipeline + compile_cl = build_compile_command_line(tfl_iree_input, tfs, options) + result = invoke_pipeline([compile_cl]) + if options.output_file: + return None + return result def compile_str(input_bytes: bytes, **kwargs): - """Compiles in-memory TFLite FlatBuffer to an IREE binary. - - Args: - input_bytes: Flatbuffer content as bytes or IR string. - **kwargs: Keyword args corresponding to ImportOptions or CompilerOptions. - Returns: - A bytes-like object with the compiled output or None if output_file= - was specified. - """ - input_bytes = input_bytes.encode("utf-8") if isinstance(input_bytes, - str) else input_bytes - with tempfile.NamedTemporaryFile(mode="w") as temp_file: - tempfile.write(input_bytes) - tempfile.close() - return compile_file(tempfile.name, **kwargs) + """Compiles in-memory TFLite FlatBuffer to an IREE binary. + + Args: + input_bytes: Flatbuffer content as bytes or IR string. + **kwargs: Keyword args corresponding to ImportOptions or CompilerOptions. + Returns: + A bytes-like object with the compiled output or None if output_file= + was specified. + """ + input_bytes = ( + input_bytes.encode("utf-8") if isinstance(input_bytes, str) else input_bytes + ) + with tempfile.NamedTemporaryFile(mode="w") as temp_file: + tempfile.write(input_bytes) + tempfile.close() + return compile_file(tempfile.name, **kwargs) diff --git a/compiler/bindings/python/test/ir/registration_test.py b/compiler/bindings/python/test/ir/registration_test.py index d2bf6939d215..33c8f0036b7e 100644 --- a/compiler/bindings/python/test/ir/registration_test.py +++ b/compiler/bindings/python/test/ir/registration_test.py @@ -9,12 +9,14 @@ # Just a simple test that dialects have been registered properly on the # context. with ir.Context() as ctx: - input_module = ir.Module.parse(r""" + input_module = ir.Module.parse( + r""" builtin.module { func.func @simple_mul(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { %0 = arith.mulf %arg0, %arg1 : tensor<4xf32> return %0 : tensor<4xf32> } } - """) - print(input_module) + """ + ) + print(input_module) diff --git a/compiler/bindings/python/test/tools/compiler_core_test.py b/compiler/bindings/python/test/tools/compiler_core_test.py index e4be98f7ffe9..0f3ca3e8d966 100644 --- a/compiler/bindings/python/test/tools/compiler_core_test.py +++ b/compiler/bindings/python/test/tools/compiler_core_test.py @@ -22,233 +22,249 @@ class CompilerTest(unittest.TestCase): - - def setUp(self): - if "IREE_SAVE_TEMPS" in os.environ: - del os.environ["IREE_SAVE_TEMPS"] - - def testQueryTargets(self): - target_names = iree.compiler.query_available_targets() - logging.info("Targets = %s", target_names) - # The VMVX target is always enabled. - self.assertIn("vmvx", target_names) - - def testNoTargetBackends(self): - with self.assertRaisesRegex( - ValueError, "Expected a non-empty list for 'target_backends'"): - binary = iree.compiler.tools.compile_str(SIMPLE_MUL_ASM) - - def testCompileStr(self): - binary = iree.compiler.tools.compile_str( - SIMPLE_MUL_ASM, - target_backends=iree.compiler.tools.DEFAULT_TESTING_BACKENDS) - logging.info("Flatbuffer size = %d", len(binary)) - self.assertTrue(binary) - - # Compiling the string form means that the compiler does not have a valid - # source file name, which can cause issues. Verify specifically. - # See: https://github.com/openxla/iree/issues/4439 - def testCompileStrLLVMCPU(self): - binary = iree.compiler.tools.compile_str(SIMPLE_MUL_ASM, - target_backends=["llvm-cpu"]) - logging.info("Flatbuffer size = %d", len(binary)) - self.assertTrue(binary) - - # Verifies that multiple target_backends are accepted. Which two are not - # load bearing. - # See: https://github.com/openxla/iree/issues/4436 - def testCompileMultipleBackends(self): - binary = iree.compiler.tools.compile_str( - SIMPLE_MUL_ASM, target_backends=["llvm-cpu", "vulkan-spirv"]) - logging.info("Flatbuffer size = %d", len(binary)) - self.assertTrue(binary) - - def testCompileInputFile(self): - with tempfile.NamedTemporaryFile("wt", delete=False) as f: - try: - f.write(SIMPLE_MUL_ASM) - f.close() - binary = iree.compiler.tools.compile_file( - f.name, - target_backends=iree.compiler.tools.DEFAULT_TESTING_BACKENDS) - finally: - os.remove(f.name) - logging.info("Flatbuffer size = %d", len(binary)) - self.assertIn(b"simple_mul", binary) - - def testCompileOutputFile(self): - with tempfile.NamedTemporaryFile("wt", delete=False) as f: - try: - f.close() - output = iree.compiler.tools.compile_str( + def setUp(self): + if "IREE_SAVE_TEMPS" in os.environ: + del os.environ["IREE_SAVE_TEMPS"] + + def testQueryTargets(self): + target_names = iree.compiler.query_available_targets() + logging.info("Targets = %s", target_names) + # The VMVX target is always enabled. + self.assertIn("vmvx", target_names) + + def testNoTargetBackends(self): + with self.assertRaisesRegex( + ValueError, "Expected a non-empty list for 'target_backends'" + ): + binary = iree.compiler.tools.compile_str(SIMPLE_MUL_ASM) + + def testCompileStr(self): + binary = iree.compiler.tools.compile_str( + SIMPLE_MUL_ASM, target_backends=iree.compiler.tools.DEFAULT_TESTING_BACKENDS + ) + logging.info("Flatbuffer size = %d", len(binary)) + self.assertTrue(binary) + + # Compiling the string form means that the compiler does not have a valid + # source file name, which can cause issues. Verify specifically. + # See: https://github.com/openxla/iree/issues/4439 + def testCompileStrLLVMCPU(self): + binary = iree.compiler.tools.compile_str( + SIMPLE_MUL_ASM, target_backends=["llvm-cpu"] + ) + logging.info("Flatbuffer size = %d", len(binary)) + self.assertTrue(binary) + + # Verifies that multiple target_backends are accepted. Which two are not + # load bearing. + # See: https://github.com/openxla/iree/issues/4436 + def testCompileMultipleBackends(self): + binary = iree.compiler.tools.compile_str( + SIMPLE_MUL_ASM, target_backends=["llvm-cpu", "vulkan-spirv"] + ) + logging.info("Flatbuffer size = %d", len(binary)) + self.assertTrue(binary) + + def testCompileInputFile(self): + with tempfile.NamedTemporaryFile("wt", delete=False) as f: + try: + f.write(SIMPLE_MUL_ASM) + f.close() + binary = iree.compiler.tools.compile_file( + f.name, target_backends=iree.compiler.tools.DEFAULT_TESTING_BACKENDS + ) + finally: + os.remove(f.name) + logging.info("Flatbuffer size = %d", len(binary)) + self.assertIn(b"simple_mul", binary) + + def testCompileOutputFile(self): + with tempfile.NamedTemporaryFile("wt", delete=False) as f: + try: + f.close() + output = iree.compiler.tools.compile_str( + SIMPLE_MUL_ASM, + output_file=f.name, + target_backends=iree.compiler.tools.DEFAULT_TESTING_BACKENDS, + ) + self.assertIsNone(output) + + with open(f.name, "rb") as f_read: + binary = f_read.read() + finally: + os.remove(f.name) + logging.info("Flatbuffer size = %d", len(binary)) + self.assertIn(b"simple_mul", binary) + + def testOutputFbText(self): + text = iree.compiler.tools.compile_str( + SIMPLE_MUL_ASM, + output_format=iree.compiler.tools.OutputFormat.FLATBUFFER_TEXT, + target_backends=iree.compiler.tools.DEFAULT_TESTING_BACKENDS, + ).decode("utf-8") + # Just check for an arbitrary JSON-tag. + self.assertIn('"exported_functions"', text) + + def testBadInputType(self): + with self.assertRaisesRegex( + ValueError, + "For input_type= argument, expected one of: " + "NONE, AUTO, STABLEHLO, STABLEHLO_XLA, TOSA", + ): + _ = iree.compiler.tools.compile_str( + SIMPLE_MUL_ASM, + input_type="not-existing", + output_format="foobar", + target_backends=iree.compiler.tools.DEFAULT_TESTING_BACKENDS, + ) + + def testBadOutputFormat(self): + with self.assertRaisesRegex( + ValueError, + "For output_format= argument, expected one of: " + "FLATBUFFER_BINARY, FLATBUFFER_TEXT, MLIR_TEXT", + ): + _ = iree.compiler.tools.compile_str( + SIMPLE_MUL_ASM, + output_format="foobar", + target_backends=iree.compiler.tools.DEFAULT_TESTING_BACKENDS, + ) + + def testOutputFbTextParsed(self): + text = iree.compiler.tools.compile_str( + SIMPLE_MUL_ASM, + input_type="stablehlo", + output_format="flatbuffer_text", + target_backends=iree.compiler.tools.DEFAULT_TESTING_BACKENDS, + ).decode("utf-8") + # Just check for an arbitrary JSON-tag. + self.assertIn('"exported_functions"', text) + + def testOutputMlirText(self): + text = iree.compiler.tools.compile_str( + SIMPLE_MUL_ASM, + output_format=iree.compiler.tools.OutputFormat.MLIR_TEXT, + target_backends=iree.compiler.tools.DEFAULT_TESTING_BACKENDS, + ).decode("utf-8") + # Just check for a textual op name. + self.assertIn("vm.module", text) + + def testExtraArgsStderr(self): + # mlir-timing is not special: it just does something and emits to stderr. + with io.StringIO() as buf, contextlib.redirect_stderr(buf): + iree.compiler.tools.compile_str( + SIMPLE_MUL_ASM, + extra_args=["--mlir-timing"], + target_backends=iree.compiler.tools.DEFAULT_TESTING_BACKENDS, + ) + stderr = buf.getvalue() + self.assertIn("Execution time report", stderr) + + def testAllOptions(self): + binary = iree.compiler.tools.compile_str( SIMPLE_MUL_ASM, - output_file=f.name, - target_backends=iree.compiler.tools.DEFAULT_TESTING_BACKENDS) - self.assertIsNone(output) - - with open(f.name, "rb") as f_read: - binary = f_read.read() - finally: - os.remove(f.name) - logging.info("Flatbuffer size = %d", len(binary)) - self.assertIn(b"simple_mul", binary) - - def testOutputFbText(self): - text = iree.compiler.tools.compile_str( - SIMPLE_MUL_ASM, - output_format=iree.compiler.tools.OutputFormat.FLATBUFFER_TEXT, - target_backends=iree.compiler.tools.DEFAULT_TESTING_BACKENDS).decode( - "utf-8") - # Just check for an arbitrary JSON-tag. - self.assertIn('"exported_functions"', text) - - def testBadInputType(self): - with self.assertRaisesRegex( - ValueError, "For input_type= argument, expected one of: " - "NONE, AUTO, STABLEHLO, STABLEHLO_XLA, TOSA"): - _ = iree.compiler.tools.compile_str( - SIMPLE_MUL_ASM, - input_type="not-existing", - output_format="foobar", - target_backends=iree.compiler.tools.DEFAULT_TESTING_BACKENDS) - - def testBadOutputFormat(self): - with self.assertRaisesRegex( - ValueError, "For output_format= argument, expected one of: " - "FLATBUFFER_BINARY, FLATBUFFER_TEXT, MLIR_TEXT"): - _ = iree.compiler.tools.compile_str( - SIMPLE_MUL_ASM, - output_format="foobar", - target_backends=iree.compiler.tools.DEFAULT_TESTING_BACKENDS) - - def testOutputFbTextParsed(self): - text = iree.compiler.tools.compile_str( - SIMPLE_MUL_ASM, - input_type='stablehlo', - output_format='flatbuffer_text', - target_backends=iree.compiler.tools.DEFAULT_TESTING_BACKENDS).decode( - "utf-8") - # Just check for an arbitrary JSON-tag. - self.assertIn('"exported_functions"', text) - - def testOutputMlirText(self): - text = iree.compiler.tools.compile_str( - SIMPLE_MUL_ASM, - output_format=iree.compiler.tools.OutputFormat.MLIR_TEXT, - target_backends=iree.compiler.tools.DEFAULT_TESTING_BACKENDS).decode( - "utf-8") - # Just check for a textual op name. - self.assertIn("vm.module", text) - - def testExtraArgsStderr(self): - # mlir-timing is not special: it just does something and emits to stderr. - with io.StringIO() as buf, contextlib.redirect_stderr(buf): - iree.compiler.tools.compile_str( - SIMPLE_MUL_ASM, - extra_args=["--mlir-timing"], - target_backends=iree.compiler.tools.DEFAULT_TESTING_BACKENDS) - stderr = buf.getvalue() - self.assertIn("Execution time report", stderr) - - def testAllOptions(self): - binary = iree.compiler.tools.compile_str( - SIMPLE_MUL_ASM, - optimize=False, - strip_debug_ops=True, - strip_source_map=True, - crash_reproducer_path="foobar.txt", - enable_benchmark=True, - target_backends=iree.compiler.tools.DEFAULT_TESTING_BACKENDS) - - def testException(self): - with self.assertRaisesRegex(iree.compiler.tools.CompilerToolError, - "Invoked with"): - _ = iree.compiler.tools.compile_str( - "I'm a little teapot but not a valid program", - target_backends=iree.compiler.tools.DEFAULT_TESTING_BACKENDS) - - def testExplicitTempFileSaver(self): - temp_dir = tempfile.TemporaryDirectory() - output_file = tempfile.NamedTemporaryFile("wt") - output_file.close() - with iree.compiler.tools.TempFileSaver(temp_dir.name): - output = iree.compiler.tools.compile_str( - SIMPLE_MUL_ASM, - output_file=output_file.name, - target_backends=iree.compiler.tools.DEFAULT_TESTING_BACKENDS) - self.assertIsNone(output) - - # There should be an output file and a core-output.bin in the temp dir. - self.assertTrue(os.path.exists(output_file.name)) - expected_temp_file = os.path.join(temp_dir.name, "core-output.bin") - self.assertTrue(os.path.exists(expected_temp_file)) - - # And they should have the same contents. - with open(output_file.name, "rb") as f: - output_contents = f.read() - with open(expected_temp_file, "rb") as f: - temp_contents = f.read() - self.assertEqual(temp_contents, output_contents) - temp_dir.cleanup() - - def testExplicitTempFileSaverCompileToStrTextInput(self): - temp_dir = tempfile.TemporaryDirectory() - with iree.compiler.tools.TempFileSaver(temp_dir.name): - output = iree.compiler.tools.compile_str( - SIMPLE_MUL_ASM, - target_backends=iree.compiler.tools.DEFAULT_TESTING_BACKENDS) - self.assertIsNotNone(output) - self.assertGreater(len(output), 0) - - # There should be a core-input.mlir and core-output.bin in the temp dir. - expected_temp_file = os.path.join(temp_dir.name, "core-output.bin") - self.assertTrue(os.path.exists(expected_temp_file)) - with open(expected_temp_file, "rb") as f: - temp_output = f.read() - self.assertEqual(output, temp_output) - - expected_temp_file = os.path.join(temp_dir.name, "core-input.mlir") - self.assertTrue(os.path.exists(expected_temp_file)) - with open(expected_temp_file, "rt") as f: - input_contents = f.read() - self.assertEqual(SIMPLE_MUL_ASM, input_contents) - temp_dir.cleanup() - - def testExplicitTempFileSaverBinaryInput(self): - temp_dir = tempfile.TemporaryDirectory() - with iree.compiler.tools.TempFileSaver(temp_dir.name): - output = iree.compiler.tools.compile_str( - SIMPLE_MUL_ASM, - target_backends=iree.compiler.tools.DEFAULT_TESTING_BACKENDS) - self.assertIsNotNone(output) - self.assertGreater(len(output), 0) - - # There should be a core-input.mlir and core-output.bin in the temp dir. - expected_temp_file = os.path.join(temp_dir.name, "core-output.bin") - self.assertTrue(os.path.exists(expected_temp_file)) - with open(expected_temp_file, "rb") as f: - temp_output = f.read() - self.assertEqual(output, temp_output) - - expected_temp_file = os.path.join(temp_dir.name, "core-input.mlir") - self.assertTrue(os.path.exists(expected_temp_file)) - with open(expected_temp_file, "rt") as f: - input_contents = f.read() - self.assertEqual(SIMPLE_MUL_ASM, input_contents) - temp_dir.cleanup() - - def testEnvTempFileSaver(self): - temp_dir = tempfile.TemporaryDirectory() - os.environ["IREE_SAVE_TEMPS"] = temp_dir.name - with iree.compiler.tools.TempFileSaver() as tfs: - self.assertTrue(tfs.retained) - self.assertEqual(tfs.retained_path, temp_dir.name) - - def testTempFileSaverDisabled(self): - with iree.compiler.tools.TempFileSaver() as tfs: - self.assertFalse(tfs.retained) + optimize=False, + strip_debug_ops=True, + strip_source_map=True, + crash_reproducer_path="foobar.txt", + enable_benchmark=True, + target_backends=iree.compiler.tools.DEFAULT_TESTING_BACKENDS, + ) + + def testException(self): + with self.assertRaisesRegex( + iree.compiler.tools.CompilerToolError, "Invoked with" + ): + _ = iree.compiler.tools.compile_str( + "I'm a little teapot but not a valid program", + target_backends=iree.compiler.tools.DEFAULT_TESTING_BACKENDS, + ) + + def testExplicitTempFileSaver(self): + temp_dir = tempfile.TemporaryDirectory() + output_file = tempfile.NamedTemporaryFile("wt") + output_file.close() + with iree.compiler.tools.TempFileSaver(temp_dir.name): + output = iree.compiler.tools.compile_str( + SIMPLE_MUL_ASM, + output_file=output_file.name, + target_backends=iree.compiler.tools.DEFAULT_TESTING_BACKENDS, + ) + self.assertIsNone(output) + + # There should be an output file and a core-output.bin in the temp dir. + self.assertTrue(os.path.exists(output_file.name)) + expected_temp_file = os.path.join(temp_dir.name, "core-output.bin") + self.assertTrue(os.path.exists(expected_temp_file)) + + # And they should have the same contents. + with open(output_file.name, "rb") as f: + output_contents = f.read() + with open(expected_temp_file, "rb") as f: + temp_contents = f.read() + self.assertEqual(temp_contents, output_contents) + temp_dir.cleanup() + + def testExplicitTempFileSaverCompileToStrTextInput(self): + temp_dir = tempfile.TemporaryDirectory() + with iree.compiler.tools.TempFileSaver(temp_dir.name): + output = iree.compiler.tools.compile_str( + SIMPLE_MUL_ASM, + target_backends=iree.compiler.tools.DEFAULT_TESTING_BACKENDS, + ) + self.assertIsNotNone(output) + self.assertGreater(len(output), 0) + + # There should be a core-input.mlir and core-output.bin in the temp dir. + expected_temp_file = os.path.join(temp_dir.name, "core-output.bin") + self.assertTrue(os.path.exists(expected_temp_file)) + with open(expected_temp_file, "rb") as f: + temp_output = f.read() + self.assertEqual(output, temp_output) + + expected_temp_file = os.path.join(temp_dir.name, "core-input.mlir") + self.assertTrue(os.path.exists(expected_temp_file)) + with open(expected_temp_file, "rt") as f: + input_contents = f.read() + self.assertEqual(SIMPLE_MUL_ASM, input_contents) + temp_dir.cleanup() + + def testExplicitTempFileSaverBinaryInput(self): + temp_dir = tempfile.TemporaryDirectory() + with iree.compiler.tools.TempFileSaver(temp_dir.name): + output = iree.compiler.tools.compile_str( + SIMPLE_MUL_ASM, + target_backends=iree.compiler.tools.DEFAULT_TESTING_BACKENDS, + ) + self.assertIsNotNone(output) + self.assertGreater(len(output), 0) + + # There should be a core-input.mlir and core-output.bin in the temp dir. + expected_temp_file = os.path.join(temp_dir.name, "core-output.bin") + self.assertTrue(os.path.exists(expected_temp_file)) + with open(expected_temp_file, "rb") as f: + temp_output = f.read() + self.assertEqual(output, temp_output) + + expected_temp_file = os.path.join(temp_dir.name, "core-input.mlir") + self.assertTrue(os.path.exists(expected_temp_file)) + with open(expected_temp_file, "rt") as f: + input_contents = f.read() + self.assertEqual(SIMPLE_MUL_ASM, input_contents) + temp_dir.cleanup() + + def testEnvTempFileSaver(self): + temp_dir = tempfile.TemporaryDirectory() + os.environ["IREE_SAVE_TEMPS"] = temp_dir.name + with iree.compiler.tools.TempFileSaver() as tfs: + self.assertTrue(tfs.retained) + self.assertEqual(tfs.retained_path, temp_dir.name) + + def testTempFileSaverDisabled(self): + with iree.compiler.tools.TempFileSaver() as tfs: + self.assertFalse(tfs.retained) if __name__ == "__main__": - logging.basicConfig(level=logging.DEBUG) - unittest.main() + logging.basicConfig(level=logging.DEBUG) + unittest.main() diff --git a/compiler/bindings/python/test/tools/compiler_tf_test.py b/compiler/bindings/python/test/tools/compiler_tf_test.py index fc1ee7166039..da98e088b6e2 100644 --- a/compiler/bindings/python/test/tools/compiler_tf_test.py +++ b/compiler/bindings/python/test/tools/compiler_tf_test.py @@ -15,69 +15,71 @@ import iree.compiler.tools.tf if not iree.compiler.tools.tf.is_available(): - print(f"Skipping test {__file__} because the IREE TensorFlow compiler " - f"is not installed") - sys.exit(0) + print( + f"Skipping test {__file__} because the IREE TensorFlow compiler " + f"is not installed" + ) + sys.exit(0) import tensorflow as tf class SimpleArithmeticModule(tf.Module): - - @tf.function(input_signature=[ - tf.TensorSpec([4], tf.float32), - tf.TensorSpec([4], tf.float32) - ]) - def simple_mul(self, a, b): - return a * b - - @tf.function(input_signature=[ - tf.TensorSpec([128, 3072], tf.float32), - tf.TensorSpec([3072, 256], tf.float32), - ]) - def simple_matmul(self, a, b): - return tf.matmul(a, b) + @tf.function( + input_signature=[tf.TensorSpec([4], tf.float32), tf.TensorSpec([4], tf.float32)] + ) + def simple_mul(self, a, b): + return a * b + + @tf.function( + input_signature=[ + tf.TensorSpec([128, 3072], tf.float32), + tf.TensorSpec([3072, 256], tf.float32), + ] + ) + def simple_matmul(self, a, b): + return tf.matmul(a, b) # TODO(laurenzo): More test cases needed (may need additional files). # Specifically, figure out how to test v1 models. class TfCompilerTest(tf.test.TestCase): - - def testImportSavedModel(self): - import_mlir = iree.compiler.tools.tf.compile_saved_model( - self.smdir, import_only=True, output_generic_mlir=True).decode("utf-8") - self.assertIn("sym_name = \"simple_matmul\"", import_mlir) - - def testCompileSavedModel(self): - binary = iree.compiler.tools.tf.compile_saved_model( - self.smdir, - target_backends=iree.compiler.tools.tf.DEFAULT_TESTING_BACKENDS) - logging.info("Compiled len: %d", len(binary)) - self.assertIn(b"simple_matmul", binary) - self.assertIn(b"simple_mul", binary) - - def testCompileModule(self): - binary = iree.compiler.tools.tf.compile_module( - self.m, target_backends=iree.compiler.tools.tf.DEFAULT_TESTING_BACKENDS) - logging.info("Compiled len: %d", len(binary)) - self.assertIn(b"simple_matmul", binary) - self.assertIn(b"simple_mul", binary) - - @classmethod - def setUpClass(cls): - cls.m = SimpleArithmeticModule() - cls.tempdir = tempfile.TemporaryDirectory() - cls.smdir = os.path.join(cls.tempdir.name, "arith.sm") - tf.saved_model.save( - cls.m, - cls.smdir, - options=tf.saved_model.SaveOptions(save_debug_info=True)) - - @classmethod - def tearDownClass(cls): - cls.tempdir.cleanup() + def testImportSavedModel(self): + import_mlir = iree.compiler.tools.tf.compile_saved_model( + self.smdir, import_only=True, output_generic_mlir=True + ).decode("utf-8") + self.assertIn('sym_name = "simple_matmul"', import_mlir) + + def testCompileSavedModel(self): + binary = iree.compiler.tools.tf.compile_saved_model( + self.smdir, target_backends=iree.compiler.tools.tf.DEFAULT_TESTING_BACKENDS + ) + logging.info("Compiled len: %d", len(binary)) + self.assertIn(b"simple_matmul", binary) + self.assertIn(b"simple_mul", binary) + + def testCompileModule(self): + binary = iree.compiler.tools.tf.compile_module( + self.m, target_backends=iree.compiler.tools.tf.DEFAULT_TESTING_BACKENDS + ) + logging.info("Compiled len: %d", len(binary)) + self.assertIn(b"simple_matmul", binary) + self.assertIn(b"simple_mul", binary) + + @classmethod + def setUpClass(cls): + cls.m = SimpleArithmeticModule() + cls.tempdir = tempfile.TemporaryDirectory() + cls.smdir = os.path.join(cls.tempdir.name, "arith.sm") + tf.saved_model.save( + cls.m, cls.smdir, options=tf.saved_model.SaveOptions(save_debug_info=True) + ) + + @classmethod + def tearDownClass(cls): + cls.tempdir.cleanup() if __name__ == "__main__": - logging.basicConfig(level=logging.DEBUG) - tf.test.main() + logging.basicConfig(level=logging.DEBUG) + tf.test.main() diff --git a/compiler/bindings/python/test/tools/compiler_tflite_test.py b/compiler/bindings/python/test/tools/compiler_tflite_test.py index 17aac7e38189..8a99e32c024a 100644 --- a/compiler/bindings/python/test/tools/compiler_tflite_test.py +++ b/compiler/bindings/python/test/tools/compiler_tflite_test.py @@ -15,87 +15,85 @@ import iree.compiler.tools.tflite if not iree.compiler.tools.tflite.is_available(): - print(f"Skipping test {__file__} because the IREE TFLite compiler " - f"is not installed") - sys.exit(0) + print( + f"Skipping test {__file__} because the IREE TFLite compiler " + f"is not installed" + ) + sys.exit(0) class CompilerTest(unittest.TestCase): + def testImportBinaryPbFile(self): + path = os.path.join(os.path.dirname(__file__), "testdata", "tflite_sample.fb") + text = iree.compiler.tools.tflite.compile_file(path, import_only=True).decode( + "utf-8" + ) + logging.info("%s", text) + self.assertIn("tosa.mul", text) - def testImportBinaryPbFile(self): - path = os.path.join(os.path.dirname(__file__), "testdata", - "tflite_sample.fb") - text = iree.compiler.tools.tflite.compile_file( - path, import_only=True).decode("utf-8") - logging.info("%s", text) - self.assertIn("tosa.mul", text) + def testCompileBinaryPbFile(self): + path = os.path.join(os.path.dirname(__file__), "testdata", "tflite_sample.fb") + binary = iree.compiler.tools.tflite.compile_file( + path, target_backends=iree.compiler.tools.tflite.DEFAULT_TESTING_BACKENDS + ) + logging.info("Binary length = %d", len(binary)) + self.assertIn(b"main", binary) - def testCompileBinaryPbFile(self): - path = os.path.join(os.path.dirname(__file__), "testdata", - "tflite_sample.fb") - binary = iree.compiler.tools.tflite.compile_file( - path, - target_backends=iree.compiler.tools.tflite.DEFAULT_TESTING_BACKENDS) - logging.info("Binary length = %d", len(binary)) - self.assertIn(b"main", binary) + def testImportBinaryPbFileOutputFile(self): + path = os.path.join(os.path.dirname(__file__), "testdata", "tflite_sample.fb") + with tempfile.NamedTemporaryFile("wt", delete=False) as f: + try: + f.close() + output = iree.compiler.tools.tflite.compile_file( + path, import_only=True, output_file=f.name + ) + self.assertIsNone(output) + with open(f.name, "rt") as f_read: + text = f_read.read() + finally: + os.remove(f.name) + logging.info("%s", text) + self.assertIn("tosa.mul", text) - def testImportBinaryPbFileOutputFile(self): - path = os.path.join(os.path.dirname(__file__), "testdata", - "tflite_sample.fb") - with tempfile.NamedTemporaryFile("wt", delete=False) as f: - try: - f.close() - output = iree.compiler.tools.tflite.compile_file(path, - import_only=True, - output_file=f.name) - self.assertIsNone(output) - with open(f.name, "rt") as f_read: - text = f_read.read() - finally: - os.remove(f.name) - logging.info("%s", text) - self.assertIn("tosa.mul", text) + def testCompileBinaryPbFileOutputFile(self): + path = os.path.join(os.path.dirname(__file__), "testdata", "tflite_sample.fb") + with tempfile.NamedTemporaryFile("wt", delete=False) as f: + try: + f.close() + output = iree.compiler.tools.tflite.compile_file( + path, + output_file=f.name, + target_backends=iree.compiler.tools.tflite.DEFAULT_TESTING_BACKENDS, + ) + self.assertIsNone(output) + with open(f.name, "rb") as f_read: + binary = f_read.read() + finally: + os.remove(f.name) + logging.info("Binary length = %d", len(binary)) + self.assertIn(b"main", binary) - def testCompileBinaryPbFileOutputFile(self): - path = os.path.join(os.path.dirname(__file__), "testdata", - "tflite_sample.fb") - with tempfile.NamedTemporaryFile("wt", delete=False) as f: - try: - f.close() - output = iree.compiler.tools.tflite.compile_file( - path, - output_file=f.name, - target_backends=iree.compiler.tools.tflite.DEFAULT_TESTING_BACKENDS) - self.assertIsNone(output) - with open(f.name, "rb") as f_read: - binary = f_read.read() - finally: - os.remove(f.name) - logging.info("Binary length = %d", len(binary)) - self.assertIn(b"main", binary) + def testImportBinaryPbBytes(self): + path = os.path.join(os.path.dirname(__file__), "testdata", "tflite_sample.fb") + with open(path, "rb") as f: + content = f.read() + text = iree.compiler.tools.tflite.compile_str(content, import_only=True).decode( + "utf-8" + ) + logging.info("%s", text) + self.assertIn("tosa.mul", text) - def testImportBinaryPbBytes(self): - path = os.path.join(os.path.dirname(__file__), "testdata", - "tflite_sample.fb") - with open(path, "rb") as f: - content = f.read() - text = iree.compiler.tools.tflite.compile_str( - content, import_only=True).decode("utf-8") - logging.info("%s", text) - self.assertIn("tosa.mul", text) - - def testCompileBinaryPbBytes(self): - path = os.path.join(os.path.dirname(__file__), "testdata", - "tflite_sample.fb") - with open(path, "rb") as f: - content = f.read() - binary = iree.compiler.tools.tflite.compile_str( - content, - target_backends=iree.compiler.tools.tflite.DEFAULT_TESTING_BACKENDS) - logging.info("Binary length = %d", len(binary)) - self.assertIn(b"main", binary) + def testCompileBinaryPbBytes(self): + path = os.path.join(os.path.dirname(__file__), "testdata", "tflite_sample.fb") + with open(path, "rb") as f: + content = f.read() + binary = iree.compiler.tools.tflite.compile_str( + content, target_backends=iree.compiler.tools.tflite.DEFAULT_TESTING_BACKENDS + ) + logging.info("Binary length = %d", len(binary)) + self.assertIn(b"main", binary) if __name__ == "__main__": - logging.basicConfig(level=logging.DEBUG) - unittest.main() + logging.basicConfig(level=logging.DEBUG) + unittest.main() diff --git a/compiler/bindings/python/test/tools/testdata/generate_tflite.py b/compiler/bindings/python/test/tools/testdata/generate_tflite.py index a325089ca5e1..7fb8047c0429 100644 --- a/compiler/bindings/python/test/tools/testdata/generate_tflite.py +++ b/compiler/bindings/python/test/tools/testdata/generate_tflite.py @@ -10,20 +10,19 @@ class Squared(tf.Module): - - @tf.function - def __call__(self, x): - return tf.square(x) + @tf.function + def __call__(self, x): + return tf.square(x) model = Squared() concrete_func = model.__call__.get_concrete_function( - tf.TensorSpec(shape=[4, 3], dtype=tf.float32)) + tf.TensorSpec(shape=[4, 3], dtype=tf.float32) +) -converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func], - model) +converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func], model) tflite_model = converter.convert() this_dir = os.path.dirname(__file__) with open(os.path.join(this_dir, "tflite_sample.fb"), "wb") as f: - f.write(tflite_model) + f.write(tflite_model) diff --git a/compiler/bindings/python/test/tools/testdata/generate_xla.py b/compiler/bindings/python/test/tools/testdata/generate_xla.py index c58329110247..b2a6c259641b 100644 --- a/compiler/bindings/python/test/tools/testdata/generate_xla.py +++ b/compiler/bindings/python/test/tools/testdata/generate_xla.py @@ -23,6 +23,6 @@ this_dir = os.path.dirname(__file__) with open(os.path.join(this_dir, "xla_sample.pb"), "wb") as f: - f.write(xla_computation.as_serialized_hlo_module_proto()) + f.write(xla_computation.as_serialized_hlo_module_proto()) with open(os.path.join(this_dir, "xla_sample.hlo"), "wt") as f: - f.write(xla_computation.as_hlo_text()) + f.write(xla_computation.as_hlo_text()) diff --git a/compiler/lit.cfg.py b/compiler/lit.cfg.py index 5ed367afa616..e6d8b7c45e0d 100644 --- a/compiler/lit.cfg.py +++ b/compiler/lit.cfg.py @@ -28,15 +28,19 @@ # WindowsLinkerTool uses these from vcvarsall "VCTOOLSINSTALLDIR", "UNIVERSALCRTSDKDIR", - "UCRTVERSION" + "UCRTVERSION", ] -config.environment.update({ - k: v - for k, v in os.environ.items() - if k.startswith("IREE_") or k in passthrough_env_vars -}) +config.environment.update( + { + k: v + for k, v in os.environ.items() + if k.startswith("IREE_") or k in passthrough_env_vars + } +) # Use the most preferred temp directory. -config.test_exec_root = (os.environ.get("TEST_UNDECLARED_OUTPUTS_DIR") or - os.environ.get("TEST_TMPDIR") or - os.path.join(tempfile.gettempdir(), "lit")) +config.test_exec_root = ( + os.environ.get("TEST_UNDECLARED_OUTPUTS_DIR") + or os.environ.get("TEST_TMPDIR") + or os.path.join(tempfile.gettempdir(), "lit") +) diff --git a/compiler/setup.py b/compiler/setup.py index 562d194f0aab..4451bdea41b2 100644 --- a/compiler/setup.py +++ b/compiler/setup.py @@ -41,22 +41,23 @@ def check_pip_version(): - from packaging import version - # Pip versions < 22.0.3 default to out of tree builds, which is quite - # incompatible with what we do (and has other issues). Pip >= 22.0.4 - # removed this option entirely and are only in-tree builds. Since the - # old behavior can silently produce unworking installations, we aggressively - # suppress it. - try: - import pip - except ModuleNotFoundError: - # If pip not installed, we are obviously not trying to package via pip. - pass - else: - if (version.parse(pip.__version__) < version.parse("21.3")): - print("ERROR: pip version >= 21.3 required") - print("Upgrade: pip install pip --upgrade") - sys.exit(2) + from packaging import version + + # Pip versions < 22.0.3 default to out of tree builds, which is quite + # incompatible with what we do (and has other issues). Pip >= 22.0.4 + # removed this option entirely and are only in-tree builds. Since the + # old behavior can silently produce unworking installations, we aggressively + # suppress it. + try: + import pip + except ModuleNotFoundError: + # If pip not installed, we are obviously not trying to package via pip. + pass + else: + if version.parse(pip.__version__) < version.parse("21.3"): + print("ERROR: pip version >= 21.3 required") + print("Upgrade: pip install pip --upgrade") + sys.exit(2) check_pip_version() @@ -80,265 +81,279 @@ def check_pip_version(): IS_CONFIGURED = CONFIGURED_SOURCE_DIR[0] != "@" if IS_CONFIGURED: - IREE_SOURCE_DIR = CONFIGURED_SOURCE_DIR - IREE_BINARY_DIR = CONFIGURED_BINARY_DIR - print( - f"Running setup.py from build tree: " - f"SOURCE_DIR = {IREE_SOURCE_DIR} " - f"BINARY_DIR = {IREE_BINARY_DIR}", - file=sys.stderr) + IREE_SOURCE_DIR = CONFIGURED_SOURCE_DIR + IREE_BINARY_DIR = CONFIGURED_BINARY_DIR + print( + f"Running setup.py from build tree: " + f"SOURCE_DIR = {IREE_SOURCE_DIR} " + f"BINARY_DIR = {IREE_BINARY_DIR}", + file=sys.stderr, + ) else: - IREE_SOURCE_DIR = os.path.join(SETUPPY_DIR, "..") - IREE_BINARY_DIR = os.getenv("IREE_COMPILER_API_CMAKE_BUILD_DIR") - if not IREE_BINARY_DIR: - # Note that setuptools always builds into a "build" directory that - # is a sibling of setup.py, so we just colonize a sub-directory of that - # by default. - IREE_BINARY_DIR = os.path.join(SETUPPY_DIR, "build", "cmake_build") - print( - f"Running setup.py from source tree: " - f"SOURCE_DIR = {IREE_SOURCE_DIR} " - f"BINARY_DIR = {IREE_BINARY_DIR}", - file=sys.stderr) + IREE_SOURCE_DIR = os.path.join(SETUPPY_DIR, "..") + IREE_BINARY_DIR = os.getenv("IREE_COMPILER_API_CMAKE_BUILD_DIR") + if not IREE_BINARY_DIR: + # Note that setuptools always builds into a "build" directory that + # is a sibling of setup.py, so we just colonize a sub-directory of that + # by default. + IREE_BINARY_DIR = os.path.join(SETUPPY_DIR, "build", "cmake_build") + print( + f"Running setup.py from source tree: " + f"SOURCE_DIR = {IREE_SOURCE_DIR} " + f"BINARY_DIR = {IREE_BINARY_DIR}", + file=sys.stderr, + ) # Setup and get version information. VERSION_INFO_FILE = os.path.join(IREE_SOURCE_DIR, "version_info.json") def load_version_info(): - with open(VERSION_INFO_FILE, "rt") as f: - return json.load(f) + with open(VERSION_INFO_FILE, "rt") as f: + return json.load(f) def find_git_versions(): - revisions = {} - try: - revisions["IREE"] = subprocess.check_output( - ["git", "rev-parse", "HEAD"], - cwd=IREE_SOURCE_DIR).decode("utf-8").strip() - except subprocess.SubprocessError as e: - print(f"ERROR: Could not get IREE revision: {e}", file=sys.stderr) - return revisions + revisions = {} + try: + revisions["IREE"] = ( + subprocess.check_output(["git", "rev-parse", "HEAD"], cwd=IREE_SOURCE_DIR) + .decode("utf-8") + .strip() + ) + except subprocess.SubprocessError as e: + print(f"ERROR: Could not get IREE revision: {e}", file=sys.stderr) + return revisions def find_git_submodule_revision(submodule_path): - try: - data = subprocess.check_output(["git", "ls-tree", "HEAD", submodule_path], - cwd=IREE_SOURCE_DIR).decode("utf-8").strip() - columns = re.split("\\s+", data) - return columns[2] - except Exception as e: - print( - f"ERROR: Could not get submodule revision for {submodule_path}" - f" ({e})", - file=sys.stderr) - return "" + try: + data = ( + subprocess.check_output( + ["git", "ls-tree", "HEAD", submodule_path], cwd=IREE_SOURCE_DIR + ) + .decode("utf-8") + .strip() + ) + columns = re.split("\\s+", data) + return columns[2] + except Exception as e: + print( + f"ERROR: Could not get submodule revision for {submodule_path}" f" ({e})", + file=sys.stderr, + ) + return "" try: - version_info = load_version_info() + version_info = load_version_info() except FileNotFoundError: - print("version_info.json not found. Using defaults", file=sys.stderr) - version_info = {} + print("version_info.json not found. Using defaults", file=sys.stderr) + version_info = {} git_versions = find_git_versions() PACKAGE_SUFFIX = version_info.get("package-suffix") or "" PACKAGE_VERSION = version_info.get("package-version") if not PACKAGE_VERSION: - PACKAGE_VERSION = f"0.dev0+{git_versions.get('IREE') or '0'}" + PACKAGE_VERSION = f"0.dev0+{git_versions.get('IREE') or '0'}" def get_cmake_version_info_args(): - version_info_args = [ - f"-DIREE_RELEASE_VERSION:STRING={PACKAGE_VERSION}", - f"-DIREE_RELEASE_REVISION:STRING={git_versions.get('IREE') or '0'}", - ] - if version_info: - version_info_args.append("-DIREE_EMBEDDED_RELEASE_INFO=ON") - return version_info_args + version_info_args = [ + f"-DIREE_RELEASE_VERSION:STRING={PACKAGE_VERSION}", + f"-DIREE_RELEASE_REVISION:STRING={git_versions.get('IREE') or '0'}", + ] + if version_info: + version_info_args.append("-DIREE_EMBEDDED_RELEASE_INFO=ON") + return version_info_args def maybe_nuke_cmake_cache(): - # From run to run under pip, we can end up with different paths to ninja, - # which isn't great and will confuse cmake. Detect if the location of - # ninja changes and force a cache flush. - ninja_path = "" - try: - import ninja - except ModuleNotFoundError: - pass - else: - ninja_path = ninja.__file__ - expected_stamp_contents = f"{sys.executable}\n{ninja_path}" - - # In order to speed things up on CI and not rebuild everything, we nuke - # the CMakeCache.txt file if the path to the Python interpreter changed. - # Ideally, CMake would let us reconfigure this dynamically... but it does - # not (and gets very confused). - # We only do this because the compiler is so expensive to build and very - # little of it depends on the Python version. This is a hack. - PYTHON_STAMP_FILE = os.path.join(IREE_BINARY_DIR, "python_stamp.txt") - if os.path.exists(PYTHON_STAMP_FILE): - with open(PYTHON_STAMP_FILE, "rt") as f: - actual_stamp_contents = f.read() - if actual_stamp_contents == expected_stamp_contents: - # All good. - return - - # Mismatch or not found. Clean it. - cmake_cache_file = os.path.join(IREE_BINARY_DIR, "CMakeCache.txt") - if os.path.exists(cmake_cache_file): - print("Removing CMakeCache.txt because Python version changed", - file=sys.stderr) - os.remove(cmake_cache_file) - - # Also clean the install directory. This avoids version specific pileups - # of binaries that can occur with repeated builds against different - # Python versions. - if os.path.exists(CMAKE_INSTALL_DIR_ABS): - print( - f"Removing CMake install dir because Python version changed: " - f"{CMAKE_INSTALL_DIR_ABS}", - file=sys.stderr) - shutil.rmtree(CMAKE_INSTALL_DIR_ABS) - - # And write. - with open(PYTHON_STAMP_FILE, "wt") as f: - f.write(expected_stamp_contents) + # From run to run under pip, we can end up with different paths to ninja, + # which isn't great and will confuse cmake. Detect if the location of + # ninja changes and force a cache flush. + ninja_path = "" + try: + import ninja + except ModuleNotFoundError: + pass + else: + ninja_path = ninja.__file__ + expected_stamp_contents = f"{sys.executable}\n{ninja_path}" + + # In order to speed things up on CI and not rebuild everything, we nuke + # the CMakeCache.txt file if the path to the Python interpreter changed. + # Ideally, CMake would let us reconfigure this dynamically... but it does + # not (and gets very confused). + # We only do this because the compiler is so expensive to build and very + # little of it depends on the Python version. This is a hack. + PYTHON_STAMP_FILE = os.path.join(IREE_BINARY_DIR, "python_stamp.txt") + if os.path.exists(PYTHON_STAMP_FILE): + with open(PYTHON_STAMP_FILE, "rt") as f: + actual_stamp_contents = f.read() + if actual_stamp_contents == expected_stamp_contents: + # All good. + return + + # Mismatch or not found. Clean it. + cmake_cache_file = os.path.join(IREE_BINARY_DIR, "CMakeCache.txt") + if os.path.exists(cmake_cache_file): + print("Removing CMakeCache.txt because Python version changed", file=sys.stderr) + os.remove(cmake_cache_file) + + # Also clean the install directory. This avoids version specific pileups + # of binaries that can occur with repeated builds against different + # Python versions. + if os.path.exists(CMAKE_INSTALL_DIR_ABS): + print( + f"Removing CMake install dir because Python version changed: " + f"{CMAKE_INSTALL_DIR_ABS}", + file=sys.stderr, + ) + shutil.rmtree(CMAKE_INSTALL_DIR_ABS) + + # And write. + with open(PYTHON_STAMP_FILE, "wt") as f: + f.write(expected_stamp_contents) def get_env_cmake_option(name: str, default_value: bool = False) -> str: - svalue = os.getenv(name) - if not svalue: - svalue = "ON" if default_value else "OFF" - return f"-D{name}={svalue}" + svalue = os.getenv(name) + if not svalue: + svalue = "ON" if default_value else "OFF" + return f"-D{name}={svalue}" def add_env_cmake_setting(args, env_name: str, cmake_name=None) -> str: - svalue = os.getenv(env_name) - if svalue is not None: - if not cmake_name: - cmake_name = env_name - args.append(f"-D{cmake_name}={svalue}") + svalue = os.getenv(env_name) + if svalue is not None: + if not cmake_name: + cmake_name = env_name + args.append(f"-D{cmake_name}={svalue}") def prepare_installation(): - version_py_content = generate_version_py() - print(f"Generating version.py:\n{version_py_content}", file=sys.stderr) - - if not IS_CONFIGURED: - # Build from source tree. - subprocess.check_call(["cmake", "--version"]) - os.makedirs(IREE_BINARY_DIR, exist_ok=True) - maybe_nuke_cmake_cache() - print(f"CMake build dir: {IREE_BINARY_DIR}", file=sys.stderr) - print(f"CMake install dir: {CMAKE_INSTALL_DIR_ABS}", file=sys.stderr) - cfg = "Release" - cmake_args = [ - "-GNinja", - "--log-level=VERBOSE", - "-DIREE_BUILD_PYTHON_BINDINGS=ON", - # Disable .so.0 style symlinking. Python wheels don't preserve links, - # so this ~doubles the binary size if not disabled (yikes!). - "-DCMAKE_PLATFORM_NO_VERSIONED_SONAME=ON", - "-DPython3_EXECUTABLE={}".format(sys.executable), - "-DCMAKE_BUILD_TYPE={}".format(cfg), - get_env_cmake_option("IREE_TARGET_BACKEND_CUDA"), - # TODO(scotttodd): include IREE_TARGET_BACKEND_WEBGPU here (and in env) - get_env_cmake_option("IREE_ENABLE_CPUINFO", "ON"), + version_py_content = generate_version_py() + print(f"Generating version.py:\n{version_py_content}", file=sys.stderr) + + if not IS_CONFIGURED: + # Build from source tree. + subprocess.check_call(["cmake", "--version"]) + os.makedirs(IREE_BINARY_DIR, exist_ok=True) + maybe_nuke_cmake_cache() + print(f"CMake build dir: {IREE_BINARY_DIR}", file=sys.stderr) + print(f"CMake install dir: {CMAKE_INSTALL_DIR_ABS}", file=sys.stderr) + cfg = "Release" + cmake_args = [ + "-GNinja", + "--log-level=VERBOSE", + "-DIREE_BUILD_PYTHON_BINDINGS=ON", + # Disable .so.0 style symlinking. Python wheels don't preserve links, + # so this ~doubles the binary size if not disabled (yikes!). + "-DCMAKE_PLATFORM_NO_VERSIONED_SONAME=ON", + "-DPython3_EXECUTABLE={}".format(sys.executable), + "-DCMAKE_BUILD_TYPE={}".format(cfg), + get_env_cmake_option("IREE_TARGET_BACKEND_CUDA"), + # TODO(scotttodd): include IREE_TARGET_BACKEND_WEBGPU here (and in env) + get_env_cmake_option("IREE_ENABLE_CPUINFO", "ON"), + ] + cmake_args.extend(get_cmake_version_info_args()) + + # These usually flow through the environment, but we add them explicitly + # so that they show clearly in logs (getting them wrong can have bad + # outcomes). + add_env_cmake_setting(cmake_args, "CMAKE_OSX_ARCHITECTURES") + add_env_cmake_setting( + cmake_args, "MACOSX_DEPLOYMENT_TARGET", "CMAKE_OSX_DEPLOYMENT_TARGET" + ) + + # Only do a from-scratch configure if not already configured. + cmake_cache_file = os.path.join(IREE_BINARY_DIR, "CMakeCache.txt") + if not os.path.exists(cmake_cache_file): + print(f"Configuring with: {cmake_args}", file=sys.stderr) + subprocess.check_call( + ["cmake", IREE_SOURCE_DIR] + cmake_args, cwd=IREE_BINARY_DIR + ) + else: + print(f"Not re-configuring (already configured)", file=sys.stderr) + + # Build. + subprocess.check_call( + ["cmake", "--build", ".", "--target", "compiler/all"], cwd=IREE_BINARY_DIR + ) + print("Build complete.", file=sys.stderr) + + # Perform installation on the entire compiler/ tree as this is guaranteed + # to have all of our installation targets. + install_subdirectory = os.path.join(IREE_BINARY_DIR, "compiler") + install_args = [ + "-DCMAKE_INSTALL_DO_STRIP=ON", + f"-DCMAKE_INSTALL_PREFIX={CMAKE_INSTALL_DIR_ABS}", + "-P", + os.path.join(install_subdirectory, "cmake_install.cmake"), ] - cmake_args.extend(get_cmake_version_info_args()) + print(f"Installing with: {install_args}", file=sys.stderr) + subprocess.check_call(["cmake"] + install_args, cwd=install_subdirectory) - # These usually flow through the environment, but we add them explicitly - # so that they show clearly in logs (getting them wrong can have bad - # outcomes). - add_env_cmake_setting(cmake_args, "CMAKE_OSX_ARCHITECTURES") - add_env_cmake_setting(cmake_args, "MACOSX_DEPLOYMENT_TARGET", - "CMAKE_OSX_DEPLOYMENT_TARGET") + # Write version.py directly into install dir. + version_py_file = os.path.join( + CMAKE_INSTALL_DIR_ABS, + "python_packages", + "iree_compiler", + "iree", + "compiler", + "version.py", + ) + os.makedirs(os.path.dirname(version_py_file), exist_ok=True) + with open(version_py_file, "wt") as f: + f.write(version_py_content) - # Only do a from-scratch configure if not already configured. - cmake_cache_file = os.path.join(IREE_BINARY_DIR, "CMakeCache.txt") - if not os.path.exists(cmake_cache_file): - print(f"Configuring with: {cmake_args}", file=sys.stderr) - subprocess.check_call(["cmake", IREE_SOURCE_DIR] + cmake_args, - cwd=IREE_BINARY_DIR) - else: - print(f"Not re-configuring (already configured)", file=sys.stderr) - - # Build. - subprocess.check_call(["cmake", "--build", ".", "--target", "compiler/all"], - cwd=IREE_BINARY_DIR) - print("Build complete.", file=sys.stderr) - - # Perform installation on the entire compiler/ tree as this is guaranteed - # to have all of our installation targets. - install_subdirectory = os.path.join(IREE_BINARY_DIR, "compiler") - install_args = [ - "-DCMAKE_INSTALL_DO_STRIP=ON", - f"-DCMAKE_INSTALL_PREFIX={CMAKE_INSTALL_DIR_ABS}", - "-P", - os.path.join(install_subdirectory, "cmake_install.cmake"), - ] - print(f"Installing with: {install_args}", file=sys.stderr) - subprocess.check_call(["cmake"] + install_args, cwd=install_subdirectory) - - # Write version.py directly into install dir. - version_py_file = os.path.join(CMAKE_INSTALL_DIR_ABS, "python_packages", - "iree_compiler", "iree", "compiler", - "version.py") - os.makedirs(os.path.dirname(version_py_file), exist_ok=True) - with open(version_py_file, "wt") as f: - f.write(version_py_content) - - print(f"Installation prepared: {CMAKE_INSTALL_DIR_ABS}", file=sys.stderr) + print(f"Installation prepared: {CMAKE_INSTALL_DIR_ABS}", file=sys.stderr) class CMakeBuildPy(_build_py): - - def run(self): - # It is critical that the target directory contain all built extensions, - # or else setuptools will helpfully compile an empty binary for us - # (this is the **worst** possible thing it could do). We just copy - # everything. What's another hundred megs between friends? - target_dir = os.path.abspath(self.build_lib) - print(f"Building in target dir: {target_dir}", file=sys.stderr) - os.makedirs(target_dir, exist_ok=True) - print("Copying install to target.", file=sys.stderr) - if os.path.exists(target_dir): - shutil.rmtree(target_dir) - shutil.copytree(os.path.join(CMAKE_INSTALL_DIR_ABS, "python_packages", - "iree_compiler"), - target_dir, - symlinks=False) - print("Target populated.", file=sys.stderr) + def run(self): + # It is critical that the target directory contain all built extensions, + # or else setuptools will helpfully compile an empty binary for us + # (this is the **worst** possible thing it could do). We just copy + # everything. What's another hundred megs between friends? + target_dir = os.path.abspath(self.build_lib) + print(f"Building in target dir: {target_dir}", file=sys.stderr) + os.makedirs(target_dir, exist_ok=True) + print("Copying install to target.", file=sys.stderr) + if os.path.exists(target_dir): + shutil.rmtree(target_dir) + shutil.copytree( + os.path.join(CMAKE_INSTALL_DIR_ABS, "python_packages", "iree_compiler"), + target_dir, + symlinks=False, + ) + print("Target populated.", file=sys.stderr) class CustomBuild(_build): - - def run(self): - self.run_command("build_py") - self.run_command("build_ext") - self.run_command("build_scripts") + def run(self): + self.run_command("build_py") + self.run_command("build_ext") + self.run_command("build_scripts") class CMakeExtension(Extension): - - def __init__(self, name, sourcedir=""): - Extension.__init__(self, name, sources=[]) - self.sourcedir = os.path.abspath(sourcedir) + def __init__(self, name, sourcedir=""): + Extension.__init__(self, name, sources=[]) + self.sourcedir = os.path.abspath(sourcedir) class NoopBuildExtension(_build_ext): + def __init__(self, *args, **kwargs): + assert False - def __init__(self, *args, **kwargs): - assert False - - def build_extension(self, ext): - pass + def build_extension(self, ext): + pass def generate_version_py(): - return f"""# Auto-generated version info. + return f"""# Auto-generated version info. PACKAGE_SUFFIX = "{PACKAGE_SUFFIX}" VERSION = "{PACKAGE_VERSION}" REVISIONS = {json.dumps(git_versions)} @@ -346,42 +361,48 @@ def generate_version_py(): def find_git_versions(): - revisions = {} - try: - revisions["IREE"] = subprocess.check_output( - ["git", "rev-parse", "HEAD"], - cwd=IREE_SOURCE_DIR).decode("utf-8").strip() - except subprocess.SubprocessError as e: - print(f"ERROR: Could not get IREE revision: {e}", file=sys.stderr) - revisions["LLVM_PROJECT"] = find_git_submodule_revision( - "third_party/llvm-project") - revisions["STABLEHLO"] = find_git_submodule_revision("third_party/stablehlo") - return revisions + revisions = {} + try: + revisions["IREE"] = ( + subprocess.check_output(["git", "rev-parse", "HEAD"], cwd=IREE_SOURCE_DIR) + .decode("utf-8") + .strip() + ) + except subprocess.SubprocessError as e: + print(f"ERROR: Could not get IREE revision: {e}", file=sys.stderr) + revisions["LLVM_PROJECT"] = find_git_submodule_revision("third_party/llvm-project") + revisions["STABLEHLO"] = find_git_submodule_revision("third_party/stablehlo") + return revisions def find_git_submodule_revision(submodule_path): - try: - data = subprocess.check_output(["git", "ls-tree", "HEAD", submodule_path], - cwd=IREE_SOURCE_DIR).decode("utf-8").strip() - columns = re.split("\\s+", data) - return columns[2] - except Exception as e: - print( - f"ERROR: Could not get submodule revision for {submodule_path}" - f" ({e})", - file=sys.stderr) - return "" + try: + data = ( + subprocess.check_output( + ["git", "ls-tree", "HEAD", submodule_path], cwd=IREE_SOURCE_DIR + ) + .decode("utf-8") + .strip() + ) + columns = re.split("\\s+", data) + return columns[2] + except Exception as e: + print( + f"ERROR: Could not get submodule revision for {submodule_path}" f" ({e})", + file=sys.stderr, + ) + return "" prepare_installation() -packages = find_namespace_packages(where=os.path.join(CMAKE_INSTALL_DIR_ABS, - "python_packages", - "iree_compiler"), - include=[ - "iree.compiler", - "iree.compiler.*", - ]) +packages = find_namespace_packages( + where=os.path.join(CMAKE_INSTALL_DIR_ABS, "python_packages", "iree_compiler"), + include=[ + "iree.compiler", + "iree.compiler.*", + ], +) print(f"Found compiler packages: {packages}") setup( diff --git a/compiler/src/iree/compiler/API/generate_exports.py b/compiler/src/iree/compiler/API/generate_exports.py index 950e72ff6611..1f0545831b03 100755 --- a/compiler/src/iree/compiler/API/generate_exports.py +++ b/compiler/src/iree/compiler/API/generate_exports.py @@ -80,121 +80,124 @@ # The group 'decl' contains the statement after the macro. MACRO_STATEMENT_PATTERN = re.compile( r"\n(MLIR_CAPI_EXPORTED|IREE_EMBED_EXPORTED)\s+(?P[^\;]+);", - re.MULTILINE | re.DOTALL) + re.MULTILINE | re.DOTALL, +) # Given a statement suspected to be a function declaration, extract the # function symbol. -FUNC_DECL_SYMBOL_PATTERN = re.compile(r"(?P\w+)\(", - re.MULTILINE | re.DOTALL) +FUNC_DECL_SYMBOL_PATTERN = re.compile(r"(?P\w+)\(", re.MULTILINE | re.DOTALL) def main(repo_root: Path, api_root: Path): - export_symbols = list(EXPLICIT_EXPORTS) - # Collect symbols from local header files. - for local_name in LOCAL_HEADER_FILES: - export_symbols.extend(collect_header_exports(api_root / local_name)) - - # Collect symbols from iree-dialects header files. - for local_name in IREE_DIALECTS_HEADER_FILES: - export_symbols.extend( - collect_header_exports( - repo_root / - "llvm-external-projects/iree-dialects/include/iree-dialects-c" / - local_name)) - - # Collect symbols from mlir-c header files. - mlir_c_dir = repo_root / "third_party/llvm-project/mlir/include/mlir-c" - for local_name in MLIR_C_HEADER_FILES: - header_file = mlir_c_dir / local_name - if not header_file.exists(): - raise RuntimeError( - f"Expected MLIR-C header file does not exist: {header_file}") - export_symbols.extend(collect_header_exports(header_file)) - - # Generate. - export_symbols.sort() - generate_macos_symbol_list(export_symbols, api_root / "api_exports.macos.lst") - generate_linker_script(export_symbols, api_root / "api_exports.ld") - generate_def_file(export_symbols, api_root / "api_exports.def") - generate_force_extern(export_symbols, api_root / "api_exports.c") + export_symbols = list(EXPLICIT_EXPORTS) + # Collect symbols from local header files. + for local_name in LOCAL_HEADER_FILES: + export_symbols.extend(collect_header_exports(api_root / local_name)) + + # Collect symbols from iree-dialects header files. + for local_name in IREE_DIALECTS_HEADER_FILES: + export_symbols.extend( + collect_header_exports( + repo_root + / "llvm-external-projects/iree-dialects/include/iree-dialects-c" + / local_name + ) + ) + + # Collect symbols from mlir-c header files. + mlir_c_dir = repo_root / "third_party/llvm-project/mlir/include/mlir-c" + for local_name in MLIR_C_HEADER_FILES: + header_file = mlir_c_dir / local_name + if not header_file.exists(): + raise RuntimeError( + f"Expected MLIR-C header file does not exist: {header_file}" + ) + export_symbols.extend(collect_header_exports(header_file)) + + # Generate. + export_symbols.sort() + generate_macos_symbol_list(export_symbols, api_root / "api_exports.macos.lst") + generate_linker_script(export_symbols, api_root / "api_exports.ld") + generate_def_file(export_symbols, api_root / "api_exports.def") + generate_force_extern(export_symbols, api_root / "api_exports.c") def collect_header_exports(header_file: Path): - with open(header_file, "r") as f: - contents = f.read() + with open(header_file, "r") as f: + contents = f.read() - symbols = [] - for m in re.finditer(MACRO_STATEMENT_PATTERN, contents): - decl = m.group("decl") - decl_m = re.search(FUNC_DECL_SYMBOL_PATTERN, decl) - if decl_m: - symbol = decl_m.group("symbol") - symbols.append(symbol) - return symbols + symbols = [] + for m in re.finditer(MACRO_STATEMENT_PATTERN, contents): + decl = m.group("decl") + decl_m = re.search(FUNC_DECL_SYMBOL_PATTERN, decl) + if decl_m: + symbol = decl_m.group("symbol") + symbols.append(symbol) + return symbols def generate_macos_symbol_list(symbols: List[str], file: Path): - with open(file, "wt") as f: - f.write("# Generated by generate_exports.py: Do not edit.\n") - for symbol in symbols: - # Note that cdecl symbols on MacOS are prefixed with "_", same as - # we all did in the 80s but (thankfully) allowing longer than 8 character - # names. - f.write(f"_{symbol}\n") + with open(file, "wt") as f: + f.write("# Generated by generate_exports.py: Do not edit.\n") + for symbol in symbols: + # Note that cdecl symbols on MacOS are prefixed with "_", same as + # we all did in the 80s but (thankfully) allowing longer than 8 character + # names. + f.write(f"_{symbol}\n") def generate_linker_script(symbols: List[str], file: Path): - with open(file, "wt") as f: - f.write("# Generated by generate_exports.py: Do not edit.\n") - f.write("VER_0 {\n") - f.write(" global:\n") - for symbol in symbols: - f.write(f" {symbol};\n") - f.write(" local:\n") - f.write(" *;\n") - f.write("};\n") + with open(file, "wt") as f: + f.write("# Generated by generate_exports.py: Do not edit.\n") + f.write("VER_0 {\n") + f.write(" global:\n") + for symbol in symbols: + f.write(f" {symbol};\n") + f.write(" local:\n") + f.write(" *;\n") + f.write("};\n") def generate_def_file(symbols: List[str], file: Path): - with open(file, "wt") as f: - f.write("; Generated by generate_exports.py: Do not edit.\n") - f.write("EXPORTS\n") - for symbol in symbols: - f.write(f" {symbol}\n") + with open(file, "wt") as f: + f.write("; Generated by generate_exports.py: Do not edit.\n") + f.write("EXPORTS\n") + for symbol in symbols: + f.write(f" {symbol}\n") def generate_force_extern(symbols: List[str], file: Path): - with open(file, "wt") as f: - f.write("// Copyright 2022 The IREE Authors\n") - f.write("//\n") - f.write("// Licensed under the Apache License v2.0 with LLVM Exceptions.\n") - f.write("// See https://llvm.org/LICENSE.txt for license information.\n") - f.write("// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception\n") - f.write("\n") - f.write("// Generated by generate_exports.py: Do not edit.\n") - f.write("\n") - f.write("#include \n") - f.write("\n") - for symbol in symbols: - f.write(f"extern void {symbol}();\n") - f.write("\n") - f.write("uintptr_t __iree_compiler_hidden_force_extern() {\n") - f.write(" uintptr_t x = 0;\n") - for symbol in symbols: - f.write(f" x += (uintptr_t)&{symbol};\n") - f.write(" return x;\n") - f.write("}\n") + with open(file, "wt") as f: + f.write("// Copyright 2022 The IREE Authors\n") + f.write("//\n") + f.write("// Licensed under the Apache License v2.0 with LLVM Exceptions.\n") + f.write("// See https://llvm.org/LICENSE.txt for license information.\n") + f.write("// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception\n") + f.write("\n") + f.write("// Generated by generate_exports.py: Do not edit.\n") + f.write("\n") + f.write("#include \n") + f.write("\n") + for symbol in symbols: + f.write(f"extern void {symbol}();\n") + f.write("\n") + f.write("uintptr_t __iree_compiler_hidden_force_extern() {\n") + f.write(" uintptr_t x = 0;\n") + for symbol in symbols: + f.write(f" x += (uintptr_t)&{symbol};\n") + f.write(" return x;\n") + f.write("}\n") if __name__ == "__main__": - script_dir = Path(__file__).parent - repo_root = script_dir - while True: - # Key off of "AUTHORS" file - if (repo_root / "AUTHORS").exists(): - break - repo_root = repo_root.parent - if not repo_root: - raise RuntimeError(f"Could not find root of repo from {script_dir}") - - main(repo_root=repo_root, api_root=script_dir) + script_dir = Path(__file__).parent + repo_root = script_dir + while True: + # Key off of "AUTHORS" file + if (repo_root / "AUTHORS").exists(): + break + repo_root = repo_root.parent + if not repo_root: + raise RuntimeError(f"Could not find root of repo from {script_dir}") + + main(repo_root=repo_root, api_root=script_dir) diff --git a/configure_bazel.py b/configure_bazel.py index 40aea824ce23..abb877acfcee 100644 --- a/configure_bazel.py +++ b/configure_bazel.py @@ -11,59 +11,62 @@ def detect_unix_platform_config(bazelrc): - # This is hoaky. Ideally, bazel had any kind of rational way of selecting - # options from within its environment (key word: "rational"), but sadly, it - # is unintelligible to mere mortals. Why should a build system have a way for - # people to condition their build options on what compiler they are using - # (without descending down the hole of deciphering what a Bazel toolchain is)? - # All I want to do is set a couple of project specific warning options! + # This is hoaky. Ideally, bazel had any kind of rational way of selecting + # options from within its environment (key word: "rational"), but sadly, it + # is unintelligible to mere mortals. Why should a build system have a way for + # people to condition their build options on what compiler they are using + # (without descending down the hole of deciphering what a Bazel toolchain is)? + # All I want to do is set a couple of project specific warning options! - if platform.system() == "Darwin": - print(f"build --config=macos_clang", file=bazelrc) - print(f"build:release --config=macos_clang_release", file=bazelrc) - else: - - # If the user specified a CXX environment var, bazel will later respect that, - # so we just see if it says "clang". - cxx = os.environ.get("CXX") - cc = os.environ.get("CC") - if (cxx is not None and cc is None) or (cxx is None and cc is not None): - print("WARNING: Only one of CXX or CC is set, which can confuse bazel. " - "Recommend: set both appropriately (or none)") - if cc is not None and cxx is not None: - # Persist the variables. - print(f"build --action_env CC=\"{cc}\"", file=bazelrc) - print(f"build --action_env CXX=\"{cxx}\"", file=bazelrc) + if platform.system() == "Darwin": + print(f"build --config=macos_clang", file=bazelrc) + print(f"build:release --config=macos_clang_release", file=bazelrc) else: - print( - "WARNING: CC and CXX are not set, which can cause mismatches between " - "flag configurations and compiler. Recommend setting them explicitly." - ) + # If the user specified a CXX environment var, bazel will later respect that, + # so we just see if it says "clang". + cxx = os.environ.get("CXX") + cc = os.environ.get("CC") + if (cxx is not None and cc is None) or (cxx is None and cc is not None): + print( + "WARNING: Only one of CXX or CC is set, which can confuse bazel. " + "Recommend: set both appropriately (or none)" + ) + if cc is not None and cxx is not None: + # Persist the variables. + print(f'build --action_env CC="{cc}"', file=bazelrc) + print(f'build --action_env CXX="{cxx}"', file=bazelrc) + else: + print( + "WARNING: CC and CXX are not set, which can cause mismatches between " + "flag configurations and compiler. Recommend setting them explicitly." + ) - if cxx is not None and "clang" in cxx: - print( - f"Choosing generic_clang config because CXX is set to clang ({cxx})") - print(f"build --config=generic_clang", file=bazelrc) - print(f"build:release --config=generic_clang_release", file=bazelrc) - else: - print(f"Choosing generic_gcc config by default because no CXX set or " - f"not recognized as clang ({cxx})") - print(f"build --config=generic_gcc", file=bazelrc) - print(f"build:release --config=generic_gcc_release", file=bazelrc) + if cxx is not None and "clang" in cxx: + print(f"Choosing generic_clang config because CXX is set to clang ({cxx})") + print(f"build --config=generic_clang", file=bazelrc) + print(f"build:release --config=generic_clang_release", file=bazelrc) + else: + print( + f"Choosing generic_gcc config by default because no CXX set or " + f"not recognized as clang ({cxx})" + ) + print(f"build --config=generic_gcc", file=bazelrc) + print(f"build:release --config=generic_gcc_release", file=bazelrc) def write_platform(bazelrc): - if platform.system() == "Windows": - print(f"build --config=msvc", file=bazelrc) - print(f"build:release --config=msvc_release", file=bazelrc) - else: - detect_unix_platform_config(bazelrc) + if platform.system() == "Windows": + print(f"build --config=msvc", file=bazelrc) + print(f"build:release --config=msvc_release", file=bazelrc) + else: + detect_unix_platform_config(bazelrc) + if len(sys.argv) > 1: - local_bazelrc = sys.argv[1] + local_bazelrc = sys.argv[1] else: - local_bazelrc = os.path.join(os.path.dirname(__file__), "configured.bazelrc") + local_bazelrc = os.path.join(os.path.dirname(__file__), "configured.bazelrc") with open(local_bazelrc, "wt") as bazelrc: - write_platform(bazelrc) + write_platform(bazelrc) print("Wrote", local_bazelrc) diff --git a/docs/api_docs/python/conf.py b/docs/api_docs/python/conf.py index da752fc3a4e7..06ad6e1e1126 100644 --- a/docs/api_docs/python/conf.py +++ b/docs/api_docs/python/conf.py @@ -16,12 +16,12 @@ # -- Project information ----------------------------------------------------- -project = 'IREE Python API' -copyright = '2021, IREE Authors' -author = 'IREE Authors' +project = "IREE Python API" +copyright = "2021, IREE Authors" +author = "IREE Authors" # The full version, including alpha/beta/rc tags -release = 'snapshot' +release = "snapshot" # -- General configuration --------------------------------------------------- @@ -29,22 +29,22 @@ # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = [ - 'myst_parser', - 'sphinx.ext.autodoc', - 'sphinx.ext.intersphinx', - 'sphinx.ext.napoleon', - 'enum_tools.autoenum', + "myst_parser", + "sphinx.ext.autodoc", + "sphinx.ext.intersphinx", + "sphinx.ext.napoleon", + "enum_tools.autoenum", ] # Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] +templates_path = ["_templates"] # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. # This pattern also affects html_static_path and html_extra_path. -exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] +exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] -intersphinx_mapping = {'python': ('https://docs.python.org/3', None)} +intersphinx_mapping = {"python": ("https://docs.python.org/3", None)} napoleon_google_docstring = True napoleon_numpy_docstring = False @@ -54,12 +54,12 @@ # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. # -html_theme = 'sphinx_rtd_theme' +html_theme = "sphinx_rtd_theme" # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['_static'] +html_static_path = ["_static"] # -- Markdown from docstrings ------------------------------------------------ diff --git a/experimental/dispatch_profiler/batch_matmul.py b/experimental/dispatch_profiler/batch_matmul.py index 86f1730fdbcf..9b8c401fdfdf 100644 --- a/experimental/dispatch_profiler/batch_matmul.py +++ b/experimental/dispatch_profiler/batch_matmul.py @@ -9,37 +9,40 @@ class BatchMatmulOperation(MatmulOperation): - """Data structure to describe a batch matrix multiplication operation.""" + """Data structure to describe a batch matrix multiplication operation.""" - def __init__(self, bmm_shape, lhs, rhs, result): - assert len(bmm_shape) == 4, "Batch matmul shape must be 4D" - super().__init__(bmm_shape[1:], lhs, rhs, result, bmm_shape[0], 1, - OperationKind.BatchMatmul) + def __init__(self, bmm_shape, lhs, rhs, result): + assert len(bmm_shape) == 4, "Batch matmul shape must be 4D" + super().__init__( + bmm_shape[1:], lhs, rhs, result, bmm_shape[0], 1, OperationKind.BatchMatmul + ) - def name(self): - return f'{OperationKindNames[self.operation_kind]}_'\ - f'{self.batch_count}x{self.M}x{self.N}x{self.K}_'\ - f'{DataTypeName[self.lhs.datatype]}{ShortLayoutTypeName[self.lhs.layout]}_'\ - f'{DataTypeName[self.rhs.datatype]}{ShortLayoutTypeName[self.rhs.layout]}_'\ - f'{DataTypeName[self.result.datatype]}{ShortLayoutTypeName[self.result.layout]}' + def name(self): + return ( + f"{OperationKindNames[self.operation_kind]}_" + f"{self.batch_count}x{self.M}x{self.N}x{self.K}_" + f"{DataTypeName[self.lhs.datatype]}{ShortLayoutTypeName[self.lhs.layout]}_" + f"{DataTypeName[self.rhs.datatype]}{ShortLayoutTypeName[self.rhs.layout]}_" + f"{DataTypeName[self.result.datatype]}{ShortLayoutTypeName[self.result.layout]}" + ) - def lhs_npy_shape(self): - return f'{self.batch_count}x{super().lhs_npy_shape()}' + def lhs_npy_shape(self): + return f"{self.batch_count}x{super().lhs_npy_shape()}" - def rhs_npy_shape(self): - return f'{self.batch_count}x{super().rhs_npy_shape()}' + def rhs_npy_shape(self): + return f"{self.batch_count}x{super().rhs_npy_shape()}" - def result_npy_shape(self): - return f'{self.batch_count}x{super().result_npy_shape()}' + def result_npy_shape(self): + return f"{self.batch_count}x{super().result_npy_shape()}" class EmitLinalgBatchMatmulDispatch: - """Emitters for the `linalg.batch_matmul` dispatch.""" + """Emitters for the `linalg.batch_matmul` dispatch.""" - def __init__(self): - self.mlir_dialect = MlirDialect.Linalg + def __init__(self): + self.mlir_dialect = MlirDialect.Linalg - self.linalg_row_row_matmul_template = """ + self.linalg_row_row_matmul_template = """ // Dispatch linalg.batch_matmul row-row layout func.func @${operation_name}_${compilation_info_name}( %lhs: tensor<${batch_count}x${problem_m}x${problem_k}x${datatype_lhs}>, @@ -55,179 +58,207 @@ def __init__(self): } """ - def emit(self, dispatch): - """Emit the matmul operation in the MLIR dialect for a single compilation info""" - compilation_info_attribute_template = """{compilation_info = #${compilation_info_name}}""" - compilation_info_attribute_str = SubstituteTemplate( - compilation_info_attribute_template, - {'compilation_info_name': dispatch.configuration.name()}) - compilation_info_attribute = compilation_info_attribute_str \ - if dispatch.configuration.config_type != CompilationConfigType.Default else "" - - values = { - 'operation_name': dispatch.operation.name(), - 'compilation_info_attribute': compilation_info_attribute, - 'batch_count': str(dispatch.operation.batch_count), - 'problem_m': str(dispatch.operation.M), - 'problem_n': str(dispatch.operation.N), - 'problem_k': str(dispatch.operation.K), - 'datatype_lhs': DataTypeName[dispatch.operation.lhs.datatype], - 'datatype_rhs': DataTypeName[dispatch.operation.rhs.datatype], - 'datatype_result': DataTypeName[dispatch.operation.result.datatype], - 'compilation_info_name': dispatch.configuration.name() - } - - return SubstituteTemplate(self.linalg_row_row_matmul_template, values) + def emit(self, dispatch): + """Emit the matmul operation in the MLIR dialect for a single compilation info""" + compilation_info_attribute_template = ( + """{compilation_info = #${compilation_info_name}}""" + ) + compilation_info_attribute_str = SubstituteTemplate( + compilation_info_attribute_template, + {"compilation_info_name": dispatch.configuration.name()}, + ) + compilation_info_attribute = ( + compilation_info_attribute_str + if dispatch.configuration.config_type != CompilationConfigType.Default + else "" + ) + + values = { + "operation_name": dispatch.operation.name(), + "compilation_info_attribute": compilation_info_attribute, + "batch_count": str(dispatch.operation.batch_count), + "problem_m": str(dispatch.operation.M), + "problem_n": str(dispatch.operation.N), + "problem_k": str(dispatch.operation.K), + "datatype_lhs": DataTypeName[dispatch.operation.lhs.datatype], + "datatype_rhs": DataTypeName[dispatch.operation.rhs.datatype], + "datatype_result": DataTypeName[dispatch.operation.result.datatype], + "compilation_info_name": dispatch.configuration.name(), + } + + return SubstituteTemplate(self.linalg_row_row_matmul_template, values) class ReferenceBatchMatmulOp(ReferenceOpInterface): - """Reference implementation for the batch matmul operation in numpy.""" - - def __init__(self, bmm_operation, op_reference_cache_path, dist_lhs, - dist_rhs): - self.bmm_operation = bmm_operation - self.op_reference_cache_path = op_reference_cache_path - - if not self.op_reference_cache_path.exists(): - self.op_reference_cache_path.mkdir() - - # Problem shape. - self.batch_count = bmm_operation.batch_count - self.M = bmm_operation.M - self.N = bmm_operation.N - self.K = bmm_operation.K - - # Data type for the input and result matrices. - self.dtype_lhs = DataTypeNumPyTag[bmm_operation.lhs.datatype] - self.dtype_rhs = DataTypeNumPyTag[bmm_operation.rhs.datatype] - self.dtype_result = DataTypeNumPyTag[bmm_operation.result.datatype] - - # Distribution of the input tensors. - self.dist_lhs = dist_lhs - self.dist_rhs = dist_rhs - - # Filename for the left hand side input tensor. - self.filename_lhs = "batch_count{batch_count}xm{problem_m}xk{problem_k}_"\ - "{tensor_description}_{dist}_lhs.npy".format( - batch_count=self.batch_count, - problem_m=self.M, - problem_k=self.K, - tensor_description=self.bmm_operation.lhs.name(), - dist=DistributionName[self.dist_lhs]) - - # Filename for the right hand side input tensor. - self.filename_rhs = "batch_count{batch_count}xk{problem_k}xn{problem_n}_"\ - "{tensor_description}_{dist}_rhs.npy".format( - batch_count=self.batch_count, - problem_k=self.K, - problem_n=self.N, - tensor_description=self.bmm_operation.rhs.name(), - dist=DistributionName[self.dist_rhs]) - - # Filename for the reference result tensor. - self.filename_reference_result = "batch_count{batch_count}xm{problem_m}xn{problem_n}_"\ - "{tensor_description}_reference_result.npy".format( - batch_count=self.batch_count, - problem_m=self.M, - problem_n=self.N, - tensor_description=self.bmm_operation.result.name()) - - # Filepath for input and output files. - self.filepath_lhs = self.op_reference_cache_path.joinpath(self.filename_lhs) - self.filepath_rhs = self.op_reference_cache_path.joinpath(self.filename_rhs) - self.filepath_reference_result = self.op_reference_cache_path.joinpath( - self.filename_reference_result) - - def get_input_filepaths(self): - """Returns the list of input file paths.""" - return [self.filepath_lhs, self.filepath_rhs] - - def get_output_filepaths(self): - """Returns the list of expected output file paths.""" - return [self.filepath_reference_result] - - def __call__(self): - """Generates input data, runs reference numpy.matmul, and save npy files to the output directory.""" - # Generate the input data as np.array for the matmul operation. - lhs_np_array = get_np_array(self.bmm_operation.lhs, - (self.batch_count, self.M, self.K), - self.dist_lhs) - rhs_np_array = get_np_array(self.bmm_operation.rhs, - (self.batch_count, self.K, self.N), - self.dist_rhs) - - # Run the reference np.matmul and generate result np.array. - result = np.matmul(lhs_np_array, rhs_np_array) - - # Save the input data as np.array for the matmul operation. - np.save(self.filepath_lhs, np.array(lhs_np_array, dtype=self.dtype_lhs)) - np.save(self.filepath_rhs, np.array(rhs_np_array, dtype=self.dtype_rhs)) - - # Save the expected result as an np.array. - np.save(self.filepath_reference_result, - np.array(result, dtype=self.dtype_result)) + """Reference implementation for the batch matmul operation in numpy.""" + + def __init__(self, bmm_operation, op_reference_cache_path, dist_lhs, dist_rhs): + self.bmm_operation = bmm_operation + self.op_reference_cache_path = op_reference_cache_path + + if not self.op_reference_cache_path.exists(): + self.op_reference_cache_path.mkdir() + + # Problem shape. + self.batch_count = bmm_operation.batch_count + self.M = bmm_operation.M + self.N = bmm_operation.N + self.K = bmm_operation.K + + # Data type for the input and result matrices. + self.dtype_lhs = DataTypeNumPyTag[bmm_operation.lhs.datatype] + self.dtype_rhs = DataTypeNumPyTag[bmm_operation.rhs.datatype] + self.dtype_result = DataTypeNumPyTag[bmm_operation.result.datatype] + + # Distribution of the input tensors. + self.dist_lhs = dist_lhs + self.dist_rhs = dist_rhs + + # Filename for the left hand side input tensor. + self.filename_lhs = ( + "batch_count{batch_count}xm{problem_m}xk{problem_k}_" + "{tensor_description}_{dist}_lhs.npy".format( + batch_count=self.batch_count, + problem_m=self.M, + problem_k=self.K, + tensor_description=self.bmm_operation.lhs.name(), + dist=DistributionName[self.dist_lhs], + ) + ) + + # Filename for the right hand side input tensor. + self.filename_rhs = ( + "batch_count{batch_count}xk{problem_k}xn{problem_n}_" + "{tensor_description}_{dist}_rhs.npy".format( + batch_count=self.batch_count, + problem_k=self.K, + problem_n=self.N, + tensor_description=self.bmm_operation.rhs.name(), + dist=DistributionName[self.dist_rhs], + ) + ) + + # Filename for the reference result tensor. + self.filename_reference_result = ( + "batch_count{batch_count}xm{problem_m}xn{problem_n}_" + "{tensor_description}_reference_result.npy".format( + batch_count=self.batch_count, + problem_m=self.M, + problem_n=self.N, + tensor_description=self.bmm_operation.result.name(), + ) + ) + + # Filepath for input and output files. + self.filepath_lhs = self.op_reference_cache_path.joinpath(self.filename_lhs) + self.filepath_rhs = self.op_reference_cache_path.joinpath(self.filename_rhs) + self.filepath_reference_result = self.op_reference_cache_path.joinpath( + self.filename_reference_result + ) + + def get_input_filepaths(self): + """Returns the list of input file paths.""" + return [self.filepath_lhs, self.filepath_rhs] + + def get_output_filepaths(self): + """Returns the list of expected output file paths.""" + return [self.filepath_reference_result] + + def __call__(self): + """Generates input data, runs reference numpy.matmul, and save npy files to the output directory.""" + # Generate the input data as np.array for the matmul operation. + lhs_np_array = get_np_array( + self.bmm_operation.lhs, (self.batch_count, self.M, self.K), self.dist_lhs + ) + rhs_np_array = get_np_array( + self.bmm_operation.rhs, (self.batch_count, self.K, self.N), self.dist_rhs + ) + + # Run the reference np.matmul and generate result np.array. + result = np.matmul(lhs_np_array, rhs_np_array) + + # Save the input data as np.array for the matmul operation. + np.save(self.filepath_lhs, np.array(lhs_np_array, dtype=self.dtype_lhs)) + np.save(self.filepath_rhs, np.array(rhs_np_array, dtype=self.dtype_rhs)) + + # Save the expected result as an np.array. + np.save( + self.filepath_reference_result, np.array(result, dtype=self.dtype_result) + ) ############################################################################## class CudaBatchMatmulGenerator(CudaMatmulGenerator): - """Batch matmul dispatch generator class. """ - - def __init__(self, args): - """Initializes the batch matmul dispatch generator.""" - super().__init__(args) - - # Predefined batch matmul problem shapes. - self.batch_matmul_shapes = [[16, 512, 64, 512]] - - # Batch matmul dispatches collection. - self.dispatches_collection_list = [] - - def _append_matmul_dispatch_collection(self, bmm_shapes, data_type, - configuration_list): - """Update the batch matmul dispatch collection with the given configuration list.""" - - # Create dispatches collection for each problem shape with the configuration list. - for bmm_shape in bmm_shapes: - operation = BatchMatmulOperation( - bmm_shape,\ - TensorDescription(data_type[0], LayoutType.RowMajor), \ - TensorDescription(data_type[1], LayoutType.RowMajor), \ - TensorDescription(data_type[2], LayoutType.RowMajor)) - - # Filter out configurations that are not supported by LLVM GPU CUDA backend. - supported_configuration_list = self._cuda_supported_configuration_list( - operation, configuration_list) - - # Add default configuration if enabled. - if self.args.default_config: - supported_configuration_list.append( - MatmulCompilationInfo([], [], OperationKind.BatchMatmul, - CompilationConfigType.Default)) - - # Append the dispatches collection. - self.dispatches_collection_list.append(DispatchCollection(\ - operation, supported_configuration_list)) - - def _cuda_matmul_tensor_cores_f16(self): - """Appends a list of matmul dispatches for GPU TensorCore F16 data type.""" - configuration_list = self._get_matmul_custom_compilation_info_list( - self.tile_descriptions_tensor_cores_f16, self.translation_infos, - OperationKind.BatchMatmul) - data_type = [DataType.f16, DataType.f16, DataType.f16] - self._append_matmul_dispatch_collection(self.batch_matmul_shapes, data_type, - configuration_list) - - def _cuda_matmul_tensor_cores_f32(self): - """Appends a list of matmul dispatches for GPU TensorCore F32 data type.""" - configuration_list = self._get_matmul_custom_compilation_info_list( - self.tile_descriptions_tensor_cores_f32, self.translation_infos, - OperationKind.BatchMatmul) - data_type = [DataType.f32, DataType.f32, DataType.f32] - self._append_matmul_dispatch_collection(self.batch_matmul_shapes, data_type, - configuration_list) - - def generate(self): - """Generates a list of matmul operations.""" - self._cuda_matmul_tensor_cores_f16() - self._cuda_matmul_tensor_cores_f32() - return self.dispatches_collection_list + """Batch matmul dispatch generator class.""" + + def __init__(self, args): + """Initializes the batch matmul dispatch generator.""" + super().__init__(args) + + # Predefined batch matmul problem shapes. + self.batch_matmul_shapes = [[16, 512, 64, 512]] + + # Batch matmul dispatches collection. + self.dispatches_collection_list = [] + + def _append_matmul_dispatch_collection( + self, bmm_shapes, data_type, configuration_list + ): + """Update the batch matmul dispatch collection with the given configuration list.""" + + # Create dispatches collection for each problem shape with the configuration list. + for bmm_shape in bmm_shapes: + operation = BatchMatmulOperation( + bmm_shape, + TensorDescription(data_type[0], LayoutType.RowMajor), + TensorDescription(data_type[1], LayoutType.RowMajor), + TensorDescription(data_type[2], LayoutType.RowMajor), + ) + + # Filter out configurations that are not supported by LLVM GPU CUDA backend. + supported_configuration_list = self._cuda_supported_configuration_list( + operation, configuration_list + ) + + # Add default configuration if enabled. + if self.args.default_config: + supported_configuration_list.append( + MatmulCompilationInfo( + [], [], OperationKind.BatchMatmul, CompilationConfigType.Default + ) + ) + + # Append the dispatches collection. + self.dispatches_collection_list.append( + DispatchCollection(operation, supported_configuration_list) + ) + + def _cuda_matmul_tensor_cores_f16(self): + """Appends a list of matmul dispatches for GPU TensorCore F16 data type.""" + configuration_list = self._get_matmul_custom_compilation_info_list( + self.tile_descriptions_tensor_cores_f16, + self.translation_infos, + OperationKind.BatchMatmul, + ) + data_type = [DataType.f16, DataType.f16, DataType.f16] + self._append_matmul_dispatch_collection( + self.batch_matmul_shapes, data_type, configuration_list + ) + + def _cuda_matmul_tensor_cores_f32(self): + """Appends a list of matmul dispatches for GPU TensorCore F32 data type.""" + configuration_list = self._get_matmul_custom_compilation_info_list( + self.tile_descriptions_tensor_cores_f32, + self.translation_infos, + OperationKind.BatchMatmul, + ) + data_type = [DataType.f32, DataType.f32, DataType.f32] + self._append_matmul_dispatch_collection( + self.batch_matmul_shapes, data_type, configuration_list + ) + + def generate(self): + """Generates a list of matmul operations.""" + self._cuda_matmul_tensor_cores_f16() + self._cuda_matmul_tensor_cores_f32() + return self.dispatches_collection_list diff --git a/experimental/dispatch_profiler/compile.py b/experimental/dispatch_profiler/compile.py index f8b11fc3d3c1..12a67bba73dd 100644 --- a/experimental/dispatch_profiler/compile.py +++ b/experimental/dispatch_profiler/compile.py @@ -18,43 +18,48 @@ ############################################################################### if __name__ == "__main__": - ############################################################################### - # Parse command line arguments - ############################################################################### - parser = argparse.ArgumentParser( - description= - "IREE Python compile tool for launching iree-compile for verification and "\ - "profiling. Issues iree-compile for a given backend device and iree-compile "\ - "flags. Uses ThreadPoolExecutor to launch multiple iree-compile processes "\ - "in parallel.") - - args = parse_compile_arguments(parser) - ############################################################################### - - # Manifests metadata for a group of accompanying operations and configurations. - manifest = Manifest(args) - manifest.load() - - # Try and use all CPUs to launch iree-compile in parallel. - cpu_count = os.cpu_count() - if args.num_cpu > 0: - cpu_count = min(cpu_count, args.num_cpu) - - # For all the operations in the manifest, issue iree-compile for verification - # and profiling in parallel using ThreadPoolExecutor and cpu_count threads. - cmds = [] - with ThreadPoolExecutor(max_workers=cpu_count) as executor: - - # For all the operations in the manifest compile, verify, and profile. - for _, dispatch_collection_list in manifest.dispatch_collection_map.items(): - for dispatch_collection in dispatch_collection_list: - # Create an instance of operation_launcher. - operation = dispatch_collection.operation - operation_launcher = IreeToolsLauncher(args, operation) - for configuration in dispatch_collection.configuration_list: - for compile_mode in [CompilationMode.Profile, CompilationMode.Verify]: - cmds.append(executor.submit(\ - operation_launcher.iree_compile, compile_mode)) - - # Wait for all the commands to complete. - results = [cmd.result() for cmd in cmds] + ############################################################################### + # Parse command line arguments + ############################################################################### + parser = argparse.ArgumentParser( + description="IREE Python compile tool for launching iree-compile for verification and " + "profiling. Issues iree-compile for a given backend device and iree-compile " + "flags. Uses ThreadPoolExecutor to launch multiple iree-compile processes " + "in parallel." + ) + + args = parse_compile_arguments(parser) + ############################################################################### + + # Manifests metadata for a group of accompanying operations and configurations. + manifest = Manifest(args) + manifest.load() + + # Try and use all CPUs to launch iree-compile in parallel. + cpu_count = os.cpu_count() + if args.num_cpu > 0: + cpu_count = min(cpu_count, args.num_cpu) + + # For all the operations in the manifest, issue iree-compile for verification + # and profiling in parallel using ThreadPoolExecutor and cpu_count threads. + cmds = [] + with ThreadPoolExecutor(max_workers=cpu_count) as executor: + # For all the operations in the manifest compile, verify, and profile. + for _, dispatch_collection_list in manifest.dispatch_collection_map.items(): + for dispatch_collection in dispatch_collection_list: + # Create an instance of operation_launcher. + operation = dispatch_collection.operation + operation_launcher = IreeToolsLauncher(args, operation) + for configuration in dispatch_collection.configuration_list: + for compile_mode in [ + CompilationMode.Profile, + CompilationMode.Verify, + ]: + cmds.append( + executor.submit( + operation_launcher.iree_compile, compile_mode + ) + ) + + # Wait for all the commands to complete. + results = [cmd.result() for cmd in cmds] diff --git a/experimental/dispatch_profiler/dispatch.py b/experimental/dispatch_profiler/dispatch.py index 86cbe7807529..c003271a0158 100644 --- a/experimental/dispatch_profiler/dispatch.py +++ b/experimental/dispatch_profiler/dispatch.py @@ -9,54 +9,55 @@ ################################################################################ class Dispatch: - """ - Dispatch: A combination of an operation and a configuration is launched by - the dispatch profiler for verification and performance profiling. Note that - a dispatch is not a MLIR operation it is binary executable that is launched - by the profiler. Additionaly, the goal of the tool is to also profile the - performance of the fusions and a dispatch for fusion is a combination of - multiple operations glued together and compiled into a single dispatch. - """ + """ + Dispatch: A combination of an operation and a configuration is launched by + the dispatch profiler for verification and performance profiling. Note that + a dispatch is not a MLIR operation it is binary executable that is launched + by the profiler. Additionaly, the goal of the tool is to also profile the + performance of the fusions and a dispatch for fusion is a combination of + multiple operations glued together and compiled into a single dispatch. + """ - def __init__(self, operation, configuration): - self.operation = operation - self.configuration = configuration - self.is_fused_dispatch = False + def __init__(self, operation, configuration): + self.operation = operation + self.configuration = configuration + self.is_fused_dispatch = False - def name(self): - return f"{self.operation.name()}_{self.configuration.name()}" + def name(self): + return f"{self.operation.name()}_{self.configuration.name()}" ################################################################################ class DispatchCollection: - """ - DispatchCollection: A collection of dispatches that only vary in their - configurations but not in their operations. For example, a collection - of matmul dispatches with different tile sizes. - - We can emit a single MLIR file for all the dispatches in a collection - and compile with single run of iree-compile and them into a single executable - """ - - def __init__(self, operation, configuration_list): - self.operation = operation - self.configuration_list = configuration_list - - def get_dispatches(self): - """Returns a list of dispatches in the collection.""" - dispatches = [] - for configuration in self.configuration_list: - dispatches.append(Dispatch(self.operation, configuration)) - return dispatches - - def append(self, dispatch): - """Appends a dispatch to the collection.""" - if dispatch.operation != self.operation: - raise ValueError( - f"operation {self.operation.name()} does not match the dispatch " - f"collection operation name {dispatch.operation.name()}.") - self.configuration_list.append(dispatch.configuration) - - def num_of_dispatches(self): - """Returns number of dispatches in the collection.""" - return len(self.configuration_list) + """ + DispatchCollection: A collection of dispatches that only vary in their + configurations but not in their operations. For example, a collection + of matmul dispatches with different tile sizes. + + We can emit a single MLIR file for all the dispatches in a collection + and compile with single run of iree-compile and them into a single executable + """ + + def __init__(self, operation, configuration_list): + self.operation = operation + self.configuration_list = configuration_list + + def get_dispatches(self): + """Returns a list of dispatches in the collection.""" + dispatches = [] + for configuration in self.configuration_list: + dispatches.append(Dispatch(self.operation, configuration)) + return dispatches + + def append(self, dispatch): + """Appends a dispatch to the collection.""" + if dispatch.operation != self.operation: + raise ValueError( + f"operation {self.operation.name()} does not match the dispatch " + f"collection operation name {dispatch.operation.name()}." + ) + self.configuration_list.append(dispatch.configuration) + + def num_of_dispatches(self): + """Returns number of dispatches in the collection.""" + return len(self.configuration_list) diff --git a/experimental/dispatch_profiler/generator.py b/experimental/dispatch_profiler/generator.py index 21614b18c146..59e6b95f2352 100644 --- a/experimental/dispatch_profiler/generator.py +++ b/experimental/dispatch_profiler/generator.py @@ -13,17 +13,18 @@ ############################################################################### if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Generates MLIR operations for " + "verification and profiling of IREE compiled dispatches." + ) - parser = argparse.ArgumentParser(description="Generates MLIR operations for "\ - "verification and profiling of IREE compiled dispatches.") + args = parse_generator_arguments(parser) - args = parse_generator_arguments(parser) + # Manifest dispatches for a group of accompanying operations and configurations. + manifest = Manifest(args) - # Manifest dispatches for a group of accompanying operations and configurations. - manifest = Manifest(args) + # Load all the pre-defined dispatches in a manifest. + manifest.initialize() - # Load all the pre-defined dispatches in a manifest. - manifest.initialize() - - # Emit the dispatches in MLIR source files. - manifest.emit() + # Emit the dispatches in MLIR source files. + manifest.emit() diff --git a/experimental/dispatch_profiler/launchers.py b/experimental/dispatch_profiler/launchers.py index 1ea4931758a6..e97ecf27713a 100644 --- a/experimental/dispatch_profiler/launchers.py +++ b/experimental/dispatch_profiler/launchers.py @@ -12,210 +12,227 @@ class IreeToolsLauncher: - """Launcher for IREE tools.""" - - def __init__(self, args, operation): - self.operation = operation - - self.generated_path = Path(args.generated_dir, 'generated', - args.mlir_dialect) - - self.args = args - self.benchmark_dispatch_repeat_count = args.batch_size - self.batch_size = args.batch_size - - # paths to source dispatch mlir, compiled vmfb, and logs. - self.operation_path = self.generated_path.joinpath( - OperationKindNames[operation.operation_kind], operation.name()) - - self.source_mlir_file = self.operation_path.joinpath( - operation.name()).with_suffix(".mlir") - - # path to cached numpy refernece input and expected output files. - self.op_reference_cache_path = Path(args.generated_dir, 'generated', - 'reference_cache', operation.name()) - - if not self.op_reference_cache_path.exists(): - self.op_reference_cache_path.mkdir(parents=True, exist_ok=True) - - # path to iree-compile tool. (for compiling the input mlir file to vmfb) - self.iree_compile_path = Path(args.iree_bin_dir, 'iree-compile') - - # path to iree-benchmark-module tool. (for performance benchmarking and profiling) - self.iree_benchmark_module_path = Path(args.iree_bin_dir, - 'iree-benchmark-module') - - # path to iree-run-module tool. (for verification) - self.iree_run_module_path = Path(args.iree_bin_dir, 'iree-run-module') - - # output vmfb files for verification and profiling. - vmfb_filename = f"{operation.name()}" - - if operation.operation_kind == OperationKind.SplitkMatmul: - split_k_suffix = "_".join( - ['split_k_slice', str(operation.split_k_slices)]) - vmfb_filename = f"{vmfb_filename}_{split_k_suffix}" - - self.vmfb_verify_filepath = self.operation_path.joinpath( - self.operation.name()).with_name(f"{vmfb_filename}_verify.vmfb") - self.vmfb_profile_filepath = self.operation_path.joinpath( - self.operation.name()).with_name(f"{vmfb_filename}_profile.vmfb") - - # reference implementation for the operation_kind. - self.reference_impl_map = { - OperationKind.Matmul: ReferenceMatmulOp, - OperationKind.SplitkMatmul: ReferenceMatmulOp, - OperationKind.BatchMatmul: ReferenceBatchMatmulOp, - } - - def iree_compile(self, compilation_mode): - """Compiles the input mlir file to vmfb file.""" - - benchmark_dispatch_repeat_count = self.benchmark_dispatch_repeat_count if compilation_mode == CompilationMode.Profile else 1 - vmfb_filepath = self.vmfb_profile_filepath if compilation_mode == CompilationMode.Profile else self.vmfb_verify_filepath - - # Base iree-compile commandline - cmd = [ - f'{self.iree_compile_path}', - f'{self.source_mlir_file}', - "-o", - f'{vmfb_filepath}', - ] - - # General compilation options - cmd += [f"--iree-hal-target-backends={self.args.device}"] - - if self.args.device == "cuda": - cmd += [f"--iree-hal-cuda-llvm-target-arch={self.args.cuda_arch}"] - if self.operation.operation_kind == OperationKind.SplitkMatmul: - cmd += [ - f"--iree-flow-split-matmul-reduction={self.operation.split_k_slices}" - ] - if self.args.use_mma_sync: - cmd += [f"--iree-codegen-llvmgpu-use-mma-sync"] - if self.args.use_wmma: - cmd += [f"--iree-codegen-llvmgpu-use-wmma"] - - # Compilation options for profiling - cmd += [ - f"--iree-hal-benchmark-dispatch-repeat-count={benchmark_dispatch_repeat_count}" - ] - - # Appends print ir options at the end of the command line. - if self.args.mlir_print_ir_after_all: - cmd += [f"--mlir-print-ir-after-all"] - - if not vmfb_filepath.exists() or self.args.force_compile: - complie_mode_str = CompilationModeNames[compilation_mode] - - print(f"[Compiling ({complie_mode_str})] {' '.join(cmd)}") - - iree_compile_stdout_filepath = self.operation_path.joinpath( - 'iree_compile_cmd_stdout.mlir') - - with open(iree_compile_stdout_filepath, "w") as fp: - subprocess.run(cmd, stderr=fp) - - elif self.args.verbose: - print( - f"Skipping compilation of operation: {vmfb_filepath} since it already exists." - ) - - def verify(self, configuration): - """Verifies the operation with a given configuration.""" - # First compile the operation to a vmfb file. - self.iree_compile(CompilationMode.Verify) - - # Verify using random data distribution. - reference_run = self.reference_impl_map[self.operation.operation_kind]( - self.operation, self.op_reference_cache_path, Distribution.Random, - Distribution.Random) - - if not reference_run.is_cached(): - reference_run() - - # Commandline `iree-run-module` for verification. - cmd = [ - f'{self.iree_run_module_path}', f'--module={self.vmfb_verify_filepath}', - f'--device={self.args.device}' - ] - - # Operation-specific verification command-line. - cmd.append(f'--function={self.operation.name()}_{configuration.name()}') - for input_file_path in reference_run.get_input_filepaths(): - cmd.append(f'--input=@{input_file_path}') - for output_file_path in reference_run.get_output_filepaths(): - cmd.append(f'--expected_output=@{output_file_path}') - - # Print the command if verbose. - if self.args.verbose: - print(f"[Verification] {' '.join(cmd)}") - - # Launch verification. - cmd_output = subprocess.check_output(cmd, text=True) - - # Save the verification command and the output, only if requested - # (file writing could slow down the verification). - if self.args.save_cmds: - filepath = self.operation_path.joinpath("iree_run_module.stdout") - with open(filepath, "w") as fp: - fp.write(f"[Command] $ {' '.join(cmd)}\n") - fp.write(cmd_output) - - # Parse the verification output. - m = re.search(r"\[(?P[a-zA-Z]+)\]", cmd_output) - if m is None: - raise ValueError( - f"Failed to parse verification output by iree-run-module: {cmd_output}" - ) - verification_result = m.group('verification_result') - - if self.args.verbose or verification_result != "SUCCESS": - print(cmd_output) - - return verification_result - - def profile(self, configuration): - """Profiles the operation with a given configuration.""" - # First compile the operation to a vmfb file. - self.iree_compile(CompilationMode.Profile) - - # Commandline `iree-benchmark-module` for profiling. - cmd = [ - f'{self.iree_benchmark_module_path}', - f'--module={self.vmfb_profile_filepath}', f'--device={self.args.device}' - ] - - # Profiling specific flags. - cmd += [f'--benchmark_repetitions={self.args.benchmark_repetitions}'] - cmd += [f'--batch_size={self.batch_size}'] - - # Operation-specific profiling command-line. - cmd += [f'--function={self.operation.name()}_{configuration.name()}'] - cmd += [f'--input={self.operation.lhs_npy_shape()}'] - cmd += [f'--input={self.operation.rhs_npy_shape()}'] - - # Print the command if verbose. - if self.args.verbose: - print(f"[Profiling] {' '.join(cmd)}") - - # Launch profiling. - cmd_output = subprocess.check_output(cmd, - text=True, - stderr=subprocess.STDOUT) - - # Save the profiling command and the output, only if requested - # (file writing could slow down the profiling). - if self.args.save_cmds: - filepath = self.operation_path.joinpath("iree_benchmark_module.stdout") - with open(filepath, "w") as fp: - fp.write(f"[Command] $ {' '.join(cmd)}\n") - fp.write(cmd_output) - - # Parse the profiling output. - m = re.search(r"real_time_median\s+(?P\d+.\d+)\s+ms", cmd_output) - if m is None: - raise ValueError( - f"Failed to parse runtime from benchmark result: {cmd_output}") - runtime_in_ms = float(m.group('runtime')) - return runtime_in_ms + """Launcher for IREE tools.""" + + def __init__(self, args, operation): + self.operation = operation + + self.generated_path = Path(args.generated_dir, "generated", args.mlir_dialect) + + self.args = args + self.benchmark_dispatch_repeat_count = args.batch_size + self.batch_size = args.batch_size + + # paths to source dispatch mlir, compiled vmfb, and logs. + self.operation_path = self.generated_path.joinpath( + OperationKindNames[operation.operation_kind], operation.name() + ) + + self.source_mlir_file = self.operation_path.joinpath( + operation.name() + ).with_suffix(".mlir") + + # path to cached numpy refernece input and expected output files. + self.op_reference_cache_path = Path( + args.generated_dir, "generated", "reference_cache", operation.name() + ) + + if not self.op_reference_cache_path.exists(): + self.op_reference_cache_path.mkdir(parents=True, exist_ok=True) + + # path to iree-compile tool. (for compiling the input mlir file to vmfb) + self.iree_compile_path = Path(args.iree_bin_dir, "iree-compile") + + # path to iree-benchmark-module tool. (for performance benchmarking and profiling) + self.iree_benchmark_module_path = Path( + args.iree_bin_dir, "iree-benchmark-module" + ) + + # path to iree-run-module tool. (for verification) + self.iree_run_module_path = Path(args.iree_bin_dir, "iree-run-module") + + # output vmfb files for verification and profiling. + vmfb_filename = f"{operation.name()}" + + if operation.operation_kind == OperationKind.SplitkMatmul: + split_k_suffix = "_".join(["split_k_slice", str(operation.split_k_slices)]) + vmfb_filename = f"{vmfb_filename}_{split_k_suffix}" + + self.vmfb_verify_filepath = self.operation_path.joinpath( + self.operation.name() + ).with_name(f"{vmfb_filename}_verify.vmfb") + self.vmfb_profile_filepath = self.operation_path.joinpath( + self.operation.name() + ).with_name(f"{vmfb_filename}_profile.vmfb") + + # reference implementation for the operation_kind. + self.reference_impl_map = { + OperationKind.Matmul: ReferenceMatmulOp, + OperationKind.SplitkMatmul: ReferenceMatmulOp, + OperationKind.BatchMatmul: ReferenceBatchMatmulOp, + } + + def iree_compile(self, compilation_mode): + """Compiles the input mlir file to vmfb file.""" + + benchmark_dispatch_repeat_count = ( + self.benchmark_dispatch_repeat_count + if compilation_mode == CompilationMode.Profile + else 1 + ) + vmfb_filepath = ( + self.vmfb_profile_filepath + if compilation_mode == CompilationMode.Profile + else self.vmfb_verify_filepath + ) + + # Base iree-compile commandline + cmd = [ + f"{self.iree_compile_path}", + f"{self.source_mlir_file}", + "-o", + f"{vmfb_filepath}", + ] + + # General compilation options + cmd += [f"--iree-hal-target-backends={self.args.device}"] + + if self.args.device == "cuda": + cmd += [f"--iree-hal-cuda-llvm-target-arch={self.args.cuda_arch}"] + if self.operation.operation_kind == OperationKind.SplitkMatmul: + cmd += [ + f"--iree-flow-split-matmul-reduction={self.operation.split_k_slices}" + ] + if self.args.use_mma_sync: + cmd += [f"--iree-codegen-llvmgpu-use-mma-sync"] + if self.args.use_wmma: + cmd += [f"--iree-codegen-llvmgpu-use-wmma"] + + # Compilation options for profiling + cmd += [ + f"--iree-hal-benchmark-dispatch-repeat-count={benchmark_dispatch_repeat_count}" + ] + + # Appends print ir options at the end of the command line. + if self.args.mlir_print_ir_after_all: + cmd += [f"--mlir-print-ir-after-all"] + + if not vmfb_filepath.exists() or self.args.force_compile: + complie_mode_str = CompilationModeNames[compilation_mode] + + print(f"[Compiling ({complie_mode_str})] {' '.join(cmd)}") + + iree_compile_stdout_filepath = self.operation_path.joinpath( + "iree_compile_cmd_stdout.mlir" + ) + + with open(iree_compile_stdout_filepath, "w") as fp: + subprocess.run(cmd, stderr=fp) + + elif self.args.verbose: + print( + f"Skipping compilation of operation: {vmfb_filepath} since it already exists." + ) + + def verify(self, configuration): + """Verifies the operation with a given configuration.""" + # First compile the operation to a vmfb file. + self.iree_compile(CompilationMode.Verify) + + # Verify using random data distribution. + reference_run = self.reference_impl_map[self.operation.operation_kind]( + self.operation, + self.op_reference_cache_path, + Distribution.Random, + Distribution.Random, + ) + + if not reference_run.is_cached(): + reference_run() + + # Commandline `iree-run-module` for verification. + cmd = [ + f"{self.iree_run_module_path}", + f"--module={self.vmfb_verify_filepath}", + f"--device={self.args.device}", + ] + + # Operation-specific verification command-line. + cmd.append(f"--function={self.operation.name()}_{configuration.name()}") + for input_file_path in reference_run.get_input_filepaths(): + cmd.append(f"--input=@{input_file_path}") + for output_file_path in reference_run.get_output_filepaths(): + cmd.append(f"--expected_output=@{output_file_path}") + + # Print the command if verbose. + if self.args.verbose: + print(f"[Verification] {' '.join(cmd)}") + + # Launch verification. + cmd_output = subprocess.check_output(cmd, text=True) + + # Save the verification command and the output, only if requested + # (file writing could slow down the verification). + if self.args.save_cmds: + filepath = self.operation_path.joinpath("iree_run_module.stdout") + with open(filepath, "w") as fp: + fp.write(f"[Command] $ {' '.join(cmd)}\n") + fp.write(cmd_output) + + # Parse the verification output. + m = re.search(r"\[(?P[a-zA-Z]+)\]", cmd_output) + if m is None: + raise ValueError( + f"Failed to parse verification output by iree-run-module: {cmd_output}" + ) + verification_result = m.group("verification_result") + + if self.args.verbose or verification_result != "SUCCESS": + print(cmd_output) + + return verification_result + + def profile(self, configuration): + """Profiles the operation with a given configuration.""" + # First compile the operation to a vmfb file. + self.iree_compile(CompilationMode.Profile) + + # Commandline `iree-benchmark-module` for profiling. + cmd = [ + f"{self.iree_benchmark_module_path}", + f"--module={self.vmfb_profile_filepath}", + f"--device={self.args.device}", + ] + + # Profiling specific flags. + cmd += [f"--benchmark_repetitions={self.args.benchmark_repetitions}"] + cmd += [f"--batch_size={self.batch_size}"] + + # Operation-specific profiling command-line. + cmd += [f"--function={self.operation.name()}_{configuration.name()}"] + cmd += [f"--input={self.operation.lhs_npy_shape()}"] + cmd += [f"--input={self.operation.rhs_npy_shape()}"] + + # Print the command if verbose. + if self.args.verbose: + print(f"[Profiling] {' '.join(cmd)}") + + # Launch profiling. + cmd_output = subprocess.check_output(cmd, text=True, stderr=subprocess.STDOUT) + + # Save the profiling command and the output, only if requested + # (file writing could slow down the profiling). + if self.args.save_cmds: + filepath = self.operation_path.joinpath("iree_benchmark_module.stdout") + with open(filepath, "w") as fp: + fp.write(f"[Command] $ {' '.join(cmd)}\n") + fp.write(cmd_output) + + # Parse the profiling output. + m = re.search(r"real_time_median\s+(?P\d+.\d+)\s+ms", cmd_output) + if m is None: + raise ValueError( + f"Failed to parse runtime from benchmark result: {cmd_output}" + ) + runtime_in_ms = float(m.group("runtime")) + return runtime_in_ms diff --git a/experimental/dispatch_profiler/library.py b/experimental/dispatch_profiler/library.py index 7dbb00b824b1..2a7f34532ff4 100644 --- a/experimental/dispatch_profiler/library.py +++ b/experimental/dispatch_profiler/library.py @@ -27,8 +27,8 @@ # Architecure types ################################################################################################### class ArchType(enum.Enum): - Cpu = auto() - Gpu = auto() + Cpu = auto() + Gpu = auto() ArchTypeNames = { @@ -38,9 +38,9 @@ class ArchType(enum.Enum): class GpuArchType(enum.Enum): - nvptx = auto() - rocm = auto() - spirv = auto() + nvptx = auto() + rocm = auto() + spirv = auto() GpuArchTypeNames = { @@ -53,79 +53,79 @@ class GpuArchType(enum.Enum): # Operation kinds ################################################################################################### class OperationKind(enum.Enum): - Matmul = auto() - BatchMatmul = auto() - SplitkMatmul = auto() - Conv2d = auto() + Matmul = auto() + BatchMatmul = auto() + SplitkMatmul = auto() + Conv2d = auto() OperationKindNames = { - OperationKind.Matmul: 'matmul', - OperationKind.SplitkMatmul: 'matmul_splitk', - OperationKind.BatchMatmul: 'batch_matmul', - OperationKind.Conv2d: 'conv2d' + OperationKind.Matmul: "matmul", + OperationKind.SplitkMatmul: "matmul_splitk", + OperationKind.BatchMatmul: "batch_matmul", + OperationKind.Conv2d: "conv2d", } # MLIR dialects ################################################################################################### class MlirDialect(enum.Enum): - Linalg = auto() - Mhlo = auto() + Linalg = auto() + Mhlo = auto() MlirDialectNames = { - MlirDialect.Linalg: 'linalg', - MlirDialect.Mhlo: 'mhlo', + MlirDialect.Linalg: "linalg", + MlirDialect.Mhlo: "mhlo", } # Compilation modes (verification or benchmarking/profiling) ################################################################################################### class CompilationMode(enum.Enum): - Verify = auto() - Profile = auto() + Verify = auto() + Profile = auto() CompilationModeNames = { - CompilationMode.Verify: 'verify', - CompilationMode.Profile: 'profile', + CompilationMode.Verify: "verify", + CompilationMode.Profile: "profile", } class CompilationConfigType(enum.Enum): - Default = auto() - Custom = auto() + Default = auto() + Custom = auto() CompilationConfigTypeName = { - CompilationConfigType.Default: 'default', - CompilationConfigType.Custom: 'custom', + CompilationConfigType.Default: "default", + CompilationConfigType.Custom: "custom", } # Enumerations for data types and layouts ################################################################################################### class DataType(enum.Enum): - b1 = auto() - u4 = auto() - u8 = auto() - u16 = auto() - u32 = auto() - u64 = auto() - s4 = auto() - s8 = auto() - s16 = auto() - s32 = auto() - s64 = auto() - e4m3 = auto() - e5m2 = auto() - f16 = auto() - bf16 = auto() - f32 = auto() - tf32 = auto() - f64 = auto() - invalid = auto() + b1 = auto() + u4 = auto() + u8 = auto() + u16 = auto() + u32 = auto() + u64 = auto() + s4 = auto() + s8 = auto() + s16 = auto() + s32 = auto() + s64 = auto() + e4m3 = auto() + e5m2 = auto() + f16 = auto() + bf16 = auto() + f32 = auto() + tf32 = auto() + f64 = auto() + invalid = auto() DataTypeName = { @@ -140,8 +140,8 @@ class DataType(enum.Enum): DataType.s16: "s16", DataType.s32: "s32", DataType.s64: "s64", - DataType.e4m3: 'e4m3', - DataType.e5m2: 'e5m2', + DataType.e4m3: "e4m3", + DataType.e5m2: "e5m2", DataType.f16: "f16", DataType.bf16: "bf16", DataType.f32: "f32", @@ -177,37 +177,34 @@ class DataType(enum.Enum): class LayoutType(enum.Enum): - ColumnMajor = auto() - RowMajor = auto() - NHWC = auto() - NCWH = auto() + ColumnMajor = auto() + RowMajor = auto() + NHWC = auto() + NCWH = auto() # cuBLAS/cuDNN layout type names convention is followed for the layout names. # https://docs.nvidia.com/cuda/cublas/index.html#cublasoperation-t ShortLayoutTypeName = { - LayoutType.ColumnMajor: 'n', - LayoutType.RowMajor: 't', - LayoutType.NHWC: 'nhwc', - LayoutType.NCWH: 'ncwh', + LayoutType.ColumnMajor: "n", + LayoutType.RowMajor: "t", + LayoutType.NHWC: "nhwc", + LayoutType.NCWH: "ncwh", } # Compilation pipelines/translation info. ################################################################################################### class TranslationInfo(enum.Enum): - LLVMGPUMatmulSIMT = auto() - LLVMGPUMatmulTensorCore = auto() - LLVMGPUMatmulTensorCoreMmaSync = auto() + LLVMGPUMatmulSIMT = auto() + LLVMGPUMatmulTensorCore = auto() + LLVMGPUMatmulTensorCoreMmaSync = auto() TranslationInfoTag = { - TranslationInfo.LLVMGPUMatmulSIMT: - "LLVMGPUMatmulSIMT", - TranslationInfo.LLVMGPUMatmulTensorCore: - "LLVMGPUMatmulTensorCore", - TranslationInfo.LLVMGPUMatmulTensorCoreMmaSync: - "LLVMGPUMatmulTensorCoreMmaSync", + TranslationInfo.LLVMGPUMatmulSIMT: "LLVMGPUMatmulSIMT", + TranslationInfo.LLVMGPUMatmulTensorCore: "LLVMGPUMatmulTensorCore", + TranslationInfo.LLVMGPUMatmulTensorCoreMmaSync: "LLVMGPUMatmulTensorCoreMmaSync", } TranslationInfoName = { @@ -220,12 +217,12 @@ class TranslationInfo(enum.Enum): # Distribution of values in a tensor. ################################################################################################### class Distribution(enum.Enum): - Empty = auto() - Zeros = auto() - Ones = auto() - Sequential = auto() - Identity = auto() - Random = auto() + Empty = auto() + Zeros = auto() + Ones = auto() + Sequential = auto() + Identity = auto() + Random = auto() DistributionName = { @@ -246,29 +243,31 @@ class Distribution(enum.Enum): class TensorDescription: - """A class for tensor description.""" + """A class for tensor description.""" - def __init__(self, datatype, layout): - self.datatype = datatype - self.layout = layout + def __init__(self, datatype, layout): + self.datatype = datatype + self.layout = layout - def name(self): - return "%s%s" % (DataTypeName[self.datatype], - ShortLayoutTypeName[self.layout]) + def name(self): + return "%s%s" % (DataTypeName[self.datatype], ShortLayoutTypeName[self.layout]) class TileDescription: - """A class for tile description.""" + """A class for tile description.""" - def __init__(self, threadblock_shape, stages, block_dim): - self.threadblock_shape = threadblock_shape # in number of elements in M, N, K - self.stages = stages # number of shared memory stages in tile K - self.block_dim = block_dim # block dimension in number of threads in x, y, z + def __init__(self, threadblock_shape, stages, block_dim): + self.threadblock_shape = threadblock_shape # in number of elements in M, N, K + self.stages = stages # number of shared memory stages in tile K + self.block_dim = block_dim # block dimension in number of threads in x, y, z - def name(self): - return "%dx%d_%dx%d" % (self.threadblock_shape[0], - self.threadblock_shape[1], - self.threadblock_shape[2], self.stages) + def name(self): + return "%dx%d_%dx%d" % ( + self.threadblock_shape[0], + self.threadblock_shape[1], + self.threadblock_shape[2], + self.stages, + ) ################################################################################################### @@ -277,77 +276,79 @@ def name(self): # functionality they provide becomes apparent and necessary as we move forward. ################################################################################################### def get_np_array(tensor_description, shape, dist): - """Returns a numpy array based on the distribution and shape.""" - # Fix the seed for reproducibility. - np.random.seed(42) - - # Generate the numpy array based on the distribution. - if dist == Distribution.Empty: - return np.empty(shape) - elif dist == Distribution.Zeros: - return np.zeros(shape) - elif dist == Distribution.Ones: - return np.ones(shape) - elif dist == Distribution.Sequential: - return np.arange(np.prod(shape)).reshape(shape) - elif dist == Distribution.Identity: - return np.eye(shape[0], shape[1]) - elif dist == Distribution.Random: - if tensor_description.datatype == DataType.s8: - return np.random.randint(-2, 3, shape) - elif tensor_description.datatype == DataType.u8: - return np.random.randint(0, 4, shape) - elif tensor_description.datatype == DataType.f16 or \ - tensor_description.datatype == DataType.bf16: - return np.random.randint(-3, 4, shape) - elif tensor_description.datatype == DataType.f32: - return np.random.randint(-7, 8, shape) + """Returns a numpy array based on the distribution and shape.""" + # Fix the seed for reproducibility. + np.random.seed(42) + + # Generate the numpy array based on the distribution. + if dist == Distribution.Empty: + return np.empty(shape) + elif dist == Distribution.Zeros: + return np.zeros(shape) + elif dist == Distribution.Ones: + return np.ones(shape) + elif dist == Distribution.Sequential: + return np.arange(np.prod(shape)).reshape(shape) + elif dist == Distribution.Identity: + return np.eye(shape[0], shape[1]) + elif dist == Distribution.Random: + if tensor_description.datatype == DataType.s8: + return np.random.randint(-2, 3, shape) + elif tensor_description.datatype == DataType.u8: + return np.random.randint(0, 4, shape) + elif ( + tensor_description.datatype == DataType.f16 + or tensor_description.datatype == DataType.bf16 + ): + return np.random.randint(-3, 4, shape) + elif tensor_description.datatype == DataType.f32: + return np.random.randint(-7, 8, shape) ################################################################################################### def SubstituteTemplate(template, values): - """Substitutes values into a template string.""" - text = template - for key, value in values.items(): - regex = "\\$\\{%s\\}" % key - newtext = re.sub(regex, value, text) - text = newtext - return text + """Substitutes values into a template string.""" + text = template + for key, value in values.items(): + regex = "\\$\\{%s\\}" % key + newtext = re.sub(regex, value, text) + text = newtext + return text ################################################################################################### class ReferenceOpInterface(ABC): - """Interface for reference implementations.""" + """Interface for reference implementations.""" - @abstractmethod - def get_input_filepaths(self): - """Returns the list of inputs.""" - pass + @abstractmethod + def get_input_filepaths(self): + """Returns the list of inputs.""" + pass - @abstractmethod - def get_output_filepaths(self): - """Returns the list of outputs/.""" - pass + @abstractmethod + def get_output_filepaths(self): + """Returns the list of outputs/.""" + pass - @abstractmethod - def __call__(self): - """Runs the reference implementation.""" - pass + @abstractmethod + def __call__(self): + """Runs the reference implementation.""" + pass - def is_cached(self): - """Returns whether the reference run is cached.""" + def is_cached(self): + """Returns whether the reference run is cached.""" - # Returns False if any of the reference input are missing. - for input_filepath in self.get_input_filepaths(): - if not input_filepath.exists(): - return False + # Returns False if any of the reference input are missing. + for input_filepath in self.get_input_filepaths(): + if not input_filepath.exists(): + return False - # Returns False if any of the reference output are missing. - for output_filepath in self.get_output_filepaths(): - if not output_filepath.exists(): - return False + # Returns False if any of the reference output are missing. + for output_filepath in self.get_output_filepaths(): + if not output_filepath.exists(): + return False - # Returns True if all the reference inputs and outputs are cached. - return True + # Returns True if all the reference inputs and outputs are cached. + return True - ################################################################################################### + ################################################################################################### diff --git a/experimental/dispatch_profiler/manifest.py b/experimental/dispatch_profiler/manifest.py index 8fa126354a38..9254c3c96a57 100644 --- a/experimental/dispatch_profiler/manifest.py +++ b/experimental/dispatch_profiler/manifest.py @@ -14,242 +14,251 @@ ############################################################################### class EmitSourceMLIR: - """Emitters for the operation MLIR source files.""" - - def __init__(self, operation_path, dispatch_collection): - self.operation_path = operation_path - self.dispatch_collection = dispatch_collection - self.operation = dispatch_collection.operation - self.operation_kind = self.operation.operation_kind - self.configuration_list = dispatch_collection.configuration_list - self.operation_filepath = self.operation_path.joinpath( - self.operation.name()).with_suffix(".mlir") - - mlir_configuration_emitter = { - OperationKind.Matmul: EmitMatmulCompilationInfo, - OperationKind.SplitkMatmul: EmitMatmulCompilationInfo, - OperationKind.BatchMatmul: EmitMatmulCompilationInfo, - } - self.configuration_emitter = mlir_configuration_emitter[ - self.operation_kind]() - - mlir_dispatch_emitter = { - OperationKind.Matmul: EmitLinalgMatmulDispatch, - OperationKind.SplitkMatmul: EmitLinalgMatmulDispatch, - OperationKind.BatchMatmul: EmitLinalgBatchMatmulDispatch, - } - self.dispatch_emitter = mlir_dispatch_emitter[self.operation_kind]() - - def __enter__(self): - self.operation_file = open(self.operation_filepath, "w") - self.operation_file.write(f'// Finename: {self.operation_filepath}') - - # Emit all the configuration attribute tags. - for configuration in self.configuration_list: - self.operation_file.write(self.configuration_emitter.emit(configuration)) - return self - - def emit(self): - """Emit the op func.func for each dispatch (operation + configuration)""" - for dispatch in self.dispatch_collection.get_dispatches(): - print( - f" Emitting tuning configuration : {dispatch.configuration.name()}" - ) - self.operation_file.write(self.dispatch_emitter.emit(dispatch)) - - def __exit__(self, exc_type, exc_value, traceback): - self.operation_file.close() + """Emitters for the operation MLIR source files.""" + + def __init__(self, operation_path, dispatch_collection): + self.operation_path = operation_path + self.dispatch_collection = dispatch_collection + self.operation = dispatch_collection.operation + self.operation_kind = self.operation.operation_kind + self.configuration_list = dispatch_collection.configuration_list + self.operation_filepath = self.operation_path.joinpath( + self.operation.name() + ).with_suffix(".mlir") + + mlir_configuration_emitter = { + OperationKind.Matmul: EmitMatmulCompilationInfo, + OperationKind.SplitkMatmul: EmitMatmulCompilationInfo, + OperationKind.BatchMatmul: EmitMatmulCompilationInfo, + } + self.configuration_emitter = mlir_configuration_emitter[self.operation_kind]() + + mlir_dispatch_emitter = { + OperationKind.Matmul: EmitLinalgMatmulDispatch, + OperationKind.SplitkMatmul: EmitLinalgMatmulDispatch, + OperationKind.BatchMatmul: EmitLinalgBatchMatmulDispatch, + } + self.dispatch_emitter = mlir_dispatch_emitter[self.operation_kind]() + + def __enter__(self): + self.operation_file = open(self.operation_filepath, "w") + self.operation_file.write(f"// Finename: {self.operation_filepath}") + + # Emit all the configuration attribute tags. + for configuration in self.configuration_list: + self.operation_file.write(self.configuration_emitter.emit(configuration)) + return self + + def emit(self): + """Emit the op func.func for each dispatch (operation + configuration)""" + for dispatch in self.dispatch_collection.get_dispatches(): + print( + f" Emitting tuning configuration : {dispatch.configuration.name()}" + ) + self.operation_file.write(self.dispatch_emitter.emit(dispatch)) + + def __exit__(self, exc_type, exc_value, traceback): + self.operation_file.close() ############################################################################### class Manifest: - """Manifest collects, filters, and stores dispatches in a data structure. - Manifest organizes the dispatches in a dictionary. - Usage: - 1. Create a manifest object with the command line arguments. - 2(a). Generate dispatches, append them in the manifest, and - serialize them into a file. - 2(b). Load dispatches from a serialized file. - - ```python - # generator.py usage: - manifest = Manifest(args) - manifest.initialize() - - # compile.py or profile.py usage: - manifest = Manifest(args) - manifest.load() - ``` - """ - - def __init__(self, args): - self.args = args - - # Dictionary of operation kind to a list of dispatch collections. We - # initialize the dictionary during the generation of dispatches and - # serialize it to a file. The serialized file is used to load the - # dispatches for compilation and profiling. - # Datatype: OperationKind -> [DispatchCollection] - self.dispatch_collection_map = {} - - # For operation kind-based filtering of dispatches. - self.operation_kind_enabled = [] - - # For name-based filtering of dispatches. - self.dispatch_names = [] - self.ignore_dispatch_names = [] - - if args.operation_kind == 'all': - self.operation_kind_enabled = [] - else: - operations_kind_list = [ - OperationKind.Matmul, - OperationKind.SplitkMatmul, - OperationKind.BatchMatmul, - ] - self.operation_kind_enabled = [ - x for x in operations_kind_list - if OperationKindNames[x] in args.operation_kind.split(',') - ] - - if args.dispatches == 'all': - self.dispatch_names = [] - else: - self.dispatch_names = [x for x in args.dispatches.split(',') if x != ''] - - # Paths to the generated directory (e.g. `./generated/linalg`). - self.generated_path = Path(self.args.generated_dir, 'generated', - self.args.mlir_dialect) - - # Create the directories in self.generated_path, if it does not exist. - if not self.generated_path.exists(): - self.generated_path.mkdir(parents=True, exist_ok=True) - - # Path to the serialized file. - self.serialized_file_path = self.generated_path.joinpath( - 'serialized_file.pkl') - - def _filter_string_matches(self, filter_string, haystack): - """Returns true if all substrings appear in the haystack in order""" - substrings = filter_string.split('*') - for sub in substrings: - idx = haystack.find(sub) - if idx < 0: - return False - haystack = haystack[idx + len(sub):] - return True - - def is_enabled(self, dispatch): - """Rerturns true if pass through filters based various criteria.""" - - # Get the operation and configuration from the dispatch. - operation = dispatch.operation - configuration = dispatch.configuration - - # If the operation is not in the enabled list, return False. - enabled = True - - # If operation_kind filter is enabled and the \ - # operation_kind in not in the enabled list, return False. - if len(self.operation_kind_enabled) and \ - operation.operation_kind not in self.operation_kind_enabled: - enabled = False - - # If dispatch name-based filter regex is enabled match the \ - # dispatch name (operation+configuration) against all regexs \ - # in self.dispatch_names. - if len(self.dispatch_names): - name = dispatch.name() - enabled = False - - # compare against each regex included in self.dispatch_names. - for substr_to_match in self.dispatch_names: - if self._filter_string_matches(substr_to_match, name): - enabled = True - break - - # Return the result of the filter. - return enabled - - def append_dispatch_collection(self, dispatch_collection): - """Appends one instance of DispatchCollection to the manifest.""" - operation_kind = dispatch_collection.operation.operation_kind - if operation_kind not in self.dispatch_collection_map.keys(): - self.dispatch_collection_map[operation_kind] = [] - - # Get all the dispatches from the dispatch_collection. - dispatches = dispatch_collection.get_dispatches() - - # Filter dispatches based on the filter criteria. - filtered_dispatch_collection = DispatchCollection( - dispatch_collection.operation, []) - for dispatch in dispatches: - if self.is_enabled(dispatch): - filtered_dispatch_collection.append(dispatch) - - # Only append the filtered_dispatch_collection if it has an unfiltered configuration. - if len(filtered_dispatch_collection.configuration_list): - self.dispatch_collection_map[operation_kind].append( - filtered_dispatch_collection) - - def append(self, dispatch_collection_list): - """Appends one instance of DispatchCollection to the manifest.""" - for dispatch_collection in dispatch_collection_list: - self.append_dispatch_collection(dispatch_collection) - - def initialize(self): - """Initialize the mainfest object by generating dispatches for supported operations.""" - self.append(CudaMatmulGenerator(self.args).generate()) - self.append(CudaSplitKMatmulGenerator(self.args).generate()) - self.append(CudaBatchMatmulGenerator(self.args).generate()) - - # Serialize the initialized mainfest state. - self.dump() - - def dump(self): - """Serialize (dump) the self.dispatch_collection_map to a pickle file.""" - with open(self.serialized_file_path, 'wb') as f: - pickle.dump(self.dispatch_collection_map, f) - - def load(self): - """Deserialize (load) the self.dispatch_collection_map from a pickle file.""" - if not self.serialized_file_path.exists(): - raise ValueError(f"Could not find : {self.serialized_file_path}") - - with open(self.serialized_file_path, 'rb') as load_file: - self.dispatch_collection_map = pickle.load(load_file) - - def emit(self): - """Emits the operations in the Manifest to the build directory as MLIR source files. - The operations are emitted in the dialect specified by the `mlir_dialect` flag. + """Manifest collects, filters, and stores dispatches in a data structure. + Manifest organizes the dispatches in a dictionary. + Usage: + 1. Create a manifest object with the command line arguments. + 2(a). Generate dispatches, append them in the manifest, and + serialize them into a file. + 2(b). Load dispatches from a serialized file. + + ```python + # generator.py usage: + manifest = Manifest(args) + manifest.initialize() + + # compile.py or profile.py usage: + manifest = Manifest(args) + manifest.load() + ``` """ - # For each operation_kind create a directory and emit the operations with - # all the configurations in the configuration_list into their seperate directories. - for operation_kind, dispatch_collection_list\ - in self.dispatch_collection_map.items(): - - operation_kind_path = self.generated_path.joinpath( - OperationKindNames[operation_kind]) - - # If the operation_kind_path does not exists, create it. - if not operation_kind_path.exists(): - operation_kind_path.mkdir(parents=True, exist_ok=True) - - for dispatch_collection in dispatch_collection_list: - - operation_path = operation_kind_path.joinpath( - dispatch_collection.operation.name()) - - # If the operation_path does not exists, create it. - if not operation_path.exists(): - operation_path.mkdir() - - with EmitSourceMLIR(operation_path, - dispatch_collection) as emit_mlir_source: - mlir_file_path = operation_path.joinpath( - dispatch_collection.operation.name()).with_suffix('.mlir') - print(f"[Generating]: {mlir_file_path}") - - # Emit mlir source file for the dispatch_collection.operation with all the configurations - emit_mlir_source.emit() + def __init__(self, args): + self.args = args + + # Dictionary of operation kind to a list of dispatch collections. We + # initialize the dictionary during the generation of dispatches and + # serialize it to a file. The serialized file is used to load the + # dispatches for compilation and profiling. + # Datatype: OperationKind -> [DispatchCollection] + self.dispatch_collection_map = {} + + # For operation kind-based filtering of dispatches. + self.operation_kind_enabled = [] + + # For name-based filtering of dispatches. + self.dispatch_names = [] + self.ignore_dispatch_names = [] + + if args.operation_kind == "all": + self.operation_kind_enabled = [] + else: + operations_kind_list = [ + OperationKind.Matmul, + OperationKind.SplitkMatmul, + OperationKind.BatchMatmul, + ] + self.operation_kind_enabled = [ + x + for x in operations_kind_list + if OperationKindNames[x] in args.operation_kind.split(",") + ] + + if args.dispatches == "all": + self.dispatch_names = [] + else: + self.dispatch_names = [x for x in args.dispatches.split(",") if x != ""] + + # Paths to the generated directory (e.g. `./generated/linalg`). + self.generated_path = Path( + self.args.generated_dir, "generated", self.args.mlir_dialect + ) + + # Create the directories in self.generated_path, if it does not exist. + if not self.generated_path.exists(): + self.generated_path.mkdir(parents=True, exist_ok=True) + + # Path to the serialized file. + self.serialized_file_path = self.generated_path.joinpath("serialized_file.pkl") + + def _filter_string_matches(self, filter_string, haystack): + """Returns true if all substrings appear in the haystack in order""" + substrings = filter_string.split("*") + for sub in substrings: + idx = haystack.find(sub) + if idx < 0: + return False + haystack = haystack[idx + len(sub) :] + return True + + def is_enabled(self, dispatch): + """Rerturns true if pass through filters based various criteria.""" + + # Get the operation and configuration from the dispatch. + operation = dispatch.operation + configuration = dispatch.configuration + + # If the operation is not in the enabled list, return False. + enabled = True + + # If operation_kind filter is enabled and the \ + # operation_kind in not in the enabled list, return False. + if ( + len(self.operation_kind_enabled) + and operation.operation_kind not in self.operation_kind_enabled + ): + enabled = False + + # If dispatch name-based filter regex is enabled match the \ + # dispatch name (operation+configuration) against all regexs \ + # in self.dispatch_names. + if len(self.dispatch_names): + name = dispatch.name() + enabled = False + + # compare against each regex included in self.dispatch_names. + for substr_to_match in self.dispatch_names: + if self._filter_string_matches(substr_to_match, name): + enabled = True + break + + # Return the result of the filter. + return enabled + + def append_dispatch_collection(self, dispatch_collection): + """Appends one instance of DispatchCollection to the manifest.""" + operation_kind = dispatch_collection.operation.operation_kind + if operation_kind not in self.dispatch_collection_map.keys(): + self.dispatch_collection_map[operation_kind] = [] + + # Get all the dispatches from the dispatch_collection. + dispatches = dispatch_collection.get_dispatches() + + # Filter dispatches based on the filter criteria. + filtered_dispatch_collection = DispatchCollection( + dispatch_collection.operation, [] + ) + for dispatch in dispatches: + if self.is_enabled(dispatch): + filtered_dispatch_collection.append(dispatch) + + # Only append the filtered_dispatch_collection if it has an unfiltered configuration. + if len(filtered_dispatch_collection.configuration_list): + self.dispatch_collection_map[operation_kind].append( + filtered_dispatch_collection + ) + + def append(self, dispatch_collection_list): + """Appends one instance of DispatchCollection to the manifest.""" + for dispatch_collection in dispatch_collection_list: + self.append_dispatch_collection(dispatch_collection) + + def initialize(self): + """Initialize the mainfest object by generating dispatches for supported operations.""" + self.append(CudaMatmulGenerator(self.args).generate()) + self.append(CudaSplitKMatmulGenerator(self.args).generate()) + self.append(CudaBatchMatmulGenerator(self.args).generate()) + + # Serialize the initialized mainfest state. + self.dump() + + def dump(self): + """Serialize (dump) the self.dispatch_collection_map to a pickle file.""" + with open(self.serialized_file_path, "wb") as f: + pickle.dump(self.dispatch_collection_map, f) + + def load(self): + """Deserialize (load) the self.dispatch_collection_map from a pickle file.""" + if not self.serialized_file_path.exists(): + raise ValueError(f"Could not find : {self.serialized_file_path}") + + with open(self.serialized_file_path, "rb") as load_file: + self.dispatch_collection_map = pickle.load(load_file) + + def emit(self): + """Emits the operations in the Manifest to the build directory as MLIR source files. + The operations are emitted in the dialect specified by the `mlir_dialect` flag. + """ + + # For each operation_kind create a directory and emit the operations with + # all the configurations in the configuration_list into their seperate directories. + for ( + operation_kind, + dispatch_collection_list, + ) in self.dispatch_collection_map.items(): + operation_kind_path = self.generated_path.joinpath( + OperationKindNames[operation_kind] + ) + + # If the operation_kind_path does not exists, create it. + if not operation_kind_path.exists(): + operation_kind_path.mkdir(parents=True, exist_ok=True) + + for dispatch_collection in dispatch_collection_list: + operation_path = operation_kind_path.joinpath( + dispatch_collection.operation.name() + ) + + # If the operation_path does not exists, create it. + if not operation_path.exists(): + operation_path.mkdir() + + with EmitSourceMLIR( + operation_path, dispatch_collection + ) as emit_mlir_source: + mlir_file_path = operation_path.joinpath( + dispatch_collection.operation.name() + ).with_suffix(".mlir") + print(f"[Generating]: {mlir_file_path}") + + # Emit mlir source file for the dispatch_collection.operation with all the configurations + emit_mlir_source.emit() diff --git a/experimental/dispatch_profiler/matmul.py b/experimental/dispatch_profiler/matmul.py index 5ad8ad59e030..5ad633cdfeef 100644 --- a/experimental/dispatch_profiler/matmul.py +++ b/experimental/dispatch_profiler/matmul.py @@ -12,176 +12,193 @@ ################################################################################ class MatmulOperation: - """Data structure to describe a matrix multiplication operation. - This includes the shape, datatype, and layout of the operands. This data - structure is *independent* of the compilation* and tiling configuration. - It "mostly" contains the parameter that changes the functionality of matmul - operation. The only exception is the split_k_slices parameter, which is - changes the performance of the matmul operation and not the functionality. - """ - - def __init__(self, - matmul_shape, - lhs, - rhs, - result, - batch_count=1, - split_k_slices=1, - operation_kind=OperationKind.Matmul): - """Initializes a matrix multiplication operation. - Matrix-multiple operation: `result[M, N] = lhs[M, K] * rhs[K, N]` - matmul_shape: A tuple representing the matrix multiplication problem shape - in the format (M, N, K), where M is the number of rows in the lhs matrix, - N is the number of columns in the rhs matrix, and K is the number of columns - in the lhs matrix and rows in the rhs matrix. - lhs: A TensorDescription object representing the left-hand-side matrix operand. - rhs: A TensorDescription object representing the right-hand-side matrix operand. - result: A TensorDescription object representing the result matrix operand. + """Data structure to describe a matrix multiplication operation. + This includes the shape, datatype, and layout of the operands. This data + structure is *independent* of the compilation* and tiling configuration. + It "mostly" contains the parameter that changes the functionality of matmul + operation. The only exception is the split_k_slices parameter, which is + changes the performance of the matmul operation and not the functionality. """ - # Parameters that change the matmul operation *functionally*. - self.operation_kind = operation_kind - self.matmul_shape = matmul_shape - self.M = matmul_shape[0] - self.N = matmul_shape[1] - self.K = matmul_shape[2] - self.batch_count = batch_count - self.lhs = lhs # TensorDescription - self.rhs = rhs # TensorDescription - self.result = result # TensorDescription - - # Parameters that change the matmul operation *performance*. - self.split_k_slices = split_k_slices - - def __eq__(self, other): - """Returns true if the matmul operation is *functionally* the same.""" - return self.matmul_shape == other.matmul_shape and \ - self.lhs == other.lhs and \ - self.rhs == other.rhs and \ - self.result == other.result and \ - self.batch_count == other.batch_count - - def name(self): - """Procedurally generated name for the matmul operation. - The name uniquely identifies a matmul operation with matmul shape, - lhs dataype and layout, rhs datatype and layout, and result - datatype and layout. - """ - return f'{OperationKindNames[self.operation_kind]}_'\ - f'{self.M}x{self.N}x{self.K}_'\ - f'{DataTypeName[self.lhs.datatype]}{ShortLayoutTypeName[self.lhs.layout]}_'\ - f'{DataTypeName[self.rhs.datatype]}{ShortLayoutTypeName[self.rhs.layout]}_'\ - f'{DataTypeName[self.result.datatype]}{ShortLayoutTypeName[self.result.layout]}' - - def get_argument_dict(self): - """Returns the dictionary of matmul arguments (shape, datatypes, split_k_slices).""" - split_k_mode = "parallel" if self.operation_kind == OperationKind.SplitkMatmul else "N/A" - split_k_slices = self.split_k_slices if self.operation_kind == OperationKind.SplitkMatmul else "N/A" - return { - "batch_count": self.batch_count, - "m": self.M, - "n": self.N, - "k": self.K, - "lhs": self.lhs.name(), - "rhs": self.rhs.name(), - "result": self.result.name(), - "split_k_mode": split_k_mode, - "split_k_slices": split_k_slices - } - - def get_dict_entry(self): - """Returns the dictionary of matmul operation summary.""" - dict_entry = { - "op_kind": OperationKindNames[self.operation_kind], - "Operation": self.name(), - "bytes": self.bytes(), - "flops": self.flops(), - } - dict_entry.update(self.get_argument_dict()) - return dict_entry - - def lhs_npy_shape(self): - """Returns the shape of the lhs numpy array as a string in the format "MxKxDataType".""" - return f"{self.M}x{self.K}x{DataTypeName[self.lhs.datatype]}" - - def rhs_npy_shape(self): - """Returns the shape of the rhs numpy array as a string in the format "KxNxDataType".""" - return f"{self.K}x{self.N}x{DataTypeName[self.rhs.datatype]}" - - def result_npy_shape(self): - """Returns the shape of the result numpy array as a string in the format "MxNxDataType".""" - return f"{self.M}x{self.N}x{DataTypeName[self.result.datatype]}" - - def bytes(self): - """Returns the number of bytes read/written by the matmul operation.""" - bytes = (DataTypeSizeInBits[self.lhs.datatype] * self.M // 8) * self.K + \ - (DataTypeSizeInBits[self.rhs.datatype] * self.K // 8) * self.N + \ - (DataTypeSizeInBits[self.result.datatype] * self.M // 8) * self.N - return bytes * self.batch_count - - def flops(self): - """Returns the number of floating point operations performed by the matmul operation.""" - return 2 * self.M * self.N * self.K * self.batch_count + def __init__( + self, + matmul_shape, + lhs, + rhs, + result, + batch_count=1, + split_k_slices=1, + operation_kind=OperationKind.Matmul, + ): + """Initializes a matrix multiplication operation. + Matrix-multiple operation: `result[M, N] = lhs[M, K] * rhs[K, N]` + matmul_shape: A tuple representing the matrix multiplication problem shape + in the format (M, N, K), where M is the number of rows in the lhs matrix, + N is the number of columns in the rhs matrix, and K is the number of columns + in the lhs matrix and rows in the rhs matrix. + lhs: A TensorDescription object representing the left-hand-side matrix operand. + rhs: A TensorDescription object representing the right-hand-side matrix operand. + result: A TensorDescription object representing the result matrix operand. + """ + + # Parameters that change the matmul operation *functionally*. + self.operation_kind = operation_kind + self.matmul_shape = matmul_shape + self.M = matmul_shape[0] + self.N = matmul_shape[1] + self.K = matmul_shape[2] + self.batch_count = batch_count + self.lhs = lhs # TensorDescription + self.rhs = rhs # TensorDescription + self.result = result # TensorDescription + + # Parameters that change the matmul operation *performance*. + self.split_k_slices = split_k_slices + + def __eq__(self, other): + """Returns true if the matmul operation is *functionally* the same.""" + return ( + self.matmul_shape == other.matmul_shape + and self.lhs == other.lhs + and self.rhs == other.rhs + and self.result == other.result + and self.batch_count == other.batch_count + ) + + def name(self): + """Procedurally generated name for the matmul operation. + The name uniquely identifies a matmul operation with matmul shape, + lhs dataype and layout, rhs datatype and layout, and result + datatype and layout. + """ + return ( + f"{OperationKindNames[self.operation_kind]}_" + f"{self.M}x{self.N}x{self.K}_" + f"{DataTypeName[self.lhs.datatype]}{ShortLayoutTypeName[self.lhs.layout]}_" + f"{DataTypeName[self.rhs.datatype]}{ShortLayoutTypeName[self.rhs.layout]}_" + f"{DataTypeName[self.result.datatype]}{ShortLayoutTypeName[self.result.layout]}" + ) + + def get_argument_dict(self): + """Returns the dictionary of matmul arguments (shape, datatypes, split_k_slices).""" + split_k_mode = ( + "parallel" if self.operation_kind == OperationKind.SplitkMatmul else "N/A" + ) + split_k_slices = ( + self.split_k_slices + if self.operation_kind == OperationKind.SplitkMatmul + else "N/A" + ) + return { + "batch_count": self.batch_count, + "m": self.M, + "n": self.N, + "k": self.K, + "lhs": self.lhs.name(), + "rhs": self.rhs.name(), + "result": self.result.name(), + "split_k_mode": split_k_mode, + "split_k_slices": split_k_slices, + } + + def get_dict_entry(self): + """Returns the dictionary of matmul operation summary.""" + dict_entry = { + "op_kind": OperationKindNames[self.operation_kind], + "Operation": self.name(), + "bytes": self.bytes(), + "flops": self.flops(), + } + dict_entry.update(self.get_argument_dict()) + return dict_entry + + def lhs_npy_shape(self): + """Returns the shape of the lhs numpy array as a string in the format "MxKxDataType".""" + return f"{self.M}x{self.K}x{DataTypeName[self.lhs.datatype]}" + + def rhs_npy_shape(self): + """Returns the shape of the rhs numpy array as a string in the format "KxNxDataType".""" + return f"{self.K}x{self.N}x{DataTypeName[self.rhs.datatype]}" + + def result_npy_shape(self): + """Returns the shape of the result numpy array as a string in the format "MxNxDataType".""" + return f"{self.M}x{self.N}x{DataTypeName[self.result.datatype]}" + + def bytes(self): + """Returns the number of bytes read/written by the matmul operation.""" + bytes = ( + (DataTypeSizeInBits[self.lhs.datatype] * self.M // 8) * self.K + + (DataTypeSizeInBits[self.rhs.datatype] * self.K // 8) * self.N + + (DataTypeSizeInBits[self.result.datatype] * self.M // 8) * self.N + ) + return bytes * self.batch_count + + def flops(self): + """Returns the number of floating point operations performed by the matmul operation.""" + return 2 * self.M * self.N * self.K * self.batch_count ############################################################################## class MatmulCompilationInfo: - """Data structure strictly describes the compilation passes and the tiling configurations. - For a matrix multiplication operation, compilation passes and tiling configuration - influences the performance of the compiled matmul operation, but the functionality. - This data structure should be independent of the matmul operation functionality. - - Any change in this data structure should not affect the functionality of the matmul operation, i.e., - we should be able to use the same reference results for a matrix operation compiled with different - compilation info. - """ - - def __init__(self, - tile_description, - translation_info, - operation_kind=OperationKind.Matmul, - config_type=CompilationConfigType.Custom): - self.tile_description = tile_description # TileDescription - self.translation_info = translation_info # TranslationInfo - self.operation_kind = operation_kind # OperationKind - self.config_type = config_type # CompilationConfigType - - def get_dict_entry(self): - """Returns the dictionary entry for the matmul compilation info.""" - if self.config_type == CompilationConfigType.Default: - return { - "Tile config": "Default", - "Core class": "Default", - "Instruction class": "Default" - } - - translation_info_name = TranslationInfoName[self.translation_info] - return { - "Tile config": self.tile_description.name(), - "Core class": translation_info_name.split('_')[0], - "Instruction class": translation_info_name.split('_')[1], - } - - def name(self): - """Procedurally generated name for the matmul compilation info.""" - if self.config_type == CompilationConfigType.Default: - return "tile_config_default" - - return "tile_config_{tbm}x{tbn}_{tbk}x{stages}_{translation_info}".format( - tbm=self.tile_description.threadblock_shape[0], - tbn=self.tile_description.threadblock_shape[1], - tbk=self.tile_description.threadblock_shape[2], - stages=self.tile_description.stages, - translation_info=TranslationInfoName[self.translation_info]) + """Data structure strictly describes the compilation passes and the tiling configurations. + For a matrix multiplication operation, compilation passes and tiling configuration + influences the performance of the compiled matmul operation, but the functionality. + This data structure should be independent of the matmul operation functionality. + + Any change in this data structure should not affect the functionality of the matmul operation, i.e., + we should be able to use the same reference results for a matrix operation compiled with different + compilation info. + """ + + def __init__( + self, + tile_description, + translation_info, + operation_kind=OperationKind.Matmul, + config_type=CompilationConfigType.Custom, + ): + self.tile_description = tile_description # TileDescription + self.translation_info = translation_info # TranslationInfo + self.operation_kind = operation_kind # OperationKind + self.config_type = config_type # CompilationConfigType + + def get_dict_entry(self): + """Returns the dictionary entry for the matmul compilation info.""" + if self.config_type == CompilationConfigType.Default: + return { + "Tile config": "Default", + "Core class": "Default", + "Instruction class": "Default", + } + + translation_info_name = TranslationInfoName[self.translation_info] + return { + "Tile config": self.tile_description.name(), + "Core class": translation_info_name.split("_")[0], + "Instruction class": translation_info_name.split("_")[1], + } + + def name(self): + """Procedurally generated name for the matmul compilation info.""" + if self.config_type == CompilationConfigType.Default: + return "tile_config_default" + + return "tile_config_{tbm}x{tbn}_{tbk}x{stages}_{translation_info}".format( + tbm=self.tile_description.threadblock_shape[0], + tbn=self.tile_description.threadblock_shape[1], + tbk=self.tile_description.threadblock_shape[2], + stages=self.tile_description.stages, + translation_info=TranslationInfoName[self.translation_info], + ) ################################################################################ class EmitMatmulCompilationInfo: - """Emitters for the matmul compilation info.""" + """Emitters for the matmul compilation info.""" - def __init__(self): - # matmul compilation info template - self.matmul_compilation_info_template = """ + def __init__(self): + # matmul compilation info template + self.matmul_compilation_info_template = """ // matmul compilation info (tile configuration, translation info, workgroup size) #${compilation_info_name} = #iree_codegen.compilation_info< lowering_config = , @@ -189,8 +206,8 @@ def __init__(self): workgroup_size = [${block_dim_x} : index, ${block_dim_y} : index, ${block_dim_z} : index] > """ - # batch matmul and split-k matmul compilation info template - self.batch_matmul_compilation_info_template = """ + # batch matmul and split-k matmul compilation info template + self.batch_matmul_compilation_info_template = """ // batch matmul compilation info (tile configuration, translation info, workgroup size) #${compilation_info_name} = #iree_codegen.compilation_info< lowering_config = , @@ -199,53 +216,52 @@ def __init__(self): > """ - def emit(self, compilation_info): - """Emits the matmul compilation info as a string.""" - if compilation_info.config_type == CompilationConfigType.Default: - return "" - - values = { - 'compilation_info_name': - compilation_info.name(), - 'translation_info': - TranslationInfoTag[compilation_info.translation_info], - 'threadblock_shape_m': - str(compilation_info.tile_description.threadblock_shape[0]), - 'threadblock_shape_n': - str(compilation_info.tile_description.threadblock_shape[1]), - 'threadblock_shape_k': - str(compilation_info.tile_description.threadblock_shape[2]), - 'stages': - str(compilation_info.tile_description.stages), - 'block_dim_x': - str(compilation_info.tile_description.block_dim[0]), - 'block_dim_y': - str(compilation_info.tile_description.block_dim[1]), - 'block_dim_z': - str(compilation_info.tile_description.block_dim[2]), - } - - # linalg.matmul (without split-k) compilation info template. - compilation_info_template = self.matmul_compilation_info_template - - # linalg.batch_matmul and linalg.matmul (with split-k) have different - # compilation info template from the linalg.matmul (without split-k). - if compilation_info.operation_kind == OperationKind.BatchMatmul or \ - compilation_info.operation_kind == OperationKind.SplitkMatmul: - compilation_info_template = self.batch_matmul_compilation_info_template - - return SubstituteTemplate(compilation_info_template, values) + def emit(self, compilation_info): + """Emits the matmul compilation info as a string.""" + if compilation_info.config_type == CompilationConfigType.Default: + return "" + + values = { + "compilation_info_name": compilation_info.name(), + "translation_info": TranslationInfoTag[compilation_info.translation_info], + "threadblock_shape_m": str( + compilation_info.tile_description.threadblock_shape[0] + ), + "threadblock_shape_n": str( + compilation_info.tile_description.threadblock_shape[1] + ), + "threadblock_shape_k": str( + compilation_info.tile_description.threadblock_shape[2] + ), + "stages": str(compilation_info.tile_description.stages), + "block_dim_x": str(compilation_info.tile_description.block_dim[0]), + "block_dim_y": str(compilation_info.tile_description.block_dim[1]), + "block_dim_z": str(compilation_info.tile_description.block_dim[2]), + } + + # linalg.matmul (without split-k) compilation info template. + compilation_info_template = self.matmul_compilation_info_template + + # linalg.batch_matmul and linalg.matmul (with split-k) have different + # compilation info template from the linalg.matmul (without split-k). + if ( + compilation_info.operation_kind == OperationKind.BatchMatmul + or compilation_info.operation_kind == OperationKind.SplitkMatmul + ): + compilation_info_template = self.batch_matmul_compilation_info_template + + return SubstituteTemplate(compilation_info_template, values) ############################################################################### class EmitLinalgMatmulDispatch: - """Emitters for the `linalg.matmul` dispatch.""" + """Emitters for the `linalg.matmul` dispatch.""" - def __init__(self): - self.mlir_dialect = MlirDialect.Linalg + def __init__(self): + self.mlir_dialect = MlirDialect.Linalg - # linalg.matmul mlir template - self.linalg_row_row_matmul_template = """ + # linalg.matmul mlir template + self.linalg_row_row_matmul_template = """ // Dispatch linalg.matmul row-row layout func.func @${operation_name}_${compilation_info_name}( %lhs: tensor<${problem_m}x${problem_k}x${datatype_lhs}>, @@ -261,344 +277,382 @@ def __init__(self): } """ - def emit(self, matmul_dispatch): - """Emit the matmul operation in the MLIR dialect for a single compilation info""" - compilation_info_attribute_template = """{compilation_info = #${compilation_info_name}}""" - compilation_info_attribute_str = SubstituteTemplate( - compilation_info_attribute_template, - {'compilation_info_name': matmul_dispatch.configuration.name()}) - compilation_info_attribute = compilation_info_attribute_str \ - if matmul_dispatch.configuration.config_type != CompilationConfigType.Default else "" - - values = { - 'operation_name': - matmul_dispatch.operation.name(), - 'compilation_info_attribute': - compilation_info_attribute, - 'problem_m': - str(matmul_dispatch.operation.M), - 'problem_n': - str(matmul_dispatch.operation.N), - 'problem_k': - str(matmul_dispatch.operation.K), - 'datatype_lhs': - DataTypeName[matmul_dispatch.operation.lhs.datatype], - 'datatype_rhs': - DataTypeName[matmul_dispatch.operation.rhs.datatype], - 'datatype_result': - DataTypeName[matmul_dispatch.operation.result.datatype], - 'compilation_info_name': - matmul_dispatch.configuration.name() - } - - return SubstituteTemplate(self.linalg_row_row_matmul_template, values) + def emit(self, matmul_dispatch): + """Emit the matmul operation in the MLIR dialect for a single compilation info""" + compilation_info_attribute_template = ( + """{compilation_info = #${compilation_info_name}}""" + ) + compilation_info_attribute_str = SubstituteTemplate( + compilation_info_attribute_template, + {"compilation_info_name": matmul_dispatch.configuration.name()}, + ) + compilation_info_attribute = ( + compilation_info_attribute_str + if matmul_dispatch.configuration.config_type + != CompilationConfigType.Default + else "" + ) + + values = { + "operation_name": matmul_dispatch.operation.name(), + "compilation_info_attribute": compilation_info_attribute, + "problem_m": str(matmul_dispatch.operation.M), + "problem_n": str(matmul_dispatch.operation.N), + "problem_k": str(matmul_dispatch.operation.K), + "datatype_lhs": DataTypeName[matmul_dispatch.operation.lhs.datatype], + "datatype_rhs": DataTypeName[matmul_dispatch.operation.rhs.datatype], + "datatype_result": DataTypeName[matmul_dispatch.operation.result.datatype], + "compilation_info_name": matmul_dispatch.configuration.name(), + } + + return SubstituteTemplate(self.linalg_row_row_matmul_template, values) ############################################################################### class ReferenceMatmulOp(ReferenceOpInterface): - """Reference implementation for the matmul operation in numpy.""" - - def __init__(self, matmul_operation, op_reference_cache_path, dist_lhs, - dist_rhs): - self.matmul_operation = matmul_operation - self.op_reference_cache_path = op_reference_cache_path - - # Problem shape. - self.M = matmul_operation.M - self.N = matmul_operation.N - self.K = matmul_operation.K - - # Data type for the input and result matrices. - self.dtype_lhs = DataTypeNumPyTag[matmul_operation.lhs.datatype] - self.dtype_rhs = DataTypeNumPyTag[matmul_operation.rhs.datatype] - self.dtype_result = DataTypeNumPyTag[matmul_operation.result.datatype] - - # Distribution of the input tensors. - self.dist_lhs = dist_lhs - self.dist_rhs = dist_rhs - - # Filename for the left hand side input tensor. - self.filename_lhs = "m{problem_m}xk{problem_k}_"\ - "{tensor_description}_{dist}_lhs.npy".format( - problem_m=self.M, - problem_k=self.K, - tensor_description=self.matmul_operation.lhs.name(), - dist=DistributionName[self.dist_lhs]) - - # Filename for the right hand side input tensor. - self.filename_rhs = "k{problem_k}xn{problem_n}_"\ - "{tensor_description}_{dist}_rhs.npy".format( - problem_k=self.K, - problem_n=self.N, - tensor_description=self.matmul_operation.rhs.name(), - dist=DistributionName[self.dist_rhs]) - - # Filename for the reference result tensor. - self.filename_reference_result = "m{problem_m}xn{problem_n}_"\ - "{tensor_description}_reference_result.npy".format( - problem_m=self.M, - problem_n=self.N, - tensor_description=self.matmul_operation.result.name()) - - # Filepath for input and output files. - self.filepath_lhs = self.op_reference_cache_path.joinpath(self.filename_lhs) - self.filepath_rhs = self.op_reference_cache_path.joinpath(self.filename_rhs) - self.filepath_reference_result = self.op_reference_cache_path.joinpath( - self.filename_reference_result) - - def get_input_filepaths(self): - """Returns the list of input file paths.""" - return [self.filepath_lhs, self.filepath_rhs] - - def get_output_filepaths(self): - """Returns the list of expected output file paths.""" - return [self.filepath_reference_result] - - def __call__(self): - """Generates input data, runs reference numpy.matmul, and save npy files to the output directory.""" - # Generate the input data as np.array for the matmul operation. - lhs_np_array = get_np_array(self.matmul_operation.lhs, (self.M, self.K), - self.dist_lhs) - rhs_np_array = get_np_array(self.matmul_operation.rhs, (self.K, self.N), - self.dist_rhs) - - # Run the reference np.matmul and generate result np.array. - result = np.matmul(lhs_np_array, rhs_np_array) - - # Save the input data as np.array for the matmul operation. - np.save(self.filepath_lhs, np.array(lhs_np_array, dtype=self.dtype_lhs)) - np.save(self.filepath_rhs, np.array(rhs_np_array, dtype=self.dtype_rhs)) - - # Save the expected result as an np.array. - np.save(self.filepath_reference_result, - np.array(result, dtype=self.dtype_result)) + """Reference implementation for the matmul operation in numpy.""" + + def __init__(self, matmul_operation, op_reference_cache_path, dist_lhs, dist_rhs): + self.matmul_operation = matmul_operation + self.op_reference_cache_path = op_reference_cache_path + + # Problem shape. + self.M = matmul_operation.M + self.N = matmul_operation.N + self.K = matmul_operation.K + + # Data type for the input and result matrices. + self.dtype_lhs = DataTypeNumPyTag[matmul_operation.lhs.datatype] + self.dtype_rhs = DataTypeNumPyTag[matmul_operation.rhs.datatype] + self.dtype_result = DataTypeNumPyTag[matmul_operation.result.datatype] + + # Distribution of the input tensors. + self.dist_lhs = dist_lhs + self.dist_rhs = dist_rhs + + # Filename for the left hand side input tensor. + self.filename_lhs = ( + "m{problem_m}xk{problem_k}_" + "{tensor_description}_{dist}_lhs.npy".format( + problem_m=self.M, + problem_k=self.K, + tensor_description=self.matmul_operation.lhs.name(), + dist=DistributionName[self.dist_lhs], + ) + ) + + # Filename for the right hand side input tensor. + self.filename_rhs = ( + "k{problem_k}xn{problem_n}_" + "{tensor_description}_{dist}_rhs.npy".format( + problem_k=self.K, + problem_n=self.N, + tensor_description=self.matmul_operation.rhs.name(), + dist=DistributionName[self.dist_rhs], + ) + ) + + # Filename for the reference result tensor. + self.filename_reference_result = ( + "m{problem_m}xn{problem_n}_" + "{tensor_description}_reference_result.npy".format( + problem_m=self.M, + problem_n=self.N, + tensor_description=self.matmul_operation.result.name(), + ) + ) + + # Filepath for input and output files. + self.filepath_lhs = self.op_reference_cache_path.joinpath(self.filename_lhs) + self.filepath_rhs = self.op_reference_cache_path.joinpath(self.filename_rhs) + self.filepath_reference_result = self.op_reference_cache_path.joinpath( + self.filename_reference_result + ) + + def get_input_filepaths(self): + """Returns the list of input file paths.""" + return [self.filepath_lhs, self.filepath_rhs] + + def get_output_filepaths(self): + """Returns the list of expected output file paths.""" + return [self.filepath_reference_result] + + def __call__(self): + """Generates input data, runs reference numpy.matmul, and save npy files to the output directory.""" + # Generate the input data as np.array for the matmul operation. + lhs_np_array = get_np_array( + self.matmul_operation.lhs, (self.M, self.K), self.dist_lhs + ) + rhs_np_array = get_np_array( + self.matmul_operation.rhs, (self.K, self.N), self.dist_rhs + ) + + # Run the reference np.matmul and generate result np.array. + result = np.matmul(lhs_np_array, rhs_np_array) + + # Save the input data as np.array for the matmul operation. + np.save(self.filepath_lhs, np.array(lhs_np_array, dtype=self.dtype_lhs)) + np.save(self.filepath_rhs, np.array(rhs_np_array, dtype=self.dtype_rhs)) + + # Save the expected result as an np.array. + np.save( + self.filepath_reference_result, np.array(result, dtype=self.dtype_result) + ) class CudaMatmulDispatchChecker: - """Given a matmul dispatch, checks if the dispatch is supported by the target GPU.""" - - def __init__(self, args): - self.args = args - - # CUDA shared memory capacity per SM in KB. - self.sharedMemPerSm = { - "sm_80": 163, # 1KB is reserved for the driver. - "sm_86": 99, # 1KB is reserved for the driver - } - - self.cuda_arch = self.args.cuda_arch - self.cuda_smem_capacity_in_bytes = self.sharedMemPerSm[self.cuda_arch] << 10 - - def _is_tile_aligned_shape(self, dispatch): - """Checks if the given dispatch is valid for CUDA.""" - matmul_shape = dispatch.operation.matmul_shape - threadblock_shape = dispatch.configuration.tile_description.threadblock_shape - if len(matmul_shape) != len(threadblock_shape): - raise ValueError( - "Problem shape and threadblock shape must have the same rank.") - is_aligned = all( - a % b == 0 for a, b in zip(matmul_shape, threadblock_shape)) - return is_aligned - - def _cuda_smem_required_in_bytes(self, dispatch): - """Returns size bytes of shared memory required for a given cuda dispatch.""" - threadblock_shape = dispatch.configuration.tile_description.threadblock_shape - num_stages = dispatch.configuration.tile_description.stages - tile_shape_lhs = threadblock_shape[0] * threadblock_shape[2] - tile_shape_rhs = threadblock_shape[2] * threadblock_shape[1] - return ( - (tile_shape_lhs * DataTypeSizeInBits[dispatch.operation.lhs.datatype] + - tile_shape_rhs * DataTypeSizeInBits[dispatch.operation.rhs.datatype]) * - num_stages) // 8 - - def _is_problem_k_divisible_by_split_k(self, dispatch): - """Checks if the given dispatch is valid for CUDA.""" - return dispatch.operation.K % dispatch.operation.split_k_slices == 0 - - def _is_cuda_smem_avialable(self, dispatch): - """Checks if the given dispatch is valid for CUDA.""" - return self._cuda_smem_required_in_bytes( - dispatch) <= self.cuda_smem_capacity_in_bytes - - def is_valid(self, dispatch): - """Checks if the given dispatch is valid for CUDA.""" - if not self._is_tile_aligned_shape(dispatch): - if self.args.verbose: - print(f"[Warning]: {dispatch.name()} is not aligned is being skipped.") - return False - if not self._is_cuda_smem_avialable(dispatch): - if self.args.verbose: - print(f"[Warning]: {dispatch.name()} requires {self._cuda_smem_required_in_bytes(dispatch)} "\ - f"bytes of shared memory, which is larger than the {self.cuda_arch} capacity "\ - f"{self.cuda_smem_capacity_in_bytes} bytes.") - return False - if (dispatch.operation.split_k_slices > - 1) and (not self._is_problem_k_divisible_by_split_k(dispatch)): - if self.args.verbose: - print(f"[Warning]: {dispatch.name()} problem k is not divisible by {dispatch.operation.split_k_slices} "\ - f"split-k slices, which is not supported on LLVM GPU CUDA backend.") - return False - return True + """Given a matmul dispatch, checks if the dispatch is supported by the target GPU.""" + + def __init__(self, args): + self.args = args + + # CUDA shared memory capacity per SM in KB. + self.sharedMemPerSm = { + "sm_80": 163, # 1KB is reserved for the driver. + "sm_86": 99, # 1KB is reserved for the driver + } + + self.cuda_arch = self.args.cuda_arch + self.cuda_smem_capacity_in_bytes = self.sharedMemPerSm[self.cuda_arch] << 10 + + def _is_tile_aligned_shape(self, dispatch): + """Checks if the given dispatch is valid for CUDA.""" + matmul_shape = dispatch.operation.matmul_shape + threadblock_shape = dispatch.configuration.tile_description.threadblock_shape + if len(matmul_shape) != len(threadblock_shape): + raise ValueError( + "Problem shape and threadblock shape must have the same rank." + ) + is_aligned = all(a % b == 0 for a, b in zip(matmul_shape, threadblock_shape)) + return is_aligned + + def _cuda_smem_required_in_bytes(self, dispatch): + """Returns size bytes of shared memory required for a given cuda dispatch.""" + threadblock_shape = dispatch.configuration.tile_description.threadblock_shape + num_stages = dispatch.configuration.tile_description.stages + tile_shape_lhs = threadblock_shape[0] * threadblock_shape[2] + tile_shape_rhs = threadblock_shape[2] * threadblock_shape[1] + return ( + ( + tile_shape_lhs * DataTypeSizeInBits[dispatch.operation.lhs.datatype] + + tile_shape_rhs * DataTypeSizeInBits[dispatch.operation.rhs.datatype] + ) + * num_stages + ) // 8 + + def _is_problem_k_divisible_by_split_k(self, dispatch): + """Checks if the given dispatch is valid for CUDA.""" + return dispatch.operation.K % dispatch.operation.split_k_slices == 0 + + def _is_cuda_smem_avialable(self, dispatch): + """Checks if the given dispatch is valid for CUDA.""" + return ( + self._cuda_smem_required_in_bytes(dispatch) + <= self.cuda_smem_capacity_in_bytes + ) + + def is_valid(self, dispatch): + """Checks if the given dispatch is valid for CUDA.""" + if not self._is_tile_aligned_shape(dispatch): + if self.args.verbose: + print(f"[Warning]: {dispatch.name()} is not aligned is being skipped.") + return False + if not self._is_cuda_smem_avialable(dispatch): + if self.args.verbose: + print( + f"[Warning]: {dispatch.name()} requires {self._cuda_smem_required_in_bytes(dispatch)} " + f"bytes of shared memory, which is larger than the {self.cuda_arch} capacity " + f"{self.cuda_smem_capacity_in_bytes} bytes." + ) + return False + if (dispatch.operation.split_k_slices > 1) and ( + not self._is_problem_k_divisible_by_split_k(dispatch) + ): + if self.args.verbose: + print( + f"[Warning]: {dispatch.name()} problem k is not divisible by {dispatch.operation.split_k_slices} " + f"split-k slices, which is not supported on LLVM GPU CUDA backend." + ) + return False + return True class CudaMatmulGenerator: - """Matmul dispatch generator class. - Generates a list of pre-defined matmul operations with resonable tuning cofigurations. - The generator function are seperated based on the target backend and the data type. - Please see example `MatmulGenerator._cuda_matmul_tensor_cores_f16` for cuda target - backend and f16 data type.""" - - def __init__(self, args): - """Initializes the matmul generator.""" - self.args = args - self.translation_infos = [ - #TranslationInfo.LLVMGPUMatmulSimt, # CUDA Core (SMIT) - #TranslationInfo.LLVMGPUMatmulTensorCore, # Tensor Core (WMMA) - TranslationInfo. - LLVMGPUMatmulTensorCoreMmaSync, # Tensor Core (MMA.SYNC) - ] - - # List of pre-defined threadblock tile shapes for Tensor Core. - self.tile_descriptions_tensor_cores_f16 = [ - TileDescription([256, 128, 32], 3, [64, 4, 1]), - TileDescription([128, 256, 32], 3, [128, 2, 1]), - TileDescription([128, 128, 64], 4, [64, 2, 1]), - TileDescription([128, 128, 32], 5, [64, 2, 1]), - TileDescription([128, 64, 32], 5, [64, 2, 1]), - TileDescription([64, 64, 64], 5, [64, 2, 1]), - TileDescription([64, 64, 32], 10, [64, 2, 1]), - ] - - self.tile_descriptions_tensor_cores_f32 = [ - TileDescription([128, 256, 16], 3, [128, 2, 1]), - TileDescription([256, 128, 16], 3, [64, 4, 1]), - TileDescription([128, 128, 16], 5, [64, 2, 1]), - TileDescription([128, 128, 32], 3, [64, 2, 1]), - TileDescription([128, 128, 32], 4, [64, 2, 1]), - TileDescription([128, 64, 32], 3, [64, 2, 1]), - TileDescription([128, 64, 16], 5, [64, 2, 1]), - TileDescription([64, 64, 32], 3, [64, 2, 1]), - TileDescription([64, 64, 16], 10, [64, 2, 1]), - ] - - # Create a list of matmul problem and initialize with some *default* shapes. - self.matmul_shapes = [[256, 512, 128], [2560, 2560, 2560], - [3456, 1024, 2048]] - - # Append matmul problem with *user* provided shapes. - self.add_cmd_line_shapes() - - # Matmul dispatches collection. - self.dispatches_collection_list = [] - - def add_cmd_line_shapes(self): - """Adds matmul shapes from command line arguments.""" - - m_list = get_cmd_line_argument_list(self.args.problem_m) - n_list = get_cmd_line_argument_list(self.args.problem_n) - k_list = get_cmd_line_argument_list(self.args.problem_k) - - # If no command line matmul problem shapes are provided, only - # use the default shapes. - if len(m_list) == 0 and len(n_list) == 0 and len(k_list) == 0: - return - - # If any of the command line matmul problem shapes are provided, - # set the default shapes to empty problem dimension. - if len(m_list) == 0: - m_list = [256] - if len(n_list) == 0: - n_list = [256] - if len(k_list) == 0: - k_list = [256] - - # Append the command line matmul problem shapes with user-proivded - # matmul problem shapes. - for m in m_list: - for n in n_list: - for k in k_list: - self.matmul_shapes.append([m, n, k]) - - def _cuda_supported_configuration_list(self, operation, configuration_list): - """Returns a list of supported configurations for CUDA.""" - supported_configuration_list = [] - dispatch_checker = CudaMatmulDispatchChecker(self.args) - for configuration in configuration_list: - if not dispatch_checker.is_valid(Dispatch(operation, configuration)): - continue - supported_configuration_list.append(configuration) - - # Return the supported configuration list. - return supported_configuration_list - - def _get_matmul_custom_compilation_info_list(self, tile_descriptions, - translation_infos, - operation_kind): - """Creates a *custom* list of matmul compilation info.""" - configuration_list = [] - for tile_description in tile_descriptions: - for translation_info in translation_infos: - configuration_list.append( - MatmulCompilationInfo(tile_description, translation_info, - operation_kind, CompilationConfigType.Custom)) - return configuration_list - - def _append_matmul_dispatch_collection(self, matmul_shapes, data_type, - configuration_list): - """Appends the matmul dispatches collection with the given configuration list.""" - - # Create dispatches collection for each matmul_shape x configuration list.. - for matmul_shape in matmul_shapes: - operation = MatmulOperation( - matmul_shape,\ - TensorDescription(data_type[0], LayoutType.RowMajor), \ - TensorDescription(data_type[1], LayoutType.RowMajor), \ - TensorDescription(data_type[2], LayoutType.RowMajor)) - - # Filter out configurations that are not supported by LLVM GPU CUDA backend. - supported_configuration_list = self._cuda_supported_configuration_list( - operation, configuration_list) - - # Add default configuration if enabled. - if self.args.default_config: - supported_configuration_list.append( - MatmulCompilationInfo([], [], OperationKind.Matmul, - CompilationConfigType.Default)) - - # Append the dispatch collection. - self.dispatches_collection_list.append(DispatchCollection(\ - operation, supported_configuration_list)) - - def _cuda_matmul_tensor_cores_f16(self): - """Appends dispatches for TensorCore with F16 input, F16 accum, F16 output.""" - configuration_list = self._get_matmul_custom_compilation_info_list( - self.tile_descriptions_tensor_cores_f16, self.translation_infos, - OperationKind.Matmul) - data_type = [DataType.f16, DataType.f16, DataType.f16] - self._append_matmul_dispatch_collection(self.matmul_shapes, data_type, - configuration_list) - - def _cuda_matmul_tensor_cores_f32(self): - """Appends dispatches for TensorCore with F32 input, F32 accum, F32 output.""" - configuration_list = self._get_matmul_custom_compilation_info_list( - self.tile_descriptions_tensor_cores_f32, self.translation_infos, - OperationKind.Matmul) - data_type = [DataType.f32, DataType.f32, DataType.f32] - self._append_matmul_dispatch_collection(self.matmul_shapes, data_type, - configuration_list) - - def _cuda_matmul_tensor_cores_mixed_precision(self): - """Appends dispatches for TensorCore with F16 input, F32 accum, F32 output.""" - configuration_list = self._get_matmul_custom_compilation_info_list( - self.tile_descriptions_tensor_cores_f16, self.translation_infos, - OperationKind.Matmul) - data_type = [DataType.f16, DataType.f16, DataType.f32] - self._append_matmul_dispatch_collection(self.matmul_shapes, data_type, - configuration_list) - - def generate(self): - """Generates a list of matmul operations.""" - self._cuda_matmul_tensor_cores_f16() - self._cuda_matmul_tensor_cores_f32() - self._cuda_matmul_tensor_cores_mixed_precision() - return self.dispatches_collection_list + """Matmul dispatch generator class. + Generates a list of pre-defined matmul operations with resonable tuning cofigurations. + The generator function are seperated based on the target backend and the data type. + Please see example `MatmulGenerator._cuda_matmul_tensor_cores_f16` for cuda target + backend and f16 data type.""" + + def __init__(self, args): + """Initializes the matmul generator.""" + self.args = args + self.translation_infos = [ + # TranslationInfo.LLVMGPUMatmulSimt, # CUDA Core (SMIT) + # TranslationInfo.LLVMGPUMatmulTensorCore, # Tensor Core (WMMA) + TranslationInfo.LLVMGPUMatmulTensorCoreMmaSync, # Tensor Core (MMA.SYNC) + ] + + # List of pre-defined threadblock tile shapes for Tensor Core. + self.tile_descriptions_tensor_cores_f16 = [ + TileDescription([256, 128, 32], 3, [64, 4, 1]), + TileDescription([128, 256, 32], 3, [128, 2, 1]), + TileDescription([128, 128, 64], 4, [64, 2, 1]), + TileDescription([128, 128, 32], 5, [64, 2, 1]), + TileDescription([128, 64, 32], 5, [64, 2, 1]), + TileDescription([64, 64, 64], 5, [64, 2, 1]), + TileDescription([64, 64, 32], 10, [64, 2, 1]), + ] + + self.tile_descriptions_tensor_cores_f32 = [ + TileDescription([128, 256, 16], 3, [128, 2, 1]), + TileDescription([256, 128, 16], 3, [64, 4, 1]), + TileDescription([128, 128, 16], 5, [64, 2, 1]), + TileDescription([128, 128, 32], 3, [64, 2, 1]), + TileDescription([128, 128, 32], 4, [64, 2, 1]), + TileDescription([128, 64, 32], 3, [64, 2, 1]), + TileDescription([128, 64, 16], 5, [64, 2, 1]), + TileDescription([64, 64, 32], 3, [64, 2, 1]), + TileDescription([64, 64, 16], 10, [64, 2, 1]), + ] + + # Create a list of matmul problem and initialize with some *default* shapes. + self.matmul_shapes = [[256, 512, 128], [2560, 2560, 2560], [3456, 1024, 2048]] + + # Append matmul problem with *user* provided shapes. + self.add_cmd_line_shapes() + + # Matmul dispatches collection. + self.dispatches_collection_list = [] + + def add_cmd_line_shapes(self): + """Adds matmul shapes from command line arguments.""" + + m_list = get_cmd_line_argument_list(self.args.problem_m) + n_list = get_cmd_line_argument_list(self.args.problem_n) + k_list = get_cmd_line_argument_list(self.args.problem_k) + + # If no command line matmul problem shapes are provided, only + # use the default shapes. + if len(m_list) == 0 and len(n_list) == 0 and len(k_list) == 0: + return + + # If any of the command line matmul problem shapes are provided, + # set the default shapes to empty problem dimension. + if len(m_list) == 0: + m_list = [256] + if len(n_list) == 0: + n_list = [256] + if len(k_list) == 0: + k_list = [256] + + # Append the command line matmul problem shapes with user-proivded + # matmul problem shapes. + for m in m_list: + for n in n_list: + for k in k_list: + self.matmul_shapes.append([m, n, k]) + + def _cuda_supported_configuration_list(self, operation, configuration_list): + """Returns a list of supported configurations for CUDA.""" + supported_configuration_list = [] + dispatch_checker = CudaMatmulDispatchChecker(self.args) + for configuration in configuration_list: + if not dispatch_checker.is_valid(Dispatch(operation, configuration)): + continue + supported_configuration_list.append(configuration) + + # Return the supported configuration list. + return supported_configuration_list + + def _get_matmul_custom_compilation_info_list( + self, tile_descriptions, translation_infos, operation_kind + ): + """Creates a *custom* list of matmul compilation info.""" + configuration_list = [] + for tile_description in tile_descriptions: + for translation_info in translation_infos: + configuration_list.append( + MatmulCompilationInfo( + tile_description, + translation_info, + operation_kind, + CompilationConfigType.Custom, + ) + ) + return configuration_list + + def _append_matmul_dispatch_collection( + self, matmul_shapes, data_type, configuration_list + ): + """Appends the matmul dispatches collection with the given configuration list.""" + + # Create dispatches collection for each matmul_shape x configuration list.. + for matmul_shape in matmul_shapes: + operation = MatmulOperation( + matmul_shape, + TensorDescription(data_type[0], LayoutType.RowMajor), + TensorDescription(data_type[1], LayoutType.RowMajor), + TensorDescription(data_type[2], LayoutType.RowMajor), + ) + + # Filter out configurations that are not supported by LLVM GPU CUDA backend. + supported_configuration_list = self._cuda_supported_configuration_list( + operation, configuration_list + ) + + # Add default configuration if enabled. + if self.args.default_config: + supported_configuration_list.append( + MatmulCompilationInfo( + [], [], OperationKind.Matmul, CompilationConfigType.Default + ) + ) + + # Append the dispatch collection. + self.dispatches_collection_list.append( + DispatchCollection(operation, supported_configuration_list) + ) + + def _cuda_matmul_tensor_cores_f16(self): + """Appends dispatches for TensorCore with F16 input, F16 accum, F16 output.""" + configuration_list = self._get_matmul_custom_compilation_info_list( + self.tile_descriptions_tensor_cores_f16, + self.translation_infos, + OperationKind.Matmul, + ) + data_type = [DataType.f16, DataType.f16, DataType.f16] + self._append_matmul_dispatch_collection( + self.matmul_shapes, data_type, configuration_list + ) + + def _cuda_matmul_tensor_cores_f32(self): + """Appends dispatches for TensorCore with F32 input, F32 accum, F32 output.""" + configuration_list = self._get_matmul_custom_compilation_info_list( + self.tile_descriptions_tensor_cores_f32, + self.translation_infos, + OperationKind.Matmul, + ) + data_type = [DataType.f32, DataType.f32, DataType.f32] + self._append_matmul_dispatch_collection( + self.matmul_shapes, data_type, configuration_list + ) + + def _cuda_matmul_tensor_cores_mixed_precision(self): + """Appends dispatches for TensorCore with F16 input, F32 accum, F32 output.""" + configuration_list = self._get_matmul_custom_compilation_info_list( + self.tile_descriptions_tensor_cores_f16, + self.translation_infos, + OperationKind.Matmul, + ) + data_type = [DataType.f16, DataType.f16, DataType.f32] + self._append_matmul_dispatch_collection( + self.matmul_shapes, data_type, configuration_list + ) + + def generate(self): + """Generates a list of matmul operations.""" + self._cuda_matmul_tensor_cores_f16() + self._cuda_matmul_tensor_cores_f32() + self._cuda_matmul_tensor_cores_mixed_precision() + return self.dispatches_collection_list diff --git a/experimental/dispatch_profiler/options.py b/experimental/dispatch_profiler/options.py index 0c88ea28d6f4..b53d18364b9d 100644 --- a/experimental/dispatch_profiler/options.py +++ b/experimental/dispatch_profiler/options.py @@ -20,148 +20,234 @@ def add_typical_arguments(parser): - """Adds typical command line arguments to the parser.""" - parser.add_argument("--iree-bin-dir", default="./tools", \ - help="Directory containing IREE binaries, "\ - "e.g. iree-compile, iree-benchmark-module, "\ - "iree-run-module") - parser.add_argument("--generated-dir", default=".", \ - help="The dispatch profiler scripts generate "\ - "mlir dispatches, compiled vmfbs, and reference_chache "\ - "containing golden npy files in the generated-dir") - parser.add_argument("--operation-kind","--op-kind", \ - dest="operation_kind", default="all", \ - help="Specifies the operation kinds to generate.", \ - choices=["matmul", "conv2d", "all"]) - parser.add_argument("--dispatches", default='', - help="Comma delimited list to filter dispatches by name. "\ - "A dispatch is a combination of operation and tuning "\ - "configuration.") - parser.add_argument("--mlir-dialect", default='linalg', \ - help="MLIR dialect entry point at which operation is emitter.", - choices=["linalg"]) - parser.add_argument("--verbose", action='store_true', \ - help='Prints verbose output and commands executed.') - parser.add_argument("--dry-run", action='store_true', \ - help='Prints commands that will be executed without actually '\ - 'executing them.') - parser.add_argument("--default-config", action='store_true', - help="Adds a dispatch without a pre-defined "\ - "tuning configuration. This dispatch will use "\ - "default configuration from KernelsConfig.cpp.") + """Adds typical command line arguments to the parser.""" + parser.add_argument( + "--iree-bin-dir", + default="./tools", + help="Directory containing IREE binaries, " + "e.g. iree-compile, iree-benchmark-module, " + "iree-run-module", + ) + parser.add_argument( + "--generated-dir", + default=".", + help="The dispatch profiler scripts generate " + "mlir dispatches, compiled vmfbs, and reference_chache " + "containing golden npy files in the generated-dir", + ) + parser.add_argument( + "--operation-kind", + "--op-kind", + dest="operation_kind", + default="all", + help="Specifies the operation kinds to generate.", + choices=["matmul", "conv2d", "all"], + ) + parser.add_argument( + "--dispatches", + default="", + help="Comma delimited list to filter dispatches by name. " + "A dispatch is a combination of operation and tuning " + "configuration.", + ) + parser.add_argument( + "--mlir-dialect", + default="linalg", + help="MLIR dialect entry point at which operation is emitter.", + choices=["linalg"], + ) + parser.add_argument( + "--verbose", + action="store_true", + help="Prints verbose output and commands executed.", + ) + parser.add_argument( + "--dry-run", + action="store_true", + help="Prints commands that will be executed without actually " + "executing them.", + ) + parser.add_argument( + "--default-config", + action="store_true", + help="Adds a dispatch without a pre-defined " + "tuning configuration. This dispatch will use " + "default configuration from KernelsConfig.cpp.", + ) def add_compilation_arguments(parser): - """Adds compilation (not part of iree-compile) command line arguments to the parser.""" - compilation_parser = parser.add_argument_group( - 'Compilation', 'Compilation related options.') - compilation_parser.add_argument("--num-cpu", "-j", \ - dest="num_cpu", type=int, default=-1, \ - help="Number of cpu threads to use for compilation.") - compilation_parser.add_argument("--force-compile", action='store_true', \ - help="Force re-compilation of the operation even "\ - "if .vmfb file is present.") + """Adds compilation (not part of iree-compile) command line arguments to the parser.""" + compilation_parser = parser.add_argument_group( + "Compilation", "Compilation related options." + ) + compilation_parser.add_argument( + "--num-cpu", + "-j", + dest="num_cpu", + type=int, + default=-1, + help="Number of cpu threads to use for compilation.", + ) + compilation_parser.add_argument( + "--force-compile", + action="store_true", + help="Force re-compilation of the operation even " "if .vmfb file is present.", + ) def add_iree_compile_arguments(parser): - """Adds iree-compile command line arguments to the parser.""" - iree_compile_parser = parser.add_argument_group( - 'iree-compile', 'iree-compile related options.') - - iree_compile_parser.add_argument( - "--iree-hal-target-backends", "--device", \ - dest="device", default="cuda", \ - help="Target backends for executable compilation. ", \ - choices=["cuda", "vulkan", "cpu"]) - iree_compile_parser.add_argument( - "--iree-hal-cuda-llvm-target-arch", "--cuda-arch", \ - dest="cuda_arch", default='sm_80', \ - help="Target architecture for the CUDA backend. ", \ - choices=["sm_50", "sm_60", "sm_75", "sm_80", "sm_86"]) - iree_compile_parser.add_argument( - '--iree-hal-benchmark-dispatch-repeat-count', '--batch-size', \ - dest="batch_size", default=100, - help="Number of times dispatch is launched in a loop to "\ - "amortize the launch overhead. This argument is used for "\ - "iree-compile and iree-benchamrk-module. The value used by "\ - "iree-compile and iree-benchamrk-module should be the same.") - iree_compile_parser.add_argument( - '--iree-flow-split-matmul-reduction', '--split-k-slices', \ - dest="split_k_slices", default="", \ - help="Number of slices to split the reduction K-dimension.") - iree_compile_parser.add_argument( - '--iree-codegen-llvmgpu-use-mma-sync', '--use-mma-sync', \ - dest="use_mma_sync", action='store_true', \ - help="Use mma.sync instructions.") - iree_compile_parser.add_argument('--iree-codegen-llvmgpu-use-wmma', '--use-wmma', \ - dest="use_wmma", action='store_true', \ - help="Use wmma instructions.") - iree_compile_parser.add_argument('--mlir-print-ir-after-all', '--print-ir-after-all', \ - dest="mlir_print_ir_after_all", action='store_true', \ - help="Prints IR after all transformations and dumps a "\ - "file print_ir_after_*.mlir file.") + """Adds iree-compile command line arguments to the parser.""" + iree_compile_parser = parser.add_argument_group( + "iree-compile", "iree-compile related options." + ) + + iree_compile_parser.add_argument( + "--iree-hal-target-backends", + "--device", + dest="device", + default="cuda", + help="Target backends for executable compilation. ", + choices=["cuda", "vulkan", "cpu"], + ) + iree_compile_parser.add_argument( + "--iree-hal-cuda-llvm-target-arch", + "--cuda-arch", + dest="cuda_arch", + default="sm_80", + help="Target architecture for the CUDA backend. ", + choices=["sm_50", "sm_60", "sm_75", "sm_80", "sm_86"], + ) + iree_compile_parser.add_argument( + "--iree-hal-benchmark-dispatch-repeat-count", + "--batch-size", + dest="batch_size", + default=100, + help="Number of times dispatch is launched in a loop to " + "amortize the launch overhead. This argument is used for " + "iree-compile and iree-benchamrk-module. The value used by " + "iree-compile and iree-benchamrk-module should be the same.", + ) + iree_compile_parser.add_argument( + "--iree-flow-split-matmul-reduction", + "--split-k-slices", + dest="split_k_slices", + default="", + help="Number of slices to split the reduction K-dimension.", + ) + iree_compile_parser.add_argument( + "--iree-codegen-llvmgpu-use-mma-sync", + "--use-mma-sync", + dest="use_mma_sync", + action="store_true", + help="Use mma.sync instructions.", + ) + iree_compile_parser.add_argument( + "--iree-codegen-llvmgpu-use-wmma", + "--use-wmma", + dest="use_wmma", + action="store_true", + help="Use wmma instructions.", + ) + iree_compile_parser.add_argument( + "--mlir-print-ir-after-all", + "--print-ir-after-all", + dest="mlir_print_ir_after_all", + action="store_true", + help="Prints IR after all transformations and dumps a " + "file print_ir_after_*.mlir file.", + ) def add_verification_arguments(parser): - """Adds verification related arguments to the parser.""" - verification_parser = parser.add_argument_group( - 'Verification', 'Verification related options.') - - verification_parser.add_argument( - "--verification-enabled", default='True', \ - type=str, help="Verify the operation.") - verification_parser.add_argument( - "--verification-providers", default='numpy', \ - choices=["numpy"], - help="Comma delimited list of verification providers.") + """Adds verification related arguments to the parser.""" + verification_parser = parser.add_argument_group( + "Verification", "Verification related options." + ) + + verification_parser.add_argument( + "--verification-enabled", default="True", type=str, help="Verify the operation." + ) + verification_parser.add_argument( + "--verification-providers", + default="numpy", + choices=["numpy"], + help="Comma delimited list of verification providers.", + ) def add_profiling_arguments(parser): - """Adds profiling related arguments to the parser.""" - profiling_parser = parser.add_argument_group( - 'Profiling', 'Profiling (iree-benchmark-module) related options.') - - profiling_parser.add_argument( - "--profiling-enabled", "--benchmark", default='True', \ - type=str, help="Benchmark the operation.") - profiling_parser.add_argument( - "--benchmark-repetitions", default=5, - type=int, help="Number of times benchmark is repeated "\ - "and min, max, median, and average runtimes/gflops are "\ - "reported.") + """Adds profiling related arguments to the parser.""" + profiling_parser = parser.add_argument_group( + "Profiling", "Profiling (iree-benchmark-module) related options." + ) + + profiling_parser.add_argument( + "--profiling-enabled", + "--benchmark", + default="True", + type=str, + help="Benchmark the operation.", + ) + profiling_parser.add_argument( + "--benchmark-repetitions", + default=5, + type=int, + help="Number of times benchmark is repeated " + "and min, max, median, and average runtimes/gflops are " + "reported.", + ) def add_performance_report_arguments(parser): - """Adds performance report related arguments to the parser.""" - - performance_report_parser = parser.add_argument_group( - 'Performance Report', 'Performance report related options.') - - performance_report_parser.add_argument("--output", default='', \ - help="Path to output file for csv readable results.") - performance_report_parser.add_argument("--append", action='store_true', \ - help="Appends the results to existing file. "\ - "o.w., the existing file is overwritten.") - performance_report_parser.add_argument("--tags", default='', \ - help="Inserts leading columns in output table "\ - "and uniform values for each column. Useful for "\ - "generating pivot tables.") + """Adds performance report related arguments to the parser.""" + + performance_report_parser = parser.add_argument_group( + "Performance Report", "Performance report related options." + ) + + performance_report_parser.add_argument( + "--output", default="", help="Path to output file for csv readable results." + ) + performance_report_parser.add_argument( + "--append", + action="store_true", + help="Appends the results to existing file. " + "o.w., the existing file is overwritten.", + ) + performance_report_parser.add_argument( + "--tags", + default="", + help="Inserts leading columns in output table " + "and uniform values for each column. Useful for " + "generating pivot tables.", + ) def add_matmul_arguments(parser): - """Adds matmul related arguments to the parser.""" - - matmul_parser = parser.add_argument_group( - 'Matmul', 'Matrix-multiplication related options.') - matmul_parser.add_argument("--problem-m", default='', \ - help="M dimension of the matrix. "\ - "--problem-m==,*") - matmul_parser.add_argument("--problem-n", default='', \ - help="N dimension of the matrix."\ - "--problem-n==,*") - matmul_parser.add_argument("--problem-k", default='', \ - help="K dimension of the matrix."\ - "--problem-k==,*") + """Adds matmul related arguments to the parser.""" + + matmul_parser = parser.add_argument_group( + "Matmul", "Matrix-multiplication related options." + ) + matmul_parser.add_argument( + "--problem-m", + default="", + help="M dimension of the matrix. " + "--problem-m==,*", + ) + matmul_parser.add_argument( + "--problem-n", + default="", + help="N dimension of the matrix." + "--problem-n==,*", + ) + matmul_parser.add_argument( + "--problem-k", + default="", + help="K dimension of the matrix." + "--problem-k==,*", + ) ############################################################################### @@ -172,73 +258,76 @@ def add_matmul_arguments(parser): def parse_generator_arguments(parser): - """Adds and parse all the arguments for the *generator.py* script.""" - add_typical_arguments(parser) - add_matmul_arguments(parser) - add_iree_compile_arguments(parser) - args = parser.parse_args() - return args + """Adds and parse all the arguments for the *generator.py* script.""" + add_typical_arguments(parser) + add_matmul_arguments(parser) + add_iree_compile_arguments(parser) + args = parser.parse_args() + return args def parse_compile_arguments(parser): - """Adds and parse all the arguments for the *compile.py* script.""" - add_typical_arguments(parser) - add_compilation_arguments(parser) - add_iree_compile_arguments(parser) - args = parser.parse_args() - return args + """Adds and parse all the arguments for the *compile.py* script.""" + add_typical_arguments(parser) + add_compilation_arguments(parser) + add_iree_compile_arguments(parser) + args = parser.parse_args() + return args def parse_profiler_arguments(parser): - """Adds and parse all the arguments for the *profiler.py* script.""" - add_typical_arguments(parser) - add_compilation_arguments(parser) - add_iree_compile_arguments(parser) - add_verification_arguments(parser) - add_profiling_arguments(parser) - add_performance_report_arguments(parser) - - # Additional arguments for the profiler. - parser.add_argument("--save-cmds", action='store_true', \ - help='Saves commands and their output that are executed '\ - 'by the profiler in a file.') - - args = parser.parse_args() - - # Boolenize the string arguments from command line. For these args, it makes easier - # to read and convey the meaning. The boolean arguments below are specified as: - # `--argument=` - args.verification_enabled = False if args.verification_enabled in [ - 'False', 'false', '0' - ] else True - - args.profiling_enabled = False if args.profiling_enabled in [ - 'False', 'false', '0' - ] else True - - return args + """Adds and parse all the arguments for the *profiler.py* script.""" + add_typical_arguments(parser) + add_compilation_arguments(parser) + add_iree_compile_arguments(parser) + add_verification_arguments(parser) + add_profiling_arguments(parser) + add_performance_report_arguments(parser) + + # Additional arguments for the profiler. + parser.add_argument( + "--save-cmds", + action="store_true", + help="Saves commands and their output that are executed " + "by the profiler in a file.", + ) + + args = parser.parse_args() + + # Boolenize the string arguments from command line. For these args, it makes easier + # to read and convey the meaning. The boolean arguments below are specified as: + # `--argument=` + args.verification_enabled = ( + False if args.verification_enabled in ["False", "false", "0"] else True + ) + + args.profiling_enabled = ( + False if args.profiling_enabled in ["False", "false", "0"] else True + ) + + return args ############################################################################### # Helper functions for parsing command line arguments. ############################################################################### def get_cmd_line_argument_ranges(arg): - """Returns a list of values generated by range of the form start:end:increment.""" - if not arg: - return [] - if ':' not in arg: - return [int(arg)] - range_elements = arg.split(':') - start = int(range_elements[0]) - end = int(range_elements[1]) - increment = int(range_elements[2]) if len(range_elements) == 3 else 1 - return range(start, end, increment) + """Returns a list of values generated by range of the form start:end:increment.""" + if not arg: + return [] + if ":" not in arg: + return [int(arg)] + range_elements = arg.split(":") + start = int(range_elements[0]) + end = int(range_elements[1]) + increment = int(range_elements[2]) if len(range_elements) == 3 else 1 + return range(start, end, increment) def get_cmd_line_argument_list(arg): - """Returns a list of values generated by comma delimited string.""" - values = arg.split(',') - range_list = [] - for val in values: - range_list += get_cmd_line_argument_ranges(val) - return range_list + """Returns a list of values generated by comma delimited string.""" + values = arg.split(",") + range_list = [] + for val in values: + range_list += get_cmd_line_argument_ranges(val) + return range_list diff --git a/experimental/dispatch_profiler/performance_report.py b/experimental/dispatch_profiler/performance_report.py index 37ebb2219edd..e924baf3d7a1 100644 --- a/experimental/dispatch_profiler/performance_report.py +++ b/experimental/dispatch_profiler/performance_report.py @@ -11,139 +11,148 @@ class PerformanceResult: - """Performance result of a single run.""" - - def __init__(self, operation, configuration, verification_result, runtime): - self.operation = operation - self.configuration = configuration - self.verification_result = verification_result - self.runtime = runtime # in milliseconds - self.gflops = float(self.operation.flops()) / self.runtime / 1.0e6 - - def print(self): - """Prints the performance result to the console.""" - runtime = (str(self.runtime) if self.runtime != -1.0 else 'Not profiled') - gflops = (str(float(round(self.gflops, 2))) - if self.runtime != -1.0 else 'Not profiled') - - print('---------------------------------------------------------------- ') - print( - f'Dispatch : {"_".join([self.operation.name(), self.configuration.name()])}' - ) - print(f'Provider : IREE Codegen') - print(f'OpKind : {self.operation.operation_kind}') - print(f'Operation : {self.operation.name()}') - print(f'Configuration : {self.configuration.name()}') - # Operation specific arguments. - arg_str = ' '.join([ - f'--{key}={value}' - for key, value in self.operation.get_argument_dict().items() - ]) - wrapped_arg_str = textwrap.fill(arg_str, - width=80, - subsequent_indent=' ') - print(f'Arguments : {wrapped_arg_str}') - print(f'Verification : {self.verification_result}') - print(f'Runtime(ms) : {runtime}') - print(f'GFLOPs : {gflops}') - - def get_dict_entry(self): - """Returns a dictionary with the performance result.""" - runtime = self.runtime if self.runtime != -1.0 else '' - gflops = (float(round(self.gflops, 2)) - if self.runtime != -1.0 else 'Not run') - dict_entry = { - 'Provider': 'IREE Codegen', - 'Verification': self.verification_result, - 'Runtime(ms)': runtime, - 'GFLOPs': gflops, - } - - # Add the operation specific arguments. - dict_entry.update(self.operation.get_dict_entry()) - - # Add the configuration specific arguments. - dict_entry.update(self.configuration.get_dict_entry()) - - return dict_entry + """Performance result of a single run.""" + + def __init__(self, operation, configuration, verification_result, runtime): + self.operation = operation + self.configuration = configuration + self.verification_result = verification_result + self.runtime = runtime # in milliseconds + self.gflops = float(self.operation.flops()) / self.runtime / 1.0e6 + + def print(self): + """Prints the performance result to the console.""" + runtime = str(self.runtime) if self.runtime != -1.0 else "Not profiled" + gflops = ( + str(float(round(self.gflops, 2))) + if self.runtime != -1.0 + else "Not profiled" + ) + + print("---------------------------------------------------------------- ") + print( + f'Dispatch : {"_".join([self.operation.name(), self.configuration.name()])}' + ) + print(f"Provider : IREE Codegen") + print(f"OpKind : {self.operation.operation_kind}") + print(f"Operation : {self.operation.name()}") + print(f"Configuration : {self.configuration.name()}") + # Operation specific arguments. + arg_str = " ".join( + [ + f"--{key}={value}" + for key, value in self.operation.get_argument_dict().items() + ] + ) + wrapped_arg_str = textwrap.fill( + arg_str, width=80, subsequent_indent=" " + ) + print(f"Arguments : {wrapped_arg_str}") + print(f"Verification : {self.verification_result}") + print(f"Runtime(ms) : {runtime}") + print(f"GFLOPs : {gflops}") + + def get_dict_entry(self): + """Returns a dictionary with the performance result.""" + runtime = self.runtime if self.runtime != -1.0 else "" + gflops = float(round(self.gflops, 2)) if self.runtime != -1.0 else "Not run" + dict_entry = { + "Provider": "IREE Codegen", + "Verification": self.verification_result, + "Runtime(ms)": runtime, + "GFLOPs": gflops, + } + + # Add the operation specific arguments. + dict_entry.update(self.operation.get_dict_entry()) + + # Add the configuration specific arguments. + dict_entry.update(self.configuration.get_dict_entry()) + + return dict_entry class PerformanceReport: - """Performance report class is used to store the performance results of multiple runs. - The report can be written to a csv file.""" - - def __init__(self, args): - self.args = args - - # Data members extracted from the args. - self.output_file_path = None - if args.output != '': - self.output_file_path = Path(args.output) - - # List of PerformanceResult. - self.perf_result_vector = [] - - # Additional tags to add to the csv report file. \ - # Useful for generating pivot tables. - self.tags = [] - if args.tags != '': - self.tags = args.tags.split(',') - - # Boolen to check if the header is written to the csv file. - self.is_header_written = False - - # If the args.output set, open the file and write the header. - self.open_mode = 'a' if self.args.append else 'w' - if self.output_file_path: - self.csv_file = open(self.output_file_path, self.open_mode) - - def __del__(self): - """If the args.output set, close the file.""" - if self.output_file_path: - print('Writing performance report to %s' % self.output_file_path) - self.csv_file.close() - - def write_csv_header(self, operation, configuration): - """Write the header to the csv file.""" - - # Create and write the header. - operation_specific_header = list(operation.get_dict_entry().keys()) - configuration_specific_header = list(configuration.get_dict_entry().keys()) - performance_header = ['Verification', 'Runtime(ms)', 'GFLOPs'] - csv_header = operation_specific_header + configuration_specific_header + performance_header - csv_header = ['Provider'] + csv_header - - # If tags are present, add the tags.keys() to the begining of the csv header. - if len(self.tags): - tag_header = [tag.split(':')[0] for tag in self.tags] - csv_header = tag_header + csv_header - - # Create the csv dictionary writer. - self.csv_writer = csv.DictWriter(self.csv_file, fieldnames=csv_header) - - # Write the header if the file is being created. - if self.open_mode == 'w': - self.csv_writer.writeheader() - - def append_perf_result(self, performance_result): - """Appends a performance result to the report. - Additionaly, if args.output set, write the csv_row entry.""" - self.perf_result_vector.append(performance_result) - - if self.output_file_path: - # Write the header if not written. - if not self.is_header_written: - self.write_csv_header(performance_result.operation, - performance_result.configuration) - self.is_header_written = True - - # Create the row entries for performance result. - csv_dict_row = performance_result.get_dict_entry() - - # Create the row entries for tags. - for tag in self.tags: - tag_key, tag_value = tag.split(':') - csv_dict_row[tag_key] = tag_value - - # Write the row. - self.csv_writer.writerow(csv_dict_row) + """Performance report class is used to store the performance results of multiple runs. + The report can be written to a csv file.""" + + def __init__(self, args): + self.args = args + + # Data members extracted from the args. + self.output_file_path = None + if args.output != "": + self.output_file_path = Path(args.output) + + # List of PerformanceResult. + self.perf_result_vector = [] + + # Additional tags to add to the csv report file. \ + # Useful for generating pivot tables. + self.tags = [] + if args.tags != "": + self.tags = args.tags.split(",") + + # Boolen to check if the header is written to the csv file. + self.is_header_written = False + + # If the args.output set, open the file and write the header. + self.open_mode = "a" if self.args.append else "w" + if self.output_file_path: + self.csv_file = open(self.output_file_path, self.open_mode) + + def __del__(self): + """If the args.output set, close the file.""" + if self.output_file_path: + print("Writing performance report to %s" % self.output_file_path) + self.csv_file.close() + + def write_csv_header(self, operation, configuration): + """Write the header to the csv file.""" + + # Create and write the header. + operation_specific_header = list(operation.get_dict_entry().keys()) + configuration_specific_header = list(configuration.get_dict_entry().keys()) + performance_header = ["Verification", "Runtime(ms)", "GFLOPs"] + csv_header = ( + operation_specific_header + + configuration_specific_header + + performance_header + ) + csv_header = ["Provider"] + csv_header + + # If tags are present, add the tags.keys() to the begining of the csv header. + if len(self.tags): + tag_header = [tag.split(":")[0] for tag in self.tags] + csv_header = tag_header + csv_header + + # Create the csv dictionary writer. + self.csv_writer = csv.DictWriter(self.csv_file, fieldnames=csv_header) + + # Write the header if the file is being created. + if self.open_mode == "w": + self.csv_writer.writeheader() + + def append_perf_result(self, performance_result): + """Appends a performance result to the report. + Additionaly, if args.output set, write the csv_row entry.""" + self.perf_result_vector.append(performance_result) + + if self.output_file_path: + # Write the header if not written. + if not self.is_header_written: + self.write_csv_header( + performance_result.operation, performance_result.configuration + ) + self.is_header_written = True + + # Create the row entries for performance result. + csv_dict_row = performance_result.get_dict_entry() + + # Create the row entries for tags. + for tag in self.tags: + tag_key, tag_value = tag.split(":") + csv_dict_row[tag_key] = tag_value + + # Write the row. + self.csv_writer.writerow(csv_dict_row) diff --git a/experimental/dispatch_profiler/profiler.py b/experimental/dispatch_profiler/profiler.py index 2fba3f93967c..f9e460f53fae 100644 --- a/experimental/dispatch_profiler/profiler.py +++ b/experimental/dispatch_profiler/profiler.py @@ -36,60 +36,63 @@ ############################################################################### if __name__ == "__main__": - ############################################################################### - # Parse command line arguments - ############################################################################### - parser = argparse.ArgumentParser(description="IREE Python profiler tool for "\ - "verifcation and performance profiling tool for IREE-compiled "\ - "MLIR operations.") - - args = parse_profiler_arguments(parser) - ############################################################################### - - # Create manifest object and load dispatches. - manifest = Manifest(args) - manifest.load() - - # Performance report - perf_report = PerformanceReport(args) - - # For all the operations in the manifest compile (if needed), verify, and profile. - for _, dispatch_collection_list in manifest.dispatch_collection_map.items(): - for dispatch_collection in dispatch_collection_list: - - operation = dispatch_collection.operation - # Select and create an instance of operation_launcher for the operation. - operation_launcher = IreeToolsLauncher(args, operation) - for configuration in dispatch_collection.configuration_list: - - # Create a dispatch object. - dispatch = Dispatch(operation, configuration) - - # Skip the dispatch if filter returns false. - if not manifest.is_enabled(dispatch): - continue - - # If dry run is enabled, skip the dispatch. - if args.dry_run: - print(f'[Dry run] : {dispatch.name()}') - continue - - # Initialize verification and profiling results. - verification_result = 'Not verified' if not args.verification_enabled else 'Failed' - runtime = -1.0 - - # Launch the operation dispatches for verification and profiling. - if args.verification_enabled: - verification_result = operation_launcher.verify(configuration) - if args.profiling_enabled: - runtime = operation_launcher.profile(configuration) - - # Create performance result. - result = PerformanceResult(operation, configuration, - verification_result, runtime) - - # Print the performance result. - result.print() - - # Append the performance result to the performance report. - perf_report.append_perf_result(result) + ############################################################################### + # Parse command line arguments + ############################################################################### + parser = argparse.ArgumentParser( + description="IREE Python profiler tool for " + "verifcation and performance profiling tool for IREE-compiled " + "MLIR operations." + ) + + args = parse_profiler_arguments(parser) + ############################################################################### + + # Create manifest object and load dispatches. + manifest = Manifest(args) + manifest.load() + + # Performance report + perf_report = PerformanceReport(args) + + # For all the operations in the manifest compile (if needed), verify, and profile. + for _, dispatch_collection_list in manifest.dispatch_collection_map.items(): + for dispatch_collection in dispatch_collection_list: + operation = dispatch_collection.operation + # Select and create an instance of operation_launcher for the operation. + operation_launcher = IreeToolsLauncher(args, operation) + for configuration in dispatch_collection.configuration_list: + # Create a dispatch object. + dispatch = Dispatch(operation, configuration) + + # Skip the dispatch if filter returns false. + if not manifest.is_enabled(dispatch): + continue + + # If dry run is enabled, skip the dispatch. + if args.dry_run: + print(f"[Dry run] : {dispatch.name()}") + continue + + # Initialize verification and profiling results. + verification_result = ( + "Not verified" if not args.verification_enabled else "Failed" + ) + runtime = -1.0 + + # Launch the operation dispatches for verification and profiling. + if args.verification_enabled: + verification_result = operation_launcher.verify(configuration) + if args.profiling_enabled: + runtime = operation_launcher.profile(configuration) + + # Create performance result. + result = PerformanceResult( + operation, configuration, verification_result, runtime + ) + + # Print the performance result. + result.print() + + # Append the performance result to the performance report. + perf_report.append_perf_result(result) diff --git a/experimental/dispatch_profiler/split_k_matmul.py b/experimental/dispatch_profiler/split_k_matmul.py index fc68512b82e6..02f462e2a2d2 100644 --- a/experimental/dispatch_profiler/split_k_matmul.py +++ b/experimental/dispatch_profiler/split_k_matmul.py @@ -10,73 +10,83 @@ class CudaSplitKMatmulGenerator(CudaMatmulGenerator): - """SplitK Matmul dispatch generator class.""" + """SplitK Matmul dispatch generator class.""" - def __init__(self, args): - """Initializes the splitK matmul generator.""" - super().__init__(args) + def __init__(self, args): + """Initializes the splitK matmul generator.""" + super().__init__(args) - # Predefined matmul shapes for splitK matmul. - self.matmul_shapes = [[128, 128, 12288]] + # Predefined matmul shapes for splitK matmul. + self.matmul_shapes = [[128, 128, 12288]] - # Predefined split_k_slices list for splitK matmul. - self.split_k_slices = [2, 4, 16, 18] + # Predefined split_k_slices list for splitK matmul. + self.split_k_slices = [2, 4, 16, 18] - # SplitK matmul dispatches collection list. - self.dispatches_collection_list = [] + # SplitK matmul dispatches collection list. + self.dispatches_collection_list = [] - def _append_matmul_dispatch_collection(self, matmul_shapes, split_k_slices, - data_type, configuration_list): - """Appends the split-k matmul dispatches collection with the given configuration list.""" + def _append_matmul_dispatch_collection( + self, matmul_shapes, split_k_slices, data_type, configuration_list + ): + """Appends the split-k matmul dispatches collection with the given configuration list.""" - # Create dispatches collection for each matmul_shape x split_k_slice x configuration list. - for matmul_shape in matmul_shapes: - for split_k_slice in split_k_slices: - operation = MatmulOperation( - matmul_shape,\ - TensorDescription(data_type[0], LayoutType.RowMajor), \ - TensorDescription(data_type[1], LayoutType.RowMajor), \ - TensorDescription(data_type[2], LayoutType.RowMajor), \ - 1, # batch_count - split_k_slice, - OperationKind.SplitkMatmul) + # Create dispatches collection for each matmul_shape x split_k_slice x configuration list. + for matmul_shape in matmul_shapes: + for split_k_slice in split_k_slices: + operation = MatmulOperation( + matmul_shape, + TensorDescription(data_type[0], LayoutType.RowMajor), + TensorDescription(data_type[1], LayoutType.RowMajor), + TensorDescription(data_type[2], LayoutType.RowMajor), + 1, # batch_count + split_k_slice, + OperationKind.SplitkMatmul, + ) - # Filter out configurations that are not supported by LLVM GPU CUDA backend. - supported_configuration_list = self._cuda_supported_configuration_list( - operation, configuration_list) + # Filter out configurations that are not supported by LLVM GPU CUDA backend. + supported_configuration_list = self._cuda_supported_configuration_list( + operation, configuration_list + ) - # Add default configuration if enabled. - if self.args.default_config: - supported_configuration_list.append( - MatmulCompilationInfo([], [], OperationKind.Matmul, - CompilationConfigType.Default)) + # Add default configuration if enabled. + if self.args.default_config: + supported_configuration_list.append( + MatmulCompilationInfo( + [], [], OperationKind.Matmul, CompilationConfigType.Default + ) + ) - # Append the dispatch collection. - self.dispatches_collection_list.append(DispatchCollection(\ - operation, supported_configuration_list)) + # Append the dispatch collection. + self.dispatches_collection_list.append( + DispatchCollection(operation, supported_configuration_list) + ) - def _cuda_matmul_tensor_cores_f16(self): - """Appends a list of matmul split-k dispatches for GPU TensorCore F16 data type.""" - configuration_list = self._get_matmul_custom_compilation_info_list( - self.tile_descriptions_tensor_cores_f16, self.translation_infos, - OperationKind.SplitkMatmul) - data_type = [DataType.f16, DataType.f16, DataType.f16] - self._append_matmul_dispatch_collection(self.matmul_shapes, - self.split_k_slices, data_type, - configuration_list) + def _cuda_matmul_tensor_cores_f16(self): + """Appends a list of matmul split-k dispatches for GPU TensorCore F16 data type.""" + configuration_list = self._get_matmul_custom_compilation_info_list( + self.tile_descriptions_tensor_cores_f16, + self.translation_infos, + OperationKind.SplitkMatmul, + ) + data_type = [DataType.f16, DataType.f16, DataType.f16] + self._append_matmul_dispatch_collection( + self.matmul_shapes, self.split_k_slices, data_type, configuration_list + ) - def _cuda_matmul_tensor_cores_f32(self): - """Appends a list of matmul split-k dispatches for GPU TensorCore F32 data type.""" - configuration_list = self._get_matmul_custom_compilation_info_list( - self.tile_descriptions_tensor_cores_f32, self.translation_infos, - OperationKind.SplitkMatmul) - data_type = [DataType.f32, DataType.f32, DataType.f32] - self._append_matmul_dispatch_collection(self.matmul_shapes, - self.split_k_slices, data_type, - configuration_list) + def _cuda_matmul_tensor_cores_f32(self): + """Appends a list of matmul split-k dispatches for GPU TensorCore F32 data type.""" + configuration_list = self._get_matmul_custom_compilation_info_list( + self.tile_descriptions_tensor_cores_f32, + self.translation_infos, + OperationKind.SplitkMatmul, + ) + data_type = [DataType.f32, DataType.f32, DataType.f32] + self._append_matmul_dispatch_collection( + self.matmul_shapes, self.split_k_slices, data_type, configuration_list + ) - def generate(self): - """Generates a list of split-k matmul operations.""" - self._cuda_matmul_tensor_cores_f16() - self._cuda_matmul_tensor_cores_f32() - return self.dispatches_collection_list + def generate(self): + """Generates a list of split-k matmul operations.""" + self._cuda_matmul_tensor_cores_f16() + self._cuda_matmul_tensor_cores_f32() + return self.dispatches_collection_list diff --git a/experimental/web/testing/parse_test_list.py b/experimental/web/testing/parse_test_list.py index 8ca2d0655a45..38979c40252f 100644 --- a/experimental/web/testing/parse_test_list.py +++ b/experimental/web/testing/parse_test_list.py @@ -22,117 +22,121 @@ def parse_arguments(): - """Parses command line arguments.""" - parser = argparse.ArgumentParser() - parser.add_argument("--ctest_dump", - type=str, - required=True, - help="Path to the output of `ctest --show-only=json-v1`") - parser.add_argument( - "--build_dir", - type=str, - required=True, - help="Path to the CMake build directory (absolute or relative)") - parser.add_argument( - "--output_format", - type=str, - choices=("html", "json"), - default="html", - help= - "Output format, either 'html' for the test runner or 'json' for a list of JSON objects", - ) - parser.add_argument("-o", - "--output", - type=str, - required=True, - help="Output file path") - return parser.parse_args() + """Parses command line arguments.""" + parser = argparse.ArgumentParser() + parser.add_argument( + "--ctest_dump", + type=str, + required=True, + help="Path to the output of `ctest --show-only=json-v1`", + ) + parser.add_argument( + "--build_dir", + type=str, + required=True, + help="Path to the CMake build directory (absolute or relative)", + ) + parser.add_argument( + "--output_format", + type=str, + choices=("html", "json"), + default="html", + help="Output format, either 'html' for the test runner or 'json' for a list of JSON objects", + ) + parser.add_argument( + "-o", "--output", type=str, required=True, help="Output file path" + ) + return parser.parse_args() def get_normalized_relative_path(absolute_path, root_directory): - # Strip the root directory prefix and get a relative path. - relative_path = os.path.relpath(absolute_path, root_directory) - # Replace the path separator (such as '\' on Windows) with web-style '/'. - normalized_path = relative_path.replace(os.sep, '/') - return normalized_path + # Strip the root directory prefix and get a relative path. + relative_path = os.path.relpath(absolute_path, root_directory) + # Replace the path separator (such as '\' on Windows) with web-style '/'. + normalized_path = relative_path.replace(os.sep, "/") + return normalized_path def parse_ctest_dump(ctest_dump_path, build_dir): - parsed_tests = [] - - # Open the ctest dump JSON file and parse each test. - # https://cmake.org/cmake/help/latest/manual/ctest.1.html#show-as-json-object-model - with open(ctest_dump_path, "rt") as f: - data = json.load(f) - for test in data["tests"]: - parsed_test = { - "testName": test["name"], - "requiredFiles": [], - "args": [], - } - - # Parse the 'command' list into the source file and its arguments. - # /path/to/test_runner.js # such as iree-check-module.js or test.js - # arg 1 # such as --device=local-task - # arg 2 # such as check_vmvx_op.mlir_module.vmfb - test_source_absolute_path = test["command"][0] - parsed_test["sourceFile"] = get_normalized_relative_path( - test_source_absolute_path, build_dir) - - parsed_test["args"] = test["command"][1:] - - # Parse the test "properties". - # Note: required file paths are relative to the working directory. - for property in test["properties"]: - if property["name"] == "REQUIRED_FILES": - parsed_test["requiredFiles"] = property["value"] - elif property["name"] == "WORKING_DIRECTORY": - working_directory_absolute_path = property["value"] - parsed_test["workingDirectory"] = get_normalized_relative_path( - working_directory_absolute_path, build_dir) - - parsed_tests.append(parsed_test) - - print("Parsed {} tests from '{}'".format(len(parsed_tests), ctest_dump_path)) - return parsed_tests + parsed_tests = [] + + # Open the ctest dump JSON file and parse each test. + # https://cmake.org/cmake/help/latest/manual/ctest.1.html#show-as-json-object-model + with open(ctest_dump_path, "rt") as f: + data = json.load(f) + for test in data["tests"]: + parsed_test = { + "testName": test["name"], + "requiredFiles": [], + "args": [], + } + + # Parse the 'command' list into the source file and its arguments. + # /path/to/test_runner.js # such as iree-check-module.js or test.js + # arg 1 # such as --device=local-task + # arg 2 # such as check_vmvx_op.mlir_module.vmfb + test_source_absolute_path = test["command"][0] + parsed_test["sourceFile"] = get_normalized_relative_path( + test_source_absolute_path, build_dir + ) + + parsed_test["args"] = test["command"][1:] + + # Parse the test "properties". + # Note: required file paths are relative to the working directory. + for property in test["properties"]: + if property["name"] == "REQUIRED_FILES": + parsed_test["requiredFiles"] = property["value"] + elif property["name"] == "WORKING_DIRECTORY": + working_directory_absolute_path = property["value"] + parsed_test["workingDirectory"] = get_normalized_relative_path( + working_directory_absolute_path, build_dir + ) + + parsed_tests.append(parsed_test) + + print("Parsed {} tests from '{}'".format(len(parsed_tests), ctest_dump_path)) + return parsed_tests def print_parsed_tests(parsed_tests, output_path, output_format): - with open(output_path, "wt") as f: - if output_format == "html": - print("Outputting parsed tests as HTML to '" + output_path + "'") - for test in parsed_tests: - f.write( - "
  • {testName}
  • \n" - .format(testName=test["testName"], - sourceFile=test["sourceFile"], - workingDirectory=test["workingDirectory"], - requiredFiles="[" + ",".join(test["requiredFiles"]) + "]", - args="[" + ",".join(test["args"]) + "]")) - elif output_format == "json": - print("Outputting parsed tests as JSON to '" + output_path + "'") - f.write(json.dumps(parsed_tests, indent=2)) - else: - raise Exception("Unknown output format: '" + output_format + "'") + with open(output_path, "wt") as f: + if output_format == "html": + print("Outputting parsed tests as HTML to '" + output_path + "'") + for test in parsed_tests: + f.write( + '
  • {testName}
  • \n'.format( + testName=test["testName"], + sourceFile=test["sourceFile"], + workingDirectory=test["workingDirectory"], + requiredFiles="[" + ",".join(test["requiredFiles"]) + "]", + args="[" + ",".join(test["args"]) + "]", + ) + ) + elif output_format == "json": + print("Outputting parsed tests as JSON to '" + output_path + "'") + f.write(json.dumps(parsed_tests, indent=2)) + else: + raise Exception("Unknown output format: '" + output_format + "'") def main(args): - # Refine the provided build directory path to a normalized, absolute path. - build_dir = args.build_dir - if not os.path.isabs(build_dir): - build_dir = os.path.join(os.getcwd(), build_dir) - build_dir = os.path.normpath(build_dir) + # Refine the provided build directory path to a normalized, absolute path. + build_dir = args.build_dir + if not os.path.isabs(build_dir): + build_dir = os.path.join(os.getcwd(), build_dir) + build_dir = os.path.normpath(build_dir) - # Create the output directory as needed (relative paths are fine here). - output_dir = os.path.dirname(args.output) - if output_dir and not os.path.isdir(output_dir): - os.makedirs(output_dir) + # Create the output directory as needed (relative paths are fine here). + output_dir = os.path.dirname(args.output) + if output_dir and not os.path.isdir(output_dir): + os.makedirs(output_dir) - parsed_tests = parse_ctest_dump(args.ctest_dump, build_dir) - parsed_tests.sort(key=lambda test: test["testName"]) + parsed_tests = parse_ctest_dump(args.ctest_dump, build_dir) + parsed_tests.sort(key=lambda test: test["testName"]) - print_parsed_tests(parsed_tests, args.output, args.output_format) + print_parsed_tests(parsed_tests, args.output, args.output_format) if __name__ == "__main__": - main(parse_arguments()) + main(parse_arguments()) diff --git a/integrations/tensorflow/lit.cfg.py b/integrations/tensorflow/lit.cfg.py index ae561057d835..1ced0b6b71c6 100644 --- a/integrations/tensorflow/lit.cfg.py +++ b/integrations/tensorflow/lit.cfg.py @@ -20,6 +20,8 @@ config.test_format = lit.formats.ShTest(execute_external=True) # Use the most preferred temp directory. -config.test_exec_root = (os.environ.get("TEST_UNDECLARED_OUTPUTS_DIR") or - os.environ.get("TEST_TMPDIR") or - os.path.join(tempfile.gettempdir(), "lit")) +config.test_exec_root = ( + os.environ.get("TEST_UNDECLARED_OUTPUTS_DIR") + or os.environ.get("TEST_TMPDIR") + or os.path.join(tempfile.gettempdir(), "lit") +) diff --git a/integrations/tensorflow/python_projects/iree_tf/iree/tf/support/module_utils.py b/integrations/tensorflow/python_projects/iree_tf/iree/tf/support/module_utils.py index 925a2e19f655..8fd45d44429e 100644 --- a/integrations/tensorflow/python_projects/iree_tf/iree/tf/support/module_utils.py +++ b/integrations/tensorflow/python_projects/iree_tf/iree/tf/support/module_utils.py @@ -9,8 +9,7 @@ import collections import os import tempfile -from typing import (Any, Callable, Dict, Optional, Sequence, Set, Tuple, Type, - Union) +from typing import Any, Callable, Dict, Optional, Sequence, Set, Tuple, Type, Union import iree.compiler.tf import iree.runtime @@ -20,57 +19,58 @@ from iree.tf.support import tf_utils flags.DEFINE_bool( - "capture_crash_reproducer", True, + "capture_crash_reproducer", + True, "Captures MLIR crash reproducers in the artifacts directory for crashes " - "and suppresses C++ stack traces.") + "and suppresses C++ stack traces.", +) FLAGS = flags.FLAGS -def _get_tf_import_output_kwargs(artifacts_dir: str, - backend_id: str, - *, - needs_temp_saved_model_dir: bool = False): - """Gets output kwargs dict to pass to tf.compile() for output generation. - - When artifacts_dir is set, writes: - tf_input.mlir: - MLIR for the module in TF's input dialect. - iree_input.mlir: - The MLIR above translated to IREE via compiler.TF_IMPORT_PASS_PIPELINE. - backend_id/compiled.vmfb: - A VM FlatBuffer compiled to the target backend from the IREE MLIR above. - `artifacts_dir/reproducer__{backend}.mlir` in the case of a crash. - - Args: - artifacts_dir: The artifacts directory. - backend_id: The backend id (for artifacts that are backend dependent). - needs_temp_saved_model_dir: Whether a temporary 'saved_model_dir' directory - needs to be set. - - Returns: - A dict of output kwargs. - """ - kwargs = {} - backend_dir = os.path.join(artifacts_dir, backend_id) - os.makedirs(backend_dir, exist_ok=True) - kwargs["output_file"] = os.path.join(backend_dir, "compiled.vmfb") - if needs_temp_saved_model_dir: - kwargs["saved_model_dir"] = os.path.join(artifacts_dir, - "tfmodule.saved_model") - kwargs["save_temp_iree_input"] = os.path.join(artifacts_dir, - "iree_input.mlir") - - # Avoid the crash reproducer under tests or if the flag is false. - if (FLAGS.capture_crash_reproducer): - kwargs["crash_reproducer_path"] = os.path.join( - artifacts_dir, f"reproducer__{backend_id}.mlir") - else: - logging.info("Crash reproducer suppressed") - logging.info( - "Outputting intermediate artifacts (--artifacts_dir is set):\n%s", - "\n".join(f" {k}: {v}" for k, v in kwargs.items())) - return kwargs +def _get_tf_import_output_kwargs( + artifacts_dir: str, backend_id: str, *, needs_temp_saved_model_dir: bool = False +): + """Gets output kwargs dict to pass to tf.compile() for output generation. + + When artifacts_dir is set, writes: + tf_input.mlir: + MLIR for the module in TF's input dialect. + iree_input.mlir: + The MLIR above translated to IREE via compiler.TF_IMPORT_PASS_PIPELINE. + backend_id/compiled.vmfb: + A VM FlatBuffer compiled to the target backend from the IREE MLIR above. + `artifacts_dir/reproducer__{backend}.mlir` in the case of a crash. + + Args: + artifacts_dir: The artifacts directory. + backend_id: The backend id (for artifacts that are backend dependent). + needs_temp_saved_model_dir: Whether a temporary 'saved_model_dir' directory + needs to be set. + + Returns: + A dict of output kwargs. + """ + kwargs = {} + backend_dir = os.path.join(artifacts_dir, backend_id) + os.makedirs(backend_dir, exist_ok=True) + kwargs["output_file"] = os.path.join(backend_dir, "compiled.vmfb") + if needs_temp_saved_model_dir: + kwargs["saved_model_dir"] = os.path.join(artifacts_dir, "tfmodule.saved_model") + kwargs["save_temp_iree_input"] = os.path.join(artifacts_dir, "iree_input.mlir") + + # Avoid the crash reproducer under tests or if the flag is false. + if FLAGS.capture_crash_reproducer: + kwargs["crash_reproducer_path"] = os.path.join( + artifacts_dir, f"reproducer__{backend_id}.mlir" + ) + else: + logging.info("Crash reproducer suppressed") + logging.info( + "Outputting intermediate artifacts (--artifacts_dir is set):\n%s", + "\n".join(f" {k}: {v}" for k, v in kwargs.items()), + ) + return kwargs def _incrementally_compile_tf_module( @@ -79,591 +79,630 @@ def _incrementally_compile_tf_module( exported_names: Sequence[str] = (), artifacts_dir: Optional[str] = None, ) -> Tuple[bytes, Optional[str]]: - """Compile a TensorFlow tf.Module and optionally save compilation artifacts. - - The module blob this creates is not callable. See IreeCompiledModule for an - API that returns a module that can be called without any further steps. - - Args: - module: A tf.Module. - backend_info: BackendInfo with the details for compiling this module. - exported_names: Optional sequence representing the exported names to keep. - artifacts_dir: An optional string pointing to where compilation artifacts - should be saved. No compilation artifacts will be saved if this is not - provided. - - Returns: - A compiled IREE module blob and the path to the compiled VM FlatBuffer if - artifacts_dir is provided. - """ - output_kwargs = (_get_tf_import_output_kwargs( - artifacts_dir, - backend_info.backend_id, - needs_temp_saved_model_dir=True, - ) if artifacts_dir else {}) - - # TODO: Revisit how artifacts_dir is plummed through and figure out how to - # get a meaningful invocation name directly. This isn't really load - # bearing - just adds a bit of usability so long as we have multiple - # methods of saving temp files. - if artifacts_dir: - invocation_id = ( - f"{os.path.basename(artifacts_dir)}__{backend_info.backend_id}") - else: - invocation_id = None - with iree.compiler.TempFileSaver(invocation_id=invocation_id): - immediate_result = iree.compiler.tf.compile_module( - module, - target_backends=backend_info.compiler_targets, - exported_names=exported_names, - **output_kwargs) - - output_file = output_kwargs.get("output_file") - if output_file: - with open(output_file, "rb") as f: - immediate_result = f.read() - return immediate_result, output_file - - -def _incrementally_compile_tf_signature_def_saved_model( - saved_model_dir: str, saved_model_tags: Set[str], backend_info: BackendInfo, - exported_name: str, artifacts_dir: str): - """Compile a SignatureDef SavedModel and optionally save compilation artifacts. - - The module blob this creates is not callable. See IreeCompiledModule for an - API that returns a module that can be called without any further steps. - - Args: - saved_model_dir: Directory of the saved model. - saved_model_tags: Optional set of tags to use when loading the model. - backend_info: BackendInfo with the details for compiling the saved model. - exported_name: A str representing the signature on the saved model to - compile. - artifacts_dir: An optional string pointing to where compilation artifacts - should be saved. No compilation artifacts will be saved if this is not - provided. - - Returns: - A compiled IREE module blob and the path to the compiled VM FlatBuffer if - artifacts_dir is provided. - """ - output_kwargs = (_get_tf_import_output_kwargs( - artifacts_dir, backend_info.backend_id) if artifacts_dir else {}) - immediate_result = iree.compiler.tf.compile_saved_model( - saved_model_dir, - import_type="SIGNATURE_DEF", - target_backends=backend_info.compiler_targets, - exported_names=[exported_name], - saved_model_tags=saved_model_tags, - **output_kwargs) - - output_file = output_kwargs.get("output_file") - if output_file: - with open(output_file, "rb") as f: - immediate_result = f.read() - return immediate_result, output_file - - -class _FunctionWrapper(object): - - def __call__(self, *args, **kwargs): - raise NotImplementedError() - - def get_serialized_values(self) -> Tuple[Tuple[str], Tuple[str]]: - """Dummy function to match _IreeFunctionWrapper's API.""" - return ("",), ("",) - - -class CompiledModule(object): - """Base class for the TF and IREE compiled modules.""" - - def __init__( - self, - module_name: str, - backend_info: BackendInfo, - compiled_paths: Union[Dict[str, str], None], - ): - """Shared base constructor – not useful on its own. + """Compile a TensorFlow tf.Module and optionally save compilation artifacts. - Args: - module_name: A name for this compiled module. In most cases this will be - the name of the tf.Module subclass or instance that is compiled. - backend_info: BackendInfo with the details about compiling this module. - compiled_paths: A dictionary mapping compiled method names to file paths - corresponding to their serialized representations. - """ - self.module_name = module_name - self.backend_info = backend_info - self.compiled_paths = compiled_paths - - def reinitialize(self): - """Reinitializes all stateful variables.""" - raise NotImplementedError() - - @classmethod - def create_from_class(cls, - module_class: Type[tf.Module], - backend_info: BackendInfo, - exported_names: Sequence[str] = (), - artifacts_dir: Optional[str] = None): - """Compile a tf.Module subclass to the target backend in backend_info. + The module blob this creates is not callable. See IreeCompiledModule for an + API that returns a module that can be called without any further steps. Args: - module_class: The tf.Module subclass to compile. + module: A tf.Module. backend_info: BackendInfo with the details for compiling this module. exported_names: Optional sequence representing the exported names to keep. artifacts_dir: An optional string pointing to where compilation artifacts should be saved. No compilation artifacts will be saved if this is not provided. + + Returns: + A compiled IREE module blob and the path to the compiled VM FlatBuffer if + artifacts_dir is provided. """ - raise NotImplementedError() + output_kwargs = ( + _get_tf_import_output_kwargs( + artifacts_dir, + backend_info.backend_id, + needs_temp_saved_model_dir=True, + ) + if artifacts_dir + else {} + ) + + # TODO: Revisit how artifacts_dir is plummed through and figure out how to + # get a meaningful invocation name directly. This isn't really load + # bearing - just adds a bit of usability so long as we have multiple + # methods of saving temp files. + if artifacts_dir: + invocation_id = f"{os.path.basename(artifacts_dir)}__{backend_info.backend_id}" + else: + invocation_id = None + with iree.compiler.TempFileSaver(invocation_id=invocation_id): + immediate_result = iree.compiler.tf.compile_module( + module, + target_backends=backend_info.compiler_targets, + exported_names=exported_names, + **output_kwargs, + ) + + output_file = output_kwargs.get("output_file") + if output_file: + with open(output_file, "rb") as f: + immediate_result = f.read() + return immediate_result, output_file - @classmethod - def create_from_instance(cls, - module_instance: tf.Module, - backend_info: BackendInfo, - exported_names: Sequence[str] = (), - artifacts_dir: Optional[str] = None): - """Compile a tf.Module instance to the target backend in backend_info. - This is only implemented for IreeCompiledModule. +def _incrementally_compile_tf_signature_def_saved_model( + saved_model_dir: str, + saved_model_tags: Set[str], + backend_info: BackendInfo, + exported_name: str, + artifacts_dir: str, +): + """Compile a SignatureDef SavedModel and optionally save compilation artifacts. - Args: - module_instance: The tf.Module instance to compile. - backend_info: BackendInfo with the details for compiling module to IREE. - exported_names: Optional sequence representing the exported names to keep. - artifacts_dir: An optional string pointing to where compilation artifacts - should be saved. No compilation artifacts will be saved if this is not - provided. - """ - raise NotImplementedError() - - @classmethod - def create_from_signature_def_saved_model( - cls, - saved_model_dir: str, - saved_model_tags: Set[str], - module_name: str, - backend_info: BackendInfo, - exported_name: str, - input_names: Sequence[str], - output_names: Sequence[str], - artifacts_dir: Optional[str] = None): - """Compile a SignatureDef SavedModel to the target backend in backend_info. + The module blob this creates is not callable. See IreeCompiledModule for an + API that returns a module that can be called without any further steps. Args: saved_model_dir: Directory of the saved model. saved_model_tags: Optional set of tags to use when loading the model. - module_name: A name for this compiled module. backend_info: BackendInfo with the details for compiling the saved model. exported_name: A str representing the signature on the saved model to compile. - input_names: A sequence of kwargs to feed to the saved model. - output_names: A sequence of named outputs to extract from the saved model. artifacts_dir: An optional string pointing to where compilation artifacts should be saved. No compilation artifacts will be saved if this is not provided. - """ - raise NotImplementedError() - - def __getattr__(self, attr: str) -> _FunctionWrapper: - raise NotImplementedError() - - def iree_serializable(self): - return False - def tflite_serializable(self): - return False - - -class _IreeFunctionWrapper(_FunctionWrapper): - """Wraps an IREE function, making it callable.""" - - def __init__(self, context: iree.runtime.SystemContext, f): - self._context = context - self._f = f - self._inputs = None - - def _get_function_inputs(self, args): - - def flatten(entries): - if entries is None: - return [] - if isinstance(entries, list) or isinstance(entries, tuple): - flattened = [] - for entry in entries: - flattened = flattened + flatten(entry) - return flattened - if isinstance(entries, dict): - flattened = [] - for entry in entries: - entry = entries[entry] - flattened = flattened + flatten(entry) - return flattened - return [entries] - - def convert(arr): - ty = [str(d) for d in arr.shape] - dty = str(arr.dtype) - dty = dty.replace("int", "i") - dty = dty.replace("float", "f") - dty = dty.replace("bool", "i1") - ty.append(dty) - ty = "x".join(ty) - arr = np.asarray(arr).flatten() - if arr.size > 0 and np.all(flatten == arr[0]): - value = arr[0] - else: - value = " ".join([str(a) for a in arr]) - return f"{ty}={value}" - - args = flatten(args) - return [convert(a) for a in args] - - def __call__(self, *args, **kwargs): - - self._inputs = self._get_function_inputs(args) - results = self._f(*args, **kwargs) - self._outputs = self._get_function_inputs(results) - return results - - def get_serialized_values(self) -> Tuple[Tuple[str], Tuple[str]]: - """Get cxx serialized inputs and outputs for this function.""" - return self._inputs, self._outputs - - -class IreeCompiledModule(CompiledModule): - """Iree compiled module.""" - - def __init__( - self, - module_name: str, - backend_info: BackendInfo, - compiled_paths: Dict[str, str], - vm_module: iree.runtime.VmModule, - config: iree.runtime.Config, - ): - """Base constructor – Use one of the named constructors instead. - - Args: - module_name: A name for this compiled module. In most cases this will be - the name of the tf.Module subclass or instance that is compiled. - backend_info: BackendInfo with the details about compiling this module. - compiled_paths: A dictionary mapping compiled method names to file paths - corresponding to their serialized representations. - vm_module: A iree.runtime.VmModule containing compilation info to wrap. - config: A iree.runtime.Config containing compilation info to wrap. + Returns: + A compiled IREE module blob and the path to the compiled VM FlatBuffer if + artifacts_dir is provided. """ - super().__init__(module_name, backend_info, compiled_paths) - self._vm_module = vm_module - self._config = config - self.reinitialize() - - @classmethod - def create_from_class(cls, - module_class: Type[tf.Module], - backend_info: BackendInfo, - exported_names: Sequence[str] = (), - artifacts_dir: Optional[str] = None): - """Compile a tf.Module subclass to the target backend in backend_info. + output_kwargs = ( + _get_tf_import_output_kwargs(artifacts_dir, backend_info.backend_id) + if artifacts_dir + else {} + ) + immediate_result = iree.compiler.tf.compile_saved_model( + saved_model_dir, + import_type="SIGNATURE_DEF", + target_backends=backend_info.compiler_targets, + exported_names=[exported_name], + saved_model_tags=saved_model_tags, + **output_kwargs, + ) - Args: - module_class: The tf.Module subclass to compile. - backend_info: BackendInfo with the details for compiling module to IREE. - exported_names: Optional sequence representing the exported names to keep. - artifacts_dir: An optional string pointing to where compilation artifacts - should be saved. No compilation artifacts will be saved if this is not - provided. - """ - tf_utils.set_random_seed() - module_instance = module_class() - return cls.create_from_instance(module_instance, backend_info, - exported_names, artifacts_dir) - - @classmethod - def create_from_instance(cls, - module_instance: tf.Module, - backend_info: BackendInfo, - exported_names: Sequence[str] = (), - artifacts_dir: Optional[str] = None): - """Compile a tf.Module instance to the target backend in backend_info. + output_file = output_kwargs.get("output_file") + if output_file: + with open(output_file, "rb") as f: + immediate_result = f.read() + return immediate_result, output_file - Args: - module_instance: The tf.Module instance to compile. - backend_info: BackendInfo with the details for compiling module to IREE. - exported_names: Optional sequence representing the exported names to keep. - artifacts_dir: An optional string pointing to where compilation artifacts - should be saved. No compilation artifacts will be saved if this is not - provided. - """ - module_blob, compiled_path = _incrementally_compile_tf_module( - module=module_instance, - backend_info=backend_info, - exported_names=exported_names, - artifacts_dir=artifacts_dir) - config = iree.runtime.Config(driver_name=backend_info.driver) - vm_module = iree.runtime.VmModule.from_flatbuffer(config.vm_instance, - module_blob) - compiled_paths = None - if compiled_path is not None: - # IREE bundles every compiled method into the same compiled module. - compiled_paths = collections.defaultdict(lambda: compiled_path) - - module_name = type(module_instance).__name__ - - return cls(module_name, backend_info, compiled_paths, vm_module, config) - - @classmethod - def create_from_signature_def_saved_model( - cls, - saved_model_dir: str, - saved_model_tags: Set[str], - module_name: str, - backend_info: BackendInfo, - exported_name: str, - input_names: Sequence[str], - output_names: Sequence[str], - artifacts_dir: Optional[str] = None): - """Compile a SignatureDef SavedModel to the target backend in backend_info. +class _FunctionWrapper(object): + def __call__(self, *args, **kwargs): + raise NotImplementedError() - Args: - saved_model_dir: Directory of the saved model. - saved_model_tags: Optional set of tags to use when loading the model. - module_name: A name for this compiled module. - backend_info: BackendInfo with the details for compiling the saved model. - exported_name: A str representing the signature on the saved model to - compile. - input_names: A sequence of kwargs to feed to the saved model. - output_names: A sequence of named outputs to extract from the saved model. - artifacts_dir: An optional string pointing to where compilation artifacts - should be saved. No compilation artifacts will be saved if this is not - provided. - """ - del input_names # Unused. - del output_names # Unused. - module_blob, compiled_path = _incrementally_compile_tf_signature_def_saved_model( - saved_model_dir, saved_model_tags, backend_info, exported_name, - artifacts_dir) - config = iree.runtime.Config(driver_name=backend_info.driver) - vm_module = iree.runtime.VmModule.from_flatbuffer(config.vm_instance, - module_blob) + def get_serialized_values(self) -> Tuple[Tuple[str], Tuple[str]]: + """Dummy function to match _IreeFunctionWrapper's API.""" + return ("",), ("",) - compiled_paths = None - if compiled_path is not None: - # IREE bundles every compiled method into the same compiled module :) - compiled_paths = collections.defaultdict(lambda: compiled_path) - return cls(module_name, backend_info, compiled_paths, vm_module, config) +class CompiledModule(object): + """Base class for the TF and IREE compiled modules.""" + + def __init__( + self, + module_name: str, + backend_info: BackendInfo, + compiled_paths: Union[Dict[str, str], None], + ): + """Shared base constructor – not useful on its own. + + Args: + module_name: A name for this compiled module. In most cases this will be + the name of the tf.Module subclass or instance that is compiled. + backend_info: BackendInfo with the details about compiling this module. + compiled_paths: A dictionary mapping compiled method names to file paths + corresponding to their serialized representations. + """ + self.module_name = module_name + self.backend_info = backend_info + self.compiled_paths = compiled_paths + + def reinitialize(self): + """Reinitializes all stateful variables.""" + raise NotImplementedError() + + @classmethod + def create_from_class( + cls, + module_class: Type[tf.Module], + backend_info: BackendInfo, + exported_names: Sequence[str] = (), + artifacts_dir: Optional[str] = None, + ): + """Compile a tf.Module subclass to the target backend in backend_info. + + Args: + module_class: The tf.Module subclass to compile. + backend_info: BackendInfo with the details for compiling this module. + exported_names: Optional sequence representing the exported names to keep. + artifacts_dir: An optional string pointing to where compilation artifacts + should be saved. No compilation artifacts will be saved if this is not + provided. + """ + raise NotImplementedError() + + @classmethod + def create_from_instance( + cls, + module_instance: tf.Module, + backend_info: BackendInfo, + exported_names: Sequence[str] = (), + artifacts_dir: Optional[str] = None, + ): + """Compile a tf.Module instance to the target backend in backend_info. + + This is only implemented for IreeCompiledModule. + + Args: + module_instance: The tf.Module instance to compile. + backend_info: BackendInfo with the details for compiling module to IREE. + exported_names: Optional sequence representing the exported names to keep. + artifacts_dir: An optional string pointing to where compilation artifacts + should be saved. No compilation artifacts will be saved if this is not + provided. + """ + raise NotImplementedError() + + @classmethod + def create_from_signature_def_saved_model( + cls, + saved_model_dir: str, + saved_model_tags: Set[str], + module_name: str, + backend_info: BackendInfo, + exported_name: str, + input_names: Sequence[str], + output_names: Sequence[str], + artifacts_dir: Optional[str] = None, + ): + """Compile a SignatureDef SavedModel to the target backend in backend_info. + + Args: + saved_model_dir: Directory of the saved model. + saved_model_tags: Optional set of tags to use when loading the model. + module_name: A name for this compiled module. + backend_info: BackendInfo with the details for compiling the saved model. + exported_name: A str representing the signature on the saved model to + compile. + input_names: A sequence of kwargs to feed to the saved model. + output_names: A sequence of named outputs to extract from the saved model. + artifacts_dir: An optional string pointing to where compilation artifacts + should be saved. No compilation artifacts will be saved if this is not + provided. + """ + raise NotImplementedError() + + def __getattr__(self, attr: str) -> _FunctionWrapper: + raise NotImplementedError() + + def iree_serializable(self): + return False + + def tflite_serializable(self): + return False - def reinitialize(self): - """Reinitializes all stateful variables.""" - # set_random_seed is not needed here because the model_class.__init__ is not - # called. - self._context = iree.runtime.SystemContext(vm_modules=[self._vm_module], - config=self._config) - def __getattr__(self, attr: str) -> _IreeFunctionWrapper: - # Try to resolve it as a function. - m = self._context.modules[self._vm_module.name] - f = m[attr] - return _IreeFunctionWrapper(self._context, f) +class _IreeFunctionWrapper(_FunctionWrapper): + """Wraps an IREE function, making it callable.""" + + def __init__(self, context: iree.runtime.SystemContext, f): + self._context = context + self._f = f + self._inputs = None + + def _get_function_inputs(self, args): + def flatten(entries): + if entries is None: + return [] + if isinstance(entries, list) or isinstance(entries, tuple): + flattened = [] + for entry in entries: + flattened = flattened + flatten(entry) + return flattened + if isinstance(entries, dict): + flattened = [] + for entry in entries: + entry = entries[entry] + flattened = flattened + flatten(entry) + return flattened + return [entries] + + def convert(arr): + ty = [str(d) for d in arr.shape] + dty = str(arr.dtype) + dty = dty.replace("int", "i") + dty = dty.replace("float", "f") + dty = dty.replace("bool", "i1") + ty.append(dty) + ty = "x".join(ty) + arr = np.asarray(arr).flatten() + if arr.size > 0 and np.all(flatten == arr[0]): + value = arr[0] + else: + value = " ".join([str(a) for a in arr]) + return f"{ty}={value}" + + args = flatten(args) + return [convert(a) for a in args] + + def __call__(self, *args, **kwargs): + self._inputs = self._get_function_inputs(args) + results = self._f(*args, **kwargs) + self._outputs = self._get_function_inputs(results) + return results + + def get_serialized_values(self) -> Tuple[Tuple[str], Tuple[str]]: + """Get cxx serialized inputs and outputs for this function.""" + return self._inputs, self._outputs - def iree_serializable(self) -> bool: - return self.compiled_paths is not None + +class IreeCompiledModule(CompiledModule): + """Iree compiled module.""" + + def __init__( + self, + module_name: str, + backend_info: BackendInfo, + compiled_paths: Dict[str, str], + vm_module: iree.runtime.VmModule, + config: iree.runtime.Config, + ): + """Base constructor – Use one of the named constructors instead. + + Args: + module_name: A name for this compiled module. In most cases this will be + the name of the tf.Module subclass or instance that is compiled. + backend_info: BackendInfo with the details about compiling this module. + compiled_paths: A dictionary mapping compiled method names to file paths + corresponding to their serialized representations. + vm_module: A iree.runtime.VmModule containing compilation info to wrap. + config: A iree.runtime.Config containing compilation info to wrap. + """ + super().__init__(module_name, backend_info, compiled_paths) + self._vm_module = vm_module + self._config = config + self.reinitialize() + + @classmethod + def create_from_class( + cls, + module_class: Type[tf.Module], + backend_info: BackendInfo, + exported_names: Sequence[str] = (), + artifacts_dir: Optional[str] = None, + ): + """Compile a tf.Module subclass to the target backend in backend_info. + + Args: + module_class: The tf.Module subclass to compile. + backend_info: BackendInfo with the details for compiling module to IREE. + exported_names: Optional sequence representing the exported names to keep. + artifacts_dir: An optional string pointing to where compilation artifacts + should be saved. No compilation artifacts will be saved if this is not + provided. + """ + tf_utils.set_random_seed() + module_instance = module_class() + return cls.create_from_instance( + module_instance, backend_info, exported_names, artifacts_dir + ) + + @classmethod + def create_from_instance( + cls, + module_instance: tf.Module, + backend_info: BackendInfo, + exported_names: Sequence[str] = (), + artifacts_dir: Optional[str] = None, + ): + """Compile a tf.Module instance to the target backend in backend_info. + + Args: + module_instance: The tf.Module instance to compile. + backend_info: BackendInfo with the details for compiling module to IREE. + exported_names: Optional sequence representing the exported names to keep. + artifacts_dir: An optional string pointing to where compilation artifacts + should be saved. No compilation artifacts will be saved if this is not + provided. + """ + module_blob, compiled_path = _incrementally_compile_tf_module( + module=module_instance, + backend_info=backend_info, + exported_names=exported_names, + artifacts_dir=artifacts_dir, + ) + config = iree.runtime.Config(driver_name=backend_info.driver) + vm_module = iree.runtime.VmModule.from_flatbuffer( + config.vm_instance, module_blob + ) + + compiled_paths = None + if compiled_path is not None: + # IREE bundles every compiled method into the same compiled module. + compiled_paths = collections.defaultdict(lambda: compiled_path) + + module_name = type(module_instance).__name__ + + return cls(module_name, backend_info, compiled_paths, vm_module, config) + + @classmethod + def create_from_signature_def_saved_model( + cls, + saved_model_dir: str, + saved_model_tags: Set[str], + module_name: str, + backend_info: BackendInfo, + exported_name: str, + input_names: Sequence[str], + output_names: Sequence[str], + artifacts_dir: Optional[str] = None, + ): + """Compile a SignatureDef SavedModel to the target backend in backend_info. + + Args: + saved_model_dir: Directory of the saved model. + saved_model_tags: Optional set of tags to use when loading the model. + module_name: A name for this compiled module. + backend_info: BackendInfo with the details for compiling the saved model. + exported_name: A str representing the signature on the saved model to + compile. + input_names: A sequence of kwargs to feed to the saved model. + output_names: A sequence of named outputs to extract from the saved model. + artifacts_dir: An optional string pointing to where compilation artifacts + should be saved. No compilation artifacts will be saved if this is not + provided. + """ + del input_names # Unused. + del output_names # Unused. + ( + module_blob, + compiled_path, + ) = _incrementally_compile_tf_signature_def_saved_model( + saved_model_dir, + saved_model_tags, + backend_info, + exported_name, + artifacts_dir, + ) + config = iree.runtime.Config(driver_name=backend_info.driver) + vm_module = iree.runtime.VmModule.from_flatbuffer( + config.vm_instance, module_blob + ) + + compiled_paths = None + if compiled_path is not None: + # IREE bundles every compiled method into the same compiled module :) + compiled_paths = collections.defaultdict(lambda: compiled_path) + + return cls(module_name, backend_info, compiled_paths, vm_module, config) + + def reinitialize(self): + """Reinitializes all stateful variables.""" + # set_random_seed is not needed here because the model_class.__init__ is not + # called. + self._context = iree.runtime.SystemContext( + vm_modules=[self._vm_module], config=self._config + ) + + def __getattr__(self, attr: str) -> _IreeFunctionWrapper: + # Try to resolve it as a function. + m = self._context.modules[self._vm_module.name] + f = m[attr] + return _IreeFunctionWrapper(self._context, f) + + def iree_serializable(self) -> bool: + return self.compiled_paths is not None class _TfFunctionWrapper(_FunctionWrapper): - """Wraps a TF function, normalizing it to numpy.""" + """Wraps a TF function, normalizing it to numpy.""" - def __init__(self, f: Callable[..., Any]): - self._f = f + def __init__(self, f: Callable[..., Any]): + self._f = f - def __call__(self, *args, **kwargs): - # TensorFlow will auto-convert all inbound args. - results = self._f(*args, **kwargs) - return tf_utils.convert_to_numpy(results) + def __call__(self, *args, **kwargs): + # TensorFlow will auto-convert all inbound args. + results = self._f(*args, **kwargs) + return tf_utils.convert_to_numpy(results) def _convert_inputs_to_tensors(function): + def decorator(*args, **kwargs): + args = [tf.convert_to_tensor(arg) for arg in args] + kwargs = {k: tf.convert_to_tensor(v) for k, v in kwargs.items()} + return function(*args, **kwargs) - def decorator(*args, **kwargs): - args = [tf.convert_to_tensor(arg) for arg in args] - kwargs = {k: tf.convert_to_tensor(v) for k, v in kwargs.items()} - return function(*args, **kwargs) - - return decorator + return decorator class SignatureDefSavedModelWrapper(object): - """Wraps a SavedModel to imitate a tf.Module with a method 'exported_name'.""" + """Wraps a SavedModel to imitate a tf.Module with a method 'exported_name'.""" - def __init__(self, saved_model_dir: str, saved_model_tags: Set[str], - exported_name: str): - self.saved_model = tf.saved_model.load(saved_model_dir, - tags=saved_model_tags) - inference_func = self.saved_model.signatures[exported_name] - inference_func = _convert_inputs_to_tensors(inference_func) - self.__setattr__(exported_name, inference_func) + def __init__( + self, saved_model_dir: str, saved_model_tags: Set[str], exported_name: str + ): + self.saved_model = tf.saved_model.load(saved_model_dir, tags=saved_model_tags) + inference_func = self.saved_model.signatures[exported_name] + inference_func = _convert_inputs_to_tensors(inference_func) + self.__setattr__(exported_name, inference_func) class TfCompiledModule(CompiledModule): - """TensorFlow 'compiled' module. - - This facade exists to provide a complimentary API to IreeCompiledModule and - normalize TensorFlow's output to Numpy. - """ - - def __init__( - self, - module_name: str, - backend_info: BackendInfo, - constructor: Callable[[], tf.Module], - exported_names: Sequence[str], - ): - """Base constructor – Use one of the named constructors instead. - - Args: - module_name: A name for this compiled module. In most cases this will be - the name of the tf.Module subclass or instance that is compiled. - backend_info: BackendInfo with the details about compiling this module. - constructor: A callable (class or function) which returns the tf.Module - subclass instance to wrap. - exported_names: an optional iterable of strings representing which of the - tf.Module subclass instance's functions should be callable. If - exported_names is empty then all functions will be callable. - """ - super().__init__(module_name, backend_info, compiled_paths=None) - self._constructor = constructor - self._exported_names = exported_names - self.reinitialize() - - @classmethod - def create_from_class(cls, - module_class: Type[tf.Module], - backend_info: BackendInfo, - exported_names: Sequence[str] = (), - artifacts_dir: Optional[str] = None): - """Compile a tf.Module subclass to the target backend in backend_info. + """TensorFlow 'compiled' module. - Args: - module_class: The tf.Module subclass to compile. - backend_info: BackendInfo with the details for compiling this module. - exported_names: Optional sequence representing the exported names to keep. - artifacts_dir: An optional string pointing to where compilation artifacts - should be saved. No compilation artifacts will be saved if this is not - provided. + This facade exists to provide a complimentary API to IreeCompiledModule and + normalize TensorFlow's output to Numpy. """ - module_name = module_class.__name__ - constructor = module_class - return cls(module_name, backend_info, constructor, exported_names) - - @classmethod - def create_from_signature_def_saved_model( - cls, - saved_model_dir: str, - saved_model_tags: Set[str], - module_name: str, - backend_info: BackendInfo, - exported_name: str, - input_names: Sequence[str], - output_names: Sequence[str], - artifacts_dir: Optional[str] = None): - """Compile a SignatureDef SavedModel to the target backend in backend_info. - Args: - saved_model_dir: Directory of the saved model. - saved_model_tags: Optional set of tags to use when loading the model. - module_name: A name for this compiled module. - backend_info: BackendInfo with the details for compiling the saved model. - exported_name: A str representing the signature on the saved model to - compile. - input_names: A sequence of kwargs to feed to the saved model. - output_names: A sequence of named outputs to extract from the saved model. - artifacts_dir: An optional string pointing to where compilation artifacts - should be saved. No compilation artifacts will be saved if this is not - provided. - """ - constructor = lambda: SignatureDefSavedModelWrapper( - saved_model_dir, saved_model_tags, exported_name) - return cls(module_name, backend_info, constructor, [exported_name]) - - def reinitialize(self): - """Reinitializes all stateful variables.""" - tf_utils.set_random_seed() - self._tf_module = self._constructor() - - def __getattr__(self, attr: str) -> _TfFunctionWrapper: - # Try to resolve it as a function. - exported = not self._exported_names or attr in self._exported_names - if not hasattr(self._tf_module, attr) or not exported: - raise AttributeError(f"The TensorFlow module does not have attr '{attr}'") - f = getattr(self._tf_module, attr) - if not f or not hasattr(f, "__call__"): - raise AttributeError( - f"The TensorFlow module does not have a callable attr '{attr}'") - return _TfFunctionWrapper(f) + def __init__( + self, + module_name: str, + backend_info: BackendInfo, + constructor: Callable[[], tf.Module], + exported_names: Sequence[str], + ): + """Base constructor – Use one of the named constructors instead. + + Args: + module_name: A name for this compiled module. In most cases this will be + the name of the tf.Module subclass or instance that is compiled. + backend_info: BackendInfo with the details about compiling this module. + constructor: A callable (class or function) which returns the tf.Module + subclass instance to wrap. + exported_names: an optional iterable of strings representing which of the + tf.Module subclass instance's functions should be callable. If + exported_names is empty then all functions will be callable. + """ + super().__init__(module_name, backend_info, compiled_paths=None) + self._constructor = constructor + self._exported_names = exported_names + self.reinitialize() + + @classmethod + def create_from_class( + cls, + module_class: Type[tf.Module], + backend_info: BackendInfo, + exported_names: Sequence[str] = (), + artifacts_dir: Optional[str] = None, + ): + """Compile a tf.Module subclass to the target backend in backend_info. + + Args: + module_class: The tf.Module subclass to compile. + backend_info: BackendInfo with the details for compiling this module. + exported_names: Optional sequence representing the exported names to keep. + artifacts_dir: An optional string pointing to where compilation artifacts + should be saved. No compilation artifacts will be saved if this is not + provided. + """ + module_name = module_class.__name__ + constructor = module_class + return cls(module_name, backend_info, constructor, exported_names) + + @classmethod + def create_from_signature_def_saved_model( + cls, + saved_model_dir: str, + saved_model_tags: Set[str], + module_name: str, + backend_info: BackendInfo, + exported_name: str, + input_names: Sequence[str], + output_names: Sequence[str], + artifacts_dir: Optional[str] = None, + ): + """Compile a SignatureDef SavedModel to the target backend in backend_info. + + Args: + saved_model_dir: Directory of the saved model. + saved_model_tags: Optional set of tags to use when loading the model. + module_name: A name for this compiled module. + backend_info: BackendInfo with the details for compiling the saved model. + exported_name: A str representing the signature on the saved model to + compile. + input_names: A sequence of kwargs to feed to the saved model. + output_names: A sequence of named outputs to extract from the saved model. + artifacts_dir: An optional string pointing to where compilation artifacts + should be saved. No compilation artifacts will be saved if this is not + provided. + """ + constructor = lambda: SignatureDefSavedModelWrapper( + saved_model_dir, saved_model_tags, exported_name + ) + return cls(module_name, backend_info, constructor, [exported_name]) + + def reinitialize(self): + """Reinitializes all stateful variables.""" + tf_utils.set_random_seed() + self._tf_module = self._constructor() + + def __getattr__(self, attr: str) -> _TfFunctionWrapper: + # Try to resolve it as a function. + exported = not self._exported_names or attr in self._exported_names + if not hasattr(self._tf_module, attr) or not exported: + raise AttributeError(f"The TensorFlow module does not have attr '{attr}'") + f = getattr(self._tf_module, attr) + if not f or not hasattr(f, "__call__"): + raise AttributeError( + f"The TensorFlow module does not have a callable attr '{attr}'" + ) + return _TfFunctionWrapper(f) def _get_non_inhereted_function_names(cls): - """Gets all methods that cls has that its parents don't have.""" - names = set(dir(cls)) - for parent in cls.__bases__: - names -= set(dir(parent)) - return list(names) - - -def _get_concrete_functions(module_class: Type[tf.Module], - exported_names: Sequence[str] = ()): - """Get concrete functions from non-inherited methods or exported_names.""" - if not len(exported_names): - # Get all method names on 'module_class' that aren't on 'tf.Module'. - exported_names = _get_non_inhereted_function_names(module_class) - instance = module_class() - functions = [] - for name in exported_names: - functions.append(getattr(instance, name).get_concrete_function()) - return functions, exported_names, instance + """Gets all methods that cls has that its parents don't have.""" + names = set(dir(cls)) + for parent in cls.__bases__: + names -= set(dir(parent)) + return list(names) + + +def _get_concrete_functions( + module_class: Type[tf.Module], exported_names: Sequence[str] = () +): + """Get concrete functions from non-inherited methods or exported_names.""" + if not len(exported_names): + # Get all method names on 'module_class' that aren't on 'tf.Module'. + exported_names = _get_non_inhereted_function_names(module_class) + instance = module_class() + functions = [] + for name in exported_names: + functions.append(getattr(instance, name).get_concrete_function()) + return functions, exported_names, instance def tf_module_to_tflite_module_bytes( module_class: Type[tf.Module], exported_names: Sequence[str] = () ) -> Dict[str, bytes]: - """Compiles a tf.Module's methods with TFLite. - - Args: - module_class: A tf.Module subclass to compile with TFLite. - exported_names: an optional iterable of strings representing which of the - module_class's functions should be compiled. If exported_names is empty - then all functions will be compiled. - - Returns: - A dict mapping method names to compiled TFLite module bytes. - """ - tflite_modules = [] - methods, method_names, instance = _get_concrete_functions( - module_class, exported_names) - failed_methods = [] - for method, method_name in zip(methods, method_names): - logging.info("Attempting to convert '%s' to tflite...", method_name) - try: - converter = tf.lite.TFLiteConverter.from_concrete_functions([method], - module_class) - logging.info("...converted '%s' to tflite.", method_name) - tflite_modules.append(converter.convert()) - except Exception as e: - logging.error("Failed to convert '%s' to tflite.", method_name) - logging.error("TFLite excpetion: %s", e) - failed_methods.append(method_name) - - if failed_methods: - raise RuntimeError( - f"Failed to convert the following methods to tflite: {failed_methods}") - - # Keep variables alive until TFLite has done the conversion; ConcreteFunctions - # themselves only keep weak references to variables. - del instance - return dict(zip(method_names, tflite_modules)) + """Compiles a tf.Module's methods with TFLite. + + Args: + module_class: A tf.Module subclass to compile with TFLite. + exported_names: an optional iterable of strings representing which of the + module_class's functions should be compiled. If exported_names is empty + then all functions will be compiled. + + Returns: + A dict mapping method names to compiled TFLite module bytes. + """ + tflite_modules = [] + methods, method_names, instance = _get_concrete_functions( + module_class, exported_names + ) + failed_methods = [] + for method, method_name in zip(methods, method_names): + logging.info("Attempting to convert '%s' to tflite...", method_name) + try: + converter = tf.lite.TFLiteConverter.from_concrete_functions( + [method], module_class + ) + logging.info("...converted '%s' to tflite.", method_name) + tflite_modules.append(converter.convert()) + except Exception as e: + logging.error("Failed to convert '%s' to tflite.", method_name) + logging.error("TFLite excpetion: %s", e) + failed_methods.append(method_name) + + if failed_methods: + raise RuntimeError( + f"Failed to convert the following methods to tflite: {failed_methods}" + ) + + # Keep variables alive until TFLite has done the conversion; ConcreteFunctions + # themselves only keep weak references to variables. + del instance + return dict(zip(method_names, tflite_modules)) def tf_signature_def_saved_model_to_tflite_module_bytes( @@ -673,323 +712,349 @@ def tf_signature_def_saved_model_to_tflite_module_bytes( input_names: Sequence[str], output_names: Sequence[str], ) -> Dict[str, bytes]: - """Compiles a SignatureDef SavedModel signature with TFLite. - - Args: - saved_model_dir: Directory of the saved model. - saved_model_tags: Optional set of tags to use when loading the model. - exported_name: A str representing the signature on the saved model to - compile. - input_names: A sequence of kwargs to feed to the saved model. - output_names: A sequence of named outputs to extract from the saved model. - - Returns: - A dict mapping the signature name to the compiled TFLite module bytes. - """ - converter = tf.compat.v1.lite.TFLiteConverter.from_saved_model( - saved_model_dir, - tag_set=saved_model_tags, - signature_key=exported_name, - input_arrays=input_names, - output_arrays=output_names) - tflite_module = converter.convert() - return dict([[exported_name, tflite_module]]) + """Compiles a SignatureDef SavedModel signature with TFLite. + + Args: + saved_model_dir: Directory of the saved model. + saved_model_tags: Optional set of tags to use when loading the model. + exported_name: A str representing the signature on the saved model to + compile. + input_names: A sequence of kwargs to feed to the saved model. + output_names: A sequence of named outputs to extract from the saved model. + + Returns: + A dict mapping the signature name to the compiled TFLite module bytes. + """ + converter = tf.compat.v1.lite.TFLiteConverter.from_saved_model( + saved_model_dir, + tag_set=saved_model_tags, + signature_key=exported_name, + input_arrays=input_names, + output_arrays=output_names, + ) + tflite_module = converter.convert() + return dict([[exported_name, tflite_module]]) def tflite_module_bytes_to_tflite_interpreters( - tflite_module_bytes: Dict[str, bytes], - artifacts_dir: Optional[str] = None + tflite_module_bytes: Dict[str, bytes], artifacts_dir: Optional[str] = None ) -> Tuple[Dict[str, tf.lite.Interpreter], Union[Dict[str, str], None]]: - """Compile a dict of TFLite compiled bytes to TFLite interpreters. - - Args: - tflite_module_bytes: A dict mapping method names to compiled TFLite byte - strings. - artifacts_dir: an optional path to save compilation artifacts to. - - Returns: - A dictionary mapping method names to TFLite interpreters and a dictionary - mapping method names to compiled tflite graph paths (or None if - artifacts_dir is None). - """ - interpreters = dict() - compiled_paths = None - if artifacts_dir is not None: - compiled_paths = dict() - - def _interpret_bytes(method_name: str, tflite_module: bytes, base_dir: str): - """Save compiled TFLite module bytes and convert into an interpreter.""" - tflite_dir = os.path.join(base_dir, "tflite") - os.makedirs(tflite_dir, exist_ok=True) - tflite_path = os.path.join(tflite_dir, f"{method_name}.tflite") - with open(tflite_path, "wb") as f: - f.write(tflite_module) - - interpreters[method_name] = tf.lite.Interpreter(tflite_path) + """Compile a dict of TFLite compiled bytes to TFLite interpreters. + + Args: + tflite_module_bytes: A dict mapping method names to compiled TFLite byte + strings. + artifacts_dir: an optional path to save compilation artifacts to. + + Returns: + A dictionary mapping method names to TFLite interpreters and a dictionary + mapping method names to compiled tflite graph paths (or None if + artifacts_dir is None). + """ + interpreters = dict() + compiled_paths = None if artifacts_dir is not None: - compiled_paths[method_name] = tflite_path + compiled_paths = dict() - # Load each of the converted methods above into tf.lite.Interpreters. - for method_name, tflite_module in tflite_module_bytes.items(): - if artifacts_dir is None: - with tempfile.TemporaryDirectory() as base_dir: - _interpret_bytes(method_name, tflite_module, base_dir) - else: - _interpret_bytes(method_name, tflite_module, artifacts_dir) + def _interpret_bytes(method_name: str, tflite_module: bytes, base_dir: str): + """Save compiled TFLite module bytes and convert into an interpreter.""" + tflite_dir = os.path.join(base_dir, "tflite") + os.makedirs(tflite_dir, exist_ok=True) + tflite_path = os.path.join(tflite_dir, f"{method_name}.tflite") + with open(tflite_path, "wb") as f: + f.write(tflite_module) - return interpreters, compiled_paths + interpreters[method_name] = tf.lite.Interpreter(tflite_path) + if artifacts_dir is not None: + compiled_paths[method_name] = tflite_path + # Load each of the converted methods above into tf.lite.Interpreters. + for method_name, tflite_module in tflite_module_bytes.items(): + if artifacts_dir is None: + with tempfile.TemporaryDirectory() as base_dir: + _interpret_bytes(method_name, tflite_module, base_dir) + else: + _interpret_bytes(method_name, tflite_module, artifacts_dir) -class _TfLiteFunctionWrapper(_FunctionWrapper): - """Wraps a TFLite interpreter and makes it behave like a python function.""" - - def __init__(self, interpreter: tf.lite.Interpreter, - output_names: Sequence[str]): - self._interpreter = interpreter - self._output_names = output_names - - def __call__(self, *args, - **kwargs) -> Union[Dict[str, Any], Tuple[Any], np.ndarray]: - if len(args) and len(kwargs): - raise ValueError("Passing both args and kwargs is not supported by " - "_TfLiteFunctionWrapper") - - if len(args) == 1 and isinstance(args[0], list): - # Specifically to get TFLite to work with keras models that take a list of - # inputs instead of a sequence of args as their inputs, because it decides - # to change the input signature but it still technically works if you - # ignore that it does that. - if len(args) == 1 and isinstance(args[0], list): - args = args[0] - - # Tell TFLite what the shapes of the input tensors are before allocation. - if args: - for arg, detail in zip(args, self._interpreter.get_input_details()): - self._interpreter.resize_tensor_input(detail["index"], arg.shape) - else: - for detail in self._interpreter.get_input_details(): - self._interpreter.resize_tensor_input(detail["index"], - kwargs[detail["name"]].shape) + return interpreters, compiled_paths - # Allocate the (potentially dynamic) tensors. - self._interpreter.allocate_tensors() - # Copy the input data into the allocated tensors. - if args: - for arg, detail in zip(args, self._interpreter.get_input_details()): - self._interpreter.set_tensor(detail["index"], arg) - else: - for detail in self._interpreter.get_input_details(): - self._interpreter.set_tensor(detail["index"], kwargs[detail["name"]]) - - # Execute the function. - self._interpreter.invoke() - - # Extract the outputs from the TFLite interpreter. - outputs = [] - for detail in self._interpreter.get_output_details(): - # Normalize for comparison with IREE. - value = tf_utils.convert_to_numpy( - self._interpreter.get_tensor(detail["index"])) - if self._output_names is not None: - name = detail["name"] - if name not in self._output_names: - raise ValueError(f"Expected '{name}' to be in {self._output_names}") - outputs.append([detail["name"], value]) - else: - outputs.append(value) - - # Process them to match the output of the tf.Module. - if self._output_names is not None: - return dict(outputs) - else: - if len(outputs) == 1: - return outputs[0] - return tuple(outputs) +class _TfLiteFunctionWrapper(_FunctionWrapper): + """Wraps a TFLite interpreter and makes it behave like a python function.""" + + def __init__(self, interpreter: tf.lite.Interpreter, output_names: Sequence[str]): + self._interpreter = interpreter + self._output_names = output_names + + def __call__( + self, *args, **kwargs + ) -> Union[Dict[str, Any], Tuple[Any], np.ndarray]: + if len(args) and len(kwargs): + raise ValueError( + "Passing both args and kwargs is not supported by " + "_TfLiteFunctionWrapper" + ) + + if len(args) == 1 and isinstance(args[0], list): + # Specifically to get TFLite to work with keras models that take a list of + # inputs instead of a sequence of args as their inputs, because it decides + # to change the input signature but it still technically works if you + # ignore that it does that. + if len(args) == 1 and isinstance(args[0], list): + args = args[0] + + # Tell TFLite what the shapes of the input tensors are before allocation. + if args: + for arg, detail in zip(args, self._interpreter.get_input_details()): + self._interpreter.resize_tensor_input(detail["index"], arg.shape) + else: + for detail in self._interpreter.get_input_details(): + self._interpreter.resize_tensor_input( + detail["index"], kwargs[detail["name"]].shape + ) + + # Allocate the (potentially dynamic) tensors. + self._interpreter.allocate_tensors() + + # Copy the input data into the allocated tensors. + if args: + for arg, detail in zip(args, self._interpreter.get_input_details()): + self._interpreter.set_tensor(detail["index"], arg) + else: + for detail in self._interpreter.get_input_details(): + self._interpreter.set_tensor(detail["index"], kwargs[detail["name"]]) + + # Execute the function. + self._interpreter.invoke() + + # Extract the outputs from the TFLite interpreter. + outputs = [] + for detail in self._interpreter.get_output_details(): + # Normalize for comparison with IREE. + value = tf_utils.convert_to_numpy( + self._interpreter.get_tensor(detail["index"]) + ) + if self._output_names is not None: + name = detail["name"] + if name not in self._output_names: + raise ValueError(f"Expected '{name}' to be in {self._output_names}") + outputs.append([detail["name"], value]) + else: + outputs.append(value) + + # Process them to match the output of the tf.Module. + if self._output_names is not None: + return dict(outputs) + else: + if len(outputs) == 1: + return outputs[0] + return tuple(outputs) class TfLiteCompiledModule(CompiledModule): - """Compiles a tf.Module with TFLite and allows it to be called.""" - - def __init__( - self, - module_name: str, - backend_info: BackendInfo, - compiled_paths: Dict[str, str], - interpreters: Dict[str, tf.lite.Interpreter], - output_names: Optional[Sequence[str]] = None, - ): - """Base constructor – Use one of the named constructors instead. - - Args: - module_name: A name for this compiled module. In most cases this will be - the name of the tf.Module subclass or instance that is compiled. - backend_info: BackendInfo with the details about compiling this module. - compiled_paths: A dictionary mapping compiled method names to file paths - corresponding to their serialized representations. - interpreters: A dict of tf.lite.Interpreters to make callable. - """ - super().__init__(module_name, backend_info, compiled_paths) - self._interpreters = interpreters - self._output_names = output_names - - @classmethod - def create_from_class(cls, - module_class: Type[tf.Module], - backend_info: BackendInfo, - exported_names: Sequence[str] = (), - artifacts_dir: Optional[str] = None): - """Compile a tf.Module subclass to the target backend in backend_info. - - Args: - module_class: The tf.Module subclass to compile. - backend_info: BackendInfo with the details for compiling this module. - exported_names: Optional sequence representing the exported names to keep. - artifacts_dir: An optional string pointing to where compilation artifacts - should be saved. No compilation artifacts will be saved if this is not - provided. - """ - tf_utils.set_random_seed() - tflite_module_bytes = tf_module_to_tflite_module_bytes( - module_class, exported_names) - interpreters, compiled_paths = tflite_module_bytes_to_tflite_interpreters( - tflite_module_bytes, artifacts_dir) - module_name = module_class.__name__ - return cls(module_name, backend_info, compiled_paths, interpreters) - - @classmethod - def create_from_signature_def_saved_model( - cls, - saved_model_dir: str, - saved_model_tags: Set[str], - module_name: str, - backend_info: BackendInfo, - exported_name: str, - input_names: Sequence[str], - output_names: Sequence[str], - artifacts_dir: Optional[str] = None): - """Compile a SignatureDef SavedModel to the target backend in backend_info. - - Args: - saved_model_dir: Directory of the saved model. - saved_model_tags: Optional set of tags to use when loading the model. - module_name: A name for this compiled module. - backend_info: BackendInfo with the details for compiling the saved model. - exported_name: A str representing the signature on the saved model to - compile. - input_names: A sequence of kwargs to feed to the saved model. - output_names: A sequence of named outputs to extract from the saved model. - artifacts_dir: An optional string pointing to where compilation artifacts - should be saved. No compilation artifacts will be saved if this is not - provided. - """ - tflite_module_bytes = tf_signature_def_saved_model_to_tflite_module_bytes( - saved_model_dir, saved_model_tags, exported_name, input_names, - output_names) - interpreters, compiled_paths = tflite_module_bytes_to_tflite_interpreters( - tflite_module_bytes, artifacts_dir) - return cls(module_name, backend_info, compiled_paths, interpreters, - output_names) - - def reinitialize(self): - """Reinitializes all stateful variables.""" - # This is a noop because TFLite (mostly) doesn't support stateful modules. - pass - - def __getattr__(self, attr: str) -> _TfLiteFunctionWrapper: - # Try to resolve it as an interpreter. - if not attr in self._interpreters: - raise AttributeError( - f"The TFLite module does not have an interpreter for '{attr}'") - return _TfLiteFunctionWrapper(self._interpreters[attr], self._output_names) - - def tflite_serializable(self) -> bool: - return self.compiled_paths is not None + """Compiles a tf.Module with TFLite and allows it to be called.""" + + def __init__( + self, + module_name: str, + backend_info: BackendInfo, + compiled_paths: Dict[str, str], + interpreters: Dict[str, tf.lite.Interpreter], + output_names: Optional[Sequence[str]] = None, + ): + """Base constructor – Use one of the named constructors instead. + + Args: + module_name: A name for this compiled module. In most cases this will be + the name of the tf.Module subclass or instance that is compiled. + backend_info: BackendInfo with the details about compiling this module. + compiled_paths: A dictionary mapping compiled method names to file paths + corresponding to their serialized representations. + interpreters: A dict of tf.lite.Interpreters to make callable. + """ + super().__init__(module_name, backend_info, compiled_paths) + self._interpreters = interpreters + self._output_names = output_names + + @classmethod + def create_from_class( + cls, + module_class: Type[tf.Module], + backend_info: BackendInfo, + exported_names: Sequence[str] = (), + artifacts_dir: Optional[str] = None, + ): + """Compile a tf.Module subclass to the target backend in backend_info. + + Args: + module_class: The tf.Module subclass to compile. + backend_info: BackendInfo with the details for compiling this module. + exported_names: Optional sequence representing the exported names to keep. + artifacts_dir: An optional string pointing to where compilation artifacts + should be saved. No compilation artifacts will be saved if this is not + provided. + """ + tf_utils.set_random_seed() + tflite_module_bytes = tf_module_to_tflite_module_bytes( + module_class, exported_names + ) + interpreters, compiled_paths = tflite_module_bytes_to_tflite_interpreters( + tflite_module_bytes, artifacts_dir + ) + module_name = module_class.__name__ + return cls(module_name, backend_info, compiled_paths, interpreters) + + @classmethod + def create_from_signature_def_saved_model( + cls, + saved_model_dir: str, + saved_model_tags: Set[str], + module_name: str, + backend_info: BackendInfo, + exported_name: str, + input_names: Sequence[str], + output_names: Sequence[str], + artifacts_dir: Optional[str] = None, + ): + """Compile a SignatureDef SavedModel to the target backend in backend_info. + + Args: + saved_model_dir: Directory of the saved model. + saved_model_tags: Optional set of tags to use when loading the model. + module_name: A name for this compiled module. + backend_info: BackendInfo with the details for compiling the saved model. + exported_name: A str representing the signature on the saved model to + compile. + input_names: A sequence of kwargs to feed to the saved model. + output_names: A sequence of named outputs to extract from the saved model. + artifacts_dir: An optional string pointing to where compilation artifacts + should be saved. No compilation artifacts will be saved if this is not + provided. + """ + tflite_module_bytes = tf_signature_def_saved_model_to_tflite_module_bytes( + saved_model_dir, saved_model_tags, exported_name, input_names, output_names + ) + interpreters, compiled_paths = tflite_module_bytes_to_tflite_interpreters( + tflite_module_bytes, artifacts_dir + ) + return cls( + module_name, backend_info, compiled_paths, interpreters, output_names + ) + + def reinitialize(self): + """Reinitializes all stateful variables.""" + # This is a noop because TFLite (mostly) doesn't support stateful modules. + pass + + def __getattr__(self, attr: str) -> _TfLiteFunctionWrapper: + # Try to resolve it as an interpreter. + if not attr in self._interpreters: + raise AttributeError( + f"The TFLite module does not have an interpreter for '{attr}'" + ) + return _TfLiteFunctionWrapper(self._interpreters[attr], self._output_names) + + def tflite_serializable(self) -> bool: + return self.compiled_paths is not None class BackendInfo: - """Contains information for compiling the specified backend.""" - - _name_to_info = { - "tf": { - "compiled_module_class": TfCompiledModule, - "driver": None, - "compiler_targets": None, - }, - "tflite": { - "compiled_module_class": TfLiteCompiledModule, - "driver": None, - "compiler_targets": None, - }, - "iree_vmvx": { - "compiled_module_class": IreeCompiledModule, - "driver": "local-task", - "compiler_targets": ["vmvx"] - }, - "iree_vulkan": { - "compiled_module_class": IreeCompiledModule, - "driver": "vulkan", - "compiler_targets": ["vulkan-spirv"] - }, - "iree_llvmcpu": { - "compiled_module_class": IreeCompiledModule, - "driver": "local-task", - "compiler_targets": ["llvm-cpu"] - }, - } - - def __init__(self, backend_name: str, backend_id: Optional[str] = None): - """Creates a BackendInfo with the compilation details for backend_name. - - Args: - backend_name: a str specifying which backend to use. Should be one of - 'tf', 'tflite', 'iree_vmvx', 'iree_vulkan', 'iree_llvmcpu'. - backend_id: an optional str specifying what name to use when saving - compiled artifacts. Must satisfy `backend_id.startswith(backend_name)`. - - Raises: - KeyError: if backend_name is not one of ['tf', 'tflite', 'iree_vmvx', - 'iree_vulkan', 'iree_llvmcpu']. - ValueError: if backend_id doesn't start with backend_name. - """ - if backend_name not in self._name_to_info: - raise KeyError( - "Expected backend_name to be one of " - f"{list(self._name_to_info.keys())} but got '{backend_name}'.") - if backend_id is not None and not backend_id.startswith(backend_name): - raise ValueError(f"Expected backend_id to start with '{backend_name}' " - f"but got '{backend_id}'.") - - self.backend_name = backend_name - self.backend_id = backend_name if backend_id is None else backend_id - - info = self._name_to_info[backend_name] - self._compiled_module_class = info["compiled_module_class"] - self.driver = info["driver"] - self.compiler_targets = info["compiler_targets"] - - def compile_from_class(self, - module_class: Type[tf.Module], - exported_names: Sequence[str] = (), - artifacts_dir: Optional[str] = None) -> CompiledModule: - """Creates a 'CompiledModule' for this backend.""" - return self._compiled_module_class.create_from_class( - module_class, self, exported_names, artifacts_dir) - - def compile_signature_def_saved_model( - self, - saved_model_dir: str, - saved_model_tags: Set[str], - module_name: str, - exported_name: str, - input_names: Sequence[str], - output_names: Sequence[str], - artifacts_dir: Optional[str] = None) -> CompiledModule: - return self._compiled_module_class.create_from_signature_def_saved_model( - saved_model_dir, saved_model_tags, module_name, self, exported_name, - input_names, output_names, artifacts_dir) - - @classmethod - def get_all_backends(cls) -> Sequence[BackendInfo]: - """Returns a list of all BackendInfo configurations.""" - return [BackendInfo(backend_name) for backend_name in cls._name_to_info] + """Contains information for compiling the specified backend.""" + + _name_to_info = { + "tf": { + "compiled_module_class": TfCompiledModule, + "driver": None, + "compiler_targets": None, + }, + "tflite": { + "compiled_module_class": TfLiteCompiledModule, + "driver": None, + "compiler_targets": None, + }, + "iree_vmvx": { + "compiled_module_class": IreeCompiledModule, + "driver": "local-task", + "compiler_targets": ["vmvx"], + }, + "iree_vulkan": { + "compiled_module_class": IreeCompiledModule, + "driver": "vulkan", + "compiler_targets": ["vulkan-spirv"], + }, + "iree_llvmcpu": { + "compiled_module_class": IreeCompiledModule, + "driver": "local-task", + "compiler_targets": ["llvm-cpu"], + }, + } + + def __init__(self, backend_name: str, backend_id: Optional[str] = None): + """Creates a BackendInfo with the compilation details for backend_name. + + Args: + backend_name: a str specifying which backend to use. Should be one of + 'tf', 'tflite', 'iree_vmvx', 'iree_vulkan', 'iree_llvmcpu'. + backend_id: an optional str specifying what name to use when saving + compiled artifacts. Must satisfy `backend_id.startswith(backend_name)`. + + Raises: + KeyError: if backend_name is not one of ['tf', 'tflite', 'iree_vmvx', + 'iree_vulkan', 'iree_llvmcpu']. + ValueError: if backend_id doesn't start with backend_name. + """ + if backend_name not in self._name_to_info: + raise KeyError( + "Expected backend_name to be one of " + f"{list(self._name_to_info.keys())} but got '{backend_name}'." + ) + if backend_id is not None and not backend_id.startswith(backend_name): + raise ValueError( + f"Expected backend_id to start with '{backend_name}' " + f"but got '{backend_id}'." + ) + + self.backend_name = backend_name + self.backend_id = backend_name if backend_id is None else backend_id + + info = self._name_to_info[backend_name] + self._compiled_module_class = info["compiled_module_class"] + self.driver = info["driver"] + self.compiler_targets = info["compiler_targets"] + + def compile_from_class( + self, + module_class: Type[tf.Module], + exported_names: Sequence[str] = (), + artifacts_dir: Optional[str] = None, + ) -> CompiledModule: + """Creates a 'CompiledModule' for this backend.""" + return self._compiled_module_class.create_from_class( + module_class, self, exported_names, artifacts_dir + ) + + def compile_signature_def_saved_model( + self, + saved_model_dir: str, + saved_model_tags: Set[str], + module_name: str, + exported_name: str, + input_names: Sequence[str], + output_names: Sequence[str], + artifacts_dir: Optional[str] = None, + ) -> CompiledModule: + return self._compiled_module_class.create_from_signature_def_saved_model( + saved_model_dir, + saved_model_tags, + module_name, + self, + exported_name, + input_names, + output_names, + artifacts_dir, + ) + + @classmethod + def get_all_backends(cls) -> Sequence[BackendInfo]: + """Returns a list of all BackendInfo configurations.""" + return [BackendInfo(backend_name) for backend_name in cls._name_to_info] diff --git a/integrations/tensorflow/python_projects/iree_tf/iree/tf/support/module_utils_test.py b/integrations/tensorflow/python_projects/iree_tf/iree/tf/support/module_utils_test.py index c3ff32df2614..b5249ebdba74 100644 --- a/integrations/tensorflow/python_projects/iree_tf/iree/tf/support/module_utils_test.py +++ b/integrations/tensorflow/python_projects/iree_tf/iree/tf/support/module_utils_test.py @@ -15,90 +15,92 @@ class ConstantModule(tf.Module): - - @tf.function(input_signature=[]) - def meaning(self): - return tf.constant([42.]) + @tf.function(input_signature=[]) + def meaning(self): + return tf.constant([42.0]) class StatefulCountingModule(tf.Module): + def __init__(self): + self.count = tf.Variable([0.0]) - def __init__(self): - self.count = tf.Variable([0.]) - - @tf.function(input_signature=[]) - def increment(self): - self.count.assign_add(tf.constant([1.])) + @tf.function(input_signature=[]) + def increment(self): + self.count.assign_add(tf.constant([1.0])) - @tf.function(input_signature=[]) - def get_count(self): - return self.count + @tf.function(input_signature=[]) + def get_count(self): + return self.count class RandomInitModule(tf.Module): + def __init__(self): + self.value = tf.Variable(tf.random.uniform([1])) - def __init__(self): - self.value = tf.Variable(tf.random.uniform([1])) - - @tf.function(input_signature=[]) - def get(self): - return self.value + @tf.function(input_signature=[]) + def get(self): + return self.value class UtilsTests(tf.test.TestCase, parameterized.TestCase): - - def test_artifact_saving(self): - backend_info = module_utils.BackendInfo('iree_vmvx') - with tempfile.TemporaryDirectory() as artifacts_dir: - tf_module = ConstantModule() - iree_module_utils, compiled_path = ( - module_utils._incrementally_compile_tf_module( - tf_module, backend_info=backend_info, - artifacts_dir=artifacts_dir)) - - artifacts_to_check = [ - 'iree_input.mlir', - compiled_path, - ] - for artifact in artifacts_to_check: - artifact_path = os.path.join(artifacts_dir, artifact) - logging.info('Checking path: %s', artifact_path) - self.assertTrue(os.path.exists(artifact_path)) - - @parameterized.named_parameters([ - ('tensorflow', 'tf'), - ('vmvx', 'iree_vmvx'), - ]) - def test_unaltered_state(self, backend_name): - backend_info = module_utils.BackendInfo(backend_name) - module = backend_info.compile_from_class(StatefulCountingModule) - - # Test that incrementing works properly. - self.assertEqual([0.], module.get_count()) - module.increment() - self.assertEqual([1.], module.get_count()) - - module.reinitialize() - # Test reinitialization. - self.assertEqual([0.], module.get_count()) - - @parameterized.named_parameters([ - ('tensorflow', 'tf'), - ('vmvx', 'iree_vmvx'), - ]) - def test_random_initialization(self, backend_name): - backend_info = module_utils.BackendInfo(backend_name) - - # Test compilation is the same. - module_1 = backend_info.compile_from_class(RandomInitModule) - module_2 = backend_info.compile_from_class(RandomInitModule) - self.assertAllEqual(module_1.get(), module_2.get()) - - # Test reinitialization is the same. - old_value = module_1.get() - module_1.reinitialize() - self.assertAllEqual(old_value, module_1.get()) - - -if __name__ == '__main__': - tf.test.main() + def test_artifact_saving(self): + backend_info = module_utils.BackendInfo("iree_vmvx") + with tempfile.TemporaryDirectory() as artifacts_dir: + tf_module = ConstantModule() + ( + iree_module_utils, + compiled_path, + ) = module_utils._incrementally_compile_tf_module( + tf_module, backend_info=backend_info, artifacts_dir=artifacts_dir + ) + + artifacts_to_check = [ + "iree_input.mlir", + compiled_path, + ] + for artifact in artifacts_to_check: + artifact_path = os.path.join(artifacts_dir, artifact) + logging.info("Checking path: %s", artifact_path) + self.assertTrue(os.path.exists(artifact_path)) + + @parameterized.named_parameters( + [ + ("tensorflow", "tf"), + ("vmvx", "iree_vmvx"), + ] + ) + def test_unaltered_state(self, backend_name): + backend_info = module_utils.BackendInfo(backend_name) + module = backend_info.compile_from_class(StatefulCountingModule) + + # Test that incrementing works properly. + self.assertEqual([0.0], module.get_count()) + module.increment() + self.assertEqual([1.0], module.get_count()) + + module.reinitialize() + # Test reinitialization. + self.assertEqual([0.0], module.get_count()) + + @parameterized.named_parameters( + [ + ("tensorflow", "tf"), + ("vmvx", "iree_vmvx"), + ] + ) + def test_random_initialization(self, backend_name): + backend_info = module_utils.BackendInfo(backend_name) + + # Test compilation is the same. + module_1 = backend_info.compile_from_class(RandomInitModule) + module_2 = backend_info.compile_from_class(RandomInitModule) + self.assertAllEqual(module_1.get(), module_2.get()) + + # Test reinitialization is the same. + old_value = module_1.get() + module_1.reinitialize() + self.assertAllEqual(old_value, module_1.get()) + + +if __name__ == "__main__": + tf.test.main() diff --git a/integrations/tensorflow/python_projects/iree_tf/iree/tf/support/tf_test_utils.py b/integrations/tensorflow/python_projects/iree_tf/iree/tf/support/tf_test_utils.py index 9a84956c1a21..a49286fffdaf 100644 --- a/integrations/tensorflow/python_projects/iree_tf/iree/tf/support/tf_test_utils.py +++ b/integrations/tensorflow/python_projects/iree_tf/iree/tf/support/tf_test_utils.py @@ -30,94 +30,104 @@ from iree.tf.support import trace_utils import tensorflow.compat.v2 as tf -flags.DEFINE_string("reference_backend", "tf", - "The backend to treat as a source of truth.") -flags.DEFINE_list("target_backends", None, - "Explicit comma-delimited list of target backends.") flags.DEFINE_string( - "artifacts_dir", None, + "reference_backend", "tf", "The backend to treat as a source of truth." +) +flags.DEFINE_list( + "target_backends", None, "Explicit comma-delimited list of target backends." +) +flags.DEFINE_string( + "artifacts_dir", + None, "Specifies a directory to dump compilation artifacts and traces to. " - "Defaults to the OS's tempdir.") + "Defaults to the OS's tempdir.", +) +flags.DEFINE_bool( + "summarize", + True, + "Summarize the inputs and outputs of each module trace logged to disk.", +) flags.DEFINE_bool( - "summarize", True, - "Summarize the inputs and outputs of each module trace logged to disk.") -flags.DEFINE_bool("log_all_traces", False, - "Log all traces to logging.info, even if comparison passes.") + "log_all_traces", + False, + "Log all traces to logging.info, even if comparison passes.", +) FLAGS = flags.FLAGS DEFAULT_INPUT_GENERATOR = tf_utils.uniform def _setup_artifacts_dir(relative_artifacts_dir: str) -> str: - parent_dirs = [ - FLAGS.artifacts_dir, - os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR'), - os.environ.get('TEST_TMPDIR'), - os.path.join(tempfile.gettempdir(), "iree", "modules"), - ] - # Use the most preferred path in parent_dirs that isn't None. - parent_dir = next(parent for parent in parent_dirs if parent is not None) + parent_dirs = [ + FLAGS.artifacts_dir, + os.environ.get("TEST_UNDECLARED_OUTPUTS_DIR"), + os.environ.get("TEST_TMPDIR"), + os.path.join(tempfile.gettempdir(), "iree", "modules"), + ] + # Use the most preferred path in parent_dirs that isn't None. + parent_dir = next(parent for parent in parent_dirs if parent is not None) - artifacts_dir = os.path.join(parent_dir, relative_artifacts_dir) - logging.info("Saving compilation artifacts and traces to '%s'", artifacts_dir) - os.makedirs(artifacts_dir, exist_ok=True) - return artifacts_dir + artifacts_dir = os.path.join(parent_dir, relative_artifacts_dir) + logging.info("Saving compilation artifacts and traces to '%s'", artifacts_dir) + os.makedirs(artifacts_dir, exist_ok=True) + return artifacts_dir def _parse_target_backends() -> Tuple[Sequence[str], Sequence[str]]: - """Decodes --target_backends and creates unique ids for them.""" - backend_names = FLAGS.target_backends - backend_to_index = {k: 0 for k in backend_names if backend_names.count(k) > 1} - backend_ids = [] - - # If there are multiple copies of the same backend_name, index them. e.g. - # backend_names = ["tf", "iree_vmvx", "tf"] - # --> backend_ids = ["tf_0", "iree_vmvx", "tf_1"] - for backend_name in backend_names: - if backend_name in backend_to_index: - backend_ids.append(f"{backend_name}_{backend_to_index[backend_name]}") - backend_to_index[backend_name] += 1 - else: - backend_ids.append(backend_name) + """Decodes --target_backends and creates unique ids for them.""" + backend_names = FLAGS.target_backends + backend_to_index = {k: 0 for k in backend_names if backend_names.count(k) > 1} + backend_ids = [] - return backend_names, backend_ids + # If there are multiple copies of the same backend_name, index them. e.g. + # backend_names = ["tf", "iree_vmvx", "tf"] + # --> backend_ids = ["tf_0", "iree_vmvx", "tf_1"] + for backend_name in backend_names: + if backend_name in backend_to_index: + backend_ids.append(f"{backend_name}_{backend_to_index[backend_name]}") + backend_to_index[backend_name] += 1 + else: + backend_ids.append(backend_name) + + return backend_names, backend_ids def get_target_backends() -> Sequence[module_utils.BackendInfo]: - """Gets the BackendInfo instances to compare with the reference backend. - - By default all backends in BackendInfo will be used. Specific backends to - run on can be specified using the `--target_backends` flag. - - Returns: - Sequence of BackendInfo that should be used. - """ - if FLAGS.target_backends is not None: - logging.info("Using backends from command line: %s", FLAGS.target_backends) - backend_names, backend_ids = _parse_target_backends() - backends = [ - module_utils.BackendInfo(backend_name, backend_id) - for backend_name, backend_id in zip(backend_names, backend_ids) - ] - else: - # If no backends are specified, use them all. - backends = module_utils.BackendInfo.get_all_backends() - return backends + """Gets the BackendInfo instances to compare with the reference backend. + + By default all backends in BackendInfo will be used. Specific backends to + run on can be specified using the `--target_backends` flag. + + Returns: + Sequence of BackendInfo that should be used. + """ + if FLAGS.target_backends is not None: + logging.info("Using backends from command line: %s", FLAGS.target_backends) + backend_names, backend_ids = _parse_target_backends() + backends = [ + module_utils.BackendInfo(backend_name, backend_id) + for backend_name, backend_id in zip(backend_names, backend_ids) + ] + else: + # If no backends are specified, use them all. + backends = module_utils.BackendInfo.get_all_backends() + return backends @dataclass(frozen=True) class Modules: - """Compiled modules. + """Compiled modules. - Args: - ref_module: Module compiled with the reference backend. - tar_modules: Sequence of modules compiled with the different target - backends. - artifacts_dir: String pointing to where compilation artifacts were saved. - """ - ref_module: module_utils.CompiledModule - tar_modules: Sequence[module_utils.CompiledModule] - artifacts_dir: str + Args: + ref_module: Module compiled with the reference backend. + tar_modules: Sequence of modules compiled with the different target + backends. + artifacts_dir: String pointing to where compilation artifacts were saved. + """ + + ref_module: module_utils.CompiledModule + tar_modules: Sequence[module_utils.CompiledModule] + artifacts_dir: str # We have to use a global variable to store the compiled modules so that we can @@ -130,92 +140,104 @@ class Modules: _global_modules = None -def compile_tf_module(module_class: Type[tf.Module], - exported_names: Sequence[str] = (), - relative_artifacts_dir: str = None) -> Modules: - """Compiles module_class to each backend that we test. - - Args: - module_class: the tf.Module subclass to compile. - exported_names: optional iterable of strings representing which of - module_class's functions to compile. If exported_names is empty all - functions will be compiled. - relative_artifacts_dir: optional string specifying where to save compilation - artifacts within the artifacts_dir. If it is not specified then - module_class.__name__ will be used. - - Returns: - A 'Modules' dataclass containing the reference module, target modules and - artifacts directory. - """ - global _global_modules - if _global_modules is not None: - return _global_modules +def compile_tf_module( + module_class: Type[tf.Module], + exported_names: Sequence[str] = (), + relative_artifacts_dir: str = None, +) -> Modules: + """Compiles module_class to each backend that we test. - # Setup the directory for saving compilation artifacts and traces. - if relative_artifacts_dir is None: - relative_artifacts_dir = module_class.__name__ - artifacts_dir = _setup_artifacts_dir(relative_artifacts_dir) - - # Get the backend information for this test. - ref_backend_info = module_utils.BackendInfo(FLAGS.reference_backend, - f"{FLAGS.reference_backend}_ref") - tar_backend_infos = get_target_backends() - - compile_backend = lambda backend_info: backend_info.compile_from_class( - module_class, exported_names, artifacts_dir) - - ref_module = compile_backend(ref_backend_info) - tar_modules = [ - compile_backend(backend_info) for backend_info in tar_backend_infos - ] - _global_modules = Modules(ref_module, tar_modules, artifacts_dir) - return _global_modules - - -def compile_tf_signature_def_saved_model( - saved_model_dir: str, saved_model_tags: Set[str], module_name: str, - exported_name: str, input_names: Sequence[str], - output_names: Sequence[str]) -> Modules: - """Compiles a SignatureDef SavedModel to each backend that we test. - - Args: - saved_model_dir: Directory of the saved model. - saved_model_tags: Optional set of tags to use when loading the model. - module_name: A name for this compiled module. - backend_info: BackendInfo with the details for compiling the saved model. - exported_name: A str representing the signature on the saved model to - compile. - input_names: A sequence of kwargs to feed to the saved model. - output_names: A sequence of named outputs to extract from the saved model. - - Returns: - A 'Modules' dataclass containing the reference module, target modules and - artifacts directory. - """ - global _global_modules - if _global_modules is not None: + Args: + module_class: the tf.Module subclass to compile. + exported_names: optional iterable of strings representing which of + module_class's functions to compile. If exported_names is empty all + functions will be compiled. + relative_artifacts_dir: optional string specifying where to save compilation + artifacts within the artifacts_dir. If it is not specified then + module_class.__name__ will be used. + + Returns: + A 'Modules' dataclass containing the reference module, target modules and + artifacts directory. + """ + global _global_modules + if _global_modules is not None: + return _global_modules + + # Setup the directory for saving compilation artifacts and traces. + if relative_artifacts_dir is None: + relative_artifacts_dir = module_class.__name__ + artifacts_dir = _setup_artifacts_dir(relative_artifacts_dir) + + # Get the backend information for this test. + ref_backend_info = module_utils.BackendInfo( + FLAGS.reference_backend, f"{FLAGS.reference_backend}_ref" + ) + tar_backend_infos = get_target_backends() + + compile_backend = lambda backend_info: backend_info.compile_from_class( + module_class, exported_names, artifacts_dir + ) + + ref_module = compile_backend(ref_backend_info) + tar_modules = [compile_backend(backend_info) for backend_info in tar_backend_infos] + _global_modules = Modules(ref_module, tar_modules, artifacts_dir) return _global_modules - # Setup the directory for saving compilation artifacts and traces. - artifacts_dir = _setup_artifacts_dir(module_name) - - # Get the backend information for this test. - ref_backend_info = module_utils.BackendInfo(FLAGS.reference_backend, - f"{FLAGS.reference_backend}_ref") - tar_backend_infos = get_target_backends() - compile_backend = ( - lambda backend_info: backend_info.compile_signature_def_saved_model( - saved_model_dir, saved_model_tags, module_name, exported_name, - input_names, output_names, artifacts_dir)) +def compile_tf_signature_def_saved_model( + saved_model_dir: str, + saved_model_tags: Set[str], + module_name: str, + exported_name: str, + input_names: Sequence[str], + output_names: Sequence[str], +) -> Modules: + """Compiles a SignatureDef SavedModel to each backend that we test. - ref_module = compile_backend(ref_backend_info) - tar_modules = [ - compile_backend(backend_info) for backend_info in tar_backend_infos - ] - _global_modules = Modules(ref_module, tar_modules, artifacts_dir) - return _global_modules + Args: + saved_model_dir: Directory of the saved model. + saved_model_tags: Optional set of tags to use when loading the model. + module_name: A name for this compiled module. + backend_info: BackendInfo with the details for compiling the saved model. + exported_name: A str representing the signature on the saved model to + compile. + input_names: A sequence of kwargs to feed to the saved model. + output_names: A sequence of named outputs to extract from the saved model. + + Returns: + A 'Modules' dataclass containing the reference module, target modules and + artifacts directory. + """ + global _global_modules + if _global_modules is not None: + return _global_modules + + # Setup the directory for saving compilation artifacts and traces. + artifacts_dir = _setup_artifacts_dir(module_name) + + # Get the backend information for this test. + ref_backend_info = module_utils.BackendInfo( + FLAGS.reference_backend, f"{FLAGS.reference_backend}_ref" + ) + tar_backend_infos = get_target_backends() + + compile_backend = ( + lambda backend_info: backend_info.compile_signature_def_saved_model( + saved_model_dir, + saved_model_tags, + module_name, + exported_name, + input_names, + output_names, + artifacts_dir, + ) + ) + + ref_module = compile_backend(ref_backend_info) + tar_modules = [compile_backend(backend_info) for backend_info in tar_backend_infos] + _global_modules = Modules(ref_module, tar_modules, artifacts_dir) + return _global_modules # We use global variables to store the configuration information for @@ -229,423 +251,453 @@ def compile_tf_signature_def_saved_model( class UnitTestSpec: - - def __init__(self, - unit_test_name: str, - input_signature: Sequence[tf.TensorSpec], - input_generator: tf_utils.InputGeneratorType = None, - input_args: Union[Sequence[Any], None] = None, - kwargs: Dict[str, Any] = None): - self.unit_test_name = tf_utils.remove_special_characters(unit_test_name) - self.input_signature = input_signature - self.input_args = input_args - self.kwargs = dict() if kwargs is None else kwargs - self.input_generator = input_generator - - def with_name(self, new_name: str) -> UnitTestSpec: - return UnitTestSpec(new_name, self.input_signature, self.input_generator, - self.input_args, self.kwargs) - - def __str__(self): - return self.unit_test_name + def __init__( + self, + unit_test_name: str, + input_signature: Sequence[tf.TensorSpec], + input_generator: tf_utils.InputGeneratorType = None, + input_args: Union[Sequence[Any], None] = None, + kwargs: Dict[str, Any] = None, + ): + self.unit_test_name = tf_utils.remove_special_characters(unit_test_name) + self.input_signature = input_signature + self.input_args = input_args + self.kwargs = dict() if kwargs is None else kwargs + self.input_generator = input_generator + + def with_name(self, new_name: str) -> UnitTestSpec: + return UnitTestSpec( + new_name, + self.input_signature, + self.input_generator, + self.input_args, + self.kwargs, + ) + + def __str__(self): + return self.unit_test_name def _dictionary_product(dictionary: Dict[Any, Any]) -> List[Dict[Any, Any]]: - """Returns a named cartesian product of dictionary's values. + """Returns a named cartesian product of dictionary's values. - Converts {'a': [1, 2], 'b': [3, 4]} into - [{'a': 1, 'b': 3}, {'a': 1, 'b': 4}, {'a': 2, 'b': 3}, {'a': 2, 'b': 4}] - """ - product = [[]] - for values in dictionary.values(): - # Iteratively grow the elements of the product. - product = [element + [value] for element in product for value in values] - dicts = [{k: v for k, v in zip(dictionary, element)} for element in product] - return dicts + Converts {'a': [1, 2], 'b': [3, 4]} into + [{'a': 1, 'b': 3}, {'a': 1, 'b': 4}, {'a': 2, 'b': 3}, {'a': 2, 'b': 4}] + """ + product = [[]] + for values in dictionary.values(): + # Iteratively grow the elements of the product. + product = [element + [value] for element in product for value in values] + dicts = [{k: v for k, v in zip(dictionary, element)} for element in product] + return dicts def _named_kwargs_product( - kwargs_to_values: Dict[str, Sequence[Any]]) -> Dict[str, Dict[str, Any]]: - """Splits kwargs_to_values into a Cartesian product of its elements.""" - # Validate 'kwargs_to_values' - if kwargs_to_values is None: - kwargs_to_values = dict() # Use only default kwargs. - for kwarg_key, kwarg_values in kwargs_to_values.items(): - if not isinstance(kwarg_values, Sequence): - raise TypeError(f"Expected kwargs_to_values[{repr(kwarg_key)}] to be a " - f"sequence, but got '{type(kwarg_values)}'") - - # Expand across a Cartesian product. - kwargs_product = _dictionary_product(kwargs_to_values) - # {'a': 1, 'b': 3} -> "a_1__b_3" - dict_to_str = lambda d: "__".join([f"{k}_{v}" for k, v in d.items()]) - return {dict_to_str(kwargs): kwargs for kwargs in kwargs_product} + kwargs_to_values: Dict[str, Sequence[Any]] +) -> Dict[str, Dict[str, Any]]: + """Splits kwargs_to_values into a Cartesian product of its elements.""" + # Validate 'kwargs_to_values' + if kwargs_to_values is None: + kwargs_to_values = dict() # Use only default kwargs. + for kwarg_key, kwarg_values in kwargs_to_values.items(): + if not isinstance(kwarg_values, Sequence): + raise TypeError( + f"Expected kwargs_to_values[{repr(kwarg_key)}] to be a " + f"sequence, but got '{type(kwarg_values)}'" + ) + + # Expand across a Cartesian product. + kwargs_product = _dictionary_product(kwargs_to_values) + # {'a': 1, 'b': 3} -> "a_1__b_3" + dict_to_str = lambda d: "__".join([f"{k}_{v}" for k, v in d.items()]) + return {dict_to_str(kwargs): kwargs for kwargs in kwargs_product} def unit_test_specs_from_signatures( signature_shapes: Sequence[Sequence[Sequence[int]]], signature_dtypes: Sequence[tf.DType] = [tf.float32], - input_generators: Union[Sequence[tf_utils.InputGeneratorType], - Dict[str, tf_utils.InputGeneratorType]] = [ - DEFAULT_INPUT_GENERATOR - ], - kwargs_to_values: Dict[str, Sequence[Any]] = None) -> List[UnitTestSpec]: - """Generates a Cartesian product of UnitTestSpecs from the given arguments. - - Args: - signature_shapes: - A sequence (representing multiple signatures to test) of sequences - (representing the shapes of the args in those signatures) of ints - (representing the individual sizes of those shapes). - signature_dtypes: - A sequence of dtypes to test each signature with. - input_generators: - Either: - 1. a sequence of input generators to test each of the signature-dtype - pairs with - 2. a dictionary mapping input generator names to input generators to - test each of the signature-dtype pairs with. This format must be used - if any of the generators are lambda functions. - kwargs_to_values: - A dict mapping kwarg names to sequences of values that they can take. - - Returns: - A list of 'UnitTestSpec's generated from the provided arguments. - """ - # Validate 'signature_shapes' - for i, shapes in enumerate(signature_shapes): - if not isinstance(shapes, Sequence): - raise TypeError(f"Expected signature_shapes[{i}] to be a sequence, but " - f"got '{type(shapes)}'") - for j, shape in enumerate(shapes): - if not isinstance(shape, Sequence): - raise TypeError(f"Expected signature_shapes[{i}][{j}] to be a " - f"sequence, but got '{type(shape)}'") - for k, size in enumerate(shape): - if not isinstance(size, int): - raise TypeError(f"Expected signature_shapes[{i}][{j}][{k}] to be an " - f"int but got '{type(size)}") - - # Parse 'signature_shapes' - names_to_shapes = dict() - for signature in signature_shapes: - # Converts [[1, 2, 3], [4, 5]] into 1x2x3_4x5. - signature_key = "_".join( - ["x".join(str(size) for size in shape) for shape in signature]) - names_to_shapes[signature_key] = signature - - # Validate 'signature_dtypes' - for i, dtype in enumerate(signature_dtypes): - if not isinstance(dtype, tf.DType): - raise TypeError( - f"Expected dtypes[{i}] to be a tf.DType, but got '{type(dtype)}'") - - # Parse 'signature_dtypes' - # 'complex64' -> 'c64' - abbreviate = lambda dtype: re.sub(r"([a-z])[a-z]*([0-9]+)", r"\1\2", dtype) - names_to_dtypes = { - abbreviate(dtype.name): dtype for dtype in signature_dtypes - } - - # Validate 'input_generators' - if not isinstance(input_generators, (Sequence, Dict)): - raise TypeError("Expected 'input_generators' to be a sequence or " - f"dictionary, but got '{type(input_generators)}'") - if isinstance(input_generators, Sequence): - for i, generator in enumerate(input_generators): - if generator.__name__ == "": + input_generators: Union[ + Sequence[tf_utils.InputGeneratorType], Dict[str, tf_utils.InputGeneratorType] + ] = [DEFAULT_INPUT_GENERATOR], + kwargs_to_values: Dict[str, Sequence[Any]] = None, +) -> List[UnitTestSpec]: + """Generates a Cartesian product of UnitTestSpecs from the given arguments. + + Args: + signature_shapes: + A sequence (representing multiple signatures to test) of sequences + (representing the shapes of the args in those signatures) of ints + (representing the individual sizes of those shapes). + signature_dtypes: + A sequence of dtypes to test each signature with. + input_generators: + Either: + 1. a sequence of input generators to test each of the signature-dtype + pairs with + 2. a dictionary mapping input generator names to input generators to + test each of the signature-dtype pairs with. This format must be used + if any of the generators are lambda functions. + kwargs_to_values: + A dict mapping kwarg names to sequences of values that they can take. + + Returns: + A list of 'UnitTestSpec's generated from the provided arguments. + """ + # Validate 'signature_shapes' + for i, shapes in enumerate(signature_shapes): + if not isinstance(shapes, Sequence): + raise TypeError( + f"Expected signature_shapes[{i}] to be a sequence, but " + f"got '{type(shapes)}'" + ) + for j, shape in enumerate(shapes): + if not isinstance(shape, Sequence): + raise TypeError( + f"Expected signature_shapes[{i}][{j}] to be a " + f"sequence, but got '{type(shape)}'" + ) + for k, size in enumerate(shape): + if not isinstance(size, int): + raise TypeError( + f"Expected signature_shapes[{i}][{j}][{k}] to be an " + f"int but got '{type(size)}" + ) + + # Parse 'signature_shapes' + names_to_shapes = dict() + for signature in signature_shapes: + # Converts [[1, 2, 3], [4, 5]] into 1x2x3_4x5. + signature_key = "_".join( + ["x".join(str(size) for size in shape) for shape in signature] + ) + names_to_shapes[signature_key] = signature + + # Validate 'signature_dtypes' + for i, dtype in enumerate(signature_dtypes): + if not isinstance(dtype, tf.DType): + raise TypeError( + f"Expected dtypes[{i}] to be a tf.DType, but got '{type(dtype)}'" + ) + + # Parse 'signature_dtypes' + # 'complex64' -> 'c64' + abbreviate = lambda dtype: re.sub(r"([a-z])[a-z]*([0-9]+)", r"\1\2", dtype) + names_to_dtypes = {abbreviate(dtype.name): dtype for dtype in signature_dtypes} + + # Validate 'input_generators' + if not isinstance(input_generators, (Sequence, Dict)): raise TypeError( - f"'input_generators' was a sequence but input_generators[{i}] was " - "lambda function. 'input_generators' must be a dictionary if " - "lambda functions are used.") - - # Parse 'input_generators' - if isinstance(input_generators, Sequence): - names_to_generators = {gen.__name__: gen for gen in input_generators} - else: - names_to_generators = input_generators - - # Validate and parse 'kwargs_to_values' - names_to_kwargs = _named_kwargs_product(kwargs_to_values) - - # Create a Cartesian product through all specifications and their names. - specs = [ - names_to_shapes, names_to_dtypes, names_to_generators, names_to_kwargs - ] - # pytype: disable=attribute-error - key_product = itertools.product(*[list(spec.keys()) for spec in specs]) - value_product = itertools.product(*[list(spec.values()) for spec in specs]) - # pytype: enable=attribute-error - - # Generate a UnitTestSpec for each element in the above product. - unit_tests = [] - for keys, (shapes, dtype, generator, kwargs) in zip(key_product, - value_product): - unit_test_name = "__".join(key for key in keys if key) - input_signature = [tf.TensorSpec(shape, dtype) for shape in shapes] - unit_tests.append( - UnitTestSpec( - unit_test_name=unit_test_name, - input_signature=input_signature, - input_generator=generator, - input_args=None, - kwargs=kwargs, - )) - return unit_tests + "Expected 'input_generators' to be a sequence or " + f"dictionary, but got '{type(input_generators)}'" + ) + if isinstance(input_generators, Sequence): + for i, generator in enumerate(input_generators): + if generator.__name__ == "": + raise TypeError( + f"'input_generators' was a sequence but input_generators[{i}] was " + "lambda function. 'input_generators' must be a dictionary if " + "lambda functions are used." + ) + + # Parse 'input_generators' + if isinstance(input_generators, Sequence): + names_to_generators = {gen.__name__: gen for gen in input_generators} + else: + names_to_generators = input_generators + + # Validate and parse 'kwargs_to_values' + names_to_kwargs = _named_kwargs_product(kwargs_to_values) + + # Create a Cartesian product through all specifications and their names. + specs = [names_to_shapes, names_to_dtypes, names_to_generators, names_to_kwargs] + # pytype: disable=attribute-error + key_product = itertools.product(*[list(spec.keys()) for spec in specs]) + value_product = itertools.product(*[list(spec.values()) for spec in specs]) + # pytype: enable=attribute-error + + # Generate a UnitTestSpec for each element in the above product. + unit_tests = [] + for keys, (shapes, dtype, generator, kwargs) in zip(key_product, value_product): + unit_test_name = "__".join(key for key in keys if key) + input_signature = [tf.TensorSpec(shape, dtype) for shape in shapes] + unit_tests.append( + UnitTestSpec( + unit_test_name=unit_test_name, + input_signature=input_signature, + input_generator=generator, + input_args=None, + kwargs=kwargs, + ) + ) + return unit_tests def unit_test_specs_from_args( names_to_input_args: Dict[str, Sequence[Any]], - kwargs_to_values: Dict[str, Sequence[Any]] = None) -> List[UnitTestSpec]: - """Generates a Cartesian product of UnitTestSpecs from the given arguments. - - Args: - signature_shapes: - A dict mapping names for input arguments to the arguments themselves. - kwargs_to_values: - A dict mapping kwarg names to sequences of values that they can take. - - Returns: - A list of 'UnitTestSpec's generated from the provided arguments. - """ - # Validate and parse 'kwargs_to_values' - names_to_kwargs = _named_kwargs_product(kwargs_to_values) - - # Create a Cartesian product through all specifications and their names. - specs = [names_to_input_args, names_to_kwargs] - key_product = itertools.product(*[list(spec.keys()) for spec in specs]) - value_product = itertools.product(*[list(spec.values()) for spec in specs]) - - # Generate a UnitTestSpec for each element in the above product. - unit_tests = [] - for keys, (input_args, kwargs) in zip(key_product, value_product): - unit_test_name = "__".join(key for key in keys if key) - input_signature = tf_utils.apply_function( - input_args, - lambda x: tf.TensorSpec.from_tensor(tf.convert_to_tensor(x))) - unit_tests.append( - UnitTestSpec( - unit_test_name=unit_test_name, - input_signature=input_signature, - input_generator=None, - input_args=input_args, - kwargs=kwargs, - )) - return unit_tests - - -def tf_function_unit_test(input_generator: tf_utils.InputGeneratorType = None, - input_args: Sequence[Any] = None, - atol: float = None, - rtol: float = None, - name: str = None, - static_signature: Sequence[tf.TensorSpec] = None, - **tf_function_kwargs): - """Creates a tf.function that can be used to generate unit_tests. - - If 'input_generator' and 'input_args' are unspecified then the function will - be tested using random uniform data. - - Args: - input_generator: - an optional callable taking a shape and dtype that returns input data for - the unit_test. - input_args: - an optional sequence of values to pass as positional args to the function. - atol: - optional, the absolute tolerance to use when comparing the decorated - function's output. - rtol: - optional, the relative tolerance to use when comparing the decorated - function's output. - name: - optional, the name to reference this function with. Must be used if - decorating a lambda. - static_signature: - optional, a signature with the same structure as 'input_signature'. Used - to specify the correct shape for data generation when dynamic dims are - provided. - - Raises: - ValueError: if 'input_generator' and 'input_args' are both specified. - - Returns: - A tf.function with the additional attributes 'input_generator' (from above) - 'trace_kwargs' (from 'atol' and 'rtol' above), and with an updated - __name__ attribute if 'name' was specified. - """ - - def _store_unit_test_info(function): - # Validate arguments. - if input_generator is not None and input_args is not None: - raise ValueError( - "'input_generator' and 'input_args' cannot both be specified.") - - function = tf.function(**tf_function_kwargs)(function) - - # Set function.__name__ - if name is not None: - function.__name__ = name - elif function.__name__ == "": - raise ValueError("The 'name' kwarg must be provided when decorating a " - "lambda function.") - - global _global_unit_test_configs - if function.__name__ not in _global_unit_test_configs: - - if static_signature is not None: - signature = static_signature - else: - signature = function.input_signature - - if input_generator is not None: - # Use the user-specificed input_generator. - get_trace_args = lambda: tf_utils.generate_inputs( - signature, input_generator) - elif input_args is not None: - # Use the user-specified input_args. - get_trace_args = lambda: copy.deepcopy(input_args) - else: - # No user data specification – default to using random uniform data. - get_trace_args = lambda: tf_utils.generate_inputs( - signature, DEFAULT_INPUT_GENERATOR) - - _global_unit_test_configs[function.__name__] = dict( - get_trace_args=get_trace_args, - trace_kwargs=dict(atol=atol, rtol=rtol)) - - return function - - return _store_unit_test_info - - -class TestModule(tf.Module): - """Thin tf.Module wrapper with helper methods for tf_function_unit_tests.""" - - @classmethod - def get_tf_function_unit_tests(cls): - """Get all tf_function_unit_test-created tf.functions on the class.""" - # Initialize the module to ensure that _global_unit_test_configs has the - # info for all of the unit_tests. (Only doing this if - # _global_unit_test_configs is empty wouldn't address the case where some - # unit_tests are defined on the class and some are generated by __init__). - cls() + kwargs_to_values: Dict[str, Sequence[Any]] = None, +) -> List[UnitTestSpec]: + """Generates a Cartesian product of UnitTestSpecs from the given arguments. - tf_function_unit_tests = list(_global_unit_test_configs.keys()) - if not len(tf_function_unit_tests): - raise ValueError( - "'get_tf_function_unit_tests' was called but no tests were found.") - return tf_function_unit_tests + Args: + signature_shapes: + A dict mapping names for input arguments to the arguments themselves. + kwargs_to_values: + A dict mapping kwarg names to sequences of values that they can take. + Returns: + A list of 'UnitTestSpec's generated from the provided arguments. + """ + # Validate and parse 'kwargs_to_values' + names_to_kwargs = _named_kwargs_product(kwargs_to_values) + + # Create a Cartesian product through all specifications and their names. + specs = [names_to_input_args, names_to_kwargs] + key_product = itertools.product(*[list(spec.keys()) for spec in specs]) + value_product = itertools.product(*[list(spec.values()) for spec in specs]) + + # Generate a UnitTestSpec for each element in the above product. + unit_tests = [] + for keys, (input_args, kwargs) in zip(key_product, value_product): + unit_test_name = "__".join(key for key in keys if key) + input_signature = tf_utils.apply_function( + input_args, lambda x: tf.TensorSpec.from_tensor(tf.convert_to_tensor(x)) + ) + unit_tests.append( + UnitTestSpec( + unit_test_name=unit_test_name, + input_signature=input_signature, + input_generator=None, + input_args=input_args, + kwargs=kwargs, + ) + ) + return unit_tests + + +def tf_function_unit_test( + input_generator: tf_utils.InputGeneratorType = None, + input_args: Sequence[Any] = None, + atol: float = None, + rtol: float = None, + name: str = None, + static_signature: Sequence[tf.TensorSpec] = None, + **tf_function_kwargs, +): + """Creates a tf.function that can be used to generate unit_tests. + + If 'input_generator' and 'input_args' are unspecified then the function will + be tested using random uniform data. -class TracedModuleTestCase(tf.test.TestCase): - """Compiles a tf.Module to multiple backends to test their correctness.""" - - def setUp(self) -> None: - # Runs before each unit test. - super().setUp() - self._modules.ref_module.reinitialize() - for module in self._modules.tar_modules: - module.reinitialize() - - @classmethod - def generate_unit_tests(cls, module_class: Type[TestModule]): - """Generates tests for each 'tf_function_unit_test' on 'module_class'.""" - for function_name in module_class.get_tf_function_unit_tests(): - # We have to pass the closure arguments 'function_name', 'get_args' and - # 'kwargs' to 'trace' via a kwarg instead of using it directly in the body - # because 'function_name' and 'unit_test_config' are overwritten in each - # iteration of this loop, and python will only use the most recent version - # of each. If we didn't do this, then we would only test the last function - # in this loop. The same is true for passing 'trace' to 'unit_test'. - unit_test_config = _global_unit_test_configs[function_name] - - # Runs the inputs through a (traced) module. - def trace(module, - function_name=function_name, - get_args=unit_test_config["get_trace_args"], - kwargs=unit_test_config["trace_kwargs"]): - getattr(module, function_name)(*get_args(), **kwargs) + Args: + input_generator: + an optional callable taking a shape and dtype that returns input data for + the unit_test. + input_args: + an optional sequence of values to pass as positional args to the function. + atol: + optional, the absolute tolerance to use when comparing the decorated + function's output. + rtol: + optional, the relative tolerance to use when comparing the decorated + function's output. + name: + optional, the name to reference this function with. Must be used if + decorating a lambda. + static_signature: + optional, a signature with the same structure as 'input_signature'. Used + to specify the correct shape for data generation when dynamic dims are + provided. + + Raises: + ValueError: if 'input_generator' and 'input_args' are both specified. + + Returns: + A tf.function with the additional attributes 'input_generator' (from above) + 'trace_kwargs' (from 'atol' and 'rtol' above), and with an updated + __name__ attribute if 'name' was specified. + """ - # Give the trace the name of the tf.function that it is testing. - trace.__name__ = function_name + def _store_unit_test_info(function): + # Validate arguments. + if input_generator is not None and input_args is not None: + raise ValueError( + "'input_generator' and 'input_args' cannot both be specified." + ) + + function = tf.function(**tf_function_kwargs)(function) + + # Set function.__name__ + if name is not None: + function.__name__ = name + elif function.__name__ == "": + raise ValueError( + "The 'name' kwarg must be provided when decorating a " + "lambda function." + ) + + global _global_unit_test_configs + if function.__name__ not in _global_unit_test_configs: + if static_signature is not None: + signature = static_signature + else: + signature = function.input_signature + + if input_generator is not None: + # Use the user-specificed input_generator. + get_trace_args = lambda: tf_utils.generate_inputs( + signature, input_generator + ) + elif input_args is not None: + # Use the user-specified input_args. + get_trace_args = lambda: copy.deepcopy(input_args) + else: + # No user data specification – default to using random uniform data. + get_trace_args = lambda: tf_utils.generate_inputs( + signature, DEFAULT_INPUT_GENERATOR + ) + + _global_unit_test_configs[function.__name__] = dict( + get_trace_args=get_trace_args, trace_kwargs=dict(atol=atol, rtol=rtol) + ) + + return function + + return _store_unit_test_info - # Runs 'trace' on modules compiled to each backend and compares them. - def unit_test(self, trace=trace): - self.compare_backends(trace, self._modules) - # Make 'unit_test' a function on the TracedModuleTestCase, which tells - # the test runner to run it. - unit_test.__name__ = f"test_{function_name}" - if hasattr(cls, unit_test.__name__): - raise ValueError("Tried to generate multiple instances of the " - f"unit_test '{unit_test.__name__}'.") - setattr(cls, unit_test.__name__, unit_test) +class TestModule(tf.Module): + """Thin tf.Module wrapper with helper methods for tf_function_unit_tests.""" - def compare_backends(self, - trace_function: Callable[[trace_utils.TracedModule], - None], - modules: Modules) -> None: - """Run the reference and target backends on trace_function and compare them. + @classmethod + def get_tf_function_unit_tests(cls): + """Get all tf_function_unit_test-created tf.functions on the class.""" + # Initialize the module to ensure that _global_unit_test_configs has the + # info for all of the unit_tests. (Only doing this if + # _global_unit_test_configs is empty wouldn't address the case where some + # unit_tests are defined on the class and some are generated by __init__). + cls() - Random seeds for tensorflow, numpy and python are set before each invocation - of trace_function. + tf_function_unit_tests = list(_global_unit_test_configs.keys()) + if not len(tf_function_unit_tests): + raise ValueError( + "'get_tf_function_unit_tests' was called but no tests were found." + ) + return tf_function_unit_tests - Args: - trace_function: a function accepting a TracedModule as its argument. - """ - # Create Traces for each backend. - ref_trace = trace_utils.Trace(modules.ref_module, trace_function) - tar_traces = [ - trace_utils.Trace(module, trace_function) - for module in modules.tar_modules - ] - # Run the traces through trace_function with their associated modules. - tf_utils.set_random_seed() - trace_function(trace_utils.TracedModule(modules.ref_module, ref_trace)) - if FLAGS.log_all_traces: - logging.info(ref_trace) - for module, trace in zip(modules.tar_modules, tar_traces): - tf_utils.set_random_seed() - trace_function(trace_utils.TracedModule(module, trace)) - if FLAGS.log_all_traces: - logging.info(trace) - - # Compare each target trace of trace_function with the reference trace. - failed_backend_indices = [] - error_messages = [] - for i, tar_trace in enumerate(tar_traces): - logging.info("Comparing the reference backend '%s' with '%s'", - ref_trace.backend_id, tar_trace.backend_id) - traces_match, errors = trace_utils.compare_traces(ref_trace, tar_trace) - if not traces_match: - failed_backend_indices.append(i) - error_messages.extend(errors) - - # Save the results to disk before validating. - ref_trace_dir = trace_utils.get_trace_dir(modules.artifacts_dir, ref_trace) - ref_trace.save_plaintext(ref_trace_dir, FLAGS.summarize) - ref_trace.serialize(ref_trace_dir) - for tar_trace in tar_traces: - tar_trace_dir = trace_utils.get_trace_dir(modules.artifacts_dir, - tar_trace) - tar_trace.save_plaintext(tar_trace_dir, FLAGS.summarize) - tar_trace.serialize(tar_trace_dir) - - # Validate results. - if failed_backend_indices: - # Extract info for logging. - failed_backends = [ - tar_traces[i].backend_id for i in failed_backend_indices - ] - error_list = ''.join([f'\n - {message}' for message in error_messages]) - self.fail( - "Comparison between the reference backend and the following targets " - f"failed: {failed_backends}. Errors: {error_list}\n" - "See the logs above for more details about the non-matching calls.") - - @classmethod - def tearDownClass(cls) -> None: - # Runs after all unit tests are completed. - super().tearDownClass() +class TracedModuleTestCase(tf.test.TestCase): + """Compiles a tf.Module to multiple backends to test their correctness.""" + + def setUp(self) -> None: + # Runs before each unit test. + super().setUp() + self._modules.ref_module.reinitialize() + for module in self._modules.tar_modules: + module.reinitialize() + + @classmethod + def generate_unit_tests(cls, module_class: Type[TestModule]): + """Generates tests for each 'tf_function_unit_test' on 'module_class'.""" + for function_name in module_class.get_tf_function_unit_tests(): + # We have to pass the closure arguments 'function_name', 'get_args' and + # 'kwargs' to 'trace' via a kwarg instead of using it directly in the body + # because 'function_name' and 'unit_test_config' are overwritten in each + # iteration of this loop, and python will only use the most recent version + # of each. If we didn't do this, then we would only test the last function + # in this loop. The same is true for passing 'trace' to 'unit_test'. + unit_test_config = _global_unit_test_configs[function_name] + + # Runs the inputs through a (traced) module. + def trace( + module, + function_name=function_name, + get_args=unit_test_config["get_trace_args"], + kwargs=unit_test_config["trace_kwargs"], + ): + getattr(module, function_name)(*get_args(), **kwargs) + + # Give the trace the name of the tf.function that it is testing. + trace.__name__ = function_name + + # Runs 'trace' on modules compiled to each backend and compares them. + def unit_test(self, trace=trace): + self.compare_backends(trace, self._modules) + + # Make 'unit_test' a function on the TracedModuleTestCase, which tells + # the test runner to run it. + unit_test.__name__ = f"test_{function_name}" + if hasattr(cls, unit_test.__name__): + raise ValueError( + "Tried to generate multiple instances of the " + f"unit_test '{unit_test.__name__}'." + ) + setattr(cls, unit_test.__name__, unit_test) + + def compare_backends( + self, + trace_function: Callable[[trace_utils.TracedModule], None], + modules: Modules, + ) -> None: + """Run the reference and target backends on trace_function and compare them. + + Random seeds for tensorflow, numpy and python are set before each invocation + of trace_function. + + Args: + trace_function: a function accepting a TracedModule as its argument. + """ + # Create Traces for each backend. + ref_trace = trace_utils.Trace(modules.ref_module, trace_function) + tar_traces = [ + trace_utils.Trace(module, trace_function) for module in modules.tar_modules + ] + + # Run the traces through trace_function with their associated modules. + tf_utils.set_random_seed() + trace_function(trace_utils.TracedModule(modules.ref_module, ref_trace)) + if FLAGS.log_all_traces: + logging.info(ref_trace) + for module, trace in zip(modules.tar_modules, tar_traces): + tf_utils.set_random_seed() + trace_function(trace_utils.TracedModule(module, trace)) + if FLAGS.log_all_traces: + logging.info(trace) + + # Compare each target trace of trace_function with the reference trace. + failed_backend_indices = [] + error_messages = [] + for i, tar_trace in enumerate(tar_traces): + logging.info( + "Comparing the reference backend '%s' with '%s'", + ref_trace.backend_id, + tar_trace.backend_id, + ) + traces_match, errors = trace_utils.compare_traces(ref_trace, tar_trace) + if not traces_match: + failed_backend_indices.append(i) + error_messages.extend(errors) + + # Save the results to disk before validating. + ref_trace_dir = trace_utils.get_trace_dir(modules.artifacts_dir, ref_trace) + ref_trace.save_plaintext(ref_trace_dir, FLAGS.summarize) + ref_trace.serialize(ref_trace_dir) + for tar_trace in tar_traces: + tar_trace_dir = trace_utils.get_trace_dir(modules.artifacts_dir, tar_trace) + tar_trace.save_plaintext(tar_trace_dir, FLAGS.summarize) + tar_trace.serialize(tar_trace_dir) + + # Validate results. + if failed_backend_indices: + # Extract info for logging. + failed_backends = [tar_traces[i].backend_id for i in failed_backend_indices] + error_list = "".join([f"\n - {message}" for message in error_messages]) + self.fail( + "Comparison between the reference backend and the following targets " + f"failed: {failed_backends}. Errors: {error_list}\n" + "See the logs above for more details about the non-matching calls." + ) + + @classmethod + def tearDownClass(cls) -> None: + # Runs after all unit tests are completed. + super().tearDownClass() diff --git a/integrations/tensorflow/python_projects/iree_tf/iree/tf/support/tf_test_utils_test.py b/integrations/tensorflow/python_projects/iree_tf/iree/tf/support/tf_test_utils_test.py index b42758111b50..298a09aca546 100644 --- a/integrations/tensorflow/python_projects/iree_tf/iree/tf/support/tf_test_utils_test.py +++ b/integrations/tensorflow/python_projects/iree_tf/iree/tf/support/tf_test_utils_test.py @@ -12,76 +12,75 @@ class TfFunctionUnitTestModule(tf_test_utils.TestModule): - - @tf_test_utils.tf_function_unit_test(input_signature=[]) - def no_args(self): - return np.array([True], dtype=bool) - - @tf_test_utils.tf_function_unit_test(input_signature=[ - tf.TensorSpec([4]), - tf.TensorSpec([4]), - ]) - def default_uniform_inputs(self, a, b): - return a + b - - @tf_test_utils.tf_function_unit_test( - input_signature=[ - tf.TensorSpec([4]), - tf.TensorSpec([4]), - ], - input_generator=tf_utils.ndarange, - ) - def custom_input_generator(self, a, b): - return a + b - - @tf_test_utils.tf_function_unit_test( - input_signature=[ - tf.TensorSpec([4]), - tf.TensorSpec([4]), - ], - input_args=[ - np.array([0, 1, 2, 3], np.float32), - -np.array([0, 1, 2, 3], np.float32), - ], - ) - def custom_input_args(self, a, b): - return a + b - - # This test will fail if atol is not successfully set. - @tf_test_utils.tf_function_unit_test( - input_signature=[ - tf.TensorSpec([128, 3072], tf.float32), - tf.TensorSpec([3072, 256], tf.float32), - ], - atol=1e-2, - ) - def high_tolerance(self, a, b): - return tf.matmul(a, b) + @tf_test_utils.tf_function_unit_test(input_signature=[]) + def no_args(self): + return np.array([True], dtype=bool) + + @tf_test_utils.tf_function_unit_test( + input_signature=[ + tf.TensorSpec([4]), + tf.TensorSpec([4]), + ] + ) + def default_uniform_inputs(self, a, b): + return a + b + + @tf_test_utils.tf_function_unit_test( + input_signature=[ + tf.TensorSpec([4]), + tf.TensorSpec([4]), + ], + input_generator=tf_utils.ndarange, + ) + def custom_input_generator(self, a, b): + return a + b + + @tf_test_utils.tf_function_unit_test( + input_signature=[ + tf.TensorSpec([4]), + tf.TensorSpec([4]), + ], + input_args=[ + np.array([0, 1, 2, 3], np.float32), + -np.array([0, 1, 2, 3], np.float32), + ], + ) + def custom_input_args(self, a, b): + return a + b + + # This test will fail if atol is not successfully set. + @tf_test_utils.tf_function_unit_test( + input_signature=[ + tf.TensorSpec([128, 3072], tf.float32), + tf.TensorSpec([3072, 256], tf.float32), + ], + atol=1e-2, + ) + def high_tolerance(self, a, b): + return tf.matmul(a, b) class TestUtilsTests(tf.test.TestCase): - - def test_tf_function_unittet(self): - - class TfFunctionUnittestTest(tf_test_utils.TracedModuleTestCase): - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._modules = tf_test_utils.compile_tf_module( - TfFunctionUnitTestModule) - - TfFunctionUnittestTest.generate_unit_tests(TfFunctionUnitTestModule) - test_case = TfFunctionUnittestTest() - self.assertTrue(hasattr(test_case, 'test_no_args')) - self.assertTrue(hasattr(test_case, 'test_default_uniform_inputs')) - self.assertTrue(hasattr(test_case, 'test_custom_input_generator')) - self.assertTrue(hasattr(test_case, 'test_custom_input_args')) - self.assertTrue(hasattr(test_case, 'test_high_tolerance')) - - # Will throw an error if 'atol' is not set. - test_case = TfFunctionUnittestTest() - test_case.test_high_tolerance() - - -if __name__ == '__main__': - tf.test.main() + def test_tf_function_unittet(self): + class TfFunctionUnittestTest(tf_test_utils.TracedModuleTestCase): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._modules = tf_test_utils.compile_tf_module( + TfFunctionUnitTestModule + ) + + TfFunctionUnittestTest.generate_unit_tests(TfFunctionUnitTestModule) + test_case = TfFunctionUnittestTest() + self.assertTrue(hasattr(test_case, "test_no_args")) + self.assertTrue(hasattr(test_case, "test_default_uniform_inputs")) + self.assertTrue(hasattr(test_case, "test_custom_input_generator")) + self.assertTrue(hasattr(test_case, "test_custom_input_args")) + self.assertTrue(hasattr(test_case, "test_high_tolerance")) + + # Will throw an error if 'atol' is not set. + test_case = TfFunctionUnittestTest() + test_case.test_high_tolerance() + + +if __name__ == "__main__": + tf.test.main() diff --git a/integrations/tensorflow/python_projects/iree_tf/iree/tf/support/tf_utils.py b/integrations/tensorflow/python_projects/iree_tf/iree/tf/support/tf_utils.py index e84ba5cbf988..aa7251a1b02e 100644 --- a/integrations/tensorflow/python_projects/iree_tf/iree/tf/support/tf_utils.py +++ b/integrations/tensorflow/python_projects/iree_tf/iree/tf/support/tf_utils.py @@ -8,275 +8,300 @@ import os import random import re -from typing import (Any, Callable, Mapping, Optional, Sequence, Set, Tuple, - Union) +from typing import Any, Callable, Mapping, Optional, Sequence, Set, Tuple, Union import iree.runtime import numpy as np import tensorflow.compat.v2 as tf from absl import logging -InputGeneratorType = Callable[[Sequence[int], Union[tf.DType, np.dtype]], - np.ndarray] +InputGeneratorType = Callable[[Sequence[int], Union[tf.DType, np.dtype]], np.ndarray] def set_random_seed(seed: int = 0) -> None: - """Set random seed for tf, np and random.""" - tf.random.set_seed(seed) - random.seed(seed) - np.random.seed(seed) - - -def uniform(shape: Sequence[int], - dtype: Union[tf.DType, np.dtype] = np.float32, - low: float = -1.0, - high: float = 1.0) -> np.ndarray: - """np.random.uniform with simplified API and dtype and bool support.""" - # pytype doesn't understand the ternary with tf as "Any" - dtype = dtype.as_numpy_dtype if isinstance(dtype, tf.DType) else dtype # pytype: disable=attribute-error - if dtype == bool: - return np.random.choice(2, shape).astype(bool) - else: - values = np.random.uniform(size=shape, low=low, high=high) - if np.issubdtype(dtype, np.integer): - values = np.round(values) - return values.astype(dtype) + """Set random seed for tf, np and random.""" + tf.random.set_seed(seed) + random.seed(seed) + np.random.seed(seed) + + +def uniform( + shape: Sequence[int], + dtype: Union[tf.DType, np.dtype] = np.float32, + low: float = -1.0, + high: float = 1.0, +) -> np.ndarray: + """np.random.uniform with simplified API and dtype and bool support.""" + # pytype doesn't understand the ternary with tf as "Any" + dtype = ( + dtype.as_numpy_dtype if isinstance(dtype, tf.DType) else dtype + ) # pytype: disable=attribute-error + if dtype == bool: + return np.random.choice(2, shape).astype(bool) + else: + values = np.random.uniform(size=shape, low=low, high=high) + if np.issubdtype(dtype, np.integer): + values = np.round(values) + return values.astype(dtype) -def ndarange(shape: Sequence[int], - dtype: Union[tf.DType, np.dtype] = np.float32) -> np.ndarray: - """np.ndarange for arbitrary input shapes.""" - # pytype doesn't understand the ternary with tf as "Any" - dtype = dtype.as_numpy_dtype if isinstance(dtype, tf.DType) else dtype # pytype: disable=attribute-error - return np.arange(np.prod(shape), dtype=dtype).reshape(shape) +def ndarange( + shape: Sequence[int], dtype: Union[tf.DType, np.dtype] = np.float32 +) -> np.ndarray: + """np.ndarange for arbitrary input shapes.""" + # pytype doesn't understand the ternary with tf as "Any" + dtype = ( + dtype.as_numpy_dtype if isinstance(dtype, tf.DType) else dtype + ) # pytype: disable=attribute-error + return np.arange(np.prod(shape), dtype=dtype).reshape(shape) def random_permutation( - shape: Sequence[int], - dtype: Union[tf.DType, np.dtype] = np.float32) -> np.ndarray: - """Returns a random permutation of [0, np.prod(shape)).""" - values = ndarange(shape, dtype) - np.random.shuffle(values) - return values + shape: Sequence[int], dtype: Union[tf.DType, np.dtype] = np.float32 +) -> np.ndarray: + """Returns a random permutation of [0, np.prod(shape)).""" + values = ndarange(shape, dtype) + np.random.shuffle(values) + return values def apply_function(values, function): - """Applies 'function' recursively to the inputted values.""" - if isinstance(values, list): - return [apply_function(v, function) for v in values] - elif isinstance(values, tuple): - return tuple(apply_function(v, function) for v in values) - elif isinstance(values, Mapping): - return {k: apply_function(v, function) for k, v in values.items()} - else: - return function(values) + """Applies 'function' recursively to the inputted values.""" + if isinstance(values, list): + return [apply_function(v, function) for v in values] + elif isinstance(values, tuple): + return tuple(apply_function(v, function) for v in values) + elif isinstance(values, Mapping): + return {k: apply_function(v, function) for k, v in values.items()} + else: + return function(values) def generate_inputs( spec, # Union[Sequence[tf.TensorSpec], tf.TensorSpec] input_generator: InputGeneratorType, ) -> Sequence[np.ndarray]: - """Generates inputs for a given input signature using 'input_generator'.""" - make_static = lambda shape: [dim if dim is not None else 2 for dim in shape] - generate = lambda spec: input_generator(make_static(spec.shape), spec.dtype) - return apply_function(spec, generate) + """Generates inputs for a given input signature using 'input_generator'.""" + make_static = lambda shape: [dim if dim is not None else 2 for dim in shape] + generate = lambda spec: input_generator(make_static(spec.shape), spec.dtype) + return apply_function(spec, generate) def convert_to_numpy(values: Any) -> Any: - """Converts any tf.Tensor, int, float, bool or list values to numpy.""" - return apply_function(values, iree.runtime.normalize_value) + """Converts any tf.Tensor, int, float, bool or list values to numpy.""" + return apply_function(values, iree.runtime.normalize_value) def to_mlir_type(dtype: np.dtype) -> str: - """Returns a string that denotes the type 'dtype' in MLIR style.""" - if not isinstance(dtype, np.dtype): - # Handle np.int8 _not_ being a dtype. - dtype = np.dtype(dtype) - bits = dtype.itemsize * 8 - if np.issubdtype(dtype, np.integer): - return f"i{bits}" - elif np.issubdtype(dtype, np.floating): - return f"f{bits}" - else: - raise TypeError(f"Expected integer or floating type, but got {dtype}") - - -def get_shape_and_dtype(array: np.ndarray, - allow_non_mlir_dtype: bool = False) -> str: - shape_dtype = [str(dim) for dim in list(array.shape)] - if np.issubdtype(array.dtype, np.number): - shape_dtype.append(to_mlir_type(array.dtype)) - elif allow_non_mlir_dtype: - shape_dtype.append(f"") - else: - raise TypeError(f"Expected integer or floating type, but got {array.dtype}") - return "x".join(shape_dtype) - - -def save_input_values(inputs: Sequence[np.ndarray], - artifacts_dir: Optional[str] = None) -> str: - """Saves input values with IREE tools format if 'artifacts_dir' is set.""" - result = [] - for array in inputs: - shape_dtype = get_shape_and_dtype(array) - values = " ".join([str(x) for x in array.flatten()]) - result.append(f"{shape_dtype}={values}") - result = "\n".join(result) - if artifacts_dir is not None: - inputs_path = os.path.join(artifacts_dir, "inputs.txt") - logging.info("Saving IREE input values to: %s", inputs_path) - with open(inputs_path, "w") as f: - f.write(result) - f.write("\n") - return result + """Returns a string that denotes the type 'dtype' in MLIR style.""" + if not isinstance(dtype, np.dtype): + # Handle np.int8 _not_ being a dtype. + dtype = np.dtype(dtype) + bits = dtype.itemsize * 8 + if np.issubdtype(dtype, np.integer): + return f"i{bits}" + elif np.issubdtype(dtype, np.floating): + return f"f{bits}" + else: + raise TypeError(f"Expected integer or floating type, but got {dtype}") + + +def get_shape_and_dtype(array: np.ndarray, allow_non_mlir_dtype: bool = False) -> str: + shape_dtype = [str(dim) for dim in list(array.shape)] + if np.issubdtype(array.dtype, np.number): + shape_dtype.append(to_mlir_type(array.dtype)) + elif allow_non_mlir_dtype: + shape_dtype.append(f"") + else: + raise TypeError(f"Expected integer or floating type, but got {array.dtype}") + return "x".join(shape_dtype) + + +def save_input_values( + inputs: Sequence[np.ndarray], artifacts_dir: Optional[str] = None +) -> str: + """Saves input values with IREE tools format if 'artifacts_dir' is set.""" + result = [] + for array in inputs: + shape_dtype = get_shape_and_dtype(array) + values = " ".join([str(x) for x in array.flatten()]) + result.append(f"{shape_dtype}={values}") + result = "\n".join(result) + if artifacts_dir is not None: + inputs_path = os.path.join(artifacts_dir, "inputs.txt") + logging.info("Saving IREE input values to: %s", inputs_path) + with open(inputs_path, "w") as f: + f.write(result) + f.write("\n") + return result def remove_special_characters(value: str) -> str: - """Replaces special characters with '_' while keeping instances of '__'.""" - normalized_parts = [] - for part in value.split("__"): - part = re.sub(r"[^a-zA-Z0-9_]", "_", part) # Remove special characters. - part = re.sub(r"_+", "_", part) # Remove duplicate "_". - part = part.strip("_") # Don't end or start in "_". - normalized_parts.append(part) - return "__".join(normalized_parts) + """Replaces special characters with '_' while keeping instances of '__'.""" + normalized_parts = [] + for part in value.split("__"): + part = re.sub(r"[^a-zA-Z0-9_]", "_", part) # Remove special characters. + part = re.sub(r"_+", "_", part) # Remove duplicate "_". + part = part.strip("_") # Don't end or start in "_". + normalized_parts.append(part) + return "__".join(normalized_parts) def is_complex(tensors: Union[Sequence[tf.TensorSpec], tf.TensorSpec]) -> bool: - if isinstance(tensors, Sequence): - for tensor in tensors: - if is_complex(tensor): - return True - return False - else: - return tensors.dtype.is_complex # pytype: disable=attribute-error + if isinstance(tensors, Sequence): + for tensor in tensors: + if is_complex(tensor): + return True + return False + else: + return tensors.dtype.is_complex # pytype: disable=attribute-error def _complex_wrapper(function): - """Wraps a tf.function to allow compiling functions of complex numbers.""" + """Wraps a tf.function to allow compiling functions of complex numbers.""" - def decorator(*args, **kwargs): - inputs = [] - for real, imag in zip(args[::2], args[1::2]): - inputs.append(tf.complex(real, imag)) - result = function(*inputs, **kwargs) - return tf.math.real(result), tf.math.imag(result) + def decorator(*args, **kwargs): + inputs = [] + for real, imag in zip(args[::2], args[1::2]): + inputs.append(tf.complex(real, imag)) + result = function(*inputs, **kwargs) + return tf.math.real(result), tf.math.imag(result) - return decorator + return decorator def rewrite_complex_signature(function, signature: Sequence[tf.TensorSpec]): - """Compatibility layer for testing complex numbers.""" - if not all([spec.dtype.is_complex for spec in signature]): - raise NotImplementedError("Signatures with mixed complex and non-complex " - "tensor specs are not supported.") + """Compatibility layer for testing complex numbers.""" + if not all([spec.dtype.is_complex for spec in signature]): + raise NotImplementedError( + "Signatures with mixed complex and non-complex " + "tensor specs are not supported." + ) - # Rewrite the signature, replacing all complex tensors with pairs of real - # and imaginary tensors. - real_imag_signature = [] - for spec in signature: - new_dtype = tf.float32 if spec.dtype.size == 8 else tf.float64 - real_imag_signature.append(tf.TensorSpec(spec.shape, new_dtype)) - real_imag_signature.append(tf.TensorSpec(spec.shape, new_dtype)) + # Rewrite the signature, replacing all complex tensors with pairs of real + # and imaginary tensors. + real_imag_signature = [] + for spec in signature: + new_dtype = tf.float32 if spec.dtype.size == 8 else tf.float64 + real_imag_signature.append(tf.TensorSpec(spec.shape, new_dtype)) + real_imag_signature.append(tf.TensorSpec(spec.shape, new_dtype)) - return _complex_wrapper(function), real_imag_signature + return _complex_wrapper(function), real_imag_signature def make_dims_dynamic(spec: tf.TensorSpec) -> tf.TensorSpec: - """Gives a tf.TensorSpec dynamic dims.""" - return tf.TensorSpec([None] * len(spec.shape), spec.dtype) - - -def check_same(ref: Any, tar: Any, rtol: float, - atol: float) -> Tuple[bool, Union[str, None]]: - """Checks that ref and tar have identical datastructures and values.""" - # Check for matching types. - if not isinstance(tar, type(ref)): - error = ("Expected ref and tar to have the same type but got " - f"'{type(ref)}' and '{type(tar)}'") - logging.error(error) - return False, error - - if ref is None: - # Nothing to compare (e.g. the called method had no outputs). - return True, None - - # Recursive check for dicts. - if isinstance(ref, dict): - if ref.keys() != tar.keys(): - error = ("Expected ref and tar to have the same keys, but got " - f"'{ref.keys()}' and '{tar.keys()}'") - logging.error(error) - return False, error - # Check that all of the dictionaries' values are the same. - for key in ref: - same, error = check_same(ref[key], tar[key], rtol, atol) - if not same: - return same, error - - # Recursive check for iterables. - elif isinstance(ref, list) or isinstance(ref, tuple): - if len(ref) != len(tar): - error = ("Expected ref and tar to have the same length, but got " - f"{len(ref)} and {len(tar)}") - logging.error(error) - return False, error - # Check that all of the iterables' values are the same. - for i in range(len(ref)): - same, error = check_same(ref[i], tar[i], rtol, atol) - if not same: - return same, error - - # Base check for numpy arrays. - elif isinstance(ref, np.ndarray): - # TODO(#5359): Simplify this and verify that the types are actually the same - # Ignore np.bool != np.int8 because the IREE python runtime awkwardly - # returns np.int8s instead of np.bools. - if ref.dtype != tar.dtype and not ( - (ref.dtype == bool and tar.dtype == np.int8) or - (ref.dtype == np.int8 and tar.dtype == bool)): - error = ("Expected ref and tar to have the same dtype, but got " - f"'{ref.dtype}' and '{tar.dtype}'") - logging.error(error) - return False, error - - if ref.size == tar.size == 0: - return True, None - - if np.issubdtype(ref.dtype, np.floating): - same = np.allclose(ref, tar, rtol=rtol, atol=atol, equal_nan=True) - abs_diff = np.max(np.abs(ref - tar)) - rel_diff = np.max(np.abs(ref - tar) / np.max(np.abs(tar))) - diff_string = (f"Max abs diff: {abs_diff:.2e}, atol: {atol:.2e}, " - f"max relative diff: {rel_diff:.2e}, rtol: {rtol:.2e}") - if not same: - error = ("Floating point difference between ref and tar was too " - f"large. {diff_string}") + """Gives a tf.TensorSpec dynamic dims.""" + return tf.TensorSpec([None] * len(spec.shape), spec.dtype) + + +def check_same( + ref: Any, tar: Any, rtol: float, atol: float +) -> Tuple[bool, Union[str, None]]: + """Checks that ref and tar have identical datastructures and values.""" + # Check for matching types. + if not isinstance(tar, type(ref)): + error = ( + "Expected ref and tar to have the same type but got " + f"'{type(ref)}' and '{type(tar)}'" + ) logging.error(error) - else: - error = None - logging.info( - "Floating point difference between ref and tar was within " - "tolerance. %s", diff_string) - return same, error - elif np.issubdtype(ref.dtype, np.integer): - same = np.array_equal(ref, tar) - if not same: - abs_diff = np.max(np.abs(ref - tar)) - error = ("Expected array equality between ref and tar, but got " - f"a max elementwise difference of {abs_diff}") - logging.error(error) - else: - error = None - return same, error + return False, error + + if ref is None: + # Nothing to compare (e.g. the called method had no outputs). + return True, None + + # Recursive check for dicts. + if isinstance(ref, dict): + if ref.keys() != tar.keys(): + error = ( + "Expected ref and tar to have the same keys, but got " + f"'{ref.keys()}' and '{tar.keys()}'" + ) + logging.error(error) + return False, error + # Check that all of the dictionaries' values are the same. + for key in ref: + same, error = check_same(ref[key], tar[key], rtol, atol) + if not same: + return same, error + + # Recursive check for iterables. + elif isinstance(ref, list) or isinstance(ref, tuple): + if len(ref) != len(tar): + error = ( + "Expected ref and tar to have the same length, but got " + f"{len(ref)} and {len(tar)}" + ) + logging.error(error) + return False, error + # Check that all of the iterables' values are the same. + for i in range(len(ref)): + same, error = check_same(ref[i], tar[i], rtol, atol) + if not same: + return same, error + + # Base check for numpy arrays. + elif isinstance(ref, np.ndarray): + # TODO(#5359): Simplify this and verify that the types are actually the same + # Ignore np.bool != np.int8 because the IREE python runtime awkwardly + # returns np.int8s instead of np.bools. + if ref.dtype != tar.dtype and not ( + (ref.dtype == bool and tar.dtype == np.int8) + or (ref.dtype == np.int8 and tar.dtype == bool) + ): + error = ( + "Expected ref and tar to have the same dtype, but got " + f"'{ref.dtype}' and '{tar.dtype}'" + ) + logging.error(error) + return False, error + + if ref.size == tar.size == 0: + return True, None + + if np.issubdtype(ref.dtype, np.floating): + same = np.allclose(ref, tar, rtol=rtol, atol=atol, equal_nan=True) + abs_diff = np.max(np.abs(ref - tar)) + rel_diff = np.max(np.abs(ref - tar) / np.max(np.abs(tar))) + diff_string = ( + f"Max abs diff: {abs_diff:.2e}, atol: {atol:.2e}, " + f"max relative diff: {rel_diff:.2e}, rtol: {rtol:.2e}" + ) + if not same: + error = ( + "Floating point difference between ref and tar was too " + f"large. {diff_string}" + ) + logging.error(error) + else: + error = None + logging.info( + "Floating point difference between ref and tar was within " + "tolerance. %s", + diff_string, + ) + return same, error + elif np.issubdtype(ref.dtype, np.integer): + same = np.array_equal(ref, tar) + if not same: + abs_diff = np.max(np.abs(ref - tar)) + error = ( + "Expected array equality between ref and tar, but got " + f"a max elementwise difference of {abs_diff}" + ) + logging.error(error) + else: + error = None + return same, error + else: + return np.array_equal(ref, tar), None + + # Base check for native number types. + elif isinstance(ref, (int, float)): + return ref == tar, None + + # If outputs end up here then an extra branch for that type should be added. else: - return np.array_equal(ref, tar), None - - # Base check for native number types. - elif isinstance(ref, (int, float)): - return ref == tar, None - - # If outputs end up here then an extra branch for that type should be added. - else: - raise TypeError(f"Encountered results with unexpected type {type(ref)}") - return True, None + raise TypeError(f"Encountered results with unexpected type {type(ref)}") + return True, None diff --git a/integrations/tensorflow/python_projects/iree_tf/iree/tf/support/tf_utils_test.py b/integrations/tensorflow/python_projects/iree_tf/iree/tf/support/tf_utils_test.py index f26e8aa877c6..ad8018978c3f 100644 --- a/integrations/tensorflow/python_projects/iree_tf/iree/tf/support/tf_utils_test.py +++ b/integrations/tensorflow/python_projects/iree_tf/iree/tf/support/tf_utils_test.py @@ -12,81 +12,81 @@ class UtilsTests(tf.test.TestCase, parameterized.TestCase): + @parameterized.named_parameters( + [ + ("int8_to_i8", np.int8, "i8"), + ("int32_to_i32", np.int32, "i32"), + ("float32_to_f32", np.float32, "f32"), + ("float64_to_f64", np.float64, "f64"), + ] + ) + def test_to_mlir_type(self, numpy_type, mlir_type): + self.assertEqual(tf_utils.to_mlir_type(numpy_type), mlir_type) - @parameterized.named_parameters([('int8_to_i8', np.int8, 'i8'), - ('int32_to_i32', np.int32, 'i32'), - ('float32_to_f32', np.float32, 'f32'), - ('float64_to_f64', np.float64, 'f64')]) - def test_to_mlir_type(self, numpy_type, mlir_type): - self.assertEqual(tf_utils.to_mlir_type(numpy_type), mlir_type) + @parameterized.named_parameters( + [ + ("single_i32", [np.array([1, 2], dtype=np.int32)], "2xi32=1 2"), + ("single_f32", [np.array([1, 2], dtype=np.float32)], "2xf32=1.0 2.0"), + ] + ) + def test_save_input_values(self, inputs, inputs_str): + self.assertEqual(tf_utils.save_input_values(inputs), inputs_str) - @parameterized.named_parameters([ - ('single_i32', [np.array([1, 2], dtype=np.int32)], '2xi32=1 2'), - ('single_f32', [np.array([1, 2], dtype=np.float32)], '2xf32=1.0 2.0'), - ]) - def test_save_input_values(self, inputs, inputs_str): - self.assertEqual(tf_utils.save_input_values(inputs), inputs_str) + def test_apply_function(self): + inputs = [1, [2, 3], (4, 5), {"6": 6, "78": [7, 8]}] + expected = [0, [1, 2], (3, 4), {"6": 5, "78": [6, 7]}] + result = tf_utils.apply_function(inputs, lambda x: x - 1) + self.assertEqual(result, expected) + self.assertNotEqual(inputs, expected) - def test_apply_function(self): - inputs = [1, [2, 3], (4, 5), {'6': 6, '78': [7, 8]}] - expected = [0, [1, 2], (3, 4), {'6': 5, '78': [6, 7]}] - result = tf_utils.apply_function(inputs, lambda x: x - 1) - self.assertEqual(result, expected) - self.assertNotEqual(inputs, expected) + @parameterized.named_parameters( + [ + { + "testcase_name": "all the same", + "array_c": np.array([0, 1, 2]), + "array_d": np.array(["0", "1", "2"]), + "array_e": np.array([0.0, 0.1, 0.2]), + "tar_same": True, + }, + { + "testcase_name": "wrong int", + "array_c": np.array([1, 1, 2]), + "array_d": np.array(["0", "1", "2"]), + "array_e": np.array([0.0, 0.1, 0.2]), + "tar_same": False, + }, + { + "testcase_name": "wrong string", + "array_c": np.array([0, 1, 2]), + "array_d": np.array(["a", "1", "2"]), + "array_e": np.array([0.0, 0.1, 0.2]), + "tar_same": False, + }, + { + "testcase_name": "wrong float", + "array_c": np.array([0, 1, 2]), + "array_d": np.array(["0", "1", "2"]), + "array_e": np.array([1.0, 0.1, 0.2]), + "tar_same": False, + }, + ] + ) + def test_recursive_check_same(self, array_c, array_d, array_e, tar_same): + ref = { + "a": 1, + "b": [ + {"c": np.array([0, 1, 2])}, + {"d": np.array(["0", "1", "2"])}, + {"e": np.array([0.0, 0.1, 0.2])}, + ], + } + tar = { + "a": 1, + "b": [{"c": array_c}, {"d": array_d}, {"e": array_e}], + } + same, _ = tf_utils.check_same(ref, tar, rtol=1e-6, atol=1e-6) + self.assertEqual(tar_same, same) - @parameterized.named_parameters([ - { - 'testcase_name': 'all the same', - 'array_c': np.array([0, 1, 2]), - 'array_d': np.array(['0', '1', '2']), - 'array_e': np.array([0.0, 0.1, 0.2]), - 'tar_same': True, - }, - { - 'testcase_name': 'wrong int', - 'array_c': np.array([1, 1, 2]), - 'array_d': np.array(['0', '1', '2']), - 'array_e': np.array([0.0, 0.1, 0.2]), - 'tar_same': False, - }, - { - 'testcase_name': 'wrong string', - 'array_c': np.array([0, 1, 2]), - 'array_d': np.array(['a', '1', '2']), - 'array_e': np.array([0.0, 0.1, 0.2]), - 'tar_same': False, - }, - { - 'testcase_name': 'wrong float', - 'array_c': np.array([0, 1, 2]), - 'array_d': np.array(['0', '1', '2']), - 'array_e': np.array([1.0, 0.1, 0.2]), - 'tar_same': False, - }, - ]) - def test_recursive_check_same(self, array_c, array_d, array_e, tar_same): - # yapf: disable - ref = { - 'a': 1, - 'b': [ - {'c': np.array([0, 1, 2])}, - {'d': np.array(['0', '1', '2'])}, - {'e': np.array([0.0, 0.1, 0.2])} - ], - } - tar = { - 'a': 1, - 'b': [ - {'c': array_c}, - {'d': array_d}, - {'e': array_e} - ], - } - # yapf: enable - same, _ = tf_utils.check_same(ref, tar, rtol=1e-6, atol=1e-6) - self.assertEqual(tar_same, same) - - -if __name__ == '__main__': - tf.test.main() +if __name__ == "__main__": + tf.test.main() diff --git a/integrations/tensorflow/python_projects/iree_tf/iree/tf/support/trace_utils.py b/integrations/tensorflow/python_projects/iree_tf/iree/tf/support/trace_utils.py index c32e8ea55486..7a0d4449788a 100644 --- a/integrations/tensorflow/python_projects/iree_tf/iree/tf/support/trace_utils.py +++ b/integrations/tensorflow/python_projects/iree_tf/iree/tf/support/trace_utils.py @@ -30,383 +30,405 @@ def _zfill_width(length: int) -> Union[int, None]: - return int(np.ceil(np.log10(length))) if length else None + return int(np.ceil(np.log10(length))) if length else None def get_trace_dir(artifacts_dir: str, trace: Trace) -> str: - trace_dir = os.path.join(artifacts_dir, trace.backend_id, "traces", - trace.function_name) - os.makedirs(trace_dir, exist_ok=True) - return trace_dir + trace_dir = os.path.join( + artifacts_dir, trace.backend_id, "traces", trace.function_name + ) + os.makedirs(trace_dir, exist_ok=True) + return trace_dir class ModuleCall: - - def __init__(self, - method: str, - inputs: Tuple[Any], - outputs: Tuple[Any], - serialized_inputs: Tuple[str], - serialized_outputs: Tuple[str], - rtol: float = 1e-5, - atol: float = 1e-5): - """Records the details of a call to a CompiledModule.""" - self.method = method - - # Deepcopy to safegard against mutation. - self.inputs = copy.deepcopy(inputs) - if outputs is not None: - outputs = copy.deepcopy(outputs) - else: - outputs = tuple() - self.outputs = outputs if isinstance(outputs, tuple) else (outputs,) - - self.serialized_inputs = serialized_inputs - self.serialized_outputs = serialized_outputs - - self.rtol = rtol - self.atol = atol - - def get_tolerances(self) -> Tuple[float, float]: - """Gets the floating point tolerances associated with this call.""" - return self.rtol, self.atol - - def _get_shape_and_dtype(self, value: Any) -> str: - if isinstance(value, np.ndarray): - return tf_utils.get_shape_and_dtype(value, allow_non_mlir_dtype=True) - else: - return str(type(value)) - - def __str__(self): - prior_printoptions = np.get_printoptions() - np.set_printoptions(linewidth=NUMPY_LINEWIDTH) - - header = f"Method: {self.method}" - inputs = "\n".join( - [textwrap.indent(str(value), INDENT) for value in self.inputs]) - input_shapes = ", ".join( - self._get_shape_and_dtype(value) for value in self.inputs) - - outputs = "\n".join( - [textwrap.indent(str(value), INDENT) for value in self.outputs]) - output_shapes = ", ".join( - self._get_shape_and_dtype(value) for value in self.outputs) - - tolerances = textwrap.indent(f"rtol={self.rtol}, atol={self.atol}", INDENT) - body = (f"Inputs: {input_shapes}\n{inputs}\n" + def __init__( + self, + method: str, + inputs: Tuple[Any], + outputs: Tuple[Any], + serialized_inputs: Tuple[str], + serialized_outputs: Tuple[str], + rtol: float = 1e-5, + atol: float = 1e-5, + ): + """Records the details of a call to a CompiledModule.""" + self.method = method + + # Deepcopy to safegard against mutation. + self.inputs = copy.deepcopy(inputs) + if outputs is not None: + outputs = copy.deepcopy(outputs) + else: + outputs = tuple() + self.outputs = outputs if isinstance(outputs, tuple) else (outputs,) + + self.serialized_inputs = serialized_inputs + self.serialized_outputs = serialized_outputs + + self.rtol = rtol + self.atol = atol + + def get_tolerances(self) -> Tuple[float, float]: + """Gets the floating point tolerances associated with this call.""" + return self.rtol, self.atol + + def _get_shape_and_dtype(self, value: Any) -> str: + if isinstance(value, np.ndarray): + return tf_utils.get_shape_and_dtype(value, allow_non_mlir_dtype=True) + else: + return str(type(value)) + + def __str__(self): + prior_printoptions = np.get_printoptions() + np.set_printoptions(linewidth=NUMPY_LINEWIDTH) + + header = f"Method: {self.method}" + inputs = "\n".join( + [textwrap.indent(str(value), INDENT) for value in self.inputs] + ) + input_shapes = ", ".join( + self._get_shape_and_dtype(value) for value in self.inputs + ) + + outputs = "\n".join( + [textwrap.indent(str(value), INDENT) for value in self.outputs] + ) + output_shapes = ", ".join( + self._get_shape_and_dtype(value) for value in self.outputs + ) + + tolerances = textwrap.indent(f"rtol={self.rtol}, atol={self.atol}", INDENT) + body = ( + f"Inputs: {input_shapes}\n{inputs}\n" f"Outputs: {output_shapes}\n{outputs}" - f"\nTolerances:\n{tolerances}") - result = f"{header}\n{textwrap.indent(body, INDENT)}" - - np.set_printoptions(**prior_printoptions) - return result - - def serialize(self, call_dir: str) -> None: - """Stores a serialized copy of this call. - - Can be loaded via ModuleCall.load(call_dir) - - Args: - call_dir: str, the path to the directory to serialize this call to. - """ - os.makedirs(call_dir, exist_ok=True) - - metadata = { - "method": self.method, - "serialized_inputs": self.serialized_inputs, - "serialized_outputs": self.serialized_outputs, - "rtol": self.rtol, - "atol": self.atol - } - with open(os.path.join(call_dir, "metadata.pkl"), "wb") as f: - pickle.dump(metadata, f) - - width = _zfill_width(len(self.inputs)) - for i, value in enumerate(self.inputs): - path = os.path.join(call_dir, f"input_{str(i).zfill(width)}.pkl") - with open(path, "wb") as f: - pickle.dump(value, f) - - width = _zfill_width(len(self.outputs)) - for i, value in enumerate(self.outputs): - path = os.path.join(call_dir, f"output_{str(i).zfill(width)}.pkl") - with open(path, "wb") as f: - pickle.dump(value, f) - - @staticmethod - def load(call_dir: str) -> ModuleCall: - """Loads and returns a trace serialized with ModuleCall.serialize.""" - with open(os.path.join(call_dir, "metadata.pkl"), "rb") as f: - kwargs = pickle.load(f) - - for result_type in ["input", "output"]: - key = f"{result_type}s" # inputs or outputs - kwargs[key] = [] - - files = glob.glob(os.path.join(call_dir, f"{result_type}_*.pkl")) - for filename in sorted(files): - with open(filename, "rb") as f: - kwargs[key].append(pickle.load(f)) - - # Convert to tuple to match python's return type for multiple results. - kwargs[key] = tuple(kwargs[key]) - - return ModuleCall(**kwargs) + f"\nTolerances:\n{tolerances}" + ) + result = f"{header}\n{textwrap.indent(body, INDENT)}" + + np.set_printoptions(**prior_printoptions) + return result + + def serialize(self, call_dir: str) -> None: + """Stores a serialized copy of this call. + + Can be loaded via ModuleCall.load(call_dir) + + Args: + call_dir: str, the path to the directory to serialize this call to. + """ + os.makedirs(call_dir, exist_ok=True) + + metadata = { + "method": self.method, + "serialized_inputs": self.serialized_inputs, + "serialized_outputs": self.serialized_outputs, + "rtol": self.rtol, + "atol": self.atol, + } + with open(os.path.join(call_dir, "metadata.pkl"), "wb") as f: + pickle.dump(metadata, f) + + width = _zfill_width(len(self.inputs)) + for i, value in enumerate(self.inputs): + path = os.path.join(call_dir, f"input_{str(i).zfill(width)}.pkl") + with open(path, "wb") as f: + pickle.dump(value, f) + + width = _zfill_width(len(self.outputs)) + for i, value in enumerate(self.outputs): + path = os.path.join(call_dir, f"output_{str(i).zfill(width)}.pkl") + with open(path, "wb") as f: + pickle.dump(value, f) + + @staticmethod + def load(call_dir: str) -> ModuleCall: + """Loads and returns a trace serialized with ModuleCall.serialize.""" + with open(os.path.join(call_dir, "metadata.pkl"), "rb") as f: + kwargs = pickle.load(f) + + for result_type in ["input", "output"]: + key = f"{result_type}s" # inputs or outputs + kwargs[key] = [] + + files = glob.glob(os.path.join(call_dir, f"{result_type}_*.pkl")) + for filename in sorted(files): + with open(filename, "rb") as f: + kwargs[key].append(pickle.load(f)) + + # Convert to tuple to match python's return type for multiple results. + kwargs[key] = tuple(kwargs[key]) + + return ModuleCall(**kwargs) class Trace: - """Stores the inputs and outputs of a series of calls to a module.""" - - def __init__(self, - module: Union[module_utils.CompiledModule, None], - function: Union[Callable[[TracedModule], None], None], - _load_dict: Optional[Dict[str, Any]] = None): - """Extracts metadata from module and function and initializes. - - Example usage: - def forward_pass(...): - ... - module = IreeCompiledModule(...) - trace = Trace(module, forward_pass) - forward_pass(TracedModule(module, trace)) - - Args: - module: the module who's outputs this trace will record. - function: the function that module will be traced on. - _load_dict: used internally - """ - if _load_dict is None: - # Extract metadata from module and function. - self.module_name = module.module_name - self.compiled_paths = module.compiled_paths - self.backend_name = module.backend_info.backend_name - self.backend_id = module.backend_info.backend_id - self.backend_driver = module.backend_info.driver - self.iree_serializable = module.iree_serializable() - self.tflite_serializable = module.tflite_serializable() - self.function_name = function.__name__ - self.function_sourcefile = inspect.getsourcefile(function) - source, start_line = inspect.getsourcelines(function) - self.function_line_numbers = (start_line, start_line + len(source)) - self.function_source = "".join(source) - - self.calls = [] - else: - self.module_name = _load_dict["module_name"] - self.compiled_paths = _load_dict["compiled_paths"] - self.backend_name = _load_dict["backend_name"] - self.backend_id = _load_dict["backend_id"] - self.backend_driver = _load_dict["backend_driver"] - self.iree_serializable = _load_dict["iree_serializable"] - self.tflite_serializable = _load_dict["tflite_serializable"] - self.function_name = _load_dict["function_name"] - self.function_sourcefile = _load_dict["function_sourcefile"] - self.function_line_numbers = _load_dict["function_line_numbers"] - self.function_source = _load_dict["function_source"] - self.calls = _load_dict["calls"] - - def __str__(self): - header = (f"Trace of {self.module_name} compiled to '{self.backend_id}' " - f"on function '{self.function_name}':") - # Give each call a number so it's easier to compare between multiple traces. - calls = [f"{i + 1}. {str(call)}" for i, call in enumerate(self.calls)] - calls = textwrap.indent("\n".join(calls), prefix=INDENT) - return f"{header}\n{calls}" - - def __iter__(self): - for call in self.calls: - yield call - - def save_plaintext(self, trace_dir: str, summarize: bool = True) -> None: - """Saves a human-readable string representation of this trace to disk. - - Args: - trace_dir: str, path to the directory to save the trace in. - summarize: a bool controlling whether numpy should summarize the inputs - and outputs if they're large. Setting this to False is very slow for - large outputs. - """ - prior_printoptions = np.get_printoptions() - np.set_printoptions( - linewidth=NUMPY_LINEWIDTH, - threshold=None if summarize else sys.maxsize, - edgeitems=10) # Can show more items since they won't clutter the logs. - - path = os.path.join(trace_dir, "log.txt") - with open(path, "w") as f: - f.write(str(self)) - f.write("\n") - - np.set_printoptions(**prior_printoptions) - - def serialize(self, trace_dir: str) -> None: - """Stores a serialized copy of this trace in trace_dir. - - It can be loaded via `Trace.load(trace_dir)`. - - Args: - trace_dir: str, path to the directory to serialize the trace to. - """ - - compiled_paths = None - if self.compiled_paths is not None: - # Convert to a dict to avoid the issues with serializing defaultdicts. - compiled_paths = dict(self.compiled_paths) - - # Python serialization. - metadata = { - "module_name": self.module_name, - "compiled_paths": compiled_paths, - "backend_name": self.backend_name, - "backend_id": self.backend_id, - "backend_driver": self.backend_driver, - "iree_serializable": self.iree_serializable, - "tflite_serializable": self.tflite_serializable, - "function_name": self.function_name, - "function_sourcefile": self.function_sourcefile, - "function_line_numbers": self.function_line_numbers, - "function_source": self.function_source - } - with open(os.path.join(trace_dir, "metadata.pkl"), "wb") as f: - pickle.dump(metadata, f) - - width = _zfill_width(len(self.calls)) - for i, call in enumerate(self.calls): - call_dir = os.path.join(trace_dir, f"call_{str(i).zfill(width)}") - call.serialize(call_dir) - - # C++ benchmark serialization. - if self.iree_serializable or self.tflite_serializable: - entry_function = self.calls[0].method - compiled_path = self.compiled_paths[entry_function] - - if self.iree_serializable: - serialized_inputs = self.calls[0].serialized_inputs - flagfile = [ - f"--module={compiled_path}", - f"--device={self.backend_driver}", - f"--function={entry_function}", - ] + [f"--input=\"{input}\"" for input in serialized_inputs] - with open(os.path.join(trace_dir, "flagfile"), "w") as f: - f.writelines(line + "\n" for line in flagfile) - else: - with open(os.path.join(trace_dir, "graph_path"), "w") as f: - f.writelines(compiled_path + "\n") - - @staticmethod - def load(trace_dir: str) -> Trace: - """Loads and returns a trace serialized with Trace.serialize. - - Args: - trace_dir: str, path to the directory of the serialized trace. - - Returns: - A Trace deserialized from trace_dir. - """ - with open(os.path.join(trace_dir, "metadata.pkl"), "rb") as f: - load_dict = pickle.load(f) - call_dirs = sorted(glob.glob(os.path.join(trace_dir, "call_*"))) - calls = [ModuleCall.load(call_dir) for call_dir in call_dirs] - load_dict["calls"] = calls - return Trace(module=None, function=None, _load_dict=load_dict) + """Stores the inputs and outputs of a series of calls to a module.""" + + def __init__( + self, + module: Union[module_utils.CompiledModule, None], + function: Union[Callable[[TracedModule], None], None], + _load_dict: Optional[Dict[str, Any]] = None, + ): + """Extracts metadata from module and function and initializes. + + Example usage: + def forward_pass(...): + ... + module = IreeCompiledModule(...) + trace = Trace(module, forward_pass) + forward_pass(TracedModule(module, trace)) + + Args: + module: the module who's outputs this trace will record. + function: the function that module will be traced on. + _load_dict: used internally + """ + if _load_dict is None: + # Extract metadata from module and function. + self.module_name = module.module_name + self.compiled_paths = module.compiled_paths + self.backend_name = module.backend_info.backend_name + self.backend_id = module.backend_info.backend_id + self.backend_driver = module.backend_info.driver + self.iree_serializable = module.iree_serializable() + self.tflite_serializable = module.tflite_serializable() + self.function_name = function.__name__ + self.function_sourcefile = inspect.getsourcefile(function) + source, start_line = inspect.getsourcelines(function) + self.function_line_numbers = (start_line, start_line + len(source)) + self.function_source = "".join(source) + + self.calls = [] + else: + self.module_name = _load_dict["module_name"] + self.compiled_paths = _load_dict["compiled_paths"] + self.backend_name = _load_dict["backend_name"] + self.backend_id = _load_dict["backend_id"] + self.backend_driver = _load_dict["backend_driver"] + self.iree_serializable = _load_dict["iree_serializable"] + self.tflite_serializable = _load_dict["tflite_serializable"] + self.function_name = _load_dict["function_name"] + self.function_sourcefile = _load_dict["function_sourcefile"] + self.function_line_numbers = _load_dict["function_line_numbers"] + self.function_source = _load_dict["function_source"] + self.calls = _load_dict["calls"] + + def __str__(self): + header = ( + f"Trace of {self.module_name} compiled to '{self.backend_id}' " + f"on function '{self.function_name}':" + ) + # Give each call a number so it's easier to compare between multiple traces. + calls = [f"{i + 1}. {str(call)}" for i, call in enumerate(self.calls)] + calls = textwrap.indent("\n".join(calls), prefix=INDENT) + return f"{header}\n{calls}" + + def __iter__(self): + for call in self.calls: + yield call + + def save_plaintext(self, trace_dir: str, summarize: bool = True) -> None: + """Saves a human-readable string representation of this trace to disk. + + Args: + trace_dir: str, path to the directory to save the trace in. + summarize: a bool controlling whether numpy should summarize the inputs + and outputs if they're large. Setting this to False is very slow for + large outputs. + """ + prior_printoptions = np.get_printoptions() + np.set_printoptions( + linewidth=NUMPY_LINEWIDTH, + threshold=None if summarize else sys.maxsize, + edgeitems=10, + ) # Can show more items since they won't clutter the logs. + + path = os.path.join(trace_dir, "log.txt") + with open(path, "w") as f: + f.write(str(self)) + f.write("\n") + + np.set_printoptions(**prior_printoptions) + + def serialize(self, trace_dir: str) -> None: + """Stores a serialized copy of this trace in trace_dir. + + It can be loaded via `Trace.load(trace_dir)`. + + Args: + trace_dir: str, path to the directory to serialize the trace to. + """ + + compiled_paths = None + if self.compiled_paths is not None: + # Convert to a dict to avoid the issues with serializing defaultdicts. + compiled_paths = dict(self.compiled_paths) + + # Python serialization. + metadata = { + "module_name": self.module_name, + "compiled_paths": compiled_paths, + "backend_name": self.backend_name, + "backend_id": self.backend_id, + "backend_driver": self.backend_driver, + "iree_serializable": self.iree_serializable, + "tflite_serializable": self.tflite_serializable, + "function_name": self.function_name, + "function_sourcefile": self.function_sourcefile, + "function_line_numbers": self.function_line_numbers, + "function_source": self.function_source, + } + with open(os.path.join(trace_dir, "metadata.pkl"), "wb") as f: + pickle.dump(metadata, f) + + width = _zfill_width(len(self.calls)) + for i, call in enumerate(self.calls): + call_dir = os.path.join(trace_dir, f"call_{str(i).zfill(width)}") + call.serialize(call_dir) + + # C++ benchmark serialization. + if self.iree_serializable or self.tflite_serializable: + entry_function = self.calls[0].method + compiled_path = self.compiled_paths[entry_function] + + if self.iree_serializable: + serialized_inputs = self.calls[0].serialized_inputs + flagfile = [ + f"--module={compiled_path}", + f"--device={self.backend_driver}", + f"--function={entry_function}", + ] + [f'--input="{input}"' for input in serialized_inputs] + with open(os.path.join(trace_dir, "flagfile"), "w") as f: + f.writelines(line + "\n" for line in flagfile) + else: + with open(os.path.join(trace_dir, "graph_path"), "w") as f: + f.writelines(compiled_path + "\n") + + @staticmethod + def load(trace_dir: str) -> Trace: + """Loads and returns a trace serialized with Trace.serialize. + + Args: + trace_dir: str, path to the directory of the serialized trace. + + Returns: + A Trace deserialized from trace_dir. + """ + with open(os.path.join(trace_dir, "metadata.pkl"), "rb") as f: + load_dict = pickle.load(f) + call_dirs = sorted(glob.glob(os.path.join(trace_dir, "call_*"))) + calls = [ModuleCall.load(call_dir) for call_dir in call_dirs] + load_dict["calls"] = calls + return Trace(module=None, function=None, _load_dict=load_dict) class TracedModule: - - def __init__(self, module: module_utils.CompiledModule, trace: Trace): - """Wraps a CompiledModule so that all inputs and outputs are traced. - - The TracedModule returned will have an API almost identical to that of the - passed CompiledModule. The only changes is that if the keywords `rtol` or - `atol` are passed to one of the CompiledModule's methods, then they will be - used to set the tolerance for comparing that call to the same call in - another trace. So for example, calling `traced_module.add(a, b rtol=1e-8)` - would be the same as calling `module.add(a, b)`. - - Args: - module: the CompiledModule to trace. - trace: the Trace to record calls to this module with. - """ - self._module = module - self._trace = trace - - def _trace_call(self, method: module_utils._FunctionWrapper, - method_name: str): - """Decorates a CompiledModule method to capture its inputs and outputs.""" - - def call(*args, **kwargs): - # Pop manually specified tolerances from the kwargs (if any). - tolerances = {} - tolerances["rtol"] = kwargs.pop("rtol", None) - tolerances["atol"] = kwargs.pop("atol", None) - # Only pass these to ModuleCall if they were specified by the user. - tolerances = {k: v for k, v in tolerances.items() if v is not None} - - # Ensure the inputs are numpy inputs. - args = tf_utils.convert_to_numpy(args) - kwargs = tf_utils.convert_to_numpy(kwargs) - - # Run the method and record the details of the call. - outputs = method(*args, **kwargs) - serialized_inputs, serialized_outputs = method.get_serialized_values() - self._trace.calls.append( - ModuleCall(method_name, args, outputs, serialized_inputs, - serialized_outputs, **tolerances)) - return outputs - - return call - - def __getattr__(self, attr): - # Try to resolve it as an attr on self._module. - if not hasattr(self._module, attr): - raise AttributeError(f"The compiled module does not have attr '{attr}'") - module_attr = getattr(self._module, attr) - if not hasattr(module_attr, "__call__"): - # e.g. traced_module.backend - return module_attr - else: - # e.g. traced_module.simple_mul(a, b) - return self._trace_call(module_attr, method_name=attr) - - -def compare_traces(ref_trace: Trace, - tar_trace: Trace) -> Tuple[bool, Sequence[str]]: - traces_match = True - error_messages = [] - - # Check that all method invocations match. - ref_methods = [(call.method, call.rtol, call.atol) for call in ref_trace] - tar_methods = [(call.method, call.rtol, call.atol) for call in tar_trace] - if ref_methods != tar_methods: - # Raise a ValueError instead of returning False since this is an - # unexpected error. - raise ValueError( - "The reference and target traces have different call structures:\n" - f"Reference: {ref_methods}\nTarget: {tar_methods}") - - for ref_call, tar_call in zip(ref_trace, tar_trace): - logging.info("Comparing calls to '%s'", ref_call.method) - rtol, atol = ref_call.get_tolerances() - - inputs_match, error_message = tf_utils.check_same(ref_call.inputs, - tar_call.inputs, rtol, - atol) - if not inputs_match: - error_messages.append(error_message) - logging.error("Inputs did not match.") - outputs_match, error_message = tf_utils.check_same(ref_call.outputs, - tar_call.outputs, rtol, - atol) - if not outputs_match: - error_messages.append(error_message) - logging.error("Outputs did not match.") - calls_match = inputs_match and outputs_match - - if not calls_match: - logging.error("Comparision between '%s' and '%s' failed on method '%s'", - ref_trace.backend_id, tar_trace.backend_id, ref_call.method) - logging.error("Reference call '%s':\n%s", ref_trace.backend_id, ref_call) - logging.error("Target call '%s':\n%s", tar_trace.backend_id, tar_call) - - traces_match = traces_match and calls_match - return traces_match, error_messages + def __init__(self, module: module_utils.CompiledModule, trace: Trace): + """Wraps a CompiledModule so that all inputs and outputs are traced. + + The TracedModule returned will have an API almost identical to that of the + passed CompiledModule. The only changes is that if the keywords `rtol` or + `atol` are passed to one of the CompiledModule's methods, then they will be + used to set the tolerance for comparing that call to the same call in + another trace. So for example, calling `traced_module.add(a, b rtol=1e-8)` + would be the same as calling `module.add(a, b)`. + + Args: + module: the CompiledModule to trace. + trace: the Trace to record calls to this module with. + """ + self._module = module + self._trace = trace + + def _trace_call(self, method: module_utils._FunctionWrapper, method_name: str): + """Decorates a CompiledModule method to capture its inputs and outputs.""" + + def call(*args, **kwargs): + # Pop manually specified tolerances from the kwargs (if any). + tolerances = {} + tolerances["rtol"] = kwargs.pop("rtol", None) + tolerances["atol"] = kwargs.pop("atol", None) + # Only pass these to ModuleCall if they were specified by the user. + tolerances = {k: v for k, v in tolerances.items() if v is not None} + + # Ensure the inputs are numpy inputs. + args = tf_utils.convert_to_numpy(args) + kwargs = tf_utils.convert_to_numpy(kwargs) + + # Run the method and record the details of the call. + outputs = method(*args, **kwargs) + serialized_inputs, serialized_outputs = method.get_serialized_values() + self._trace.calls.append( + ModuleCall( + method_name, + args, + outputs, + serialized_inputs, + serialized_outputs, + **tolerances, + ) + ) + return outputs + + return call + + def __getattr__(self, attr): + # Try to resolve it as an attr on self._module. + if not hasattr(self._module, attr): + raise AttributeError(f"The compiled module does not have attr '{attr}'") + module_attr = getattr(self._module, attr) + if not hasattr(module_attr, "__call__"): + # e.g. traced_module.backend + return module_attr + else: + # e.g. traced_module.simple_mul(a, b) + return self._trace_call(module_attr, method_name=attr) + + +def compare_traces(ref_trace: Trace, tar_trace: Trace) -> Tuple[bool, Sequence[str]]: + traces_match = True + error_messages = [] + + # Check that all method invocations match. + ref_methods = [(call.method, call.rtol, call.atol) for call in ref_trace] + tar_methods = [(call.method, call.rtol, call.atol) for call in tar_trace] + if ref_methods != tar_methods: + # Raise a ValueError instead of returning False since this is an + # unexpected error. + raise ValueError( + "The reference and target traces have different call structures:\n" + f"Reference: {ref_methods}\nTarget: {tar_methods}" + ) + + for ref_call, tar_call in zip(ref_trace, tar_trace): + logging.info("Comparing calls to '%s'", ref_call.method) + rtol, atol = ref_call.get_tolerances() + + inputs_match, error_message = tf_utils.check_same( + ref_call.inputs, tar_call.inputs, rtol, atol + ) + if not inputs_match: + error_messages.append(error_message) + logging.error("Inputs did not match.") + outputs_match, error_message = tf_utils.check_same( + ref_call.outputs, tar_call.outputs, rtol, atol + ) + if not outputs_match: + error_messages.append(error_message) + logging.error("Outputs did not match.") + calls_match = inputs_match and outputs_match + + if not calls_match: + logging.error( + "Comparision between '%s' and '%s' failed on method '%s'", + ref_trace.backend_id, + tar_trace.backend_id, + ref_call.method, + ) + logging.error("Reference call '%s':\n%s", ref_trace.backend_id, ref_call) + logging.error("Target call '%s':\n%s", tar_trace.backend_id, tar_call) + + traces_match = traces_match and calls_match + return traces_match, error_messages diff --git a/integrations/tensorflow/python_projects/iree_tf/iree/tf/support/trace_utils_test.py b/integrations/tensorflow/python_projects/iree_tf/iree/tf/support/trace_utils_test.py index 65f1fdf7db2f..fb10201df1b2 100644 --- a/integrations/tensorflow/python_projects/iree_tf/iree/tf/support/trace_utils_test.py +++ b/integrations/tensorflow/python_projects/iree_tf/iree/tf/support/trace_utils_test.py @@ -16,132 +16,134 @@ class StatefulCountingModule(tf.Module): + def __init__(self): + self.count = tf.Variable([0.0]) - def __init__(self): - self.count = tf.Variable([0.]) + @tf.function(input_signature=[]) + def increment(self): + self.count.assign_add(tf.constant([1.0])) - @tf.function(input_signature=[]) - def increment(self): - self.count.assign_add(tf.constant([1.])) + @tf.function(input_signature=[]) + def get_count(self): + return self.count - @tf.function(input_signature=[]) - def get_count(self): - return self.count + @tf.function(input_signature=[tf.TensorSpec([1])]) + def increment_by(self, value): + self.count.assign_add(value) - @tf.function(input_signature=[tf.TensorSpec([1])]) - def increment_by(self, value): - self.count.assign_add(value) + @tf.function(input_signature=[tf.TensorSpec([1]), tf.TensorSpec([1])]) + def increment_by_max(self, a, b): + result = tf.maximum(a, b) + self.count.assign_add(result) + return result - @tf.function(input_signature=[tf.TensorSpec([1]), tf.TensorSpec([1])]) - def increment_by_max(self, a, b): - result = tf.maximum(a, b) - self.count.assign_add(result) - return result - - @tf.function(input_signature=[]) - def decrement(self): - self.count.assign_sub(tf.constant([1.])) + @tf.function(input_signature=[]) + def decrement(self): + self.count.assign_sub(tf.constant([1.0])) class TestUtilsTests(tf.test.TestCase, parameterized.TestCase): - - def test_trace_inputs_and_outputs(self): - - def trace_function(module): - # No inputs or outputs - module.increment() - # Only inputs - module.increment_by(np.array([81.], dtype=np.float32)) - # Only outputs - module.get_count() - - module = module_utils.TfCompiledModule.create_from_class( - StatefulCountingModule, module_utils.BackendInfo('tf')) - trace = trace_utils.Trace(module, trace_function) - trace_function(trace_utils.TracedModule(module, trace)) - - self.assertIsInstance(trace.calls[0].inputs, tuple) - self.assertEmpty(trace.calls[0].inputs) - self.assertIsInstance(trace.calls[0].outputs, tuple) - self.assertEmpty(trace.calls[0].outputs) - - self.assertAllClose(trace.calls[1].inputs[0], [81.]) - self.assertAllClose(trace.calls[2].outputs[0], [82.]) - - def test_nonmatching_methods(self): - - def tf_function(module): - module.increment() - module.increment() - - def vmvx_function(module): - module.increment() - module.decrement() - - tf_module = module_utils.TfCompiledModule.create_from_class( - StatefulCountingModule, module_utils.BackendInfo('tf')) - tf_trace = trace_utils.Trace(tf_module, tf_function) - tf_function(trace_utils.TracedModule(tf_module, tf_trace)) - - vmvx_module = module_utils.IreeCompiledModule.create_from_class( - StatefulCountingModule, module_utils.BackendInfo('iree_vmvx')) - vmvx_trace = trace_utils.Trace(vmvx_module, vmvx_function) - vmvx_function(trace_utils.TracedModule(vmvx_module, vmvx_trace)) - - with self.assertRaises(ValueError): - trace_utils.compare_traces(tf_trace, vmvx_trace) - - def test_nonmatching_inputs(self): - - def tf_function(module): - module.increment_by(np.array([42.], dtype=np.float32)) - - def vmvx_function(module): - module.increment_by(np.array([22.], dtype=np.float32)) - - tf_module = module_utils.TfCompiledModule.create_from_class( - StatefulCountingModule, module_utils.BackendInfo('tf')) - tf_trace = trace_utils.Trace(tf_module, tf_function) - tf_function(trace_utils.TracedModule(tf_module, tf_trace)) - - vmvx_module = module_utils.IreeCompiledModule.create_from_class( - StatefulCountingModule, module_utils.BackendInfo('iree_vmvx')) - vmvx_trace = trace_utils.Trace(vmvx_module, vmvx_function) - vmvx_function(trace_utils.TracedModule(vmvx_module, vmvx_trace)) - - same, error_messages = trace_utils.compare_traces(tf_trace, vmvx_trace) - self.assertFalse(same) - - def test_trace_serialize_and_load(self): - - def trace_function(module): - module.increment() - module.increment_by(np.array([81.], dtype=np.float32)) - module.increment_by_max(np.array([81], dtype=np.float32), - np.array([92], dtype=np.float32)) - module.get_count() - - module = module_utils.IreeCompiledModule.create_from_class( - StatefulCountingModule, module_utils.BackendInfo('iree_vmvx')) - trace = trace_utils.Trace(module, trace_function) - trace_function(trace_utils.TracedModule(module, trace)) - - with tempfile.TemporaryDirectory() as artifacts_dir: - trace_function_dir = trace_utils.get_trace_dir(artifacts_dir, trace) - trace.serialize(trace_function_dir) - self.assertTrue( - os.path.exists(os.path.join(trace_function_dir, 'metadata.pkl'))) - loaded_trace = trace_utils.Trace.load(trace_function_dir) - - # Check all calls match. - self.assertTrue(trace_utils.compare_traces(trace, loaded_trace)) - - # Check all other metadata match. - self.assertAllEqual(trace.__dict__.keys(), loaded_trace.__dict__.keys()) - for key in trace.__dict__.keys(): - if key != 'calls': - self.assertEqual(trace.__dict__[key], loaded_trace.__dict__[key]) - - -if __name__ == '__main__': - tf.test.main() + def test_trace_inputs_and_outputs(self): + def trace_function(module): + # No inputs or outputs + module.increment() + # Only inputs + module.increment_by(np.array([81.0], dtype=np.float32)) + # Only outputs + module.get_count() + + module = module_utils.TfCompiledModule.create_from_class( + StatefulCountingModule, module_utils.BackendInfo("tf") + ) + trace = trace_utils.Trace(module, trace_function) + trace_function(trace_utils.TracedModule(module, trace)) + + self.assertIsInstance(trace.calls[0].inputs, tuple) + self.assertEmpty(trace.calls[0].inputs) + self.assertIsInstance(trace.calls[0].outputs, tuple) + self.assertEmpty(trace.calls[0].outputs) + + self.assertAllClose(trace.calls[1].inputs[0], [81.0]) + self.assertAllClose(trace.calls[2].outputs[0], [82.0]) + + def test_nonmatching_methods(self): + def tf_function(module): + module.increment() + module.increment() + + def vmvx_function(module): + module.increment() + module.decrement() + + tf_module = module_utils.TfCompiledModule.create_from_class( + StatefulCountingModule, module_utils.BackendInfo("tf") + ) + tf_trace = trace_utils.Trace(tf_module, tf_function) + tf_function(trace_utils.TracedModule(tf_module, tf_trace)) + + vmvx_module = module_utils.IreeCompiledModule.create_from_class( + StatefulCountingModule, module_utils.BackendInfo("iree_vmvx") + ) + vmvx_trace = trace_utils.Trace(vmvx_module, vmvx_function) + vmvx_function(trace_utils.TracedModule(vmvx_module, vmvx_trace)) + + with self.assertRaises(ValueError): + trace_utils.compare_traces(tf_trace, vmvx_trace) + + def test_nonmatching_inputs(self): + def tf_function(module): + module.increment_by(np.array([42.0], dtype=np.float32)) + + def vmvx_function(module): + module.increment_by(np.array([22.0], dtype=np.float32)) + + tf_module = module_utils.TfCompiledModule.create_from_class( + StatefulCountingModule, module_utils.BackendInfo("tf") + ) + tf_trace = trace_utils.Trace(tf_module, tf_function) + tf_function(trace_utils.TracedModule(tf_module, tf_trace)) + + vmvx_module = module_utils.IreeCompiledModule.create_from_class( + StatefulCountingModule, module_utils.BackendInfo("iree_vmvx") + ) + vmvx_trace = trace_utils.Trace(vmvx_module, vmvx_function) + vmvx_function(trace_utils.TracedModule(vmvx_module, vmvx_trace)) + + same, error_messages = trace_utils.compare_traces(tf_trace, vmvx_trace) + self.assertFalse(same) + + def test_trace_serialize_and_load(self): + def trace_function(module): + module.increment() + module.increment_by(np.array([81.0], dtype=np.float32)) + module.increment_by_max( + np.array([81], dtype=np.float32), np.array([92], dtype=np.float32) + ) + module.get_count() + + module = module_utils.IreeCompiledModule.create_from_class( + StatefulCountingModule, module_utils.BackendInfo("iree_vmvx") + ) + trace = trace_utils.Trace(module, trace_function) + trace_function(trace_utils.TracedModule(module, trace)) + + with tempfile.TemporaryDirectory() as artifacts_dir: + trace_function_dir = trace_utils.get_trace_dir(artifacts_dir, trace) + trace.serialize(trace_function_dir) + self.assertTrue( + os.path.exists(os.path.join(trace_function_dir, "metadata.pkl")) + ) + loaded_trace = trace_utils.Trace.load(trace_function_dir) + + # Check all calls match. + self.assertTrue(trace_utils.compare_traces(trace, loaded_trace)) + + # Check all other metadata match. + self.assertAllEqual(trace.__dict__.keys(), loaded_trace.__dict__.keys()) + for key in trace.__dict__.keys(): + if key != "calls": + self.assertEqual(trace.__dict__[key], loaded_trace.__dict__[key]) + + +if __name__ == "__main__": + tf.test.main() diff --git a/integrations/tensorflow/python_projects/iree_tf/iree/tools/tf/__init__.py b/integrations/tensorflow/python_projects/iree_tf/iree/tools/tf/__init__.py index 6154bc9a8c1c..11c0a0680fed 100644 --- a/integrations/tensorflow/python_projects/iree_tf/iree/tools/tf/__init__.py +++ b/integrations/tensorflow/python_projects/iree_tf/iree/tools/tf/__init__.py @@ -14,8 +14,8 @@ def get_tool(exe_name: str) -> Optional[str]: - if platform.system() == "Windows": - exe_name = exe_name + ".exe" - this_path = os.path.dirname(__file__) - tool_path = os.path.join(this_path, exe_name) - return tool_path + if platform.system() == "Windows": + exe_name = exe_name + ".exe" + this_path = os.path.dirname(__file__) + tool_path = os.path.join(this_path, exe_name) + return tool_path diff --git a/integrations/tensorflow/python_projects/iree_tf/iree/tools/tf/scripts/iree_import_tf/__main__.py b/integrations/tensorflow/python_projects/iree_tf/iree/tools/tf/scripts/iree_import_tf/__main__.py index 9efebda2bfb7..cd44c1e30204 100644 --- a/integrations/tensorflow/python_projects/iree_tf/iree/tools/tf/scripts/iree_import_tf/__main__.py +++ b/integrations/tensorflow/python_projects/iree_tf/iree/tools/tf/scripts/iree_import_tf/__main__.py @@ -12,98 +12,104 @@ def main(): - parser = argparse.ArgumentParser() - parser.add_argument('saved_model_path', - help='Path to the saved model directory to import.') - parser.add_argument('-o', - '--output_path', - dest='output_path', - required=True, - help='Path to the mlir file name to output.') - parser.add_argument( - '--tf-savedmodel-exported-names', - required=False, - help= - "List of exported names for cases that the model has ambiguous exports") - parser.add_argument( - "--tf-import-type", - default="savedmodel_v2", - help= - "Import type for legacy saved models ('savedmodel_v2' or 'savedmodel_v1')" - ) - parser.add_argument( - "--tf-savedmodel-tags", - default="serve", - help="Tags used to indicate which MetaGraphDef to import, separated by " - "','") + parser = argparse.ArgumentParser() + parser.add_argument( + "saved_model_path", help="Path to the saved model directory to import." + ) + parser.add_argument( + "-o", + "--output_path", + dest="output_path", + required=True, + help="Path to the mlir file name to output.", + ) + parser.add_argument( + "--tf-savedmodel-exported-names", + required=False, + help="List of exported names for cases that the model has ambiguous exports", + ) + parser.add_argument( + "--tf-import-type", + default="savedmodel_v2", + help="Import type for legacy saved models ('savedmodel_v2' or 'savedmodel_v1')", + ) + parser.add_argument( + "--tf-savedmodel-tags", + default="serve", + help="Tags used to indicate which MetaGraphDef to import, separated by " "','", + ) - # Deprecated and unused. Kept in place so callers of the old tool don't break - # when using the new tool. - parser.add_argument('--output-format', - dest='_', - required=False, - help=argparse.SUPPRESS) - args = parser.parse_args() + # Deprecated and unused. Kept in place so callers of the old tool don't break + # when using the new tool. + parser.add_argument( + "--output-format", dest="_", required=False, help=argparse.SUPPRESS + ) + args = parser.parse_args() - saved_model_dir = args.saved_model_path - exported_names = args.tf_savedmodel_exported_names - output_path = args.output_path - import_type = args.tf_import_type - tags = args.tf_savedmodel_tags - import_saved_model(output_path=output_path, - saved_model_dir=saved_model_dir, - exported_names=exported_names, - import_type=import_type, - tags=tags) - - -def import_saved_model(*, output_path, saved_model_dir, exported_names, - import_type, tags): - # From here there be dragons. - from tensorflow.python import pywrap_mlir - if import_type == "savedmodel_v2": - result = pywrap_mlir.experimental_convert_saved_model_to_mlir( - saved_model_dir, exported_names=exported_names, show_debug_info=False) - elif import_type == "savedmodel_v1": - # You saw it here, folks: The TF team just adds random positional params - # without explanation or default. So we detect and default them on our - # own. Because this is normal and fine. - sig = inspect.signature( - pywrap_mlir.experimental_convert_saved_model_v1_to_mlir) - dumb_extra_kwargs = {} - if "include_variables_in_initializers" in sig.parameters: - dumb_extra_kwargs["include_variables_in_initializers"] = False - if "upgrade_legacy" in sig.parameters: - dumb_extra_kwargs["upgrade_legacy"] = False - if "lift_variables" in sig.parameters: - dumb_extra_kwargs["lift_variables"] = True - result = pywrap_mlir.experimental_convert_saved_model_v1_to_mlir( - saved_model_dir, + saved_model_dir = args.saved_model_path + exported_names = args.tf_savedmodel_exported_names + output_path = args.output_path + import_type = args.tf_import_type + tags = args.tf_savedmodel_tags + import_saved_model( + output_path=output_path, + saved_model_dir=saved_model_dir, exported_names=exported_names, + import_type=import_type, tags=tags, - show_debug_info=False, - **dumb_extra_kwargs) - else: - raise ValueError(f"Unsupported import type: '{import_type}'") - # The import to MLIR produces public functions like __inference__{name}_2222 - # but the conversion pipeline requires a single public @main function. - # Not sure how this was allowed to happen, but regex to the rescue. - # This is fine and normal, and totally to be expected. :( - result = re.sub(r"func @__inference_(.+)_[0-9]+\(", r"func @\1(", result) - pipeline = ["tf-lower-to-mlprogram-and-hlo"] - result = pywrap_mlir.experimental_run_pass_pipeline(result, - ",".join(pipeline), - show_debug_info=False) + ) + + +def import_saved_model( + *, output_path, saved_model_dir, exported_names, import_type, tags +): + # From here there be dragons. + from tensorflow.python import pywrap_mlir + + if import_type == "savedmodel_v2": + result = pywrap_mlir.experimental_convert_saved_model_to_mlir( + saved_model_dir, exported_names=exported_names, show_debug_info=False + ) + elif import_type == "savedmodel_v1": + # You saw it here, folks: The TF team just adds random positional params + # without explanation or default. So we detect and default them on our + # own. Because this is normal and fine. + sig = inspect.signature(pywrap_mlir.experimental_convert_saved_model_v1_to_mlir) + dumb_extra_kwargs = {} + if "include_variables_in_initializers" in sig.parameters: + dumb_extra_kwargs["include_variables_in_initializers"] = False + if "upgrade_legacy" in sig.parameters: + dumb_extra_kwargs["upgrade_legacy"] = False + if "lift_variables" in sig.parameters: + dumb_extra_kwargs["lift_variables"] = True + result = pywrap_mlir.experimental_convert_saved_model_v1_to_mlir( + saved_model_dir, + exported_names=exported_names, + tags=tags, + show_debug_info=False, + **dumb_extra_kwargs, + ) + else: + raise ValueError(f"Unsupported import type: '{import_type}'") + # The import to MLIR produces public functions like __inference__{name}_2222 + # but the conversion pipeline requires a single public @main function. + # Not sure how this was allowed to happen, but regex to the rescue. + # This is fine and normal, and totally to be expected. :( + result = re.sub(r"func @__inference_(.+)_[0-9]+\(", r"func @\1(", result) + pipeline = ["tf-lower-to-mlprogram-and-hlo"] + result = pywrap_mlir.experimental_run_pass_pipeline( + result, ",".join(pipeline), show_debug_info=False + ) - # TODO: The experimental_write_bytecode function does not register the - # stablehlo dialect. Once fixed, remove this bypass. - WRITE_BYTECODE = False - if WRITE_BYTECODE: - result = pywrap_mlir.experimental_write_bytecode(output_path, result) - else: - with open(output_path, "wt") as f: - f.write(result) + # TODO: The experimental_write_bytecode function does not register the + # stablehlo dialect. Once fixed, remove this bypass. + WRITE_BYTECODE = False + if WRITE_BYTECODE: + result = pywrap_mlir.experimental_write_bytecode(output_path, result) + else: + with open(output_path, "wt") as f: + f.write(result) if __name__ == "__main__": - main() + main() diff --git a/integrations/tensorflow/python_projects/iree_tf/setup.py b/integrations/tensorflow/python_projects/iree_tf/setup.py index d59f430a0249..da94a7d1cf5c 100644 --- a/integrations/tensorflow/python_projects/iree_tf/setup.py +++ b/integrations/tensorflow/python_projects/iree_tf/setup.py @@ -15,9 +15,9 @@ import platform from setuptools import setup, find_namespace_packages -README = r''' +README = r""" TensorFlow TF Compiler Tools -''' +""" exe_suffix = ".exe" if platform.system() == "Windows" else "" @@ -28,15 +28,15 @@ def load_version_info(): - with open(VERSION_INFO_FILE, "rt") as f: - return json.load(f) + with open(VERSION_INFO_FILE, "rt") as f: + return json.load(f) try: - version_info = load_version_info() + version_info = load_version_info() except FileNotFoundError: - print("version_info.json not found. Using defaults") - version_info = {} + print("version_info.json not found. Using defaults") + version_info = {} PACKAGE_SUFFIX = version_info.get("package-suffix") or "" PACKAGE_VERSION = version_info.get("package-version") or "0.1dev1" @@ -60,13 +60,17 @@ def load_version_info(): "Programming Language :: Python :: 3.10", ], python_requires=">=3.8", - packages=find_namespace_packages(include=[ - "iree.tools.tf", - "iree.tools.tf.*", - "iree.tf.support", - ]), + packages=find_namespace_packages( + include=[ + "iree.tools.tf", + "iree.tools.tf.*", + "iree.tf.support", + ] + ), package_data={ - "iree.tools.tf": [f"iree-import-tf{exe_suffix}",], + "iree.tools.tf": [ + f"iree-import-tf{exe_suffix}", + ], }, entry_points={ "console_scripts": [ diff --git a/integrations/tensorflow/python_projects/iree_tflite/iree/tools/tflite/__init__.py b/integrations/tensorflow/python_projects/iree_tflite/iree/tools/tflite/__init__.py index 1d8001b318ea..4916a877c1aa 100644 --- a/integrations/tensorflow/python_projects/iree_tflite/iree/tools/tflite/__init__.py +++ b/integrations/tensorflow/python_projects/iree_tflite/iree/tools/tflite/__init__.py @@ -14,8 +14,8 @@ def get_tool(exe_name: str) -> Optional[str]: - if platform.system() == "Windows": - exe_name = exe_name + ".exe" - this_path = os.path.dirname(__file__) - tool_path = os.path.join(this_path, exe_name) - return tool_path + if platform.system() == "Windows": + exe_name = exe_name + ".exe" + this_path = os.path.dirname(__file__) + tool_path = os.path.join(this_path, exe_name) + return tool_path diff --git a/integrations/tensorflow/python_projects/iree_tflite/iree/tools/tflite/scripts/iree_import_tflite/__main__.py b/integrations/tensorflow/python_projects/iree_tflite/iree/tools/tflite/scripts/iree_import_tflite/__main__.py index 7bf64b6824b2..d6b5cb7e6457 100644 --- a/integrations/tensorflow/python_projects/iree_tflite/iree/tools/tflite/scripts/iree_import_tflite/__main__.py +++ b/integrations/tensorflow/python_projects/iree_tflite/iree/tools/tflite/scripts/iree_import_tflite/__main__.py @@ -15,49 +15,48 @@ def main(): - parser = argparse.ArgumentParser() - parser.add_argument('flatbuffer', help='') - parser.add_argument( - '-o', - '--output-path', - dest='output_path', - required=True, - help='Path to the mlirbc file name to output.', - ) - parser.add_argument( - '--input-array', - dest='input_arrays', - action='append', - help='Input tensor, if different from the default inputs', - ) - parser.add_argument( - '--output-array', - dest='output_arrays', - action='append', - help='Output tensor, if different from the default inputs', - ) + parser = argparse.ArgumentParser() + parser.add_argument("flatbuffer", help="") + parser.add_argument( + "-o", + "--output-path", + dest="output_path", + required=True, + help="Path to the mlirbc file name to output.", + ) + parser.add_argument( + "--input-array", + dest="input_arrays", + action="append", + help="Input tensor, if different from the default inputs", + ) + parser.add_argument( + "--output-array", + dest="output_arrays", + action="append", + help="Output tensor, if different from the default inputs", + ) - # Deprecated and unused. Kept in place so callers of the old tool don't break - # when using the new tool. - parser.add_argument( - '--output-format', - dest='output_format', - required=False, - default='mlir-bytecode', - help=argparse.SUPPRESS, - ) - args = parser.parse_args() + # Deprecated and unused. Kept in place so callers of the old tool don't break + # when using the new tool. + parser.add_argument( + "--output-format", + dest="output_format", + required=False, + default="mlir-bytecode", + help=argparse.SUPPRESS, + ) + args = parser.parse_args() - if args.output_format != 'mlir-bytecode': - logging.warning( - 'output-format option is deprecated, emitting MLIR bytecode') + if args.output_format != "mlir-bytecode": + logging.warning("output-format option is deprecated, emitting MLIR bytecode") - tflite_to_tosa( - flatbuffer=args.flatbuffer, - bytecode=args.output_path, - ordered_input_arrays=args.input_arrays, - ordered_output_arrays=args.output_arrays, - ) + tflite_to_tosa( + flatbuffer=args.flatbuffer, + bytecode=args.output_path, + ordered_input_arrays=args.input_arrays, + ordered_output_arrays=args.output_arrays, + ) def tflite_to_tosa( @@ -67,14 +66,14 @@ def tflite_to_tosa( ordered_input_arrays=None, ordered_output_arrays=None, ): - experimental_tflite_to_tosa_bytecode( - flatbuffer, - bytecode, - use_external_constant, - ordered_input_arrays, - ordered_output_arrays, - ) + experimental_tflite_to_tosa_bytecode( + flatbuffer, + bytecode, + use_external_constant, + ordered_input_arrays, + ordered_output_arrays, + ) -if __name__ == '__main__': - main() +if __name__ == "__main__": + main() diff --git a/integrations/tensorflow/python_projects/iree_tflite/setup.py b/integrations/tensorflow/python_projects/iree_tflite/setup.py index aa106108f0ef..6c8823c6038d 100644 --- a/integrations/tensorflow/python_projects/iree_tflite/setup.py +++ b/integrations/tensorflow/python_projects/iree_tflite/setup.py @@ -15,9 +15,9 @@ import platform from setuptools import setup, find_namespace_packages -README = r''' +README = r""" TensorFlow TFLite Compiler Tools -''' +""" exe_suffix = ".exe" if platform.system() == "Windows" else "" @@ -28,15 +28,15 @@ def load_version_info(): - with open(VERSION_INFO_FILE, "rt") as f: - return json.load(f) + with open(VERSION_INFO_FILE, "rt") as f: + return json.load(f) try: - version_info = load_version_info() + version_info = load_version_info() except FileNotFoundError: - print("version_info.json not found. Using defaults") - version_info = {} + print("version_info.json not found. Using defaults") + version_info = {} PACKAGE_SUFFIX = version_info.get("package-suffix") or "" PACKAGE_VERSION = version_info.get("package-version") or "0.1dev1" @@ -60,12 +60,16 @@ def load_version_info(): "Programming Language :: Python :: 3.10", ], python_requires=">=3.8", - packages=find_namespace_packages(include=[ - "iree.tools.tflite", - "iree.tools.tflite.*", - ]), + packages=find_namespace_packages( + include=[ + "iree.tools.tflite", + "iree.tools.tflite.*", + ] + ), package_data={ - "iree.tools.tflite": [f"iree-import-tflite{exe_suffix}",], + "iree.tools.tflite": [ + f"iree-import-tflite{exe_suffix}", + ], }, entry_points={ "console_scripts": [ diff --git a/integrations/tensorflow/test/iree_tfl_tests/update_tflite_model_documentation.py b/integrations/tensorflow/test/iree_tfl_tests/update_tflite_model_documentation.py index 3da89699e637..29fe7ead702a 100755 --- a/integrations/tensorflow/test/iree_tfl_tests/update_tflite_model_documentation.py +++ b/integrations/tensorflow/test/iree_tfl_tests/update_tflite_model_documentation.py @@ -15,54 +15,57 @@ from typing import Sequence # The symbols to show in the table if the operation is supported or not. -SUCCESS_ELEMENT = 'PASS ✓' -FAILURE_ELEMENT = 'FAIL ✗' +SUCCESS_ELEMENT = "PASS ✓" +FAILURE_ELEMENT = "FAIL ✗" def main(): - dir = os.path.dirname(__file__) - readme_file_path = os.path.join(dir, 'README.md') - old_lines = read_file(readme_file_path) - - files = list(Path(dir).glob('*.run')) - num_files = len(files) - - models = [[0 for x in range(2)] for y in range(num_files)] - print(f"Processing {num_files} files") - - for i in range(num_files): - name = os.path.basename(files[i].name).replace('.run', '') - models[i][0] = name.ljust(20) - - with open(files[i], 'r') as file: - models[i][1] = FAILURE_ELEMENT if 'XFAIL' in file.read( - ) else SUCCESS_ELEMENT - - with open(readme_file_path, 'w') as tflite_model_documentation: - tflite_model_documentation.write('# TFLite integration tests status\n\n' \ - 'This dashboard shows the models that are currently being tested on IREE\'s\n' \ - 'presubmits. If any tests are added or changed, please run\n' \ - 'update_tflite_model_documentation.py to update this table.\n\n' \ - '| Model | Status |\n' \ - '| ------------------ | ------------------ |\n') - tflite_model_documentation.write(create_markdown_table(models)) - - new_lines = read_file(readme_file_path) - if new_lines == old_lines: - print(f"{readme_file_path} required no update") - else: - print(f"Updated {readme_file_path} with latest test status") + dir = os.path.dirname(__file__) + readme_file_path = os.path.join(dir, "README.md") + old_lines = read_file(readme_file_path) + + files = list(Path(dir).glob("*.run")) + num_files = len(files) + + models = [[0 for x in range(2)] for y in range(num_files)] + print(f"Processing {num_files} files") + + for i in range(num_files): + name = os.path.basename(files[i].name).replace(".run", "") + models[i][0] = name.ljust(20) + + with open(files[i], "r") as file: + models[i][1] = ( + FAILURE_ELEMENT if "XFAIL" in file.read() else SUCCESS_ELEMENT + ) + + with open(readme_file_path, "w") as tflite_model_documentation: + tflite_model_documentation.write( + "# TFLite integration tests status\n\n" + "This dashboard shows the models that are currently being tested on IREE's\n" + "presubmits. If any tests are added or changed, please run\n" + "update_tflite_model_documentation.py to update this table.\n\n" + "| Model | Status |\n" + "| ------------------ | ------------------ |\n" + ) + tflite_model_documentation.write(create_markdown_table(models)) + + new_lines = read_file(readme_file_path) + if new_lines == old_lines: + print(f"{readme_file_path} required no update") + else: + print(f"Updated {readme_file_path} with latest test status") def read_file(file_path): - with open(file_path, 'r') as file: - return file.readlines() + with open(file_path, "r") as file: + return file.readlines() def create_markdown_table(rows: Sequence[Sequence[str]]): - """Converts a 2D array to a Markdown table.""" - return '\n'.join([' | '.join(row) for row in rows]) + """Converts a 2D array to a Markdown table.""" + return "\n".join([" | ".join(row) for row in rows]) -if __name__ == '__main__': - main() +if __name__ == "__main__": + main() diff --git a/integrations/tensorflow/test/lit.cfg.py b/integrations/tensorflow/test/lit.cfg.py index 5b5b62e1c934..9da1bb8dcd4b 100644 --- a/integrations/tensorflow/test/lit.cfg.py +++ b/integrations/tensorflow/test/lit.cfg.py @@ -21,9 +21,11 @@ llvm_config.with_system_environment("VK_ICD_FILENAMES") # Put execution artifacts in the temp dir. -config.test_exec_root = (os.environ.get("TEST_UNDECLARED_OUTPUTS_DIR") or - os.environ.get("TEST_TMPDIR") or - os.path.join(tempfile.gettempdir(), "lit")) +config.test_exec_root = ( + os.environ.get("TEST_UNDECLARED_OUTPUTS_DIR") + or os.environ.get("TEST_TMPDIR") + or os.path.join(tempfile.gettempdir(), "lit") +) # name: The name of this test suite. config.name = "TENSORFLOW_TESTS" @@ -36,7 +38,7 @@ # test_source_root: The root path where tests are located. config.test_source_root = os.path.dirname(__file__) -#config.use_default_substitutions() +# config.use_default_substitutions() config.excludes = [ "lit.cfg.py", "lit.site.cfg.py", @@ -46,20 +48,24 @@ "imagenet_test_data.py", ] -config.substitutions.extend([ - ("%PYTHON", os.getenv("PYTHON", sys.executable)), -]) +config.substitutions.extend( + [ + ("%PYTHON", os.getenv("PYTHON", sys.executable)), + ] +) # Add our local projects to the PYTHONPATH -python_projects_dir = os.path.join(os.path.dirname(__file__), "..", - "python_projects") +python_projects_dir = os.path.join(os.path.dirname(__file__), "..", "python_projects") test_src_dir = os.path.join(os.path.dirname(__file__), "python") -llvm_config.with_environment("PYTHONPATH", [ - test_src_dir, - os.path.join(python_projects_dir, "iree_tf"), - os.path.join(python_projects_dir, "iree_tflite"), -], - append_path=True) +llvm_config.with_environment( + "PYTHONPATH", + [ + test_src_dir, + os.path.join(python_projects_dir, "iree_tf"), + os.path.join(python_projects_dir, "iree_tflite"), + ], + append_path=True, +) # Enable features based on -D FEATURES=hugetest,vulkan # syntax. @@ -67,9 +73,9 @@ disable_features_param = lit_config.params.get("DISABLE_FEATURES") disable_features = [] if disable_features_param: - disable_features = disable_features_param.split(",") + disable_features = disable_features_param.split(",") if "llvmcpu" not in disable_features: - config.available_features.add("llvmcpu") + config.available_features.add("llvmcpu") features_param = lit_config.params.get("FEATURES") if features_param: - config.available_features.update(features_param.split(",")) + config.available_features.update(features_param.split(",")) diff --git a/integrations/tensorflow/test/python/generate_runner.py b/integrations/tensorflow/test/python/generate_runner.py index 6442c827f168..6e6bedfaba29 100644 --- a/integrations/tensorflow/test/python/generate_runner.py +++ b/integrations/tensorflow/test/python/generate_runner.py @@ -21,73 +21,70 @@ def main(args): - variant = args[0] - flags = args[1] - src_file_specs = args[2:] - src_files = [ - transform_src_file_spec_to_src_file(spec) for spec in src_file_specs - ] - module_names = [transform_src_file_to_module(f) for f in src_files] - run_files = [ - transform_src_file_spec_to_run_file(spec, variant) - for spec in src_file_specs - ] - for module, run_file in zip(module_names, run_files): - if os.path.exists(run_file): - print(f"SKIPPING (exists): {run_file}") - continue - print(f"CREATE RUN FILE: {module} -> {run_file}") - os.makedirs(os.path.dirname(run_file), exist_ok=True) - with open(run_file, "wt") as f: - print(f"# REQUIRES: {variant}", file=f) - print(f"# RUN: %PYTHON -m {module} {flags}", file=f) + variant = args[0] + flags = args[1] + src_file_specs = args[2:] + src_files = [transform_src_file_spec_to_src_file(spec) for spec in src_file_specs] + module_names = [transform_src_file_to_module(f) for f in src_files] + run_files = [ + transform_src_file_spec_to_run_file(spec, variant) for spec in src_file_specs + ] + for module, run_file in zip(module_names, run_files): + if os.path.exists(run_file): + print(f"SKIPPING (exists): {run_file}") + continue + print(f"CREATE RUN FILE: {module} -> {run_file}") + os.makedirs(os.path.dirname(run_file), exist_ok=True) + with open(run_file, "wt") as f: + print(f"# REQUIRES: {variant}", file=f) + print(f"# RUN: %PYTHON -m {module} {flags}", file=f) def transform_src_file_spec_to_src_file(spec: str): - try: - colon_pos = spec.index(":") - except ValueError: - return spec - return spec[0:colon_pos] + try: + colon_pos = spec.index(":") + except ValueError: + return spec + return spec[0:colon_pos] def transform_src_file_to_module(file_name): - module_name = file_name.replace("/", ".") - if (module_name.endswith(".py")): - module_name = module_name[0:-3] - return module_name + module_name = file_name.replace("/", ".") + if module_name.endswith(".py"): + module_name = module_name[0:-3] + return module_name def transform_src_file_spec_to_run_file(spec, variant): - # Transform path:alias, defaulting to the basename if the alias is not - # specified. - file_path = spec - file_name = os.path.basename(file_path) - colon_pos = -1 - try: - colon_pos = spec.index(":") - except ValueError: - pass - if colon_pos > -1: - # Explicit alias. - file_path = spec[0:colon_pos] - file_name = spec[colon_pos + 1:] - print(f"FILE PATH = {file_path}") - else: - # Auto detect the alias from the basename. + # Transform path:alias, defaulting to the basename if the alias is not + # specified. + file_path = spec file_name = os.path.basename(file_path) - if file_name.endswith(".py"): - file_name = file_name[0:-3] - if file_name.endswith("_test"): - file_name = file_name[0:-5] + colon_pos = -1 + try: + colon_pos = spec.index(":") + except ValueError: + pass + if colon_pos > -1: + # Explicit alias. + file_path = spec[0:colon_pos] + file_name = spec[colon_pos + 1 :] + print(f"FILE PATH = {file_path}") + else: + # Auto detect the alias from the basename. + file_name = os.path.basename(file_path) + if file_name.endswith(".py"): + file_name = file_name[0:-3] + if file_name.endswith("_test"): + file_name = file_name[0:-5] - main_test_dir = os.path.join(os.path.dirname(__file__), "..") - parent_path = os.path.dirname(file_path) + main_test_dir = os.path.join(os.path.dirname(__file__), "..") + parent_path = os.path.dirname(file_path) - file_name = f"{variant}__{file_name}.run" - run_file = os.path.join(main_test_dir, parent_path, file_name) - return run_file + file_name = f"{variant}__{file_name}.run" + run_file = os.path.join(main_test_dir, parent_path, file_name) + return run_file if __name__ == "__main__": - main(sys.argv[1:]) + main(sys.argv[1:]) diff --git a/integrations/tensorflow/test/python/iree_tf_tests/batch_norm_test.py b/integrations/tensorflow/test/python/iree_tf_tests/batch_norm_test.py index 544a44325a9a..cdb0bea71f64 100644 --- a/integrations/tensorflow/test/python/iree_tf_tests/batch_norm_test.py +++ b/integrations/tensorflow/test/python/iree_tf_tests/batch_norm_test.py @@ -13,49 +13,50 @@ class BatchNormModule(tf.Module): - - @tf.function(input_signature=[ - tf.TensorSpec([4, 16], tf.float32), - tf.TensorSpec([16], tf.float32), - tf.TensorSpec([16], tf.float32), - tf.TensorSpec([16], tf.float32), - tf.TensorSpec([16], tf.float32), - ]) - def batch_norm_inference(self, x, mean, variance, offset, scale): - return tf.nn.batch_normalization(x, - mean=mean, - variance=variance, - offset=offset, - scale=scale, - variance_epsilon=1e-4) + @tf.function( + input_signature=[ + tf.TensorSpec([4, 16], tf.float32), + tf.TensorSpec([16], tf.float32), + tf.TensorSpec([16], tf.float32), + tf.TensorSpec([16], tf.float32), + tf.TensorSpec([16], tf.float32), + ] + ) + def batch_norm_inference(self, x, mean, variance, offset, scale): + return tf.nn.batch_normalization( + x, + mean=mean, + variance=variance, + offset=offset, + scale=scale, + variance_epsilon=1e-4, + ) class BatchNormTest(tf_test_utils.TracedModuleTestCase): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._modules = tf_test_utils.compile_tf_module(BatchNormModule) - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._modules = tf_test_utils.compile_tf_module(BatchNormModule) - - def test_batch_norm_inference(self): - - def batch_norm_inference(module): - # Note: scaling by a small value to increase numerical stability. - x = tf_utils.uniform((4, 16)) * 1e-3 - mean = tf_utils.uniform((16,)) * 1e-3 - variance = tf_utils.uniform((16,), low=0.0) * 1e-3 - offset = tf_utils.uniform((16,)) * 1e-3 - scale = tf_utils.uniform((16,)) * 1e-3 - module.batch_norm_inference(x, mean, variance, offset, scale) + def test_batch_norm_inference(self): + def batch_norm_inference(module): + # Note: scaling by a small value to increase numerical stability. + x = tf_utils.uniform((4, 16)) * 1e-3 + mean = tf_utils.uniform((16,)) * 1e-3 + variance = tf_utils.uniform((16,), low=0.0) * 1e-3 + offset = tf_utils.uniform((16,)) * 1e-3 + scale = tf_utils.uniform((16,)) * 1e-3 + module.batch_norm_inference(x, mean, variance, offset, scale) - self.compare_backends(batch_norm_inference, self._modules) + self.compare_backends(batch_norm_inference, self._modules) def main(argv): - del argv # Unused - if hasattr(tf, 'enable_v2_behavior'): - tf.enable_v2_behavior() - tf.test.main() + del argv # Unused + if hasattr(tf, "enable_v2_behavior"): + tf.enable_v2_behavior() + tf.test.main() -if __name__ == '__main__': - app.run(main) +if __name__ == "__main__": + app.run(main) diff --git a/integrations/tensorflow/test/python/iree_tf_tests/batch_to_space_nd_test.py b/integrations/tensorflow/test/python/iree_tf_tests/batch_to_space_nd_test.py index a7cd76abb50f..c801e53d5f63 100644 --- a/integrations/tensorflow/test/python/iree_tf_tests/batch_to_space_nd_test.py +++ b/integrations/tensorflow/test/python/iree_tf_tests/batch_to_space_nd_test.py @@ -13,36 +13,33 @@ class BatchtoSpaceModule(tf.Module): - - @tf.function(input_signature=[tf.TensorSpec([3, 5, 2], tf.float32)]) - def batch_to_space_nd(self, batched): - block_shape = [3] - paddings = [[3, 4]] - return tf.compat.v1.batch_to_space_nd(batched, block_shape, paddings) + @tf.function(input_signature=[tf.TensorSpec([3, 5, 2], tf.float32)]) + def batch_to_space_nd(self, batched): + block_shape = [3] + paddings = [[3, 4]] + return tf.compat.v1.batch_to_space_nd(batched, block_shape, paddings) class BatchtoSpaceTest(tf_test_utils.TracedModuleTestCase): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._modules = tf_test_utils.compile_tf_module(BatchtoSpaceModule) - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._modules = tf_test_utils.compile_tf_module(BatchtoSpaceModule) - - def test_space_to_batch_inference(self): - - def space_to_batch_inference(module): - x = np.linspace(0, 29, 30, dtype=np.float32) - x = np.reshape(x, [3, 5, 2]) - module.batch_to_space_nd(x) + def test_space_to_batch_inference(self): + def space_to_batch_inference(module): + x = np.linspace(0, 29, 30, dtype=np.float32) + x = np.reshape(x, [3, 5, 2]) + module.batch_to_space_nd(x) - self.compare_backends(space_to_batch_inference, self._modules) + self.compare_backends(space_to_batch_inference, self._modules) def main(argv): - del argv # Unused - if hasattr(tf, 'enable_v2_behavior'): - tf.enable_v2_behavior() - tf.test.main() + del argv # Unused + if hasattr(tf, "enable_v2_behavior"): + tf.enable_v2_behavior() + tf.test.main() -if __name__ == '__main__': - app.run(main) +if __name__ == "__main__": + app.run(main) diff --git a/integrations/tensorflow/test/python/iree_tf_tests/broadcast_to_test.py b/integrations/tensorflow/test/python/iree_tf_tests/broadcast_to_test.py index 300fdfce736e..c4f43a3ae7ed 100644 --- a/integrations/tensorflow/test/python/iree_tf_tests/broadcast_to_test.py +++ b/integrations/tensorflow/test/python/iree_tf_tests/broadcast_to_test.py @@ -11,40 +11,36 @@ class BroadcastToModule(tf.Module): + def __init__(self): + pass - def __init__(self): - pass - - @tf.function(input_signature=[ - tf.TensorSpec([], tf.float32), - tf.TensorSpec([2], tf.int32) - ]) - def scalar_broadcast_to(self, x, shape): - return tf.broadcast_to(x, shape) + @tf.function( + input_signature=[tf.TensorSpec([], tf.float32), tf.TensorSpec([2], tf.int32)] + ) + def scalar_broadcast_to(self, x, shape): + return tf.broadcast_to(x, shape) class BroadcastToTest(tf_test_utils.TracedModuleTestCase): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._modules = tf_test_utils.compile_tf_module(BroadcastToModule) - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._modules = tf_test_utils.compile_tf_module(BroadcastToModule) - - def test_scalar_broadcast_to(self): - - def scalar_broadcast_to(module): - x = np.array(1, dtype=np.float32) - shape = np.array([3, 3], dtype=np.int32) - result = module.scalar_broadcast_to(x, shape) + def test_scalar_broadcast_to(self): + def scalar_broadcast_to(module): + x = np.array(1, dtype=np.float32) + shape = np.array([3, 3], dtype=np.int32) + result = module.scalar_broadcast_to(x, shape) - self.compare_backends(scalar_broadcast_to, self._modules) + self.compare_backends(scalar_broadcast_to, self._modules) def main(argv): - del argv # Unused - if hasattr(tf, 'enable_v2_behavior'): - tf.enable_v2_behavior() - tf.test.main() + del argv # Unused + if hasattr(tf, "enable_v2_behavior"): + tf.enable_v2_behavior() + tf.test.main() -if __name__ == '__main__': - app.run(main) +if __name__ == "__main__": + app.run(main) diff --git a/integrations/tensorflow/test/python/iree_tf_tests/broadcasting_test.py b/integrations/tensorflow/test/python/iree_tf_tests/broadcasting_test.py index 07ae6ec57fbf..f168fcca13ed 100644 --- a/integrations/tensorflow/test/python/iree_tf_tests/broadcasting_test.py +++ b/integrations/tensorflow/test/python/iree_tf_tests/broadcasting_test.py @@ -13,55 +13,52 @@ class BroadcastingModule(tf.Module): - - @tf.function(input_signature=[ - tf.TensorSpec([None], tf.float32), - tf.TensorSpec([None], tf.float32), - ]) - def add(self, lhs, rhs): - return lhs + rhs + @tf.function( + input_signature=[ + tf.TensorSpec([None], tf.float32), + tf.TensorSpec([None], tf.float32), + ] + ) + def add(self, lhs, rhs): + return lhs + rhs class BroadcastingTest(tf_test_utils.TracedModuleTestCase): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._modules = tf_test_utils.compile_tf_module(BroadcastingModule) - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._modules = tf_test_utils.compile_tf_module(BroadcastingModule) - - def test_add_same_shape(self): - - def add_same_shape(module): - lhs = tf_utils.uniform([4]) - rhs = tf_utils.uniform([4]) - module.add(lhs, rhs) - - self.compare_backends(add_same_shape, self._modules) - - def test_add_broadcast_lhs(self): + def test_add_same_shape(self): + def add_same_shape(module): + lhs = tf_utils.uniform([4]) + rhs = tf_utils.uniform([4]) + module.add(lhs, rhs) - def add_broadcast_lhs(module): - lhs = tf_utils.uniform([1]) - rhs = tf_utils.uniform([4]) - module.add(lhs, rhs) + self.compare_backends(add_same_shape, self._modules) - self.compare_backends(add_broadcast_lhs, self._modules) + def test_add_broadcast_lhs(self): + def add_broadcast_lhs(module): + lhs = tf_utils.uniform([1]) + rhs = tf_utils.uniform([4]) + module.add(lhs, rhs) - def test_add_broadcast_rhs(self): + self.compare_backends(add_broadcast_lhs, self._modules) - def add_broadcast_rhs(module): - lhs = tf_utils.uniform([4]) - rhs = tf_utils.uniform([1]) - module.add(lhs, rhs) + def test_add_broadcast_rhs(self): + def add_broadcast_rhs(module): + lhs = tf_utils.uniform([4]) + rhs = tf_utils.uniform([1]) + module.add(lhs, rhs) - self.compare_backends(add_broadcast_rhs, self._modules) + self.compare_backends(add_broadcast_rhs, self._modules) def main(argv): - del argv # Unused - if hasattr(tf, 'enable_v2_behavior'): - tf.enable_v2_behavior() - tf.test.main() + del argv # Unused + if hasattr(tf, "enable_v2_behavior"): + tf.enable_v2_behavior() + tf.test.main() -if __name__ == '__main__': - app.run(main) +if __name__ == "__main__": + app.run(main) diff --git a/integrations/tensorflow/test/python/iree_tf_tests/concat_test.py b/integrations/tensorflow/test/python/iree_tf_tests/concat_test.py index 3dbfd4526583..d45939904f12 100644 --- a/integrations/tensorflow/test/python/iree_tf_tests/concat_test.py +++ b/integrations/tensorflow/test/python/iree_tf_tests/concat_test.py @@ -13,85 +13,87 @@ class ConcatOpsModule(tf.Module): - - @tf.function(input_signature=[ - tf.TensorSpec([1, 5, 0], tf.float32), - tf.TensorSpec([1, 5, 1], tf.float32), - ]) - def concat_zero_dim(self, a, b): - return tf.concat([a, b], axis=2) - - @tf.function(input_signature=[ - tf.TensorSpec([1, 5, 1], tf.float32), - tf.TensorSpec([1, 5, 1], tf.float32), - ]) - def concat0axis(self, a, b): - return tf.concat([a, b], axis=0) - - @tf.function(input_signature=[ - tf.TensorSpec([1, 5, 1], tf.float32), - tf.TensorSpec([1, 5, 1], tf.float32), - ]) - def concat1axis(self, a, b): - return tf.concat([a, b], axis=1) - - @tf.function(input_signature=[ - tf.TensorSpec([1, 5, 1], tf.float32), - tf.TensorSpec([1, 5, 1], tf.float32), - ]) - def concat2axis(self, a, b): - return tf.concat([a, b], axis=2) + @tf.function( + input_signature=[ + tf.TensorSpec([1, 5, 0], tf.float32), + tf.TensorSpec([1, 5, 1], tf.float32), + ] + ) + def concat_zero_dim(self, a, b): + return tf.concat([a, b], axis=2) + + @tf.function( + input_signature=[ + tf.TensorSpec([1, 5, 1], tf.float32), + tf.TensorSpec([1, 5, 1], tf.float32), + ] + ) + def concat0axis(self, a, b): + return tf.concat([a, b], axis=0) + + @tf.function( + input_signature=[ + tf.TensorSpec([1, 5, 1], tf.float32), + tf.TensorSpec([1, 5, 1], tf.float32), + ] + ) + def concat1axis(self, a, b): + return tf.concat([a, b], axis=1) + + @tf.function( + input_signature=[ + tf.TensorSpec([1, 5, 1], tf.float32), + tf.TensorSpec([1, 5, 1], tf.float32), + ] + ) + def concat2axis(self, a, b): + return tf.concat([a, b], axis=2) class ConcatOpsTest(tf_test_utils.TracedModuleTestCase): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._modules = tf_test_utils.compile_tf_module(ConcatOpsModule) - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._modules = tf_test_utils.compile_tf_module(ConcatOpsModule) - - def test_concat_zero_dim(self): - - def concat_zero_dim(module): - a = tf_utils.uniform([1, 5, 0]) - b = tf_utils.uniform([1, 5, 1]) - module.concat_zero_dim(a, b) - - self.compare_backends(concat_zero_dim, self._modules) - - def test_concat0axis(self): - - def concat0axis(module): - a = tf_utils.uniform([1, 5, 1]) - b = tf_utils.uniform([1, 5, 1]) - module.concat0axis(a, b) + def test_concat_zero_dim(self): + def concat_zero_dim(module): + a = tf_utils.uniform([1, 5, 0]) + b = tf_utils.uniform([1, 5, 1]) + module.concat_zero_dim(a, b) - self.compare_backends(concat0axis, self._modules) + self.compare_backends(concat_zero_dim, self._modules) - def test_concat1axis(self): + def test_concat0axis(self): + def concat0axis(module): + a = tf_utils.uniform([1, 5, 1]) + b = tf_utils.uniform([1, 5, 1]) + module.concat0axis(a, b) - def concat1axis(module): - a = tf_utils.uniform([1, 5, 1]) - b = tf_utils.uniform([1, 5, 1]) - module.concat1axis(a, b) + self.compare_backends(concat0axis, self._modules) - self.compare_backends(concat1axis, self._modules) + def test_concat1axis(self): + def concat1axis(module): + a = tf_utils.uniform([1, 5, 1]) + b = tf_utils.uniform([1, 5, 1]) + module.concat1axis(a, b) - def test_concat2axis(self): + self.compare_backends(concat1axis, self._modules) - def concat2axis(module): - a = tf_utils.uniform([1, 5, 1]) - b = tf_utils.uniform([1, 5, 1]) - module.concat2axis(a, b) + def test_concat2axis(self): + def concat2axis(module): + a = tf_utils.uniform([1, 5, 1]) + b = tf_utils.uniform([1, 5, 1]) + module.concat2axis(a, b) - self.compare_backends(concat2axis, self._modules) + self.compare_backends(concat2axis, self._modules) def main(argv): - del argv # Unused - if hasattr(tf, 'enable_v2_behavior'): - tf.enable_v2_behavior() - tf.test.main() + del argv # Unused + if hasattr(tf, "enable_v2_behavior"): + tf.enable_v2_behavior() + tf.test.main() -if __name__ == '__main__': - app.run(main) +if __name__ == "__main__": + app.run(main) diff --git a/integrations/tensorflow/test/python/iree_tf_tests/control_flow_test.py b/integrations/tensorflow/test/python/iree_tf_tests/control_flow_test.py index 0196a7f3bdfe..5f1b33ff25ac 100644 --- a/integrations/tensorflow/test/python/iree_tf_tests/control_flow_test.py +++ b/integrations/tensorflow/test/python/iree_tf_tests/control_flow_test.py @@ -11,51 +11,47 @@ class ControlFlowModule(tf.Module): + def __init__(self): + pass - def __init__(self): - pass - - @tf.function(input_signature=[tf.TensorSpec([], tf.float32)]) - def collatz(self, a): - i = 0. - while a > 1.: - i = i + 1. - if (a % 2.) > 0.: - a = 3. * a + 1. - else: - a = a / 2. - return i + @tf.function(input_signature=[tf.TensorSpec([], tf.float32)]) + def collatz(self, a): + i = 0.0 + while a > 1.0: + i = i + 1.0 + if (a % 2.0) > 0.0: + a = 3.0 * a + 1.0 + else: + a = a / 2.0 + return i class ControlFlowTest(tf_test_utils.TracedModuleTestCase): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._modules = tf_test_utils.compile_tf_module(ControlFlowModule) - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._modules = tf_test_utils.compile_tf_module(ControlFlowModule) - - def test_short_sequence(self): - - def short_sequence(module): - input_array = np.array(9., dtype=np.float32) - module.collatz(input_array) - - self.compare_backends(short_sequence, self._modules) + def test_short_sequence(self): + def short_sequence(module): + input_array = np.array(9.0, dtype=np.float32) + module.collatz(input_array) - def test_long_sequence(self): + self.compare_backends(short_sequence, self._modules) - def long_sequence(module): - input_array = np.array(178., dtype=np.float32) - module.collatz(input_array) + def test_long_sequence(self): + def long_sequence(module): + input_array = np.array(178.0, dtype=np.float32) + module.collatz(input_array) - self.compare_backends(long_sequence, self._modules) + self.compare_backends(long_sequence, self._modules) def main(argv): - del argv # Unused - if hasattr(tf, 'enable_v2_behavior'): - tf.enable_v2_behavior() - tf.test.main() + del argv # Unused + if hasattr(tf, "enable_v2_behavior"): + tf.enable_v2_behavior() + tf.test.main() -if __name__ == '__main__': - app.run(main) +if __name__ == "__main__": + app.run(main) diff --git a/integrations/tensorflow/test/python/iree_tf_tests/conv_test.py b/integrations/tensorflow/test/python/iree_tf_tests/conv_test.py index 3373905e34a2..aa4c956215a5 100644 --- a/integrations/tensorflow/test/python/iree_tf_tests/conv_test.py +++ b/integrations/tensorflow/test/python/iree_tf_tests/conv_test.py @@ -12,121 +12,141 @@ class Conv2dModule(tf_test_utils.TestModule): - - @tf_test_utils.tf_function_unit_test(input_signature=[ - tf.TensorSpec([1, 4, 5, 1], tf.float32), - tf.TensorSpec([1, 1, 1, 1], tf.float32), - ]) - def conv2d_1451x1111_valid(self, img, kernel): - return tf.nn.conv2d(img, kernel, [1, 1, 1, 1], "VALID", name="result") - - @tf_test_utils.tf_function_unit_test(input_signature=[ - tf.TensorSpec([1, 4, 5, 1], tf.float32), - tf.TensorSpec([2, 2, 1, 1], tf.float32), - ]) - def conv2d_1451x2211_dilated_valid(self, img, kernel): - return tf.nn.conv2d(img, - kernel, [1, 1, 1, 1], - "VALID", - dilations=[1, 2, 1, 1], - name="result") - - @tf_test_utils.tf_function_unit_test(input_signature=[ - tf.TensorSpec([1, 4, 5, 2], tf.float32), - tf.TensorSpec([2, 2, 2, 3], tf.float32), - ]) - def conv2d_1452x2223_dilated_valid(self, img, kernel): - return tf.nn.conv2d(img, - kernel, [1, 1, 1, 1], - "VALID", - dilations=[1, 2, 1, 1], - name="result") - - @tf_test_utils.tf_function_unit_test(input_signature=[ - tf.TensorSpec([2, 4, 5, 1], tf.float32), - tf.TensorSpec([1, 1, 1, 1], tf.float32), - ]) - def conv2d_2451x1111_valid(self, img, kernel): - return tf.nn.conv2d(img, kernel, [1, 1, 1, 1], "VALID", name="result") - - @tf_test_utils.tf_function_unit_test(input_signature=[ - tf.TensorSpec([1, 4, 5, 1], tf.float32), - tf.TensorSpec([2, 3, 1, 1], tf.float32), - ]) - def conv2d_1451x2311_valid(self, img, kernel): - return tf.nn.conv2d(img, kernel, [1, 1, 1, 1], "VALID", name="result") - - @tf_test_utils.tf_function_unit_test(input_signature=[ - tf.TensorSpec([1, 4, 5, 1], tf.float32), - tf.TensorSpec([2, 3, 1, 1], tf.float32), - ]) - def conv2d_1451x2311_same(self, img, kernel): - return tf.nn.conv2d(img, kernel, [1, 1, 1, 1], "SAME", name="result") - - @tf_test_utils.tf_function_unit_test(input_signature=[ - tf.TensorSpec([2, 4, 5, 1], tf.float32), - tf.TensorSpec([2, 3, 1, 1], tf.float32), - ]) - def conv2d_2451x2311_same(self, img, kernel): - return tf.nn.conv2d(img, kernel, [1, 1, 1, 1], "SAME", name="result") - - @tf_test_utils.tf_function_unit_test(input_signature=[ - tf.TensorSpec([1, 4, 5, 2], tf.float32), - tf.TensorSpec([3, 2, 2, 1], tf.float32), - ]) - def conv2d_1452x3221_same(self, img, kernel): - return tf.nn.conv2d(img, kernel, [1, 1, 1, 1], "SAME", name="result") - - @tf_test_utils.tf_function_unit_test(input_signature=[ - tf.TensorSpec([1, 4, 5, 1], tf.float32), - tf.TensorSpec([1, 1, 1, 2], tf.float32), - ]) - def conv2d_1451x1112_same(self, img, kernel): - return tf.nn.conv2d(img, kernel, [1, 1, 1, 1], "SAME", name="result") - - @tf_test_utils.tf_function_unit_test(input_signature=[ - tf.TensorSpec([1, 4, 5, 2], tf.float32), - tf.TensorSpec([1, 1, 2, 2], tf.float32), - ]) - def conv2d_1452x1122_same(self, img, kernel): - return tf.nn.conv2d(img, kernel, [1, 1, 1, 1], "SAME", name="result") - - @tf_test_utils.tf_function_unit_test(input_signature=[ - tf.TensorSpec([1, 4, 5, 2], tf.float32), - tf.TensorSpec([2, 2, 2, 3], tf.float32), - ]) - def conv2d_1452x2223_same(self, img, kernel): - return tf.nn.conv2d(img, kernel, [1, 1, 1, 1], "SAME", name="result") - - @tf_test_utils.tf_function_unit_test(input_signature=[ - tf.TensorSpec([1, 4, 5, 2], tf.float32), - tf.TensorSpec([2, 2, 2, 3], tf.float32), - ]) - def conv2d_1452x2223_valid(self, img, kernel): - return tf.nn.conv2d(img, kernel, [1, 1, 1, 1], "VALID", name="result") - - @tf_test_utils.tf_function_unit_test(input_signature=[ - tf.TensorSpec([2, 4, 5, 2], tf.float32), - tf.TensorSpec([2, 2, 2, 3], tf.float32), - ]) - def conv2d_2452x2223_valid(self, img, kernel): - return tf.nn.conv2d(img, kernel, [1, 1, 1, 1], "VALID", name="result") + @tf_test_utils.tf_function_unit_test( + input_signature=[ + tf.TensorSpec([1, 4, 5, 1], tf.float32), + tf.TensorSpec([1, 1, 1, 1], tf.float32), + ] + ) + def conv2d_1451x1111_valid(self, img, kernel): + return tf.nn.conv2d(img, kernel, [1, 1, 1, 1], "VALID", name="result") + + @tf_test_utils.tf_function_unit_test( + input_signature=[ + tf.TensorSpec([1, 4, 5, 1], tf.float32), + tf.TensorSpec([2, 2, 1, 1], tf.float32), + ] + ) + def conv2d_1451x2211_dilated_valid(self, img, kernel): + return tf.nn.conv2d( + img, kernel, [1, 1, 1, 1], "VALID", dilations=[1, 2, 1, 1], name="result" + ) + + @tf_test_utils.tf_function_unit_test( + input_signature=[ + tf.TensorSpec([1, 4, 5, 2], tf.float32), + tf.TensorSpec([2, 2, 2, 3], tf.float32), + ] + ) + def conv2d_1452x2223_dilated_valid(self, img, kernel): + return tf.nn.conv2d( + img, kernel, [1, 1, 1, 1], "VALID", dilations=[1, 2, 1, 1], name="result" + ) + + @tf_test_utils.tf_function_unit_test( + input_signature=[ + tf.TensorSpec([2, 4, 5, 1], tf.float32), + tf.TensorSpec([1, 1, 1, 1], tf.float32), + ] + ) + def conv2d_2451x1111_valid(self, img, kernel): + return tf.nn.conv2d(img, kernel, [1, 1, 1, 1], "VALID", name="result") + + @tf_test_utils.tf_function_unit_test( + input_signature=[ + tf.TensorSpec([1, 4, 5, 1], tf.float32), + tf.TensorSpec([2, 3, 1, 1], tf.float32), + ] + ) + def conv2d_1451x2311_valid(self, img, kernel): + return tf.nn.conv2d(img, kernel, [1, 1, 1, 1], "VALID", name="result") + + @tf_test_utils.tf_function_unit_test( + input_signature=[ + tf.TensorSpec([1, 4, 5, 1], tf.float32), + tf.TensorSpec([2, 3, 1, 1], tf.float32), + ] + ) + def conv2d_1451x2311_same(self, img, kernel): + return tf.nn.conv2d(img, kernel, [1, 1, 1, 1], "SAME", name="result") + + @tf_test_utils.tf_function_unit_test( + input_signature=[ + tf.TensorSpec([2, 4, 5, 1], tf.float32), + tf.TensorSpec([2, 3, 1, 1], tf.float32), + ] + ) + def conv2d_2451x2311_same(self, img, kernel): + return tf.nn.conv2d(img, kernel, [1, 1, 1, 1], "SAME", name="result") + + @tf_test_utils.tf_function_unit_test( + input_signature=[ + tf.TensorSpec([1, 4, 5, 2], tf.float32), + tf.TensorSpec([3, 2, 2, 1], tf.float32), + ] + ) + def conv2d_1452x3221_same(self, img, kernel): + return tf.nn.conv2d(img, kernel, [1, 1, 1, 1], "SAME", name="result") + + @tf_test_utils.tf_function_unit_test( + input_signature=[ + tf.TensorSpec([1, 4, 5, 1], tf.float32), + tf.TensorSpec([1, 1, 1, 2], tf.float32), + ] + ) + def conv2d_1451x1112_same(self, img, kernel): + return tf.nn.conv2d(img, kernel, [1, 1, 1, 1], "SAME", name="result") + + @tf_test_utils.tf_function_unit_test( + input_signature=[ + tf.TensorSpec([1, 4, 5, 2], tf.float32), + tf.TensorSpec([1, 1, 2, 2], tf.float32), + ] + ) + def conv2d_1452x1122_same(self, img, kernel): + return tf.nn.conv2d(img, kernel, [1, 1, 1, 1], "SAME", name="result") + + @tf_test_utils.tf_function_unit_test( + input_signature=[ + tf.TensorSpec([1, 4, 5, 2], tf.float32), + tf.TensorSpec([2, 2, 2, 3], tf.float32), + ] + ) + def conv2d_1452x2223_same(self, img, kernel): + return tf.nn.conv2d(img, kernel, [1, 1, 1, 1], "SAME", name="result") + + @tf_test_utils.tf_function_unit_test( + input_signature=[ + tf.TensorSpec([1, 4, 5, 2], tf.float32), + tf.TensorSpec([2, 2, 2, 3], tf.float32), + ] + ) + def conv2d_1452x2223_valid(self, img, kernel): + return tf.nn.conv2d(img, kernel, [1, 1, 1, 1], "VALID", name="result") + + @tf_test_utils.tf_function_unit_test( + input_signature=[ + tf.TensorSpec([2, 4, 5, 2], tf.float32), + tf.TensorSpec([2, 2, 2, 3], tf.float32), + ] + ) + def conv2d_2452x2223_valid(self, img, kernel): + return tf.nn.conv2d(img, kernel, [1, 1, 1, 1], "VALID", name="result") class ConvTest(tf_test_utils.TracedModuleTestCase): - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._modules = tf_test_utils.compile_tf_module(Conv2dModule) + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._modules = tf_test_utils.compile_tf_module(Conv2dModule) def main(argv): - del argv # Unused - if hasattr(tf, 'enable_v2_behavior'): - tf.enable_v2_behavior() - ConvTest.generate_unit_tests(Conv2dModule) - tf.test.main() + del argv # Unused + if hasattr(tf, "enable_v2_behavior"): + tf.enable_v2_behavior() + ConvTest.generate_unit_tests(Conv2dModule) + tf.test.main() -if __name__ == '__main__': - app.run(main) +if __name__ == "__main__": + app.run(main) diff --git a/integrations/tensorflow/test/python/iree_tf_tests/conv_transpose_test.py b/integrations/tensorflow/test/python/iree_tf_tests/conv_transpose_test.py index 19ac6f7d7e5d..8501eb103f45 100644 --- a/integrations/tensorflow/test/python/iree_tf_tests/conv_transpose_test.py +++ b/integrations/tensorflow/test/python/iree_tf_tests/conv_transpose_test.py @@ -12,92 +12,88 @@ class ConvTransposeModule(tf.Module): - - @tf.function(input_signature=[ - tf.TensorSpec([2, 2, 1, 1], tf.float32), - tf.TensorSpec([1, 2, 4, 1], tf.float32), - ]) - def conv2d_transpose_same(self, filt, img): - input_sizes = [1, 2, 4, 1] - strides = [1, 1, 1, 1] - padding = "SAME" - return tf.nn.conv2d_transpose(img, - filt, - input_sizes, - strides, - padding, - name="result") - - @tf.function(input_signature=[ - tf.TensorSpec([1, 4, 2, 3], tf.float32), - tf.TensorSpec([1, 1, 4, 3], tf.float32), - ]) - def conv2d_transpose_dilated_w(self, filt, img): - input_sizes = [1, 1, 10, 2] - strides = [1, 1, 2, 1] - padding = "VALID" - return tf.nn.conv2d_transpose(img, - filt, - input_sizes, - strides, - padding, - name="result") - - @tf.function(input_signature=[ - tf.TensorSpec([4, 1, 2, 3], tf.float32), - tf.TensorSpec([1, 4, 1, 3], tf.float32), - ]) - def conv2d_transpose_dilated_h(self, filt, img): - input_sizes = [1, 10, 1, 2] - strides = [1, 2, 1, 1] - padding = "VALID" - return tf.nn.conv2d_transpose(img, - filt, - input_sizes, - strides, - padding, - name="result") + @tf.function( + input_signature=[ + tf.TensorSpec([2, 2, 1, 1], tf.float32), + tf.TensorSpec([1, 2, 4, 1], tf.float32), + ] + ) + def conv2d_transpose_same(self, filt, img): + input_sizes = [1, 2, 4, 1] + strides = [1, 1, 1, 1] + padding = "SAME" + return tf.nn.conv2d_transpose( + img, filt, input_sizes, strides, padding, name="result" + ) + + @tf.function( + input_signature=[ + tf.TensorSpec([1, 4, 2, 3], tf.float32), + tf.TensorSpec([1, 1, 4, 3], tf.float32), + ] + ) + def conv2d_transpose_dilated_w(self, filt, img): + input_sizes = [1, 1, 10, 2] + strides = [1, 1, 2, 1] + padding = "VALID" + return tf.nn.conv2d_transpose( + img, filt, input_sizes, strides, padding, name="result" + ) + + @tf.function( + input_signature=[ + tf.TensorSpec([4, 1, 2, 3], tf.float32), + tf.TensorSpec([1, 4, 1, 3], tf.float32), + ] + ) + def conv2d_transpose_dilated_h(self, filt, img): + input_sizes = [1, 10, 1, 2] + strides = [1, 2, 1, 1] + padding = "VALID" + return tf.nn.conv2d_transpose( + img, filt, input_sizes, strides, padding, name="result" + ) class ConvTransposeTest(tf_test_utils.TracedModuleTestCase): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._modules = tf_test_utils.compile_tf_module(ConvTransposeModule) + + def test_transposed(self): + def transposed(module): + kernel = tf_utils.uniform([2, 2, 1, 1], dtype=np.float32) + img = tf_utils.uniform([1, 2, 4, 1], dtype=np.float32) + + module.conv2d_transpose_same(kernel, img) - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._modules = tf_test_utils.compile_tf_module(ConvTransposeModule) + self.compare_backends(transposed, self._modules) - # yapf: disable - def test_transposed(self): - def transposed(module): - kernel = tf_utils.uniform([2, 2, 1, 1], dtype=np.float32) - img = tf_utils.uniform([1, 2, 4, 1], dtype=np.float32) + def test_transposed_dilated_w(self): + def transposed(module): + kernel = tf_utils.uniform([1, 4, 2, 3], dtype=np.float32) + img = tf_utils.uniform([1, 1, 4, 3], dtype=np.float32) - module.conv2d_transpose_same(kernel, img) - self.compare_backends(transposed, self._modules) + module.conv2d_transpose_dilated_w(kernel, img) - def test_transposed_dilated_w(self): - def transposed(module): - kernel = tf_utils.uniform([1, 4, 2, 3], dtype=np.float32) - img = tf_utils.uniform([1, 1, 4, 3], dtype=np.float32) + self.compare_backends(transposed, self._modules) - module.conv2d_transpose_dilated_w(kernel, img) - self.compare_backends(transposed, self._modules) + def test_transposed_dilated_h(self): + def transposed(module): + kernel = tf_utils.uniform([4, 1, 2, 3], dtype=np.float32) + img = tf_utils.uniform([1, 4, 1, 3], dtype=np.float32) - def test_transposed_dilated_h(self): - def transposed(module): - kernel = tf_utils.uniform([4, 1, 2, 3], dtype=np.float32) - img = tf_utils.uniform([1, 4, 1, 3], dtype=np.float32) + module.conv2d_transpose_dilated_h(kernel, img) - module.conv2d_transpose_dilated_h(kernel, img) - self.compare_backends(transposed, self._modules) - # yapf: enable + self.compare_backends(transposed, self._modules) def main(argv): - del argv # Unused - if hasattr(tf, 'enable_v2_behavior'): - tf.enable_v2_behavior() - tf.test.main() + del argv # Unused + if hasattr(tf, "enable_v2_behavior"): + tf.enable_v2_behavior() + tf.test.main() -if __name__ == '__main__': - app.run(main) +if __name__ == "__main__": + app.run(main) diff --git a/integrations/tensorflow/test/python/iree_tf_tests/depth_conv_test.py b/integrations/tensorflow/test/python/iree_tf_tests/depth_conv_test.py index ab63b3346910..c73bedf8f914 100644 --- a/integrations/tensorflow/test/python/iree_tf_tests/depth_conv_test.py +++ b/integrations/tensorflow/test/python/iree_tf_tests/depth_conv_test.py @@ -12,111 +12,107 @@ class DepthConv2dModule(tf.Module): - - # TODO(ataei): Add dilation and strided tests. - @tf.function(input_signature=[ - tf.TensorSpec([2, 4, 5, 2], tf.float32), - tf.TensorSpec([2, 2, 2, 3], tf.float32), - ]) - def conv2d_2452x2423_valid(self, img, kernel): - return tf.nn.depthwise_conv2d(img, - kernel, [1, 1, 1, 1], - "VALID", - name="result") - - @tf.function(input_signature=[ - tf.TensorSpec([2, 4, 5, 2], tf.float32), - tf.TensorSpec([2, 4, 2, 3], tf.float32), - ]) - def conv2d_2452x2423_same(self, img, kernel): - return tf.nn.depthwise_conv2d(img, - kernel, [1, 1, 1, 1], - "SAME", - name="result") - - @tf.function(input_signature=[ - tf.TensorSpec([2, 4, 5, 2], tf.float32), - tf.TensorSpec([2, 4, 2, 3], tf.float32), - ]) - def conv2d_2452x2423_valid_stride_2(self, img, kernel): - return tf.nn.depthwise_conv2d(img, - kernel, [1, 2, 2, 1], - "VALID", - name="result") - - @tf.function(input_signature=[ - tf.TensorSpec([2, 4, 5, 2], tf.float32), - tf.TensorSpec([2, 4, 2, 3], tf.float32), - ]) - def conv2d_2452x2423_same_stride_2(self, img, kernel): - return tf.nn.depthwise_conv2d(img, - kernel, [1, 2, 2, 1], - "SAME", - name="result") - - @tf.function(input_signature=[ - tf.TensorSpec([2, 4, 5, 4], tf.float32), - tf.TensorSpec([2, 4, 4, 1], tf.float32), - ]) - def conv2d_2453x2441_same_stride_1(self, img, kernel): - return tf.nn.depthwise_conv2d(img, - kernel, [1, 1, 1, 1], - "SAME", - name="result") + # TODO(ataei): Add dilation and strided tests. + @tf.function( + input_signature=[ + tf.TensorSpec([2, 4, 5, 2], tf.float32), + tf.TensorSpec([2, 2, 2, 3], tf.float32), + ] + ) + def conv2d_2452x2423_valid(self, img, kernel): + return tf.nn.depthwise_conv2d(img, kernel, [1, 1, 1, 1], "VALID", name="result") + + @tf.function( + input_signature=[ + tf.TensorSpec([2, 4, 5, 2], tf.float32), + tf.TensorSpec([2, 4, 2, 3], tf.float32), + ] + ) + def conv2d_2452x2423_same(self, img, kernel): + return tf.nn.depthwise_conv2d(img, kernel, [1, 1, 1, 1], "SAME", name="result") + + @tf.function( + input_signature=[ + tf.TensorSpec([2, 4, 5, 2], tf.float32), + tf.TensorSpec([2, 4, 2, 3], tf.float32), + ] + ) + def conv2d_2452x2423_valid_stride_2(self, img, kernel): + return tf.nn.depthwise_conv2d(img, kernel, [1, 2, 2, 1], "VALID", name="result") + + @tf.function( + input_signature=[ + tf.TensorSpec([2, 4, 5, 2], tf.float32), + tf.TensorSpec([2, 4, 2, 3], tf.float32), + ] + ) + def conv2d_2452x2423_same_stride_2(self, img, kernel): + return tf.nn.depthwise_conv2d(img, kernel, [1, 2, 2, 1], "SAME", name="result") + + @tf.function( + input_signature=[ + tf.TensorSpec([2, 4, 5, 4], tf.float32), + tf.TensorSpec([2, 4, 4, 1], tf.float32), + ] + ) + def conv2d_2453x2441_same_stride_1(self, img, kernel): + return tf.nn.depthwise_conv2d(img, kernel, [1, 1, 1, 1], "SAME", name="result") class ConvTest(tf_test_utils.TracedModuleTestCase): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._modules = tf_test_utils.compile_tf_module(DepthConv2dModule) + + def test_batched_feature_unpadded(self): + def batched_feature_unpadded(module): + i = tf_utils.ndarange([2, 4, 5, 2]) + k = tf_utils.ndarange([2, 2, 2, 3]) + module.conv2d_2452x2423_valid(i, k) + + self.compare_backends(batched_feature_unpadded, self._modules) + + def test_batched_feature_unpadded_same(self): + def batched_feature_unpadded_same(module): + i = tf_utils.ndarange([2, 4, 5, 2]) + k = tf_utils.ndarange([2, 4, 2, 3]) + module.conv2d_2452x2423_same(i, k) + + self.compare_backends(batched_feature_unpadded_same, self._modules) + + def test_batched_feature_unpadded_same_stride_2(self): + def batched_feature_unpadded_same_stride_2(module): + i = tf_utils.ndarange([2, 4, 5, 2]) + k = tf_utils.ndarange([2, 4, 2, 3]) + module.conv2d_2452x2423_valid_stride_2(i, k) + + self.compare_backends(batched_feature_unpadded_same_stride_2, self._modules) + + def test_batched_feature_padded_same_stride_2(self): + def batched_feature_padded_same_stride_2(module): + i = tf_utils.ndarange([2, 4, 5, 2]) + k = tf_utils.ndarange([2, 4, 2, 3]) + module.conv2d_2452x2423_same_stride_2(i, k) + + self.compare_backends(batched_feature_padded_same_stride_2, self._modules) + + def test_batched_feature_padded_same_stride_1_output_1(self): + def batched_feature_padded_same_stride_1_output_1(module): + i = tf_utils.ndarange([2, 4, 5, 4]) + k = tf_utils.ndarange([2, 4, 4, 1]) + module.conv2d_2453x2441_same_stride_1(i, k) - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._modules = tf_test_utils.compile_tf_module(DepthConv2dModule) - - # yapf: disable - def test_batched_feature_unpadded(self): - def batched_feature_unpadded(module): - i = tf_utils.ndarange([2, 4, 5, 2]) - k = tf_utils.ndarange([2, 2, 2, 3]) - module.conv2d_2452x2423_valid(i, k) - self.compare_backends(batched_feature_unpadded, self._modules) - - def test_batched_feature_unpadded_same(self): - def batched_feature_unpadded_same(module): - i = tf_utils.ndarange([2, 4, 5, 2]) - k = tf_utils.ndarange([2, 4, 2, 3]) - module.conv2d_2452x2423_same(i, k) - self.compare_backends(batched_feature_unpadded_same, self._modules) - - def test_batched_feature_unpadded_same_stride_2(self): - def batched_feature_unpadded_same_stride_2(module): - i = tf_utils.ndarange([2, 4, 5, 2]) - k = tf_utils.ndarange([2, 4, 2, 3]) - module.conv2d_2452x2423_valid_stride_2(i, k) - self.compare_backends(batched_feature_unpadded_same_stride_2, - self._modules) - - def test_batched_feature_padded_same_stride_2(self): - def batched_feature_padded_same_stride_2(module): - i = tf_utils.ndarange([2, 4, 5, 2]) - k = tf_utils.ndarange([2, 4, 2, 3]) - module.conv2d_2452x2423_same_stride_2(i, k) - self.compare_backends(batched_feature_padded_same_stride_2, self._modules) - - def test_batched_feature_padded_same_stride_1_output_1(self): - def batched_feature_padded_same_stride_1_output_1(module): - i = tf_utils.ndarange([2, 4, 5, 4]) - k = tf_utils.ndarange([2, 4, 4, 1]) - module.conv2d_2453x2441_same_stride_1(i, k) - self.compare_backends(batched_feature_padded_same_stride_1_output_1, - self._modules) - # yapf: enable + self.compare_backends( + batched_feature_padded_same_stride_1_output_1, self._modules + ) def main(argv): - del argv # Unused - if hasattr(tf, 'enable_v2_behavior'): - tf.enable_v2_behavior() - tf.test.main() + del argv # Unused + if hasattr(tf, "enable_v2_behavior"): + tf.enable_v2_behavior() + tf.test.main() -if __name__ == '__main__': - app.run(main) +if __name__ == "__main__": + app.run(main) diff --git a/integrations/tensorflow/test/python/iree_tf_tests/dynamic_mlp_relu_test.py b/integrations/tensorflow/test/python/iree_tf_tests/dynamic_mlp_relu_test.py index e049c9e47c2f..5784f41bfcf4 100644 --- a/integrations/tensorflow/test/python/iree_tf_tests/dynamic_mlp_relu_test.py +++ b/integrations/tensorflow/test/python/iree_tf_tests/dynamic_mlp_relu_test.py @@ -21,63 +21,57 @@ class DynamicMlpReluModule(tf.Module): - - def __init__(self, - hidden_1_dim=256, - hidden_2_dim=256, - input_dim=28 * 28, - classes=10): - super().__init__() - tf_utils.set_random_seed() - self.hidden_1_dim = hidden_1_dim - self.hidden_2_dim = hidden_2_dim - self.input_dim = input_dim - self.classes = classes - self.h1_weights = tf.Variable(tf.random.normal([input_dim, hidden_1_dim])) - self.h2_weights = tf.Variable(tf.random.normal([hidden_1_dim, - hidden_2_dim])) - self.out_weights = tf.Variable(tf.random.normal([hidden_2_dim, classes])) - self.h1_bias = tf.Variable(tf.random.normal([hidden_1_dim])) - self.h2_bias = tf.Variable(tf.random.normal([hidden_2_dim])) - self.out_bias = tf.Variable(tf.random.normal([classes])) - - # Compile with dynamic batch dim. - self.predict = tf.function( - input_signature=[tf.TensorSpec([None, self.input_dim])])(self.predict) - - def mlp(self, x): - layer_1 = tf.nn.relu(tf.add(tf.matmul(x, self.h1_weights), self.h1_bias)) - layer_2 = tf.nn.relu( - tf.add(tf.matmul(layer_1, self.h2_weights), self.h2_bias)) - return tf.nn.relu( - tf.add(tf.matmul(layer_2, self.out_weights), self.out_bias)) - - def predict(self, x): - return tf.nn.softmax(self.mlp(x)) + def __init__( + self, hidden_1_dim=256, hidden_2_dim=256, input_dim=28 * 28, classes=10 + ): + super().__init__() + tf_utils.set_random_seed() + self.hidden_1_dim = hidden_1_dim + self.hidden_2_dim = hidden_2_dim + self.input_dim = input_dim + self.classes = classes + self.h1_weights = tf.Variable(tf.random.normal([input_dim, hidden_1_dim])) + self.h2_weights = tf.Variable(tf.random.normal([hidden_1_dim, hidden_2_dim])) + self.out_weights = tf.Variable(tf.random.normal([hidden_2_dim, classes])) + self.h1_bias = tf.Variable(tf.random.normal([hidden_1_dim])) + self.h2_bias = tf.Variable(tf.random.normal([hidden_2_dim])) + self.out_bias = tf.Variable(tf.random.normal([classes])) + + # Compile with dynamic batch dim. + self.predict = tf.function( + input_signature=[tf.TensorSpec([None, self.input_dim])] + )(self.predict) + + def mlp(self, x): + layer_1 = tf.nn.relu(tf.add(tf.matmul(x, self.h1_weights), self.h1_bias)) + layer_2 = tf.nn.relu(tf.add(tf.matmul(layer_1, self.h2_weights), self.h2_bias)) + return tf.nn.relu(tf.add(tf.matmul(layer_2, self.out_weights), self.out_bias)) + + def predict(self, x): + return tf.nn.softmax(self.mlp(x)) class DynamicMlpReluTest(tf_test_utils.TracedModuleTestCase): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._modules = tf_test_utils.compile_tf_module( + DynamicMlpReluModule, exported_names=["predict"] + ) - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._modules = tf_test_utils.compile_tf_module(DynamicMlpReluModule, - exported_names=["predict"]) - - def test_dynamic_batch(self): - - def dynamic_batch(module): - x = tf_utils.uniform([3, 28 * 28]) * 1e-3 - module.predict(x) + def test_dynamic_batch(self): + def dynamic_batch(module): + x = tf_utils.uniform([3, 28 * 28]) * 1e-3 + module.predict(x) - self.compare_backends(dynamic_batch, self._modules) + self.compare_backends(dynamic_batch, self._modules) def main(argv): - del argv # Unused - if hasattr(tf, 'enable_v2_behavior'): - tf.enable_v2_behavior() - tf.test.main() + del argv # Unused + if hasattr(tf, "enable_v2_behavior"): + tf.enable_v2_behavior() + tf.test.main() -if __name__ == '__main__': - app.run(main) +if __name__ == "__main__": + app.run(main) diff --git a/integrations/tensorflow/test/python/iree_tf_tests/dynamic_mlp_test.py b/integrations/tensorflow/test/python/iree_tf_tests/dynamic_mlp_test.py index a648f2491aa9..ce2be9576e2b 100644 --- a/integrations/tensorflow/test/python/iree_tf_tests/dynamic_mlp_test.py +++ b/integrations/tensorflow/test/python/iree_tf_tests/dynamic_mlp_test.py @@ -17,63 +17,57 @@ class DynamicMlpModule(tf.Module): - - def __init__(self, - hidden_1_dim=256, - hidden_2_dim=256, - input_dim=28 * 28, - classes=10): - super().__init__() - tf_utils.set_random_seed() - self.hidden_1_dim = hidden_1_dim - self.hidden_2_dim = hidden_2_dim - self.input_dim = input_dim - self.classes = classes - self.h1_weights = tf.Variable(tf.random.normal([input_dim, hidden_1_dim])) - self.h2_weights = tf.Variable(tf.random.normal([hidden_1_dim, - hidden_2_dim])) - self.out_weights = tf.Variable(tf.random.normal([hidden_2_dim, classes])) - self.h1_bias = tf.Variable(tf.random.normal([hidden_1_dim])) - self.h2_bias = tf.Variable(tf.random.normal([hidden_2_dim])) - self.out_bias = tf.Variable(tf.random.normal([classes])) - - # Compile with dynamic batch dim. - self.predict = tf.function( - input_signature=[tf.TensorSpec([None, self.input_dim])])(self.predict) - - def mlp(self, x): - layer_1 = tf.sigmoid(tf.add(tf.matmul(x, self.h1_weights), self.h1_bias)) - layer_2 = tf.sigmoid( - tf.add(tf.matmul(layer_1, self.h2_weights), self.h2_bias)) - return tf.sigmoid( - tf.add(tf.matmul(layer_2, self.out_weights), self.out_bias)) - - def predict(self, x): - return tf.nn.softmax(self.mlp(x)) + def __init__( + self, hidden_1_dim=256, hidden_2_dim=256, input_dim=28 * 28, classes=10 + ): + super().__init__() + tf_utils.set_random_seed() + self.hidden_1_dim = hidden_1_dim + self.hidden_2_dim = hidden_2_dim + self.input_dim = input_dim + self.classes = classes + self.h1_weights = tf.Variable(tf.random.normal([input_dim, hidden_1_dim])) + self.h2_weights = tf.Variable(tf.random.normal([hidden_1_dim, hidden_2_dim])) + self.out_weights = tf.Variable(tf.random.normal([hidden_2_dim, classes])) + self.h1_bias = tf.Variable(tf.random.normal([hidden_1_dim])) + self.h2_bias = tf.Variable(tf.random.normal([hidden_2_dim])) + self.out_bias = tf.Variable(tf.random.normal([classes])) + + # Compile with dynamic batch dim. + self.predict = tf.function( + input_signature=[tf.TensorSpec([None, self.input_dim])] + )(self.predict) + + def mlp(self, x): + layer_1 = tf.sigmoid(tf.add(tf.matmul(x, self.h1_weights), self.h1_bias)) + layer_2 = tf.sigmoid(tf.add(tf.matmul(layer_1, self.h2_weights), self.h2_bias)) + return tf.sigmoid(tf.add(tf.matmul(layer_2, self.out_weights), self.out_bias)) + + def predict(self, x): + return tf.nn.softmax(self.mlp(x)) class DynamicMlpTest(tf_test_utils.TracedModuleTestCase): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._modules = tf_test_utils.compile_tf_module( + DynamicMlpModule, exported_names=["predict"] + ) - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._modules = tf_test_utils.compile_tf_module(DynamicMlpModule, - exported_names=["predict"]) - - def test_dynamic_batch(self): - - def dynamic_batch(module): - x = tf_utils.uniform([3, 28 * 28]) * 1e-3 - module.predict(x) + def test_dynamic_batch(self): + def dynamic_batch(module): + x = tf_utils.uniform([3, 28 * 28]) * 1e-3 + module.predict(x) - self.compare_backends(dynamic_batch, self._modules) + self.compare_backends(dynamic_batch, self._modules) def main(argv): - del argv # Unused - if hasattr(tf, 'enable_v2_behavior'): - tf.enable_v2_behavior() - tf.test.main() + del argv # Unused + if hasattr(tf, "enable_v2_behavior"): + tf.enable_v2_behavior() + tf.test.main() -if __name__ == '__main__': - app.run(main) +if __name__ == "__main__": + app.run(main) diff --git a/integrations/tensorflow/test/python/iree_tf_tests/einsum_dynamic_test.py b/integrations/tensorflow/test/python/iree_tf_tests/einsum_dynamic_test.py index 11df38ee3db7..1132135534b6 100644 --- a/integrations/tensorflow/test/python/iree_tf_tests/einsum_dynamic_test.py +++ b/integrations/tensorflow/test/python/iree_tf_tests/einsum_dynamic_test.py @@ -16,121 +16,147 @@ class EinsumDynamicModule(tf.Module): - - @tf.function(input_signature=[ - tf.TensorSpec([None, None], tf.float32), - ]) - def einsum_dynamic_dim_identity(self, x): - return tf.einsum('ij', x) - - @tf.function(input_signature=[ - tf.TensorSpec([None, None, None], tf.float32), - ]) - def einsum_dynamic_rank_identity(self, x): - return tf.einsum('...', x) - - @tf.function(input_signature=[ - tf.TensorSpec([None, LEFT_DIM, RIGHT_DIM], tf.float32), - ]) - def einsum_dynamic_dim_transpose(self, x): - return tf.einsum('bij -> bji', x) - - @tf.function(input_signature=[ - tf.TensorSpec([None, None, LEFT_DIM, RIGHT_DIM], tf.float32), - ]) - def einsum_dynamic_rank_diag(self, x): - return tf.einsum('...ii -> ...i', x) - - @tf.function(input_signature=[ - tf.TensorSpec([None, None, LEFT_DIM, RIGHT_DIM], tf.float32), - ]) - def einsum_dynamic_dim_sum(self, x): - return tf.einsum('abij -> ab', x) - - @tf.function(input_signature=[ - tf.TensorSpec([None, None], tf.float32), - tf.TensorSpec([None, None], tf.float32), - ]) - def einsum_dynamic_dim_matmul(self, lhs, rhs): - return tf.einsum('ij, jk -> ik', lhs, rhs) - - @tf.function(input_signature=[ - tf.TensorSpec([None, LEFT_DIM, INNER_DIM], tf.float32), - tf.TensorSpec([INNER_DIM, RIGHT_DIM], tf.float32), - ]) - def einsum_dynamic_dim_lhs_batch(self, lhs, rhs): - return tf.einsum('bij, jk -> bik', lhs, rhs) - - @tf.function(input_signature=[ - tf.TensorSpec([None, None, 8, 6], tf.float32), - tf.TensorSpec([12, 6, 4], tf.float32), - ]) - def einsum_dynamic_rank_split_heads(self, seq, weights): - # l: seq_len, m: d_model, h: num_heads, d: attention_depth - return tf.einsum('...lm, hmd -> ...hld', seq, weights) + @tf.function( + input_signature=[ + tf.TensorSpec([None, None], tf.float32), + ] + ) + def einsum_dynamic_dim_identity(self, x): + return tf.einsum("ij", x) + + @tf.function( + input_signature=[ + tf.TensorSpec([None, None, None], tf.float32), + ] + ) + def einsum_dynamic_rank_identity(self, x): + return tf.einsum("...", x) + + @tf.function( + input_signature=[ + tf.TensorSpec([None, LEFT_DIM, RIGHT_DIM], tf.float32), + ] + ) + def einsum_dynamic_dim_transpose(self, x): + return tf.einsum("bij -> bji", x) + + @tf.function( + input_signature=[ + tf.TensorSpec([None, None, LEFT_DIM, RIGHT_DIM], tf.float32), + ] + ) + def einsum_dynamic_rank_diag(self, x): + return tf.einsum("...ii -> ...i", x) + + @tf.function( + input_signature=[ + tf.TensorSpec([None, None, LEFT_DIM, RIGHT_DIM], tf.float32), + ] + ) + def einsum_dynamic_dim_sum(self, x): + return tf.einsum("abij -> ab", x) + + @tf.function( + input_signature=[ + tf.TensorSpec([None, None], tf.float32), + tf.TensorSpec([None, None], tf.float32), + ] + ) + def einsum_dynamic_dim_matmul(self, lhs, rhs): + return tf.einsum("ij, jk -> ik", lhs, rhs) + + @tf.function( + input_signature=[ + tf.TensorSpec([None, LEFT_DIM, INNER_DIM], tf.float32), + tf.TensorSpec([INNER_DIM, RIGHT_DIM], tf.float32), + ] + ) + def einsum_dynamic_dim_lhs_batch(self, lhs, rhs): + return tf.einsum("bij, jk -> bik", lhs, rhs) + + @tf.function( + input_signature=[ + tf.TensorSpec([None, None, 8, 6], tf.float32), + tf.TensorSpec([12, 6, 4], tf.float32), + ] + ) + def einsum_dynamic_rank_split_heads(self, seq, weights): + # l: seq_len, m: d_model, h: num_heads, d: attention_depth + return tf.einsum("...lm, hmd -> ...hld", seq, weights) class EinsumDynamicTest(tf_test_utils.TracedModuleTestCase): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._modules = tf_test_utils.compile_tf_module(EinsumDynamicModule) + + def test_einsum_dynamic_dim_identity(self): + def einsum_dynamic_dim_identity(module): + module.einsum_dynamic_dim_identity(tf_utils.ndarange([LEFT_DIM, RIGHT_DIM])) + + self.compare_backends(einsum_dynamic_dim_identity, self._modules) + + def test_einsum_dynamic_rank_identity(self): + def einsum_dynamic_rank_identity(module): + module.einsum_dynamic_rank_identity( + tf_utils.ndarange([BATCH_DIM, LEFT_DIM, RIGHT_DIM]) + ) + + self.compare_backends(einsum_dynamic_rank_identity, self._modules) + + def test_einsum_dynamic_dim_transpose(self): + def einsum_dynamic_dim_transpose(module): + module.einsum_dynamic_dim_transpose( + tf_utils.ndarange([BATCH_DIM, LEFT_DIM, RIGHT_DIM]) + ) + + self.compare_backends(einsum_dynamic_dim_transpose, self._modules) + + def test_einsum_dynamic_rank_diag(self): + def einsum_dynamic_rank_diag(module): + module.einsum_dynamic_rank_diag( + tf_utils.ndarange([BATCH_DIM, BATCH_DIM, LEFT_DIM, RIGHT_DIM]) + ) + + self.compare_backends(einsum_dynamic_rank_diag, self._modules) + + def test_einsum_dynamic_dim_sum(self): + def einsum_dynamic_dim_sum(module): + module.einsum_dynamic_dim_sum( + tf_utils.ndarange([BATCH_DIM, BATCH_DIM, LEFT_DIM, RIGHT_DIM]) + ) + + self.compare_backends(einsum_dynamic_dim_sum, self._modules) + + def test_einsum_dynamic_dim_matmul(self): + def einsum_dynamic_dim_matmul(module): + module.einsum_dynamic_dim_matmul( + tf_utils.ndarange([LEFT_DIM, INNER_DIM]), + tf_utils.ndarange([INNER_DIM, RIGHT_DIM]), + ) + + self.compare_backends(einsum_dynamic_dim_matmul, self._modules) + + def test_einsum_dynamic_dim_lhs_batch(self): + def einsum_dynamic_dim_lhs_batch(module): + module.einsum_dynamic_dim_lhs_batch( + tf_utils.ndarange([BATCH_DIM, LEFT_DIM, INNER_DIM]), + tf_utils.ndarange([INNER_DIM, RIGHT_DIM]), + ) + + self.compare_backends(einsum_dynamic_dim_lhs_batch, self._modules) + + def test_einsum_dynamic_rank_split_heads(self): + def einsum_dynamic_rank_split_heads(module): + module.einsum_dynamic_rank_split_heads( + tf_utils.ndarange([BATCH_DIM, BATCH_DIM, 8, 6]), + tf_utils.ndarange([12, 6, 4]), + ) - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._modules = tf_test_utils.compile_tf_module(EinsumDynamicModule) - - # yapf: disable - def test_einsum_dynamic_dim_identity(self): - def einsum_dynamic_dim_identity(module): - module.einsum_dynamic_dim_identity( - tf_utils.ndarange([LEFT_DIM, RIGHT_DIM])) - self.compare_backends(einsum_dynamic_dim_identity, self._modules) - - def test_einsum_dynamic_rank_identity(self): - def einsum_dynamic_rank_identity(module): - module.einsum_dynamic_rank_identity( - tf_utils.ndarange([BATCH_DIM, LEFT_DIM, RIGHT_DIM])) - self.compare_backends(einsum_dynamic_rank_identity, self._modules) - - def test_einsum_dynamic_dim_transpose(self): - def einsum_dynamic_dim_transpose(module): - module.einsum_dynamic_dim_transpose( - tf_utils.ndarange([BATCH_DIM, LEFT_DIM, RIGHT_DIM])) - self.compare_backends(einsum_dynamic_dim_transpose, self._modules) - - def test_einsum_dynamic_rank_diag(self): - def einsum_dynamic_rank_diag(module): - module.einsum_dynamic_rank_diag( - tf_utils.ndarange([BATCH_DIM, BATCH_DIM, LEFT_DIM, RIGHT_DIM])) - self.compare_backends(einsum_dynamic_rank_diag, self._modules) - - def test_einsum_dynamic_dim_sum(self): - def einsum_dynamic_dim_sum(module): - module.einsum_dynamic_dim_sum( - tf_utils.ndarange([BATCH_DIM, BATCH_DIM, LEFT_DIM, RIGHT_DIM])) - self.compare_backends(einsum_dynamic_dim_sum, self._modules) - - def test_einsum_dynamic_dim_matmul(self): - def einsum_dynamic_dim_matmul(module): - module.einsum_dynamic_dim_matmul( - tf_utils.ndarange([LEFT_DIM, INNER_DIM]), - tf_utils.ndarange([INNER_DIM, RIGHT_DIM])) - self.compare_backends(einsum_dynamic_dim_matmul, self._modules) - - def test_einsum_dynamic_dim_lhs_batch(self): - def einsum_dynamic_dim_lhs_batch(module): - module.einsum_dynamic_dim_lhs_batch( - tf_utils.ndarange([BATCH_DIM, LEFT_DIM, INNER_DIM]), - tf_utils.ndarange([INNER_DIM, RIGHT_DIM])) - self.compare_backends(einsum_dynamic_dim_lhs_batch, self._modules) - - def test_einsum_dynamic_rank_split_heads(self): - def einsum_dynamic_rank_split_heads(module): - module.einsum_dynamic_rank_split_heads( - tf_utils.ndarange([BATCH_DIM, BATCH_DIM, 8, 6]), - tf_utils.ndarange([12, 6, 4])) - self.compare_backends(einsum_dynamic_rank_split_heads, self._modules) - # yapf: enable + self.compare_backends(einsum_dynamic_rank_split_heads, self._modules) if __name__ == "__main__": - if hasattr(tf, "enable_v2_behavior"): - tf.enable_v2_behavior() - tf.test.main() + if hasattr(tf, "enable_v2_behavior"): + tf.enable_v2_behavior() + tf.test.main() diff --git a/integrations/tensorflow/test/python/iree_tf_tests/einsum_static_test.py b/integrations/tensorflow/test/python/iree_tf_tests/einsum_static_test.py index 00ea10c66d01..4e82f29cf060 100644 --- a/integrations/tensorflow/test/python/iree_tf_tests/einsum_static_test.py +++ b/integrations/tensorflow/test/python/iree_tf_tests/einsum_static_test.py @@ -16,202 +16,248 @@ class EinsumStaticModule(tf.Module): - - @tf.function(input_signature=[ - tf.TensorSpec([LEFT_DIM, RIGHT_DIM], tf.float32), - ]) - def einsum_identity(self, x): - return tf.einsum('ij', x) - - @tf.function(input_signature=[ - tf.TensorSpec([LEFT_DIM, RIGHT_DIM], tf.float32), - ]) - def einsum_implicit_transpose(self, x): - return tf.einsum('ji', x) # :woozy: - - @tf.function(input_signature=[ - tf.TensorSpec([LEFT_DIM, RIGHT_DIM], tf.float32), - ]) - def einsum_explicit_transpose(self, x): - return tf.einsum('ij -> ji', x) - - @tf.function(input_signature=[ - tf.TensorSpec([LEFT_DIM, RIGHT_DIM], tf.float32), - ]) - def einsum_implicit_trace(self, x): - return tf.einsum('ii', x) - - @tf.function(input_signature=[ - tf.TensorSpec([LEFT_DIM, RIGHT_DIM], tf.float32), - ]) - def einsum_explicit_trace(self, x): - return tf.einsum('ii ->', x) - - @tf.function(input_signature=[ - tf.TensorSpec([LEFT_DIM, RIGHT_DIM], tf.float32), - ]) - def einsum_diag(self, x): - return tf.einsum('ii -> i', x) - - @tf.function(input_signature=[ - tf.TensorSpec([LEFT_DIM, RIGHT_DIM], tf.float32), - ]) - def einsum_sum(self, x): - return tf.einsum('ij ->', x) - - @tf.function(input_signature=[ - tf.TensorSpec([LEFT_DIM, RIGHT_DIM], tf.float32), - ]) - def einsum_sum_axis_0(self, x): - return tf.einsum('ij -> j', x) - - @tf.function(input_signature=[ - tf.TensorSpec([LEFT_DIM, RIGHT_DIM], tf.float32), - ]) - def einsum_sum_axis_1(self, x): - return tf.einsum('ij -> i', x) - - @tf.function(input_signature=[ - tf.TensorSpec([LEFT_DIM, INNER_DIM], tf.float32), - tf.TensorSpec([INNER_DIM, RIGHT_DIM], tf.float32), - ]) - def einsum_matmul(self, lhs, rhs): - return tf.einsum('ij, jk -> ik', lhs, rhs) - - @tf.function(input_signature=[ - tf.TensorSpec([BATCH_DIM, LEFT_DIM, INNER_DIM], tf.float32), - tf.TensorSpec([INNER_DIM, RIGHT_DIM], tf.float32), - ]) - def einsum_lhs_batch(self, lhs, rhs): - return tf.einsum('bij, jk -> bik', lhs, rhs) - - @tf.function(input_signature=[ - tf.TensorSpec([1, LEFT_DIM, INNER_DIM], tf.float32), - tf.TensorSpec([BATCH_DIM, INNER_DIM, RIGHT_DIM], tf.float32), - ]) - def einsum_broadcast_singleton_dimension(self, lhs, rhs): - return tf.einsum('lij, rjk -> rik', lhs, rhs) - - @tf.function(input_signature=[ - tf.TensorSpec([BATCH_DIM, 8, 6], tf.float32), - tf.TensorSpec([12, 6, 4], tf.float32), - ]) - def einsum_split_heads(self, seq, weights): - # l: seq_len, m: d_model, h: num_heads, d: attention_depth - return tf.einsum('blm, hmd -> bhld', seq, weights) - - @tf.function(input_signature=[ - tf.TensorSpec([BATCH_DIM, 5, 3, 2, 6], tf.float32), - tf.TensorSpec([BATCH_DIM, 5, 6], tf.float32), - ]) - def einsum_batched_high_rank_matrix_vector_mul(self, lhs, rhs): - return tf.einsum('bijxy, biy -> bijx', lhs, rhs) - - @tf.function(input_signature=[ - tf.TensorSpec([BATCH_DIM, 2, 6], tf.float32), - tf.TensorSpec([BATCH_DIM, 5, 3, 6], tf.float32), - ]) - def einsum_batched_matrix_high_rank_vector_mul(self, lhs, rhs): - return tf.einsum('bxy, bijy -> bijx', lhs, rhs) + @tf.function( + input_signature=[ + tf.TensorSpec([LEFT_DIM, RIGHT_DIM], tf.float32), + ] + ) + def einsum_identity(self, x): + return tf.einsum("ij", x) + + @tf.function( + input_signature=[ + tf.TensorSpec([LEFT_DIM, RIGHT_DIM], tf.float32), + ] + ) + def einsum_implicit_transpose(self, x): + return tf.einsum("ji", x) # :woozy: + + @tf.function( + input_signature=[ + tf.TensorSpec([LEFT_DIM, RIGHT_DIM], tf.float32), + ] + ) + def einsum_explicit_transpose(self, x): + return tf.einsum("ij -> ji", x) + + @tf.function( + input_signature=[ + tf.TensorSpec([LEFT_DIM, RIGHT_DIM], tf.float32), + ] + ) + def einsum_implicit_trace(self, x): + return tf.einsum("ii", x) + + @tf.function( + input_signature=[ + tf.TensorSpec([LEFT_DIM, RIGHT_DIM], tf.float32), + ] + ) + def einsum_explicit_trace(self, x): + return tf.einsum("ii ->", x) + + @tf.function( + input_signature=[ + tf.TensorSpec([LEFT_DIM, RIGHT_DIM], tf.float32), + ] + ) + def einsum_diag(self, x): + return tf.einsum("ii -> i", x) + + @tf.function( + input_signature=[ + tf.TensorSpec([LEFT_DIM, RIGHT_DIM], tf.float32), + ] + ) + def einsum_sum(self, x): + return tf.einsum("ij ->", x) + + @tf.function( + input_signature=[ + tf.TensorSpec([LEFT_DIM, RIGHT_DIM], tf.float32), + ] + ) + def einsum_sum_axis_0(self, x): + return tf.einsum("ij -> j", x) + + @tf.function( + input_signature=[ + tf.TensorSpec([LEFT_DIM, RIGHT_DIM], tf.float32), + ] + ) + def einsum_sum_axis_1(self, x): + return tf.einsum("ij -> i", x) + + @tf.function( + input_signature=[ + tf.TensorSpec([LEFT_DIM, INNER_DIM], tf.float32), + tf.TensorSpec([INNER_DIM, RIGHT_DIM], tf.float32), + ] + ) + def einsum_matmul(self, lhs, rhs): + return tf.einsum("ij, jk -> ik", lhs, rhs) + + @tf.function( + input_signature=[ + tf.TensorSpec([BATCH_DIM, LEFT_DIM, INNER_DIM], tf.float32), + tf.TensorSpec([INNER_DIM, RIGHT_DIM], tf.float32), + ] + ) + def einsum_lhs_batch(self, lhs, rhs): + return tf.einsum("bij, jk -> bik", lhs, rhs) + + @tf.function( + input_signature=[ + tf.TensorSpec([1, LEFT_DIM, INNER_DIM], tf.float32), + tf.TensorSpec([BATCH_DIM, INNER_DIM, RIGHT_DIM], tf.float32), + ] + ) + def einsum_broadcast_singleton_dimension(self, lhs, rhs): + return tf.einsum("lij, rjk -> rik", lhs, rhs) + + @tf.function( + input_signature=[ + tf.TensorSpec([BATCH_DIM, 8, 6], tf.float32), + tf.TensorSpec([12, 6, 4], tf.float32), + ] + ) + def einsum_split_heads(self, seq, weights): + # l: seq_len, m: d_model, h: num_heads, d: attention_depth + return tf.einsum("blm, hmd -> bhld", seq, weights) + + @tf.function( + input_signature=[ + tf.TensorSpec([BATCH_DIM, 5, 3, 2, 6], tf.float32), + tf.TensorSpec([BATCH_DIM, 5, 6], tf.float32), + ] + ) + def einsum_batched_high_rank_matrix_vector_mul(self, lhs, rhs): + return tf.einsum("bijxy, biy -> bijx", lhs, rhs) + + @tf.function( + input_signature=[ + tf.TensorSpec([BATCH_DIM, 2, 6], tf.float32), + tf.TensorSpec([BATCH_DIM, 5, 3, 6], tf.float32), + ] + ) + def einsum_batched_matrix_high_rank_vector_mul(self, lhs, rhs): + return tf.einsum("bxy, bijy -> bijx", lhs, rhs) class EinsumStaticTest(tf_test_utils.TracedModuleTestCase): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._modules = tf_test_utils.compile_tf_module(EinsumStaticModule) + + def test_einsum_identity(self): + def einsum_identity(module): + module.einsum_identity(tf_utils.ndarange([LEFT_DIM, RIGHT_DIM])) + + self.compare_backends(einsum_identity, self._modules) + + def test_einsum_implicit_transpose(self): + def einsum_implicit_transpose(module): + module.einsum_implicit_transpose(tf_utils.ndarange([LEFT_DIM, RIGHT_DIM])) + + self.compare_backends(einsum_implicit_transpose, self._modules) + + def test_einsum_explicit_transpose(self): + def einsum_explicit_transpose(module): + module.einsum_explicit_transpose(tf_utils.ndarange([LEFT_DIM, RIGHT_DIM])) + + self.compare_backends(einsum_explicit_transpose, self._modules) + + def test_einsum_implicit_trace(self): + def einsum_implicit_trace(module): + module.einsum_implicit_trace(tf_utils.ndarange([LEFT_DIM, RIGHT_DIM])) + + self.compare_backends(einsum_implicit_trace, self._modules) + + def test_einsum_explicit_trace(self): + def einsum_explicit_trace(module): + module.einsum_explicit_trace(tf_utils.ndarange([LEFT_DIM, RIGHT_DIM])) + + self.compare_backends(einsum_explicit_trace, self._modules) + + def test_einsum_diag(self): + def einsum_diag(module): + module.einsum_diag(tf_utils.ndarange([LEFT_DIM, RIGHT_DIM])) + + self.compare_backends(einsum_diag, self._modules) + + def test_einsum_sum(self): + def einsum_sum(module): + module.einsum_sum(tf_utils.ndarange([LEFT_DIM, RIGHT_DIM])) + + self.compare_backends(einsum_sum, self._modules) + + def test_einsum_sum_axis_0(self): + def einsum_sum_axis_0(module): + module.einsum_sum_axis_0(tf_utils.ndarange([LEFT_DIM, RIGHT_DIM])) + + self.compare_backends(einsum_sum_axis_0, self._modules) + + def test_einsum_sum_axis_1(self): + def einsum_sum_axis_1(module): + module.einsum_sum_axis_1(tf_utils.ndarange([LEFT_DIM, RIGHT_DIM])) + + self.compare_backends(einsum_sum_axis_1, self._modules) + + def test_einsum_matmul(self): + def einsum_matmul(module): + module.einsum_matmul( + tf_utils.ndarange([LEFT_DIM, INNER_DIM]), + tf_utils.ndarange([INNER_DIM, RIGHT_DIM]), + ) + + self.compare_backends(einsum_matmul, self._modules) + + def test_einsum_lhs_batch(self): + def einsum_lhs_batch(module): + module.einsum_lhs_batch( + tf_utils.ndarange([BATCH_DIM, LEFT_DIM, INNER_DIM]), + tf_utils.ndarange([INNER_DIM, RIGHT_DIM]), + ) + + self.compare_backends(einsum_lhs_batch, self._modules) + + def test_einsum_broadcast_singleton_dimension(self): + def einsum_broadcast_singleton_dimension(module): + module.einsum_broadcast_singleton_dimension( + tf_utils.ndarange([1, LEFT_DIM, INNER_DIM]), + tf_utils.ndarange([BATCH_DIM, INNER_DIM, RIGHT_DIM]), + ) + + self.compare_backends(einsum_broadcast_singleton_dimension, self._modules) + + def test_einsum_split_heads(self): + def einsum_split_heads(module): + module.einsum_split_heads( + tf_utils.ndarange([BATCH_DIM, 8, 6]), tf_utils.ndarange([12, 6, 4]) + ) + + self.compare_backends(einsum_split_heads, self._modules) + + def test_einsum_batched_high_rank_matrix_vector_mul(self): + def einsum_batched_high_rank_matrix_vector_mul(module): + module.einsum_batched_high_rank_matrix_vector_mul( + tf_utils.ndarange([BATCH_DIM, 5, 3, 2, 6]), + tf_utils.ndarange([BATCH_DIM, 5, 6]), + ) + + self.compare_backends(einsum_batched_high_rank_matrix_vector_mul, self._modules) + + def test_einsum_batched_matrix_high_rank_vector_mul(self): + def einsum_batched_matrix_high_rank_vector_mul(module): + module.einsum_batched_matrix_high_rank_vector_mul( + tf_utils.ndarange([BATCH_DIM, 2, 6]), + tf_utils.ndarange([BATCH_DIM, 5, 3, 6]), + ) - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._modules = tf_test_utils.compile_tf_module(EinsumStaticModule) - - # yapf: disable - def test_einsum_identity(self): - def einsum_identity(module): - module.einsum_identity(tf_utils.ndarange([LEFT_DIM, RIGHT_DIM])) - self.compare_backends(einsum_identity, self._modules) - - def test_einsum_implicit_transpose(self): - def einsum_implicit_transpose(module): - module.einsum_implicit_transpose(tf_utils.ndarange([LEFT_DIM, RIGHT_DIM])) - self.compare_backends(einsum_implicit_transpose, self._modules) - - def test_einsum_explicit_transpose(self): - def einsum_explicit_transpose(module): - module.einsum_explicit_transpose(tf_utils.ndarange([LEFT_DIM, RIGHT_DIM])) - self.compare_backends(einsum_explicit_transpose, self._modules) - - def test_einsum_implicit_trace(self): - def einsum_implicit_trace(module): - module.einsum_implicit_trace(tf_utils.ndarange([LEFT_DIM, RIGHT_DIM])) - self.compare_backends(einsum_implicit_trace, self._modules) - - def test_einsum_explicit_trace(self): - def einsum_explicit_trace(module): - module.einsum_explicit_trace(tf_utils.ndarange([LEFT_DIM, RIGHT_DIM])) - self.compare_backends(einsum_explicit_trace, self._modules) - - def test_einsum_diag(self): - def einsum_diag(module): - module.einsum_diag(tf_utils.ndarange([LEFT_DIM, RIGHT_DIM])) - self.compare_backends(einsum_diag, self._modules) - - def test_einsum_sum(self): - def einsum_sum(module): - module.einsum_sum(tf_utils.ndarange([LEFT_DIM, RIGHT_DIM])) - self.compare_backends(einsum_sum, self._modules) - - def test_einsum_sum_axis_0(self): - def einsum_sum_axis_0(module): - module.einsum_sum_axis_0(tf_utils.ndarange([LEFT_DIM, RIGHT_DIM])) - self.compare_backends(einsum_sum_axis_0, self._modules) - - def test_einsum_sum_axis_1(self): - def einsum_sum_axis_1(module): - module.einsum_sum_axis_1(tf_utils.ndarange([LEFT_DIM, RIGHT_DIM])) - self.compare_backends(einsum_sum_axis_1, self._modules) - - def test_einsum_matmul(self): - def einsum_matmul(module): - module.einsum_matmul(tf_utils.ndarange([LEFT_DIM, INNER_DIM]), - tf_utils.ndarange([INNER_DIM, RIGHT_DIM])) - self.compare_backends(einsum_matmul, self._modules) - - def test_einsum_lhs_batch(self): - def einsum_lhs_batch(module): - module.einsum_lhs_batch( - tf_utils.ndarange([BATCH_DIM, LEFT_DIM, INNER_DIM]), - tf_utils.ndarange([INNER_DIM, RIGHT_DIM])) - self.compare_backends(einsum_lhs_batch, self._modules) - - def test_einsum_broadcast_singleton_dimension(self): - def einsum_broadcast_singleton_dimension(module): - module.einsum_broadcast_singleton_dimension( - tf_utils.ndarange([1, LEFT_DIM, INNER_DIM]), - tf_utils.ndarange([BATCH_DIM, INNER_DIM, RIGHT_DIM])) - self.compare_backends(einsum_broadcast_singleton_dimension, self._modules) - - def test_einsum_split_heads(self): - def einsum_split_heads(module): - module.einsum_split_heads(tf_utils.ndarange([BATCH_DIM, 8, 6]), - tf_utils.ndarange([12, 6, 4])) - self.compare_backends(einsum_split_heads, self._modules) - - def test_einsum_batched_high_rank_matrix_vector_mul(self): - def einsum_batched_high_rank_matrix_vector_mul(module): - module.einsum_batched_high_rank_matrix_vector_mul( - tf_utils.ndarange([BATCH_DIM, 5, 3, 2, 6]), - tf_utils.ndarange([BATCH_DIM, 5, 6])) - self.compare_backends(einsum_batched_high_rank_matrix_vector_mul, - self._modules) - - def test_einsum_batched_matrix_high_rank_vector_mul(self): - def einsum_batched_matrix_high_rank_vector_mul(module): - module.einsum_batched_matrix_high_rank_vector_mul( - tf_utils.ndarange([BATCH_DIM, 2, 6]), - tf_utils.ndarange([BATCH_DIM, 5, 3, 6])) - self.compare_backends(einsum_batched_matrix_high_rank_vector_mul, - self._modules) - # yapf: enable + self.compare_backends(einsum_batched_matrix_high_rank_vector_mul, self._modules) if __name__ == "__main__": - if hasattr(tf, "enable_v2_behavior"): - tf.enable_v2_behavior() - tf.test.main() + if hasattr(tf, "enable_v2_behavior"): + tf.enable_v2_behavior() + tf.test.main() diff --git a/integrations/tensorflow/test/python/iree_tf_tests/einsum_vector_test.py b/integrations/tensorflow/test/python/iree_tf_tests/einsum_vector_test.py index 228f26bda0d1..147faff436de 100644 --- a/integrations/tensorflow/test/python/iree_tf_tests/einsum_vector_test.py +++ b/integrations/tensorflow/test/python/iree_tf_tests/einsum_vector_test.py @@ -13,92 +13,110 @@ class EinsumVectorModule(tf.Module): - - @tf.function(input_signature=[ - tf.TensorSpec([VECTOR_DIM], tf.float32), - ]) - def einsum_identity(self, x): - return tf.einsum('i', x) - - @tf.function(input_signature=[ - tf.TensorSpec([VECTOR_DIM], tf.float32), - ]) - def einsum_sum(self, x): - return tf.einsum('i ->', x) - - @tf.function(input_signature=[ - tf.TensorSpec([VECTOR_DIM], tf.float32), - tf.TensorSpec([VECTOR_DIM], tf.float32), - ]) - def einsum_mul(self, lhs, rhs): - return tf.einsum('i, i -> i', lhs, rhs) - - @tf.function(input_signature=[ - tf.TensorSpec([VECTOR_DIM], tf.float32), - tf.TensorSpec([VECTOR_DIM], tf.float32), - ]) - def einsum_implicit_inner_product(self, lhs, rhs): - return tf.einsum('i, i', lhs, rhs) - - @tf.function(input_signature=[ - tf.TensorSpec([VECTOR_DIM], tf.float32), - tf.TensorSpec([VECTOR_DIM], tf.float32), - ]) - def einsum_explicit_inner_product(self, lhs, rhs): - return tf.einsum('i, i ->', lhs, rhs) - - @tf.function(input_signature=[ - tf.TensorSpec([VECTOR_DIM], tf.float32), - tf.TensorSpec([VECTOR_DIM], tf.float32), - ]) - def einsum_outer_product(self, lhs, rhs): - return tf.einsum('i, j -> ij', lhs, rhs) + @tf.function( + input_signature=[ + tf.TensorSpec([VECTOR_DIM], tf.float32), + ] + ) + def einsum_identity(self, x): + return tf.einsum("i", x) + + @tf.function( + input_signature=[ + tf.TensorSpec([VECTOR_DIM], tf.float32), + ] + ) + def einsum_sum(self, x): + return tf.einsum("i ->", x) + + @tf.function( + input_signature=[ + tf.TensorSpec([VECTOR_DIM], tf.float32), + tf.TensorSpec([VECTOR_DIM], tf.float32), + ] + ) + def einsum_mul(self, lhs, rhs): + return tf.einsum("i, i -> i", lhs, rhs) + + @tf.function( + input_signature=[ + tf.TensorSpec([VECTOR_DIM], tf.float32), + tf.TensorSpec([VECTOR_DIM], tf.float32), + ] + ) + def einsum_implicit_inner_product(self, lhs, rhs): + return tf.einsum("i, i", lhs, rhs) + + @tf.function( + input_signature=[ + tf.TensorSpec([VECTOR_DIM], tf.float32), + tf.TensorSpec([VECTOR_DIM], tf.float32), + ] + ) + def einsum_explicit_inner_product(self, lhs, rhs): + return tf.einsum("i, i ->", lhs, rhs) + + @tf.function( + input_signature=[ + tf.TensorSpec([VECTOR_DIM], tf.float32), + tf.TensorSpec([VECTOR_DIM], tf.float32), + ] + ) + def einsum_outer_product(self, lhs, rhs): + return tf.einsum("i, j -> ij", lhs, rhs) class EinsumVectorTest(tf_test_utils.TracedModuleTestCase): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._modules = tf_test_utils.compile_tf_module(EinsumVectorModule) + + def test_einsum_identity(self): + def einsum_identity(module): + module.einsum_identity(tf_utils.ndarange([VECTOR_DIM])) + + self.compare_backends(einsum_identity, self._modules) + + def test_einsum_sum(self): + def einsum_sum(module): + module.einsum_sum(tf_utils.ndarange([VECTOR_DIM])) + + self.compare_backends(einsum_sum, self._modules) + + def test_einsum_mul(self): + def einsum_mul(module): + module.einsum_mul( + tf_utils.ndarange([VECTOR_DIM]), tf_utils.ndarange([VECTOR_DIM]) + ) + + self.compare_backends(einsum_mul, self._modules) + + def test_einsum_implicit_inner_product(self): + def einsum_implicit_inner_product(module): + module.einsum_implicit_inner_product( + tf_utils.ndarange([VECTOR_DIM]), tf_utils.ndarange([VECTOR_DIM]) + ) + + self.compare_backends(einsum_implicit_inner_product, self._modules) + + def test_einsum_explicit_inner_product(self): + def einsum_explicit_inner_product(module): + module.einsum_explicit_inner_product( + tf_utils.ndarange([VECTOR_DIM]), tf_utils.ndarange([VECTOR_DIM]) + ) + + self.compare_backends(einsum_explicit_inner_product, self._modules) + + def test_einsum_outer_product(self): + def einsum_outer_product(module): + module.einsum_outer_product( + tf_utils.ndarange([VECTOR_DIM]), tf_utils.ndarange([VECTOR_DIM]) + ) - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._modules = tf_test_utils.compile_tf_module(EinsumVectorModule) - - # yapf: disable - def test_einsum_identity(self): - def einsum_identity(module): - module.einsum_identity(tf_utils.ndarange([VECTOR_DIM])) - self.compare_backends(einsum_identity, self._modules) - - def test_einsum_sum(self): - def einsum_sum(module): - module.einsum_sum(tf_utils.ndarange([VECTOR_DIM])) - self.compare_backends(einsum_sum, self._modules) - - def test_einsum_mul(self): - def einsum_mul(module): - module.einsum_mul(tf_utils.ndarange([VECTOR_DIM]), - tf_utils.ndarange([VECTOR_DIM])) - self.compare_backends(einsum_mul, self._modules) - - def test_einsum_implicit_inner_product(self): - def einsum_implicit_inner_product(module): - module.einsum_implicit_inner_product(tf_utils.ndarange([VECTOR_DIM]), - tf_utils.ndarange([VECTOR_DIM])) - self.compare_backends(einsum_implicit_inner_product, self._modules) - - def test_einsum_explicit_inner_product(self): - def einsum_explicit_inner_product(module): - module.einsum_explicit_inner_product(tf_utils.ndarange([VECTOR_DIM]), - tf_utils.ndarange([VECTOR_DIM])) - self.compare_backends(einsum_explicit_inner_product, self._modules) - - def test_einsum_outer_product(self): - def einsum_outer_product(module): - module.einsum_outer_product(tf_utils.ndarange([VECTOR_DIM]), - tf_utils.ndarange([VECTOR_DIM])) - self.compare_backends(einsum_outer_product, self._modules) - # yapf: enable + self.compare_backends(einsum_outer_product, self._modules) if __name__ == "__main__": - if hasattr(tf, "enable_v2_behavior"): - tf.enable_v2_behavior() - tf.test.main() + if hasattr(tf, "enable_v2_behavior"): + tf.enable_v2_behavior() + tf.test.main() diff --git a/integrations/tensorflow/test/python/iree_tf_tests/fft_test.py b/integrations/tensorflow/test/python/iree_tf_tests/fft_test.py index 7135a05c073a..3cba32aa0428 100644 --- a/integrations/tensorflow/test/python/iree_tf_tests/fft_test.py +++ b/integrations/tensorflow/test/python/iree_tf_tests/fft_test.py @@ -11,86 +11,120 @@ class FftModule(tf.Module): - - complex_input_signature = [ - tf.TensorSpec([16], tf.float32), - tf.TensorSpec([16], tf.float32) - ] - - @tf.function(input_signature=complex_input_signature) - def fft(self, real_array, imag_array): - complex_in = tf.complex(real_array, imag_array) - complex_out = tf.signal.fft(complex_in) - return tf.math.real(complex_out), tf.math.imag(complex_out) - - @tf.function(input_signature=complex_input_signature) - def ifft(self, real_array, imag_array): - complex_in = tf.complex(real_array, imag_array) - complex_out = tf.signal.ifft(complex_in) - return tf.math.real(complex_out), tf.math.imag(complex_out) - - @tf.function(input_signature=[tf.TensorSpec([32], tf.float32)]) - def rfft(self, real_array): - complex_out = tf.signal.rfft(real_array) - return tf.math.real(complex_out), tf.math.imag(complex_out) - - # TODO(natashaknk): Enable IRFFT tests when Linalg on tensors changes land. - # @tf.function(input_signature=complex_input_signature) - # def irfft(self, real_array, imag_array): - # complex_in = tf.complex(real_array, imag_array) - # real_out = tf.signal.irfft(complex_in) - # return real_out + complex_input_signature = [ + tf.TensorSpec([16], tf.float32), + tf.TensorSpec([16], tf.float32), + ] + + @tf.function(input_signature=complex_input_signature) + def fft(self, real_array, imag_array): + complex_in = tf.complex(real_array, imag_array) + complex_out = tf.signal.fft(complex_in) + return tf.math.real(complex_out), tf.math.imag(complex_out) + + @tf.function(input_signature=complex_input_signature) + def ifft(self, real_array, imag_array): + complex_in = tf.complex(real_array, imag_array) + complex_out = tf.signal.ifft(complex_in) + return tf.math.real(complex_out), tf.math.imag(complex_out) + + @tf.function(input_signature=[tf.TensorSpec([32], tf.float32)]) + def rfft(self, real_array): + complex_out = tf.signal.rfft(real_array) + return tf.math.real(complex_out), tf.math.imag(complex_out) + + # TODO(natashaknk): Enable IRFFT tests when Linalg on tensors changes land. + # @tf.function(input_signature=complex_input_signature) + # def irfft(self, real_array, imag_array): + # complex_in = tf.complex(real_array, imag_array) + # real_out = tf.signal.irfft(complex_in) + # return real_out class FftTest(tf_test_utils.TracedModuleTestCase): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._modules = tf_test_utils.compile_tf_module(FftModule) + self.real_array = np.array( + [ + 9.0, + 1.0, + 4.5, + -0.3, + 10.0, + -1.0, + 5.5, + 0.3, + 299.0, + 3.5, + -0.777, + 2, + 1.7, + 3.5, + -4.5, + 0.0, + ], + dtype=np.float32, + ) + self.imag_array = np.array( + [ + 0.0, + -1.0, + 17.7, + 10.0, + 0.0, + -11.0, + 2763, + 0.0, + 0.0, + -1.5, + 16.8, + 100.0, + 0.0, + -111.0, + 2.3, + 1.0, + ], + dtype=np.float32, + ) + + # Required since pffft requires a minimum of 32 elements for real ffts. + self.long_real_array = np.concatenate( + (self.real_array, self.real_array), axis=None + ) + + def test_fft(self): + def fft(module): + module.fft(self.real_array, self.imag_array, rtol=1e-4) + + self.compare_backends(fft, self._modules) + + def test_ifft(self): + def ifft(module): + module.ifft(self.real_array, self.imag_array, rtol=1e-4) + + self.compare_backends(ifft, self._modules) + + def test_rfft(self): + def rfft(module): + module.rfft(self.long_real_array, rtol=1e-4) + + self.compare_backends(rfft, self._modules) - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._modules = tf_test_utils.compile_tf_module(FftModule) - self.real_array = np.array([ - 9., 1., 4.5, -0.3, 10., -1., 5.5, 0.3, 299., 3.5, -0.777, 2, 1.7, 3.5, - -4.5, 0.0 - ], - dtype=np.float32) - self.imag_array = np.array([ - 0., -1., 17.7, 10., 0., -11., 2763, 0., 0., -1.5, 16.8, 100., 0., -111., - 2.3, 1. - ], - dtype=np.float32) - - # Required since pffft requires a minimum of 32 elements for real ffts. - self.long_real_array = np.concatenate((self.real_array, self.real_array), - axis=None) - - # yapf: disable - def test_fft(self): - def fft(module): - module.fft(self.real_array, self.imag_array, rtol=1e-4) - self.compare_backends(fft, self._modules) - - def test_ifft(self): - def ifft(module): - module.ifft(self.real_array, self.imag_array, rtol=1e-4) - self.compare_backends(ifft, self._modules) - - def test_rfft(self): - def rfft(module): - module.rfft(self.long_real_array, rtol=1e-4) - self.compare_backends(rfft, self._modules) # TODO(natashaknk): Enable IRFFT tests when Linalg on tensors changes land. # def test_irfft(self): # def irfft(module): # module.irfft(self.real_array, self.imag_array, rtol=1e-4) # self.compare_backends(irfft, self._modules) -# yapf: enable def main(argv): - del argv # Unused - if hasattr(tf, 'enable_v2_behavior'): - tf.enable_v2_behavior() - tf.test.main() + del argv # Unused + if hasattr(tf, "enable_v2_behavior"): + tf.enable_v2_behavior() + tf.test.main() + -if __name__ == '__main__': - app.run(main) +if __name__ == "__main__": + app.run(main) diff --git a/integrations/tensorflow/test/python/iree_tf_tests/fill_test.py b/integrations/tensorflow/test/python/iree_tf_tests/fill_test.py index ef1166466d8c..d9e14f5332ae 100644 --- a/integrations/tensorflow/test/python/iree_tf_tests/fill_test.py +++ b/integrations/tensorflow/test/python/iree_tf_tests/fill_test.py @@ -11,40 +11,36 @@ class FillModule(tf.Module): + def __init__(self): + pass - def __init__(self): - pass - - @tf.function(input_signature=[ - tf.TensorSpec([2], tf.int32), - tf.TensorSpec([], tf.float32) - ]) - def fill(self, dims, value): - return tf.fill(dims, value) + @tf.function( + input_signature=[tf.TensorSpec([2], tf.int32), tf.TensorSpec([], tf.float32)] + ) + def fill(self, dims, value): + return tf.fill(dims, value) class FillTest(tf_test_utils.TracedModuleTestCase): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._modules = tf_test_utils.compile_tf_module(FillModule) - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._modules = tf_test_utils.compile_tf_module(FillModule) - - def test_fill(self): - - def fill(module): - dims = np.array([2, 3], dtype=np.int32) - value = np.array(9., dtype=np.float32) - module.fill(dims, value) + def test_fill(self): + def fill(module): + dims = np.array([2, 3], dtype=np.int32) + value = np.array(9.0, dtype=np.float32) + module.fill(dims, value) - self.compare_backends(fill, self._modules) + self.compare_backends(fill, self._modules) def main(argv): - del argv # Unused - if hasattr(tf, 'enable_v2_behavior'): - tf.enable_v2_behavior() - tf.test.main() + del argv # Unused + if hasattr(tf, "enable_v2_behavior"): + tf.enable_v2_behavior() + tf.test.main() -if __name__ == '__main__': - app.run(main) +if __name__ == "__main__": + app.run(main) diff --git a/integrations/tensorflow/test/python/iree_tf_tests/gather_test.py b/integrations/tensorflow/test/python/iree_tf_tests/gather_test.py index dfdbaca50429..a183d18656ce 100644 --- a/integrations/tensorflow/test/python/iree_tf_tests/gather_test.py +++ b/integrations/tensorflow/test/python/iree_tf_tests/gather_test.py @@ -12,107 +12,118 @@ class GatherModule(tf.Module): - - @tf.function(input_signature=[ - tf.TensorSpec([4, 8], tf.float32), - tf.TensorSpec([], tf.int32) - ]) - def gather_axis0_scalar(self, params, indices): - return tf.gather(params, indices) - - @tf.function(input_signature=[ - tf.TensorSpec([4, 8], tf.float32), - tf.TensorSpec([2], tf.int32) - ]) - def gather_axis0_batch0(self, params, indices): - return tf.gather(params, indices) - - @tf.function(input_signature=[ - tf.TensorSpec([4, 7, 8], tf.float32), - tf.TensorSpec([2], tf.int32) - ]) - def gather_axis1_batch0(self, params, indices): - return tf.gather(params, indices, axis=1) - - @tf.function(input_signature=[ - tf.TensorSpec([4, 7, 8, 2], tf.float32), - tf.TensorSpec([4, 1], tf.int32) - ]) - def gather_axis2_batch1(self, params, indices): - return tf.gather(params, indices, axis=2, batch_dims=1) - - @tf.function(input_signature=[ - tf.TensorSpec([4, 7, 8, 2], tf.float32), - tf.TensorSpec([4, 1], tf.int32) - ]) - def gather_axis1_batch1(self, params, indices): - return tf.gather(params, indices, axis=1, batch_dims=1) - - @tf.function(input_signature=[ - tf.TensorSpec([2, 4], tf.int32), - tf.TensorSpec([2, 4], tf.int32) - ]) - def gather_axis2_batch2(self, params, indices): - return tf.gather(params, indices, axis=1, batch_dims=1) + @tf.function( + input_signature=[tf.TensorSpec([4, 8], tf.float32), tf.TensorSpec([], tf.int32)] + ) + def gather_axis0_scalar(self, params, indices): + return tf.gather(params, indices) + + @tf.function( + input_signature=[ + tf.TensorSpec([4, 8], tf.float32), + tf.TensorSpec([2], tf.int32), + ] + ) + def gather_axis0_batch0(self, params, indices): + return tf.gather(params, indices) + + @tf.function( + input_signature=[ + tf.TensorSpec([4, 7, 8], tf.float32), + tf.TensorSpec([2], tf.int32), + ] + ) + def gather_axis1_batch0(self, params, indices): + return tf.gather(params, indices, axis=1) + + @tf.function( + input_signature=[ + tf.TensorSpec([4, 7, 8, 2], tf.float32), + tf.TensorSpec([4, 1], tf.int32), + ] + ) + def gather_axis2_batch1(self, params, indices): + return tf.gather(params, indices, axis=2, batch_dims=1) + + @tf.function( + input_signature=[ + tf.TensorSpec([4, 7, 8, 2], tf.float32), + tf.TensorSpec([4, 1], tf.int32), + ] + ) + def gather_axis1_batch1(self, params, indices): + return tf.gather(params, indices, axis=1, batch_dims=1) + + @tf.function( + input_signature=[ + tf.TensorSpec([2, 4], tf.int32), + tf.TensorSpec([2, 4], tf.int32), + ] + ) + def gather_axis2_batch2(self, params, indices): + return tf.gather(params, indices, axis=1, batch_dims=1) class GatherTest(tf_test_utils.TracedModuleTestCase): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._modules = tf_test_utils.compile_tf_module(GatherModule) + + def test_gather_axis0_scalar(self): + def gather_axis0_scalar(module): + indices = np.array(2, dtype=np.int32) + params = tf_utils.ndarange([4, 8]) + module.gather_axis0_scalar(params, indices) + + self.compare_backends(gather_axis0_scalar, self._modules) + + def test_gather_axis0_batch0(self): + def gather_axis0_batch0(module): + indices = np.array([2, 3], dtype=np.int32) + params = tf_utils.ndarange([4, 8]) + module.gather_axis0_batch0(params, indices) + + self.compare_backends(gather_axis0_batch0, self._modules) + + def test_gather_axis1_batch0(self): + def gather_axis1_batch0(module): + indices = np.array([2, 3], dtype=np.int32) + params = tf_utils.ndarange([4, 7, 8]) + module.gather_axis1_batch0(params, indices) + + self.compare_backends(gather_axis1_batch0, self._modules) + + def test_gather_axis2_batch1(self): + def gather_axis2_batch1(module): + indices = np.array([[2], [3], [0], [1]], dtype=np.int32) + params = tf_utils.ndarange([4, 7, 8, 2]) + module.gather_axis2_batch1(params, indices) + + self.compare_backends(gather_axis2_batch1, self._modules) + + def test_gather_axis1_batch1(self): + def gather_axis1_batch1(module): + indices = np.array([[2], [3], [0], [1]], dtype=np.int32) + params = tf_utils.ndarange([4, 7, 8, 2]) + module.gather_axis1_batch1(params, indices) + + self.compare_backends(gather_axis1_batch1, self._modules) + + def test_gather_axis2_batch2(self): + def gather_axis2_batch2(module): + indices = np.array([[0, 1, 2, 3], [3, 2, 1, 0]], dtype=np.int32) + values = np.array([[0, 1, 2, 3], [9, 8, 7, 0]], dtype=np.int32) + module.gather_axis2_batch2(values, indices) - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._modules = tf_test_utils.compile_tf_module(GatherModule) - - # yapf: disable - def test_gather_axis0_scalar(self): - def gather_axis0_scalar(module): - indices = np.array(2, dtype=np.int32) - params = tf_utils.ndarange([4, 8]) - module.gather_axis0_scalar(params, indices) - self.compare_backends(gather_axis0_scalar, self._modules) - - def test_gather_axis0_batch0(self): - def gather_axis0_batch0(module): - indices = np.array([2, 3], dtype=np.int32) - params = tf_utils.ndarange([4, 8]) - module.gather_axis0_batch0(params, indices) - self.compare_backends(gather_axis0_batch0, self._modules) - - def test_gather_axis1_batch0(self): - def gather_axis1_batch0(module): - indices = np.array([2, 3], dtype=np.int32) - params = tf_utils.ndarange([4, 7, 8]) - module.gather_axis1_batch0(params, indices) - self.compare_backends(gather_axis1_batch0, self._modules) - - def test_gather_axis2_batch1(self): - def gather_axis2_batch1(module): - indices = np.array([[2], [3], [0], [1]], dtype=np.int32) - params = tf_utils.ndarange([4, 7, 8, 2]) - module.gather_axis2_batch1(params, indices) - self.compare_backends(gather_axis2_batch1, self._modules) - - def test_gather_axis1_batch1(self): - def gather_axis1_batch1(module): - indices = np.array([[2], [3], [0], [1]], dtype=np.int32) - params = tf_utils.ndarange([4, 7, 8, 2]) - module.gather_axis1_batch1(params, indices) - self.compare_backends(gather_axis1_batch1, self._modules) - - def test_gather_axis2_batch2(self): - def gather_axis2_batch2(module): - indices = np.array([[0, 1, 2, 3], [3, 2, 1, 0]], dtype=np.int32) - values = np.array([[0, 1, 2, 3], [9, 8, 7, 0]], dtype=np.int32) - module.gather_axis2_batch2(values, indices) - self.compare_backends(gather_axis2_batch2, self._modules) - # yapf: enable + self.compare_backends(gather_axis2_batch2, self._modules) def main(argv): - del argv # Unused - if hasattr(tf, 'enable_v2_behavior'): - tf.enable_v2_behavior() - tf.test.main() + del argv # Unused + if hasattr(tf, "enable_v2_behavior"): + tf.enable_v2_behavior() + tf.test.main() -if __name__ == '__main__': - app.run(main) +if __name__ == "__main__": + app.run(main) diff --git a/integrations/tensorflow/test/python/iree_tf_tests/image_resize_test.py b/integrations/tensorflow/test/python/iree_tf_tests/image_resize_test.py index 5600560fd94c..fffd61d6cb44 100644 --- a/integrations/tensorflow/test/python/iree_tf_tests/image_resize_test.py +++ b/integrations/tensorflow/test/python/iree_tf_tests/image_resize_test.py @@ -12,50 +12,46 @@ class ResizeImageModule(tf.Module): + def __init__(self): + pass - def __init__(self): - pass + @tf.function(input_signature=[tf.TensorSpec([1, 52, 37, 1], tf.int32)]) + def downsample_nearest_neighbor(self, image): + size = np.asarray([8, 7], dtype=np.int32) + return tf.image.resize_nearest_neighbor(image, size) - @tf.function(input_signature=[tf.TensorSpec([1, 52, 37, 1], tf.int32)]) - def downsample_nearest_neighbor(self, image): - size = np.asarray([8, 7], dtype=np.int32) - return tf.image.resize_nearest_neighbor(image, size) - - @tf.function(input_signature=[tf.TensorSpec([1, 8, 7, 1], tf.int32)]) - def upsample_nearest_neighbor(self, image): - size = np.asarray([52, 37], dtype=np.int32) - return tf.image.resize_nearest_neighbor(image, size) + @tf.function(input_signature=[tf.TensorSpec([1, 8, 7, 1], tf.int32)]) + def upsample_nearest_neighbor(self, image): + size = np.asarray([52, 37], dtype=np.int32) + return tf.image.resize_nearest_neighbor(image, size) class ResizeImageTest(tf_test_utils.TracedModuleTestCase): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._modules = tf_test_utils.compile_tf_module(ResizeImageModule) - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._modules = tf_test_utils.compile_tf_module(ResizeImageModule) - - def test_downsample_nearest_neighbor(self): - - def downsample_nearest_neighbor(module): - img = tf_utils.ndarange([1, 52, 37, 1], dtype=np.int32) - module.downsample_nearest_neighbor(img) - - self.compare_backends(downsample_nearest_neighbor, self._modules) + def test_downsample_nearest_neighbor(self): + def downsample_nearest_neighbor(module): + img = tf_utils.ndarange([1, 52, 37, 1], dtype=np.int32) + module.downsample_nearest_neighbor(img) - def test_upsample_nearest_neighbor(self): + self.compare_backends(downsample_nearest_neighbor, self._modules) - def upsample_nearest_neighbor(module): - img = tf_utils.ndarange([1, 8, 7, 1], dtype=np.int32) - module.upsample_nearest_neighbor(img) + def test_upsample_nearest_neighbor(self): + def upsample_nearest_neighbor(module): + img = tf_utils.ndarange([1, 8, 7, 1], dtype=np.int32) + module.upsample_nearest_neighbor(img) - self.compare_backends(upsample_nearest_neighbor, self._modules) + self.compare_backends(upsample_nearest_neighbor, self._modules) def main(argv): - del argv # Unused - if hasattr(tf, 'enable_v2_behavior'): - tf.enable_v2_behavior() - tf.test.main() + del argv # Unused + if hasattr(tf, "enable_v2_behavior"): + tf.enable_v2_behavior() + tf.test.main() -if __name__ == '__main__': - app.run(main) +if __name__ == "__main__": + app.run(main) diff --git a/integrations/tensorflow/test/python/iree_tf_tests/linspace_test.py b/integrations/tensorflow/test/python/iree_tf_tests/linspace_test.py index 88a4dcc49c70..1894aad37a77 100644 --- a/integrations/tensorflow/test/python/iree_tf_tests/linspace_test.py +++ b/integrations/tensorflow/test/python/iree_tf_tests/linspace_test.py @@ -11,43 +11,39 @@ class LinspaceModule(tf.Module): + def __init__(self): + pass - def __init__(self): - pass - - @tf.function(input_signature=[ - tf.TensorSpec([], tf.float32), - tf.TensorSpec([], tf.float32) - ]) - def linspace(self, start, stop): - # 'num' is const because XLA's iota operation does not support dynamic - # shapes. - num = np.array(3, dtype=np.int32) - return tf.linspace(start, stop, num) + @tf.function( + input_signature=[tf.TensorSpec([], tf.float32), tf.TensorSpec([], tf.float32)] + ) + def linspace(self, start, stop): + # 'num' is const because XLA's iota operation does not support dynamic + # shapes. + num = np.array(3, dtype=np.int32) + return tf.linspace(start, stop, num) class LinspaceTest(tf_test_utils.TracedModuleTestCase): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._modules = tf_test_utils.compile_tf_module(LinspaceModule) - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._modules = tf_test_utils.compile_tf_module(LinspaceModule) - - def test_linspace(self): - - def linspace(module): - start = np.array(10., dtype=np.float32) - stop = np.array(12., dtype=np.float32) - module.linspace(start, stop) + def test_linspace(self): + def linspace(module): + start = np.array(10.0, dtype=np.float32) + stop = np.array(12.0, dtype=np.float32) + module.linspace(start, stop) - self.compare_backends(linspace, self._modules) + self.compare_backends(linspace, self._modules) def main(argv): - del argv # Unused - if hasattr(tf, 'enable_v2_behavior'): - tf.enable_v2_behavior() - tf.test.main() + del argv # Unused + if hasattr(tf, "enable_v2_behavior"): + tf.enable_v2_behavior() + tf.test.main() -if __name__ == '__main__': - app.run(main) +if __name__ == "__main__": + app.run(main) diff --git a/integrations/tensorflow/test/python/iree_tf_tests/mandelbrot_test.py b/integrations/tensorflow/test/python/iree_tf_tests/mandelbrot_test.py index 8c0c6161c47d..7ffe0e18f776 100644 --- a/integrations/tensorflow/test/python/iree_tf_tests/mandelbrot_test.py +++ b/integrations/tensorflow/test/python/iree_tf_tests/mandelbrot_test.py @@ -10,100 +10,103 @@ def complex_add(a_re, a_im, b_re, b_im): - return a_re + b_re, a_im + b_im + return a_re + b_re, a_im + b_im def complex_mul(a_re, a_im, b_re, b_im): - c_re = a_re * b_re - a_im * b_im - c_im = a_re * b_im + a_im * b_re - return c_re, c_im + c_re = a_re * b_re - a_im * b_im + c_im = a_re * b_im + a_im * b_re + return c_re, c_im # This is a fun but quite interesting example because the return value and most # of the interior computations are dynamically shaped. class MandelbrotModule(tf.Module): - - @tf.function(input_signature=[ - tf.TensorSpec([], tf.float32), - tf.TensorSpec([], tf.float32), - tf.TensorSpec([], tf.float32), - tf.TensorSpec([], tf.int32), - tf.TensorSpec([], tf.int32) - ]) - def calculate(self, center_re, center_im, view_size, view_pixels, - num_iterations): - """Calculates an image which represents the Mandelbrot set. - - Args: - center_re: The center point of the view (real part). - center_im: The center point of the view (imaginary part). - view_size: The view will display a square with this size. - view_pixels: The returned image will be a square with this many pixels on - a side. - num_iterations: The number of iterations to use for determining escape. - - Returns: - A tensor of pixels with shape [view_size, view_size] which represents - the mandelbrot set. - """ - re_min = center_re - view_size / 2. - re_max = center_re + view_size / 2. - im_min = center_im - view_size / 2. - im_max = center_im + view_size / 2. - re_coords = tf.linspace(re_min, re_max, view_pixels) - im_coords = tf.linspace(im_min, im_max, view_pixels) - - # Generate flat list of real and imaginary parts of the points to test. - # This requires taking all pairs of re_coords and im_coords, which we - # do by broadcasting into a 2d matrix (real part is broadcasted "vertically" - # and imaginary part is broadcasted "horizontally"). - # We use a Nx1 * 1xN -> NxN matmul to do the broadcast. - c_re = tf.reshape( - tf.matmul(tf.ones([view_pixels, 1]), - tf.reshape(re_coords, [1, view_pixels])), [-1]) - c_im = tf.reshape( - tf.matmul(tf.reshape(im_coords, [view_pixels, 1]), - tf.ones([1, view_pixels])), [-1]) - - z_re = tf.zeros_like(c_re) - z_im = tf.zeros_like(c_im) - for _ in range(num_iterations): - square_re, square_im = complex_mul(z_re, z_im, z_re, z_im) - z_re, z_im = complex_add(square_re, square_im, c_re, c_im) - - # Calculate if the points are in the set (that is, if their orbit under the - # recurrence relationship has diverged). - z_abs = tf.sqrt(z_re**2 + z_im**2) - z_abs = tf.where(tf.math.is_nan(z_abs), 100. * tf.ones_like(z_abs), z_abs) - in_the_set = tf.where(z_abs > 50., tf.ones_like(z_abs), - tf.zeros_like(z_abs)) - # Return an image - return tf.reshape(in_the_set, shape=[view_pixels, view_pixels]) + @tf.function( + input_signature=[ + tf.TensorSpec([], tf.float32), + tf.TensorSpec([], tf.float32), + tf.TensorSpec([], tf.float32), + tf.TensorSpec([], tf.int32), + tf.TensorSpec([], tf.int32), + ] + ) + def calculate(self, center_re, center_im, view_size, view_pixels, num_iterations): + """Calculates an image which represents the Mandelbrot set. + + Args: + center_re: The center point of the view (real part). + center_im: The center point of the view (imaginary part). + view_size: The view will display a square with this size. + view_pixels: The returned image will be a square with this many pixels on + a side. + num_iterations: The number of iterations to use for determining escape. + + Returns: + A tensor of pixels with shape [view_size, view_size] which represents + the mandelbrot set. + """ + re_min = center_re - view_size / 2.0 + re_max = center_re + view_size / 2.0 + im_min = center_im - view_size / 2.0 + im_max = center_im + view_size / 2.0 + re_coords = tf.linspace(re_min, re_max, view_pixels) + im_coords = tf.linspace(im_min, im_max, view_pixels) + + # Generate flat list of real and imaginary parts of the points to test. + # This requires taking all pairs of re_coords and im_coords, which we + # do by broadcasting into a 2d matrix (real part is broadcasted "vertically" + # and imaginary part is broadcasted "horizontally"). + # We use a Nx1 * 1xN -> NxN matmul to do the broadcast. + c_re = tf.reshape( + tf.matmul( + tf.ones([view_pixels, 1]), tf.reshape(re_coords, [1, view_pixels]) + ), + [-1], + ) + c_im = tf.reshape( + tf.matmul( + tf.reshape(im_coords, [view_pixels, 1]), tf.ones([1, view_pixels]) + ), + [-1], + ) + + z_re = tf.zeros_like(c_re) + z_im = tf.zeros_like(c_im) + for _ in range(num_iterations): + square_re, square_im = complex_mul(z_re, z_im, z_re, z_im) + z_re, z_im = complex_add(square_re, square_im, c_re, c_im) + + # Calculate if the points are in the set (that is, if their orbit under the + # recurrence relationship has diverged). + z_abs = tf.sqrt(z_re**2 + z_im**2) + z_abs = tf.where(tf.math.is_nan(z_abs), 100.0 * tf.ones_like(z_abs), z_abs) + in_the_set = tf.where(z_abs > 50.0, tf.ones_like(z_abs), tf.zeros_like(z_abs)) + # Return an image + return tf.reshape(in_the_set, shape=[view_pixels, view_pixels]) class MandelbrotTest(tf_test_utils.TracedModuleTestCase): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._modules = tf_test_utils.compile_tf_module(MandelbrotModule) - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._modules = tf_test_utils.compile_tf_module(MandelbrotModule) - - def test_mandelbrot(self): - - def mandelbrot(module): - # Basic view of the entire set. - module.calculate(-0.7, 0.0, 3.0, 400, 100) - # This is a much more detailed view, so more iterations are needed. - module.calculate(-0.7436447860, 0.1318252536, 0.0000029336, 400, 3000) + def test_mandelbrot(self): + def mandelbrot(module): + # Basic view of the entire set. + module.calculate(-0.7, 0.0, 3.0, 400, 100) + # This is a much more detailed view, so more iterations are needed. + module.calculate(-0.7436447860, 0.1318252536, 0.0000029336, 400, 3000) - self.compare_backends(mandelbrot, self._modules) + self.compare_backends(mandelbrot, self._modules) def main(argv): - del argv # Unused - if hasattr(tf, 'enable_v2_behavior'): - tf.enable_v2_behavior() - tf.test.main() + del argv # Unused + if hasattr(tf, "enable_v2_behavior"): + tf.enable_v2_behavior() + tf.test.main() -if __name__ == '__main__': - app.run(main) +if __name__ == "__main__": + app.run(main) diff --git a/integrations/tensorflow/test/python/iree_tf_tests/matrix_ops_dynamic_test.py b/integrations/tensorflow/test/python/iree_tf_tests/matrix_ops_dynamic_test.py index 01fac2ed1a31..466e30b0a04c 100644 --- a/integrations/tensorflow/test/python/iree_tf_tests/matrix_ops_dynamic_test.py +++ b/integrations/tensorflow/test/python/iree_tf_tests/matrix_ops_dynamic_test.py @@ -12,74 +12,86 @@ class MatrixOpsDynamicModule(tf.Module): + @tf.function( + input_signature=[ + tf.TensorSpec([None, None, 4, 2], tf.float32), + tf.TensorSpec([None, None, 2, 4], tf.float32), + ] + ) + def matmul_high_rank_batch(self, lhs, rhs): + return tf.matmul(lhs, rhs) + + @tf.function( + input_signature=[ + tf.TensorSpec([None, None, None], tf.float32), + tf.TensorSpec([None, None, None], tf.float32), + ] + ) + def matmul_dynamic(self, lhs, rhs): + return tf.matmul(lhs, rhs) + + @tf.function( + input_signature=[ + tf.TensorSpec([None, None, None], tf.float32), + tf.TensorSpec([None, None], tf.float32), + ] + ) + def matmul_dynamic_lhs_batch(self, lhs, rhs): + return tf.matmul(lhs, rhs) - @tf.function(input_signature=[ - tf.TensorSpec([None, None, 4, 2], tf.float32), - tf.TensorSpec([None, None, 2, 4], tf.float32), - ]) - def matmul_high_rank_batch(self, lhs, rhs): - return tf.matmul(lhs, rhs) - @tf.function(input_signature=[ - tf.TensorSpec([None, None, None], tf.float32), - tf.TensorSpec([None, None, None], tf.float32), - ]) - def matmul_dynamic(self, lhs, rhs): - return tf.matmul(lhs, rhs) +class MatrixOpsDynamicTest(tf_test_utils.TracedModuleTestCase): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._modules = tf_test_utils.compile_tf_module(MatrixOpsDynamicModule) - @tf.function(input_signature=[ - tf.TensorSpec([None, None, None], tf.float32), - tf.TensorSpec([None, None], tf.float32), - ]) - def matmul_dynamic_lhs_batch(self, lhs, rhs): - return tf.matmul(lhs, rhs) + def test_matmul_high_rank_batch(self): + def matmul_high_rank_batch(module): + module.matmul_high_rank_batch( + tf_utils.uniform([1, 7, 4, 2]), tf_utils.uniform([7, 1, 2, 4]) + ) + self.compare_backends(matmul_high_rank_batch, self._modules) -class MatrixOpsDynamicTest(tf_test_utils.TracedModuleTestCase): + def test_matmul_dynamic_matching_batch(self): + def matmul_dynamic_matching_batch(module): + module.matmul_dynamic( + tf_utils.uniform([2, 2, 3]), tf_utils.uniform([2, 3, 4]) + ) + + self.compare_backends(matmul_dynamic_matching_batch, self._modules) + + def test_matmul_dynamic_broadcast_lhs(self): + def matmul_dynamic_broadcast_lhs(module): + module.matmul_dynamic( + tf_utils.uniform([1, 2, 3]), tf_utils.uniform([2, 3, 4]) + ) + + self.compare_backends(matmul_dynamic_broadcast_lhs, self._modules) + + def test_matmul_dynamic_broadcast_rhs(self): + def matmul_dynamic_broadcast_rhs(module): + module.matmul_dynamic( + tf_utils.uniform([2, 2, 3]), tf_utils.uniform([1, 3, 4]) + ) + + self.compare_backends(matmul_dynamic_broadcast_rhs, self._modules) + + def test_matmul_dynamic_rank_broadcasting(self): + def matmul_dynamic_rank_broadcasting(module): + module.matmul_dynamic_lhs_batch( + tf_utils.uniform([7, 2, 3]), tf_utils.uniform([3, 4]) + ) - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._modules = tf_test_utils.compile_tf_module(MatrixOpsDynamicModule) - - # yapf: disable - def test_matmul_high_rank_batch(self): - def matmul_high_rank_batch(module): - module.matmul_high_rank_batch( - tf_utils.uniform([1, 7, 4, 2]), tf_utils.uniform([7, 1, 2, 4])) - self.compare_backends(matmul_high_rank_batch, self._modules) - - def test_matmul_dynamic_matching_batch(self): - def matmul_dynamic_matching_batch(module): - module.matmul_dynamic( - tf_utils.uniform([2, 2, 3]), tf_utils.uniform([2, 3, 4])) - self.compare_backends(matmul_dynamic_matching_batch, self._modules) - - def test_matmul_dynamic_broadcast_lhs(self): - def matmul_dynamic_broadcast_lhs(module): - module.matmul_dynamic( - tf_utils.uniform([1, 2, 3]), tf_utils.uniform([2, 3, 4])) - self.compare_backends(matmul_dynamic_broadcast_lhs, self._modules) - - def test_matmul_dynamic_broadcast_rhs(self): - def matmul_dynamic_broadcast_rhs(module): - module.matmul_dynamic( - tf_utils.uniform([2, 2, 3]), tf_utils.uniform([1, 3, 4])) - self.compare_backends(matmul_dynamic_broadcast_rhs, self._modules) - - def test_matmul_dynamic_rank_broadcasting(self): - def matmul_dynamic_rank_broadcasting(module): - module.matmul_dynamic_lhs_batch( - tf_utils.uniform([7, 2, 3]), tf_utils.uniform([3, 4])) - self.compare_backends(matmul_dynamic_rank_broadcasting, self._modules) - # yapf: enable + self.compare_backends(matmul_dynamic_rank_broadcasting, self._modules) def main(argv): - del argv # Unused - if hasattr(tf, 'enable_v2_behavior'): - tf.enable_v2_behavior() - tf.test.main() + del argv # Unused + if hasattr(tf, "enable_v2_behavior"): + tf.enable_v2_behavior() + tf.test.main() -if __name__ == '__main__': - app.run(main) +if __name__ == "__main__": + app.run(main) diff --git a/integrations/tensorflow/test/python/iree_tf_tests/matrix_ops_static_test.py b/integrations/tensorflow/test/python/iree_tf_tests/matrix_ops_static_test.py index 412339a2d803..648c26cf406f 100644 --- a/integrations/tensorflow/test/python/iree_tf_tests/matrix_ops_static_test.py +++ b/integrations/tensorflow/test/python/iree_tf_tests/matrix_ops_static_test.py @@ -17,78 +17,91 @@ class MatrixOpsStaticModule(tf.Module): - - @tf.function(input_signature=[ - tf.TensorSpec([LEFT_DIM, INNER_DIM], tf.float32), - tf.TensorSpec([INNER_DIM, RIGHT_DIM], tf.float32), - ]) - def basic_matmul(self, lhs, rhs): - return tf.matmul(lhs, rhs) - - @tf.function(input_signature=[ - tf.TensorSpec([BATCH_DIM, LEFT_DIM, INNER_DIM], tf.float32), - tf.TensorSpec([INNER_DIM, RIGHT_DIM], tf.float32), - ]) - def matmul_lhs_batch(self, lhs, rhs): - return tf.matmul(lhs, rhs) - - @tf.function(input_signature=[ - tf.TensorSpec([LEFT_DIM, INNER_DIM], tf.float32), - tf.TensorSpec([BATCH_DIM, INNER_DIM, RIGHT_DIM], tf.float32), - ]) - def matmul_rhs_batch(self, lhs, rhs): - return tf.matmul(lhs, rhs) - - @tf.function(input_signature=[ - tf.TensorSpec([1, LEFT_DIM, INNER_DIM], tf.float32), - tf.TensorSpec([BATCH_DIM, INNER_DIM, RIGHT_DIM], tf.float32), - ]) - def matmul_broadcast_singleton_dimension(self, lhs, rhs): - return tf.matmul(lhs, rhs) + @tf.function( + input_signature=[ + tf.TensorSpec([LEFT_DIM, INNER_DIM], tf.float32), + tf.TensorSpec([INNER_DIM, RIGHT_DIM], tf.float32), + ] + ) + def basic_matmul(self, lhs, rhs): + return tf.matmul(lhs, rhs) + + @tf.function( + input_signature=[ + tf.TensorSpec([BATCH_DIM, LEFT_DIM, INNER_DIM], tf.float32), + tf.TensorSpec([INNER_DIM, RIGHT_DIM], tf.float32), + ] + ) + def matmul_lhs_batch(self, lhs, rhs): + return tf.matmul(lhs, rhs) + + @tf.function( + input_signature=[ + tf.TensorSpec([LEFT_DIM, INNER_DIM], tf.float32), + tf.TensorSpec([BATCH_DIM, INNER_DIM, RIGHT_DIM], tf.float32), + ] + ) + def matmul_rhs_batch(self, lhs, rhs): + return tf.matmul(lhs, rhs) + + @tf.function( + input_signature=[ + tf.TensorSpec([1, LEFT_DIM, INNER_DIM], tf.float32), + tf.TensorSpec([BATCH_DIM, INNER_DIM, RIGHT_DIM], tf.float32), + ] + ) + def matmul_broadcast_singleton_dimension(self, lhs, rhs): + return tf.matmul(lhs, rhs) class MatrixOpsStaticTest(tf_test_utils.TracedModuleTestCase): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._modules = tf_test_utils.compile_tf_module(MatrixOpsStaticModule) + + def test_basic_matmul(self): + def basic_matmul(module): + module.basic_matmul( + tf_utils.uniform([LEFT_DIM, INNER_DIM]), + tf_utils.uniform([INNER_DIM, RIGHT_DIM]), + ) + + self.compare_backends(basic_matmul, self._modules) + + def test_matmul_lhs_batch(self): + def matmul_lhs_batch(module): + module.matmul_lhs_batch( + tf_utils.uniform([BATCH_DIM, LEFT_DIM, INNER_DIM]), + tf_utils.uniform([INNER_DIM, RIGHT_DIM]), + ) + + self.compare_backends(matmul_lhs_batch, self._modules) + + def test_matmul_rhs_batch(self): + def matmul_rhs_batch(module): + module.matmul_rhs_batch( + tf_utils.uniform([LEFT_DIM, INNER_DIM]), + tf_utils.uniform([BATCH_DIM, INNER_DIM, RIGHT_DIM]), + ) + + self.compare_backends(matmul_rhs_batch, self._modules) + + def test_matmul_broadcast_singleton_dimension(self): + def matmul_broadcast_singleton_dimension(module): + module.matmul_broadcast_singleton_dimension( + tf_utils.uniform([1, LEFT_DIM, INNER_DIM]), + tf_utils.uniform([BATCH_DIM, INNER_DIM, RIGHT_DIM]), + ) - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._modules = tf_test_utils.compile_tf_module(MatrixOpsStaticModule) - - # yapf: disable - def test_basic_matmul(self): - def basic_matmul(module): - module.basic_matmul(tf_utils.uniform([LEFT_DIM, INNER_DIM]), - tf_utils.uniform([INNER_DIM, RIGHT_DIM])) - self.compare_backends(basic_matmul, self._modules) - - def test_matmul_lhs_batch(self): - def matmul_lhs_batch(module): - module.matmul_lhs_batch( - tf_utils.uniform([BATCH_DIM, LEFT_DIM, INNER_DIM]), - tf_utils.uniform([INNER_DIM, RIGHT_DIM])) - self.compare_backends(matmul_lhs_batch, self._modules) - - def test_matmul_rhs_batch(self): - def matmul_rhs_batch(module): - module.matmul_rhs_batch( - tf_utils.uniform([LEFT_DIM, INNER_DIM]), - tf_utils.uniform([BATCH_DIM, INNER_DIM, RIGHT_DIM])) - self.compare_backends(matmul_rhs_batch, self._modules) - - def test_matmul_broadcast_singleton_dimension(self): - def matmul_broadcast_singleton_dimension(module): - module.matmul_broadcast_singleton_dimension( - tf_utils.uniform([1, LEFT_DIM, INNER_DIM]), - tf_utils.uniform([BATCH_DIM, INNER_DIM, RIGHT_DIM])) - self.compare_backends(matmul_broadcast_singleton_dimension, self._modules) - # yapf: enable + self.compare_backends(matmul_broadcast_singleton_dimension, self._modules) def main(argv): - del argv # Unused - if hasattr(tf, 'enable_v2_behavior'): - tf.enable_v2_behavior() - tf.test.main() + del argv # Unused + if hasattr(tf, "enable_v2_behavior"): + tf.enable_v2_behavior() + tf.test.main() -if __name__ == '__main__': - app.run(main) +if __name__ == "__main__": + app.run(main) diff --git a/integrations/tensorflow/test/python/iree_tf_tests/mobile_bert_squad_test.py b/integrations/tensorflow/test/python/iree_tf_tests/mobile_bert_squad_test.py index 2a3aaff17229..8c554db109d0 100644 --- a/integrations/tensorflow/test/python/iree_tf_tests/mobile_bert_squad_test.py +++ b/integrations/tensorflow/test/python/iree_tf_tests/mobile_bert_squad_test.py @@ -20,64 +20,70 @@ FLAGS = flags.FLAGS -flags.DEFINE_boolean('use_quantized_weights', False, - 'Whether to use quantized or floating point weights.') +flags.DEFINE_boolean( + "use_quantized_weights", + False, + "Whether to use quantized or floating point weights.", +) MAX_SEQ_LENGTH = 384 # Max input sequence length used in mobilebert_squad. -FILE_NAME = 'mobilebert_squad_savedmodels' +FILE_NAME = "mobilebert_squad_savedmodels" MODEL_URL = posixpath.join( - f'https://storage.googleapis.com/cloud-tpu-checkpoints/mobilebert/{FILE_NAME}.tar.gz' + f"https://storage.googleapis.com/cloud-tpu-checkpoints/mobilebert/{FILE_NAME}.tar.gz" ) class MobileBertSquadTest(tf_test_utils.TracedModuleTestCase): - """Tests of MobileBertSquad.""" - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - model_type = 'quant_saved_model' if FLAGS.use_quantized_weights else 'float' - - # Get_file will download the model weights from a publicly available folder, - # save them to cache_dir=~/.keras/datasets/ and return a path to them. - model_path = tf.keras.utils.get_file( - FILE_NAME, - MODEL_URL, - untar=True, - cache_dir=tf_test_utils._setup_artifacts_dir("download")) - model_dir = os.path.dirname(model_path) - extracted_name = FILE_NAME.split('.')[0] - model_path = os.path.join(model_dir, extracted_name, model_type) - - self._modules = tf_test_utils.compile_tf_signature_def_saved_model( - saved_model_dir=model_path, - saved_model_tags=set(['serve']), - module_name='MobileBertSquad', - exported_name='serving_default', - input_names=['input_ids', 'input_mask', 'segment_ids'], - output_names=['start_logits', 'end_logits']) - - def test_serving_default(self): - - def serving_default(module): - input_ids = np.zeros((1, MAX_SEQ_LENGTH), dtype=np.int32) - input_mask = np.zeros((1, MAX_SEQ_LENGTH), dtype=np.int32) - segment_ids = np.zeros((1, MAX_SEQ_LENGTH), dtype=np.int32) - - module.serving_default(input_ids=input_ids, - input_mask=input_mask, - segment_ids=segment_ids, - atol=1e0) - - self.compare_backends(serving_default, self._modules) + """Tests of MobileBertSquad.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + model_type = "quant_saved_model" if FLAGS.use_quantized_weights else "float" + + # Get_file will download the model weights from a publicly available folder, + # save them to cache_dir=~/.keras/datasets/ and return a path to them. + model_path = tf.keras.utils.get_file( + FILE_NAME, + MODEL_URL, + untar=True, + cache_dir=tf_test_utils._setup_artifacts_dir("download"), + ) + model_dir = os.path.dirname(model_path) + extracted_name = FILE_NAME.split(".")[0] + model_path = os.path.join(model_dir, extracted_name, model_type) + + self._modules = tf_test_utils.compile_tf_signature_def_saved_model( + saved_model_dir=model_path, + saved_model_tags=set(["serve"]), + module_name="MobileBertSquad", + exported_name="serving_default", + input_names=["input_ids", "input_mask", "segment_ids"], + output_names=["start_logits", "end_logits"], + ) + + def test_serving_default(self): + def serving_default(module): + input_ids = np.zeros((1, MAX_SEQ_LENGTH), dtype=np.int32) + input_mask = np.zeros((1, MAX_SEQ_LENGTH), dtype=np.int32) + segment_ids = np.zeros((1, MAX_SEQ_LENGTH), dtype=np.int32) + + module.serving_default( + input_ids=input_ids, + input_mask=input_mask, + segment_ids=segment_ids, + atol=1e0, + ) + + self.compare_backends(serving_default, self._modules) def main(argv): - del argv # Unused - if hasattr(tf, 'enable_v2_behavior'): - tf.enable_v2_behavior() - tf.test.main() + del argv # Unused + if hasattr(tf, "enable_v2_behavior"): + tf.enable_v2_behavior() + tf.test.main() -if __name__ == '__main__': - app.run(main) +if __name__ == "__main__": + app.run(main) diff --git a/integrations/tensorflow/test/python/iree_tf_tests/pytree_test.py b/integrations/tensorflow/test/python/iree_tf_tests/pytree_test.py index ba77687b773b..71329b4563b1 100644 --- a/integrations/tensorflow/test/python/iree_tf_tests/pytree_test.py +++ b/integrations/tensorflow/test/python/iree_tf_tests/pytree_test.py @@ -12,36 +12,34 @@ # Empty lists and dicts are currently unsupported. IREE also currently cannot # represent multiple sequence types, so we turn all sequences into tuples. class PyTreeModule(tf_test_utils.TestModule): + @tf_test_utils.tf_function_unit_test(input_signature=[]) + def output_tuple_len_1(self): + return (0,) - @tf_test_utils.tf_function_unit_test(input_signature=[]) - def output_tuple_len_1(self): - return (0,) + @tf_test_utils.tf_function_unit_test(input_signature=[]) + def output_tuple_len_2(self): + return 0, 1 - @tf_test_utils.tf_function_unit_test(input_signature=[]) - def output_tuple_len_2(self): - return 0, 1 + @tf_test_utils.tf_function_unit_test(input_signature=[]) + def output_tuple_len_3(self): + return 0, 1, 2 - @tf_test_utils.tf_function_unit_test(input_signature=[]) - def output_tuple_len_3(self): - return 0, 1, 2 - - @tf_test_utils.tf_function_unit_test(input_signature=[]) - def output_nested_pytree(self): - return {"key_a": (0, 1, 2), "key_b": (0, 1, {"key_c": (0, 1)})} + @tf_test_utils.tf_function_unit_test(input_signature=[]) + def output_nested_pytree(self): + return {"key_a": (0, 1, 2), "key_b": (0, 1, {"key_c": (0, 1)})} class PyTreeTest(tf_test_utils.TracedModuleTestCase): - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._modules = tf_test_utils.compile_tf_module(PyTreeModule) + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._modules = tf_test_utils.compile_tf_module(PyTreeModule) def main(argv): - del argv # Unused - PyTreeTest.generate_unit_tests(PyTreeModule) - tf.test.main() + del argv # Unused + PyTreeTest.generate_unit_tests(PyTreeModule) + tf.test.main() -if __name__ == '__main__': - app.run(main) +if __name__ == "__main__": + app.run(main) diff --git a/integrations/tensorflow/test/python/iree_tf_tests/quantization_dyn_test.py b/integrations/tensorflow/test/python/iree_tf_tests/quantization_dyn_test.py index 14b89302705c..7e0baaab71b9 100644 --- a/integrations/tensorflow/test/python/iree_tf_tests/quantization_dyn_test.py +++ b/integrations/tensorflow/test/python/iree_tf_tests/quantization_dyn_test.py @@ -13,37 +13,31 @@ class QuantizationDynModule(tf.Module): - - @tf.function(input_signature=[tf.TensorSpec([None], tf.float32)]) - def fake_quant(self, x): - return tf.quantization.fake_quant_with_min_max_args(x, - min=-6, - max=6, - num_bits=8, - narrow_range=False, - name=None) + @tf.function(input_signature=[tf.TensorSpec([None], tf.float32)]) + def fake_quant(self, x): + return tf.quantization.fake_quant_with_min_max_args( + x, min=-6, max=6, num_bits=8, narrow_range=False, name=None + ) class QuantizationDynTest(tf_test_utils.TracedModuleTestCase): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._modules = tf_test_utils.compile_tf_module(QuantizationDynModule) - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._modules = tf_test_utils.compile_tf_module(QuantizationDynModule) - - def test_fake_quant(self): - - def abs(module): - module.fake_quant(tf_utils.uniform([32], low=-6, high=6)) + def test_fake_quant(self): + def abs(module): + module.fake_quant(tf_utils.uniform([32], low=-6, high=6)) - self.compare_backends(abs, self._modules) + self.compare_backends(abs, self._modules) def main(argv): - del argv # Unused - if hasattr(tf, 'enable_v2_behavior'): - tf.enable_v2_behavior() - tf.test.main() + del argv # Unused + if hasattr(tf, "enable_v2_behavior"): + tf.enable_v2_behavior() + tf.test.main() -if __name__ == '__main__': - app.run(main) +if __name__ == "__main__": + app.run(main) diff --git a/integrations/tensorflow/test/python/iree_tf_tests/quantization_test.py b/integrations/tensorflow/test/python/iree_tf_tests/quantization_test.py index c96982a2f565..826a13866a13 100644 --- a/integrations/tensorflow/test/python/iree_tf_tests/quantization_test.py +++ b/integrations/tensorflow/test/python/iree_tf_tests/quantization_test.py @@ -13,34 +13,30 @@ class QuantizationModule(tf_test_utils.TestModule): - - @tf_test_utils.tf_function_unit_test( - input_signature=[tf.TensorSpec([32], tf.float32)], - input_generator=lambda *args: tf_utils.uniform(*args, low=-6, high=6)) - def fake_quant(self, x): - return tf.quantization.fake_quant_with_min_max_args(x, - min=-6, - max=6, - num_bits=8, - narrow_range=False, - name=None) + @tf_test_utils.tf_function_unit_test( + input_signature=[tf.TensorSpec([32], tf.float32)], + input_generator=lambda *args: tf_utils.uniform(*args, low=-6, high=6), + ) + def fake_quant(self, x): + return tf.quantization.fake_quant_with_min_max_args( + x, min=-6, max=6, num_bits=8, narrow_range=False, name=None + ) class QuantizationTest(tf_test_utils.TracedModuleTestCase): - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._modules = tf_test_utils.compile_tf_module(QuantizationModule) + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._modules = tf_test_utils.compile_tf_module(QuantizationModule) def main(argv): - del argv # Unused - if hasattr(tf, 'enable_v2_behavior'): - tf.enable_v2_behavior() + del argv # Unused + if hasattr(tf, "enable_v2_behavior"): + tf.enable_v2_behavior() - QuantizationTest.generate_unit_tests(QuantizationModule) - tf.test.main() + QuantizationTest.generate_unit_tests(QuantizationModule) + tf.test.main() -if __name__ == '__main__': - app.run(main) +if __name__ == "__main__": + app.run(main) diff --git a/integrations/tensorflow/test/python/iree_tf_tests/range_test.py b/integrations/tensorflow/test/python/iree_tf_tests/range_test.py index 08d7c4dacb3e..822afd4297a0 100644 --- a/integrations/tensorflow/test/python/iree_tf_tests/range_test.py +++ b/integrations/tensorflow/test/python/iree_tf_tests/range_test.py @@ -11,42 +11,41 @@ class RangeModule(tf.Module): + def __init__(self): + pass - def __init__(self): - pass - - @tf.function(input_signature=[ - tf.TensorSpec([], tf.float32), - tf.TensorSpec([], tf.float32), - tf.TensorSpec([], tf.float32) - ]) - def range(self, start, stop, delta): - return tf.range(start, stop, delta) + @tf.function( + input_signature=[ + tf.TensorSpec([], tf.float32), + tf.TensorSpec([], tf.float32), + tf.TensorSpec([], tf.float32), + ] + ) + def range(self, start, stop, delta): + return tf.range(start, stop, delta) class RangeTest(tf_test_utils.TracedModuleTestCase): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._modules = tf_test_utils.compile_tf_module(RangeModule) - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._modules = tf_test_utils.compile_tf_module(RangeModule) - - def test_range(self): - - def range(module): - start = np.array(3., dtype=np.float32) - stop = np.array(12., dtype=np.float32) - delta = np.array(3, dtype=np.float32) - result = module.range(start, stop, delta) + def test_range(self): + def range(module): + start = np.array(3.0, dtype=np.float32) + stop = np.array(12.0, dtype=np.float32) + delta = np.array(3, dtype=np.float32) + result = module.range(start, stop, delta) - self.compare_backends(range, self._modules) + self.compare_backends(range, self._modules) def main(argv): - del argv # Unused - if hasattr(tf, 'enable_v2_behavior'): - tf.enable_v2_behavior() - tf.test.main() + del argv # Unused + if hasattr(tf, "enable_v2_behavior"): + tf.enable_v2_behavior() + tf.test.main() -if __name__ == '__main__': - app.run(main) +if __name__ == "__main__": + app.run(main) diff --git a/integrations/tensorflow/test/python/iree_tf_tests/resource_ops_test.py b/integrations/tensorflow/test/python/iree_tf_tests/resource_ops_test.py index 3bf41cfee9df..e0b6d9a63cd6 100644 --- a/integrations/tensorflow/test/python/iree_tf_tests/resource_ops_test.py +++ b/integrations/tensorflow/test/python/iree_tf_tests/resource_ops_test.py @@ -11,52 +11,48 @@ class ResourcesOpsModule(tf.Module): + def __init__(self): + super().__init__() + self.counter = tf.Variable(0.0) - def __init__(self): - super().__init__() - self.counter = tf.Variable(0.0) + @tf.function(input_signature=[tf.TensorSpec([], tf.float32)]) + def add_assign(self, value): + return self.counter.assign_add(value) - @tf.function(input_signature=[tf.TensorSpec([], tf.float32)]) - def add_assign(self, value): - return self.counter.assign_add(value) + @tf.function(input_signature=[tf.TensorSpec([], tf.float32)]) + def set_value(self, new_value): + self.counter.assign(new_value) - @tf.function(input_signature=[tf.TensorSpec([], tf.float32)]) - def set_value(self, new_value): - self.counter.assign(new_value) - - @tf.function(input_signature=[]) - def get_value(self): - return self.counter + @tf.function(input_signature=[]) + def get_value(self): + return self.counter class ResourcesOpsTest(tf_test_utils.TracedModuleTestCase): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._modules = tf_test_utils.compile_tf_module(ResourcesOpsModule) - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._modules = tf_test_utils.compile_tf_module(ResourcesOpsModule) - - def test_add_assign(self): - - def add_assign(module): - module.add_assign(np.array(9., dtype=np.float32)) - - self.compare_backends(add_assign, self._modules) + def test_add_assign(self): + def add_assign(module): + module.add_assign(np.array(9.0, dtype=np.float32)) - def test_assign_get(self): + self.compare_backends(add_assign, self._modules) - def assign_get(module): - module.set_value(np.array(9., dtype=np.float32)) - return module.get_value() + def test_assign_get(self): + def assign_get(module): + module.set_value(np.array(9.0, dtype=np.float32)) + return module.get_value() - self.compare_backends(assign_get, self._modules) + self.compare_backends(assign_get, self._modules) def main(argv): - del argv # Unused - if hasattr(tf, 'enable_v2_behavior'): - tf.enable_v2_behavior() - tf.test.main() + del argv # Unused + if hasattr(tf, "enable_v2_behavior"): + tf.enable_v2_behavior() + tf.test.main() -if __name__ == '__main__': - app.run(main) +if __name__ == "__main__": + app.run(main) diff --git a/integrations/tensorflow/test/python/iree_tf_tests/ring_buffer_test.py b/integrations/tensorflow/test/python/iree_tf_tests/ring_buffer_test.py index 3dd345dd7b32..0fb1541199e9 100644 --- a/integrations/tensorflow/test/python/iree_tf_tests/ring_buffer_test.py +++ b/integrations/tensorflow/test/python/iree_tf_tests/ring_buffer_test.py @@ -15,198 +15,202 @@ class RingBuffer(tf.Module): - """Implements a RingBuffer.""" - - def __init__(self, buffer_size, dims, dtype): - self._buffer_size = buffer_size - self._dims = dims - - # buffer has size [buffer_size, dims] - # only the first dimension is used for updating buffer in a ring manner - self._buffer = tf.Variable(tf.zeros((self._buffer_size,) + dims, - dtype=dtype), - trainable=False, - name="RingBuffer") - # Size of the data available for reading - self._data_size = tf.Variable(0, - trainable=False, - dtype=tf.int32, - name="FramerBuffer/Size") - # The index pointing to the head of the data available for reading - self._read_head = tf.Variable(0, - trainable=False, - dtype=tf.int32, - name="FramerBuffer/Head") - - @property - def dtype(self): - return self._buffer.dtype - - @property - def dims(self): - return self._dims - - @tf.function - def get_write_headroom(self): - """Gets the available writable headroom. - - Returns: - integer scalar tensor of headroom. - """ - return self._buffer_size - self._data_size - - @tf.function - def get_read_available(self): - """Gets the available readable entries. - - Returns: - integer scalar tensor of headroom. - """ - return self._data_size - - @tf.function - def write(self, elements): - """Writes elements to the ringbuffer. - - Args: - elements: Elements to write. - - Returns: - Whether the write was successful (always True for now). - """ - elements_size = tf.shape(elements)[0] - start = tf.math.floormod( - self._read_head.read_value() + self._data_size.read_value(), - self._buffer_size) - indices = tf.math.floormod(tf.range(start, limit=start + elements_size), - self._buffer_size) - - tf.compat.v1.scatter_update(self._buffer, indices, elements) - - # special case when addition of new data, exceed _buffer_size: - # we start overwriting existing data in circular manner - # and need to update _read_head - if tf.greater(self._data_size + elements_size, self._buffer_size): - self._read_head.assign( - tf.math.floormod( - self._read_head.read_value() + self._data_size + - tf.math.floormod(elements_size, self._buffer_size), - self._buffer_size)) - - self._data_size.assign( - tf.minimum(self._data_size + elements_size, self._buffer_size)) - return tf.convert_to_tensor(True) - - @tf.function - def read(self, length, offset=0, consume=True): - """Reads elements from the ringbuffer. - - This will unconditionally read from the buffer and will produce undefined - outputs if attempting to read past the end. This does not consume from - the read buffer. - - Args: - length: The length of data to read. - offset: The offset into the readable area to begin. - consume: Consumes the read data (default true). - - Returns: - Tensor of elements with shape [length, dims...]. - """ - start = self._read_head + offset - indices = tf.math.floormod(tf.range(start, limit=start + length), - self._buffer_size) - result = tf.gather(self._buffer, indices) - if consume: - self.consume(length, offset) - return result - - @tf.function - def consume(self, length, offset=0): - """Consumes elements from the buffer. - - Args: - length: The length of data to read. - offset: The offset into the readable area to begin. - """ - start = self._read_head + offset - self._read_head.assign(tf.math.floormod(start + length, self._buffer_size)) - self._data_size.assign(self._data_size - length) + """Implements a RingBuffer.""" + + def __init__(self, buffer_size, dims, dtype): + self._buffer_size = buffer_size + self._dims = dims + + # buffer has size [buffer_size, dims] + # only the first dimension is used for updating buffer in a ring manner + self._buffer = tf.Variable( + tf.zeros((self._buffer_size,) + dims, dtype=dtype), + trainable=False, + name="RingBuffer", + ) + # Size of the data available for reading + self._data_size = tf.Variable( + 0, trainable=False, dtype=tf.int32, name="FramerBuffer/Size" + ) + # The index pointing to the head of the data available for reading + self._read_head = tf.Variable( + 0, trainable=False, dtype=tf.int32, name="FramerBuffer/Head" + ) + + @property + def dtype(self): + return self._buffer.dtype + + @property + def dims(self): + return self._dims + + @tf.function + def get_write_headroom(self): + """Gets the available writable headroom. + + Returns: + integer scalar tensor of headroom. + """ + return self._buffer_size - self._data_size + + @tf.function + def get_read_available(self): + """Gets the available readable entries. + + Returns: + integer scalar tensor of headroom. + """ + return self._data_size + + @tf.function + def write(self, elements): + """Writes elements to the ringbuffer. + + Args: + elements: Elements to write. + + Returns: + Whether the write was successful (always True for now). + """ + elements_size = tf.shape(elements)[0] + start = tf.math.floormod( + self._read_head.read_value() + self._data_size.read_value(), + self._buffer_size, + ) + indices = tf.math.floormod( + tf.range(start, limit=start + elements_size), self._buffer_size + ) + + tf.compat.v1.scatter_update(self._buffer, indices, elements) + + # special case when addition of new data, exceed _buffer_size: + # we start overwriting existing data in circular manner + # and need to update _read_head + if tf.greater(self._data_size + elements_size, self._buffer_size): + self._read_head.assign( + tf.math.floormod( + self._read_head.read_value() + + self._data_size + + tf.math.floormod(elements_size, self._buffer_size), + self._buffer_size, + ) + ) + + self._data_size.assign( + tf.minimum(self._data_size + elements_size, self._buffer_size) + ) + return tf.convert_to_tensor(True) + + @tf.function + def read(self, length, offset=0, consume=True): + """Reads elements from the ringbuffer. + + This will unconditionally read from the buffer and will produce undefined + outputs if attempting to read past the end. This does not consume from + the read buffer. + + Args: + length: The length of data to read. + offset: The offset into the readable area to begin. + consume: Consumes the read data (default true). + + Returns: + Tensor of elements with shape [length, dims...]. + """ + start = self._read_head + offset + indices = tf.math.floormod( + tf.range(start, limit=start + length), self._buffer_size + ) + result = tf.gather(self._buffer, indices) + if consume: + self.consume(length, offset) + return result + + @tf.function + def consume(self, length, offset=0): + """Consumes elements from the buffer. + + Args: + length: The length of data to read. + offset: The offset into the readable area to begin. + """ + start = self._read_head + offset + self._read_head.assign(tf.math.floormod(start + length, self._buffer_size)) + self._data_size.assign(self._data_size - length) class StatefulRingBuffer(tf.keras.layers.Layer): - - def __init__(self, state_shape=None, consume=False, **kwargs): - super().__init__(**kwargs) - self.state_shape = state_shape - self.consume = consume - - def build(self, input_shape): - super(StatefulRingBuffer, self).build(input_shape) - buffer_size = self.state_shape[1] - self.rb = RingBuffer(buffer_size=buffer_size, - dims=(self.state_shape[2],), - dtype=tf.float32) - - def call(self, inputs): - self.rb.write(inputs) - return self.rb.read(1, consume=self.consume) - - def get_config(self): - config = { - "state_shape": self.state_shape, - "consume": self.consume, - } - base_config = super(StatefulRingBuffer, self).get_config() - return dict(list(base_config.items()) + list(config.items())) + def __init__(self, state_shape=None, consume=False, **kwargs): + super().__init__(**kwargs) + self.state_shape = state_shape + self.consume = consume + + def build(self, input_shape): + super(StatefulRingBuffer, self).build(input_shape) + buffer_size = self.state_shape[1] + self.rb = RingBuffer( + buffer_size=buffer_size, dims=(self.state_shape[2],), dtype=tf.float32 + ) + + def call(self, inputs): + self.rb.write(inputs) + return self.rb.read(1, consume=self.consume) + + def get_config(self): + config = { + "state_shape": self.state_shape, + "consume": self.consume, + } + base_config = super(StatefulRingBuffer, self).get_config() + return dict(list(base_config.items()) + list(config.items())) class StatefulRingBufferModule(tf.Module): + def __init__(self): + super().__init__() + state_shape = [BATCH_SIZE, TIME_SIZE, FEATURE_SIZE] + self.rb = StatefulRingBuffer(state_shape=state_shape) - def __init__(self): - super().__init__() - state_shape = [BATCH_SIZE, TIME_SIZE, FEATURE_SIZE] - self.rb = StatefulRingBuffer(state_shape=state_shape) - - @tf.function( - input_signature=[tf.TensorSpec([BATCH_SIZE, FEATURE_SIZE], tf.float32)]) - def predict(self, x): - return self.rb(x) + @tf.function( + input_signature=[tf.TensorSpec([BATCH_SIZE, FEATURE_SIZE], tf.float32)] + ) + def predict(self, x): + return self.rb(x) class StatefulRingBufferTest(tf_test_utils.TracedModuleTestCase): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._modules = tf_test_utils.compile_tf_module( + StatefulRingBufferModule, exported_names=["predict"] + ) - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._modules = tf_test_utils.compile_tf_module(StatefulRingBufferModule, - exported_names=["predict"]) - - def test_stateful_ringbuffer(self): - - def stateful_ringbuffer(module): - input1 = np.array([[1.0, 2.0]], dtype=np.float32) - module.predict(input1) - # output = np.array([[1.0, 2.0]], dtype=np.float32) + def test_stateful_ringbuffer(self): + def stateful_ringbuffer(module): + input1 = np.array([[1.0, 2.0]], dtype=np.float32) + module.predict(input1) + # output = np.array([[1.0, 2.0]], dtype=np.float32) - # ring buffer is not filled yet so data from first cycle will be returned. - input2 = np.array([[3.0, 4.0]], dtype=np.float32) - module.predict(input2) - # output = np.array([[1.0, 2.0]], dtype=np.float32) + # ring buffer is not filled yet so data from first cycle will be returned. + input2 = np.array([[3.0, 4.0]], dtype=np.float32) + module.predict(input2) + # output = np.array([[1.0, 2.0]], dtype=np.float32) - # on 3rd cycle we overwrite oldest data and return data from 2nd cycle. - input3 = np.array([[5.0, 6.0]], dtype=np.float32) - module.predict(input3) - # output = np.array([[3.0, 4.0]], dtype=np.float32) + # on 3rd cycle we overwrite oldest data and return data from 2nd cycle. + input3 = np.array([[5.0, 6.0]], dtype=np.float32) + module.predict(input3) + # output = np.array([[3.0, 4.0]], dtype=np.float32) - self.compare_backends(stateful_ringbuffer, self._modules) + self.compare_backends(stateful_ringbuffer, self._modules) def main(argv): - del argv # Unused - if hasattr(tf, 'enable_v2_behavior'): - tf.enable_v2_behavior() - tf.test.main() + del argv # Unused + if hasattr(tf, "enable_v2_behavior"): + tf.enable_v2_behavior() + tf.test.main() -if __name__ == '__main__': - app.run(main) +if __name__ == "__main__": + app.run(main) diff --git a/integrations/tensorflow/test/python/iree_tf_tests/scatter_update_test.py b/integrations/tensorflow/test/python/iree_tf_tests/scatter_update_test.py index 6519e78e22a7..9754f2579abc 100644 --- a/integrations/tensorflow/test/python/iree_tf_tests/scatter_update_test.py +++ b/integrations/tensorflow/test/python/iree_tf_tests/scatter_update_test.py @@ -12,74 +12,79 @@ class ScatterUpdateModule(tf.Module): - - def __init__(self): - pass - - @tf.function(input_signature=[ - tf.TensorSpec([8], tf.int32), - tf.TensorSpec([3, 1], tf.int32), - tf.TensorSpec([3], tf.int32) - ]) - def scatter_update_1D(self, tensor, indices, updates): - return tf.tensor_scatter_nd_update(tensor, indices, updates) - - @tf.function(input_signature=[ - tf.TensorSpec([4, 3], tf.int32), - tf.TensorSpec([3, 2], tf.int32), - tf.TensorSpec([3], tf.int32) - ]) - def scatter_update_2D(self, tensor, indices, updates): - return tf.tensor_scatter_nd_update(tensor, indices, updates) - - @tf.function(input_signature=[ - tf.TensorSpec([4, 3], tf.int32), - tf.TensorSpec([1, 1], tf.int32), - tf.TensorSpec([1, 3], tf.int32) - ]) - def scatter_update_2D_slice(self, tensor, indices, updates): - return tf.tensor_scatter_nd_update(tensor, indices, updates) + def __init__(self): + pass + + @tf.function( + input_signature=[ + tf.TensorSpec([8], tf.int32), + tf.TensorSpec([3, 1], tf.int32), + tf.TensorSpec([3], tf.int32), + ] + ) + def scatter_update_1D(self, tensor, indices, updates): + return tf.tensor_scatter_nd_update(tensor, indices, updates) + + @tf.function( + input_signature=[ + tf.TensorSpec([4, 3], tf.int32), + tf.TensorSpec([3, 2], tf.int32), + tf.TensorSpec([3], tf.int32), + ] + ) + def scatter_update_2D(self, tensor, indices, updates): + return tf.tensor_scatter_nd_update(tensor, indices, updates) + + @tf.function( + input_signature=[ + tf.TensorSpec([4, 3], tf.int32), + tf.TensorSpec([1, 1], tf.int32), + tf.TensorSpec([1, 3], tf.int32), + ] + ) + def scatter_update_2D_slice(self, tensor, indices, updates): + return tf.tensor_scatter_nd_update(tensor, indices, updates) class ScatterUpdateTest(tf_test_utils.TracedModuleTestCase): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._modules = tf_test_utils.compile_tf_module(ScatterUpdateModule) + + def test_scatter_update_1D(self): + def scatter_update_1D(module): + tensor = np.ones([8], dtype=np.int32) + indices = np.array([[4], [5], [6]], dtype=np.int32) + updates = np.array([9, 10, 11], dtype=np.int32) + module.scatter_update_1D(tensor, indices, updates) + + self.compare_backends(scatter_update_1D, self._modules) + + def test_scatter_update_2D(self): + def scatter_update_2D(module): + tensor = np.ones([4, 3], dtype=np.int32) + indices = np.array([[1, 0], [2, 1], [3, 2]], dtype=np.int32) + updates = np.array([2, 5, 8], dtype=np.int32) + module.scatter_update_2D(tensor, indices, updates) + + self.compare_backends(scatter_update_2D, self._modules) + + def test_scatter_update_2D_slice(self): + def scatter_update_2D_slice(module): + tensor = np.ones([4, 3], dtype=np.int32) + indices = np.array([[1]], dtype=np.int32) + updates = np.array([[2, 3, 4]], dtype=np.int32) + module.scatter_update_2D_slice(tensor, indices, updates) - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._modules = tf_test_utils.compile_tf_module(ScatterUpdateModule) - - # yapf: disable - def test_scatter_update_1D(self): - def scatter_update_1D(module): - tensor = np.ones([8], dtype=np.int32) - indices = np.array([[4], [5], [6]], dtype=np.int32) - updates = np.array([9, 10, 11], dtype=np.int32) - module.scatter_update_1D(tensor, indices, updates) - self.compare_backends(scatter_update_1D, self._modules) - - def test_scatter_update_2D(self): - def scatter_update_2D(module): - tensor = np.ones([4, 3], dtype=np.int32) - indices = np.array([[1, 0], [2, 1], [3, 2]], dtype=np.int32) - updates = np.array([2, 5, 8], dtype=np.int32) - module.scatter_update_2D(tensor, indices, updates) - self.compare_backends(scatter_update_2D, self._modules) - - def test_scatter_update_2D_slice(self): - def scatter_update_2D_slice(module): - tensor = np.ones([4, 3], dtype=np.int32) - indices = np.array([[1]], dtype=np.int32) - updates = np.array([[2, 3, 4]], dtype=np.int32) - module.scatter_update_2D_slice(tensor, indices, updates) - self.compare_backends(scatter_update_2D_slice, self._modules) - # yapf: enable + self.compare_backends(scatter_update_2D_slice, self._modules) def main(argv): - del argv # Unused - if hasattr(tf, 'enable_v2_behavior'): - tf.enable_v2_behavior() - tf.test.main() + del argv # Unused + if hasattr(tf, "enable_v2_behavior"): + tf.enable_v2_behavior() + tf.test.main() -if __name__ == '__main__': - app.run(main) +if __name__ == "__main__": + app.run(main) diff --git a/integrations/tensorflow/test/python/iree_tf_tests/simple_arithmetic_test.py b/integrations/tensorflow/test/python/iree_tf_tests/simple_arithmetic_test.py index c6a4ae878bf9..ac8c1c500400 100644 --- a/integrations/tensorflow/test/python/iree_tf_tests/simple_arithmetic_test.py +++ b/integrations/tensorflow/test/python/iree_tf_tests/simple_arithmetic_test.py @@ -13,55 +13,52 @@ class SimpleArithmeticModule(tf.Module): - - @tf.function(input_signature=[ - tf.TensorSpec([4], tf.float32), - tf.TensorSpec([4], tf.float32) - ]) - def simple_mul(self, a, b): - return a * b - - @tf.function(input_signature=[ - tf.TensorSpec([128, 3072], tf.float32), - tf.TensorSpec([3072, 256], tf.float32), - ]) - def simple_matmul(self, a, b): - return tf.matmul(a, b) + @tf.function( + input_signature=[tf.TensorSpec([4], tf.float32), tf.TensorSpec([4], tf.float32)] + ) + def simple_mul(self, a, b): + return a * b + + @tf.function( + input_signature=[ + tf.TensorSpec([128, 3072], tf.float32), + tf.TensorSpec([3072, 256], tf.float32), + ] + ) + def simple_matmul(self, a, b): + return tf.matmul(a, b) class SimpleArithmeticTest(tf_test_utils.TracedModuleTestCase): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._modules = tf_test_utils.compile_tf_module(SimpleArithmeticModule) - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._modules = tf_test_utils.compile_tf_module(SimpleArithmeticModule) - - def test_simple_mul(self): - - def simple_mul(module): - a = np.array([1., 2., 3., 4.], dtype=np.float32) - b = np.array([400., 5., 6., 7.], dtype=np.float32) - c = module.simple_mul(a, b) - module.simple_mul(a, c) - - self.compare_backends(simple_mul, self._modules) + def test_simple_mul(self): + def simple_mul(module): + a = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32) + b = np.array([400.0, 5.0, 6.0, 7.0], dtype=np.float32) + c = module.simple_mul(a, b) + module.simple_mul(a, c) - def test_simple_matmul(self): + self.compare_backends(simple_mul, self._modules) - def simple_matmul(module): - # Note: scaling by a small value to increase numerical stability. - a = tf_utils.uniform((128, 3072)) * 1e-3 - b = tf_utils.uniform((3072, 256)) * 1e-3 - module.simple_matmul(a, b) + def test_simple_matmul(self): + def simple_matmul(module): + # Note: scaling by a small value to increase numerical stability. + a = tf_utils.uniform((128, 3072)) * 1e-3 + b = tf_utils.uniform((3072, 256)) * 1e-3 + module.simple_matmul(a, b) - self.compare_backends(simple_matmul, self._modules) + self.compare_backends(simple_matmul, self._modules) def main(argv): - del argv # Unused - if hasattr(tf, 'enable_v2_behavior'): - tf.enable_v2_behavior() - tf.test.main() + del argv # Unused + if hasattr(tf, "enable_v2_behavior"): + tf.enable_v2_behavior() + tf.test.main() -if __name__ == '__main__': - app.run(main) +if __name__ == "__main__": + app.run(main) diff --git a/integrations/tensorflow/test/python/iree_tf_tests/simple_stateful_test.py b/integrations/tensorflow/test/python/iree_tf_tests/simple_stateful_test.py index c2d88cab6fbf..4bcf2828087e 100644 --- a/integrations/tensorflow/test/python/iree_tf_tests/simple_stateful_test.py +++ b/integrations/tensorflow/test/python/iree_tf_tests/simple_stateful_test.py @@ -11,41 +11,38 @@ class SimpleStatefulModule(tf.Module): + def __init__(self): + super().__init__() + self.counter = tf.Variable(0.0) - def __init__(self): - super().__init__() - self.counter = tf.Variable(0.0) + @tf.function(input_signature=[tf.TensorSpec([], tf.float32)]) + def inc_by(self, x): + self.counter.assign(self.counter + x) - @tf.function(input_signature=[tf.TensorSpec([], tf.float32)]) - def inc_by(self, x): - self.counter.assign(self.counter + x) - - @tf.function(input_signature=[]) - def get_state(self): - return self.counter + @tf.function(input_signature=[]) + def get_state(self): + return self.counter class StatefulTest(tf_test_utils.TracedModuleTestCase): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._modules = tf_test_utils.compile_tf_module(SimpleStatefulModule) - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._modules = tf_test_utils.compile_tf_module(SimpleStatefulModule) - - def test_stateful(self): - - def get_state(module): - module.inc_by(np.array(1., dtype=np.float32)) - module.get_state() + def test_stateful(self): + def get_state(module): + module.inc_by(np.array(1.0, dtype=np.float32)) + module.get_state() - self.compare_backends(get_state, self._modules) + self.compare_backends(get_state, self._modules) def main(argv): - del argv # Unused - if hasattr(tf, 'enable_v2_behavior'): - tf.enable_v2_behavior() - tf.test.main() + del argv # Unused + if hasattr(tf, "enable_v2_behavior"): + tf.enable_v2_behavior() + tf.test.main() -if __name__ == '__main__': - app.run(main) +if __name__ == "__main__": + app.run(main) diff --git a/integrations/tensorflow/test/python/iree_tf_tests/sliding_window_test.py b/integrations/tensorflow/test/python/iree_tf_tests/sliding_window_test.py index fcbad1796af6..b9daf74a6ba3 100644 --- a/integrations/tensorflow/test/python/iree_tf_tests/sliding_window_test.py +++ b/integrations/tensorflow/test/python/iree_tf_tests/sliding_window_test.py @@ -15,86 +15,85 @@ class SlidingWindow(tf.keras.layers.Layer): - # It is another version of a ring buffer - # during call() it appends new update and remove the oldest one + # It is another version of a ring buffer + # during call() it appends new update and remove the oldest one - def __init__(self, state_shape=None, **kwargs): - super().__init__(**kwargs) + def __init__(self, state_shape=None, **kwargs): + super().__init__(**kwargs) - self.state_shape = state_shape + self.state_shape = state_shape - def build(self, input_shape): - super(SlidingWindow, self).build(input_shape) + def build(self, input_shape): + super(SlidingWindow, self).build(input_shape) - self.states = self.add_weight( - name="states", - shape=self.state_shape, # [batch, time, feature] - trainable=False, - initializer=tf.zeros_initializer) + self.states = self.add_weight( + name="states", + shape=self.state_shape, # [batch, time, feature] + trainable=False, + initializer=tf.zeros_initializer, + ) - def call(self, inputs): + def call(self, inputs): + # [batch_size, 1, feature_dim] + inputs_time = tf.keras.backend.expand_dims(inputs, -2) - # [batch_size, 1, feature_dim] - inputs_time = tf.keras.backend.expand_dims(inputs, -2) + # remove latest row [batch_size, (memory_size-1), feature_dim] + memory = self.states[:, 1 : self.state_shape[1], :] - # remove latest row [batch_size, (memory_size-1), feature_dim] - memory = self.states[:, 1:self.state_shape[1], :] + # add new row [batch_size, memory_size, feature_dim] + memory = tf.keras.backend.concatenate([memory, inputs_time], 1) - # add new row [batch_size, memory_size, feature_dim] - memory = tf.keras.backend.concatenate([memory, inputs_time], 1) + self.states.assign(memory) - self.states.assign(memory) + return self.states - return self.states - - def get_config(self): - config = { - "state_shape": self.state_shape, - } - base_config = super(SlidingWindow, self).get_config() - return dict(list(base_config.items()) + list(config.items())) + def get_config(self): + config = { + "state_shape": self.state_shape, + } + base_config = super(SlidingWindow, self).get_config() + return dict(list(base_config.items()) + list(config.items())) class SlidingWindowModule(tf.Module): + def __init__(self): + super().__init__() + state_shape = [BATCH_SIZE, TIME_SIZE, FEATURE_SIZE] + self.sw = SlidingWindow(state_shape=state_shape) - def __init__(self): - super().__init__() - state_shape = [BATCH_SIZE, TIME_SIZE, FEATURE_SIZE] - self.sw = SlidingWindow(state_shape=state_shape) - - @tf.function( - input_signature=[tf.TensorSpec([BATCH_SIZE, FEATURE_SIZE], tf.float32)]) - def predict(self, x): - return self.sw(x) + @tf.function( + input_signature=[tf.TensorSpec([BATCH_SIZE, FEATURE_SIZE], tf.float32)] + ) + def predict(self, x): + return self.sw(x) class SlidingWindowTest(tf_test_utils.TracedModuleTestCase): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._modules = tf_test_utils.compile_tf_module( + SlidingWindowModule, exported_names=["predict"] + ) - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._modules = tf_test_utils.compile_tf_module(SlidingWindowModule, - exported_names=["predict"]) - - def test_sliding_window(self): - - def sliding_window(module): - input1 = np.array([[1.0, 2.0]], dtype=np.float32) - result1 = module.predict(input1) - # output1 = np.array([[0.0, 0.0], [0.0, 0.0], [1.0, 2.0]], dtype=np.float32) + def test_sliding_window(self): + def sliding_window(module): + input1 = np.array([[1.0, 2.0]], dtype=np.float32) + result1 = module.predict(input1) + # output1 = np.array([[0.0, 0.0], [0.0, 0.0], [1.0, 2.0]], dtype=np.float32) - input2 = np.array([[3.0, 4.0]], dtype=np.float32) - result2 = module.predict(input2) - # output2 = np.array([[0.0, 0.0], [1.0, 2.0], [3.0, 4.0]], dtype=np.float32) + input2 = np.array([[3.0, 4.0]], dtype=np.float32) + result2 = module.predict(input2) + # output2 = np.array([[0.0, 0.0], [1.0, 2.0], [3.0, 4.0]], dtype=np.float32) - self.compare_backends(sliding_window, self._modules) + self.compare_backends(sliding_window, self._modules) def main(argv): - del argv # Unused - if hasattr(tf, 'enable_v2_behavior'): - tf.enable_v2_behavior() - tf.test.main() + del argv # Unused + if hasattr(tf, "enable_v2_behavior"): + tf.enable_v2_behavior() + tf.test.main() -if __name__ == '__main__': - app.run(main) +if __name__ == "__main__": + app.run(main) diff --git a/integrations/tensorflow/test/python/iree_tf_tests/space_to_batch_nd_test.py b/integrations/tensorflow/test/python/iree_tf_tests/space_to_batch_nd_test.py index 44bbbf33f404..e949ff9ee5a6 100644 --- a/integrations/tensorflow/test/python/iree_tf_tests/space_to_batch_nd_test.py +++ b/integrations/tensorflow/test/python/iree_tf_tests/space_to_batch_nd_test.py @@ -13,36 +13,33 @@ class SpaceToBatchModule(tf.Module): - - @tf.function(input_signature=[tf.TensorSpec([1, 8, 2], tf.float32)]) - def batch_to_space_nd(self, x): - block_shape = [3] - paddings = [[3, 4]] - return tf.space_to_batch_nd(x, block_shape, paddings) + @tf.function(input_signature=[tf.TensorSpec([1, 8, 2], tf.float32)]) + def batch_to_space_nd(self, x): + block_shape = [3] + paddings = [[3, 4]] + return tf.space_to_batch_nd(x, block_shape, paddings) class SpaceToBatchTest(tf_test_utils.TracedModuleTestCase): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._modules = tf_test_utils.compile_tf_module(SpaceToBatchModule) - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._modules = tf_test_utils.compile_tf_module(SpaceToBatchModule) - - def test_space_to_batch_inference(self): - - def space_to_batch_inference(module): - x = np.linspace(0, 15, 16, dtype=np.float32) - x = np.reshape(x, [1, 8, 2]) - module.batch_to_space_nd(x) + def test_space_to_batch_inference(self): + def space_to_batch_inference(module): + x = np.linspace(0, 15, 16, dtype=np.float32) + x = np.reshape(x, [1, 8, 2]) + module.batch_to_space_nd(x) - self.compare_backends(space_to_batch_inference, self._modules) + self.compare_backends(space_to_batch_inference, self._modules) def main(argv): - del argv # Unused - if hasattr(tf, 'enable_v2_behavior'): - tf.enable_v2_behavior() - tf.test.main() + del argv # Unused + if hasattr(tf, "enable_v2_behavior"): + tf.enable_v2_behavior() + tf.test.main() -if __name__ == '__main__': - app.run(main) +if __name__ == "__main__": + app.run(main) diff --git a/integrations/tensorflow/test/python/iree_tfl_tests/cartoon_gan_test.py b/integrations/tensorflow/test/python/iree_tfl_tests/cartoon_gan_test.py index aa67a3f6c742..1cb3ee1b0c72 100644 --- a/integrations/tensorflow/test/python/iree_tfl_tests/cartoon_gan_test.py +++ b/integrations/tensorflow/test/python/iree_tfl_tests/cartoon_gan_test.py @@ -11,13 +11,12 @@ class CartoonGanTest(test_util.TFLiteModelTest): + def __init__(self, *args, **kwargs): + super(CartoonGanTest, self).__init__(model_path, *args, **kwargs) - def __init__(self, *args, **kwargs): - super(CartoonGanTest, self).__init__(model_path, *args, **kwargs) + def test_compile_tflite(self): + self.compile_and_execute() - def test_compile_tflite(self): - self.compile_and_execute() - -if __name__ == '__main__': - absl.testing.absltest.main() +if __name__ == "__main__": + absl.testing.absltest.main() diff --git a/integrations/tensorflow/test/python/iree_tfl_tests/east_text_detector_test.py b/integrations/tensorflow/test/python/iree_tfl_tests/east_text_detector_test.py index b5d9f26470e4..d0a103b5465a 100644 --- a/integrations/tensorflow/test/python/iree_tfl_tests/east_text_detector_test.py +++ b/integrations/tensorflow/test/python/iree_tfl_tests/east_text_detector_test.py @@ -8,32 +8,36 @@ import numpy from . import test_util -model_path = "https://tfhub.dev/sayakpaul/lite-model/east-text-detector/dr/1?lite-format=tflite" +model_path = ( + "https://tfhub.dev/sayakpaul/lite-model/east-text-detector/dr/1?lite-format=tflite" +) class EastTextDetectorTest(test_util.TFLiteModelTest): + def __init__(self, *args, **kwargs): + super(EastTextDetectorTest, self).__init__(model_path, *args, **kwargs) - def __init__(self, *args, **kwargs): - super(EastTextDetectorTest, self).__init__(model_path, *args, **kwargs) + def compare_results(self, iree_results, tflite_results, details): + super(EastTextDetectorTest, self).compare_results( + iree_results, tflite_results, details + ) + self.assertTrue( + numpy.isclose(iree_results[0], tflite_results[0], atol=1e-3).all() + ) - def compare_results(self, iree_results, tflite_results, details): - super(EastTextDetectorTest, self).compare_results(iree_results, - tflite_results, details) - self.assertTrue( - numpy.isclose(iree_results[0], tflite_results[0], atol=1e-3).all()) + # The second return is extremely noisy as it is not a binary classification. To handle we + # check normalized correlation with an expectation of "close enough". + iree_norm = numpy.sqrt(iree_results[1] * iree_results[1]) + tflite_norm = numpy.sqrt(tflite_results[1] * tflite_results[1]) - # The second return is extremely noisy as it is not a binary classification. To handle we - # check normalized correlation with an expectation of "close enough". - iree_norm = numpy.sqrt(iree_results[1] * iree_results[1]) - tflite_norm = numpy.sqrt(tflite_results[1] * tflite_results[1]) + correlation = numpy.average( + iree_results[1] * tflite_results[1] / iree_norm / tflite_norm + ) + self.assertTrue(numpy.isclose(correlation, 1.0, atol=1e-2).all()) - correlation = numpy.average(iree_results[1] * tflite_results[1] / - iree_norm / tflite_norm) - self.assertTrue(numpy.isclose(correlation, 1.0, atol=1e-2).all()) + def test_compile_tflite(self): + self.compile_and_execute() - def test_compile_tflite(self): - self.compile_and_execute() - -if __name__ == '__main__': - absl.testing.absltest.main() +if __name__ == "__main__": + absl.testing.absltest.main() diff --git a/integrations/tensorflow/test/python/iree_tfl_tests/gpt2_test.py b/integrations/tensorflow/test/python/iree_tfl_tests/gpt2_test.py index 5d98e7e425d5..cccf17f09598 100644 --- a/integrations/tensorflow/test/python/iree_tfl_tests/gpt2_test.py +++ b/integrations/tensorflow/test/python/iree_tfl_tests/gpt2_test.py @@ -13,29 +13,32 @@ # This test is a massive download and excluded due to causing timeouts. class GPT2Test(test_util.TFLiteModelTest): - - def __init__(self, *args, **kwargs): - super(GPT2Test, self).__init__(model_path, *args, **kwargs) - - # Inputs modified to be useful mobilebert inputs. - def generate_inputs(self, input_details): - args = [] - args.append( - numpy.random.randint(low=0, - high=256, - size=input_details[0]["shape"], - dtype=input_details[0]["dtype"])) - return args - - def compare_results(self, iree_results, tflite_results, details): - super(GPT2Test, self).compare_results(iree_results, tflite_results, details) - for i in range(len(iree_results)): - self.assertTrue( - numpy.isclose(iree_results[i], tflite_results[i], atol=5e-3).all()) - - def test_compile_tflite(self): - self.compile_and_execute() - - -if __name__ == '__main__': - absl.testing.absltest.main() + def __init__(self, *args, **kwargs): + super(GPT2Test, self).__init__(model_path, *args, **kwargs) + + # Inputs modified to be useful mobilebert inputs. + def generate_inputs(self, input_details): + args = [] + args.append( + numpy.random.randint( + low=0, + high=256, + size=input_details[0]["shape"], + dtype=input_details[0]["dtype"], + ) + ) + return args + + def compare_results(self, iree_results, tflite_results, details): + super(GPT2Test, self).compare_results(iree_results, tflite_results, details) + for i in range(len(iree_results)): + self.assertTrue( + numpy.isclose(iree_results[i], tflite_results[i], atol=5e-3).all() + ) + + def test_compile_tflite(self): + self.compile_and_execute() + + +if __name__ == "__main__": + absl.testing.absltest.main() diff --git a/integrations/tensorflow/test/python/iree_tfl_tests/imagenet_test_data.py b/integrations/tensorflow/test/python/iree_tfl_tests/imagenet_test_data.py index 0de2cf13568a..b0702eefe9b2 100644 --- a/integrations/tensorflow/test/python/iree_tfl_tests/imagenet_test_data.py +++ b/integrations/tensorflow/test/python/iree_tfl_tests/imagenet_test_data.py @@ -6,11 +6,11 @@ # Returns a sample image in the Imagenet dataset in uint8. def generate_input(workdir, input_details): - # We use an image of apples since this is an easy example. - img_path = "https://storage.googleapis.com/iree-model-artifacts/ILSVRC2012_val_00000023.JPEG" - local_path = "/".join([workdir, "ILSVRC2012_val_00000023.JPEG"]) - urllib.request.urlretrieve(img_path, local_path) + # We use an image of apples since this is an easy example. + img_path = "https://storage.googleapis.com/iree-model-artifacts/ILSVRC2012_val_00000023.JPEG" + local_path = "/".join([workdir, "ILSVRC2012_val_00000023.JPEG"]) + urllib.request.urlretrieve(img_path, local_path) - shape = input_details[0]["shape"] - im = np.array(Image.open(local_path).resize((shape[1], shape[2]))) - return im.reshape(shape) + shape = input_details[0]["shape"] + im = np.array(Image.open(local_path).resize((shape[1], shape[2]))) + return im.reshape(shape) diff --git a/integrations/tensorflow/test/python/iree_tfl_tests/mnasnet_test.py b/integrations/tensorflow/test/python/iree_tfl_tests/mnasnet_test.py index 71f2044916e5..42777a8e39e5 100644 --- a/integrations/tensorflow/test/python/iree_tfl_tests/mnasnet_test.py +++ b/integrations/tensorflow/test/python/iree_tfl_tests/mnasnet_test.py @@ -12,13 +12,12 @@ class MnasnetTest(test_util.TFLiteModelTest): + def __init__(self, *args, **kwargs): + super(MnasnetTest, self).__init__(model_path, *args, **kwargs) - def __init__(self, *args, **kwargs): - super(MnasnetTest, self).__init__(model_path, *args, **kwargs) + def test_compile_tflite(self): + self.compile_and_execute() - def test_compile_tflite(self): - self.compile_and_execute() - -if __name__ == '__main__': - absl.testing.absltest.main() +if __name__ == "__main__": + absl.testing.absltest.main() diff --git a/integrations/tensorflow/test/python/iree_tfl_tests/mobilebert_tf2_quant_test.py b/integrations/tensorflow/test/python/iree_tfl_tests/mobilebert_tf2_quant_test.py index 924537df01dc..1ad9e6133927 100644 --- a/integrations/tensorflow/test/python/iree_tfl_tests/mobilebert_tf2_quant_test.py +++ b/integrations/tensorflow/test/python/iree_tfl_tests/mobilebert_tf2_quant_test.py @@ -10,45 +10,45 @@ class MobileBertTest(test_util.TFLiteModelTest): - - def __init__(self, *args, **kwargs): - super(MobileBertTest, self).__init__(model_path, *args, **kwargs) - - # Inputs modified to be useful mobilebert inputs. - def generate_inputs(self, input_details): - for input in input_details: - absl.logging.info("\t%s, %s", str(input["shape"]), - input["dtype"].__name__) - - input_0 = np.asarray(squad_test_data._INPUT_WORD_ID, - dtype=input_details[0]["dtype"]) - input_1 = np.asarray(squad_test_data._INPUT_TYPE_ID, - dtype=input_details[1]["dtype"]) - input_2 = np.asarray(squad_test_data._INPUT_MASK, - dtype=input_details[2]["dtype"]) - return [ - input_0.reshape(input_details[0]["shape"]), - input_1.reshape(input_details[1]["shape"]), - input_2.reshape(input_details[2]["shape"]) - ] - - def compare_results(self, iree_results, tflite_results, details): - super(MobileBertTest, self).compare_results(iree_results, tflite_results, - details) - # We have confirmed in large scale accuracy tests that differences as large - # as 5.0 is acceptable. We later further relaxed from 5.0 to 7.0 in - # https://github.com/openxla/iree/pull/9337 when quantized Softmax got - # de-quantized, which should be numerically correct albeit not bit-exact. - # The actual observed max error was ~ 6.36. The value 7.0 is that rounded up - # to the next integer. - self.assertTrue( - np.isclose(iree_results[0], tflite_results[0], atol=7.0).all()) - self.assertTrue( - np.isclose(iree_results[1], tflite_results[1], atol=7.0).all()) - - def test_compile_tflite(self): - self.compile_and_execute() - - -if __name__ == '__main__': - absl.testing.absltest.main() + def __init__(self, *args, **kwargs): + super(MobileBertTest, self).__init__(model_path, *args, **kwargs) + + # Inputs modified to be useful mobilebert inputs. + def generate_inputs(self, input_details): + for input in input_details: + absl.logging.info("\t%s, %s", str(input["shape"]), input["dtype"].__name__) + + input_0 = np.asarray( + squad_test_data._INPUT_WORD_ID, dtype=input_details[0]["dtype"] + ) + input_1 = np.asarray( + squad_test_data._INPUT_TYPE_ID, dtype=input_details[1]["dtype"] + ) + input_2 = np.asarray( + squad_test_data._INPUT_MASK, dtype=input_details[2]["dtype"] + ) + return [ + input_0.reshape(input_details[0]["shape"]), + input_1.reshape(input_details[1]["shape"]), + input_2.reshape(input_details[2]["shape"]), + ] + + def compare_results(self, iree_results, tflite_results, details): + super(MobileBertTest, self).compare_results( + iree_results, tflite_results, details + ) + # We have confirmed in large scale accuracy tests that differences as large + # as 5.0 is acceptable. We later further relaxed from 5.0 to 7.0 in + # https://github.com/openxla/iree/pull/9337 when quantized Softmax got + # de-quantized, which should be numerically correct albeit not bit-exact. + # The actual observed max error was ~ 6.36. The value 7.0 is that rounded up + # to the next integer. + self.assertTrue(np.isclose(iree_results[0], tflite_results[0], atol=7.0).all()) + self.assertTrue(np.isclose(iree_results[1], tflite_results[1], atol=7.0).all()) + + def test_compile_tflite(self): + self.compile_and_execute() + + +if __name__ == "__main__": + absl.testing.absltest.main() diff --git a/integrations/tensorflow/test/python/iree_tfl_tests/mobilenet_v1_test.py b/integrations/tensorflow/test/python/iree_tfl_tests/mobilenet_v1_test.py index ab9f15dfa1c7..2a3525b0ac8d 100644 --- a/integrations/tensorflow/test/python/iree_tfl_tests/mobilenet_v1_test.py +++ b/integrations/tensorflow/test/python/iree_tfl_tests/mobilenet_v1_test.py @@ -12,19 +12,20 @@ class MobilenetV1Test(test_util.TFLiteModelTest): + def __init__(self, *args, **kwargs): + super(MobilenetV1Test, self).__init__(model_path, *args, **kwargs) - def __init__(self, *args, **kwargs): - super(MobilenetV1Test, self).__init__(model_path, *args, **kwargs) + def compare_results(self, iree_results, tflite_results, details): + super(MobilenetV1Test, self).compare_results( + iree_results, tflite_results, details + ) + self.assertTrue( + numpy.isclose(iree_results[0], tflite_results[0], atol=1e-4).all() + ) - def compare_results(self, iree_results, tflite_results, details): - super(MobilenetV1Test, self).compare_results(iree_results, tflite_results, - details) - self.assertTrue( - numpy.isclose(iree_results[0], tflite_results[0], atol=1e-4).all()) + def test_compile_tflite(self): + self.compile_and_execute() - def test_compile_tflite(self): - self.compile_and_execute() - -if __name__ == '__main__': - absl.testing.absltest.main() +if __name__ == "__main__": + absl.testing.absltest.main() diff --git a/integrations/tensorflow/test/python/iree_tfl_tests/mobilenet_v3-large_uint8_test.py b/integrations/tensorflow/test/python/iree_tfl_tests/mobilenet_v3-large_uint8_test.py index 9b303f16159e..9e90809224d5 100644 --- a/integrations/tensorflow/test/python/iree_tfl_tests/mobilenet_v3-large_uint8_test.py +++ b/integrations/tensorflow/test/python/iree_tfl_tests/mobilenet_v3-large_uint8_test.py @@ -10,29 +10,30 @@ class MobilenetV3LargeUint8Test(test_util.TFLiteModelTest): - - def __init__(self, *args, **kwargs): - super(MobilenetV3LargeUint8Test, self).__init__(model_path, *args, **kwargs) - - def compare_results(self, iree_results, tflite_results, details): - super(MobilenetV3LargeUint8Test, - self).compare_results(iree_results, tflite_results, details) - # Dequantize outputs. - zero_point = details[0]['quantization_parameters']['zero_points'][0] - scale = details[0]['quantization_parameters']['scales'][0] - dequantized_iree_results = (iree_results - zero_point) * scale - dequantized_tflite_results = (tflite_results - zero_point) * scale - self.assertTrue( - numpy.isclose(dequantized_iree_results, - dequantized_tflite_results, - atol=5e-3).all()) - - def generate_inputs(self, input_details): - return [imagenet_test_data.generate_input(self.workdir, input_details)] - - def test_compile_tflite(self): - self.compile_and_execute() - - -if __name__ == '__main__': - absl.testing.absltest.main() + def __init__(self, *args, **kwargs): + super(MobilenetV3LargeUint8Test, self).__init__(model_path, *args, **kwargs) + + def compare_results(self, iree_results, tflite_results, details): + super(MobilenetV3LargeUint8Test, self).compare_results( + iree_results, tflite_results, details + ) + # Dequantize outputs. + zero_point = details[0]["quantization_parameters"]["zero_points"][0] + scale = details[0]["quantization_parameters"]["scales"][0] + dequantized_iree_results = (iree_results - zero_point) * scale + dequantized_tflite_results = (tflite_results - zero_point) * scale + self.assertTrue( + numpy.isclose( + dequantized_iree_results, dequantized_tflite_results, atol=5e-3 + ).all() + ) + + def generate_inputs(self, input_details): + return [imagenet_test_data.generate_input(self.workdir, input_details)] + + def test_compile_tflite(self): + self.compile_and_execute() + + +if __name__ == "__main__": + absl.testing.absltest.main() diff --git a/integrations/tensorflow/test/python/iree_tfl_tests/mobilenet_v3_test.py b/integrations/tensorflow/test/python/iree_tfl_tests/mobilenet_v3_test.py index e77e3951a64e..5d8769c8cf9e 100644 --- a/integrations/tensorflow/test/python/iree_tfl_tests/mobilenet_v3_test.py +++ b/integrations/tensorflow/test/python/iree_tfl_tests/mobilenet_v3_test.py @@ -12,19 +12,20 @@ class MobilenetV3Test(test_util.TFLiteModelTest): + def __init__(self, *args, **kwargs): + super(MobilenetV3Test, self).__init__(model_path, *args, **kwargs) - def __init__(self, *args, **kwargs): - super(MobilenetV3Test, self).__init__(model_path, *args, **kwargs) + def compare_results(self, iree_results, tflite_results, details): + super(MobilenetV3Test, self).compare_results( + iree_results, tflite_results, details + ) + self.assertTrue( + numpy.isclose(iree_results[0], tflite_results[0], atol=1e-4).all() + ) - def compare_results(self, iree_results, tflite_results, details): - super(MobilenetV3Test, self).compare_results(iree_results, tflite_results, - details) - self.assertTrue( - numpy.isclose(iree_results[0], tflite_results[0], atol=1e-4).all()) + def test_compile_tflite(self): + self.compile_and_execute() - def test_compile_tflite(self): - self.compile_and_execute() - -if __name__ == '__main__': - absl.testing.absltest.main() +if __name__ == "__main__": + absl.testing.absltest.main() diff --git a/integrations/tensorflow/test/python/iree_tfl_tests/person_detect_test.py b/integrations/tensorflow/test/python/iree_tfl_tests/person_detect_test.py index 5b7662ce671b..70bcab42f33b 100644 --- a/integrations/tensorflow/test/python/iree_tfl_tests/person_detect_test.py +++ b/integrations/tensorflow/test/python/iree_tfl_tests/person_detect_test.py @@ -15,46 +15,52 @@ class PersonDetectTest(test_util.TFLiteModelTest): + def __init__(self, *args, **kwargs): + super(PersonDetectTest, self).__init__(model_path, *args, **kwargs) - def __init__(self, *args, **kwargs): - super(PersonDetectTest, self).__init__(model_path, *args, **kwargs) - - def compare_results(self, iree_results, tflite_results, details): - super(PersonDetectTest, self).compare_results(iree_results, tflite_results, - details) - self.assertTrue( - numpy.isclose(iree_results[0], tflite_results[0], atol=1e-3).all()) - - # TFLite is broken with this model so we hardcode the input/output details. - def setup_tflite(self): - self.input_details = [{ - "shape": [1, 96, 96, 1], - "dtype": numpy.int8, - "index": 0, - }] - self.output_details = [{ - "shape": [1, 2], - "dtype": numpy.int8, - }] - - # The input has known expected values. We hardcode this value. - def invoke_tflite(self, args): - return [numpy.array([[-113, 113]], dtype=numpy.int8)] - - def generate_inputs(self, input_details): - img_path = "https://github.com/tensorflow/tflite-micro/raw/aeac6f39e5c7475cea20c54e86d41e3a38312546/tensorflow/lite/micro/examples/person_detection/testdata/person.bmp" - local_path = "/".join([self.workdir, "person.bmp"]) - urllib.request.urlretrieve(img_path, local_path) - - shape = input_details[0]["shape"] - im = numpy.array(Image.open(local_path).resize( - (shape[1], shape[2]))).astype(input_details[0]["dtype"]) - args = [im.reshape(shape)] - return args - - def test_compile_tflite(self): - self.compile_and_execute() - - -if __name__ == '__main__': - absl.testing.absltest.main() + def compare_results(self, iree_results, tflite_results, details): + super(PersonDetectTest, self).compare_results( + iree_results, tflite_results, details + ) + self.assertTrue( + numpy.isclose(iree_results[0], tflite_results[0], atol=1e-3).all() + ) + + # TFLite is broken with this model so we hardcode the input/output details. + def setup_tflite(self): + self.input_details = [ + { + "shape": [1, 96, 96, 1], + "dtype": numpy.int8, + "index": 0, + } + ] + self.output_details = [ + { + "shape": [1, 2], + "dtype": numpy.int8, + } + ] + + # The input has known expected values. We hardcode this value. + def invoke_tflite(self, args): + return [numpy.array([[-113, 113]], dtype=numpy.int8)] + + def generate_inputs(self, input_details): + img_path = "https://github.com/tensorflow/tflite-micro/raw/aeac6f39e5c7475cea20c54e86d41e3a38312546/tensorflow/lite/micro/examples/person_detection/testdata/person.bmp" + local_path = "/".join([self.workdir, "person.bmp"]) + urllib.request.urlretrieve(img_path, local_path) + + shape = input_details[0]["shape"] + im = numpy.array(Image.open(local_path).resize((shape[1], shape[2]))).astype( + input_details[0]["dtype"] + ) + args = [im.reshape(shape)] + return args + + def test_compile_tflite(self): + self.compile_and_execute() + + +if __name__ == "__main__": + absl.testing.absltest.main() diff --git a/integrations/tensorflow/test/python/iree_tfl_tests/posenet_i8_test.py b/integrations/tensorflow/test/python/iree_tfl_tests/posenet_i8_test.py index 3ec354028187..33c3b4b464d9 100644 --- a/integrations/tensorflow/test/python/iree_tfl_tests/posenet_i8_test.py +++ b/integrations/tensorflow/test/python/iree_tfl_tests/posenet_i8_test.py @@ -16,36 +16,38 @@ class PosenetI8Test(test_util.TFLiteModelTest): - - def __init__(self, *args, **kwargs): - super(PosenetI8Test, self).__init__(model_path, *args, **kwargs) - - def compare_results(self, iree_results, tflite_results, details): - super(PosenetI8Test, self).compare_results(iree_results, tflite_results, - details) - # This value is a discretized location of the persons joints. If we are - # *close* to the expected position we can consider this good enough. - self.assertTrue( - numpy.isclose(iree_results[0][:, :, :, 0], - tflite_results[0][:, :, :, 0], - atol=25e-3).all()) - self.assertTrue( - numpy.isclose(iree_results[0][:, :, :, 1], - tflite_results[0][:, :, :, 1], - atol=25e-3).all()) - - def generate_inputs(self, input_details): - local_path = "/".join([self.workdir, "person.jpg"]) - urllib.request.urlretrieve(model_input, local_path) - - shape = input_details[0]["shape"] - im = numpy.array(Image.open(local_path).resize((shape[1], shape[2]))) - args = [im.reshape(shape)] - return args - - def test_compile_tflite(self): - self.compile_and_execute() - - -if __name__ == '__main__': - absl.testing.absltest.main() + def __init__(self, *args, **kwargs): + super(PosenetI8Test, self).__init__(model_path, *args, **kwargs) + + def compare_results(self, iree_results, tflite_results, details): + super(PosenetI8Test, self).compare_results( + iree_results, tflite_results, details + ) + # This value is a discretized location of the persons joints. If we are + # *close* to the expected position we can consider this good enough. + self.assertTrue( + numpy.isclose( + iree_results[0][:, :, :, 0], tflite_results[0][:, :, :, 0], atol=25e-3 + ).all() + ) + self.assertTrue( + numpy.isclose( + iree_results[0][:, :, :, 1], tflite_results[0][:, :, :, 1], atol=25e-3 + ).all() + ) + + def generate_inputs(self, input_details): + local_path = "/".join([self.workdir, "person.jpg"]) + urllib.request.urlretrieve(model_input, local_path) + + shape = input_details[0]["shape"] + im = numpy.array(Image.open(local_path).resize((shape[1], shape[2]))) + args = [im.reshape(shape)] + return args + + def test_compile_tflite(self): + self.compile_and_execute() + + +if __name__ == "__main__": + absl.testing.absltest.main() diff --git a/integrations/tensorflow/test/python/iree_tfl_tests/resnet_50_int8_test.py b/integrations/tensorflow/test/python/iree_tfl_tests/resnet_50_int8_test.py index 9f07d1435650..226608d57349 100644 --- a/integrations/tensorflow/test/python/iree_tfl_tests/resnet_50_int8_test.py +++ b/integrations/tensorflow/test/python/iree_tfl_tests/resnet_50_int8_test.py @@ -10,24 +10,24 @@ class Resnet50Int8Test(test_util.TFLiteModelTest): + def __init__(self, *args, **kwargs): + super(Resnet50Int8Test, self).__init__(model_path, *args, **kwargs) - def __init__(self, *args, **kwargs): - super(Resnet50Int8Test, self).__init__(model_path, *args, **kwargs) + def compare_results(self, iree_results, tflite_results, details): + super(Resnet50Int8Test, self).compare_results( + iree_results, tflite_results, details + ) + self.assertTrue(numpy.isclose(iree_results, tflite_results, atol=0.3).all()) - def compare_results(self, iree_results, tflite_results, details): - super(Resnet50Int8Test, self).compare_results(iree_results, tflite_results, - details) - self.assertTrue(numpy.isclose(iree_results, tflite_results, atol=0.3).all()) + def generate_inputs(self, input_details): + inputs = imagenet_test_data.generate_input(self.workdir, input_details) + # Normalize inputs to [-1, 1]. + inputs = (inputs.astype("float32") / 127.5) - 1 + return [inputs] - def generate_inputs(self, input_details): - inputs = imagenet_test_data.generate_input(self.workdir, input_details) - # Normalize inputs to [-1, 1]. - inputs = (inputs.astype('float32') / 127.5) - 1 - return [inputs] + def test_compile_tflite(self): + self.compile_and_execute() - def test_compile_tflite(self): - self.compile_and_execute() - -if __name__ == '__main__': - absl.testing.absltest.main() +if __name__ == "__main__": + absl.testing.absltest.main() diff --git a/integrations/tensorflow/test/python/iree_tfl_tests/squad_test_data.py b/integrations/tensorflow/test/python/iree_tfl_tests/squad_test_data.py index 20a73f3a7c6c..d624d51f78cb 100644 --- a/integrations/tensorflow/test/python/iree_tfl_tests/squad_test_data.py +++ b/integrations/tensorflow/test/python/iree_tfl_tests/squad_test_data.py @@ -1,58 +1,1159 @@ # An example input combination from the Squad 1.1 dataset. _INPUT_WORD_ID = [ - 101, 2129, 2116, 19576, 2015, 2106, 3854, 4679, 2486, 1029, 102, 1996, - 14169, 2165, 2019, 2220, 2599, 1999, 3565, 4605, 2753, 1998, 2196, 11145, - 1012, 8446, 2001, 3132, 2011, 7573, 1005, 1055, 3639, 1010, 2029, 14159, - 2032, 2698, 2335, 1998, 3140, 2032, 2046, 2093, 20991, 2015, 1010, 2164, - 1037, 19576, 2029, 2027, 6757, 2005, 1037, 7921, 1012, 7573, 15674, 3854, - 4679, 2001, 2315, 3565, 4605, 12041, 1010, 3405, 2274, 3948, 10455, 1010, - 1016, 13714, 14918, 1010, 1998, 2048, 3140, 19576, 2015, 1012, 102, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 + 101, + 2129, + 2116, + 19576, + 2015, + 2106, + 3854, + 4679, + 2486, + 1029, + 102, + 1996, + 14169, + 2165, + 2019, + 2220, + 2599, + 1999, + 3565, + 4605, + 2753, + 1998, + 2196, + 11145, + 1012, + 8446, + 2001, + 3132, + 2011, + 7573, + 1005, + 1055, + 3639, + 1010, + 2029, + 14159, + 2032, + 2698, + 2335, + 1998, + 3140, + 2032, + 2046, + 2093, + 20991, + 2015, + 1010, + 2164, + 1037, + 19576, + 2029, + 2027, + 6757, + 2005, + 1037, + 7921, + 1012, + 7573, + 15674, + 3854, + 4679, + 2001, + 2315, + 3565, + 4605, + 12041, + 1010, + 3405, + 2274, + 3948, + 10455, + 1010, + 1016, + 13714, + 14918, + 1010, + 1998, + 2048, + 3140, + 19576, + 2015, + 1012, + 102, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, ] _INPUT_TYPE_ID = [ - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, - 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0 + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, ] _INPUT_MASK = [ - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, - 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0 + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, ] diff --git a/integrations/tensorflow/test/python/iree_tfl_tests/test_util.py b/integrations/tensorflow/test/python/iree_tfl_tests/test_util.py index dc5321b28573..bbbfd87233ec 100644 --- a/integrations/tensorflow/test/python/iree_tfl_tests/test_util.py +++ b/integrations/tensorflow/test/python/iree_tfl_tests/test_util.py @@ -18,154 +18,156 @@ import urllib.request targets = { - 'llvmcpu': 'llvm-cpu', - 'vmvx': 'vmvx', - 'vulkan': 'vulkan-spirv', + "llvmcpu": "llvm-cpu", + "vmvx": "vmvx", + "vulkan": "vulkan-spirv", } configs = { - 'llvmcpu': 'local-task', - 'vmvx': 'local-task', - 'vulkan': 'vulkan', + "llvmcpu": "local-task", + "vmvx": "local-task", + "vulkan": "vulkan", } -absl.flags.DEFINE_string('target_backend', 'llvmcpu', 'model path to execute') +absl.flags.DEFINE_string("target_backend", "llvmcpu", "model path to execute") absl.flags.DEFINE_string( - "artifacts_dir", None, + "artifacts_dir", + None, "Specifies a directory to dump compilation artifacts and traces to. " - "Defaults to the OS's tempdir.") + "Defaults to the OS's tempdir.", +) def _setup_artifacts_dir(relative_artifacts_dir: str) -> str: - parent_dirs = [ - FLAGS.artifacts_dir, - os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR'), - os.environ.get('TEST_TMPDIR'), - os.path.join(tempfile.gettempdir(), "iree", "modules"), - ] - # Use the most preferred path in parent_dirs that isn't None. - parent_dir = next(parent for parent in parent_dirs if parent is not None) - - artifacts_dir = os.path.join(parent_dir, relative_artifacts_dir) - absl.logging.info("Saving compilation artifacts and traces to '%s'", - artifacts_dir) - os.makedirs(artifacts_dir, exist_ok=True) - return artifacts_dir + parent_dirs = [ + FLAGS.artifacts_dir, + os.environ.get("TEST_UNDECLARED_OUTPUTS_DIR"), + os.environ.get("TEST_TMPDIR"), + os.path.join(tempfile.gettempdir(), "iree", "modules"), + ] + # Use the most preferred path in parent_dirs that isn't None. + parent_dir = next(parent for parent in parent_dirs if parent is not None) + artifacts_dir = os.path.join(parent_dir, relative_artifacts_dir) + absl.logging.info("Saving compilation artifacts and traces to '%s'", artifacts_dir) + os.makedirs(artifacts_dir, exist_ok=True) + return artifacts_dir -class TFLiteModelTest(testing.absltest.TestCase): - def __init__(self, model_path, *args, **kwargs): - super(TFLiteModelTest, self).__init__(*args, **kwargs) - self.model_path = model_path - - def setUp(self): - if self.model_path is None: - return - self.workdir = _setup_artifacts_dir("download") - print(f"TMPDIR = {self.workdir}") - self.tflite_file = '/'.join([self.workdir, 'model.mlirbc']) - self.tflite_ir = '/'.join([self.workdir, 'tflite.mlirbc']) - self.iree_ir = '/'.join([self.workdir, 'tosa.mlirbc']) - if os.path.exists(self.model_path): - self.tflite_file = self.model_path - else: - urllib.request.urlretrieve(self.model_path, self.tflite_file) - self.binary = '/'.join([self.workdir, 'module.vmfb']) - - def generate_inputs(self, input_details): - args = [] - for input in input_details: - absl.logging.info("\t%s, %s", str(input["shape"]), - input["dtype"].__name__) - args.append(np.zeros(shape=input["shape"], dtype=input["dtype"])) - return args - - def compare_results(self, iree_results, tflite_results, details): - self.assertEqual(len(iree_results), len(tflite_results), - "Number of results do not match") - - for i in range(len(details)): - iree_result = iree_results[i] - tflite_result = tflite_results[i] - iree_result = iree_result.astype(np.single) - tflite_result = tflite_result.astype(np.single) - self.assertEqual(iree_result.shape, tflite_result.shape) - maxError = np.max(np.abs(iree_result - tflite_result)) - absl.logging.info("Max error (%d): %f", i, maxError) - - def setup_tflite(self): - absl.logging.info("Setting up tflite interpreter") - self.tflite_interpreter = tf.lite.Interpreter(model_path=self.tflite_file) - self.tflite_interpreter.allocate_tensors() - self.input_details = self.tflite_interpreter.get_input_details() - self.output_details = self.tflite_interpreter.get_output_details() - - def setup_iree(self): - absl.logging.info("Setting up iree runtime") - with open(self.binary, 'rb') as f: - config = iree_rt.Config(configs[absl.flags.FLAGS.target_backend]) - self.iree_context = iree_rt.SystemContext(config=config) - vm_module = iree_rt.VmModule.from_flatbuffer(config.vm_instance, f.read()) - self.iree_context.add_vm_module(vm_module) - - def invoke_tflite(self, args): - for i, input in enumerate(args): - self.tflite_interpreter.set_tensor(self.input_details[i]['index'], input) - start = time.perf_counter() - self.tflite_interpreter.invoke() - end = time.perf_counter() - tflite_results = [] - absl.logging.info(f"Invocation time: {end - start:0.4f} seconds") - for output_detail in self.output_details: - tflite_results.append( - np.array(self.tflite_interpreter.get_tensor(output_detail['index']))) - - for i in range(len(self.output_details)): - dtype = self.output_details[i]["dtype"] - tflite_results[i] = tflite_results[i].astype(dtype) - return tflite_results - - def invoke_iree(self, args): - invoke = self.iree_context.modules.module["main"] - start = time.perf_counter() - iree_results = invoke(*args) - end = time.perf_counter() - absl.logging.info(f"Invocation time: {end - start:0.4f} seconds") - if not isinstance(iree_results, tuple): - iree_results = (iree_results,) - return iree_results - - def compile_and_execute(self): - self.assertIsNotNone(self.model_path) - - absl.logging.info("Setting up for IREE") - iree_tflite_compile.compile_file( - self.tflite_file, - input_type="tosa", - output_file=self.binary, - save_temp_tfl_input=self.tflite_ir, - save_temp_iree_input=self.iree_ir, - target_backends=[targets[absl.flags.FLAGS.target_backend]], - import_only=False) - - self.setup_tflite() - self.setup_iree() - - absl.logging.info("Setting up test inputs") - args = self.generate_inputs(self.input_details) - - absl.logging.info("Invoking TFLite") - tflite_results = self.invoke_tflite(args) - - absl.logging.info("Invoke IREE") - iree_results = self.invoke_iree(args) - - # Fix type information for unsigned cases. - iree_results = list(iree_results) - for i in range(len(self.output_details)): - dtype = self.output_details[i]["dtype"] - iree_results[i] = iree_results[i].astype(dtype) - - self.compare_results(iree_results, tflite_results, self.output_details) +class TFLiteModelTest(testing.absltest.TestCase): + def __init__(self, model_path, *args, **kwargs): + super(TFLiteModelTest, self).__init__(*args, **kwargs) + self.model_path = model_path + + def setUp(self): + if self.model_path is None: + return + self.workdir = _setup_artifacts_dir("download") + print(f"TMPDIR = {self.workdir}") + self.tflite_file = "/".join([self.workdir, "model.mlirbc"]) + self.tflite_ir = "/".join([self.workdir, "tflite.mlirbc"]) + self.iree_ir = "/".join([self.workdir, "tosa.mlirbc"]) + if os.path.exists(self.model_path): + self.tflite_file = self.model_path + else: + urllib.request.urlretrieve(self.model_path, self.tflite_file) + self.binary = "/".join([self.workdir, "module.vmfb"]) + + def generate_inputs(self, input_details): + args = [] + for input in input_details: + absl.logging.info("\t%s, %s", str(input["shape"]), input["dtype"].__name__) + args.append(np.zeros(shape=input["shape"], dtype=input["dtype"])) + return args + + def compare_results(self, iree_results, tflite_results, details): + self.assertEqual( + len(iree_results), len(tflite_results), "Number of results do not match" + ) + + for i in range(len(details)): + iree_result = iree_results[i] + tflite_result = tflite_results[i] + iree_result = iree_result.astype(np.single) + tflite_result = tflite_result.astype(np.single) + self.assertEqual(iree_result.shape, tflite_result.shape) + maxError = np.max(np.abs(iree_result - tflite_result)) + absl.logging.info("Max error (%d): %f", i, maxError) + + def setup_tflite(self): + absl.logging.info("Setting up tflite interpreter") + self.tflite_interpreter = tf.lite.Interpreter(model_path=self.tflite_file) + self.tflite_interpreter.allocate_tensors() + self.input_details = self.tflite_interpreter.get_input_details() + self.output_details = self.tflite_interpreter.get_output_details() + + def setup_iree(self): + absl.logging.info("Setting up iree runtime") + with open(self.binary, "rb") as f: + config = iree_rt.Config(configs[absl.flags.FLAGS.target_backend]) + self.iree_context = iree_rt.SystemContext(config=config) + vm_module = iree_rt.VmModule.from_flatbuffer(config.vm_instance, f.read()) + self.iree_context.add_vm_module(vm_module) + + def invoke_tflite(self, args): + for i, input in enumerate(args): + self.tflite_interpreter.set_tensor(self.input_details[i]["index"], input) + start = time.perf_counter() + self.tflite_interpreter.invoke() + end = time.perf_counter() + tflite_results = [] + absl.logging.info(f"Invocation time: {end - start:0.4f} seconds") + for output_detail in self.output_details: + tflite_results.append( + np.array(self.tflite_interpreter.get_tensor(output_detail["index"])) + ) + + for i in range(len(self.output_details)): + dtype = self.output_details[i]["dtype"] + tflite_results[i] = tflite_results[i].astype(dtype) + return tflite_results + + def invoke_iree(self, args): + invoke = self.iree_context.modules.module["main"] + start = time.perf_counter() + iree_results = invoke(*args) + end = time.perf_counter() + absl.logging.info(f"Invocation time: {end - start:0.4f} seconds") + if not isinstance(iree_results, tuple): + iree_results = (iree_results,) + return iree_results + + def compile_and_execute(self): + self.assertIsNotNone(self.model_path) + + absl.logging.info("Setting up for IREE") + iree_tflite_compile.compile_file( + self.tflite_file, + input_type="tosa", + output_file=self.binary, + save_temp_tfl_input=self.tflite_ir, + save_temp_iree_input=self.iree_ir, + target_backends=[targets[absl.flags.FLAGS.target_backend]], + import_only=False, + ) + + self.setup_tflite() + self.setup_iree() + + absl.logging.info("Setting up test inputs") + args = self.generate_inputs(self.input_details) + + absl.logging.info("Invoking TFLite") + tflite_results = self.invoke_tflite(args) + + absl.logging.info("Invoke IREE") + iree_results = self.invoke_iree(args) + + # Fix type information for unsigned cases. + iree_results = list(iree_results) + for i in range(len(self.output_details)): + dtype = self.output_details[i]["dtype"] + iree_results[i] = iree_results[i].astype(dtype) + + self.compare_results(iree_results, tflite_results, self.output_details) diff --git a/llvm-external-projects/iree-dialects/python/iree/compiler/dialects/_iree_linalg_transform_ops_ext.py b/llvm-external-projects/iree-dialects/python/iree/compiler/dialects/_iree_linalg_transform_ops_ext.py index e315cb3573d9..2dc4516884f0 100644 --- a/llvm-external-projects/iree-dialects/python/iree/compiler/dialects/_iree_linalg_transform_ops_ext.py +++ b/llvm-external-projects/iree-dialects/python/iree/compiler/dialects/_iree_linalg_transform_ops_ext.py @@ -7,77 +7,84 @@ # MLIR. # pytype: skip-file try: - from .. import ir - from ..dialects import pdl - from ._ods_common import extend_opview_class as _ods_extend_opview_class, segmented_accessor as _ods_segmented_accessor, equally_sized_accessor as _ods_equally_sized_accessor, get_default_loc_context as _ods_get_default_loc_context, get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values - from typing import Optional, Sequence, Union + from .. import ir + from ..dialects import pdl + from ._ods_common import ( + extend_opview_class as _ods_extend_opview_class, + segmented_accessor as _ods_segmented_accessor, + equally_sized_accessor as _ods_equally_sized_accessor, + get_default_loc_context as _ods_get_default_loc_context, + get_op_result_or_value as _get_op_result_or_value, + get_op_results_or_values as _get_op_results_or_values, + ) + from typing import Optional, Sequence, Union except ImportError as e: - raise RuntimeError("Error loading imports from extension module") from e + raise RuntimeError("Error loading imports from extension module") from e BoolArg = Optional[Union[bool, ir.BoolAttr]] IntArg = Optional[Union[int, ir.IntegerAttr]] IntListArg = Optional[Union[Sequence[int], ir.ArrayAttr]] -IntListListArg = Optional[Union[Sequence[Union[Sequence[int], ir.ArrayAttr]], - ir.ArrayAttr]] +IntListListArg = Optional[ + Union[Sequence[Union[Sequence[int], ir.ArrayAttr]], ir.ArrayAttr] +] StringArg = Optional[Union[str, ir.StringAttr]] StringListArg = Optional[Union[Sequence[str], ir.ArrayAttr]] def _defaulted_ensure(f): + def inner(value, default=None): + assert value is not None or default is not None + return f(default if value is None else value) - def inner(value, default=None): - assert value is not None or default is not None - return f(default if value is None else value) - - return inner + return inner @_defaulted_ensure def _ensure_int_array_attr(value: IntListArg): - i64 = ir.IntegerType.get_signless(64) - if isinstance(value, Sequence): - return ir.ArrayAttr.get([ir.IntegerAttr.get(i64, i) for i in value]) - return value + i64 = ir.IntegerType.get_signless(64) + if isinstance(value, Sequence): + return ir.ArrayAttr.get([ir.IntegerAttr.get(i64, i) for i in value]) + return value @_defaulted_ensure def _ensure_string_array_attr(value: StringListArg): - if isinstance(value, Sequence): - return ir.ArrayAttr.get([ir.StringAttr.get(str(i)) for i in value]) - return value + if isinstance(value, Sequence): + return ir.ArrayAttr.get([ir.StringAttr.get(str(i)) for i in value]) + return value @_defaulted_ensure def _ensure_array_of_array_attr(value: IntListListArg): - if isinstance(value, Sequence): - return ir.ArrayAttr.get([_ensure_int_array_attr(inner) for inner in value]) - return value + if isinstance(value, Sequence): + return ir.ArrayAttr.get([_ensure_int_array_attr(inner) for inner in value]) + return value @_defaulted_ensure def _ensure_int_attr(value: IntArg): - if isinstance(value, int): - return ir.IntegerAttr.get(ir.IntegerType.get_signless(64), value) - return value + if isinstance(value, int): + return ir.IntegerAttr.get(ir.IntegerType.get_signless(64), value) + return value @_defaulted_ensure def _ensure_bool_attr(value: BoolArg): - if isinstance(value, bool): - return ir.BoolAttr.get(value) - return value + if isinstance(value, bool): + return ir.BoolAttr.get(value) + return value @_defaulted_ensure def _ensure_string_attr(value: StringArg): - if isinstance(value, str): - return ir.StringAttr.get(value) - return value + if isinstance(value, str): + return ir.StringAttr.get(value) + return value def _count_expected_loops(tile_sizes: ir.ArrayAttr) -> int: - # Number of loops = number of tile sizes != 0 - zero = _ensure_int_attr(0) - return len(list(tile_sizes)) - list(tile_sizes).count(zero) + # Number of loops = number of tile sizes != 0 + zero = _ensure_int_attr(0) + return len(list(tile_sizes)) - list(tile_sizes).count(zero) ##===----------------------------------------------------------------------===## @@ -86,26 +93,30 @@ def _count_expected_loops(tile_sizes: ir.ArrayAttr) -> int: class TileToLinalgExtTileOp: - """Specialization for the TileToLinalgExtTileOp class.""" + """Specialization for the TileToLinalgExtTileOp class.""" - def __init__(self, - target: Union[ir.Value, ir.Operation, ir.OpView], - *, - sizes: IntListArg = None, - loc=None, - ip=None): - sizes = _ensure_int_array_attr(sizes, []) - operation_type = pdl.OperationType.get() - super().__init__(operation_type, target, sizes, loc=loc, ip=ip) + def __init__( + self, + target: Union[ir.Value, ir.Operation, ir.OpView], + *, + sizes: IntListArg = None, + loc=None, + ip=None + ): + sizes = _ensure_int_array_attr(sizes, []) + operation_type = pdl.OperationType.get() + super().__init__(operation_type, target, sizes, loc=loc, ip=ip) class FuseIntoContainingOp: - """Specialization for the FuseIntoContainingOp class.""" - - def __init__(self, - producerOp: Union[ir.Value, ir.Operation, ir.OpView], - *, - containingOp: Union[ir.Value, ir.Operation, ir.OpView], - loc=None, - ip=None): - super().__init__([], producerOp, containingOp, loc=loc, ip=ip) + """Specialization for the FuseIntoContainingOp class.""" + + def __init__( + self, + producerOp: Union[ir.Value, ir.Operation, ir.OpView], + *, + containingOp: Union[ir.Value, ir.Operation, ir.OpView], + loc=None, + ip=None + ): + super().__init__([], producerOp, containingOp, loc=loc, ip=ip) diff --git a/llvm-external-projects/iree-dialects/python/iree/compiler/dialects/_iree_structured_transform_ops_ext.py b/llvm-external-projects/iree-dialects/python/iree/compiler/dialects/_iree_structured_transform_ops_ext.py index 754c9fac82db..f1f89550bd42 100644 --- a/llvm-external-projects/iree-dialects/python/iree/compiler/dialects/_iree_structured_transform_ops_ext.py +++ b/llvm-external-projects/iree-dialects/python/iree/compiler/dialects/_iree_structured_transform_ops_ext.py @@ -7,68 +7,78 @@ # MLIR. # pytype: skip-file try: - from ..ir import * - from ..dialects import pdl - from ._ods_common import extend_opview_class as _ods_extend_opview_class, segmented_accessor as _ods_segmented_accessor, equally_sized_accessor as _ods_equally_sized_accessor, get_default_loc_context as _ods_get_default_loc_context, get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values - from typing import Optional, overload, Sequence, Union + from ..ir import * + from ..dialects import pdl + from ._ods_common import ( + extend_opview_class as _ods_extend_opview_class, + segmented_accessor as _ods_segmented_accessor, + equally_sized_accessor as _ods_equally_sized_accessor, + get_default_loc_context as _ods_get_default_loc_context, + get_op_result_or_value as _get_op_result_or_value, + get_op_results_or_values as _get_op_results_or_values, + ) + from typing import Optional, overload, Sequence, Union except ImportError as e: - raise RuntimeError("Error loading imports from extension module") from e + raise RuntimeError("Error loading imports from extension module") from e BoolArg = Optional[Union[bool, BoolAttr]] IntListArg = Optional[Union[Sequence[int], ArrayAttr]] StringArg = Optional[Union[str, StringAttr]] def _defaulted_ensure(f): + def inner(value, default=None): + assert value is not None or default is not None + return f(default if value is None else value) - def inner(value, default=None): - assert value is not None or default is not None - return f(default if value is None else value) - - return inner + return inner @_defaulted_ensure def _ensure_int_array_attr(value: IntListArg): - i64 = IntegerType.get_signless(64) - if isinstance(value, Sequence): - return ArrayAttr.get([IntegerAttr.get(i64, i) for i in value]) - return value + i64 = IntegerType.get_signless(64) + if isinstance(value, Sequence): + return ArrayAttr.get([IntegerAttr.get(i64, i) for i in value]) + return value @_defaulted_ensure def _ensure_bool_attr(value: BoolArg): - if isinstance(value, bool): - return BoolAttr.get(value) - return value + if isinstance(value, bool): + return BoolAttr.get(value) + return value @_defaulted_ensure def _ensure_string_attr(value: StringArg): - if isinstance(value, str): - return StringAttr.get(value) - return value + if isinstance(value, str): + return StringAttr.get(value) + return value class LowerToLLVMOp: - """Specialization for the LowerToLLVMOp class.""" + """Specialization for the LowerToLLVMOp class.""" - def __init__(self, - *, - reassociate_fp_reductions: BoolArg = None, - enable_index_optimizations: BoolArg = None, - enable_arm_neon: BoolArg = None, - enable_arm_sve: BoolArg = None, - enable_amx: BoolArg = None, - enable_x86vector: BoolArg = None, - enable_async: BoolArg = None, - loc=None, - ip=None): - super().__init__(_ensure_bool_attr(reassociate_fp_reductions, False), - _ensure_bool_attr(enable_index_optimizations, False), - _ensure_bool_attr(enable_arm_neon, False), - _ensure_bool_attr(enable_arm_sve, False), - _ensure_bool_attr(enable_amx, False), - _ensure_bool_attr(enable_x86vector, False), - _ensure_bool_attr(enable_async, False), - loc=loc, - ip=ip) + def __init__( + self, + *, + reassociate_fp_reductions: BoolArg = None, + enable_index_optimizations: BoolArg = None, + enable_arm_neon: BoolArg = None, + enable_arm_sve: BoolArg = None, + enable_amx: BoolArg = None, + enable_x86vector: BoolArg = None, + enable_async: BoolArg = None, + loc=None, + ip=None + ): + super().__init__( + _ensure_bool_attr(reassociate_fp_reductions, False), + _ensure_bool_attr(enable_index_optimizations, False), + _ensure_bool_attr(enable_arm_neon, False), + _ensure_bool_attr(enable_arm_sve, False), + _ensure_bool_attr(enable_amx, False), + _ensure_bool_attr(enable_x86vector, False), + _ensure_bool_attr(enable_async, False), + loc=loc, + ip=ip, + ) diff --git a/llvm-external-projects/iree-dialects/test/lit.cfg.py b/llvm-external-projects/iree-dialects/test/lit.cfg.py index 0a3fe29f3db3..c0cd5c23a025 100644 --- a/llvm-external-projects/iree-dialects/test/lit.cfg.py +++ b/llvm-external-projects/iree-dialects/test/lit.cfg.py @@ -21,65 +21,74 @@ # Configuration file for the 'lit' test runner. # name: The name of this test suite. -config.name = 'IREE_DIALECTS' +config.name = "IREE_DIALECTS" config.test_format = lit.formats.ShTest(not llvm_config.use_lit_shell) # suffixes: A list of file extensions to treat as test files. -config.suffixes = ['.mlir', '.py'] +config.suffixes = [".mlir", ".py"] # test_source_root: The root path where tests are located. config.test_source_root = os.path.dirname(__file__) # test_exec_root: The root path where tests should be run. -config.test_exec_root = os.path.join(config.iree_dialects_obj_root, 'test') +config.test_exec_root = os.path.join(config.iree_dialects_obj_root, "test") -config.substitutions.append(('%PATH%', config.environment['PATH'])) -config.substitutions.append(('%shlibext', config.llvm_shlib_ext)) +config.substitutions.append(("%PATH%", config.environment["PATH"])) +config.substitutions.append(("%shlibext", config.llvm_shlib_ext)) config.substitutions.append( - ('%resources_dir', os.path.join(config.iree_dialects_obj_root, - 'resources'))) + ("%resources_dir", os.path.join(config.iree_dialects_obj_root, "resources")) +) -llvm_config.with_system_environment(['HOME', 'INCLUDE', 'LIB', 'TMP', 'TEMP']) +llvm_config.with_system_environment(["HOME", "INCLUDE", "LIB", "TMP", "TEMP"]) -#llvm_config.use_default_substitutions() +# llvm_config.use_default_substitutions() # excludes: A list of directories to exclude from the testsuite. The 'Inputs' # subdirectories contain auxiliary inputs for various tests in their parent # directories. config.excludes = [ - 'Inputs', 'Examples', 'CMakeLists.txt', 'README.txt', 'LICENSE.txt', - 'lit.cfg.py', 'lit.site.cfg.py' + "Inputs", + "Examples", + "CMakeLists.txt", + "README.txt", + "LICENSE.txt", + "lit.cfg.py", + "lit.site.cfg.py", ] # test_source_root: The root path where tests are located. config.test_source_root = os.path.dirname(__file__) # test_exec_root: The root path where tests should be run. -config.test_exec_root = os.path.join(config.iree_dialects_obj_root, 'test') -config.standalone_tools_dir = os.path.join(config.iree_dialects_obj_root, 'bin') +config.test_exec_root = os.path.join(config.iree_dialects_obj_root, "test") +config.standalone_tools_dir = os.path.join(config.iree_dialects_obj_root, "bin") # Tweak the PATH to include the tools dir. -llvm_config.with_environment('PATH', - config.llvm_tools_binary_dir, - append_path=True) +llvm_config.with_environment("PATH", config.llvm_tools_binary_dir, append_path=True) tool_dirs = [config.llvm_tools_binary_dir] tools = [ - ToolSubst('%PYTHON', config.python_executable, unresolved='ignore'), + ToolSubst("%PYTHON", config.python_executable, unresolved="ignore"), # Since we build iree-dialects out of tree, we don't have a common tools # directory, so substitute binaries needed to an explicit path. ToolSubst( - 'iree-dialects-opt', - os.path.join(config.iree_dialects_obj_root, - 'tools/iree-dialects-opt/iree-dialects-opt')) + "iree-dialects-opt", + os.path.join( + config.iree_dialects_obj_root, "tools/iree-dialects-opt/iree-dialects-opt" + ), + ), ] llvm_config.add_tool_substitutions(tools, tool_dirs) if config.enable_bindings_python: - llvm_config.with_environment('PYTHONPATH', [ - os.path.join(config.iree_dialects_obj_root, 'python_packages', - 'iree_dialects'), - ], - append_path=True) + llvm_config.with_environment( + "PYTHONPATH", + [ + os.path.join( + config.iree_dialects_obj_root, "python_packages", "iree_dialects" + ), + ], + append_path=True, + ) diff --git a/llvm-external-projects/iree-dialects/test/python/smoketest.py b/llvm-external-projects/iree-dialects/test/python/smoketest.py index 1a0402b3a665..fe1abb7088dd 100644 --- a/llvm-external-projects/iree-dialects/test/python/smoketest.py +++ b/llvm-external-projects/iree-dialects/test/python/smoketest.py @@ -6,6 +6,6 @@ from iree.compiler.dialects import iree_linalg_transform with iree.compiler.ir.Context() as ctx: - iree_d.register_dialect() - iree_linalg_ext.register_dialect() - iree_linalg_transform.register_dialect() + iree_d.register_dialect() + iree_linalg_ext.register_dialect() + iree_linalg_transform.register_dialect() diff --git a/runtime/bindings/python/iree/runtime/array_interop.py b/runtime/bindings/python/iree/runtime/array_interop.py index c3b821423601..72b236c33930 100644 --- a/runtime/bindings/python/iree/runtime/array_interop.py +++ b/runtime/bindings/python/iree/runtime/array_interop.py @@ -28,200 +28,211 @@ def _device_implements(np_function): - """Decorator that registers a base class implementation.""" + """Decorator that registers a base class implementation.""" - def decorator(func): - _DEVICE_HANDLED_FUNCTIONS[np_function] = func - return func + def decorator(func): + _DEVICE_HANDLED_FUNCTIONS[np_function] = func + return func - return decorator + return decorator class DeviceArray(numpy.lib.mixins.NDArrayOperatorsMixin): - """An IREE device array. - - Device arrays can be in one of two states: - 1. Host accessible: The array will be backed by host accessible memory - and can have the usual things done with it that one expects to be - able to do with an ndarray. - 2. Device resident: The array is just a handle to a device resident - Buffer (and BufferView wrapper). Metadata about the array are accessible - (shape and dtype) but anything that touches the data cannot be accessed - in this state. - - How a device array comes into existence controls how it can transition - between these states: - * A user can create a DeviceArray explicitly with a device allocator. - Such an array will not be implicitly convertible to host accessible, - although accessors exist to do so. - * When created by the platform with a synchronization policy, then - implicit transfer back to the host will trigger appropriate waits and - be performed automatically (this is the common case for function return - values if not otherwise configured, as an example). - """ - - def __init__(self, - device: HalDevice, - buffer_view: HalBufferView, - implicit_host_transfer: bool = False, - override_dtype=None): - self._device = device - self._buffer_view = buffer_view - self._implicit_host_transfer = implicit_host_transfer - self._override_dtype = override_dtype - - # If the array is host accessible, these will be non-None. - self._mapped_memory: Optional[MappedMemory] = None - self._host_array: Optional[np.ndarray] = None - - def __array__(self, dtype=None): - self._transfer_to_host(True) - if dtype is None: - return self._host_array - else: - return self._host_array.__array__(dtype) # pytype: disable=attribute-error - - def __array_function__(self, func, types, args, kwargs): - if func in _DEVICE_HANDLED_FUNCTIONS: - return _DEVICE_HANDLED_FUNCTIONS[func](*args, **kwargs) - - # Anything else forces a transfer to host and then delegates to the - # host array. - host_array = self.to_host() - return host_array.__array_function__(func, types, args, kwargs) # pytype: disable=attribute-error - - def __repr__(self): - return f"" - - @property - def is_host_accessible(self): - """Whether this array is currently host accessible.""" - return self._host_array is not None - - def to_host(self) -> np.ndarray: - self._transfer_to_host(False) - return self._host_array - - def _transfer_to_host(self, implicit): - if self._host_array is not None: - return - if implicit and not self._implicit_host_transfer: - raise ValueError( - "DeviceArray cannot be implicitly transferred to the host: " - "if necessary, do an explicit transfer via .to_host()") - self._mapped_memory, self._host_array = self._map_to_host() - - def _map_to_host(self) -> Tuple[MappedMemory, np.ndarray]: - # TODO: When synchronization is enabled, need to block here. - raw_dtype = self._get_raw_dtype() - mapped_memory = self._buffer_view.map() - host_array = mapped_memory.asarray(self._buffer_view.shape, raw_dtype) - # Detect if we need to force an explicit conversion. This happens when - # we were requested to pretend that the array is in a specific dtype, - # even if that is not representable on the device. You guessed it: - # this is to support bools. - if self._override_dtype is not None and self._override_dtype != raw_dtype: - host_array = host_array.astype(self._override_dtype) - return mapped_memory, host_array - - def _get_raw_dtype(self): - return HalElementType.map_to_dtype(self._buffer_view.element_type) - - @property - def dtype(self): - if self._override_dtype: - return self._override_dtype - return self._get_raw_dtype() - - @property - def shape(self): - return np.shape(self) - - def astype(self, dtype, casting="unsafe", copy=True): - if self.dtype == dtype and not copy: - return self - host_ary = self.to_host() - return host_ary.astype(dtype, casting=casting, copy=copy) - - def reshape(self, *args): - # TODO(scotttodd): add a native impl with a new buffer_view of the same data - # TODO(scotttodd): return DeviceArray instead of host ndarray? - host_ary = self.to_host() - return host_ary.reshape(*args) - - def __iter__(self): - host_ary = self.to_host() - return host_ary.__iter__() - - def __getitem__(self, index): - host_ary = self.to_host() - return host_ary.__getitem__(index) - - def __reduce__(self): - # Since this is used for making deep copies and pickling, we map - # separately from any interactive state. We just reduce to the actual - # host ndarray, which supports the necessary serialization protocols. - _, host_array = self._map_to_host() - return _restore_reduced_array, (host_array,) + """An IREE device array. + + Device arrays can be in one of two states: + 1. Host accessible: The array will be backed by host accessible memory + and can have the usual things done with it that one expects to be + able to do with an ndarray. + 2. Device resident: The array is just a handle to a device resident + Buffer (and BufferView wrapper). Metadata about the array are accessible + (shape and dtype) but anything that touches the data cannot be accessed + in this state. + + How a device array comes into existence controls how it can transition + between these states: + * A user can create a DeviceArray explicitly with a device allocator. + Such an array will not be implicitly convertible to host accessible, + although accessors exist to do so. + * When created by the platform with a synchronization policy, then + implicit transfer back to the host will trigger appropriate waits and + be performed automatically (this is the common case for function return + values if not otherwise configured, as an example). + """ + + def __init__( + self, + device: HalDevice, + buffer_view: HalBufferView, + implicit_host_transfer: bool = False, + override_dtype=None, + ): + self._device = device + self._buffer_view = buffer_view + self._implicit_host_transfer = implicit_host_transfer + self._override_dtype = override_dtype + + # If the array is host accessible, these will be non-None. + self._mapped_memory: Optional[MappedMemory] = None + self._host_array: Optional[np.ndarray] = None + + def __array__(self, dtype=None): + self._transfer_to_host(True) + if dtype is None: + return self._host_array + else: + return self._host_array.__array__(dtype) # pytype: disable=attribute-error + + def __array_function__(self, func, types, args, kwargs): + if func in _DEVICE_HANDLED_FUNCTIONS: + return _DEVICE_HANDLED_FUNCTIONS[func](*args, **kwargs) + + # Anything else forces a transfer to host and then delegates to the + # host array. + host_array = self.to_host() + return host_array.__array_function__( + func, types, args, kwargs + ) # pytype: disable=attribute-error + + def __repr__(self): + return f"" + + @property + def is_host_accessible(self): + """Whether this array is currently host accessible.""" + return self._host_array is not None + + def to_host(self) -> np.ndarray: + self._transfer_to_host(False) + return self._host_array + + def _transfer_to_host(self, implicit): + if self._host_array is not None: + return + if implicit and not self._implicit_host_transfer: + raise ValueError( + "DeviceArray cannot be implicitly transferred to the host: " + "if necessary, do an explicit transfer via .to_host()" + ) + self._mapped_memory, self._host_array = self._map_to_host() + + def _map_to_host(self) -> Tuple[MappedMemory, np.ndarray]: + # TODO: When synchronization is enabled, need to block here. + raw_dtype = self._get_raw_dtype() + mapped_memory = self._buffer_view.map() + host_array = mapped_memory.asarray(self._buffer_view.shape, raw_dtype) + # Detect if we need to force an explicit conversion. This happens when + # we were requested to pretend that the array is in a specific dtype, + # even if that is not representable on the device. You guessed it: + # this is to support bools. + if self._override_dtype is not None and self._override_dtype != raw_dtype: + host_array = host_array.astype(self._override_dtype) + return mapped_memory, host_array + + def _get_raw_dtype(self): + return HalElementType.map_to_dtype(self._buffer_view.element_type) + + @property + def dtype(self): + if self._override_dtype: + return self._override_dtype + return self._get_raw_dtype() + + @property + def shape(self): + return np.shape(self) + + def astype(self, dtype, casting="unsafe", copy=True): + if self.dtype == dtype and not copy: + return self + host_ary = self.to_host() + return host_ary.astype(dtype, casting=casting, copy=copy) + + def reshape(self, *args): + # TODO(scotttodd): add a native impl with a new buffer_view of the same data + # TODO(scotttodd): return DeviceArray instead of host ndarray? + host_ary = self.to_host() + return host_ary.reshape(*args) + + def __iter__(self): + host_ary = self.to_host() + return host_ary.__iter__() + + def __getitem__(self, index): + host_ary = self.to_host() + return host_ary.__getitem__(index) + + def __reduce__(self): + # Since this is used for making deep copies and pickling, we map + # separately from any interactive state. We just reduce to the actual + # host ndarray, which supports the necessary serialization protocols. + _, host_array = self._map_to_host() + return _restore_reduced_array, (host_array,) def _restore_reduced_array(ary): - return ary + return ary # Function implementations with custom behavior. @_device_implements(np.shape) def _(arr: DeviceArray): - return arr._buffer_view.shape + return arr._buffer_view.shape @_device_implements(np.reshape) def _(arr: DeviceArray, *args): - return arr.reshape(*args) - - -def asdevicearray(device: HalDevice, - a, - dtype=None, - *, - implicit_host_transfer: bool = False, - memory_type=MemoryType.DEVICE_LOCAL, - allowed_usage=(BufferUsage.DEFAULT | BufferUsage.MAPPING), - element_type: Optional[HalElementType] = None) -> DeviceArray: - """Helper to create a DeviceArray from an arbitrary array like. - - This is similar in purpose and usage to np.asarray, except that it takes - a device as the first argument. This may not be the best mechanism for - getting a DeviceArray, depending on your use case, but it is reliable - and simple. This function may make a defensive copy or cause implicit - transfers to satisfy the request. If this is important to you, then a lower - level API is likely more appropriate. - - Note that additional flags `memory_type`, `allowed_usage` and `element_type` - are only hints if creating a new DeviceArray. If `a` is already a DeviceArray, - they are ignored. - """ - if isinstance(a, DeviceArray): - if dtype is None: - return a - # Need to do a conversion, which we currently do not support on the - # device, so transfer back to the host. - logging.warn( - "Implicit dtype conversion of a DeviceArray forces a host transfer") - # First get an ndarray. - a = np.asarray(a, dtype=dtype) - element_type = map_dtype_to_element_type(a.dtype) - if element_type is None: - raise ValueError(f"Could not map dtype {a.dtype} to IREE element type") - buffer_view = device.allocator.allocate_buffer_copy( - memory_type=memory_type, - allowed_usage=allowed_usage, - buffer=a, - element_type=element_type) - return DeviceArray(device, - buffer_view, - implicit_host_transfer=implicit_host_transfer, - override_dtype=a.dtype) + return arr.reshape(*args) + + +def asdevicearray( + device: HalDevice, + a, + dtype=None, + *, + implicit_host_transfer: bool = False, + memory_type=MemoryType.DEVICE_LOCAL, + allowed_usage=(BufferUsage.DEFAULT | BufferUsage.MAPPING), + element_type: Optional[HalElementType] = None, +) -> DeviceArray: + """Helper to create a DeviceArray from an arbitrary array like. + + This is similar in purpose and usage to np.asarray, except that it takes + a device as the first argument. This may not be the best mechanism for + getting a DeviceArray, depending on your use case, but it is reliable + and simple. This function may make a defensive copy or cause implicit + transfers to satisfy the request. If this is important to you, then a lower + level API is likely more appropriate. + + Note that additional flags `memory_type`, `allowed_usage` and `element_type` + are only hints if creating a new DeviceArray. If `a` is already a DeviceArray, + they are ignored. + """ + if isinstance(a, DeviceArray): + if dtype is None: + return a + # Need to do a conversion, which we currently do not support on the + # device, so transfer back to the host. + logging.warn( + "Implicit dtype conversion of a DeviceArray forces a host transfer" + ) + # First get an ndarray. + a = np.asarray(a, dtype=dtype) + element_type = map_dtype_to_element_type(a.dtype) + if element_type is None: + raise ValueError(f"Could not map dtype {a.dtype} to IREE element type") + buffer_view = device.allocator.allocate_buffer_copy( + memory_type=memory_type, + allowed_usage=allowed_usage, + buffer=a, + element_type=element_type, + ) + return DeviceArray( + device, + buffer_view, + implicit_host_transfer=implicit_host_transfer, + override_dtype=a.dtype, + ) # NOTE: Numpy dtypes are not hashable and exist in a hierarchy that should @@ -248,8 +259,8 @@ def asdevicearray(device: HalDevice, def map_dtype_to_element_type(dtype) -> Optional[HalElementType]: - for match_dtype, element_type in _DTYPE_TO_HAL_ELEMENT_TYPE: - if match_dtype == dtype: - return element_type - else: - return None + for match_dtype, element_type in _DTYPE_TO_HAL_ELEMENT_TYPE: + if match_dtype == dtype: + return element_type + else: + return None diff --git a/runtime/bindings/python/iree/runtime/benchmark.py b/runtime/bindings/python/iree/runtime/benchmark.py index c6ee5f94a0e7..d301469ef653 100644 --- a/runtime/bindings/python/iree/runtime/benchmark.py +++ b/runtime/bindings/python/iree/runtime/benchmark.py @@ -28,7 +28,8 @@ ] BenchmarkResult = namedtuple( - "BenchmarkResult", "benchmark_name time cpu_time iterations user_counters") + "BenchmarkResult", "benchmark_name time cpu_time iterations user_counters" +) DTYPE_TO_ABI_TYPE = { numpy.dtype(numpy.float32): "f32", @@ -42,88 +43,90 @@ class BenchmarkToolError(Exception): - """Benchmark exception that preserves the command line and error output.""" + """Benchmark exception that preserves the command line and error output.""" - def __init__(self, message): - self.message = message - super().__init__(self.message) + def __init__(self, message): + self.message = message + super().__init__(self.message) def benchmark_exe(): - return os.path.join(os.path.dirname(__file__), "iree-benchmark-module") + return os.path.join(os.path.dirname(__file__), "iree-benchmark-module") def benchmark_module(module, entry_functiong=None, inputs=[], **kwargs): - funcs = [a for a in module.function_names if a != "__init"] - if entry_functiong is None: - if len(funcs) > 1: - raise ValueError(f"No function specified with multiple options {funcs}") - entry_functiong = funcs[0] - - # Throw an error - if entry_functiong not in funcs: - raise ValueError( - f"Attempted to benchmark unknown function {entry_functiong} of options {funcs}" + funcs = [a for a in module.function_names if a != "__init"] + if entry_functiong is None: + if len(funcs) > 1: + raise ValueError(f"No function specified with multiple options {funcs}") + entry_functiong = funcs[0] + + # Throw an error + if entry_functiong not in funcs: + raise ValueError( + f"Attempted to benchmark unknown function {entry_functiong} of options {funcs}" + ) + + flatbuffer = module.stashed_flatbuffer_blob + function = module.lookup_function(entry_functiong) + args = [iree.runtime.benchmark_exe()] + args.append(f"--function={funcs[0]}") + + for k in kwargs: + v = kwargs[k] + args.append(f"--{k}={v}") + + for inp in inputs: + if isinstance(inp, str): + args.append(f"--input={inp}") + continue + shape = "x".join([str(d) for d in inp.shape]) + abitype = DTYPE_TO_ABI_TYPE[inp.dtype] + values = inp.flatten() + if numpy.all(values[0] == values): + values = str(values[0]) + else: + values = ",".join([str(v) for v in values]) + + args.append(f"--input={shape}x{abitype}={values}") + args.append(f"--module=-") + + call = subprocess.Popen( + args=args, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE ) - - flatbuffer = module.stashed_flatbuffer_blob - function = module.lookup_function(entry_functiong) - args = [iree.runtime.benchmark_exe()] - args.append(f"--function={funcs[0]}") - - for k in kwargs: - v = kwargs[k] - args.append(f"--{k}={v}") - - for inp in inputs: - if isinstance(inp, str): - args.append(f"--input={inp}") - continue - shape = "x".join([str(d) for d in inp.shape]) - abitype = DTYPE_TO_ABI_TYPE[inp.dtype] - values = inp.flatten() - if numpy.all(values[0] == values): - values = str(values[0]) - else: - values = ",".join([str(v) for v in values]) - - args.append(f"--input={shape}x{abitype}={values}") - args.append(f"--module=-") - - call = subprocess.Popen(args=args, - stdin=subprocess.PIPE, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE) - out, err = call.communicate(input=flatbuffer) - - err = err.decode() - if "INVALID_ARGUMENT;" in err: - raise ValueError("Invalid inputs specified for benchmarking") - - # In the event benchmarking runs but encounteres an internal error, - # return the internal error instead of benchmark results. - if "INTERNAL; CUDA driver error" in str(out): - raise BenchmarkToolError(str(out)) - - # Grab individual results by line (skip header lines) - bench_lines = out.decode().split("\n")[3:] - benchmark_results = [] - for line in bench_lines: - split = line.split() - if len(split) == 0: - continue - benchmark_name = split[0] - time = " ".join(split[1:3]) - cpu_time = " ".join(split[3:5]) - iterations = split[5] - user_counters = None - if len(split) > 5: - user_counters = split[6] - benchmark_results.append( - BenchmarkResult(benchmark_name=benchmark_name, - time=time, - cpu_time=cpu_time, - iterations=iterations, - user_counters=user_counters)) - - return benchmark_results + out, err = call.communicate(input=flatbuffer) + + err = err.decode() + if "INVALID_ARGUMENT;" in err: + raise ValueError("Invalid inputs specified for benchmarking") + + # In the event benchmarking runs but encounteres an internal error, + # return the internal error instead of benchmark results. + if "INTERNAL; CUDA driver error" in str(out): + raise BenchmarkToolError(str(out)) + + # Grab individual results by line (skip header lines) + bench_lines = out.decode().split("\n")[3:] + benchmark_results = [] + for line in bench_lines: + split = line.split() + if len(split) == 0: + continue + benchmark_name = split[0] + time = " ".join(split[1:3]) + cpu_time = " ".join(split[3:5]) + iterations = split[5] + user_counters = None + if len(split) > 5: + user_counters = split[6] + benchmark_results.append( + BenchmarkResult( + benchmark_name=benchmark_name, + time=time, + cpu_time=cpu_time, + iterations=iterations, + user_counters=user_counters, + ) + ) + + return benchmark_results diff --git a/runtime/bindings/python/iree/runtime/function.py b/runtime/bindings/python/iree/runtime/function.py index 5ad16241b940..1eb917a8b7d2 100644 --- a/runtime/bindings/python/iree/runtime/function.py +++ b/runtime/bindings/python/iree/runtime/function.py @@ -31,7 +31,8 @@ DeviceArray, ) from .flags import ( - FUNCTION_INPUT_VALIDATION,) + FUNCTION_INPUT_VALIDATION, +) __all__ = [ "FunctionInvoker", @@ -39,155 +40,167 @@ class Invocation: - __slots__ = [ - "current_arg", - "current_desc", - "current_return_list", - "current_return_index", - "device", - ] - - def __init__(self, device: HalDevice): - self.device = device - # Captured during arg/ret processing to emit better error messages. - self.current_arg = None - self.current_desc = None - self.current_return_list = None - self.current_return_index = 0 - - def summarize_arg_error(self) -> str: - if self.current_arg is None: - return "" - if isinstance(self.current_arg, np.ndarray): - current_arg_repr = ( - f"ndarray({self.current_arg.shape}, {self.current_arg.dtype})") - else: - current_arg_repr = repr(self.current_arg) - return f"{repr(current_arg_repr)} with description {self.current_desc}" - - def summarize_return_error(self) -> str: - if self.current_return_list is None: - return "" - try: - vm_repr = f"{self.current_return_index}@{self.current_return_list}" - except: - vm_repr = "" - return f"{vm_repr} with description {self.current_desc}" + __slots__ = [ + "current_arg", + "current_desc", + "current_return_list", + "current_return_index", + "device", + ] + + def __init__(self, device: HalDevice): + self.device = device + # Captured during arg/ret processing to emit better error messages. + self.current_arg = None + self.current_desc = None + self.current_return_list = None + self.current_return_index = 0 + + def summarize_arg_error(self) -> str: + if self.current_arg is None: + return "" + if isinstance(self.current_arg, np.ndarray): + current_arg_repr = ( + f"ndarray({self.current_arg.shape}, {self.current_arg.dtype})" + ) + else: + current_arg_repr = repr(self.current_arg) + return f"{repr(current_arg_repr)} with description {self.current_desc}" + + def summarize_return_error(self) -> str: + if self.current_return_list is None: + return "" + try: + vm_repr = f"{self.current_return_index}@{self.current_return_list}" + except: + vm_repr = "" + return f"{vm_repr} with description {self.current_desc}" class FunctionInvoker: - """Wraps a VmFunction, enabling invocations against it.""" - __slots__ = [ - "_vm_context", - "_device", - "_vm_function", - "_abi_dict", - "_arg_descs", - "_arg_packer", - "_ret_descs", - "_has_inlined_results", - "_tracer", - ] - - def __init__(self, vm_context: VmContext, device: HalDevice, - vm_function: VmFunction, - tracer: Optional[tracing.ContextTracer]): - self._vm_context = vm_context - # TODO: Needing to know the precise device to allocate on here is bad - # layering and will need to be fixed in some fashion if/when doing - # heterogenous dispatch. - self._device = device - self._vm_function = vm_function - self._tracer = tracer - self._abi_dict = None - self._arg_descs = None - self._ret_descs = None - self._has_inlined_results = False - self._parse_abi_dict(vm_function) - self._arg_packer = ArgumentPacker(_invoke_statics, self._arg_descs) - - @property - def vm_function(self) -> VmFunction: - return self._vm_function - - def __call__(self, *args, **kwargs): - invoke_context = InvokeContext(self._device) - arg_list = self._arg_packer.pack(invoke_context, args, kwargs) - - call_trace = None # type: Optional[tracing.CallTrace] - if self._tracer: - call_trace = self._tracer.start_call(self._vm_function) - try: - # Initialize the capacity to our total number of args, since we should - # be below that when doing a flat invocation. May want to be more - # conservative here when considering nesting. - inv = Invocation(self._device) - ret_descs = self._ret_descs - - ret_list = VmVariantList(len(ret_descs) if ret_descs is not None else 1) - if call_trace: - call_trace.add_vm_list(arg_list, "args") - self._invoke(arg_list, ret_list) - if call_trace: - call_trace.add_vm_list(ret_list, "results") - - # Un-inline the results to align with reflection, as needed. - reflection_aligned_ret_list = ret_list - if self._has_inlined_results: - reflection_aligned_ret_list = VmVariantList(1) - reflection_aligned_ret_list.push_list(ret_list) - returns = _extract_vm_sequence_to_python(inv, reflection_aligned_ret_list, - ret_descs) - return_arity = len(returns) - if return_arity == 1: - return returns[0] - elif return_arity == 0: - return None - else: - return tuple(returns) - finally: - if call_trace: - call_trace.end_call() - - # Break out invoke so it shows up in profiles. - def _invoke(self, arg_list, ret_list): - self._vm_context.invoke(self._vm_function, arg_list, ret_list) - - def _parse_abi_dict(self, vm_function: VmFunction): - reflection = vm_function.reflection - abi_json = reflection.get("iree.abi") - if abi_json is None: - # It is valid to have no reflection data, and rely on pure dynamic - # dispatch. - logging.debug( - "Function lacks reflection data. Interop will be limited: %r", - vm_function) - return - try: - self._abi_dict = json.loads(abi_json) - except json.JSONDecodeError as e: - raise RuntimeError( - f"Reflection metadata is not valid JSON: {abi_json}") from e - try: - self._arg_descs = self._abi_dict["a"] - self._ret_descs = self._abi_dict["r"] - except KeyError as e: - raise RuntimeError( - f"Malformed function reflection metadata: {reflection}") from e - if not isinstance(self._arg_descs, list) or not isinstance( - self._ret_descs, list): - raise RuntimeError( - f"Malformed function reflection metadata structure: {reflection}") - - # Detect whether the results are a slist/stuple/sdict, which indicates - # that they are inlined with the function's results. - if len(self._ret_descs) == 1: - maybe_inlined = self._ret_descs[0] - if maybe_inlined and maybe_inlined[0] in ["slist", "stuple", "sdict"]: - self._has_inlined_results = True - - def __repr__(self): - return repr(self._vm_function) + """Wraps a VmFunction, enabling invocations against it.""" + + __slots__ = [ + "_vm_context", + "_device", + "_vm_function", + "_abi_dict", + "_arg_descs", + "_arg_packer", + "_ret_descs", + "_has_inlined_results", + "_tracer", + ] + + def __init__( + self, + vm_context: VmContext, + device: HalDevice, + vm_function: VmFunction, + tracer: Optional[tracing.ContextTracer], + ): + self._vm_context = vm_context + # TODO: Needing to know the precise device to allocate on here is bad + # layering and will need to be fixed in some fashion if/when doing + # heterogenous dispatch. + self._device = device + self._vm_function = vm_function + self._tracer = tracer + self._abi_dict = None + self._arg_descs = None + self._ret_descs = None + self._has_inlined_results = False + self._parse_abi_dict(vm_function) + self._arg_packer = ArgumentPacker(_invoke_statics, self._arg_descs) + + @property + def vm_function(self) -> VmFunction: + return self._vm_function + + def __call__(self, *args, **kwargs): + invoke_context = InvokeContext(self._device) + arg_list = self._arg_packer.pack(invoke_context, args, kwargs) + + call_trace = None # type: Optional[tracing.CallTrace] + if self._tracer: + call_trace = self._tracer.start_call(self._vm_function) + try: + # Initialize the capacity to our total number of args, since we should + # be below that when doing a flat invocation. May want to be more + # conservative here when considering nesting. + inv = Invocation(self._device) + ret_descs = self._ret_descs + + ret_list = VmVariantList(len(ret_descs) if ret_descs is not None else 1) + if call_trace: + call_trace.add_vm_list(arg_list, "args") + self._invoke(arg_list, ret_list) + if call_trace: + call_trace.add_vm_list(ret_list, "results") + + # Un-inline the results to align with reflection, as needed. + reflection_aligned_ret_list = ret_list + if self._has_inlined_results: + reflection_aligned_ret_list = VmVariantList(1) + reflection_aligned_ret_list.push_list(ret_list) + returns = _extract_vm_sequence_to_python( + inv, reflection_aligned_ret_list, ret_descs + ) + return_arity = len(returns) + if return_arity == 1: + return returns[0] + elif return_arity == 0: + return None + else: + return tuple(returns) + finally: + if call_trace: + call_trace.end_call() + + # Break out invoke so it shows up in profiles. + def _invoke(self, arg_list, ret_list): + self._vm_context.invoke(self._vm_function, arg_list, ret_list) + + def _parse_abi_dict(self, vm_function: VmFunction): + reflection = vm_function.reflection + abi_json = reflection.get("iree.abi") + if abi_json is None: + # It is valid to have no reflection data, and rely on pure dynamic + # dispatch. + logging.debug( + "Function lacks reflection data. Interop will be limited: %r", + vm_function, + ) + return + try: + self._abi_dict = json.loads(abi_json) + except json.JSONDecodeError as e: + raise RuntimeError( + f"Reflection metadata is not valid JSON: {abi_json}" + ) from e + try: + self._arg_descs = self._abi_dict["a"] + self._ret_descs = self._abi_dict["r"] + except KeyError as e: + raise RuntimeError( + f"Malformed function reflection metadata: {reflection}" + ) from e + if not isinstance(self._arg_descs, list) or not isinstance( + self._ret_descs, list + ): + raise RuntimeError( + f"Malformed function reflection metadata structure: {reflection}" + ) + + # Detect whether the results are a slist/stuple/sdict, which indicates + # that they are inlined with the function's results. + if len(self._ret_descs) == 1: + maybe_inlined = self._ret_descs[0] + if maybe_inlined and maybe_inlined[0] in ["slist", "stuple", "sdict"]: + self._has_inlined_results = True + + def __repr__(self): + return repr(self._vm_function) # VM to Python converters. All take: @@ -198,70 +211,69 @@ def __repr__(self): # Return the corresponding Python object. -def _vm_to_ndarray(inv: Invocation, vm_list: VmVariantList, vm_index: int, - desc): - # The descriptor for an ndarray is like: - # ["ndarray", "", , ...] - # ex: ['ndarray', 'i32', 1, 25948] - buffer_view = vm_list.get_as_object(vm_index, HalBufferView) - dtype_str = desc[1] - try: - dtype = ABI_TYPE_TO_DTYPE[dtype_str] - except KeyError: - _raise_return_error(inv, f"unrecognized dtype '{dtype_str}'") - x = DeviceArray(inv.device, - buffer_view, - implicit_host_transfer=True, - override_dtype=dtype) - return x +def _vm_to_ndarray(inv: Invocation, vm_list: VmVariantList, vm_index: int, desc): + # The descriptor for an ndarray is like: + # ["ndarray", "", , ...] + # ex: ['ndarray', 'i32', 1, 25948] + buffer_view = vm_list.get_as_object(vm_index, HalBufferView) + dtype_str = desc[1] + try: + dtype = ABI_TYPE_TO_DTYPE[dtype_str] + except KeyError: + _raise_return_error(inv, f"unrecognized dtype '{dtype_str}'") + x = DeviceArray( + inv.device, buffer_view, implicit_host_transfer=True, override_dtype=dtype + ) + return x def _vm_to_sdict(inv: Invocation, vm_list: VmVariantList, vm_index: int, desc): - # The descriptor for an sdict is like: - # ['sdict', ['key1', value1], ...] - sub_vm_list = vm_list.get_as_list(vm_index) - item_keys = [] - item_descs = [] - for k, d in desc[1:]: - item_keys.append(k) - item_descs.append(d) - py_items = _extract_vm_sequence_to_python(inv, sub_vm_list, item_descs) - return dict(zip(item_keys, py_items)) + # The descriptor for an sdict is like: + # ['sdict', ['key1', value1], ...] + sub_vm_list = vm_list.get_as_list(vm_index) + item_keys = [] + item_descs = [] + for k, d in desc[1:]: + item_keys.append(k) + item_descs.append(d) + py_items = _extract_vm_sequence_to_python(inv, sub_vm_list, item_descs) + return dict(zip(item_keys, py_items)) def _vm_to_slist(inv: Invocation, vm_list: VmVariantList, vm_index: int, desc): - # The descriptor for an slist is like: - # ['slist, item1, ...] - sub_vm_list = vm_list.get_as_list(vm_index) - item_descs = desc[1:] - py_items = _extract_vm_sequence_to_python(inv, sub_vm_list, item_descs) - return py_items + # The descriptor for an slist is like: + # ['slist, item1, ...] + sub_vm_list = vm_list.get_as_list(vm_index) + item_descs = desc[1:] + py_items = _extract_vm_sequence_to_python(inv, sub_vm_list, item_descs) + return py_items def _vm_to_stuple(inv: Invocation, vm_list: VmVariantList, vm_index: int, desc): - return tuple(_vm_to_slist(inv, vm_list, vm_index, desc)) + return tuple(_vm_to_slist(inv, vm_list, vm_index, desc)) def _vm_to_scalar(type_bound: type): + def convert(inv: Invocation, vm_list: VmVariantList, vm_index: int, desc): + value = vm_list.get_variant(vm_index) + if not isinstance(value, type_bound): + raise ReturnError( + f"expected an {type_bound} value but got {value.__class__}" + ) + return value - def convert(inv: Invocation, vm_list: VmVariantList, vm_index: int, desc): - value = vm_list.get_variant(vm_index) - if not isinstance(value, type_bound): - raise ReturnError( - f"expected an {type_bound} value but got {value.__class__}") - return value - - return convert + return convert def _vm_to_pylist(inv: Invocation, vm_list: VmVariantList, vm_index: int, desc): - # The descriptor for a pylist is like: - # ['pylist', element_type] - sub_vm_list = vm_list.get_as_list(vm_index) - element_type_desc = desc[1:] - py_items = _extract_vm_sequence_to_python( - inv, sub_vm_list, element_type_desc * len(sub_vm_list)) - return py_items + # The descriptor for a pylist is like: + # ['pylist', element_type] + sub_vm_list = vm_list.get_as_list(vm_index) + element_type_desc = desc[1:] + py_items = _extract_vm_sequence_to_python( + inv, sub_vm_list, element_type_desc * len(sub_vm_list) + ) + return py_items VM_TO_PYTHON_CONVERTERS = { @@ -270,7 +282,6 @@ def _vm_to_pylist(inv: Invocation, vm_list: VmVariantList, vm_index: int, desc): "slist": _vm_to_slist, "stuple": _vm_to_stuple, "py_homogeneous_list": _vm_to_pylist, - # Scalars. "i8": _vm_to_scalar(int), "i16": _vm_to_scalar(int), @@ -296,96 +307,97 @@ def _vm_to_pylist(inv: Invocation, vm_list: VmVariantList, vm_index: int, desc): # When we get an ndarray as an argument and are implicitly mapping it to a # buffer view, flags for doing so. IMPLICIT_BUFFER_ARG_MEMORY_TYPE = MemoryType.DEVICE_LOCAL -IMPLICIT_BUFFER_ARG_USAGE = (BufferUsage.DEFAULT | BufferUsage.MAPPING) +IMPLICIT_BUFFER_ARG_USAGE = BufferUsage.DEFAULT | BufferUsage.MAPPING def _is_ndarray_descriptor(desc): - return desc and desc[0] == "ndarray" + return desc and desc[0] == "ndarray" def _is_0d_ndarray_descriptor(desc): - # Example: ["ndarray", "f32", 0] - return desc and desc[0] == "ndarray" and desc[2] == 0 + # Example: ["ndarray", "f32", 0] + return desc and desc[0] == "ndarray" and desc[2] == 0 def _cast_scalar_to_ndarray(inv: Invocation, x, desc): - # Example descriptor: ["ndarray", "f32", 0] - dtype_str = desc[1] - try: - dtype = ABI_TYPE_TO_DTYPE[dtype_str] - except KeyError: - _raise_argument_error(inv, f"unrecognized dtype '{dtype_str}'") - return dtype(x) + # Example descriptor: ["ndarray", "f32", 0] + dtype_str = desc[1] + try: + dtype = ABI_TYPE_TO_DTYPE[dtype_str] + except KeyError: + _raise_argument_error(inv, f"unrecognized dtype '{dtype_str}'") + return dtype(x) class ArgumentError(ValueError): - pass + pass class ReturnError(ValueError): - pass + pass -def _raise_argument_error(inv: Invocation, - summary: str, - e: Optional[Exception] = None): - new_e = ArgumentError( - f"Error passing argument: {summary} " - f"(while encoding argument {inv.summarize_arg_error()})") - if e: - raise new_e from e - else: - raise new_e +def _raise_argument_error(inv: Invocation, summary: str, e: Optional[Exception] = None): + new_e = ArgumentError( + f"Error passing argument: {summary} " + f"(while encoding argument {inv.summarize_arg_error()})" + ) + if e: + raise new_e from e + else: + raise new_e -def _raise_return_error(inv: Invocation, - summary: str, - e: Optional[Exception] = None): - new_e = ReturnError(f"Error processing function return: {summary} " - f"(while decoding return {inv.summarize_return_error()})") - if e: - raise new_e from e - else: - raise new_e +def _raise_return_error(inv: Invocation, summary: str, e: Optional[Exception] = None): + new_e = ReturnError( + f"Error processing function return: {summary} " + f"(while decoding return {inv.summarize_return_error()})" + ) + if e: + raise new_e from e + else: + raise new_e def _extract_vm_sequence_to_python(inv: Invocation, vm_list, descs): - vm_list_arity = len(vm_list) - if descs is None: - descs = [None] * vm_list_arity - elif vm_list_arity != len(descs): - _raise_return_error( - inv, f"mismatched return arity: {vm_list_arity} vs {len(descs)}") - results = [] - for vm_index, desc in zip(range(vm_list_arity), descs): - inv.current_return_list = vm_list - inv.current_return_index = vm_index - inv.current_desc = desc - if desc is None: - # Dynamic (non reflection mode). - converted = vm_list.get_variant(vm_index) - # Special case: Upgrade HalBufferView to a DeviceArray. We do that here - # since this is higher level and it preserves layering. Note that - # the reflection case also does this conversion. - if isinstance(converted, VmRef): - converted_buffer_view = converted.deref(HalBufferView, True) - if converted_buffer_view: - converted = DeviceArray(inv.device, - converted_buffer_view, - implicit_host_transfer=True) - else: - # Known type descriptor. - vm_type = desc if isinstance(desc, str) else desc[0] - try: - converter = VM_TO_PYTHON_CONVERTERS[vm_type] - except KeyError: - _raise_return_error(inv, f"cannot map VM type to Python: {vm_type}") - try: - converted = converter(inv, vm_list, vm_index, desc) - except ReturnError: - raise - except Exception as e: - _raise_return_error(inv, f"exception converting from VM type to Python", - e) - results.append(converted) - return results + vm_list_arity = len(vm_list) + if descs is None: + descs = [None] * vm_list_arity + elif vm_list_arity != len(descs): + _raise_return_error( + inv, f"mismatched return arity: {vm_list_arity} vs {len(descs)}" + ) + results = [] + for vm_index, desc in zip(range(vm_list_arity), descs): + inv.current_return_list = vm_list + inv.current_return_index = vm_index + inv.current_desc = desc + if desc is None: + # Dynamic (non reflection mode). + converted = vm_list.get_variant(vm_index) + # Special case: Upgrade HalBufferView to a DeviceArray. We do that here + # since this is higher level and it preserves layering. Note that + # the reflection case also does this conversion. + if isinstance(converted, VmRef): + converted_buffer_view = converted.deref(HalBufferView, True) + if converted_buffer_view: + converted = DeviceArray( + inv.device, converted_buffer_view, implicit_host_transfer=True + ) + else: + # Known type descriptor. + vm_type = desc if isinstance(desc, str) else desc[0] + try: + converter = VM_TO_PYTHON_CONVERTERS[vm_type] + except KeyError: + _raise_return_error(inv, f"cannot map VM type to Python: {vm_type}") + try: + converted = converter(inv, vm_list, vm_index, desc) + except ReturnError: + raise + except Exception as e: + _raise_return_error( + inv, f"exception converting from VM type to Python", e + ) + results.append(converted) + return results diff --git a/runtime/bindings/python/iree/runtime/scripts/iree_benchmark_module/__main__.py b/runtime/bindings/python/iree/runtime/scripts/iree_benchmark_module/__main__.py index f9672dd4b3b3..528bc2ac2d46 100644 --- a/runtime/bindings/python/iree/runtime/scripts/iree_benchmark_module/__main__.py +++ b/runtime/bindings/python/iree/runtime/scripts/iree_benchmark_module/__main__.py @@ -10,12 +10,11 @@ def main(args=None): - if args is None: - args = sys.argv[1:] - exe = os.path.join(os.path.dirname(__file__), "..", "..", - "iree-benchmark-module") - return subprocess.call(args=[exe] + args) + if args is None: + args = sys.argv[1:] + exe = os.path.join(os.path.dirname(__file__), "..", "..", "iree-benchmark-module") + return subprocess.call(args=[exe] + args) if __name__ == "__main__": - sys.exit(main()) + sys.exit(main()) diff --git a/runtime/bindings/python/iree/runtime/scripts/iree_benchmark_trace/__main__.py b/runtime/bindings/python/iree/runtime/scripts/iree_benchmark_trace/__main__.py index 007e96ecd429..ba64bd658af6 100644 --- a/runtime/bindings/python/iree/runtime/scripts/iree_benchmark_trace/__main__.py +++ b/runtime/bindings/python/iree/runtime/scripts/iree_benchmark_trace/__main__.py @@ -10,12 +10,11 @@ def main(args=None): - if args is None: - args = sys.argv[1:] - exe = os.path.join(os.path.dirname(__file__), "..", "..", - "iree-benchmark-trace") - return subprocess.call(args=[exe] + args) + if args is None: + args = sys.argv[1:] + exe = os.path.join(os.path.dirname(__file__), "..", "..", "iree-benchmark-trace") + return subprocess.call(args=[exe] + args) if __name__ == "__main__": - sys.exit(main()) + sys.exit(main()) diff --git a/runtime/bindings/python/iree/runtime/scripts/iree_run_module/__main__.py b/runtime/bindings/python/iree/runtime/scripts/iree_run_module/__main__.py index a5509a3d013a..7fdac38eef85 100644 --- a/runtime/bindings/python/iree/runtime/scripts/iree_run_module/__main__.py +++ b/runtime/bindings/python/iree/runtime/scripts/iree_run_module/__main__.py @@ -10,11 +10,11 @@ def main(args=None): - if args is None: - args = sys.argv[1:] - exe = os.path.join(os.path.dirname(__file__), "..", "..", "iree-run-module") - return subprocess.call(args=[exe] + args) + if args is None: + args = sys.argv[1:] + exe = os.path.join(os.path.dirname(__file__), "..", "..", "iree-run-module") + return subprocess.call(args=[exe] + args) if __name__ == "__main__": - sys.exit(main()) + sys.exit(main()) diff --git a/runtime/bindings/python/iree/runtime/scripts/iree_run_trace/__main__.py b/runtime/bindings/python/iree/runtime/scripts/iree_run_trace/__main__.py index 08dced348eba..d75a65b244ac 100644 --- a/runtime/bindings/python/iree/runtime/scripts/iree_run_trace/__main__.py +++ b/runtime/bindings/python/iree/runtime/scripts/iree_run_trace/__main__.py @@ -10,11 +10,11 @@ def main(args=None): - if args is None: - args = sys.argv[1:] - exe = os.path.join(os.path.dirname(__file__), "..", "..", "iree-run-trace") - return subprocess.call(args=[exe] + args) + if args is None: + args = sys.argv[1:] + exe = os.path.join(os.path.dirname(__file__), "..", "..", "iree-run-trace") + return subprocess.call(args=[exe] + args) if __name__ == "__main__": - sys.exit(main()) + sys.exit(main()) diff --git a/runtime/bindings/python/iree/runtime/scripts/iree_tracy_capture/__main__.py b/runtime/bindings/python/iree/runtime/scripts/iree_tracy_capture/__main__.py index 58f2118d7c3b..f5e36fab6d90 100644 --- a/runtime/bindings/python/iree/runtime/scripts/iree_tracy_capture/__main__.py +++ b/runtime/bindings/python/iree/runtime/scripts/iree_tracy_capture/__main__.py @@ -10,12 +10,11 @@ def main(args=None): - if args is None: - args = sys.argv[1:] - exe = os.path.join(os.path.dirname(__file__), "..", "..", - "iree-tracy-capture") - return subprocess.call(args=[exe] + args) + if args is None: + args = sys.argv[1:] + exe = os.path.join(os.path.dirname(__file__), "..", "..", "iree-tracy-capture") + return subprocess.call(args=[exe] + args) if __name__ == "__main__": - sys.exit(main()) + sys.exit(main()) diff --git a/runtime/bindings/python/iree/runtime/system_api.py b/runtime/bindings/python/iree/runtime/system_api.py index 2423d2f5f2ad..8a2c204141bf 100644 --- a/runtime/bindings/python/iree/runtime/system_api.py +++ b/runtime/bindings/python/iree/runtime/system_api.py @@ -49,262 +49,273 @@ class Config: - """System configuration.""" - - device: _binding.HalDevice - vm_instance: _binding.VmInstance - default_vm_modules: Tuple[_binding.VmModule, ...] - tracer: Optional[tracing.Tracer] - - def __init__(self, - driver_name: Optional[str] = None, - *, - device: Optional[_binding.HalDevice] = None, - tracer: Optional[tracing.Tracer] = None): - # Either use an explicit device or auto config based on driver names. - if device is not None and driver_name is not None: - raise ValueError( - "Either 'device' or 'driver_name' can be specified (not both)") - if device is not None: - self.device = device - else: - self.device = get_first_device( - driver_name.split(",") if driver_name is not None else None) - - self.vm_instance = _binding.VmInstance() - hal_module = _binding.create_hal_module(self.vm_instance, self.device) - self.default_vm_modules = (hal_module,) - self.tracer = tracer or tracing.get_default_tracer() - if self.tracer and self.tracer.enabled: - logging.info("IREE runtime tracing calls to path: %s", - self.tracer.trace_path) - else: - self.tracer = None - - -def _bool_to_int8( - array: Any) -> Optional[Union[np.ndarray, List[Any], Tuple[Any]]]: - if not isinstance(array, np.ndarray): + """System configuration.""" + + device: _binding.HalDevice + vm_instance: _binding.VmInstance + default_vm_modules: Tuple[_binding.VmModule, ...] + tracer: Optional[tracing.Tracer] + + def __init__( + self, + driver_name: Optional[str] = None, + *, + device: Optional[_binding.HalDevice] = None, + tracer: Optional[tracing.Tracer] = None, + ): + # Either use an explicit device or auto config based on driver names. + if device is not None and driver_name is not None: + raise ValueError( + "Either 'device' or 'driver_name' can be specified (not both)" + ) + if device is not None: + self.device = device + else: + self.device = get_first_device( + driver_name.split(",") if driver_name is not None else None + ) + + self.vm_instance = _binding.VmInstance() + hal_module = _binding.create_hal_module(self.vm_instance, self.device) + self.default_vm_modules = (hal_module,) + self.tracer = tracer or tracing.get_default_tracer() + if self.tracer and self.tracer.enabled: + logging.info( + "IREE runtime tracing calls to path: %s", self.tracer.trace_path + ) + else: + self.tracer = None + + +def _bool_to_int8(array: Any) -> Optional[Union[np.ndarray, List[Any], Tuple[Any]]]: + if not isinstance(array, np.ndarray): + return array + + # IREE models booleans as i8s. + # TODO(#5359): This cast should be moved into the function abi. + if array.dtype == bool: + array = array.astype(np.int8) return array - # IREE models booleans as i8s. - # TODO(#5359): This cast should be moved into the function abi. - if array.dtype == bool: - array = array.astype(np.int8) - return array - -def normalize_value( - value: Any) -> Optional[Union[np.ndarray, List[Any], Tuple[Any]]]: - """Normalizes the given value for input to (or comparison with) IREE.""" - if value is None: - # Exclude None from falling through to blanket np.asarray conversion. - return value +def normalize_value(value: Any) -> Optional[Union[np.ndarray, List[Any], Tuple[Any]]]: + """Normalizes the given value for input to (or comparison with) IREE.""" + if value is None: + # Exclude None from falling through to blanket np.asarray conversion. + return value - if isinstance(value, (list, tuple, dict)): - return value + if isinstance(value, (list, tuple, dict)): + return value - array = np.asarray(value) - # TODO(#5359): Move into the function abi. - if isinstance(value, (bool, int, float)): - # Manually convert ints and floats to 32 bits. - if array.dtype == np.float64: - array = array.astype(np.float32) - elif array.dtype == np.int64: - array = array.astype(np.int32) + array = np.asarray(value) + # TODO(#5359): Move into the function abi. + if isinstance(value, (bool, int, float)): + # Manually convert ints and floats to 32 bits. + if array.dtype == np.float64: + array = array.astype(np.float32) + elif array.dtype == np.int64: + array = array.astype(np.int32) - return array + return array def _convert_lists_to_tuples(pytree): - if isinstance(pytree, Sequence): - return tuple(_convert_lists_to_tuples(leaf) for leaf in pytree) - elif isinstance(pytree, Mapping): - for key in pytree: - pytree[key] = _convert_lists_to_tuples(pytree[key]) - return pytree - else: - return pytree + if isinstance(pytree, Sequence): + return tuple(_convert_lists_to_tuples(leaf) for leaf in pytree) + elif isinstance(pytree, Mapping): + for key in pytree: + pytree[key] = _convert_lists_to_tuples(pytree[key]) + return pytree + else: + return pytree class BoundModule: - """Wraps a VmModule with its context and provides nice python accessors. - - Resolves item access (["foo"]) as function resolution. - """ - - def __init__(self, context: SystemContext, vm_module: _binding.VmModule): - self._context = context - self._tracer = self._context._config.tracer - self._vm_module = vm_module - self._lazy_functions = dict() - - # Let the tracing infra create a traced module. - self.traced_module = None - if self._tracer: - self.traced_module = self._tracer.persist_vm_module(vm_module) - - @property - def name(self): - return self._vm_module.name - - @property - def vm_module(self): - return self._vm_module - - def __getattr__(self, name): - try: - return self[name] - except KeyError: - raise AttributeError(name) - - def __getitem__(self, name): - vm_function = self._lazy_functions.get(name) - if vm_function is not None: - return vm_function - - vm_function = self._vm_module.lookup_function(name) - if vm_function is None: - raise KeyError(f"Function '{name}' not found in module '{self}'") - - # TODO: Needing to know the precise device to allocate on here is bad - # layering and will need to be fixed in some fashion if/when doing - # heterogenous dispatch. - return FunctionInvoker(self._context.vm_context, - self._context.config.device, vm_function, - self._context._tracer) - - def __repr__(self): - return f"" + """Wraps a VmModule with its context and provides nice python accessors. + + Resolves item access (["foo"]) as function resolution. + """ + + def __init__(self, context: SystemContext, vm_module: _binding.VmModule): + self._context = context + self._tracer = self._context._config.tracer + self._vm_module = vm_module + self._lazy_functions = dict() + + # Let the tracing infra create a traced module. + self.traced_module = None + if self._tracer: + self.traced_module = self._tracer.persist_vm_module(vm_module) + + @property + def name(self): + return self._vm_module.name + + @property + def vm_module(self): + return self._vm_module + + def __getattr__(self, name): + try: + return self[name] + except KeyError: + raise AttributeError(name) + + def __getitem__(self, name): + vm_function = self._lazy_functions.get(name) + if vm_function is not None: + return vm_function + + vm_function = self._vm_module.lookup_function(name) + if vm_function is None: + raise KeyError(f"Function '{name}' not found in module '{self}'") + + # TODO: Needing to know the precise device to allocate on here is bad + # layering and will need to be fixed in some fashion if/when doing + # heterogenous dispatch. + return FunctionInvoker( + self._context.vm_context, + self._context.config.device, + vm_function, + self._context._tracer, + ) + + def __repr__(self): + return f"" class BoundModules(dict): - """Provides nice python accessors for a dict of BoundModules.""" + """Provides nice python accessors for a dict of BoundModules.""" - def __getattr__(self, name): - try: - return self[name] - except KeyError: - raise AttributeError(name) + def __getattr__(self, name): + try: + return self[name] + except KeyError: + raise AttributeError(name) class SystemContext: - """Global system.""" - - def __init__(self, vm_modules=None, config: Optional[Config] = None): - self._config = config if config is not None else Config() - self._is_dynamic = vm_modules is None - if self._is_dynamic: - init_vm_modules = None - else: - init_vm_modules = self._config.default_vm_modules + tuple(vm_modules) - - self._vm_context = _binding.VmContext(instance=self._config.vm_instance, - modules=init_vm_modules) - - if self._is_dynamic: - self._vm_context.register_modules(self._config.default_vm_modules) - self._bound_modules = BoundModules([ - (m.name, BoundModule(self, m)) - for m in self._config.default_vm_modules - ]) - else: - self._bound_modules = BoundModules([ - (m.name, BoundModule(self, m)) for m in init_vm_modules - ]) - - self._tracer = None # type: Optional[tracing.ContextTracer] - if self._config.tracer: - self._tracer = tracing.ContextTracer( - self._config.tracer, - is_dynamic=self._is_dynamic, - modules=[bm.traced_module for bm in self._bound_modules.values()]) - - @property - def vm_context(self) -> _binding.VmContext: - return self._vm_context - - @property - def is_dynamic(self) -> bool: - return self._is_dynamic - - @property - def config(self) -> Config: - return self._config - - @property - def instance(self) -> _binding.VmInstance: - return self._config.vm_instance - - @property - def modules(self) -> BoundModules: - return self._bound_modules - - def add_module_dependency(self, name, minimum_version=0): - resolved_module = _binding.VmModule.resolve_module_dependency( - self._config.vm_instance, name, minimum_version) - self._vm_context.register_modules([resolved_module]) - - def add_vm_modules(self, vm_modules): - assert self._is_dynamic, "Cannot 'add_module' on a static context" - for m in vm_modules: - if m.name in self._bound_modules: - raise ValueError(f"Attempt to register duplicate VmModule: '{m.name}'") - bound_module = BoundModule(self, m) - self._bound_modules[m.name] = bound_module - if self._tracer: - self._tracer.add_module(bound_module.traced_module) - self._vm_context.register_modules(vm_modules) - - def add_vm_module(self, vm_module): - self.add_vm_modules((vm_module,)) + """Global system.""" + + def __init__(self, vm_modules=None, config: Optional[Config] = None): + self._config = config if config is not None else Config() + self._is_dynamic = vm_modules is None + if self._is_dynamic: + init_vm_modules = None + else: + init_vm_modules = self._config.default_vm_modules + tuple(vm_modules) + + self._vm_context = _binding.VmContext( + instance=self._config.vm_instance, modules=init_vm_modules + ) + + if self._is_dynamic: + self._vm_context.register_modules(self._config.default_vm_modules) + self._bound_modules = BoundModules( + [ + (m.name, BoundModule(self, m)) + for m in self._config.default_vm_modules + ] + ) + else: + self._bound_modules = BoundModules( + [(m.name, BoundModule(self, m)) for m in init_vm_modules] + ) + + self._tracer = None # type: Optional[tracing.ContextTracer] + if self._config.tracer: + self._tracer = tracing.ContextTracer( + self._config.tracer, + is_dynamic=self._is_dynamic, + modules=[bm.traced_module for bm in self._bound_modules.values()], + ) + + @property + def vm_context(self) -> _binding.VmContext: + return self._vm_context + + @property + def is_dynamic(self) -> bool: + return self._is_dynamic + + @property + def config(self) -> Config: + return self._config + + @property + def instance(self) -> _binding.VmInstance: + return self._config.vm_instance + + @property + def modules(self) -> BoundModules: + return self._bound_modules + + def add_module_dependency(self, name, minimum_version=0): + resolved_module = _binding.VmModule.resolve_module_dependency( + self._config.vm_instance, name, minimum_version + ) + self._vm_context.register_modules([resolved_module]) + + def add_vm_modules(self, vm_modules): + assert self._is_dynamic, "Cannot 'add_module' on a static context" + for m in vm_modules: + if m.name in self._bound_modules: + raise ValueError(f"Attempt to register duplicate VmModule: '{m.name}'") + bound_module = BoundModule(self, m) + self._bound_modules[m.name] = bound_module + if self._tracer: + self._tracer.add_module(bound_module.traced_module) + self._vm_context.register_modules(vm_modules) + + def add_vm_module(self, vm_module): + self.add_vm_modules((vm_module,)) def load_vm_modules(*vm_modules, config: Optional[Config] = None): - """Loads VmModules into a new SystemContext and returns them.""" - context = SystemContext(vm_modules=vm_modules, config=config) - bound_modules = [context.modules[m.name] for m in vm_modules] - return bound_modules + """Loads VmModules into a new SystemContext and returns them.""" + context = SystemContext(vm_modules=vm_modules, config=config) + bound_modules = [context.modules[m.name] for m in vm_modules] + return bound_modules def load_vm_module(vm_module, config: Optional[Config] = None): - """Loads a VmModule into a new SystemContext and returns it.""" - return load_vm_modules(vm_module, config=config)[0] - - -def load_vm_flatbuffer(vm_flatbuffer: bytes, - *, - driver: Optional[str] = None, - backend: Optional[str] = None) -> BoundModule: - """Loads a VM Flatbuffer into a callable module. - - Either 'driver' or 'backend' must be specified. - """ - if driver is None and backend is None: - raise ValueError("Either 'driver' or 'backend' must be specified, but got " - "'None' for both.") - if backend is not None and driver is not None: - raise ValueError("Cannot specify both 'driver' and a 'backend' to infer " - "the driver from.") - if backend is not None: - driver = TARGET_BACKEND_TO_DRIVER[backend] - config = Config(driver) - vm_module = _binding.VmModule.from_flatbuffer(config.vm_instance, - vm_flatbuffer) - bound_module = load_vm_module(vm_module, config) - return bound_module + """Loads a VmModule into a new SystemContext and returns it.""" + return load_vm_modules(vm_module, config=config)[0] + + +def load_vm_flatbuffer( + vm_flatbuffer: bytes, *, driver: Optional[str] = None, backend: Optional[str] = None +) -> BoundModule: + """Loads a VM Flatbuffer into a callable module. + + Either 'driver' or 'backend' must be specified. + """ + if driver is None and backend is None: + raise ValueError( + "Either 'driver' or 'backend' must be specified, but got " + "'None' for both." + ) + if backend is not None and driver is not None: + raise ValueError( + "Cannot specify both 'driver' and a 'backend' to infer " "the driver from." + ) + if backend is not None: + driver = TARGET_BACKEND_TO_DRIVER[backend] + config = Config(driver) + vm_module = _binding.VmModule.from_flatbuffer(config.vm_instance, vm_flatbuffer) + bound_module = load_vm_module(vm_module, config) + return bound_module # TODO: There should be an API for mmap'ing the file which should be used # instead of reading into memory. -def load_vm_flatbuffer_file(path: str, - *, - driver: Optional[str] = None, - backend: Optional[str] = None) -> BoundModule: - """Loads a file containing a VM Flatbuffer into a callable module. - - Either 'driver' or 'backend' must be specified. - """ - with open(path, "rb") as f: - vm_flatbuffer = f.read() - return load_vm_flatbuffer(vm_flatbuffer, driver=driver, backend=backend) +def load_vm_flatbuffer_file( + path: str, *, driver: Optional[str] = None, backend: Optional[str] = None +) -> BoundModule: + """Loads a file containing a VM Flatbuffer into a callable module. + + Either 'driver' or 'backend' must be specified. + """ + with open(path, "rb") as f: + vm_flatbuffer = f.read() + return load_vm_flatbuffer(vm_flatbuffer, driver=driver, backend=backend) diff --git a/runtime/bindings/python/iree/runtime/system_setup.py b/runtime/bindings/python/iree/runtime/system_setup.py index 99c276d999a1..0560003d5f1d 100644 --- a/runtime/bindings/python/iree/runtime/system_setup.py +++ b/runtime/bindings/python/iree/runtime/system_setup.py @@ -21,78 +21,80 @@ def query_available_drivers() -> Collection[str]: - """Returns a collection of driver names that are available.""" - return HalDriver.query() + """Returns a collection of driver names that are available.""" + return HalDriver.query() def get_driver(device_uri: str) -> HalDriver: - """Returns a HAL driver by device_uri (or driver name).""" - return get_cached_hal_driver(device_uri) + """Returns a HAL driver by device_uri (or driver name).""" + return get_cached_hal_driver(device_uri) def get_device(device_uri: str, cache: bool = True) -> HalDevice: - """Gets a cached device by URI. - - Args: - device_uri: The URI of the device, either just a driver name for the - default or a fully qualified "driver://path?params". - cache: Whether to cache the device (default True). - Returns: - A HalDevice. - """ - with _LOCK: - if cache: - existing = _GLOBAL_DEVICES_BY_URI.get(device_uri) - if existing is not None: - return existing - - driver = get_driver(device_uri) - device = driver.create_device_by_uri(device_uri) - - if cache: - _GLOBAL_DEVICES_BY_URI[device_uri] = device - return device - - -def get_first_device(device_uris: Optional[Sequence[str]] = None, - cache: bool = True) -> HalDevice: - """Gets the first valid (cached) device for a prioritized list of names. - - If no driver_names are given, and an environment variable of - IREE_DEFAULT_DEVICE is available, then it is treated as a comma delimitted - list of driver names to try. - - This is meant to be used for default/automagic startup and is not suitable - for any kind of multi-device setup. - - Args: - device_uris: Explicit list of device URIs to try. - cache: Whether to cache the device (default True). - Returns: - A HalDevice instance. - """ - # Parse from environment or defaults if not explicitly provided. - if device_uris is None: - device_uris = os.environ.get("IREE_DEFAULT_DEVICE") + """Gets a cached device by URI. + + Args: + device_uri: The URI of the device, either just a driver name for the + default or a fully qualified "driver://path?params". + cache: Whether to cache the device (default True). + Returns: + A HalDevice. + """ + with _LOCK: + if cache: + existing = _GLOBAL_DEVICES_BY_URI.get(device_uri) + if existing is not None: + return existing + + driver = get_driver(device_uri) + device = driver.create_device_by_uri(device_uri) + + if cache: + _GLOBAL_DEVICES_BY_URI[device_uri] = device + return device + + +def get_first_device( + device_uris: Optional[Sequence[str]] = None, cache: bool = True +) -> HalDevice: + """Gets the first valid (cached) device for a prioritized list of names. + + If no driver_names are given, and an environment variable of + IREE_DEFAULT_DEVICE is available, then it is treated as a comma delimitted + list of driver names to try. + + This is meant to be used for default/automagic startup and is not suitable + for any kind of multi-device setup. + + Args: + device_uris: Explicit list of device URIs to try. + cache: Whether to cache the device (default True). + Returns: + A HalDevice instance. + """ + # Parse from environment or defaults if not explicitly provided. if device_uris is None: - device_uris = DEFAULT_DRIVER_NAMES - device_uris = [s.strip() for s in device_uris.split(",")] - - last_exception = None - for device_uri in device_uris: - try: - return get_device(device_uri, cache=cache) - except ValueError: - # Driver not known. - continue - except Exception as ex: - # Failure to create driver. - logging.info(f"Failed to create device {device_uri}: {ex}") - last_exception = ex - continue - - if last_exception: - raise RuntimeError("Could not create device. " - "Exception for last tried follows.") from last_exception - else: - raise ValueError(f"No device found from list {device_uris}") + device_uris = os.environ.get("IREE_DEFAULT_DEVICE") + if device_uris is None: + device_uris = DEFAULT_DRIVER_NAMES + device_uris = [s.strip() for s in device_uris.split(",")] + + last_exception = None + for device_uri in device_uris: + try: + return get_device(device_uri, cache=cache) + except ValueError: + # Driver not known. + continue + except Exception as ex: + # Failure to create driver. + logging.info(f"Failed to create device {device_uri}: {ex}") + last_exception = ex + continue + + if last_exception: + raise RuntimeError( + "Could not create device. " "Exception for last tried follows." + ) from last_exception + else: + raise ValueError(f"No device found from list {device_uris}") diff --git a/runtime/bindings/python/iree/runtime/tracing.py b/runtime/bindings/python/iree/runtime/tracing.py index 8650654a8f9f..bda1df9f2890 100644 --- a/runtime/bindings/python/iree/runtime/tracing.py +++ b/runtime/bindings/python/iree/runtime/tracing.py @@ -16,11 +16,11 @@ from . import _binding try: - import yaml + import yaml except ModuleNotFoundError: - _has_yaml = False + _has_yaml = False else: - _has_yaml = True + _has_yaml = True __all__ = [ "get_default_tracer", @@ -32,139 +32,150 @@ class Tracer: - """Object for tracing calls made into the runtime.""" - - def __init__(self, trace_path: str): - if not _has_yaml: - self.enabled = False - logging.warning("PyYAML not installed: tracing will be disabled") - return - self.enabled = True - self.trace_path = trace_path - os.makedirs(trace_path, exist_ok=True) - self._name_count = dict() # type: Dict[str, int] - - def persist_vm_module(self, vm_module: _binding.VmModule) -> "TracedModule": - # Depending on how the module was created, there are different bits - # of information available to reconstruct. - name = vm_module.name - flatbuffer_blob = vm_module.stashed_flatbuffer_blob - if flatbuffer_blob: - save_path = os.path.join(self.trace_path, - self.get_unique_name(f"{name}.vmfb")) - logging.info("Saving traced vmfb to %s", save_path) - with open(save_path, "wb") as f: - f.write(flatbuffer_blob) - return TracedModule(self, vm_module, save_path) - - # No persistent form, but likely they are built-in modules. - return TracedModule(self, vm_module) - - def get_unique_name(self, local_name: str) -> str: - if local_name not in self._name_count: - self._name_count[local_name] = 1 - return local_name - stem, ext = os.path.splitext(local_name) - index = self._name_count[local_name] - self._name_count[local_name] += 1 - unique_name = f"{stem}__{index}{ext}" - return unique_name + """Object for tracing calls made into the runtime.""" + + def __init__(self, trace_path: str): + if not _has_yaml: + self.enabled = False + logging.warning("PyYAML not installed: tracing will be disabled") + return + self.enabled = True + self.trace_path = trace_path + os.makedirs(trace_path, exist_ok=True) + self._name_count = dict() # type: Dict[str, int] + + def persist_vm_module(self, vm_module: _binding.VmModule) -> "TracedModule": + # Depending on how the module was created, there are different bits + # of information available to reconstruct. + name = vm_module.name + flatbuffer_blob = vm_module.stashed_flatbuffer_blob + if flatbuffer_blob: + save_path = os.path.join( + self.trace_path, self.get_unique_name(f"{name}.vmfb") + ) + logging.info("Saving traced vmfb to %s", save_path) + with open(save_path, "wb") as f: + f.write(flatbuffer_blob) + return TracedModule(self, vm_module, save_path) + + # No persistent form, but likely they are built-in modules. + return TracedModule(self, vm_module) + + def get_unique_name(self, local_name: str) -> str: + if local_name not in self._name_count: + self._name_count[local_name] = 1 + return local_name + stem, ext = os.path.splitext(local_name) + index = self._name_count[local_name] + self._name_count[local_name] += 1 + unique_name = f"{stem}__{index}{ext}" + return unique_name class TracedModule: - """Wraps a VmModule with additional information for tracing.""" - - def __init__(self, - parent: Tracer, - vm_module: _binding.VmModule, - vmfb_path: Optional[str] = None): - self._parent = parent - self._vm_module = vm_module - self._vmfb_path = vmfb_path - - def serialize(self): - module_record = {"name": self._vm_module.name} - if self._vmfb_path: - module_record["type"] = "bytecode" - module_record["path"] = os.path.relpath(self._vmfb_path, - self._parent.trace_path) - else: - module_record["type"] = "builtin" - - return module_record + """Wraps a VmModule with additional information for tracing.""" + + def __init__( + self, + parent: Tracer, + vm_module: _binding.VmModule, + vmfb_path: Optional[str] = None, + ): + self._parent = parent + self._vm_module = vm_module + self._vmfb_path = vmfb_path + + def serialize(self): + module_record = {"name": self._vm_module.name} + if self._vmfb_path: + module_record["type"] = "bytecode" + module_record["path"] = os.path.relpath( + self._vmfb_path, self._parent.trace_path + ) + else: + module_record["type"] = "builtin" + + return module_record class ContextTracer: - """Traces invocations against a context.""" - - def __init__(self, parent: Tracer, is_dynamic: bool, - modules: Sequence[TracedModule]): - self._parent = parent - self._modules = list(modules) # type: List[TracedModule] - self._frame_count = 0 - self._file_path = os.path.join(parent.trace_path, - parent.get_unique_name("calls.yaml")) - if os.path.exists(self._file_path): - # Truncate the file. - with open(self._file_path, "wt"): - pass - else: - os.makedirs(os.path.dirname(parent.trace_path), exist_ok=True) - logging.info("Tracing context events to: %s", self._file_path) - self.emit_frame({ - "type": "context_load", - }) - for module in self._modules: - self.emit_frame({ - "type": "module_load", - "module": module.serialize(), - }) - - def add_module(self, module: TracedModule): - self._modules.append(module) - self.emit_frame({ - "type": "module_load", - "module": module.serialize(), - }) - - def start_call(self, function: _binding.VmFunction): - logging.info("Tracing call to %s.%s", function.module_name, function.name) - - # Start assembling the call record. - record = { - "type": "call", - "function": "%s.%s" % (function.module_name, function.name), - } - return CallTrace(self, record) - - def emit_frame(self, frame: dict): - self._frame_count += 1 - with open(self._file_path, "at") as f: - if self._frame_count != 1: - f.write("---\n") - contents = yaml.dump(frame, sort_keys=False) - f.write(contents) + """Traces invocations against a context.""" + + def __init__( + self, parent: Tracer, is_dynamic: bool, modules: Sequence[TracedModule] + ): + self._parent = parent + self._modules = list(modules) # type: List[TracedModule] + self._frame_count = 0 + self._file_path = os.path.join( + parent.trace_path, parent.get_unique_name("calls.yaml") + ) + if os.path.exists(self._file_path): + # Truncate the file. + with open(self._file_path, "wt"): + pass + else: + os.makedirs(os.path.dirname(parent.trace_path), exist_ok=True) + logging.info("Tracing context events to: %s", self._file_path) + self.emit_frame( + { + "type": "context_load", + } + ) + for module in self._modules: + self.emit_frame( + { + "type": "module_load", + "module": module.serialize(), + } + ) + + def add_module(self, module: TracedModule): + self._modules.append(module) + self.emit_frame( + { + "type": "module_load", + "module": module.serialize(), + } + ) + + def start_call(self, function: _binding.VmFunction): + logging.info("Tracing call to %s.%s", function.module_name, function.name) + + # Start assembling the call record. + record = { + "type": "call", + "function": "%s.%s" % (function.module_name, function.name), + } + return CallTrace(self, record) + + def emit_frame(self, frame: dict): + self._frame_count += 1 + with open(self._file_path, "at") as f: + if self._frame_count != 1: + f.write("---\n") + contents = yaml.dump(frame, sort_keys=False) + f.write(contents) class CallTrace: + def __init__(self, parent: ContextTracer, record: dict): + self._parent = parent + self._record = record - def __init__(self, parent: ContextTracer, record: dict): - self._parent = parent - self._record = record - - def add_vm_list(self, vm_list: _binding.VmVariantList, key: str): - mapped = [] - for i in range(len(vm_list)): - mapped.append(vm_list.get_serialized_trace_value(i)) - self._record[key] = mapped + def add_vm_list(self, vm_list: _binding.VmVariantList, key: str): + mapped = [] + for i in range(len(vm_list)): + mapped.append(vm_list.get_serialized_trace_value(i)) + self._record[key] = mapped - def end_call(self): - self._parent.emit_frame(self._record) + def end_call(self): + self._parent.emit_frame(self._record) def get_default_tracer() -> Optional[Tracer]: - """Gets a default call tracer based on environment variables.""" - default_path = os.getenv(TRACE_PATH_ENV_KEY) - if not default_path: - return None - return Tracer(default_path) + """Gets a default call tracer based on environment variables.""" + default_path = os.getenv(TRACE_PATH_ENV_KEY) + if not default_path: + return None + return Tracer(default_path) diff --git a/runtime/bindings/python/tests/array_interop_test.py b/runtime/bindings/python/tests/array_interop_test.py index 602c1db08d10..89c325ad8d92 100644 --- a/runtime/bindings/python/tests/array_interop_test.py +++ b/runtime/bindings/python/tests/array_interop_test.py @@ -13,156 +13,156 @@ class DeviceHalTest(unittest.TestCase): - - def setUp(self): - super().setUp() - self.device = iree.runtime.get_device("local-task") - self.allocator = self.device.allocator - # Make sure device setup maintains proper references. - gc.collect() - - def testGcShutdownFiasco(self): - init_ary = np.zeros([3, 4], dtype=np.int32) + 2 - ary = iree.runtime.asdevicearray(self.device, init_ary) - - # Drop all references to backing objects in reverse order to try to - # trigger heap use-after-free on bad shutdown order. - self.allocator = None - gc.collect() - self.device = None - gc.collect() - - # Now drop the ary and make sure nothing crashes (which would indicate - # a reference counting problem of some kind): The array should retain - # everything that it needs to stay live. - ary = None - gc.collect() - - def testMetadataAttributes(self): - init_ary = np.zeros([3, 4], dtype=np.int32) + 2 - ary = iree.runtime.asdevicearray(self.device, init_ary) - self.assertEqual([3, 4], ary.shape) - self.assertEqual(np.int32, ary.dtype) - - def testExplicitHostTransfer(self): - init_ary = np.zeros([3, 4], dtype=np.int32) + 2 - ary = iree.runtime.asdevicearray(self.device, init_ary) - self.assertEqual(repr(ary), "") - self.assertFalse(ary.is_host_accessible) - - # Explicit transfer. - cp = ary.to_host() - np.testing.assert_array_equal(cp, init_ary) - self.assertTrue(ary.is_host_accessible) - - def testOverrideDtype(self): - init_ary = np.zeros([3, 4], dtype=np.int32) + 2 - buffer_view = self.allocator.allocate_buffer_copy( - memory_type=iree.runtime.MemoryType.DEVICE_LOCAL, - allowed_usage=iree.runtime.BufferUsage.DEFAULT, - buffer=init_ary, - element_type=iree.runtime.HalElementType.SINT_32) - - ary = iree.runtime.DeviceArray(self.device, - buffer_view, - override_dtype=np.float32) - - # Explicit transfer. - cp = ary.to_host() - self.assertEqual(cp.dtype, np.float32) - np.testing.assert_array_equal(cp, init_ary.astype(np.float32)) - self.assertTrue(ary.is_host_accessible) - - def testIllegalImplicitHostTransfer(self): - init_ary = np.zeros([3, 4], dtype=np.int32) + 2 - ary = iree.runtime.asdevicearray(self.device, init_ary) - # Implicit transfer. - with self.assertRaises(ValueError): - _ = np.asarray(ary) - - def testImplicitHostArithmetic(self): - init_ary = np.zeros([3, 4], dtype=np.int32) + 2 - ary = iree.runtime.asdevicearray(self.device, - init_ary, - implicit_host_transfer=True) - sum = ary + init_ary - np.testing.assert_array_equal(sum, init_ary + 2) - self.assertTrue(ary.is_host_accessible) - - def testArrayFunctions(self): - init_ary = np.zeros([3, 4], dtype=np.float32) + 2 - ary = iree.runtime.asdevicearray(self.device, - init_ary, - implicit_host_transfer=True) - f = np.isfinite(ary) - self.assertTrue(f.all()) - - def testIteration(self): - init_ary = np.array([0, 1, 2, 3, 4, 5]) - ary = iree.runtime.asdevicearray(self.device, - init_ary, - implicit_host_transfer=True) - - for index, value in enumerate(ary): - self.assertEqual(index, value) - - def testSubscriptable(self): - init_ary = np.array([0, 1, 2, 3, 4, 5]) - ary = iree.runtime.asdevicearray(self.device, - init_ary, - implicit_host_transfer=True) - - for index in range(0, 6): - value = ary[index] - self.assertEqual(index, value) - - def testReshape(self): - init_ary = np.zeros([3, 4], dtype=np.float32) + 2 - ary = iree.runtime.asdevicearray(self.device, - init_ary, - implicit_host_transfer=True) - reshaped = ary.reshape((4, 3)) - self.assertEqual((4, 3), reshaped.shape) - - np_reshaped = np.reshape(ary, (2, 2, 3)) - self.assertEqual((2, 2, 3), np_reshaped.shape) - - def testDeepcopy(self): - init_ary = np.zeros([3, 4], dtype=np.float32) + 2 - orig_ary = iree.runtime.asdevicearray(self.device, - init_ary, - implicit_host_transfer=True) - copy_ary = copy.deepcopy(orig_ary) - self.assertIsNot(orig_ary, copy_ary) - np.testing.assert_array_equal(orig_ary, copy_ary) - - def testAsType(self): - init_ary = np.zeros([3, 4], dtype=np.int32) + 2 - orig_ary = iree.runtime.asdevicearray(self.device, - init_ary, - implicit_host_transfer=True) - # Same dtype, no copy. - i32_nocopy = orig_ary.astype(np.int32, copy=False) - self.assertIs(orig_ary, i32_nocopy) - - # Same dtype, copy. - i32_nocopy = orig_ary.astype(np.int32) - self.assertIsNot(orig_ary, i32_nocopy) - np.testing.assert_array_equal(orig_ary, i32_nocopy) - - # Different dtype, copy. - f32_copy = orig_ary.astype(np.float32) - self.assertIsNot(orig_ary, f32_copy) - self.assertEqual(f32_copy.dtype, np.float32) - np.testing.assert_array_equal(orig_ary.astype(np.float32), f32_copy) - - def testBool(self): - init_ary = np.zeros([3, 4], dtype=np.bool_) - init_ary[1] = True # Set some non-zero value. - ary = iree.runtime.asdevicearray(self.device, init_ary) - self.assertEqual(repr(ary), "") - np.testing.assert_array_equal(ary.to_host(), init_ary) + def setUp(self): + super().setUp() + self.device = iree.runtime.get_device("local-task") + self.allocator = self.device.allocator + # Make sure device setup maintains proper references. + gc.collect() + + def testGcShutdownFiasco(self): + init_ary = np.zeros([3, 4], dtype=np.int32) + 2 + ary = iree.runtime.asdevicearray(self.device, init_ary) + + # Drop all references to backing objects in reverse order to try to + # trigger heap use-after-free on bad shutdown order. + self.allocator = None + gc.collect() + self.device = None + gc.collect() + + # Now drop the ary and make sure nothing crashes (which would indicate + # a reference counting problem of some kind): The array should retain + # everything that it needs to stay live. + ary = None + gc.collect() + + def testMetadataAttributes(self): + init_ary = np.zeros([3, 4], dtype=np.int32) + 2 + ary = iree.runtime.asdevicearray(self.device, init_ary) + self.assertEqual([3, 4], ary.shape) + self.assertEqual(np.int32, ary.dtype) + + def testExplicitHostTransfer(self): + init_ary = np.zeros([3, 4], dtype=np.int32) + 2 + ary = iree.runtime.asdevicearray(self.device, init_ary) + self.assertEqual(repr(ary), "") + self.assertFalse(ary.is_host_accessible) + + # Explicit transfer. + cp = ary.to_host() + np.testing.assert_array_equal(cp, init_ary) + self.assertTrue(ary.is_host_accessible) + + def testOverrideDtype(self): + init_ary = np.zeros([3, 4], dtype=np.int32) + 2 + buffer_view = self.allocator.allocate_buffer_copy( + memory_type=iree.runtime.MemoryType.DEVICE_LOCAL, + allowed_usage=iree.runtime.BufferUsage.DEFAULT, + buffer=init_ary, + element_type=iree.runtime.HalElementType.SINT_32, + ) + + ary = iree.runtime.DeviceArray( + self.device, buffer_view, override_dtype=np.float32 + ) + + # Explicit transfer. + cp = ary.to_host() + self.assertEqual(cp.dtype, np.float32) + np.testing.assert_array_equal(cp, init_ary.astype(np.float32)) + self.assertTrue(ary.is_host_accessible) + + def testIllegalImplicitHostTransfer(self): + init_ary = np.zeros([3, 4], dtype=np.int32) + 2 + ary = iree.runtime.asdevicearray(self.device, init_ary) + # Implicit transfer. + with self.assertRaises(ValueError): + _ = np.asarray(ary) + + def testImplicitHostArithmetic(self): + init_ary = np.zeros([3, 4], dtype=np.int32) + 2 + ary = iree.runtime.asdevicearray( + self.device, init_ary, implicit_host_transfer=True + ) + sum = ary + init_ary + np.testing.assert_array_equal(sum, init_ary + 2) + self.assertTrue(ary.is_host_accessible) + + def testArrayFunctions(self): + init_ary = np.zeros([3, 4], dtype=np.float32) + 2 + ary = iree.runtime.asdevicearray( + self.device, init_ary, implicit_host_transfer=True + ) + f = np.isfinite(ary) + self.assertTrue(f.all()) + + def testIteration(self): + init_ary = np.array([0, 1, 2, 3, 4, 5]) + ary = iree.runtime.asdevicearray( + self.device, init_ary, implicit_host_transfer=True + ) + + for index, value in enumerate(ary): + self.assertEqual(index, value) + + def testSubscriptable(self): + init_ary = np.array([0, 1, 2, 3, 4, 5]) + ary = iree.runtime.asdevicearray( + self.device, init_ary, implicit_host_transfer=True + ) + + for index in range(0, 6): + value = ary[index] + self.assertEqual(index, value) + + def testReshape(self): + init_ary = np.zeros([3, 4], dtype=np.float32) + 2 + ary = iree.runtime.asdevicearray( + self.device, init_ary, implicit_host_transfer=True + ) + reshaped = ary.reshape((4, 3)) + self.assertEqual((4, 3), reshaped.shape) + + np_reshaped = np.reshape(ary, (2, 2, 3)) + self.assertEqual((2, 2, 3), np_reshaped.shape) + + def testDeepcopy(self): + init_ary = np.zeros([3, 4], dtype=np.float32) + 2 + orig_ary = iree.runtime.asdevicearray( + self.device, init_ary, implicit_host_transfer=True + ) + copy_ary = copy.deepcopy(orig_ary) + self.assertIsNot(orig_ary, copy_ary) + np.testing.assert_array_equal(orig_ary, copy_ary) + + def testAsType(self): + init_ary = np.zeros([3, 4], dtype=np.int32) + 2 + orig_ary = iree.runtime.asdevicearray( + self.device, init_ary, implicit_host_transfer=True + ) + # Same dtype, no copy. + i32_nocopy = orig_ary.astype(np.int32, copy=False) + self.assertIs(orig_ary, i32_nocopy) + + # Same dtype, copy. + i32_nocopy = orig_ary.astype(np.int32) + self.assertIsNot(orig_ary, i32_nocopy) + np.testing.assert_array_equal(orig_ary, i32_nocopy) + + # Different dtype, copy. + f32_copy = orig_ary.astype(np.float32) + self.assertIsNot(orig_ary, f32_copy) + self.assertEqual(f32_copy.dtype, np.float32) + np.testing.assert_array_equal(orig_ary.astype(np.float32), f32_copy) + + def testBool(self): + init_ary = np.zeros([3, 4], dtype=np.bool_) + init_ary[1] = True # Set some non-zero value. + ary = iree.runtime.asdevicearray(self.device, init_ary) + self.assertEqual(repr(ary), "") + np.testing.assert_array_equal(ary.to_host(), init_ary) if __name__ == "__main__": - unittest.main() + unittest.main() diff --git a/runtime/bindings/python/tests/devices_cli_test.py b/runtime/bindings/python/tests/devices_cli_test.py index b81137df25f6..26c614d11088 100644 --- a/runtime/bindings/python/tests/devices_cli_test.py +++ b/runtime/bindings/python/tests/devices_cli_test.py @@ -14,55 +14,54 @@ def run_cli(*args) -> Tuple[int, str, str]: - capture_stdout = StringIO() - capture_stderr = StringIO() - sys.stdout = capture_stdout - sys.stderr = capture_stderr - try: - rc = cli.main(args) - finally: - sys.stdout = sys.__stdout__ - sys.stderr = sys.__stderr__ - return rc, capture_stdout.getvalue(), capture_stderr.getvalue() + capture_stdout = StringIO() + capture_stderr = StringIO() + sys.stdout = capture_stdout + sys.stderr = capture_stderr + try: + rc = cli.main(args) + finally: + sys.stdout = sys.__stdout__ + sys.stderr = sys.__stderr__ + return rc, capture_stdout.getvalue(), capture_stderr.getvalue() class DevicesCliTest(unittest.TestCase): - - def testLs(self): - rc, output, err = run_cli("ls") - self.assertEqual(rc, 0) - self.assertIn("vmvx:0\tdefault", output) - - def testLsTryCreate(self): - rc, output, err = run_cli("ls", "--try-create") - self.assertEqual(rc, 0) - self.assertIn("vmvx:0\tdefault\tSUCCESS", output) - - def testLsTryCreateExplicitDriver(self): - rc, output, err = run_cli("ls", "--try-create", "-d", "vmvx") - self.assertEqual(rc, 0) - self.assertIn("vmvx:0\tdefault\tSUCCESS", output) - - def testLsTryCreateExplicitDriverNotFound(self): - rc, output, err = run_cli("ls", "--try-create", "-d", "DOES_NOT_EXIST") - self.assertEqual(rc, 0) - self.assertIn("Could not create driver DOES_NOT_EXIST", err) - - def testTestIndexedDevice(self): - rc, output, err = run_cli("test", "vmvx:0") - self.assertEqual(rc, 0) - self.assertIn("Creating device vmvx:0... SUCCESS", output) - - def testTestDefaultDevice(self): - rc, output, err = run_cli("test", "vmvx") - self.assertEqual(rc, 0) - self.assertIn("Creating device vmvx... SUCCESS", output) - - def testTestNonExisting(self): - rc, output, err = run_cli("test", "NOT_EXISTING") - self.assertEqual(rc, 1) - self.assertIn("Creating device NOT_EXISTING... ERROR", output) + def testLs(self): + rc, output, err = run_cli("ls") + self.assertEqual(rc, 0) + self.assertIn("vmvx:0\tdefault", output) + + def testLsTryCreate(self): + rc, output, err = run_cli("ls", "--try-create") + self.assertEqual(rc, 0) + self.assertIn("vmvx:0\tdefault\tSUCCESS", output) + + def testLsTryCreateExplicitDriver(self): + rc, output, err = run_cli("ls", "--try-create", "-d", "vmvx") + self.assertEqual(rc, 0) + self.assertIn("vmvx:0\tdefault\tSUCCESS", output) + + def testLsTryCreateExplicitDriverNotFound(self): + rc, output, err = run_cli("ls", "--try-create", "-d", "DOES_NOT_EXIST") + self.assertEqual(rc, 0) + self.assertIn("Could not create driver DOES_NOT_EXIST", err) + + def testTestIndexedDevice(self): + rc, output, err = run_cli("test", "vmvx:0") + self.assertEqual(rc, 0) + self.assertIn("Creating device vmvx:0... SUCCESS", output) + + def testTestDefaultDevice(self): + rc, output, err = run_cli("test", "vmvx") + self.assertEqual(rc, 0) + self.assertIn("Creating device vmvx... SUCCESS", output) + + def testTestNonExisting(self): + rc, output, err = run_cli("test", "NOT_EXISTING") + self.assertEqual(rc, 1) + self.assertIn("Creating device NOT_EXISTING... ERROR", output) if __name__ == "__main__": - unittest.main() + unittest.main() diff --git a/runtime/bindings/python/tests/flags_test.py b/runtime/bindings/python/tests/flags_test.py index 56309abfd91a..6cfd7d47e690 100644 --- a/runtime/bindings/python/tests/flags_test.py +++ b/runtime/bindings/python/tests/flags_test.py @@ -10,15 +10,14 @@ class FlagsTest(unittest.TestCase): + def testParse(self): + # --help is always available if flags are. + rt.flags.parse_flags("--help") - def testParse(self): - # --help is always available if flags are. - rt.flags.parse_flags("--help") - - def testParseError(self): - with self.assertRaisesRegex(ValueError, "flag 'barbar' not recognized"): - rt.flags.parse_flags("--barbar") + def testParseError(self): + with self.assertRaisesRegex(ValueError, "flag 'barbar' not recognized"): + rt.flags.parse_flags("--barbar") if __name__ == "__main__": - unittest.main() + unittest.main() diff --git a/runtime/bindings/python/tests/function_test.py b/runtime/bindings/python/tests/function_test.py index 4c23cc61ef14..95625e906b9a 100644 --- a/runtime/bindings/python/tests/function_test.py +++ b/runtime/bindings/python/tests/function_test.py @@ -18,577 +18,668 @@ class MockVmContext: + def __init__(self, invoke_callback): + self._invoke_callback = invoke_callback + self.invocations = [] - def __init__(self, invoke_callback): - self._invoke_callback = invoke_callback - self.invocations = [] + def invoke(self, vm_function, arg_list, ret_list): + self._invoke_callback(arg_list, ret_list) + self.invocations.append((vm_function, arg_list, ret_list)) + print(f"INVOKE: {arg_list} -> {ret_list}") - def invoke(self, vm_function, arg_list, ret_list): - self._invoke_callback(arg_list, ret_list) - self.invocations.append((vm_function, arg_list, ret_list)) - print(f"INVOKE: {arg_list} -> {ret_list}") - - @property - def mock_arg_reprs(self): - return repr([arg_list for _, arg_list, _ in self.invocations]) + @property + def mock_arg_reprs(self): + return repr([arg_list for _, arg_list, _ in self.invocations]) class MockVmFunction: - - def __init__(self, reflection): - self.reflection = reflection + def __init__(self, reflection): + self.reflection = reflection class FunctionTest(unittest.TestCase): - - @classmethod - def setUpClass(cls): - # Doesn't matter what device. We just need one. - config = rt.Config("local-task") - cls.device = config.device - - def testNoReflectionScalars(self): - - def invoke(arg_list, ret_list): - ret_list.push_int(3) - ret_list.push_int(4) - - vm_context = MockVmContext(invoke) - vm_function = MockVmFunction(reflection={}) - invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None) - result = invoker(1, 2) - self.assertEqual("[]", vm_context.mock_arg_reprs) - self.assertEqual((3, 4), result) - - def testKeywordArgs(self): - - def invoke(arg_list, ret_list): - ret_list.push_int(3) - - vm_context = MockVmContext(invoke) - vm_function = MockVmFunction( - reflection={ - "iree.abi": - json.dumps({ - "a": [ - "i32", - ["named", "a", "i32"], - ["named", "b", "i32"], - ], - "r": ["i32",], - }) - }) - invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None) - result = invoker(-1, a=1, b=2) - self.assertEqual("[]", - vm_context.mock_arg_reprs) - self.assertEqual(3, result) - - def testListArg(self): - - def invoke(arg_list, ret_list): - ret_list.push_int(3) - - vm_context = MockVmContext(invoke) - vm_function = MockVmFunction(reflection={ - "iree.abi": - json.dumps({ - "a": [["slist", "i32", "i32"],], - "r": ["i32",], - }) - }) - invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None) - _ = invoker([2, 3]) - self.assertEqual("[]", - vm_context.mock_arg_reprs) - - def testListArgNoReflection(self): - - def invoke(arg_list, ret_list): - ret_list.push_int(3) - - vm_context = MockVmContext(invoke) - vm_function = MockVmFunction(reflection={}) - invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None) - _ = invoker([2, 3]) - self.assertEqual("[]", - vm_context.mock_arg_reprs) - - def testListArgArityMismatch(self): - - def invoke(arg_list, ret_list): - ret_list.push_int(3) - - vm_context = MockVmContext(invoke) - vm_function = MockVmFunction(reflection={ - "iree.abi": - json.dumps({ - "a": [["slist", "i32", "i32"],], - "r": ["i32",], - }) - }) - invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None) - with self.assertRaisesRegex(ValueError, - "expected a sequence with 2 values. got:"): - _ = invoker([2, 3, 4]) - - def testTupleArg(self): - - def invoke(arg_list, ret_list): - ret_list.push_int(3) - - vm_context = MockVmContext(invoke) - vm_function = MockVmFunction(reflection={ - "iree.abi": - json.dumps({ - "a": [["stuple", "i32", "i32"],], - "r": ["i32",], - }) - }) - invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None) - _ = invoker((2, 3)) - self.assertEqual("[]", - vm_context.mock_arg_reprs) - - def testDictArg(self): - - def invoke(arg_list, ret_list): - ret_list.push_int(3) - - vm_context = MockVmContext(invoke) - vm_function = MockVmFunction( - reflection={ - "iree.abi": - json.dumps({ - "a": [["sdict", ["a", "i32"], ["b", "i32"]],], - "r": ["i32",], - }) - }) - invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None) - _ = invoker({"b": 3, "a": 2}) - self.assertEqual("[]", - vm_context.mock_arg_reprs) - - def testDictArgArityMismatch(self): - - def invoke(arg_list, ret_list): - ret_list.push_int(3) - - vm_context = MockVmContext(invoke) - vm_function = MockVmFunction( - reflection={ - "iree.abi": - json.dumps({ - "a": [["sdict", ["a", "i32"], ["b", "i32"]],], - "r": ["i32",], - }) - }) - invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None) - with self.assertRaisesRegex(ValueError, - "expected a dict with 2 values. got:"): - _ = invoker({"a": 2, "b": 3, "c": 4}) - - def testDictArgKeyError(self): - - def invoke(arg_list, ret_list): - ret_list.push_int(3) - - vm_context = MockVmContext(invoke) - vm_function = MockVmFunction( - reflection={ - "iree.abi": - json.dumps({ - "a": [["sdict", ["a", "i32"], ["b", "i32"]],], - "r": ["i32",], - }) - }) - invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None) - with self.assertRaisesRegex(ValueError, "could not get item 'b' from: "): - _ = invoker({"a": 2, "c": 3}) - - def testDictArgNoReflection(self): - - def invoke(arg_list, ret_list): - ret_list.push_int(3) - - vm_context = MockVmContext(invoke) - vm_function = MockVmFunction(reflection={}) - invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None) - _ = invoker({"b": 3, "a": 2}) - self.assertEqual("[]", - vm_context.mock_arg_reprs) - - def testInlinedResults(self): - - def invoke(arg_list, ret_list): - ret_list.push_int(3) - ret_list.push_int(4) - - vm_context = MockVmContext(invoke) - vm_function = MockVmFunction(reflection={ - "iree.abi": json.dumps({ - "a": [], - "r": [["slist", "i32", "i32"]], - }) - }) - invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None) - result = invoker() - self.assertEqual([3, 4], result) - - def testNestedResults(self): - - def invoke(arg_list, ret_list): - ret_list.push_int(3) - sub_list = VmVariantList(2) - sub_dict = VmVariantList(2) - sub_dict.push_int(100) - sub_dict.push_int(200) - sub_list.push_list(sub_dict) - sub_list.push_int(6) - ret_list.push_list(sub_list) - - vm_context = MockVmContext(invoke) - vm_function = MockVmFunction( - reflection={ - "iree.abi": - json.dumps({ - "a": [], - "r": [ - "i32", - [ - "slist", - ["sdict", ["bar", "i32"], ["foo", "i32"]], - "i64", - ] - ], - }) - }) - invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None) - result = invoker() - self.assertEqual((3, [{"bar": 100, "foo": 200}, 6]), result) - - def testMissingPositional(self): - - def invoke(arg_list, ret_list): - ret_list.push_int(3) - - vm_context = MockVmContext(invoke) - vm_function = MockVmFunction( - reflection={ - "iree.abi": - json.dumps({ - "a": [ - "i32", - ["named", "a", "i32"], - ["named", "b", "i32"], - ], - "r": ["i32",], - }) - }) - invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None) - with self.assertRaisesRegex(ValueError, "mismatched call arity:"): - result = invoker(a=1, b=1) - - def testMissingPositionalNdarray(self): - - def invoke(arg_list, ret_list): - ret_list.push_int(3) - - vm_context = MockVmContext(invoke) - vm_function = MockVmFunction( - reflection={ - "iree.abi": - json.dumps({ - "a": [ - ["ndarray", "i32", 1, 1], - ["named", "a", ["ndarray", "i32", 1, 1]], - ["named", "b", ["ndarray", "i32", 1, 1]], - ], - "r": ["i32",], - }) - }) - invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None) - with self.assertRaisesRegex(ValueError, "mismatched call arity:"): - result = invoker(a=1, b=1) - - def testMissingKeyword(self): - - def invoke(arg_list, ret_list): - ret_list.push_int(3) - - vm_context = MockVmContext(invoke) - vm_function = MockVmFunction( - reflection={ - "iree.abi": - json.dumps({ - "a": [ - "i32", - ["named", "a", "i32"], - ["named", "b", "i32"], - ], - "r": ["i32",], - }) - }) - invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None) - with self.assertRaisesRegex(ValueError, "mismatched call arity:"): - result = invoker(-1, a=1) - - def testMissingKeywordNdArray(self): - - def invoke(arg_list, ret_list): - ret_list.push_int(3) - - vm_context = MockVmContext(invoke) - vm_function = MockVmFunction( - reflection={ - "iree.abi": - json.dumps({ - "a": [ - ["ndarray", "i32", 1, 1], - ["named", "a", ["ndarray", "i32", 1, 1]], - ["named", "b", ["ndarray", "i32", 1, 1]], - ], - "r": ["i32",], - }) - }) - invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None) - with self.assertRaisesRegex(ValueError, "mismatched call arity:"): - result = invoker(-1, a=1) - - def testExtraKeyword(self): - - def invoke(arg_list, ret_list): - ret_list.push_int(3) - - vm_context = MockVmContext(invoke) - vm_function = MockVmFunction( - reflection={ - "iree.abi": - json.dumps({ - "a": [ - "i32", - ["named", "a", "i32"], - ["named", "b", "i32"], - ], - "r": ["i32",], - }) - }) - invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None) - with self.assertRaisesRegex(ValueError, "specified kwarg 'c' is unknown"): - result = invoker(-1, a=1, b=2, c=3) - - def testNdarrayArg(self): - arg_array = np.asarray([1, 0], dtype=np.int32) - - invoked_arg_list = None - - def invoke(arg_list, ret_list): - nonlocal invoked_arg_list - invoked_arg_list = arg_list - - vm_context = MockVmContext(invoke) - vm_function = MockVmFunction(reflection={ - "iree.abi": json.dumps({ - "a": [["ndarray", "i32", 1, 2]], - "r": [], - }) - }) - invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None) - result = invoker(arg_array) - self.assertEqual("", - repr(invoked_arg_list)) - - def testDeviceArrayArg(self): - # Note that since the device array is set up to disallow implicit host - # transfers, this also verifies that no accidental/automatic transfers - # are done as part of marshalling the array to the function. - arg_array = rt.asdevicearray(self.device, - np.asarray([1, 0], dtype=np.int32), - implicit_host_transfer=False) - - invoked_arg_list = None - - def invoke(arg_list, ret_list): - nonlocal invoked_arg_list - invoked_arg_list = arg_list - - vm_context = MockVmContext(invoke) - vm_function = MockVmFunction(reflection={ - "iree.abi": json.dumps({ - "a": [["ndarray", "i32", 1, 2]], - "r": [], - }) - }) - invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None) - result = invoker(arg_array) - self.assertEqual("", - repr(invoked_arg_list)) - - def testBufferViewArg(self): - arg_buffer_view = self.device.allocator.allocate_buffer_copy( - memory_type=IMPLICIT_BUFFER_ARG_MEMORY_TYPE, - allowed_usage=IMPLICIT_BUFFER_ARG_USAGE, - buffer=np.asarray([1, 0], dtype=np.int32), - element_type=rt.HalElementType.SINT_32) - - invoked_arg_list = None - - def invoke(arg_list, ret_list): - nonlocal invoked_arg_list - invoked_arg_list = arg_list - - vm_context = MockVmContext(invoke) - vm_function = MockVmFunction(reflection={ - "iree.abi": json.dumps({ - "a": [["ndarray", "i32", 1, 2]], - "r": [], - }) - }) - invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None) - _ = invoker(arg_buffer_view) - self.assertEqual("", - repr(invoked_arg_list)) - - def testNdarrayArgNoReflection(self): - arg_array = np.asarray([1, 0], dtype=np.int32) - - invoked_arg_list = None - - def invoke(arg_list, ret_list): - nonlocal invoked_arg_list - invoked_arg_list = arg_list - - vm_context = MockVmContext(invoke) - vm_function = MockVmFunction(reflection={}) - invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None) - result = invoker(arg_array) - self.assertEqual("", - repr(invoked_arg_list)) - - def testDeviceArrayArgNoReflection(self): - # Note that since the device array is set up to disallow implicit host - # transfers, this also verifies that no accidental/automatic transfers - # are done as part of marshalling the array to the function. - arg_array = rt.asdevicearray(self.device, - np.asarray([1, 0], dtype=np.int32), - implicit_host_transfer=False) - - invoked_arg_list = None - - def invoke(arg_list, ret_list): - nonlocal invoked_arg_list - invoked_arg_list = arg_list - - vm_context = MockVmContext(invoke) - vm_function = MockVmFunction(reflection={}) - invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None) - result = invoker(arg_array) - self.assertEqual("", - repr(invoked_arg_list)) - - def testBufferViewArgNoReflection(self): - arg_buffer_view = self.device.allocator.allocate_buffer_copy( - memory_type=IMPLICIT_BUFFER_ARG_MEMORY_TYPE, - allowed_usage=IMPLICIT_BUFFER_ARG_USAGE, - buffer=np.asarray([1, 0], dtype=np.int32), - element_type=rt.HalElementType.SINT_32) - - invoked_arg_list = None - - def invoke(arg_list, ret_list): - nonlocal invoked_arg_list - invoked_arg_list = arg_list - - vm_context = MockVmContext(invoke) - vm_function = MockVmFunction(reflection={}) - invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None) - _ = invoker(arg_buffer_view) - self.assertEqual("", - repr(invoked_arg_list)) - - def testReturnBufferView(self): - result_array = np.asarray([1, 0], dtype=np.int32) - - def invoke(arg_list, ret_list): - buffer_view = self.device.allocator.allocate_buffer_copy( - memory_type=IMPLICIT_BUFFER_ARG_MEMORY_TYPE, - allowed_usage=IMPLICIT_BUFFER_ARG_USAGE, - buffer=result_array, - element_type=rt.HalElementType.SINT_32) - ret_list.push_ref(buffer_view) - - vm_context = MockVmContext(invoke) - vm_function = MockVmFunction(reflection={ - "iree.abi": json.dumps({ - "a": [], - "r": [["ndarray", "i32", 1, 2]], - }) - }) - invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None) - result = invoker() - np.testing.assert_array_equal([1, 0], result) - - def testReturnBufferViewNoReflection(self): - result_array = np.asarray([1, 0], dtype=np.int32) - - def invoke(arg_list, ret_list): - buffer_view = self.device.allocator.allocate_buffer_copy( - memory_type=IMPLICIT_BUFFER_ARG_MEMORY_TYPE, - allowed_usage=IMPLICIT_BUFFER_ARG_USAGE, - buffer=result_array, - element_type=rt.HalElementType.SINT_32) - ret_list.push_ref(buffer_view) - - vm_context = MockVmContext(invoke) - vm_function = MockVmFunction(reflection={}) - invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None) - result = invoker() - np.testing.assert_array_equal([1, 0], result) - - # TODO: Fill out all return types. - def testReturnTypeNdArrayBool(self): - result_array = np.asarray([1, 0], dtype=np.int8) - - def invoke(arg_list, ret_list): - buffer_view = self.device.allocator.allocate_buffer_copy( - memory_type=IMPLICIT_BUFFER_ARG_MEMORY_TYPE, - allowed_usage=IMPLICIT_BUFFER_ARG_USAGE, - buffer=result_array, - element_type=rt.HalElementType.UINT_8) - ret_list.push_ref(buffer_view) - - vm_context = MockVmContext(invoke) - vm_function = MockVmFunction(reflection={ - "iree.abi": json.dumps({ - "a": [], - "r": [["ndarray", "i1", 1, 2]], - }) - }) - invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None) - result = invoker() - # assertEqual on bool arrays is fraught for... reasons. - np.testing.assert_array_equal([True, False], result) - - def testReturnTypeList(self): - vm_list = VmVariantList(2) - vm_list.push_int(1) - vm_list.push_int(2) - - def invoke(arg_list, ret_list): - ret_list.push_list(vm_list) - - vm_context = MockVmContext(invoke) - vm_function = MockVmFunction(reflection={ - "iree.abi": - json.dumps({ - "a": [], - "r": [["py_homogeneous_list", "i64"]], - }) - }) - invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None) - result = invoker() - self.assertEqual("[1, 2]", repr(result)) + @classmethod + def setUpClass(cls): + # Doesn't matter what device. We just need one. + config = rt.Config("local-task") + cls.device = config.device + + def testNoReflectionScalars(self): + def invoke(arg_list, ret_list): + ret_list.push_int(3) + ret_list.push_int(4) + + vm_context = MockVmContext(invoke) + vm_function = MockVmFunction(reflection={}) + invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None) + result = invoker(1, 2) + self.assertEqual("[]", vm_context.mock_arg_reprs) + self.assertEqual((3, 4), result) + + def testKeywordArgs(self): + def invoke(arg_list, ret_list): + ret_list.push_int(3) + + vm_context = MockVmContext(invoke) + vm_function = MockVmFunction( + reflection={ + "iree.abi": json.dumps( + { + "a": [ + "i32", + ["named", "a", "i32"], + ["named", "b", "i32"], + ], + "r": [ + "i32", + ], + } + ) + } + ) + invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None) + result = invoker(-1, a=1, b=2) + self.assertEqual("[]", vm_context.mock_arg_reprs) + self.assertEqual(3, result) + + def testListArg(self): + def invoke(arg_list, ret_list): + ret_list.push_int(3) + + vm_context = MockVmContext(invoke) + vm_function = MockVmFunction( + reflection={ + "iree.abi": json.dumps( + { + "a": [ + ["slist", "i32", "i32"], + ], + "r": [ + "i32", + ], + } + ) + } + ) + invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None) + _ = invoker([2, 3]) + self.assertEqual( + "[]", vm_context.mock_arg_reprs + ) + + def testListArgNoReflection(self): + def invoke(arg_list, ret_list): + ret_list.push_int(3) + + vm_context = MockVmContext(invoke) + vm_function = MockVmFunction(reflection={}) + invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None) + _ = invoker([2, 3]) + self.assertEqual( + "[]", vm_context.mock_arg_reprs + ) + + def testListArgArityMismatch(self): + def invoke(arg_list, ret_list): + ret_list.push_int(3) + + vm_context = MockVmContext(invoke) + vm_function = MockVmFunction( + reflection={ + "iree.abi": json.dumps( + { + "a": [ + ["slist", "i32", "i32"], + ], + "r": [ + "i32", + ], + } + ) + } + ) + invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None) + with self.assertRaisesRegex( + ValueError, "expected a sequence with 2 values. got:" + ): + _ = invoker([2, 3, 4]) + + def testTupleArg(self): + def invoke(arg_list, ret_list): + ret_list.push_int(3) + + vm_context = MockVmContext(invoke) + vm_function = MockVmFunction( + reflection={ + "iree.abi": json.dumps( + { + "a": [ + ["stuple", "i32", "i32"], + ], + "r": [ + "i32", + ], + } + ) + } + ) + invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None) + _ = invoker((2, 3)) + self.assertEqual( + "[]", vm_context.mock_arg_reprs + ) + + def testDictArg(self): + def invoke(arg_list, ret_list): + ret_list.push_int(3) + + vm_context = MockVmContext(invoke) + vm_function = MockVmFunction( + reflection={ + "iree.abi": json.dumps( + { + "a": [ + ["sdict", ["a", "i32"], ["b", "i32"]], + ], + "r": [ + "i32", + ], + } + ) + } + ) + invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None) + _ = invoker({"b": 3, "a": 2}) + self.assertEqual( + "[]", vm_context.mock_arg_reprs + ) + + def testDictArgArityMismatch(self): + def invoke(arg_list, ret_list): + ret_list.push_int(3) + + vm_context = MockVmContext(invoke) + vm_function = MockVmFunction( + reflection={ + "iree.abi": json.dumps( + { + "a": [ + ["sdict", ["a", "i32"], ["b", "i32"]], + ], + "r": [ + "i32", + ], + } + ) + } + ) + invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None) + with self.assertRaisesRegex(ValueError, "expected a dict with 2 values. got:"): + _ = invoker({"a": 2, "b": 3, "c": 4}) + + def testDictArgKeyError(self): + def invoke(arg_list, ret_list): + ret_list.push_int(3) + + vm_context = MockVmContext(invoke) + vm_function = MockVmFunction( + reflection={ + "iree.abi": json.dumps( + { + "a": [ + ["sdict", ["a", "i32"], ["b", "i32"]], + ], + "r": [ + "i32", + ], + } + ) + } + ) + invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None) + with self.assertRaisesRegex(ValueError, "could not get item 'b' from: "): + _ = invoker({"a": 2, "c": 3}) + + def testDictArgNoReflection(self): + def invoke(arg_list, ret_list): + ret_list.push_int(3) + + vm_context = MockVmContext(invoke) + vm_function = MockVmFunction(reflection={}) + invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None) + _ = invoker({"b": 3, "a": 2}) + self.assertEqual( + "[]", vm_context.mock_arg_reprs + ) + + def testInlinedResults(self): + def invoke(arg_list, ret_list): + ret_list.push_int(3) + ret_list.push_int(4) + + vm_context = MockVmContext(invoke) + vm_function = MockVmFunction( + reflection={ + "iree.abi": json.dumps( + { + "a": [], + "r": [["slist", "i32", "i32"]], + } + ) + } + ) + invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None) + result = invoker() + self.assertEqual([3, 4], result) + + def testNestedResults(self): + def invoke(arg_list, ret_list): + ret_list.push_int(3) + sub_list = VmVariantList(2) + sub_dict = VmVariantList(2) + sub_dict.push_int(100) + sub_dict.push_int(200) + sub_list.push_list(sub_dict) + sub_list.push_int(6) + ret_list.push_list(sub_list) + + vm_context = MockVmContext(invoke) + vm_function = MockVmFunction( + reflection={ + "iree.abi": json.dumps( + { + "a": [], + "r": [ + "i32", + [ + "slist", + ["sdict", ["bar", "i32"], ["foo", "i32"]], + "i64", + ], + ], + } + ) + } + ) + invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None) + result = invoker() + self.assertEqual((3, [{"bar": 100, "foo": 200}, 6]), result) + + def testMissingPositional(self): + def invoke(arg_list, ret_list): + ret_list.push_int(3) + + vm_context = MockVmContext(invoke) + vm_function = MockVmFunction( + reflection={ + "iree.abi": json.dumps( + { + "a": [ + "i32", + ["named", "a", "i32"], + ["named", "b", "i32"], + ], + "r": [ + "i32", + ], + } + ) + } + ) + invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None) + with self.assertRaisesRegex(ValueError, "mismatched call arity:"): + result = invoker(a=1, b=1) + + def testMissingPositionalNdarray(self): + def invoke(arg_list, ret_list): + ret_list.push_int(3) + + vm_context = MockVmContext(invoke) + vm_function = MockVmFunction( + reflection={ + "iree.abi": json.dumps( + { + "a": [ + ["ndarray", "i32", 1, 1], + ["named", "a", ["ndarray", "i32", 1, 1]], + ["named", "b", ["ndarray", "i32", 1, 1]], + ], + "r": [ + "i32", + ], + } + ) + } + ) + invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None) + with self.assertRaisesRegex(ValueError, "mismatched call arity:"): + result = invoker(a=1, b=1) + + def testMissingKeyword(self): + def invoke(arg_list, ret_list): + ret_list.push_int(3) + + vm_context = MockVmContext(invoke) + vm_function = MockVmFunction( + reflection={ + "iree.abi": json.dumps( + { + "a": [ + "i32", + ["named", "a", "i32"], + ["named", "b", "i32"], + ], + "r": [ + "i32", + ], + } + ) + } + ) + invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None) + with self.assertRaisesRegex(ValueError, "mismatched call arity:"): + result = invoker(-1, a=1) + + def testMissingKeywordNdArray(self): + def invoke(arg_list, ret_list): + ret_list.push_int(3) + + vm_context = MockVmContext(invoke) + vm_function = MockVmFunction( + reflection={ + "iree.abi": json.dumps( + { + "a": [ + ["ndarray", "i32", 1, 1], + ["named", "a", ["ndarray", "i32", 1, 1]], + ["named", "b", ["ndarray", "i32", 1, 1]], + ], + "r": [ + "i32", + ], + } + ) + } + ) + invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None) + with self.assertRaisesRegex(ValueError, "mismatched call arity:"): + result = invoker(-1, a=1) + + def testExtraKeyword(self): + def invoke(arg_list, ret_list): + ret_list.push_int(3) + + vm_context = MockVmContext(invoke) + vm_function = MockVmFunction( + reflection={ + "iree.abi": json.dumps( + { + "a": [ + "i32", + ["named", "a", "i32"], + ["named", "b", "i32"], + ], + "r": [ + "i32", + ], + } + ) + } + ) + invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None) + with self.assertRaisesRegex(ValueError, "specified kwarg 'c' is unknown"): + result = invoker(-1, a=1, b=2, c=3) + + def testNdarrayArg(self): + arg_array = np.asarray([1, 0], dtype=np.int32) + + invoked_arg_list = None + + def invoke(arg_list, ret_list): + nonlocal invoked_arg_list + invoked_arg_list = arg_list + + vm_context = MockVmContext(invoke) + vm_function = MockVmFunction( + reflection={ + "iree.abi": json.dumps( + { + "a": [["ndarray", "i32", 1, 2]], + "r": [], + } + ) + } + ) + invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None) + result = invoker(arg_array) + self.assertEqual( + "", repr(invoked_arg_list) + ) + + def testDeviceArrayArg(self): + # Note that since the device array is set up to disallow implicit host + # transfers, this also verifies that no accidental/automatic transfers + # are done as part of marshalling the array to the function. + arg_array = rt.asdevicearray( + self.device, + np.asarray([1, 0], dtype=np.int32), + implicit_host_transfer=False, + ) + + invoked_arg_list = None + + def invoke(arg_list, ret_list): + nonlocal invoked_arg_list + invoked_arg_list = arg_list + + vm_context = MockVmContext(invoke) + vm_function = MockVmFunction( + reflection={ + "iree.abi": json.dumps( + { + "a": [["ndarray", "i32", 1, 2]], + "r": [], + } + ) + } + ) + invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None) + result = invoker(arg_array) + self.assertEqual( + "", repr(invoked_arg_list) + ) + + def testBufferViewArg(self): + arg_buffer_view = self.device.allocator.allocate_buffer_copy( + memory_type=IMPLICIT_BUFFER_ARG_MEMORY_TYPE, + allowed_usage=IMPLICIT_BUFFER_ARG_USAGE, + buffer=np.asarray([1, 0], dtype=np.int32), + element_type=rt.HalElementType.SINT_32, + ) + + invoked_arg_list = None + + def invoke(arg_list, ret_list): + nonlocal invoked_arg_list + invoked_arg_list = arg_list + + vm_context = MockVmContext(invoke) + vm_function = MockVmFunction( + reflection={ + "iree.abi": json.dumps( + { + "a": [["ndarray", "i32", 1, 2]], + "r": [], + } + ) + } + ) + invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None) + _ = invoker(arg_buffer_view) + self.assertEqual( + "", repr(invoked_arg_list) + ) + + def testNdarrayArgNoReflection(self): + arg_array = np.asarray([1, 0], dtype=np.int32) + + invoked_arg_list = None + + def invoke(arg_list, ret_list): + nonlocal invoked_arg_list + invoked_arg_list = arg_list + + vm_context = MockVmContext(invoke) + vm_function = MockVmFunction(reflection={}) + invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None) + result = invoker(arg_array) + self.assertEqual( + "", repr(invoked_arg_list) + ) + + def testDeviceArrayArgNoReflection(self): + # Note that since the device array is set up to disallow implicit host + # transfers, this also verifies that no accidental/automatic transfers + # are done as part of marshalling the array to the function. + arg_array = rt.asdevicearray( + self.device, + np.asarray([1, 0], dtype=np.int32), + implicit_host_transfer=False, + ) + + invoked_arg_list = None + + def invoke(arg_list, ret_list): + nonlocal invoked_arg_list + invoked_arg_list = arg_list + + vm_context = MockVmContext(invoke) + vm_function = MockVmFunction(reflection={}) + invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None) + result = invoker(arg_array) + self.assertEqual( + "", repr(invoked_arg_list) + ) + + def testBufferViewArgNoReflection(self): + arg_buffer_view = self.device.allocator.allocate_buffer_copy( + memory_type=IMPLICIT_BUFFER_ARG_MEMORY_TYPE, + allowed_usage=IMPLICIT_BUFFER_ARG_USAGE, + buffer=np.asarray([1, 0], dtype=np.int32), + element_type=rt.HalElementType.SINT_32, + ) + + invoked_arg_list = None + + def invoke(arg_list, ret_list): + nonlocal invoked_arg_list + invoked_arg_list = arg_list + + vm_context = MockVmContext(invoke) + vm_function = MockVmFunction(reflection={}) + invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None) + _ = invoker(arg_buffer_view) + self.assertEqual( + "", repr(invoked_arg_list) + ) + + def testReturnBufferView(self): + result_array = np.asarray([1, 0], dtype=np.int32) + + def invoke(arg_list, ret_list): + buffer_view = self.device.allocator.allocate_buffer_copy( + memory_type=IMPLICIT_BUFFER_ARG_MEMORY_TYPE, + allowed_usage=IMPLICIT_BUFFER_ARG_USAGE, + buffer=result_array, + element_type=rt.HalElementType.SINT_32, + ) + ret_list.push_ref(buffer_view) + + vm_context = MockVmContext(invoke) + vm_function = MockVmFunction( + reflection={ + "iree.abi": json.dumps( + { + "a": [], + "r": [["ndarray", "i32", 1, 2]], + } + ) + } + ) + invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None) + result = invoker() + np.testing.assert_array_equal([1, 0], result) + + def testReturnBufferViewNoReflection(self): + result_array = np.asarray([1, 0], dtype=np.int32) + + def invoke(arg_list, ret_list): + buffer_view = self.device.allocator.allocate_buffer_copy( + memory_type=IMPLICIT_BUFFER_ARG_MEMORY_TYPE, + allowed_usage=IMPLICIT_BUFFER_ARG_USAGE, + buffer=result_array, + element_type=rt.HalElementType.SINT_32, + ) + ret_list.push_ref(buffer_view) + + vm_context = MockVmContext(invoke) + vm_function = MockVmFunction(reflection={}) + invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None) + result = invoker() + np.testing.assert_array_equal([1, 0], result) + + # TODO: Fill out all return types. + def testReturnTypeNdArrayBool(self): + result_array = np.asarray([1, 0], dtype=np.int8) + + def invoke(arg_list, ret_list): + buffer_view = self.device.allocator.allocate_buffer_copy( + memory_type=IMPLICIT_BUFFER_ARG_MEMORY_TYPE, + allowed_usage=IMPLICIT_BUFFER_ARG_USAGE, + buffer=result_array, + element_type=rt.HalElementType.UINT_8, + ) + ret_list.push_ref(buffer_view) + + vm_context = MockVmContext(invoke) + vm_function = MockVmFunction( + reflection={ + "iree.abi": json.dumps( + { + "a": [], + "r": [["ndarray", "i1", 1, 2]], + } + ) + } + ) + invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None) + result = invoker() + # assertEqual on bool arrays is fraught for... reasons. + np.testing.assert_array_equal([True, False], result) + + def testReturnTypeList(self): + vm_list = VmVariantList(2) + vm_list.push_int(1) + vm_list.push_int(2) + + def invoke(arg_list, ret_list): + ret_list.push_list(vm_list) + + vm_context = MockVmContext(invoke) + vm_function = MockVmFunction( + reflection={ + "iree.abi": json.dumps( + { + "a": [], + "r": [["py_homogeneous_list", "i64"]], + } + ) + } + ) + invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None) + result = invoker() + self.assertEqual("[1, 2]", repr(result)) if __name__ == "__main__": - unittest.main() + unittest.main() diff --git a/runtime/bindings/python/tests/hal_test.py b/runtime/bindings/python/tests/hal_test.py index 654039209796..38eaf3d0f22c 100644 --- a/runtime/bindings/python/tests/hal_test.py +++ b/runtime/bindings/python/tests/hal_test.py @@ -12,136 +12,147 @@ class NonDeviceHalTest(unittest.TestCase): - - def testEnums(self): - print("MemoryType:", iree.runtime.MemoryType) - print("HOST_VISIBLE:", int(iree.runtime.MemoryType.HOST_VISIBLE)) - - # Enum and/or operations on BufferCompatibility. - self.assertEqual( - iree.runtime.BufferCompatibility.IMPORTABLE | - iree.runtime.BufferCompatibility.EXPORTABLE, - int(iree.runtime.BufferCompatibility.IMPORTABLE) | - int(iree.runtime.BufferCompatibility.EXPORTABLE)) - self.assertEqual( - iree.runtime.BufferCompatibility.EXPORTABLE & - iree.runtime.BufferCompatibility.EXPORTABLE, - int(iree.runtime.BufferCompatibility.EXPORTABLE)) - - # Enum and/or operations on BufferUsage. - self.assertEqual( - iree.runtime.BufferUsage.TRANSFER | iree.runtime.BufferUsage.MAPPING, - int(iree.runtime.BufferUsage.TRANSFER) | - int(iree.runtime.BufferUsage.MAPPING)) - self.assertEqual( - iree.runtime.BufferUsage.TRANSFER & iree.runtime.BufferUsage.TRANSFER, - int(iree.runtime.BufferUsage.TRANSFER)) - - # Enum and/or operations on MemoryAccess. - self.assertEqual( - iree.runtime.MemoryAccess.READ | iree.runtime.MemoryAccess.WRITE, - int(iree.runtime.MemoryAccess.READ) | - int(iree.runtime.MemoryAccess.WRITE)) - self.assertEqual( - iree.runtime.MemoryAccess.ALL & iree.runtime.MemoryAccess.READ, - int(iree.runtime.MemoryAccess.READ)) - - # Enum and/or operations on MemoryType. - self.assertEqual( - iree.runtime.MemoryType.DEVICE_LOCAL | - iree.runtime.MemoryType.HOST_VISIBLE, - int(iree.runtime.MemoryType.DEVICE_LOCAL) | - int(iree.runtime.MemoryType.HOST_VISIBLE)) - self.assertEqual( - iree.runtime.MemoryType.OPTIMAL & iree.runtime.MemoryType.OPTIMAL, - int(iree.runtime.MemoryType.OPTIMAL)) + def testEnums(self): + print("MemoryType:", iree.runtime.MemoryType) + print("HOST_VISIBLE:", int(iree.runtime.MemoryType.HOST_VISIBLE)) + + # Enum and/or operations on BufferCompatibility. + self.assertEqual( + iree.runtime.BufferCompatibility.IMPORTABLE + | iree.runtime.BufferCompatibility.EXPORTABLE, + int(iree.runtime.BufferCompatibility.IMPORTABLE) + | int(iree.runtime.BufferCompatibility.EXPORTABLE), + ) + self.assertEqual( + iree.runtime.BufferCompatibility.EXPORTABLE + & iree.runtime.BufferCompatibility.EXPORTABLE, + int(iree.runtime.BufferCompatibility.EXPORTABLE), + ) + + # Enum and/or operations on BufferUsage. + self.assertEqual( + iree.runtime.BufferUsage.TRANSFER | iree.runtime.BufferUsage.MAPPING, + int(iree.runtime.BufferUsage.TRANSFER) + | int(iree.runtime.BufferUsage.MAPPING), + ) + self.assertEqual( + iree.runtime.BufferUsage.TRANSFER & iree.runtime.BufferUsage.TRANSFER, + int(iree.runtime.BufferUsage.TRANSFER), + ) + + # Enum and/or operations on MemoryAccess. + self.assertEqual( + iree.runtime.MemoryAccess.READ | iree.runtime.MemoryAccess.WRITE, + int(iree.runtime.MemoryAccess.READ) | int(iree.runtime.MemoryAccess.WRITE), + ) + self.assertEqual( + iree.runtime.MemoryAccess.ALL & iree.runtime.MemoryAccess.READ, + int(iree.runtime.MemoryAccess.READ), + ) + + # Enum and/or operations on MemoryType. + self.assertEqual( + iree.runtime.MemoryType.DEVICE_LOCAL | iree.runtime.MemoryType.HOST_VISIBLE, + int(iree.runtime.MemoryType.DEVICE_LOCAL) + | int(iree.runtime.MemoryType.HOST_VISIBLE), + ) + self.assertEqual( + iree.runtime.MemoryType.OPTIMAL & iree.runtime.MemoryType.OPTIMAL, + int(iree.runtime.MemoryType.OPTIMAL), + ) class DeviceHalTest(unittest.TestCase): - - def setUp(self): - super().setUp() - self.device = iree.runtime.get_device("local-task") - self.allocator = self.device.allocator - gc.collect() - - def testTrim(self): - self.allocator.trim() - # Just running is sufficient. - - def testProfilingDefaults(self): - self.device.begin_profiling() - self.device.end_profiling() - # Just running is sufficient. - - def testProfilingOptions(self): - self.device.begin_profiling(mode="queue", file_path="foo.rdc") - self.device.end_profiling() - # Just running is sufficient. - - def testProfilingInvalidOptions(self): - with self.assertRaisesRegex(ValueError, "unrecognized profiling mode"): - self.device.begin_profiling(mode="SOMETHING THAT DOESN'T EXIST") - - def testStatistics(self): - stats_dict = self.allocator.statistics - stats_str = self.allocator.formatted_statistics - if self.allocator.has_statistics: - self.assertIn("host_bytes_peak", stats_dict) - self.assertIn("host_bytes_allocated", stats_dict) - self.assertIn("host_bytes_freed", stats_dict) - self.assertIn("device_bytes_peak", stats_dict) - self.assertIn("device_bytes_allocated", stats_dict) - self.assertIn("device_bytes_freed", stats_dict) - self.assertIn("HOST_LOCAL", stats_str) - - def testQueryCompatibility(self): - compat = self.allocator.query_buffer_compatibility( - memory_type=iree.runtime.MemoryType.DEVICE_LOCAL, - allowed_usage=iree.runtime.BufferUsage.DEFAULT, - intended_usage=iree.runtime.BufferUsage.DEFAULT, - allocation_size=1024) - print("COMPAT:", compat) - self.assertTrue( - bool(compat & int(iree.runtime.BufferCompatibility.ALLOCATABLE)), - "should be allocatable") - self.assertTrue( - bool(compat & int(iree.runtime.BufferCompatibility.IMPORTABLE)), - "should be importable") - self.assertTrue( - bool(compat & int(iree.runtime.BufferCompatibility.EXPORTABLE)), - "should be exportable") - - def testAllocateBuffer(self): - buffer = self.allocator.allocate_buffer( - memory_type=iree.runtime.MemoryType.DEVICE_LOCAL, - allowed_usage=iree.runtime.BufferUsage.DEFAULT, - allocation_size=13) - print("BUFFER:", buffer) - - def testAllocateBufferCopy(self): - ary = np.zeros([3, 4], dtype=np.int32) + 2 - buffer = self.allocator.allocate_buffer_copy( - memory_type=iree.runtime.MemoryType.DEVICE_LOCAL, - allowed_usage=iree.runtime.BufferUsage.DEFAULT, - buffer=ary) - self.assertEqual( - repr(buffer), - "" - ) - - def testAllocateBufferViewCopy(self): - ary = np.zeros([3, 4], dtype=np.int32) + 2 - buffer = self.allocator.allocate_buffer_copy( - memory_type=iree.runtime.MemoryType.DEVICE_LOCAL, - allowed_usage=iree.runtime.BufferUsage.DEFAULT, - buffer=ary, - element_type=iree.runtime.HalElementType.SINT_32) - self.assertEqual( - repr(buffer), - "" - ) + def setUp(self): + super().setUp() + self.device = iree.runtime.get_device("local-task") + self.allocator = self.device.allocator + gc.collect() + + def testTrim(self): + self.allocator.trim() + # Just running is sufficient. + + def testProfilingDefaults(self): + self.device.begin_profiling() + self.device.end_profiling() + # Just running is sufficient. + + def testProfilingOptions(self): + self.device.begin_profiling(mode="queue", file_path="foo.rdc") + self.device.end_profiling() + # Just running is sufficient. + + def testProfilingInvalidOptions(self): + with self.assertRaisesRegex(ValueError, "unrecognized profiling mode"): + self.device.begin_profiling(mode="SOMETHING THAT DOESN'T EXIST") + + def testStatistics(self): + stats_dict = self.allocator.statistics + stats_str = self.allocator.formatted_statistics + if self.allocator.has_statistics: + self.assertIn("host_bytes_peak", stats_dict) + self.assertIn("host_bytes_allocated", stats_dict) + self.assertIn("host_bytes_freed", stats_dict) + self.assertIn("device_bytes_peak", stats_dict) + self.assertIn("device_bytes_allocated", stats_dict) + self.assertIn("device_bytes_freed", stats_dict) + self.assertIn("HOST_LOCAL", stats_str) + + def testQueryCompatibility(self): + compat = self.allocator.query_buffer_compatibility( + memory_type=iree.runtime.MemoryType.DEVICE_LOCAL, + allowed_usage=iree.runtime.BufferUsage.DEFAULT, + intended_usage=iree.runtime.BufferUsage.DEFAULT, + allocation_size=1024, + ) + print("COMPAT:", compat) + self.assertTrue( + bool(compat & int(iree.runtime.BufferCompatibility.ALLOCATABLE)), + "should be allocatable", + ) + self.assertTrue( + bool(compat & int(iree.runtime.BufferCompatibility.IMPORTABLE)), + "should be importable", + ) + self.assertTrue( + bool(compat & int(iree.runtime.BufferCompatibility.EXPORTABLE)), + "should be exportable", + ) + + def testAllocateBuffer(self): + buffer = self.allocator.allocate_buffer( + memory_type=iree.runtime.MemoryType.DEVICE_LOCAL, + allowed_usage=iree.runtime.BufferUsage.DEFAULT, + allocation_size=13, + ) + print("BUFFER:", buffer) + + def testAllocateBufferCopy(self): + ary = np.zeros([3, 4], dtype=np.int32) + 2 + buffer = self.allocator.allocate_buffer_copy( + memory_type=iree.runtime.MemoryType.DEVICE_LOCAL, + allowed_usage=iree.runtime.BufferUsage.DEFAULT, + buffer=ary, + ) + self.assertEqual( + repr(buffer), + "", + ) + + def testAllocateBufferViewCopy(self): + ary = np.zeros([3, 4], dtype=np.int32) + 2 + buffer = self.allocator.allocate_buffer_copy( + memory_type=iree.runtime.MemoryType.DEVICE_LOCAL, + allowed_usage=iree.runtime.BufferUsage.DEFAULT, + buffer=ary, + element_type=iree.runtime.HalElementType.SINT_32, + ) + self.assertEqual( + repr(buffer), + "", + ) if __name__ == "__main__": - unittest.main() + unittest.main() diff --git a/runtime/bindings/python/tests/package_test.py b/runtime/bindings/python/tests/package_test.py index 111008663f2c..facbe314c862 100644 --- a/runtime/bindings/python/tests/package_test.py +++ b/runtime/bindings/python/tests/package_test.py @@ -19,13 +19,13 @@ SETUP_PY_DIR = sys.argv[1] if not os.path.exists(os.path.join(SETUP_PY_DIR, "setup.py")): - print("ERROR: Expected directory containing setup.py as argument") + print("ERROR: Expected directory containing setup.py as argument") print(f"Using setup.py directory: {SETUP_PY_DIR}") # Figure out where to stage output. TEMP_DIR = os.getenv("TEST_TMPDIR") if not TEMP_DIR: - TEMP_DIR = tempfile.gettempdir() + TEMP_DIR = tempfile.gettempdir() # Create the venv. VENV_DIR = os.path.join(TEMP_DIR, "iree_runtime_venv") @@ -34,35 +34,34 @@ venv_python = None for venv_bin in [ os.path.join(VENV_DIR, "bin"), # Posix. - os.path.join(VENV_DIR, "Scripts") # Windows. + os.path.join(VENV_DIR, "Scripts"), # Windows. ]: - if os.path.exists(os.path.join(venv_bin, "activate")): - venv_python = os.path.join(venv_bin, "python") + if os.path.exists(os.path.join(venv_bin, "activate")): + venv_python = os.path.join(venv_bin, "python") if not venv_python: - print("ERROR: Could not find venv python") + print("ERROR: Could not find venv python") venv_bin = os.path.dirname(venv_python) print(f"Running with python: {venv_python}") # Install the package. -subprocess.check_call([ - venv_python, "-m", "pip", "install", "--force-reinstall", SETUP_PY_DIR + "/" -]) +subprocess.check_call( + [venv_python, "-m", "pip", "install", "--force-reinstall", SETUP_PY_DIR + "/"] +) # Run some sanity checks. if "PYTHONPATH" in os.environ: - del os.environ["PYTHONPATH"] + del os.environ["PYTHONPATH"] print("***** Sanity checking that module loads...") subprocess.check_call( - [venv_python, "-c", "import iree.runtime; print('Runtime loaded')"], - cwd=venv_bin) + [venv_python, "-c", "import iree.runtime; print('Runtime loaded')"], cwd=venv_bin +) # Check tools. def check_tool(tool_name: str, args: List[str]): - print(f"**** Checking tool {tool_name} with args {args}") - subprocess.check_call([os.path.join(venv_bin, tool_name)] + args, - cwd=venv_bin) + print(f"**** Checking tool {tool_name} with args {args}") + subprocess.check_call([os.path.join(venv_bin, tool_name)] + args, cwd=venv_bin) check_tool("iree-benchmark-module", ["--help"]) diff --git a/runtime/bindings/python/tests/py_module_test.py b/runtime/bindings/python/tests/py_module_test.py index 831b51082e20..05f60788133e 100644 --- a/runtime/bindings/python/tests/py_module_test.py +++ b/runtime/bindings/python/tests/py_module_test.py @@ -13,292 +13,286 @@ class PyModuleInterfaceTest(unittest.TestCase): + def setUp(self): + self._instance = rt.VmInstance() + + def testEmptyModuleLifecycle(self): + iface = rt.PyModuleInterface("test1", NONE_CTOR) + print(iface) + self.assertFalse(iface.initialized) + m = iface.create() + print(iface) + self.assertTrue(iface.initialized) + print(m) + m = None + gc.collect() + print(iface) + self.assertTrue(iface.destroyed) + + def testEmptyModuleInstance(self): + iface = rt.PyModuleInterface("test1", NONE_CTOR) + m = iface.create() + context = rt.VmContext(self._instance, modules=(m,)) + self.assertTrue(iface.initialized) + print(context) + + # Make sure no circular refs and that everything frees. + context = None + m = None + gc.collect() + self.assertTrue(iface.destroyed) + + def testMultiModuleInstance(self): + calls = [] + + def ctor(iface): + calls.append(iface) + return None + + iface = rt.PyModuleInterface("test1", ctor) + m = iface.create() + context1 = rt.VmContext(self._instance, modules=(m,)) + self.assertTrue(iface.initialized) + context2 = rt.VmContext(self._instance, modules=(m,)) + self.assertTrue(iface.initialized) + self.assertEqual(2, len(calls)) + + # Make sure no circular refs and that everything frees. + calls = None + context1 = None + m = None + context2 = None + gc.collect() + self.assertTrue(iface.destroyed) + + def testVoidFunctionExport(self): + messages = [] + + class Methods: + def __init__(self, iface): + self.iface = iface + self.counter = 0 + + def say_hello(self): + messages.append(f"Hello! Your number is {self.counter}") + print(messages[-1]) + self.counter += 1 + + iface = rt.PyModuleInterface("test1", Methods) + iface.export("say_hello", "0v", Methods.say_hello) + m = iface.create() + context = rt.VmContext(self._instance, modules=(m,)) + f = m.lookup_function("say_hello") + self.assertIsNotNone(f) + args = rt.VmVariantList(0) + results = rt.VmVariantList(0) + + # Invoke twice - should produce two messages. + context.invoke(f, args, results) + context.invoke(f, args, results) + self.assertListEqual( + messages, + [ + "Hello! Your number is 0", + "Hello! Your number is 1", + ], + ) + + # Make sure no circular refs and that everything frees. + context = None + m = None + gc.collect() + self.assertTrue(iface.destroyed) + + def testPythonException(self): + messages = [] + + class Methods: + def __init__(self, iface): + pass + + def do_it(self): + raise ValueError("This is from Python") + + iface = rt.PyModuleInterface("test1", Methods) + iface.export("do_it", "0v", Methods.do_it) + m = iface.create() + context = rt.VmContext(self._instance, modules=(m,)) + f = m.lookup_function("do_it") + self.assertIsNotNone(f) + args = rt.VmVariantList(0) + results = rt.VmVariantList(0) + + # We are testing here that the Python level exception is caught and + # translated to an IREE status (surfacing as a RuntimeError) vs percolating + # through the C call stack. + with self.assertRaisesRegex(RuntimeError, "ValueError: This is from Python"): + context.invoke(f, args, results) + + # Make sure no circular refs and that everything frees. + context = None + m = None + gc.collect() + self.assertTrue(iface.destroyed) + + def testPrimitiveArguments(self): + values = [] + + class Methods: + def __init__(self, iface): + pass + + def do_it(self, a, b): + values.append((a, b)) + + iface = rt.PyModuleInterface("test1", Methods) + iface.export("do_it_i32", "0ii", Methods.do_it) + iface.export("do_it_i64", "0II", Methods.do_it) + iface.export("do_it_f32", "0ff", Methods.do_it) + iface.export("do_it_f64", "0FF", Methods.do_it) + m = iface.create() + context = rt.VmContext(self._instance, modules=(m,)) + + args = rt.VmVariantList(2) + results = rt.VmVariantList(0) + args.push_int(42) + args.push_int(43) + context.invoke(m.lookup_function("do_it_i32"), args, results) + context.invoke(m.lookup_function("do_it_i64"), args, results) + + args = rt.VmVariantList(2) + args.push_float(2.0) + args.push_float(4.0) + # TODO: Python doesn't have 32bit floats, so we are populating f64 args. + # These are coming back as zeros, and I expected something to be + # doing a conversion? The same is being done with i64 above but is + # working there. + context.invoke(m.lookup_function("do_it_f32"), args, results) + context.invoke(m.lookup_function("do_it_f64"), args, results) + + print(values) + self.assertEqual(repr(values), "[(42, 43), (42, 43), (0.0, 0.0), (2.0, 4.0)]") + + # Make sure no circular refs and that everything frees. + context = None + m = None + gc.collect() + self.assertTrue(iface.destroyed) + + def testPrimitiveResults(self): + next_results = None + + class Methods: + def __init__(self, iface): + pass + + def do_it(self): + return next_results + + iface = rt.PyModuleInterface("test1", Methods) + iface.export("do_it_i32", "0v_ii", Methods.do_it) + iface.export("do_it_i64", "0v_II", Methods.do_it) + iface.export("do_it_f32", "0v_ff", Methods.do_it) + iface.export("do_it_f64", "0v_FF", Methods.do_it) + iface.export("do_it_unary_i32", "0v_i", Methods.do_it) + m = iface.create() + context = rt.VmContext(self._instance, modules=(m,)) + + args = rt.VmVariantList(0) + + # i32 + results = rt.VmVariantList(2) + next_results = (42, 43) + context.invoke(m.lookup_function("do_it_i32"), args, results) + self.assertEqual(repr(results), "") + + # i64 + results = rt.VmVariantList(2) + next_results = (42, 43) + context.invoke(m.lookup_function("do_it_i64"), args, results) + self.assertEqual(repr(results), "") + + # f32 + results = rt.VmVariantList(2) + next_results = (2.0, 4.0) + context.invoke(m.lookup_function("do_it_f32"), args, results) + self.assertEqual(repr(results), "") + + # f64 + results = rt.VmVariantList(2) + next_results = (2.0, 4.0) + context.invoke(m.lookup_function("do_it_f64"), args, results) + self.assertEqual(repr(results), "") + + # Unary special case. + results = rt.VmVariantList(1) + next_results = 42 + context.invoke(m.lookup_function("do_it_unary_i32"), args, results) + self.assertEqual(repr(results), "") + + # Make sure no circular refs and that everything frees. + context = None + m = None + gc.collect() + self.assertTrue(iface.destroyed) + + def testRefArguments(self): + values = [] + + class Methods: + def __init__(self, iface): + pass + + def do_it(self, a, b): + values.append((a.deref(rt.VmVariantList), b.deref(rt.VmVariantList))) + + iface = rt.PyModuleInterface("test1", Methods) + iface.export("do_it", "0rr", Methods.do_it) + m = iface.create() + context = rt.VmContext(self._instance, modules=(m,)) - def setUp(self): - self._instance = rt.VmInstance() - - def testEmptyModuleLifecycle(self): - iface = rt.PyModuleInterface("test1", NONE_CTOR) - print(iface) - self.assertFalse(iface.initialized) - m = iface.create() - print(iface) - self.assertTrue(iface.initialized) - print(m) - m = None - gc.collect() - print(iface) - self.assertTrue(iface.destroyed) - - def testEmptyModuleInstance(self): - iface = rt.PyModuleInterface("test1", NONE_CTOR) - m = iface.create() - context = rt.VmContext(self._instance, modules=(m,)) - self.assertTrue(iface.initialized) - print(context) - - # Make sure no circular refs and that everything frees. - context = None - m = None - gc.collect() - self.assertTrue(iface.destroyed) - - def testMultiModuleInstance(self): - calls = [] - - def ctor(iface): - calls.append(iface) - return None - - iface = rt.PyModuleInterface("test1", ctor) - m = iface.create() - context1 = rt.VmContext(self._instance, modules=(m,)) - self.assertTrue(iface.initialized) - context2 = rt.VmContext(self._instance, modules=(m,)) - self.assertTrue(iface.initialized) - self.assertEqual(2, len(calls)) - - # Make sure no circular refs and that everything frees. - calls = None - context1 = None - m = None - context2 = None - gc.collect() - self.assertTrue(iface.destroyed) - - def testVoidFunctionExport(self): - messages = [] - - class Methods: - - def __init__(self, iface): - self.iface = iface - self.counter = 0 - - def say_hello(self): - messages.append(f"Hello! Your number is {self.counter}") - print(messages[-1]) - self.counter += 1 - - iface = rt.PyModuleInterface("test1", Methods) - iface.export("say_hello", "0v", Methods.say_hello) - m = iface.create() - context = rt.VmContext(self._instance, modules=(m,)) - f = m.lookup_function("say_hello") - self.assertIsNotNone(f) - args = rt.VmVariantList(0) - results = rt.VmVariantList(0) - - # Invoke twice - should produce two messages. - context.invoke(f, args, results) - context.invoke(f, args, results) - self.assertListEqual(messages, [ - "Hello! Your number is 0", - "Hello! Your number is 1", - ]) - - # Make sure no circular refs and that everything frees. - context = None - m = None - gc.collect() - self.assertTrue(iface.destroyed) - - def testPythonException(self): - messages = [] - - class Methods: - - def __init__(self, iface): - pass - - def do_it(self): - raise ValueError("This is from Python") - - iface = rt.PyModuleInterface("test1", Methods) - iface.export("do_it", "0v", Methods.do_it) - m = iface.create() - context = rt.VmContext(self._instance, modules=(m,)) - f = m.lookup_function("do_it") - self.assertIsNotNone(f) - args = rt.VmVariantList(0) - results = rt.VmVariantList(0) - - # We are testing here that the Python level exception is caught and - # translated to an IREE status (surfacing as a RuntimeError) vs percolating - # through the C call stack. - with self.assertRaisesRegex(RuntimeError, - "ValueError: This is from Python"): - context.invoke(f, args, results) - - # Make sure no circular refs and that everything frees. - context = None - m = None - gc.collect() - self.assertTrue(iface.destroyed) - - def testPrimitiveArguments(self): - values = [] - - class Methods: - - def __init__(self, iface): - pass - - def do_it(self, a, b): - values.append((a, b)) - - iface = rt.PyModuleInterface("test1", Methods) - iface.export("do_it_i32", "0ii", Methods.do_it) - iface.export("do_it_i64", "0II", Methods.do_it) - iface.export("do_it_f32", "0ff", Methods.do_it) - iface.export("do_it_f64", "0FF", Methods.do_it) - m = iface.create() - context = rt.VmContext(self._instance, modules=(m,)) - - args = rt.VmVariantList(2) - results = rt.VmVariantList(0) - args.push_int(42) - args.push_int(43) - context.invoke(m.lookup_function("do_it_i32"), args, results) - context.invoke(m.lookup_function("do_it_i64"), args, results) - - args = rt.VmVariantList(2) - args.push_float(2.0) - args.push_float(4.0) - # TODO: Python doesn't have 32bit floats, so we are populating f64 args. - # These are coming back as zeros, and I expected something to be - # doing a conversion? The same is being done with i64 above but is - # working there. - context.invoke(m.lookup_function("do_it_f32"), args, results) - context.invoke(m.lookup_function("do_it_f64"), args, results) - - print(values) - self.assertEqual(repr(values), - "[(42, 43), (42, 43), (0.0, 0.0), (2.0, 4.0)]") - - # Make sure no circular refs and that everything frees. - context = None - m = None - gc.collect() - self.assertTrue(iface.destroyed) - - def testPrimitiveResults(self): - next_results = None - - class Methods: - - def __init__(self, iface): - pass - - def do_it(self): - return next_results - - iface = rt.PyModuleInterface("test1", Methods) - iface.export("do_it_i32", "0v_ii", Methods.do_it) - iface.export("do_it_i64", "0v_II", Methods.do_it) - iface.export("do_it_f32", "0v_ff", Methods.do_it) - iface.export("do_it_f64", "0v_FF", Methods.do_it) - iface.export("do_it_unary_i32", "0v_i", Methods.do_it) - m = iface.create() - context = rt.VmContext(self._instance, modules=(m,)) - - args = rt.VmVariantList(0) - - # i32 - results = rt.VmVariantList(2) - next_results = (42, 43) - context.invoke(m.lookup_function("do_it_i32"), args, results) - self.assertEqual(repr(results), "") - - # i64 - results = rt.VmVariantList(2) - next_results = (42, 43) - context.invoke(m.lookup_function("do_it_i64"), args, results) - self.assertEqual(repr(results), "") - - # f32 - results = rt.VmVariantList(2) - next_results = (2.0, 4.0) - context.invoke(m.lookup_function("do_it_f32"), args, results) - self.assertEqual(repr(results), "") - - # f64 - results = rt.VmVariantList(2) - next_results = (2.0, 4.0) - context.invoke(m.lookup_function("do_it_f64"), args, results) - self.assertEqual(repr(results), "") - - # Unary special case. - results = rt.VmVariantList(1) - next_results = (42) - context.invoke(m.lookup_function("do_it_unary_i32"), args, results) - self.assertEqual(repr(results), "") - - # Make sure no circular refs and that everything frees. - context = None - m = None - gc.collect() - self.assertTrue(iface.destroyed) - - def testRefArguments(self): - values = [] - - class Methods: - - def __init__(self, iface): - pass - - def do_it(self, a, b): - values.append((a.deref(rt.VmVariantList), b.deref(rt.VmVariantList))) - - iface = rt.PyModuleInterface("test1", Methods) - iface.export("do_it", "0rr", Methods.do_it) - m = iface.create() - context = rt.VmContext(self._instance, modules=(m,)) - - # These lists just happen to be reference objects we know how to - # create. - arg0 = rt.VmVariantList(1) - arg0.push_int(42) - arg1 = rt.VmVariantList(1) - arg1.push_int(84) - - args = rt.VmVariantList(2) - args.push_list(arg0) - args.push_list(arg1) - results = rt.VmVariantList(2) - context.invoke(m.lookup_function("do_it"), args, results) - print("REF VALUES:", values) - self.assertEqual(repr(values), - "[(, )]") - - def testRefResults(self): - - class Methods: - - def __init__(self, iface): - pass - - def do_it(self): # These lists just happen to be reference objects we know how to # create. - r0 = rt.VmVariantList(1) - r0.push_int(42) - r1 = rt.VmVariantList(1) - r1.push_int(84) - return r0.ref, r1.ref - - iface = rt.PyModuleInterface("test1", Methods) - iface.export("do_it", "0v_rr", Methods.do_it) - m = iface.create() - context = rt.VmContext(self._instance, modules=(m,)) - - args = rt.VmVariantList(0) - results = rt.VmVariantList(2) - context.invoke(m.lookup_function("do_it"), args, results) - print("REF RESULTS:", results) - self.assertEqual(repr(results), "") + arg0 = rt.VmVariantList(1) + arg0.push_int(42) + arg1 = rt.VmVariantList(1) + arg1.push_int(84) + + args = rt.VmVariantList(2) + args.push_list(arg0) + args.push_list(arg1) + results = rt.VmVariantList(2) + context.invoke(m.lookup_function("do_it"), args, results) + print("REF VALUES:", values) + self.assertEqual( + repr(values), "[(, )]" + ) + + def testRefResults(self): + class Methods: + def __init__(self, iface): + pass + + def do_it(self): + # These lists just happen to be reference objects we know how to + # create. + r0 = rt.VmVariantList(1) + r0.push_int(42) + r1 = rt.VmVariantList(1) + r1.push_int(84) + return r0.ref, r1.ref + + iface = rt.PyModuleInterface("test1", Methods) + iface.export("do_it", "0v_rr", Methods.do_it) + m = iface.create() + context = rt.VmContext(self._instance, modules=(m,)) + + args = rt.VmVariantList(0) + results = rt.VmVariantList(2) + context.invoke(m.lookup_function("do_it"), args, results) + print("REF RESULTS:", results) + self.assertEqual(repr(results), "") if __name__ == "__main__": - unittest.main() + unittest.main() diff --git a/runtime/bindings/python/tests/system_api_test.py b/runtime/bindings/python/tests/system_api_test.py index 227df92107c2..be5f09646f76 100644 --- a/runtime/bindings/python/tests/system_api_test.py +++ b/runtime/bindings/python/tests/system_api_test.py @@ -18,8 +18,8 @@ def create_simple_mul_module(instance): - binary = iree.compiler.compile_str( - """ + binary = iree.compiler.compile_str( + """ module @arithmetic { func.func @simple_mul(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { %0 = arith.mulf %arg0, %arg1 : tensor<4xf32> @@ -27,129 +27,127 @@ def create_simple_mul_module(instance): } } """, - target_backends=iree.compiler.core.DEFAULT_TESTING_BACKENDS, - ) - m = iree.runtime.VmModule.from_flatbuffer(instance, binary) - return m + target_backends=iree.compiler.core.DEFAULT_TESTING_BACKENDS, + ) + m = iree.runtime.VmModule.from_flatbuffer(instance, binary) + return m class SystemApiTest(unittest.TestCase): - - def test_non_existing_driver(self): - with self.assertRaisesRegex(ValueError, "No device found from list"): - config = iree.runtime.Config("nothere1,nothere2") - - def test_subsequent_driver(self): - config = iree.runtime.Config("nothere1,local-task") - - def test_multi_config_caches(self): - config1 = iree.runtime.Config("nothere1,local-sync") - config2 = iree.runtime.Config("nothere1,local-sync") - self.assertIs(config1.device, config2.device) - - def test_empty_dynamic(self): - ctx = iree.runtime.SystemContext() - self.assertTrue(ctx.is_dynamic) - self.assertIn("hal", ctx.modules) - self.assertEqual(ctx.modules.hal.name, "hal") - - def test_empty_static(self): - ctx = iree.runtime.SystemContext(vm_modules=()) - self.assertFalse(ctx.is_dynamic) - self.assertIn("hal", ctx.modules) - self.assertEqual(ctx.modules.hal.name, "hal") - - def test_custom_dynamic(self): - ctx = iree.runtime.SystemContext() - self.assertTrue(ctx.is_dynamic) - ctx.add_vm_module(create_simple_mul_module(ctx.instance)) - self.assertEqual(ctx.modules.arithmetic.name, "arithmetic") - f = ctx.modules.arithmetic["simple_mul"] - f_repr = repr(f) - logging.info("f_repr: %s", f_repr) - self.assertEqual(f_repr, "") - - def test_duplicate_module(self): - ctx = iree.runtime.SystemContext() - self.assertTrue(ctx.is_dynamic) - ctx.add_vm_module(create_simple_mul_module(ctx.instance)) - with self.assertRaisesRegex(ValueError, "arithmetic"): - ctx.add_vm_module(create_simple_mul_module(ctx.instance)) - - def test_static_invoke(self): - ctx = iree.runtime.SystemContext() - self.assertTrue(ctx.is_dynamic) - ctx.add_vm_module(create_simple_mul_module(ctx.instance)) - self.assertEqual(ctx.modules.arithmetic.name, "arithmetic") - f = ctx.modules.arithmetic["simple_mul"] - arg0 = np.array([1., 2., 3., 4.], dtype=np.float32) - arg1 = np.array([4., 5., 6., 7.], dtype=np.float32) - results = f(arg0, arg1) - np.testing.assert_allclose(results, [4., 10., 18., 28.]) - - def test_chained_invoke(self): - # This ensures that everything works if DeviceArrays are returned - # and input to functions. - ctx = iree.runtime.SystemContext() - self.assertTrue(ctx.is_dynamic) - ctx.add_vm_module(create_simple_mul_module(ctx.instance)) - self.assertEqual(ctx.modules.arithmetic.name, "arithmetic") - f = ctx.modules.arithmetic["simple_mul"] - arg0 = np.array([1., 2., 3., 4.], dtype=np.float32) - arg1 = np.array([4., 5., 6., 7.], dtype=np.float32) - results = f(arg0, arg1) - results2 = f(results, results) - np.testing.assert_allclose(results2, [16., 100., 324., 784.]) - - def test_tracing_explicit(self): - with tempfile.TemporaryDirectory() as temp_dir: - tracer = iree.runtime.Tracer(temp_dir) - config = iree.runtime.Config("local-task", tracer=tracer) - self.verify_tracing(config, temp_dir) - - def test_tracing_from_environment(self): - original = os.environ.get(iree.runtime.TRACE_PATH_ENV_KEY) - try: - with tempfile.TemporaryDirectory() as temp_dir: - os.environ[iree.runtime.TRACE_PATH_ENV_KEY] = temp_dir - config = iree.runtime.Config("local-task") - self.verify_tracing(config, temp_dir) - finally: - if original: - os.environ[iree.runtime.TRACE_PATH_ENV_KEY] = original - - def verify_tracing(self, config, temp_dir): - logging.info("Tracing test to: %s", temp_dir) - ctx = iree.runtime.SystemContext(config=config) - ctx.add_vm_module(create_simple_mul_module(ctx.instance)) - f = ctx.modules.arithmetic["simple_mul"] - arg0 = np.array([1., 2., 3., 4.], dtype=np.float32) - arg1 = np.array([4., 5., 6., 7.], dtype=np.float32) - results = f(arg0, arg1) - self.assertTrue(os.path.exists(os.path.join(temp_dir, "arithmetic.vmfb"))) - self.assertTrue(os.path.exists(os.path.join(temp_dir, "calls.yaml"))) - # TODO: Once replay is possible, verify that. - - def test_load_vm_module(self): - ctx = iree.runtime.SystemContext() - arithmetic = iree.runtime.load_vm_module( - create_simple_mul_module(ctx.instance)) - arg0 = np.array([1., 2., 3., 4.], dtype=np.float32) - arg1 = np.array([4., 5., 6., 7.], dtype=np.float32) - results = arithmetic.simple_mul(arg0, arg1) - print("SIMPLE_MUL RESULTS:", results) - np.testing.assert_allclose(results, [4., 10., 18., 28.]) - - def test_load_multiple_modules(self): - # Doing default device configuration multiple times should be valid - # (if this were instantiating drivers multiple times, it can trigger - # a crash, depending on whether the driver supports multi-instantiation). - ctx = iree.runtime.SystemContext() - m = create_simple_mul_module(ctx.instance) - m1 = iree.runtime.load_vm_module(m) - m2 = iree.runtime.load_vm_module(m) + def test_non_existing_driver(self): + with self.assertRaisesRegex(ValueError, "No device found from list"): + config = iree.runtime.Config("nothere1,nothere2") + + def test_subsequent_driver(self): + config = iree.runtime.Config("nothere1,local-task") + + def test_multi_config_caches(self): + config1 = iree.runtime.Config("nothere1,local-sync") + config2 = iree.runtime.Config("nothere1,local-sync") + self.assertIs(config1.device, config2.device) + + def test_empty_dynamic(self): + ctx = iree.runtime.SystemContext() + self.assertTrue(ctx.is_dynamic) + self.assertIn("hal", ctx.modules) + self.assertEqual(ctx.modules.hal.name, "hal") + + def test_empty_static(self): + ctx = iree.runtime.SystemContext(vm_modules=()) + self.assertFalse(ctx.is_dynamic) + self.assertIn("hal", ctx.modules) + self.assertEqual(ctx.modules.hal.name, "hal") + + def test_custom_dynamic(self): + ctx = iree.runtime.SystemContext() + self.assertTrue(ctx.is_dynamic) + ctx.add_vm_module(create_simple_mul_module(ctx.instance)) + self.assertEqual(ctx.modules.arithmetic.name, "arithmetic") + f = ctx.modules.arithmetic["simple_mul"] + f_repr = repr(f) + logging.info("f_repr: %s", f_repr) + self.assertEqual(f_repr, "") + + def test_duplicate_module(self): + ctx = iree.runtime.SystemContext() + self.assertTrue(ctx.is_dynamic) + ctx.add_vm_module(create_simple_mul_module(ctx.instance)) + with self.assertRaisesRegex(ValueError, "arithmetic"): + ctx.add_vm_module(create_simple_mul_module(ctx.instance)) + + def test_static_invoke(self): + ctx = iree.runtime.SystemContext() + self.assertTrue(ctx.is_dynamic) + ctx.add_vm_module(create_simple_mul_module(ctx.instance)) + self.assertEqual(ctx.modules.arithmetic.name, "arithmetic") + f = ctx.modules.arithmetic["simple_mul"] + arg0 = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32) + arg1 = np.array([4.0, 5.0, 6.0, 7.0], dtype=np.float32) + results = f(arg0, arg1) + np.testing.assert_allclose(results, [4.0, 10.0, 18.0, 28.0]) + + def test_chained_invoke(self): + # This ensures that everything works if DeviceArrays are returned + # and input to functions. + ctx = iree.runtime.SystemContext() + self.assertTrue(ctx.is_dynamic) + ctx.add_vm_module(create_simple_mul_module(ctx.instance)) + self.assertEqual(ctx.modules.arithmetic.name, "arithmetic") + f = ctx.modules.arithmetic["simple_mul"] + arg0 = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32) + arg1 = np.array([4.0, 5.0, 6.0, 7.0], dtype=np.float32) + results = f(arg0, arg1) + results2 = f(results, results) + np.testing.assert_allclose(results2, [16.0, 100.0, 324.0, 784.0]) + + def test_tracing_explicit(self): + with tempfile.TemporaryDirectory() as temp_dir: + tracer = iree.runtime.Tracer(temp_dir) + config = iree.runtime.Config("local-task", tracer=tracer) + self.verify_tracing(config, temp_dir) + + def test_tracing_from_environment(self): + original = os.environ.get(iree.runtime.TRACE_PATH_ENV_KEY) + try: + with tempfile.TemporaryDirectory() as temp_dir: + os.environ[iree.runtime.TRACE_PATH_ENV_KEY] = temp_dir + config = iree.runtime.Config("local-task") + self.verify_tracing(config, temp_dir) + finally: + if original: + os.environ[iree.runtime.TRACE_PATH_ENV_KEY] = original + + def verify_tracing(self, config, temp_dir): + logging.info("Tracing test to: %s", temp_dir) + ctx = iree.runtime.SystemContext(config=config) + ctx.add_vm_module(create_simple_mul_module(ctx.instance)) + f = ctx.modules.arithmetic["simple_mul"] + arg0 = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32) + arg1 = np.array([4.0, 5.0, 6.0, 7.0], dtype=np.float32) + results = f(arg0, arg1) + self.assertTrue(os.path.exists(os.path.join(temp_dir, "arithmetic.vmfb"))) + self.assertTrue(os.path.exists(os.path.join(temp_dir, "calls.yaml"))) + # TODO: Once replay is possible, verify that. + + def test_load_vm_module(self): + ctx = iree.runtime.SystemContext() + arithmetic = iree.runtime.load_vm_module(create_simple_mul_module(ctx.instance)) + arg0 = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32) + arg1 = np.array([4.0, 5.0, 6.0, 7.0], dtype=np.float32) + results = arithmetic.simple_mul(arg0, arg1) + print("SIMPLE_MUL RESULTS:", results) + np.testing.assert_allclose(results, [4.0, 10.0, 18.0, 28.0]) + + def test_load_multiple_modules(self): + # Doing default device configuration multiple times should be valid + # (if this were instantiating drivers multiple times, it can trigger + # a crash, depending on whether the driver supports multi-instantiation). + ctx = iree.runtime.SystemContext() + m = create_simple_mul_module(ctx.instance) + m1 = iree.runtime.load_vm_module(m) + m2 = iree.runtime.load_vm_module(m) if __name__ == "__main__": - logging.basicConfig(level=logging.DEBUG) - unittest.main() + logging.basicConfig(level=logging.DEBUG) + unittest.main() diff --git a/runtime/bindings/python/tests/system_setup_test.py b/runtime/bindings/python/tests/system_setup_test.py index 3248b7ca423c..2d0ddf9ea2cb 100644 --- a/runtime/bindings/python/tests/system_setup_test.py +++ b/runtime/bindings/python/tests/system_setup_test.py @@ -11,64 +11,61 @@ class DeviceSetupTest(unittest.TestCase): + def testQueryDriversDevices(self): + driver_names = ss.query_available_drivers() + print(f"Drivers: {driver_names}") + self.assertIn("local-sync", driver_names) + self.assertIn("local-task", driver_names) - def testQueryDriversDevices(self): - driver_names = ss.query_available_drivers() - print(f"Drivers: {driver_names}") - self.assertIn("local-sync", driver_names) - self.assertIn("local-task", driver_names) + for driver_name in ["local-sync", "local-task"]: + driver = ss.get_driver(driver_name) + print(f"Driver {driver_name}: {driver}") + device_infos = driver.query_available_devices() + print(f"DeviceInfos: {device_infos}") + if driver_name == "local-sync": + # We happen to know that this should have one device_info + self.assertEqual( + device_infos, [{"device_id": 0, "path": "", "name": "default"}] + ) - for driver_name in ["local-sync", "local-task"]: - driver = ss.get_driver(driver_name) - print(f"Driver {driver_name}: {driver}") - device_infos = driver.query_available_devices() - print(f"DeviceInfos: {device_infos}") - if driver_name == "local-sync": - # We happen to know that this should have one device_info - self.assertEqual(device_infos, [{ - "device_id": 0, - "path": "", - "name": "default" - }]) + def testCreateBadDeviceId(self): + driver = ss.get_driver("local-sync") + with self.assertRaises( + ValueError, + msg="Device id 5555 not found. Available devices: [{ device_id:0, path:'', name:'default'}]", + ): + _ = driver.create_device(5555) - def testCreateBadDeviceId(self): - driver = ss.get_driver("local-sync") - with self.assertRaises( - ValueError, - msg= - "Device id 5555 not found. Available devices: [{ device_id:0, path:'', name:'default'}]" - ): - _ = driver.create_device(5555) + def testCreateDevice(self): + driver = ss.get_driver("local-sync") + infos = driver.query_available_devices() + # Each record is a dict: + # {"device_id": obj, "path": str, "name": str}. + device1 = driver.create_device(infos[0]["device_id"]) + # Should also take the info dict directly for convenience. + device2 = driver.create_device(infos[0]) - def testCreateDevice(self): - driver = ss.get_driver("local-sync") - infos = driver.query_available_devices() - # Each record is a dict: - # {"device_id": obj, "path": str, "name": str}. - device1 = driver.create_device(infos[0]["device_id"]) - # Should also take the info dict directly for convenience. - device2 = driver.create_device(infos[0]) + def testCreateDeviceByName(self): + device1 = ss.get_device("local-task") + device2 = ss.get_device("local-sync") + device3 = ss.get_device("local-sync") + device4 = ss.get_device("local-sync", cache=False) + self.assertIsNot(device1, device2) + self.assertIsNot(device3, device4) + self.assertIs(device2, device3) - def testCreateDeviceByName(self): - device1 = ss.get_device("local-task") - device2 = ss.get_device("local-sync") - device3 = ss.get_device("local-sync") - device4 = ss.get_device("local-sync", cache=False) - self.assertIsNot(device1, device2) - self.assertIsNot(device3, device4) - self.assertIs(device2, device3) + with self.assertRaises(ValueError, msg="Device not found: local-sync://1"): + _ = ss.get_device("local-sync://1") - with self.assertRaises(ValueError, msg="Device not found: local-sync://1"): - _ = ss.get_device("local-sync://1") - - def testCreateDeviceWithAllocators(self): - driver = ss.get_driver("local-sync") - infos = driver.query_available_devices() - device1 = driver.create_device(infos[0]["device_id"], allocators=[]) - device2 = driver.create_device(infos[0]["device_id"], - allocators=["caching", "debug"]) + def testCreateDeviceWithAllocators(self): + driver = ss.get_driver("local-sync") + infos = driver.query_available_devices() + device1 = driver.create_device(infos[0]["device_id"], allocators=[]) + device2 = driver.create_device( + infos[0]["device_id"], allocators=["caching", "debug"] + ) if __name__ == "__main__": - logging.basicConfig(level=logging.INFO) - unittest.main() + logging.basicConfig(level=logging.INFO) + unittest.main() diff --git a/runtime/bindings/python/tests/vm_test.py b/runtime/bindings/python/tests/vm_test.py index e997f462c5cd..f412c7b77c39 100644 --- a/runtime/bindings/python/tests/vm_test.py +++ b/runtime/bindings/python/tests/vm_test.py @@ -22,209 +22,204 @@ def compile_add_scalar(): - global COMPILED_ADD_SCALAR - if not COMPILED_ADD_SCALAR: - COMPILED_ADD_SCALAR = iree.compiler.compile_str( - """ - func.func @add_scalar(%arg0: i32, %arg1: i32) -> i32 { - %0 = arith.addi %arg0, %arg1 : i32 - return %0 : i32 - } - """, - target_backends=iree.compiler.core.DEFAULT_TESTING_BACKENDS) - return COMPILED_ADD_SCALAR + global COMPILED_ADD_SCALAR + if not COMPILED_ADD_SCALAR: + COMPILED_ADD_SCALAR = iree.compiler.compile_str( + """ + func.func @add_scalar(%arg0: i32, %arg1: i32) -> i32 { + %0 = arith.addi %arg0, %arg1 : i32 + return %0 : i32 + } + """, + target_backends=iree.compiler.core.DEFAULT_TESTING_BACKENDS, + ) + return COMPILED_ADD_SCALAR def create_add_scalar_module(instance): - binary = compile_add_scalar() - m = iree.runtime.VmModule.from_flatbuffer(instance, binary) - return m + binary = compile_add_scalar() + m = iree.runtime.VmModule.from_flatbuffer(instance, binary) + return m def create_simple_static_mul_module(instance): - binary = iree.compiler.compile_str( - """ - func.func @simple_mul(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { - %0 = arith.mulf %arg0, %arg1 : tensor<4xf32> - return %0 : tensor<4xf32> - } - """, - target_backends=iree.compiler.core.DEFAULT_TESTING_BACKENDS, - ) - m = iree.runtime.VmModule.from_flatbuffer(instance, binary) - return m + binary = iree.compiler.compile_str( + """ + func.func @simple_mul(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { + %0 = arith.mulf %arg0, %arg1 : tensor<4xf32> + return %0 : tensor<4xf32> + } + """, + target_backends=iree.compiler.core.DEFAULT_TESTING_BACKENDS, + ) + m = iree.runtime.VmModule.from_flatbuffer(instance, binary) + return m def create_simple_dynamic_abs_module(instance): - binary = iree.compiler.compile_str( - """ - func.func @dynamic_abs(%arg0: tensor) -> tensor { - %0 = math.absf %arg0 : tensor - return %0 : tensor - } - """, - target_backends=iree.compiler.DEFAULT_TESTING_BACKENDS, - ) - m = iree.runtime.VmModule.from_flatbuffer(instance, binary) - return m + binary = iree.compiler.compile_str( + """ + func.func @dynamic_abs(%arg0: tensor) -> tensor { + %0 = math.absf %arg0 : tensor + return %0 : tensor + } + """, + target_backends=iree.compiler.DEFAULT_TESTING_BACKENDS, + ) + m = iree.runtime.VmModule.from_flatbuffer(instance, binary) + return m class VmTest(unittest.TestCase): - - @classmethod - def setUp(self): - self.instance = iree.runtime.VmInstance() - self.device = iree.runtime.get_device( - iree.compiler.core.DEFAULT_TESTING_DRIVER) - self.hal_module = iree.runtime.create_hal_module(self.instance, self.device) - - def test_context_id(self): - context1 = iree.runtime.VmContext(self.instance) - context2 = iree.runtime.VmContext(self.instance) - self.assertNotEqual(context2.context_id, context1.context_id) - - def test_module_basics(self): - m = create_simple_static_mul_module(self.instance) - f = m.lookup_function("simple_mul") - self.assertGreaterEqual(f.ordinal, 0) - notfound = m.lookup_function("notfound") - self.assertIs(notfound, None) - - def test_dynamic_module_context(self): - context = iree.runtime.VmContext(self.instance) - m = create_simple_static_mul_module(self.instance) - context.register_modules([self.hal_module, m]) - - def test_static_module_context(self): - m = create_simple_static_mul_module(self.instance) - logging.info("module: %s", m) - context = iree.runtime.VmContext(self.instance, - modules=[self.hal_module, m]) - logging.info("context: %s", context) - - def test_dynamic_shape_compile(self): - m = create_simple_dynamic_abs_module(self.instance) - logging.info("module: %s", m) - context = iree.runtime.VmContext(self.instance, - modules=[self.hal_module, m]) - logging.info("context: %s", context) - - def test_add_scalar_new_abi(self): - m = create_add_scalar_module(self.instance) - context = iree.runtime.VmContext(self.instance, - modules=[self.hal_module, m]) - f = m.lookup_function("add_scalar") - finv = iree.runtime.FunctionInvoker(context, self.device, f, tracer=None) - result = finv(5, 6) - logging.info("result: %s", result) - self.assertEqual(result, 11) - - def test_unaligned_buffer_error(self): - buffer = memoryview(b"foobar") - with self.assertRaisesRegex(ValueError, "unaligned buffer"): - # One byte into a heap buffer will never satisfy alignment - # constraints. - iree.runtime.VmModule.wrap_buffer(self.instance, buffer[1:]) - - def test_from_buffer_unaligned_warns(self): - binary = compile_add_scalar() - # One byte into a heap buffer will never satisfy alignment - # constraints. - unaligned_binary = memoryview(b"1" + binary)[1:] - with self.assertWarnsRegex(UserWarning, - "Making copy of unaligned VmModule buffer"): - iree.runtime.VmModule.from_buffer(self.instance, unaligned_binary) - - def test_mmap_implicit_unmap(self): - binary = compile_add_scalar() - with tempfile.NamedTemporaryFile(delete=False) as tf: - tf.write(binary) - tf.flush() - vmfb_file_path = tf.name - - # Note that on Windows, an open file cannot be mapped. - try: - m = iree.runtime.VmModule.mmap(self.instance, vmfb_file_path) - context = iree.runtime.VmContext(self.instance, - modules=[self.hal_module, m]) - f = m.lookup_function("add_scalar") - finv = iree.runtime.FunctionInvoker(context, self.device, f, tracer=None) - result = finv(5, 6) - logging.info("result: %s", result) - self.assertEqual(result, 11) - - del finv - del f - del context - del m - gc.collect() - finally: - # On Windows, a mapped file cannot be deleted and this will fail if - # the mapping was not cleaned up properly. - os.unlink(vmfb_file_path) - - def test_mmap_destroy_callback(self): - binary = compile_add_scalar() - with tempfile.NamedTemporaryFile(delete=False) as tf: - tf.write(binary) - tf.flush() - vmfb_file_path = tf.name - - destroyed = [False] - - def on_destroy(): - print("on_destroy callback") - try: - os.unlink(vmfb_file_path) - except: - print("exception while unlinking mapped file") - traceback.print_exc(file=sys.stdout) - raise - destroyed[0] = True - - # Note that on Windows, an open file cannot be mapped. - m = iree.runtime.VmModule.mmap(self.instance, - vmfb_file_path, - destroy_callback=on_destroy) - context = iree.runtime.VmContext(self.instance, - modules=[self.hal_module, m]) - f = m.lookup_function("add_scalar") - finv = iree.runtime.FunctionInvoker(context, self.device, f, tracer=None) - result = finv(5, 6) - logging.info("result: %s", result) - self.assertEqual(result, 11) - - del finv - del f - del context - del m - gc.collect() - self.assertTrue(destroyed[0]) - - def test_synchronous_dynamic_shape_invoke_function_new_abi(self): - m = create_simple_dynamic_abs_module(self.instance) - context = iree.runtime.VmContext(self.instance, - modules=[self.hal_module, m]) - f = m.lookup_function("dynamic_abs") - finv = iree.runtime.FunctionInvoker(context, self.device, f, tracer=None) - arg0 = np.array([[-1., 2.], [3., -4.]], dtype=np.float32) - result = finv(arg0) - logging.info("result: %s", result) - np.testing.assert_allclose(result, [[1., 2.], [3., 4.]]) - - def test_synchronous_invoke_function_new_abi(self): - m = create_simple_static_mul_module(self.instance) - context = iree.runtime.VmContext(self.instance, - modules=[self.hal_module, m]) - f = m.lookup_function("simple_mul") - finv = iree.runtime.FunctionInvoker(context, self.device, f, tracer=None) - arg0 = np.array([1., 2., 3., 4.], dtype=np.float32) - arg1 = np.array([4., 5., 6., 7.], dtype=np.float32) - result = finv(arg0, arg1) - logging.info("result: %s", result) - np.testing.assert_allclose(result, [4., 10., 18., 28.]) + @classmethod + def setUp(self): + self.instance = iree.runtime.VmInstance() + self.device = iree.runtime.get_device(iree.compiler.core.DEFAULT_TESTING_DRIVER) + self.hal_module = iree.runtime.create_hal_module(self.instance, self.device) + + def test_context_id(self): + context1 = iree.runtime.VmContext(self.instance) + context2 = iree.runtime.VmContext(self.instance) + self.assertNotEqual(context2.context_id, context1.context_id) + + def test_module_basics(self): + m = create_simple_static_mul_module(self.instance) + f = m.lookup_function("simple_mul") + self.assertGreaterEqual(f.ordinal, 0) + notfound = m.lookup_function("notfound") + self.assertIs(notfound, None) + + def test_dynamic_module_context(self): + context = iree.runtime.VmContext(self.instance) + m = create_simple_static_mul_module(self.instance) + context.register_modules([self.hal_module, m]) + + def test_static_module_context(self): + m = create_simple_static_mul_module(self.instance) + logging.info("module: %s", m) + context = iree.runtime.VmContext(self.instance, modules=[self.hal_module, m]) + logging.info("context: %s", context) + + def test_dynamic_shape_compile(self): + m = create_simple_dynamic_abs_module(self.instance) + logging.info("module: %s", m) + context = iree.runtime.VmContext(self.instance, modules=[self.hal_module, m]) + logging.info("context: %s", context) + + def test_add_scalar_new_abi(self): + m = create_add_scalar_module(self.instance) + context = iree.runtime.VmContext(self.instance, modules=[self.hal_module, m]) + f = m.lookup_function("add_scalar") + finv = iree.runtime.FunctionInvoker(context, self.device, f, tracer=None) + result = finv(5, 6) + logging.info("result: %s", result) + self.assertEqual(result, 11) + + def test_unaligned_buffer_error(self): + buffer = memoryview(b"foobar") + with self.assertRaisesRegex(ValueError, "unaligned buffer"): + # One byte into a heap buffer will never satisfy alignment + # constraints. + iree.runtime.VmModule.wrap_buffer(self.instance, buffer[1:]) + + def test_from_buffer_unaligned_warns(self): + binary = compile_add_scalar() + # One byte into a heap buffer will never satisfy alignment + # constraints. + unaligned_binary = memoryview(b"1" + binary)[1:] + with self.assertWarnsRegex( + UserWarning, "Making copy of unaligned VmModule buffer" + ): + iree.runtime.VmModule.from_buffer(self.instance, unaligned_binary) + + def test_mmap_implicit_unmap(self): + binary = compile_add_scalar() + with tempfile.NamedTemporaryFile(delete=False) as tf: + tf.write(binary) + tf.flush() + vmfb_file_path = tf.name + + # Note that on Windows, an open file cannot be mapped. + try: + m = iree.runtime.VmModule.mmap(self.instance, vmfb_file_path) + context = iree.runtime.VmContext( + self.instance, modules=[self.hal_module, m] + ) + f = m.lookup_function("add_scalar") + finv = iree.runtime.FunctionInvoker(context, self.device, f, tracer=None) + result = finv(5, 6) + logging.info("result: %s", result) + self.assertEqual(result, 11) + + del finv + del f + del context + del m + gc.collect() + finally: + # On Windows, a mapped file cannot be deleted and this will fail if + # the mapping was not cleaned up properly. + os.unlink(vmfb_file_path) + + def test_mmap_destroy_callback(self): + binary = compile_add_scalar() + with tempfile.NamedTemporaryFile(delete=False) as tf: + tf.write(binary) + tf.flush() + vmfb_file_path = tf.name + + destroyed = [False] + + def on_destroy(): + print("on_destroy callback") + try: + os.unlink(vmfb_file_path) + except: + print("exception while unlinking mapped file") + traceback.print_exc(file=sys.stdout) + raise + destroyed[0] = True + + # Note that on Windows, an open file cannot be mapped. + m = iree.runtime.VmModule.mmap( + self.instance, vmfb_file_path, destroy_callback=on_destroy + ) + context = iree.runtime.VmContext(self.instance, modules=[self.hal_module, m]) + f = m.lookup_function("add_scalar") + finv = iree.runtime.FunctionInvoker(context, self.device, f, tracer=None) + result = finv(5, 6) + logging.info("result: %s", result) + self.assertEqual(result, 11) + + del finv + del f + del context + del m + gc.collect() + self.assertTrue(destroyed[0]) + + def test_synchronous_dynamic_shape_invoke_function_new_abi(self): + m = create_simple_dynamic_abs_module(self.instance) + context = iree.runtime.VmContext(self.instance, modules=[self.hal_module, m]) + f = m.lookup_function("dynamic_abs") + finv = iree.runtime.FunctionInvoker(context, self.device, f, tracer=None) + arg0 = np.array([[-1.0, 2.0], [3.0, -4.0]], dtype=np.float32) + result = finv(arg0) + logging.info("result: %s", result) + np.testing.assert_allclose(result, [[1.0, 2.0], [3.0, 4.0]]) + + def test_synchronous_invoke_function_new_abi(self): + m = create_simple_static_mul_module(self.instance) + context = iree.runtime.VmContext(self.instance, modules=[self.hal_module, m]) + f = m.lookup_function("simple_mul") + finv = iree.runtime.FunctionInvoker(context, self.device, f, tracer=None) + arg0 = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32) + arg1 = np.array([4.0, 5.0, 6.0, 7.0], dtype=np.float32) + result = finv(arg0, arg1) + logging.info("result: %s", result) + np.testing.assert_allclose(result, [4.0, 10.0, 18.0, 28.0]) if __name__ == "__main__": - logging.basicConfig(level=logging.DEBUG) - unittest.main() + logging.basicConfig(level=logging.DEBUG) + unittest.main() diff --git a/runtime/bindings/python/tests/vm_types_test.py b/runtime/bindings/python/tests/vm_types_test.py index 3e0a951c6ff3..db8c7e8e669e 100644 --- a/runtime/bindings/python/tests/vm_types_test.py +++ b/runtime/bindings/python/tests/vm_types_test.py @@ -12,96 +12,99 @@ class VmTypesTest(unittest.TestCase): + @classmethod + def setUp(self): + # Ensures types are registered. + self.instance = rt.VmInstance() - @classmethod - def setUp(self): - # Ensures types are registered. - self.instance = rt.VmInstance() + def testRefProtocol(self): + lst1 = rt.VmVariantList(0) + ref = lst1.__iree_vm_ref__ + ref2 = lst1.ref + print(ref) + print(ref2) + self.assertEqual(ref, ref2) + self.assertNotEqual(ref, False) + lst2 = rt.VmVariantList.__iree_vm_cast__(ref) + print(lst2) + lst3 = ref.deref(rt.VmVariantList) + print(lst3) + self.assertEqual(lst1, lst2) + self.assertEqual(lst2, lst3) + self.assertNotEqual(lst1, False) + self.assertTrue(ref.isinstance(rt.VmVariantList)) - def testRefProtocol(self): - lst1 = rt.VmVariantList(0) - ref = lst1.__iree_vm_ref__ - ref2 = lst1.ref - print(ref) - print(ref2) - self.assertEqual(ref, ref2) - self.assertNotEqual(ref, False) - lst2 = rt.VmVariantList.__iree_vm_cast__(ref) - print(lst2) - lst3 = ref.deref(rt.VmVariantList) - print(lst3) - self.assertEqual(lst1, lst2) - self.assertEqual(lst2, lst3) - self.assertNotEqual(lst1, False) - self.assertTrue(ref.isinstance(rt.VmVariantList)) + def test_variant_list(self): + l = rt.VmVariantList(5) + logging.info("variant_list: %s", l) + self.assertEqual(l.size, 0) - def test_variant_list(self): - l = rt.VmVariantList(5) - logging.info("variant_list: %s", l) - self.assertEqual(l.size, 0) + def test_variant_list_i64(self): + l = rt.VmVariantList(5) + # Push a value that exceeds 32-bit range. + l.push_int(10 * 1000 * 1000 * 1000) + self.assertEqual(str(l), "") - def test_variant_list_i64(self): - l = rt.VmVariantList(5) - # Push a value that exceeds 32-bit range. - l.push_int(10 * 1000 * 1000 * 1000) - self.assertEqual(str(l), "") + def test_variant_list_buffers(self): + device = rt.get_device("local-sync") + ET = rt.HalElementType + for dt, et in ( + (np.int8, ET.SINT_8), # + (np.int16, ET.SINT_16), # + (np.int32, ET.SINT_32), # + (np.int64, ET.SINT_64), # + (np.uint8, ET.UINT_8), # + (np.uint16, ET.UINT_16), # + (np.uint32, ET.UINT_32), # + (np.uint64, ET.UINT_64), # + (np.float16, ET.FLOAT_16), # + (np.float32, ET.FLOAT_32), # + (np.float64, ET.FLOAT_64), # + (np.complex64, ET.COMPLEX_64), # + (np.complex128, ET.COMPLEX_128), + ): + lst = rt.VmVariantList(5) + ary1 = np.asarray([1, 2, 3, 4], dtype=dt) + bv1 = device.allocator.allocate_buffer_copy( + memory_type=rt.MemoryType.DEVICE_LOCAL, + allowed_usage=(rt.BufferUsage.DEFAULT | rt.BufferUsage.MAPPING), + buffer=ary1, + element_type=et, + ) + lst.push_ref(bv1) + ary2 = rt.DeviceArray( + device, + lst.get_as_object(0, rt.HalBufferView), + implicit_host_transfer=True, + ) + np.testing.assert_array_equal(ary1, ary2) + with self.assertRaises(IndexError): + lst.get_as_object(1, rt.HalBufferView) - def test_variant_list_buffers(self): - device = rt.get_device("local-sync") - ET = rt.HalElementType - for dt, et in ( - (np.int8, ET.SINT_8), # - (np.int16, ET.SINT_16), # - (np.int32, ET.SINT_32), # - (np.int64, ET.SINT_64), # - (np.uint8, ET.UINT_8), # - (np.uint16, ET.UINT_16), # - (np.uint32, ET.UINT_32), # - (np.uint64, ET.UINT_64), # - (np.float16, ET.FLOAT_16), # - (np.float32, ET.FLOAT_32), # - (np.float64, ET.FLOAT_64), # - (np.complex64, ET.COMPLEX_64), # - (np.complex128, ET.COMPLEX_128)): - lst = rt.VmVariantList(5) - ary1 = np.asarray([1, 2, 3, 4], dtype=dt) - bv1 = device.allocator.allocate_buffer_copy( - memory_type=rt.MemoryType.DEVICE_LOCAL, - allowed_usage=(rt.BufferUsage.DEFAULT | rt.BufferUsage.MAPPING), - buffer=ary1, - element_type=et) - lst.push_ref(bv1) - ary2 = rt.DeviceArray(device, - lst.get_as_object(0, rt.HalBufferView), - implicit_host_transfer=True) - np.testing.assert_array_equal(ary1, ary2) - with self.assertRaises(IndexError): - lst.get_as_object(1, rt.HalBufferView) + def test_variant_list_list(self): + lst1 = rt.VmVariantList(5) + lst2 = rt.VmVariantList(5) + lst1.push_list(lst2) + self.assertEqual("", str(lst1)) + lstout = lst1.get_as_list(0) + self.assertEqual("", str(lstout)) + with self.assertRaises(IndexError): + lst1.get_as_list(1) - def test_variant_list_list(self): - lst1 = rt.VmVariantList(5) - lst2 = rt.VmVariantList(5) - lst1.push_list(lst2) - self.assertEqual("", str(lst1)) - lstout = lst1.get_as_list(0) - self.assertEqual("", str(lstout)) - with self.assertRaises(IndexError): - lst1.get_as_list(1) + def test_vm_buffer(self): + b1 = rt.VmBuffer(10, alignment=0, mutable=True) + print(b1) + contents = memoryview(b1) + contents[0:] = b"0123456789" + self.assertEqual(bytes(b1), b"0123456789") - def test_vm_buffer(self): - b1 = rt.VmBuffer(10, alignment=0, mutable=True) - print(b1) - contents = memoryview(b1) - contents[0:] = b'0123456789' - self.assertEqual(bytes(b1), b'0123456789') - - def test_vm_buffer_ro(self): - b1 = rt.VmBuffer(10, alignment=16, mutable=False) - contents = memoryview(b1) - with self.assertRaises(TypeError): - contents[0:] = b'0123456789' + def test_vm_buffer_ro(self): + b1 = rt.VmBuffer(10, alignment=16, mutable=False) + contents = memoryview(b1) + with self.assertRaises(TypeError): + contents[0:] = b"0123456789" if __name__ == "__main__": - logging.basicConfig(level=logging.DEBUG) - unittest.main() + logging.basicConfig(level=logging.DEBUG) + unittest.main() diff --git a/runtime/lit.cfg.py b/runtime/lit.cfg.py index 77a049855cc0..cc344cbed8cd 100644 --- a/runtime/lit.cfg.py +++ b/runtime/lit.cfg.py @@ -20,13 +20,17 @@ config.test_format = lit.formats.ShTest(execute_external=True) # Forward all IREE environment variables passthrough_env_vars = ["VK_ICD_FILENAMES"] -config.environment.update({ - k: v - for k, v in os.environ.items() - if k.startswith("IREE_") or k in passthrough_env_vars -}) +config.environment.update( + { + k: v + for k, v in os.environ.items() + if k.startswith("IREE_") or k in passthrough_env_vars + } +) # Use the most preferred temp directory. -config.test_exec_root = (os.environ.get("TEST_UNDECLARED_OUTPUTS_DIR") or - os.environ.get("TEST_TMPDIR") or - os.path.join(tempfile.gettempdir(), "lit")) +config.test_exec_root = ( + os.environ.get("TEST_UNDECLARED_OUTPUTS_DIR") + or os.environ.get("TEST_TMPDIR") + or os.path.join(tempfile.gettempdir(), "lit") +) diff --git a/runtime/setup.py b/runtime/setup.py index c827ef1457f1..9aee63996eac 100644 --- a/runtime/setup.py +++ b/runtime/setup.py @@ -45,23 +45,23 @@ def check_pip_version(): - from packaging import version - - # Pip versions < 22.0.3 default to out of tree builds, which is quite - # incompatible with what we do (and has other issues). Pip >= 22.0.4 - # removed this option entirely and are only in-tree builds. Since the - # old behavior can silently produce unworking installations, we aggressively - # suppress it. - try: - import pip - except ModuleNotFoundError: - # If pip not installed, we are obviously not trying to package via pip. - pass - else: - if (version.parse(pip.__version__) < version.parse("21.3")): - print("ERROR: pip version >= 21.3 required") - print("Upgrade: pip install pip --upgrade") - sys.exit(2) + from packaging import version + + # Pip versions < 22.0.3 default to out of tree builds, which is quite + # incompatible with what we do (and has other issues). Pip >= 22.0.4 + # removed this option entirely and are only in-tree builds. Since the + # old behavior can silently produce unworking installations, we aggressively + # suppress it. + try: + import pip + except ModuleNotFoundError: + # If pip not installed, we are obviously not trying to package via pip. + pass + else: + if version.parse(pip.__version__) < version.parse("21.3"): + print("ERROR: pip version >= 21.3 required") + print("Upgrade: pip install pip --upgrade") + sys.exit(2) check_pip_version() @@ -85,256 +85,271 @@ def check_pip_version(): IS_CONFIGURED = CONFIGURED_SOURCE_DIR[0] != "@" if IS_CONFIGURED: - IREE_SOURCE_DIR = CONFIGURED_SOURCE_DIR - IREE_BINARY_DIR = CONFIGURED_BINARY_DIR - print( - f"Running setup.py from build tree: " - f"SOURCE_DIR = {IREE_SOURCE_DIR} " - f"BINARY_DIR = {IREE_BINARY_DIR}", - file=sys.stderr) + IREE_SOURCE_DIR = CONFIGURED_SOURCE_DIR + IREE_BINARY_DIR = CONFIGURED_BINARY_DIR + print( + f"Running setup.py from build tree: " + f"SOURCE_DIR = {IREE_SOURCE_DIR} " + f"BINARY_DIR = {IREE_BINARY_DIR}", + file=sys.stderr, + ) else: - IREE_SOURCE_DIR = os.path.join(SETUPPY_DIR, "..") - IREE_BINARY_DIR = os.getenv("IREE_RUNTIME_API_CMAKE_BUILD_DIR") - if not IREE_BINARY_DIR: - # Note that setuptools always builds into a "build" directory that - # is a sibling of setup.py, so we just colonize a sub-directory of that - # by default. - IREE_BINARY_DIR = os.path.join(SETUPPY_DIR, "build", "cmake_build") - print( - f"Running setup.py from source tree: " - f"SOURCE_DIR = {IREE_SOURCE_DIR} " - f"BINARY_DIR = {IREE_BINARY_DIR}", - file=sys.stderr) + IREE_SOURCE_DIR = os.path.join(SETUPPY_DIR, "..") + IREE_BINARY_DIR = os.getenv("IREE_RUNTIME_API_CMAKE_BUILD_DIR") + if not IREE_BINARY_DIR: + # Note that setuptools always builds into a "build" directory that + # is a sibling of setup.py, so we just colonize a sub-directory of that + # by default. + IREE_BINARY_DIR = os.path.join(SETUPPY_DIR, "build", "cmake_build") + print( + f"Running setup.py from source tree: " + f"SOURCE_DIR = {IREE_SOURCE_DIR} " + f"BINARY_DIR = {IREE_BINARY_DIR}", + file=sys.stderr, + ) # Setup and get version information. VERSION_INFO_FILE = os.path.join(IREE_SOURCE_DIR, "version_info.json") def load_version_info(): - with open(VERSION_INFO_FILE, "rt") as f: - return json.load(f) + with open(VERSION_INFO_FILE, "rt") as f: + return json.load(f) def find_git_versions(): - revisions = {} - try: - revisions["IREE"] = subprocess.check_output( - ["git", "rev-parse", "HEAD"], - cwd=IREE_SOURCE_DIR).decode("utf-8").strip() - except subprocess.SubprocessError as e: - print(f"ERROR: Could not get IREE revision: {e}", file=sys.stderr) - return revisions + revisions = {} + try: + revisions["IREE"] = ( + subprocess.check_output(["git", "rev-parse", "HEAD"], cwd=IREE_SOURCE_DIR) + .decode("utf-8") + .strip() + ) + except subprocess.SubprocessError as e: + print(f"ERROR: Could not get IREE revision: {e}", file=sys.stderr) + return revisions def find_git_submodule_revision(submodule_path): - try: - data = subprocess.check_output(["git", "ls-tree", "HEAD", submodule_path], - cwd=IREE_SOURCE_DIR).decode("utf-8").strip() - columns = re.split("\\s+", data) - return columns[2] - except Exception as e: - print( - f"ERROR: Could not get submodule revision for {submodule_path}" - f" ({e})", - file=sys.stderr) - return "" + try: + data = ( + subprocess.check_output( + ["git", "ls-tree", "HEAD", submodule_path], cwd=IREE_SOURCE_DIR + ) + .decode("utf-8") + .strip() + ) + columns = re.split("\\s+", data) + return columns[2] + except Exception as e: + print( + f"ERROR: Could not get submodule revision for {submodule_path}" f" ({e})", + file=sys.stderr, + ) + return "" try: - version_info = load_version_info() + version_info = load_version_info() except FileNotFoundError: - print("version_info.json not found. Using defaults", file=sys.stderr) - version_info = {} + print("version_info.json not found. Using defaults", file=sys.stderr) + version_info = {} git_versions = find_git_versions() PACKAGE_SUFFIX = version_info.get("package-suffix") or "" PACKAGE_VERSION = version_info.get("package-version") if not PACKAGE_VERSION: - PACKAGE_VERSION = f"0.dev0+{git_versions.get('IREE') or '0'}" + PACKAGE_VERSION = f"0.dev0+{git_versions.get('IREE') or '0'}" def maybe_nuke_cmake_cache(): - # From run to run under pip, we can end up with different paths to ninja, - # which isn't great and will confuse cmake. Detect if the location of - # ninja changes and force a cache flush. - ninja_path = "" - try: - import ninja - except ModuleNotFoundError: - pass - else: - ninja_path = ninja.__file__ - expected_stamp_contents = f"{sys.executable}\n{ninja_path}" - - # In order to speed things up on CI and not rebuild everything, we nuke - # the CMakeCache.txt file if the path to the Python interpreter changed. - # Ideally, CMake would let us reconfigure this dynamically... but it does - # not (and gets very confused). - PYTHON_STAMP_FILE = os.path.join(IREE_BINARY_DIR, "python_stamp.txt") - if os.path.exists(PYTHON_STAMP_FILE): - with open(PYTHON_STAMP_FILE, "rt") as f: - actual_stamp_contents = f.read() - if actual_stamp_contents == expected_stamp_contents: - # All good. - return - - # Mismatch or not found. Clean it. - cmake_cache_file = os.path.join(IREE_BINARY_DIR, "CMakeCache.txt") - if os.path.exists(cmake_cache_file): - print("Removing CMakeCache.txt because Python version changed", - file=sys.stderr) - os.remove(cmake_cache_file) - - # Also clean the install directory. This avoids version specific pileups - # of binaries that can occur with repeated builds against different - # Python versions. - if os.path.exists(CMAKE_INSTALL_DIR_ABS): - print( - f"Removing CMake install dir because Python version changed: " - f"{CMAKE_INSTALL_DIR_ABS}", - file=sys.stderr) - shutil.rmtree(CMAKE_INSTALL_DIR_ABS) - - # And write. - with open(PYTHON_STAMP_FILE, "wt") as f: - f.write(expected_stamp_contents) + # From run to run under pip, we can end up with different paths to ninja, + # which isn't great and will confuse cmake. Detect if the location of + # ninja changes and force a cache flush. + ninja_path = "" + try: + import ninja + except ModuleNotFoundError: + pass + else: + ninja_path = ninja.__file__ + expected_stamp_contents = f"{sys.executable}\n{ninja_path}" + + # In order to speed things up on CI and not rebuild everything, we nuke + # the CMakeCache.txt file if the path to the Python interpreter changed. + # Ideally, CMake would let us reconfigure this dynamically... but it does + # not (and gets very confused). + PYTHON_STAMP_FILE = os.path.join(IREE_BINARY_DIR, "python_stamp.txt") + if os.path.exists(PYTHON_STAMP_FILE): + with open(PYTHON_STAMP_FILE, "rt") as f: + actual_stamp_contents = f.read() + if actual_stamp_contents == expected_stamp_contents: + # All good. + return + + # Mismatch or not found. Clean it. + cmake_cache_file = os.path.join(IREE_BINARY_DIR, "CMakeCache.txt") + if os.path.exists(cmake_cache_file): + print("Removing CMakeCache.txt because Python version changed", file=sys.stderr) + os.remove(cmake_cache_file) + + # Also clean the install directory. This avoids version specific pileups + # of binaries that can occur with repeated builds against different + # Python versions. + if os.path.exists(CMAKE_INSTALL_DIR_ABS): + print( + f"Removing CMake install dir because Python version changed: " + f"{CMAKE_INSTALL_DIR_ABS}", + file=sys.stderr, + ) + shutil.rmtree(CMAKE_INSTALL_DIR_ABS) + + # And write. + with open(PYTHON_STAMP_FILE, "wt") as f: + f.write(expected_stamp_contents) def get_env_cmake_option(name: str, default_value: bool = False) -> str: - svalue = os.getenv(name) - if not svalue: - svalue = "ON" if default_value else "OFF" - return f"-D{name}={svalue}" + svalue = os.getenv(name) + if not svalue: + svalue = "ON" if default_value else "OFF" + return f"-D{name}={svalue}" def add_env_cmake_setting(args, env_name: str, cmake_name=None) -> str: - svalue = os.getenv(env_name) - if svalue is not None: - if not cmake_name: - cmake_name = env_name - args.append(f"-D{cmake_name}={svalue}") + svalue = os.getenv(env_name) + if svalue is not None: + if not cmake_name: + cmake_name = env_name + args.append(f"-D{cmake_name}={svalue}") def prepare_installation(): - subprocess.check_call(["cmake", "--version"]) - version_py_content = generate_version_py() - print(f"Generating version.py:\n{version_py_content}", file=sys.stderr) - - if not IS_CONFIGURED: - # Build from source tree. - os.makedirs(IREE_BINARY_DIR, exist_ok=True) - maybe_nuke_cmake_cache() - print(f"CMake build dir: {IREE_BINARY_DIR}", file=sys.stderr) - print(f"CMake install dir: {CMAKE_INSTALL_DIR_ABS}", file=sys.stderr) - cfg = "Release" - cmake_args = [ - "-GNinja", - "--log-level=VERBOSE", - "-DIREE_BUILD_PYTHON_BINDINGS=ON", - "-DIREE_BUILD_COMPILER=OFF", - "-DIREE_BUILD_SAMPLES=OFF", - "-DIREE_BUILD_TESTS=OFF", - "-DPython3_EXECUTABLE={}".format(sys.executable), - "-DCMAKE_BUILD_TYPE={}".format(cfg), - get_env_cmake_option("IREE_HAL_DRIVER_CUDA"), - get_env_cmake_option("IREE_HAL_DRIVER_VULKAN", - "OFF" if platform.system() == "Darwin" else "ON"), - get_env_cmake_option("IREE_ENABLE_RUNTIME_TRACING"), - get_env_cmake_option("IREE_BUILD_TRACY"), - get_env_cmake_option("IREE_ENABLE_CPUINFO", "ON"), + subprocess.check_call(["cmake", "--version"]) + version_py_content = generate_version_py() + print(f"Generating version.py:\n{version_py_content}", file=sys.stderr) + + if not IS_CONFIGURED: + # Build from source tree. + os.makedirs(IREE_BINARY_DIR, exist_ok=True) + maybe_nuke_cmake_cache() + print(f"CMake build dir: {IREE_BINARY_DIR}", file=sys.stderr) + print(f"CMake install dir: {CMAKE_INSTALL_DIR_ABS}", file=sys.stderr) + cfg = "Release" + cmake_args = [ + "-GNinja", + "--log-level=VERBOSE", + "-DIREE_BUILD_PYTHON_BINDINGS=ON", + "-DIREE_BUILD_COMPILER=OFF", + "-DIREE_BUILD_SAMPLES=OFF", + "-DIREE_BUILD_TESTS=OFF", + "-DPython3_EXECUTABLE={}".format(sys.executable), + "-DCMAKE_BUILD_TYPE={}".format(cfg), + get_env_cmake_option("IREE_HAL_DRIVER_CUDA"), + get_env_cmake_option( + "IREE_HAL_DRIVER_VULKAN", + "OFF" if platform.system() == "Darwin" else "ON", + ), + get_env_cmake_option("IREE_ENABLE_RUNTIME_TRACING"), + get_env_cmake_option("IREE_BUILD_TRACY"), + get_env_cmake_option("IREE_ENABLE_CPUINFO", "ON"), + ] + add_env_cmake_setting(cmake_args, "IREE_TRACING_PROVIDER") + add_env_cmake_setting(cmake_args, "IREE_TRACING_PROVIDER_H") + + # These usually flow through the environment, but we add them explicitly + # so that they show clearly in logs (getting them wrong can have bad + # outcomes). + add_env_cmake_setting(cmake_args, "CMAKE_OSX_ARCHITECTURES") + add_env_cmake_setting( + cmake_args, "MACOSX_DEPLOYMENT_TARGET", "CMAKE_OSX_DEPLOYMENT_TARGET" + ) + + # Only do a from-scratch configure if not already configured. + cmake_cache_file = os.path.join(IREE_BINARY_DIR, "CMakeCache.txt") + if not os.path.exists(cmake_cache_file): + print(f"Configuring with: {cmake_args}", file=sys.stderr) + subprocess.check_call( + ["cmake", IREE_SOURCE_DIR] + cmake_args, cwd=IREE_BINARY_DIR + ) + else: + print(f"Not re-configuring (already configured)", file=sys.stderr) + + # Build. Since we have restricted to just the runtime, build everything + # so as to avoid fragility with more targeted selection criteria. + subprocess.check_call(["cmake", "--build", "."], cwd=IREE_BINARY_DIR) + print("Build complete.", file=sys.stderr) + + # Install the component we care about. + install_args = [ + "-DCMAKE_INSTALL_DO_STRIP=ON", + f"-DCMAKE_INSTALL_PREFIX={CMAKE_INSTALL_DIR_ABS}/", + f"-DCMAKE_INSTALL_COMPONENT=IreePythonPackage-runtime", + "-P", + os.path.join(IREE_BINARY_DIR, "cmake_install.cmake"), ] - add_env_cmake_setting(cmake_args, "IREE_TRACING_PROVIDER") - add_env_cmake_setting(cmake_args, "IREE_TRACING_PROVIDER_H") + print(f"Installing with: {install_args}", file=sys.stderr) + subprocess.check_call(["cmake"] + install_args, cwd=IREE_BINARY_DIR) - # These usually flow through the environment, but we add them explicitly - # so that they show clearly in logs (getting them wrong can have bad - # outcomes). - add_env_cmake_setting(cmake_args, "CMAKE_OSX_ARCHITECTURES") - add_env_cmake_setting(cmake_args, "MACOSX_DEPLOYMENT_TARGET", - "CMAKE_OSX_DEPLOYMENT_TARGET") + # Write version.py directly into install dir. + version_py_file = os.path.join( + CMAKE_INSTALL_DIR_ABS, + "python_packages", + "iree_runtime", + "iree", + "runtime", + "version.py", + ) + os.makedirs(os.path.dirname(version_py_file), exist_ok=True) + with open(version_py_file, "wt") as f: + f.write(version_py_content) - # Only do a from-scratch configure if not already configured. - cmake_cache_file = os.path.join(IREE_BINARY_DIR, "CMakeCache.txt") - if not os.path.exists(cmake_cache_file): - print(f"Configuring with: {cmake_args}", file=sys.stderr) - subprocess.check_call(["cmake", IREE_SOURCE_DIR] + cmake_args, - cwd=IREE_BINARY_DIR) - else: - print(f"Not re-configuring (already configured)", file=sys.stderr) - - # Build. Since we have restricted to just the runtime, build everything - # so as to avoid fragility with more targeted selection criteria. - subprocess.check_call(["cmake", "--build", "."], cwd=IREE_BINARY_DIR) - print("Build complete.", file=sys.stderr) - - # Install the component we care about. - install_args = [ - "-DCMAKE_INSTALL_DO_STRIP=ON", - f"-DCMAKE_INSTALL_PREFIX={CMAKE_INSTALL_DIR_ABS}/", - f"-DCMAKE_INSTALL_COMPONENT=IreePythonPackage-runtime", - "-P", - os.path.join(IREE_BINARY_DIR, "cmake_install.cmake"), - ] - print(f"Installing with: {install_args}", file=sys.stderr) - subprocess.check_call(["cmake"] + install_args, cwd=IREE_BINARY_DIR) - - # Write version.py directly into install dir. - version_py_file = os.path.join(CMAKE_INSTALL_DIR_ABS, "python_packages", - "iree_runtime", "iree", "runtime", - "version.py") - os.makedirs(os.path.dirname(version_py_file), exist_ok=True) - with open(version_py_file, "wt") as f: - f.write(version_py_content) - - print(f"Installation prepared: {CMAKE_INSTALL_DIR_ABS}", file=sys.stderr) + print(f"Installation prepared: {CMAKE_INSTALL_DIR_ABS}", file=sys.stderr) class CMakeBuildPy(_build_py): - - def run(self): - # It is critical that the target directory contain all built extensions, - # or else setuptools will helpfully compile an empty binary for us - # (this is the **worst** possible thing it could do). We just copy - # everything. What's another hundred megs between friends? - target_dir = os.path.abspath(self.build_lib) - print(f"Building in target dir: {target_dir}", file=sys.stderr) - os.makedirs(target_dir, exist_ok=True) - print("Copying install to target.", file=sys.stderr) - if os.path.exists(target_dir): - shutil.rmtree(target_dir) - shutil.copytree(os.path.join(CMAKE_INSTALL_DIR_ABS, "python_packages", - "iree_runtime"), - target_dir, - symlinks=False) - print("Target populated.", file=sys.stderr) + def run(self): + # It is critical that the target directory contain all built extensions, + # or else setuptools will helpfully compile an empty binary for us + # (this is the **worst** possible thing it could do). We just copy + # everything. What's another hundred megs between friends? + target_dir = os.path.abspath(self.build_lib) + print(f"Building in target dir: {target_dir}", file=sys.stderr) + os.makedirs(target_dir, exist_ok=True) + print("Copying install to target.", file=sys.stderr) + if os.path.exists(target_dir): + shutil.rmtree(target_dir) + shutil.copytree( + os.path.join(CMAKE_INSTALL_DIR_ABS, "python_packages", "iree_runtime"), + target_dir, + symlinks=False, + ) + print("Target populated.", file=sys.stderr) class CustomBuild(_build): - - def run(self): - self.run_command("build_py") - self.run_command("build_ext") - self.run_command("build_scripts") + def run(self): + self.run_command("build_py") + self.run_command("build_ext") + self.run_command("build_scripts") class CMakeExtension(Extension): - - def __init__(self, name, sourcedir=""): - Extension.__init__(self, name, sources=[]) - self.sourcedir = os.path.abspath(sourcedir) + def __init__(self, name, sourcedir=""): + Extension.__init__(self, name, sources=[]) + self.sourcedir = os.path.abspath(sourcedir) class NoopBuildExtension(_build_ext): + def __init__(self, *args, **kwargs): + assert False - def __init__(self, *args, **kwargs): - assert False - - def build_extension(self, ext): - pass + def build_extension(self, ext): + pass def generate_version_py(): - return f"""# Auto-generated version info. + return f"""# Auto-generated version info. PACKAGE_SUFFIX = "{PACKAGE_SUFFIX}" VERSION = "{PACKAGE_VERSION}" REVISIONS = {json.dumps(git_versions)} @@ -343,24 +358,27 @@ def generate_version_py(): prepare_installation() -packages = find_namespace_packages(where=os.path.join(CMAKE_INSTALL_DIR_ABS, - "python_packages", - "iree_runtime"), - include=[ - "iree._runtime", - "iree.runtime", - "iree.runtime.*", - ]) +packages = find_namespace_packages( + where=os.path.join(CMAKE_INSTALL_DIR_ABS, "python_packages", "iree_runtime"), + include=[ + "iree._runtime", + "iree.runtime", + "iree.runtime.*", + ], +) print(f"Found runtime packages: {packages}") with open( - os.path.join(IREE_SOURCE_DIR, "runtime", "bindings", "python", "iree", - "runtime", "README.md"), "rt") as f: - README = f.read() + os.path.join( + IREE_SOURCE_DIR, "runtime", "bindings", "python", "iree", "runtime", "README.md" + ), + "rt", +) as f: + README = f.read() custom_package_suffix = os.getenv("IREE_RUNTIME_CUSTOM_PACKAGE_SUFFIX") if not custom_package_suffix: - custom_package_suffix = "" + custom_package_suffix = "" setup( name=f"iree-runtime{custom_package_suffix}{PACKAGE_SUFFIX}", diff --git a/runtime/src/iree/tooling/testdata/npy/generate_npy_files.py b/runtime/src/iree/tooling/testdata/npy/generate_npy_files.py index 1f739c3e27b3..2d96ac9e479b 100644 --- a/runtime/src/iree/tooling/testdata/npy/generate_npy_files.py +++ b/runtime/src/iree/tooling/testdata/npy/generate_npy_files.py @@ -8,42 +8,42 @@ import numpy as np # zero bytes -with open('empty.npy', 'wb') as f: - f.flush() +with open("empty.npy", "wb") as f: + f.flush() # single array -with open('single.npy', 'wb') as f: - np.save(f, np.array([1.1, 2.2, 3.3], dtype=np.float32)) +with open("single.npy", "wb") as f: + np.save(f, np.array([1.1, 2.2, 3.3], dtype=np.float32)) # multiple arrays -with open('multiple.npy', 'wb') as f: - np.save(f, np.array([1.1, 2.2, 3.3], dtype=np.float32)) - np.save(f, np.array([[0, 1], [2, 3]], dtype=np.int32)) - np.save(f, np.array(42, dtype=np.int32)) +with open("multiple.npy", "wb") as f: + np.save(f, np.array([1.1, 2.2, 3.3], dtype=np.float32)) + np.save(f, np.array([[0, 1], [2, 3]], dtype=np.int32)) + np.save(f, np.array(42, dtype=np.int32)) # arrays of various shapes -with open('array_shapes.npy', 'wb') as f: - np.save(f, np.array(1, dtype=np.int8)) - np.save(f, np.array([], dtype=np.int8)) - np.save(f, np.array([1], dtype=np.int8)) - np.save(f, np.array([[1], [2]], dtype=np.int8)) - np.save(f, np.array([[0], [1], [2], [3], [4], [5], [6], [7]], dtype=np.int8)) - np.save(f, np.array([[1, 2], [3, 4]], dtype=np.int8)) - np.save(f, np.array([[[1], [2]], [[3], [4]]], dtype=np.int8)) +with open("array_shapes.npy", "wb") as f: + np.save(f, np.array(1, dtype=np.int8)) + np.save(f, np.array([], dtype=np.int8)) + np.save(f, np.array([1], dtype=np.int8)) + np.save(f, np.array([[1], [2]], dtype=np.int8)) + np.save(f, np.array([[0], [1], [2], [3], [4], [5], [6], [7]], dtype=np.int8)) + np.save(f, np.array([[1, 2], [3, 4]], dtype=np.int8)) + np.save(f, np.array([[[1], [2]], [[3], [4]]], dtype=np.int8)) # arrays of various types -with open('array_types.npy', 'wb') as f: - np.save(f, np.array([True, False], dtype=np.bool_)) - np.save(f, np.array([-1, 1], dtype=np.int8)) - np.save(f, np.array([-20000, 20000], dtype=np.int16)) - np.save(f, np.array([-2000000, 2000000], dtype=np.int32)) - np.save(f, np.array([-20000000000, 20000000000], dtype=np.int64)) - np.save(f, np.array([1, 255], dtype=np.uint8)) - np.save(f, np.array([1, 65535], dtype=np.uint16)) - np.save(f, np.array([1, 4294967295], dtype=np.uint32)) - np.save(f, np.array([1, 18446744073709551615], dtype=np.uint64)) - np.save(f, np.array([-1.1, 1.1], dtype=np.float16)) - np.save(f, np.array([-1.1, 1.1], dtype=np.float32)) - np.save(f, np.array([-1.1, 1.1], dtype=np.float64)) - np.save(f, np.array([1 + 5j, 2 + 6j], dtype=np.complex64)) - np.save(f, np.array([1 + 5j, 2 + 6j], dtype=np.complex128)) +with open("array_types.npy", "wb") as f: + np.save(f, np.array([True, False], dtype=np.bool_)) + np.save(f, np.array([-1, 1], dtype=np.int8)) + np.save(f, np.array([-20000, 20000], dtype=np.int16)) + np.save(f, np.array([-2000000, 2000000], dtype=np.int32)) + np.save(f, np.array([-20000000000, 20000000000], dtype=np.int64)) + np.save(f, np.array([1, 255], dtype=np.uint8)) + np.save(f, np.array([1, 65535], dtype=np.uint16)) + np.save(f, np.array([1, 4294967295], dtype=np.uint32)) + np.save(f, np.array([1, 18446744073709551615], dtype=np.uint64)) + np.save(f, np.array([-1.1, 1.1], dtype=np.float16)) + np.save(f, np.array([-1.1, 1.1], dtype=np.float32)) + np.save(f, np.array([-1.1, 1.1], dtype=np.float64)) + np.save(f, np.array([1 + 5j, 2 + 6j], dtype=np.complex64)) + np.save(f, np.array([1 + 5j, 2 + 6j], dtype=np.complex128)) diff --git a/samples/colab/test_notebooks.py b/samples/colab/test_notebooks.py index a714bae7a3eb..8dbbba268be2 100755 --- a/samples/colab/test_notebooks.py +++ b/samples/colab/test_notebooks.py @@ -24,36 +24,37 @@ class ColabNotebookTests(unittest.TestCase): - """Tests running all Colab notebooks in this directory.""" + """Tests running all Colab notebooks in this directory.""" - @classmethod - def generateTests(cls): - repo_root = os.path.dirname( - os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) - script_path = os.path.join(repo_root, - "build_tools/testing/run_python_notebook.sh") + @classmethod + def generateTests(cls): + repo_root = os.path.dirname( + os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + ) + script_path = os.path.join( + repo_root, "build_tools/testing/run_python_notebook.sh" + ) - # Create a test case for each notebook in this folder. - notebooks_path = os.path.join(repo_root, "samples/colab/") - for notebook_path in glob.glob(notebooks_path + "*.ipynb"): - notebook_name = os.path.basename(notebook_path) + # Create a test case for each notebook in this folder. + notebooks_path = os.path.join(repo_root, "samples/colab/") + for notebook_path in glob.glob(notebooks_path + "*.ipynb"): + notebook_name = os.path.basename(notebook_path) - def unit_test(self, notebook_path=notebook_path): + def unit_test(self, notebook_path=notebook_path): + completed_process = subprocess.run([script_path, notebook_path]) + self.assertEqual(completed_process.returncode, 0) - completed_process = subprocess.run([script_path, notebook_path]) - self.assertEqual(completed_process.returncode, 0) + if notebook_name in NOTEBOOKS_TO_SKIP: + unit_test = unittest.skip("Skip requested")(unit_test) + elif notebook_name in NOTEBOOKS_EXPECTED_TO_FAIL: + unit_test = unittest.expectedFailure(unit_test) - if notebook_name in NOTEBOOKS_TO_SKIP: - unit_test = unittest.skip("Skip requested")(unit_test) - elif notebook_name in NOTEBOOKS_EXPECTED_TO_FAIL: - unit_test = unittest.expectedFailure(unit_test) - - # Add 'unit_test' to this class, so the test runner runs it. - unit_test.__name__ = f"test_{notebook_name}" - setattr(cls, unit_test.__name__, unit_test) + # Add 'unit_test' to this class, so the test runner runs it. + unit_test.__name__ = f"test_{notebook_name}" + setattr(cls, unit_test.__name__, unit_test) if __name__ == "__main__": - ColabNotebookTests.generateTests() - logging.basicConfig(level=logging.DEBUG) - unittest.main() + ColabNotebookTests.generateTests() + logging.basicConfig(level=logging.DEBUG) + unittest.main() diff --git a/samples/compiler_plugins/simple_io_sample/test/run_mock.py b/samples/compiler_plugins/simple_io_sample/test/run_mock.py index db9454a89cdd..942c05105790 100644 --- a/samples/compiler_plugins/simple_io_sample/test/run_mock.py +++ b/samples/compiler_plugins/simple_io_sample/test/run_mock.py @@ -17,22 +17,20 @@ print(f"--- Loading {input_file}") with open(input_file, "rb") as f: - vmfb_contents = f.read() + vmfb_contents = f.read() def create_simple_io_module(): + class SimpleIO: + def __init__(self, iface): + ... - class SimpleIO: + def print_impl(self): + print("+++ HELLO FROM SIMPLE_IO") - def __init__(self, iface): - ... - - def print_impl(self): - print("+++ HELLO FROM SIMPLE_IO") - - iface = rt.PyModuleInterface("simple_io", SimpleIO) - iface.export("print", "0v_v", SimpleIO.print_impl) - return iface.create() + iface = rt.PyModuleInterface("simple_io", SimpleIO) + iface.export("print", "0v_v", SimpleIO.print_impl) + return iface.create() config = rt.Config("local-sync") diff --git a/samples/lit.cfg.py b/samples/lit.cfg.py index 77a049855cc0..cc344cbed8cd 100644 --- a/samples/lit.cfg.py +++ b/samples/lit.cfg.py @@ -20,13 +20,17 @@ config.test_format = lit.formats.ShTest(execute_external=True) # Forward all IREE environment variables passthrough_env_vars = ["VK_ICD_FILENAMES"] -config.environment.update({ - k: v - for k, v in os.environ.items() - if k.startswith("IREE_") or k in passthrough_env_vars -}) +config.environment.update( + { + k: v + for k, v in os.environ.items() + if k.startswith("IREE_") or k in passthrough_env_vars + } +) # Use the most preferred temp directory. -config.test_exec_root = (os.environ.get("TEST_UNDECLARED_OUTPUTS_DIR") or - os.environ.get("TEST_TMPDIR") or - os.path.join(tempfile.gettempdir(), "lit")) +config.test_exec_root = ( + os.environ.get("TEST_UNDECLARED_OUTPUTS_DIR") + or os.environ.get("TEST_TMPDIR") + or os.path.join(tempfile.gettempdir(), "lit") +) diff --git a/samples/py_custom_module/decode_secret_message.py b/samples/py_custom_module/decode_secret_message.py index 42682b4e90d7..e28d3e3f3848 100644 --- a/samples/py_custom_module/decode_secret_message.py +++ b/samples/py_custom_module/decode_secret_message.py @@ -29,113 +29,113 @@ def create_tokenizer_module(): - """Creates a module which defines some custom methods for decoding.""" - - class Detokenizer: - - def __init__(self, iface): - # Any class state here is maintained per-context. - self.start_of_text = True - self.start_of_sentence = True - - def reset(self): - self.start_of_text = True - self.start_of_sentence = True - - def accumtokens(self, ids_tensor_ref, token_list_ref): - # TODO: This little dance to turn BufferView refs into real arrays... is not good. - ids_bv = ids_tensor_ref.deref(rt.HalBufferView) - ids_array = ids_bv.map().asarray( - ids_bv.shape, rt.HalElementType.map_to_dtype(ids_bv.element_type)) - token_list = token_list_ref.deref(rt.VmVariantList) - for index in range(ids_array.shape[0]): - token_id = ids_array[index] - token = TOKEN_TABLE[token_id] - - # And this dance to make a buffer... is also not good. - # A real implementation would just map the constant memory, etc. - buffer = rt.VmBuffer(len(token)) - buffer_view = memoryview(buffer) - buffer_view[:] = token - token_list.push_ref(buffer) - return ids_array.shape[0] - - def jointokens(self, token_list_ref): - # The world's dumbest detokenizer. Ideally, the state tracking - # would be in a module private type that got retained and passed - # back through. - token_list = token_list_ref.deref(rt.VmVariantList) - text = bytearray() - for i in range(len(token_list)): - item = bytes(token_list.get_as_object(i, rt.VmBuffer)) - if item == b".": - text.extend(b".") - self.start_of_sentence = True - else: - if not self.start_of_text: - text.extend(b" ") - else: - self.start_of_text = False - if self.start_of_sentence: - text.extend(item[0:1].decode("utf-8").upper().encode("utf-8")) - text.extend(item[1:]) - self.start_of_sentence = False - else: - text.extend(item) - - # TODO: This dance to make a buffer is still bad. - results = rt.VmBuffer(len(text)) - memoryview(results)[:] = text - return results.ref - - iface = rt.PyModuleInterface("detokenizer", Detokenizer) - iface.export("accumtokens", "0rr_i", Detokenizer.accumtokens) - iface.export("jointokens", "0r_r", Detokenizer.jointokens) - iface.export("reset", "0v_v", Detokenizer.reset) - return iface.create() + """Creates a module which defines some custom methods for decoding.""" + + class Detokenizer: + def __init__(self, iface): + # Any class state here is maintained per-context. + self.start_of_text = True + self.start_of_sentence = True + + def reset(self): + self.start_of_text = True + self.start_of_sentence = True + + def accumtokens(self, ids_tensor_ref, token_list_ref): + # TODO: This little dance to turn BufferView refs into real arrays... is not good. + ids_bv = ids_tensor_ref.deref(rt.HalBufferView) + ids_array = ids_bv.map().asarray( + ids_bv.shape, rt.HalElementType.map_to_dtype(ids_bv.element_type) + ) + token_list = token_list_ref.deref(rt.VmVariantList) + for index in range(ids_array.shape[0]): + token_id = ids_array[index] + token = TOKEN_TABLE[token_id] + + # And this dance to make a buffer... is also not good. + # A real implementation would just map the constant memory, etc. + buffer = rt.VmBuffer(len(token)) + buffer_view = memoryview(buffer) + buffer_view[:] = token + token_list.push_ref(buffer) + return ids_array.shape[0] + + def jointokens(self, token_list_ref): + # The world's dumbest detokenizer. Ideally, the state tracking + # would be in a module private type that got retained and passed + # back through. + token_list = token_list_ref.deref(rt.VmVariantList) + text = bytearray() + for i in range(len(token_list)): + item = bytes(token_list.get_as_object(i, rt.VmBuffer)) + if item == b".": + text.extend(b".") + self.start_of_sentence = True + else: + if not self.start_of_text: + text.extend(b" ") + else: + self.start_of_text = False + if self.start_of_sentence: + text.extend(item[0:1].decode("utf-8").upper().encode("utf-8")) + text.extend(item[1:]) + self.start_of_sentence = False + else: + text.extend(item) + + # TODO: This dance to make a buffer is still bad. + results = rt.VmBuffer(len(text)) + memoryview(results)[:] = text + return results.ref + + iface = rt.PyModuleInterface("detokenizer", Detokenizer) + iface.export("accumtokens", "0rr_i", Detokenizer.accumtokens) + iface.export("jointokens", "0r_r", Detokenizer.jointokens) + iface.export("reset", "0v_v", Detokenizer.reset) + return iface.create() def compile(): - return compiler.tools.compile_file(os.path.join(os.path.dirname(__file__), - "main.mlir"), - target_backends=["vmvx"]) + return compiler.tools.compile_file( + os.path.join(os.path.dirname(__file__), "main.mlir"), target_backends=["vmvx"] + ) def main(): - print("Compiling...") - vmfb_contents = compile() - print("Decoding secret message...") - config = rt.Config("local-sync") - main_module = rt.VmModule.from_flatbuffer(config.vm_instance, vmfb_contents) - modules = config.default_vm_modules + ( - create_tokenizer_module(), - main_module, - ) - context = rt.SystemContext(vm_modules=modules, config=config) - - # First message. - count = context.modules.main.add_tokens( - np.asarray([5, 10, 11, 1, 3, 4, 5, 7, 12], dtype=np.int32)) - print(f"ADDED {count} tokens") - - # Second message. - count = context.modules.main.add_tokens(np.asarray([2, 13], dtype=np.int32)) - print(f"ADDED {count} tokens") - - text = bytes(context.modules.main.get_results().deref(rt.VmBuffer)) - print(f"RESULTS: {text}") - - assert text == b"So long and thanks for all so fish. Bye now" - - # Reset and decode some more. - context.modules.main.reset() - count = context.modules.main.add_tokens( - np.asarray([0, 14, 12], dtype=np.int32)) - print(f"ADDED {count} tokens") - text = bytes(context.modules.main.get_results().deref(rt.VmBuffer)) - print(f"RESULTS: {text}") - assert text == b"Hi there." + print("Compiling...") + vmfb_contents = compile() + print("Decoding secret message...") + config = rt.Config("local-sync") + main_module = rt.VmModule.from_flatbuffer(config.vm_instance, vmfb_contents) + modules = config.default_vm_modules + ( + create_tokenizer_module(), + main_module, + ) + context = rt.SystemContext(vm_modules=modules, config=config) + + # First message. + count = context.modules.main.add_tokens( + np.asarray([5, 10, 11, 1, 3, 4, 5, 7, 12], dtype=np.int32) + ) + print(f"ADDED {count} tokens") + + # Second message. + count = context.modules.main.add_tokens(np.asarray([2, 13], dtype=np.int32)) + print(f"ADDED {count} tokens") + + text = bytes(context.modules.main.get_results().deref(rt.VmBuffer)) + print(f"RESULTS: {text}") + + assert text == b"So long and thanks for all so fish. Bye now" + + # Reset and decode some more. + context.modules.main.reset() + count = context.modules.main.add_tokens(np.asarray([0, 14, 12], dtype=np.int32)) + print(f"ADDED {count} tokens") + text = bytes(context.modules.main.get_results().deref(rt.VmBuffer)) + print(f"RESULTS: {text}") + assert text == b"Hi there." if __name__ == "__main__": - main() + main() diff --git a/samples/vision_inference/convert_image.py b/samples/vision_inference/convert_image.py index 6253ccaf6f03..ee1ef440e80e 100644 --- a/samples/vision_inference/convert_image.py +++ b/samples/vision_inference/convert_image.py @@ -16,12 +16,12 @@ # Read image from stdin (in any format supported by PIL). with Image.open(sys.stdin.buffer) as color_img: - # Resize to 28x28, matching what the program expects. - resized_color_img = color_img.resize((28, 28)) - # Convert to grayscale. - grayscale_img = resized_color_img.convert('L') - # Rescale to a float32 in range [0.0, 1.0]. - grayscale_arr = np.array(grayscale_img) - grayscale_arr_f32 = grayscale_arr.astype(np.float32) / 255.0 - # Write bytes back out to stdout. - sys.stdout.buffer.write(grayscale_arr_f32.tobytes()) + # Resize to 28x28, matching what the program expects. + resized_color_img = color_img.resize((28, 28)) + # Convert to grayscale. + grayscale_img = resized_color_img.convert("L") + # Rescale to a float32 in range [0.0, 1.0]. + grayscale_arr = np.array(grayscale_img) + grayscale_arr_f32 = grayscale_arr.astype(np.float32) / 255.0 + # Write bytes back out to stdout. + sys.stdout.buffer.write(grayscale_arr_f32.tobytes()) diff --git a/tests/e2e/matmul/generate_e2e_matmul_tests.py b/tests/e2e/matmul/generate_e2e_matmul_tests.py index 3fde6e68e8fa..53dbbb4d3843 100644 --- a/tests/e2e/matmul/generate_e2e_matmul_tests.py +++ b/tests/e2e/matmul/generate_e2e_matmul_tests.py @@ -22,47 +22,47 @@ # as this also includes accumulator-specific types like i32. @enum.unique class MatrixElemTypeId(enum.Enum): - I8 = "i8" - I32 = "i32" - F32 = "f32" - F16 = "f16" + I8 = "i8" + I32 = "i32" + F32 = "f32" + F16 = "f16" # Enumerates of the collections of shapes that we can generate tests for. # The values are the accepted values for the --shapes= flag. @enum.unique class ShapesId(enum.Enum): - SMALL = "small" - LARGE = "large" - GPU_LARGE = "gpu_large" - GPU_LARGE_ALIGNED = "gpu_large_aligned" + SMALL = "small" + LARGE = "large" + GPU_LARGE = "gpu_large" + GPU_LARGE_ALIGNED = "gpu_large_aligned" # Enumerates of the collections of compilation info that we can generate tests # for. The values are the accepted values for the --compilation_info= flag. @enum.unique class CompilationInfoId(enum.Enum): - NONE = "" - LLVMGPUMatmulSimt = "LLVMGPUMatmulSimt" - LLVMGPUMatmulTensorCore = "LLVMGPUMatmulTensorCore" - LLVMGPUMatmulTensorCoreMmaSync = "LLVMGPUMatmulTensorCoreMmaSync" - SPIRVVectorizeMali = "SPIRVVectorizeMali" - SPIRVVectorizeNVIDIA = "SPIRVVectorizeNVIDIA" + NONE = "" + LLVMGPUMatmulSimt = "LLVMGPUMatmulSimt" + LLVMGPUMatmulTensorCore = "LLVMGPUMatmulTensorCore" + LLVMGPUMatmulTensorCoreMmaSync = "LLVMGPUMatmulTensorCoreMmaSync" + SPIRVVectorizeMali = "SPIRVVectorizeMali" + SPIRVVectorizeNVIDIA = "SPIRVVectorizeNVIDIA" # Enumerates ways to construct MLIR tensor types. @enum.unique class Dynamicity(enum.Enum): - DYNAMIC = "dynamic" # Use '?' everywhere. Example: tensor. - STATIC = "static" # Use fixed values everywhere. Example: tensor<4x6xf32>. - MIXED = "mixed" # Randomly mix '?' and values. Example: tensor. + DYNAMIC = "dynamic" # Use '?' everywhere. Example: tensor. + STATIC = "static" # Use fixed values everywhere. Example: tensor<4x6xf32>. + MIXED = "mixed" # Randomly mix '?' and values. Example: tensor. # Enumerates ways to initialize matrix buffer contents. @enum.unique class MatrixGenerator(enum.Enum): - ZERO = "zero" # Fill with zeros - RANDOM = "random" # Fill with (deterministic) pseudorandom values. + ZERO = "zero" # Fill with zeros + RANDOM = "random" # Fill with (deterministic) pseudorandom values. # Describes the shape of a matrix multiplication in the usual convention: @@ -72,204 +72,214 @@ class MatrixGenerator(enum.Enum): # (C = A * B). @dataclasses.dataclass class TestShape: - m: int - k: int - n: int - accumulate: bool + m: int + k: int + n: int + accumulate: bool # Describes how to construct compilation info for the testcase. @dataclasses.dataclass class CompilationInfo: - # Lowering Config - tile_sizes: typing.List[typing.List[int]] - # Translation Info - dispatch_lowering_pass_pipeline: str - workload_per_wg: typing.List[int] - software_pipeline_depth: int - # Compilation info - workgroup_size: typing.List[int] - - # Prints the workgroup size as 'index' types - def workgroup_size_str(self): - return "[" + ", ".join([f"{size} : index" for size in self.workgroup_size - ]) + "]" + # Lowering Config + tile_sizes: typing.List[typing.List[int]] + # Translation Info + dispatch_lowering_pass_pipeline: str + workload_per_wg: typing.List[int] + software_pipeline_depth: int + # Compilation info + workgroup_size: typing.List[int] + + # Prints the workgroup size as 'index' types + def workgroup_size_str(self): + return ( + "[" + ", ".join([f"{size} : index" for size in self.workgroup_size]) + "]" + ) # Returns the list of TestShape's to use for the collection of shapes # identified by shapes_id. def get_test_shapes(shapes_id: ShapesId): - # Notes: - # 1. Be conservative in adding more shapes, as that can increase both the - # build and execution latency of tests. The build latency is nearly the - # same for all shapes, while execution latency grows cubicly i.e. - # linearly with m*k*n. - # 2. Some shapes are commented out: they used to be tested but have been - # disabled to improve the trade-off between test coverage and build - # latency. - if shapes_id == ShapesId.SMALL: - return [ - # square matrices. Start by the simplest case of 1x1x1. - TestShape(m=1, k=1, n=1, accumulate=True), - TestShape(m=1, k=1, n=1, accumulate=False), - # test 9x9x9 because as many kernel M0/K0/N0 dims are equal to 8, - # this will often be the smallest value that exercises something above - # the kernel's size. - TestShape(m=9, k=9, n=9, accumulate=True), - # rectangular matrices. - # >= 2x differences between M/N/K dims may exercise tiling corner cases - # not exercised by nearly-square matrices. - TestShape(m=6, k=13, n=3, accumulate=True), - TestShape(m=15, k=37, n=7, accumulate=False), - TestShape(m=81, k=19, n=41, accumulate=True), - # shapes involving vectors (i.e. most rectangular cases) - # This is particularly relevant because we have dedicated kernels for - # the matrix*vector / vector*matrix case. - TestShape(m=1, k=10, n=10, accumulate=True), # vector*matrix - TestShape(m=1, k=10, n=10, accumulate=False), # vector*matrix - TestShape(m=10, k=1, n=10, accumulate=True), # outer-product - TestShape(m=10, k=10, n=1, accumulate=True), # matrix*vector - TestShape(m=10, k=10, n=1, accumulate=False), # matrix*vector - ] - if shapes_id == ShapesId.LARGE: - return [ - # some random large sizes - TestShape(m=123, k=456, n=789, accumulate=True), - TestShape(m=654, k=321, n=234, accumulate=False), - # shapes involving vectors (i.e. most rectangular cases) - TestShape(m=1, k=1000, n=1000, accumulate=True), # large vector*matrix - TestShape(m=1000, k=1000, n=1, accumulate=True), # large matrix*vector - TestShape(m=1000, k=1000, n=1, accumulate=False), # large matrix*vector - # Be conservative in adding larger shapes. They can result in - # high latency tests. If you have to, consider splitting them - # out in a way that constrains the latency impact, e.g. by - # running on fewer backends/drivers or with fewer generators - # (see get_test_generators). - ] - if shapes_id == ShapesId.GPU_LARGE_ALIGNED: - return [ - TestShape(m=256, k=128, n=512, accumulate=True), - TestShape(m=256, k=128, n=512, accumulate=False), - ] - if shapes_id == ShapesId.GPU_LARGE: - return [ - # unaligned cases. - TestShape(m=457, k=330, n=512, accumulate=False), - TestShape(m=457, k=330, n=514, accumulate=False), - TestShape(m=438, k=330, n=514, accumulate=False), - TestShape(m=540, k=332, n=516, accumulate=False), - TestShape(m=1000, k=4, n=512, accumulate=False), - TestShape(m=4, k=1000, n=512, accumulate=False), - TestShape(m=512, k=1000, n=4, accumulate=False), - TestShape(m=512, k=128, n=500, accumulate=False), - TestShape(m=457, k=160, n=512, accumulate=False), - TestShape(m=512, k=330, n=512, accumulate=False), - ] - - raise ValueError(shapes_id) + # Notes: + # 1. Be conservative in adding more shapes, as that can increase both the + # build and execution latency of tests. The build latency is nearly the + # same for all shapes, while execution latency grows cubicly i.e. + # linearly with m*k*n. + # 2. Some shapes are commented out: they used to be tested but have been + # disabled to improve the trade-off between test coverage and build + # latency. + if shapes_id == ShapesId.SMALL: + return [ + # square matrices. Start by the simplest case of 1x1x1. + TestShape(m=1, k=1, n=1, accumulate=True), + TestShape(m=1, k=1, n=1, accumulate=False), + # test 9x9x9 because as many kernel M0/K0/N0 dims are equal to 8, + # this will often be the smallest value that exercises something above + # the kernel's size. + TestShape(m=9, k=9, n=9, accumulate=True), + # rectangular matrices. + # >= 2x differences between M/N/K dims may exercise tiling corner cases + # not exercised by nearly-square matrices. + TestShape(m=6, k=13, n=3, accumulate=True), + TestShape(m=15, k=37, n=7, accumulate=False), + TestShape(m=81, k=19, n=41, accumulate=True), + # shapes involving vectors (i.e. most rectangular cases) + # This is particularly relevant because we have dedicated kernels for + # the matrix*vector / vector*matrix case. + TestShape(m=1, k=10, n=10, accumulate=True), # vector*matrix + TestShape(m=1, k=10, n=10, accumulate=False), # vector*matrix + TestShape(m=10, k=1, n=10, accumulate=True), # outer-product + TestShape(m=10, k=10, n=1, accumulate=True), # matrix*vector + TestShape(m=10, k=10, n=1, accumulate=False), # matrix*vector + ] + if shapes_id == ShapesId.LARGE: + return [ + # some random large sizes + TestShape(m=123, k=456, n=789, accumulate=True), + TestShape(m=654, k=321, n=234, accumulate=False), + # shapes involving vectors (i.e. most rectangular cases) + TestShape(m=1, k=1000, n=1000, accumulate=True), # large vector*matrix + TestShape(m=1000, k=1000, n=1, accumulate=True), # large matrix*vector + TestShape(m=1000, k=1000, n=1, accumulate=False), # large matrix*vector + # Be conservative in adding larger shapes. They can result in + # high latency tests. If you have to, consider splitting them + # out in a way that constrains the latency impact, e.g. by + # running on fewer backends/drivers or with fewer generators + # (see get_test_generators). + ] + if shapes_id == ShapesId.GPU_LARGE_ALIGNED: + return [ + TestShape(m=256, k=128, n=512, accumulate=True), + TestShape(m=256, k=128, n=512, accumulate=False), + ] + if shapes_id == ShapesId.GPU_LARGE: + return [ + # unaligned cases. + TestShape(m=457, k=330, n=512, accumulate=False), + TestShape(m=457, k=330, n=514, accumulate=False), + TestShape(m=438, k=330, n=514, accumulate=False), + TestShape(m=540, k=332, n=516, accumulate=False), + TestShape(m=1000, k=4, n=512, accumulate=False), + TestShape(m=4, k=1000, n=512, accumulate=False), + TestShape(m=512, k=1000, n=4, accumulate=False), + TestShape(m=512, k=128, n=500, accumulate=False), + TestShape(m=457, k=160, n=512, accumulate=False), + TestShape(m=512, k=330, n=512, accumulate=False), + ] + + raise ValueError(shapes_id) # Returns the list of Dynamicity's to use for the collection of shapes # identified by shapes_id. def get_dynamicities(shapes_id: ShapesId): - if shapes_id == ShapesId.GPU_LARGE or shapes_id == ShapesId.GPU_LARGE_ALIGNED: - return [ - Dynamicity.STATIC, - ] - else: - return [ - Dynamicity.DYNAMIC, - Dynamicity.STATIC, - ] - raise ValueError(shapes_id) + if shapes_id == ShapesId.GPU_LARGE or shapes_id == ShapesId.GPU_LARGE_ALIGNED: + return [ + Dynamicity.STATIC, + ] + else: + return [ + Dynamicity.DYNAMIC, + Dynamicity.STATIC, + ] + raise ValueError(shapes_id) @dataclasses.dataclass class TileWorkgroupSizePair: - tile_size: typing.List[typing.List[int]] - workgroup_size: typing.List[int] + tile_size: typing.List[typing.List[int]] + workgroup_size: typing.List[int] # Constructs a TileWorkgroupSizePair for SPIRV Targets enforcing the constraints between # the workgroup_size and tile size -def get_spirv_tile_workgroup_size_pair(workgroup_size, - t_tile_k, - t_tile_m=4, - t_tile_n=4): - x, y, z = workgroup_size - wg_tile_m = y * t_tile_m - wg_tile_n = x * t_tile_n - return TileWorkgroupSizePair( - [[wg_tile_m, wg_tile_n], [t_tile_m, t_tile_n], [0, 0, t_tile_k]], - workgroup_size) +def get_spirv_tile_workgroup_size_pair( + workgroup_size, t_tile_k, t_tile_m=4, t_tile_n=4 +): + x, y, z = workgroup_size + wg_tile_m = y * t_tile_m + wg_tile_n = x * t_tile_n + return TileWorkgroupSizePair( + [[wg_tile_m, wg_tile_n], [t_tile_m, t_tile_n], [0, 0, t_tile_k]], workgroup_size + ) # Returns all the TileWorkgroupSizePairs for a given SPIRV Target def get_all_spirv_tile_workgroup_size_pairs(t_tile_k): - tile_workgroup_size_pairs = [ - get_spirv_tile_workgroup_size_pair([32, 8, 1], t_tile_k), - get_spirv_tile_workgroup_size_pair([16, 8, 1], t_tile_k), - get_spirv_tile_workgroup_size_pair([64, 2, 1], t_tile_k), - get_spirv_tile_workgroup_size_pair([8, 8, 1], t_tile_k), - get_spirv_tile_workgroup_size_pair([32, 1, 1], t_tile_k), - get_spirv_tile_workgroup_size_pair([16, 2, 1], t_tile_k), - get_spirv_tile_workgroup_size_pair([32, 1, 1], t_tile_k), - ] - return tile_workgroup_size_pairs + tile_workgroup_size_pairs = [ + get_spirv_tile_workgroup_size_pair([32, 8, 1], t_tile_k), + get_spirv_tile_workgroup_size_pair([16, 8, 1], t_tile_k), + get_spirv_tile_workgroup_size_pair([64, 2, 1], t_tile_k), + get_spirv_tile_workgroup_size_pair([8, 8, 1], t_tile_k), + get_spirv_tile_workgroup_size_pair([32, 1, 1], t_tile_k), + get_spirv_tile_workgroup_size_pair([16, 2, 1], t_tile_k), + get_spirv_tile_workgroup_size_pair([32, 1, 1], t_tile_k), + ] + return tile_workgroup_size_pairs # Returns the list of CompilationInfo's to use for the CompilationInfoId. def get_test_compilation_infos( compilation_info_id: CompilationInfoId, lhs_rhs_type: MatrixElemTypeId ) -> typing.List[typing.Optional[CompilationInfo]]: - if compilation_info_id == CompilationInfoId.NONE: - return [None] - if compilation_info_id == CompilationInfoId.LLVMGPUMatmulSimt: - tile_workgroup_size_pairs = [ - TileWorkgroupSizePair([[32, 128, 32]], [32, 8, 1]), - TileWorkgroupSizePair([[128, 64, 8]], [16, 8, 1]), - TileWorkgroupSizePair([[16, 256, 32]], [64, 2, 1]), - TileWorkgroupSizePair([[8, 32, 32]], [8, 8, 1]), - TileWorkgroupSizePair([[8, 128, 4]], [32, 1, 1]), - TileWorkgroupSizePair([[16, 64, 4]], [16, 2, 1]), - TileWorkgroupSizePair([[1, 128, 8]], [32, 1, 1]), - ] - elif compilation_info_id == CompilationInfoId.SPIRVVectorizeNVIDIA: - tile_workgroup_size_pairs = get_all_spirv_tile_workgroup_size_pairs(32) - elif compilation_info_id == CompilationInfoId.SPIRVVectorizeMali: - tile_workgroup_size_pairs = get_all_spirv_tile_workgroup_size_pairs(4) - elif compilation_info_id == CompilationInfoId.LLVMGPUMatmulTensorCore or compilation_info_id == CompilationInfoId.LLVMGPUMatmulTensorCoreMmaSync: - tile_workgroup_size_pairs = [] - ## WarpShape = 2x2 - tile_workgroup_size_pairs.append( - TileWorkgroupSizePair([[32, 32, 16]], [64, 2, 1])) - tile_workgroup_size_pairs.append( - TileWorkgroupSizePair([[64, 64, 64]], [64, 2, 1])) - - ## WarpShape = 4x1 - tile_workgroup_size_pairs.append( - TileWorkgroupSizePair([[32, 32, 32]], [64, 1, 1])) - - ## WarpShape = 2x2 with large tiles using larger Shared Memory capacity. - if lhs_rhs_type == MatrixElemTypeId.F16: - tile_workgroup_size_pairs.append( - TileWorkgroupSizePair([[128, 128, 64]], [64, 2, 1])) - elif lhs_rhs_type == MatrixElemTypeId.F32: - tile_workgroup_size_pairs.append( - TileWorkgroupSizePair([[128, 128, 16]], [64, 2, 1])) - - compilation_infos = [] - for tile_workgroup_size_pair in tile_workgroup_size_pairs: - compilation_infos.append( - CompilationInfo( - tile_sizes=tile_workgroup_size_pair.tile_size, - dispatch_lowering_pass_pipeline=compilation_info_id.value, - workload_per_wg=[ - a for a in reversed(tile_workgroup_size_pair.tile_size[0:2]) - ], - workgroup_size=tile_workgroup_size_pair.workgroup_size, - software_pipeline_depth=3)) - return compilation_infos + if compilation_info_id == CompilationInfoId.NONE: + return [None] + if compilation_info_id == CompilationInfoId.LLVMGPUMatmulSimt: + tile_workgroup_size_pairs = [ + TileWorkgroupSizePair([[32, 128, 32]], [32, 8, 1]), + TileWorkgroupSizePair([[128, 64, 8]], [16, 8, 1]), + TileWorkgroupSizePair([[16, 256, 32]], [64, 2, 1]), + TileWorkgroupSizePair([[8, 32, 32]], [8, 8, 1]), + TileWorkgroupSizePair([[8, 128, 4]], [32, 1, 1]), + TileWorkgroupSizePair([[16, 64, 4]], [16, 2, 1]), + TileWorkgroupSizePair([[1, 128, 8]], [32, 1, 1]), + ] + elif compilation_info_id == CompilationInfoId.SPIRVVectorizeNVIDIA: + tile_workgroup_size_pairs = get_all_spirv_tile_workgroup_size_pairs(32) + elif compilation_info_id == CompilationInfoId.SPIRVVectorizeMali: + tile_workgroup_size_pairs = get_all_spirv_tile_workgroup_size_pairs(4) + elif ( + compilation_info_id == CompilationInfoId.LLVMGPUMatmulTensorCore + or compilation_info_id == CompilationInfoId.LLVMGPUMatmulTensorCoreMmaSync + ): + tile_workgroup_size_pairs = [] + ## WarpShape = 2x2 + tile_workgroup_size_pairs.append( + TileWorkgroupSizePair([[32, 32, 16]], [64, 2, 1]) + ) + tile_workgroup_size_pairs.append( + TileWorkgroupSizePair([[64, 64, 64]], [64, 2, 1]) + ) + + ## WarpShape = 4x1 + tile_workgroup_size_pairs.append( + TileWorkgroupSizePair([[32, 32, 32]], [64, 1, 1]) + ) + + ## WarpShape = 2x2 with large tiles using larger Shared Memory capacity. + if lhs_rhs_type == MatrixElemTypeId.F16: + tile_workgroup_size_pairs.append( + TileWorkgroupSizePair([[128, 128, 64]], [64, 2, 1]) + ) + elif lhs_rhs_type == MatrixElemTypeId.F32: + tile_workgroup_size_pairs.append( + TileWorkgroupSizePair([[128, 128, 16]], [64, 2, 1]) + ) + + compilation_infos = [] + for tile_workgroup_size_pair in tile_workgroup_size_pairs: + compilation_infos.append( + CompilationInfo( + tile_sizes=tile_workgroup_size_pair.tile_size, + dispatch_lowering_pass_pipeline=compilation_info_id.value, + workload_per_wg=[ + a for a in reversed(tile_workgroup_size_pair.tile_size[0:2]) + ], + workgroup_size=tile_workgroup_size_pair.workgroup_size, + software_pipeline_depth=3, + ) + ) + return compilation_infos # Intentionally fixed seed! We want full reproducibility here, both across runs @@ -283,29 +293,29 @@ def get_test_compilation_infos( # such as 'tensor'. None means a dynamic size, similar to '?' in MLIR. @dataclasses.dataclass class DimSize: - value: typing.Optional[int] + value: typing.Optional[int] # Generates a compile-time MLIR size value, i.e. either a fixed positive integer # or None (which maps to MLIR '?') depending on dynamicity. def shape_dim(x: int, dynamicity: Dynamicity): - if dynamicity == Dynamicity.DYNAMIC: - return DimSize(None) - elif dynamicity == Dynamicity.STATIC: - return DimSize(x) - else: - raise ValueError(dynamicity) + if dynamicity == Dynamicity.DYNAMIC: + return DimSize(None) + elif dynamicity == Dynamicity.STATIC: + return DimSize(x) + else: + raise ValueError(dynamicity) # Stringification used for generating MLIR types, e.g. tensor. def int_or_question_mark(s: DimSize): - return s.value or "?" + return s.value or "?" # Stringification used for generating alphanumeric identifiers, e.g. # func.func @somefunction_DYNxDYNxf32, where we can't use "?" characters. def int_or_DYN(s: DimSize): - return s.value or "DYN" + return s.value or "DYN" # Describes the fully resolved shape dimensions of all 3 input matrices, @@ -315,27 +325,27 @@ def int_or_DYN(s: DimSize): # These string values are used to generate MLIR function names and tensor shapes. @dataclasses.dataclass class TestInputMatricesShapes: - lhs_rows: DimSize - lhs_cols: DimSize - rhs_rows: DimSize - rhs_cols: DimSize - acc_rows: DimSize - acc_cols: DimSize + lhs_rows: DimSize + lhs_cols: DimSize + rhs_rows: DimSize + rhs_cols: DimSize + acc_rows: DimSize + acc_cols: DimSize # Helper for generate_function. Generates TestInputMatricesShapes, i.e. # converts from the runtime shape dimensions in TestShape and given dynamicity to # the set of shapes to be used in a test function's input tensors. def generate_shapes(shape: TestShape, dynamicity: Dynamicity): - shapes = TestInputMatricesShapes( - lhs_rows=shape_dim(shape.m, dynamicity), - lhs_cols=shape_dim(shape.k, dynamicity), - rhs_rows=shape_dim(shape.k, dynamicity), - rhs_cols=shape_dim(shape.n, dynamicity), - acc_rows=shape_dim(shape.m, dynamicity), - acc_cols=shape_dim(shape.n, dynamicity), - ) - return shapes + shapes = TestInputMatricesShapes( + lhs_rows=shape_dim(shape.m, dynamicity), + lhs_cols=shape_dim(shape.k, dynamicity), + rhs_rows=shape_dim(shape.k, dynamicity), + rhs_cols=shape_dim(shape.n, dynamicity), + acc_rows=shape_dim(shape.m, dynamicity), + acc_cols=shape_dim(shape.n, dynamicity), + ) + return shapes # Helper for generate_function. @@ -345,33 +355,36 @@ def generate_function_name( acc_type: MatrixElemTypeId, shapes: TestInputMatricesShapes, accumulate: bool, - compilation_info: typing.Optional[CompilationInfo] = None): - input_t = lhs_rhs_type.value - acc_t = acc_type.value - lhs_m = int_or_DYN(shapes.lhs_rows) - lhs_k = int_or_DYN(shapes.lhs_cols) - rhs_k = int_or_DYN(shapes.rhs_rows) - rhs_n = int_or_DYN(shapes.rhs_cols) - acc_m = int_or_DYN(shapes.acc_rows) - acc_n = int_or_DYN(shapes.acc_cols) - - info = "" - if compilation_info: - tile_sizes = list(itertools.chain(*compilation_info.tile_sizes)) - tile_workgroup_key = "_".join([ - str(a) for a in tile_sizes - ]) + "_" + "_".join([str(a) for a in compilation_info.workgroup_size]) - info = f"_for_{compilation_info.dispatch_lowering_pass_pipeline}_{tile_workgroup_key}" - - matmul_kind = "matmul_accumulate" if accumulate else "matmul" - return f"{matmul_kind}_{lhs_m}x{lhs_k}x{input_t}_times_{rhs_k}x{rhs_n}x{input_t}_into_{acc_m}x{acc_n}x{acc_t}{info}" + compilation_info: typing.Optional[CompilationInfo] = None, +): + input_t = lhs_rhs_type.value + acc_t = acc_type.value + lhs_m = int_or_DYN(shapes.lhs_rows) + lhs_k = int_or_DYN(shapes.lhs_cols) + rhs_k = int_or_DYN(shapes.rhs_rows) + rhs_n = int_or_DYN(shapes.rhs_cols) + acc_m = int_or_DYN(shapes.acc_rows) + acc_n = int_or_DYN(shapes.acc_cols) + + info = "" + if compilation_info: + tile_sizes = list(itertools.chain(*compilation_info.tile_sizes)) + tile_workgroup_key = ( + "_".join([str(a) for a in tile_sizes]) + + "_" + + "_".join([str(a) for a in compilation_info.workgroup_size]) + ) + info = f"_for_{compilation_info.dispatch_lowering_pass_pipeline}_{tile_workgroup_key}" + + matmul_kind = "matmul_accumulate" if accumulate else "matmul" + return f"{matmul_kind}_{lhs_m}x{lhs_k}x{input_t}_times_{rhs_k}x{rhs_n}x{input_t}_into_{acc_m}x{acc_n}x{acc_t}{info}" # Represents a generated test function. @dataclasses.dataclass class MLIRFunction: - name: str - definition: str + name: str + definition: str # Generates a test function in the generated MLIR code. @@ -382,76 +395,90 @@ def generate_function( acc_type: MatrixElemTypeId, shape: TestShape, dynamicity: Dynamicity, - compilation_info: typing.Optional[CompilationInfo] = None): - shapes = generate_shapes(shape, dynamicity) - func_name = generate_function_name(lhs_rhs_type, acc_type, shapes, - shape.accumulate, compilation_info) - lhs_m = int_or_question_mark(shapes.lhs_rows) - lhs_k = int_or_question_mark(shapes.lhs_cols) - rhs_k = int_or_question_mark(shapes.rhs_rows) - rhs_n = int_or_question_mark(shapes.rhs_cols) - acc_m = int_or_question_mark(shapes.acc_rows) - acc_n = int_or_question_mark(shapes.acc_cols) - lhs_tensor_type = f"tensor<{lhs_m}x{lhs_k}x{lhs_rhs_type.value}>" - rhs_tensor_type = f"tensor<{rhs_k}x{rhs_n}x{lhs_rhs_type.value}>" - acc_tensor_type = f"tensor<{acc_m}x{acc_n}x{acc_type.value}>" - - # Compilation info is optional; prints empty string by default. - func_definition = "" - compilation_info_attr = "" - if compilation_info: - if "SPIRV" in compilation_info.dispatch_lowering_pass_pipeline == "SPIRVVectorizeMali": - dispatch_lowering_pass_pipeline = "SPIRVBaseVectorize" - elif compilation_info.dispatch_lowering_pass_pipeline == "SPIRVVectorizeNVIDIA": - # TODO: change to test SPIRVMatmulPromoteVectorize too - dispatch_lowering_pass_pipeline = "SPIRVBaseVectorize" + compilation_info: typing.Optional[CompilationInfo] = None, +): + shapes = generate_shapes(shape, dynamicity) + func_name = generate_function_name( + lhs_rhs_type, acc_type, shapes, shape.accumulate, compilation_info + ) + lhs_m = int_or_question_mark(shapes.lhs_rows) + lhs_k = int_or_question_mark(shapes.lhs_cols) + rhs_k = int_or_question_mark(shapes.rhs_rows) + rhs_n = int_or_question_mark(shapes.rhs_cols) + acc_m = int_or_question_mark(shapes.acc_rows) + acc_n = int_or_question_mark(shapes.acc_cols) + lhs_tensor_type = f"tensor<{lhs_m}x{lhs_k}x{lhs_rhs_type.value}>" + rhs_tensor_type = f"tensor<{rhs_k}x{rhs_n}x{lhs_rhs_type.value}>" + acc_tensor_type = f"tensor<{acc_m}x{acc_n}x{acc_type.value}>" + + # Compilation info is optional; prints empty string by default. + func_definition = "" + compilation_info_attr = "" + if compilation_info: + if ( + "SPIRV" + in compilation_info.dispatch_lowering_pass_pipeline + == "SPIRVVectorizeMali" + ): + dispatch_lowering_pass_pipeline = "SPIRVBaseVectorize" + elif compilation_info.dispatch_lowering_pass_pipeline == "SPIRVVectorizeNVIDIA": + # TODO: change to test SPIRVMatmulPromoteVectorize too + dispatch_lowering_pass_pipeline = "SPIRVBaseVectorize" + else: + dispatch_lowering_pass_pipeline = ( + compilation_info.dispatch_lowering_pass_pipeline + ) + compilation_info_string = ( + f"#compilation{generate_function.compilation_index} = #iree_codegen.compilation_info<\n" + f" lowering_config = ,\n" + f" translation_info = <{dispatch_lowering_pass_pipeline}\n" + f" pipeline_depth = {compilation_info.software_pipeline_depth}>,\n" + f" workgroup_size = {compilation_info.workgroup_size_str()}>\n" + ) + compilation_info_attr = ( + f"{{compilation_info = #compilation{generate_function.compilation_index}}} " + ) + func_definition = func_definition + compilation_info_string + generate_function.compilation_index += 1 + + if shape.accumulate: + func_definition = func_definition + ( + f"func.func @{func_name}(%lhs: {lhs_tensor_type}, %rhs: {rhs_tensor_type}, %acc: {acc_tensor_type}) -> {acc_tensor_type} {{\n" + f" %result = linalg.matmul {compilation_info_attr}ins(%lhs, %rhs: {lhs_tensor_type}, {rhs_tensor_type}) outs(%acc: {acc_tensor_type}) -> {acc_tensor_type}\n" + f" return %result: {acc_tensor_type}\n" + f"}}\n" + ) else: - dispatch_lowering_pass_pipeline = compilation_info.dispatch_lowering_pass_pipeline - compilation_info_string = ( - f"#compilation{generate_function.compilation_index} = #iree_codegen.compilation_info<\n" - f" lowering_config = ,\n" - f" translation_info = <{dispatch_lowering_pass_pipeline}\n" - f" pipeline_depth = {compilation_info.software_pipeline_depth}>,\n" - f" workgroup_size = {compilation_info.workgroup_size_str()}>\n") - compilation_info_attr = f"{{compilation_info = #compilation{generate_function.compilation_index}}} " - func_definition = func_definition + compilation_info_string - generate_function.compilation_index += 1 - - if shape.accumulate: - func_definition = func_definition + ( - f"func.func @{func_name}(%lhs: {lhs_tensor_type}, %rhs: {rhs_tensor_type}, %acc: {acc_tensor_type}) -> {acc_tensor_type} {{\n" - f" %result = linalg.matmul {compilation_info_attr}ins(%lhs, %rhs: {lhs_tensor_type}, {rhs_tensor_type}) outs(%acc: {acc_tensor_type}) -> {acc_tensor_type}\n" - f" return %result: {acc_tensor_type}\n" - f"}}\n") - else: - literal_zero_for_acc_type = "0.0" if "f" in acc_type.value else "0" - acc_dyn_sizes = [] - if acc_m == "?": - func_definition = func_definition + ( - f"func.func @{func_name}(%lhs: {lhs_tensor_type}, %rhs: {rhs_tensor_type}) -> {acc_tensor_type} {{\n" - f" %c0 = arith.constant 0 : index\n" - f" %c1 = arith.constant 1 : index\n" - f" %acc_dim0 = tensor.dim %lhs, %c0 : {lhs_tensor_type}\n" - f" %acc_dim1 = tensor.dim %rhs, %c1 : {rhs_tensor_type}\n" - f" %init_acc = tensor.empty(%acc_dim0, %acc_dim1) : {acc_tensor_type}\n" - f" %c0_acc_type = arith.constant {literal_zero_for_acc_type}: {acc_type.value}\n" - f" %acc = linalg.fill ins(%c0_acc_type : {acc_type.value}) outs(%init_acc : {acc_tensor_type}) -> {acc_tensor_type}\n" - f" %result = linalg.matmul {compilation_info_attr}ins(%lhs, %rhs: {lhs_tensor_type}, {rhs_tensor_type}) outs(%acc: {acc_tensor_type}) -> {acc_tensor_type}\n" - f" return %result: {acc_tensor_type}\n" - f"}}\n") - else: - func_definition = func_definition + ( - f"func.func @{func_name}(%lhs: {lhs_tensor_type}, %rhs: {rhs_tensor_type}) -> {acc_tensor_type} {{\n" - f" %init_acc = tensor.empty() : {acc_tensor_type}\n" - f" %c0_acc_type = arith.constant {literal_zero_for_acc_type}: {acc_type.value}\n" - f" %acc = linalg.fill ins(%c0_acc_type : {acc_type.value}) outs(%init_acc : {acc_tensor_type}) -> {acc_tensor_type}\n" - f" %result = linalg.matmul {compilation_info_attr}ins(%lhs, %rhs: {lhs_tensor_type}, {rhs_tensor_type}) outs(%acc: {acc_tensor_type}) -> {acc_tensor_type}\n" - f" return %result: {acc_tensor_type}\n" - f"}}\n") - return MLIRFunction( - name=func_name, - definition=func_definition, - ) + literal_zero_for_acc_type = "0.0" if "f" in acc_type.value else "0" + acc_dyn_sizes = [] + if acc_m == "?": + func_definition = func_definition + ( + f"func.func @{func_name}(%lhs: {lhs_tensor_type}, %rhs: {rhs_tensor_type}) -> {acc_tensor_type} {{\n" + f" %c0 = arith.constant 0 : index\n" + f" %c1 = arith.constant 1 : index\n" + f" %acc_dim0 = tensor.dim %lhs, %c0 : {lhs_tensor_type}\n" + f" %acc_dim1 = tensor.dim %rhs, %c1 : {rhs_tensor_type}\n" + f" %init_acc = tensor.empty(%acc_dim0, %acc_dim1) : {acc_tensor_type}\n" + f" %c0_acc_type = arith.constant {literal_zero_for_acc_type}: {acc_type.value}\n" + f" %acc = linalg.fill ins(%c0_acc_type : {acc_type.value}) outs(%init_acc : {acc_tensor_type}) -> {acc_tensor_type}\n" + f" %result = linalg.matmul {compilation_info_attr}ins(%lhs, %rhs: {lhs_tensor_type}, {rhs_tensor_type}) outs(%acc: {acc_tensor_type}) -> {acc_tensor_type}\n" + f" return %result: {acc_tensor_type}\n" + f"}}\n" + ) + else: + func_definition = func_definition + ( + f"func.func @{func_name}(%lhs: {lhs_tensor_type}, %rhs: {rhs_tensor_type}) -> {acc_tensor_type} {{\n" + f" %init_acc = tensor.empty() : {acc_tensor_type}\n" + f" %c0_acc_type = arith.constant {literal_zero_for_acc_type}: {acc_type.value}\n" + f" %acc = linalg.fill ins(%c0_acc_type : {acc_type.value}) outs(%init_acc : {acc_tensor_type}) -> {acc_tensor_type}\n" + f" %result = linalg.matmul {compilation_info_attr}ins(%lhs, %rhs: {lhs_tensor_type}, {rhs_tensor_type}) outs(%acc: {acc_tensor_type}) -> {acc_tensor_type}\n" + f" return %result: {acc_tensor_type}\n" + f"}}\n" + ) + return MLIRFunction( + name=func_name, + definition=func_definition, + ) # Counter for producing unique compilation info attrs @@ -465,172 +492,197 @@ def generate_function( def contents_generator_tag(generator: MatrixGenerator): - if generator == MatrixGenerator.ZERO: - return "" - elif generator == MatrixGenerator.RANDOM: - global pseudorandom_generator_seed - pseudorandom_generator_seed = pseudorandom_generator_seed + 1 - return f"!tag:iree:fully_specified_pseudorandom {pseudorandom_generator_seed}" - else: - raise ValueError(generator) + if generator == MatrixGenerator.ZERO: + return "" + elif generator == MatrixGenerator.RANDOM: + global pseudorandom_generator_seed + pseudorandom_generator_seed = pseudorandom_generator_seed + 1 + return f"!tag:iree:fully_specified_pseudorandom {pseudorandom_generator_seed}" + else: + raise ValueError(generator) # Generate a matrix function argument in the output trace, as a dictionary # to be passed to yaml.dump. -def generate_trace_matrix_arg(matrix_shape: list, - element_type: MatrixElemTypeId, - generator: MatrixGenerator): - result = { - "type": "hal.buffer_view", - "shape": matrix_shape, - "element_type": element_type.value, - } - generator_tag = contents_generator_tag(generator) - if generator_tag: - result["contents_generator"] = generator_tag - return result +def generate_trace_matrix_arg( + matrix_shape: list, element_type: MatrixElemTypeId, generator: MatrixGenerator +): + result = { + "type": "hal.buffer_view", + "shape": matrix_shape, + "element_type": element_type.value, + } + generator_tag = contents_generator_tag(generator) + if generator_tag: + result["contents_generator"] = generator_tag + return result # Generates the output trace for a testcase i.e. a single test function call, # as a dictionary to be passed to yaml.dump. -def generate_trace(func_name: str, lhs_rhs_type: MatrixElemTypeId, - acc_type: MatrixElemTypeId, shape: TestShape): - args = [ - generate_trace_matrix_arg([shape.m, shape.k], lhs_rhs_type, - MatrixGenerator.RANDOM), - generate_trace_matrix_arg([shape.k, shape.n], lhs_rhs_type, - MatrixGenerator.RANDOM), - ] - if shape.accumulate: - args.append( - generate_trace_matrix_arg([shape.m, shape.n], acc_type, - MatrixGenerator.RANDOM)) - - result = generate_trace_matrix_arg([shape.m, shape.n], acc_type, - MatrixGenerator.ZERO) - return { - "type": "call", - "function": "module." + func_name, - "args": args, - "results": [result], - } +def generate_trace( + func_name: str, + lhs_rhs_type: MatrixElemTypeId, + acc_type: MatrixElemTypeId, + shape: TestShape, +): + args = [ + generate_trace_matrix_arg( + [shape.m, shape.k], lhs_rhs_type, MatrixGenerator.RANDOM + ), + generate_trace_matrix_arg( + [shape.k, shape.n], lhs_rhs_type, MatrixGenerator.RANDOM + ), + ] + if shape.accumulate: + args.append( + generate_trace_matrix_arg( + [shape.m, shape.n], acc_type, MatrixGenerator.RANDOM + ) + ) + + result = generate_trace_matrix_arg( + [shape.m, shape.n], acc_type, MatrixGenerator.ZERO + ) + return { + "type": "call", + "function": "module." + func_name, + "args": args, + "results": [result], + } # Generates all output files' contents as strings. -def generate(lhs_rhs_type: MatrixElemTypeId, acc_type: MatrixElemTypeId, - shapes_id: ShapesId, compilation_info_id: CompilationInfoId): - function_definitions = {} - traces = [] - - for compilation_info in get_test_compilation_infos(compilation_info_id, - lhs_rhs_type): - for shape in get_test_shapes(shapes_id): - for dynamicity in get_dynamicities(shapes_id): - function = generate_function(lhs_rhs_type, acc_type, shape, dynamicity, - compilation_info) - # Different testcases may differ only by runtime parameters but - # share the same code. For example, dynamic-shapes testcases - # share the same code involing tensor even though the runtime - # value in the trace are different. That's why we append conditionally - # to traces, but unconditionally to function_definitions. - if function.name not in function_definitions: - function_definitions[function.name] = function.definition - traces.append( - generate_trace(function.name, lhs_rhs_type, acc_type, shape)) - - return (function_definitions, traces) +def generate( + lhs_rhs_type: MatrixElemTypeId, + acc_type: MatrixElemTypeId, + shapes_id: ShapesId, + compilation_info_id: CompilationInfoId, +): + function_definitions = {} + traces = [] + + for compilation_info in get_test_compilation_infos( + compilation_info_id, lhs_rhs_type + ): + for shape in get_test_shapes(shapes_id): + for dynamicity in get_dynamicities(shapes_id): + function = generate_function( + lhs_rhs_type, acc_type, shape, dynamicity, compilation_info + ) + # Different testcases may differ only by runtime parameters but + # share the same code. For example, dynamic-shapes testcases + # share the same code involing tensor even though the runtime + # value in the trace are different. That's why we append conditionally + # to traces, but unconditionally to function_definitions. + if function.name not in function_definitions: + function_definitions[function.name] = function.definition + traces.append( + generate_trace(function.name, lhs_rhs_type, acc_type, shape) + ) + + return (function_definitions, traces) def parse_arguments(): - parser = argparse.ArgumentParser(description="Generator of e2e matmul tests") - parser.add_argument("--output_code", - type=str, - help="Path of output .mlir file", - required=True) - parser.add_argument("--output_trace", - type=str, - help="Path of output .yaml trace file", - required=True) - parser.add_argument("--lhs_rhs_type", - type=str, - choices=["i8", "f32", "f16"], - help="Numeric type of input matrices", - required=True) - parser.add_argument("--shapes", - type=str, - choices=[s.value for s in ShapesId], - help="Collection of matrix shapes to test", - required=True) - parser.add_argument("--compilation_info", - type=str, - choices=[i.value for i in CompilationInfoId], - help="Collection of compilation info setups to test", - default="", - required=False) - - parser.add_argument( - "--module_path", - type=str, - help= - "Module path (typically .vmfb) to be referenced in the output trace. Should match the output path of the iree-compile command generating the module.", - required=True) - parser.add_argument( - "--requirements", - type=str, - help= - "Target requirements for this module. Comma-separated. As in -iree-llvmcpu-target-cpu-features. If the target device does not meet all of the requirements, the test will be skipped.", - required=False) - return parser.parse_args() + parser = argparse.ArgumentParser(description="Generator of e2e matmul tests") + parser.add_argument( + "--output_code", type=str, help="Path of output .mlir file", required=True + ) + parser.add_argument( + "--output_trace", + type=str, + help="Path of output .yaml trace file", + required=True, + ) + parser.add_argument( + "--lhs_rhs_type", + type=str, + choices=["i8", "f32", "f16"], + help="Numeric type of input matrices", + required=True, + ) + parser.add_argument( + "--shapes", + type=str, + choices=[s.value for s in ShapesId], + help="Collection of matrix shapes to test", + required=True, + ) + parser.add_argument( + "--compilation_info", + type=str, + choices=[i.value for i in CompilationInfoId], + help="Collection of compilation info setups to test", + default="", + required=False, + ) + + parser.add_argument( + "--module_path", + type=str, + help="Module path (typically .vmfb) to be referenced in the output trace. Should match the output path of the iree-compile command generating the module.", + required=True, + ) + parser.add_argument( + "--requirements", + type=str, + help="Target requirements for this module. Comma-separated. As in -iree-llvmcpu-target-cpu-features. If the target device does not meet all of the requirements, the test will be skipped.", + required=False, + ) + return parser.parse_args() def write_code_file(function_definitions, filename): - with open(filename, "w") as file: - for funcname in function_definitions: - file.write(function_definitions[funcname] + "\n") + with open(filename, "w") as file: + for funcname in function_definitions: + file.write(function_definitions[funcname] + "\n") def write_trace_file(traces, filename, module_path, requirements): - yaml_documents = [ - { - "type": "context_load", - }, - { - "type": "module_load", - "module": { - "name": "hal", - "type": "builtin", - } - }, - { - "type": "module_load", - "module": { - "name": "module", - "type": "bytecode", - "path": os.path.relpath(module_path, os.path.dirname(filename)) - } - }, - ] - if requirements: - yaml_documents.append({ - "type": "requirements", - "target_features": [req.lstrip("+") for req in requirements.split(",")], - }) - - for trace in traces: - yaml_documents.append(trace) - - dumped_yaml = yaml.dump_all(yaml_documents) - - # TODO: This regex substitution is a hack as I couldn't figure how to have - # PyYAML dump our custom contents_generator into the desired format, e.g. - # contents_generator: !tag:iree:fully_specified_pseudorandom 368 - # Someone with better knowledge of YAML is welcome to fix this, possibly by - # changing that format if that's appropriate! So long as the e2e_matmul tests - # pass. - processed_yaml = re.sub(r"'(![^']*)'", "\\1", dumped_yaml) - - with open(filename, "w") as file: - file.write(processed_yaml) + yaml_documents = [ + { + "type": "context_load", + }, + { + "type": "module_load", + "module": { + "name": "hal", + "type": "builtin", + }, + }, + { + "type": "module_load", + "module": { + "name": "module", + "type": "bytecode", + "path": os.path.relpath(module_path, os.path.dirname(filename)), + }, + }, + ] + if requirements: + yaml_documents.append( + { + "type": "requirements", + "target_features": [req.lstrip("+") for req in requirements.split(",")], + } + ) + + for trace in traces: + yaml_documents.append(trace) + + dumped_yaml = yaml.dump_all(yaml_documents) + + # TODO: This regex substitution is a hack as I couldn't figure how to have + # PyYAML dump our custom contents_generator into the desired format, e.g. + # contents_generator: !tag:iree:fully_specified_pseudorandom 368 + # Someone with better knowledge of YAML is welcome to fix this, possibly by + # changing that format if that's appropriate! So long as the e2e_matmul tests + # pass. + processed_yaml = re.sub(r"'(![^']*)'", "\\1", dumped_yaml) + + with open(filename, "w") as file: + file.write(processed_yaml) # For now, the accumulator type can always be inferred from the input LHS/RHS @@ -638,24 +690,24 @@ def write_trace_file(traces, filename, module_path, requirements): # where the same input types are used with different accumulator types, e.g. # f16 inputs with both f16 and f32 accumulator. def infer_acc_type(lhs_rhs_type: MatrixElemTypeId): - if lhs_rhs_type == MatrixElemTypeId.I8: - return MatrixElemTypeId.I32 - else: - return lhs_rhs_type + if lhs_rhs_type == MatrixElemTypeId.I8: + return MatrixElemTypeId.I32 + else: + return lhs_rhs_type def main(args): - lhs_rhs_type = MatrixElemTypeId(args.lhs_rhs_type) - acc_type = infer_acc_type(lhs_rhs_type) - shapes_id = ShapesId(args.shapes) - compilation_info_id = CompilationInfoId(args.compilation_info) - (function_definitions, traces) = generate(lhs_rhs_type, acc_type, shapes_id, - compilation_info_id) + lhs_rhs_type = MatrixElemTypeId(args.lhs_rhs_type) + acc_type = infer_acc_type(lhs_rhs_type) + shapes_id = ShapesId(args.shapes) + compilation_info_id = CompilationInfoId(args.compilation_info) + (function_definitions, traces) = generate( + lhs_rhs_type, acc_type, shapes_id, compilation_info_id + ) - write_code_file(function_definitions, args.output_code) - write_trace_file(traces, args.output_trace, args.module_path, - args.requirements) + write_code_file(function_definitions, args.output_code) + write_trace_file(traces, args.output_trace, args.module_path, args.requirements) if __name__ == "__main__": - main(parse_arguments()) + main(parse_arguments()) diff --git a/tests/e2e/models/mnist_train_test/datasets.py b/tests/e2e/models/mnist_train_test/datasets.py index 8fcc36dcb82f..2caa7c53c9f6 100644 --- a/tests/e2e/models/mnist_train_test/datasets.py +++ b/tests/e2e/models/mnist_train_test/datasets.py @@ -18,67 +18,70 @@ def _download(url, filename): - """Download a url to a file in the JAX data temp directory.""" - if not path.exists(_DATA): - os.makedirs(_DATA) - out_file = path.join(_DATA, filename) - if not path.isfile(out_file): - urllib.request.urlretrieve(url, out_file) - print("downloaded {} to {}".format(url, _DATA)) + """Download a url to a file in the JAX data temp directory.""" + if not path.exists(_DATA): + os.makedirs(_DATA) + out_file = path.join(_DATA, filename) + if not path.isfile(out_file): + urllib.request.urlretrieve(url, out_file) + print("downloaded {} to {}".format(url, _DATA)) def _partial_flatten(x): - """Flatten all but the first dimension of an ndarray.""" - return np.reshape(x, (x.shape[0], -1)) + """Flatten all but the first dimension of an ndarray.""" + return np.reshape(x, (x.shape[0], -1)) def _one_hot(x, k, dtype=np.float32): - """Create a one-hot encoding of x of size k.""" - return np.array(x[:, None] == np.arange(k), dtype) + """Create a one-hot encoding of x of size k.""" + return np.array(x[:, None] == np.arange(k), dtype) def mnist_raw(): - """Download and parse the raw MNIST dataset.""" - # CVDF mirror of http://yann.lecun.com/exdb/mnist/ - base_url = "https://storage.googleapis.com/cvdf-datasets/mnist/" - - def parse_labels(filename): - with gzip.open(filename, "rb") as fh: - _ = struct.unpack(">II", fh.read(8)) - return np.array(array.array("B", fh.read()), dtype=np.uint8) - - def parse_images(filename): - with gzip.open(filename, "rb") as fh: - _, num_data, rows, cols = struct.unpack(">IIII", fh.read(16)) - return np.array(array.array("B", fh.read()), - dtype=np.uint8).reshape(num_data, rows, cols) - - for filename in [ - "train-images-idx3-ubyte.gz", "train-labels-idx1-ubyte.gz", - "t10k-images-idx3-ubyte.gz", "t10k-labels-idx1-ubyte.gz" - ]: - _download(base_url + filename, filename) - - train_images = parse_images(path.join(_DATA, "train-images-idx3-ubyte.gz")) - train_labels = parse_labels(path.join(_DATA, "train-labels-idx1-ubyte.gz")) - test_images = parse_images(path.join(_DATA, "t10k-images-idx3-ubyte.gz")) - test_labels = parse_labels(path.join(_DATA, "t10k-labels-idx1-ubyte.gz")) - - return train_images, train_labels, test_images, test_labels + """Download and parse the raw MNIST dataset.""" + # CVDF mirror of http://yann.lecun.com/exdb/mnist/ + base_url = "https://storage.googleapis.com/cvdf-datasets/mnist/" + + def parse_labels(filename): + with gzip.open(filename, "rb") as fh: + _ = struct.unpack(">II", fh.read(8)) + return np.array(array.array("B", fh.read()), dtype=np.uint8) + + def parse_images(filename): + with gzip.open(filename, "rb") as fh: + _, num_data, rows, cols = struct.unpack(">IIII", fh.read(16)) + return np.array(array.array("B", fh.read()), dtype=np.uint8).reshape( + num_data, rows, cols + ) + + for filename in [ + "train-images-idx3-ubyte.gz", + "train-labels-idx1-ubyte.gz", + "t10k-images-idx3-ubyte.gz", + "t10k-labels-idx1-ubyte.gz", + ]: + _download(base_url + filename, filename) + + train_images = parse_images(path.join(_DATA, "train-images-idx3-ubyte.gz")) + train_labels = parse_labels(path.join(_DATA, "train-labels-idx1-ubyte.gz")) + test_images = parse_images(path.join(_DATA, "t10k-images-idx3-ubyte.gz")) + test_labels = parse_labels(path.join(_DATA, "t10k-labels-idx1-ubyte.gz")) + + return train_images, train_labels, test_images, test_labels def mnist(permute_train=False): - """Download, parse and process MNIST data to unit scale and one-hot labels.""" - train_images, train_labels, test_images, test_labels = mnist_raw() + """Download, parse and process MNIST data to unit scale and one-hot labels.""" + train_images, train_labels, test_images, test_labels = mnist_raw() - train_images = _partial_flatten(train_images) / np.float32(255.) - test_images = _partial_flatten(test_images) / np.float32(255.) - train_labels = _one_hot(train_labels, 10) - test_labels = _one_hot(test_labels, 10) + train_images = _partial_flatten(train_images) / np.float32(255.0) + test_images = _partial_flatten(test_images) / np.float32(255.0) + train_labels = _one_hot(train_labels, 10) + test_labels = _one_hot(test_labels, 10) - if permute_train: - perm = np.random.RandomState(0).permutation(train_images.shape[0]) - train_images = train_images[perm] - train_labels = train_labels[perm] + if permute_train: + perm = np.random.RandomState(0).permutation(train_images.shape[0]) + train_images = train_images[perm] + train_labels = train_labels[perm] - return train_images, train_labels, test_images, test_labels + return train_images, train_labels, test_images, test_labels diff --git a/tests/e2e/models/mnist_train_test/generate_test_data.py b/tests/e2e/models/mnist_train_test/generate_test_data.py index a5e5d6a28fb2..52dda817a6a2 100644 --- a/tests/e2e/models/mnist_train_test/generate_test_data.py +++ b/tests/e2e/models/mnist_train_test/generate_test_data.py @@ -23,189 +23,203 @@ def get_example_batch(): - batch_size = 128 - train_images, train_labels, _, _ = datasets.mnist() - num_train = train_images.shape[0] - num_complete_batches, leftover = divmod(num_train, batch_size) - num_batches = num_complete_batches + bool(leftover) + batch_size = 128 + train_images, train_labels, _, _ = datasets.mnist() + num_train = train_images.shape[0] + num_complete_batches, leftover = divmod(num_train, batch_size) + num_batches = num_complete_batches + bool(leftover) - def data_stream(): - rng = npr.RandomState(0) - while True: - perm = rng.permutation(num_train) - for i in range(num_batches): - batch_idx = perm[i * batch_size:(i + 1) * batch_size] - yield train_images[batch_idx], train_labels[batch_idx] + def data_stream(): + rng = npr.RandomState(0) + while True: + perm = rng.permutation(num_train) + for i in range(num_batches): + batch_idx = perm[i * batch_size : (i + 1) * batch_size] + yield train_images[batch_idx], train_labels[batch_idx] - batches = data_stream() - return next(batches) + batches = data_stream() + return next(batches) def get_model(): - init_random_params, predict = stax.serial( - Dense(128), - Relu, - Dense(128), - Relu, - Dense(10), - LogSoftmax, - ) - return init_random_params, predict + init_random_params, predict = stax.serial( + Dense(128), + Relu, + Dense(128), + Relu, + Dense(10), + LogSoftmax, + ) + return init_random_params, predict def loss(params, batch, predict_fn): - inputs, targets = batch - preds = predict_fn(params, inputs) - return -jnp.mean(jnp.sum(preds * targets, axis=1)) + inputs, targets = batch + preds = predict_fn(params, inputs) + return -jnp.mean(jnp.sum(preds * targets, axis=1)) def create_iree_jax_module(): - init_random_params, forward = get_model() + init_random_params, forward = get_model() - rng = random.PRNGKey(12345) - _, init_params = init_random_params(rng, (-1, 28 * 28)) - opt_init, opt_update, opt_get_params = optimizers.momentum(0.001, mass=0.9) - opt_state = opt_init(init_params) + rng = random.PRNGKey(12345) + _, init_params = init_random_params(rng, (-1, 28 * 28)) + opt_init, opt_update, opt_get_params = optimizers.momentum(0.001, mass=0.9) + opt_state = opt_init(init_params) - example_batch = get_example_batch() + example_batch = get_example_batch() - class IreeJaxMnistModule(Program): - _opt_state = opt_state + class IreeJaxMnistModule(Program): + _opt_state = opt_state - def get_params(self): - return opt_get_params(self._opt_state) + def get_params(self): + return opt_get_params(self._opt_state) - def get_opt_state(self): - return self._opt_state + def get_opt_state(self): + return self._opt_state - def set_opt_state(self, new_opt_state=like(opt_state)): - self._opt_state = new_opt_state + def set_opt_state(self, new_opt_state=like(opt_state)): + self._opt_state = new_opt_state - def initialize(self, rng=like(rng)): - self._opt_state = self._initialize_optimizer(rng) + def initialize(self, rng=like(rng)): + self._opt_state = self._initialize_optimizer(rng) - def update(self, batch=like(example_batch)): - new_opt_state = self._update_step(batch, self._opt_state) - self._opt_state = new_opt_state + def update(self, batch=like(example_batch)): + new_opt_state = self._update_step(batch, self._opt_state) + self._opt_state = new_opt_state - def forward(self, inputs=like(example_batch[0])): - return self._forward(opt_get_params(self._opt_state), inputs) + def forward(self, inputs=like(example_batch[0])): + return self._forward(opt_get_params(self._opt_state), inputs) - @kernel - def _initialize_optimizer(rng): - _, init_params = init_random_params(rng, (-1, 28 * 28)) - return opt_init(init_params) + @kernel + def _initialize_optimizer(rng): + _, init_params = init_random_params(rng, (-1, 28 * 28)) + return opt_init(init_params) - @kernel - def _update_step(batch, opt_state): - params = opt_get_params(opt_state) - return opt_update(0, grad(loss)(params, batch, forward), opt_state) + @kernel + def _update_step(batch, opt_state): + params = opt_get_params(opt_state) + return opt_update(0, grad(loss)(params, batch, forward), opt_state) - @kernel - def _forward(params, inputs): - return forward(params, inputs) + @kernel + def _forward(params, inputs): + return forward(params, inputs) - return IreeJaxMnistModule() + return IreeJaxMnistModule() def build_mlir_module(output_filepath): - module = create_iree_jax_module() - with open(output_filepath, "wb") as f: - Program.get_mlir_module(module).operation.write_bytecode(f) + module = create_iree_jax_module() + with open(output_filepath, "wb") as f: + Program.get_mlir_module(module).operation.write_bytecode(f) def build_jax_module(): - init_random_params, forward = get_model() - - rng = random.PRNGKey(12345) - _, init_params = init_random_params(rng, (-1, 28 * 28)) - opt_init, opt_update, opt_get_params = optimizers.momentum(0.001, mass=0.9) - opt_state = opt_init(init_params) - - example_batch = get_example_batch() - - class JaxMnistModule: - _opt_state = opt_state - - def get_params(self): - return opt_get_params(self._opt_state) - - def get_opt_state(self): - return self._opt_state - - def set_opt_state(self, new_opt_state): - self._opt_state = new_opt_state - - def initialize(self, rng): - self._opt_state = JaxMnistModule._initialize_optimizer(rng) - - def update(self, batch): - new_opt_state = JaxMnistModule._update_step(batch, self._opt_state) - self._opt_state = new_opt_state - - def forward(self, inputs): - return JaxMnistModule._forward(opt_get_params(self._opt_state), inputs) - - @jax.jit - def _initialize_optimizer(rng): - _, init_params = init_random_params(rng, (-1, 28 * 28)) - return opt_init(init_params) - - @jax.jit - def _update_step(batch, opt_state): - params = opt_get_params(opt_state) - return opt_update(0, grad(loss)(params, batch, forward), opt_state) - - @jax.jit - def _forward(params, inputs): - return forward(params, inputs) - - return JaxMnistModule() - - -def generate_test_data(output_mlir_filepath: str, batch_filepath: str, - expected_optimizer_state_after_init_filepath: str, - expected_optimizer_state_after_train_step_filepath: str, - expected_prediction_after_train_step_filepath: str): - build_mlir_module(output_mlir_filepath) - example_batch = get_example_batch() - np.savez_compressed(batch_filepath, *example_batch) - jax_module = build_jax_module() - jax_module.update(example_batch) - np.savez_compressed(expected_optimizer_state_after_train_step_filepath, - *tree_flatten(jax_module.get_opt_state())[0]) - prediction_jax = jax_module.forward(example_batch[0]) - np.savez_compressed(expected_prediction_after_train_step_filepath, - prediction_jax) - rng = random.PRNGKey(6789) - jax_module.initialize(rng) - np.savez_compressed(expected_optimizer_state_after_init_filepath, - *tree_flatten(jax_module.get_opt_state())[0]) + init_random_params, forward = get_model() + + rng = random.PRNGKey(12345) + _, init_params = init_random_params(rng, (-1, 28 * 28)) + opt_init, opt_update, opt_get_params = optimizers.momentum(0.001, mass=0.9) + opt_state = opt_init(init_params) + + example_batch = get_example_batch() + + class JaxMnistModule: + _opt_state = opt_state + + def get_params(self): + return opt_get_params(self._opt_state) + + def get_opt_state(self): + return self._opt_state + + def set_opt_state(self, new_opt_state): + self._opt_state = new_opt_state + + def initialize(self, rng): + self._opt_state = JaxMnistModule._initialize_optimizer(rng) + + def update(self, batch): + new_opt_state = JaxMnistModule._update_step(batch, self._opt_state) + self._opt_state = new_opt_state + + def forward(self, inputs): + return JaxMnistModule._forward(opt_get_params(self._opt_state), inputs) + + @jax.jit + def _initialize_optimizer(rng): + _, init_params = init_random_params(rng, (-1, 28 * 28)) + return opt_init(init_params) + + @jax.jit + def _update_step(batch, opt_state): + params = opt_get_params(opt_state) + return opt_update(0, grad(loss)(params, batch, forward), opt_state) + + @jax.jit + def _forward(params, inputs): + return forward(params, inputs) + + return JaxMnistModule() + + +def generate_test_data( + output_mlir_filepath: str, + batch_filepath: str, + expected_optimizer_state_after_init_filepath: str, + expected_optimizer_state_after_train_step_filepath: str, + expected_prediction_after_train_step_filepath: str, +): + build_mlir_module(output_mlir_filepath) + example_batch = get_example_batch() + np.savez_compressed(batch_filepath, *example_batch) + jax_module = build_jax_module() + jax_module.update(example_batch) + np.savez_compressed( + expected_optimizer_state_after_train_step_filepath, + *tree_flatten(jax_module.get_opt_state())[0] + ) + prediction_jax = jax_module.forward(example_batch[0]) + np.savez_compressed(expected_prediction_after_train_step_filepath, prediction_jax) + rng = random.PRNGKey(6789) + jax_module.initialize(rng) + np.savez_compressed( + expected_optimizer_state_after_init_filepath, + *tree_flatten(jax_module.get_opt_state())[0] + ) def parse_args(): - parser = argparse.ArgumentParser() - parser.add_argument("--output_mlir_filepath", - help="Output to the compiled IREE Jax MLIR model.", - type=str, - default="mnist_train.mlirbc") - parser.add_argument("--batch_filepath", type=str, default="batch.npz") - parser.add_argument("--expected_optimizer_state_after_init_filepath", - type=str, - default="expected_optimizer_state_after_init.npz") - parser.add_argument("--expected_optimizer_state_after_train_step_filepath", - type=str, - default="expected_optimizer_state_after_train_step.npz") - parser.add_argument("--expected_prediction_after_train_step_filepath", - type=str, - default="expected_prediction_after_train_step.npz") - return parser.parse_args() + parser = argparse.ArgumentParser() + parser.add_argument( + "--output_mlir_filepath", + help="Output to the compiled IREE Jax MLIR model.", + type=str, + default="mnist_train.mlirbc", + ) + parser.add_argument("--batch_filepath", type=str, default="batch.npz") + parser.add_argument( + "--expected_optimizer_state_after_init_filepath", + type=str, + default="expected_optimizer_state_after_init.npz", + ) + parser.add_argument( + "--expected_optimizer_state_after_train_step_filepath", + type=str, + default="expected_optimizer_state_after_train_step.npz", + ) + parser.add_argument( + "--expected_prediction_after_train_step_filepath", + type=str, + default="expected_prediction_after_train_step.npz", + ) + return parser.parse_args() def generate_test_data_cli(): - kwargs = vars(parse_args()) - generate_test_data(**kwargs) + kwargs = vars(parse_args()) + generate_test_data(**kwargs) if __name__ == "__main__": - generate_test_data_cli() + generate_test_data_cli() diff --git a/tests/e2e/models/mnist_train_test/mnist_train_test.py b/tests/e2e/models/mnist_train_test/mnist_train_test.py index 578ad090fef5..f339face945d 100644 --- a/tests/e2e/models/mnist_train_test/mnist_train_test.py +++ b/tests/e2e/models/mnist_train_test/mnist_train_test.py @@ -20,106 +20,118 @@ MODEL_ARTIFACTS_URL = "https://storage.googleapis.com/iree-model-artifacts/mnist_train.a49ba1535a45ac0f3e6be22a7ed5dddf4a53cd1f41126af938f0667b998f8e11.tar" -Tensor = TypeVar('Tensor') +Tensor = TypeVar("Tensor") def build_module(artifacts_dir: str): - vmfb_file = os.path.join(artifacts_dir, "mnist_train.vmfb") - compile_file(input_file=os.path.join(artifacts_dir, "mnist_train.mlirbc"), - output_file=vmfb_file, - target_backends=[args.target_backend], - input_type=InputType.STABLEHLO) - return load_vm_flatbuffer_file(vmfb_file, driver=args.driver) + vmfb_file = os.path.join(artifacts_dir, "mnist_train.vmfb") + compile_file( + input_file=os.path.join(artifacts_dir, "mnist_train.mlirbc"), + output_file=vmfb_file, + target_backends=[args.target_backend], + input_type=InputType.STABLEHLO, + ) + return load_vm_flatbuffer_file(vmfb_file, driver=args.driver) def load_data(data_dir: str): - batch = list(np.load(os.path.join(data_dir, "batch.npz")).values()) - expected_optimizer_state_after_init = list( - np.load(os.path.join(data_dir, - "expected_optimizer_state_after_init.npz")).values()) - expected_optimizer_state_after_train_step = list( - np.load( - os.path.join( - data_dir, - "expected_optimizer_state_after_train_step.npz")).values()) - expected_prediction_after_train_step = list( - np.load(os.path.join( - data_dir, "expected_prediction_after_train_step.npz")).values())[0] - return ( - batch, - expected_optimizer_state_after_init, - expected_optimizer_state_after_train_step, - expected_prediction_after_train_step, - ) + batch = list(np.load(os.path.join(data_dir, "batch.npz")).values()) + expected_optimizer_state_after_init = list( + np.load( + os.path.join(data_dir, "expected_optimizer_state_after_init.npz") + ).values() + ) + expected_optimizer_state_after_train_step = list( + np.load( + os.path.join(data_dir, "expected_optimizer_state_after_train_step.npz") + ).values() + ) + expected_prediction_after_train_step = list( + np.load( + os.path.join(data_dir, "expected_prediction_after_train_step.npz") + ).values() + )[0] + return ( + batch, + expected_optimizer_state_after_init, + expected_optimizer_state_after_train_step, + expected_prediction_after_train_step, + ) def parse_args(): - parser = argparse.ArgumentParser() - parser.add_argument("--target_backend", type=str, default="llvm-cpu") - parser.add_argument("--driver", type=str, default="local-task") - return parser.parse_known_args() + parser = argparse.ArgumentParser() + parser.add_argument("--target_backend", type=str, default="llvm-cpu") + parser.add_argument("--driver", type=str, default="local-task") + return parser.parse_known_args() DEFAULT_REL_TOLERANCE = 1e-5 DEFAULT_ABS_TOLERANCE = 1e-5 -def allclose(a: Tensor, - b: Tensor, - rtol=DEFAULT_REL_TOLERANCE, - atol=DEFAULT_ABS_TOLERANCE): - return np.allclose(np.asarray(a), np.asarray(b), rtol, atol) +def allclose( + a: Tensor, b: Tensor, rtol=DEFAULT_REL_TOLERANCE, atol=DEFAULT_ABS_TOLERANCE +): + return np.allclose(np.asarray(a), np.asarray(b), rtol, atol) def assert_array_list_compare(array_compare_fn, a: Tensor, b: Tensor): - assert (len(a) == len(b)) - for x, y in zip(a, b): - np.testing.assert_array_compare(array_compare_fn, x, y) + assert len(a) == len(b) + for x, y in zip(a, b): + np.testing.assert_array_compare(array_compare_fn, x, y) -def assert_array_list_allclose(a: List[Tensor], - b: List[Tensor], - rtol=DEFAULT_REL_TOLERANCE, - atol=DEFAULT_ABS_TOLERANCE): - assert_array_list_compare(lambda x, y: allclose(x, y, rtol, atol), a, b) +def assert_array_list_allclose( + a: List[Tensor], + b: List[Tensor], + rtol=DEFAULT_REL_TOLERANCE, + atol=DEFAULT_ABS_TOLERANCE, +): + assert_array_list_compare(lambda x, y: allclose(x, y, rtol, atol), a, b) def download_test_data(out_path: str): - urlretrieve(MODEL_ARTIFACTS_URL, out_path) + urlretrieve(MODEL_ARTIFACTS_URL, out_path) def extract_test_data(archive_path: str, out_dir: str): - with tarfile.open(archive_path) as tar: - tar.extractall(out_dir) + with tarfile.open(archive_path) as tar: + tar.extractall(out_dir) class MnistTrainTest(unittest.TestCase): - - def test_mnist_training(self): - with tempfile.TemporaryDirectory() as tmp_dir: - archive_path = os.path.join(tmp_dir, "mnist_train.tar") - download_test_data(archive_path) - extract_test_data(archive_path, tmp_dir) - module = build_module(tmp_dir) - ( - batch, - expected_optimizer_state_after_init, - expected_optimizer_state_after_train_step, - expected_prediction_after_train_step, - ) = load_data(tmp_dir) - - module.update(*batch) - assert_array_list_allclose(module.get_opt_state(), - expected_optimizer_state_after_train_step) - prediction = module.forward(batch[0]) - np.testing.assert_allclose(prediction, expected_prediction_after_train_step, - DEFAULT_REL_TOLERANCE, DEFAULT_ABS_TOLERANCE) - rng_state = np.array([0, 6789], dtype=np.int32) - module.initialize(rng_state) - assert_array_list_allclose(module.get_opt_state(), - expected_optimizer_state_after_init) - - -if __name__ == '__main__': - args, remaining_args = parse_args() - unittest.main(argv=[sys.argv[0]] + remaining_args) + def test_mnist_training(self): + with tempfile.TemporaryDirectory() as tmp_dir: + archive_path = os.path.join(tmp_dir, "mnist_train.tar") + download_test_data(archive_path) + extract_test_data(archive_path, tmp_dir) + module = build_module(tmp_dir) + ( + batch, + expected_optimizer_state_after_init, + expected_optimizer_state_after_train_step, + expected_prediction_after_train_step, + ) = load_data(tmp_dir) + + module.update(*batch) + assert_array_list_allclose( + module.get_opt_state(), expected_optimizer_state_after_train_step + ) + prediction = module.forward(batch[0]) + np.testing.assert_allclose( + prediction, + expected_prediction_after_train_step, + DEFAULT_REL_TOLERANCE, + DEFAULT_ABS_TOLERANCE, + ) + rng_state = np.array([0, 6789], dtype=np.int32) + module.initialize(rng_state) + assert_array_list_allclose( + module.get_opt_state(), expected_optimizer_state_after_init + ) + + +if __name__ == "__main__": + args, remaining_args = parse_args() + unittest.main(argv=[sys.argv[0]] + remaining_args) diff --git a/tests/lit.cfg.py b/tests/lit.cfg.py index 5ed367afa616..e6d8b7c45e0d 100644 --- a/tests/lit.cfg.py +++ b/tests/lit.cfg.py @@ -28,15 +28,19 @@ # WindowsLinkerTool uses these from vcvarsall "VCTOOLSINSTALLDIR", "UNIVERSALCRTSDKDIR", - "UCRTVERSION" + "UCRTVERSION", ] -config.environment.update({ - k: v - for k, v in os.environ.items() - if k.startswith("IREE_") or k in passthrough_env_vars -}) +config.environment.update( + { + k: v + for k, v in os.environ.items() + if k.startswith("IREE_") or k in passthrough_env_vars + } +) # Use the most preferred temp directory. -config.test_exec_root = (os.environ.get("TEST_UNDECLARED_OUTPUTS_DIR") or - os.environ.get("TEST_TMPDIR") or - os.path.join(tempfile.gettempdir(), "lit")) +config.test_exec_root = ( + os.environ.get("TEST_UNDECLARED_OUTPUTS_DIR") + or os.environ.get("TEST_TMPDIR") + or os.path.join(tempfile.gettempdir(), "lit") +) diff --git a/tools/lit.cfg.py b/tools/lit.cfg.py index 6c8ed401ffa1..c74928c710c6 100644 --- a/tools/lit.cfg.py +++ b/tools/lit.cfg.py @@ -21,17 +21,23 @@ config.test_format = lit.formats.ShTest(execute_external=True) # Forward all IREE environment variables passthrough_env_vars = ["VK_ICD_FILENAMES"] -config.environment.update({ - k: v - for k, v in os.environ.items() - if k.startswith("IREE_") or k in passthrough_env_vars -}) +config.environment.update( + { + k: v + for k, v in os.environ.items() + if k.startswith("IREE_") or k in passthrough_env_vars + } +) # Use the most preferred temp directory. -config.test_exec_root = (os.environ.get("TEST_UNDECLARED_OUTPUTS_DIR") or - os.environ.get("TEST_TMPDIR") or - os.path.join(tempfile.gettempdir(), "lit")) +config.test_exec_root = ( + os.environ.get("TEST_UNDECLARED_OUTPUTS_DIR") + or os.environ.get("TEST_TMPDIR") + or os.path.join(tempfile.gettempdir(), "lit") +) -config.substitutions.extend([ - ("%PYTHON", os.getenv("PYTHON", sys.executable)), -]) +config.substitutions.extend( + [ + ("%PYTHON", os.getenv("PYTHON", sys.executable)), + ] +) diff --git a/tools/test/echo_npy.py b/tools/test/echo_npy.py index d88b55dcd22c..0e59f990a28a 100644 --- a/tools/test/echo_npy.py +++ b/tools/test/echo_npy.py @@ -8,9 +8,9 @@ import os import sys -with open(os.path.realpath(sys.argv[1]), 'rb') as f: - f.seek(0, 2) - file_len = f.tell() - f.seek(0, 0) - while f.tell() < file_len: - print(numpy.load(f)) +with open(os.path.realpath(sys.argv[1]), "rb") as f: + f.seek(0, 2) + file_len = f.tell() + f.seek(0, 0) + while f.tell() < file_len: + print(numpy.load(f))