Skip to content

Commit 877deed

Browse files
authored
Merge pull request #240 from ROCm/ci-upstream-sync-127_1
CI: 02/25/25 upstream sync
2 parents 9cac506 + 1217ba9 commit 877deed

File tree

275 files changed

+18962
-8889
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

275 files changed

+18962
-8889
lines changed

.bazelrc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# All default build options below. These apply to all build commands.
33
# #############################################################################
44
# Make Bazel print out all options from rc files.
5-
build --announce_rc
5+
common --announce_rc
66

77
# By default, execute all actions locally.
88
build --spawn_strategy=local
@@ -11,7 +11,7 @@ build --spawn_strategy=local
1111
# automatically when building on Linux.
1212
build --enable_platform_specific_config
1313

14-
build --experimental_cc_shared_library
14+
common --experimental_cc_shared_library
1515

1616
# Do not use C-Ares when building gRPC.
1717
build --define=grpc_no_ares=true

.github/workflows/asan.yaml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,8 @@ jobs:
6262
run: |
6363
source ${GITHUB_WORKSPACE}/venv/bin/activate
6464
cd jax
65-
pip install -r build/test-requirements.txt
65+
pip install uv~=0.5.30
66+
uv pip install -r build/test-requirements.txt
6667
- name: Build and install JAX
6768
env:
6869
ASAN_OPTIONS: detect_leaks=0
@@ -73,8 +74,8 @@ jobs:
7374
--bazel_options=--color=yes \
7475
--bazel_options=--copt=-fsanitize=address \
7576
--clang_path=/usr/bin/clang-18
76-
pip install dist/jaxlib-*.whl
77-
pip install -e .
77+
uv pip install dist/jaxlib-*.whl \
78+
-e .
7879
- name: Run tests
7980
env:
8081
ASAN_OPTIONS: detect_leaks=0

.github/workflows/bazel_cuda_non_rbe.yml

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,12 @@ on:
4242

4343
jobs:
4444
run-tests:
45+
defaults:
46+
run:
47+
# Explicitly set the shell to bash
48+
shell: bash
4549
runs-on: ${{ inputs.runner }}
46-
container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/nosla-cuda12.3-cudnn9.1-ubuntu20.04-manylinux2014-multipython:latest"
50+
container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-cuda12.3-cudnn9.1:latest"
4751

4852
env:
4953
JAXCI_HERMETIC_PYTHON_VERSION: ${{ inputs.python }}
@@ -68,11 +72,23 @@ jobs:
6872
echo "ARCH=${arch}" >> $GITHUB_ENV
6973
echo "PYTHON_MAJOR_MINOR=${python_major_minor}" >> $GITHUB_ENV
7074
- name: Download the wheel artifacts from GCS
75+
id: download-wheel-artifacts
76+
# Set continue-on-error to true to prevent actions from failing the workflow if this step
77+
# fails. Instead, we verify the outcome in the next step so that we can print a more
78+
# informative error message.
79+
continue-on-error: true
7180
run: >-
7281
mkdir -p $(pwd)/dist &&
7382
gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ &&
7483
gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jax*cuda*plugin*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ &&
7584
gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jax*cuda*pjrt*${OS}*${ARCH}*.whl" $(pwd)/dist/
85+
- name: Skip the test run if the wheel artifacts were not downloaded successfully
86+
if: steps.download-wheel-artifacts.outcome == 'failure'
87+
run: |
88+
echo "Failed to download wheel artifacts from GCS. Please check if the wheels were"
89+
echo "built successfully by the artifact build jobs and are available in the GCS bucket."
90+
echo "Skipping the test run."
91+
exit 1
7692
# Halt for testing
7793
- name: Wait For Connection
7894
uses: google-ml-infra/actions/ci_connection@main

.github/workflows/build_artifacts.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ jobs:
131131
with:
132132
halt-dispatch-input: ${{ inputs.halt-for-connection }}
133133
- name: Build ${{ inputs.artifact }}
134-
timeout-minutes: 30
134+
timeout-minutes: 60
135135
run: ./ci/build_artifacts.sh "${{ inputs.artifact }}"
136136
- name: Upload artifacts to a GCS bucket (non-Windows runs)
137137
if: >-

.github/workflows/ci-build.yaml

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ jobs:
6767
python-version: ${{ matrix.python-version }}
6868
- name: Install dependencies
6969
run: |
70-
pip install uv
70+
pip install uv~=0.5.30
7171
uv pip install --system .[minimum-jaxlib] -r build/test-requirements.txt
7272
7373
- name: Run tests
@@ -105,7 +105,7 @@ jobs:
105105
python-version: ${{ matrix.python-version }}
106106
- name: Install dependencies
107107
run: |
108-
pip install uv
108+
pip install uv~=0.5.30
109109
uv pip install --system -r docs/requirements.txt
110110
- name: Test documentation
111111
env:
@@ -133,7 +133,7 @@ jobs:
133133
python-version: ${{ matrix.python-version }}
134134
- name: Install dependencies
135135
run: |
136-
pip install uv
136+
pip install uv~=0.5.30
137137
uv pip install --system -r docs/requirements.txt
138138
- name: Render documentation
139139
run: |
@@ -159,8 +159,9 @@ jobs:
159159
python-version: ${{ matrix.python-version }}
160160
- name: Install dependencies
161161
run: |
162-
pip install uv
163-
uv pip install --system .[minimum-jaxlib] tensorflow -r build/test-requirements.txt
162+
pip install uv~=0.5.30
163+
uv pip install --system .[minimum-jaxlib] -r build/test-requirements.txt
164+
uv pip install --system --pre tensorflow==2.19.0rc0
164165
165166
- name: Run tests
166167
env:
@@ -189,6 +190,7 @@ jobs:
189190
python-version: 3.12
190191
- name: Install JAX
191192
run: |
193+
pip install uv~=0.5.30
192194
pip install uv
193195
uv pip install --system .
194196
- name: Build and install example project

.github/workflows/cloud-tpu-ci-nightly.yml

Lines changed: 16 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -59,42 +59,38 @@ jobs:
5959
git config --global --add safe.directory "$GITHUB_WORKSPACE"
6060
- name: Install JAX test requirements
6161
run: |
62-
$PYTHON -m pip install -U -r build/test-requirements.txt
63-
$PYTHON -m pip install -U -r build/collect-profile-requirements.txt
62+
$PYTHON -m uv pip install -U -r build/test-requirements.txt -r build/collect-profile-requirements.txt
6463
- name: Install JAX
6564
run: |
66-
$PYTHON -m pip uninstall -y jax jaxlib libtpu
65+
$PYTHON -m uv pip uninstall jax jaxlib libtpu
6766
if [ "${{ matrix.jaxlib-version }}" == "head" ]; then
6867
# Build and install jaxlib at head
6968
$PYTHON build/build.py build --wheels=jaxlib \
7069
--bazel_options=--config=rbe_linux_x86_64 \
7170
--local_xla_path="$(pwd)/xla" \
7271
--verbose
7372
74-
$PYTHON -m pip install dist/*.whl
75-
76-
# Install "jax" at head
77-
$PYTHON -m pip install -U -e .
78-
79-
# Install libtpu
80-
$PYTHON -m pip install --pre libtpu \
81-
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
73+
# Install jaxlib, "jax" at head, and libtpu
74+
$PYTHON -m uv pip install dist/*.whl \
75+
-U -e . \
76+
--pre libtpu -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
8277
elif [ "${{ matrix.jaxlib-version }}" == "pypi_latest" ]; then
83-
$PYTHON -m pip install .[tpu] \
78+
$PYTHON -m uv pip install .[tpu] \
8479
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
8580
8681
elif [ "${{ matrix.jaxlib-version }}" == "nightly" ]; then
87-
$PYTHON -m pip install --pre . -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
88-
$PYTHON -m pip install --pre libtpu \
89-
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
90-
$PYTHON -m pip install requests
82+
$PYTHON -m uv pip install \
83+
--pre . -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
84+
libtpu -f https://storage.googleapis.com/jax-releases/libtpu_releases.html \
85+
requests
9186
9287
elif [ "${{ matrix.jaxlib-version }}" == "nightly+oldest_supported_libtpu" ]; then
93-
$PYTHON -m pip install --pre . -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
9488
# TODO(phawkins): switch to libtpu, when the oldest release we support is a libtpu release.
95-
$PYTHON -m pip install --pre libtpu-nightly==0.1.dev${{ env.LIBTPU_OLDEST_VERSION_DATE }} \
96-
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
97-
$PYTHON -m pip install requests
89+
$PYTHON -m uv pip install \
90+
--pre . -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
91+
libtpu-nightly==0.1.dev${{ env.LIBTPU_OLDEST_VERSION_DATE }} \
92+
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html \
93+
requests
9894
else
9995
echo "Unknown jaxlib-version: ${{ matrix.jaxlib-version }}"
10096
exit 1

.github/workflows/cloud-tpu-ci-presubmit.yml

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,7 @@ jobs:
7474
git config --global --add safe.directory "$GITHUB_WORKSPACE"
7575
- name: Install JAX test requirements
7676
run: |
77-
$JAXCI_PYTHON -m pip install -U -r build/test-requirements.txt
78-
$JAXCI_PYTHON -m pip install -U -r build/collect-profile-requirements.txt
77+
$JAXCI_PYTHON -m uv pip install -U -r build/test-requirements.txt -r build/collect-profile-requirements.txt
7978
- name: Build jaxlib at head with latest XLA
8079
run: |
8180
# Build and install jaxlib at head
@@ -86,7 +85,7 @@ jobs:
8685
--verbose
8786
8887
# Install libtpu
89-
$JAXCI_PYTHON -m pip install --pre libtpu \
88+
$JAXCI_PYTHON -m uv pip install --pre libtpu \
9089
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
9190
# Halt for testing
9291
- name: Wait For Connection

.github/workflows/jax-array-api.yml

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,8 @@ jobs:
3737
python-version: ${{ matrix.python-version }}
3838
- name: Install dependencies
3939
run: |
40-
pip install uv
41-
uv pip install --system .[ci]
42-
uv pip install --system pytest-xdist -r array-api-tests/requirements.txt
40+
pip install uv~=0.5.30
41+
uv pip install --system .[ci] pytest-xdist -r array-api-tests/requirements.txt
4342
- name: Run the test suite
4443
env:
4544
ARRAY_API_TESTS_MODULE: jax.numpy

.github/workflows/metal_plugin_ci.yml

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,15 +35,14 @@ jobs:
3535
rm -rf ${GITHUB_WORKSPACE}/jax-metal-venv
3636
python3 -m venv ${GITHUB_WORKSPACE}/jax-metal-venv
3737
source ${GITHUB_WORKSPACE}/jax-metal-venv/bin/activate
38-
pip install -U pip numpy wheel
39-
pip install absl-py pytest
38+
pip install uv~=0.5.30
39+
uv pip install -U pip numpy wheel absl-py pytest
4040
if [[ "${{ matrix.jaxlib-version }}" == "nightly" ]]; then
41-
pip install --pre jaxlib \
41+
uv pip install --pre jaxlib \
4242
-f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
4343
fi;
4444
cd jax
45-
pip install .
46-
pip install jax-metal
45+
uv pip install . jax-metal
4746
- name: Run test
4847
run: |
4948
source ${GITHUB_WORKSPACE}/jax-metal-venv/bin/activate

.github/workflows/pytest_cpu.yml

Lines changed: 40 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,11 @@ on:
2929
type: string
3030
required: true
3131
default: "0"
32+
install-jax-current-commit:
33+
description: "Should the 'jax' package be installed from the current commit?"
34+
type: string
35+
required: true
36+
default: "1"
3237
gcs_download_uri:
3338
description: "GCS location prefix from where the artifacts should be downloaded"
3439
required: true
@@ -57,6 +62,7 @@ jobs:
5762
JAXCI_HERMETIC_PYTHON_VERSION: "${{ inputs.python }}"
5863
JAXCI_PYTHON: "python${{ inputs.python }}"
5964
JAXCI_ENABLE_X64: "${{ inputs.enable-x64 }}"
65+
JAXCI_INSTALL_JAX_CURRENT_COMMIT: "${{ inputs.install-jax-current-commit }}"
6066

6167
steps:
6268
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
@@ -79,18 +85,47 @@ jobs:
7985
echo "ARCH=${arch}" >> $GITHUB_ENV
8086
echo "PYTHON_MAJOR_MINOR=${python_major_minor}" >> $GITHUB_ENV
8187
- name: Download jaxlib wheel from GCS (non-Windows runs)
88+
id: download-wheel-artifacts-nw
89+
# Set continue-on-error to true to prevent actions from failing the workflow if this step
90+
# fails. Instead, we verify the outcome in the step below so that we can print a more
91+
# informative error message.
92+
continue-on-error: true
8293
if: ${{ !contains(inputs.runner, 'windows-x86') }}
83-
run: >-
94+
run: |
8495
mkdir -p $(pwd)/dist &&
8596
gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/
97+
98+
# Download the "jax" wheel from GCS if inputs.install-jax-current-commit is not set to 1
99+
if [[ "${{ inputs.install-jax-current-commit }}" != 1 ]]; then
100+
gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*any.whl $(pwd)/dist/
101+
fi
86102
- name: Download jaxlib wheel from GCS (Windows runs)
103+
id: download-wheel-artifacts-w
104+
# Set continue-on-error to true to prevent actions from failing the workflow if this step
105+
# fails. Instead, we verify the outcome in step below so that we can print a more
106+
# informative error message.
107+
continue-on-error: true
87108
if: ${{ contains(inputs.runner, 'windows-x86') }}
88109
shell: cmd
89-
run: >-
90-
mkdir dist &&
91-
gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jaxlib*%PYTHON_MAJOR_MINOR%*%OS%*%ARCH%*.whl" dist/
110+
run: |
111+
mkdir dist
112+
@REM Use `call` so that we can run sequential gsutil commands on Windows
113+
@REM See https://github.com/GoogleCloudPlatform/gsutil/issues/233#issuecomment-196150652
114+
call gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jaxlib*%PYTHON_MAJOR_MINOR%*%OS%*%ARCH%*.whl" dist/
115+
116+
@REM Download the "jax" wheel from GCS if inputs.install-jax-current-commit is not set to 1
117+
if not "${{ inputs.install-jax-current-commit }}"=="1" (
118+
call gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*any.whl dist/
119+
)
120+
- name: Skip the test run if the wheel artifacts were not downloaded successfully
121+
if: steps.download-wheel-artifacts-nw.outcome == 'failure' || steps.download-wheel-artifacts-w.outcome == 'failure'
122+
run: |
123+
echo "Failed to download wheel artifacts from GCS. Please check if the wheels were"
124+
echo "built successfully by the artifact build jobs and are available in the GCS bucket."
125+
echo "Skipping the test run."
126+
exit 1
92127
- name: Install Python dependencies
93-
run: $JAXCI_PYTHON -m pip install -r build/requirements.in
128+
run: $JAXCI_PYTHON -m uv pip install -r build/requirements.in
94129
# Halt for testing
95130
- name: Wait For Connection
96131
uses: google-ml-infra/actions/ci_connection@main

.github/workflows/pytest_cuda.yml

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,11 @@ on:
3434
type: string
3535
required: true
3636
default: "0"
37+
install-jax-current-commit:
38+
description: "Should the 'jax' package be installed from the current commit?"
39+
type: string
40+
required: true
41+
default: "1"
3742
gcs_download_uri:
3843
description: "GCS location prefix from where the artifacts should be downloaded"
3944
required: true
@@ -47,16 +52,21 @@ on:
4752

4853
jobs:
4954
run-tests:
55+
defaults:
56+
run:
57+
# Explicitly set the shell to bash
58+
shell: bash
5059
runs-on: ${{ inputs.runner }}
5160
# TODO: Update to the generic ML ecosystem test containers when they are ready.
52-
container: ${{ (contains(inputs.cuda, '12.3') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/nosla-cuda12.3-cudnn9.1-ubuntu20.04-manylinux2014-multipython:latest') ||
53-
(contains(inputs.cuda, '12.1') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/nosla-cuda12.1-cudnn9.1-ubuntu20.04-manylinux2014-multipython:latest') }}
61+
container: ${{ (contains(inputs.cuda, '12.3') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-cuda12.3-cudnn9.1:latest') ||
62+
(contains(inputs.cuda, '12.1') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-cuda12.1-cudnn9.1:latest') }}
5463
name: "Pytest CUDA (${{ inputs.runner }}, CUDA ${{ inputs.cuda }}, Python ${{ inputs.python }}, x64=${{ inputs.enable-x64 }})"
5564

5665
env:
5766
JAXCI_HERMETIC_PYTHON_VERSION: "${{ inputs.python }}"
5867
JAXCI_PYTHON: "python${{ inputs.python }}"
5968
JAXCI_ENABLE_X64: "${{ inputs.enable-x64 }}"
69+
JAXCI_INSTALL_JAX_CURRENT_COMMIT: "${{ inputs.install-jax-current-commit }}"
6070

6171
steps:
6272
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
@@ -73,13 +83,30 @@ jobs:
7383
echo "ARCH=${arch}" >> $GITHUB_ENV
7484
echo "PYTHON_MAJOR_MINOR=${python_major_minor}" >> $GITHUB_ENV
7585
- name: Download the wheel artifacts from GCS
76-
run: >-
86+
id: download-wheel-artifacts
87+
# Set continue-on-error to true to prevent actions from failing the workflow if this step
88+
# fails. Instead, we verify the outcome in the next step so that we can print a more
89+
# informative error message.
90+
continue-on-error: true
91+
run: |
7792
mkdir -p $(pwd)/dist &&
7893
gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ &&
7994
gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jax*cuda*plugin*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ &&
8095
gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jax*cuda*pjrt*${OS}*${ARCH}*.whl" $(pwd)/dist/
96+
97+
# Download the "jax" wheel from GCS if inputs.install-jax-current-commit is not set to 1
98+
if [[ "${{ inputs.install-jax-current-commit }}" != 1 ]]; then
99+
gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*any.whl $(pwd)/dist/
100+
fi
101+
- name: Skip the test run if the wheel artifacts were not downloaded successfully
102+
if: steps.download-wheel-artifacts.outcome == 'failure'
103+
run: |
104+
echo "Failed to download wheel artifacts from GCS. Please check if the wheels were"
105+
echo "built successfully by the artifact build jobs and are available in the GCS bucket."
106+
echo "Skipping the test run."
107+
exit 1
81108
- name: Install Python dependencies
82-
run: $JAXCI_PYTHON -m pip install -r build/requirements.in
109+
run: $JAXCI_PYTHON -m uv pip install -r build/requirements.in
83110
# Halt for testing
84111
- name: Wait For Connection
85112
uses: google-ml-infra/actions/ci_connection@main

0 commit comments

Comments
 (0)