Skip to content

Commit 6a82cfa

Browse files
authored
Add RegisterBinaryFromCmdArg function (#1076)
1 parent df45dec commit 6a82cfa

File tree

3 files changed

+67
-18
lines changed

3 files changed

+67
-18
lines changed

drivers/registration.go

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,12 @@
11
package drivers
22

3-
import "fmt"
3+
import (
4+
"fmt"
5+
"os"
6+
"os/exec"
7+
"path/filepath"
8+
"strings"
9+
)
410

511
// registeredDrivers are all the drivers which are currently registered
612
var registeredDrivers = map[string]Interface{}
@@ -34,3 +40,51 @@ func register(name string, driver Interface) {
3440

3541
registeredDrivers[name] = driver
3642
}
43+
44+
// RegisterBinaryFromCmdArg is used to register drivers from a command line argument
45+
// The argument is either just the driver name or a path to a specific driver
46+
// Panics if a driver with the same name has been previously loaded.
47+
func RegisterBinaryFromCmdArg(arg string) (name, path string, err error) {
48+
path, err = getFullPath(arg)
49+
if err != nil {
50+
return name, path, err
51+
}
52+
53+
name = getNameFromPath(path)
54+
55+
RegisterBinary(name, path)
56+
57+
return name, path, nil
58+
}
59+
60+
// Get the full path to the driver binary from the given path
61+
// the path can also be just the driver name e.g. "psql"
62+
func getFullPath(path string) (string, error) {
63+
var err error
64+
65+
if strings.ContainsRune(path, os.PathSeparator) {
66+
return path, nil
67+
}
68+
69+
path, err = exec.LookPath("sqlboiler-" + path)
70+
if err != nil {
71+
return path, fmt.Errorf("could not find driver executable: %w", err)
72+
}
73+
74+
path, err = filepath.Abs(path)
75+
if err != nil {
76+
return path, fmt.Errorf("could not find absolute path to driver: %w", err)
77+
}
78+
79+
return path, nil
80+
}
81+
82+
// Get the driver name from the path.
83+
// strips the "sqlboiler-" prefix if it exists
84+
// strips the ".exe" suffix if it exits
85+
func getNameFromPath(name string) string {
86+
name = strings.Replace(filepath.Base(name), "sqlboiler-", "", 1)
87+
name = strings.Replace(name, ".exe", "", 1)
88+
89+
return name
90+
}

drivers/registration_test.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,16 @@ func TestBinaryRegistration(t *testing.T) {
4444
}
4545
}
4646

47+
func TestBinaryFromArgRegistration(t *testing.T) {
48+
RegisterBinaryFromCmdArg("/bin/true/mock5")
49+
50+
if d, ok := registeredDrivers["mock5"]; !ok {
51+
t.Error("driver was not found")
52+
} else if string(d.(binaryDriver)) != "/bin/true/mock5" {
53+
t.Error("got the wrong driver back")
54+
}
55+
}
56+
4757
func TestGetDriver(t *testing.T) {
4858
didYouPanic := false
4959

main.go

Lines changed: 2 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ package main
44
import (
55
"fmt"
66
"os"
7-
"os/exec"
87
"path/filepath"
98
"strings"
109

@@ -147,24 +146,10 @@ func preRun(cmd *cobra.Command, args []string) error {
147146
return commandFailure("must provide a driver name")
148147
}
149148

150-
driverName := args[0]
151-
driverPath := args[0]
152-
153-
if strings.ContainsRune(driverName, os.PathSeparator) {
154-
driverName = strings.Replace(filepath.Base(driverName), "sqlboiler-", "", 1)
155-
driverName = strings.Replace(driverName, ".exe", "", 1)
156-
} else {
157-
driverPath = "sqlboiler-" + driverPath
158-
if p, err := exec.LookPath(driverPath); err == nil {
159-
driverPath = p
160-
}
161-
}
162-
163-
driverPath, err = filepath.Abs(driverPath)
149+
driverName, driverPath, err := drivers.RegisterBinaryFromCmdArg(args[0])
164150
if err != nil {
165-
return errors.Wrap(err, "could not find absolute path to driver")
151+
return errors.Wrap(err, "could not register driver")
166152
}
167-
drivers.RegisterBinary(driverName, driverPath)
168153

169154
cmdConfig = &boilingcore.Config{
170155
DriverName: driverName,

0 commit comments

Comments
 (0)