diff --git a/drivers/registration.go b/drivers/registration.go index c595faf6a..5467a1b6e 100644 --- a/drivers/registration.go +++ b/drivers/registration.go @@ -1,6 +1,12 @@ package drivers -import "fmt" +import ( + "fmt" + "os" + "os/exec" + "path/filepath" + "strings" +) // registeredDrivers are all the drivers which are currently registered var registeredDrivers = map[string]Interface{} @@ -34,3 +40,51 @@ func register(name string, driver Interface) { registeredDrivers[name] = driver } + +// RegisterBinaryFromCmdArg is used to register drivers from a command line argument +// The argument is either just the driver name or a path to a specific driver +// Panics if a driver with the same name has been previously loaded. +func RegisterBinaryFromCmdArg(arg string) (name, path string, err error) { + path, err = getFullPath(arg) + if err != nil { + return name, path, err + } + + name = getNameFromPath(path) + + RegisterBinary(name, path) + + return name, path, nil +} + +// Get the full path to the driver binary from the given path +// the path can also be just the driver name e.g. "psql" +func getFullPath(path string) (string, error) { + var err error + + if strings.ContainsRune(path, os.PathSeparator) { + return path, nil + } + + path, err = exec.LookPath("sqlboiler-" + path) + if err != nil { + return path, fmt.Errorf("could not find driver executable: %w", err) + } + + path, err = filepath.Abs(path) + if err != nil { + return path, fmt.Errorf("could not find absolute path to driver: %w", err) + } + + return path, nil +} + +// Get the driver name from the path. +// strips the "sqlboiler-" prefix if it exists +// strips the ".exe" suffix if it exits +func getNameFromPath(name string) string { + name = strings.Replace(filepath.Base(name), "sqlboiler-", "", 1) + name = strings.Replace(name, ".exe", "", 1) + + return name +} diff --git a/drivers/registration_test.go b/drivers/registration_test.go index a84778840..c57b9775d 100644 --- a/drivers/registration_test.go +++ b/drivers/registration_test.go @@ -44,6 +44,16 @@ func TestBinaryRegistration(t *testing.T) { } } +func TestBinaryFromArgRegistration(t *testing.T) { + RegisterBinaryFromCmdArg("/bin/true/mock5") + + if d, ok := registeredDrivers["mock5"]; !ok { + t.Error("driver was not found") + } else if string(d.(binaryDriver)) != "/bin/true/mock5" { + t.Error("got the wrong driver back") + } +} + func TestGetDriver(t *testing.T) { didYouPanic := false diff --git a/main.go b/main.go index c5ae22be6..f49b956bf 100644 --- a/main.go +++ b/main.go @@ -4,7 +4,6 @@ package main import ( "fmt" "os" - "os/exec" "path/filepath" "strings" @@ -147,24 +146,10 @@ func preRun(cmd *cobra.Command, args []string) error { return commandFailure("must provide a driver name") } - driverName := args[0] - driverPath := args[0] - - if strings.ContainsRune(driverName, os.PathSeparator) { - driverName = strings.Replace(filepath.Base(driverName), "sqlboiler-", "", 1) - driverName = strings.Replace(driverName, ".exe", "", 1) - } else { - driverPath = "sqlboiler-" + driverPath - if p, err := exec.LookPath(driverPath); err == nil { - driverPath = p - } - } - - driverPath, err = filepath.Abs(driverPath) + driverName, driverPath, err := drivers.RegisterBinaryFromCmdArg(args[0]) if err != nil { - return errors.Wrap(err, "could not find absolute path to driver") + return errors.Wrap(err, "could not register driver") } - drivers.RegisterBinary(driverName, driverPath) cmdConfig = &boilingcore.Config{ DriverName: driverName,