Skip to content

Commit

Permalink
DXCDT-541: Persistent rate limit handling (#839)
Browse files Browse the repository at this point in the history
* Adding more bespoke retry handler

* Removing dependency on config struct

* Adding test

---------

Co-authored-by: Will Vedder <will.vedder@okta.com>
  • Loading branch information
willvedd and willvedd authored Sep 12, 2023
1 parent c39fabf commit 04992d8
Show file tree
Hide file tree
Showing 3 changed files with 298 additions and 13 deletions.
14 changes: 1 addition & 13 deletions internal/cli/cli.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,13 @@ package cli
import (
"context"
"fmt"
"net/http"
"strings"

"github.com/auth0/go-auth0/management"
"github.com/spf13/cobra"
"github.com/spf13/pflag"

"github.com/auth0/auth0-cli/internal/analytics"
"github.com/auth0/auth0-cli/internal/ansi"
"github.com/auth0/auth0-cli/internal/auth0"
"github.com/auth0/auth0-cli/internal/buildinfo"
"github.com/auth0/auth0-cli/internal/config"
"github.com/auth0/auth0-cli/internal/display"
"github.com/auth0/auth0-cli/internal/iostream"
Expand Down Expand Up @@ -108,15 +104,7 @@ func (c *cli) setupWithAuthentication(ctx context.Context) error {
}
}

userAgent := fmt.Sprintf("%v/%v", userAgent, strings.TrimPrefix(buildinfo.Version, "v"))

api, err := management.New(
tenant.Domain,
management.WithStaticToken(tenant.GetAccessToken()),
management.WithUserAgent(userAgent),
management.WithAuth0ClientEnvEntry("Auth0-CLI", strings.TrimPrefix(buildinfo.Version, "v")),
management.WithRetries(5, []int{http.StatusTooManyRequests, http.StatusInternalServerError}),
)
api, err := initializeManagementClient(tenant.Domain, tenant.GetAccessToken())
if err != nil {
return err
}
Expand Down
122 changes: 122 additions & 0 deletions internal/cli/management.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
package cli

import (
"crypto/tls"
"crypto/x509"
"fmt"
"net/http"
"net/url"
"regexp"
"strconv"
"strings"
"time"

"github.com/PuerkitoBio/rehttp"
"github.com/auth0/go-auth0/management"

"github.com/auth0/auth0-cli/internal/buildinfo"
)

func initializeManagementClient(tenantDomain string, accessToken string) (*management.Management, error) {
client, err := management.New(
tenantDomain,
management.WithStaticToken(accessToken),
management.WithUserAgent(fmt.Sprintf("%v/%v", userAgent, strings.TrimPrefix(buildinfo.Version, "v"))),
management.WithAuth0ClientEnvEntry("Auth0-CLI", strings.TrimPrefix(buildinfo.Version, "v")),
management.WithNoRetries(),
management.WithClient(customClientWithRetries()),
)

return client, err
}

func customClientWithRetries() *http.Client {
client := &http.Client{
Transport: rateLimitTransport(
retryableErrorTransport(
http.DefaultTransport,
),
),
}

return client
}

func rateLimitTransport(tripper http.RoundTripper) http.RoundTripper {
return rehttp.NewTransport(tripper, rateLimitRetry, rateLimitDelay)
}

func rateLimitRetry(attempt rehttp.Attempt) bool {
if attempt.Response == nil {
return false
}

return attempt.Response.StatusCode == http.StatusTooManyRequests
}

func rateLimitDelay(attempt rehttp.Attempt) time.Duration {
resetAt := attempt.Response.Header.Get("X-RateLimit-Reset")

resetAtUnix, err := strconv.ParseInt(resetAt, 10, 64)
if err != nil {
resetAtUnix = time.Now().Add(5 * time.Second).Unix()
}

return time.Duration(resetAtUnix-time.Now().Unix()) * time.Second
}

func retryableErrorTransport(tripper http.RoundTripper) http.RoundTripper {
retryableCodes := []int{
http.StatusServiceUnavailable,
http.StatusInternalServerError,
http.StatusBadGateway,
http.StatusGatewayTimeout,
// Cloudflare-specific server error that is generated
// because Cloudflare did not receive an HTTP response
// from the origin server after an HTTP Connection was made.
524,
}

return rehttp.NewTransport(
tripper,
rehttp.RetryAll(
rehttp.RetryMaxRetries(3),
rehttp.RetryAny(
rehttp.RetryStatuses(retryableCodes...),
rehttp.RetryIsErr(retryableErrorRetryFunc),
),
),
rehttp.ExpJitterDelay(500*time.Millisecond, 10*time.Second),
)
}

func retryableErrorRetryFunc(err error) bool {
if err == nil {
return false
}

if v, ok := err.(*url.Error); ok {
// Don't retry if the error was due to too many redirects.
if regexp.MustCompile(`stopped after \d+ redirects\z`).MatchString(v.Error()) {
return false
}

// Don't retry if the error was due to an invalid protocol scheme.
if regexp.MustCompile(`unsupported protocol scheme`).MatchString(v.Error()) {
return false
}

// Don't retry if the certificate issuer is unknown.
if _, ok := v.Err.(*tls.CertificateVerificationError); ok {
return false
}

// Don't retry if the certificate issuer is unknown.
if _, ok := v.Err.(x509.UnknownAuthorityError); ok {
return false
}
}

// The error is likely recoverable so retry.
return true
}
175 changes: 175 additions & 0 deletions internal/cli/management_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
package cli

import (
"crypto/tls"
"crypto/x509"
"errors"
"net/http"
"net/http/httptest"
"net/url"
"strconv"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestCustomClientWithRetries(t *testing.T) {
t.Run("it retries on rate limit error", func(t *testing.T) {
apiCalls := 0
fail := true
testServer := httptest.NewServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
apiCalls++

if fail {
fail = false
writer.WriteHeader(429)
resetAt := time.Now().Add(time.Second).Unix()
writer.Header().Set("X-RateLimit-Reset", strconv.Itoa(int(resetAt)))
return
}

writer.WriteHeader(200)
}))

client := customClientWithRetries()

request, err := http.NewRequest(http.MethodGet, testServer.URL, nil)
require.NoError(t, err)

response, err := client.Do(request)
require.NoError(t, err)

assert.Equal(t, 200, response.StatusCode)
assert.False(t, fail)
assert.Equal(t, 2, apiCalls)

t.Cleanup(func() {
testServer.Close()
err := response.Body.Close()
require.NoError(t, err)
})
})

t.Run("it retries on server error", func(t *testing.T) {
apiCalls := 0
fail := true
testServer := httptest.NewServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
apiCalls++

if fail {
fail = false
writer.WriteHeader(500)
return
}

writer.WriteHeader(200)
}))

client := customClientWithRetries()

request, err := http.NewRequest(http.MethodGet, testServer.URL, nil)
require.NoError(t, err)

response, err := client.Do(request)
require.NoError(t, err)

assert.Equal(t, 200, response.StatusCode)
assert.False(t, fail)
assert.Equal(t, 2, apiCalls)

t.Cleanup(func() {
testServer.Close()
err := response.Body.Close()
require.NoError(t, err)
})
})

t.Run("it does not retry more than 3 times on server error", func(t *testing.T) {
apiCalls := 0
testServer := httptest.NewServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
apiCalls++
writer.WriteHeader(500)
}))

client := customClientWithRetries()

request, err := http.NewRequest(http.MethodGet, testServer.URL, nil)
require.NoError(t, err)

response, err := client.Do(request)
require.NoError(t, err)

assert.Equal(t, 500, response.StatusCode)
assert.Equal(t, 3+1, apiCalls) // 3 retries + 1 first call.

t.Cleanup(func() {
testServer.Close()
err := response.Body.Close()
require.NoError(t, err)
})
})
}

func TestRetryableErrorRetryFunc(t *testing.T) {
testCases := []struct {
name string
err error
expected bool
}{
{
name: "NilError",
err: nil,
expected: false,
},
{
name: "TooManyRedirectsError",
err: &url.Error{
Op: "Get",
URL: "http://example.com",
Err: errors.New("stopped after 5 redirects"),
},
expected: false,
},
{
name: "UnsupportedProtocolSchemeError",
err: &url.Error{
Op: "Get",
URL: "ftp://example.com",
Err: errors.New("unsupported protocol scheme"),
},
expected: false,
},
{
name: "CertificateVerificationError",
err: &url.Error{
Op: "Get",
URL: "https://example.com",
Err: &tls.CertificateVerificationError{},
},
expected: false,
},
{
name: "UnknownAuthorityError",
err: &url.Error{
Op: "Get",
URL: "https://example.com",
Err: x509.UnknownAuthorityError{},
},
expected: false,
},
{
name: "OtherError",
err: errors.New("some other error"),
expected: true,
},
}

for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
actual := retryableErrorRetryFunc(testCase.err)
assert.Equal(t, testCase.expected, actual)
})
}
}

0 comments on commit 04992d8

Please sign in to comment.