Skip to content

Commit

Permalink
Add RegisterBinaryFromCmdArg function (#1076)
Browse files Browse the repository at this point in the history
  • Loading branch information
stephenafamo authored Jan 28, 2022
1 parent df45dec commit 6a82cfa
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 18 deletions.
56 changes: 55 additions & 1 deletion drivers/registration.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
package drivers

import "fmt"
import (
"fmt"
"os"
"os/exec"
"path/filepath"
"strings"
)

// registeredDrivers are all the drivers which are currently registered
var registeredDrivers = map[string]Interface{}
Expand Down Expand Up @@ -34,3 +40,51 @@ func register(name string, driver Interface) {

registeredDrivers[name] = driver
}

// RegisterBinaryFromCmdArg is used to register drivers from a command line argument
// The argument is either just the driver name or a path to a specific driver
// Panics if a driver with the same name has been previously loaded.
func RegisterBinaryFromCmdArg(arg string) (name, path string, err error) {
path, err = getFullPath(arg)
if err != nil {
return name, path, err
}

name = getNameFromPath(path)

RegisterBinary(name, path)

return name, path, nil
}

// Get the full path to the driver binary from the given path
// the path can also be just the driver name e.g. "psql"
func getFullPath(path string) (string, error) {
var err error

if strings.ContainsRune(path, os.PathSeparator) {
return path, nil
}

path, err = exec.LookPath("sqlboiler-" + path)
if err != nil {
return path, fmt.Errorf("could not find driver executable: %w", err)
}

path, err = filepath.Abs(path)
if err != nil {
return path, fmt.Errorf("could not find absolute path to driver: %w", err)
}

return path, nil
}

// Get the driver name from the path.
// strips the "sqlboiler-" prefix if it exists
// strips the ".exe" suffix if it exits
func getNameFromPath(name string) string {
name = strings.Replace(filepath.Base(name), "sqlboiler-", "", 1)
name = strings.Replace(name, ".exe", "", 1)

return name
}
10 changes: 10 additions & 0 deletions drivers/registration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,16 @@ func TestBinaryRegistration(t *testing.T) {
}
}

func TestBinaryFromArgRegistration(t *testing.T) {
RegisterBinaryFromCmdArg("/bin/true/mock5")

if d, ok := registeredDrivers["mock5"]; !ok {
t.Error("driver was not found")
} else if string(d.(binaryDriver)) != "/bin/true/mock5" {
t.Error("got the wrong driver back")
}
}

func TestGetDriver(t *testing.T) {
didYouPanic := false

Expand Down
19 changes: 2 additions & 17 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ package main
import (
"fmt"
"os"
"os/exec"
"path/filepath"
"strings"

Expand Down Expand Up @@ -147,24 +146,10 @@ func preRun(cmd *cobra.Command, args []string) error {
return commandFailure("must provide a driver name")
}

driverName := args[0]
driverPath := args[0]

if strings.ContainsRune(driverName, os.PathSeparator) {
driverName = strings.Replace(filepath.Base(driverName), "sqlboiler-", "", 1)
driverName = strings.Replace(driverName, ".exe", "", 1)
} else {
driverPath = "sqlboiler-" + driverPath
if p, err := exec.LookPath(driverPath); err == nil {
driverPath = p
}
}

driverPath, err = filepath.Abs(driverPath)
driverName, driverPath, err := drivers.RegisterBinaryFromCmdArg(args[0])
if err != nil {
return errors.Wrap(err, "could not find absolute path to driver")
return errors.Wrap(err, "could not register driver")
}
drivers.RegisterBinary(driverName, driverPath)

cmdConfig = &boilingcore.Config{
DriverName: driverName,
Expand Down

0 comments on commit 6a82cfa

Please sign in to comment.