Skip to content

Commit

Permalink
Extend integrity protection of LCOW layers to SCSI devices (microsoft…
Browse files Browse the repository at this point in the history
…#1170)

* extend integrity protection of LCOW layers to SCSI devices

LCOW layers can be added both as VPMem and as SCSI devices.
Previous work focused on enabling integrity protection for read
only VPMem layers, this change enables it for read-only SCSI
devices as well.
Just like in a VPMem scenario, create dm-verity target when
verity information is presented to the guest during SCSI device
mounting step.

Additionally remove unnecessary unit test, since the guest logic
has changed.

Add pmem and scsi unit tests for linear/verity device mapper
targets

Signed-off-by: Maksim An <maksiman@microsoft.com>
  • Loading branch information
anmaxvl authored Oct 20, 2021
1 parent 13c6c9a commit e43d462
Show file tree
Hide file tree
Showing 5 changed files with 631 additions and 71 deletions.
22 changes: 13 additions & 9 deletions guest/storage/pmem/pmem.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,13 @@ import (

// Test dependencies
var (
osMkdirAll = os.MkdirAll
osRemoveAll = os.RemoveAll
unixMount = unix.Mount
osMkdirAll = os.MkdirAll
osRemoveAll = os.RemoveAll
unixMount = unix.Mount
mountInternal = mount
createZeroSectorLinearTarget = dm.CreateZeroSectorLinearTarget
createVerityTarget = dm.CreateVerityTarget
removeDevice = dm.RemoveDevice
)

const (
Expand All @@ -32,8 +36,8 @@ const (
verityDeviceFmt = "dm-verity-pmem%d-%s"
)

// mountInternal mounts source to target via unix.Mount
func mountInternal(ctx context.Context, source, target string) (err error) {
// mount mounts source to target via unix.Mount
func mount(ctx context.Context, source, target string) (err error) {
if err := osMkdirAll(target, 0700); err != nil {
return err
}
Expand Down Expand Up @@ -89,12 +93,12 @@ func Mount(ctx context.Context, device uint32, target string, mappingInfo *prot.
// device instead of the original VPMem.
if mappingInfo != nil {
dmLinearName := fmt.Sprintf(linearDeviceFmt, device, mappingInfo.DeviceOffsetInBytes, mappingInfo.DeviceSizeInBytes)
if devicePath, err = dm.CreateZeroSectorLinearTarget(mCtx, devicePath, dmLinearName, mappingInfo); err != nil {
if devicePath, err = createZeroSectorLinearTarget(mCtx, devicePath, dmLinearName, mappingInfo); err != nil {
return err
}
defer func() {
if err != nil {
if err := dm.RemoveDevice(dmLinearName); err != nil {
if err := removeDevice(dmLinearName); err != nil {
log.G(mCtx).WithError(err).Debugf("failed to cleanup linear target: %s", dmLinearName)
}
}
Expand All @@ -103,12 +107,12 @@ func Mount(ctx context.Context, device uint32, target string, mappingInfo *prot.

if verityInfo != nil {
dmVerityName := fmt.Sprintf(verityDeviceFmt, device, verityInfo.RootDigest)
if devicePath, err = dm.CreateVerityTarget(mCtx, devicePath, dmVerityName, verityInfo); err != nil {
if devicePath, err = createVerityTarget(mCtx, devicePath, dmVerityName, verityInfo); err != nil {
return err
}
defer func() {
if err != nil {
if err := dm.RemoveDevice(dmVerityName); err != nil {
if err := removeDevice(dmVerityName); err != nil {
log.G(mCtx).WithError(err).Debugf("failed to cleanup verity target: %s", dmVerityName)
}
}
Expand Down
323 changes: 323 additions & 0 deletions guest/storage/pmem/pmem_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package pmem
import (
"context"
"fmt"
"github.com/Microsoft/hcsshim/internal/guest/prot"
"os"
"testing"

Expand All @@ -18,6 +19,10 @@ func clearTestDependencies() {
osMkdirAll = nil
osRemoveAll = nil
unixMount = nil
createZeroSectorLinearTarget = nil
createVerityTarget = nil
removeDevice = nil
mountInternal = mount
}

func Test_Mount_Mkdir_Fails_Error(t *testing.T) {
Expand Down Expand Up @@ -305,3 +310,321 @@ func openDoorSecurityPolicyEnforcer() securitypolicy.SecurityPolicyEnforcer {
func mountMonitoringSecurityPolicyEnforcer() *policy.MountMonitoringSecurityPolicyEnforcer {
return &policy.MountMonitoringSecurityPolicyEnforcer{}
}

// device mapper tests
func Test_CreateLinearTarget_And_Mount_Called_With_Correct_Parameters(t *testing.T) {
clearTestDependencies()

mappingInfo := &prot.DeviceMappingInfo{
DeviceOffsetInBytes: 0,
DeviceSizeInBytes: 1024,
}
expectedLinearName := fmt.Sprintf(linearDeviceFmt, 0, mappingInfo.DeviceOffsetInBytes, mappingInfo.DeviceSizeInBytes)
expectedSource := "/dev/pmem0"
expectedTarget := "/foo"
mapperPath := fmt.Sprintf("/dev/mapper/%s", expectedLinearName)
createZSLTCalled := false

osMkdirAll = func(_ string, _ os.FileMode) error {
return nil
}

mountInternal = func(_ context.Context, source, target string) error {
if source != mapperPath {
t.Errorf("expected mountInternal source %s, got %s", mapperPath, source)
}
if target != expectedTarget {
t.Errorf("expected mountInternal target %s, got %s", expectedTarget, source)
}
return nil
}

createZeroSectorLinearTarget = func(_ context.Context, source, name string, mapping *prot.DeviceMappingInfo) (string, error) {
createZSLTCalled = true
if source != expectedSource {
t.Errorf("expected createZeroSectorLinearTarget source %s, got %s", expectedSource, source)
}
if name != expectedLinearName {
t.Errorf("expected createZeroSectorLinearTarget name %s, got %s", expectedLinearName, name)
}
return mapperPath, nil
}

if err := Mount(
context.Background(),
0,
expectedTarget,
mappingInfo,
nil,
openDoorSecurityPolicyEnforcer(),
); err != nil {
t.Fatalf("unexpected error during Mount: %s", err)
}
if !createZSLTCalled {
t.Fatalf("createZeroSectorLinearTarget not called")
}
}

func Test_CreateVerityTargetCalled_And_Mount_Called_With_Correct_Parameters(t *testing.T) {
clearTestDependencies()

verityInfo := &prot.DeviceVerityInfo{
RootDigest: "hash",
}
expectedVerityName := fmt.Sprintf(verityDeviceFmt, 0, verityInfo.RootDigest)
expectedSource := "/dev/pmem0"
expectedTarget := "/foo"
mapperPath := fmt.Sprintf("/dev/mapper/%s", expectedVerityName)
createVerityTargetCalled := false

mountInternal = func(_ context.Context, source, target string) error {
if source != mapperPath {
t.Errorf("expected mountInternal source %s, got %s", mapperPath, source)
}
if target != expectedTarget {
t.Errorf("expected mountInternal target %s, got %s", expectedTarget, target)
}
return nil
}
createVerityTarget = func(_ context.Context, source, name string, verity *prot.DeviceVerityInfo) (string, error) {
createVerityTargetCalled = true
if source != expectedSource {
t.Errorf("expected createVerityTarget source %s, got %s", expectedSource, source)
}
if name != expectedVerityName {
t.Errorf("expected createVerityTarget name %s, got %s", expectedVerityName, name)
}
return mapperPath, nil
}

if err := Mount(
context.Background(),
0,
expectedTarget,
nil,
verityInfo,
openDoorSecurityPolicyEnforcer(),
); err != nil {
t.Fatalf("unexpected Mount failure: %s", err)
}
if !createVerityTargetCalled {
t.Fatal("createVerityTarget not called")
}
}

func Test_CreateLinearTarget_And_CreateVerityTargetCalled_Called_Correctly(t *testing.T) {
clearTestDependencies()

verityInfo := &prot.DeviceVerityInfo{
RootDigest: "hash",
}
mapping := &prot.DeviceMappingInfo{
DeviceOffsetInBytes: 0,
DeviceSizeInBytes: 1024,
}
expectedLinearTarget := fmt.Sprintf(linearDeviceFmt, 0, mapping.DeviceOffsetInBytes, mapping.DeviceSizeInBytes)
expectedVerityTarget := fmt.Sprintf(verityDeviceFmt, 0, verityInfo.RootDigest)
expectedPMemDevice := "/dev/pmem0"
mapperLinearPath := fmt.Sprintf("/dev/mapper/%s", expectedLinearTarget)
mapperVerityPath := fmt.Sprintf("/dev/mapper/%s", expectedVerityTarget)
dmLinearCalled := false
dmVerityCalled := false
mountCalled := false

createZeroSectorLinearTarget = func(_ context.Context, source, name string, mapping *prot.DeviceMappingInfo) (string, error) {
dmLinearCalled = true
if source != expectedPMemDevice {
t.Errorf("expected createZeroSectorLinearTarget source %s, got %s", expectedPMemDevice, source)
}
if name != expectedLinearTarget {
t.Errorf("expected createZeroSectorLinearTarget name %s, got %s", expectedLinearTarget, name)
}
return mapperLinearPath, nil
}
createVerityTarget = func(_ context.Context, source, name string, verity *prot.DeviceVerityInfo) (string, error) {
dmVerityCalled = true
if source != mapperLinearPath {
t.Errorf("expected createVerityTarget source %s, got %s", mapperLinearPath, source)
}
if name != expectedVerityTarget {
t.Errorf("expected createVerityTarget target name %s, got %s", expectedVerityTarget, name)
}
return mapperVerityPath, nil
}
mountInternal = func(_ context.Context, source, target string) error {
mountCalled = true
if source != mapperVerityPath {
t.Errorf("expected Mount source %s, got %s", mapperVerityPath, source)
}
return nil
}

if err := Mount(
context.Background(),
0,
"/foo",
mapping,
verityInfo,
openDoorSecurityPolicyEnforcer(),
); err != nil {
t.Fatalf("unexpected error during Mount call: %s", err)
}
if !dmLinearCalled {
t.Fatal("expected createZeroSectorLinearTarget call")
}
if !dmVerityCalled {
t.Fatal("expected createVerityTarget call")
}
if !mountCalled {
t.Fatal("expected mountInternal call")
}
}

func Test_RemoveDevice_Called_For_LinearTarget_On_MountInternalFailure(t *testing.T) {
clearTestDependencies()

mappingInfo := &prot.DeviceMappingInfo{
DeviceOffsetInBytes: 0,
DeviceSizeInBytes: 1024,
}
expectedError := errors.New("mountInternal error")
expectedTarget := fmt.Sprintf(linearDeviceFmt, 0, mappingInfo.DeviceOffsetInBytes, mappingInfo.DeviceSizeInBytes)
mapperPath := fmt.Sprintf("/dev/mapper/%s", expectedTarget)
removeDeviceCalled := false

createZeroSectorLinearTarget = func(_ context.Context, source, name string, mapping *prot.DeviceMappingInfo) (string, error) {
return mapperPath, nil
}
mountInternal = func(_ context.Context, source, target string) error {
return expectedError
}
removeDevice = func(name string) error {
removeDeviceCalled = true
if name != expectedTarget {
t.Errorf("expected removeDevice linear target %s, got %s", expectedTarget, name)
}
return nil
}

if err := Mount(
context.Background(),
0,
"/foo",
mappingInfo,
nil,
openDoorSecurityPolicyEnforcer(),
); err != expectedError {
t.Fatalf("expected Mount error %s, got %s", expectedError, err)
}
if !removeDeviceCalled {
t.Fatal("expected removeDevice to be callled")
}
}

func Test_RemoveDevice_Called_For_VerityTarget_On_MountInternalFailure(t *testing.T) {
clearTestDependencies()

verity := &prot.DeviceVerityInfo{
RootDigest: "hash",
}
expectedVerityTarget := fmt.Sprintf(verityDeviceFmt, 0, verity.RootDigest)
expectedError := errors.New("mountInternal error")
mapperPath := fmt.Sprintf("/dev/mapper/%s", expectedVerityTarget)
removeDeviceCalled := false

createVerityTarget = func(_ context.Context, source, name string, verity *prot.DeviceVerityInfo) (string, error) {
return mapperPath, nil
}
mountInternal = func(_ context.Context, _, _ string) error {
return expectedError
}
removeDevice = func(name string) error {
removeDeviceCalled = true
if name != expectedVerityTarget {
t.Errorf("expected removeDevice verity target %s, got %s", expectedVerityTarget, name)
}
return nil
}

if err := Mount(
context.Background(),
0,
"/foo",
nil,
verity,
openDoorSecurityPolicyEnforcer(),
); err != expectedError {
t.Fatalf("expected Mount error %s, got %s", expectedError, err)
}
if !removeDeviceCalled {
t.Fatal("expected removeDevice to be called")
}
}

func Test_RemoveDevice_Called_For_Both_Targets_On_MountInternalFailure(t *testing.T) {
clearTestDependencies()

mapping := &prot.DeviceMappingInfo{
DeviceOffsetInBytes: 0,
DeviceSizeInBytes: 1024,
}
verity := &prot.DeviceVerityInfo{
RootDigest: "hash",
}
expectedError := errors.New("mountInternal error")
expectedLinearTarget := fmt.Sprintf(linearDeviceFmt, 0, mapping.DeviceOffsetInBytes, mapping.DeviceSizeInBytes)
expectedVerityTarget := fmt.Sprintf(verityDeviceFmt, 0, verity.RootDigest)
expectedPMemDevice := "/dev/pmem0"
mapperLinearPath := fmt.Sprintf("/dev/mapper/%s", expectedLinearTarget)
mapperVerityPath := fmt.Sprintf("/dev/mapper/%s", expectedVerityTarget)
rmLinearCalled := false
rmVerityCalled := false

createZeroSectorLinearTarget = func(_ context.Context, source, name string, m *prot.DeviceMappingInfo) (string, error) {
if source != expectedPMemDevice {
t.Errorf("expected createZeroSectorLinearTarget source %s, got %s", expectedPMemDevice, source)
}
return mapperLinearPath, nil
}
createVerityTarget = func(_ context.Context, source, name string, v *prot.DeviceVerityInfo) (string, error) {
if source != mapperLinearPath {
t.Errorf("expected createVerityTarget to be called with %s, got %s", mapperLinearPath, source)
}
if name != expectedVerityTarget {
t.Errorf("expected createVerityTarget target %s, got %s", expectedVerityTarget, name)
}
return mapperVerityPath, nil
}
removeDevice = func(name string) error {
if name != expectedLinearTarget && name != expectedVerityTarget {
t.Errorf("unexpected removeDevice target name %s", name)
}
if name == expectedLinearTarget {
rmLinearCalled = true
}
if name == expectedVerityTarget {
rmVerityCalled = true
}
return nil
}
mountInternal = func(_ context.Context, _, _ string) error {
return expectedError
}

if err := Mount(
context.Background(),
0,
"/foo",
mapping,
verity,
openDoorSecurityPolicyEnforcer(),
); err != expectedError {
t.Fatalf("expected Mount error %s, got %s", expectedError, err)
}
if !rmLinearCalled {
t.Fatal("expected removeDevice for linear target to be called")
}
if !rmVerityCalled {
t.Fatal("expected removeDevice for verity target to be called")
}
}
Loading

0 comments on commit e43d462

Please sign in to comment.