Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TransformerEngine v1.2.1 throws CuDNN frontend error on H100 GPU (AWS p5.48xlarge instance) #651

Open
sirutBuasai opened this issue Feb 2, 2024 · 10 comments
Labels
bug Something isn't working

Comments

@sirutBuasai
Copy link

Hi, we are currently running into TransformerEngine related error when running GPT model on H100 GPU (AWS p5.48xlarge).
Below is the error log

Error:

RuntimeErrorRuntimeError: RuntimeError/fsx/sbuasai/test_te/deps/TransformerEngine/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu:227 in function operator(): cuDNN Error: [cudnn_frontend] Error: No execution plans built successfully.. For more information, enable cuDNN error logging by setting CUDNN_LOGERR_DBG=1 and CUDNN_LOGDEST_DBG=stderr in the environment.RuntimeError
: : : /fsx/sbuasai/test_te/deps/TransformerEngine/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu:227 in function operator(): cuDNN Error: [cudnn_frontend] Error: No execution plans built successfully.. For more information, enable cuDNN error logging by setting CUDNN_LOGERR_DBG=1 and CUDNN_LOGDEST_DBG=stderr in the environment./fsx/sbuasai/test_te/deps/TransformerEngine/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu:227 in function operator(): cuDNN Error: [cudnn_frontend] Error: No execution plans built successfully.. For more information, enable cuDNN error logging by setting CUDNN_LOGERR_DBG=1 and CUDNN_LOGDEST_DBG=stderr in the environment.    /fsx/sbuasai/test_te/deps/TransformerEngine/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu:227 in function operator(): cuDNN Error: [cudnn_frontend] Error: No execution plans built successfully.. For more information, enable cuDNN error logging by setting CUDNN_LOGERR_DBG=1 and CUDNN_LOGDEST_DBG=stderr in the environment.

output_tensors = tex.fused_attn_fwd(

return fn(*args, **kwargs)
  File "/home/ubuntu/.conda/envs/megatron_bench/lib/python3.10/site-packages/transformer_engine/pytorch/attention.py", line 1835, in forward
RuntimeError: /fsx/sbuasai/test_te/deps/TransformerEngine/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu:227 in function operator(): cuDNN Error: [cudnn_frontend] Error: No execution plans built successfully.. For more information, enable cuDNN error logging by setting CUDNN_LOGERR_DBG=1 and CUDNN_LOGDEST_DBG=stderr in the environment.
    out, aux_ctx_tensors = fused_attn_fwd(
  File "/home/ubuntu/.conda/envs/megatron_bench/lib/python3.10/site-packages/transformer_engine/pytorch/cpp_extensions/fused_attn.py", line 811, in fused_attn_fwd
    output = FusedAttnFunc.apply(
  File "/home/ubuntu/.conda/envs/megatron_bench/lib/python3.10/site-packages/torch/autograd/function.py", line 539, in apply
    output_tensors = tex.fused_attn_fwd(
RuntimeError: /fsx/sbuasai/test_te/deps/TransformerEngine/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu:227 in function operator(): cuDNN Error: [cudnn_frontend] Error: No execution plans built successfully.. For more information, enable cuDNN error logging by setting CUDNN_LOGERR_DBG=1 and CUDNN_LOGDEST_DBG=stderr in the environment.
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/home/ubuntu/.conda/envs/megatron_bench/lib/python3.10/site-packages/transformer_engine/pytorch/attention.py", line 1616, in forward
    output_tensors = tex.fused_attn_fwd(
RuntimeError: /fsx/sbuasai/test_te/deps/TransformerEngine/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu:227 in function operator(): cuDNN Error: [cudnn_frontend] Error: No execution plans built successfully.. For more information, enable cuDNN error logging by setting CUDNN_LOGERR_DBG=1 and CUDNN_LOGDEST_DBG=stderr in the environment.
    out, aux_ctx_tensors = fused_attn_fwd(
  File "/home/ubuntu/.conda/envs/megatron_bench/lib/python3.10/site-packages/transformer_engine/pytorch/cpp_extensions/fused_attn.py", line 811, in fused_attn_fwd
    output_tensors = tex.fused_attn_fwd(
RuntimeError: /fsx/sbuasai/test_te/deps/TransformerEngine/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu:227 in function operator(): cuDNN Error: [cudnn_frontend] Error: No execution plans built successfully.. For more information, enable cuDNN error logging by setting CUDNN_LOGERR_DBG=1 and CUDNN_LOGDEST_DBG=stderr in the environment.
[2024-02-02 01:08:10,718] torch.distributed.elastic.multiprocessing.api: [WARNING] Sending process 24721 closing signal SIGTERM
[2024-02-02 01:08:10,750] torch.distributed.elastic.multiprocessing.api: [ERROR] failed (exitcode: 1) local_rank: 0 (pid: 24718) of binary: /home/ubuntu/.conda/envs/megatron_bench/bin/python

Steps to reproduce:

  1. Create conda env with conda env create -f megatron_bench.yml and conda activate megatron_bench
  2. Install flash-attnetion, TransformerEngine, apex, and Megatron-LM from source declared in install_deps.sh.
  3. Update the path to the data in train.sh.
  4. Run training using script using train.sh.

megatron_bench.yml:

name: megatron_bench
channels:
  - pytorch
  - nvidia
dependencies:
  - python=3.10
  - pip
  - conda:
    - python=3.10
    - pytorch=2.1.2
    - pytorch-cuda=12.1
    - torchvision
    - torchaudio

install_deps.sh:

#!/bin/bash

set -e

# ===================================================
# Set dependencies pin
# ===================================================
FLASH_ATTN_BRANCH='v2.0.4'
TE_BRANCH='v1.2.1'
APEX_HASH='6c8f384b40a596bbed960f5e8d9a808ebd0e93d8'
MEGATRON_LM_HASH='2c3468a49ed51324ae9b442e0d88416f1b29422b'

# ===================================================
# Install megatron python dependencies
# ===================================================
conda install -y regex astunparse ninja pyyaml mkl mkl-include setuptools cmake cffi typing_extensions future six requests libcurl dataclasses packaging
pip install six regex tensorboardX daal4py deepspeed pyarrow pybind11 numpy==1.23.5

# ===================================================
# Install flash-attention
# ===================================================
cd $DEPS_DIR
if [ ! -d "$DEPS_DIR/flash-attention" ]; then
  git clone -b ${FLASH_ATTN_BRANCH} https://github.com/Dao-AILab/flash-attention.git
  cd flash-attention
  python setup.py install
fi

# ===================================================
# Install TransformerEngine
# ===================================================
cd $DEPS_DIR
if [ ! -d "$DEPS_DIR/TransformerEngine" ]; then
  git clone --branch stable --recursive https://github.com/NVIDIA/TransformerEngine.git
  cd TransformerEngine
  git checkout ${TE_BRANCH}
  git submodule update --init --recursive
  export NVTE_FRAMEWORK="pytorch"
  export CUDNN_PATH=/usr/local/cuda-12.1
  export CUDNN_INCLUDE_DIR=/usr/local/cuda-12.1/include
  pip install .
fi

# ===================================================
# Install apex
# ===================================================
cd $DEPS_DIR
if [ ! -d "$DEPS_DIR/apex" ]; then
  git clone https://github.com/NVIDIA/apex.git
  cd apex
  git checkout ${APEX_HASH}
  pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation \
    --config-settings "--global-option=--cpp_ext" \
    --config-settings "--global-option=--cuda_ext" \
    ./
fi


# ===================================================
# Clone Megatron-LM for scripts
# ===================================================
cd $DEPS_DIR
if [ ! -d "$DEPS_DIR/Megatron-LM" ]; then
  git clone https://github.com/NVIDIA/Megatron-LM.git
  cd Megatron-LM
  git checkout ${MEGATRON_LM_HASH}
  cd $DEPS_DIR
fi

train.sh:

#!/bin/bash

DEPS_DIR="$(pwd)/deps"
DATASET_DIR="<DECLARE YOUR DATASET DIRECTORY HERE>"

bash install_deps.sh

export GPT_HOME="${DATASET_DIR}"
export DATASET="${DATASET_DIR}/my-gpt2_text_document/my-gpt2_text_document"
export CHECKPOINT_PATH="${DATASET_DIR}/checkpoints/gptmodel"
export VOCAB_FILE="${DATASET_DIR}/gpt2-vocab.json"
export MERGES_FILE="${DATASET_DIR}/gpt2-merges.txt"
export DATA_PATH="${DATASET_DIR}/my-gpt2_text_document/my-gpt2_text_document"
export NVTE_BIAS_GELU_NVFUSION=0
export PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:512
export CUDA_DEVICE_MAX_CONNECTIONS=1
export NCCL_DEBUG=INFO
export NCCL_PROTO=LL,LL128,Simple
export FI_PROVIDER=efa
export FI_EFA_USE_DEVICE_RDMA=1
export RDMAV_FORK_SAFE=1

# remove previous checkpoints
rm -rf ${DATASET_DIR}/checkpoints/

torchrun --nproc-per-node 8 --nnodes 1 \
  ${DEPS_DIR}/Megatron-LM/pretrain_gpt.py \
  --tensor-model-parallel-size 1 \
  --pipeline-model-parallel-size 1 \
  --sequence-parallel \
  --num-layers 24 \
  --hidden-size 1024 \
  --num-attention-heads 16 \
  --micro-batch-size 1 \
  --global-batch-size 8 \
  --seq-length 2048 \
  --max-position-embeddings 2048 \
  --train-iters 1200 \
  --lr-decay-iters 320000 \
  --save ${CHECKPOINT_PATH} \
  --load ${CHECKPOINT_PATH} \
  --data-path ${DATA_PATH} \
  --vocab-file ${VOCAB_FILE} \
  --merge-file ${MERGES_FILE} \
  --split 949,50,1 \
  --distributed-backend nccl \
  --lr 0.00015 \
  --lr-decay-style cosine \
  --min-lr 1.0e-5 \
  --weight-decay 1e-2 \
  --clip-grad 1.0 \
  --lr-warmup-fraction .01 \
  --log-interval 100 \
  --save-interval 10000 \
  --eval-interval 1000 \
  --eval-iters 10 \
  --adam-beta1 0.9 \
  --adam-beta2 0.95 \
  --init-method-std 0.006 \
  --bf16 \
  --transformer-impl transformer_engine \
  --attention-softmax-in-fp32
@ptrendx
Copy link
Member

ptrendx commented Feb 2, 2024

Hi @sirutBuasai, what is the cuDNN version you are using?

@sirutBuasai
Copy link
Author

CuDNN installed with torch==2.1.2 is 8.9.2

(megatron_bench) ubuntu@ip-10-0-0-88:~$ python -c "import torch;print(torch.backends.cudnn.version())"
8902

@cyanguwa
Copy link
Collaborator

cyanguwa commented Feb 5, 2024

Hi @sirutBuasai , could you try upgrading to cuDNN 8.9.7+ please?

@sirutBuasai
Copy link
Author

Will do, in the meantime, is there a TE version that is built with CuDNN 8.9.2?

@cyanguwa
Copy link
Collaborator

cyanguwa commented Feb 5, 2024

I think it's probably v0.10, but I'd rather you roll forward with cuDNN than backward with TE. There's been a lot of development in the last year or so. If it's easier, you can use the NGC pytorch container, which has the latest TE (1.3) and cuDNN (9.0): nvcr.io/nvidia/pytorch:24.01-py3

@ptrendx
Copy link
Member

ptrendx commented Feb 8, 2024

@cyanguwa I think we still should catch this error from cuDNN Frontend and just disable cuDNN's implementation of attention in this case.

@liu21yd
Copy link

liu21yd commented Feb 28, 2024

@sirutBuasai Was your problem solved? Could you tell me the solution. I meet the same problem.

@sirutBuasai
Copy link
Author

@liu21yd, We ended up using TE v0.10 but it is pretty old. I haven't tried upgrading CuDNN and TE together but that would be a place to start.

@odashi
Copy link

odashi commented Jul 31, 2024

Recently we observed similar issues with any combinations of TE 1.4/1.7 and cuDNN 8.9.4/8.9.7. In our cases, the fused_attn test in this repository also fails, as well as the frontend toolkit (Megatron-LM) doesn't work.

Note that our operating system is Rocky, not Debian-ish ones.

For a workaround we eventually set NVTE_FUSED_ATTN=0 to disable fused attention kernels, then the issue went away.

@cyanguwa
Copy link
Collaborator

Hi @sirutBuasai, I haven't gone back to TE 1.2.1, but I just tried TE 1.11.0 with PyTorch 24.10 container, and it seems to work. Would you be able to upgrade to this combination?

Container:

nvcr.io/nvidia/pytorch:24.10-py3
(cuDNN 9.5.0.50, but I also tested 8.9.7 with it)

Install Megatron:

git clone https://github.com/NVIDIA/Megatron-LM.git
cd Megatron-LM/
git checkout 2c3468a49ed51324ae9b442e0d88416f1b29422b
sed -i -e 's/from pkg_resources import packaging/from pkg_resources\.extern import packaging/g' megatron/model/transformer.py
sed -i -e 's/from pkg_resources import packaging/from pkg_resources\.extern import packaging/g' megatron/core/transformer/attention.py
sed -i -e 's/from pkg_resources import packaging/from pkg_resources\.extern import packaging/g' megatron/core/transformer/custom_layers/transformer_engine.py

Script:

#!/bin/bash

DEPS_DIR=/workspace
DATASET_DIR=/code
WORK_DIR=/workspace

export NVTE_DEBUG=1
export NVTE_DEBUG_LEVEL=1

export GPT_HOME="${DATASET_DIR}"
export DATASET="${DATASET_DIR}/BookCorpus2-shuf/BookCorpus2_ftfy_cleaned_id_shuf_text_document"
export CHECKPOINT_PATH="${WORK_DIR}/checkpoints/gptmodel"
export VOCAB_FILE="${DATASET_DIR}/gpt2-vocab.json"
export MERGES_FILE="${DATASET_DIR}/gpt2-merges.txt"
export DATA_PATH="${DATASET_DIR}/BookCorpus2-shuf/BookCorpus2_ftfy_cleaned_id_shuf_text_document"
export NVTE_BIAS_GELU_NVFUSION=0
export PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:512
export CUDA_DEVICE_MAX_CONNECTIONS=1
export NCCL_DEBUG=INFO
export NCCL_PROTO=LL,LL128,Simple
export FI_PROVIDER=efa
export FI_EFA_USE_DEVICE_RDMA=1
export RDMAV_FORK_SAFE=1

# remove previous checkpoints
rm -rf ${WORK_DIR}/checkpoints/

torchrun --nproc-per-node 8 --nnodes 1 \
  ${WORK_DIR}/Megatron-LM/pretrain_gpt.py \
  --tensor-model-parallel-size 1 \
  --pipeline-model-parallel-size 1 \
  --sequence-parallel \
  --num-layers 24 \
  --hidden-size 1024 \
  --num-attention-heads 16 \
  --micro-batch-size 1 \
  --global-batch-size 8 \
  --seq-length 2048 \
  --max-position-embeddings 2048 \
  --train-iters 1200 \
  --lr-decay-iters 320000 \
  --save ${CHECKPOINT_PATH} \
  --load ${CHECKPOINT_PATH} \
  --data-path ${DATA_PATH} \
  --vocab-file ${VOCAB_FILE} \
  --merge-file ${MERGES_FILE} \
  --split 949,50,1 \
  --distributed-backend nccl \
  --lr 0.00015 \
  --lr-decay-style cosine \
  --min-lr 1.0e-5 \
  --weight-decay 1e-2 \
  --clip-grad 1.0 \
  --lr-warmup-fraction .01 \
  --log-interval 100 \
  --save-interval 10000 \
  --eval-interval 1000 \
  --eval-iters 10 \
  --adam-beta1 0.9 \
  --adam-beta2 0.95 \
  --init-method-std 0.006 \
  --bf16 \
  --transformer-impl transformer_engine \
  --attention-softmax-in-fp32

Results:

......
 iteration      900/    1200 | consumed samples:         7200 | elapsed time per iteration (ms): 85.6 | learning rate: 4.219E-05 | global batch size:     8 | lm loss: 5.917407E+00 | loss scale: 1.0 | grad norm: 2.783 | number of skipped iterations:   0 | number of nan iterations:   0 |
 iteration     1000/    1200 | consumed samples:         8000 | elapsed time per iteration (ms): 79.9 | learning rate: 4.688E-05 | global batch size:     8 | lm loss: 5.788058E+00 | loss scale: 1.0 | grad norm: 1.414 | number of skipped iterations:   0 | number of nan iterations:   0 |
INFO:DotProductAttention:Running with FusedAttention backend (sub-backend 1)
INFO:DotProductAttention:Running with FusedAttention backend (sub-backend 1)
INFO:DotProductAttention:Running with FusedAttention backend (sub-backend 1)
INFO:DotProductAttention:Running with FusedAttention backend (sub-backend 1)
INFO:DotProductAttention:Running with FusedAttention backend (sub-backend 1)
INFO:DotProductAttention:Running with FusedAttention backend (sub-backend 1)
INFO:DotProductAttention:Running with FusedAttention backend (sub-backend 1)
INFO:DotProductAttention:Running with FusedAttention backend (sub-backend 1)
.......

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

5 participants