Skip to content

Commit b1a9aa9

Browse files
WAR for nvidia-nvshmem package
Signed-off-by: Vladimir Cherepanov <[email protected]>
1 parent e5fcc2c commit b1a9aa9

File tree

1 file changed

+5
-0
lines changed

1 file changed

+5
-0
lines changed

transformer_engine/common/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,11 @@ def _nvidia_cudart_include_dir() -> str:
218218
except ModuleNotFoundError:
219219
return ""
220220

221+
# Installing some nvidia-* packages, like nvshmem, create nvidia name, so "import nvidia"
222+
# above doesn't through. However, they don't set "__file__" attribute.
223+
if nvidia.__file__ is None:
224+
return ""
225+
221226
include_dir = Path(nvidia.__file__).parent / "cuda_runtime"
222227
return str(include_dir) if include_dir.exists() else ""
223228

0 commit comments

Comments
 (0)