diff --git a/tool/tsh/common/tsh.go b/tool/tsh/common/tsh.go index 5466f3bda68a..b892a1c3c60c 100644 --- a/tool/tsh/common/tsh.go +++ b/tool/tsh/common/tsh.go @@ -3605,6 +3605,11 @@ func onSSH(cf *CLIConf) error { tc.AllowHeadless = true + // Support calling `tsh ssh -- ` (with a double dash before the command) + if len(cf.RemoteCommand) > 0 && strings.TrimSpace(cf.RemoteCommand[0]) == "--" { + cf.RemoteCommand = cf.RemoteCommand[1:] + } + tc.Stdin = os.Stdin err = retryWithAccessRequest(cf, tc, func() error { sshFunc := func() error { diff --git a/tool/tsh/common/tsh_test.go b/tool/tsh/common/tsh_test.go index c2f96b070a3c..37e06c0ef228 100644 --- a/tool/tsh/common/tsh_test.go +++ b/tool/tsh/common/tsh_test.go @@ -2298,6 +2298,183 @@ func TestAccessRequestOnLeaf(t *testing.T) { require.NoError(t, err) } +// TestSSHCommand tests that a user can access a single SSH node and run commands. +func TestSSHCommands(t *testing.T) { + modules.SetTestModules(t, &modules.TestModules{TestBuildType: modules.BuildEnterprise}) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + accessRoleName := "access" + sshHostname := "test-ssh-server" + + accessUser, err := types.NewUser(accessRoleName) + require.NoError(t, err) + accessUser.SetRoles([]string{accessRoleName}) + + user, err := user.Current() + require.NoError(t, err) + accessUser.SetLogins([]string{user.Username}) + + traits := map[string][]string{ + constants.TraitLogins: {user.Username}, + } + accessUser.SetTraits(traits) + + connector := mockConnector(t) + rootServerOpts := []testserver.TestServerOptFunc{ + testserver.WithBootstrap(connector, accessUser), + testserver.WithHostname(sshHostname), + testserver.WithClusterName(t, "root"), + testserver.WithSSHLabel(accessRoleName, "true"), + testserver.WithSSHPublicAddrs("127.0.0.1:0"), + testserver.WithConfig(func(cfg *servicecfg.Config) { + cfg.SSH.Enabled = true + cfg.SSH.PublicAddrs = []utils.NetAddr{cfg.SSH.Addr} + cfg.SSH.DisableCreateHostUser = true + }), + } + rootServer := testserver.MakeTestServer(t, rootServerOpts...) + + rootProxyAddr, err := rootServer.ProxyWebAddr() + require.NoError(t, err) + + require.EventuallyWithT(t, func(t *assert.CollectT) { + rootNodes, err := rootServer.GetAuthServer().GetNodes(ctx, apidefaults.Namespace) + if !assert.NoError(t, err) || !assert.Len(t, rootNodes, 1) { + return + } + }, 10*time.Second, 100*time.Millisecond) + + tmpHomePath := t.TempDir() + rootAuth := rootServer.GetAuthServer() + + err = Run(ctx, []string{ + "login", + "--insecure", + "--proxy", rootProxyAddr.String(), + "--user", user.Username, + }, setHomePath(tmpHomePath), setMockSSOLogin(rootAuth, accessUser, connector.GetName())) + require.NoError(t, err) + + tests := []struct { + name string + args []string + expected string + shouldErr bool + }{ + { + // Test that a simple echo works. + name: "ssh simple command", + expected: "this is a test message", + args: []string{ + fmt.Sprintf("%s@%s", user.Username, sshHostname), + "echo", + "this is a test message", + }, + shouldErr: false, + }, + { + // Test that commands can be prefixed with a double dash. + name: "ssh command with double dash", + expected: "this is a test message", + args: []string{ + fmt.Sprintf("%s@%s", user.Username, sshHostname), + "--", + "echo", + "this is a test message", + }, + shouldErr: false, + }, + { + // Test that a double dash is not removed from the middle of a command. + name: "ssh command with double dash in the middle", + expected: "-- this is a test message", + args: []string{ + fmt.Sprintf("%s@%s", user.Username, sshHostname), + "echo", + "--", + "this is a test message", + }, + shouldErr: false, + }, + { + // Test that quoted commands work (e.g. `tsh ssh 'echo test'`) + name: "ssh command literal", + expected: "this is a test message", + args: []string{ + fmt.Sprintf("%s@%s", user.Username, sshHostname), + "echo this is a test message", + }, + shouldErr: false, + }, + { + // Test that a double dash is passed as-is in a quoted command (which should fail). + name: "ssh command literal with double dash err", + expected: "", + args: []string{ + fmt.Sprintf("%s@%s", user.Username, sshHostname), + "-- echo this is a test message", + }, + shouldErr: true, + }, + { + // Test that a double dash is not removed from the middle of a quoted command. + name: "ssh command literal with double dash in the middle", + expected: "-- this is a test message", + args: []string{ + fmt.Sprintf("%s@%s", user.Username, sshHostname), + "echo", "-- this is a test message", + }, + shouldErr: false, + }, + { + // Test tsh ssh -- hostname command + name: "delimiter before host and command", + expected: "this is a test message", + args: []string{ + "--", sshHostname, "echo", "this is a test message", + }, + shouldErr: false, + }, + } + + for _, test := range tests { + test := test + ctx := context.Background() + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + stdout := &output{buf: bytes.Buffer{}} + stderr := &output{buf: bytes.Buffer{}} + args := append( + []string{ + "ssh", + "--insecure", + "--proxy", rootProxyAddr.String(), + }, + test.args..., + ) + + err := Run(ctx, args, setHomePath(tmpHomePath), + func(conf *CLIConf) error { + conf.overrideStdin = &bytes.Buffer{} + conf.OverrideStdout = stdout + conf.overrideStderr = stderr + return nil + }, + ) + + if test.shouldErr { + require.Error(t, err) + } else { + require.NoError(t, err) + require.Equal(t, test.expected, strings.TrimSpace(stdout.String())) + require.Empty(t, stderr.String()) + } + }) + } +} + // tryCreateTrustedCluster performs several attempts to create a trusted cluster, // retries on connection problems and access denied errors to let caches // propagate and services to start