Skip to content

Commit 6ecb399

Browse files
Update patch to ensure maxtext images are downgraded to 0.7.0
Remove extra comment Force install jax nightly only in jaii nightly images Remove duplicate block
1 parent cc9a196 commit 6ecb399

File tree

2 files changed

+20
-5
lines changed

2 files changed

+20
-5
lines changed

.github/workflows/RunTests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ jobs:
5151
device_name: v4-8
5252
cloud_runner: linux-x86-n2-16-buildkit
5353
build_mode: jax_ai_image
54-
base_image: us-docker.pkg.dev/tpu-prod-env-multipod/jax-stable-stack/candidate/tpu:latest
54+
base_image: us-docker.pkg.dev/cloud-tpu-images/jax-ai-image/tpu:jax0.7.0-rev1
5555

5656
gpu_image:
5757
needs: prelim

maxtext_jax_ai_image.Dockerfile

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,10 +49,18 @@ RUN if [ "$DEVICE" = "tpu" ] && [ "$JAX_STABLE_STACK_BASEIMAGE" = "us-docker.pkg
4949

5050
# Install google-tunix for TPU devices, skip for GPU
5151
RUN if [ "$DEVICE" = "tpu" ]; then \
52+
echo "TPU device detected. Installing google-tunix."; \
5253
python3 -m pip install 'google-tunix>=0.1.2'; \
53-
# TODO: Once tunix stopped pinning jax 0.7.1, we should remove our 0.7.0 version pin (b/450286600)
54-
python3 -m pip install 'jax==0.7.0' 'jaxlib==0.7.0'; \
55-
fi
54+
\
55+
if [[ "$JAX_AI_IMAGE_BASEIMAGE" == *"nightly"* ]]; then \
56+
echo "Nightly image detected. Uninstalling base JAX and installing pre-release."; \
57+
pip uninstall -y jax jaxlib libtpu; \
58+
pip install -U --pre jax jaxlib libtpu requests -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ -f https://storage.googleapis.com/jax-releases/libtpu_releases.html; \
59+
else \
60+
echo "Non-nightly image. Installing JAX 0.7.0."; \
61+
python3 -m pip install jax[tpu]==0.7.0; \
62+
fi; \
63+
fi
5664

5765
# Now copy the remaining code (source files that may change frequently)
5866
COPY . .
@@ -68,7 +76,14 @@ RUN if [ "$TEST_TYPE" = "xlml" ] || [ "$TEST_TYPE" = "unit_test" ]; then \
6876
fi
6977

7078
# Run the script available in JAX AI base image to generate the manifest file
71-
RUN bash /jax-ai-image/generate_manifest.sh PREFIX=maxtext COMMIT_HASH=$COMMIT_HASH
79+
RUN if [ -d "/jax-ai-image" ]; then \
80+
echo "Found /jax-ai-image directory. Running with 'jax-ai-image' path."; \
81+
bash /jax-ai-image/generate_manifest.sh PREFIX=maxtext COMMIT_HASH=$COMMIT_HASH; \
82+
else \
83+
echo "/jax-ai-image not found. Running with 'jax-stable-stack' path."; \
84+
bash /jax-stable-stack/generate_manifest.sh PREFIX=maxtext COMMIT_HASH=$COMMIT_HASH; \
85+
fi
86+
7287

7388
# Install (editable) MaxText
7489
RUN test -f '/tmp/venv_created' && "$(tail -n1 /tmp/venv_created)"/bin/activate ; pip install --no-dependencies -e .

0 commit comments

Comments
 (0)