Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: respond with 401 for programmatic requests. #8795

Merged
merged 2 commits into from
Feb 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 12 additions & 13 deletions master/internal/api_auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -174,28 +174,27 @@ func processProxyAuthentication(c echo.Context) (done bool, err error) {
return err != nil, authz.SubIfUnauthorized(err, serviceNotFoundErr)
}

// processAuthWithRedirect is an auth middleware that redirects the requests
// processAuthWithRedirect is an auth middleware that redirects browser requests
// to login page for a set of given paths in case of authentication errors.
func processAuthWithRedirect(redirectPaths []string) echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
hamidzr marked this conversation as resolved.
Show resolved Hide resolved
err := user.GetService().ProcessAuthentication(next)(c)
if err == nil {
return nil
}
// No web page redirects for programmatic requests.
for _, accept := range c.Request().Header["Accept"] {
hamidzr marked this conversation as resolved.
Show resolved Hide resolved
if strings.Contains(accept, "application/json") {
return err
}
}
path := c.Request().RequestURI
shouldRedirect := false

for _, p := range redirectPaths {
if strings.HasPrefix(path, p) {
shouldRedirect = true
break
return redirectToLogin(c)
}
}

err := user.GetService().ProcessAuthentication(next)(c)

// If there's an authentication error and we should redirect, then do so
if err != nil && shouldRedirect {
return redirectToLogin(c)
}

return err
}
}
Expand Down
79 changes: 79 additions & 0 deletions master/internal/api_user_intg_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,12 @@ package internal
import (
"context"
"fmt"
"net/http"
"net/http/httptest"
"testing"

"github.com/labstack/echo/v4"

"github.com/determined-ai/determined/master/internal/job/jobservice"
"github.com/determined-ai/determined/master/internal/rm/rmevents"
"github.com/determined-ai/determined/master/internal/sproto"
Expand Down Expand Up @@ -148,6 +152,81 @@ func fetchUserIds(ctx context.Context, t *testing.T, api *apiServer, req *apiv1.
return ids
}

func TestProcessAuth(t *testing.T) {
api, _, _ := setupAPITest(t, nil)
extConfig := model.ExternalSessions{}
user.InitService(api.m.db, &extConfig)

e := echo.New()
handler := user.GetService().ProcessAuthentication(
func(c echo.Context) error {
require.Fail(t, "Should not have reached this point")
return nil
},
)
req := httptest.NewRequest(http.MethodGet, "/authed-route", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
err := handler(c)
require.Error(t, err)
httpError, ok := err.(*echo.HTTPError)
require.True(t, ok)
require.Equal(t, http.StatusUnauthorized, httpError.Code)
}

func TestAuthMiddleware(t *testing.T) {
proxies := []string{"/proxied-path-a"}
api, _, _ := setupAPITest(t, nil)
extConfig := model.ExternalSessions{}
user.InitService(api.m.db, &extConfig)

tests := []struct {
path string
acceptHeader string
expectedCode int
expectedLoc string // Expected location header, empty if no redirect expected
}{
{"/proxied-path-a/anysubroute", "", http.StatusSeeOther, "/det/login?redirect=/proxied-path-a/anysubroute"},
{"/proxied-path-a", "application/json", http.StatusUnauthorized, ""},
{"/non-proxied-path", "", http.StatusUnauthorized, ""},
{"/non-proxied-path", "application/json", http.StatusUnauthorized, ""},
}

e := echo.New()
for _, tc := range tests {
t.Run(fmt.Sprintf("Path: %s, Accept: %s", tc.path, tc.acceptHeader), func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, tc.path, nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)

if tc.acceptHeader != "" {
req.Header.Set("Accept", tc.acceptHeader)
}

middleware := processAuthWithRedirect(proxies)
fn := middleware(func(c echo.Context) error { return c.NoContent(http.StatusUnauthorized) })

err := fn(c)

if tc.expectedCode == http.StatusUnauthorized {
require.Error(t, err, "Expected an error but got none")
httpError, ok := err.(*echo.HTTPError) // Cast error to *echo.HTTPError to check code
if ok && httpError != nil {
require.Equal(t, tc.expectedCode, httpError.Code, "HTTP status code does not match expected")
} else {
require.Fail(t, "Error is not an HTTPError as expected")
}
} else {
require.Equal(t, tc.expectedCode, http.StatusSeeOther)
require.Equal(t, tc.expectedCode, rec.Code, "HTTP status code does not match expected")
require.NoError(t, err, "Did not expect an error but got one")
require.Contains(t, rec.Header().Get("Location"), tc.expectedLoc,
"Location header does not match expected redirect")
}
})
}
}

func TestLoginRemote(t *testing.T) {
api, _, ctx := setupAPITest(t, nil)

Expand Down
Loading