From 5bcc330cb0a983072078a3c11c1d83a4b22dc96f Mon Sep 17 00:00:00 2001 From: Andrew Lavery Date: Tue, 7 May 2019 14:52:59 -0700 Subject: [PATCH 1/4] make the state manager safe to use in parallel add a test that makes many updates to the state in parallel add a function to update and return updated state --- pkg/state/manager.go | 341 +++++++++++++-------------- pkg/state/manager_test.go | 144 +++++++++++ pkg/test-mocks/state/manager_mock.go | 25 ++ 3 files changed, 330 insertions(+), 180 deletions(-) diff --git a/pkg/state/manager.go b/pkg/state/manager.go index 1606da27b..d47e6d3f4 100644 --- a/pkg/state/manager.go +++ b/pkg/state/manager.go @@ -6,6 +6,7 @@ import ( "os" "path/filepath" "strings" + "sync" "github.com/go-kit/kit/log" "github.com/go-kit/kit/log/level" @@ -31,6 +32,8 @@ type Manager interface { templateContext map[string]interface{}, ) error TryLoad() (State, error) + SafeStateUpdate(updater StateUpdate) error + SafeStateUpdateReturn(updater StateUpdate) (State, error) RemoveStateFile() error SaveKustomize(kustomize *Kustomize) error SerializeUpstream(URL string) error @@ -55,10 +58,17 @@ type MManager struct { FS afero.Afero V *viper.Viper patcher patch.Patcher + mut sync.Mutex } func (m *MManager) Save(v VersionedState) error { - return m.serializeAndWriteState(v) + debug := level.Debug(log.With(m.Logger, "method", "SerializeShipMetadata")) + + debug.Log("event", "safeStateUpdate") + return m.SafeStateUpdate(func(state VersionedState) (VersionedState, error) { + state = v + return state, nil + }) } func NewManager( @@ -73,200 +83,179 @@ func NewManager( } } -// SerializeShipMetadata is used by `ship init` to serialize metadata from ship applications to state file -func (m *MManager) SerializeShipMetadata(metadata api.ShipAppMetadata, applicationType string) error { - debug := level.Debug(log.With(m.Logger, "method", "SerializeShipMetadata")) +type StateUpdate func(VersionedState) (VersionedState, error) - debug.Log("event", "tryLoadState") - current, err := m.TryLoad() +// applies the provided updater to the current state. Returns error +func (m *MManager) SafeStateUpdate(updater StateUpdate) error { + _, err := m.SafeStateUpdateReturn(updater) + return err +} + +// applies the provided updater to the current state. Returns the new state and err +func (m *MManager) SafeStateUpdateReturn(updater StateUpdate) (State, error) { + m.mut.Lock() + defer m.mut.Unlock() + + currentState, err := m.TryLoad() if err != nil { - return errors.Wrap(err, "load state") + return VersionedState{}, errors.Wrap(err, "tryLoad in safe updater") } - versionedState := current.Versioned() - versionedState.V1.Metadata = &Metadata{ - ApplicationType: applicationType, - ReleaseNotes: metadata.ReleaseNotes, - Version: metadata.Version, - Icon: metadata.Icon, - Name: metadata.Name, + updatedState, err := updater(currentState.Versioned()) + if err != nil { + return VersionedState{}, errors.Wrap(err, "run state update function in safe updater") } - return m.serializeAndWriteState(versionedState) + return updatedState, errors.Wrap(m.serializeAndWriteState(updatedState), "write state in safe updater") +} + +// SerializeShipMetadata is used by `ship init` to serialize metadata from ship applications to state file +func (m *MManager) SerializeShipMetadata(metadata api.ShipAppMetadata, applicationType string) error { + debug := level.Debug(log.With(m.Logger, "method", "SerializeShipMetadata")) + + debug.Log("event", "safeStateUpdate") + return m.SafeStateUpdate(func(state VersionedState) (VersionedState, error) { + state.V1.Metadata = &Metadata{ + ApplicationType: applicationType, + ReleaseNotes: metadata.ReleaseNotes, + Version: metadata.Version, + Icon: metadata.Icon, + Name: metadata.Name, + } + return state, nil + }) } // SerializeAppMetadata is used by `ship app` to serialize replicated app metadata to state file func (m *MManager) SerializeAppMetadata(metadata api.ReleaseMetadata) error { debug := level.Debug(log.With(m.Logger, "method", "SerializeAppMetadata")) - debug.Log("event", "tryLoadState") - current, err := m.TryLoad() - if err != nil { - return errors.Wrap(err, "load state") - } - - versionedState := current.Versioned() - versionedState.V1.Metadata = &Metadata{ - ApplicationType: "replicated.app", - ReleaseNotes: metadata.ReleaseNotes, - Version: metadata.Semver, - CustomerID: metadata.CustomerID, - InstallationID: metadata.InstallationID, - LicenseID: metadata.LicenseID, - AppSlug: metadata.AppSlug, - License: License{ - ID: metadata.License.ID, - Assignee: metadata.License.Assignee, - CreatedAt: metadata.License.CreatedAt, - ExpiresAt: metadata.License.ExpiresAt, - Type: metadata.License.Type, - }, - } - - return m.serializeAndWriteState(versionedState) + debug.Log("event", "safeStateUpdate") + return m.SafeStateUpdate(func(state VersionedState) (VersionedState, error) { + state.V1.Metadata = &Metadata{ + ApplicationType: "replicated.app", + ReleaseNotes: metadata.ReleaseNotes, + Version: metadata.Semver, + CustomerID: metadata.CustomerID, + InstallationID: metadata.InstallationID, + LicenseID: metadata.LicenseID, + AppSlug: metadata.AppSlug, + License: License{ + ID: metadata.License.ID, + Assignee: metadata.License.Assignee, + CreatedAt: metadata.License.CreatedAt, + ExpiresAt: metadata.License.ExpiresAt, + Type: metadata.License.Type, + }, + } + return state, nil + }) } // SerializeUpstream is used by `ship init` to serialize a state file with ChartURL to disk func (m *MManager) SerializeUpstream(upstream string) error { debug := level.Debug(log.With(m.Logger, "method", "SerializeUpstream")) - current, err := m.TryLoad() - if err != nil { - return errors.Wrap(err, "load state") - } - debug.Log("event", "generateUpstreamURLState") - - toSerialize := current.Versioned() - toSerialize.V1.Upstream = upstream - - return m.serializeAndWriteState(toSerialize) + debug.Log("event", "safeStateUpdate") + return m.SafeStateUpdate(func(state VersionedState) (VersionedState, error) { + state.V1.Upstream = upstream + return state, nil + }) } // SerializeContentSHA writes the contentSHA to the state file func (m *MManager) SerializeContentSHA(contentSHA string) error { debug := level.Debug(log.With(m.Logger, "method", "SerializeContentSHA")) - debug.Log("event", "tryLoadState") - currentState, err := m.TryLoad() - if err != nil { - return errors.Wrap(err, "try load state") - } - versionedState := currentState.Versioned() - versionedState.V1.ContentSHA = contentSHA - - return m.serializeAndWriteState(versionedState) + debug.Log("event", "safeStateUpdate") + return m.SafeStateUpdate(func(state VersionedState) (VersionedState, error) { + state.V1.ContentSHA = contentSHA + return state, nil + }) } // SerializeHelmValues takes user input helm values and serializes a state file to disk func (m *MManager) SerializeHelmValues(values string, defaults string) error { debug := level.Debug(log.With(m.Logger, "method", "serializeHelmValues")) - debug.Log("event", "tryLoadState") - currentState, err := m.TryLoad() - if err != nil { - return errors.Wrap(err, "try load state") - } - versionedState := currentState.Versioned() - versionedState.V1.HelmValues = values - versionedState.V1.HelmValuesDefaults = defaults - - return m.serializeAndWriteState(versionedState) + debug.Log("event", "safeStateUpdate") + return m.SafeStateUpdate(func(state VersionedState) (VersionedState, error) { + state.V1.HelmValues = values + state.V1.HelmValuesDefaults = defaults + return state, nil + }) } // SerializeReleaseName serializes to disk the name to use for helm template func (m *MManager) SerializeReleaseName(name string) error { - debug := level.Debug(log.With(m.Logger, "method", "serializeHelmValues")) + debug := level.Debug(log.With(m.Logger, "method", "serializeReleaseName")) - debug.Log("event", "tryLoadState") - currentState, err := m.TryLoad() - if err != nil { - return errors.Wrap(err, "try load state") - } - versionedState := currentState.Versioned() - versionedState.V1.ReleaseName = name - - return m.serializeAndWriteState(versionedState) + debug.Log("event", "safeStateUpdate") + return m.SafeStateUpdate(func(state VersionedState) (VersionedState, error) { + state.V1.ReleaseName = name + return state, nil + }) } // SerializeNamespace serializes to disk the namespace to use for helm template func (m *MManager) SerializeNamespace(namespace string) error { - debug := level.Debug(log.With(m.Logger, "method", "serializeHelmValues")) + debug := level.Debug(log.With(m.Logger, "method", "serializeNamespace")) - debug.Log("event", "tryLoadState") - currentState, err := m.TryLoad() - if err != nil { - return errors.Wrap(err, "try load state") - } - versionedState := currentState.Versioned() - versionedState.V1.Namespace = namespace - - return m.serializeAndWriteState(versionedState) + debug.Log("event", "safeStateUpdate") + return m.SafeStateUpdate(func(state VersionedState) (VersionedState, error) { + state.V1.Namespace = namespace + return state, nil + }) } // SerializeConfig takes the application data and input params and serializes a state file to disk func (m *MManager) SerializeConfig(assets []api.Asset, meta api.ReleaseMetadata, templateContext map[string]interface{}) error { debug := level.Debug(log.With(m.Logger, "method", "serializeConfig")) - debug.Log("event", "tryLoadState") - currentState, err := m.TryLoad() - if err != nil { - return errors.Wrap(err, "try load state") - } - versionedState := currentState.Versioned() - versionedState.V1.Config = templateContext - - return m.serializeAndWriteState(versionedState) + debug.Log("event", "safeStateUpdate") + return m.SafeStateUpdate(func(state VersionedState) (VersionedState, error) { + state.V1.Config = templateContext + return state, nil + }) } func (m *MManager) SerializeListsMetadata(list util.List) error { debug := level.Debug(log.With(m.Logger, "method", "serializeListMetadata")) - debug.Log("event", "tryLoadState") - currentState, err := m.TryLoad() - if err != nil { - return errors.Wrap(err, "try load state") - } - - versionedState := currentState.Versioned() - if versionedState.V1.Metadata == nil { - versionedState.V1.Metadata = &Metadata{} - } - versionedState.V1.Metadata.Lists = append(versionedState.V1.Metadata.Lists, list) - - return m.serializeAndWriteState(versionedState) + debug.Log("event", "safeStateUpdate") + return m.SafeStateUpdate(func(state VersionedState) (VersionedState, error) { + if state.V1.Metadata == nil { + state.V1.Metadata = &Metadata{} + } + state.V1.Metadata.Lists = append(state.V1.Metadata.Lists, list) + return state, nil + }) } func (m *MManager) ClearListsMetadata() error { - debug := level.Debug(log.With(m.Logger, "method", "serializeListMetadata")) + debug := level.Debug(log.With(m.Logger, "method", "clearListMetadata")) - debug.Log("event", "tryLoadState") - currentState, err := m.TryLoad() - if err != nil { - return errors.Wrap(err, "try load state") - } + debug.Log("event", "safeStateUpdate") + return m.SafeStateUpdate(func(state VersionedState) (VersionedState, error) { + if state.V1.Metadata == nil { + return state, nil + } - versionedState := currentState.Versioned() - if versionedState.V1.Metadata == nil { - return nil - } - versionedState.V1.Metadata.Lists = []util.List{} - - return m.serializeAndWriteState(versionedState) + state.V1.Metadata.Lists = []util.List{} + return state, nil + }) } // SerializeConfig takes the application data and input params and serializes a state file to disk func (m *MManager) SerializeUpstreamContents(contents *UpstreamContents) error { - debug := level.Debug(log.With(m.Logger, "method", "serializeConfig")) + debug := level.Debug(log.With(m.Logger, "method", "serializeUpstreamContents")) - debug.Log("event", "tryLoadState") - currentState, err := m.TryLoad() - if err != nil { - return errors.Wrap(err, "try load state") - } - versionedState := currentState.Versioned() - versionedState.V1.UpstreamContents = contents + debug.Log("event", "safeStateUpdate") + return m.SafeStateUpdate(func(state VersionedState) (VersionedState, error) { - return m.serializeAndWriteState(versionedState) + state.V1.UpstreamContents = contents + return state, nil + }) } // TryLoad will attempt to load a state file from disk, if present @@ -294,15 +283,12 @@ func (m *MManager) TryLoad() (State, error) { func (m *MManager) ResetLifecycle() error { debug := level.Debug(log.With(m.Logger, "method", "ResetLifecycle")) - debug.Log("event", "tryLoadState") - currentState, err := m.TryLoad() - if err != nil { - return errors.Wrap(err, "try load state") - } - versionedState := currentState.Versioned() - versionedState.V1.Lifecycle = nil + debug.Log("event", "safeStateUpdate") + return m.SafeStateUpdate(func(state VersionedState) (VersionedState, error) { - return m.serializeAndWriteState(versionedState) + state.V1.Lifecycle = nil + return state, nil + }) } // tryLoadFromSecret will attempt to load the state from a secret @@ -406,18 +392,14 @@ func (m *MManager) tryLoadFromFile() (State, error) { } func (m *MManager) SaveKustomize(kustomize *Kustomize) error { - currentState, err := m.TryLoad() - if err != nil { - return errors.Wrapf(err, "load state") - } - versionedState := currentState.Versioned() - versionedState.V1.Kustomize = kustomize + debug := level.Debug(log.With(m.Logger, "method", "SaveKustomize")) - if err := m.serializeAndWriteState(versionedState); err != nil { - return errors.Wrap(err, "write state") - } + debug.Log("event", "safeStateUpdate") + return m.SafeStateUpdate(func(state VersionedState) (VersionedState, error) { - return nil + state.V1.Kustomize = kustomize + return state, nil + }) } // RemoveStateFile will attempt to remove the state file from disk @@ -512,36 +494,35 @@ func (m *MManager) serializeAndWriteStateSecret(state VersionedState) error { } func (m *MManager) AddCert(name string, newCert util.CertType) error { - currentState, err := m.TryLoad() - if err != nil { - return errors.Wrapf(err, "load state") - } - versionedState := currentState.Versioned() - if versionedState.V1.Certs == nil { - versionedState.V1.Certs = make(map[string]util.CertType) - } - if _, ok := versionedState.V1.Certs[name]; ok { - return fmt.Errorf("cert with name %s already exists in state", name) - } - versionedState.V1.Certs[name] = newCert - - return errors.Wrap(m.serializeAndWriteState(versionedState), "write state") + debug := level.Debug(log.With(m.Logger, "method", "SaveKustomize")) + + debug.Log("event", "safeStateUpdate") + return m.SafeStateUpdate(func(state VersionedState) (VersionedState, error) { + + if state.V1.Certs == nil { + state.V1.Certs = make(map[string]util.CertType) + } + if _, ok := state.V1.Certs[name]; ok { + return state, fmt.Errorf("cert with name %s already exists in state", name) + } + state.V1.Certs[name] = newCert + return state, nil + }) } func (m *MManager) AddCA(name string, newCA util.CAType) error { - currentState, err := m.TryLoad() - if err != nil { - return errors.Wrapf(err, "load state") - } - versionedState := currentState.Versioned() - if versionedState.V1.CAs == nil { - versionedState.V1.CAs = make(map[string]util.CAType) - } - if _, ok := versionedState.V1.CAs[name]; ok { - return fmt.Errorf("cert with name %s already exists in state", name) - } - versionedState.V1.CAs[name] = newCA - - return errors.Wrap(m.serializeAndWriteState(versionedState), "write state") - + debug := level.Debug(log.With(m.Logger, "method", "SaveKustomize")) + + debug.Log("event", "safeStateUpdate") + return m.SafeStateUpdate(func(state VersionedState) (VersionedState, error) { + + if state.V1.CAs == nil { + state.V1.CAs = make(map[string]util.CAType) + } + if _, ok := state.V1.CAs[name]; ok { + return state, fmt.Errorf("cert with name %s already exists in state", name) + } + state.V1.CAs[name] = newCA + return state, nil + }) } diff --git a/pkg/state/manager_test.go b/pkg/state/manager_test.go index 962342c17..6690ce219 100644 --- a/pkg/state/manager_test.go +++ b/pkg/state/manager_test.go @@ -1,6 +1,8 @@ package state import ( + "fmt" + "sync" "testing" "github.com/go-kit/kit/log" @@ -8,6 +10,8 @@ import ( "github.com/replicatedhq/ship/pkg/api" "github.com/replicatedhq/ship/pkg/constants" "github.com/replicatedhq/ship/pkg/testing/logger" + "github.com/replicatedhq/ship/pkg/util" + "github.com/spf13/afero" "github.com/spf13/viper" "github.com/stretchr/testify/require" @@ -515,3 +519,143 @@ func TestMManager_ResetLifecycle(t *testing.T) { }) } } + +func TestMManager_ParallelUpdates(t *testing.T) { + tests := []struct { + name string + runners []func(*MManager, *require.Assertions, *sync.WaitGroup) + validator func(VersionedState, *require.Assertions) + }{ + { + name: "lists", + runners: []func(*MManager, *require.Assertions, *sync.WaitGroup){ + func(m *MManager, req *require.Assertions, group *sync.WaitGroup) { + // add the integers 1-20 to the list + for i := 1; i <= 20; i++ { + err := m.SerializeListsMetadata(util.List{APIVersion: fmt.Sprintf("%d", i)}) + req.NoError(err) + } + group.Done() + }, + }, + validator: func(state VersionedState, req *require.Assertions) { + req.Len(state.V1.Metadata.Lists, 20) + }, + }, + { + name: "lists and upstream", + runners: []func(*MManager, *require.Assertions, *sync.WaitGroup){ + // lists + func(m *MManager, req *require.Assertions, group *sync.WaitGroup) { + // add the integers 1-20 to the list + for i := 1; i <= 20; i++ { + err := m.SerializeListsMetadata(util.List{APIVersion: fmt.Sprintf("%d", i)}) + req.NoError(err) + } + group.Done() + }, + // first upstream mutator + func(m *MManager, req *require.Assertions, group *sync.WaitGroup) { + // append the integers 1-200 to the upstream + for i := 1; i <= 200; i++ { + err := m.SafeStateUpdate(func(state VersionedState) (VersionedState, error) { + state.V1.Upstream += fmt.Sprintf(" a:%d ", i) + return state, nil + }) + req.NoError(err) + } + group.Done() + }, + // second upstream mutator + func(m *MManager, req *require.Assertions, group *sync.WaitGroup) { + // append the integers 1-200 to the upstream + for i := 1; i <= 200; i++ { + err := m.SafeStateUpdate(func(state VersionedState) (VersionedState, error) { + state.V1.Upstream += fmt.Sprintf(" b:%d ", i) + return state, nil + }) + req.NoError(err) + } + group.Done() + }, + // third upstream mutator + func(m *MManager, req *require.Assertions, group *sync.WaitGroup) { + // append the integers 1-200 to the upstream + for i := 1; i <= 200; i++ { + err := m.SafeStateUpdate(func(state VersionedState) (VersionedState, error) { + state.V1.Upstream += fmt.Sprintf(" c:%d ", i) + return state, nil + }) + req.NoError(err) + } + group.Done() + }, + // fourth upstream mutator + func(m *MManager, req *require.Assertions, group *sync.WaitGroup) { + // append the integers 1-200 to the upstream + for i := 1; i <= 200; i++ { + err := m.SafeStateUpdate(func(state VersionedState) (VersionedState, error) { + state.V1.Upstream += fmt.Sprintf(" d:%d ", i) + return state, nil + }) + req.NoError(err) + } + group.Done() + }, + // fifth upstream mutator + func(m *MManager, req *require.Assertions, group *sync.WaitGroup) { + // append the integers 1-200 to the upstream + for i := 1; i <= 200; i++ { + err := m.SafeStateUpdate(func(state VersionedState) (VersionedState, error) { + state.V1.Upstream += fmt.Sprintf(" e:%d ", i) + return state, nil + }) + req.NoError(err) + } + group.Done() + }, + }, + validator: func(state VersionedState, req *require.Assertions) { + req.Len(state.V1.Metadata.Lists, 20) + + totalUpstream := state.Upstream() + for _, str := range []string{"a", "b", "c", "d", "e"} { + for i := 1; i <= 200; i++ { + req.Contains(totalUpstream, fmt.Sprintf(" %s:%d ", str, i)) + } + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + + req := require.New(t) + m := &MManager{ + Logger: log.NewNopLogger(), + FS: afero.Afero{Fs: afero.NewMemMapFs()}, + V: viper.New(), + } + + initialState := VersionedState{V1: &V1{Lifecycle: nil}} + + group := sync.WaitGroup{} + + err := m.serializeAndWriteState(initialState) + req.NoError(err) + + group.Add(len(tt.runners)) + for _, runner := range tt.runners { + go runner(m, req, &group) + } + + group.Wait() + actualState, err := m.TryLoad() + req.NoError(err) + + tt.validator(actualState.Versioned(), req) + }) + } + +} diff --git a/pkg/test-mocks/state/manager_mock.go b/pkg/test-mocks/state/manager_mock.go index 99def17a9..ecf44ce46 100644 --- a/pkg/test-mocks/state/manager_mock.go +++ b/pkg/test-mocks/state/manager_mock.go @@ -96,6 +96,31 @@ func (mr *MockManagerMockRecorder) ResetLifecycle() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ResetLifecycle", reflect.TypeOf((*MockManager)(nil).ResetLifecycle)) } +// SafeStateUpdate mocks base method +func (m *MockManager) SafeStateUpdate(arg0 state.StateUpdate) error { + ret := m.ctrl.Call(m, "SafeStateUpdate", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// SafeStateUpdate indicates an expected call of SafeStateUpdate +func (mr *MockManagerMockRecorder) SafeStateUpdate(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SafeStateUpdate", reflect.TypeOf((*MockManager)(nil).SafeStateUpdate), arg0) +} + +// SafeStateUpdateReturn mocks base method +func (m *MockManager) SafeStateUpdateReturn(arg0 state.StateUpdate) (state.State, error) { + ret := m.ctrl.Call(m, "SafeStateUpdateReturn", arg0) + ret0, _ := ret[0].(state.State) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// SafeStateUpdateReturn indicates an expected call of SafeStateUpdateReturn +func (mr *MockManagerMockRecorder) SafeStateUpdateReturn(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SafeStateUpdateReturn", reflect.TypeOf((*MockManager)(nil).SafeStateUpdateReturn), arg0) +} + // Save mocks base method func (m *MockManager) Save(arg0 state.VersionedState) error { ret := m.ctrl.Call(m, "Save", arg0) From 420b2bc1b3d39f7ffc212f2b2a042e622e0b73d1 Mon Sep 17 00:00:00 2001 From: Andrew Lavery Date: Thu, 9 May 2019 13:01:28 -0700 Subject: [PATCH 2/4] all updates should be safe, so rename function there can never be too many tests --- pkg/state/manager.go | 75 +++--- pkg/state/manager_test.go | 368 +++++++++++++++++++++++++-- pkg/test-mocks/state/manager_mock.go | 20 +- 3 files changed, 395 insertions(+), 68 deletions(-) diff --git a/pkg/state/manager.go b/pkg/state/manager.go index d47e6d3f4..afbd496e6 100644 --- a/pkg/state/manager.go +++ b/pkg/state/manager.go @@ -32,8 +32,8 @@ type Manager interface { templateContext map[string]interface{}, ) error TryLoad() (State, error) - SafeStateUpdate(updater StateUpdate) error - SafeStateUpdateReturn(updater StateUpdate) (State, error) + StateUpdate(updater Update) error + StateUpdateReturn(updater Update) (State, error) RemoveStateFile() error SaveKustomize(kustomize *Kustomize) error SerializeUpstream(URL string) error @@ -65,7 +65,7 @@ func (m *MManager) Save(v VersionedState) error { debug := level.Debug(log.With(m.Logger, "method", "SerializeShipMetadata")) debug.Log("event", "safeStateUpdate") - return m.SafeStateUpdate(func(state VersionedState) (VersionedState, error) { + return m.StateUpdate(func(state VersionedState) (VersionedState, error) { state = v return state, nil }) @@ -83,16 +83,16 @@ func NewManager( } } -type StateUpdate func(VersionedState) (VersionedState, error) +type Update func(VersionedState) (VersionedState, error) // applies the provided updater to the current state. Returns error -func (m *MManager) SafeStateUpdate(updater StateUpdate) error { - _, err := m.SafeStateUpdateReturn(updater) +func (m *MManager) StateUpdate(updater Update) error { + _, err := m.StateUpdateReturn(updater) return err } // applies the provided updater to the current state. Returns the new state and err -func (m *MManager) SafeStateUpdateReturn(updater StateUpdate) (State, error) { +func (m *MManager) StateUpdateReturn(updater Update) (State, error) { m.mut.Lock() defer m.mut.Unlock() @@ -114,7 +114,7 @@ func (m *MManager) SerializeShipMetadata(metadata api.ShipAppMetadata, applicati debug := level.Debug(log.With(m.Logger, "method", "SerializeShipMetadata")) debug.Log("event", "safeStateUpdate") - return m.SafeStateUpdate(func(state VersionedState) (VersionedState, error) { + return m.StateUpdate(func(state VersionedState) (VersionedState, error) { state.V1.Metadata = &Metadata{ ApplicationType: applicationType, ReleaseNotes: metadata.ReleaseNotes, @@ -131,22 +131,23 @@ func (m *MManager) SerializeAppMetadata(metadata api.ReleaseMetadata) error { debug := level.Debug(log.With(m.Logger, "method", "SerializeAppMetadata")) debug.Log("event", "safeStateUpdate") - return m.SafeStateUpdate(func(state VersionedState) (VersionedState, error) { - state.V1.Metadata = &Metadata{ - ApplicationType: "replicated.app", - ReleaseNotes: metadata.ReleaseNotes, - Version: metadata.Semver, - CustomerID: metadata.CustomerID, - InstallationID: metadata.InstallationID, - LicenseID: metadata.LicenseID, - AppSlug: metadata.AppSlug, - License: License{ - ID: metadata.License.ID, - Assignee: metadata.License.Assignee, - CreatedAt: metadata.License.CreatedAt, - ExpiresAt: metadata.License.ExpiresAt, - Type: metadata.License.Type, - }, + return m.StateUpdate(func(state VersionedState) (VersionedState, error) { + if state.V1.Metadata == nil { + state.V1.Metadata = &Metadata{} + } + state.V1.Metadata.ApplicationType = "replicated.app" + state.V1.Metadata.ReleaseNotes = metadata.ReleaseNotes + state.V1.Metadata.Version = metadata.Semver + state.V1.Metadata.CustomerID = metadata.CustomerID + state.V1.Metadata.InstallationID = metadata.InstallationID + state.V1.Metadata.LicenseID = metadata.LicenseID + state.V1.Metadata.AppSlug = metadata.AppSlug + state.V1.Metadata.License = License{ + ID: metadata.License.ID, + Assignee: metadata.License.Assignee, + CreatedAt: metadata.License.CreatedAt, + ExpiresAt: metadata.License.ExpiresAt, + Type: metadata.License.Type, } return state, nil }) @@ -157,7 +158,7 @@ func (m *MManager) SerializeUpstream(upstream string) error { debug := level.Debug(log.With(m.Logger, "method", "SerializeUpstream")) debug.Log("event", "safeStateUpdate") - return m.SafeStateUpdate(func(state VersionedState) (VersionedState, error) { + return m.StateUpdate(func(state VersionedState) (VersionedState, error) { state.V1.Upstream = upstream return state, nil }) @@ -168,7 +169,7 @@ func (m *MManager) SerializeContentSHA(contentSHA string) error { debug := level.Debug(log.With(m.Logger, "method", "SerializeContentSHA")) debug.Log("event", "safeStateUpdate") - return m.SafeStateUpdate(func(state VersionedState) (VersionedState, error) { + return m.StateUpdate(func(state VersionedState) (VersionedState, error) { state.V1.ContentSHA = contentSHA return state, nil }) @@ -179,7 +180,7 @@ func (m *MManager) SerializeHelmValues(values string, defaults string) error { debug := level.Debug(log.With(m.Logger, "method", "serializeHelmValues")) debug.Log("event", "safeStateUpdate") - return m.SafeStateUpdate(func(state VersionedState) (VersionedState, error) { + return m.StateUpdate(func(state VersionedState) (VersionedState, error) { state.V1.HelmValues = values state.V1.HelmValuesDefaults = defaults return state, nil @@ -191,7 +192,7 @@ func (m *MManager) SerializeReleaseName(name string) error { debug := level.Debug(log.With(m.Logger, "method", "serializeReleaseName")) debug.Log("event", "safeStateUpdate") - return m.SafeStateUpdate(func(state VersionedState) (VersionedState, error) { + return m.StateUpdate(func(state VersionedState) (VersionedState, error) { state.V1.ReleaseName = name return state, nil }) @@ -202,7 +203,7 @@ func (m *MManager) SerializeNamespace(namespace string) error { debug := level.Debug(log.With(m.Logger, "method", "serializeNamespace")) debug.Log("event", "safeStateUpdate") - return m.SafeStateUpdate(func(state VersionedState) (VersionedState, error) { + return m.StateUpdate(func(state VersionedState) (VersionedState, error) { state.V1.Namespace = namespace return state, nil }) @@ -213,7 +214,7 @@ func (m *MManager) SerializeConfig(assets []api.Asset, meta api.ReleaseMetadata, debug := level.Debug(log.With(m.Logger, "method", "serializeConfig")) debug.Log("event", "safeStateUpdate") - return m.SafeStateUpdate(func(state VersionedState) (VersionedState, error) { + return m.StateUpdate(func(state VersionedState) (VersionedState, error) { state.V1.Config = templateContext return state, nil }) @@ -223,7 +224,7 @@ func (m *MManager) SerializeListsMetadata(list util.List) error { debug := level.Debug(log.With(m.Logger, "method", "serializeListMetadata")) debug.Log("event", "safeStateUpdate") - return m.SafeStateUpdate(func(state VersionedState) (VersionedState, error) { + return m.StateUpdate(func(state VersionedState) (VersionedState, error) { if state.V1.Metadata == nil { state.V1.Metadata = &Metadata{} } @@ -236,7 +237,7 @@ func (m *MManager) ClearListsMetadata() error { debug := level.Debug(log.With(m.Logger, "method", "clearListMetadata")) debug.Log("event", "safeStateUpdate") - return m.SafeStateUpdate(func(state VersionedState) (VersionedState, error) { + return m.StateUpdate(func(state VersionedState) (VersionedState, error) { if state.V1.Metadata == nil { return state, nil } @@ -251,7 +252,7 @@ func (m *MManager) SerializeUpstreamContents(contents *UpstreamContents) error { debug := level.Debug(log.With(m.Logger, "method", "serializeUpstreamContents")) debug.Log("event", "safeStateUpdate") - return m.SafeStateUpdate(func(state VersionedState) (VersionedState, error) { + return m.StateUpdate(func(state VersionedState) (VersionedState, error) { state.V1.UpstreamContents = contents return state, nil @@ -284,7 +285,7 @@ func (m *MManager) ResetLifecycle() error { debug := level.Debug(log.With(m.Logger, "method", "ResetLifecycle")) debug.Log("event", "safeStateUpdate") - return m.SafeStateUpdate(func(state VersionedState) (VersionedState, error) { + return m.StateUpdate(func(state VersionedState) (VersionedState, error) { state.V1.Lifecycle = nil return state, nil @@ -395,7 +396,7 @@ func (m *MManager) SaveKustomize(kustomize *Kustomize) error { debug := level.Debug(log.With(m.Logger, "method", "SaveKustomize")) debug.Log("event", "safeStateUpdate") - return m.SafeStateUpdate(func(state VersionedState) (VersionedState, error) { + return m.StateUpdate(func(state VersionedState) (VersionedState, error) { state.V1.Kustomize = kustomize return state, nil @@ -497,7 +498,7 @@ func (m *MManager) AddCert(name string, newCert util.CertType) error { debug := level.Debug(log.With(m.Logger, "method", "SaveKustomize")) debug.Log("event", "safeStateUpdate") - return m.SafeStateUpdate(func(state VersionedState) (VersionedState, error) { + return m.StateUpdate(func(state VersionedState) (VersionedState, error) { if state.V1.Certs == nil { state.V1.Certs = make(map[string]util.CertType) @@ -514,7 +515,7 @@ func (m *MManager) AddCA(name string, newCA util.CAType) error { debug := level.Debug(log.With(m.Logger, "method", "SaveKustomize")) debug.Log("event", "safeStateUpdate") - return m.SafeStateUpdate(func(state VersionedState) (VersionedState, error) { + return m.StateUpdate(func(state VersionedState) (VersionedState, error) { if state.V1.CAs == nil { state.V1.CAs = make(map[string]util.CAType) diff --git a/pkg/state/manager_test.go b/pkg/state/manager_test.go index 6690ce219..986f951e4 100644 --- a/pkg/state/manager_test.go +++ b/pkg/state/manager_test.go @@ -96,11 +96,7 @@ func TestLoadConfig(t *testing.T) { req.NoError(err, "write existing state") } - manager := &MManager{ - Logger: &logger.TestLogger{T: t}, - FS: fs, - V: viper.New(), - } + manager := NewManager(&logger.TestLogger{T: t}, fs, viper.New()) state, err := manager.TryLoad() req.NoError(err) @@ -158,11 +154,7 @@ func TestHelmValue(t *testing.T) { req := require.New(t) fs := afero.Afero{Fs: afero.NewMemMapFs()} - manager := &MManager{ - Logger: &logger.TestLogger{T: t}, - FS: fs, - V: viper.New(), - } + manager := NewManager(&logger.TestLogger{T: t}, fs, viper.New()) err := manager.SerializeHelmValues(test.userInputValues, test.chartValuesOnInit) req.NoError(err) @@ -235,7 +227,7 @@ func TestMManager_SerializeChartURL(t *testing.T) { t.Run(tt.name, func(t *testing.T) { req := require.New(t) m := &MManager{ - Logger: log.NewNopLogger(), + Logger: &logger.TestLogger{T: t}, FS: afero.Afero{Fs: afero.NewMemMapFs()}, V: viper.New(), } @@ -312,7 +304,7 @@ func TestMManager_SerializeContentSHA(t *testing.T) { t.Run(tt.name, func(t *testing.T) { req := require.New(t) m := &MManager{ - Logger: log.NewNopLogger(), + Logger: &logger.TestLogger{T: t}, FS: afero.Afero{Fs: afero.NewMemMapFs()}, V: viper.New(), } @@ -390,7 +382,7 @@ func TestMManager_SerializeHelmValues(t *testing.T) { t.Run(tt.name, func(t *testing.T) { req := require.New(t) m := &MManager{ - Logger: log.NewNopLogger(), + Logger: &logger.TestLogger{T: t}, FS: afero.Afero{Fs: afero.NewMemMapFs()}, V: viper.New(), } @@ -448,7 +440,7 @@ func TestMManager_SerializeShipMetadata(t *testing.T) { t.Run(tt.name, func(t *testing.T) { req := require.New(t) m := &MManager{ - Logger: log.NewNopLogger(), + Logger: &logger.TestLogger{T: t}, FS: afero.Afero{Fs: afero.NewMemMapFs()}, V: viper.New(), } @@ -501,7 +493,7 @@ func TestMManager_ResetLifecycle(t *testing.T) { t.Run(tt.name, func(t *testing.T) { req := require.New(t) m := &MManager{ - Logger: log.NewNopLogger(), + Logger: &logger.TestLogger{T: t}, FS: afero.Afero{Fs: afero.NewMemMapFs()}, V: viper.New(), } @@ -542,6 +534,79 @@ func TestMManager_ParallelUpdates(t *testing.T) { req.Len(state.V1.Metadata.Lists, 20) }, }, + { + name: "emptied lists", + runners: []func(*MManager, *require.Assertions, *sync.WaitGroup){ + func(m *MManager, req *require.Assertions, group *sync.WaitGroup) { + err := m.ClearListsMetadata() + req.NoError(err) + + // add the integers 1-20 to the list + for i := 1; i <= 20; i++ { + err := m.SerializeListsMetadata(util.List{APIVersion: fmt.Sprintf("%d", i)}) + req.NoError(err) + } + + err = m.ClearListsMetadata() + req.NoError(err) + + group.Done() + }, + }, + validator: func(state VersionedState, req *require.Assertions) { + req.Len(state.V1.Metadata.Lists, 0) + }, + }, + { + name: "lists and app metadata", + runners: []func(*MManager, *require.Assertions, *sync.WaitGroup){ + func(m *MManager, req *require.Assertions, group *sync.WaitGroup) { + // add the integers 1-20 to the list + for i := 1; i <= 20; i++ { + err := m.SerializeListsMetadata(util.List{APIVersion: fmt.Sprintf("%d", i)}) + req.NoError(err) + } + group.Done() + }, + func(m *MManager, req *require.Assertions, group *sync.WaitGroup) { + err := m.SerializeAppMetadata(api.ReleaseMetadata{Semver: "tested"}) + req.NoError(err) + group.Done() + }, + }, + validator: func(state VersionedState, req *require.Assertions) { + req.Len(state.V1.Metadata.Lists, 20) + req.Equal("tested", state.V1.Metadata.Version) + }, + }, + { + name: "lists, release name and namespace", + runners: []func(*MManager, *require.Assertions, *sync.WaitGroup){ + func(m *MManager, req *require.Assertions, group *sync.WaitGroup) { + // add the integers 1-20 to the list + for i := 1; i <= 20; i++ { + err := m.SerializeListsMetadata(util.List{APIVersion: fmt.Sprintf("%d", i)}) + req.NoError(err) + } + group.Done() + }, + func(m *MManager, req *require.Assertions, group *sync.WaitGroup) { + err := m.SerializeReleaseName("testedName") + req.NoError(err) + group.Done() + }, + func(m *MManager, req *require.Assertions, group *sync.WaitGroup) { + err := m.SerializeNamespace("testedNS") + req.NoError(err) + group.Done() + }, + }, + validator: func(state VersionedState, req *require.Assertions) { + req.Len(state.V1.Metadata.Lists, 20) + req.Equal("testedName", state.CurrentReleaseName()) + req.Equal("testedNS", state.CurrentNamespace()) + }, + }, { name: "lists and upstream", runners: []func(*MManager, *require.Assertions, *sync.WaitGroup){ @@ -558,7 +623,7 @@ func TestMManager_ParallelUpdates(t *testing.T) { func(m *MManager, req *require.Assertions, group *sync.WaitGroup) { // append the integers 1-200 to the upstream for i := 1; i <= 200; i++ { - err := m.SafeStateUpdate(func(state VersionedState) (VersionedState, error) { + err := m.StateUpdate(func(state VersionedState) (VersionedState, error) { state.V1.Upstream += fmt.Sprintf(" a:%d ", i) return state, nil }) @@ -570,7 +635,7 @@ func TestMManager_ParallelUpdates(t *testing.T) { func(m *MManager, req *require.Assertions, group *sync.WaitGroup) { // append the integers 1-200 to the upstream for i := 1; i <= 200; i++ { - err := m.SafeStateUpdate(func(state VersionedState) (VersionedState, error) { + err := m.StateUpdate(func(state VersionedState) (VersionedState, error) { state.V1.Upstream += fmt.Sprintf(" b:%d ", i) return state, nil }) @@ -582,7 +647,7 @@ func TestMManager_ParallelUpdates(t *testing.T) { func(m *MManager, req *require.Assertions, group *sync.WaitGroup) { // append the integers 1-200 to the upstream for i := 1; i <= 200; i++ { - err := m.SafeStateUpdate(func(state VersionedState) (VersionedState, error) { + err := m.StateUpdate(func(state VersionedState) (VersionedState, error) { state.V1.Upstream += fmt.Sprintf(" c:%d ", i) return state, nil }) @@ -594,7 +659,7 @@ func TestMManager_ParallelUpdates(t *testing.T) { func(m *MManager, req *require.Assertions, group *sync.WaitGroup) { // append the integers 1-200 to the upstream for i := 1; i <= 200; i++ { - err := m.SafeStateUpdate(func(state VersionedState) (VersionedState, error) { + err := m.StateUpdate(func(state VersionedState) (VersionedState, error) { state.V1.Upstream += fmt.Sprintf(" d:%d ", i) return state, nil }) @@ -606,7 +671,7 @@ func TestMManager_ParallelUpdates(t *testing.T) { func(m *MManager, req *require.Assertions, group *sync.WaitGroup) { // append the integers 1-200 to the upstream for i := 1; i <= 200; i++ { - err := m.SafeStateUpdate(func(state VersionedState) (VersionedState, error) { + err := m.StateUpdate(func(state VersionedState) (VersionedState, error) { state.V1.Upstream += fmt.Sprintf(" e:%d ", i) return state, nil }) @@ -626,6 +691,70 @@ func TestMManager_ParallelUpdates(t *testing.T) { } }, }, + { + name: "certs and keys", + runners: []func(*MManager, *require.Assertions, *sync.WaitGroup){ + // first cert mutator + func(m *MManager, req *require.Assertions, group *sync.WaitGroup) { + // append 100 certs to the cert list + for i := 1; i <= 100; i++ { + err := m.AddCert(fmt.Sprintf(" a:%d ", i), util.CertType{}) + req.NoError(err) + } + group.Done() + }, + // second cert mutator + func(m *MManager, req *require.Assertions, group *sync.WaitGroup) { + // append 100 certs to the cert list + for i := 1; i <= 100; i++ { + err := m.AddCert(fmt.Sprintf(" b:%d ", i), util.CertType{}) + req.NoError(err) + } + group.Done() + }, + // third cert mutator + func(m *MManager, req *require.Assertions, group *sync.WaitGroup) { + // append 100 certs to the cert list + for i := 1; i <= 100; i++ { + err := m.AddCert(fmt.Sprintf(" c:%d ", i), util.CertType{}) + req.NoError(err) + } + group.Done() + }, + // first ca mutator + func(m *MManager, req *require.Assertions, group *sync.WaitGroup) { + // append 100 CAs to the CA list + for i := 1; i <= 100; i++ { + err := m.AddCA(fmt.Sprintf(" a:%d ", i), util.CAType{}) + req.NoError(err) + } + group.Done() + }, + // second ca mutator + func(m *MManager, req *require.Assertions, group *sync.WaitGroup) { + // append 100 CAs to the CA list + for i := 1; i <= 100; i++ { + err := m.AddCA(fmt.Sprintf(" b:%d ", i), util.CAType{}) + req.NoError(err) + } + group.Done() + }, + }, + validator: func(state VersionedState, req *require.Assertions) { + totalCAs := state.CurrentCAs() + for _, str := range []string{"a", "b"} { + for i := 1; i <= 100; i++ { + req.Contains(totalCAs, fmt.Sprintf(" %s:%d ", str, i)) + } + } + totalCerts := state.CurrentCerts() + for _, str := range []string{"a", "b", "c"} { + for i := 1; i <= 100; i++ { + req.Contains(totalCerts, fmt.Sprintf(" %s:%d ", str, i)) + } + } + }, + }, } for _, tt := range tests { @@ -633,7 +762,7 @@ func TestMManager_ParallelUpdates(t *testing.T) { req := require.New(t) m := &MManager{ - Logger: log.NewNopLogger(), + Logger: &logger.TestLogger{T: t}, FS: afero.Afero{Fs: afero.NewMemMapFs()}, V: viper.New(), } @@ -657,5 +786,202 @@ func TestMManager_ParallelUpdates(t *testing.T) { tt.validator(actualState.Versioned(), req) }) } +} + +func TestMManager_AddCA(t *testing.T) { + tests := []struct { + name string + caName string + newCA util.CAType + wantErr bool + before VersionedState + expected VersionedState + }{ + { + name: "basic test", + caName: "aCA", + newCA: util.CAType{Cert: "aCert", Key: "aKey"}, + before: VersionedState{ + V1: &V1{ + Upstream: "abc123", + }, + }, + expected: VersionedState{ + V1: &V1{ + Upstream: "abc123", + CAs: map[string]util.CAType{ + "aCA": {Cert: "aCert", Key: "aKey"}, + }, + }, + }, + }, + { + name: "add to existing", + caName: "bCA", + newCA: util.CAType{Cert: "bCert", Key: "bKey"}, + before: VersionedState{ + V1: &V1{ + Upstream: "abc123", + CAs: map[string]util.CAType{ + "aCA": {Cert: "aCert", Key: "aKey"}, + }, + }, + }, + expected: VersionedState{ + V1: &V1{ + Upstream: "abc123", + CAs: map[string]util.CAType{ + "aCA": {Cert: "aCert", Key: "aKey"}, + "bCA": {Cert: "bCert", Key: "bKey"}, + }, + }, + }, + }, + { + name: "colliding ca names", + wantErr: true, + caName: "aCA", + newCA: util.CAType{Cert: "aCert", Key: "aKey"}, + before: VersionedState{ + V1: &V1{ + Upstream: "abc123", + CAs: map[string]util.CAType{ + "aCA": {Cert: "aCert", Key: "aKey"}, + }, + }, + }, + expected: VersionedState{ + V1: &V1{ + Upstream: "abc123", + CAs: map[string]util.CAType{ + "aCA": {Cert: "aCert", Key: "aKey"}, + }, + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := require.New(t) + m := &MManager{ + Logger: &logger.TestLogger{T: t}, + FS: afero.Afero{Fs: afero.NewMemMapFs()}, + V: viper.New(), + } + + err := m.serializeAndWriteState(tt.before) + req.NoError(err) + + err = m.AddCA(tt.caName, tt.newCA) + if !tt.wantErr { + req.NoError(err, "MManager.AddCA() error = %v", err) + } else { + req.Error(err) + } + + actualState, err := m.TryLoad() + req.NoError(err) + + req.Equal(tt.expected, actualState) + }) + } +} + +func TestMManager_AddCert(t *testing.T) { + tests := []struct { + name string + certName string + newCert util.CertType + wantErr bool + before VersionedState + expected VersionedState + }{ + { + name: "basic test", + certName: "aCert", + newCert: util.CertType{Cert: "aCert", Key: "aKey"}, + before: VersionedState{ + V1: &V1{ + Upstream: "abc123", + }, + }, + expected: VersionedState{ + V1: &V1{ + Upstream: "abc123", + Certs: map[string]util.CertType{ + "aCert": {Cert: "aCert", Key: "aKey"}, + }, + }, + }, + }, + { + name: "add to existing", + certName: "bCert", + newCert: util.CertType{Cert: "bCert", Key: "bKey"}, + before: VersionedState{ + V1: &V1{ + Upstream: "abc123", + Certs: map[string]util.CertType{ + "aCert": {Cert: "aCert", Key: "aKey"}, + }, + }, + }, + expected: VersionedState{ + V1: &V1{ + Upstream: "abc123", + Certs: map[string]util.CertType{ + "aCert": {Cert: "aCert", Key: "aKey"}, + "bCert": {Cert: "bCert", Key: "bKey"}, + }, + }, + }, + }, + { + name: "colliding ca names", + wantErr: true, + certName: "aCert", + newCert: util.CertType{Cert: "aCert", Key: "aKey"}, + before: VersionedState{ + V1: &V1{ + Upstream: "abc123", + Certs: map[string]util.CertType{ + "aCert": {Cert: "aCert", Key: "aKey"}, + }, + }, + }, + expected: VersionedState{ + V1: &V1{ + Upstream: "abc123", + Certs: map[string]util.CertType{ + "aCert": {Cert: "aCert", Key: "aKey"}, + }, + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := require.New(t) + m := &MManager{ + Logger: &logger.TestLogger{T: t}, + FS: afero.Afero{Fs: afero.NewMemMapFs()}, + V: viper.New(), + } + + err := m.serializeAndWriteState(tt.before) + req.NoError(err) + err = m.AddCert(tt.certName, tt.newCert) + if !tt.wantErr { + req.NoError(err, "MManager.AddCert() error = %v", err) + } else { + req.Error(err) + } + + actualState, err := m.TryLoad() + req.NoError(err) + + req.Equal(tt.expected, actualState) + }) + } } diff --git a/pkg/test-mocks/state/manager_mock.go b/pkg/test-mocks/state/manager_mock.go index ecf44ce46..7958755bc 100644 --- a/pkg/test-mocks/state/manager_mock.go +++ b/pkg/test-mocks/state/manager_mock.go @@ -96,29 +96,29 @@ func (mr *MockManagerMockRecorder) ResetLifecycle() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ResetLifecycle", reflect.TypeOf((*MockManager)(nil).ResetLifecycle)) } -// SafeStateUpdate mocks base method -func (m *MockManager) SafeStateUpdate(arg0 state.StateUpdate) error { - ret := m.ctrl.Call(m, "SafeStateUpdate", arg0) +// StateUpdate mocks base method +func (m *MockManager) StateUpdate(arg0 state.Update) error { + ret := m.ctrl.Call(m, "StateUpdate", arg0) ret0, _ := ret[0].(error) return ret0 } -// SafeStateUpdate indicates an expected call of SafeStateUpdate +// StateUpdate indicates an expected call of StateUpdate func (mr *MockManagerMockRecorder) SafeStateUpdate(arg0 interface{}) *gomock.Call { - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SafeStateUpdate", reflect.TypeOf((*MockManager)(nil).SafeStateUpdate), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StateUpdate", reflect.TypeOf((*MockManager)(nil).StateUpdate), arg0) } -// SafeStateUpdateReturn mocks base method -func (m *MockManager) SafeStateUpdateReturn(arg0 state.StateUpdate) (state.State, error) { - ret := m.ctrl.Call(m, "SafeStateUpdateReturn", arg0) +// StateUpdateReturn mocks base method +func (m *MockManager) StateUpdateReturn(arg0 state.Update) (state.State, error) { + ret := m.ctrl.Call(m, "StateUpdateReturn", arg0) ret0, _ := ret[0].(state.State) ret1, _ := ret[1].(error) return ret0, ret1 } -// SafeStateUpdateReturn indicates an expected call of SafeStateUpdateReturn +// StateUpdateReturn indicates an expected call of StateUpdateReturn func (mr *MockManagerMockRecorder) SafeStateUpdateReturn(arg0 interface{}) *gomock.Call { - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SafeStateUpdateReturn", reflect.TypeOf((*MockManager)(nil).SafeStateUpdateReturn), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StateUpdateReturn", reflect.TypeOf((*MockManager)(nil).StateUpdateReturn), arg0) } // Save mocks base method From 0388f58660d0673cb497db7bbaced078c7c3dac0 Mon Sep 17 00:00:00 2001 From: Andrew Lavery Date: Thu, 9 May 2019 16:16:58 -0700 Subject: [PATCH 3/4] remove version of StateUpdate that did not return state --- pkg/state/manager.go | 59 ++++++++++++++++------------ pkg/state/manager_test.go | 10 ++--- pkg/test-mocks/state/manager_mock.go | 38 ++++++------------ 3 files changed, 52 insertions(+), 55 deletions(-) diff --git a/pkg/state/manager.go b/pkg/state/manager.go index afbd496e6..e2842e5e7 100644 --- a/pkg/state/manager.go +++ b/pkg/state/manager.go @@ -32,8 +32,7 @@ type Manager interface { templateContext map[string]interface{}, ) error TryLoad() (State, error) - StateUpdate(updater Update) error - StateUpdateReturn(updater Update) (State, error) + StateUpdate(updater Update) (State, error) RemoveStateFile() error SaveKustomize(kustomize *Kustomize) error SerializeUpstream(URL string) error @@ -65,10 +64,11 @@ func (m *MManager) Save(v VersionedState) error { debug := level.Debug(log.With(m.Logger, "method", "SerializeShipMetadata")) debug.Log("event", "safeStateUpdate") - return m.StateUpdate(func(state VersionedState) (VersionedState, error) { + _, err := m.StateUpdate(func(state VersionedState) (VersionedState, error) { state = v return state, nil }) + return err } func NewManager( @@ -85,14 +85,8 @@ func NewManager( type Update func(VersionedState) (VersionedState, error) -// applies the provided updater to the current state. Returns error -func (m *MManager) StateUpdate(updater Update) error { - _, err := m.StateUpdateReturn(updater) - return err -} - // applies the provided updater to the current state. Returns the new state and err -func (m *MManager) StateUpdateReturn(updater Update) (State, error) { +func (m *MManager) StateUpdate(updater Update) (State, error) { m.mut.Lock() defer m.mut.Unlock() @@ -114,7 +108,7 @@ func (m *MManager) SerializeShipMetadata(metadata api.ShipAppMetadata, applicati debug := level.Debug(log.With(m.Logger, "method", "SerializeShipMetadata")) debug.Log("event", "safeStateUpdate") - return m.StateUpdate(func(state VersionedState) (VersionedState, error) { + _, err := m.StateUpdate(func(state VersionedState) (VersionedState, error) { state.V1.Metadata = &Metadata{ ApplicationType: applicationType, ReleaseNotes: metadata.ReleaseNotes, @@ -124,6 +118,7 @@ func (m *MManager) SerializeShipMetadata(metadata api.ShipAppMetadata, applicati } return state, nil }) + return err } // SerializeAppMetadata is used by `ship app` to serialize replicated app metadata to state file @@ -131,7 +126,7 @@ func (m *MManager) SerializeAppMetadata(metadata api.ReleaseMetadata) error { debug := level.Debug(log.With(m.Logger, "method", "SerializeAppMetadata")) debug.Log("event", "safeStateUpdate") - return m.StateUpdate(func(state VersionedState) (VersionedState, error) { + _, err := m.StateUpdate(func(state VersionedState) (VersionedState, error) { if state.V1.Metadata == nil { state.V1.Metadata = &Metadata{} } @@ -151,6 +146,7 @@ func (m *MManager) SerializeAppMetadata(metadata api.ReleaseMetadata) error { } return state, nil }) + return err } // SerializeUpstream is used by `ship init` to serialize a state file with ChartURL to disk @@ -158,10 +154,11 @@ func (m *MManager) SerializeUpstream(upstream string) error { debug := level.Debug(log.With(m.Logger, "method", "SerializeUpstream")) debug.Log("event", "safeStateUpdate") - return m.StateUpdate(func(state VersionedState) (VersionedState, error) { + _, err := m.StateUpdate(func(state VersionedState) (VersionedState, error) { state.V1.Upstream = upstream return state, nil }) + return err } // SerializeContentSHA writes the contentSHA to the state file @@ -169,10 +166,11 @@ func (m *MManager) SerializeContentSHA(contentSHA string) error { debug := level.Debug(log.With(m.Logger, "method", "SerializeContentSHA")) debug.Log("event", "safeStateUpdate") - return m.StateUpdate(func(state VersionedState) (VersionedState, error) { + _, err := m.StateUpdate(func(state VersionedState) (VersionedState, error) { state.V1.ContentSHA = contentSHA return state, nil }) + return err } // SerializeHelmValues takes user input helm values and serializes a state file to disk @@ -180,11 +178,12 @@ func (m *MManager) SerializeHelmValues(values string, defaults string) error { debug := level.Debug(log.With(m.Logger, "method", "serializeHelmValues")) debug.Log("event", "safeStateUpdate") - return m.StateUpdate(func(state VersionedState) (VersionedState, error) { + _, err := m.StateUpdate(func(state VersionedState) (VersionedState, error) { state.V1.HelmValues = values state.V1.HelmValuesDefaults = defaults return state, nil }) + return err } // SerializeReleaseName serializes to disk the name to use for helm template @@ -192,10 +191,11 @@ func (m *MManager) SerializeReleaseName(name string) error { debug := level.Debug(log.With(m.Logger, "method", "serializeReleaseName")) debug.Log("event", "safeStateUpdate") - return m.StateUpdate(func(state VersionedState) (VersionedState, error) { + _, err := m.StateUpdate(func(state VersionedState) (VersionedState, error) { state.V1.ReleaseName = name return state, nil }) + return err } // SerializeNamespace serializes to disk the namespace to use for helm template @@ -203,10 +203,11 @@ func (m *MManager) SerializeNamespace(namespace string) error { debug := level.Debug(log.With(m.Logger, "method", "serializeNamespace")) debug.Log("event", "safeStateUpdate") - return m.StateUpdate(func(state VersionedState) (VersionedState, error) { + _, err := m.StateUpdate(func(state VersionedState) (VersionedState, error) { state.V1.Namespace = namespace return state, nil }) + return err } // SerializeConfig takes the application data and input params and serializes a state file to disk @@ -214,30 +215,32 @@ func (m *MManager) SerializeConfig(assets []api.Asset, meta api.ReleaseMetadata, debug := level.Debug(log.With(m.Logger, "method", "serializeConfig")) debug.Log("event", "safeStateUpdate") - return m.StateUpdate(func(state VersionedState) (VersionedState, error) { + _, err := m.StateUpdate(func(state VersionedState) (VersionedState, error) { state.V1.Config = templateContext return state, nil }) + return err } func (m *MManager) SerializeListsMetadata(list util.List) error { debug := level.Debug(log.With(m.Logger, "method", "serializeListMetadata")) debug.Log("event", "safeStateUpdate") - return m.StateUpdate(func(state VersionedState) (VersionedState, error) { + _, err := m.StateUpdate(func(state VersionedState) (VersionedState, error) { if state.V1.Metadata == nil { state.V1.Metadata = &Metadata{} } state.V1.Metadata.Lists = append(state.V1.Metadata.Lists, list) return state, nil }) + return err } func (m *MManager) ClearListsMetadata() error { debug := level.Debug(log.With(m.Logger, "method", "clearListMetadata")) debug.Log("event", "safeStateUpdate") - return m.StateUpdate(func(state VersionedState) (VersionedState, error) { + _, err := m.StateUpdate(func(state VersionedState) (VersionedState, error) { if state.V1.Metadata == nil { return state, nil } @@ -245,6 +248,7 @@ func (m *MManager) ClearListsMetadata() error { state.V1.Metadata.Lists = []util.List{} return state, nil }) + return err } // SerializeConfig takes the application data and input params and serializes a state file to disk @@ -252,11 +256,12 @@ func (m *MManager) SerializeUpstreamContents(contents *UpstreamContents) error { debug := level.Debug(log.With(m.Logger, "method", "serializeUpstreamContents")) debug.Log("event", "safeStateUpdate") - return m.StateUpdate(func(state VersionedState) (VersionedState, error) { + _, err := m.StateUpdate(func(state VersionedState) (VersionedState, error) { state.V1.UpstreamContents = contents return state, nil }) + return err } // TryLoad will attempt to load a state file from disk, if present @@ -285,11 +290,12 @@ func (m *MManager) ResetLifecycle() error { debug := level.Debug(log.With(m.Logger, "method", "ResetLifecycle")) debug.Log("event", "safeStateUpdate") - return m.StateUpdate(func(state VersionedState) (VersionedState, error) { + _, err := m.StateUpdate(func(state VersionedState) (VersionedState, error) { state.V1.Lifecycle = nil return state, nil }) + return err } // tryLoadFromSecret will attempt to load the state from a secret @@ -396,11 +402,12 @@ func (m *MManager) SaveKustomize(kustomize *Kustomize) error { debug := level.Debug(log.With(m.Logger, "method", "SaveKustomize")) debug.Log("event", "safeStateUpdate") - return m.StateUpdate(func(state VersionedState) (VersionedState, error) { + _, err := m.StateUpdate(func(state VersionedState) (VersionedState, error) { state.V1.Kustomize = kustomize return state, nil }) + return err } // RemoveStateFile will attempt to remove the state file from disk @@ -498,7 +505,7 @@ func (m *MManager) AddCert(name string, newCert util.CertType) error { debug := level.Debug(log.With(m.Logger, "method", "SaveKustomize")) debug.Log("event", "safeStateUpdate") - return m.StateUpdate(func(state VersionedState) (VersionedState, error) { + _, err := m.StateUpdate(func(state VersionedState) (VersionedState, error) { if state.V1.Certs == nil { state.V1.Certs = make(map[string]util.CertType) @@ -509,13 +516,14 @@ func (m *MManager) AddCert(name string, newCert util.CertType) error { state.V1.Certs[name] = newCert return state, nil }) + return err } func (m *MManager) AddCA(name string, newCA util.CAType) error { debug := level.Debug(log.With(m.Logger, "method", "SaveKustomize")) debug.Log("event", "safeStateUpdate") - return m.StateUpdate(func(state VersionedState) (VersionedState, error) { + _, err := m.StateUpdate(func(state VersionedState) (VersionedState, error) { if state.V1.CAs == nil { state.V1.CAs = make(map[string]util.CAType) @@ -526,4 +534,5 @@ func (m *MManager) AddCA(name string, newCA util.CAType) error { state.V1.CAs[name] = newCA return state, nil }) + return err } diff --git a/pkg/state/manager_test.go b/pkg/state/manager_test.go index 986f951e4..70de69789 100644 --- a/pkg/state/manager_test.go +++ b/pkg/state/manager_test.go @@ -623,7 +623,7 @@ func TestMManager_ParallelUpdates(t *testing.T) { func(m *MManager, req *require.Assertions, group *sync.WaitGroup) { // append the integers 1-200 to the upstream for i := 1; i <= 200; i++ { - err := m.StateUpdate(func(state VersionedState) (VersionedState, error) { + _, err := m.StateUpdate(func(state VersionedState) (VersionedState, error) { state.V1.Upstream += fmt.Sprintf(" a:%d ", i) return state, nil }) @@ -635,7 +635,7 @@ func TestMManager_ParallelUpdates(t *testing.T) { func(m *MManager, req *require.Assertions, group *sync.WaitGroup) { // append the integers 1-200 to the upstream for i := 1; i <= 200; i++ { - err := m.StateUpdate(func(state VersionedState) (VersionedState, error) { + _, err := m.StateUpdate(func(state VersionedState) (VersionedState, error) { state.V1.Upstream += fmt.Sprintf(" b:%d ", i) return state, nil }) @@ -647,7 +647,7 @@ func TestMManager_ParallelUpdates(t *testing.T) { func(m *MManager, req *require.Assertions, group *sync.WaitGroup) { // append the integers 1-200 to the upstream for i := 1; i <= 200; i++ { - err := m.StateUpdate(func(state VersionedState) (VersionedState, error) { + _, err := m.StateUpdate(func(state VersionedState) (VersionedState, error) { state.V1.Upstream += fmt.Sprintf(" c:%d ", i) return state, nil }) @@ -659,7 +659,7 @@ func TestMManager_ParallelUpdates(t *testing.T) { func(m *MManager, req *require.Assertions, group *sync.WaitGroup) { // append the integers 1-200 to the upstream for i := 1; i <= 200; i++ { - err := m.StateUpdate(func(state VersionedState) (VersionedState, error) { + _, err := m.StateUpdate(func(state VersionedState) (VersionedState, error) { state.V1.Upstream += fmt.Sprintf(" d:%d ", i) return state, nil }) @@ -671,7 +671,7 @@ func TestMManager_ParallelUpdates(t *testing.T) { func(m *MManager, req *require.Assertions, group *sync.WaitGroup) { // append the integers 1-200 to the upstream for i := 1; i <= 200; i++ { - err := m.StateUpdate(func(state VersionedState) (VersionedState, error) { + _, err := m.StateUpdate(func(state VersionedState) (VersionedState, error) { state.V1.Upstream += fmt.Sprintf(" e:%d ", i) return state, nil }) diff --git a/pkg/test-mocks/state/manager_mock.go b/pkg/test-mocks/state/manager_mock.go index 7958755bc..984d31060 100644 --- a/pkg/test-mocks/state/manager_mock.go +++ b/pkg/test-mocks/state/manager_mock.go @@ -96,31 +96,6 @@ func (mr *MockManagerMockRecorder) ResetLifecycle() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ResetLifecycle", reflect.TypeOf((*MockManager)(nil).ResetLifecycle)) } -// StateUpdate mocks base method -func (m *MockManager) StateUpdate(arg0 state.Update) error { - ret := m.ctrl.Call(m, "StateUpdate", arg0) - ret0, _ := ret[0].(error) - return ret0 -} - -// StateUpdate indicates an expected call of StateUpdate -func (mr *MockManagerMockRecorder) SafeStateUpdate(arg0 interface{}) *gomock.Call { - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StateUpdate", reflect.TypeOf((*MockManager)(nil).StateUpdate), arg0) -} - -// StateUpdateReturn mocks base method -func (m *MockManager) StateUpdateReturn(arg0 state.Update) (state.State, error) { - ret := m.ctrl.Call(m, "StateUpdateReturn", arg0) - ret0, _ := ret[0].(state.State) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// StateUpdateReturn indicates an expected call of StateUpdateReturn -func (mr *MockManagerMockRecorder) SafeStateUpdateReturn(arg0 interface{}) *gomock.Call { - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StateUpdateReturn", reflect.TypeOf((*MockManager)(nil).StateUpdateReturn), arg0) -} - // Save mocks base method func (m *MockManager) Save(arg0 state.VersionedState) error { ret := m.ctrl.Call(m, "Save", arg0) @@ -265,6 +240,19 @@ func (mr *MockManagerMockRecorder) SerializeUpstreamContents(arg0 interface{}) * return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SerializeUpstreamContents", reflect.TypeOf((*MockManager)(nil).SerializeUpstreamContents), arg0) } +// StateUpdate mocks base method +func (m *MockManager) StateUpdate(arg0 state.Update) (state.State, error) { + ret := m.ctrl.Call(m, "StateUpdate", arg0) + ret0, _ := ret[0].(state.State) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// StateUpdate indicates an expected call of StateUpdate +func (mr *MockManagerMockRecorder) StateUpdate(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StateUpdate", reflect.TypeOf((*MockManager)(nil).StateUpdate), arg0) +} + // TryLoad mocks base method func (m *MockManager) TryLoad() (state.State, error) { ret := m.ctrl.Call(m, "TryLoad") From 47581c90778266616c6adf394a38f1dbc030a962 Mon Sep 17 00:00:00 2001 From: Andrew Lavery Date: Fri, 10 May 2019 10:54:30 -0700 Subject: [PATCH 4/4] don't return empty val when we can return nil instead --- pkg/state/manager.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pkg/state/manager.go b/pkg/state/manager.go index e2842e5e7..f6ba5bd9a 100644 --- a/pkg/state/manager.go +++ b/pkg/state/manager.go @@ -92,12 +92,12 @@ func (m *MManager) StateUpdate(updater Update) (State, error) { currentState, err := m.TryLoad() if err != nil { - return VersionedState{}, errors.Wrap(err, "tryLoad in safe updater") + return nil, errors.Wrap(err, "tryLoad in safe updater") } updatedState, err := updater(currentState.Versioned()) if err != nil { - return VersionedState{}, errors.Wrap(err, "run state update function in safe updater") + return nil, errors.Wrap(err, "run state update function in safe updater") } return updatedState, errors.Wrap(m.serializeAndWriteState(updatedState), "write state in safe updater")