diff --git a/lib/web/apiserver_test.go b/lib/web/apiserver_test.go index 2e2d4411594a..854695f64886 100644 --- a/lib/web/apiserver_test.go +++ b/lib/web/apiserver_test.go @@ -7535,6 +7535,10 @@ type authProviderMock struct { server types.ServerV2 } +func (mock authProviderMock) ListUnifiedResources(ctx context.Context, req *authproto.ListUnifiedResourcesRequest) (*authproto.ListUnifiedResourcesResponse, error) { + return nil, nil +} + func (mock authProviderMock) GetNode(ctx context.Context, namespace, name string) (types.Server, error) { return &mock.server, nil } diff --git a/lib/web/terminal.go b/lib/web/terminal.go index 5e0913199e63..6b06a53832c3 100644 --- a/lib/web/terminal.go +++ b/lib/web/terminal.go @@ -22,6 +22,7 @@ import ( "context" "encoding/json" "errors" + "fmt" "io" "net" "net/http" @@ -104,6 +105,7 @@ type UserAuthClient interface { CreateAuthenticateChallenge(ctx context.Context, req *authproto.CreateAuthenticateChallengeRequest) (*authproto.MFAAuthenticateChallenge, error) GenerateUserCerts(ctx context.Context, req authproto.UserCertsRequest) (*authproto.Certs, error) MaintainSessionPresence(ctx context.Context) (authproto.AuthService_MaintainSessionPresenceClient, error) + ListUnifiedResources(ctx context.Context, req *authproto.ListUnifiedResourcesRequest) (*authproto.ListUnifiedResourcesResponse, error) } // NewTerminal creates a web-based terminal based on WebSockets and returns a @@ -911,6 +913,21 @@ func (t *sshBaseHandler) connectToNode(ctx context.Context, ws terminal.WSConn, // The close error is ignored instead of using [trace.NewAggregate] because // aggregate errors do not allow error inspection with things like [trace.IsAccessDenied]. _ = conn.Close() + + // Since connection attempts are made via UUID and not hostname, any access denied errors + // will not contain the resolved host address. To provide an easier troubleshooting experience + // for users, attempt to resolve the hostname of the server and augment the error message with it. + if trace.IsAccessDenied(err) { + if resp, err := t.userAuthClient.ListUnifiedResources(ctx, &authproto.ListUnifiedResourcesRequest{ + SortBy: types.SortBy{Field: types.ResourceKind}, + Kinds: []string{types.KindNode}, + Limit: 1, + PredicateExpression: fmt.Sprintf(`resource.metadata.name == "%s"`, t.sessionData.ServerID), + }); err == nil && len(resp.Resources) > 0 { + return nil, trace.AccessDenied("access denied to %q connecting to %v", sshConfig.User, resp.Resources[0].GetNode().GetHostname()) + } + } + return nil, trace.Wrap(err) }