Skip to content

Commit 406e2c9

Browse files
Fix incorrect version checks for atomic GEMM (#2095)
* Fix incorrect version checks for atomic GEMM Signed-off-by: Tim Moon <[email protected]> * Fix typo Signed-off-by: Tim Moon <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Tim Moon <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 96944a8 commit 406e2c9

File tree

1 file changed

+15
-13
lines changed

1 file changed

+15
-13
lines changed

transformer_engine/common/gemm/cublaslt_gemm.cu

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -517,22 +517,22 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
517517
&epilogue, sizeof(epilogue)));
518518

519519
if (counter != nullptr) {
520-
#if !(CUDA_VERSION >= 12020 && CUBLAS_VERSION >= 13000)
521-
NVTE_ERROR("Atomic GEMM requires CUDA >=12.2.0 and <13.0.0, but compile-time CUDA verson is ",
520+
#if !(CUDA_VERSION >= 12020 && CUDA_VERSION < 13000)
521+
NVTE_ERROR("Atomic GEMM requires CUDA >=12.2.0 and <13.0.0, but compile-time CUDA version is ",
522522
CUDA_VERSION);
523523
#endif
524524
#if !(CUBLAS_VERSION >= 120205 && CUBLAS_VERSION < 130000)
525525
NVTE_ERROR(
526-
"Atomic GEMM requires cuBLAS >=12.2.5 and <13.0.0, but compile-time cuBLAS verson is ",
526+
"Atomic GEMM requires cuBLAS >=12.2.5 and <13.0.0, but compile-time cuBLAS version is ",
527527
CUBLAS_VERSION);
528528
#endif
529529
#if CUDA_VERSION >= 12020 && CUBLAS_VERSION >= 120205 && CUDA_VERSION < 13000 && \
530530
CUBLAS_VERSION < 130000
531531
NVTE_CHECK(cuda::cudart_version() >= 12020 && cuda::cudart_version() < 13000,
532-
"Atomic GEMM requires CUDA >=12.2.0 and <13.0.0, but run-time CUDA verson is ",
532+
"Atomic GEMM requires CUDA >=12.2.0 and <13.0.0, but run-time CUDA version is ",
533533
cuda::cudart_version());
534534
NVTE_CHECK(cublas_version() >= 120205 && cublas_version() < 130000,
535-
"Atomic GEMM requires cuBLAS >=12.2.5 and <13.0.0, but run-time cuBLAS verson is ",
535+
"Atomic GEMM requires cuBLAS >=12.2.5 and <13.0.0, but run-time cuBLAS version is ",
536536
cublas_version());
537537
if (m_split == 0) m_split = 1;
538538
if (n_split == 0) n_split = 1;
@@ -658,20 +658,22 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor
658658
using namespace transformer_engine;
659659

660660
// Check CUDA and cuBLAS versions
661-
#if !(CUDA_VERSION >= 12020 && CUBLAS_VERSION >= 13000)
662-
NVTE_ERROR("Atomic GEMM requires CUDA >=12.2.0 and <13.0.0, but compile-time CUDA verson is ",
661+
#if !(CUDA_VERSION >= 12020 && CUDA_VERSION < 13000)
662+
NVTE_ERROR("Atomic GEMM requires CUDA >=12.2.0 and <13.0.0, but compile-time CUDA version is ",
663663
CUDA_VERSION);
664664
#endif
665665
#if !(CUBLAS_VERSION >= 120205 && CUBLAS_VERSION < 130000)
666-
NVTE_ERROR("Atomic GEMM requires cuBLAS >=12.2.5 and <13.0.0, but compile-time cuBLAS verson is ",
667-
CUBLAS_VERSION);
666+
NVTE_ERROR(
667+
"Atomic GEMM requires cuBLAS >=12.2.5 and <13.0.0, but compile-time cuBLAS version is ",
668+
CUBLAS_VERSION);
668669
#endif
669-
NVTE_CHECK(cuda::cudart_version() >= 12020 && cuda::cudart_version() < 13000,
670-
"Atomic GEMM requires CUDA version >=12.2.0 and <13.0.0, but run-time CUDA verson is ",
671-
cuda::cudart_version());
670+
NVTE_CHECK(
671+
cuda::cudart_version() >= 12020 && cuda::cudart_version() < 13000,
672+
"Atomic GEMM requires CUDA version >=12.2.0 and <13.0.0, but run-time CUDA version is ",
673+
cuda::cudart_version());
672674
NVTE_CHECK(
673675
cublas_version() >= 120205 && cublas_version() < 130000,
674-
"Atomic GEMM requires cuBLAS version >=12.2.5 and <13.0.0, but run-time cuBLAS verson is ",
676+
"Atomic GEMM requires cuBLAS version >=12.2.5 and <13.0.0, but run-time cuBLAS version is ",
675677
cublas_version());
676678

677679
const Tensor *inputA = convertNVTETensorCheck(A);

0 commit comments

Comments
 (0)