diff --git a/pkg/state/manager.go b/pkg/state/manager.go index 1606da27b..f6ba5bd9a 100644 --- a/pkg/state/manager.go +++ b/pkg/state/manager.go @@ -6,6 +6,7 @@ import ( "os" "path/filepath" "strings" + "sync" "github.com/go-kit/kit/log" "github.com/go-kit/kit/log/level" @@ -31,6 +32,7 @@ type Manager interface { templateContext map[string]interface{}, ) error TryLoad() (State, error) + StateUpdate(updater Update) (State, error) RemoveStateFile() error SaveKustomize(kustomize *Kustomize) error SerializeUpstream(URL string) error @@ -55,10 +57,18 @@ type MManager struct { FS afero.Afero V *viper.Viper patcher patch.Patcher + mut sync.Mutex } func (m *MManager) Save(v VersionedState) error { - return m.serializeAndWriteState(v) + debug := level.Debug(log.With(m.Logger, "method", "SerializeShipMetadata")) + + debug.Log("event", "safeStateUpdate") + _, err := m.StateUpdate(func(state VersionedState) (VersionedState, error) { + state = v + return state, nil + }) + return err } func NewManager( @@ -73,200 +83,185 @@ func NewManager( } } -// SerializeShipMetadata is used by `ship init` to serialize metadata from ship applications to state file -func (m *MManager) SerializeShipMetadata(metadata api.ShipAppMetadata, applicationType string) error { - debug := level.Debug(log.With(m.Logger, "method", "SerializeShipMetadata")) +type Update func(VersionedState) (VersionedState, error) - debug.Log("event", "tryLoadState") - current, err := m.TryLoad() +// applies the provided updater to the current state. Returns the new state and err +func (m *MManager) StateUpdate(updater Update) (State, error) { + m.mut.Lock() + defer m.mut.Unlock() + + currentState, err := m.TryLoad() if err != nil { - return errors.Wrap(err, "load state") + return nil, errors.Wrap(err, "tryLoad in safe updater") } - versionedState := current.Versioned() - versionedState.V1.Metadata = &Metadata{ - ApplicationType: applicationType, - ReleaseNotes: metadata.ReleaseNotes, - Version: metadata.Version, - Icon: metadata.Icon, - Name: metadata.Name, + updatedState, err := updater(currentState.Versioned()) + if err != nil { + return nil, errors.Wrap(err, "run state update function in safe updater") } - return m.serializeAndWriteState(versionedState) + return updatedState, errors.Wrap(m.serializeAndWriteState(updatedState), "write state in safe updater") +} + +// SerializeShipMetadata is used by `ship init` to serialize metadata from ship applications to state file +func (m *MManager) SerializeShipMetadata(metadata api.ShipAppMetadata, applicationType string) error { + debug := level.Debug(log.With(m.Logger, "method", "SerializeShipMetadata")) + + debug.Log("event", "safeStateUpdate") + _, err := m.StateUpdate(func(state VersionedState) (VersionedState, error) { + state.V1.Metadata = &Metadata{ + ApplicationType: applicationType, + ReleaseNotes: metadata.ReleaseNotes, + Version: metadata.Version, + Icon: metadata.Icon, + Name: metadata.Name, + } + return state, nil + }) + return err } // SerializeAppMetadata is used by `ship app` to serialize replicated app metadata to state file func (m *MManager) SerializeAppMetadata(metadata api.ReleaseMetadata) error { debug := level.Debug(log.With(m.Logger, "method", "SerializeAppMetadata")) - debug.Log("event", "tryLoadState") - current, err := m.TryLoad() - if err != nil { - return errors.Wrap(err, "load state") - } - - versionedState := current.Versioned() - versionedState.V1.Metadata = &Metadata{ - ApplicationType: "replicated.app", - ReleaseNotes: metadata.ReleaseNotes, - Version: metadata.Semver, - CustomerID: metadata.CustomerID, - InstallationID: metadata.InstallationID, - LicenseID: metadata.LicenseID, - AppSlug: metadata.AppSlug, - License: License{ + debug.Log("event", "safeStateUpdate") + _, err := m.StateUpdate(func(state VersionedState) (VersionedState, error) { + if state.V1.Metadata == nil { + state.V1.Metadata = &Metadata{} + } + state.V1.Metadata.ApplicationType = "replicated.app" + state.V1.Metadata.ReleaseNotes = metadata.ReleaseNotes + state.V1.Metadata.Version = metadata.Semver + state.V1.Metadata.CustomerID = metadata.CustomerID + state.V1.Metadata.InstallationID = metadata.InstallationID + state.V1.Metadata.LicenseID = metadata.LicenseID + state.V1.Metadata.AppSlug = metadata.AppSlug + state.V1.Metadata.License = License{ ID: metadata.License.ID, Assignee: metadata.License.Assignee, CreatedAt: metadata.License.CreatedAt, ExpiresAt: metadata.License.ExpiresAt, Type: metadata.License.Type, - }, - } - - return m.serializeAndWriteState(versionedState) + } + return state, nil + }) + return err } // SerializeUpstream is used by `ship init` to serialize a state file with ChartURL to disk func (m *MManager) SerializeUpstream(upstream string) error { debug := level.Debug(log.With(m.Logger, "method", "SerializeUpstream")) - current, err := m.TryLoad() - if err != nil { - return errors.Wrap(err, "load state") - } - debug.Log("event", "generateUpstreamURLState") - - toSerialize := current.Versioned() - toSerialize.V1.Upstream = upstream - - return m.serializeAndWriteState(toSerialize) + debug.Log("event", "safeStateUpdate") + _, err := m.StateUpdate(func(state VersionedState) (VersionedState, error) { + state.V1.Upstream = upstream + return state, nil + }) + return err } // SerializeContentSHA writes the contentSHA to the state file func (m *MManager) SerializeContentSHA(contentSHA string) error { debug := level.Debug(log.With(m.Logger, "method", "SerializeContentSHA")) - debug.Log("event", "tryLoadState") - currentState, err := m.TryLoad() - if err != nil { - return errors.Wrap(err, "try load state") - } - versionedState := currentState.Versioned() - versionedState.V1.ContentSHA = contentSHA - - return m.serializeAndWriteState(versionedState) + debug.Log("event", "safeStateUpdate") + _, err := m.StateUpdate(func(state VersionedState) (VersionedState, error) { + state.V1.ContentSHA = contentSHA + return state, nil + }) + return err } // SerializeHelmValues takes user input helm values and serializes a state file to disk func (m *MManager) SerializeHelmValues(values string, defaults string) error { debug := level.Debug(log.With(m.Logger, "method", "serializeHelmValues")) - debug.Log("event", "tryLoadState") - currentState, err := m.TryLoad() - if err != nil { - return errors.Wrap(err, "try load state") - } - versionedState := currentState.Versioned() - versionedState.V1.HelmValues = values - versionedState.V1.HelmValuesDefaults = defaults - - return m.serializeAndWriteState(versionedState) + debug.Log("event", "safeStateUpdate") + _, err := m.StateUpdate(func(state VersionedState) (VersionedState, error) { + state.V1.HelmValues = values + state.V1.HelmValuesDefaults = defaults + return state, nil + }) + return err } // SerializeReleaseName serializes to disk the name to use for helm template func (m *MManager) SerializeReleaseName(name string) error { - debug := level.Debug(log.With(m.Logger, "method", "serializeHelmValues")) + debug := level.Debug(log.With(m.Logger, "method", "serializeReleaseName")) - debug.Log("event", "tryLoadState") - currentState, err := m.TryLoad() - if err != nil { - return errors.Wrap(err, "try load state") - } - versionedState := currentState.Versioned() - versionedState.V1.ReleaseName = name - - return m.serializeAndWriteState(versionedState) + debug.Log("event", "safeStateUpdate") + _, err := m.StateUpdate(func(state VersionedState) (VersionedState, error) { + state.V1.ReleaseName = name + return state, nil + }) + return err } // SerializeNamespace serializes to disk the namespace to use for helm template func (m *MManager) SerializeNamespace(namespace string) error { - debug := level.Debug(log.With(m.Logger, "method", "serializeHelmValues")) + debug := level.Debug(log.With(m.Logger, "method", "serializeNamespace")) - debug.Log("event", "tryLoadState") - currentState, err := m.TryLoad() - if err != nil { - return errors.Wrap(err, "try load state") - } - versionedState := currentState.Versioned() - versionedState.V1.Namespace = namespace - - return m.serializeAndWriteState(versionedState) + debug.Log("event", "safeStateUpdate") + _, err := m.StateUpdate(func(state VersionedState) (VersionedState, error) { + state.V1.Namespace = namespace + return state, nil + }) + return err } // SerializeConfig takes the application data and input params and serializes a state file to disk func (m *MManager) SerializeConfig(assets []api.Asset, meta api.ReleaseMetadata, templateContext map[string]interface{}) error { debug := level.Debug(log.With(m.Logger, "method", "serializeConfig")) - debug.Log("event", "tryLoadState") - currentState, err := m.TryLoad() - if err != nil { - return errors.Wrap(err, "try load state") - } - versionedState := currentState.Versioned() - versionedState.V1.Config = templateContext - - return m.serializeAndWriteState(versionedState) + debug.Log("event", "safeStateUpdate") + _, err := m.StateUpdate(func(state VersionedState) (VersionedState, error) { + state.V1.Config = templateContext + return state, nil + }) + return err } func (m *MManager) SerializeListsMetadata(list util.List) error { debug := level.Debug(log.With(m.Logger, "method", "serializeListMetadata")) - debug.Log("event", "tryLoadState") - currentState, err := m.TryLoad() - if err != nil { - return errors.Wrap(err, "try load state") - } - - versionedState := currentState.Versioned() - if versionedState.V1.Metadata == nil { - versionedState.V1.Metadata = &Metadata{} - } - versionedState.V1.Metadata.Lists = append(versionedState.V1.Metadata.Lists, list) - - return m.serializeAndWriteState(versionedState) + debug.Log("event", "safeStateUpdate") + _, err := m.StateUpdate(func(state VersionedState) (VersionedState, error) { + if state.V1.Metadata == nil { + state.V1.Metadata = &Metadata{} + } + state.V1.Metadata.Lists = append(state.V1.Metadata.Lists, list) + return state, nil + }) + return err } func (m *MManager) ClearListsMetadata() error { - debug := level.Debug(log.With(m.Logger, "method", "serializeListMetadata")) + debug := level.Debug(log.With(m.Logger, "method", "clearListMetadata")) - debug.Log("event", "tryLoadState") - currentState, err := m.TryLoad() - if err != nil { - return errors.Wrap(err, "try load state") - } + debug.Log("event", "safeStateUpdate") + _, err := m.StateUpdate(func(state VersionedState) (VersionedState, error) { + if state.V1.Metadata == nil { + return state, nil + } - versionedState := currentState.Versioned() - if versionedState.V1.Metadata == nil { - return nil - } - versionedState.V1.Metadata.Lists = []util.List{} - - return m.serializeAndWriteState(versionedState) + state.V1.Metadata.Lists = []util.List{} + return state, nil + }) + return err } // SerializeConfig takes the application data and input params and serializes a state file to disk func (m *MManager) SerializeUpstreamContents(contents *UpstreamContents) error { - debug := level.Debug(log.With(m.Logger, "method", "serializeConfig")) + debug := level.Debug(log.With(m.Logger, "method", "serializeUpstreamContents")) - debug.Log("event", "tryLoadState") - currentState, err := m.TryLoad() - if err != nil { - return errors.Wrap(err, "try load state") - } - versionedState := currentState.Versioned() - versionedState.V1.UpstreamContents = contents + debug.Log("event", "safeStateUpdate") + _, err := m.StateUpdate(func(state VersionedState) (VersionedState, error) { - return m.serializeAndWriteState(versionedState) + state.V1.UpstreamContents = contents + return state, nil + }) + return err } // TryLoad will attempt to load a state file from disk, if present @@ -294,15 +289,13 @@ func (m *MManager) TryLoad() (State, error) { func (m *MManager) ResetLifecycle() error { debug := level.Debug(log.With(m.Logger, "method", "ResetLifecycle")) - debug.Log("event", "tryLoadState") - currentState, err := m.TryLoad() - if err != nil { - return errors.Wrap(err, "try load state") - } - versionedState := currentState.Versioned() - versionedState.V1.Lifecycle = nil + debug.Log("event", "safeStateUpdate") + _, err := m.StateUpdate(func(state VersionedState) (VersionedState, error) { - return m.serializeAndWriteState(versionedState) + state.V1.Lifecycle = nil + return state, nil + }) + return err } // tryLoadFromSecret will attempt to load the state from a secret @@ -406,18 +399,15 @@ func (m *MManager) tryLoadFromFile() (State, error) { } func (m *MManager) SaveKustomize(kustomize *Kustomize) error { - currentState, err := m.TryLoad() - if err != nil { - return errors.Wrapf(err, "load state") - } - versionedState := currentState.Versioned() - versionedState.V1.Kustomize = kustomize + debug := level.Debug(log.With(m.Logger, "method", "SaveKustomize")) - if err := m.serializeAndWriteState(versionedState); err != nil { - return errors.Wrap(err, "write state") - } + debug.Log("event", "safeStateUpdate") + _, err := m.StateUpdate(func(state VersionedState) (VersionedState, error) { - return nil + state.V1.Kustomize = kustomize + return state, nil + }) + return err } // RemoveStateFile will attempt to remove the state file from disk @@ -512,36 +502,37 @@ func (m *MManager) serializeAndWriteStateSecret(state VersionedState) error { } func (m *MManager) AddCert(name string, newCert util.CertType) error { - currentState, err := m.TryLoad() - if err != nil { - return errors.Wrapf(err, "load state") - } - versionedState := currentState.Versioned() - if versionedState.V1.Certs == nil { - versionedState.V1.Certs = make(map[string]util.CertType) - } - if _, ok := versionedState.V1.Certs[name]; ok { - return fmt.Errorf("cert with name %s already exists in state", name) - } - versionedState.V1.Certs[name] = newCert - - return errors.Wrap(m.serializeAndWriteState(versionedState), "write state") + debug := level.Debug(log.With(m.Logger, "method", "SaveKustomize")) + + debug.Log("event", "safeStateUpdate") + _, err := m.StateUpdate(func(state VersionedState) (VersionedState, error) { + + if state.V1.Certs == nil { + state.V1.Certs = make(map[string]util.CertType) + } + if _, ok := state.V1.Certs[name]; ok { + return state, fmt.Errorf("cert with name %s already exists in state", name) + } + state.V1.Certs[name] = newCert + return state, nil + }) + return err } func (m *MManager) AddCA(name string, newCA util.CAType) error { - currentState, err := m.TryLoad() - if err != nil { - return errors.Wrapf(err, "load state") - } - versionedState := currentState.Versioned() - if versionedState.V1.CAs == nil { - versionedState.V1.CAs = make(map[string]util.CAType) - } - if _, ok := versionedState.V1.CAs[name]; ok { - return fmt.Errorf("cert with name %s already exists in state", name) - } - versionedState.V1.CAs[name] = newCA - - return errors.Wrap(m.serializeAndWriteState(versionedState), "write state") - + debug := level.Debug(log.With(m.Logger, "method", "SaveKustomize")) + + debug.Log("event", "safeStateUpdate") + _, err := m.StateUpdate(func(state VersionedState) (VersionedState, error) { + + if state.V1.CAs == nil { + state.V1.CAs = make(map[string]util.CAType) + } + if _, ok := state.V1.CAs[name]; ok { + return state, fmt.Errorf("cert with name %s already exists in state", name) + } + state.V1.CAs[name] = newCA + return state, nil + }) + return err } diff --git a/pkg/state/manager_test.go b/pkg/state/manager_test.go index 962342c17..70de69789 100644 --- a/pkg/state/manager_test.go +++ b/pkg/state/manager_test.go @@ -1,6 +1,8 @@ package state import ( + "fmt" + "sync" "testing" "github.com/go-kit/kit/log" @@ -8,6 +10,8 @@ import ( "github.com/replicatedhq/ship/pkg/api" "github.com/replicatedhq/ship/pkg/constants" "github.com/replicatedhq/ship/pkg/testing/logger" + "github.com/replicatedhq/ship/pkg/util" + "github.com/spf13/afero" "github.com/spf13/viper" "github.com/stretchr/testify/require" @@ -92,11 +96,7 @@ func TestLoadConfig(t *testing.T) { req.NoError(err, "write existing state") } - manager := &MManager{ - Logger: &logger.TestLogger{T: t}, - FS: fs, - V: viper.New(), - } + manager := NewManager(&logger.TestLogger{T: t}, fs, viper.New()) state, err := manager.TryLoad() req.NoError(err) @@ -154,11 +154,7 @@ func TestHelmValue(t *testing.T) { req := require.New(t) fs := afero.Afero{Fs: afero.NewMemMapFs()} - manager := &MManager{ - Logger: &logger.TestLogger{T: t}, - FS: fs, - V: viper.New(), - } + manager := NewManager(&logger.TestLogger{T: t}, fs, viper.New()) err := manager.SerializeHelmValues(test.userInputValues, test.chartValuesOnInit) req.NoError(err) @@ -231,7 +227,7 @@ func TestMManager_SerializeChartURL(t *testing.T) { t.Run(tt.name, func(t *testing.T) { req := require.New(t) m := &MManager{ - Logger: log.NewNopLogger(), + Logger: &logger.TestLogger{T: t}, FS: afero.Afero{Fs: afero.NewMemMapFs()}, V: viper.New(), } @@ -308,7 +304,7 @@ func TestMManager_SerializeContentSHA(t *testing.T) { t.Run(tt.name, func(t *testing.T) { req := require.New(t) m := &MManager{ - Logger: log.NewNopLogger(), + Logger: &logger.TestLogger{T: t}, FS: afero.Afero{Fs: afero.NewMemMapFs()}, V: viper.New(), } @@ -386,7 +382,7 @@ func TestMManager_SerializeHelmValues(t *testing.T) { t.Run(tt.name, func(t *testing.T) { req := require.New(t) m := &MManager{ - Logger: log.NewNopLogger(), + Logger: &logger.TestLogger{T: t}, FS: afero.Afero{Fs: afero.NewMemMapFs()}, V: viper.New(), } @@ -444,7 +440,7 @@ func TestMManager_SerializeShipMetadata(t *testing.T) { t.Run(tt.name, func(t *testing.T) { req := require.New(t) m := &MManager{ - Logger: log.NewNopLogger(), + Logger: &logger.TestLogger{T: t}, FS: afero.Afero{Fs: afero.NewMemMapFs()}, V: viper.New(), } @@ -497,7 +493,7 @@ func TestMManager_ResetLifecycle(t *testing.T) { t.Run(tt.name, func(t *testing.T) { req := require.New(t) m := &MManager{ - Logger: log.NewNopLogger(), + Logger: &logger.TestLogger{T: t}, FS: afero.Afero{Fs: afero.NewMemMapFs()}, V: viper.New(), } @@ -515,3 +511,477 @@ func TestMManager_ResetLifecycle(t *testing.T) { }) } } + +func TestMManager_ParallelUpdates(t *testing.T) { + tests := []struct { + name string + runners []func(*MManager, *require.Assertions, *sync.WaitGroup) + validator func(VersionedState, *require.Assertions) + }{ + { + name: "lists", + runners: []func(*MManager, *require.Assertions, *sync.WaitGroup){ + func(m *MManager, req *require.Assertions, group *sync.WaitGroup) { + // add the integers 1-20 to the list + for i := 1; i <= 20; i++ { + err := m.SerializeListsMetadata(util.List{APIVersion: fmt.Sprintf("%d", i)}) + req.NoError(err) + } + group.Done() + }, + }, + validator: func(state VersionedState, req *require.Assertions) { + req.Len(state.V1.Metadata.Lists, 20) + }, + }, + { + name: "emptied lists", + runners: []func(*MManager, *require.Assertions, *sync.WaitGroup){ + func(m *MManager, req *require.Assertions, group *sync.WaitGroup) { + err := m.ClearListsMetadata() + req.NoError(err) + + // add the integers 1-20 to the list + for i := 1; i <= 20; i++ { + err := m.SerializeListsMetadata(util.List{APIVersion: fmt.Sprintf("%d", i)}) + req.NoError(err) + } + + err = m.ClearListsMetadata() + req.NoError(err) + + group.Done() + }, + }, + validator: func(state VersionedState, req *require.Assertions) { + req.Len(state.V1.Metadata.Lists, 0) + }, + }, + { + name: "lists and app metadata", + runners: []func(*MManager, *require.Assertions, *sync.WaitGroup){ + func(m *MManager, req *require.Assertions, group *sync.WaitGroup) { + // add the integers 1-20 to the list + for i := 1; i <= 20; i++ { + err := m.SerializeListsMetadata(util.List{APIVersion: fmt.Sprintf("%d", i)}) + req.NoError(err) + } + group.Done() + }, + func(m *MManager, req *require.Assertions, group *sync.WaitGroup) { + err := m.SerializeAppMetadata(api.ReleaseMetadata{Semver: "tested"}) + req.NoError(err) + group.Done() + }, + }, + validator: func(state VersionedState, req *require.Assertions) { + req.Len(state.V1.Metadata.Lists, 20) + req.Equal("tested", state.V1.Metadata.Version) + }, + }, + { + name: "lists, release name and namespace", + runners: []func(*MManager, *require.Assertions, *sync.WaitGroup){ + func(m *MManager, req *require.Assertions, group *sync.WaitGroup) { + // add the integers 1-20 to the list + for i := 1; i <= 20; i++ { + err := m.SerializeListsMetadata(util.List{APIVersion: fmt.Sprintf("%d", i)}) + req.NoError(err) + } + group.Done() + }, + func(m *MManager, req *require.Assertions, group *sync.WaitGroup) { + err := m.SerializeReleaseName("testedName") + req.NoError(err) + group.Done() + }, + func(m *MManager, req *require.Assertions, group *sync.WaitGroup) { + err := m.SerializeNamespace("testedNS") + req.NoError(err) + group.Done() + }, + }, + validator: func(state VersionedState, req *require.Assertions) { + req.Len(state.V1.Metadata.Lists, 20) + req.Equal("testedName", state.CurrentReleaseName()) + req.Equal("testedNS", state.CurrentNamespace()) + }, + }, + { + name: "lists and upstream", + runners: []func(*MManager, *require.Assertions, *sync.WaitGroup){ + // lists + func(m *MManager, req *require.Assertions, group *sync.WaitGroup) { + // add the integers 1-20 to the list + for i := 1; i <= 20; i++ { + err := m.SerializeListsMetadata(util.List{APIVersion: fmt.Sprintf("%d", i)}) + req.NoError(err) + } + group.Done() + }, + // first upstream mutator + func(m *MManager, req *require.Assertions, group *sync.WaitGroup) { + // append the integers 1-200 to the upstream + for i := 1; i <= 200; i++ { + _, err := m.StateUpdate(func(state VersionedState) (VersionedState, error) { + state.V1.Upstream += fmt.Sprintf(" a:%d ", i) + return state, nil + }) + req.NoError(err) + } + group.Done() + }, + // second upstream mutator + func(m *MManager, req *require.Assertions, group *sync.WaitGroup) { + // append the integers 1-200 to the upstream + for i := 1; i <= 200; i++ { + _, err := m.StateUpdate(func(state VersionedState) (VersionedState, error) { + state.V1.Upstream += fmt.Sprintf(" b:%d ", i) + return state, nil + }) + req.NoError(err) + } + group.Done() + }, + // third upstream mutator + func(m *MManager, req *require.Assertions, group *sync.WaitGroup) { + // append the integers 1-200 to the upstream + for i := 1; i <= 200; i++ { + _, err := m.StateUpdate(func(state VersionedState) (VersionedState, error) { + state.V1.Upstream += fmt.Sprintf(" c:%d ", i) + return state, nil + }) + req.NoError(err) + } + group.Done() + }, + // fourth upstream mutator + func(m *MManager, req *require.Assertions, group *sync.WaitGroup) { + // append the integers 1-200 to the upstream + for i := 1; i <= 200; i++ { + _, err := m.StateUpdate(func(state VersionedState) (VersionedState, error) { + state.V1.Upstream += fmt.Sprintf(" d:%d ", i) + return state, nil + }) + req.NoError(err) + } + group.Done() + }, + // fifth upstream mutator + func(m *MManager, req *require.Assertions, group *sync.WaitGroup) { + // append the integers 1-200 to the upstream + for i := 1; i <= 200; i++ { + _, err := m.StateUpdate(func(state VersionedState) (VersionedState, error) { + state.V1.Upstream += fmt.Sprintf(" e:%d ", i) + return state, nil + }) + req.NoError(err) + } + group.Done() + }, + }, + validator: func(state VersionedState, req *require.Assertions) { + req.Len(state.V1.Metadata.Lists, 20) + + totalUpstream := state.Upstream() + for _, str := range []string{"a", "b", "c", "d", "e"} { + for i := 1; i <= 200; i++ { + req.Contains(totalUpstream, fmt.Sprintf(" %s:%d ", str, i)) + } + } + }, + }, + { + name: "certs and keys", + runners: []func(*MManager, *require.Assertions, *sync.WaitGroup){ + // first cert mutator + func(m *MManager, req *require.Assertions, group *sync.WaitGroup) { + // append 100 certs to the cert list + for i := 1; i <= 100; i++ { + err := m.AddCert(fmt.Sprintf(" a:%d ", i), util.CertType{}) + req.NoError(err) + } + group.Done() + }, + // second cert mutator + func(m *MManager, req *require.Assertions, group *sync.WaitGroup) { + // append 100 certs to the cert list + for i := 1; i <= 100; i++ { + err := m.AddCert(fmt.Sprintf(" b:%d ", i), util.CertType{}) + req.NoError(err) + } + group.Done() + }, + // third cert mutator + func(m *MManager, req *require.Assertions, group *sync.WaitGroup) { + // append 100 certs to the cert list + for i := 1; i <= 100; i++ { + err := m.AddCert(fmt.Sprintf(" c:%d ", i), util.CertType{}) + req.NoError(err) + } + group.Done() + }, + // first ca mutator + func(m *MManager, req *require.Assertions, group *sync.WaitGroup) { + // append 100 CAs to the CA list + for i := 1; i <= 100; i++ { + err := m.AddCA(fmt.Sprintf(" a:%d ", i), util.CAType{}) + req.NoError(err) + } + group.Done() + }, + // second ca mutator + func(m *MManager, req *require.Assertions, group *sync.WaitGroup) { + // append 100 CAs to the CA list + for i := 1; i <= 100; i++ { + err := m.AddCA(fmt.Sprintf(" b:%d ", i), util.CAType{}) + req.NoError(err) + } + group.Done() + }, + }, + validator: func(state VersionedState, req *require.Assertions) { + totalCAs := state.CurrentCAs() + for _, str := range []string{"a", "b"} { + for i := 1; i <= 100; i++ { + req.Contains(totalCAs, fmt.Sprintf(" %s:%d ", str, i)) + } + } + totalCerts := state.CurrentCerts() + for _, str := range []string{"a", "b", "c"} { + for i := 1; i <= 100; i++ { + req.Contains(totalCerts, fmt.Sprintf(" %s:%d ", str, i)) + } + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + + req := require.New(t) + m := &MManager{ + Logger: &logger.TestLogger{T: t}, + FS: afero.Afero{Fs: afero.NewMemMapFs()}, + V: viper.New(), + } + + initialState := VersionedState{V1: &V1{Lifecycle: nil}} + + group := sync.WaitGroup{} + + err := m.serializeAndWriteState(initialState) + req.NoError(err) + + group.Add(len(tt.runners)) + for _, runner := range tt.runners { + go runner(m, req, &group) + } + + group.Wait() + actualState, err := m.TryLoad() + req.NoError(err) + + tt.validator(actualState.Versioned(), req) + }) + } +} + +func TestMManager_AddCA(t *testing.T) { + tests := []struct { + name string + caName string + newCA util.CAType + wantErr bool + before VersionedState + expected VersionedState + }{ + { + name: "basic test", + caName: "aCA", + newCA: util.CAType{Cert: "aCert", Key: "aKey"}, + before: VersionedState{ + V1: &V1{ + Upstream: "abc123", + }, + }, + expected: VersionedState{ + V1: &V1{ + Upstream: "abc123", + CAs: map[string]util.CAType{ + "aCA": {Cert: "aCert", Key: "aKey"}, + }, + }, + }, + }, + { + name: "add to existing", + caName: "bCA", + newCA: util.CAType{Cert: "bCert", Key: "bKey"}, + before: VersionedState{ + V1: &V1{ + Upstream: "abc123", + CAs: map[string]util.CAType{ + "aCA": {Cert: "aCert", Key: "aKey"}, + }, + }, + }, + expected: VersionedState{ + V1: &V1{ + Upstream: "abc123", + CAs: map[string]util.CAType{ + "aCA": {Cert: "aCert", Key: "aKey"}, + "bCA": {Cert: "bCert", Key: "bKey"}, + }, + }, + }, + }, + { + name: "colliding ca names", + wantErr: true, + caName: "aCA", + newCA: util.CAType{Cert: "aCert", Key: "aKey"}, + before: VersionedState{ + V1: &V1{ + Upstream: "abc123", + CAs: map[string]util.CAType{ + "aCA": {Cert: "aCert", Key: "aKey"}, + }, + }, + }, + expected: VersionedState{ + V1: &V1{ + Upstream: "abc123", + CAs: map[string]util.CAType{ + "aCA": {Cert: "aCert", Key: "aKey"}, + }, + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := require.New(t) + m := &MManager{ + Logger: &logger.TestLogger{T: t}, + FS: afero.Afero{Fs: afero.NewMemMapFs()}, + V: viper.New(), + } + + err := m.serializeAndWriteState(tt.before) + req.NoError(err) + + err = m.AddCA(tt.caName, tt.newCA) + if !tt.wantErr { + req.NoError(err, "MManager.AddCA() error = %v", err) + } else { + req.Error(err) + } + + actualState, err := m.TryLoad() + req.NoError(err) + + req.Equal(tt.expected, actualState) + }) + } +} + +func TestMManager_AddCert(t *testing.T) { + tests := []struct { + name string + certName string + newCert util.CertType + wantErr bool + before VersionedState + expected VersionedState + }{ + { + name: "basic test", + certName: "aCert", + newCert: util.CertType{Cert: "aCert", Key: "aKey"}, + before: VersionedState{ + V1: &V1{ + Upstream: "abc123", + }, + }, + expected: VersionedState{ + V1: &V1{ + Upstream: "abc123", + Certs: map[string]util.CertType{ + "aCert": {Cert: "aCert", Key: "aKey"}, + }, + }, + }, + }, + { + name: "add to existing", + certName: "bCert", + newCert: util.CertType{Cert: "bCert", Key: "bKey"}, + before: VersionedState{ + V1: &V1{ + Upstream: "abc123", + Certs: map[string]util.CertType{ + "aCert": {Cert: "aCert", Key: "aKey"}, + }, + }, + }, + expected: VersionedState{ + V1: &V1{ + Upstream: "abc123", + Certs: map[string]util.CertType{ + "aCert": {Cert: "aCert", Key: "aKey"}, + "bCert": {Cert: "bCert", Key: "bKey"}, + }, + }, + }, + }, + { + name: "colliding ca names", + wantErr: true, + certName: "aCert", + newCert: util.CertType{Cert: "aCert", Key: "aKey"}, + before: VersionedState{ + V1: &V1{ + Upstream: "abc123", + Certs: map[string]util.CertType{ + "aCert": {Cert: "aCert", Key: "aKey"}, + }, + }, + }, + expected: VersionedState{ + V1: &V1{ + Upstream: "abc123", + Certs: map[string]util.CertType{ + "aCert": {Cert: "aCert", Key: "aKey"}, + }, + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := require.New(t) + m := &MManager{ + Logger: &logger.TestLogger{T: t}, + FS: afero.Afero{Fs: afero.NewMemMapFs()}, + V: viper.New(), + } + + err := m.serializeAndWriteState(tt.before) + req.NoError(err) + + err = m.AddCert(tt.certName, tt.newCert) + if !tt.wantErr { + req.NoError(err, "MManager.AddCert() error = %v", err) + } else { + req.Error(err) + } + + actualState, err := m.TryLoad() + req.NoError(err) + + req.Equal(tt.expected, actualState) + }) + } +} diff --git a/pkg/test-mocks/state/manager_mock.go b/pkg/test-mocks/state/manager_mock.go index 99def17a9..984d31060 100644 --- a/pkg/test-mocks/state/manager_mock.go +++ b/pkg/test-mocks/state/manager_mock.go @@ -240,6 +240,19 @@ func (mr *MockManagerMockRecorder) SerializeUpstreamContents(arg0 interface{}) * return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SerializeUpstreamContents", reflect.TypeOf((*MockManager)(nil).SerializeUpstreamContents), arg0) } +// StateUpdate mocks base method +func (m *MockManager) StateUpdate(arg0 state.Update) (state.State, error) { + ret := m.ctrl.Call(m, "StateUpdate", arg0) + ret0, _ := ret[0].(state.State) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// StateUpdate indicates an expected call of StateUpdate +func (mr *MockManagerMockRecorder) StateUpdate(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StateUpdate", reflect.TypeOf((*MockManager)(nil).StateUpdate), arg0) +} + // TryLoad mocks base method func (m *MockManager) TryLoad() (state.State, error) { ret := m.ctrl.Call(m, "TryLoad")