From ba9b4e520fd506ba621720fd39ee5f53180e0a8b Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Mon, 11 Aug 2025 17:50:14 +0000 Subject: [PATCH 1/5] Change dep structure to support cuda-13 Signed-off-by: Kirthi Shankar Sivamani --- build_tools/wheel_utils/Dockerfile.x86 | 2 +- setup.py | 27 +++++++++++++++++++++++--- 2 files changed, 25 insertions(+), 4 deletions(-) diff --git a/build_tools/wheel_utils/Dockerfile.x86 b/build_tools/wheel_utils/Dockerfile.x86 index 26122eed9b..93dfd2496f 100644 --- a/build_tools/wheel_utils/Dockerfile.x86 +++ b/build_tools/wheel_utils/Dockerfile.x86 @@ -21,7 +21,7 @@ RUN dnf -y install --allowerasing cudnn9-cuda-12 RUN dnf clean all RUN rm -rf /var/cache/dnf/* RUN echo "/usr/local/cuda/lib64" >> /etc/ld.so.conf.d/999_nvidia_cuda.conf -RUN dnf -y install cuda-toolkit +RUN dnf -y install cuda-toolkit-12 RUN dnf clean all RUN dnf -y install glog.x86_64 glog-devel.x86_64 diff --git a/setup.py b/setup.py index 0b1b523277..64b677c608 100644 --- a/setup.py +++ b/setup.py @@ -125,10 +125,31 @@ def setup_requirements() -> Tuple[List[str], List[str]]: ext_modules = [] package_data = {} include_package_data = False - install_requires = ([f"transformer_engine_cu12=={__version__}"],) extras_require = { - "pytorch": [f"transformer_engine_torch=={__version__}"], - "jax": [f"transformer_engine_jax=={__version__}"], + "pytorch": [ + f"transformer_engine_torch=={__version__}", + f"transformer_engine_cu12=={__version__}", + ], + "jax": [ + f"transformer_engine_jax=={__version__}", + f"transformer_engine_cu12=={__version__}", + ], + "pytorch_cu12": [ + f"transformer_engine_torch=={__version__}", + f"transformer_engine_cu12=={__version__}", + ], + "jax_cu12": [ + f"transformer_engine_jax=={__version__}", + f"transformer_engine_cu12=={__version__}", + ], + "pytorch_cu13": [ + f"transformer_engine_torch=={__version__}", + f"transformer_engine_cu13=={__version__}", + ], + "jax_cu13": [ + f"transformer_engine_jax=={__version__}", + f"transformer_engine_cu13=={__version__}", + ], } else: install_requires, test_requires = setup_requirements() From 6029e1bfcb9c03fbddc737a9843460033da434d9 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Fri, 15 Aug 2025 07:31:34 +0000 Subject: [PATCH 2/5] Add CUDA 13 build to wheel scripts Signed-off-by: Tim Moon --- build_tools/wheel_utils/Dockerfile.aarch | 32 +++++++++++------ build_tools/wheel_utils/Dockerfile.x86 | 32 +++++++++++------ build_tools/wheel_utils/build_wheels.sh | 18 +++++----- build_tools/wheel_utils/launch_aarch.sh | 45 ++++++++++++++++++++++-- build_tools/wheel_utils/launch_x86.sh | 45 ++++++++++++++++++++++-- setup.py | 1 + 6 files changed, 137 insertions(+), 36 deletions(-) diff --git a/build_tools/wheel_utils/Dockerfile.aarch b/build_tools/wheel_utils/Dockerfile.aarch index 223c4a7f1c..292a047a43 100644 --- a/build_tools/wheel_utils/Dockerfile.aarch +++ b/build_tools/wheel_utils/Dockerfile.aarch @@ -5,23 +5,26 @@ FROM quay.io/pypa/manylinux_2_28_aarch64 WORKDIR /TransformerEngine/ -COPY ../.. /TransformerEngine/ -ARG VER="12-3" -ARG ARCH="aarch64" -RUN dnf -y install vim +ARG ARCH=aarch64 +ARG CUDA_VERSION_MAJOR=12 +ARG CUDA_VERSION_MINOR=3 +ARG BUILD_METAPACKAGE=true +ARG BUILD_COMMON=true +ARG BUILD_PYTORCH=true +ARG BUILD_JAX=true -# Cuda toolkit, cudnn, driver. +# CUDA Toolkit, cuDNN, CUDA driver. RUN dnf config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/sbsa/cuda-rhel8.repo RUN dnf -y install epel-release -RUN dnf -y install cuda-compiler-${VER}.${ARCH} \ - cuda-libraries-${VER}.${ARCH} \ - cuda-libraries-devel-${VER}.${ARCH} -RUN dnf -y install --allowerasing cudnn9-cuda-12 +RUN dnf -y install cuda-compiler-${CUDA_VERSION_MAJOR}-${CUDA_VERSION_MINOR}.${ARCH} \ + cuda-libraries-${CUDA_VERSION_MAJOR}-${CUDA_VERSION_MINOR}.${ARCH} \ + cuda-libraries-devel-${CUDA_VERSION_MAJOR}-${CUDA_VERSION_MINOR}.${ARCH} +RUN dnf -y install --allowerasing cudnn9-cuda-${CUDA_VERSION_MAJOR} RUN dnf clean all RUN rm -rf /var/cache/dnf/* RUN echo "/usr/local/cuda/lib64" >> /etc/ld.so.conf.d/999_nvidia_cuda.conf -RUN dnf -y install cuda-toolkit +RUN dnf -y install cuda-toolkit-${CUDA_VERSION_MAJOR} RUN dnf clean all RUN dnf -y install glog.aarch64 glog-devel.aarch64 @@ -32,5 +35,12 @@ ENV CUDA_ROOT=/usr/local/cuda ENV CUDA_PATH=/usr/local/cuda ENV CUDADIR=/usr/local/cuda ENV NVTE_RELEASE_BUILD=1 +ENV CUDA_VERSION_MAJOR=${CUDA_VERSION_MAJOR} +ENV CUDA_VERSION_MINOR=${CUDA_VERSION_MINOR} +ENV BUILD_METAPACKAGE=${BUILD_METAPACKAGE} +ENV BUILD_COMMON=${BUILD_COMMON} +ENV BUILD_PYTORCH=${BUILD_PYTORCH} +ENV BUILD_JAX=${BUILD_JAX} -CMD ["/bin/bash", "/TransformerEngine/build_tools/wheel_utils/build_wheels.sh", "manylinux_2_28_aarch64", "true", "true", "false", "false", "false"] +COPY ../.. /TransformerEngine/ +CMD ["/bin/bash", "/TransformerEngine/build_tools/wheel_utils/build_wheels.sh", "manylinux_2_28_aarch64"] diff --git a/build_tools/wheel_utils/Dockerfile.x86 b/build_tools/wheel_utils/Dockerfile.x86 index 93dfd2496f..1884181d83 100644 --- a/build_tools/wheel_utils/Dockerfile.x86 +++ b/build_tools/wheel_utils/Dockerfile.x86 @@ -5,23 +5,26 @@ FROM quay.io/pypa/manylinux_2_28_x86_64 WORKDIR /TransformerEngine/ -COPY ../.. /TransformerEngine/ -ARG VER="12-3" -ARG ARCH="x86_64" -RUN dnf -y install vim +ARG ARCH=x86_64 +ARG CUDA_VERSION_MAJOR=12 +ARG CUDA_VERSION_MINOR=3 +ARG BUILD_METAPACKAGE=true +ARG BUILD_COMMON=true +ARG BUILD_PYTORCH=true +ARG BUILD_JAX=true -# Cuda toolkit, cudnn, driver. +# CUDA Toolkit, cuDNN, CUDA driver. RUN dnf config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/x86_64/cuda-rhel8.repo RUN dnf -y install epel-release -RUN dnf -y install cuda-compiler-${VER}.${ARCH} \ - cuda-libraries-${VER}.${ARCH} \ - cuda-libraries-devel-${VER}.${ARCH} -RUN dnf -y install --allowerasing cudnn9-cuda-12 +RUN dnf -y install cuda-compiler-${CUDA_VERSION_MAJOR}-${CUDA_VERSION_MINOR}.${ARCH} \ + cuda-libraries-${CUDA_VERSION_MAJOR}-${CUDA_VERSION_MINOR}.${ARCH} \ + cuda-libraries-devel-${CUDA_VERSION_MAJOR}-${CUDA_VERSION_MINOR}.${ARCH} +RUN dnf -y install --allowerasing cudnn9-cuda-${CUDA_VERSION_MAJOR} RUN dnf clean all RUN rm -rf /var/cache/dnf/* RUN echo "/usr/local/cuda/lib64" >> /etc/ld.so.conf.d/999_nvidia_cuda.conf -RUN dnf -y install cuda-toolkit-12 +RUN dnf -y install cuda-toolkit-${CUDA_VERSION_MAJOR} RUN dnf clean all RUN dnf -y install glog.x86_64 glog-devel.x86_64 @@ -32,5 +35,12 @@ ENV CUDA_ROOT=/usr/local/cuda ENV CUDA_PATH=/usr/local/cuda ENV CUDADIR=/usr/local/cuda ENV NVTE_RELEASE_BUILD=1 +ENV CUDA_VERSION_MAJOR=${CUDA_VERSION_MAJOR} +ENV CUDA_VERSION_MINOR=${CUDA_VERSION_MINOR} +ENV BUILD_METAPACKAGE=${BUILD_METAPACKAGE} +ENV BUILD_COMMON=${BUILD_COMMON} +ENV BUILD_PYTORCH=${BUILD_PYTORCH} +ENV BUILD_JAX=${BUILD_JAX} -CMD ["/bin/bash", "/TransformerEngine/build_tools/wheel_utils/build_wheels.sh", "manylinux_2_28_x86_64", "true", "true", "true", "true", "true"] +COPY ../.. /TransformerEngine/ +CMD ["/bin/bash", "/TransformerEngine/build_tools/wheel_utils/build_wheels.sh", "manylinux_2_28_x86_64"] diff --git a/build_tools/wheel_utils/build_wheels.sh b/build_tools/wheel_utils/build_wheels.sh index bf4f9d2bc2..c8e7278225 100644 --- a/build_tools/wheel_utils/build_wheels.sh +++ b/build_tools/wheel_utils/build_wheels.sh @@ -5,13 +5,15 @@ set -e PLATFORM=${1:-manylinux_2_28_x86_64} -BUILD_METAPACKAGE=${2:-true} -BUILD_COMMON=${3:-true} -BUILD_PYTORCH=${4:-true} -BUILD_JAX=${5:-true} export NVTE_RELEASE_BUILD=1 +export CUDA_VERSION_MAJOR=${CUDA_VERSION_MAJOR:-12} +export BUILD_METAPACKAGE=${BUILD_METAPACKAGE:-true} +export BUILD_COMMON=${BUILD_COMMON:-true} +export BUILD_PYTORCH=${BUILD_PYTORCH:-true} +export BUILD_JAX=${BUILD_JAX:-true} export TARGET_BRANCH=${TARGET_BRANCH:-} +export NVTE_BUILD_THREADS_PER_JOB=4 mkdir -p /wheelhouse/logs # Generate wheels for common library. @@ -39,15 +41,15 @@ if $BUILD_COMMON ; then # Repack the wheel for cuda specific package, i.e. cu12. /opt/python/cp310-cp310/bin/wheel unpack dist/* # From python 3.10 to 3.11, the package name delimiter in metadata got changed from - (hyphen) to _ (underscore). - sed -i "s/Name: transformer-engine/Name: transformer-engine-cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA" - sed -i "s/Name: transformer_engine/Name: transformer_engine_cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA" - mv "${WHL_BASE}/${WHL_BASE}.dist-info" "${WHL_BASE}/transformer_engine_cu12-${VERSION}.dist-info" + sed -i "s/Name: transformer-engine/Name: transformer-engine-cu${CUDA_VERSION_MAJOR}/g" "${WHL_BASE}/${WHL_BASE}.dist-info/METADATA" + sed -i "s/Name: transformer_engine/Name: transformer_engine_cu${CUDA_VERSION_MAJOR}/g" "${WHL_BASE}/${WHL_BASE}.dist-info/METADATA" + mv "${WHL_BASE}/${WHL_BASE}.dist-info" "${WHL_BASE}/transformer_engine_cu${CUDA_VERSION_MAJOR}-${VERSION}.dist-info" /opt/python/cp310-cp310/bin/wheel pack ${WHL_BASE} # Rename the wheel to make it python version agnostic. whl_name=$(basename dist/*) IFS='-' read -ra whl_parts <<< "$whl_name" - whl_name_target="${whl_parts[0]}_cu12-${whl_parts[1]}-py3-none-${whl_parts[4]}" + whl_name_target="${whl_parts[0]}_cu${CUDA_VERSION_MAJOR}-${whl_parts[1]}-py3-none-${whl_parts[4]}" rm -rf $WHL_BASE dist mv *.whl /wheelhouse/"$whl_name_target" fi diff --git a/build_tools/wheel_utils/launch_aarch.sh b/build_tools/wheel_utils/launch_aarch.sh index 04e3cd6916..3e4cd4bc7e 100644 --- a/build_tools/wheel_utils/launch_aarch.sh +++ b/build_tools/wheel_utils/launch_aarch.sh @@ -2,7 +2,46 @@ # # See LICENSE for license information. -docker build --no-cache -t "aarch_wheel" -f build_tools/wheel_utils/Dockerfile.aarch . -docker run --runtime=nvidia --gpus=all --ipc=host "aarch_wheel" +set -ex + +# Paths +TMP_DIR=$(mktemp --directory) rm -rf aarch_wheelhouse -docker cp $(docker ps -aq | head -1):/wheelhouse/ aarch_wheelhouse +mkdir aarch_wheelhouse + +# CUDA 12 wheels +docker build \ + --no-cache \ + --tag aarch_cu12_wheel \ + --file build_tools/wheel_utils/Dockerfile.aarch \ + --build-arg CUDA_VERSION_MAJOR=12 \ + --build-arg CUDA_VERSION_MINOR=3 \ + --build-arg BUILD_METAPACKAGE=true \ + --build-arg BUILD_COMMON=true \ + --build-arg BUILD_PYTORCH=true \ + --build-arg BUILD_JAX=true \ + . +docker run --runtime=nvidia --gpus=all --ipc=host aarch_cu12_wheel +docker cp $(docker ps -aq | head -1):/wheelhouse ${TMP_DIR} +cp -r ${TMP_DIR}/wheelhouse/* aarch_wheelhouse +rm -rf ${TMP_DIR}/wheelhouse + +# CUDA 13 wheels +docker build \ + --no-cache \ + --tag aarch_cu13_wheel \ + --file build_tools/wheel_utils/Dockerfile.aarch \ + --build-arg CUDA_VERSION_MAJOR=13 \ + --build-arg CUDA_VERSION_MINOR=0 \ + --build-arg BUILD_METAPACKAGE=false \ + --build-arg BUILD_COMMON=true \ + --build-arg BUILD_PYTORCH=false \ + --build-arg BUILD_JAX=false \ + . +docker run --runtime=nvidia --gpus=all --ipc=host aarch_cu13_wheel +docker cp $(docker ps -aq | head -1):/wheelhouse ${TMP_DIR} +cp -r ${TMP_DIR}/wheelhouse/* aarch_wheelhouse +rm -rf ${TMP_DIR}/wheelhouse + +# Clean up +rm -rf ${TMP_DIR} diff --git a/build_tools/wheel_utils/launch_x86.sh b/build_tools/wheel_utils/launch_x86.sh index b0d20be3f4..f804422e54 100644 --- a/build_tools/wheel_utils/launch_x86.sh +++ b/build_tools/wheel_utils/launch_x86.sh @@ -2,7 +2,46 @@ # # See LICENSE for license information. -docker build --no-cache -t "x86_wheel" -f build_tools/wheel_utils/Dockerfile.x86 . -docker run --runtime=nvidia --gpus=all --ipc=host "x86_wheel" +set -ex + +# Paths +TMP_DIR=$(mktemp --directory) rm -rf x86_wheelhouse -docker cp $(docker ps -aq | head -1):/wheelhouse x86_wheelhouse +mkdir x86_wheelhouse + +# CUDA 12 wheels +docker build \ + --no-cache \ + --tag x86_cu12_wheel \ + --file build_tools/wheel_utils/Dockerfile.x86 \ + --build-arg CUDA_VERSION_MAJOR=12 \ + --build-arg CUDA_VERSION_MINOR=3 \ + --build-arg BUILD_METAPACKAGE=true \ + --build-arg BUILD_COMMON=true \ + --build-arg BUILD_PYTORCH=true \ + --build-arg BUILD_JAX=true \ + . +docker run --runtime=nvidia --gpus=all --ipc=host x86_cu12_wheel +docker cp $(docker ps -aq | head -1):/wheelhouse ${TMP_DIR} +cp -r ${TMP_DIR}/wheelhouse/* x86_wheelhouse +rm -rf ${TMP_DIR}/wheelhouse + +# CUDA 13 wheels +docker build \ + --no-cache \ + --tag x86_cu13_wheel \ + --file build_tools/wheel_utils/Dockerfile.x86 \ + --build-arg CUDA_VERSION_MAJOR=13 \ + --build-arg CUDA_VERSION_MINOR=0 \ + --build-arg BUILD_METAPACKAGE=false \ + --build-arg BUILD_COMMON=true \ + --build-arg BUILD_PYTORCH=false \ + --build-arg BUILD_JAX=false \ + . +docker run --runtime=nvidia --gpus=all --ipc=host x86_cu13_wheel +docker cp $(docker ps -aq | head -1):/wheelhouse ${TMP_DIR} +cp -r ${TMP_DIR}/wheelhouse/* x86_wheelhouse +rm -rf ${TMP_DIR}/wheelhouse + +# Clean up +rm -rf ${TMP_DIR} diff --git a/setup.py b/setup.py index 64b677c608..20026b2e18 100644 --- a/setup.py +++ b/setup.py @@ -125,6 +125,7 @@ def setup_requirements() -> Tuple[List[str], List[str]]: ext_modules = [] package_data = {} include_package_data = False + install_requires = [] extras_require = { "pytorch": [ f"transformer_engine_torch=={__version__}", From e9ce4141a23e13a26f7a4030df0aa3f29053ea0a Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Fri, 15 Aug 2025 09:27:10 +0000 Subject: [PATCH 3/5] Support importing core module from CUDA 13 wheel Signed-off-by: Tim Moon --- transformer_engine/common/__init__.py | 52 ++++++++++++++------------- 1 file changed, 28 insertions(+), 24 deletions(-) diff --git a/transformer_engine/common/__init__.py b/transformer_engine/common/__init__.py index 834c4fe259..1e8815a381 100644 --- a/transformer_engine/common/__init__.py +++ b/transformer_engine/common/__init__.py @@ -130,38 +130,42 @@ def load_framework_extension(framework: str) -> None: if framework == "torch": extra_dep_name = "pytorch" + # Check if the core package is installed via PyPI. + found_core_module = False + for core_module_name in ("transformer_engine_cu13", "transformer_engine_cu12"): + if _is_pip_package_installed(core_module_name): + found_core_module = True + break + # If the framework extension pip package is installed, it means that TE is installed via # PyPI. For this case we need to make sure that the metapackage, the core lib, and framework # extension are all installed via PyPI and have matching version. if _is_pip_package_installed(module_name): - assert _is_pip_package_installed( - "transformer_engine" - ), "Could not find `transformer-engine`." - assert _is_pip_package_installed( - "transformer_engine_cu12" - ), "Could not find `transformer-engine-cu12`." - assert ( - version(module_name) - == version("transformer-engine") - == version("transformer-engine-cu12") - ), ( - "TransformerEngine package version mismatch. Found" - f" {module_name} v{version(module_name)}, transformer-engine" - f" v{version('transformer-engine')}, and transformer-engine-cu12" - f" v{version('transformer-engine-cu12')}. Install transformer-engine using " - f"'pip3 install transformer-engine[{extra_dep_name}]==VERSION'" - ) + if not _is_pip_package_installed("transformer_engine"): + raise RuntimeError("Could not find `transformer-engine`.") + if not found_core_module: + raise RuntimeError("Could not find `transformer-engine-cu*`.") + if ( + version(module_name) != version("transformer-engine") + or version(core_module_name) != version("transformer-engine") + ): + raise RuntimeError( + "TransformerEngine package version mismatch. Found" + f" {module_name} v{version(module_name)}, transformer-engine" + f" v{version('transformer-engine')}, and {core_module_name}" + f" v{version(core_module_name)}. Install transformer-engine using " + f"'pip3 install transformer-engine[{extra_dep_name}]==VERSION'" + ) # If the core package is installed via PyPI, log if # the framework extension is not found from PyPI. # Note: Should we error? This is a rare use case. - if _is_pip_package_installed("transformer-engine-cu12"): - if not _is_pip_package_installed(module_name): - _logger.info( - "Could not find package %s. Install transformer-engine using " - f"'pip3 install transformer-engine[{extra_dep_name}]==VERSION'", - module_name, - ) + if found_core_module and not _is_pip_package_installed(module_name): + _logger.info( + "Could not find package %s. Install transformer-engine using " + f"'pip3 install transformer-engine[{extra_dep_name}]==VERSION'", + module_name, + ) # After all checks are completed, load the shared object file. spec = importlib.util.spec_from_file_location(module_name, _get_shared_object_file(framework)) From 8d7723000a70c7bad7d3f1c2915768ee890c254d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 15 Aug 2025 09:27:49 +0000 Subject: [PATCH 4/5] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/common/__init__.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/transformer_engine/common/__init__.py b/transformer_engine/common/__init__.py index 1e8815a381..722a6a7bd6 100644 --- a/transformer_engine/common/__init__.py +++ b/transformer_engine/common/__init__.py @@ -145,10 +145,9 @@ def load_framework_extension(framework: str) -> None: raise RuntimeError("Could not find `transformer-engine`.") if not found_core_module: raise RuntimeError("Could not find `transformer-engine-cu*`.") - if ( - version(module_name) != version("transformer-engine") - or version(core_module_name) != version("transformer-engine") - ): + if version(module_name) != version("transformer-engine") or version( + core_module_name + ) != version("transformer-engine"): raise RuntimeError( "TransformerEngine package version mismatch. Found" f" {module_name} v{version(module_name)}, transformer-engine" From 05a47af8cece447524f601ef216f21cbcedb6e2f Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Fri, 22 Aug 2025 23:53:01 +0000 Subject: [PATCH 5/5] Update docs Signed-off-by: Tim Moon --- README.rst | 9 ++++++--- docs/installation.rst | 9 ++++++++- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/README.rst b/README.rst index 19ab1a7d91..d8c58c8e5b 100644 --- a/README.rst +++ b/README.rst @@ -205,13 +205,16 @@ To install the latest stable version with pip: # For PyTorch integration pip install --no-build-isolation transformer_engine[pytorch] - + # For JAX integration pip install --no-build-isolation transformer_engine[jax] - + # For both frameworks pip install --no-build-isolation transformer_engine[pytorch,jax] + # For CUDA 13 support + pip install --no-build-isolation transformer_engine[pytorch_cu13,jax_cu13] + Alternatively, install directly from the GitHub repository: .. code-block:: bash @@ -233,7 +236,7 @@ To install the latest stable version with conda from conda-forge: # For PyTorch integration conda install -c conda-forge transformer-engine-torch - + # JAX integration (coming soon) Source Installation diff --git a/docs/installation.rst b/docs/installation.rst index ecb1e9a0dd..5acdef40c4 100644 --- a/docs/installation.rst +++ b/docs/installation.rst @@ -36,7 +36,14 @@ Transformer Engine can be directly installed from `our PyPI