Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Enforce device trust on OSS processes #46940

Merged
merged 2 commits into from
Sep 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 2 additions & 18 deletions lib/devicetrust/authz/authz.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,14 @@
package authz

import (
"sync"

"github.com/gravitational/trace"
log "github.com/sirupsen/logrus"
"golang.org/x/crypto/ssh"

"github.com/gravitational/teleport"
"github.com/gravitational/teleport/api/constants"
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/lib/devicetrust/config"
dtconfig "github.com/gravitational/teleport/lib/devicetrust/config"
"github.com/gravitational/teleport/lib/tlsca"
)

Expand Down Expand Up @@ -73,9 +71,7 @@ func VerifySSHUser(dt *types.DeviceTrust, cert *ssh.Certificate) error {
}

func verifyDeviceExtensions(dt *types.DeviceTrust, username string, verified bool) error {
mode := config.GetEffectiveMode(dt)
maybeLogModeMismatch(mode, dt)

mode := dtconfig.GetEnforcementMode(dt)
switch {
case mode != constants.DeviceTrustModeRequired:
return nil // OK, extensions not enforced.
Expand All @@ -88,15 +84,3 @@ func verifyDeviceExtensions(dt *types.DeviceTrust, username string, verified boo
return nil
}
}

var logModeOnce sync.Once

func maybeLogModeMismatch(effective string, dt *types.DeviceTrust) {
if dt == nil || dt.Mode == "" || effective == dt.Mode {
return
}

logModeOnce.Do(func() {
log.Warnf("Device Trust: mode %q requires Teleport Enterprise. Using effective mode %q.", dt.Mode, effective)
})
}
6 changes: 3 additions & 3 deletions lib/devicetrust/authz/authz_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -174,13 +174,13 @@ func runVerifyUserTest(t *testing.T, method string, verify func(dt *types.Device
assertErr: assertNoErr,
},
{
name: "OSS mode never enforced",
name: "OSS mode=required (Enterprise Auth)",
buildType: modules.BuildOSS,
dt: &types.DeviceTrust{
Mode: constants.DeviceTrustModeRequired, // Invalid for OSS, treated as "off".
Mode: constants.DeviceTrustModeRequired,
},
ext: userWithoutExtensions,
assertErr: assertNoErr,
assertErr: assertDeniedErr,
},
{
name: "Enterprise mode=off",
Expand Down
12 changes: 12 additions & 0 deletions lib/devicetrust/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,18 @@ func GetEffectiveMode(dt *types.DeviceTrust) string {
return dt.Mode
}

// GetEnforcementMode returns the configured device trust mode, disregarding the
// provenance of the binary if the mode is set.
// Used for device enforcement checks. Guarantees that OSS binaries paired with
// an Enterprise Auth will correctly enforce device trust.
func GetEnforcementMode(dt *types.DeviceTrust) string {
// If absent use the defaults from GetEffectiveMode.
if dt == nil || dt.Mode == "" {
return GetEffectiveMode(dt)
}
return dt.Mode
}

// ValidateConfigAgainstModules verifies the device trust configuration against
// the current modules.
// This method exists to provide feedback to users about invalid configurations,
Expand Down
58 changes: 58 additions & 0 deletions lib/devicetrust/config/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ import (
)

func TestValidateConfigAgainstModules(t *testing.T) {
// Don't t.Parallel, depends on modules.SetTestModules.

type testCase struct {
name string
buildType string
Expand Down Expand Up @@ -110,3 +112,59 @@ func TestValidateConfigAgainstModules(t *testing.T) {
})
}
}

func TestGetEnforcementMode(t *testing.T) {
// Don't t.Parallel, depends on modules.SetTestModules.

tests := []struct {
name string
buildType string
dt *types.DeviceTrust
want string
}{
{
name: "OSS default",
buildType: modules.BuildOSS,
want: constants.DeviceTrustModeOff,
},
{
name: "Enterprise default",
buildType: modules.BuildEnterprise,
want: constants.DeviceTrustModeOptional,
},
{
name: "dt.Mode empty",
buildType: modules.BuildEnterprise,
dt: &types.DeviceTrust{
Mode: "",
},
want: constants.DeviceTrustModeOptional,
},
{
name: "dt.Mode set",
buildType: modules.BuildEnterprise,
dt: &types.DeviceTrust{
Mode: constants.DeviceTrustModeRequired,
},
want: constants.DeviceTrustModeRequired,
},
{
name: "OSS node with Ent Auth",
buildType: modules.BuildOSS,
dt: &types.DeviceTrust{
Mode: constants.DeviceTrustModeRequired,
},
want: constants.DeviceTrustModeRequired,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
modules.SetTestModules(t, &modules.TestModules{
TestBuildType: test.buildType,
})

got := dtconfig.GetEnforcementMode(test.dt)
assert.Equal(t, test.want, got, "dtconfig.GetEnforcementMode mismatch")
})
}
}
19 changes: 9 additions & 10 deletions lib/srv/session_control_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,10 @@ func TestSessionController_AcquireSessionContext(t *testing.T) {
}
return idCtx
}
assertTrustedDeviceRequired := func(t *testing.T, _ context.Context, err error, _ *eventstest.MockRecorderEmitter) {
assert.ErrorContains(t, err, "device", "AcquireSessionContext returned an unexpected error")
assert.True(t, trace.IsAccessDenied(err), "AcquireSessionContext returned an error other than trace.AccessDeniedError: %T", err)
}

cases := []struct {
name string
Expand Down Expand Up @@ -451,22 +455,17 @@ func TestSessionController_AcquireSessionContext(t *testing.T) {
},
},
{
name: "device extensions not enforced for OSS",
cfg: cfgWithDeviceMode(constants.DeviceTrustModeRequired),
identity: minimalIdentity,
assertion: func(t *testing.T, _ context.Context, err error, _ *eventstest.MockRecorderEmitter) {
assert.NoError(t, err, "AcquireSessionContext returned an unexpected error")
},
name: "device extensions enforced for OSS",
cfg: cfgWithDeviceMode(constants.DeviceTrustModeRequired),
identity: minimalIdentity,
assertion: assertTrustedDeviceRequired,
},
{
name: "device extensions enforced for Enterprise",
buildType: modules.BuildEnterprise,
cfg: cfgWithDeviceMode(constants.DeviceTrustModeRequired),
identity: minimalIdentity,
assertion: func(t *testing.T, _ context.Context, err error, _ *eventstest.MockRecorderEmitter) {
assert.ErrorContains(t, err, "device", "AcquireSessionContext returned an unexpected error")
assert.True(t, trace.IsAccessDenied(err), "AcquireSessionContext returned an error other than trace.AccessDeniedError: %T", err)
},
assertion: assertTrustedDeviceRequired,
},
{
name: "device extensions valid for Enterprise",
Expand Down
Loading