From e43d462a366f56f3ad1604924221a4dd0555e65b Mon Sep 17 00:00:00 2001 From: Maksim An Date: Tue, 19 Oct 2021 17:16:03 -0700 Subject: [PATCH] Extend integrity protection of LCOW layers to SCSI devices (#1170) * extend integrity protection of LCOW layers to SCSI devices LCOW layers can be added both as VPMem and as SCSI devices. Previous work focused on enabling integrity protection for read only VPMem layers, this change enables it for read-only SCSI devices as well. Just like in a VPMem scenario, create dm-verity target when verity information is presented to the guest during SCSI device mounting step. Additionally remove unnecessary unit test, since the guest logic has changed. Add pmem and scsi unit tests for linear/verity device mapper targets Signed-off-by: Maksim An --- guest/storage/pmem/pmem.go | 22 ++- guest/storage/pmem/pmem_test.go | 323 ++++++++++++++++++++++++++++++++ guest/storage/scsi/scsi.go | 38 +++- guest/storage/scsi/scsi_test.go | 317 +++++++++++++++++++++++++------ uvm/scsi.go | 2 +- 5 files changed, 631 insertions(+), 71 deletions(-) diff --git a/guest/storage/pmem/pmem.go b/guest/storage/pmem/pmem.go index 8af7207bad..681659f061 100644 --- a/guest/storage/pmem/pmem.go +++ b/guest/storage/pmem/pmem.go @@ -21,9 +21,13 @@ import ( // Test dependencies var ( - osMkdirAll = os.MkdirAll - osRemoveAll = os.RemoveAll - unixMount = unix.Mount + osMkdirAll = os.MkdirAll + osRemoveAll = os.RemoveAll + unixMount = unix.Mount + mountInternal = mount + createZeroSectorLinearTarget = dm.CreateZeroSectorLinearTarget + createVerityTarget = dm.CreateVerityTarget + removeDevice = dm.RemoveDevice ) const ( @@ -32,8 +36,8 @@ const ( verityDeviceFmt = "dm-verity-pmem%d-%s" ) -// mountInternal mounts source to target via unix.Mount -func mountInternal(ctx context.Context, source, target string) (err error) { +// mount mounts source to target via unix.Mount +func mount(ctx context.Context, source, target string) (err error) { if err := osMkdirAll(target, 0700); err != nil { return err } @@ -89,12 +93,12 @@ func Mount(ctx context.Context, device uint32, target string, mappingInfo *prot. // device instead of the original VPMem. if mappingInfo != nil { dmLinearName := fmt.Sprintf(linearDeviceFmt, device, mappingInfo.DeviceOffsetInBytes, mappingInfo.DeviceSizeInBytes) - if devicePath, err = dm.CreateZeroSectorLinearTarget(mCtx, devicePath, dmLinearName, mappingInfo); err != nil { + if devicePath, err = createZeroSectorLinearTarget(mCtx, devicePath, dmLinearName, mappingInfo); err != nil { return err } defer func() { if err != nil { - if err := dm.RemoveDevice(dmLinearName); err != nil { + if err := removeDevice(dmLinearName); err != nil { log.G(mCtx).WithError(err).Debugf("failed to cleanup linear target: %s", dmLinearName) } } @@ -103,12 +107,12 @@ func Mount(ctx context.Context, device uint32, target string, mappingInfo *prot. if verityInfo != nil { dmVerityName := fmt.Sprintf(verityDeviceFmt, device, verityInfo.RootDigest) - if devicePath, err = dm.CreateVerityTarget(mCtx, devicePath, dmVerityName, verityInfo); err != nil { + if devicePath, err = createVerityTarget(mCtx, devicePath, dmVerityName, verityInfo); err != nil { return err } defer func() { if err != nil { - if err := dm.RemoveDevice(dmVerityName); err != nil { + if err := removeDevice(dmVerityName); err != nil { log.G(mCtx).WithError(err).Debugf("failed to cleanup verity target: %s", dmVerityName) } } diff --git a/guest/storage/pmem/pmem_test.go b/guest/storage/pmem/pmem_test.go index d147ede5f2..be61d70b11 100644 --- a/guest/storage/pmem/pmem_test.go +++ b/guest/storage/pmem/pmem_test.go @@ -5,6 +5,7 @@ package pmem import ( "context" "fmt" + "github.com/Microsoft/hcsshim/internal/guest/prot" "os" "testing" @@ -18,6 +19,10 @@ func clearTestDependencies() { osMkdirAll = nil osRemoveAll = nil unixMount = nil + createZeroSectorLinearTarget = nil + createVerityTarget = nil + removeDevice = nil + mountInternal = mount } func Test_Mount_Mkdir_Fails_Error(t *testing.T) { @@ -305,3 +310,321 @@ func openDoorSecurityPolicyEnforcer() securitypolicy.SecurityPolicyEnforcer { func mountMonitoringSecurityPolicyEnforcer() *policy.MountMonitoringSecurityPolicyEnforcer { return &policy.MountMonitoringSecurityPolicyEnforcer{} } + +// device mapper tests +func Test_CreateLinearTarget_And_Mount_Called_With_Correct_Parameters(t *testing.T) { + clearTestDependencies() + + mappingInfo := &prot.DeviceMappingInfo{ + DeviceOffsetInBytes: 0, + DeviceSizeInBytes: 1024, + } + expectedLinearName := fmt.Sprintf(linearDeviceFmt, 0, mappingInfo.DeviceOffsetInBytes, mappingInfo.DeviceSizeInBytes) + expectedSource := "/dev/pmem0" + expectedTarget := "/foo" + mapperPath := fmt.Sprintf("/dev/mapper/%s", expectedLinearName) + createZSLTCalled := false + + osMkdirAll = func(_ string, _ os.FileMode) error { + return nil + } + + mountInternal = func(_ context.Context, source, target string) error { + if source != mapperPath { + t.Errorf("expected mountInternal source %s, got %s", mapperPath, source) + } + if target != expectedTarget { + t.Errorf("expected mountInternal target %s, got %s", expectedTarget, source) + } + return nil + } + + createZeroSectorLinearTarget = func(_ context.Context, source, name string, mapping *prot.DeviceMappingInfo) (string, error) { + createZSLTCalled = true + if source != expectedSource { + t.Errorf("expected createZeroSectorLinearTarget source %s, got %s", expectedSource, source) + } + if name != expectedLinearName { + t.Errorf("expected createZeroSectorLinearTarget name %s, got %s", expectedLinearName, name) + } + return mapperPath, nil + } + + if err := Mount( + context.Background(), + 0, + expectedTarget, + mappingInfo, + nil, + openDoorSecurityPolicyEnforcer(), + ); err != nil { + t.Fatalf("unexpected error during Mount: %s", err) + } + if !createZSLTCalled { + t.Fatalf("createZeroSectorLinearTarget not called") + } +} + +func Test_CreateVerityTargetCalled_And_Mount_Called_With_Correct_Parameters(t *testing.T) { + clearTestDependencies() + + verityInfo := &prot.DeviceVerityInfo{ + RootDigest: "hash", + } + expectedVerityName := fmt.Sprintf(verityDeviceFmt, 0, verityInfo.RootDigest) + expectedSource := "/dev/pmem0" + expectedTarget := "/foo" + mapperPath := fmt.Sprintf("/dev/mapper/%s", expectedVerityName) + createVerityTargetCalled := false + + mountInternal = func(_ context.Context, source, target string) error { + if source != mapperPath { + t.Errorf("expected mountInternal source %s, got %s", mapperPath, source) + } + if target != expectedTarget { + t.Errorf("expected mountInternal target %s, got %s", expectedTarget, target) + } + return nil + } + createVerityTarget = func(_ context.Context, source, name string, verity *prot.DeviceVerityInfo) (string, error) { + createVerityTargetCalled = true + if source != expectedSource { + t.Errorf("expected createVerityTarget source %s, got %s", expectedSource, source) + } + if name != expectedVerityName { + t.Errorf("expected createVerityTarget name %s, got %s", expectedVerityName, name) + } + return mapperPath, nil + } + + if err := Mount( + context.Background(), + 0, + expectedTarget, + nil, + verityInfo, + openDoorSecurityPolicyEnforcer(), + ); err != nil { + t.Fatalf("unexpected Mount failure: %s", err) + } + if !createVerityTargetCalled { + t.Fatal("createVerityTarget not called") + } +} + +func Test_CreateLinearTarget_And_CreateVerityTargetCalled_Called_Correctly(t *testing.T) { + clearTestDependencies() + + verityInfo := &prot.DeviceVerityInfo{ + RootDigest: "hash", + } + mapping := &prot.DeviceMappingInfo{ + DeviceOffsetInBytes: 0, + DeviceSizeInBytes: 1024, + } + expectedLinearTarget := fmt.Sprintf(linearDeviceFmt, 0, mapping.DeviceOffsetInBytes, mapping.DeviceSizeInBytes) + expectedVerityTarget := fmt.Sprintf(verityDeviceFmt, 0, verityInfo.RootDigest) + expectedPMemDevice := "/dev/pmem0" + mapperLinearPath := fmt.Sprintf("/dev/mapper/%s", expectedLinearTarget) + mapperVerityPath := fmt.Sprintf("/dev/mapper/%s", expectedVerityTarget) + dmLinearCalled := false + dmVerityCalled := false + mountCalled := false + + createZeroSectorLinearTarget = func(_ context.Context, source, name string, mapping *prot.DeviceMappingInfo) (string, error) { + dmLinearCalled = true + if source != expectedPMemDevice { + t.Errorf("expected createZeroSectorLinearTarget source %s, got %s", expectedPMemDevice, source) + } + if name != expectedLinearTarget { + t.Errorf("expected createZeroSectorLinearTarget name %s, got %s", expectedLinearTarget, name) + } + return mapperLinearPath, nil + } + createVerityTarget = func(_ context.Context, source, name string, verity *prot.DeviceVerityInfo) (string, error) { + dmVerityCalled = true + if source != mapperLinearPath { + t.Errorf("expected createVerityTarget source %s, got %s", mapperLinearPath, source) + } + if name != expectedVerityTarget { + t.Errorf("expected createVerityTarget target name %s, got %s", expectedVerityTarget, name) + } + return mapperVerityPath, nil + } + mountInternal = func(_ context.Context, source, target string) error { + mountCalled = true + if source != mapperVerityPath { + t.Errorf("expected Mount source %s, got %s", mapperVerityPath, source) + } + return nil + } + + if err := Mount( + context.Background(), + 0, + "/foo", + mapping, + verityInfo, + openDoorSecurityPolicyEnforcer(), + ); err != nil { + t.Fatalf("unexpected error during Mount call: %s", err) + } + if !dmLinearCalled { + t.Fatal("expected createZeroSectorLinearTarget call") + } + if !dmVerityCalled { + t.Fatal("expected createVerityTarget call") + } + if !mountCalled { + t.Fatal("expected mountInternal call") + } +} + +func Test_RemoveDevice_Called_For_LinearTarget_On_MountInternalFailure(t *testing.T) { + clearTestDependencies() + + mappingInfo := &prot.DeviceMappingInfo{ + DeviceOffsetInBytes: 0, + DeviceSizeInBytes: 1024, + } + expectedError := errors.New("mountInternal error") + expectedTarget := fmt.Sprintf(linearDeviceFmt, 0, mappingInfo.DeviceOffsetInBytes, mappingInfo.DeviceSizeInBytes) + mapperPath := fmt.Sprintf("/dev/mapper/%s", expectedTarget) + removeDeviceCalled := false + + createZeroSectorLinearTarget = func(_ context.Context, source, name string, mapping *prot.DeviceMappingInfo) (string, error) { + return mapperPath, nil + } + mountInternal = func(_ context.Context, source, target string) error { + return expectedError + } + removeDevice = func(name string) error { + removeDeviceCalled = true + if name != expectedTarget { + t.Errorf("expected removeDevice linear target %s, got %s", expectedTarget, name) + } + return nil + } + + if err := Mount( + context.Background(), + 0, + "/foo", + mappingInfo, + nil, + openDoorSecurityPolicyEnforcer(), + ); err != expectedError { + t.Fatalf("expected Mount error %s, got %s", expectedError, err) + } + if !removeDeviceCalled { + t.Fatal("expected removeDevice to be callled") + } +} + +func Test_RemoveDevice_Called_For_VerityTarget_On_MountInternalFailure(t *testing.T) { + clearTestDependencies() + + verity := &prot.DeviceVerityInfo{ + RootDigest: "hash", + } + expectedVerityTarget := fmt.Sprintf(verityDeviceFmt, 0, verity.RootDigest) + expectedError := errors.New("mountInternal error") + mapperPath := fmt.Sprintf("/dev/mapper/%s", expectedVerityTarget) + removeDeviceCalled := false + + createVerityTarget = func(_ context.Context, source, name string, verity *prot.DeviceVerityInfo) (string, error) { + return mapperPath, nil + } + mountInternal = func(_ context.Context, _, _ string) error { + return expectedError + } + removeDevice = func(name string) error { + removeDeviceCalled = true + if name != expectedVerityTarget { + t.Errorf("expected removeDevice verity target %s, got %s", expectedVerityTarget, name) + } + return nil + } + + if err := Mount( + context.Background(), + 0, + "/foo", + nil, + verity, + openDoorSecurityPolicyEnforcer(), + ); err != expectedError { + t.Fatalf("expected Mount error %s, got %s", expectedError, err) + } + if !removeDeviceCalled { + t.Fatal("expected removeDevice to be called") + } +} + +func Test_RemoveDevice_Called_For_Both_Targets_On_MountInternalFailure(t *testing.T) { + clearTestDependencies() + + mapping := &prot.DeviceMappingInfo{ + DeviceOffsetInBytes: 0, + DeviceSizeInBytes: 1024, + } + verity := &prot.DeviceVerityInfo{ + RootDigest: "hash", + } + expectedError := errors.New("mountInternal error") + expectedLinearTarget := fmt.Sprintf(linearDeviceFmt, 0, mapping.DeviceOffsetInBytes, mapping.DeviceSizeInBytes) + expectedVerityTarget := fmt.Sprintf(verityDeviceFmt, 0, verity.RootDigest) + expectedPMemDevice := "/dev/pmem0" + mapperLinearPath := fmt.Sprintf("/dev/mapper/%s", expectedLinearTarget) + mapperVerityPath := fmt.Sprintf("/dev/mapper/%s", expectedVerityTarget) + rmLinearCalled := false + rmVerityCalled := false + + createZeroSectorLinearTarget = func(_ context.Context, source, name string, m *prot.DeviceMappingInfo) (string, error) { + if source != expectedPMemDevice { + t.Errorf("expected createZeroSectorLinearTarget source %s, got %s", expectedPMemDevice, source) + } + return mapperLinearPath, nil + } + createVerityTarget = func(_ context.Context, source, name string, v *prot.DeviceVerityInfo) (string, error) { + if source != mapperLinearPath { + t.Errorf("expected createVerityTarget to be called with %s, got %s", mapperLinearPath, source) + } + if name != expectedVerityTarget { + t.Errorf("expected createVerityTarget target %s, got %s", expectedVerityTarget, name) + } + return mapperVerityPath, nil + } + removeDevice = func(name string) error { + if name != expectedLinearTarget && name != expectedVerityTarget { + t.Errorf("unexpected removeDevice target name %s", name) + } + if name == expectedLinearTarget { + rmLinearCalled = true + } + if name == expectedVerityTarget { + rmVerityCalled = true + } + return nil + } + mountInternal = func(_ context.Context, _, _ string) error { + return expectedError + } + + if err := Mount( + context.Background(), + 0, + "/foo", + mapping, + verity, + openDoorSecurityPolicyEnforcer(), + ); err != expectedError { + t.Fatalf("expected Mount error %s, got %s", expectedError, err) + } + if !rmLinearCalled { + t.Fatal("expected removeDevice for linear target to be called") + } + if !rmVerityCalled { + t.Fatal("expected removeDevice for verity target to be called") + } +} diff --git a/guest/storage/scsi/scsi.go b/guest/storage/scsi/scsi.go index c95ab9d555..fbcebf3754 100644 --- a/guest/storage/scsi/scsi.go +++ b/guest/storage/scsi/scsi.go @@ -5,6 +5,7 @@ package scsi import ( "context" "fmt" + dm "github.com/Microsoft/hcsshim/internal/guest/storage/devicemapper" "io/ioutil" "os" "path/filepath" @@ -29,10 +30,15 @@ var ( // controllerLunToName is stubbed to make testing `Mount` easier. controllerLunToName = ControllerLunToName + // createVerityTarget is stubbed for unit testing `Mount` + createVerityTarget = dm.CreateVerityTarget + // removeDevice is stubbed for unit testing `Mount` + removeDevice = dm.RemoveDevice ) const ( scsiDevicesPath = "/sys/bus/scsi/devices" + verityDeviceFmt = "verity-scsi-contr%d-lun%d-%s" ) // Mount creates a mount from the SCSI device on `controller` index `lun` to @@ -52,16 +58,36 @@ func Mount(ctx context.Context, controller, lun uint8, target string, readonly b trace.Int64Attribute("controller", int64(controller)), trace.Int64Attribute("lun", int64(lun))) + source, err := controllerLunToName(spnCtx, controller, lun) + if err != nil { + return err + } + if readonly { // containers only have read-only layers so only enforce for them var deviceHash string if verityInfo != nil { deviceHash = verityInfo.RootDigest } + err = securityPolicy.EnforceDeviceMountPolicy(target, deviceHash) if err != nil { return errors.Wrapf(err, "won't mount scsi controller %d lun %d onto %s", controller, lun, target) } + + if verityInfo != nil { + dmVerityName := fmt.Sprintf(verityDeviceFmt, controller, lun, deviceHash) + if source, err = createVerityTarget(spnCtx, source, dmVerityName, verityInfo); err != nil { + return err + } + defer func() { + if err != nil { + if err := removeDevice(dmVerityName); err != nil { + log.G(spnCtx).WithError(err).WithField("verityTarget", dmVerityName).Debug("failed to cleanup verity target") + } + } + }() + } } if err := osMkdirAll(target, 0700); err != nil { @@ -72,10 +98,6 @@ func Mount(ctx context.Context, controller, lun uint8, target string, readonly b osRemoveAll(target) } }() - source, err := controllerLunToName(spnCtx, controller, lun) - if err != nil { - return err - } // we only care about readonly mount option when mounting the device var flags uintptr @@ -147,6 +169,14 @@ func Unmount(ctx context.Context, controller, lun uint8, target string, encrypte return errors.Wrapf(err, "unmount failed: "+target) } + if verityInfo != nil { + dmVerityName := fmt.Sprintf(verityDeviceFmt, controller, lun, verityInfo.RootDigest) + if err := removeDevice(dmVerityName); err != nil { + // Ignore failures, since the path has been unmounted at this point. + log.G(ctx).WithError(err).Debugf("failed to remove dm verity target: %s", dmVerityName) + } + } + if encrypted { if err := crypt.CleanupCryptDevice(target); err != nil { return errors.Wrapf(err, "failed to cleanup dm-crypt state: "+target) diff --git a/guest/storage/scsi/scsi_test.go b/guest/storage/scsi/scsi_test.go index 99cc77533d..853cb83df7 100644 --- a/guest/storage/scsi/scsi_test.go +++ b/guest/storage/scsi/scsi_test.go @@ -5,6 +5,8 @@ package scsi import ( "context" "errors" + "fmt" + "github.com/Microsoft/hcsshim/internal/guest/prot" "os" "testing" @@ -18,6 +20,7 @@ func clearTestDependencies() { osRemoveAll = nil unixMount = nil controllerLunToName = nil + createVerityTarget = nil } func Test_Mount_Mkdir_Fails_Error(t *testing.T) { @@ -27,8 +30,22 @@ func Test_Mount_Mkdir_Fails_Error(t *testing.T) { osMkdirAll = func(path string, perm os.FileMode) error { return expectedErr } - err := Mount(context.Background(), 0, 0, "", false, false, nil, nil, openDoorSecurityPolicyEnforcer()) - if err != expectedErr { + + controllerLunToName = func(ctx context.Context, controller, lun uint8) (string, error) { + return "", nil + } + + if err := Mount( + context.Background(), + 0, + 0, + "", + false, + false, + nil, + nil, + openDoorSecurityPolicyEnforcer(), + ); err != expectedErr { t.Fatalf("expected err: %v, got: %v", expectedErr, err) } } @@ -54,8 +71,18 @@ func Test_Mount_Mkdir_ExpectedPath(t *testing.T) { // Fake the mount success return nil } - err := Mount(context.Background(), 0, 0, target, false, false, nil, nil, openDoorSecurityPolicyEnforcer()) - if err != nil { + + if err := Mount( + context.Background(), + 0, + 0, + target, + false, + false, + nil, + nil, + openDoorSecurityPolicyEnforcer(), + ); err != nil { t.Fatalf("expected nil error got: %v", err) } } @@ -81,8 +108,18 @@ func Test_Mount_Mkdir_ExpectedPerm(t *testing.T) { // Fake the mount success return nil } - err := Mount(context.Background(), 0, 0, target, false, false, nil, nil, openDoorSecurityPolicyEnforcer()) - if err != nil { + + if err := Mount( + context.Background(), + 0, + 0, + target, + false, + false, + nil, + nil, + openDoorSecurityPolicyEnforcer(), + ); err != nil { t.Fatalf("expected nil error got: %v", err) } } @@ -108,8 +145,18 @@ func Test_Mount_ControllerLunToName_Valid_Controller(t *testing.T) { // Fake the mount success return nil } - err := Mount(context.Background(), expectedController, 0, "/fake/path", false, false, nil, nil, openDoorSecurityPolicyEnforcer()) - if err != nil { + + if err := Mount( + context.Background(), + expectedController, + 0, + "/fake/path", + false, + false, + nil, + nil, + openDoorSecurityPolicyEnforcer(), + ); err != nil { t.Fatalf("expected nil error got: %v", err) } } @@ -135,42 +182,19 @@ func Test_Mount_ControllerLunToName_Valid_Lun(t *testing.T) { // Fake the mount success return nil } - err := Mount(context.Background(), 0, expectedLun, "/fake/path", false, false, nil, nil, openDoorSecurityPolicyEnforcer()) - if err != nil { - t.Fatalf("expected nil error got: %v", err) - } -} - -func Test_Mount_Calls_RemoveAll_OnControllerToLunFailure(t *testing.T) { - clearTestDependencies() - osMkdirAll = func(path string, perm os.FileMode) error { - return nil - } - expectedErr := errors.New("expected controller to lun failure") - controllerLunToName = func(ctx context.Context, controller, lun uint8) (string, error) { - return "", expectedErr - } - target := "/fake/path" - removeAllCalled := false - osRemoveAll = func(path string) error { - removeAllCalled = true - if path != target { - t.Errorf("expected path: %v, got: %v", target, path) - return errors.New("unexpected path") - } - return nil - } - - // NOTE: Do NOT set unixMount because the controller to lun fails. Expect it - // not to be called. - - err := Mount(context.Background(), 0, 0, target, false, false, nil, nil, openDoorSecurityPolicyEnforcer()) - if err != expectedErr { - t.Fatalf("expected err: %v, got: %v", expectedErr, err) - } - if !removeAllCalled { - t.Fatal("expected os.RemoveAll to be called on mount failure") + if err := Mount( + context.Background(), + 0, + expectedLun, + "/fake/path", + false, + false, + nil, + nil, + openDoorSecurityPolicyEnforcer(), + ); err != nil { + t.Fatalf("expected nil error got: %v", err) } } @@ -198,8 +222,18 @@ func Test_Mount_Calls_RemoveAll_OnMountFailure(t *testing.T) { // Fake the mount failure to test remove is called return expectedErr } - err := Mount(context.Background(), 0, 0, target, false, false, nil, nil, openDoorSecurityPolicyEnforcer()) - if err != expectedErr { + + if err := Mount( + context.Background(), + 0, + 0, + target, + false, + false, + nil, + nil, + openDoorSecurityPolicyEnforcer(), + ); err != expectedErr { t.Fatalf("expected err: %v, got: %v", expectedErr, err) } if !removeAllCalled { @@ -253,8 +287,18 @@ func Test_Mount_Valid_Target(t *testing.T) { } return nil } - err := Mount(context.Background(), 0, 0, expectedTarget, false, false, nil, nil, openDoorSecurityPolicyEnforcer()) - if err != nil { + + if err := Mount( + context.Background(), + 0, + 0, + expectedTarget, + false, + false, + nil, + nil, + openDoorSecurityPolicyEnforcer(), + ); err != nil { t.Fatalf("expected nil err, got: %v", err) } } @@ -279,8 +323,18 @@ func Test_Mount_Valid_FSType(t *testing.T) { } return nil } - err := Mount(context.Background(), 0, 0, "/fake/path", false, false, nil, nil, openDoorSecurityPolicyEnforcer()) - if err != nil { + + if err := Mount( + context.Background(), + 0, + 0, + "/fake/path", + false, + false, + nil, + nil, + openDoorSecurityPolicyEnforcer(), + ); err != nil { t.Fatalf("expected nil err, got: %v", err) } } @@ -305,8 +359,18 @@ func Test_Mount_Valid_Flags(t *testing.T) { } return nil } - err := Mount(context.Background(), 0, 0, "/fake/path", false, false, nil, nil, openDoorSecurityPolicyEnforcer()) - if err != nil { + + if err := Mount( + context.Background(), + 0, + 0, + "/fake/path", + false, + false, + nil, + nil, + openDoorSecurityPolicyEnforcer(), + ); err != nil { t.Fatalf("expected nil err, got: %v", err) } } @@ -331,8 +395,18 @@ func Test_Mount_Readonly_Valid_Flags(t *testing.T) { } return nil } - err := Mount(context.Background(), 0, 0, "/fake/path", true, false, nil, nil, openDoorSecurityPolicyEnforcer()) - if err != nil { + + if err := Mount( + context.Background(), + 0, + 0, + "/fake/path", + true, + false, + nil, + nil, + openDoorSecurityPolicyEnforcer(), + ); err != nil { t.Fatalf("expected nil err, got: %v", err) } } @@ -356,8 +430,18 @@ func Test_Mount_Valid_Data(t *testing.T) { } return nil } - err := Mount(context.Background(), 0, 0, "/fake/path", false, false, nil, nil, openDoorSecurityPolicyEnforcer()) - if err != nil { + + if err := Mount( + context.Background(), + 0, + 0, + "/fake/path", + false, + false, + nil, + nil, + openDoorSecurityPolicyEnforcer(), + ); err != nil { t.Fatalf("expected nil err, got: %v", err) } } @@ -382,8 +466,18 @@ func Test_Mount_Readonly_Valid_Data(t *testing.T) { } return nil } - err := Mount(context.Background(), 0, 0, "/fake/path", true, false, nil, nil, openDoorSecurityPolicyEnforcer()) - if err != nil { + + if err := Mount( + context.Background(), + 0, + 0, + "/fake/path", + true, + false, + nil, + nil, + openDoorSecurityPolicyEnforcer(), + ); err != nil { t.Fatalf("expected nil err, got: %v", err) } } @@ -529,3 +623,112 @@ func openDoorSecurityPolicyEnforcer() securitypolicy.SecurityPolicyEnforcer { func mountMonitoringSecurityPolicyEnforcer() *policy.MountMonitoringSecurityPolicyEnforcer { return &policy.MountMonitoringSecurityPolicyEnforcer{} } + +// dm-verity tests +func Test_CreateVerityTarget_And_Mount_Called_With_Correct_Parameters(t *testing.T) { + clearTestDependencies() + + expectedVerityName := fmt.Sprintf(verityDeviceFmt, 0, 0, "hash") + expectedSource := "/dev/sdb" + expectedMapperPath := fmt.Sprintf("/dev/mapper/%s", expectedVerityName) + expectedTarget := "/foo" + createVerityTargetCalled := false + + controllerLunToName = func(_ context.Context, _, _ uint8) (string, error) { + return expectedSource, nil + } + + osMkdirAll = func(_ string, _ os.FileMode) error { + return nil + } + + vInfo := &prot.DeviceVerityInfo{ + RootDigest: "hash", + } + createVerityTarget = func(_ context.Context, source, name string, verityInfo *prot.DeviceVerityInfo) (string, error) { + createVerityTargetCalled = true + if source != expectedSource { + t.Errorf("expected source %s, got %s", expectedSource, source) + } + if name != expectedVerityName { + t.Errorf("expected verity target name %s, got %s", expectedVerityName, name) + } + return expectedMapperPath, nil + } + + unixMount = func(source string, target string, fstype string, flags uintptr, data string) error { + if source != expectedMapperPath { + t.Errorf("expected unixMount source %s, got %s", expectedMapperPath, source) + } + if target != expectedTarget { + t.Errorf("expected unixMount target %s, got %s", expectedTarget, target) + } + return nil + } + + if err := Mount( + context.Background(), + 0, + 0, + expectedTarget, + true, + false, + nil, + vInfo, + openDoorSecurityPolicyEnforcer(), + ); err != nil { + t.Fatalf("unexpected error during Mount: %s", err) + } + if !createVerityTargetCalled { + t.Fatalf("expected createVerityTargetCalled to be called") + } +} + +func Test_osMkdirAllFails_And_RemoveDevice_Called(t *testing.T) { + clearTestDependencies() + + expectedError := errors.New("osMkdirAll error") + expectedVerityName := fmt.Sprintf(verityDeviceFmt, 0, 0, "hash") + removeDeviceCalled := false + + controllerLunToName = func(_ context.Context, _, _ uint8) (string, error) { + return "/dev/sdb", nil + } + + osMkdirAll = func(_ string, _ os.FileMode) error { + return expectedError + } + + verityInfo := &prot.DeviceVerityInfo{ + RootDigest: "hash", + } + + createVerityTarget = func(_ context.Context, _, _ string, _ *prot.DeviceVerityInfo) (string, error) { + return fmt.Sprintf("/dev/mapper/%s", expectedVerityName), nil + } + + removeDevice = func(name string) error { + removeDeviceCalled = true + if name != expectedVerityName { + t.Errorf("expected RemoveDevice name %s, got %s", expectedVerityName, name) + } + return nil + } + + if err := Mount( + context.Background(), + 0, + 0, + "/foo", + true, + false, + nil, + verityInfo, + openDoorSecurityPolicyEnforcer(), + ); err != expectedError { + t.Fatalf("expected Mount error %s, got %s", expectedError, err) + } + if !removeDeviceCalled { + t.Fatal("expected removeDevice to be called") + } +} diff --git a/uvm/scsi.go b/uvm/scsi.go index 652cc5b964..f9587e908a 100644 --- a/uvm/scsi.go +++ b/uvm/scsi.go @@ -443,7 +443,7 @@ func (uvm *UtilityVM) addSCSIActual(ctx context.Context, addReq *addSCSIRequest) log.G(ctx).WithFields(logrus.Fields{ "hostPath": sm.HostPath, "rootDigest": v.RootDigest, - }).Debug("adding VPMem with dm-verity") + }).Debug("adding SCSI with dm-verity") } verity = v }