diff --git a/save.go b/save.go index 9aeceaa..4a81d58 100644 --- a/save.go +++ b/save.go @@ -134,6 +134,32 @@ func (p byImportPath) Len() int { return len(p) } func (p byImportPath) Less(i, j int) bool { return p[i].root < p[j].root } func (p byImportPath) Swap(i, j int) { p[i], p[j] = p[j], p[i] } +var majorVersionComponent = regexp.MustCompile(`v[\d]+`) + +// pathWithoutMajorVersion returns path with the 1st major version /component +// (if any) stripped out. If one was found, the 2nd return value is true. +func pathWithoutMajorVersion(path string) (string, bool) { + parts := strings.Split(path, "/") + for idx, part := range strings.Split(path, "/") { + if majorVersionComponent.MatchString(part) { + return strings.Join(append(parts[:idx], parts[idx+1:]...), "/"), true + } + } + return path, false +} + +// tryImport attempts to import the path as-is and, if it fails to be found and +// path contains a major module version, reattempts with the version removed. +func tryImport(ctx build.Context, path, srcDir string, mode build.ImportMode) (*build.Package, error) { + pkg, err := ctx.Import(path, srcDir, mode) + if err != nil && strings.HasPrefix(err.Error(), "cannot find package ") { + if versionlessPath, ok := pathWithoutMajorVersion(path); ok { + return ctx.Import(versionlessPath, srcDir, mode) + } + } + return pkg, err +} + // getAllDeps returns a slice of package import paths for all dependencies // (including test dependencies) of the given import path (and subpackages) and commands. func getAllDeps(importPath string, cmds []string) []string { @@ -169,7 +195,7 @@ func getAllDeps(importPath string, cmds []string) []string { // Add the subpackages. for path := range buildutil.ExpandPatterns(&buildContext, []string{subpackagePrefix + "..."}) { - _, err := buildContext.Import(path, "", 0) + _, err := tryImport(buildContext, path, "", 0) if _, ok := err.(*build.NoGoError); ok { continue } @@ -179,7 +205,7 @@ func getAllDeps(importPath string, cmds []string) []string { var addTransitiveClosure func(string) addTransitiveClosure = func(path string) { - pkg, err := buildContext.Import(path, "", 0) + pkg, err := tryImport(buildContext, path, "", 0) printLoadingError(path, err) importPaths := append([]string(nil), pkg.Imports...) @@ -194,7 +220,7 @@ func getAllDeps(importPath string, cmds []string) []string { } // Resolve the import path relative to the importing package. - if bp2, _ := buildContext.Import(path, pkg.Dir, build.FindOnly); bp2 != nil { + if bp2, _ := tryImport(buildContext, path, pkg.Dir, build.FindOnly); bp2 != nil { path = bp2.ImportPath } diff --git a/save_test.go b/save_test.go index 027fc3c..7c5d7a4 100644 --- a/save_test.go +++ b/save_test.go @@ -204,6 +204,39 @@ var saveTests = []saveTest{ "github.com/test/p2", }, }, + + { + "module major versions", + []pkg{{ + "github.com/test/p1", + []file{ + {"foo.go", false, []string{"github.com/test/p2/v2"}}, + {"foo_test.go", false, []string{"github.com/test/p3/v2"}}, + }}, { + "github.com/test/p2", + []file{ + {"foo.go", false, []string{"github.com/test/p4/v2/subpkg"}}, + }}, { + "github.com/test/p3", + []file{ + {"v2/foo.go", false, []string{"os"}}, + }}, { + "github.com/test/p4", + []file{ + {"subpkg/foo.go", false, []string{"github.com/test/p5"}}, + }}, { + "github.com/test/p5", + []file{ + {"foo.go", false, []string{"os"}}, + }}, + }, + []string{ + "github.com/test/p2", + "github.com/test/p3", + "github.com/test/p4", + "github.com/test/p5", + }, + }, } func TestSave(t *testing.T) { @@ -315,3 +348,26 @@ import ( } } } + +func TestPathWithoutMajorVersion(t *testing.T) { + tests := []struct{ + path string + expectedPath string + expectedBool bool + }{ + {"github.com/p1", "github.com/p1", false}, + {"github.com/p1/v1", "github.com/p1", true}, + {"github.com/p2/p3", "github.com/p2/p3", false}, + {"github.com/p1/v2/p3", "github.com/p1/p3", true}, + } + + for _, test := range tests { + actualPath, actualBool := pathWithoutMajorVersion(test.path) + if actualPath != test.expectedPath { + t.Errorf("%v: expected: %v got: %v", test.path, test.expectedPath, actualPath) + } + if actualBool != test.expectedBool { + t.Errorf("%v: expected: %v got: %v", test.path, test.expectedBool, actualBool) + } + } +}