diff --git a/lib/devicetrust/authz/authz.go b/lib/devicetrust/authz/authz.go index 5a73a1ce602e..a95a8f13a4e4 100644 --- a/lib/devicetrust/authz/authz.go +++ b/lib/devicetrust/authz/authz.go @@ -15,8 +15,6 @@ package authz import ( - "sync" - "github.com/gravitational/trace" log "github.com/sirupsen/logrus" "golang.org/x/crypto/ssh" @@ -24,7 +22,7 @@ import ( "github.com/gravitational/teleport" "github.com/gravitational/teleport/api/constants" "github.com/gravitational/teleport/api/types" - "github.com/gravitational/teleport/lib/devicetrust/config" + dtconfig "github.com/gravitational/teleport/lib/devicetrust/config" "github.com/gravitational/teleport/lib/tlsca" ) @@ -63,9 +61,7 @@ func VerifySSHUser(dt *types.DeviceTrust, cert *ssh.Certificate) error { } func verifyDeviceExtensions(dt *types.DeviceTrust, username string, verified bool) error { - mode := config.GetEffectiveMode(dt) - maybeLogModeMismatch(mode, dt) - + mode := dtconfig.GetEnforcementMode(dt) switch { case mode != constants.DeviceTrustModeRequired: return nil // OK, extensions not enforced. @@ -78,15 +74,3 @@ func verifyDeviceExtensions(dt *types.DeviceTrust, username string, verified boo return nil } } - -var logModeOnce sync.Once - -func maybeLogModeMismatch(effective string, dt *types.DeviceTrust) { - if dt == nil || dt.Mode == "" || effective == dt.Mode { - return - } - - logModeOnce.Do(func() { - log.Warnf("Device Trust: mode %q requires Teleport Enterprise. Using effective mode %q.", dt.Mode, effective) - }) -} diff --git a/lib/devicetrust/authz/authz_test.go b/lib/devicetrust/authz/authz_test.go index 469bb3dc5fc5..30cabf7abcfc 100644 --- a/lib/devicetrust/authz/authz_test.go +++ b/lib/devicetrust/authz/authz_test.go @@ -170,13 +170,13 @@ func runVerifyUserTest(t *testing.T, method string, verify func(dt *types.Device assertErr: assertNoErr, }, { - name: "OSS mode never enforced", + name: "OSS mode=required (Enterprise Auth)", buildType: modules.BuildOSS, dt: &types.DeviceTrust{ - Mode: constants.DeviceTrustModeRequired, // Invalid for OSS, treated as "off". + Mode: constants.DeviceTrustModeRequired, }, ext: userWithoutExtensions, - assertErr: assertNoErr, + assertErr: assertDeniedErr, }, { name: "Enterprise mode=off", diff --git a/lib/devicetrust/config/config.go b/lib/devicetrust/config/config.go index 4db5737d1736..f1c91b3b569d 100644 --- a/lib/devicetrust/config/config.go +++ b/lib/devicetrust/config/config.go @@ -38,6 +38,18 @@ func GetEffectiveMode(dt *types.DeviceTrust) string { return dt.Mode } +// GetEnforcementMode returns the configured device trust mode, disregarding the +// provenance of the binary if the mode is set. +// Used for device enforcement checks. Guarantees that OSS binaries paired with +// an Enterprise Auth will correctly enforce device trust. +func GetEnforcementMode(dt *types.DeviceTrust) string { + // If absent use the defaults from GetEffectiveMode. + if dt == nil || dt.Mode == "" { + return GetEffectiveMode(dt) + } + return dt.Mode +} + // ValidateConfigAgainstModules verifies the device trust configuration against // the current modules. // This method exists to provide feedback to users about invalid configurations, diff --git a/lib/devicetrust/config/config_test.go b/lib/devicetrust/config/config_test.go index 1082402f8636..075b0cc457c6 100644 --- a/lib/devicetrust/config/config_test.go +++ b/lib/devicetrust/config/config_test.go @@ -28,6 +28,8 @@ import ( ) func TestValidateConfigAgainstModules(t *testing.T) { + // Don't t.Parallel, depends on modules.SetTestModules. + type testCase struct { name string buildType string @@ -106,3 +108,59 @@ func TestValidateConfigAgainstModules(t *testing.T) { }) } } + +func TestGetEnforcementMode(t *testing.T) { + // Don't t.Parallel, depends on modules.SetTestModules. + + tests := []struct { + name string + buildType string + dt *types.DeviceTrust + want string + }{ + { + name: "OSS default", + buildType: modules.BuildOSS, + want: constants.DeviceTrustModeOff, + }, + { + name: "Enterprise default", + buildType: modules.BuildEnterprise, + want: constants.DeviceTrustModeOptional, + }, + { + name: "dt.Mode empty", + buildType: modules.BuildEnterprise, + dt: &types.DeviceTrust{ + Mode: "", + }, + want: constants.DeviceTrustModeOptional, + }, + { + name: "dt.Mode set", + buildType: modules.BuildEnterprise, + dt: &types.DeviceTrust{ + Mode: constants.DeviceTrustModeRequired, + }, + want: constants.DeviceTrustModeRequired, + }, + { + name: "OSS node with Ent Auth", + buildType: modules.BuildOSS, + dt: &types.DeviceTrust{ + Mode: constants.DeviceTrustModeRequired, + }, + want: constants.DeviceTrustModeRequired, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + modules.SetTestModules(t, &modules.TestModules{ + TestBuildType: test.buildType, + }) + + got := dtconfig.GetEnforcementMode(test.dt) + assert.Equal(t, test.want, got, "dtconfig.GetEnforcementMode mismatch") + }) + } +} diff --git a/lib/srv/session_control_test.go b/lib/srv/session_control_test.go index d01894a2e3eb..ecdd6a9dac3e 100644 --- a/lib/srv/session_control_test.go +++ b/lib/srv/session_control_test.go @@ -159,6 +159,10 @@ func TestSessionController_AcquireSessionContext(t *testing.T) { } return idCtx } + assertTrustedDeviceRequired := func(t *testing.T, _ context.Context, err error, _ *eventstest.MockRecorderEmitter) { + assert.ErrorContains(t, err, "device", "AcquireSessionContext returned an unexpected error") + assert.True(t, trace.IsAccessDenied(err), "AcquireSessionContext returned an error other than trace.AccessDeniedError: %T", err) + } cases := []struct { name string @@ -447,22 +451,17 @@ func TestSessionController_AcquireSessionContext(t *testing.T) { }, }, { - name: "device extensions not enforced for OSS", - cfg: cfgWithDeviceMode(constants.DeviceTrustModeRequired), - identity: minimalIdentity, - assertion: func(t *testing.T, _ context.Context, err error, _ *eventstest.MockRecorderEmitter) { - assert.NoError(t, err, "AcquireSessionContext returned an unexpected error") - }, + name: "device extensions enforced for OSS", + cfg: cfgWithDeviceMode(constants.DeviceTrustModeRequired), + identity: minimalIdentity, + assertion: assertTrustedDeviceRequired, }, { name: "device extensions enforced for Enterprise", buildType: modules.BuildEnterprise, cfg: cfgWithDeviceMode(constants.DeviceTrustModeRequired), identity: minimalIdentity, - assertion: func(t *testing.T, _ context.Context, err error, _ *eventstest.MockRecorderEmitter) { - assert.ErrorContains(t, err, "device", "AcquireSessionContext returned an unexpected error") - assert.True(t, trace.IsAccessDenied(err), "AcquireSessionContext returned an error other than trace.AccessDeniedError: %T", err) - }, + assertion: assertTrustedDeviceRequired, }, { name: "device extensions valid for Enterprise",