Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
6e0898a
Update entrypoint for jaii
Rohan-Bierneni Sep 25, 2025
c50ab1b
remove self-hosted tag
Rohan-Bierneni Sep 26, 2025
1f22cf5
wrong tag name
Rohan-Bierneni Sep 26, 2025
2deaa7e
Update with command for gpu
Rohan-Bierneni Sep 26, 2025
d47a6eb
remove pip install .
Rohan-Bierneni Sep 26, 2025
a6579ec
Check environment
Rohan-Bierneni Sep 26, 2025
cc5d322
Update to right image
Rohan-Bierneni Sep 26, 2025
26a03c0
point to allowed bucket
Rohan-Bierneni Sep 26, 2025
9236f6f
Install entire TE
Rohan-Bierneni Sep 26, 2025
4e0dd7f
uninstall all TE deps
Rohan-Bierneni Sep 26, 2025
0888316
install TE jax
Rohan-Bierneni Sep 26, 2025
09c5d6a
fix pip instal typo
Rohan-Bierneni Sep 26, 2025
846e176
Install TE for pytorch and jax
Rohan-Bierneni Sep 26, 2025
496f67c
Comment out tflop calc
Rohan-Bierneni Sep 26, 2025
37524bb
Test with dot_product attention
Rohan-Bierneni Sep 26, 2025
64f30f3
Test if maxtext has same gpu issue
Rohan-Bierneni Sep 26, 2025
53d1d22
change to pip install maxtext package
Rohan-Bierneni Sep 26, 2025
fdcec4c
Install with no dependencies
Rohan-Bierneni Sep 26, 2025
cd9daaf
Use custom maxtext branch
Rohan-Bierneni Sep 26, 2025
e0615dc
use synthetic data
Rohan-Bierneni Sep 26, 2025
323654a
Try with TE flash
Rohan-Bierneni Sep 26, 2025
2313d54
Try with older TE
Rohan-Bierneni Sep 26, 2025
2eafa93
Typo in TE install
Rohan-Bierneni Sep 26, 2025
a1da601
Use TE 2.5.0
Rohan-Bierneni Sep 26, 2025
dfe8fc5
Use TE 2.6.0
Rohan-Bierneni Sep 26, 2025
4243d53
Test with tensorflow-cpu
Rohan-Bierneni Sep 26, 2025
d37e0dd
Test with new cuda13 images and TE 2.6.0
Rohan-Bierneni Sep 30, 2025
c247bee
Update to right image
Rohan-Bierneni Sep 30, 2025
67141da
uninstall cuda12 TE as well
Rohan-Bierneni Sep 30, 2025
4ac11bb
Test with te-cu13 package
Rohan-Bierneni Sep 30, 2025
994d8dc
try with base TE
Rohan-Bierneni Sep 30, 2025
6700b53
uninstall existing TE
Rohan-Bierneni Sep 30, 2025
23d8ca2
test te 2.6.0 jax
Rohan-Bierneni Sep 30, 2025
0885a52
test te 2.6.0 jax
Rohan-Bierneni Sep 30, 2025
22cfa35
remove steps for uninstalling TE
Rohan-Bierneni Sep 30, 2025
ab478a1
check host gpu driver
Rohan-Bierneni Sep 30, 2025
7ff2c23
try with cuda 12 on TE cu12
Rohan-Bierneni Sep 30, 2025
ef2041e
try with last tested TE version
Rohan-Bierneni Sep 30, 2025
ec29e4e
try with no-cache build image
Rohan-Bierneni Sep 30, 2025
ecfacad
update image tag to prevent gke caching
Rohan-Bierneni Sep 30, 2025
1273c3e
try with dot_product
Rohan-Bierneni Sep 30, 2025
d1263eb
try without te
Rohan-Bierneni Sep 30, 2025
9295b8b
use with no te
Rohan-Bierneni Sep 30, 2025
95c1b03
Check jax.devices output
Rohan-Bierneni Sep 30, 2025
333e4f5
try with workload
Rohan-Bierneni Sep 30, 2025
5f5749a
check if torch is causing the issue
Rohan-Bierneni Sep 30, 2025
1b5b8c4
Run script in workflow
Rohan-Bierneni Sep 30, 2025
c029bdd
Try with new image that works
Rohan-Bierneni Oct 1, 2025
cc95649
Use original TE from image
Rohan-Bierneni Oct 1, 2025
5dfa7a8
Use maxtext at head
Rohan-Bierneni Oct 1, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 32 additions & 41 deletions .github/workflows/UnitTests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,47 +22,38 @@ on:
push:
branches: [ "main" ]
workflow_dispatch:
schedule:
# Run the job every 12 hours
- cron: '0 */12 * * *'

jobs:
build:
strategy:
fail-fast: false
matrix:
tpu-type: ["v5p-8"]
name: "TPU test (${{ matrix.tpu-type }})"
runs-on: ["self-hosted","${{ matrix.tpu-type }}"]
maxtext_workload:
name: "Run MaxText Workload"
# IMPORTANT: Replace with the label for your runner (e.g., v5p-8)
runs-on: ["linux-x86-a3-megagpu-h100-8gpu"]
container:
image: us-docker.pkg.dev/tpu-prod-env-multipod/jax-stable-stack/maxtext-gpu-custom:latest
steps:
- uses: actions/checkout@v4
- name: Set up Python 3.12
uses: actions/setup-python@v5
with:
python-version: '3.12'
- name: Install dependencies
run: |
pip install -e .
pip uninstall jax jaxlib libtpu-nightly libtpu -y
bash setup.sh MODE=stable
export PATH=$PATH:$HOME/.local/bin
pip install ruff
pip install isort
pip install pytest
- name: Analysing the code with ruff
run: |
ruff check .
- name: version check
run: |
python --version
pip show jax jaxlib flax transformers datasets tensorflow tensorflow_datasets
- name: PyTest
run: | #--deselect=src/maxdiffusion/tests/input_pipeline_interface_test.py
HF_HUB_CACHE=/mnt/disks/github-runner-disk/ HF_HOME=/mnt/disks/github-runner-disk/ TOKENIZERS_PARALLELISM=false python3 -m pytest --deselect=src/maxdiffusion/tests/ltx_transformer_step_test.py -x
# add_pull_ready:
# if: github.ref != 'refs/heads/main'
# permissions:
# checks: read
# pull-requests: write
# needs: build
# uses: ./.github/workflows/AddLabel.yml
- name: Checkout MaxText Repo
uses: actions/checkout@v4
with:
repository: AI-Hypercomputer/maxtext
path: maxtext

- name: Print dependencies
run: |
pip freeze

- name: Run MaxText Training
run: |
# This command is adapted from your DAG for a single-slice configuration.
cd maxtext && \

export XLA_PYTHON_CLIENT_MEM_FRACTION=0.65
export TF_FORCE_GPU_ALLOW_GROWTH=true
export NVTE_FUSED_ATTN=1

python3 -m MaxText.train MaxText/configs/base.yml \
steps=5 \
enable_checkpointing=false \
attention=cudnn_flash_te \
dataset_type=synthetic \
run_name=rbierneni-test-maxtext-gpu \
base_output_directory=gs://rbierneni-multipod-dev/maxtext/${{ github.run_id }}
2 changes: 1 addition & 1 deletion maxdiffusion_jax_ai_image_tpu.Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,4 @@ COPY . .
RUN pip install -r /deps/requirements_with_jax_ai_image.txt

# Run the script available in JAX-AI-Image base image to generate the manifest file
RUN bash /jax-stable-stack/generate_manifest.sh PREFIX=maxdiffusion COMMIT_HASH=$COMMIT_HASH
RUN bash /jax-ai-image/generate_manifest.sh PREFIX=maxdiffusion COMMIT_HASH=$COMMIT_HASH
22 changes: 22 additions & 0 deletions verify_conflict.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
print("--- PyTorch vs. JAX Conflict Test ---")

print("\nStep 1: Attempting to import torch...")
try:
import torch
print(f"Successfully imported torch version: {torch.__version__}")
# This check will confirm you have the CPU-only version
print(f"Is PyTorch using CUDA? -> {torch.cuda.is_available()}")
except Exception as e:
print(f"Failed to import torch: {e}")


print("\nStep 2: Now, attempting to initialize JAX...")
try:
import jax
devices = jax.devices()
print("\n--- RESULT: SUCCESS ---")
print(f"JAX initialized correctly and found devices: {devices}")
except Exception as e:
print("\n--- RESULT: FAILURE ---")
print("JAX failed to initialize after PyTorch was imported.")
print(f"JAX Error: {e}")
Loading