diff --git a/guest/storage/pmem/pmem.go b/guest/storage/pmem/pmem.go index 8af7207bad..681659f061 100644 --- a/guest/storage/pmem/pmem.go +++ b/guest/storage/pmem/pmem.go @@ -21,9 +21,13 @@ import ( // Test dependencies var ( - osMkdirAll = os.MkdirAll - osRemoveAll = os.RemoveAll - unixMount = unix.Mount + osMkdirAll = os.MkdirAll + osRemoveAll = os.RemoveAll + unixMount = unix.Mount + mountInternal = mount + createZeroSectorLinearTarget = dm.CreateZeroSectorLinearTarget + createVerityTarget = dm.CreateVerityTarget + removeDevice = dm.RemoveDevice ) const ( @@ -32,8 +36,8 @@ const ( verityDeviceFmt = "dm-verity-pmem%d-%s" ) -// mountInternal mounts source to target via unix.Mount -func mountInternal(ctx context.Context, source, target string) (err error) { +// mount mounts source to target via unix.Mount +func mount(ctx context.Context, source, target string) (err error) { if err := osMkdirAll(target, 0700); err != nil { return err } @@ -89,12 +93,12 @@ func Mount(ctx context.Context, device uint32, target string, mappingInfo *prot. // device instead of the original VPMem. if mappingInfo != nil { dmLinearName := fmt.Sprintf(linearDeviceFmt, device, mappingInfo.DeviceOffsetInBytes, mappingInfo.DeviceSizeInBytes) - if devicePath, err = dm.CreateZeroSectorLinearTarget(mCtx, devicePath, dmLinearName, mappingInfo); err != nil { + if devicePath, err = createZeroSectorLinearTarget(mCtx, devicePath, dmLinearName, mappingInfo); err != nil { return err } defer func() { if err != nil { - if err := dm.RemoveDevice(dmLinearName); err != nil { + if err := removeDevice(dmLinearName); err != nil { log.G(mCtx).WithError(err).Debugf("failed to cleanup linear target: %s", dmLinearName) } } @@ -103,12 +107,12 @@ func Mount(ctx context.Context, device uint32, target string, mappingInfo *prot. if verityInfo != nil { dmVerityName := fmt.Sprintf(verityDeviceFmt, device, verityInfo.RootDigest) - if devicePath, err = dm.CreateVerityTarget(mCtx, devicePath, dmVerityName, verityInfo); err != nil { + if devicePath, err = createVerityTarget(mCtx, devicePath, dmVerityName, verityInfo); err != nil { return err } defer func() { if err != nil { - if err := dm.RemoveDevice(dmVerityName); err != nil { + if err := removeDevice(dmVerityName); err != nil { log.G(mCtx).WithError(err).Debugf("failed to cleanup verity target: %s", dmVerityName) } } diff --git a/guest/storage/pmem/pmem_test.go b/guest/storage/pmem/pmem_test.go index d147ede5f2..be61d70b11 100644 --- a/guest/storage/pmem/pmem_test.go +++ b/guest/storage/pmem/pmem_test.go @@ -5,6 +5,7 @@ package pmem import ( "context" "fmt" + "github.com/Microsoft/hcsshim/internal/guest/prot" "os" "testing" @@ -18,6 +19,10 @@ func clearTestDependencies() { osMkdirAll = nil osRemoveAll = nil unixMount = nil + createZeroSectorLinearTarget = nil + createVerityTarget = nil + removeDevice = nil + mountInternal = mount } func Test_Mount_Mkdir_Fails_Error(t *testing.T) { @@ -305,3 +310,321 @@ func openDoorSecurityPolicyEnforcer() securitypolicy.SecurityPolicyEnforcer { func mountMonitoringSecurityPolicyEnforcer() *policy.MountMonitoringSecurityPolicyEnforcer { return &policy.MountMonitoringSecurityPolicyEnforcer{} } + +// device mapper tests +func Test_CreateLinearTarget_And_Mount_Called_With_Correct_Parameters(t *testing.T) { + clearTestDependencies() + + mappingInfo := &prot.DeviceMappingInfo{ + DeviceOffsetInBytes: 0, + DeviceSizeInBytes: 1024, + } + expectedLinearName := fmt.Sprintf(linearDeviceFmt, 0, mappingInfo.DeviceOffsetInBytes, mappingInfo.DeviceSizeInBytes) + expectedSource := "/dev/pmem0" + expectedTarget := "/foo" + mapperPath := fmt.Sprintf("/dev/mapper/%s", expectedLinearName) + createZSLTCalled := false + + osMkdirAll = func(_ string, _ os.FileMode) error { + return nil + } + + mountInternal = func(_ context.Context, source, target string) error { + if source != mapperPath { + t.Errorf("expected mountInternal source %s, got %s", mapperPath, source) + } + if target != expectedTarget { + t.Errorf("expected mountInternal target %s, got %s", expectedTarget, source) + } + return nil + } + + createZeroSectorLinearTarget = func(_ context.Context, source, name string, mapping *prot.DeviceMappingInfo) (string, error) { + createZSLTCalled = true + if source != expectedSource { + t.Errorf("expected createZeroSectorLinearTarget source %s, got %s", expectedSource, source) + } + if name != expectedLinearName { + t.Errorf("expected createZeroSectorLinearTarget name %s, got %s", expectedLinearName, name) + } + return mapperPath, nil + } + + if err := Mount( + context.Background(), + 0, + expectedTarget, + mappingInfo, + nil, + openDoorSecurityPolicyEnforcer(), + ); err != nil { + t.Fatalf("unexpected error during Mount: %s", err) + } + if !createZSLTCalled { + t.Fatalf("createZeroSectorLinearTarget not called") + } +} + +func Test_CreateVerityTargetCalled_And_Mount_Called_With_Correct_Parameters(t *testing.T) { + clearTestDependencies() + + verityInfo := &prot.DeviceVerityInfo{ + RootDigest: "hash", + } + expectedVerityName := fmt.Sprintf(verityDeviceFmt, 0, verityInfo.RootDigest) + expectedSource := "/dev/pmem0" + expectedTarget := "/foo" + mapperPath := fmt.Sprintf("/dev/mapper/%s", expectedVerityName) + createVerityTargetCalled := false + + mountInternal = func(_ context.Context, source, target string) error { + if source != mapperPath { + t.Errorf("expected mountInternal source %s, got %s", mapperPath, source) + } + if target != expectedTarget { + t.Errorf("expected mountInternal target %s, got %s", expectedTarget, target) + } + return nil + } + createVerityTarget = func(_ context.Context, source, name string, verity *prot.DeviceVerityInfo) (string, error) { + createVerityTargetCalled = true + if source != expectedSource { + t.Errorf("expected createVerityTarget source %s, got %s", expectedSource, source) + } + if name != expectedVerityName { + t.Errorf("expected createVerityTarget name %s, got %s", expectedVerityName, name) + } + return mapperPath, nil + } + + if err := Mount( + context.Background(), + 0, + expectedTarget, + nil, + verityInfo, + openDoorSecurityPolicyEnforcer(), + ); err != nil { + t.Fatalf("unexpected Mount failure: %s", err) + } + if !createVerityTargetCalled { + t.Fatal("createVerityTarget not called") + } +} + +func Test_CreateLinearTarget_And_CreateVerityTargetCalled_Called_Correctly(t *testing.T) { + clearTestDependencies() + + verityInfo := &prot.DeviceVerityInfo{ + RootDigest: "hash", + } + mapping := &prot.DeviceMappingInfo{ + DeviceOffsetInBytes: 0, + DeviceSizeInBytes: 1024, + } + expectedLinearTarget := fmt.Sprintf(linearDeviceFmt, 0, mapping.DeviceOffsetInBytes, mapping.DeviceSizeInBytes) + expectedVerityTarget := fmt.Sprintf(verityDeviceFmt, 0, verityInfo.RootDigest) + expectedPMemDevice := "/dev/pmem0" + mapperLinearPath := fmt.Sprintf("/dev/mapper/%s", expectedLinearTarget) + mapperVerityPath := fmt.Sprintf("/dev/mapper/%s", expectedVerityTarget) + dmLinearCalled := false + dmVerityCalled := false + mountCalled := false + + createZeroSectorLinearTarget = func(_ context.Context, source, name string, mapping *prot.DeviceMappingInfo) (string, error) { + dmLinearCalled = true + if source != expectedPMemDevice { + t.Errorf("expected createZeroSectorLinearTarget source %s, got %s", expectedPMemDevice, source) + } + if name != expectedLinearTarget { + t.Errorf("expected createZeroSectorLinearTarget name %s, got %s", expectedLinearTarget, name) + } + return mapperLinearPath, nil + } + createVerityTarget = func(_ context.Context, source, name string, verity *prot.DeviceVerityInfo) (string, error) { + dmVerityCalled = true + if source != mapperLinearPath { + t.Errorf("expected createVerityTarget source %s, got %s", mapperLinearPath, source) + } + if name != expectedVerityTarget { + t.Errorf("expected createVerityTarget target name %s, got %s", expectedVerityTarget, name) + } + return mapperVerityPath, nil + } + mountInternal = func(_ context.Context, source, target string) error { + mountCalled = true + if source != mapperVerityPath { + t.Errorf("expected Mount source %s, got %s", mapperVerityPath, source) + } + return nil + } + + if err := Mount( + context.Background(), + 0, + "/foo", + mapping, + verityInfo, + openDoorSecurityPolicyEnforcer(), + ); err != nil { + t.Fatalf("unexpected error during Mount call: %s", err) + } + if !dmLinearCalled { + t.Fatal("expected createZeroSectorLinearTarget call") + } + if !dmVerityCalled { + t.Fatal("expected createVerityTarget call") + } + if !mountCalled { + t.Fatal("expected mountInternal call") + } +} + +func Test_RemoveDevice_Called_For_LinearTarget_On_MountInternalFailure(t *testing.T) { + clearTestDependencies() + + mappingInfo := &prot.DeviceMappingInfo{ + DeviceOffsetInBytes: 0, + DeviceSizeInBytes: 1024, + } + expectedError := errors.New("mountInternal error") + expectedTarget := fmt.Sprintf(linearDeviceFmt, 0, mappingInfo.DeviceOffsetInBytes, mappingInfo.DeviceSizeInBytes) + mapperPath := fmt.Sprintf("/dev/mapper/%s", expectedTarget) + removeDeviceCalled := false + + createZeroSectorLinearTarget = func(_ context.Context, source, name string, mapping *prot.DeviceMappingInfo) (string, error) { + return mapperPath, nil + } + mountInternal = func(_ context.Context, source, target string) error { + return expectedError + } + removeDevice = func(name string) error { + removeDeviceCalled = true + if name != expectedTarget { + t.Errorf("expected removeDevice linear target %s, got %s", expectedTarget, name) + } + return nil + } + + if err := Mount( + context.Background(), + 0, + "/foo", + mappingInfo, + nil, + openDoorSecurityPolicyEnforcer(), + ); err != expectedError { + t.Fatalf("expected Mount error %s, got %s", expectedError, err) + } + if !removeDeviceCalled { + t.Fatal("expected removeDevice to be callled") + } +} + +func Test_RemoveDevice_Called_For_VerityTarget_On_MountInternalFailure(t *testing.T) { + clearTestDependencies() + + verity := &prot.DeviceVerityInfo{ + RootDigest: "hash", + } + expectedVerityTarget := fmt.Sprintf(verityDeviceFmt, 0, verity.RootDigest) + expectedError := errors.New("mountInternal error") + mapperPath := fmt.Sprintf("/dev/mapper/%s", expectedVerityTarget) + removeDeviceCalled := false + + createVerityTarget = func(_ context.Context, source, name string, verity *prot.DeviceVerityInfo) (string, error) { + return mapperPath, nil + } + mountInternal = func(_ context.Context, _, _ string) error { + return expectedError + } + removeDevice = func(name string) error { + removeDeviceCalled = true + if name != expectedVerityTarget { + t.Errorf("expected removeDevice verity target %s, got %s", expectedVerityTarget, name) + } + return nil + } + + if err := Mount( + context.Background(), + 0, + "/foo", + nil, + verity, + openDoorSecurityPolicyEnforcer(), + ); err != expectedError { + t.Fatalf("expected Mount error %s, got %s", expectedError, err) + } + if !removeDeviceCalled { + t.Fatal("expected removeDevice to be called") + } +} + +func Test_RemoveDevice_Called_For_Both_Targets_On_MountInternalFailure(t *testing.T) { + clearTestDependencies() + + mapping := &prot.DeviceMappingInfo{ + DeviceOffsetInBytes: 0, + DeviceSizeInBytes: 1024, + } + verity := &prot.DeviceVerityInfo{ + RootDigest: "hash", + } + expectedError := errors.New("mountInternal error") + expectedLinearTarget := fmt.Sprintf(linearDeviceFmt, 0, mapping.DeviceOffsetInBytes, mapping.DeviceSizeInBytes) + expectedVerityTarget := fmt.Sprintf(verityDeviceFmt, 0, verity.RootDigest) + expectedPMemDevice := "/dev/pmem0" + mapperLinearPath := fmt.Sprintf("/dev/mapper/%s", expectedLinearTarget) + mapperVerityPath := fmt.Sprintf("/dev/mapper/%s", expectedVerityTarget) + rmLinearCalled := false + rmVerityCalled := false + + createZeroSectorLinearTarget = func(_ context.Context, source, name string, m *prot.DeviceMappingInfo) (string, error) { + if source != expectedPMemDevice { + t.Errorf("expected createZeroSectorLinearTarget source %s, got %s", expectedPMemDevice, source) + } + return mapperLinearPath, nil + } + createVerityTarget = func(_ context.Context, source, name string, v *prot.DeviceVerityInfo) (string, error) { + if source != mapperLinearPath { + t.Errorf("expected createVerityTarget to be called with %s, got %s", mapperLinearPath, source) + } + if name != expectedVerityTarget { + t.Errorf("expected createVerityTarget target %s, got %s", expectedVerityTarget, name) + } + return mapperVerityPath, nil + } + removeDevice = func(name string) error { + if name != expectedLinearTarget && name != expectedVerityTarget { + t.Errorf("unexpected removeDevice target name %s", name) + } + if name == expectedLinearTarget { + rmLinearCalled = true + } + if name == expectedVerityTarget { + rmVerityCalled = true + } + return nil + } + mountInternal = func(_ context.Context, _, _ string) error { + return expectedError + } + + if err := Mount( + context.Background(), + 0, + "/foo", + mapping, + verity, + openDoorSecurityPolicyEnforcer(), + ); err != expectedError { + t.Fatalf("expected Mount error %s, got %s", expectedError, err) + } + if !rmLinearCalled { + t.Fatal("expected removeDevice for linear target to be called") + } + if !rmVerityCalled { + t.Fatal("expected removeDevice for verity target to be called") + } +} diff --git a/guest/storage/scsi/scsi.go b/guest/storage/scsi/scsi.go index c95ab9d555..fbcebf3754 100644 --- a/guest/storage/scsi/scsi.go +++ b/guest/storage/scsi/scsi.go @@ -5,6 +5,7 @@ package scsi import ( "context" "fmt" + dm "github.com/Microsoft/hcsshim/internal/guest/storage/devicemapper" "io/ioutil" "os" "path/filepath" @@ -29,10 +30,15 @@ var ( // controllerLunToName is stubbed to make testing `Mount` easier. controllerLunToName = ControllerLunToName + // createVerityTarget is stubbed for unit testing `Mount` + createVerityTarget = dm.CreateVerityTarget + // removeDevice is stubbed for unit testing `Mount` + removeDevice = dm.RemoveDevice ) const ( scsiDevicesPath = "/sys/bus/scsi/devices" + verityDeviceFmt = "verity-scsi-contr%d-lun%d-%s" ) // Mount creates a mount from the SCSI device on `controller` index `lun` to @@ -52,16 +58,36 @@ func Mount(ctx context.Context, controller, lun uint8, target string, readonly b trace.Int64Attribute("controller", int64(controller)), trace.Int64Attribute("lun", int64(lun))) + source, err := controllerLunToName(spnCtx, controller, lun) + if err != nil { + return err + } + if readonly { // containers only have read-only layers so only enforce for them var deviceHash string if verityInfo != nil { deviceHash = verityInfo.RootDigest } + err = securityPolicy.EnforceDeviceMountPolicy(target, deviceHash) if err != nil { return errors.Wrapf(err, "won't mount scsi controller %d lun %d onto %s", controller, lun, target) } + + if verityInfo != nil { + dmVerityName := fmt.Sprintf(verityDeviceFmt, controller, lun, deviceHash) + if source, err = createVerityTarget(spnCtx, source, dmVerityName, verityInfo); err != nil { + return err + } + defer func() { + if err != nil { + if err := removeDevice(dmVerityName); err != nil { + log.G(spnCtx).WithError(err).WithField("verityTarget", dmVerityName).Debug("failed to cleanup verity target") + } + } + }() + } } if err := osMkdirAll(target, 0700); err != nil { @@ -72,10 +98,6 @@ func Mount(ctx context.Context, controller, lun uint8, target string, readonly b osRemoveAll(target) } }() - source, err := controllerLunToName(spnCtx, controller, lun) - if err != nil { - return err - } // we only care about readonly mount option when mounting the device var flags uintptr @@ -147,6 +169,14 @@ func Unmount(ctx context.Context, controller, lun uint8, target string, encrypte return errors.Wrapf(err, "unmount failed: "+target) } + if verityInfo != nil { + dmVerityName := fmt.Sprintf(verityDeviceFmt, controller, lun, verityInfo.RootDigest) + if err := removeDevice(dmVerityName); err != nil { + // Ignore failures, since the path has been unmounted at this point. + log.G(ctx).WithError(err).Debugf("failed to remove dm verity target: %s", dmVerityName) + } + } + if encrypted { if err := crypt.CleanupCryptDevice(target); err != nil { return errors.Wrapf(err, "failed to cleanup dm-crypt state: "+target) diff --git a/guest/storage/scsi/scsi_test.go b/guest/storage/scsi/scsi_test.go index 99cc77533d..853cb83df7 100644 --- a/guest/storage/scsi/scsi_test.go +++ b/guest/storage/scsi/scsi_test.go @@ -5,6 +5,8 @@ package scsi import ( "context" "errors" + "fmt" + "github.com/Microsoft/hcsshim/internal/guest/prot" "os" "testing" @@ -18,6 +20,7 @@ func clearTestDependencies() { osRemoveAll = nil unixMount = nil controllerLunToName = nil + createVerityTarget = nil } func Test_Mount_Mkdir_Fails_Error(t *testing.T) { @@ -27,8 +30,22 @@ func Test_Mount_Mkdir_Fails_Error(t *testing.T) { osMkdirAll = func(path string, perm os.FileMode) error { return expectedErr } - err := Mount(context.Background(), 0, 0, "", false, false, nil, nil, openDoorSecurityPolicyEnforcer()) - if err != expectedErr { + + controllerLunToName = func(ctx context.Context, controller, lun uint8) (string, error) { + return "", nil + } + + if err := Mount( + context.Background(), + 0, + 0, + "", + false, + false, + nil, + nil, + openDoorSecurityPolicyEnforcer(), + ); err != expectedErr { t.Fatalf("expected err: %v, got: %v", expectedErr, err) } } @@ -54,8 +71,18 @@ func Test_Mount_Mkdir_ExpectedPath(t *testing.T) { // Fake the mount success return nil } - err := Mount(context.Background(), 0, 0, target, false, false, nil, nil, openDoorSecurityPolicyEnforcer()) - if err != nil { + + if err := Mount( + context.Background(), + 0, + 0, + target, + false, + false, + nil, + nil, + openDoorSecurityPolicyEnforcer(), + ); err != nil { t.Fatalf("expected nil error got: %v", err) } } @@ -81,8 +108,18 @@ func Test_Mount_Mkdir_ExpectedPerm(t *testing.T) { // Fake the mount success return nil } - err := Mount(context.Background(), 0, 0, target, false, false, nil, nil, openDoorSecurityPolicyEnforcer()) - if err != nil { + + if err := Mount( + context.Background(), + 0, + 0, + target, + false, + false, + nil, + nil, + openDoorSecurityPolicyEnforcer(), + ); err != nil { t.Fatalf("expected nil error got: %v", err) } } @@ -108,8 +145,18 @@ func Test_Mount_ControllerLunToName_Valid_Controller(t *testing.T) { // Fake the mount success return nil } - err := Mount(context.Background(), expectedController, 0, "/fake/path", false, false, nil, nil, openDoorSecurityPolicyEnforcer()) - if err != nil { + + if err := Mount( + context.Background(), + expectedController, + 0, + "/fake/path", + false, + false, + nil, + nil, + openDoorSecurityPolicyEnforcer(), + ); err != nil { t.Fatalf("expected nil error got: %v", err) } } @@ -135,42 +182,19 @@ func Test_Mount_ControllerLunToName_Valid_Lun(t *testing.T) { // Fake the mount success return nil } - err := Mount(context.Background(), 0, expectedLun, "/fake/path", false, false, nil, nil, openDoorSecurityPolicyEnforcer()) - if err != nil { - t.Fatalf("expected nil error got: %v", err) - } -} - -func Test_Mount_Calls_RemoveAll_OnControllerToLunFailure(t *testing.T) { - clearTestDependencies() - osMkdirAll = func(path string, perm os.FileMode) error { - return nil - } - expectedErr := errors.New("expected controller to lun failure") - controllerLunToName = func(ctx context.Context, controller, lun uint8) (string, error) { - return "", expectedErr - } - target := "/fake/path" - removeAllCalled := false - osRemoveAll = func(path string) error { - removeAllCalled = true - if path != target { - t.Errorf("expected path: %v, got: %v", target, path) - return errors.New("unexpected path") - } - return nil - } - - // NOTE: Do NOT set unixMount because the controller to lun fails. Expect it - // not to be called. - - err := Mount(context.Background(), 0, 0, target, false, false, nil, nil, openDoorSecurityPolicyEnforcer()) - if err != expectedErr { - t.Fatalf("expected err: %v, got: %v", expectedErr, err) - } - if !removeAllCalled { - t.Fatal("expected os.RemoveAll to be called on mount failure") + if err := Mount( + context.Background(), + 0, + expectedLun, + "/fake/path", + false, + false, + nil, + nil, + openDoorSecurityPolicyEnforcer(), + ); err != nil { + t.Fatalf("expected nil error got: %v", err) } } @@ -198,8 +222,18 @@ func Test_Mount_Calls_RemoveAll_OnMountFailure(t *testing.T) { // Fake the mount failure to test remove is called return expectedErr } - err := Mount(context.Background(), 0, 0, target, false, false, nil, nil, openDoorSecurityPolicyEnforcer()) - if err != expectedErr { + + if err := Mount( + context.Background(), + 0, + 0, + target, + false, + false, + nil, + nil, + openDoorSecurityPolicyEnforcer(), + ); err != expectedErr { t.Fatalf("expected err: %v, got: %v", expectedErr, err) } if !removeAllCalled { @@ -253,8 +287,18 @@ func Test_Mount_Valid_Target(t *testing.T) { } return nil } - err := Mount(context.Background(), 0, 0, expectedTarget, false, false, nil, nil, openDoorSecurityPolicyEnforcer()) - if err != nil { + + if err := Mount( + context.Background(), + 0, + 0, + expectedTarget, + false, + false, + nil, + nil, + openDoorSecurityPolicyEnforcer(), + ); err != nil { t.Fatalf("expected nil err, got: %v", err) } } @@ -279,8 +323,18 @@ func Test_Mount_Valid_FSType(t *testing.T) { } return nil } - err := Mount(context.Background(), 0, 0, "/fake/path", false, false, nil, nil, openDoorSecurityPolicyEnforcer()) - if err != nil { + + if err := Mount( + context.Background(), + 0, + 0, + "/fake/path", + false, + false, + nil, + nil, + openDoorSecurityPolicyEnforcer(), + ); err != nil { t.Fatalf("expected nil err, got: %v", err) } } @@ -305,8 +359,18 @@ func Test_Mount_Valid_Flags(t *testing.T) { } return nil } - err := Mount(context.Background(), 0, 0, "/fake/path", false, false, nil, nil, openDoorSecurityPolicyEnforcer()) - if err != nil { + + if err := Mount( + context.Background(), + 0, + 0, + "/fake/path", + false, + false, + nil, + nil, + openDoorSecurityPolicyEnforcer(), + ); err != nil { t.Fatalf("expected nil err, got: %v", err) } } @@ -331,8 +395,18 @@ func Test_Mount_Readonly_Valid_Flags(t *testing.T) { } return nil } - err := Mount(context.Background(), 0, 0, "/fake/path", true, false, nil, nil, openDoorSecurityPolicyEnforcer()) - if err != nil { + + if err := Mount( + context.Background(), + 0, + 0, + "/fake/path", + true, + false, + nil, + nil, + openDoorSecurityPolicyEnforcer(), + ); err != nil { t.Fatalf("expected nil err, got: %v", err) } } @@ -356,8 +430,18 @@ func Test_Mount_Valid_Data(t *testing.T) { } return nil } - err := Mount(context.Background(), 0, 0, "/fake/path", false, false, nil, nil, openDoorSecurityPolicyEnforcer()) - if err != nil { + + if err := Mount( + context.Background(), + 0, + 0, + "/fake/path", + false, + false, + nil, + nil, + openDoorSecurityPolicyEnforcer(), + ); err != nil { t.Fatalf("expected nil err, got: %v", err) } } @@ -382,8 +466,18 @@ func Test_Mount_Readonly_Valid_Data(t *testing.T) { } return nil } - err := Mount(context.Background(), 0, 0, "/fake/path", true, false, nil, nil, openDoorSecurityPolicyEnforcer()) - if err != nil { + + if err := Mount( + context.Background(), + 0, + 0, + "/fake/path", + true, + false, + nil, + nil, + openDoorSecurityPolicyEnforcer(), + ); err != nil { t.Fatalf("expected nil err, got: %v", err) } } @@ -529,3 +623,112 @@ func openDoorSecurityPolicyEnforcer() securitypolicy.SecurityPolicyEnforcer { func mountMonitoringSecurityPolicyEnforcer() *policy.MountMonitoringSecurityPolicyEnforcer { return &policy.MountMonitoringSecurityPolicyEnforcer{} } + +// dm-verity tests +func Test_CreateVerityTarget_And_Mount_Called_With_Correct_Parameters(t *testing.T) { + clearTestDependencies() + + expectedVerityName := fmt.Sprintf(verityDeviceFmt, 0, 0, "hash") + expectedSource := "/dev/sdb" + expectedMapperPath := fmt.Sprintf("/dev/mapper/%s", expectedVerityName) + expectedTarget := "/foo" + createVerityTargetCalled := false + + controllerLunToName = func(_ context.Context, _, _ uint8) (string, error) { + return expectedSource, nil + } + + osMkdirAll = func(_ string, _ os.FileMode) error { + return nil + } + + vInfo := &prot.DeviceVerityInfo{ + RootDigest: "hash", + } + createVerityTarget = func(_ context.Context, source, name string, verityInfo *prot.DeviceVerityInfo) (string, error) { + createVerityTargetCalled = true + if source != expectedSource { + t.Errorf("expected source %s, got %s", expectedSource, source) + } + if name != expectedVerityName { + t.Errorf("expected verity target name %s, got %s", expectedVerityName, name) + } + return expectedMapperPath, nil + } + + unixMount = func(source string, target string, fstype string, flags uintptr, data string) error { + if source != expectedMapperPath { + t.Errorf("expected unixMount source %s, got %s", expectedMapperPath, source) + } + if target != expectedTarget { + t.Errorf("expected unixMount target %s, got %s", expectedTarget, target) + } + return nil + } + + if err := Mount( + context.Background(), + 0, + 0, + expectedTarget, + true, + false, + nil, + vInfo, + openDoorSecurityPolicyEnforcer(), + ); err != nil { + t.Fatalf("unexpected error during Mount: %s", err) + } + if !createVerityTargetCalled { + t.Fatalf("expected createVerityTargetCalled to be called") + } +} + +func Test_osMkdirAllFails_And_RemoveDevice_Called(t *testing.T) { + clearTestDependencies() + + expectedError := errors.New("osMkdirAll error") + expectedVerityName := fmt.Sprintf(verityDeviceFmt, 0, 0, "hash") + removeDeviceCalled := false + + controllerLunToName = func(_ context.Context, _, _ uint8) (string, error) { + return "/dev/sdb", nil + } + + osMkdirAll = func(_ string, _ os.FileMode) error { + return expectedError + } + + verityInfo := &prot.DeviceVerityInfo{ + RootDigest: "hash", + } + + createVerityTarget = func(_ context.Context, _, _ string, _ *prot.DeviceVerityInfo) (string, error) { + return fmt.Sprintf("/dev/mapper/%s", expectedVerityName), nil + } + + removeDevice = func(name string) error { + removeDeviceCalled = true + if name != expectedVerityName { + t.Errorf("expected RemoveDevice name %s, got %s", expectedVerityName, name) + } + return nil + } + + if err := Mount( + context.Background(), + 0, + 0, + "/foo", + true, + false, + nil, + verityInfo, + openDoorSecurityPolicyEnforcer(), + ); err != expectedError { + t.Fatalf("expected Mount error %s, got %s", expectedError, err) + } + if !removeDeviceCalled { + t.Fatal("expected removeDevice to be called") + } +} diff --git a/uvm/scsi.go b/uvm/scsi.go index 652cc5b964..f9587e908a 100644 --- a/uvm/scsi.go +++ b/uvm/scsi.go @@ -443,7 +443,7 @@ func (uvm *UtilityVM) addSCSIActual(ctx context.Context, addReq *addSCSIRequest) log.G(ctx).WithFields(logrus.Fields{ "hostPath": sm.HostPath, "rootDigest": v.RootDigest, - }).Debug("adding VPMem with dm-verity") + }).Debug("adding SCSI with dm-verity") } verity = v }