diff --git a/cmd/manager/main.go b/cmd/manager/main.go index c2d53149dd..7d1f4cd2b9 100644 --- a/cmd/manager/main.go +++ b/cmd/manager/main.go @@ -35,6 +35,7 @@ import ( "github.com/elastic/cloud-on-k8s/pkg/controller/beat" "github.com/elastic/cloud-on-k8s/pkg/controller/common/certificates" "github.com/elastic/cloud-on-k8s/pkg/controller/common/container" + commonlicense "github.com/elastic/cloud-on-k8s/pkg/controller/common/license" "github.com/elastic/cloud-on-k8s/pkg/controller/common/operator" "github.com/elastic/cloud-on-k8s/pkg/controller/common/reconciler" controllerscheme "github.com/elastic/cloud-on-k8s/pkg/controller/common/scheme" @@ -580,12 +581,31 @@ func startOperator(ctx context.Context) error { "build_hash", operatorInfo.BuildInfo.Hash, "build_date", operatorInfo.BuildInfo.Date, "build_snapshot", operatorInfo.BuildInfo.Snapshot) - if err := mgr.Start(ctx); err != nil { - log.Error(err, "Failed to start the controller manager") - return err - } + exitOnErr := make(chan error) - return nil + // start the manager + go func() { + if err := mgr.Start(ctx); err != nil { + log.Error(err, "Failed to start the controller manager") + exitOnErr <- err + } + }() + + // check operator license key + go func() { + mgr.GetCache().WaitForCacheSync(ctx) + + lc := commonlicense.NewLicenseChecker(mgr.GetClient(), params.OperatorNamespace) + licenseType, err := lc.ValidOperatorLicenseKeyType() + if err != nil { + log.Error(err, "Failed to validate operator license key") + exitOnErr <- err + } else { + log.Info("Operator license key validated", "license_type", licenseType) + } + }() + + return <-exitOnErr } // asyncTasks schedules some tasks to be started when this instance of the operator is elected diff --git a/pkg/controller/autoscaling/elasticsearch/controller_test.go b/pkg/controller/autoscaling/elasticsearch/controller_test.go index fb17d8bb3b..f324d90cb2 100644 --- a/pkg/controller/autoscaling/elasticsearch/controller_test.go +++ b/pkg/controller/autoscaling/elasticsearch/controller_test.go @@ -94,7 +94,7 @@ func TestReconcile(t *testing.T) { fields: fields{ EsClient: newFakeEsClient(t).withCapacity("frozen-tier"), recorder: record.NewFakeRecorder(1000), - licenseChecker: &fakeLicenceChecker{}, + licenseChecker: &license.MockLicenseChecker{EnterpriseEnabled: true}, }, args: args{ esManifest: "frozen-tier", @@ -109,7 +109,7 @@ func TestReconcile(t *testing.T) { fields: fields{ EsClient: newFakeEsClient(t).withCapacity("ml"), recorder: record.NewFakeRecorder(1000), - licenseChecker: &fakeLicenceChecker{}, + licenseChecker: &license.MockLicenseChecker{EnterpriseEnabled: true}, }, args: args{ esManifest: "ml", @@ -124,7 +124,7 @@ func TestReconcile(t *testing.T) { fields: fields{ EsClient: newFakeEsClient(t).withErrorOnDeleteAutoscalingAutoscalingPolicies(), recorder: record.NewFakeRecorder(1000), - licenseChecker: &fakeLicenceChecker{}, + licenseChecker: &license.MockLicenseChecker{EnterpriseEnabled: true}, }, args: args{ esManifest: "min-nodes-increased-by-user", @@ -139,7 +139,7 @@ func TestReconcile(t *testing.T) { fields: fields{ EsClient: newFakeEsClient(t).withCapacity("empty-autoscaling-api-response"), recorder: record.NewFakeRecorder(1000), - licenseChecker: &fakeLicenceChecker{}, + licenseChecker: &license.MockLicenseChecker{EnterpriseEnabled: true}, }, args: args{ esManifest: "empty-autoscaling-api-response", @@ -152,7 +152,7 @@ func TestReconcile(t *testing.T) { fields: fields{ EsClient: newFakeEsClient(t), recorder: record.NewFakeRecorder(1000), - licenseChecker: &fakeLicenceChecker{}, + licenseChecker: &license.MockLicenseChecker{EnterpriseEnabled: true}, }, args: args{ esManifest: "cluster-creation", @@ -165,7 +165,7 @@ func TestReconcile(t *testing.T) { fields: fields{ EsClient: newFakeEsClient(t).withCapacity("max-storage-reached"), recorder: record.NewFakeRecorder(1000), - licenseChecker: &fakeLicenceChecker{}, + licenseChecker: &license.MockLicenseChecker{EnterpriseEnabled: true}, }, args: args{ esManifest: "max-storage-reached", @@ -182,7 +182,7 @@ func TestReconcile(t *testing.T) { fields: fields{ EsClient: newFakeEsClient(t).withCapacity("storage-scaled-horizontally"), recorder: record.NewFakeRecorder(1000), - licenseChecker: &fakeLicenceChecker{}, + licenseChecker: &license.MockLicenseChecker{EnterpriseEnabled: true}, }, args: args{ esManifest: "storage-scaled-horizontally", @@ -195,7 +195,7 @@ func TestReconcile(t *testing.T) { fields: fields{ EsClient: newFakeEsClient(t), recorder: record.NewFakeRecorder(1000), - licenseChecker: &fakeLicenceChecker{}, + licenseChecker: &license.MockLicenseChecker{EnterpriseEnabled: true}, }, args: args{ esManifest: "", @@ -374,19 +374,3 @@ func (f *fakeEsClient) GetAutoscalingCapacity(_ context.Context) (esclient.Autos func (f *fakeEsClient) UpdateMLNodesSettings(_ context.Context, maxLazyMLNodes int32, maxMemory string) error { return nil } - -// - Fake licence checker - -type fakeLicenceChecker struct{} - -func (flc *fakeLicenceChecker) CurrentEnterpriseLicense() (*license.EnterpriseLicense, error) { - return nil, nil -} - -func (flc *fakeLicenceChecker) EnterpriseFeaturesEnabled() (bool, error) { - return true, nil -} - -func (flc *fakeLicenceChecker) Valid(l license.EnterpriseLicense) (bool, error) { - return true, nil -} diff --git a/pkg/controller/common/license/check.go b/pkg/controller/common/license/check.go index cc436f8f17..7d7f5c3819 100644 --- a/pkg/controller/common/license/check.go +++ b/pkg/controller/common/license/check.go @@ -6,6 +6,7 @@ package license import ( "context" + "fmt" "sort" "time" @@ -24,6 +25,7 @@ type Checker interface { CurrentEnterpriseLicense() (*EnterpriseLicense, error) EnterpriseFeaturesEnabled() (bool, error) Valid(l EnterpriseLicense) (bool, error) + ValidOperatorLicenseKeyType() (OperatorLicenseType, error) } // checker contains parameters for license checks. @@ -64,7 +66,7 @@ func (lc *checker) CurrentEnterpriseLicense() (*EnterpriseLicense, error) { } sort.Slice(licenses, func(i, j int) bool { - t1, t2 := EnterpriseLicenseTypeOrder[licenses[i].License.Type], EnterpriseLicenseTypeOrder[licenses[j].License.Type] + t1, t2 := OperatorLicenseTypeOrder[licenses[i].License.Type], OperatorLicenseTypeOrder[licenses[j].License.Type] if t1 != t2 { // sort by type (first the most features) return t1 > t2 } @@ -115,20 +117,38 @@ func (lc *checker) Valid(l EnterpriseLicense) (bool, error) { return false, nil } -type MockChecker struct { - MissingLicense bool +// ValidOperatorLicenseKeyType returns true if the current operator license key is valid +func (lc checker) ValidOperatorLicenseKeyType() (OperatorLicenseType, error) { + lic, err := lc.CurrentEnterpriseLicense() + if err != nil { + log.V(-1).Info("Invalid Enterprise license, fallback to Basic: " + err.Error()) + } + + licType := lic.GetOperatorLicenseType() + if _, valid := OperatorLicenseTypeOrder[licType]; !valid { + return licType, fmt.Errorf("invalid license key: %s", licType) + } + return licType, nil } -func (m MockChecker) CurrentEnterpriseLicense() (*EnterpriseLicense, error) { +type MockLicenseChecker struct { + EnterpriseEnabled bool +} + +func (m MockLicenseChecker) CurrentEnterpriseLicense() (*EnterpriseLicense, error) { return &EnterpriseLicense{}, nil } -func (m MockChecker) EnterpriseFeaturesEnabled() (bool, error) { - return !m.MissingLicense, nil +func (m MockLicenseChecker) EnterpriseFeaturesEnabled() (bool, error) { + return m.EnterpriseEnabled, nil +} + +func (m MockLicenseChecker) Valid(l EnterpriseLicense) (bool, error) { + return m.EnterpriseEnabled, nil } -func (m MockChecker) Valid(l EnterpriseLicense) (bool, error) { - return !m.MissingLicense, nil +func (m MockLicenseChecker) ValidOperatorLicenseKeyType() (OperatorLicenseType, error) { + return LicenseTypeEnterprise, nil } -var _ Checker = &MockChecker{} +var _ Checker = &MockLicenseChecker{} diff --git a/pkg/controller/common/license/check_test.go b/pkg/controller/common/license/check_test.go index 9509340348..0361d56f4d 100644 --- a/pkg/controller/common/license/check_test.go +++ b/pkg/controller/common/license/check_test.go @@ -236,3 +236,102 @@ func Test_CurrentEnterpriseLicense(t *testing.T) { }) } } + +func Test_ValidOperatorLicenseKey(t *testing.T) { + privKey, err := x509.ParsePKCS1PrivateKey(privateKeyFixture) + require.NoError(t, err) + + validLicenseFixture := licenseFixtureV3 + validLicenseFixture.License.ExpiryDateInMillis = chrono.ToMillis(time.Now().Add(1 * time.Hour)) + signatureBytes, err := NewSigner(privKey).Sign(validLicenseFixture) + require.NoError(t, err) + validLicense := asRuntimeObjects(validLicenseFixture, signatureBytes) + + trialState, err := NewTrialState() + require.NoError(t, err) + validTrialLicenseFixture := emptyTrialLicenseFixture + require.NoError(t, trialState.InitTrialLicense(&validTrialLicenseFixture)) + validTrialLicense := asRuntimeObject(validTrialLicenseFixture) + + statusSecret, err := ExpectedTrialStatus(testNS, types.NamespacedName{}, trialState) + require.NoError(t, err) + + type fields struct { + initialObjects []runtime.Object + operatorNamespace string + publicKey []byte + } + + tests := []struct { + name string + fields fields + wantErr bool + wantType OperatorLicenseType + }{ + { + name: "get valid basic license: OK", + fields: fields{ + initialObjects: []runtime.Object{}, + operatorNamespace: "test-system", + }, + wantType: LicenseTypeBasic, + wantErr: false, + }, + { + name: "get valid enterprise license: OK", + fields: fields{ + initialObjects: validLicense, + operatorNamespace: "test-system", + publicKey: publicKeyBytesFixture(t), + }, + wantType: LicenseTypeEnterprise, + wantErr: false, + }, + { + name: "get valid trial enterprise license: OK", + fields: fields{ + initialObjects: []runtime.Object{validTrialLicense, &statusSecret}, + operatorNamespace: "test-system", + publicKey: publicKeyBytesFixture(t), + }, + wantType: LicenseTypeEnterpriseTrial, + wantErr: false, + }, + { + name: "get valid enterprise license among two licenses: OK", + fields: fields{ + initialObjects: append(validLicense, validTrialLicense), + operatorNamespace: "test-system", + publicKey: publicKeyBytesFixture(t), + }, + wantType: LicenseTypeEnterprise, + wantErr: false, + }, + { + name: "invalid public key: fallback to basic", + fields: fields{ + initialObjects: validLicense, + operatorNamespace: "test-system", + publicKey: []byte("not a public key"), + }, + wantType: LicenseTypeBasic, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + lc := &checker{ + k8sClient: k8s.NewFakeClient(tt.fields.initialObjects...), + operatorNamespace: tt.fields.operatorNamespace, + publicKey: tt.fields.publicKey, + } + licenseType, err := lc.ValidOperatorLicenseKeyType() + if (err != nil) != tt.wantErr { + t.Errorf("Checker.ValidOperatorLicenseKeyType() err = %v, wantErr %v", err, tt.wantErr) + } + if licenseType != tt.wantType { + t.Errorf("Checker.ValidOperatorLicenseKeyType() licenseType = %v, wantType %v", licenseType, tt.wantType) + } + }) + } +} diff --git a/pkg/controller/common/license/model.go b/pkg/controller/common/license/model.go index 484faec45f..5d0da20738 100644 --- a/pkg/controller/common/license/model.go +++ b/pkg/controller/common/license/model.go @@ -15,6 +15,7 @@ import ( type OperatorLicenseType string const ( + LicenseTypeBasic OperatorLicenseType = "basic" LicenseTypeEnterprise OperatorLicenseType = "enterprise" LicenseTypeEnterpriseTrial OperatorLicenseType = "enterprise_trial" // LicenseTypeLegacyTrial earlier versions of ECK used this as the trial identifier @@ -47,8 +48,9 @@ type LicenseSpec struct { Version int // not marshalled but part of the signature } -// EnterpriseLicenseTypeOrder license types mapped to ints in increasing order of feature sets for sorting purposes. -var EnterpriseLicenseTypeOrder = map[OperatorLicenseType]int{ +// OperatorLicenseTypeOrder license types mapped to ints in increasing order of feature sets for sorting purposes. +var OperatorLicenseTypeOrder = map[OperatorLicenseType]int{ + LicenseTypeBasic: -1, LicenseTypeLegacyTrial: 0, LicenseTypeEnterpriseTrial: 1, LicenseTypeEnterprise: 2, @@ -107,6 +109,13 @@ func (l EnterpriseLicense) IsMissingFields() error { return nil } +func (l *EnterpriseLicense) GetOperatorLicenseType() OperatorLicenseType { + if l == nil { + return LicenseTypeBasic + } + return l.License.Type +} + // LicenseStatus expresses the validity status of a license. type LicenseStatus string diff --git a/pkg/controller/elasticsearch/remotecluster/elasticsearch_test.go b/pkg/controller/elasticsearch/remotecluster/elasticsearch_test.go index 85bf15a453..952578f84b 100644 --- a/pkg/controller/elasticsearch/remotecluster/elasticsearch_test.go +++ b/pkg/controller/elasticsearch/remotecluster/elasticsearch_test.go @@ -108,22 +108,6 @@ func newEsWithRemoteClusters( } } -type fakeLicenseChecker struct { - enterpriseFeaturesEnabled bool -} - -func (fakeLicenseChecker) CurrentEnterpriseLicense() (*license.EnterpriseLicense, error) { - return nil, nil -} - -func (f *fakeLicenseChecker) EnterpriseFeaturesEnabled() (bool, error) { - return f.enterpriseFeaturesEnabled, nil -} - -func (fakeLicenseChecker) Valid(_ license.EnterpriseLicense) (bool, error) { - return true, nil -} - func TestUpdateSettings(t *testing.T) { emptySettings := esclient.RemoteClustersSettings{PersistentSettings: &esclient.SettingsGroup{}} type args struct { @@ -144,7 +128,7 @@ func TestUpdateSettings(t *testing.T) { name: "Nothing to create, nothing to delete", args: args{ esClient: &fakeESClient{existingSettings: emptySettings}, - licenseChecker: &fakeLicenseChecker{true}, + licenseChecker: &license.MockLicenseChecker{EnterpriseEnabled: true}, es: newEsWithRemoteClusters( "ns1", "es1", @@ -158,7 +142,7 @@ func TestUpdateSettings(t *testing.T) { name: "Empty annotation", args: args{ esClient: &fakeESClient{existingSettings: emptySettings}, - licenseChecker: &fakeLicenseChecker{true}, + licenseChecker: &license.MockLicenseChecker{EnterpriseEnabled: true}, es: newEsWithRemoteClusters( "ns1", "es1", @@ -172,7 +156,7 @@ func TestUpdateSettings(t *testing.T) { name: "Outdated annotation should be removed", args: args{ esClient: &fakeESClient{existingSettings: emptySettings}, - licenseChecker: &fakeLicenseChecker{true}, + licenseChecker: &license.MockLicenseChecker{EnterpriseEnabled: true}, es: newEsWithRemoteClusters( "ns1", "es1", @@ -186,7 +170,7 @@ func TestUpdateSettings(t *testing.T) { name: "Create a new remote cluster", args: args{ esClient: &fakeESClient{existingSettings: emptySettings}, - licenseChecker: &fakeLicenseChecker{true}, + licenseChecker: &license.MockLicenseChecker{EnterpriseEnabled: true}, es: newEsWithRemoteClusters( "ns1", "es1", @@ -215,7 +199,7 @@ func TestUpdateSettings(t *testing.T) { esClient: &fakeESClient{ existingSettings: esclient.RemoteClustersSettings{PersistentSettings: &esclient.SettingsGroup{}}, }, - licenseChecker: &fakeLicenseChecker{true}, + licenseChecker: &license.MockLicenseChecker{EnterpriseEnabled: true}, es: newEsWithRemoteClusters( "ns1", "es1", @@ -251,7 +235,7 @@ func TestUpdateSettings(t *testing.T) { }, }, }, - licenseChecker: &fakeLicenseChecker{true}, + licenseChecker: &license.MockLicenseChecker{EnterpriseEnabled: true}, es: newEsWithRemoteClusters( "ns1", "es1", @@ -290,7 +274,7 @@ func TestUpdateSettings(t *testing.T) { }, }, }, - licenseChecker: &fakeLicenseChecker{true}, + licenseChecker: &license.MockLicenseChecker{EnterpriseEnabled: true}, es: newEsWithRemoteClusters( "ns1", "es1", @@ -331,7 +315,7 @@ func TestUpdateSettings(t *testing.T) { }, }, }, - licenseChecker: &fakeLicenseChecker{true}, + licenseChecker: &license.MockLicenseChecker{EnterpriseEnabled: true}, es: newEsWithRemoteClusters( "ns1", "es1", @@ -374,7 +358,7 @@ func TestUpdateSettings(t *testing.T) { }, }, }, - licenseChecker: &fakeLicenseChecker{true}, + licenseChecker: &license.MockLicenseChecker{EnterpriseEnabled: true}, es: newEsWithRemoteClusters( "ns1", "es1", @@ -404,7 +388,7 @@ func TestUpdateSettings(t *testing.T) { name: "No valid license to create a new remote cluster", args: args{ esClient: &fakeESClient{}, - licenseChecker: &fakeLicenseChecker{false}, + licenseChecker: &license.MockLicenseChecker{EnterpriseEnabled: false}, es: newEsWithRemoteClusters( "ns1", "es1", @@ -431,7 +415,7 @@ func TestUpdateSettings(t *testing.T) { }, }, }, - licenseChecker: &fakeLicenseChecker{true}, + licenseChecker: &license.MockLicenseChecker{EnterpriseEnabled: true}, es: newEsWithRemoteClusters( "ns1", "es1", diff --git a/pkg/controller/license/license_controller_integration_test.go b/pkg/controller/license/license_controller_integration_test.go index 8841a25531..8a101b3c59 100644 --- a/pkg/controller/license/license_controller_integration_test.go +++ b/pkg/controller/license/license_controller_integration_test.go @@ -43,7 +43,7 @@ func TestReconcile(t *testing.T) { c, stop := test.StartManager(t, func(mgr manager.Manager, p operator.Parameters) error { r := &ReconcileLicenses{ Client: mgr.GetClient(), - checker: license.MockChecker{}, + checker: license.MockLicenseChecker{EnterpriseEnabled: true}, } c, err := common.NewController(mgr, name, r, p) if err != nil { diff --git a/pkg/controller/license/license_controller_test.go b/pkg/controller/license/license_controller_test.go index 2f03a58a57..5aec5e285e 100644 --- a/pkg/controller/license/license_controller_test.go +++ b/pkg/controller/license/license_controller_test.go @@ -177,7 +177,7 @@ func TestReconcileLicenses_reconcileInternal(t *testing.T) { client := k8s.NewFakeClient(tt.k8sResources...) r := &ReconcileLicenses{ Client: client, - checker: commonlicense.MockChecker{}, + checker: commonlicense.MockLicenseChecker{EnterpriseEnabled: true}, } nsn := k8s.ExtractNamespacedName(tt.cluster) res, err := r.reconcileInternal(reconcile.Request{NamespacedName: nsn}).Aggregate() diff --git a/pkg/controller/maps/controller_test.go b/pkg/controller/maps/controller_test.go index 5b227da612..5475057ea3 100644 --- a/pkg/controller/maps/controller_test.go +++ b/pkg/controller/maps/controller_test.go @@ -78,7 +78,7 @@ func TestReconcileMapsServer_Reconcile(t *testing.T) { Client: k8s.NewFakeClient(&v1alpha1.ElasticMapsServer{ ObjectMeta: metav1.ObjectMeta{Name: nsnFixture.Name, Namespace: nsnFixture.Namespace, DeletionTimestamp: &timeFixture}, }), - licenseChecker: license.MockChecker{}, + licenseChecker: license.MockLicenseChecker{EnterpriseEnabled: true}, dynamicWatches: watches.NewDynamicWatches(), }, pre: func(r ReconcileMapsServer) { @@ -115,7 +115,7 @@ func TestReconcileMapsServer_Reconcile(t *testing.T) { Client: k8s.NewFakeClient(&emsFixture), recorder: record.NewFakeRecorder(10), dynamicWatches: watches.NewDynamicWatches(), - licenseChecker: license.MockChecker{MissingLicense: true}, + licenseChecker: license.MockLicenseChecker{EnterpriseEnabled: false}, Parameters: operator.Parameters{OperatorInfo: about.OperatorInfo{BuildInfo: about.BuildInfo{Version: "1.6.0"}}}, }, post: func(r ReconcileMapsServer) { @@ -132,7 +132,7 @@ func TestReconcileMapsServer_Reconcile(t *testing.T) { Client: k8s.NewFakeClient(&emsFixture), recorder: record.NewFakeRecorder(10), dynamicWatches: watches.NewDynamicWatches(), - licenseChecker: license.MockChecker{}, + licenseChecker: license.MockLicenseChecker{EnterpriseEnabled: true}, Parameters: operator.Parameters{OperatorInfo: about.OperatorInfo{BuildInfo: about.BuildInfo{Version: "1.6.0"}}}, }, post: func(r ReconcileMapsServer) { @@ -155,7 +155,7 @@ func TestReconcileMapsServer_Reconcile(t *testing.T) { Version: "7.10.0", // unsupported version }, }), - licenseChecker: license.MockChecker{}, + licenseChecker: license.MockLicenseChecker{EnterpriseEnabled: true}, recorder: record.NewFakeRecorder(10), }, wantErr: true, @@ -174,7 +174,7 @@ func TestReconcileMapsServer_Reconcile(t *testing.T) { }, }), dynamicWatches: watches.NewDynamicWatches(), - licenseChecker: license.MockChecker{}, + licenseChecker: license.MockLicenseChecker{EnterpriseEnabled: true}, recorder: record.NewFakeRecorder(10), }, post: func(r ReconcileMapsServer) { @@ -200,7 +200,7 @@ func TestReconcileMapsServer_Reconcile(t *testing.T) { }, }), dynamicWatches: watches.NewDynamicWatches(), - licenseChecker: license.MockChecker{}, + licenseChecker: license.MockLicenseChecker{EnterpriseEnabled: true}, recorder: record.NewFakeRecorder(10), }, post: func(r ReconcileMapsServer) { @@ -215,7 +215,7 @@ func TestReconcileMapsServer_Reconcile(t *testing.T) { Client: k8s.NewFakeClient(&emsFixture), recorder: record.NewFakeRecorder(10), dynamicWatches: watches.NewDynamicWatches(), - licenseChecker: license.MockChecker{}, + licenseChecker: license.MockLicenseChecker{EnterpriseEnabled: true}, Parameters: operator.Parameters{OperatorInfo: about.OperatorInfo{BuildInfo: about.BuildInfo{Version: "1.6.0"}}}, }, post: func(r ReconcileMapsServer) { diff --git a/pkg/controller/remoteca/controller_test.go b/pkg/controller/remoteca/controller_test.go index f639f00240..c449c3b81d 100644 --- a/pkg/controller/remoteca/controller_test.go +++ b/pkg/controller/remoteca/controller_test.go @@ -78,20 +78,6 @@ func (f *fakeAccessReviewer) AccessAllowed(_ context.Context, _ string, _ string return f.allowed, f.err } -type fakeLicenseChecker struct { - enterpriseFeaturesEnabled bool -} - -func (f fakeLicenseChecker) CurrentEnterpriseLicense() (*license.EnterpriseLicense, error) { - return nil, nil -} -func (f fakeLicenseChecker) EnterpriseFeaturesEnabled() (bool, error) { - return f.enterpriseFeaturesEnabled, nil -} -func (f fakeLicenseChecker) Valid(l license.EnterpriseLicense) (bool, error) { - return f.enterpriseFeaturesEnabled, nil -} - func fakePublicCa(namespace, name string) *corev1.Secret { namespacedName := types.NamespacedName{ Name: name, @@ -166,7 +152,7 @@ func TestReconcileRemoteCa_Reconcile(t *testing.T) { fakePublicCa("ns2", "es2"), }, accessReviewer: &fakeAccessReviewer{allowed: true}, - licenseChecker: &fakeLicenseChecker{enterpriseFeaturesEnabled: true}, + licenseChecker: license.MockLicenseChecker{EnterpriseEnabled: true}, }, args: args{ request: reconcile.Request{ @@ -193,7 +179,7 @@ func TestReconcileRemoteCa_Reconcile(t *testing.T) { fakePublicCa("ns2", "es2"), }, accessReviewer: &fakeAccessReviewer{allowed: true}, - licenseChecker: &fakeLicenseChecker{enterpriseFeaturesEnabled: true}, + licenseChecker: license.MockLicenseChecker{EnterpriseEnabled: true}, }, args: args{ request: reconcile.Request{ @@ -222,7 +208,7 @@ func TestReconcileRemoteCa_Reconcile(t *testing.T) { remoteCa("ns2", "es2", "ns1", "es1"), }, accessReviewer: &fakeAccessReviewer{allowed: true}, - licenseChecker: &fakeLicenseChecker{enterpriseFeaturesEnabled: true}, + licenseChecker: license.MockLicenseChecker{EnterpriseEnabled: true}, }, args: args{ request: reconcile.Request{ @@ -263,7 +249,7 @@ func TestReconcileRemoteCa_Reconcile(t *testing.T) { withDataCert(remoteCa("ns2", "es2", "ns1", "es1"), []byte("bar")), }, accessReviewer: &fakeAccessReviewer{allowed: true}, - licenseChecker: &fakeLicenseChecker{enterpriseFeaturesEnabled: true}, + licenseChecker: license.MockLicenseChecker{EnterpriseEnabled: true}, }, args: args{ request: reconcile.Request{ @@ -295,7 +281,7 @@ func TestReconcileRemoteCa_Reconcile(t *testing.T) { remoteCa("ns3", "es3", "ns1", "es1"), }, accessReviewer: &fakeAccessReviewer{allowed: true}, - licenseChecker: &fakeLicenseChecker{enterpriseFeaturesEnabled: true}, + licenseChecker: license.MockLicenseChecker{EnterpriseEnabled: true}, }, args: args{ request: reconcile.Request{ @@ -334,7 +320,7 @@ func TestReconcileRemoteCa_Reconcile(t *testing.T) { fakePublicCa("ns2", "es2"), }, accessReviewer: &fakeAccessReviewer{allowed: true}, - licenseChecker: &fakeLicenseChecker{enterpriseFeaturesEnabled: false}, + licenseChecker: license.MockLicenseChecker{EnterpriseEnabled: false}, }, args: args{ request: reconcile.Request{ @@ -375,7 +361,7 @@ func TestReconcileRemoteCa_Reconcile(t *testing.T) { remoteCa("ns2", "es2", "ns1", "es1"), }, accessReviewer: &fakeAccessReviewer{allowed: true}, - licenseChecker: &fakeLicenseChecker{enterpriseFeaturesEnabled: false}, + licenseChecker: license.MockLicenseChecker{EnterpriseEnabled: false}, }, args: args{ request: reconcile.Request{ @@ -404,7 +390,7 @@ func TestReconcileRemoteCa_Reconcile(t *testing.T) { remoteCa("ns2", "es2", "ns1", "es1"), }, accessReviewer: &fakeAccessReviewer{allowed: false}, - licenseChecker: &fakeLicenseChecker{enterpriseFeaturesEnabled: true}, + licenseChecker: license.MockLicenseChecker{EnterpriseEnabled: true}, }, args: args{ request: reconcile.Request{