Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to debug tex.fused_attn_bwd getting cuDNN Error: [cudnn_frontend] Error: No execution plans support the graph #1591

Open
Ir1d opened this issue Mar 19, 2025 · 1 comment
Labels
bug Something isn't working

Comments

@Ir1d
Copy link

Ir1d commented Mar 19, 2025

Describe the bug

Fused attention backward gets RuntimeError with no informative message. Setting CUDNN_LOGERR_DBG=1 and CUDNN_LOGDEST_DBG=stderr don't help.

Error Msg
Traceback (most recent call last):
    File "./test.py", line 25, in <module>
      output_fused.backward(out_grad)
    File "/xxx/.venv/lib/python3.12/site-packages/torch/_tensor.py", line 626, in backward
      torch.autograd.backward(
    File "/xxx/.venv/lib/python3.12/site-packages/torch/autograd/__init__.py", line 347, in backward
      _engine_run_backward(
    File "/xxx/.venv/lib/python3.12/site-packages/torch/autograd/graph.py", line 823, in _engine_run_backward
      return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/xxx/.venv/lib/python3.12/site-packages/torch/autograd/function.py", line 307, in apply
      return user_fn(self, *args)
             ^^^^^^^^^^^^^^^^^^^^
    File "/xxx/.venv/lib/python3.12/site-packages/transformer_engine/pytorch/attention.py", line 6340, in backward
      dq, dk, dv, *rest = fused_attn_bwd(
                          ^^^^^^^^^^^^^^^
    File "/xxx/.venv/lib/python3.12/site-packages/transformer_engine/pytorch/cpp_extensions/fused_attn.py", line 451, in fused_attn_bwd
      output_tensors = tex.fused_attn_bwd(
                       ^^^^^^^^^^^^^^^^^^^
  RuntimeError: /xxxx/TransformerEngine/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu:771 in function operator(): cuDNN Error: [cudnn_frontend] Error: No execution plans support the graph.. For more information, enable cuDNN error logging by setting CUDNN_LOGERR_DBG=1 and CUDNN_LOGDEST_DBG=stderr in the environment.

Steps/Code to reproduce bug

`test.py`
import os
import torch

from transformer_engine.pytorch.attention import DotProductAttention, _attention_backends

seqlen, batch_size, heads, kv_channels = 1024, 2, 16, 192

q, k = [torch.randn(seqlen, batch_size, heads, kv_channels, dtype=torch.float16, device="cuda", requires_grad=True) for _ in range(2)]
v = torch.randn(seqlen, batch_size, heads, 128, dtype=torch.float16, device="cuda", requires_grad=True)

cu_seqlens_q = cu_seqlens_kv = torch.tensor([0, 1024, 2048], device="cuda", dtype=torch.int32)

attention_kernel = DotProductAttention(heads, (192, 128))


os.environ["NVTE_FUSED_ATTN"] = "1"
os.environ["NVTE_FLASH_ATTN"] = "0"
_attention_backends["backend_selection_requires_update"] = True
output_fused = attention_kernel(q, k, v, qkv_format='sbhd', attn_mask_type='causal', cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_kv)
print(output_fused.shape)
out_grad = 0.001 * torch.randint(0, 200, (1024, 2, 2048), device="cuda")
output_fused.backward(out_grad)
CUDNN_LOGERR_DBG=1 CUDNN_LOGDEST_DBG=stderr NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 python test.py

Expected behavior

Expects backward fluently.

Environment overview (please complete the following information)

H100, Cudnn 9.1.0, CUDA 12.3, python 3.12.6, pytorch 2.6.0+cu124

TE installed via compilation with uv using 8eb1712

compilation script
CMAKE_BUILD_WITH_INSTALL_RPATH=ON \
CMAKE_INSTALL_RPATH_USE_LINK_PATH=ON \
CMAKE_SKIP_BUILD_RPATH=FALSE \
CMAKE_BUILD_WITH_INSTALL_RPATH=TRUE \
CMAKE_INSTALL_RPATH="/xxx/.venv/lib/python3.12/site-packages/nvidia/cudnn/lib/" \
CUDNN_PATH=/xxx/.venv/lib/python3.12/site-packages/nvidia/cudnn/ \
CUDACXX=/usr/local/cuda-12.3/bin/nvcc \
CMAKE_CUDA_COMPILER=/usr/local/cuda-12.3/bin/nvcc \
CUDA_HOME=/usr/local/cuda-12.3 \
NVTE_FRAMEWORK=pytorch \
MAX_JOBS=96 \
CC=gcc \
CXX=g++ \
CMAKE_GENERATOR="Unix Makefiles" \
SKBUILD_CMAKE_ARGS="-DCMAKE_BUILD_WITH_INSTALL_RPATH=ON -DCMAKE_INSTALL_RPATH_USE_LINK_PATH=ON" \
uv pip install -v "." --no-build-isolation --no-cache-dir
@Ir1d Ir1d added the bug Something isn't working label Mar 19, 2025
@liangxuZhang
Copy link

I encountered the same problem. It seems that fused attention does not support scenes with different kv head dimensions. My approach is to pad v to 192 to use fused or flash attention. Hope TE can support attention with different kv dimensions in the future.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants