From a260bf60f1e7f4220903ca60f8fe394fff1045fe Mon Sep 17 00:00:00 2001 From: Alan Parra Date: Wed, 2 Oct 2024 18:34:47 -0300 Subject: [PATCH] Fix races on wanwin.PromptPlatformMessage --- lib/auth/webauthnwin/api.go | 39 +++++++++++++++++++++++++++---------- lib/client/mfa/cli.go | 2 +- tool/tsh/common/mfa.go | 4 ++-- 3 files changed, 32 insertions(+), 13 deletions(-) diff --git a/lib/auth/webauthnwin/api.go b/lib/auth/webauthnwin/api.go index 14402fc3e70c6..74888021ee994 100644 --- a/lib/auth/webauthnwin/api.go +++ b/lib/auth/webauthnwin/api.go @@ -28,6 +28,7 @@ import ( "fmt" "io" "os" + "sync" "github.com/go-webauthn/webauthn/protocol" "github.com/go-webauthn/webauthn/protocol/webauthncose" @@ -166,22 +167,40 @@ func Register(_ context.Context, origin string, cc *wantypes.CredentialCreation) const defaultPromptMessage = "Using platform authenticator, follow the OS dialogs" -var ( - // PromptPlatformMessage is the message shown before system prompts. - PromptPlatformMessage = defaultPromptMessage +// promptPlatformMessage is the message shown before system prompts. +var promptPlatformMessage = struct { + mu sync.Mutex + message string +}{} - // PromptWriter is the writer used for prompt messages. - PromptWriter io.Writer = os.Stderr -) +// PromptWriter is the writer used for prompt messages. +var PromptWriter io.Writer = os.Stderr + +// SetPromptPlatformMessage assigns a new prompt platform message. The prompt +// platform message is shown by [Login] or [Register] when prompting for a +// device touch. +// +// See [ResetPromptPlatformMessage]. +func SetPromptPlatformMessage(message string) { + promptPlatformMessage.mu.Lock() + promptPlatformMessage.message = message + promptPlatformMessage.mu.Unlock() +} -// ResetPromptPlatformMessage resets [PromptPlatformMessage] to its original state. +// ResetPromptPlatformMessage resets the prompt platform message to its original +// state. +// +// See [SetPromptPlatformMessage]. func ResetPromptPlatformMessage() { - PromptPlatformMessage = defaultPromptMessage + SetPromptPlatformMessage(defaultPromptMessage) } func promptPlatform() { - if PromptPlatformMessage != "" { - fmt.Fprintln(PromptWriter, PromptPlatformMessage) + promptPlatformMessage.mu.Lock() + defer promptPlatformMessage.mu.Unlock() + + if msg := promptPlatformMessage.message; msg != "" { + fmt.Fprintln(PromptWriter, msg) } } diff --git a/lib/client/mfa/cli.go b/lib/client/mfa/cli.go index e6558472838d3..1684578c6e0b3 100644 --- a/lib/client/mfa/cli.go +++ b/lib/client/mfa/cli.go @@ -76,7 +76,7 @@ func (c *CLIPrompt) Run(ctx context.Context, chal *proto.MFAAuthenticateChalleng var message string if runtime.GOOS == constants.WindowsOS { message = "Follow the OS dialogs for platform authentication, or enter an OTP code here:" - webauthnwin.PromptPlatformMessage = "" + webauthnwin.SetPromptPlatformMessage("") } else { message = fmt.Sprintf("Tap any %ssecurity key or enter a code from a %sOTP device", c.promptDevicePrefix(), c.promptDevicePrefix()) } diff --git a/tool/tsh/common/mfa.go b/tool/tsh/common/mfa.go index 744fdb4ff1a1a..c540cce3b1228 100644 --- a/tool/tsh/common/mfa.go +++ b/tool/tsh/common/mfa.go @@ -340,8 +340,8 @@ func (c *mfaAddCommand) addDeviceRPC(ctx context.Context, tc *client.TeleportCli // TODO(Joerger): this should live in lib/client/mfa/cli.go using the prompt device prefix. const registeredMsg = "Using platform authentication for *registered* device, follow the OS dialogs" const newMsg = "Using platform authentication for *new* device, follow the OS dialogs" + wanwin.SetPromptPlatformMessage(registeredMsg) defer wanwin.ResetPromptPlatformMessage() - wanwin.PromptPlatformMessage = registeredMsg mfaResp, err := tc.NewMFACeremony().Run(ctx, &proto.CreateAuthenticateChallengeRequest{ ChallengeExtensions: &mfav1.ChallengeExtensions{ @@ -363,7 +363,7 @@ func (c *mfaAddCommand) addDeviceRPC(ctx context.Context, tc *client.TeleportCli } // Prompt for registration. - wanwin.PromptPlatformMessage = newMsg + wanwin.SetPromptPlatformMessage(newMsg) registerResp, registerCallback, err := promptRegisterChallenge(ctx, tc.WebProxyAddr, c.devType, registerChallenge) if err != nil { return trace.Wrap(err)