Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 20 additions & 6 deletions api/common/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,28 @@


def is_ampere_gpu():
stdout, exit_code = system.run_command("nvidia-smi -L")
if exit_code == 0:
gpu_list = stdout.split("\n")
if len(gpu_list) >= 1:
try:
from pynvml import nvmlInit
from pynvml import nvmlDeviceGetHandleByIndex
from pynvml import nvmlDeviceGetCudaComputeCapability
nvmlInit()
handle = nvmlDeviceGetHandleByIndex(0)
cc_major,cc_minor = nvmlDeviceGetCudaComputeCapability(handle)
#print(str(cc_major)+"."+str(cc_minor))
# 8.0 or 8.6
return str(cc_major)+"."+str(cc_minor)>"8.0"
except ImportError:
print("Warning: pynvml package is not installed, please install it as follow \"pip install pynvml\"")
stdout, exit_code = system.run_command("nvidia-smi -L")
if exit_code == 0:
gpu_list = stdout.split("\n")
#print(gpu_list[0])
# GPU 0: NVIDIA A100-SXM4-40GB (UUID: xxxx)
return gpu_list[0].find("A100") > 0
return False
else:
print("Error: Failed to run sys command \"nvidia-smi -L\"")
return False
return gpu_list[0].find("A100") > 0



class NvprofRunner(object):
Expand Down