1
- #otherwise conda install may fail
2
- del __file__
3
- __package__ = None
4
- __spec__ = None
5
-
6
1
import sys
7
2
import importlib
8
3
import importlib .util
9
4
import argparse
5
+ import subprocess
10
6
parser = argparse .ArgumentParser ()
11
- parser .add_argument ("--channel" , help = "pytorch channel" , type = str , default = 'pytorch' )
12
7
parser .add_argument ("--cuda" , help = "install nvidia gpu support" , action = "store_true" , default = False )
13
8
parser .add_argument ("--package" , help = "install specific package" , action = 'append' , nargs = '+' , default = [])
14
9
args , unknown = parser .parse_known_args ()
15
10
16
- if importlib .util .find_spec ("conda " ) is None :
17
- print ("Error: A Conda environment is required to install the required packages." , file = sys .stderr )
11
+ if importlib .util .find_spec ("pip " ) is None :
12
+ print ("Error: Pip is required to install the required packages." , file = sys .stderr )
18
13
exit ()
19
14
20
- import conda .cli .python_api as Conda
21
-
22
- #increase rows (from default 20 when no terminal is found) to display all parallel packages downloads at once
23
- from tqdm import tqdm
24
- init_source = tqdm .__init__
25
- def init_patch (self , ** kwargs ):
26
- kwargs ['ncols' ]= 80
27
- kwargs ['nrows' ]= 80
28
- init_source (self , ** kwargs )
29
- tqdm .__init__ = init_patch
30
-
31
15
if not args .package :
32
- #https://edcarp.github.io/introduction-to-conda-for-data-scientists/03-using-packages-and-channels/index.html#alternative-syntax-for-installing-packages-from-specific-channels
33
- conda_install = f"pytorch torchvision torchaudio torchtext"
16
+ #first install python-graphviz as it only exist as a conda package, and conda is recommended before pip: https://www.anaconda.com/blog/using-pip-in-a-conda-environment
17
+ if importlib .util .find_spec ("conda" ) is None :
18
+ print ("Error: Conda is required to install the graphviz package." , file = sys .stderr )
19
+ exit ()
20
+ else :
21
+ print ("Downloading and installing the graphviz package..." )
22
+ print ("" )
23
+ result = subprocess .run ([sys .executable , "-m" , "conda" , "install" , "-y" , "python-graphviz" , "-c" , "conda-forge" ])
24
+ if result .returncode != 0 :
25
+ exit (result .returncode )
26
+ print ("" )
27
+
28
+ pip_install = "torch torchvision torchaudio torchtext"
34
29
if (sys .platform .startswith ('win' ) or sys .platform .startswith ('linux' )):
35
30
if args .cuda :
36
31
print ("Checking the latest supported CUDA version..." )
37
- highest_cuda_version = ( 11 , 6 ) # highest supported cuda version for PyTorch 1.12
32
+ highest_cuda_version = 118 #11.8 highest supported cuda version for PyTorch 2.0
38
33
import requests
39
34
try :
40
- pytorch_repo = requests .get ("https://anaconda. org/" + args . channel + "/pytorch/files " )
35
+ pytorch_repo = requests .get ("https://download.pytorch. org/whl/torch " )
41
36
except :
42
37
print ("Could not retrieve the latest supported CUDA version" )
43
38
else :
44
39
import re
45
- regex_request = re .compile ("cuda([0-9]+. [0-9]+)" )
40
+ regex_request = re .compile ("cu( [0-9]+)" )
46
41
results = re .findall (regex_request , pytorch_repo .text )
47
- highest_cuda_version = ( 11 , 6 )
42
+ highest_cuda_version = 118
48
43
for cuda_string in results :
49
- cuda_version = tuple ( int (i ) for i in cuda_string . split ( '.' ) )
44
+ cuda_version = int (cuda_string )
50
45
if cuda_version > highest_cuda_version :
51
46
highest_cuda_version = cuda_version
52
- highest_cuda_string = '.' . join ([ str (value ) for value in highest_cuda_version ])
47
+ highest_cuda_string = str ( highest_cuda_version )[: 2 ] + "." + str (highest_cuda_version )[ 2 :]
53
48
print ("Using CUDA " + highest_cuda_string )
54
49
print ("" )
55
- conda_install += " pytorch-cuda=" + highest_cuda_string + " -c " + args .channel + " -c nvidia"
56
- else :
57
- conda_install += " cpuonly -c " + args .channel
58
- else :
59
- conda_install += " -c " + args .channel
60
- print (f"Downloading and installing { args .channel } packages..." )
50
+ pip_install += " --index-url https://download.pytorch.org/whl/cu" + str (highest_cuda_version )
51
+
52
+ print ("Downloading and installing pytorch packages..." )
61
53
print ("" )
62
- # https://stackoverflow.com/questions/41767340/using-conda-install-within-a-python-script
63
- ( stdout_str , stderr_str , return_code_int ) = Conda . run_command ( Conda . Commands . INSTALL , conda_install .split (), use_exception_handler = True , stdout = sys . stdout , stderr = sys . stderr )
64
- if return_code_int != 0 :
65
- exit (return_code_int )
54
+
55
+ result = subprocess . run ([ sys . executable , "-m" , "pip" , "install" ] + pip_install .split ())
56
+ if result . returncode != 0 :
57
+ exit (result . returncode )
66
58
print ("" )
67
59
68
60
# onnx required for onnx export
@@ -73,16 +65,13 @@ def init_patch(self, **kwargs):
73
65
# python-graphviz required by torchstudio graph
74
66
# paramiko required for ssh connections (+updated cffi required on intel mac)
75
67
# pysoundfile required by torchaudio datasets: https://pytorch.org/audio/stable/backend.html#soundfile-backend
76
- conda_install = "onnx datasets scipy pandas matplotlib-base python-graphviz paramiko pysoundfile"
77
- if sys .platform .startswith ('darwin' ):
78
- conda_install += " cffi"
68
+ pip_install = "onnx datasets scipy pandas matplotlib paramiko pysoundfile"
79
69
80
70
else :
81
- conda_install = " " .join (args .package [0 ])
71
+ pip_install = " " .join (args .package [0 ])
82
72
83
- print ("Downloading and installing conda-forge packages..." )
73
+ print ("Downloading and installing additional packages..." )
84
74
print ("" )
85
- conda_install += " -c conda-forge"
86
- (stdout_str , stderr_str , return_code_int ) = Conda .run_command (Conda .Commands .INSTALL ,conda_install .split (),use_exception_handler = True ,stdout = sys .stdout ,stderr = sys .stderr )
87
- if return_code_int != 0 :
88
- exit (return_code_int )
75
+ result = subprocess .run ([sys .executable , "-m" , "pip" , "install" ]+ pip_install .split ())
76
+ if result .returncode != 0 :
77
+ exit (result .returncode )
0 commit comments