|
15 | 15 | args, unknown = parser.parse_known_args()
|
16 | 16 |
|
17 | 17 | #check python version first
|
18 |
| -python_version=(sys.version_info.major,sys.version_info.minor) |
19 |
| -min_python_version=(3,7) if args.remote else (3,8) #3.7 required for ordered dicts and stdout/stderr utf8 encoding, 3.8 required for python parsing |
| 18 | +python_version=(sys.version_info.major,sys.version_info.minor,sys.version_info.micro) |
| 19 | +min_python_version=(3,7,0) if args.remote else (3,8,0) #3.7 required for ordered dicts and stdout/stderr utf8 encoding, 3.8 required for python parsing |
20 | 20 | if python_version<min_python_version:
|
21 |
| - print("Error: Python "+str(min_python_version[0])+"."+str(min_python_version[1])+" minimum is required.", file=sys.stderr) |
22 |
| - print("This environment has Python "+str(python_version[0])+"."+str(python_version[1])+".", file=sys.stderr) |
| 21 | + print("Error: Python "+'.'.join((str(i) for i in min_python_version))+" minimum is required.", file=sys.stderr) |
| 22 | + print("This environment has Python "+'.'.join((str(i) for i in python_version))+".", file=sys.stderr) |
23 | 23 | exit(1)
|
24 | 24 |
|
25 | 25 | print("Checking required packages...\n", file=sys.stderr)
|
|
36 | 36 | if module is None:
|
37 | 37 | missing_modules.append(module_check)
|
38 | 38 | elif module_check=='torch':
|
39 |
| - if python_version<(3,8): |
| 39 | + if python_version<(3,8,0): |
40 | 40 | from importlib_metadata import version
|
41 | 41 | else:
|
42 | 42 | from importlib.metadata import version
|
43 |
| - pytorch_version=tuple(int(i) for i in version('torch').split('.')[:2]) |
44 |
| - min_pytorch_version=(1,9) #1.9 required for torch.package support, 1.10 preferred for stable torch.fx and profile-directed typing in torchscript |
| 43 | + pytorch_version=tuple(int(i) if i.isdigit() else 0 for i in version('torch').split('.')[:3]) |
| 44 | + min_pytorch_version=(1,9,0) #1.9 required for torch.package support, 1.10 preferred for stable torch.fx and profile-directed typing in torchscript |
45 | 45 | if pytorch_version<min_pytorch_version:
|
46 |
| - print("Error: PyTorch "+str(min_pytorch_version[0])+"."+str(min_pytorch_version[1])+" minimum is required.", file=sys.stderr) |
47 |
| - print("This environment has PyTorch "+str(pytorch_version[0])+"."+str(pytorch_version[1])+".", file=sys.stderr) |
| 46 | + print("Error: PyTorch "+'.'.join((str(i) for i in min_pytorch_version))+" minimum is required.", file=sys.stderr) |
| 47 | + print("This environment has PyTorch "+'.'.join((str(i) for i in pytorch_version))+".", file=sys.stderr) |
48 | 48 | exit(1)
|
49 | 49 |
|
50 | 50 | if len(missing_modules)>0:
|
|
72 | 72 |
|
73 | 73 | #finally, list available devices
|
74 | 74 | print("Loading PyTorch...\n", file=sys.stderr)
|
75 |
| - |
76 | 75 | import torch
|
77 | 76 |
|
78 | 77 | print("Listing devices...\n", file=sys.stderr)
|
79 |
| - |
80 | 78 | devices = {}
|
81 |
| - devices['cpu'] = {'name': 'CPU', 'pin_memory': False} |
| 79 | + devices['cpu'] = {'name': 'CPU', 'pin_memory': False, 'modes': ['FP32']} |
82 | 80 | for i in range(torch.cuda.device_count()):
|
83 |
| - devices['cuda:'+str(i)] = {'name': torch.cuda.get_device_name(i), 'pin_memory': True} |
84 |
| - if pytorch_version>=(1,12): |
| 81 | + modes = ['FP32'] |
| 82 | + #same as torch.cuda.is_bf16_supported() but compatible with PyTorch<1.10, and not limited to current cuda device only |
| 83 | + cu_vers = torch.version.cuda |
| 84 | + if cu_vers is not None: |
| 85 | + cuda_maj_decide = int(cu_vers.split('.')[0]) >= 11 |
| 86 | + else: |
| 87 | + cuda_maj_decide = False |
| 88 | + compute_capability=torch.cuda.get_device_properties(torch.cuda.device(i)).major #https://developer.nvidia.com/cuda-gpus |
| 89 | + if compute_capability>=8 and cuda_maj_decide: #RTX 3000 and higher |
| 90 | + modes+=['TF32','FP16','BF16'] |
| 91 | + if compute_capability==7: #RTX 2000 |
| 92 | + modes+=['FP16'] |
| 93 | + devices['cuda:'+str(i)] = {'name': torch.cuda.get_device_name(i), 'pin_memory': True, 'modes': modes} |
| 94 | + if pytorch_version>=(1,12,0): |
85 | 95 | if torch.backends.mps.is_available():
|
86 |
| - devices['mps'] = {'name': 'Metal Acceleration', 'pin_memory': False} |
| 96 | + devices['mps'] = {'name': 'Metal', 'pin_memory': False, 'modes': ['FP32']} |
87 | 97 | #other possible devices:
|
88 | 98 | #'hpu' (https://docs.habana.ai/en/latest/PyTorch_User_Guide/PyTorch_User_Guide.html)
|
89 | 99 | #'dml' (https://docs.microsoft.com/en-us/windows/ai/directml/gpu-pytorch-windows)
|
90 | 100 | devices_string_list=[]
|
91 | 101 | for id in devices:
|
92 |
| - devices_string_list.append(devices[id]['name']+" ("+id+")") |
93 |
| - print(("Online and functional " if args.remote else "Functional ")+"("+platform.platform()+", Python "+str(python_version[0])+"."+str(python_version[1])+", PyTorch "+str(pytorch_version[0])+"."+str(pytorch_version[1])+", Devices: "+", ".join(devices_string_list)+")"); |
94 |
| - |
95 |
| - |
| 102 | + devices_string_list.append(id+' "'+devices[id]['name']+'" ('+'/'.join(devices[id]['modes'])+')') |
| 103 | + print("Ready ("+platform.platform()+", Python "+'.'.join((str(i) for i in python_version))+", PyTorch "+'.'.join((str(i) for i in pytorch_version))+", Devices: "+", ".join(devices_string_list)+")"); |
0 commit comments