diff --git a/internal/requests/requests.go b/internal/requests/requests.go index 70cb9d407..a2ce97536 100644 --- a/internal/requests/requests.go +++ b/internal/requests/requests.go @@ -2,6 +2,7 @@ package requests import ( "bytes" + "context" "encoding/json" "fmt" "io" @@ -58,6 +59,7 @@ func (cl *CustomLogger) Printf(format string, args ...interface{}) { func NewKosliClient(httpProxyURL string, maxAPIRetries int, debug bool, logger *logger.Logger) (*Client, error) { retryClient := retryablehttp.NewClient() retryClient.RetryMax = maxAPIRetries + retryClient.CheckRetry = customCheckRetry if debug { retryClient.Logger = &CustomLogger{ Logger: log.New(os.Stderr, "[debug]", log.Lmsgprefix), @@ -300,3 +302,20 @@ func (c *Client) PayloadOutput(req *http.Request, jsonFields map[string]interfac } return nil } + +func customCheckRetry(ctx context.Context, resp *http.Response, err error) (bool, error) { + // Get the default retry policy for errors and certain status codes. + // It will retry on 5xx, 429 and some special cases + shouldRetry, retryErr := retryablehttp.DefaultRetryPolicy(ctx, resp, err) + if retryErr != nil { + return false, retryErr + } + if shouldRetry { + return true, nil + } + // The sever gives 409 if we have a lock conflict. + if resp != nil && resp.StatusCode == 409 { + return true, nil + } + return false, nil +} diff --git a/internal/requests/requests_test.go b/internal/requests/requests_test.go index c71777012..dec330da8 100644 --- a/internal/requests/requests_test.go +++ b/internal/requests/requests_test.go @@ -51,6 +51,10 @@ func (suite *RequestsTestSuite) SetupSuite() { Get("/denied/"). Reply(403). BodyString(`"Denied"`) + suite.fakeService.NewHandler(). + Get("/locked/"). + Reply(409). + BodyString("resource temporarily locked") suite.fakeService.NewHandler(). Get("/fail/"). Reply(500). @@ -288,6 +292,15 @@ func (suite *RequestsTestSuite) TestDo() { wantError: true, expectedErrorMsg: "resource not found", }, + { + name: "GET request to 409 endpoint", + params: &RequestParams{ + Method: http.MethodGet, + URL: suite.fakeService.ResolveURL("/locked/"), + }, + wantError: true, + expectedErrorMsg: fmt.Sprintf("Get \"%s\": GET %s giving up after 2 attempt(s)", suite.fakeService.ResolveURL("/locked/"), suite.fakeService.ResolveURL("/locked/")), + }, { name: "GET request to 500 endpoint", params: &RequestParams{