Skip to content

Commit ccc0080

Browse files
committed
Refactored driver package to reduce the number of things to keep track of.
1 parent a63c823 commit ccc0080

File tree

4 files changed

+111
-48
lines changed

4 files changed

+111
-48
lines changed

driver/connector.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,15 +30,15 @@ func NewConnector(s Store, driverName string, cfg *Config) (*Connector, error) {
3030
return nil, ErrConfigRequired
3131
}
3232

33-
d, fmt, authErr, err := CreateDriver(driverName)
33+
d, err := CreateDriver(driverName)
3434
if err != nil {
3535
return nil, err
3636
}
3737

3838
// Allow caller to override formatter. This makes it easier to use different DSN
3939
// formats in cases where a default formatter might be difficult to use.
4040
if cfg.Formatter != nil {
41-
fmt = cfg.Formatter
41+
d.Formatter = cfg.Formatter
4242
}
4343

4444
// 0 retries means that it should try once, retry, then don't attempt any more retries
@@ -49,9 +49,9 @@ func NewConnector(s Store, driverName string, cfg *Config) (*Connector, error) {
4949
return &Connector{
5050
store: s,
5151
cfg: cfg,
52-
driver: d,
53-
errHandler: authErr,
54-
formatter: fmt,
52+
driver: d.Driver,
53+
errHandler: d.AuthError,
54+
formatter: d.Formatter,
5555
mu: sync.Mutex{},
5656
}, nil
5757
}

driver/connector_test.go

Lines changed: 60 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,12 @@ func (c *testCredential) GetPassword() string {
9393

9494
func TestNewConnectorFailsWithNilConfig(t *testing.T) {
9595
unregisterAllDrivers()
96-
if err := Register("driver", func() (driver.Driver, Formatter, AuthError) {
97-
return &testDriver{}, MysqlFormatter, errorTester(MysqlErrorText)
96+
if err := Register("driver", func() *Driver {
97+
return &Driver{
98+
Driver: &testDriver{},
99+
Formatter: MysqlFormatter,
100+
AuthError: errorTester(MysqlErrorText),
101+
}
98102
}); err != nil {
99103
t.Error(err)
100104
}
@@ -132,8 +136,12 @@ func TestNewConnectorWithInvalidDriver(t *testing.T) {
132136

133137
func TestConnectorErrorsIfStoreGetFailsReturnsNilOrIsInvalid(t *testing.T) {
134138
unregisterAllDrivers()
135-
if err := Register("driver", func() (driver.Driver, Formatter, AuthError) {
136-
return &testDriver{}, MysqlFormatter, errorTester(MysqlErrorText)
139+
if err := Register("driver", func() *Driver {
140+
return &Driver{
141+
Driver: &testDriver{},
142+
Formatter: MysqlFormatter,
143+
AuthError: errorTester(MysqlErrorText),
144+
}
137145
}); err != nil {
138146
t.Error(err)
139147
}
@@ -211,8 +219,12 @@ func TestConnectorErrorsIfStoreGetFailsReturnsNilOrIsInvalid(t *testing.T) {
211219
func TestConnectorCanUseAlternateFormatter(t *testing.T) {
212220
unregisterAllDrivers()
213221
d := &testDriver{}
214-
if err := Register("driver", func() (driver.Driver, Formatter, AuthError) {
215-
return d, MysqlFormatter, errorTester(MysqlErrorText)
222+
if err := Register("driver", func() *Driver {
223+
return &Driver{
224+
Driver: d,
225+
Formatter: MysqlFormatter,
226+
AuthError: errorTester(MysqlErrorText),
227+
}
216228
}); err != nil {
217229
t.Error(err)
218230
}
@@ -253,8 +265,12 @@ func TestConnectorCanUseAlternateFormatter(t *testing.T) {
253265
func TestConnectorRefreshesCredentialsCorrectly(t *testing.T) {
254266
unregisterAllDrivers()
255267
d := &testDriver{}
256-
if err := Register("driver", func() (driver.Driver, Formatter, AuthError) {
257-
return d, MysqlFormatter, errorTester(MysqlErrorText)
268+
if err := Register("driver", func() *Driver {
269+
return &Driver{
270+
Driver: d,
271+
Formatter: MysqlFormatter,
272+
AuthError: errorTester(MysqlErrorText),
273+
}
258274
}); err != nil {
259275
t.Error(err)
260276
}
@@ -294,8 +310,12 @@ func TestConnectorFailsToConnectThenReconnects(t *testing.T) {
294310
d := &testFailingDriver{
295311
ConnErr: errors.New(MysqlErrorText),
296312
}
297-
if err := Register("driver", func() (driver.Driver, Formatter, AuthError) {
298-
return d, MysqlFormatter, errorTester(MysqlErrorText)
313+
if err := Register("driver", func() *Driver {
314+
return &Driver{
315+
Driver: d,
316+
Formatter: MysqlFormatter,
317+
AuthError: errorTester(MysqlErrorText),
318+
}
299319
}); err != nil {
300320
t.Error(err)
301321
}
@@ -336,8 +356,12 @@ func TestConnectorFailsToRefreshOnConnectionFailure(t *testing.T) {
336356
d := &testFailingDriver{
337357
ConnErr: errors.New(MysqlErrorText),
338358
}
339-
if err := Register("driver", func() (driver.Driver, Formatter, AuthError) {
340-
return d, MysqlFormatter, errorTester(MysqlErrorText)
359+
if err := Register("driver", func() *Driver {
360+
return &Driver{
361+
Driver: d,
362+
Formatter: MysqlFormatter,
363+
AuthError: errorTester(MysqlErrorText),
364+
}
341365
}); err != nil {
342366
t.Error(err)
343367
}
@@ -390,8 +414,12 @@ func TestConnectorRetriesUntilSuccess(t *testing.T) {
390414
MaxCalled: 3,
391415
}
392416

393-
if err := Register("driver", func() (driver.Driver, Formatter, AuthError) {
394-
return d, MysqlFormatter, errorTester(MysqlErrorText)
417+
if err := Register("driver", func() *Driver {
418+
return &Driver{
419+
Driver: d,
420+
Formatter: MysqlFormatter,
421+
AuthError: errorTester(MysqlErrorText),
422+
}
395423
}); err != nil {
396424
t.Error(err)
397425
}
@@ -457,8 +485,12 @@ func TestConnectorRetriesUntilMax(t *testing.T) {
457485
MaxCalled: maxCalled,
458486
}
459487

460-
if err := Register("driver", func() (driver.Driver, Formatter, AuthError) {
461-
return d, MysqlFormatter, errorTester(MysqlErrorText)
488+
if err := Register("driver", func() *Driver {
489+
return &Driver{
490+
Driver: d,
491+
Formatter: MysqlFormatter,
492+
AuthError: errorTester(MysqlErrorText),
493+
}
462494
}); err != nil {
463495
t.Error(err)
464496
}
@@ -522,8 +554,12 @@ func TestConnectorRetriesUntilNonAuthError(t *testing.T) {
522554
MaxCalled: maxCalled,
523555
}
524556

525-
if err := Register("driver", func() (driver.Driver, Formatter, AuthError) {
526-
return d, MysqlFormatter, errorTester(MysqlErrorText)
557+
if err := Register("driver", func() *Driver {
558+
return &Driver{
559+
Driver: d,
560+
Formatter: MysqlFormatter,
561+
AuthError: errorTester(MysqlErrorText),
562+
}
527563
}); err != nil {
528564
t.Error(err)
529565
}
@@ -577,8 +613,12 @@ func TestConnectorErrorsIfUnknownDBErrorMessage(t *testing.T) {
577613
d := &testFailingDriver{
578614
ConnErr: errors.New(""),
579615
}
580-
if err := Register("driver", func() (driver.Driver, Formatter, AuthError) {
581-
return d, MysqlFormatter, errorTester(MysqlErrorText)
616+
if err := Register("driver", func() *Driver {
617+
return &Driver{
618+
Driver: d,
619+
Formatter: MysqlFormatter,
620+
AuthError: errorTester(MysqlErrorText),
621+
}
582622
}); err != nil {
583623
t.Error(err)
584624
}

driver/driver.go

Lines changed: 31 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,14 @@ import (
1212
"github.com/lib/pq"
1313
)
1414

15-
type factory func() (driver.Driver, Formatter, AuthError)
15+
// Driver carries information along with a database/sql/driver required for creating a Connector
16+
type Driver struct {
17+
Driver driver.Driver
18+
Formatter Formatter
19+
AuthError AuthError
20+
}
21+
22+
type factory func() *Driver
1623

1724
type errFactoryAlreadyRegistered struct {
1825
name string
@@ -88,31 +95,44 @@ func drivers() []string {
8895
}
8996

9097
// CreateDriver creates a Driver.
91-
func CreateDriver(name string) (driver.Driver, Formatter, AuthError, error) {
98+
func CreateDriver(name string) (*Driver, error) {
9299
driverMu.Lock()
93100

94101
driverFactory, ok := driverFactories[name]
95102
if !ok {
96103
// Factory has not been registered.
97104
driverMu.Unlock()
98-
return nil, nil, nil, errInvalidDriverName
105+
106+
return nil, errInvalidDriverName
99107
}
100108
defer driverMu.Unlock()
101109

102110
// Run the factory
103-
d, f, authError := driverFactory()
111+
d := driverFactory()
104112

105-
return d, f, authError, nil
113+
return d, nil
106114
}
107115

108-
func mysqlDriver() (driver.Driver, Formatter, AuthError) {
109-
return &mysql.MySQLDriver{}, MysqlFormatter, MySQLAuthError
116+
func mysqlDriver() *Driver {
117+
return &Driver{
118+
Driver: &mysql.MySQLDriver{},
119+
Formatter: MysqlFormatter,
120+
AuthError: MySQLAuthError,
121+
}
110122
}
111123

112-
func pgxDriver() (driver.Driver, Formatter, AuthError) {
113-
return &stdlib.Driver{}, PgFormatter, PostgreSQLAuthError
124+
func pgxDriver() *Driver {
125+
return &Driver{
126+
Driver: &stdlib.Driver{},
127+
Formatter: PgFormatter,
128+
AuthError: PostgreSQLAuthError,
129+
}
114130
}
115131

116-
func pqDriver() (driver.Driver, Formatter, AuthError) {
117-
return &pq.Driver{}, PgFormatter, PostgreSQLAuthError
132+
func pqDriver() *Driver {
133+
return &Driver{
134+
Driver: &pq.Driver{},
135+
Formatter: PgFormatter,
136+
AuthError: PostgreSQLAuthError,
137+
}
118138
}

driver/driver_test.go

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
package driver
22

33
import (
4-
"database/sql/driver"
54
"errors"
65
"strings"
76
"testing"
@@ -49,8 +48,8 @@ func TestCanRegisterAValidFactory(t *testing.T) {
4948
}
5049
}()
5150

52-
var fn factory = func() (driver.Driver, Formatter, AuthError) {
53-
return nil, nil, nil
51+
var fn factory = func() *Driver {
52+
return &Driver{}
5453
}
5554

5655
Register("a driver", fn) //nolint:errcheck
@@ -62,8 +61,8 @@ func TestCanRegisterAValidFactory(t *testing.T) {
6261

6362
func TestCantRegisterMultipleFactoriesWithTheSameName(t *testing.T) {
6463
unregisterAllDrivers()
65-
var fn factory = func() (driver.Driver, Formatter, AuthError) {
66-
return nil, nil, nil
64+
var fn factory = func() *Driver {
65+
return &Driver{}
6766
}
6867

6968
if err := Register(driverName, fn); err != nil {
@@ -82,8 +81,12 @@ func TestCantRegisterMultipleFactoriesWithTheSameName(t *testing.T) {
8281
func TestCanCreateADriverInstance(t *testing.T) {
8382
unregisterAllDrivers()
8483

85-
if err := Register("a driver", func() (driver.Driver, Formatter, AuthError) {
86-
return nil, MysqlFormatter, func(e error) bool { return true }
84+
if err := Register("a driver", func() *Driver {
85+
return &Driver{
86+
Driver: nil,
87+
Formatter: MysqlFormatter,
88+
AuthError: func(e error) bool { return true },
89+
}
8790
}); err != nil {
8891
t.Error(err)
8992
}
@@ -93,29 +96,29 @@ func TestCanCreateADriverInstance(t *testing.T) {
9396
t.Errorf("expected one driver to be registered but got %d", len(ds))
9497
}
9598

96-
d, f, a, err := CreateDriver("a driver")
99+
d, err := CreateDriver("a driver")
97100
if err != nil {
98101
t.Error(err)
99102
}
100103

101-
if d != nil {
104+
if d.Driver != nil {
102105
t.Errorf("expected a nil driver but got %v", d)
103106
}
104107

105108
// test formatter
106-
if f("user", "pass", "host", 0, "", nil) != MysqlFormatter("user", "pass", "host", 0, "", nil) {
109+
if d.Formatter("user", "pass", "host", 0, "", nil) != MysqlFormatter("user", "pass", "host", 0, "", nil) {
107110
t.Error("Formatter should be mysqlFormatter but wasn't")
108111
}
109112

110-
if !a(errors.New("foo")) {
113+
if !d.AuthError(errors.New("foo")) {
111114
t.Error("AuthError should be true but wasn't")
112115
}
113116
}
114117

115118
func TestCantCreateMissingDriver(t *testing.T) {
116119
unregisterAllDrivers()
117120

118-
_, _, _, err := CreateDriver("a driver") //nolint:dogsled
121+
_, err := CreateDriver("a driver") //nolint:dogsled
119122
if err == nil {
120123
t.Error("expected an error but didn't get one")
121124
}

0 commit comments

Comments
 (0)