Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Timeout middleware implementation #1743

Merged
merged 1 commit into from
Jan 5, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 81 additions & 0 deletions middleware/timeout.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
// +build go1.13

package middleware

import (
"context"
"github.com/labstack/echo/v4"
"time"
)

type (
// TimeoutConfig defines the config for Timeout middleware.
TimeoutConfig struct {
// Skipper defines a function to skip middleware.
Skipper Skipper
// ErrorHandler defines a function which is executed for a timeout
// It can be used to define a custom timeout error
ErrorHandler TimeoutErrorHandlerWithContext
// Timeout configures a timeout for the middleware, defaults to 0 for no timeout
Timeout time.Duration
}

// TimeoutErrorHandlerWithContext is an error handler that is used with the timeout middleware so we can
// handle the error as we see fit
TimeoutErrorHandlerWithContext func(error, echo.Context) error
)

var (
// DefaultTimeoutConfig is the default Timeout middleware config.
DefaultTimeoutConfig = TimeoutConfig{
Skipper: DefaultSkipper,
Timeout: 0,
ErrorHandler: nil,
}
)

// Timeout returns a middleware which recovers from panics anywhere in the chain
// and handles the control to the centralized HTTPErrorHandler.
func Timeout() echo.MiddlewareFunc {
return TimeoutWithConfig(DefaultTimeoutConfig)
}

// TimeoutWithConfig returns a Timeout middleware with config.
// See: `Timeout()`.
func TimeoutWithConfig(config TimeoutConfig) echo.MiddlewareFunc {
// Defaults
if config.Skipper == nil {
config.Skipper = DefaultTimeoutConfig.Skipper
}

return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
if config.Skipper(c) || config.Timeout == 0 {
return next(c)
}

ctx, cancel := context.WithTimeout(c.Request().Context(), config.Timeout)
defer cancel()

// this does a deep clone of the context, wondering if there is a better way to do this?
c.SetRequest(c.Request().Clone(ctx))

done := make(chan error, 1)
go func() {
// This goroutine will keep running even if this middleware times out and
// will be stopped when ctx.Done() is called down the next(c) call chain
done <- next(c)
lammel marked this conversation as resolved.
Show resolved Hide resolved
}()

select {
case <-ctx.Done():
if config.ErrorHandler != nil {
return config.ErrorHandler(ctx.Err(), c)
}
return ctx.Err()
case err := <-done:
return err
}
}
}
}
177 changes: 177 additions & 0 deletions middleware/timeout_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
// +build go1.13

package middleware

import (
"context"
"errors"
"github.com/labstack/echo/v4"
"github.com/stretchr/testify/assert"
"net/http"
"net/http/httptest"
"net/url"
"reflect"
"strings"
"testing"
"time"
)

func TestTimeoutSkipper(t *testing.T) {
t.Parallel()
m := TimeoutWithConfig(TimeoutConfig{
Skipper: func(context echo.Context) bool {
return true
},
})

req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()

e := echo.New()
c := e.NewContext(req, rec)

err := m(func(c echo.Context) error {
assert.NotEqual(t, "*context.timerCtx", reflect.TypeOf(c.Request().Context()).String())
return nil
})(c)

assert.NoError(t, err)
}

func TestTimeoutWithTimeout0(t *testing.T) {
t.Parallel()
m := TimeoutWithConfig(TimeoutConfig{
Timeout: 0,
})

req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()

e := echo.New()
c := e.NewContext(req, rec)

err := m(func(c echo.Context) error {
assert.NotEqual(t, "*context.timerCtx", reflect.TypeOf(c.Request().Context()).String())
return nil
})(c)

assert.NoError(t, err)
}

func TestTimeoutIsCancelable(t *testing.T) {
t.Parallel()
m := TimeoutWithConfig(TimeoutConfig{
Timeout: time.Minute,
})

req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()

e := echo.New()
c := e.NewContext(req, rec)

err := m(func(c echo.Context) error {
assert.EqualValues(t, "*context.timerCtx", reflect.TypeOf(c.Request().Context()).String())
return nil
})(c)

assert.NoError(t, err)
}

func TestTimeoutErrorOutInHandler(t *testing.T) {
t.Parallel()
m := Timeout()

req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()

e := echo.New()
c := e.NewContext(req, rec)

err := m(func(c echo.Context) error {
return errors.New("err")
})(c)

assert.Error(t, err)
}

func TestTimeoutTimesOutAfterPredefinedTimeoutWithErrorHandler(t *testing.T) {
t.Parallel()
m := TimeoutWithConfig(TimeoutConfig{
Timeout: time.Second,
ErrorHandler: func(err error, e echo.Context) error {
assert.EqualError(t, err, context.DeadlineExceeded.Error())
return errors.New("err")
},
})

req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()

e := echo.New()
c := e.NewContext(req, rec)

err := m(func(c echo.Context) error {
time.Sleep(time.Minute)
return nil
})(c)

assert.EqualError(t, err, errors.New("err").Error())
}

func TestTimeoutTimesOutAfterPredefinedTimeout(t *testing.T) {
t.Parallel()
m := TimeoutWithConfig(TimeoutConfig{
Timeout: time.Second,
})

req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()

e := echo.New()
c := e.NewContext(req, rec)

err := m(func(c echo.Context) error {
time.Sleep(time.Minute)
return nil
})(c)

assert.EqualError(t, err, context.DeadlineExceeded.Error())
}

func TestTimeoutTestRequestClone(t *testing.T) {
t.Parallel()
req := httptest.NewRequest(http.MethodPost, "/uri?query=value", strings.NewReader(url.Values{"form": {"value"}}.Encode()))
req.AddCookie(&http.Cookie{Name: "cookie", Value: "value"})
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
rec := httptest.NewRecorder()

m := TimeoutWithConfig(TimeoutConfig{
// Timeout has to be defined or the whole flow for timeout middleware will be skipped
Timeout: time.Second,
})

e := echo.New()
c := e.NewContext(req, rec)

err := m(func(c echo.Context) error {
// Cookie test
cookie, err := c.Request().Cookie("cookie")
if assert.NoError(t, err) {
assert.EqualValues(t, "cookie", cookie.Name)
assert.EqualValues(t, "value", cookie.Value)
}

// Form values
if assert.NoError(t, c.Request().ParseForm()) {
assert.EqualValues(t, "value", c.Request().FormValue("form"))
}

// Query string
assert.EqualValues(t, "value", c.Request().URL.Query()["query"][0])
return nil
})(c)

assert.NoError(t, err)

}