From c8b8de8b1a77767eadffeb9caea00ac5053a7296 Mon Sep 17 00:00:00 2001 From: Samuel Marks <807580+SamuelMarks@users.noreply.github.com> Date: Wed, 29 Oct 2025 10:33:37 -0700 Subject: [PATCH] Refactor: 3.1 phase of RESTRUCTURE.md, focussing on setup.sh and related Python dependency & Dockerfile files --- .github/workflows/UploadDockerImages.yml | 16 ++-- .github/workflows/build_and_upload_images.sh | 13 ++- .github/workflows/build_upload_internal.yml | 2 +- .github/workflows/check_docs_build.yml | 2 +- .../workflows/run_tests_against_package.yml | 2 +- README.md | 6 +- base_requirements/requirements.txt | 43 --------- .../dockerfiles/clean_py_env.Dockerfile | 0 .../dockerfiles/clean_py_env_gpu.Dockerfile | 0 .../dockerfiles/jetstream_pathways.Dockerfile | 2 +- .../dockerfiles/maxengine_server.Dockerfile | 4 +- .../maxtext_custom_wheels.Dockerfile | 0 .../maxtext_db_dependencies.Dockerfile | 6 +- .../maxtext_dependencies.Dockerfile | 6 +- .../maxtext_gpu_dependencies.Dockerfile | 6 +- .../maxtext_jax_ai_image.Dockerfile | 9 +- .../maxtext_libtpu_path.Dockerfile | 0 ...text_post_training_dependencies.Dockerfile | 0 .../dockerfiles/maxtext_runner.Dockerfile | 1 - .../requirements}/cuda12-requirements.txt | 0 .../requirements}/extra_deps_from_github.txt | 0 .../requirements}/gpu-base-requirements.txt | 0 .../requirements/requirements.txt | 0 .../requirements/requirements_docs.txt | 0 .../requirements_with_jax_ai_image.txt | 0 ...ts_with_jax_stable_stack_0_6_1_pipreqs.txt | 0 .../requirements}/tpu-base-requirements.txt | 0 .../requirements}/tpu-requirements.txt | 0 .../scripts/docker_build_dependency_image.sh | 51 +++++++++-- .../scripts/docker_upload_runner.sh | 15 ++- pyproject.toml | 9 +- src/MaxText/examples/demo_decoding.ipynb | 2 +- src/MaxText/examples/sft_llama3_demo.ipynb | 2 +- src/MaxText/examples/sft_qwen3_demo.ipynb | 2 +- src/install_maxtext_extra_deps/__init__.py | 13 --- .../install_github_deps.py | 91 ------------------- setup.sh => tools/setup/setup.sh | 59 ++++++++---- .../setup/setup_gcsfuse.sh | 0 .../setup/setup_with_retries.sh | 0 39 files changed, 141 insertions(+), 221 deletions(-) delete mode 100644 base_requirements/requirements.txt rename clean_py_env.Dockerfile => dependencies/dockerfiles/clean_py_env.Dockerfile (100%) rename clean_py_env_gpu.Dockerfile => dependencies/dockerfiles/clean_py_env_gpu.Dockerfile (100%) rename src/MaxText/inference/jetstream_pathways/Dockerfile => dependencies/dockerfiles/jetstream_pathways.Dockerfile (98%) rename src/MaxText/inference/maxengine_server/Dockerfile => dependencies/dockerfiles/maxengine_server.Dockerfile (94%) rename maxtext_custom_wheels.Dockerfile => dependencies/dockerfiles/maxtext_custom_wheels.Dockerfile (100%) rename maxtext_db_dependencies.Dockerfile => dependencies/dockerfiles/maxtext_db_dependencies.Dockerfile (85%) rename maxtext_dependencies.Dockerfile => dependencies/dockerfiles/maxtext_dependencies.Dockerfile (85%) rename maxtext_gpu_dependencies.Dockerfile => dependencies/dockerfiles/maxtext_gpu_dependencies.Dockerfile (87%) rename maxtext_jax_ai_image.Dockerfile => dependencies/dockerfiles/maxtext_jax_ai_image.Dockerfile (87%) rename maxtext_libtpu_path.Dockerfile => dependencies/dockerfiles/maxtext_libtpu_path.Dockerfile (100%) rename maxtext_post_training_dependencies.Dockerfile => dependencies/dockerfiles/maxtext_post_training_dependencies.Dockerfile (100%) rename maxtext_runner.Dockerfile => dependencies/dockerfiles/maxtext_runner.Dockerfile (95%) rename {generated_requirements => dependencies/requirements}/cuda12-requirements.txt (100%) rename {src/install_maxtext_extra_deps => dependencies/requirements}/extra_deps_from_github.txt (100%) rename {base_requirements => dependencies/requirements}/gpu-base-requirements.txt (100%) rename requirements.txt => dependencies/requirements/requirements.txt (100%) rename requirements_docs.txt => dependencies/requirements/requirements_docs.txt (100%) rename requirements_with_jax_ai_image.txt => dependencies/requirements/requirements_with_jax_ai_image.txt (100%) rename requirements_with_jax_stable_stack_0_6_1_pipreqs.txt => dependencies/requirements/requirements_with_jax_stable_stack_0_6_1_pipreqs.txt (100%) rename {base_requirements => dependencies/requirements}/tpu-base-requirements.txt (100%) rename {generated_requirements => dependencies/requirements}/tpu-requirements.txt (100%) rename docker_build_dependency_image.sh => dependencies/scripts/docker_build_dependency_image.sh (74%) rename docker_upload_runner.sh => dependencies/scripts/docker_upload_runner.sh (88%) delete mode 100644 src/install_maxtext_extra_deps/__init__.py delete mode 100644 src/install_maxtext_extra_deps/install_github_deps.py rename setup.sh => tools/setup/setup.sh (84%) rename setup_gcsfuse.sh => tools/setup/setup_gcsfuse.sh (100%) rename setup_with_retries.sh => tools/setup/setup_with_retries.sh (100%) diff --git a/.github/workflows/UploadDockerImages.yml b/.github/workflows/UploadDockerImages.yml index eb4c769678..76945ed6e8 100644 --- a/.github/workflows/UploadDockerImages.yml +++ b/.github/workflows/UploadDockerImages.yml @@ -52,26 +52,26 @@ jobs: include: # TPU Image Builds - image_name: maxtext_jax_stable - dockerfile: ./maxtext_dependencies.Dockerfile + dockerfile: ./dependencies/dockerfiles/maxtext_dependencies.Dockerfile build_args: | MODE=stable JAX_VERSION=NONE LIBTPU_GCS_PATH=NONE - image_name: maxtext_jax_nightly - dockerfile: ./maxtext_dependencies.Dockerfile + dockerfile: ./dependencies/dockerfiles/maxtext_dependencies.Dockerfile build_args: | MODE=nightly JAX_VERSION=NONE LIBTPU_GCS_PATH=NONE # TPU Image builds using JAX AI Image - image_name: maxtext_jax_stable_stack - dockerfile: ./maxtext_jax_ai_image.Dockerfile + dockerfile: ./dependencies/dockerfiles/maxtext_jax_ai_image.Dockerfile base_image: us-docker.pkg.dev/cloud-tpu-images/jax-ai-image/tpu:latest - image_name: maxtext_stable_stack_nightly_jax - dockerfile: ./maxtext_jax_ai_image.Dockerfile + dockerfile: ./dependencies/dockerfiles/maxtext_jax_ai_image.Dockerfile base_image: us-docker.pkg.dev/tpu-prod-env-multipod/jax-stable-stack/tpu/jax_nightly:latest - image_name: maxtext_stable_stack_candidate - dockerfile: ./maxtext_jax_ai_image.Dockerfile + dockerfile: ./dependencies/dockerfiles/maxtext_jax_ai_image.Dockerfile base_image: us-docker.pkg.dev/tpu-prod-env-multipod/jax-stable-stack/candidate/tpu:latest # Setup for GKE runners per b/412986220#comment82 and b/412986220#comment90 @@ -130,13 +130,13 @@ jobs: # GPU Image Builds using JAX AI Image include: - image_name: maxtext_gpu_jax_stable_stack - dockerfile: ./maxtext_jax_ai_image.Dockerfile + dockerfile: ./dependencies/dockerfiles/maxtext_jax_ai_image.Dockerfile base_image: us-central1-docker.pkg.dev/deeplearning-images/jax-ai-image/gpu:latest - image_name: maxtext_gpu_stable_stack_nightly_jax - dockerfile: ./maxtext_jax_ai_image.Dockerfile + dockerfile: ./dependencies/dockerfiles/maxtext_jax_ai_image.Dockerfile base_image: us-docker.pkg.dev/tpu-prod-env-multipod/jax-stable-stack/gpu/jax_nightly:latest - image_name: maxtext_stable_stack_candidate_gpu - dockerfile: ./maxtext_jax_ai_image.Dockerfile + dockerfile: ./dependencies/dockerfiles/maxtext_jax_ai_image.Dockerfile base_image: us-docker.pkg.dev/tpu-prod-env-multipod/jax-stable-stack/candidate/gpu:latest steps: diff --git a/.github/workflows/build_and_upload_images.sh b/.github/workflows/build_and_upload_images.sh index 5690abd9df..dd9c974207 100644 --- a/.github/workflows/build_and_upload_images.sh +++ b/.github/workflows/build_and_upload_images.sh @@ -24,7 +24,16 @@ # Example command: # bash build_and_upload_images.sh PROJECT= MODE=stable DEVICE=tpu CLOUD_IMAGE_NAME=${USER}_runner +if [ "${BASH_SOURCE-}" ]; then + this_file="${BASH_SOURCE[0]}" +elif [ "${ZSH_VERSION-}" ]; then + # shellcheck disable=SC2296 + this_file="${(%):-%x}" +else + this_file="${0}" +fi +MAXTEXT_REPO_ROOT="${MAXTEXT_REPO_ROOT:-$(CDPATH='' cd -- "$(dirname -- "${this_file}")"'/../..' && pwd)}" export LOCAL_IMAGE_NAME=maxtext_base_image # Set environment variables @@ -40,7 +49,7 @@ if [[ ! -v CLOUD_IMAGE_NAME ]] || [[ ! -v PROJECT ]] || [[ ! -v MODE ]] || [[ ! fi gcloud auth configure-docker us-docker.pkg.dev --quiet -bash docker_build_dependency_image.sh LOCAL_IMAGE_NAME=$LOCAL_IMAGE_NAME MODE=$MODE DEVICE=$DEVICE +bash "$MAXTEXT_REPO_ROOT"'/dependencies/scripts/docker_build_dependency_image.sh' LOCAL_IMAGE_NAME=$LOCAL_IMAGE_NAME MODE="$MODE" DEVICE="$DEVICE" image_date=$(date +%Y-%m-%d) # Upload only dependencies image @@ -56,7 +65,7 @@ if ! gcloud storage cp gs://maxtext-test-assets/* "${MAXTEXT_TEST_ASSETS_ROOT:-$ fi # Build then upload "dependencies + code" image -docker build --build-arg BASEIMAGE=${LOCAL_IMAGE_NAME} -f ./maxtext_runner.Dockerfile -t ${LOCAL_IMAGE_NAME}_runner . +docker build --build-arg BASEIMAGE=${LOCAL_IMAGE_NAME} -f "$MAXTEXT_REPO_ROOT"'/dependencies/dockerfiles/maxtext_runner.Dockerfile' -t ${LOCAL_IMAGE_NAME}_runner . docker tag ${LOCAL_IMAGE_NAME}_runner gcr.io/$PROJECT/${CLOUD_IMAGE_NAME}:latest docker push gcr.io/$PROJECT/${CLOUD_IMAGE_NAME}:latest docker tag ${LOCAL_IMAGE_NAME}_runner gcr.io/$PROJECT/${CLOUD_IMAGE_NAME}:${image_date} diff --git a/.github/workflows/build_upload_internal.yml b/.github/workflows/build_upload_internal.yml index 3a073ec5ad..0fba599642 100644 --- a/.github/workflows/build_upload_internal.yml +++ b/.github/workflows/build_upload_internal.yml @@ -59,7 +59,7 @@ jobs: with: push: true context: . - file: ./maxtext_jax_ai_image.Dockerfile + file: ./dependencies/dockerfiles/maxtext_jax_ai_image.Dockerfile tags: gcr.io/tpu-prod-env-multipod/maxtext_${{ github.run_id }}:${{ inputs.device_type }} provenance: false build-args: | diff --git a/.github/workflows/check_docs_build.yml b/.github/workflows/check_docs_build.yml index 880480c4d4..393bd98fd2 100644 --- a/.github/workflows/check_docs_build.yml +++ b/.github/workflows/check_docs_build.yml @@ -24,7 +24,7 @@ jobs: cache: 'pip' # caching pip dependencies - name: Install dependencies - run: pip install -r requirements_docs.txt + run: pip install -r dependencies/requirements/requirements_docs.txt - name: Build documentation run: | diff --git a/.github/workflows/run_tests_against_package.yml b/.github/workflows/run_tests_against_package.yml index 6076efd38a..2e7683e40e 100644 --- a/.github/workflows/run_tests_against_package.yml +++ b/.github/workflows/run_tests_against_package.yml @@ -85,7 +85,7 @@ jobs: source .venv/bin/activate maxtext_wheel=$(ls maxtext-*-py3-none-any.whl 2>/dev/null) uv pip install ${maxtext_wheel}[${MAXTEXT_PACKAGE_EXTRA}] --resolution=lowest - install_maxtext_github_deps + uv pip install -r dependencies/requirements/extra_deps_from_github.txt python3 --version python3 -m pip freeze - name: Copy test assets files diff --git a/README.md b/README.md index 25c8db54fd..3dd78949c9 100644 --- a/README.md +++ b/README.md @@ -45,9 +45,9 @@ pip install uv # 3. Install MaxText and its dependencies uv pip install maxtext --resolution=lowest -install_maxtext_github_deps +uv pip install -r https://raw.githubusercontent.com/AI-Hypercomputer/maxtext/refs/heads/main/dependencies/requirements/extra_deps_from_github.txt ``` -> **Note:** The `install_maxtext_github_deps` command is temporarily required to install dependencies directly from GitHub that are not yet available on PyPI. +> **Note:** The `extra_deps_from_github.txt` is temporarily required to install dependencies directly from GitHub that are not yet available on PyPI. > **Note:** The maxtext package contains a comprehensive list of all direct and transitive dependencies, with lower bounds, generated by [seed-env](https://github.com/google-ml-infra/actions/tree/main/python_seed_env). We highly recommend the `--resolution=lowest` flag. It instructs `uv` to install the specific, tested versions of dependencies defined by MaxText, rather than the latest available ones. This ensures a consistent and reproducible environment, which is critical for stable performance and for running benchmarks. @@ -69,7 +69,7 @@ pip install uv uv pip install -e .[tpu] --resolution=lowest # or install the gpu package by running the following line # uv pip install -e .[cuda12] --resolution=lowest -install_maxtext_github_deps +uv pip install -r dependencies/requirements/extra_deps_from_github.txt ``` After installation, you can verify the package is available with `python3 -c "import MaxText"` and run training jobs with `python3 -m MaxText.train ...`. diff --git a/base_requirements/requirements.txt b/base_requirements/requirements.txt deleted file mode 100644 index bc3e683451..0000000000 --- a/base_requirements/requirements.txt +++ /dev/null @@ -1,43 +0,0 @@ -absl-py -aqtp -array-record -cloud-accelerator-diagnostics -cloud-tpu-diagnostics -datasets -flax -gcsfs -google-api-python-client -google-cloud-aiplatform -google-cloud-monitoring -grain[parquet] -huggingface_hub -jax -jaxlib -jaxtyping -jsonlines -ml-collections -ml-goodput-measurement -numpy -omegaconf -optax -orbax-checkpoint -pathwaysutils -pillow -pre-commit -protobuf -pyink -pylint -pytest -pytype -sentencepiece -tensorboard-plugin-profile -tensorboardx -tensorflow-datasets -tensorflow-text -tensorflow -tiktoken -tokamax -transformers -qwix -google-jetstream @ https://github.com/AI-Hypercomputer/JetStream/archive/29329e8e73820993f77cfc8efe34eb2a73f5de98.zip -mlperf-logging @ https://github.com/mlcommons/logging/archive/38ab22670527888c8eb7825a4ece176fcc36a95d.zip diff --git a/clean_py_env.Dockerfile b/dependencies/dockerfiles/clean_py_env.Dockerfile similarity index 100% rename from clean_py_env.Dockerfile rename to dependencies/dockerfiles/clean_py_env.Dockerfile diff --git a/clean_py_env_gpu.Dockerfile b/dependencies/dockerfiles/clean_py_env_gpu.Dockerfile similarity index 100% rename from clean_py_env_gpu.Dockerfile rename to dependencies/dockerfiles/clean_py_env_gpu.Dockerfile diff --git a/src/MaxText/inference/jetstream_pathways/Dockerfile b/dependencies/dockerfiles/jetstream_pathways.Dockerfile similarity index 98% rename from src/MaxText/inference/jetstream_pathways/Dockerfile rename to dependencies/dockerfiles/jetstream_pathways.Dockerfile index 8fd5dde42c..669961b8af 100644 --- a/src/MaxText/inference/jetstream_pathways/Dockerfile +++ b/dependencies/dockerfiles/jetstream_pathways.Dockerfile @@ -38,7 +38,7 @@ git clone https://github.com/AI-Hypercomputer/maxtext.git RUN cd maxtext/ && \ git checkout ${MAXTEXT_VERSION} && \ -bash setup.sh +bash ./tools/setup/setup.sh RUN cd /JetStream && \ git checkout ${JETSTREAM_VERSION} && \ diff --git a/src/MaxText/inference/maxengine_server/Dockerfile b/dependencies/dockerfiles/maxengine_server.Dockerfile similarity index 94% rename from src/MaxText/inference/maxengine_server/Dockerfile rename to dependencies/dockerfiles/maxengine_server.Dockerfile index ca5efeed27..c0fe5a3820 100644 --- a/src/MaxText/inference/maxengine_server/Dockerfile +++ b/dependencies/dockerfiles/maxengine_server.Dockerfile @@ -35,7 +35,7 @@ git clone https://github.com/AI-Hypercomputer/JetStream.git RUN cd maxtext/ && \ git checkout ${MAXTEXT_VERSION} && \ -bash setup.sh +bash ./tools/setup/setup.sh RUN cd /JetStream && \ git checkout ${JETSTREAM_VERSION} && \ @@ -45,4 +45,4 @@ COPY maxengine_server_entrypoint.sh /usr/bin/ RUN chmod +x /usr/bin/maxengine_server_entrypoint.sh -ENTRYPOINT ["/usr/bin/maxengine_server_entrypoint.sh"] \ No newline at end of file +ENTRYPOINT ["/usr/bin/maxengine_server_entrypoint.sh"] diff --git a/maxtext_custom_wheels.Dockerfile b/dependencies/dockerfiles/maxtext_custom_wheels.Dockerfile similarity index 100% rename from maxtext_custom_wheels.Dockerfile rename to dependencies/dockerfiles/maxtext_custom_wheels.Dockerfile diff --git a/maxtext_db_dependencies.Dockerfile b/dependencies/dockerfiles/maxtext_db_dependencies.Dockerfile similarity index 85% rename from maxtext_db_dependencies.Dockerfile rename to dependencies/dockerfiles/maxtext_db_dependencies.Dockerfile index ab997e2e31..cf1e966697 100644 --- a/maxtext_db_dependencies.Dockerfile +++ b/dependencies/dockerfiles/maxtext_db_dependencies.Dockerfile @@ -40,12 +40,12 @@ ENV MAXTEXT_REPO_ROOT=/deps WORKDIR /deps # Copy setup files and dependency files separately for better caching -COPY setup.sh ./ -COPY requirements.txt requirements_with_jax_ai_image.txt src/install_maxtext_extra_deps/extra_deps_from_github.txt generated_requirements ./ +COPY tools/setup /deps/tools/setup/ +COPY dependencies/requirements/ /deps/dependencies/requirements/ # Install dependencies - these steps are cached unless the copied files change RUN echo "Running command: bash setup.sh MODE=$ENV_MODE JAX_VERSION=$ENV_JAX_VERSION LIBTPU_GCS_PATH=${ENV_LIBTPU_GCS_PATH} DEVICE=${ENV_DEVICE}" -RUN --mount=type=cache,target=/root/.cache/pip bash setup.sh MODE=${ENV_MODE} JAX_VERSION=${ENV_JAX_VERSION} LIBTPU_GCS_PATH=${ENV_LIBTPU_GCS_PATH} DEVICE=${ENV_DEVICE} +RUN --mount=type=cache,target=/root/.cache/pip bash /deps/tools/setup/setup.sh MODE=${ENV_MODE} JAX_VERSION=${ENV_JAX_VERSION} LIBTPU_GCS_PATH=${ENV_LIBTPU_GCS_PATH} DEVICE=${ENV_DEVICE} # Now copy the remaining code (source files that may change frequently) COPY . . diff --git a/maxtext_dependencies.Dockerfile b/dependencies/dockerfiles/maxtext_dependencies.Dockerfile similarity index 85% rename from maxtext_dependencies.Dockerfile rename to dependencies/dockerfiles/maxtext_dependencies.Dockerfile index 9843ba896c..ef72c9fe9c 100644 --- a/maxtext_dependencies.Dockerfile +++ b/dependencies/dockerfiles/maxtext_dependencies.Dockerfile @@ -40,12 +40,12 @@ ENV MAXTEXT_REPO_ROOT=/deps WORKDIR /deps # Copy setup files and dependency files separately for better caching -COPY setup.sh ./ -COPY requirements.txt requirements_with_jax_ai_image.txt src/install_maxtext_extra_deps/extra_deps_from_github.txt generated_requirements ./ +COPY tools/setup /deps/tools/setup/ +COPY dependencies/requirements/ /deps/dependencies/requirements/ # Install dependencies - these steps are cached unless the copied files change RUN echo "Running command: bash setup.sh MODE=$ENV_MODE JAX_VERSION=$ENV_JAX_VERSION LIBTPU_GCS_PATH=${ENV_LIBTPU_GCS_PATH} DEVICE=${ENV_DEVICE}" -RUN --mount=type=cache,target=/root/.cache/pip bash setup.sh MODE=${ENV_MODE} JAX_VERSION=${ENV_JAX_VERSION} LIBTPU_GCS_PATH=${ENV_LIBTPU_GCS_PATH} DEVICE=${ENV_DEVICE} +RUN --mount=type=cache,target=/root/.cache/pip bash /deps/tools/setup/setup.sh MODE=${ENV_MODE} JAX_VERSION=${ENV_JAX_VERSION} LIBTPU_GCS_PATH=${ENV_LIBTPU_GCS_PATH} DEVICE=${ENV_DEVICE} # Now copy the remaining code (source files that may change frequently) COPY . . diff --git a/maxtext_gpu_dependencies.Dockerfile b/dependencies/dockerfiles/maxtext_gpu_dependencies.Dockerfile similarity index 87% rename from maxtext_gpu_dependencies.Dockerfile rename to dependencies/dockerfiles/maxtext_gpu_dependencies.Dockerfile index 474a6e8583..e5c89e7708 100644 --- a/maxtext_gpu_dependencies.Dockerfile +++ b/dependencies/dockerfiles/maxtext_gpu_dependencies.Dockerfile @@ -42,12 +42,12 @@ ENV MAXTEXT_REPO_ROOT=/deps WORKDIR /deps # Copy setup files and dependency files separately for better caching -COPY setup.sh ./ -COPY requirements.txt requirements_with_jax_ai_image.txt src/install_maxtext_extra_deps/extra_deps_from_github.txt generated_requirements ./ +COPY tools/setup /deps/tools/setup/ +COPY dependencies/requirements/ /deps/dependencies/requirements/ # Install dependencies - these steps are cached unless the copied files change RUN echo "Running command: bash setup.sh MODE=$ENV_MODE JAX_VERSION=$ENV_JAX_VERSION DEVICE=${ENV_DEVICE}" -RUN --mount=type=cache,target=/root/.cache/pip bash setup.sh MODE=${ENV_MODE} JAX_VERSION=${ENV_JAX_VERSION} DEVICE=${ENV_DEVICE} +RUN --mount=type=cache,target=/root/.cache/pip bash /deps/tools/setup/setup.sh MODE=${ENV_MODE} JAX_VERSION=${ENV_JAX_VERSION} DEVICE=${ENV_DEVICE} # Now copy the remaining code (source files that may change frequently) COPY . . diff --git a/maxtext_jax_ai_image.Dockerfile b/dependencies/dockerfiles/maxtext_jax_ai_image.Dockerfile similarity index 87% rename from maxtext_jax_ai_image.Dockerfile rename to dependencies/dockerfiles/maxtext_jax_ai_image.Dockerfile index 0d4942bdcf..0a7504d3bd 100644 --- a/maxtext_jax_ai_image.Dockerfile +++ b/dependencies/dockerfiles/maxtext_jax_ai_image.Dockerfile @@ -16,9 +16,8 @@ ENV MAXTEXT_REPO_ROOT=/deps WORKDIR /deps # Copy setup files and dependency files separately for better caching -COPY setup.sh ./ -COPY requirements.txt requirements_with_jax_ai_image.txt requirements_with_jax_stable_stack_0_6_1_pipreqs.txt src/install_maxtext_extra_deps/extra_deps_from_github.txt generated_requirements ./ - +COPY tools/setup /deps/tools/setup/ +COPY dependencies/requirements/ /deps/dependencies/requirements/ # For JAX AI tpu training images 0.4.37 AND 0.4.35 # Orbax checkpoint installs the latest version of JAX, @@ -42,9 +41,9 @@ RUN pip install google-cloud-monitoring # pipreqs --savepath requirements_with_jax_stable_stack_0_6_1_pipreqs.txt # Otherwise use general requirements_with_jax_ai_image.txt RUN if [ "$DEVICE" = "tpu" ] && [ "$JAX_STABLE_STACK_BASEIMAGE" = "us-docker.pkg.dev/cloud-tpu-images/jax-ai-image/tpu:jax0.6.1-rev1" ]; then \ - python3 -m pip install -r /deps/requirements_with_jax_stable_stack_0_6_1_pipreqs.txt; \ + python3 -m pip install -r /deps/dependencies/requirements/requirements_with_jax_stable_stack_0_6_1_pipreqs.txt; \ else \ - python3 -m pip install -r /deps/requirements_with_jax_ai_image.txt; \ + python3 -m pip install -r /deps/dependencies/requirements/requirements_with_jax_ai_image.txt; \ fi # Install google-tunix for TPU devices, skip for GPU diff --git a/maxtext_libtpu_path.Dockerfile b/dependencies/dockerfiles/maxtext_libtpu_path.Dockerfile similarity index 100% rename from maxtext_libtpu_path.Dockerfile rename to dependencies/dockerfiles/maxtext_libtpu_path.Dockerfile diff --git a/maxtext_post_training_dependencies.Dockerfile b/dependencies/dockerfiles/maxtext_post_training_dependencies.Dockerfile similarity index 100% rename from maxtext_post_training_dependencies.Dockerfile rename to dependencies/dockerfiles/maxtext_post_training_dependencies.Dockerfile diff --git a/maxtext_runner.Dockerfile b/dependencies/dockerfiles/maxtext_runner.Dockerfile similarity index 95% rename from maxtext_runner.Dockerfile rename to dependencies/dockerfiles/maxtext_runner.Dockerfile index 4f499bb7e5..151ac3c353 100644 --- a/maxtext_runner.Dockerfile +++ b/dependencies/dockerfiles/maxtext_runner.Dockerfile @@ -16,7 +16,6 @@ WORKDIR /deps # Copy assets separately COPY src/MaxText/assets/ "${MAXTEXT_ASSETS_ROOT}" COPY src/MaxText/test_assets/ "${MAXTEXT_TEST_ASSETS_ROOT}" -COPY generated_requirements . # Copy all files except assets from local workspace into docker container COPY --exclude="${MAXTEXT_ASSETS_ROOT}" --exclude="${MAXTEXT_TEST_ASSETS_ROOT}" . . diff --git a/generated_requirements/cuda12-requirements.txt b/dependencies/requirements/cuda12-requirements.txt similarity index 100% rename from generated_requirements/cuda12-requirements.txt rename to dependencies/requirements/cuda12-requirements.txt diff --git a/src/install_maxtext_extra_deps/extra_deps_from_github.txt b/dependencies/requirements/extra_deps_from_github.txt similarity index 100% rename from src/install_maxtext_extra_deps/extra_deps_from_github.txt rename to dependencies/requirements/extra_deps_from_github.txt diff --git a/base_requirements/gpu-base-requirements.txt b/dependencies/requirements/gpu-base-requirements.txt similarity index 100% rename from base_requirements/gpu-base-requirements.txt rename to dependencies/requirements/gpu-base-requirements.txt diff --git a/requirements.txt b/dependencies/requirements/requirements.txt similarity index 100% rename from requirements.txt rename to dependencies/requirements/requirements.txt diff --git a/requirements_docs.txt b/dependencies/requirements/requirements_docs.txt similarity index 100% rename from requirements_docs.txt rename to dependencies/requirements/requirements_docs.txt diff --git a/requirements_with_jax_ai_image.txt b/dependencies/requirements/requirements_with_jax_ai_image.txt similarity index 100% rename from requirements_with_jax_ai_image.txt rename to dependencies/requirements/requirements_with_jax_ai_image.txt diff --git a/requirements_with_jax_stable_stack_0_6_1_pipreqs.txt b/dependencies/requirements/requirements_with_jax_stable_stack_0_6_1_pipreqs.txt similarity index 100% rename from requirements_with_jax_stable_stack_0_6_1_pipreqs.txt rename to dependencies/requirements/requirements_with_jax_stable_stack_0_6_1_pipreqs.txt diff --git a/base_requirements/tpu-base-requirements.txt b/dependencies/requirements/tpu-base-requirements.txt similarity index 100% rename from base_requirements/tpu-base-requirements.txt rename to dependencies/requirements/tpu-base-requirements.txt diff --git a/generated_requirements/tpu-requirements.txt b/dependencies/requirements/tpu-requirements.txt similarity index 100% rename from generated_requirements/tpu-requirements.txt rename to dependencies/requirements/tpu-requirements.txt diff --git a/docker_build_dependency_image.sh b/dependencies/scripts/docker_build_dependency_image.sh similarity index 74% rename from docker_build_dependency_image.sh rename to dependencies/scripts/docker_build_dependency_image.sh index 0ef64ad55a..695ee29b41 100644 --- a/docker_build_dependency_image.sh +++ b/dependencies/scripts/docker_build_dependency_image.sh @@ -29,6 +29,17 @@ # bash docker_build_dependency_image.sh MODE=post-training +if [ "${BASH_SOURCE-}" ]; then + this_file="${BASH_SOURCE[0]}" +elif [ "${ZSH_VERSION-}" ]; then + # shellcheck disable=SC2296 + this_file="${(%):-%x}" +else + this_file="${0}" +fi + +MAXTEXT_REPO_ROOT="${MAXTEXT_REPO_ROOT:-$(CDPATH='' cd -- "$(dirname -- "${this_file}")"'/../..' && pwd)}" + # Enable "exit immediately if any command fails" option set -e @@ -101,7 +112,8 @@ build_ai_image() { --build-arg DEVICE="$DEVICE" \ --network=host \ -t ${LOCAL_IMAGE_NAME} \ - -f ./maxtext_jax_ai_image.Dockerfile . + -f "$MAXTEXT_REPO_ROOT"'/dependencies/dockerfiles/maxtext_jax_ai_image.Dockerfile' \ + . } if [[ -z ${LIBTPU_GCS_PATH+x} ]] ; then @@ -116,24 +128,41 @@ if [[ -z ${LIBTPU_GCS_PATH+x} ]] ; then else export BASEIMAGE=ghcr.io/nvidia/jax:base fi - docker build --network host --build-arg MODE=${MODE} --build-arg JAX_VERSION=$JAX_VERSION --build-arg DEVICE=$DEVICE --build-arg BASEIMAGE=$BASEIMAGE -f ./maxtext_gpu_dependencies.Dockerfile -t ${LOCAL_IMAGE_NAME} . + docker build --network host --build-arg MODE=${MODE} --build-arg JAX_VERSION=$JAX_VERSION \ + --build-arg DEVICE=$DEVICE --build-arg BASEIMAGE=$BASEIMAGE \ + -f "$MAXTEXT_REPO_ROOT"'/dependencies/dockerfiles/maxtext_gpu_dependencies.Dockerfile' \ + -t ${LOCAL_IMAGE_NAME} . fi else if [[ ${MODE} == "stable_stack" || ${MODE} == "jax_ai_image" ]]; then build_ai_image elif [[ ${MANTARAY} == "true" ]]; then echo "Building with benchmark-db" - docker build --network host --build-arg MODE=${MODE} --build-arg JAX_VERSION=$JAX_VERSION --build-arg LIBTPU_GCS_PATH=$LIBTPU_GCS_PATH --build-arg DEVICE=$DEVICE -f ./maxtext_db_dependencies.Dockerfile -t ${LOCAL_IMAGE_NAME} . + docker build --network host --build-arg MODE=${MODE} --build-arg JAX_VERSION=$JAX_VERSION \ + --build-arg LIBTPU_GCS_PATH=$LIBTPU_GCS_PATH --build-arg DEVICE=$DEVICE \ + -f "$MAXTEXT_REPO_ROOT"'/dependencies/dockerfiles/maxtext_db_dependencies.Dockerfile' \ + -t ${LOCAL_IMAGE_NAME} . elif [[ ${INSTALL_POST_TRAINING} -eq 1 && ${DEVICE} == "tpu" ]]; then - echo "Installing MaxText stable mode dependencies for Post-Training" - docker build --network host --build-arg MODE=stable --build-arg JAX_VERSION=$JAX_VERSION --build-arg LIBTPU_GCS_PATH=$LIBTPU_GCS_PATH --build-arg DEVICE=$DEVICE -f ./maxtext_dependencies.Dockerfile -t ${LOCAL_IMAGE_NAME} . + echo "Installing MaxText stable mode dependencies for GRPO" + docker build --network host --build-arg MODE=stable --build-arg JAX_VERSION=$JAX_VERSION \ + --build-arg LIBTPU_GCS_PATH=$LIBTPU_GCS_PATH --build-arg DEVICE=$DEVICE \ + -f "$MAXTEXT_REPO_ROOT"'/dependencies/dockerfiles/maxtext_dependencies.Dockerfile' \ + -t ${LOCAL_IMAGE_NAME} . else - docker build --network host --build-arg MODE=${MODE} --build-arg JAX_VERSION=$JAX_VERSION --build-arg LIBTPU_GCS_PATH=$LIBTPU_GCS_PATH --build-arg DEVICE=$DEVICE -f ./maxtext_dependencies.Dockerfile -t ${LOCAL_IMAGE_NAME} . + docker build --network host --build-arg MODE=${MODE} --build-arg JAX_VERSION=$JAX_VERSION \ + --build-arg LIBTPU_GCS_PATH=$LIBTPU_GCS_PATH --build-arg DEVICE=$DEVICE \ + -f "$MAXTEXT_REPO_ROOT"'/dependencies/dockerfiles/maxtext_dependencies.Dockerfile' \ + -t ${LOCAL_IMAGE_NAME} . fi fi else - docker build --network host --build-arg MODE=${MODE} --build-arg JAX_VERSION=$JAX_VERSION --build-arg LIBTPU_GCS_PATH=$LIBTPU_GCS_PATH -f ./maxtext_dependencies.Dockerfile -t ${LOCAL_IMAGE_NAME} . - docker build --network host --build-arg CUSTOM_LIBTPU=true -f ./maxtext_libtpu_path.Dockerfile -t ${LOCAL_IMAGE_NAME} . + docker build --network host --build-arg MODE=${MODE} --build-arg JAX_VERSION=$JAX_VERSION \ + --build-arg LIBTPU_GCS_PATH=$LIBTPU_GCS_PATH \ + -f "$MAXTEXT_REPO_ROOT"'/dependencies/dockerfiles/maxtext_dependencies.Dockerfile' \ + -t ${LOCAL_IMAGE_NAME} . + docker build --network host --build-arg CUSTOM_LIBTPU=true \ + -f "$MAXTEXT_REPO_ROOT"'/dependencies/dockerfiles/maxtext_libtpu_path.Dockerfile' \ + -t ${LOCAL_IMAGE_NAME} . fi if [[ ${INSTALL_POST_TRAINING} -eq 1 ]] ; then @@ -158,13 +187,15 @@ if [[ ${INSTALL_POST_TRAINING} -eq 1 ]] ; then --network host \ --build-arg BASEIMAGE=${LOCAL_IMAGE_NAME} \ --build-arg MODE=${MODE} \ - -f ./maxtext_post_training_dependencies.Dockerfile \ + -f "$MAXTEXT_REPO_ROOT"'/dependencies/dockerfiles/maxtext_post_training_dependencies.Dockerfile' \ -t ${LOCAL_IMAGE_NAME} . fi if [[ ${CUSTOM_JAX} -eq 1 ]] ; then echo "Installing custom jax and jaxlib" - docker build --network host -f ./maxtext_custom_wheels.Dockerfile -t ${LOCAL_IMAGE_NAME} . + docker build --network host \ + -f "$MAXTEXT_REPO_ROOT"'/dependencies/dockerfiles/maxtext_custom_wheels.Dockerfile' \ + -t ${LOCAL_IMAGE_NAME} . fi echo "" diff --git a/docker_upload_runner.sh b/dependencies/scripts/docker_upload_runner.sh similarity index 88% rename from docker_upload_runner.sh rename to dependencies/scripts/docker_upload_runner.sh index a23628f590..e343006767 100644 --- a/docker_upload_runner.sh +++ b/dependencies/scripts/docker_upload_runner.sh @@ -23,6 +23,17 @@ # Example command: # bash docker_upload_runner.sh CLOUD_IMAGE_NAME=${USER}_runner +if [ "${BASH_SOURCE-}" ]; then + this_file="${BASH_SOURCE[0]}" +elif [ "${ZSH_VERSION-}" ]; then + # shellcheck disable=SC2296 + this_file="${(%):-%x}" +else + this_file="${0}" +fi + +MAXTEXT_REPO_ROOT="${MAXTEXT_REPO_ROOT:-$(CDPATH='' cd -- "$(dirname -- "${this_file}")"'/../..' && pwd)}" + set -e export LOCAL_IMAGE_NAME=maxtext_base_image @@ -77,7 +88,9 @@ if ! docker image inspect "${LOCAL_IMAGE_NAME}" &> /dev/null; then exit 1 fi -docker build --build-arg BASEIMAGE=${LOCAL_IMAGE_NAME} -f ./maxtext_runner.Dockerfile -t ${LOCAL_IMAGE_NAME_RUNNER} . +docker build --build-arg BASEIMAGE=${LOCAL_IMAGE_NAME} \ + -f "$MAXTEXT_REPO_ROOT"'/dependencies/dockerfiles/maxtext_runner.Dockerfile' \ + -t ${LOCAL_IMAGE_NAME_RUNNER} . docker tag ${LOCAL_IMAGE_NAME_RUNNER} gcr.io/$PROJECT/${CLOUD_IMAGE_NAME}:latest docker push gcr.io/$PROJECT/${CLOUD_IMAGE_NAME}:latest diff --git a/pyproject.toml b/pyproject.toml index 974c80313c..fda074c35b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,8 +26,8 @@ classifiers = [ dependencies = [] [tool.hatch.metadata.hooks.requirements_txt.optional-dependencies] -tpu = ["generated_requirements/tpu-requirements.txt"] -cuda12 = ["generated_requirements/cuda12-requirements.txt"] +tpu = ["dependencies/requirements/tpu-requirements.txt"] +cuda12 = ["dependencies/requirements/cuda12-requirements.txt"] [project.urls] Repository = "https://github.com/AI-Hypercomputer/maxtext.git" @@ -37,7 +37,4 @@ Repository = "https://github.com/AI-Hypercomputer/maxtext.git" allow-direct-references = true [tool.hatch.build.targets.wheel] -packages = ["src/MaxText", "src/install_maxtext_extra_deps"] - -[project.scripts] -install_maxtext_github_deps = "install_maxtext_extra_deps.install_github_deps:main" +packages = ["src/MaxText"] diff --git a/src/MaxText/examples/demo_decoding.ipynb b/src/MaxText/examples/demo_decoding.ipynb index 2c86715bb8..4fed76d5f9 100644 --- a/src/MaxText/examples/demo_decoding.ipynb +++ b/src/MaxText/examples/demo_decoding.ipynb @@ -86,7 +86,7 @@ "\n", "# Install MaxText and dependencies\n", "!uv pip install maxtext --resolution=lowest\n", - "!install_maxtext_github_deps\n", + "!uv pip install -r https://raw.githubusercontent.com/AI-Hypercomputer/maxtext/refs/heads/main/dependencies/requirements/extra_deps_from_github.txt\n", "\n", "# Use nest_asyncio to allow nested event loops in notebooks\n", "!uv pip install nest_asyncio\n", diff --git a/src/MaxText/examples/sft_llama3_demo.ipynb b/src/MaxText/examples/sft_llama3_demo.ipynb index de7d166b37..6ff070b3e5 100644 --- a/src/MaxText/examples/sft_llama3_demo.ipynb +++ b/src/MaxText/examples/sft_llama3_demo.ipynb @@ -109,7 +109,7 @@ "\n", "# 2. Install MaxText and its dependencies\n", "!uv pip install maxtext --resolution=lowest\n", - "!install_maxtext_github_deps" + "!uv pip install -r https://raw.githubusercontent.com/AI-Hypercomputer/maxtext/refs/heads/main/dependencies/requirements/extra_deps_from_github.txt" ] }, { diff --git a/src/MaxText/examples/sft_qwen3_demo.ipynb b/src/MaxText/examples/sft_qwen3_demo.ipynb index 0f6b154c2a..e1b237127d 100644 --- a/src/MaxText/examples/sft_qwen3_demo.ipynb +++ b/src/MaxText/examples/sft_qwen3_demo.ipynb @@ -109,7 +109,7 @@ "\n", "# 2. Install MaxText and its dependencies\n", "!uv pip install maxtext --resolution=lowest\n", - "!install_maxtext_github_deps" + "!uv pip install -r https://raw.githubusercontent.com/AI-Hypercomputer/maxtext/refs/heads/main/dependencies/requirements/extra_deps_from_github.txt" ] }, { diff --git a/src/install_maxtext_extra_deps/__init__.py b/src/install_maxtext_extra_deps/__init__.py deleted file mode 100644 index 715cb9c6e3..0000000000 --- a/src/install_maxtext_extra_deps/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright 2025 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/src/install_maxtext_extra_deps/install_github_deps.py b/src/install_maxtext_extra_deps/install_github_deps.py deleted file mode 100644 index 10b0b6d632..0000000000 --- a/src/install_maxtext_extra_deps/install_github_deps.py +++ /dev/null @@ -1,91 +0,0 @@ -# Copyright 2025 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Installs extra dependencies from a requirements file using uv. - -This script is designed to be run to install dependencies specified in -'extra_deps_from_github.txt', which is expected to be in the same directory. -It first ensures 'uv' is installed and then uses it to install the packages -listed in the requirements file. -""" - -import subprocess -import sys -from pathlib import Path - - -def main(): - """ - Installs extra dependencies specified in extra_deps_from_github.txt using uv. - - This script looks for 'extra_deps_from_github.txt' relative to its own location. - It executes 'uv pip install -r --resolution=lowest'. - """ - script_dir = Path(__file__).resolve().parent - - # Adjust this path if your extra_deps_from_github.txt is in a different location, - # e.g., script_dir / "data" / "extra_deps_from_github.txt" - extra_deps_file = script_dir / "extra_deps_from_github.txt" - - if not extra_deps_file.exists(): - print(f"Error: '{extra_deps_file}' not found.") - print("Please ensure 'extra_deps_from_github.txt' is in the correct location relative to the script.") - sys.exit(1) - # Check if 'uv' is available in the environment - try: - subprocess.run([sys.executable, "-m", "pip", "install", "uv"], check=True, capture_output=True) - subprocess.run([sys.executable, "-m", "uv", "--version"], check=True, capture_output=True) - except subprocess.CalledProcessError as e: - print(f"Error checking uv version: {e}") - print(f"Stderr: {e.stderr.decode()}") - sys.exit(1) - - command = [ - sys.executable, # Use the current Python executable's pip to ensure the correct environment - "-m", - "uv", - "pip", - "install", - "-r", - str(extra_deps_file), - "--no-deps", - ] - - print(f"Installing extra dependencies from '{extra_deps_file}' using uv...") - print(f"Running command: {' '.join(command)}") - - try: - # Run the command - process = subprocess.run(command, check=True, capture_output=True, text=True) - print("Extra dependencies installed successfully!") - print("--- Output from uv ---") - print(process.stdout) - if process.stderr: - print("--- Errors/Warnings from uv (if any) ---") - print(process.stderr) - except subprocess.CalledProcessError as e: - print("Failed to install extra dependencies.") - print(f"Command '{' '.join(e.cmd)}' returned non-zero exit status {e.returncode}.") - print("--- Stderr ---") - print(e.stderr) - print("--- Stdout ---") - print(e.stdout) - sys.exit(e.returncode) - except (OSError, FileNotFoundError) as e: - print(f"An OS-level error occurred while trying to run uv: {e}") - sys.exit(1) - - -if __name__ == "__main__": - main() diff --git a/setup.sh b/tools/setup/setup.sh similarity index 84% rename from setup.sh rename to tools/setup/setup.sh index 3c3a0d14b1..05cd2f93ae 100644 --- a/setup.sh +++ b/tools/setup/setup.sh @@ -114,8 +114,8 @@ if [[ $LIBTPU_GCS_PATH == NONE ]]; then fi if [[ -n $JAX_VERSION && ! ($MODE == "stable" || -z $MODE || ($MODE == "nightly" && $DEVICE == "gpu")) ]]; then - echo -e "\n\nError: You can only specify a JAX_VERSION with stable mode (plus nightly mode on GPU).\n\n" - exit 1 + echo -e "\n\nError: You can only specify a JAX_VERSION with stable mode (plus nightly mode on GPU).\n\n" + exit 1 fi if [[ $DEVICE == "tpu" ]]; then @@ -129,7 +129,6 @@ if [[ $DEVICE == "tpu" ]]; then python3 -m uv pip install -U crcmod # Copy libtpu.so from GCS path gsutil cp "$LIBTPU_GCS_PATH" "$libtpu_path" - exit 0 else echo -e "\n\nError: You must provide a custom libtpu for libtpu-only mode.\n\n" exit 1 @@ -138,7 +137,14 @@ if [[ $DEVICE == "tpu" ]]; then fi if [[ "$MODE" == "nightly" ]]; then - echo "Nightly mode: Installing requirements.txt, stripping commit pins from git+ repos." + if [ "$DEVICE" = "gpu" ]; then + dep_name='dependencies/requirements/cuda12-requirements.txt' + else + dep_name='dependencies/requirements/'"${DEVICE?}"'-requirements.txt' + fi + printf 'Nightly mode: Installing "%s", stripping commit pins from git+ repos.\n' "$dep_name" + nightly_txt="${dep_name##*/}" + nightly_txt="${nightly_txt%.txt}"'-nightly-temp.txt' # Create a temp file, strip commit pins from git+ repos in requirements.txt # Remove/update this section based on the pinned github repo commit in requirements.txt @@ -146,40 +152,53 @@ if [[ "$MODE" == "nightly" ]]; then -e 's|^([^ ]*) @ https?://github.com/([^/]*\/[^/]*)/archive/.*\.zip$|\1@git+https://github.com/\2.git|' \ -e '/JetStream/d' \ -e '/mlperf-logging/d' \ - requirements.txt > requirements.txt.nightly-temp + "$dep_name" > "$nightly_txt" echo "--- Installing modified nightly requirements: ---" - cat requirements.txt.nightly-temp + cat -- "$nightly_txt" echo "-------------------------------------------------" - python3 -m uv pip install --no-cache-dir -U -r requirements.txt.nightly-temp \ - -r "${MAXTEXT_REPO_ROOT?}"'/extra_deps_from_github.txt' - rm requirements.txt.nightly-temp + python3 -m uv pip install --no-cache-dir -U -r "$nightly_txt" \ + -r "${dep_name%/*}"'/extra_deps_from_github.txt' + rm -fv -- "$nightly_txt" else # stable or stable_stack mode: Install with pinned commits - echo "Installing tpu-requirements.txt with pinned commits." - tpu_requirements_txt= - for candidate in 'generated_requirements' "${MAXTEXT_REPO_ROOT?}"'/generated_requirements' "$PWD"; do - if [ -f "$candidate"'/tpu-requirements.txt' ]; then - tpu_requirements_txt="$candidate"'/tpu-requirements.txt' + if [ "$DEVICE" = "gpu" ]; then + dep_basename='cuda12-requirements.txt' + else + dep_basename="${DEVICE?}"'-requirements.txt' + fi + + printf 'Installing "%s" with pinned commits.\n' "$dep_basename" + requirements_txt= + for candidate in 'dependencies/requirements' "${MAXTEXT_REPO_ROOT}"'/dependencies/requirements' "$PWD"; do + if [ -f "$candidate"'/'"$dep_basename" ]; then + requirements_txt="$candidate"'/'"$dep_basename" break else searched="$searched"':' fi done - if [ -z "${tpu_requirements_txt}" ]; then - >&2 printf 'Could not find "tpu-requirements.txt", looked in: %s\n' "${searched%?}" + if [ -z "${requirements_txt}" ]; then + >&2 printf 'Could not find "%s", looked in: %s\n' "$dep_basename" "${searched%?}" exit 2 else - python3 -m uv pip install --resolution=lowest -r "$tpu_requirements_txt" \ - -r "${MAXTEXT_REPO_ROOT?}"'/extra_deps_from_github.txt' + python3 -m uv pip install --resolution=lowest -r "$requirements_txt" \ + -r "${requirements_txt%/*}"'/extra_deps_from_github.txt' fi fi # Install maxtext package if [ -f 'pyproject.toml' ]; then - python3 -m uv pip install -e .[tpu] --no-deps --resolution=lowest - install_maxtext_github_deps + case "$DEVICE" in + 'gpu') python3 -m uv pip install -e .[tpu] --no-deps --resolution=lowest ;; + 'tpu') python3 -m uv pip install -e .[cuda12] --no-deps --resolution=lowest ;; + *) + >&2 printf 'Unsupported device\n' + exit 6 + ;; + esac + python3 -m uv pip install --resolution=lowest -r 'dependencies/requirements/extra_deps_from_github.txt' fi # Delete custom libtpu if it exists diff --git a/setup_gcsfuse.sh b/tools/setup/setup_gcsfuse.sh similarity index 100% rename from setup_gcsfuse.sh rename to tools/setup/setup_gcsfuse.sh diff --git a/setup_with_retries.sh b/tools/setup/setup_with_retries.sh similarity index 100% rename from setup_with_retries.sh rename to tools/setup/setup_with_retries.sh