Skip to content

Conversation

@Purfview
Copy link
Contributor

@Purfview Purfview commented Nov 24, 2025

Closes: #1865

I debugged the issue a bit, observed the CUBLAS_STATUS_NOT_SUPPORTED error only when n[fw] is not multiple by 4 and some m[fw] values [I didn't noticed a pattern with m].
[Note: ldc=n[fw] and n / m are swapped in CT2's call to GEMM]
The cuBLAS docs doesn't say that m or ldc should be divisible by 4, it only recommends it for performance:

To use IMMA kernels, one of the following sets of requirements, with the first being the preferred one, must be met:
    1) Using a regular data ordering:
        All matrix pointers must be 4-byte aligned. For even better performance, this condition should hold with 16 instead of 4.
        Leading dimensions of matrices A, B, C must be multiples of 4.
        Dimensions m and k must be multiples of 4.

I guess, NVIDIA dropped some kernels for sm120, and now cublasGemmEx() fails when n[fw] % 4 ≠ 0 and some m[fw] (batch * sequence_length) combination.

In some calls to GEMM, n[fw]=vocab[whisper]. In multilingual models vocab is 51865 or 51866[v3] - not divisible by 4.

Sure way to reproduce the error - use word_timestamps=True.
Should be reproducible with word_timestamps=False and high beam_size, for example 24. [not tested]
I think word_timestamps is irrelevant to reproduce in batched mode. [not tested]

Closes: OpenNMT#1865

I debugged[not exhaustive] the issue a bit, observed the CUBLAS_STATUS_NOT_SUPPORTED error only when n[fw] is not multiple by 4.
[Note: ldc=n[fw] and n / m are swapped in CT2's call to GEMM]
The cuBLAS docs doesn't say that m or ldc should be divisible by 4, it only recomends it for performance.
I guess NVIDIA dropped some tensor core kernels for sm12x, and now cublasGemmEx() fails on some m & n combinations.
@jordimas
Copy link
Collaborator

Thanks. I am in favor of the change. If somebody has a better approach, please comment in this PR.

@Purfview
Copy link
Contributor Author

Purfview commented Nov 24, 2025

BTW, currently cublasGemmAlgo_t type used by CT2 is deprecated, I tried recommended one but there was no effect on the error.
No effect on performance too.

It's deprecated since CUDA 11, maybe it should be changed before it gets removed.
EDIT: Done -> #1938

@Purfview
Copy link
Contributor Author

Maybe it's a bug, I see similar issue in the latest cuBLAS: Release 13.0 Update 2

Known Issues
    cublasLtMatmul with INT8 inputs, INT32 accumulation, and INT32 outputs might return CUBLAS_STATUS_NOT_SUPPORTED 
when dimension N is larger than 65,536 or when the batch count is larger than 1. 
The issue has existed since CUDA Toolkit 13.0 Update 1 and will be fixed in a later release. [5541380]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

GeForce RTX 50XX - cuBLAS failed with status CUBLAS_STATUS_NOT_SUPPORTED

2 participants