Skip to content

Find pyversion at build time 0.5.0 #328

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: rocm-jaxlib-v0.5.0
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 81 additions & 4 deletions build/rocm/ci_build
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,12 @@ import argparse
import os
import subprocess
import sys
from typing import List


DEFAULT_GPU_DEVICE_TARGETS = (
"gfx906,gfx908,gfx90a,gfx942,gfx1030,gfx1100,gfx1101,gfx1200,gfx1201"
)


def image_by_name(name):
Expand All @@ -40,7 +46,11 @@ def dist_wheels(
rocm_build_job="",
rocm_build_num="",
compiler="gcc",
gpu_device_targets: List[str] = None,
):
if not gpu_device_targets:
gpu_device_targets = DEFAULT_GPU_DEVICE_TARGETS.split(",")

if xla_path:
xla_path = os.path.abspath(xla_path)

Expand All @@ -63,6 +73,7 @@ def dist_wheels(
"--build-arg=ROCM_BUILD_JOB=%s" % rocm_build_job,
"--build-arg=ROCM_BUILD_NUM=%s" % rocm_build_num,
"--tag=%s" % image,
"--build-arg=GPU_DEVICE_TARGETS=%s" % " ".join(gpu_device_targets),
".",
]

Expand All @@ -85,6 +96,8 @@ def dist_wheels(
pyver_string,
"--compiler",
compiler,
"--gpu-device-targets",
",".join(gpu_device_targets),
]

if xla_path:
Expand Down Expand Up @@ -158,10 +171,14 @@ def dist_docker(
tag="rocm/jax-dev",
dockerfile=None,
keep_image=True,
gpu_device_targets: List[str] = None,
):
if not dockerfile:
dockerfile = "build/rocm/Dockerfile.ms"

if not gpu_device_targets:
gpu_device_targets = DEFAULT_GPU_DEVICE_TARGETS.split(",")

python_version = python_versions[0]

md = _fetch_jax_metadata(xla_path)
Expand All @@ -174,6 +191,7 @@ def dist_docker(
"--target",
"rt_build",
"--build-arg=ROCM_VERSION=%s" % rocm_version,
"--build-arg=GPU_DEVICE_TARGETS=%s" % " ".join(gpu_device_targets),
"--build-arg=ROCM_BUILD_JOB=%s" % rocm_build_job,
"--build-arg=ROCM_BUILD_NUM=%s" % rocm_build_num,
"--build-arg=BASE_DOCKER=%s" % base_docker,
Expand Down Expand Up @@ -238,6 +256,55 @@ def test(image_name):
subprocess.check_call(cmd)


def canonicalize_python_versions(versions: List[str]):
if isinstance(versions, str):
raise ValueError("'versions' must be a list of strings: versions=%r" % versions)

cleaned = []
for v in versions:
tup = v.split(".")
major = tup[0]
minor = tup[1]
rev = None
if len(tup) > 2 and tup[2]:
rev = tup[2]

cleaned.append("%s.%s" % (major, minor))

return cleaned


def parse_gpu_targets(targets_string):
# catch case where targets_string was empty.
# None should already be caught by argparse, but
# it doesn't hurt to check twice
if not targets_string:
targets_string = DEFAULT_GPU_DEVICE_TARGETS

if "," in targets_string:
targets = targets_string.split(",")
elif " " in targets_string:
targets = targets_string.split(" ")
else:
targets = targets_string

res = []
# cleanup and validation
for t in targets:
if not t:
continue

if not t.startswith("gfx"):
raise ValueError("Invalid GPU architecture target: %r" % t)

res.append(t.strip())

if not res:
raise ValueError("GPU_DEVICE_TARGETS cannot be empty")

return res


def parse_args():
p = argparse.ArgumentParser()
p.add_argument(
Expand All @@ -249,7 +316,7 @@ def parse_args():
p.add_argument(
"--python-versions",
type=lambda x: x.split(","),
default="3.12",
default=["3.12"],
help="Comma separated list of CPython versions to build wheels for",
)

Expand Down Expand Up @@ -281,6 +348,11 @@ def parse_args():
choices=["gcc", "clang"],
help="Compiler backend to use when compiling jax/jaxlib",
)
p.add_argument(
"--gpu-device-targets",
default=DEFAULT_GPU_DEVICE_TARGETS,
help="List of AMDGPU device targets passed from job",
)

subp = p.add_subparsers(dest="action", required=True)

Expand All @@ -299,15 +371,18 @@ def parse_args():

def main():
args = parse_args()
gpu_device_targets = parse_gpu_targets(args.gpu_device_targets)
python_versions = canonicalize_python_versions(args.python_versions)

if args.action == "dist_wheels":
dist_wheels(
args.rocm_version,
args.python_versions,
python_versions,
args.xla_source_dir,
args.rocm_build_job,
args.rocm_build_num,
compiler=args.compiler,
gpu_device_targets=gpu_device_targets,
)

elif args.action == "test":
Expand All @@ -316,22 +391,24 @@ def main():
elif args.action == "dist_docker":
dist_wheels(
args.rocm_version,
args.python_versions,
python_versions,
args.xla_source_dir,
args.rocm_build_job,
args.rocm_build_num,
compiler=args.compiler,
gpu_device_targets=gpu_device_targets,
)
dist_docker(
args.rocm_version,
args.base_docker,
args.python_versions,
python_versions,
args.xla_source_dir,
rocm_build_job=args.rocm_build_job,
rocm_build_num=args.rocm_build_num,
tag=args.image_tag,
dockerfile=args.dockerfile,
keep_image=args.keep_image,
gpu_device_targets=gpu_device_targets,
)


Expand Down
14 changes: 14 additions & 0 deletions build/rocm/ci_build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ ROCM_BUILD_NUM=""
BASE_DOCKER="ubuntu:22.04"
CUSTOM_INSTALL=""
JAX_USE_CLANG=""
GPU_DEVICE_TARGETS=""
POSITIONAL_ARGS=()

RUNTIME_FLAG=0
Expand Down Expand Up @@ -98,6 +99,18 @@ while [[ $# -gt 0 ]]; do
JAX_USE_CLANG="$2"
shift 2
;;
--gpu_device_targets)
if [[ "$2" == "--custom_install" ]]; then
GPU_DEVICE_TARGETS=""
shift 2
elif [[ -n "$2" ]]; then
GPU_DEVICE_TARGETS="$2"
shift 2
else
GPU_DEVICE_TARGETS=""
shift 1
fi
;;
*)
POSITIONAL_ARGS+=("$1")
shift
Expand Down Expand Up @@ -164,6 +177,7 @@ fi
--rocm-build-job=$ROCM_BUILD_JOB \
--rocm-build-num=$ROCM_BUILD_NUM \
--compiler=$JAX_COMPILER \
--gpu-device-targets="${GPU_DEVICE_TARGETS}" \
dist_docker \
--dockerfile $DOCKERFILE_PATH \
--image-tag $DOCKER_IMG_NAME
Expand Down
81 changes: 81 additions & 0 deletions build/rocm/test_ci_build.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
#!/usr/bin/env python3

# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import importlib.util
import importlib.machinery


def load_ci_build():
spec = importlib.util.spec_from_loader(
"ci_build", importlib.machinery.SourceFileLoader("ci_build", "./ci_build")
)
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
return mod


ci_build = load_ci_build()


class CIBuildTestCase(unittest.TestCase):
def test_parse_gpu_targets_spaces(self):
targets = ["gfx908", "gfx940", "gfx1201"]
r = ci_build.parse_gpu_targets(" ".join(targets))
self.assertEqual(r, targets)

def test_parse_gpu_targets_commas(self):
targets = ["gfx908", "gfx940", "gfx1201"]
r = ci_build.parse_gpu_targets(",".join(targets))
self.assertEqual(r, targets)

def test_parse_gpu_targets_empty_string(self):
expected = ci_build.DEFAULT_GPU_DEVICE_TARGETS.split(",")
r = ci_build.parse_gpu_targets("")
self.assertEqual(r, expected)

def test_parse_gpu_targets_whitespace_only(self):
self.assertRaises(ValueError, ci_build.parse_gpu_targets, " ")

def test_parse_gpu_targets_invalid_arch(self):
targets = ["gfx908", "gfx940", "--oops", "/jax"]
self.assertRaises(ValueError, ci_build.parse_gpu_targets, " ".join(targets))

def test_canonicalize_python_versions(self):
versions = ["3.10.0", "3.11.0", "3.12.0"]
exp = ["3.10", "3.11", "3.12"]
res = ci_build.canonicalize_python_versions(versions)
self.assertEqual(res, exp)

def test_canonicalize_python_versions_scalar(self):
versions = ["3.10.0"]
exp = ["3.10"]
res = ci_build.canonicalize_python_versions(versions)
self.assertEqual(res, exp)

def test_canonicalize_python_versions_no_revision_part(self):
versions = ["3.10", "3.11"]
res = ci_build.canonicalize_python_versions(versions)
self.assertEqual(res, versions)

def test_canonicalize_python_versions_string(self):
versions = "3.10.0"
self.assertRaises(ValueError, ci_build.canonicalize_python_versions, versions)


if __name__ == "__main__":
unittest.main()
19 changes: 14 additions & 5 deletions build/rocm/tools/build_wheels.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,15 @@
import subprocess
import shutil
import sys
from typing import List


LOG = logging.getLogger(__name__)


GPU_DEVICE_TARGETS = "gfx906 gfx908 gfx90a gfx942 gfx1030 gfx1100 gfx1101 gfx1200 gfx1201"
DEFAULT_GPU_DEVICE_TARGETS = (
"gfx906,gfx908,gfx90a,gfx942,gfx1030,gfx1100,gfx1101,gfx1200,gfx1201"
)


def build_rocm_path(rocm_version_str):
Expand All @@ -46,11 +49,11 @@ def build_rocm_path(rocm_version_str):
return os.path.realpath("/opt/rocm")


def update_rocm_targets(rocm_path, targets):
def update_rocm_targets(rocm_path: str, targets: List[str]):
target_fp = os.path.join(rocm_path, "bin/target.lst")
version_fp = os.path.join(rocm_path, ".info/version")
with open(target_fp, "w") as fd:
fd.write("%s\n" % targets)
fd.write("%s\n" % " ".join(targets))

# mimic touch
open(version_fp, "a").close()
Expand Down Expand Up @@ -250,7 +253,7 @@ def parse_args():
)
p.add_argument(
"--python-versions",
default=["3.10.19,3.12"],
default="3.10.19,3.12",
help="Comma separated CPython versions that wheels will be built and output for",
)
p.add_argument(
Expand All @@ -265,6 +268,11 @@ def parse_args():
default="gcc",
help="Compiler backend to use when compiling jax/jaxlib",
)
p.add_argument(
"--gpu-device-targets",
default=DEFAULT_GPU_DEVICE_TARGETS,
help="Comma separated list of GPU device targets passed from job",
)

p.add_argument("jax_path", help="Directory where JAX source directory is located")

Expand All @@ -285,6 +293,7 @@ def find_wheels(path):
def main():
args = parse_args()
python_versions = args.python_versions.split(",")
gpu_device_targets = args.gpu_device_targets.split(",")

print("ROCM_VERSION=%s" % args.rocm_version)
print("PYTHON_VERSIONS=%r" % python_versions)
Expand All @@ -294,7 +303,7 @@ def main():

rocm_path = build_rocm_path(args.rocm_version)

update_rocm_targets(rocm_path, GPU_DEVICE_TARGETS)
update_rocm_targets(rocm_path, gpu_device_targets)

for py in python_versions:
build_jaxlib_wheel(args.jax_path, rocm_path, py, args.xla_path, args.compiler)
Expand Down