Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add operator license key check #4925

Merged
merged 3 commits into from
Oct 11, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 25 additions & 5 deletions cmd/manager/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ import (
"github.com/elastic/cloud-on-k8s/pkg/controller/beat"
"github.com/elastic/cloud-on-k8s/pkg/controller/common/certificates"
"github.com/elastic/cloud-on-k8s/pkg/controller/common/container"
commonlicense "github.com/elastic/cloud-on-k8s/pkg/controller/common/license"
"github.com/elastic/cloud-on-k8s/pkg/controller/common/operator"
"github.com/elastic/cloud-on-k8s/pkg/controller/common/reconciler"
controllerscheme "github.com/elastic/cloud-on-k8s/pkg/controller/common/scheme"
Expand Down Expand Up @@ -580,12 +581,31 @@ func startOperator(ctx context.Context) error {
"build_hash", operatorInfo.BuildInfo.Hash, "build_date", operatorInfo.BuildInfo.Date,
"build_snapshot", operatorInfo.BuildInfo.Snapshot)

if err := mgr.Start(ctx); err != nil {
log.Error(err, "Failed to start the controller manager")
return err
}
exitOnErr := make(chan error)

return nil
// start the manager
go func() {
if err := mgr.Start(ctx); err != nil {
log.Error(err, "Failed to start the controller manager")
exitOnErr <- err
}
}()

// check operator license key
go func() {
mgr.GetCache().WaitForCacheSync(ctx)

lc := commonlicense.NewLicenseChecker(mgr.GetClient(), params.OperatorNamespace)
licenseType, err := lc.ValidOperatorLicenseKey()
thbkrkr marked this conversation as resolved.
Show resolved Hide resolved
if err != nil {
log.Error(err, "Failed to validate operator license key")
exitOnErr <- err
} else {
log.Info("Operator license key validated", "license_type", licenseType)
}
}()

return <-exitOnErr
}

// asyncTasks schedules some tasks to be started when this instance of the operator is elected
Expand Down
32 changes: 8 additions & 24 deletions pkg/controller/autoscaling/elasticsearch/controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ func TestReconcile(t *testing.T) {
fields: fields{
EsClient: newFakeEsClient(t).withCapacity("frozen-tier"),
recorder: record.NewFakeRecorder(1000),
licenseChecker: &fakeLicenceChecker{},
licenseChecker: &license.MockLicenseChecker{EnterpriseEnabled: true},
},
args: args{
esManifest: "frozen-tier",
Expand All @@ -109,7 +109,7 @@ func TestReconcile(t *testing.T) {
fields: fields{
EsClient: newFakeEsClient(t).withCapacity("ml"),
recorder: record.NewFakeRecorder(1000),
licenseChecker: &fakeLicenceChecker{},
licenseChecker: &license.MockLicenseChecker{EnterpriseEnabled: true},
},
args: args{
esManifest: "ml",
Expand All @@ -124,7 +124,7 @@ func TestReconcile(t *testing.T) {
fields: fields{
EsClient: newFakeEsClient(t).withErrorOnDeleteAutoscalingAutoscalingPolicies(),
recorder: record.NewFakeRecorder(1000),
licenseChecker: &fakeLicenceChecker{},
licenseChecker: &license.MockLicenseChecker{EnterpriseEnabled: true},
},
args: args{
esManifest: "min-nodes-increased-by-user",
Expand All @@ -139,7 +139,7 @@ func TestReconcile(t *testing.T) {
fields: fields{
EsClient: newFakeEsClient(t).withCapacity("empty-autoscaling-api-response"),
recorder: record.NewFakeRecorder(1000),
licenseChecker: &fakeLicenceChecker{},
licenseChecker: &license.MockLicenseChecker{EnterpriseEnabled: true},
},
args: args{
esManifest: "empty-autoscaling-api-response",
Expand All @@ -152,7 +152,7 @@ func TestReconcile(t *testing.T) {
fields: fields{
EsClient: newFakeEsClient(t),
recorder: record.NewFakeRecorder(1000),
licenseChecker: &fakeLicenceChecker{},
licenseChecker: &license.MockLicenseChecker{EnterpriseEnabled: true},
},
args: args{
esManifest: "cluster-creation",
Expand All @@ -165,7 +165,7 @@ func TestReconcile(t *testing.T) {
fields: fields{
EsClient: newFakeEsClient(t).withCapacity("max-storage-reached"),
recorder: record.NewFakeRecorder(1000),
licenseChecker: &fakeLicenceChecker{},
licenseChecker: &license.MockLicenseChecker{EnterpriseEnabled: true},
},
args: args{
esManifest: "max-storage-reached",
Expand All @@ -182,7 +182,7 @@ func TestReconcile(t *testing.T) {
fields: fields{
EsClient: newFakeEsClient(t).withCapacity("storage-scaled-horizontally"),
recorder: record.NewFakeRecorder(1000),
licenseChecker: &fakeLicenceChecker{},
licenseChecker: &license.MockLicenseChecker{EnterpriseEnabled: true},
},
args: args{
esManifest: "storage-scaled-horizontally",
Expand All @@ -195,7 +195,7 @@ func TestReconcile(t *testing.T) {
fields: fields{
EsClient: newFakeEsClient(t),
recorder: record.NewFakeRecorder(1000),
licenseChecker: &fakeLicenceChecker{},
licenseChecker: &license.MockLicenseChecker{EnterpriseEnabled: true},
},
args: args{
esManifest: "",
Expand Down Expand Up @@ -374,19 +374,3 @@ func (f *fakeEsClient) GetAutoscalingCapacity(_ context.Context) (esclient.Autos
func (f *fakeEsClient) UpdateMLNodesSettings(_ context.Context, maxLazyMLNodes int32, maxMemory string) error {
return nil
}

// - Fake licence checker

type fakeLicenceChecker struct{}

func (flc *fakeLicenceChecker) CurrentEnterpriseLicense() (*license.EnterpriseLicense, error) {
return nil, nil
}

func (flc *fakeLicenceChecker) EnterpriseFeaturesEnabled() (bool, error) {
return true, nil
}

func (flc *fakeLicenceChecker) Valid(l license.EnterpriseLicense) (bool, error) {
return true, nil
}
38 changes: 29 additions & 9 deletions pkg/controller/common/license/check.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package license

import (
"context"
"fmt"
"sort"
"time"

Expand All @@ -24,6 +25,7 @@ type Checker interface {
CurrentEnterpriseLicense() (*EnterpriseLicense, error)
EnterpriseFeaturesEnabled() (bool, error)
Valid(l EnterpriseLicense) (bool, error)
ValidOperatorLicenseKey() (OperatorLicenseType, error)
}

// checker contains parameters for license checks.
Expand Down Expand Up @@ -64,7 +66,7 @@ func (lc *checker) CurrentEnterpriseLicense() (*EnterpriseLicense, error) {
}

sort.Slice(licenses, func(i, j int) bool {
t1, t2 := EnterpriseLicenseTypeOrder[licenses[i].License.Type], EnterpriseLicenseTypeOrder[licenses[j].License.Type]
t1, t2 := OperatorLicenseTypeOrder[licenses[i].License.Type], OperatorLicenseTypeOrder[licenses[j].License.Type]
if t1 != t2 { // sort by type (first the most features)
return t1 > t2
}
Expand Down Expand Up @@ -115,20 +117,38 @@ func (lc *checker) Valid(l EnterpriseLicense) (bool, error) {
return false, nil
}

type MockChecker struct {
MissingLicense bool
// ValidOperatorLicenseKey returns true if the current operator license key is valid
func (lc checker) ValidOperatorLicenseKey() (OperatorLicenseType, error) {
lic, err := lc.CurrentEnterpriseLicense()
if err != nil {
log.V(-1).Info("Invalid Enterprise license, fallback to Basic: " + err.Error())
}

licType := lic.GetOperatorLicenseType()
if _, valid := OperatorLicenseTypeOrder[licType]; !valid {
return licType, fmt.Errorf("invalid license key: %s", licType)
}
return licType, nil
}

func (m MockChecker) CurrentEnterpriseLicense() (*EnterpriseLicense, error) {
type MockLicenseChecker struct {
EnterpriseEnabled bool
}

func (m MockLicenseChecker) CurrentEnterpriseLicense() (*EnterpriseLicense, error) {
return &EnterpriseLicense{}, nil
}

func (m MockChecker) EnterpriseFeaturesEnabled() (bool, error) {
return !m.MissingLicense, nil
func (m MockLicenseChecker) EnterpriseFeaturesEnabled() (bool, error) {
return m.EnterpriseEnabled, nil
}

func (m MockLicenseChecker) Valid(l EnterpriseLicense) (bool, error) {
return m.EnterpriseEnabled, nil
}

func (m MockChecker) Valid(l EnterpriseLicense) (bool, error) {
return !m.MissingLicense, nil
func (m MockLicenseChecker) ValidOperatorLicenseKey() (OperatorLicenseType, error) {
return LicenseTypeEnterprise, nil
}

var _ Checker = &MockChecker{}
var _ Checker = &MockLicenseChecker{}
99 changes: 99 additions & 0 deletions pkg/controller/common/license/check_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -236,3 +236,102 @@ func Test_CurrentEnterpriseLicense(t *testing.T) {
})
}
}

func Test_ValidOperatorLicenseKey(t *testing.T) {
privKey, err := x509.ParsePKCS1PrivateKey(privateKeyFixture)
require.NoError(t, err)

validLicenseFixture := licenseFixtureV3
validLicenseFixture.License.ExpiryDateInMillis = chrono.ToMillis(time.Now().Add(1 * time.Hour))
signatureBytes, err := NewSigner(privKey).Sign(validLicenseFixture)
require.NoError(t, err)
validLicense := asRuntimeObjects(validLicenseFixture, signatureBytes)

trialState, err := NewTrialState()
require.NoError(t, err)
validTrialLicenseFixture := emptyTrialLicenseFixture
require.NoError(t, trialState.InitTrialLicense(&validTrialLicenseFixture))
validTrialLicense := asRuntimeObject(validTrialLicenseFixture)

statusSecret, err := ExpectedTrialStatus(testNS, types.NamespacedName{}, trialState)
require.NoError(t, err)

type fields struct {
initialObjects []runtime.Object
operatorNamespace string
publicKey []byte
}

tests := []struct {
name string
fields fields
wantErr bool
wantType OperatorLicenseType
}{
{
name: "get valid basic license: OK",
fields: fields{
initialObjects: []runtime.Object{},
operatorNamespace: "test-system",
},
wantType: LicenseTypeBasic,
wantErr: false,
},
{
name: "get valid enterprise license: OK",
fields: fields{
initialObjects: validLicense,
operatorNamespace: "test-system",
publicKey: publicKeyBytesFixture(t),
},
wantType: LicenseTypeEnterprise,
wantErr: false,
},
{
name: "get valid trial enterprise license: OK",
fields: fields{
initialObjects: []runtime.Object{validTrialLicense, &statusSecret},
operatorNamespace: "test-system",
publicKey: publicKeyBytesFixture(t),
},
wantType: LicenseTypeEnterpriseTrial,
wantErr: false,
},
{
name: "get valid enterprise license among two licenses: OK",
fields: fields{
initialObjects: append(validLicense, validTrialLicense),
operatorNamespace: "test-system",
publicKey: publicKeyBytesFixture(t),
},
wantType: LicenseTypeEnterprise,
wantErr: false,
},
{
name: "invalid public key: fallback to basic",
fields: fields{
initialObjects: validLicense,
operatorNamespace: "test-system",
publicKey: []byte("not a public key"),
},
wantType: LicenseTypeBasic,
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
lc := &checker{
k8sClient: k8s.NewFakeClient(tt.fields.initialObjects...),
operatorNamespace: tt.fields.operatorNamespace,
publicKey: tt.fields.publicKey,
}
licenseType, err := lc.ValidOperatorLicenseKey()
if (err != nil) != tt.wantErr {
t.Errorf("Checker.ValidOperatorLicenseKey() err = %v, wantErr %v", err, tt.wantErr)
}
if licenseType != tt.wantType {
t.Errorf("Checker.ValidOperatorLicenseKey() licenseType = %v, wantType %v", licenseType, tt.wantType)
}
})
}
}
13 changes: 11 additions & 2 deletions pkg/controller/common/license/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
type OperatorLicenseType string

const (
LicenseTypeBasic OperatorLicenseType = "basic"
LicenseTypeEnterprise OperatorLicenseType = "enterprise"
LicenseTypeEnterpriseTrial OperatorLicenseType = "enterprise_trial"
// LicenseTypeLegacyTrial earlier versions of ECK used this as the trial identifier
Expand Down Expand Up @@ -47,8 +48,9 @@ type LicenseSpec struct {
Version int // not marshalled but part of the signature
}

// EnterpriseLicenseTypeOrder license types mapped to ints in increasing order of feature sets for sorting purposes.
var EnterpriseLicenseTypeOrder = map[OperatorLicenseType]int{
// OperatorLicenseTypeOrder license types mapped to ints in increasing order of feature sets for sorting purposes.
var OperatorLicenseTypeOrder = map[OperatorLicenseType]int{
LicenseTypeBasic: -1,
LicenseTypeLegacyTrial: 0,
LicenseTypeEnterpriseTrial: 1,
LicenseTypeEnterprise: 2,
Expand Down Expand Up @@ -107,6 +109,13 @@ func (l EnterpriseLicense) IsMissingFields() error {
return nil
}

func (l *EnterpriseLicense) GetOperatorLicenseType() OperatorLicenseType {
if l == nil {
return LicenseTypeBasic
}
return l.License.Type
}

// LicenseStatus expresses the validity status of a license.
type LicenseStatus string

Expand Down
Loading