Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cuBLAS Error #1585

Open
wangli68 opened this issue Mar 18, 2025 · 7 comments
Open

cuBLAS Error #1585

wangli68 opened this issue Mar 18, 2025 · 7 comments

Comments

@wangli68
Copy link

When my pytorch is 2.5 and my transformer engine is 2.1, I run the following original network structure:

class SelfAttention(nn.Module):
    def __init__(self, dim: int, num_heads: int, eps: float = 1e-6):
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.head_dim = dim // num_heads

        self.q = nn.Linear(dim, dim)
        self.k = nn.Linear(dim, dim)
        self.v = nn.Linear(dim, dim)
        self.o = nn.Linear(dim, dim)
        self.norm_q = RMSNorm(dim, eps=eps)
        self.norm_k = RMSNorm(dim, eps=eps)

    def forward(self, x, freqs):
        q = self.norm_q(self.q(x))
        k = self.norm_k(self.k(x))
        v = self.v(x)
        x = flash_attention(
            q=rope_apply(q, freqs, self.num_heads),
            k=rope_apply(k, freqs, self.num_heads),
            v=v,
            num_heads=self.num_heads
        )
        return self.o(x)

Note that at this point, I have changed the lightning network structure to the corresponding te.xx, but it reports the following error:

rank1]: File "/output/DiffSynth-Studio/diffsynth/models/wan_video_dit.py", line 128, in forward
[rank1]: q = self.norm_q(self.q(x))
[rank1]: File "/openbayes/home/Test/env/ao25/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank1]: return self._call_impl(*args, **kwargs)
[rank1]: File "/openbayes/home/Test/env/ao25/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
[rank1]: return forward_call(*args, **kwargs)
[rank1]: File "/openbayes/home/Test/env/ao25/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 632, in fn
[rank1]: return fn(*args, **kwargs)
[rank1]: File "/openbayes/home/Test/env/ao25/lib/python3.10/site-packages/transformer_engine/pytorch/module/linear.py", line 1085, in forward
[rank1]: out = linear_fn(*args)
[rank1]: File "/openbayes/home/Test/env/ao25/lib/python3.10/site-packages/torch/autograd/function.py", line 575, in apply
[rank1]: return super().apply(*args, **kwargs) # type: ignore[misc]
[rank1]: File "/openbayes/home/Test/env/ao25/lib/python3.10/site-packages/transformer_engine/pytorch/module/linear.py", line 231, in forward
[rank1]: out, *
, rs_out = general_gemm(
[rank1]: File "/openbayes/home/Test/env/ao25/lib/python3.10/site-packages/transformer_engine/pytorch/cpp_extensions/gemm.py", line 141, in general_gemm
[rank1]: out, bias_grad, gelu_input, extra_output = tex.generic_gemm(*args, **kwargs)
[rank1]: RuntimeError: /TransformerEngine/transformer_engine/common/gemm/cublaslt_gemm.cu:282 in function cublas_gemm: cuBLAS Error: an unsupported value or parameter was passed to the function

Epoch 0: 0%| | 0/63 [00:04<?, ?it/s]

But when my PyTorch is 2.4 and my Transformer Engine is 1.3, it does not report any errors, but enters a state where the memory usage remains unchanged, as if it is crashing.

@Pedrexus
Copy link

Pedrexus commented Mar 20, 2025

I am also running into this issue with v2.1, whereas I did not have this problem in v2.0. I am using CUDA 12.4, CUDNN 9.0.0

Here is a simple snippet to reproduce the issue:

import torch
import transformer_engine.pytorch as te
from transformer_engine.common.recipe import DelayedScaling, Format

# Layer configuration
hidden_size = 4096
sequence_length = 2048
batch_size = 8
ffn_hidden_size = 16384
num_attention_heads = 32
dtype = torch.bfloat16

model = te.TransformerLayer(
    hidden_size,
    ffn_hidden_size,
    num_attention_heads,
    attn_input_format="bshd",
    params_dtype=dtype,
).cuda()

x = torch.rand(batch_size, sequence_length, hidden_size, dtype=dtype).cuda()

with te.fp8_autocast(enabled=True):
    y = model(x).sum()
y.backward()

print("Done")

Environment

(llm) pvalois@pegasus02:~/llm$ uv pip list
Package                       Version
----------------------------- --------------
absl-py                       2.1.0
accelerate                    1.5.2
aiohappyeyeballs              2.6.1
aiohttp                       3.11.14
aiosignal                     1.3.2
alabaster                     1.0.0
antlr4-python3-runtime        4.9.3
anyio                         4.9.0
argon2-cffi                   23.1.0
argon2-cffi-bindings          21.2.0
arrow                         1.3.0
asttokens                     3.0.0
async-lru                     2.0.5
attrs                         25.3.0
babel                         2.17.0
beautifulsoup4                4.13.3
bitsandbytes                  0.44.1
black                         25.1.0
bleach                        6.2.0
boto3                         1.37.16
botocore                      1.37.16
certifi                       2025.1.31
cffi                          1.17.1
cfgv                          3.4.0
chardet                       5.2.0
charset-normalizer            3.4.1
click                         8.1.8
colorama                      0.4.6
comm                          0.2.2
contourpy                     1.3.1
curio                         1.6
cycler                        0.12.1
dataproperty                  1.1.0
datasets                      3.4.1
debugpy                       1.8.13
decorator                     5.2.1
deepspeed                     0.9.3
defusedxml                    0.7.1
dill                          0.3.8
distlib                       0.3.9
docrepr                       0.2.0
docstring-parser              0.16
docutils                      0.21.2
einops                        0.8.1
evaluate                      0.4.3
exceptiongroup                1.2.2
executing                     2.2.0
fastapi                       0.115.11
fastjsonschema                2.21.1
filelock                      3.18.0
fire                          0.7.0
flash-attn                    2.7.3
fonttools                     4.56.0
fqdn                          1.5.1
frozenlist                    1.5.0
fsspec                        2024.12.0
grpcio                        1.71.0
h11                           0.14.0
hf-transfer                   0.1.9
hjson                         3.1.0
httpcore                      1.0.7
httptools                     0.6.4
httpx                         0.28.1
huggingface-hub               0.29.3
hydra-core                    1.3.2
identify                      2.6.9
idna                          3.10
imagesize                     1.4.1
importlib-metadata            8.6.1
importlib-resources           6.5.2
iniconfig                     2.1.0
intersphinx-registry          0.2501.23
ipykernel                     6.29.5
ipyparallel                   9.0.1
ipython                       8.34.0
ipywidgets                    8.1.5
isoduration                   20.11.0
jedi                          0.19.2
jinja2                        3.1.6
jmespath                      1.0.1
joblib                        1.4.2
json5                         0.10.0
jsonargparse                  4.32.1
jsonlines                     4.0.0
jsonpointer                   3.0.0
jsonschema                    4.23.0
jsonschema-specifications     2024.10.1
jupyter-client                8.6.3
jupyter-core                  5.7.2
jupyter-events                0.12.0
jupyter-lsp                   2.2.5
jupyter-server                2.15.0
jupyter-server-terminals      0.5.3
jupyterlab                    4.3.6
jupyterlab-pygments           0.3.0
jupyterlab-server             2.27.3
jupyterlab-widgets            3.0.13
kiwisolver                    1.4.8
lightning                     2.5.0.post0
lightning-thunder             0.2.1
lightning-utilities           0.14.1
litdata                       0.2.17
litgpt                        0.5.7
litserve                      0.2.4
lm-eval                       0.4.8
loguru                        0.7.3
looseversion                  1.3.0
lxml                          5.3.1
markdown                      3.7
markdown-it-py                3.0.0
markupsafe                    3.0.2
matplotlib                    3.10.1
matplotlib-inline             0.1.7
mbstrdecoder                  1.1.4
mdurl                         0.1.2
mistune                       3.1.3
more-itertools                10.6.0
mpi4py                        4.0.3
mpmath                        1.3.0
multidict                     6.2.0
multiprocess                  0.70.16
mypy-extensions               1.0.0
nbclient                      0.10.2
nbconvert                     7.16.6
nbformat                      5.10.4
nest-asyncio                  1.6.0
networkx                      3.4.2
ninja                         1.11.1.3
nltk                          3.9.1
nodeenv                       1.9.1
notebook                      7.3.3
notebook-shim                 0.2.4
numexpr                       2.10.2
numpy                         1.26.4
nvidia-cublas-cu12            12.4.5.8
nvidia-cuda-cupti-cu12        12.4.127
nvidia-cuda-nvrtc-cu12        12.4.127
nvidia-cuda-runtime-cu12      12.4.127
nvidia-cudnn-cu12             9.1.0.70
nvidia-cufft-cu12             11.2.1.3
nvidia-curand-cu12            10.3.5.147
nvidia-cusolver-cu12          11.6.1.9
nvidia-cusparse-cu12          12.3.1.170
nvidia-nccl-cu12              2.21.5
nvidia-nvjitlink-cu12         12.4.127
nvidia-nvtx-cu12              12.4.127
omegaconf                     2.3.0
opt-einsum                    3.4.0
optree                        0.14.1
outcome                       1.3.0.post0
overrides                     7.7.0
packaging                     24.2
pandas                        2.2.3
pandocfilters                 1.5.1
parso                         0.8.4
pathspec                      0.12.1
pathvalidate                  3.2.3
peft                          0.15.0
pexpect                       4.9.0
pickleshare                   0.7.5
pillow                        11.1.0
platformdirs                  4.3.7
pluggy                        1.5.0
portalocker                   3.1.1
pre-commit                    4.2.0
prometheus-client             0.21.1
prompt-toolkit                3.0.50
propcache                     0.3.0
protobuf                      6.30.1
psutil                        7.0.0
ptyprocess                    0.7.0
pure-eval                     0.2.3
py-cpuinfo                    9.0.0
pyarrow                       19.0.1
pybind11                      2.13.6
pycparser                     2.22
pydantic                      1.10.21
pygments                      2.19.1
pyparsing                     3.2.1
pytablewriter                 1.2.1
pytest                        8.3.5
pytest-asyncio                0.21.2
python-dateutil               2.9.0.post0
python-dotenv                 1.0.1
python-json-logger            3.3.0
pytorch-lightning             2.5.1
pytz                          2025.1
pyyaml                        6.0.2
pyzmq                         26.3.0
qtconsole                     5.6.1
qtpy                          2.4.3
referencing                   0.36.2
regex                         2024.11.6
requests                      2.32.3
rfc3339-validator             0.1.4
rfc3986-validator             0.1.1
rich                          13.9.4
roman-numerals-py             3.1.0
rouge-score                   0.1.2
rpds-py                       0.23.1
s3transfer                    0.11.4
sacrebleu                     2.5.1
safetensors                   0.5.3
scikit-learn                  1.6.1
scipy                         1.15.2
send2trash                    1.8.3
sentencepiece                 0.2.0
setuptools                    77.0.1
six                           1.17.0
sniffio                       1.3.1
snowballstemmer               2.2.0
sortedcontainers              2.4.0
soupsieve                     2.6
sphinx                        8.2.3
sphinx-rtd-theme              3.0.2
sphinxcontrib-applehelp       2.0.0
sphinxcontrib-devhelp         2.0.0
sphinxcontrib-htmlhelp        2.1.0
sphinxcontrib-jquery          4.1
sphinxcontrib-jsmath          1.0.1
sphinxcontrib-qthelp          2.0.0
sphinxcontrib-serializinghtml 2.0.0
sqlitedict                    2.1.0
stack-data                    0.6.3
starlette                     0.46.1
sympy                         1.13.1
tabledata                     1.3.4
tabulate                      0.9.0
tcolorpy                      0.1.7
tenacity                      9.0.0
tensorboard                   2.19.0
tensorboard-data-server       0.7.2
tensorboardx                  2.6.2.2
termcolor                     2.5.0
terminado                     0.18.1
testpath                      0.6.0
threadpoolctl                 3.6.0
tinycss2                      1.4.0
tokenizers                    0.21.1
torch                         2.5.1
torchmetrics                  1.6.3
torchvision                   0.20.1
tornado                       6.4.2
tqdm                          4.67.1
tqdm-multiprocess             0.0.11
traitlets                     5.14.3
transformer-engine            2.1.0
transformer-engine-cu12       2.1.0
transformer-engine-torch      2.1.0
transformers                  4.47.1
trio                          0.29.0
triton                        3.1.0
typepy                        1.3.4
types-python-dateutil         2.9.0.20241206
typeshed-client               2.7.0
typing-extensions             4.12.2
tzdata                        2025.1
uri-template                  1.3.0
urllib3                       2.3.0
uvicorn                       0.34.0
uvloop                        0.21.0
virtualenv                    20.29.3
watchfiles                    1.0.4
wcwidth                       0.2.13
webcolors                     24.11.1
webencodings                  0.5.1
websocket-client              1.8.0
websockets                    15.0.1
werkzeug                      3.1.3
widgetsnbextension            4.0.13
word2number                   1.1
xxhash                        3.5.0
yarl                          1.18.3
zipp                          3.21.0
zstandard                     0.23.0

@chky1997
Copy link

Hi, I run into this issue with both TE 2.1.0 and 2.0.0, with torch 2.6.0+cu126, cuda is 12.8 and cudnn is 9.8.0

@haitian-jiang
Copy link

haitian-jiang commented Mar 23, 2025

Same issue, TE 2.1.0, torch 2.5.1+cu124, cuda 12.4, cudnn 9.8.0. TE 1.13.0 works fine with my environment.

@pranayj77
Copy link

I am also running into this issue, torch 2.6.0+cuda126, cuda 12.6.2, cudnn 9.8.0, TE 2.1.0

@is
Copy link

is commented Mar 26, 2025

Same Issue, torch 2.6.0/te 2.1.0, cuda 12.4

@flyingmanPan
Copy link

Only TE 1.13.0 works fine

@Pedrexus
Copy link

FYI, I was able to run 2.0.0 by building it from source, but not 2.1.x nor 2.2.x. However, 2.0.x is not available in PyPI.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

7 participants