Skip to content

Commit b10f436

Browse files
Fix CUDA version in setup.py (#2132)
* Fix CUDA version in setup.py Signed-off-by: Vladimir Cherepanov <[email protected]> * Re-enable building comm-gemm tests Signed-off-by: Vladimir Cherepanov <[email protected]> * WAR for nvidia-nvshmem package Signed-off-by: Vladimir Cherepanov <[email protected]> --------- Signed-off-by: Vladimir Cherepanov <[email protected]> Co-authored-by: Tim Moon <[email protected]>
1 parent 11e9d66 commit b10f436

File tree

3 files changed

+10
-3
lines changed

3 files changed

+10
-3
lines changed

setup.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from build_tools.te_version import te_version
1818
from build_tools.utils import (
1919
cuda_archs,
20+
cuda_version,
2021
get_frameworks,
2122
remove_dups,
2223
)
@@ -70,11 +71,11 @@ def setup_common_extension() -> CMakeExtension:
7071
if bool(int(os.getenv("NVTE_WITH_CUBLASMP", "0"))):
7172
cmake_flags.append("-DNVTE_WITH_CUBLASMP=ON")
7273
cublasmp_dir = os.getenv("CUBLASMP_HOME") or metadata.distribution(
73-
"nvidia-cublasmp-cu12"
74-
).locate_file("nvidia/cublasmp/cu12")
74+
f"nvidia-cublasmp-cu{cuda_version()[0]}"
75+
).locate_file(f"nvidia/cublasmp/cu{cuda_version()[0]}")
7576
cmake_flags.append(f"-DCUBLASMP_DIR={cublasmp_dir}")
7677
nvshmem_dir = os.getenv("NVSHMEM_HOME") or metadata.distribution(
77-
"nvidia-nvshmem-cu12"
78+
f"nvidia-nvshmem-cu{cuda_version()[0]}"
7879
).locate_file("nvidia/nvshmem")
7980
cmake_flags.append(f"-DNVSHMEM_DIR={nvshmem_dir}")
8081
print("CMAKE_FLAGS:", cmake_flags[-2:])

tests/cpp/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,5 +43,6 @@ include_directories(${CMAKE_SOURCE_DIR})
4343
find_package(CUDAToolkit REQUIRED)
4444
include(${CMAKE_SOURCE_DIR}/../../3rdparty/cudnn-frontend/cmake/cuDNN.cmake)
4545

46+
add_subdirectory(comm_gemm)
4647
add_subdirectory(operator)
4748
add_subdirectory(util)

transformer_engine/common/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,11 @@ def _nvidia_cudart_include_dir() -> str:
218218
except ModuleNotFoundError:
219219
return ""
220220

221+
# Installing some nvidia-* packages, like nvshmem, create nvidia name, so "import nvidia"
222+
# above doesn't through. However, they don't set "__file__" attribute.
223+
if nvidia.__file__ is None:
224+
return ""
225+
221226
include_dir = Path(nvidia.__file__).parent / "cuda_runtime"
222227
return str(include_dir) if include_dir.exists() else ""
223228

0 commit comments

Comments
 (0)