From e0166f29f2972633556e1c4939f43eb4a7cdee98 Mon Sep 17 00:00:00 2001 From: Maksim An Date: Tue, 21 Sep 2021 17:56:48 -0700 Subject: [PATCH] tests: add pmem and scsi unit tests for linear/verity targets Signed-off-by: Maksim An --- internal/guest/storage/pmem/pmem.go | 22 +- internal/guest/storage/pmem/pmem_test.go | 293 +++++++++++++++++++++++ internal/guest/storage/scsi/scsi.go | 37 ++- internal/guest/storage/scsi/scsi_test.go | 98 ++++++++ 4 files changed, 421 insertions(+), 29 deletions(-) diff --git a/internal/guest/storage/pmem/pmem.go b/internal/guest/storage/pmem/pmem.go index 8af7207bad..9291db42f3 100644 --- a/internal/guest/storage/pmem/pmem.go +++ b/internal/guest/storage/pmem/pmem.go @@ -21,9 +21,13 @@ import ( // Test dependencies var ( - osMkdirAll = os.MkdirAll - osRemoveAll = os.RemoveAll - unixMount = unix.Mount + osMkdirAll = os.MkdirAll + osRemoveAll = os.RemoveAll + unixMount = unix.Mount + mountInternal = mount + createLinearTarget = dm.CreateZeroSectorLinearTarget + veritySetup = dm.CreateVerityTarget + removeDevice = dm.RemoveDevice ) const ( @@ -32,8 +36,8 @@ const ( verityDeviceFmt = "dm-verity-pmem%d-%s" ) -// mountInternal mounts source to target via unix.Mount -func mountInternal(ctx context.Context, source, target string) (err error) { +// mount mounts source to target via unix.Mount +func mount(ctx context.Context, source, target string) (err error) { if err := osMkdirAll(target, 0700); err != nil { return err } @@ -89,12 +93,12 @@ func Mount(ctx context.Context, device uint32, target string, mappingInfo *prot. // device instead of the original VPMem. if mappingInfo != nil { dmLinearName := fmt.Sprintf(linearDeviceFmt, device, mappingInfo.DeviceOffsetInBytes, mappingInfo.DeviceSizeInBytes) - if devicePath, err = dm.CreateZeroSectorLinearTarget(mCtx, devicePath, dmLinearName, mappingInfo); err != nil { + if devicePath, err = createLinearTarget(mCtx, devicePath, dmLinearName, mappingInfo); err != nil { return err } defer func() { if err != nil { - if err := dm.RemoveDevice(dmLinearName); err != nil { + if err := removeDevice(dmLinearName); err != nil { log.G(mCtx).WithError(err).Debugf("failed to cleanup linear target: %s", dmLinearName) } } @@ -103,12 +107,12 @@ func Mount(ctx context.Context, device uint32, target string, mappingInfo *prot. if verityInfo != nil { dmVerityName := fmt.Sprintf(verityDeviceFmt, device, verityInfo.RootDigest) - if devicePath, err = dm.CreateVerityTarget(mCtx, devicePath, dmVerityName, verityInfo); err != nil { + if devicePath, err = veritySetup(mCtx, devicePath, dmVerityName, verityInfo); err != nil { return err } defer func() { if err != nil { - if err := dm.RemoveDevice(dmVerityName); err != nil { + if err := removeDevice(dmVerityName); err != nil { log.G(mCtx).WithError(err).Debugf("failed to cleanup verity target: %s", dmVerityName) } } diff --git a/internal/guest/storage/pmem/pmem_test.go b/internal/guest/storage/pmem/pmem_test.go index d147ede5f2..60b911503d 100644 --- a/internal/guest/storage/pmem/pmem_test.go +++ b/internal/guest/storage/pmem/pmem_test.go @@ -5,6 +5,7 @@ package pmem import ( "context" "fmt" + "github.com/Microsoft/hcsshim/internal/guest/prot" "os" "testing" @@ -18,6 +19,10 @@ func clearTestDependencies() { osMkdirAll = nil osRemoveAll = nil unixMount = nil + createLinearTarget = nil + veritySetup = nil + removeDevice = nil + mountInternal = mount } func Test_Mount_Mkdir_Fails_Error(t *testing.T) { @@ -305,3 +310,291 @@ func openDoorSecurityPolicyEnforcer() securitypolicy.SecurityPolicyEnforcer { func mountMonitoringSecurityPolicyEnforcer() *policy.MountMonitoringSecurityPolicyEnforcer { return &policy.MountMonitoringSecurityPolicyEnforcer{} } + +// device mapper tests +func Test_CreateLinearTarget_And_Mount_Called_With_Correct_Parameters(t *testing.T) { + clearTestDependencies() + + mappingInfo := &prot.DeviceMappingInfo{ + DeviceOffsetInBytes: 0, + DeviceSizeInBytes: 1024, + } + expectedLinearName := fmt.Sprintf(linearDeviceFmt, 0, mappingInfo.DeviceOffsetInBytes, mappingInfo.DeviceSizeInBytes) + expectedSource := "/dev/pmem0" + expectedTarget := "/foo" + mapperPath := fmt.Sprintf("/dev/mapper/%s", expectedLinearName) + createLTCalled := false + + osMkdirAll = func(_ string, _ os.FileMode) error { + return nil + } + + mountInternal = func(_ context.Context, source, target string) error { + if source != mapperPath { + t.Errorf("expected mountInternal source %s, got %s", mapperPath, source) + } + if target != expectedTarget { + t.Errorf("expected mountInternal target %s, got %s", expectedTarget, source) + } + return nil + } + + createLinearTarget = func(_ context.Context, source, name string, mapping *prot.DeviceMappingInfo) (string, error) { + createLTCalled = true + if source != expectedSource { + t.Errorf("expected createLinearTarget source %s, got %s", expectedSource, source) + } + if name != expectedLinearName { + t.Errorf("expected createLinearTarget name %s, got %s", expectedLinearName, name) + } + return mapperPath, nil + } + + if err := Mount( + context.Background(), 0, expectedTarget, mappingInfo, nil, openDoorSecurityPolicyEnforcer(), + ); err != nil { + t.Fatalf("unexpected error during Mount: %s", err) + } + if !createLTCalled { + t.Fatalf("createLinearTarget not called") + } +} + +func Test_VeritySetup_And_Mount_Called_With_Correct_Parameters(t *testing.T) { + clearTestDependencies() + + verityInfo := &prot.DeviceVerityInfo{ + RootDigest: "hash", + } + expectedVerityName := fmt.Sprintf(verityDeviceFmt, 0, verityInfo.RootDigest) + expectedSource := "/dev/pmem0" + expectedTarget := "/foo" + mapperPath := fmt.Sprintf("/dev/mapper/%s", expectedVerityName) + veritySetupCalled := false + + mountInternal = func(_ context.Context, source, target string) error { + if source != mapperPath { + t.Errorf("expected mountInternal source %s, got %s", mapperPath, source) + } + if target != expectedTarget { + t.Errorf("expected mountInternal target %s, got %s", expectedTarget, target) + } + return nil + } + veritySetup = func(_ context.Context, source, name string, verity *prot.DeviceVerityInfo) (string, error) { + veritySetupCalled = true + if source != expectedSource { + t.Errorf("expected veritySetup source %s, got %s", expectedSource, source) + } + if name != expectedVerityName { + t.Errorf("expected veritySetup name %s, got %s", expectedVerityName, name) + } + return mapperPath, nil + } + + if err := Mount( + context.Background(), 0, expectedTarget, nil, verityInfo, openDoorSecurityPolicyEnforcer(), + ); err != nil { + t.Fatalf("unexpected Mount failure: %s", err) + } + if !veritySetupCalled { + t.Fatal("veritySetup not called") + } +} + +func Test_CreateLinearTarget_And_VeritySetup_Called_Correctly(t *testing.T) { + clearTestDependencies() + + verityInfo := &prot.DeviceVerityInfo{ + RootDigest: "hash", + } + mapping := &prot.DeviceMappingInfo{ + DeviceOffsetInBytes: 0, + DeviceSizeInBytes: 1024, + } + expectedLinearTarget := fmt.Sprintf(linearDeviceFmt, 0, mapping.DeviceOffsetInBytes, mapping.DeviceSizeInBytes) + expectedVerityTarget := fmt.Sprintf(verityDeviceFmt, 0, verityInfo.RootDigest) + expectedPMemDevice := "/dev/pmem0" + mapperLinearPath := fmt.Sprintf("/dev/mapper/%s", expectedLinearTarget) + mapperVerityPath := fmt.Sprintf("/dev/mapper/%s", expectedVerityTarget) + dmLinearCalled := false + dmVerityCalled := false + mountCalled := false + + createLinearTarget = func(_ context.Context, source, name string, mapping *prot.DeviceMappingInfo) (string, error) { + dmLinearCalled = true + if source != expectedPMemDevice { + t.Errorf("expected createLinearTarget source %s, got %s", expectedPMemDevice, source) + } + if name != expectedLinearTarget { + t.Errorf("expected createLineartarget name %s, got %s", expectedLinearTarget, name) + } + return mapperLinearPath, nil + } + veritySetup = func(_ context.Context, source, name string, verity *prot.DeviceVerityInfo) (string, error) { + dmVerityCalled = true + if source != mapperLinearPath { + t.Errorf("expected veritySetup source %s, got %s", mapperLinearPath, source) + } + if name != expectedVerityTarget { + t.Errorf("expected veritySetup target name %s, got %s", expectedVerityTarget, name) + } + return mapperVerityPath, nil + } + mountInternal = func(_ context.Context, source, target string) error { + mountCalled = true + if source != mapperVerityPath { + t.Errorf("expected Mount source %s, got %s", mapperVerityPath, source) + } + return nil + } + + if err := Mount( + context.Background(), 0, "/foo", mapping, verityInfo, openDoorSecurityPolicyEnforcer(), + ); err != nil { + t.Fatalf("unexpected error during Mount call: %s", err) + } + if !dmLinearCalled { + t.Fatal("expected createLinearTarget call") + } + if !dmVerityCalled { + t.Fatal("expected veritySetup call") + } + if !mountCalled { + t.Fatal("expected mountInternal call") + } +} + +func Test_RemoveDevice_Called_For_LinearTarget_On_MountInternalFailure(t *testing.T) { + clearTestDependencies() + + mappingInfo := &prot.DeviceMappingInfo{ + DeviceOffsetInBytes: 0, + DeviceSizeInBytes: 1024, + } + expectedError := errors.New("mountInternal error") + expectedTarget := fmt.Sprintf(linearDeviceFmt, 0, mappingInfo.DeviceOffsetInBytes, mappingInfo.DeviceSizeInBytes) + mapperPath := fmt.Sprintf("/dev/mapper/%s", expectedTarget) + removeDeviceCalled := false + + createLinearTarget = func(_ context.Context, source, name string, mapping *prot.DeviceMappingInfo) (string, error) { + return mapperPath, nil + } + mountInternal = func(_ context.Context, source, target string) error { + return expectedError + } + removeDevice = func(name string) error { + removeDeviceCalled = true + if name != expectedTarget { + t.Errorf("expected removeDevice linear target %s, got %s", expectedTarget, name) + } + return nil + } + + if err := Mount( + context.Background(), 0, "/foo", mappingInfo, nil, openDoorSecurityPolicyEnforcer(), + ); err != expectedError { + t.Fatalf("expected Mount error %s, got %s", expectedError, err) + } + if !removeDeviceCalled { + t.Fatal("expected removeDevice to be callled") + } +} + +func Test_RemoveDevice_Called_For_VerityTarget_On_MountInternalFailure(t *testing.T) { + clearTestDependencies() + + verity := &prot.DeviceVerityInfo{ + RootDigest: "hash", + } + expectedVerityTarget := fmt.Sprintf(verityDeviceFmt, 0, verity.RootDigest) + expectedError := errors.New("mountInternal error") + mapperPath := fmt.Sprintf("/dev/mapper/%s", expectedVerityTarget) + removeDeviceCalled := false + + veritySetup = func(_ context.Context, source, name string, verity *prot.DeviceVerityInfo) (string, error) { + return mapperPath, nil + } + mountInternal = func(_ context.Context, _, _ string) error { + return expectedError + } + removeDevice = func(name string) error { + removeDeviceCalled = true + if name != expectedVerityTarget { + t.Errorf("expected removeDevice verity target %s, got %s", expectedVerityTarget, name) + } + return nil + } + + if err := Mount( + context.Background(), 0, "/foo", nil, verity, openDoorSecurityPolicyEnforcer(), + ); err != expectedError { + t.Fatalf("expected Mount error %s, got %s", expectedError, err) + } + if !removeDeviceCalled { + t.Fatal("expected removeDevice to be called") + } +} + +func Test_RemoveDevice_Called_For_Both_Targets_On_MountInternalFailure(t *testing.T) { + clearTestDependencies() + + mapping := &prot.DeviceMappingInfo{ + DeviceOffsetInBytes: 0, + DeviceSizeInBytes: 1024, + } + verity := &prot.DeviceVerityInfo{ + RootDigest: "hash", + } + expectedError := errors.New("mountInternal error") + expectedLinearTarget := fmt.Sprintf(linearDeviceFmt, 0, mapping.DeviceOffsetInBytes, mapping.DeviceSizeInBytes) + expectedVerityTarget := fmt.Sprintf(verityDeviceFmt, 0, verity.RootDigest) + expectedPMemDevice := "/dev/pmem0" + mapperLinearPath := fmt.Sprintf("/dev/mapper/%s", expectedLinearTarget) + mapperVerityPath := fmt.Sprintf("/dev/mapper/%s", expectedVerityTarget) + rmLinearCalled := false + rmVerityCalled := false + + createLinearTarget = func(_ context.Context, source, name string, m *prot.DeviceMappingInfo) (string, error) { + if source != expectedLinearTarget { + t.Errorf("expected createLinearTarget source %s, got %s", expectedPMemDevice, source) + } + return mapperLinearPath, nil + } + veritySetup = func(_ context.Context, source, name string, v *prot.DeviceVerityInfo) (string, error) { + if source != mapperLinearPath { + t.Errorf("expected veritySetup to be called with %s, got %s", mapperLinearPath, source) + } + if name != expectedVerityTarget { + t.Errorf("expected veritySetup target %s, got %s", expectedVerityTarget, name) + } + return mapperVerityPath, nil + } + removeDevice = func(name string) error { + if name != expectedLinearTarget && name != expectedVerityTarget { + t.Errorf("unexpected removeDevice target name %s", name) + } + if name == expectedLinearTarget { + rmLinearCalled = true + } + if name == expectedVerityTarget { + rmVerityCalled = true + } + return nil + } + mountInternal = func(_ context.Context, _, _ string) error { + return expectedError + } + + if err := Mount( + context.Background(), 0, "/foo", mapping, verity, openDoorSecurityPolicyEnforcer(), + ); err != expectedError { + t.Fatalf("expected Mount error %s, got %s", expectedError, err) + } + if !rmLinearCalled { + t.Fatal("expected removeDevice for linear target to be called") + } + if !rmVerityCalled { + t.Fatal("expected removeDevice for verity target to be called") + } +} diff --git a/internal/guest/storage/scsi/scsi.go b/internal/guest/storage/scsi/scsi.go index 2591e1121c..a1ac0db085 100644 --- a/internal/guest/storage/scsi/scsi.go +++ b/internal/guest/storage/scsi/scsi.go @@ -30,6 +30,10 @@ var ( // controllerLunToName is stubbed to make testing `Mount` easier. controllerLunToName = ControllerLunToName + // veritySetup is stubbed for unit testing `Mount` + veritySetup = dm.CreateVerityTarget + // removeDevice is stubbed for unit testing `Mount` + removeDevice = dm.RemoveDevice ) const ( @@ -62,25 +66,8 @@ func Mount(ctx context.Context, controller, lun uint8, target string, readonly b if readonly { // containers only have read-only layers so only enforce for them var deviceHash string - verityHandler := func() error { - return nil - } if verityInfo != nil { deviceHash = verityInfo.RootDigest - verityHandler = func() error { - dmVerityName := fmt.Sprintf(verityDeviceFmt, controller, lun, deviceHash) - if source, err = dm.CreateVerityTarget(ctx, source, dmVerityName, verityInfo); err != nil { - return err - } - defer func() { - if err != nil { - if err := dm.RemoveDevice(dmVerityName); err != nil { - log.G(spnCtx).WithError(err).WithField("verityTarget", dmVerityName).Debug("failed to cleanup verity target") - } - } - }() - return nil - } } err = securityPolicy.EnforceDeviceMountPolicy(target, deviceHash) @@ -88,8 +75,18 @@ func Mount(ctx context.Context, controller, lun uint8, target string, readonly b return errors.Wrapf(err, "won't mount scsi controller %d lun %d onto %s", controller, lun, target) } - if err := verityHandler(); err != nil { - return err + if verityInfo != nil { + dmVerityName := fmt.Sprintf(verityDeviceFmt, controller, lun, deviceHash) + if source, err = veritySetup(ctx, source, dmVerityName, verityInfo); err != nil { + return err + } + defer func() { + if err != nil { + if err := removeDevice(dmVerityName); err != nil { + log.G(spnCtx).WithError(err).WithField("verityTarget", dmVerityName).Debug("failed to cleanup verity target") + } + } + }() } } @@ -169,7 +166,7 @@ func Unmount(ctx context.Context, controller, lun uint8, target string, encrypte if verityInfo != nil { dmVerityName := fmt.Sprintf(verityDeviceFmt, controller, lun, verityInfo.RootDigest) - if err := dm.RemoveDevice(dmVerityName); err != nil { + if err := removeDevice(dmVerityName); err != nil { return errors.Wrapf(err, "failed to remove dm verity target: %s", dmVerityName) } } diff --git a/internal/guest/storage/scsi/scsi_test.go b/internal/guest/storage/scsi/scsi_test.go index 23d3182aed..4b53b62896 100644 --- a/internal/guest/storage/scsi/scsi_test.go +++ b/internal/guest/storage/scsi/scsi_test.go @@ -5,6 +5,8 @@ package scsi import ( "context" "errors" + "fmt" + "github.com/Microsoft/hcsshim/internal/guest/prot" "os" "testing" @@ -18,6 +20,7 @@ func clearTestDependencies() { osRemoveAll = nil unixMount = nil controllerLunToName = nil + veritySetup = nil } func Test_Mount_Mkdir_Fails_Error(t *testing.T) { @@ -501,3 +504,98 @@ func openDoorSecurityPolicyEnforcer() securitypolicy.SecurityPolicyEnforcer { func mountMonitoringSecurityPolicyEnforcer() *policy.MountMonitoringSecurityPolicyEnforcer { return &policy.MountMonitoringSecurityPolicyEnforcer{} } + +// dm-verity tests +func Test_CreateVerityTarget_And_Mount_Called_With_Correct_Parameters(t *testing.T) { + clearTestDependencies() + + expectedVerityName := fmt.Sprintf(verityDeviceFmt, 0, 0, "hash") + expectedSource := "/dev/sdb" + expectedMapperPath := fmt.Sprintf("/dev/mapper/%s", expectedVerityName) + expectedTarget := "/foo" + veritySetupCalled := false + + controllerLunToName = func(_ context.Context, _, _ uint8) (string, error) { + return expectedSource, nil + } + + osMkdirAll = func(_ string, _ os.FileMode) error { + return nil + } + + vInfo := &prot.DeviceVerityInfo{ + RootDigest: "hash", + } + veritySetup = func(_ context.Context, source, name string, verityInfo *prot.DeviceVerityInfo) (string, error) { + veritySetupCalled = true + if source != expectedSource { + t.Errorf("expected source %s, got %s", expectedSource, source) + } + if name != expectedVerityName { + t.Errorf("expected verity target name %s, got %s", expectedVerityName, name) + } + return expectedMapperPath, nil + } + + unixMount = func(source string, target string, fstype string, flags uintptr, data string) error { + if source != expectedMapperPath { + t.Errorf("expected unixMount source %s, got %s", expectedMapperPath, source) + } + if target != expectedTarget { + t.Errorf("expected unixMount target %s, got %s", expectedTarget, target) + } + return nil + } + + if err := Mount( + context.Background(), 0, 0, expectedTarget, true, false, nil, vInfo, + openDoorSecurityPolicyEnforcer(), + ); err != nil { + t.Fatalf("unexpected error during Mount: %s", err) + } + if !veritySetupCalled { + t.Fatalf("expected veritySetup to be called") + } +} + +func Test_osMkdirAllFails_And_RemoveDevice_Called(t *testing.T) { + clearTestDependencies() + + expectedError := errors.New("osMkdirAll error") + expectedVerityName := fmt.Sprintf(verityDeviceFmt, 0, 0, "hash") + removeDeviceCalled := false + + controllerLunToName = func(_ context.Context, _, _ uint8) (string, error) { + return "/dev/sdb", nil + } + + osMkdirAll = func(_ string, _ os.FileMode) error { + return expectedError + } + + verityInfo := &prot.DeviceVerityInfo{ + RootDigest: "hash", + } + + veritySetup = func(_ context.Context, _, _ string, _ *prot.DeviceVerityInfo) (string, error) { + return fmt.Sprintf("/dev/mapper/%s", expectedVerityName), nil + } + + removeDevice = func(name string) error { + removeDeviceCalled = true + if name != expectedVerityName { + t.Errorf("expected RemoveDevice name %s, got %s", expectedVerityName, name) + } + return nil + } + + if err := Mount( + context.Background(), 0, 0, "/foo", true, false, nil, verityInfo, + openDoorSecurityPolicyEnforcer(), + ); err != expectedError { + t.Fatalf("expected Mount error %s, got %s", expectedError, err) + } + if !removeDeviceCalled { + t.Fatal("expected removeDevice to be called") + } +}