diff --git a/lib/auth/webauthnwin/api.go b/lib/auth/webauthnwin/api.go index 14402fc3e70c6..4e8de09029911 100644 --- a/lib/auth/webauthnwin/api.go +++ b/lib/auth/webauthnwin/api.go @@ -28,6 +28,7 @@ import ( "fmt" "io" "os" + "sync" "github.com/go-webauthn/webauthn/protocol" "github.com/go-webauthn/webauthn/protocol/webauthncose" @@ -166,22 +167,42 @@ func Register(_ context.Context, origin string, cc *wantypes.CredentialCreation) const defaultPromptMessage = "Using platform authenticator, follow the OS dialogs" -var ( - // PromptPlatformMessage is the message shown before system prompts. - PromptPlatformMessage = defaultPromptMessage +// promptPlatformMessage is the message shown before system prompts. +var promptPlatformMessage = struct { + mu sync.Mutex + message string +}{ + message: defaultPromptMessage, +} - // PromptWriter is the writer used for prompt messages. - PromptWriter io.Writer = os.Stderr -) +// PromptWriter is the writer used for prompt messages. +var PromptWriter io.Writer = os.Stderr -// ResetPromptPlatformMessage resets [PromptPlatformMessage] to its original state. +// SetPromptPlatformMessage assigns a new prompt platform message. The prompt +// platform message is shown by [Login] or [Register] when prompting for a +// device touch. +// +// See [ResetPromptPlatformMessage]. +func SetPromptPlatformMessage(message string) { + promptPlatformMessage.mu.Lock() + promptPlatformMessage.message = message + promptPlatformMessage.mu.Unlock() +} + +// ResetPromptPlatformMessage resets the prompt platform message to its original +// state. +// +// See [SetPromptPlatformMessage]. func ResetPromptPlatformMessage() { - PromptPlatformMessage = defaultPromptMessage + SetPromptPlatformMessage(defaultPromptMessage) } func promptPlatform() { - if PromptPlatformMessage != "" { - fmt.Fprintln(PromptWriter, PromptPlatformMessage) + promptPlatformMessage.mu.Lock() + defer promptPlatformMessage.mu.Unlock() + + if msg := promptPlatformMessage.message; msg != "" { + fmt.Fprintln(PromptWriter, msg) } } diff --git a/lib/client/mfa.go b/lib/client/mfa.go index dbb1520a0fd64..3dc94018fa8af 100644 --- a/lib/client/mfa.go +++ b/lib/client/mfa.go @@ -23,13 +23,12 @@ import ( "github.com/gravitational/teleport/api/client/proto" "github.com/gravitational/teleport/api/mfa" - wancli "github.com/gravitational/teleport/lib/auth/webauthncli" - wantypes "github.com/gravitational/teleport/lib/auth/webauthntypes" libmfa "github.com/gravitational/teleport/lib/client/mfa" ) -// WebauthnLoginFunc matches the signature of [wancli.Login]. -type WebauthnLoginFunc func(ctx context.Context, origin string, assertion *wantypes.CredentialAssertion, prompt wancli.LoginPrompt, opts *wancli.LoginOpts) (*proto.MFAAuthenticateResponse, string, error) +// WebauthnLoginFunc is a function that performs WebAuthn login. +// Mimics the signature of [webauthncli.Login]. +type WebauthnLoginFunc = libmfa.WebauthnLoginFunc // NewMFAPrompt creates a new MFA prompt from client settings. func (tc *TeleportClient) NewMFAPrompt(opts ...mfa.PromptOpt) mfa.Prompt { diff --git a/lib/client/mfa/cli.go b/lib/client/mfa/cli.go index 9058ca2822991..1684578c6e0b3 100644 --- a/lib/client/mfa/cli.go +++ b/lib/client/mfa/cli.go @@ -22,11 +22,13 @@ import ( "context" "fmt" "io" + "runtime" "sync" "github.com/gravitational/trace" "github.com/gravitational/teleport/api/client/proto" + "github.com/gravitational/teleport/api/constants" "github.com/gravitational/teleport/api/utils/prompt" wancli "github.com/gravitational/teleport/lib/auth/webauthncli" wantypes "github.com/gravitational/teleport/lib/auth/webauthntypes" @@ -65,26 +67,42 @@ func (c *CLIPrompt) Run(ctx context.Context, chal *proto.MFAAuthenticateChalleng // Depending on the run opts, we may spawn a TOTP goroutine, webauth goroutine, or both. spawnGoroutines := func(ctx context.Context, wg *sync.WaitGroup, respC chan<- MFAGoroutineResponse) { - // Use variables below to cancel OTP reads and make sure the goroutine exited. - otpCtx, otpCancel := context.WithCancel(ctx) - otpDone := make(chan struct{}) - otpCancelAndWait := func() { - otpCancel() - <-otpDone + dualPrompt := runOpts.PromptTOTP && runOpts.PromptWebauthn + + // Print the prompt message directly here in case of dualPrompt. + // This avoids problems with a goroutine failing before any message is + // printed. + if dualPrompt { + var message string + if runtime.GOOS == constants.WindowsOS { + message = "Follow the OS dialogs for platform authentication, or enter an OTP code here:" + webauthnwin.SetPromptPlatformMessage("") + } else { + message = fmt.Sprintf("Tap any %ssecurity key or enter a code from a %sOTP device", c.promptDevicePrefix(), c.promptDevicePrefix()) + } + fmt.Fprintln(c.writer, message) } // Fire TOTP goroutine. + var otpCancelAndWait func() if runOpts.PromptTOTP { + otpCtx, otpCancel := context.WithCancel(ctx) + otpDone := make(chan struct{}) + otpCancelAndWait = func() { + otpCancel() + <-otpDone + } + wg.Add(1) go func() { - defer wg.Done() - defer otpCancel() - defer close(otpDone) - - // Let Webauthn take the prompt below if applicable. - quiet := c.cfg.Quiet || runOpts.PromptWebauthn - - resp, err := c.promptTOTP(otpCtx, chal, quiet) + defer func() { + wg.Done() + otpCancel() + close(otpDone) + }() + + quiet := c.cfg.Quiet || dualPrompt + resp, err := c.promptTOTP(otpCtx, quiet) respC <- MFAGoroutineResponse{Resp: resp, Err: trace.Wrap(err, "TOTP authentication failed")} }() } @@ -93,11 +111,15 @@ func (c *CLIPrompt) Run(ctx context.Context, chal *proto.MFAAuthenticateChalleng if runOpts.PromptWebauthn { wg.Add(1) go func() { - defer wg.Done() + defer func() { + wg.Done() + // Important for dual-prompt, harmless otherwise. + webauthnwin.ResetPromptPlatformMessage() + }() // Get webauthn prompt and wrap with otp context handler. prompt := &webauthnPromptWithOTP{ - LoginPrompt: c.getWebauthnPrompt(ctx, runOpts.PromptTOTP), + LoginPrompt: c.getWebauthnPrompt(ctx, dualPrompt), otpCancelAndWait: otpCancelAndWait, } @@ -110,7 +132,7 @@ func (c *CLIPrompt) Run(ctx context.Context, chal *proto.MFAAuthenticateChalleng return HandleMFAPromptGoroutines(ctx, spawnGoroutines) } -func (c *CLIPrompt) promptTOTP(ctx context.Context, chal *proto.MFAAuthenticateChallenge, quiet bool) (*proto.MFAAuthenticateResponse, error) { +func (c *CLIPrompt) promptTOTP(ctx context.Context, quiet bool) (*proto.MFAAuthenticateResponse, error) { var msg string if !quiet { msg = fmt.Sprintf("Enter an OTP code from a %sdevice", c.promptDevicePrefix()) @@ -128,7 +150,7 @@ func (c *CLIPrompt) promptTOTP(ctx context.Context, chal *proto.MFAAuthenticateC }, nil } -func (c *CLIPrompt) getWebauthnPrompt(ctx context.Context, withTOTP bool) wancli.LoginPrompt { +func (c *CLIPrompt) getWebauthnPrompt(ctx context.Context, dualPrompt bool) wancli.LoginPrompt { writer := c.writer if c.cfg.Quiet { writer = io.Discard @@ -138,13 +160,10 @@ func (c *CLIPrompt) getWebauthnPrompt(ctx context.Context, withTOTP bool) wancli prompt.SecondTouchMessage = fmt.Sprintf("Tap your %ssecurity key to complete login", c.promptDevicePrefix()) prompt.FirstTouchMessage = fmt.Sprintf("Tap any %ssecurity key", c.promptDevicePrefix()) - if withTOTP { - prompt.FirstTouchMessage = fmt.Sprintf("Tap any %ssecurity key or enter a code from a %sOTP device", c.promptDevicePrefix(), c.promptDevicePrefix()) - - // Customize Windows prompt directly. - // Note that the platform popup is a modal and will only go away if canceled. - webauthnwin.PromptPlatformMessage = "Follow the OS dialogs for platform authentication, or enter an OTP code here:" - defer webauthnwin.ResetPromptPlatformMessage() + // Skip when both OTP and WebAuthn are possible, as the prompt happens + // externally. + if dualPrompt { + prompt.FirstTouchMessage = "" } return prompt @@ -173,12 +192,25 @@ func (c *CLIPrompt) promptDevicePrefix() string { // authenticators out there. type webauthnPromptWithOTP struct { wancli.LoginPrompt - otpCancelAndWait func() + + otpCancelAndWaitOnce sync.Once + otpCancelAndWait func() } -func (w *webauthnPromptWithOTP) PromptPIN() (string, error) { - // If we get to this stage, Webauthn PIN verification is underway. - // Cancel otp goroutine so that it doesn't capture the PIN from stdin. - w.otpCancelAndWait() - return w.LoginPrompt.PromptPIN() +func (w *webauthnPromptWithOTP) PromptTouch() (wancli.TouchAcknowledger, error) { + ack, err := w.LoginPrompt.PromptTouch() + if err != nil { + return nil, trace.Wrap(err) + } + + return func() error { + err := ack() + + // Stop the OTP goroutine when the first touch is acknowledged. + if w.otpCancelAndWait != nil { + w.otpCancelAndWaitOnce.Do(w.otpCancelAndWait) + } + + return trace.Wrap(err) + }, nil } diff --git a/lib/client/mfa/cli_test.go b/lib/client/mfa/cli_test.go index f177212fddaba..0f0d6cc8323b2 100644 --- a/lib/client/mfa/cli_test.go +++ b/lib/client/mfa/cli_test.go @@ -21,6 +21,7 @@ package mfa_test import ( "bytes" "context" + "errors" "testing" "time" @@ -39,12 +40,13 @@ func TestCLIPrompt(t *testing.T) { ctx := context.Background() for _, tc := range []struct { - name string - stdin string - challenge *proto.MFAAuthenticateChallenge - expectErr error - expectStdOut string - expectResp *proto.MFAAuthenticateResponse + name string + stdin string + challenge *proto.MFAAuthenticateChallenge + expectErr error + expectStdOut string + expectResp *proto.MFAAuthenticateResponse + makeWebauthnLoginFunc func(stdin *prompt.FakeReader) mfa.WebauthnLoginFunc }{ { name: "OK empty challenge", @@ -126,6 +128,102 @@ func TestCLIPrompt(t *testing.T) { }, expectErr: context.DeadlineExceeded, }, + { + name: "OK otp and webauthn with PIN", + challenge: &proto.MFAAuthenticateChallenge{ + TOTP: &proto.TOTPChallenge{}, + WebauthnChallenge: &webauthnpb.CredentialAssertion{}, + }, + expectStdOut: `Tap any security key or enter a code from a OTP device +Detected security key tap +Enter your security key PIN: +`, + expectResp: &proto.MFAAuthenticateResponse{ + Response: &proto.MFAAuthenticateResponse_Webauthn{ + Webauthn: &webauthnpb.CredentialAssertionResponse{ + RawId: []byte{1, 2, 3, 4, 5}, + }, + }, + }, + makeWebauthnLoginFunc: func(stdin *prompt.FakeReader) mfa.WebauthnLoginFunc { + return func(ctx context.Context, origin string, assertion *wantypes.CredentialAssertion, prompt wancli.LoginPrompt, opts *wancli.LoginOpts) (*proto.MFAAuthenticateResponse, string, error) { + ack, err := prompt.PromptTouch() + if err != nil { + return nil, "", trace.Wrap(err) + } + + // Ack first (so the OTP goroutine stops)... + if err := ack(); err != nil { + return nil, "", trace.Wrap(err) + } + + // ...then send the PIN to stdin... + const pin = "1234" + stdin.AddString(pin) + + // ...then prompt for the PIN. + switch got, err := prompt.PromptPIN(); { + case err != nil: + return nil, "", trace.Wrap(err) + case got != pin: + return nil, "", errors.New("invalid PIN") + } + + return &proto.MFAAuthenticateResponse{ + Response: &proto.MFAAuthenticateResponse_Webauthn{ + Webauthn: &webauthnpb.CredentialAssertionResponse{ + RawId: []byte{1, 2, 3, 4, 5}, + }, + }, + }, "", nil + } + }, + }, + { + name: "OK webauthn with PIN", + challenge: &proto.MFAAuthenticateChallenge{ + TOTP: nil, // no TOTP challenge + WebauthnChallenge: &webauthnpb.CredentialAssertion{}, + }, + stdin: "1234", + expectStdOut: `Tap any security key +Detected security key tap +Enter your security key PIN: +`, + expectResp: &proto.MFAAuthenticateResponse{ + Response: &proto.MFAAuthenticateResponse_Webauthn{ + Webauthn: &webauthnpb.CredentialAssertionResponse{ + RawId: []byte{1, 2, 3, 4, 5}, + }, + }, + }, + makeWebauthnLoginFunc: func(_ *prompt.FakeReader) mfa.WebauthnLoginFunc { + return func(ctx context.Context, origin string, assertion *wantypes.CredentialAssertion, prompt wancli.LoginPrompt, opts *wancli.LoginOpts) (*proto.MFAAuthenticateResponse, string, error) { + ack, err := prompt.PromptTouch() + if err != nil { + return nil, "", trace.Wrap(err) + } + if err := ack(); err != nil { + return nil, "", trace.Wrap(err) + } + + switch got, err := prompt.PromptPIN(); { + case err != nil: + return nil, "", trace.Wrap(err) + case got != "1234": + return nil, "", errors.New("invalid PIN") + } + + return &proto.MFAAuthenticateResponse{ + Response: &proto.MFAAuthenticateResponse_Webauthn{ + Webauthn: &webauthnpb.CredentialAssertionResponse{ + RawId: []byte{1, 2, 3, 4, 5}, + }, + }, + }, "", nil + } + }, + }, } { t.Run(tc.name, func(t *testing.T) { ctx, cancel := context.WithTimeout(ctx, 100*time.Millisecond) @@ -143,17 +241,21 @@ func TestCLIPrompt(t *testing.T) { cfg := mfa.NewPromptConfig("proxy.example.com") cfg.AllowStdinHijack = true cfg.WebauthnSupported = true - cfg.WebauthnLoginFunc = func(ctx context.Context, origin string, assertion *wantypes.CredentialAssertion, prompt wancli.LoginPrompt, opts *wancli.LoginOpts) (*proto.MFAAuthenticateResponse, string, error) { - if _, err := prompt.PromptTouch(); err != nil { - return nil, "", trace.Wrap(err) - } + if tc.makeWebauthnLoginFunc != nil { + cfg.WebauthnLoginFunc = tc.makeWebauthnLoginFunc(stdin) + } else { + cfg.WebauthnLoginFunc = func(ctx context.Context, origin string, assertion *wantypes.CredentialAssertion, prompt wancli.LoginPrompt, opts *wancli.LoginOpts) (*proto.MFAAuthenticateResponse, string, error) { + if _, err := prompt.PromptTouch(); err != nil { + return nil, "", trace.Wrap(err) + } - if tc.expectResp.GetWebauthn() == nil { - <-ctx.Done() - return nil, "", trace.Wrap(ctx.Err()) - } + if tc.expectResp.GetWebauthn() == nil { + <-ctx.Done() + return nil, "", trace.Wrap(ctx.Err()) + } - return tc.expectResp, "", nil + return tc.expectResp, "", nil + } } buffer := make([]byte, 0, 100) diff --git a/lib/client/mfa/prompt.go b/lib/client/mfa/prompt.go index 9fe7775118266..93a9645e2faf3 100644 --- a/lib/client/mfa/prompt.go +++ b/lib/client/mfa/prompt.go @@ -25,6 +25,7 @@ import ( "sync" "github.com/gravitational/trace" + log "github.com/sirupsen/logrus" "github.com/gravitational/teleport/api/client/proto" "github.com/gravitational/teleport/api/mfa" @@ -32,13 +33,23 @@ import ( wantypes "github.com/gravitational/teleport/lib/auth/webauthntypes" ) +// WebauthnLoginFunc is a function that performs WebAuthn login. +// Mimics the signature of [wancli.Login]. +type WebauthnLoginFunc func( + ctx context.Context, + origin string, + assertion *wantypes.CredentialAssertion, + prompt wancli.LoginPrompt, + opts *wancli.LoginOpts, +) (*proto.MFAAuthenticateResponse, string, error) + // PromptConfig contains common mfa prompt config options. type PromptConfig struct { mfa.PromptConfig // ProxyAddress is the address of the authenticating proxy. required. ProxyAddress string // WebauthnLoginFunc performs client-side Webauthn login. - WebauthnLoginFunc func(ctx context.Context, origin string, assertion *wantypes.CredentialAssertion, prompt wancli.LoginPrompt, opts *wancli.LoginOpts) (*proto.MFAAuthenticateResponse, string, error) + WebauthnLoginFunc WebauthnLoginFunc // AllowStdinHijack allows stdin hijack during MFA prompts. // Stdin hijack provides a better login UX, but it can be difficult to reason // about and is often a source of bugs. @@ -147,6 +158,9 @@ func HandleMFAPromptGoroutines(ctx context.Context, startGoroutines func(context // Surface error immediately. return nil, trace.Wrap(resp.Err) case err != nil: + log. + WithError(err). + Debug("MFA goroutine failed, continuing so other goroutines have a chance to succeed") errs = append(errs, err) // Continue to give the other authn goroutine a chance to succeed. // If both have failed, this will exit the loop. diff --git a/tool/tsh/common/mfa.go b/tool/tsh/common/mfa.go index 73ae0b8024a6e..1ce15bd11ec3a 100644 --- a/tool/tsh/common/mfa.go +++ b/tool/tsh/common/mfa.go @@ -350,8 +350,8 @@ func (c *mfaAddCommand) addDeviceRPC(ctx context.Context, tc *client.TeleportCli // of finding out whether it is a Windows prompt or not). const registeredMsg = "Using platform authentication for *registered* device, follow the OS dialogs" const newMsg = "Using platform authentication for *new* device, follow the OS dialogs" + wanwin.SetPromptPlatformMessage(registeredMsg) defer wanwin.ResetPromptPlatformMessage() - wanwin.PromptPlatformMessage = registeredMsg // Prompt for authentication. // Does nothing if no challenges were issued (aka user has no devices). @@ -371,7 +371,7 @@ func (c *mfaAddCommand) addDeviceRPC(ctx context.Context, tc *client.TeleportCli } // Prompt for registration. - wanwin.PromptPlatformMessage = newMsg + wanwin.SetPromptPlatformMessage(newMsg) registerResp, registerCallback, err := promptRegisterChallenge(ctx, tc.WebProxyAddr, c.devType, registerChallenge) if err != nil { return trace.Wrap(err)