Skip to content

Commit 5824b97

Browse files
committed
feat: Allow customizing of AWS retry codes
Allow users to customize the error codes that should be retried by the AWS SDK. This enables advanced workflows such as retrying authentication failures
1 parent b7ca15e commit 5824b97

File tree

4 files changed

+207
-3
lines changed

4 files changed

+207
-3
lines changed

aws_config.go

+11-3
Original file line numberDiff line numberDiff line change
@@ -84,10 +84,18 @@ func resolveRetryer(ctx context.Context, awsConfig *aws.Config) {
8484
})
8585
}
8686

87+
var r aws.Retryer = &networkErrorShortcutter{
88+
RetryerV2: retry.NewStandard(standardOptions...),
89+
}
90+
91+
// Add additional retry codes
92+
if retryCodes := os.Getenv("AWS_RETRY_CODES"); retryCodes != "" {
93+
codes := strings.Split(retryCodes, ",")
94+
r = retry.AddWithErrorCodes(r, codes...)
95+
}
96+
8797
awsConfig.Retryer = func() aws.Retryer {
88-
return &networkErrorShortcutter{
89-
RetryerV2: retry.NewStandard(standardOptions...),
90-
}
98+
return r
9199
}
92100
}
93101

aws_config_test.go

+88
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import (
2020
"github.com/aws/aws-sdk-go-v2/config"
2121
"github.com/aws/aws-sdk-go-v2/feature/ec2/imds"
2222
"github.com/aws/aws-sdk-go-v2/service/sts"
23+
"github.com/aws/smithy-go"
2324
"github.com/aws/smithy-go/middleware"
2425
smithyhttp "github.com/aws/smithy-go/transport/http"
2526
"github.com/google/go-cmp/cmp"
@@ -1649,6 +1650,93 @@ max_attempts = 10
16491650
}
16501651
}
16511652

1653+
func TestRetryCodes(t *testing.T) {
1654+
testCases := map[string]struct {
1655+
Config *Config
1656+
EnvironmentVariables map[string]string
1657+
ExpectedRetryableErrors []smithy.APIError
1658+
ExpectedNonRetryableErrors []smithy.APIError
1659+
}{
1660+
"no configuration": {
1661+
Config: &Config{
1662+
AccessKey: servicemocks.MockStaticAccessKey,
1663+
SecretKey: servicemocks.MockStaticSecretKey,
1664+
},
1665+
ExpectedNonRetryableErrors: []smithy.APIError{
1666+
&smithy.GenericAPIError{Code: "error 1"},
1667+
},
1668+
},
1669+
1670+
"AWS_RETRY_CODES single": {
1671+
Config: &Config{
1672+
AccessKey: servicemocks.MockStaticAccessKey,
1673+
SecretKey: servicemocks.MockStaticSecretKey,
1674+
},
1675+
EnvironmentVariables: map[string]string{
1676+
"AWS_RETRY_CODES": "error 1",
1677+
},
1678+
ExpectedRetryableErrors: []smithy.APIError{
1679+
&smithy.GenericAPIError{Code: "error 1"},
1680+
},
1681+
ExpectedNonRetryableErrors: []smithy.APIError{
1682+
&smithy.GenericAPIError{Code: "error 2"},
1683+
},
1684+
},
1685+
1686+
"AWS_RETRY_CODES multiple": {
1687+
Config: &Config{
1688+
AccessKey: servicemocks.MockStaticAccessKey,
1689+
SecretKey: servicemocks.MockStaticSecretKey,
1690+
},
1691+
EnvironmentVariables: map[string]string{
1692+
"AWS_RETRY_CODES": "error 1,error 2",
1693+
},
1694+
ExpectedRetryableErrors: []smithy.APIError{
1695+
&smithy.GenericAPIError{Code: "error 1"},
1696+
&smithy.GenericAPIError{Code: "error 2"},
1697+
},
1698+
ExpectedNonRetryableErrors: []smithy.APIError{
1699+
&smithy.GenericAPIError{Code: "error 3"},
1700+
},
1701+
},
1702+
}
1703+
1704+
for testName, testCase := range testCases {
1705+
testCase := testCase
1706+
1707+
t.Run(testName, func(t *testing.T) {
1708+
oldEnv := servicemocks.InitSessionTestEnv()
1709+
defer servicemocks.PopEnv(oldEnv)
1710+
1711+
for k, v := range testCase.EnvironmentVariables {
1712+
os.Setenv(k, v)
1713+
}
1714+
1715+
testCase.Config.SkipCredsValidation = true
1716+
1717+
awsConfig, err := GetAwsConfig(context.Background(), testCase.Config)
1718+
if err != nil {
1719+
t.Fatalf("error in GetAwsConfig() '%[1]T': %[1]s", err)
1720+
}
1721+
1722+
retryer := awsConfig.Retryer()
1723+
if retryer == nil {
1724+
t.Fatal("no retryer set")
1725+
}
1726+
for _, e := range testCase.ExpectedRetryableErrors {
1727+
if a := retryer.IsErrorRetryable(e); !a {
1728+
t.Errorf(`expected error %q would be retryable, got not retryable`, e)
1729+
}
1730+
}
1731+
for _, e := range testCase.ExpectedNonRetryableErrors {
1732+
if a := retryer.IsErrorRetryable(e); a {
1733+
t.Errorf(`expected error %q would not be retryable, got retryable`, e)
1734+
}
1735+
}
1736+
})
1737+
}
1738+
}
1739+
16521740
func TestServiceEndpointTypes(t *testing.T) {
16531741
testCases := map[string]struct {
16541742
Config *Config

v2/awsv1shim/session.go

+13
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import ( // nosemgrep: no-sdkv2-imports-in-awsv1shim
55
"fmt"
66
"log"
77
"os"
8+
"strings"
89

910
awsv2 "github.com/aws/aws-sdk-go-v2/aws"
1011
"github.com/aws/aws-sdk-go/aws"
@@ -89,6 +90,18 @@ func GetSession(awsC *awsv2.Config, c *awsbase.Config) (*session.Session, error)
8990
sess = sess.Copy(&aws.Config{MaxRetries: aws.Int(retryer.MaxAttempts())})
9091
}
9192

93+
// Add custom error code retries. It's easier to recheck the environment variable
94+
// here as the retry codes aren't available from the original v2 config
95+
if retryCodes := os.Getenv("AWS_RETRY_CODES"); retryCodes != "" {
96+
codes := strings.Split(retryCodes, ",")
97+
log.Printf("[DEBUG] Using additional retry codes: %s", codes)
98+
sess.Handlers.Retry.PushBack(func(r *request.Request) {
99+
if tfawserr.ErrCodeEquals(r.Error, codes...) {
100+
r.Retryable = aws.Bool(true)
101+
}
102+
})
103+
}
104+
92105
SetSessionUserAgent(sess, c.APNInfo, c.UserAgent)
93106

94107
// Add custom input from ENV to the User-Agent request header

v2/awsv1shim/session_test.go

+95
Original file line numberDiff line numberDiff line change
@@ -1470,6 +1470,101 @@ max_attempts = 10
14701470
}
14711471
}
14721472

1473+
func TestRetryCodes(t *testing.T) {
1474+
testCases := map[string]struct {
1475+
Config *awsbase.Config
1476+
EnvironmentVariables map[string]string
1477+
ExpectedRetryableErrors []awserr.Error
1478+
ExpectedNonRetryableErrors []awserr.Error
1479+
}{
1480+
"no configuration": {
1481+
Config: &awsbase.Config{
1482+
AccessKey: servicemocks.MockStaticAccessKey,
1483+
SecretKey: servicemocks.MockStaticSecretKey,
1484+
},
1485+
ExpectedNonRetryableErrors: []awserr.Error{
1486+
awserr.New("error 1", "", nil),
1487+
},
1488+
},
1489+
1490+
"AWS_RETRY_CODES single": {
1491+
Config: &awsbase.Config{
1492+
AccessKey: servicemocks.MockStaticAccessKey,
1493+
SecretKey: servicemocks.MockStaticSecretKey,
1494+
},
1495+
EnvironmentVariables: map[string]string{
1496+
"AWS_RETRY_CODES": "error 1",
1497+
},
1498+
ExpectedRetryableErrors: []awserr.Error{
1499+
awserr.New("error 1", "", nil),
1500+
},
1501+
ExpectedNonRetryableErrors: []awserr.Error{
1502+
awserr.New("error 2", "", nil),
1503+
},
1504+
},
1505+
1506+
"AWS_RETRY_CODES multiple": {
1507+
Config: &awsbase.Config{
1508+
AccessKey: servicemocks.MockStaticAccessKey,
1509+
SecretKey: servicemocks.MockStaticSecretKey,
1510+
},
1511+
EnvironmentVariables: map[string]string{
1512+
"AWS_RETRY_CODES": "error 1,error 2",
1513+
},
1514+
ExpectedRetryableErrors: []awserr.Error{
1515+
awserr.New("error 1", "", nil),
1516+
awserr.New("error 2", "", nil),
1517+
},
1518+
ExpectedNonRetryableErrors: []awserr.Error{
1519+
awserr.New("error 3", "", nil),
1520+
},
1521+
},
1522+
}
1523+
1524+
for testName, testCase := range testCases {
1525+
testCase := testCase
1526+
1527+
t.Run(testName, func(t *testing.T) {
1528+
oldEnv := servicemocks.InitSessionTestEnv()
1529+
defer servicemocks.PopEnv(oldEnv)
1530+
1531+
for k, v := range testCase.EnvironmentVariables {
1532+
os.Setenv(k, v)
1533+
}
1534+
1535+
testCase.Config.SkipCredsValidation = true
1536+
1537+
awsConfig, err := awsbase.GetAwsConfig(context.Background(), testCase.Config)
1538+
if err != nil {
1539+
t.Fatalf("GetAwsConfig() returned error: %s", err)
1540+
}
1541+
actualSession, err := GetSession(&awsConfig, testCase.Config)
1542+
if err != nil {
1543+
t.Fatalf("error in GetSession() '%[1]T': %[1]s", err)
1544+
}
1545+
1546+
for _, e := range testCase.ExpectedRetryableErrors {
1547+
r := &request.Request{
1548+
Error: e,
1549+
}
1550+
actualSession.Handlers.Retry.Run(r)
1551+
if !aws.BoolValue(r.Retryable) {
1552+
t.Errorf(`expected error %q would be retryable, got not retryable`, e)
1553+
}
1554+
}
1555+
for _, e := range testCase.ExpectedNonRetryableErrors {
1556+
r := &request.Request{
1557+
Error: e,
1558+
}
1559+
actualSession.Handlers.Retry.Run(r)
1560+
if aws.BoolValue(r.Retryable) {
1561+
t.Errorf(`expected error %q would not be retryable, got retryable`, e)
1562+
}
1563+
}
1564+
})
1565+
}
1566+
}
1567+
14731568
func TestServiceEndpointTypes(t *testing.T) {
14741569
testCases := map[string]struct {
14751570
Config *awsbase.Config

0 commit comments

Comments
 (0)