Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -205,13 +205,16 @@ To install the latest stable version with pip:

# For PyTorch integration
pip install --no-build-isolation transformer_engine[pytorch]

# For JAX integration
pip install --no-build-isolation transformer_engine[jax]

# For both frameworks
pip install --no-build-isolation transformer_engine[pytorch,jax]

# For CUDA 13 support
pip install --no-build-isolation transformer_engine[pytorch_cu13,jax_cu13]

Alternatively, install directly from the GitHub repository:

.. code-block:: bash
Expand All @@ -233,7 +236,7 @@ To install the latest stable version with conda from conda-forge:

# For PyTorch integration
conda install -c conda-forge transformer-engine-torch

# JAX integration (coming soon)

Source Installation
Expand Down
32 changes: 21 additions & 11 deletions build_tools/wheel_utils/Dockerfile.aarch
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,26 @@
FROM quay.io/pypa/manylinux_2_28_aarch64

WORKDIR /TransformerEngine/
COPY ../.. /TransformerEngine/

ARG VER="12-3"
ARG ARCH="aarch64"
RUN dnf -y install vim
ARG ARCH=aarch64
ARG CUDA_VERSION_MAJOR=12
ARG CUDA_VERSION_MINOR=3
ARG BUILD_METAPACKAGE=true
ARG BUILD_COMMON=true
ARG BUILD_PYTORCH=true
ARG BUILD_JAX=true

# Cuda toolkit, cudnn, driver.
# CUDA Toolkit, cuDNN, CUDA driver.
RUN dnf config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/sbsa/cuda-rhel8.repo
RUN dnf -y install epel-release
RUN dnf -y install cuda-compiler-${VER}.${ARCH} \
cuda-libraries-${VER}.${ARCH} \
cuda-libraries-devel-${VER}.${ARCH}
RUN dnf -y install --allowerasing cudnn9-cuda-12
RUN dnf -y install cuda-compiler-${CUDA_VERSION_MAJOR}-${CUDA_VERSION_MINOR}.${ARCH} \
cuda-libraries-${CUDA_VERSION_MAJOR}-${CUDA_VERSION_MINOR}.${ARCH} \
cuda-libraries-devel-${CUDA_VERSION_MAJOR}-${CUDA_VERSION_MINOR}.${ARCH}
RUN dnf -y install --allowerasing cudnn9-cuda-${CUDA_VERSION_MAJOR}
RUN dnf clean all
RUN rm -rf /var/cache/dnf/*
RUN echo "/usr/local/cuda/lib64" >> /etc/ld.so.conf.d/999_nvidia_cuda.conf
RUN dnf -y install cuda-toolkit
RUN dnf -y install cuda-toolkit-${CUDA_VERSION_MAJOR}
RUN dnf clean all
RUN dnf -y install glog.aarch64 glog-devel.aarch64

Expand All @@ -32,5 +35,12 @@ ENV CUDA_ROOT=/usr/local/cuda
ENV CUDA_PATH=/usr/local/cuda
ENV CUDADIR=/usr/local/cuda
ENV NVTE_RELEASE_BUILD=1
ENV CUDA_VERSION_MAJOR=${CUDA_VERSION_MAJOR}
ENV CUDA_VERSION_MINOR=${CUDA_VERSION_MINOR}
ENV BUILD_METAPACKAGE=${BUILD_METAPACKAGE}
ENV BUILD_COMMON=${BUILD_COMMON}
ENV BUILD_PYTORCH=${BUILD_PYTORCH}
ENV BUILD_JAX=${BUILD_JAX}

CMD ["/bin/bash", "/TransformerEngine/build_tools/wheel_utils/build_wheels.sh", "manylinux_2_28_aarch64", "true", "true", "false", "false", "false"]
COPY ../.. /TransformerEngine/
CMD ["/bin/bash", "/TransformerEngine/build_tools/wheel_utils/build_wheels.sh", "manylinux_2_28_aarch64"]
32 changes: 21 additions & 11 deletions build_tools/wheel_utils/Dockerfile.x86
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,26 @@
FROM quay.io/pypa/manylinux_2_28_x86_64

WORKDIR /TransformerEngine/
COPY ../.. /TransformerEngine/

ARG VER="12-3"
ARG ARCH="x86_64"
RUN dnf -y install vim
ARG ARCH=x86_64
ARG CUDA_VERSION_MAJOR=12
ARG CUDA_VERSION_MINOR=3
ARG BUILD_METAPACKAGE=true
ARG BUILD_COMMON=true
ARG BUILD_PYTORCH=true
ARG BUILD_JAX=true

# Cuda toolkit, cudnn, driver.
# CUDA Toolkit, cuDNN, CUDA driver.
RUN dnf config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/x86_64/cuda-rhel8.repo
RUN dnf -y install epel-release
RUN dnf -y install cuda-compiler-${VER}.${ARCH} \
cuda-libraries-${VER}.${ARCH} \
cuda-libraries-devel-${VER}.${ARCH}
RUN dnf -y install --allowerasing cudnn9-cuda-12
RUN dnf -y install cuda-compiler-${CUDA_VERSION_MAJOR}-${CUDA_VERSION_MINOR}.${ARCH} \
cuda-libraries-${CUDA_VERSION_MAJOR}-${CUDA_VERSION_MINOR}.${ARCH} \
cuda-libraries-devel-${CUDA_VERSION_MAJOR}-${CUDA_VERSION_MINOR}.${ARCH}
RUN dnf -y install --allowerasing cudnn9-cuda-${CUDA_VERSION_MAJOR}
RUN dnf clean all
RUN rm -rf /var/cache/dnf/*
RUN echo "/usr/local/cuda/lib64" >> /etc/ld.so.conf.d/999_nvidia_cuda.conf
RUN dnf -y install cuda-toolkit
RUN dnf -y install cuda-toolkit-${CUDA_VERSION_MAJOR}
RUN dnf clean all
RUN dnf -y install glog.x86_64 glog-devel.x86_64

Expand All @@ -32,5 +35,12 @@ ENV CUDA_ROOT=/usr/local/cuda
ENV CUDA_PATH=/usr/local/cuda
ENV CUDADIR=/usr/local/cuda
ENV NVTE_RELEASE_BUILD=1
ENV CUDA_VERSION_MAJOR=${CUDA_VERSION_MAJOR}
ENV CUDA_VERSION_MINOR=${CUDA_VERSION_MINOR}
ENV BUILD_METAPACKAGE=${BUILD_METAPACKAGE}
ENV BUILD_COMMON=${BUILD_COMMON}
ENV BUILD_PYTORCH=${BUILD_PYTORCH}
ENV BUILD_JAX=${BUILD_JAX}

CMD ["/bin/bash", "/TransformerEngine/build_tools/wheel_utils/build_wheels.sh", "manylinux_2_28_x86_64", "true", "true", "true", "true", "true"]
COPY ../.. /TransformerEngine/
CMD ["/bin/bash", "/TransformerEngine/build_tools/wheel_utils/build_wheels.sh", "manylinux_2_28_x86_64"]
18 changes: 10 additions & 8 deletions build_tools/wheel_utils/build_wheels.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@
set -e

PLATFORM=${1:-manylinux_2_28_x86_64}
BUILD_METAPACKAGE=${2:-true}
BUILD_COMMON=${3:-true}
BUILD_PYTORCH=${4:-true}
BUILD_JAX=${5:-true}

export NVTE_RELEASE_BUILD=1
export CUDA_VERSION_MAJOR=${CUDA_VERSION_MAJOR:-12}
export BUILD_METAPACKAGE=${BUILD_METAPACKAGE:-true}
export BUILD_COMMON=${BUILD_COMMON:-true}
export BUILD_PYTORCH=${BUILD_PYTORCH:-true}
export BUILD_JAX=${BUILD_JAX:-true}
export TARGET_BRANCH=${TARGET_BRANCH:-}
export NVTE_BUILD_THREADS_PER_JOB=4
mkdir -p /wheelhouse/logs

# Generate wheels for common library.
Expand Down Expand Up @@ -39,15 +41,15 @@ if $BUILD_COMMON ; then
# Repack the wheel for cuda specific package, i.e. cu12.
/opt/python/cp310-cp310/bin/wheel unpack dist/*
# From python 3.10 to 3.11, the package name delimiter in metadata got changed from - (hyphen) to _ (underscore).
sed -i "s/Name: transformer-engine/Name: transformer-engine-cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA"
sed -i "s/Name: transformer_engine/Name: transformer_engine_cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA"
mv "${WHL_BASE}/${WHL_BASE}.dist-info" "${WHL_BASE}/transformer_engine_cu12-${VERSION}.dist-info"
sed -i "s/Name: transformer-engine/Name: transformer-engine-cu${CUDA_VERSION_MAJOR}/g" "${WHL_BASE}/${WHL_BASE}.dist-info/METADATA"
sed -i "s/Name: transformer_engine/Name: transformer_engine_cu${CUDA_VERSION_MAJOR}/g" "${WHL_BASE}/${WHL_BASE}.dist-info/METADATA"
mv "${WHL_BASE}/${WHL_BASE}.dist-info" "${WHL_BASE}/transformer_engine_cu${CUDA_VERSION_MAJOR}-${VERSION}.dist-info"
/opt/python/cp310-cp310/bin/wheel pack ${WHL_BASE}

# Rename the wheel to make it python version agnostic.
whl_name=$(basename dist/*)
IFS='-' read -ra whl_parts <<< "$whl_name"
whl_name_target="${whl_parts[0]}_cu12-${whl_parts[1]}-py3-none-${whl_parts[4]}"
whl_name_target="${whl_parts[0]}_cu${CUDA_VERSION_MAJOR}-${whl_parts[1]}-py3-none-${whl_parts[4]}"
rm -rf $WHL_BASE dist
mv *.whl /wheelhouse/"$whl_name_target"
fi
Expand Down
45 changes: 42 additions & 3 deletions build_tools/wheel_utils/launch_aarch.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,46 @@
#
# See LICENSE for license information.

docker build --no-cache -t "aarch_wheel" -f build_tools/wheel_utils/Dockerfile.aarch .
docker run --runtime=nvidia --gpus=all --ipc=host "aarch_wheel"
set -ex

# Paths
TMP_DIR=$(mktemp --directory)
rm -rf aarch_wheelhouse
docker cp $(docker ps -aq | head -1):/wheelhouse/ aarch_wheelhouse
mkdir aarch_wheelhouse

# CUDA 12 wheels
docker build \
--no-cache \
--tag aarch_cu12_wheel \
--file build_tools/wheel_utils/Dockerfile.aarch \
--build-arg CUDA_VERSION_MAJOR=12 \
--build-arg CUDA_VERSION_MINOR=3 \
--build-arg BUILD_METAPACKAGE=true \
--build-arg BUILD_COMMON=true \
--build-arg BUILD_PYTORCH=true \
--build-arg BUILD_JAX=true \
.
docker run --runtime=nvidia --gpus=all --ipc=host aarch_cu12_wheel
docker cp $(docker ps -aq | head -1):/wheelhouse ${TMP_DIR}
cp -r ${TMP_DIR}/wheelhouse/* aarch_wheelhouse
rm -rf ${TMP_DIR}/wheelhouse

# CUDA 13 wheels
docker build \
--no-cache \
--tag aarch_cu13_wheel \
--file build_tools/wheel_utils/Dockerfile.aarch \
--build-arg CUDA_VERSION_MAJOR=13 \
--build-arg CUDA_VERSION_MINOR=0 \
--build-arg BUILD_METAPACKAGE=false \
--build-arg BUILD_COMMON=true \
--build-arg BUILD_PYTORCH=false \
--build-arg BUILD_JAX=false \
.
docker run --runtime=nvidia --gpus=all --ipc=host aarch_cu13_wheel
docker cp $(docker ps -aq | head -1):/wheelhouse ${TMP_DIR}
cp -r ${TMP_DIR}/wheelhouse/* aarch_wheelhouse
rm -rf ${TMP_DIR}/wheelhouse

# Clean up
rm -rf ${TMP_DIR}
45 changes: 42 additions & 3 deletions build_tools/wheel_utils/launch_x86.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,46 @@
#
# See LICENSE for license information.

docker build --no-cache -t "x86_wheel" -f build_tools/wheel_utils/Dockerfile.x86 .
docker run --runtime=nvidia --gpus=all --ipc=host "x86_wheel"
set -ex

# Paths
TMP_DIR=$(mktemp --directory)
rm -rf x86_wheelhouse
docker cp $(docker ps -aq | head -1):/wheelhouse x86_wheelhouse
mkdir x86_wheelhouse

# CUDA 12 wheels
docker build \
--no-cache \
--tag x86_cu12_wheel \
--file build_tools/wheel_utils/Dockerfile.x86 \
--build-arg CUDA_VERSION_MAJOR=12 \
--build-arg CUDA_VERSION_MINOR=3 \
--build-arg BUILD_METAPACKAGE=true \
--build-arg BUILD_COMMON=true \
--build-arg BUILD_PYTORCH=true \
--build-arg BUILD_JAX=true \
.
docker run --runtime=nvidia --gpus=all --ipc=host x86_cu12_wheel
docker cp $(docker ps -aq | head -1):/wheelhouse ${TMP_DIR}
cp -r ${TMP_DIR}/wheelhouse/* x86_wheelhouse
rm -rf ${TMP_DIR}/wheelhouse

# CUDA 13 wheels
docker build \
--no-cache \
--tag x86_cu13_wheel \
--file build_tools/wheel_utils/Dockerfile.x86 \
--build-arg CUDA_VERSION_MAJOR=13 \
--build-arg CUDA_VERSION_MINOR=0 \
--build-arg BUILD_METAPACKAGE=false \
--build-arg BUILD_COMMON=true \
--build-arg BUILD_PYTORCH=false \
--build-arg BUILD_JAX=false \
.
docker run --runtime=nvidia --gpus=all --ipc=host x86_cu13_wheel
docker cp $(docker ps -aq | head -1):/wheelhouse ${TMP_DIR}
cp -r ${TMP_DIR}/wheelhouse/* x86_wheelhouse
rm -rf ${TMP_DIR}/wheelhouse

# Clean up
rm -rf ${TMP_DIR}
9 changes: 8 additions & 1 deletion docs/installation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,14 @@ Transformer Engine can be directly installed from `our PyPI <https://pypi.org/pr

pip3 install --no-build-isolation transformer_engine[pytorch]

To obtain the necessary Python bindings for Transformer Engine, the frameworks needed must be explicitly specified as extra dependencies in a comma-separated list (e.g. [jax,pytorch]). Transformer Engine ships wheels for the core library. Source distributions are shipped for the JAX and PyTorch extensions.
To obtain the necessary Python bindings for Transformer Engine, the frameworks must be explicitly specified as extra dependencies in a comma-separated list (e.g. [jax,pytorch]). Transformer Engine ships wheels for the core library. Source distributions are shipped for the JAX and PyTorch extensions.

By default, the wheels are built with CUDA 12. To specify the CUDA version, append `_cu12` or `_cu13` to the framework name:

.. code-block:: bash

pip3 install --no-build-isolation transformer_engine[pytorch_cu13]


pip - from GitHub
-----------------------
Expand Down
28 changes: 25 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,10 +138,32 @@ def setup_requirements() -> Tuple[List[str], List[str]]:
ext_modules = []
package_data = {}
include_package_data = False
install_requires = ([f"transformer_engine_cu12=={__version__}"],)
install_requires = []
extras_require = {
"pytorch": [f"transformer_engine_torch=={__version__}"],
"jax": [f"transformer_engine_jax=={__version__}"],
"pytorch": [
f"transformer_engine_torch=={__version__}",
f"transformer_engine_cu12=={__version__}",
],
"jax": [
f"transformer_engine_jax=={__version__}",
f"transformer_engine_cu12=={__version__}",
],
"pytorch_cu12": [
f"transformer_engine_torch=={__version__}",
f"transformer_engine_cu12=={__version__}",
],
"jax_cu12": [
f"transformer_engine_jax=={__version__}",
f"transformer_engine_cu12=={__version__}",
],
"pytorch_cu13": [
f"transformer_engine_torch=={__version__}",
f"transformer_engine_cu13=={__version__}",
],
"jax_cu13": [
f"transformer_engine_jax=={__version__}",
f"transformer_engine_cu13=={__version__}",
],
}
else:
install_requires, test_requires = setup_requirements()
Expand Down
51 changes: 27 additions & 24 deletions transformer_engine/common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,38 +130,41 @@ def load_framework_extension(framework: str) -> None:
if framework == "torch":
extra_dep_name = "pytorch"

# Check if the core package is installed via PyPI.
found_core_module = False
for core_module_name in ("transformer_engine_cu13", "transformer_engine_cu12"):
if _is_pip_package_installed(core_module_name):
found_core_module = True
break

# If the framework extension pip package is installed, it means that TE is installed via
# PyPI. For this case we need to make sure that the metapackage, the core lib, and framework
# extension are all installed via PyPI and have matching version.
if _is_pip_package_installed(module_name):
assert _is_pip_package_installed(
"transformer_engine"
), "Could not find `transformer-engine`."
assert _is_pip_package_installed(
"transformer_engine_cu12"
), "Could not find `transformer-engine-cu12`."
assert (
version(module_name)
== version("transformer-engine")
== version("transformer-engine-cu12")
), (
"TransformerEngine package version mismatch. Found"
f" {module_name} v{version(module_name)}, transformer-engine"
f" v{version('transformer-engine')}, and transformer-engine-cu12"
f" v{version('transformer-engine-cu12')}. Install transformer-engine using "
f"'pip3 install transformer-engine[{extra_dep_name}]==VERSION'"
)
if not _is_pip_package_installed("transformer_engine"):
raise RuntimeError("Could not find `transformer-engine`.")
if not found_core_module:
raise RuntimeError("Could not find `transformer-engine-cu*`.")
if version(module_name) != version("transformer-engine") or version(
core_module_name
) != version("transformer-engine"):
raise RuntimeError(
"TransformerEngine package version mismatch. Found"
f" {module_name} v{version(module_name)}, transformer-engine"
f" v{version('transformer-engine')}, and {core_module_name}"
f" v{version(core_module_name)}. Install transformer-engine using "
f"'pip3 install transformer-engine[{extra_dep_name}]==VERSION'"
)

# If the core package is installed via PyPI, log if
# the framework extension is not found from PyPI.
# Note: Should we error? This is a rare use case.
if _is_pip_package_installed("transformer-engine-cu12"):
if not _is_pip_package_installed(module_name):
_logger.info(
"Could not find package %s. Install transformer-engine using "
f"'pip3 install transformer-engine[{extra_dep_name}]==VERSION'",
module_name,
)
if found_core_module and not _is_pip_package_installed(module_name):
_logger.info(
"Could not find package %s. Install transformer-engine using "
f"'pip3 install transformer-engine[{extra_dep_name}]==VERSION'",
module_name,
)

# After all checks are completed, load the shared object file.
spec = importlib.util.spec_from_file_location(module_name, _get_shared_object_file(framework))
Expand Down