From 6e0898a907be01b483348f1cb78a531779485b9f Mon Sep 17 00:00:00 2001 From: Rohan Bierneni Date: Thu, 25 Sep 2025 18:00:31 +0000 Subject: [PATCH 01/50] Update entrypoint for jaii Test maxdiffusion workload on gpu image --- .github/workflows/UnitTests.yml | 99 ++++++++++++++++-------- maxdiffusion_jax_ai_image_tpu.Dockerfile | 2 +- 2 files changed, 66 insertions(+), 35 deletions(-) diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index 2c588b43..f8544c47 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -22,43 +22,74 @@ on: push: branches: [ "main" ] workflow_dispatch: - schedule: - # Run the job every 12 hours - - cron: '0 */12 * * *' jobs: - build: - strategy: - fail-fast: false - matrix: - tpu-type: ["v5p-8"] - name: "TPU test (${{ matrix.tpu-type }})" - runs-on: ["self-hosted","${{ matrix.tpu-type }}"] + # STAGE 1: PULL MAXDIFFUSION IMAGE AND RUN WORKLOAD + maxdiffusion_workload: + name: "Run MaxDiffusion Workload" + # IMPORTANT: Replace with the label for your runner (e.g., v5p-8) + runs-on: ["self-hosted", "linux-x86-a2-48-a100-4gpu"] + container: + image: gcr.io/tpu-prod-env-multipod/maxdiffusion_stable_stack_candidate:latest steps: - - uses: actions/checkout@v4 - - name: Set up Python 3.12 - uses: actions/setup-python@v5 - with: - python-version: '3.12' - - name: Install dependencies - run: | - pip install -e . - pip uninstall jax jaxlib libtpu-nightly libtpu -y - bash setup.sh MODE=stable - export PATH=$PATH:$HOME/.local/bin - pip install ruff - pip install isort - pip install pytest - - name: Analysing the code with ruff - run: | - ruff check . - - name: version check - run: | - python --version - pip show jax jaxlib flax transformers datasets tensorflow tensorflow_datasets - - name: PyTest - run: | #--deselect=src/maxdiffusion/tests/input_pipeline_interface_test.py - HF_HUB_CACHE=/mnt/disks/github-runner-disk/ HF_HOME=/mnt/disks/github-runner-disk/ TOKENIZERS_PARALLELISM=false python3 -m pytest --deselect=src/maxdiffusion/tests/ltx_transformer_step_test.py -x + - name: Checkout Repository + uses: actions/checkout@v4 + + - name: Run MaxDiffusion Training + run: | + # This command is adapted from your DAG for a single-slice configuration. + JAX_PLATFORMS=tpu,cpu ENABLE_PJRT_COMPATIBILITY=true \ + TPU_SLICE_BUILDER_DUMP_CHIP_FORCE=true TPU_SLICE_BUILDER_DUMP_ICI=true \ + JAX_FORCE_TPU_INIT=true ENABLE_TPUNETD_CLIENT=true && \ + pip install . && \ + python src/maxdiffusion/train_sdxl.py src/maxdiffusion/configs/base_xl.yml \ + pretrained_model_name_or_path=gs://maxdiffusion-github-runner-test-assets/checkpoints/models--stabilityai--stable-diffusion-xl-base-1.0 \ + revision=refs/pr/95 \ + activations_dtype=bfloat16 \ + weights_dtype=bfloat16 \ + dataset_name=gs://jfacevedo-maxdiffusion-v5p/pokemon-datasets/pokemon-gpt4-captions_sdxl \ + resolution=1024 \ + per_device_batch_size=1 \ + jax_cache_dir=gs://jfacevedo-maxdiffusion/cache_dir/ \ + max_train_steps=20 \ + attention=flash \ + enable_profiler=True \ + run_name=1slice-maxdiffusion-stable-stack-${{ github.run_id }} \ + output_dir=gs://your-output-bucket/maxdiffusion-jax-stable-stack/automated/${{ github.run_id }} + +# jobs: +# build: +# strategy: +# fail-fast: false +# matrix: +# tpu-type: ["v5p-8"] +# name: "TPU test (${{ matrix.tpu-type }})" +# runs-on: ["self-hosted","${{ matrix.tpu-type }}"] +# steps: +# - uses: actions/checkout@v4 +# - name: Set up Python 3.12 +# uses: actions/setup-python@v5 +# with: +# python-version: '3.12' +# - name: Install dependencies +# run: | +# pip install -e . +# pip uninstall jax jaxlib libtpu-nightly libtpu -y +# bash setup.sh MODE=stable +# export PATH=$PATH:$HOME/.local/bin +# pip install ruff +# pip install isort +# pip install pytest +# - name: Analysing the code with ruff +# run: | +# ruff check . +# - name: version check +# run: | +# python --version +# pip show jax jaxlib flax transformers datasets tensorflow tensorflow_datasets +# - name: PyTest +# run: | #--deselect=src/maxdiffusion/tests/input_pipeline_interface_test.py +# HF_HUB_CACHE=/mnt/disks/github-runner-disk/ HF_HOME=/mnt/disks/github-runner-disk/ TOKENIZERS_PARALLELISM=false python3 -m pytest --deselect=src/maxdiffusion/tests/ltx_transformer_step_test.py -x # add_pull_ready: # if: github.ref != 'refs/heads/main' # permissions: diff --git a/maxdiffusion_jax_ai_image_tpu.Dockerfile b/maxdiffusion_jax_ai_image_tpu.Dockerfile index cab50fee..301f9b88 100644 --- a/maxdiffusion_jax_ai_image_tpu.Dockerfile +++ b/maxdiffusion_jax_ai_image_tpu.Dockerfile @@ -19,4 +19,4 @@ COPY . . RUN pip install -r /deps/requirements_with_jax_ai_image.txt # Run the script available in JAX-AI-Image base image to generate the manifest file -RUN bash /jax-stable-stack/generate_manifest.sh PREFIX=maxdiffusion COMMIT_HASH=$COMMIT_HASH \ No newline at end of file +RUN bash /jax-ai-image/generate_manifest.sh PREFIX=maxdiffusion COMMIT_HASH=$COMMIT_HASH \ No newline at end of file From c50ab1b036752a4980aa80d05df8d60bb3a70af4 Mon Sep 17 00:00:00 2001 From: Rohan Bierneni Date: Fri, 26 Sep 2025 15:52:17 +0000 Subject: [PATCH 02/50] remove self-hosted tag --- .github/workflows/UnitTests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index f8544c47..a216fd35 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -28,7 +28,7 @@ jobs: maxdiffusion_workload: name: "Run MaxDiffusion Workload" # IMPORTANT: Replace with the label for your runner (e.g., v5p-8) - runs-on: ["self-hosted", "linux-x86-a2-48-a100-4gpu"] + runs-on: ["linux-x86-a2-48-a100-4gpu"] container: image: gcr.io/tpu-prod-env-multipod/maxdiffusion_stable_stack_candidate:latest steps: From 1f22cf53f0795c8f113316f90c94b3bdf744859c Mon Sep 17 00:00:00 2001 From: Rohan Bierneni Date: Fri, 26 Sep 2025 15:53:26 +0000 Subject: [PATCH 03/50] wrong tag name --- .github/workflows/UnitTests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index a216fd35..bd5cd29e 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -28,7 +28,7 @@ jobs: maxdiffusion_workload: name: "Run MaxDiffusion Workload" # IMPORTANT: Replace with the label for your runner (e.g., v5p-8) - runs-on: ["linux-x86-a2-48-a100-4gpu"] + runs-on: ["linux-x86-a3-megagpu-h100-8gpu"] container: image: gcr.io/tpu-prod-env-multipod/maxdiffusion_stable_stack_candidate:latest steps: From 2deaa7e9fdc0e0e83fed085acb8bb1d04d6261ad Mon Sep 17 00:00:00 2001 From: Rohan Bierneni Date: Fri, 26 Sep 2025 16:09:46 +0000 Subject: [PATCH 04/50] Update with command for gpu --- .github/workflows/UnitTests.yml | 26 +++++++++++--------------- 1 file changed, 11 insertions(+), 15 deletions(-) diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index bd5cd29e..d8695c33 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -38,24 +38,20 @@ jobs: - name: Run MaxDiffusion Training run: | # This command is adapted from your DAG for a single-slice configuration. - JAX_PLATFORMS=tpu,cpu ENABLE_PJRT_COMPATIBILITY=true \ - TPU_SLICE_BUILDER_DUMP_CHIP_FORCE=true TPU_SLICE_BUILDER_DUMP_ICI=true \ - JAX_FORCE_TPU_INIT=true ENABLE_TPUNETD_CLIENT=true && \ - pip install . && \ - python src/maxdiffusion/train_sdxl.py src/maxdiffusion/configs/base_xl.yml \ - pretrained_model_name_or_path=gs://maxdiffusion-github-runner-test-assets/checkpoints/models--stabilityai--stable-diffusion-xl-base-1.0 \ - revision=refs/pr/95 \ + NVTE_FUSED_ATTN=1 pip install . && \ + python -m src.maxdiffusion.train_sdxl src/maxdiffusion/configs/base_xl.yml \ + hardware=gpu \ + train_new_unet=true \ + train_text_encoder=false \ + cache_latents_text_encoder_outputs=true \ + per_device_batch_size=1 \ + attention=cudnn_flash_te \ activations_dtype=bfloat16 \ weights_dtype=bfloat16 \ - dataset_name=gs://jfacevedo-maxdiffusion-v5p/pokemon-datasets/pokemon-gpt4-captions_sdxl \ - resolution=1024 \ - per_device_batch_size=1 \ - jax_cache_dir=gs://jfacevedo-maxdiffusion/cache_dir/ \ - max_train_steps=20 \ - attention=flash \ + max_train_steps=200 \ enable_profiler=True \ - run_name=1slice-maxdiffusion-stable-stack-${{ github.run_id }} \ - output_dir=gs://your-output-bucket/maxdiffusion-jax-stable-stack/automated/${{ github.run_id }} + run_name=1slice-VGpuVersion.XPK_H100_a3-maxdiffusion-jax-stable-stack-2025-09-26-04-12-02 \ + output_dir=gs://ml-auto-solutions/output/maxdiffusion/automated/maxdiffusion_sdxl/${{ github.run_id }} # jobs: # build: From d47a6ebeff3c70aa10fbe07eafea030b075e9da6 Mon Sep 17 00:00:00 2001 From: Rohan Bierneni Date: Fri, 26 Sep 2025 16:28:34 +0000 Subject: [PATCH 05/50] remove pip install . --- .github/workflows/UnitTests.yml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index d8695c33..20fad808 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -38,8 +38,7 @@ jobs: - name: Run MaxDiffusion Training run: | # This command is adapted from your DAG for a single-slice configuration. - NVTE_FUSED_ATTN=1 pip install . && \ - python -m src.maxdiffusion.train_sdxl src/maxdiffusion/configs/base_xl.yml \ + NVTE_FUSED_ATTN=1 python -m src.maxdiffusion.train_sdxl src/maxdiffusion/configs/base_xl.yml \ hardware=gpu \ train_new_unet=true \ train_text_encoder=false \ From a6579ecd68ff5ade6b8be0f9ac73f1da8e196344 Mon Sep 17 00:00:00 2001 From: Rohan Bierneni Date: Fri, 26 Sep 2025 16:30:01 +0000 Subject: [PATCH 06/50] Check environment --- .github/workflows/UnitTests.yml | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index 20fad808..5704c986 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -35,10 +35,15 @@ jobs: - name: Checkout Repository uses: actions/checkout@v4 + - name: Print dependencies + run: | + pip freeze + - name: Run MaxDiffusion Training run: | # This command is adapted from your DAG for a single-slice configuration. - NVTE_FUSED_ATTN=1 python -m src.maxdiffusion.train_sdxl src/maxdiffusion/configs/base_xl.yml \ + NVTE_FUSED_ATTN=1 pip install . && \ + python -m src.maxdiffusion.train_sdxl src/maxdiffusion/configs/base_xl.yml \ hardware=gpu \ train_new_unet=true \ train_text_encoder=false \ From cc5d322f2f8ebdc3f9008f8ceab0681eb0adff0f Mon Sep 17 00:00:00 2001 From: Rohan Bierneni Date: Fri, 26 Sep 2025 16:41:14 +0000 Subject: [PATCH 07/50] Update to right image --- .github/workflows/UnitTests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index 5704c986..95e261d3 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -30,7 +30,7 @@ jobs: # IMPORTANT: Replace with the label for your runner (e.g., v5p-8) runs-on: ["linux-x86-a3-megagpu-h100-8gpu"] container: - image: gcr.io/tpu-prod-env-multipod/maxdiffusion_stable_stack_candidate:latest + image: gcr.io/tpu-prod-env-multipod/maxdiffusion_stable_stack_candidate:jax0.7.2_rev1_gpu steps: - name: Checkout Repository uses: actions/checkout@v4 From 26a03c0bf1cbcbbf6100136031de4783bfd3b41b Mon Sep 17 00:00:00 2001 From: Rohan Bierneni Date: Fri, 26 Sep 2025 16:47:49 +0000 Subject: [PATCH 08/50] point to allowed bucket --- .github/workflows/UnitTests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index 95e261d3..7d36ca48 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -55,7 +55,7 @@ jobs: max_train_steps=200 \ enable_profiler=True \ run_name=1slice-VGpuVersion.XPK_H100_a3-maxdiffusion-jax-stable-stack-2025-09-26-04-12-02 \ - output_dir=gs://ml-auto-solutions/output/maxdiffusion/automated/maxdiffusion_sdxl/${{ github.run_id }} + output_dir=gs://rbierneni-multipod-dev/${{ github.run_id }} # jobs: # build: From 9236f6fcfd6ec2911253dead691daf239f88a7fe Mon Sep 17 00:00:00 2001 From: Rohan Bierneni Date: Fri, 26 Sep 2025 17:25:38 +0000 Subject: [PATCH 09/50] Install entire TE --- .github/workflows/UnitTests.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index 7d36ca48..86326037 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -37,6 +37,8 @@ jobs: - name: Print dependencies run: | + pip uninstall -y transformer-engine + pip install transformer-engine pip freeze - name: Run MaxDiffusion Training From 4e0dd7f1df2a1ed8f3b3e87c4bb1c875e8235be9 Mon Sep 17 00:00:00 2001 From: Rohan Bierneni Date: Fri, 26 Sep 2025 17:30:24 +0000 Subject: [PATCH 10/50] uninstall all TE deps --- .github/workflows/UnitTests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index 86326037..31a17048 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -37,7 +37,7 @@ jobs: - name: Print dependencies run: | - pip uninstall -y transformer-engine + pip uninstall -y transformer-engine transformer-engine-jax pip install transformer-engine pip freeze From 08883164629b931409a5d6d9fd8f2e4c5b1ee306 Mon Sep 17 00:00:00 2001 From: Rohan Bierneni Date: Fri, 26 Sep 2025 17:34:52 +0000 Subject: [PATCH 11/50] install TE jax --- .github/workflows/UnitTests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index 31a17048..57475c91 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -38,7 +38,7 @@ jobs: - name: Print dependencies run: | pip uninstall -y transformer-engine transformer-engine-jax - pip install transformer-engine + pip install transformer-engine transformer-engine-jax pip freeze - name: Run MaxDiffusion Training From 09c5d6a04212970db02dd7eafde6d6c1f0bb91a4 Mon Sep 17 00:00:00 2001 From: Rohan Bierneni Date: Fri, 26 Sep 2025 17:35:39 +0000 Subject: [PATCH 12/50] fix pip instal typo --- .github/workflows/UnitTests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index 57475c91..bb3c0a22 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -38,7 +38,7 @@ jobs: - name: Print dependencies run: | pip uninstall -y transformer-engine transformer-engine-jax - pip install transformer-engine transformer-engine-jax + pip install transformer-engine transformer_engine[jax] pip freeze - name: Run MaxDiffusion Training From 846e1768e7d512861553b78c59c1c1cac484161b Mon Sep 17 00:00:00 2001 From: Rohan Bierneni Date: Fri, 26 Sep 2025 17:43:00 +0000 Subject: [PATCH 13/50] Install TE for pytorch and jax --- .github/workflows/UnitTests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index bb3c0a22..4c721dec 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -38,7 +38,7 @@ jobs: - name: Print dependencies run: | pip uninstall -y transformer-engine transformer-engine-jax - pip install transformer-engine transformer_engine[jax] + pip install -U transformer-engine[pytorch,jax] pip freeze - name: Run MaxDiffusion Training From 496f67cfb2d06772a2cbe7aa398836306ecbf27a Mon Sep 17 00:00:00 2001 From: Rohan Bierneni Date: Fri, 26 Sep 2025 18:02:16 +0000 Subject: [PATCH 14/50] Comment out tflop calc --- .github/workflows/UnitTests.yml | 4 ++-- src/maxdiffusion/trainers/base_stable_diffusion_trainer.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index 4c721dec..4eb3f862 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -37,8 +37,8 @@ jobs: - name: Print dependencies run: | - pip uninstall -y transformer-engine transformer-engine-jax - pip install -U transformer-engine[pytorch,jax] + # pip uninstall -y transformer-engine transformer-engine-jax + # pip install -U transformer-engine[pytorch,jax] pip freeze - name: Run MaxDiffusion Training diff --git a/src/maxdiffusion/trainers/base_stable_diffusion_trainer.py b/src/maxdiffusion/trainers/base_stable_diffusion_trainer.py index a9f17adc..e889d816 100644 --- a/src/maxdiffusion/trainers/base_stable_diffusion_trainer.py +++ b/src/maxdiffusion/trainers/base_stable_diffusion_trainer.py @@ -161,8 +161,8 @@ def start_training(self): params["scheduler"] = noise_scheduler_state # Calculate tflops - per_device_tflops = self.calculate_tflops(pipeline, params) - self.per_device_tflops = per_device_tflops + # per_device_tflops = self.calculate_tflops(pipeline, params) + # self.per_device_tflops = per_device_tflops # Load dataset data_iterator = self._time_and_log_call(self.load_dataset, pipeline, params, train_states) From 37524bba47363e300110a9da4f036de1df378e4f Mon Sep 17 00:00:00 2001 From: Rohan Bierneni Date: Fri, 26 Sep 2025 18:09:40 +0000 Subject: [PATCH 15/50] Test with dot_product attention --- .github/workflows/UnitTests.yml | 2 +- src/maxdiffusion/trainers/base_stable_diffusion_trainer.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index 4eb3f862..2f16eb7f 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -51,7 +51,7 @@ jobs: train_text_encoder=false \ cache_latents_text_encoder_outputs=true \ per_device_batch_size=1 \ - attention=cudnn_flash_te \ + attention=dot_product \ activations_dtype=bfloat16 \ weights_dtype=bfloat16 \ max_train_steps=200 \ diff --git a/src/maxdiffusion/trainers/base_stable_diffusion_trainer.py b/src/maxdiffusion/trainers/base_stable_diffusion_trainer.py index e889d816..a9f17adc 100644 --- a/src/maxdiffusion/trainers/base_stable_diffusion_trainer.py +++ b/src/maxdiffusion/trainers/base_stable_diffusion_trainer.py @@ -161,8 +161,8 @@ def start_training(self): params["scheduler"] = noise_scheduler_state # Calculate tflops - # per_device_tflops = self.calculate_tflops(pipeline, params) - # self.per_device_tflops = per_device_tflops + per_device_tflops = self.calculate_tflops(pipeline, params) + self.per_device_tflops = per_device_tflops # Load dataset data_iterator = self._time_and_log_call(self.load_dataset, pipeline, params, train_states) From 64f30f36bbc427ffab039a9ef4dfa56429d0f121 Mon Sep 17 00:00:00 2001 From: Rohan Bierneni Date: Fri, 26 Sep 2025 18:45:42 +0000 Subject: [PATCH 16/50] Test if maxtext has same gpu issue --- .github/workflows/UnitTests.yml | 78 ++++++++++++++++++++++----------- 1 file changed, 52 insertions(+), 26 deletions(-) diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index 2f16eb7f..2980841e 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -24,40 +24,66 @@ on: workflow_dispatch: jobs: - # STAGE 1: PULL MAXDIFFUSION IMAGE AND RUN WORKLOAD - maxdiffusion_workload: - name: "Run MaxDiffusion Workload" + maxtext_workload: + name: "Run MaxText Workload" # IMPORTANT: Replace with the label for your runner (e.g., v5p-8) runs-on: ["linux-x86-a3-megagpu-h100-8gpu"] container: - image: gcr.io/tpu-prod-env-multipod/maxdiffusion_stable_stack_candidate:jax0.7.2_rev1_gpu + image: gcr.io/tpu-prod-env-multipod/maxtext_stable_stack_candidate_gpu:latest steps: - - name: Checkout Repository + - name: Checkout MaxText Repo uses: actions/checkout@v4 + with: + repository: AI-Hypercomputer/maxtext + path: maxtext - - name: Print dependencies - run: | - # pip uninstall -y transformer-engine transformer-engine-jax - # pip install -U transformer-engine[pytorch,jax] - pip freeze - - - name: Run MaxDiffusion Training + - name: Run MaxText Training run: | # This command is adapted from your DAG for a single-slice configuration. - NVTE_FUSED_ATTN=1 pip install . && \ - python -m src.maxdiffusion.train_sdxl src/maxdiffusion/configs/base_xl.yml \ - hardware=gpu \ - train_new_unet=true \ - train_text_encoder=false \ - cache_latents_text_encoder_outputs=true \ - per_device_batch_size=1 \ - attention=dot_product \ - activations_dtype=bfloat16 \ - weights_dtype=bfloat16 \ - max_train_steps=200 \ - enable_profiler=True \ - run_name=1slice-VGpuVersion.XPK_H100_a3-maxdiffusion-jax-stable-stack-2025-09-26-04-12-02 \ - output_dir=gs://rbierneni-multipod-dev/${{ github.run_id }} + cd maxtext && \ + pip install -e . --no-dependencies \ + XLA_PYTHON_CLIENT_MEM_FRACTION=0.65 TF_FORCE_GPU_ALLOW_GROWTH=true \ + python3 -m MaxText.train MaxText/configs/base.yml \ + steps=2 \ + enable_checkpointing=false \ + attention=dot_product \ + run_name=rbierneni-test-maxtext-gpu \ + base_output_directory=gs://rbierneni-multipod-dev/maxtext/${{ github.run_id }} + + # # STAGE 1: PULL MAXDIFFUSION IMAGE AND RUN WORKLOAD + # maxdiffusion_workload: + # name: "Run MaxDiffusion Workload" + # # IMPORTANT: Replace with the label for your runner (e.g., v5p-8) + # runs-on: ["linux-x86-a3-megagpu-h100-8gpu"] + # container: + # image: gcr.io/tpu-prod-env-multipod/maxdiffusion_stable_stack_candidate:jax0.7.2_rev1_gpu + # steps: + # - name: Checkout Repository + # uses: actions/checkout@v4 + + # - name: Print dependencies + # run: | + # # pip uninstall -y transformer-engine transformer-engine-jax + # # pip install -U transformer-engine[pytorch,jax] + # pip freeze + + # - name: Run MaxDiffusion Training + # run: | + # # This command is adapted from your DAG for a single-slice configuration. + # NVTE_FUSED_ATTN=1 pip install . && \ + # python -m src.maxdiffusion.train_sdxl src/maxdiffusion/configs/base_xl.yml \ + # hardware=gpu \ + # train_new_unet=true \ + # train_text_encoder=false \ + # cache_latents_text_encoder_outputs=true \ + # per_device_batch_size=1 \ + # attention=dot_product \ + # activations_dtype=bfloat16 \ + # weights_dtype=bfloat16 \ + # max_train_steps=200 \ + # enable_profiler=True \ + # run_name=1slice-VGpuVersion.XPK_H100_a3-maxdiffusion-jax-stable-stack-2025-09-26-04-12-02 \ + # output_dir=gs://rbierneni-multipod-dev/${{ github.run_id }} # jobs: # build: From 53d1d22711f9e5a9bde3695d660aa079021c156b Mon Sep 17 00:00:00 2001 From: Rohan Bierneni Date: Fri, 26 Sep 2025 18:48:27 +0000 Subject: [PATCH 17/50] change to pip install maxtext package --- .github/workflows/UnitTests.yml | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index 2980841e..b60c251b 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -41,14 +41,17 @@ jobs: run: | # This command is adapted from your DAG for a single-slice configuration. cd maxtext && \ - pip install -e . --no-dependencies \ - XLA_PYTHON_CLIENT_MEM_FRACTION=0.65 TF_FORCE_GPU_ALLOW_GROWTH=true \ + pip install . + + export XLA_PYTHON_CLIENT_MEM_FRACTION=0.65 + export TF_FORCE_GPU_ALLOW_GROWTH=true + python3 -m MaxText.train MaxText/configs/base.yml \ - steps=2 \ - enable_checkpointing=false \ - attention=dot_product \ - run_name=rbierneni-test-maxtext-gpu \ - base_output_directory=gs://rbierneni-multipod-dev/maxtext/${{ github.run_id }} + steps=2 \ + enable_checkpointing=false \ + attention=dot_product \ + run_name=rbierneni-test-maxtext-gpu \ + base_output_directory=gs://rbierneni-multipod-dev/maxtext/${{ github.run_id }} # # STAGE 1: PULL MAXDIFFUSION IMAGE AND RUN WORKLOAD # maxdiffusion_workload: From fdcec4c018efd261a5c1866ded85c76f2fdb1c48 Mon Sep 17 00:00:00 2001 From: Rohan Bierneni Date: Fri, 26 Sep 2025 18:52:28 +0000 Subject: [PATCH 18/50] Install with no dependencies --- .github/workflows/UnitTests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index b60c251b..c1033069 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -41,7 +41,7 @@ jobs: run: | # This command is adapted from your DAG for a single-slice configuration. cd maxtext && \ - pip install . + pip install . --no-dependencies export XLA_PYTHON_CLIENT_MEM_FRACTION=0.65 export TF_FORCE_GPU_ALLOW_GROWTH=true From cd9daaf19d7c46c246a09ec6cc16d5cec3312cf1 Mon Sep 17 00:00:00 2001 From: Rohan Bierneni Date: Fri, 26 Sep 2025 18:59:16 +0000 Subject: [PATCH 19/50] Use custom maxtext branch --- .github/workflows/UnitTests.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index c1033069..686704b7 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -36,6 +36,7 @@ jobs: with: repository: AI-Hypercomputer/maxtext path: maxtext + ref: rbierneni-test-gpu-run - name: Run MaxText Training run: | From e0615dcf61746e0d010b224f0570e10035955bd9 Mon Sep 17 00:00:00 2001 From: Rohan Bierneni Date: Fri, 26 Sep 2025 19:04:45 +0000 Subject: [PATCH 20/50] use synthetic data --- .github/workflows/UnitTests.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index 686704b7..61e7c637 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -51,6 +51,7 @@ jobs: steps=2 \ enable_checkpointing=false \ attention=dot_product \ + dataset_type=synthetic \ run_name=rbierneni-test-maxtext-gpu \ base_output_directory=gs://rbierneni-multipod-dev/maxtext/${{ github.run_id }} From 323654a7d18329696923cf90ec55a6e87663c218 Mon Sep 17 00:00:00 2001 From: Rohan Bierneni Date: Fri, 26 Sep 2025 19:10:37 +0000 Subject: [PATCH 21/50] Try with TE flash --- .github/workflows/UnitTests.yml | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index 61e7c637..49c22a75 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -37,6 +37,10 @@ jobs: repository: AI-Hypercomputer/maxtext path: maxtext ref: rbierneni-test-gpu-run + + - name: Print dependencies + run: | + pip freeze - name: Run MaxText Training run: | @@ -50,7 +54,7 @@ jobs: python3 -m MaxText.train MaxText/configs/base.yml \ steps=2 \ enable_checkpointing=false \ - attention=dot_product \ + attention=cudnn_flash_te \ dataset_type=synthetic \ run_name=rbierneni-test-maxtext-gpu \ base_output_directory=gs://rbierneni-multipod-dev/maxtext/${{ github.run_id }} From 2313d541dd1705beec1b22b794b721dd95f489bd Mon Sep 17 00:00:00 2001 From: Rohan Bierneni Date: Fri, 26 Sep 2025 21:07:21 +0000 Subject: [PATCH 22/50] Try with older TE --- .github/workflows/UnitTests.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index 49c22a75..ebf85225 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -40,6 +40,8 @@ jobs: - name: Print dependencies run: | + pip uninstall -y transformer-engine transformer-engine-jax + pip install -U transformer-engine[jax]=0.2.4 pip freeze - name: Run MaxText Training From 2eafa9329ec0d8e23ae878644fd2b7745f0cbdda Mon Sep 17 00:00:00 2001 From: Rohan Bierneni Date: Fri, 26 Sep 2025 21:15:01 +0000 Subject: [PATCH 23/50] Typo in TE install --- .github/workflows/UnitTests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index ebf85225..cb283a7e 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -41,7 +41,7 @@ jobs: - name: Print dependencies run: | pip uninstall -y transformer-engine transformer-engine-jax - pip install -U transformer-engine[jax]=0.2.4 + pip install -U transformer-engine[jax]==2.4.0 pip freeze - name: Run MaxText Training From a1da601045d60dd274df2c468f28f9e8dc60080e Mon Sep 17 00:00:00 2001 From: Rohan Bierneni Date: Fri, 26 Sep 2025 21:16:11 +0000 Subject: [PATCH 24/50] Use TE 2.5.0 --- .github/workflows/UnitTests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index cb283a7e..dfd71fa5 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -41,7 +41,7 @@ jobs: - name: Print dependencies run: | pip uninstall -y transformer-engine transformer-engine-jax - pip install -U transformer-engine[jax]==2.4.0 + pip install -U transformer-engine[jax]==2.5.0 pip freeze - name: Run MaxText Training From dfe8fc5bff63df3224712f2471a4d1f517ccabf3 Mon Sep 17 00:00:00 2001 From: Rohan Bierneni Date: Fri, 26 Sep 2025 21:23:40 +0000 Subject: [PATCH 25/50] Use TE 2.6.0 --- .github/workflows/UnitTests.yml | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index dfd71fa5..2209d61c 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -41,7 +41,7 @@ jobs: - name: Print dependencies run: | pip uninstall -y transformer-engine transformer-engine-jax - pip install -U transformer-engine[jax]==2.5.0 + pip install -U transformer-engine[jax]==2.6.0 pip freeze - name: Run MaxText Training @@ -81,20 +81,20 @@ jobs: # - name: Run MaxDiffusion Training # run: | # # This command is adapted from your DAG for a single-slice configuration. - # NVTE_FUSED_ATTN=1 pip install . && \ - # python -m src.maxdiffusion.train_sdxl src/maxdiffusion/configs/base_xl.yml \ - # hardware=gpu \ - # train_new_unet=true \ - # train_text_encoder=false \ - # cache_latents_text_encoder_outputs=true \ - # per_device_batch_size=1 \ - # attention=dot_product \ - # activations_dtype=bfloat16 \ - # weights_dtype=bfloat16 \ - # max_train_steps=200 \ - # enable_profiler=True \ - # run_name=1slice-VGpuVersion.XPK_H100_a3-maxdiffusion-jax-stable-stack-2025-09-26-04-12-02 \ - # output_dir=gs://rbierneni-multipod-dev/${{ github.run_id }} + # NVTE_FUSED_ATTN=1 pip install . && \ + # python -m src.maxdiffusion.train_sdxl src/maxdiffusion/configs/base_xl.yml \ + # hardware=gpu \ + # train_new_unet=true \ + # train_text_encoder=false \ + # cache_latents_text_encoder_outputs=true \ + # per_device_batch_size=1 \ + # attention=dot_product \ + # activations_dtype=bfloat16 \ + # weights_dtype=bfloat16 \ + # max_train_steps=200 \ + # enable_profiler=True \ + # run_name=1slice-VGpuVersion.XPK_H100_a3-maxdiffusion-jax-stable-stack-2025-09-26-04-12-02 \ + # output_dir=gs://rbierneni-multipod-dev/${{ github.run_id }} # jobs: # build: From 4243d537ce26820f5671c70e84fe00765035bb8c Mon Sep 17 00:00:00 2001 From: Rohan Bierneni Date: Fri, 26 Sep 2025 21:32:46 +0000 Subject: [PATCH 26/50] Test with tensorflow-cpu --- .github/workflows/UnitTests.yml | 124 ++++++++++++++++---------------- 1 file changed, 64 insertions(+), 60 deletions(-) diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index 2209d61c..99a07e2f 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -24,77 +24,81 @@ on: workflow_dispatch: jobs: - maxtext_workload: - name: "Run MaxText Workload" - # IMPORTANT: Replace with the label for your runner (e.g., v5p-8) - runs-on: ["linux-x86-a3-megagpu-h100-8gpu"] - container: - image: gcr.io/tpu-prod-env-multipod/maxtext_stable_stack_candidate_gpu:latest - steps: - - name: Checkout MaxText Repo - uses: actions/checkout@v4 - with: - repository: AI-Hypercomputer/maxtext - path: maxtext - ref: rbierneni-test-gpu-run - - - name: Print dependencies - run: | - pip uninstall -y transformer-engine transformer-engine-jax - pip install -U transformer-engine[jax]==2.6.0 - pip freeze - - - name: Run MaxText Training - run: | - # This command is adapted from your DAG for a single-slice configuration. - cd maxtext && \ - pip install . --no-dependencies - - export XLA_PYTHON_CLIENT_MEM_FRACTION=0.65 - export TF_FORCE_GPU_ALLOW_GROWTH=true - - python3 -m MaxText.train MaxText/configs/base.yml \ - steps=2 \ - enable_checkpointing=false \ - attention=cudnn_flash_te \ - dataset_type=synthetic \ - run_name=rbierneni-test-maxtext-gpu \ - base_output_directory=gs://rbierneni-multipod-dev/maxtext/${{ github.run_id }} - - # # STAGE 1: PULL MAXDIFFUSION IMAGE AND RUN WORKLOAD - # maxdiffusion_workload: - # name: "Run MaxDiffusion Workload" + # maxtext_workload: + # name: "Run MaxText Workload" # # IMPORTANT: Replace with the label for your runner (e.g., v5p-8) # runs-on: ["linux-x86-a3-megagpu-h100-8gpu"] # container: - # image: gcr.io/tpu-prod-env-multipod/maxdiffusion_stable_stack_candidate:jax0.7.2_rev1_gpu + # image: gcr.io/tpu-prod-env-multipod/maxtext_stable_stack_candidate_gpu:latest # steps: - # - name: Checkout Repository + # - name: Checkout MaxText Repo # uses: actions/checkout@v4 - + # with: + # repository: AI-Hypercomputer/maxtext + # path: maxtext + # ref: rbierneni-test-gpu-run + # - name: Print dependencies # run: | - # # pip uninstall -y transformer-engine transformer-engine-jax - # # pip install -U transformer-engine[pytorch,jax] + # pip uninstall -y transformer-engine transformer-engine-jax + # pip install -U transformer-engine[jax]==2.6.0 + # pip uninstall -y tensorflow + # pip install tensorflow-cpu # pip freeze - # - name: Run MaxDiffusion Training + # - name: Run MaxText Training # run: | # # This command is adapted from your DAG for a single-slice configuration. - # NVTE_FUSED_ATTN=1 pip install . && \ - # python -m src.maxdiffusion.train_sdxl src/maxdiffusion/configs/base_xl.yml \ - # hardware=gpu \ - # train_new_unet=true \ - # train_text_encoder=false \ - # cache_latents_text_encoder_outputs=true \ - # per_device_batch_size=1 \ - # attention=dot_product \ - # activations_dtype=bfloat16 \ - # weights_dtype=bfloat16 \ - # max_train_steps=200 \ - # enable_profiler=True \ - # run_name=1slice-VGpuVersion.XPK_H100_a3-maxdiffusion-jax-stable-stack-2025-09-26-04-12-02 \ - # output_dir=gs://rbierneni-multipod-dev/${{ github.run_id }} + # cd maxtext && \ + # pip install . --no-dependencies + + # export XLA_PYTHON_CLIENT_MEM_FRACTION=0.65 + # export TF_FORCE_GPU_ALLOW_GROWTH=true + + # python3 -m MaxText.train MaxText/configs/base.yml \ + # steps=2 \ + # enable_checkpointing=false \ + # attention=cudnn_flash_te \ + # dataset_type=synthetic \ + # run_name=rbierneni-test-maxtext-gpu \ + # base_output_directory=gs://rbierneni-multipod-dev/maxtext/${{ github.run_id }} + + # STAGE 1: PULL MAXDIFFUSION IMAGE AND RUN WORKLOAD + maxdiffusion_workload: + name: "Run MaxDiffusion Workload" + # IMPORTANT: Replace with the label for your runner (e.g., v5p-8) + runs-on: ["linux-x86-a3-megagpu-h100-8gpu"] + container: + image: gcr.io/tpu-prod-env-multipod/maxdiffusion_stable_stack_candidate:jax0.7.2_rev1_gpu + steps: + - name: Checkout Repository + uses: actions/checkout@v4 + + - name: Print dependencies + run: | + # pip uninstall -y transformer-engine transformer-engine-jax + # pip install -U transformer-engine[pytorch,jax] + pip uninstall -y tensorflow + pip install tensorflow-cpu + pip freeze + + - name: Run MaxDiffusion Training + run: | + # This command is adapted from your DAG for a single-slice configuration. + NVTE_FUSED_ATTN=1 pip install . && \ + python -m src.maxdiffusion.train_sdxl src/maxdiffusion/configs/base_xl.yml \ + hardware=gpu \ + train_new_unet=true \ + train_text_encoder=false \ + cache_latents_text_encoder_outputs=true \ + per_device_batch_size=1 \ + attention=dot_product \ + activations_dtype=bfloat16 \ + weights_dtype=bfloat16 \ + max_train_steps=200 \ + enable_profiler=True \ + run_name=1slice-VGpuVersion.XPK_H100_a3-maxdiffusion-jax-stable-stack-2025-09-26-04-12-02 \ + output_dir=gs://rbierneni-multipod-dev/${{ github.run_id }} # jobs: # build: From d37e0dd775ed74d39c8b66f253df967ce0cea64b Mon Sep 17 00:00:00 2001 From: Rohan Bierneni Date: Tue, 30 Sep 2025 00:22:37 +0000 Subject: [PATCH 27/50] Test with new cuda13 images and TE 2.6.0 --- .github/workflows/UnitTests.yml | 78 ++++++++++++++++----------------- 1 file changed, 39 insertions(+), 39 deletions(-) diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index 99a07e2f..b5c0f0c3 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -24,44 +24,44 @@ on: workflow_dispatch: jobs: - # maxtext_workload: - # name: "Run MaxText Workload" - # # IMPORTANT: Replace with the label for your runner (e.g., v5p-8) - # runs-on: ["linux-x86-a3-megagpu-h100-8gpu"] - # container: - # image: gcr.io/tpu-prod-env-multipod/maxtext_stable_stack_candidate_gpu:latest - # steps: - # - name: Checkout MaxText Repo - # uses: actions/checkout@v4 - # with: - # repository: AI-Hypercomputer/maxtext - # path: maxtext - # ref: rbierneni-test-gpu-run + maxtext_workload: + name: "Run MaxText Workload" + # IMPORTANT: Replace with the label for your runner (e.g., v5p-8) + runs-on: ["linux-x86-a3-megagpu-h100-8gpu"] + container: + image: gcr.io/tpu-prod-env-multipod/maxtext_stable_stack_candidate_gpu:latest + steps: + - name: Checkout MaxText Repo + uses: actions/checkout@v4 + with: + repository: AI-Hypercomputer/maxtext + path: maxtext + ref: rbierneni-test-gpu-run - # - name: Print dependencies - # run: | - # pip uninstall -y transformer-engine transformer-engine-jax - # pip install -U transformer-engine[jax]==2.6.0 - # pip uninstall -y tensorflow - # pip install tensorflow-cpu - # pip freeze + - name: Print dependencies + run: | + pip uninstall -y transformer-engine transformer-engine-jax + pip install -U transformer-engine[jax]==2.6.0 + # pip uninstall -y tensorflow + # pip install tensorflow-cpu + pip freeze - # - name: Run MaxText Training - # run: | - # # This command is adapted from your DAG for a single-slice configuration. - # cd maxtext && \ - # pip install . --no-dependencies + - name: Run MaxText Training + run: | + # This command is adapted from your DAG for a single-slice configuration. + cd maxtext && \ + pip install . --no-dependencies - # export XLA_PYTHON_CLIENT_MEM_FRACTION=0.65 - # export TF_FORCE_GPU_ALLOW_GROWTH=true + export XLA_PYTHON_CLIENT_MEM_FRACTION=0.65 + export TF_FORCE_GPU_ALLOW_GROWTH=true - # python3 -m MaxText.train MaxText/configs/base.yml \ - # steps=2 \ - # enable_checkpointing=false \ - # attention=cudnn_flash_te \ - # dataset_type=synthetic \ - # run_name=rbierneni-test-maxtext-gpu \ - # base_output_directory=gs://rbierneni-multipod-dev/maxtext/${{ github.run_id }} + python3 -m MaxText.train MaxText/configs/base.yml \ + steps=2 \ + enable_checkpointing=false \ + attention=cudnn_flash_te \ + dataset_type=synthetic \ + run_name=rbierneni-test-maxtext-gpu \ + base_output_directory=gs://rbierneni-multipod-dev/maxtext/${{ github.run_id }} # STAGE 1: PULL MAXDIFFUSION IMAGE AND RUN WORKLOAD maxdiffusion_workload: @@ -69,7 +69,7 @@ jobs: # IMPORTANT: Replace with the label for your runner (e.g., v5p-8) runs-on: ["linux-x86-a3-megagpu-h100-8gpu"] container: - image: gcr.io/tpu-prod-env-multipod/maxdiffusion_stable_stack_candidate:jax0.7.2_rev1_gpu + image: gcr.io/tpu-prod-env-multipod/maxdiffusion_stable_stack_candidate:jax0.7.2_rev2_gpu steps: - name: Checkout Repository uses: actions/checkout@v4 @@ -77,9 +77,9 @@ jobs: - name: Print dependencies run: | # pip uninstall -y transformer-engine transformer-engine-jax - # pip install -U transformer-engine[pytorch,jax] - pip uninstall -y tensorflow - pip install tensorflow-cpu + pip install -U transformer-engine[jax]==2.6.0 + # pip uninstall -y tensorflow + # pip install tensorflow-cpu pip freeze - name: Run MaxDiffusion Training @@ -92,7 +92,7 @@ jobs: train_text_encoder=false \ cache_latents_text_encoder_outputs=true \ per_device_batch_size=1 \ - attention=dot_product \ + attention=cudnn_flash_te \ activations_dtype=bfloat16 \ weights_dtype=bfloat16 \ max_train_steps=200 \ From c247bee21d3e22756afe03f332cb9d733e5f551e Mon Sep 17 00:00:00 2001 From: Rohan Bierneni Date: Tue, 30 Sep 2025 02:19:55 +0000 Subject: [PATCH 28/50] Update to right image --- .github/workflows/UnitTests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index b5c0f0c3..37df849b 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -69,7 +69,7 @@ jobs: # IMPORTANT: Replace with the label for your runner (e.g., v5p-8) runs-on: ["linux-x86-a3-megagpu-h100-8gpu"] container: - image: gcr.io/tpu-prod-env-multipod/maxdiffusion_stable_stack_candidate:jax0.7.2_rev2_gpu + image: gcr.io/tpu-prod-env-multipod/maxdiffusion_stable_stack_candidate:jax0.7.2_rev3_gpu steps: - name: Checkout Repository uses: actions/checkout@v4 From 67141daed96c6ecb0a6424434a7af24d9d8ef2e3 Mon Sep 17 00:00:00 2001 From: Rohan Bierneni Date: Tue, 30 Sep 2025 02:29:23 +0000 Subject: [PATCH 29/50] uninstall cuda12 TE as well --- .github/workflows/UnitTests.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index 37df849b..59e4a723 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -40,7 +40,7 @@ jobs: - name: Print dependencies run: | - pip uninstall -y transformer-engine transformer-engine-jax + pip uninstall -y transformer-engine transformer-engine-jax transformer-engine-cu12 pip install -U transformer-engine[jax]==2.6.0 # pip uninstall -y tensorflow # pip install tensorflow-cpu @@ -76,7 +76,7 @@ jobs: - name: Print dependencies run: | - # pip uninstall -y transformer-engine transformer-engine-jax + # pip uninstall -y transformer-engine transformer-engine-jax transformer-engine-cu12 pip install -U transformer-engine[jax]==2.6.0 # pip uninstall -y tensorflow # pip install tensorflow-cpu From 4ac11bb6bd82915f402d8ded35702a3a08c88fa8 Mon Sep 17 00:00:00 2001 From: Rohan Bierneni Date: Tue, 30 Sep 2025 02:49:52 +0000 Subject: [PATCH 30/50] Test with te-cu13 package --- .github/workflows/UnitTests.yml | 70 +++++++++++++++++---------------- 1 file changed, 36 insertions(+), 34 deletions(-) diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index 59e4a723..a43497ba 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -24,44 +24,44 @@ on: workflow_dispatch: jobs: - maxtext_workload: - name: "Run MaxText Workload" - # IMPORTANT: Replace with the label for your runner (e.g., v5p-8) - runs-on: ["linux-x86-a3-megagpu-h100-8gpu"] - container: - image: gcr.io/tpu-prod-env-multipod/maxtext_stable_stack_candidate_gpu:latest - steps: - - name: Checkout MaxText Repo - uses: actions/checkout@v4 - with: - repository: AI-Hypercomputer/maxtext - path: maxtext - ref: rbierneni-test-gpu-run + # maxtext_workload: + # name: "Run MaxText Workload" + # # IMPORTANT: Replace with the label for your runner (e.g., v5p-8) + # runs-on: ["linux-x86-a3-megagpu-h100-8gpu"] + # container: + # image: gcr.io/tpu-prod-env-multipod/maxtext_stable_stack_candidate_gpu:latest + # steps: + # - name: Checkout MaxText Repo + # uses: actions/checkout@v4 + # with: + # repository: AI-Hypercomputer/maxtext + # path: maxtext + # ref: rbierneni-test-gpu-run - - name: Print dependencies - run: | - pip uninstall -y transformer-engine transformer-engine-jax transformer-engine-cu12 - pip install -U transformer-engine[jax]==2.6.0 - # pip uninstall -y tensorflow - # pip install tensorflow-cpu - pip freeze + # - name: Print dependencies + # run: | + # pip uninstall -y transformer-engine transformer-engine-jax transformer-engine-cu12 + # pip install -U transformer-engine[jax]==2.6.0 + # # pip uninstall -y tensorflow + # # pip install tensorflow-cpu + # pip freeze - - name: Run MaxText Training - run: | - # This command is adapted from your DAG for a single-slice configuration. - cd maxtext && \ - pip install . --no-dependencies + # - name: Run MaxText Training + # run: | + # # This command is adapted from your DAG for a single-slice configuration. + # cd maxtext && \ + # pip install . --no-dependencies - export XLA_PYTHON_CLIENT_MEM_FRACTION=0.65 - export TF_FORCE_GPU_ALLOW_GROWTH=true + # export XLA_PYTHON_CLIENT_MEM_FRACTION=0.65 + # export TF_FORCE_GPU_ALLOW_GROWTH=true - python3 -m MaxText.train MaxText/configs/base.yml \ - steps=2 \ - enable_checkpointing=false \ - attention=cudnn_flash_te \ - dataset_type=synthetic \ - run_name=rbierneni-test-maxtext-gpu \ - base_output_directory=gs://rbierneni-multipod-dev/maxtext/${{ github.run_id }} + # python3 -m MaxText.train MaxText/configs/base.yml \ + # steps=2 \ + # enable_checkpointing=false \ + # attention=cudnn_flash_te \ + # dataset_type=synthetic \ + # run_name=rbierneni-test-maxtext-gpu \ + # base_output_directory=gs://rbierneni-multipod-dev/maxtext/${{ github.run_id }} # STAGE 1: PULL MAXDIFFUSION IMAGE AND RUN WORKLOAD maxdiffusion_workload: @@ -78,6 +78,8 @@ jobs: run: | # pip uninstall -y transformer-engine transformer-engine-jax transformer-engine-cu12 pip install -U transformer-engine[jax]==2.6.0 + pip uninstall -y transformer-engine-cu12 + pip install transformer-engine-cu13 # pip uninstall -y tensorflow # pip install tensorflow-cpu pip freeze From 994d8dcb482aefd0a59b0b629fe2b9a3df3a18f0 Mon Sep 17 00:00:00 2001 From: Rohan Bierneni Date: Tue, 30 Sep 2025 03:01:37 +0000 Subject: [PATCH 31/50] try with base TE --- .github/workflows/UnitTests.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index a43497ba..50977447 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -77,9 +77,9 @@ jobs: - name: Print dependencies run: | # pip uninstall -y transformer-engine transformer-engine-jax transformer-engine-cu12 - pip install -U transformer-engine[jax]==2.6.0 - pip uninstall -y transformer-engine-cu12 - pip install transformer-engine-cu13 + pip install -U transformer-engine==2.6.0 + # pip uninstall -y transformer-engine-cu12 + # pip install transformer-engine-cu13 # pip uninstall -y tensorflow # pip install tensorflow-cpu pip freeze From 6700b5364ae3488759561c3f5365b45f2d355ee7 Mon Sep 17 00:00:00 2001 From: Rohan Bierneni Date: Tue, 30 Sep 2025 03:02:36 +0000 Subject: [PATCH 32/50] uninstall existing TE --- .github/workflows/UnitTests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index 50977447..fb32c9a2 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -76,7 +76,7 @@ jobs: - name: Print dependencies run: | - # pip uninstall -y transformer-engine transformer-engine-jax transformer-engine-cu12 + pip uninstall -y transformer-engine transformer-engine-jax transformer-engine-cu12 pip install -U transformer-engine==2.6.0 # pip uninstall -y transformer-engine-cu12 # pip install transformer-engine-cu13 From 23d8ca20426b5930d937a2b0009c391accc19148 Mon Sep 17 00:00:00 2001 From: Rohan Bierneni Date: Tue, 30 Sep 2025 03:18:32 +0000 Subject: [PATCH 33/50] test te 2.6.0 jax --- .github/workflows/UnitTests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index fb32c9a2..97c5de55 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -77,7 +77,7 @@ jobs: - name: Print dependencies run: | pip uninstall -y transformer-engine transformer-engine-jax transformer-engine-cu12 - pip install -U transformer-engine==2.6.0 + pip install -U transformer-engine[jax]==2.6.0 # pip uninstall -y transformer-engine-cu12 # pip install transformer-engine-cu13 # pip uninstall -y tensorflow From 0885a52ac7865490f7b02c3a3723d0d51998b001 Mon Sep 17 00:00:00 2001 From: Rohan Bierneni Date: Tue, 30 Sep 2025 04:01:59 +0000 Subject: [PATCH 34/50] test te 2.6.0 jax --- .github/workflows/UnitTests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index 97c5de55..ce5d4be2 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -69,7 +69,7 @@ jobs: # IMPORTANT: Replace with the label for your runner (e.g., v5p-8) runs-on: ["linux-x86-a3-megagpu-h100-8gpu"] container: - image: gcr.io/tpu-prod-env-multipod/maxdiffusion_stable_stack_candidate:jax0.7.2_rev3_gpu + image: gcr.io/tpu-prod-env-multipod/maxdiffusion_stable_stack_candidate:jax0.7.2_rev4_gpu steps: - name: Checkout Repository uses: actions/checkout@v4 From 22cfa352df04e0d2c28d8581a46e3da013db8d9c Mon Sep 17 00:00:00 2001 From: Rohan Bierneni Date: Tue, 30 Sep 2025 04:12:21 +0000 Subject: [PATCH 35/50] remove steps for uninstalling TE --- .github/workflows/UnitTests.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index ce5d4be2..69fa4f16 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -76,8 +76,8 @@ jobs: - name: Print dependencies run: | - pip uninstall -y transformer-engine transformer-engine-jax transformer-engine-cu12 - pip install -U transformer-engine[jax]==2.6.0 + # pip uninstall -y transformer-engine transformer-engine-jax transformer-engine-cu12 + # pip install -U transformer-engine[jax]==2.6.0 # pip uninstall -y transformer-engine-cu12 # pip install transformer-engine-cu13 # pip uninstall -y tensorflow From ab478a1f258cc49c45fc3399f17d1b4fd8b22cc1 Mon Sep 17 00:00:00 2001 From: Rohan Bierneni Date: Tue, 30 Sep 2025 04:29:17 +0000 Subject: [PATCH 36/50] check host gpu driver --- .github/workflows/UnitTests.yml | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index 69fa4f16..14ddf4db 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -74,6 +74,19 @@ jobs: - name: Checkout Repository uses: actions/checkout@v4 + - name: Check Host CUDA and GPU Environment + run: | + echo "--- Checking NVIDIA driver and supported CUDA version ---" + nvidia-smi || echo "nvidia-smi command not found. No GPU or NVIDIA driver detected." + + echo "" + echo "--- Checking for default CUDA toolkit installation ---" + ls -l /usr/local/ | grep cuda || echo "No default CUDA toolkit found in /usr/local/" + + echo "" + echo "--- Checking dynamic linker library path ---" + echo "LD_LIBRARY_PATH=${LD_LIBRARY_PATH:-'Not Set'}" + - name: Print dependencies run: | # pip uninstall -y transformer-engine transformer-engine-jax transformer-engine-cu12 From 7ff2c235db5a165c2960aab08e162a838f3ec280 Mon Sep 17 00:00:00 2001 From: Rohan Bierneni Date: Tue, 30 Sep 2025 07:32:16 +0000 Subject: [PATCH 37/50] try with cuda 12 on TE cu12 --- .github/workflows/UnitTests.yml | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index 14ddf4db..6e92be77 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -69,7 +69,7 @@ jobs: # IMPORTANT: Replace with the label for your runner (e.g., v5p-8) runs-on: ["linux-x86-a3-megagpu-h100-8gpu"] container: - image: gcr.io/tpu-prod-env-multipod/maxdiffusion_stable_stack_candidate:jax0.7.2_rev4_gpu + image: gcr.io/tpu-prod-env-multipod/maxdiffusion_stable_stack_candidate:jax0.7.2_cuda12_tecu12_gpu steps: - name: Checkout Repository uses: actions/checkout@v4 @@ -89,7 +89,8 @@ jobs: - name: Print dependencies run: | - # pip uninstall -y transformer-engine transformer-engine-jax transformer-engine-cu12 + pip uninstall -y transformer-engine transformer-engine-jax transformer-engine-cu12 + pip install transformer_engine[jax]==2.6.0 # pip install -U transformer-engine[jax]==2.6.0 # pip uninstall -y transformer-engine-cu12 # pip install transformer-engine-cu13 From ef2041e86993fa0bd03801d9011cb7a368308779 Mon Sep 17 00:00:00 2001 From: Rohan Bierneni Date: Tue, 30 Sep 2025 07:56:48 +0000 Subject: [PATCH 38/50] try with last tested TE version --- .github/workflows/UnitTests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index 6e92be77..d83bbf01 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -90,7 +90,7 @@ jobs: - name: Print dependencies run: | pip uninstall -y transformer-engine transformer-engine-jax transformer-engine-cu12 - pip install transformer_engine[jax]==2.6.0 + pip install transformer_engine[jax]==2.4.0 # pip install -U transformer-engine[jax]==2.6.0 # pip uninstall -y transformer-engine-cu12 # pip install transformer-engine-cu13 From ec29e4e055c5de2f13bd96a95800e0450e30e662 Mon Sep 17 00:00:00 2001 From: Rohan Bierneni Date: Tue, 30 Sep 2025 08:38:18 +0000 Subject: [PATCH 39/50] try with no-cache build image --- .github/workflows/UnitTests.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index d83bbf01..a53b892a 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -89,8 +89,8 @@ jobs: - name: Print dependencies run: | - pip uninstall -y transformer-engine transformer-engine-jax transformer-engine-cu12 - pip install transformer_engine[jax]==2.4.0 + # pip uninstall -y transformer-engine transformer-engine-jax transformer-engine-cu12 + # pip install transformer_engine[jax]==2.4.0 # pip install -U transformer-engine[jax]==2.6.0 # pip uninstall -y transformer-engine-cu12 # pip install transformer-engine-cu13 From ecfacad5bcb02c84ac07c16af22db0feaf39b328 Mon Sep 17 00:00:00 2001 From: Rohan Bierneni Date: Tue, 30 Sep 2025 08:46:39 +0000 Subject: [PATCH 40/50] update image tag to prevent gke caching --- .github/workflows/UnitTests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index a53b892a..22f4d801 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -69,7 +69,7 @@ jobs: # IMPORTANT: Replace with the label for your runner (e.g., v5p-8) runs-on: ["linux-x86-a3-megagpu-h100-8gpu"] container: - image: gcr.io/tpu-prod-env-multipod/maxdiffusion_stable_stack_candidate:jax0.7.2_cuda12_tecu12_gpu + image: gcr.io/tpu-prod-env-multipod/maxdiffusion_stable_stack_candidate:jax0.7.2_cuda13_te2.6.0 steps: - name: Checkout Repository uses: actions/checkout@v4 From 1273c3e1f8328e2fd55c27065c0ad8d990a8c95b Mon Sep 17 00:00:00 2001 From: Rohan Bierneni Date: Tue, 30 Sep 2025 08:53:46 +0000 Subject: [PATCH 41/50] try with dot_product --- .github/workflows/UnitTests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index 22f4d801..4e1b3ae7 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -108,7 +108,7 @@ jobs: train_text_encoder=false \ cache_latents_text_encoder_outputs=true \ per_device_batch_size=1 \ - attention=cudnn_flash_te \ + attention=dot_product \ activations_dtype=bfloat16 \ weights_dtype=bfloat16 \ max_train_steps=200 \ From d1263eb3872675f55cde8091816aa8e23b968a14 Mon Sep 17 00:00:00 2001 From: Rohan Bierneni Date: Tue, 30 Sep 2025 09:03:17 +0000 Subject: [PATCH 42/50] try without te --- .github/workflows/UnitTests.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index 4e1b3ae7..915644a4 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -69,7 +69,7 @@ jobs: # IMPORTANT: Replace with the label for your runner (e.g., v5p-8) runs-on: ["linux-x86-a3-megagpu-h100-8gpu"] container: - image: gcr.io/tpu-prod-env-multipod/maxdiffusion_stable_stack_candidate:jax0.7.2_cuda13_te2.6.0 + image: gcr.io/tpu-prod-env-multipod/maxdiffusion_stable_stack_candidate: steps: - name: Checkout Repository uses: actions/checkout@v4 @@ -89,7 +89,7 @@ jobs: - name: Print dependencies run: | - # pip uninstall -y transformer-engine transformer-engine-jax transformer-engine-cu12 + pip uninstall -y transformer-engine transformer-engine-jax transformer-engine-cu12 # pip install transformer_engine[jax]==2.4.0 # pip install -U transformer-engine[jax]==2.6.0 # pip uninstall -y transformer-engine-cu12 @@ -108,7 +108,7 @@ jobs: train_text_encoder=false \ cache_latents_text_encoder_outputs=true \ per_device_batch_size=1 \ - attention=dot_product \ + attention=cudnn_flash_te \ activations_dtype=bfloat16 \ weights_dtype=bfloat16 \ max_train_steps=200 \ From 9295b8be95c75075c91c98bedbabfc7d3edbcdb3 Mon Sep 17 00:00:00 2001 From: Rohan Bierneni Date: Tue, 30 Sep 2025 09:03:58 +0000 Subject: [PATCH 43/50] use with no te --- .github/workflows/UnitTests.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index 915644a4..e2dd2630 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -69,7 +69,7 @@ jobs: # IMPORTANT: Replace with the label for your runner (e.g., v5p-8) runs-on: ["linux-x86-a3-megagpu-h100-8gpu"] container: - image: gcr.io/tpu-prod-env-multipod/maxdiffusion_stable_stack_candidate: + image: gcr.io/tpu-prod-env-multipod/maxdiffusion_stable_stack_candidate:jax0.7.2_cuda13_te2.6.0 steps: - name: Checkout Repository uses: actions/checkout@v4 @@ -108,7 +108,7 @@ jobs: train_text_encoder=false \ cache_latents_text_encoder_outputs=true \ per_device_batch_size=1 \ - attention=cudnn_flash_te \ + attention=dot_product \ activations_dtype=bfloat16 \ weights_dtype=bfloat16 \ max_train_steps=200 \ From 95c1b03326915c87581835ad8f19cd5a7ce03ef7 Mon Sep 17 00:00:00 2001 From: Rohan Bierneni Date: Tue, 30 Sep 2025 09:14:49 +0000 Subject: [PATCH 44/50] Check jax.devices output --- .github/workflows/UnitTests.yml | 36 ++++++++++++++++++--------------- 1 file changed, 20 insertions(+), 16 deletions(-) diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index e2dd2630..123d38c0 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -98,23 +98,27 @@ jobs: # pip install tensorflow-cpu pip freeze - - name: Run MaxDiffusion Training + - name: Check per_device_batch_size run: | - # This command is adapted from your DAG for a single-slice configuration. - NVTE_FUSED_ATTN=1 pip install . && \ - python -m src.maxdiffusion.train_sdxl src/maxdiffusion/configs/base_xl.yml \ - hardware=gpu \ - train_new_unet=true \ - train_text_encoder=false \ - cache_latents_text_encoder_outputs=true \ - per_device_batch_size=1 \ - attention=dot_product \ - activations_dtype=bfloat16 \ - weights_dtype=bfloat16 \ - max_train_steps=200 \ - enable_profiler=True \ - run_name=1slice-VGpuVersion.XPK_H100_a3-maxdiffusion-jax-stable-stack-2025-09-26-04-12-02 \ - output_dir=gs://rbierneni-multipod-dev/${{ github.run_id }} + python -c "import jax; print(jax.devices())" + + # - name: Run MaxDiffusion Training + # run: | + # # This command is adapted from your DAG for a single-slice configuration. + # NVTE_FUSED_ATTN=1 pip install . && \ + # python -m src.maxdiffusion.train_sdxl src/maxdiffusion/configs/base_xl.yml \ + # hardware=gpu \ + # train_new_unet=true \ + # train_text_encoder=false \ + # cache_latents_text_encoder_outputs=true \ + # per_device_batch_size=1 \ + # attention=dot_product \ + # activations_dtype=bfloat16 \ + # weights_dtype=bfloat16 \ + # max_train_steps=200 \ + # enable_profiler=True \ + # run_name=1slice-VGpuVersion.XPK_H100_a3-maxdiffusion-jax-stable-stack-2025-09-26-04-12-02 \ + # output_dir=gs://rbierneni-multipod-dev/${{ github.run_id }} # jobs: # build: From 333e4f596922e3234bdba948710f94ae0257f262 Mon Sep 17 00:00:00 2001 From: Rohan Bierneni Date: Tue, 30 Sep 2025 09:19:57 +0000 Subject: [PATCH 45/50] try with workload --- .github/workflows/UnitTests.yml | 34 ++++++++++++++++----------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index 123d38c0..5f42ec72 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -102,23 +102,23 @@ jobs: run: | python -c "import jax; print(jax.devices())" - # - name: Run MaxDiffusion Training - # run: | - # # This command is adapted from your DAG for a single-slice configuration. - # NVTE_FUSED_ATTN=1 pip install . && \ - # python -m src.maxdiffusion.train_sdxl src/maxdiffusion/configs/base_xl.yml \ - # hardware=gpu \ - # train_new_unet=true \ - # train_text_encoder=false \ - # cache_latents_text_encoder_outputs=true \ - # per_device_batch_size=1 \ - # attention=dot_product \ - # activations_dtype=bfloat16 \ - # weights_dtype=bfloat16 \ - # max_train_steps=200 \ - # enable_profiler=True \ - # run_name=1slice-VGpuVersion.XPK_H100_a3-maxdiffusion-jax-stable-stack-2025-09-26-04-12-02 \ - # output_dir=gs://rbierneni-multipod-dev/${{ github.run_id }} + - name: Run MaxDiffusion Training + run: | + # This command is adapted from your DAG for a single-slice configuration. + NVTE_FRAMEWORK=JAX NVTE_FUSED_ATTN=1 pip install . && \ + python -m src.maxdiffusion.train_sdxl src/maxdiffusion/configs/base_xl.yml \ + hardware=gpu \ + train_new_unet=true \ + train_text_encoder=false \ + cache_latents_text_encoder_outputs=true \ + per_device_batch_size=1 \ + attention=dot_product \ + activations_dtype=bfloat16 \ + weights_dtype=bfloat16 \ + max_train_steps=200 \ + enable_profiler=True \ + run_name=1slice-VGpuVersion.XPK_H100_a3-maxdiffusion-jax-stable-stack-2025-09-26-04-12-02 \ + output_dir=gs://rbierneni-multipod-dev/${{ github.run_id }} # jobs: # build: From 5f5749a9274d0def2e42a9fde46aeaf2fe72e99a Mon Sep 17 00:00:00 2001 From: Rohan Bierneni Date: Tue, 30 Sep 2025 09:27:50 +0000 Subject: [PATCH 46/50] check if torch is causing the issue --- .github/workflows/UnitTests.yml | 37 +++++++++++++++++---------------- verify_conflict.sh | 22 ++++++++++++++++++++ 2 files changed, 41 insertions(+), 18 deletions(-) create mode 100644 verify_conflict.sh diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index 5f42ec72..9ac54559 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -98,27 +98,28 @@ jobs: # pip install tensorflow-cpu pip freeze - - name: Check per_device_batch_size + - name: Check devices run: | python -c "import jax; print(jax.devices())" + python verify_conflict.py - - name: Run MaxDiffusion Training - run: | - # This command is adapted from your DAG for a single-slice configuration. - NVTE_FRAMEWORK=JAX NVTE_FUSED_ATTN=1 pip install . && \ - python -m src.maxdiffusion.train_sdxl src/maxdiffusion/configs/base_xl.yml \ - hardware=gpu \ - train_new_unet=true \ - train_text_encoder=false \ - cache_latents_text_encoder_outputs=true \ - per_device_batch_size=1 \ - attention=dot_product \ - activations_dtype=bfloat16 \ - weights_dtype=bfloat16 \ - max_train_steps=200 \ - enable_profiler=True \ - run_name=1slice-VGpuVersion.XPK_H100_a3-maxdiffusion-jax-stable-stack-2025-09-26-04-12-02 \ - output_dir=gs://rbierneni-multipod-dev/${{ github.run_id }} + # - name: Run MaxDiffusion Training + # run: | + # # This command is adapted from your DAG for a single-slice configuration. + # NVTE_FRAMEWORK=JAX NVTE_FUSED_ATTN=1 pip install . && \ + # python -m src.maxdiffusion.train_sdxl src/maxdiffusion/configs/base_xl.yml \ + # hardware=gpu \ + # train_new_unet=true \ + # train_text_encoder=false \ + # cache_latents_text_encoder_outputs=true \ + # per_device_batch_size=1 \ + # attention=dot_product \ + # activations_dtype=bfloat16 \ + # weights_dtype=bfloat16 \ + # max_train_steps=200 \ + # enable_profiler=True \ + # run_name=1slice-VGpuVersion.XPK_H100_a3-maxdiffusion-jax-stable-stack-2025-09-26-04-12-02 \ + # output_dir=gs://rbierneni-multipod-dev/${{ github.run_id }} # jobs: # build: diff --git a/verify_conflict.sh b/verify_conflict.sh new file mode 100644 index 00000000..54454f7c --- /dev/null +++ b/verify_conflict.sh @@ -0,0 +1,22 @@ +print("--- PyTorch vs. JAX Conflict Test ---") + +print("\nStep 1: Attempting to import torch...") +try: + import torch + print(f"Successfully imported torch version: {torch.__version__}") + # This check will confirm you have the CPU-only version + print(f"Is PyTorch using CUDA? -> {torch.cuda.is_available()}") +except Exception as e: + print(f"Failed to import torch: {e}") + + +print("\nStep 2: Now, attempting to initialize JAX...") +try: + import jax + devices = jax.devices() + print("\n--- RESULT: SUCCESS ---") + print(f"JAX initialized correctly and found devices: {devices}") +except Exception as e: + print("\n--- RESULT: FAILURE ---") + print("JAX failed to initialize after PyTorch was imported.") + print(f"JAX Error: {e}") \ No newline at end of file From 1b5b8c4e5d518845e63199654480ac54c895ac16 Mon Sep 17 00:00:00 2001 From: Rohan Bierneni Date: Tue, 30 Sep 2025 09:35:08 +0000 Subject: [PATCH 47/50] Run script in workflow --- .github/workflows/UnitTests.yml | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index 9ac54559..3cebd084 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -101,6 +101,34 @@ jobs: - name: Check devices run: | python -c "import jax; print(jax.devices())" + + - name: Run Conflict Verification Script + run: | + # This command creates the file inside the runner + cat <<'EOF' > verify_conflict.py + print("--- PyTorch vs. JAX Conflict Test ---") + + print("\nStep 1: Attempting to import torch...") + try: + import torch + print(f"Successfully imported torch version: {torch.__version__}") + print(f"Is PyTorch using CUDA? -> {torch.cuda.is_available()}") + except Exception as e: + print(f"Failed to import torch: {e}") + + print("\nStep 2: Now, attempting to initialize JAX...") + try: + import jax + devices = jax.devices() + print("\n--- RESULT: SUCCESS ---") + print(f"JAX initialized correctly and found devices: {devices}") + except Exception as e: + print("\n--- RESULT: FAILURE ---") + print("JAX failed to initialize after PyTorch was imported.") + print(f"JAX Error: {e}") + EOF + + # Now that the file exists, this command will work python verify_conflict.py # - name: Run MaxDiffusion Training From c029bdd12ddd107083f53f06db1f89244b16b7db Mon Sep 17 00:00:00 2001 From: Rohan Bierneni Date: Wed, 1 Oct 2025 19:28:40 +0000 Subject: [PATCH 48/50] Try with new image that works --- .github/workflows/UnitTests.yml | 178 +++++--------------------------- 1 file changed, 26 insertions(+), 152 deletions(-) diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index 3cebd084..bf59bc76 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -24,168 +24,42 @@ on: workflow_dispatch: jobs: - # maxtext_workload: - # name: "Run MaxText Workload" - # # IMPORTANT: Replace with the label for your runner (e.g., v5p-8) - # runs-on: ["linux-x86-a3-megagpu-h100-8gpu"] - # container: - # image: gcr.io/tpu-prod-env-multipod/maxtext_stable_stack_candidate_gpu:latest - # steps: - # - name: Checkout MaxText Repo - # uses: actions/checkout@v4 - # with: - # repository: AI-Hypercomputer/maxtext - # path: maxtext - # ref: rbierneni-test-gpu-run - - # - name: Print dependencies - # run: | - # pip uninstall -y transformer-engine transformer-engine-jax transformer-engine-cu12 - # pip install -U transformer-engine[jax]==2.6.0 - # # pip uninstall -y tensorflow - # # pip install tensorflow-cpu - # pip freeze - - # - name: Run MaxText Training - # run: | - # # This command is adapted from your DAG for a single-slice configuration. - # cd maxtext && \ - # pip install . --no-dependencies - - # export XLA_PYTHON_CLIENT_MEM_FRACTION=0.65 - # export TF_FORCE_GPU_ALLOW_GROWTH=true - - # python3 -m MaxText.train MaxText/configs/base.yml \ - # steps=2 \ - # enable_checkpointing=false \ - # attention=cudnn_flash_te \ - # dataset_type=synthetic \ - # run_name=rbierneni-test-maxtext-gpu \ - # base_output_directory=gs://rbierneni-multipod-dev/maxtext/${{ github.run_id }} - - # STAGE 1: PULL MAXDIFFUSION IMAGE AND RUN WORKLOAD - maxdiffusion_workload: - name: "Run MaxDiffusion Workload" + maxtext_workload: + name: "Run MaxText Workload" # IMPORTANT: Replace with the label for your runner (e.g., v5p-8) runs-on: ["linux-x86-a3-megagpu-h100-8gpu"] container: - image: gcr.io/tpu-prod-env-multipod/maxdiffusion_stable_stack_candidate:jax0.7.2_cuda13_te2.6.0 + image: us-docker.pkg.dev/tpu-prod-env-multipod/jax-stable-stack/maxtext-gpu-custom:latest steps: - - name: Checkout Repository + - name: Checkout MaxText Repo uses: actions/checkout@v4 - - - name: Check Host CUDA and GPU Environment - run: | - echo "--- Checking NVIDIA driver and supported CUDA version ---" - nvidia-smi || echo "nvidia-smi command not found. No GPU or NVIDIA driver detected." - - echo "" - echo "--- Checking for default CUDA toolkit installation ---" - ls -l /usr/local/ | grep cuda || echo "No default CUDA toolkit found in /usr/local/" - - echo "" - echo "--- Checking dynamic linker library path ---" - echo "LD_LIBRARY_PATH=${LD_LIBRARY_PATH:-'Not Set'}" - + with: + repository: AI-Hypercomputer/maxtext + path: maxtext + ref: rbierneni-test-gpu-run + - name: Print dependencies run: | pip uninstall -y transformer-engine transformer-engine-jax transformer-engine-cu12 - # pip install transformer_engine[jax]==2.4.0 - # pip install -U transformer-engine[jax]==2.6.0 - # pip uninstall -y transformer-engine-cu12 - # pip install transformer-engine-cu13 + pip install -U transformer-engine[jax]==2.6.0 # pip uninstall -y tensorflow # pip install tensorflow-cpu pip freeze - - name: Check devices - run: | - python -c "import jax; print(jax.devices())" - - - name: Run Conflict Verification Script + - name: Run MaxText Training run: | - # This command creates the file inside the runner - cat <<'EOF' > verify_conflict.py - print("--- PyTorch vs. JAX Conflict Test ---") - - print("\nStep 1: Attempting to import torch...") - try: - import torch - print(f"Successfully imported torch version: {torch.__version__}") - print(f"Is PyTorch using CUDA? -> {torch.cuda.is_available()}") - except Exception as e: - print(f"Failed to import torch: {e}") - - print("\nStep 2: Now, attempting to initialize JAX...") - try: - import jax - devices = jax.devices() - print("\n--- RESULT: SUCCESS ---") - print(f"JAX initialized correctly and found devices: {devices}") - except Exception as e: - print("\n--- RESULT: FAILURE ---") - print("JAX failed to initialize after PyTorch was imported.") - print(f"JAX Error: {e}") - EOF - - # Now that the file exists, this command will work - python verify_conflict.py - - # - name: Run MaxDiffusion Training - # run: | - # # This command is adapted from your DAG for a single-slice configuration. - # NVTE_FRAMEWORK=JAX NVTE_FUSED_ATTN=1 pip install . && \ - # python -m src.maxdiffusion.train_sdxl src/maxdiffusion/configs/base_xl.yml \ - # hardware=gpu \ - # train_new_unet=true \ - # train_text_encoder=false \ - # cache_latents_text_encoder_outputs=true \ - # per_device_batch_size=1 \ - # attention=dot_product \ - # activations_dtype=bfloat16 \ - # weights_dtype=bfloat16 \ - # max_train_steps=200 \ - # enable_profiler=True \ - # run_name=1slice-VGpuVersion.XPK_H100_a3-maxdiffusion-jax-stable-stack-2025-09-26-04-12-02 \ - # output_dir=gs://rbierneni-multipod-dev/${{ github.run_id }} - -# jobs: -# build: -# strategy: -# fail-fast: false -# matrix: -# tpu-type: ["v5p-8"] -# name: "TPU test (${{ matrix.tpu-type }})" -# runs-on: ["self-hosted","${{ matrix.tpu-type }}"] -# steps: -# - uses: actions/checkout@v4 -# - name: Set up Python 3.12 -# uses: actions/setup-python@v5 -# with: -# python-version: '3.12' -# - name: Install dependencies -# run: | -# pip install -e . -# pip uninstall jax jaxlib libtpu-nightly libtpu -y -# bash setup.sh MODE=stable -# export PATH=$PATH:$HOME/.local/bin -# pip install ruff -# pip install isort -# pip install pytest -# - name: Analysing the code with ruff -# run: | -# ruff check . -# - name: version check -# run: | -# python --version -# pip show jax jaxlib flax transformers datasets tensorflow tensorflow_datasets -# - name: PyTest -# run: | #--deselect=src/maxdiffusion/tests/input_pipeline_interface_test.py -# HF_HUB_CACHE=/mnt/disks/github-runner-disk/ HF_HOME=/mnt/disks/github-runner-disk/ TOKENIZERS_PARALLELISM=false python3 -m pytest --deselect=src/maxdiffusion/tests/ltx_transformer_step_test.py -x -# add_pull_ready: -# if: github.ref != 'refs/heads/main' -# permissions: -# checks: read -# pull-requests: write -# needs: build -# uses: ./.github/workflows/AddLabel.yml + # This command is adapted from your DAG for a single-slice configuration. + cd maxtext && \ + pip install . --no-dependencies + + export XLA_PYTHON_CLIENT_MEM_FRACTION=0.65 + export TF_FORCE_GPU_ALLOW_GROWTH=true + export NVTE_FUSED_ATTN=1 + + python3 -m MaxText.train MaxText/configs/base.yml \ + steps=5 \ + enable_checkpointing=false \ + attention=cudnn_flash_te \ + dataset_type=synthetic \ + run_name=rbierneni-test-maxtext-gpu \ + base_output_directory=gs://rbierneni-multipod-dev/maxtext/${{ github.run_id }} From cc956491f50368356bd6ec98c586e221aa22595c Mon Sep 17 00:00:00 2001 From: Rohan Bierneni Date: Wed, 1 Oct 2025 19:34:10 +0000 Subject: [PATCH 49/50] Use original TE from image --- .github/workflows/UnitTests.yml | 5 ----- 1 file changed, 5 deletions(-) diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index bf59bc76..c5b0c79f 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -40,17 +40,12 @@ jobs: - name: Print dependencies run: | - pip uninstall -y transformer-engine transformer-engine-jax transformer-engine-cu12 - pip install -U transformer-engine[jax]==2.6.0 - # pip uninstall -y tensorflow - # pip install tensorflow-cpu pip freeze - name: Run MaxText Training run: | # This command is adapted from your DAG for a single-slice configuration. cd maxtext && \ - pip install . --no-dependencies export XLA_PYTHON_CLIENT_MEM_FRACTION=0.65 export TF_FORCE_GPU_ALLOW_GROWTH=true From 5dfa7a82eba1e99a88355e74648bb1a9506c7730 Mon Sep 17 00:00:00 2001 From: Rohan Bierneni Date: Wed, 1 Oct 2025 19:42:41 +0000 Subject: [PATCH 50/50] Use maxtext at head --- .github/workflows/UnitTests.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index c5b0c79f..f3d579de 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -36,7 +36,6 @@ jobs: with: repository: AI-Hypercomputer/maxtext path: maxtext - ref: rbierneni-test-gpu-run - name: Print dependencies run: |