Skip to content

Commit

Permalink
Add operator license key check
Browse files Browse the repository at this point in the history
  • Loading branch information
thbkrkr committed Oct 7, 2021
1 parent 710713f commit 5480e1e
Show file tree
Hide file tree
Showing 4 changed files with 146 additions and 5 deletions.
30 changes: 25 additions & 5 deletions cmd/manager/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ import (
"github.com/elastic/cloud-on-k8s/pkg/controller/beat"
"github.com/elastic/cloud-on-k8s/pkg/controller/common/certificates"
"github.com/elastic/cloud-on-k8s/pkg/controller/common/container"
commonlicense "github.com/elastic/cloud-on-k8s/pkg/controller/common/license"
"github.com/elastic/cloud-on-k8s/pkg/controller/common/operator"
"github.com/elastic/cloud-on-k8s/pkg/controller/common/reconciler"
controllerscheme "github.com/elastic/cloud-on-k8s/pkg/controller/common/scheme"
Expand Down Expand Up @@ -580,12 +581,31 @@ func startOperator(ctx context.Context) error {
"build_hash", operatorInfo.BuildInfo.Hash, "build_date", operatorInfo.BuildInfo.Date,
"build_snapshot", operatorInfo.BuildInfo.Snapshot)

if err := mgr.Start(ctx); err != nil {
log.Error(err, "Failed to start the controller manager")
return err
}
exitOnErr := make(chan error)

return nil
// start the manager
go func() {
if err := mgr.Start(ctx); err != nil {
log.Error(err, "Failed to start the controller manager")
exitOnErr <- err
}
}()

// check operator license key
go func() {
mgr.GetCache().WaitForCacheSync(ctx)

lc := commonlicense.NewLicenseChecker(mgr.GetClient(), params.OperatorNamespace)
licenseType, err := lc.ValidOperatorLicenseKey()
if err != nil {
log.Error(err, "Failed to validate operator license key")
exitOnErr <- err
} else {
log.Info("Operator license key validated", "license_type", licenseType)
}
}()

return <-exitOnErr
}

// asyncTasks schedules some tasks to be started when this instance of the operator is elected
Expand Down
24 changes: 24 additions & 0 deletions pkg/controller/common/license/check.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package license

import (
"context"
"fmt"
"sort"
"time"

Expand All @@ -24,6 +25,7 @@ type Checker interface {
CurrentEnterpriseLicense() (*EnterpriseLicense, error)
EnterpriseFeaturesEnabled() (bool, error)
Valid(l EnterpriseLicense) (bool, error)
ValidOperatorLicenseKey() (OperatorLicenseType, error)
}

// checker contains parameters for license checks.
Expand Down Expand Up @@ -115,6 +117,24 @@ func (lc *checker) Valid(l EnterpriseLicense) (bool, error) {
return false, nil
}

// ValidOperatorLicenseKey returns true if the current operator license key is valid
func (lc checker) ValidOperatorLicenseKey() (OperatorLicenseType, error) {
license, _ := lc.CurrentEnterpriseLicense()
licenseType := license.GetOperatorLicenseType()
switch licenseType {
case LicenseTypeBasic:
return licenseType, nil
case LicenseTypeEnterprise:
return licenseType, nil
case LicenseTypeEnterpriseTrial:
return licenseType, nil
case LicenseTypeLegacyTrial:
return licenseType, nil
default:
return licenseType, fmt.Errorf("invalid license key: %v", license)
}
}

type MockChecker struct {
MissingLicense bool
}
Expand All @@ -131,4 +151,8 @@ func (m MockChecker) Valid(l EnterpriseLicense) (bool, error) {
return !m.MissingLicense, nil
}

func (m MockChecker) ValidOperatorLicenseKey() (OperatorLicenseType, error) {
return LicenseTypeBasic, nil
}

var _ Checker = &MockChecker{}
89 changes: 89 additions & 0 deletions pkg/controller/common/license/check_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -236,3 +236,92 @@ func Test_CurrentEnterpriseLicense(t *testing.T) {
})
}
}

func Test_ValidOperatorLicenseKey(t *testing.T) {
privKey, err := x509.ParsePKCS1PrivateKey(privateKeyFixture)
require.NoError(t, err)

validLicenseFixture := licenseFixtureV3
validLicenseFixture.License.ExpiryDateInMillis = chrono.ToMillis(time.Now().Add(1 * time.Hour))
signatureBytes, err := NewSigner(privKey).Sign(validLicenseFixture)
require.NoError(t, err)
validLicense := asRuntimeObjects(validLicenseFixture, signatureBytes)

trialState, err := NewTrialState()
require.NoError(t, err)
validTrialLicenseFixture := emptyTrialLicenseFixture
require.NoError(t, trialState.InitTrialLicense(&validTrialLicenseFixture))
validTrialLicense := asRuntimeObject(validTrialLicenseFixture)

statusSecret, err := ExpectedTrialStatus(testNS, types.NamespacedName{}, trialState)
require.NoError(t, err)

type fields struct {
initialObjects []runtime.Object
operatorNamespace string
publicKey []byte
}

tests := []struct {
name string
fields fields
wantErr bool
wantType OperatorLicenseType
}{
{
name: "get valid basic license: OK",
fields: fields{
initialObjects: []runtime.Object{},
operatorNamespace: "test-system",
},
wantType: LicenseTypeBasic,
wantErr: false,
},
{
name: "get valid enterprise license: OK",
fields: fields{
initialObjects: validLicense,
operatorNamespace: "test-system",
publicKey: publicKeyBytesFixture(t),
},
wantType: LicenseTypeEnterprise,
wantErr: false,
},
{
name: "get valid trial enterprise license: OK",
fields: fields{
initialObjects: []runtime.Object{validTrialLicense, &statusSecret},
operatorNamespace: "test-system",
publicKey: publicKeyBytesFixture(t),
},
wantType: LicenseTypeEnterpriseTrial,
wantErr: false,
},
{
name: "get valid enterprise license among two licenses: OK",
fields: fields{
initialObjects: append(validLicense, validTrialLicense),
operatorNamespace: "test-system",
publicKey: publicKeyBytesFixture(t),
},
wantType: LicenseTypeEnterprise,
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
lc := &checker{
k8sClient: k8s.NewFakeClient(tt.fields.initialObjects...),
operatorNamespace: tt.fields.operatorNamespace,
publicKey: tt.fields.publicKey,
}
licenseType, err := lc.ValidOperatorLicenseKey()
if (err != nil) != tt.wantErr {
t.Errorf("Checker.ValidOperatorLicenseKey() err = %v, wantErr %v", err, tt.wantErr)
}
if licenseType != tt.wantType {
t.Errorf("Checker.ValidOperatorLicenseKey() licenseType = %v, wantType %v", licenseType, tt.wantType)
}
})
}
}
8 changes: 8 additions & 0 deletions pkg/controller/common/license/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
type OperatorLicenseType string

const (
LicenseTypeBasic OperatorLicenseType = "basic"
LicenseTypeEnterprise OperatorLicenseType = "enterprise"
LicenseTypeEnterpriseTrial OperatorLicenseType = "enterprise_trial"
// LicenseTypeLegacyTrial earlier versions of ECK used this as the trial identifier
Expand Down Expand Up @@ -107,6 +108,13 @@ func (l EnterpriseLicense) IsMissingFields() error {
return nil
}

func (l *EnterpriseLicense) GetOperatorLicenseType() OperatorLicenseType {
if l == nil {
return LicenseTypeBasic
}
return l.License.Type
}

// LicenseStatus expresses the validity status of a license.
type LicenseStatus string

Expand Down

0 comments on commit 5480e1e

Please sign in to comment.