diff --git a/build/rocm/ci_build b/build/rocm/ci_build index 5f79f11502d5..7c76fba83803 100755 --- a/build/rocm/ci_build +++ b/build/rocm/ci_build @@ -24,6 +24,12 @@ import argparse import os import subprocess import sys +from typing import List + + +DEFAULT_GPU_DEVICE_TARGETS = ( + "gfx906,gfx908,gfx90a,gfx942,gfx1030,gfx1100,gfx1101,gfx1200,gfx1201" +) def image_by_name(name): @@ -40,7 +46,11 @@ def dist_wheels( rocm_build_job="", rocm_build_num="", compiler="gcc", + gpu_device_targets: List[str] = None, ): + if not gpu_device_targets: + gpu_device_targets = DEFAULT_GPU_DEVICE_TARGETS.split(",") + if xla_path: xla_path = os.path.abspath(xla_path) @@ -63,6 +73,7 @@ def dist_wheels( "--build-arg=ROCM_BUILD_JOB=%s" % rocm_build_job, "--build-arg=ROCM_BUILD_NUM=%s" % rocm_build_num, "--tag=%s" % image, + "--build-arg=GPU_DEVICE_TARGETS=%s" % " ".join(gpu_device_targets), ".", ] @@ -85,6 +96,8 @@ def dist_wheels( pyver_string, "--compiler", compiler, + "--gpu-device-targets", + ",".join(gpu_device_targets), ] if xla_path: @@ -158,10 +171,14 @@ def dist_docker( tag="rocm/jax-dev", dockerfile=None, keep_image=True, + gpu_device_targets: List[str] = None, ): if not dockerfile: dockerfile = "build/rocm/Dockerfile.ms" + if not gpu_device_targets: + gpu_device_targets = DEFAULT_GPU_DEVICE_TARGETS.split(",") + python_version = python_versions[0] md = _fetch_jax_metadata(xla_path) @@ -174,6 +191,7 @@ def dist_docker( "--target", "rt_build", "--build-arg=ROCM_VERSION=%s" % rocm_version, + "--build-arg=GPU_DEVICE_TARGETS=%s" % " ".join(gpu_device_targets), "--build-arg=ROCM_BUILD_JOB=%s" % rocm_build_job, "--build-arg=ROCM_BUILD_NUM=%s" % rocm_build_num, "--build-arg=BASE_DOCKER=%s" % base_docker, @@ -238,6 +256,55 @@ def test(image_name): subprocess.check_call(cmd) +def canonicalize_python_versions(versions: List[str]): + if isinstance(versions, str): + raise ValueError("'versions' must be a list of strings: versions=%r" % versions) + + cleaned = [] + for v in versions: + tup = v.split(".") + major = tup[0] + minor = tup[1] + rev = None + if len(tup) > 2 and tup[2]: + rev = tup[2] + + cleaned.append("%s.%s" % (major, minor)) + + return cleaned + + +def parse_gpu_targets(targets_string): + # catch case where targets_string was empty. + # None should already be caught by argparse, but + # it doesn't hurt to check twice + if not targets_string: + targets_string = DEFAULT_GPU_DEVICE_TARGETS + + if "," in targets_string: + targets = targets_string.split(",") + elif " " in targets_string: + targets = targets_string.split(" ") + else: + targets = targets_string + + res = [] + # cleanup and validation + for t in targets: + if not t: + continue + + if not t.startswith("gfx"): + raise ValueError("Invalid GPU architecture target: %r" % t) + + res.append(t.strip()) + + if not res: + raise ValueError("GPU_DEVICE_TARGETS cannot be empty") + + return res + + def parse_args(): p = argparse.ArgumentParser() p.add_argument( @@ -249,7 +316,7 @@ def parse_args(): p.add_argument( "--python-versions", type=lambda x: x.split(","), - default="3.12", + default=["3.12"], help="Comma separated list of CPython versions to build wheels for", ) @@ -281,6 +348,11 @@ def parse_args(): choices=["gcc", "clang"], help="Compiler backend to use when compiling jax/jaxlib", ) + p.add_argument( + "--gpu-device-targets", + default=DEFAULT_GPU_DEVICE_TARGETS, + help="List of AMDGPU device targets passed from job", + ) subp = p.add_subparsers(dest="action", required=True) @@ -299,15 +371,18 @@ def parse_args(): def main(): args = parse_args() + gpu_device_targets = parse_gpu_targets(args.gpu_device_targets) + python_versions = canonicalize_python_versions(args.python_versions) if args.action == "dist_wheels": dist_wheels( args.rocm_version, - args.python_versions, + python_versions, args.xla_source_dir, args.rocm_build_job, args.rocm_build_num, compiler=args.compiler, + gpu_device_targets=gpu_device_targets, ) elif args.action == "test": @@ -316,22 +391,24 @@ def main(): elif args.action == "dist_docker": dist_wheels( args.rocm_version, - args.python_versions, + python_versions, args.xla_source_dir, args.rocm_build_job, args.rocm_build_num, compiler=args.compiler, + gpu_device_targets=gpu_device_targets, ) dist_docker( args.rocm_version, args.base_docker, - args.python_versions, + python_versions, args.xla_source_dir, rocm_build_job=args.rocm_build_job, rocm_build_num=args.rocm_build_num, tag=args.image_tag, dockerfile=args.dockerfile, keep_image=args.keep_image, + gpu_device_targets=gpu_device_targets, ) diff --git a/build/rocm/ci_build.sh b/build/rocm/ci_build.sh index 386f70ee1a96..d8b98490b076 100755 --- a/build/rocm/ci_build.sh +++ b/build/rocm/ci_build.sh @@ -51,6 +51,7 @@ ROCM_BUILD_NUM="" BASE_DOCKER="ubuntu:22.04" CUSTOM_INSTALL="" JAX_USE_CLANG="" +GPU_DEVICE_TARGETS="" POSITIONAL_ARGS=() RUNTIME_FLAG=0 @@ -98,6 +99,18 @@ while [[ $# -gt 0 ]]; do JAX_USE_CLANG="$2" shift 2 ;; + --gpu_device_targets) + if [[ "$2" == "--custom_install" ]]; then + GPU_DEVICE_TARGETS="" + shift 2 + elif [[ -n "$2" ]]; then + GPU_DEVICE_TARGETS="$2" + shift 2 + else + GPU_DEVICE_TARGETS="" + shift 1 + fi + ;; *) POSITIONAL_ARGS+=("$1") shift @@ -164,6 +177,7 @@ fi --rocm-build-job=$ROCM_BUILD_JOB \ --rocm-build-num=$ROCM_BUILD_NUM \ --compiler=$JAX_COMPILER \ + --gpu-device-targets="${GPU_DEVICE_TARGETS}" \ dist_docker \ --dockerfile $DOCKERFILE_PATH \ --image-tag $DOCKER_IMG_NAME diff --git a/build/rocm/test_ci_build.py b/build/rocm/test_ci_build.py new file mode 100644 index 000000000000..354da937d3e5 --- /dev/null +++ b/build/rocm/test_ci_build.py @@ -0,0 +1,81 @@ +#!/usr/bin/env python3 + +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import importlib.util +import importlib.machinery + + +def load_ci_build(): + spec = importlib.util.spec_from_loader( + "ci_build", importlib.machinery.SourceFileLoader("ci_build", "./ci_build") + ) + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + return mod + + +ci_build = load_ci_build() + + +class CIBuildTestCase(unittest.TestCase): + def test_parse_gpu_targets_spaces(self): + targets = ["gfx908", "gfx940", "gfx1201"] + r = ci_build.parse_gpu_targets(" ".join(targets)) + self.assertEqual(r, targets) + + def test_parse_gpu_targets_commas(self): + targets = ["gfx908", "gfx940", "gfx1201"] + r = ci_build.parse_gpu_targets(",".join(targets)) + self.assertEqual(r, targets) + + def test_parse_gpu_targets_empty_string(self): + expected = ci_build.DEFAULT_GPU_DEVICE_TARGETS.split(",") + r = ci_build.parse_gpu_targets("") + self.assertEqual(r, expected) + + def test_parse_gpu_targets_whitespace_only(self): + self.assertRaises(ValueError, ci_build.parse_gpu_targets, " ") + + def test_parse_gpu_targets_invalid_arch(self): + targets = ["gfx908", "gfx940", "--oops", "/jax"] + self.assertRaises(ValueError, ci_build.parse_gpu_targets, " ".join(targets)) + + def test_canonicalize_python_versions(self): + versions = ["3.10.0", "3.11.0", "3.12.0"] + exp = ["3.10", "3.11", "3.12"] + res = ci_build.canonicalize_python_versions(versions) + self.assertEqual(res, exp) + + def test_canonicalize_python_versions_scalar(self): + versions = ["3.10.0"] + exp = ["3.10"] + res = ci_build.canonicalize_python_versions(versions) + self.assertEqual(res, exp) + + def test_canonicalize_python_versions_no_revision_part(self): + versions = ["3.10", "3.11"] + res = ci_build.canonicalize_python_versions(versions) + self.assertEqual(res, versions) + + def test_canonicalize_python_versions_string(self): + versions = "3.10.0" + self.assertRaises(ValueError, ci_build.canonicalize_python_versions, versions) + + +if __name__ == "__main__": + unittest.main() diff --git a/build/rocm/tools/build_wheels.py b/build/rocm/tools/build_wheels.py index 36d2c35d2f36..65398cdb4da1 100644 --- a/build/rocm/tools/build_wheels.py +++ b/build/rocm/tools/build_wheels.py @@ -30,12 +30,15 @@ import subprocess import shutil import sys +from typing import List LOG = logging.getLogger(__name__) -GPU_DEVICE_TARGETS = "gfx906 gfx908 gfx90a gfx942 gfx1030 gfx1100 gfx1101 gfx1200 gfx1201" +DEFAULT_GPU_DEVICE_TARGETS = ( + "gfx906,gfx908,gfx90a,gfx942,gfx1030,gfx1100,gfx1101,gfx1200,gfx1201" +) def build_rocm_path(rocm_version_str): @@ -46,11 +49,11 @@ def build_rocm_path(rocm_version_str): return os.path.realpath("/opt/rocm") -def update_rocm_targets(rocm_path, targets): +def update_rocm_targets(rocm_path: str, targets: List[str]): target_fp = os.path.join(rocm_path, "bin/target.lst") version_fp = os.path.join(rocm_path, ".info/version") with open(target_fp, "w") as fd: - fd.write("%s\n" % targets) + fd.write("%s\n" % " ".join(targets)) # mimic touch open(version_fp, "a").close() @@ -250,7 +253,7 @@ def parse_args(): ) p.add_argument( "--python-versions", - default=["3.10.19,3.12"], + default="3.10.19,3.12", help="Comma separated CPython versions that wheels will be built and output for", ) p.add_argument( @@ -265,6 +268,11 @@ def parse_args(): default="gcc", help="Compiler backend to use when compiling jax/jaxlib", ) + p.add_argument( + "--gpu-device-targets", + default=DEFAULT_GPU_DEVICE_TARGETS, + help="Comma separated list of GPU device targets passed from job", + ) p.add_argument("jax_path", help="Directory where JAX source directory is located") @@ -285,6 +293,7 @@ def find_wheels(path): def main(): args = parse_args() python_versions = args.python_versions.split(",") + gpu_device_targets = args.gpu_device_targets.split(",") print("ROCM_VERSION=%s" % args.rocm_version) print("PYTHON_VERSIONS=%r" % python_versions) @@ -294,7 +303,7 @@ def main(): rocm_path = build_rocm_path(args.rocm_version) - update_rocm_targets(rocm_path, GPU_DEVICE_TARGETS) + update_rocm_targets(rocm_path, gpu_device_targets) for py in python_versions: build_jaxlib_wheel(args.jax_path, rocm_path, py, args.xla_path, args.compiler)