diff --git a/.github/workflows/RunTests.yml b/.github/workflows/RunTests.yml index 3787f2207..6e5ce347f 100644 --- a/.github/workflows/RunTests.yml +++ b/.github/workflows/RunTests.yml @@ -51,7 +51,7 @@ jobs: device_name: v4-8 cloud_runner: linux-x86-n2-16-buildkit build_mode: jax_ai_image - base_image: us-docker.pkg.dev/tpu-prod-env-multipod/jax-stable-stack/candidate/tpu:latest + base_image: us-docker.pkg.dev/cloud-tpu-images/jax-ai-image/tpu:jax0.7.0-rev1 gpu_image: needs: prelim diff --git a/maxtext_jax_ai_image.Dockerfile b/maxtext_jax_ai_image.Dockerfile index cd2dd457a..d51d1c324 100644 --- a/maxtext_jax_ai_image.Dockerfile +++ b/maxtext_jax_ai_image.Dockerfile @@ -49,10 +49,18 @@ RUN if [ "$DEVICE" = "tpu" ] && [ "$JAX_STABLE_STACK_BASEIMAGE" = "us-docker.pkg # Install google-tunix for TPU devices, skip for GPU RUN if [ "$DEVICE" = "tpu" ]; then \ + echo "TPU device detected. Installing google-tunix."; \ python3 -m pip install 'google-tunix>=0.1.2'; \ - # TODO: Once tunix stopped pinning jax 0.7.1, we should remove our 0.7.0 version pin (b/450286600) - python3 -m pip install 'jax==0.7.0' 'jaxlib==0.7.0'; \ - fi +\ + if [[ "$JAX_AI_IMAGE_BASEIMAGE" == *"nightly"* ]]; then \ + echo "Nightly image detected. Uninstalling base JAX and installing pre-release."; \ + pip uninstall -y jax jaxlib libtpu; \ + pip install -U --pre jax jaxlib libtpu requests -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ -f https://storage.googleapis.com/jax-releases/libtpu_releases.html; \ + else \ + echo "Non-nightly image. Installing JAX 0.7.0."; \ + python3 -m pip install jax[tpu]==0.7.0; \ + fi; \ + fi # Now copy the remaining code (source files that may change frequently) COPY . . @@ -68,7 +76,14 @@ RUN if [ "$TEST_TYPE" = "xlml" ] || [ "$TEST_TYPE" = "unit_test" ]; then \ fi # Run the script available in JAX AI base image to generate the manifest file -RUN bash /jax-ai-image/generate_manifest.sh PREFIX=maxtext COMMIT_HASH=$COMMIT_HASH +RUN if [ -d "/jax-ai-image" ]; then \ + echo "Found /jax-ai-image directory. Running with 'jax-ai-image' path."; \ + bash /jax-ai-image/generate_manifest.sh PREFIX=maxtext COMMIT_HASH=$COMMIT_HASH; \ + else \ + echo "/jax-ai-image not found. Running with 'jax-stable-stack' path."; \ + bash /jax-stable-stack/generate_manifest.sh PREFIX=maxtext COMMIT_HASH=$COMMIT_HASH; \ + fi + # Install (editable) MaxText RUN test -f '/tmp/venv_created' && "$(tail -n1 /tmp/venv_created)"/bin/activate ; pip install --no-dependencies -e .