diff --git a/common/db/command.go b/common/db/command.go index f2af6ca96..a8c1701c7 100644 --- a/common/db/command.go +++ b/common/db/command.go @@ -104,7 +104,7 @@ func (sp *SessionProvider) ServerVersion() (string, error) { func (sp *SessionProvider) ServerVersionArray() (Version, error) { var version Version out := struct { - VersionArray []int32 `bson:"versionArray"` + VersionArray []int `bson:"versionArray"` }{} err := sp.RunString("buildInfo", &out, "admin") if err != nil { @@ -113,9 +113,17 @@ func (sp *SessionProvider) ServerVersionArray() (Version, error) { if len(out.VersionArray) < 3 { return version, fmt.Errorf("buildInfo.versionArray had fewer than 3 elements") } - for i := 0; i <= 2; i++ { - version[i] = int(out.VersionArray[i]) + + copy(version[:], out.VersionArray[:len(version)]) + + // In development server builds `versionArray`’s 4th member is negative, and + // `versionArray`’s patch version exceeds `version`’s by 1. Since we have + // logic that compares this method’s output to `version` we need to subtract + // one from the patch value. + if len(out.VersionArray) > 3 && out.VersionArray[3] < 0 { + version[2]-- } + return version, nil } diff --git a/common/db/db.go b/common/db/db.go index d88df972f..c70cfac32 100644 --- a/common/db/db.go +++ b/common/db/db.go @@ -331,8 +331,6 @@ func configureClient(opts options.ToolOptions) (*mongo.Client, error) { clientopt := mopt.Client() cs := opts.URI.ParsedConnString() - clientopt.Hosts = cs.Hosts - if opts.RetryWrites != nil { clientopt.SetRetryWrites(*opts.RetryWrites) } @@ -506,6 +504,12 @@ func configureClient(opts options.ToolOptions) (*mongo.Client, error) { clientopt.SetAuth(cred) } + if opts.Kerberos != nil && opts.Kerberos.ServiceHost != "" { + clientopt.Hosts = cs.Hosts + } else { + clientopt.ApplyURI(cs.String()) + } + if opts.SSL != nil && opts.UseSSL { // Error on unsupported features if opts.SSLFipsMode { diff --git a/common/util/mongo.go b/common/util/mongo.go index 972f126c8..d8559cae3 100644 --- a/common/util/mongo.go +++ b/common/util/mongo.go @@ -7,8 +7,13 @@ package util import ( + "context" "fmt" "strings" + + "github.com/pkg/errors" + "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/mongo" ) const ( @@ -228,3 +233,29 @@ func ValidateCollectionGrammar(collection string) error { // collection name is valid return nil } + +func IsConnectionAuthenticated(ctx context.Context, conn *mongo.Client) (bool, error) { + res := conn.Database("admin").RunCommand( + ctx, + bson.D{{"connectionStatus", 1}}, + ) + + if res.Err() != nil { + return false, errors.Wrap(res.Err(), "failed to query for connection information") + } + + body := struct { + AuthInfo struct { + AuthenticatedUsers []bson.D `bson:"authenticatedUsers"` + } `bson:"authInfo"` + }{} + + err := res.Decode(&body) + if err != nil { + raw, _ := res.Raw() + + return false, errors.Wrapf(err, "failed to decode connection information (%+v)", raw) + } + + return len(body.AuthInfo.AuthenticatedUsers) > 0, nil +} diff --git a/mongorestore/mongorestore_test.go b/mongorestore/mongorestore_test.go index d7e81b831..4a48e2449 100644 --- a/mongorestore/mongorestore_test.go +++ b/mongorestore/mongorestore_test.go @@ -30,6 +30,7 @@ import ( "github.com/mongodb/mongo-tools/common/options" "github.com/mongodb/mongo-tools/common/testtype" "github.com/mongodb/mongo-tools/common/testutil" + "github.com/mongodb/mongo-tools/common/util" . "github.com/smartystreets/goconvey/convey" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -1833,7 +1834,8 @@ func TestCreateIndexes(t *testing.T) { "RestoreOplog() should convert commitIndexBuild op to createIndexes cmd and build index", func() { destColl := session.Database("create_indexes").Collection("test") - indexes, _ := destColl.Indexes().List(context.Background()) + indexes, err := destColl.Indexes().List(context.Background()) + So(err, ShouldBeNil) type indexSpec struct { Name, NS string @@ -2202,7 +2204,14 @@ func TestRestoreTimeseriesCollections(t *testing.T) { session, err := sessionProvider.GetSession() if err != nil { - t.Fatalf("No client available") + t.Fatalf("No client available: %v", err) + } + + isAuthn, err := util.IsConnectionAuthenticated(ctx, session) + require.NoError(t, err, "should query for authentication state") + + if isAuthn { + t.Skip("This test requires a non-authenticated connection.") } fcv := testutil.GetFCV(session) @@ -3424,12 +3433,19 @@ func TestDumpAndRestoreConfigDB(t *testing.T) { testtype.SkipUnlessTestType(t, testtype.IntegrationTestType) - _, err := testutil.GetBareSession() + client, err := testutil.GetBareSession() require.NoError(err, "can connect to server") + isAuthn, err := util.IsConnectionAuthenticated(context.Background(), client) + require.NoError(err, "should query for authentication state") + t.Run( "test dump and restore only config db includes all config collections", func(t *testing.T) { + if isAuthn { + t.Skip("This test requires a non-authenticated connection.") + } + testDumpAndRestoreConfigDBIncludesAllCollections(t) }, )