Skip to content

Commit

Permalink
GODRIVER-3168 Retry KMS requests on transient errors. (#1887)
Browse files Browse the repository at this point in the history
Co-authored-by: Preston Vasquez <[email protected]>
  • Loading branch information
qingyang-hu and prestonvasquez authored Feb 7, 2025
1 parent f99da4d commit a3ad820
Show file tree
Hide file tree
Showing 8 changed files with 203 additions and 12 deletions.
40 changes: 35 additions & 5 deletions .evergreen/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -561,6 +561,25 @@ functions:
KMS_MOCK_SERVERS_RUNNING: "true"
args: [*task-runner, evg-test-kmip]

run-retry-kms-requests:
- command: subprocess.exec
type: test
params:
binary: "bash"
env:
GO_BUILD_TAGS: cse
include_expansions_in_env: [AUTH, SSL, MONGODB_URI, TOPOLOGY,
MONGO_GO_DRIVER_COMPRESSOR]
args: [*task-runner, setup-test]
- command: subprocess.exec
type: test
params:
binary: "bash"
env:
KMS_FAILPOINT_CA_FILE: "${DRIVERS_TOOLS}/.evergreen/x509gen/ca.pem"
KMS_FAILPOINT_SERVER_RUNNING: "true"
args: [*task-runner, evg-test-retry-kms-requests]

run-fuzz-tests:
- command: subprocess.exec
type: test
Expand Down Expand Up @@ -1468,7 +1487,7 @@ tasks:
SSL: "nossl"

- name: "test-kms-tls-invalid-cert"
tags: ["kms-tls"]
tags: ["kms-test"]
commands:
- func: bootstrap-mongo-orchestration
vars:
Expand All @@ -1484,7 +1503,7 @@ tasks:
SSL: "nossl"

- name: "test-kms-tls-invalid-hostname"
tags: ["kms-tls"]
tags: ["kms-test"]
commands:
- func: bootstrap-mongo-orchestration
vars:
Expand Down Expand Up @@ -1514,6 +1533,17 @@ tasks:
AUTH: "noauth"
SSL: "nossl"

- name: "test-retry-kms-requests"
tags: ["kms-test"]
commands:
- func: bootstrap-mongo-orchestration
vars:
TOPOLOGY: "server"
AUTH: "noauth"
SSL: "nossl"
- func: start-cse-servers
- func: run-retry-kms-requests

- name: "test-serverless"
tags: ["serverless"]
commands:
Expand Down Expand Up @@ -2201,11 +2231,11 @@ buildvariants:
tasks:
- name: ".versioned-api"

- matrix_name: "kms-tls-test"
- matrix_name: "kms-test"
matrix_spec: { version: ["7.0"], os-ssl-40: ["rhel87-64"] }
display_name: "KMS TLS ${os-ssl-40}"
display_name: "KMS TEST ${os-ssl-40}"
tasks:
- name: ".kms-tls"
- name: ".kms-test"

- matrix_name: "load-balancer-test"
tags: ["pullrequest"]
Expand Down
3 changes: 3 additions & 0 deletions Taskfile.yml
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,9 @@ tasks:
evg-test-kms:
- go test -exec "env PKG_CONFIG_PATH=${PKG_CONFIG_PATH} LD_LIBRARY_PATH=${LD_LIBRARY_PATH}" ${BUILD_TAGS} -v -timeout {{.TEST_TIMEOUT}}s ./internal/integration -run TestClientSideEncryptionProse/kms_tls_tests >> test.suite

evg-test-retry-kms-requests:
- go test -exec "env PKG_CONFIG_PATH=${PKG_CONFIG_PATH} LD_LIBRARY_PATH=${LD_LIBRARY_PATH}" ${BUILD_TAGS} -v -timeout {{.TEST_TIMEOUT}}s ./internal/integration -run TestClientSideEncryptionProse/kms_retry_tests >> test.suite

evg-test-load-balancers:
# Load balancer should be tested with all unified tests as well as tests in the following
# components: retryable reads, retryable writes, change streams, initial DNS seedlist discovery.
Expand Down
2 changes: 1 addition & 1 deletion etc/install-libmongocrypt.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# This script installs libmongocrypt into an "install" directory.
set -eux

LIBMONGOCRYPT_TAG="1.11.0"
LIBMONGOCRYPT_TAG="1.12.0"

# Install libmongocrypt based on OS.
if [ "Windows_NT" = "${OS:-}" ]; then
Expand Down
145 changes: 144 additions & 1 deletion internal/integration/client_side_encryption_prose_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"context"
"crypto/tls"
"encoding/base64"
"encoding/json"
"fmt"
"io/ioutil"
"net"
Expand All @@ -30,6 +31,7 @@ import (
"go.mongodb.org/mongo-driver/v2/internal/handshake"
"go.mongodb.org/mongo-driver/v2/internal/integration/mtest"
"go.mongodb.org/mongo-driver/v2/internal/integtest"
"go.mongodb.org/mongo-driver/v2/internal/require"
"go.mongodb.org/mongo-driver/v2/mongo"
"go.mongodb.org/mongo-driver/v2/mongo/options"
"go.mongodb.org/mongo-driver/v2/mongo/writeconcern"
Expand Down Expand Up @@ -2925,7 +2927,7 @@ func TestClientSideEncryptionProse(t *testing.T) {
}
})

mt.RunOpts("22. range explicit encryption applies defaults", qeRunOpts22, func(mt *mtest.T) {
mt.RunOpts("23. range explicit encryption applies defaults", qeRunOpts22, func(mt *mtest.T) {
err := mt.Client.Database("keyvault").Collection("datakeys").Drop(context.Background())
assert.Nil(mt, err, "error on Drop: %v", err)

Expand Down Expand Up @@ -2986,6 +2988,147 @@ func TestClientSideEncryptionProse(t *testing.T) {
assert.Greater(t, len(payload.Data), len(payloadDefaults.Data), "the returned payload size is expected to be greater than %d", len(payloadDefaults.Data))
})
})

mt.RunOpts("24. kms retry tests", noClientOpts, func(mt *mtest.T) {
kmsTlsTestcase := os.Getenv("KMS_FAILPOINT_SERVER_RUNNING")
if kmsTlsTestcase == "" {
mt.Skipf("Skipping test as KMS_FAILPOINT_SERVER_RUNNING is not set")
}

mt.Parallel()

tlsCAFile := os.Getenv("KMS_FAILPOINT_CA_FILE")
require.NotEqual(mt, tlsCAFile, "", "failed to load CA file")

clientAndCATlsMap := map[string]interface{}{
"tlsCAFile": tlsCAFile,
}
tlsCfg, err := options.BuildTLSConfig(clientAndCATlsMap)
require.NoError(mt, err, "BuildTLSConfig error: %v", err)

setFailPoint := func(failure string, count int) error {
url := fmt.Sprintf("https://localhost:9003/set_failpoint/%s", failure)
var payloadBuf bytes.Buffer
body := map[string]int{"count": count}
json.NewEncoder(&payloadBuf).Encode(body)
req, err := http.NewRequest(http.MethodPost, url, &payloadBuf)
if err != nil {
return err
}

client := &http.Client{
Transport: &http.Transport{TLSClientConfig: tlsCfg},
}
res, err := client.Do(req)
if err != nil {
return err
}
return res.Body.Close()
}

kmsProviders := map[string]map[string]interface{}{
"aws": {
"accessKeyId": awsAccessKeyID,
"secretAccessKey": awsSecretAccessKey,
},
"azure": {
"tenantId": azureTenantID,
"clientId": azureClientID,
"clientSecret": azureClientSecret,
"identityPlatformEndpoint": "127.0.0.1:9003",
},
"gcp": {
"email": gcpEmail,
"privateKey": gcpPrivateKey,
"endpoint": "127.0.0.1:9003",
},
}

dataKeys := []struct {
provider string
masterKey interface{}
}{
{"aws", bson.D{
{"region", "foo"},
{"key", "bar"},
{"endpoint", "127.0.0.1:9003"},
}},
{"azure", bson.D{
{"keyVaultEndpoint", "127.0.0.1:9003"},
{"keyName", "foo"},
}},
{"gcp", bson.D{
{"projectId", "foo"},
{"location", "bar"},
{"keyRing", "baz"},
{"keyName", "qux"},
{"endpoint", "127.0.0.1:9003"},
}},
}

testCases := []struct {
name string
failure string
}{
{"Case 1: createDataKey and encrypt with TCP retry", "network"},
{"Case 2: createDataKey and encrypt with HTTP retry", "http"},
}

for _, tc := range testCases {
for _, dataKey := range dataKeys {
mt.Run(fmt.Sprintf("%s_%s", tc.name, dataKey.provider), func(mt *mtest.T) {
keyVaultClient, err := mongo.Connect(options.Client().ApplyURI(mtest.ClusterURI()))
require.NoError(mt, err, "error on Connect: %v", err)

ceo := options.ClientEncryption().
SetKeyVaultNamespace(kvNamespace).
SetKmsProviders(kmsProviders).
SetTLSConfig(map[string]*tls.Config{dataKey.provider: tlsCfg})
clientEncryption, err := mongo.NewClientEncryption(keyVaultClient, ceo)
require.NoError(mt, err, "error on NewClientEncryption: %v", err)

err = setFailPoint(tc.failure, 1)
require.NoError(mt, err, "mock server error: %v", err)

dkOpts := options.DataKey().SetMasterKey(dataKey.masterKey)
var keyID bson.Binary
keyID, err = clientEncryption.CreateDataKey(context.Background(), dataKey.provider, dkOpts)
require.NoError(mt, err, "error in CreateDataKey: %v", err)

err = setFailPoint(tc.failure, 1)
require.NoError(mt, err, "mock server error: %v", err)

testVal := bson.RawValue{Type: bson.TypeInt32, Value: bsoncore.AppendInt32(nil, 123)}
eo := options.Encrypt().
SetKeyID(keyID).
SetAlgorithm("AEAD_AES_256_CBC_HMAC_SHA_512-Deterministic")
_, err = clientEncryption.Encrypt(context.Background(), testVal, eo)
require.NoError(mt, err, "error in Encrypt: %v", err)
})
}
}

for _, dataKey := range dataKeys {
mt.Run(fmt.Sprintf("Case 3: createDataKey fails after too many retries_%s", dataKey.provider), func(mt *mtest.T) {
keyVaultClient, err := mongo.Connect(options.Client().ApplyURI(mtest.ClusterURI()))
require.NoError(mt, err, "error on Connect: %v", err)

ceo := options.ClientEncryption().
SetKeyVaultNamespace(kvNamespace).
SetKmsProviders(kmsProviders).
SetTLSConfig(map[string]*tls.Config{dataKey.provider: tlsCfg})
clientEncryption, err := mongo.NewClientEncryption(keyVaultClient, ceo)
require.NoError(mt, err, "error on NewClientEncryption: %v", err)

err = setFailPoint("network", 4)
require.NoError(mt, err, "mock server error: %v", err)

dkOpts := options.DataKey().SetMasterKey(dataKey.masterKey)
_, err = clientEncryption.CreateDataKey(context.Background(), dataKey.provider, dkOpts)
require.ErrorContains(mt, err, "KMS request failed after 3 retries due to a network error")
})
}
})
}

func getWatcher(mt *mtest.T, streamType mongo.StreamType, cpt *cseProseTest) watcher {
Expand Down
6 changes: 2 additions & 4 deletions x/mongo/driver/crypt.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,7 @@ package driver
import (
"context"
"crypto/tls"
"errors"
"fmt"
"io"
"strings"
"time"

Expand Down Expand Up @@ -399,8 +397,8 @@ func (c *crypt) decryptKey(kmsCtx *mongocrypt.KmsContext) error {

res := make([]byte, bytesNeeded)
bytesRead, err := conn.Read(res)
if err != nil && !errors.Is(err, io.EOF) {
return err
if err != nil {
return kmsCtx.RequestError()
}

if err = kmsCtx.FeedResponse(res[:bytesRead]); err != nil {
Expand Down
3 changes: 2 additions & 1 deletion x/mongo/driver/mongocrypt/mongocrypt.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ func NewMongoCrypt(opts *options.MongoCryptOptions) (*MongoCrypt, error) {
if wrapped == nil {
return nil, errors.New("could not create new mongocrypt object")
}
C.mongocrypt_setopt_retry_kms(wrapped, true)
httpClient := opts.HTTPClient
if httpClient == nil {
httpClient = httputil.DefaultHTTPClient
Expand Down Expand Up @@ -85,7 +86,7 @@ func NewMongoCrypt(opts *options.MongoCryptOptions) (*MongoCrypt, error) {
}

if opts.BypassQueryAnalysis {
C.mongocrypt_setopt_bypass_query_analysis(wrapped)
C.mongocrypt_setopt_bypass_query_analysis(crypt.wrapped)
}

// If loading the crypt_shared library isn't disabled, set the default library search path "$SYSTEM"
Expand Down
11 changes: 11 additions & 0 deletions x/mongo/driver/mongocrypt/mongocrypt_kms_context.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ package mongocrypt

// #include <mongocrypt.h>
import "C"
import "time"

// KmsContext represents a mongocrypt_kms_ctx_t handle.
type KmsContext struct {
Expand Down Expand Up @@ -41,6 +42,8 @@ func (kc *KmsContext) KMSProvider() string {

// Message returns the message to send to the KMS.
func (kc *KmsContext) Message() ([]byte, error) {
time.Sleep(time.Duration(C.mongocrypt_kms_ctx_usleep(kc.wrapped)) * time.Microsecond)

msgBinary := newBinary()
defer msgBinary.close()

Expand Down Expand Up @@ -74,3 +77,11 @@ func (kc *KmsContext) createErrorFromStatus() error {
C.mongocrypt_kms_ctx_status(kc.wrapped, status)
return errorFromStatus(status)
}

// RequestError returns the source of the network error for KMS requests.
func (kc *KmsContext) RequestError() error {
if bool(C.mongocrypt_kms_ctx_fail(kc.wrapped)) {
return nil
}
return kc.createErrorFromStatus()
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,8 @@ func (kc *KmsContext) BytesNeeded() int32 {
func (kc *KmsContext) FeedResponse([]byte) error {
panic(cseNotSupportedMsg)
}

// RequestError returns the source of the network error for KMS requests.
func (kc *KmsContext) RequestError() error {
panic(cseNotSupportedMsg)
}

0 comments on commit a3ad820

Please sign in to comment.