Skip to content

Commit

Permalink
Unify and simplify MFA Ceremony helpers (#46986) (#47157)
Browse files Browse the repository at this point in the history
* Refactor MFA ceremony helpers.

* Refactor session MFA ceremony to use new MFA ceremony helpers.

* Simplify calls to NewMFACeremony.

* Remove remaining usage of tc.PromptMFA in favor of Ceremony.

* Rename prompt constructor.

* Add godoc to ceremony; update tests.

* Cleanup.

* Resolve comments; fix tests.

* Update comments.

* Fix test.

* Fix lint.
  • Loading branch information
Joerger authored Oct 16, 2024
1 parent e28b8b9 commit 4a59807
Show file tree
Hide file tree
Showing 24 changed files with 396 additions and 398 deletions.
20 changes: 4 additions & 16 deletions api/client/mfa.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@ package client
import (
"context"

"github.com/gravitational/trace"

"github.com/gravitational/teleport/api/client/proto"
"github.com/gravitational/teleport/api/mfa"
)
Expand All @@ -29,19 +27,9 @@ import (
// and prompts the user to answer the challenge with the given promptOpts, and ultimately returning
// an MFA challenge response for the user.
func (c *Client) PerformMFACeremony(ctx context.Context, challengeRequest *proto.CreateAuthenticateChallengeRequest, promptOpts ...mfa.PromptOpt) (*proto.MFAAuthenticateResponse, error) {
// Don't attempt the MFA ceremony if we can't prompt for a response.
if c.c.MFAPromptConstructor == nil {
return nil, trace.Wrap(&mfa.ErrMFANotSupported, "missing MFAPromptConstructor field, client cannot perform MFA ceremony")
}

return mfa.PerformMFACeremony(ctx, c, challengeRequest, promptOpts...)
}

// PromptMFA prompts the user for MFA. Implements [mfa.MFACeremonyClient].
func (c *Client) PromptMFA(ctx context.Context, chal *proto.MFAAuthenticateChallenge, promptOpts ...mfa.PromptOpt) (*proto.MFAAuthenticateResponse, error) {
if c.c.MFAPromptConstructor == nil {
return nil, trace.Wrap(&mfa.ErrMFANotSupported, "missing MFAPromptConstructor field, client cannot prompt for MFA")
mfaCeremony := &mfa.Ceremony{
CreateAuthenticateChallenge: c.CreateAuthenticateChallenge,
PromptConstructor: c.c.MFAPromptConstructor,
}

return c.c.MFAPromptConstructor(promptOpts...).Run(ctx, chal)
return mfaCeremony.Run(ctx, challengeRequest, promptOpts...)
}
66 changes: 41 additions & 25 deletions api/mfa/ceremony.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,32 +25,37 @@ import (
mfav1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/mfa/v1"
)

// MFACeremonyClient is a client that can perform an MFA ceremony, from retrieving
// the MFA challenge to prompting for an MFA response from the user.
type MFACeremonyClient interface {
// CreateAuthenticateChallenge creates and returns MFA challenges for a users registered MFA devices.
CreateAuthenticateChallenge(ctx context.Context, in *proto.CreateAuthenticateChallengeRequest) (*proto.MFAAuthenticateChallenge, error)
// PromptMFA prompts the user for MFA.
PromptMFA(ctx context.Context, chal *proto.MFAAuthenticateChallenge, promptOpts ...PromptOpt) (*proto.MFAAuthenticateResponse, error)
// Ceremony is an MFA ceremony.
type Ceremony struct {
// CreateAuthenticateChallenge creates an authentication challenge.
CreateAuthenticateChallenge func(ctx context.Context, req *proto.CreateAuthenticateChallengeRequest) (*proto.MFAAuthenticateChallenge, error)
// PromptConstructor creates a prompt to prompt the user to solve an authentication challenge.
PromptConstructor PromptConstructor
// SolveAuthenticateChallenge solves an authentication challenge. Used in non-interactive settings,
// such as the WebUI with layers abstracting user interaction, and tests.
SolveAuthenticateChallenge func(ctx context.Context, chal *proto.MFAAuthenticateChallenge) (*proto.MFAAuthenticateResponse, error)
}

// PerformMFACeremony retrieves an MFA challenge from the server with the given challenge extensions
// and prompts the user to answer the challenge with the given promptOpts, and ultimately returning
// an MFA challenge response for the user.
func PerformMFACeremony(ctx context.Context, clt MFACeremonyClient, challengeRequest *proto.CreateAuthenticateChallengeRequest, promptOpts ...PromptOpt) (*proto.MFAAuthenticateResponse, error) {
if challengeRequest == nil {
return nil, trace.BadParameter("missing challenge request")
}

if challengeRequest.ChallengeExtensions == nil {
// Run the MFA ceremony.
//
// req may be nil if ceremony.CreateAuthenticateChallenge does not require it, e.g. in
// the moderated session mfa ceremony which uses a custom stream rpc to create challenges.
func (c *Ceremony) Run(ctx context.Context, req *proto.CreateAuthenticateChallengeRequest, promptOpts ...PromptOpt) (*proto.MFAAuthenticateResponse, error) {
switch {
case c.CreateAuthenticateChallenge == nil:
return nil, trace.BadParameter("mfa ceremony must have CreateAuthenticateChallenge set in order to begin")
case c.SolveAuthenticateChallenge != nil && c.PromptConstructor != nil:
return nil, trace.BadParameter("mfa ceremony should have SolveAuthenticateChallenge or PromptConstructor set, not both")
case req == nil:
// req may be nil in cases where the ceremony's CreateAuthenticateChallenge sources
// its own req or uses a different rpc, e.g. moderated sessions.
case req.ChallengeExtensions == nil:
return nil, trace.BadParameter("missing challenge extensions")
}

if challengeRequest.ChallengeExtensions.Scope == mfav1.ChallengeScope_CHALLENGE_SCOPE_UNSPECIFIED {
case req.ChallengeExtensions.Scope == mfav1.ChallengeScope_CHALLENGE_SCOPE_UNSPECIFIED:
return nil, trace.BadParameter("mfa challenge scope must be specified")
}

chal, err := clt.CreateAuthenticateChallenge(ctx, challengeRequest)
chal, err := c.CreateAuthenticateChallenge(ctx, req)
if err != nil {
// CreateAuthenticateChallenge returns a bad parameter error when the client
// user is not a Teleport user - for example, the AdminRole. Treat this as an MFA
Expand All @@ -67,21 +72,31 @@ func PerformMFACeremony(ctx context.Context, clt MFACeremonyClient, challengeReq
return nil, &ErrMFANotRequired
}

return clt.PromptMFA(ctx, chal, promptOpts...)
if c.SolveAuthenticateChallenge == nil && c.PromptConstructor == nil {
return nil, trace.Wrap(&ErrMFANotSupported, "mfa ceremony must have SolveAuthenticateChallenge or PromptConstructor set in order to succeed")
}

if c.SolveAuthenticateChallenge != nil {
resp, err := c.SolveAuthenticateChallenge(ctx, chal)
return resp, trace.Wrap(err)
}

resp, err := c.PromptConstructor(promptOpts...).Run(ctx, chal)
return resp, trace.Wrap(err)
}

type MFACeremony func(ctx context.Context, challengeRequest *proto.CreateAuthenticateChallengeRequest, promptOpts ...PromptOpt) (*proto.MFAAuthenticateResponse, error)
// CeremonyFn is a function that will carry out an MFA ceremony.
type CeremonyFn func(ctx context.Context, in *proto.CreateAuthenticateChallengeRequest, promptOpts ...PromptOpt) (*proto.MFAAuthenticateResponse, error)

// PerformAdminActionMFACeremony retrieves an MFA challenge from the server for an admin
// action, prompts the user to answer the challenge, and returns the resulting MFA response.
func PerformAdminActionMFACeremony(ctx context.Context, mfaCeremony MFACeremony, allowReuse bool) (*proto.MFAAuthenticateResponse, error) {
func PerformAdminActionMFACeremony(ctx context.Context, mfaCeremony CeremonyFn, allowReuse bool) (*proto.MFAAuthenticateResponse, error) {
allowReuseExt := mfav1.ChallengeAllowReuse_CHALLENGE_ALLOW_REUSE_NO
if allowReuse {
allowReuseExt = mfav1.ChallengeAllowReuse_CHALLENGE_ALLOW_REUSE_YES
}

challengeRequest := &proto.CreateAuthenticateChallengeRequest{
Request: &proto.CreateAuthenticateChallengeRequest_ContextUser{},
MFARequiredCheck: &proto.IsMFARequiredRequest{
Target: &proto.IsMFARequiredRequest_AdminAction{
AdminAction: &proto.AdminAction{},
Expand All @@ -93,5 +108,6 @@ func PerformAdminActionMFACeremony(ctx context.Context, mfaCeremony MFACeremony,
},
}

return mfaCeremony(ctx, challengeRequest, WithPromptReasonAdminAction())
resp, err := mfaCeremony(ctx, challengeRequest, WithPromptReasonAdminAction())
return resp, trace.Wrap(err)
}
114 changes: 69 additions & 45 deletions api/mfa/ceremony_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"errors"
"testing"

"github.com/gravitational/trace"
"github.com/stretchr/testify/assert"

"github.com/gravitational/teleport/api/client/proto"
Expand All @@ -32,6 +33,9 @@ func TestPerformMFACeremony(t *testing.T) {
t.Parallel()
ctx := context.Background()

testMFAChallenge := &proto.MFAAuthenticateChallenge{
TOTP: &proto.TOTPChallenge{},
}
testMFAResponse := &proto.MFAAuthenticateResponse{
Response: &proto.MFAAuthenticateResponse_TOTP{
TOTP: &proto.TOTPResponse{
Expand All @@ -42,52 +46,103 @@ func TestPerformMFACeremony(t *testing.T) {

for _, tt := range []struct {
name string
ceremonyClient *fakeMFACeremonyClient
ceremony *mfa.Ceremony
assertCeremonyResponse func(*testing.T, *proto.MFAAuthenticateResponse, error, ...interface{})
}{
{
name: "OK ceremony success",
ceremonyClient: &fakeMFACeremonyClient{
challengeResponse: testMFAResponse,
name: "OK ceremony success prompt",
ceremony: &mfa.Ceremony{
CreateAuthenticateChallenge: func(ctx context.Context, req *proto.CreateAuthenticateChallengeRequest) (*proto.MFAAuthenticateChallenge, error) {
return testMFAChallenge, nil
},
PromptConstructor: func(po ...mfa.PromptOpt) mfa.Prompt {
return mfa.PromptFunc(func(ctx context.Context, chal *proto.MFAAuthenticateChallenge) (*proto.MFAAuthenticateResponse, error) {
return testMFAResponse, nil
})
},
},
assertCeremonyResponse: func(t *testing.T, mr *proto.MFAAuthenticateResponse, err error, i ...interface{}) {
assert.NoError(t, err)
assert.Equal(t, testMFAResponse, mr)
},
}, {
name: "OK ceremony success solve",
ceremony: &mfa.Ceremony{
CreateAuthenticateChallenge: func(ctx context.Context, req *proto.CreateAuthenticateChallengeRequest) (*proto.MFAAuthenticateChallenge, error) {
return testMFAChallenge, nil
},
SolveAuthenticateChallenge: func(ctx context.Context, chal *proto.MFAAuthenticateChallenge) (*proto.MFAAuthenticateResponse, error) {
return testMFAResponse, nil
},
},
assertCeremonyResponse: func(t *testing.T, mr *proto.MFAAuthenticateResponse, err error, i ...interface{}) {
assert.NoError(t, err)
assert.Equal(t, testMFAResponse, mr)
},
}, {
name: "OK ceremony not required",
ceremonyClient: &fakeMFACeremonyClient{
challengeResponse: testMFAResponse,
mfaRequired: proto.MFARequired_MFA_REQUIRED_NO,
ceremony: &mfa.Ceremony{
CreateAuthenticateChallenge: func(ctx context.Context, req *proto.CreateAuthenticateChallengeRequest) (*proto.MFAAuthenticateChallenge, error) {
return &proto.MFAAuthenticateChallenge{
MFARequired: proto.MFARequired_MFA_REQUIRED_NO,
}, nil
},
SolveAuthenticateChallenge: func(ctx context.Context, chal *proto.MFAAuthenticateChallenge) (*proto.MFAAuthenticateResponse, error) {
return nil, trace.BadParameter("expected mfa not required")
},
},
assertCeremonyResponse: func(t *testing.T, mr *proto.MFAAuthenticateResponse, err error, i ...interface{}) {
assert.Error(t, err, mfa.ErrMFANotRequired)
assert.Nil(t, mr)
},
}, {
name: "NOK create challenge fail",
ceremonyClient: &fakeMFACeremonyClient{
challengeResponse: testMFAResponse,
createAuthenticateChallengeErr: errors.New("create authenticate challenge failure"),
ceremony: &mfa.Ceremony{
CreateAuthenticateChallenge: func(ctx context.Context, req *proto.CreateAuthenticateChallengeRequest) (*proto.MFAAuthenticateChallenge, error) {
return nil, errors.New("create authenticate challenge failure")
},
SolveAuthenticateChallenge: func(ctx context.Context, chal *proto.MFAAuthenticateChallenge) (*proto.MFAAuthenticateResponse, error) {
return nil, trace.BadParameter("expected challenge failure")
},
},
assertCeremonyResponse: func(t *testing.T, mr *proto.MFAAuthenticateResponse, err error, i ...interface{}) {
assert.ErrorContains(t, err, "create authenticate challenge failure")
assert.Nil(t, mr)
},
}, {
name: "NOK prompt mfa fail",
ceremonyClient: &fakeMFACeremonyClient{
challengeResponse: testMFAResponse,
promptMFAErr: errors.New("prompt mfa failure"),
ceremony: &mfa.Ceremony{
CreateAuthenticateChallenge: func(ctx context.Context, req *proto.CreateAuthenticateChallengeRequest) (*proto.MFAAuthenticateChallenge, error) {
return testMFAChallenge, nil
},
PromptConstructor: func(po ...mfa.PromptOpt) mfa.Prompt {
return mfa.PromptFunc(func(ctx context.Context, chal *proto.MFAAuthenticateChallenge) (*proto.MFAAuthenticateResponse, error) {
return nil, errors.New("prompt mfa failure")
})
},
},
assertCeremonyResponse: func(t *testing.T, mr *proto.MFAAuthenticateResponse, err error, i ...interface{}) {
assert.ErrorContains(t, err, "prompt mfa failure")
assert.Nil(t, mr)
},
}, {
name: "NOK solve mfa fail",
ceremony: &mfa.Ceremony{
CreateAuthenticateChallenge: func(ctx context.Context, req *proto.CreateAuthenticateChallengeRequest) (*proto.MFAAuthenticateChallenge, error) {
return testMFAChallenge, nil
},
SolveAuthenticateChallenge: func(ctx context.Context, chal *proto.MFAAuthenticateChallenge) (*proto.MFAAuthenticateResponse, error) {
return nil, errors.New("solve mfa failure")
},
},
assertCeremonyResponse: func(t *testing.T, mr *proto.MFAAuthenticateResponse, err error, i ...interface{}) {
assert.ErrorContains(t, err, "solve mfa failure")
assert.Nil(t, mr)
},
},
} {
t.Run(tt.name, func(t *testing.T) {
resp, err := mfa.PerformMFACeremony(ctx, tt.ceremonyClient, &proto.CreateAuthenticateChallengeRequest{
resp, err := tt.ceremony.Run(ctx, &proto.CreateAuthenticateChallengeRequest{
ChallengeExtensions: &mfav1.ChallengeExtensions{
Scope: mfav1.ChallengeScope_CHALLENGE_SCOPE_ADMIN_ACTION,
},
Expand All @@ -97,34 +152,3 @@ func TestPerformMFACeremony(t *testing.T) {
})
}
}

type fakeMFACeremonyClient struct {
createAuthenticateChallengeErr error
promptMFAErr error
mfaRequired proto.MFARequired
challengeResponse *proto.MFAAuthenticateResponse
}

func (c *fakeMFACeremonyClient) CreateAuthenticateChallenge(ctx context.Context, in *proto.CreateAuthenticateChallengeRequest) (*proto.MFAAuthenticateChallenge, error) {
if c.createAuthenticateChallengeErr != nil {
return nil, c.createAuthenticateChallengeErr
}

chal := &proto.MFAAuthenticateChallenge{
TOTP: &proto.TOTPChallenge{},
}

if in.MFARequiredCheck != nil {
chal.MFARequired = c.mfaRequired
}

return chal, nil
}

func (c *fakeMFACeremonyClient) PromptMFA(ctx context.Context, chal *proto.MFAAuthenticateChallenge, promptOpts ...mfa.PromptOpt) (*proto.MFAAuthenticateResponse, error) {
if c.promptMFAErr != nil {
return nil, c.promptMFAErr
}

return c.challengeResponse, nil
}
2 changes: 1 addition & 1 deletion api/utils/grpc/interceptors/mfa.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import (
// to the rpc call when an MFA response is provided through the context. Additionally,
// when the call returns an error that indicates that MFA is required, this interceptor
// will prompt for MFA using the given mfaCeremony and retry.
func WithMFAUnaryInterceptor(mfaCeremony mfa.MFACeremony) grpc.UnaryClientInterceptor {
func WithMFAUnaryInterceptor(mfaCeremony mfa.CeremonyFn) grpc.UnaryClientInterceptor {
return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
// Check for MFA response passed through the context.
if mfaResp, err := mfa.MFAResponseFromContext(ctx); err == nil {
Expand Down
19 changes: 9 additions & 10 deletions lib/auth/helpers_mfa.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import (

"github.com/gravitational/teleport/api/client/proto"
mfav1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/mfa/v1"
"github.com/gravitational/teleport/api/mfa"
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/lib/auth/mocku2f"
wantypes "github.com/gravitational/teleport/lib/auth/webauthntypes"
Expand Down Expand Up @@ -101,24 +102,22 @@ type authClientI interface {
AddMFADeviceSync(context.Context, *proto.AddMFADeviceSyncRequest) (*proto.AddMFADeviceSyncResponse, error)
}

func (d *TestDevice) registerDevice(
ctx context.Context, authClient authClientI, devName string, devType proto.DeviceType, authenticator *TestDevice) error {
// Re-authenticate using MFA.
authnChal, err := authClient.CreateAuthenticateChallenge(ctx, &proto.CreateAuthenticateChallengeRequest{
Request: &proto.CreateAuthenticateChallengeRequest_ContextUser{
ContextUser: &proto.ContextUser{},
func (d *TestDevice) registerDevice(ctx context.Context, authClient authClientI, devName string, devType proto.DeviceType, authenticator *TestDevice) error {
mfaCeremony := &mfa.Ceremony{
CreateAuthenticateChallenge: authClient.CreateAuthenticateChallenge,
SolveAuthenticateChallenge: func(_ context.Context, chal *proto.MFAAuthenticateChallenge) (*proto.MFAAuthenticateResponse, error) {
return authenticator.SolveAuthn(chal)
},
}

authnSolved, err := mfaCeremony.Run(ctx, &proto.CreateAuthenticateChallengeRequest{
ChallengeExtensions: &mfav1.ChallengeExtensions{
Scope: mfav1.ChallengeScope_CHALLENGE_SCOPE_MANAGE_DEVICES,
},
})
if err != nil {
return trace.Wrap(err)
}
authnSolved, err := authenticator.SolveAuthn(authnChal)
if err != nil {
return trace.Wrap(err)
}

// Acquire and solve registration challenge.
usage := proto.DeviceUsage_DEVICE_USAGE_MFA
Expand Down
Loading

0 comments on commit 4a59807

Please sign in to comment.