Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions .github/workflows/UploadDockerImages.yml
Original file line number Diff line number Diff line change
Expand Up @@ -52,26 +52,26 @@ jobs:
include:
# TPU Image Builds
- image_name: maxtext_jax_stable
dockerfile: ./maxtext_dependencies.Dockerfile
dockerfile: ./dependencies/dockerfiles/maxtext_dependencies.Dockerfile
build_args: |
MODE=stable
JAX_VERSION=NONE
LIBTPU_GCS_PATH=NONE
- image_name: maxtext_jax_nightly
dockerfile: ./maxtext_dependencies.Dockerfile
dockerfile: ./dependencies/dockerfiles/maxtext_dependencies.Dockerfile
build_args: |
MODE=nightly
JAX_VERSION=NONE
LIBTPU_GCS_PATH=NONE
# TPU Image builds using JAX AI Image
- image_name: maxtext_jax_stable_stack
dockerfile: ./maxtext_jax_ai_image.Dockerfile
dockerfile: ./dependencies/dockerfiles/maxtext_jax_ai_image.Dockerfile
base_image: us-docker.pkg.dev/cloud-tpu-images/jax-ai-image/tpu:latest
- image_name: maxtext_stable_stack_nightly_jax
dockerfile: ./maxtext_jax_ai_image.Dockerfile
dockerfile: ./dependencies/dockerfiles/maxtext_jax_ai_image.Dockerfile
base_image: us-docker.pkg.dev/tpu-prod-env-multipod/jax-stable-stack/tpu/jax_nightly:latest
- image_name: maxtext_stable_stack_candidate
dockerfile: ./maxtext_jax_ai_image.Dockerfile
dockerfile: ./dependencies/dockerfiles/maxtext_jax_ai_image.Dockerfile
base_image: us-docker.pkg.dev/tpu-prod-env-multipod/jax-stable-stack/candidate/tpu:latest

# Setup for GKE runners per b/412986220#comment82 and b/412986220#comment90
Expand Down Expand Up @@ -130,13 +130,13 @@ jobs:
# GPU Image Builds using JAX AI Image
include:
- image_name: maxtext_gpu_jax_stable_stack
dockerfile: ./maxtext_jax_ai_image.Dockerfile
dockerfile: ./dependencies/dockerfiles/maxtext_jax_ai_image.Dockerfile
base_image: us-central1-docker.pkg.dev/deeplearning-images/jax-ai-image/gpu:latest
- image_name: maxtext_gpu_stable_stack_nightly_jax
dockerfile: ./maxtext_jax_ai_image.Dockerfile
dockerfile: ./dependencies/dockerfiles/maxtext_jax_ai_image.Dockerfile
base_image: us-docker.pkg.dev/tpu-prod-env-multipod/jax-stable-stack/gpu/jax_nightly:latest
- image_name: maxtext_stable_stack_candidate_gpu
dockerfile: ./maxtext_jax_ai_image.Dockerfile
dockerfile: ./dependencies/dockerfiles/maxtext_jax_ai_image.Dockerfile
base_image: us-docker.pkg.dev/tpu-prod-env-multipod/jax-stable-stack/candidate/gpu:latest

steps:
Expand Down
13 changes: 11 additions & 2 deletions .github/workflows/build_and_upload_images.sh
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,16 @@
# Example command:
# bash build_and_upload_images.sh PROJECT=<project> MODE=stable DEVICE=tpu CLOUD_IMAGE_NAME=${USER}_runner

if [ "${BASH_SOURCE-}" ]; then
this_file="${BASH_SOURCE[0]}"
elif [ "${ZSH_VERSION-}" ]; then
# shellcheck disable=SC2296
this_file="${(%):-%x}"
else
this_file="${0}"
fi
Comment on lines +27 to +34
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this needed? Can we remove this?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No we need it otherwise the script doesn't know where it is. In Python you have __file__ magic; this is how you get similar in shell.

Paths relative to this enable always-correct specification of [later] dockerfile location.


MAXTEXT_REPO_ROOT="${MAXTEXT_REPO_ROOT:-$(CDPATH='' cd -- "$(dirname -- "${this_file}")"'/../..' && pwd)}"
export LOCAL_IMAGE_NAME=maxtext_base_image

# Set environment variables
Expand All @@ -40,7 +49,7 @@ if [[ ! -v CLOUD_IMAGE_NAME ]] || [[ ! -v PROJECT ]] || [[ ! -v MODE ]] || [[ !
fi

gcloud auth configure-docker us-docker.pkg.dev --quiet
bash docker_build_dependency_image.sh LOCAL_IMAGE_NAME=$LOCAL_IMAGE_NAME MODE=$MODE DEVICE=$DEVICE
bash "$MAXTEXT_REPO_ROOT"'/dependencies/scripts/docker_build_dependency_image.sh' LOCAL_IMAGE_NAME=$LOCAL_IMAGE_NAME MODE="$MODE" DEVICE="$DEVICE"
image_date=$(date +%Y-%m-%d)

# Upload only dependencies image
Expand All @@ -56,7 +65,7 @@ if ! gcloud storage cp gs://maxtext-test-assets/* "${MAXTEXT_TEST_ASSETS_ROOT:-$
fi

# Build then upload "dependencies + code" image
docker build --build-arg BASEIMAGE=${LOCAL_IMAGE_NAME} -f ./maxtext_runner.Dockerfile -t ${LOCAL_IMAGE_NAME}_runner .
docker build --build-arg BASEIMAGE=${LOCAL_IMAGE_NAME} -f "$MAXTEXT_REPO_ROOT"'/dependencies/dockerfiles/maxtext_runner.Dockerfile' -t ${LOCAL_IMAGE_NAME}_runner .
docker tag ${LOCAL_IMAGE_NAME}_runner gcr.io/$PROJECT/${CLOUD_IMAGE_NAME}:latest
docker push gcr.io/$PROJECT/${CLOUD_IMAGE_NAME}:latest
docker tag ${LOCAL_IMAGE_NAME}_runner gcr.io/$PROJECT/${CLOUD_IMAGE_NAME}:${image_date}
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/build_upload_internal.yml
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ jobs:
with:
push: true
context: .
file: ./maxtext_jax_ai_image.Dockerfile
file: ./dependencies/dockerfiles/maxtext_jax_ai_image.Dockerfile
tags: gcr.io/tpu-prod-env-multipod/maxtext_${{ github.run_id }}:${{ inputs.device_type }}
provenance: false
build-args: |
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/check_docs_build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ jobs:
cache: 'pip' # caching pip dependencies

- name: Install dependencies
run: pip install -r requirements_docs.txt
run: pip install -r dependencies/requirements/requirements_docs.txt

- name: Build documentation
run: |
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/run_tests_against_package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ jobs:
source .venv/bin/activate
maxtext_wheel=$(ls maxtext-*-py3-none-any.whl 2>/dev/null)
uv pip install ${maxtext_wheel}[${MAXTEXT_PACKAGE_EXTRA}] --resolution=lowest
install_maxtext_github_deps
uv pip install -r dependencies/requirements/extra_deps_from_github.txt
python3 --version
python3 -m pip freeze
- name: Copy test assets files
Expand Down
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,9 @@ pip install uv

# 3. Install MaxText and its dependencies
uv pip install maxtext --resolution=lowest
install_maxtext_github_deps
uv pip install -r https://raw.githubusercontent.com/AI-Hypercomputer/maxtext/refs/heads/main/dependencies/requirements/extra_deps_from_github.txt
```
> **Note:** The `install_maxtext_github_deps` command is temporarily required to install dependencies directly from GitHub that are not yet available on PyPI.
> **Note:** The `extra_deps_from_github.txt` is temporarily required to install dependencies directly from GitHub that are not yet available on PyPI.

> **Note:** The maxtext package contains a comprehensive list of all direct and transitive dependencies, with lower bounds, generated by [seed-env](https://github.com/google-ml-infra/actions/tree/main/python_seed_env). We highly recommend the `--resolution=lowest` flag. It instructs `uv` to install the specific, tested versions of dependencies defined by MaxText, rather than the latest available ones. This ensures a consistent and reproducible environment, which is critical for stable performance and for running benchmarks.

Expand All @@ -69,7 +69,7 @@ pip install uv
uv pip install -e .[tpu] --resolution=lowest
# or install the gpu package by running the following line
# uv pip install -e .[cuda12] --resolution=lowest
install_maxtext_github_deps
uv pip install -r dependencies/requirements/extra_deps_from_github.txt
```

After installation, you can verify the package is available with `python3 -c "import MaxText"` and run training jobs with `python3 -m MaxText.train ...`.
Expand Down
43 changes: 0 additions & 43 deletions base_requirements/requirements.txt

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ git clone https://github.com/AI-Hypercomputer/maxtext.git

RUN cd maxtext/ && \
git checkout ${MAXTEXT_VERSION} && \
bash setup.sh
bash ./tools/setup/setup.sh

RUN cd /JetStream && \
git checkout ${JETSTREAM_VERSION} && \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ git clone https://github.com/AI-Hypercomputer/JetStream.git

RUN cd maxtext/ && \
git checkout ${MAXTEXT_VERSION} && \
bash setup.sh
bash ./tools/setup/setup.sh

RUN cd /JetStream && \
git checkout ${JETSTREAM_VERSION} && \
Expand All @@ -45,4 +45,4 @@ COPY maxengine_server_entrypoint.sh /usr/bin/

RUN chmod +x /usr/bin/maxengine_server_entrypoint.sh

ENTRYPOINT ["/usr/bin/maxengine_server_entrypoint.sh"]
ENTRYPOINT ["/usr/bin/maxengine_server_entrypoint.sh"]
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,12 @@ ENV MAXTEXT_REPO_ROOT=/deps
WORKDIR /deps

# Copy setup files and dependency files separately for better caching
COPY setup.sh ./
COPY requirements.txt requirements_with_jax_ai_image.txt src/install_maxtext_extra_deps/extra_deps_from_github.txt generated_requirements ./
COPY tools/setup /deps/tools/setup/
COPY dependencies/requirements/ /deps/dependencies/requirements/

# Install dependencies - these steps are cached unless the copied files change
RUN echo "Running command: bash setup.sh MODE=$ENV_MODE JAX_VERSION=$ENV_JAX_VERSION LIBTPU_GCS_PATH=${ENV_LIBTPU_GCS_PATH} DEVICE=${ENV_DEVICE}"
RUN --mount=type=cache,target=/root/.cache/pip bash setup.sh MODE=${ENV_MODE} JAX_VERSION=${ENV_JAX_VERSION} LIBTPU_GCS_PATH=${ENV_LIBTPU_GCS_PATH} DEVICE=${ENV_DEVICE}
RUN --mount=type=cache,target=/root/.cache/pip bash /deps/tools/setup/setup.sh MODE=${ENV_MODE} JAX_VERSION=${ENV_JAX_VERSION} LIBTPU_GCS_PATH=${ENV_LIBTPU_GCS_PATH} DEVICE=${ENV_DEVICE}

# Now copy the remaining code (source files that may change frequently)
COPY . .
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,12 @@ ENV MAXTEXT_REPO_ROOT=/deps
WORKDIR /deps

# Copy setup files and dependency files separately for better caching
COPY setup.sh ./
COPY requirements.txt requirements_with_jax_ai_image.txt src/install_maxtext_extra_deps/extra_deps_from_github.txt generated_requirements ./
COPY tools/setup /deps/tools/setup/
COPY dependencies/requirements/ /deps/dependencies/requirements/

# Install dependencies - these steps are cached unless the copied files change
RUN echo "Running command: bash setup.sh MODE=$ENV_MODE JAX_VERSION=$ENV_JAX_VERSION LIBTPU_GCS_PATH=${ENV_LIBTPU_GCS_PATH} DEVICE=${ENV_DEVICE}"
RUN --mount=type=cache,target=/root/.cache/pip bash setup.sh MODE=${ENV_MODE} JAX_VERSION=${ENV_JAX_VERSION} LIBTPU_GCS_PATH=${ENV_LIBTPU_GCS_PATH} DEVICE=${ENV_DEVICE}
RUN --mount=type=cache,target=/root/.cache/pip bash /deps/tools/setup/setup.sh MODE=${ENV_MODE} JAX_VERSION=${ENV_JAX_VERSION} LIBTPU_GCS_PATH=${ENV_LIBTPU_GCS_PATH} DEVICE=${ENV_DEVICE}

# Now copy the remaining code (source files that may change frequently)
COPY . .
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,12 @@ ENV MAXTEXT_REPO_ROOT=/deps
WORKDIR /deps

# Copy setup files and dependency files separately for better caching
COPY setup.sh ./
COPY requirements.txt requirements_with_jax_ai_image.txt src/install_maxtext_extra_deps/extra_deps_from_github.txt generated_requirements ./
COPY tools/setup /deps/tools/setup/
COPY dependencies/requirements/ /deps/dependencies/requirements/

# Install dependencies - these steps are cached unless the copied files change
RUN echo "Running command: bash setup.sh MODE=$ENV_MODE JAX_VERSION=$ENV_JAX_VERSION DEVICE=${ENV_DEVICE}"
RUN --mount=type=cache,target=/root/.cache/pip bash setup.sh MODE=${ENV_MODE} JAX_VERSION=${ENV_JAX_VERSION} DEVICE=${ENV_DEVICE}
RUN --mount=type=cache,target=/root/.cache/pip bash /deps/tools/setup/setup.sh MODE=${ENV_MODE} JAX_VERSION=${ENV_JAX_VERSION} DEVICE=${ENV_DEVICE}

# Now copy the remaining code (source files that may change frequently)
COPY . .
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
ARG JAX_AI_IMAGE_BASEIMAGE

# JAX AI Base Image
FROM $JAX_AI_IMAGE_BASEIMAGE

Check warning on line 4 in dependencies/dockerfiles/maxtext_jax_ai_image.Dockerfile

View workflow job for this annotation

GitHub Actions / tpu_image / Build and upload image (v4-8)

Default value for global ARG results in an empty or invalid base image name

InvalidDefaultArgInFrom: Default value for ARG $JAX_AI_IMAGE_BASEIMAGE results in empty or invalid base image name More info: https://docs.docker.com/go/dockerfile/rule/invalid-default-arg-in-from/

Check warning on line 4 in dependencies/dockerfiles/maxtext_jax_ai_image.Dockerfile

View workflow job for this annotation

GitHub Actions / gpu_image / Build and upload image (a100-40gb-4)

Default value for global ARG results in an empty or invalid base image name

InvalidDefaultArgInFrom: Default value for ARG $JAX_AI_IMAGE_BASEIMAGE results in empty or invalid base image name More info: https://docs.docker.com/go/dockerfile/rule/invalid-default-arg-in-from/
ARG JAX_AI_IMAGE_BASEIMAGE

ARG COMMIT_HASH
Expand All @@ -16,9 +16,8 @@
WORKDIR /deps

# Copy setup files and dependency files separately for better caching
COPY setup.sh ./
COPY requirements.txt requirements_with_jax_ai_image.txt requirements_with_jax_stable_stack_0_6_1_pipreqs.txt src/install_maxtext_extra_deps/extra_deps_from_github.txt generated_requirements ./

COPY tools/setup /deps/tools/setup/
COPY dependencies/requirements/ /deps/dependencies/requirements/

# For JAX AI tpu training images 0.4.37 AND 0.4.35
# Orbax checkpoint installs the latest version of JAX,
Expand All @@ -42,9 +41,9 @@
# pipreqs --savepath requirements_with_jax_stable_stack_0_6_1_pipreqs.txt
# Otherwise use general requirements_with_jax_ai_image.txt
RUN if [ "$DEVICE" = "tpu" ] && [ "$JAX_STABLE_STACK_BASEIMAGE" = "us-docker.pkg.dev/cloud-tpu-images/jax-ai-image/tpu:jax0.6.1-rev1" ]; then \
python3 -m pip install -r /deps/requirements_with_jax_stable_stack_0_6_1_pipreqs.txt; \
python3 -m pip install -r /deps/dependencies/requirements/requirements_with_jax_stable_stack_0_6_1_pipreqs.txt; \
else \
python3 -m pip install -r /deps/requirements_with_jax_ai_image.txt; \
python3 -m pip install -r /deps/dependencies/requirements/requirements_with_jax_ai_image.txt; \
fi

# Install google-tunix for TPU devices, skip for GPU
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ WORKDIR /deps
# Copy assets separately
COPY src/MaxText/assets/ "${MAXTEXT_ASSETS_ROOT}"
COPY src/MaxText/test_assets/ "${MAXTEXT_TEST_ASSETS_ROOT}"
COPY generated_requirements .

# Copy all files except assets from local workspace into docker container
COPY --exclude="${MAXTEXT_ASSETS_ROOT}" --exclude="${MAXTEXT_TEST_ASSETS_ROOT}" . .
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,17 @@

# bash docker_build_dependency_image.sh MODE=post-training

if [ "${BASH_SOURCE-}" ]; then
this_file="${BASH_SOURCE[0]}"
elif [ "${ZSH_VERSION-}" ]; then
# shellcheck disable=SC2296
this_file="${(%):-%x}"
else
this_file="${0}"
fi

MAXTEXT_REPO_ROOT="${MAXTEXT_REPO_ROOT:-$(CDPATH='' cd -- "$(dirname -- "${this_file}")"'/../..' && pwd)}"

# Enable "exit immediately if any command fails" option
set -e

Expand Down Expand Up @@ -101,7 +112,8 @@ build_ai_image() {
--build-arg DEVICE="$DEVICE" \
--network=host \
-t ${LOCAL_IMAGE_NAME} \
-f ./maxtext_jax_ai_image.Dockerfile .
-f "$MAXTEXT_REPO_ROOT"'/dependencies/dockerfiles/maxtext_jax_ai_image.Dockerfile' \
.
}

if [[ -z ${LIBTPU_GCS_PATH+x} ]] ; then
Expand All @@ -116,24 +128,41 @@ if [[ -z ${LIBTPU_GCS_PATH+x} ]] ; then
else
export BASEIMAGE=ghcr.io/nvidia/jax:base
fi
docker build --network host --build-arg MODE=${MODE} --build-arg JAX_VERSION=$JAX_VERSION --build-arg DEVICE=$DEVICE --build-arg BASEIMAGE=$BASEIMAGE -f ./maxtext_gpu_dependencies.Dockerfile -t ${LOCAL_IMAGE_NAME} .
docker build --network host --build-arg MODE=${MODE} --build-arg JAX_VERSION=$JAX_VERSION \
--build-arg DEVICE=$DEVICE --build-arg BASEIMAGE=$BASEIMAGE \
-f "$MAXTEXT_REPO_ROOT"'/dependencies/dockerfiles/maxtext_gpu_dependencies.Dockerfile' \
-t ${LOCAL_IMAGE_NAME} .
fi
else
if [[ ${MODE} == "stable_stack" || ${MODE} == "jax_ai_image" ]]; then
build_ai_image
elif [[ ${MANTARAY} == "true" ]]; then
echo "Building with benchmark-db"
docker build --network host --build-arg MODE=${MODE} --build-arg JAX_VERSION=$JAX_VERSION --build-arg LIBTPU_GCS_PATH=$LIBTPU_GCS_PATH --build-arg DEVICE=$DEVICE -f ./maxtext_db_dependencies.Dockerfile -t ${LOCAL_IMAGE_NAME} .
docker build --network host --build-arg MODE=${MODE} --build-arg JAX_VERSION=$JAX_VERSION \
--build-arg LIBTPU_GCS_PATH=$LIBTPU_GCS_PATH --build-arg DEVICE=$DEVICE \
-f "$MAXTEXT_REPO_ROOT"'/dependencies/dockerfiles/maxtext_db_dependencies.Dockerfile' \
-t ${LOCAL_IMAGE_NAME} .
elif [[ ${INSTALL_POST_TRAINING} -eq 1 && ${DEVICE} == "tpu" ]]; then
echo "Installing MaxText stable mode dependencies for Post-Training"
docker build --network host --build-arg MODE=stable --build-arg JAX_VERSION=$JAX_VERSION --build-arg LIBTPU_GCS_PATH=$LIBTPU_GCS_PATH --build-arg DEVICE=$DEVICE -f ./maxtext_dependencies.Dockerfile -t ${LOCAL_IMAGE_NAME} .
echo "Installing MaxText stable mode dependencies for GRPO"
docker build --network host --build-arg MODE=stable --build-arg JAX_VERSION=$JAX_VERSION \
--build-arg LIBTPU_GCS_PATH=$LIBTPU_GCS_PATH --build-arg DEVICE=$DEVICE \
-f "$MAXTEXT_REPO_ROOT"'/dependencies/dockerfiles/maxtext_dependencies.Dockerfile' \
-t ${LOCAL_IMAGE_NAME} .
else
docker build --network host --build-arg MODE=${MODE} --build-arg JAX_VERSION=$JAX_VERSION --build-arg LIBTPU_GCS_PATH=$LIBTPU_GCS_PATH --build-arg DEVICE=$DEVICE -f ./maxtext_dependencies.Dockerfile -t ${LOCAL_IMAGE_NAME} .
docker build --network host --build-arg MODE=${MODE} --build-arg JAX_VERSION=$JAX_VERSION \
--build-arg LIBTPU_GCS_PATH=$LIBTPU_GCS_PATH --build-arg DEVICE=$DEVICE \
-f "$MAXTEXT_REPO_ROOT"'/dependencies/dockerfiles/maxtext_dependencies.Dockerfile' \
-t ${LOCAL_IMAGE_NAME} .
fi
fi
else
docker build --network host --build-arg MODE=${MODE} --build-arg JAX_VERSION=$JAX_VERSION --build-arg LIBTPU_GCS_PATH=$LIBTPU_GCS_PATH -f ./maxtext_dependencies.Dockerfile -t ${LOCAL_IMAGE_NAME} .
docker build --network host --build-arg CUSTOM_LIBTPU=true -f ./maxtext_libtpu_path.Dockerfile -t ${LOCAL_IMAGE_NAME} .
docker build --network host --build-arg MODE=${MODE} --build-arg JAX_VERSION=$JAX_VERSION \
--build-arg LIBTPU_GCS_PATH=$LIBTPU_GCS_PATH \
-f "$MAXTEXT_REPO_ROOT"'/dependencies/dockerfiles/maxtext_dependencies.Dockerfile' \
-t ${LOCAL_IMAGE_NAME} .
docker build --network host --build-arg CUSTOM_LIBTPU=true \
-f "$MAXTEXT_REPO_ROOT"'/dependencies/dockerfiles/maxtext_libtpu_path.Dockerfile' \
-t ${LOCAL_IMAGE_NAME} .
fi

if [[ ${INSTALL_POST_TRAINING} -eq 1 ]] ; then
Expand All @@ -158,13 +187,15 @@ if [[ ${INSTALL_POST_TRAINING} -eq 1 ]] ; then
--network host \
--build-arg BASEIMAGE=${LOCAL_IMAGE_NAME} \
--build-arg MODE=${MODE} \
-f ./maxtext_post_training_dependencies.Dockerfile \
-f "$MAXTEXT_REPO_ROOT"'/dependencies/dockerfiles/maxtext_post_training_dependencies.Dockerfile' \
-t ${LOCAL_IMAGE_NAME} .
fi

if [[ ${CUSTOM_JAX} -eq 1 ]] ; then
echo "Installing custom jax and jaxlib"
docker build --network host -f ./maxtext_custom_wheels.Dockerfile -t ${LOCAL_IMAGE_NAME} .
docker build --network host \
-f "$MAXTEXT_REPO_ROOT"'/dependencies/dockerfiles/maxtext_custom_wheels.Dockerfile' \
-t ${LOCAL_IMAGE_NAME} .
fi

echo ""
Expand Down
Loading