@@ -49,10 +49,18 @@ RUN if [ "$DEVICE" = "tpu" ] && [ "$JAX_STABLE_STACK_BASEIMAGE" = "us-docker.pkg
4949
5050#  Install google-tunix for TPU devices, skip for GPU
5151RUN  if [ "$DEVICE"  = "tpu"  ]; then \
52+         echo "TPU device detected. Installing google-tunix." ; \
5253        python3 -m pip install 'google-tunix>=0.1.2' ; \
53-         #  TODO: Once tunix stopped pinning jax 0.7.1, we should remove our 0.7.0 version pin (b/450286600)
54-         python3 -m pip install 'jax==0.7.0'  'jaxlib==0.7.0' ; \
55-   fi
54+ \
55+         if [[ "$JAX_AI_IMAGE_BASEIMAGE"  == *"nightly" * ]]; then \
56+             echo "Nightly image detected. Uninstalling base JAX and installing pre-release." ; \
57+             pip uninstall -y jax jaxlib libtpu; \
58+             pip install -U --pre jax jaxlib libtpu requests -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ -f https://storage.googleapis.com/jax-releases/libtpu_releases.html; \
59+         else \
60+             echo "Non-nightly image. Installing JAX 0.7.0." ; \
61+             python3 -m pip install jax[tpu]==0.7.0; \
62+         fi; \
63+     fi
5664
5765#  Now copy the remaining code (source files that may change frequently)
5866COPY  . .
@@ -68,7 +76,14 @@ RUN if [ "$TEST_TYPE" = "xlml" ] || [ "$TEST_TYPE" = "unit_test" ]; then \
6876    fi
6977
7078#  Run the script available in JAX AI base image to generate the manifest file
71- RUN  bash /jax-ai-image/generate_manifest.sh PREFIX=maxtext COMMIT_HASH=$COMMIT_HASH
79+ RUN  if [ -d "/jax-ai-image"  ]; then \
80+         echo "Found /jax-ai-image directory. Running with 'jax-ai-image' path." ; \
81+         bash /jax-ai-image/generate_manifest.sh PREFIX=maxtext COMMIT_HASH=$COMMIT_HASH; \
82+     else \
83+         echo "/jax-ai-image not found. Running with 'jax-stable-stack' path." ; \
84+         bash /jax-stable-stack/generate_manifest.sh PREFIX=maxtext COMMIT_HASH=$COMMIT_HASH; \
85+     fi
86+ 
7287
7388#  Install (editable) MaxText
7489RUN  test -f '/tmp/venv_created'  && "$(tail -n1 /tmp/venv_created)" /bin/activate ; pip install --no-dependencies -e .
0 commit comments